Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| import sys | |
| import io | |
| import json | |
| import requests | |
| from typing import Iterable, List, Tuple, Dict, Any | |
| from PIL import Image, ImageDraw, ImageFont | |
| import gradio as gr | |
| import torch | |
| from transformers import AutoProcessor, Florence2ForConditionalGeneration | |
| from gradio.themes import Soft | |
| from gradio.themes.utils import colors, fonts, sizes | |
| # ---------- Theme (kept from your original) ---------- | |
| colors.steel_blue = colors.Color( | |
| name="steel_blue", | |
| c50="#EBF3F8", c100="#D3E5F0", c200="#A8CCE1", c300="#7DB3D2", | |
| c400="#529AC3", c500="#4682B4", c600="#3E72A0", c700="#36638C", | |
| c800="#2E5378", c900="#264364", c950="#1E3450", | |
| ) | |
| class SteelBlueTheme(Soft): | |
| def __init__( | |
| self, | |
| *, | |
| primary_hue: colors.Color | str = colors.gray, | |
| secondary_hue: colors.Color | str = colors.steel_blue, | |
| neutral_hue: colors.Color | str = colors.slate, | |
| text_size: sizes.Size | str = sizes.text_lg, | |
| font: fonts.Font | str | Iterable[fonts.Font | str] = ( | |
| fonts.GoogleFont("Outfit"), "Arial", "sans-serif", | |
| ), | |
| font_mono: fonts.Font | str | Iterable[fonts.Font | str] = ( | |
| fonts.GoogleFont("IBM Plex Mono"), "ui-monospace", "monospace", | |
| ), | |
| ): | |
| super().__init__( | |
| primary_hue=primary_hue, secondary_hue=secondary_hue, neutral_hue=neutral_hue, | |
| text_size=text_size, font=font, font_mono=font_mono, | |
| ) | |
| super().set( | |
| background_fill_primary="*primary_50", | |
| background_fill_primary_dark="*primary_900", | |
| body_background_fill="linear-gradient(135deg, *primary_200, *primary_100)", | |
| body_background_fill_dark="linear-gradient(135deg, *primary_900, *primary_800)", | |
| button_primary_text_color="white", | |
| button_primary_text_color_hover="white", | |
| button_primary_background_fill="linear-gradient(90deg, *secondary_500, *secondary_600)", | |
| button_primary_background_fill_hover="linear-gradient(90deg, *secondary_600, *secondary_700)", | |
| button_primary_background_fill_dark="linear-gradient(90deg, *secondary_600, *secondary_700)", | |
| button_primary_background_fill_hover_dark="linear-gradient(90deg, *secondary_500, *secondary_600)", | |
| slider_color="*secondary_500", | |
| slider_color_dark="*secondary_600", | |
| block_title_text_weight="600", | |
| block_border_width="3px", | |
| block_shadow="*shadow_drop_lg", | |
| button_primary_shadow="*shadow_drop_lg", | |
| button_large_padding="11px", | |
| color_accent_soft="*primary_100", | |
| block_label_background_fill="*primary_200", | |
| ) | |
| steel_blue_theme = SteelBlueTheme() | |
| css = """ | |
| #main-title h1 { | |
| font-size: 2.3em !important; | |
| } | |
| #output-title h2 { | |
| font-size: 2.1em !important; | |
| } | |
| """ | |
| # ---------- Models ---------- | |
| MODEL_IDS = { | |
| "Florence-2-base": "florence-community/Florence-2-base", | |
| "Florence-2-base-ft": "florence-community/Florence-2-base-ft", | |
| "Florence-2-large": "florence-community/Florence-2-large", | |
| "Florence-2-large-ft": "florence-community/Florence-2-large-ft", | |
| } | |
| models: Dict[str, Florence2ForConditionalGeneration] = {} | |
| processors: Dict[str, AutoProcessor] = {} | |
| print("Loading Florence-2 models... This may take a while.") | |
| for name, repo_id in MODEL_IDS.items(): | |
| print(f"Loading {name}...") | |
| model = Florence2ForConditionalGeneration.from_pretrained( | |
| repo_id, | |
| dtype=torch.bfloat16, | |
| device_map="auto", | |
| trust_remote_code=True | |
| ) | |
| processor = AutoProcessor.from_pretrained(repo_id, trust_remote_code=True) | |
| models[name] = model | |
| processors[name] = processor | |
| print(f"✅ Finished loading {name}.") | |
| print("\n🎉 All models loaded successfully!") | |
| # ---------- Utilities ---------- | |
| def _safe_parse_json_like(text: Any) -> Any: | |
| """ | |
| If text is a dict already, return it. If it's a JSON-like string, try to json.loads it. | |
| Otherwise return the original text. | |
| """ | |
| if isinstance(text, dict): | |
| return text | |
| if isinstance(text, str): | |
| text_str = text.strip() | |
| # try to decode if it looks like JSON | |
| if (text_str.startswith("{") and text_str.endswith("}")) or (text_str.startswith("[") and text_str.endswith("]")): | |
| try: | |
| return json.loads(text_str) | |
| except Exception: | |
| # fallback to returning original string | |
| return text | |
| return text | |
| def _find_bboxes_and_labels(obj: Any) -> List[Tuple[List[int], str]]: | |
| """ | |
| Recursively search `obj` (dict/list) for pairs of 'bboxes' and 'labels' (or region entries). | |
| Returns list of tuples: (bbox, label) | |
| bbox assumed as [x1,y1,x2,y2] (integers/floats) | |
| """ | |
| found: List[Tuple[List[int], str]] = [] | |
| def recurse(o: Any): | |
| if isinstance(o, dict): | |
| # direct pair case | |
| if "bboxes" in o: | |
| bboxes = o.get("bboxes", []) | |
| labels = o.get("labels", []) | |
| # if labels length mismatch, fill with empty strings | |
| for i, bx in enumerate(bboxes): | |
| lbl = labels[i] if i < len(labels) else "" | |
| # sometimes bboxes come as dicts with keys or lists | |
| if isinstance(bx, dict) and {"x","y","w","h"}.issubset(bx.keys()): | |
| # convert xywh to x1,y1,x2,y2 | |
| x = bx["x"]; y = bx["y"]; w = bx["w"]; h = bx["h"] | |
| found.append(([int(x), int(y), int(x + w), int(y + h)], lbl)) | |
| else: | |
| # assume list-like [x1,y1,x2,y2] or [x,y,w,h] | |
| try: | |
| bx_list = list(map(int, bx)) | |
| if len(bx_list) == 4: | |
| x1, y1, x2, y2 = bx_list | |
| # Heuristic: if x2>x1 and y2>y1 assume x1,y1,x2,y2 otherwise maybe xywh | |
| if x2 > x1 and y2 > y1: | |
| found.append(([x1, y1, x2, y2], lbl)) | |
| else: | |
| # try treat as xywh | |
| found.append(([x1, y1, x1 + x2, y1 + y2], lbl)) | |
| else: | |
| # skip unexpected format | |
| pass | |
| except Exception: | |
| pass | |
| # also check for region entries like {'bbox': ..., 'text': ...} or list of regions | |
| if "regions" in o and isinstance(o["regions"], list): | |
| for reg in o["regions"]: | |
| if isinstance(reg, dict) and "bbox" in reg: | |
| bx = reg["bbox"] | |
| lbl = reg.get("label", reg.get("text", "")) | |
| try: | |
| bx_list = list(map(int, bx)) | |
| if len(bx_list) == 4: | |
| found.append(([bx_list[0], bx_list[1], bx_list[2], bx_list[3]], lbl)) | |
| except Exception: | |
| pass | |
| # recurse deeper | |
| for v in o.values(): | |
| recurse(v) | |
| elif isinstance(o, list): | |
| for item in o: | |
| recurse(item) | |
| # else ignore primitives | |
| recurse(obj) | |
| return found | |
| def _draw_bboxes_on_image(img: Image.Image, boxes_and_labels: List[Tuple[List[int], str]]) -> Image.Image: | |
| """ | |
| Draw bounding boxes and labels on a copy of `img`. | |
| """ | |
| annotated = img.convert("RGB").copy() | |
| draw = ImageDraw.Draw(annotated) | |
| # try to get a default font (PIL may not have a TTF available) | |
| try: | |
| font = ImageFont.truetype("DejaVuSans.ttf", size=14) | |
| except Exception: | |
| font = ImageFont.load_default() | |
| for bbox, label in boxes_and_labels: | |
| # bbox should be [x1,y1,x2,y2] | |
| x1, y1, x2, y2 = bbox | |
| # keep coordinates within image bounds | |
| x1 = max(0, int(x1)); y1 = max(0, int(y1)) | |
| x2 = min(annotated.width - 1, int(x2)); y2 = min(annotated.height - 1, int(y2)) | |
| # draw rectangle (thicker by drawing several offsets) | |
| thickness = max(2, int(round(min(annotated.width, annotated.height) / 200))) | |
| for t in range(thickness): | |
| draw.rectangle([x1 - t, y1 - t, x2 + t, y2 + t], outline="red") | |
| # draw label background | |
| if label is None: | |
| label = "" | |
| label_text = str(label) | |
| text_w, text_h = draw.textsize(label_text, font=font) | |
| # background rectangle for label (semi-opaque) | |
| label_bg = [x1, max(0, y1 - text_h - 4), x1 + text_w + 6, y1] | |
| draw.rectangle(label_bg, fill="red") | |
| # text | |
| draw.text((x1 + 3, max(0, y1 - text_h - 2)), label_text, fill="white", font=font) | |
| return annotated | |
| # ---------- Inference function ---------- | |
| # tasks for which we attempt to extract/display bboxes | |
| VISUAL_REGION_TASKS = {"<OD>", "<DENSE_REGION_CAPTION>", "<OCR_WITH_REGION>", "<REGION_PROPOSAL>"} | |
| # If you are using Spaces with GPU decorator; keep it as-is in your environment | |
| def run_florence2_inference(model_name: str, image: Image.Image, task_prompt: str, | |
| max_new_tokens: int = 1024, num_beams: int = 3): | |
| """ | |
| Runs inference using the selected Florence-2 model. | |
| Returns a tuple: (parsed_answer, annotated_image_or_none) | |
| """ | |
| if image is None: | |
| return {"error": "Please upload an image to get started."}, None | |
| model = models[model_name] | |
| processor = processors[model_name] | |
| # Prepare inputs (move to model device) | |
| inputs = processor(text=task_prompt, images=image, return_tensors="pt") | |
| # send tensors to model device and set dtype | |
| device = model.device | |
| for k, v in inputs.items(): | |
| if isinstance(v, torch.Tensor): | |
| inputs[k] = v.to(device, dtype=torch.bfloat16) | |
| # Generate | |
| generated_ids = model.generate( | |
| input_ids=inputs.get("input_ids"), | |
| pixel_values=inputs.get("pixel_values"), | |
| max_new_tokens=max_new_tokens, | |
| num_beams=num_beams, | |
| do_sample=False | |
| ) | |
| # Decode | |
| generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0] | |
| # Post-process (the processor provided by Florence models sometimes provides | |
| # a structured output such as dict with 'bboxes' etc.) | |
| image_size = image.size | |
| parsed_answer = processor.post_process_generation( | |
| generated_text, task=task_prompt, image_size=image_size | |
| ) | |
| # Try to make parsed_answer JSON-serializable and easily inspectable | |
| parsed_serializable = parsed_answer | |
| # If it's a string that contains JSON, attempt to parse | |
| if isinstance(parsed_answer, str): | |
| parsed_serializable = _safe_parse_json_like(parsed_answer) | |
| annotated_image = None | |
| # If the task is in our visual region tasks, try to find bboxes and labels | |
| if task_prompt in VISUAL_REGION_TASKS: | |
| # parsed_serializable may be dict/list or string; try to find bboxes | |
| boxes_and_labels = _find_bboxes_and_labels(parsed_serializable) | |
| if boxes_and_labels: | |
| try: | |
| annotated_image = _draw_bboxes_on_image(image, boxes_and_labels) | |
| except Exception as e: | |
| # if drawing fails, set annotated_image to None but keep parsed answer | |
| print("Failed to draw boxes:", e) | |
| annotated_image = None | |
| # Return parsed answer (prefer a dict or serializable structure) and annotated image (PIL) or None | |
| return parsed_serializable, annotated_image | |
| # ---------- UI ---------- | |
| florence_tasks = [ | |
| "<OD>", "<CAPTION>", "<DETAILED_CAPTION>", "<MORE_DETAILED_CAPTION>", | |
| "<DENSE_REGION_CAPTION>", "<REGION_PROPOSAL>", "<OCR>", "<OCR_WITH_REGION>" | |
| ] | |
| # Example image (keeps your example) | |
| url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg?download=true" | |
| example_image = Image.open(requests.get(url, stream=True).raw).convert("RGB") | |
| with gr.Blocks(css=css, theme=steel_blue_theme) as demo: | |
| gr.Markdown("# **Florence-2 Vision Models**", elem_id="main-title") | |
| gr.Markdown("Select a model, upload an image, choose a task, and click Submit to see the parsed output and an annotated image (when bounding boxes are present).") | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| image_upload = gr.Image(type="pil", label="Upload Image", value=example_image, height=290) | |
| task_prompt = gr.Dropdown( | |
| label="Select Task", | |
| choices=florence_tasks, | |
| value="<MORE_DETAILED_CAPTION>" | |
| ) | |
| model_choice = gr.Radio( | |
| choices=list(MODEL_IDS.keys()), | |
| label="Select Model", | |
| value="Florence-2-large-ft" | |
| ) | |
| image_submit = gr.Button("Submit", variant="primary") | |
| with gr.Accordion("Advanced options", open=False): | |
| max_new_tokens = gr.Slider( | |
| label="Max New Tokens", minimum=128, maximum=2048, step=128, value=1024 | |
| ) | |
| num_beams = gr.Slider( | |
| label="Number of Beams", minimum=1, maximum=10, step=1, value=3 | |
| ) | |
| with gr.Column(scale=3): | |
| gr.Markdown("## Output", elem_id="output-title") | |
| parsed_output = gr.JSON(label="Parsed Answer") | |
| annotated_output = gr.Image(label="Annotated Image (if available)", type="pil") | |
| image_submit.click( | |
| fn=run_florence2_inference, | |
| inputs=[model_choice, image_upload, task_prompt, max_new_tokens, num_beams], | |
| outputs=[parsed_output, annotated_output] | |
| ) | |
| if __name__ == "__main__": | |
| demo.queue().launch(debug=True, mcp_server=True, ssr_mode=False, show_error=True) | |