|
|
from tqdm import tqdm |
|
|
import numpy as np |
|
|
import os |
|
|
import torch.nn as nn |
|
|
import torch.optim as optim |
|
|
import torch.utils.data as data |
|
|
import torchvision.transforms as transforms |
|
|
import medmnist |
|
|
from medmnist import INFO, Evaluator |
|
|
from PIL import Image |
|
|
from torch.utils.data import Dataset |
|
|
import matplotlib.pyplot as plt |
|
|
from medmnist.utils import montage2d |
|
|
from medimeta import MedIMeta |
|
|
|
|
|
class DataClass(Dataset): |
|
|
def __init__(self, root, transform=None, size=224): |
|
|
""" |
|
|
Args: |
|
|
root (str): Path to the .npz file (e.g., 'data_root/breastminst.npz'). |
|
|
transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed version. |
|
|
size (int, optional): Image size. Defaults to 224. |
|
|
""" |
|
|
if not os.path.exists(root): |
|
|
raise FileNotFoundError(f"Dataset file not found at {root}") |
|
|
|
|
|
self.root = root |
|
|
self.transform = transform |
|
|
self.size = size |
|
|
|
|
|
|
|
|
npz_file = np.load(self.root, mmap_mode="r") |
|
|
self.imgs = npz_file["images"] |
|
|
self.labels = npz_file["labels"] |
|
|
|
|
|
|
|
|
self.n_channels = 3 if len(self.imgs.shape) == 4 and self.imgs.shape[-1] == 3 else 1 |
|
|
|
|
|
def __len__(self): |
|
|
return self.imgs.shape[0] |
|
|
|
|
|
def __getitem__(self, index): |
|
|
""" |
|
|
Returns: |
|
|
img (PIL.Image): Image loaded and transformed (if applicable). |
|
|
target (int/array): Corresponding label. |
|
|
""" |
|
|
img, target = self.imgs[index], self.labels[index].astype(int) |
|
|
img = Image.fromarray(img) |
|
|
|
|
|
if self.transform: |
|
|
img = self.transform(img) |
|
|
|
|
|
return img, target |
|
|
|
|
|
def montage(self, length=10, replace=False, save_folder=None): |
|
|
""" |
|
|
Create a montage of randomly selected images. |
|
|
|
|
|
Args: |
|
|
length (int): Number of images per row and column (default=10). |
|
|
replace (bool): Whether to allow selecting the same image multiple times. |
|
|
save_folder (str, optional): If provided, saves the montage image. |
|
|
|
|
|
Returns: |
|
|
PIL.Image: The generated montage. |
|
|
""" |
|
|
n_sel = length * length |
|
|
indices = np.arange(n_sel) % len(self) |
|
|
|
|
|
|
|
|
montage_img = montage2d(imgs=self.imgs, n_channels=self.n_channels, sel=indices) |
|
|
|
|
|
|
|
|
if save_folder: |
|
|
os.makedirs(save_folder, exist_ok=True) |
|
|
save_path = os.path.join(save_folder, "montage1.jpg") |
|
|
montage_img.save(save_path) |
|
|
print(f"Montage saved at {save_path}") |
|
|
|
|
|
return montage_img |
|
|
|
|
|
def build_medmnist_dataset(data_root, transform): |
|
|
dataset = DataClass(root=data_root, transform=transform, size=224) |
|
|
return dataset |
|
|
|
|
|
def build_medimeta_dataset(data_root, task='bus', disease='Disease', transform=None): |
|
|
dataset = MedIMeta(data_root, task, disease, transform=transform) |
|
|
return dataset |