import gradio as gr import spaces import argparse import cv2 from PIL import Image import numpy as np import warnings import torch warnings.filterwarnings("ignore") # Replace custom imports with Transformers from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection # Add supervision for better visualization import supervision as sv # Model IDs for Hugging Face MODEL_IDS = { "MM Grounding DINO Large": "rziga/mm_grounding_dino_large_all", "MM Grounding DINO Base": "rziga/mm_grounding_dino_base_all" } # Global variables for model caching device = "cuda" if torch.cuda.is_available() else "cpu" loaded_model_name = None processor = None model = None @spaces.GPU def run_grounding(input_image, grounding_caption, model_choice, box_threshold, text_threshold): global loaded_model_name, processor, model # Load or reload model if changed if loaded_model_name != model_choice: model_id = MODEL_IDS[model_choice] processor = AutoProcessor.from_pretrained(model_id) model = AutoModelForZeroShotObjectDetection.from_pretrained(model_id).to(device) loaded_model_name = model_choice # Convert numpy array to PIL Image if needed if isinstance(input_image, np.ndarray): if input_image.ndim == 3: input_image = cv2.cvtColor(input_image, cv2.COLOR_BGR2RGB) input_image = Image.fromarray(input_image) init_image = input_image.convert("RGB") # Process caption into list of list format for mm grounding dino # Split by period and strip whitespace text_labels = [[label.strip() for label in grounding_caption.split('.') if label.strip()]] # Process input using transformers inputs = processor(images=init_image, text=text_labels, return_tensors="pt").to(device) # Run inference with torch.no_grad(): outputs = model(**inputs) # Post-process results results = processor.post_process_grounded_object_detection( outputs, threshold=box_threshold, target_sizes=[(init_image.size[1], init_image.size[0])] ) result = results[0] # Convert image for supervision visualization image_np = np.array(init_image) # Create detections for supervision boxes = [] labels = [] confidences = [] class_ids = [] for i, (box, score, label) in enumerate(zip(result["boxes"], result["scores"], result["labels"])): # box is xyxy format [xmin, ymin, xmax, ymax] xyxy = box.tolist() boxes.append(xyxy) labels.append(label) confidences.append(float(score)) class_ids.append(i) # Use index as class_id (integer) # Build the text summary in the requested format if boxes: lines = [] for label, xyxy, conf in zip(labels, boxes, confidences): x1, y1, x2, y2 = [int(round(v)) for v in xyxy] # Format: class confidence top_left_x, top_left_y, bot_x, bot_y lines.append(f"{label} {conf:.3f} {x1}, {y1}, {x2}, {y2}") detection_text = "\n".join(lines) else: detection_text = "No detections." # Create Detections object for supervision & annotate if boxes: detections = sv.Detections( xyxy=np.array(boxes), confidence=np.array(confidences), class_id=np.array(class_ids, dtype=np.int32), ) text_scale = sv.calculate_optimal_text_scale(resolution_wh=init_image.size) line_thickness = sv.calculate_optimal_line_thickness(resolution_wh=init_image.size) # Create annotators box_annotator = sv.BoxAnnotator( thickness=2, color=sv.ColorPalette.DEFAULT, ) label_annotator = sv.LabelAnnotator( color=sv.ColorPalette.DEFAULT, text_color=sv.Color.WHITE, text_scale=text_scale, text_thickness=line_thickness, text_padding=3 ) # Create formatted labels for each detection formatted_labels = [ f"{label}: {conf:.2f}" for label, conf in zip(labels, confidences) ] # Apply annotations to the image annotated_image = box_annotator.annotate(scene=image_np, detections=detections) annotated_image = label_annotator.annotate( scene=annotated_image, detections=detections, labels=formatted_labels ) else: annotated_image = image_np # Convert back to PIL Image image_with_box = Image.fromarray(annotated_image) # Return both the annotated image and the detection text return image_with_box, detection_text if __name__ == "__main__": parser = argparse.ArgumentParser("Grounding DINO demo", add_help=True) parser.add_argument("--debug", action="store_true", help="using debug mode") parser.add_argument("--share", action="store_true", help="share the app") args = parser.parse_args() css = """ #mkd { height: 500px; overflow: auto; border: 1px solid #ccc; } """ with gr.Blocks(css=css) as demo: gr.Markdown("