TCube_Merging / medmnistc_data.py
razaimam45's picture
Upload 108 files
a96891a verified
raw
history blame
3.1 kB
from tqdm import tqdm
import numpy as np
import os
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import torchvision.transforms as transforms
import medmnist
from medmnist import INFO, Evaluator
from PIL import Image
from torch.utils.data import Dataset
import matplotlib.pyplot as plt
from medmnist.utils import montage2d
from medimeta import MedIMeta
class DataClass(Dataset):
def __init__(self, root, transform=None, size=224):
"""
Args:
root (str): Path to the .npz file (e.g., 'data_root/breastminst.npz').
transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed version.
size (int, optional): Image size. Defaults to 224.
"""
if not os.path.exists(root):
raise FileNotFoundError(f"Dataset file not found at {root}")
self.root = root
self.transform = transform
self.size = size
# Load dataset
npz_file = np.load(self.root, mmap_mode="r")
self.imgs = npz_file["images"] # Assuming key names are 'images' and 'labels'
self.labels = npz_file["labels"]
# Check if grayscale or RGB
self.n_channels = 3 if len(self.imgs.shape) == 4 and self.imgs.shape[-1] == 3 else 1
def __len__(self):
return self.imgs.shape[0]
def __getitem__(self, index):
"""
Returns:
img (PIL.Image): Image loaded and transformed (if applicable).
target (int/array): Corresponding label.
"""
img, target = self.imgs[index], self.labels[index].astype(int)
img = Image.fromarray(img)
if self.transform:
img = self.transform(img)
return img, target
def montage(self, length=10, replace=False, save_folder=None):
"""
Create a montage of randomly selected images.
Args:
length (int): Number of images per row and column (default=10).
replace (bool): Whether to allow selecting the same image multiple times.
save_folder (str, optional): If provided, saves the montage image.
Returns:
PIL.Image: The generated montage.
"""
n_sel = length * length # Total images in montage
indices = np.arange(n_sel) % len(self)
# Generate montage using MedMNIST utility
montage_img = montage2d(imgs=self.imgs, n_channels=self.n_channels, sel=indices)
# Save montage if required
if save_folder:
os.makedirs(save_folder, exist_ok=True)
save_path = os.path.join(save_folder, "montage1.jpg")
montage_img.save(save_path)
print(f"Montage saved at {save_path}")
return montage_img
def build_medmnist_dataset(data_root, transform):
dataset = DataClass(root=data_root, transform=transform, size=224)
return dataset
def build_medimeta_dataset(data_root, task='bus', disease='Disease', transform=None):
dataset = MedIMeta(data_root, task, disease, transform=transform)
return dataset