from tqdm import tqdm import numpy as np import os import torch.nn as nn import torch.optim as optim import torch.utils.data as data import torchvision.transforms as transforms import medmnist from medmnist import INFO, Evaluator from PIL import Image from torch.utils.data import Dataset import matplotlib.pyplot as plt from medmnist.utils import montage2d from medimeta import MedIMeta class DataClass(Dataset): def __init__(self, root, transform=None, size=224): """ Args: root (str): Path to the .npz file (e.g., 'data_root/breastminst.npz'). transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed version. size (int, optional): Image size. Defaults to 224. """ if not os.path.exists(root): raise FileNotFoundError(f"Dataset file not found at {root}") self.root = root self.transform = transform self.size = size # Load dataset npz_file = np.load(self.root, mmap_mode="r") self.imgs = npz_file["images"] # Assuming key names are 'images' and 'labels' self.labels = npz_file["labels"] # Check if grayscale or RGB self.n_channels = 3 if len(self.imgs.shape) == 4 and self.imgs.shape[-1] == 3 else 1 def __len__(self): return self.imgs.shape[0] def __getitem__(self, index): """ Returns: img (PIL.Image): Image loaded and transformed (if applicable). target (int/array): Corresponding label. """ img, target = self.imgs[index], self.labels[index].astype(int) img = Image.fromarray(img) if self.transform: img = self.transform(img) return img, target def montage(self, length=10, replace=False, save_folder=None): """ Create a montage of randomly selected images. Args: length (int): Number of images per row and column (default=10). replace (bool): Whether to allow selecting the same image multiple times. save_folder (str, optional): If provided, saves the montage image. Returns: PIL.Image: The generated montage. """ n_sel = length * length # Total images in montage indices = np.arange(n_sel) % len(self) # Generate montage using MedMNIST utility montage_img = montage2d(imgs=self.imgs, n_channels=self.n_channels, sel=indices) # Save montage if required if save_folder: os.makedirs(save_folder, exist_ok=True) save_path = os.path.join(save_folder, "montage1.jpg") montage_img.save(save_path) print(f"Montage saved at {save_path}") return montage_img def build_medmnist_dataset(data_root, transform): dataset = DataClass(root=data_root, transform=transform, size=224) return dataset def build_medimeta_dataset(data_root, task='bus', disease='Disease', transform=None): dataset = MedIMeta(data_root, task, disease, transform=transform) return dataset