Upload 53 files
Browse filesUpload codes and weights
This view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +1 -0
- Aberration_Correction/Options/Test_Aberration_Transformers.yml +75 -0
- Aberration_Correction/Options/Train_Aberration_Transformers.yml +141 -0
- Aberration_Correction/utils.py +90 -0
- VERSION +1 -0
- basicsr/data/__init__.py +126 -0
- basicsr/data/data_sampler.py +49 -0
- basicsr/data/data_util.py +15 -0
- basicsr/data/paired_image_dataset.py +156 -0
- basicsr/data/prefetch_dataloader.py +126 -0
- basicsr/data/transforms.py +167 -0
- basicsr/metrics/__init__.py +4 -0
- basicsr/metrics/fid.py +102 -0
- basicsr/metrics/metric_util.py +47 -0
- basicsr/metrics/niqe.py +205 -0
- basicsr/metrics/niqe_pris_params.npz +3 -0
- basicsr/metrics/other_metrics.py +88 -0
- basicsr/metrics/psnr_ssim.py +303 -0
- basicsr/models/__init__.py +42 -0
- basicsr/models/archs/__init__.py +45 -0
- basicsr/models/archs/arch_util.py +255 -0
- basicsr/models/archs/restormer_arch.py +527 -0
- basicsr/models/base_model.py +376 -0
- basicsr/models/image_restoration_model.py +392 -0
- basicsr/models/losses/__init__.py +5 -0
- basicsr/models/losses/loss_util.py +95 -0
- basicsr/models/losses/losses.py +180 -0
- basicsr/models/lr_scheduler.py +232 -0
- basicsr/test.py +142 -0
- basicsr/train.py +328 -0
- basicsr/utils/__init__.py +45 -0
- basicsr/utils/bundle_submissions.py +108 -0
- basicsr/utils/create_lmdb.py +124 -0
- basicsr/utils/dist_util.py +83 -0
- basicsr/utils/download_util.py +70 -0
- basicsr/utils/face_util.py +217 -0
- basicsr/utils/file_client.py +186 -0
- basicsr/utils/flow_util.py +180 -0
- basicsr/utils/img_util.py +216 -0
- basicsr/utils/lmdb_util.py +208 -0
- basicsr/utils/logger.py +175 -0
- basicsr/utils/matlab_functions.py +361 -0
- basicsr/utils/misc.py +266 -0
- basicsr/utils/nano.py +250 -0
- basicsr/utils/options.py +112 -0
- basicsr/version.py +5 -0
- experiments/pretrained/models/net_g_100000.pth +3 -0
- experiments/pretrained/training_states/100000.state +3 -0
- psf.npy +3 -0
- readme.md +73 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
experiments/pretrained/training_states/100000.state filter=lfs diff=lfs merge=lfs -text
|
Aberration_Correction/Options/Test_Aberration_Transformers.yml
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# general settings
|
| 2 |
+
name: sample_test
|
| 3 |
+
# name: batch8
|
| 4 |
+
model_type: ImageCleanModel
|
| 5 |
+
scale: 1
|
| 6 |
+
num_gpu: 4 # set num_gpu: 0 for cpu mode
|
| 7 |
+
manual_seed: 100
|
| 8 |
+
|
| 9 |
+
# dataset and data loader settings
|
| 10 |
+
datasets:
|
| 11 |
+
val:
|
| 12 |
+
name: ValSet
|
| 13 |
+
type: Dataset_PaddedImage # Use Dataset_PaddedImage_npy if load convolved images (lr images). Also please set dataroot_lq as well.
|
| 14 |
+
dataroot_gt: PATH_TO_TEST_SET # TODO
|
| 15 |
+
io_backend:
|
| 16 |
+
type: disk
|
| 17 |
+
|
| 18 |
+
sensor_size: 1215
|
| 19 |
+
psf_size: 135
|
| 20 |
+
|
| 21 |
+
# network structures
|
| 22 |
+
network_g:
|
| 23 |
+
type: ACFormer
|
| 24 |
+
inp_channels: 39
|
| 25 |
+
out_channels: 3
|
| 26 |
+
dim: 48
|
| 27 |
+
num_blocks: [2,4,4,4]
|
| 28 |
+
num_refinement_blocks: 4
|
| 29 |
+
channel_heads: [1,2,4,8]
|
| 30 |
+
spatial_heads: [1,2,4,8]
|
| 31 |
+
overlap_ratio: [0.5,0.5,0.5,0.5]
|
| 32 |
+
window_size: 8
|
| 33 |
+
spatial_dim_head: 16
|
| 34 |
+
ffn_expansion_factor: 2.66
|
| 35 |
+
bias: False
|
| 36 |
+
LayerNorm_type: WithBias
|
| 37 |
+
ca_dim: 32
|
| 38 |
+
ca_heads: 2
|
| 39 |
+
M: 13
|
| 40 |
+
window_size_ca: 8
|
| 41 |
+
query_ksize: [15,11,7,3,3]
|
| 42 |
+
|
| 43 |
+
# path
|
| 44 |
+
path:
|
| 45 |
+
pretrain_network_g: ~
|
| 46 |
+
strict_load_g: true
|
| 47 |
+
resume_state: ~
|
| 48 |
+
|
| 49 |
+
# training settings
|
| 50 |
+
train:
|
| 51 |
+
ks:
|
| 52 |
+
start: -2
|
| 53 |
+
end: -5
|
| 54 |
+
num: 13
|
| 55 |
+
|
| 56 |
+
# validation settings
|
| 57 |
+
val:
|
| 58 |
+
window_size: 8
|
| 59 |
+
save_img: true
|
| 60 |
+
rgb2bgr: true
|
| 61 |
+
use_image: true
|
| 62 |
+
max_minibatch: 8
|
| 63 |
+
padding: 64
|
| 64 |
+
|
| 65 |
+
metrics:
|
| 66 |
+
psnr: # metric name, can be arbitrary
|
| 67 |
+
type: calculate_psnr
|
| 68 |
+
crop_border: 0
|
| 69 |
+
test_y_channel: true
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
# dist training settings
|
| 73 |
+
dist_params:
|
| 74 |
+
backend: nccl
|
| 75 |
+
port: 29502
|
Aberration_Correction/Options/Train_Aberration_Transformers.yml
ADDED
|
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# general settings
|
| 2 |
+
name: sample_test
|
| 3 |
+
# name: batch8
|
| 4 |
+
model_type: ImageCleanModel
|
| 5 |
+
scale: 1
|
| 6 |
+
num_gpu: 4 # set num_gpu: 0 for cpu mode
|
| 7 |
+
manual_seed: 100
|
| 8 |
+
|
| 9 |
+
# dataset and data loader settings
|
| 10 |
+
datasets:
|
| 11 |
+
train:
|
| 12 |
+
name: TrainSet
|
| 13 |
+
type: Dataset_PaddedImage # make lr image from gt image on the fly.
|
| 14 |
+
dataroot_gt: PATH_TO_TRAIN_SET # TODO
|
| 15 |
+
|
| 16 |
+
filename_tmpl: '{}'
|
| 17 |
+
io_backend:
|
| 18 |
+
type: disk
|
| 19 |
+
|
| 20 |
+
# data loader
|
| 21 |
+
use_shuffle: true
|
| 22 |
+
num_worker_per_gpu: 8 # 8
|
| 23 |
+
batch_size_per_gpu: 2 # 8
|
| 24 |
+
|
| 25 |
+
gt_size: 256
|
| 26 |
+
|
| 27 |
+
dataset_enlarge_ratio: 1
|
| 28 |
+
prefetch_mode: ~
|
| 29 |
+
|
| 30 |
+
sensor_size: 1215
|
| 31 |
+
psf_size: 135
|
| 32 |
+
|
| 33 |
+
val:
|
| 34 |
+
name: ValSet
|
| 35 |
+
type: Dataset_PaddedImage
|
| 36 |
+
dataroot_gt: PATH_TO_TEST_SET # TODO
|
| 37 |
+
io_backend:
|
| 38 |
+
type: disk
|
| 39 |
+
|
| 40 |
+
sensor_size: 1215
|
| 41 |
+
psf_size: 135
|
| 42 |
+
|
| 43 |
+
# network structures
|
| 44 |
+
network_g:
|
| 45 |
+
type: ACFormer
|
| 46 |
+
inp_channels: 39
|
| 47 |
+
out_channels: 3
|
| 48 |
+
dim: 48
|
| 49 |
+
num_blocks: [2,4,4,4]
|
| 50 |
+
num_refinement_blocks: 4
|
| 51 |
+
channel_heads: [1,2,4,8]
|
| 52 |
+
spatial_heads: [1,2,4,8]
|
| 53 |
+
overlap_ratio: [0.5,0.5,0.5,0.5]
|
| 54 |
+
window_size: 8
|
| 55 |
+
spatial_dim_head: 16
|
| 56 |
+
ffn_expansion_factor: 2.66
|
| 57 |
+
bias: False
|
| 58 |
+
LayerNorm_type: WithBias
|
| 59 |
+
ca_dim: 32
|
| 60 |
+
ca_heads: 2
|
| 61 |
+
M: 13
|
| 62 |
+
window_size_ca: 8
|
| 63 |
+
query_ksize: [15,11,7,3,3]
|
| 64 |
+
|
| 65 |
+
# path
|
| 66 |
+
path:
|
| 67 |
+
pretrain_network_g: ~
|
| 68 |
+
strict_load_g: true
|
| 69 |
+
resume_state: ~
|
| 70 |
+
|
| 71 |
+
# training settings
|
| 72 |
+
train:
|
| 73 |
+
eval_only: True
|
| 74 |
+
eval_name: Sample_data
|
| 75 |
+
real_psf: True
|
| 76 |
+
grid: True
|
| 77 |
+
total_iter: 100000
|
| 78 |
+
warmup_iter: -1 # no warm up
|
| 79 |
+
use_grad_clip: true
|
| 80 |
+
contrast_tik: 2
|
| 81 |
+
sensor_height: 1215
|
| 82 |
+
|
| 83 |
+
scheduler:
|
| 84 |
+
type: CosineAnnealingRestartCyclicLR
|
| 85 |
+
periods: [92000, 208000]
|
| 86 |
+
restart_weights: [1,1]
|
| 87 |
+
eta_mins: [0.0003,0.000001]
|
| 88 |
+
|
| 89 |
+
mixing_augs:
|
| 90 |
+
mixup: false
|
| 91 |
+
mixup_beta: 1.2
|
| 92 |
+
use_identity: true
|
| 93 |
+
|
| 94 |
+
optim_g:
|
| 95 |
+
type: AdamW
|
| 96 |
+
lr: !!float 3e-4
|
| 97 |
+
weight_decay: !!float 1e-4
|
| 98 |
+
betas: [0.9, 0.999]
|
| 99 |
+
|
| 100 |
+
# losses
|
| 101 |
+
pixel_opt:
|
| 102 |
+
type: L1Loss
|
| 103 |
+
loss_weight: 1
|
| 104 |
+
reduction: mean
|
| 105 |
+
|
| 106 |
+
ks:
|
| 107 |
+
start: -2
|
| 108 |
+
end: -5
|
| 109 |
+
num: 13
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
# validation settings
|
| 113 |
+
val:
|
| 114 |
+
window_size: 8
|
| 115 |
+
val_freq: !!float 1e8 # inactivated
|
| 116 |
+
save_img: false
|
| 117 |
+
rgb2bgr: true
|
| 118 |
+
use_image: true
|
| 119 |
+
max_minibatch: 8
|
| 120 |
+
padding: 64
|
| 121 |
+
apply_conv: True # Apply convolution to GT image to create lr image. False if load .npy data (already aberrated)
|
| 122 |
+
|
| 123 |
+
metrics:
|
| 124 |
+
psnr: # metric name, can be arbitrary
|
| 125 |
+
type: calculate_psnr
|
| 126 |
+
crop_border: 0
|
| 127 |
+
test_y_channel: true
|
| 128 |
+
|
| 129 |
+
# logging settings
|
| 130 |
+
logger:
|
| 131 |
+
print_freq: 500
|
| 132 |
+
save_checkpoint_freq: !!float 5e3
|
| 133 |
+
use_tb_logger: true
|
| 134 |
+
wandb:
|
| 135 |
+
project: ~
|
| 136 |
+
resume_id: ~
|
| 137 |
+
|
| 138 |
+
# dist training settings
|
| 139 |
+
dist_params:
|
| 140 |
+
backend: nccl
|
| 141 |
+
port: 29502
|
Aberration_Correction/utils.py
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
## Restormer: Efficient Transformer for High-Resolution Image Restoration
|
| 2 |
+
## Syed Waqas Zamir, Aditya Arora, Salman Khan, Munawar Hayat, Fahad Shahbaz Khan, and Ming-Hsuan Yang
|
| 3 |
+
## https://arxiv.org/abs/2111.09881
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import os
|
| 7 |
+
import cv2
|
| 8 |
+
import math
|
| 9 |
+
|
| 10 |
+
def calculate_psnr(img1, img2, border=0):
|
| 11 |
+
# img1 and img2 have range [0, 255]
|
| 12 |
+
#img1 = img1.squeeze()
|
| 13 |
+
#img2 = img2.squeeze()
|
| 14 |
+
if not img1.shape == img2.shape:
|
| 15 |
+
raise ValueError('Input images must have the same dimensions.')
|
| 16 |
+
h, w = img1.shape[:2]
|
| 17 |
+
img1 = img1[border:h-border, border:w-border]
|
| 18 |
+
img2 = img2[border:h-border, border:w-border]
|
| 19 |
+
|
| 20 |
+
img1 = img1.astype(np.float64)
|
| 21 |
+
img2 = img2.astype(np.float64)
|
| 22 |
+
mse = np.mean((img1 - img2)**2)
|
| 23 |
+
if mse == 0:
|
| 24 |
+
return float('inf')
|
| 25 |
+
return 20 * math.log10(255.0 / math.sqrt(mse))
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
# --------------------------------------------
|
| 29 |
+
# SSIM
|
| 30 |
+
# --------------------------------------------
|
| 31 |
+
def calculate_ssim(img1, img2, border=0):
|
| 32 |
+
'''calculate SSIM
|
| 33 |
+
the same outputs as MATLAB's
|
| 34 |
+
img1, img2: [0, 255]
|
| 35 |
+
'''
|
| 36 |
+
#img1 = img1.squeeze()
|
| 37 |
+
#img2 = img2.squeeze()
|
| 38 |
+
if not img1.shape == img2.shape:
|
| 39 |
+
raise ValueError('Input images must have the same dimensions.')
|
| 40 |
+
h, w = img1.shape[:2]
|
| 41 |
+
img1 = img1[border:h-border, border:w-border]
|
| 42 |
+
img2 = img2[border:h-border, border:w-border]
|
| 43 |
+
|
| 44 |
+
if img1.ndim == 2:
|
| 45 |
+
return ssim(img1, img2)
|
| 46 |
+
elif img1.ndim == 3:
|
| 47 |
+
if img1.shape[2] == 3:
|
| 48 |
+
ssims = []
|
| 49 |
+
for i in range(3):
|
| 50 |
+
ssims.append(ssim(img1[:,:,i], img2[:,:,i]))
|
| 51 |
+
return np.array(ssims).mean()
|
| 52 |
+
elif img1.shape[2] == 1:
|
| 53 |
+
return ssim(np.squeeze(img1), np.squeeze(img2))
|
| 54 |
+
else:
|
| 55 |
+
raise ValueError('Wrong input image dimensions.')
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def ssim(img1, img2):
|
| 59 |
+
C1 = (0.01 * 255)**2
|
| 60 |
+
C2 = (0.03 * 255)**2
|
| 61 |
+
|
| 62 |
+
img1 = img1.astype(np.float64)
|
| 63 |
+
img2 = img2.astype(np.float64)
|
| 64 |
+
kernel = cv2.getGaussianKernel(11, 1.5)
|
| 65 |
+
window = np.outer(kernel, kernel.transpose())
|
| 66 |
+
|
| 67 |
+
mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid
|
| 68 |
+
mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
|
| 69 |
+
mu1_sq = mu1**2
|
| 70 |
+
mu2_sq = mu2**2
|
| 71 |
+
mu1_mu2 = mu1 * mu2
|
| 72 |
+
sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq
|
| 73 |
+
sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
|
| 74 |
+
sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
|
| 75 |
+
|
| 76 |
+
ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
|
| 77 |
+
(sigma1_sq + sigma2_sq + C2))
|
| 78 |
+
return ssim_map.mean()
|
| 79 |
+
|
| 80 |
+
def load_img(filepath):
|
| 81 |
+
return cv2.cvtColor(cv2.imread(filepath), cv2.COLOR_BGR2RGB)
|
| 82 |
+
|
| 83 |
+
def save_img(filepath, img):
|
| 84 |
+
cv2.imwrite(filepath,cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
|
| 85 |
+
|
| 86 |
+
def load_gray_img(filepath):
|
| 87 |
+
return np.expand_dims(cv2.imread(filepath, cv2.IMREAD_GRAYSCALE), axis=2)
|
| 88 |
+
|
| 89 |
+
def save_gray_img(filepath, img):
|
| 90 |
+
cv2.imwrite(filepath, img)
|
VERSION
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
1.2.0
|
basicsr/data/__init__.py
ADDED
|
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import importlib
|
| 2 |
+
import numpy as np
|
| 3 |
+
import random
|
| 4 |
+
import torch
|
| 5 |
+
import torch.utils.data
|
| 6 |
+
from functools import partial
|
| 7 |
+
from os import path as osp
|
| 8 |
+
|
| 9 |
+
from basicsr.data.prefetch_dataloader import PrefetchDataLoader
|
| 10 |
+
from basicsr.utils import get_root_logger, scandir
|
| 11 |
+
from basicsr.utils.dist_util import get_dist_info
|
| 12 |
+
|
| 13 |
+
__all__ = ['create_dataset', 'create_dataloader']
|
| 14 |
+
|
| 15 |
+
# automatically scan and import dataset modules
|
| 16 |
+
# scan all the files under the data folder with '_dataset' in file names
|
| 17 |
+
data_folder = osp.dirname(osp.abspath(__file__))
|
| 18 |
+
dataset_filenames = [
|
| 19 |
+
osp.splitext(osp.basename(v))[0] for v in scandir(data_folder)
|
| 20 |
+
if v.endswith('_dataset.py')
|
| 21 |
+
]
|
| 22 |
+
# import all the dataset modules
|
| 23 |
+
_dataset_modules = [
|
| 24 |
+
importlib.import_module(f'basicsr.data.{file_name}')
|
| 25 |
+
for file_name in dataset_filenames
|
| 26 |
+
]
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def create_dataset(dataset_opt, mv=False):
|
| 30 |
+
"""Create dataset.
|
| 31 |
+
|
| 32 |
+
Args:
|
| 33 |
+
dataset_opt (dict): Configuration for dataset. It constains:
|
| 34 |
+
name (str): Dataset name.
|
| 35 |
+
type (str): Dataset type.
|
| 36 |
+
"""
|
| 37 |
+
dataset_type = dataset_opt['type']
|
| 38 |
+
# dynamic instantiation
|
| 39 |
+
for module in _dataset_modules:
|
| 40 |
+
dataset_cls = getattr(module, dataset_type, None)
|
| 41 |
+
if dataset_cls is not None:
|
| 42 |
+
break
|
| 43 |
+
if dataset_cls is None:
|
| 44 |
+
raise ValueError(f'Dataset {dataset_type} is not found.')
|
| 45 |
+
|
| 46 |
+
dataset = dataset_cls(dataset_opt)
|
| 47 |
+
|
| 48 |
+
logger = get_root_logger()
|
| 49 |
+
logger.info(
|
| 50 |
+
f'Dataset {dataset.__class__.__name__} - {dataset_opt["name"]} '
|
| 51 |
+
'is created.')
|
| 52 |
+
return dataset
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def create_dataloader(dataset,
|
| 56 |
+
dataset_opt,
|
| 57 |
+
num_gpu=1,
|
| 58 |
+
dist=False,
|
| 59 |
+
sampler=None,
|
| 60 |
+
seed=None):
|
| 61 |
+
"""Create dataloader.
|
| 62 |
+
|
| 63 |
+
Args:
|
| 64 |
+
dataset (torch.utils.data.Dataset): Dataset.
|
| 65 |
+
dataset_opt (dict): Dataset options. It contains the following keys:
|
| 66 |
+
phase (str): 'train' or 'val'.
|
| 67 |
+
num_worker_per_gpu (int): Number of workers for each GPU.
|
| 68 |
+
batch_size_per_gpu (int): Training batch size for each GPU.
|
| 69 |
+
num_gpu (int): Number of GPUs. Used only in the train phase.
|
| 70 |
+
Default: 1.
|
| 71 |
+
dist (bool): Whether in distributed training. Used only in the train
|
| 72 |
+
phase. Default: False.
|
| 73 |
+
sampler (torch.utils.data.sampler): Data sampler. Default: None.
|
| 74 |
+
seed (int | None): Seed. Default: None
|
| 75 |
+
"""
|
| 76 |
+
phase = dataset_opt['phase']
|
| 77 |
+
rank, _ = get_dist_info()
|
| 78 |
+
if phase == 'train':
|
| 79 |
+
if dist: # distributed training
|
| 80 |
+
batch_size = dataset_opt['batch_size_per_gpu']
|
| 81 |
+
num_workers = dataset_opt['num_worker_per_gpu']
|
| 82 |
+
else: # non-distributed training
|
| 83 |
+
multiplier = 1 if num_gpu == 0 else num_gpu
|
| 84 |
+
batch_size = dataset_opt['batch_size_per_gpu'] * multiplier
|
| 85 |
+
num_workers = dataset_opt['num_worker_per_gpu'] * multiplier
|
| 86 |
+
dataloader_args = dict(
|
| 87 |
+
dataset=dataset,
|
| 88 |
+
batch_size=batch_size,
|
| 89 |
+
shuffle=False,
|
| 90 |
+
num_workers=num_workers,
|
| 91 |
+
sampler=sampler,
|
| 92 |
+
drop_last=True)
|
| 93 |
+
|
| 94 |
+
if sampler is None:
|
| 95 |
+
dataloader_args['shuffle'] = True
|
| 96 |
+
dataloader_args['worker_init_fn'] = partial(
|
| 97 |
+
worker_init_fn, num_workers=num_workers, rank=rank,
|
| 98 |
+
seed=seed) if seed is not None else None
|
| 99 |
+
elif phase in ['val', 'test', 'val20']: # validation
|
| 100 |
+
dataloader_args = dict(
|
| 101 |
+
dataset=dataset, batch_size=1, shuffle=False, num_workers=0)
|
| 102 |
+
else:
|
| 103 |
+
raise ValueError(f'Wrong dataset phase: {phase}. '
|
| 104 |
+
"Supported ones are 'train', 'val' and 'test'.")
|
| 105 |
+
|
| 106 |
+
dataloader_args['pin_memory'] = dataset_opt.get('pin_memory', False)
|
| 107 |
+
|
| 108 |
+
prefetch_mode = dataset_opt.get('prefetch_mode')
|
| 109 |
+
if prefetch_mode == 'cpu': # CPUPrefetcher
|
| 110 |
+
num_prefetch_queue = dataset_opt.get('num_prefetch_queue', 1)
|
| 111 |
+
logger = get_root_logger()
|
| 112 |
+
logger.info(f'Use {prefetch_mode} prefetch dataloader: '
|
| 113 |
+
f'num_prefetch_queue = {num_prefetch_queue}')
|
| 114 |
+
return PrefetchDataLoader(
|
| 115 |
+
num_prefetch_queue=num_prefetch_queue, **dataloader_args)
|
| 116 |
+
else:
|
| 117 |
+
# prefetch_mode=None: Normal dataloader
|
| 118 |
+
# prefetch_mode='cuda': dataloader for CUDAPrefetcher
|
| 119 |
+
return torch.utils.data.DataLoader(**dataloader_args)
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def worker_init_fn(worker_id, num_workers, rank, seed):
|
| 123 |
+
# Set the worker seed to num_workers * rank + worker_id + seed
|
| 124 |
+
worker_seed = num_workers * rank + worker_id + seed
|
| 125 |
+
np.random.seed(worker_seed)
|
| 126 |
+
random.seed(worker_seed)
|
basicsr/data/data_sampler.py
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import torch
|
| 3 |
+
from torch.utils.data.sampler import Sampler
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class EnlargedSampler(Sampler):
|
| 7 |
+
"""Sampler that restricts data loading to a subset of the dataset.
|
| 8 |
+
|
| 9 |
+
Modified from torch.utils.data.distributed.DistributedSampler
|
| 10 |
+
Support enlarging the dataset for iteration-based training, for saving
|
| 11 |
+
time when restart the dataloader after each epoch
|
| 12 |
+
|
| 13 |
+
Args:
|
| 14 |
+
dataset (torch.utils.data.Dataset): Dataset used for sampling.
|
| 15 |
+
num_replicas (int | None): Number of processes participating in
|
| 16 |
+
the training. It is usually the world_size.
|
| 17 |
+
rank (int | None): Rank of the current process within num_replicas.
|
| 18 |
+
ratio (int): Enlarging ratio. Default: 1.
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
def __init__(self, dataset, num_replicas, rank, ratio=1):
|
| 22 |
+
self.dataset = dataset
|
| 23 |
+
self.num_replicas = num_replicas
|
| 24 |
+
self.rank = rank
|
| 25 |
+
self.epoch = 0
|
| 26 |
+
self.num_samples = math.ceil(
|
| 27 |
+
len(self.dataset) * ratio / self.num_replicas)
|
| 28 |
+
self.total_size = self.num_samples * self.num_replicas
|
| 29 |
+
|
| 30 |
+
def __iter__(self):
|
| 31 |
+
# deterministically shuffle based on epoch
|
| 32 |
+
g = torch.Generator()
|
| 33 |
+
g.manual_seed(self.epoch)
|
| 34 |
+
indices = torch.randperm(self.total_size, generator=g).tolist()
|
| 35 |
+
|
| 36 |
+
dataset_size = len(self.dataset)
|
| 37 |
+
indices = [v % dataset_size for v in indices]
|
| 38 |
+
|
| 39 |
+
# subsample
|
| 40 |
+
indices = indices[self.rank:self.total_size:self.num_replicas]
|
| 41 |
+
assert len(indices) == self.num_samples
|
| 42 |
+
|
| 43 |
+
return iter(indices)
|
| 44 |
+
|
| 45 |
+
def __len__(self):
|
| 46 |
+
return self.num_samples
|
| 47 |
+
|
| 48 |
+
def set_epoch(self, epoch):
|
| 49 |
+
self.epoch = epoch
|
basicsr/data/data_util.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import cv2
|
| 2 |
+
cv2.setNumThreads(1)
|
| 3 |
+
from os import path as osp
|
| 4 |
+
from basicsr.utils import scandir
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def paths_from_folder(folder, key):
|
| 8 |
+
gt_paths = list(scandir(folder))
|
| 9 |
+
paths = []
|
| 10 |
+
for idx in range(len(gt_paths)):
|
| 11 |
+
gt_path = gt_paths[idx]
|
| 12 |
+
gt_path = osp.join(folder, gt_path)
|
| 13 |
+
paths.append(
|
| 14 |
+
dict([(f'{key}_path', gt_path)]))
|
| 15 |
+
return paths
|
basicsr/data/paired_image_dataset.py
ADDED
|
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torch.utils import data as data
|
| 2 |
+
from torchvision.transforms.functional import normalize
|
| 3 |
+
|
| 4 |
+
from basicsr.data.data_util import paths_from_folder
|
| 5 |
+
from basicsr.utils import FileClient, imfrombytes, img2tensor, padding
|
| 6 |
+
from natsort import natsorted
|
| 7 |
+
import random
|
| 8 |
+
import numpy as np
|
| 9 |
+
import torch
|
| 10 |
+
import cv2
|
| 11 |
+
import os
|
| 12 |
+
import random
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class Dataset_PaddedImage(data.Dataset):
|
| 16 |
+
"""Padded image dataset for image restoration.
|
| 17 |
+
|
| 18 |
+
Args:
|
| 19 |
+
opt (dict): Config for train datasets. It contains the following keys:
|
| 20 |
+
dataroot_gt (str): Data root path for gt.
|
| 21 |
+
io_backend (dict): IO backend type and other kwarg.
|
| 22 |
+
gt_size (int): Cropped patched size for gt patches.
|
| 23 |
+
scale (bool): Scale, which will be added automatically.
|
| 24 |
+
phase (str): 'train' or 'val'.
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
def __init__(self, opt):
|
| 28 |
+
super(Dataset_PaddedImage, self).__init__()
|
| 29 |
+
self.opt = opt
|
| 30 |
+
# file client (io backend)
|
| 31 |
+
self.file_client = None
|
| 32 |
+
self.io_backend_opt = opt['io_backend']
|
| 33 |
+
|
| 34 |
+
self.gt_folder = opt['dataroot_gt']
|
| 35 |
+
self.paths = paths_from_folder(self.gt_folder, 'gt')
|
| 36 |
+
|
| 37 |
+
self.sensor_size = opt['sensor_size']
|
| 38 |
+
self.psf_size = opt['psf_size']
|
| 39 |
+
self.padded_size = self.sensor_size + 2 * self.psf_size
|
| 40 |
+
|
| 41 |
+
def __getitem__(self, index):
|
| 42 |
+
if self.file_client is None:
|
| 43 |
+
self.file_client = FileClient(
|
| 44 |
+
self.io_backend_opt.pop('type'), **self.io_backend_opt)
|
| 45 |
+
|
| 46 |
+
scale = self.opt['scale']
|
| 47 |
+
index = index % len(self.paths)
|
| 48 |
+
# Load gt and lq images. Dimension order: HWC; channel order: BGR;
|
| 49 |
+
# image range: [0, 1], float32.
|
| 50 |
+
gt_path = self.paths[index]['gt_path']
|
| 51 |
+
img_bytes = self.file_client.get(gt_path, 'gt')
|
| 52 |
+
try:
|
| 53 |
+
img_gt = imfrombytes(img_bytes, float32=True)
|
| 54 |
+
except:
|
| 55 |
+
raise Exception("gt path {} not working".format(gt_path))
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
if self.opt['phase'] == 'train':
|
| 59 |
+
gt_size = self.opt['gt_size']
|
| 60 |
+
# padding
|
| 61 |
+
img_gt = padding(img_gt, gt_size) # h,w,c
|
| 62 |
+
orig_h, orig_w, _ = img_gt.shape
|
| 63 |
+
|
| 64 |
+
# Fit one axis to sensor height (width)
|
| 65 |
+
longer = max(orig_h, orig_w)
|
| 66 |
+
scale = float(longer / self.sensor_size)
|
| 67 |
+
resolution = (int(orig_w / scale), int(orig_h / scale))
|
| 68 |
+
img_gt = cv2.resize(img_gt, resolution, interpolation=cv2.INTER_LINEAR) # sensor_size,x,3 or y,sensor_size,3 where x,y <= sensor_size
|
| 69 |
+
|
| 70 |
+
resized_h, resized_w, _ = img_gt.shape
|
| 71 |
+
# add padding
|
| 72 |
+
pad_h = self.padded_size - resized_h
|
| 73 |
+
pad_w = self.padded_size - resized_w
|
| 74 |
+
pad_l = pad_r = pad_w // 2
|
| 75 |
+
if pad_w % 2:
|
| 76 |
+
pad_r += 1
|
| 77 |
+
pad_t = pad_b = pad_h // 2
|
| 78 |
+
if pad_h % 2:
|
| 79 |
+
pad_b += 1
|
| 80 |
+
img_gt = np.pad(img_gt, ((pad_t, pad_b), (pad_l, pad_r), (0,0))) # padded_size,padded_size,3
|
| 81 |
+
|
| 82 |
+
# BGR to RGB, HWC to CHW, numpy to tensor
|
| 83 |
+
img_gt = img2tensor(img_gt, bgr2rgb=True,
|
| 84 |
+
float32=True)
|
| 85 |
+
|
| 86 |
+
return {
|
| 87 |
+
'gt': img_gt,
|
| 88 |
+
'gt_path': gt_path,
|
| 89 |
+
'padding': (pad_t-self.psf_size, pad_b-self.psf_size, pad_l-self.psf_size, pad_r-self.psf_size)
|
| 90 |
+
}
|
| 91 |
+
|
| 92 |
+
def __len__(self):
|
| 93 |
+
return len(self.paths)
|
| 94 |
+
|
| 95 |
+
class Dataset_PaddedImage_npy(data.Dataset):
|
| 96 |
+
# validation only
|
| 97 |
+
def __init__(self, opt):
|
| 98 |
+
super(Dataset_PaddedImage_npy, self).__init__()
|
| 99 |
+
self.opt = opt
|
| 100 |
+
# file client (io backend)
|
| 101 |
+
self.file_client = None
|
| 102 |
+
self.io_backend_opt = opt['io_backend']
|
| 103 |
+
self.gt_folder, self.lq_folder = opt['dataroot_gt'], opt['dataroot_lq']
|
| 104 |
+
self.lq_paths = natsorted(os.listdir(self.lq_folder))
|
| 105 |
+
self.gt_paths = natsorted(os.listdir(self.gt_folder))
|
| 106 |
+
|
| 107 |
+
self.sensor_size = opt['sensor_size']
|
| 108 |
+
self.psf_size = opt['psf_size']
|
| 109 |
+
self.padded_size = self.sensor_size + 2 * self.psf_size
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def __getitem__(self, index):
|
| 113 |
+
if self.file_client is None:
|
| 114 |
+
self.file_client = FileClient(
|
| 115 |
+
self.io_backend_opt.pop('type'), **self.io_backend_opt)
|
| 116 |
+
|
| 117 |
+
scale = self.opt['scale']
|
| 118 |
+
index = index % len(self.gt_paths)
|
| 119 |
+
# Load gt and lq images. Dimension order: HWC; channel order: BGR;
|
| 120 |
+
# image range: [0, 1], float32.
|
| 121 |
+
gt_path = f"{self.gt_folder}/{self.gt_paths[index]}"
|
| 122 |
+
lq_path = f"{self.lq_folder}/{self.lq_paths[index]}"
|
| 123 |
+
assert os.path.basename(gt_path).split(".")[0] == os.path.basename(lq_path).split(".")[0]
|
| 124 |
+
|
| 125 |
+
img_bytes = self.file_client.get(gt_path, 'gt')
|
| 126 |
+
try:
|
| 127 |
+
img_gt = imfrombytes(img_bytes, float32=True)
|
| 128 |
+
except:
|
| 129 |
+
raise Exception("gt path {} not working".format(gt_path))
|
| 130 |
+
|
| 131 |
+
img_lq = torch.tensor(np.load(lq_path)) # 1,1,81,3,405,405
|
| 132 |
+
|
| 133 |
+
resized_h, resized_w, _ = img_gt.shape
|
| 134 |
+
pad_h = self.padded_size - resized_h
|
| 135 |
+
pad_w = self.padded_size - resized_w
|
| 136 |
+
pad_l = pad_r = pad_w // 2
|
| 137 |
+
if pad_w % 2:
|
| 138 |
+
pad_r += 1
|
| 139 |
+
pad_t = pad_b = pad_h // 2
|
| 140 |
+
if pad_h % 2:
|
| 141 |
+
pad_b += 1
|
| 142 |
+
|
| 143 |
+
# BGR to RGB, HWC to CHW, numpy to tensor
|
| 144 |
+
img_gt = img2tensor(img_gt, bgr2rgb=True,
|
| 145 |
+
float32=True)
|
| 146 |
+
|
| 147 |
+
return {
|
| 148 |
+
'gt': img_gt,
|
| 149 |
+
'lq': img_lq,
|
| 150 |
+
'lq_path': lq_path,
|
| 151 |
+
'gt_path': gt_path,
|
| 152 |
+
'padding': (pad_t-self.psf_size, pad_b-self.psf_size, pad_l-self.psf_size, pad_r-self.psf_size)
|
| 153 |
+
}
|
| 154 |
+
|
| 155 |
+
def __len__(self):
|
| 156 |
+
return len(self.gt_paths)
|
basicsr/data/prefetch_dataloader.py
ADDED
|
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import queue as Queue
|
| 2 |
+
import threading
|
| 3 |
+
import torch
|
| 4 |
+
from torch.utils.data import DataLoader
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class PrefetchGenerator(threading.Thread):
|
| 8 |
+
"""A general prefetch generator.
|
| 9 |
+
|
| 10 |
+
Ref:
|
| 11 |
+
https://stackoverflow.com/questions/7323664/python-generator-pre-fetch
|
| 12 |
+
|
| 13 |
+
Args:
|
| 14 |
+
generator: Python generator.
|
| 15 |
+
num_prefetch_queue (int): Number of prefetch queue.
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
def __init__(self, generator, num_prefetch_queue):
|
| 19 |
+
threading.Thread.__init__(self)
|
| 20 |
+
self.queue = Queue.Queue(num_prefetch_queue)
|
| 21 |
+
self.generator = generator
|
| 22 |
+
self.daemon = True
|
| 23 |
+
self.start()
|
| 24 |
+
|
| 25 |
+
def run(self):
|
| 26 |
+
for item in self.generator:
|
| 27 |
+
self.queue.put(item)
|
| 28 |
+
self.queue.put(None)
|
| 29 |
+
|
| 30 |
+
def __next__(self):
|
| 31 |
+
next_item = self.queue.get()
|
| 32 |
+
if next_item is None:
|
| 33 |
+
raise StopIteration
|
| 34 |
+
return next_item
|
| 35 |
+
|
| 36 |
+
def __iter__(self):
|
| 37 |
+
return self
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class PrefetchDataLoader(DataLoader):
|
| 41 |
+
"""Prefetch version of dataloader.
|
| 42 |
+
|
| 43 |
+
Ref:
|
| 44 |
+
https://github.com/IgorSusmelj/pytorch-styleguide/issues/5#
|
| 45 |
+
|
| 46 |
+
TODO:
|
| 47 |
+
Need to test on single gpu and ddp (multi-gpu). There is a known issue in
|
| 48 |
+
ddp.
|
| 49 |
+
|
| 50 |
+
Args:
|
| 51 |
+
num_prefetch_queue (int): Number of prefetch queue.
|
| 52 |
+
kwargs (dict): Other arguments for dataloader.
|
| 53 |
+
"""
|
| 54 |
+
|
| 55 |
+
def __init__(self, num_prefetch_queue, **kwargs):
|
| 56 |
+
self.num_prefetch_queue = num_prefetch_queue
|
| 57 |
+
super(PrefetchDataLoader, self).__init__(**kwargs)
|
| 58 |
+
|
| 59 |
+
def __iter__(self):
|
| 60 |
+
return PrefetchGenerator(super().__iter__(), self.num_prefetch_queue)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
class CPUPrefetcher():
|
| 64 |
+
"""CPU prefetcher.
|
| 65 |
+
|
| 66 |
+
Args:
|
| 67 |
+
loader: Dataloader.
|
| 68 |
+
"""
|
| 69 |
+
|
| 70 |
+
def __init__(self, loader):
|
| 71 |
+
self.ori_loader = loader
|
| 72 |
+
self.loader = iter(loader)
|
| 73 |
+
|
| 74 |
+
def next(self):
|
| 75 |
+
try:
|
| 76 |
+
return next(self.loader)
|
| 77 |
+
except StopIteration:
|
| 78 |
+
return None
|
| 79 |
+
|
| 80 |
+
def reset(self):
|
| 81 |
+
self.loader = iter(self.ori_loader)
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
class CUDAPrefetcher():
|
| 85 |
+
"""CUDA prefetcher.
|
| 86 |
+
|
| 87 |
+
Ref:
|
| 88 |
+
https://github.com/NVIDIA/apex/issues/304#
|
| 89 |
+
|
| 90 |
+
It may consums more GPU memory.
|
| 91 |
+
|
| 92 |
+
Args:
|
| 93 |
+
loader: Dataloader.
|
| 94 |
+
opt (dict): Options.
|
| 95 |
+
"""
|
| 96 |
+
|
| 97 |
+
def __init__(self, loader, opt):
|
| 98 |
+
self.ori_loader = loader
|
| 99 |
+
self.loader = iter(loader)
|
| 100 |
+
self.opt = opt
|
| 101 |
+
self.stream = torch.cuda.Stream()
|
| 102 |
+
self.device = torch.device('cuda' if opt['num_gpu'] != 0 else 'cpu')
|
| 103 |
+
self.preload()
|
| 104 |
+
|
| 105 |
+
def preload(self):
|
| 106 |
+
try:
|
| 107 |
+
self.batch = next(self.loader) # self.batch is a dict
|
| 108 |
+
except StopIteration:
|
| 109 |
+
self.batch = None
|
| 110 |
+
return None
|
| 111 |
+
# put tensors to gpu
|
| 112 |
+
with torch.cuda.stream(self.stream):
|
| 113 |
+
for k, v in self.batch.items():
|
| 114 |
+
if torch.is_tensor(v):
|
| 115 |
+
self.batch[k] = self.batch[k].to(
|
| 116 |
+
device=self.device, non_blocking=True)
|
| 117 |
+
|
| 118 |
+
def next(self):
|
| 119 |
+
torch.cuda.current_stream().wait_stream(self.stream)
|
| 120 |
+
batch = self.batch
|
| 121 |
+
self.preload()
|
| 122 |
+
return batch
|
| 123 |
+
|
| 124 |
+
def reset(self):
|
| 125 |
+
self.loader = iter(self.ori_loader)
|
| 126 |
+
self.preload()
|
basicsr/data/transforms.py
ADDED
|
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import cv2
|
| 2 |
+
import random
|
| 3 |
+
import numpy as np
|
| 4 |
+
|
| 5 |
+
def mod_crop(img, scale):
|
| 6 |
+
"""Mod crop images, used during testing.
|
| 7 |
+
|
| 8 |
+
Args:
|
| 9 |
+
img (ndarray): Input image.
|
| 10 |
+
scale (int): Scale factor.
|
| 11 |
+
|
| 12 |
+
Returns:
|
| 13 |
+
ndarray: Result image.
|
| 14 |
+
"""
|
| 15 |
+
img = img.copy()
|
| 16 |
+
if img.ndim in (2, 3):
|
| 17 |
+
h, w = img.shape[0], img.shape[1]
|
| 18 |
+
h_remainder, w_remainder = h % scale, w % scale
|
| 19 |
+
img = img[:h - h_remainder, :w - w_remainder, ...]
|
| 20 |
+
else:
|
| 21 |
+
raise ValueError(f'Wrong img ndim: {img.ndim}.')
|
| 22 |
+
return img
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def augment(imgs, hflip=True, rotation=True, flows=None, return_status=False):
|
| 26 |
+
"""Augment: horizontal flips OR rotate (0, 90, 180, 270 degrees).
|
| 27 |
+
|
| 28 |
+
We use vertical flip and transpose for rotation implementation.
|
| 29 |
+
All the images in the list use the same augmentation.
|
| 30 |
+
|
| 31 |
+
Args:
|
| 32 |
+
imgs (list[ndarray] | ndarray): Images to be augmented. If the input
|
| 33 |
+
is an ndarray, it will be transformed to a list.
|
| 34 |
+
hflip (bool): Horizontal flip. Default: True.
|
| 35 |
+
rotation (bool): Ratotation. Default: True.
|
| 36 |
+
flows (list[ndarray]: Flows to be augmented. If the input is an
|
| 37 |
+
ndarray, it will be transformed to a list.
|
| 38 |
+
Dimension is (h, w, 2). Default: None.
|
| 39 |
+
return_status (bool): Return the status of flip and rotation.
|
| 40 |
+
Default: False.
|
| 41 |
+
|
| 42 |
+
Returns:
|
| 43 |
+
list[ndarray] | ndarray: Augmented images and flows. If returned
|
| 44 |
+
results only have one element, just return ndarray.
|
| 45 |
+
|
| 46 |
+
"""
|
| 47 |
+
hflip = hflip and random.random() < 0.5
|
| 48 |
+
vflip = rotation and random.random() < 0.5
|
| 49 |
+
rot90 = rotation and random.random() < 0.5
|
| 50 |
+
|
| 51 |
+
def _augment(img):
|
| 52 |
+
if hflip: # horizontal
|
| 53 |
+
cv2.flip(img, 1, img)
|
| 54 |
+
if vflip: # vertical
|
| 55 |
+
cv2.flip(img, 0, img)
|
| 56 |
+
if rot90:
|
| 57 |
+
img = img.transpose(1, 0, 2)
|
| 58 |
+
return img
|
| 59 |
+
|
| 60 |
+
def _augment_flow(flow):
|
| 61 |
+
if hflip: # horizontal
|
| 62 |
+
cv2.flip(flow, 1, flow)
|
| 63 |
+
flow[:, :, 0] *= -1
|
| 64 |
+
if vflip: # vertical
|
| 65 |
+
cv2.flip(flow, 0, flow)
|
| 66 |
+
flow[:, :, 1] *= -1
|
| 67 |
+
if rot90:
|
| 68 |
+
flow = flow.transpose(1, 0, 2)
|
| 69 |
+
flow = flow[:, :, [1, 0]]
|
| 70 |
+
return flow
|
| 71 |
+
|
| 72 |
+
if not isinstance(imgs, list):
|
| 73 |
+
imgs = [imgs]
|
| 74 |
+
imgs = [_augment(img) for img in imgs]
|
| 75 |
+
if len(imgs) == 1:
|
| 76 |
+
imgs = imgs[0]
|
| 77 |
+
|
| 78 |
+
if flows is not None:
|
| 79 |
+
if not isinstance(flows, list):
|
| 80 |
+
flows = [flows]
|
| 81 |
+
flows = [_augment_flow(flow) for flow in flows]
|
| 82 |
+
if len(flows) == 1:
|
| 83 |
+
flows = flows[0]
|
| 84 |
+
return imgs, flows
|
| 85 |
+
else:
|
| 86 |
+
if return_status:
|
| 87 |
+
return imgs, (hflip, vflip, rot90)
|
| 88 |
+
else:
|
| 89 |
+
return imgs
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def img_rotate(img, angle, center=None, scale=1.0):
|
| 93 |
+
"""Rotate image.
|
| 94 |
+
|
| 95 |
+
Args:
|
| 96 |
+
img (ndarray): Image to be rotated.
|
| 97 |
+
angle (float): Rotation angle in degrees. Positive values mean
|
| 98 |
+
counter-clockwise rotation.
|
| 99 |
+
center (tuple[int]): Rotation center. If the center is None,
|
| 100 |
+
initialize it as the center of the image. Default: None.
|
| 101 |
+
scale (float): Isotropic scale factor. Default: 1.0.
|
| 102 |
+
"""
|
| 103 |
+
(h, w) = img.shape[:2]
|
| 104 |
+
|
| 105 |
+
if center is None:
|
| 106 |
+
center = (w // 2, h // 2)
|
| 107 |
+
|
| 108 |
+
matrix = cv2.getRotationMatrix2D(center, angle, scale)
|
| 109 |
+
rotated_img = cv2.warpAffine(img, matrix, (w, h))
|
| 110 |
+
return rotated_img
|
| 111 |
+
|
| 112 |
+
def data_augmentation(image, mode):
|
| 113 |
+
"""
|
| 114 |
+
Performs data augmentation of the input image
|
| 115 |
+
Input:
|
| 116 |
+
image: a cv2 (OpenCV) image
|
| 117 |
+
mode: int. Choice of transformation to apply to the image
|
| 118 |
+
0 - no transformation
|
| 119 |
+
1 - flip up and down
|
| 120 |
+
2 - rotate counterwise 90 degree
|
| 121 |
+
3 - rotate 90 degree and flip up and down
|
| 122 |
+
4 - rotate 180 degree
|
| 123 |
+
5 - rotate 180 degree and flip
|
| 124 |
+
6 - rotate 270 degree
|
| 125 |
+
7 - rotate 270 degree and flip
|
| 126 |
+
"""
|
| 127 |
+
if mode == 0:
|
| 128 |
+
# original
|
| 129 |
+
out = image
|
| 130 |
+
elif mode == 1:
|
| 131 |
+
# flip up and down
|
| 132 |
+
out = np.flipud(image)
|
| 133 |
+
elif mode == 2:
|
| 134 |
+
# rotate counterwise 90 degree
|
| 135 |
+
out = np.rot90(image)
|
| 136 |
+
elif mode == 3:
|
| 137 |
+
# rotate 90 degree and flip up and down
|
| 138 |
+
out = np.rot90(image)
|
| 139 |
+
out = np.flipud(out)
|
| 140 |
+
elif mode == 4:
|
| 141 |
+
# rotate 180 degree
|
| 142 |
+
out = np.rot90(image, k=2)
|
| 143 |
+
elif mode == 5:
|
| 144 |
+
# rotate 180 degree and flip
|
| 145 |
+
out = np.rot90(image, k=2)
|
| 146 |
+
out = np.flipud(out)
|
| 147 |
+
elif mode == 6:
|
| 148 |
+
# rotate 270 degree
|
| 149 |
+
out = np.rot90(image, k=3)
|
| 150 |
+
elif mode == 7:
|
| 151 |
+
# rotate 270 degree and flip
|
| 152 |
+
out = np.rot90(image, k=3)
|
| 153 |
+
out = np.flipud(out)
|
| 154 |
+
else:
|
| 155 |
+
raise Exception('Invalid choice of image transformation')
|
| 156 |
+
|
| 157 |
+
return out
|
| 158 |
+
|
| 159 |
+
def random_augmentation(*args):
|
| 160 |
+
out = []
|
| 161 |
+
flag_aug = random.randint(0,7)
|
| 162 |
+
for data in args:
|
| 163 |
+
if type(data) == list:
|
| 164 |
+
out.append([data_augmentation(_data, flag_aug).copy() for _data in data])
|
| 165 |
+
else:
|
| 166 |
+
out.append(data_augmentation(data, flag_aug).copy())
|
| 167 |
+
return out
|
basicsr/metrics/__init__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .niqe import calculate_niqe
|
| 2 |
+
from .psnr_ssim import calculate_psnr, calculate_ssim
|
| 3 |
+
|
| 4 |
+
__all__ = ['calculate_psnr', 'calculate_ssim', 'calculate_niqe']
|
basicsr/metrics/fid.py
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
from scipy import linalg
|
| 5 |
+
from tqdm import tqdm
|
| 6 |
+
|
| 7 |
+
from basicsr.models.archs.inception import InceptionV3
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def load_patched_inception_v3(device='cuda',
|
| 11 |
+
resize_input=True,
|
| 12 |
+
normalize_input=False):
|
| 13 |
+
# we may not resize the input, but in [rosinality/stylegan2-pytorch] it
|
| 14 |
+
# does resize the input.
|
| 15 |
+
inception = InceptionV3([3],
|
| 16 |
+
resize_input=resize_input,
|
| 17 |
+
normalize_input=normalize_input)
|
| 18 |
+
inception = nn.DataParallel(inception).eval().to(device)
|
| 19 |
+
return inception
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
@torch.no_grad()
|
| 23 |
+
def extract_inception_features(data_generator,
|
| 24 |
+
inception,
|
| 25 |
+
len_generator=None,
|
| 26 |
+
device='cuda'):
|
| 27 |
+
"""Extract inception features.
|
| 28 |
+
|
| 29 |
+
Args:
|
| 30 |
+
data_generator (generator): A data generator.
|
| 31 |
+
inception (nn.Module): Inception model.
|
| 32 |
+
len_generator (int): Length of the data_generator to show the
|
| 33 |
+
progressbar. Default: None.
|
| 34 |
+
device (str): Device. Default: cuda.
|
| 35 |
+
|
| 36 |
+
Returns:
|
| 37 |
+
Tensor: Extracted features.
|
| 38 |
+
"""
|
| 39 |
+
if len_generator is not None:
|
| 40 |
+
pbar = tqdm(total=len_generator, unit='batch', desc='Extract')
|
| 41 |
+
else:
|
| 42 |
+
pbar = None
|
| 43 |
+
features = []
|
| 44 |
+
|
| 45 |
+
for data in data_generator:
|
| 46 |
+
if pbar:
|
| 47 |
+
pbar.update(1)
|
| 48 |
+
data = data.to(device)
|
| 49 |
+
feature = inception(data)[0].view(data.shape[0], -1)
|
| 50 |
+
features.append(feature.to('cpu'))
|
| 51 |
+
if pbar:
|
| 52 |
+
pbar.close()
|
| 53 |
+
features = torch.cat(features, 0)
|
| 54 |
+
return features
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def calculate_fid(mu1, sigma1, mu2, sigma2, eps=1e-6):
|
| 58 |
+
"""Numpy implementation of the Frechet Distance.
|
| 59 |
+
|
| 60 |
+
The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
|
| 61 |
+
and X_2 ~ N(mu_2, C_2) is
|
| 62 |
+
d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
|
| 63 |
+
Stable version by Dougal J. Sutherland.
|
| 64 |
+
|
| 65 |
+
Args:
|
| 66 |
+
mu1 (np.array): The sample mean over activations.
|
| 67 |
+
sigma1 (np.array): The covariance matrix over activations for
|
| 68 |
+
generated samples.
|
| 69 |
+
mu2 (np.array): The sample mean over activations, precalculated on an
|
| 70 |
+
representative data set.
|
| 71 |
+
sigma2 (np.array): The covariance matrix over activations,
|
| 72 |
+
precalculated on an representative data set.
|
| 73 |
+
|
| 74 |
+
Returns:
|
| 75 |
+
float: The Frechet Distance.
|
| 76 |
+
"""
|
| 77 |
+
assert mu1.shape == mu2.shape, 'Two mean vectors have different lengths'
|
| 78 |
+
assert sigma1.shape == sigma2.shape, (
|
| 79 |
+
'Two covariances have different dimensions')
|
| 80 |
+
|
| 81 |
+
cov_sqrt, _ = linalg.sqrtm(sigma1 @ sigma2, disp=False)
|
| 82 |
+
|
| 83 |
+
# Product might be almost singular
|
| 84 |
+
if not np.isfinite(cov_sqrt).all():
|
| 85 |
+
print('Product of cov matrices is singular. Adding {eps} to diagonal '
|
| 86 |
+
'of cov estimates')
|
| 87 |
+
offset = np.eye(sigma1.shape[0]) * eps
|
| 88 |
+
cov_sqrt = linalg.sqrtm((sigma1 + offset) @ (sigma2 + offset))
|
| 89 |
+
|
| 90 |
+
# Numerical error might give slight imaginary component
|
| 91 |
+
if np.iscomplexobj(cov_sqrt):
|
| 92 |
+
if not np.allclose(np.diagonal(cov_sqrt).imag, 0, atol=1e-3):
|
| 93 |
+
m = np.max(np.abs(cov_sqrt.imag))
|
| 94 |
+
raise ValueError(f'Imaginary component {m}')
|
| 95 |
+
cov_sqrt = cov_sqrt.real
|
| 96 |
+
|
| 97 |
+
mean_diff = mu1 - mu2
|
| 98 |
+
mean_norm = mean_diff @ mean_diff
|
| 99 |
+
trace = np.trace(sigma1) + np.trace(sigma2) - 2 * np.trace(cov_sqrt)
|
| 100 |
+
fid = mean_norm + trace
|
| 101 |
+
|
| 102 |
+
return fid
|
basicsr/metrics/metric_util.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
|
| 3 |
+
from basicsr.utils.matlab_functions import bgr2ycbcr
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def reorder_image(img, input_order='HWC'):
|
| 7 |
+
"""Reorder images to 'HWC' order.
|
| 8 |
+
|
| 9 |
+
If the input_order is (h, w), return (h, w, 1);
|
| 10 |
+
If the input_order is (c, h, w), return (h, w, c);
|
| 11 |
+
If the input_order is (h, w, c), return as it is.
|
| 12 |
+
|
| 13 |
+
Args:
|
| 14 |
+
img (ndarray): Input image.
|
| 15 |
+
input_order (str): Whether the input order is 'HWC' or 'CHW'.
|
| 16 |
+
If the input image shape is (h, w), input_order will not have
|
| 17 |
+
effects. Default: 'HWC'.
|
| 18 |
+
|
| 19 |
+
Returns:
|
| 20 |
+
ndarray: reordered image.
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
if input_order not in ['HWC', 'CHW']:
|
| 24 |
+
raise ValueError(
|
| 25 |
+
f'Wrong input_order {input_order}. Supported input_orders are '
|
| 26 |
+
"'HWC' and 'CHW'")
|
| 27 |
+
if len(img.shape) == 2:
|
| 28 |
+
img = img[..., None]
|
| 29 |
+
if input_order == 'CHW':
|
| 30 |
+
img = img.transpose(1, 2, 0)
|
| 31 |
+
return img
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def to_y_channel(img):
|
| 35 |
+
"""Change to Y channel of YCbCr.
|
| 36 |
+
|
| 37 |
+
Args:
|
| 38 |
+
img (ndarray): Images with range [0, 255].
|
| 39 |
+
|
| 40 |
+
Returns:
|
| 41 |
+
(ndarray): Images with range [0, 255] (float type) without round.
|
| 42 |
+
"""
|
| 43 |
+
img = img.astype(np.float32) / 255.
|
| 44 |
+
if img.ndim == 3 and img.shape[2] == 3:
|
| 45 |
+
img = bgr2ycbcr(img, y_only=True)
|
| 46 |
+
img = img[..., None]
|
| 47 |
+
return img * 255.
|
basicsr/metrics/niqe.py
ADDED
|
@@ -0,0 +1,205 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import cv2
|
| 2 |
+
import math
|
| 3 |
+
import numpy as np
|
| 4 |
+
from scipy.ndimage.filters import convolve
|
| 5 |
+
from scipy.special import gamma
|
| 6 |
+
|
| 7 |
+
from basicsr.metrics.metric_util import reorder_image, to_y_channel
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def estimate_aggd_param(block):
|
| 11 |
+
"""Estimate AGGD (Asymmetric Generalized Gaussian Distribution) paramters.
|
| 12 |
+
|
| 13 |
+
Args:
|
| 14 |
+
block (ndarray): 2D Image block.
|
| 15 |
+
|
| 16 |
+
Returns:
|
| 17 |
+
tuple: alpha (float), beta_l (float) and beta_r (float) for the AGGD
|
| 18 |
+
distribution (Estimating the parames in Equation 7 in the paper).
|
| 19 |
+
"""
|
| 20 |
+
block = block.flatten()
|
| 21 |
+
gam = np.arange(0.2, 10.001, 0.001) # len = 9801
|
| 22 |
+
gam_reciprocal = np.reciprocal(gam)
|
| 23 |
+
r_gam = np.square(gamma(gam_reciprocal * 2)) / (
|
| 24 |
+
gamma(gam_reciprocal) * gamma(gam_reciprocal * 3))
|
| 25 |
+
|
| 26 |
+
left_std = np.sqrt(np.mean(block[block < 0]**2))
|
| 27 |
+
right_std = np.sqrt(np.mean(block[block > 0]**2))
|
| 28 |
+
gammahat = left_std / right_std
|
| 29 |
+
rhat = (np.mean(np.abs(block)))**2 / np.mean(block**2)
|
| 30 |
+
rhatnorm = (rhat * (gammahat**3 + 1) *
|
| 31 |
+
(gammahat + 1)) / ((gammahat**2 + 1)**2)
|
| 32 |
+
array_position = np.argmin((r_gam - rhatnorm)**2)
|
| 33 |
+
|
| 34 |
+
alpha = gam[array_position]
|
| 35 |
+
beta_l = left_std * np.sqrt(gamma(1 / alpha) / gamma(3 / alpha))
|
| 36 |
+
beta_r = right_std * np.sqrt(gamma(1 / alpha) / gamma(3 / alpha))
|
| 37 |
+
return (alpha, beta_l, beta_r)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def compute_feature(block):
|
| 41 |
+
"""Compute features.
|
| 42 |
+
|
| 43 |
+
Args:
|
| 44 |
+
block (ndarray): 2D Image block.
|
| 45 |
+
|
| 46 |
+
Returns:
|
| 47 |
+
list: Features with length of 18.
|
| 48 |
+
"""
|
| 49 |
+
feat = []
|
| 50 |
+
alpha, beta_l, beta_r = estimate_aggd_param(block)
|
| 51 |
+
feat.extend([alpha, (beta_l + beta_r) / 2])
|
| 52 |
+
|
| 53 |
+
# distortions disturb the fairly regular structure of natural images.
|
| 54 |
+
# This deviation can be captured by analyzing the sample distribution of
|
| 55 |
+
# the products of pairs of adjacent coefficients computed along
|
| 56 |
+
# horizontal, vertical and diagonal orientations.
|
| 57 |
+
shifts = [[0, 1], [1, 0], [1, 1], [1, -1]]
|
| 58 |
+
for i in range(len(shifts)):
|
| 59 |
+
shifted_block = np.roll(block, shifts[i], axis=(0, 1))
|
| 60 |
+
alpha, beta_l, beta_r = estimate_aggd_param(block * shifted_block)
|
| 61 |
+
# Eq. 8
|
| 62 |
+
mean = (beta_r - beta_l) * (gamma(2 / alpha) / gamma(1 / alpha))
|
| 63 |
+
feat.extend([alpha, mean, beta_l, beta_r])
|
| 64 |
+
return feat
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def niqe(img,
|
| 68 |
+
mu_pris_param,
|
| 69 |
+
cov_pris_param,
|
| 70 |
+
gaussian_window,
|
| 71 |
+
block_size_h=96,
|
| 72 |
+
block_size_w=96):
|
| 73 |
+
"""Calculate NIQE (Natural Image Quality Evaluator) metric.
|
| 74 |
+
|
| 75 |
+
Ref: Making a "Completely Blind" Image Quality Analyzer.
|
| 76 |
+
This implementation could produce almost the same results as the official
|
| 77 |
+
MATLAB codes: http://live.ece.utexas.edu/research/quality/niqe_release.zip
|
| 78 |
+
|
| 79 |
+
Note that we do not include block overlap height and width, since they are
|
| 80 |
+
always 0 in the official implementation.
|
| 81 |
+
|
| 82 |
+
For good performance, it is advisable by the official implemtation to
|
| 83 |
+
divide the distorted image in to the same size patched as used for the
|
| 84 |
+
construction of multivariate Gaussian model.
|
| 85 |
+
|
| 86 |
+
Args:
|
| 87 |
+
img (ndarray): Input image whose quality needs to be computed. The
|
| 88 |
+
image must be a gray or Y (of YCbCr) image with shape (h, w).
|
| 89 |
+
Range [0, 255] with float type.
|
| 90 |
+
mu_pris_param (ndarray): Mean of a pre-defined multivariate Gaussian
|
| 91 |
+
model calculated on the pristine dataset.
|
| 92 |
+
cov_pris_param (ndarray): Covariance of a pre-defined multivariate
|
| 93 |
+
Gaussian model calculated on the pristine dataset.
|
| 94 |
+
gaussian_window (ndarray): A 7x7 Gaussian window used for smoothing the
|
| 95 |
+
image.
|
| 96 |
+
block_size_h (int): Height of the blocks in to which image is divided.
|
| 97 |
+
Default: 96 (the official recommended value).
|
| 98 |
+
block_size_w (int): Width of the blocks in to which image is divided.
|
| 99 |
+
Default: 96 (the official recommended value).
|
| 100 |
+
"""
|
| 101 |
+
assert img.ndim == 2, (
|
| 102 |
+
'Input image must be a gray or Y (of YCbCr) image with shape (h, w).')
|
| 103 |
+
# crop image
|
| 104 |
+
h, w = img.shape
|
| 105 |
+
num_block_h = math.floor(h / block_size_h)
|
| 106 |
+
num_block_w = math.floor(w / block_size_w)
|
| 107 |
+
img = img[0:num_block_h * block_size_h, 0:num_block_w * block_size_w]
|
| 108 |
+
|
| 109 |
+
distparam = [] # dist param is actually the multiscale features
|
| 110 |
+
for scale in (1, 2): # perform on two scales (1, 2)
|
| 111 |
+
mu = convolve(img, gaussian_window, mode='nearest')
|
| 112 |
+
sigma = np.sqrt(
|
| 113 |
+
np.abs(
|
| 114 |
+
convolve(np.square(img), gaussian_window, mode='nearest') -
|
| 115 |
+
np.square(mu)))
|
| 116 |
+
# normalize, as in Eq. 1 in the paper
|
| 117 |
+
img_nomalized = (img - mu) / (sigma + 1)
|
| 118 |
+
|
| 119 |
+
feat = []
|
| 120 |
+
for idx_w in range(num_block_w):
|
| 121 |
+
for idx_h in range(num_block_h):
|
| 122 |
+
# process ecah block
|
| 123 |
+
block = img_nomalized[idx_h * block_size_h //
|
| 124 |
+
scale:(idx_h + 1) * block_size_h //
|
| 125 |
+
scale, idx_w * block_size_w //
|
| 126 |
+
scale:(idx_w + 1) * block_size_w //
|
| 127 |
+
scale]
|
| 128 |
+
feat.append(compute_feature(block))
|
| 129 |
+
|
| 130 |
+
distparam.append(np.array(feat))
|
| 131 |
+
# TODO: matlab bicubic downsample with anti-aliasing
|
| 132 |
+
# for simplicity, now we use opencv instead, which will result in
|
| 133 |
+
# a slight difference.
|
| 134 |
+
if scale == 1:
|
| 135 |
+
h, w = img.shape
|
| 136 |
+
img = cv2.resize(
|
| 137 |
+
img / 255., (w // 2, h // 2), interpolation=cv2.INTER_LINEAR)
|
| 138 |
+
img = img * 255.
|
| 139 |
+
|
| 140 |
+
distparam = np.concatenate(distparam, axis=1)
|
| 141 |
+
|
| 142 |
+
# fit a MVG (multivariate Gaussian) model to distorted patch features
|
| 143 |
+
mu_distparam = np.nanmean(distparam, axis=0)
|
| 144 |
+
# use nancov. ref: https://ww2.mathworks.cn/help/stats/nancov.html
|
| 145 |
+
distparam_no_nan = distparam[~np.isnan(distparam).any(axis=1)]
|
| 146 |
+
cov_distparam = np.cov(distparam_no_nan, rowvar=False)
|
| 147 |
+
|
| 148 |
+
# compute niqe quality, Eq. 10 in the paper
|
| 149 |
+
invcov_param = np.linalg.pinv((cov_pris_param + cov_distparam) / 2)
|
| 150 |
+
quality = np.matmul(
|
| 151 |
+
np.matmul((mu_pris_param - mu_distparam), invcov_param),
|
| 152 |
+
np.transpose((mu_pris_param - mu_distparam)))
|
| 153 |
+
quality = np.sqrt(quality)
|
| 154 |
+
|
| 155 |
+
return quality
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
def calculate_niqe(img, crop_border, input_order='HWC', convert_to='y'):
|
| 159 |
+
"""Calculate NIQE (Natural Image Quality Evaluator) metric.
|
| 160 |
+
|
| 161 |
+
Ref: Making a "Completely Blind" Image Quality Analyzer.
|
| 162 |
+
This implementation could produce almost the same results as the official
|
| 163 |
+
MATLAB codes: http://live.ece.utexas.edu/research/quality/niqe_release.zip
|
| 164 |
+
|
| 165 |
+
We use the official params estimated from the pristine dataset.
|
| 166 |
+
We use the recommended block size (96, 96) without overlaps.
|
| 167 |
+
|
| 168 |
+
Args:
|
| 169 |
+
img (ndarray): Input image whose quality needs to be computed.
|
| 170 |
+
The input image must be in range [0, 255] with float/int type.
|
| 171 |
+
The input_order of image can be 'HW' or 'HWC' or 'CHW'. (BGR order)
|
| 172 |
+
If the input order is 'HWC' or 'CHW', it will be converted to gray
|
| 173 |
+
or Y (of YCbCr) image according to the ``convert_to`` argument.
|
| 174 |
+
crop_border (int): Cropped pixels in each edge of an image. These
|
| 175 |
+
pixels are not involved in the metric calculation.
|
| 176 |
+
input_order (str): Whether the input order is 'HW', 'HWC' or 'CHW'.
|
| 177 |
+
Default: 'HWC'.
|
| 178 |
+
convert_to (str): Whether coverted to 'y' (of MATLAB YCbCr) or 'gray'.
|
| 179 |
+
Default: 'y'.
|
| 180 |
+
|
| 181 |
+
Returns:
|
| 182 |
+
float: NIQE result.
|
| 183 |
+
"""
|
| 184 |
+
|
| 185 |
+
# we use the official params estimated from the pristine dataset.
|
| 186 |
+
niqe_pris_params = np.load('basicsr/metrics/niqe_pris_params.npz')
|
| 187 |
+
mu_pris_param = niqe_pris_params['mu_pris_param']
|
| 188 |
+
cov_pris_param = niqe_pris_params['cov_pris_param']
|
| 189 |
+
gaussian_window = niqe_pris_params['gaussian_window']
|
| 190 |
+
|
| 191 |
+
img = img.astype(np.float32)
|
| 192 |
+
if input_order != 'HW':
|
| 193 |
+
img = reorder_image(img, input_order=input_order)
|
| 194 |
+
if convert_to == 'y':
|
| 195 |
+
img = to_y_channel(img)
|
| 196 |
+
elif convert_to == 'gray':
|
| 197 |
+
img = cv2.cvtColor(img / 255., cv2.COLOR_BGR2GRAY) * 255.
|
| 198 |
+
img = np.squeeze(img)
|
| 199 |
+
|
| 200 |
+
if crop_border != 0:
|
| 201 |
+
img = img[crop_border:-crop_border, crop_border:-crop_border]
|
| 202 |
+
|
| 203 |
+
niqe_result = niqe(img, mu_pris_param, cov_pris_param, gaussian_window)
|
| 204 |
+
|
| 205 |
+
return niqe_result
|
basicsr/metrics/niqe_pris_params.npz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:2a7c182a68c9e7f1b2e2e5ec723279d6f65d912b6fcaf37eb2bf03d7367c4296
|
| 3 |
+
size 11850
|
basicsr/metrics/other_metrics.py
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
import os
|
| 4 |
+
from PIL import Image
|
| 5 |
+
from natsort import natsorted
|
| 6 |
+
from glob import glob
|
| 7 |
+
from skimage import metrics
|
| 8 |
+
import torch.hub
|
| 9 |
+
from lpips.lpips import LPIPS
|
| 10 |
+
from tqdm import tqdm
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
photometric = {
|
| 14 |
+
"mse": None,
|
| 15 |
+
"ssim": None,
|
| 16 |
+
"psnr": None,
|
| 17 |
+
"lpips": None
|
| 18 |
+
}
|
| 19 |
+
|
| 20 |
+
def psnr(img1, img2):
|
| 21 |
+
mse = (((img1 - img2)) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True)
|
| 22 |
+
return 20 * torch.log10(1.0 / torch.sqrt(mse))
|
| 23 |
+
|
| 24 |
+
def compute_img_metric(im1t: torch.Tensor, im2t: torch.Tensor,
|
| 25 |
+
metric="mse", mask=None):
|
| 26 |
+
"""
|
| 27 |
+
im1t, im2t: torch.tensors with batched imaged shape, range from (0, 1)
|
| 28 |
+
"""
|
| 29 |
+
if metric not in photometric.keys():
|
| 30 |
+
raise RuntimeError(f"img_utils:: metric {metric} not recognized")
|
| 31 |
+
if photometric[metric] is None:
|
| 32 |
+
if metric == "mse":
|
| 33 |
+
photometric[metric] = metrics.mean_squared_error
|
| 34 |
+
elif metric == "ssim":
|
| 35 |
+
photometric[metric] = metrics.structural_similarity
|
| 36 |
+
elif metric == "psnr":
|
| 37 |
+
photometric[metric] = metrics.peak_signal_noise_ratio
|
| 38 |
+
elif metric == "lpips":
|
| 39 |
+
photometric[metric] = LPIPS().cpu()
|
| 40 |
+
|
| 41 |
+
# convert from [0, 1] to [-1, 1]
|
| 42 |
+
im1t = (im1t * 2 - 1).clamp(-1, 1)
|
| 43 |
+
im2t = (im2t * 2 - 1).clamp(-1, 1)
|
| 44 |
+
|
| 45 |
+
if im1t.dim() == 3:
|
| 46 |
+
im1t = im1t.unsqueeze(0)
|
| 47 |
+
im2t = im2t.unsqueeze(0)
|
| 48 |
+
im1t = im1t.detach().cpu()
|
| 49 |
+
im2t = im2t.detach().cpu()
|
| 50 |
+
|
| 51 |
+
if im1t.shape[-1] == 3:
|
| 52 |
+
im1t = im1t.permute(0, 3, 1, 2) # BCHW
|
| 53 |
+
im2t = im2t.permute(0, 3, 1, 2)
|
| 54 |
+
|
| 55 |
+
im1 = im1t.permute(0, 2, 3, 1).numpy()
|
| 56 |
+
im2 = im2t.permute(0, 2, 3, 1).numpy()
|
| 57 |
+
batchsz, hei, wid, _ = im1.shape
|
| 58 |
+
values = []
|
| 59 |
+
|
| 60 |
+
for i in range(batchsz):
|
| 61 |
+
if metric in ["mse", "psnr"]:
|
| 62 |
+
if mask is not None:
|
| 63 |
+
im1 = im1 * mask[i]
|
| 64 |
+
im2 = im2 * mask[i]
|
| 65 |
+
value = photometric[metric](
|
| 66 |
+
im1[i], im2[i]
|
| 67 |
+
)
|
| 68 |
+
if mask is not None:
|
| 69 |
+
hei, wid, _ = im1[i].shape
|
| 70 |
+
pixelnum = mask[i, ..., 0].sum()
|
| 71 |
+
value = value - 10 * np.log10(hei * wid / pixelnum)
|
| 72 |
+
elif metric in ["ssim"]:
|
| 73 |
+
value, ssimmap = photometric["ssim"](
|
| 74 |
+
im1[i], im2[i], multichannel=True, full=True
|
| 75 |
+
)
|
| 76 |
+
if mask is not None:
|
| 77 |
+
value = (ssimmap * mask[i]).sum() / mask[i].sum()
|
| 78 |
+
elif metric in ["lpips"]:
|
| 79 |
+
value = photometric[metric](
|
| 80 |
+
im1t[i:i + 1], im2t[i:i + 1]
|
| 81 |
+
)
|
| 82 |
+
else:
|
| 83 |
+
raise NotImplementedError
|
| 84 |
+
values.append(value)
|
| 85 |
+
|
| 86 |
+
return sum(values) / len(values)
|
| 87 |
+
|
| 88 |
+
|
basicsr/metrics/psnr_ssim.py
ADDED
|
@@ -0,0 +1,303 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import cv2
|
| 2 |
+
import numpy as np
|
| 3 |
+
|
| 4 |
+
from basicsr.metrics.metric_util import reorder_image, to_y_channel
|
| 5 |
+
import skimage.metrics
|
| 6 |
+
import torch
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def calculate_psnr(img1,
|
| 10 |
+
img2,
|
| 11 |
+
crop_border,
|
| 12 |
+
input_order='HWC',
|
| 13 |
+
test_y_channel=False):
|
| 14 |
+
"""Calculate PSNR (Peak Signal-to-Noise Ratio).
|
| 15 |
+
|
| 16 |
+
Ref: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio
|
| 17 |
+
|
| 18 |
+
Args:
|
| 19 |
+
img1 (ndarray/tensor): Images with range [0, 255]/[0, 1].
|
| 20 |
+
img2 (ndarray/tensor): Images with range [0, 255]/[0, 1].
|
| 21 |
+
crop_border (int): Cropped pixels in each edge of an image. These
|
| 22 |
+
pixels are not involved in the PSNR calculation.
|
| 23 |
+
input_order (str): Whether the input order is 'HWC' or 'CHW'.
|
| 24 |
+
Default: 'HWC'.
|
| 25 |
+
test_y_channel (bool): Test on Y channel of YCbCr. Default: False.
|
| 26 |
+
|
| 27 |
+
Returns:
|
| 28 |
+
float: psnr result.
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
assert img1.shape == img2.shape, (
|
| 32 |
+
f'Image shapes are differnet: {img1.shape}, {img2.shape}.')
|
| 33 |
+
if input_order not in ['HWC', 'CHW']:
|
| 34 |
+
raise ValueError(
|
| 35 |
+
f'Wrong input_order {input_order}. Supported input_orders are '
|
| 36 |
+
'"HWC" and "CHW"')
|
| 37 |
+
if type(img1) == torch.Tensor:
|
| 38 |
+
if len(img1.shape) == 4:
|
| 39 |
+
img1 = img1.squeeze(0)
|
| 40 |
+
img1 = img1.detach().cpu().numpy().transpose(1,2,0)
|
| 41 |
+
if type(img2) == torch.Tensor:
|
| 42 |
+
if len(img2.shape) == 4:
|
| 43 |
+
img2 = img2.squeeze(0)
|
| 44 |
+
img2 = img2.detach().cpu().numpy().transpose(1,2,0)
|
| 45 |
+
|
| 46 |
+
img1 = reorder_image(img1, input_order=input_order)
|
| 47 |
+
img2 = reorder_image(img2, input_order=input_order)
|
| 48 |
+
img1 = img1.astype(np.float64)
|
| 49 |
+
img2 = img2.astype(np.float64)
|
| 50 |
+
|
| 51 |
+
if crop_border != 0:
|
| 52 |
+
img1 = img1[crop_border:-crop_border, crop_border:-crop_border, ...]
|
| 53 |
+
img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...]
|
| 54 |
+
|
| 55 |
+
if test_y_channel:
|
| 56 |
+
img1 = to_y_channel(img1)
|
| 57 |
+
img2 = to_y_channel(img2)
|
| 58 |
+
|
| 59 |
+
mse = np.mean((img1 - img2)**2)
|
| 60 |
+
if mse == 0:
|
| 61 |
+
return float('inf')
|
| 62 |
+
max_value = 1. if img1.max() <= 1 else 255.
|
| 63 |
+
return 20. * np.log10(max_value / np.sqrt(mse))
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def _ssim(img1, img2):
|
| 67 |
+
"""Calculate SSIM (structural similarity) for one channel images.
|
| 68 |
+
|
| 69 |
+
It is called by func:`calculate_ssim`.
|
| 70 |
+
|
| 71 |
+
Args:
|
| 72 |
+
img1 (ndarray): Images with range [0, 255] with order 'HWC'.
|
| 73 |
+
img2 (ndarray): Images with range [0, 255] with order 'HWC'.
|
| 74 |
+
|
| 75 |
+
Returns:
|
| 76 |
+
float: ssim result.
|
| 77 |
+
"""
|
| 78 |
+
|
| 79 |
+
C1 = (0.01 * 255)**2
|
| 80 |
+
C2 = (0.03 * 255)**2
|
| 81 |
+
|
| 82 |
+
img1 = img1.astype(np.float64)
|
| 83 |
+
img2 = img2.astype(np.float64)
|
| 84 |
+
kernel = cv2.getGaussianKernel(11, 1.5)
|
| 85 |
+
window = np.outer(kernel, kernel.transpose())
|
| 86 |
+
|
| 87 |
+
mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5]
|
| 88 |
+
mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
|
| 89 |
+
mu1_sq = mu1**2
|
| 90 |
+
mu2_sq = mu2**2
|
| 91 |
+
mu1_mu2 = mu1 * mu2
|
| 92 |
+
sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq
|
| 93 |
+
sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
|
| 94 |
+
sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
|
| 95 |
+
|
| 96 |
+
ssim_map = ((2 * mu1_mu2 + C1) *
|
| 97 |
+
(2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
|
| 98 |
+
(sigma1_sq + sigma2_sq + C2))
|
| 99 |
+
return ssim_map.mean()
|
| 100 |
+
|
| 101 |
+
def prepare_for_ssim(img, k):
|
| 102 |
+
import torch
|
| 103 |
+
with torch.no_grad():
|
| 104 |
+
img = torch.from_numpy(img).unsqueeze(0).unsqueeze(0).float()
|
| 105 |
+
conv = torch.nn.Conv2d(1, 1, k, stride=1, padding=k//2, padding_mode='reflect')
|
| 106 |
+
conv.weight.requires_grad = False
|
| 107 |
+
conv.weight[:, :, :, :] = 1. / (k * k)
|
| 108 |
+
|
| 109 |
+
img = conv(img)
|
| 110 |
+
|
| 111 |
+
img = img.squeeze(0).squeeze(0)
|
| 112 |
+
img = img[0::k, 0::k]
|
| 113 |
+
return img.detach().cpu().numpy()
|
| 114 |
+
|
| 115 |
+
def prepare_for_ssim_rgb(img, k):
|
| 116 |
+
import torch
|
| 117 |
+
with torch.no_grad():
|
| 118 |
+
img = torch.from_numpy(img).float() #HxWx3
|
| 119 |
+
|
| 120 |
+
conv = torch.nn.Conv2d(1, 1, k, stride=1, padding=k // 2, padding_mode='reflect')
|
| 121 |
+
conv.weight.requires_grad = False
|
| 122 |
+
conv.weight[:, :, :, :] = 1. / (k * k)
|
| 123 |
+
|
| 124 |
+
new_img = []
|
| 125 |
+
|
| 126 |
+
for i in range(3):
|
| 127 |
+
new_img.append(conv(img[:, :, i].unsqueeze(0).unsqueeze(0)).squeeze(0).squeeze(0)[0::k, 0::k])
|
| 128 |
+
|
| 129 |
+
return torch.stack(new_img, dim=2).detach().cpu().numpy()
|
| 130 |
+
|
| 131 |
+
def _3d_gaussian_calculator(img, conv3d):
|
| 132 |
+
out = conv3d(img.unsqueeze(0).unsqueeze(0)).squeeze(0).squeeze(0)
|
| 133 |
+
return out
|
| 134 |
+
|
| 135 |
+
def _generate_3d_gaussian_kernel():
|
| 136 |
+
kernel = cv2.getGaussianKernel(11, 1.5)
|
| 137 |
+
window = np.outer(kernel, kernel.transpose())
|
| 138 |
+
kernel_3 = cv2.getGaussianKernel(11, 1.5)
|
| 139 |
+
kernel = torch.tensor(np.stack([window * k for k in kernel_3], axis=0))
|
| 140 |
+
conv3d = torch.nn.Conv3d(1, 1, (11, 11, 11), stride=1, padding=(5, 5, 5), bias=False, padding_mode='replicate')
|
| 141 |
+
conv3d.weight.requires_grad = False
|
| 142 |
+
conv3d.weight[0, 0, :, :, :] = kernel
|
| 143 |
+
return conv3d
|
| 144 |
+
|
| 145 |
+
def _ssim_3d(img1, img2, max_value):
|
| 146 |
+
assert len(img1.shape) == 3 and len(img2.shape) == 3
|
| 147 |
+
"""Calculate SSIM (structural similarity) for one channel images.
|
| 148 |
+
|
| 149 |
+
It is called by func:`calculate_ssim`.
|
| 150 |
+
|
| 151 |
+
Args:
|
| 152 |
+
img1 (ndarray): Images with range [0, 255]/[0, 1] with order 'HWC'.
|
| 153 |
+
img2 (ndarray): Images with range [0, 255]/[0, 1] with order 'HWC'.
|
| 154 |
+
|
| 155 |
+
Returns:
|
| 156 |
+
float: ssim result.
|
| 157 |
+
"""
|
| 158 |
+
C1 = (0.01 * max_value) ** 2
|
| 159 |
+
C2 = (0.03 * max_value) ** 2
|
| 160 |
+
img1 = img1.astype(np.float64)
|
| 161 |
+
img2 = img2.astype(np.float64)
|
| 162 |
+
|
| 163 |
+
kernel = _generate_3d_gaussian_kernel().cuda()
|
| 164 |
+
|
| 165 |
+
img1 = torch.tensor(img1).float().cuda()
|
| 166 |
+
img2 = torch.tensor(img2).float().cuda()
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
mu1 = _3d_gaussian_calculator(img1, kernel)
|
| 170 |
+
mu2 = _3d_gaussian_calculator(img2, kernel)
|
| 171 |
+
|
| 172 |
+
mu1_sq = mu1 ** 2
|
| 173 |
+
mu2_sq = mu2 ** 2
|
| 174 |
+
mu1_mu2 = mu1 * mu2
|
| 175 |
+
sigma1_sq = _3d_gaussian_calculator(img1 ** 2, kernel) - mu1_sq
|
| 176 |
+
sigma2_sq = _3d_gaussian_calculator(img2 ** 2, kernel) - mu2_sq
|
| 177 |
+
sigma12 = _3d_gaussian_calculator(img1*img2, kernel) - mu1_mu2
|
| 178 |
+
|
| 179 |
+
ssim_map = ((2 * mu1_mu2 + C1) *
|
| 180 |
+
(2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
|
| 181 |
+
(sigma1_sq + sigma2_sq + C2))
|
| 182 |
+
return float(ssim_map.mean())
|
| 183 |
+
|
| 184 |
+
def _ssim_cly(img1, img2):
|
| 185 |
+
assert len(img1.shape) == 2 and len(img2.shape) == 2
|
| 186 |
+
"""Calculate SSIM (structural similarity) for one channel images.
|
| 187 |
+
|
| 188 |
+
It is called by func:`calculate_ssim`.
|
| 189 |
+
|
| 190 |
+
Args:
|
| 191 |
+
img1 (ndarray): Images with range [0, 255] with order 'HWC'.
|
| 192 |
+
img2 (ndarray): Images with range [0, 255] with order 'HWC'.
|
| 193 |
+
|
| 194 |
+
Returns:
|
| 195 |
+
float: ssim result.
|
| 196 |
+
"""
|
| 197 |
+
|
| 198 |
+
C1 = (0.01 * 255)**2
|
| 199 |
+
C2 = (0.03 * 255)**2
|
| 200 |
+
img1 = img1.astype(np.float64)
|
| 201 |
+
img2 = img2.astype(np.float64)
|
| 202 |
+
|
| 203 |
+
kernel = cv2.getGaussianKernel(11, 1.5)
|
| 204 |
+
# print(kernel)
|
| 205 |
+
window = np.outer(kernel, kernel.transpose())
|
| 206 |
+
|
| 207 |
+
bt = cv2.BORDER_REPLICATE
|
| 208 |
+
|
| 209 |
+
mu1 = cv2.filter2D(img1, -1, window, borderType=bt)
|
| 210 |
+
mu2 = cv2.filter2D(img2, -1, window,borderType=bt)
|
| 211 |
+
|
| 212 |
+
mu1_sq = mu1**2
|
| 213 |
+
mu2_sq = mu2**2
|
| 214 |
+
mu1_mu2 = mu1 * mu2
|
| 215 |
+
sigma1_sq = cv2.filter2D(img1**2, -1, window, borderType=bt) - mu1_sq
|
| 216 |
+
sigma2_sq = cv2.filter2D(img2**2, -1, window, borderType=bt) - mu2_sq
|
| 217 |
+
sigma12 = cv2.filter2D(img1 * img2, -1, window, borderType=bt) - mu1_mu2
|
| 218 |
+
|
| 219 |
+
ssim_map = ((2 * mu1_mu2 + C1) *
|
| 220 |
+
(2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
|
| 221 |
+
(sigma1_sq + sigma2_sq + C2))
|
| 222 |
+
return ssim_map.mean()
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
def calculate_ssim(img1,
|
| 226 |
+
img2,
|
| 227 |
+
crop_border,
|
| 228 |
+
input_order='HWC',
|
| 229 |
+
test_y_channel=False):
|
| 230 |
+
"""Calculate SSIM (structural similarity).
|
| 231 |
+
|
| 232 |
+
Ref:
|
| 233 |
+
Image quality assessment: From error visibility to structural similarity
|
| 234 |
+
|
| 235 |
+
The results are the same as that of the official released MATLAB code in
|
| 236 |
+
https://ece.uwaterloo.ca/~z70wang/research/ssim/.
|
| 237 |
+
|
| 238 |
+
For three-channel images, SSIM is calculated for each channel and then
|
| 239 |
+
averaged.
|
| 240 |
+
|
| 241 |
+
Args:
|
| 242 |
+
img1 (ndarray): Images with range [0, 255].
|
| 243 |
+
img2 (ndarray): Images with range [0, 255].
|
| 244 |
+
crop_border (int): Cropped pixels in each edge of an image. These
|
| 245 |
+
pixels are not involved in the SSIM calculation.
|
| 246 |
+
input_order (str): Whether the input order is 'HWC' or 'CHW'.
|
| 247 |
+
Default: 'HWC'.
|
| 248 |
+
test_y_channel (bool): Test on Y channel of YCbCr. Default: False.
|
| 249 |
+
|
| 250 |
+
Returns:
|
| 251 |
+
float: ssim result.
|
| 252 |
+
"""
|
| 253 |
+
|
| 254 |
+
assert img1.shape == img2.shape, (
|
| 255 |
+
f'Image shapes are differnet: {img1.shape}, {img2.shape}.')
|
| 256 |
+
if input_order not in ['HWC', 'CHW']:
|
| 257 |
+
raise ValueError(
|
| 258 |
+
f'Wrong input_order {input_order}. Supported input_orders are '
|
| 259 |
+
'"HWC" and "CHW"')
|
| 260 |
+
|
| 261 |
+
if type(img1) == torch.Tensor:
|
| 262 |
+
if len(img1.shape) == 4:
|
| 263 |
+
img1 = img1.squeeze(0)
|
| 264 |
+
img1 = img1.detach().cpu().numpy().transpose(1,2,0)
|
| 265 |
+
if type(img2) == torch.Tensor:
|
| 266 |
+
if len(img2.shape) == 4:
|
| 267 |
+
img2 = img2.squeeze(0)
|
| 268 |
+
img2 = img2.detach().cpu().numpy().transpose(1,2,0)
|
| 269 |
+
|
| 270 |
+
img1 = reorder_image(img1, input_order=input_order)
|
| 271 |
+
img2 = reorder_image(img2, input_order=input_order)
|
| 272 |
+
|
| 273 |
+
img1 = img1.astype(np.float64)
|
| 274 |
+
img2 = img2.astype(np.float64)
|
| 275 |
+
|
| 276 |
+
if crop_border != 0:
|
| 277 |
+
img1 = img1[crop_border:-crop_border, crop_border:-crop_border, ...]
|
| 278 |
+
img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...]
|
| 279 |
+
|
| 280 |
+
if test_y_channel:
|
| 281 |
+
img1 = to_y_channel(img1)
|
| 282 |
+
img2 = to_y_channel(img2)
|
| 283 |
+
return _ssim_cly(img1[..., 0], img2[..., 0])
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
ssims = []
|
| 287 |
+
# ssims_before = []
|
| 288 |
+
|
| 289 |
+
# skimage_before = skimage.metrics.structural_similarity(img1, img2, data_range=255., multichannel=True)
|
| 290 |
+
# print('.._skimage',
|
| 291 |
+
# skimage.metrics.structural_similarity(img1, img2, data_range=255., multichannel=True))
|
| 292 |
+
max_value = 1 if img1.max() <= 1 else 255
|
| 293 |
+
with torch.no_grad():
|
| 294 |
+
final_ssim = _ssim_3d(img1, img2, max_value)
|
| 295 |
+
ssims.append(final_ssim)
|
| 296 |
+
|
| 297 |
+
# for i in range(img1.shape[2]):
|
| 298 |
+
# ssims_before.append(_ssim(img1, img2))
|
| 299 |
+
|
| 300 |
+
# print('..ssim mean , new {:.4f} and before {:.4f} .... skimage before {:.4f}'.format(np.array(ssims).mean(), np.array(ssims_before).mean(), skimage_before))
|
| 301 |
+
# ssims.append(skimage.metrics.structural_similarity(img1[..., i], img2[..., i], multichannel=False))
|
| 302 |
+
|
| 303 |
+
return np.array(ssims).mean()
|
basicsr/models/__init__.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import importlib
|
| 2 |
+
from os import path as osp
|
| 3 |
+
|
| 4 |
+
from basicsr.utils import get_root_logger, scandir
|
| 5 |
+
|
| 6 |
+
# automatically scan and import model modules
|
| 7 |
+
# scan all the files under the 'models' folder and collect files ending with
|
| 8 |
+
# '_model.py'
|
| 9 |
+
model_folder = osp.dirname(osp.abspath(__file__))
|
| 10 |
+
model_filenames = [
|
| 11 |
+
osp.splitext(osp.basename(v))[0] for v in scandir(model_folder)
|
| 12 |
+
if v.endswith('_model.py')
|
| 13 |
+
]
|
| 14 |
+
# import all the model modules
|
| 15 |
+
_model_modules = [
|
| 16 |
+
importlib.import_module(f'basicsr.models.{file_name}')
|
| 17 |
+
for file_name in model_filenames
|
| 18 |
+
]
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def create_model(opt):
|
| 22 |
+
"""Create model.
|
| 23 |
+
|
| 24 |
+
Args:
|
| 25 |
+
opt (dict): Configuration. It constains:
|
| 26 |
+
model_type (str): Model type.
|
| 27 |
+
"""
|
| 28 |
+
model_type = opt['model_type']
|
| 29 |
+
|
| 30 |
+
# dynamic instantiation
|
| 31 |
+
for module in _model_modules:
|
| 32 |
+
model_cls = getattr(module, model_type, None)
|
| 33 |
+
if model_cls is not None:
|
| 34 |
+
break
|
| 35 |
+
if model_cls is None:
|
| 36 |
+
raise ValueError(f'Model {model_type} is not found.')
|
| 37 |
+
|
| 38 |
+
model = model_cls(opt)
|
| 39 |
+
|
| 40 |
+
logger = get_root_logger()
|
| 41 |
+
logger.info(f'Model [{model.__class__.__name__}] is created.')
|
| 42 |
+
return model
|
basicsr/models/archs/__init__.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import importlib
|
| 2 |
+
from os import path as osp
|
| 3 |
+
|
| 4 |
+
from basicsr.utils import scandir
|
| 5 |
+
|
| 6 |
+
# automatically scan and import arch modules
|
| 7 |
+
# scan all the files under the 'archs' folder and collect files ending with
|
| 8 |
+
# '_arch.py'
|
| 9 |
+
arch_folder = osp.dirname(osp.abspath(__file__))
|
| 10 |
+
arch_filenames = [
|
| 11 |
+
osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder)
|
| 12 |
+
if v.endswith('_arch.py')
|
| 13 |
+
]
|
| 14 |
+
# import all the arch modules
|
| 15 |
+
_arch_modules = [
|
| 16 |
+
importlib.import_module(f'basicsr.models.archs.{file_name}')
|
| 17 |
+
for file_name in arch_filenames
|
| 18 |
+
]
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def dynamic_instantiation(modules, cls_type, opt):
|
| 22 |
+
"""Dynamically instantiate class.
|
| 23 |
+
|
| 24 |
+
Args:
|
| 25 |
+
modules (list[importlib modules]): List of modules from importlib
|
| 26 |
+
files.
|
| 27 |
+
cls_type (str): Class type.
|
| 28 |
+
opt (dict): Class initialization kwargs.
|
| 29 |
+
|
| 30 |
+
Returns:
|
| 31 |
+
class: Instantiated class.
|
| 32 |
+
"""
|
| 33 |
+
for module in modules:
|
| 34 |
+
cls_ = getattr(module, cls_type, None)
|
| 35 |
+
if cls_ is not None:
|
| 36 |
+
break
|
| 37 |
+
if cls_ is None:
|
| 38 |
+
raise ValueError(f'{cls_type} is not found.')
|
| 39 |
+
return cls_(**opt)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def define_network(opt):
|
| 43 |
+
network_type = opt.pop('type')
|
| 44 |
+
net = dynamic_instantiation(_arch_modules, network_type, opt)
|
| 45 |
+
return net
|
basicsr/models/archs/arch_util.py
ADDED
|
@@ -0,0 +1,255 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import torch
|
| 3 |
+
from torch import nn as nn
|
| 4 |
+
from torch.nn import functional as F
|
| 5 |
+
from torch.nn import init as init
|
| 6 |
+
from torch.nn.modules.batchnorm import _BatchNorm
|
| 7 |
+
|
| 8 |
+
from basicsr.utils import get_root_logger
|
| 9 |
+
|
| 10 |
+
# try:
|
| 11 |
+
# from basicsr.models.ops.dcn import (ModulatedDeformConvPack,
|
| 12 |
+
# modulated_deform_conv)
|
| 13 |
+
# except ImportError:
|
| 14 |
+
# # print('Cannot import dcn. Ignore this warning if dcn is not used. '
|
| 15 |
+
# # 'Otherwise install BasicSR with compiling dcn.')
|
| 16 |
+
#
|
| 17 |
+
|
| 18 |
+
@torch.no_grad()
|
| 19 |
+
def default_init_weights(module_list, scale=1, bias_fill=0, **kwargs):
|
| 20 |
+
"""Initialize network weights.
|
| 21 |
+
|
| 22 |
+
Args:
|
| 23 |
+
module_list (list[nn.Module] | nn.Module): Modules to be initialized.
|
| 24 |
+
scale (float): Scale initialized weights, especially for residual
|
| 25 |
+
blocks. Default: 1.
|
| 26 |
+
bias_fill (float): The value to fill bias. Default: 0
|
| 27 |
+
kwargs (dict): Other arguments for initialization function.
|
| 28 |
+
"""
|
| 29 |
+
if not isinstance(module_list, list):
|
| 30 |
+
module_list = [module_list]
|
| 31 |
+
for module in module_list:
|
| 32 |
+
for m in module.modules():
|
| 33 |
+
if isinstance(m, nn.Conv2d):
|
| 34 |
+
init.kaiming_normal_(m.weight, **kwargs)
|
| 35 |
+
m.weight.data *= scale
|
| 36 |
+
if m.bias is not None:
|
| 37 |
+
m.bias.data.fill_(bias_fill)
|
| 38 |
+
elif isinstance(m, nn.Linear):
|
| 39 |
+
init.kaiming_normal_(m.weight, **kwargs)
|
| 40 |
+
m.weight.data *= scale
|
| 41 |
+
if m.bias is not None:
|
| 42 |
+
m.bias.data.fill_(bias_fill)
|
| 43 |
+
elif isinstance(m, _BatchNorm):
|
| 44 |
+
init.constant_(m.weight, 1)
|
| 45 |
+
if m.bias is not None:
|
| 46 |
+
m.bias.data.fill_(bias_fill)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def make_layer(basic_block, num_basic_block, **kwarg):
|
| 50 |
+
"""Make layers by stacking the same blocks.
|
| 51 |
+
|
| 52 |
+
Args:
|
| 53 |
+
basic_block (nn.module): nn.module class for basic block.
|
| 54 |
+
num_basic_block (int): number of blocks.
|
| 55 |
+
|
| 56 |
+
Returns:
|
| 57 |
+
nn.Sequential: Stacked blocks in nn.Sequential.
|
| 58 |
+
"""
|
| 59 |
+
layers = []
|
| 60 |
+
for _ in range(num_basic_block):
|
| 61 |
+
layers.append(basic_block(**kwarg))
|
| 62 |
+
return nn.Sequential(*layers)
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
class ResidualBlockNoBN(nn.Module):
|
| 66 |
+
"""Residual block without BN.
|
| 67 |
+
|
| 68 |
+
It has a style of:
|
| 69 |
+
---Conv-ReLU-Conv-+-
|
| 70 |
+
|________________|
|
| 71 |
+
|
| 72 |
+
Args:
|
| 73 |
+
num_feat (int): Channel number of intermediate features.
|
| 74 |
+
Default: 64.
|
| 75 |
+
res_scale (float): Residual scale. Default: 1.
|
| 76 |
+
pytorch_init (bool): If set to True, use pytorch default init,
|
| 77 |
+
otherwise, use default_init_weights. Default: False.
|
| 78 |
+
"""
|
| 79 |
+
|
| 80 |
+
def __init__(self, num_feat=64, res_scale=1, pytorch_init=False):
|
| 81 |
+
super(ResidualBlockNoBN, self).__init__()
|
| 82 |
+
self.res_scale = res_scale
|
| 83 |
+
self.conv1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
|
| 84 |
+
self.conv2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
|
| 85 |
+
self.relu = nn.ReLU(inplace=True)
|
| 86 |
+
|
| 87 |
+
if not pytorch_init:
|
| 88 |
+
default_init_weights([self.conv1, self.conv2], 0.1)
|
| 89 |
+
|
| 90 |
+
def forward(self, x):
|
| 91 |
+
identity = x
|
| 92 |
+
out = self.conv2(self.relu(self.conv1(x)))
|
| 93 |
+
return identity + out * self.res_scale
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
class Upsample(nn.Sequential):
|
| 97 |
+
"""Upsample module.
|
| 98 |
+
|
| 99 |
+
Args:
|
| 100 |
+
scale (int): Scale factor. Supported scales: 2^n and 3.
|
| 101 |
+
num_feat (int): Channel number of intermediate features.
|
| 102 |
+
"""
|
| 103 |
+
|
| 104 |
+
def __init__(self, scale, num_feat):
|
| 105 |
+
m = []
|
| 106 |
+
if (scale & (scale - 1)) == 0: # scale = 2^n
|
| 107 |
+
for _ in range(int(math.log(scale, 2))):
|
| 108 |
+
m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
|
| 109 |
+
m.append(nn.PixelShuffle(2))
|
| 110 |
+
elif scale == 3:
|
| 111 |
+
m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
|
| 112 |
+
m.append(nn.PixelShuffle(3))
|
| 113 |
+
else:
|
| 114 |
+
raise ValueError(f'scale {scale} is not supported. '
|
| 115 |
+
'Supported scales: 2^n and 3.')
|
| 116 |
+
super(Upsample, self).__init__(*m)
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def flow_warp(x,
|
| 120 |
+
flow,
|
| 121 |
+
interp_mode='bilinear',
|
| 122 |
+
padding_mode='zeros',
|
| 123 |
+
align_corners=True):
|
| 124 |
+
"""Warp an image or feature map with optical flow.
|
| 125 |
+
|
| 126 |
+
Args:
|
| 127 |
+
x (Tensor): Tensor with size (n, c, h, w).
|
| 128 |
+
flow (Tensor): Tensor with size (n, h, w, 2), normal value.
|
| 129 |
+
interp_mode (str): 'nearest' or 'bilinear'. Default: 'bilinear'.
|
| 130 |
+
padding_mode (str): 'zeros' or 'border' or 'reflection'.
|
| 131 |
+
Default: 'zeros'.
|
| 132 |
+
align_corners (bool): Before pytorch 1.3, the default value is
|
| 133 |
+
align_corners=True. After pytorch 1.3, the default value is
|
| 134 |
+
align_corners=False. Here, we use the True as default.
|
| 135 |
+
|
| 136 |
+
Returns:
|
| 137 |
+
Tensor: Warped image or feature map.
|
| 138 |
+
"""
|
| 139 |
+
assert x.size()[-2:] == flow.size()[1:3]
|
| 140 |
+
_, _, h, w = x.size()
|
| 141 |
+
# create mesh grid
|
| 142 |
+
grid_y, grid_x = torch.meshgrid(
|
| 143 |
+
torch.arange(0, h).type_as(x),
|
| 144 |
+
torch.arange(0, w).type_as(x))
|
| 145 |
+
grid = torch.stack((grid_x, grid_y), 2).float() # W(x), H(y), 2
|
| 146 |
+
grid.requires_grad = False
|
| 147 |
+
|
| 148 |
+
vgrid = grid + flow
|
| 149 |
+
# scale grid to [-1,1]
|
| 150 |
+
vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(w - 1, 1) - 1.0
|
| 151 |
+
vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(h - 1, 1) - 1.0
|
| 152 |
+
vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3)
|
| 153 |
+
output = F.grid_sample(
|
| 154 |
+
x,
|
| 155 |
+
vgrid_scaled,
|
| 156 |
+
mode=interp_mode,
|
| 157 |
+
padding_mode=padding_mode,
|
| 158 |
+
align_corners=align_corners)
|
| 159 |
+
|
| 160 |
+
# TODO, what if align_corners=False
|
| 161 |
+
return output
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
def resize_flow(flow,
|
| 165 |
+
size_type,
|
| 166 |
+
sizes,
|
| 167 |
+
interp_mode='bilinear',
|
| 168 |
+
align_corners=False):
|
| 169 |
+
"""Resize a flow according to ratio or shape.
|
| 170 |
+
|
| 171 |
+
Args:
|
| 172 |
+
flow (Tensor): Precomputed flow. shape [N, 2, H, W].
|
| 173 |
+
size_type (str): 'ratio' or 'shape'.
|
| 174 |
+
sizes (list[int | float]): the ratio for resizing or the final output
|
| 175 |
+
shape.
|
| 176 |
+
1) The order of ratio should be [ratio_h, ratio_w]. For
|
| 177 |
+
downsampling, the ratio should be smaller than 1.0 (i.e., ratio
|
| 178 |
+
< 1.0). For upsampling, the ratio should be larger than 1.0 (i.e.,
|
| 179 |
+
ratio > 1.0).
|
| 180 |
+
2) The order of output_size should be [out_h, out_w].
|
| 181 |
+
interp_mode (str): The mode of interpolation for resizing.
|
| 182 |
+
Default: 'bilinear'.
|
| 183 |
+
align_corners (bool): Whether align corners. Default: False.
|
| 184 |
+
|
| 185 |
+
Returns:
|
| 186 |
+
Tensor: Resized flow.
|
| 187 |
+
"""
|
| 188 |
+
_, _, flow_h, flow_w = flow.size()
|
| 189 |
+
if size_type == 'ratio':
|
| 190 |
+
output_h, output_w = int(flow_h * sizes[0]), int(flow_w * sizes[1])
|
| 191 |
+
elif size_type == 'shape':
|
| 192 |
+
output_h, output_w = sizes[0], sizes[1]
|
| 193 |
+
else:
|
| 194 |
+
raise ValueError(
|
| 195 |
+
f'Size type should be ratio or shape, but got type {size_type}.')
|
| 196 |
+
|
| 197 |
+
input_flow = flow.clone()
|
| 198 |
+
ratio_h = output_h / flow_h
|
| 199 |
+
ratio_w = output_w / flow_w
|
| 200 |
+
input_flow[:, 0, :, :] *= ratio_w
|
| 201 |
+
input_flow[:, 1, :, :] *= ratio_h
|
| 202 |
+
resized_flow = F.interpolate(
|
| 203 |
+
input=input_flow,
|
| 204 |
+
size=(output_h, output_w),
|
| 205 |
+
mode=interp_mode,
|
| 206 |
+
align_corners=align_corners)
|
| 207 |
+
return resized_flow
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
# TODO: may write a cpp file
|
| 211 |
+
def pixel_unshuffle(x, scale):
|
| 212 |
+
""" Pixel unshuffle.
|
| 213 |
+
|
| 214 |
+
Args:
|
| 215 |
+
x (Tensor): Input feature with shape (b, c, hh, hw).
|
| 216 |
+
scale (int): Downsample ratio.
|
| 217 |
+
|
| 218 |
+
Returns:
|
| 219 |
+
Tensor: the pixel unshuffled feature.
|
| 220 |
+
"""
|
| 221 |
+
b, c, hh, hw = x.size()
|
| 222 |
+
out_channel = c * (scale**2)
|
| 223 |
+
assert hh % scale == 0 and hw % scale == 0
|
| 224 |
+
h = hh // scale
|
| 225 |
+
w = hw // scale
|
| 226 |
+
x_view = x.view(b, c, h, scale, w, scale)
|
| 227 |
+
return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w)
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
# class DCNv2Pack(ModulatedDeformConvPack):
|
| 231 |
+
# """Modulated deformable conv for deformable alignment.
|
| 232 |
+
#
|
| 233 |
+
# Different from the official DCNv2Pack, which generates offsets and masks
|
| 234 |
+
# from the preceding features, this DCNv2Pack takes another different
|
| 235 |
+
# features to generate offsets and masks.
|
| 236 |
+
#
|
| 237 |
+
# Ref:
|
| 238 |
+
# Delving Deep into Deformable Alignment in Video Super-Resolution.
|
| 239 |
+
# """
|
| 240 |
+
#
|
| 241 |
+
# def forward(self, x, feat):
|
| 242 |
+
# out = self.conv_offset(feat)
|
| 243 |
+
# o1, o2, mask = torch.chunk(out, 3, dim=1)
|
| 244 |
+
# offset = torch.cat((o1, o2), dim=1)
|
| 245 |
+
# mask = torch.sigmoid(mask)
|
| 246 |
+
#
|
| 247 |
+
# offset_absmean = torch.mean(torch.abs(offset))
|
| 248 |
+
# if offset_absmean > 50:
|
| 249 |
+
# logger = get_root_logger()
|
| 250 |
+
# logger.warning(
|
| 251 |
+
# f'Offset abs mean is {offset_absmean}, larger than 50.')
|
| 252 |
+
#
|
| 253 |
+
# return modulated_deform_conv(x, offset, mask, self.weight, self.bias,
|
| 254 |
+
# self.stride, self.padding, self.dilation,
|
| 255 |
+
# self.groups, self.deformable_groups)
|
basicsr/models/archs/restormer_arch.py
ADDED
|
@@ -0,0 +1,527 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import numbers
|
| 5 |
+
from torch import einsum
|
| 6 |
+
|
| 7 |
+
from einops import rearrange
|
| 8 |
+
from basicsr.utils.nano import psf2otf
|
| 9 |
+
|
| 10 |
+
try:
|
| 11 |
+
from flash_attn import flash_attn_func
|
| 12 |
+
except:
|
| 13 |
+
print("Flash attention is required")
|
| 14 |
+
raise NotImplementedError
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def to_3d(x):
|
| 18 |
+
return rearrange(x, 'b c h w -> b (h w) c')
|
| 19 |
+
|
| 20 |
+
def to_4d(x,h,w):
|
| 21 |
+
return rearrange(x, 'b (h w) c -> b c h w',h=h,w=w)
|
| 22 |
+
|
| 23 |
+
class BiasFree_LayerNorm(nn.Module):
|
| 24 |
+
def __init__(self, normalized_shape):
|
| 25 |
+
super(BiasFree_LayerNorm, self).__init__()
|
| 26 |
+
if isinstance(normalized_shape, numbers.Integral):
|
| 27 |
+
normalized_shape = (normalized_shape,)
|
| 28 |
+
normalized_shape = torch.Size(normalized_shape)
|
| 29 |
+
|
| 30 |
+
assert len(normalized_shape) == 1
|
| 31 |
+
|
| 32 |
+
self.weight = nn.Parameter(torch.ones(normalized_shape))
|
| 33 |
+
self.normalized_shape = normalized_shape
|
| 34 |
+
|
| 35 |
+
def forward(self, x):
|
| 36 |
+
sigma = x.var(-1, keepdim=True, unbiased=False)
|
| 37 |
+
return x / torch.sqrt(sigma+1e-5) * self.weight
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class WithBias_LayerNorm(nn.Module):
|
| 41 |
+
def __init__(self, normalized_shape):
|
| 42 |
+
super(WithBias_LayerNorm, self).__init__()
|
| 43 |
+
if isinstance(normalized_shape, numbers.Integral):
|
| 44 |
+
normalized_shape = (normalized_shape,)
|
| 45 |
+
normalized_shape = torch.Size(normalized_shape)
|
| 46 |
+
|
| 47 |
+
assert len(normalized_shape) == 1
|
| 48 |
+
|
| 49 |
+
self.weight = nn.Parameter(torch.ones(normalized_shape))
|
| 50 |
+
self.bias = nn.Parameter(torch.zeros(normalized_shape))
|
| 51 |
+
self.normalized_shape = normalized_shape
|
| 52 |
+
|
| 53 |
+
def forward(self, x):
|
| 54 |
+
mu = x.mean(-1, keepdim=True)
|
| 55 |
+
sigma = x.var(-1, keepdim=True, unbiased=False)
|
| 56 |
+
return (x - mu) / torch.sqrt(sigma+1e-5) * self.weight + self.bias
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
class LayerNorm(nn.Module):
|
| 60 |
+
def __init__(self, dim, LayerNorm_type):
|
| 61 |
+
super(LayerNorm, self).__init__()
|
| 62 |
+
if LayerNorm_type =='BiasFree':
|
| 63 |
+
self.body = BiasFree_LayerNorm(dim)
|
| 64 |
+
else:
|
| 65 |
+
self.body = WithBias_LayerNorm(dim)
|
| 66 |
+
|
| 67 |
+
def forward(self, x):
|
| 68 |
+
h, w = x.shape[-2:]
|
| 69 |
+
return to_4d(self.body(to_3d(x)), h, w)
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
##########################################################################
|
| 74 |
+
## Gated-Dconv Feed-Forward Network (GDFN)
|
| 75 |
+
class FeedForward(nn.Module):
|
| 76 |
+
def __init__(self, dim, ffn_expansion_factor, bias):
|
| 77 |
+
super(FeedForward, self).__init__()
|
| 78 |
+
|
| 79 |
+
hidden_features = int(dim*ffn_expansion_factor)
|
| 80 |
+
|
| 81 |
+
self.project_in = nn.Conv2d(dim, hidden_features*2, kernel_size=1, bias=bias)
|
| 82 |
+
|
| 83 |
+
self.dwconv = nn.Conv2d(hidden_features*2, hidden_features*2, kernel_size=3, stride=1, padding=1, groups=hidden_features*2, bias=bias)
|
| 84 |
+
|
| 85 |
+
self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias)
|
| 86 |
+
|
| 87 |
+
def forward(self, x):
|
| 88 |
+
x = self.project_in(x)
|
| 89 |
+
x1, x2 = self.dwconv(x).chunk(2, dim=1)
|
| 90 |
+
x = F.gelu(x1) * x2
|
| 91 |
+
x = self.project_out(x)
|
| 92 |
+
return x
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
##########################################################################
|
| 97 |
+
## Multi-DConv Head Transposed Self-Attention (MDTA)
|
| 98 |
+
class Attention(nn.Module):
|
| 99 |
+
def __init__(self, dim, num_heads, bias, ksize=0):
|
| 100 |
+
super(Attention, self).__init__()
|
| 101 |
+
self.num_heads = num_heads
|
| 102 |
+
self.ksize = ksize
|
| 103 |
+
self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))
|
| 104 |
+
|
| 105 |
+
self.qkv = nn.Conv2d(dim, dim*3, kernel_size=1, bias=bias)
|
| 106 |
+
self.qkv_dwconv = nn.Conv2d(dim*3, dim*3, kernel_size=3, stride=1, padding=1, groups=dim*3, bias=bias)
|
| 107 |
+
self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)
|
| 108 |
+
if ksize:
|
| 109 |
+
self.avg = torch.nn.AvgPool2d(kernel_size=ksize, stride=1, padding=(ksize-1) //2)
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def forward(self, x):
|
| 113 |
+
b,c,h,w = x.shape
|
| 114 |
+
|
| 115 |
+
qkv = self.qkv_dwconv(self.qkv(x))
|
| 116 |
+
q,k,v = qkv.chunk(3, dim=1)
|
| 117 |
+
|
| 118 |
+
if self.ksize:
|
| 119 |
+
q = q - self.avg(q)
|
| 120 |
+
|
| 121 |
+
q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
|
| 122 |
+
k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
|
| 123 |
+
v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
|
| 124 |
+
|
| 125 |
+
q = torch.nn.functional.normalize(q, dim=-1)
|
| 126 |
+
k = torch.nn.functional.normalize(k, dim=-1)
|
| 127 |
+
|
| 128 |
+
attn = (q @ k.transpose(-2, -1)) * self.temperature
|
| 129 |
+
attn = attn.softmax(dim=-1)
|
| 130 |
+
|
| 131 |
+
out = (attn @ v)
|
| 132 |
+
|
| 133 |
+
out = rearrange(out, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w)
|
| 134 |
+
|
| 135 |
+
out = self.project_out(out)
|
| 136 |
+
return out
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
##########################################################################
|
| 140 |
+
## Overlapped image patch embedding with 3x3 Conv
|
| 141 |
+
class OverlapPatchEmbed(nn.Module):
|
| 142 |
+
def __init__(self, in_c=3, embed_dim=48, bias=False):
|
| 143 |
+
super(OverlapPatchEmbed, self).__init__()
|
| 144 |
+
|
| 145 |
+
self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=3, stride=1, padding=1, bias=bias)
|
| 146 |
+
|
| 147 |
+
def forward(self, x):
|
| 148 |
+
x = self.proj(x)
|
| 149 |
+
|
| 150 |
+
return x
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
##########################################################################
|
| 155 |
+
## Resizing modules
|
| 156 |
+
class Downsample(nn.Module):
|
| 157 |
+
def __init__(self, n_feat):
|
| 158 |
+
super(Downsample, self).__init__()
|
| 159 |
+
|
| 160 |
+
self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat//2, kernel_size=3, stride=1, padding=1, bias=False),
|
| 161 |
+
nn.PixelUnshuffle(2))
|
| 162 |
+
|
| 163 |
+
def forward(self, x):
|
| 164 |
+
return self.body(x)
|
| 165 |
+
|
| 166 |
+
class Upsample(nn.Module):
|
| 167 |
+
def __init__(self, n_feat):
|
| 168 |
+
super(Upsample, self).__init__()
|
| 169 |
+
|
| 170 |
+
self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat*2, kernel_size=3, stride=1, padding=1, bias=False),
|
| 171 |
+
nn.PixelShuffle(2))
|
| 172 |
+
|
| 173 |
+
def forward(self, x):
|
| 174 |
+
return self.body(x)
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
def to(x):
|
| 178 |
+
return {'device': x.device, 'dtype': x.dtype}
|
| 179 |
+
|
| 180 |
+
def pair(x):
|
| 181 |
+
return (x, x) if not isinstance(x, tuple) else x
|
| 182 |
+
|
| 183 |
+
def expand_dim(t, dim, k):
|
| 184 |
+
t = t.unsqueeze(dim = dim)
|
| 185 |
+
expand_shape = [-1] * len(t.shape)
|
| 186 |
+
expand_shape[dim] = k
|
| 187 |
+
return t.expand(*expand_shape)
|
| 188 |
+
|
| 189 |
+
def rel_to_abs(x):
|
| 190 |
+
b, l, m = x.shape
|
| 191 |
+
r = (m + 1) // 2
|
| 192 |
+
|
| 193 |
+
col_pad = torch.zeros((b, l, 1), **to(x))
|
| 194 |
+
x = torch.cat((x, col_pad), dim = 2)
|
| 195 |
+
flat_x = rearrange(x, 'b l c -> b (l c)')
|
| 196 |
+
flat_pad = torch.zeros((b, m - l), **to(x))
|
| 197 |
+
flat_x_padded = torch.cat((flat_x, flat_pad), dim = 1)
|
| 198 |
+
final_x = flat_x_padded.reshape(b, l + 1, m)
|
| 199 |
+
final_x = final_x[:, :l, -r:]
|
| 200 |
+
return final_x
|
| 201 |
+
|
| 202 |
+
def relative_logits_1d(q, rel_k):
|
| 203 |
+
b, h, w, _ = q.shape
|
| 204 |
+
r = (rel_k.shape[0] + 1) // 2
|
| 205 |
+
|
| 206 |
+
logits = einsum('b x y d, r d -> b x y r', q, rel_k)
|
| 207 |
+
logits = rearrange(logits, 'b x y r -> (b x) y r')
|
| 208 |
+
logits = rel_to_abs(logits)
|
| 209 |
+
|
| 210 |
+
logits = logits.reshape(b, h, w, r)
|
| 211 |
+
logits = expand_dim(logits, dim = 2, k = r)
|
| 212 |
+
return logits
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
class RelPosEmb(nn.Module):
|
| 216 |
+
def __init__(
|
| 217 |
+
self,
|
| 218 |
+
block_size,
|
| 219 |
+
rel_size,
|
| 220 |
+
dim_head
|
| 221 |
+
):
|
| 222 |
+
super().__init__()
|
| 223 |
+
height = width = rel_size
|
| 224 |
+
scale = dim_head ** -0.5
|
| 225 |
+
|
| 226 |
+
self.block_size = block_size
|
| 227 |
+
self.rel_height = nn.Parameter(torch.randn(height * 2 - 1, dim_head) * scale)
|
| 228 |
+
self.rel_width = nn.Parameter(torch.randn(width * 2 - 1, dim_head) * scale)
|
| 229 |
+
|
| 230 |
+
def forward(self, q):
|
| 231 |
+
block = self.block_size
|
| 232 |
+
|
| 233 |
+
q = rearrange(q, 'b (x y) c -> b x y c', x = block)
|
| 234 |
+
rel_logits_w = relative_logits_1d(q, self.rel_width)
|
| 235 |
+
rel_logits_w = rearrange(rel_logits_w, 'b x i y j-> b (x y) (i j)')
|
| 236 |
+
|
| 237 |
+
q = rearrange(q, 'b x y d -> b y x d')
|
| 238 |
+
rel_logits_h = relative_logits_1d(q, self.rel_height)
|
| 239 |
+
rel_logits_h = rearrange(rel_logits_h, 'b x i y j -> b (y x) (j i)')
|
| 240 |
+
return rel_logits_w + rel_logits_h
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
##########################################################################
|
| 244 |
+
## Overlapping Cross-Attention (OCA)
|
| 245 |
+
class OCAB(nn.Module):
|
| 246 |
+
def __init__(self, dim, window_size, overlap_ratio, num_heads, dim_head, bias, ksize=0):
|
| 247 |
+
super(OCAB, self).__init__()
|
| 248 |
+
self.num_spatial_heads = num_heads
|
| 249 |
+
self.dim = dim
|
| 250 |
+
self.window_size = window_size
|
| 251 |
+
self.overlap_win_size = int(window_size * overlap_ratio) + window_size
|
| 252 |
+
self.dim_head = dim_head
|
| 253 |
+
self.inner_dim = self.dim_head * self.num_spatial_heads
|
| 254 |
+
self.scale = self.dim_head**-0.5
|
| 255 |
+
self.ksize = ksize
|
| 256 |
+
|
| 257 |
+
self.unfold = nn.Unfold(kernel_size=(self.overlap_win_size, self.overlap_win_size), stride=window_size, padding=(self.overlap_win_size-window_size)//2)
|
| 258 |
+
self.qkv = nn.Conv2d(self.dim, self.inner_dim*3, kernel_size=1, bias=bias)
|
| 259 |
+
self.project_out = nn.Conv2d(self.inner_dim, dim, kernel_size=1, bias=bias)
|
| 260 |
+
self.rel_pos_emb = RelPosEmb(
|
| 261 |
+
block_size = window_size,
|
| 262 |
+
rel_size = window_size + (self.overlap_win_size - window_size),
|
| 263 |
+
dim_head = self.dim_head
|
| 264 |
+
)
|
| 265 |
+
if ksize:
|
| 266 |
+
self.avg = torch.nn.AvgPool2d(kernel_size=ksize, stride=1, padding=(ksize-1) //2)
|
| 267 |
+
|
| 268 |
+
def forward(self, x):
|
| 269 |
+
b, c, h, w = x.shape
|
| 270 |
+
|
| 271 |
+
qkv = self.qkv(x)
|
| 272 |
+
qs, ks, vs = qkv.chunk(3, dim=1)
|
| 273 |
+
|
| 274 |
+
if self.ksize:
|
| 275 |
+
qs = qs - self.avg(qs)
|
| 276 |
+
|
| 277 |
+
# spatial attention
|
| 278 |
+
qs = rearrange(qs, 'b c (h p1) (w p2) -> (b h w) (p1 p2) c', p1 = self.window_size, p2 = self.window_size)
|
| 279 |
+
ks, vs = map(lambda t: self.unfold(t), (ks, vs))
|
| 280 |
+
ks, vs = map(lambda t: rearrange(t, 'b (c j) i -> (b i) j c', c = self.inner_dim), (ks, vs))
|
| 281 |
+
|
| 282 |
+
#split heads
|
| 283 |
+
qs, ks, vs = map(lambda t: rearrange(t, 'b n (head c) -> (b head) n c', head = self.num_spatial_heads), (qs, ks, vs))
|
| 284 |
+
|
| 285 |
+
# attention
|
| 286 |
+
qs = qs * self.scale
|
| 287 |
+
spatial_attn = (qs @ ks.transpose(-2, -1))
|
| 288 |
+
spatial_attn += self.rel_pos_emb(qs)
|
| 289 |
+
spatial_attn = spatial_attn.softmax(dim=-1)
|
| 290 |
+
|
| 291 |
+
out = (spatial_attn @ vs)
|
| 292 |
+
|
| 293 |
+
out = rearrange(out, '(b h w head) (p1 p2) c -> b (head c) (h p1) (w p2)', head = self.num_spatial_heads, h = h // self.window_size, w = w // self.window_size, p1 = self.window_size, p2 = self.window_size)
|
| 294 |
+
|
| 295 |
+
# merge spatial and channel
|
| 296 |
+
out = self.project_out(out)
|
| 297 |
+
|
| 298 |
+
return out
|
| 299 |
+
|
| 300 |
+
|
| 301 |
+
class AttentionFusion(nn.Module):
|
| 302 |
+
def __init__(self, dim, bias, channel_fusion):
|
| 303 |
+
super(AttentionFusion, self).__init__()
|
| 304 |
+
|
| 305 |
+
self.channel_fusion = channel_fusion
|
| 306 |
+
self.fusion = nn.Sequential(
|
| 307 |
+
nn.Conv2d(dim, dim // 2, kernel_size=1, bias=bias),
|
| 308 |
+
nn.GELU(),
|
| 309 |
+
nn.Conv2d(dim // 2, dim // 2, kernel_size=1, bias=bias)
|
| 310 |
+
)
|
| 311 |
+
self.dim = dim // 2
|
| 312 |
+
|
| 313 |
+
def forward(self, x):
|
| 314 |
+
fusion_map = self.fusion(x)
|
| 315 |
+
if self.channel_fusion:
|
| 316 |
+
weight = F.sigmoid(torch.mean(fusion_map, 1, True))
|
| 317 |
+
else:
|
| 318 |
+
weight = F.sigmoid(torch.mean(fusion_map, (2,3), True))
|
| 319 |
+
fused_feature = x[:, :self.dim] * weight + x[:, self.dim:] * (1-weight) # [:, :self.dim] == SA
|
| 320 |
+
return fused_feature
|
| 321 |
+
|
| 322 |
+
|
| 323 |
+
|
| 324 |
+
class Transformer_STAF(nn.Module):
|
| 325 |
+
def __init__(self, dim, window_size, overlap_ratio, num_channel_heads, num_spatial_heads, spatial_dim_head, ffn_expansion_factor, bias, LayerNorm_type, channel_fusion, query_ksize=0):
|
| 326 |
+
super(Transformer_STAF, self).__init__()
|
| 327 |
+
|
| 328 |
+
self.spatial_attn = OCAB(dim, window_size, overlap_ratio, num_spatial_heads, spatial_dim_head, bias, ksize=query_ksize)
|
| 329 |
+
self.channel_attn = Attention(dim, num_channel_heads, bias, ksize=query_ksize)
|
| 330 |
+
|
| 331 |
+
self.norm1 = LayerNorm(dim, LayerNorm_type)
|
| 332 |
+
self.norm2 = LayerNorm(dim, LayerNorm_type)
|
| 333 |
+
self.norm3 = LayerNorm(dim, LayerNorm_type)
|
| 334 |
+
self.norm4 = LayerNorm(dim, LayerNorm_type)
|
| 335 |
+
|
| 336 |
+
self.channel_ffn = FeedForward(dim, ffn_expansion_factor, bias)
|
| 337 |
+
self.spatial_ffn = FeedForward(dim, ffn_expansion_factor, bias)
|
| 338 |
+
|
| 339 |
+
self.fusion = AttentionFusion(dim*2, bias, channel_fusion)
|
| 340 |
+
|
| 341 |
+
def forward(self, x):
|
| 342 |
+
sa = x + self.spatial_attn(self.norm1(x))
|
| 343 |
+
sa = sa + self.spatial_ffn(self.norm2(sa))
|
| 344 |
+
ca = x + self.channel_attn(self.norm3(x))
|
| 345 |
+
ca = ca + self.channel_ffn(self.norm4(ca))
|
| 346 |
+
fused = self.fusion(torch.cat([sa, ca], 1))
|
| 347 |
+
|
| 348 |
+
return fused
|
| 349 |
+
|
| 350 |
+
|
| 351 |
+
class MAFG_CA(nn.Module):
|
| 352 |
+
def __init__(self, embed_dim, num_heads, M, window_size=0, eps=1e-6):
|
| 353 |
+
super().__init__()
|
| 354 |
+
self.M = M
|
| 355 |
+
self.Q_idx = M // 2
|
| 356 |
+
self.embed_dim = embed_dim
|
| 357 |
+
self.num_heads = num_heads
|
| 358 |
+
self.head_dim = embed_dim // num_heads
|
| 359 |
+
self.M = M
|
| 360 |
+
self.wsize = window_size
|
| 361 |
+
|
| 362 |
+
self.proj_high = nn.Conv2d(3, embed_dim, kernel_size=1)
|
| 363 |
+
self.proj_rgb = nn.Conv2d(embed_dim, 3, kernel_size=1)
|
| 364 |
+
|
| 365 |
+
self.norm = nn.LayerNorm(embed_dim, eps=eps)
|
| 366 |
+
self.qkv = nn.Linear(embed_dim, embed_dim*3, bias=False)
|
| 367 |
+
self.proj_out = nn.Linear(embed_dim, embed_dim, bias=False)
|
| 368 |
+
self.max_seq = 2**16-1
|
| 369 |
+
|
| 370 |
+
# window based sliding similar to OCAB
|
| 371 |
+
self.overlap_wsize = int(self.wsize * 0.5) + self.wsize
|
| 372 |
+
self.unfold = nn.Unfold(kernel_size=(self.overlap_wsize, self.overlap_wsize), stride=window_size, padding=(self.overlap_wsize-self.wsize)//2)
|
| 373 |
+
self.scale = self.embed_dim ** -0.5
|
| 374 |
+
self.pos_emb_q = nn.Parameter(torch.zeros(self.wsize**2, embed_dim))
|
| 375 |
+
self.pos_emb_k = nn.Parameter(torch.zeros(self.overlap_wsize**2, embed_dim))
|
| 376 |
+
nn.init.trunc_normal_(self.pos_emb_q, std=0.02)
|
| 377 |
+
nn.init.trunc_normal_(self.pos_emb_k, std=0.02)
|
| 378 |
+
|
| 379 |
+
def forward(self, x):
|
| 380 |
+
x = self.proj_high(x)
|
| 381 |
+
BM,E,H,W = x.shape
|
| 382 |
+
|
| 383 |
+
x_seq = x.view(BM,E,-1).permute(0,2,1)
|
| 384 |
+
x_seq = self.norm(x_seq)
|
| 385 |
+
B = BM // self.M
|
| 386 |
+
QKV = self.qkv(x_seq)
|
| 387 |
+
QKV = QKV.view(BM, H, W, 3, -1).permute(3,0,4,1,2).contiguous()
|
| 388 |
+
Q,K,V = QKV[0], QKV[1], QKV[2]
|
| 389 |
+
Q_bm = Q.view(B, self.M, E, H,W)
|
| 390 |
+
_Q = Q_bm[:, self.Q_idx:self.Q_idx+1]
|
| 391 |
+
Q = torch.stack([__Q.repeat(self.M,1,1,1) for __Q in _Q]).view(BM,E,H,W)
|
| 392 |
+
|
| 393 |
+
Q = rearrange(Q, 'b c (h p1) (w p2) -> (b h w) (p1 p2) c', p1 = self.wsize, p2 = self.wsize)
|
| 394 |
+
K,V = map(lambda t: self.unfold(t), (K,V))
|
| 395 |
+
if K.shape[-1] > 10000: # Inference
|
| 396 |
+
b,_,pp = K.shape
|
| 397 |
+
K = K.view(b,self.embed_dim,-1,pp).permute(0,3,2,1).reshape(b*pp,-1,self.embed_dim)
|
| 398 |
+
V = V.view(b,self.embed_dim,-1,pp).permute(0,3,2,1).reshape(b*pp,-1,self.embed_dim)
|
| 399 |
+
else:
|
| 400 |
+
K,V = map(lambda t: rearrange(t, 'b (c j) i -> (b i) j c', c = self.embed_dim), (K,V))
|
| 401 |
+
|
| 402 |
+
# Absolute positional embedding
|
| 403 |
+
Q = Q + self.pos_emb_q
|
| 404 |
+
K = K + self.pos_emb_k
|
| 405 |
+
|
| 406 |
+
s, eq, _ = Q.shape
|
| 407 |
+
_, ek, _ = K.shape
|
| 408 |
+
Q = Q.view(s, eq, self.num_heads,self.head_dim).half()
|
| 409 |
+
K = K.view(s, ek, self.num_heads,self.head_dim).half()
|
| 410 |
+
V = V.view(s, ek, self.num_heads,self.head_dim).half()
|
| 411 |
+
if s > self.max_seq: # maximum allowed sequence of flash attention
|
| 412 |
+
outs = []
|
| 413 |
+
sp = self.max_seq
|
| 414 |
+
_max = s // sp + 1
|
| 415 |
+
for i in range(_max):
|
| 416 |
+
outs.append(flash_attn_func(Q[i*sp: (i+1)*sp], K[i*sp: (i+1)*sp], V[i*sp: (i+1)*sp], causal=False))
|
| 417 |
+
out = torch.cat(outs).to(torch.float32)
|
| 418 |
+
else:
|
| 419 |
+
out = flash_attn_func(Q, K, V, causal=False).to(torch.float32)
|
| 420 |
+
out = rearrange(out, '(b nh nw) (ph pw) h d -> b (nh ph nw pw) (h d)', nh=H//self.wsize, nw=W//self.wsize, ph=self.wsize, pw=self.wsize)
|
| 421 |
+
out = self.proj_out(out)
|
| 422 |
+
|
| 423 |
+
mixed_feature = out.view(BM,H,W,E).permute(0,3,1,2).contiguous() + x
|
| 424 |
+
return self.proj_rgb(mixed_feature).reshape(B,-1,H,W)
|
| 425 |
+
|
| 426 |
+
|
| 427 |
+
##########################################################################
|
| 428 |
+
## Aberration Correction Transformers for Metalens
|
| 429 |
+
class ACFormer(nn.Module):
|
| 430 |
+
def __init__(self,
|
| 431 |
+
inp_channels=3,
|
| 432 |
+
out_channels=3,
|
| 433 |
+
dim = 48,
|
| 434 |
+
num_blocks = [4,6,6,8],
|
| 435 |
+
num_refinement_blocks = 4,
|
| 436 |
+
channel_heads = [1,2,4,8],
|
| 437 |
+
spatial_heads = [2,2,3,4],
|
| 438 |
+
overlap_ratio=[0.5, 0.5, 0.5, 0.5],
|
| 439 |
+
window_size = 8,
|
| 440 |
+
spatial_dim_head = 16,
|
| 441 |
+
bias = False,
|
| 442 |
+
ffn_expansion_factor = 2.66,
|
| 443 |
+
LayerNorm_type = 'WithBias', ## Other option 'BiasFree'
|
| 444 |
+
M=13,
|
| 445 |
+
ca_heads=2,
|
| 446 |
+
ca_dim=32,
|
| 447 |
+
window_size_ca=0,
|
| 448 |
+
query_ksize=None
|
| 449 |
+
):
|
| 450 |
+
|
| 451 |
+
super(ACFormer, self).__init__()
|
| 452 |
+
self.center_idx = M // 2
|
| 453 |
+
self.ca = MAFG_CA(embed_dim=ca_dim, num_heads=ca_heads, M=M, window_size=window_size_ca)
|
| 454 |
+
self.patch_embed = OverlapPatchEmbed(inp_channels, dim)
|
| 455 |
+
|
| 456 |
+
self.encoder_level1 = nn.Sequential(*[Transformer_STAF(dim=dim, window_size = window_size, overlap_ratio=overlap_ratio[0], num_channel_heads=channel_heads[0], num_spatial_heads=spatial_heads[0], spatial_dim_head = spatial_dim_head, ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type, channel_fusion=False, query_ksize=0) for i in range(num_blocks[0])])
|
| 457 |
+
|
| 458 |
+
self.down1_2 = Downsample(dim) ## From Level 1 to Level 2
|
| 459 |
+
self.encoder_level2 = nn.Sequential(*[Transformer_STAF(dim=int(dim*2**1), window_size = window_size, overlap_ratio=overlap_ratio[1], num_channel_heads=channel_heads[1], num_spatial_heads=spatial_heads[1], spatial_dim_head = spatial_dim_head, ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type, channel_fusion=False, query_ksize=0) for i in range(num_blocks[1])])
|
| 460 |
+
|
| 461 |
+
self.down2_3 = Downsample(int(dim*2**1)) ## From Level 2 to Level 3
|
| 462 |
+
self.encoder_level3 = nn.Sequential(*[Transformer_STAF(dim=int(dim*2**2), window_size = window_size, overlap_ratio=overlap_ratio[2], num_channel_heads=channel_heads[2], num_spatial_heads=spatial_heads[2], spatial_dim_head = spatial_dim_head, ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type, channel_fusion=False, query_ksize=0) for i in range(num_blocks[2])])
|
| 463 |
+
|
| 464 |
+
self.down3_4 = Downsample(int(dim*2**2)) ## From Level 3 to Level 4
|
| 465 |
+
self.latent = nn.Sequential(*[Transformer_STAF(dim=int(dim*2**3), window_size = window_size, overlap_ratio=overlap_ratio[3], num_channel_heads=channel_heads[3], num_spatial_heads=spatial_heads[3], spatial_dim_head = spatial_dim_head, ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type, channel_fusion=False, query_ksize=query_ksize[0] if i % 2 == 1 else 0) for i in range(num_blocks[3])])
|
| 466 |
+
|
| 467 |
+
self.up4_3 = Upsample(int(dim*2**3)) ## From Level 4 to Level 3
|
| 468 |
+
self.reduce_chan_level3 = nn.Conv2d(int(dim*2**3), int(dim*2**2), kernel_size=1, bias=bias)
|
| 469 |
+
self.decoder_level3 = nn.Sequential(*[Transformer_STAF(dim=int(dim*2**2), window_size = window_size, overlap_ratio=overlap_ratio[2], num_channel_heads=channel_heads[2], num_spatial_heads=spatial_heads[2], spatial_dim_head = spatial_dim_head, ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type, channel_fusion=True, query_ksize=query_ksize[1] if i % 2 == 1 else 0) for i in range(num_blocks[2])])
|
| 470 |
+
|
| 471 |
+
self.up3_2 = Upsample(int(dim*2**2)) ## From Level 3 to Level 2
|
| 472 |
+
self.reduce_chan_level2 = nn.Conv2d(int(dim*2**2), int(dim*2**1), kernel_size=1, bias=bias)
|
| 473 |
+
self.decoder_level2 = nn.Sequential(*[Transformer_STAF(dim=int(dim*2**1), window_size = window_size, overlap_ratio=overlap_ratio[1], num_channel_heads=channel_heads[1], num_spatial_heads=spatial_heads[1], spatial_dim_head = spatial_dim_head, ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type, channel_fusion=True, query_ksize=query_ksize[2] if i % 2 == 1 else 0) for i in range(num_blocks[1])])
|
| 474 |
+
|
| 475 |
+
self.up2_1 = Upsample(int(dim*2**1)) ## From Level 2 to Level 1 (NO 1x1 conv to reduce channels)
|
| 476 |
+
|
| 477 |
+
self.decoder_level1 = nn.Sequential(*[Transformer_STAF(dim=int(dim*2**1), window_size = window_size, overlap_ratio=overlap_ratio[0], num_channel_heads=channel_heads[0], num_spatial_heads=spatial_heads[0], spatial_dim_head = spatial_dim_head, ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type, channel_fusion=True, query_ksize=query_ksize[3] if i % 2 == 1 else 0) for i in range(num_blocks[0])])
|
| 478 |
+
|
| 479 |
+
self.refinement = nn.Sequential(*[Transformer_STAF(dim=int(dim*2**1), window_size = window_size, overlap_ratio=overlap_ratio[0], num_channel_heads=channel_heads[0], num_spatial_heads=spatial_heads[0], spatial_dim_head = spatial_dim_head, ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type, channel_fusion=True, query_ksize=query_ksize[4] if i % 2 == 1 else 0) for i in range(num_refinement_blocks)])
|
| 480 |
+
|
| 481 |
+
self.output = nn.Conv2d(int(dim*2**1), out_channels, kernel_size=3, stride=1, padding=1, bias=bias)
|
| 482 |
+
|
| 483 |
+
def forward(self, inp_img):
|
| 484 |
+
if inp_img.ndim == 5:
|
| 485 |
+
B,M,C,H,W = inp_img.shape
|
| 486 |
+
center_img = inp_img[:, self.center_idx]
|
| 487 |
+
inp_img = inp_img.view(B*M,C,H,W).contiguous()
|
| 488 |
+
else:
|
| 489 |
+
center_img = inp_img
|
| 490 |
+
|
| 491 |
+
if self.ca is None:
|
| 492 |
+
inp_enc_level1 = inp_img.view(B,M*C,H,W)
|
| 493 |
+
else:
|
| 494 |
+
inp_enc_level1 = self.ca(inp_img)
|
| 495 |
+
|
| 496 |
+
inp_enc_level1 = self.patch_embed(inp_enc_level1)
|
| 497 |
+
|
| 498 |
+
out_enc_level1 = self.encoder_level1(inp_enc_level1)
|
| 499 |
+
|
| 500 |
+
inp_enc_level2 = self.down1_2(out_enc_level1)
|
| 501 |
+
out_enc_level2 = self.encoder_level2(inp_enc_level2)
|
| 502 |
+
|
| 503 |
+
inp_enc_level3 = self.down2_3(out_enc_level2)
|
| 504 |
+
out_enc_level3 = self.encoder_level3(inp_enc_level3)
|
| 505 |
+
|
| 506 |
+
inp_enc_level4 = self.down3_4(out_enc_level3)
|
| 507 |
+
latent = self.latent(inp_enc_level4)
|
| 508 |
+
|
| 509 |
+
inp_dec_level3 = self.up4_3(latent)
|
| 510 |
+
inp_dec_level3 = torch.cat([inp_dec_level3, out_enc_level3], 1)
|
| 511 |
+
inp_dec_level3 = self.reduce_chan_level3(inp_dec_level3)
|
| 512 |
+
out_dec_level3 = self.decoder_level3(inp_dec_level3)
|
| 513 |
+
|
| 514 |
+
inp_dec_level2 = self.up3_2(out_dec_level3)
|
| 515 |
+
inp_dec_level2 = torch.cat([inp_dec_level2, out_enc_level2], 1)
|
| 516 |
+
inp_dec_level2 = self.reduce_chan_level2(inp_dec_level2)
|
| 517 |
+
out_dec_level2 = self.decoder_level2(inp_dec_level2)
|
| 518 |
+
|
| 519 |
+
inp_dec_level1 = self.up2_1(out_dec_level2)
|
| 520 |
+
inp_dec_level1 = torch.cat([inp_dec_level1, out_enc_level1], 1)
|
| 521 |
+
out_dec_level1 = self.decoder_level1(inp_dec_level1)
|
| 522 |
+
|
| 523 |
+
|
| 524 |
+
out_dec_level1 = self.refinement(out_dec_level1)
|
| 525 |
+
out_dec_level1 = self.output(out_dec_level1) + center_img
|
| 526 |
+
|
| 527 |
+
return out_dec_level1
|
basicsr/models/base_model.py
ADDED
|
@@ -0,0 +1,376 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import os
|
| 3 |
+
import torch
|
| 4 |
+
from collections import OrderedDict
|
| 5 |
+
from copy import deepcopy
|
| 6 |
+
from torch.nn.parallel import DataParallel, DistributedDataParallel
|
| 7 |
+
|
| 8 |
+
from basicsr.models import lr_scheduler as lr_scheduler
|
| 9 |
+
from basicsr.utils.dist_util import master_only
|
| 10 |
+
|
| 11 |
+
logger = logging.getLogger('basicsr')
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class BaseModel():
|
| 15 |
+
"""Base model."""
|
| 16 |
+
|
| 17 |
+
def __init__(self, opt):
|
| 18 |
+
self.opt = opt
|
| 19 |
+
self.device = torch.device('cuda' if opt['num_gpu'] != 0 else 'cpu')
|
| 20 |
+
self.is_train = opt['is_train']
|
| 21 |
+
self.schedulers = []
|
| 22 |
+
self.optimizers = []
|
| 23 |
+
|
| 24 |
+
def feed_data(self, data):
|
| 25 |
+
pass
|
| 26 |
+
|
| 27 |
+
def optimize_parameters(self):
|
| 28 |
+
pass
|
| 29 |
+
|
| 30 |
+
def get_current_visuals(self):
|
| 31 |
+
pass
|
| 32 |
+
|
| 33 |
+
def save(self, epoch, current_iter):
|
| 34 |
+
"""Save networks and training state."""
|
| 35 |
+
pass
|
| 36 |
+
def validation(self, dataloader, current_iter, tb_logger, save_img=False, rgb2bgr=True, use_image=True, psf=None, ks=None, val_conv=True):
|
| 37 |
+
"""Validation function.
|
| 38 |
+
|
| 39 |
+
Args:
|
| 40 |
+
dataloader (torch.utils.data.DataLoader): Validation dataloader.
|
| 41 |
+
current_iter (int): Current iteration.
|
| 42 |
+
tb_logger (tensorboard logger): Tensorboard logger.
|
| 43 |
+
save_img (bool): Whether to save images. Default: False.
|
| 44 |
+
rgb2bgr (bool): Whether to save images using rgb2bgr. Default: True
|
| 45 |
+
use_image (bool): Whether to use saved images to compute metrics (PSNR, SSIM), if not, then use data directly from network' output. Default: True
|
| 46 |
+
"""
|
| 47 |
+
if self.opt['dist']:
|
| 48 |
+
return self.dist_validation(dataloader, current_iter, tb_logger, save_img, rgb2bgr, use_image, psf, ks, val_conv)
|
| 49 |
+
else:
|
| 50 |
+
return self.nondist_validation(dataloader, current_iter, tb_logger, save_img, rgb2bgr, use_image, psf, ks, val_conv)
|
| 51 |
+
|
| 52 |
+
def model_ema(self, decay=0.999):
|
| 53 |
+
net_g = self.get_bare_model(self.net_g)
|
| 54 |
+
|
| 55 |
+
net_g_params = dict(net_g.named_parameters())
|
| 56 |
+
net_g_ema_params = dict(self.net_g_ema.named_parameters())
|
| 57 |
+
|
| 58 |
+
for k in net_g_ema_params.keys():
|
| 59 |
+
net_g_ema_params[k].data.mul_(decay).add_(
|
| 60 |
+
net_g_params[k].data, alpha=1 - decay)
|
| 61 |
+
|
| 62 |
+
def get_current_log(self):
|
| 63 |
+
return self.log_dict
|
| 64 |
+
|
| 65 |
+
def model_to_device(self, net):
|
| 66 |
+
"""Model to device. It also warps models with DistributedDataParallel
|
| 67 |
+
or DataParallel.
|
| 68 |
+
|
| 69 |
+
Args:
|
| 70 |
+
net (nn.Module)
|
| 71 |
+
"""
|
| 72 |
+
|
| 73 |
+
net = net.to(self.device)
|
| 74 |
+
# if self.opt['dist']:
|
| 75 |
+
# find_unused_parameters = self.opt.get('find_unused_parameters',
|
| 76 |
+
# False)
|
| 77 |
+
# net = DistributedDataParallel(
|
| 78 |
+
# net,
|
| 79 |
+
# device_ids=[torch.cuda.current_device()],
|
| 80 |
+
# find_unused_parameters=find_unused_parameters)
|
| 81 |
+
# elif self.opt['num_gpu'] > 1:
|
| 82 |
+
# net = DataParallel(net)
|
| 83 |
+
return net
|
| 84 |
+
|
| 85 |
+
def setup_schedulers(self):
|
| 86 |
+
"""Set up schedulers."""
|
| 87 |
+
train_opt = self.opt['train']
|
| 88 |
+
scheduler_type = train_opt['scheduler'].pop('type')
|
| 89 |
+
if scheduler_type in ['MultiStepLR', 'MultiStepRestartLR']:
|
| 90 |
+
for optimizer in self.optimizers:
|
| 91 |
+
self.schedulers.append(
|
| 92 |
+
lr_scheduler.MultiStepRestartLR(optimizer,
|
| 93 |
+
**train_opt['scheduler']))
|
| 94 |
+
elif scheduler_type == 'CosineAnnealingRestartLR':
|
| 95 |
+
for optimizer in self.optimizers:
|
| 96 |
+
self.schedulers.append(
|
| 97 |
+
lr_scheduler.CosineAnnealingRestartLR(
|
| 98 |
+
optimizer, **train_opt['scheduler']))
|
| 99 |
+
elif scheduler_type == 'CosineAnnealingWarmupRestarts':
|
| 100 |
+
for optimizer in self.optimizers:
|
| 101 |
+
self.schedulers.append(
|
| 102 |
+
lr_scheduler.CosineAnnealingWarmupRestarts(
|
| 103 |
+
optimizer, **train_opt['scheduler']))
|
| 104 |
+
elif scheduler_type == 'CosineAnnealingRestartCyclicLR':
|
| 105 |
+
for optimizer in self.optimizers:
|
| 106 |
+
self.schedulers.append(
|
| 107 |
+
lr_scheduler.CosineAnnealingRestartCyclicLR(
|
| 108 |
+
optimizer, **train_opt['scheduler']))
|
| 109 |
+
elif scheduler_type == 'TrueCosineAnnealingLR':
|
| 110 |
+
print('..', 'cosineannealingLR')
|
| 111 |
+
for optimizer in self.optimizers:
|
| 112 |
+
self.schedulers.append(
|
| 113 |
+
torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, **train_opt['scheduler']))
|
| 114 |
+
elif scheduler_type == 'CosineAnnealingLRWithRestart':
|
| 115 |
+
print('..', 'CosineAnnealingLR_With_Restart')
|
| 116 |
+
for optimizer in self.optimizers:
|
| 117 |
+
self.schedulers.append(
|
| 118 |
+
lr_scheduler.CosineAnnealingLRWithRestart(optimizer, **train_opt['scheduler']))
|
| 119 |
+
elif scheduler_type == 'LinearLR':
|
| 120 |
+
for optimizer in self.optimizers:
|
| 121 |
+
self.schedulers.append(
|
| 122 |
+
lr_scheduler.LinearLR(
|
| 123 |
+
optimizer, train_opt['total_iter']))
|
| 124 |
+
elif scheduler_type == 'VibrateLR':
|
| 125 |
+
for optimizer in self.optimizers:
|
| 126 |
+
self.schedulers.append(
|
| 127 |
+
lr_scheduler.VibrateLR(
|
| 128 |
+
optimizer, train_opt['total_iter']))
|
| 129 |
+
else:
|
| 130 |
+
raise NotImplementedError(
|
| 131 |
+
f'Scheduler {scheduler_type} is not implemented yet.')
|
| 132 |
+
|
| 133 |
+
def get_bare_model(self, net):
|
| 134 |
+
"""Get bare model, especially under wrapping with
|
| 135 |
+
DistributedDataParallel or DataParallel.
|
| 136 |
+
"""
|
| 137 |
+
if isinstance(net, (DataParallel, DistributedDataParallel)):
|
| 138 |
+
net = net.module
|
| 139 |
+
return net
|
| 140 |
+
|
| 141 |
+
@master_only
|
| 142 |
+
def print_network(self, net):
|
| 143 |
+
"""Print the str and parameter number of a network.
|
| 144 |
+
|
| 145 |
+
Args:
|
| 146 |
+
net (nn.Module)
|
| 147 |
+
"""
|
| 148 |
+
if isinstance(net, (DataParallel, DistributedDataParallel)):
|
| 149 |
+
net_cls_str = (f'{net.__class__.__name__} - '
|
| 150 |
+
f'{net.module.__class__.__name__}')
|
| 151 |
+
else:
|
| 152 |
+
net_cls_str = f'{net.__class__.__name__}'
|
| 153 |
+
|
| 154 |
+
net = self.get_bare_model(net)
|
| 155 |
+
net_str = str(net)
|
| 156 |
+
net_params = sum(map(lambda x: x.numel(), net.parameters()))
|
| 157 |
+
|
| 158 |
+
logger.info(
|
| 159 |
+
f'Network: {net_cls_str}, with parameters: {net_params:,d}')
|
| 160 |
+
logger.info(net_str)
|
| 161 |
+
|
| 162 |
+
def _set_lr(self, lr_groups_l):
|
| 163 |
+
"""Set learning rate for warmup.
|
| 164 |
+
|
| 165 |
+
Args:
|
| 166 |
+
lr_groups_l (list): List for lr_groups, each for an optimizer.
|
| 167 |
+
"""
|
| 168 |
+
for optimizer, lr_groups in zip(self.optimizers, lr_groups_l):
|
| 169 |
+
for param_group, lr in zip(optimizer.param_groups, lr_groups):
|
| 170 |
+
param_group['lr'] = lr
|
| 171 |
+
|
| 172 |
+
def _get_init_lr(self):
|
| 173 |
+
"""Get the initial lr, which is set by the scheduler.
|
| 174 |
+
"""
|
| 175 |
+
init_lr_groups_l = []
|
| 176 |
+
for optimizer in self.optimizers:
|
| 177 |
+
init_lr_groups_l.append(
|
| 178 |
+
[v['initial_lr'] for v in optimizer.param_groups])
|
| 179 |
+
return init_lr_groups_l
|
| 180 |
+
|
| 181 |
+
def update_learning_rate(self, current_iter, warmup_iter=-1):
|
| 182 |
+
"""Update learning rate.
|
| 183 |
+
|
| 184 |
+
Args:
|
| 185 |
+
current_iter (int): Current iteration.
|
| 186 |
+
warmup_iter (int): Warmup iter numbers. -1 for no warmup.
|
| 187 |
+
Default: -1.
|
| 188 |
+
"""
|
| 189 |
+
if current_iter > 1:
|
| 190 |
+
for scheduler in self.schedulers:
|
| 191 |
+
scheduler.step()
|
| 192 |
+
# set up warm-up learning rate
|
| 193 |
+
if current_iter < warmup_iter:
|
| 194 |
+
# get initial lr for each group
|
| 195 |
+
init_lr_g_l = self._get_init_lr()
|
| 196 |
+
# modify warming-up learning rates
|
| 197 |
+
# currently only support linearly warm up
|
| 198 |
+
warm_up_lr_l = []
|
| 199 |
+
for init_lr_g in init_lr_g_l:
|
| 200 |
+
warm_up_lr_l.append(
|
| 201 |
+
[v / warmup_iter * current_iter for v in init_lr_g])
|
| 202 |
+
# set learning rate
|
| 203 |
+
self._set_lr(warm_up_lr_l)
|
| 204 |
+
|
| 205 |
+
def get_current_learning_rate(self):
|
| 206 |
+
return [
|
| 207 |
+
param_group['lr']
|
| 208 |
+
for param_group in self.optimizers[0].param_groups
|
| 209 |
+
]
|
| 210 |
+
|
| 211 |
+
@master_only
|
| 212 |
+
def save_network(self, net, net_label, current_iter, param_key='params'):
|
| 213 |
+
"""Save networks.
|
| 214 |
+
|
| 215 |
+
Args:
|
| 216 |
+
net (nn.Module | list[nn.Module]): Network(s) to be saved.
|
| 217 |
+
net_label (str): Network label.
|
| 218 |
+
current_iter (int): Current iter number.
|
| 219 |
+
param_key (str | list[str]): The parameter key(s) to save network.
|
| 220 |
+
Default: 'params'.
|
| 221 |
+
"""
|
| 222 |
+
if current_iter == -1:
|
| 223 |
+
current_iter = 'latest'
|
| 224 |
+
save_filename = f'{net_label}_{current_iter}.pth'
|
| 225 |
+
save_path = os.path.join(self.opt['path']['models'], save_filename)
|
| 226 |
+
|
| 227 |
+
net = net if isinstance(net, list) else [net]
|
| 228 |
+
param_key = param_key if isinstance(param_key, list) else [param_key]
|
| 229 |
+
assert len(net) == len(
|
| 230 |
+
param_key), 'The lengths of net and param_key should be the same.'
|
| 231 |
+
|
| 232 |
+
save_dict = {}
|
| 233 |
+
for net_, param_key_ in zip(net, param_key):
|
| 234 |
+
net_ = self.get_bare_model(net_)
|
| 235 |
+
state_dict = net_.state_dict()
|
| 236 |
+
for key, param in state_dict.items():
|
| 237 |
+
if key.startswith('module.'): # remove unnecessary 'module.'
|
| 238 |
+
key = key[7:]
|
| 239 |
+
state_dict[key] = param.cpu()
|
| 240 |
+
save_dict[param_key_] = state_dict
|
| 241 |
+
|
| 242 |
+
torch.save(save_dict, save_path)
|
| 243 |
+
|
| 244 |
+
def _print_different_keys_loading(self, crt_net, load_net, strict=True):
|
| 245 |
+
"""Print keys with differnet name or different size when loading models.
|
| 246 |
+
|
| 247 |
+
1. Print keys with differnet names.
|
| 248 |
+
2. If strict=False, print the same key but with different tensor size.
|
| 249 |
+
It also ignore these keys with different sizes (not load).
|
| 250 |
+
|
| 251 |
+
Args:
|
| 252 |
+
crt_net (torch model): Current network.
|
| 253 |
+
load_net (dict): Loaded network.
|
| 254 |
+
strict (bool): Whether strictly loaded. Default: True.
|
| 255 |
+
"""
|
| 256 |
+
crt_net = self.get_bare_model(crt_net)
|
| 257 |
+
crt_net = crt_net.state_dict()
|
| 258 |
+
crt_net_keys = set(crt_net.keys())
|
| 259 |
+
load_net_keys = set(load_net.keys())
|
| 260 |
+
|
| 261 |
+
if crt_net_keys != load_net_keys:
|
| 262 |
+
logger.warning('Current net - loaded net:')
|
| 263 |
+
for v in sorted(list(crt_net_keys - load_net_keys)):
|
| 264 |
+
logger.warning(f' {v}')
|
| 265 |
+
logger.warning('Loaded net - current net:')
|
| 266 |
+
for v in sorted(list(load_net_keys - crt_net_keys)):
|
| 267 |
+
logger.warning(f' {v}')
|
| 268 |
+
|
| 269 |
+
# check the size for the same keys
|
| 270 |
+
if not strict:
|
| 271 |
+
common_keys = crt_net_keys & load_net_keys
|
| 272 |
+
for k in common_keys:
|
| 273 |
+
if crt_net[k].size() != load_net[k].size():
|
| 274 |
+
logger.warning(
|
| 275 |
+
f'Size different, ignore [{k}]: crt_net: '
|
| 276 |
+
f'{crt_net[k].shape}; load_net: {load_net[k].shape}')
|
| 277 |
+
load_net[k + '.ignore'] = load_net.pop(k)
|
| 278 |
+
|
| 279 |
+
def load_network(self, net, load_path, strict=True, param_key='params'):
|
| 280 |
+
"""Load network.
|
| 281 |
+
|
| 282 |
+
Args:
|
| 283 |
+
load_path (str): The path of networks to be loaded.
|
| 284 |
+
net (nn.Module): Network.
|
| 285 |
+
strict (bool): Whether strictly loaded.
|
| 286 |
+
param_key (str): The parameter key of loaded network. If set to
|
| 287 |
+
None, use the root 'path'.
|
| 288 |
+
Default: 'params'.
|
| 289 |
+
"""
|
| 290 |
+
net = self.get_bare_model(net)
|
| 291 |
+
logger.info(
|
| 292 |
+
f'Loading {net.__class__.__name__} model from {load_path}.')
|
| 293 |
+
load_net = torch.load(
|
| 294 |
+
load_path, map_location=lambda storage, loc: storage)
|
| 295 |
+
if param_key is not None:
|
| 296 |
+
if param_key not in load_net and 'params' in load_net:
|
| 297 |
+
param_key = 'params'
|
| 298 |
+
logger.info('Loading: params_ema does not exist, use params.')
|
| 299 |
+
load_net = load_net[param_key]
|
| 300 |
+
print(' load net keys', load_net.keys)
|
| 301 |
+
# remove unnecessary 'module.'
|
| 302 |
+
for k, v in deepcopy(load_net).items():
|
| 303 |
+
if k.startswith('module.'):
|
| 304 |
+
load_net[k[7:]] = v
|
| 305 |
+
load_net.pop(k)
|
| 306 |
+
self._print_different_keys_loading(net, load_net, strict)
|
| 307 |
+
net.load_state_dict(load_net, strict=strict)
|
| 308 |
+
|
| 309 |
+
@master_only
|
| 310 |
+
def save_training_state(self, epoch, current_iter):
|
| 311 |
+
"""Save training states during training, which will be used for
|
| 312 |
+
resuming.
|
| 313 |
+
|
| 314 |
+
Args:
|
| 315 |
+
epoch (int): Current epoch.
|
| 316 |
+
current_iter (int): Current iteration.
|
| 317 |
+
"""
|
| 318 |
+
if current_iter != -1:
|
| 319 |
+
state = {
|
| 320 |
+
'epoch': epoch,
|
| 321 |
+
'iter': current_iter,
|
| 322 |
+
'optimizers': [],
|
| 323 |
+
'schedulers': []
|
| 324 |
+
}
|
| 325 |
+
for o in self.optimizers:
|
| 326 |
+
state['optimizers'].append(o.state_dict())
|
| 327 |
+
for s in self.schedulers:
|
| 328 |
+
state['schedulers'].append(s.state_dict())
|
| 329 |
+
save_filename = f'{current_iter}.state'
|
| 330 |
+
save_path = os.path.join(self.opt['path']['training_states'],
|
| 331 |
+
save_filename)
|
| 332 |
+
torch.save(state, save_path)
|
| 333 |
+
|
| 334 |
+
def resume_training(self, resume_state):
|
| 335 |
+
"""Reload the optimizers and schedulers for resumed training.
|
| 336 |
+
|
| 337 |
+
Args:
|
| 338 |
+
resume_state (dict): Resume state.
|
| 339 |
+
"""
|
| 340 |
+
resume_optimizers = resume_state['optimizers']
|
| 341 |
+
resume_schedulers = resume_state['schedulers']
|
| 342 |
+
assert len(resume_optimizers) == len(
|
| 343 |
+
self.optimizers), 'Wrong lengths of optimizers'
|
| 344 |
+
assert len(resume_schedulers) == len(
|
| 345 |
+
self.schedulers), 'Wrong lengths of schedulers'
|
| 346 |
+
for i, o in enumerate(resume_optimizers):
|
| 347 |
+
self.optimizers[i].load_state_dict(o)
|
| 348 |
+
for i, s in enumerate(resume_schedulers):
|
| 349 |
+
self.schedulers[i].load_state_dict(s)
|
| 350 |
+
|
| 351 |
+
def reduce_loss_dict(self, loss_dict):
|
| 352 |
+
"""reduce loss dict.
|
| 353 |
+
|
| 354 |
+
In distributed training, it averages the losses among different GPUs .
|
| 355 |
+
|
| 356 |
+
Args:
|
| 357 |
+
loss_dict (OrderedDict): Loss dict.
|
| 358 |
+
"""
|
| 359 |
+
with torch.no_grad():
|
| 360 |
+
if self.opt['dist']:
|
| 361 |
+
keys = []
|
| 362 |
+
losses = []
|
| 363 |
+
for name, value in loss_dict.items():
|
| 364 |
+
keys.append(name)
|
| 365 |
+
losses.append(value)
|
| 366 |
+
losses = torch.stack(losses, 0)
|
| 367 |
+
torch.distributed.reduce(losses, dst=0)
|
| 368 |
+
if self.opt['rank'] == 0:
|
| 369 |
+
losses /= self.opt['world_size']
|
| 370 |
+
loss_dict = {key: loss for key, loss in zip(keys, losses)}
|
| 371 |
+
|
| 372 |
+
log_dict = OrderedDict()
|
| 373 |
+
for name, value in loss_dict.items():
|
| 374 |
+
log_dict[name] = value.mean().item()
|
| 375 |
+
|
| 376 |
+
return log_dict
|
basicsr/models/image_restoration_model.py
ADDED
|
@@ -0,0 +1,392 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import importlib
|
| 2 |
+
import torch
|
| 3 |
+
import os
|
| 4 |
+
import gc
|
| 5 |
+
import random
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
|
| 8 |
+
from collections import OrderedDict
|
| 9 |
+
from copy import deepcopy
|
| 10 |
+
from os import path as osp
|
| 11 |
+
from tqdm import tqdm
|
| 12 |
+
from functools import partial
|
| 13 |
+
|
| 14 |
+
from basicsr.models.archs import define_network
|
| 15 |
+
from basicsr.models.base_model import BaseModel
|
| 16 |
+
from basicsr.utils import get_root_logger, imwrite, tensor2img
|
| 17 |
+
from basicsr.utils.nano import apply_conv_n_deconv
|
| 18 |
+
from basicsr.metrics.other_metrics import compute_img_metric
|
| 19 |
+
|
| 20 |
+
loss_module = importlib.import_module('basicsr.models.losses')
|
| 21 |
+
metric_module = importlib.import_module('basicsr.metrics')
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class Mixing_Augment:
|
| 25 |
+
def __init__(self, mixup_beta, use_identity, device):
|
| 26 |
+
self.dist = torch.distributions.beta.Beta(torch.tensor([mixup_beta]), torch.tensor([mixup_beta]))
|
| 27 |
+
self.device = device
|
| 28 |
+
|
| 29 |
+
self.use_identity = use_identity
|
| 30 |
+
|
| 31 |
+
self.augments = [self.mixup]
|
| 32 |
+
|
| 33 |
+
def mixup(self, target, input_):
|
| 34 |
+
lam = self.dist.rsample((1,1)).item()
|
| 35 |
+
|
| 36 |
+
r_index = torch.randperm(target.size(0)).to(self.device)
|
| 37 |
+
|
| 38 |
+
target = lam * target + (1-lam) * target[r_index, :]
|
| 39 |
+
input_ = lam * input_ + (1-lam) * input_[r_index, :]
|
| 40 |
+
|
| 41 |
+
return target, input_
|
| 42 |
+
|
| 43 |
+
def __call__(self, target, input_):
|
| 44 |
+
if self.use_identity:
|
| 45 |
+
augment = random.randint(0, len(self.augments))
|
| 46 |
+
if augment < len(self.augments):
|
| 47 |
+
target, input_ = self.augments[augment](target, input_)
|
| 48 |
+
else:
|
| 49 |
+
augment = random.randint(0, len(self.augments)-1)
|
| 50 |
+
target, input_ = self.augments[augment](target, input_)
|
| 51 |
+
return target, input_
|
| 52 |
+
|
| 53 |
+
class ImageCleanModel(BaseModel):
|
| 54 |
+
"""Base Deblur model for single image deblur."""
|
| 55 |
+
|
| 56 |
+
def __init__(self, opt):
|
| 57 |
+
super(ImageCleanModel, self).__init__(opt)
|
| 58 |
+
|
| 59 |
+
# define network
|
| 60 |
+
|
| 61 |
+
self.mixing_flag = self.opt['train']['mixing_augs'].get('mixup', False)
|
| 62 |
+
if self.mixing_flag:
|
| 63 |
+
mixup_beta = self.opt['train']['mixing_augs'].get('mixup_beta', 1.2)
|
| 64 |
+
use_identity = self.opt['train']['mixing_augs'].get('use_identity', False)
|
| 65 |
+
self.mixing_augmentation = Mixing_Augment(mixup_beta, use_identity, self.device)
|
| 66 |
+
|
| 67 |
+
self.net_g = define_network(deepcopy(opt['network_g']))
|
| 68 |
+
self.net_g = self.model_to_device(self.net_g)
|
| 69 |
+
|
| 70 |
+
# load pretrained models
|
| 71 |
+
load_path = self.opt['path'].get('pretrain_network_g', None)
|
| 72 |
+
if load_path is not None:
|
| 73 |
+
self.load_network(self.net_g, load_path,
|
| 74 |
+
self.opt['path'].get('strict_load_g', True), param_key=self.opt['path'].get('param_key', 'params'))
|
| 75 |
+
|
| 76 |
+
if self.is_train:
|
| 77 |
+
self.init_training_settings()
|
| 78 |
+
|
| 79 |
+
def init_training_settings(self):
|
| 80 |
+
self.net_g.train()
|
| 81 |
+
train_opt = self.opt['train']
|
| 82 |
+
|
| 83 |
+
self.ema_decay = train_opt.get('ema_decay', 0)
|
| 84 |
+
if self.ema_decay > 0:
|
| 85 |
+
logger = get_root_logger()
|
| 86 |
+
logger.info(
|
| 87 |
+
f'Use Exponential Moving Average with decay: {self.ema_decay}')
|
| 88 |
+
# define network net_g with Exponential Moving Average (EMA)
|
| 89 |
+
# net_g_ema is used only for testing on one GPU and saving
|
| 90 |
+
# There is no need to wrap with DistributedDataParallel
|
| 91 |
+
self.net_g_ema = define_network(self.opt['network_g']).to(
|
| 92 |
+
self.device)
|
| 93 |
+
# load pretrained model
|
| 94 |
+
load_path = self.opt['path'].get('pretrain_network_g', None)
|
| 95 |
+
if load_path is not None:
|
| 96 |
+
self.load_network(self.net_g_ema, load_path,
|
| 97 |
+
self.opt['path'].get('strict_load_g',
|
| 98 |
+
True), 'params_ema')
|
| 99 |
+
else:
|
| 100 |
+
self.model_ema(0) # copy net_g weight
|
| 101 |
+
self.net_g_ema.eval()
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
# define losses
|
| 105 |
+
if train_opt.get('pixel_opt'):
|
| 106 |
+
pixel_type = train_opt['pixel_opt'].pop('type')
|
| 107 |
+
cri_pix_cls = getattr(loss_module, pixel_type)
|
| 108 |
+
self.cri_pix = cri_pix_cls(**train_opt['pixel_opt']).to(
|
| 109 |
+
self.device)
|
| 110 |
+
else:
|
| 111 |
+
raise ValueError('pixel loss are None.')
|
| 112 |
+
|
| 113 |
+
# set up optimizers and schedulers
|
| 114 |
+
self.setup_optimizers()
|
| 115 |
+
self.setup_schedulers()
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def setup_optimizers(self):
|
| 119 |
+
train_opt = self.opt['train']
|
| 120 |
+
optim_params = []
|
| 121 |
+
|
| 122 |
+
for k, v in self.net_g.named_parameters():
|
| 123 |
+
if v.requires_grad:
|
| 124 |
+
optim_params.append(v)
|
| 125 |
+
else:
|
| 126 |
+
logger = get_root_logger()
|
| 127 |
+
logger.warning(f'Params {k} will not be optimized.')
|
| 128 |
+
|
| 129 |
+
optim_type = train_opt['optim_g'].pop('type')
|
| 130 |
+
if optim_type == 'Adam':
|
| 131 |
+
self.optimizer_g = torch.optim.Adam(optim_params, **train_opt['optim_g'])
|
| 132 |
+
elif optim_type == 'AdamW':
|
| 133 |
+
self.optimizer_g = torch.optim.AdamW(optim_params, **train_opt['optim_g'])
|
| 134 |
+
else:
|
| 135 |
+
raise NotImplementedError(
|
| 136 |
+
f'optimizer {optim_type} is not supperted yet.')
|
| 137 |
+
self.optimizers.append(self.optimizer_g)
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
def feed_train_data(self, data):
|
| 141 |
+
self.lq = data['lq'].to(self.device)
|
| 142 |
+
if 'gt' in data:
|
| 143 |
+
self.gt = data['gt'].to(self.device)
|
| 144 |
+
|
| 145 |
+
if self.mixing_flag:
|
| 146 |
+
self.gt, self.lq = self.mixing_augmentation(self.gt, self.lq)
|
| 147 |
+
|
| 148 |
+
def feed_data(self, data, psf=None, ks=None, val_conv=True):
|
| 149 |
+
gt = data['gt'].to(self.device)
|
| 150 |
+
padding = data['padding']
|
| 151 |
+
padding = torch.stack(padding).T
|
| 152 |
+
otf = psf
|
| 153 |
+
M = ks.shape[1]
|
| 154 |
+
if val_conv: # Apply convolution on the fly (use gt img to create lr image)
|
| 155 |
+
lq, gt = apply_conv_n_deconv(gt, otf, padding, M, 0, ks=ks, ph=135, num_psf=9, sensor_h=1215, crop=False, conv=True)
|
| 156 |
+
self.lq = lq[None]
|
| 157 |
+
self.gt = gt[None] # TODO check dim. 이전에는 square에서 리턴해주는거 그대로 썼는데 지금은 원래 gt 바로 써서 shape 다를수도. 이후 아래랑 합치기
|
| 158 |
+
# TODO 애초에 deconv(gt) 를 gt를 위에서 if else로 받아서 한 줄로 처리 가능
|
| 159 |
+
|
| 160 |
+
else: # loaded npy for validaiton
|
| 161 |
+
lq = data['lq'].to(self.device)
|
| 162 |
+
lq, gt = apply_conv_n_deconv(lq, otf, padding, M, 0, ks=ks, ph=135, num_psf=9, sensor_h=1215, crop=False, conv=False)
|
| 163 |
+
self.lq = lq[None]
|
| 164 |
+
self.gt = gt
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
def optimize_parameters(self, current_iter):
|
| 168 |
+
self.optimizer_g.zero_grad()
|
| 169 |
+
preds = self.net_g(self.lq)
|
| 170 |
+
if not isinstance(preds, list):
|
| 171 |
+
preds = [preds]
|
| 172 |
+
|
| 173 |
+
self.output = preds[-1]
|
| 174 |
+
|
| 175 |
+
loss_dict = OrderedDict()
|
| 176 |
+
# pixel loss
|
| 177 |
+
l_pix = 0.
|
| 178 |
+
for pred in preds:
|
| 179 |
+
l_pix += self.cri_pix(pred, self.gt)
|
| 180 |
+
|
| 181 |
+
loss_dict['l_pix'] = l_pix
|
| 182 |
+
|
| 183 |
+
l_pix.backward()
|
| 184 |
+
if self.opt['train']['use_grad_clip']:
|
| 185 |
+
torch.nn.utils.clip_grad_norm_(self.net_g.parameters(), 0.01)
|
| 186 |
+
self.optimizer_g.step()
|
| 187 |
+
|
| 188 |
+
self.log_dict = self.reduce_loss_dict(loss_dict)
|
| 189 |
+
|
| 190 |
+
if self.ema_decay > 0:
|
| 191 |
+
self.model_ema(decay=self.ema_decay)
|
| 192 |
+
|
| 193 |
+
def pad_test(self, window_size):
|
| 194 |
+
scale = self.opt.get('scale', 1)
|
| 195 |
+
mod_pad_h, mod_pad_w = 0, 0
|
| 196 |
+
h,w = self.lq.size()[-2:]
|
| 197 |
+
if h % window_size != 0:
|
| 198 |
+
mod_pad_h = window_size - h % window_size
|
| 199 |
+
if w % window_size != 0:
|
| 200 |
+
mod_pad_w = window_size - w % window_size
|
| 201 |
+
img = F.pad(self.lq[0], (0, mod_pad_w, 0, mod_pad_h), 'reflect')[None]
|
| 202 |
+
self.nonpad_test(img)
|
| 203 |
+
_, _, h, w = self.output.size()
|
| 204 |
+
self.output = self.output[:, :, 0:h - mod_pad_h * scale, 0:w - mod_pad_w * scale]
|
| 205 |
+
|
| 206 |
+
def nonpad_test(self, img=None):
|
| 207 |
+
if img is None:
|
| 208 |
+
img = self.lq
|
| 209 |
+
if hasattr(self, 'net_g_ema'):
|
| 210 |
+
self.net_g_ema.eval()
|
| 211 |
+
with torch.no_grad():
|
| 212 |
+
pred = self.net_g_ema(img)
|
| 213 |
+
if isinstance(pred, list):
|
| 214 |
+
pred = pred[-1]
|
| 215 |
+
self.output = pred
|
| 216 |
+
else:
|
| 217 |
+
self.net_g.eval()
|
| 218 |
+
with torch.no_grad():
|
| 219 |
+
pred = self.net_g(img)
|
| 220 |
+
|
| 221 |
+
if isinstance(pred, list):
|
| 222 |
+
pred = pred[-1]
|
| 223 |
+
self.output = pred
|
| 224 |
+
self.net_g.train()
|
| 225 |
+
|
| 226 |
+
def dist_validation(self, dataloader, current_iter, tb_logger, save_img, rgb2bgr, use_image, psf, ks, val_conv):
|
| 227 |
+
if os.environ['LOCAL_RANK'] == '0':
|
| 228 |
+
return self.nondist_validation(dataloader, current_iter, tb_logger, save_img, rgb2bgr, use_image, psf, ks, val_conv)
|
| 229 |
+
else:
|
| 230 |
+
return 0.
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
def pre_process(self, padding_size):
|
| 234 |
+
# pad to multiplication of window_size
|
| 235 |
+
self.mod_pad_h, self.mod_pad_w = 0, 0
|
| 236 |
+
h,w = self.lq.size()[-2:] # BMCHW
|
| 237 |
+
if h % padding_size != 0:
|
| 238 |
+
self.mod_pad_h = padding_size - h % padding_size
|
| 239 |
+
if w % padding_size != 0:
|
| 240 |
+
self.mod_pad_w = padding_size - w % padding_size
|
| 241 |
+
self.lq = F.pad(self.lq[0], (0, self.mod_pad_w, 0, self.mod_pad_h), 'reflect')[None]
|
| 242 |
+
|
| 243 |
+
def post_process(self):
|
| 244 |
+
_, _, h, w = self.output.size()
|
| 245 |
+
self.output = self.output[...,0:h - self.mod_pad_h, 0:w - self.mod_pad_w]
|
| 246 |
+
|
| 247 |
+
def nondist_validation(self, dataloader, current_iter, tb_logger,
|
| 248 |
+
save_img, rgb2bgr, use_image, psf, ks, val_conv):
|
| 249 |
+
dataset_name = dataloader.dataset.opt['name']
|
| 250 |
+
base_path = self.opt['path']['visualization']
|
| 251 |
+
|
| 252 |
+
with_metrics = self.opt['val'].get('metrics') is not None
|
| 253 |
+
if with_metrics:
|
| 254 |
+
self.metric_results = {
|
| 255 |
+
metric: 0
|
| 256 |
+
for metric in self.opt['val']['metrics'].keys()
|
| 257 |
+
}
|
| 258 |
+
if save_img:
|
| 259 |
+
cur_other_metrics = {'ssim': 0., 'lpips': 0.}
|
| 260 |
+
else:
|
| 261 |
+
cur_other_metrics = None
|
| 262 |
+
|
| 263 |
+
window_size = self.opt['val'].get('window_size', 0)
|
| 264 |
+
|
| 265 |
+
if window_size:
|
| 266 |
+
test = partial(self.pad_test, window_size)
|
| 267 |
+
else:
|
| 268 |
+
test = self.nonpad_test
|
| 269 |
+
|
| 270 |
+
cnt = 0
|
| 271 |
+
|
| 272 |
+
for idx, val_data in enumerate(tqdm(dataloader)):
|
| 273 |
+
img_name = osp.splitext(osp.basename(val_data['gt_path'][0]))[0]
|
| 274 |
+
self.feed_data(val_data, psf, ks, val_conv)
|
| 275 |
+
pad_for_OCB = self.opt['val'].get('padding')
|
| 276 |
+
if pad_for_OCB is not None:
|
| 277 |
+
self.pre_process(pad_for_OCB)
|
| 278 |
+
|
| 279 |
+
torch.cuda.empty_cache()
|
| 280 |
+
gc.collect()
|
| 281 |
+
|
| 282 |
+
test()
|
| 283 |
+
|
| 284 |
+
if pad_for_OCB is not None:
|
| 285 |
+
self.post_process()
|
| 286 |
+
|
| 287 |
+
if save_img and with_metrics and use_image:
|
| 288 |
+
visuals = self.get_current_visuals(to_cpu=False)
|
| 289 |
+
cur_other_metrics['ssim'] += compute_img_metric(visuals['result'][0], visuals['gt'][0], 'ssim')
|
| 290 |
+
cur_other_metrics['lpips'] += compute_img_metric(visuals['result'][0], visuals['gt'][0], 'lpips').item()
|
| 291 |
+
|
| 292 |
+
visuals = self.get_current_visuals()
|
| 293 |
+
|
| 294 |
+
sr_img = tensor2img([visuals['result']], rgb2bgr=rgb2bgr)
|
| 295 |
+
if 'gt' in visuals:
|
| 296 |
+
gt_img = tensor2img([visuals['gt']], rgb2bgr=rgb2bgr)
|
| 297 |
+
del self.gt
|
| 298 |
+
|
| 299 |
+
# tentative for out of GPU memory
|
| 300 |
+
del self.lq
|
| 301 |
+
del self.output
|
| 302 |
+
torch.cuda.empty_cache()
|
| 303 |
+
gc.collect()
|
| 304 |
+
|
| 305 |
+
if save_img:
|
| 306 |
+
if self.opt['is_train']:
|
| 307 |
+
if 'eval_only' in self.opt['train']:
|
| 308 |
+
save_img_path = osp.join(base_path + self.opt['train']['eval_name'],
|
| 309 |
+
f'{img_name}_{current_iter}.png')
|
| 310 |
+
else:
|
| 311 |
+
save_img_path = osp.join(base_path,
|
| 312 |
+
f'{img_name}_{current_iter}.png')
|
| 313 |
+
else:
|
| 314 |
+
save_img_path = osp.join(
|
| 315 |
+
base_path,
|
| 316 |
+
f'{img_name}.png')
|
| 317 |
+
save_gt_img_path = osp.join(
|
| 318 |
+
base_path, dataset_name,
|
| 319 |
+
f'{img_name}_gt.png')
|
| 320 |
+
|
| 321 |
+
imwrite(sr_img, save_img_path)
|
| 322 |
+
|
| 323 |
+
if with_metrics:
|
| 324 |
+
# calculate metrics
|
| 325 |
+
opt_metric = deepcopy(self.opt['val']['metrics'])
|
| 326 |
+
if use_image:
|
| 327 |
+
for name, opt_ in opt_metric.items():
|
| 328 |
+
metric_type = opt_.pop('type')
|
| 329 |
+
self.metric_results[name] += getattr(
|
| 330 |
+
metric_module, metric_type)(sr_img, gt_img, **opt_)
|
| 331 |
+
else:
|
| 332 |
+
for name, opt_ in opt_metric.items():
|
| 333 |
+
metric_type = opt_.pop('type')
|
| 334 |
+
self.metric_results[name] += getattr(
|
| 335 |
+
metric_module, metric_type)(visuals['result'], visuals['gt'], **opt_)
|
| 336 |
+
|
| 337 |
+
cnt += 1
|
| 338 |
+
|
| 339 |
+
|
| 340 |
+
# tentative for out of GPU memory
|
| 341 |
+
torch.cuda.empty_cache()
|
| 342 |
+
gc.collect()
|
| 343 |
+
|
| 344 |
+
current_metric = 0.
|
| 345 |
+
if with_metrics:
|
| 346 |
+
for metric in self.metric_results.keys():
|
| 347 |
+
self.metric_results[metric] /= cnt
|
| 348 |
+
current_metric = self.metric_results[metric]
|
| 349 |
+
if save_img:
|
| 350 |
+
cur_other_metrics['ssim'] /= cnt
|
| 351 |
+
cur_other_metrics['lpips'] /= cnt
|
| 352 |
+
|
| 353 |
+
self._log_validation_metric_values(current_iter, dataset_name,
|
| 354 |
+
tb_logger)
|
| 355 |
+
return current_metric, cur_other_metrics
|
| 356 |
+
|
| 357 |
+
|
| 358 |
+
def _log_validation_metric_values(self, current_iter, dataset_name,
|
| 359 |
+
tb_logger):
|
| 360 |
+
log_str = f'Validation {dataset_name},\t'
|
| 361 |
+
for metric, value in self.metric_results.items():
|
| 362 |
+
log_str += f'\t # {metric}: {value:.4f}'
|
| 363 |
+
logger = get_root_logger()
|
| 364 |
+
logger.info(log_str)
|
| 365 |
+
if tb_logger:
|
| 366 |
+
for metric, value in self.metric_results.items():
|
| 367 |
+
tb_logger.add_scalar(f'metrics/{metric}', value, current_iter)
|
| 368 |
+
|
| 369 |
+
def get_current_visuals(self, to_cpu=True):
|
| 370 |
+
if to_cpu:
|
| 371 |
+
out_dict = OrderedDict()
|
| 372 |
+
out_dict['lq'] = self.lq.detach().cpu()
|
| 373 |
+
out_dict['result'] = self.output.detach().cpu()
|
| 374 |
+
if hasattr(self, 'gt'):
|
| 375 |
+
out_dict['gt'] = self.gt.detach().cpu()
|
| 376 |
+
else:
|
| 377 |
+
out_dict = OrderedDict()
|
| 378 |
+
out_dict['lq'] = self.lq.detach()
|
| 379 |
+
out_dict['result'] = self.output.detach()
|
| 380 |
+
if hasattr(self, 'gt'):
|
| 381 |
+
out_dict['gt'] = self.gt.detach()
|
| 382 |
+
return out_dict
|
| 383 |
+
|
| 384 |
+
def save(self, epoch, current_iter):
|
| 385 |
+
if self.ema_decay > 0:
|
| 386 |
+
self.save_network([self.net_g, self.net_g_ema],
|
| 387 |
+
'net_g',
|
| 388 |
+
current_iter,
|
| 389 |
+
param_key=['params', 'params_ema'])
|
| 390 |
+
else:
|
| 391 |
+
self.save_network(self.net_g, 'net_g', current_iter)
|
| 392 |
+
self.save_training_state(epoch, current_iter)
|
basicsr/models/losses/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .losses import (L1Loss, MSELoss, PSNRLoss, CharbonnierLoss)
|
| 2 |
+
|
| 3 |
+
__all__ = [
|
| 4 |
+
'L1Loss', 'MSELoss', 'PSNRLoss', 'CharbonnierLoss',
|
| 5 |
+
]
|
basicsr/models/losses/loss_util.py
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import functools
|
| 2 |
+
from torch.nn import functional as F
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def reduce_loss(loss, reduction):
|
| 6 |
+
"""Reduce loss as specified.
|
| 7 |
+
|
| 8 |
+
Args:
|
| 9 |
+
loss (Tensor): Elementwise loss tensor.
|
| 10 |
+
reduction (str): Options are 'none', 'mean' and 'sum'.
|
| 11 |
+
|
| 12 |
+
Returns:
|
| 13 |
+
Tensor: Reduced loss tensor.
|
| 14 |
+
"""
|
| 15 |
+
reduction_enum = F._Reduction.get_enum(reduction)
|
| 16 |
+
# none: 0, elementwise_mean:1, sum: 2
|
| 17 |
+
if reduction_enum == 0:
|
| 18 |
+
return loss
|
| 19 |
+
elif reduction_enum == 1:
|
| 20 |
+
return loss.mean()
|
| 21 |
+
else:
|
| 22 |
+
return loss.sum()
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def weight_reduce_loss(loss, weight=None, reduction='mean'):
|
| 26 |
+
"""Apply element-wise weight and reduce loss.
|
| 27 |
+
|
| 28 |
+
Args:
|
| 29 |
+
loss (Tensor): Element-wise loss.
|
| 30 |
+
weight (Tensor): Element-wise weights. Default: None.
|
| 31 |
+
reduction (str): Same as built-in losses of PyTorch. Options are
|
| 32 |
+
'none', 'mean' and 'sum'. Default: 'mean'.
|
| 33 |
+
|
| 34 |
+
Returns:
|
| 35 |
+
Tensor: Loss values.
|
| 36 |
+
"""
|
| 37 |
+
# if weight is specified, apply element-wise weight
|
| 38 |
+
if weight is not None:
|
| 39 |
+
assert weight.dim() == loss.dim()
|
| 40 |
+
assert weight.size(1) == 1 or weight.size(1) == loss.size(1)
|
| 41 |
+
loss = loss * weight
|
| 42 |
+
|
| 43 |
+
# if weight is not specified or reduction is sum, just reduce the loss
|
| 44 |
+
if weight is None or reduction == 'sum':
|
| 45 |
+
loss = reduce_loss(loss, reduction)
|
| 46 |
+
# if reduction is mean, then compute mean over weight region
|
| 47 |
+
elif reduction == 'mean':
|
| 48 |
+
if weight.size(1) > 1:
|
| 49 |
+
weight = weight.sum()
|
| 50 |
+
else:
|
| 51 |
+
weight = weight.sum() * loss.size(1)
|
| 52 |
+
loss = loss.sum() / weight
|
| 53 |
+
|
| 54 |
+
return loss
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def weighted_loss(loss_func):
|
| 58 |
+
"""Create a weighted version of a given loss function.
|
| 59 |
+
|
| 60 |
+
To use this decorator, the loss function must have the signature like
|
| 61 |
+
`loss_func(pred, target, **kwargs)`. The function only needs to compute
|
| 62 |
+
element-wise loss without any reduction. This decorator will add weight
|
| 63 |
+
and reduction arguments to the function. The decorated function will have
|
| 64 |
+
the signature like `loss_func(pred, target, weight=None, reduction='mean',
|
| 65 |
+
**kwargs)`.
|
| 66 |
+
|
| 67 |
+
:Example:
|
| 68 |
+
|
| 69 |
+
>>> import torch
|
| 70 |
+
>>> @weighted_loss
|
| 71 |
+
>>> def l1_loss(pred, target):
|
| 72 |
+
>>> return (pred - target).abs()
|
| 73 |
+
|
| 74 |
+
>>> pred = torch.Tensor([0, 2, 3])
|
| 75 |
+
>>> target = torch.Tensor([1, 1, 1])
|
| 76 |
+
>>> weight = torch.Tensor([1, 0, 1])
|
| 77 |
+
|
| 78 |
+
>>> l1_loss(pred, target)
|
| 79 |
+
tensor(1.3333)
|
| 80 |
+
>>> l1_loss(pred, target, weight)
|
| 81 |
+
tensor(1.5000)
|
| 82 |
+
>>> l1_loss(pred, target, reduction='none')
|
| 83 |
+
tensor([1., 1., 2.])
|
| 84 |
+
>>> l1_loss(pred, target, weight, reduction='sum')
|
| 85 |
+
tensor(3.)
|
| 86 |
+
"""
|
| 87 |
+
|
| 88 |
+
@functools.wraps(loss_func)
|
| 89 |
+
def wrapper(pred, target, weight=None, reduction='mean', **kwargs):
|
| 90 |
+
# get element-wise loss
|
| 91 |
+
loss = loss_func(pred, target, **kwargs)
|
| 92 |
+
loss = weight_reduce_loss(loss, weight, reduction)
|
| 93 |
+
return loss
|
| 94 |
+
|
| 95 |
+
return wrapper
|
basicsr/models/losses/losses.py
ADDED
|
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn as nn
|
| 3 |
+
from torch.nn import functional as F
|
| 4 |
+
import numpy as np
|
| 5 |
+
from math import exp
|
| 6 |
+
|
| 7 |
+
from basicsr.models.losses.loss_util import weighted_loss
|
| 8 |
+
|
| 9 |
+
_reduction_modes = ['none', 'mean', 'sum']
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
@weighted_loss
|
| 13 |
+
def l1_loss(pred, target):
|
| 14 |
+
return F.l1_loss(pred, target, reduction='none')
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
@weighted_loss
|
| 18 |
+
def mse_loss(pred, target):
|
| 19 |
+
return F.mse_loss(pred, target, reduction='none')
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
# @weighted_loss
|
| 23 |
+
# def charbonnier_loss(pred, target, eps=1e-12):
|
| 24 |
+
# return torch.sqrt((pred - target)**2 + eps)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class L1Loss(nn.Module):
|
| 28 |
+
"""L1 (mean absolute error, MAE) loss.
|
| 29 |
+
|
| 30 |
+
Args:
|
| 31 |
+
loss_weight (float): Loss weight for L1 loss. Default: 1.0.
|
| 32 |
+
reduction (str): Specifies the reduction to apply to the output.
|
| 33 |
+
Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'.
|
| 34 |
+
"""
|
| 35 |
+
|
| 36 |
+
def __init__(self, loss_weight=1.0, reduction='mean'):
|
| 37 |
+
super(L1Loss, self).__init__()
|
| 38 |
+
if reduction not in ['none', 'mean', 'sum']:
|
| 39 |
+
raise ValueError(f'Unsupported reduction mode: {reduction}. '
|
| 40 |
+
f'Supported ones are: {_reduction_modes}')
|
| 41 |
+
|
| 42 |
+
self.loss_weight = loss_weight
|
| 43 |
+
self.reduction = reduction
|
| 44 |
+
|
| 45 |
+
def forward(self, pred, target, weight=None, **kwargs):
|
| 46 |
+
"""
|
| 47 |
+
Args:
|
| 48 |
+
pred (Tensor): of shape (N, C, H, W). Predicted tensor.
|
| 49 |
+
target (Tensor): of shape (N, C, H, W). Ground truth tensor.
|
| 50 |
+
weight (Tensor, optional): of shape (N, C, H, W). Element-wise
|
| 51 |
+
weights. Default: None.
|
| 52 |
+
"""
|
| 53 |
+
return self.loss_weight * l1_loss(
|
| 54 |
+
pred, target, weight, reduction=self.reduction)
|
| 55 |
+
|
| 56 |
+
class MSELoss(nn.Module):
|
| 57 |
+
"""MSE (L2) loss.
|
| 58 |
+
|
| 59 |
+
Args:
|
| 60 |
+
loss_weight (float): Loss weight for MSE loss. Default: 1.0.
|
| 61 |
+
reduction (str): Specifies the reduction to apply to the output.
|
| 62 |
+
Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'.
|
| 63 |
+
"""
|
| 64 |
+
|
| 65 |
+
def __init__(self, loss_weight=1.0, reduction='mean'):
|
| 66 |
+
super(MSELoss, self).__init__()
|
| 67 |
+
if reduction not in ['none', 'mean', 'sum']:
|
| 68 |
+
raise ValueError(f'Unsupported reduction mode: {reduction}. '
|
| 69 |
+
f'Supported ones are: {_reduction_modes}')
|
| 70 |
+
|
| 71 |
+
self.loss_weight = loss_weight
|
| 72 |
+
self.reduction = reduction
|
| 73 |
+
|
| 74 |
+
def forward(self, pred, target, weight=None, **kwargs):
|
| 75 |
+
"""
|
| 76 |
+
Args:
|
| 77 |
+
pred (Tensor): of shape (N, C, H, W). Predicted tensor.
|
| 78 |
+
target (Tensor): of shape (N, C, H, W). Ground truth tensor.
|
| 79 |
+
weight (Tensor, optional): of shape (N, C, H, W). Element-wise
|
| 80 |
+
weights. Default: None.
|
| 81 |
+
"""
|
| 82 |
+
return self.loss_weight * mse_loss(
|
| 83 |
+
pred, target, weight, reduction=self.reduction)
|
| 84 |
+
|
| 85 |
+
class PSNRLoss(nn.Module):
|
| 86 |
+
|
| 87 |
+
def __init__(self, loss_weight=1.0, reduction='mean', toY=False):
|
| 88 |
+
super(PSNRLoss, self).__init__()
|
| 89 |
+
assert reduction == 'mean'
|
| 90 |
+
self.loss_weight = loss_weight
|
| 91 |
+
self.scale = 10 / np.log(10)
|
| 92 |
+
self.toY = toY
|
| 93 |
+
self.coef = torch.tensor([65.481, 128.553, 24.966]).reshape(1, 3, 1, 1)
|
| 94 |
+
self.first = True
|
| 95 |
+
|
| 96 |
+
def forward(self, pred, target):
|
| 97 |
+
assert len(pred.size()) == 4
|
| 98 |
+
if self.toY:
|
| 99 |
+
if self.first:
|
| 100 |
+
self.coef = self.coef.to(pred.device)
|
| 101 |
+
self.first = False
|
| 102 |
+
|
| 103 |
+
pred = (pred * self.coef).sum(dim=1).unsqueeze(dim=1) + 16.
|
| 104 |
+
target = (target * self.coef).sum(dim=1).unsqueeze(dim=1) + 16.
|
| 105 |
+
|
| 106 |
+
pred, target = pred / 255., target / 255.
|
| 107 |
+
pass
|
| 108 |
+
assert len(pred.size()) == 4
|
| 109 |
+
|
| 110 |
+
return self.loss_weight * self.scale * torch.log(((pred - target) ** 2).mean(dim=(1, 2, 3)) + 1e-8).mean()
|
| 111 |
+
|
| 112 |
+
class CharbonnierLoss(nn.Module):
|
| 113 |
+
"""Charbonnier Loss (L1)"""
|
| 114 |
+
|
| 115 |
+
def __init__(self, loss_weight=1.0, reduction='mean', eps=1e-3):
|
| 116 |
+
super(CharbonnierLoss, self).__init__()
|
| 117 |
+
self.eps = eps
|
| 118 |
+
|
| 119 |
+
def forward(self, x, y):
|
| 120 |
+
diff = x - y
|
| 121 |
+
# loss = torch.sum(torch.sqrt(diff * diff + self.eps))
|
| 122 |
+
loss = torch.mean(torch.sqrt((diff * diff) + (self.eps*self.eps)))
|
| 123 |
+
return loss
|
| 124 |
+
|
| 125 |
+
class MS_SSIM(nn.Module):
|
| 126 |
+
def __init__(self, window_size=11, sigma=1.5, device="cuda"):
|
| 127 |
+
super(MS_SSIM, self).__init__()
|
| 128 |
+
self.device = device
|
| 129 |
+
self.channel = 3
|
| 130 |
+
self.sigma=sigma
|
| 131 |
+
self.weights = [0.0448, 0.2856, 0.3001, 0.2363, 0.1333]
|
| 132 |
+
self.levels = len(self.weights)
|
| 133 |
+
self.window = self.create_window(window_size)
|
| 134 |
+
|
| 135 |
+
def create_window(self, window_size):
|
| 136 |
+
self.window_size = window_size
|
| 137 |
+
# 1D gaussian kernel
|
| 138 |
+
gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * self.sigma ** 2)) for x in range(window_size)])
|
| 139 |
+
gauss = gauss / gauss.sum()
|
| 140 |
+
|
| 141 |
+
# 2D Gaussian window
|
| 142 |
+
_1D_window = gauss.unsqueeze(1)
|
| 143 |
+
_2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
|
| 144 |
+
return _2D_window.expand(self.channel, 1, window_size, window_size).contiguous().to(self.device)
|
| 145 |
+
|
| 146 |
+
def update_window_size(self, window_size):
|
| 147 |
+
self.window = self.create_window(window_size)
|
| 148 |
+
|
| 149 |
+
def ssim(self, img1, img2):
|
| 150 |
+
"""Compute SSIM between two images."""
|
| 151 |
+
mu1 = F.conv2d(img1, self.window, padding=self.window_size // 2, groups=self.channel)
|
| 152 |
+
mu2 = F.conv2d(img2, self.window, padding=self.window_size // 2, groups=self.channel)
|
| 153 |
+
|
| 154 |
+
mu1_sq = mu1.pow(2)
|
| 155 |
+
mu2_sq = mu2.pow(2)
|
| 156 |
+
mu1_mu2 = mu1 * mu2
|
| 157 |
+
|
| 158 |
+
sigma1_sq = F.conv2d(img1 * img1, self.window, padding=self.window_size // 2, groups=self.channel) - mu1_sq
|
| 159 |
+
sigma2_sq = F.conv2d(img2 * img2, self.window, padding=self.window_size // 2, groups=self.channel) - mu2_sq
|
| 160 |
+
sigma12 = F.conv2d(img1 * img2, self.window, padding=self.window_size // 2, groups=self.channel) - mu1_mu2
|
| 161 |
+
|
| 162 |
+
C1 = 0.01 ** 2
|
| 163 |
+
C2 = 0.03 ** 2
|
| 164 |
+
|
| 165 |
+
ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
|
| 166 |
+
|
| 167 |
+
return ssim_map.mean()
|
| 168 |
+
|
| 169 |
+
def forward(self, pred, target):
|
| 170 |
+
msssim = []
|
| 171 |
+
for i in range(self.levels):
|
| 172 |
+
ssim_val = self.ssim(pred, target)
|
| 173 |
+
msssim.append(ssim_val * self.weights[i])
|
| 174 |
+
if i < self.levels - 1:
|
| 175 |
+
pred = F.avg_pool2d(pred, kernel_size=2, stride=2)
|
| 176 |
+
target = F.avg_pool2d(target, kernel_size=2, stride=2)
|
| 177 |
+
|
| 178 |
+
return torch.prod(torch.stack(msssim))
|
| 179 |
+
|
| 180 |
+
|
basicsr/models/lr_scheduler.py
ADDED
|
@@ -0,0 +1,232 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
from collections import Counter
|
| 3 |
+
from torch.optim.lr_scheduler import _LRScheduler
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class MultiStepRestartLR(_LRScheduler):
|
| 8 |
+
""" MultiStep with restarts learning rate scheme.
|
| 9 |
+
|
| 10 |
+
Args:
|
| 11 |
+
optimizer (torch.nn.optimizer): Torch optimizer.
|
| 12 |
+
milestones (list): Iterations that will decrease learning rate.
|
| 13 |
+
gamma (float): Decrease ratio. Default: 0.1.
|
| 14 |
+
restarts (list): Restart iterations. Default: [0].
|
| 15 |
+
restart_weights (list): Restart weights at each restart iteration.
|
| 16 |
+
Default: [1].
|
| 17 |
+
last_epoch (int): Used in _LRScheduler. Default: -1.
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
def __init__(self,
|
| 21 |
+
optimizer,
|
| 22 |
+
milestones,
|
| 23 |
+
gamma=0.1,
|
| 24 |
+
restarts=(0, ),
|
| 25 |
+
restart_weights=(1, ),
|
| 26 |
+
last_epoch=-1):
|
| 27 |
+
self.milestones = Counter(milestones)
|
| 28 |
+
self.gamma = gamma
|
| 29 |
+
self.restarts = restarts
|
| 30 |
+
self.restart_weights = restart_weights
|
| 31 |
+
assert len(self.restarts) == len(
|
| 32 |
+
self.restart_weights), 'restarts and their weights do not match.'
|
| 33 |
+
super(MultiStepRestartLR, self).__init__(optimizer, last_epoch)
|
| 34 |
+
|
| 35 |
+
def get_lr(self):
|
| 36 |
+
if self.last_epoch in self.restarts:
|
| 37 |
+
weight = self.restart_weights[self.restarts.index(self.last_epoch)]
|
| 38 |
+
return [
|
| 39 |
+
group['initial_lr'] * weight
|
| 40 |
+
for group in self.optimizer.param_groups
|
| 41 |
+
]
|
| 42 |
+
if self.last_epoch not in self.milestones:
|
| 43 |
+
return [group['lr'] for group in self.optimizer.param_groups]
|
| 44 |
+
return [
|
| 45 |
+
group['lr'] * self.gamma**self.milestones[self.last_epoch]
|
| 46 |
+
for group in self.optimizer.param_groups
|
| 47 |
+
]
|
| 48 |
+
|
| 49 |
+
class LinearLR(_LRScheduler):
|
| 50 |
+
"""
|
| 51 |
+
|
| 52 |
+
Args:
|
| 53 |
+
optimizer (torch.nn.optimizer): Torch optimizer.
|
| 54 |
+
milestones (list): Iterations that will decrease learning rate.
|
| 55 |
+
gamma (float): Decrease ratio. Default: 0.1.
|
| 56 |
+
last_epoch (int): Used in _LRScheduler. Default: -1.
|
| 57 |
+
"""
|
| 58 |
+
|
| 59 |
+
def __init__(self,
|
| 60 |
+
optimizer,
|
| 61 |
+
total_iter,
|
| 62 |
+
last_epoch=-1):
|
| 63 |
+
self.total_iter = total_iter
|
| 64 |
+
super(LinearLR, self).__init__(optimizer, last_epoch)
|
| 65 |
+
|
| 66 |
+
def get_lr(self):
|
| 67 |
+
process = self.last_epoch / self.total_iter
|
| 68 |
+
weight = (1 - process)
|
| 69 |
+
# print('get lr ', [weight * group['initial_lr'] for group in self.optimizer.param_groups])
|
| 70 |
+
return [weight * group['initial_lr'] for group in self.optimizer.param_groups]
|
| 71 |
+
|
| 72 |
+
class VibrateLR(_LRScheduler):
|
| 73 |
+
"""
|
| 74 |
+
|
| 75 |
+
Args:
|
| 76 |
+
optimizer (torch.nn.optimizer): Torch optimizer.
|
| 77 |
+
milestones (list): Iterations that will decrease learning rate.
|
| 78 |
+
gamma (float): Decrease ratio. Default: 0.1.
|
| 79 |
+
last_epoch (int): Used in _LRScheduler. Default: -1.
|
| 80 |
+
"""
|
| 81 |
+
|
| 82 |
+
def __init__(self,
|
| 83 |
+
optimizer,
|
| 84 |
+
total_iter,
|
| 85 |
+
last_epoch=-1):
|
| 86 |
+
self.total_iter = total_iter
|
| 87 |
+
super(VibrateLR, self).__init__(optimizer, last_epoch)
|
| 88 |
+
|
| 89 |
+
def get_lr(self):
|
| 90 |
+
process = self.last_epoch / self.total_iter
|
| 91 |
+
|
| 92 |
+
f = 0.1
|
| 93 |
+
if process < 3 / 8:
|
| 94 |
+
f = 1 - process * 8 / 3
|
| 95 |
+
elif process < 5 / 8:
|
| 96 |
+
f = 0.2
|
| 97 |
+
|
| 98 |
+
T = self.total_iter // 80
|
| 99 |
+
Th = T // 2
|
| 100 |
+
|
| 101 |
+
t = self.last_epoch % T
|
| 102 |
+
|
| 103 |
+
f2 = t / Th
|
| 104 |
+
if t >= Th:
|
| 105 |
+
f2 = 2 - f2
|
| 106 |
+
|
| 107 |
+
weight = f * f2
|
| 108 |
+
|
| 109 |
+
if self.last_epoch < Th:
|
| 110 |
+
weight = max(0.1, weight)
|
| 111 |
+
|
| 112 |
+
# print('f {}, T {}, Th {}, t {}, f2 {}'.format(f, T, Th, t, f2))
|
| 113 |
+
return [weight * group['initial_lr'] for group in self.optimizer.param_groups]
|
| 114 |
+
|
| 115 |
+
def get_position_from_periods(iteration, cumulative_period):
|
| 116 |
+
"""Get the position from a period list.
|
| 117 |
+
|
| 118 |
+
It will return the index of the right-closest number in the period list.
|
| 119 |
+
For example, the cumulative_period = [100, 200, 300, 400],
|
| 120 |
+
if iteration == 50, return 0;
|
| 121 |
+
if iteration == 210, return 2;
|
| 122 |
+
if iteration == 300, return 2.
|
| 123 |
+
|
| 124 |
+
Args:
|
| 125 |
+
iteration (int): Current iteration.
|
| 126 |
+
cumulative_period (list[int]): Cumulative period list.
|
| 127 |
+
|
| 128 |
+
Returns:
|
| 129 |
+
int: The position of the right-closest number in the period list.
|
| 130 |
+
"""
|
| 131 |
+
for i, period in enumerate(cumulative_period):
|
| 132 |
+
if iteration <= period:
|
| 133 |
+
return i
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
class CosineAnnealingRestartLR(_LRScheduler):
|
| 137 |
+
""" Cosine annealing with restarts learning rate scheme.
|
| 138 |
+
|
| 139 |
+
An example of config:
|
| 140 |
+
periods = [10, 10, 10, 10]
|
| 141 |
+
restart_weights = [1, 0.5, 0.5, 0.5]
|
| 142 |
+
eta_min=1e-7
|
| 143 |
+
|
| 144 |
+
It has four cycles, each has 10 iterations. At 10th, 20th, 30th, the
|
| 145 |
+
scheduler will restart with the weights in restart_weights.
|
| 146 |
+
|
| 147 |
+
Args:
|
| 148 |
+
optimizer (torch.nn.optimizer): Torch optimizer.
|
| 149 |
+
periods (list): Period for each cosine anneling cycle.
|
| 150 |
+
restart_weights (list): Restart weights at each restart iteration.
|
| 151 |
+
Default: [1].
|
| 152 |
+
eta_min (float): The mimimum lr. Default: 0.
|
| 153 |
+
last_epoch (int): Used in _LRScheduler. Default: -1.
|
| 154 |
+
"""
|
| 155 |
+
|
| 156 |
+
def __init__(self,
|
| 157 |
+
optimizer,
|
| 158 |
+
periods,
|
| 159 |
+
restart_weights=(1, ),
|
| 160 |
+
eta_min=0,
|
| 161 |
+
last_epoch=-1):
|
| 162 |
+
self.periods = periods
|
| 163 |
+
self.restart_weights = restart_weights
|
| 164 |
+
self.eta_min = eta_min
|
| 165 |
+
assert (len(self.periods) == len(self.restart_weights)
|
| 166 |
+
), 'periods and restart_weights should have the same length.'
|
| 167 |
+
self.cumulative_period = [
|
| 168 |
+
sum(self.periods[0:i + 1]) for i in range(0, len(self.periods))
|
| 169 |
+
]
|
| 170 |
+
super(CosineAnnealingRestartLR, self).__init__(optimizer, last_epoch)
|
| 171 |
+
|
| 172 |
+
def get_lr(self):
|
| 173 |
+
idx = get_position_from_periods(self.last_epoch,
|
| 174 |
+
self.cumulative_period)
|
| 175 |
+
current_weight = self.restart_weights[idx]
|
| 176 |
+
nearest_restart = 0 if idx == 0 else self.cumulative_period[idx - 1]
|
| 177 |
+
current_period = self.periods[idx]
|
| 178 |
+
|
| 179 |
+
return [
|
| 180 |
+
self.eta_min + current_weight * 0.5 * (base_lr - self.eta_min) *
|
| 181 |
+
(1 + math.cos(math.pi * (
|
| 182 |
+
(self.last_epoch - nearest_restart) / current_period)))
|
| 183 |
+
for base_lr in self.base_lrs
|
| 184 |
+
]
|
| 185 |
+
|
| 186 |
+
class CosineAnnealingRestartCyclicLR(_LRScheduler):
|
| 187 |
+
""" Cosine annealing with restarts learning rate scheme.
|
| 188 |
+
An example of config:
|
| 189 |
+
periods = [10, 10, 10, 10]
|
| 190 |
+
restart_weights = [1, 0.5, 0.5, 0.5]
|
| 191 |
+
eta_min=1e-7
|
| 192 |
+
It has four cycles, each has 10 iterations. At 10th, 20th, 30th, the
|
| 193 |
+
scheduler will restart with the weights in restart_weights.
|
| 194 |
+
Args:
|
| 195 |
+
optimizer (torch.nn.optimizer): Torch optimizer.
|
| 196 |
+
periods (list): Period for each cosine anneling cycle.
|
| 197 |
+
restart_weights (list): Restart weights at each restart iteration.
|
| 198 |
+
Default: [1].
|
| 199 |
+
eta_min (float): The mimimum lr. Default: 0.
|
| 200 |
+
last_epoch (int): Used in _LRScheduler. Default: -1.
|
| 201 |
+
"""
|
| 202 |
+
|
| 203 |
+
def __init__(self,
|
| 204 |
+
optimizer,
|
| 205 |
+
periods,
|
| 206 |
+
restart_weights=(1, ),
|
| 207 |
+
eta_mins=(0, ),
|
| 208 |
+
last_epoch=-1):
|
| 209 |
+
self.periods = periods
|
| 210 |
+
self.restart_weights = restart_weights
|
| 211 |
+
self.eta_mins = eta_mins
|
| 212 |
+
assert (len(self.periods) == len(self.restart_weights)
|
| 213 |
+
), 'periods and restart_weights should have the same length.'
|
| 214 |
+
self.cumulative_period = [
|
| 215 |
+
sum(self.periods[0:i + 1]) for i in range(0, len(self.periods))
|
| 216 |
+
]
|
| 217 |
+
super(CosineAnnealingRestartCyclicLR, self).__init__(optimizer, last_epoch)
|
| 218 |
+
|
| 219 |
+
def get_lr(self):
|
| 220 |
+
idx = get_position_from_periods(self.last_epoch,
|
| 221 |
+
self.cumulative_period)
|
| 222 |
+
current_weight = self.restart_weights[idx]
|
| 223 |
+
nearest_restart = 0 if idx == 0 else self.cumulative_period[idx - 1]
|
| 224 |
+
current_period = self.periods[idx]
|
| 225 |
+
eta_min = self.eta_mins[idx]
|
| 226 |
+
|
| 227 |
+
return [
|
| 228 |
+
eta_min + current_weight * 0.5 * (base_lr - eta_min) *
|
| 229 |
+
(1 + math.cos(math.pi * (
|
| 230 |
+
(self.last_epoch - nearest_restart) / current_period)))
|
| 231 |
+
for base_lr in self.base_lrs
|
| 232 |
+
]
|
basicsr/test.py
ADDED
|
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import random
|
| 3 |
+
import torch
|
| 4 |
+
from os import path as osp
|
| 5 |
+
|
| 6 |
+
from basicsr.data import create_dataloader, create_dataset
|
| 7 |
+
from basicsr.models import create_model
|
| 8 |
+
from basicsr.utils import (check_resume, make_exp_dirs, mkdir_and_rename, set_random_seed)
|
| 9 |
+
from basicsr.utils.dist_util import get_dist_info, init_dist
|
| 10 |
+
from basicsr.utils.options import parse
|
| 11 |
+
from basicsr.utils.nano import psf2otf
|
| 12 |
+
|
| 13 |
+
import numpy as np
|
| 14 |
+
from tqdm import tqdm
|
| 15 |
+
|
| 16 |
+
def parse_options(is_train=True):
|
| 17 |
+
parser = argparse.ArgumentParser()
|
| 18 |
+
parser.add_argument(
|
| 19 |
+
'-opt', type=str, required=True, help='Path to option YAML file.')
|
| 20 |
+
parser.add_argument(
|
| 21 |
+
'--launcher',
|
| 22 |
+
choices=['none', 'pytorch', 'slurm'],
|
| 23 |
+
default='none',
|
| 24 |
+
help='job launcher')
|
| 25 |
+
parser.add_argument(
|
| 26 |
+
'--name',
|
| 27 |
+
default=None,
|
| 28 |
+
help='job launcher')
|
| 29 |
+
import sys
|
| 30 |
+
vv = sys.version_info.minor
|
| 31 |
+
parser.add_argument('--local-rank', type=int, default=0)
|
| 32 |
+
parser.add_argument('--local_rank', type=int, default=0)
|
| 33 |
+
args = parser.parse_args()
|
| 34 |
+
opt = parse(args.opt, is_train=is_train, name=args.name if args.name is not None and args.name != "" else None)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
# distributed settings
|
| 38 |
+
if args.launcher == 'none':
|
| 39 |
+
opt['dist'] = False
|
| 40 |
+
print('Disable distributed.', flush=True)
|
| 41 |
+
else:
|
| 42 |
+
opt['dist'] = True
|
| 43 |
+
if args.launcher == 'slurm' and 'dist_params' in opt:
|
| 44 |
+
init_dist(args.launcher, **opt['dist_params'])
|
| 45 |
+
else:
|
| 46 |
+
init_dist(args.launcher)
|
| 47 |
+
print('init dist .. ', args.launcher)
|
| 48 |
+
|
| 49 |
+
opt['rank'], opt['world_size'] = get_dist_info()
|
| 50 |
+
|
| 51 |
+
# random seed
|
| 52 |
+
seed = opt.get('manual_seed')
|
| 53 |
+
if seed is None:
|
| 54 |
+
seed = random.randint(1, 10000)
|
| 55 |
+
opt['manual_seed'] = seed
|
| 56 |
+
set_random_seed(seed + opt['rank'])
|
| 57 |
+
|
| 58 |
+
return opt
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def main():
|
| 62 |
+
# parse options, set distributed setting, set ramdom seed
|
| 63 |
+
opt = parse_options(is_train=True)
|
| 64 |
+
torch.backends.cudnn.benchmark = True
|
| 65 |
+
|
| 66 |
+
# automatic resume ..
|
| 67 |
+
state_folder_path = 'experiments/{}/training_states/'.format(opt['name'])
|
| 68 |
+
import os
|
| 69 |
+
try:
|
| 70 |
+
states = os.listdir(state_folder_path)
|
| 71 |
+
except:
|
| 72 |
+
states = []
|
| 73 |
+
resume_state = None
|
| 74 |
+
if len(states) > 0:
|
| 75 |
+
max_state_file = '{}.state'.format(max([int(x[0:-6]) for x in states]))
|
| 76 |
+
resume_state = os.path.join(state_folder_path, max_state_file)
|
| 77 |
+
opt['path']['resume_state'] = resume_state
|
| 78 |
+
|
| 79 |
+
# load resume states if necessary
|
| 80 |
+
if opt['path'].get('resume_state'):
|
| 81 |
+
device_id = torch.cuda.current_device()
|
| 82 |
+
resume_state = torch.load(
|
| 83 |
+
opt['path']['resume_state'],
|
| 84 |
+
map_location=lambda storage, loc: storage.cuda(device_id))
|
| 85 |
+
else:
|
| 86 |
+
resume_state = None
|
| 87 |
+
|
| 88 |
+
# mkdir for experiments and logger
|
| 89 |
+
if resume_state is None:
|
| 90 |
+
make_exp_dirs(opt)
|
| 91 |
+
if opt['logger'].get('use_tb_logger') and 'debug' not in opt[
|
| 92 |
+
'name'] and opt['rank'] == 0:
|
| 93 |
+
mkdir_and_rename(osp.join('tb_logger', opt['name']))
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
# define ks for Wiener filters
|
| 97 |
+
ks_params = opt['train'].get('ks', None)
|
| 98 |
+
if not ks_params:
|
| 99 |
+
raise NotImplementedError
|
| 100 |
+
M = ks_params['num']
|
| 101 |
+
ks = torch.logspace(ks_params['start'], ks_params['end'], M)
|
| 102 |
+
ks = ks.view(1,M,1,1,1,1).to("cuda")
|
| 103 |
+
|
| 104 |
+
val_conv = opt['val'].get("apply_conv", True)
|
| 105 |
+
|
| 106 |
+
# create model
|
| 107 |
+
if resume_state: # resume training
|
| 108 |
+
check_resume(opt, resume_state['iter'])
|
| 109 |
+
model = create_model(opt)
|
| 110 |
+
model.resume_training(resume_state) # handle optimizers and schedulers
|
| 111 |
+
current_iter = resume_state['iter']
|
| 112 |
+
|
| 113 |
+
else:
|
| 114 |
+
model = create_model(opt)
|
| 115 |
+
current_iter = 0
|
| 116 |
+
|
| 117 |
+
# load psf
|
| 118 |
+
psf = torch.tensor(np.load("./psf.npy")).to("cuda")
|
| 119 |
+
_,psf_h,psf_w,_ = psf.shape
|
| 120 |
+
otf = psf2otf(psf, h=psf_h*3, w=psf_w*3, permute=True)[None]
|
| 121 |
+
|
| 122 |
+
dataset_opt = opt['datasets']['val']
|
| 123 |
+
|
| 124 |
+
val_set = create_dataset(dataset_opt)
|
| 125 |
+
val_loader = create_dataloader(
|
| 126 |
+
val_set,
|
| 127 |
+
dataset_opt,
|
| 128 |
+
num_gpu=opt['num_gpu'],
|
| 129 |
+
dist=opt['dist'],
|
| 130 |
+
sampler=None,
|
| 131 |
+
seed=opt['manual_seed'])
|
| 132 |
+
|
| 133 |
+
print("Start validation on spatially varying aberrration")
|
| 134 |
+
rgb2bgr = opt['val'].get('rgb2bgr', True)
|
| 135 |
+
use_image = opt['val'].get('use_image', True)
|
| 136 |
+
psnr, others = model.validation(val_loader, current_iter, None, True, rgb2bgr, use_image, psf=otf, ks=ks, val_conv=val_conv)
|
| 137 |
+
print("==================")
|
| 138 |
+
print(f"Test results: PSNR: {psnr:.2f}, SSIM: {others['ssim']:.4f}, LPIPS: {others['lpips']:.4f}\n")
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
if __name__ == '__main__':
|
| 142 |
+
main()
|
basicsr/train.py
ADDED
|
@@ -0,0 +1,328 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import datetime
|
| 3 |
+
import logging
|
| 4 |
+
import math
|
| 5 |
+
import random
|
| 6 |
+
import time
|
| 7 |
+
import torch
|
| 8 |
+
import gc
|
| 9 |
+
from os import path as osp
|
| 10 |
+
|
| 11 |
+
from basicsr.data import create_dataloader, create_dataset
|
| 12 |
+
from basicsr.data.data_sampler import EnlargedSampler
|
| 13 |
+
from basicsr.data.prefetch_dataloader import CPUPrefetcher, CUDAPrefetcher
|
| 14 |
+
from basicsr.models import create_model
|
| 15 |
+
from basicsr.utils import (MessageLogger, check_resume, get_env_info,
|
| 16 |
+
get_root_logger, get_time_str, init_tb_logger,
|
| 17 |
+
init_wandb_logger, make_exp_dirs, mkdir_and_rename,
|
| 18 |
+
set_random_seed)
|
| 19 |
+
from basicsr.utils.dist_util import get_dist_info, init_dist
|
| 20 |
+
from basicsr.utils.options import dict2str, parse
|
| 21 |
+
from basicsr.utils.nano import apply_conv_n_deconv, psf2otf
|
| 22 |
+
|
| 23 |
+
import numpy as np
|
| 24 |
+
from tqdm import tqdm
|
| 25 |
+
|
| 26 |
+
def parse_options(is_train=True):
|
| 27 |
+
parser = argparse.ArgumentParser()
|
| 28 |
+
parser.add_argument(
|
| 29 |
+
'-opt', type=str, required=True, help='Path to option YAML file.')
|
| 30 |
+
parser.add_argument(
|
| 31 |
+
'--launcher',
|
| 32 |
+
choices=['none', 'pytorch', 'slurm'],
|
| 33 |
+
default='none',
|
| 34 |
+
help='job launcher')
|
| 35 |
+
parser.add_argument(
|
| 36 |
+
'--name',
|
| 37 |
+
default=None,
|
| 38 |
+
help='job launcher')
|
| 39 |
+
import sys
|
| 40 |
+
vv = sys.version_info.minor
|
| 41 |
+
parser.add_argument('--local-rank', type=int, default=0)
|
| 42 |
+
parser.add_argument('--local_rank', type=int, default=0)
|
| 43 |
+
args = parser.parse_args()
|
| 44 |
+
opt = parse(args.opt, is_train=is_train, name=args.name if args.name is not None and args.name != "" else None)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
# distributed settings
|
| 48 |
+
if args.launcher == 'none':
|
| 49 |
+
opt['dist'] = False
|
| 50 |
+
print('Disable distributed.', flush=True)
|
| 51 |
+
else:
|
| 52 |
+
opt['dist'] = True
|
| 53 |
+
if args.launcher == 'slurm' and 'dist_params' in opt:
|
| 54 |
+
init_dist(args.launcher, **opt['dist_params'])
|
| 55 |
+
else:
|
| 56 |
+
init_dist(args.launcher)
|
| 57 |
+
print('init dist .. ', args.launcher)
|
| 58 |
+
|
| 59 |
+
opt['rank'], opt['world_size'] = get_dist_info()
|
| 60 |
+
|
| 61 |
+
# random seed
|
| 62 |
+
seed = opt.get('manual_seed')
|
| 63 |
+
if seed is None:
|
| 64 |
+
seed = random.randint(1, 10000)
|
| 65 |
+
opt['manual_seed'] = seed
|
| 66 |
+
set_random_seed(seed + opt['rank'])
|
| 67 |
+
|
| 68 |
+
return opt
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def init_loggers(opt):
|
| 72 |
+
log_file = osp.join(opt['path']['log'],
|
| 73 |
+
f"train_{opt['name']}_{get_time_str()}.log")
|
| 74 |
+
logger = get_root_logger(
|
| 75 |
+
logger_name='basicsr', log_level=logging.INFO, log_file=log_file)
|
| 76 |
+
logger.info(get_env_info())
|
| 77 |
+
logger.info(dict2str(opt))
|
| 78 |
+
|
| 79 |
+
# initialize wandb logger before tensorboard logger to allow proper sync:
|
| 80 |
+
if (opt['logger'].get('wandb')
|
| 81 |
+
is not None) and (opt['logger']['wandb'].get('project')
|
| 82 |
+
is not None) and ('debug' not in opt['name']):
|
| 83 |
+
assert opt['logger'].get('use_tb_logger') is True, (
|
| 84 |
+
'should turn on tensorboard when using wandb')
|
| 85 |
+
init_wandb_logger(opt)
|
| 86 |
+
tb_logger = None
|
| 87 |
+
if opt['logger'].get('use_tb_logger') and 'debug' not in opt['name']:
|
| 88 |
+
tb_logger = init_tb_logger(log_dir=osp.join('tb_logger', opt['name']))
|
| 89 |
+
return logger, tb_logger
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def create_train_val_dataloader(opt, logger):
|
| 93 |
+
# create train and val dataloaders
|
| 94 |
+
for phase, dataset_opt in opt['datasets'].items():
|
| 95 |
+
if phase == 'train':
|
| 96 |
+
dataset_enlarge_ratio = dataset_opt.get('dataset_enlarge_ratio', 1)
|
| 97 |
+
train_set = create_dataset(dataset_opt)
|
| 98 |
+
train_sampler = EnlargedSampler(train_set, opt['world_size'],
|
| 99 |
+
opt['rank'], dataset_enlarge_ratio)
|
| 100 |
+
train_loader = create_dataloader(
|
| 101 |
+
train_set,
|
| 102 |
+
dataset_opt,
|
| 103 |
+
num_gpu=opt['num_gpu'],
|
| 104 |
+
dist=opt['dist'],
|
| 105 |
+
sampler=train_sampler,
|
| 106 |
+
seed=opt['manual_seed'],
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
num_iter_per_epoch = math.ceil(
|
| 110 |
+
len(train_set) * dataset_enlarge_ratio /
|
| 111 |
+
(dataset_opt['batch_size_per_gpu'] * opt['world_size']))
|
| 112 |
+
total_iters = int(opt['train']['total_iter'])
|
| 113 |
+
total_epochs = math.ceil(total_iters / (num_iter_per_epoch))
|
| 114 |
+
logger.info(
|
| 115 |
+
'Training statistics:'
|
| 116 |
+
f'\n\tNumber of train images: {len(train_set)}'
|
| 117 |
+
f'\n\tDataset enlarge ratio: {dataset_enlarge_ratio}'
|
| 118 |
+
f'\n\tBatch size per gpu: {dataset_opt["batch_size_per_gpu"]}'
|
| 119 |
+
f'\n\tWorld size (gpu number): {opt["world_size"]}'
|
| 120 |
+
f'\n\tRequire iter number per epoch: {num_iter_per_epoch}'
|
| 121 |
+
f'\n\tTotal epochs: {total_epochs}; iters: {total_iters}.')
|
| 122 |
+
|
| 123 |
+
elif phase == 'val':
|
| 124 |
+
val_set = create_dataset(dataset_opt)
|
| 125 |
+
val_loader = create_dataloader(
|
| 126 |
+
val_set,
|
| 127 |
+
dataset_opt,
|
| 128 |
+
num_gpu=opt['num_gpu'],
|
| 129 |
+
dist=opt['dist'],
|
| 130 |
+
sampler=None,
|
| 131 |
+
seed=opt['manual_seed'],
|
| 132 |
+
)
|
| 133 |
+
logger.info(
|
| 134 |
+
f'Number of val images/folders in {dataset_opt["name"]}: '
|
| 135 |
+
f'{len(val_set)}')
|
| 136 |
+
|
| 137 |
+
else:
|
| 138 |
+
raise ValueError(f'Dataset phase {phase} is not recognized.')
|
| 139 |
+
|
| 140 |
+
return train_loader, train_sampler, val_loader, total_epochs, total_iters
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
def main():
|
| 144 |
+
# parse options, set distributed setting, set ramdom seed
|
| 145 |
+
opt = parse_options(is_train=True)
|
| 146 |
+
torch.backends.cudnn.benchmark = True
|
| 147 |
+
|
| 148 |
+
# automatic resume ..
|
| 149 |
+
state_folder_path = 'experiments/{}/training_states/'.format(opt['name'])
|
| 150 |
+
import os
|
| 151 |
+
try:
|
| 152 |
+
states = os.listdir(state_folder_path)
|
| 153 |
+
except:
|
| 154 |
+
states = []
|
| 155 |
+
resume_state = None
|
| 156 |
+
if len(states) > 0:
|
| 157 |
+
max_state_file = '{}.state'.format(max([int(x[0:-6]) for x in states]))
|
| 158 |
+
resume_state = os.path.join(state_folder_path, max_state_file)
|
| 159 |
+
opt['path']['resume_state'] = resume_state
|
| 160 |
+
|
| 161 |
+
# load resume states if necessary
|
| 162 |
+
if opt['path'].get('resume_state'):
|
| 163 |
+
device_id = torch.cuda.current_device()
|
| 164 |
+
resume_state = torch.load(
|
| 165 |
+
opt['path']['resume_state'],
|
| 166 |
+
map_location=lambda storage, loc: storage.cuda(device_id))
|
| 167 |
+
else:
|
| 168 |
+
resume_state = None
|
| 169 |
+
|
| 170 |
+
# mkdir for experiments and logger
|
| 171 |
+
if resume_state is None:
|
| 172 |
+
make_exp_dirs(opt)
|
| 173 |
+
if opt['logger'].get('use_tb_logger') and 'debug' not in opt[
|
| 174 |
+
'name'] and opt['rank'] == 0:
|
| 175 |
+
mkdir_and_rename(osp.join('tb_logger', opt['name']))
|
| 176 |
+
|
| 177 |
+
# initialize loggers
|
| 178 |
+
logger, tb_logger = init_loggers(opt)
|
| 179 |
+
|
| 180 |
+
# define ks for Wiener filters
|
| 181 |
+
ks_params = opt['train'].get('ks', None)
|
| 182 |
+
if not ks_params:
|
| 183 |
+
raise NotImplementedError
|
| 184 |
+
M = ks_params['num']
|
| 185 |
+
ks = torch.logspace(ks_params['start'], ks_params['end'], M)
|
| 186 |
+
ks = ks.view(1,M,1,1,1,1).to("cuda")
|
| 187 |
+
|
| 188 |
+
# create model
|
| 189 |
+
if resume_state: # resume training
|
| 190 |
+
check_resume(opt, resume_state['iter'])
|
| 191 |
+
model = create_model(opt)
|
| 192 |
+
model.resume_training(resume_state) # handle optimizers and schedulers
|
| 193 |
+
logger.info(f"Resuming training from epoch: {resume_state['epoch']}, "
|
| 194 |
+
f"iter: {resume_state['iter']}.")
|
| 195 |
+
start_epoch = resume_state['epoch']
|
| 196 |
+
current_iter = resume_state['iter']
|
| 197 |
+
|
| 198 |
+
else:
|
| 199 |
+
model = create_model(opt)
|
| 200 |
+
start_epoch = 0
|
| 201 |
+
current_iter = 0
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
# create train and validation dataloaders
|
| 206 |
+
result = create_train_val_dataloader(opt, logger)
|
| 207 |
+
train_loader, train_sampler, val_loader, total_epochs, total_iters = result
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
# create message logger (formatted outputs)
|
| 211 |
+
msg_logger = MessageLogger(opt, current_iter, tb_logger)
|
| 212 |
+
|
| 213 |
+
# dataloader prefetcher
|
| 214 |
+
prefetch_mode = opt['datasets']['train'].get('prefetch_mode')
|
| 215 |
+
if prefetch_mode is None or prefetch_mode == 'cpu':
|
| 216 |
+
prefetcher = CPUPrefetcher(train_loader)
|
| 217 |
+
elif prefetch_mode == 'cuda':
|
| 218 |
+
prefetcher = CUDAPrefetcher(train_loader, opt)
|
| 219 |
+
logger.info(f'Use {prefetch_mode} prefetch dataloader')
|
| 220 |
+
if opt['datasets']['train'].get('pin_memory') is not True:
|
| 221 |
+
raise ValueError('Please set pin_memory=True for CUDAPrefetcher.')
|
| 222 |
+
else:
|
| 223 |
+
raise ValueError(f'Wrong prefetch_mode {prefetch_mode}.'
|
| 224 |
+
"Supported ones are: None, 'cuda', 'cpu'.")
|
| 225 |
+
|
| 226 |
+
# training
|
| 227 |
+
logger.info(
|
| 228 |
+
f'Start training from epoch: {start_epoch}, iter: {current_iter}')
|
| 229 |
+
data_time, iter_time = time.time(), time.time()
|
| 230 |
+
start_time = time.time()
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
epoch = start_epoch
|
| 235 |
+
pbar = tqdm(total = total_iters+1)
|
| 236 |
+
pbar.update(current_iter)
|
| 237 |
+
|
| 238 |
+
# load psf
|
| 239 |
+
psf = torch.tensor(np.load("./psf.npy")).to("cuda")
|
| 240 |
+
psf_n,psf_h,psf_w,_ = psf.shape
|
| 241 |
+
psf_n_row = int(psf_n ** 0.5)
|
| 242 |
+
sensor_h = opt['datasets']['train'].get('sensor_size')
|
| 243 |
+
otf = psf2otf(psf, h=psf_h*3, w=psf_w*3, permute=True)[None]
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
gt_size = opt['datasets']['train']['gt_size']
|
| 247 |
+
val_conv = opt['val'].get("apply_conv", True)
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
while current_iter <= total_iters:
|
| 251 |
+
train_sampler.set_epoch(epoch)
|
| 252 |
+
prefetcher.reset()
|
| 253 |
+
train_data = prefetcher.next()
|
| 254 |
+
|
| 255 |
+
while train_data is not None:
|
| 256 |
+
data_time = time.time() - data_time
|
| 257 |
+
|
| 258 |
+
gt = train_data['gt'].to("cuda") # B,C,H,H
|
| 259 |
+
padding = train_data['padding']
|
| 260 |
+
padding = torch.stack(padding).T
|
| 261 |
+
lq, gt = apply_conv_n_deconv(gt, otf, padding, M, gt_size, ks=ks, ph=psf_h, num_psf=psf_n_row, sensor_h=sensor_h)
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
# 3 H W . conv -> crop
|
| 265 |
+
current_iter += 1
|
| 266 |
+
if current_iter > total_iters:
|
| 267 |
+
break
|
| 268 |
+
# update learning rate
|
| 269 |
+
model.update_learning_rate(
|
| 270 |
+
current_iter, warmup_iter=opt['train'].get('warmup_iter', -1))
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
model.feed_train_data({'lq': lq, 'gt':gt})
|
| 274 |
+
model.optimize_parameters(current_iter)
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
iter_time = time.time() - iter_time
|
| 278 |
+
|
| 279 |
+
# log
|
| 280 |
+
if current_iter % opt['logger']['print_freq'] == 0:
|
| 281 |
+
log_vars = {'epoch': epoch, 'iter': current_iter}
|
| 282 |
+
log_vars.update({'lrs': model.get_current_learning_rate()})
|
| 283 |
+
log_vars.update({'time': iter_time, 'data_time': data_time})
|
| 284 |
+
|
| 285 |
+
log_vars.update(model.get_current_log())
|
| 286 |
+
msg_logger(log_vars)
|
| 287 |
+
|
| 288 |
+
# save models and training states
|
| 289 |
+
if current_iter % opt['logger']['save_checkpoint_freq'] == 0:
|
| 290 |
+
logger.info('Saving models and training states.')
|
| 291 |
+
model.save(epoch, current_iter)
|
| 292 |
+
|
| 293 |
+
# validation
|
| 294 |
+
if opt.get('val') is not None and ((current_iter % opt['val']['val_freq'] == 0)):
|
| 295 |
+
rgb2bgr = opt['val'].get('rgb2bgr', True)
|
| 296 |
+
# wheather use uint8 image to compute metrics
|
| 297 |
+
use_image = opt['val'].get('use_image', True)
|
| 298 |
+
model.validation(val_loader, current_iter, tb_logger, False, rgb2bgr, use_image, psf=otf, ks=ks, val_conv=val_conv)
|
| 299 |
+
gc.collect()
|
| 300 |
+
torch.cuda.empty_cache()
|
| 301 |
+
|
| 302 |
+
data_time = time.time()
|
| 303 |
+
iter_time = time.time()
|
| 304 |
+
train_data = prefetcher.next()
|
| 305 |
+
pbar.update(1)
|
| 306 |
+
# end of iter
|
| 307 |
+
epoch += 1
|
| 308 |
+
|
| 309 |
+
# end of epoch
|
| 310 |
+
|
| 311 |
+
consumed_time = str(
|
| 312 |
+
datetime.timedelta(seconds=int(time.time() - start_time)))
|
| 313 |
+
logger.info(f'End of training. Time consumed: {consumed_time}')
|
| 314 |
+
logger.info('Save the latest model.')
|
| 315 |
+
model.save(epoch=-1, current_iter=-1) # -1 stands for the latest
|
| 316 |
+
if opt.get('val') is not None:
|
| 317 |
+
rgb2bgr = opt['val'].get('rgb2bgr', True)
|
| 318 |
+
use_image = opt['val'].get('use_image', True)
|
| 319 |
+
psnr, others = model.validation(val_loader, current_iter, tb_logger, True, rgb2bgr, use_image, psf=otf, ks=ks, val_conv=val_conv)
|
| 320 |
+
print("==================")
|
| 321 |
+
print(f"Test results: PSNR: {psnr:.2f}, SSIM: {others['ssim']:.4f}, LPIPS: {others['lpips']:.4f}\n")
|
| 322 |
+
|
| 323 |
+
if tb_logger:
|
| 324 |
+
tb_logger.close()
|
| 325 |
+
|
| 326 |
+
|
| 327 |
+
if __name__ == '__main__':
|
| 328 |
+
main()
|
basicsr/utils/__init__.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .file_client import FileClient
|
| 2 |
+
from .img_util import crop_border, imfrombytes, img2tensor, imwrite, tensor2img, padding, padding_DP, imfrombytesDP
|
| 3 |
+
from .logger import (MessageLogger, get_env_info, get_root_logger,
|
| 4 |
+
init_tb_logger, init_wandb_logger)
|
| 5 |
+
from .misc import (check_resume, get_time_str, make_exp_dirs, mkdir_and_rename,
|
| 6 |
+
scandir, scandir_mv, scandir_mv_flat, scandir_SIDD, set_random_seed, sizeof_fmt)
|
| 7 |
+
from .create_lmdb import (create_lmdb_for_reds, create_lmdb_for_gopro, create_lmdb_for_rain13k)
|
| 8 |
+
|
| 9 |
+
__all__ = [
|
| 10 |
+
# file_client.py
|
| 11 |
+
'FileClient',
|
| 12 |
+
# img_util.py
|
| 13 |
+
'img2tensor',
|
| 14 |
+
'tensor2img',
|
| 15 |
+
'imfrombytes',
|
| 16 |
+
'imwrite',
|
| 17 |
+
'crop_border',
|
| 18 |
+
# logger.py
|
| 19 |
+
'MessageLogger',
|
| 20 |
+
'init_tb_logger',
|
| 21 |
+
'init_wandb_logger',
|
| 22 |
+
'get_root_logger',
|
| 23 |
+
'get_env_info',
|
| 24 |
+
# misc.py
|
| 25 |
+
'set_random_seed',
|
| 26 |
+
'get_time_str',
|
| 27 |
+
'mkdir_and_rename',
|
| 28 |
+
'make_exp_dirs',
|
| 29 |
+
'scandir',
|
| 30 |
+
'scandir_mv',
|
| 31 |
+
'scandir_mv_flat',
|
| 32 |
+
'check_resume',
|
| 33 |
+
'sizeof_fmt',
|
| 34 |
+
'padding',
|
| 35 |
+
'padding_DP',
|
| 36 |
+
'imfrombytesDP',
|
| 37 |
+
'create_lmdb_for_reds',
|
| 38 |
+
'create_lmdb_for_gopro',
|
| 39 |
+
'create_lmdb_for_rain13k',
|
| 40 |
+
# nano.py
|
| 41 |
+
'psf2otf',
|
| 42 |
+
'fft',
|
| 43 |
+
'ifft',
|
| 44 |
+
'get_edgetaper_weight',
|
| 45 |
+
]
|
basicsr/utils/bundle_submissions.py
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Author: Tobias Plötz, TU Darmstadt ([email protected])
|
| 2 |
+
|
| 3 |
+
# This file is part of the implementation as described in the CVPR 2017 paper:
|
| 4 |
+
# Tobias Plötz and Stefan Roth, Benchmarking Denoising Algorithms with Real Photographs.
|
| 5 |
+
# Please see the file LICENSE.txt for the license governing this code.
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
import scipy.io as sio
|
| 10 |
+
import os
|
| 11 |
+
import h5py
|
| 12 |
+
|
| 13 |
+
def bundle_submissions_raw(submission_folder,session):
|
| 14 |
+
'''
|
| 15 |
+
Bundles submission data for raw denoising
|
| 16 |
+
|
| 17 |
+
submission_folder Folder where denoised images reside
|
| 18 |
+
|
| 19 |
+
Output is written to <submission_folder>/bundled/. Please submit
|
| 20 |
+
the content of this folder.
|
| 21 |
+
'''
|
| 22 |
+
|
| 23 |
+
out_folder = os.path.join(submission_folder, session)
|
| 24 |
+
# out_folder = os.path.join(submission_folder, "bundled/")
|
| 25 |
+
try:
|
| 26 |
+
os.mkdir(out_folder)
|
| 27 |
+
except:pass
|
| 28 |
+
|
| 29 |
+
israw = True
|
| 30 |
+
eval_version="1.0"
|
| 31 |
+
|
| 32 |
+
for i in range(50):
|
| 33 |
+
Idenoised = np.zeros((20,), dtype=np.object)
|
| 34 |
+
for bb in range(20):
|
| 35 |
+
filename = '%04d_%02d.mat'%(i+1,bb+1)
|
| 36 |
+
s = sio.loadmat(os.path.join(submission_folder,filename))
|
| 37 |
+
Idenoised_crop = s["Idenoised_crop"]
|
| 38 |
+
Idenoised[bb] = Idenoised_crop
|
| 39 |
+
filename = '%04d.mat'%(i+1)
|
| 40 |
+
sio.savemat(os.path.join(out_folder, filename),
|
| 41 |
+
{"Idenoised": Idenoised,
|
| 42 |
+
"israw": israw,
|
| 43 |
+
"eval_version": eval_version},
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
def bundle_submissions_srgb(submission_folder,session):
|
| 47 |
+
'''
|
| 48 |
+
Bundles submission data for sRGB denoising
|
| 49 |
+
|
| 50 |
+
submission_folder Folder where denoised images reside
|
| 51 |
+
|
| 52 |
+
Output is written to <submission_folder>/bundled/. Please submit
|
| 53 |
+
the content of this folder.
|
| 54 |
+
'''
|
| 55 |
+
out_folder = os.path.join(submission_folder, session)
|
| 56 |
+
# out_folder = os.path.join(submission_folder, "bundled/")
|
| 57 |
+
try:
|
| 58 |
+
os.mkdir(out_folder)
|
| 59 |
+
except:pass
|
| 60 |
+
israw = False
|
| 61 |
+
eval_version="1.0"
|
| 62 |
+
|
| 63 |
+
for i in range(50):
|
| 64 |
+
Idenoised = np.zeros((20,), dtype=np.object)
|
| 65 |
+
for bb in range(20):
|
| 66 |
+
filename = '%04d_%02d.mat'%(i+1,bb+1)
|
| 67 |
+
s = sio.loadmat(os.path.join(submission_folder,filename))
|
| 68 |
+
Idenoised_crop = s["Idenoised_crop"]
|
| 69 |
+
Idenoised[bb] = Idenoised_crop
|
| 70 |
+
filename = '%04d.mat'%(i+1)
|
| 71 |
+
sio.savemat(os.path.join(out_folder, filename),
|
| 72 |
+
{"Idenoised": Idenoised,
|
| 73 |
+
"israw": israw,
|
| 74 |
+
"eval_version": eval_version},
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def bundle_submissions_srgb_v1(submission_folder,session):
|
| 80 |
+
'''
|
| 81 |
+
Bundles submission data for sRGB denoising
|
| 82 |
+
|
| 83 |
+
submission_folder Folder where denoised images reside
|
| 84 |
+
|
| 85 |
+
Output is written to <submission_folder>/bundled/. Please submit
|
| 86 |
+
the content of this folder.
|
| 87 |
+
'''
|
| 88 |
+
out_folder = os.path.join(submission_folder, session)
|
| 89 |
+
# out_folder = os.path.join(submission_folder, "bundled/")
|
| 90 |
+
try:
|
| 91 |
+
os.mkdir(out_folder)
|
| 92 |
+
except:pass
|
| 93 |
+
israw = False
|
| 94 |
+
eval_version="1.0"
|
| 95 |
+
|
| 96 |
+
for i in range(50):
|
| 97 |
+
Idenoised = np.zeros((20,), dtype=np.object)
|
| 98 |
+
for bb in range(20):
|
| 99 |
+
filename = '%04d_%d.mat'%(i+1,bb+1)
|
| 100 |
+
s = sio.loadmat(os.path.join(submission_folder,filename))
|
| 101 |
+
Idenoised_crop = s["Idenoised_crop"]
|
| 102 |
+
Idenoised[bb] = Idenoised_crop
|
| 103 |
+
filename = '%04d.mat'%(i+1)
|
| 104 |
+
sio.savemat(os.path.join(out_folder, filename),
|
| 105 |
+
{"Idenoised": Idenoised,
|
| 106 |
+
"israw": israw,
|
| 107 |
+
"eval_version": eval_version},
|
| 108 |
+
)
|
basicsr/utils/create_lmdb.py
ADDED
|
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
from os import path as osp
|
| 3 |
+
|
| 4 |
+
from basicsr.utils import scandir
|
| 5 |
+
from basicsr.utils.lmdb_util import make_lmdb_from_imgs
|
| 6 |
+
|
| 7 |
+
def prepare_keys(folder_path, suffix='png'):
|
| 8 |
+
"""Prepare image path list and keys for DIV2K dataset.
|
| 9 |
+
|
| 10 |
+
Args:
|
| 11 |
+
folder_path (str): Folder path.
|
| 12 |
+
|
| 13 |
+
Returns:
|
| 14 |
+
list[str]: Image path list.
|
| 15 |
+
list[str]: Key list.
|
| 16 |
+
"""
|
| 17 |
+
print('Reading image path list ...')
|
| 18 |
+
img_path_list = sorted(
|
| 19 |
+
list(scandir(folder_path, suffix=suffix, recursive=False)))
|
| 20 |
+
keys = [img_path.split('.{}'.format(suffix))[0] for img_path in sorted(img_path_list)]
|
| 21 |
+
|
| 22 |
+
return img_path_list, keys
|
| 23 |
+
|
| 24 |
+
def create_lmdb_for_reds():
|
| 25 |
+
folder_path = './datasets/REDS/val/sharp_300'
|
| 26 |
+
lmdb_path = './datasets/REDS/val/sharp_300.lmdb'
|
| 27 |
+
img_path_list, keys = prepare_keys(folder_path, 'png')
|
| 28 |
+
make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys)
|
| 29 |
+
#
|
| 30 |
+
folder_path = './datasets/REDS/val/blur_300'
|
| 31 |
+
lmdb_path = './datasets/REDS/val/blur_300.lmdb'
|
| 32 |
+
img_path_list, keys = prepare_keys(folder_path, 'jpg')
|
| 33 |
+
make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys)
|
| 34 |
+
|
| 35 |
+
folder_path = './datasets/REDS/train/train_sharp'
|
| 36 |
+
lmdb_path = './datasets/REDS/train/train_sharp.lmdb'
|
| 37 |
+
img_path_list, keys = prepare_keys(folder_path, 'png')
|
| 38 |
+
make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys)
|
| 39 |
+
|
| 40 |
+
folder_path = './datasets/REDS/train/train_blur_jpeg'
|
| 41 |
+
lmdb_path = './datasets/REDS/train/train_blur_jpeg.lmdb'
|
| 42 |
+
img_path_list, keys = prepare_keys(folder_path, 'jpg')
|
| 43 |
+
make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def create_lmdb_for_gopro():
|
| 47 |
+
folder_path = './datasets/GoPro/train/blur_crops'
|
| 48 |
+
lmdb_path = './datasets/GoPro/train/blur_crops.lmdb'
|
| 49 |
+
|
| 50 |
+
img_path_list, keys = prepare_keys(folder_path, 'png')
|
| 51 |
+
make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys)
|
| 52 |
+
|
| 53 |
+
folder_path = './datasets/GoPro/train/sharp_crops'
|
| 54 |
+
lmdb_path = './datasets/GoPro/train/sharp_crops.lmdb'
|
| 55 |
+
|
| 56 |
+
img_path_list, keys = prepare_keys(folder_path, 'png')
|
| 57 |
+
make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys)
|
| 58 |
+
|
| 59 |
+
folder_path = './datasets/GoPro/test/target'
|
| 60 |
+
lmdb_path = './datasets/GoPro/test/target.lmdb'
|
| 61 |
+
|
| 62 |
+
img_path_list, keys = prepare_keys(folder_path, 'png')
|
| 63 |
+
make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys)
|
| 64 |
+
|
| 65 |
+
folder_path = './datasets/GoPro/test/input'
|
| 66 |
+
lmdb_path = './datasets/GoPro/test/input.lmdb'
|
| 67 |
+
|
| 68 |
+
img_path_list, keys = prepare_keys(folder_path, 'png')
|
| 69 |
+
make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys)
|
| 70 |
+
|
| 71 |
+
def create_lmdb_for_rain13k():
|
| 72 |
+
folder_path = './datasets/Rain13k/train/input'
|
| 73 |
+
lmdb_path = './datasets/Rain13k/train/input.lmdb'
|
| 74 |
+
|
| 75 |
+
img_path_list, keys = prepare_keys(folder_path, 'jpg')
|
| 76 |
+
make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys)
|
| 77 |
+
|
| 78 |
+
folder_path = './datasets/Rain13k/train/target'
|
| 79 |
+
lmdb_path = './datasets/Rain13k/train/target.lmdb'
|
| 80 |
+
|
| 81 |
+
img_path_list, keys = prepare_keys(folder_path, 'jpg')
|
| 82 |
+
make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys)
|
| 83 |
+
|
| 84 |
+
def create_lmdb_for_SIDD():
|
| 85 |
+
folder_path = './datasets/SIDD/train/input_crops'
|
| 86 |
+
lmdb_path = './datasets/SIDD/train/input_crops.lmdb'
|
| 87 |
+
|
| 88 |
+
img_path_list, keys = prepare_keys(folder_path, 'PNG')
|
| 89 |
+
make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys)
|
| 90 |
+
|
| 91 |
+
folder_path = './datasets/SIDD/train/gt_crops'
|
| 92 |
+
lmdb_path = './datasets/SIDD/train/gt_crops.lmdb'
|
| 93 |
+
|
| 94 |
+
img_path_list, keys = prepare_keys(folder_path, 'PNG')
|
| 95 |
+
make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys)
|
| 96 |
+
|
| 97 |
+
#for val
|
| 98 |
+
folder_path = './datasets/SIDD/val/input_crops'
|
| 99 |
+
lmdb_path = './datasets/SIDD/val/input_crops.lmdb'
|
| 100 |
+
mat_path = './datasets/SIDD/ValidationNoisyBlocksSrgb.mat'
|
| 101 |
+
if not osp.exists(folder_path):
|
| 102 |
+
os.makedirs(folder_path)
|
| 103 |
+
assert osp.exists(mat_path)
|
| 104 |
+
data = scio.loadmat(mat_path)['ValidationNoisyBlocksSrgb']
|
| 105 |
+
N, B, H ,W, C = data.shape
|
| 106 |
+
data = data.reshape(N*B, H, W, C)
|
| 107 |
+
for i in tqdm(range(N*B)):
|
| 108 |
+
cv2.imwrite(osp.join(folder_path, 'ValidationBlocksSrgb_{}.png'.format(i)), cv2.cvtColor(data[i,...], cv2.COLOR_RGB2BGR))
|
| 109 |
+
img_path_list, keys = prepare_keys(folder_path, 'png')
|
| 110 |
+
make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys)
|
| 111 |
+
|
| 112 |
+
folder_path = './datasets/SIDD/val/gt_crops'
|
| 113 |
+
lmdb_path = './datasets/SIDD/val/gt_crops.lmdb'
|
| 114 |
+
mat_path = './datasets/SIDD/ValidationGtBlocksSrgb.mat'
|
| 115 |
+
if not osp.exists(folder_path):
|
| 116 |
+
os.makedirs(folder_path)
|
| 117 |
+
assert osp.exists(mat_path)
|
| 118 |
+
data = scio.loadmat(mat_path)['ValidationGtBlocksSrgb']
|
| 119 |
+
N, B, H ,W, C = data.shape
|
| 120 |
+
data = data.reshape(N*B, H, W, C)
|
| 121 |
+
for i in tqdm(range(N*B)):
|
| 122 |
+
cv2.imwrite(osp.join(folder_path, 'ValidationBlocksSrgb_{}.png'.format(i)), cv2.cvtColor(data[i,...], cv2.COLOR_RGB2BGR))
|
| 123 |
+
img_path_list, keys = prepare_keys(folder_path, 'png')
|
| 124 |
+
make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys)
|
basicsr/utils/dist_util.py
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py # noqa: E501
|
| 2 |
+
import functools
|
| 3 |
+
import os
|
| 4 |
+
import subprocess
|
| 5 |
+
import torch
|
| 6 |
+
import torch.distributed as dist
|
| 7 |
+
import torch.multiprocessing as mp
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def init_dist(launcher, backend='nccl', **kwargs):
|
| 11 |
+
if mp.get_start_method(allow_none=True) is None:
|
| 12 |
+
mp.set_start_method('spawn')
|
| 13 |
+
if launcher == 'pytorch':
|
| 14 |
+
_init_dist_pytorch(backend, **kwargs)
|
| 15 |
+
elif launcher == 'slurm':
|
| 16 |
+
_init_dist_slurm(backend, **kwargs)
|
| 17 |
+
else:
|
| 18 |
+
raise ValueError(f'Invalid launcher type: {launcher}')
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def _init_dist_pytorch(backend, **kwargs):
|
| 22 |
+
rank = int(os.environ['RANK'])
|
| 23 |
+
num_gpus = torch.cuda.device_count()
|
| 24 |
+
torch.cuda.set_device(rank % num_gpus)
|
| 25 |
+
dist.init_process_group(backend=backend, **kwargs)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def _init_dist_slurm(backend, port=None):
|
| 29 |
+
"""Initialize slurm distributed training environment.
|
| 30 |
+
|
| 31 |
+
If argument ``port`` is not specified, then the master port will be system
|
| 32 |
+
environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system
|
| 33 |
+
environment variable, then a default port ``29500`` will be used.
|
| 34 |
+
|
| 35 |
+
Args:
|
| 36 |
+
backend (str): Backend of torch.distributed.
|
| 37 |
+
port (int, optional): Master port. Defaults to None.
|
| 38 |
+
"""
|
| 39 |
+
proc_id = int(os.environ['SLURM_PROCID'])
|
| 40 |
+
ntasks = int(os.environ['SLURM_NTASKS'])
|
| 41 |
+
node_list = os.environ['SLURM_NODELIST']
|
| 42 |
+
num_gpus = torch.cuda.device_count()
|
| 43 |
+
torch.cuda.set_device(proc_id % num_gpus)
|
| 44 |
+
addr = subprocess.getoutput(
|
| 45 |
+
f'scontrol show hostname {node_list} | head -n1')
|
| 46 |
+
# specify master port
|
| 47 |
+
if port is not None:
|
| 48 |
+
os.environ['MASTER_PORT'] = str(port)
|
| 49 |
+
elif 'MASTER_PORT' in os.environ:
|
| 50 |
+
pass # use MASTER_PORT in the environment variable
|
| 51 |
+
else:
|
| 52 |
+
# 29500 is torch.distributed default port
|
| 53 |
+
os.environ['MASTER_PORT'] = '29500'
|
| 54 |
+
os.environ['MASTER_ADDR'] = addr
|
| 55 |
+
os.environ['WORLD_SIZE'] = str(ntasks)
|
| 56 |
+
os.environ['LOCAL_RANK'] = str(proc_id % num_gpus)
|
| 57 |
+
os.environ['RANK'] = str(proc_id)
|
| 58 |
+
dist.init_process_group(backend=backend)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def get_dist_info():
|
| 62 |
+
if dist.is_available():
|
| 63 |
+
initialized = dist.is_initialized()
|
| 64 |
+
else:
|
| 65 |
+
initialized = False
|
| 66 |
+
if initialized:
|
| 67 |
+
rank = dist.get_rank()
|
| 68 |
+
world_size = dist.get_world_size()
|
| 69 |
+
else:
|
| 70 |
+
rank = 0
|
| 71 |
+
world_size = 1
|
| 72 |
+
return rank, world_size
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def master_only(func):
|
| 76 |
+
|
| 77 |
+
@functools.wraps(func)
|
| 78 |
+
def wrapper(*args, **kwargs):
|
| 79 |
+
rank, _ = get_dist_info()
|
| 80 |
+
if rank == 0:
|
| 81 |
+
return func(*args, **kwargs)
|
| 82 |
+
|
| 83 |
+
return wrapper
|
basicsr/utils/download_util.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import requests
|
| 3 |
+
from tqdm import tqdm
|
| 4 |
+
|
| 5 |
+
from .misc import sizeof_fmt
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def download_file_from_google_drive(file_id, save_path):
|
| 9 |
+
"""Download files from google drive.
|
| 10 |
+
|
| 11 |
+
Ref:
|
| 12 |
+
https://stackoverflow.com/questions/25010369/wget-curl-large-file-from-google-drive # noqa E501
|
| 13 |
+
|
| 14 |
+
Args:
|
| 15 |
+
file_id (str): File id.
|
| 16 |
+
save_path (str): Save path.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
session = requests.Session()
|
| 20 |
+
URL = 'https://docs.google.com/uc?export=download'
|
| 21 |
+
params = {'id': file_id}
|
| 22 |
+
|
| 23 |
+
response = session.get(URL, params=params, stream=True)
|
| 24 |
+
token = get_confirm_token(response)
|
| 25 |
+
if token:
|
| 26 |
+
params['confirm'] = token
|
| 27 |
+
response = session.get(URL, params=params, stream=True)
|
| 28 |
+
|
| 29 |
+
# get file size
|
| 30 |
+
response_file_size = session.get(
|
| 31 |
+
URL, params=params, stream=True, headers={'Range': 'bytes=0-2'})
|
| 32 |
+
if 'Content-Range' in response_file_size.headers:
|
| 33 |
+
file_size = int(
|
| 34 |
+
response_file_size.headers['Content-Range'].split('/')[1])
|
| 35 |
+
else:
|
| 36 |
+
file_size = None
|
| 37 |
+
|
| 38 |
+
save_response_content(response, save_path, file_size)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def get_confirm_token(response):
|
| 42 |
+
for key, value in response.cookies.items():
|
| 43 |
+
if key.startswith('download_warning'):
|
| 44 |
+
return value
|
| 45 |
+
return None
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def save_response_content(response,
|
| 49 |
+
destination,
|
| 50 |
+
file_size=None,
|
| 51 |
+
chunk_size=32768):
|
| 52 |
+
if file_size is not None:
|
| 53 |
+
pbar = tqdm(total=math.ceil(file_size / chunk_size), unit='chunk')
|
| 54 |
+
|
| 55 |
+
readable_file_size = sizeof_fmt(file_size)
|
| 56 |
+
else:
|
| 57 |
+
pbar = None
|
| 58 |
+
|
| 59 |
+
with open(destination, 'wb') as f:
|
| 60 |
+
downloaded_size = 0
|
| 61 |
+
for chunk in response.iter_content(chunk_size):
|
| 62 |
+
downloaded_size += chunk_size
|
| 63 |
+
if pbar is not None:
|
| 64 |
+
pbar.update(1)
|
| 65 |
+
pbar.set_description(f'Download {sizeof_fmt(downloaded_size)} '
|
| 66 |
+
f'/ {readable_file_size}')
|
| 67 |
+
if chunk: # filter out keep-alive new chunks
|
| 68 |
+
f.write(chunk)
|
| 69 |
+
if pbar is not None:
|
| 70 |
+
pbar.close()
|
basicsr/utils/face_util.py
ADDED
|
@@ -0,0 +1,217 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import cv2
|
| 2 |
+
import numpy as np
|
| 3 |
+
import os
|
| 4 |
+
import torch
|
| 5 |
+
from skimage import transform as trans
|
| 6 |
+
|
| 7 |
+
from basicsr.utils import imwrite
|
| 8 |
+
|
| 9 |
+
try:
|
| 10 |
+
import dlib
|
| 11 |
+
except ImportError:
|
| 12 |
+
print('Please install dlib before testing face restoration.'
|
| 13 |
+
'Reference: https://github.com/davisking/dlib')
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class FaceRestorationHelper(object):
|
| 17 |
+
"""Helper for the face restoration pipeline."""
|
| 18 |
+
|
| 19 |
+
def __init__(self, upscale_factor, face_size=512):
|
| 20 |
+
self.upscale_factor = upscale_factor
|
| 21 |
+
self.face_size = (face_size, face_size)
|
| 22 |
+
|
| 23 |
+
# standard 5 landmarks for FFHQ faces with 1024 x 1024
|
| 24 |
+
self.face_template = np.array([[686.77227723, 488.62376238],
|
| 25 |
+
[586.77227723, 493.59405941],
|
| 26 |
+
[337.91089109, 488.38613861],
|
| 27 |
+
[437.95049505, 493.51485149],
|
| 28 |
+
[513.58415842, 678.5049505]])
|
| 29 |
+
self.face_template = self.face_template / (1024 // face_size)
|
| 30 |
+
# for estimation the 2D similarity transformation
|
| 31 |
+
self.similarity_trans = trans.SimilarityTransform()
|
| 32 |
+
|
| 33 |
+
self.all_landmarks_5 = []
|
| 34 |
+
self.all_landmarks_68 = []
|
| 35 |
+
self.affine_matrices = []
|
| 36 |
+
self.inverse_affine_matrices = []
|
| 37 |
+
self.cropped_faces = []
|
| 38 |
+
self.restored_faces = []
|
| 39 |
+
self.save_png = True
|
| 40 |
+
|
| 41 |
+
def init_dlib(self, detection_path, landmark5_path, landmark68_path):
|
| 42 |
+
"""Initialize the dlib detectors and predictors."""
|
| 43 |
+
self.face_detector = dlib.cnn_face_detection_model_v1(detection_path)
|
| 44 |
+
self.shape_predictor_5 = dlib.shape_predictor(landmark5_path)
|
| 45 |
+
self.shape_predictor_68 = dlib.shape_predictor(landmark68_path)
|
| 46 |
+
|
| 47 |
+
def free_dlib_gpu_memory(self):
|
| 48 |
+
del self.face_detector
|
| 49 |
+
del self.shape_predictor_5
|
| 50 |
+
del self.shape_predictor_68
|
| 51 |
+
|
| 52 |
+
def read_input_image(self, img_path):
|
| 53 |
+
# self.input_img is Numpy array, (h, w, c) with RGB order
|
| 54 |
+
self.input_img = dlib.load_rgb_image(img_path)
|
| 55 |
+
|
| 56 |
+
def detect_faces(self,
|
| 57 |
+
img_path,
|
| 58 |
+
upsample_num_times=1,
|
| 59 |
+
only_keep_largest=False):
|
| 60 |
+
"""
|
| 61 |
+
Args:
|
| 62 |
+
img_path (str): Image path.
|
| 63 |
+
upsample_num_times (int): Upsamples the image before running the
|
| 64 |
+
face detector
|
| 65 |
+
|
| 66 |
+
Returns:
|
| 67 |
+
int: Number of detected faces.
|
| 68 |
+
"""
|
| 69 |
+
self.read_input_image(img_path)
|
| 70 |
+
det_faces = self.face_detector(self.input_img, upsample_num_times)
|
| 71 |
+
if len(det_faces) == 0:
|
| 72 |
+
print('No face detected. Try to increase upsample_num_times.')
|
| 73 |
+
else:
|
| 74 |
+
if only_keep_largest:
|
| 75 |
+
print('Detect several faces and only keep the largest.')
|
| 76 |
+
face_areas = []
|
| 77 |
+
for i in range(len(det_faces)):
|
| 78 |
+
face_area = (det_faces[i].rect.right() -
|
| 79 |
+
det_faces[i].rect.left()) * (
|
| 80 |
+
det_faces[i].rect.bottom() -
|
| 81 |
+
det_faces[i].rect.top())
|
| 82 |
+
face_areas.append(face_area)
|
| 83 |
+
largest_idx = face_areas.index(max(face_areas))
|
| 84 |
+
self.det_faces = [det_faces[largest_idx]]
|
| 85 |
+
else:
|
| 86 |
+
self.det_faces = det_faces
|
| 87 |
+
return len(self.det_faces)
|
| 88 |
+
|
| 89 |
+
def get_face_landmarks_5(self):
|
| 90 |
+
for face in self.det_faces:
|
| 91 |
+
shape = self.shape_predictor_5(self.input_img, face.rect)
|
| 92 |
+
landmark = np.array([[part.x, part.y] for part in shape.parts()])
|
| 93 |
+
self.all_landmarks_5.append(landmark)
|
| 94 |
+
return len(self.all_landmarks_5)
|
| 95 |
+
|
| 96 |
+
def get_face_landmarks_68(self):
|
| 97 |
+
"""Get 68 densemarks for cropped images.
|
| 98 |
+
|
| 99 |
+
Should only have one face at most in the cropped image.
|
| 100 |
+
"""
|
| 101 |
+
num_detected_face = 0
|
| 102 |
+
for idx, face in enumerate(self.cropped_faces):
|
| 103 |
+
# face detection
|
| 104 |
+
det_face = self.face_detector(face, 1) # TODO: can we remove it?
|
| 105 |
+
if len(det_face) == 0:
|
| 106 |
+
print(f'Cannot find faces in cropped image with index {idx}.')
|
| 107 |
+
self.all_landmarks_68.append(None)
|
| 108 |
+
else:
|
| 109 |
+
if len(det_face) > 1:
|
| 110 |
+
print('Detect several faces in the cropped face. Use the '
|
| 111 |
+
' largest one. Note that it will also cause overlap '
|
| 112 |
+
'during paste_faces_to_input_image.')
|
| 113 |
+
face_areas = []
|
| 114 |
+
for i in range(len(det_face)):
|
| 115 |
+
face_area = (det_face[i].rect.right() -
|
| 116 |
+
det_face[i].rect.left()) * (
|
| 117 |
+
det_face[i].rect.bottom() -
|
| 118 |
+
det_face[i].rect.top())
|
| 119 |
+
face_areas.append(face_area)
|
| 120 |
+
largest_idx = face_areas.index(max(face_areas))
|
| 121 |
+
face_rect = det_face[largest_idx].rect
|
| 122 |
+
else:
|
| 123 |
+
face_rect = det_face[0].rect
|
| 124 |
+
shape = self.shape_predictor_68(face, face_rect)
|
| 125 |
+
landmark = np.array([[part.x, part.y]
|
| 126 |
+
for part in shape.parts()])
|
| 127 |
+
self.all_landmarks_68.append(landmark)
|
| 128 |
+
num_detected_face += 1
|
| 129 |
+
|
| 130 |
+
return num_detected_face
|
| 131 |
+
|
| 132 |
+
def warp_crop_faces(self,
|
| 133 |
+
save_cropped_path=None,
|
| 134 |
+
save_inverse_affine_path=None):
|
| 135 |
+
"""Get affine matrix, warp and cropped faces.
|
| 136 |
+
|
| 137 |
+
Also get inverse affine matrix for post-processing.
|
| 138 |
+
"""
|
| 139 |
+
for idx, landmark in enumerate(self.all_landmarks_5):
|
| 140 |
+
# use 5 landmarks to get affine matrix
|
| 141 |
+
self.similarity_trans.estimate(landmark, self.face_template)
|
| 142 |
+
affine_matrix = self.similarity_trans.params[0:2, :]
|
| 143 |
+
self.affine_matrices.append(affine_matrix)
|
| 144 |
+
# warp and crop faces
|
| 145 |
+
cropped_face = cv2.warpAffine(self.input_img, affine_matrix,
|
| 146 |
+
self.face_size)
|
| 147 |
+
self.cropped_faces.append(cropped_face)
|
| 148 |
+
# save the cropped face
|
| 149 |
+
if save_cropped_path is not None:
|
| 150 |
+
path, ext = os.path.splitext(save_cropped_path)
|
| 151 |
+
if self.save_png:
|
| 152 |
+
save_path = f'{path}_{idx:02d}.png'
|
| 153 |
+
else:
|
| 154 |
+
save_path = f'{path}_{idx:02d}{ext}'
|
| 155 |
+
|
| 156 |
+
imwrite(
|
| 157 |
+
cv2.cvtColor(cropped_face, cv2.COLOR_RGB2BGR), save_path)
|
| 158 |
+
|
| 159 |
+
# get inverse affine matrix
|
| 160 |
+
self.similarity_trans.estimate(self.face_template,
|
| 161 |
+
landmark * self.upscale_factor)
|
| 162 |
+
inverse_affine = self.similarity_trans.params[0:2, :]
|
| 163 |
+
self.inverse_affine_matrices.append(inverse_affine)
|
| 164 |
+
# save inverse affine matrices
|
| 165 |
+
if save_inverse_affine_path is not None:
|
| 166 |
+
path, _ = os.path.splitext(save_inverse_affine_path)
|
| 167 |
+
save_path = f'{path}_{idx:02d}.pth'
|
| 168 |
+
torch.save(inverse_affine, save_path)
|
| 169 |
+
|
| 170 |
+
def add_restored_face(self, face):
|
| 171 |
+
self.restored_faces.append(face)
|
| 172 |
+
|
| 173 |
+
def paste_faces_to_input_image(self, save_path):
|
| 174 |
+
# operate in the BGR order
|
| 175 |
+
input_img = cv2.cvtColor(self.input_img, cv2.COLOR_RGB2BGR)
|
| 176 |
+
h, w, _ = input_img.shape
|
| 177 |
+
h_up, w_up = h * self.upscale_factor, w * self.upscale_factor
|
| 178 |
+
# simply resize the background
|
| 179 |
+
upsample_img = cv2.resize(input_img, (w_up, h_up))
|
| 180 |
+
assert len(self.restored_faces) == len(self.inverse_affine_matrices), (
|
| 181 |
+
'length of restored_faces and affine_matrices are different.')
|
| 182 |
+
for restored_face, inverse_affine in zip(self.restored_faces,
|
| 183 |
+
self.inverse_affine_matrices):
|
| 184 |
+
inv_restored = cv2.warpAffine(restored_face, inverse_affine,
|
| 185 |
+
(w_up, h_up))
|
| 186 |
+
mask = np.ones((*self.face_size, 3), dtype=np.float32)
|
| 187 |
+
inv_mask = cv2.warpAffine(mask, inverse_affine, (w_up, h_up))
|
| 188 |
+
# remove the black borders
|
| 189 |
+
inv_mask_erosion = cv2.erode(
|
| 190 |
+
inv_mask,
|
| 191 |
+
np.ones((2 * self.upscale_factor, 2 * self.upscale_factor),
|
| 192 |
+
np.uint8))
|
| 193 |
+
inv_restored_remove_border = inv_mask_erosion * inv_restored
|
| 194 |
+
total_face_area = np.sum(inv_mask_erosion) // 3
|
| 195 |
+
# compute the fusion edge based on the area of face
|
| 196 |
+
w_edge = int(total_face_area**0.5) // 20
|
| 197 |
+
erosion_radius = w_edge * 2
|
| 198 |
+
inv_mask_center = cv2.erode(
|
| 199 |
+
inv_mask_erosion,
|
| 200 |
+
np.ones((erosion_radius, erosion_radius), np.uint8))
|
| 201 |
+
blur_size = w_edge * 2
|
| 202 |
+
inv_soft_mask = cv2.GaussianBlur(inv_mask_center,
|
| 203 |
+
(blur_size + 1, blur_size + 1), 0)
|
| 204 |
+
upsample_img = inv_soft_mask * inv_restored_remove_border + (
|
| 205 |
+
1 - inv_soft_mask) * upsample_img
|
| 206 |
+
if self.save_png:
|
| 207 |
+
save_path = save_path.replace('.jpg',
|
| 208 |
+
'.png').replace('.jpeg', '.png')
|
| 209 |
+
imwrite(upsample_img.astype(np.uint8), save_path)
|
| 210 |
+
|
| 211 |
+
def clean_all(self):
|
| 212 |
+
self.all_landmarks_5 = []
|
| 213 |
+
self.all_landmarks_68 = []
|
| 214 |
+
self.restored_faces = []
|
| 215 |
+
self.affine_matrices = []
|
| 216 |
+
self.cropped_faces = []
|
| 217 |
+
self.inverse_affine_matrices = []
|
basicsr/utils/file_client.py
ADDED
|
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/fileio/file_client.py # noqa: E501
|
| 2 |
+
from abc import ABCMeta, abstractmethod
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class BaseStorageBackend(metaclass=ABCMeta):
|
| 6 |
+
"""Abstract class of storage backends.
|
| 7 |
+
|
| 8 |
+
All backends need to implement two apis: ``get()`` and ``get_text()``.
|
| 9 |
+
``get()`` reads the file as a byte stream and ``get_text()`` reads the file
|
| 10 |
+
as texts.
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
@abstractmethod
|
| 14 |
+
def get(self, filepath):
|
| 15 |
+
pass
|
| 16 |
+
|
| 17 |
+
@abstractmethod
|
| 18 |
+
def get_text(self, filepath):
|
| 19 |
+
pass
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class MemcachedBackend(BaseStorageBackend):
|
| 23 |
+
"""Memcached storage backend.
|
| 24 |
+
|
| 25 |
+
Attributes:
|
| 26 |
+
server_list_cfg (str): Config file for memcached server list.
|
| 27 |
+
client_cfg (str): Config file for memcached client.
|
| 28 |
+
sys_path (str | None): Additional path to be appended to `sys.path`.
|
| 29 |
+
Default: None.
|
| 30 |
+
"""
|
| 31 |
+
|
| 32 |
+
def __init__(self, server_list_cfg, client_cfg, sys_path=None):
|
| 33 |
+
if sys_path is not None:
|
| 34 |
+
import sys
|
| 35 |
+
sys.path.append(sys_path)
|
| 36 |
+
try:
|
| 37 |
+
import mc
|
| 38 |
+
except ImportError:
|
| 39 |
+
raise ImportError(
|
| 40 |
+
'Please install memcached to enable MemcachedBackend.')
|
| 41 |
+
|
| 42 |
+
self.server_list_cfg = server_list_cfg
|
| 43 |
+
self.client_cfg = client_cfg
|
| 44 |
+
self._client = mc.MemcachedClient.GetInstance(self.server_list_cfg,
|
| 45 |
+
self.client_cfg)
|
| 46 |
+
# mc.pyvector servers as a point which points to a memory cache
|
| 47 |
+
self._mc_buffer = mc.pyvector()
|
| 48 |
+
|
| 49 |
+
def get(self, filepath):
|
| 50 |
+
filepath = str(filepath)
|
| 51 |
+
import mc
|
| 52 |
+
self._client.Get(filepath, self._mc_buffer)
|
| 53 |
+
value_buf = mc.ConvertBuffer(self._mc_buffer)
|
| 54 |
+
return value_buf
|
| 55 |
+
|
| 56 |
+
def get_text(self, filepath):
|
| 57 |
+
raise NotImplementedError
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
class HardDiskBackend(BaseStorageBackend):
|
| 61 |
+
"""Raw hard disks storage backend."""
|
| 62 |
+
|
| 63 |
+
def get(self, filepath):
|
| 64 |
+
filepath = str(filepath)
|
| 65 |
+
with open(filepath, 'rb') as f:
|
| 66 |
+
value_buf = f.read()
|
| 67 |
+
return value_buf
|
| 68 |
+
|
| 69 |
+
def get_text(self, filepath):
|
| 70 |
+
filepath = str(filepath)
|
| 71 |
+
with open(filepath, 'r') as f:
|
| 72 |
+
value_buf = f.read()
|
| 73 |
+
return value_buf
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
class LmdbBackend(BaseStorageBackend):
|
| 77 |
+
"""Lmdb storage backend.
|
| 78 |
+
|
| 79 |
+
Args:
|
| 80 |
+
db_paths (str | list[str]): Lmdb database paths.
|
| 81 |
+
client_keys (str | list[str]): Lmdb client keys. Default: 'default'.
|
| 82 |
+
readonly (bool, optional): Lmdb environment parameter. If True,
|
| 83 |
+
disallow any write operations. Default: True.
|
| 84 |
+
lock (bool, optional): Lmdb environment parameter. If False, when
|
| 85 |
+
concurrent access occurs, do not lock the database. Default: False.
|
| 86 |
+
readahead (bool, optional): Lmdb environment parameter. If False,
|
| 87 |
+
disable the OS filesystem readahead mechanism, which may improve
|
| 88 |
+
random read performance when a database is larger than RAM.
|
| 89 |
+
Default: False.
|
| 90 |
+
|
| 91 |
+
Attributes:
|
| 92 |
+
db_paths (list): Lmdb database path.
|
| 93 |
+
_client (list): A list of several lmdb envs.
|
| 94 |
+
"""
|
| 95 |
+
|
| 96 |
+
def __init__(self,
|
| 97 |
+
db_paths,
|
| 98 |
+
client_keys='default',
|
| 99 |
+
readonly=True,
|
| 100 |
+
lock=False,
|
| 101 |
+
readahead=False,
|
| 102 |
+
**kwargs):
|
| 103 |
+
try:
|
| 104 |
+
import lmdb
|
| 105 |
+
except ImportError:
|
| 106 |
+
raise ImportError('Please install lmdb to enable LmdbBackend.')
|
| 107 |
+
|
| 108 |
+
if isinstance(client_keys, str):
|
| 109 |
+
client_keys = [client_keys]
|
| 110 |
+
|
| 111 |
+
if isinstance(db_paths, list):
|
| 112 |
+
self.db_paths = [str(v) for v in db_paths]
|
| 113 |
+
elif isinstance(db_paths, str):
|
| 114 |
+
self.db_paths = [str(db_paths)]
|
| 115 |
+
assert len(client_keys) == len(self.db_paths), (
|
| 116 |
+
'client_keys and db_paths should have the same length, '
|
| 117 |
+
f'but received {len(client_keys)} and {len(self.db_paths)}.')
|
| 118 |
+
|
| 119 |
+
self._client = {}
|
| 120 |
+
|
| 121 |
+
for client, path in zip(client_keys, self.db_paths):
|
| 122 |
+
self._client[client] = lmdb.open(
|
| 123 |
+
path,
|
| 124 |
+
readonly=readonly,
|
| 125 |
+
lock=lock,
|
| 126 |
+
readahead=readahead,
|
| 127 |
+
map_size=8*1024*10485760,
|
| 128 |
+
# max_readers=1,
|
| 129 |
+
**kwargs)
|
| 130 |
+
|
| 131 |
+
def get(self, filepath, client_key):
|
| 132 |
+
"""Get values according to the filepath from one lmdb named client_key.
|
| 133 |
+
|
| 134 |
+
Args:
|
| 135 |
+
filepath (str | obj:`Path`): Here, filepath is the lmdb key.
|
| 136 |
+
client_key (str): Used for distinguishing differnet lmdb envs.
|
| 137 |
+
"""
|
| 138 |
+
filepath = str(filepath)
|
| 139 |
+
assert client_key in self._client, (f'client_key {client_key} is not '
|
| 140 |
+
'in lmdb clients.')
|
| 141 |
+
client = self._client[client_key]
|
| 142 |
+
with client.begin(write=False) as txn:
|
| 143 |
+
value_buf = txn.get(filepath.encode('ascii'))
|
| 144 |
+
return value_buf
|
| 145 |
+
|
| 146 |
+
def get_text(self, filepath):
|
| 147 |
+
raise NotImplementedError
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
class FileClient(object):
|
| 151 |
+
"""A general file client to access files in different backend.
|
| 152 |
+
|
| 153 |
+
The client loads a file or text in a specified backend from its path
|
| 154 |
+
and return it as a binary file. it can also register other backend
|
| 155 |
+
accessor with a given name and backend class.
|
| 156 |
+
|
| 157 |
+
Attributes:
|
| 158 |
+
backend (str): The storage backend type. Options are "disk",
|
| 159 |
+
"memcached" and "lmdb".
|
| 160 |
+
client (:obj:`BaseStorageBackend`): The backend object.
|
| 161 |
+
"""
|
| 162 |
+
|
| 163 |
+
_backends = {
|
| 164 |
+
'disk': HardDiskBackend,
|
| 165 |
+
'memcached': MemcachedBackend,
|
| 166 |
+
'lmdb': LmdbBackend,
|
| 167 |
+
}
|
| 168 |
+
|
| 169 |
+
def __init__(self, backend='disk', **kwargs):
|
| 170 |
+
if backend not in self._backends:
|
| 171 |
+
raise ValueError(
|
| 172 |
+
f'Backend {backend} is not supported. Currently supported ones'
|
| 173 |
+
f' are {list(self._backends.keys())}')
|
| 174 |
+
self.backend = backend
|
| 175 |
+
self.client = self._backends[backend](**kwargs)
|
| 176 |
+
|
| 177 |
+
def get(self, filepath, client_key='default'):
|
| 178 |
+
# client_key is used only for lmdb, where different fileclients have
|
| 179 |
+
# different lmdb environments.
|
| 180 |
+
if self.backend == 'lmdb':
|
| 181 |
+
return self.client.get(filepath, client_key)
|
| 182 |
+
else:
|
| 183 |
+
return self.client.get(filepath)
|
| 184 |
+
|
| 185 |
+
def get_text(self, filepath):
|
| 186 |
+
return self.client.get_text(filepath)
|
basicsr/utils/flow_util.py
ADDED
|
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/video/optflow.py # noqa: E501
|
| 2 |
+
import cv2
|
| 3 |
+
import numpy as np
|
| 4 |
+
import os
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def flowread(flow_path, quantize=False, concat_axis=0, *args, **kwargs):
|
| 8 |
+
"""Read an optical flow map.
|
| 9 |
+
|
| 10 |
+
Args:
|
| 11 |
+
flow_path (ndarray or str): Flow path.
|
| 12 |
+
quantize (bool): whether to read quantized pair, if set to True,
|
| 13 |
+
remaining args will be passed to :func:`dequantize_flow`.
|
| 14 |
+
concat_axis (int): The axis that dx and dy are concatenated,
|
| 15 |
+
can be either 0 or 1. Ignored if quantize is False.
|
| 16 |
+
|
| 17 |
+
Returns:
|
| 18 |
+
ndarray: Optical flow represented as a (h, w, 2) numpy array
|
| 19 |
+
"""
|
| 20 |
+
if quantize:
|
| 21 |
+
assert concat_axis in [0, 1]
|
| 22 |
+
cat_flow = cv2.imread(flow_path, cv2.IMREAD_UNCHANGED)
|
| 23 |
+
if cat_flow.ndim != 2:
|
| 24 |
+
raise IOError(f'{flow_path} is not a valid quantized flow file, '
|
| 25 |
+
f'its dimension is {cat_flow.ndim}.')
|
| 26 |
+
assert cat_flow.shape[concat_axis] % 2 == 0
|
| 27 |
+
dx, dy = np.split(cat_flow, 2, axis=concat_axis)
|
| 28 |
+
flow = dequantize_flow(dx, dy, *args, **kwargs)
|
| 29 |
+
else:
|
| 30 |
+
with open(flow_path, 'rb') as f:
|
| 31 |
+
try:
|
| 32 |
+
header = f.read(4).decode('utf-8')
|
| 33 |
+
except Exception:
|
| 34 |
+
raise IOError(f'Invalid flow file: {flow_path}')
|
| 35 |
+
else:
|
| 36 |
+
if header != 'PIEH':
|
| 37 |
+
raise IOError(f'Invalid flow file: {flow_path}, '
|
| 38 |
+
'header does not contain PIEH')
|
| 39 |
+
|
| 40 |
+
w = np.fromfile(f, np.int32, 1).squeeze()
|
| 41 |
+
h = np.fromfile(f, np.int32, 1).squeeze()
|
| 42 |
+
flow = np.fromfile(f, np.float32, w * h * 2).reshape((h, w, 2))
|
| 43 |
+
|
| 44 |
+
return flow.astype(np.float32)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def flowwrite(flow, filename, quantize=False, concat_axis=0, *args, **kwargs):
|
| 48 |
+
"""Write optical flow to file.
|
| 49 |
+
|
| 50 |
+
If the flow is not quantized, it will be saved as a .flo file losslessly,
|
| 51 |
+
otherwise a jpeg image which is lossy but of much smaller size. (dx and dy
|
| 52 |
+
will be concatenated horizontally into a single image if quantize is True.)
|
| 53 |
+
|
| 54 |
+
Args:
|
| 55 |
+
flow (ndarray): (h, w, 2) array of optical flow.
|
| 56 |
+
filename (str): Output filepath.
|
| 57 |
+
quantize (bool): Whether to quantize the flow and save it to 2 jpeg
|
| 58 |
+
images. If set to True, remaining args will be passed to
|
| 59 |
+
:func:`quantize_flow`.
|
| 60 |
+
concat_axis (int): The axis that dx and dy are concatenated,
|
| 61 |
+
can be either 0 or 1. Ignored if quantize is False.
|
| 62 |
+
"""
|
| 63 |
+
if not quantize:
|
| 64 |
+
with open(filename, 'wb') as f:
|
| 65 |
+
f.write('PIEH'.encode('utf-8'))
|
| 66 |
+
np.array([flow.shape[1], flow.shape[0]], dtype=np.int32).tofile(f)
|
| 67 |
+
flow = flow.astype(np.float32)
|
| 68 |
+
flow.tofile(f)
|
| 69 |
+
f.flush()
|
| 70 |
+
else:
|
| 71 |
+
assert concat_axis in [0, 1]
|
| 72 |
+
dx, dy = quantize_flow(flow, *args, **kwargs)
|
| 73 |
+
dxdy = np.concatenate((dx, dy), axis=concat_axis)
|
| 74 |
+
os.makedirs(filename, exist_ok=True)
|
| 75 |
+
cv2.imwrite(dxdy, filename)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def quantize_flow(flow, max_val=0.02, norm=True):
|
| 79 |
+
"""Quantize flow to [0, 255].
|
| 80 |
+
|
| 81 |
+
After this step, the size of flow will be much smaller, and can be
|
| 82 |
+
dumped as jpeg images.
|
| 83 |
+
|
| 84 |
+
Args:
|
| 85 |
+
flow (ndarray): (h, w, 2) array of optical flow.
|
| 86 |
+
max_val (float): Maximum value of flow, values beyond
|
| 87 |
+
[-max_val, max_val] will be truncated.
|
| 88 |
+
norm (bool): Whether to divide flow values by image width/height.
|
| 89 |
+
|
| 90 |
+
Returns:
|
| 91 |
+
tuple[ndarray]: Quantized dx and dy.
|
| 92 |
+
"""
|
| 93 |
+
h, w, _ = flow.shape
|
| 94 |
+
dx = flow[..., 0]
|
| 95 |
+
dy = flow[..., 1]
|
| 96 |
+
if norm:
|
| 97 |
+
dx = dx / w # avoid inplace operations
|
| 98 |
+
dy = dy / h
|
| 99 |
+
# use 255 levels instead of 256 to make sure 0 is 0 after dequantization.
|
| 100 |
+
flow_comps = [
|
| 101 |
+
quantize(d, -max_val, max_val, 255, np.uint8) for d in [dx, dy]
|
| 102 |
+
]
|
| 103 |
+
return tuple(flow_comps)
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def dequantize_flow(dx, dy, max_val=0.02, denorm=True):
|
| 107 |
+
"""Recover from quantized flow.
|
| 108 |
+
|
| 109 |
+
Args:
|
| 110 |
+
dx (ndarray): Quantized dx.
|
| 111 |
+
dy (ndarray): Quantized dy.
|
| 112 |
+
max_val (float): Maximum value used when quantizing.
|
| 113 |
+
denorm (bool): Whether to multiply flow values with width/height.
|
| 114 |
+
|
| 115 |
+
Returns:
|
| 116 |
+
ndarray: Dequantized flow.
|
| 117 |
+
"""
|
| 118 |
+
assert dx.shape == dy.shape
|
| 119 |
+
assert dx.ndim == 2 or (dx.ndim == 3 and dx.shape[-1] == 1)
|
| 120 |
+
|
| 121 |
+
dx, dy = [dequantize(d, -max_val, max_val, 255) for d in [dx, dy]]
|
| 122 |
+
|
| 123 |
+
if denorm:
|
| 124 |
+
dx *= dx.shape[1]
|
| 125 |
+
dy *= dx.shape[0]
|
| 126 |
+
flow = np.dstack((dx, dy))
|
| 127 |
+
return flow
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def quantize(arr, min_val, max_val, levels, dtype=np.int64):
|
| 131 |
+
"""Quantize an array of (-inf, inf) to [0, levels-1].
|
| 132 |
+
|
| 133 |
+
Args:
|
| 134 |
+
arr (ndarray): Input array.
|
| 135 |
+
min_val (scalar): Minimum value to be clipped.
|
| 136 |
+
max_val (scalar): Maximum value to be clipped.
|
| 137 |
+
levels (int): Quantization levels.
|
| 138 |
+
dtype (np.type): The type of the quantized array.
|
| 139 |
+
|
| 140 |
+
Returns:
|
| 141 |
+
tuple: Quantized array.
|
| 142 |
+
"""
|
| 143 |
+
if not (isinstance(levels, int) and levels > 1):
|
| 144 |
+
raise ValueError(
|
| 145 |
+
f'levels must be a positive integer, but got {levels}')
|
| 146 |
+
if min_val >= max_val:
|
| 147 |
+
raise ValueError(
|
| 148 |
+
f'min_val ({min_val}) must be smaller than max_val ({max_val})')
|
| 149 |
+
|
| 150 |
+
arr = np.clip(arr, min_val, max_val) - min_val
|
| 151 |
+
quantized_arr = np.minimum(
|
| 152 |
+
np.floor(levels * arr / (max_val - min_val)).astype(dtype), levels - 1)
|
| 153 |
+
|
| 154 |
+
return quantized_arr
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def dequantize(arr, min_val, max_val, levels, dtype=np.float64):
|
| 158 |
+
"""Dequantize an array.
|
| 159 |
+
|
| 160 |
+
Args:
|
| 161 |
+
arr (ndarray): Input array.
|
| 162 |
+
min_val (scalar): Minimum value to be clipped.
|
| 163 |
+
max_val (scalar): Maximum value to be clipped.
|
| 164 |
+
levels (int): Quantization levels.
|
| 165 |
+
dtype (np.type): The type of the dequantized array.
|
| 166 |
+
|
| 167 |
+
Returns:
|
| 168 |
+
tuple: Dequantized array.
|
| 169 |
+
"""
|
| 170 |
+
if not (isinstance(levels, int) and levels > 1):
|
| 171 |
+
raise ValueError(
|
| 172 |
+
f'levels must be a positive integer, but got {levels}')
|
| 173 |
+
if min_val >= max_val:
|
| 174 |
+
raise ValueError(
|
| 175 |
+
f'min_val ({min_val}) must be smaller than max_val ({max_val})')
|
| 176 |
+
|
| 177 |
+
dequantized_arr = (arr + 0.5).astype(dtype) * (max_val -
|
| 178 |
+
min_val) / levels + min_val
|
| 179 |
+
|
| 180 |
+
return dequantized_arr
|
basicsr/utils/img_util.py
ADDED
|
@@ -0,0 +1,216 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import cv2
|
| 2 |
+
import math
|
| 3 |
+
import numpy as np
|
| 4 |
+
import os
|
| 5 |
+
import torch
|
| 6 |
+
from torchvision.utils import make_grid
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def img2tensor(imgs, bgr2rgb=True, float32=True):
|
| 10 |
+
"""Numpy array to tensor.
|
| 11 |
+
|
| 12 |
+
Args:
|
| 13 |
+
imgs (list[ndarray] | ndarray): Input images.
|
| 14 |
+
bgr2rgb (bool): Whether to change bgr to rgb.
|
| 15 |
+
float32 (bool): Whether to change to float32.
|
| 16 |
+
|
| 17 |
+
Returns:
|
| 18 |
+
list[tensor] | tensor: Tensor images. If returned results only have
|
| 19 |
+
one element, just return tensor.
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
def _totensor(img, bgr2rgb, float32):
|
| 23 |
+
if img.shape[2] == 3 and bgr2rgb:
|
| 24 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
| 25 |
+
img = torch.from_numpy(img.transpose(2, 0, 1))
|
| 26 |
+
if float32:
|
| 27 |
+
img = img.float()
|
| 28 |
+
return img
|
| 29 |
+
|
| 30 |
+
if isinstance(imgs, list):
|
| 31 |
+
return [_totensor(img, bgr2rgb, float32) for img in imgs]
|
| 32 |
+
else:
|
| 33 |
+
return _totensor(imgs, bgr2rgb, float32)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def tensor2img(tensor, rgb2bgr=True, out_type=np.uint8, min_max=(0, 1)):
|
| 37 |
+
"""Convert torch Tensors into image numpy arrays.
|
| 38 |
+
|
| 39 |
+
After clamping to [min, max], values will be normalized to [0, 1].
|
| 40 |
+
|
| 41 |
+
Args:
|
| 42 |
+
tensor (Tensor or list[Tensor]): Accept shapes:
|
| 43 |
+
1) 4D mini-batch Tensor of shape (B x 3/1 x H x W);
|
| 44 |
+
2) 3D Tensor of shape (3/1 x H x W);
|
| 45 |
+
3) 2D Tensor of shape (H x W).
|
| 46 |
+
Tensor channel should be in RGB order.
|
| 47 |
+
rgb2bgr (bool): Whether to change rgb to bgr.
|
| 48 |
+
out_type (numpy type): output types. If ``np.uint8``, transform outputs
|
| 49 |
+
to uint8 type with range [0, 255]; otherwise, float type with
|
| 50 |
+
range [0, 1]. Default: ``np.uint8``.
|
| 51 |
+
min_max (tuple[int]): min and max values for clamp.
|
| 52 |
+
|
| 53 |
+
Returns:
|
| 54 |
+
(Tensor or list): 3D ndarray of shape (H x W x C) OR 2D ndarray of
|
| 55 |
+
shape (H x W). The channel order is BGR.
|
| 56 |
+
"""
|
| 57 |
+
if not (torch.is_tensor(tensor) or
|
| 58 |
+
(isinstance(tensor, list)
|
| 59 |
+
and all(torch.is_tensor(t) for t in tensor))):
|
| 60 |
+
raise TypeError(
|
| 61 |
+
f'tensor or list of tensors expected, got {type(tensor)}')
|
| 62 |
+
|
| 63 |
+
if torch.is_tensor(tensor):
|
| 64 |
+
tensor = [tensor]
|
| 65 |
+
result = []
|
| 66 |
+
for _tensor in tensor:
|
| 67 |
+
_tensor = _tensor.squeeze(0).float().detach().cpu().clamp_(*min_max)
|
| 68 |
+
_tensor = (_tensor - min_max[0]) / (min_max[1] - min_max[0])
|
| 69 |
+
|
| 70 |
+
n_dim = _tensor.dim()
|
| 71 |
+
if n_dim == 4:
|
| 72 |
+
img_np = make_grid(
|
| 73 |
+
_tensor, nrow=int(math.sqrt(_tensor.size(0))),
|
| 74 |
+
normalize=False).numpy()
|
| 75 |
+
img_np = img_np.transpose(1, 2, 0)
|
| 76 |
+
if rgb2bgr:
|
| 77 |
+
img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
|
| 78 |
+
elif n_dim == 3:
|
| 79 |
+
img_np = _tensor.numpy()
|
| 80 |
+
img_np = img_np.transpose(1, 2, 0)
|
| 81 |
+
if img_np.shape[2] == 1: # gray image
|
| 82 |
+
img_np = np.squeeze(img_np, axis=2)
|
| 83 |
+
else:
|
| 84 |
+
if rgb2bgr:
|
| 85 |
+
img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
|
| 86 |
+
elif n_dim == 2:
|
| 87 |
+
img_np = _tensor.numpy()
|
| 88 |
+
else:
|
| 89 |
+
raise TypeError('Only support 4D, 3D or 2D tensor. '
|
| 90 |
+
f'But received with dimension: {n_dim}')
|
| 91 |
+
if out_type == np.uint8:
|
| 92 |
+
# Unlike MATLAB, numpy.unit8() WILL NOT round by default.
|
| 93 |
+
img_np = (img_np * 255.0).round()
|
| 94 |
+
img_np = img_np.astype(out_type)
|
| 95 |
+
result.append(img_np)
|
| 96 |
+
if len(result) == 1:
|
| 97 |
+
result = result[0]
|
| 98 |
+
return result
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def imfrombytes(content, flag='color', float32=False):
|
| 102 |
+
"""Read an image from bytes.
|
| 103 |
+
|
| 104 |
+
Args:
|
| 105 |
+
content (bytes): Image bytes got from files or other streams.
|
| 106 |
+
flag (str): Flags specifying the color type of a loaded image,
|
| 107 |
+
candidates are `color`, `grayscale` and `unchanged`.
|
| 108 |
+
float32 (bool): Whether to change to float32., If True, will also norm
|
| 109 |
+
to [0, 1]. Default: False.
|
| 110 |
+
|
| 111 |
+
Returns:
|
| 112 |
+
ndarray: Loaded image array.
|
| 113 |
+
"""
|
| 114 |
+
img_np = np.frombuffer(content, np.uint8)
|
| 115 |
+
imread_flags = {
|
| 116 |
+
'color': cv2.IMREAD_COLOR,
|
| 117 |
+
'grayscale': cv2.IMREAD_GRAYSCALE,
|
| 118 |
+
'unchanged': cv2.IMREAD_UNCHANGED
|
| 119 |
+
}
|
| 120 |
+
if img_np is None:
|
| 121 |
+
raise Exception('None .. !!!')
|
| 122 |
+
img = cv2.imdecode(img_np, imread_flags[flag])
|
| 123 |
+
if float32:
|
| 124 |
+
img = img.astype(np.float32) / 255.
|
| 125 |
+
return img
|
| 126 |
+
|
| 127 |
+
def imfrombytesDP(content, flag='color', float32=False):
|
| 128 |
+
"""Read an image from bytes.
|
| 129 |
+
|
| 130 |
+
Args:
|
| 131 |
+
content (bytes): Image bytes got from files or other streams.
|
| 132 |
+
flag (str): Flags specifying the color type of a loaded image,
|
| 133 |
+
candidates are `color`, `grayscale` and `unchanged`.
|
| 134 |
+
float32 (bool): Whether to change to float32., If True, will also norm
|
| 135 |
+
to [0, 1]. Default: False.
|
| 136 |
+
|
| 137 |
+
Returns:
|
| 138 |
+
ndarray: Loaded image array.
|
| 139 |
+
"""
|
| 140 |
+
img_np = np.frombuffer(content, np.uint8)
|
| 141 |
+
if img_np is None:
|
| 142 |
+
raise Exception('None .. !!!')
|
| 143 |
+
img = cv2.imdecode(img_np, cv2.IMREAD_UNCHANGED)
|
| 144 |
+
if float32:
|
| 145 |
+
img = img.astype(np.float32) / 65535.
|
| 146 |
+
return img
|
| 147 |
+
|
| 148 |
+
def padding(img_gt, gt_size):
|
| 149 |
+
h, w, _ = img_gt.shape
|
| 150 |
+
|
| 151 |
+
h_pad = max(0, gt_size - h)
|
| 152 |
+
w_pad = max(0, gt_size - w)
|
| 153 |
+
|
| 154 |
+
if h_pad == 0 and w_pad == 0:
|
| 155 |
+
return img_gt
|
| 156 |
+
|
| 157 |
+
img_gt = cv2.copyMakeBorder(img_gt, 0, h_pad, 0, w_pad, cv2.BORDER_REFLECT)
|
| 158 |
+
if img_gt.ndim == 2:
|
| 159 |
+
img_gt = np.expand_dims(img_gt, axis=2)
|
| 160 |
+
return img_gt
|
| 161 |
+
|
| 162 |
+
def padding_DP(img_lqL, img_lqR, img_gt, gt_size):
|
| 163 |
+
h, w, _ = img_gt.shape
|
| 164 |
+
|
| 165 |
+
h_pad = max(0, gt_size - h)
|
| 166 |
+
w_pad = max(0, gt_size - w)
|
| 167 |
+
|
| 168 |
+
if h_pad == 0 and w_pad == 0:
|
| 169 |
+
return img_lqL, img_lqR, img_gt
|
| 170 |
+
|
| 171 |
+
img_lqL = cv2.copyMakeBorder(img_lqL, 0, h_pad, 0, w_pad, cv2.BORDER_REFLECT)
|
| 172 |
+
img_lqR = cv2.copyMakeBorder(img_lqR, 0, h_pad, 0, w_pad, cv2.BORDER_REFLECT)
|
| 173 |
+
img_gt = cv2.copyMakeBorder(img_gt, 0, h_pad, 0, w_pad, cv2.BORDER_REFLECT)
|
| 174 |
+
# print('img_lq', img_lq.shape, img_gt.shape)
|
| 175 |
+
return img_lqL, img_lqR, img_gt
|
| 176 |
+
|
| 177 |
+
def imwrite(img, file_path, params=None, auto_mkdir=True):
|
| 178 |
+
"""Write image to file.
|
| 179 |
+
|
| 180 |
+
Args:
|
| 181 |
+
img (ndarray): Image array to be written.
|
| 182 |
+
file_path (str): Image file path.
|
| 183 |
+
params (None or list): Same as opencv's :func:`imwrite` interface.
|
| 184 |
+
auto_mkdir (bool): If the parent folder of `file_path` does not exist,
|
| 185 |
+
whether to create it automatically.
|
| 186 |
+
|
| 187 |
+
Returns:
|
| 188 |
+
bool: Successful or not.
|
| 189 |
+
"""
|
| 190 |
+
if auto_mkdir:
|
| 191 |
+
dir_name = os.path.abspath(os.path.dirname(file_path))
|
| 192 |
+
os.makedirs(dir_name, exist_ok=True)
|
| 193 |
+
return cv2.imwrite(file_path, img, params)
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
def crop_border(imgs, crop_border):
|
| 197 |
+
"""Crop borders of images.
|
| 198 |
+
|
| 199 |
+
Args:
|
| 200 |
+
imgs (list[ndarray] | ndarray): Images with shape (h, w, c).
|
| 201 |
+
crop_border (int): Crop border for each end of height and weight.
|
| 202 |
+
|
| 203 |
+
Returns:
|
| 204 |
+
list[ndarray]: Cropped images.
|
| 205 |
+
"""
|
| 206 |
+
if crop_border == 0:
|
| 207 |
+
return imgs
|
| 208 |
+
else:
|
| 209 |
+
if isinstance(imgs, list):
|
| 210 |
+
return [
|
| 211 |
+
v[crop_border:-crop_border, crop_border:-crop_border, ...]
|
| 212 |
+
for v in imgs
|
| 213 |
+
]
|
| 214 |
+
else:
|
| 215 |
+
return imgs[crop_border:-crop_border, crop_border:-crop_border,
|
| 216 |
+
...]
|
basicsr/utils/lmdb_util.py
ADDED
|
@@ -0,0 +1,208 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import cv2
|
| 2 |
+
import lmdb
|
| 3 |
+
import sys
|
| 4 |
+
from multiprocessing import Pool
|
| 5 |
+
from os import path as osp
|
| 6 |
+
from tqdm import tqdm
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def make_lmdb_from_imgs(data_path,
|
| 10 |
+
lmdb_path,
|
| 11 |
+
img_path_list,
|
| 12 |
+
keys,
|
| 13 |
+
batch=5000,
|
| 14 |
+
compress_level=1,
|
| 15 |
+
multiprocessing_read=False,
|
| 16 |
+
n_thread=40,
|
| 17 |
+
map_size=None):
|
| 18 |
+
"""Make lmdb from images.
|
| 19 |
+
|
| 20 |
+
Contents of lmdb. The file structure is:
|
| 21 |
+
example.lmdb
|
| 22 |
+
├── data.mdb
|
| 23 |
+
├── lock.mdb
|
| 24 |
+
├── meta_info.txt
|
| 25 |
+
|
| 26 |
+
The data.mdb and lock.mdb are standard lmdb files and you can refer to
|
| 27 |
+
https://lmdb.readthedocs.io/en/release/ for more details.
|
| 28 |
+
|
| 29 |
+
The meta_info.txt is a specified txt file to record the meta information
|
| 30 |
+
of our datasets. It will be automatically created when preparing
|
| 31 |
+
datasets by our provided dataset tools.
|
| 32 |
+
Each line in the txt file records 1)image name (with extension),
|
| 33 |
+
2)image shape, and 3)compression level, separated by a white space.
|
| 34 |
+
|
| 35 |
+
For example, the meta information could be:
|
| 36 |
+
`000_00000000.png (720,1280,3) 1`, which means:
|
| 37 |
+
1) image name (with extension): 000_00000000.png;
|
| 38 |
+
2) image shape: (720,1280,3);
|
| 39 |
+
3) compression level: 1
|
| 40 |
+
|
| 41 |
+
We use the image name without extension as the lmdb key.
|
| 42 |
+
|
| 43 |
+
If `multiprocessing_read` is True, it will read all the images to memory
|
| 44 |
+
using multiprocessing. Thus, your server needs to have enough memory.
|
| 45 |
+
|
| 46 |
+
Args:
|
| 47 |
+
data_path (str): Data path for reading images.
|
| 48 |
+
lmdb_path (str): Lmdb save path.
|
| 49 |
+
img_path_list (str): Image path list.
|
| 50 |
+
keys (str): Used for lmdb keys.
|
| 51 |
+
batch (int): After processing batch images, lmdb commits.
|
| 52 |
+
Default: 5000.
|
| 53 |
+
compress_level (int): Compress level when encoding images. Default: 1.
|
| 54 |
+
multiprocessing_read (bool): Whether use multiprocessing to read all
|
| 55 |
+
the images to memory. Default: False.
|
| 56 |
+
n_thread (int): For multiprocessing.
|
| 57 |
+
map_size (int | None): Map size for lmdb env. If None, use the
|
| 58 |
+
estimated size from images. Default: None
|
| 59 |
+
"""
|
| 60 |
+
|
| 61 |
+
assert len(img_path_list) == len(keys), (
|
| 62 |
+
'img_path_list and keys should have the same length, '
|
| 63 |
+
f'but got {len(img_path_list)} and {len(keys)}')
|
| 64 |
+
print(f'Create lmdb for {data_path}, save to {lmdb_path}...')
|
| 65 |
+
print(f'Totoal images: {len(img_path_list)}')
|
| 66 |
+
if not lmdb_path.endswith('.lmdb'):
|
| 67 |
+
raise ValueError("lmdb_path must end with '.lmdb'.")
|
| 68 |
+
if osp.exists(lmdb_path):
|
| 69 |
+
print(f'Folder {lmdb_path} already exists. Exit.')
|
| 70 |
+
sys.exit(1)
|
| 71 |
+
|
| 72 |
+
if multiprocessing_read:
|
| 73 |
+
# read all the images to memory (multiprocessing)
|
| 74 |
+
dataset = {} # use dict to keep the order for multiprocessing
|
| 75 |
+
shapes = {}
|
| 76 |
+
print(f'Read images with multiprocessing, #thread: {n_thread} ...')
|
| 77 |
+
pbar = tqdm(total=len(img_path_list), unit='image')
|
| 78 |
+
|
| 79 |
+
def callback(arg):
|
| 80 |
+
"""get the image data and update pbar."""
|
| 81 |
+
key, dataset[key], shapes[key] = arg
|
| 82 |
+
pbar.update(1)
|
| 83 |
+
pbar.set_description(f'Read {key}')
|
| 84 |
+
|
| 85 |
+
pool = Pool(n_thread)
|
| 86 |
+
for path, key in zip(img_path_list, keys):
|
| 87 |
+
pool.apply_async(
|
| 88 |
+
read_img_worker,
|
| 89 |
+
args=(osp.join(data_path, path), key, compress_level),
|
| 90 |
+
callback=callback)
|
| 91 |
+
pool.close()
|
| 92 |
+
pool.join()
|
| 93 |
+
pbar.close()
|
| 94 |
+
print(f'Finish reading {len(img_path_list)} images.')
|
| 95 |
+
|
| 96 |
+
# create lmdb environment
|
| 97 |
+
if map_size is None:
|
| 98 |
+
# obtain data size for one image
|
| 99 |
+
img = cv2.imread(
|
| 100 |
+
osp.join(data_path, img_path_list[0]), cv2.IMREAD_UNCHANGED)
|
| 101 |
+
_, img_byte = cv2.imencode(
|
| 102 |
+
'.png', img, [cv2.IMWRITE_PNG_COMPRESSION, compress_level])
|
| 103 |
+
data_size_per_img = img_byte.nbytes
|
| 104 |
+
print('Data size per image is: ', data_size_per_img)
|
| 105 |
+
data_size = data_size_per_img * len(img_path_list)
|
| 106 |
+
map_size = data_size * 10
|
| 107 |
+
|
| 108 |
+
env = lmdb.open(lmdb_path, map_size=map_size)
|
| 109 |
+
|
| 110 |
+
# write data to lmdb
|
| 111 |
+
pbar = tqdm(total=len(img_path_list), unit='chunk')
|
| 112 |
+
txn = env.begin(write=True)
|
| 113 |
+
txt_file = open(osp.join(lmdb_path, 'meta_info.txt'), 'w')
|
| 114 |
+
for idx, (path, key) in enumerate(zip(img_path_list, keys)):
|
| 115 |
+
pbar.update(1)
|
| 116 |
+
pbar.set_description(f'Write {key}')
|
| 117 |
+
key_byte = key.encode('ascii')
|
| 118 |
+
if multiprocessing_read:
|
| 119 |
+
img_byte = dataset[key]
|
| 120 |
+
h, w, c = shapes[key]
|
| 121 |
+
else:
|
| 122 |
+
_, img_byte, img_shape = read_img_worker(
|
| 123 |
+
osp.join(data_path, path), key, compress_level)
|
| 124 |
+
h, w, c = img_shape
|
| 125 |
+
|
| 126 |
+
txn.put(key_byte, img_byte)
|
| 127 |
+
# write meta information
|
| 128 |
+
txt_file.write(f'{key}.png ({h},{w},{c}) {compress_level}\n')
|
| 129 |
+
if idx % batch == 0:
|
| 130 |
+
txn.commit()
|
| 131 |
+
txn = env.begin(write=True)
|
| 132 |
+
pbar.close()
|
| 133 |
+
txn.commit()
|
| 134 |
+
env.close()
|
| 135 |
+
txt_file.close()
|
| 136 |
+
print('\nFinish writing lmdb.')
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
def read_img_worker(path, key, compress_level):
|
| 140 |
+
"""Read image worker.
|
| 141 |
+
|
| 142 |
+
Args:
|
| 143 |
+
path (str): Image path.
|
| 144 |
+
key (str): Image key.
|
| 145 |
+
compress_level (int): Compress level when encoding images.
|
| 146 |
+
|
| 147 |
+
Returns:
|
| 148 |
+
str: Image key.
|
| 149 |
+
byte: Image byte.
|
| 150 |
+
tuple[int]: Image shape.
|
| 151 |
+
"""
|
| 152 |
+
|
| 153 |
+
img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
|
| 154 |
+
if img.ndim == 2:
|
| 155 |
+
h, w = img.shape
|
| 156 |
+
c = 1
|
| 157 |
+
else:
|
| 158 |
+
h, w, c = img.shape
|
| 159 |
+
_, img_byte = cv2.imencode('.png', img,
|
| 160 |
+
[cv2.IMWRITE_PNG_COMPRESSION, compress_level])
|
| 161 |
+
return (key, img_byte, (h, w, c))
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
class LmdbMaker():
|
| 165 |
+
"""LMDB Maker.
|
| 166 |
+
|
| 167 |
+
Args:
|
| 168 |
+
lmdb_path (str): Lmdb save path.
|
| 169 |
+
map_size (int): Map size for lmdb env. Default: 1024 ** 4, 1TB.
|
| 170 |
+
batch (int): After processing batch images, lmdb commits.
|
| 171 |
+
Default: 5000.
|
| 172 |
+
compress_level (int): Compress level when encoding images. Default: 1.
|
| 173 |
+
"""
|
| 174 |
+
|
| 175 |
+
def __init__(self,
|
| 176 |
+
lmdb_path,
|
| 177 |
+
map_size=1024**4,
|
| 178 |
+
batch=5000,
|
| 179 |
+
compress_level=1):
|
| 180 |
+
if not lmdb_path.endswith('.lmdb'):
|
| 181 |
+
raise ValueError("lmdb_path must end with '.lmdb'.")
|
| 182 |
+
if osp.exists(lmdb_path):
|
| 183 |
+
print(f'Folder {lmdb_path} already exists. Exit.')
|
| 184 |
+
sys.exit(1)
|
| 185 |
+
|
| 186 |
+
self.lmdb_path = lmdb_path
|
| 187 |
+
self.batch = batch
|
| 188 |
+
self.compress_level = compress_level
|
| 189 |
+
self.env = lmdb.open(lmdb_path, map_size=map_size)
|
| 190 |
+
self.txn = self.env.begin(write=True)
|
| 191 |
+
self.txt_file = open(osp.join(lmdb_path, 'meta_info.txt'), 'w')
|
| 192 |
+
self.counter = 0
|
| 193 |
+
|
| 194 |
+
def put(self, img_byte, key, img_shape):
|
| 195 |
+
self.counter += 1
|
| 196 |
+
key_byte = key.encode('ascii')
|
| 197 |
+
self.txn.put(key_byte, img_byte)
|
| 198 |
+
# write meta information
|
| 199 |
+
h, w, c = img_shape
|
| 200 |
+
self.txt_file.write(f'{key}.png ({h},{w},{c}) {self.compress_level}\n')
|
| 201 |
+
if self.counter % self.batch == 0:
|
| 202 |
+
self.txn.commit()
|
| 203 |
+
self.txn = self.env.begin(write=True)
|
| 204 |
+
|
| 205 |
+
def close(self):
|
| 206 |
+
self.txn.commit()
|
| 207 |
+
self.env.close()
|
| 208 |
+
self.txt_file.close()
|
basicsr/utils/logger.py
ADDED
|
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import datetime
|
| 2 |
+
import logging
|
| 3 |
+
import time
|
| 4 |
+
|
| 5 |
+
from .dist_util import get_dist_info, master_only
|
| 6 |
+
|
| 7 |
+
initialized_logger = {}
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class MessageLogger():
|
| 11 |
+
"""Message logger for printing.
|
| 12 |
+
|
| 13 |
+
Args:
|
| 14 |
+
opt (dict): Config. It contains the following keys:
|
| 15 |
+
name (str): Exp name.
|
| 16 |
+
logger (dict): Contains 'print_freq' (str) for logger interval.
|
| 17 |
+
train (dict): Contains 'total_iter' (int) for total iters.
|
| 18 |
+
use_tb_logger (bool): Use tensorboard logger.
|
| 19 |
+
start_iter (int): Start iter. Default: 1.
|
| 20 |
+
tb_logger (obj:`tb_logger`): Tensorboard logger. Default: None.
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
def __init__(self, opt, start_iter=1, tb_logger=None):
|
| 24 |
+
self.exp_name = opt['name']
|
| 25 |
+
self.interval = opt['logger']['print_freq']
|
| 26 |
+
self.start_iter = start_iter
|
| 27 |
+
self.max_iters = opt['train']['total_iter']
|
| 28 |
+
self.use_tb_logger = opt['logger']['use_tb_logger']
|
| 29 |
+
self.tb_logger = tb_logger
|
| 30 |
+
self.start_time = time.time()
|
| 31 |
+
self.logger = get_root_logger()
|
| 32 |
+
|
| 33 |
+
@master_only
|
| 34 |
+
def __call__(self, log_vars):
|
| 35 |
+
"""Format logging message.
|
| 36 |
+
|
| 37 |
+
Args:
|
| 38 |
+
log_vars (dict): It contains the following keys:
|
| 39 |
+
epoch (int): Epoch number.
|
| 40 |
+
iter (int): Current iter.
|
| 41 |
+
lrs (list): List for learning rates.
|
| 42 |
+
|
| 43 |
+
time (float): Iter time.
|
| 44 |
+
data_time (float): Data time for each iter.
|
| 45 |
+
"""
|
| 46 |
+
# epoch, iter, learning rates
|
| 47 |
+
epoch = log_vars.pop('epoch')
|
| 48 |
+
current_iter = log_vars.pop('iter')
|
| 49 |
+
lrs = log_vars.pop('lrs')
|
| 50 |
+
|
| 51 |
+
message = (f'[{self.exp_name[:5]}..][epoch:{epoch:3d}, ' f'iter:{current_iter:8,d}, lr:(')
|
| 52 |
+
for v in lrs:
|
| 53 |
+
message += f'{v:.3e},'
|
| 54 |
+
message += ')] '
|
| 55 |
+
|
| 56 |
+
# time and estimated time
|
| 57 |
+
if 'time' in log_vars.keys():
|
| 58 |
+
iter_time = log_vars.pop('time')
|
| 59 |
+
data_time = log_vars.pop('data_time')
|
| 60 |
+
|
| 61 |
+
total_time = time.time() - self.start_time
|
| 62 |
+
time_sec_avg = total_time / (current_iter - self.start_iter + 1)
|
| 63 |
+
eta_sec = time_sec_avg * (self.max_iters - current_iter - 1)
|
| 64 |
+
eta_str = str(datetime.timedelta(seconds=int(eta_sec)))
|
| 65 |
+
message += f'[eta: {eta_str}, '
|
| 66 |
+
message += f'time (data): {iter_time:.3f} ({data_time:.3f})] '
|
| 67 |
+
|
| 68 |
+
# other items, especially losses
|
| 69 |
+
for k, v in log_vars.items():
|
| 70 |
+
message += f'{k}: {v:.4e} '
|
| 71 |
+
# tensorboard logger
|
| 72 |
+
if self.use_tb_logger and 'debug' not in self.exp_name:
|
| 73 |
+
if k.startswith('l_'):
|
| 74 |
+
self.tb_logger.add_scalar(f'losses/{k}', v, current_iter)
|
| 75 |
+
else:
|
| 76 |
+
self.tb_logger.add_scalar(k, v, current_iter)
|
| 77 |
+
self.logger.info(message)
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
@master_only
|
| 81 |
+
def init_tb_logger(log_dir):
|
| 82 |
+
from torch.utils.tensorboard import SummaryWriter
|
| 83 |
+
tb_logger = SummaryWriter(log_dir=log_dir)
|
| 84 |
+
return tb_logger
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
@master_only
|
| 88 |
+
def init_wandb_logger(opt):
|
| 89 |
+
"""We now only use wandb to sync tensorboard log."""
|
| 90 |
+
import wandb
|
| 91 |
+
logger = logging.getLogger('basicsr')
|
| 92 |
+
|
| 93 |
+
project = opt['logger']['wandb']['project']
|
| 94 |
+
resume_id = opt['logger']['wandb'].get('resume_id')
|
| 95 |
+
if resume_id:
|
| 96 |
+
wandb_id = resume_id
|
| 97 |
+
resume = 'allow'
|
| 98 |
+
logger.warning(f'Resume wandb logger with id={wandb_id}.')
|
| 99 |
+
else:
|
| 100 |
+
wandb_id = wandb.util.generate_id()
|
| 101 |
+
resume = 'never'
|
| 102 |
+
|
| 103 |
+
wandb.init(id=wandb_id, resume=resume, name=opt['name'], config=opt, project=project, sync_tensorboard=True)
|
| 104 |
+
|
| 105 |
+
logger.info(f'Use wandb logger with id={wandb_id}; project={project}.')
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=None):
|
| 109 |
+
"""Get the root logger.
|
| 110 |
+
|
| 111 |
+
The logger will be initialized if it has not been initialized. By default a
|
| 112 |
+
StreamHandler will be added. If `log_file` is specified, a FileHandler will
|
| 113 |
+
also be added.
|
| 114 |
+
|
| 115 |
+
Args:
|
| 116 |
+
logger_name (str): root logger name. Default: 'basicsr'.
|
| 117 |
+
log_file (str | None): The log filename. If specified, a FileHandler
|
| 118 |
+
will be added to the root logger.
|
| 119 |
+
log_level (int): The root logger level. Note that only the process of
|
| 120 |
+
rank 0 is affected, while other processes will set the level to
|
| 121 |
+
"Error" and be silent most of the time.
|
| 122 |
+
|
| 123 |
+
Returns:
|
| 124 |
+
logging.Logger: The root logger.
|
| 125 |
+
"""
|
| 126 |
+
logger = logging.getLogger(logger_name)
|
| 127 |
+
# if the logger has been initialized, just return it
|
| 128 |
+
if logger_name in initialized_logger:
|
| 129 |
+
return logger
|
| 130 |
+
|
| 131 |
+
format_str = '%(asctime)s %(levelname)s: %(message)s'
|
| 132 |
+
stream_handler = logging.StreamHandler()
|
| 133 |
+
stream_handler.setFormatter(logging.Formatter(format_str))
|
| 134 |
+
logger.addHandler(stream_handler)
|
| 135 |
+
logger.propagate = False
|
| 136 |
+
rank, _ = get_dist_info()
|
| 137 |
+
if rank != 0:
|
| 138 |
+
logger.setLevel('ERROR')
|
| 139 |
+
elif log_file is not None:
|
| 140 |
+
logger.setLevel(log_level)
|
| 141 |
+
# add file handler
|
| 142 |
+
file_handler = logging.FileHandler(log_file, 'w')
|
| 143 |
+
file_handler.setFormatter(logging.Formatter(format_str))
|
| 144 |
+
file_handler.setLevel(log_level)
|
| 145 |
+
logger.addHandler(file_handler)
|
| 146 |
+
initialized_logger[logger_name] = True
|
| 147 |
+
return logger
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
def get_env_info():
|
| 151 |
+
"""Get environment information.
|
| 152 |
+
|
| 153 |
+
Currently, only log the software version.
|
| 154 |
+
"""
|
| 155 |
+
import torch
|
| 156 |
+
import torchvision
|
| 157 |
+
|
| 158 |
+
from basicsr.version import __version__
|
| 159 |
+
msg = r"""
|
| 160 |
+
____ _ _____ ____
|
| 161 |
+
/ __ ) ____ _ _____ (_)_____/ ___/ / __ \
|
| 162 |
+
/ __ |/ __ `// ___// // ___/\__ \ / /_/ /
|
| 163 |
+
/ /_/ // /_/ /(__ )/ // /__ ___/ // _, _/
|
| 164 |
+
/_____/ \__,_//____//_/ \___//____//_/ |_|
|
| 165 |
+
______ __ __ __ __
|
| 166 |
+
/ ____/____ ____ ____/ / / / __ __ _____ / /__ / /
|
| 167 |
+
/ / __ / __ \ / __ \ / __ / / / / / / // ___// //_/ / /
|
| 168 |
+
/ /_/ // /_/ // /_/ // /_/ / / /___/ /_/ // /__ / /< /_/
|
| 169 |
+
\____/ \____/ \____/ \____/ /_____/\____/ \___//_/|_| (_)
|
| 170 |
+
"""
|
| 171 |
+
msg += ('\nVersion Information: '
|
| 172 |
+
f'\n\tBasicSR: {__version__}'
|
| 173 |
+
f'\n\tPyTorch: {torch.__version__}'
|
| 174 |
+
f'\n\tTorchVision: {torchvision.__version__}')
|
| 175 |
+
return msg
|
basicsr/utils/matlab_functions.py
ADDED
|
@@ -0,0 +1,361 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def cubic(x):
|
| 7 |
+
"""cubic function used for calculate_weights_indices."""
|
| 8 |
+
absx = torch.abs(x)
|
| 9 |
+
absx2 = absx**2
|
| 10 |
+
absx3 = absx**3
|
| 11 |
+
return (1.5 * absx3 - 2.5 * absx2 + 1) * (
|
| 12 |
+
(absx <= 1).type_as(absx)) + (-0.5 * absx3 + 2.5 * absx2 - 4 * absx +
|
| 13 |
+
2) * (((absx > 1) *
|
| 14 |
+
(absx <= 2)).type_as(absx))
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def calculate_weights_indices(in_length, out_length, scale, kernel,
|
| 18 |
+
kernel_width, antialiasing):
|
| 19 |
+
"""Calculate weights and indices, used for imresize function.
|
| 20 |
+
|
| 21 |
+
Args:
|
| 22 |
+
in_length (int): Input length.
|
| 23 |
+
out_length (int): Output length.
|
| 24 |
+
scale (float): Scale factor.
|
| 25 |
+
kernel_width (int): Kernel width.
|
| 26 |
+
antialisaing (bool): Whether to apply anti-aliasing when downsampling.
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
if (scale < 1) and antialiasing:
|
| 30 |
+
# Use a modified kernel (larger kernel width) to simultaneously
|
| 31 |
+
# interpolate and antialias
|
| 32 |
+
kernel_width = kernel_width / scale
|
| 33 |
+
|
| 34 |
+
# Output-space coordinates
|
| 35 |
+
x = torch.linspace(1, out_length, out_length)
|
| 36 |
+
|
| 37 |
+
# Input-space coordinates. Calculate the inverse mapping such that 0.5
|
| 38 |
+
# in output space maps to 0.5 in input space, and 0.5 + scale in output
|
| 39 |
+
# space maps to 1.5 in input space.
|
| 40 |
+
u = x / scale + 0.5 * (1 - 1 / scale)
|
| 41 |
+
|
| 42 |
+
# What is the left-most pixel that can be involved in the computation?
|
| 43 |
+
left = torch.floor(u - kernel_width / 2)
|
| 44 |
+
|
| 45 |
+
# What is the maximum number of pixels that can be involved in the
|
| 46 |
+
# computation? Note: it's OK to use an extra pixel here; if the
|
| 47 |
+
# corresponding weights are all zero, it will be eliminated at the end
|
| 48 |
+
# of this function.
|
| 49 |
+
p = math.ceil(kernel_width) + 2
|
| 50 |
+
|
| 51 |
+
# The indices of the input pixels involved in computing the k-th output
|
| 52 |
+
# pixel are in row k of the indices matrix.
|
| 53 |
+
indices = left.view(out_length, 1).expand(out_length, p) + torch.linspace(
|
| 54 |
+
0, p - 1, p).view(1, p).expand(out_length, p)
|
| 55 |
+
|
| 56 |
+
# The weights used to compute the k-th output pixel are in row k of the
|
| 57 |
+
# weights matrix.
|
| 58 |
+
distance_to_center = u.view(out_length, 1).expand(out_length, p) - indices
|
| 59 |
+
|
| 60 |
+
# apply cubic kernel
|
| 61 |
+
if (scale < 1) and antialiasing:
|
| 62 |
+
weights = scale * cubic(distance_to_center * scale)
|
| 63 |
+
else:
|
| 64 |
+
weights = cubic(distance_to_center)
|
| 65 |
+
|
| 66 |
+
# Normalize the weights matrix so that each row sums to 1.
|
| 67 |
+
weights_sum = torch.sum(weights, 1).view(out_length, 1)
|
| 68 |
+
weights = weights / weights_sum.expand(out_length, p)
|
| 69 |
+
|
| 70 |
+
# If a column in weights is all zero, get rid of it. only consider the
|
| 71 |
+
# first and last column.
|
| 72 |
+
weights_zero_tmp = torch.sum((weights == 0), 0)
|
| 73 |
+
if not math.isclose(weights_zero_tmp[0], 0, rel_tol=1e-6):
|
| 74 |
+
indices = indices.narrow(1, 1, p - 2)
|
| 75 |
+
weights = weights.narrow(1, 1, p - 2)
|
| 76 |
+
if not math.isclose(weights_zero_tmp[-1], 0, rel_tol=1e-6):
|
| 77 |
+
indices = indices.narrow(1, 0, p - 2)
|
| 78 |
+
weights = weights.narrow(1, 0, p - 2)
|
| 79 |
+
weights = weights.contiguous()
|
| 80 |
+
indices = indices.contiguous()
|
| 81 |
+
sym_len_s = -indices.min() + 1
|
| 82 |
+
sym_len_e = indices.max() - in_length
|
| 83 |
+
indices = indices + sym_len_s - 1
|
| 84 |
+
return weights, indices, int(sym_len_s), int(sym_len_e)
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
@torch.no_grad()
|
| 88 |
+
def imresize(img, scale, antialiasing=True):
|
| 89 |
+
"""imresize function same as MATLAB.
|
| 90 |
+
|
| 91 |
+
It now only supports bicubic.
|
| 92 |
+
The same scale applies for both height and width.
|
| 93 |
+
|
| 94 |
+
Args:
|
| 95 |
+
img (Tensor | Numpy array):
|
| 96 |
+
Tensor: Input image with shape (c, h, w), [0, 1] range.
|
| 97 |
+
Numpy: Input image with shape (h, w, c), [0, 1] range.
|
| 98 |
+
scale (float): Scale factor. The same scale applies for both height
|
| 99 |
+
and width.
|
| 100 |
+
antialisaing (bool): Whether to apply anti-aliasing when downsampling.
|
| 101 |
+
Default: True.
|
| 102 |
+
|
| 103 |
+
Returns:
|
| 104 |
+
Tensor: Output image with shape (c, h, w), [0, 1] range, w/o round.
|
| 105 |
+
"""
|
| 106 |
+
if type(img).__module__ == np.__name__: # numpy type
|
| 107 |
+
numpy_type = True
|
| 108 |
+
img = torch.from_numpy(img.transpose(2, 0, 1)).float()
|
| 109 |
+
else:
|
| 110 |
+
numpy_type = False
|
| 111 |
+
|
| 112 |
+
in_c, in_h, in_w = img.size()
|
| 113 |
+
out_h, out_w = math.ceil(in_h * scale), math.ceil(in_w * scale)
|
| 114 |
+
kernel_width = 4
|
| 115 |
+
kernel = 'cubic'
|
| 116 |
+
|
| 117 |
+
# get weights and indices
|
| 118 |
+
weights_h, indices_h, sym_len_hs, sym_len_he = calculate_weights_indices(
|
| 119 |
+
in_h, out_h, scale, kernel, kernel_width, antialiasing)
|
| 120 |
+
weights_w, indices_w, sym_len_ws, sym_len_we = calculate_weights_indices(
|
| 121 |
+
in_w, out_w, scale, kernel, kernel_width, antialiasing)
|
| 122 |
+
# process H dimension
|
| 123 |
+
# symmetric copying
|
| 124 |
+
img_aug = torch.FloatTensor(in_c, in_h + sym_len_hs + sym_len_he, in_w)
|
| 125 |
+
img_aug.narrow(1, sym_len_hs, in_h).copy_(img)
|
| 126 |
+
|
| 127 |
+
sym_patch = img[:, :sym_len_hs, :]
|
| 128 |
+
inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
|
| 129 |
+
sym_patch_inv = sym_patch.index_select(1, inv_idx)
|
| 130 |
+
img_aug.narrow(1, 0, sym_len_hs).copy_(sym_patch_inv)
|
| 131 |
+
|
| 132 |
+
sym_patch = img[:, -sym_len_he:, :]
|
| 133 |
+
inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
|
| 134 |
+
sym_patch_inv = sym_patch.index_select(1, inv_idx)
|
| 135 |
+
img_aug.narrow(1, sym_len_hs + in_h, sym_len_he).copy_(sym_patch_inv)
|
| 136 |
+
|
| 137 |
+
out_1 = torch.FloatTensor(in_c, out_h, in_w)
|
| 138 |
+
kernel_width = weights_h.size(1)
|
| 139 |
+
for i in range(out_h):
|
| 140 |
+
idx = int(indices_h[i][0])
|
| 141 |
+
for j in range(in_c):
|
| 142 |
+
out_1[j, i, :] = img_aug[j, idx:idx + kernel_width, :].transpose(
|
| 143 |
+
0, 1).mv(weights_h[i])
|
| 144 |
+
|
| 145 |
+
# process W dimension
|
| 146 |
+
# symmetric copying
|
| 147 |
+
out_1_aug = torch.FloatTensor(in_c, out_h, in_w + sym_len_ws + sym_len_we)
|
| 148 |
+
out_1_aug.narrow(2, sym_len_ws, in_w).copy_(out_1)
|
| 149 |
+
|
| 150 |
+
sym_patch = out_1[:, :, :sym_len_ws]
|
| 151 |
+
inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
|
| 152 |
+
sym_patch_inv = sym_patch.index_select(2, inv_idx)
|
| 153 |
+
out_1_aug.narrow(2, 0, sym_len_ws).copy_(sym_patch_inv)
|
| 154 |
+
|
| 155 |
+
sym_patch = out_1[:, :, -sym_len_we:]
|
| 156 |
+
inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
|
| 157 |
+
sym_patch_inv = sym_patch.index_select(2, inv_idx)
|
| 158 |
+
out_1_aug.narrow(2, sym_len_ws + in_w, sym_len_we).copy_(sym_patch_inv)
|
| 159 |
+
|
| 160 |
+
out_2 = torch.FloatTensor(in_c, out_h, out_w)
|
| 161 |
+
kernel_width = weights_w.size(1)
|
| 162 |
+
for i in range(out_w):
|
| 163 |
+
idx = int(indices_w[i][0])
|
| 164 |
+
for j in range(in_c):
|
| 165 |
+
out_2[j, :, i] = out_1_aug[j, :,
|
| 166 |
+
idx:idx + kernel_width].mv(weights_w[i])
|
| 167 |
+
|
| 168 |
+
if numpy_type:
|
| 169 |
+
out_2 = out_2.numpy().transpose(1, 2, 0)
|
| 170 |
+
return out_2
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
def rgb2ycbcr(img, y_only=False):
|
| 174 |
+
"""Convert a RGB image to YCbCr image.
|
| 175 |
+
|
| 176 |
+
This function produces the same results as Matlab's `rgb2ycbcr` function.
|
| 177 |
+
It implements the ITU-R BT.601 conversion for standard-definition
|
| 178 |
+
television. See more details in
|
| 179 |
+
https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
|
| 180 |
+
|
| 181 |
+
It differs from a similar function in cv2.cvtColor: `RGB <-> YCrCb`.
|
| 182 |
+
In OpenCV, it implements a JPEG conversion. See more details in
|
| 183 |
+
https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
|
| 184 |
+
|
| 185 |
+
Args:
|
| 186 |
+
img (ndarray): The input image. It accepts:
|
| 187 |
+
1. np.uint8 type with range [0, 255];
|
| 188 |
+
2. np.float32 type with range [0, 1].
|
| 189 |
+
y_only (bool): Whether to only return Y channel. Default: False.
|
| 190 |
+
|
| 191 |
+
Returns:
|
| 192 |
+
ndarray: The converted YCbCr image. The output image has the same type
|
| 193 |
+
and range as input image.
|
| 194 |
+
"""
|
| 195 |
+
img_type = img.dtype
|
| 196 |
+
img = _convert_input_type_range(img)
|
| 197 |
+
if y_only:
|
| 198 |
+
out_img = np.dot(img, [65.481, 128.553, 24.966]) + 16.0
|
| 199 |
+
else:
|
| 200 |
+
out_img = np.matmul(
|
| 201 |
+
img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786],
|
| 202 |
+
[24.966, 112.0, -18.214]]) + [16, 128, 128]
|
| 203 |
+
out_img = _convert_output_type_range(out_img, img_type)
|
| 204 |
+
return out_img
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
def bgr2ycbcr(img, y_only=False):
|
| 208 |
+
"""Convert a BGR image to YCbCr image.
|
| 209 |
+
|
| 210 |
+
The bgr version of rgb2ycbcr.
|
| 211 |
+
It implements the ITU-R BT.601 conversion for standard-definition
|
| 212 |
+
television. See more details in
|
| 213 |
+
https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
|
| 214 |
+
|
| 215 |
+
It differs from a similar function in cv2.cvtColor: `BGR <-> YCrCb`.
|
| 216 |
+
In OpenCV, it implements a JPEG conversion. See more details in
|
| 217 |
+
https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
|
| 218 |
+
|
| 219 |
+
Args:
|
| 220 |
+
img (ndarray): The input image. It accepts:
|
| 221 |
+
1. np.uint8 type with range [0, 255];
|
| 222 |
+
2. np.float32 type with range [0, 1].
|
| 223 |
+
y_only (bool): Whether to only return Y channel. Default: False.
|
| 224 |
+
|
| 225 |
+
Returns:
|
| 226 |
+
ndarray: The converted YCbCr image. The output image has the same type
|
| 227 |
+
and range as input image.
|
| 228 |
+
"""
|
| 229 |
+
img_type = img.dtype
|
| 230 |
+
img = _convert_input_type_range(img)
|
| 231 |
+
if y_only:
|
| 232 |
+
out_img = np.dot(img, [24.966, 128.553, 65.481]) + 16.0
|
| 233 |
+
else:
|
| 234 |
+
out_img = np.matmul(
|
| 235 |
+
img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786],
|
| 236 |
+
[65.481, -37.797, 112.0]]) + [16, 128, 128]
|
| 237 |
+
out_img = _convert_output_type_range(out_img, img_type)
|
| 238 |
+
return out_img
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
def ycbcr2rgb(img):
|
| 242 |
+
"""Convert a YCbCr image to RGB image.
|
| 243 |
+
|
| 244 |
+
This function produces the same results as Matlab's ycbcr2rgb function.
|
| 245 |
+
It implements the ITU-R BT.601 conversion for standard-definition
|
| 246 |
+
television. See more details in
|
| 247 |
+
https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
|
| 248 |
+
|
| 249 |
+
It differs from a similar function in cv2.cvtColor: `YCrCb <-> RGB`.
|
| 250 |
+
In OpenCV, it implements a JPEG conversion. See more details in
|
| 251 |
+
https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
|
| 252 |
+
|
| 253 |
+
Args:
|
| 254 |
+
img (ndarray): The input image. It accepts:
|
| 255 |
+
1. np.uint8 type with range [0, 255];
|
| 256 |
+
2. np.float32 type with range [0, 1].
|
| 257 |
+
|
| 258 |
+
Returns:
|
| 259 |
+
ndarray: The converted RGB image. The output image has the same type
|
| 260 |
+
and range as input image.
|
| 261 |
+
"""
|
| 262 |
+
img_type = img.dtype
|
| 263 |
+
img = _convert_input_type_range(img) * 255
|
| 264 |
+
out_img = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621],
|
| 265 |
+
[0, -0.00153632, 0.00791071],
|
| 266 |
+
[0.00625893, -0.00318811, 0]]) * 255.0 + [
|
| 267 |
+
-222.921, 135.576, -276.836
|
| 268 |
+
] # noqa: E126
|
| 269 |
+
out_img = _convert_output_type_range(out_img, img_type)
|
| 270 |
+
return out_img
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
def ycbcr2bgr(img):
|
| 274 |
+
"""Convert a YCbCr image to BGR image.
|
| 275 |
+
|
| 276 |
+
The bgr version of ycbcr2rgb.
|
| 277 |
+
It implements the ITU-R BT.601 conversion for standard-definition
|
| 278 |
+
television. See more details in
|
| 279 |
+
https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
|
| 280 |
+
|
| 281 |
+
It differs from a similar function in cv2.cvtColor: `YCrCb <-> BGR`.
|
| 282 |
+
In OpenCV, it implements a JPEG conversion. See more details in
|
| 283 |
+
https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
|
| 284 |
+
|
| 285 |
+
Args:
|
| 286 |
+
img (ndarray): The input image. It accepts:
|
| 287 |
+
1. np.uint8 type with range [0, 255];
|
| 288 |
+
2. np.float32 type with range [0, 1].
|
| 289 |
+
|
| 290 |
+
Returns:
|
| 291 |
+
ndarray: The converted BGR image. The output image has the same type
|
| 292 |
+
and range as input image.
|
| 293 |
+
"""
|
| 294 |
+
img_type = img.dtype
|
| 295 |
+
img = _convert_input_type_range(img) * 255
|
| 296 |
+
out_img = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621],
|
| 297 |
+
[0.00791071, -0.00153632, 0],
|
| 298 |
+
[0, -0.00318811, 0.00625893]]) * 255.0 + [
|
| 299 |
+
-276.836, 135.576, -222.921
|
| 300 |
+
] # noqa: E126
|
| 301 |
+
out_img = _convert_output_type_range(out_img, img_type)
|
| 302 |
+
return out_img
|
| 303 |
+
|
| 304 |
+
|
| 305 |
+
def _convert_input_type_range(img):
|
| 306 |
+
"""Convert the type and range of the input image.
|
| 307 |
+
|
| 308 |
+
It converts the input image to np.float32 type and range of [0, 1].
|
| 309 |
+
It is mainly used for pre-processing the input image in colorspace
|
| 310 |
+
convertion functions such as rgb2ycbcr and ycbcr2rgb.
|
| 311 |
+
|
| 312 |
+
Args:
|
| 313 |
+
img (ndarray): The input image. It accepts:
|
| 314 |
+
1. np.uint8 type with range [0, 255];
|
| 315 |
+
2. np.float32 type with range [0, 1].
|
| 316 |
+
|
| 317 |
+
Returns:
|
| 318 |
+
(ndarray): The converted image with type of np.float32 and range of
|
| 319 |
+
[0, 1].
|
| 320 |
+
"""
|
| 321 |
+
img_type = img.dtype
|
| 322 |
+
img = img.astype(np.float32)
|
| 323 |
+
if img_type == np.float32:
|
| 324 |
+
pass
|
| 325 |
+
elif img_type == np.uint8:
|
| 326 |
+
img /= 255.
|
| 327 |
+
else:
|
| 328 |
+
raise TypeError('The img type should be np.float32 or np.uint8, '
|
| 329 |
+
f'but got {img_type}')
|
| 330 |
+
return img
|
| 331 |
+
|
| 332 |
+
|
| 333 |
+
def _convert_output_type_range(img, dst_type):
|
| 334 |
+
"""Convert the type and range of the image according to dst_type.
|
| 335 |
+
|
| 336 |
+
It converts the image to desired type and range. If `dst_type` is np.uint8,
|
| 337 |
+
images will be converted to np.uint8 type with range [0, 255]. If
|
| 338 |
+
`dst_type` is np.float32, it converts the image to np.float32 type with
|
| 339 |
+
range [0, 1].
|
| 340 |
+
It is mainly used for post-processing images in colorspace convertion
|
| 341 |
+
functions such as rgb2ycbcr and ycbcr2rgb.
|
| 342 |
+
|
| 343 |
+
Args:
|
| 344 |
+
img (ndarray): The image to be converted with np.float32 type and
|
| 345 |
+
range [0, 255].
|
| 346 |
+
dst_type (np.uint8 | np.float32): If dst_type is np.uint8, it
|
| 347 |
+
converts the image to np.uint8 type with range [0, 255]. If
|
| 348 |
+
dst_type is np.float32, it converts the image to np.float32 type
|
| 349 |
+
with range [0, 1].
|
| 350 |
+
|
| 351 |
+
Returns:
|
| 352 |
+
(ndarray): The converted image with desired type and range.
|
| 353 |
+
"""
|
| 354 |
+
if dst_type not in (np.uint8, np.float32):
|
| 355 |
+
raise TypeError('The dst_type should be np.float32 or np.uint8, '
|
| 356 |
+
f'but got {dst_type}')
|
| 357 |
+
if dst_type == np.uint8:
|
| 358 |
+
img = img.round()
|
| 359 |
+
else:
|
| 360 |
+
img /= 255.
|
| 361 |
+
return img.astype(dst_type)
|
basicsr/utils/misc.py
ADDED
|
@@ -0,0 +1,266 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import os
|
| 3 |
+
import random
|
| 4 |
+
import time
|
| 5 |
+
import torch
|
| 6 |
+
from os import path as osp
|
| 7 |
+
|
| 8 |
+
from .dist_util import master_only
|
| 9 |
+
from .logger import get_root_logger
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def set_random_seed(seed):
|
| 13 |
+
"""Set random seeds."""
|
| 14 |
+
random.seed(seed)
|
| 15 |
+
np.random.seed(seed)
|
| 16 |
+
torch.manual_seed(seed)
|
| 17 |
+
torch.cuda.manual_seed(seed)
|
| 18 |
+
torch.cuda.manual_seed_all(seed)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def get_time_str():
|
| 22 |
+
return time.strftime('%Y%m%d_%H%M%S', time.localtime())
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def mkdir_and_rename(path):
|
| 26 |
+
"""mkdirs. If path exists, rename it with timestamp and create a new one.
|
| 27 |
+
|
| 28 |
+
Args:
|
| 29 |
+
path (str): Folder path.
|
| 30 |
+
"""
|
| 31 |
+
# if osp.exists(path):
|
| 32 |
+
# new_name = path + '_archived_' + get_time_str()
|
| 33 |
+
# print(f'Path already exists. Rename it to {new_name}', flush=True)
|
| 34 |
+
# os.rename(path, new_name)
|
| 35 |
+
os.makedirs(path, exist_ok=True)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
@master_only
|
| 39 |
+
def make_exp_dirs(opt):
|
| 40 |
+
"""Make dirs for experiments."""
|
| 41 |
+
path_opt = opt['path'].copy()
|
| 42 |
+
if opt['is_train']:
|
| 43 |
+
mkdir_and_rename(path_opt.pop('experiments_root'))
|
| 44 |
+
else:
|
| 45 |
+
mkdir_and_rename(path_opt.pop('results_root'))
|
| 46 |
+
for key, path in path_opt.items():
|
| 47 |
+
if ('strict_load' not in key) and ('pretrain_network'
|
| 48 |
+
not in key) and ('resume'
|
| 49 |
+
not in key):
|
| 50 |
+
os.makedirs(path, exist_ok=True)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def scandir(dir_path, suffix=None, recursive=False, full_path=False):
|
| 54 |
+
"""Scan a directory to find the interested files.
|
| 55 |
+
|
| 56 |
+
Args:
|
| 57 |
+
dir_path (str): Path of the directory.
|
| 58 |
+
suffix (str | tuple(str), optional): File suffix that we are
|
| 59 |
+
interested in. Default: None.
|
| 60 |
+
recursive (bool, optional): If set to True, recursively scan the
|
| 61 |
+
directory. Default: False.
|
| 62 |
+
full_path (bool, optional): If set to True, include the dir_path.
|
| 63 |
+
Default: False.
|
| 64 |
+
|
| 65 |
+
Returns:
|
| 66 |
+
A generator for all the interested files with relative pathes.
|
| 67 |
+
"""
|
| 68 |
+
|
| 69 |
+
if (suffix is not None) and not isinstance(suffix, (str, tuple)):
|
| 70 |
+
raise TypeError('"suffix" must be a string or tuple of strings')
|
| 71 |
+
|
| 72 |
+
root = dir_path
|
| 73 |
+
|
| 74 |
+
def _scandir(dir_path, suffix, recursive):
|
| 75 |
+
for entry in os.scandir(dir_path):
|
| 76 |
+
if not entry.name.startswith('.') and entry.is_file():
|
| 77 |
+
if full_path:
|
| 78 |
+
return_path = entry.path
|
| 79 |
+
else:
|
| 80 |
+
return_path = osp.relpath(entry.path, root)
|
| 81 |
+
|
| 82 |
+
if suffix is None:
|
| 83 |
+
yield return_path
|
| 84 |
+
elif return_path.endswith(suffix):
|
| 85 |
+
yield return_path
|
| 86 |
+
else:
|
| 87 |
+
if recursive:
|
| 88 |
+
yield from _scandir(
|
| 89 |
+
entry.path, suffix=suffix, recursive=recursive)
|
| 90 |
+
else:
|
| 91 |
+
continue
|
| 92 |
+
|
| 93 |
+
return _scandir(dir_path, suffix=suffix, recursive=recursive)
|
| 94 |
+
|
| 95 |
+
def scandir_mv(dir_path, suffix=None, recursive=False, full_path=False, lq=True):
|
| 96 |
+
"""Scan a directory to find the interested files.
|
| 97 |
+
|
| 98 |
+
Args:
|
| 99 |
+
dir_path (str): Path of the directory.
|
| 100 |
+
suffix (str | tuple(str), optional): File suffix that we are
|
| 101 |
+
interested in. Default: None.
|
| 102 |
+
recursive (bool, optional): If set to True, recursively scan the
|
| 103 |
+
directory. Default: False.
|
| 104 |
+
full_path (bool, optional): If set to True, include the dir_path.
|
| 105 |
+
Default: False.
|
| 106 |
+
|
| 107 |
+
Returns:
|
| 108 |
+
A generator for all the interested files with relative pathes.
|
| 109 |
+
"""
|
| 110 |
+
|
| 111 |
+
if (suffix is not None) and not isinstance(suffix, (str, tuple)):
|
| 112 |
+
raise TypeError('"suffix" must be a string or tuple of strings')
|
| 113 |
+
|
| 114 |
+
root = dir_path
|
| 115 |
+
_type = "no_noise" if lq else "images"
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
# 1,3K는 아직 안만들어져서 2K까지 받는다고 가정
|
| 119 |
+
def _scandir(dir_path, suffix, recursive):
|
| 120 |
+
folders = os.listdir(dir_path)
|
| 121 |
+
all_files = []
|
| 122 |
+
for folder in folders: # tag
|
| 123 |
+
all_files.append(osp.join(dir_path, folder, "images_4")) # ~~train/46/0398fdk3/no_noise
|
| 124 |
+
|
| 125 |
+
# 아래는 1,2,3K 다 쓰는 경우
|
| 126 |
+
# subfolders = os.listdir(osp.join(dir_path, folder)) # images4
|
| 127 |
+
# for subfolder in subfolders:
|
| 128 |
+
# all_files.append(osp.join(dir_path, folder, subfolder, _type)) # ~~train/46/0398fdk3/no_noise
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
return all_files
|
| 132 |
+
return _scandir(dir_path, suffix, recursive)
|
| 133 |
+
|
| 134 |
+
def scandir_mv_flat(dir_path, suffix=None, recursive=False, full_path=False, lq=True):
|
| 135 |
+
"""Scan a directory to find the interested files.
|
| 136 |
+
|
| 137 |
+
Args:
|
| 138 |
+
dir_path (str): Path of the directory.
|
| 139 |
+
suffix (str | tuple(str), optional): File suffix that we are
|
| 140 |
+
interested in. Default: None.
|
| 141 |
+
recursive (bool, optional): If set to True, recursively scan the
|
| 142 |
+
directory. Default: False.
|
| 143 |
+
full_path (bool, optional): If set to True, include the dir_path.
|
| 144 |
+
Default: False.
|
| 145 |
+
|
| 146 |
+
Returns:
|
| 147 |
+
A generator for all the interested files with relative pathes.
|
| 148 |
+
"""
|
| 149 |
+
|
| 150 |
+
if (suffix is not None) and not isinstance(suffix, (str, tuple)):
|
| 151 |
+
raise TypeError('"suffix" must be a string or tuple of strings')
|
| 152 |
+
|
| 153 |
+
root = dir_path
|
| 154 |
+
_type = "no_noise" if lq else "images"
|
| 155 |
+
|
| 156 |
+
def _scandir(dir_path, suffix, recursive):
|
| 157 |
+
for entry in os.scandir(dir_path):
|
| 158 |
+
if not entry.name.startswith('.') and entry.is_file():
|
| 159 |
+
if full_path:
|
| 160 |
+
return_path = entry.path
|
| 161 |
+
else:
|
| 162 |
+
return_path = osp.relpath(entry.path, root)
|
| 163 |
+
|
| 164 |
+
if suffix is None:
|
| 165 |
+
yield return_path
|
| 166 |
+
elif return_path.endswith(suffix):
|
| 167 |
+
yield return_path
|
| 168 |
+
else:
|
| 169 |
+
if recursive:
|
| 170 |
+
if entry.name in ["both_noises", "gaussian_only", "images", "no_noise", "no_noise_BGR", "poisson_only", "sparse"]:
|
| 171 |
+
if entry.name != _type:
|
| 172 |
+
continue
|
| 173 |
+
yield from _scandir(
|
| 174 |
+
entry.path, suffix=suffix, recursive=recursive)
|
| 175 |
+
else:
|
| 176 |
+
continue
|
| 177 |
+
|
| 178 |
+
return _scandir(dir_path, suffix=suffix, recursive=recursive)
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
def scandir_SIDD(dir_path, keywords=None, recursive=False, full_path=False):
|
| 182 |
+
"""Scan a directory to find the interested files.
|
| 183 |
+
|
| 184 |
+
Args:
|
| 185 |
+
dir_path (str): Path of the directory.
|
| 186 |
+
keywords (str | tuple(str), optional): File keywords that we are
|
| 187 |
+
interested in. Default: None.
|
| 188 |
+
recursive (bool, optional): If set to True, recursively scan the
|
| 189 |
+
directory. Default: False.
|
| 190 |
+
full_path (bool, optional): If set to True, include the dir_path.
|
| 191 |
+
Default: False.
|
| 192 |
+
|
| 193 |
+
Returns:
|
| 194 |
+
A generator for all the interested files with relative pathes.
|
| 195 |
+
"""
|
| 196 |
+
|
| 197 |
+
if (keywords is not None) and not isinstance(keywords, (str, tuple)):
|
| 198 |
+
raise TypeError('"keywords" must be a string or tuple of strings')
|
| 199 |
+
|
| 200 |
+
root = dir_path
|
| 201 |
+
|
| 202 |
+
def _scandir(dir_path, keywords, recursive):
|
| 203 |
+
for entry in os.scandir(dir_path):
|
| 204 |
+
if not entry.name.startswith('.') and entry.is_file():
|
| 205 |
+
if full_path:
|
| 206 |
+
return_path = entry.path
|
| 207 |
+
else:
|
| 208 |
+
return_path = osp.relpath(entry.path, root)
|
| 209 |
+
|
| 210 |
+
if keywords is None:
|
| 211 |
+
yield return_path
|
| 212 |
+
elif return_path.find(keywords) > 0:
|
| 213 |
+
yield return_path
|
| 214 |
+
else:
|
| 215 |
+
if recursive:
|
| 216 |
+
yield from _scandir(
|
| 217 |
+
entry.path, keywords=keywords, recursive=recursive)
|
| 218 |
+
else:
|
| 219 |
+
continue
|
| 220 |
+
|
| 221 |
+
return _scandir(dir_path, keywords=keywords, recursive=recursive)
|
| 222 |
+
|
| 223 |
+
def check_resume(opt, resume_iter):
|
| 224 |
+
"""Check resume states and pretrain_network paths.
|
| 225 |
+
|
| 226 |
+
Args:
|
| 227 |
+
opt (dict): Options.
|
| 228 |
+
resume_iter (int): Resume iteration.
|
| 229 |
+
"""
|
| 230 |
+
logger = get_root_logger()
|
| 231 |
+
if opt['path']['resume_state']:
|
| 232 |
+
# get all the networks
|
| 233 |
+
networks = [key for key in opt.keys() if key.startswith('network_')]
|
| 234 |
+
flag_pretrain = False
|
| 235 |
+
for network in networks:
|
| 236 |
+
if opt['path'].get(f'pretrain_{network}') is not None:
|
| 237 |
+
flag_pretrain = True
|
| 238 |
+
if flag_pretrain:
|
| 239 |
+
logger.warning(
|
| 240 |
+
'pretrain_network path will be ignored during resuming.')
|
| 241 |
+
# set pretrained model paths
|
| 242 |
+
for network in networks:
|
| 243 |
+
name = f'pretrain_{network}'
|
| 244 |
+
basename = network.replace('network_', '')
|
| 245 |
+
if opt['path'].get('ignore_resume_networks') is None or (
|
| 246 |
+
basename not in opt['path']['ignore_resume_networks']):
|
| 247 |
+
opt['path'][name] = osp.join(
|
| 248 |
+
opt['path']['models'], f'net_{basename}_{resume_iter}.pth')
|
| 249 |
+
logger.info(f"Set {name} to {opt['path'][name]}")
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
def sizeof_fmt(size, suffix='B'):
|
| 253 |
+
"""Get human readable file size.
|
| 254 |
+
|
| 255 |
+
Args:
|
| 256 |
+
size (int): File size.
|
| 257 |
+
suffix (str): Suffix. Default: 'B'.
|
| 258 |
+
|
| 259 |
+
Return:
|
| 260 |
+
str: Formated file siz.
|
| 261 |
+
"""
|
| 262 |
+
for unit in ['', 'K', 'M', 'G', 'T', 'P', 'E', 'Z']:
|
| 263 |
+
if abs(size) < 1024.0:
|
| 264 |
+
return f'{size:3.1f} {unit}{suffix}'
|
| 265 |
+
size /= 1024.0
|
| 266 |
+
return f'{size:3.1f} Y{suffix}'
|
basicsr/utils/nano.py
ADDED
|
@@ -0,0 +1,250 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
import numpy as np
|
| 4 |
+
from torch.distributions.poisson import Poisson
|
| 5 |
+
import random
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def crop_to_bounding_box(image, offset_height, offset_width, target_height,
|
| 9 |
+
target_width, is_batch):
|
| 10 |
+
# BHWC -> BHWC
|
| 11 |
+
cropped = image[:, offset_height: offset_height + target_height, offset_width: offset_width + target_width, :]
|
| 12 |
+
|
| 13 |
+
if not is_batch:
|
| 14 |
+
cropped = cropped[0]
|
| 15 |
+
|
| 16 |
+
return cropped
|
| 17 |
+
|
| 18 |
+
def crop_to_bounding_box_list(image, offset_height, offset_width, target_height,
|
| 19 |
+
target_width):
|
| 20 |
+
# HWC
|
| 21 |
+
cropped = [_image[offset_height: offset_height + target_height, offset_width: offset_width + target_width, :] for _image in image]
|
| 22 |
+
|
| 23 |
+
return cropped
|
| 24 |
+
|
| 25 |
+
def pad_to_bounding_box(image, offset_height, offset_width, target_height,
|
| 26 |
+
target_width, is_batch):
|
| 27 |
+
_,height,width,_ = image.shape
|
| 28 |
+
after_padding_width = target_width - offset_width - width
|
| 29 |
+
after_padding_height = target_height - offset_height - height
|
| 30 |
+
|
| 31 |
+
paddings = (0, 0, offset_width, after_padding_width, offset_height, after_padding_height, 0, 0)
|
| 32 |
+
|
| 33 |
+
padded = torch.nn.functional.pad(image, paddings)
|
| 34 |
+
if not is_batch:
|
| 35 |
+
padded = padded[0]
|
| 36 |
+
|
| 37 |
+
return padded
|
| 38 |
+
|
| 39 |
+
def resize_with_crop_or_pad_torch(image, target_height, target_width):
|
| 40 |
+
# BHWC -> BHWC
|
| 41 |
+
|
| 42 |
+
is_batch = True
|
| 43 |
+
if image.ndim == 3:
|
| 44 |
+
is_batch = False
|
| 45 |
+
image = image[None] # 1HWC
|
| 46 |
+
|
| 47 |
+
def max_(x, y):
|
| 48 |
+
return max(x, y)
|
| 49 |
+
|
| 50 |
+
def min_(x, y):
|
| 51 |
+
return min(x, y)
|
| 52 |
+
|
| 53 |
+
def equal_(x, y):
|
| 54 |
+
return x == y
|
| 55 |
+
|
| 56 |
+
_, height, width, _ = image.shape
|
| 57 |
+
width_diff = target_width - width
|
| 58 |
+
offset_crop_width = max_(-width_diff // 2, 0)
|
| 59 |
+
offset_pad_width = max_(width_diff // 2, 0)
|
| 60 |
+
|
| 61 |
+
height_diff = target_height - height
|
| 62 |
+
offset_crop_height = max_(-height_diff // 2, 0)
|
| 63 |
+
offset_pad_height = max_(height_diff // 2, 0)
|
| 64 |
+
|
| 65 |
+
# Maybe crop if needed.
|
| 66 |
+
cropped = crop_to_bounding_box(image, offset_crop_height, offset_crop_width,
|
| 67 |
+
min_(target_height, height),
|
| 68 |
+
min_(target_width, width), is_batch)
|
| 69 |
+
|
| 70 |
+
# Maybe pad if needed.
|
| 71 |
+
if not is_batch and cropped.ndim == 3:
|
| 72 |
+
cropped = cropped[None]
|
| 73 |
+
resized = pad_to_bounding_box(cropped, offset_pad_height, offset_pad_width,
|
| 74 |
+
target_height, target_width, is_batch)
|
| 75 |
+
|
| 76 |
+
return resized
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def psf2otf(psf, h=None, w=None, permute=False):
|
| 81 |
+
'''
|
| 82 |
+
psf = (b) h,w,c
|
| 83 |
+
'''
|
| 84 |
+
if h is not None:
|
| 85 |
+
psf = resize_with_crop_or_pad_torch(psf, h, w)
|
| 86 |
+
if permute:
|
| 87 |
+
if psf.ndim == 3:
|
| 88 |
+
psf = psf.permute(2,0,1) # HWC -> CHW
|
| 89 |
+
else:
|
| 90 |
+
psf = psf.permute(0,3,1,2) # HWC -> CHW
|
| 91 |
+
psf = psf.to(torch.complex64)
|
| 92 |
+
psf = torch.fft.fftshift(psf, dim=(-1,-2))
|
| 93 |
+
otf = torch.fft.fft2(psf)
|
| 94 |
+
return otf
|
| 95 |
+
|
| 96 |
+
def fft(img): # CHW
|
| 97 |
+
img = img.to(torch.complex64)
|
| 98 |
+
Fimg = torch.fft.fft2(img)
|
| 99 |
+
return Fimg
|
| 100 |
+
|
| 101 |
+
def ifft(Fimg):
|
| 102 |
+
img = torch.abs(torch.fft.ifft2(Fimg)).to(torch.float32)
|
| 103 |
+
return img
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def create_contrast_mask(image):
|
| 107 |
+
return 1 - torch.mean(image, dim=(-1,-2), keepdim=True) # (B), C,1,1
|
| 108 |
+
|
| 109 |
+
def apply_tikhonov(lr_img, psf, K, norm=True, otf=None):
|
| 110 |
+
h,w = lr_img.shape[-2:]
|
| 111 |
+
if otf is None:
|
| 112 |
+
psf_norm = resize_with_crop_or_pad_torch(psf, h, w)
|
| 113 |
+
if norm:
|
| 114 |
+
psf_norm = psf_norm / psf_norm.sum((0, 1))
|
| 115 |
+
otf = psf2otf(psf_norm, h, w, permute=True)
|
| 116 |
+
|
| 117 |
+
otf = otf[:,None,...] # B,1,C,H,W
|
| 118 |
+
contrast_mask = create_contrast_mask(lr_img)[:,None,...] # B,1,C,1,1
|
| 119 |
+
K_adjusted = K * contrast_mask # B,M,C,1,1
|
| 120 |
+
tikhonov_filter = torch.conj(otf) / (torch.abs(otf) ** 2 + K_adjusted) # B,M,C,H,W
|
| 121 |
+
lr_fft = fft(lr_img)[:,None,...] # B,1,C,H,W
|
| 122 |
+
deconvolved_fft = lr_fft * tikhonov_filter
|
| 123 |
+
deconvolved_image = torch.fft.ifft2(deconvolved_fft).real
|
| 124 |
+
deconvolved_image = torch.clamp(deconvolved_image, min=0.0, max=1.0)
|
| 125 |
+
|
| 126 |
+
return deconvolved_image # B,M,C,H,W
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def add_noise_all_new(image, poss=4e-5, gaus=1e-5):
|
| 130 |
+
p = Poisson(image / poss)
|
| 131 |
+
sampled = p.sample((1,))[0]
|
| 132 |
+
poss_img = sampled * poss
|
| 133 |
+
gauss_noise = torch.randn_like(image) * gaus
|
| 134 |
+
noised_img = poss_img + gauss_noise
|
| 135 |
+
|
| 136 |
+
noised_img = torch.clamp(noised_img, 0.0, 1.0)
|
| 137 |
+
|
| 138 |
+
return noised_img
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
def apply_convolution(image, psf, pad):
|
| 142 |
+
'''
|
| 143 |
+
input: hr img (b,c,h,w, [0,1])
|
| 144 |
+
output: noised lr img (b,c,h+P,w+P, [0,1])
|
| 145 |
+
'''
|
| 146 |
+
|
| 147 |
+
# metalens simulation
|
| 148 |
+
image = F.pad(image, (pad, pad, pad, pad))
|
| 149 |
+
h,w = image.shape[-2:]
|
| 150 |
+
psf_norm = resize_with_crop_or_pad_torch(psf, h, w)
|
| 151 |
+
otf = psf2otf(psf_norm, h, w, permute=True)
|
| 152 |
+
lr_img = fft(image) * otf
|
| 153 |
+
lr_img = torch.clamp(ifft(lr_img), min=1e-20, max=1.0)
|
| 154 |
+
|
| 155 |
+
# noise addition
|
| 156 |
+
noised_img = add_noise_all_new(lr_img)
|
| 157 |
+
|
| 158 |
+
return noised_img, otf
|
| 159 |
+
|
| 160 |
+
def apply_conv_n_deconv(image, otf, padding, M, psize, ks=None, ph=135, num_psf=9, sensor_h=1215, crop=True, conv=True):
|
| 161 |
+
'''
|
| 162 |
+
input: hr img (b,c,h,w)
|
| 163 |
+
otf: 1,N,C,H,W
|
| 164 |
+
output: noised lr img (N,c,h,w)
|
| 165 |
+
'''
|
| 166 |
+
|
| 167 |
+
b,_,_,_ = image.shape
|
| 168 |
+
if conv:
|
| 169 |
+
img_patch = F.unfold(image, kernel_size=ph*3, stride=ph).view(b,3,ph*3,ph*3,num_psf**2).permute(0,4,1,2,3).contiguous() # B,N,C,H,W
|
| 170 |
+
|
| 171 |
+
# metalens simulation
|
| 172 |
+
lr_img = fft(img_patch) * otf
|
| 173 |
+
lr_img = torch.clamp(ifft(lr_img), min=1e-20, max=1.0)
|
| 174 |
+
|
| 175 |
+
# noise addtion
|
| 176 |
+
lr_img = add_noise_all_new(lr_img)
|
| 177 |
+
|
| 178 |
+
else: # load convolved image for validation
|
| 179 |
+
b = 1
|
| 180 |
+
lr_img = image
|
| 181 |
+
|
| 182 |
+
# apply deconvolution
|
| 183 |
+
if ks is not None:
|
| 184 |
+
lr_img = apply_tikhonov(lr_img, None, ks, otf=otf) # B,M,N,C,405,405
|
| 185 |
+
lr_img = lr_img[..., ph:-ph, ph:-ph] # BMNCHW
|
| 186 |
+
lr_img = lr_img.view(b, M, num_psf, num_psf, 3, ph, ph).permute(0,1,4,2,5,3,6).reshape(b,M,3,sensor_h,sensor_h)
|
| 187 |
+
else:
|
| 188 |
+
lr_img = lr_img[..., ph:-ph, ph:-ph] # BNCHW
|
| 189 |
+
lr_img = lr_img.view(b, num_psf, num_psf, 3, ph, ph).permute(0,3,1,4,2,5).reshape(b,3,sensor_h,sensor_h)
|
| 190 |
+
|
| 191 |
+
lq_patches = []
|
| 192 |
+
gt_patches = []
|
| 193 |
+
for i in range(b):
|
| 194 |
+
cur = lr_img[i] # (M),C,H,W
|
| 195 |
+
cur_gt = image[i]
|
| 196 |
+
|
| 197 |
+
# remove padding for lq and gt
|
| 198 |
+
pt,pb,pl,pr = padding[i]
|
| 199 |
+
if pb and pt:
|
| 200 |
+
cur = cur[...,pt: -pb, :]
|
| 201 |
+
cur_gt = cur_gt[...,pt+ph: -(pb+ph), ph:-ph]
|
| 202 |
+
elif pl and pr:
|
| 203 |
+
cur = cur[...,pl:-pr]
|
| 204 |
+
cur_gt = cur_gt[...,ph:-ph, pl+ph: -(pr+ph)]
|
| 205 |
+
else:
|
| 206 |
+
cur_gt = cur_gt[...,ph:-ph, ph: -ph]
|
| 207 |
+
h,w = cur.shape[-2:]
|
| 208 |
+
|
| 209 |
+
# randomly crop patch for training
|
| 210 |
+
if crop: # train
|
| 211 |
+
top = random.randint(0, h - psize)
|
| 212 |
+
left = random.randint(0, w - psize)
|
| 213 |
+
lq_patches.append(cur[..., top:top + psize, left:left + psize])
|
| 214 |
+
gt_patches.append(cur_gt[..., top:top + psize, left:left + psize])
|
| 215 |
+
if crop: # training
|
| 216 |
+
lq_patches = torch.stack(lq_patches)
|
| 217 |
+
gt_patches = torch.stack(gt_patches)
|
| 218 |
+
else: # validation
|
| 219 |
+
return cur, cur_gt
|
| 220 |
+
|
| 221 |
+
return lq_patches, gt_patches # B,(M),C,H,W
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
def apply_convolution_square_val(image, otf, padding, M, psize, ks=None, ph=135, num_psf=9, sensor_h=1215, crop=False):
|
| 225 |
+
'''
|
| 226 |
+
merge to above one.
|
| 227 |
+
image = lr_image
|
| 228 |
+
'''
|
| 229 |
+
lr_img = image
|
| 230 |
+
b = 1
|
| 231 |
+
if M: # apply deconvolution
|
| 232 |
+
lr_img = apply_tikhonov(lr_img, None, ks, otf=otf) # B,M,N,C,H,W
|
| 233 |
+
lr_img = lr_img[..., ph:-ph, ph:-ph] # B,M,N,C,H,W
|
| 234 |
+
lr_img = lr_img.view(b, M, num_psf, num_psf, 3, ph, ph).permute(0,1,4,2,5,3,6).reshape(b,M,3,sensor_h,sensor_h)
|
| 235 |
+
else:
|
| 236 |
+
lr_img = lr_img[..., ph:-ph, ph:-ph] # B,N,C,H,W
|
| 237 |
+
lr_img = lr_img.view(b, num_psf, num_psf, 3, ph, ph).permute(0,3,1,4,2,5).reshape(b,3,sensor_h,sensor_h)
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
for i in range(b):
|
| 241 |
+
cur = lr_img[i] # (M),C,H,W
|
| 242 |
+
|
| 243 |
+
# remove padding for lq and gt
|
| 244 |
+
pt,pb,pl,pr = padding[i]
|
| 245 |
+
if pb and pt:
|
| 246 |
+
cur = cur[...,pt: -pb, :]
|
| 247 |
+
elif pl and pr:
|
| 248 |
+
cur = cur[...,pl:-pr]
|
| 249 |
+
|
| 250 |
+
return cur
|
basicsr/utils/options.py
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import yaml
|
| 2 |
+
from collections import OrderedDict
|
| 3 |
+
from os import path as osp
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def ordered_yaml():
|
| 7 |
+
"""Support OrderedDict for yaml.
|
| 8 |
+
|
| 9 |
+
Returns:
|
| 10 |
+
yaml Loader and Dumper.
|
| 11 |
+
"""
|
| 12 |
+
try:
|
| 13 |
+
from yaml import CDumper as Dumper
|
| 14 |
+
from yaml import CLoader as Loader
|
| 15 |
+
except ImportError:
|
| 16 |
+
from yaml import Dumper, Loader
|
| 17 |
+
|
| 18 |
+
_mapping_tag = yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG
|
| 19 |
+
|
| 20 |
+
def dict_representer(dumper, data):
|
| 21 |
+
return dumper.represent_dict(data.items())
|
| 22 |
+
|
| 23 |
+
def dict_constructor(loader, node):
|
| 24 |
+
return OrderedDict(loader.construct_pairs(node))
|
| 25 |
+
|
| 26 |
+
Dumper.add_representer(OrderedDict, dict_representer)
|
| 27 |
+
Loader.add_constructor(_mapping_tag, dict_constructor)
|
| 28 |
+
return Loader, Dumper
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def parse(opt_path, is_train=True, name=None):
|
| 32 |
+
"""Parse option file.
|
| 33 |
+
|
| 34 |
+
Args:
|
| 35 |
+
opt_path (str): Option file path.
|
| 36 |
+
is_train (str): Indicate whether in training or not. Default: True.
|
| 37 |
+
|
| 38 |
+
Returns:
|
| 39 |
+
(dict): Options.
|
| 40 |
+
"""
|
| 41 |
+
with open(opt_path, mode='r') as f:
|
| 42 |
+
Loader, _ = ordered_yaml()
|
| 43 |
+
opt = yaml.load(f, Loader=Loader)
|
| 44 |
+
|
| 45 |
+
opt['is_train'] = is_train
|
| 46 |
+
if name is not None:
|
| 47 |
+
opt['name'] = name
|
| 48 |
+
|
| 49 |
+
# datasets
|
| 50 |
+
for phase, dataset in opt['datasets'].items():
|
| 51 |
+
# for several datasets, e.g., test_1, test_2
|
| 52 |
+
phase = phase.split('_')[0]
|
| 53 |
+
dataset['phase'] = phase
|
| 54 |
+
if 'scale' in opt:
|
| 55 |
+
dataset['scale'] = opt['scale']
|
| 56 |
+
if dataset.get('dataroot_gt') is not None:
|
| 57 |
+
dataset['dataroot_gt'] = osp.expanduser(dataset['dataroot_gt'])
|
| 58 |
+
if dataset.get('dataroot_lq') is not None:
|
| 59 |
+
dataset['dataroot_lq'] = osp.expanduser(dataset['dataroot_lq'])
|
| 60 |
+
|
| 61 |
+
# paths
|
| 62 |
+
for key, val in opt['path'].items():
|
| 63 |
+
if (val is not None) and ('resume_state' in key
|
| 64 |
+
or 'pretrain_network' in key):
|
| 65 |
+
opt['path'][key] = osp.expanduser(val)
|
| 66 |
+
opt['path']['root'] = osp.abspath(
|
| 67 |
+
osp.join(__file__, osp.pardir, osp.pardir, osp.pardir))
|
| 68 |
+
if is_train:
|
| 69 |
+
experiments_root = osp.join(opt['path']['root'], 'experiments',
|
| 70 |
+
opt['name'])
|
| 71 |
+
opt['path']['experiments_root'] = experiments_root
|
| 72 |
+
opt['path']['models'] = osp.join(experiments_root, 'models')
|
| 73 |
+
opt['path']['training_states'] = osp.join(experiments_root,
|
| 74 |
+
'training_states')
|
| 75 |
+
opt['path']['log'] = experiments_root
|
| 76 |
+
opt['path']['visualization'] = osp.join(experiments_root,
|
| 77 |
+
'visualization')
|
| 78 |
+
|
| 79 |
+
# change some options for debug mode
|
| 80 |
+
if 'debug' in opt['name']:
|
| 81 |
+
if 'val' in opt:
|
| 82 |
+
opt['val']['val_freq'] = 8
|
| 83 |
+
opt['logger']['print_freq'] = 1
|
| 84 |
+
opt['logger']['save_checkpoint_freq'] = 8
|
| 85 |
+
else: # test
|
| 86 |
+
results_root = osp.join(opt['path']['root'], 'results', opt['name'])
|
| 87 |
+
opt['path']['results_root'] = results_root
|
| 88 |
+
opt['path']['log'] = results_root
|
| 89 |
+
opt['path']['visualization'] = osp.join(results_root, 'visualization')
|
| 90 |
+
|
| 91 |
+
return opt
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def dict2str(opt, indent_level=1):
|
| 95 |
+
"""dict to string for printing options.
|
| 96 |
+
|
| 97 |
+
Args:
|
| 98 |
+
opt (dict): Option dict.
|
| 99 |
+
indent_level (int): Indent level. Default: 1.
|
| 100 |
+
|
| 101 |
+
Return:
|
| 102 |
+
(str): Option string for printing.
|
| 103 |
+
"""
|
| 104 |
+
msg = '\n'
|
| 105 |
+
for k, v in opt.items():
|
| 106 |
+
if isinstance(v, dict):
|
| 107 |
+
msg += ' ' * (indent_level * 2) + k + ':['
|
| 108 |
+
msg += dict2str(v, indent_level + 1)
|
| 109 |
+
msg += ' ' * (indent_level * 2) + ']\n'
|
| 110 |
+
else:
|
| 111 |
+
msg += ' ' * (indent_level * 2) + k + ': ' + str(v) + '\n'
|
| 112 |
+
return msg
|
basicsr/version.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# GENERATED VERSION FILE
|
| 2 |
+
# TIME: Fri Mar 21 07:59:14 2025
|
| 3 |
+
__version__ = '1.2.0+5ea673c'
|
| 4 |
+
short_version = '1.2.0'
|
| 5 |
+
version_info = (1, 2, 0)
|
experiments/pretrained/models/net_g_100000.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8cc95533ca8a4dfdcfad5de2973346ad6b699c6abaf4e7e9d0de77007c4b855f
|
| 3 |
+
size 116763496
|
experiments/pretrained/training_states/100000.state
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:edb3104cc8f57a1100b4f0e3d87814a74b2c0fd1ed24a86d69b917b0e1973d2b
|
| 3 |
+
size 233563982
|
psf.npy
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:337461630addd8dcc48a0293678b5ef75d9c35a5c7b6a0524154d2e8540741a8
|
| 3 |
+
size 17714828
|
readme.md
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Aberration Correcting Vision Transformers for High-Fidelity Metalens Imaging
|
| 2 |
+
|
| 3 |
+
Byeonghyeon Lee, Youbin Kim, Yongjae Jo, Hyunsu Kim, Hyemi Park, Yangkyu Kim, Debabrata Mandal, Praneeth Chakravarthula, Inki Kim, and Eunbyung Park
|
| 4 |
+
|
| 5 |
+
[Project Page](https://benhenryl.github.io/Metalens-Transformer/) [Paper](https://arxiv.org/abs/2412.04591)
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
We ran the experiments in the following environment:
|
| 9 |
+
```
|
| 10 |
+
- ubuntu: 20.04
|
| 11 |
+
- python: 3.10.13
|
| 12 |
+
- cuda: 11.8
|
| 13 |
+
- pytorch: 2.2.0
|
| 14 |
+
- GPU: 4x A6000 ada
|
| 15 |
+
```
|
| 16 |
+
|
| 17 |
+
Our code is based on [Restormer](https://github.com/swz30/Restormer), [X-Restormer](https://github.com/Andrew0613/X-Restormer), and [Neural Nano Optics](https://github.com/princeton-computational-imaging/Neural_Nano-Optics). We appreciate their works.
|
| 18 |
+
|
| 19 |
+
## 1. Environment Setting
|
| 20 |
+
### 1-1. Pytorch
|
| 21 |
+
Note: pytorch >= 2.2.0 is required for Flash Attention.
|
| 22 |
+
|
| 23 |
+
### 1-2. [Flash Attention](https://github.com/Dao-AILab/flash-attention)
|
| 24 |
+
cf. Ampere, Ada, or Hopper GPUs (e.g., A100, RTX 3090, RTX 4090, H100) are supported now.
|
| 25 |
+
```
|
| 26 |
+
pip install packaging ninja
|
| 27 |
+
pip install flash-attn --no-build-isolation
|
| 28 |
+
```
|
| 29 |
+
|
| 30 |
+
### 1-3. Other required packages
|
| 31 |
+
```
|
| 32 |
+
pip install -r requirements.txt
|
| 33 |
+
```
|
| 34 |
+
|
| 35 |
+
### 1-4. Basicsr
|
| 36 |
+
```
|
| 37 |
+
python setup.py develop --no_cuda_ext
|
| 38 |
+
```
|
| 39 |
+
|
| 40 |
+
## 2. Dataset & Pre-trained weights
|
| 41 |
+
You can download train/test dataset [here](https://drive.google.com/drive/folders/1e2wJwmcjXFvblVs0l5OXwpIkTqxd1Fhq?usp=drive_link) and pre-trained weights [here](https://drive.google.com/drive/folders/1q5pKE1Z0RJjHVmJlNq7nPSWcaGd9bDb7?usp=drive_link).
|
| 42 |
+
Please move the pre-trained weights to experiments/.
|
| 43 |
+
Note: The model creates aberrated images on the fly using clean (gt) images during training.
|
| 44 |
+
In case of validation, it also produces the aberrated images in the same manner, where the aberrated images can have different noises to what we used for our validation.
|
| 45 |
+
There will be only negligible difference in the results as it still uses the same noise distributions, but if you want a precise comparison with the validation set we used for our experiments, please contact us.
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
## 3. Training
|
| 49 |
+
Please set dataset path in ```./Aberration_Correction/Options/Train_Aberration_Transformers.yml```
|
| 50 |
+
```
|
| 51 |
+
bash train.sh GPU_IDS FOLDER_NAME
|
| 52 |
+
// ex. bash train.sh 0,1,2,3 training
|
| 53 |
+
// where it uses gpu 0 to 3 and make a directory experiments/training where log, weights and others will be stored.
|
| 54 |
+
```
|
| 55 |
+
|
| 56 |
+
## 4. Inference
|
| 57 |
+
Please set dataset path in ```./Aberration_Correction/Options/Test_Aberration_Transformers.yml```
|
| 58 |
+
If you want to run a inference using the pre-trained model, you can use a command
|
| 59 |
+
```
|
| 60 |
+
bash test.sh GPU_ID FOLDER_NAME
|
| 61 |
+
// ex. bash test.sh 0 pretrained
|
| 62 |
+
```
|
| 63 |
+
Or you can designate the FOLDER_NAME with your weight path.
|
| 64 |
+
|
| 65 |
+
## BibTeX
|
| 66 |
+
```
|
| 67 |
+
@article{lee2024aberration,
|
| 68 |
+
title={Aberration Correcting Vision Transformers for High-Fidelity Metalens Imaging},
|
| 69 |
+
author={Lee, Byeonghyeon and Kim, Youbin and Jo, Yongjae and Kim, Hyunsu and Park, Hyemi and Kim, Yangkyu and Mandal, Debabrata and Chakravarthula, Praneeth and Kim, Inki and Park, Eunbyung},
|
| 70 |
+
journal={arXiv preprint arXiv:2412.04591},
|
| 71 |
+
year={2024}
|
| 72 |
+
}
|
| 73 |
+
```
|