Spaces:
Running
on
Zero
Running
on
Zero
[Update] Demo updated with suggestive prompting, and pinning of packages (#5)
Browse files- update demo with suggestive prompting (8b69117e3279dc7d26728fe43d34ee22ce575434)
- pinning the libraries (5d44beee15caad5daf575c663b125e6d8ec12294)
Co-authored-by: Aritra Roy Gosthipaty <[email protected]>
- app.py +261 -348
- requirements.txt +7 -10
app.py
CHANGED
|
@@ -1,227 +1,140 @@
|
|
| 1 |
-
import json
|
| 2 |
-
import time
|
| 3 |
-
|
| 4 |
import gradio as gr
|
| 5 |
-
import numpy as np
|
| 6 |
from gradio.themes.ocean import Ocean
|
| 7 |
-
|
| 8 |
-
|
|
|
|
| 9 |
from transformers import (
|
| 10 |
AutoModelForCausalLM,
|
| 11 |
-
AutoProcessor,
|
| 12 |
Qwen3VLForConditionalGeneration,
|
|
|
|
| 13 |
)
|
| 14 |
-
|
|
|
|
|
|
|
|
|
|
| 15 |
from spaces import GPU
|
| 16 |
-
import supervision as sv
|
| 17 |
|
| 18 |
-
|
| 19 |
-
|
|
|
|
| 20 |
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
trust_remote_code=True,
|
| 27 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
)
|
| 29 |
|
| 30 |
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
def create_annotated_image(image, json_data, height, width):
|
| 43 |
try:
|
| 44 |
-
|
| 45 |
-
bbox_data = json.loads(parsed_json_data)
|
| 46 |
except Exception:
|
| 47 |
-
return
|
| 48 |
|
| 49 |
-
original_width, original_height = image.size
|
| 50 |
-
x_scale = original_width / width
|
| 51 |
-
y_scale = original_height / height
|
| 52 |
-
|
| 53 |
-
points = []
|
| 54 |
-
point_labels = []
|
| 55 |
-
|
| 56 |
-
for item in bbox_data:
|
| 57 |
-
label = item.get("label", "")
|
| 58 |
-
if "point_2d" in item:
|
| 59 |
-
x, y = item["point_2d"]
|
| 60 |
-
scaled_x = int(x * x_scale)
|
| 61 |
-
scaled_y = int(y * y_scale)
|
| 62 |
-
points.append([scaled_x, scaled_y])
|
| 63 |
-
point_labels.append(label)
|
| 64 |
-
|
| 65 |
-
annotated_image = np.array(image.convert("RGB"))
|
| 66 |
-
|
| 67 |
-
detections = sv.Detections.from_vlm(vlm = sv.VLM.QWEN_2_5_VL,
|
| 68 |
-
result=json_data,
|
| 69 |
-
input_wh=(original_width,
|
| 70 |
-
original_height),
|
| 71 |
-
resolution_wh=(original_width,
|
| 72 |
-
original_height))
|
| 73 |
-
bounding_box_annotator = sv.BoxAnnotator(color_lookup=sv.ColorLookup.INDEX)
|
| 74 |
-
label_annotator = sv.LabelAnnotator(color_lookup=sv.ColorLookup.INDEX)
|
| 75 |
-
|
| 76 |
-
annotated_image = bounding_box_annotator.annotate(
|
| 77 |
-
scene=annotated_image, detections=detections
|
| 78 |
-
)
|
| 79 |
-
annotated_image = label_annotator.annotate(
|
| 80 |
-
scene=annotated_image, detections=detections
|
| 81 |
-
)
|
| 82 |
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
|
| 89 |
-
|
| 90 |
-
|
|
|
|
|
|
|
|
|
|
| 91 |
)
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
|
|
|
|
|
|
| 100 |
|
| 101 |
|
| 102 |
-
def
|
| 103 |
-
if not isinstance(
|
| 104 |
-
return image
|
|
|
|
|
|
|
| 105 |
|
| 106 |
original_width, original_height = image.size
|
| 107 |
-
annotated_image = np.array(image.convert("RGB"))
|
| 108 |
|
| 109 |
-
|
| 110 |
-
if "points" in
|
| 111 |
-
|
|
|
|
| 112 |
x = int(point["x"] * original_width)
|
| 113 |
y = int(point["y"] * original_height)
|
| 114 |
-
|
| 115 |
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
for x_norm, y_norm in grounding.get("points", []):
|
| 119 |
-
x = int(x_norm * original_width)
|
| 120 |
-
y = int(y_norm * original_height)
|
| 121 |
-
points.append([x, y])
|
| 122 |
|
| 123 |
-
|
| 124 |
-
points_array = np.array(points).reshape(1, -1, 2)
|
| 125 |
key_points = sv.KeyPoints(xy=points_array)
|
| 126 |
-
vertex_annotator = sv.VertexAnnotator(radius=
|
| 127 |
annotated_image = vertex_annotator.annotate(
|
| 128 |
-
scene=
|
| 129 |
-
)
|
| 130 |
-
|
| 131 |
-
if "objects" in json_data:
|
| 132 |
-
detections = sv.Detections.from_vlm(sv.VLM.MOONDREAM,json_data,
|
| 133 |
-
resolution_wh=(original_width,
|
| 134 |
-
original_height))
|
| 135 |
-
|
| 136 |
-
bounding_box_annotator = sv.BoxAnnotator(color_lookup=sv.ColorLookup.INDEX)
|
| 137 |
-
label_annotator = sv.LabelAnnotator(color_lookup=sv.ColorLookup.INDEX)
|
| 138 |
-
|
| 139 |
-
labels = [label for _ in detections.xyxy]
|
| 140 |
-
|
| 141 |
-
annotated_image = bounding_box_annotator.annotate(
|
| 142 |
-
scene=annotated_image, detections=detections
|
| 143 |
)
|
| 144 |
-
annotated_image
|
| 145 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 146 |
)
|
|
|
|
|
|
|
| 147 |
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
for i, line in enumerate(lines):
|
| 154 |
-
if line == "```json":
|
| 155 |
-
json_output = "\n".join(lines[i+1:])
|
| 156 |
-
json_output = json_output.split("```")[0]
|
| 157 |
-
break
|
| 158 |
-
|
| 159 |
-
try:
|
| 160 |
-
boxes = json.loads(json_output)
|
| 161 |
-
except json.JSONDecodeError:
|
| 162 |
-
end_idx = json_output.rfind('"}') + len('"}')
|
| 163 |
-
truncated_text = json_output[:end_idx] + "]"
|
| 164 |
-
boxes = json.loads(truncated_text)
|
| 165 |
-
|
| 166 |
-
if not isinstance(boxes, list):
|
| 167 |
-
boxes = [boxes]
|
| 168 |
-
|
| 169 |
-
return boxes
|
| 170 |
-
|
| 171 |
|
| 172 |
-
|
| 173 |
-
try:
|
| 174 |
-
boxes = parse_qwen3_json(json_output)
|
| 175 |
-
except Exception as e:
|
| 176 |
-
print(f"Error parsing JSON: {e}")
|
| 177 |
-
return image
|
| 178 |
-
|
| 179 |
-
if not boxes:
|
| 180 |
-
return image
|
| 181 |
-
|
| 182 |
-
original_width, original_height = image.size
|
| 183 |
-
annotated_image = np.array(image.convert("RGB"))
|
| 184 |
-
|
| 185 |
-
xyxy = []
|
| 186 |
-
labels = []
|
| 187 |
-
|
| 188 |
-
for box in boxes:
|
| 189 |
-
if "bbox_2d" in box and "label" in box:
|
| 190 |
-
x1, y1, x2, y2 = box["bbox_2d"]
|
| 191 |
-
scale = 1000
|
| 192 |
-
x1 = max(0, min(scale, x1)) / scale * original_width
|
| 193 |
-
y1 = max(0, min(scale, y1)) / scale * original_height
|
| 194 |
-
x2 = max(0, min(scale, x2)) / scale * original_width
|
| 195 |
-
y2 = max(0, min(scale, y2)) / scale * original_height
|
| 196 |
-
# Ensure x1 <= x2 and y1 <= y2
|
| 197 |
-
if x1 > x2: x1, x2 = x2, x1
|
| 198 |
-
if y1 > y2: y1, y2 = y2, y1
|
| 199 |
-
xyxy.append([int(x1), int(y1), int(x2), int(y2)])
|
| 200 |
-
labels.append(box["label"])
|
| 201 |
-
|
| 202 |
-
if not xyxy:
|
| 203 |
-
return image
|
| 204 |
-
|
| 205 |
-
detections = sv.Detections(
|
| 206 |
-
xyxy=np.array(xyxy),
|
| 207 |
-
class_id=np.arange(len(xyxy))
|
| 208 |
-
)
|
| 209 |
-
|
| 210 |
-
bounding_box_annotator = sv.BoxAnnotator(color_lookup=sv.ColorLookup.INDEX)
|
| 211 |
-
label_annotator = sv.LabelAnnotator(color_lookup=sv.ColorLookup.INDEX)
|
| 212 |
-
|
| 213 |
-
annotated_image = bounding_box_annotator.annotate(
|
| 214 |
-
scene=annotated_image, detections=detections
|
| 215 |
-
)
|
| 216 |
-
annotated_image = label_annotator.annotate(
|
| 217 |
-
scene=annotated_image, detections=detections, labels=labels
|
| 218 |
-
)
|
| 219 |
-
|
| 220 |
-
return Image.fromarray(annotated_image)
|
| 221 |
|
| 222 |
|
| 223 |
-
|
| 224 |
-
def
|
| 225 |
messages = [
|
| 226 |
{
|
| 227 |
"role": "user",
|
|
@@ -231,75 +144,132 @@ def detect_qwen(image, prompt):
|
|
| 231 |
],
|
| 232 |
}
|
| 233 |
]
|
| 234 |
-
|
| 235 |
-
t0 = time.perf_counter()
|
| 236 |
-
inputs = processor_qwen.apply_chat_template(
|
| 237 |
messages,
|
| 238 |
tokenize=True,
|
| 239 |
add_generation_prompt=True,
|
| 240 |
return_dict=True,
|
| 241 |
-
return_tensors="pt"
|
| 242 |
-
).to(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 243 |
|
| 244 |
-
generated_ids = model_qwen.generate(**inputs, max_new_tokens=1024)
|
| 245 |
generated_ids_trimmed = [
|
| 246 |
out_ids[len(in_ids) :]
|
| 247 |
for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
|
| 248 |
]
|
| 249 |
-
output_text =
|
| 250 |
generated_ids_trimmed,
|
| 251 |
skip_special_tokens=True,
|
| 252 |
clean_up_tokenization_spaces=False,
|
| 253 |
)[0]
|
| 254 |
-
|
| 255 |
|
| 256 |
-
annotated_image = create_annotated_image_qwen3(image, output_text)
|
| 257 |
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 261 |
|
| 262 |
|
| 263 |
@GPU
|
| 264 |
-
def
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
elif
|
| 269 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 270 |
else:
|
| 271 |
-
|
| 272 |
-
image=image, question=prompt, reasoning=True
|
| 273 |
-
)
|
| 274 |
-
elapsed_ms = (time.perf_counter() - t0) * 1_000
|
| 275 |
|
| 276 |
-
annotated_image = create_annotated_image_normalized(
|
| 277 |
-
image=image, json_data=output_text, label="object"
|
| 278 |
-
)
|
| 279 |
|
| 280 |
-
|
| 281 |
-
|
|
|
|
|
|
|
|
|
|
| 282 |
|
| 283 |
|
| 284 |
-
def
|
| 285 |
-
|
| 286 |
-
|
|
|
|
|
|
|
| 287 |
|
| 288 |
-
|
| 289 |
-
|
| 290 |
-
)
|
| 291 |
-
annotated_image_model_2, output_text_model_2, timing_2 = detect_moondream(
|
| 292 |
-
image, prompt_model_2, category_input
|
| 293 |
-
)
|
| 294 |
|
| 295 |
-
|
| 296 |
-
|
| 297 |
-
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
output_text_model_2,
|
| 301 |
-
timing_2,
|
| 302 |
-
)
|
| 303 |
|
| 304 |
|
| 305 |
css_hide_share = """
|
|
@@ -308,6 +278,7 @@ button#gradio-share-link-button-0 {
|
|
| 308 |
}
|
| 309 |
"""
|
| 310 |
|
|
|
|
| 311 |
with gr.Blocks(theme=Ocean(), css=css_hide_share) as demo:
|
| 312 |
gr.Markdown("# 👓 Object Understanding with Vision Language Models")
|
| 313 |
gr.Markdown(
|
|
@@ -319,130 +290,72 @@ with gr.Blocks(theme=Ocean(), css=css_hide_share) as demo:
|
|
| 319 |
""")
|
| 320 |
|
| 321 |
with gr.Row():
|
| 322 |
-
with gr.Column(scale=
|
| 323 |
-
image_input = gr.Image(
|
| 324 |
-
|
| 325 |
-
|
| 326 |
-
|
|
|
|
|
|
|
| 327 |
)
|
| 328 |
-
|
| 329 |
-
|
| 330 |
-
|
| 331 |
-
|
|
|
|
|
|
|
| 332 |
)
|
| 333 |
-
|
| 334 |
-
|
| 335 |
-
"
|
| 336 |
-
|
| 337 |
-
"Visual Grounding + Keypoint Detection",
|
| 338 |
-
"Visual Grounding + Object Detection",
|
| 339 |
-
"General query",
|
| 340 |
-
]
|
| 341 |
-
|
| 342 |
-
category_input = gr.Dropdown(
|
| 343 |
-
choices=categories, label="Category", interactive=True
|
| 344 |
)
|
| 345 |
-
generate_btn = gr.Button(value="Generate")
|
| 346 |
|
| 347 |
-
|
| 348 |
-
|
| 349 |
-
|
| 350 |
-
)
|
| 351 |
-
|
| 352 |
-
|
| 353 |
-
|
| 354 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 355 |
|
| 356 |
-
with gr.Column(scale=1):
|
| 357 |
-
output_image_model_2 = gr.Image(
|
| 358 |
-
type="pil",
|
| 359 |
-
label=f"Annotated image for {model_moondream_name}",
|
| 360 |
-
height=400,
|
| 361 |
-
)
|
| 362 |
-
output_textbox_model_2 = gr.Textbox(
|
| 363 |
-
label=f"Model response for {model_moondream_name}", lines=10
|
| 364 |
-
)
|
| 365 |
-
output_time_model_2 = gr.Markdown()
|
| 366 |
-
|
| 367 |
-
gr.Markdown("### Examples")
|
| 368 |
-
example_prompts = [
|
| 369 |
-
[
|
| 370 |
-
"examples/example_1.jpg",
|
| 371 |
-
"locate every instance in the image. Report bbox coordinates in JSON format.",
|
| 372 |
-
"objects",
|
| 373 |
-
"Object Detection",
|
| 374 |
-
],
|
| 375 |
-
[
|
| 376 |
-
"examples/example_2.JPG",
|
| 377 |
-
'locate every instance that belongs to the following categories: "candy, hand". Report bbox coordinates in JSON format.',
|
| 378 |
-
"candies",
|
| 379 |
-
"Object Detection",
|
| 380 |
-
],
|
| 381 |
-
[
|
| 382 |
-
"examples/example_1.jpg",
|
| 383 |
-
"Count the number of red cars in the image.",
|
| 384 |
-
"Count the number of red cars in the image.",
|
| 385 |
-
"Object Counting",
|
| 386 |
-
],
|
| 387 |
-
[
|
| 388 |
-
"examples/example_2.JPG",
|
| 389 |
-
"Count the number of blue candies in the image.",
|
| 390 |
-
"Count the number of blue candies in the image.",
|
| 391 |
-
"Object Counting",
|
| 392 |
-
],
|
| 393 |
-
[
|
| 394 |
-
"examples/example_1.jpg",
|
| 395 |
-
'locate every instance that belongs to the following categories: "red car". Report bbox coordinates in JSON format..',
|
| 396 |
-
"red cars",
|
| 397 |
-
"Visual Grounding + Keypoint Detection",
|
| 398 |
-
],
|
| 399 |
-
[
|
| 400 |
-
"examples/example_2.JPG",
|
| 401 |
-
"Identify the blue candies in this image, detect their key points and return their positions in the form of points.",
|
| 402 |
-
"blue candies",
|
| 403 |
-
"Visual Grounding + Keypoint Detection",
|
| 404 |
-
],
|
| 405 |
-
[
|
| 406 |
-
"examples/example_1.jpg",
|
| 407 |
-
'locate every instance that belongs to the following categories: "leading red car". Report bbox coordinates in JSON format..',
|
| 408 |
-
"leading red car",
|
| 409 |
-
"Visual Grounding + Object Detection",
|
| 410 |
-
],
|
| 411 |
-
[
|
| 412 |
-
"examples/example_2.JPG",
|
| 413 |
-
'locate every instance that belongs to the following categories: "blue candy located at the top of the group". Report bbox coordinates in JSON format.',
|
| 414 |
-
"blue candy located at the top of the group",
|
| 415 |
-
"Visual Grounding + Object Detection",
|
| 416 |
-
],
|
| 417 |
-
]
|
| 418 |
gr.Examples(
|
| 419 |
-
|
| 420 |
-
|
| 421 |
-
|
| 422 |
-
|
| 423 |
-
|
| 424 |
-
|
| 425 |
-
|
| 426 |
-
label="Click an example to populate the input",
|
| 427 |
)
|
| 428 |
|
| 429 |
-
|
| 430 |
-
|
| 431 |
-
|
| 432 |
-
|
| 433 |
-
|
| 434 |
-
|
| 435 |
-
|
| 436 |
-
|
| 437 |
-
|
| 438 |
-
|
| 439 |
-
|
| 440 |
-
|
| 441 |
-
|
| 442 |
-
|
| 443 |
-
|
| 444 |
-
],
|
|
|
|
| 445 |
)
|
| 446 |
|
| 447 |
if __name__ == "__main__":
|
| 448 |
-
demo.launch()
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import gradio as gr
|
|
|
|
| 2 |
from gradio.themes.ocean import Ocean
|
| 3 |
+
import torch
|
| 4 |
+
import numpy as np
|
| 5 |
+
import supervision as sv
|
| 6 |
from transformers import (
|
| 7 |
AutoModelForCausalLM,
|
|
|
|
| 8 |
Qwen3VLForConditionalGeneration,
|
| 9 |
+
Qwen3VLProcessor,
|
| 10 |
)
|
| 11 |
+
import json
|
| 12 |
+
import ast
|
| 13 |
+
import re
|
| 14 |
+
from PIL import Image
|
| 15 |
from spaces import GPU
|
|
|
|
| 16 |
|
| 17 |
+
# --- Constants and Configuration ---
|
| 18 |
+
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 19 |
+
DTYPE = "auto"
|
| 20 |
|
| 21 |
+
CATEGORIES = ["Query", "Caption", "Point", "Detect"]
|
| 22 |
+
PLACEHOLDERS = {
|
| 23 |
+
"Query": "What's in this image?",
|
| 24 |
+
"Caption": "Enter caption length: short, normal, or long",
|
| 25 |
+
"Point": "Select an object from suggestions or enter manually",
|
| 26 |
+
"Detect": "Select an object from suggestions or enter manually",
|
| 27 |
+
}
|
| 28 |
+
|
| 29 |
+
# --- Model Loading ---
|
| 30 |
+
# Load Moondream
|
| 31 |
+
moondream = AutoModelForCausalLM.from_pretrained(
|
| 32 |
+
"moondream/moondream3-preview",
|
| 33 |
trust_remote_code=True,
|
| 34 |
+
dtype=DTYPE,
|
| 35 |
+
device_map=DEVICE,
|
| 36 |
+
revision="main",
|
| 37 |
+
).eval()
|
| 38 |
+
|
| 39 |
+
# Load Qwen3-VL
|
| 40 |
+
qwen_model = Qwen3VLForConditionalGeneration.from_pretrained(
|
| 41 |
+
"Qwen/Qwen3-VL-4B-Instruct",
|
| 42 |
+
dtype=DTYPE,
|
| 43 |
+
device_map=DEVICE,
|
| 44 |
+
).eval()
|
| 45 |
+
qwen_processor = Qwen3VLProcessor.from_pretrained(
|
| 46 |
+
"Qwen/Qwen3-VL-4B-Instruct",
|
| 47 |
)
|
| 48 |
|
| 49 |
|
| 50 |
+
# --- Utility Functions ---
|
| 51 |
+
def safe_parse_json(text: str):
|
| 52 |
+
text = text.strip()
|
| 53 |
+
text = re.sub(r"^```(json)?", "", text)
|
| 54 |
+
text = re.sub(r"```$", "", text)
|
| 55 |
+
text = text.strip()
|
| 56 |
+
try:
|
| 57 |
+
return json.loads(text)
|
| 58 |
+
except json.JSONDecodeError:
|
| 59 |
+
pass
|
|
|
|
|
|
|
| 60 |
try:
|
| 61 |
+
return ast.literal_eval(text)
|
|
|
|
| 62 |
except Exception:
|
| 63 |
+
return {}
|
| 64 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
|
| 66 |
+
@GPU
|
| 67 |
+
def get_suggested_objects(image: Image.Image):
|
| 68 |
+
"""Get suggested objects in the image using Moondream"""
|
| 69 |
+
if image is None:
|
| 70 |
+
return []
|
| 71 |
|
| 72 |
+
try:
|
| 73 |
+
result = moondream.query(
|
| 74 |
+
image=image,
|
| 75 |
+
question="What objects are in the image, provide the list.",
|
| 76 |
+
reasoning=False,
|
| 77 |
)
|
| 78 |
+
suggested_objects = ast.literal_eval(result["answer"])
|
| 79 |
+
if isinstance(suggested_objects, list):
|
| 80 |
+
if len(suggested_objects) > 3: # send not more than 3 suggestions
|
| 81 |
+
return suggested_objects[:3]
|
| 82 |
+
else:
|
| 83 |
+
suggested_objects
|
| 84 |
+
return []
|
| 85 |
+
except Exception as e:
|
| 86 |
+
print(f"Error getting suggestions: {e}")
|
| 87 |
+
return []
|
| 88 |
|
| 89 |
|
| 90 |
+
def annotate_image(image: Image.Image, result: dict):
|
| 91 |
+
if not isinstance(image, Image.Image):
|
| 92 |
+
return image # Return original if not a valid image
|
| 93 |
+
if not isinstance(result, dict):
|
| 94 |
+
return image # Return original if result is not a dict
|
| 95 |
|
| 96 |
original_width, original_height = image.size
|
|
|
|
| 97 |
|
| 98 |
+
# Handle Point annotations
|
| 99 |
+
if "points" in result and result["points"]:
|
| 100 |
+
points_list = []
|
| 101 |
+
for point in result.get("points", []):
|
| 102 |
x = int(point["x"] * original_width)
|
| 103 |
y = int(point["y"] * original_height)
|
| 104 |
+
points_list.append([x, y])
|
| 105 |
|
| 106 |
+
if not points_list:
|
| 107 |
+
return image
|
|
|
|
|
|
|
|
|
|
|
|
|
| 108 |
|
| 109 |
+
points_array = np.array(points_list).reshape(1, -1, 2)
|
|
|
|
| 110 |
key_points = sv.KeyPoints(xy=points_array)
|
| 111 |
+
vertex_annotator = sv.VertexAnnotator(radius=8, color=sv.Color.RED)
|
| 112 |
annotated_image = vertex_annotator.annotate(
|
| 113 |
+
scene=image.copy(), key_points=key_points
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 114 |
)
|
| 115 |
+
return annotated_image
|
| 116 |
+
|
| 117 |
+
# Handle Detection annotations
|
| 118 |
+
if "objects" in result and result["objects"]:
|
| 119 |
+
detections = sv.Detections.from_vlm(
|
| 120 |
+
sv.VLM.MOONDREAM,
|
| 121 |
+
result,
|
| 122 |
+
resolution_wh=image.size,
|
| 123 |
)
|
| 124 |
+
if len(detections) == 0:
|
| 125 |
+
return image
|
| 126 |
|
| 127 |
+
box_annotator = sv.BoxAnnotator(color_lookup=sv.ColorLookup.INDEX, thickness=5)
|
| 128 |
+
annotated_scene = box_annotator.annotate(
|
| 129 |
+
scene=image.copy(), detections=detections
|
| 130 |
+
)
|
| 131 |
+
return annotated_scene
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 132 |
|
| 133 |
+
return image
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 134 |
|
| 135 |
|
| 136 |
+
# --- Inference Functions ---
|
| 137 |
+
def run_qwen_inference(image: Image.Image, prompt: str):
|
| 138 |
messages = [
|
| 139 |
{
|
| 140 |
"role": "user",
|
|
|
|
| 144 |
],
|
| 145 |
}
|
| 146 |
]
|
| 147 |
+
inputs = qwen_processor.apply_chat_template(
|
|
|
|
|
|
|
| 148 |
messages,
|
| 149 |
tokenize=True,
|
| 150 |
add_generation_prompt=True,
|
| 151 |
return_dict=True,
|
| 152 |
+
return_tensors="pt",
|
| 153 |
+
).to(DEVICE)
|
| 154 |
+
|
| 155 |
+
with torch.inference_mode():
|
| 156 |
+
generated_ids = qwen_model.generate(
|
| 157 |
+
**inputs,
|
| 158 |
+
max_new_tokens=512,
|
| 159 |
+
)
|
| 160 |
|
|
|
|
| 161 |
generated_ids_trimmed = [
|
| 162 |
out_ids[len(in_ids) :]
|
| 163 |
for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
|
| 164 |
]
|
| 165 |
+
output_text = qwen_processor.batch_decode(
|
| 166 |
generated_ids_trimmed,
|
| 167 |
skip_special_tokens=True,
|
| 168 |
clean_up_tokenization_spaces=False,
|
| 169 |
)[0]
|
| 170 |
+
return output_text
|
| 171 |
|
|
|
|
| 172 |
|
| 173 |
+
@GPU
|
| 174 |
+
def process_qwen(image: Image.Image, category: str, prompt: str):
|
| 175 |
+
if category == "Query":
|
| 176 |
+
return run_qwen_inference(image, prompt), {}
|
| 177 |
+
elif category == "Caption":
|
| 178 |
+
full_prompt = f"Provide a {prompt} length caption for the image."
|
| 179 |
+
return run_qwen_inference(image, full_prompt), {}
|
| 180 |
+
elif category == "Point":
|
| 181 |
+
full_prompt = (
|
| 182 |
+
f"Provide 2d point coordinates for {prompt}. Report in JSON format."
|
| 183 |
+
)
|
| 184 |
+
output_text = run_qwen_inference(image, full_prompt)
|
| 185 |
+
parsed_json = safe_parse_json(output_text)
|
| 186 |
+
points_result = {"points": []}
|
| 187 |
+
if isinstance(parsed_json, list):
|
| 188 |
+
for item in parsed_json:
|
| 189 |
+
if "point_2d" in item and len(item["point_2d"]) == 2:
|
| 190 |
+
x, y = item["point_2d"]
|
| 191 |
+
points_result["points"].append({"x": x / 1000.0, "y": y / 1000.0})
|
| 192 |
+
return json.dumps(points_result, indent=2), points_result
|
| 193 |
+
elif category == "Detect":
|
| 194 |
+
full_prompt = (
|
| 195 |
+
f"Provide bounding box coordinates for {prompt}. Report in JSON format."
|
| 196 |
+
)
|
| 197 |
+
output_text = run_qwen_inference(image, full_prompt)
|
| 198 |
+
parsed_json = safe_parse_json(output_text)
|
| 199 |
+
objects_result = {"objects": []}
|
| 200 |
+
if isinstance(parsed_json, list):
|
| 201 |
+
for item in parsed_json:
|
| 202 |
+
if "bbox_2d" in item and len(item["bbox_2d"]) == 4:
|
| 203 |
+
xmin, ymin, xmax, ymax = item["bbox_2d"]
|
| 204 |
+
objects_result["objects"].append(
|
| 205 |
+
{
|
| 206 |
+
"x_min": xmin / 1000.0,
|
| 207 |
+
"y_min": ymin / 1000.0,
|
| 208 |
+
"x_max": xmax / 1000.0,
|
| 209 |
+
"y_max": ymax / 1000.0,
|
| 210 |
+
}
|
| 211 |
+
)
|
| 212 |
+
return json.dumps(objects_result, indent=2), objects_result
|
| 213 |
+
return "Invalid category", {}
|
| 214 |
|
| 215 |
|
| 216 |
@GPU
|
| 217 |
+
def process_moondream(image: Image.Image, category: str, prompt: str):
|
| 218 |
+
if category == "Query":
|
| 219 |
+
result = moondream.query(image=image, question=prompt)
|
| 220 |
+
return result["answer"], {}
|
| 221 |
+
elif category == "Caption":
|
| 222 |
+
result = moondream.caption(image, length=prompt)
|
| 223 |
+
return result["caption"], {}
|
| 224 |
+
elif category == "Point":
|
| 225 |
+
result = moondream.point(image, prompt)
|
| 226 |
+
return json.dumps(result, indent=2), result
|
| 227 |
+
elif category == "Detect":
|
| 228 |
+
result = moondream.detect(image, prompt)
|
| 229 |
+
return json.dumps(result, indent=2), result
|
| 230 |
+
return "Invalid category", {}
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
# --- Gradio Interface Logic ---
|
| 234 |
+
def on_category_and_image_change(image, category):
|
| 235 |
+
"""Generate suggestions when category changes to Point or Detect"""
|
| 236 |
+
text_box = gr.Textbox(value="", placeholder=PLACEHOLDERS.get(category, ""), interactive=True)
|
| 237 |
+
|
| 238 |
+
if image is None or category not in ["Point", "Detect", "Caption"]:
|
| 239 |
+
return gr.Radio(choices=[], visible=False), text_box
|
| 240 |
+
|
| 241 |
+
if category == "Caption":
|
| 242 |
+
return gr.Radio(choices=["short", "normal", "long"], visible=True), text_box
|
| 243 |
+
|
| 244 |
+
suggestions = get_suggested_objects(image)
|
| 245 |
+
if suggestions:
|
| 246 |
+
return gr.Radio(choices=suggestions, visible=True, interactive=True), text_box
|
| 247 |
else:
|
| 248 |
+
return gr.Radio(choices=["no choice possible"], visible=True, interactive=True), text_box
|
|
|
|
|
|
|
|
|
|
| 249 |
|
|
|
|
|
|
|
|
|
|
| 250 |
|
| 251 |
+
def update_prompt_from_radio(selected_object):
|
| 252 |
+
"""Update prompt textbox when a radio option is selected"""
|
| 253 |
+
if selected_object:
|
| 254 |
+
return gr.Textbox(value=selected_object)
|
| 255 |
+
return gr.Textbox(value="")
|
| 256 |
|
| 257 |
|
| 258 |
+
def process_inputs(image, category, prompt):
|
| 259 |
+
if image is None:
|
| 260 |
+
raise gr.Error("Please upload an image.")
|
| 261 |
+
if not prompt:
|
| 262 |
+
raise gr.Error("Please provide a prompt.")
|
| 263 |
|
| 264 |
+
# Process with Qwen
|
| 265 |
+
qwen_text, qwen_data = process_qwen(image, category, prompt)
|
| 266 |
+
qwen_annotated_image = annotate_image(image, qwen_data)
|
|
|
|
|
|
|
|
|
|
| 267 |
|
| 268 |
+
# Process with Moondream
|
| 269 |
+
moondream_text, moondream_data = process_moondream(image, category, prompt)
|
| 270 |
+
moondream_annotated_image = annotate_image(image, moondream_data)
|
| 271 |
+
|
| 272 |
+
return qwen_annotated_image, qwen_text, moondream_annotated_image, moondream_text
|
|
|
|
|
|
|
|
|
|
| 273 |
|
| 274 |
|
| 275 |
css_hide_share = """
|
|
|
|
| 278 |
}
|
| 279 |
"""
|
| 280 |
|
| 281 |
+
# --- Gradio UI Layout ---
|
| 282 |
with gr.Blocks(theme=Ocean(), css=css_hide_share) as demo:
|
| 283 |
gr.Markdown("# 👓 Object Understanding with Vision Language Models")
|
| 284 |
gr.Markdown(
|
|
|
|
| 290 |
""")
|
| 291 |
|
| 292 |
with gr.Row():
|
| 293 |
+
with gr.Column(scale=1):
|
| 294 |
+
image_input = gr.Image(type="pil", label="Input Image")
|
| 295 |
+
category_select = gr.Radio(
|
| 296 |
+
choices=CATEGORIES,
|
| 297 |
+
value=CATEGORIES[0],
|
| 298 |
+
label="Select Task Category",
|
| 299 |
+
interactive=True,
|
| 300 |
)
|
| 301 |
+
# Suggested objects radio (hidden by default)
|
| 302 |
+
suggestions_radio = gr.Radio(
|
| 303 |
+
choices=[],
|
| 304 |
+
label="Suggestions",
|
| 305 |
+
visible=False,
|
| 306 |
+
interactive=True,
|
| 307 |
)
|
| 308 |
+
prompt_input = gr.Textbox(
|
| 309 |
+
placeholder=PLACEHOLDERS[CATEGORIES[0]],
|
| 310 |
+
label="Prompt",
|
| 311 |
+
lines=2,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 312 |
)
|
|
|
|
| 313 |
|
| 314 |
+
submit_btn = gr.Button("Compare Models", variant="primary")
|
| 315 |
+
|
| 316 |
+
with gr.Column(scale=2):
|
| 317 |
+
with gr.Row():
|
| 318 |
+
with gr.Column():
|
| 319 |
+
gr.Markdown("### Qwen/Qwen3-VL-4B-Instruct")
|
| 320 |
+
qwen_img_output = gr.Image(label="Annotated Image")
|
| 321 |
+
qwen_text_output = gr.Textbox(
|
| 322 |
+
label="Text Output", lines=8, interactive=False
|
| 323 |
+
)
|
| 324 |
+
with gr.Column():
|
| 325 |
+
gr.Markdown("### moondream/moondream3-preview")
|
| 326 |
+
moon_img_output = gr.Image(label="Annotated Image")
|
| 327 |
+
moon_text_output = gr.Textbox(
|
| 328 |
+
label="Text Output", lines=8, interactive=False
|
| 329 |
+
)
|
| 330 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 331 |
gr.Examples(
|
| 332 |
+
examples=[
|
| 333 |
+
["examples/example_1.jpg", "Query", "How many cars are in the image?"],
|
| 334 |
+
["examples/example_1.jpg", "Caption", ""],
|
| 335 |
+
["examples/example_2.JPG", "Point", ""],
|
| 336 |
+
["examples/example_2.JPG", "Detect", ""],
|
| 337 |
+
],
|
| 338 |
+
inputs=[image_input, category_select, prompt_input],
|
|
|
|
| 339 |
)
|
| 340 |
|
| 341 |
+
# --- Event Listeners ---
|
| 342 |
+
category_select.change(
|
| 343 |
+
fn=on_category_and_image_change,
|
| 344 |
+
inputs=[image_input, category_select],
|
| 345 |
+
outputs=[suggestions_radio, prompt_input],
|
| 346 |
+
)
|
| 347 |
+
|
| 348 |
+
suggestions_radio.change(
|
| 349 |
+
fn=update_prompt_from_radio,
|
| 350 |
+
inputs=[suggestions_radio],
|
| 351 |
+
outputs=[prompt_input],
|
| 352 |
+
)
|
| 353 |
+
|
| 354 |
+
submit_btn.click(
|
| 355 |
+
fn=process_inputs,
|
| 356 |
+
inputs=[image_input, category_select, prompt_input],
|
| 357 |
+
outputs=[qwen_img_output, qwen_text_output, moon_img_output, moon_text_output],
|
| 358 |
)
|
| 359 |
|
| 360 |
if __name__ == "__main__":
|
| 361 |
+
demo.launch()
|
requirements.txt
CHANGED
|
@@ -1,10 +1,7 @@
|
|
| 1 |
-
torch
|
| 2 |
-
transformers
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
torchvision
|
| 9 |
-
matplotlib
|
| 10 |
-
supervision
|
|
|
|
| 1 |
+
torch==2.8.0
|
| 2 |
+
transformers==4.57.0
|
| 3 |
+
Pillow==11.3.0
|
| 4 |
+
gradio==5.49.1
|
| 5 |
+
accelerate==1.10.1
|
| 6 |
+
torchvision==0.23.0
|
| 7 |
+
supervision==0.26.1
|
|
|
|
|
|
|
|
|