Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
|
@@ -13,6 +13,7 @@ from PIL import Image
|
|
| 13 |
import requests
|
| 14 |
import yaml
|
| 15 |
import numpy as np
|
|
|
|
| 16 |
|
| 17 |
from src.core import YAMLConfig
|
| 18 |
|
|
@@ -108,7 +109,7 @@ def download_weights(model_name):
|
|
| 108 |
print(f"Downloaded weights to: {weights_path}")
|
| 109 |
return weights_path
|
| 110 |
|
| 111 |
-
|
| 112 |
def process_image_for_gradio(model, device, image, model_name, threshold=0.4):
|
| 113 |
"""Process image function for Gradio interface"""
|
| 114 |
if isinstance(image, np.ndarray):
|
|
@@ -185,10 +186,37 @@ class ModelWrapper(nn.Module):
|
|
| 185 |
return outputs
|
| 186 |
|
| 187 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 188 |
def load_model(model_name):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 189 |
cfgfile = model_configs[model_name]["cfgfile"]
|
| 190 |
weights_path = download_weights(model_name)
|
| 191 |
|
|
|
|
| 192 |
cfg = YAMLConfig(cfgfile, resume=weights_path)
|
| 193 |
|
| 194 |
if "HGNetv2" in cfg.yaml_cfg:
|
|
@@ -197,7 +225,11 @@ def load_model(model_name):
|
|
| 197 |
checkpoint = torch.load(weights_path, map_location="cpu")
|
| 198 |
state = checkpoint["ema"]["module"] if "ema" in checkpoint else checkpoint["model"]
|
| 199 |
|
| 200 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 201 |
|
| 202 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 203 |
model = ModelWrapper(cfg).to(device)
|
|
@@ -205,26 +237,34 @@ def load_model(model_name):
|
|
| 205 |
|
| 206 |
return model, device
|
| 207 |
|
| 208 |
-
|
| 209 |
-
# Dictionary to store loaded models
|
| 210 |
-
loaded_models = {}
|
| 211 |
-
|
| 212 |
@spaces.GPU
|
| 213 |
def process_image(image, model_name, confidence_threshold):
|
| 214 |
"""Main processing function for Gradio interface"""
|
| 215 |
-
global loaded_models
|
| 216 |
|
| 217 |
-
#
|
| 218 |
-
if
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 219 |
print(f"Loading model: {model_name}")
|
| 220 |
model, device = load_model(model_name)
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 225 |
|
| 226 |
-
|
| 227 |
-
return process_image_for_gradio(model, device, image, model_name, confidence_threshold)
|
| 228 |
|
| 229 |
|
| 230 |
# Create Gradio interface
|
|
@@ -256,4 +296,6 @@ demo = gr.Interface(
|
|
| 256 |
]
|
| 257 |
)
|
| 258 |
|
| 259 |
-
|
|
|
|
|
|
|
|
|
| 13 |
import requests
|
| 14 |
import yaml
|
| 15 |
import numpy as np
|
| 16 |
+
import gc
|
| 17 |
|
| 18 |
from src.core import YAMLConfig
|
| 19 |
|
|
|
|
| 109 |
print(f"Downloaded weights to: {weights_path}")
|
| 110 |
return weights_path
|
| 111 |
|
| 112 |
+
@torch.no_grad()
|
| 113 |
def process_image_for_gradio(model, device, image, model_name, threshold=0.4):
|
| 114 |
"""Process image function for Gradio interface"""
|
| 115 |
if isinstance(image, np.ndarray):
|
|
|
|
| 186 |
return outputs
|
| 187 |
|
| 188 |
|
| 189 |
+
# YAMLConfig ํด๋์ค์ ๋ด๋ถ ์ํ๋ฅผ ์ด๊ธฐํํ๋ ํจ์ ์ถ๊ฐ
|
| 190 |
+
def reset_yaml_config():
|
| 191 |
+
"""YAMLConfig ํด๋์ค์ ๋ด๋ถ ์ํ๋ฅผ ์ด๊ธฐํ"""
|
| 192 |
+
# ํด๋์ค ๋ด๋ถ์ ์บ์ฑ๋ ์ ๋ณด๊ฐ ์๋ค๋ฉด ์ญ์
|
| 193 |
+
if hasattr(YAMLConfig, '_instances'):
|
| 194 |
+
YAMLConfig._instances = {}
|
| 195 |
+
if hasattr(YAMLConfig, '_configs'):
|
| 196 |
+
YAMLConfig._configs = {}
|
| 197 |
+
|
| 198 |
+
# ๊ฐ๋ฅํ ๋ค๋ฅธ ๋ชจ๋ ๋ชจ๋ ์บ์ ๋ฆฌ์
|
| 199 |
+
import importlib
|
| 200 |
+
for module_name in list(sys.modules.keys()):
|
| 201 |
+
if module_name.startswith('src.'):
|
| 202 |
+
try:
|
| 203 |
+
importlib.reload(sys.modules[module_name])
|
| 204 |
+
except:
|
| 205 |
+
pass
|
| 206 |
+
|
| 207 |
def load_model(model_name):
|
| 208 |
+
# ๋ชจ๋ธ ๋ก๋ ์ ์ CUDA ์บ์์ ๊ฐ๋น์ง ์ปฌ๋ ์
์ ๋ฆฌ
|
| 209 |
+
if torch.cuda.is_available():
|
| 210 |
+
torch.cuda.empty_cache()
|
| 211 |
+
gc.collect()
|
| 212 |
+
|
| 213 |
+
# YAMLConfig ๋ด๋ถ ์ํ ์ด๊ธฐํ
|
| 214 |
+
reset_yaml_config()
|
| 215 |
+
|
| 216 |
cfgfile = model_configs[model_name]["cfgfile"]
|
| 217 |
weights_path = download_weights(model_name)
|
| 218 |
|
| 219 |
+
# ์์ ํ ์๋ก์ด YAMLConfig ์ธ์คํด์ค ์์ฑ
|
| 220 |
cfg = YAMLConfig(cfgfile, resume=weights_path)
|
| 221 |
|
| 222 |
if "HGNetv2" in cfg.yaml_cfg:
|
|
|
|
| 225 |
checkpoint = torch.load(weights_path, map_location="cpu")
|
| 226 |
state = checkpoint["ema"]["module"] if "ema" in checkpoint else checkpoint["model"]
|
| 227 |
|
| 228 |
+
# ๋ชจ๋ธ ์์ฑ ์ ํ๋ฒ ๋ ํ์ธ
|
| 229 |
+
torch.cuda.empty_cache()
|
| 230 |
+
gc.collect()
|
| 231 |
+
|
| 232 |
+
cfg.model.load_state_dict(state, strict=False)
|
| 233 |
|
| 234 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 235 |
model = ModelWrapper(cfg).to(device)
|
|
|
|
| 237 |
|
| 238 |
return model, device
|
| 239 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 240 |
@spaces.GPU
|
| 241 |
def process_image(image, model_name, confidence_threshold):
|
| 242 |
"""Main processing function for Gradio interface"""
|
|
|
|
| 243 |
|
| 244 |
+
# ๋ชจ๋ ์ฌ์ฉ ๊ฐ๋ฅํ CUDA ์ฅ์น ๋ฉ๋ชจ๋ฆฌ ํ๋ณด
|
| 245 |
+
if torch.cuda.is_available():
|
| 246 |
+
torch.cuda.empty_cache()
|
| 247 |
+
|
| 248 |
+
# ๋ชจ๋ Python ๊ฐ์ฒด ๊ฐ๋น์ง ์ปฌ๋ ์
|
| 249 |
+
gc.collect()
|
| 250 |
+
|
| 251 |
+
try:
|
| 252 |
print(f"Loading model: {model_name}")
|
| 253 |
model, device = load_model(model_name)
|
| 254 |
+
|
| 255 |
+
# ์ด๋ฏธ์ง ์ฒ๋ฆฌ
|
| 256 |
+
result = process_image_for_gradio(model, device, image, model_name, confidence_threshold)
|
| 257 |
+
|
| 258 |
+
# ๋ชจ๋ธ ๊ฐ์ฒด ๋ฐ ๊ด๋ จ ๋ฐ์ดํฐ ๋ช
์์ ์ ๊ฑฐ
|
| 259 |
+
del model
|
| 260 |
+
|
| 261 |
+
finally:
|
| 262 |
+
# ํญ์ ๋ฉ๋ชจ๋ฆฌ ์ ๋ฆฌ ๋ณด์ฅ
|
| 263 |
+
if torch.cuda.is_available():
|
| 264 |
+
torch.cuda.empty_cache()
|
| 265 |
+
gc.collect()
|
| 266 |
|
| 267 |
+
return result
|
|
|
|
| 268 |
|
| 269 |
|
| 270 |
# Create Gradio interface
|
|
|
|
| 296 |
]
|
| 297 |
)
|
| 298 |
|
| 299 |
+
if __name__ == "__main__":
|
| 300 |
+
# Launch the Gradio app
|
| 301 |
+
demo.launch(share=True)
|