blee commited on
Commit
6670ec8
·
verified ·
1 Parent(s): 66d36dd

Upload 53 files

Browse files

Upload codes and weights

This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. Aberration_Correction/Options/Test_Aberration_Transformers.yml +75 -0
  3. Aberration_Correction/Options/Train_Aberration_Transformers.yml +141 -0
  4. Aberration_Correction/utils.py +90 -0
  5. VERSION +1 -0
  6. basicsr/data/__init__.py +126 -0
  7. basicsr/data/data_sampler.py +49 -0
  8. basicsr/data/data_util.py +15 -0
  9. basicsr/data/paired_image_dataset.py +156 -0
  10. basicsr/data/prefetch_dataloader.py +126 -0
  11. basicsr/data/transforms.py +167 -0
  12. basicsr/metrics/__init__.py +4 -0
  13. basicsr/metrics/fid.py +102 -0
  14. basicsr/metrics/metric_util.py +47 -0
  15. basicsr/metrics/niqe.py +205 -0
  16. basicsr/metrics/niqe_pris_params.npz +3 -0
  17. basicsr/metrics/other_metrics.py +88 -0
  18. basicsr/metrics/psnr_ssim.py +303 -0
  19. basicsr/models/__init__.py +42 -0
  20. basicsr/models/archs/__init__.py +45 -0
  21. basicsr/models/archs/arch_util.py +255 -0
  22. basicsr/models/archs/restormer_arch.py +527 -0
  23. basicsr/models/base_model.py +376 -0
  24. basicsr/models/image_restoration_model.py +392 -0
  25. basicsr/models/losses/__init__.py +5 -0
  26. basicsr/models/losses/loss_util.py +95 -0
  27. basicsr/models/losses/losses.py +180 -0
  28. basicsr/models/lr_scheduler.py +232 -0
  29. basicsr/test.py +142 -0
  30. basicsr/train.py +328 -0
  31. basicsr/utils/__init__.py +45 -0
  32. basicsr/utils/bundle_submissions.py +108 -0
  33. basicsr/utils/create_lmdb.py +124 -0
  34. basicsr/utils/dist_util.py +83 -0
  35. basicsr/utils/download_util.py +70 -0
  36. basicsr/utils/face_util.py +217 -0
  37. basicsr/utils/file_client.py +186 -0
  38. basicsr/utils/flow_util.py +180 -0
  39. basicsr/utils/img_util.py +216 -0
  40. basicsr/utils/lmdb_util.py +208 -0
  41. basicsr/utils/logger.py +175 -0
  42. basicsr/utils/matlab_functions.py +361 -0
  43. basicsr/utils/misc.py +266 -0
  44. basicsr/utils/nano.py +250 -0
  45. basicsr/utils/options.py +112 -0
  46. basicsr/version.py +5 -0
  47. experiments/pretrained/models/net_g_100000.pth +3 -0
  48. experiments/pretrained/training_states/100000.state +3 -0
  49. psf.npy +3 -0
  50. readme.md +73 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ experiments/pretrained/training_states/100000.state filter=lfs diff=lfs merge=lfs -text
Aberration_Correction/Options/Test_Aberration_Transformers.yml ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # general settings
2
+ name: sample_test
3
+ # name: batch8
4
+ model_type: ImageCleanModel
5
+ scale: 1
6
+ num_gpu: 4 # set num_gpu: 0 for cpu mode
7
+ manual_seed: 100
8
+
9
+ # dataset and data loader settings
10
+ datasets:
11
+ val:
12
+ name: ValSet
13
+ type: Dataset_PaddedImage # Use Dataset_PaddedImage_npy if load convolved images (lr images). Also please set dataroot_lq as well.
14
+ dataroot_gt: PATH_TO_TEST_SET # TODO
15
+ io_backend:
16
+ type: disk
17
+
18
+ sensor_size: 1215
19
+ psf_size: 135
20
+
21
+ # network structures
22
+ network_g:
23
+ type: ACFormer
24
+ inp_channels: 39
25
+ out_channels: 3
26
+ dim: 48
27
+ num_blocks: [2,4,4,4]
28
+ num_refinement_blocks: 4
29
+ channel_heads: [1,2,4,8]
30
+ spatial_heads: [1,2,4,8]
31
+ overlap_ratio: [0.5,0.5,0.5,0.5]
32
+ window_size: 8
33
+ spatial_dim_head: 16
34
+ ffn_expansion_factor: 2.66
35
+ bias: False
36
+ LayerNorm_type: WithBias
37
+ ca_dim: 32
38
+ ca_heads: 2
39
+ M: 13
40
+ window_size_ca: 8
41
+ query_ksize: [15,11,7,3,3]
42
+
43
+ # path
44
+ path:
45
+ pretrain_network_g: ~
46
+ strict_load_g: true
47
+ resume_state: ~
48
+
49
+ # training settings
50
+ train:
51
+ ks:
52
+ start: -2
53
+ end: -5
54
+ num: 13
55
+
56
+ # validation settings
57
+ val:
58
+ window_size: 8
59
+ save_img: true
60
+ rgb2bgr: true
61
+ use_image: true
62
+ max_minibatch: 8
63
+ padding: 64
64
+
65
+ metrics:
66
+ psnr: # metric name, can be arbitrary
67
+ type: calculate_psnr
68
+ crop_border: 0
69
+ test_y_channel: true
70
+
71
+
72
+ # dist training settings
73
+ dist_params:
74
+ backend: nccl
75
+ port: 29502
Aberration_Correction/Options/Train_Aberration_Transformers.yml ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # general settings
2
+ name: sample_test
3
+ # name: batch8
4
+ model_type: ImageCleanModel
5
+ scale: 1
6
+ num_gpu: 4 # set num_gpu: 0 for cpu mode
7
+ manual_seed: 100
8
+
9
+ # dataset and data loader settings
10
+ datasets:
11
+ train:
12
+ name: TrainSet
13
+ type: Dataset_PaddedImage # make lr image from gt image on the fly.
14
+ dataroot_gt: PATH_TO_TRAIN_SET # TODO
15
+
16
+ filename_tmpl: '{}'
17
+ io_backend:
18
+ type: disk
19
+
20
+ # data loader
21
+ use_shuffle: true
22
+ num_worker_per_gpu: 8 # 8
23
+ batch_size_per_gpu: 2 # 8
24
+
25
+ gt_size: 256
26
+
27
+ dataset_enlarge_ratio: 1
28
+ prefetch_mode: ~
29
+
30
+ sensor_size: 1215
31
+ psf_size: 135
32
+
33
+ val:
34
+ name: ValSet
35
+ type: Dataset_PaddedImage
36
+ dataroot_gt: PATH_TO_TEST_SET # TODO
37
+ io_backend:
38
+ type: disk
39
+
40
+ sensor_size: 1215
41
+ psf_size: 135
42
+
43
+ # network structures
44
+ network_g:
45
+ type: ACFormer
46
+ inp_channels: 39
47
+ out_channels: 3
48
+ dim: 48
49
+ num_blocks: [2,4,4,4]
50
+ num_refinement_blocks: 4
51
+ channel_heads: [1,2,4,8]
52
+ spatial_heads: [1,2,4,8]
53
+ overlap_ratio: [0.5,0.5,0.5,0.5]
54
+ window_size: 8
55
+ spatial_dim_head: 16
56
+ ffn_expansion_factor: 2.66
57
+ bias: False
58
+ LayerNorm_type: WithBias
59
+ ca_dim: 32
60
+ ca_heads: 2
61
+ M: 13
62
+ window_size_ca: 8
63
+ query_ksize: [15,11,7,3,3]
64
+
65
+ # path
66
+ path:
67
+ pretrain_network_g: ~
68
+ strict_load_g: true
69
+ resume_state: ~
70
+
71
+ # training settings
72
+ train:
73
+ eval_only: True
74
+ eval_name: Sample_data
75
+ real_psf: True
76
+ grid: True
77
+ total_iter: 100000
78
+ warmup_iter: -1 # no warm up
79
+ use_grad_clip: true
80
+ contrast_tik: 2
81
+ sensor_height: 1215
82
+
83
+ scheduler:
84
+ type: CosineAnnealingRestartCyclicLR
85
+ periods: [92000, 208000]
86
+ restart_weights: [1,1]
87
+ eta_mins: [0.0003,0.000001]
88
+
89
+ mixing_augs:
90
+ mixup: false
91
+ mixup_beta: 1.2
92
+ use_identity: true
93
+
94
+ optim_g:
95
+ type: AdamW
96
+ lr: !!float 3e-4
97
+ weight_decay: !!float 1e-4
98
+ betas: [0.9, 0.999]
99
+
100
+ # losses
101
+ pixel_opt:
102
+ type: L1Loss
103
+ loss_weight: 1
104
+ reduction: mean
105
+
106
+ ks:
107
+ start: -2
108
+ end: -5
109
+ num: 13
110
+
111
+
112
+ # validation settings
113
+ val:
114
+ window_size: 8
115
+ val_freq: !!float 1e8 # inactivated
116
+ save_img: false
117
+ rgb2bgr: true
118
+ use_image: true
119
+ max_minibatch: 8
120
+ padding: 64
121
+ apply_conv: True # Apply convolution to GT image to create lr image. False if load .npy data (already aberrated)
122
+
123
+ metrics:
124
+ psnr: # metric name, can be arbitrary
125
+ type: calculate_psnr
126
+ crop_border: 0
127
+ test_y_channel: true
128
+
129
+ # logging settings
130
+ logger:
131
+ print_freq: 500
132
+ save_checkpoint_freq: !!float 5e3
133
+ use_tb_logger: true
134
+ wandb:
135
+ project: ~
136
+ resume_id: ~
137
+
138
+ # dist training settings
139
+ dist_params:
140
+ backend: nccl
141
+ port: 29502
Aberration_Correction/utils.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Restormer: Efficient Transformer for High-Resolution Image Restoration
2
+ ## Syed Waqas Zamir, Aditya Arora, Salman Khan, Munawar Hayat, Fahad Shahbaz Khan, and Ming-Hsuan Yang
3
+ ## https://arxiv.org/abs/2111.09881
4
+
5
+ import numpy as np
6
+ import os
7
+ import cv2
8
+ import math
9
+
10
+ def calculate_psnr(img1, img2, border=0):
11
+ # img1 and img2 have range [0, 255]
12
+ #img1 = img1.squeeze()
13
+ #img2 = img2.squeeze()
14
+ if not img1.shape == img2.shape:
15
+ raise ValueError('Input images must have the same dimensions.')
16
+ h, w = img1.shape[:2]
17
+ img1 = img1[border:h-border, border:w-border]
18
+ img2 = img2[border:h-border, border:w-border]
19
+
20
+ img1 = img1.astype(np.float64)
21
+ img2 = img2.astype(np.float64)
22
+ mse = np.mean((img1 - img2)**2)
23
+ if mse == 0:
24
+ return float('inf')
25
+ return 20 * math.log10(255.0 / math.sqrt(mse))
26
+
27
+
28
+ # --------------------------------------------
29
+ # SSIM
30
+ # --------------------------------------------
31
+ def calculate_ssim(img1, img2, border=0):
32
+ '''calculate SSIM
33
+ the same outputs as MATLAB's
34
+ img1, img2: [0, 255]
35
+ '''
36
+ #img1 = img1.squeeze()
37
+ #img2 = img2.squeeze()
38
+ if not img1.shape == img2.shape:
39
+ raise ValueError('Input images must have the same dimensions.')
40
+ h, w = img1.shape[:2]
41
+ img1 = img1[border:h-border, border:w-border]
42
+ img2 = img2[border:h-border, border:w-border]
43
+
44
+ if img1.ndim == 2:
45
+ return ssim(img1, img2)
46
+ elif img1.ndim == 3:
47
+ if img1.shape[2] == 3:
48
+ ssims = []
49
+ for i in range(3):
50
+ ssims.append(ssim(img1[:,:,i], img2[:,:,i]))
51
+ return np.array(ssims).mean()
52
+ elif img1.shape[2] == 1:
53
+ return ssim(np.squeeze(img1), np.squeeze(img2))
54
+ else:
55
+ raise ValueError('Wrong input image dimensions.')
56
+
57
+
58
+ def ssim(img1, img2):
59
+ C1 = (0.01 * 255)**2
60
+ C2 = (0.03 * 255)**2
61
+
62
+ img1 = img1.astype(np.float64)
63
+ img2 = img2.astype(np.float64)
64
+ kernel = cv2.getGaussianKernel(11, 1.5)
65
+ window = np.outer(kernel, kernel.transpose())
66
+
67
+ mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid
68
+ mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
69
+ mu1_sq = mu1**2
70
+ mu2_sq = mu2**2
71
+ mu1_mu2 = mu1 * mu2
72
+ sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq
73
+ sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
74
+ sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
75
+
76
+ ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
77
+ (sigma1_sq + sigma2_sq + C2))
78
+ return ssim_map.mean()
79
+
80
+ def load_img(filepath):
81
+ return cv2.cvtColor(cv2.imread(filepath), cv2.COLOR_BGR2RGB)
82
+
83
+ def save_img(filepath, img):
84
+ cv2.imwrite(filepath,cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
85
+
86
+ def load_gray_img(filepath):
87
+ return np.expand_dims(cv2.imread(filepath, cv2.IMREAD_GRAYSCALE), axis=2)
88
+
89
+ def save_gray_img(filepath, img):
90
+ cv2.imwrite(filepath, img)
VERSION ADDED
@@ -0,0 +1 @@
 
 
1
+ 1.2.0
basicsr/data/__init__.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ import numpy as np
3
+ import random
4
+ import torch
5
+ import torch.utils.data
6
+ from functools import partial
7
+ from os import path as osp
8
+
9
+ from basicsr.data.prefetch_dataloader import PrefetchDataLoader
10
+ from basicsr.utils import get_root_logger, scandir
11
+ from basicsr.utils.dist_util import get_dist_info
12
+
13
+ __all__ = ['create_dataset', 'create_dataloader']
14
+
15
+ # automatically scan and import dataset modules
16
+ # scan all the files under the data folder with '_dataset' in file names
17
+ data_folder = osp.dirname(osp.abspath(__file__))
18
+ dataset_filenames = [
19
+ osp.splitext(osp.basename(v))[0] for v in scandir(data_folder)
20
+ if v.endswith('_dataset.py')
21
+ ]
22
+ # import all the dataset modules
23
+ _dataset_modules = [
24
+ importlib.import_module(f'basicsr.data.{file_name}')
25
+ for file_name in dataset_filenames
26
+ ]
27
+
28
+
29
+ def create_dataset(dataset_opt, mv=False):
30
+ """Create dataset.
31
+
32
+ Args:
33
+ dataset_opt (dict): Configuration for dataset. It constains:
34
+ name (str): Dataset name.
35
+ type (str): Dataset type.
36
+ """
37
+ dataset_type = dataset_opt['type']
38
+ # dynamic instantiation
39
+ for module in _dataset_modules:
40
+ dataset_cls = getattr(module, dataset_type, None)
41
+ if dataset_cls is not None:
42
+ break
43
+ if dataset_cls is None:
44
+ raise ValueError(f'Dataset {dataset_type} is not found.')
45
+
46
+ dataset = dataset_cls(dataset_opt)
47
+
48
+ logger = get_root_logger()
49
+ logger.info(
50
+ f'Dataset {dataset.__class__.__name__} - {dataset_opt["name"]} '
51
+ 'is created.')
52
+ return dataset
53
+
54
+
55
+ def create_dataloader(dataset,
56
+ dataset_opt,
57
+ num_gpu=1,
58
+ dist=False,
59
+ sampler=None,
60
+ seed=None):
61
+ """Create dataloader.
62
+
63
+ Args:
64
+ dataset (torch.utils.data.Dataset): Dataset.
65
+ dataset_opt (dict): Dataset options. It contains the following keys:
66
+ phase (str): 'train' or 'val'.
67
+ num_worker_per_gpu (int): Number of workers for each GPU.
68
+ batch_size_per_gpu (int): Training batch size for each GPU.
69
+ num_gpu (int): Number of GPUs. Used only in the train phase.
70
+ Default: 1.
71
+ dist (bool): Whether in distributed training. Used only in the train
72
+ phase. Default: False.
73
+ sampler (torch.utils.data.sampler): Data sampler. Default: None.
74
+ seed (int | None): Seed. Default: None
75
+ """
76
+ phase = dataset_opt['phase']
77
+ rank, _ = get_dist_info()
78
+ if phase == 'train':
79
+ if dist: # distributed training
80
+ batch_size = dataset_opt['batch_size_per_gpu']
81
+ num_workers = dataset_opt['num_worker_per_gpu']
82
+ else: # non-distributed training
83
+ multiplier = 1 if num_gpu == 0 else num_gpu
84
+ batch_size = dataset_opt['batch_size_per_gpu'] * multiplier
85
+ num_workers = dataset_opt['num_worker_per_gpu'] * multiplier
86
+ dataloader_args = dict(
87
+ dataset=dataset,
88
+ batch_size=batch_size,
89
+ shuffle=False,
90
+ num_workers=num_workers,
91
+ sampler=sampler,
92
+ drop_last=True)
93
+
94
+ if sampler is None:
95
+ dataloader_args['shuffle'] = True
96
+ dataloader_args['worker_init_fn'] = partial(
97
+ worker_init_fn, num_workers=num_workers, rank=rank,
98
+ seed=seed) if seed is not None else None
99
+ elif phase in ['val', 'test', 'val20']: # validation
100
+ dataloader_args = dict(
101
+ dataset=dataset, batch_size=1, shuffle=False, num_workers=0)
102
+ else:
103
+ raise ValueError(f'Wrong dataset phase: {phase}. '
104
+ "Supported ones are 'train', 'val' and 'test'.")
105
+
106
+ dataloader_args['pin_memory'] = dataset_opt.get('pin_memory', False)
107
+
108
+ prefetch_mode = dataset_opt.get('prefetch_mode')
109
+ if prefetch_mode == 'cpu': # CPUPrefetcher
110
+ num_prefetch_queue = dataset_opt.get('num_prefetch_queue', 1)
111
+ logger = get_root_logger()
112
+ logger.info(f'Use {prefetch_mode} prefetch dataloader: '
113
+ f'num_prefetch_queue = {num_prefetch_queue}')
114
+ return PrefetchDataLoader(
115
+ num_prefetch_queue=num_prefetch_queue, **dataloader_args)
116
+ else:
117
+ # prefetch_mode=None: Normal dataloader
118
+ # prefetch_mode='cuda': dataloader for CUDAPrefetcher
119
+ return torch.utils.data.DataLoader(**dataloader_args)
120
+
121
+
122
+ def worker_init_fn(worker_id, num_workers, rank, seed):
123
+ # Set the worker seed to num_workers * rank + worker_id + seed
124
+ worker_seed = num_workers * rank + worker_id + seed
125
+ np.random.seed(worker_seed)
126
+ random.seed(worker_seed)
basicsr/data/data_sampler.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch.utils.data.sampler import Sampler
4
+
5
+
6
+ class EnlargedSampler(Sampler):
7
+ """Sampler that restricts data loading to a subset of the dataset.
8
+
9
+ Modified from torch.utils.data.distributed.DistributedSampler
10
+ Support enlarging the dataset for iteration-based training, for saving
11
+ time when restart the dataloader after each epoch
12
+
13
+ Args:
14
+ dataset (torch.utils.data.Dataset): Dataset used for sampling.
15
+ num_replicas (int | None): Number of processes participating in
16
+ the training. It is usually the world_size.
17
+ rank (int | None): Rank of the current process within num_replicas.
18
+ ratio (int): Enlarging ratio. Default: 1.
19
+ """
20
+
21
+ def __init__(self, dataset, num_replicas, rank, ratio=1):
22
+ self.dataset = dataset
23
+ self.num_replicas = num_replicas
24
+ self.rank = rank
25
+ self.epoch = 0
26
+ self.num_samples = math.ceil(
27
+ len(self.dataset) * ratio / self.num_replicas)
28
+ self.total_size = self.num_samples * self.num_replicas
29
+
30
+ def __iter__(self):
31
+ # deterministically shuffle based on epoch
32
+ g = torch.Generator()
33
+ g.manual_seed(self.epoch)
34
+ indices = torch.randperm(self.total_size, generator=g).tolist()
35
+
36
+ dataset_size = len(self.dataset)
37
+ indices = [v % dataset_size for v in indices]
38
+
39
+ # subsample
40
+ indices = indices[self.rank:self.total_size:self.num_replicas]
41
+ assert len(indices) == self.num_samples
42
+
43
+ return iter(indices)
44
+
45
+ def __len__(self):
46
+ return self.num_samples
47
+
48
+ def set_epoch(self, epoch):
49
+ self.epoch = epoch
basicsr/data/data_util.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ cv2.setNumThreads(1)
3
+ from os import path as osp
4
+ from basicsr.utils import scandir
5
+
6
+
7
+ def paths_from_folder(folder, key):
8
+ gt_paths = list(scandir(folder))
9
+ paths = []
10
+ for idx in range(len(gt_paths)):
11
+ gt_path = gt_paths[idx]
12
+ gt_path = osp.join(folder, gt_path)
13
+ paths.append(
14
+ dict([(f'{key}_path', gt_path)]))
15
+ return paths
basicsr/data/paired_image_dataset.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.utils import data as data
2
+ from torchvision.transforms.functional import normalize
3
+
4
+ from basicsr.data.data_util import paths_from_folder
5
+ from basicsr.utils import FileClient, imfrombytes, img2tensor, padding
6
+ from natsort import natsorted
7
+ import random
8
+ import numpy as np
9
+ import torch
10
+ import cv2
11
+ import os
12
+ import random
13
+
14
+
15
+ class Dataset_PaddedImage(data.Dataset):
16
+ """Padded image dataset for image restoration.
17
+
18
+ Args:
19
+ opt (dict): Config for train datasets. It contains the following keys:
20
+ dataroot_gt (str): Data root path for gt.
21
+ io_backend (dict): IO backend type and other kwarg.
22
+ gt_size (int): Cropped patched size for gt patches.
23
+ scale (bool): Scale, which will be added automatically.
24
+ phase (str): 'train' or 'val'.
25
+ """
26
+
27
+ def __init__(self, opt):
28
+ super(Dataset_PaddedImage, self).__init__()
29
+ self.opt = opt
30
+ # file client (io backend)
31
+ self.file_client = None
32
+ self.io_backend_opt = opt['io_backend']
33
+
34
+ self.gt_folder = opt['dataroot_gt']
35
+ self.paths = paths_from_folder(self.gt_folder, 'gt')
36
+
37
+ self.sensor_size = opt['sensor_size']
38
+ self.psf_size = opt['psf_size']
39
+ self.padded_size = self.sensor_size + 2 * self.psf_size
40
+
41
+ def __getitem__(self, index):
42
+ if self.file_client is None:
43
+ self.file_client = FileClient(
44
+ self.io_backend_opt.pop('type'), **self.io_backend_opt)
45
+
46
+ scale = self.opt['scale']
47
+ index = index % len(self.paths)
48
+ # Load gt and lq images. Dimension order: HWC; channel order: BGR;
49
+ # image range: [0, 1], float32.
50
+ gt_path = self.paths[index]['gt_path']
51
+ img_bytes = self.file_client.get(gt_path, 'gt')
52
+ try:
53
+ img_gt = imfrombytes(img_bytes, float32=True)
54
+ except:
55
+ raise Exception("gt path {} not working".format(gt_path))
56
+
57
+
58
+ if self.opt['phase'] == 'train':
59
+ gt_size = self.opt['gt_size']
60
+ # padding
61
+ img_gt = padding(img_gt, gt_size) # h,w,c
62
+ orig_h, orig_w, _ = img_gt.shape
63
+
64
+ # Fit one axis to sensor height (width)
65
+ longer = max(orig_h, orig_w)
66
+ scale = float(longer / self.sensor_size)
67
+ resolution = (int(orig_w / scale), int(orig_h / scale))
68
+ img_gt = cv2.resize(img_gt, resolution, interpolation=cv2.INTER_LINEAR) # sensor_size,x,3 or y,sensor_size,3 where x,y <= sensor_size
69
+
70
+ resized_h, resized_w, _ = img_gt.shape
71
+ # add padding
72
+ pad_h = self.padded_size - resized_h
73
+ pad_w = self.padded_size - resized_w
74
+ pad_l = pad_r = pad_w // 2
75
+ if pad_w % 2:
76
+ pad_r += 1
77
+ pad_t = pad_b = pad_h // 2
78
+ if pad_h % 2:
79
+ pad_b += 1
80
+ img_gt = np.pad(img_gt, ((pad_t, pad_b), (pad_l, pad_r), (0,0))) # padded_size,padded_size,3
81
+
82
+ # BGR to RGB, HWC to CHW, numpy to tensor
83
+ img_gt = img2tensor(img_gt, bgr2rgb=True,
84
+ float32=True)
85
+
86
+ return {
87
+ 'gt': img_gt,
88
+ 'gt_path': gt_path,
89
+ 'padding': (pad_t-self.psf_size, pad_b-self.psf_size, pad_l-self.psf_size, pad_r-self.psf_size)
90
+ }
91
+
92
+ def __len__(self):
93
+ return len(self.paths)
94
+
95
+ class Dataset_PaddedImage_npy(data.Dataset):
96
+ # validation only
97
+ def __init__(self, opt):
98
+ super(Dataset_PaddedImage_npy, self).__init__()
99
+ self.opt = opt
100
+ # file client (io backend)
101
+ self.file_client = None
102
+ self.io_backend_opt = opt['io_backend']
103
+ self.gt_folder, self.lq_folder = opt['dataroot_gt'], opt['dataroot_lq']
104
+ self.lq_paths = natsorted(os.listdir(self.lq_folder))
105
+ self.gt_paths = natsorted(os.listdir(self.gt_folder))
106
+
107
+ self.sensor_size = opt['sensor_size']
108
+ self.psf_size = opt['psf_size']
109
+ self.padded_size = self.sensor_size + 2 * self.psf_size
110
+
111
+
112
+ def __getitem__(self, index):
113
+ if self.file_client is None:
114
+ self.file_client = FileClient(
115
+ self.io_backend_opt.pop('type'), **self.io_backend_opt)
116
+
117
+ scale = self.opt['scale']
118
+ index = index % len(self.gt_paths)
119
+ # Load gt and lq images. Dimension order: HWC; channel order: BGR;
120
+ # image range: [0, 1], float32.
121
+ gt_path = f"{self.gt_folder}/{self.gt_paths[index]}"
122
+ lq_path = f"{self.lq_folder}/{self.lq_paths[index]}"
123
+ assert os.path.basename(gt_path).split(".")[0] == os.path.basename(lq_path).split(".")[0]
124
+
125
+ img_bytes = self.file_client.get(gt_path, 'gt')
126
+ try:
127
+ img_gt = imfrombytes(img_bytes, float32=True)
128
+ except:
129
+ raise Exception("gt path {} not working".format(gt_path))
130
+
131
+ img_lq = torch.tensor(np.load(lq_path)) # 1,1,81,3,405,405
132
+
133
+ resized_h, resized_w, _ = img_gt.shape
134
+ pad_h = self.padded_size - resized_h
135
+ pad_w = self.padded_size - resized_w
136
+ pad_l = pad_r = pad_w // 2
137
+ if pad_w % 2:
138
+ pad_r += 1
139
+ pad_t = pad_b = pad_h // 2
140
+ if pad_h % 2:
141
+ pad_b += 1
142
+
143
+ # BGR to RGB, HWC to CHW, numpy to tensor
144
+ img_gt = img2tensor(img_gt, bgr2rgb=True,
145
+ float32=True)
146
+
147
+ return {
148
+ 'gt': img_gt,
149
+ 'lq': img_lq,
150
+ 'lq_path': lq_path,
151
+ 'gt_path': gt_path,
152
+ 'padding': (pad_t-self.psf_size, pad_b-self.psf_size, pad_l-self.psf_size, pad_r-self.psf_size)
153
+ }
154
+
155
+ def __len__(self):
156
+ return len(self.gt_paths)
basicsr/data/prefetch_dataloader.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import queue as Queue
2
+ import threading
3
+ import torch
4
+ from torch.utils.data import DataLoader
5
+
6
+
7
+ class PrefetchGenerator(threading.Thread):
8
+ """A general prefetch generator.
9
+
10
+ Ref:
11
+ https://stackoverflow.com/questions/7323664/python-generator-pre-fetch
12
+
13
+ Args:
14
+ generator: Python generator.
15
+ num_prefetch_queue (int): Number of prefetch queue.
16
+ """
17
+
18
+ def __init__(self, generator, num_prefetch_queue):
19
+ threading.Thread.__init__(self)
20
+ self.queue = Queue.Queue(num_prefetch_queue)
21
+ self.generator = generator
22
+ self.daemon = True
23
+ self.start()
24
+
25
+ def run(self):
26
+ for item in self.generator:
27
+ self.queue.put(item)
28
+ self.queue.put(None)
29
+
30
+ def __next__(self):
31
+ next_item = self.queue.get()
32
+ if next_item is None:
33
+ raise StopIteration
34
+ return next_item
35
+
36
+ def __iter__(self):
37
+ return self
38
+
39
+
40
+ class PrefetchDataLoader(DataLoader):
41
+ """Prefetch version of dataloader.
42
+
43
+ Ref:
44
+ https://github.com/IgorSusmelj/pytorch-styleguide/issues/5#
45
+
46
+ TODO:
47
+ Need to test on single gpu and ddp (multi-gpu). There is a known issue in
48
+ ddp.
49
+
50
+ Args:
51
+ num_prefetch_queue (int): Number of prefetch queue.
52
+ kwargs (dict): Other arguments for dataloader.
53
+ """
54
+
55
+ def __init__(self, num_prefetch_queue, **kwargs):
56
+ self.num_prefetch_queue = num_prefetch_queue
57
+ super(PrefetchDataLoader, self).__init__(**kwargs)
58
+
59
+ def __iter__(self):
60
+ return PrefetchGenerator(super().__iter__(), self.num_prefetch_queue)
61
+
62
+
63
+ class CPUPrefetcher():
64
+ """CPU prefetcher.
65
+
66
+ Args:
67
+ loader: Dataloader.
68
+ """
69
+
70
+ def __init__(self, loader):
71
+ self.ori_loader = loader
72
+ self.loader = iter(loader)
73
+
74
+ def next(self):
75
+ try:
76
+ return next(self.loader)
77
+ except StopIteration:
78
+ return None
79
+
80
+ def reset(self):
81
+ self.loader = iter(self.ori_loader)
82
+
83
+
84
+ class CUDAPrefetcher():
85
+ """CUDA prefetcher.
86
+
87
+ Ref:
88
+ https://github.com/NVIDIA/apex/issues/304#
89
+
90
+ It may consums more GPU memory.
91
+
92
+ Args:
93
+ loader: Dataloader.
94
+ opt (dict): Options.
95
+ """
96
+
97
+ def __init__(self, loader, opt):
98
+ self.ori_loader = loader
99
+ self.loader = iter(loader)
100
+ self.opt = opt
101
+ self.stream = torch.cuda.Stream()
102
+ self.device = torch.device('cuda' if opt['num_gpu'] != 0 else 'cpu')
103
+ self.preload()
104
+
105
+ def preload(self):
106
+ try:
107
+ self.batch = next(self.loader) # self.batch is a dict
108
+ except StopIteration:
109
+ self.batch = None
110
+ return None
111
+ # put tensors to gpu
112
+ with torch.cuda.stream(self.stream):
113
+ for k, v in self.batch.items():
114
+ if torch.is_tensor(v):
115
+ self.batch[k] = self.batch[k].to(
116
+ device=self.device, non_blocking=True)
117
+
118
+ def next(self):
119
+ torch.cuda.current_stream().wait_stream(self.stream)
120
+ batch = self.batch
121
+ self.preload()
122
+ return batch
123
+
124
+ def reset(self):
125
+ self.loader = iter(self.ori_loader)
126
+ self.preload()
basicsr/data/transforms.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import random
3
+ import numpy as np
4
+
5
+ def mod_crop(img, scale):
6
+ """Mod crop images, used during testing.
7
+
8
+ Args:
9
+ img (ndarray): Input image.
10
+ scale (int): Scale factor.
11
+
12
+ Returns:
13
+ ndarray: Result image.
14
+ """
15
+ img = img.copy()
16
+ if img.ndim in (2, 3):
17
+ h, w = img.shape[0], img.shape[1]
18
+ h_remainder, w_remainder = h % scale, w % scale
19
+ img = img[:h - h_remainder, :w - w_remainder, ...]
20
+ else:
21
+ raise ValueError(f'Wrong img ndim: {img.ndim}.')
22
+ return img
23
+
24
+
25
+ def augment(imgs, hflip=True, rotation=True, flows=None, return_status=False):
26
+ """Augment: horizontal flips OR rotate (0, 90, 180, 270 degrees).
27
+
28
+ We use vertical flip and transpose for rotation implementation.
29
+ All the images in the list use the same augmentation.
30
+
31
+ Args:
32
+ imgs (list[ndarray] | ndarray): Images to be augmented. If the input
33
+ is an ndarray, it will be transformed to a list.
34
+ hflip (bool): Horizontal flip. Default: True.
35
+ rotation (bool): Ratotation. Default: True.
36
+ flows (list[ndarray]: Flows to be augmented. If the input is an
37
+ ndarray, it will be transformed to a list.
38
+ Dimension is (h, w, 2). Default: None.
39
+ return_status (bool): Return the status of flip and rotation.
40
+ Default: False.
41
+
42
+ Returns:
43
+ list[ndarray] | ndarray: Augmented images and flows. If returned
44
+ results only have one element, just return ndarray.
45
+
46
+ """
47
+ hflip = hflip and random.random() < 0.5
48
+ vflip = rotation and random.random() < 0.5
49
+ rot90 = rotation and random.random() < 0.5
50
+
51
+ def _augment(img):
52
+ if hflip: # horizontal
53
+ cv2.flip(img, 1, img)
54
+ if vflip: # vertical
55
+ cv2.flip(img, 0, img)
56
+ if rot90:
57
+ img = img.transpose(1, 0, 2)
58
+ return img
59
+
60
+ def _augment_flow(flow):
61
+ if hflip: # horizontal
62
+ cv2.flip(flow, 1, flow)
63
+ flow[:, :, 0] *= -1
64
+ if vflip: # vertical
65
+ cv2.flip(flow, 0, flow)
66
+ flow[:, :, 1] *= -1
67
+ if rot90:
68
+ flow = flow.transpose(1, 0, 2)
69
+ flow = flow[:, :, [1, 0]]
70
+ return flow
71
+
72
+ if not isinstance(imgs, list):
73
+ imgs = [imgs]
74
+ imgs = [_augment(img) for img in imgs]
75
+ if len(imgs) == 1:
76
+ imgs = imgs[0]
77
+
78
+ if flows is not None:
79
+ if not isinstance(flows, list):
80
+ flows = [flows]
81
+ flows = [_augment_flow(flow) for flow in flows]
82
+ if len(flows) == 1:
83
+ flows = flows[0]
84
+ return imgs, flows
85
+ else:
86
+ if return_status:
87
+ return imgs, (hflip, vflip, rot90)
88
+ else:
89
+ return imgs
90
+
91
+
92
+ def img_rotate(img, angle, center=None, scale=1.0):
93
+ """Rotate image.
94
+
95
+ Args:
96
+ img (ndarray): Image to be rotated.
97
+ angle (float): Rotation angle in degrees. Positive values mean
98
+ counter-clockwise rotation.
99
+ center (tuple[int]): Rotation center. If the center is None,
100
+ initialize it as the center of the image. Default: None.
101
+ scale (float): Isotropic scale factor. Default: 1.0.
102
+ """
103
+ (h, w) = img.shape[:2]
104
+
105
+ if center is None:
106
+ center = (w // 2, h // 2)
107
+
108
+ matrix = cv2.getRotationMatrix2D(center, angle, scale)
109
+ rotated_img = cv2.warpAffine(img, matrix, (w, h))
110
+ return rotated_img
111
+
112
+ def data_augmentation(image, mode):
113
+ """
114
+ Performs data augmentation of the input image
115
+ Input:
116
+ image: a cv2 (OpenCV) image
117
+ mode: int. Choice of transformation to apply to the image
118
+ 0 - no transformation
119
+ 1 - flip up and down
120
+ 2 - rotate counterwise 90 degree
121
+ 3 - rotate 90 degree and flip up and down
122
+ 4 - rotate 180 degree
123
+ 5 - rotate 180 degree and flip
124
+ 6 - rotate 270 degree
125
+ 7 - rotate 270 degree and flip
126
+ """
127
+ if mode == 0:
128
+ # original
129
+ out = image
130
+ elif mode == 1:
131
+ # flip up and down
132
+ out = np.flipud(image)
133
+ elif mode == 2:
134
+ # rotate counterwise 90 degree
135
+ out = np.rot90(image)
136
+ elif mode == 3:
137
+ # rotate 90 degree and flip up and down
138
+ out = np.rot90(image)
139
+ out = np.flipud(out)
140
+ elif mode == 4:
141
+ # rotate 180 degree
142
+ out = np.rot90(image, k=2)
143
+ elif mode == 5:
144
+ # rotate 180 degree and flip
145
+ out = np.rot90(image, k=2)
146
+ out = np.flipud(out)
147
+ elif mode == 6:
148
+ # rotate 270 degree
149
+ out = np.rot90(image, k=3)
150
+ elif mode == 7:
151
+ # rotate 270 degree and flip
152
+ out = np.rot90(image, k=3)
153
+ out = np.flipud(out)
154
+ else:
155
+ raise Exception('Invalid choice of image transformation')
156
+
157
+ return out
158
+
159
+ def random_augmentation(*args):
160
+ out = []
161
+ flag_aug = random.randint(0,7)
162
+ for data in args:
163
+ if type(data) == list:
164
+ out.append([data_augmentation(_data, flag_aug).copy() for _data in data])
165
+ else:
166
+ out.append(data_augmentation(data, flag_aug).copy())
167
+ return out
basicsr/metrics/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .niqe import calculate_niqe
2
+ from .psnr_ssim import calculate_psnr, calculate_ssim
3
+
4
+ __all__ = ['calculate_psnr', 'calculate_ssim', 'calculate_niqe']
basicsr/metrics/fid.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+ from scipy import linalg
5
+ from tqdm import tqdm
6
+
7
+ from basicsr.models.archs.inception import InceptionV3
8
+
9
+
10
+ def load_patched_inception_v3(device='cuda',
11
+ resize_input=True,
12
+ normalize_input=False):
13
+ # we may not resize the input, but in [rosinality/stylegan2-pytorch] it
14
+ # does resize the input.
15
+ inception = InceptionV3([3],
16
+ resize_input=resize_input,
17
+ normalize_input=normalize_input)
18
+ inception = nn.DataParallel(inception).eval().to(device)
19
+ return inception
20
+
21
+
22
+ @torch.no_grad()
23
+ def extract_inception_features(data_generator,
24
+ inception,
25
+ len_generator=None,
26
+ device='cuda'):
27
+ """Extract inception features.
28
+
29
+ Args:
30
+ data_generator (generator): A data generator.
31
+ inception (nn.Module): Inception model.
32
+ len_generator (int): Length of the data_generator to show the
33
+ progressbar. Default: None.
34
+ device (str): Device. Default: cuda.
35
+
36
+ Returns:
37
+ Tensor: Extracted features.
38
+ """
39
+ if len_generator is not None:
40
+ pbar = tqdm(total=len_generator, unit='batch', desc='Extract')
41
+ else:
42
+ pbar = None
43
+ features = []
44
+
45
+ for data in data_generator:
46
+ if pbar:
47
+ pbar.update(1)
48
+ data = data.to(device)
49
+ feature = inception(data)[0].view(data.shape[0], -1)
50
+ features.append(feature.to('cpu'))
51
+ if pbar:
52
+ pbar.close()
53
+ features = torch.cat(features, 0)
54
+ return features
55
+
56
+
57
+ def calculate_fid(mu1, sigma1, mu2, sigma2, eps=1e-6):
58
+ """Numpy implementation of the Frechet Distance.
59
+
60
+ The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
61
+ and X_2 ~ N(mu_2, C_2) is
62
+ d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
63
+ Stable version by Dougal J. Sutherland.
64
+
65
+ Args:
66
+ mu1 (np.array): The sample mean over activations.
67
+ sigma1 (np.array): The covariance matrix over activations for
68
+ generated samples.
69
+ mu2 (np.array): The sample mean over activations, precalculated on an
70
+ representative data set.
71
+ sigma2 (np.array): The covariance matrix over activations,
72
+ precalculated on an representative data set.
73
+
74
+ Returns:
75
+ float: The Frechet Distance.
76
+ """
77
+ assert mu1.shape == mu2.shape, 'Two mean vectors have different lengths'
78
+ assert sigma1.shape == sigma2.shape, (
79
+ 'Two covariances have different dimensions')
80
+
81
+ cov_sqrt, _ = linalg.sqrtm(sigma1 @ sigma2, disp=False)
82
+
83
+ # Product might be almost singular
84
+ if not np.isfinite(cov_sqrt).all():
85
+ print('Product of cov matrices is singular. Adding {eps} to diagonal '
86
+ 'of cov estimates')
87
+ offset = np.eye(sigma1.shape[0]) * eps
88
+ cov_sqrt = linalg.sqrtm((sigma1 + offset) @ (sigma2 + offset))
89
+
90
+ # Numerical error might give slight imaginary component
91
+ if np.iscomplexobj(cov_sqrt):
92
+ if not np.allclose(np.diagonal(cov_sqrt).imag, 0, atol=1e-3):
93
+ m = np.max(np.abs(cov_sqrt.imag))
94
+ raise ValueError(f'Imaginary component {m}')
95
+ cov_sqrt = cov_sqrt.real
96
+
97
+ mean_diff = mu1 - mu2
98
+ mean_norm = mean_diff @ mean_diff
99
+ trace = np.trace(sigma1) + np.trace(sigma2) - 2 * np.trace(cov_sqrt)
100
+ fid = mean_norm + trace
101
+
102
+ return fid
basicsr/metrics/metric_util.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ from basicsr.utils.matlab_functions import bgr2ycbcr
4
+
5
+
6
+ def reorder_image(img, input_order='HWC'):
7
+ """Reorder images to 'HWC' order.
8
+
9
+ If the input_order is (h, w), return (h, w, 1);
10
+ If the input_order is (c, h, w), return (h, w, c);
11
+ If the input_order is (h, w, c), return as it is.
12
+
13
+ Args:
14
+ img (ndarray): Input image.
15
+ input_order (str): Whether the input order is 'HWC' or 'CHW'.
16
+ If the input image shape is (h, w), input_order will not have
17
+ effects. Default: 'HWC'.
18
+
19
+ Returns:
20
+ ndarray: reordered image.
21
+ """
22
+
23
+ if input_order not in ['HWC', 'CHW']:
24
+ raise ValueError(
25
+ f'Wrong input_order {input_order}. Supported input_orders are '
26
+ "'HWC' and 'CHW'")
27
+ if len(img.shape) == 2:
28
+ img = img[..., None]
29
+ if input_order == 'CHW':
30
+ img = img.transpose(1, 2, 0)
31
+ return img
32
+
33
+
34
+ def to_y_channel(img):
35
+ """Change to Y channel of YCbCr.
36
+
37
+ Args:
38
+ img (ndarray): Images with range [0, 255].
39
+
40
+ Returns:
41
+ (ndarray): Images with range [0, 255] (float type) without round.
42
+ """
43
+ img = img.astype(np.float32) / 255.
44
+ if img.ndim == 3 and img.shape[2] == 3:
45
+ img = bgr2ycbcr(img, y_only=True)
46
+ img = img[..., None]
47
+ return img * 255.
basicsr/metrics/niqe.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import math
3
+ import numpy as np
4
+ from scipy.ndimage.filters import convolve
5
+ from scipy.special import gamma
6
+
7
+ from basicsr.metrics.metric_util import reorder_image, to_y_channel
8
+
9
+
10
+ def estimate_aggd_param(block):
11
+ """Estimate AGGD (Asymmetric Generalized Gaussian Distribution) paramters.
12
+
13
+ Args:
14
+ block (ndarray): 2D Image block.
15
+
16
+ Returns:
17
+ tuple: alpha (float), beta_l (float) and beta_r (float) for the AGGD
18
+ distribution (Estimating the parames in Equation 7 in the paper).
19
+ """
20
+ block = block.flatten()
21
+ gam = np.arange(0.2, 10.001, 0.001) # len = 9801
22
+ gam_reciprocal = np.reciprocal(gam)
23
+ r_gam = np.square(gamma(gam_reciprocal * 2)) / (
24
+ gamma(gam_reciprocal) * gamma(gam_reciprocal * 3))
25
+
26
+ left_std = np.sqrt(np.mean(block[block < 0]**2))
27
+ right_std = np.sqrt(np.mean(block[block > 0]**2))
28
+ gammahat = left_std / right_std
29
+ rhat = (np.mean(np.abs(block)))**2 / np.mean(block**2)
30
+ rhatnorm = (rhat * (gammahat**3 + 1) *
31
+ (gammahat + 1)) / ((gammahat**2 + 1)**2)
32
+ array_position = np.argmin((r_gam - rhatnorm)**2)
33
+
34
+ alpha = gam[array_position]
35
+ beta_l = left_std * np.sqrt(gamma(1 / alpha) / gamma(3 / alpha))
36
+ beta_r = right_std * np.sqrt(gamma(1 / alpha) / gamma(3 / alpha))
37
+ return (alpha, beta_l, beta_r)
38
+
39
+
40
+ def compute_feature(block):
41
+ """Compute features.
42
+
43
+ Args:
44
+ block (ndarray): 2D Image block.
45
+
46
+ Returns:
47
+ list: Features with length of 18.
48
+ """
49
+ feat = []
50
+ alpha, beta_l, beta_r = estimate_aggd_param(block)
51
+ feat.extend([alpha, (beta_l + beta_r) / 2])
52
+
53
+ # distortions disturb the fairly regular structure of natural images.
54
+ # This deviation can be captured by analyzing the sample distribution of
55
+ # the products of pairs of adjacent coefficients computed along
56
+ # horizontal, vertical and diagonal orientations.
57
+ shifts = [[0, 1], [1, 0], [1, 1], [1, -1]]
58
+ for i in range(len(shifts)):
59
+ shifted_block = np.roll(block, shifts[i], axis=(0, 1))
60
+ alpha, beta_l, beta_r = estimate_aggd_param(block * shifted_block)
61
+ # Eq. 8
62
+ mean = (beta_r - beta_l) * (gamma(2 / alpha) / gamma(1 / alpha))
63
+ feat.extend([alpha, mean, beta_l, beta_r])
64
+ return feat
65
+
66
+
67
+ def niqe(img,
68
+ mu_pris_param,
69
+ cov_pris_param,
70
+ gaussian_window,
71
+ block_size_h=96,
72
+ block_size_w=96):
73
+ """Calculate NIQE (Natural Image Quality Evaluator) metric.
74
+
75
+ Ref: Making a "Completely Blind" Image Quality Analyzer.
76
+ This implementation could produce almost the same results as the official
77
+ MATLAB codes: http://live.ece.utexas.edu/research/quality/niqe_release.zip
78
+
79
+ Note that we do not include block overlap height and width, since they are
80
+ always 0 in the official implementation.
81
+
82
+ For good performance, it is advisable by the official implemtation to
83
+ divide the distorted image in to the same size patched as used for the
84
+ construction of multivariate Gaussian model.
85
+
86
+ Args:
87
+ img (ndarray): Input image whose quality needs to be computed. The
88
+ image must be a gray or Y (of YCbCr) image with shape (h, w).
89
+ Range [0, 255] with float type.
90
+ mu_pris_param (ndarray): Mean of a pre-defined multivariate Gaussian
91
+ model calculated on the pristine dataset.
92
+ cov_pris_param (ndarray): Covariance of a pre-defined multivariate
93
+ Gaussian model calculated on the pristine dataset.
94
+ gaussian_window (ndarray): A 7x7 Gaussian window used for smoothing the
95
+ image.
96
+ block_size_h (int): Height of the blocks in to which image is divided.
97
+ Default: 96 (the official recommended value).
98
+ block_size_w (int): Width of the blocks in to which image is divided.
99
+ Default: 96 (the official recommended value).
100
+ """
101
+ assert img.ndim == 2, (
102
+ 'Input image must be a gray or Y (of YCbCr) image with shape (h, w).')
103
+ # crop image
104
+ h, w = img.shape
105
+ num_block_h = math.floor(h / block_size_h)
106
+ num_block_w = math.floor(w / block_size_w)
107
+ img = img[0:num_block_h * block_size_h, 0:num_block_w * block_size_w]
108
+
109
+ distparam = [] # dist param is actually the multiscale features
110
+ for scale in (1, 2): # perform on two scales (1, 2)
111
+ mu = convolve(img, gaussian_window, mode='nearest')
112
+ sigma = np.sqrt(
113
+ np.abs(
114
+ convolve(np.square(img), gaussian_window, mode='nearest') -
115
+ np.square(mu)))
116
+ # normalize, as in Eq. 1 in the paper
117
+ img_nomalized = (img - mu) / (sigma + 1)
118
+
119
+ feat = []
120
+ for idx_w in range(num_block_w):
121
+ for idx_h in range(num_block_h):
122
+ # process ecah block
123
+ block = img_nomalized[idx_h * block_size_h //
124
+ scale:(idx_h + 1) * block_size_h //
125
+ scale, idx_w * block_size_w //
126
+ scale:(idx_w + 1) * block_size_w //
127
+ scale]
128
+ feat.append(compute_feature(block))
129
+
130
+ distparam.append(np.array(feat))
131
+ # TODO: matlab bicubic downsample with anti-aliasing
132
+ # for simplicity, now we use opencv instead, which will result in
133
+ # a slight difference.
134
+ if scale == 1:
135
+ h, w = img.shape
136
+ img = cv2.resize(
137
+ img / 255., (w // 2, h // 2), interpolation=cv2.INTER_LINEAR)
138
+ img = img * 255.
139
+
140
+ distparam = np.concatenate(distparam, axis=1)
141
+
142
+ # fit a MVG (multivariate Gaussian) model to distorted patch features
143
+ mu_distparam = np.nanmean(distparam, axis=0)
144
+ # use nancov. ref: https://ww2.mathworks.cn/help/stats/nancov.html
145
+ distparam_no_nan = distparam[~np.isnan(distparam).any(axis=1)]
146
+ cov_distparam = np.cov(distparam_no_nan, rowvar=False)
147
+
148
+ # compute niqe quality, Eq. 10 in the paper
149
+ invcov_param = np.linalg.pinv((cov_pris_param + cov_distparam) / 2)
150
+ quality = np.matmul(
151
+ np.matmul((mu_pris_param - mu_distparam), invcov_param),
152
+ np.transpose((mu_pris_param - mu_distparam)))
153
+ quality = np.sqrt(quality)
154
+
155
+ return quality
156
+
157
+
158
+ def calculate_niqe(img, crop_border, input_order='HWC', convert_to='y'):
159
+ """Calculate NIQE (Natural Image Quality Evaluator) metric.
160
+
161
+ Ref: Making a "Completely Blind" Image Quality Analyzer.
162
+ This implementation could produce almost the same results as the official
163
+ MATLAB codes: http://live.ece.utexas.edu/research/quality/niqe_release.zip
164
+
165
+ We use the official params estimated from the pristine dataset.
166
+ We use the recommended block size (96, 96) without overlaps.
167
+
168
+ Args:
169
+ img (ndarray): Input image whose quality needs to be computed.
170
+ The input image must be in range [0, 255] with float/int type.
171
+ The input_order of image can be 'HW' or 'HWC' or 'CHW'. (BGR order)
172
+ If the input order is 'HWC' or 'CHW', it will be converted to gray
173
+ or Y (of YCbCr) image according to the ``convert_to`` argument.
174
+ crop_border (int): Cropped pixels in each edge of an image. These
175
+ pixels are not involved in the metric calculation.
176
+ input_order (str): Whether the input order is 'HW', 'HWC' or 'CHW'.
177
+ Default: 'HWC'.
178
+ convert_to (str): Whether coverted to 'y' (of MATLAB YCbCr) or 'gray'.
179
+ Default: 'y'.
180
+
181
+ Returns:
182
+ float: NIQE result.
183
+ """
184
+
185
+ # we use the official params estimated from the pristine dataset.
186
+ niqe_pris_params = np.load('basicsr/metrics/niqe_pris_params.npz')
187
+ mu_pris_param = niqe_pris_params['mu_pris_param']
188
+ cov_pris_param = niqe_pris_params['cov_pris_param']
189
+ gaussian_window = niqe_pris_params['gaussian_window']
190
+
191
+ img = img.astype(np.float32)
192
+ if input_order != 'HW':
193
+ img = reorder_image(img, input_order=input_order)
194
+ if convert_to == 'y':
195
+ img = to_y_channel(img)
196
+ elif convert_to == 'gray':
197
+ img = cv2.cvtColor(img / 255., cv2.COLOR_BGR2GRAY) * 255.
198
+ img = np.squeeze(img)
199
+
200
+ if crop_border != 0:
201
+ img = img[crop_border:-crop_border, crop_border:-crop_border]
202
+
203
+ niqe_result = niqe(img, mu_pris_param, cov_pris_param, gaussian_window)
204
+
205
+ return niqe_result
basicsr/metrics/niqe_pris_params.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2a7c182a68c9e7f1b2e2e5ec723279d6f65d912b6fcaf37eb2bf03d7367c4296
3
+ size 11850
basicsr/metrics/other_metrics.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import os
4
+ from PIL import Image
5
+ from natsort import natsorted
6
+ from glob import glob
7
+ from skimage import metrics
8
+ import torch.hub
9
+ from lpips.lpips import LPIPS
10
+ from tqdm import tqdm
11
+
12
+
13
+ photometric = {
14
+ "mse": None,
15
+ "ssim": None,
16
+ "psnr": None,
17
+ "lpips": None
18
+ }
19
+
20
+ def psnr(img1, img2):
21
+ mse = (((img1 - img2)) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True)
22
+ return 20 * torch.log10(1.0 / torch.sqrt(mse))
23
+
24
+ def compute_img_metric(im1t: torch.Tensor, im2t: torch.Tensor,
25
+ metric="mse", mask=None):
26
+ """
27
+ im1t, im2t: torch.tensors with batched imaged shape, range from (0, 1)
28
+ """
29
+ if metric not in photometric.keys():
30
+ raise RuntimeError(f"img_utils:: metric {metric} not recognized")
31
+ if photometric[metric] is None:
32
+ if metric == "mse":
33
+ photometric[metric] = metrics.mean_squared_error
34
+ elif metric == "ssim":
35
+ photometric[metric] = metrics.structural_similarity
36
+ elif metric == "psnr":
37
+ photometric[metric] = metrics.peak_signal_noise_ratio
38
+ elif metric == "lpips":
39
+ photometric[metric] = LPIPS().cpu()
40
+
41
+ # convert from [0, 1] to [-1, 1]
42
+ im1t = (im1t * 2 - 1).clamp(-1, 1)
43
+ im2t = (im2t * 2 - 1).clamp(-1, 1)
44
+
45
+ if im1t.dim() == 3:
46
+ im1t = im1t.unsqueeze(0)
47
+ im2t = im2t.unsqueeze(0)
48
+ im1t = im1t.detach().cpu()
49
+ im2t = im2t.detach().cpu()
50
+
51
+ if im1t.shape[-1] == 3:
52
+ im1t = im1t.permute(0, 3, 1, 2) # BCHW
53
+ im2t = im2t.permute(0, 3, 1, 2)
54
+
55
+ im1 = im1t.permute(0, 2, 3, 1).numpy()
56
+ im2 = im2t.permute(0, 2, 3, 1).numpy()
57
+ batchsz, hei, wid, _ = im1.shape
58
+ values = []
59
+
60
+ for i in range(batchsz):
61
+ if metric in ["mse", "psnr"]:
62
+ if mask is not None:
63
+ im1 = im1 * mask[i]
64
+ im2 = im2 * mask[i]
65
+ value = photometric[metric](
66
+ im1[i], im2[i]
67
+ )
68
+ if mask is not None:
69
+ hei, wid, _ = im1[i].shape
70
+ pixelnum = mask[i, ..., 0].sum()
71
+ value = value - 10 * np.log10(hei * wid / pixelnum)
72
+ elif metric in ["ssim"]:
73
+ value, ssimmap = photometric["ssim"](
74
+ im1[i], im2[i], multichannel=True, full=True
75
+ )
76
+ if mask is not None:
77
+ value = (ssimmap * mask[i]).sum() / mask[i].sum()
78
+ elif metric in ["lpips"]:
79
+ value = photometric[metric](
80
+ im1t[i:i + 1], im2t[i:i + 1]
81
+ )
82
+ else:
83
+ raise NotImplementedError
84
+ values.append(value)
85
+
86
+ return sum(values) / len(values)
87
+
88
+
basicsr/metrics/psnr_ssim.py ADDED
@@ -0,0 +1,303 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+
4
+ from basicsr.metrics.metric_util import reorder_image, to_y_channel
5
+ import skimage.metrics
6
+ import torch
7
+
8
+
9
+ def calculate_psnr(img1,
10
+ img2,
11
+ crop_border,
12
+ input_order='HWC',
13
+ test_y_channel=False):
14
+ """Calculate PSNR (Peak Signal-to-Noise Ratio).
15
+
16
+ Ref: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio
17
+
18
+ Args:
19
+ img1 (ndarray/tensor): Images with range [0, 255]/[0, 1].
20
+ img2 (ndarray/tensor): Images with range [0, 255]/[0, 1].
21
+ crop_border (int): Cropped pixels in each edge of an image. These
22
+ pixels are not involved in the PSNR calculation.
23
+ input_order (str): Whether the input order is 'HWC' or 'CHW'.
24
+ Default: 'HWC'.
25
+ test_y_channel (bool): Test on Y channel of YCbCr. Default: False.
26
+
27
+ Returns:
28
+ float: psnr result.
29
+ """
30
+
31
+ assert img1.shape == img2.shape, (
32
+ f'Image shapes are differnet: {img1.shape}, {img2.shape}.')
33
+ if input_order not in ['HWC', 'CHW']:
34
+ raise ValueError(
35
+ f'Wrong input_order {input_order}. Supported input_orders are '
36
+ '"HWC" and "CHW"')
37
+ if type(img1) == torch.Tensor:
38
+ if len(img1.shape) == 4:
39
+ img1 = img1.squeeze(0)
40
+ img1 = img1.detach().cpu().numpy().transpose(1,2,0)
41
+ if type(img2) == torch.Tensor:
42
+ if len(img2.shape) == 4:
43
+ img2 = img2.squeeze(0)
44
+ img2 = img2.detach().cpu().numpy().transpose(1,2,0)
45
+
46
+ img1 = reorder_image(img1, input_order=input_order)
47
+ img2 = reorder_image(img2, input_order=input_order)
48
+ img1 = img1.astype(np.float64)
49
+ img2 = img2.astype(np.float64)
50
+
51
+ if crop_border != 0:
52
+ img1 = img1[crop_border:-crop_border, crop_border:-crop_border, ...]
53
+ img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...]
54
+
55
+ if test_y_channel:
56
+ img1 = to_y_channel(img1)
57
+ img2 = to_y_channel(img2)
58
+
59
+ mse = np.mean((img1 - img2)**2)
60
+ if mse == 0:
61
+ return float('inf')
62
+ max_value = 1. if img1.max() <= 1 else 255.
63
+ return 20. * np.log10(max_value / np.sqrt(mse))
64
+
65
+
66
+ def _ssim(img1, img2):
67
+ """Calculate SSIM (structural similarity) for one channel images.
68
+
69
+ It is called by func:`calculate_ssim`.
70
+
71
+ Args:
72
+ img1 (ndarray): Images with range [0, 255] with order 'HWC'.
73
+ img2 (ndarray): Images with range [0, 255] with order 'HWC'.
74
+
75
+ Returns:
76
+ float: ssim result.
77
+ """
78
+
79
+ C1 = (0.01 * 255)**2
80
+ C2 = (0.03 * 255)**2
81
+
82
+ img1 = img1.astype(np.float64)
83
+ img2 = img2.astype(np.float64)
84
+ kernel = cv2.getGaussianKernel(11, 1.5)
85
+ window = np.outer(kernel, kernel.transpose())
86
+
87
+ mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5]
88
+ mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
89
+ mu1_sq = mu1**2
90
+ mu2_sq = mu2**2
91
+ mu1_mu2 = mu1 * mu2
92
+ sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq
93
+ sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
94
+ sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
95
+
96
+ ssim_map = ((2 * mu1_mu2 + C1) *
97
+ (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
98
+ (sigma1_sq + sigma2_sq + C2))
99
+ return ssim_map.mean()
100
+
101
+ def prepare_for_ssim(img, k):
102
+ import torch
103
+ with torch.no_grad():
104
+ img = torch.from_numpy(img).unsqueeze(0).unsqueeze(0).float()
105
+ conv = torch.nn.Conv2d(1, 1, k, stride=1, padding=k//2, padding_mode='reflect')
106
+ conv.weight.requires_grad = False
107
+ conv.weight[:, :, :, :] = 1. / (k * k)
108
+
109
+ img = conv(img)
110
+
111
+ img = img.squeeze(0).squeeze(0)
112
+ img = img[0::k, 0::k]
113
+ return img.detach().cpu().numpy()
114
+
115
+ def prepare_for_ssim_rgb(img, k):
116
+ import torch
117
+ with torch.no_grad():
118
+ img = torch.from_numpy(img).float() #HxWx3
119
+
120
+ conv = torch.nn.Conv2d(1, 1, k, stride=1, padding=k // 2, padding_mode='reflect')
121
+ conv.weight.requires_grad = False
122
+ conv.weight[:, :, :, :] = 1. / (k * k)
123
+
124
+ new_img = []
125
+
126
+ for i in range(3):
127
+ new_img.append(conv(img[:, :, i].unsqueeze(0).unsqueeze(0)).squeeze(0).squeeze(0)[0::k, 0::k])
128
+
129
+ return torch.stack(new_img, dim=2).detach().cpu().numpy()
130
+
131
+ def _3d_gaussian_calculator(img, conv3d):
132
+ out = conv3d(img.unsqueeze(0).unsqueeze(0)).squeeze(0).squeeze(0)
133
+ return out
134
+
135
+ def _generate_3d_gaussian_kernel():
136
+ kernel = cv2.getGaussianKernel(11, 1.5)
137
+ window = np.outer(kernel, kernel.transpose())
138
+ kernel_3 = cv2.getGaussianKernel(11, 1.5)
139
+ kernel = torch.tensor(np.stack([window * k for k in kernel_3], axis=0))
140
+ conv3d = torch.nn.Conv3d(1, 1, (11, 11, 11), stride=1, padding=(5, 5, 5), bias=False, padding_mode='replicate')
141
+ conv3d.weight.requires_grad = False
142
+ conv3d.weight[0, 0, :, :, :] = kernel
143
+ return conv3d
144
+
145
+ def _ssim_3d(img1, img2, max_value):
146
+ assert len(img1.shape) == 3 and len(img2.shape) == 3
147
+ """Calculate SSIM (structural similarity) for one channel images.
148
+
149
+ It is called by func:`calculate_ssim`.
150
+
151
+ Args:
152
+ img1 (ndarray): Images with range [0, 255]/[0, 1] with order 'HWC'.
153
+ img2 (ndarray): Images with range [0, 255]/[0, 1] with order 'HWC'.
154
+
155
+ Returns:
156
+ float: ssim result.
157
+ """
158
+ C1 = (0.01 * max_value) ** 2
159
+ C2 = (0.03 * max_value) ** 2
160
+ img1 = img1.astype(np.float64)
161
+ img2 = img2.astype(np.float64)
162
+
163
+ kernel = _generate_3d_gaussian_kernel().cuda()
164
+
165
+ img1 = torch.tensor(img1).float().cuda()
166
+ img2 = torch.tensor(img2).float().cuda()
167
+
168
+
169
+ mu1 = _3d_gaussian_calculator(img1, kernel)
170
+ mu2 = _3d_gaussian_calculator(img2, kernel)
171
+
172
+ mu1_sq = mu1 ** 2
173
+ mu2_sq = mu2 ** 2
174
+ mu1_mu2 = mu1 * mu2
175
+ sigma1_sq = _3d_gaussian_calculator(img1 ** 2, kernel) - mu1_sq
176
+ sigma2_sq = _3d_gaussian_calculator(img2 ** 2, kernel) - mu2_sq
177
+ sigma12 = _3d_gaussian_calculator(img1*img2, kernel) - mu1_mu2
178
+
179
+ ssim_map = ((2 * mu1_mu2 + C1) *
180
+ (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
181
+ (sigma1_sq + sigma2_sq + C2))
182
+ return float(ssim_map.mean())
183
+
184
+ def _ssim_cly(img1, img2):
185
+ assert len(img1.shape) == 2 and len(img2.shape) == 2
186
+ """Calculate SSIM (structural similarity) for one channel images.
187
+
188
+ It is called by func:`calculate_ssim`.
189
+
190
+ Args:
191
+ img1 (ndarray): Images with range [0, 255] with order 'HWC'.
192
+ img2 (ndarray): Images with range [0, 255] with order 'HWC'.
193
+
194
+ Returns:
195
+ float: ssim result.
196
+ """
197
+
198
+ C1 = (0.01 * 255)**2
199
+ C2 = (0.03 * 255)**2
200
+ img1 = img1.astype(np.float64)
201
+ img2 = img2.astype(np.float64)
202
+
203
+ kernel = cv2.getGaussianKernel(11, 1.5)
204
+ # print(kernel)
205
+ window = np.outer(kernel, kernel.transpose())
206
+
207
+ bt = cv2.BORDER_REPLICATE
208
+
209
+ mu1 = cv2.filter2D(img1, -1, window, borderType=bt)
210
+ mu2 = cv2.filter2D(img2, -1, window,borderType=bt)
211
+
212
+ mu1_sq = mu1**2
213
+ mu2_sq = mu2**2
214
+ mu1_mu2 = mu1 * mu2
215
+ sigma1_sq = cv2.filter2D(img1**2, -1, window, borderType=bt) - mu1_sq
216
+ sigma2_sq = cv2.filter2D(img2**2, -1, window, borderType=bt) - mu2_sq
217
+ sigma12 = cv2.filter2D(img1 * img2, -1, window, borderType=bt) - mu1_mu2
218
+
219
+ ssim_map = ((2 * mu1_mu2 + C1) *
220
+ (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
221
+ (sigma1_sq + sigma2_sq + C2))
222
+ return ssim_map.mean()
223
+
224
+
225
+ def calculate_ssim(img1,
226
+ img2,
227
+ crop_border,
228
+ input_order='HWC',
229
+ test_y_channel=False):
230
+ """Calculate SSIM (structural similarity).
231
+
232
+ Ref:
233
+ Image quality assessment: From error visibility to structural similarity
234
+
235
+ The results are the same as that of the official released MATLAB code in
236
+ https://ece.uwaterloo.ca/~z70wang/research/ssim/.
237
+
238
+ For three-channel images, SSIM is calculated for each channel and then
239
+ averaged.
240
+
241
+ Args:
242
+ img1 (ndarray): Images with range [0, 255].
243
+ img2 (ndarray): Images with range [0, 255].
244
+ crop_border (int): Cropped pixels in each edge of an image. These
245
+ pixels are not involved in the SSIM calculation.
246
+ input_order (str): Whether the input order is 'HWC' or 'CHW'.
247
+ Default: 'HWC'.
248
+ test_y_channel (bool): Test on Y channel of YCbCr. Default: False.
249
+
250
+ Returns:
251
+ float: ssim result.
252
+ """
253
+
254
+ assert img1.shape == img2.shape, (
255
+ f'Image shapes are differnet: {img1.shape}, {img2.shape}.')
256
+ if input_order not in ['HWC', 'CHW']:
257
+ raise ValueError(
258
+ f'Wrong input_order {input_order}. Supported input_orders are '
259
+ '"HWC" and "CHW"')
260
+
261
+ if type(img1) == torch.Tensor:
262
+ if len(img1.shape) == 4:
263
+ img1 = img1.squeeze(0)
264
+ img1 = img1.detach().cpu().numpy().transpose(1,2,0)
265
+ if type(img2) == torch.Tensor:
266
+ if len(img2.shape) == 4:
267
+ img2 = img2.squeeze(0)
268
+ img2 = img2.detach().cpu().numpy().transpose(1,2,0)
269
+
270
+ img1 = reorder_image(img1, input_order=input_order)
271
+ img2 = reorder_image(img2, input_order=input_order)
272
+
273
+ img1 = img1.astype(np.float64)
274
+ img2 = img2.astype(np.float64)
275
+
276
+ if crop_border != 0:
277
+ img1 = img1[crop_border:-crop_border, crop_border:-crop_border, ...]
278
+ img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...]
279
+
280
+ if test_y_channel:
281
+ img1 = to_y_channel(img1)
282
+ img2 = to_y_channel(img2)
283
+ return _ssim_cly(img1[..., 0], img2[..., 0])
284
+
285
+
286
+ ssims = []
287
+ # ssims_before = []
288
+
289
+ # skimage_before = skimage.metrics.structural_similarity(img1, img2, data_range=255., multichannel=True)
290
+ # print('.._skimage',
291
+ # skimage.metrics.structural_similarity(img1, img2, data_range=255., multichannel=True))
292
+ max_value = 1 if img1.max() <= 1 else 255
293
+ with torch.no_grad():
294
+ final_ssim = _ssim_3d(img1, img2, max_value)
295
+ ssims.append(final_ssim)
296
+
297
+ # for i in range(img1.shape[2]):
298
+ # ssims_before.append(_ssim(img1, img2))
299
+
300
+ # print('..ssim mean , new {:.4f} and before {:.4f} .... skimage before {:.4f}'.format(np.array(ssims).mean(), np.array(ssims_before).mean(), skimage_before))
301
+ # ssims.append(skimage.metrics.structural_similarity(img1[..., i], img2[..., i], multichannel=False))
302
+
303
+ return np.array(ssims).mean()
basicsr/models/__init__.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ from os import path as osp
3
+
4
+ from basicsr.utils import get_root_logger, scandir
5
+
6
+ # automatically scan and import model modules
7
+ # scan all the files under the 'models' folder and collect files ending with
8
+ # '_model.py'
9
+ model_folder = osp.dirname(osp.abspath(__file__))
10
+ model_filenames = [
11
+ osp.splitext(osp.basename(v))[0] for v in scandir(model_folder)
12
+ if v.endswith('_model.py')
13
+ ]
14
+ # import all the model modules
15
+ _model_modules = [
16
+ importlib.import_module(f'basicsr.models.{file_name}')
17
+ for file_name in model_filenames
18
+ ]
19
+
20
+
21
+ def create_model(opt):
22
+ """Create model.
23
+
24
+ Args:
25
+ opt (dict): Configuration. It constains:
26
+ model_type (str): Model type.
27
+ """
28
+ model_type = opt['model_type']
29
+
30
+ # dynamic instantiation
31
+ for module in _model_modules:
32
+ model_cls = getattr(module, model_type, None)
33
+ if model_cls is not None:
34
+ break
35
+ if model_cls is None:
36
+ raise ValueError(f'Model {model_type} is not found.')
37
+
38
+ model = model_cls(opt)
39
+
40
+ logger = get_root_logger()
41
+ logger.info(f'Model [{model.__class__.__name__}] is created.')
42
+ return model
basicsr/models/archs/__init__.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ from os import path as osp
3
+
4
+ from basicsr.utils import scandir
5
+
6
+ # automatically scan and import arch modules
7
+ # scan all the files under the 'archs' folder and collect files ending with
8
+ # '_arch.py'
9
+ arch_folder = osp.dirname(osp.abspath(__file__))
10
+ arch_filenames = [
11
+ osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder)
12
+ if v.endswith('_arch.py')
13
+ ]
14
+ # import all the arch modules
15
+ _arch_modules = [
16
+ importlib.import_module(f'basicsr.models.archs.{file_name}')
17
+ for file_name in arch_filenames
18
+ ]
19
+
20
+
21
+ def dynamic_instantiation(modules, cls_type, opt):
22
+ """Dynamically instantiate class.
23
+
24
+ Args:
25
+ modules (list[importlib modules]): List of modules from importlib
26
+ files.
27
+ cls_type (str): Class type.
28
+ opt (dict): Class initialization kwargs.
29
+
30
+ Returns:
31
+ class: Instantiated class.
32
+ """
33
+ for module in modules:
34
+ cls_ = getattr(module, cls_type, None)
35
+ if cls_ is not None:
36
+ break
37
+ if cls_ is None:
38
+ raise ValueError(f'{cls_type} is not found.')
39
+ return cls_(**opt)
40
+
41
+
42
+ def define_network(opt):
43
+ network_type = opt.pop('type')
44
+ net = dynamic_instantiation(_arch_modules, network_type, opt)
45
+ return net
basicsr/models/archs/arch_util.py ADDED
@@ -0,0 +1,255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch import nn as nn
4
+ from torch.nn import functional as F
5
+ from torch.nn import init as init
6
+ from torch.nn.modules.batchnorm import _BatchNorm
7
+
8
+ from basicsr.utils import get_root_logger
9
+
10
+ # try:
11
+ # from basicsr.models.ops.dcn import (ModulatedDeformConvPack,
12
+ # modulated_deform_conv)
13
+ # except ImportError:
14
+ # # print('Cannot import dcn. Ignore this warning if dcn is not used. '
15
+ # # 'Otherwise install BasicSR with compiling dcn.')
16
+ #
17
+
18
+ @torch.no_grad()
19
+ def default_init_weights(module_list, scale=1, bias_fill=0, **kwargs):
20
+ """Initialize network weights.
21
+
22
+ Args:
23
+ module_list (list[nn.Module] | nn.Module): Modules to be initialized.
24
+ scale (float): Scale initialized weights, especially for residual
25
+ blocks. Default: 1.
26
+ bias_fill (float): The value to fill bias. Default: 0
27
+ kwargs (dict): Other arguments for initialization function.
28
+ """
29
+ if not isinstance(module_list, list):
30
+ module_list = [module_list]
31
+ for module in module_list:
32
+ for m in module.modules():
33
+ if isinstance(m, nn.Conv2d):
34
+ init.kaiming_normal_(m.weight, **kwargs)
35
+ m.weight.data *= scale
36
+ if m.bias is not None:
37
+ m.bias.data.fill_(bias_fill)
38
+ elif isinstance(m, nn.Linear):
39
+ init.kaiming_normal_(m.weight, **kwargs)
40
+ m.weight.data *= scale
41
+ if m.bias is not None:
42
+ m.bias.data.fill_(bias_fill)
43
+ elif isinstance(m, _BatchNorm):
44
+ init.constant_(m.weight, 1)
45
+ if m.bias is not None:
46
+ m.bias.data.fill_(bias_fill)
47
+
48
+
49
+ def make_layer(basic_block, num_basic_block, **kwarg):
50
+ """Make layers by stacking the same blocks.
51
+
52
+ Args:
53
+ basic_block (nn.module): nn.module class for basic block.
54
+ num_basic_block (int): number of blocks.
55
+
56
+ Returns:
57
+ nn.Sequential: Stacked blocks in nn.Sequential.
58
+ """
59
+ layers = []
60
+ for _ in range(num_basic_block):
61
+ layers.append(basic_block(**kwarg))
62
+ return nn.Sequential(*layers)
63
+
64
+
65
+ class ResidualBlockNoBN(nn.Module):
66
+ """Residual block without BN.
67
+
68
+ It has a style of:
69
+ ---Conv-ReLU-Conv-+-
70
+ |________________|
71
+
72
+ Args:
73
+ num_feat (int): Channel number of intermediate features.
74
+ Default: 64.
75
+ res_scale (float): Residual scale. Default: 1.
76
+ pytorch_init (bool): If set to True, use pytorch default init,
77
+ otherwise, use default_init_weights. Default: False.
78
+ """
79
+
80
+ def __init__(self, num_feat=64, res_scale=1, pytorch_init=False):
81
+ super(ResidualBlockNoBN, self).__init__()
82
+ self.res_scale = res_scale
83
+ self.conv1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
84
+ self.conv2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
85
+ self.relu = nn.ReLU(inplace=True)
86
+
87
+ if not pytorch_init:
88
+ default_init_weights([self.conv1, self.conv2], 0.1)
89
+
90
+ def forward(self, x):
91
+ identity = x
92
+ out = self.conv2(self.relu(self.conv1(x)))
93
+ return identity + out * self.res_scale
94
+
95
+
96
+ class Upsample(nn.Sequential):
97
+ """Upsample module.
98
+
99
+ Args:
100
+ scale (int): Scale factor. Supported scales: 2^n and 3.
101
+ num_feat (int): Channel number of intermediate features.
102
+ """
103
+
104
+ def __init__(self, scale, num_feat):
105
+ m = []
106
+ if (scale & (scale - 1)) == 0: # scale = 2^n
107
+ for _ in range(int(math.log(scale, 2))):
108
+ m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
109
+ m.append(nn.PixelShuffle(2))
110
+ elif scale == 3:
111
+ m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
112
+ m.append(nn.PixelShuffle(3))
113
+ else:
114
+ raise ValueError(f'scale {scale} is not supported. '
115
+ 'Supported scales: 2^n and 3.')
116
+ super(Upsample, self).__init__(*m)
117
+
118
+
119
+ def flow_warp(x,
120
+ flow,
121
+ interp_mode='bilinear',
122
+ padding_mode='zeros',
123
+ align_corners=True):
124
+ """Warp an image or feature map with optical flow.
125
+
126
+ Args:
127
+ x (Tensor): Tensor with size (n, c, h, w).
128
+ flow (Tensor): Tensor with size (n, h, w, 2), normal value.
129
+ interp_mode (str): 'nearest' or 'bilinear'. Default: 'bilinear'.
130
+ padding_mode (str): 'zeros' or 'border' or 'reflection'.
131
+ Default: 'zeros'.
132
+ align_corners (bool): Before pytorch 1.3, the default value is
133
+ align_corners=True. After pytorch 1.3, the default value is
134
+ align_corners=False. Here, we use the True as default.
135
+
136
+ Returns:
137
+ Tensor: Warped image or feature map.
138
+ """
139
+ assert x.size()[-2:] == flow.size()[1:3]
140
+ _, _, h, w = x.size()
141
+ # create mesh grid
142
+ grid_y, grid_x = torch.meshgrid(
143
+ torch.arange(0, h).type_as(x),
144
+ torch.arange(0, w).type_as(x))
145
+ grid = torch.stack((grid_x, grid_y), 2).float() # W(x), H(y), 2
146
+ grid.requires_grad = False
147
+
148
+ vgrid = grid + flow
149
+ # scale grid to [-1,1]
150
+ vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(w - 1, 1) - 1.0
151
+ vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(h - 1, 1) - 1.0
152
+ vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3)
153
+ output = F.grid_sample(
154
+ x,
155
+ vgrid_scaled,
156
+ mode=interp_mode,
157
+ padding_mode=padding_mode,
158
+ align_corners=align_corners)
159
+
160
+ # TODO, what if align_corners=False
161
+ return output
162
+
163
+
164
+ def resize_flow(flow,
165
+ size_type,
166
+ sizes,
167
+ interp_mode='bilinear',
168
+ align_corners=False):
169
+ """Resize a flow according to ratio or shape.
170
+
171
+ Args:
172
+ flow (Tensor): Precomputed flow. shape [N, 2, H, W].
173
+ size_type (str): 'ratio' or 'shape'.
174
+ sizes (list[int | float]): the ratio for resizing or the final output
175
+ shape.
176
+ 1) The order of ratio should be [ratio_h, ratio_w]. For
177
+ downsampling, the ratio should be smaller than 1.0 (i.e., ratio
178
+ < 1.0). For upsampling, the ratio should be larger than 1.0 (i.e.,
179
+ ratio > 1.0).
180
+ 2) The order of output_size should be [out_h, out_w].
181
+ interp_mode (str): The mode of interpolation for resizing.
182
+ Default: 'bilinear'.
183
+ align_corners (bool): Whether align corners. Default: False.
184
+
185
+ Returns:
186
+ Tensor: Resized flow.
187
+ """
188
+ _, _, flow_h, flow_w = flow.size()
189
+ if size_type == 'ratio':
190
+ output_h, output_w = int(flow_h * sizes[0]), int(flow_w * sizes[1])
191
+ elif size_type == 'shape':
192
+ output_h, output_w = sizes[0], sizes[1]
193
+ else:
194
+ raise ValueError(
195
+ f'Size type should be ratio or shape, but got type {size_type}.')
196
+
197
+ input_flow = flow.clone()
198
+ ratio_h = output_h / flow_h
199
+ ratio_w = output_w / flow_w
200
+ input_flow[:, 0, :, :] *= ratio_w
201
+ input_flow[:, 1, :, :] *= ratio_h
202
+ resized_flow = F.interpolate(
203
+ input=input_flow,
204
+ size=(output_h, output_w),
205
+ mode=interp_mode,
206
+ align_corners=align_corners)
207
+ return resized_flow
208
+
209
+
210
+ # TODO: may write a cpp file
211
+ def pixel_unshuffle(x, scale):
212
+ """ Pixel unshuffle.
213
+
214
+ Args:
215
+ x (Tensor): Input feature with shape (b, c, hh, hw).
216
+ scale (int): Downsample ratio.
217
+
218
+ Returns:
219
+ Tensor: the pixel unshuffled feature.
220
+ """
221
+ b, c, hh, hw = x.size()
222
+ out_channel = c * (scale**2)
223
+ assert hh % scale == 0 and hw % scale == 0
224
+ h = hh // scale
225
+ w = hw // scale
226
+ x_view = x.view(b, c, h, scale, w, scale)
227
+ return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w)
228
+
229
+
230
+ # class DCNv2Pack(ModulatedDeformConvPack):
231
+ # """Modulated deformable conv for deformable alignment.
232
+ #
233
+ # Different from the official DCNv2Pack, which generates offsets and masks
234
+ # from the preceding features, this DCNv2Pack takes another different
235
+ # features to generate offsets and masks.
236
+ #
237
+ # Ref:
238
+ # Delving Deep into Deformable Alignment in Video Super-Resolution.
239
+ # """
240
+ #
241
+ # def forward(self, x, feat):
242
+ # out = self.conv_offset(feat)
243
+ # o1, o2, mask = torch.chunk(out, 3, dim=1)
244
+ # offset = torch.cat((o1, o2), dim=1)
245
+ # mask = torch.sigmoid(mask)
246
+ #
247
+ # offset_absmean = torch.mean(torch.abs(offset))
248
+ # if offset_absmean > 50:
249
+ # logger = get_root_logger()
250
+ # logger.warning(
251
+ # f'Offset abs mean is {offset_absmean}, larger than 50.')
252
+ #
253
+ # return modulated_deform_conv(x, offset, mask, self.weight, self.bias,
254
+ # self.stride, self.padding, self.dilation,
255
+ # self.groups, self.deformable_groups)
basicsr/models/archs/restormer_arch.py ADDED
@@ -0,0 +1,527 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import numbers
5
+ from torch import einsum
6
+
7
+ from einops import rearrange
8
+ from basicsr.utils.nano import psf2otf
9
+
10
+ try:
11
+ from flash_attn import flash_attn_func
12
+ except:
13
+ print("Flash attention is required")
14
+ raise NotImplementedError
15
+
16
+
17
+ def to_3d(x):
18
+ return rearrange(x, 'b c h w -> b (h w) c')
19
+
20
+ def to_4d(x,h,w):
21
+ return rearrange(x, 'b (h w) c -> b c h w',h=h,w=w)
22
+
23
+ class BiasFree_LayerNorm(nn.Module):
24
+ def __init__(self, normalized_shape):
25
+ super(BiasFree_LayerNorm, self).__init__()
26
+ if isinstance(normalized_shape, numbers.Integral):
27
+ normalized_shape = (normalized_shape,)
28
+ normalized_shape = torch.Size(normalized_shape)
29
+
30
+ assert len(normalized_shape) == 1
31
+
32
+ self.weight = nn.Parameter(torch.ones(normalized_shape))
33
+ self.normalized_shape = normalized_shape
34
+
35
+ def forward(self, x):
36
+ sigma = x.var(-1, keepdim=True, unbiased=False)
37
+ return x / torch.sqrt(sigma+1e-5) * self.weight
38
+
39
+
40
+ class WithBias_LayerNorm(nn.Module):
41
+ def __init__(self, normalized_shape):
42
+ super(WithBias_LayerNorm, self).__init__()
43
+ if isinstance(normalized_shape, numbers.Integral):
44
+ normalized_shape = (normalized_shape,)
45
+ normalized_shape = torch.Size(normalized_shape)
46
+
47
+ assert len(normalized_shape) == 1
48
+
49
+ self.weight = nn.Parameter(torch.ones(normalized_shape))
50
+ self.bias = nn.Parameter(torch.zeros(normalized_shape))
51
+ self.normalized_shape = normalized_shape
52
+
53
+ def forward(self, x):
54
+ mu = x.mean(-1, keepdim=True)
55
+ sigma = x.var(-1, keepdim=True, unbiased=False)
56
+ return (x - mu) / torch.sqrt(sigma+1e-5) * self.weight + self.bias
57
+
58
+
59
+ class LayerNorm(nn.Module):
60
+ def __init__(self, dim, LayerNorm_type):
61
+ super(LayerNorm, self).__init__()
62
+ if LayerNorm_type =='BiasFree':
63
+ self.body = BiasFree_LayerNorm(dim)
64
+ else:
65
+ self.body = WithBias_LayerNorm(dim)
66
+
67
+ def forward(self, x):
68
+ h, w = x.shape[-2:]
69
+ return to_4d(self.body(to_3d(x)), h, w)
70
+
71
+
72
+
73
+ ##########################################################################
74
+ ## Gated-Dconv Feed-Forward Network (GDFN)
75
+ class FeedForward(nn.Module):
76
+ def __init__(self, dim, ffn_expansion_factor, bias):
77
+ super(FeedForward, self).__init__()
78
+
79
+ hidden_features = int(dim*ffn_expansion_factor)
80
+
81
+ self.project_in = nn.Conv2d(dim, hidden_features*2, kernel_size=1, bias=bias)
82
+
83
+ self.dwconv = nn.Conv2d(hidden_features*2, hidden_features*2, kernel_size=3, stride=1, padding=1, groups=hidden_features*2, bias=bias)
84
+
85
+ self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias)
86
+
87
+ def forward(self, x):
88
+ x = self.project_in(x)
89
+ x1, x2 = self.dwconv(x).chunk(2, dim=1)
90
+ x = F.gelu(x1) * x2
91
+ x = self.project_out(x)
92
+ return x
93
+
94
+
95
+
96
+ ##########################################################################
97
+ ## Multi-DConv Head Transposed Self-Attention (MDTA)
98
+ class Attention(nn.Module):
99
+ def __init__(self, dim, num_heads, bias, ksize=0):
100
+ super(Attention, self).__init__()
101
+ self.num_heads = num_heads
102
+ self.ksize = ksize
103
+ self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))
104
+
105
+ self.qkv = nn.Conv2d(dim, dim*3, kernel_size=1, bias=bias)
106
+ self.qkv_dwconv = nn.Conv2d(dim*3, dim*3, kernel_size=3, stride=1, padding=1, groups=dim*3, bias=bias)
107
+ self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)
108
+ if ksize:
109
+ self.avg = torch.nn.AvgPool2d(kernel_size=ksize, stride=1, padding=(ksize-1) //2)
110
+
111
+
112
+ def forward(self, x):
113
+ b,c,h,w = x.shape
114
+
115
+ qkv = self.qkv_dwconv(self.qkv(x))
116
+ q,k,v = qkv.chunk(3, dim=1)
117
+
118
+ if self.ksize:
119
+ q = q - self.avg(q)
120
+
121
+ q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
122
+ k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
123
+ v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
124
+
125
+ q = torch.nn.functional.normalize(q, dim=-1)
126
+ k = torch.nn.functional.normalize(k, dim=-1)
127
+
128
+ attn = (q @ k.transpose(-2, -1)) * self.temperature
129
+ attn = attn.softmax(dim=-1)
130
+
131
+ out = (attn @ v)
132
+
133
+ out = rearrange(out, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w)
134
+
135
+ out = self.project_out(out)
136
+ return out
137
+
138
+
139
+ ##########################################################################
140
+ ## Overlapped image patch embedding with 3x3 Conv
141
+ class OverlapPatchEmbed(nn.Module):
142
+ def __init__(self, in_c=3, embed_dim=48, bias=False):
143
+ super(OverlapPatchEmbed, self).__init__()
144
+
145
+ self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=3, stride=1, padding=1, bias=bias)
146
+
147
+ def forward(self, x):
148
+ x = self.proj(x)
149
+
150
+ return x
151
+
152
+
153
+
154
+ ##########################################################################
155
+ ## Resizing modules
156
+ class Downsample(nn.Module):
157
+ def __init__(self, n_feat):
158
+ super(Downsample, self).__init__()
159
+
160
+ self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat//2, kernel_size=3, stride=1, padding=1, bias=False),
161
+ nn.PixelUnshuffle(2))
162
+
163
+ def forward(self, x):
164
+ return self.body(x)
165
+
166
+ class Upsample(nn.Module):
167
+ def __init__(self, n_feat):
168
+ super(Upsample, self).__init__()
169
+
170
+ self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat*2, kernel_size=3, stride=1, padding=1, bias=False),
171
+ nn.PixelShuffle(2))
172
+
173
+ def forward(self, x):
174
+ return self.body(x)
175
+
176
+
177
+ def to(x):
178
+ return {'device': x.device, 'dtype': x.dtype}
179
+
180
+ def pair(x):
181
+ return (x, x) if not isinstance(x, tuple) else x
182
+
183
+ def expand_dim(t, dim, k):
184
+ t = t.unsqueeze(dim = dim)
185
+ expand_shape = [-1] * len(t.shape)
186
+ expand_shape[dim] = k
187
+ return t.expand(*expand_shape)
188
+
189
+ def rel_to_abs(x):
190
+ b, l, m = x.shape
191
+ r = (m + 1) // 2
192
+
193
+ col_pad = torch.zeros((b, l, 1), **to(x))
194
+ x = torch.cat((x, col_pad), dim = 2)
195
+ flat_x = rearrange(x, 'b l c -> b (l c)')
196
+ flat_pad = torch.zeros((b, m - l), **to(x))
197
+ flat_x_padded = torch.cat((flat_x, flat_pad), dim = 1)
198
+ final_x = flat_x_padded.reshape(b, l + 1, m)
199
+ final_x = final_x[:, :l, -r:]
200
+ return final_x
201
+
202
+ def relative_logits_1d(q, rel_k):
203
+ b, h, w, _ = q.shape
204
+ r = (rel_k.shape[0] + 1) // 2
205
+
206
+ logits = einsum('b x y d, r d -> b x y r', q, rel_k)
207
+ logits = rearrange(logits, 'b x y r -> (b x) y r')
208
+ logits = rel_to_abs(logits)
209
+
210
+ logits = logits.reshape(b, h, w, r)
211
+ logits = expand_dim(logits, dim = 2, k = r)
212
+ return logits
213
+
214
+
215
+ class RelPosEmb(nn.Module):
216
+ def __init__(
217
+ self,
218
+ block_size,
219
+ rel_size,
220
+ dim_head
221
+ ):
222
+ super().__init__()
223
+ height = width = rel_size
224
+ scale = dim_head ** -0.5
225
+
226
+ self.block_size = block_size
227
+ self.rel_height = nn.Parameter(torch.randn(height * 2 - 1, dim_head) * scale)
228
+ self.rel_width = nn.Parameter(torch.randn(width * 2 - 1, dim_head) * scale)
229
+
230
+ def forward(self, q):
231
+ block = self.block_size
232
+
233
+ q = rearrange(q, 'b (x y) c -> b x y c', x = block)
234
+ rel_logits_w = relative_logits_1d(q, self.rel_width)
235
+ rel_logits_w = rearrange(rel_logits_w, 'b x i y j-> b (x y) (i j)')
236
+
237
+ q = rearrange(q, 'b x y d -> b y x d')
238
+ rel_logits_h = relative_logits_1d(q, self.rel_height)
239
+ rel_logits_h = rearrange(rel_logits_h, 'b x i y j -> b (y x) (j i)')
240
+ return rel_logits_w + rel_logits_h
241
+
242
+
243
+ ##########################################################################
244
+ ## Overlapping Cross-Attention (OCA)
245
+ class OCAB(nn.Module):
246
+ def __init__(self, dim, window_size, overlap_ratio, num_heads, dim_head, bias, ksize=0):
247
+ super(OCAB, self).__init__()
248
+ self.num_spatial_heads = num_heads
249
+ self.dim = dim
250
+ self.window_size = window_size
251
+ self.overlap_win_size = int(window_size * overlap_ratio) + window_size
252
+ self.dim_head = dim_head
253
+ self.inner_dim = self.dim_head * self.num_spatial_heads
254
+ self.scale = self.dim_head**-0.5
255
+ self.ksize = ksize
256
+
257
+ self.unfold = nn.Unfold(kernel_size=(self.overlap_win_size, self.overlap_win_size), stride=window_size, padding=(self.overlap_win_size-window_size)//2)
258
+ self.qkv = nn.Conv2d(self.dim, self.inner_dim*3, kernel_size=1, bias=bias)
259
+ self.project_out = nn.Conv2d(self.inner_dim, dim, kernel_size=1, bias=bias)
260
+ self.rel_pos_emb = RelPosEmb(
261
+ block_size = window_size,
262
+ rel_size = window_size + (self.overlap_win_size - window_size),
263
+ dim_head = self.dim_head
264
+ )
265
+ if ksize:
266
+ self.avg = torch.nn.AvgPool2d(kernel_size=ksize, stride=1, padding=(ksize-1) //2)
267
+
268
+ def forward(self, x):
269
+ b, c, h, w = x.shape
270
+
271
+ qkv = self.qkv(x)
272
+ qs, ks, vs = qkv.chunk(3, dim=1)
273
+
274
+ if self.ksize:
275
+ qs = qs - self.avg(qs)
276
+
277
+ # spatial attention
278
+ qs = rearrange(qs, 'b c (h p1) (w p2) -> (b h w) (p1 p2) c', p1 = self.window_size, p2 = self.window_size)
279
+ ks, vs = map(lambda t: self.unfold(t), (ks, vs))
280
+ ks, vs = map(lambda t: rearrange(t, 'b (c j) i -> (b i) j c', c = self.inner_dim), (ks, vs))
281
+
282
+ #split heads
283
+ qs, ks, vs = map(lambda t: rearrange(t, 'b n (head c) -> (b head) n c', head = self.num_spatial_heads), (qs, ks, vs))
284
+
285
+ # attention
286
+ qs = qs * self.scale
287
+ spatial_attn = (qs @ ks.transpose(-2, -1))
288
+ spatial_attn += self.rel_pos_emb(qs)
289
+ spatial_attn = spatial_attn.softmax(dim=-1)
290
+
291
+ out = (spatial_attn @ vs)
292
+
293
+ out = rearrange(out, '(b h w head) (p1 p2) c -> b (head c) (h p1) (w p2)', head = self.num_spatial_heads, h = h // self.window_size, w = w // self.window_size, p1 = self.window_size, p2 = self.window_size)
294
+
295
+ # merge spatial and channel
296
+ out = self.project_out(out)
297
+
298
+ return out
299
+
300
+
301
+ class AttentionFusion(nn.Module):
302
+ def __init__(self, dim, bias, channel_fusion):
303
+ super(AttentionFusion, self).__init__()
304
+
305
+ self.channel_fusion = channel_fusion
306
+ self.fusion = nn.Sequential(
307
+ nn.Conv2d(dim, dim // 2, kernel_size=1, bias=bias),
308
+ nn.GELU(),
309
+ nn.Conv2d(dim // 2, dim // 2, kernel_size=1, bias=bias)
310
+ )
311
+ self.dim = dim // 2
312
+
313
+ def forward(self, x):
314
+ fusion_map = self.fusion(x)
315
+ if self.channel_fusion:
316
+ weight = F.sigmoid(torch.mean(fusion_map, 1, True))
317
+ else:
318
+ weight = F.sigmoid(torch.mean(fusion_map, (2,3), True))
319
+ fused_feature = x[:, :self.dim] * weight + x[:, self.dim:] * (1-weight) # [:, :self.dim] == SA
320
+ return fused_feature
321
+
322
+
323
+
324
+ class Transformer_STAF(nn.Module):
325
+ def __init__(self, dim, window_size, overlap_ratio, num_channel_heads, num_spatial_heads, spatial_dim_head, ffn_expansion_factor, bias, LayerNorm_type, channel_fusion, query_ksize=0):
326
+ super(Transformer_STAF, self).__init__()
327
+
328
+ self.spatial_attn = OCAB(dim, window_size, overlap_ratio, num_spatial_heads, spatial_dim_head, bias, ksize=query_ksize)
329
+ self.channel_attn = Attention(dim, num_channel_heads, bias, ksize=query_ksize)
330
+
331
+ self.norm1 = LayerNorm(dim, LayerNorm_type)
332
+ self.norm2 = LayerNorm(dim, LayerNorm_type)
333
+ self.norm3 = LayerNorm(dim, LayerNorm_type)
334
+ self.norm4 = LayerNorm(dim, LayerNorm_type)
335
+
336
+ self.channel_ffn = FeedForward(dim, ffn_expansion_factor, bias)
337
+ self.spatial_ffn = FeedForward(dim, ffn_expansion_factor, bias)
338
+
339
+ self.fusion = AttentionFusion(dim*2, bias, channel_fusion)
340
+
341
+ def forward(self, x):
342
+ sa = x + self.spatial_attn(self.norm1(x))
343
+ sa = sa + self.spatial_ffn(self.norm2(sa))
344
+ ca = x + self.channel_attn(self.norm3(x))
345
+ ca = ca + self.channel_ffn(self.norm4(ca))
346
+ fused = self.fusion(torch.cat([sa, ca], 1))
347
+
348
+ return fused
349
+
350
+
351
+ class MAFG_CA(nn.Module):
352
+ def __init__(self, embed_dim, num_heads, M, window_size=0, eps=1e-6):
353
+ super().__init__()
354
+ self.M = M
355
+ self.Q_idx = M // 2
356
+ self.embed_dim = embed_dim
357
+ self.num_heads = num_heads
358
+ self.head_dim = embed_dim // num_heads
359
+ self.M = M
360
+ self.wsize = window_size
361
+
362
+ self.proj_high = nn.Conv2d(3, embed_dim, kernel_size=1)
363
+ self.proj_rgb = nn.Conv2d(embed_dim, 3, kernel_size=1)
364
+
365
+ self.norm = nn.LayerNorm(embed_dim, eps=eps)
366
+ self.qkv = nn.Linear(embed_dim, embed_dim*3, bias=False)
367
+ self.proj_out = nn.Linear(embed_dim, embed_dim, bias=False)
368
+ self.max_seq = 2**16-1
369
+
370
+ # window based sliding similar to OCAB
371
+ self.overlap_wsize = int(self.wsize * 0.5) + self.wsize
372
+ self.unfold = nn.Unfold(kernel_size=(self.overlap_wsize, self.overlap_wsize), stride=window_size, padding=(self.overlap_wsize-self.wsize)//2)
373
+ self.scale = self.embed_dim ** -0.5
374
+ self.pos_emb_q = nn.Parameter(torch.zeros(self.wsize**2, embed_dim))
375
+ self.pos_emb_k = nn.Parameter(torch.zeros(self.overlap_wsize**2, embed_dim))
376
+ nn.init.trunc_normal_(self.pos_emb_q, std=0.02)
377
+ nn.init.trunc_normal_(self.pos_emb_k, std=0.02)
378
+
379
+ def forward(self, x):
380
+ x = self.proj_high(x)
381
+ BM,E,H,W = x.shape
382
+
383
+ x_seq = x.view(BM,E,-1).permute(0,2,1)
384
+ x_seq = self.norm(x_seq)
385
+ B = BM // self.M
386
+ QKV = self.qkv(x_seq)
387
+ QKV = QKV.view(BM, H, W, 3, -1).permute(3,0,4,1,2).contiguous()
388
+ Q,K,V = QKV[0], QKV[1], QKV[2]
389
+ Q_bm = Q.view(B, self.M, E, H,W)
390
+ _Q = Q_bm[:, self.Q_idx:self.Q_idx+1]
391
+ Q = torch.stack([__Q.repeat(self.M,1,1,1) for __Q in _Q]).view(BM,E,H,W)
392
+
393
+ Q = rearrange(Q, 'b c (h p1) (w p2) -> (b h w) (p1 p2) c', p1 = self.wsize, p2 = self.wsize)
394
+ K,V = map(lambda t: self.unfold(t), (K,V))
395
+ if K.shape[-1] > 10000: # Inference
396
+ b,_,pp = K.shape
397
+ K = K.view(b,self.embed_dim,-1,pp).permute(0,3,2,1).reshape(b*pp,-1,self.embed_dim)
398
+ V = V.view(b,self.embed_dim,-1,pp).permute(0,3,2,1).reshape(b*pp,-1,self.embed_dim)
399
+ else:
400
+ K,V = map(lambda t: rearrange(t, 'b (c j) i -> (b i) j c', c = self.embed_dim), (K,V))
401
+
402
+ # Absolute positional embedding
403
+ Q = Q + self.pos_emb_q
404
+ K = K + self.pos_emb_k
405
+
406
+ s, eq, _ = Q.shape
407
+ _, ek, _ = K.shape
408
+ Q = Q.view(s, eq, self.num_heads,self.head_dim).half()
409
+ K = K.view(s, ek, self.num_heads,self.head_dim).half()
410
+ V = V.view(s, ek, self.num_heads,self.head_dim).half()
411
+ if s > self.max_seq: # maximum allowed sequence of flash attention
412
+ outs = []
413
+ sp = self.max_seq
414
+ _max = s // sp + 1
415
+ for i in range(_max):
416
+ outs.append(flash_attn_func(Q[i*sp: (i+1)*sp], K[i*sp: (i+1)*sp], V[i*sp: (i+1)*sp], causal=False))
417
+ out = torch.cat(outs).to(torch.float32)
418
+ else:
419
+ out = flash_attn_func(Q, K, V, causal=False).to(torch.float32)
420
+ out = rearrange(out, '(b nh nw) (ph pw) h d -> b (nh ph nw pw) (h d)', nh=H//self.wsize, nw=W//self.wsize, ph=self.wsize, pw=self.wsize)
421
+ out = self.proj_out(out)
422
+
423
+ mixed_feature = out.view(BM,H,W,E).permute(0,3,1,2).contiguous() + x
424
+ return self.proj_rgb(mixed_feature).reshape(B,-1,H,W)
425
+
426
+
427
+ ##########################################################################
428
+ ## Aberration Correction Transformers for Metalens
429
+ class ACFormer(nn.Module):
430
+ def __init__(self,
431
+ inp_channels=3,
432
+ out_channels=3,
433
+ dim = 48,
434
+ num_blocks = [4,6,6,8],
435
+ num_refinement_blocks = 4,
436
+ channel_heads = [1,2,4,8],
437
+ spatial_heads = [2,2,3,4],
438
+ overlap_ratio=[0.5, 0.5, 0.5, 0.5],
439
+ window_size = 8,
440
+ spatial_dim_head = 16,
441
+ bias = False,
442
+ ffn_expansion_factor = 2.66,
443
+ LayerNorm_type = 'WithBias', ## Other option 'BiasFree'
444
+ M=13,
445
+ ca_heads=2,
446
+ ca_dim=32,
447
+ window_size_ca=0,
448
+ query_ksize=None
449
+ ):
450
+
451
+ super(ACFormer, self).__init__()
452
+ self.center_idx = M // 2
453
+ self.ca = MAFG_CA(embed_dim=ca_dim, num_heads=ca_heads, M=M, window_size=window_size_ca)
454
+ self.patch_embed = OverlapPatchEmbed(inp_channels, dim)
455
+
456
+ self.encoder_level1 = nn.Sequential(*[Transformer_STAF(dim=dim, window_size = window_size, overlap_ratio=overlap_ratio[0], num_channel_heads=channel_heads[0], num_spatial_heads=spatial_heads[0], spatial_dim_head = spatial_dim_head, ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type, channel_fusion=False, query_ksize=0) for i in range(num_blocks[0])])
457
+
458
+ self.down1_2 = Downsample(dim) ## From Level 1 to Level 2
459
+ self.encoder_level2 = nn.Sequential(*[Transformer_STAF(dim=int(dim*2**1), window_size = window_size, overlap_ratio=overlap_ratio[1], num_channel_heads=channel_heads[1], num_spatial_heads=spatial_heads[1], spatial_dim_head = spatial_dim_head, ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type, channel_fusion=False, query_ksize=0) for i in range(num_blocks[1])])
460
+
461
+ self.down2_3 = Downsample(int(dim*2**1)) ## From Level 2 to Level 3
462
+ self.encoder_level3 = nn.Sequential(*[Transformer_STAF(dim=int(dim*2**2), window_size = window_size, overlap_ratio=overlap_ratio[2], num_channel_heads=channel_heads[2], num_spatial_heads=spatial_heads[2], spatial_dim_head = spatial_dim_head, ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type, channel_fusion=False, query_ksize=0) for i in range(num_blocks[2])])
463
+
464
+ self.down3_4 = Downsample(int(dim*2**2)) ## From Level 3 to Level 4
465
+ self.latent = nn.Sequential(*[Transformer_STAF(dim=int(dim*2**3), window_size = window_size, overlap_ratio=overlap_ratio[3], num_channel_heads=channel_heads[3], num_spatial_heads=spatial_heads[3], spatial_dim_head = spatial_dim_head, ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type, channel_fusion=False, query_ksize=query_ksize[0] if i % 2 == 1 else 0) for i in range(num_blocks[3])])
466
+
467
+ self.up4_3 = Upsample(int(dim*2**3)) ## From Level 4 to Level 3
468
+ self.reduce_chan_level3 = nn.Conv2d(int(dim*2**3), int(dim*2**2), kernel_size=1, bias=bias)
469
+ self.decoder_level3 = nn.Sequential(*[Transformer_STAF(dim=int(dim*2**2), window_size = window_size, overlap_ratio=overlap_ratio[2], num_channel_heads=channel_heads[2], num_spatial_heads=spatial_heads[2], spatial_dim_head = spatial_dim_head, ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type, channel_fusion=True, query_ksize=query_ksize[1] if i % 2 == 1 else 0) for i in range(num_blocks[2])])
470
+
471
+ self.up3_2 = Upsample(int(dim*2**2)) ## From Level 3 to Level 2
472
+ self.reduce_chan_level2 = nn.Conv2d(int(dim*2**2), int(dim*2**1), kernel_size=1, bias=bias)
473
+ self.decoder_level2 = nn.Sequential(*[Transformer_STAF(dim=int(dim*2**1), window_size = window_size, overlap_ratio=overlap_ratio[1], num_channel_heads=channel_heads[1], num_spatial_heads=spatial_heads[1], spatial_dim_head = spatial_dim_head, ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type, channel_fusion=True, query_ksize=query_ksize[2] if i % 2 == 1 else 0) for i in range(num_blocks[1])])
474
+
475
+ self.up2_1 = Upsample(int(dim*2**1)) ## From Level 2 to Level 1 (NO 1x1 conv to reduce channels)
476
+
477
+ self.decoder_level1 = nn.Sequential(*[Transformer_STAF(dim=int(dim*2**1), window_size = window_size, overlap_ratio=overlap_ratio[0], num_channel_heads=channel_heads[0], num_spatial_heads=spatial_heads[0], spatial_dim_head = spatial_dim_head, ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type, channel_fusion=True, query_ksize=query_ksize[3] if i % 2 == 1 else 0) for i in range(num_blocks[0])])
478
+
479
+ self.refinement = nn.Sequential(*[Transformer_STAF(dim=int(dim*2**1), window_size = window_size, overlap_ratio=overlap_ratio[0], num_channel_heads=channel_heads[0], num_spatial_heads=spatial_heads[0], spatial_dim_head = spatial_dim_head, ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type, channel_fusion=True, query_ksize=query_ksize[4] if i % 2 == 1 else 0) for i in range(num_refinement_blocks)])
480
+
481
+ self.output = nn.Conv2d(int(dim*2**1), out_channels, kernel_size=3, stride=1, padding=1, bias=bias)
482
+
483
+ def forward(self, inp_img):
484
+ if inp_img.ndim == 5:
485
+ B,M,C,H,W = inp_img.shape
486
+ center_img = inp_img[:, self.center_idx]
487
+ inp_img = inp_img.view(B*M,C,H,W).contiguous()
488
+ else:
489
+ center_img = inp_img
490
+
491
+ if self.ca is None:
492
+ inp_enc_level1 = inp_img.view(B,M*C,H,W)
493
+ else:
494
+ inp_enc_level1 = self.ca(inp_img)
495
+
496
+ inp_enc_level1 = self.patch_embed(inp_enc_level1)
497
+
498
+ out_enc_level1 = self.encoder_level1(inp_enc_level1)
499
+
500
+ inp_enc_level2 = self.down1_2(out_enc_level1)
501
+ out_enc_level2 = self.encoder_level2(inp_enc_level2)
502
+
503
+ inp_enc_level3 = self.down2_3(out_enc_level2)
504
+ out_enc_level3 = self.encoder_level3(inp_enc_level3)
505
+
506
+ inp_enc_level4 = self.down3_4(out_enc_level3)
507
+ latent = self.latent(inp_enc_level4)
508
+
509
+ inp_dec_level3 = self.up4_3(latent)
510
+ inp_dec_level3 = torch.cat([inp_dec_level3, out_enc_level3], 1)
511
+ inp_dec_level3 = self.reduce_chan_level3(inp_dec_level3)
512
+ out_dec_level3 = self.decoder_level3(inp_dec_level3)
513
+
514
+ inp_dec_level2 = self.up3_2(out_dec_level3)
515
+ inp_dec_level2 = torch.cat([inp_dec_level2, out_enc_level2], 1)
516
+ inp_dec_level2 = self.reduce_chan_level2(inp_dec_level2)
517
+ out_dec_level2 = self.decoder_level2(inp_dec_level2)
518
+
519
+ inp_dec_level1 = self.up2_1(out_dec_level2)
520
+ inp_dec_level1 = torch.cat([inp_dec_level1, out_enc_level1], 1)
521
+ out_dec_level1 = self.decoder_level1(inp_dec_level1)
522
+
523
+
524
+ out_dec_level1 = self.refinement(out_dec_level1)
525
+ out_dec_level1 = self.output(out_dec_level1) + center_img
526
+
527
+ return out_dec_level1
basicsr/models/base_model.py ADDED
@@ -0,0 +1,376 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import torch
4
+ from collections import OrderedDict
5
+ from copy import deepcopy
6
+ from torch.nn.parallel import DataParallel, DistributedDataParallel
7
+
8
+ from basicsr.models import lr_scheduler as lr_scheduler
9
+ from basicsr.utils.dist_util import master_only
10
+
11
+ logger = logging.getLogger('basicsr')
12
+
13
+
14
+ class BaseModel():
15
+ """Base model."""
16
+
17
+ def __init__(self, opt):
18
+ self.opt = opt
19
+ self.device = torch.device('cuda' if opt['num_gpu'] != 0 else 'cpu')
20
+ self.is_train = opt['is_train']
21
+ self.schedulers = []
22
+ self.optimizers = []
23
+
24
+ def feed_data(self, data):
25
+ pass
26
+
27
+ def optimize_parameters(self):
28
+ pass
29
+
30
+ def get_current_visuals(self):
31
+ pass
32
+
33
+ def save(self, epoch, current_iter):
34
+ """Save networks and training state."""
35
+ pass
36
+ def validation(self, dataloader, current_iter, tb_logger, save_img=False, rgb2bgr=True, use_image=True, psf=None, ks=None, val_conv=True):
37
+ """Validation function.
38
+
39
+ Args:
40
+ dataloader (torch.utils.data.DataLoader): Validation dataloader.
41
+ current_iter (int): Current iteration.
42
+ tb_logger (tensorboard logger): Tensorboard logger.
43
+ save_img (bool): Whether to save images. Default: False.
44
+ rgb2bgr (bool): Whether to save images using rgb2bgr. Default: True
45
+ use_image (bool): Whether to use saved images to compute metrics (PSNR, SSIM), if not, then use data directly from network' output. Default: True
46
+ """
47
+ if self.opt['dist']:
48
+ return self.dist_validation(dataloader, current_iter, tb_logger, save_img, rgb2bgr, use_image, psf, ks, val_conv)
49
+ else:
50
+ return self.nondist_validation(dataloader, current_iter, tb_logger, save_img, rgb2bgr, use_image, psf, ks, val_conv)
51
+
52
+ def model_ema(self, decay=0.999):
53
+ net_g = self.get_bare_model(self.net_g)
54
+
55
+ net_g_params = dict(net_g.named_parameters())
56
+ net_g_ema_params = dict(self.net_g_ema.named_parameters())
57
+
58
+ for k in net_g_ema_params.keys():
59
+ net_g_ema_params[k].data.mul_(decay).add_(
60
+ net_g_params[k].data, alpha=1 - decay)
61
+
62
+ def get_current_log(self):
63
+ return self.log_dict
64
+
65
+ def model_to_device(self, net):
66
+ """Model to device. It also warps models with DistributedDataParallel
67
+ or DataParallel.
68
+
69
+ Args:
70
+ net (nn.Module)
71
+ """
72
+
73
+ net = net.to(self.device)
74
+ # if self.opt['dist']:
75
+ # find_unused_parameters = self.opt.get('find_unused_parameters',
76
+ # False)
77
+ # net = DistributedDataParallel(
78
+ # net,
79
+ # device_ids=[torch.cuda.current_device()],
80
+ # find_unused_parameters=find_unused_parameters)
81
+ # elif self.opt['num_gpu'] > 1:
82
+ # net = DataParallel(net)
83
+ return net
84
+
85
+ def setup_schedulers(self):
86
+ """Set up schedulers."""
87
+ train_opt = self.opt['train']
88
+ scheduler_type = train_opt['scheduler'].pop('type')
89
+ if scheduler_type in ['MultiStepLR', 'MultiStepRestartLR']:
90
+ for optimizer in self.optimizers:
91
+ self.schedulers.append(
92
+ lr_scheduler.MultiStepRestartLR(optimizer,
93
+ **train_opt['scheduler']))
94
+ elif scheduler_type == 'CosineAnnealingRestartLR':
95
+ for optimizer in self.optimizers:
96
+ self.schedulers.append(
97
+ lr_scheduler.CosineAnnealingRestartLR(
98
+ optimizer, **train_opt['scheduler']))
99
+ elif scheduler_type == 'CosineAnnealingWarmupRestarts':
100
+ for optimizer in self.optimizers:
101
+ self.schedulers.append(
102
+ lr_scheduler.CosineAnnealingWarmupRestarts(
103
+ optimizer, **train_opt['scheduler']))
104
+ elif scheduler_type == 'CosineAnnealingRestartCyclicLR':
105
+ for optimizer in self.optimizers:
106
+ self.schedulers.append(
107
+ lr_scheduler.CosineAnnealingRestartCyclicLR(
108
+ optimizer, **train_opt['scheduler']))
109
+ elif scheduler_type == 'TrueCosineAnnealingLR':
110
+ print('..', 'cosineannealingLR')
111
+ for optimizer in self.optimizers:
112
+ self.schedulers.append(
113
+ torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, **train_opt['scheduler']))
114
+ elif scheduler_type == 'CosineAnnealingLRWithRestart':
115
+ print('..', 'CosineAnnealingLR_With_Restart')
116
+ for optimizer in self.optimizers:
117
+ self.schedulers.append(
118
+ lr_scheduler.CosineAnnealingLRWithRestart(optimizer, **train_opt['scheduler']))
119
+ elif scheduler_type == 'LinearLR':
120
+ for optimizer in self.optimizers:
121
+ self.schedulers.append(
122
+ lr_scheduler.LinearLR(
123
+ optimizer, train_opt['total_iter']))
124
+ elif scheduler_type == 'VibrateLR':
125
+ for optimizer in self.optimizers:
126
+ self.schedulers.append(
127
+ lr_scheduler.VibrateLR(
128
+ optimizer, train_opt['total_iter']))
129
+ else:
130
+ raise NotImplementedError(
131
+ f'Scheduler {scheduler_type} is not implemented yet.')
132
+
133
+ def get_bare_model(self, net):
134
+ """Get bare model, especially under wrapping with
135
+ DistributedDataParallel or DataParallel.
136
+ """
137
+ if isinstance(net, (DataParallel, DistributedDataParallel)):
138
+ net = net.module
139
+ return net
140
+
141
+ @master_only
142
+ def print_network(self, net):
143
+ """Print the str and parameter number of a network.
144
+
145
+ Args:
146
+ net (nn.Module)
147
+ """
148
+ if isinstance(net, (DataParallel, DistributedDataParallel)):
149
+ net_cls_str = (f'{net.__class__.__name__} - '
150
+ f'{net.module.__class__.__name__}')
151
+ else:
152
+ net_cls_str = f'{net.__class__.__name__}'
153
+
154
+ net = self.get_bare_model(net)
155
+ net_str = str(net)
156
+ net_params = sum(map(lambda x: x.numel(), net.parameters()))
157
+
158
+ logger.info(
159
+ f'Network: {net_cls_str}, with parameters: {net_params:,d}')
160
+ logger.info(net_str)
161
+
162
+ def _set_lr(self, lr_groups_l):
163
+ """Set learning rate for warmup.
164
+
165
+ Args:
166
+ lr_groups_l (list): List for lr_groups, each for an optimizer.
167
+ """
168
+ for optimizer, lr_groups in zip(self.optimizers, lr_groups_l):
169
+ for param_group, lr in zip(optimizer.param_groups, lr_groups):
170
+ param_group['lr'] = lr
171
+
172
+ def _get_init_lr(self):
173
+ """Get the initial lr, which is set by the scheduler.
174
+ """
175
+ init_lr_groups_l = []
176
+ for optimizer in self.optimizers:
177
+ init_lr_groups_l.append(
178
+ [v['initial_lr'] for v in optimizer.param_groups])
179
+ return init_lr_groups_l
180
+
181
+ def update_learning_rate(self, current_iter, warmup_iter=-1):
182
+ """Update learning rate.
183
+
184
+ Args:
185
+ current_iter (int): Current iteration.
186
+ warmup_iter (int): Warmup iter numbers. -1 for no warmup.
187
+ Default: -1.
188
+ """
189
+ if current_iter > 1:
190
+ for scheduler in self.schedulers:
191
+ scheduler.step()
192
+ # set up warm-up learning rate
193
+ if current_iter < warmup_iter:
194
+ # get initial lr for each group
195
+ init_lr_g_l = self._get_init_lr()
196
+ # modify warming-up learning rates
197
+ # currently only support linearly warm up
198
+ warm_up_lr_l = []
199
+ for init_lr_g in init_lr_g_l:
200
+ warm_up_lr_l.append(
201
+ [v / warmup_iter * current_iter for v in init_lr_g])
202
+ # set learning rate
203
+ self._set_lr(warm_up_lr_l)
204
+
205
+ def get_current_learning_rate(self):
206
+ return [
207
+ param_group['lr']
208
+ for param_group in self.optimizers[0].param_groups
209
+ ]
210
+
211
+ @master_only
212
+ def save_network(self, net, net_label, current_iter, param_key='params'):
213
+ """Save networks.
214
+
215
+ Args:
216
+ net (nn.Module | list[nn.Module]): Network(s) to be saved.
217
+ net_label (str): Network label.
218
+ current_iter (int): Current iter number.
219
+ param_key (str | list[str]): The parameter key(s) to save network.
220
+ Default: 'params'.
221
+ """
222
+ if current_iter == -1:
223
+ current_iter = 'latest'
224
+ save_filename = f'{net_label}_{current_iter}.pth'
225
+ save_path = os.path.join(self.opt['path']['models'], save_filename)
226
+
227
+ net = net if isinstance(net, list) else [net]
228
+ param_key = param_key if isinstance(param_key, list) else [param_key]
229
+ assert len(net) == len(
230
+ param_key), 'The lengths of net and param_key should be the same.'
231
+
232
+ save_dict = {}
233
+ for net_, param_key_ in zip(net, param_key):
234
+ net_ = self.get_bare_model(net_)
235
+ state_dict = net_.state_dict()
236
+ for key, param in state_dict.items():
237
+ if key.startswith('module.'): # remove unnecessary 'module.'
238
+ key = key[7:]
239
+ state_dict[key] = param.cpu()
240
+ save_dict[param_key_] = state_dict
241
+
242
+ torch.save(save_dict, save_path)
243
+
244
+ def _print_different_keys_loading(self, crt_net, load_net, strict=True):
245
+ """Print keys with differnet name or different size when loading models.
246
+
247
+ 1. Print keys with differnet names.
248
+ 2. If strict=False, print the same key but with different tensor size.
249
+ It also ignore these keys with different sizes (not load).
250
+
251
+ Args:
252
+ crt_net (torch model): Current network.
253
+ load_net (dict): Loaded network.
254
+ strict (bool): Whether strictly loaded. Default: True.
255
+ """
256
+ crt_net = self.get_bare_model(crt_net)
257
+ crt_net = crt_net.state_dict()
258
+ crt_net_keys = set(crt_net.keys())
259
+ load_net_keys = set(load_net.keys())
260
+
261
+ if crt_net_keys != load_net_keys:
262
+ logger.warning('Current net - loaded net:')
263
+ for v in sorted(list(crt_net_keys - load_net_keys)):
264
+ logger.warning(f' {v}')
265
+ logger.warning('Loaded net - current net:')
266
+ for v in sorted(list(load_net_keys - crt_net_keys)):
267
+ logger.warning(f' {v}')
268
+
269
+ # check the size for the same keys
270
+ if not strict:
271
+ common_keys = crt_net_keys & load_net_keys
272
+ for k in common_keys:
273
+ if crt_net[k].size() != load_net[k].size():
274
+ logger.warning(
275
+ f'Size different, ignore [{k}]: crt_net: '
276
+ f'{crt_net[k].shape}; load_net: {load_net[k].shape}')
277
+ load_net[k + '.ignore'] = load_net.pop(k)
278
+
279
+ def load_network(self, net, load_path, strict=True, param_key='params'):
280
+ """Load network.
281
+
282
+ Args:
283
+ load_path (str): The path of networks to be loaded.
284
+ net (nn.Module): Network.
285
+ strict (bool): Whether strictly loaded.
286
+ param_key (str): The parameter key of loaded network. If set to
287
+ None, use the root 'path'.
288
+ Default: 'params'.
289
+ """
290
+ net = self.get_bare_model(net)
291
+ logger.info(
292
+ f'Loading {net.__class__.__name__} model from {load_path}.')
293
+ load_net = torch.load(
294
+ load_path, map_location=lambda storage, loc: storage)
295
+ if param_key is not None:
296
+ if param_key not in load_net and 'params' in load_net:
297
+ param_key = 'params'
298
+ logger.info('Loading: params_ema does not exist, use params.')
299
+ load_net = load_net[param_key]
300
+ print(' load net keys', load_net.keys)
301
+ # remove unnecessary 'module.'
302
+ for k, v in deepcopy(load_net).items():
303
+ if k.startswith('module.'):
304
+ load_net[k[7:]] = v
305
+ load_net.pop(k)
306
+ self._print_different_keys_loading(net, load_net, strict)
307
+ net.load_state_dict(load_net, strict=strict)
308
+
309
+ @master_only
310
+ def save_training_state(self, epoch, current_iter):
311
+ """Save training states during training, which will be used for
312
+ resuming.
313
+
314
+ Args:
315
+ epoch (int): Current epoch.
316
+ current_iter (int): Current iteration.
317
+ """
318
+ if current_iter != -1:
319
+ state = {
320
+ 'epoch': epoch,
321
+ 'iter': current_iter,
322
+ 'optimizers': [],
323
+ 'schedulers': []
324
+ }
325
+ for o in self.optimizers:
326
+ state['optimizers'].append(o.state_dict())
327
+ for s in self.schedulers:
328
+ state['schedulers'].append(s.state_dict())
329
+ save_filename = f'{current_iter}.state'
330
+ save_path = os.path.join(self.opt['path']['training_states'],
331
+ save_filename)
332
+ torch.save(state, save_path)
333
+
334
+ def resume_training(self, resume_state):
335
+ """Reload the optimizers and schedulers for resumed training.
336
+
337
+ Args:
338
+ resume_state (dict): Resume state.
339
+ """
340
+ resume_optimizers = resume_state['optimizers']
341
+ resume_schedulers = resume_state['schedulers']
342
+ assert len(resume_optimizers) == len(
343
+ self.optimizers), 'Wrong lengths of optimizers'
344
+ assert len(resume_schedulers) == len(
345
+ self.schedulers), 'Wrong lengths of schedulers'
346
+ for i, o in enumerate(resume_optimizers):
347
+ self.optimizers[i].load_state_dict(o)
348
+ for i, s in enumerate(resume_schedulers):
349
+ self.schedulers[i].load_state_dict(s)
350
+
351
+ def reduce_loss_dict(self, loss_dict):
352
+ """reduce loss dict.
353
+
354
+ In distributed training, it averages the losses among different GPUs .
355
+
356
+ Args:
357
+ loss_dict (OrderedDict): Loss dict.
358
+ """
359
+ with torch.no_grad():
360
+ if self.opt['dist']:
361
+ keys = []
362
+ losses = []
363
+ for name, value in loss_dict.items():
364
+ keys.append(name)
365
+ losses.append(value)
366
+ losses = torch.stack(losses, 0)
367
+ torch.distributed.reduce(losses, dst=0)
368
+ if self.opt['rank'] == 0:
369
+ losses /= self.opt['world_size']
370
+ loss_dict = {key: loss for key, loss in zip(keys, losses)}
371
+
372
+ log_dict = OrderedDict()
373
+ for name, value in loss_dict.items():
374
+ log_dict[name] = value.mean().item()
375
+
376
+ return log_dict
basicsr/models/image_restoration_model.py ADDED
@@ -0,0 +1,392 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ import torch
3
+ import os
4
+ import gc
5
+ import random
6
+ import torch.nn.functional as F
7
+
8
+ from collections import OrderedDict
9
+ from copy import deepcopy
10
+ from os import path as osp
11
+ from tqdm import tqdm
12
+ from functools import partial
13
+
14
+ from basicsr.models.archs import define_network
15
+ from basicsr.models.base_model import BaseModel
16
+ from basicsr.utils import get_root_logger, imwrite, tensor2img
17
+ from basicsr.utils.nano import apply_conv_n_deconv
18
+ from basicsr.metrics.other_metrics import compute_img_metric
19
+
20
+ loss_module = importlib.import_module('basicsr.models.losses')
21
+ metric_module = importlib.import_module('basicsr.metrics')
22
+
23
+
24
+ class Mixing_Augment:
25
+ def __init__(self, mixup_beta, use_identity, device):
26
+ self.dist = torch.distributions.beta.Beta(torch.tensor([mixup_beta]), torch.tensor([mixup_beta]))
27
+ self.device = device
28
+
29
+ self.use_identity = use_identity
30
+
31
+ self.augments = [self.mixup]
32
+
33
+ def mixup(self, target, input_):
34
+ lam = self.dist.rsample((1,1)).item()
35
+
36
+ r_index = torch.randperm(target.size(0)).to(self.device)
37
+
38
+ target = lam * target + (1-lam) * target[r_index, :]
39
+ input_ = lam * input_ + (1-lam) * input_[r_index, :]
40
+
41
+ return target, input_
42
+
43
+ def __call__(self, target, input_):
44
+ if self.use_identity:
45
+ augment = random.randint(0, len(self.augments))
46
+ if augment < len(self.augments):
47
+ target, input_ = self.augments[augment](target, input_)
48
+ else:
49
+ augment = random.randint(0, len(self.augments)-1)
50
+ target, input_ = self.augments[augment](target, input_)
51
+ return target, input_
52
+
53
+ class ImageCleanModel(BaseModel):
54
+ """Base Deblur model for single image deblur."""
55
+
56
+ def __init__(self, opt):
57
+ super(ImageCleanModel, self).__init__(opt)
58
+
59
+ # define network
60
+
61
+ self.mixing_flag = self.opt['train']['mixing_augs'].get('mixup', False)
62
+ if self.mixing_flag:
63
+ mixup_beta = self.opt['train']['mixing_augs'].get('mixup_beta', 1.2)
64
+ use_identity = self.opt['train']['mixing_augs'].get('use_identity', False)
65
+ self.mixing_augmentation = Mixing_Augment(mixup_beta, use_identity, self.device)
66
+
67
+ self.net_g = define_network(deepcopy(opt['network_g']))
68
+ self.net_g = self.model_to_device(self.net_g)
69
+
70
+ # load pretrained models
71
+ load_path = self.opt['path'].get('pretrain_network_g', None)
72
+ if load_path is not None:
73
+ self.load_network(self.net_g, load_path,
74
+ self.opt['path'].get('strict_load_g', True), param_key=self.opt['path'].get('param_key', 'params'))
75
+
76
+ if self.is_train:
77
+ self.init_training_settings()
78
+
79
+ def init_training_settings(self):
80
+ self.net_g.train()
81
+ train_opt = self.opt['train']
82
+
83
+ self.ema_decay = train_opt.get('ema_decay', 0)
84
+ if self.ema_decay > 0:
85
+ logger = get_root_logger()
86
+ logger.info(
87
+ f'Use Exponential Moving Average with decay: {self.ema_decay}')
88
+ # define network net_g with Exponential Moving Average (EMA)
89
+ # net_g_ema is used only for testing on one GPU and saving
90
+ # There is no need to wrap with DistributedDataParallel
91
+ self.net_g_ema = define_network(self.opt['network_g']).to(
92
+ self.device)
93
+ # load pretrained model
94
+ load_path = self.opt['path'].get('pretrain_network_g', None)
95
+ if load_path is not None:
96
+ self.load_network(self.net_g_ema, load_path,
97
+ self.opt['path'].get('strict_load_g',
98
+ True), 'params_ema')
99
+ else:
100
+ self.model_ema(0) # copy net_g weight
101
+ self.net_g_ema.eval()
102
+
103
+
104
+ # define losses
105
+ if train_opt.get('pixel_opt'):
106
+ pixel_type = train_opt['pixel_opt'].pop('type')
107
+ cri_pix_cls = getattr(loss_module, pixel_type)
108
+ self.cri_pix = cri_pix_cls(**train_opt['pixel_opt']).to(
109
+ self.device)
110
+ else:
111
+ raise ValueError('pixel loss are None.')
112
+
113
+ # set up optimizers and schedulers
114
+ self.setup_optimizers()
115
+ self.setup_schedulers()
116
+
117
+
118
+ def setup_optimizers(self):
119
+ train_opt = self.opt['train']
120
+ optim_params = []
121
+
122
+ for k, v in self.net_g.named_parameters():
123
+ if v.requires_grad:
124
+ optim_params.append(v)
125
+ else:
126
+ logger = get_root_logger()
127
+ logger.warning(f'Params {k} will not be optimized.')
128
+
129
+ optim_type = train_opt['optim_g'].pop('type')
130
+ if optim_type == 'Adam':
131
+ self.optimizer_g = torch.optim.Adam(optim_params, **train_opt['optim_g'])
132
+ elif optim_type == 'AdamW':
133
+ self.optimizer_g = torch.optim.AdamW(optim_params, **train_opt['optim_g'])
134
+ else:
135
+ raise NotImplementedError(
136
+ f'optimizer {optim_type} is not supperted yet.')
137
+ self.optimizers.append(self.optimizer_g)
138
+
139
+
140
+ def feed_train_data(self, data):
141
+ self.lq = data['lq'].to(self.device)
142
+ if 'gt' in data:
143
+ self.gt = data['gt'].to(self.device)
144
+
145
+ if self.mixing_flag:
146
+ self.gt, self.lq = self.mixing_augmentation(self.gt, self.lq)
147
+
148
+ def feed_data(self, data, psf=None, ks=None, val_conv=True):
149
+ gt = data['gt'].to(self.device)
150
+ padding = data['padding']
151
+ padding = torch.stack(padding).T
152
+ otf = psf
153
+ M = ks.shape[1]
154
+ if val_conv: # Apply convolution on the fly (use gt img to create lr image)
155
+ lq, gt = apply_conv_n_deconv(gt, otf, padding, M, 0, ks=ks, ph=135, num_psf=9, sensor_h=1215, crop=False, conv=True)
156
+ self.lq = lq[None]
157
+ self.gt = gt[None] # TODO check dim. 이전에는 square에서 리턴해주는거 그대로 썼는데 지금은 원래 gt 바로 써서 shape 다를수도. 이후 아래랑 합치기
158
+ # TODO 애초에 deconv(gt) 를 gt를 위에서 if else로 받아서 한 줄로 처리 가능
159
+
160
+ else: # loaded npy for validaiton
161
+ lq = data['lq'].to(self.device)
162
+ lq, gt = apply_conv_n_deconv(lq, otf, padding, M, 0, ks=ks, ph=135, num_psf=9, sensor_h=1215, crop=False, conv=False)
163
+ self.lq = lq[None]
164
+ self.gt = gt
165
+
166
+
167
+ def optimize_parameters(self, current_iter):
168
+ self.optimizer_g.zero_grad()
169
+ preds = self.net_g(self.lq)
170
+ if not isinstance(preds, list):
171
+ preds = [preds]
172
+
173
+ self.output = preds[-1]
174
+
175
+ loss_dict = OrderedDict()
176
+ # pixel loss
177
+ l_pix = 0.
178
+ for pred in preds:
179
+ l_pix += self.cri_pix(pred, self.gt)
180
+
181
+ loss_dict['l_pix'] = l_pix
182
+
183
+ l_pix.backward()
184
+ if self.opt['train']['use_grad_clip']:
185
+ torch.nn.utils.clip_grad_norm_(self.net_g.parameters(), 0.01)
186
+ self.optimizer_g.step()
187
+
188
+ self.log_dict = self.reduce_loss_dict(loss_dict)
189
+
190
+ if self.ema_decay > 0:
191
+ self.model_ema(decay=self.ema_decay)
192
+
193
+ def pad_test(self, window_size):
194
+ scale = self.opt.get('scale', 1)
195
+ mod_pad_h, mod_pad_w = 0, 0
196
+ h,w = self.lq.size()[-2:]
197
+ if h % window_size != 0:
198
+ mod_pad_h = window_size - h % window_size
199
+ if w % window_size != 0:
200
+ mod_pad_w = window_size - w % window_size
201
+ img = F.pad(self.lq[0], (0, mod_pad_w, 0, mod_pad_h), 'reflect')[None]
202
+ self.nonpad_test(img)
203
+ _, _, h, w = self.output.size()
204
+ self.output = self.output[:, :, 0:h - mod_pad_h * scale, 0:w - mod_pad_w * scale]
205
+
206
+ def nonpad_test(self, img=None):
207
+ if img is None:
208
+ img = self.lq
209
+ if hasattr(self, 'net_g_ema'):
210
+ self.net_g_ema.eval()
211
+ with torch.no_grad():
212
+ pred = self.net_g_ema(img)
213
+ if isinstance(pred, list):
214
+ pred = pred[-1]
215
+ self.output = pred
216
+ else:
217
+ self.net_g.eval()
218
+ with torch.no_grad():
219
+ pred = self.net_g(img)
220
+
221
+ if isinstance(pred, list):
222
+ pred = pred[-1]
223
+ self.output = pred
224
+ self.net_g.train()
225
+
226
+ def dist_validation(self, dataloader, current_iter, tb_logger, save_img, rgb2bgr, use_image, psf, ks, val_conv):
227
+ if os.environ['LOCAL_RANK'] == '0':
228
+ return self.nondist_validation(dataloader, current_iter, tb_logger, save_img, rgb2bgr, use_image, psf, ks, val_conv)
229
+ else:
230
+ return 0.
231
+
232
+
233
+ def pre_process(self, padding_size):
234
+ # pad to multiplication of window_size
235
+ self.mod_pad_h, self.mod_pad_w = 0, 0
236
+ h,w = self.lq.size()[-2:] # BMCHW
237
+ if h % padding_size != 0:
238
+ self.mod_pad_h = padding_size - h % padding_size
239
+ if w % padding_size != 0:
240
+ self.mod_pad_w = padding_size - w % padding_size
241
+ self.lq = F.pad(self.lq[0], (0, self.mod_pad_w, 0, self.mod_pad_h), 'reflect')[None]
242
+
243
+ def post_process(self):
244
+ _, _, h, w = self.output.size()
245
+ self.output = self.output[...,0:h - self.mod_pad_h, 0:w - self.mod_pad_w]
246
+
247
+ def nondist_validation(self, dataloader, current_iter, tb_logger,
248
+ save_img, rgb2bgr, use_image, psf, ks, val_conv):
249
+ dataset_name = dataloader.dataset.opt['name']
250
+ base_path = self.opt['path']['visualization']
251
+
252
+ with_metrics = self.opt['val'].get('metrics') is not None
253
+ if with_metrics:
254
+ self.metric_results = {
255
+ metric: 0
256
+ for metric in self.opt['val']['metrics'].keys()
257
+ }
258
+ if save_img:
259
+ cur_other_metrics = {'ssim': 0., 'lpips': 0.}
260
+ else:
261
+ cur_other_metrics = None
262
+
263
+ window_size = self.opt['val'].get('window_size', 0)
264
+
265
+ if window_size:
266
+ test = partial(self.pad_test, window_size)
267
+ else:
268
+ test = self.nonpad_test
269
+
270
+ cnt = 0
271
+
272
+ for idx, val_data in enumerate(tqdm(dataloader)):
273
+ img_name = osp.splitext(osp.basename(val_data['gt_path'][0]))[0]
274
+ self.feed_data(val_data, psf, ks, val_conv)
275
+ pad_for_OCB = self.opt['val'].get('padding')
276
+ if pad_for_OCB is not None:
277
+ self.pre_process(pad_for_OCB)
278
+
279
+ torch.cuda.empty_cache()
280
+ gc.collect()
281
+
282
+ test()
283
+
284
+ if pad_for_OCB is not None:
285
+ self.post_process()
286
+
287
+ if save_img and with_metrics and use_image:
288
+ visuals = self.get_current_visuals(to_cpu=False)
289
+ cur_other_metrics['ssim'] += compute_img_metric(visuals['result'][0], visuals['gt'][0], 'ssim')
290
+ cur_other_metrics['lpips'] += compute_img_metric(visuals['result'][0], visuals['gt'][0], 'lpips').item()
291
+
292
+ visuals = self.get_current_visuals()
293
+
294
+ sr_img = tensor2img([visuals['result']], rgb2bgr=rgb2bgr)
295
+ if 'gt' in visuals:
296
+ gt_img = tensor2img([visuals['gt']], rgb2bgr=rgb2bgr)
297
+ del self.gt
298
+
299
+ # tentative for out of GPU memory
300
+ del self.lq
301
+ del self.output
302
+ torch.cuda.empty_cache()
303
+ gc.collect()
304
+
305
+ if save_img:
306
+ if self.opt['is_train']:
307
+ if 'eval_only' in self.opt['train']:
308
+ save_img_path = osp.join(base_path + self.opt['train']['eval_name'],
309
+ f'{img_name}_{current_iter}.png')
310
+ else:
311
+ save_img_path = osp.join(base_path,
312
+ f'{img_name}_{current_iter}.png')
313
+ else:
314
+ save_img_path = osp.join(
315
+ base_path,
316
+ f'{img_name}.png')
317
+ save_gt_img_path = osp.join(
318
+ base_path, dataset_name,
319
+ f'{img_name}_gt.png')
320
+
321
+ imwrite(sr_img, save_img_path)
322
+
323
+ if with_metrics:
324
+ # calculate metrics
325
+ opt_metric = deepcopy(self.opt['val']['metrics'])
326
+ if use_image:
327
+ for name, opt_ in opt_metric.items():
328
+ metric_type = opt_.pop('type')
329
+ self.metric_results[name] += getattr(
330
+ metric_module, metric_type)(sr_img, gt_img, **opt_)
331
+ else:
332
+ for name, opt_ in opt_metric.items():
333
+ metric_type = opt_.pop('type')
334
+ self.metric_results[name] += getattr(
335
+ metric_module, metric_type)(visuals['result'], visuals['gt'], **opt_)
336
+
337
+ cnt += 1
338
+
339
+
340
+ # tentative for out of GPU memory
341
+ torch.cuda.empty_cache()
342
+ gc.collect()
343
+
344
+ current_metric = 0.
345
+ if with_metrics:
346
+ for metric in self.metric_results.keys():
347
+ self.metric_results[metric] /= cnt
348
+ current_metric = self.metric_results[metric]
349
+ if save_img:
350
+ cur_other_metrics['ssim'] /= cnt
351
+ cur_other_metrics['lpips'] /= cnt
352
+
353
+ self._log_validation_metric_values(current_iter, dataset_name,
354
+ tb_logger)
355
+ return current_metric, cur_other_metrics
356
+
357
+
358
+ def _log_validation_metric_values(self, current_iter, dataset_name,
359
+ tb_logger):
360
+ log_str = f'Validation {dataset_name},\t'
361
+ for metric, value in self.metric_results.items():
362
+ log_str += f'\t # {metric}: {value:.4f}'
363
+ logger = get_root_logger()
364
+ logger.info(log_str)
365
+ if tb_logger:
366
+ for metric, value in self.metric_results.items():
367
+ tb_logger.add_scalar(f'metrics/{metric}', value, current_iter)
368
+
369
+ def get_current_visuals(self, to_cpu=True):
370
+ if to_cpu:
371
+ out_dict = OrderedDict()
372
+ out_dict['lq'] = self.lq.detach().cpu()
373
+ out_dict['result'] = self.output.detach().cpu()
374
+ if hasattr(self, 'gt'):
375
+ out_dict['gt'] = self.gt.detach().cpu()
376
+ else:
377
+ out_dict = OrderedDict()
378
+ out_dict['lq'] = self.lq.detach()
379
+ out_dict['result'] = self.output.detach()
380
+ if hasattr(self, 'gt'):
381
+ out_dict['gt'] = self.gt.detach()
382
+ return out_dict
383
+
384
+ def save(self, epoch, current_iter):
385
+ if self.ema_decay > 0:
386
+ self.save_network([self.net_g, self.net_g_ema],
387
+ 'net_g',
388
+ current_iter,
389
+ param_key=['params', 'params_ema'])
390
+ else:
391
+ self.save_network(self.net_g, 'net_g', current_iter)
392
+ self.save_training_state(epoch, current_iter)
basicsr/models/losses/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .losses import (L1Loss, MSELoss, PSNRLoss, CharbonnierLoss)
2
+
3
+ __all__ = [
4
+ 'L1Loss', 'MSELoss', 'PSNRLoss', 'CharbonnierLoss',
5
+ ]
basicsr/models/losses/loss_util.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ from torch.nn import functional as F
3
+
4
+
5
+ def reduce_loss(loss, reduction):
6
+ """Reduce loss as specified.
7
+
8
+ Args:
9
+ loss (Tensor): Elementwise loss tensor.
10
+ reduction (str): Options are 'none', 'mean' and 'sum'.
11
+
12
+ Returns:
13
+ Tensor: Reduced loss tensor.
14
+ """
15
+ reduction_enum = F._Reduction.get_enum(reduction)
16
+ # none: 0, elementwise_mean:1, sum: 2
17
+ if reduction_enum == 0:
18
+ return loss
19
+ elif reduction_enum == 1:
20
+ return loss.mean()
21
+ else:
22
+ return loss.sum()
23
+
24
+
25
+ def weight_reduce_loss(loss, weight=None, reduction='mean'):
26
+ """Apply element-wise weight and reduce loss.
27
+
28
+ Args:
29
+ loss (Tensor): Element-wise loss.
30
+ weight (Tensor): Element-wise weights. Default: None.
31
+ reduction (str): Same as built-in losses of PyTorch. Options are
32
+ 'none', 'mean' and 'sum'. Default: 'mean'.
33
+
34
+ Returns:
35
+ Tensor: Loss values.
36
+ """
37
+ # if weight is specified, apply element-wise weight
38
+ if weight is not None:
39
+ assert weight.dim() == loss.dim()
40
+ assert weight.size(1) == 1 or weight.size(1) == loss.size(1)
41
+ loss = loss * weight
42
+
43
+ # if weight is not specified or reduction is sum, just reduce the loss
44
+ if weight is None or reduction == 'sum':
45
+ loss = reduce_loss(loss, reduction)
46
+ # if reduction is mean, then compute mean over weight region
47
+ elif reduction == 'mean':
48
+ if weight.size(1) > 1:
49
+ weight = weight.sum()
50
+ else:
51
+ weight = weight.sum() * loss.size(1)
52
+ loss = loss.sum() / weight
53
+
54
+ return loss
55
+
56
+
57
+ def weighted_loss(loss_func):
58
+ """Create a weighted version of a given loss function.
59
+
60
+ To use this decorator, the loss function must have the signature like
61
+ `loss_func(pred, target, **kwargs)`. The function only needs to compute
62
+ element-wise loss without any reduction. This decorator will add weight
63
+ and reduction arguments to the function. The decorated function will have
64
+ the signature like `loss_func(pred, target, weight=None, reduction='mean',
65
+ **kwargs)`.
66
+
67
+ :Example:
68
+
69
+ >>> import torch
70
+ >>> @weighted_loss
71
+ >>> def l1_loss(pred, target):
72
+ >>> return (pred - target).abs()
73
+
74
+ >>> pred = torch.Tensor([0, 2, 3])
75
+ >>> target = torch.Tensor([1, 1, 1])
76
+ >>> weight = torch.Tensor([1, 0, 1])
77
+
78
+ >>> l1_loss(pred, target)
79
+ tensor(1.3333)
80
+ >>> l1_loss(pred, target, weight)
81
+ tensor(1.5000)
82
+ >>> l1_loss(pred, target, reduction='none')
83
+ tensor([1., 1., 2.])
84
+ >>> l1_loss(pred, target, weight, reduction='sum')
85
+ tensor(3.)
86
+ """
87
+
88
+ @functools.wraps(loss_func)
89
+ def wrapper(pred, target, weight=None, reduction='mean', **kwargs):
90
+ # get element-wise loss
91
+ loss = loss_func(pred, target, **kwargs)
92
+ loss = weight_reduce_loss(loss, weight, reduction)
93
+ return loss
94
+
95
+ return wrapper
basicsr/models/losses/losses.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn as nn
3
+ from torch.nn import functional as F
4
+ import numpy as np
5
+ from math import exp
6
+
7
+ from basicsr.models.losses.loss_util import weighted_loss
8
+
9
+ _reduction_modes = ['none', 'mean', 'sum']
10
+
11
+
12
+ @weighted_loss
13
+ def l1_loss(pred, target):
14
+ return F.l1_loss(pred, target, reduction='none')
15
+
16
+
17
+ @weighted_loss
18
+ def mse_loss(pred, target):
19
+ return F.mse_loss(pred, target, reduction='none')
20
+
21
+
22
+ # @weighted_loss
23
+ # def charbonnier_loss(pred, target, eps=1e-12):
24
+ # return torch.sqrt((pred - target)**2 + eps)
25
+
26
+
27
+ class L1Loss(nn.Module):
28
+ """L1 (mean absolute error, MAE) loss.
29
+
30
+ Args:
31
+ loss_weight (float): Loss weight for L1 loss. Default: 1.0.
32
+ reduction (str): Specifies the reduction to apply to the output.
33
+ Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'.
34
+ """
35
+
36
+ def __init__(self, loss_weight=1.0, reduction='mean'):
37
+ super(L1Loss, self).__init__()
38
+ if reduction not in ['none', 'mean', 'sum']:
39
+ raise ValueError(f'Unsupported reduction mode: {reduction}. '
40
+ f'Supported ones are: {_reduction_modes}')
41
+
42
+ self.loss_weight = loss_weight
43
+ self.reduction = reduction
44
+
45
+ def forward(self, pred, target, weight=None, **kwargs):
46
+ """
47
+ Args:
48
+ pred (Tensor): of shape (N, C, H, W). Predicted tensor.
49
+ target (Tensor): of shape (N, C, H, W). Ground truth tensor.
50
+ weight (Tensor, optional): of shape (N, C, H, W). Element-wise
51
+ weights. Default: None.
52
+ """
53
+ return self.loss_weight * l1_loss(
54
+ pred, target, weight, reduction=self.reduction)
55
+
56
+ class MSELoss(nn.Module):
57
+ """MSE (L2) loss.
58
+
59
+ Args:
60
+ loss_weight (float): Loss weight for MSE loss. Default: 1.0.
61
+ reduction (str): Specifies the reduction to apply to the output.
62
+ Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'.
63
+ """
64
+
65
+ def __init__(self, loss_weight=1.0, reduction='mean'):
66
+ super(MSELoss, self).__init__()
67
+ if reduction not in ['none', 'mean', 'sum']:
68
+ raise ValueError(f'Unsupported reduction mode: {reduction}. '
69
+ f'Supported ones are: {_reduction_modes}')
70
+
71
+ self.loss_weight = loss_weight
72
+ self.reduction = reduction
73
+
74
+ def forward(self, pred, target, weight=None, **kwargs):
75
+ """
76
+ Args:
77
+ pred (Tensor): of shape (N, C, H, W). Predicted tensor.
78
+ target (Tensor): of shape (N, C, H, W). Ground truth tensor.
79
+ weight (Tensor, optional): of shape (N, C, H, W). Element-wise
80
+ weights. Default: None.
81
+ """
82
+ return self.loss_weight * mse_loss(
83
+ pred, target, weight, reduction=self.reduction)
84
+
85
+ class PSNRLoss(nn.Module):
86
+
87
+ def __init__(self, loss_weight=1.0, reduction='mean', toY=False):
88
+ super(PSNRLoss, self).__init__()
89
+ assert reduction == 'mean'
90
+ self.loss_weight = loss_weight
91
+ self.scale = 10 / np.log(10)
92
+ self.toY = toY
93
+ self.coef = torch.tensor([65.481, 128.553, 24.966]).reshape(1, 3, 1, 1)
94
+ self.first = True
95
+
96
+ def forward(self, pred, target):
97
+ assert len(pred.size()) == 4
98
+ if self.toY:
99
+ if self.first:
100
+ self.coef = self.coef.to(pred.device)
101
+ self.first = False
102
+
103
+ pred = (pred * self.coef).sum(dim=1).unsqueeze(dim=1) + 16.
104
+ target = (target * self.coef).sum(dim=1).unsqueeze(dim=1) + 16.
105
+
106
+ pred, target = pred / 255., target / 255.
107
+ pass
108
+ assert len(pred.size()) == 4
109
+
110
+ return self.loss_weight * self.scale * torch.log(((pred - target) ** 2).mean(dim=(1, 2, 3)) + 1e-8).mean()
111
+
112
+ class CharbonnierLoss(nn.Module):
113
+ """Charbonnier Loss (L1)"""
114
+
115
+ def __init__(self, loss_weight=1.0, reduction='mean', eps=1e-3):
116
+ super(CharbonnierLoss, self).__init__()
117
+ self.eps = eps
118
+
119
+ def forward(self, x, y):
120
+ diff = x - y
121
+ # loss = torch.sum(torch.sqrt(diff * diff + self.eps))
122
+ loss = torch.mean(torch.sqrt((diff * diff) + (self.eps*self.eps)))
123
+ return loss
124
+
125
+ class MS_SSIM(nn.Module):
126
+ def __init__(self, window_size=11, sigma=1.5, device="cuda"):
127
+ super(MS_SSIM, self).__init__()
128
+ self.device = device
129
+ self.channel = 3
130
+ self.sigma=sigma
131
+ self.weights = [0.0448, 0.2856, 0.3001, 0.2363, 0.1333]
132
+ self.levels = len(self.weights)
133
+ self.window = self.create_window(window_size)
134
+
135
+ def create_window(self, window_size):
136
+ self.window_size = window_size
137
+ # 1D gaussian kernel
138
+ gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * self.sigma ** 2)) for x in range(window_size)])
139
+ gauss = gauss / gauss.sum()
140
+
141
+ # 2D Gaussian window
142
+ _1D_window = gauss.unsqueeze(1)
143
+ _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
144
+ return _2D_window.expand(self.channel, 1, window_size, window_size).contiguous().to(self.device)
145
+
146
+ def update_window_size(self, window_size):
147
+ self.window = self.create_window(window_size)
148
+
149
+ def ssim(self, img1, img2):
150
+ """Compute SSIM between two images."""
151
+ mu1 = F.conv2d(img1, self.window, padding=self.window_size // 2, groups=self.channel)
152
+ mu2 = F.conv2d(img2, self.window, padding=self.window_size // 2, groups=self.channel)
153
+
154
+ mu1_sq = mu1.pow(2)
155
+ mu2_sq = mu2.pow(2)
156
+ mu1_mu2 = mu1 * mu2
157
+
158
+ sigma1_sq = F.conv2d(img1 * img1, self.window, padding=self.window_size // 2, groups=self.channel) - mu1_sq
159
+ sigma2_sq = F.conv2d(img2 * img2, self.window, padding=self.window_size // 2, groups=self.channel) - mu2_sq
160
+ sigma12 = F.conv2d(img1 * img2, self.window, padding=self.window_size // 2, groups=self.channel) - mu1_mu2
161
+
162
+ C1 = 0.01 ** 2
163
+ C2 = 0.03 ** 2
164
+
165
+ ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
166
+
167
+ return ssim_map.mean()
168
+
169
+ def forward(self, pred, target):
170
+ msssim = []
171
+ for i in range(self.levels):
172
+ ssim_val = self.ssim(pred, target)
173
+ msssim.append(ssim_val * self.weights[i])
174
+ if i < self.levels - 1:
175
+ pred = F.avg_pool2d(pred, kernel_size=2, stride=2)
176
+ target = F.avg_pool2d(target, kernel_size=2, stride=2)
177
+
178
+ return torch.prod(torch.stack(msssim))
179
+
180
+
basicsr/models/lr_scheduler.py ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from collections import Counter
3
+ from torch.optim.lr_scheduler import _LRScheduler
4
+ import torch
5
+
6
+
7
+ class MultiStepRestartLR(_LRScheduler):
8
+ """ MultiStep with restarts learning rate scheme.
9
+
10
+ Args:
11
+ optimizer (torch.nn.optimizer): Torch optimizer.
12
+ milestones (list): Iterations that will decrease learning rate.
13
+ gamma (float): Decrease ratio. Default: 0.1.
14
+ restarts (list): Restart iterations. Default: [0].
15
+ restart_weights (list): Restart weights at each restart iteration.
16
+ Default: [1].
17
+ last_epoch (int): Used in _LRScheduler. Default: -1.
18
+ """
19
+
20
+ def __init__(self,
21
+ optimizer,
22
+ milestones,
23
+ gamma=0.1,
24
+ restarts=(0, ),
25
+ restart_weights=(1, ),
26
+ last_epoch=-1):
27
+ self.milestones = Counter(milestones)
28
+ self.gamma = gamma
29
+ self.restarts = restarts
30
+ self.restart_weights = restart_weights
31
+ assert len(self.restarts) == len(
32
+ self.restart_weights), 'restarts and their weights do not match.'
33
+ super(MultiStepRestartLR, self).__init__(optimizer, last_epoch)
34
+
35
+ def get_lr(self):
36
+ if self.last_epoch in self.restarts:
37
+ weight = self.restart_weights[self.restarts.index(self.last_epoch)]
38
+ return [
39
+ group['initial_lr'] * weight
40
+ for group in self.optimizer.param_groups
41
+ ]
42
+ if self.last_epoch not in self.milestones:
43
+ return [group['lr'] for group in self.optimizer.param_groups]
44
+ return [
45
+ group['lr'] * self.gamma**self.milestones[self.last_epoch]
46
+ for group in self.optimizer.param_groups
47
+ ]
48
+
49
+ class LinearLR(_LRScheduler):
50
+ """
51
+
52
+ Args:
53
+ optimizer (torch.nn.optimizer): Torch optimizer.
54
+ milestones (list): Iterations that will decrease learning rate.
55
+ gamma (float): Decrease ratio. Default: 0.1.
56
+ last_epoch (int): Used in _LRScheduler. Default: -1.
57
+ """
58
+
59
+ def __init__(self,
60
+ optimizer,
61
+ total_iter,
62
+ last_epoch=-1):
63
+ self.total_iter = total_iter
64
+ super(LinearLR, self).__init__(optimizer, last_epoch)
65
+
66
+ def get_lr(self):
67
+ process = self.last_epoch / self.total_iter
68
+ weight = (1 - process)
69
+ # print('get lr ', [weight * group['initial_lr'] for group in self.optimizer.param_groups])
70
+ return [weight * group['initial_lr'] for group in self.optimizer.param_groups]
71
+
72
+ class VibrateLR(_LRScheduler):
73
+ """
74
+
75
+ Args:
76
+ optimizer (torch.nn.optimizer): Torch optimizer.
77
+ milestones (list): Iterations that will decrease learning rate.
78
+ gamma (float): Decrease ratio. Default: 0.1.
79
+ last_epoch (int): Used in _LRScheduler. Default: -1.
80
+ """
81
+
82
+ def __init__(self,
83
+ optimizer,
84
+ total_iter,
85
+ last_epoch=-1):
86
+ self.total_iter = total_iter
87
+ super(VibrateLR, self).__init__(optimizer, last_epoch)
88
+
89
+ def get_lr(self):
90
+ process = self.last_epoch / self.total_iter
91
+
92
+ f = 0.1
93
+ if process < 3 / 8:
94
+ f = 1 - process * 8 / 3
95
+ elif process < 5 / 8:
96
+ f = 0.2
97
+
98
+ T = self.total_iter // 80
99
+ Th = T // 2
100
+
101
+ t = self.last_epoch % T
102
+
103
+ f2 = t / Th
104
+ if t >= Th:
105
+ f2 = 2 - f2
106
+
107
+ weight = f * f2
108
+
109
+ if self.last_epoch < Th:
110
+ weight = max(0.1, weight)
111
+
112
+ # print('f {}, T {}, Th {}, t {}, f2 {}'.format(f, T, Th, t, f2))
113
+ return [weight * group['initial_lr'] for group in self.optimizer.param_groups]
114
+
115
+ def get_position_from_periods(iteration, cumulative_period):
116
+ """Get the position from a period list.
117
+
118
+ It will return the index of the right-closest number in the period list.
119
+ For example, the cumulative_period = [100, 200, 300, 400],
120
+ if iteration == 50, return 0;
121
+ if iteration == 210, return 2;
122
+ if iteration == 300, return 2.
123
+
124
+ Args:
125
+ iteration (int): Current iteration.
126
+ cumulative_period (list[int]): Cumulative period list.
127
+
128
+ Returns:
129
+ int: The position of the right-closest number in the period list.
130
+ """
131
+ for i, period in enumerate(cumulative_period):
132
+ if iteration <= period:
133
+ return i
134
+
135
+
136
+ class CosineAnnealingRestartLR(_LRScheduler):
137
+ """ Cosine annealing with restarts learning rate scheme.
138
+
139
+ An example of config:
140
+ periods = [10, 10, 10, 10]
141
+ restart_weights = [1, 0.5, 0.5, 0.5]
142
+ eta_min=1e-7
143
+
144
+ It has four cycles, each has 10 iterations. At 10th, 20th, 30th, the
145
+ scheduler will restart with the weights in restart_weights.
146
+
147
+ Args:
148
+ optimizer (torch.nn.optimizer): Torch optimizer.
149
+ periods (list): Period for each cosine anneling cycle.
150
+ restart_weights (list): Restart weights at each restart iteration.
151
+ Default: [1].
152
+ eta_min (float): The mimimum lr. Default: 0.
153
+ last_epoch (int): Used in _LRScheduler. Default: -1.
154
+ """
155
+
156
+ def __init__(self,
157
+ optimizer,
158
+ periods,
159
+ restart_weights=(1, ),
160
+ eta_min=0,
161
+ last_epoch=-1):
162
+ self.periods = periods
163
+ self.restart_weights = restart_weights
164
+ self.eta_min = eta_min
165
+ assert (len(self.periods) == len(self.restart_weights)
166
+ ), 'periods and restart_weights should have the same length.'
167
+ self.cumulative_period = [
168
+ sum(self.periods[0:i + 1]) for i in range(0, len(self.periods))
169
+ ]
170
+ super(CosineAnnealingRestartLR, self).__init__(optimizer, last_epoch)
171
+
172
+ def get_lr(self):
173
+ idx = get_position_from_periods(self.last_epoch,
174
+ self.cumulative_period)
175
+ current_weight = self.restart_weights[idx]
176
+ nearest_restart = 0 if idx == 0 else self.cumulative_period[idx - 1]
177
+ current_period = self.periods[idx]
178
+
179
+ return [
180
+ self.eta_min + current_weight * 0.5 * (base_lr - self.eta_min) *
181
+ (1 + math.cos(math.pi * (
182
+ (self.last_epoch - nearest_restart) / current_period)))
183
+ for base_lr in self.base_lrs
184
+ ]
185
+
186
+ class CosineAnnealingRestartCyclicLR(_LRScheduler):
187
+ """ Cosine annealing with restarts learning rate scheme.
188
+ An example of config:
189
+ periods = [10, 10, 10, 10]
190
+ restart_weights = [1, 0.5, 0.5, 0.5]
191
+ eta_min=1e-7
192
+ It has four cycles, each has 10 iterations. At 10th, 20th, 30th, the
193
+ scheduler will restart with the weights in restart_weights.
194
+ Args:
195
+ optimizer (torch.nn.optimizer): Torch optimizer.
196
+ periods (list): Period for each cosine anneling cycle.
197
+ restart_weights (list): Restart weights at each restart iteration.
198
+ Default: [1].
199
+ eta_min (float): The mimimum lr. Default: 0.
200
+ last_epoch (int): Used in _LRScheduler. Default: -1.
201
+ """
202
+
203
+ def __init__(self,
204
+ optimizer,
205
+ periods,
206
+ restart_weights=(1, ),
207
+ eta_mins=(0, ),
208
+ last_epoch=-1):
209
+ self.periods = periods
210
+ self.restart_weights = restart_weights
211
+ self.eta_mins = eta_mins
212
+ assert (len(self.periods) == len(self.restart_weights)
213
+ ), 'periods and restart_weights should have the same length.'
214
+ self.cumulative_period = [
215
+ sum(self.periods[0:i + 1]) for i in range(0, len(self.periods))
216
+ ]
217
+ super(CosineAnnealingRestartCyclicLR, self).__init__(optimizer, last_epoch)
218
+
219
+ def get_lr(self):
220
+ idx = get_position_from_periods(self.last_epoch,
221
+ self.cumulative_period)
222
+ current_weight = self.restart_weights[idx]
223
+ nearest_restart = 0 if idx == 0 else self.cumulative_period[idx - 1]
224
+ current_period = self.periods[idx]
225
+ eta_min = self.eta_mins[idx]
226
+
227
+ return [
228
+ eta_min + current_weight * 0.5 * (base_lr - eta_min) *
229
+ (1 + math.cos(math.pi * (
230
+ (self.last_epoch - nearest_restart) / current_period)))
231
+ for base_lr in self.base_lrs
232
+ ]
basicsr/test.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import random
3
+ import torch
4
+ from os import path as osp
5
+
6
+ from basicsr.data import create_dataloader, create_dataset
7
+ from basicsr.models import create_model
8
+ from basicsr.utils import (check_resume, make_exp_dirs, mkdir_and_rename, set_random_seed)
9
+ from basicsr.utils.dist_util import get_dist_info, init_dist
10
+ from basicsr.utils.options import parse
11
+ from basicsr.utils.nano import psf2otf
12
+
13
+ import numpy as np
14
+ from tqdm import tqdm
15
+
16
+ def parse_options(is_train=True):
17
+ parser = argparse.ArgumentParser()
18
+ parser.add_argument(
19
+ '-opt', type=str, required=True, help='Path to option YAML file.')
20
+ parser.add_argument(
21
+ '--launcher',
22
+ choices=['none', 'pytorch', 'slurm'],
23
+ default='none',
24
+ help='job launcher')
25
+ parser.add_argument(
26
+ '--name',
27
+ default=None,
28
+ help='job launcher')
29
+ import sys
30
+ vv = sys.version_info.minor
31
+ parser.add_argument('--local-rank', type=int, default=0)
32
+ parser.add_argument('--local_rank', type=int, default=0)
33
+ args = parser.parse_args()
34
+ opt = parse(args.opt, is_train=is_train, name=args.name if args.name is not None and args.name != "" else None)
35
+
36
+
37
+ # distributed settings
38
+ if args.launcher == 'none':
39
+ opt['dist'] = False
40
+ print('Disable distributed.', flush=True)
41
+ else:
42
+ opt['dist'] = True
43
+ if args.launcher == 'slurm' and 'dist_params' in opt:
44
+ init_dist(args.launcher, **opt['dist_params'])
45
+ else:
46
+ init_dist(args.launcher)
47
+ print('init dist .. ', args.launcher)
48
+
49
+ opt['rank'], opt['world_size'] = get_dist_info()
50
+
51
+ # random seed
52
+ seed = opt.get('manual_seed')
53
+ if seed is None:
54
+ seed = random.randint(1, 10000)
55
+ opt['manual_seed'] = seed
56
+ set_random_seed(seed + opt['rank'])
57
+
58
+ return opt
59
+
60
+
61
+ def main():
62
+ # parse options, set distributed setting, set ramdom seed
63
+ opt = parse_options(is_train=True)
64
+ torch.backends.cudnn.benchmark = True
65
+
66
+ # automatic resume ..
67
+ state_folder_path = 'experiments/{}/training_states/'.format(opt['name'])
68
+ import os
69
+ try:
70
+ states = os.listdir(state_folder_path)
71
+ except:
72
+ states = []
73
+ resume_state = None
74
+ if len(states) > 0:
75
+ max_state_file = '{}.state'.format(max([int(x[0:-6]) for x in states]))
76
+ resume_state = os.path.join(state_folder_path, max_state_file)
77
+ opt['path']['resume_state'] = resume_state
78
+
79
+ # load resume states if necessary
80
+ if opt['path'].get('resume_state'):
81
+ device_id = torch.cuda.current_device()
82
+ resume_state = torch.load(
83
+ opt['path']['resume_state'],
84
+ map_location=lambda storage, loc: storage.cuda(device_id))
85
+ else:
86
+ resume_state = None
87
+
88
+ # mkdir for experiments and logger
89
+ if resume_state is None:
90
+ make_exp_dirs(opt)
91
+ if opt['logger'].get('use_tb_logger') and 'debug' not in opt[
92
+ 'name'] and opt['rank'] == 0:
93
+ mkdir_and_rename(osp.join('tb_logger', opt['name']))
94
+
95
+
96
+ # define ks for Wiener filters
97
+ ks_params = opt['train'].get('ks', None)
98
+ if not ks_params:
99
+ raise NotImplementedError
100
+ M = ks_params['num']
101
+ ks = torch.logspace(ks_params['start'], ks_params['end'], M)
102
+ ks = ks.view(1,M,1,1,1,1).to("cuda")
103
+
104
+ val_conv = opt['val'].get("apply_conv", True)
105
+
106
+ # create model
107
+ if resume_state: # resume training
108
+ check_resume(opt, resume_state['iter'])
109
+ model = create_model(opt)
110
+ model.resume_training(resume_state) # handle optimizers and schedulers
111
+ current_iter = resume_state['iter']
112
+
113
+ else:
114
+ model = create_model(opt)
115
+ current_iter = 0
116
+
117
+ # load psf
118
+ psf = torch.tensor(np.load("./psf.npy")).to("cuda")
119
+ _,psf_h,psf_w,_ = psf.shape
120
+ otf = psf2otf(psf, h=psf_h*3, w=psf_w*3, permute=True)[None]
121
+
122
+ dataset_opt = opt['datasets']['val']
123
+
124
+ val_set = create_dataset(dataset_opt)
125
+ val_loader = create_dataloader(
126
+ val_set,
127
+ dataset_opt,
128
+ num_gpu=opt['num_gpu'],
129
+ dist=opt['dist'],
130
+ sampler=None,
131
+ seed=opt['manual_seed'])
132
+
133
+ print("Start validation on spatially varying aberrration")
134
+ rgb2bgr = opt['val'].get('rgb2bgr', True)
135
+ use_image = opt['val'].get('use_image', True)
136
+ psnr, others = model.validation(val_loader, current_iter, None, True, rgb2bgr, use_image, psf=otf, ks=ks, val_conv=val_conv)
137
+ print("==================")
138
+ print(f"Test results: PSNR: {psnr:.2f}, SSIM: {others['ssim']:.4f}, LPIPS: {others['lpips']:.4f}\n")
139
+
140
+
141
+ if __name__ == '__main__':
142
+ main()
basicsr/train.py ADDED
@@ -0,0 +1,328 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import datetime
3
+ import logging
4
+ import math
5
+ import random
6
+ import time
7
+ import torch
8
+ import gc
9
+ from os import path as osp
10
+
11
+ from basicsr.data import create_dataloader, create_dataset
12
+ from basicsr.data.data_sampler import EnlargedSampler
13
+ from basicsr.data.prefetch_dataloader import CPUPrefetcher, CUDAPrefetcher
14
+ from basicsr.models import create_model
15
+ from basicsr.utils import (MessageLogger, check_resume, get_env_info,
16
+ get_root_logger, get_time_str, init_tb_logger,
17
+ init_wandb_logger, make_exp_dirs, mkdir_and_rename,
18
+ set_random_seed)
19
+ from basicsr.utils.dist_util import get_dist_info, init_dist
20
+ from basicsr.utils.options import dict2str, parse
21
+ from basicsr.utils.nano import apply_conv_n_deconv, psf2otf
22
+
23
+ import numpy as np
24
+ from tqdm import tqdm
25
+
26
+ def parse_options(is_train=True):
27
+ parser = argparse.ArgumentParser()
28
+ parser.add_argument(
29
+ '-opt', type=str, required=True, help='Path to option YAML file.')
30
+ parser.add_argument(
31
+ '--launcher',
32
+ choices=['none', 'pytorch', 'slurm'],
33
+ default='none',
34
+ help='job launcher')
35
+ parser.add_argument(
36
+ '--name',
37
+ default=None,
38
+ help='job launcher')
39
+ import sys
40
+ vv = sys.version_info.minor
41
+ parser.add_argument('--local-rank', type=int, default=0)
42
+ parser.add_argument('--local_rank', type=int, default=0)
43
+ args = parser.parse_args()
44
+ opt = parse(args.opt, is_train=is_train, name=args.name if args.name is not None and args.name != "" else None)
45
+
46
+
47
+ # distributed settings
48
+ if args.launcher == 'none':
49
+ opt['dist'] = False
50
+ print('Disable distributed.', flush=True)
51
+ else:
52
+ opt['dist'] = True
53
+ if args.launcher == 'slurm' and 'dist_params' in opt:
54
+ init_dist(args.launcher, **opt['dist_params'])
55
+ else:
56
+ init_dist(args.launcher)
57
+ print('init dist .. ', args.launcher)
58
+
59
+ opt['rank'], opt['world_size'] = get_dist_info()
60
+
61
+ # random seed
62
+ seed = opt.get('manual_seed')
63
+ if seed is None:
64
+ seed = random.randint(1, 10000)
65
+ opt['manual_seed'] = seed
66
+ set_random_seed(seed + opt['rank'])
67
+
68
+ return opt
69
+
70
+
71
+ def init_loggers(opt):
72
+ log_file = osp.join(opt['path']['log'],
73
+ f"train_{opt['name']}_{get_time_str()}.log")
74
+ logger = get_root_logger(
75
+ logger_name='basicsr', log_level=logging.INFO, log_file=log_file)
76
+ logger.info(get_env_info())
77
+ logger.info(dict2str(opt))
78
+
79
+ # initialize wandb logger before tensorboard logger to allow proper sync:
80
+ if (opt['logger'].get('wandb')
81
+ is not None) and (opt['logger']['wandb'].get('project')
82
+ is not None) and ('debug' not in opt['name']):
83
+ assert opt['logger'].get('use_tb_logger') is True, (
84
+ 'should turn on tensorboard when using wandb')
85
+ init_wandb_logger(opt)
86
+ tb_logger = None
87
+ if opt['logger'].get('use_tb_logger') and 'debug' not in opt['name']:
88
+ tb_logger = init_tb_logger(log_dir=osp.join('tb_logger', opt['name']))
89
+ return logger, tb_logger
90
+
91
+
92
+ def create_train_val_dataloader(opt, logger):
93
+ # create train and val dataloaders
94
+ for phase, dataset_opt in opt['datasets'].items():
95
+ if phase == 'train':
96
+ dataset_enlarge_ratio = dataset_opt.get('dataset_enlarge_ratio', 1)
97
+ train_set = create_dataset(dataset_opt)
98
+ train_sampler = EnlargedSampler(train_set, opt['world_size'],
99
+ opt['rank'], dataset_enlarge_ratio)
100
+ train_loader = create_dataloader(
101
+ train_set,
102
+ dataset_opt,
103
+ num_gpu=opt['num_gpu'],
104
+ dist=opt['dist'],
105
+ sampler=train_sampler,
106
+ seed=opt['manual_seed'],
107
+ )
108
+
109
+ num_iter_per_epoch = math.ceil(
110
+ len(train_set) * dataset_enlarge_ratio /
111
+ (dataset_opt['batch_size_per_gpu'] * opt['world_size']))
112
+ total_iters = int(opt['train']['total_iter'])
113
+ total_epochs = math.ceil(total_iters / (num_iter_per_epoch))
114
+ logger.info(
115
+ 'Training statistics:'
116
+ f'\n\tNumber of train images: {len(train_set)}'
117
+ f'\n\tDataset enlarge ratio: {dataset_enlarge_ratio}'
118
+ f'\n\tBatch size per gpu: {dataset_opt["batch_size_per_gpu"]}'
119
+ f'\n\tWorld size (gpu number): {opt["world_size"]}'
120
+ f'\n\tRequire iter number per epoch: {num_iter_per_epoch}'
121
+ f'\n\tTotal epochs: {total_epochs}; iters: {total_iters}.')
122
+
123
+ elif phase == 'val':
124
+ val_set = create_dataset(dataset_opt)
125
+ val_loader = create_dataloader(
126
+ val_set,
127
+ dataset_opt,
128
+ num_gpu=opt['num_gpu'],
129
+ dist=opt['dist'],
130
+ sampler=None,
131
+ seed=opt['manual_seed'],
132
+ )
133
+ logger.info(
134
+ f'Number of val images/folders in {dataset_opt["name"]}: '
135
+ f'{len(val_set)}')
136
+
137
+ else:
138
+ raise ValueError(f'Dataset phase {phase} is not recognized.')
139
+
140
+ return train_loader, train_sampler, val_loader, total_epochs, total_iters
141
+
142
+
143
+ def main():
144
+ # parse options, set distributed setting, set ramdom seed
145
+ opt = parse_options(is_train=True)
146
+ torch.backends.cudnn.benchmark = True
147
+
148
+ # automatic resume ..
149
+ state_folder_path = 'experiments/{}/training_states/'.format(opt['name'])
150
+ import os
151
+ try:
152
+ states = os.listdir(state_folder_path)
153
+ except:
154
+ states = []
155
+ resume_state = None
156
+ if len(states) > 0:
157
+ max_state_file = '{}.state'.format(max([int(x[0:-6]) for x in states]))
158
+ resume_state = os.path.join(state_folder_path, max_state_file)
159
+ opt['path']['resume_state'] = resume_state
160
+
161
+ # load resume states if necessary
162
+ if opt['path'].get('resume_state'):
163
+ device_id = torch.cuda.current_device()
164
+ resume_state = torch.load(
165
+ opt['path']['resume_state'],
166
+ map_location=lambda storage, loc: storage.cuda(device_id))
167
+ else:
168
+ resume_state = None
169
+
170
+ # mkdir for experiments and logger
171
+ if resume_state is None:
172
+ make_exp_dirs(opt)
173
+ if opt['logger'].get('use_tb_logger') and 'debug' not in opt[
174
+ 'name'] and opt['rank'] == 0:
175
+ mkdir_and_rename(osp.join('tb_logger', opt['name']))
176
+
177
+ # initialize loggers
178
+ logger, tb_logger = init_loggers(opt)
179
+
180
+ # define ks for Wiener filters
181
+ ks_params = opt['train'].get('ks', None)
182
+ if not ks_params:
183
+ raise NotImplementedError
184
+ M = ks_params['num']
185
+ ks = torch.logspace(ks_params['start'], ks_params['end'], M)
186
+ ks = ks.view(1,M,1,1,1,1).to("cuda")
187
+
188
+ # create model
189
+ if resume_state: # resume training
190
+ check_resume(opt, resume_state['iter'])
191
+ model = create_model(opt)
192
+ model.resume_training(resume_state) # handle optimizers and schedulers
193
+ logger.info(f"Resuming training from epoch: {resume_state['epoch']}, "
194
+ f"iter: {resume_state['iter']}.")
195
+ start_epoch = resume_state['epoch']
196
+ current_iter = resume_state['iter']
197
+
198
+ else:
199
+ model = create_model(opt)
200
+ start_epoch = 0
201
+ current_iter = 0
202
+
203
+
204
+
205
+ # create train and validation dataloaders
206
+ result = create_train_val_dataloader(opt, logger)
207
+ train_loader, train_sampler, val_loader, total_epochs, total_iters = result
208
+
209
+
210
+ # create message logger (formatted outputs)
211
+ msg_logger = MessageLogger(opt, current_iter, tb_logger)
212
+
213
+ # dataloader prefetcher
214
+ prefetch_mode = opt['datasets']['train'].get('prefetch_mode')
215
+ if prefetch_mode is None or prefetch_mode == 'cpu':
216
+ prefetcher = CPUPrefetcher(train_loader)
217
+ elif prefetch_mode == 'cuda':
218
+ prefetcher = CUDAPrefetcher(train_loader, opt)
219
+ logger.info(f'Use {prefetch_mode} prefetch dataloader')
220
+ if opt['datasets']['train'].get('pin_memory') is not True:
221
+ raise ValueError('Please set pin_memory=True for CUDAPrefetcher.')
222
+ else:
223
+ raise ValueError(f'Wrong prefetch_mode {prefetch_mode}.'
224
+ "Supported ones are: None, 'cuda', 'cpu'.")
225
+
226
+ # training
227
+ logger.info(
228
+ f'Start training from epoch: {start_epoch}, iter: {current_iter}')
229
+ data_time, iter_time = time.time(), time.time()
230
+ start_time = time.time()
231
+
232
+
233
+
234
+ epoch = start_epoch
235
+ pbar = tqdm(total = total_iters+1)
236
+ pbar.update(current_iter)
237
+
238
+ # load psf
239
+ psf = torch.tensor(np.load("./psf.npy")).to("cuda")
240
+ psf_n,psf_h,psf_w,_ = psf.shape
241
+ psf_n_row = int(psf_n ** 0.5)
242
+ sensor_h = opt['datasets']['train'].get('sensor_size')
243
+ otf = psf2otf(psf, h=psf_h*3, w=psf_w*3, permute=True)[None]
244
+
245
+
246
+ gt_size = opt['datasets']['train']['gt_size']
247
+ val_conv = opt['val'].get("apply_conv", True)
248
+
249
+
250
+ while current_iter <= total_iters:
251
+ train_sampler.set_epoch(epoch)
252
+ prefetcher.reset()
253
+ train_data = prefetcher.next()
254
+
255
+ while train_data is not None:
256
+ data_time = time.time() - data_time
257
+
258
+ gt = train_data['gt'].to("cuda") # B,C,H,H
259
+ padding = train_data['padding']
260
+ padding = torch.stack(padding).T
261
+ lq, gt = apply_conv_n_deconv(gt, otf, padding, M, gt_size, ks=ks, ph=psf_h, num_psf=psf_n_row, sensor_h=sensor_h)
262
+
263
+
264
+ # 3 H W . conv -> crop
265
+ current_iter += 1
266
+ if current_iter > total_iters:
267
+ break
268
+ # update learning rate
269
+ model.update_learning_rate(
270
+ current_iter, warmup_iter=opt['train'].get('warmup_iter', -1))
271
+
272
+
273
+ model.feed_train_data({'lq': lq, 'gt':gt})
274
+ model.optimize_parameters(current_iter)
275
+
276
+
277
+ iter_time = time.time() - iter_time
278
+
279
+ # log
280
+ if current_iter % opt['logger']['print_freq'] == 0:
281
+ log_vars = {'epoch': epoch, 'iter': current_iter}
282
+ log_vars.update({'lrs': model.get_current_learning_rate()})
283
+ log_vars.update({'time': iter_time, 'data_time': data_time})
284
+
285
+ log_vars.update(model.get_current_log())
286
+ msg_logger(log_vars)
287
+
288
+ # save models and training states
289
+ if current_iter % opt['logger']['save_checkpoint_freq'] == 0:
290
+ logger.info('Saving models and training states.')
291
+ model.save(epoch, current_iter)
292
+
293
+ # validation
294
+ if opt.get('val') is not None and ((current_iter % opt['val']['val_freq'] == 0)):
295
+ rgb2bgr = opt['val'].get('rgb2bgr', True)
296
+ # wheather use uint8 image to compute metrics
297
+ use_image = opt['val'].get('use_image', True)
298
+ model.validation(val_loader, current_iter, tb_logger, False, rgb2bgr, use_image, psf=otf, ks=ks, val_conv=val_conv)
299
+ gc.collect()
300
+ torch.cuda.empty_cache()
301
+
302
+ data_time = time.time()
303
+ iter_time = time.time()
304
+ train_data = prefetcher.next()
305
+ pbar.update(1)
306
+ # end of iter
307
+ epoch += 1
308
+
309
+ # end of epoch
310
+
311
+ consumed_time = str(
312
+ datetime.timedelta(seconds=int(time.time() - start_time)))
313
+ logger.info(f'End of training. Time consumed: {consumed_time}')
314
+ logger.info('Save the latest model.')
315
+ model.save(epoch=-1, current_iter=-1) # -1 stands for the latest
316
+ if opt.get('val') is not None:
317
+ rgb2bgr = opt['val'].get('rgb2bgr', True)
318
+ use_image = opt['val'].get('use_image', True)
319
+ psnr, others = model.validation(val_loader, current_iter, tb_logger, True, rgb2bgr, use_image, psf=otf, ks=ks, val_conv=val_conv)
320
+ print("==================")
321
+ print(f"Test results: PSNR: {psnr:.2f}, SSIM: {others['ssim']:.4f}, LPIPS: {others['lpips']:.4f}\n")
322
+
323
+ if tb_logger:
324
+ tb_logger.close()
325
+
326
+
327
+ if __name__ == '__main__':
328
+ main()
basicsr/utils/__init__.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .file_client import FileClient
2
+ from .img_util import crop_border, imfrombytes, img2tensor, imwrite, tensor2img, padding, padding_DP, imfrombytesDP
3
+ from .logger import (MessageLogger, get_env_info, get_root_logger,
4
+ init_tb_logger, init_wandb_logger)
5
+ from .misc import (check_resume, get_time_str, make_exp_dirs, mkdir_and_rename,
6
+ scandir, scandir_mv, scandir_mv_flat, scandir_SIDD, set_random_seed, sizeof_fmt)
7
+ from .create_lmdb import (create_lmdb_for_reds, create_lmdb_for_gopro, create_lmdb_for_rain13k)
8
+
9
+ __all__ = [
10
+ # file_client.py
11
+ 'FileClient',
12
+ # img_util.py
13
+ 'img2tensor',
14
+ 'tensor2img',
15
+ 'imfrombytes',
16
+ 'imwrite',
17
+ 'crop_border',
18
+ # logger.py
19
+ 'MessageLogger',
20
+ 'init_tb_logger',
21
+ 'init_wandb_logger',
22
+ 'get_root_logger',
23
+ 'get_env_info',
24
+ # misc.py
25
+ 'set_random_seed',
26
+ 'get_time_str',
27
+ 'mkdir_and_rename',
28
+ 'make_exp_dirs',
29
+ 'scandir',
30
+ 'scandir_mv',
31
+ 'scandir_mv_flat',
32
+ 'check_resume',
33
+ 'sizeof_fmt',
34
+ 'padding',
35
+ 'padding_DP',
36
+ 'imfrombytesDP',
37
+ 'create_lmdb_for_reds',
38
+ 'create_lmdb_for_gopro',
39
+ 'create_lmdb_for_rain13k',
40
+ # nano.py
41
+ 'psf2otf',
42
+ 'fft',
43
+ 'ifft',
44
+ 'get_edgetaper_weight',
45
+ ]
basicsr/utils/bundle_submissions.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Author: Tobias Plötz, TU Darmstadt ([email protected])
2
+
3
+ # This file is part of the implementation as described in the CVPR 2017 paper:
4
+ # Tobias Plötz and Stefan Roth, Benchmarking Denoising Algorithms with Real Photographs.
5
+ # Please see the file LICENSE.txt for the license governing this code.
6
+
7
+
8
+ import numpy as np
9
+ import scipy.io as sio
10
+ import os
11
+ import h5py
12
+
13
+ def bundle_submissions_raw(submission_folder,session):
14
+ '''
15
+ Bundles submission data for raw denoising
16
+
17
+ submission_folder Folder where denoised images reside
18
+
19
+ Output is written to <submission_folder>/bundled/. Please submit
20
+ the content of this folder.
21
+ '''
22
+
23
+ out_folder = os.path.join(submission_folder, session)
24
+ # out_folder = os.path.join(submission_folder, "bundled/")
25
+ try:
26
+ os.mkdir(out_folder)
27
+ except:pass
28
+
29
+ israw = True
30
+ eval_version="1.0"
31
+
32
+ for i in range(50):
33
+ Idenoised = np.zeros((20,), dtype=np.object)
34
+ for bb in range(20):
35
+ filename = '%04d_%02d.mat'%(i+1,bb+1)
36
+ s = sio.loadmat(os.path.join(submission_folder,filename))
37
+ Idenoised_crop = s["Idenoised_crop"]
38
+ Idenoised[bb] = Idenoised_crop
39
+ filename = '%04d.mat'%(i+1)
40
+ sio.savemat(os.path.join(out_folder, filename),
41
+ {"Idenoised": Idenoised,
42
+ "israw": israw,
43
+ "eval_version": eval_version},
44
+ )
45
+
46
+ def bundle_submissions_srgb(submission_folder,session):
47
+ '''
48
+ Bundles submission data for sRGB denoising
49
+
50
+ submission_folder Folder where denoised images reside
51
+
52
+ Output is written to <submission_folder>/bundled/. Please submit
53
+ the content of this folder.
54
+ '''
55
+ out_folder = os.path.join(submission_folder, session)
56
+ # out_folder = os.path.join(submission_folder, "bundled/")
57
+ try:
58
+ os.mkdir(out_folder)
59
+ except:pass
60
+ israw = False
61
+ eval_version="1.0"
62
+
63
+ for i in range(50):
64
+ Idenoised = np.zeros((20,), dtype=np.object)
65
+ for bb in range(20):
66
+ filename = '%04d_%02d.mat'%(i+1,bb+1)
67
+ s = sio.loadmat(os.path.join(submission_folder,filename))
68
+ Idenoised_crop = s["Idenoised_crop"]
69
+ Idenoised[bb] = Idenoised_crop
70
+ filename = '%04d.mat'%(i+1)
71
+ sio.savemat(os.path.join(out_folder, filename),
72
+ {"Idenoised": Idenoised,
73
+ "israw": israw,
74
+ "eval_version": eval_version},
75
+ )
76
+
77
+
78
+
79
+ def bundle_submissions_srgb_v1(submission_folder,session):
80
+ '''
81
+ Bundles submission data for sRGB denoising
82
+
83
+ submission_folder Folder where denoised images reside
84
+
85
+ Output is written to <submission_folder>/bundled/. Please submit
86
+ the content of this folder.
87
+ '''
88
+ out_folder = os.path.join(submission_folder, session)
89
+ # out_folder = os.path.join(submission_folder, "bundled/")
90
+ try:
91
+ os.mkdir(out_folder)
92
+ except:pass
93
+ israw = False
94
+ eval_version="1.0"
95
+
96
+ for i in range(50):
97
+ Idenoised = np.zeros((20,), dtype=np.object)
98
+ for bb in range(20):
99
+ filename = '%04d_%d.mat'%(i+1,bb+1)
100
+ s = sio.loadmat(os.path.join(submission_folder,filename))
101
+ Idenoised_crop = s["Idenoised_crop"]
102
+ Idenoised[bb] = Idenoised_crop
103
+ filename = '%04d.mat'%(i+1)
104
+ sio.savemat(os.path.join(out_folder, filename),
105
+ {"Idenoised": Idenoised,
106
+ "israw": israw,
107
+ "eval_version": eval_version},
108
+ )
basicsr/utils/create_lmdb.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from os import path as osp
3
+
4
+ from basicsr.utils import scandir
5
+ from basicsr.utils.lmdb_util import make_lmdb_from_imgs
6
+
7
+ def prepare_keys(folder_path, suffix='png'):
8
+ """Prepare image path list and keys for DIV2K dataset.
9
+
10
+ Args:
11
+ folder_path (str): Folder path.
12
+
13
+ Returns:
14
+ list[str]: Image path list.
15
+ list[str]: Key list.
16
+ """
17
+ print('Reading image path list ...')
18
+ img_path_list = sorted(
19
+ list(scandir(folder_path, suffix=suffix, recursive=False)))
20
+ keys = [img_path.split('.{}'.format(suffix))[0] for img_path in sorted(img_path_list)]
21
+
22
+ return img_path_list, keys
23
+
24
+ def create_lmdb_for_reds():
25
+ folder_path = './datasets/REDS/val/sharp_300'
26
+ lmdb_path = './datasets/REDS/val/sharp_300.lmdb'
27
+ img_path_list, keys = prepare_keys(folder_path, 'png')
28
+ make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys)
29
+ #
30
+ folder_path = './datasets/REDS/val/blur_300'
31
+ lmdb_path = './datasets/REDS/val/blur_300.lmdb'
32
+ img_path_list, keys = prepare_keys(folder_path, 'jpg')
33
+ make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys)
34
+
35
+ folder_path = './datasets/REDS/train/train_sharp'
36
+ lmdb_path = './datasets/REDS/train/train_sharp.lmdb'
37
+ img_path_list, keys = prepare_keys(folder_path, 'png')
38
+ make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys)
39
+
40
+ folder_path = './datasets/REDS/train/train_blur_jpeg'
41
+ lmdb_path = './datasets/REDS/train/train_blur_jpeg.lmdb'
42
+ img_path_list, keys = prepare_keys(folder_path, 'jpg')
43
+ make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys)
44
+
45
+
46
+ def create_lmdb_for_gopro():
47
+ folder_path = './datasets/GoPro/train/blur_crops'
48
+ lmdb_path = './datasets/GoPro/train/blur_crops.lmdb'
49
+
50
+ img_path_list, keys = prepare_keys(folder_path, 'png')
51
+ make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys)
52
+
53
+ folder_path = './datasets/GoPro/train/sharp_crops'
54
+ lmdb_path = './datasets/GoPro/train/sharp_crops.lmdb'
55
+
56
+ img_path_list, keys = prepare_keys(folder_path, 'png')
57
+ make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys)
58
+
59
+ folder_path = './datasets/GoPro/test/target'
60
+ lmdb_path = './datasets/GoPro/test/target.lmdb'
61
+
62
+ img_path_list, keys = prepare_keys(folder_path, 'png')
63
+ make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys)
64
+
65
+ folder_path = './datasets/GoPro/test/input'
66
+ lmdb_path = './datasets/GoPro/test/input.lmdb'
67
+
68
+ img_path_list, keys = prepare_keys(folder_path, 'png')
69
+ make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys)
70
+
71
+ def create_lmdb_for_rain13k():
72
+ folder_path = './datasets/Rain13k/train/input'
73
+ lmdb_path = './datasets/Rain13k/train/input.lmdb'
74
+
75
+ img_path_list, keys = prepare_keys(folder_path, 'jpg')
76
+ make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys)
77
+
78
+ folder_path = './datasets/Rain13k/train/target'
79
+ lmdb_path = './datasets/Rain13k/train/target.lmdb'
80
+
81
+ img_path_list, keys = prepare_keys(folder_path, 'jpg')
82
+ make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys)
83
+
84
+ def create_lmdb_for_SIDD():
85
+ folder_path = './datasets/SIDD/train/input_crops'
86
+ lmdb_path = './datasets/SIDD/train/input_crops.lmdb'
87
+
88
+ img_path_list, keys = prepare_keys(folder_path, 'PNG')
89
+ make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys)
90
+
91
+ folder_path = './datasets/SIDD/train/gt_crops'
92
+ lmdb_path = './datasets/SIDD/train/gt_crops.lmdb'
93
+
94
+ img_path_list, keys = prepare_keys(folder_path, 'PNG')
95
+ make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys)
96
+
97
+ #for val
98
+ folder_path = './datasets/SIDD/val/input_crops'
99
+ lmdb_path = './datasets/SIDD/val/input_crops.lmdb'
100
+ mat_path = './datasets/SIDD/ValidationNoisyBlocksSrgb.mat'
101
+ if not osp.exists(folder_path):
102
+ os.makedirs(folder_path)
103
+ assert osp.exists(mat_path)
104
+ data = scio.loadmat(mat_path)['ValidationNoisyBlocksSrgb']
105
+ N, B, H ,W, C = data.shape
106
+ data = data.reshape(N*B, H, W, C)
107
+ for i in tqdm(range(N*B)):
108
+ cv2.imwrite(osp.join(folder_path, 'ValidationBlocksSrgb_{}.png'.format(i)), cv2.cvtColor(data[i,...], cv2.COLOR_RGB2BGR))
109
+ img_path_list, keys = prepare_keys(folder_path, 'png')
110
+ make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys)
111
+
112
+ folder_path = './datasets/SIDD/val/gt_crops'
113
+ lmdb_path = './datasets/SIDD/val/gt_crops.lmdb'
114
+ mat_path = './datasets/SIDD/ValidationGtBlocksSrgb.mat'
115
+ if not osp.exists(folder_path):
116
+ os.makedirs(folder_path)
117
+ assert osp.exists(mat_path)
118
+ data = scio.loadmat(mat_path)['ValidationGtBlocksSrgb']
119
+ N, B, H ,W, C = data.shape
120
+ data = data.reshape(N*B, H, W, C)
121
+ for i in tqdm(range(N*B)):
122
+ cv2.imwrite(osp.join(folder_path, 'ValidationBlocksSrgb_{}.png'.format(i)), cv2.cvtColor(data[i,...], cv2.COLOR_RGB2BGR))
123
+ img_path_list, keys = prepare_keys(folder_path, 'png')
124
+ make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys)
basicsr/utils/dist_util.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py # noqa: E501
2
+ import functools
3
+ import os
4
+ import subprocess
5
+ import torch
6
+ import torch.distributed as dist
7
+ import torch.multiprocessing as mp
8
+
9
+
10
+ def init_dist(launcher, backend='nccl', **kwargs):
11
+ if mp.get_start_method(allow_none=True) is None:
12
+ mp.set_start_method('spawn')
13
+ if launcher == 'pytorch':
14
+ _init_dist_pytorch(backend, **kwargs)
15
+ elif launcher == 'slurm':
16
+ _init_dist_slurm(backend, **kwargs)
17
+ else:
18
+ raise ValueError(f'Invalid launcher type: {launcher}')
19
+
20
+
21
+ def _init_dist_pytorch(backend, **kwargs):
22
+ rank = int(os.environ['RANK'])
23
+ num_gpus = torch.cuda.device_count()
24
+ torch.cuda.set_device(rank % num_gpus)
25
+ dist.init_process_group(backend=backend, **kwargs)
26
+
27
+
28
+ def _init_dist_slurm(backend, port=None):
29
+ """Initialize slurm distributed training environment.
30
+
31
+ If argument ``port`` is not specified, then the master port will be system
32
+ environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system
33
+ environment variable, then a default port ``29500`` will be used.
34
+
35
+ Args:
36
+ backend (str): Backend of torch.distributed.
37
+ port (int, optional): Master port. Defaults to None.
38
+ """
39
+ proc_id = int(os.environ['SLURM_PROCID'])
40
+ ntasks = int(os.environ['SLURM_NTASKS'])
41
+ node_list = os.environ['SLURM_NODELIST']
42
+ num_gpus = torch.cuda.device_count()
43
+ torch.cuda.set_device(proc_id % num_gpus)
44
+ addr = subprocess.getoutput(
45
+ f'scontrol show hostname {node_list} | head -n1')
46
+ # specify master port
47
+ if port is not None:
48
+ os.environ['MASTER_PORT'] = str(port)
49
+ elif 'MASTER_PORT' in os.environ:
50
+ pass # use MASTER_PORT in the environment variable
51
+ else:
52
+ # 29500 is torch.distributed default port
53
+ os.environ['MASTER_PORT'] = '29500'
54
+ os.environ['MASTER_ADDR'] = addr
55
+ os.environ['WORLD_SIZE'] = str(ntasks)
56
+ os.environ['LOCAL_RANK'] = str(proc_id % num_gpus)
57
+ os.environ['RANK'] = str(proc_id)
58
+ dist.init_process_group(backend=backend)
59
+
60
+
61
+ def get_dist_info():
62
+ if dist.is_available():
63
+ initialized = dist.is_initialized()
64
+ else:
65
+ initialized = False
66
+ if initialized:
67
+ rank = dist.get_rank()
68
+ world_size = dist.get_world_size()
69
+ else:
70
+ rank = 0
71
+ world_size = 1
72
+ return rank, world_size
73
+
74
+
75
+ def master_only(func):
76
+
77
+ @functools.wraps(func)
78
+ def wrapper(*args, **kwargs):
79
+ rank, _ = get_dist_info()
80
+ if rank == 0:
81
+ return func(*args, **kwargs)
82
+
83
+ return wrapper
basicsr/utils/download_util.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import requests
3
+ from tqdm import tqdm
4
+
5
+ from .misc import sizeof_fmt
6
+
7
+
8
+ def download_file_from_google_drive(file_id, save_path):
9
+ """Download files from google drive.
10
+
11
+ Ref:
12
+ https://stackoverflow.com/questions/25010369/wget-curl-large-file-from-google-drive # noqa E501
13
+
14
+ Args:
15
+ file_id (str): File id.
16
+ save_path (str): Save path.
17
+ """
18
+
19
+ session = requests.Session()
20
+ URL = 'https://docs.google.com/uc?export=download'
21
+ params = {'id': file_id}
22
+
23
+ response = session.get(URL, params=params, stream=True)
24
+ token = get_confirm_token(response)
25
+ if token:
26
+ params['confirm'] = token
27
+ response = session.get(URL, params=params, stream=True)
28
+
29
+ # get file size
30
+ response_file_size = session.get(
31
+ URL, params=params, stream=True, headers={'Range': 'bytes=0-2'})
32
+ if 'Content-Range' in response_file_size.headers:
33
+ file_size = int(
34
+ response_file_size.headers['Content-Range'].split('/')[1])
35
+ else:
36
+ file_size = None
37
+
38
+ save_response_content(response, save_path, file_size)
39
+
40
+
41
+ def get_confirm_token(response):
42
+ for key, value in response.cookies.items():
43
+ if key.startswith('download_warning'):
44
+ return value
45
+ return None
46
+
47
+
48
+ def save_response_content(response,
49
+ destination,
50
+ file_size=None,
51
+ chunk_size=32768):
52
+ if file_size is not None:
53
+ pbar = tqdm(total=math.ceil(file_size / chunk_size), unit='chunk')
54
+
55
+ readable_file_size = sizeof_fmt(file_size)
56
+ else:
57
+ pbar = None
58
+
59
+ with open(destination, 'wb') as f:
60
+ downloaded_size = 0
61
+ for chunk in response.iter_content(chunk_size):
62
+ downloaded_size += chunk_size
63
+ if pbar is not None:
64
+ pbar.update(1)
65
+ pbar.set_description(f'Download {sizeof_fmt(downloaded_size)} '
66
+ f'/ {readable_file_size}')
67
+ if chunk: # filter out keep-alive new chunks
68
+ f.write(chunk)
69
+ if pbar is not None:
70
+ pbar.close()
basicsr/utils/face_util.py ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ import os
4
+ import torch
5
+ from skimage import transform as trans
6
+
7
+ from basicsr.utils import imwrite
8
+
9
+ try:
10
+ import dlib
11
+ except ImportError:
12
+ print('Please install dlib before testing face restoration.'
13
+ 'Reference: https://github.com/davisking/dlib')
14
+
15
+
16
+ class FaceRestorationHelper(object):
17
+ """Helper for the face restoration pipeline."""
18
+
19
+ def __init__(self, upscale_factor, face_size=512):
20
+ self.upscale_factor = upscale_factor
21
+ self.face_size = (face_size, face_size)
22
+
23
+ # standard 5 landmarks for FFHQ faces with 1024 x 1024
24
+ self.face_template = np.array([[686.77227723, 488.62376238],
25
+ [586.77227723, 493.59405941],
26
+ [337.91089109, 488.38613861],
27
+ [437.95049505, 493.51485149],
28
+ [513.58415842, 678.5049505]])
29
+ self.face_template = self.face_template / (1024 // face_size)
30
+ # for estimation the 2D similarity transformation
31
+ self.similarity_trans = trans.SimilarityTransform()
32
+
33
+ self.all_landmarks_5 = []
34
+ self.all_landmarks_68 = []
35
+ self.affine_matrices = []
36
+ self.inverse_affine_matrices = []
37
+ self.cropped_faces = []
38
+ self.restored_faces = []
39
+ self.save_png = True
40
+
41
+ def init_dlib(self, detection_path, landmark5_path, landmark68_path):
42
+ """Initialize the dlib detectors and predictors."""
43
+ self.face_detector = dlib.cnn_face_detection_model_v1(detection_path)
44
+ self.shape_predictor_5 = dlib.shape_predictor(landmark5_path)
45
+ self.shape_predictor_68 = dlib.shape_predictor(landmark68_path)
46
+
47
+ def free_dlib_gpu_memory(self):
48
+ del self.face_detector
49
+ del self.shape_predictor_5
50
+ del self.shape_predictor_68
51
+
52
+ def read_input_image(self, img_path):
53
+ # self.input_img is Numpy array, (h, w, c) with RGB order
54
+ self.input_img = dlib.load_rgb_image(img_path)
55
+
56
+ def detect_faces(self,
57
+ img_path,
58
+ upsample_num_times=1,
59
+ only_keep_largest=False):
60
+ """
61
+ Args:
62
+ img_path (str): Image path.
63
+ upsample_num_times (int): Upsamples the image before running the
64
+ face detector
65
+
66
+ Returns:
67
+ int: Number of detected faces.
68
+ """
69
+ self.read_input_image(img_path)
70
+ det_faces = self.face_detector(self.input_img, upsample_num_times)
71
+ if len(det_faces) == 0:
72
+ print('No face detected. Try to increase upsample_num_times.')
73
+ else:
74
+ if only_keep_largest:
75
+ print('Detect several faces and only keep the largest.')
76
+ face_areas = []
77
+ for i in range(len(det_faces)):
78
+ face_area = (det_faces[i].rect.right() -
79
+ det_faces[i].rect.left()) * (
80
+ det_faces[i].rect.bottom() -
81
+ det_faces[i].rect.top())
82
+ face_areas.append(face_area)
83
+ largest_idx = face_areas.index(max(face_areas))
84
+ self.det_faces = [det_faces[largest_idx]]
85
+ else:
86
+ self.det_faces = det_faces
87
+ return len(self.det_faces)
88
+
89
+ def get_face_landmarks_5(self):
90
+ for face in self.det_faces:
91
+ shape = self.shape_predictor_5(self.input_img, face.rect)
92
+ landmark = np.array([[part.x, part.y] for part in shape.parts()])
93
+ self.all_landmarks_5.append(landmark)
94
+ return len(self.all_landmarks_5)
95
+
96
+ def get_face_landmarks_68(self):
97
+ """Get 68 densemarks for cropped images.
98
+
99
+ Should only have one face at most in the cropped image.
100
+ """
101
+ num_detected_face = 0
102
+ for idx, face in enumerate(self.cropped_faces):
103
+ # face detection
104
+ det_face = self.face_detector(face, 1) # TODO: can we remove it?
105
+ if len(det_face) == 0:
106
+ print(f'Cannot find faces in cropped image with index {idx}.')
107
+ self.all_landmarks_68.append(None)
108
+ else:
109
+ if len(det_face) > 1:
110
+ print('Detect several faces in the cropped face. Use the '
111
+ ' largest one. Note that it will also cause overlap '
112
+ 'during paste_faces_to_input_image.')
113
+ face_areas = []
114
+ for i in range(len(det_face)):
115
+ face_area = (det_face[i].rect.right() -
116
+ det_face[i].rect.left()) * (
117
+ det_face[i].rect.bottom() -
118
+ det_face[i].rect.top())
119
+ face_areas.append(face_area)
120
+ largest_idx = face_areas.index(max(face_areas))
121
+ face_rect = det_face[largest_idx].rect
122
+ else:
123
+ face_rect = det_face[0].rect
124
+ shape = self.shape_predictor_68(face, face_rect)
125
+ landmark = np.array([[part.x, part.y]
126
+ for part in shape.parts()])
127
+ self.all_landmarks_68.append(landmark)
128
+ num_detected_face += 1
129
+
130
+ return num_detected_face
131
+
132
+ def warp_crop_faces(self,
133
+ save_cropped_path=None,
134
+ save_inverse_affine_path=None):
135
+ """Get affine matrix, warp and cropped faces.
136
+
137
+ Also get inverse affine matrix for post-processing.
138
+ """
139
+ for idx, landmark in enumerate(self.all_landmarks_5):
140
+ # use 5 landmarks to get affine matrix
141
+ self.similarity_trans.estimate(landmark, self.face_template)
142
+ affine_matrix = self.similarity_trans.params[0:2, :]
143
+ self.affine_matrices.append(affine_matrix)
144
+ # warp and crop faces
145
+ cropped_face = cv2.warpAffine(self.input_img, affine_matrix,
146
+ self.face_size)
147
+ self.cropped_faces.append(cropped_face)
148
+ # save the cropped face
149
+ if save_cropped_path is not None:
150
+ path, ext = os.path.splitext(save_cropped_path)
151
+ if self.save_png:
152
+ save_path = f'{path}_{idx:02d}.png'
153
+ else:
154
+ save_path = f'{path}_{idx:02d}{ext}'
155
+
156
+ imwrite(
157
+ cv2.cvtColor(cropped_face, cv2.COLOR_RGB2BGR), save_path)
158
+
159
+ # get inverse affine matrix
160
+ self.similarity_trans.estimate(self.face_template,
161
+ landmark * self.upscale_factor)
162
+ inverse_affine = self.similarity_trans.params[0:2, :]
163
+ self.inverse_affine_matrices.append(inverse_affine)
164
+ # save inverse affine matrices
165
+ if save_inverse_affine_path is not None:
166
+ path, _ = os.path.splitext(save_inverse_affine_path)
167
+ save_path = f'{path}_{idx:02d}.pth'
168
+ torch.save(inverse_affine, save_path)
169
+
170
+ def add_restored_face(self, face):
171
+ self.restored_faces.append(face)
172
+
173
+ def paste_faces_to_input_image(self, save_path):
174
+ # operate in the BGR order
175
+ input_img = cv2.cvtColor(self.input_img, cv2.COLOR_RGB2BGR)
176
+ h, w, _ = input_img.shape
177
+ h_up, w_up = h * self.upscale_factor, w * self.upscale_factor
178
+ # simply resize the background
179
+ upsample_img = cv2.resize(input_img, (w_up, h_up))
180
+ assert len(self.restored_faces) == len(self.inverse_affine_matrices), (
181
+ 'length of restored_faces and affine_matrices are different.')
182
+ for restored_face, inverse_affine in zip(self.restored_faces,
183
+ self.inverse_affine_matrices):
184
+ inv_restored = cv2.warpAffine(restored_face, inverse_affine,
185
+ (w_up, h_up))
186
+ mask = np.ones((*self.face_size, 3), dtype=np.float32)
187
+ inv_mask = cv2.warpAffine(mask, inverse_affine, (w_up, h_up))
188
+ # remove the black borders
189
+ inv_mask_erosion = cv2.erode(
190
+ inv_mask,
191
+ np.ones((2 * self.upscale_factor, 2 * self.upscale_factor),
192
+ np.uint8))
193
+ inv_restored_remove_border = inv_mask_erosion * inv_restored
194
+ total_face_area = np.sum(inv_mask_erosion) // 3
195
+ # compute the fusion edge based on the area of face
196
+ w_edge = int(total_face_area**0.5) // 20
197
+ erosion_radius = w_edge * 2
198
+ inv_mask_center = cv2.erode(
199
+ inv_mask_erosion,
200
+ np.ones((erosion_radius, erosion_radius), np.uint8))
201
+ blur_size = w_edge * 2
202
+ inv_soft_mask = cv2.GaussianBlur(inv_mask_center,
203
+ (blur_size + 1, blur_size + 1), 0)
204
+ upsample_img = inv_soft_mask * inv_restored_remove_border + (
205
+ 1 - inv_soft_mask) * upsample_img
206
+ if self.save_png:
207
+ save_path = save_path.replace('.jpg',
208
+ '.png').replace('.jpeg', '.png')
209
+ imwrite(upsample_img.astype(np.uint8), save_path)
210
+
211
+ def clean_all(self):
212
+ self.all_landmarks_5 = []
213
+ self.all_landmarks_68 = []
214
+ self.restored_faces = []
215
+ self.affine_matrices = []
216
+ self.cropped_faces = []
217
+ self.inverse_affine_matrices = []
basicsr/utils/file_client.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/fileio/file_client.py # noqa: E501
2
+ from abc import ABCMeta, abstractmethod
3
+
4
+
5
+ class BaseStorageBackend(metaclass=ABCMeta):
6
+ """Abstract class of storage backends.
7
+
8
+ All backends need to implement two apis: ``get()`` and ``get_text()``.
9
+ ``get()`` reads the file as a byte stream and ``get_text()`` reads the file
10
+ as texts.
11
+ """
12
+
13
+ @abstractmethod
14
+ def get(self, filepath):
15
+ pass
16
+
17
+ @abstractmethod
18
+ def get_text(self, filepath):
19
+ pass
20
+
21
+
22
+ class MemcachedBackend(BaseStorageBackend):
23
+ """Memcached storage backend.
24
+
25
+ Attributes:
26
+ server_list_cfg (str): Config file for memcached server list.
27
+ client_cfg (str): Config file for memcached client.
28
+ sys_path (str | None): Additional path to be appended to `sys.path`.
29
+ Default: None.
30
+ """
31
+
32
+ def __init__(self, server_list_cfg, client_cfg, sys_path=None):
33
+ if sys_path is not None:
34
+ import sys
35
+ sys.path.append(sys_path)
36
+ try:
37
+ import mc
38
+ except ImportError:
39
+ raise ImportError(
40
+ 'Please install memcached to enable MemcachedBackend.')
41
+
42
+ self.server_list_cfg = server_list_cfg
43
+ self.client_cfg = client_cfg
44
+ self._client = mc.MemcachedClient.GetInstance(self.server_list_cfg,
45
+ self.client_cfg)
46
+ # mc.pyvector servers as a point which points to a memory cache
47
+ self._mc_buffer = mc.pyvector()
48
+
49
+ def get(self, filepath):
50
+ filepath = str(filepath)
51
+ import mc
52
+ self._client.Get(filepath, self._mc_buffer)
53
+ value_buf = mc.ConvertBuffer(self._mc_buffer)
54
+ return value_buf
55
+
56
+ def get_text(self, filepath):
57
+ raise NotImplementedError
58
+
59
+
60
+ class HardDiskBackend(BaseStorageBackend):
61
+ """Raw hard disks storage backend."""
62
+
63
+ def get(self, filepath):
64
+ filepath = str(filepath)
65
+ with open(filepath, 'rb') as f:
66
+ value_buf = f.read()
67
+ return value_buf
68
+
69
+ def get_text(self, filepath):
70
+ filepath = str(filepath)
71
+ with open(filepath, 'r') as f:
72
+ value_buf = f.read()
73
+ return value_buf
74
+
75
+
76
+ class LmdbBackend(BaseStorageBackend):
77
+ """Lmdb storage backend.
78
+
79
+ Args:
80
+ db_paths (str | list[str]): Lmdb database paths.
81
+ client_keys (str | list[str]): Lmdb client keys. Default: 'default'.
82
+ readonly (bool, optional): Lmdb environment parameter. If True,
83
+ disallow any write operations. Default: True.
84
+ lock (bool, optional): Lmdb environment parameter. If False, when
85
+ concurrent access occurs, do not lock the database. Default: False.
86
+ readahead (bool, optional): Lmdb environment parameter. If False,
87
+ disable the OS filesystem readahead mechanism, which may improve
88
+ random read performance when a database is larger than RAM.
89
+ Default: False.
90
+
91
+ Attributes:
92
+ db_paths (list): Lmdb database path.
93
+ _client (list): A list of several lmdb envs.
94
+ """
95
+
96
+ def __init__(self,
97
+ db_paths,
98
+ client_keys='default',
99
+ readonly=True,
100
+ lock=False,
101
+ readahead=False,
102
+ **kwargs):
103
+ try:
104
+ import lmdb
105
+ except ImportError:
106
+ raise ImportError('Please install lmdb to enable LmdbBackend.')
107
+
108
+ if isinstance(client_keys, str):
109
+ client_keys = [client_keys]
110
+
111
+ if isinstance(db_paths, list):
112
+ self.db_paths = [str(v) for v in db_paths]
113
+ elif isinstance(db_paths, str):
114
+ self.db_paths = [str(db_paths)]
115
+ assert len(client_keys) == len(self.db_paths), (
116
+ 'client_keys and db_paths should have the same length, '
117
+ f'but received {len(client_keys)} and {len(self.db_paths)}.')
118
+
119
+ self._client = {}
120
+
121
+ for client, path in zip(client_keys, self.db_paths):
122
+ self._client[client] = lmdb.open(
123
+ path,
124
+ readonly=readonly,
125
+ lock=lock,
126
+ readahead=readahead,
127
+ map_size=8*1024*10485760,
128
+ # max_readers=1,
129
+ **kwargs)
130
+
131
+ def get(self, filepath, client_key):
132
+ """Get values according to the filepath from one lmdb named client_key.
133
+
134
+ Args:
135
+ filepath (str | obj:`Path`): Here, filepath is the lmdb key.
136
+ client_key (str): Used for distinguishing differnet lmdb envs.
137
+ """
138
+ filepath = str(filepath)
139
+ assert client_key in self._client, (f'client_key {client_key} is not '
140
+ 'in lmdb clients.')
141
+ client = self._client[client_key]
142
+ with client.begin(write=False) as txn:
143
+ value_buf = txn.get(filepath.encode('ascii'))
144
+ return value_buf
145
+
146
+ def get_text(self, filepath):
147
+ raise NotImplementedError
148
+
149
+
150
+ class FileClient(object):
151
+ """A general file client to access files in different backend.
152
+
153
+ The client loads a file or text in a specified backend from its path
154
+ and return it as a binary file. it can also register other backend
155
+ accessor with a given name and backend class.
156
+
157
+ Attributes:
158
+ backend (str): The storage backend type. Options are "disk",
159
+ "memcached" and "lmdb".
160
+ client (:obj:`BaseStorageBackend`): The backend object.
161
+ """
162
+
163
+ _backends = {
164
+ 'disk': HardDiskBackend,
165
+ 'memcached': MemcachedBackend,
166
+ 'lmdb': LmdbBackend,
167
+ }
168
+
169
+ def __init__(self, backend='disk', **kwargs):
170
+ if backend not in self._backends:
171
+ raise ValueError(
172
+ f'Backend {backend} is not supported. Currently supported ones'
173
+ f' are {list(self._backends.keys())}')
174
+ self.backend = backend
175
+ self.client = self._backends[backend](**kwargs)
176
+
177
+ def get(self, filepath, client_key='default'):
178
+ # client_key is used only for lmdb, where different fileclients have
179
+ # different lmdb environments.
180
+ if self.backend == 'lmdb':
181
+ return self.client.get(filepath, client_key)
182
+ else:
183
+ return self.client.get(filepath)
184
+
185
+ def get_text(self, filepath):
186
+ return self.client.get_text(filepath)
basicsr/utils/flow_util.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/video/optflow.py # noqa: E501
2
+ import cv2
3
+ import numpy as np
4
+ import os
5
+
6
+
7
+ def flowread(flow_path, quantize=False, concat_axis=0, *args, **kwargs):
8
+ """Read an optical flow map.
9
+
10
+ Args:
11
+ flow_path (ndarray or str): Flow path.
12
+ quantize (bool): whether to read quantized pair, if set to True,
13
+ remaining args will be passed to :func:`dequantize_flow`.
14
+ concat_axis (int): The axis that dx and dy are concatenated,
15
+ can be either 0 or 1. Ignored if quantize is False.
16
+
17
+ Returns:
18
+ ndarray: Optical flow represented as a (h, w, 2) numpy array
19
+ """
20
+ if quantize:
21
+ assert concat_axis in [0, 1]
22
+ cat_flow = cv2.imread(flow_path, cv2.IMREAD_UNCHANGED)
23
+ if cat_flow.ndim != 2:
24
+ raise IOError(f'{flow_path} is not a valid quantized flow file, '
25
+ f'its dimension is {cat_flow.ndim}.')
26
+ assert cat_flow.shape[concat_axis] % 2 == 0
27
+ dx, dy = np.split(cat_flow, 2, axis=concat_axis)
28
+ flow = dequantize_flow(dx, dy, *args, **kwargs)
29
+ else:
30
+ with open(flow_path, 'rb') as f:
31
+ try:
32
+ header = f.read(4).decode('utf-8')
33
+ except Exception:
34
+ raise IOError(f'Invalid flow file: {flow_path}')
35
+ else:
36
+ if header != 'PIEH':
37
+ raise IOError(f'Invalid flow file: {flow_path}, '
38
+ 'header does not contain PIEH')
39
+
40
+ w = np.fromfile(f, np.int32, 1).squeeze()
41
+ h = np.fromfile(f, np.int32, 1).squeeze()
42
+ flow = np.fromfile(f, np.float32, w * h * 2).reshape((h, w, 2))
43
+
44
+ return flow.astype(np.float32)
45
+
46
+
47
+ def flowwrite(flow, filename, quantize=False, concat_axis=0, *args, **kwargs):
48
+ """Write optical flow to file.
49
+
50
+ If the flow is not quantized, it will be saved as a .flo file losslessly,
51
+ otherwise a jpeg image which is lossy but of much smaller size. (dx and dy
52
+ will be concatenated horizontally into a single image if quantize is True.)
53
+
54
+ Args:
55
+ flow (ndarray): (h, w, 2) array of optical flow.
56
+ filename (str): Output filepath.
57
+ quantize (bool): Whether to quantize the flow and save it to 2 jpeg
58
+ images. If set to True, remaining args will be passed to
59
+ :func:`quantize_flow`.
60
+ concat_axis (int): The axis that dx and dy are concatenated,
61
+ can be either 0 or 1. Ignored if quantize is False.
62
+ """
63
+ if not quantize:
64
+ with open(filename, 'wb') as f:
65
+ f.write('PIEH'.encode('utf-8'))
66
+ np.array([flow.shape[1], flow.shape[0]], dtype=np.int32).tofile(f)
67
+ flow = flow.astype(np.float32)
68
+ flow.tofile(f)
69
+ f.flush()
70
+ else:
71
+ assert concat_axis in [0, 1]
72
+ dx, dy = quantize_flow(flow, *args, **kwargs)
73
+ dxdy = np.concatenate((dx, dy), axis=concat_axis)
74
+ os.makedirs(filename, exist_ok=True)
75
+ cv2.imwrite(dxdy, filename)
76
+
77
+
78
+ def quantize_flow(flow, max_val=0.02, norm=True):
79
+ """Quantize flow to [0, 255].
80
+
81
+ After this step, the size of flow will be much smaller, and can be
82
+ dumped as jpeg images.
83
+
84
+ Args:
85
+ flow (ndarray): (h, w, 2) array of optical flow.
86
+ max_val (float): Maximum value of flow, values beyond
87
+ [-max_val, max_val] will be truncated.
88
+ norm (bool): Whether to divide flow values by image width/height.
89
+
90
+ Returns:
91
+ tuple[ndarray]: Quantized dx and dy.
92
+ """
93
+ h, w, _ = flow.shape
94
+ dx = flow[..., 0]
95
+ dy = flow[..., 1]
96
+ if norm:
97
+ dx = dx / w # avoid inplace operations
98
+ dy = dy / h
99
+ # use 255 levels instead of 256 to make sure 0 is 0 after dequantization.
100
+ flow_comps = [
101
+ quantize(d, -max_val, max_val, 255, np.uint8) for d in [dx, dy]
102
+ ]
103
+ return tuple(flow_comps)
104
+
105
+
106
+ def dequantize_flow(dx, dy, max_val=0.02, denorm=True):
107
+ """Recover from quantized flow.
108
+
109
+ Args:
110
+ dx (ndarray): Quantized dx.
111
+ dy (ndarray): Quantized dy.
112
+ max_val (float): Maximum value used when quantizing.
113
+ denorm (bool): Whether to multiply flow values with width/height.
114
+
115
+ Returns:
116
+ ndarray: Dequantized flow.
117
+ """
118
+ assert dx.shape == dy.shape
119
+ assert dx.ndim == 2 or (dx.ndim == 3 and dx.shape[-1] == 1)
120
+
121
+ dx, dy = [dequantize(d, -max_val, max_val, 255) for d in [dx, dy]]
122
+
123
+ if denorm:
124
+ dx *= dx.shape[1]
125
+ dy *= dx.shape[0]
126
+ flow = np.dstack((dx, dy))
127
+ return flow
128
+
129
+
130
+ def quantize(arr, min_val, max_val, levels, dtype=np.int64):
131
+ """Quantize an array of (-inf, inf) to [0, levels-1].
132
+
133
+ Args:
134
+ arr (ndarray): Input array.
135
+ min_val (scalar): Minimum value to be clipped.
136
+ max_val (scalar): Maximum value to be clipped.
137
+ levels (int): Quantization levels.
138
+ dtype (np.type): The type of the quantized array.
139
+
140
+ Returns:
141
+ tuple: Quantized array.
142
+ """
143
+ if not (isinstance(levels, int) and levels > 1):
144
+ raise ValueError(
145
+ f'levels must be a positive integer, but got {levels}')
146
+ if min_val >= max_val:
147
+ raise ValueError(
148
+ f'min_val ({min_val}) must be smaller than max_val ({max_val})')
149
+
150
+ arr = np.clip(arr, min_val, max_val) - min_val
151
+ quantized_arr = np.minimum(
152
+ np.floor(levels * arr / (max_val - min_val)).astype(dtype), levels - 1)
153
+
154
+ return quantized_arr
155
+
156
+
157
+ def dequantize(arr, min_val, max_val, levels, dtype=np.float64):
158
+ """Dequantize an array.
159
+
160
+ Args:
161
+ arr (ndarray): Input array.
162
+ min_val (scalar): Minimum value to be clipped.
163
+ max_val (scalar): Maximum value to be clipped.
164
+ levels (int): Quantization levels.
165
+ dtype (np.type): The type of the dequantized array.
166
+
167
+ Returns:
168
+ tuple: Dequantized array.
169
+ """
170
+ if not (isinstance(levels, int) and levels > 1):
171
+ raise ValueError(
172
+ f'levels must be a positive integer, but got {levels}')
173
+ if min_val >= max_val:
174
+ raise ValueError(
175
+ f'min_val ({min_val}) must be smaller than max_val ({max_val})')
176
+
177
+ dequantized_arr = (arr + 0.5).astype(dtype) * (max_val -
178
+ min_val) / levels + min_val
179
+
180
+ return dequantized_arr
basicsr/utils/img_util.py ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import math
3
+ import numpy as np
4
+ import os
5
+ import torch
6
+ from torchvision.utils import make_grid
7
+
8
+
9
+ def img2tensor(imgs, bgr2rgb=True, float32=True):
10
+ """Numpy array to tensor.
11
+
12
+ Args:
13
+ imgs (list[ndarray] | ndarray): Input images.
14
+ bgr2rgb (bool): Whether to change bgr to rgb.
15
+ float32 (bool): Whether to change to float32.
16
+
17
+ Returns:
18
+ list[tensor] | tensor: Tensor images. If returned results only have
19
+ one element, just return tensor.
20
+ """
21
+
22
+ def _totensor(img, bgr2rgb, float32):
23
+ if img.shape[2] == 3 and bgr2rgb:
24
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
25
+ img = torch.from_numpy(img.transpose(2, 0, 1))
26
+ if float32:
27
+ img = img.float()
28
+ return img
29
+
30
+ if isinstance(imgs, list):
31
+ return [_totensor(img, bgr2rgb, float32) for img in imgs]
32
+ else:
33
+ return _totensor(imgs, bgr2rgb, float32)
34
+
35
+
36
+ def tensor2img(tensor, rgb2bgr=True, out_type=np.uint8, min_max=(0, 1)):
37
+ """Convert torch Tensors into image numpy arrays.
38
+
39
+ After clamping to [min, max], values will be normalized to [0, 1].
40
+
41
+ Args:
42
+ tensor (Tensor or list[Tensor]): Accept shapes:
43
+ 1) 4D mini-batch Tensor of shape (B x 3/1 x H x W);
44
+ 2) 3D Tensor of shape (3/1 x H x W);
45
+ 3) 2D Tensor of shape (H x W).
46
+ Tensor channel should be in RGB order.
47
+ rgb2bgr (bool): Whether to change rgb to bgr.
48
+ out_type (numpy type): output types. If ``np.uint8``, transform outputs
49
+ to uint8 type with range [0, 255]; otherwise, float type with
50
+ range [0, 1]. Default: ``np.uint8``.
51
+ min_max (tuple[int]): min and max values for clamp.
52
+
53
+ Returns:
54
+ (Tensor or list): 3D ndarray of shape (H x W x C) OR 2D ndarray of
55
+ shape (H x W). The channel order is BGR.
56
+ """
57
+ if not (torch.is_tensor(tensor) or
58
+ (isinstance(tensor, list)
59
+ and all(torch.is_tensor(t) for t in tensor))):
60
+ raise TypeError(
61
+ f'tensor or list of tensors expected, got {type(tensor)}')
62
+
63
+ if torch.is_tensor(tensor):
64
+ tensor = [tensor]
65
+ result = []
66
+ for _tensor in tensor:
67
+ _tensor = _tensor.squeeze(0).float().detach().cpu().clamp_(*min_max)
68
+ _tensor = (_tensor - min_max[0]) / (min_max[1] - min_max[0])
69
+
70
+ n_dim = _tensor.dim()
71
+ if n_dim == 4:
72
+ img_np = make_grid(
73
+ _tensor, nrow=int(math.sqrt(_tensor.size(0))),
74
+ normalize=False).numpy()
75
+ img_np = img_np.transpose(1, 2, 0)
76
+ if rgb2bgr:
77
+ img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
78
+ elif n_dim == 3:
79
+ img_np = _tensor.numpy()
80
+ img_np = img_np.transpose(1, 2, 0)
81
+ if img_np.shape[2] == 1: # gray image
82
+ img_np = np.squeeze(img_np, axis=2)
83
+ else:
84
+ if rgb2bgr:
85
+ img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
86
+ elif n_dim == 2:
87
+ img_np = _tensor.numpy()
88
+ else:
89
+ raise TypeError('Only support 4D, 3D or 2D tensor. '
90
+ f'But received with dimension: {n_dim}')
91
+ if out_type == np.uint8:
92
+ # Unlike MATLAB, numpy.unit8() WILL NOT round by default.
93
+ img_np = (img_np * 255.0).round()
94
+ img_np = img_np.astype(out_type)
95
+ result.append(img_np)
96
+ if len(result) == 1:
97
+ result = result[0]
98
+ return result
99
+
100
+
101
+ def imfrombytes(content, flag='color', float32=False):
102
+ """Read an image from bytes.
103
+
104
+ Args:
105
+ content (bytes): Image bytes got from files or other streams.
106
+ flag (str): Flags specifying the color type of a loaded image,
107
+ candidates are `color`, `grayscale` and `unchanged`.
108
+ float32 (bool): Whether to change to float32., If True, will also norm
109
+ to [0, 1]. Default: False.
110
+
111
+ Returns:
112
+ ndarray: Loaded image array.
113
+ """
114
+ img_np = np.frombuffer(content, np.uint8)
115
+ imread_flags = {
116
+ 'color': cv2.IMREAD_COLOR,
117
+ 'grayscale': cv2.IMREAD_GRAYSCALE,
118
+ 'unchanged': cv2.IMREAD_UNCHANGED
119
+ }
120
+ if img_np is None:
121
+ raise Exception('None .. !!!')
122
+ img = cv2.imdecode(img_np, imread_flags[flag])
123
+ if float32:
124
+ img = img.astype(np.float32) / 255.
125
+ return img
126
+
127
+ def imfrombytesDP(content, flag='color', float32=False):
128
+ """Read an image from bytes.
129
+
130
+ Args:
131
+ content (bytes): Image bytes got from files or other streams.
132
+ flag (str): Flags specifying the color type of a loaded image,
133
+ candidates are `color`, `grayscale` and `unchanged`.
134
+ float32 (bool): Whether to change to float32., If True, will also norm
135
+ to [0, 1]. Default: False.
136
+
137
+ Returns:
138
+ ndarray: Loaded image array.
139
+ """
140
+ img_np = np.frombuffer(content, np.uint8)
141
+ if img_np is None:
142
+ raise Exception('None .. !!!')
143
+ img = cv2.imdecode(img_np, cv2.IMREAD_UNCHANGED)
144
+ if float32:
145
+ img = img.astype(np.float32) / 65535.
146
+ return img
147
+
148
+ def padding(img_gt, gt_size):
149
+ h, w, _ = img_gt.shape
150
+
151
+ h_pad = max(0, gt_size - h)
152
+ w_pad = max(0, gt_size - w)
153
+
154
+ if h_pad == 0 and w_pad == 0:
155
+ return img_gt
156
+
157
+ img_gt = cv2.copyMakeBorder(img_gt, 0, h_pad, 0, w_pad, cv2.BORDER_REFLECT)
158
+ if img_gt.ndim == 2:
159
+ img_gt = np.expand_dims(img_gt, axis=2)
160
+ return img_gt
161
+
162
+ def padding_DP(img_lqL, img_lqR, img_gt, gt_size):
163
+ h, w, _ = img_gt.shape
164
+
165
+ h_pad = max(0, gt_size - h)
166
+ w_pad = max(0, gt_size - w)
167
+
168
+ if h_pad == 0 and w_pad == 0:
169
+ return img_lqL, img_lqR, img_gt
170
+
171
+ img_lqL = cv2.copyMakeBorder(img_lqL, 0, h_pad, 0, w_pad, cv2.BORDER_REFLECT)
172
+ img_lqR = cv2.copyMakeBorder(img_lqR, 0, h_pad, 0, w_pad, cv2.BORDER_REFLECT)
173
+ img_gt = cv2.copyMakeBorder(img_gt, 0, h_pad, 0, w_pad, cv2.BORDER_REFLECT)
174
+ # print('img_lq', img_lq.shape, img_gt.shape)
175
+ return img_lqL, img_lqR, img_gt
176
+
177
+ def imwrite(img, file_path, params=None, auto_mkdir=True):
178
+ """Write image to file.
179
+
180
+ Args:
181
+ img (ndarray): Image array to be written.
182
+ file_path (str): Image file path.
183
+ params (None or list): Same as opencv's :func:`imwrite` interface.
184
+ auto_mkdir (bool): If the parent folder of `file_path` does not exist,
185
+ whether to create it automatically.
186
+
187
+ Returns:
188
+ bool: Successful or not.
189
+ """
190
+ if auto_mkdir:
191
+ dir_name = os.path.abspath(os.path.dirname(file_path))
192
+ os.makedirs(dir_name, exist_ok=True)
193
+ return cv2.imwrite(file_path, img, params)
194
+
195
+
196
+ def crop_border(imgs, crop_border):
197
+ """Crop borders of images.
198
+
199
+ Args:
200
+ imgs (list[ndarray] | ndarray): Images with shape (h, w, c).
201
+ crop_border (int): Crop border for each end of height and weight.
202
+
203
+ Returns:
204
+ list[ndarray]: Cropped images.
205
+ """
206
+ if crop_border == 0:
207
+ return imgs
208
+ else:
209
+ if isinstance(imgs, list):
210
+ return [
211
+ v[crop_border:-crop_border, crop_border:-crop_border, ...]
212
+ for v in imgs
213
+ ]
214
+ else:
215
+ return imgs[crop_border:-crop_border, crop_border:-crop_border,
216
+ ...]
basicsr/utils/lmdb_util.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import lmdb
3
+ import sys
4
+ from multiprocessing import Pool
5
+ from os import path as osp
6
+ from tqdm import tqdm
7
+
8
+
9
+ def make_lmdb_from_imgs(data_path,
10
+ lmdb_path,
11
+ img_path_list,
12
+ keys,
13
+ batch=5000,
14
+ compress_level=1,
15
+ multiprocessing_read=False,
16
+ n_thread=40,
17
+ map_size=None):
18
+ """Make lmdb from images.
19
+
20
+ Contents of lmdb. The file structure is:
21
+ example.lmdb
22
+ ├── data.mdb
23
+ ├── lock.mdb
24
+ ├── meta_info.txt
25
+
26
+ The data.mdb and lock.mdb are standard lmdb files and you can refer to
27
+ https://lmdb.readthedocs.io/en/release/ for more details.
28
+
29
+ The meta_info.txt is a specified txt file to record the meta information
30
+ of our datasets. It will be automatically created when preparing
31
+ datasets by our provided dataset tools.
32
+ Each line in the txt file records 1)image name (with extension),
33
+ 2)image shape, and 3)compression level, separated by a white space.
34
+
35
+ For example, the meta information could be:
36
+ `000_00000000.png (720,1280,3) 1`, which means:
37
+ 1) image name (with extension): 000_00000000.png;
38
+ 2) image shape: (720,1280,3);
39
+ 3) compression level: 1
40
+
41
+ We use the image name without extension as the lmdb key.
42
+
43
+ If `multiprocessing_read` is True, it will read all the images to memory
44
+ using multiprocessing. Thus, your server needs to have enough memory.
45
+
46
+ Args:
47
+ data_path (str): Data path for reading images.
48
+ lmdb_path (str): Lmdb save path.
49
+ img_path_list (str): Image path list.
50
+ keys (str): Used for lmdb keys.
51
+ batch (int): After processing batch images, lmdb commits.
52
+ Default: 5000.
53
+ compress_level (int): Compress level when encoding images. Default: 1.
54
+ multiprocessing_read (bool): Whether use multiprocessing to read all
55
+ the images to memory. Default: False.
56
+ n_thread (int): For multiprocessing.
57
+ map_size (int | None): Map size for lmdb env. If None, use the
58
+ estimated size from images. Default: None
59
+ """
60
+
61
+ assert len(img_path_list) == len(keys), (
62
+ 'img_path_list and keys should have the same length, '
63
+ f'but got {len(img_path_list)} and {len(keys)}')
64
+ print(f'Create lmdb for {data_path}, save to {lmdb_path}...')
65
+ print(f'Totoal images: {len(img_path_list)}')
66
+ if not lmdb_path.endswith('.lmdb'):
67
+ raise ValueError("lmdb_path must end with '.lmdb'.")
68
+ if osp.exists(lmdb_path):
69
+ print(f'Folder {lmdb_path} already exists. Exit.')
70
+ sys.exit(1)
71
+
72
+ if multiprocessing_read:
73
+ # read all the images to memory (multiprocessing)
74
+ dataset = {} # use dict to keep the order for multiprocessing
75
+ shapes = {}
76
+ print(f'Read images with multiprocessing, #thread: {n_thread} ...')
77
+ pbar = tqdm(total=len(img_path_list), unit='image')
78
+
79
+ def callback(arg):
80
+ """get the image data and update pbar."""
81
+ key, dataset[key], shapes[key] = arg
82
+ pbar.update(1)
83
+ pbar.set_description(f'Read {key}')
84
+
85
+ pool = Pool(n_thread)
86
+ for path, key in zip(img_path_list, keys):
87
+ pool.apply_async(
88
+ read_img_worker,
89
+ args=(osp.join(data_path, path), key, compress_level),
90
+ callback=callback)
91
+ pool.close()
92
+ pool.join()
93
+ pbar.close()
94
+ print(f'Finish reading {len(img_path_list)} images.')
95
+
96
+ # create lmdb environment
97
+ if map_size is None:
98
+ # obtain data size for one image
99
+ img = cv2.imread(
100
+ osp.join(data_path, img_path_list[0]), cv2.IMREAD_UNCHANGED)
101
+ _, img_byte = cv2.imencode(
102
+ '.png', img, [cv2.IMWRITE_PNG_COMPRESSION, compress_level])
103
+ data_size_per_img = img_byte.nbytes
104
+ print('Data size per image is: ', data_size_per_img)
105
+ data_size = data_size_per_img * len(img_path_list)
106
+ map_size = data_size * 10
107
+
108
+ env = lmdb.open(lmdb_path, map_size=map_size)
109
+
110
+ # write data to lmdb
111
+ pbar = tqdm(total=len(img_path_list), unit='chunk')
112
+ txn = env.begin(write=True)
113
+ txt_file = open(osp.join(lmdb_path, 'meta_info.txt'), 'w')
114
+ for idx, (path, key) in enumerate(zip(img_path_list, keys)):
115
+ pbar.update(1)
116
+ pbar.set_description(f'Write {key}')
117
+ key_byte = key.encode('ascii')
118
+ if multiprocessing_read:
119
+ img_byte = dataset[key]
120
+ h, w, c = shapes[key]
121
+ else:
122
+ _, img_byte, img_shape = read_img_worker(
123
+ osp.join(data_path, path), key, compress_level)
124
+ h, w, c = img_shape
125
+
126
+ txn.put(key_byte, img_byte)
127
+ # write meta information
128
+ txt_file.write(f'{key}.png ({h},{w},{c}) {compress_level}\n')
129
+ if idx % batch == 0:
130
+ txn.commit()
131
+ txn = env.begin(write=True)
132
+ pbar.close()
133
+ txn.commit()
134
+ env.close()
135
+ txt_file.close()
136
+ print('\nFinish writing lmdb.')
137
+
138
+
139
+ def read_img_worker(path, key, compress_level):
140
+ """Read image worker.
141
+
142
+ Args:
143
+ path (str): Image path.
144
+ key (str): Image key.
145
+ compress_level (int): Compress level when encoding images.
146
+
147
+ Returns:
148
+ str: Image key.
149
+ byte: Image byte.
150
+ tuple[int]: Image shape.
151
+ """
152
+
153
+ img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
154
+ if img.ndim == 2:
155
+ h, w = img.shape
156
+ c = 1
157
+ else:
158
+ h, w, c = img.shape
159
+ _, img_byte = cv2.imencode('.png', img,
160
+ [cv2.IMWRITE_PNG_COMPRESSION, compress_level])
161
+ return (key, img_byte, (h, w, c))
162
+
163
+
164
+ class LmdbMaker():
165
+ """LMDB Maker.
166
+
167
+ Args:
168
+ lmdb_path (str): Lmdb save path.
169
+ map_size (int): Map size for lmdb env. Default: 1024 ** 4, 1TB.
170
+ batch (int): After processing batch images, lmdb commits.
171
+ Default: 5000.
172
+ compress_level (int): Compress level when encoding images. Default: 1.
173
+ """
174
+
175
+ def __init__(self,
176
+ lmdb_path,
177
+ map_size=1024**4,
178
+ batch=5000,
179
+ compress_level=1):
180
+ if not lmdb_path.endswith('.lmdb'):
181
+ raise ValueError("lmdb_path must end with '.lmdb'.")
182
+ if osp.exists(lmdb_path):
183
+ print(f'Folder {lmdb_path} already exists. Exit.')
184
+ sys.exit(1)
185
+
186
+ self.lmdb_path = lmdb_path
187
+ self.batch = batch
188
+ self.compress_level = compress_level
189
+ self.env = lmdb.open(lmdb_path, map_size=map_size)
190
+ self.txn = self.env.begin(write=True)
191
+ self.txt_file = open(osp.join(lmdb_path, 'meta_info.txt'), 'w')
192
+ self.counter = 0
193
+
194
+ def put(self, img_byte, key, img_shape):
195
+ self.counter += 1
196
+ key_byte = key.encode('ascii')
197
+ self.txn.put(key_byte, img_byte)
198
+ # write meta information
199
+ h, w, c = img_shape
200
+ self.txt_file.write(f'{key}.png ({h},{w},{c}) {self.compress_level}\n')
201
+ if self.counter % self.batch == 0:
202
+ self.txn.commit()
203
+ self.txn = self.env.begin(write=True)
204
+
205
+ def close(self):
206
+ self.txn.commit()
207
+ self.env.close()
208
+ self.txt_file.close()
basicsr/utils/logger.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+ import logging
3
+ import time
4
+
5
+ from .dist_util import get_dist_info, master_only
6
+
7
+ initialized_logger = {}
8
+
9
+
10
+ class MessageLogger():
11
+ """Message logger for printing.
12
+
13
+ Args:
14
+ opt (dict): Config. It contains the following keys:
15
+ name (str): Exp name.
16
+ logger (dict): Contains 'print_freq' (str) for logger interval.
17
+ train (dict): Contains 'total_iter' (int) for total iters.
18
+ use_tb_logger (bool): Use tensorboard logger.
19
+ start_iter (int): Start iter. Default: 1.
20
+ tb_logger (obj:`tb_logger`): Tensorboard logger. Default: None.
21
+ """
22
+
23
+ def __init__(self, opt, start_iter=1, tb_logger=None):
24
+ self.exp_name = opt['name']
25
+ self.interval = opt['logger']['print_freq']
26
+ self.start_iter = start_iter
27
+ self.max_iters = opt['train']['total_iter']
28
+ self.use_tb_logger = opt['logger']['use_tb_logger']
29
+ self.tb_logger = tb_logger
30
+ self.start_time = time.time()
31
+ self.logger = get_root_logger()
32
+
33
+ @master_only
34
+ def __call__(self, log_vars):
35
+ """Format logging message.
36
+
37
+ Args:
38
+ log_vars (dict): It contains the following keys:
39
+ epoch (int): Epoch number.
40
+ iter (int): Current iter.
41
+ lrs (list): List for learning rates.
42
+
43
+ time (float): Iter time.
44
+ data_time (float): Data time for each iter.
45
+ """
46
+ # epoch, iter, learning rates
47
+ epoch = log_vars.pop('epoch')
48
+ current_iter = log_vars.pop('iter')
49
+ lrs = log_vars.pop('lrs')
50
+
51
+ message = (f'[{self.exp_name[:5]}..][epoch:{epoch:3d}, ' f'iter:{current_iter:8,d}, lr:(')
52
+ for v in lrs:
53
+ message += f'{v:.3e},'
54
+ message += ')] '
55
+
56
+ # time and estimated time
57
+ if 'time' in log_vars.keys():
58
+ iter_time = log_vars.pop('time')
59
+ data_time = log_vars.pop('data_time')
60
+
61
+ total_time = time.time() - self.start_time
62
+ time_sec_avg = total_time / (current_iter - self.start_iter + 1)
63
+ eta_sec = time_sec_avg * (self.max_iters - current_iter - 1)
64
+ eta_str = str(datetime.timedelta(seconds=int(eta_sec)))
65
+ message += f'[eta: {eta_str}, '
66
+ message += f'time (data): {iter_time:.3f} ({data_time:.3f})] '
67
+
68
+ # other items, especially losses
69
+ for k, v in log_vars.items():
70
+ message += f'{k}: {v:.4e} '
71
+ # tensorboard logger
72
+ if self.use_tb_logger and 'debug' not in self.exp_name:
73
+ if k.startswith('l_'):
74
+ self.tb_logger.add_scalar(f'losses/{k}', v, current_iter)
75
+ else:
76
+ self.tb_logger.add_scalar(k, v, current_iter)
77
+ self.logger.info(message)
78
+
79
+
80
+ @master_only
81
+ def init_tb_logger(log_dir):
82
+ from torch.utils.tensorboard import SummaryWriter
83
+ tb_logger = SummaryWriter(log_dir=log_dir)
84
+ return tb_logger
85
+
86
+
87
+ @master_only
88
+ def init_wandb_logger(opt):
89
+ """We now only use wandb to sync tensorboard log."""
90
+ import wandb
91
+ logger = logging.getLogger('basicsr')
92
+
93
+ project = opt['logger']['wandb']['project']
94
+ resume_id = opt['logger']['wandb'].get('resume_id')
95
+ if resume_id:
96
+ wandb_id = resume_id
97
+ resume = 'allow'
98
+ logger.warning(f'Resume wandb logger with id={wandb_id}.')
99
+ else:
100
+ wandb_id = wandb.util.generate_id()
101
+ resume = 'never'
102
+
103
+ wandb.init(id=wandb_id, resume=resume, name=opt['name'], config=opt, project=project, sync_tensorboard=True)
104
+
105
+ logger.info(f'Use wandb logger with id={wandb_id}; project={project}.')
106
+
107
+
108
+ def get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=None):
109
+ """Get the root logger.
110
+
111
+ The logger will be initialized if it has not been initialized. By default a
112
+ StreamHandler will be added. If `log_file` is specified, a FileHandler will
113
+ also be added.
114
+
115
+ Args:
116
+ logger_name (str): root logger name. Default: 'basicsr'.
117
+ log_file (str | None): The log filename. If specified, a FileHandler
118
+ will be added to the root logger.
119
+ log_level (int): The root logger level. Note that only the process of
120
+ rank 0 is affected, while other processes will set the level to
121
+ "Error" and be silent most of the time.
122
+
123
+ Returns:
124
+ logging.Logger: The root logger.
125
+ """
126
+ logger = logging.getLogger(logger_name)
127
+ # if the logger has been initialized, just return it
128
+ if logger_name in initialized_logger:
129
+ return logger
130
+
131
+ format_str = '%(asctime)s %(levelname)s: %(message)s'
132
+ stream_handler = logging.StreamHandler()
133
+ stream_handler.setFormatter(logging.Formatter(format_str))
134
+ logger.addHandler(stream_handler)
135
+ logger.propagate = False
136
+ rank, _ = get_dist_info()
137
+ if rank != 0:
138
+ logger.setLevel('ERROR')
139
+ elif log_file is not None:
140
+ logger.setLevel(log_level)
141
+ # add file handler
142
+ file_handler = logging.FileHandler(log_file, 'w')
143
+ file_handler.setFormatter(logging.Formatter(format_str))
144
+ file_handler.setLevel(log_level)
145
+ logger.addHandler(file_handler)
146
+ initialized_logger[logger_name] = True
147
+ return logger
148
+
149
+
150
+ def get_env_info():
151
+ """Get environment information.
152
+
153
+ Currently, only log the software version.
154
+ """
155
+ import torch
156
+ import torchvision
157
+
158
+ from basicsr.version import __version__
159
+ msg = r"""
160
+ ____ _ _____ ____
161
+ / __ ) ____ _ _____ (_)_____/ ___/ / __ \
162
+ / __ |/ __ `// ___// // ___/\__ \ / /_/ /
163
+ / /_/ // /_/ /(__ )/ // /__ ___/ // _, _/
164
+ /_____/ \__,_//____//_/ \___//____//_/ |_|
165
+ ______ __ __ __ __
166
+ / ____/____ ____ ____/ / / / __ __ _____ / /__ / /
167
+ / / __ / __ \ / __ \ / __ / / / / / / // ___// //_/ / /
168
+ / /_/ // /_/ // /_/ // /_/ / / /___/ /_/ // /__ / /< /_/
169
+ \____/ \____/ \____/ \____/ /_____/\____/ \___//_/|_| (_)
170
+ """
171
+ msg += ('\nVersion Information: '
172
+ f'\n\tBasicSR: {__version__}'
173
+ f'\n\tPyTorch: {torch.__version__}'
174
+ f'\n\tTorchVision: {torchvision.__version__}')
175
+ return msg
basicsr/utils/matlab_functions.py ADDED
@@ -0,0 +1,361 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import numpy as np
3
+ import torch
4
+
5
+
6
+ def cubic(x):
7
+ """cubic function used for calculate_weights_indices."""
8
+ absx = torch.abs(x)
9
+ absx2 = absx**2
10
+ absx3 = absx**3
11
+ return (1.5 * absx3 - 2.5 * absx2 + 1) * (
12
+ (absx <= 1).type_as(absx)) + (-0.5 * absx3 + 2.5 * absx2 - 4 * absx +
13
+ 2) * (((absx > 1) *
14
+ (absx <= 2)).type_as(absx))
15
+
16
+
17
+ def calculate_weights_indices(in_length, out_length, scale, kernel,
18
+ kernel_width, antialiasing):
19
+ """Calculate weights and indices, used for imresize function.
20
+
21
+ Args:
22
+ in_length (int): Input length.
23
+ out_length (int): Output length.
24
+ scale (float): Scale factor.
25
+ kernel_width (int): Kernel width.
26
+ antialisaing (bool): Whether to apply anti-aliasing when downsampling.
27
+ """
28
+
29
+ if (scale < 1) and antialiasing:
30
+ # Use a modified kernel (larger kernel width) to simultaneously
31
+ # interpolate and antialias
32
+ kernel_width = kernel_width / scale
33
+
34
+ # Output-space coordinates
35
+ x = torch.linspace(1, out_length, out_length)
36
+
37
+ # Input-space coordinates. Calculate the inverse mapping such that 0.5
38
+ # in output space maps to 0.5 in input space, and 0.5 + scale in output
39
+ # space maps to 1.5 in input space.
40
+ u = x / scale + 0.5 * (1 - 1 / scale)
41
+
42
+ # What is the left-most pixel that can be involved in the computation?
43
+ left = torch.floor(u - kernel_width / 2)
44
+
45
+ # What is the maximum number of pixels that can be involved in the
46
+ # computation? Note: it's OK to use an extra pixel here; if the
47
+ # corresponding weights are all zero, it will be eliminated at the end
48
+ # of this function.
49
+ p = math.ceil(kernel_width) + 2
50
+
51
+ # The indices of the input pixels involved in computing the k-th output
52
+ # pixel are in row k of the indices matrix.
53
+ indices = left.view(out_length, 1).expand(out_length, p) + torch.linspace(
54
+ 0, p - 1, p).view(1, p).expand(out_length, p)
55
+
56
+ # The weights used to compute the k-th output pixel are in row k of the
57
+ # weights matrix.
58
+ distance_to_center = u.view(out_length, 1).expand(out_length, p) - indices
59
+
60
+ # apply cubic kernel
61
+ if (scale < 1) and antialiasing:
62
+ weights = scale * cubic(distance_to_center * scale)
63
+ else:
64
+ weights = cubic(distance_to_center)
65
+
66
+ # Normalize the weights matrix so that each row sums to 1.
67
+ weights_sum = torch.sum(weights, 1).view(out_length, 1)
68
+ weights = weights / weights_sum.expand(out_length, p)
69
+
70
+ # If a column in weights is all zero, get rid of it. only consider the
71
+ # first and last column.
72
+ weights_zero_tmp = torch.sum((weights == 0), 0)
73
+ if not math.isclose(weights_zero_tmp[0], 0, rel_tol=1e-6):
74
+ indices = indices.narrow(1, 1, p - 2)
75
+ weights = weights.narrow(1, 1, p - 2)
76
+ if not math.isclose(weights_zero_tmp[-1], 0, rel_tol=1e-6):
77
+ indices = indices.narrow(1, 0, p - 2)
78
+ weights = weights.narrow(1, 0, p - 2)
79
+ weights = weights.contiguous()
80
+ indices = indices.contiguous()
81
+ sym_len_s = -indices.min() + 1
82
+ sym_len_e = indices.max() - in_length
83
+ indices = indices + sym_len_s - 1
84
+ return weights, indices, int(sym_len_s), int(sym_len_e)
85
+
86
+
87
+ @torch.no_grad()
88
+ def imresize(img, scale, antialiasing=True):
89
+ """imresize function same as MATLAB.
90
+
91
+ It now only supports bicubic.
92
+ The same scale applies for both height and width.
93
+
94
+ Args:
95
+ img (Tensor | Numpy array):
96
+ Tensor: Input image with shape (c, h, w), [0, 1] range.
97
+ Numpy: Input image with shape (h, w, c), [0, 1] range.
98
+ scale (float): Scale factor. The same scale applies for both height
99
+ and width.
100
+ antialisaing (bool): Whether to apply anti-aliasing when downsampling.
101
+ Default: True.
102
+
103
+ Returns:
104
+ Tensor: Output image with shape (c, h, w), [0, 1] range, w/o round.
105
+ """
106
+ if type(img).__module__ == np.__name__: # numpy type
107
+ numpy_type = True
108
+ img = torch.from_numpy(img.transpose(2, 0, 1)).float()
109
+ else:
110
+ numpy_type = False
111
+
112
+ in_c, in_h, in_w = img.size()
113
+ out_h, out_w = math.ceil(in_h * scale), math.ceil(in_w * scale)
114
+ kernel_width = 4
115
+ kernel = 'cubic'
116
+
117
+ # get weights and indices
118
+ weights_h, indices_h, sym_len_hs, sym_len_he = calculate_weights_indices(
119
+ in_h, out_h, scale, kernel, kernel_width, antialiasing)
120
+ weights_w, indices_w, sym_len_ws, sym_len_we = calculate_weights_indices(
121
+ in_w, out_w, scale, kernel, kernel_width, antialiasing)
122
+ # process H dimension
123
+ # symmetric copying
124
+ img_aug = torch.FloatTensor(in_c, in_h + sym_len_hs + sym_len_he, in_w)
125
+ img_aug.narrow(1, sym_len_hs, in_h).copy_(img)
126
+
127
+ sym_patch = img[:, :sym_len_hs, :]
128
+ inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
129
+ sym_patch_inv = sym_patch.index_select(1, inv_idx)
130
+ img_aug.narrow(1, 0, sym_len_hs).copy_(sym_patch_inv)
131
+
132
+ sym_patch = img[:, -sym_len_he:, :]
133
+ inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
134
+ sym_patch_inv = sym_patch.index_select(1, inv_idx)
135
+ img_aug.narrow(1, sym_len_hs + in_h, sym_len_he).copy_(sym_patch_inv)
136
+
137
+ out_1 = torch.FloatTensor(in_c, out_h, in_w)
138
+ kernel_width = weights_h.size(1)
139
+ for i in range(out_h):
140
+ idx = int(indices_h[i][0])
141
+ for j in range(in_c):
142
+ out_1[j, i, :] = img_aug[j, idx:idx + kernel_width, :].transpose(
143
+ 0, 1).mv(weights_h[i])
144
+
145
+ # process W dimension
146
+ # symmetric copying
147
+ out_1_aug = torch.FloatTensor(in_c, out_h, in_w + sym_len_ws + sym_len_we)
148
+ out_1_aug.narrow(2, sym_len_ws, in_w).copy_(out_1)
149
+
150
+ sym_patch = out_1[:, :, :sym_len_ws]
151
+ inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
152
+ sym_patch_inv = sym_patch.index_select(2, inv_idx)
153
+ out_1_aug.narrow(2, 0, sym_len_ws).copy_(sym_patch_inv)
154
+
155
+ sym_patch = out_1[:, :, -sym_len_we:]
156
+ inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
157
+ sym_patch_inv = sym_patch.index_select(2, inv_idx)
158
+ out_1_aug.narrow(2, sym_len_ws + in_w, sym_len_we).copy_(sym_patch_inv)
159
+
160
+ out_2 = torch.FloatTensor(in_c, out_h, out_w)
161
+ kernel_width = weights_w.size(1)
162
+ for i in range(out_w):
163
+ idx = int(indices_w[i][0])
164
+ for j in range(in_c):
165
+ out_2[j, :, i] = out_1_aug[j, :,
166
+ idx:idx + kernel_width].mv(weights_w[i])
167
+
168
+ if numpy_type:
169
+ out_2 = out_2.numpy().transpose(1, 2, 0)
170
+ return out_2
171
+
172
+
173
+ def rgb2ycbcr(img, y_only=False):
174
+ """Convert a RGB image to YCbCr image.
175
+
176
+ This function produces the same results as Matlab's `rgb2ycbcr` function.
177
+ It implements the ITU-R BT.601 conversion for standard-definition
178
+ television. See more details in
179
+ https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
180
+
181
+ It differs from a similar function in cv2.cvtColor: `RGB <-> YCrCb`.
182
+ In OpenCV, it implements a JPEG conversion. See more details in
183
+ https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
184
+
185
+ Args:
186
+ img (ndarray): The input image. It accepts:
187
+ 1. np.uint8 type with range [0, 255];
188
+ 2. np.float32 type with range [0, 1].
189
+ y_only (bool): Whether to only return Y channel. Default: False.
190
+
191
+ Returns:
192
+ ndarray: The converted YCbCr image. The output image has the same type
193
+ and range as input image.
194
+ """
195
+ img_type = img.dtype
196
+ img = _convert_input_type_range(img)
197
+ if y_only:
198
+ out_img = np.dot(img, [65.481, 128.553, 24.966]) + 16.0
199
+ else:
200
+ out_img = np.matmul(
201
+ img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786],
202
+ [24.966, 112.0, -18.214]]) + [16, 128, 128]
203
+ out_img = _convert_output_type_range(out_img, img_type)
204
+ return out_img
205
+
206
+
207
+ def bgr2ycbcr(img, y_only=False):
208
+ """Convert a BGR image to YCbCr image.
209
+
210
+ The bgr version of rgb2ycbcr.
211
+ It implements the ITU-R BT.601 conversion for standard-definition
212
+ television. See more details in
213
+ https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
214
+
215
+ It differs from a similar function in cv2.cvtColor: `BGR <-> YCrCb`.
216
+ In OpenCV, it implements a JPEG conversion. See more details in
217
+ https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
218
+
219
+ Args:
220
+ img (ndarray): The input image. It accepts:
221
+ 1. np.uint8 type with range [0, 255];
222
+ 2. np.float32 type with range [0, 1].
223
+ y_only (bool): Whether to only return Y channel. Default: False.
224
+
225
+ Returns:
226
+ ndarray: The converted YCbCr image. The output image has the same type
227
+ and range as input image.
228
+ """
229
+ img_type = img.dtype
230
+ img = _convert_input_type_range(img)
231
+ if y_only:
232
+ out_img = np.dot(img, [24.966, 128.553, 65.481]) + 16.0
233
+ else:
234
+ out_img = np.matmul(
235
+ img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786],
236
+ [65.481, -37.797, 112.0]]) + [16, 128, 128]
237
+ out_img = _convert_output_type_range(out_img, img_type)
238
+ return out_img
239
+
240
+
241
+ def ycbcr2rgb(img):
242
+ """Convert a YCbCr image to RGB image.
243
+
244
+ This function produces the same results as Matlab's ycbcr2rgb function.
245
+ It implements the ITU-R BT.601 conversion for standard-definition
246
+ television. See more details in
247
+ https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
248
+
249
+ It differs from a similar function in cv2.cvtColor: `YCrCb <-> RGB`.
250
+ In OpenCV, it implements a JPEG conversion. See more details in
251
+ https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
252
+
253
+ Args:
254
+ img (ndarray): The input image. It accepts:
255
+ 1. np.uint8 type with range [0, 255];
256
+ 2. np.float32 type with range [0, 1].
257
+
258
+ Returns:
259
+ ndarray: The converted RGB image. The output image has the same type
260
+ and range as input image.
261
+ """
262
+ img_type = img.dtype
263
+ img = _convert_input_type_range(img) * 255
264
+ out_img = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621],
265
+ [0, -0.00153632, 0.00791071],
266
+ [0.00625893, -0.00318811, 0]]) * 255.0 + [
267
+ -222.921, 135.576, -276.836
268
+ ] # noqa: E126
269
+ out_img = _convert_output_type_range(out_img, img_type)
270
+ return out_img
271
+
272
+
273
+ def ycbcr2bgr(img):
274
+ """Convert a YCbCr image to BGR image.
275
+
276
+ The bgr version of ycbcr2rgb.
277
+ It implements the ITU-R BT.601 conversion for standard-definition
278
+ television. See more details in
279
+ https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
280
+
281
+ It differs from a similar function in cv2.cvtColor: `YCrCb <-> BGR`.
282
+ In OpenCV, it implements a JPEG conversion. See more details in
283
+ https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
284
+
285
+ Args:
286
+ img (ndarray): The input image. It accepts:
287
+ 1. np.uint8 type with range [0, 255];
288
+ 2. np.float32 type with range [0, 1].
289
+
290
+ Returns:
291
+ ndarray: The converted BGR image. The output image has the same type
292
+ and range as input image.
293
+ """
294
+ img_type = img.dtype
295
+ img = _convert_input_type_range(img) * 255
296
+ out_img = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621],
297
+ [0.00791071, -0.00153632, 0],
298
+ [0, -0.00318811, 0.00625893]]) * 255.0 + [
299
+ -276.836, 135.576, -222.921
300
+ ] # noqa: E126
301
+ out_img = _convert_output_type_range(out_img, img_type)
302
+ return out_img
303
+
304
+
305
+ def _convert_input_type_range(img):
306
+ """Convert the type and range of the input image.
307
+
308
+ It converts the input image to np.float32 type and range of [0, 1].
309
+ It is mainly used for pre-processing the input image in colorspace
310
+ convertion functions such as rgb2ycbcr and ycbcr2rgb.
311
+
312
+ Args:
313
+ img (ndarray): The input image. It accepts:
314
+ 1. np.uint8 type with range [0, 255];
315
+ 2. np.float32 type with range [0, 1].
316
+
317
+ Returns:
318
+ (ndarray): The converted image with type of np.float32 and range of
319
+ [0, 1].
320
+ """
321
+ img_type = img.dtype
322
+ img = img.astype(np.float32)
323
+ if img_type == np.float32:
324
+ pass
325
+ elif img_type == np.uint8:
326
+ img /= 255.
327
+ else:
328
+ raise TypeError('The img type should be np.float32 or np.uint8, '
329
+ f'but got {img_type}')
330
+ return img
331
+
332
+
333
+ def _convert_output_type_range(img, dst_type):
334
+ """Convert the type and range of the image according to dst_type.
335
+
336
+ It converts the image to desired type and range. If `dst_type` is np.uint8,
337
+ images will be converted to np.uint8 type with range [0, 255]. If
338
+ `dst_type` is np.float32, it converts the image to np.float32 type with
339
+ range [0, 1].
340
+ It is mainly used for post-processing images in colorspace convertion
341
+ functions such as rgb2ycbcr and ycbcr2rgb.
342
+
343
+ Args:
344
+ img (ndarray): The image to be converted with np.float32 type and
345
+ range [0, 255].
346
+ dst_type (np.uint8 | np.float32): If dst_type is np.uint8, it
347
+ converts the image to np.uint8 type with range [0, 255]. If
348
+ dst_type is np.float32, it converts the image to np.float32 type
349
+ with range [0, 1].
350
+
351
+ Returns:
352
+ (ndarray): The converted image with desired type and range.
353
+ """
354
+ if dst_type not in (np.uint8, np.float32):
355
+ raise TypeError('The dst_type should be np.float32 or np.uint8, '
356
+ f'but got {dst_type}')
357
+ if dst_type == np.uint8:
358
+ img = img.round()
359
+ else:
360
+ img /= 255.
361
+ return img.astype(dst_type)
basicsr/utils/misc.py ADDED
@@ -0,0 +1,266 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import os
3
+ import random
4
+ import time
5
+ import torch
6
+ from os import path as osp
7
+
8
+ from .dist_util import master_only
9
+ from .logger import get_root_logger
10
+
11
+
12
+ def set_random_seed(seed):
13
+ """Set random seeds."""
14
+ random.seed(seed)
15
+ np.random.seed(seed)
16
+ torch.manual_seed(seed)
17
+ torch.cuda.manual_seed(seed)
18
+ torch.cuda.manual_seed_all(seed)
19
+
20
+
21
+ def get_time_str():
22
+ return time.strftime('%Y%m%d_%H%M%S', time.localtime())
23
+
24
+
25
+ def mkdir_and_rename(path):
26
+ """mkdirs. If path exists, rename it with timestamp and create a new one.
27
+
28
+ Args:
29
+ path (str): Folder path.
30
+ """
31
+ # if osp.exists(path):
32
+ # new_name = path + '_archived_' + get_time_str()
33
+ # print(f'Path already exists. Rename it to {new_name}', flush=True)
34
+ # os.rename(path, new_name)
35
+ os.makedirs(path, exist_ok=True)
36
+
37
+
38
+ @master_only
39
+ def make_exp_dirs(opt):
40
+ """Make dirs for experiments."""
41
+ path_opt = opt['path'].copy()
42
+ if opt['is_train']:
43
+ mkdir_and_rename(path_opt.pop('experiments_root'))
44
+ else:
45
+ mkdir_and_rename(path_opt.pop('results_root'))
46
+ for key, path in path_opt.items():
47
+ if ('strict_load' not in key) and ('pretrain_network'
48
+ not in key) and ('resume'
49
+ not in key):
50
+ os.makedirs(path, exist_ok=True)
51
+
52
+
53
+ def scandir(dir_path, suffix=None, recursive=False, full_path=False):
54
+ """Scan a directory to find the interested files.
55
+
56
+ Args:
57
+ dir_path (str): Path of the directory.
58
+ suffix (str | tuple(str), optional): File suffix that we are
59
+ interested in. Default: None.
60
+ recursive (bool, optional): If set to True, recursively scan the
61
+ directory. Default: False.
62
+ full_path (bool, optional): If set to True, include the dir_path.
63
+ Default: False.
64
+
65
+ Returns:
66
+ A generator for all the interested files with relative pathes.
67
+ """
68
+
69
+ if (suffix is not None) and not isinstance(suffix, (str, tuple)):
70
+ raise TypeError('"suffix" must be a string or tuple of strings')
71
+
72
+ root = dir_path
73
+
74
+ def _scandir(dir_path, suffix, recursive):
75
+ for entry in os.scandir(dir_path):
76
+ if not entry.name.startswith('.') and entry.is_file():
77
+ if full_path:
78
+ return_path = entry.path
79
+ else:
80
+ return_path = osp.relpath(entry.path, root)
81
+
82
+ if suffix is None:
83
+ yield return_path
84
+ elif return_path.endswith(suffix):
85
+ yield return_path
86
+ else:
87
+ if recursive:
88
+ yield from _scandir(
89
+ entry.path, suffix=suffix, recursive=recursive)
90
+ else:
91
+ continue
92
+
93
+ return _scandir(dir_path, suffix=suffix, recursive=recursive)
94
+
95
+ def scandir_mv(dir_path, suffix=None, recursive=False, full_path=False, lq=True):
96
+ """Scan a directory to find the interested files.
97
+
98
+ Args:
99
+ dir_path (str): Path of the directory.
100
+ suffix (str | tuple(str), optional): File suffix that we are
101
+ interested in. Default: None.
102
+ recursive (bool, optional): If set to True, recursively scan the
103
+ directory. Default: False.
104
+ full_path (bool, optional): If set to True, include the dir_path.
105
+ Default: False.
106
+
107
+ Returns:
108
+ A generator for all the interested files with relative pathes.
109
+ """
110
+
111
+ if (suffix is not None) and not isinstance(suffix, (str, tuple)):
112
+ raise TypeError('"suffix" must be a string or tuple of strings')
113
+
114
+ root = dir_path
115
+ _type = "no_noise" if lq else "images"
116
+
117
+
118
+ # 1,3K는 아직 안만들어져서 2K까지 받는다고 가정
119
+ def _scandir(dir_path, suffix, recursive):
120
+ folders = os.listdir(dir_path)
121
+ all_files = []
122
+ for folder in folders: # tag
123
+ all_files.append(osp.join(dir_path, folder, "images_4")) # ~~train/46/0398fdk3/no_noise
124
+
125
+ # 아래는 1,2,3K 다 쓰는 경우
126
+ # subfolders = os.listdir(osp.join(dir_path, folder)) # images4
127
+ # for subfolder in subfolders:
128
+ # all_files.append(osp.join(dir_path, folder, subfolder, _type)) # ~~train/46/0398fdk3/no_noise
129
+
130
+
131
+ return all_files
132
+ return _scandir(dir_path, suffix, recursive)
133
+
134
+ def scandir_mv_flat(dir_path, suffix=None, recursive=False, full_path=False, lq=True):
135
+ """Scan a directory to find the interested files.
136
+
137
+ Args:
138
+ dir_path (str): Path of the directory.
139
+ suffix (str | tuple(str), optional): File suffix that we are
140
+ interested in. Default: None.
141
+ recursive (bool, optional): If set to True, recursively scan the
142
+ directory. Default: False.
143
+ full_path (bool, optional): If set to True, include the dir_path.
144
+ Default: False.
145
+
146
+ Returns:
147
+ A generator for all the interested files with relative pathes.
148
+ """
149
+
150
+ if (suffix is not None) and not isinstance(suffix, (str, tuple)):
151
+ raise TypeError('"suffix" must be a string or tuple of strings')
152
+
153
+ root = dir_path
154
+ _type = "no_noise" if lq else "images"
155
+
156
+ def _scandir(dir_path, suffix, recursive):
157
+ for entry in os.scandir(dir_path):
158
+ if not entry.name.startswith('.') and entry.is_file():
159
+ if full_path:
160
+ return_path = entry.path
161
+ else:
162
+ return_path = osp.relpath(entry.path, root)
163
+
164
+ if suffix is None:
165
+ yield return_path
166
+ elif return_path.endswith(suffix):
167
+ yield return_path
168
+ else:
169
+ if recursive:
170
+ if entry.name in ["both_noises", "gaussian_only", "images", "no_noise", "no_noise_BGR", "poisson_only", "sparse"]:
171
+ if entry.name != _type:
172
+ continue
173
+ yield from _scandir(
174
+ entry.path, suffix=suffix, recursive=recursive)
175
+ else:
176
+ continue
177
+
178
+ return _scandir(dir_path, suffix=suffix, recursive=recursive)
179
+
180
+
181
+ def scandir_SIDD(dir_path, keywords=None, recursive=False, full_path=False):
182
+ """Scan a directory to find the interested files.
183
+
184
+ Args:
185
+ dir_path (str): Path of the directory.
186
+ keywords (str | tuple(str), optional): File keywords that we are
187
+ interested in. Default: None.
188
+ recursive (bool, optional): If set to True, recursively scan the
189
+ directory. Default: False.
190
+ full_path (bool, optional): If set to True, include the dir_path.
191
+ Default: False.
192
+
193
+ Returns:
194
+ A generator for all the interested files with relative pathes.
195
+ """
196
+
197
+ if (keywords is not None) and not isinstance(keywords, (str, tuple)):
198
+ raise TypeError('"keywords" must be a string or tuple of strings')
199
+
200
+ root = dir_path
201
+
202
+ def _scandir(dir_path, keywords, recursive):
203
+ for entry in os.scandir(dir_path):
204
+ if not entry.name.startswith('.') and entry.is_file():
205
+ if full_path:
206
+ return_path = entry.path
207
+ else:
208
+ return_path = osp.relpath(entry.path, root)
209
+
210
+ if keywords is None:
211
+ yield return_path
212
+ elif return_path.find(keywords) > 0:
213
+ yield return_path
214
+ else:
215
+ if recursive:
216
+ yield from _scandir(
217
+ entry.path, keywords=keywords, recursive=recursive)
218
+ else:
219
+ continue
220
+
221
+ return _scandir(dir_path, keywords=keywords, recursive=recursive)
222
+
223
+ def check_resume(opt, resume_iter):
224
+ """Check resume states and pretrain_network paths.
225
+
226
+ Args:
227
+ opt (dict): Options.
228
+ resume_iter (int): Resume iteration.
229
+ """
230
+ logger = get_root_logger()
231
+ if opt['path']['resume_state']:
232
+ # get all the networks
233
+ networks = [key for key in opt.keys() if key.startswith('network_')]
234
+ flag_pretrain = False
235
+ for network in networks:
236
+ if opt['path'].get(f'pretrain_{network}') is not None:
237
+ flag_pretrain = True
238
+ if flag_pretrain:
239
+ logger.warning(
240
+ 'pretrain_network path will be ignored during resuming.')
241
+ # set pretrained model paths
242
+ for network in networks:
243
+ name = f'pretrain_{network}'
244
+ basename = network.replace('network_', '')
245
+ if opt['path'].get('ignore_resume_networks') is None or (
246
+ basename not in opt['path']['ignore_resume_networks']):
247
+ opt['path'][name] = osp.join(
248
+ opt['path']['models'], f'net_{basename}_{resume_iter}.pth')
249
+ logger.info(f"Set {name} to {opt['path'][name]}")
250
+
251
+
252
+ def sizeof_fmt(size, suffix='B'):
253
+ """Get human readable file size.
254
+
255
+ Args:
256
+ size (int): File size.
257
+ suffix (str): Suffix. Default: 'B'.
258
+
259
+ Return:
260
+ str: Formated file siz.
261
+ """
262
+ for unit in ['', 'K', 'M', 'G', 'T', 'P', 'E', 'Z']:
263
+ if abs(size) < 1024.0:
264
+ return f'{size:3.1f} {unit}{suffix}'
265
+ size /= 1024.0
266
+ return f'{size:3.1f} Y{suffix}'
basicsr/utils/nano.py ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import numpy as np
4
+ from torch.distributions.poisson import Poisson
5
+ import random
6
+
7
+
8
+ def crop_to_bounding_box(image, offset_height, offset_width, target_height,
9
+ target_width, is_batch):
10
+ # BHWC -> BHWC
11
+ cropped = image[:, offset_height: offset_height + target_height, offset_width: offset_width + target_width, :]
12
+
13
+ if not is_batch:
14
+ cropped = cropped[0]
15
+
16
+ return cropped
17
+
18
+ def crop_to_bounding_box_list(image, offset_height, offset_width, target_height,
19
+ target_width):
20
+ # HWC
21
+ cropped = [_image[offset_height: offset_height + target_height, offset_width: offset_width + target_width, :] for _image in image]
22
+
23
+ return cropped
24
+
25
+ def pad_to_bounding_box(image, offset_height, offset_width, target_height,
26
+ target_width, is_batch):
27
+ _,height,width,_ = image.shape
28
+ after_padding_width = target_width - offset_width - width
29
+ after_padding_height = target_height - offset_height - height
30
+
31
+ paddings = (0, 0, offset_width, after_padding_width, offset_height, after_padding_height, 0, 0)
32
+
33
+ padded = torch.nn.functional.pad(image, paddings)
34
+ if not is_batch:
35
+ padded = padded[0]
36
+
37
+ return padded
38
+
39
+ def resize_with_crop_or_pad_torch(image, target_height, target_width):
40
+ # BHWC -> BHWC
41
+
42
+ is_batch = True
43
+ if image.ndim == 3:
44
+ is_batch = False
45
+ image = image[None] # 1HWC
46
+
47
+ def max_(x, y):
48
+ return max(x, y)
49
+
50
+ def min_(x, y):
51
+ return min(x, y)
52
+
53
+ def equal_(x, y):
54
+ return x == y
55
+
56
+ _, height, width, _ = image.shape
57
+ width_diff = target_width - width
58
+ offset_crop_width = max_(-width_diff // 2, 0)
59
+ offset_pad_width = max_(width_diff // 2, 0)
60
+
61
+ height_diff = target_height - height
62
+ offset_crop_height = max_(-height_diff // 2, 0)
63
+ offset_pad_height = max_(height_diff // 2, 0)
64
+
65
+ # Maybe crop if needed.
66
+ cropped = crop_to_bounding_box(image, offset_crop_height, offset_crop_width,
67
+ min_(target_height, height),
68
+ min_(target_width, width), is_batch)
69
+
70
+ # Maybe pad if needed.
71
+ if not is_batch and cropped.ndim == 3:
72
+ cropped = cropped[None]
73
+ resized = pad_to_bounding_box(cropped, offset_pad_height, offset_pad_width,
74
+ target_height, target_width, is_batch)
75
+
76
+ return resized
77
+
78
+
79
+
80
+ def psf2otf(psf, h=None, w=None, permute=False):
81
+ '''
82
+ psf = (b) h,w,c
83
+ '''
84
+ if h is not None:
85
+ psf = resize_with_crop_or_pad_torch(psf, h, w)
86
+ if permute:
87
+ if psf.ndim == 3:
88
+ psf = psf.permute(2,0,1) # HWC -> CHW
89
+ else:
90
+ psf = psf.permute(0,3,1,2) # HWC -> CHW
91
+ psf = psf.to(torch.complex64)
92
+ psf = torch.fft.fftshift(psf, dim=(-1,-2))
93
+ otf = torch.fft.fft2(psf)
94
+ return otf
95
+
96
+ def fft(img): # CHW
97
+ img = img.to(torch.complex64)
98
+ Fimg = torch.fft.fft2(img)
99
+ return Fimg
100
+
101
+ def ifft(Fimg):
102
+ img = torch.abs(torch.fft.ifft2(Fimg)).to(torch.float32)
103
+ return img
104
+
105
+
106
+ def create_contrast_mask(image):
107
+ return 1 - torch.mean(image, dim=(-1,-2), keepdim=True) # (B), C,1,1
108
+
109
+ def apply_tikhonov(lr_img, psf, K, norm=True, otf=None):
110
+ h,w = lr_img.shape[-2:]
111
+ if otf is None:
112
+ psf_norm = resize_with_crop_or_pad_torch(psf, h, w)
113
+ if norm:
114
+ psf_norm = psf_norm / psf_norm.sum((0, 1))
115
+ otf = psf2otf(psf_norm, h, w, permute=True)
116
+
117
+ otf = otf[:,None,...] # B,1,C,H,W
118
+ contrast_mask = create_contrast_mask(lr_img)[:,None,...] # B,1,C,1,1
119
+ K_adjusted = K * contrast_mask # B,M,C,1,1
120
+ tikhonov_filter = torch.conj(otf) / (torch.abs(otf) ** 2 + K_adjusted) # B,M,C,H,W
121
+ lr_fft = fft(lr_img)[:,None,...] # B,1,C,H,W
122
+ deconvolved_fft = lr_fft * tikhonov_filter
123
+ deconvolved_image = torch.fft.ifft2(deconvolved_fft).real
124
+ deconvolved_image = torch.clamp(deconvolved_image, min=0.0, max=1.0)
125
+
126
+ return deconvolved_image # B,M,C,H,W
127
+
128
+
129
+ def add_noise_all_new(image, poss=4e-5, gaus=1e-5):
130
+ p = Poisson(image / poss)
131
+ sampled = p.sample((1,))[0]
132
+ poss_img = sampled * poss
133
+ gauss_noise = torch.randn_like(image) * gaus
134
+ noised_img = poss_img + gauss_noise
135
+
136
+ noised_img = torch.clamp(noised_img, 0.0, 1.0)
137
+
138
+ return noised_img
139
+
140
+
141
+ def apply_convolution(image, psf, pad):
142
+ '''
143
+ input: hr img (b,c,h,w, [0,1])
144
+ output: noised lr img (b,c,h+P,w+P, [0,1])
145
+ '''
146
+
147
+ # metalens simulation
148
+ image = F.pad(image, (pad, pad, pad, pad))
149
+ h,w = image.shape[-2:]
150
+ psf_norm = resize_with_crop_or_pad_torch(psf, h, w)
151
+ otf = psf2otf(psf_norm, h, w, permute=True)
152
+ lr_img = fft(image) * otf
153
+ lr_img = torch.clamp(ifft(lr_img), min=1e-20, max=1.0)
154
+
155
+ # noise addition
156
+ noised_img = add_noise_all_new(lr_img)
157
+
158
+ return noised_img, otf
159
+
160
+ def apply_conv_n_deconv(image, otf, padding, M, psize, ks=None, ph=135, num_psf=9, sensor_h=1215, crop=True, conv=True):
161
+ '''
162
+ input: hr img (b,c,h,w)
163
+ otf: 1,N,C,H,W
164
+ output: noised lr img (N,c,h,w)
165
+ '''
166
+
167
+ b,_,_,_ = image.shape
168
+ if conv:
169
+ img_patch = F.unfold(image, kernel_size=ph*3, stride=ph).view(b,3,ph*3,ph*3,num_psf**2).permute(0,4,1,2,3).contiguous() # B,N,C,H,W
170
+
171
+ # metalens simulation
172
+ lr_img = fft(img_patch) * otf
173
+ lr_img = torch.clamp(ifft(lr_img), min=1e-20, max=1.0)
174
+
175
+ # noise addtion
176
+ lr_img = add_noise_all_new(lr_img)
177
+
178
+ else: # load convolved image for validation
179
+ b = 1
180
+ lr_img = image
181
+
182
+ # apply deconvolution
183
+ if ks is not None:
184
+ lr_img = apply_tikhonov(lr_img, None, ks, otf=otf) # B,M,N,C,405,405
185
+ lr_img = lr_img[..., ph:-ph, ph:-ph] # BMNCHW
186
+ lr_img = lr_img.view(b, M, num_psf, num_psf, 3, ph, ph).permute(0,1,4,2,5,3,6).reshape(b,M,3,sensor_h,sensor_h)
187
+ else:
188
+ lr_img = lr_img[..., ph:-ph, ph:-ph] # BNCHW
189
+ lr_img = lr_img.view(b, num_psf, num_psf, 3, ph, ph).permute(0,3,1,4,2,5).reshape(b,3,sensor_h,sensor_h)
190
+
191
+ lq_patches = []
192
+ gt_patches = []
193
+ for i in range(b):
194
+ cur = lr_img[i] # (M),C,H,W
195
+ cur_gt = image[i]
196
+
197
+ # remove padding for lq and gt
198
+ pt,pb,pl,pr = padding[i]
199
+ if pb and pt:
200
+ cur = cur[...,pt: -pb, :]
201
+ cur_gt = cur_gt[...,pt+ph: -(pb+ph), ph:-ph]
202
+ elif pl and pr:
203
+ cur = cur[...,pl:-pr]
204
+ cur_gt = cur_gt[...,ph:-ph, pl+ph: -(pr+ph)]
205
+ else:
206
+ cur_gt = cur_gt[...,ph:-ph, ph: -ph]
207
+ h,w = cur.shape[-2:]
208
+
209
+ # randomly crop patch for training
210
+ if crop: # train
211
+ top = random.randint(0, h - psize)
212
+ left = random.randint(0, w - psize)
213
+ lq_patches.append(cur[..., top:top + psize, left:left + psize])
214
+ gt_patches.append(cur_gt[..., top:top + psize, left:left + psize])
215
+ if crop: # training
216
+ lq_patches = torch.stack(lq_patches)
217
+ gt_patches = torch.stack(gt_patches)
218
+ else: # validation
219
+ return cur, cur_gt
220
+
221
+ return lq_patches, gt_patches # B,(M),C,H,W
222
+
223
+
224
+ def apply_convolution_square_val(image, otf, padding, M, psize, ks=None, ph=135, num_psf=9, sensor_h=1215, crop=False):
225
+ '''
226
+ merge to above one.
227
+ image = lr_image
228
+ '''
229
+ lr_img = image
230
+ b = 1
231
+ if M: # apply deconvolution
232
+ lr_img = apply_tikhonov(lr_img, None, ks, otf=otf) # B,M,N,C,H,W
233
+ lr_img = lr_img[..., ph:-ph, ph:-ph] # B,M,N,C,H,W
234
+ lr_img = lr_img.view(b, M, num_psf, num_psf, 3, ph, ph).permute(0,1,4,2,5,3,6).reshape(b,M,3,sensor_h,sensor_h)
235
+ else:
236
+ lr_img = lr_img[..., ph:-ph, ph:-ph] # B,N,C,H,W
237
+ lr_img = lr_img.view(b, num_psf, num_psf, 3, ph, ph).permute(0,3,1,4,2,5).reshape(b,3,sensor_h,sensor_h)
238
+
239
+
240
+ for i in range(b):
241
+ cur = lr_img[i] # (M),C,H,W
242
+
243
+ # remove padding for lq and gt
244
+ pt,pb,pl,pr = padding[i]
245
+ if pb and pt:
246
+ cur = cur[...,pt: -pb, :]
247
+ elif pl and pr:
248
+ cur = cur[...,pl:-pr]
249
+
250
+ return cur
basicsr/utils/options.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import yaml
2
+ from collections import OrderedDict
3
+ from os import path as osp
4
+
5
+
6
+ def ordered_yaml():
7
+ """Support OrderedDict for yaml.
8
+
9
+ Returns:
10
+ yaml Loader and Dumper.
11
+ """
12
+ try:
13
+ from yaml import CDumper as Dumper
14
+ from yaml import CLoader as Loader
15
+ except ImportError:
16
+ from yaml import Dumper, Loader
17
+
18
+ _mapping_tag = yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG
19
+
20
+ def dict_representer(dumper, data):
21
+ return dumper.represent_dict(data.items())
22
+
23
+ def dict_constructor(loader, node):
24
+ return OrderedDict(loader.construct_pairs(node))
25
+
26
+ Dumper.add_representer(OrderedDict, dict_representer)
27
+ Loader.add_constructor(_mapping_tag, dict_constructor)
28
+ return Loader, Dumper
29
+
30
+
31
+ def parse(opt_path, is_train=True, name=None):
32
+ """Parse option file.
33
+
34
+ Args:
35
+ opt_path (str): Option file path.
36
+ is_train (str): Indicate whether in training or not. Default: True.
37
+
38
+ Returns:
39
+ (dict): Options.
40
+ """
41
+ with open(opt_path, mode='r') as f:
42
+ Loader, _ = ordered_yaml()
43
+ opt = yaml.load(f, Loader=Loader)
44
+
45
+ opt['is_train'] = is_train
46
+ if name is not None:
47
+ opt['name'] = name
48
+
49
+ # datasets
50
+ for phase, dataset in opt['datasets'].items():
51
+ # for several datasets, e.g., test_1, test_2
52
+ phase = phase.split('_')[0]
53
+ dataset['phase'] = phase
54
+ if 'scale' in opt:
55
+ dataset['scale'] = opt['scale']
56
+ if dataset.get('dataroot_gt') is not None:
57
+ dataset['dataroot_gt'] = osp.expanduser(dataset['dataroot_gt'])
58
+ if dataset.get('dataroot_lq') is not None:
59
+ dataset['dataroot_lq'] = osp.expanduser(dataset['dataroot_lq'])
60
+
61
+ # paths
62
+ for key, val in opt['path'].items():
63
+ if (val is not None) and ('resume_state' in key
64
+ or 'pretrain_network' in key):
65
+ opt['path'][key] = osp.expanduser(val)
66
+ opt['path']['root'] = osp.abspath(
67
+ osp.join(__file__, osp.pardir, osp.pardir, osp.pardir))
68
+ if is_train:
69
+ experiments_root = osp.join(opt['path']['root'], 'experiments',
70
+ opt['name'])
71
+ opt['path']['experiments_root'] = experiments_root
72
+ opt['path']['models'] = osp.join(experiments_root, 'models')
73
+ opt['path']['training_states'] = osp.join(experiments_root,
74
+ 'training_states')
75
+ opt['path']['log'] = experiments_root
76
+ opt['path']['visualization'] = osp.join(experiments_root,
77
+ 'visualization')
78
+
79
+ # change some options for debug mode
80
+ if 'debug' in opt['name']:
81
+ if 'val' in opt:
82
+ opt['val']['val_freq'] = 8
83
+ opt['logger']['print_freq'] = 1
84
+ opt['logger']['save_checkpoint_freq'] = 8
85
+ else: # test
86
+ results_root = osp.join(opt['path']['root'], 'results', opt['name'])
87
+ opt['path']['results_root'] = results_root
88
+ opt['path']['log'] = results_root
89
+ opt['path']['visualization'] = osp.join(results_root, 'visualization')
90
+
91
+ return opt
92
+
93
+
94
+ def dict2str(opt, indent_level=1):
95
+ """dict to string for printing options.
96
+
97
+ Args:
98
+ opt (dict): Option dict.
99
+ indent_level (int): Indent level. Default: 1.
100
+
101
+ Return:
102
+ (str): Option string for printing.
103
+ """
104
+ msg = '\n'
105
+ for k, v in opt.items():
106
+ if isinstance(v, dict):
107
+ msg += ' ' * (indent_level * 2) + k + ':['
108
+ msg += dict2str(v, indent_level + 1)
109
+ msg += ' ' * (indent_level * 2) + ']\n'
110
+ else:
111
+ msg += ' ' * (indent_level * 2) + k + ': ' + str(v) + '\n'
112
+ return msg
basicsr/version.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # GENERATED VERSION FILE
2
+ # TIME: Fri Mar 21 07:59:14 2025
3
+ __version__ = '1.2.0+5ea673c'
4
+ short_version = '1.2.0'
5
+ version_info = (1, 2, 0)
experiments/pretrained/models/net_g_100000.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8cc95533ca8a4dfdcfad5de2973346ad6b699c6abaf4e7e9d0de77007c4b855f
3
+ size 116763496
experiments/pretrained/training_states/100000.state ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:edb3104cc8f57a1100b4f0e3d87814a74b2c0fd1ed24a86d69b917b0e1973d2b
3
+ size 233563982
psf.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:337461630addd8dcc48a0293678b5ef75d9c35a5c7b6a0524154d2e8540741a8
3
+ size 17714828
readme.md ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Aberration Correcting Vision Transformers for High-Fidelity Metalens Imaging
2
+
3
+ Byeonghyeon Lee, Youbin Kim, Yongjae Jo, Hyunsu Kim, Hyemi Park, Yangkyu Kim, Debabrata Mandal, Praneeth Chakravarthula, Inki Kim, and Eunbyung Park
4
+
5
+ [Project Page](https://benhenryl.github.io/Metalens-Transformer/) &nbsp; [Paper](https://arxiv.org/abs/2412.04591)
6
+
7
+
8
+ We ran the experiments in the following environment:
9
+ ```
10
+ - ubuntu: 20.04
11
+ - python: 3.10.13
12
+ - cuda: 11.8
13
+ - pytorch: 2.2.0
14
+ - GPU: 4x A6000 ada
15
+ ```
16
+
17
+ Our code is based on [Restormer](https://github.com/swz30/Restormer), [X-Restormer](https://github.com/Andrew0613/X-Restormer), and [Neural Nano Optics](https://github.com/princeton-computational-imaging/Neural_Nano-Optics). We appreciate their works.
18
+
19
+ ## 1. Environment Setting
20
+ ### 1-1. Pytorch
21
+ Note: pytorch >= 2.2.0 is required for Flash Attention.
22
+
23
+ ### 1-2. [Flash Attention](https://github.com/Dao-AILab/flash-attention)
24
+ cf. Ampere, Ada, or Hopper GPUs (e.g., A100, RTX 3090, RTX 4090, H100) are supported now.
25
+ ```
26
+ pip install packaging ninja
27
+ pip install flash-attn --no-build-isolation
28
+ ```
29
+
30
+ ### 1-3. Other required packages
31
+ ```
32
+ pip install -r requirements.txt
33
+ ```
34
+
35
+ ### 1-4. Basicsr
36
+ ```
37
+ python setup.py develop --no_cuda_ext
38
+ ```
39
+
40
+ ## 2. Dataset & Pre-trained weights
41
+ You can download train/test dataset [here](https://drive.google.com/drive/folders/1e2wJwmcjXFvblVs0l5OXwpIkTqxd1Fhq?usp=drive_link) and pre-trained weights [here](https://drive.google.com/drive/folders/1q5pKE1Z0RJjHVmJlNq7nPSWcaGd9bDb7?usp=drive_link).
42
+ Please move the pre-trained weights to experiments/.
43
+ Note: The model creates aberrated images on the fly using clean (gt) images during training.
44
+ In case of validation, it also produces the aberrated images in the same manner, where the aberrated images can have different noises to what we used for our validation.
45
+ There will be only negligible difference in the results as it still uses the same noise distributions, but if you want a precise comparison with the validation set we used for our experiments, please contact us.
46
+
47
+
48
+ ## 3. Training
49
+ Please set dataset path in ```./Aberration_Correction/Options/Train_Aberration_Transformers.yml```
50
+ ```
51
+ bash train.sh GPU_IDS FOLDER_NAME
52
+ // ex. bash train.sh 0,1,2,3 training
53
+ // where it uses gpu 0 to 3 and make a directory experiments/training where log, weights and others will be stored.
54
+ ```
55
+
56
+ ## 4. Inference
57
+ Please set dataset path in ```./Aberration_Correction/Options/Test_Aberration_Transformers.yml```
58
+ If you want to run a inference using the pre-trained model, you can use a command
59
+ ```
60
+ bash test.sh GPU_ID FOLDER_NAME
61
+ // ex. bash test.sh 0 pretrained
62
+ ```
63
+ Or you can designate the FOLDER_NAME with your weight path.
64
+
65
+ ## BibTeX
66
+ ```
67
+ @article{lee2024aberration,
68
+ title={Aberration Correcting Vision Transformers for High-Fidelity Metalens Imaging},
69
+ author={Lee, Byeonghyeon and Kim, Youbin and Jo, Yongjae and Kim, Hyunsu and Park, Hyemi and Kim, Yangkyu and Mandal, Debabrata and Chakravarthula, Praneeth and Kim, Inki and Park, Eunbyung},
70
+ journal={arXiv preprint arXiv:2412.04591},
71
+ year={2024}
72
+ }
73
+ ```