Spaces:
Running
on
Zero
Running
on
Zero
File size: 13,933 Bytes
03932b3 647e0d3 03932b3 647e0d3 8852109 647e0d3 03932b3 647e0d3 03932b3 647e0d3 03932b3 cd499b2 03932b3 647e0d3 03932b3 647e0d3 03932b3 647e0d3 8852109 647e0d3 8852109 647e0d3 e704cdd 647e0d3 e704cdd 647e0d3 8852109 647e0d3 8852109 647e0d3 8852109 647e0d3 8852109 647e0d3 e704cdd 647e0d3 03932b3 e704cdd 647e0d3 03932b3 647e0d3 03932b3 647e0d3 03932b3 647e0d3 03932b3 647e0d3 03932b3 7424eda 03932b3 647e0d3 03932b3 647e0d3 03932b3 647e0d3 8852109 647e0d3 03932b3 647e0d3 03932b3 647e0d3 03932b3 647e0d3 03932b3 212463e 03932b3 647e0d3 03932b3 74d543b 647e0d3 03932b3 647e0d3 03932b3 647e0d3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 |
import os
import sys
import io
import json
import requests
from typing import Iterable, List, Tuple, Dict, Any
from PIL import Image, ImageDraw, ImageFont
import gradio as gr
import torch
from transformers import AutoProcessor, Florence2ForConditionalGeneration
from gradio.themes import Soft
from gradio.themes.utils import colors, fonts, sizes
# ---------- Theme (kept from your original) ----------
colors.steel_blue = colors.Color(
name="steel_blue",
c50="#EBF3F8", c100="#D3E5F0", c200="#A8CCE1", c300="#7DB3D2",
c400="#529AC3", c500="#4682B4", c600="#3E72A0", c700="#36638C",
c800="#2E5378", c900="#264364", c950="#1E3450",
)
class SteelBlueTheme(Soft):
def __init__(
self,
*,
primary_hue: colors.Color | str = colors.gray,
secondary_hue: colors.Color | str = colors.steel_blue,
neutral_hue: colors.Color | str = colors.slate,
text_size: sizes.Size | str = sizes.text_lg,
font: fonts.Font | str | Iterable[fonts.Font | str] = (
fonts.GoogleFont("Outfit"), "Arial", "sans-serif",
),
font_mono: fonts.Font | str | Iterable[fonts.Font | str] = (
fonts.GoogleFont("IBM Plex Mono"), "ui-monospace", "monospace",
),
):
super().__init__(
primary_hue=primary_hue, secondary_hue=secondary_hue, neutral_hue=neutral_hue,
text_size=text_size, font=font, font_mono=font_mono,
)
super().set(
background_fill_primary="*primary_50",
background_fill_primary_dark="*primary_900",
body_background_fill="linear-gradient(135deg, *primary_200, *primary_100)",
body_background_fill_dark="linear-gradient(135deg, *primary_900, *primary_800)",
button_primary_text_color="white",
button_primary_text_color_hover="white",
button_primary_background_fill="linear-gradient(90deg, *secondary_500, *secondary_600)",
button_primary_background_fill_hover="linear-gradient(90deg, *secondary_600, *secondary_700)",
button_primary_background_fill_dark="linear-gradient(90deg, *secondary_600, *secondary_700)",
button_primary_background_fill_hover_dark="linear-gradient(90deg, *secondary_500, *secondary_600)",
slider_color="*secondary_500",
slider_color_dark="*secondary_600",
block_title_text_weight="600",
block_border_width="3px",
block_shadow="*shadow_drop_lg",
button_primary_shadow="*shadow_drop_lg",
button_large_padding="11px",
color_accent_soft="*primary_100",
block_label_background_fill="*primary_200",
)
steel_blue_theme = SteelBlueTheme()
css = """
#main-title h1 {
font-size: 2.3em !important;
}
#output-title h2 {
font-size: 2.1em !important;
}
"""
# ---------- Models ----------
MODEL_IDS = {
"Florence-2-base": "florence-community/Florence-2-base",
"Florence-2-base-ft": "florence-community/Florence-2-base-ft",
"Florence-2-large": "florence-community/Florence-2-large",
"Florence-2-large-ft": "florence-community/Florence-2-large-ft",
}
models: Dict[str, Florence2ForConditionalGeneration] = {}
processors: Dict[str, AutoProcessor] = {}
print("Loading Florence-2 models... This may take a while.")
for name, repo_id in MODEL_IDS.items():
print(f"Loading {name}...")
model = Florence2ForConditionalGeneration.from_pretrained(
repo_id,
dtype=torch.bfloat16,
device_map="auto",
trust_remote_code=True
)
processor = AutoProcessor.from_pretrained(repo_id, trust_remote_code=True)
models[name] = model
processors[name] = processor
print(f"✅ Finished loading {name}.")
print("\n🎉 All models loaded successfully!")
# ---------- Utilities ----------
def _safe_parse_json_like(text: Any) -> Any:
"""
If text is a dict already, return it. If it's a JSON-like string, try to json.loads it.
Otherwise return the original text.
"""
if isinstance(text, dict):
return text
if isinstance(text, str):
text_str = text.strip()
# try to decode if it looks like JSON
if (text_str.startswith("{") and text_str.endswith("}")) or (text_str.startswith("[") and text_str.endswith("]")):
try:
return json.loads(text_str)
except Exception:
# fallback to returning original string
return text
return text
def _find_bboxes_and_labels(obj: Any) -> List[Tuple[List[int], str]]:
"""
Recursively search `obj` (dict/list) for pairs of 'bboxes' and 'labels' (or region entries).
Returns list of tuples: (bbox, label)
bbox assumed as [x1,y1,x2,y2] (integers/floats)
"""
found: List[Tuple[List[int], str]] = []
def recurse(o: Any):
if isinstance(o, dict):
# direct pair case
if "bboxes" in o:
bboxes = o.get("bboxes", [])
labels = o.get("labels", [])
# if labels length mismatch, fill with empty strings
for i, bx in enumerate(bboxes):
lbl = labels[i] if i < len(labels) else ""
# sometimes bboxes come as dicts with keys or lists
if isinstance(bx, dict) and {"x","y","w","h"}.issubset(bx.keys()):
# convert xywh to x1,y1,x2,y2
x = bx["x"]; y = bx["y"]; w = bx["w"]; h = bx["h"]
found.append(([int(x), int(y), int(x + w), int(y + h)], lbl))
else:
# assume list-like [x1,y1,x2,y2] or [x,y,w,h]
try:
bx_list = list(map(int, bx))
if len(bx_list) == 4:
x1, y1, x2, y2 = bx_list
# Heuristic: if x2>x1 and y2>y1 assume x1,y1,x2,y2 otherwise maybe xywh
if x2 > x1 and y2 > y1:
found.append(([x1, y1, x2, y2], lbl))
else:
# try treat as xywh
found.append(([x1, y1, x1 + x2, y1 + y2], lbl))
else:
# skip unexpected format
pass
except Exception:
pass
# also check for region entries like {'bbox': ..., 'text': ...} or list of regions
if "regions" in o and isinstance(o["regions"], list):
for reg in o["regions"]:
if isinstance(reg, dict) and "bbox" in reg:
bx = reg["bbox"]
lbl = reg.get("label", reg.get("text", ""))
try:
bx_list = list(map(int, bx))
if len(bx_list) == 4:
found.append(([bx_list[0], bx_list[1], bx_list[2], bx_list[3]], lbl))
except Exception:
pass
# recurse deeper
for v in o.values():
recurse(v)
elif isinstance(o, list):
for item in o:
recurse(item)
# else ignore primitives
recurse(obj)
return found
def _draw_bboxes_on_image(img: Image.Image, boxes_and_labels: List[Tuple[List[int], str]]) -> Image.Image:
"""
Draw bounding boxes and labels on a copy of `img`.
"""
annotated = img.convert("RGB").copy()
draw = ImageDraw.Draw(annotated)
# try to get a default font (PIL may not have a TTF available)
try:
font = ImageFont.truetype("DejaVuSans.ttf", size=14)
except Exception:
font = ImageFont.load_default()
for bbox, label in boxes_and_labels:
# bbox should be [x1,y1,x2,y2]
x1, y1, x2, y2 = bbox
# keep coordinates within image bounds
x1 = max(0, int(x1)); y1 = max(0, int(y1))
x2 = min(annotated.width - 1, int(x2)); y2 = min(annotated.height - 1, int(y2))
# draw rectangle (thicker by drawing several offsets)
thickness = max(2, int(round(min(annotated.width, annotated.height) / 200)))
for t in range(thickness):
draw.rectangle([x1 - t, y1 - t, x2 + t, y2 + t], outline="red")
# draw label background
if label is None:
label = ""
label_text = str(label)
text_w, text_h = draw.textsize(label_text, font=font)
# background rectangle for label (semi-opaque)
label_bg = [x1, max(0, y1 - text_h - 4), x1 + text_w + 6, y1]
draw.rectangle(label_bg, fill="red")
# text
draw.text((x1 + 3, max(0, y1 - text_h - 2)), label_text, fill="white", font=font)
return annotated
# ---------- Inference function ----------
# tasks for which we attempt to extract/display bboxes
VISUAL_REGION_TASKS = {"<OD>", "<DENSE_REGION_CAPTION>", "<OCR_WITH_REGION>", "<REGION_PROPOSAL>"}
# If you are using Spaces with GPU decorator; keep it as-is in your environment
def run_florence2_inference(model_name: str, image: Image.Image, task_prompt: str,
max_new_tokens: int = 1024, num_beams: int = 3):
"""
Runs inference using the selected Florence-2 model.
Returns a tuple: (parsed_answer, annotated_image_or_none)
"""
if image is None:
return {"error": "Please upload an image to get started."}, None
model = models[model_name]
processor = processors[model_name]
# Prepare inputs (move to model device)
inputs = processor(text=task_prompt, images=image, return_tensors="pt")
# send tensors to model device and set dtype
device = model.device
for k, v in inputs.items():
if isinstance(v, torch.Tensor):
inputs[k] = v.to(device, dtype=torch.bfloat16)
# Generate
generated_ids = model.generate(
input_ids=inputs.get("input_ids"),
pixel_values=inputs.get("pixel_values"),
max_new_tokens=max_new_tokens,
num_beams=num_beams,
do_sample=False
)
# Decode
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
# Post-process (the processor provided by Florence models sometimes provides
# a structured output such as dict with 'bboxes' etc.)
image_size = image.size
parsed_answer = processor.post_process_generation(
generated_text, task=task_prompt, image_size=image_size
)
# Try to make parsed_answer JSON-serializable and easily inspectable
parsed_serializable = parsed_answer
# If it's a string that contains JSON, attempt to parse
if isinstance(parsed_answer, str):
parsed_serializable = _safe_parse_json_like(parsed_answer)
annotated_image = None
# If the task is in our visual region tasks, try to find bboxes and labels
if task_prompt in VISUAL_REGION_TASKS:
# parsed_serializable may be dict/list or string; try to find bboxes
boxes_and_labels = _find_bboxes_and_labels(parsed_serializable)
if boxes_and_labels:
try:
annotated_image = _draw_bboxes_on_image(image, boxes_and_labels)
except Exception as e:
# if drawing fails, set annotated_image to None but keep parsed answer
print("Failed to draw boxes:", e)
annotated_image = None
# Return parsed answer (prefer a dict or serializable structure) and annotated image (PIL) or None
return parsed_serializable, annotated_image
# ---------- UI ----------
florence_tasks = [
"<OD>", "<CAPTION>", "<DETAILED_CAPTION>", "<MORE_DETAILED_CAPTION>",
"<DENSE_REGION_CAPTION>", "<REGION_PROPOSAL>", "<OCR>", "<OCR_WITH_REGION>"
]
# Example image (keeps your example)
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg?download=true"
example_image = Image.open(requests.get(url, stream=True).raw).convert("RGB")
with gr.Blocks(css=css, theme=steel_blue_theme) as demo:
gr.Markdown("# **Florence-2 Vision Models**", elem_id="main-title")
gr.Markdown("Select a model, upload an image, choose a task, and click Submit to see the parsed output and an annotated image (when bounding boxes are present).")
with gr.Row():
with gr.Column(scale=2):
image_upload = gr.Image(type="pil", label="Upload Image", value=example_image, height=290)
task_prompt = gr.Dropdown(
label="Select Task",
choices=florence_tasks,
value="<MORE_DETAILED_CAPTION>"
)
model_choice = gr.Radio(
choices=list(MODEL_IDS.keys()),
label="Select Model",
value="Florence-2-large-ft"
)
image_submit = gr.Button("Submit", variant="primary")
with gr.Accordion("Advanced options", open=False):
max_new_tokens = gr.Slider(
label="Max New Tokens", minimum=128, maximum=2048, step=128, value=1024
)
num_beams = gr.Slider(
label="Number of Beams", minimum=1, maximum=10, step=1, value=3
)
with gr.Column(scale=3):
gr.Markdown("## Output", elem_id="output-title")
parsed_output = gr.JSON(label="Parsed Answer")
annotated_output = gr.Image(label="Annotated Image (if available)", type="pil")
image_submit.click(
fn=run_florence2_inference,
inputs=[model_choice, image_upload, task_prompt, max_new_tokens, num_beams],
outputs=[parsed_output, annotated_output]
)
if __name__ == "__main__":
demo.queue().launch(debug=True, mcp_server=True, ssr_mode=False, show_error=True)
|