hongw.qin commited on
Commit ·
3c0e74c
1
Parent(s): 54abac0
upload models
Browse files- .gitignore +6 -0
- README.md +166 -0
- download_edsr_benchmark.py +76 -0
- onnx_eval.py +202 -0
- onnx_inference.py +53 -0
- onnx_runner.py +275 -0
- requirements-eval.txt +6 -0
- requirements-infer.txt +4 -0
- sesr_nchw_fp32.onnx +3 -0
- sesr_nchw_int8.onnx +3 -0
.gitignore
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.vscode/
|
| 2 |
+
.venv/
|
| 3 |
+
*.pyc
|
| 4 |
+
__pycache__/
|
| 5 |
+
outputs/
|
| 6 |
+
datasets/
|
README.md
ADDED
|
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: apache-2.0
|
| 3 |
+
tags:
|
| 4 |
+
- RyzenAI
|
| 5 |
+
- Int8 quantization
|
| 6 |
+
- Single Image Super Resolution
|
| 7 |
+
- SESR
|
| 8 |
+
- ONNX
|
| 9 |
+
- Computer Vision
|
| 10 |
+
metrics:
|
| 11 |
+
- PSNR
|
| 12 |
+
- MS_SSIM
|
| 13 |
+
- FID
|
| 14 |
+
---
|
| 15 |
+
|
| 16 |
+
# SESR for 2x Single Image Super Resolution
|
| 17 |
+
|
| 18 |
+
We provide 2x super-resolution models at resolution 256x256.
|
| 19 |
+
|
| 20 |
+
It was introduced in the paper _Collapsible Linear Blocks for Super-Efficient Super Resolution_ by Bhardwaj. The official code for this work is available at [sesr](https://github.com/ARM-software/sesr).
|
| 21 |
+
|
| 22 |
+
We have developed a modified version optimized for [AMD Ryzen AI](https://onnxruntime.ai/docs/execution-providers/Vitis-AI-ExecutionProvider.html).
|
| 23 |
+
|
| 24 |
+
## Model description
|
| 25 |
+
|
| 26 |
+
SESR is based on linear overparameterization of CNNs and creates an efficient model architecture for SISR.
|
| 27 |
+
|
| 28 |
+
## Intended uses & limitations
|
| 29 |
+
|
| 30 |
+
You can use this model for single image super resolution tasks. See the [model hub](https://huggingface.co/models?search=amd/ryzenai-sesr) for all available models.
|
| 31 |
+
|
| 32 |
+
## How to use
|
| 33 |
+
|
| 34 |
+
### Installation
|
| 35 |
+
|
| 36 |
+
```bash
|
| 37 |
+
# inference only
|
| 38 |
+
pip install -r requirements-infer.txt
|
| 39 |
+
# inference & evaluation
|
| 40 |
+
pip install -r requirements-eval.txt
|
| 41 |
+
```
|
| 42 |
+
|
| 43 |
+
### Data Preparation (optional: for evaluation)
|
| 44 |
+
|
| 45 |
+
Run `python download_edsr_benchmark.py` to automatically download and extract the EDSR benchmark dataset into the datasets directory. After it completes, your datasets folder should have the following structure:
|
| 46 |
+
|
| 47 |
+
```Plain
|
| 48 |
+
datasets/edsr_benchmark
|
| 49 |
+
└── B100
|
| 50 |
+
└── HR
|
| 51 |
+
├── 3096.png
|
| 52 |
+
├── ...
|
| 53 |
+
└── LR_bicubic/X2
|
| 54 |
+
├── 3096x4.png
|
| 55 |
+
├── ...
|
| 56 |
+
└── Set5
|
| 57 |
+
└── HR
|
| 58 |
+
├── baby.png
|
| 59 |
+
├── ...
|
| 60 |
+
└── LR_bicubic/X2
|
| 61 |
+
├── babyx4.png
|
| 62 |
+
├── ...
|
| 63 |
+
```
|
| 64 |
+
|
| 65 |
+
### Test & Evaluation
|
| 66 |
+
|
| 67 |
+
- **Run inference on images**
|
| 68 |
+
|
| 69 |
+
```bash
|
| 70 |
+
python onnx_inference.py --onnx sesr_nchw_fp32.onnx --input /Path/To/Image --out-dir outputs
|
| 71 |
+
python onnx_inference.py --onnx sesr_nchw_int8.onnx --input /Path/To/Image --out-dir outputs
|
| 72 |
+
```
|
| 73 |
+
|
| 74 |
+
_Arguments:_
|
| 75 |
+
|
| 76 |
+
`--input`: Accepts either a single image file path or a directory path. If it's a file, the script will process that image only. If it's a directory, the script will recursively scan for .png, .jpg, and .jpeg files and process all of them.
|
| 77 |
+
|
| 78 |
+
`--out-dir`: Output directory where the restored images will be saved.
|
| 79 |
+
|
| 80 |
+
- **Evaluate the quantized model**
|
| 81 |
+
|
| 82 |
+
_Arguments:_
|
| 83 |
+
|
| 84 |
+
`--onnx`: Path to the ONNX model file.
|
| 85 |
+
|
| 86 |
+
`--hq-dir`: Directory containing high-quality (ground truth) images.
|
| 87 |
+
|
| 88 |
+
`--lq-dir`: Directory containing low-quality (input) images.
|
| 89 |
+
|
| 90 |
+
`--out-dir`: Output directory where evaluation results will be saved.
|
| 91 |
+
|
| 92 |
+
`--max-samples`: (Optional) Limit the number of samples to evaluate. Useful for debugging. If not specified, all samples will be evaluated.
|
| 93 |
+
|
| 94 |
+
`-clean`: (Optional) If specified, the generated super-resolution images will be deleted after evaluation to save disk space.
|
| 95 |
+
|
| 96 |
+
```bash
|
| 97 |
+
# ===================== eval int8 =====================
|
| 98 |
+
python onnx_eval.py \
|
| 99 |
+
--onnx sesr_nchw_int8.onnx \
|
| 100 |
+
--hq-dir datasets/edsr_benchmark/Set5/HR \
|
| 101 |
+
--lq-dir datasets/edsr_benchmark/Set5/LR_bicubic/X2 \
|
| 102 |
+
--out-dir outputs/Set5 -clean
|
| 103 |
+
|
| 104 |
+
python onnx_eval.py \
|
| 105 |
+
--onnx sesr_nchw_int8.onnx \
|
| 106 |
+
--hq-dir datasets/edsr_benchmark/Set14/HR \
|
| 107 |
+
--lq-dir datasets/edsr_benchmark/Set14/LR_bicubic/X2 \
|
| 108 |
+
--out-dir outputs/Set14 -clean
|
| 109 |
+
|
| 110 |
+
python onnx_eval.py \
|
| 111 |
+
--onnx sesr_nchw_int8.onnx \
|
| 112 |
+
--hq-dir datasets/edsr_benchmark/B100/HR \
|
| 113 |
+
--lq-dir datasets/edsr_benchmark/B100/LR_bicubic/X2 \
|
| 114 |
+
--out-dir outputs/B100 -clean
|
| 115 |
+
|
| 116 |
+
python onnx_eval.py \
|
| 117 |
+
--onnx sesr_nchw_int8.onnx \
|
| 118 |
+
--hq-dir datasets/edsr_benchmark/Urban100/HR \
|
| 119 |
+
--lq-dir datasets/edsr_benchmark/Urban100/LR_bicubic/X2 \
|
| 120 |
+
--out-dir outputs/Urban100 -clean
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
# ===================== eval fp32 =====================
|
| 124 |
+
python onnx_eval.py \
|
| 125 |
+
--onnx sesr_nchw_fp32.onnx \
|
| 126 |
+
--hq-dir datasets/edsr_benchmark/Set5/HR \
|
| 127 |
+
--lq-dir datasets/edsr_benchmark/Set5/LR_bicubic/X2 \
|
| 128 |
+
--out-dir outputs/Set5 -clean
|
| 129 |
+
|
| 130 |
+
python onnx_eval.py \
|
| 131 |
+
--onnx sesr_nchw_fp32.onnx \
|
| 132 |
+
--hq-dir datasets/edsr_benchmark/Set14/HR \
|
| 133 |
+
--lq-dir datasets/edsr_benchmark/Set14/LR_bicubic/X2 \
|
| 134 |
+
--out-dir outputs/Set14 -clean
|
| 135 |
+
|
| 136 |
+
python onnx_eval.py \
|
| 137 |
+
--onnx sesr_nchw_fp32.onnx \
|
| 138 |
+
--hq-dir datasets/edsr_benchmark/B100/HR \
|
| 139 |
+
--lq-dir datasets/edsr_benchmark/B100/LR_bicubic/X2 \
|
| 140 |
+
--out-dir outputs/B100 -clean
|
| 141 |
+
|
| 142 |
+
python onnx_eval.py \
|
| 143 |
+
--onnx sesr_nchw_fp32.onnx \
|
| 144 |
+
--hq-dir datasets/edsr_benchmark/Urban100/HR \
|
| 145 |
+
--lq-dir datasets/edsr_benchmark/Urban100/LR_bicubic/X2 \
|
| 146 |
+
--out-dir outputs/Urban100 -clean
|
| 147 |
+
```
|
| 148 |
+
|
| 149 |
+
### Performance
|
| 150 |
+
|
| 151 |
+
| Model | | Set5 | | | Set14 | | | B100 | | | Urban100 | |
|
| 152 |
+
| :--------- | ------- | ---------- | ------ | ------- | ---------- | ------ | ------- | ----------- | ------ | ------- | ---------- | ------ |
|
| 153 |
+
| | PSNR(↑) | MS_SSIM(↑) | FID(↓) | PSNR(↑) | MS_SSIM(↑) | FID(↓) | PSNR(↑) | MS_SSIM (↑) | FID(↓) | PSNR(↑) | MS_SSIM(↑) | FID(↓) |
|
| 154 |
+
| sesr(fp32) | 35.65 | 0.9971 | 26.46 | 30.98 | 0.9935 | 17.69 | 30.23 | 0.9921 | 17.00 | 28.84 | 0.9929 | 0.25 |
|
| 155 |
+
| sesr(int8) | 34.65 | 0.9952 | 28.37 | 30.46 | 0.9916 | 20.70 | 29.80 | 0.9900 | 19.38 | 28.25 | 0.9906 | 1.47 |
|
| 156 |
+
|
| 157 |
+
---
|
| 158 |
+
|
| 159 |
+
```bibtex
|
| 160 |
+
@article{bhardwaj2021collapsible,
|
| 161 |
+
title={Collapsible Linear Blocks for Super-Efficient Super Resolution},
|
| 162 |
+
author={Bhardwaj, Kartikeya and Milosavljevic, Milos and O'Neil, Liam and Gope, Dibakar and Matas, Ramon and Chalfin, Alex and Suda, Naveen and Meng, Lingchuan and Loh, Danny},
|
| 163 |
+
journal={arXiv preprint arXiv:2103.09404},
|
| 164 |
+
year={2021}
|
| 165 |
+
}
|
| 166 |
+
```
|
download_edsr_benchmark.py
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pathlib import Path
|
| 2 |
+
from urllib.request import urlretrieve
|
| 3 |
+
import tarfile
|
| 4 |
+
from tqdm import tqdm
|
| 5 |
+
import shutil
|
| 6 |
+
|
| 7 |
+
ITEMS = [
|
| 8 |
+
{
|
| 9 |
+
"url": "https://cv.snu.ac.kr/research/EDSR/benchmark.tar",
|
| 10 |
+
"name": "EDSR_benchmark",
|
| 11 |
+
},
|
| 12 |
+
]
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def download_with_progress(url: str, out_path: Path) -> None:
|
| 16 |
+
out_path.parent.mkdir(parents=True, exist_ok=True)
|
| 17 |
+
|
| 18 |
+
print(f"Downloading {url} -> {out_path}")
|
| 19 |
+
|
| 20 |
+
bar = None
|
| 21 |
+
last_b = 0
|
| 22 |
+
|
| 23 |
+
def reporthook(b: int, bsize: int, tsize: int):
|
| 24 |
+
nonlocal bar, last_b
|
| 25 |
+
if bar is None:
|
| 26 |
+
total = tsize if tsize > 0 else None
|
| 27 |
+
bar = tqdm(
|
| 28 |
+
total=total,
|
| 29 |
+
unit="B",
|
| 30 |
+
unit_scale=True,
|
| 31 |
+
unit_divisor=1024,
|
| 32 |
+
desc=out_path.name,
|
| 33 |
+
dynamic_ncols=True,
|
| 34 |
+
)
|
| 35 |
+
delta_blocks = b - last_b
|
| 36 |
+
if delta_blocks > 0:
|
| 37 |
+
bar.update(delta_blocks * bsize)
|
| 38 |
+
last_b = b
|
| 39 |
+
|
| 40 |
+
try:
|
| 41 |
+
urlretrieve(url, out_path, reporthook=reporthook)
|
| 42 |
+
finally:
|
| 43 |
+
if bar is not None:
|
| 44 |
+
bar.close()
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def extract_tar_flatten(tar_path: Path, dest_dir: Path) -> None:
|
| 48 |
+
dest_dir.mkdir(parents=True, exist_ok=True)
|
| 49 |
+
print(f"Extracting {tar_path} -> {dest_dir} (flatten top folder)")
|
| 50 |
+
|
| 51 |
+
with tarfile.open(tar_path, "r") as tf:
|
| 52 |
+
tf.extractall(dest_dir)
|
| 53 |
+
|
| 54 |
+
# rename benchmark -> edsr_benchmark
|
| 55 |
+
print("Renaming benchmark -> edsr_benchmark")
|
| 56 |
+
shutil.move(str(dest_dir / "benchmark"), str(dest_dir / "edsr_benchmark"))
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def main() -> None:
|
| 60 |
+
base = Path(__file__).resolve().parent
|
| 61 |
+
root = base / "datasets"
|
| 62 |
+
out_dir = base / "datasets"
|
| 63 |
+
root.mkdir(parents=True, exist_ok=True)
|
| 64 |
+
out_dir.mkdir(parents=True, exist_ok=True)
|
| 65 |
+
|
| 66 |
+
for it in ITEMS:
|
| 67 |
+
tar_path = out_dir / f"{it['name']}.tar"
|
| 68 |
+
|
| 69 |
+
download_with_progress(it["url"], tar_path)
|
| 70 |
+
extract_tar_flatten(tar_path, root)
|
| 71 |
+
|
| 72 |
+
print("All done.")
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
if __name__ == "__main__":
|
| 76 |
+
main()
|
onnx_eval.py
ADDED
|
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import json
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
|
| 5 |
+
sys.path.insert(0, Path(__file__).parent.as_posix())
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
import cv2
|
| 9 |
+
import pyiqa
|
| 10 |
+
import torch
|
| 11 |
+
import numpy as np
|
| 12 |
+
from tqdm import tqdm
|
| 13 |
+
from onnx_runner import OnnxRunner
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def collect_common_image_pairs(
|
| 17 |
+
lq_dir: Path, hq_dir: Path
|
| 18 |
+
) -> tuple[list[Path], list[Path]]:
|
| 19 |
+
exts = {".png", ".jpg", ".jpeg"}
|
| 20 |
+
|
| 21 |
+
def is_img(p: Path) -> bool:
|
| 22 |
+
return p.is_file() and p.suffix.lower() in exts
|
| 23 |
+
|
| 24 |
+
hq_map = {p.stem: p for p in hq_dir.iterdir() if is_img(p)}
|
| 25 |
+
hq_names = sorted(hq_map.keys())
|
| 26 |
+
|
| 27 |
+
lq_files = [p for p in lq_dir.iterdir() if is_img(p)]
|
| 28 |
+
|
| 29 |
+
lq_paths: list[Path] = []
|
| 30 |
+
hq_paths: list[Path] = []
|
| 31 |
+
for base in hq_names:
|
| 32 |
+
# try full match first
|
| 33 |
+
best_lq = next((p for p in lq_files if p.stem == base), None)
|
| 34 |
+
|
| 35 |
+
# try prefix match then
|
| 36 |
+
if best_lq is None:
|
| 37 |
+
best_lq = next(
|
| 38 |
+
(
|
| 39 |
+
p
|
| 40 |
+
for p in lq_files
|
| 41 |
+
if p.stem.startswith(base) and len(p.stem) > len(base)
|
| 42 |
+
),
|
| 43 |
+
None,
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
if best_lq is not None: # matched
|
| 47 |
+
hq_paths.append(hq_map[base])
|
| 48 |
+
lq_paths.append(best_lq)
|
| 49 |
+
|
| 50 |
+
return lq_paths, hq_paths
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def align_shape_by_crop(sr_bgr: np.ndarray, hq_bgr: np.ndarray):
|
| 54 |
+
if sr_bgr.shape != hq_bgr.shape:
|
| 55 |
+
min_h = min(sr_bgr.shape[0], hq_bgr.shape[0])
|
| 56 |
+
min_w = min(sr_bgr.shape[1], hq_bgr.shape[1])
|
| 57 |
+
|
| 58 |
+
sr_bgr = sr_bgr[:min_h, :min_w]
|
| 59 |
+
hq_bgr = hq_bgr[:min_h, :min_w]
|
| 60 |
+
|
| 61 |
+
return sr_bgr, hq_bgr
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def gen_sr_images(
|
| 65 |
+
hq_dir: Path, lq_dir: Path, out_dir: Path, onnx_path: Path, max_samples: int
|
| 66 |
+
):
|
| 67 |
+
out_dir.mkdir(exist_ok=True, parents=True)
|
| 68 |
+
|
| 69 |
+
onnx_runner = OnnxRunner(onnx_path, sr_scale=2, tile_overlap=8)
|
| 70 |
+
|
| 71 |
+
lq_paths, hq_paths = collect_common_image_pairs(lq_dir, hq_dir)
|
| 72 |
+
|
| 73 |
+
if max_samples is not None:
|
| 74 |
+
lq_paths = lq_paths[: max(max_samples, 1)]
|
| 75 |
+
hq_paths = hq_paths[: max(max_samples, 1)]
|
| 76 |
+
|
| 77 |
+
sr_paths = []
|
| 78 |
+
for i in tqdm(range(len(lq_paths)), desc="generating"):
|
| 79 |
+
lq_img_path = lq_paths[i]
|
| 80 |
+
lq_bgr = cv2.imread(lq_img_path.as_posix(), cv2.IMREAD_COLOR)
|
| 81 |
+
assert lq_bgr is not None
|
| 82 |
+
sr_bgr = onnx_runner.run(lq_bgr)
|
| 83 |
+
|
| 84 |
+
hq_img_path = hq_paths[i]
|
| 85 |
+
hq_bgr = cv2.imread(hq_img_path.as_posix(), cv2.IMREAD_COLOR)
|
| 86 |
+
|
| 87 |
+
aligned_sr_bgr, aligned_hq_bgr = align_shape_by_crop(sr_bgr, hq_bgr)
|
| 88 |
+
if aligned_hq_bgr.shape != hq_bgr.shape:
|
| 89 |
+
cv2.imwrite(hq_img_path.as_posix(), aligned_hq_bgr)
|
| 90 |
+
|
| 91 |
+
out_path = out_dir / f"{lq_img_path.stem}.png"
|
| 92 |
+
cv2.imwrite(out_path.as_posix(), aligned_sr_bgr)
|
| 93 |
+
|
| 94 |
+
sr_paths.append(out_path)
|
| 95 |
+
|
| 96 |
+
return hq_paths, sr_paths
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def eval_metrics(
|
| 100 |
+
hq_paths: list[Path],
|
| 101 |
+
sr_paths: list[Path],
|
| 102 |
+
hq_dir: Path,
|
| 103 |
+
sr_dir: Path,
|
| 104 |
+
device: torch.device | None = None,
|
| 105 |
+
) -> dict[str, float]:
|
| 106 |
+
assert len(hq_paths) == len(sr_paths)
|
| 107 |
+
|
| 108 |
+
device = device or (
|
| 109 |
+
torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
# FR: sr, ref
|
| 113 |
+
psnr_metric = pyiqa.create_metric("psnr", device=device, test_y_channel=True)
|
| 114 |
+
# FR: sr, ref
|
| 115 |
+
ms_ssim_metric = pyiqa.create_metric("ms_ssim", device=device, test_y_channel=True)
|
| 116 |
+
fid_metric = pyiqa.create_metric("fid")
|
| 117 |
+
|
| 118 |
+
with torch.inference_mode():
|
| 119 |
+
psnr_vals = []
|
| 120 |
+
ms_ssim_vals = []
|
| 121 |
+
for sr_p, hq_p in zip(sr_paths, hq_paths):
|
| 122 |
+
sr_p = sr_p.as_posix()
|
| 123 |
+
hq_p = hq_p.as_posix()
|
| 124 |
+
psnr_vals.append(psnr_metric(sr_p, hq_p).detach())
|
| 125 |
+
ms_ssim_vals.append(ms_ssim_metric(sr_p, hq_p).detach())
|
| 126 |
+
|
| 127 |
+
psnr = torch.stack(psnr_vals).mean().item()
|
| 128 |
+
ms_ssim = torch.stack(ms_ssim_vals).mean().item()
|
| 129 |
+
|
| 130 |
+
fid = fid_metric(
|
| 131 |
+
sr_dir.as_posix(),
|
| 132 |
+
hq_dir.as_posix(),
|
| 133 |
+
mode="clean",
|
| 134 |
+
batch_size=1,
|
| 135 |
+
num_workers=0,
|
| 136 |
+
).item()
|
| 137 |
+
|
| 138 |
+
return {"psnr": psnr, "ms_ssim": ms_ssim, "fid": fid}
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
def main(args):
|
| 142 |
+
onnx_path = Path(args.onnx)
|
| 143 |
+
hq_dir = Path(args.hq_dir)
|
| 144 |
+
lq_dir = Path(args.lq_dir)
|
| 145 |
+
out_dir = Path(args.out_dir)
|
| 146 |
+
|
| 147 |
+
assert onnx_path.suffix == ".onnx"
|
| 148 |
+
assert lq_dir.is_dir(), f"{lq_dir} is not a dir!"
|
| 149 |
+
assert hq_dir.is_dir(), f"{hq_dir} is not a dir!"
|
| 150 |
+
|
| 151 |
+
sr_dir = out_dir / "sr"
|
| 152 |
+
hq_paths, sr_paths = gen_sr_images(
|
| 153 |
+
hq_dir, lq_dir, sr_dir, onnx_path, args.max_samples
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
scores = eval_metrics(hq_paths, sr_paths, hq_dir, sr_dir)
|
| 157 |
+
|
| 158 |
+
summary = {
|
| 159 |
+
"onnx": onnx_path.as_posix(),
|
| 160 |
+
"psnr": scores["psnr"],
|
| 161 |
+
"ms_ssim": scores["ms_ssim"],
|
| 162 |
+
"fid": scores["fid"],
|
| 163 |
+
}
|
| 164 |
+
|
| 165 |
+
out_file = out_dir / f"eval_{onnx_path.stem}_result.json"
|
| 166 |
+
with open(out_file, "w") as f:
|
| 167 |
+
json.dump(summary, f, indent=2)
|
| 168 |
+
dataset_name = hq_dir.parent.name
|
| 169 |
+
print(f"summary of {dataset_name}: PSNR | MS_SSIM | FID")
|
| 170 |
+
print(
|
| 171 |
+
f"{dataset_name}: {scores['psnr']:.2f} | {scores['ms_ssim']:.4f} | {scores['fid']:.2f}"
|
| 172 |
+
)
|
| 173 |
+
print(f"result saved to {out_file}")
|
| 174 |
+
|
| 175 |
+
if args.clean:
|
| 176 |
+
import shutil
|
| 177 |
+
|
| 178 |
+
print(f"cleaning enhanced lq dir: {sr_dir}")
|
| 179 |
+
shutil.rmtree(sr_dir.as_posix(), ignore_errors=True)
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
if __name__ == "__main__":
|
| 183 |
+
from argparse import ArgumentParser
|
| 184 |
+
|
| 185 |
+
parser = ArgumentParser()
|
| 186 |
+
parser.add_argument("--onnx", type=str, required=True)
|
| 187 |
+
parser.add_argument("--hq-dir", type=str, required=True)
|
| 188 |
+
parser.add_argument("--lq-dir", type=str, required=True)
|
| 189 |
+
parser.add_argument("--out-dir", type=str, default="outputs")
|
| 190 |
+
parser.add_argument(
|
| 191 |
+
"--max-samples",
|
| 192 |
+
type=int,
|
| 193 |
+
default=None,
|
| 194 |
+
help="limit number of used samples(debug purpose only), None means not-limited",
|
| 195 |
+
)
|
| 196 |
+
parser.add_argument(
|
| 197 |
+
"-clean",
|
| 198 |
+
action="store_true",
|
| 199 |
+
default=False,
|
| 200 |
+
help="clean out-dir when finished",
|
| 201 |
+
)
|
| 202 |
+
main(parser.parse_args())
|
onnx_inference.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pathlib import Path
|
| 2 |
+
import sys
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
sys.path.insert(0, Path(__file__).parent.as_posix())
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
import cv2
|
| 9 |
+
from onnx_runner import OnnxRunner
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def main(args):
|
| 13 |
+
onnx_path = Path(args.onnx)
|
| 14 |
+
input_path = Path(args.input)
|
| 15 |
+
out_dir = Path(args.out_dir)
|
| 16 |
+
|
| 17 |
+
assert onnx_path.suffix == ".onnx"
|
| 18 |
+
|
| 19 |
+
onnx_runner = OnnxRunner(onnx_path, sr_scale=2, debug=False)
|
| 20 |
+
|
| 21 |
+
if input_path.is_file():
|
| 22 |
+
input_images_path = [input_path]
|
| 23 |
+
else:
|
| 24 |
+
input_images_path = sorted(
|
| 25 |
+
[
|
| 26 |
+
p
|
| 27 |
+
for p in input_path.rglob("*")
|
| 28 |
+
if p.suffix.lower() in (".png", ".jpg", ".jpeg")
|
| 29 |
+
]
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
out_dir.mkdir(exist_ok=True, parents=True)
|
| 33 |
+
for input_img_path in input_images_path:
|
| 34 |
+
input_img_path: Path
|
| 35 |
+
|
| 36 |
+
img_bgr = cv2.imread(input_img_path.as_posix(), cv2.IMREAD_COLOR)
|
| 37 |
+
assert img_bgr is not None
|
| 38 |
+
sr_img_bgr = onnx_runner.run(img_bgr)
|
| 39 |
+
|
| 40 |
+
out_path = out_dir / f"{input_img_path.stem}.png"
|
| 41 |
+
cv2.imwrite(out_path.as_posix(), sr_img_bgr)
|
| 42 |
+
print(f"saved {out_path}")
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
if __name__ == "__main__":
|
| 46 |
+
from argparse import ArgumentParser
|
| 47 |
+
|
| 48 |
+
parser = ArgumentParser()
|
| 49 |
+
parser.add_argument("--onnx", type=str, required=True)
|
| 50 |
+
parser.add_argument("--input", type=str, required=True)
|
| 51 |
+
parser.add_argument("--out-dir", type=str, required=True)
|
| 52 |
+
|
| 53 |
+
main(parser.parse_args())
|
onnx_runner.py
ADDED
|
@@ -0,0 +1,275 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pathlib import Path
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
import onnxruntime as ort
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
__all__ = [
|
| 9 |
+
"OnnxRunner",
|
| 10 |
+
]
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def split_into_tiles_with_context(
|
| 14 |
+
img_chw: np.ndarray,
|
| 15 |
+
patch_size_hw: tuple[int, int],
|
| 16 |
+
overlap: int,
|
| 17 |
+
):
|
| 18 |
+
"""
|
| 19 |
+
Args:
|
| 20 |
+
img_chw: (C, H, W)
|
| 21 |
+
patch_size_hw: (ph, pw) size of each patch.
|
| 22 |
+
overlap: overlap of neighbored patches.
|
| 23 |
+
|
| 24 |
+
Returns:
|
| 25 |
+
tiles_chw: list[np.ndarray], each tile in shape of [C, ph, pw]
|
| 26 |
+
orig_hw: (H, W)
|
| 27 |
+
padded_hw: (H_pad, W_pad)
|
| 28 |
+
"""
|
| 29 |
+
import math
|
| 30 |
+
|
| 31 |
+
assert img_chw.ndim == 3
|
| 32 |
+
C, H, W = img_chw.shape
|
| 33 |
+
ph, pw = patch_size_hw
|
| 34 |
+
|
| 35 |
+
assert 2 * overlap < ph and 2 * overlap < pw, "2*overlap must <= patch_size"
|
| 36 |
+
|
| 37 |
+
# core region size(remove overlap region)
|
| 38 |
+
core_h = ph - 2 * overlap
|
| 39 |
+
core_w = pw - 2 * overlap
|
| 40 |
+
assert core_h > 0 and core_w > 0
|
| 41 |
+
|
| 42 |
+
# compute how much tiles required
|
| 43 |
+
n_tiles_h = math.ceil(H / core_h)
|
| 44 |
+
n_tiles_w = math.ceil(W / core_w)
|
| 45 |
+
|
| 46 |
+
# center padded size
|
| 47 |
+
H_pad = n_tiles_h * core_h
|
| 48 |
+
W_pad = n_tiles_w * core_w
|
| 49 |
+
|
| 50 |
+
# first padding, make sure padded image divisible by patch size
|
| 51 |
+
pad_h = H_pad - H
|
| 52 |
+
pad_w = W_pad - W
|
| 53 |
+
img_pad = np.pad(
|
| 54 |
+
img_chw,
|
| 55 |
+
pad_width=((0, 0), (0, pad_h), (0, pad_w)),
|
| 56 |
+
mode="reflect",
|
| 57 |
+
) # (C, H_pad, W_pad)
|
| 58 |
+
|
| 59 |
+
# second padding, add reflect context for boundaries
|
| 60 |
+
big_pad = np.pad(
|
| 61 |
+
img_pad,
|
| 62 |
+
pad_width=((0, 0), (overlap, overlap), (overlap, overlap)),
|
| 63 |
+
mode="reflect",
|
| 64 |
+
) # (C, H_pad+2o, W_pad+2o)
|
| 65 |
+
|
| 66 |
+
tiles = []
|
| 67 |
+
for iy in range(n_tiles_h):
|
| 68 |
+
for ix in range(n_tiles_w):
|
| 69 |
+
cy = iy * core_h
|
| 70 |
+
cx = ix * core_w
|
| 71 |
+
y0 = cy
|
| 72 |
+
x0 = cx
|
| 73 |
+
tile = big_pad[:, y0 : y0 + ph, x0 : x0 + pw]
|
| 74 |
+
tiles.append(tile)
|
| 75 |
+
|
| 76 |
+
return tiles, (H, W), (H_pad, W_pad)
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def merge_tiles_with_context(
|
| 80 |
+
tiles_chw: list[np.ndarray],
|
| 81 |
+
orig_hw: tuple[int, int],
|
| 82 |
+
padded_hw: tuple[int, int],
|
| 83 |
+
overlap: int,
|
| 84 |
+
) -> np.ndarray:
|
| 85 |
+
"""
|
| 86 |
+
Args:
|
| 87 |
+
tiles_chw:
|
| 88 |
+
orig_hw: original image size.
|
| 89 |
+
padded_hw: center-padded image size.
|
| 90 |
+
overlap: overlap of neighbored patches.
|
| 91 |
+
|
| 92 |
+
Returns:
|
| 93 |
+
img_chw: (C, H, W)
|
| 94 |
+
"""
|
| 95 |
+
assert len(tiles_chw) > 0
|
| 96 |
+
C, ph, pw = tiles_chw[0].shape
|
| 97 |
+
H, W = orig_hw
|
| 98 |
+
H_pad, W_pad = padded_hw
|
| 99 |
+
|
| 100 |
+
assert 2 * overlap < ph and 2 * overlap < pw
|
| 101 |
+
core_h = ph - 2 * overlap
|
| 102 |
+
core_w = pw - 2 * overlap
|
| 103 |
+
n_h = H_pad // core_h
|
| 104 |
+
n_w = W_pad // core_w
|
| 105 |
+
assert n_h * n_w == len(tiles_chw), "tiles != padded_hw"
|
| 106 |
+
|
| 107 |
+
img_pad_recon = np.zeros((C, H_pad, W_pad), dtype=tiles_chw[0].dtype)
|
| 108 |
+
|
| 109 |
+
idx = 0
|
| 110 |
+
for iy in range(n_h):
|
| 111 |
+
for ix in range(n_w):
|
| 112 |
+
cy = iy * core_h
|
| 113 |
+
cx = ix * core_w
|
| 114 |
+
|
| 115 |
+
tile = tiles_chw[idx]
|
| 116 |
+
core = tile[:, overlap : overlap + core_h, overlap : overlap + core_w]
|
| 117 |
+
img_pad_recon[:, cy : cy + core_h, cx : cx + core_w] = core
|
| 118 |
+
idx += 1
|
| 119 |
+
|
| 120 |
+
img_recon = img_pad_recon[:, :H, :W]
|
| 121 |
+
return np.ascontiguousarray(img_recon)
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def parse_input_shape_fmt(input_shape):
|
| 125 |
+
"""parse input shape is nchw or nhwc format.
|
| 126 |
+
We assume c is smaller than h&w dimensions
|
| 127 |
+
"""
|
| 128 |
+
assert len(input_shape) == 4
|
| 129 |
+
|
| 130 |
+
c1, c2, c3 = input_shape[1:]
|
| 131 |
+
|
| 132 |
+
if c1 < min(c2, c3): # c1 is channel dimension
|
| 133 |
+
return "nchw"
|
| 134 |
+
elif c3 < min(c1, c2): # c3 is channel dimension
|
| 135 |
+
return "nhwc"
|
| 136 |
+
else:
|
| 137 |
+
raise ValueError(f"can not parse input format for shape: {input_shape}")
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
def is_channel_last(img: np.ndarray):
|
| 141 |
+
return img.shape[2] < min(img.shape[0], img.shape[1])
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def preprocess(img_bgr: np.ndarray):
|
| 145 |
+
"""Convert bgr channel last uint8 image to rgb channel first fp32 image."""
|
| 146 |
+
img_rgb = img_bgr[..., ::-1] # bgr -> rgb
|
| 147 |
+
img_chw = np.transpose(img_rgb, [2, 0, 1]) # hwc -> chw
|
| 148 |
+
img_chw = np.float32(img_chw) # uint8 -> fp32
|
| 149 |
+
|
| 150 |
+
return np.ascontiguousarray(img_chw)
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
def postprocess(pred_chw: np.ndarray):
|
| 154 |
+
"""Convert rgb channel first fp32 image to bgr channel last uint8 image"""
|
| 155 |
+
|
| 156 |
+
uint8_chw = pred_chw.clip(0, 255).astype(np.uint8) # fp32 -> uint8
|
| 157 |
+
img_rgb = np.transpose(uint8_chw, [1, 2, 0]) # chw -> hwc
|
| 158 |
+
img_bgr = img_rgb[..., ::-1] # rgb to bgr
|
| 159 |
+
|
| 160 |
+
return np.ascontiguousarray(img_bgr)
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
class OnnxRunner:
|
| 164 |
+
"""Single Image Super Resolution onnx runner."""
|
| 165 |
+
|
| 166 |
+
def __init__(self, onnx_path, sr_scale, tile_overlap: int = 8, debug=False):
|
| 167 |
+
if "CUDAExecutionProvider" in ort.get_available_providers():
|
| 168 |
+
providers = ["CUDAExecutionProvider"]
|
| 169 |
+
else:
|
| 170 |
+
providers = ["CPUExecutionProvider"]
|
| 171 |
+
|
| 172 |
+
ort_session = ort.InferenceSession(str(onnx_path), providers=providers)
|
| 173 |
+
|
| 174 |
+
input0 = ort_session.get_inputs()[0]
|
| 175 |
+
self.input_name = input0.name
|
| 176 |
+
self.input_shape = tuple(input0.shape)
|
| 177 |
+
self.input_format = parse_input_shape_fmt(input0.shape)
|
| 178 |
+
self.ort_session = ort_session
|
| 179 |
+
self.sr_scale = sr_scale
|
| 180 |
+
self.tile_overlap = max(tile_overlap, 0)
|
| 181 |
+
self.debug = debug
|
| 182 |
+
|
| 183 |
+
if self.input_format == "nchw":
|
| 184 |
+
self._in_h, self._in_w = self.input_shape[2:]
|
| 185 |
+
else: # nhwc
|
| 186 |
+
self._in_h, self._in_w = self.input_shape[1:3]
|
| 187 |
+
|
| 188 |
+
if debug:
|
| 189 |
+
self._dbg_out_dir = Path(__file__).parent / "outputs"
|
| 190 |
+
self._dbg_out_dir.mkdir(exist_ok=True, parents=True)
|
| 191 |
+
|
| 192 |
+
def _save_dbg_img(self, savename, img):
|
| 193 |
+
if not self.debug:
|
| 194 |
+
return
|
| 195 |
+
import cv2
|
| 196 |
+
|
| 197 |
+
cv2.imwrite(str(self._dbg_out_dir / savename), img)
|
| 198 |
+
|
| 199 |
+
def run(self, img_bgr: np.ndarray) -> np.ndarray:
|
| 200 |
+
"""Do 2x scale-up super resolution on given uint8 bgr image,
|
| 201 |
+
and return scaled uint8 bgr image.
|
| 202 |
+
"""
|
| 203 |
+
assert img_bgr.dtype == np.uint8, img_bgr.dtype
|
| 204 |
+
assert img_bgr.ndim in (2, 3), img_bgr.ndim
|
| 205 |
+
|
| 206 |
+
if img_bgr.ndim == 3:
|
| 207 |
+
assert is_channel_last(img_bgr), img_bgr.shape
|
| 208 |
+
|
| 209 |
+
if self.debug:
|
| 210 |
+
self._save_dbg_img("original_input_bgr.png", img_bgr)
|
| 211 |
+
|
| 212 |
+
# =====================
|
| 213 |
+
# preprocessing
|
| 214 |
+
# =====================
|
| 215 |
+
img_chw = preprocess(img_bgr)
|
| 216 |
+
tiles_chw, origin_size_hw, padded_size_hw = split_into_tiles_with_context(
|
| 217 |
+
img_chw, (self._in_h, self._in_w), self.tile_overlap
|
| 218 |
+
)
|
| 219 |
+
if self.debug:
|
| 220 |
+
print(f"tiling to {len(tiles_chw)} tiles")
|
| 221 |
+
tile_bgr = postprocess(tiles_chw[0])
|
| 222 |
+
self._save_dbg_img("tile_bgr.png", tile_bgr)
|
| 223 |
+
|
| 224 |
+
# =====================
|
| 225 |
+
# inference
|
| 226 |
+
# =====================
|
| 227 |
+
sr_tiles_chw = []
|
| 228 |
+
for tile_chw in tiles_chw:
|
| 229 |
+
if self.input_format == "nhwc":
|
| 230 |
+
input_3d = np.transpose(tile_chw, [1, 2, 0]) # chw -> hwc
|
| 231 |
+
else:
|
| 232 |
+
input_3d = tile_chw
|
| 233 |
+
|
| 234 |
+
outputs = self.ort_session.run(None, {self.input_name: input_3d[None, ...]})
|
| 235 |
+
sr_tile = outputs[0][0] # chw or hwc format
|
| 236 |
+
|
| 237 |
+
if self.input_format == "nhwc":
|
| 238 |
+
sr_tile_chw = np.transpose(sr_tile, [2, 0, 1]) # hwc -> chw
|
| 239 |
+
else:
|
| 240 |
+
sr_tile_chw = sr_tile
|
| 241 |
+
|
| 242 |
+
sr_tiles_chw.append(sr_tile_chw)
|
| 243 |
+
|
| 244 |
+
if self.debug:
|
| 245 |
+
sr_padded_tile_bgr = postprocess(sr_tiles_chw[0])
|
| 246 |
+
self._save_dbg_img("sr_padded_tile_bgr.png", sr_padded_tile_bgr)
|
| 247 |
+
|
| 248 |
+
# =====================
|
| 249 |
+
# postprocessing
|
| 250 |
+
# =====================
|
| 251 |
+
sr_origin_hw = (
|
| 252 |
+
int(origin_size_hw[0] * self.sr_scale),
|
| 253 |
+
int(origin_size_hw[1] * self.sr_scale),
|
| 254 |
+
)
|
| 255 |
+
sr_padded_hw = (
|
| 256 |
+
int(padded_size_hw[0] * self.sr_scale),
|
| 257 |
+
int(padded_size_hw[1] * self.sr_scale),
|
| 258 |
+
)
|
| 259 |
+
sr_overlap = int(self.tile_overlap * self.sr_scale)
|
| 260 |
+
|
| 261 |
+
sr_padded_chw = merge_tiles_with_context(
|
| 262 |
+
sr_tiles_chw,
|
| 263 |
+
orig_hw=sr_origin_hw,
|
| 264 |
+
padded_hw=sr_padded_hw,
|
| 265 |
+
overlap=sr_overlap,
|
| 266 |
+
)
|
| 267 |
+
if self.debug:
|
| 268 |
+
sr_padded_img_bgr = postprocess(sr_padded_chw)
|
| 269 |
+
self._save_dbg_img("sr_padded_img_bgr.png", sr_padded_img_bgr)
|
| 270 |
+
|
| 271 |
+
sr_chw = sr_padded_chw[..., : sr_origin_hw[0], : sr_origin_hw[1]]
|
| 272 |
+
|
| 273 |
+
sr_img_bgr = postprocess(sr_chw)
|
| 274 |
+
|
| 275 |
+
return sr_img_bgr
|
requirements-eval.txt
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
onnxruntime==1.22
|
| 2 |
+
numpy==1.26.*
|
| 3 |
+
opencv-python==4.8.*
|
| 4 |
+
tqdm
|
| 5 |
+
torch==2.6.0
|
| 6 |
+
pyiqa @ git+https://github.com/chaofengc/IQA-PyTorch.git@e851fd62e66a97345e1281d80e8deb4ab7b93c83
|
requirements-infer.txt
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
onnxruntime==1.22
|
| 2 |
+
numpy==1.26.*
|
| 3 |
+
opencv-python==4.8.*
|
| 4 |
+
tqdm
|
sesr_nchw_fp32.onnx
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4b686864a8b17cf9aaad0d787f7b7a133c95317f408cac5204701d7291199711
|
| 3 |
+
size 93732
|
sesr_nchw_int8.onnx
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:11b6adfb3d5d9cc46405af6e684237a1efa7ff6956f82c489631476edc813237
|
| 3 |
+
size 120432
|