diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..fc5dfa04549ce5e4870ffc786aaa93716dd8735c 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +experiments/pretrained/training_states/100000.state filter=lfs diff=lfs merge=lfs -text diff --git a/Aberration_Correction/Options/Test_Aberration_Transformers.yml b/Aberration_Correction/Options/Test_Aberration_Transformers.yml new file mode 100644 index 0000000000000000000000000000000000000000..837c97c9ded4d5c4f4691129daef0046e33ab819 --- /dev/null +++ b/Aberration_Correction/Options/Test_Aberration_Transformers.yml @@ -0,0 +1,75 @@ +# general settings +name: sample_test +# name: batch8 +model_type: ImageCleanModel +scale: 1 +num_gpu: 4 # set num_gpu: 0 for cpu mode +manual_seed: 100 + +# dataset and data loader settings +datasets: + val: + name: ValSet + type: Dataset_PaddedImage # Use Dataset_PaddedImage_npy if load convolved images (lr images). Also please set dataroot_lq as well. + dataroot_gt: PATH_TO_TEST_SET # TODO + io_backend: + type: disk + + sensor_size: 1215 + psf_size: 135 + +# network structures +network_g: + type: ACFormer + inp_channels: 39 + out_channels: 3 + dim: 48 + num_blocks: [2,4,4,4] + num_refinement_blocks: 4 + channel_heads: [1,2,4,8] + spatial_heads: [1,2,4,8] + overlap_ratio: [0.5,0.5,0.5,0.5] + window_size: 8 + spatial_dim_head: 16 + ffn_expansion_factor: 2.66 + bias: False + LayerNorm_type: WithBias + ca_dim: 32 + ca_heads: 2 + M: 13 + window_size_ca: 8 + query_ksize: [15,11,7,3,3] + +# path +path: + pretrain_network_g: ~ + strict_load_g: true + resume_state: ~ + +# training settings +train: + ks: + start: -2 + end: -5 + num: 13 + +# validation settings +val: + window_size: 8 + save_img: true + rgb2bgr: true + use_image: true + max_minibatch: 8 + padding: 64 + + metrics: + psnr: # metric name, can be arbitrary + type: calculate_psnr + crop_border: 0 + test_y_channel: true + + +# dist training settings +dist_params: + backend: nccl + port: 29502 diff --git a/Aberration_Correction/Options/Train_Aberration_Transformers.yml b/Aberration_Correction/Options/Train_Aberration_Transformers.yml new file mode 100644 index 0000000000000000000000000000000000000000..aa8c128fd0f6a747aae6826327d4dc8ce2f34b99 --- /dev/null +++ b/Aberration_Correction/Options/Train_Aberration_Transformers.yml @@ -0,0 +1,141 @@ +# general settings +name: sample_test +# name: batch8 +model_type: ImageCleanModel +scale: 1 +num_gpu: 4 # set num_gpu: 0 for cpu mode +manual_seed: 100 + +# dataset and data loader settings +datasets: + train: + name: TrainSet + type: Dataset_PaddedImage # make lr image from gt image on the fly. + dataroot_gt: PATH_TO_TRAIN_SET # TODO + + filename_tmpl: '{}' + io_backend: + type: disk + + # data loader + use_shuffle: true + num_worker_per_gpu: 8 # 8 + batch_size_per_gpu: 2 # 8 + + gt_size: 256 + + dataset_enlarge_ratio: 1 + prefetch_mode: ~ + + sensor_size: 1215 + psf_size: 135 + + val: + name: ValSet + type: Dataset_PaddedImage + dataroot_gt: PATH_TO_TEST_SET # TODO + io_backend: + type: disk + + sensor_size: 1215 + psf_size: 135 + +# network structures +network_g: + type: ACFormer + inp_channels: 39 + out_channels: 3 + dim: 48 + num_blocks: [2,4,4,4] + num_refinement_blocks: 4 + channel_heads: [1,2,4,8] + spatial_heads: [1,2,4,8] + overlap_ratio: [0.5,0.5,0.5,0.5] + window_size: 8 + spatial_dim_head: 16 + ffn_expansion_factor: 2.66 + bias: False + LayerNorm_type: WithBias + ca_dim: 32 + ca_heads: 2 + M: 13 + window_size_ca: 8 + query_ksize: [15,11,7,3,3] + +# path +path: + pretrain_network_g: ~ + strict_load_g: true + resume_state: ~ + +# training settings +train: + eval_only: True + eval_name: Sample_data + real_psf: True + grid: True + total_iter: 100000 + warmup_iter: -1 # no warm up + use_grad_clip: true + contrast_tik: 2 + sensor_height: 1215 + + scheduler: + type: CosineAnnealingRestartCyclicLR + periods: [92000, 208000] + restart_weights: [1,1] + eta_mins: [0.0003,0.000001] + + mixing_augs: + mixup: false + mixup_beta: 1.2 + use_identity: true + + optim_g: + type: AdamW + lr: !!float 3e-4 + weight_decay: !!float 1e-4 + betas: [0.9, 0.999] + + # losses + pixel_opt: + type: L1Loss + loss_weight: 1 + reduction: mean + + ks: + start: -2 + end: -5 + num: 13 + + +# validation settings +val: + window_size: 8 + val_freq: !!float 1e8 # inactivated + save_img: false + rgb2bgr: true + use_image: true + max_minibatch: 8 + padding: 64 + apply_conv: True # Apply convolution to GT image to create lr image. False if load .npy data (already aberrated) + + metrics: + psnr: # metric name, can be arbitrary + type: calculate_psnr + crop_border: 0 + test_y_channel: true + +# logging settings +logger: + print_freq: 500 + save_checkpoint_freq: !!float 5e3 + use_tb_logger: true + wandb: + project: ~ + resume_id: ~ + +# dist training settings +dist_params: + backend: nccl + port: 29502 diff --git a/Aberration_Correction/utils.py b/Aberration_Correction/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..b44a8952dd24cc3da09b380ccadac7e5a60f42cd --- /dev/null +++ b/Aberration_Correction/utils.py @@ -0,0 +1,90 @@ +## Restormer: Efficient Transformer for High-Resolution Image Restoration +## Syed Waqas Zamir, Aditya Arora, Salman Khan, Munawar Hayat, Fahad Shahbaz Khan, and Ming-Hsuan Yang +## https://arxiv.org/abs/2111.09881 + +import numpy as np +import os +import cv2 +import math + +def calculate_psnr(img1, img2, border=0): + # img1 and img2 have range [0, 255] + #img1 = img1.squeeze() + #img2 = img2.squeeze() + if not img1.shape == img2.shape: + raise ValueError('Input images must have the same dimensions.') + h, w = img1.shape[:2] + img1 = img1[border:h-border, border:w-border] + img2 = img2[border:h-border, border:w-border] + + img1 = img1.astype(np.float64) + img2 = img2.astype(np.float64) + mse = np.mean((img1 - img2)**2) + if mse == 0: + return float('inf') + return 20 * math.log10(255.0 / math.sqrt(mse)) + + +# -------------------------------------------- +# SSIM +# -------------------------------------------- +def calculate_ssim(img1, img2, border=0): + '''calculate SSIM + the same outputs as MATLAB's + img1, img2: [0, 255] + ''' + #img1 = img1.squeeze() + #img2 = img2.squeeze() + if not img1.shape == img2.shape: + raise ValueError('Input images must have the same dimensions.') + h, w = img1.shape[:2] + img1 = img1[border:h-border, border:w-border] + img2 = img2[border:h-border, border:w-border] + + if img1.ndim == 2: + return ssim(img1, img2) + elif img1.ndim == 3: + if img1.shape[2] == 3: + ssims = [] + for i in range(3): + ssims.append(ssim(img1[:,:,i], img2[:,:,i])) + return np.array(ssims).mean() + elif img1.shape[2] == 1: + return ssim(np.squeeze(img1), np.squeeze(img2)) + else: + raise ValueError('Wrong input image dimensions.') + + +def ssim(img1, img2): + C1 = (0.01 * 255)**2 + C2 = (0.03 * 255)**2 + + img1 = img1.astype(np.float64) + img2 = img2.astype(np.float64) + kernel = cv2.getGaussianKernel(11, 1.5) + window = np.outer(kernel, kernel.transpose()) + + mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid + mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5] + mu1_sq = mu1**2 + mu2_sq = mu2**2 + mu1_mu2 = mu1 * mu2 + sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq + sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq + sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2 + + ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * + (sigma1_sq + sigma2_sq + C2)) + return ssim_map.mean() + +def load_img(filepath): + return cv2.cvtColor(cv2.imread(filepath), cv2.COLOR_BGR2RGB) + +def save_img(filepath, img): + cv2.imwrite(filepath,cv2.cvtColor(img, cv2.COLOR_RGB2BGR)) + +def load_gray_img(filepath): + return np.expand_dims(cv2.imread(filepath, cv2.IMREAD_GRAYSCALE), axis=2) + +def save_gray_img(filepath, img): + cv2.imwrite(filepath, img) diff --git a/VERSION b/VERSION new file mode 100644 index 0000000000000000000000000000000000000000..26aaba0e86632e4d537006e45b0ec918d780b3b4 --- /dev/null +++ b/VERSION @@ -0,0 +1 @@ +1.2.0 diff --git a/basicsr/data/__init__.py b/basicsr/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..12803a51ba6e85d46da3d26a71ea17604c175a17 --- /dev/null +++ b/basicsr/data/__init__.py @@ -0,0 +1,126 @@ +import importlib +import numpy as np +import random +import torch +import torch.utils.data +from functools import partial +from os import path as osp + +from basicsr.data.prefetch_dataloader import PrefetchDataLoader +from basicsr.utils import get_root_logger, scandir +from basicsr.utils.dist_util import get_dist_info + +__all__ = ['create_dataset', 'create_dataloader'] + +# automatically scan and import dataset modules +# scan all the files under the data folder with '_dataset' in file names +data_folder = osp.dirname(osp.abspath(__file__)) +dataset_filenames = [ + osp.splitext(osp.basename(v))[0] for v in scandir(data_folder) + if v.endswith('_dataset.py') +] +# import all the dataset modules +_dataset_modules = [ + importlib.import_module(f'basicsr.data.{file_name}') + for file_name in dataset_filenames +] + + +def create_dataset(dataset_opt, mv=False): + """Create dataset. + + Args: + dataset_opt (dict): Configuration for dataset. It constains: + name (str): Dataset name. + type (str): Dataset type. + """ + dataset_type = dataset_opt['type'] + # dynamic instantiation + for module in _dataset_modules: + dataset_cls = getattr(module, dataset_type, None) + if dataset_cls is not None: + break + if dataset_cls is None: + raise ValueError(f'Dataset {dataset_type} is not found.') + + dataset = dataset_cls(dataset_opt) + + logger = get_root_logger() + logger.info( + f'Dataset {dataset.__class__.__name__} - {dataset_opt["name"]} ' + 'is created.') + return dataset + + +def create_dataloader(dataset, + dataset_opt, + num_gpu=1, + dist=False, + sampler=None, + seed=None): + """Create dataloader. + + Args: + dataset (torch.utils.data.Dataset): Dataset. + dataset_opt (dict): Dataset options. It contains the following keys: + phase (str): 'train' or 'val'. + num_worker_per_gpu (int): Number of workers for each GPU. + batch_size_per_gpu (int): Training batch size for each GPU. + num_gpu (int): Number of GPUs. Used only in the train phase. + Default: 1. + dist (bool): Whether in distributed training. Used only in the train + phase. Default: False. + sampler (torch.utils.data.sampler): Data sampler. Default: None. + seed (int | None): Seed. Default: None + """ + phase = dataset_opt['phase'] + rank, _ = get_dist_info() + if phase == 'train': + if dist: # distributed training + batch_size = dataset_opt['batch_size_per_gpu'] + num_workers = dataset_opt['num_worker_per_gpu'] + else: # non-distributed training + multiplier = 1 if num_gpu == 0 else num_gpu + batch_size = dataset_opt['batch_size_per_gpu'] * multiplier + num_workers = dataset_opt['num_worker_per_gpu'] * multiplier + dataloader_args = dict( + dataset=dataset, + batch_size=batch_size, + shuffle=False, + num_workers=num_workers, + sampler=sampler, + drop_last=True) + + if sampler is None: + dataloader_args['shuffle'] = True + dataloader_args['worker_init_fn'] = partial( + worker_init_fn, num_workers=num_workers, rank=rank, + seed=seed) if seed is not None else None + elif phase in ['val', 'test', 'val20']: # validation + dataloader_args = dict( + dataset=dataset, batch_size=1, shuffle=False, num_workers=0) + else: + raise ValueError(f'Wrong dataset phase: {phase}. ' + "Supported ones are 'train', 'val' and 'test'.") + + dataloader_args['pin_memory'] = dataset_opt.get('pin_memory', False) + + prefetch_mode = dataset_opt.get('prefetch_mode') + if prefetch_mode == 'cpu': # CPUPrefetcher + num_prefetch_queue = dataset_opt.get('num_prefetch_queue', 1) + logger = get_root_logger() + logger.info(f'Use {prefetch_mode} prefetch dataloader: ' + f'num_prefetch_queue = {num_prefetch_queue}') + return PrefetchDataLoader( + num_prefetch_queue=num_prefetch_queue, **dataloader_args) + else: + # prefetch_mode=None: Normal dataloader + # prefetch_mode='cuda': dataloader for CUDAPrefetcher + return torch.utils.data.DataLoader(**dataloader_args) + + +def worker_init_fn(worker_id, num_workers, rank, seed): + # Set the worker seed to num_workers * rank + worker_id + seed + worker_seed = num_workers * rank + worker_id + seed + np.random.seed(worker_seed) + random.seed(worker_seed) diff --git a/basicsr/data/data_sampler.py b/basicsr/data/data_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..0da5bf9c1ca4e32eb305e41efcc1f430f6d33421 --- /dev/null +++ b/basicsr/data/data_sampler.py @@ -0,0 +1,49 @@ +import math +import torch +from torch.utils.data.sampler import Sampler + + +class EnlargedSampler(Sampler): + """Sampler that restricts data loading to a subset of the dataset. + + Modified from torch.utils.data.distributed.DistributedSampler + Support enlarging the dataset for iteration-based training, for saving + time when restart the dataloader after each epoch + + Args: + dataset (torch.utils.data.Dataset): Dataset used for sampling. + num_replicas (int | None): Number of processes participating in + the training. It is usually the world_size. + rank (int | None): Rank of the current process within num_replicas. + ratio (int): Enlarging ratio. Default: 1. + """ + + def __init__(self, dataset, num_replicas, rank, ratio=1): + self.dataset = dataset + self.num_replicas = num_replicas + self.rank = rank + self.epoch = 0 + self.num_samples = math.ceil( + len(self.dataset) * ratio / self.num_replicas) + self.total_size = self.num_samples * self.num_replicas + + def __iter__(self): + # deterministically shuffle based on epoch + g = torch.Generator() + g.manual_seed(self.epoch) + indices = torch.randperm(self.total_size, generator=g).tolist() + + dataset_size = len(self.dataset) + indices = [v % dataset_size for v in indices] + + # subsample + indices = indices[self.rank:self.total_size:self.num_replicas] + assert len(indices) == self.num_samples + + return iter(indices) + + def __len__(self): + return self.num_samples + + def set_epoch(self, epoch): + self.epoch = epoch diff --git a/basicsr/data/data_util.py b/basicsr/data/data_util.py new file mode 100644 index 0000000000000000000000000000000000000000..f251385dd87cad21e82774d5d1e32ee2d3ea8ff5 --- /dev/null +++ b/basicsr/data/data_util.py @@ -0,0 +1,15 @@ +import cv2 +cv2.setNumThreads(1) +from os import path as osp +from basicsr.utils import scandir + + +def paths_from_folder(folder, key): + gt_paths = list(scandir(folder)) + paths = [] + for idx in range(len(gt_paths)): + gt_path = gt_paths[idx] + gt_path = osp.join(folder, gt_path) + paths.append( + dict([(f'{key}_path', gt_path)])) + return paths \ No newline at end of file diff --git a/basicsr/data/paired_image_dataset.py b/basicsr/data/paired_image_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..fd722391764df16fc2f30e795686adccd23b5f51 --- /dev/null +++ b/basicsr/data/paired_image_dataset.py @@ -0,0 +1,156 @@ +from torch.utils import data as data +from torchvision.transforms.functional import normalize + +from basicsr.data.data_util import paths_from_folder +from basicsr.utils import FileClient, imfrombytes, img2tensor, padding +from natsort import natsorted +import random +import numpy as np +import torch +import cv2 +import os +import random + + +class Dataset_PaddedImage(data.Dataset): + """Padded image dataset for image restoration. + + Args: + opt (dict): Config for train datasets. It contains the following keys: + dataroot_gt (str): Data root path for gt. + io_backend (dict): IO backend type and other kwarg. + gt_size (int): Cropped patched size for gt patches. + scale (bool): Scale, which will be added automatically. + phase (str): 'train' or 'val'. + """ + + def __init__(self, opt): + super(Dataset_PaddedImage, self).__init__() + self.opt = opt + # file client (io backend) + self.file_client = None + self.io_backend_opt = opt['io_backend'] + + self.gt_folder = opt['dataroot_gt'] + self.paths = paths_from_folder(self.gt_folder, 'gt') + + self.sensor_size = opt['sensor_size'] + self.psf_size = opt['psf_size'] + self.padded_size = self.sensor_size + 2 * self.psf_size + + def __getitem__(self, index): + if self.file_client is None: + self.file_client = FileClient( + self.io_backend_opt.pop('type'), **self.io_backend_opt) + + scale = self.opt['scale'] + index = index % len(self.paths) + # Load gt and lq images. Dimension order: HWC; channel order: BGR; + # image range: [0, 1], float32. + gt_path = self.paths[index]['gt_path'] + img_bytes = self.file_client.get(gt_path, 'gt') + try: + img_gt = imfrombytes(img_bytes, float32=True) + except: + raise Exception("gt path {} not working".format(gt_path)) + + + if self.opt['phase'] == 'train': + gt_size = self.opt['gt_size'] + # padding + img_gt = padding(img_gt, gt_size) # h,w,c + orig_h, orig_w, _ = img_gt.shape + + # Fit one axis to sensor height (width) + longer = max(orig_h, orig_w) + scale = float(longer / self.sensor_size) + resolution = (int(orig_w / scale), int(orig_h / scale)) + img_gt = cv2.resize(img_gt, resolution, interpolation=cv2.INTER_LINEAR) # sensor_size,x,3 or y,sensor_size,3 where x,y <= sensor_size + + resized_h, resized_w, _ = img_gt.shape + # add padding + pad_h = self.padded_size - resized_h + pad_w = self.padded_size - resized_w + pad_l = pad_r = pad_w // 2 + if pad_w % 2: + pad_r += 1 + pad_t = pad_b = pad_h // 2 + if pad_h % 2: + pad_b += 1 + img_gt = np.pad(img_gt, ((pad_t, pad_b), (pad_l, pad_r), (0,0))) # padded_size,padded_size,3 + + # BGR to RGB, HWC to CHW, numpy to tensor + img_gt = img2tensor(img_gt, bgr2rgb=True, + float32=True) + + return { + 'gt': img_gt, + 'gt_path': gt_path, + 'padding': (pad_t-self.psf_size, pad_b-self.psf_size, pad_l-self.psf_size, pad_r-self.psf_size) + } + + def __len__(self): + return len(self.paths) + +class Dataset_PaddedImage_npy(data.Dataset): + # validation only + def __init__(self, opt): + super(Dataset_PaddedImage_npy, self).__init__() + self.opt = opt + # file client (io backend) + self.file_client = None + self.io_backend_opt = opt['io_backend'] + self.gt_folder, self.lq_folder = opt['dataroot_gt'], opt['dataroot_lq'] + self.lq_paths = natsorted(os.listdir(self.lq_folder)) + self.gt_paths = natsorted(os.listdir(self.gt_folder)) + + self.sensor_size = opt['sensor_size'] + self.psf_size = opt['psf_size'] + self.padded_size = self.sensor_size + 2 * self.psf_size + + + def __getitem__(self, index): + if self.file_client is None: + self.file_client = FileClient( + self.io_backend_opt.pop('type'), **self.io_backend_opt) + + scale = self.opt['scale'] + index = index % len(self.gt_paths) + # Load gt and lq images. Dimension order: HWC; channel order: BGR; + # image range: [0, 1], float32. + gt_path = f"{self.gt_folder}/{self.gt_paths[index]}" + lq_path = f"{self.lq_folder}/{self.lq_paths[index]}" + assert os.path.basename(gt_path).split(".")[0] == os.path.basename(lq_path).split(".")[0] + + img_bytes = self.file_client.get(gt_path, 'gt') + try: + img_gt = imfrombytes(img_bytes, float32=True) + except: + raise Exception("gt path {} not working".format(gt_path)) + + img_lq = torch.tensor(np.load(lq_path)) # 1,1,81,3,405,405 + + resized_h, resized_w, _ = img_gt.shape + pad_h = self.padded_size - resized_h + pad_w = self.padded_size - resized_w + pad_l = pad_r = pad_w // 2 + if pad_w % 2: + pad_r += 1 + pad_t = pad_b = pad_h // 2 + if pad_h % 2: + pad_b += 1 + + # BGR to RGB, HWC to CHW, numpy to tensor + img_gt = img2tensor(img_gt, bgr2rgb=True, + float32=True) + + return { + 'gt': img_gt, + 'lq': img_lq, + 'lq_path': lq_path, + 'gt_path': gt_path, + 'padding': (pad_t-self.psf_size, pad_b-self.psf_size, pad_l-self.psf_size, pad_r-self.psf_size) + } + + def __len__(self): + return len(self.gt_paths) diff --git a/basicsr/data/prefetch_dataloader.py b/basicsr/data/prefetch_dataloader.py new file mode 100644 index 0000000000000000000000000000000000000000..ad78c4a52c410bd3fe9cc5e58436c7b17d935058 --- /dev/null +++ b/basicsr/data/prefetch_dataloader.py @@ -0,0 +1,126 @@ +import queue as Queue +import threading +import torch +from torch.utils.data import DataLoader + + +class PrefetchGenerator(threading.Thread): + """A general prefetch generator. + + Ref: + https://stackoverflow.com/questions/7323664/python-generator-pre-fetch + + Args: + generator: Python generator. + num_prefetch_queue (int): Number of prefetch queue. + """ + + def __init__(self, generator, num_prefetch_queue): + threading.Thread.__init__(self) + self.queue = Queue.Queue(num_prefetch_queue) + self.generator = generator + self.daemon = True + self.start() + + def run(self): + for item in self.generator: + self.queue.put(item) + self.queue.put(None) + + def __next__(self): + next_item = self.queue.get() + if next_item is None: + raise StopIteration + return next_item + + def __iter__(self): + return self + + +class PrefetchDataLoader(DataLoader): + """Prefetch version of dataloader. + + Ref: + https://github.com/IgorSusmelj/pytorch-styleguide/issues/5# + + TODO: + Need to test on single gpu and ddp (multi-gpu). There is a known issue in + ddp. + + Args: + num_prefetch_queue (int): Number of prefetch queue. + kwargs (dict): Other arguments for dataloader. + """ + + def __init__(self, num_prefetch_queue, **kwargs): + self.num_prefetch_queue = num_prefetch_queue + super(PrefetchDataLoader, self).__init__(**kwargs) + + def __iter__(self): + return PrefetchGenerator(super().__iter__(), self.num_prefetch_queue) + + +class CPUPrefetcher(): + """CPU prefetcher. + + Args: + loader: Dataloader. + """ + + def __init__(self, loader): + self.ori_loader = loader + self.loader = iter(loader) + + def next(self): + try: + return next(self.loader) + except StopIteration: + return None + + def reset(self): + self.loader = iter(self.ori_loader) + + +class CUDAPrefetcher(): + """CUDA prefetcher. + + Ref: + https://github.com/NVIDIA/apex/issues/304# + + It may consums more GPU memory. + + Args: + loader: Dataloader. + opt (dict): Options. + """ + + def __init__(self, loader, opt): + self.ori_loader = loader + self.loader = iter(loader) + self.opt = opt + self.stream = torch.cuda.Stream() + self.device = torch.device('cuda' if opt['num_gpu'] != 0 else 'cpu') + self.preload() + + def preload(self): + try: + self.batch = next(self.loader) # self.batch is a dict + except StopIteration: + self.batch = None + return None + # put tensors to gpu + with torch.cuda.stream(self.stream): + for k, v in self.batch.items(): + if torch.is_tensor(v): + self.batch[k] = self.batch[k].to( + device=self.device, non_blocking=True) + + def next(self): + torch.cuda.current_stream().wait_stream(self.stream) + batch = self.batch + self.preload() + return batch + + def reset(self): + self.loader = iter(self.ori_loader) + self.preload() diff --git a/basicsr/data/transforms.py b/basicsr/data/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..c838aca888596e02542120e34722d770e6ec4632 --- /dev/null +++ b/basicsr/data/transforms.py @@ -0,0 +1,167 @@ +import cv2 +import random +import numpy as np + +def mod_crop(img, scale): + """Mod crop images, used during testing. + + Args: + img (ndarray): Input image. + scale (int): Scale factor. + + Returns: + ndarray: Result image. + """ + img = img.copy() + if img.ndim in (2, 3): + h, w = img.shape[0], img.shape[1] + h_remainder, w_remainder = h % scale, w % scale + img = img[:h - h_remainder, :w - w_remainder, ...] + else: + raise ValueError(f'Wrong img ndim: {img.ndim}.') + return img + + +def augment(imgs, hflip=True, rotation=True, flows=None, return_status=False): + """Augment: horizontal flips OR rotate (0, 90, 180, 270 degrees). + + We use vertical flip and transpose for rotation implementation. + All the images in the list use the same augmentation. + + Args: + imgs (list[ndarray] | ndarray): Images to be augmented. If the input + is an ndarray, it will be transformed to a list. + hflip (bool): Horizontal flip. Default: True. + rotation (bool): Ratotation. Default: True. + flows (list[ndarray]: Flows to be augmented. If the input is an + ndarray, it will be transformed to a list. + Dimension is (h, w, 2). Default: None. + return_status (bool): Return the status of flip and rotation. + Default: False. + + Returns: + list[ndarray] | ndarray: Augmented images and flows. If returned + results only have one element, just return ndarray. + + """ + hflip = hflip and random.random() < 0.5 + vflip = rotation and random.random() < 0.5 + rot90 = rotation and random.random() < 0.5 + + def _augment(img): + if hflip: # horizontal + cv2.flip(img, 1, img) + if vflip: # vertical + cv2.flip(img, 0, img) + if rot90: + img = img.transpose(1, 0, 2) + return img + + def _augment_flow(flow): + if hflip: # horizontal + cv2.flip(flow, 1, flow) + flow[:, :, 0] *= -1 + if vflip: # vertical + cv2.flip(flow, 0, flow) + flow[:, :, 1] *= -1 + if rot90: + flow = flow.transpose(1, 0, 2) + flow = flow[:, :, [1, 0]] + return flow + + if not isinstance(imgs, list): + imgs = [imgs] + imgs = [_augment(img) for img in imgs] + if len(imgs) == 1: + imgs = imgs[0] + + if flows is not None: + if not isinstance(flows, list): + flows = [flows] + flows = [_augment_flow(flow) for flow in flows] + if len(flows) == 1: + flows = flows[0] + return imgs, flows + else: + if return_status: + return imgs, (hflip, vflip, rot90) + else: + return imgs + + +def img_rotate(img, angle, center=None, scale=1.0): + """Rotate image. + + Args: + img (ndarray): Image to be rotated. + angle (float): Rotation angle in degrees. Positive values mean + counter-clockwise rotation. + center (tuple[int]): Rotation center. If the center is None, + initialize it as the center of the image. Default: None. + scale (float): Isotropic scale factor. Default: 1.0. + """ + (h, w) = img.shape[:2] + + if center is None: + center = (w // 2, h // 2) + + matrix = cv2.getRotationMatrix2D(center, angle, scale) + rotated_img = cv2.warpAffine(img, matrix, (w, h)) + return rotated_img + +def data_augmentation(image, mode): + """ + Performs data augmentation of the input image + Input: + image: a cv2 (OpenCV) image + mode: int. Choice of transformation to apply to the image + 0 - no transformation + 1 - flip up and down + 2 - rotate counterwise 90 degree + 3 - rotate 90 degree and flip up and down + 4 - rotate 180 degree + 5 - rotate 180 degree and flip + 6 - rotate 270 degree + 7 - rotate 270 degree and flip + """ + if mode == 0: + # original + out = image + elif mode == 1: + # flip up and down + out = np.flipud(image) + elif mode == 2: + # rotate counterwise 90 degree + out = np.rot90(image) + elif mode == 3: + # rotate 90 degree and flip up and down + out = np.rot90(image) + out = np.flipud(out) + elif mode == 4: + # rotate 180 degree + out = np.rot90(image, k=2) + elif mode == 5: + # rotate 180 degree and flip + out = np.rot90(image, k=2) + out = np.flipud(out) + elif mode == 6: + # rotate 270 degree + out = np.rot90(image, k=3) + elif mode == 7: + # rotate 270 degree and flip + out = np.rot90(image, k=3) + out = np.flipud(out) + else: + raise Exception('Invalid choice of image transformation') + + return out + +def random_augmentation(*args): + out = [] + flag_aug = random.randint(0,7) + for data in args: + if type(data) == list: + out.append([data_augmentation(_data, flag_aug).copy() for _data in data]) + else: + out.append(data_augmentation(data, flag_aug).copy()) + return out diff --git a/basicsr/metrics/__init__.py b/basicsr/metrics/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3f4804754dc35451d353d0390708462ee571c9d8 --- /dev/null +++ b/basicsr/metrics/__init__.py @@ -0,0 +1,4 @@ +from .niqe import calculate_niqe +from .psnr_ssim import calculate_psnr, calculate_ssim + +__all__ = ['calculate_psnr', 'calculate_ssim', 'calculate_niqe'] diff --git a/basicsr/metrics/fid.py b/basicsr/metrics/fid.py new file mode 100644 index 0000000000000000000000000000000000000000..35fc23db48ccc90020ab1d603fcad9fc215c12ef --- /dev/null +++ b/basicsr/metrics/fid.py @@ -0,0 +1,102 @@ +import numpy as np +import torch +import torch.nn as nn +from scipy import linalg +from tqdm import tqdm + +from basicsr.models.archs.inception import InceptionV3 + + +def load_patched_inception_v3(device='cuda', + resize_input=True, + normalize_input=False): + # we may not resize the input, but in [rosinality/stylegan2-pytorch] it + # does resize the input. + inception = InceptionV3([3], + resize_input=resize_input, + normalize_input=normalize_input) + inception = nn.DataParallel(inception).eval().to(device) + return inception + + +@torch.no_grad() +def extract_inception_features(data_generator, + inception, + len_generator=None, + device='cuda'): + """Extract inception features. + + Args: + data_generator (generator): A data generator. + inception (nn.Module): Inception model. + len_generator (int): Length of the data_generator to show the + progressbar. Default: None. + device (str): Device. Default: cuda. + + Returns: + Tensor: Extracted features. + """ + if len_generator is not None: + pbar = tqdm(total=len_generator, unit='batch', desc='Extract') + else: + pbar = None + features = [] + + for data in data_generator: + if pbar: + pbar.update(1) + data = data.to(device) + feature = inception(data)[0].view(data.shape[0], -1) + features.append(feature.to('cpu')) + if pbar: + pbar.close() + features = torch.cat(features, 0) + return features + + +def calculate_fid(mu1, sigma1, mu2, sigma2, eps=1e-6): + """Numpy implementation of the Frechet Distance. + + The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) + and X_2 ~ N(mu_2, C_2) is + d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). + Stable version by Dougal J. Sutherland. + + Args: + mu1 (np.array): The sample mean over activations. + sigma1 (np.array): The covariance matrix over activations for + generated samples. + mu2 (np.array): The sample mean over activations, precalculated on an + representative data set. + sigma2 (np.array): The covariance matrix over activations, + precalculated on an representative data set. + + Returns: + float: The Frechet Distance. + """ + assert mu1.shape == mu2.shape, 'Two mean vectors have different lengths' + assert sigma1.shape == sigma2.shape, ( + 'Two covariances have different dimensions') + + cov_sqrt, _ = linalg.sqrtm(sigma1 @ sigma2, disp=False) + + # Product might be almost singular + if not np.isfinite(cov_sqrt).all(): + print('Product of cov matrices is singular. Adding {eps} to diagonal ' + 'of cov estimates') + offset = np.eye(sigma1.shape[0]) * eps + cov_sqrt = linalg.sqrtm((sigma1 + offset) @ (sigma2 + offset)) + + # Numerical error might give slight imaginary component + if np.iscomplexobj(cov_sqrt): + if not np.allclose(np.diagonal(cov_sqrt).imag, 0, atol=1e-3): + m = np.max(np.abs(cov_sqrt.imag)) + raise ValueError(f'Imaginary component {m}') + cov_sqrt = cov_sqrt.real + + mean_diff = mu1 - mu2 + mean_norm = mean_diff @ mean_diff + trace = np.trace(sigma1) + np.trace(sigma2) - 2 * np.trace(cov_sqrt) + fid = mean_norm + trace + + return fid diff --git a/basicsr/metrics/metric_util.py b/basicsr/metrics/metric_util.py new file mode 100644 index 0000000000000000000000000000000000000000..fb38e1bc7c281f1cc2a23b2cdc994812baad6533 --- /dev/null +++ b/basicsr/metrics/metric_util.py @@ -0,0 +1,47 @@ +import numpy as np + +from basicsr.utils.matlab_functions import bgr2ycbcr + + +def reorder_image(img, input_order='HWC'): + """Reorder images to 'HWC' order. + + If the input_order is (h, w), return (h, w, 1); + If the input_order is (c, h, w), return (h, w, c); + If the input_order is (h, w, c), return as it is. + + Args: + img (ndarray): Input image. + input_order (str): Whether the input order is 'HWC' or 'CHW'. + If the input image shape is (h, w), input_order will not have + effects. Default: 'HWC'. + + Returns: + ndarray: reordered image. + """ + + if input_order not in ['HWC', 'CHW']: + raise ValueError( + f'Wrong input_order {input_order}. Supported input_orders are ' + "'HWC' and 'CHW'") + if len(img.shape) == 2: + img = img[..., None] + if input_order == 'CHW': + img = img.transpose(1, 2, 0) + return img + + +def to_y_channel(img): + """Change to Y channel of YCbCr. + + Args: + img (ndarray): Images with range [0, 255]. + + Returns: + (ndarray): Images with range [0, 255] (float type) without round. + """ + img = img.astype(np.float32) / 255. + if img.ndim == 3 and img.shape[2] == 3: + img = bgr2ycbcr(img, y_only=True) + img = img[..., None] + return img * 255. diff --git a/basicsr/metrics/niqe.py b/basicsr/metrics/niqe.py new file mode 100644 index 0000000000000000000000000000000000000000..3ceb45da8ad76c2e9a414e2a95367a7ebd29fbcf --- /dev/null +++ b/basicsr/metrics/niqe.py @@ -0,0 +1,205 @@ +import cv2 +import math +import numpy as np +from scipy.ndimage.filters import convolve +from scipy.special import gamma + +from basicsr.metrics.metric_util import reorder_image, to_y_channel + + +def estimate_aggd_param(block): + """Estimate AGGD (Asymmetric Generalized Gaussian Distribution) paramters. + + Args: + block (ndarray): 2D Image block. + + Returns: + tuple: alpha (float), beta_l (float) and beta_r (float) for the AGGD + distribution (Estimating the parames in Equation 7 in the paper). + """ + block = block.flatten() + gam = np.arange(0.2, 10.001, 0.001) # len = 9801 + gam_reciprocal = np.reciprocal(gam) + r_gam = np.square(gamma(gam_reciprocal * 2)) / ( + gamma(gam_reciprocal) * gamma(gam_reciprocal * 3)) + + left_std = np.sqrt(np.mean(block[block < 0]**2)) + right_std = np.sqrt(np.mean(block[block > 0]**2)) + gammahat = left_std / right_std + rhat = (np.mean(np.abs(block)))**2 / np.mean(block**2) + rhatnorm = (rhat * (gammahat**3 + 1) * + (gammahat + 1)) / ((gammahat**2 + 1)**2) + array_position = np.argmin((r_gam - rhatnorm)**2) + + alpha = gam[array_position] + beta_l = left_std * np.sqrt(gamma(1 / alpha) / gamma(3 / alpha)) + beta_r = right_std * np.sqrt(gamma(1 / alpha) / gamma(3 / alpha)) + return (alpha, beta_l, beta_r) + + +def compute_feature(block): + """Compute features. + + Args: + block (ndarray): 2D Image block. + + Returns: + list: Features with length of 18. + """ + feat = [] + alpha, beta_l, beta_r = estimate_aggd_param(block) + feat.extend([alpha, (beta_l + beta_r) / 2]) + + # distortions disturb the fairly regular structure of natural images. + # This deviation can be captured by analyzing the sample distribution of + # the products of pairs of adjacent coefficients computed along + # horizontal, vertical and diagonal orientations. + shifts = [[0, 1], [1, 0], [1, 1], [1, -1]] + for i in range(len(shifts)): + shifted_block = np.roll(block, shifts[i], axis=(0, 1)) + alpha, beta_l, beta_r = estimate_aggd_param(block * shifted_block) + # Eq. 8 + mean = (beta_r - beta_l) * (gamma(2 / alpha) / gamma(1 / alpha)) + feat.extend([alpha, mean, beta_l, beta_r]) + return feat + + +def niqe(img, + mu_pris_param, + cov_pris_param, + gaussian_window, + block_size_h=96, + block_size_w=96): + """Calculate NIQE (Natural Image Quality Evaluator) metric. + + Ref: Making a "Completely Blind" Image Quality Analyzer. + This implementation could produce almost the same results as the official + MATLAB codes: http://live.ece.utexas.edu/research/quality/niqe_release.zip + + Note that we do not include block overlap height and width, since they are + always 0 in the official implementation. + + For good performance, it is advisable by the official implemtation to + divide the distorted image in to the same size patched as used for the + construction of multivariate Gaussian model. + + Args: + img (ndarray): Input image whose quality needs to be computed. The + image must be a gray or Y (of YCbCr) image with shape (h, w). + Range [0, 255] with float type. + mu_pris_param (ndarray): Mean of a pre-defined multivariate Gaussian + model calculated on the pristine dataset. + cov_pris_param (ndarray): Covariance of a pre-defined multivariate + Gaussian model calculated on the pristine dataset. + gaussian_window (ndarray): A 7x7 Gaussian window used for smoothing the + image. + block_size_h (int): Height of the blocks in to which image is divided. + Default: 96 (the official recommended value). + block_size_w (int): Width of the blocks in to which image is divided. + Default: 96 (the official recommended value). + """ + assert img.ndim == 2, ( + 'Input image must be a gray or Y (of YCbCr) image with shape (h, w).') + # crop image + h, w = img.shape + num_block_h = math.floor(h / block_size_h) + num_block_w = math.floor(w / block_size_w) + img = img[0:num_block_h * block_size_h, 0:num_block_w * block_size_w] + + distparam = [] # dist param is actually the multiscale features + for scale in (1, 2): # perform on two scales (1, 2) + mu = convolve(img, gaussian_window, mode='nearest') + sigma = np.sqrt( + np.abs( + convolve(np.square(img), gaussian_window, mode='nearest') - + np.square(mu))) + # normalize, as in Eq. 1 in the paper + img_nomalized = (img - mu) / (sigma + 1) + + feat = [] + for idx_w in range(num_block_w): + for idx_h in range(num_block_h): + # process ecah block + block = img_nomalized[idx_h * block_size_h // + scale:(idx_h + 1) * block_size_h // + scale, idx_w * block_size_w // + scale:(idx_w + 1) * block_size_w // + scale] + feat.append(compute_feature(block)) + + distparam.append(np.array(feat)) + # TODO: matlab bicubic downsample with anti-aliasing + # for simplicity, now we use opencv instead, which will result in + # a slight difference. + if scale == 1: + h, w = img.shape + img = cv2.resize( + img / 255., (w // 2, h // 2), interpolation=cv2.INTER_LINEAR) + img = img * 255. + + distparam = np.concatenate(distparam, axis=1) + + # fit a MVG (multivariate Gaussian) model to distorted patch features + mu_distparam = np.nanmean(distparam, axis=0) + # use nancov. ref: https://ww2.mathworks.cn/help/stats/nancov.html + distparam_no_nan = distparam[~np.isnan(distparam).any(axis=1)] + cov_distparam = np.cov(distparam_no_nan, rowvar=False) + + # compute niqe quality, Eq. 10 in the paper + invcov_param = np.linalg.pinv((cov_pris_param + cov_distparam) / 2) + quality = np.matmul( + np.matmul((mu_pris_param - mu_distparam), invcov_param), + np.transpose((mu_pris_param - mu_distparam))) + quality = np.sqrt(quality) + + return quality + + +def calculate_niqe(img, crop_border, input_order='HWC', convert_to='y'): + """Calculate NIQE (Natural Image Quality Evaluator) metric. + + Ref: Making a "Completely Blind" Image Quality Analyzer. + This implementation could produce almost the same results as the official + MATLAB codes: http://live.ece.utexas.edu/research/quality/niqe_release.zip + + We use the official params estimated from the pristine dataset. + We use the recommended block size (96, 96) without overlaps. + + Args: + img (ndarray): Input image whose quality needs to be computed. + The input image must be in range [0, 255] with float/int type. + The input_order of image can be 'HW' or 'HWC' or 'CHW'. (BGR order) + If the input order is 'HWC' or 'CHW', it will be converted to gray + or Y (of YCbCr) image according to the ``convert_to`` argument. + crop_border (int): Cropped pixels in each edge of an image. These + pixels are not involved in the metric calculation. + input_order (str): Whether the input order is 'HW', 'HWC' or 'CHW'. + Default: 'HWC'. + convert_to (str): Whether coverted to 'y' (of MATLAB YCbCr) or 'gray'. + Default: 'y'. + + Returns: + float: NIQE result. + """ + + # we use the official params estimated from the pristine dataset. + niqe_pris_params = np.load('basicsr/metrics/niqe_pris_params.npz') + mu_pris_param = niqe_pris_params['mu_pris_param'] + cov_pris_param = niqe_pris_params['cov_pris_param'] + gaussian_window = niqe_pris_params['gaussian_window'] + + img = img.astype(np.float32) + if input_order != 'HW': + img = reorder_image(img, input_order=input_order) + if convert_to == 'y': + img = to_y_channel(img) + elif convert_to == 'gray': + img = cv2.cvtColor(img / 255., cv2.COLOR_BGR2GRAY) * 255. + img = np.squeeze(img) + + if crop_border != 0: + img = img[crop_border:-crop_border, crop_border:-crop_border] + + niqe_result = niqe(img, mu_pris_param, cov_pris_param, gaussian_window) + + return niqe_result \ No newline at end of file diff --git a/basicsr/metrics/niqe_pris_params.npz b/basicsr/metrics/niqe_pris_params.npz new file mode 100644 index 0000000000000000000000000000000000000000..42f06a9a18e6ed8bbf7933bec1477b189ef798de --- /dev/null +++ b/basicsr/metrics/niqe_pris_params.npz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2a7c182a68c9e7f1b2e2e5ec723279d6f65d912b6fcaf37eb2bf03d7367c4296 +size 11850 diff --git a/basicsr/metrics/other_metrics.py b/basicsr/metrics/other_metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..966a08a26dba3fe810cf06ff7e5ac5cc7ed31d6e --- /dev/null +++ b/basicsr/metrics/other_metrics.py @@ -0,0 +1,88 @@ +import torch +import numpy as np +import os +from PIL import Image +from natsort import natsorted +from glob import glob +from skimage import metrics +import torch.hub +from lpips.lpips import LPIPS +from tqdm import tqdm + + +photometric = { + "mse": None, + "ssim": None, + "psnr": None, + "lpips": None +} + +def psnr(img1, img2): + mse = (((img1 - img2)) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True) + return 20 * torch.log10(1.0 / torch.sqrt(mse)) + +def compute_img_metric(im1t: torch.Tensor, im2t: torch.Tensor, + metric="mse", mask=None): + """ + im1t, im2t: torch.tensors with batched imaged shape, range from (0, 1) + """ + if metric not in photometric.keys(): + raise RuntimeError(f"img_utils:: metric {metric} not recognized") + if photometric[metric] is None: + if metric == "mse": + photometric[metric] = metrics.mean_squared_error + elif metric == "ssim": + photometric[metric] = metrics.structural_similarity + elif metric == "psnr": + photometric[metric] = metrics.peak_signal_noise_ratio + elif metric == "lpips": + photometric[metric] = LPIPS().cpu() + + # convert from [0, 1] to [-1, 1] + im1t = (im1t * 2 - 1).clamp(-1, 1) + im2t = (im2t * 2 - 1).clamp(-1, 1) + + if im1t.dim() == 3: + im1t = im1t.unsqueeze(0) + im2t = im2t.unsqueeze(0) + im1t = im1t.detach().cpu() + im2t = im2t.detach().cpu() + + if im1t.shape[-1] == 3: + im1t = im1t.permute(0, 3, 1, 2) # BCHW + im2t = im2t.permute(0, 3, 1, 2) + + im1 = im1t.permute(0, 2, 3, 1).numpy() + im2 = im2t.permute(0, 2, 3, 1).numpy() + batchsz, hei, wid, _ = im1.shape + values = [] + + for i in range(batchsz): + if metric in ["mse", "psnr"]: + if mask is not None: + im1 = im1 * mask[i] + im2 = im2 * mask[i] + value = photometric[metric]( + im1[i], im2[i] + ) + if mask is not None: + hei, wid, _ = im1[i].shape + pixelnum = mask[i, ..., 0].sum() + value = value - 10 * np.log10(hei * wid / pixelnum) + elif metric in ["ssim"]: + value, ssimmap = photometric["ssim"]( + im1[i], im2[i], multichannel=True, full=True + ) + if mask is not None: + value = (ssimmap * mask[i]).sum() / mask[i].sum() + elif metric in ["lpips"]: + value = photometric[metric]( + im1t[i:i + 1], im2t[i:i + 1] + ) + else: + raise NotImplementedError + values.append(value) + + return sum(values) / len(values) + + diff --git a/basicsr/metrics/psnr_ssim.py b/basicsr/metrics/psnr_ssim.py new file mode 100644 index 0000000000000000000000000000000000000000..ab4721bcff5a6331d9d95785eb1d261c14c97699 --- /dev/null +++ b/basicsr/metrics/psnr_ssim.py @@ -0,0 +1,303 @@ +import cv2 +import numpy as np + +from basicsr.metrics.metric_util import reorder_image, to_y_channel +import skimage.metrics +import torch + + +def calculate_psnr(img1, + img2, + crop_border, + input_order='HWC', + test_y_channel=False): + """Calculate PSNR (Peak Signal-to-Noise Ratio). + + Ref: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio + + Args: + img1 (ndarray/tensor): Images with range [0, 255]/[0, 1]. + img2 (ndarray/tensor): Images with range [0, 255]/[0, 1]. + crop_border (int): Cropped pixels in each edge of an image. These + pixels are not involved in the PSNR calculation. + input_order (str): Whether the input order is 'HWC' or 'CHW'. + Default: 'HWC'. + test_y_channel (bool): Test on Y channel of YCbCr. Default: False. + + Returns: + float: psnr result. + """ + + assert img1.shape == img2.shape, ( + f'Image shapes are differnet: {img1.shape}, {img2.shape}.') + if input_order not in ['HWC', 'CHW']: + raise ValueError( + f'Wrong input_order {input_order}. Supported input_orders are ' + '"HWC" and "CHW"') + if type(img1) == torch.Tensor: + if len(img1.shape) == 4: + img1 = img1.squeeze(0) + img1 = img1.detach().cpu().numpy().transpose(1,2,0) + if type(img2) == torch.Tensor: + if len(img2.shape) == 4: + img2 = img2.squeeze(0) + img2 = img2.detach().cpu().numpy().transpose(1,2,0) + + img1 = reorder_image(img1, input_order=input_order) + img2 = reorder_image(img2, input_order=input_order) + img1 = img1.astype(np.float64) + img2 = img2.astype(np.float64) + + if crop_border != 0: + img1 = img1[crop_border:-crop_border, crop_border:-crop_border, ...] + img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...] + + if test_y_channel: + img1 = to_y_channel(img1) + img2 = to_y_channel(img2) + + mse = np.mean((img1 - img2)**2) + if mse == 0: + return float('inf') + max_value = 1. if img1.max() <= 1 else 255. + return 20. * np.log10(max_value / np.sqrt(mse)) + + +def _ssim(img1, img2): + """Calculate SSIM (structural similarity) for one channel images. + + It is called by func:`calculate_ssim`. + + Args: + img1 (ndarray): Images with range [0, 255] with order 'HWC'. + img2 (ndarray): Images with range [0, 255] with order 'HWC'. + + Returns: + float: ssim result. + """ + + C1 = (0.01 * 255)**2 + C2 = (0.03 * 255)**2 + + img1 = img1.astype(np.float64) + img2 = img2.astype(np.float64) + kernel = cv2.getGaussianKernel(11, 1.5) + window = np.outer(kernel, kernel.transpose()) + + mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] + mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5] + mu1_sq = mu1**2 + mu2_sq = mu2**2 + mu1_mu2 = mu1 * mu2 + sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq + sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq + sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2 + + ssim_map = ((2 * mu1_mu2 + C1) * + (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * + (sigma1_sq + sigma2_sq + C2)) + return ssim_map.mean() + +def prepare_for_ssim(img, k): + import torch + with torch.no_grad(): + img = torch.from_numpy(img).unsqueeze(0).unsqueeze(0).float() + conv = torch.nn.Conv2d(1, 1, k, stride=1, padding=k//2, padding_mode='reflect') + conv.weight.requires_grad = False + conv.weight[:, :, :, :] = 1. / (k * k) + + img = conv(img) + + img = img.squeeze(0).squeeze(0) + img = img[0::k, 0::k] + return img.detach().cpu().numpy() + +def prepare_for_ssim_rgb(img, k): + import torch + with torch.no_grad(): + img = torch.from_numpy(img).float() #HxWx3 + + conv = torch.nn.Conv2d(1, 1, k, stride=1, padding=k // 2, padding_mode='reflect') + conv.weight.requires_grad = False + conv.weight[:, :, :, :] = 1. / (k * k) + + new_img = [] + + for i in range(3): + new_img.append(conv(img[:, :, i].unsqueeze(0).unsqueeze(0)).squeeze(0).squeeze(0)[0::k, 0::k]) + + return torch.stack(new_img, dim=2).detach().cpu().numpy() + +def _3d_gaussian_calculator(img, conv3d): + out = conv3d(img.unsqueeze(0).unsqueeze(0)).squeeze(0).squeeze(0) + return out + +def _generate_3d_gaussian_kernel(): + kernel = cv2.getGaussianKernel(11, 1.5) + window = np.outer(kernel, kernel.transpose()) + kernel_3 = cv2.getGaussianKernel(11, 1.5) + kernel = torch.tensor(np.stack([window * k for k in kernel_3], axis=0)) + conv3d = torch.nn.Conv3d(1, 1, (11, 11, 11), stride=1, padding=(5, 5, 5), bias=False, padding_mode='replicate') + conv3d.weight.requires_grad = False + conv3d.weight[0, 0, :, :, :] = kernel + return conv3d + +def _ssim_3d(img1, img2, max_value): + assert len(img1.shape) == 3 and len(img2.shape) == 3 + """Calculate SSIM (structural similarity) for one channel images. + + It is called by func:`calculate_ssim`. + + Args: + img1 (ndarray): Images with range [0, 255]/[0, 1] with order 'HWC'. + img2 (ndarray): Images with range [0, 255]/[0, 1] with order 'HWC'. + + Returns: + float: ssim result. + """ + C1 = (0.01 * max_value) ** 2 + C2 = (0.03 * max_value) ** 2 + img1 = img1.astype(np.float64) + img2 = img2.astype(np.float64) + + kernel = _generate_3d_gaussian_kernel().cuda() + + img1 = torch.tensor(img1).float().cuda() + img2 = torch.tensor(img2).float().cuda() + + + mu1 = _3d_gaussian_calculator(img1, kernel) + mu2 = _3d_gaussian_calculator(img2, kernel) + + mu1_sq = mu1 ** 2 + mu2_sq = mu2 ** 2 + mu1_mu2 = mu1 * mu2 + sigma1_sq = _3d_gaussian_calculator(img1 ** 2, kernel) - mu1_sq + sigma2_sq = _3d_gaussian_calculator(img2 ** 2, kernel) - mu2_sq + sigma12 = _3d_gaussian_calculator(img1*img2, kernel) - mu1_mu2 + + ssim_map = ((2 * mu1_mu2 + C1) * + (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * + (sigma1_sq + sigma2_sq + C2)) + return float(ssim_map.mean()) + +def _ssim_cly(img1, img2): + assert len(img1.shape) == 2 and len(img2.shape) == 2 + """Calculate SSIM (structural similarity) for one channel images. + + It is called by func:`calculate_ssim`. + + Args: + img1 (ndarray): Images with range [0, 255] with order 'HWC'. + img2 (ndarray): Images with range [0, 255] with order 'HWC'. + + Returns: + float: ssim result. + """ + + C1 = (0.01 * 255)**2 + C2 = (0.03 * 255)**2 + img1 = img1.astype(np.float64) + img2 = img2.astype(np.float64) + + kernel = cv2.getGaussianKernel(11, 1.5) + # print(kernel) + window = np.outer(kernel, kernel.transpose()) + + bt = cv2.BORDER_REPLICATE + + mu1 = cv2.filter2D(img1, -1, window, borderType=bt) + mu2 = cv2.filter2D(img2, -1, window,borderType=bt) + + mu1_sq = mu1**2 + mu2_sq = mu2**2 + mu1_mu2 = mu1 * mu2 + sigma1_sq = cv2.filter2D(img1**2, -1, window, borderType=bt) - mu1_sq + sigma2_sq = cv2.filter2D(img2**2, -1, window, borderType=bt) - mu2_sq + sigma12 = cv2.filter2D(img1 * img2, -1, window, borderType=bt) - mu1_mu2 + + ssim_map = ((2 * mu1_mu2 + C1) * + (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * + (sigma1_sq + sigma2_sq + C2)) + return ssim_map.mean() + + +def calculate_ssim(img1, + img2, + crop_border, + input_order='HWC', + test_y_channel=False): + """Calculate SSIM (structural similarity). + + Ref: + Image quality assessment: From error visibility to structural similarity + + The results are the same as that of the official released MATLAB code in + https://ece.uwaterloo.ca/~z70wang/research/ssim/. + + For three-channel images, SSIM is calculated for each channel and then + averaged. + + Args: + img1 (ndarray): Images with range [0, 255]. + img2 (ndarray): Images with range [0, 255]. + crop_border (int): Cropped pixels in each edge of an image. These + pixels are not involved in the SSIM calculation. + input_order (str): Whether the input order is 'HWC' or 'CHW'. + Default: 'HWC'. + test_y_channel (bool): Test on Y channel of YCbCr. Default: False. + + Returns: + float: ssim result. + """ + + assert img1.shape == img2.shape, ( + f'Image shapes are differnet: {img1.shape}, {img2.shape}.') + if input_order not in ['HWC', 'CHW']: + raise ValueError( + f'Wrong input_order {input_order}. Supported input_orders are ' + '"HWC" and "CHW"') + + if type(img1) == torch.Tensor: + if len(img1.shape) == 4: + img1 = img1.squeeze(0) + img1 = img1.detach().cpu().numpy().transpose(1,2,0) + if type(img2) == torch.Tensor: + if len(img2.shape) == 4: + img2 = img2.squeeze(0) + img2 = img2.detach().cpu().numpy().transpose(1,2,0) + + img1 = reorder_image(img1, input_order=input_order) + img2 = reorder_image(img2, input_order=input_order) + + img1 = img1.astype(np.float64) + img2 = img2.astype(np.float64) + + if crop_border != 0: + img1 = img1[crop_border:-crop_border, crop_border:-crop_border, ...] + img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...] + + if test_y_channel: + img1 = to_y_channel(img1) + img2 = to_y_channel(img2) + return _ssim_cly(img1[..., 0], img2[..., 0]) + + + ssims = [] + # ssims_before = [] + + # skimage_before = skimage.metrics.structural_similarity(img1, img2, data_range=255., multichannel=True) + # print('.._skimage', + # skimage.metrics.structural_similarity(img1, img2, data_range=255., multichannel=True)) + max_value = 1 if img1.max() <= 1 else 255 + with torch.no_grad(): + final_ssim = _ssim_3d(img1, img2, max_value) + ssims.append(final_ssim) + + # for i in range(img1.shape[2]): + # ssims_before.append(_ssim(img1, img2)) + + # print('..ssim mean , new {:.4f} and before {:.4f} .... skimage before {:.4f}'.format(np.array(ssims).mean(), np.array(ssims_before).mean(), skimage_before)) + # ssims.append(skimage.metrics.structural_similarity(img1[..., i], img2[..., i], multichannel=False)) + + return np.array(ssims).mean() diff --git a/basicsr/models/__init__.py b/basicsr/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..10f3b9fd8d7fbc5213966c695cd230618cdf88f8 --- /dev/null +++ b/basicsr/models/__init__.py @@ -0,0 +1,42 @@ +import importlib +from os import path as osp + +from basicsr.utils import get_root_logger, scandir + +# automatically scan and import model modules +# scan all the files under the 'models' folder and collect files ending with +# '_model.py' +model_folder = osp.dirname(osp.abspath(__file__)) +model_filenames = [ + osp.splitext(osp.basename(v))[0] for v in scandir(model_folder) + if v.endswith('_model.py') +] +# import all the model modules +_model_modules = [ + importlib.import_module(f'basicsr.models.{file_name}') + for file_name in model_filenames +] + + +def create_model(opt): + """Create model. + + Args: + opt (dict): Configuration. It constains: + model_type (str): Model type. + """ + model_type = opt['model_type'] + + # dynamic instantiation + for module in _model_modules: + model_cls = getattr(module, model_type, None) + if model_cls is not None: + break + if model_cls is None: + raise ValueError(f'Model {model_type} is not found.') + + model = model_cls(opt) + + logger = get_root_logger() + logger.info(f'Model [{model.__class__.__name__}] is created.') + return model diff --git a/basicsr/models/archs/__init__.py b/basicsr/models/archs/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1276c66987b47c8c91a5fb7ca5df8d670e341e1a --- /dev/null +++ b/basicsr/models/archs/__init__.py @@ -0,0 +1,45 @@ +import importlib +from os import path as osp + +from basicsr.utils import scandir + +# automatically scan and import arch modules +# scan all the files under the 'archs' folder and collect files ending with +# '_arch.py' +arch_folder = osp.dirname(osp.abspath(__file__)) +arch_filenames = [ + osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder) + if v.endswith('_arch.py') +] +# import all the arch modules +_arch_modules = [ + importlib.import_module(f'basicsr.models.archs.{file_name}') + for file_name in arch_filenames +] + + +def dynamic_instantiation(modules, cls_type, opt): + """Dynamically instantiate class. + + Args: + modules (list[importlib modules]): List of modules from importlib + files. + cls_type (str): Class type. + opt (dict): Class initialization kwargs. + + Returns: + class: Instantiated class. + """ + for module in modules: + cls_ = getattr(module, cls_type, None) + if cls_ is not None: + break + if cls_ is None: + raise ValueError(f'{cls_type} is not found.') + return cls_(**opt) + + +def define_network(opt): + network_type = opt.pop('type') + net = dynamic_instantiation(_arch_modules, network_type, opt) + return net diff --git a/basicsr/models/archs/arch_util.py b/basicsr/models/archs/arch_util.py new file mode 100644 index 0000000000000000000000000000000000000000..66b4a527e451a9e7ede97a23eddbcbad242b1da8 --- /dev/null +++ b/basicsr/models/archs/arch_util.py @@ -0,0 +1,255 @@ +import math +import torch +from torch import nn as nn +from torch.nn import functional as F +from torch.nn import init as init +from torch.nn.modules.batchnorm import _BatchNorm + +from basicsr.utils import get_root_logger + +# try: +# from basicsr.models.ops.dcn import (ModulatedDeformConvPack, +# modulated_deform_conv) +# except ImportError: +# # print('Cannot import dcn. Ignore this warning if dcn is not used. ' +# # 'Otherwise install BasicSR with compiling dcn.') +# + +@torch.no_grad() +def default_init_weights(module_list, scale=1, bias_fill=0, **kwargs): + """Initialize network weights. + + Args: + module_list (list[nn.Module] | nn.Module): Modules to be initialized. + scale (float): Scale initialized weights, especially for residual + blocks. Default: 1. + bias_fill (float): The value to fill bias. Default: 0 + kwargs (dict): Other arguments for initialization function. + """ + if not isinstance(module_list, list): + module_list = [module_list] + for module in module_list: + for m in module.modules(): + if isinstance(m, nn.Conv2d): + init.kaiming_normal_(m.weight, **kwargs) + m.weight.data *= scale + if m.bias is not None: + m.bias.data.fill_(bias_fill) + elif isinstance(m, nn.Linear): + init.kaiming_normal_(m.weight, **kwargs) + m.weight.data *= scale + if m.bias is not None: + m.bias.data.fill_(bias_fill) + elif isinstance(m, _BatchNorm): + init.constant_(m.weight, 1) + if m.bias is not None: + m.bias.data.fill_(bias_fill) + + +def make_layer(basic_block, num_basic_block, **kwarg): + """Make layers by stacking the same blocks. + + Args: + basic_block (nn.module): nn.module class for basic block. + num_basic_block (int): number of blocks. + + Returns: + nn.Sequential: Stacked blocks in nn.Sequential. + """ + layers = [] + for _ in range(num_basic_block): + layers.append(basic_block(**kwarg)) + return nn.Sequential(*layers) + + +class ResidualBlockNoBN(nn.Module): + """Residual block without BN. + + It has a style of: + ---Conv-ReLU-Conv-+- + |________________| + + Args: + num_feat (int): Channel number of intermediate features. + Default: 64. + res_scale (float): Residual scale. Default: 1. + pytorch_init (bool): If set to True, use pytorch default init, + otherwise, use default_init_weights. Default: False. + """ + + def __init__(self, num_feat=64, res_scale=1, pytorch_init=False): + super(ResidualBlockNoBN, self).__init__() + self.res_scale = res_scale + self.conv1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True) + self.conv2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True) + self.relu = nn.ReLU(inplace=True) + + if not pytorch_init: + default_init_weights([self.conv1, self.conv2], 0.1) + + def forward(self, x): + identity = x + out = self.conv2(self.relu(self.conv1(x))) + return identity + out * self.res_scale + + +class Upsample(nn.Sequential): + """Upsample module. + + Args: + scale (int): Scale factor. Supported scales: 2^n and 3. + num_feat (int): Channel number of intermediate features. + """ + + def __init__(self, scale, num_feat): + m = [] + if (scale & (scale - 1)) == 0: # scale = 2^n + for _ in range(int(math.log(scale, 2))): + m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(2)) + elif scale == 3: + m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(3)) + else: + raise ValueError(f'scale {scale} is not supported. ' + 'Supported scales: 2^n and 3.') + super(Upsample, self).__init__(*m) + + +def flow_warp(x, + flow, + interp_mode='bilinear', + padding_mode='zeros', + align_corners=True): + """Warp an image or feature map with optical flow. + + Args: + x (Tensor): Tensor with size (n, c, h, w). + flow (Tensor): Tensor with size (n, h, w, 2), normal value. + interp_mode (str): 'nearest' or 'bilinear'. Default: 'bilinear'. + padding_mode (str): 'zeros' or 'border' or 'reflection'. + Default: 'zeros'. + align_corners (bool): Before pytorch 1.3, the default value is + align_corners=True. After pytorch 1.3, the default value is + align_corners=False. Here, we use the True as default. + + Returns: + Tensor: Warped image or feature map. + """ + assert x.size()[-2:] == flow.size()[1:3] + _, _, h, w = x.size() + # create mesh grid + grid_y, grid_x = torch.meshgrid( + torch.arange(0, h).type_as(x), + torch.arange(0, w).type_as(x)) + grid = torch.stack((grid_x, grid_y), 2).float() # W(x), H(y), 2 + grid.requires_grad = False + + vgrid = grid + flow + # scale grid to [-1,1] + vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(w - 1, 1) - 1.0 + vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(h - 1, 1) - 1.0 + vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3) + output = F.grid_sample( + x, + vgrid_scaled, + mode=interp_mode, + padding_mode=padding_mode, + align_corners=align_corners) + + # TODO, what if align_corners=False + return output + + +def resize_flow(flow, + size_type, + sizes, + interp_mode='bilinear', + align_corners=False): + """Resize a flow according to ratio or shape. + + Args: + flow (Tensor): Precomputed flow. shape [N, 2, H, W]. + size_type (str): 'ratio' or 'shape'. + sizes (list[int | float]): the ratio for resizing or the final output + shape. + 1) The order of ratio should be [ratio_h, ratio_w]. For + downsampling, the ratio should be smaller than 1.0 (i.e., ratio + < 1.0). For upsampling, the ratio should be larger than 1.0 (i.e., + ratio > 1.0). + 2) The order of output_size should be [out_h, out_w]. + interp_mode (str): The mode of interpolation for resizing. + Default: 'bilinear'. + align_corners (bool): Whether align corners. Default: False. + + Returns: + Tensor: Resized flow. + """ + _, _, flow_h, flow_w = flow.size() + if size_type == 'ratio': + output_h, output_w = int(flow_h * sizes[0]), int(flow_w * sizes[1]) + elif size_type == 'shape': + output_h, output_w = sizes[0], sizes[1] + else: + raise ValueError( + f'Size type should be ratio or shape, but got type {size_type}.') + + input_flow = flow.clone() + ratio_h = output_h / flow_h + ratio_w = output_w / flow_w + input_flow[:, 0, :, :] *= ratio_w + input_flow[:, 1, :, :] *= ratio_h + resized_flow = F.interpolate( + input=input_flow, + size=(output_h, output_w), + mode=interp_mode, + align_corners=align_corners) + return resized_flow + + +# TODO: may write a cpp file +def pixel_unshuffle(x, scale): + """ Pixel unshuffle. + + Args: + x (Tensor): Input feature with shape (b, c, hh, hw). + scale (int): Downsample ratio. + + Returns: + Tensor: the pixel unshuffled feature. + """ + b, c, hh, hw = x.size() + out_channel = c * (scale**2) + assert hh % scale == 0 and hw % scale == 0 + h = hh // scale + w = hw // scale + x_view = x.view(b, c, h, scale, w, scale) + return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w) + + +# class DCNv2Pack(ModulatedDeformConvPack): +# """Modulated deformable conv for deformable alignment. +# +# Different from the official DCNv2Pack, which generates offsets and masks +# from the preceding features, this DCNv2Pack takes another different +# features to generate offsets and masks. +# +# Ref: +# Delving Deep into Deformable Alignment in Video Super-Resolution. +# """ +# +# def forward(self, x, feat): +# out = self.conv_offset(feat) +# o1, o2, mask = torch.chunk(out, 3, dim=1) +# offset = torch.cat((o1, o2), dim=1) +# mask = torch.sigmoid(mask) +# +# offset_absmean = torch.mean(torch.abs(offset)) +# if offset_absmean > 50: +# logger = get_root_logger() +# logger.warning( +# f'Offset abs mean is {offset_absmean}, larger than 50.') +# +# return modulated_deform_conv(x, offset, mask, self.weight, self.bias, +# self.stride, self.padding, self.dilation, +# self.groups, self.deformable_groups) diff --git a/basicsr/models/archs/restormer_arch.py b/basicsr/models/archs/restormer_arch.py new file mode 100644 index 0000000000000000000000000000000000000000..ace138ed7fc70f4fb6c4c32609ae96dee3126e94 --- /dev/null +++ b/basicsr/models/archs/restormer_arch.py @@ -0,0 +1,527 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import numbers +from torch import einsum + +from einops import rearrange +from basicsr.utils.nano import psf2otf + +try: + from flash_attn import flash_attn_func +except: + print("Flash attention is required") + raise NotImplementedError + + +def to_3d(x): + return rearrange(x, 'b c h w -> b (h w) c') + +def to_4d(x,h,w): + return rearrange(x, 'b (h w) c -> b c h w',h=h,w=w) + +class BiasFree_LayerNorm(nn.Module): + def __init__(self, normalized_shape): + super(BiasFree_LayerNorm, self).__init__() + if isinstance(normalized_shape, numbers.Integral): + normalized_shape = (normalized_shape,) + normalized_shape = torch.Size(normalized_shape) + + assert len(normalized_shape) == 1 + + self.weight = nn.Parameter(torch.ones(normalized_shape)) + self.normalized_shape = normalized_shape + + def forward(self, x): + sigma = x.var(-1, keepdim=True, unbiased=False) + return x / torch.sqrt(sigma+1e-5) * self.weight + + +class WithBias_LayerNorm(nn.Module): + def __init__(self, normalized_shape): + super(WithBias_LayerNorm, self).__init__() + if isinstance(normalized_shape, numbers.Integral): + normalized_shape = (normalized_shape,) + normalized_shape = torch.Size(normalized_shape) + + assert len(normalized_shape) == 1 + + self.weight = nn.Parameter(torch.ones(normalized_shape)) + self.bias = nn.Parameter(torch.zeros(normalized_shape)) + self.normalized_shape = normalized_shape + + def forward(self, x): + mu = x.mean(-1, keepdim=True) + sigma = x.var(-1, keepdim=True, unbiased=False) + return (x - mu) / torch.sqrt(sigma+1e-5) * self.weight + self.bias + + +class LayerNorm(nn.Module): + def __init__(self, dim, LayerNorm_type): + super(LayerNorm, self).__init__() + if LayerNorm_type =='BiasFree': + self.body = BiasFree_LayerNorm(dim) + else: + self.body = WithBias_LayerNorm(dim) + + def forward(self, x): + h, w = x.shape[-2:] + return to_4d(self.body(to_3d(x)), h, w) + + + +########################################################################## +## Gated-Dconv Feed-Forward Network (GDFN) +class FeedForward(nn.Module): + def __init__(self, dim, ffn_expansion_factor, bias): + super(FeedForward, self).__init__() + + hidden_features = int(dim*ffn_expansion_factor) + + self.project_in = nn.Conv2d(dim, hidden_features*2, kernel_size=1, bias=bias) + + self.dwconv = nn.Conv2d(hidden_features*2, hidden_features*2, kernel_size=3, stride=1, padding=1, groups=hidden_features*2, bias=bias) + + self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias) + + def forward(self, x): + x = self.project_in(x) + x1, x2 = self.dwconv(x).chunk(2, dim=1) + x = F.gelu(x1) * x2 + x = self.project_out(x) + return x + + + +########################################################################## +## Multi-DConv Head Transposed Self-Attention (MDTA) +class Attention(nn.Module): + def __init__(self, dim, num_heads, bias, ksize=0): + super(Attention, self).__init__() + self.num_heads = num_heads + self.ksize = ksize + self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1)) + + self.qkv = nn.Conv2d(dim, dim*3, kernel_size=1, bias=bias) + self.qkv_dwconv = nn.Conv2d(dim*3, dim*3, kernel_size=3, stride=1, padding=1, groups=dim*3, bias=bias) + self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias) + if ksize: + self.avg = torch.nn.AvgPool2d(kernel_size=ksize, stride=1, padding=(ksize-1) //2) + + + def forward(self, x): + b,c,h,w = x.shape + + qkv = self.qkv_dwconv(self.qkv(x)) + q,k,v = qkv.chunk(3, dim=1) + + if self.ksize: + q = q - self.avg(q) + + q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads) + k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads) + v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads) + + q = torch.nn.functional.normalize(q, dim=-1) + k = torch.nn.functional.normalize(k, dim=-1) + + attn = (q @ k.transpose(-2, -1)) * self.temperature + attn = attn.softmax(dim=-1) + + out = (attn @ v) + + out = rearrange(out, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w) + + out = self.project_out(out) + return out + + +########################################################################## +## Overlapped image patch embedding with 3x3 Conv +class OverlapPatchEmbed(nn.Module): + def __init__(self, in_c=3, embed_dim=48, bias=False): + super(OverlapPatchEmbed, self).__init__() + + self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=3, stride=1, padding=1, bias=bias) + + def forward(self, x): + x = self.proj(x) + + return x + + + +########################################################################## +## Resizing modules +class Downsample(nn.Module): + def __init__(self, n_feat): + super(Downsample, self).__init__() + + self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat//2, kernel_size=3, stride=1, padding=1, bias=False), + nn.PixelUnshuffle(2)) + + def forward(self, x): + return self.body(x) + +class Upsample(nn.Module): + def __init__(self, n_feat): + super(Upsample, self).__init__() + + self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat*2, kernel_size=3, stride=1, padding=1, bias=False), + nn.PixelShuffle(2)) + + def forward(self, x): + return self.body(x) + + +def to(x): + return {'device': x.device, 'dtype': x.dtype} + +def pair(x): + return (x, x) if not isinstance(x, tuple) else x + +def expand_dim(t, dim, k): + t = t.unsqueeze(dim = dim) + expand_shape = [-1] * len(t.shape) + expand_shape[dim] = k + return t.expand(*expand_shape) + +def rel_to_abs(x): + b, l, m = x.shape + r = (m + 1) // 2 + + col_pad = torch.zeros((b, l, 1), **to(x)) + x = torch.cat((x, col_pad), dim = 2) + flat_x = rearrange(x, 'b l c -> b (l c)') + flat_pad = torch.zeros((b, m - l), **to(x)) + flat_x_padded = torch.cat((flat_x, flat_pad), dim = 1) + final_x = flat_x_padded.reshape(b, l + 1, m) + final_x = final_x[:, :l, -r:] + return final_x + +def relative_logits_1d(q, rel_k): + b, h, w, _ = q.shape + r = (rel_k.shape[0] + 1) // 2 + + logits = einsum('b x y d, r d -> b x y r', q, rel_k) + logits = rearrange(logits, 'b x y r -> (b x) y r') + logits = rel_to_abs(logits) + + logits = logits.reshape(b, h, w, r) + logits = expand_dim(logits, dim = 2, k = r) + return logits + + +class RelPosEmb(nn.Module): + def __init__( + self, + block_size, + rel_size, + dim_head + ): + super().__init__() + height = width = rel_size + scale = dim_head ** -0.5 + + self.block_size = block_size + self.rel_height = nn.Parameter(torch.randn(height * 2 - 1, dim_head) * scale) + self.rel_width = nn.Parameter(torch.randn(width * 2 - 1, dim_head) * scale) + + def forward(self, q): + block = self.block_size + + q = rearrange(q, 'b (x y) c -> b x y c', x = block) + rel_logits_w = relative_logits_1d(q, self.rel_width) + rel_logits_w = rearrange(rel_logits_w, 'b x i y j-> b (x y) (i j)') + + q = rearrange(q, 'b x y d -> b y x d') + rel_logits_h = relative_logits_1d(q, self.rel_height) + rel_logits_h = rearrange(rel_logits_h, 'b x i y j -> b (y x) (j i)') + return rel_logits_w + rel_logits_h + + +########################################################################## +## Overlapping Cross-Attention (OCA) +class OCAB(nn.Module): + def __init__(self, dim, window_size, overlap_ratio, num_heads, dim_head, bias, ksize=0): + super(OCAB, self).__init__() + self.num_spatial_heads = num_heads + self.dim = dim + self.window_size = window_size + self.overlap_win_size = int(window_size * overlap_ratio) + window_size + self.dim_head = dim_head + self.inner_dim = self.dim_head * self.num_spatial_heads + self.scale = self.dim_head**-0.5 + self.ksize = ksize + + self.unfold = nn.Unfold(kernel_size=(self.overlap_win_size, self.overlap_win_size), stride=window_size, padding=(self.overlap_win_size-window_size)//2) + self.qkv = nn.Conv2d(self.dim, self.inner_dim*3, kernel_size=1, bias=bias) + self.project_out = nn.Conv2d(self.inner_dim, dim, kernel_size=1, bias=bias) + self.rel_pos_emb = RelPosEmb( + block_size = window_size, + rel_size = window_size + (self.overlap_win_size - window_size), + dim_head = self.dim_head + ) + if ksize: + self.avg = torch.nn.AvgPool2d(kernel_size=ksize, stride=1, padding=(ksize-1) //2) + + def forward(self, x): + b, c, h, w = x.shape + + qkv = self.qkv(x) + qs, ks, vs = qkv.chunk(3, dim=1) + + if self.ksize: + qs = qs - self.avg(qs) + + # spatial attention + qs = rearrange(qs, 'b c (h p1) (w p2) -> (b h w) (p1 p2) c', p1 = self.window_size, p2 = self.window_size) + ks, vs = map(lambda t: self.unfold(t), (ks, vs)) + ks, vs = map(lambda t: rearrange(t, 'b (c j) i -> (b i) j c', c = self.inner_dim), (ks, vs)) + + #split heads + qs, ks, vs = map(lambda t: rearrange(t, 'b n (head c) -> (b head) n c', head = self.num_spatial_heads), (qs, ks, vs)) + + # attention + qs = qs * self.scale + spatial_attn = (qs @ ks.transpose(-2, -1)) + spatial_attn += self.rel_pos_emb(qs) + spatial_attn = spatial_attn.softmax(dim=-1) + + out = (spatial_attn @ vs) + + out = rearrange(out, '(b h w head) (p1 p2) c -> b (head c) (h p1) (w p2)', head = self.num_spatial_heads, h = h // self.window_size, w = w // self.window_size, p1 = self.window_size, p2 = self.window_size) + + # merge spatial and channel + out = self.project_out(out) + + return out + + +class AttentionFusion(nn.Module): + def __init__(self, dim, bias, channel_fusion): + super(AttentionFusion, self).__init__() + + self.channel_fusion = channel_fusion + self.fusion = nn.Sequential( + nn.Conv2d(dim, dim // 2, kernel_size=1, bias=bias), + nn.GELU(), + nn.Conv2d(dim // 2, dim // 2, kernel_size=1, bias=bias) + ) + self.dim = dim // 2 + + def forward(self, x): + fusion_map = self.fusion(x) + if self.channel_fusion: + weight = F.sigmoid(torch.mean(fusion_map, 1, True)) + else: + weight = F.sigmoid(torch.mean(fusion_map, (2,3), True)) + fused_feature = x[:, :self.dim] * weight + x[:, self.dim:] * (1-weight) # [:, :self.dim] == SA + return fused_feature + + + +class Transformer_STAF(nn.Module): + def __init__(self, dim, window_size, overlap_ratio, num_channel_heads, num_spatial_heads, spatial_dim_head, ffn_expansion_factor, bias, LayerNorm_type, channel_fusion, query_ksize=0): + super(Transformer_STAF, self).__init__() + + self.spatial_attn = OCAB(dim, window_size, overlap_ratio, num_spatial_heads, spatial_dim_head, bias, ksize=query_ksize) + self.channel_attn = Attention(dim, num_channel_heads, bias, ksize=query_ksize) + + self.norm1 = LayerNorm(dim, LayerNorm_type) + self.norm2 = LayerNorm(dim, LayerNorm_type) + self.norm3 = LayerNorm(dim, LayerNorm_type) + self.norm4 = LayerNorm(dim, LayerNorm_type) + + self.channel_ffn = FeedForward(dim, ffn_expansion_factor, bias) + self.spatial_ffn = FeedForward(dim, ffn_expansion_factor, bias) + + self.fusion = AttentionFusion(dim*2, bias, channel_fusion) + + def forward(self, x): + sa = x + self.spatial_attn(self.norm1(x)) + sa = sa + self.spatial_ffn(self.norm2(sa)) + ca = x + self.channel_attn(self.norm3(x)) + ca = ca + self.channel_ffn(self.norm4(ca)) + fused = self.fusion(torch.cat([sa, ca], 1)) + + return fused + + +class MAFG_CA(nn.Module): + def __init__(self, embed_dim, num_heads, M, window_size=0, eps=1e-6): + super().__init__() + self.M = M + self.Q_idx = M // 2 + self.embed_dim = embed_dim + self.num_heads = num_heads + self.head_dim = embed_dim // num_heads + self.M = M + self.wsize = window_size + + self.proj_high = nn.Conv2d(3, embed_dim, kernel_size=1) + self.proj_rgb = nn.Conv2d(embed_dim, 3, kernel_size=1) + + self.norm = nn.LayerNorm(embed_dim, eps=eps) + self.qkv = nn.Linear(embed_dim, embed_dim*3, bias=False) + self.proj_out = nn.Linear(embed_dim, embed_dim, bias=False) + self.max_seq = 2**16-1 + + # window based sliding similar to OCAB + self.overlap_wsize = int(self.wsize * 0.5) + self.wsize + self.unfold = nn.Unfold(kernel_size=(self.overlap_wsize, self.overlap_wsize), stride=window_size, padding=(self.overlap_wsize-self.wsize)//2) + self.scale = self.embed_dim ** -0.5 + self.pos_emb_q = nn.Parameter(torch.zeros(self.wsize**2, embed_dim)) + self.pos_emb_k = nn.Parameter(torch.zeros(self.overlap_wsize**2, embed_dim)) + nn.init.trunc_normal_(self.pos_emb_q, std=0.02) + nn.init.trunc_normal_(self.pos_emb_k, std=0.02) + + def forward(self, x): + x = self.proj_high(x) + BM,E,H,W = x.shape + + x_seq = x.view(BM,E,-1).permute(0,2,1) + x_seq = self.norm(x_seq) + B = BM // self.M + QKV = self.qkv(x_seq) + QKV = QKV.view(BM, H, W, 3, -1).permute(3,0,4,1,2).contiguous() + Q,K,V = QKV[0], QKV[1], QKV[2] + Q_bm = Q.view(B, self.M, E, H,W) + _Q = Q_bm[:, self.Q_idx:self.Q_idx+1] + Q = torch.stack([__Q.repeat(self.M,1,1,1) for __Q in _Q]).view(BM,E,H,W) + + Q = rearrange(Q, 'b c (h p1) (w p2) -> (b h w) (p1 p2) c', p1 = self.wsize, p2 = self.wsize) + K,V = map(lambda t: self.unfold(t), (K,V)) + if K.shape[-1] > 10000: # Inference + b,_,pp = K.shape + K = K.view(b,self.embed_dim,-1,pp).permute(0,3,2,1).reshape(b*pp,-1,self.embed_dim) + V = V.view(b,self.embed_dim,-1,pp).permute(0,3,2,1).reshape(b*pp,-1,self.embed_dim) + else: + K,V = map(lambda t: rearrange(t, 'b (c j) i -> (b i) j c', c = self.embed_dim), (K,V)) + + # Absolute positional embedding + Q = Q + self.pos_emb_q + K = K + self.pos_emb_k + + s, eq, _ = Q.shape + _, ek, _ = K.shape + Q = Q.view(s, eq, self.num_heads,self.head_dim).half() + K = K.view(s, ek, self.num_heads,self.head_dim).half() + V = V.view(s, ek, self.num_heads,self.head_dim).half() + if s > self.max_seq: # maximum allowed sequence of flash attention + outs = [] + sp = self.max_seq + _max = s // sp + 1 + for i in range(_max): + outs.append(flash_attn_func(Q[i*sp: (i+1)*sp], K[i*sp: (i+1)*sp], V[i*sp: (i+1)*sp], causal=False)) + out = torch.cat(outs).to(torch.float32) + else: + out = flash_attn_func(Q, K, V, causal=False).to(torch.float32) + out = rearrange(out, '(b nh nw) (ph pw) h d -> b (nh ph nw pw) (h d)', nh=H//self.wsize, nw=W//self.wsize, ph=self.wsize, pw=self.wsize) + out = self.proj_out(out) + + mixed_feature = out.view(BM,H,W,E).permute(0,3,1,2).contiguous() + x + return self.proj_rgb(mixed_feature).reshape(B,-1,H,W) + + +########################################################################## +## Aberration Correction Transformers for Metalens +class ACFormer(nn.Module): + def __init__(self, + inp_channels=3, + out_channels=3, + dim = 48, + num_blocks = [4,6,6,8], + num_refinement_blocks = 4, + channel_heads = [1,2,4,8], + spatial_heads = [2,2,3,4], + overlap_ratio=[0.5, 0.5, 0.5, 0.5], + window_size = 8, + spatial_dim_head = 16, + bias = False, + ffn_expansion_factor = 2.66, + LayerNorm_type = 'WithBias', ## Other option 'BiasFree' + M=13, + ca_heads=2, + ca_dim=32, + window_size_ca=0, + query_ksize=None + ): + + super(ACFormer, self).__init__() + self.center_idx = M // 2 + self.ca = MAFG_CA(embed_dim=ca_dim, num_heads=ca_heads, M=M, window_size=window_size_ca) + self.patch_embed = OverlapPatchEmbed(inp_channels, dim) + + self.encoder_level1 = nn.Sequential(*[Transformer_STAF(dim=dim, window_size = window_size, overlap_ratio=overlap_ratio[0], num_channel_heads=channel_heads[0], num_spatial_heads=spatial_heads[0], spatial_dim_head = spatial_dim_head, ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type, channel_fusion=False, query_ksize=0) for i in range(num_blocks[0])]) + + self.down1_2 = Downsample(dim) ## From Level 1 to Level 2 + self.encoder_level2 = nn.Sequential(*[Transformer_STAF(dim=int(dim*2**1), window_size = window_size, overlap_ratio=overlap_ratio[1], num_channel_heads=channel_heads[1], num_spatial_heads=spatial_heads[1], spatial_dim_head = spatial_dim_head, ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type, channel_fusion=False, query_ksize=0) for i in range(num_blocks[1])]) + + self.down2_3 = Downsample(int(dim*2**1)) ## From Level 2 to Level 3 + self.encoder_level3 = nn.Sequential(*[Transformer_STAF(dim=int(dim*2**2), window_size = window_size, overlap_ratio=overlap_ratio[2], num_channel_heads=channel_heads[2], num_spatial_heads=spatial_heads[2], spatial_dim_head = spatial_dim_head, ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type, channel_fusion=False, query_ksize=0) for i in range(num_blocks[2])]) + + self.down3_4 = Downsample(int(dim*2**2)) ## From Level 3 to Level 4 + self.latent = nn.Sequential(*[Transformer_STAF(dim=int(dim*2**3), window_size = window_size, overlap_ratio=overlap_ratio[3], num_channel_heads=channel_heads[3], num_spatial_heads=spatial_heads[3], spatial_dim_head = spatial_dim_head, ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type, channel_fusion=False, query_ksize=query_ksize[0] if i % 2 == 1 else 0) for i in range(num_blocks[3])]) + + self.up4_3 = Upsample(int(dim*2**3)) ## From Level 4 to Level 3 + self.reduce_chan_level3 = nn.Conv2d(int(dim*2**3), int(dim*2**2), kernel_size=1, bias=bias) + self.decoder_level3 = nn.Sequential(*[Transformer_STAF(dim=int(dim*2**2), window_size = window_size, overlap_ratio=overlap_ratio[2], num_channel_heads=channel_heads[2], num_spatial_heads=spatial_heads[2], spatial_dim_head = spatial_dim_head, ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type, channel_fusion=True, query_ksize=query_ksize[1] if i % 2 == 1 else 0) for i in range(num_blocks[2])]) + + self.up3_2 = Upsample(int(dim*2**2)) ## From Level 3 to Level 2 + self.reduce_chan_level2 = nn.Conv2d(int(dim*2**2), int(dim*2**1), kernel_size=1, bias=bias) + self.decoder_level2 = nn.Sequential(*[Transformer_STAF(dim=int(dim*2**1), window_size = window_size, overlap_ratio=overlap_ratio[1], num_channel_heads=channel_heads[1], num_spatial_heads=spatial_heads[1], spatial_dim_head = spatial_dim_head, ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type, channel_fusion=True, query_ksize=query_ksize[2] if i % 2 == 1 else 0) for i in range(num_blocks[1])]) + + self.up2_1 = Upsample(int(dim*2**1)) ## From Level 2 to Level 1 (NO 1x1 conv to reduce channels) + + self.decoder_level1 = nn.Sequential(*[Transformer_STAF(dim=int(dim*2**1), window_size = window_size, overlap_ratio=overlap_ratio[0], num_channel_heads=channel_heads[0], num_spatial_heads=spatial_heads[0], spatial_dim_head = spatial_dim_head, ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type, channel_fusion=True, query_ksize=query_ksize[3] if i % 2 == 1 else 0) for i in range(num_blocks[0])]) + + self.refinement = nn.Sequential(*[Transformer_STAF(dim=int(dim*2**1), window_size = window_size, overlap_ratio=overlap_ratio[0], num_channel_heads=channel_heads[0], num_spatial_heads=spatial_heads[0], spatial_dim_head = spatial_dim_head, ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type, channel_fusion=True, query_ksize=query_ksize[4] if i % 2 == 1 else 0) for i in range(num_refinement_blocks)]) + + self.output = nn.Conv2d(int(dim*2**1), out_channels, kernel_size=3, stride=1, padding=1, bias=bias) + + def forward(self, inp_img): + if inp_img.ndim == 5: + B,M,C,H,W = inp_img.shape + center_img = inp_img[:, self.center_idx] + inp_img = inp_img.view(B*M,C,H,W).contiguous() + else: + center_img = inp_img + + if self.ca is None: + inp_enc_level1 = inp_img.view(B,M*C,H,W) + else: + inp_enc_level1 = self.ca(inp_img) + + inp_enc_level1 = self.patch_embed(inp_enc_level1) + + out_enc_level1 = self.encoder_level1(inp_enc_level1) + + inp_enc_level2 = self.down1_2(out_enc_level1) + out_enc_level2 = self.encoder_level2(inp_enc_level2) + + inp_enc_level3 = self.down2_3(out_enc_level2) + out_enc_level3 = self.encoder_level3(inp_enc_level3) + + inp_enc_level4 = self.down3_4(out_enc_level3) + latent = self.latent(inp_enc_level4) + + inp_dec_level3 = self.up4_3(latent) + inp_dec_level3 = torch.cat([inp_dec_level3, out_enc_level3], 1) + inp_dec_level3 = self.reduce_chan_level3(inp_dec_level3) + out_dec_level3 = self.decoder_level3(inp_dec_level3) + + inp_dec_level2 = self.up3_2(out_dec_level3) + inp_dec_level2 = torch.cat([inp_dec_level2, out_enc_level2], 1) + inp_dec_level2 = self.reduce_chan_level2(inp_dec_level2) + out_dec_level2 = self.decoder_level2(inp_dec_level2) + + inp_dec_level1 = self.up2_1(out_dec_level2) + inp_dec_level1 = torch.cat([inp_dec_level1, out_enc_level1], 1) + out_dec_level1 = self.decoder_level1(inp_dec_level1) + + + out_dec_level1 = self.refinement(out_dec_level1) + out_dec_level1 = self.output(out_dec_level1) + center_img + + return out_dec_level1 \ No newline at end of file diff --git a/basicsr/models/base_model.py b/basicsr/models/base_model.py new file mode 100644 index 0000000000000000000000000000000000000000..df03cd24aee80f3f72bf9c55a74b329ee2aa5bf1 --- /dev/null +++ b/basicsr/models/base_model.py @@ -0,0 +1,376 @@ +import logging +import os +import torch +from collections import OrderedDict +from copy import deepcopy +from torch.nn.parallel import DataParallel, DistributedDataParallel + +from basicsr.models import lr_scheduler as lr_scheduler +from basicsr.utils.dist_util import master_only + +logger = logging.getLogger('basicsr') + + +class BaseModel(): + """Base model.""" + + def __init__(self, opt): + self.opt = opt + self.device = torch.device('cuda' if opt['num_gpu'] != 0 else 'cpu') + self.is_train = opt['is_train'] + self.schedulers = [] + self.optimizers = [] + + def feed_data(self, data): + pass + + def optimize_parameters(self): + pass + + def get_current_visuals(self): + pass + + def save(self, epoch, current_iter): + """Save networks and training state.""" + pass + def validation(self, dataloader, current_iter, tb_logger, save_img=False, rgb2bgr=True, use_image=True, psf=None, ks=None, val_conv=True): + """Validation function. + + Args: + dataloader (torch.utils.data.DataLoader): Validation dataloader. + current_iter (int): Current iteration. + tb_logger (tensorboard logger): Tensorboard logger. + save_img (bool): Whether to save images. Default: False. + rgb2bgr (bool): Whether to save images using rgb2bgr. Default: True + use_image (bool): Whether to use saved images to compute metrics (PSNR, SSIM), if not, then use data directly from network' output. Default: True + """ + if self.opt['dist']: + return self.dist_validation(dataloader, current_iter, tb_logger, save_img, rgb2bgr, use_image, psf, ks, val_conv) + else: + return self.nondist_validation(dataloader, current_iter, tb_logger, save_img, rgb2bgr, use_image, psf, ks, val_conv) + + def model_ema(self, decay=0.999): + net_g = self.get_bare_model(self.net_g) + + net_g_params = dict(net_g.named_parameters()) + net_g_ema_params = dict(self.net_g_ema.named_parameters()) + + for k in net_g_ema_params.keys(): + net_g_ema_params[k].data.mul_(decay).add_( + net_g_params[k].data, alpha=1 - decay) + + def get_current_log(self): + return self.log_dict + + def model_to_device(self, net): + """Model to device. It also warps models with DistributedDataParallel + or DataParallel. + + Args: + net (nn.Module) + """ + + net = net.to(self.device) + # if self.opt['dist']: + # find_unused_parameters = self.opt.get('find_unused_parameters', + # False) + # net = DistributedDataParallel( + # net, + # device_ids=[torch.cuda.current_device()], + # find_unused_parameters=find_unused_parameters) + # elif self.opt['num_gpu'] > 1: + # net = DataParallel(net) + return net + + def setup_schedulers(self): + """Set up schedulers.""" + train_opt = self.opt['train'] + scheduler_type = train_opt['scheduler'].pop('type') + if scheduler_type in ['MultiStepLR', 'MultiStepRestartLR']: + for optimizer in self.optimizers: + self.schedulers.append( + lr_scheduler.MultiStepRestartLR(optimizer, + **train_opt['scheduler'])) + elif scheduler_type == 'CosineAnnealingRestartLR': + for optimizer in self.optimizers: + self.schedulers.append( + lr_scheduler.CosineAnnealingRestartLR( + optimizer, **train_opt['scheduler'])) + elif scheduler_type == 'CosineAnnealingWarmupRestarts': + for optimizer in self.optimizers: + self.schedulers.append( + lr_scheduler.CosineAnnealingWarmupRestarts( + optimizer, **train_opt['scheduler'])) + elif scheduler_type == 'CosineAnnealingRestartCyclicLR': + for optimizer in self.optimizers: + self.schedulers.append( + lr_scheduler.CosineAnnealingRestartCyclicLR( + optimizer, **train_opt['scheduler'])) + elif scheduler_type == 'TrueCosineAnnealingLR': + print('..', 'cosineannealingLR') + for optimizer in self.optimizers: + self.schedulers.append( + torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, **train_opt['scheduler'])) + elif scheduler_type == 'CosineAnnealingLRWithRestart': + print('..', 'CosineAnnealingLR_With_Restart') + for optimizer in self.optimizers: + self.schedulers.append( + lr_scheduler.CosineAnnealingLRWithRestart(optimizer, **train_opt['scheduler'])) + elif scheduler_type == 'LinearLR': + for optimizer in self.optimizers: + self.schedulers.append( + lr_scheduler.LinearLR( + optimizer, train_opt['total_iter'])) + elif scheduler_type == 'VibrateLR': + for optimizer in self.optimizers: + self.schedulers.append( + lr_scheduler.VibrateLR( + optimizer, train_opt['total_iter'])) + else: + raise NotImplementedError( + f'Scheduler {scheduler_type} is not implemented yet.') + + def get_bare_model(self, net): + """Get bare model, especially under wrapping with + DistributedDataParallel or DataParallel. + """ + if isinstance(net, (DataParallel, DistributedDataParallel)): + net = net.module + return net + + @master_only + def print_network(self, net): + """Print the str and parameter number of a network. + + Args: + net (nn.Module) + """ + if isinstance(net, (DataParallel, DistributedDataParallel)): + net_cls_str = (f'{net.__class__.__name__} - ' + f'{net.module.__class__.__name__}') + else: + net_cls_str = f'{net.__class__.__name__}' + + net = self.get_bare_model(net) + net_str = str(net) + net_params = sum(map(lambda x: x.numel(), net.parameters())) + + logger.info( + f'Network: {net_cls_str}, with parameters: {net_params:,d}') + logger.info(net_str) + + def _set_lr(self, lr_groups_l): + """Set learning rate for warmup. + + Args: + lr_groups_l (list): List for lr_groups, each for an optimizer. + """ + for optimizer, lr_groups in zip(self.optimizers, lr_groups_l): + for param_group, lr in zip(optimizer.param_groups, lr_groups): + param_group['lr'] = lr + + def _get_init_lr(self): + """Get the initial lr, which is set by the scheduler. + """ + init_lr_groups_l = [] + for optimizer in self.optimizers: + init_lr_groups_l.append( + [v['initial_lr'] for v in optimizer.param_groups]) + return init_lr_groups_l + + def update_learning_rate(self, current_iter, warmup_iter=-1): + """Update learning rate. + + Args: + current_iter (int): Current iteration. + warmup_iter (int): Warmup iter numbers. -1 for no warmup. + Default: -1. + """ + if current_iter > 1: + for scheduler in self.schedulers: + scheduler.step() + # set up warm-up learning rate + if current_iter < warmup_iter: + # get initial lr for each group + init_lr_g_l = self._get_init_lr() + # modify warming-up learning rates + # currently only support linearly warm up + warm_up_lr_l = [] + for init_lr_g in init_lr_g_l: + warm_up_lr_l.append( + [v / warmup_iter * current_iter for v in init_lr_g]) + # set learning rate + self._set_lr(warm_up_lr_l) + + def get_current_learning_rate(self): + return [ + param_group['lr'] + for param_group in self.optimizers[0].param_groups + ] + + @master_only + def save_network(self, net, net_label, current_iter, param_key='params'): + """Save networks. + + Args: + net (nn.Module | list[nn.Module]): Network(s) to be saved. + net_label (str): Network label. + current_iter (int): Current iter number. + param_key (str | list[str]): The parameter key(s) to save network. + Default: 'params'. + """ + if current_iter == -1: + current_iter = 'latest' + save_filename = f'{net_label}_{current_iter}.pth' + save_path = os.path.join(self.opt['path']['models'], save_filename) + + net = net if isinstance(net, list) else [net] + param_key = param_key if isinstance(param_key, list) else [param_key] + assert len(net) == len( + param_key), 'The lengths of net and param_key should be the same.' + + save_dict = {} + for net_, param_key_ in zip(net, param_key): + net_ = self.get_bare_model(net_) + state_dict = net_.state_dict() + for key, param in state_dict.items(): + if key.startswith('module.'): # remove unnecessary 'module.' + key = key[7:] + state_dict[key] = param.cpu() + save_dict[param_key_] = state_dict + + torch.save(save_dict, save_path) + + def _print_different_keys_loading(self, crt_net, load_net, strict=True): + """Print keys with differnet name or different size when loading models. + + 1. Print keys with differnet names. + 2. If strict=False, print the same key but with different tensor size. + It also ignore these keys with different sizes (not load). + + Args: + crt_net (torch model): Current network. + load_net (dict): Loaded network. + strict (bool): Whether strictly loaded. Default: True. + """ + crt_net = self.get_bare_model(crt_net) + crt_net = crt_net.state_dict() + crt_net_keys = set(crt_net.keys()) + load_net_keys = set(load_net.keys()) + + if crt_net_keys != load_net_keys: + logger.warning('Current net - loaded net:') + for v in sorted(list(crt_net_keys - load_net_keys)): + logger.warning(f' {v}') + logger.warning('Loaded net - current net:') + for v in sorted(list(load_net_keys - crt_net_keys)): + logger.warning(f' {v}') + + # check the size for the same keys + if not strict: + common_keys = crt_net_keys & load_net_keys + for k in common_keys: + if crt_net[k].size() != load_net[k].size(): + logger.warning( + f'Size different, ignore [{k}]: crt_net: ' + f'{crt_net[k].shape}; load_net: {load_net[k].shape}') + load_net[k + '.ignore'] = load_net.pop(k) + + def load_network(self, net, load_path, strict=True, param_key='params'): + """Load network. + + Args: + load_path (str): The path of networks to be loaded. + net (nn.Module): Network. + strict (bool): Whether strictly loaded. + param_key (str): The parameter key of loaded network. If set to + None, use the root 'path'. + Default: 'params'. + """ + net = self.get_bare_model(net) + logger.info( + f'Loading {net.__class__.__name__} model from {load_path}.') + load_net = torch.load( + load_path, map_location=lambda storage, loc: storage) + if param_key is not None: + if param_key not in load_net and 'params' in load_net: + param_key = 'params' + logger.info('Loading: params_ema does not exist, use params.') + load_net = load_net[param_key] + print(' load net keys', load_net.keys) + # remove unnecessary 'module.' + for k, v in deepcopy(load_net).items(): + if k.startswith('module.'): + load_net[k[7:]] = v + load_net.pop(k) + self._print_different_keys_loading(net, load_net, strict) + net.load_state_dict(load_net, strict=strict) + + @master_only + def save_training_state(self, epoch, current_iter): + """Save training states during training, which will be used for + resuming. + + Args: + epoch (int): Current epoch. + current_iter (int): Current iteration. + """ + if current_iter != -1: + state = { + 'epoch': epoch, + 'iter': current_iter, + 'optimizers': [], + 'schedulers': [] + } + for o in self.optimizers: + state['optimizers'].append(o.state_dict()) + for s in self.schedulers: + state['schedulers'].append(s.state_dict()) + save_filename = f'{current_iter}.state' + save_path = os.path.join(self.opt['path']['training_states'], + save_filename) + torch.save(state, save_path) + + def resume_training(self, resume_state): + """Reload the optimizers and schedulers for resumed training. + + Args: + resume_state (dict): Resume state. + """ + resume_optimizers = resume_state['optimizers'] + resume_schedulers = resume_state['schedulers'] + assert len(resume_optimizers) == len( + self.optimizers), 'Wrong lengths of optimizers' + assert len(resume_schedulers) == len( + self.schedulers), 'Wrong lengths of schedulers' + for i, o in enumerate(resume_optimizers): + self.optimizers[i].load_state_dict(o) + for i, s in enumerate(resume_schedulers): + self.schedulers[i].load_state_dict(s) + + def reduce_loss_dict(self, loss_dict): + """reduce loss dict. + + In distributed training, it averages the losses among different GPUs . + + Args: + loss_dict (OrderedDict): Loss dict. + """ + with torch.no_grad(): + if self.opt['dist']: + keys = [] + losses = [] + for name, value in loss_dict.items(): + keys.append(name) + losses.append(value) + losses = torch.stack(losses, 0) + torch.distributed.reduce(losses, dst=0) + if self.opt['rank'] == 0: + losses /= self.opt['world_size'] + loss_dict = {key: loss for key, loss in zip(keys, losses)} + + log_dict = OrderedDict() + for name, value in loss_dict.items(): + log_dict[name] = value.mean().item() + + return log_dict diff --git a/basicsr/models/image_restoration_model.py b/basicsr/models/image_restoration_model.py new file mode 100644 index 0000000000000000000000000000000000000000..9d067b9a4a541ec2e72de8f34014fc43e187b5bb --- /dev/null +++ b/basicsr/models/image_restoration_model.py @@ -0,0 +1,392 @@ +import importlib +import torch +import os +import gc +import random +import torch.nn.functional as F + +from collections import OrderedDict +from copy import deepcopy +from os import path as osp +from tqdm import tqdm +from functools import partial + +from basicsr.models.archs import define_network +from basicsr.models.base_model import BaseModel +from basicsr.utils import get_root_logger, imwrite, tensor2img +from basicsr.utils.nano import apply_conv_n_deconv +from basicsr.metrics.other_metrics import compute_img_metric + +loss_module = importlib.import_module('basicsr.models.losses') +metric_module = importlib.import_module('basicsr.metrics') + + +class Mixing_Augment: + def __init__(self, mixup_beta, use_identity, device): + self.dist = torch.distributions.beta.Beta(torch.tensor([mixup_beta]), torch.tensor([mixup_beta])) + self.device = device + + self.use_identity = use_identity + + self.augments = [self.mixup] + + def mixup(self, target, input_): + lam = self.dist.rsample((1,1)).item() + + r_index = torch.randperm(target.size(0)).to(self.device) + + target = lam * target + (1-lam) * target[r_index, :] + input_ = lam * input_ + (1-lam) * input_[r_index, :] + + return target, input_ + + def __call__(self, target, input_): + if self.use_identity: + augment = random.randint(0, len(self.augments)) + if augment < len(self.augments): + target, input_ = self.augments[augment](target, input_) + else: + augment = random.randint(0, len(self.augments)-1) + target, input_ = self.augments[augment](target, input_) + return target, input_ + +class ImageCleanModel(BaseModel): + """Base Deblur model for single image deblur.""" + + def __init__(self, opt): + super(ImageCleanModel, self).__init__(opt) + + # define network + + self.mixing_flag = self.opt['train']['mixing_augs'].get('mixup', False) + if self.mixing_flag: + mixup_beta = self.opt['train']['mixing_augs'].get('mixup_beta', 1.2) + use_identity = self.opt['train']['mixing_augs'].get('use_identity', False) + self.mixing_augmentation = Mixing_Augment(mixup_beta, use_identity, self.device) + + self.net_g = define_network(deepcopy(opt['network_g'])) + self.net_g = self.model_to_device(self.net_g) + + # load pretrained models + load_path = self.opt['path'].get('pretrain_network_g', None) + if load_path is not None: + self.load_network(self.net_g, load_path, + self.opt['path'].get('strict_load_g', True), param_key=self.opt['path'].get('param_key', 'params')) + + if self.is_train: + self.init_training_settings() + + def init_training_settings(self): + self.net_g.train() + train_opt = self.opt['train'] + + self.ema_decay = train_opt.get('ema_decay', 0) + if self.ema_decay > 0: + logger = get_root_logger() + logger.info( + f'Use Exponential Moving Average with decay: {self.ema_decay}') + # define network net_g with Exponential Moving Average (EMA) + # net_g_ema is used only for testing on one GPU and saving + # There is no need to wrap with DistributedDataParallel + self.net_g_ema = define_network(self.opt['network_g']).to( + self.device) + # load pretrained model + load_path = self.opt['path'].get('pretrain_network_g', None) + if load_path is not None: + self.load_network(self.net_g_ema, load_path, + self.opt['path'].get('strict_load_g', + True), 'params_ema') + else: + self.model_ema(0) # copy net_g weight + self.net_g_ema.eval() + + + # define losses + if train_opt.get('pixel_opt'): + pixel_type = train_opt['pixel_opt'].pop('type') + cri_pix_cls = getattr(loss_module, pixel_type) + self.cri_pix = cri_pix_cls(**train_opt['pixel_opt']).to( + self.device) + else: + raise ValueError('pixel loss are None.') + + # set up optimizers and schedulers + self.setup_optimizers() + self.setup_schedulers() + + + def setup_optimizers(self): + train_opt = self.opt['train'] + optim_params = [] + + for k, v in self.net_g.named_parameters(): + if v.requires_grad: + optim_params.append(v) + else: + logger = get_root_logger() + logger.warning(f'Params {k} will not be optimized.') + + optim_type = train_opt['optim_g'].pop('type') + if optim_type == 'Adam': + self.optimizer_g = torch.optim.Adam(optim_params, **train_opt['optim_g']) + elif optim_type == 'AdamW': + self.optimizer_g = torch.optim.AdamW(optim_params, **train_opt['optim_g']) + else: + raise NotImplementedError( + f'optimizer {optim_type} is not supperted yet.') + self.optimizers.append(self.optimizer_g) + + + def feed_train_data(self, data): + self.lq = data['lq'].to(self.device) + if 'gt' in data: + self.gt = data['gt'].to(self.device) + + if self.mixing_flag: + self.gt, self.lq = self.mixing_augmentation(self.gt, self.lq) + + def feed_data(self, data, psf=None, ks=None, val_conv=True): + gt = data['gt'].to(self.device) + padding = data['padding'] + padding = torch.stack(padding).T + otf = psf + M = ks.shape[1] + if val_conv: # Apply convolution on the fly (use gt img to create lr image) + lq, gt = apply_conv_n_deconv(gt, otf, padding, M, 0, ks=ks, ph=135, num_psf=9, sensor_h=1215, crop=False, conv=True) + self.lq = lq[None] + self.gt = gt[None] # TODO check dim. 이전에는 square에서 리턴해주는거 그대로 썼는데 지금은 원래 gt 바로 써서 shape 다를수도. 이후 아래랑 합치기 + # TODO 애초에 deconv(gt) 를 gt를 위에서 if else로 받아서 한 줄로 처리 가능 + + else: # loaded npy for validaiton + lq = data['lq'].to(self.device) + lq, gt = apply_conv_n_deconv(lq, otf, padding, M, 0, ks=ks, ph=135, num_psf=9, sensor_h=1215, crop=False, conv=False) + self.lq = lq[None] + self.gt = gt + + + def optimize_parameters(self, current_iter): + self.optimizer_g.zero_grad() + preds = self.net_g(self.lq) + if not isinstance(preds, list): + preds = [preds] + + self.output = preds[-1] + + loss_dict = OrderedDict() + # pixel loss + l_pix = 0. + for pred in preds: + l_pix += self.cri_pix(pred, self.gt) + + loss_dict['l_pix'] = l_pix + + l_pix.backward() + if self.opt['train']['use_grad_clip']: + torch.nn.utils.clip_grad_norm_(self.net_g.parameters(), 0.01) + self.optimizer_g.step() + + self.log_dict = self.reduce_loss_dict(loss_dict) + + if self.ema_decay > 0: + self.model_ema(decay=self.ema_decay) + + def pad_test(self, window_size): + scale = self.opt.get('scale', 1) + mod_pad_h, mod_pad_w = 0, 0 + h,w = self.lq.size()[-2:] + if h % window_size != 0: + mod_pad_h = window_size - h % window_size + if w % window_size != 0: + mod_pad_w = window_size - w % window_size + img = F.pad(self.lq[0], (0, mod_pad_w, 0, mod_pad_h), 'reflect')[None] + self.nonpad_test(img) + _, _, h, w = self.output.size() + self.output = self.output[:, :, 0:h - mod_pad_h * scale, 0:w - mod_pad_w * scale] + + def nonpad_test(self, img=None): + if img is None: + img = self.lq + if hasattr(self, 'net_g_ema'): + self.net_g_ema.eval() + with torch.no_grad(): + pred = self.net_g_ema(img) + if isinstance(pred, list): + pred = pred[-1] + self.output = pred + else: + self.net_g.eval() + with torch.no_grad(): + pred = self.net_g(img) + + if isinstance(pred, list): + pred = pred[-1] + self.output = pred + self.net_g.train() + + def dist_validation(self, dataloader, current_iter, tb_logger, save_img, rgb2bgr, use_image, psf, ks, val_conv): + if os.environ['LOCAL_RANK'] == '0': + return self.nondist_validation(dataloader, current_iter, tb_logger, save_img, rgb2bgr, use_image, psf, ks, val_conv) + else: + return 0. + + + def pre_process(self, padding_size): + # pad to multiplication of window_size + self.mod_pad_h, self.mod_pad_w = 0, 0 + h,w = self.lq.size()[-2:] # BMCHW + if h % padding_size != 0: + self.mod_pad_h = padding_size - h % padding_size + if w % padding_size != 0: + self.mod_pad_w = padding_size - w % padding_size + self.lq = F.pad(self.lq[0], (0, self.mod_pad_w, 0, self.mod_pad_h), 'reflect')[None] + + def post_process(self): + _, _, h, w = self.output.size() + self.output = self.output[...,0:h - self.mod_pad_h, 0:w - self.mod_pad_w] + + def nondist_validation(self, dataloader, current_iter, tb_logger, + save_img, rgb2bgr, use_image, psf, ks, val_conv): + dataset_name = dataloader.dataset.opt['name'] + base_path = self.opt['path']['visualization'] + + with_metrics = self.opt['val'].get('metrics') is not None + if with_metrics: + self.metric_results = { + metric: 0 + for metric in self.opt['val']['metrics'].keys() + } + if save_img: + cur_other_metrics = {'ssim': 0., 'lpips': 0.} + else: + cur_other_metrics = None + + window_size = self.opt['val'].get('window_size', 0) + + if window_size: + test = partial(self.pad_test, window_size) + else: + test = self.nonpad_test + + cnt = 0 + + for idx, val_data in enumerate(tqdm(dataloader)): + img_name = osp.splitext(osp.basename(val_data['gt_path'][0]))[0] + self.feed_data(val_data, psf, ks, val_conv) + pad_for_OCB = self.opt['val'].get('padding') + if pad_for_OCB is not None: + self.pre_process(pad_for_OCB) + + torch.cuda.empty_cache() + gc.collect() + + test() + + if pad_for_OCB is not None: + self.post_process() + + if save_img and with_metrics and use_image: + visuals = self.get_current_visuals(to_cpu=False) + cur_other_metrics['ssim'] += compute_img_metric(visuals['result'][0], visuals['gt'][0], 'ssim') + cur_other_metrics['lpips'] += compute_img_metric(visuals['result'][0], visuals['gt'][0], 'lpips').item() + + visuals = self.get_current_visuals() + + sr_img = tensor2img([visuals['result']], rgb2bgr=rgb2bgr) + if 'gt' in visuals: + gt_img = tensor2img([visuals['gt']], rgb2bgr=rgb2bgr) + del self.gt + + # tentative for out of GPU memory + del self.lq + del self.output + torch.cuda.empty_cache() + gc.collect() + + if save_img: + if self.opt['is_train']: + if 'eval_only' in self.opt['train']: + save_img_path = osp.join(base_path + self.opt['train']['eval_name'], + f'{img_name}_{current_iter}.png') + else: + save_img_path = osp.join(base_path, + f'{img_name}_{current_iter}.png') + else: + save_img_path = osp.join( + base_path, + f'{img_name}.png') + save_gt_img_path = osp.join( + base_path, dataset_name, + f'{img_name}_gt.png') + + imwrite(sr_img, save_img_path) + + if with_metrics: + # calculate metrics + opt_metric = deepcopy(self.opt['val']['metrics']) + if use_image: + for name, opt_ in opt_metric.items(): + metric_type = opt_.pop('type') + self.metric_results[name] += getattr( + metric_module, metric_type)(sr_img, gt_img, **opt_) + else: + for name, opt_ in opt_metric.items(): + metric_type = opt_.pop('type') + self.metric_results[name] += getattr( + metric_module, metric_type)(visuals['result'], visuals['gt'], **opt_) + + cnt += 1 + + + # tentative for out of GPU memory + torch.cuda.empty_cache() + gc.collect() + + current_metric = 0. + if with_metrics: + for metric in self.metric_results.keys(): + self.metric_results[metric] /= cnt + current_metric = self.metric_results[metric] + if save_img: + cur_other_metrics['ssim'] /= cnt + cur_other_metrics['lpips'] /= cnt + + self._log_validation_metric_values(current_iter, dataset_name, + tb_logger) + return current_metric, cur_other_metrics + + + def _log_validation_metric_values(self, current_iter, dataset_name, + tb_logger): + log_str = f'Validation {dataset_name},\t' + for metric, value in self.metric_results.items(): + log_str += f'\t # {metric}: {value:.4f}' + logger = get_root_logger() + logger.info(log_str) + if tb_logger: + for metric, value in self.metric_results.items(): + tb_logger.add_scalar(f'metrics/{metric}', value, current_iter) + + def get_current_visuals(self, to_cpu=True): + if to_cpu: + out_dict = OrderedDict() + out_dict['lq'] = self.lq.detach().cpu() + out_dict['result'] = self.output.detach().cpu() + if hasattr(self, 'gt'): + out_dict['gt'] = self.gt.detach().cpu() + else: + out_dict = OrderedDict() + out_dict['lq'] = self.lq.detach() + out_dict['result'] = self.output.detach() + if hasattr(self, 'gt'): + out_dict['gt'] = self.gt.detach() + return out_dict + + def save(self, epoch, current_iter): + if self.ema_decay > 0: + self.save_network([self.net_g, self.net_g_ema], + 'net_g', + current_iter, + param_key=['params', 'params_ema']) + else: + self.save_network(self.net_g, 'net_g', current_iter) + self.save_training_state(epoch, current_iter) diff --git a/basicsr/models/losses/__init__.py b/basicsr/models/losses/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e3e7dafccc3453dc4efb985badf0ab71c4493ac4 --- /dev/null +++ b/basicsr/models/losses/__init__.py @@ -0,0 +1,5 @@ +from .losses import (L1Loss, MSELoss, PSNRLoss, CharbonnierLoss) + +__all__ = [ + 'L1Loss', 'MSELoss', 'PSNRLoss', 'CharbonnierLoss', +] diff --git a/basicsr/models/losses/loss_util.py b/basicsr/models/losses/loss_util.py new file mode 100644 index 0000000000000000000000000000000000000000..744eeb46d1f3b5a7b4553ca23237ddd9c899a698 --- /dev/null +++ b/basicsr/models/losses/loss_util.py @@ -0,0 +1,95 @@ +import functools +from torch.nn import functional as F + + +def reduce_loss(loss, reduction): + """Reduce loss as specified. + + Args: + loss (Tensor): Elementwise loss tensor. + reduction (str): Options are 'none', 'mean' and 'sum'. + + Returns: + Tensor: Reduced loss tensor. + """ + reduction_enum = F._Reduction.get_enum(reduction) + # none: 0, elementwise_mean:1, sum: 2 + if reduction_enum == 0: + return loss + elif reduction_enum == 1: + return loss.mean() + else: + return loss.sum() + + +def weight_reduce_loss(loss, weight=None, reduction='mean'): + """Apply element-wise weight and reduce loss. + + Args: + loss (Tensor): Element-wise loss. + weight (Tensor): Element-wise weights. Default: None. + reduction (str): Same as built-in losses of PyTorch. Options are + 'none', 'mean' and 'sum'. Default: 'mean'. + + Returns: + Tensor: Loss values. + """ + # if weight is specified, apply element-wise weight + if weight is not None: + assert weight.dim() == loss.dim() + assert weight.size(1) == 1 or weight.size(1) == loss.size(1) + loss = loss * weight + + # if weight is not specified or reduction is sum, just reduce the loss + if weight is None or reduction == 'sum': + loss = reduce_loss(loss, reduction) + # if reduction is mean, then compute mean over weight region + elif reduction == 'mean': + if weight.size(1) > 1: + weight = weight.sum() + else: + weight = weight.sum() * loss.size(1) + loss = loss.sum() / weight + + return loss + + +def weighted_loss(loss_func): + """Create a weighted version of a given loss function. + + To use this decorator, the loss function must have the signature like + `loss_func(pred, target, **kwargs)`. The function only needs to compute + element-wise loss without any reduction. This decorator will add weight + and reduction arguments to the function. The decorated function will have + the signature like `loss_func(pred, target, weight=None, reduction='mean', + **kwargs)`. + + :Example: + + >>> import torch + >>> @weighted_loss + >>> def l1_loss(pred, target): + >>> return (pred - target).abs() + + >>> pred = torch.Tensor([0, 2, 3]) + >>> target = torch.Tensor([1, 1, 1]) + >>> weight = torch.Tensor([1, 0, 1]) + + >>> l1_loss(pred, target) + tensor(1.3333) + >>> l1_loss(pred, target, weight) + tensor(1.5000) + >>> l1_loss(pred, target, reduction='none') + tensor([1., 1., 2.]) + >>> l1_loss(pred, target, weight, reduction='sum') + tensor(3.) + """ + + @functools.wraps(loss_func) + def wrapper(pred, target, weight=None, reduction='mean', **kwargs): + # get element-wise loss + loss = loss_func(pred, target, **kwargs) + loss = weight_reduce_loss(loss, weight, reduction) + return loss + + return wrapper diff --git a/basicsr/models/losses/losses.py b/basicsr/models/losses/losses.py new file mode 100644 index 0000000000000000000000000000000000000000..7d66b174ad67b5d9cb96d6f3476fc0241825890a --- /dev/null +++ b/basicsr/models/losses/losses.py @@ -0,0 +1,180 @@ +import torch +from torch import nn as nn +from torch.nn import functional as F +import numpy as np +from math import exp + +from basicsr.models.losses.loss_util import weighted_loss + +_reduction_modes = ['none', 'mean', 'sum'] + + +@weighted_loss +def l1_loss(pred, target): + return F.l1_loss(pred, target, reduction='none') + + +@weighted_loss +def mse_loss(pred, target): + return F.mse_loss(pred, target, reduction='none') + + +# @weighted_loss +# def charbonnier_loss(pred, target, eps=1e-12): +# return torch.sqrt((pred - target)**2 + eps) + + +class L1Loss(nn.Module): + """L1 (mean absolute error, MAE) loss. + + Args: + loss_weight (float): Loss weight for L1 loss. Default: 1.0. + reduction (str): Specifies the reduction to apply to the output. + Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'. + """ + + def __init__(self, loss_weight=1.0, reduction='mean'): + super(L1Loss, self).__init__() + if reduction not in ['none', 'mean', 'sum']: + raise ValueError(f'Unsupported reduction mode: {reduction}. ' + f'Supported ones are: {_reduction_modes}') + + self.loss_weight = loss_weight + self.reduction = reduction + + def forward(self, pred, target, weight=None, **kwargs): + """ + Args: + pred (Tensor): of shape (N, C, H, W). Predicted tensor. + target (Tensor): of shape (N, C, H, W). Ground truth tensor. + weight (Tensor, optional): of shape (N, C, H, W). Element-wise + weights. Default: None. + """ + return self.loss_weight * l1_loss( + pred, target, weight, reduction=self.reduction) + +class MSELoss(nn.Module): + """MSE (L2) loss. + + Args: + loss_weight (float): Loss weight for MSE loss. Default: 1.0. + reduction (str): Specifies the reduction to apply to the output. + Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'. + """ + + def __init__(self, loss_weight=1.0, reduction='mean'): + super(MSELoss, self).__init__() + if reduction not in ['none', 'mean', 'sum']: + raise ValueError(f'Unsupported reduction mode: {reduction}. ' + f'Supported ones are: {_reduction_modes}') + + self.loss_weight = loss_weight + self.reduction = reduction + + def forward(self, pred, target, weight=None, **kwargs): + """ + Args: + pred (Tensor): of shape (N, C, H, W). Predicted tensor. + target (Tensor): of shape (N, C, H, W). Ground truth tensor. + weight (Tensor, optional): of shape (N, C, H, W). Element-wise + weights. Default: None. + """ + return self.loss_weight * mse_loss( + pred, target, weight, reduction=self.reduction) + +class PSNRLoss(nn.Module): + + def __init__(self, loss_weight=1.0, reduction='mean', toY=False): + super(PSNRLoss, self).__init__() + assert reduction == 'mean' + self.loss_weight = loss_weight + self.scale = 10 / np.log(10) + self.toY = toY + self.coef = torch.tensor([65.481, 128.553, 24.966]).reshape(1, 3, 1, 1) + self.first = True + + def forward(self, pred, target): + assert len(pred.size()) == 4 + if self.toY: + if self.first: + self.coef = self.coef.to(pred.device) + self.first = False + + pred = (pred * self.coef).sum(dim=1).unsqueeze(dim=1) + 16. + target = (target * self.coef).sum(dim=1).unsqueeze(dim=1) + 16. + + pred, target = pred / 255., target / 255. + pass + assert len(pred.size()) == 4 + + return self.loss_weight * self.scale * torch.log(((pred - target) ** 2).mean(dim=(1, 2, 3)) + 1e-8).mean() + +class CharbonnierLoss(nn.Module): + """Charbonnier Loss (L1)""" + + def __init__(self, loss_weight=1.0, reduction='mean', eps=1e-3): + super(CharbonnierLoss, self).__init__() + self.eps = eps + + def forward(self, x, y): + diff = x - y + # loss = torch.sum(torch.sqrt(diff * diff + self.eps)) + loss = torch.mean(torch.sqrt((diff * diff) + (self.eps*self.eps))) + return loss + +class MS_SSIM(nn.Module): + def __init__(self, window_size=11, sigma=1.5, device="cuda"): + super(MS_SSIM, self).__init__() + self.device = device + self.channel = 3 + self.sigma=sigma + self.weights = [0.0448, 0.2856, 0.3001, 0.2363, 0.1333] + self.levels = len(self.weights) + self.window = self.create_window(window_size) + + def create_window(self, window_size): + self.window_size = window_size + # 1D gaussian kernel + gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * self.sigma ** 2)) for x in range(window_size)]) + gauss = gauss / gauss.sum() + + # 2D Gaussian window + _1D_window = gauss.unsqueeze(1) + _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) + return _2D_window.expand(self.channel, 1, window_size, window_size).contiguous().to(self.device) + + def update_window_size(self, window_size): + self.window = self.create_window(window_size) + + def ssim(self, img1, img2): + """Compute SSIM between two images.""" + mu1 = F.conv2d(img1, self.window, padding=self.window_size // 2, groups=self.channel) + mu2 = F.conv2d(img2, self.window, padding=self.window_size // 2, groups=self.channel) + + mu1_sq = mu1.pow(2) + mu2_sq = mu2.pow(2) + mu1_mu2 = mu1 * mu2 + + sigma1_sq = F.conv2d(img1 * img1, self.window, padding=self.window_size // 2, groups=self.channel) - mu1_sq + sigma2_sq = F.conv2d(img2 * img2, self.window, padding=self.window_size // 2, groups=self.channel) - mu2_sq + sigma12 = F.conv2d(img1 * img2, self.window, padding=self.window_size // 2, groups=self.channel) - mu1_mu2 + + C1 = 0.01 ** 2 + C2 = 0.03 ** 2 + + ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) + + return ssim_map.mean() + + def forward(self, pred, target): + msssim = [] + for i in range(self.levels): + ssim_val = self.ssim(pred, target) + msssim.append(ssim_val * self.weights[i]) + if i < self.levels - 1: + pred = F.avg_pool2d(pred, kernel_size=2, stride=2) + target = F.avg_pool2d(target, kernel_size=2, stride=2) + + return torch.prod(torch.stack(msssim)) + + diff --git a/basicsr/models/lr_scheduler.py b/basicsr/models/lr_scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..ec4bf2cc0523d5c9a5cef53be51eb3203802fada --- /dev/null +++ b/basicsr/models/lr_scheduler.py @@ -0,0 +1,232 @@ +import math +from collections import Counter +from torch.optim.lr_scheduler import _LRScheduler +import torch + + +class MultiStepRestartLR(_LRScheduler): + """ MultiStep with restarts learning rate scheme. + + Args: + optimizer (torch.nn.optimizer): Torch optimizer. + milestones (list): Iterations that will decrease learning rate. + gamma (float): Decrease ratio. Default: 0.1. + restarts (list): Restart iterations. Default: [0]. + restart_weights (list): Restart weights at each restart iteration. + Default: [1]. + last_epoch (int): Used in _LRScheduler. Default: -1. + """ + + def __init__(self, + optimizer, + milestones, + gamma=0.1, + restarts=(0, ), + restart_weights=(1, ), + last_epoch=-1): + self.milestones = Counter(milestones) + self.gamma = gamma + self.restarts = restarts + self.restart_weights = restart_weights + assert len(self.restarts) == len( + self.restart_weights), 'restarts and their weights do not match.' + super(MultiStepRestartLR, self).__init__(optimizer, last_epoch) + + def get_lr(self): + if self.last_epoch in self.restarts: + weight = self.restart_weights[self.restarts.index(self.last_epoch)] + return [ + group['initial_lr'] * weight + for group in self.optimizer.param_groups + ] + if self.last_epoch not in self.milestones: + return [group['lr'] for group in self.optimizer.param_groups] + return [ + group['lr'] * self.gamma**self.milestones[self.last_epoch] + for group in self.optimizer.param_groups + ] + +class LinearLR(_LRScheduler): + """ + + Args: + optimizer (torch.nn.optimizer): Torch optimizer. + milestones (list): Iterations that will decrease learning rate. + gamma (float): Decrease ratio. Default: 0.1. + last_epoch (int): Used in _LRScheduler. Default: -1. + """ + + def __init__(self, + optimizer, + total_iter, + last_epoch=-1): + self.total_iter = total_iter + super(LinearLR, self).__init__(optimizer, last_epoch) + + def get_lr(self): + process = self.last_epoch / self.total_iter + weight = (1 - process) + # print('get lr ', [weight * group['initial_lr'] for group in self.optimizer.param_groups]) + return [weight * group['initial_lr'] for group in self.optimizer.param_groups] + +class VibrateLR(_LRScheduler): + """ + + Args: + optimizer (torch.nn.optimizer): Torch optimizer. + milestones (list): Iterations that will decrease learning rate. + gamma (float): Decrease ratio. Default: 0.1. + last_epoch (int): Used in _LRScheduler. Default: -1. + """ + + def __init__(self, + optimizer, + total_iter, + last_epoch=-1): + self.total_iter = total_iter + super(VibrateLR, self).__init__(optimizer, last_epoch) + + def get_lr(self): + process = self.last_epoch / self.total_iter + + f = 0.1 + if process < 3 / 8: + f = 1 - process * 8 / 3 + elif process < 5 / 8: + f = 0.2 + + T = self.total_iter // 80 + Th = T // 2 + + t = self.last_epoch % T + + f2 = t / Th + if t >= Th: + f2 = 2 - f2 + + weight = f * f2 + + if self.last_epoch < Th: + weight = max(0.1, weight) + + # print('f {}, T {}, Th {}, t {}, f2 {}'.format(f, T, Th, t, f2)) + return [weight * group['initial_lr'] for group in self.optimizer.param_groups] + +def get_position_from_periods(iteration, cumulative_period): + """Get the position from a period list. + + It will return the index of the right-closest number in the period list. + For example, the cumulative_period = [100, 200, 300, 400], + if iteration == 50, return 0; + if iteration == 210, return 2; + if iteration == 300, return 2. + + Args: + iteration (int): Current iteration. + cumulative_period (list[int]): Cumulative period list. + + Returns: + int: The position of the right-closest number in the period list. + """ + for i, period in enumerate(cumulative_period): + if iteration <= period: + return i + + +class CosineAnnealingRestartLR(_LRScheduler): + """ Cosine annealing with restarts learning rate scheme. + + An example of config: + periods = [10, 10, 10, 10] + restart_weights = [1, 0.5, 0.5, 0.5] + eta_min=1e-7 + + It has four cycles, each has 10 iterations. At 10th, 20th, 30th, the + scheduler will restart with the weights in restart_weights. + + Args: + optimizer (torch.nn.optimizer): Torch optimizer. + periods (list): Period for each cosine anneling cycle. + restart_weights (list): Restart weights at each restart iteration. + Default: [1]. + eta_min (float): The mimimum lr. Default: 0. + last_epoch (int): Used in _LRScheduler. Default: -1. + """ + + def __init__(self, + optimizer, + periods, + restart_weights=(1, ), + eta_min=0, + last_epoch=-1): + self.periods = periods + self.restart_weights = restart_weights + self.eta_min = eta_min + assert (len(self.periods) == len(self.restart_weights) + ), 'periods and restart_weights should have the same length.' + self.cumulative_period = [ + sum(self.periods[0:i + 1]) for i in range(0, len(self.periods)) + ] + super(CosineAnnealingRestartLR, self).__init__(optimizer, last_epoch) + + def get_lr(self): + idx = get_position_from_periods(self.last_epoch, + self.cumulative_period) + current_weight = self.restart_weights[idx] + nearest_restart = 0 if idx == 0 else self.cumulative_period[idx - 1] + current_period = self.periods[idx] + + return [ + self.eta_min + current_weight * 0.5 * (base_lr - self.eta_min) * + (1 + math.cos(math.pi * ( + (self.last_epoch - nearest_restart) / current_period))) + for base_lr in self.base_lrs + ] + +class CosineAnnealingRestartCyclicLR(_LRScheduler): + """ Cosine annealing with restarts learning rate scheme. + An example of config: + periods = [10, 10, 10, 10] + restart_weights = [1, 0.5, 0.5, 0.5] + eta_min=1e-7 + It has four cycles, each has 10 iterations. At 10th, 20th, 30th, the + scheduler will restart with the weights in restart_weights. + Args: + optimizer (torch.nn.optimizer): Torch optimizer. + periods (list): Period for each cosine anneling cycle. + restart_weights (list): Restart weights at each restart iteration. + Default: [1]. + eta_min (float): The mimimum lr. Default: 0. + last_epoch (int): Used in _LRScheduler. Default: -1. + """ + + def __init__(self, + optimizer, + periods, + restart_weights=(1, ), + eta_mins=(0, ), + last_epoch=-1): + self.periods = periods + self.restart_weights = restart_weights + self.eta_mins = eta_mins + assert (len(self.periods) == len(self.restart_weights) + ), 'periods and restart_weights should have the same length.' + self.cumulative_period = [ + sum(self.periods[0:i + 1]) for i in range(0, len(self.periods)) + ] + super(CosineAnnealingRestartCyclicLR, self).__init__(optimizer, last_epoch) + + def get_lr(self): + idx = get_position_from_periods(self.last_epoch, + self.cumulative_period) + current_weight = self.restart_weights[idx] + nearest_restart = 0 if idx == 0 else self.cumulative_period[idx - 1] + current_period = self.periods[idx] + eta_min = self.eta_mins[idx] + + return [ + eta_min + current_weight * 0.5 * (base_lr - eta_min) * + (1 + math.cos(math.pi * ( + (self.last_epoch - nearest_restart) / current_period))) + for base_lr in self.base_lrs + ] diff --git a/basicsr/test.py b/basicsr/test.py new file mode 100644 index 0000000000000000000000000000000000000000..e0c57d4e9b67d120d889cbe0196b505a2473b151 --- /dev/null +++ b/basicsr/test.py @@ -0,0 +1,142 @@ +import argparse +import random +import torch +from os import path as osp + +from basicsr.data import create_dataloader, create_dataset +from basicsr.models import create_model +from basicsr.utils import (check_resume, make_exp_dirs, mkdir_and_rename, set_random_seed) +from basicsr.utils.dist_util import get_dist_info, init_dist +from basicsr.utils.options import parse +from basicsr.utils.nano import psf2otf + +import numpy as np +from tqdm import tqdm + +def parse_options(is_train=True): + parser = argparse.ArgumentParser() + parser.add_argument( + '-opt', type=str, required=True, help='Path to option YAML file.') + parser.add_argument( + '--launcher', + choices=['none', 'pytorch', 'slurm'], + default='none', + help='job launcher') + parser.add_argument( + '--name', + default=None, + help='job launcher') + import sys + vv = sys.version_info.minor + parser.add_argument('--local-rank', type=int, default=0) + parser.add_argument('--local_rank', type=int, default=0) + args = parser.parse_args() + opt = parse(args.opt, is_train=is_train, name=args.name if args.name is not None and args.name != "" else None) + + + # distributed settings + if args.launcher == 'none': + opt['dist'] = False + print('Disable distributed.', flush=True) + else: + opt['dist'] = True + if args.launcher == 'slurm' and 'dist_params' in opt: + init_dist(args.launcher, **opt['dist_params']) + else: + init_dist(args.launcher) + print('init dist .. ', args.launcher) + + opt['rank'], opt['world_size'] = get_dist_info() + + # random seed + seed = opt.get('manual_seed') + if seed is None: + seed = random.randint(1, 10000) + opt['manual_seed'] = seed + set_random_seed(seed + opt['rank']) + + return opt + + +def main(): + # parse options, set distributed setting, set ramdom seed + opt = parse_options(is_train=True) + torch.backends.cudnn.benchmark = True + + # automatic resume .. + state_folder_path = 'experiments/{}/training_states/'.format(opt['name']) + import os + try: + states = os.listdir(state_folder_path) + except: + states = [] + resume_state = None + if len(states) > 0: + max_state_file = '{}.state'.format(max([int(x[0:-6]) for x in states])) + resume_state = os.path.join(state_folder_path, max_state_file) + opt['path']['resume_state'] = resume_state + + # load resume states if necessary + if opt['path'].get('resume_state'): + device_id = torch.cuda.current_device() + resume_state = torch.load( + opt['path']['resume_state'], + map_location=lambda storage, loc: storage.cuda(device_id)) + else: + resume_state = None + + # mkdir for experiments and logger + if resume_state is None: + make_exp_dirs(opt) + if opt['logger'].get('use_tb_logger') and 'debug' not in opt[ + 'name'] and opt['rank'] == 0: + mkdir_and_rename(osp.join('tb_logger', opt['name'])) + + + # define ks for Wiener filters + ks_params = opt['train'].get('ks', None) + if not ks_params: + raise NotImplementedError + M = ks_params['num'] + ks = torch.logspace(ks_params['start'], ks_params['end'], M) + ks = ks.view(1,M,1,1,1,1).to("cuda") + + val_conv = opt['val'].get("apply_conv", True) + + # create model + if resume_state: # resume training + check_resume(opt, resume_state['iter']) + model = create_model(opt) + model.resume_training(resume_state) # handle optimizers and schedulers + current_iter = resume_state['iter'] + + else: + model = create_model(opt) + current_iter = 0 + + # load psf + psf = torch.tensor(np.load("./psf.npy")).to("cuda") + _,psf_h,psf_w,_ = psf.shape + otf = psf2otf(psf, h=psf_h*3, w=psf_w*3, permute=True)[None] + + dataset_opt = opt['datasets']['val'] + + val_set = create_dataset(dataset_opt) + val_loader = create_dataloader( + val_set, + dataset_opt, + num_gpu=opt['num_gpu'], + dist=opt['dist'], + sampler=None, + seed=opt['manual_seed']) + + print("Start validation on spatially varying aberrration") + rgb2bgr = opt['val'].get('rgb2bgr', True) + use_image = opt['val'].get('use_image', True) + psnr, others = model.validation(val_loader, current_iter, None, True, rgb2bgr, use_image, psf=otf, ks=ks, val_conv=val_conv) + print("==================") + print(f"Test results: PSNR: {psnr:.2f}, SSIM: {others['ssim']:.4f}, LPIPS: {others['lpips']:.4f}\n") + + +if __name__ == '__main__': + main() diff --git a/basicsr/train.py b/basicsr/train.py new file mode 100644 index 0000000000000000000000000000000000000000..07e351356c9950977741c6e0c6e9cbd36e318106 --- /dev/null +++ b/basicsr/train.py @@ -0,0 +1,328 @@ +import argparse +import datetime +import logging +import math +import random +import time +import torch +import gc +from os import path as osp + +from basicsr.data import create_dataloader, create_dataset +from basicsr.data.data_sampler import EnlargedSampler +from basicsr.data.prefetch_dataloader import CPUPrefetcher, CUDAPrefetcher +from basicsr.models import create_model +from basicsr.utils import (MessageLogger, check_resume, get_env_info, + get_root_logger, get_time_str, init_tb_logger, + init_wandb_logger, make_exp_dirs, mkdir_and_rename, + set_random_seed) +from basicsr.utils.dist_util import get_dist_info, init_dist +from basicsr.utils.options import dict2str, parse +from basicsr.utils.nano import apply_conv_n_deconv, psf2otf + +import numpy as np +from tqdm import tqdm + +def parse_options(is_train=True): + parser = argparse.ArgumentParser() + parser.add_argument( + '-opt', type=str, required=True, help='Path to option YAML file.') + parser.add_argument( + '--launcher', + choices=['none', 'pytorch', 'slurm'], + default='none', + help='job launcher') + parser.add_argument( + '--name', + default=None, + help='job launcher') + import sys + vv = sys.version_info.minor + parser.add_argument('--local-rank', type=int, default=0) + parser.add_argument('--local_rank', type=int, default=0) + args = parser.parse_args() + opt = parse(args.opt, is_train=is_train, name=args.name if args.name is not None and args.name != "" else None) + + + # distributed settings + if args.launcher == 'none': + opt['dist'] = False + print('Disable distributed.', flush=True) + else: + opt['dist'] = True + if args.launcher == 'slurm' and 'dist_params' in opt: + init_dist(args.launcher, **opt['dist_params']) + else: + init_dist(args.launcher) + print('init dist .. ', args.launcher) + + opt['rank'], opt['world_size'] = get_dist_info() + + # random seed + seed = opt.get('manual_seed') + if seed is None: + seed = random.randint(1, 10000) + opt['manual_seed'] = seed + set_random_seed(seed + opt['rank']) + + return opt + + +def init_loggers(opt): + log_file = osp.join(opt['path']['log'], + f"train_{opt['name']}_{get_time_str()}.log") + logger = get_root_logger( + logger_name='basicsr', log_level=logging.INFO, log_file=log_file) + logger.info(get_env_info()) + logger.info(dict2str(opt)) + + # initialize wandb logger before tensorboard logger to allow proper sync: + if (opt['logger'].get('wandb') + is not None) and (opt['logger']['wandb'].get('project') + is not None) and ('debug' not in opt['name']): + assert opt['logger'].get('use_tb_logger') is True, ( + 'should turn on tensorboard when using wandb') + init_wandb_logger(opt) + tb_logger = None + if opt['logger'].get('use_tb_logger') and 'debug' not in opt['name']: + tb_logger = init_tb_logger(log_dir=osp.join('tb_logger', opt['name'])) + return logger, tb_logger + + +def create_train_val_dataloader(opt, logger): + # create train and val dataloaders + for phase, dataset_opt in opt['datasets'].items(): + if phase == 'train': + dataset_enlarge_ratio = dataset_opt.get('dataset_enlarge_ratio', 1) + train_set = create_dataset(dataset_opt) + train_sampler = EnlargedSampler(train_set, opt['world_size'], + opt['rank'], dataset_enlarge_ratio) + train_loader = create_dataloader( + train_set, + dataset_opt, + num_gpu=opt['num_gpu'], + dist=opt['dist'], + sampler=train_sampler, + seed=opt['manual_seed'], + ) + + num_iter_per_epoch = math.ceil( + len(train_set) * dataset_enlarge_ratio / + (dataset_opt['batch_size_per_gpu'] * opt['world_size'])) + total_iters = int(opt['train']['total_iter']) + total_epochs = math.ceil(total_iters / (num_iter_per_epoch)) + logger.info( + 'Training statistics:' + f'\n\tNumber of train images: {len(train_set)}' + f'\n\tDataset enlarge ratio: {dataset_enlarge_ratio}' + f'\n\tBatch size per gpu: {dataset_opt["batch_size_per_gpu"]}' + f'\n\tWorld size (gpu number): {opt["world_size"]}' + f'\n\tRequire iter number per epoch: {num_iter_per_epoch}' + f'\n\tTotal epochs: {total_epochs}; iters: {total_iters}.') + + elif phase == 'val': + val_set = create_dataset(dataset_opt) + val_loader = create_dataloader( + val_set, + dataset_opt, + num_gpu=opt['num_gpu'], + dist=opt['dist'], + sampler=None, + seed=opt['manual_seed'], + ) + logger.info( + f'Number of val images/folders in {dataset_opt["name"]}: ' + f'{len(val_set)}') + + else: + raise ValueError(f'Dataset phase {phase} is not recognized.') + + return train_loader, train_sampler, val_loader, total_epochs, total_iters + + +def main(): + # parse options, set distributed setting, set ramdom seed + opt = parse_options(is_train=True) + torch.backends.cudnn.benchmark = True + + # automatic resume .. + state_folder_path = 'experiments/{}/training_states/'.format(opt['name']) + import os + try: + states = os.listdir(state_folder_path) + except: + states = [] + resume_state = None + if len(states) > 0: + max_state_file = '{}.state'.format(max([int(x[0:-6]) for x in states])) + resume_state = os.path.join(state_folder_path, max_state_file) + opt['path']['resume_state'] = resume_state + + # load resume states if necessary + if opt['path'].get('resume_state'): + device_id = torch.cuda.current_device() + resume_state = torch.load( + opt['path']['resume_state'], + map_location=lambda storage, loc: storage.cuda(device_id)) + else: + resume_state = None + + # mkdir for experiments and logger + if resume_state is None: + make_exp_dirs(opt) + if opt['logger'].get('use_tb_logger') and 'debug' not in opt[ + 'name'] and opt['rank'] == 0: + mkdir_and_rename(osp.join('tb_logger', opt['name'])) + + # initialize loggers + logger, tb_logger = init_loggers(opt) + + # define ks for Wiener filters + ks_params = opt['train'].get('ks', None) + if not ks_params: + raise NotImplementedError + M = ks_params['num'] + ks = torch.logspace(ks_params['start'], ks_params['end'], M) + ks = ks.view(1,M,1,1,1,1).to("cuda") + + # create model + if resume_state: # resume training + check_resume(opt, resume_state['iter']) + model = create_model(opt) + model.resume_training(resume_state) # handle optimizers and schedulers + logger.info(f"Resuming training from epoch: {resume_state['epoch']}, " + f"iter: {resume_state['iter']}.") + start_epoch = resume_state['epoch'] + current_iter = resume_state['iter'] + + else: + model = create_model(opt) + start_epoch = 0 + current_iter = 0 + + + + # create train and validation dataloaders + result = create_train_val_dataloader(opt, logger) + train_loader, train_sampler, val_loader, total_epochs, total_iters = result + + + # create message logger (formatted outputs) + msg_logger = MessageLogger(opt, current_iter, tb_logger) + + # dataloader prefetcher + prefetch_mode = opt['datasets']['train'].get('prefetch_mode') + if prefetch_mode is None or prefetch_mode == 'cpu': + prefetcher = CPUPrefetcher(train_loader) + elif prefetch_mode == 'cuda': + prefetcher = CUDAPrefetcher(train_loader, opt) + logger.info(f'Use {prefetch_mode} prefetch dataloader') + if opt['datasets']['train'].get('pin_memory') is not True: + raise ValueError('Please set pin_memory=True for CUDAPrefetcher.') + else: + raise ValueError(f'Wrong prefetch_mode {prefetch_mode}.' + "Supported ones are: None, 'cuda', 'cpu'.") + + # training + logger.info( + f'Start training from epoch: {start_epoch}, iter: {current_iter}') + data_time, iter_time = time.time(), time.time() + start_time = time.time() + + + + epoch = start_epoch + pbar = tqdm(total = total_iters+1) + pbar.update(current_iter) + + # load psf + psf = torch.tensor(np.load("./psf.npy")).to("cuda") + psf_n,psf_h,psf_w,_ = psf.shape + psf_n_row = int(psf_n ** 0.5) + sensor_h = opt['datasets']['train'].get('sensor_size') + otf = psf2otf(psf, h=psf_h*3, w=psf_w*3, permute=True)[None] + + + gt_size = opt['datasets']['train']['gt_size'] + val_conv = opt['val'].get("apply_conv", True) + + + while current_iter <= total_iters: + train_sampler.set_epoch(epoch) + prefetcher.reset() + train_data = prefetcher.next() + + while train_data is not None: + data_time = time.time() - data_time + + gt = train_data['gt'].to("cuda") # B,C,H,H + padding = train_data['padding'] + padding = torch.stack(padding).T + lq, gt = apply_conv_n_deconv(gt, otf, padding, M, gt_size, ks=ks, ph=psf_h, num_psf=psf_n_row, sensor_h=sensor_h) + + + # 3 H W . conv -> crop + current_iter += 1 + if current_iter > total_iters: + break + # update learning rate + model.update_learning_rate( + current_iter, warmup_iter=opt['train'].get('warmup_iter', -1)) + + + model.feed_train_data({'lq': lq, 'gt':gt}) + model.optimize_parameters(current_iter) + + + iter_time = time.time() - iter_time + + # log + if current_iter % opt['logger']['print_freq'] == 0: + log_vars = {'epoch': epoch, 'iter': current_iter} + log_vars.update({'lrs': model.get_current_learning_rate()}) + log_vars.update({'time': iter_time, 'data_time': data_time}) + + log_vars.update(model.get_current_log()) + msg_logger(log_vars) + + # save models and training states + if current_iter % opt['logger']['save_checkpoint_freq'] == 0: + logger.info('Saving models and training states.') + model.save(epoch, current_iter) + + # validation + if opt.get('val') is not None and ((current_iter % opt['val']['val_freq'] == 0)): + rgb2bgr = opt['val'].get('rgb2bgr', True) + # wheather use uint8 image to compute metrics + use_image = opt['val'].get('use_image', True) + model.validation(val_loader, current_iter, tb_logger, False, rgb2bgr, use_image, psf=otf, ks=ks, val_conv=val_conv) + gc.collect() + torch.cuda.empty_cache() + + data_time = time.time() + iter_time = time.time() + train_data = prefetcher.next() + pbar.update(1) + # end of iter + epoch += 1 + + # end of epoch + + consumed_time = str( + datetime.timedelta(seconds=int(time.time() - start_time))) + logger.info(f'End of training. Time consumed: {consumed_time}') + logger.info('Save the latest model.') + model.save(epoch=-1, current_iter=-1) # -1 stands for the latest + if opt.get('val') is not None: + rgb2bgr = opt['val'].get('rgb2bgr', True) + use_image = opt['val'].get('use_image', True) + psnr, others = model.validation(val_loader, current_iter, tb_logger, True, rgb2bgr, use_image, psf=otf, ks=ks, val_conv=val_conv) + print("==================") + print(f"Test results: PSNR: {psnr:.2f}, SSIM: {others['ssim']:.4f}, LPIPS: {others['lpips']:.4f}\n") + + if tb_logger: + tb_logger.close() + + +if __name__ == '__main__': + main() diff --git a/basicsr/utils/__init__.py b/basicsr/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..aaa3c5f4b1273fefa478473063dd3390a706f0ad --- /dev/null +++ b/basicsr/utils/__init__.py @@ -0,0 +1,45 @@ +from .file_client import FileClient +from .img_util import crop_border, imfrombytes, img2tensor, imwrite, tensor2img, padding, padding_DP, imfrombytesDP +from .logger import (MessageLogger, get_env_info, get_root_logger, + init_tb_logger, init_wandb_logger) +from .misc import (check_resume, get_time_str, make_exp_dirs, mkdir_and_rename, + scandir, scandir_mv, scandir_mv_flat, scandir_SIDD, set_random_seed, sizeof_fmt) +from .create_lmdb import (create_lmdb_for_reds, create_lmdb_for_gopro, create_lmdb_for_rain13k) + +__all__ = [ + # file_client.py + 'FileClient', + # img_util.py + 'img2tensor', + 'tensor2img', + 'imfrombytes', + 'imwrite', + 'crop_border', + # logger.py + 'MessageLogger', + 'init_tb_logger', + 'init_wandb_logger', + 'get_root_logger', + 'get_env_info', + # misc.py + 'set_random_seed', + 'get_time_str', + 'mkdir_and_rename', + 'make_exp_dirs', + 'scandir', + 'scandir_mv', + 'scandir_mv_flat', + 'check_resume', + 'sizeof_fmt', + 'padding', + 'padding_DP', + 'imfrombytesDP', + 'create_lmdb_for_reds', + 'create_lmdb_for_gopro', + 'create_lmdb_for_rain13k', + # nano.py + 'psf2otf', + 'fft', + 'ifft', + 'get_edgetaper_weight', +] diff --git a/basicsr/utils/bundle_submissions.py b/basicsr/utils/bundle_submissions.py new file mode 100644 index 0000000000000000000000000000000000000000..fc6a242624e77d2e83c43577bc5772b16936856b --- /dev/null +++ b/basicsr/utils/bundle_submissions.py @@ -0,0 +1,108 @@ + # Author: Tobias Plötz, TU Darmstadt (tobias.ploetz@visinf.tu-darmstadt.de) + + # This file is part of the implementation as described in the CVPR 2017 paper: + # Tobias Plötz and Stefan Roth, Benchmarking Denoising Algorithms with Real Photographs. + # Please see the file LICENSE.txt for the license governing this code. + + +import numpy as np +import scipy.io as sio +import os +import h5py + +def bundle_submissions_raw(submission_folder,session): + ''' + Bundles submission data for raw denoising + + submission_folder Folder where denoised images reside + + Output is written to /bundled/. Please submit + the content of this folder. + ''' + + out_folder = os.path.join(submission_folder, session) + # out_folder = os.path.join(submission_folder, "bundled/") + try: + os.mkdir(out_folder) + except:pass + + israw = True + eval_version="1.0" + + for i in range(50): + Idenoised = np.zeros((20,), dtype=np.object) + for bb in range(20): + filename = '%04d_%02d.mat'%(i+1,bb+1) + s = sio.loadmat(os.path.join(submission_folder,filename)) + Idenoised_crop = s["Idenoised_crop"] + Idenoised[bb] = Idenoised_crop + filename = '%04d.mat'%(i+1) + sio.savemat(os.path.join(out_folder, filename), + {"Idenoised": Idenoised, + "israw": israw, + "eval_version": eval_version}, + ) + +def bundle_submissions_srgb(submission_folder,session): + ''' + Bundles submission data for sRGB denoising + + submission_folder Folder where denoised images reside + + Output is written to /bundled/. Please submit + the content of this folder. + ''' + out_folder = os.path.join(submission_folder, session) + # out_folder = os.path.join(submission_folder, "bundled/") + try: + os.mkdir(out_folder) + except:pass + israw = False + eval_version="1.0" + + for i in range(50): + Idenoised = np.zeros((20,), dtype=np.object) + for bb in range(20): + filename = '%04d_%02d.mat'%(i+1,bb+1) + s = sio.loadmat(os.path.join(submission_folder,filename)) + Idenoised_crop = s["Idenoised_crop"] + Idenoised[bb] = Idenoised_crop + filename = '%04d.mat'%(i+1) + sio.savemat(os.path.join(out_folder, filename), + {"Idenoised": Idenoised, + "israw": israw, + "eval_version": eval_version}, + ) + + + +def bundle_submissions_srgb_v1(submission_folder,session): + ''' + Bundles submission data for sRGB denoising + + submission_folder Folder where denoised images reside + + Output is written to /bundled/. Please submit + the content of this folder. + ''' + out_folder = os.path.join(submission_folder, session) + # out_folder = os.path.join(submission_folder, "bundled/") + try: + os.mkdir(out_folder) + except:pass + israw = False + eval_version="1.0" + + for i in range(50): + Idenoised = np.zeros((20,), dtype=np.object) + for bb in range(20): + filename = '%04d_%d.mat'%(i+1,bb+1) + s = sio.loadmat(os.path.join(submission_folder,filename)) + Idenoised_crop = s["Idenoised_crop"] + Idenoised[bb] = Idenoised_crop + filename = '%04d.mat'%(i+1) + sio.savemat(os.path.join(out_folder, filename), + {"Idenoised": Idenoised, + "israw": israw, + "eval_version": eval_version}, + ) \ No newline at end of file diff --git a/basicsr/utils/create_lmdb.py b/basicsr/utils/create_lmdb.py new file mode 100644 index 0000000000000000000000000000000000000000..e809df65f97846b1917d73e121a3d65bb8ca9b9c --- /dev/null +++ b/basicsr/utils/create_lmdb.py @@ -0,0 +1,124 @@ +import argparse +from os import path as osp + +from basicsr.utils import scandir +from basicsr.utils.lmdb_util import make_lmdb_from_imgs + +def prepare_keys(folder_path, suffix='png'): + """Prepare image path list and keys for DIV2K dataset. + + Args: + folder_path (str): Folder path. + + Returns: + list[str]: Image path list. + list[str]: Key list. + """ + print('Reading image path list ...') + img_path_list = sorted( + list(scandir(folder_path, suffix=suffix, recursive=False))) + keys = [img_path.split('.{}'.format(suffix))[0] for img_path in sorted(img_path_list)] + + return img_path_list, keys + +def create_lmdb_for_reds(): + folder_path = './datasets/REDS/val/sharp_300' + lmdb_path = './datasets/REDS/val/sharp_300.lmdb' + img_path_list, keys = prepare_keys(folder_path, 'png') + make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) + # + folder_path = './datasets/REDS/val/blur_300' + lmdb_path = './datasets/REDS/val/blur_300.lmdb' + img_path_list, keys = prepare_keys(folder_path, 'jpg') + make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) + + folder_path = './datasets/REDS/train/train_sharp' + lmdb_path = './datasets/REDS/train/train_sharp.lmdb' + img_path_list, keys = prepare_keys(folder_path, 'png') + make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) + + folder_path = './datasets/REDS/train/train_blur_jpeg' + lmdb_path = './datasets/REDS/train/train_blur_jpeg.lmdb' + img_path_list, keys = prepare_keys(folder_path, 'jpg') + make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) + + +def create_lmdb_for_gopro(): + folder_path = './datasets/GoPro/train/blur_crops' + lmdb_path = './datasets/GoPro/train/blur_crops.lmdb' + + img_path_list, keys = prepare_keys(folder_path, 'png') + make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) + + folder_path = './datasets/GoPro/train/sharp_crops' + lmdb_path = './datasets/GoPro/train/sharp_crops.lmdb' + + img_path_list, keys = prepare_keys(folder_path, 'png') + make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) + + folder_path = './datasets/GoPro/test/target' + lmdb_path = './datasets/GoPro/test/target.lmdb' + + img_path_list, keys = prepare_keys(folder_path, 'png') + make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) + + folder_path = './datasets/GoPro/test/input' + lmdb_path = './datasets/GoPro/test/input.lmdb' + + img_path_list, keys = prepare_keys(folder_path, 'png') + make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) + +def create_lmdb_for_rain13k(): + folder_path = './datasets/Rain13k/train/input' + lmdb_path = './datasets/Rain13k/train/input.lmdb' + + img_path_list, keys = prepare_keys(folder_path, 'jpg') + make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) + + folder_path = './datasets/Rain13k/train/target' + lmdb_path = './datasets/Rain13k/train/target.lmdb' + + img_path_list, keys = prepare_keys(folder_path, 'jpg') + make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) + +def create_lmdb_for_SIDD(): + folder_path = './datasets/SIDD/train/input_crops' + lmdb_path = './datasets/SIDD/train/input_crops.lmdb' + + img_path_list, keys = prepare_keys(folder_path, 'PNG') + make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) + + folder_path = './datasets/SIDD/train/gt_crops' + lmdb_path = './datasets/SIDD/train/gt_crops.lmdb' + + img_path_list, keys = prepare_keys(folder_path, 'PNG') + make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) + + #for val + folder_path = './datasets/SIDD/val/input_crops' + lmdb_path = './datasets/SIDD/val/input_crops.lmdb' + mat_path = './datasets/SIDD/ValidationNoisyBlocksSrgb.mat' + if not osp.exists(folder_path): + os.makedirs(folder_path) + assert osp.exists(mat_path) + data = scio.loadmat(mat_path)['ValidationNoisyBlocksSrgb'] + N, B, H ,W, C = data.shape + data = data.reshape(N*B, H, W, C) + for i in tqdm(range(N*B)): + cv2.imwrite(osp.join(folder_path, 'ValidationBlocksSrgb_{}.png'.format(i)), cv2.cvtColor(data[i,...], cv2.COLOR_RGB2BGR)) + img_path_list, keys = prepare_keys(folder_path, 'png') + make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) + + folder_path = './datasets/SIDD/val/gt_crops' + lmdb_path = './datasets/SIDD/val/gt_crops.lmdb' + mat_path = './datasets/SIDD/ValidationGtBlocksSrgb.mat' + if not osp.exists(folder_path): + os.makedirs(folder_path) + assert osp.exists(mat_path) + data = scio.loadmat(mat_path)['ValidationGtBlocksSrgb'] + N, B, H ,W, C = data.shape + data = data.reshape(N*B, H, W, C) + for i in tqdm(range(N*B)): + cv2.imwrite(osp.join(folder_path, 'ValidationBlocksSrgb_{}.png'.format(i)), cv2.cvtColor(data[i,...], cv2.COLOR_RGB2BGR)) + img_path_list, keys = prepare_keys(folder_path, 'png') + make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) diff --git a/basicsr/utils/dist_util.py b/basicsr/utils/dist_util.py new file mode 100644 index 0000000000000000000000000000000000000000..43cf4cda16db549c7961feac383f3813bcdd985f --- /dev/null +++ b/basicsr/utils/dist_util.py @@ -0,0 +1,83 @@ +# Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py # noqa: E501 +import functools +import os +import subprocess +import torch +import torch.distributed as dist +import torch.multiprocessing as mp + + +def init_dist(launcher, backend='nccl', **kwargs): + if mp.get_start_method(allow_none=True) is None: + mp.set_start_method('spawn') + if launcher == 'pytorch': + _init_dist_pytorch(backend, **kwargs) + elif launcher == 'slurm': + _init_dist_slurm(backend, **kwargs) + else: + raise ValueError(f'Invalid launcher type: {launcher}') + + +def _init_dist_pytorch(backend, **kwargs): + rank = int(os.environ['RANK']) + num_gpus = torch.cuda.device_count() + torch.cuda.set_device(rank % num_gpus) + dist.init_process_group(backend=backend, **kwargs) + + +def _init_dist_slurm(backend, port=None): + """Initialize slurm distributed training environment. + + If argument ``port`` is not specified, then the master port will be system + environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system + environment variable, then a default port ``29500`` will be used. + + Args: + backend (str): Backend of torch.distributed. + port (int, optional): Master port. Defaults to None. + """ + proc_id = int(os.environ['SLURM_PROCID']) + ntasks = int(os.environ['SLURM_NTASKS']) + node_list = os.environ['SLURM_NODELIST'] + num_gpus = torch.cuda.device_count() + torch.cuda.set_device(proc_id % num_gpus) + addr = subprocess.getoutput( + f'scontrol show hostname {node_list} | head -n1') + # specify master port + if port is not None: + os.environ['MASTER_PORT'] = str(port) + elif 'MASTER_PORT' in os.environ: + pass # use MASTER_PORT in the environment variable + else: + # 29500 is torch.distributed default port + os.environ['MASTER_PORT'] = '29500' + os.environ['MASTER_ADDR'] = addr + os.environ['WORLD_SIZE'] = str(ntasks) + os.environ['LOCAL_RANK'] = str(proc_id % num_gpus) + os.environ['RANK'] = str(proc_id) + dist.init_process_group(backend=backend) + + +def get_dist_info(): + if dist.is_available(): + initialized = dist.is_initialized() + else: + initialized = False + if initialized: + rank = dist.get_rank() + world_size = dist.get_world_size() + else: + rank = 0 + world_size = 1 + return rank, world_size + + +def master_only(func): + + @functools.wraps(func) + def wrapper(*args, **kwargs): + rank, _ = get_dist_info() + if rank == 0: + return func(*args, **kwargs) + + return wrapper diff --git a/basicsr/utils/download_util.py b/basicsr/utils/download_util.py new file mode 100644 index 0000000000000000000000000000000000000000..64a00161de42992b6427b30565727e426ec89a0a --- /dev/null +++ b/basicsr/utils/download_util.py @@ -0,0 +1,70 @@ +import math +import requests +from tqdm import tqdm + +from .misc import sizeof_fmt + + +def download_file_from_google_drive(file_id, save_path): + """Download files from google drive. + + Ref: + https://stackoverflow.com/questions/25010369/wget-curl-large-file-from-google-drive # noqa E501 + + Args: + file_id (str): File id. + save_path (str): Save path. + """ + + session = requests.Session() + URL = 'https://docs.google.com/uc?export=download' + params = {'id': file_id} + + response = session.get(URL, params=params, stream=True) + token = get_confirm_token(response) + if token: + params['confirm'] = token + response = session.get(URL, params=params, stream=True) + + # get file size + response_file_size = session.get( + URL, params=params, stream=True, headers={'Range': 'bytes=0-2'}) + if 'Content-Range' in response_file_size.headers: + file_size = int( + response_file_size.headers['Content-Range'].split('/')[1]) + else: + file_size = None + + save_response_content(response, save_path, file_size) + + +def get_confirm_token(response): + for key, value in response.cookies.items(): + if key.startswith('download_warning'): + return value + return None + + +def save_response_content(response, + destination, + file_size=None, + chunk_size=32768): + if file_size is not None: + pbar = tqdm(total=math.ceil(file_size / chunk_size), unit='chunk') + + readable_file_size = sizeof_fmt(file_size) + else: + pbar = None + + with open(destination, 'wb') as f: + downloaded_size = 0 + for chunk in response.iter_content(chunk_size): + downloaded_size += chunk_size + if pbar is not None: + pbar.update(1) + pbar.set_description(f'Download {sizeof_fmt(downloaded_size)} ' + f'/ {readable_file_size}') + if chunk: # filter out keep-alive new chunks + f.write(chunk) + if pbar is not None: + pbar.close() diff --git a/basicsr/utils/face_util.py b/basicsr/utils/face_util.py new file mode 100644 index 0000000000000000000000000000000000000000..33fe178a215af328118961f8ce67a3a3f11c2ed0 --- /dev/null +++ b/basicsr/utils/face_util.py @@ -0,0 +1,217 @@ +import cv2 +import numpy as np +import os +import torch +from skimage import transform as trans + +from basicsr.utils import imwrite + +try: + import dlib +except ImportError: + print('Please install dlib before testing face restoration.' + 'Reference: https://github.com/davisking/dlib') + + +class FaceRestorationHelper(object): + """Helper for the face restoration pipeline.""" + + def __init__(self, upscale_factor, face_size=512): + self.upscale_factor = upscale_factor + self.face_size = (face_size, face_size) + + # standard 5 landmarks for FFHQ faces with 1024 x 1024 + self.face_template = np.array([[686.77227723, 488.62376238], + [586.77227723, 493.59405941], + [337.91089109, 488.38613861], + [437.95049505, 493.51485149], + [513.58415842, 678.5049505]]) + self.face_template = self.face_template / (1024 // face_size) + # for estimation the 2D similarity transformation + self.similarity_trans = trans.SimilarityTransform() + + self.all_landmarks_5 = [] + self.all_landmarks_68 = [] + self.affine_matrices = [] + self.inverse_affine_matrices = [] + self.cropped_faces = [] + self.restored_faces = [] + self.save_png = True + + def init_dlib(self, detection_path, landmark5_path, landmark68_path): + """Initialize the dlib detectors and predictors.""" + self.face_detector = dlib.cnn_face_detection_model_v1(detection_path) + self.shape_predictor_5 = dlib.shape_predictor(landmark5_path) + self.shape_predictor_68 = dlib.shape_predictor(landmark68_path) + + def free_dlib_gpu_memory(self): + del self.face_detector + del self.shape_predictor_5 + del self.shape_predictor_68 + + def read_input_image(self, img_path): + # self.input_img is Numpy array, (h, w, c) with RGB order + self.input_img = dlib.load_rgb_image(img_path) + + def detect_faces(self, + img_path, + upsample_num_times=1, + only_keep_largest=False): + """ + Args: + img_path (str): Image path. + upsample_num_times (int): Upsamples the image before running the + face detector + + Returns: + int: Number of detected faces. + """ + self.read_input_image(img_path) + det_faces = self.face_detector(self.input_img, upsample_num_times) + if len(det_faces) == 0: + print('No face detected. Try to increase upsample_num_times.') + else: + if only_keep_largest: + print('Detect several faces and only keep the largest.') + face_areas = [] + for i in range(len(det_faces)): + face_area = (det_faces[i].rect.right() - + det_faces[i].rect.left()) * ( + det_faces[i].rect.bottom() - + det_faces[i].rect.top()) + face_areas.append(face_area) + largest_idx = face_areas.index(max(face_areas)) + self.det_faces = [det_faces[largest_idx]] + else: + self.det_faces = det_faces + return len(self.det_faces) + + def get_face_landmarks_5(self): + for face in self.det_faces: + shape = self.shape_predictor_5(self.input_img, face.rect) + landmark = np.array([[part.x, part.y] for part in shape.parts()]) + self.all_landmarks_5.append(landmark) + return len(self.all_landmarks_5) + + def get_face_landmarks_68(self): + """Get 68 densemarks for cropped images. + + Should only have one face at most in the cropped image. + """ + num_detected_face = 0 + for idx, face in enumerate(self.cropped_faces): + # face detection + det_face = self.face_detector(face, 1) # TODO: can we remove it? + if len(det_face) == 0: + print(f'Cannot find faces in cropped image with index {idx}.') + self.all_landmarks_68.append(None) + else: + if len(det_face) > 1: + print('Detect several faces in the cropped face. Use the ' + ' largest one. Note that it will also cause overlap ' + 'during paste_faces_to_input_image.') + face_areas = [] + for i in range(len(det_face)): + face_area = (det_face[i].rect.right() - + det_face[i].rect.left()) * ( + det_face[i].rect.bottom() - + det_face[i].rect.top()) + face_areas.append(face_area) + largest_idx = face_areas.index(max(face_areas)) + face_rect = det_face[largest_idx].rect + else: + face_rect = det_face[0].rect + shape = self.shape_predictor_68(face, face_rect) + landmark = np.array([[part.x, part.y] + for part in shape.parts()]) + self.all_landmarks_68.append(landmark) + num_detected_face += 1 + + return num_detected_face + + def warp_crop_faces(self, + save_cropped_path=None, + save_inverse_affine_path=None): + """Get affine matrix, warp and cropped faces. + + Also get inverse affine matrix for post-processing. + """ + for idx, landmark in enumerate(self.all_landmarks_5): + # use 5 landmarks to get affine matrix + self.similarity_trans.estimate(landmark, self.face_template) + affine_matrix = self.similarity_trans.params[0:2, :] + self.affine_matrices.append(affine_matrix) + # warp and crop faces + cropped_face = cv2.warpAffine(self.input_img, affine_matrix, + self.face_size) + self.cropped_faces.append(cropped_face) + # save the cropped face + if save_cropped_path is not None: + path, ext = os.path.splitext(save_cropped_path) + if self.save_png: + save_path = f'{path}_{idx:02d}.png' + else: + save_path = f'{path}_{idx:02d}{ext}' + + imwrite( + cv2.cvtColor(cropped_face, cv2.COLOR_RGB2BGR), save_path) + + # get inverse affine matrix + self.similarity_trans.estimate(self.face_template, + landmark * self.upscale_factor) + inverse_affine = self.similarity_trans.params[0:2, :] + self.inverse_affine_matrices.append(inverse_affine) + # save inverse affine matrices + if save_inverse_affine_path is not None: + path, _ = os.path.splitext(save_inverse_affine_path) + save_path = f'{path}_{idx:02d}.pth' + torch.save(inverse_affine, save_path) + + def add_restored_face(self, face): + self.restored_faces.append(face) + + def paste_faces_to_input_image(self, save_path): + # operate in the BGR order + input_img = cv2.cvtColor(self.input_img, cv2.COLOR_RGB2BGR) + h, w, _ = input_img.shape + h_up, w_up = h * self.upscale_factor, w * self.upscale_factor + # simply resize the background + upsample_img = cv2.resize(input_img, (w_up, h_up)) + assert len(self.restored_faces) == len(self.inverse_affine_matrices), ( + 'length of restored_faces and affine_matrices are different.') + for restored_face, inverse_affine in zip(self.restored_faces, + self.inverse_affine_matrices): + inv_restored = cv2.warpAffine(restored_face, inverse_affine, + (w_up, h_up)) + mask = np.ones((*self.face_size, 3), dtype=np.float32) + inv_mask = cv2.warpAffine(mask, inverse_affine, (w_up, h_up)) + # remove the black borders + inv_mask_erosion = cv2.erode( + inv_mask, + np.ones((2 * self.upscale_factor, 2 * self.upscale_factor), + np.uint8)) + inv_restored_remove_border = inv_mask_erosion * inv_restored + total_face_area = np.sum(inv_mask_erosion) // 3 + # compute the fusion edge based on the area of face + w_edge = int(total_face_area**0.5) // 20 + erosion_radius = w_edge * 2 + inv_mask_center = cv2.erode( + inv_mask_erosion, + np.ones((erosion_radius, erosion_radius), np.uint8)) + blur_size = w_edge * 2 + inv_soft_mask = cv2.GaussianBlur(inv_mask_center, + (blur_size + 1, blur_size + 1), 0) + upsample_img = inv_soft_mask * inv_restored_remove_border + ( + 1 - inv_soft_mask) * upsample_img + if self.save_png: + save_path = save_path.replace('.jpg', + '.png').replace('.jpeg', '.png') + imwrite(upsample_img.astype(np.uint8), save_path) + + def clean_all(self): + self.all_landmarks_5 = [] + self.all_landmarks_68 = [] + self.restored_faces = [] + self.affine_matrices = [] + self.cropped_faces = [] + self.inverse_affine_matrices = [] diff --git a/basicsr/utils/file_client.py b/basicsr/utils/file_client.py new file mode 100644 index 0000000000000000000000000000000000000000..d17250706cc9d1d8b416385e1c1e6d8ba97051fd --- /dev/null +++ b/basicsr/utils/file_client.py @@ -0,0 +1,186 @@ +# Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/fileio/file_client.py # noqa: E501 +from abc import ABCMeta, abstractmethod + + +class BaseStorageBackend(metaclass=ABCMeta): + """Abstract class of storage backends. + + All backends need to implement two apis: ``get()`` and ``get_text()``. + ``get()`` reads the file as a byte stream and ``get_text()`` reads the file + as texts. + """ + + @abstractmethod + def get(self, filepath): + pass + + @abstractmethod + def get_text(self, filepath): + pass + + +class MemcachedBackend(BaseStorageBackend): + """Memcached storage backend. + + Attributes: + server_list_cfg (str): Config file for memcached server list. + client_cfg (str): Config file for memcached client. + sys_path (str | None): Additional path to be appended to `sys.path`. + Default: None. + """ + + def __init__(self, server_list_cfg, client_cfg, sys_path=None): + if sys_path is not None: + import sys + sys.path.append(sys_path) + try: + import mc + except ImportError: + raise ImportError( + 'Please install memcached to enable MemcachedBackend.') + + self.server_list_cfg = server_list_cfg + self.client_cfg = client_cfg + self._client = mc.MemcachedClient.GetInstance(self.server_list_cfg, + self.client_cfg) + # mc.pyvector servers as a point which points to a memory cache + self._mc_buffer = mc.pyvector() + + def get(self, filepath): + filepath = str(filepath) + import mc + self._client.Get(filepath, self._mc_buffer) + value_buf = mc.ConvertBuffer(self._mc_buffer) + return value_buf + + def get_text(self, filepath): + raise NotImplementedError + + +class HardDiskBackend(BaseStorageBackend): + """Raw hard disks storage backend.""" + + def get(self, filepath): + filepath = str(filepath) + with open(filepath, 'rb') as f: + value_buf = f.read() + return value_buf + + def get_text(self, filepath): + filepath = str(filepath) + with open(filepath, 'r') as f: + value_buf = f.read() + return value_buf + + +class LmdbBackend(BaseStorageBackend): + """Lmdb storage backend. + + Args: + db_paths (str | list[str]): Lmdb database paths. + client_keys (str | list[str]): Lmdb client keys. Default: 'default'. + readonly (bool, optional): Lmdb environment parameter. If True, + disallow any write operations. Default: True. + lock (bool, optional): Lmdb environment parameter. If False, when + concurrent access occurs, do not lock the database. Default: False. + readahead (bool, optional): Lmdb environment parameter. If False, + disable the OS filesystem readahead mechanism, which may improve + random read performance when a database is larger than RAM. + Default: False. + + Attributes: + db_paths (list): Lmdb database path. + _client (list): A list of several lmdb envs. + """ + + def __init__(self, + db_paths, + client_keys='default', + readonly=True, + lock=False, + readahead=False, + **kwargs): + try: + import lmdb + except ImportError: + raise ImportError('Please install lmdb to enable LmdbBackend.') + + if isinstance(client_keys, str): + client_keys = [client_keys] + + if isinstance(db_paths, list): + self.db_paths = [str(v) for v in db_paths] + elif isinstance(db_paths, str): + self.db_paths = [str(db_paths)] + assert len(client_keys) == len(self.db_paths), ( + 'client_keys and db_paths should have the same length, ' + f'but received {len(client_keys)} and {len(self.db_paths)}.') + + self._client = {} + + for client, path in zip(client_keys, self.db_paths): + self._client[client] = lmdb.open( + path, + readonly=readonly, + lock=lock, + readahead=readahead, + map_size=8*1024*10485760, + # max_readers=1, + **kwargs) + + def get(self, filepath, client_key): + """Get values according to the filepath from one lmdb named client_key. + + Args: + filepath (str | obj:`Path`): Here, filepath is the lmdb key. + client_key (str): Used for distinguishing differnet lmdb envs. + """ + filepath = str(filepath) + assert client_key in self._client, (f'client_key {client_key} is not ' + 'in lmdb clients.') + client = self._client[client_key] + with client.begin(write=False) as txn: + value_buf = txn.get(filepath.encode('ascii')) + return value_buf + + def get_text(self, filepath): + raise NotImplementedError + + +class FileClient(object): + """A general file client to access files in different backend. + + The client loads a file or text in a specified backend from its path + and return it as a binary file. it can also register other backend + accessor with a given name and backend class. + + Attributes: + backend (str): The storage backend type. Options are "disk", + "memcached" and "lmdb". + client (:obj:`BaseStorageBackend`): The backend object. + """ + + _backends = { + 'disk': HardDiskBackend, + 'memcached': MemcachedBackend, + 'lmdb': LmdbBackend, + } + + def __init__(self, backend='disk', **kwargs): + if backend not in self._backends: + raise ValueError( + f'Backend {backend} is not supported. Currently supported ones' + f' are {list(self._backends.keys())}') + self.backend = backend + self.client = self._backends[backend](**kwargs) + + def get(self, filepath, client_key='default'): + # client_key is used only for lmdb, where different fileclients have + # different lmdb environments. + if self.backend == 'lmdb': + return self.client.get(filepath, client_key) + else: + return self.client.get(filepath) + + def get_text(self, filepath): + return self.client.get_text(filepath) diff --git a/basicsr/utils/flow_util.py b/basicsr/utils/flow_util.py new file mode 100644 index 0000000000000000000000000000000000000000..2b052cc9a9a71b4f1873ea09581effd0ffe8d1b3 --- /dev/null +++ b/basicsr/utils/flow_util.py @@ -0,0 +1,180 @@ +# Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/video/optflow.py # noqa: E501 +import cv2 +import numpy as np +import os + + +def flowread(flow_path, quantize=False, concat_axis=0, *args, **kwargs): + """Read an optical flow map. + + Args: + flow_path (ndarray or str): Flow path. + quantize (bool): whether to read quantized pair, if set to True, + remaining args will be passed to :func:`dequantize_flow`. + concat_axis (int): The axis that dx and dy are concatenated, + can be either 0 or 1. Ignored if quantize is False. + + Returns: + ndarray: Optical flow represented as a (h, w, 2) numpy array + """ + if quantize: + assert concat_axis in [0, 1] + cat_flow = cv2.imread(flow_path, cv2.IMREAD_UNCHANGED) + if cat_flow.ndim != 2: + raise IOError(f'{flow_path} is not a valid quantized flow file, ' + f'its dimension is {cat_flow.ndim}.') + assert cat_flow.shape[concat_axis] % 2 == 0 + dx, dy = np.split(cat_flow, 2, axis=concat_axis) + flow = dequantize_flow(dx, dy, *args, **kwargs) + else: + with open(flow_path, 'rb') as f: + try: + header = f.read(4).decode('utf-8') + except Exception: + raise IOError(f'Invalid flow file: {flow_path}') + else: + if header != 'PIEH': + raise IOError(f'Invalid flow file: {flow_path}, ' + 'header does not contain PIEH') + + w = np.fromfile(f, np.int32, 1).squeeze() + h = np.fromfile(f, np.int32, 1).squeeze() + flow = np.fromfile(f, np.float32, w * h * 2).reshape((h, w, 2)) + + return flow.astype(np.float32) + + +def flowwrite(flow, filename, quantize=False, concat_axis=0, *args, **kwargs): + """Write optical flow to file. + + If the flow is not quantized, it will be saved as a .flo file losslessly, + otherwise a jpeg image which is lossy but of much smaller size. (dx and dy + will be concatenated horizontally into a single image if quantize is True.) + + Args: + flow (ndarray): (h, w, 2) array of optical flow. + filename (str): Output filepath. + quantize (bool): Whether to quantize the flow and save it to 2 jpeg + images. If set to True, remaining args will be passed to + :func:`quantize_flow`. + concat_axis (int): The axis that dx and dy are concatenated, + can be either 0 or 1. Ignored if quantize is False. + """ + if not quantize: + with open(filename, 'wb') as f: + f.write('PIEH'.encode('utf-8')) + np.array([flow.shape[1], flow.shape[0]], dtype=np.int32).tofile(f) + flow = flow.astype(np.float32) + flow.tofile(f) + f.flush() + else: + assert concat_axis in [0, 1] + dx, dy = quantize_flow(flow, *args, **kwargs) + dxdy = np.concatenate((dx, dy), axis=concat_axis) + os.makedirs(filename, exist_ok=True) + cv2.imwrite(dxdy, filename) + + +def quantize_flow(flow, max_val=0.02, norm=True): + """Quantize flow to [0, 255]. + + After this step, the size of flow will be much smaller, and can be + dumped as jpeg images. + + Args: + flow (ndarray): (h, w, 2) array of optical flow. + max_val (float): Maximum value of flow, values beyond + [-max_val, max_val] will be truncated. + norm (bool): Whether to divide flow values by image width/height. + + Returns: + tuple[ndarray]: Quantized dx and dy. + """ + h, w, _ = flow.shape + dx = flow[..., 0] + dy = flow[..., 1] + if norm: + dx = dx / w # avoid inplace operations + dy = dy / h + # use 255 levels instead of 256 to make sure 0 is 0 after dequantization. + flow_comps = [ + quantize(d, -max_val, max_val, 255, np.uint8) for d in [dx, dy] + ] + return tuple(flow_comps) + + +def dequantize_flow(dx, dy, max_val=0.02, denorm=True): + """Recover from quantized flow. + + Args: + dx (ndarray): Quantized dx. + dy (ndarray): Quantized dy. + max_val (float): Maximum value used when quantizing. + denorm (bool): Whether to multiply flow values with width/height. + + Returns: + ndarray: Dequantized flow. + """ + assert dx.shape == dy.shape + assert dx.ndim == 2 or (dx.ndim == 3 and dx.shape[-1] == 1) + + dx, dy = [dequantize(d, -max_val, max_val, 255) for d in [dx, dy]] + + if denorm: + dx *= dx.shape[1] + dy *= dx.shape[0] + flow = np.dstack((dx, dy)) + return flow + + +def quantize(arr, min_val, max_val, levels, dtype=np.int64): + """Quantize an array of (-inf, inf) to [0, levels-1]. + + Args: + arr (ndarray): Input array. + min_val (scalar): Minimum value to be clipped. + max_val (scalar): Maximum value to be clipped. + levels (int): Quantization levels. + dtype (np.type): The type of the quantized array. + + Returns: + tuple: Quantized array. + """ + if not (isinstance(levels, int) and levels > 1): + raise ValueError( + f'levels must be a positive integer, but got {levels}') + if min_val >= max_val: + raise ValueError( + f'min_val ({min_val}) must be smaller than max_val ({max_val})') + + arr = np.clip(arr, min_val, max_val) - min_val + quantized_arr = np.minimum( + np.floor(levels * arr / (max_val - min_val)).astype(dtype), levels - 1) + + return quantized_arr + + +def dequantize(arr, min_val, max_val, levels, dtype=np.float64): + """Dequantize an array. + + Args: + arr (ndarray): Input array. + min_val (scalar): Minimum value to be clipped. + max_val (scalar): Maximum value to be clipped. + levels (int): Quantization levels. + dtype (np.type): The type of the dequantized array. + + Returns: + tuple: Dequantized array. + """ + if not (isinstance(levels, int) and levels > 1): + raise ValueError( + f'levels must be a positive integer, but got {levels}') + if min_val >= max_val: + raise ValueError( + f'min_val ({min_val}) must be smaller than max_val ({max_val})') + + dequantized_arr = (arr + 0.5).astype(dtype) * (max_val - + min_val) / levels + min_val + + return dequantized_arr diff --git a/basicsr/utils/img_util.py b/basicsr/utils/img_util.py new file mode 100644 index 0000000000000000000000000000000000000000..ab563b36375287cb154408f4fd52fcc812f7b3a2 --- /dev/null +++ b/basicsr/utils/img_util.py @@ -0,0 +1,216 @@ +import cv2 +import math +import numpy as np +import os +import torch +from torchvision.utils import make_grid + + +def img2tensor(imgs, bgr2rgb=True, float32=True): + """Numpy array to tensor. + + Args: + imgs (list[ndarray] | ndarray): Input images. + bgr2rgb (bool): Whether to change bgr to rgb. + float32 (bool): Whether to change to float32. + + Returns: + list[tensor] | tensor: Tensor images. If returned results only have + one element, just return tensor. + """ + + def _totensor(img, bgr2rgb, float32): + if img.shape[2] == 3 and bgr2rgb: + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + img = torch.from_numpy(img.transpose(2, 0, 1)) + if float32: + img = img.float() + return img + + if isinstance(imgs, list): + return [_totensor(img, bgr2rgb, float32) for img in imgs] + else: + return _totensor(imgs, bgr2rgb, float32) + + +def tensor2img(tensor, rgb2bgr=True, out_type=np.uint8, min_max=(0, 1)): + """Convert torch Tensors into image numpy arrays. + + After clamping to [min, max], values will be normalized to [0, 1]. + + Args: + tensor (Tensor or list[Tensor]): Accept shapes: + 1) 4D mini-batch Tensor of shape (B x 3/1 x H x W); + 2) 3D Tensor of shape (3/1 x H x W); + 3) 2D Tensor of shape (H x W). + Tensor channel should be in RGB order. + rgb2bgr (bool): Whether to change rgb to bgr. + out_type (numpy type): output types. If ``np.uint8``, transform outputs + to uint8 type with range [0, 255]; otherwise, float type with + range [0, 1]. Default: ``np.uint8``. + min_max (tuple[int]): min and max values for clamp. + + Returns: + (Tensor or list): 3D ndarray of shape (H x W x C) OR 2D ndarray of + shape (H x W). The channel order is BGR. + """ + if not (torch.is_tensor(tensor) or + (isinstance(tensor, list) + and all(torch.is_tensor(t) for t in tensor))): + raise TypeError( + f'tensor or list of tensors expected, got {type(tensor)}') + + if torch.is_tensor(tensor): + tensor = [tensor] + result = [] + for _tensor in tensor: + _tensor = _tensor.squeeze(0).float().detach().cpu().clamp_(*min_max) + _tensor = (_tensor - min_max[0]) / (min_max[1] - min_max[0]) + + n_dim = _tensor.dim() + if n_dim == 4: + img_np = make_grid( + _tensor, nrow=int(math.sqrt(_tensor.size(0))), + normalize=False).numpy() + img_np = img_np.transpose(1, 2, 0) + if rgb2bgr: + img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR) + elif n_dim == 3: + img_np = _tensor.numpy() + img_np = img_np.transpose(1, 2, 0) + if img_np.shape[2] == 1: # gray image + img_np = np.squeeze(img_np, axis=2) + else: + if rgb2bgr: + img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR) + elif n_dim == 2: + img_np = _tensor.numpy() + else: + raise TypeError('Only support 4D, 3D or 2D tensor. ' + f'But received with dimension: {n_dim}') + if out_type == np.uint8: + # Unlike MATLAB, numpy.unit8() WILL NOT round by default. + img_np = (img_np * 255.0).round() + img_np = img_np.astype(out_type) + result.append(img_np) + if len(result) == 1: + result = result[0] + return result + + +def imfrombytes(content, flag='color', float32=False): + """Read an image from bytes. + + Args: + content (bytes): Image bytes got from files or other streams. + flag (str): Flags specifying the color type of a loaded image, + candidates are `color`, `grayscale` and `unchanged`. + float32 (bool): Whether to change to float32., If True, will also norm + to [0, 1]. Default: False. + + Returns: + ndarray: Loaded image array. + """ + img_np = np.frombuffer(content, np.uint8) + imread_flags = { + 'color': cv2.IMREAD_COLOR, + 'grayscale': cv2.IMREAD_GRAYSCALE, + 'unchanged': cv2.IMREAD_UNCHANGED + } + if img_np is None: + raise Exception('None .. !!!') + img = cv2.imdecode(img_np, imread_flags[flag]) + if float32: + img = img.astype(np.float32) / 255. + return img + +def imfrombytesDP(content, flag='color', float32=False): + """Read an image from bytes. + + Args: + content (bytes): Image bytes got from files or other streams. + flag (str): Flags specifying the color type of a loaded image, + candidates are `color`, `grayscale` and `unchanged`. + float32 (bool): Whether to change to float32., If True, will also norm + to [0, 1]. Default: False. + + Returns: + ndarray: Loaded image array. + """ + img_np = np.frombuffer(content, np.uint8) + if img_np is None: + raise Exception('None .. !!!') + img = cv2.imdecode(img_np, cv2.IMREAD_UNCHANGED) + if float32: + img = img.astype(np.float32) / 65535. + return img + +def padding(img_gt, gt_size): + h, w, _ = img_gt.shape + + h_pad = max(0, gt_size - h) + w_pad = max(0, gt_size - w) + + if h_pad == 0 and w_pad == 0: + return img_gt + + img_gt = cv2.copyMakeBorder(img_gt, 0, h_pad, 0, w_pad, cv2.BORDER_REFLECT) + if img_gt.ndim == 2: + img_gt = np.expand_dims(img_gt, axis=2) + return img_gt + +def padding_DP(img_lqL, img_lqR, img_gt, gt_size): + h, w, _ = img_gt.shape + + h_pad = max(0, gt_size - h) + w_pad = max(0, gt_size - w) + + if h_pad == 0 and w_pad == 0: + return img_lqL, img_lqR, img_gt + + img_lqL = cv2.copyMakeBorder(img_lqL, 0, h_pad, 0, w_pad, cv2.BORDER_REFLECT) + img_lqR = cv2.copyMakeBorder(img_lqR, 0, h_pad, 0, w_pad, cv2.BORDER_REFLECT) + img_gt = cv2.copyMakeBorder(img_gt, 0, h_pad, 0, w_pad, cv2.BORDER_REFLECT) + # print('img_lq', img_lq.shape, img_gt.shape) + return img_lqL, img_lqR, img_gt + +def imwrite(img, file_path, params=None, auto_mkdir=True): + """Write image to file. + + Args: + img (ndarray): Image array to be written. + file_path (str): Image file path. + params (None or list): Same as opencv's :func:`imwrite` interface. + auto_mkdir (bool): If the parent folder of `file_path` does not exist, + whether to create it automatically. + + Returns: + bool: Successful or not. + """ + if auto_mkdir: + dir_name = os.path.abspath(os.path.dirname(file_path)) + os.makedirs(dir_name, exist_ok=True) + return cv2.imwrite(file_path, img, params) + + +def crop_border(imgs, crop_border): + """Crop borders of images. + + Args: + imgs (list[ndarray] | ndarray): Images with shape (h, w, c). + crop_border (int): Crop border for each end of height and weight. + + Returns: + list[ndarray]: Cropped images. + """ + if crop_border == 0: + return imgs + else: + if isinstance(imgs, list): + return [ + v[crop_border:-crop_border, crop_border:-crop_border, ...] + for v in imgs + ] + else: + return imgs[crop_border:-crop_border, crop_border:-crop_border, + ...] diff --git a/basicsr/utils/lmdb_util.py b/basicsr/utils/lmdb_util.py new file mode 100644 index 0000000000000000000000000000000000000000..a81278fc0d38c451b5f785a405c4c356cd248cab --- /dev/null +++ b/basicsr/utils/lmdb_util.py @@ -0,0 +1,208 @@ +import cv2 +import lmdb +import sys +from multiprocessing import Pool +from os import path as osp +from tqdm import tqdm + + +def make_lmdb_from_imgs(data_path, + lmdb_path, + img_path_list, + keys, + batch=5000, + compress_level=1, + multiprocessing_read=False, + n_thread=40, + map_size=None): + """Make lmdb from images. + + Contents of lmdb. The file structure is: + example.lmdb + ├── data.mdb + ├── lock.mdb + ├── meta_info.txt + + The data.mdb and lock.mdb are standard lmdb files and you can refer to + https://lmdb.readthedocs.io/en/release/ for more details. + + The meta_info.txt is a specified txt file to record the meta information + of our datasets. It will be automatically created when preparing + datasets by our provided dataset tools. + Each line in the txt file records 1)image name (with extension), + 2)image shape, and 3)compression level, separated by a white space. + + For example, the meta information could be: + `000_00000000.png (720,1280,3) 1`, which means: + 1) image name (with extension): 000_00000000.png; + 2) image shape: (720,1280,3); + 3) compression level: 1 + + We use the image name without extension as the lmdb key. + + If `multiprocessing_read` is True, it will read all the images to memory + using multiprocessing. Thus, your server needs to have enough memory. + + Args: + data_path (str): Data path for reading images. + lmdb_path (str): Lmdb save path. + img_path_list (str): Image path list. + keys (str): Used for lmdb keys. + batch (int): After processing batch images, lmdb commits. + Default: 5000. + compress_level (int): Compress level when encoding images. Default: 1. + multiprocessing_read (bool): Whether use multiprocessing to read all + the images to memory. Default: False. + n_thread (int): For multiprocessing. + map_size (int | None): Map size for lmdb env. If None, use the + estimated size from images. Default: None + """ + + assert len(img_path_list) == len(keys), ( + 'img_path_list and keys should have the same length, ' + f'but got {len(img_path_list)} and {len(keys)}') + print(f'Create lmdb for {data_path}, save to {lmdb_path}...') + print(f'Totoal images: {len(img_path_list)}') + if not lmdb_path.endswith('.lmdb'): + raise ValueError("lmdb_path must end with '.lmdb'.") + if osp.exists(lmdb_path): + print(f'Folder {lmdb_path} already exists. Exit.') + sys.exit(1) + + if multiprocessing_read: + # read all the images to memory (multiprocessing) + dataset = {} # use dict to keep the order for multiprocessing + shapes = {} + print(f'Read images with multiprocessing, #thread: {n_thread} ...') + pbar = tqdm(total=len(img_path_list), unit='image') + + def callback(arg): + """get the image data and update pbar.""" + key, dataset[key], shapes[key] = arg + pbar.update(1) + pbar.set_description(f'Read {key}') + + pool = Pool(n_thread) + for path, key in zip(img_path_list, keys): + pool.apply_async( + read_img_worker, + args=(osp.join(data_path, path), key, compress_level), + callback=callback) + pool.close() + pool.join() + pbar.close() + print(f'Finish reading {len(img_path_list)} images.') + + # create lmdb environment + if map_size is None: + # obtain data size for one image + img = cv2.imread( + osp.join(data_path, img_path_list[0]), cv2.IMREAD_UNCHANGED) + _, img_byte = cv2.imencode( + '.png', img, [cv2.IMWRITE_PNG_COMPRESSION, compress_level]) + data_size_per_img = img_byte.nbytes + print('Data size per image is: ', data_size_per_img) + data_size = data_size_per_img * len(img_path_list) + map_size = data_size * 10 + + env = lmdb.open(lmdb_path, map_size=map_size) + + # write data to lmdb + pbar = tqdm(total=len(img_path_list), unit='chunk') + txn = env.begin(write=True) + txt_file = open(osp.join(lmdb_path, 'meta_info.txt'), 'w') + for idx, (path, key) in enumerate(zip(img_path_list, keys)): + pbar.update(1) + pbar.set_description(f'Write {key}') + key_byte = key.encode('ascii') + if multiprocessing_read: + img_byte = dataset[key] + h, w, c = shapes[key] + else: + _, img_byte, img_shape = read_img_worker( + osp.join(data_path, path), key, compress_level) + h, w, c = img_shape + + txn.put(key_byte, img_byte) + # write meta information + txt_file.write(f'{key}.png ({h},{w},{c}) {compress_level}\n') + if idx % batch == 0: + txn.commit() + txn = env.begin(write=True) + pbar.close() + txn.commit() + env.close() + txt_file.close() + print('\nFinish writing lmdb.') + + +def read_img_worker(path, key, compress_level): + """Read image worker. + + Args: + path (str): Image path. + key (str): Image key. + compress_level (int): Compress level when encoding images. + + Returns: + str: Image key. + byte: Image byte. + tuple[int]: Image shape. + """ + + img = cv2.imread(path, cv2.IMREAD_UNCHANGED) + if img.ndim == 2: + h, w = img.shape + c = 1 + else: + h, w, c = img.shape + _, img_byte = cv2.imencode('.png', img, + [cv2.IMWRITE_PNG_COMPRESSION, compress_level]) + return (key, img_byte, (h, w, c)) + + +class LmdbMaker(): + """LMDB Maker. + + Args: + lmdb_path (str): Lmdb save path. + map_size (int): Map size for lmdb env. Default: 1024 ** 4, 1TB. + batch (int): After processing batch images, lmdb commits. + Default: 5000. + compress_level (int): Compress level when encoding images. Default: 1. + """ + + def __init__(self, + lmdb_path, + map_size=1024**4, + batch=5000, + compress_level=1): + if not lmdb_path.endswith('.lmdb'): + raise ValueError("lmdb_path must end with '.lmdb'.") + if osp.exists(lmdb_path): + print(f'Folder {lmdb_path} already exists. Exit.') + sys.exit(1) + + self.lmdb_path = lmdb_path + self.batch = batch + self.compress_level = compress_level + self.env = lmdb.open(lmdb_path, map_size=map_size) + self.txn = self.env.begin(write=True) + self.txt_file = open(osp.join(lmdb_path, 'meta_info.txt'), 'w') + self.counter = 0 + + def put(self, img_byte, key, img_shape): + self.counter += 1 + key_byte = key.encode('ascii') + self.txn.put(key_byte, img_byte) + # write meta information + h, w, c = img_shape + self.txt_file.write(f'{key}.png ({h},{w},{c}) {self.compress_level}\n') + if self.counter % self.batch == 0: + self.txn.commit() + self.txn = self.env.begin(write=True) + + def close(self): + self.txn.commit() + self.env.close() + self.txt_file.close() diff --git a/basicsr/utils/logger.py b/basicsr/utils/logger.py new file mode 100644 index 0000000000000000000000000000000000000000..bab2b6f6ec00a68d44b5b10be5445a9f425420a3 --- /dev/null +++ b/basicsr/utils/logger.py @@ -0,0 +1,175 @@ +import datetime +import logging +import time + +from .dist_util import get_dist_info, master_only + +initialized_logger = {} + + +class MessageLogger(): + """Message logger for printing. + + Args: + opt (dict): Config. It contains the following keys: + name (str): Exp name. + logger (dict): Contains 'print_freq' (str) for logger interval. + train (dict): Contains 'total_iter' (int) for total iters. + use_tb_logger (bool): Use tensorboard logger. + start_iter (int): Start iter. Default: 1. + tb_logger (obj:`tb_logger`): Tensorboard logger. Default: None. + """ + + def __init__(self, opt, start_iter=1, tb_logger=None): + self.exp_name = opt['name'] + self.interval = opt['logger']['print_freq'] + self.start_iter = start_iter + self.max_iters = opt['train']['total_iter'] + self.use_tb_logger = opt['logger']['use_tb_logger'] + self.tb_logger = tb_logger + self.start_time = time.time() + self.logger = get_root_logger() + + @master_only + def __call__(self, log_vars): + """Format logging message. + + Args: + log_vars (dict): It contains the following keys: + epoch (int): Epoch number. + iter (int): Current iter. + lrs (list): List for learning rates. + + time (float): Iter time. + data_time (float): Data time for each iter. + """ + # epoch, iter, learning rates + epoch = log_vars.pop('epoch') + current_iter = log_vars.pop('iter') + lrs = log_vars.pop('lrs') + + message = (f'[{self.exp_name[:5]}..][epoch:{epoch:3d}, ' f'iter:{current_iter:8,d}, lr:(') + for v in lrs: + message += f'{v:.3e},' + message += ')] ' + + # time and estimated time + if 'time' in log_vars.keys(): + iter_time = log_vars.pop('time') + data_time = log_vars.pop('data_time') + + total_time = time.time() - self.start_time + time_sec_avg = total_time / (current_iter - self.start_iter + 1) + eta_sec = time_sec_avg * (self.max_iters - current_iter - 1) + eta_str = str(datetime.timedelta(seconds=int(eta_sec))) + message += f'[eta: {eta_str}, ' + message += f'time (data): {iter_time:.3f} ({data_time:.3f})] ' + + # other items, especially losses + for k, v in log_vars.items(): + message += f'{k}: {v:.4e} ' + # tensorboard logger + if self.use_tb_logger and 'debug' not in self.exp_name: + if k.startswith('l_'): + self.tb_logger.add_scalar(f'losses/{k}', v, current_iter) + else: + self.tb_logger.add_scalar(k, v, current_iter) + self.logger.info(message) + + +@master_only +def init_tb_logger(log_dir): + from torch.utils.tensorboard import SummaryWriter + tb_logger = SummaryWriter(log_dir=log_dir) + return tb_logger + + +@master_only +def init_wandb_logger(opt): + """We now only use wandb to sync tensorboard log.""" + import wandb + logger = logging.getLogger('basicsr') + + project = opt['logger']['wandb']['project'] + resume_id = opt['logger']['wandb'].get('resume_id') + if resume_id: + wandb_id = resume_id + resume = 'allow' + logger.warning(f'Resume wandb logger with id={wandb_id}.') + else: + wandb_id = wandb.util.generate_id() + resume = 'never' + + wandb.init(id=wandb_id, resume=resume, name=opt['name'], config=opt, project=project, sync_tensorboard=True) + + logger.info(f'Use wandb logger with id={wandb_id}; project={project}.') + + +def get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=None): + """Get the root logger. + + The logger will be initialized if it has not been initialized. By default a + StreamHandler will be added. If `log_file` is specified, a FileHandler will + also be added. + + Args: + logger_name (str): root logger name. Default: 'basicsr'. + log_file (str | None): The log filename. If specified, a FileHandler + will be added to the root logger. + log_level (int): The root logger level. Note that only the process of + rank 0 is affected, while other processes will set the level to + "Error" and be silent most of the time. + + Returns: + logging.Logger: The root logger. + """ + logger = logging.getLogger(logger_name) + # if the logger has been initialized, just return it + if logger_name in initialized_logger: + return logger + + format_str = '%(asctime)s %(levelname)s: %(message)s' + stream_handler = logging.StreamHandler() + stream_handler.setFormatter(logging.Formatter(format_str)) + logger.addHandler(stream_handler) + logger.propagate = False + rank, _ = get_dist_info() + if rank != 0: + logger.setLevel('ERROR') + elif log_file is not None: + logger.setLevel(log_level) + # add file handler + file_handler = logging.FileHandler(log_file, 'w') + file_handler.setFormatter(logging.Formatter(format_str)) + file_handler.setLevel(log_level) + logger.addHandler(file_handler) + initialized_logger[logger_name] = True + return logger + + +def get_env_info(): + """Get environment information. + + Currently, only log the software version. + """ + import torch + import torchvision + + from basicsr.version import __version__ + msg = r""" + ____ _ _____ ____ + / __ ) ____ _ _____ (_)_____/ ___/ / __ \ + / __ |/ __ `// ___// // ___/\__ \ / /_/ / + / /_/ // /_/ /(__ )/ // /__ ___/ // _, _/ + /_____/ \__,_//____//_/ \___//____//_/ |_| + ______ __ __ __ __ + / ____/____ ____ ____/ / / / __ __ _____ / /__ / / + / / __ / __ \ / __ \ / __ / / / / / / // ___// //_/ / / + / /_/ // /_/ // /_/ // /_/ / / /___/ /_/ // /__ / /< /_/ + \____/ \____/ \____/ \____/ /_____/\____/ \___//_/|_| (_) + """ + msg += ('\nVersion Information: ' + f'\n\tBasicSR: {__version__}' + f'\n\tPyTorch: {torch.__version__}' + f'\n\tTorchVision: {torchvision.__version__}') + return msg \ No newline at end of file diff --git a/basicsr/utils/matlab_functions.py b/basicsr/utils/matlab_functions.py new file mode 100644 index 0000000000000000000000000000000000000000..cd96a2c8e822b633c8b5d5ffda35259ad0a8765b --- /dev/null +++ b/basicsr/utils/matlab_functions.py @@ -0,0 +1,361 @@ +import math +import numpy as np +import torch + + +def cubic(x): + """cubic function used for calculate_weights_indices.""" + absx = torch.abs(x) + absx2 = absx**2 + absx3 = absx**3 + return (1.5 * absx3 - 2.5 * absx2 + 1) * ( + (absx <= 1).type_as(absx)) + (-0.5 * absx3 + 2.5 * absx2 - 4 * absx + + 2) * (((absx > 1) * + (absx <= 2)).type_as(absx)) + + +def calculate_weights_indices(in_length, out_length, scale, kernel, + kernel_width, antialiasing): + """Calculate weights and indices, used for imresize function. + + Args: + in_length (int): Input length. + out_length (int): Output length. + scale (float): Scale factor. + kernel_width (int): Kernel width. + antialisaing (bool): Whether to apply anti-aliasing when downsampling. + """ + + if (scale < 1) and antialiasing: + # Use a modified kernel (larger kernel width) to simultaneously + # interpolate and antialias + kernel_width = kernel_width / scale + + # Output-space coordinates + x = torch.linspace(1, out_length, out_length) + + # Input-space coordinates. Calculate the inverse mapping such that 0.5 + # in output space maps to 0.5 in input space, and 0.5 + scale in output + # space maps to 1.5 in input space. + u = x / scale + 0.5 * (1 - 1 / scale) + + # What is the left-most pixel that can be involved in the computation? + left = torch.floor(u - kernel_width / 2) + + # What is the maximum number of pixels that can be involved in the + # computation? Note: it's OK to use an extra pixel here; if the + # corresponding weights are all zero, it will be eliminated at the end + # of this function. + p = math.ceil(kernel_width) + 2 + + # The indices of the input pixels involved in computing the k-th output + # pixel are in row k of the indices matrix. + indices = left.view(out_length, 1).expand(out_length, p) + torch.linspace( + 0, p - 1, p).view(1, p).expand(out_length, p) + + # The weights used to compute the k-th output pixel are in row k of the + # weights matrix. + distance_to_center = u.view(out_length, 1).expand(out_length, p) - indices + + # apply cubic kernel + if (scale < 1) and antialiasing: + weights = scale * cubic(distance_to_center * scale) + else: + weights = cubic(distance_to_center) + + # Normalize the weights matrix so that each row sums to 1. + weights_sum = torch.sum(weights, 1).view(out_length, 1) + weights = weights / weights_sum.expand(out_length, p) + + # If a column in weights is all zero, get rid of it. only consider the + # first and last column. + weights_zero_tmp = torch.sum((weights == 0), 0) + if not math.isclose(weights_zero_tmp[0], 0, rel_tol=1e-6): + indices = indices.narrow(1, 1, p - 2) + weights = weights.narrow(1, 1, p - 2) + if not math.isclose(weights_zero_tmp[-1], 0, rel_tol=1e-6): + indices = indices.narrow(1, 0, p - 2) + weights = weights.narrow(1, 0, p - 2) + weights = weights.contiguous() + indices = indices.contiguous() + sym_len_s = -indices.min() + 1 + sym_len_e = indices.max() - in_length + indices = indices + sym_len_s - 1 + return weights, indices, int(sym_len_s), int(sym_len_e) + + +@torch.no_grad() +def imresize(img, scale, antialiasing=True): + """imresize function same as MATLAB. + + It now only supports bicubic. + The same scale applies for both height and width. + + Args: + img (Tensor | Numpy array): + Tensor: Input image with shape (c, h, w), [0, 1] range. + Numpy: Input image with shape (h, w, c), [0, 1] range. + scale (float): Scale factor. The same scale applies for both height + and width. + antialisaing (bool): Whether to apply anti-aliasing when downsampling. + Default: True. + + Returns: + Tensor: Output image with shape (c, h, w), [0, 1] range, w/o round. + """ + if type(img).__module__ == np.__name__: # numpy type + numpy_type = True + img = torch.from_numpy(img.transpose(2, 0, 1)).float() + else: + numpy_type = False + + in_c, in_h, in_w = img.size() + out_h, out_w = math.ceil(in_h * scale), math.ceil(in_w * scale) + kernel_width = 4 + kernel = 'cubic' + + # get weights and indices + weights_h, indices_h, sym_len_hs, sym_len_he = calculate_weights_indices( + in_h, out_h, scale, kernel, kernel_width, antialiasing) + weights_w, indices_w, sym_len_ws, sym_len_we = calculate_weights_indices( + in_w, out_w, scale, kernel, kernel_width, antialiasing) + # process H dimension + # symmetric copying + img_aug = torch.FloatTensor(in_c, in_h + sym_len_hs + sym_len_he, in_w) + img_aug.narrow(1, sym_len_hs, in_h).copy_(img) + + sym_patch = img[:, :sym_len_hs, :] + inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(1, inv_idx) + img_aug.narrow(1, 0, sym_len_hs).copy_(sym_patch_inv) + + sym_patch = img[:, -sym_len_he:, :] + inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(1, inv_idx) + img_aug.narrow(1, sym_len_hs + in_h, sym_len_he).copy_(sym_patch_inv) + + out_1 = torch.FloatTensor(in_c, out_h, in_w) + kernel_width = weights_h.size(1) + for i in range(out_h): + idx = int(indices_h[i][0]) + for j in range(in_c): + out_1[j, i, :] = img_aug[j, idx:idx + kernel_width, :].transpose( + 0, 1).mv(weights_h[i]) + + # process W dimension + # symmetric copying + out_1_aug = torch.FloatTensor(in_c, out_h, in_w + sym_len_ws + sym_len_we) + out_1_aug.narrow(2, sym_len_ws, in_w).copy_(out_1) + + sym_patch = out_1[:, :, :sym_len_ws] + inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(2, inv_idx) + out_1_aug.narrow(2, 0, sym_len_ws).copy_(sym_patch_inv) + + sym_patch = out_1[:, :, -sym_len_we:] + inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(2, inv_idx) + out_1_aug.narrow(2, sym_len_ws + in_w, sym_len_we).copy_(sym_patch_inv) + + out_2 = torch.FloatTensor(in_c, out_h, out_w) + kernel_width = weights_w.size(1) + for i in range(out_w): + idx = int(indices_w[i][0]) + for j in range(in_c): + out_2[j, :, i] = out_1_aug[j, :, + idx:idx + kernel_width].mv(weights_w[i]) + + if numpy_type: + out_2 = out_2.numpy().transpose(1, 2, 0) + return out_2 + + +def rgb2ycbcr(img, y_only=False): + """Convert a RGB image to YCbCr image. + + This function produces the same results as Matlab's `rgb2ycbcr` function. + It implements the ITU-R BT.601 conversion for standard-definition + television. See more details in + https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion. + + It differs from a similar function in cv2.cvtColor: `RGB <-> YCrCb`. + In OpenCV, it implements a JPEG conversion. See more details in + https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion. + + Args: + img (ndarray): The input image. It accepts: + 1. np.uint8 type with range [0, 255]; + 2. np.float32 type with range [0, 1]. + y_only (bool): Whether to only return Y channel. Default: False. + + Returns: + ndarray: The converted YCbCr image. The output image has the same type + and range as input image. + """ + img_type = img.dtype + img = _convert_input_type_range(img) + if y_only: + out_img = np.dot(img, [65.481, 128.553, 24.966]) + 16.0 + else: + out_img = np.matmul( + img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786], + [24.966, 112.0, -18.214]]) + [16, 128, 128] + out_img = _convert_output_type_range(out_img, img_type) + return out_img + + +def bgr2ycbcr(img, y_only=False): + """Convert a BGR image to YCbCr image. + + The bgr version of rgb2ycbcr. + It implements the ITU-R BT.601 conversion for standard-definition + television. See more details in + https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion. + + It differs from a similar function in cv2.cvtColor: `BGR <-> YCrCb`. + In OpenCV, it implements a JPEG conversion. See more details in + https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion. + + Args: + img (ndarray): The input image. It accepts: + 1. np.uint8 type with range [0, 255]; + 2. np.float32 type with range [0, 1]. + y_only (bool): Whether to only return Y channel. Default: False. + + Returns: + ndarray: The converted YCbCr image. The output image has the same type + and range as input image. + """ + img_type = img.dtype + img = _convert_input_type_range(img) + if y_only: + out_img = np.dot(img, [24.966, 128.553, 65.481]) + 16.0 + else: + out_img = np.matmul( + img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786], + [65.481, -37.797, 112.0]]) + [16, 128, 128] + out_img = _convert_output_type_range(out_img, img_type) + return out_img + + +def ycbcr2rgb(img): + """Convert a YCbCr image to RGB image. + + This function produces the same results as Matlab's ycbcr2rgb function. + It implements the ITU-R BT.601 conversion for standard-definition + television. See more details in + https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion. + + It differs from a similar function in cv2.cvtColor: `YCrCb <-> RGB`. + In OpenCV, it implements a JPEG conversion. See more details in + https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion. + + Args: + img (ndarray): The input image. It accepts: + 1. np.uint8 type with range [0, 255]; + 2. np.float32 type with range [0, 1]. + + Returns: + ndarray: The converted RGB image. The output image has the same type + and range as input image. + """ + img_type = img.dtype + img = _convert_input_type_range(img) * 255 + out_img = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], + [0, -0.00153632, 0.00791071], + [0.00625893, -0.00318811, 0]]) * 255.0 + [ + -222.921, 135.576, -276.836 + ] # noqa: E126 + out_img = _convert_output_type_range(out_img, img_type) + return out_img + + +def ycbcr2bgr(img): + """Convert a YCbCr image to BGR image. + + The bgr version of ycbcr2rgb. + It implements the ITU-R BT.601 conversion for standard-definition + television. See more details in + https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion. + + It differs from a similar function in cv2.cvtColor: `YCrCb <-> BGR`. + In OpenCV, it implements a JPEG conversion. See more details in + https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion. + + Args: + img (ndarray): The input image. It accepts: + 1. np.uint8 type with range [0, 255]; + 2. np.float32 type with range [0, 1]. + + Returns: + ndarray: The converted BGR image. The output image has the same type + and range as input image. + """ + img_type = img.dtype + img = _convert_input_type_range(img) * 255 + out_img = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], + [0.00791071, -0.00153632, 0], + [0, -0.00318811, 0.00625893]]) * 255.0 + [ + -276.836, 135.576, -222.921 + ] # noqa: E126 + out_img = _convert_output_type_range(out_img, img_type) + return out_img + + +def _convert_input_type_range(img): + """Convert the type and range of the input image. + + It converts the input image to np.float32 type and range of [0, 1]. + It is mainly used for pre-processing the input image in colorspace + convertion functions such as rgb2ycbcr and ycbcr2rgb. + + Args: + img (ndarray): The input image. It accepts: + 1. np.uint8 type with range [0, 255]; + 2. np.float32 type with range [0, 1]. + + Returns: + (ndarray): The converted image with type of np.float32 and range of + [0, 1]. + """ + img_type = img.dtype + img = img.astype(np.float32) + if img_type == np.float32: + pass + elif img_type == np.uint8: + img /= 255. + else: + raise TypeError('The img type should be np.float32 or np.uint8, ' + f'but got {img_type}') + return img + + +def _convert_output_type_range(img, dst_type): + """Convert the type and range of the image according to dst_type. + + It converts the image to desired type and range. If `dst_type` is np.uint8, + images will be converted to np.uint8 type with range [0, 255]. If + `dst_type` is np.float32, it converts the image to np.float32 type with + range [0, 1]. + It is mainly used for post-processing images in colorspace convertion + functions such as rgb2ycbcr and ycbcr2rgb. + + Args: + img (ndarray): The image to be converted with np.float32 type and + range [0, 255]. + dst_type (np.uint8 | np.float32): If dst_type is np.uint8, it + converts the image to np.uint8 type with range [0, 255]. If + dst_type is np.float32, it converts the image to np.float32 type + with range [0, 1]. + + Returns: + (ndarray): The converted image with desired type and range. + """ + if dst_type not in (np.uint8, np.float32): + raise TypeError('The dst_type should be np.float32 or np.uint8, ' + f'but got {dst_type}') + if dst_type == np.uint8: + img = img.round() + else: + img /= 255. + return img.astype(dst_type) diff --git a/basicsr/utils/misc.py b/basicsr/utils/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..9b7d8d0849084462e337c754edb49f81220824b0 --- /dev/null +++ b/basicsr/utils/misc.py @@ -0,0 +1,266 @@ +import numpy as np +import os +import random +import time +import torch +from os import path as osp + +from .dist_util import master_only +from .logger import get_root_logger + + +def set_random_seed(seed): + """Set random seeds.""" + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +def get_time_str(): + return time.strftime('%Y%m%d_%H%M%S', time.localtime()) + + +def mkdir_and_rename(path): + """mkdirs. If path exists, rename it with timestamp and create a new one. + + Args: + path (str): Folder path. + """ + # if osp.exists(path): + # new_name = path + '_archived_' + get_time_str() + # print(f'Path already exists. Rename it to {new_name}', flush=True) + # os.rename(path, new_name) + os.makedirs(path, exist_ok=True) + + +@master_only +def make_exp_dirs(opt): + """Make dirs for experiments.""" + path_opt = opt['path'].copy() + if opt['is_train']: + mkdir_and_rename(path_opt.pop('experiments_root')) + else: + mkdir_and_rename(path_opt.pop('results_root')) + for key, path in path_opt.items(): + if ('strict_load' not in key) and ('pretrain_network' + not in key) and ('resume' + not in key): + os.makedirs(path, exist_ok=True) + + +def scandir(dir_path, suffix=None, recursive=False, full_path=False): + """Scan a directory to find the interested files. + + Args: + dir_path (str): Path of the directory. + suffix (str | tuple(str), optional): File suffix that we are + interested in. Default: None. + recursive (bool, optional): If set to True, recursively scan the + directory. Default: False. + full_path (bool, optional): If set to True, include the dir_path. + Default: False. + + Returns: + A generator for all the interested files with relative pathes. + """ + + if (suffix is not None) and not isinstance(suffix, (str, tuple)): + raise TypeError('"suffix" must be a string or tuple of strings') + + root = dir_path + + def _scandir(dir_path, suffix, recursive): + for entry in os.scandir(dir_path): + if not entry.name.startswith('.') and entry.is_file(): + if full_path: + return_path = entry.path + else: + return_path = osp.relpath(entry.path, root) + + if suffix is None: + yield return_path + elif return_path.endswith(suffix): + yield return_path + else: + if recursive: + yield from _scandir( + entry.path, suffix=suffix, recursive=recursive) + else: + continue + + return _scandir(dir_path, suffix=suffix, recursive=recursive) + +def scandir_mv(dir_path, suffix=None, recursive=False, full_path=False, lq=True): + """Scan a directory to find the interested files. + + Args: + dir_path (str): Path of the directory. + suffix (str | tuple(str), optional): File suffix that we are + interested in. Default: None. + recursive (bool, optional): If set to True, recursively scan the + directory. Default: False. + full_path (bool, optional): If set to True, include the dir_path. + Default: False. + + Returns: + A generator for all the interested files with relative pathes. + """ + + if (suffix is not None) and not isinstance(suffix, (str, tuple)): + raise TypeError('"suffix" must be a string or tuple of strings') + + root = dir_path + _type = "no_noise" if lq else "images" + + + # 1,3K는 아직 안만들어져서 2K까지 받는다고 가정 + def _scandir(dir_path, suffix, recursive): + folders = os.listdir(dir_path) + all_files = [] + for folder in folders: # tag + all_files.append(osp.join(dir_path, folder, "images_4")) # ~~train/46/0398fdk3/no_noise + + # 아래는 1,2,3K 다 쓰는 경우 + # subfolders = os.listdir(osp.join(dir_path, folder)) # images4 + # for subfolder in subfolders: + # all_files.append(osp.join(dir_path, folder, subfolder, _type)) # ~~train/46/0398fdk3/no_noise + + + return all_files + return _scandir(dir_path, suffix, recursive) + +def scandir_mv_flat(dir_path, suffix=None, recursive=False, full_path=False, lq=True): + """Scan a directory to find the interested files. + + Args: + dir_path (str): Path of the directory. + suffix (str | tuple(str), optional): File suffix that we are + interested in. Default: None. + recursive (bool, optional): If set to True, recursively scan the + directory. Default: False. + full_path (bool, optional): If set to True, include the dir_path. + Default: False. + + Returns: + A generator for all the interested files with relative pathes. + """ + + if (suffix is not None) and not isinstance(suffix, (str, tuple)): + raise TypeError('"suffix" must be a string or tuple of strings') + + root = dir_path + _type = "no_noise" if lq else "images" + + def _scandir(dir_path, suffix, recursive): + for entry in os.scandir(dir_path): + if not entry.name.startswith('.') and entry.is_file(): + if full_path: + return_path = entry.path + else: + return_path = osp.relpath(entry.path, root) + + if suffix is None: + yield return_path + elif return_path.endswith(suffix): + yield return_path + else: + if recursive: + if entry.name in ["both_noises", "gaussian_only", "images", "no_noise", "no_noise_BGR", "poisson_only", "sparse"]: + if entry.name != _type: + continue + yield from _scandir( + entry.path, suffix=suffix, recursive=recursive) + else: + continue + + return _scandir(dir_path, suffix=suffix, recursive=recursive) + + +def scandir_SIDD(dir_path, keywords=None, recursive=False, full_path=False): + """Scan a directory to find the interested files. + + Args: + dir_path (str): Path of the directory. + keywords (str | tuple(str), optional): File keywords that we are + interested in. Default: None. + recursive (bool, optional): If set to True, recursively scan the + directory. Default: False. + full_path (bool, optional): If set to True, include the dir_path. + Default: False. + + Returns: + A generator for all the interested files with relative pathes. + """ + + if (keywords is not None) and not isinstance(keywords, (str, tuple)): + raise TypeError('"keywords" must be a string or tuple of strings') + + root = dir_path + + def _scandir(dir_path, keywords, recursive): + for entry in os.scandir(dir_path): + if not entry.name.startswith('.') and entry.is_file(): + if full_path: + return_path = entry.path + else: + return_path = osp.relpath(entry.path, root) + + if keywords is None: + yield return_path + elif return_path.find(keywords) > 0: + yield return_path + else: + if recursive: + yield from _scandir( + entry.path, keywords=keywords, recursive=recursive) + else: + continue + + return _scandir(dir_path, keywords=keywords, recursive=recursive) + +def check_resume(opt, resume_iter): + """Check resume states and pretrain_network paths. + + Args: + opt (dict): Options. + resume_iter (int): Resume iteration. + """ + logger = get_root_logger() + if opt['path']['resume_state']: + # get all the networks + networks = [key for key in opt.keys() if key.startswith('network_')] + flag_pretrain = False + for network in networks: + if opt['path'].get(f'pretrain_{network}') is not None: + flag_pretrain = True + if flag_pretrain: + logger.warning( + 'pretrain_network path will be ignored during resuming.') + # set pretrained model paths + for network in networks: + name = f'pretrain_{network}' + basename = network.replace('network_', '') + if opt['path'].get('ignore_resume_networks') is None or ( + basename not in opt['path']['ignore_resume_networks']): + opt['path'][name] = osp.join( + opt['path']['models'], f'net_{basename}_{resume_iter}.pth') + logger.info(f"Set {name} to {opt['path'][name]}") + + +def sizeof_fmt(size, suffix='B'): + """Get human readable file size. + + Args: + size (int): File size. + suffix (str): Suffix. Default: 'B'. + + Return: + str: Formated file siz. + """ + for unit in ['', 'K', 'M', 'G', 'T', 'P', 'E', 'Z']: + if abs(size) < 1024.0: + return f'{size:3.1f} {unit}{suffix}' + size /= 1024.0 + return f'{size:3.1f} Y{suffix}' diff --git a/basicsr/utils/nano.py b/basicsr/utils/nano.py new file mode 100644 index 0000000000000000000000000000000000000000..4a0fdff2cd203a79dfba86ee1fe51ceb2416d9b0 --- /dev/null +++ b/basicsr/utils/nano.py @@ -0,0 +1,250 @@ +import torch +import torch.nn.functional as F +import numpy as np +from torch.distributions.poisson import Poisson +import random + + +def crop_to_bounding_box(image, offset_height, offset_width, target_height, + target_width, is_batch): + # BHWC -> BHWC + cropped = image[:, offset_height: offset_height + target_height, offset_width: offset_width + target_width, :] + + if not is_batch: + cropped = cropped[0] + + return cropped + +def crop_to_bounding_box_list(image, offset_height, offset_width, target_height, + target_width): + # HWC + cropped = [_image[offset_height: offset_height + target_height, offset_width: offset_width + target_width, :] for _image in image] + + return cropped + +def pad_to_bounding_box(image, offset_height, offset_width, target_height, + target_width, is_batch): + _,height,width,_ = image.shape + after_padding_width = target_width - offset_width - width + after_padding_height = target_height - offset_height - height + + paddings = (0, 0, offset_width, after_padding_width, offset_height, after_padding_height, 0, 0) + + padded = torch.nn.functional.pad(image, paddings) + if not is_batch: + padded = padded[0] + + return padded + +def resize_with_crop_or_pad_torch(image, target_height, target_width): + # BHWC -> BHWC + + is_batch = True + if image.ndim == 3: + is_batch = False + image = image[None] # 1HWC + + def max_(x, y): + return max(x, y) + + def min_(x, y): + return min(x, y) + + def equal_(x, y): + return x == y + + _, height, width, _ = image.shape + width_diff = target_width - width + offset_crop_width = max_(-width_diff // 2, 0) + offset_pad_width = max_(width_diff // 2, 0) + + height_diff = target_height - height + offset_crop_height = max_(-height_diff // 2, 0) + offset_pad_height = max_(height_diff // 2, 0) + + # Maybe crop if needed. + cropped = crop_to_bounding_box(image, offset_crop_height, offset_crop_width, + min_(target_height, height), + min_(target_width, width), is_batch) + + # Maybe pad if needed. + if not is_batch and cropped.ndim == 3: + cropped = cropped[None] + resized = pad_to_bounding_box(cropped, offset_pad_height, offset_pad_width, + target_height, target_width, is_batch) + + return resized + + + +def psf2otf(psf, h=None, w=None, permute=False): + ''' + psf = (b) h,w,c + ''' + if h is not None: + psf = resize_with_crop_or_pad_torch(psf, h, w) + if permute: + if psf.ndim == 3: + psf = psf.permute(2,0,1) # HWC -> CHW + else: + psf = psf.permute(0,3,1,2) # HWC -> CHW + psf = psf.to(torch.complex64) + psf = torch.fft.fftshift(psf, dim=(-1,-2)) + otf = torch.fft.fft2(psf) + return otf + +def fft(img): # CHW + img = img.to(torch.complex64) + Fimg = torch.fft.fft2(img) + return Fimg + +def ifft(Fimg): + img = torch.abs(torch.fft.ifft2(Fimg)).to(torch.float32) + return img + + +def create_contrast_mask(image): + return 1 - torch.mean(image, dim=(-1,-2), keepdim=True) # (B), C,1,1 + +def apply_tikhonov(lr_img, psf, K, norm=True, otf=None): + h,w = lr_img.shape[-2:] + if otf is None: + psf_norm = resize_with_crop_or_pad_torch(psf, h, w) + if norm: + psf_norm = psf_norm / psf_norm.sum((0, 1)) + otf = psf2otf(psf_norm, h, w, permute=True) + + otf = otf[:,None,...] # B,1,C,H,W + contrast_mask = create_contrast_mask(lr_img)[:,None,...] # B,1,C,1,1 + K_adjusted = K * contrast_mask # B,M,C,1,1 + tikhonov_filter = torch.conj(otf) / (torch.abs(otf) ** 2 + K_adjusted) # B,M,C,H,W + lr_fft = fft(lr_img)[:,None,...] # B,1,C,H,W + deconvolved_fft = lr_fft * tikhonov_filter + deconvolved_image = torch.fft.ifft2(deconvolved_fft).real + deconvolved_image = torch.clamp(deconvolved_image, min=0.0, max=1.0) + + return deconvolved_image # B,M,C,H,W + + +def add_noise_all_new(image, poss=4e-5, gaus=1e-5): + p = Poisson(image / poss) + sampled = p.sample((1,))[0] + poss_img = sampled * poss + gauss_noise = torch.randn_like(image) * gaus + noised_img = poss_img + gauss_noise + + noised_img = torch.clamp(noised_img, 0.0, 1.0) + + return noised_img + + +def apply_convolution(image, psf, pad): + ''' + input: hr img (b,c,h,w, [0,1]) + output: noised lr img (b,c,h+P,w+P, [0,1]) + ''' + + # metalens simulation + image = F.pad(image, (pad, pad, pad, pad)) + h,w = image.shape[-2:] + psf_norm = resize_with_crop_or_pad_torch(psf, h, w) + otf = psf2otf(psf_norm, h, w, permute=True) + lr_img = fft(image) * otf + lr_img = torch.clamp(ifft(lr_img), min=1e-20, max=1.0) + + # noise addition + noised_img = add_noise_all_new(lr_img) + + return noised_img, otf + +def apply_conv_n_deconv(image, otf, padding, M, psize, ks=None, ph=135, num_psf=9, sensor_h=1215, crop=True, conv=True): + ''' + input: hr img (b,c,h,w) + otf: 1,N,C,H,W + output: noised lr img (N,c,h,w) + ''' + + b,_,_,_ = image.shape + if conv: + img_patch = F.unfold(image, kernel_size=ph*3, stride=ph).view(b,3,ph*3,ph*3,num_psf**2).permute(0,4,1,2,3).contiguous() # B,N,C,H,W + + # metalens simulation + lr_img = fft(img_patch) * otf + lr_img = torch.clamp(ifft(lr_img), min=1e-20, max=1.0) + + # noise addtion + lr_img = add_noise_all_new(lr_img) + + else: # load convolved image for validation + b = 1 + lr_img = image + + # apply deconvolution + if ks is not None: + lr_img = apply_tikhonov(lr_img, None, ks, otf=otf) # B,M,N,C,405,405 + lr_img = lr_img[..., ph:-ph, ph:-ph] # BMNCHW + lr_img = lr_img.view(b, M, num_psf, num_psf, 3, ph, ph).permute(0,1,4,2,5,3,6).reshape(b,M,3,sensor_h,sensor_h) + else: + lr_img = lr_img[..., ph:-ph, ph:-ph] # BNCHW + lr_img = lr_img.view(b, num_psf, num_psf, 3, ph, ph).permute(0,3,1,4,2,5).reshape(b,3,sensor_h,sensor_h) + + lq_patches = [] + gt_patches = [] + for i in range(b): + cur = lr_img[i] # (M),C,H,W + cur_gt = image[i] + + # remove padding for lq and gt + pt,pb,pl,pr = padding[i] + if pb and pt: + cur = cur[...,pt: -pb, :] + cur_gt = cur_gt[...,pt+ph: -(pb+ph), ph:-ph] + elif pl and pr: + cur = cur[...,pl:-pr] + cur_gt = cur_gt[...,ph:-ph, pl+ph: -(pr+ph)] + else: + cur_gt = cur_gt[...,ph:-ph, ph: -ph] + h,w = cur.shape[-2:] + + # randomly crop patch for training + if crop: # train + top = random.randint(0, h - psize) + left = random.randint(0, w - psize) + lq_patches.append(cur[..., top:top + psize, left:left + psize]) + gt_patches.append(cur_gt[..., top:top + psize, left:left + psize]) + if crop: # training + lq_patches = torch.stack(lq_patches) + gt_patches = torch.stack(gt_patches) + else: # validation + return cur, cur_gt + + return lq_patches, gt_patches # B,(M),C,H,W + + +def apply_convolution_square_val(image, otf, padding, M, psize, ks=None, ph=135, num_psf=9, sensor_h=1215, crop=False): + ''' + merge to above one. + image = lr_image + ''' + lr_img = image + b = 1 + if M: # apply deconvolution + lr_img = apply_tikhonov(lr_img, None, ks, otf=otf) # B,M,N,C,H,W + lr_img = lr_img[..., ph:-ph, ph:-ph] # B,M,N,C,H,W + lr_img = lr_img.view(b, M, num_psf, num_psf, 3, ph, ph).permute(0,1,4,2,5,3,6).reshape(b,M,3,sensor_h,sensor_h) + else: + lr_img = lr_img[..., ph:-ph, ph:-ph] # B,N,C,H,W + lr_img = lr_img.view(b, num_psf, num_psf, 3, ph, ph).permute(0,3,1,4,2,5).reshape(b,3,sensor_h,sensor_h) + + + for i in range(b): + cur = lr_img[i] # (M),C,H,W + + # remove padding for lq and gt + pt,pb,pl,pr = padding[i] + if pb and pt: + cur = cur[...,pt: -pb, :] + elif pl and pr: + cur = cur[...,pl:-pr] + + return cur \ No newline at end of file diff --git a/basicsr/utils/options.py b/basicsr/utils/options.py new file mode 100644 index 0000000000000000000000000000000000000000..643c174840854f2ad996558febc534b1041979b2 --- /dev/null +++ b/basicsr/utils/options.py @@ -0,0 +1,112 @@ +import yaml +from collections import OrderedDict +from os import path as osp + + +def ordered_yaml(): + """Support OrderedDict for yaml. + + Returns: + yaml Loader and Dumper. + """ + try: + from yaml import CDumper as Dumper + from yaml import CLoader as Loader + except ImportError: + from yaml import Dumper, Loader + + _mapping_tag = yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG + + def dict_representer(dumper, data): + return dumper.represent_dict(data.items()) + + def dict_constructor(loader, node): + return OrderedDict(loader.construct_pairs(node)) + + Dumper.add_representer(OrderedDict, dict_representer) + Loader.add_constructor(_mapping_tag, dict_constructor) + return Loader, Dumper + + +def parse(opt_path, is_train=True, name=None): + """Parse option file. + + Args: + opt_path (str): Option file path. + is_train (str): Indicate whether in training or not. Default: True. + + Returns: + (dict): Options. + """ + with open(opt_path, mode='r') as f: + Loader, _ = ordered_yaml() + opt = yaml.load(f, Loader=Loader) + + opt['is_train'] = is_train + if name is not None: + opt['name'] = name + + # datasets + for phase, dataset in opt['datasets'].items(): + # for several datasets, e.g., test_1, test_2 + phase = phase.split('_')[0] + dataset['phase'] = phase + if 'scale' in opt: + dataset['scale'] = opt['scale'] + if dataset.get('dataroot_gt') is not None: + dataset['dataroot_gt'] = osp.expanduser(dataset['dataroot_gt']) + if dataset.get('dataroot_lq') is not None: + dataset['dataroot_lq'] = osp.expanduser(dataset['dataroot_lq']) + + # paths + for key, val in opt['path'].items(): + if (val is not None) and ('resume_state' in key + or 'pretrain_network' in key): + opt['path'][key] = osp.expanduser(val) + opt['path']['root'] = osp.abspath( + osp.join(__file__, osp.pardir, osp.pardir, osp.pardir)) + if is_train: + experiments_root = osp.join(opt['path']['root'], 'experiments', + opt['name']) + opt['path']['experiments_root'] = experiments_root + opt['path']['models'] = osp.join(experiments_root, 'models') + opt['path']['training_states'] = osp.join(experiments_root, + 'training_states') + opt['path']['log'] = experiments_root + opt['path']['visualization'] = osp.join(experiments_root, + 'visualization') + + # change some options for debug mode + if 'debug' in opt['name']: + if 'val' in opt: + opt['val']['val_freq'] = 8 + opt['logger']['print_freq'] = 1 + opt['logger']['save_checkpoint_freq'] = 8 + else: # test + results_root = osp.join(opt['path']['root'], 'results', opt['name']) + opt['path']['results_root'] = results_root + opt['path']['log'] = results_root + opt['path']['visualization'] = osp.join(results_root, 'visualization') + + return opt + + +def dict2str(opt, indent_level=1): + """dict to string for printing options. + + Args: + opt (dict): Option dict. + indent_level (int): Indent level. Default: 1. + + Return: + (str): Option string for printing. + """ + msg = '\n' + for k, v in opt.items(): + if isinstance(v, dict): + msg += ' ' * (indent_level * 2) + k + ':[' + msg += dict2str(v, indent_level + 1) + msg += ' ' * (indent_level * 2) + ']\n' + else: + msg += ' ' * (indent_level * 2) + k + ': ' + str(v) + '\n' + return msg diff --git a/basicsr/version.py b/basicsr/version.py new file mode 100644 index 0000000000000000000000000000000000000000..f8c945791963c33df2b5abcd2ce0e42a72339064 --- /dev/null +++ b/basicsr/version.py @@ -0,0 +1,5 @@ +# GENERATED VERSION FILE +# TIME: Fri Mar 21 07:59:14 2025 +__version__ = '1.2.0+5ea673c' +short_version = '1.2.0' +version_info = (1, 2, 0) diff --git a/experiments/pretrained/models/net_g_100000.pth b/experiments/pretrained/models/net_g_100000.pth new file mode 100644 index 0000000000000000000000000000000000000000..10dfd10f970bc5c78ed11f0b47019e51cfe760c9 --- /dev/null +++ b/experiments/pretrained/models/net_g_100000.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8cc95533ca8a4dfdcfad5de2973346ad6b699c6abaf4e7e9d0de77007c4b855f +size 116763496 diff --git a/experiments/pretrained/training_states/100000.state b/experiments/pretrained/training_states/100000.state new file mode 100644 index 0000000000000000000000000000000000000000..0696f0dd3aacc15eb0673be36a332d1bdb748d11 --- /dev/null +++ b/experiments/pretrained/training_states/100000.state @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:edb3104cc8f57a1100b4f0e3d87814a74b2c0fd1ed24a86d69b917b0e1973d2b +size 233563982 diff --git a/psf.npy b/psf.npy new file mode 100644 index 0000000000000000000000000000000000000000..50da07db2d92dc4bee92f96c49620f96546435a2 --- /dev/null +++ b/psf.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:337461630addd8dcc48a0293678b5ef75d9c35a5c7b6a0524154d2e8540741a8 +size 17714828 diff --git a/readme.md b/readme.md new file mode 100644 index 0000000000000000000000000000000000000000..a0f63a39a63be4097abcbff0127c58d7701d5a8d --- /dev/null +++ b/readme.md @@ -0,0 +1,73 @@ +# Aberration Correcting Vision Transformers for High-Fidelity Metalens Imaging + +Byeonghyeon Lee, Youbin Kim, Yongjae Jo, Hyunsu Kim, Hyemi Park, Yangkyu Kim, Debabrata Mandal, Praneeth Chakravarthula, Inki Kim, and Eunbyung Park + +[Project Page](https://benhenryl.github.io/Metalens-Transformer/)   [Paper](https://arxiv.org/abs/2412.04591) + + +We ran the experiments in the following environment: +``` +- ubuntu: 20.04 +- python: 3.10.13 +- cuda: 11.8 +- pytorch: 2.2.0 +- GPU: 4x A6000 ada +``` + +Our code is based on [Restormer](https://github.com/swz30/Restormer), [X-Restormer](https://github.com/Andrew0613/X-Restormer), and [Neural Nano Optics](https://github.com/princeton-computational-imaging/Neural_Nano-Optics). We appreciate their works. + +## 1. Environment Setting +### 1-1. Pytorch +Note: pytorch >= 2.2.0 is required for Flash Attention. + +### 1-2. [Flash Attention](https://github.com/Dao-AILab/flash-attention) +cf. Ampere, Ada, or Hopper GPUs (e.g., A100, RTX 3090, RTX 4090, H100) are supported now. +``` +pip install packaging ninja +pip install flash-attn --no-build-isolation +``` + +### 1-3. Other required packages +``` +pip install -r requirements.txt +``` + +### 1-4. Basicsr +``` +python setup.py develop --no_cuda_ext +``` + +## 2. Dataset & Pre-trained weights +You can download train/test dataset [here](https://drive.google.com/drive/folders/1e2wJwmcjXFvblVs0l5OXwpIkTqxd1Fhq?usp=drive_link) and pre-trained weights [here](https://drive.google.com/drive/folders/1q5pKE1Z0RJjHVmJlNq7nPSWcaGd9bDb7?usp=drive_link). +Please move the pre-trained weights to experiments/. +Note: The model creates aberrated images on the fly using clean (gt) images during training. +In case of validation, it also produces the aberrated images in the same manner, where the aberrated images can have different noises to what we used for our validation. +There will be only negligible difference in the results as it still uses the same noise distributions, but if you want a precise comparison with the validation set we used for our experiments, please contact us. + + +## 3. Training +Please set dataset path in ```./Aberration_Correction/Options/Train_Aberration_Transformers.yml``` +``` +bash train.sh GPU_IDS FOLDER_NAME +// ex. bash train.sh 0,1,2,3 training +// where it uses gpu 0 to 3 and make a directory experiments/training where log, weights and others will be stored. +``` + +## 4. Inference +Please set dataset path in ```./Aberration_Correction/Options/Test_Aberration_Transformers.yml``` +If you want to run a inference using the pre-trained model, you can use a command +``` +bash test.sh GPU_ID FOLDER_NAME +// ex. bash test.sh 0 pretrained +``` +Or you can designate the FOLDER_NAME with your weight path. + +## BibTeX +``` +@article{lee2024aberration, + title={Aberration Correcting Vision Transformers for High-Fidelity Metalens Imaging}, + author={Lee, Byeonghyeon and Kim, Youbin and Jo, Yongjae and Kim, Hyunsu and Park, Hyemi and Kim, Yangkyu and Mandal, Debabrata and Chakravarthula, Praneeth and Kim, Inki and Park, Eunbyung}, + journal={arXiv preprint arXiv:2412.04591}, + year={2024} +} +``` diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..e097cc999068bcb8c40345fc9a417c975fb50864 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,22 @@ +matplotlib +scikit-learn +scikit-image==0.19.3 +opencv-python +yacs +joblib +natsort +h5py +tqdm +einops +gdown +addict +future +lmdb +numpy +pyyaml +requests +scipy +tb-nightly +yapf +lpips +torchmetrics \ No newline at end of file diff --git a/setup.py b/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..07f403f3a3521e4e426279beaad6f05788b191a0 --- /dev/null +++ b/setup.py @@ -0,0 +1,177 @@ +#!/usr/bin/env python + +from setuptools import find_packages, setup + +import os +import subprocess +import sys +import time +import torch +from torch.utils.cpp_extension import (BuildExtension, CppExtension, + CUDAExtension) + +version_file = 'basicsr/version.py' + + +def readme(): + return '' + # with open('README.md', encoding='utf-8') as f: + # content = f.read() + # return content + + +def get_git_hash(): + + def _minimal_ext_cmd(cmd): + # construct minimal environment + env = {} + for k in ['SYSTEMROOT', 'PATH', 'HOME']: + v = os.environ.get(k) + if v is not None: + env[k] = v + # LANGUAGE is used on win32 + env['LANGUAGE'] = 'C' + env['LANG'] = 'C' + env['LC_ALL'] = 'C' + out = subprocess.Popen( + cmd, stdout=subprocess.PIPE, env=env).communicate()[0] + return out + + try: + out = _minimal_ext_cmd(['git', 'rev-parse', 'HEAD']) + sha = out.strip().decode('ascii') + except OSError: + sha = 'unknown' + + return sha + + +def get_hash(): + if os.path.exists('.git'): + sha = get_git_hash()[:7] + elif os.path.exists(version_file): + try: + from basicsr.version import __version__ + sha = __version__.split('+')[-1] + except ImportError: + raise ImportError('Unable to get git version') + else: + sha = 'unknown' + + return sha + + +def write_version_py(): + content = """# GENERATED VERSION FILE +# TIME: {} +__version__ = '{}' +short_version = '{}' +version_info = ({}) +""" + sha = get_hash() + with open('VERSION', 'r') as f: + SHORT_VERSION = f.read().strip() + VERSION_INFO = ', '.join( + [x if x.isdigit() else f'"{x}"' for x in SHORT_VERSION.split('.')]) + VERSION = SHORT_VERSION + '+' + sha + + version_file_str = content.format(time.asctime(), VERSION, SHORT_VERSION, + VERSION_INFO) + with open(version_file, 'w') as f: + f.write(version_file_str) + + +def get_version(): + with open(version_file, 'r') as f: + exec(compile(f.read(), version_file, 'exec')) + return locals()['__version__'] + + +def make_cuda_ext(name, module, sources, sources_cuda=None): + if sources_cuda is None: + sources_cuda = [] + define_macros = [] + extra_compile_args = {'cxx': []} + + if torch.cuda.is_available() or os.getenv('FORCE_CUDA', '0') == '1': + define_macros += [('WITH_CUDA', None)] + extension = CUDAExtension + extra_compile_args['nvcc'] = [ + '-D__CUDA_NO_HALF_OPERATORS__', + '-D__CUDA_NO_HALF_CONVERSIONS__', + '-D__CUDA_NO_HALF2_OPERATORS__', + ] + sources += sources_cuda + else: + print(f'Compiling {name} without CUDA') + extension = CppExtension + + return extension( + name=f'{module}.{name}', + sources=[os.path.join(*module.split('.'), p) for p in sources], + define_macros=define_macros, + extra_compile_args=extra_compile_args) + + +def get_requirements(filename='requirements.txt'): + return [] + here = os.path.dirname(os.path.realpath(__file__)) + with open(os.path.join(here, filename), 'r') as f: + requires = [line.replace('\n', '') for line in f.readlines()] + return requires + + +if __name__ == '__main__': + if '--no_cuda_ext' in sys.argv: + ext_modules = [] + sys.argv.remove('--no_cuda_ext') + else: + ext_modules = [ + make_cuda_ext( + name='deform_conv_ext', + module='basicsr.models.ops.dcn', + sources=['src/deform_conv_ext.cpp'], + sources_cuda=[ + 'src/deform_conv_cuda.cpp', + 'src/deform_conv_cuda_kernel.cu' + ]), + make_cuda_ext( + name='fused_act_ext', + module='basicsr.models.ops.fused_act', + sources=['src/fused_bias_act.cpp'], + sources_cuda=['src/fused_bias_act_kernel.cu']), + make_cuda_ext( + name='upfirdn2d_ext', + module='basicsr.models.ops.upfirdn2d', + sources=['src/upfirdn2d.cpp'], + sources_cuda=['src/upfirdn2d_kernel.cu']), + ] + + write_version_py() + print("setup start") + setup( + name='basicsr', + version=get_version(), + description='Open Source Image and Video Super-Resolution Toolbox', + long_description=readme(), + author='Xintao Wang', + author_email='xintao.wang@outlook.com', + keywords='computer vision, restoration, super resolution', + url='https://github.com/xinntao/BasicSR', + packages=find_packages( + exclude=('options', 'datasets', 'experiments', 'results', + 'tb_logger', 'wandb')), + classifiers=[ + 'Development Status :: 4 - Beta', + 'License :: OSI Approved :: Apache Software License', + 'Operating System :: OS Independent', + 'Programming Language :: Python :: 3', + 'Programming Language :: Python :: 3.7', + 'Programming Language :: Python :: 3.8', + ], + license='Apache License 2.0', + setup_requires=['cython', 'numpy'], + install_requires=get_requirements(), + ext_modules=ext_modules, + cmdclass={'build_ext': BuildExtension}, + zip_safe=False) diff --git a/test.sh b/test.sh new file mode 100644 index 0000000000000000000000000000000000000000..59e05c12b4f3869ea133155a9b00a23c0ccd4e50 --- /dev/null +++ b/test.sh @@ -0,0 +1,3 @@ +#!/bin/bash + +CUDA_VISIBLE_DEVICES=$1 python -m torch.distributed.launch --nproc_per_node=1 --master_port=12321 basicsr/test.py -opt Aberration_Correction/Options/Eval_Aberration_Transformers.yml --launcher pytorch --name $2 \ No newline at end of file diff --git a/train.sh b/train.sh new file mode 100644 index 0000000000000000000000000000000000000000..36a5d32ab36667d863e4056b2efb69e31472a0a5 --- /dev/null +++ b/train.sh @@ -0,0 +1,3 @@ +#!/bin/bash + +CUDA_VISIBLE_DEVICES=$1 python -m torch.distributed.launch --nproc_per_node=4 --master_port=12321 basicsr/train.py -opt Aberration_Correction/Options/Train_Aberration_Transformers.yml --launcher pytorch --name $2 \ No newline at end of file