import os import sys import io import json import requests from typing import Iterable, List, Tuple, Dict, Any from PIL import Image, ImageDraw, ImageFont import gradio as gr import torch from transformers import AutoProcessor, Florence2ForConditionalGeneration from gradio.themes import Soft from gradio.themes.utils import colors, fonts, sizes # ---------- Theme (kept from your original) ---------- colors.steel_blue = colors.Color( name="steel_blue", c50="#EBF3F8", c100="#D3E5F0", c200="#A8CCE1", c300="#7DB3D2", c400="#529AC3", c500="#4682B4", c600="#3E72A0", c700="#36638C", c800="#2E5378", c900="#264364", c950="#1E3450", ) class SteelBlueTheme(Soft): def __init__( self, *, primary_hue: colors.Color | str = colors.gray, secondary_hue: colors.Color | str = colors.steel_blue, neutral_hue: colors.Color | str = colors.slate, text_size: sizes.Size | str = sizes.text_lg, font: fonts.Font | str | Iterable[fonts.Font | str] = ( fonts.GoogleFont("Outfit"), "Arial", "sans-serif", ), font_mono: fonts.Font | str | Iterable[fonts.Font | str] = ( fonts.GoogleFont("IBM Plex Mono"), "ui-monospace", "monospace", ), ): super().__init__( primary_hue=primary_hue, secondary_hue=secondary_hue, neutral_hue=neutral_hue, text_size=text_size, font=font, font_mono=font_mono, ) super().set( background_fill_primary="*primary_50", background_fill_primary_dark="*primary_900", body_background_fill="linear-gradient(135deg, *primary_200, *primary_100)", body_background_fill_dark="linear-gradient(135deg, *primary_900, *primary_800)", button_primary_text_color="white", button_primary_text_color_hover="white", button_primary_background_fill="linear-gradient(90deg, *secondary_500, *secondary_600)", button_primary_background_fill_hover="linear-gradient(90deg, *secondary_600, *secondary_700)", button_primary_background_fill_dark="linear-gradient(90deg, *secondary_600, *secondary_700)", button_primary_background_fill_hover_dark="linear-gradient(90deg, *secondary_500, *secondary_600)", slider_color="*secondary_500", slider_color_dark="*secondary_600", block_title_text_weight="600", block_border_width="3px", block_shadow="*shadow_drop_lg", button_primary_shadow="*shadow_drop_lg", button_large_padding="11px", color_accent_soft="*primary_100", block_label_background_fill="*primary_200", ) steel_blue_theme = SteelBlueTheme() css = """ #main-title h1 { font-size: 2.3em !important; } #output-title h2 { font-size: 2.1em !important; } """ # ---------- Models ---------- MODEL_IDS = { "Florence-2-base": "florence-community/Florence-2-base", "Florence-2-base-ft": "florence-community/Florence-2-base-ft", "Florence-2-large": "florence-community/Florence-2-large", "Florence-2-large-ft": "florence-community/Florence-2-large-ft", } models: Dict[str, Florence2ForConditionalGeneration] = {} processors: Dict[str, AutoProcessor] = {} print("Loading Florence-2 models... This may take a while.") for name, repo_id in MODEL_IDS.items(): print(f"Loading {name}...") model = Florence2ForConditionalGeneration.from_pretrained( repo_id, dtype=torch.bfloat16, device_map="auto", trust_remote_code=True ) processor = AutoProcessor.from_pretrained(repo_id, trust_remote_code=True) models[name] = model processors[name] = processor print(f"āœ… Finished loading {name}.") print("\nšŸŽ‰ All models loaded successfully!") # ---------- Utilities ---------- def _safe_parse_json_like(text: Any) -> Any: """ If text is a dict already, return it. If it's a JSON-like string, try to json.loads it. Otherwise return the original text. """ if isinstance(text, dict): return text if isinstance(text, str): text_str = text.strip() # try to decode if it looks like JSON if (text_str.startswith("{") and text_str.endswith("}")) or (text_str.startswith("[") and text_str.endswith("]")): try: return json.loads(text_str) except Exception: # fallback to returning original string return text return text def _find_bboxes_and_labels(obj: Any) -> List[Tuple[List[int], str]]: """ Recursively search `obj` (dict/list) for pairs of 'bboxes' and 'labels' (or region entries). Returns list of tuples: (bbox, label) bbox assumed as [x1,y1,x2,y2] (integers/floats) """ found: List[Tuple[List[int], str]] = [] def recurse(o: Any): if isinstance(o, dict): # direct pair case if "bboxes" in o: bboxes = o.get("bboxes", []) labels = o.get("labels", []) # if labels length mismatch, fill with empty strings for i, bx in enumerate(bboxes): lbl = labels[i] if i < len(labels) else "" # sometimes bboxes come as dicts with keys or lists if isinstance(bx, dict) and {"x","y","w","h"}.issubset(bx.keys()): # convert xywh to x1,y1,x2,y2 x = bx["x"]; y = bx["y"]; w = bx["w"]; h = bx["h"] found.append(([int(x), int(y), int(x + w), int(y + h)], lbl)) else: # assume list-like [x1,y1,x2,y2] or [x,y,w,h] try: bx_list = list(map(int, bx)) if len(bx_list) == 4: x1, y1, x2, y2 = bx_list # Heuristic: if x2>x1 and y2>y1 assume x1,y1,x2,y2 otherwise maybe xywh if x2 > x1 and y2 > y1: found.append(([x1, y1, x2, y2], lbl)) else: # try treat as xywh found.append(([x1, y1, x1 + x2, y1 + y2], lbl)) else: # skip unexpected format pass except Exception: pass # also check for region entries like {'bbox': ..., 'text': ...} or list of regions if "regions" in o and isinstance(o["regions"], list): for reg in o["regions"]: if isinstance(reg, dict) and "bbox" in reg: bx = reg["bbox"] lbl = reg.get("label", reg.get("text", "")) try: bx_list = list(map(int, bx)) if len(bx_list) == 4: found.append(([bx_list[0], bx_list[1], bx_list[2], bx_list[3]], lbl)) except Exception: pass # recurse deeper for v in o.values(): recurse(v) elif isinstance(o, list): for item in o: recurse(item) # else ignore primitives recurse(obj) return found def _draw_bboxes_on_image(img: Image.Image, boxes_and_labels: List[Tuple[List[int], str]]) -> Image.Image: """ Draw bounding boxes and labels on a copy of `img`. """ annotated = img.convert("RGB").copy() draw = ImageDraw.Draw(annotated) # try to get a default font (PIL may not have a TTF available) try: font = ImageFont.truetype("DejaVuSans.ttf", size=14) except Exception: font = ImageFont.load_default() for bbox, label in boxes_and_labels: # bbox should be [x1,y1,x2,y2] x1, y1, x2, y2 = bbox # keep coordinates within image bounds x1 = max(0, int(x1)); y1 = max(0, int(y1)) x2 = min(annotated.width - 1, int(x2)); y2 = min(annotated.height - 1, int(y2)) # draw rectangle (thicker by drawing several offsets) thickness = max(2, int(round(min(annotated.width, annotated.height) / 200))) for t in range(thickness): draw.rectangle([x1 - t, y1 - t, x2 + t, y2 + t], outline="red") # draw label background if label is None: label = "" label_text = str(label) text_w, text_h = draw.textsize(label_text, font=font) # background rectangle for label (semi-opaque) label_bg = [x1, max(0, y1 - text_h - 4), x1 + text_w + 6, y1] draw.rectangle(label_bg, fill="red") # text draw.text((x1 + 3, max(0, y1 - text_h - 2)), label_text, fill="white", font=font) return annotated # ---------- Inference function ---------- # tasks for which we attempt to extract/display bboxes VISUAL_REGION_TASKS = {"", "", "", ""} # If you are using Spaces with GPU decorator; keep it as-is in your environment def run_florence2_inference(model_name: str, image: Image.Image, task_prompt: str, max_new_tokens: int = 1024, num_beams: int = 3): """ Runs inference using the selected Florence-2 model. Returns a tuple: (parsed_answer, annotated_image_or_none) """ if image is None: return {"error": "Please upload an image to get started."}, None model = models[model_name] processor = processors[model_name] # Prepare inputs (move to model device) inputs = processor(text=task_prompt, images=image, return_tensors="pt") # send tensors to model device and set dtype device = model.device for k, v in inputs.items(): if isinstance(v, torch.Tensor): inputs[k] = v.to(device, dtype=torch.bfloat16) # Generate generated_ids = model.generate( input_ids=inputs.get("input_ids"), pixel_values=inputs.get("pixel_values"), max_new_tokens=max_new_tokens, num_beams=num_beams, do_sample=False ) # Decode generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0] # Post-process (the processor provided by Florence models sometimes provides # a structured output such as dict with 'bboxes' etc.) image_size = image.size parsed_answer = processor.post_process_generation( generated_text, task=task_prompt, image_size=image_size ) # Try to make parsed_answer JSON-serializable and easily inspectable parsed_serializable = parsed_answer # If it's a string that contains JSON, attempt to parse if isinstance(parsed_answer, str): parsed_serializable = _safe_parse_json_like(parsed_answer) annotated_image = None # If the task is in our visual region tasks, try to find bboxes and labels if task_prompt in VISUAL_REGION_TASKS: # parsed_serializable may be dict/list or string; try to find bboxes boxes_and_labels = _find_bboxes_and_labels(parsed_serializable) if boxes_and_labels: try: annotated_image = _draw_bboxes_on_image(image, boxes_and_labels) except Exception as e: # if drawing fails, set annotated_image to None but keep parsed answer print("Failed to draw boxes:", e) annotated_image = None # Return parsed answer (prefer a dict or serializable structure) and annotated image (PIL) or None return parsed_serializable, annotated_image # ---------- UI ---------- florence_tasks = [ "", "", "", "", "", "", "", "" ] # Example image (keeps your example) url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg?download=true" example_image = Image.open(requests.get(url, stream=True).raw).convert("RGB") with gr.Blocks(css=css, theme=steel_blue_theme) as demo: gr.Markdown("# **Florence-2 Vision Models**", elem_id="main-title") gr.Markdown("Select a model, upload an image, choose a task, and click Submit to see the parsed output and an annotated image (when bounding boxes are present).") with gr.Row(): with gr.Column(scale=2): image_upload = gr.Image(type="pil", label="Upload Image", value=example_image, height=290) task_prompt = gr.Dropdown( label="Select Task", choices=florence_tasks, value="" ) model_choice = gr.Radio( choices=list(MODEL_IDS.keys()), label="Select Model", value="Florence-2-large-ft" ) image_submit = gr.Button("Submit", variant="primary") with gr.Accordion("Advanced options", open=False): max_new_tokens = gr.Slider( label="Max New Tokens", minimum=128, maximum=2048, step=128, value=1024 ) num_beams = gr.Slider( label="Number of Beams", minimum=1, maximum=10, step=1, value=3 ) with gr.Column(scale=3): gr.Markdown("## Output", elem_id="output-title") parsed_output = gr.JSON(label="Parsed Answer") annotated_output = gr.Image(label="Annotated Image (if available)", type="pil") image_submit.click( fn=run_florence2_inference, inputs=[model_choice, image_upload, task_prompt, max_new_tokens, num_beams], outputs=[parsed_output, annotated_output] ) if __name__ == "__main__": demo.queue().launch(debug=True, mcp_server=True, ssr_mode=False, show_error=True)