hongw.qin commited on
Commit
3c0e74c
·
1 Parent(s): 54abac0

upload models

Browse files
.gitignore ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ .vscode/
2
+ .venv/
3
+ *.pyc
4
+ __pycache__/
5
+ outputs/
6
+ datasets/
README.md ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ tags:
4
+ - RyzenAI
5
+ - Int8 quantization
6
+ - Single Image Super Resolution
7
+ - SESR
8
+ - ONNX
9
+ - Computer Vision
10
+ metrics:
11
+ - PSNR
12
+ - MS_SSIM
13
+ - FID
14
+ ---
15
+
16
+ # SESR for 2x Single Image Super Resolution
17
+
18
+ We provide 2x super-resolution models at resolution 256x256.
19
+
20
+ It was introduced in the paper _Collapsible Linear Blocks for Super-Efficient Super Resolution_ by Bhardwaj. The official code for this work is available at [sesr](https://github.com/ARM-software/sesr).
21
+
22
+ We have developed a modified version optimized for [AMD Ryzen AI](https://onnxruntime.ai/docs/execution-providers/Vitis-AI-ExecutionProvider.html).
23
+
24
+ ## Model description
25
+
26
+ SESR is based on linear overparameterization of CNNs and creates an efficient model architecture for SISR.
27
+
28
+ ## Intended uses & limitations
29
+
30
+ You can use this model for single image super resolution tasks. See the [model hub](https://huggingface.co/models?search=amd/ryzenai-sesr) for all available models.
31
+
32
+ ## How to use
33
+
34
+ ### Installation
35
+
36
+ ```bash
37
+ # inference only
38
+ pip install -r requirements-infer.txt
39
+ # inference & evaluation
40
+ pip install -r requirements-eval.txt
41
+ ```
42
+
43
+ ### Data Preparation (optional: for evaluation)
44
+
45
+ Run `python download_edsr_benchmark.py` to automatically download and extract the EDSR benchmark dataset into the datasets directory. After it completes, your datasets folder should have the following structure:
46
+
47
+ ```Plain
48
+ datasets/edsr_benchmark
49
+ └── B100
50
+ └── HR
51
+ ├── 3096.png
52
+ ├── ...
53
+ └── LR_bicubic/X2
54
+ ├── 3096x4.png
55
+ ├── ...
56
+ └── Set5
57
+ └── HR
58
+ ├── baby.png
59
+ ├── ...
60
+ └── LR_bicubic/X2
61
+ ├── babyx4.png
62
+ ├── ...
63
+ ```
64
+
65
+ ### Test & Evaluation
66
+
67
+ - **Run inference on images**
68
+
69
+ ```bash
70
+ python onnx_inference.py --onnx sesr_nchw_fp32.onnx --input /Path/To/Image --out-dir outputs
71
+ python onnx_inference.py --onnx sesr_nchw_int8.onnx --input /Path/To/Image --out-dir outputs
72
+ ```
73
+
74
+ _Arguments:_
75
+
76
+ `--input`: Accepts either a single image file path or a directory path. If it's a file, the script will process that image only. If it's a directory, the script will recursively scan for .png, .jpg, and .jpeg files and process all of them.
77
+
78
+ `--out-dir`: Output directory where the restored images will be saved.
79
+
80
+ - **Evaluate the quantized model**
81
+
82
+ _Arguments:_
83
+
84
+ `--onnx`: Path to the ONNX model file.
85
+
86
+ `--hq-dir`: Directory containing high-quality (ground truth) images.
87
+
88
+ `--lq-dir`: Directory containing low-quality (input) images.
89
+
90
+ `--out-dir`: Output directory where evaluation results will be saved.
91
+
92
+ `--max-samples`: (Optional) Limit the number of samples to evaluate. Useful for debugging. If not specified, all samples will be evaluated.
93
+
94
+ `-clean`: (Optional) If specified, the generated super-resolution images will be deleted after evaluation to save disk space.
95
+
96
+ ```bash
97
+ # ===================== eval int8 =====================
98
+ python onnx_eval.py \
99
+ --onnx sesr_nchw_int8.onnx \
100
+ --hq-dir datasets/edsr_benchmark/Set5/HR \
101
+ --lq-dir datasets/edsr_benchmark/Set5/LR_bicubic/X2 \
102
+ --out-dir outputs/Set5 -clean
103
+
104
+ python onnx_eval.py \
105
+ --onnx sesr_nchw_int8.onnx \
106
+ --hq-dir datasets/edsr_benchmark/Set14/HR \
107
+ --lq-dir datasets/edsr_benchmark/Set14/LR_bicubic/X2 \
108
+ --out-dir outputs/Set14 -clean
109
+
110
+ python onnx_eval.py \
111
+ --onnx sesr_nchw_int8.onnx \
112
+ --hq-dir datasets/edsr_benchmark/B100/HR \
113
+ --lq-dir datasets/edsr_benchmark/B100/LR_bicubic/X2 \
114
+ --out-dir outputs/B100 -clean
115
+
116
+ python onnx_eval.py \
117
+ --onnx sesr_nchw_int8.onnx \
118
+ --hq-dir datasets/edsr_benchmark/Urban100/HR \
119
+ --lq-dir datasets/edsr_benchmark/Urban100/LR_bicubic/X2 \
120
+ --out-dir outputs/Urban100 -clean
121
+
122
+
123
+ # ===================== eval fp32 =====================
124
+ python onnx_eval.py \
125
+ --onnx sesr_nchw_fp32.onnx \
126
+ --hq-dir datasets/edsr_benchmark/Set5/HR \
127
+ --lq-dir datasets/edsr_benchmark/Set5/LR_bicubic/X2 \
128
+ --out-dir outputs/Set5 -clean
129
+
130
+ python onnx_eval.py \
131
+ --onnx sesr_nchw_fp32.onnx \
132
+ --hq-dir datasets/edsr_benchmark/Set14/HR \
133
+ --lq-dir datasets/edsr_benchmark/Set14/LR_bicubic/X2 \
134
+ --out-dir outputs/Set14 -clean
135
+
136
+ python onnx_eval.py \
137
+ --onnx sesr_nchw_fp32.onnx \
138
+ --hq-dir datasets/edsr_benchmark/B100/HR \
139
+ --lq-dir datasets/edsr_benchmark/B100/LR_bicubic/X2 \
140
+ --out-dir outputs/B100 -clean
141
+
142
+ python onnx_eval.py \
143
+ --onnx sesr_nchw_fp32.onnx \
144
+ --hq-dir datasets/edsr_benchmark/Urban100/HR \
145
+ --lq-dir datasets/edsr_benchmark/Urban100/LR_bicubic/X2 \
146
+ --out-dir outputs/Urban100 -clean
147
+ ```
148
+
149
+ ### Performance
150
+
151
+ | Model | | Set5 | | | Set14 | | | B100 | | | Urban100 | |
152
+ | :--------- | ------- | ---------- | ------ | ------- | ---------- | ------ | ------- | ----------- | ------ | ------- | ---------- | ------ |
153
+ | | PSNR(↑) | MS_SSIM(↑) | FID(↓) | PSNR(↑) | MS_SSIM(↑) | FID(↓) | PSNR(↑) | MS_SSIM (↑) | FID(↓) | PSNR(↑) | MS_SSIM(↑) | FID(↓) |
154
+ | sesr(fp32) | 35.65 | 0.9971 | 26.46 | 30.98 | 0.9935 | 17.69 | 30.23 | 0.9921 | 17.00 | 28.84 | 0.9929 | 0.25 |
155
+ | sesr(int8) | 34.65 | 0.9952 | 28.37 | 30.46 | 0.9916 | 20.70 | 29.80 | 0.9900 | 19.38 | 28.25 | 0.9906 | 1.47 |
156
+
157
+ ---
158
+
159
+ ```bibtex
160
+ @article{bhardwaj2021collapsible,
161
+ title={Collapsible Linear Blocks for Super-Efficient Super Resolution},
162
+ author={Bhardwaj, Kartikeya and Milosavljevic, Milos and O'Neil, Liam and Gope, Dibakar and Matas, Ramon and Chalfin, Alex and Suda, Naveen and Meng, Lingchuan and Loh, Danny},
163
+ journal={arXiv preprint arXiv:2103.09404},
164
+ year={2021}
165
+ }
166
+ ```
download_edsr_benchmark.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from urllib.request import urlretrieve
3
+ import tarfile
4
+ from tqdm import tqdm
5
+ import shutil
6
+
7
+ ITEMS = [
8
+ {
9
+ "url": "https://cv.snu.ac.kr/research/EDSR/benchmark.tar",
10
+ "name": "EDSR_benchmark",
11
+ },
12
+ ]
13
+
14
+
15
+ def download_with_progress(url: str, out_path: Path) -> None:
16
+ out_path.parent.mkdir(parents=True, exist_ok=True)
17
+
18
+ print(f"Downloading {url} -> {out_path}")
19
+
20
+ bar = None
21
+ last_b = 0
22
+
23
+ def reporthook(b: int, bsize: int, tsize: int):
24
+ nonlocal bar, last_b
25
+ if bar is None:
26
+ total = tsize if tsize > 0 else None
27
+ bar = tqdm(
28
+ total=total,
29
+ unit="B",
30
+ unit_scale=True,
31
+ unit_divisor=1024,
32
+ desc=out_path.name,
33
+ dynamic_ncols=True,
34
+ )
35
+ delta_blocks = b - last_b
36
+ if delta_blocks > 0:
37
+ bar.update(delta_blocks * bsize)
38
+ last_b = b
39
+
40
+ try:
41
+ urlretrieve(url, out_path, reporthook=reporthook)
42
+ finally:
43
+ if bar is not None:
44
+ bar.close()
45
+
46
+
47
+ def extract_tar_flatten(tar_path: Path, dest_dir: Path) -> None:
48
+ dest_dir.mkdir(parents=True, exist_ok=True)
49
+ print(f"Extracting {tar_path} -> {dest_dir} (flatten top folder)")
50
+
51
+ with tarfile.open(tar_path, "r") as tf:
52
+ tf.extractall(dest_dir)
53
+
54
+ # rename benchmark -> edsr_benchmark
55
+ print("Renaming benchmark -> edsr_benchmark")
56
+ shutil.move(str(dest_dir / "benchmark"), str(dest_dir / "edsr_benchmark"))
57
+
58
+
59
+ def main() -> None:
60
+ base = Path(__file__).resolve().parent
61
+ root = base / "datasets"
62
+ out_dir = base / "datasets"
63
+ root.mkdir(parents=True, exist_ok=True)
64
+ out_dir.mkdir(parents=True, exist_ok=True)
65
+
66
+ for it in ITEMS:
67
+ tar_path = out_dir / f"{it['name']}.tar"
68
+
69
+ download_with_progress(it["url"], tar_path)
70
+ extract_tar_flatten(tar_path, root)
71
+
72
+ print("All done.")
73
+
74
+
75
+ if __name__ == "__main__":
76
+ main()
onnx_eval.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import json
3
+ from pathlib import Path
4
+
5
+ sys.path.insert(0, Path(__file__).parent.as_posix())
6
+
7
+
8
+ import cv2
9
+ import pyiqa
10
+ import torch
11
+ import numpy as np
12
+ from tqdm import tqdm
13
+ from onnx_runner import OnnxRunner
14
+
15
+
16
+ def collect_common_image_pairs(
17
+ lq_dir: Path, hq_dir: Path
18
+ ) -> tuple[list[Path], list[Path]]:
19
+ exts = {".png", ".jpg", ".jpeg"}
20
+
21
+ def is_img(p: Path) -> bool:
22
+ return p.is_file() and p.suffix.lower() in exts
23
+
24
+ hq_map = {p.stem: p for p in hq_dir.iterdir() if is_img(p)}
25
+ hq_names = sorted(hq_map.keys())
26
+
27
+ lq_files = [p for p in lq_dir.iterdir() if is_img(p)]
28
+
29
+ lq_paths: list[Path] = []
30
+ hq_paths: list[Path] = []
31
+ for base in hq_names:
32
+ # try full match first
33
+ best_lq = next((p for p in lq_files if p.stem == base), None)
34
+
35
+ # try prefix match then
36
+ if best_lq is None:
37
+ best_lq = next(
38
+ (
39
+ p
40
+ for p in lq_files
41
+ if p.stem.startswith(base) and len(p.stem) > len(base)
42
+ ),
43
+ None,
44
+ )
45
+
46
+ if best_lq is not None: # matched
47
+ hq_paths.append(hq_map[base])
48
+ lq_paths.append(best_lq)
49
+
50
+ return lq_paths, hq_paths
51
+
52
+
53
+ def align_shape_by_crop(sr_bgr: np.ndarray, hq_bgr: np.ndarray):
54
+ if sr_bgr.shape != hq_bgr.shape:
55
+ min_h = min(sr_bgr.shape[0], hq_bgr.shape[0])
56
+ min_w = min(sr_bgr.shape[1], hq_bgr.shape[1])
57
+
58
+ sr_bgr = sr_bgr[:min_h, :min_w]
59
+ hq_bgr = hq_bgr[:min_h, :min_w]
60
+
61
+ return sr_bgr, hq_bgr
62
+
63
+
64
+ def gen_sr_images(
65
+ hq_dir: Path, lq_dir: Path, out_dir: Path, onnx_path: Path, max_samples: int
66
+ ):
67
+ out_dir.mkdir(exist_ok=True, parents=True)
68
+
69
+ onnx_runner = OnnxRunner(onnx_path, sr_scale=2, tile_overlap=8)
70
+
71
+ lq_paths, hq_paths = collect_common_image_pairs(lq_dir, hq_dir)
72
+
73
+ if max_samples is not None:
74
+ lq_paths = lq_paths[: max(max_samples, 1)]
75
+ hq_paths = hq_paths[: max(max_samples, 1)]
76
+
77
+ sr_paths = []
78
+ for i in tqdm(range(len(lq_paths)), desc="generating"):
79
+ lq_img_path = lq_paths[i]
80
+ lq_bgr = cv2.imread(lq_img_path.as_posix(), cv2.IMREAD_COLOR)
81
+ assert lq_bgr is not None
82
+ sr_bgr = onnx_runner.run(lq_bgr)
83
+
84
+ hq_img_path = hq_paths[i]
85
+ hq_bgr = cv2.imread(hq_img_path.as_posix(), cv2.IMREAD_COLOR)
86
+
87
+ aligned_sr_bgr, aligned_hq_bgr = align_shape_by_crop(sr_bgr, hq_bgr)
88
+ if aligned_hq_bgr.shape != hq_bgr.shape:
89
+ cv2.imwrite(hq_img_path.as_posix(), aligned_hq_bgr)
90
+
91
+ out_path = out_dir / f"{lq_img_path.stem}.png"
92
+ cv2.imwrite(out_path.as_posix(), aligned_sr_bgr)
93
+
94
+ sr_paths.append(out_path)
95
+
96
+ return hq_paths, sr_paths
97
+
98
+
99
+ def eval_metrics(
100
+ hq_paths: list[Path],
101
+ sr_paths: list[Path],
102
+ hq_dir: Path,
103
+ sr_dir: Path,
104
+ device: torch.device | None = None,
105
+ ) -> dict[str, float]:
106
+ assert len(hq_paths) == len(sr_paths)
107
+
108
+ device = device or (
109
+ torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
110
+ )
111
+
112
+ # FR: sr, ref
113
+ psnr_metric = pyiqa.create_metric("psnr", device=device, test_y_channel=True)
114
+ # FR: sr, ref
115
+ ms_ssim_metric = pyiqa.create_metric("ms_ssim", device=device, test_y_channel=True)
116
+ fid_metric = pyiqa.create_metric("fid")
117
+
118
+ with torch.inference_mode():
119
+ psnr_vals = []
120
+ ms_ssim_vals = []
121
+ for sr_p, hq_p in zip(sr_paths, hq_paths):
122
+ sr_p = sr_p.as_posix()
123
+ hq_p = hq_p.as_posix()
124
+ psnr_vals.append(psnr_metric(sr_p, hq_p).detach())
125
+ ms_ssim_vals.append(ms_ssim_metric(sr_p, hq_p).detach())
126
+
127
+ psnr = torch.stack(psnr_vals).mean().item()
128
+ ms_ssim = torch.stack(ms_ssim_vals).mean().item()
129
+
130
+ fid = fid_metric(
131
+ sr_dir.as_posix(),
132
+ hq_dir.as_posix(),
133
+ mode="clean",
134
+ batch_size=1,
135
+ num_workers=0,
136
+ ).item()
137
+
138
+ return {"psnr": psnr, "ms_ssim": ms_ssim, "fid": fid}
139
+
140
+
141
+ def main(args):
142
+ onnx_path = Path(args.onnx)
143
+ hq_dir = Path(args.hq_dir)
144
+ lq_dir = Path(args.lq_dir)
145
+ out_dir = Path(args.out_dir)
146
+
147
+ assert onnx_path.suffix == ".onnx"
148
+ assert lq_dir.is_dir(), f"{lq_dir} is not a dir!"
149
+ assert hq_dir.is_dir(), f"{hq_dir} is not a dir!"
150
+
151
+ sr_dir = out_dir / "sr"
152
+ hq_paths, sr_paths = gen_sr_images(
153
+ hq_dir, lq_dir, sr_dir, onnx_path, args.max_samples
154
+ )
155
+
156
+ scores = eval_metrics(hq_paths, sr_paths, hq_dir, sr_dir)
157
+
158
+ summary = {
159
+ "onnx": onnx_path.as_posix(),
160
+ "psnr": scores["psnr"],
161
+ "ms_ssim": scores["ms_ssim"],
162
+ "fid": scores["fid"],
163
+ }
164
+
165
+ out_file = out_dir / f"eval_{onnx_path.stem}_result.json"
166
+ with open(out_file, "w") as f:
167
+ json.dump(summary, f, indent=2)
168
+ dataset_name = hq_dir.parent.name
169
+ print(f"summary of {dataset_name}: PSNR | MS_SSIM | FID")
170
+ print(
171
+ f"{dataset_name}: {scores['psnr']:.2f} | {scores['ms_ssim']:.4f} | {scores['fid']:.2f}"
172
+ )
173
+ print(f"result saved to {out_file}")
174
+
175
+ if args.clean:
176
+ import shutil
177
+
178
+ print(f"cleaning enhanced lq dir: {sr_dir}")
179
+ shutil.rmtree(sr_dir.as_posix(), ignore_errors=True)
180
+
181
+
182
+ if __name__ == "__main__":
183
+ from argparse import ArgumentParser
184
+
185
+ parser = ArgumentParser()
186
+ parser.add_argument("--onnx", type=str, required=True)
187
+ parser.add_argument("--hq-dir", type=str, required=True)
188
+ parser.add_argument("--lq-dir", type=str, required=True)
189
+ parser.add_argument("--out-dir", type=str, default="outputs")
190
+ parser.add_argument(
191
+ "--max-samples",
192
+ type=int,
193
+ default=None,
194
+ help="limit number of used samples(debug purpose only), None means not-limited",
195
+ )
196
+ parser.add_argument(
197
+ "-clean",
198
+ action="store_true",
199
+ default=False,
200
+ help="clean out-dir when finished",
201
+ )
202
+ main(parser.parse_args())
onnx_inference.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ import sys
3
+
4
+
5
+ sys.path.insert(0, Path(__file__).parent.as_posix())
6
+
7
+
8
+ import cv2
9
+ from onnx_runner import OnnxRunner
10
+
11
+
12
+ def main(args):
13
+ onnx_path = Path(args.onnx)
14
+ input_path = Path(args.input)
15
+ out_dir = Path(args.out_dir)
16
+
17
+ assert onnx_path.suffix == ".onnx"
18
+
19
+ onnx_runner = OnnxRunner(onnx_path, sr_scale=2, debug=False)
20
+
21
+ if input_path.is_file():
22
+ input_images_path = [input_path]
23
+ else:
24
+ input_images_path = sorted(
25
+ [
26
+ p
27
+ for p in input_path.rglob("*")
28
+ if p.suffix.lower() in (".png", ".jpg", ".jpeg")
29
+ ]
30
+ )
31
+
32
+ out_dir.mkdir(exist_ok=True, parents=True)
33
+ for input_img_path in input_images_path:
34
+ input_img_path: Path
35
+
36
+ img_bgr = cv2.imread(input_img_path.as_posix(), cv2.IMREAD_COLOR)
37
+ assert img_bgr is not None
38
+ sr_img_bgr = onnx_runner.run(img_bgr)
39
+
40
+ out_path = out_dir / f"{input_img_path.stem}.png"
41
+ cv2.imwrite(out_path.as_posix(), sr_img_bgr)
42
+ print(f"saved {out_path}")
43
+
44
+
45
+ if __name__ == "__main__":
46
+ from argparse import ArgumentParser
47
+
48
+ parser = ArgumentParser()
49
+ parser.add_argument("--onnx", type=str, required=True)
50
+ parser.add_argument("--input", type=str, required=True)
51
+ parser.add_argument("--out-dir", type=str, required=True)
52
+
53
+ main(parser.parse_args())
onnx_runner.py ADDED
@@ -0,0 +1,275 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+
4
+ import numpy as np
5
+ import onnxruntime as ort
6
+
7
+
8
+ __all__ = [
9
+ "OnnxRunner",
10
+ ]
11
+
12
+
13
+ def split_into_tiles_with_context(
14
+ img_chw: np.ndarray,
15
+ patch_size_hw: tuple[int, int],
16
+ overlap: int,
17
+ ):
18
+ """
19
+ Args:
20
+ img_chw: (C, H, W)
21
+ patch_size_hw: (ph, pw) size of each patch.
22
+ overlap: overlap of neighbored patches.
23
+
24
+ Returns:
25
+ tiles_chw: list[np.ndarray], each tile in shape of [C, ph, pw]
26
+ orig_hw: (H, W)
27
+ padded_hw: (H_pad, W_pad)
28
+ """
29
+ import math
30
+
31
+ assert img_chw.ndim == 3
32
+ C, H, W = img_chw.shape
33
+ ph, pw = patch_size_hw
34
+
35
+ assert 2 * overlap < ph and 2 * overlap < pw, "2*overlap must <= patch_size"
36
+
37
+ # core region size(remove overlap region)
38
+ core_h = ph - 2 * overlap
39
+ core_w = pw - 2 * overlap
40
+ assert core_h > 0 and core_w > 0
41
+
42
+ # compute how much tiles required
43
+ n_tiles_h = math.ceil(H / core_h)
44
+ n_tiles_w = math.ceil(W / core_w)
45
+
46
+ # center padded size
47
+ H_pad = n_tiles_h * core_h
48
+ W_pad = n_tiles_w * core_w
49
+
50
+ # first padding, make sure padded image divisible by patch size
51
+ pad_h = H_pad - H
52
+ pad_w = W_pad - W
53
+ img_pad = np.pad(
54
+ img_chw,
55
+ pad_width=((0, 0), (0, pad_h), (0, pad_w)),
56
+ mode="reflect",
57
+ ) # (C, H_pad, W_pad)
58
+
59
+ # second padding, add reflect context for boundaries
60
+ big_pad = np.pad(
61
+ img_pad,
62
+ pad_width=((0, 0), (overlap, overlap), (overlap, overlap)),
63
+ mode="reflect",
64
+ ) # (C, H_pad+2o, W_pad+2o)
65
+
66
+ tiles = []
67
+ for iy in range(n_tiles_h):
68
+ for ix in range(n_tiles_w):
69
+ cy = iy * core_h
70
+ cx = ix * core_w
71
+ y0 = cy
72
+ x0 = cx
73
+ tile = big_pad[:, y0 : y0 + ph, x0 : x0 + pw]
74
+ tiles.append(tile)
75
+
76
+ return tiles, (H, W), (H_pad, W_pad)
77
+
78
+
79
+ def merge_tiles_with_context(
80
+ tiles_chw: list[np.ndarray],
81
+ orig_hw: tuple[int, int],
82
+ padded_hw: tuple[int, int],
83
+ overlap: int,
84
+ ) -> np.ndarray:
85
+ """
86
+ Args:
87
+ tiles_chw:
88
+ orig_hw: original image size.
89
+ padded_hw: center-padded image size.
90
+ overlap: overlap of neighbored patches.
91
+
92
+ Returns:
93
+ img_chw: (C, H, W)
94
+ """
95
+ assert len(tiles_chw) > 0
96
+ C, ph, pw = tiles_chw[0].shape
97
+ H, W = orig_hw
98
+ H_pad, W_pad = padded_hw
99
+
100
+ assert 2 * overlap < ph and 2 * overlap < pw
101
+ core_h = ph - 2 * overlap
102
+ core_w = pw - 2 * overlap
103
+ n_h = H_pad // core_h
104
+ n_w = W_pad // core_w
105
+ assert n_h * n_w == len(tiles_chw), "tiles != padded_hw"
106
+
107
+ img_pad_recon = np.zeros((C, H_pad, W_pad), dtype=tiles_chw[0].dtype)
108
+
109
+ idx = 0
110
+ for iy in range(n_h):
111
+ for ix in range(n_w):
112
+ cy = iy * core_h
113
+ cx = ix * core_w
114
+
115
+ tile = tiles_chw[idx]
116
+ core = tile[:, overlap : overlap + core_h, overlap : overlap + core_w]
117
+ img_pad_recon[:, cy : cy + core_h, cx : cx + core_w] = core
118
+ idx += 1
119
+
120
+ img_recon = img_pad_recon[:, :H, :W]
121
+ return np.ascontiguousarray(img_recon)
122
+
123
+
124
+ def parse_input_shape_fmt(input_shape):
125
+ """parse input shape is nchw or nhwc format.
126
+ We assume c is smaller than h&w dimensions
127
+ """
128
+ assert len(input_shape) == 4
129
+
130
+ c1, c2, c3 = input_shape[1:]
131
+
132
+ if c1 < min(c2, c3): # c1 is channel dimension
133
+ return "nchw"
134
+ elif c3 < min(c1, c2): # c3 is channel dimension
135
+ return "nhwc"
136
+ else:
137
+ raise ValueError(f"can not parse input format for shape: {input_shape}")
138
+
139
+
140
+ def is_channel_last(img: np.ndarray):
141
+ return img.shape[2] < min(img.shape[0], img.shape[1])
142
+
143
+
144
+ def preprocess(img_bgr: np.ndarray):
145
+ """Convert bgr channel last uint8 image to rgb channel first fp32 image."""
146
+ img_rgb = img_bgr[..., ::-1] # bgr -> rgb
147
+ img_chw = np.transpose(img_rgb, [2, 0, 1]) # hwc -> chw
148
+ img_chw = np.float32(img_chw) # uint8 -> fp32
149
+
150
+ return np.ascontiguousarray(img_chw)
151
+
152
+
153
+ def postprocess(pred_chw: np.ndarray):
154
+ """Convert rgb channel first fp32 image to bgr channel last uint8 image"""
155
+
156
+ uint8_chw = pred_chw.clip(0, 255).astype(np.uint8) # fp32 -> uint8
157
+ img_rgb = np.transpose(uint8_chw, [1, 2, 0]) # chw -> hwc
158
+ img_bgr = img_rgb[..., ::-1] # rgb to bgr
159
+
160
+ return np.ascontiguousarray(img_bgr)
161
+
162
+
163
+ class OnnxRunner:
164
+ """Single Image Super Resolution onnx runner."""
165
+
166
+ def __init__(self, onnx_path, sr_scale, tile_overlap: int = 8, debug=False):
167
+ if "CUDAExecutionProvider" in ort.get_available_providers():
168
+ providers = ["CUDAExecutionProvider"]
169
+ else:
170
+ providers = ["CPUExecutionProvider"]
171
+
172
+ ort_session = ort.InferenceSession(str(onnx_path), providers=providers)
173
+
174
+ input0 = ort_session.get_inputs()[0]
175
+ self.input_name = input0.name
176
+ self.input_shape = tuple(input0.shape)
177
+ self.input_format = parse_input_shape_fmt(input0.shape)
178
+ self.ort_session = ort_session
179
+ self.sr_scale = sr_scale
180
+ self.tile_overlap = max(tile_overlap, 0)
181
+ self.debug = debug
182
+
183
+ if self.input_format == "nchw":
184
+ self._in_h, self._in_w = self.input_shape[2:]
185
+ else: # nhwc
186
+ self._in_h, self._in_w = self.input_shape[1:3]
187
+
188
+ if debug:
189
+ self._dbg_out_dir = Path(__file__).parent / "outputs"
190
+ self._dbg_out_dir.mkdir(exist_ok=True, parents=True)
191
+
192
+ def _save_dbg_img(self, savename, img):
193
+ if not self.debug:
194
+ return
195
+ import cv2
196
+
197
+ cv2.imwrite(str(self._dbg_out_dir / savename), img)
198
+
199
+ def run(self, img_bgr: np.ndarray) -> np.ndarray:
200
+ """Do 2x scale-up super resolution on given uint8 bgr image,
201
+ and return scaled uint8 bgr image.
202
+ """
203
+ assert img_bgr.dtype == np.uint8, img_bgr.dtype
204
+ assert img_bgr.ndim in (2, 3), img_bgr.ndim
205
+
206
+ if img_bgr.ndim == 3:
207
+ assert is_channel_last(img_bgr), img_bgr.shape
208
+
209
+ if self.debug:
210
+ self._save_dbg_img("original_input_bgr.png", img_bgr)
211
+
212
+ # =====================
213
+ # preprocessing
214
+ # =====================
215
+ img_chw = preprocess(img_bgr)
216
+ tiles_chw, origin_size_hw, padded_size_hw = split_into_tiles_with_context(
217
+ img_chw, (self._in_h, self._in_w), self.tile_overlap
218
+ )
219
+ if self.debug:
220
+ print(f"tiling to {len(tiles_chw)} tiles")
221
+ tile_bgr = postprocess(tiles_chw[0])
222
+ self._save_dbg_img("tile_bgr.png", tile_bgr)
223
+
224
+ # =====================
225
+ # inference
226
+ # =====================
227
+ sr_tiles_chw = []
228
+ for tile_chw in tiles_chw:
229
+ if self.input_format == "nhwc":
230
+ input_3d = np.transpose(tile_chw, [1, 2, 0]) # chw -> hwc
231
+ else:
232
+ input_3d = tile_chw
233
+
234
+ outputs = self.ort_session.run(None, {self.input_name: input_3d[None, ...]})
235
+ sr_tile = outputs[0][0] # chw or hwc format
236
+
237
+ if self.input_format == "nhwc":
238
+ sr_tile_chw = np.transpose(sr_tile, [2, 0, 1]) # hwc -> chw
239
+ else:
240
+ sr_tile_chw = sr_tile
241
+
242
+ sr_tiles_chw.append(sr_tile_chw)
243
+
244
+ if self.debug:
245
+ sr_padded_tile_bgr = postprocess(sr_tiles_chw[0])
246
+ self._save_dbg_img("sr_padded_tile_bgr.png", sr_padded_tile_bgr)
247
+
248
+ # =====================
249
+ # postprocessing
250
+ # =====================
251
+ sr_origin_hw = (
252
+ int(origin_size_hw[0] * self.sr_scale),
253
+ int(origin_size_hw[1] * self.sr_scale),
254
+ )
255
+ sr_padded_hw = (
256
+ int(padded_size_hw[0] * self.sr_scale),
257
+ int(padded_size_hw[1] * self.sr_scale),
258
+ )
259
+ sr_overlap = int(self.tile_overlap * self.sr_scale)
260
+
261
+ sr_padded_chw = merge_tiles_with_context(
262
+ sr_tiles_chw,
263
+ orig_hw=sr_origin_hw,
264
+ padded_hw=sr_padded_hw,
265
+ overlap=sr_overlap,
266
+ )
267
+ if self.debug:
268
+ sr_padded_img_bgr = postprocess(sr_padded_chw)
269
+ self._save_dbg_img("sr_padded_img_bgr.png", sr_padded_img_bgr)
270
+
271
+ sr_chw = sr_padded_chw[..., : sr_origin_hw[0], : sr_origin_hw[1]]
272
+
273
+ sr_img_bgr = postprocess(sr_chw)
274
+
275
+ return sr_img_bgr
requirements-eval.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ onnxruntime==1.22
2
+ numpy==1.26.*
3
+ opencv-python==4.8.*
4
+ tqdm
5
+ torch==2.6.0
6
+ pyiqa @ git+https://github.com/chaofengc/IQA-PyTorch.git@e851fd62e66a97345e1281d80e8deb4ab7b93c83
requirements-infer.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ onnxruntime==1.22
2
+ numpy==1.26.*
3
+ opencv-python==4.8.*
4
+ tqdm
sesr_nchw_fp32.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4b686864a8b17cf9aaad0d787f7b7a133c95317f408cac5204701d7291199711
3
+ size 93732
sesr_nchw_int8.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:11b6adfb3d5d9cc46405af6e684237a1efa7ff6956f82c489631476edc813237
3
+ size 120432