Spaces:
Sleeping
Sleeping
xco2
commited on
Commit
·
ebf6d7b
1
Parent(s):
687cb7c
init
Browse files- net/UNet.py +0 -96
- requirements.txt +2 -179
net/UNet.py
CHANGED
|
@@ -422,99 +422,3 @@ class UNet(nn.Module):
|
|
| 422 |
# print("decoder:")
|
| 423 |
# print(decoder_out.shape)
|
| 424 |
return decoder_out
|
| 425 |
-
|
| 426 |
-
|
| 427 |
-
if __name__ == '__main__':
|
| 428 |
-
import cv2, os
|
| 429 |
-
|
| 430 |
-
|
| 431 |
-
def modelSave(model, save_path, save_name):
|
| 432 |
-
if not os.path.exists(save_path):
|
| 433 |
-
os.mkdir(save_path)
|
| 434 |
-
torch.save(model.state_dict(), os.path.join(save_path, save_name))
|
| 435 |
-
|
| 436 |
-
|
| 437 |
-
def merge_images(images: np.ndarray):
|
| 438 |
-
"""
|
| 439 |
-
合并图像
|
| 440 |
-
:param images: 图像数组
|
| 441 |
-
:return: 合并后的图像数组
|
| 442 |
-
"""
|
| 443 |
-
n, h, w, c = images.shape
|
| 444 |
-
nn = int(np.ceil(n ** 0.5))
|
| 445 |
-
merged_image = np.zeros((h * nn, w * nn, 3), dtype=images.dtype)
|
| 446 |
-
for i in range(n):
|
| 447 |
-
row = i // nn
|
| 448 |
-
col = i % nn
|
| 449 |
-
merged_image[row * h:(row + 1) * h, col * w:(col + 1) * w, :] = images[i]
|
| 450 |
-
|
| 451 |
-
merged_image = np.clip(merged_image, 0, 255)
|
| 452 |
-
merged_image = np.array(merged_image, dtype=np.uint8)
|
| 453 |
-
return merged_image
|
| 454 |
-
|
| 455 |
-
|
| 456 |
-
# 320,448,576,832
|
| 457 |
-
config = { # 模型结构相关
|
| 458 |
-
"en_out_c": (256, 256, 256, 320, 320, 320, 576, 576, 576, 704, 704, 704),
|
| 459 |
-
"en_down": (0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0),
|
| 460 |
-
"en_skip": (0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1),
|
| 461 |
-
"en_att_heads": (8, 8, 8, 0, 8, 8, 0, 8, 8, 0, 8, 8),
|
| 462 |
-
"de_out_c": (704, 576, 576, 576, 320, 320, 320, 256, 256, 256, 256),
|
| 463 |
-
"de_up": ("none", "subpix", "none", "none", "subpix", "none", "none", "subpix", "none", "none", "none"),
|
| 464 |
-
"de_skip": (1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0),
|
| 465 |
-
"de_att_heads": (8, 8, 0, 8, 8, 0, 8, 8, 0, 8, 8), # skip的地方不做self-attention
|
| 466 |
-
"t_out_c": 256,
|
| 467 |
-
"vae_c": 4,
|
| 468 |
-
"block_deep": 3,
|
| 469 |
-
}
|
| 470 |
-
device = "cuda"
|
| 471 |
-
total_step = 1000
|
| 472 |
-
|
| 473 |
-
unet = UNet(config["en_out_c"], config["en_down"], config["en_skip"], config["en_att_heads"],
|
| 474 |
-
config["de_out_c"], config["de_up"], config["de_skip"], config["de_att_heads"],
|
| 475 |
-
config["t_out_c"], config["vae_c"], config["block_deep"]).to(device)
|
| 476 |
-
|
| 477 |
-
print("总参数", sum(i.numel() for i in unet.parameters()) / 10000, "单位:万")
|
| 478 |
-
print("encoder", sum(i.numel() for i in unet.encoder.parameters()) / 10000, "单位:万")
|
| 479 |
-
print("decoder", sum(i.numel() for i in unet.decoder.parameters()) / 10000, "单位:万")
|
| 480 |
-
print("t", sum(i.numel() for i in unet.t_encoder.parameters()) / 10000, "单位:万")
|
| 481 |
-
|
| 482 |
-
batch_size = 2
|
| 483 |
-
x = np.random.random((batch_size, config["vae_c"], 32, 32))
|
| 484 |
-
t = np.random.uniform(1, total_step + 0.9999, size=(batch_size, 1))
|
| 485 |
-
t = np.array(t, dtype=np.int16)
|
| 486 |
-
t = t / total_step
|
| 487 |
-
|
| 488 |
-
with torch.no_grad():
|
| 489 |
-
x = torch.Tensor(x).to(device)
|
| 490 |
-
t = torch.Tensor(t).to(device)
|
| 491 |
-
y = unet(x, t)
|
| 492 |
-
print(y.shape)
|
| 493 |
-
|
| 494 |
-
z = y[0].cpu().numpy()
|
| 495 |
-
# z = (z - np.mean(z)) / (np.max(z) - np.min(z))
|
| 496 |
-
z = np.clip(np.asarray((z + 1) * 127.5), 0, 255)
|
| 497 |
-
z = np.asarray(z, dtype=np.uint8)
|
| 498 |
-
|
| 499 |
-
z = [np.tile(z[ii, :, :, np.newaxis], (1, 1, 3)) for ii in range(z.shape[0])]
|
| 500 |
-
noise = merge_images(np.array(z))
|
| 501 |
-
|
| 502 |
-
noise = cv2.resize(noise, None, fx=2, fy=2)
|
| 503 |
-
cv2.imshow("noise", noise)
|
| 504 |
-
cv2.waitKey(0)
|
| 505 |
-
|
| 506 |
-
# modelSave(unet, "./", "test.pth")
|
| 507 |
-
# 导出为onnx格式
|
| 508 |
-
torch.onnx.export(
|
| 509 |
-
unet,
|
| 510 |
-
(x, t),
|
| 511 |
-
'unet.onnx',
|
| 512 |
-
export_params=True,
|
| 513 |
-
opset_version=12,
|
| 514 |
-
)
|
| 515 |
-
import onnx
|
| 516 |
-
|
| 517 |
-
# 增加维度信息
|
| 518 |
-
model_file = 'unet.onnx'
|
| 519 |
-
onnx_model = onnx.load(model_file)
|
| 520 |
-
onnx.save(onnx.shape_inference.infer_shapes(onnx_model), model_file)
|
|
|
|
| 422 |
# print("decoder:")
|
| 423 |
# print(decoder_out.shape)
|
| 424 |
return decoder_out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
requirements.txt
CHANGED
|
@@ -1,186 +1,9 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
aiofiles==23.1.0
|
| 4 |
-
aiohttp==3.8.3
|
| 5 |
-
aiosignal==1.3.1
|
| 6 |
-
aliyun-python-sdk-core==2.13.36
|
| 7 |
-
aliyun-python-sdk-kms==2.16.0
|
| 8 |
-
altair==4.2.0
|
| 9 |
-
anyio==3.6.2
|
| 10 |
-
appdirs==1.4.4
|
| 11 |
-
asttokens==2.3.0
|
| 12 |
-
async-timeout==4.0.2
|
| 13 |
-
attrs==22.1.0
|
| 14 |
-
audioread==3.0.0
|
| 15 |
-
backcall==0.2.0
|
| 16 |
-
certifi==2022.12.7
|
| 17 |
-
cffi==1.15.1
|
| 18 |
-
charset-normalizer==2.1.1
|
| 19 |
-
chumpy==0.70
|
| 20 |
-
click==8.1.3
|
| 21 |
-
clip==1.0
|
| 22 |
-
colorama==0.4.6
|
| 23 |
-
commonmark==0.9.1
|
| 24 |
-
contourpy==1.0.6
|
| 25 |
-
cpm-kernels==1.0.11
|
| 26 |
-
crcmod==1.7
|
| 27 |
-
cryptography==39.0.2
|
| 28 |
-
cycler==0.11.0
|
| 29 |
-
Cython==0.29.32
|
| 30 |
-
datasets==2.8.0
|
| 31 |
-
decorator==5.1.1
|
| 32 |
-
decord==0.6.0
|
| 33 |
-
diffusers==0.20.1
|
| 34 |
-
dill==0.3.6
|
| 35 |
-
docker-pycreds==0.4.0
|
| 36 |
-
einops==0.6.0
|
| 37 |
-
entrypoints==0.4
|
| 38 |
-
exceptiongroup==1.1.3
|
| 39 |
-
executing==1.2.0
|
| 40 |
-
fastapi==0.88.0
|
| 41 |
-
ffmpy==0.3.0
|
| 42 |
-
filelock==3.8.2
|
| 43 |
-
Flask==2.0.2
|
| 44 |
-
Flask-Cors==3.0.10
|
| 45 |
-
fonttools==4.38.0
|
| 46 |
-
frozenlist==1.3.3
|
| 47 |
-
fsspec==2022.11.0
|
| 48 |
-
ftfy==6.1.1
|
| 49 |
-
gast==0.5.3
|
| 50 |
-
gitdb==4.0.10
|
| 51 |
-
GitPython==3.1.32
|
| 52 |
-
gradio==3.39.0
|
| 53 |
-
gradio_client==0.3.0
|
| 54 |
-
h11==0.14.0
|
| 55 |
-
httpcore==0.16.2
|
| 56 |
-
httpx==0.23.1
|
| 57 |
huggingface-hub==0.16.4
|
| 58 |
-
icetk==0.0.4
|
| 59 |
-
idna==3.4
|
| 60 |
-
importlib-metadata==5.2.0
|
| 61 |
-
ipython==8.15.0
|
| 62 |
-
itsdangerous==2.1.2
|
| 63 |
-
jedi==0.19.0
|
| 64 |
-
Jinja2==3.1.2
|
| 65 |
-
jmespath==0.10.0
|
| 66 |
-
joblib==1.2.0
|
| 67 |
-
json-tricks==3.16.1
|
| 68 |
-
jsonplus==0.8.0
|
| 69 |
-
jsonschema==4.17.3
|
| 70 |
-
kiwisolver==1.4.4
|
| 71 |
-
lazy_loader==0.1
|
| 72 |
-
librosa==0.10.0
|
| 73 |
-
linkify-it-py==1.0.3
|
| 74 |
-
lion-pytorch==0.1.2
|
| 75 |
-
llvmlite==0.39.1
|
| 76 |
-
loguru==0.6.0
|
| 77 |
-
Markdown==3.4.1
|
| 78 |
-
markdown-it-py==2.1.0
|
| 79 |
-
MarkupSafe==2.1.1
|
| 80 |
-
matplotlib==3.6.2
|
| 81 |
-
matplotlib-inline==0.1.6
|
| 82 |
-
mdit-py-plugins==0.3.3
|
| 83 |
-
mdurl==0.1.2
|
| 84 |
-
mediapipe==0.8.11
|
| 85 |
-
mmcv-full==1.7.0
|
| 86 |
-
mmdet==2.26.0
|
| 87 |
-
model-index==0.1.11
|
| 88 |
-
modelscope==1.3.2
|
| 89 |
-
mpmath==1.2.1
|
| 90 |
-
msgpack==1.0.4
|
| 91 |
-
multidict==6.0.3
|
| 92 |
-
multiprocess==0.70.14
|
| 93 |
-
munkres==1.1.4
|
| 94 |
-
networkx==3.0
|
| 95 |
-
numba==0.56.4
|
| 96 |
numpy==1.23.4
|
| 97 |
-
onnx==1.14.1
|
| 98 |
-
opencv-contrib-python==4.5.5.64
|
| 99 |
-
opencv-python==4.5.5.64
|
| 100 |
-
openmim==0.3.3
|
| 101 |
-
ordered-set==4.1.0
|
| 102 |
-
orjson==3.8.3
|
| 103 |
-
oss2==2.16.0
|
| 104 |
-
packaging==21.3
|
| 105 |
-
pandas==1.5.2
|
| 106 |
-
parso==0.8.3
|
| 107 |
-
pathtools==0.1.2
|
| 108 |
-
pickleshare==0.7.5
|
| 109 |
-
Pillow==9.2.0
|
| 110 |
-
pip==23.1.2
|
| 111 |
-
platformdirs==3.1.0
|
| 112 |
-
plotly==5.11.0
|
| 113 |
-
pooch==1.7.0
|
| 114 |
-
prodigyopt==1.0
|
| 115 |
-
prompt-toolkit==3.0.39
|
| 116 |
-
protobuf==4.24.2
|
| 117 |
-
psutil==5.9.5
|
| 118 |
-
pure-eval==0.2.2
|
| 119 |
-
pyarrow==11.0.0
|
| 120 |
-
pycocotools==2.0.6
|
| 121 |
-
pycparser==2.21
|
| 122 |
-
pycryptodome==3.16.0
|
| 123 |
-
pydantic==1.10.2
|
| 124 |
-
pydub==0.25.1
|
| 125 |
-
Pygments==2.13.0
|
| 126 |
-
pyparsing==3.0.9
|
| 127 |
-
pyrsistent==0.19.2
|
| 128 |
-
python-dateutil==2.8.2
|
| 129 |
-
python-multipart==0.0.5
|
| 130 |
-
pytorch-fid==0.3.0
|
| 131 |
-
pytz==2022.6
|
| 132 |
-
PyYAML==6.0
|
| 133 |
-
regex==2022.10.31
|
| 134 |
-
requests==2.28.1
|
| 135 |
-
responses==0.18.0
|
| 136 |
-
rfc3986==1.5.0
|
| 137 |
-
rich==12.6.0
|
| 138 |
-
safetensors==0.3.3
|
| 139 |
-
scikit-learn==1.2.1
|
| 140 |
-
scipy==1.9.3
|
| 141 |
-
semantic-version==2.10.0
|
| 142 |
-
sentencepiece==0.1.97
|
| 143 |
-
sentry-sdk==1.28.0
|
| 144 |
-
setproctitle==1.3.2
|
| 145 |
-
setuptools==65.5.0
|
| 146 |
-
simplejson==3.18.3
|
| 147 |
-
six==1.16.0
|
| 148 |
-
smmap==5.0.0
|
| 149 |
-
sniffio==1.3.0
|
| 150 |
-
sortedcontainers==2.4.0
|
| 151 |
-
soundfile==0.12.1
|
| 152 |
-
soxr==0.3.4
|
| 153 |
-
stack-data==0.6.2
|
| 154 |
-
starlette==0.22.0
|
| 155 |
-
sympy==1.11.1
|
| 156 |
-
tabulate==0.9.0
|
| 157 |
-
tenacity==8.1.0
|
| 158 |
-
terminaltables==3.1.10
|
| 159 |
-
threadpoolctl==3.1.0
|
| 160 |
-
timm==0.4.9
|
| 161 |
-
tokenizers==0.13.2
|
| 162 |
-
toolz==0.12.0
|
| 163 |
torch==2.0.0+cu117
|
| 164 |
torchaudio==2.0.1+cu117
|
| 165 |
torchinfo==1.7.1
|
| 166 |
torchvision==0.15.1+cu117
|
| 167 |
tqdm==4.64.1
|
| 168 |
-
traitlets==5.9.0
|
| 169 |
-
transformers==4.26.1
|
| 170 |
-
typing_extensions==4.4.0
|
| 171 |
-
uc-micro-py==1.0.1
|
| 172 |
-
unicodedata2==15.0.0
|
| 173 |
-
urllib3==1.26.12
|
| 174 |
-
uvicorn==0.20.0
|
| 175 |
-
wandb==0.15.5
|
| 176 |
-
wcwidth==0.2.5
|
| 177 |
-
websockets==10.4
|
| 178 |
-
Werkzeug==2.2.2
|
| 179 |
-
wheel==0.37.1
|
| 180 |
-
win32-setctime==1.1.0
|
| 181 |
-
wincertstore==0.2
|
| 182 |
-
xtcocotools==1.12
|
| 183 |
-
xxhash==3.2.0
|
| 184 |
-
yapf==0.32.0
|
| 185 |
-
yarl==1.8.2
|
| 186 |
-
zipp==3.11.0
|
|
|
|
| 1 |
+
gradio
|
| 2 |
+
gradio_client
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
huggingface-hub==0.16.4
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
numpy==1.23.4
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
torch==2.0.0+cu117
|
| 6 |
torchaudio==2.0.1+cu117
|
| 7 |
torchinfo==1.7.1
|
| 8 |
torchvision==0.15.1+cu117
|
| 9 |
tqdm==4.64.1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|