Spaces:
Running
on
Zero
Running
on
Zero
| """ | |
| Copyright (c) 2024 The D-FINE Authors. All Rights Reserved. | |
| """ | |
| import gradio as gr | |
| import spaces | |
| import os | |
| import sys | |
| import torch | |
| import torch.nn as nn | |
| import torchvision.transforms as T | |
| import supervision as sv | |
| from PIL import Image | |
| import requests | |
| import yaml | |
| import numpy as np | |
| import gc | |
| from src.core import YAMLConfig | |
| model_configs = { | |
| "dfine_n_coco": | |
| {"cfgfile": "configs/dfine/dfine_hgnetv2_n_coco.yml", | |
| "classinfofile": "configs/coco.yml", | |
| "weights": "https://github.com/Peterande/storage/releases/download/dfinev1.0/dfine_n_coco.pth"}, | |
| "dfine_s_coco": | |
| {"cfgfile": "configs/dfine/dfine_hgnetv2_s_coco.yml", | |
| "classinfofile": "configs/coco.yml", | |
| "weights": "https://github.com/Peterande/storage/releases/download/dfinev1.0/dfine_s_coco.pth"}, | |
| "dfine_m_coco": | |
| {"cfgfile": "configs/dfine/dfine_hgnetv2_m_coco.yml", | |
| "classinfofile": "configs/coco.yml", | |
| "weights": "https://github.com/Peterande/storage/releases/download/dfinev1.0/dfine_m_coco.pth"}, | |
| "dfine_l_coco": | |
| {"cfgfile": "configs/dfine/dfine_hgnetv2_l_coco.yml", | |
| "classinfofile": "configs/coco.yml", | |
| "weights": "https://github.com/Peterande/storage/releases/download/dfinev1.0/dfine_l_coco.pth"}, | |
| "dfine_x_coco": | |
| {"cfgfile": "configs/dfine/dfine_hgnetv2_x_coco.yml", | |
| "classinfofile": "configs/coco.yml", | |
| "weights": "https://github.com/Peterande/storage/releases/download/dfinev1.0/dfine_x_coco.pth"}, | |
| "dfine_s_obj365": | |
| {"cfgfile": "configs/dfine/objects365/dfine_hgnetv2_s_obj365.yml", | |
| "classinfofile": "configs/obj365.yml", | |
| "weights": "https://github.com/Peterande/storage/releases/download/dfinev1.0/dfine_s_obj365.pth"}, | |
| "dfine_m_obj365": | |
| {"cfgfile": "configs/dfine/objects365/dfine_hgnetv2_m_obj365.yml", | |
| "classinfofile": "configs/obj365.yml", | |
| "weights": "https://github.com/Peterande/storage/releases/download/dfinev1.0/dfine_m_obj365.pth"}, | |
| "dfine_l_obj365": | |
| {"cfgfile": "configs/dfine/objects365/dfine_hgnetv2_l_obj365.yml", | |
| "classinfofile": "configs/obj365.yml", | |
| "weights": "https://github.com/Peterande/storage/releases/download/dfinev1.0/dfine_l_obj365.pth"}, | |
| "dfine_l_obj365_e25": | |
| {"cfgfile": "configs/dfine/objects365/dfine_hgnetv2_l_obj365.yml", | |
| "classinfofile": "configs/obj365.yml", | |
| "weights": "https://github.com/Peterande/storage/releases/download/dfinev1.0/dfine_l_obj365_e25.pth"}, | |
| "dfine_x_obj365": | |
| {"cfgfile": "configs/dfine/objects365/dfine_hgnetv2_x_obj365.yml", | |
| "classinfofile": "configs/obj365.yml", | |
| "weights": "https://github.com/Peterande/storage/releases/download/dfinev1.0/dfine_x_obj365.pth"}, | |
| "dfine_s_obj2coco": | |
| {"cfgfile": "configs/dfine/objects365/dfine_hgnetv2_s_obj2coco.yml", | |
| "classinfofile": "configs/coco.yml", | |
| "weights": "https://github.com/Peterande/storage/releases/download/dfinev1.0/dfine_s_obj2coco.pth"}, | |
| "dfine_m_obj2coco": | |
| {"cfgfile": "configs/dfine/objects365/dfine_hgnetv2_m_obj2coco.yml", | |
| "classinfofile": "configs/coco.yml", | |
| "weights": "https://github.com/Peterande/storage/releases/download/dfinev1.0/dfine_m_obj2coco.pth"}, | |
| "dfine_l_obj2coco_e25": | |
| {"cfgfile": "configs/dfine/objects365/dfine_hgnetv2_l_obj2coco.yml", | |
| "classinfofile": "configs/coco.yml", | |
| "weights": "https://github.com/Peterande/storage/releases/download/dfinev1.0/dfine_l_obj2coco_e25.pth"}, | |
| "dfine_x_obj2coco": | |
| {"cfgfile": "configs/dfine/objects365/dfine_hgnetv2_x_obj2coco.yml", | |
| "classinfofile": "configs/coco.yml", | |
| "weights": "https://github.com/Peterande/storage/releases/download/dfinev1.0/dfine_x_obj2coco.pth"}, | |
| } | |
| def download_weights(model_name): | |
| """Download model weights if not already present""" | |
| weights_url = model_configs[model_name]["weights"] | |
| # Directory path to save weight files | |
| weights_dir = os.path.join(os.path.dirname(__file__), "weights") | |
| # Weight file path | |
| weights_path = os.path.join(weights_dir, model_name + ".pth") | |
| # Create weights directory if it doesn't exist | |
| if not os.path.exists(weights_dir): | |
| os.makedirs(weights_dir) | |
| print(f"Created directory: {weights_dir}") | |
| # Check if file already exists | |
| if os.path.exists(weights_path): | |
| print(f"Weights file already exists at: {weights_path}") | |
| return weights_path | |
| # Download file | |
| print(f"Downloading weights from {weights_url} to {weights_path}...") | |
| response = requests.get(weights_url, stream=True) | |
| response.raise_for_status() # Check for download errors | |
| with open(weights_path, 'wb') as f: | |
| for chunk in response.iter_content(chunk_size=8192): | |
| f.write(chunk) | |
| print(f"Downloaded weights to: {weights_path}") | |
| return weights_path | |
| def process_image_for_gradio(model, device, image, model_name, threshold=0.4): | |
| """Process image function for Gradio interface""" | |
| if isinstance(image, np.ndarray): | |
| # Convert NumPy array to PIL image | |
| im_pil = Image.fromarray(image) | |
| else: | |
| im_pil = image | |
| # Load class information | |
| classinfofile = model_configs[model_name]["classinfofile"] | |
| classinfo = yaml.load(open(classinfofile, "r"), Loader=yaml.FullLoader)["names"] | |
| indexing_method = "0-based" if "coco" in classinfofile else "1-based" | |
| w, h = im_pil.size | |
| orig_size = torch.tensor([[w, h]]).to(device) | |
| transforms = T.Compose( | |
| [ | |
| T.Resize((640, 640)), | |
| T.ToTensor(), | |
| ] | |
| ) | |
| im_data = transforms(im_pil).unsqueeze(0).to(device) | |
| output = model(im_data, orig_size) | |
| labels, boxes, scores = output | |
| # Visualize results | |
| detections = sv.Detections( | |
| xyxy=boxes[0].detach().cpu().numpy(), | |
| confidence=scores[0].detach().cpu().numpy(), | |
| class_id=labels[0].detach().cpu().numpy().astype(int), | |
| ) | |
| detections = detections[detections.confidence > threshold] | |
| text_scale = sv.calculate_optimal_text_scale(resolution_wh=im_pil.size) | |
| line_thickness = sv.calculate_optimal_line_thickness(resolution_wh=im_pil.size) | |
| box_annotator = sv.BoxAnnotator(thickness=line_thickness) | |
| label_annotator = sv.LabelAnnotator(text_scale=text_scale, smart_position=True) | |
| label_texts = [ | |
| f"{classinfo[class_id if indexing_method == '0-based' else class_id - 1]} {confidence:.2f}" | |
| for class_id, confidence | |
| in zip(detections.class_id, detections.confidence) | |
| ] | |
| result_image = im_pil.copy() | |
| result_image = box_annotator.annotate(scene=result_image, detections=detections) | |
| result_image = label_annotator.annotate( | |
| scene=result_image, | |
| detections=detections, | |
| labels=label_texts | |
| ) | |
| detection_info = [ | |
| f"{classinfo[class_id if indexing_method == '0-based' else class_id - 1]}: {confidence:.2f}, bbox: [{xyxy[0]:.1f}, {xyxy[1]:.1f}, {xyxy[2]:.1f}, {xyxy[3]:.1f}]" | |
| for class_id, confidence, xyxy | |
| in zip(detections.class_id, detections.confidence, detections.xyxy) | |
| ] | |
| return result_image, "\n".join(detection_info) | |
| class ModelWrapper(nn.Module): | |
| def __init__(self, cfg): | |
| super().__init__() | |
| self.model = cfg.model.deploy() | |
| self.postprocessor = cfg.postprocessor.deploy() | |
| def forward(self, images, orig_target_sizes): | |
| outputs = self.model(images) | |
| outputs = self.postprocessor(outputs, orig_target_sizes) | |
| return outputs | |
| # YAMLConfig ํด๋์ค์ ๋ด๋ถ ์ํ๋ฅผ ์ด๊ธฐํํ๋ ํจ์ ์ถ๊ฐ | |
| def reset_yaml_config(): | |
| """YAMLConfig ํด๋์ค์ ๋ด๋ถ ์ํ๋ฅผ ์ด๊ธฐํ""" | |
| # ํด๋์ค ๋ด๋ถ์ ์บ์ฑ๋ ์ ๋ณด๊ฐ ์๋ค๋ฉด ์ญ์ | |
| if hasattr(YAMLConfig, '_instances'): | |
| YAMLConfig._instances = {} | |
| if hasattr(YAMLConfig, '_configs'): | |
| YAMLConfig._configs = {} | |
| # ๊ฐ๋ฅํ ๋ค๋ฅธ ๋ชจ๋ ๋ชจ๋ ์บ์ ๋ฆฌ์ | |
| import importlib | |
| for module_name in list(sys.modules.keys()): | |
| if module_name.startswith('src.'): | |
| try: | |
| importlib.reload(sys.modules[module_name]) | |
| except: | |
| pass | |
| def load_model(model_name): | |
| # ๋ชจ๋ธ ๋ก๋ ์ ์ CUDA ์บ์์ ๊ฐ๋น์ง ์ปฌ๋ ์ ์ ๋ฆฌ | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| # YAMLConfig ๋ด๋ถ ์ํ ์ด๊ธฐํ | |
| reset_yaml_config() | |
| cfgfile = model_configs[model_name]["cfgfile"] | |
| weights_path = download_weights(model_name) | |
| # ์์ ํ ์๋ก์ด YAMLConfig ์ธ์คํด์ค ์์ฑ | |
| cfg = YAMLConfig(cfgfile, resume=weights_path) | |
| if "HGNetv2" in cfg.yaml_cfg: | |
| cfg.yaml_cfg["HGNetv2"]["pretrained"] = False | |
| checkpoint = torch.load(weights_path, map_location="cpu") | |
| state = checkpoint["ema"]["module"] if "ema" in checkpoint else checkpoint["model"] | |
| # ๋ชจ๋ธ ์์ฑ ์ ํ๋ฒ ๋ ํ์ธ | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| cfg.model.load_state_dict(state, strict=False) | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model = ModelWrapper(cfg).to(device) | |
| model.eval() | |
| return model, device | |
| def process_image(image, model_name, confidence_threshold): | |
| """Main processing function for Gradio interface""" | |
| # ๋ชจ๋ ์ฌ์ฉ ๊ฐ๋ฅํ CUDA ์ฅ์น ๋ฉ๋ชจ๋ฆฌ ํ๋ณด | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| # ๋ชจ๋ Python ๊ฐ์ฒด ๊ฐ๋น์ง ์ปฌ๋ ์ | |
| gc.collect() | |
| try: | |
| print(f"Loading model: {model_name}") | |
| model, device = load_model(model_name) | |
| # ์ด๋ฏธ์ง ์ฒ๋ฆฌ | |
| result = process_image_for_gradio(model, device, image, model_name, confidence_threshold) | |
| # ๋ชจ๋ธ ๊ฐ์ฒด ๋ฐ ๊ด๋ จ ๋ฐ์ดํฐ ๋ช ์์ ์ ๊ฑฐ | |
| del model | |
| finally: | |
| # ํญ์ ๋ฉ๋ชจ๋ฆฌ ์ ๋ฆฌ ๋ณด์ฅ | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| return result | |
| # Create Gradio interface | |
| demo = gr.Interface( | |
| fn=process_image, | |
| inputs=[ | |
| gr.Image(type="pil", label="Input Image"), | |
| gr.Dropdown( | |
| choices=list(model_configs.keys()), | |
| value="dfine_n_coco", | |
| label="Model Selection" | |
| ), | |
| gr.Slider( | |
| minimum=0.1, | |
| maximum=0.9, | |
| value=0.4, | |
| step=0.05, | |
| label="Confidence Threshold" | |
| ) | |
| ], | |
| outputs=[ | |
| gr.Image(type="pil", label="Detection Result"), | |
| gr.Textbox(label="Detected Objects") | |
| ], | |
| title="D-FINE Object Detection Demo", | |
| description="Upload an image to see object detection results using the D-FINE model. You can select different models and adjust the confidence threshold.", | |
| examples=[ | |
| ["examples/image1.jpg", "dfine_n_coco", 0.4], | |
| ] | |
| ) | |
| if __name__ == "__main__": | |
| # Launch the Gradio app | |
| demo.launch(share=True) |