prithivMLmods's picture
Update app.py
647e0d3 verified
raw
history blame
13.9 kB
import os
import sys
import io
import json
import requests
from typing import Iterable, List, Tuple, Dict, Any
from PIL import Image, ImageDraw, ImageFont
import gradio as gr
import torch
from transformers import AutoProcessor, Florence2ForConditionalGeneration
from gradio.themes import Soft
from gradio.themes.utils import colors, fonts, sizes
# ---------- Theme (kept from your original) ----------
colors.steel_blue = colors.Color(
name="steel_blue",
c50="#EBF3F8", c100="#D3E5F0", c200="#A8CCE1", c300="#7DB3D2",
c400="#529AC3", c500="#4682B4", c600="#3E72A0", c700="#36638C",
c800="#2E5378", c900="#264364", c950="#1E3450",
)
class SteelBlueTheme(Soft):
def __init__(
self,
*,
primary_hue: colors.Color | str = colors.gray,
secondary_hue: colors.Color | str = colors.steel_blue,
neutral_hue: colors.Color | str = colors.slate,
text_size: sizes.Size | str = sizes.text_lg,
font: fonts.Font | str | Iterable[fonts.Font | str] = (
fonts.GoogleFont("Outfit"), "Arial", "sans-serif",
),
font_mono: fonts.Font | str | Iterable[fonts.Font | str] = (
fonts.GoogleFont("IBM Plex Mono"), "ui-monospace", "monospace",
),
):
super().__init__(
primary_hue=primary_hue, secondary_hue=secondary_hue, neutral_hue=neutral_hue,
text_size=text_size, font=font, font_mono=font_mono,
)
super().set(
background_fill_primary="*primary_50",
background_fill_primary_dark="*primary_900",
body_background_fill="linear-gradient(135deg, *primary_200, *primary_100)",
body_background_fill_dark="linear-gradient(135deg, *primary_900, *primary_800)",
button_primary_text_color="white",
button_primary_text_color_hover="white",
button_primary_background_fill="linear-gradient(90deg, *secondary_500, *secondary_600)",
button_primary_background_fill_hover="linear-gradient(90deg, *secondary_600, *secondary_700)",
button_primary_background_fill_dark="linear-gradient(90deg, *secondary_600, *secondary_700)",
button_primary_background_fill_hover_dark="linear-gradient(90deg, *secondary_500, *secondary_600)",
slider_color="*secondary_500",
slider_color_dark="*secondary_600",
block_title_text_weight="600",
block_border_width="3px",
block_shadow="*shadow_drop_lg",
button_primary_shadow="*shadow_drop_lg",
button_large_padding="11px",
color_accent_soft="*primary_100",
block_label_background_fill="*primary_200",
)
steel_blue_theme = SteelBlueTheme()
css = """
#main-title h1 {
font-size: 2.3em !important;
}
#output-title h2 {
font-size: 2.1em !important;
}
"""
# ---------- Models ----------
MODEL_IDS = {
"Florence-2-base": "florence-community/Florence-2-base",
"Florence-2-base-ft": "florence-community/Florence-2-base-ft",
"Florence-2-large": "florence-community/Florence-2-large",
"Florence-2-large-ft": "florence-community/Florence-2-large-ft",
}
models: Dict[str, Florence2ForConditionalGeneration] = {}
processors: Dict[str, AutoProcessor] = {}
print("Loading Florence-2 models... This may take a while.")
for name, repo_id in MODEL_IDS.items():
print(f"Loading {name}...")
model = Florence2ForConditionalGeneration.from_pretrained(
repo_id,
dtype=torch.bfloat16,
device_map="auto",
trust_remote_code=True
)
processor = AutoProcessor.from_pretrained(repo_id, trust_remote_code=True)
models[name] = model
processors[name] = processor
print(f"✅ Finished loading {name}.")
print("\n🎉 All models loaded successfully!")
# ---------- Utilities ----------
def _safe_parse_json_like(text: Any) -> Any:
"""
If text is a dict already, return it. If it's a JSON-like string, try to json.loads it.
Otherwise return the original text.
"""
if isinstance(text, dict):
return text
if isinstance(text, str):
text_str = text.strip()
# try to decode if it looks like JSON
if (text_str.startswith("{") and text_str.endswith("}")) or (text_str.startswith("[") and text_str.endswith("]")):
try:
return json.loads(text_str)
except Exception:
# fallback to returning original string
return text
return text
def _find_bboxes_and_labels(obj: Any) -> List[Tuple[List[int], str]]:
"""
Recursively search `obj` (dict/list) for pairs of 'bboxes' and 'labels' (or region entries).
Returns list of tuples: (bbox, label)
bbox assumed as [x1,y1,x2,y2] (integers/floats)
"""
found: List[Tuple[List[int], str]] = []
def recurse(o: Any):
if isinstance(o, dict):
# direct pair case
if "bboxes" in o:
bboxes = o.get("bboxes", [])
labels = o.get("labels", [])
# if labels length mismatch, fill with empty strings
for i, bx in enumerate(bboxes):
lbl = labels[i] if i < len(labels) else ""
# sometimes bboxes come as dicts with keys or lists
if isinstance(bx, dict) and {"x","y","w","h"}.issubset(bx.keys()):
# convert xywh to x1,y1,x2,y2
x = bx["x"]; y = bx["y"]; w = bx["w"]; h = bx["h"]
found.append(([int(x), int(y), int(x + w), int(y + h)], lbl))
else:
# assume list-like [x1,y1,x2,y2] or [x,y,w,h]
try:
bx_list = list(map(int, bx))
if len(bx_list) == 4:
x1, y1, x2, y2 = bx_list
# Heuristic: if x2>x1 and y2>y1 assume x1,y1,x2,y2 otherwise maybe xywh
if x2 > x1 and y2 > y1:
found.append(([x1, y1, x2, y2], lbl))
else:
# try treat as xywh
found.append(([x1, y1, x1 + x2, y1 + y2], lbl))
else:
# skip unexpected format
pass
except Exception:
pass
# also check for region entries like {'bbox': ..., 'text': ...} or list of regions
if "regions" in o and isinstance(o["regions"], list):
for reg in o["regions"]:
if isinstance(reg, dict) and "bbox" in reg:
bx = reg["bbox"]
lbl = reg.get("label", reg.get("text", ""))
try:
bx_list = list(map(int, bx))
if len(bx_list) == 4:
found.append(([bx_list[0], bx_list[1], bx_list[2], bx_list[3]], lbl))
except Exception:
pass
# recurse deeper
for v in o.values():
recurse(v)
elif isinstance(o, list):
for item in o:
recurse(item)
# else ignore primitives
recurse(obj)
return found
def _draw_bboxes_on_image(img: Image.Image, boxes_and_labels: List[Tuple[List[int], str]]) -> Image.Image:
"""
Draw bounding boxes and labels on a copy of `img`.
"""
annotated = img.convert("RGB").copy()
draw = ImageDraw.Draw(annotated)
# try to get a default font (PIL may not have a TTF available)
try:
font = ImageFont.truetype("DejaVuSans.ttf", size=14)
except Exception:
font = ImageFont.load_default()
for bbox, label in boxes_and_labels:
# bbox should be [x1,y1,x2,y2]
x1, y1, x2, y2 = bbox
# keep coordinates within image bounds
x1 = max(0, int(x1)); y1 = max(0, int(y1))
x2 = min(annotated.width - 1, int(x2)); y2 = min(annotated.height - 1, int(y2))
# draw rectangle (thicker by drawing several offsets)
thickness = max(2, int(round(min(annotated.width, annotated.height) / 200)))
for t in range(thickness):
draw.rectangle([x1 - t, y1 - t, x2 + t, y2 + t], outline="red")
# draw label background
if label is None:
label = ""
label_text = str(label)
text_w, text_h = draw.textsize(label_text, font=font)
# background rectangle for label (semi-opaque)
label_bg = [x1, max(0, y1 - text_h - 4), x1 + text_w + 6, y1]
draw.rectangle(label_bg, fill="red")
# text
draw.text((x1 + 3, max(0, y1 - text_h - 2)), label_text, fill="white", font=font)
return annotated
# ---------- Inference function ----------
# tasks for which we attempt to extract/display bboxes
VISUAL_REGION_TASKS = {"<OD>", "<DENSE_REGION_CAPTION>", "<OCR_WITH_REGION>", "<REGION_PROPOSAL>"}
# If you are using Spaces with GPU decorator; keep it as-is in your environment
def run_florence2_inference(model_name: str, image: Image.Image, task_prompt: str,
max_new_tokens: int = 1024, num_beams: int = 3):
"""
Runs inference using the selected Florence-2 model.
Returns a tuple: (parsed_answer, annotated_image_or_none)
"""
if image is None:
return {"error": "Please upload an image to get started."}, None
model = models[model_name]
processor = processors[model_name]
# Prepare inputs (move to model device)
inputs = processor(text=task_prompt, images=image, return_tensors="pt")
# send tensors to model device and set dtype
device = model.device
for k, v in inputs.items():
if isinstance(v, torch.Tensor):
inputs[k] = v.to(device, dtype=torch.bfloat16)
# Generate
generated_ids = model.generate(
input_ids=inputs.get("input_ids"),
pixel_values=inputs.get("pixel_values"),
max_new_tokens=max_new_tokens,
num_beams=num_beams,
do_sample=False
)
# Decode
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
# Post-process (the processor provided by Florence models sometimes provides
# a structured output such as dict with 'bboxes' etc.)
image_size = image.size
parsed_answer = processor.post_process_generation(
generated_text, task=task_prompt, image_size=image_size
)
# Try to make parsed_answer JSON-serializable and easily inspectable
parsed_serializable = parsed_answer
# If it's a string that contains JSON, attempt to parse
if isinstance(parsed_answer, str):
parsed_serializable = _safe_parse_json_like(parsed_answer)
annotated_image = None
# If the task is in our visual region tasks, try to find bboxes and labels
if task_prompt in VISUAL_REGION_TASKS:
# parsed_serializable may be dict/list or string; try to find bboxes
boxes_and_labels = _find_bboxes_and_labels(parsed_serializable)
if boxes_and_labels:
try:
annotated_image = _draw_bboxes_on_image(image, boxes_and_labels)
except Exception as e:
# if drawing fails, set annotated_image to None but keep parsed answer
print("Failed to draw boxes:", e)
annotated_image = None
# Return parsed answer (prefer a dict or serializable structure) and annotated image (PIL) or None
return parsed_serializable, annotated_image
# ---------- UI ----------
florence_tasks = [
"<OD>", "<CAPTION>", "<DETAILED_CAPTION>", "<MORE_DETAILED_CAPTION>",
"<DENSE_REGION_CAPTION>", "<REGION_PROPOSAL>", "<OCR>", "<OCR_WITH_REGION>"
]
# Example image (keeps your example)
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg?download=true"
example_image = Image.open(requests.get(url, stream=True).raw).convert("RGB")
with gr.Blocks(css=css, theme=steel_blue_theme) as demo:
gr.Markdown("# **Florence-2 Vision Models**", elem_id="main-title")
gr.Markdown("Select a model, upload an image, choose a task, and click Submit to see the parsed output and an annotated image (when bounding boxes are present).")
with gr.Row():
with gr.Column(scale=2):
image_upload = gr.Image(type="pil", label="Upload Image", value=example_image, height=290)
task_prompt = gr.Dropdown(
label="Select Task",
choices=florence_tasks,
value="<MORE_DETAILED_CAPTION>"
)
model_choice = gr.Radio(
choices=list(MODEL_IDS.keys()),
label="Select Model",
value="Florence-2-large-ft"
)
image_submit = gr.Button("Submit", variant="primary")
with gr.Accordion("Advanced options", open=False):
max_new_tokens = gr.Slider(
label="Max New Tokens", minimum=128, maximum=2048, step=128, value=1024
)
num_beams = gr.Slider(
label="Number of Beams", minimum=1, maximum=10, step=1, value=3
)
with gr.Column(scale=3):
gr.Markdown("## Output", elem_id="output-title")
parsed_output = gr.JSON(label="Parsed Answer")
annotated_output = gr.Image(label="Annotated Image (if available)", type="pil")
image_submit.click(
fn=run_florence2_inference,
inputs=[model_choice, image_upload, task_prompt, max_new_tokens, num_beams],
outputs=[parsed_output, annotated_output]
)
if __name__ == "__main__":
demo.queue().launch(debug=True, mcp_server=True, ssr_mode=False, show_error=True)