sergiopaniego HF Staff ariG23498 HF Staff commited on
Commit
c785294
·
verified ·
1 Parent(s): e80a30e

[Update] Demo updated with suggestive prompting, and pinning of packages (#5)

Browse files

- update demo with suggestive prompting (8b69117e3279dc7d26728fe43d34ee22ce575434)
- pinning the libraries (5d44beee15caad5daf575c663b125e6d8ec12294)


Co-authored-by: Aritra Roy Gosthipaty <[email protected]>

Files changed (2) hide show
  1. app.py +261 -348
  2. requirements.txt +7 -10
app.py CHANGED
@@ -1,227 +1,140 @@
1
- import json
2
- import time
3
-
4
  import gradio as gr
5
- import numpy as np
6
  from gradio.themes.ocean import Ocean
7
- from PIL import Image
8
- from qwen_vl_utils import process_vision_info
 
9
  from transformers import (
10
  AutoModelForCausalLM,
11
- AutoProcessor,
12
  Qwen3VLForConditionalGeneration,
 
13
  )
14
-
 
 
 
15
  from spaces import GPU
16
- import supervision as sv
17
 
18
- model_qwen_id = "Qwen/Qwen3-VL-4B-Instruct"
19
- model_moondream_id = "moondream/moondream3-preview"
 
20
 
21
- model_qwen = Qwen3VLForConditionalGeneration.from_pretrained(
22
- model_qwen_id, torch_dtype="auto", device_map="auto",
23
- )
24
- model_moondream = AutoModelForCausalLM.from_pretrained(
25
- model_moondream_id,
 
 
 
 
 
 
 
26
  trust_remote_code=True,
27
- device_map={"": "cuda"},
 
 
 
 
 
 
 
 
 
 
 
 
28
  )
29
 
30
 
31
- def extract_model_short_name(model_id):
32
- return model_id.split("/")[-1].replace("-", " ").replace("_", " ")
33
-
34
-
35
- model_qwen_name = extract_model_short_name(model_qwen_id)
36
- model_moondream_name = extract_model_short_name(model_moondream_id)
37
-
38
-
39
- processor_qwen = AutoProcessor.from_pretrained(model_qwen_id)
40
-
41
-
42
- def create_annotated_image(image, json_data, height, width):
43
  try:
44
- parsed_json_data = json_data.split("```json")[1].split("```")[0]
45
- bbox_data = json.loads(parsed_json_data)
46
  except Exception:
47
- return image
48
 
49
- original_width, original_height = image.size
50
- x_scale = original_width / width
51
- y_scale = original_height / height
52
-
53
- points = []
54
- point_labels = []
55
-
56
- for item in bbox_data:
57
- label = item.get("label", "")
58
- if "point_2d" in item:
59
- x, y = item["point_2d"]
60
- scaled_x = int(x * x_scale)
61
- scaled_y = int(y * y_scale)
62
- points.append([scaled_x, scaled_y])
63
- point_labels.append(label)
64
-
65
- annotated_image = np.array(image.convert("RGB"))
66
-
67
- detections = sv.Detections.from_vlm(vlm = sv.VLM.QWEN_2_5_VL,
68
- result=json_data,
69
- input_wh=(original_width,
70
- original_height),
71
- resolution_wh=(original_width,
72
- original_height))
73
- bounding_box_annotator = sv.BoxAnnotator(color_lookup=sv.ColorLookup.INDEX)
74
- label_annotator = sv.LabelAnnotator(color_lookup=sv.ColorLookup.INDEX)
75
-
76
- annotated_image = bounding_box_annotator.annotate(
77
- scene=annotated_image, detections=detections
78
- )
79
- annotated_image = label_annotator.annotate(
80
- scene=annotated_image, detections=detections
81
- )
82
 
83
- if points:
84
- points_array = np.array(points).reshape(1, -1, 2)
85
- key_points = sv.KeyPoints(xy=points_array)
86
- vertex_annotator = sv.VertexAnnotator(radius=5, color=sv.Color.BLUE)
87
- # vertex_label_annotator = sv.VertexLabelAnnotator(text_scale=0.5, border_radius=2)
88
 
89
- annotated_image = vertex_annotator.annotate(
90
- scene=annotated_image, key_points=key_points
 
 
 
91
  )
92
-
93
- # annotated_image = vertex_label_annotator.annotate(
94
- # scene=annotated_image,
95
- # key_points=key_points,
96
- # labels=point_labels
97
- # )
98
-
99
- return Image.fromarray(annotated_image)
 
 
100
 
101
 
102
- def create_annotated_image_normalized(image, json_data, label="object"):
103
- if not isinstance(json_data, dict):
104
- return image
 
 
105
 
106
  original_width, original_height = image.size
107
- annotated_image = np.array(image.convert("RGB"))
108
 
109
- points = []
110
- if "points" in json_data:
111
- for point in json_data.get("points", []):
 
112
  x = int(point["x"] * original_width)
113
  y = int(point["y"] * original_height)
114
- points.append([x, y])
115
 
116
- if "reasoning" in json_data:
117
- for grounding in json_data["reasoning"].get("grounding", []):
118
- for x_norm, y_norm in grounding.get("points", []):
119
- x = int(x_norm * original_width)
120
- y = int(y_norm * original_height)
121
- points.append([x, y])
122
 
123
- if points:
124
- points_array = np.array(points).reshape(1, -1, 2)
125
  key_points = sv.KeyPoints(xy=points_array)
126
- vertex_annotator = sv.VertexAnnotator(radius=5, color=sv.Color.RED)
127
  annotated_image = vertex_annotator.annotate(
128
- scene=annotated_image, key_points=key_points
129
- )
130
-
131
- if "objects" in json_data:
132
- detections = sv.Detections.from_vlm(sv.VLM.MOONDREAM,json_data,
133
- resolution_wh=(original_width,
134
- original_height))
135
-
136
- bounding_box_annotator = sv.BoxAnnotator(color_lookup=sv.ColorLookup.INDEX)
137
- label_annotator = sv.LabelAnnotator(color_lookup=sv.ColorLookup.INDEX)
138
-
139
- labels = [label for _ in detections.xyxy]
140
-
141
- annotated_image = bounding_box_annotator.annotate(
142
- scene=annotated_image, detections=detections
143
  )
144
- annotated_image = label_annotator.annotate(
145
- scene=annotated_image, detections=detections, labels=labels
 
 
 
 
 
 
146
  )
 
 
147
 
148
- return Image.fromarray(annotated_image)
149
-
150
-
151
- def parse_qwen3_json(json_output):
152
- lines = json_output.splitlines()
153
- for i, line in enumerate(lines):
154
- if line == "```json":
155
- json_output = "\n".join(lines[i+1:])
156
- json_output = json_output.split("```")[0]
157
- break
158
-
159
- try:
160
- boxes = json.loads(json_output)
161
- except json.JSONDecodeError:
162
- end_idx = json_output.rfind('"}') + len('"}')
163
- truncated_text = json_output[:end_idx] + "]"
164
- boxes = json.loads(truncated_text)
165
-
166
- if not isinstance(boxes, list):
167
- boxes = [boxes]
168
-
169
- return boxes
170
-
171
 
172
- def create_annotated_image_qwen3(image, json_output):
173
- try:
174
- boxes = parse_qwen3_json(json_output)
175
- except Exception as e:
176
- print(f"Error parsing JSON: {e}")
177
- return image
178
-
179
- if not boxes:
180
- return image
181
-
182
- original_width, original_height = image.size
183
- annotated_image = np.array(image.convert("RGB"))
184
-
185
- xyxy = []
186
- labels = []
187
-
188
- for box in boxes:
189
- if "bbox_2d" in box and "label" in box:
190
- x1, y1, x2, y2 = box["bbox_2d"]
191
- scale = 1000
192
- x1 = max(0, min(scale, x1)) / scale * original_width
193
- y1 = max(0, min(scale, y1)) / scale * original_height
194
- x2 = max(0, min(scale, x2)) / scale * original_width
195
- y2 = max(0, min(scale, y2)) / scale * original_height
196
- # Ensure x1 <= x2 and y1 <= y2
197
- if x1 > x2: x1, x2 = x2, x1
198
- if y1 > y2: y1, y2 = y2, y1
199
- xyxy.append([int(x1), int(y1), int(x2), int(y2)])
200
- labels.append(box["label"])
201
-
202
- if not xyxy:
203
- return image
204
-
205
- detections = sv.Detections(
206
- xyxy=np.array(xyxy),
207
- class_id=np.arange(len(xyxy))
208
- )
209
-
210
- bounding_box_annotator = sv.BoxAnnotator(color_lookup=sv.ColorLookup.INDEX)
211
- label_annotator = sv.LabelAnnotator(color_lookup=sv.ColorLookup.INDEX)
212
-
213
- annotated_image = bounding_box_annotator.annotate(
214
- scene=annotated_image, detections=detections
215
- )
216
- annotated_image = label_annotator.annotate(
217
- scene=annotated_image, detections=detections, labels=labels
218
- )
219
-
220
- return Image.fromarray(annotated_image)
221
 
222
 
223
- @GPU
224
- def detect_qwen(image, prompt):
225
  messages = [
226
  {
227
  "role": "user",
@@ -231,75 +144,132 @@ def detect_qwen(image, prompt):
231
  ],
232
  }
233
  ]
234
-
235
- t0 = time.perf_counter()
236
- inputs = processor_qwen.apply_chat_template(
237
  messages,
238
  tokenize=True,
239
  add_generation_prompt=True,
240
  return_dict=True,
241
- return_tensors="pt"
242
- ).to(model_qwen.device)
 
 
 
 
 
 
243
 
244
- generated_ids = model_qwen.generate(**inputs, max_new_tokens=1024)
245
  generated_ids_trimmed = [
246
  out_ids[len(in_ids) :]
247
  for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
248
  ]
249
- output_text = processor_qwen.batch_decode(
250
  generated_ids_trimmed,
251
  skip_special_tokens=True,
252
  clean_up_tokenization_spaces=False,
253
  )[0]
254
- elapsed_ms = (time.perf_counter() - t0) * 1_000
255
 
256
- annotated_image = create_annotated_image_qwen3(image, output_text)
257
 
258
- time_taken = f"**Inference time ({model_qwen_name}):** {elapsed_ms:.0f} ms"
259
-
260
- return annotated_image, output_text, time_taken
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
261
 
262
 
263
  @GPU
264
- def detect_moondream(image, prompt, category_input):
265
- t0 = time.perf_counter()
266
- if category_input in ["Object Detection", "Visual Grounding + Object Detection"]:
267
- output_text = model_moondream.detect(image=image, object=prompt)
268
- elif category_input == "Visual Grounding + Keypoint Detection":
269
- output_text = model_moondream.point(image=image, object=prompt)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
270
  else:
271
- output_text = model_moondream.query(
272
- image=image, question=prompt, reasoning=True
273
- )
274
- elapsed_ms = (time.perf_counter() - t0) * 1_000
275
 
276
- annotated_image = create_annotated_image_normalized(
277
- image=image, json_data=output_text, label="object"
278
- )
279
 
280
- time_taken = f"**Inference time ({model_moondream_name}):** {elapsed_ms:.0f} ms"
281
- return annotated_image, output_text, time_taken
 
 
 
282
 
283
 
284
- def detect(image, prompt_model_1, prompt_model_2, category_input):
285
- STANDARD_SIZE = (1024, 1024)
286
- image.thumbnail(STANDARD_SIZE)
 
 
287
 
288
- annotated_image_model_1, output_text_model_1, timing_1 = detect_qwen(
289
- image, prompt_model_1
290
- )
291
- annotated_image_model_2, output_text_model_2, timing_2 = detect_moondream(
292
- image, prompt_model_2, category_input
293
- )
294
 
295
- return (
296
- annotated_image_model_1,
297
- output_text_model_1,
298
- timing_1,
299
- annotated_image_model_2,
300
- output_text_model_2,
301
- timing_2,
302
- )
303
 
304
 
305
  css_hide_share = """
@@ -308,6 +278,7 @@ button#gradio-share-link-button-0 {
308
  }
309
  """
310
 
 
311
  with gr.Blocks(theme=Ocean(), css=css_hide_share) as demo:
312
  gr.Markdown("# 👓 Object Understanding with Vision Language Models")
313
  gr.Markdown(
@@ -319,130 +290,72 @@ with gr.Blocks(theme=Ocean(), css=css_hide_share) as demo:
319
  """)
320
 
321
  with gr.Row():
322
- with gr.Column(scale=2):
323
- image_input = gr.Image(label="Upload an image", type="pil", height=400)
324
- prompt_input_model_1 = gr.Textbox(
325
- label=f"Enter your prompt for {model_qwen_name}",
326
- placeholder="e.g., Detect all red cars in the image",
 
 
327
  )
328
-
329
- prompt_input_model_2 = gr.Textbox(
330
- label=f"Enter your prompt for {model_moondream_name}",
331
- placeholder="e.g., Detect all blue cars in the image",
 
 
332
  )
333
-
334
- categories = [
335
- "Object Detection",
336
- "Object Counting",
337
- "Visual Grounding + Keypoint Detection",
338
- "Visual Grounding + Object Detection",
339
- "General query",
340
- ]
341
-
342
- category_input = gr.Dropdown(
343
- choices=categories, label="Category", interactive=True
344
  )
345
- generate_btn = gr.Button(value="Generate")
346
 
347
- with gr.Column(scale=1):
348
- output_image_model_1 = gr.Image(
349
- type="pil", label=f"Annotated image for {model_qwen_name}", height=400
350
- )
351
- output_textbox_model_1 = gr.Textbox(
352
- label=f"Model response for {model_qwen_name}", lines=10
353
- )
354
- output_time_model_1 = gr.Markdown()
 
 
 
 
 
 
 
 
355
 
356
- with gr.Column(scale=1):
357
- output_image_model_2 = gr.Image(
358
- type="pil",
359
- label=f"Annotated image for {model_moondream_name}",
360
- height=400,
361
- )
362
- output_textbox_model_2 = gr.Textbox(
363
- label=f"Model response for {model_moondream_name}", lines=10
364
- )
365
- output_time_model_2 = gr.Markdown()
366
-
367
- gr.Markdown("### Examples")
368
- example_prompts = [
369
- [
370
- "examples/example_1.jpg",
371
- "locate every instance in the image. Report bbox coordinates in JSON format.",
372
- "objects",
373
- "Object Detection",
374
- ],
375
- [
376
- "examples/example_2.JPG",
377
- 'locate every instance that belongs to the following categories: "candy, hand". Report bbox coordinates in JSON format.',
378
- "candies",
379
- "Object Detection",
380
- ],
381
- [
382
- "examples/example_1.jpg",
383
- "Count the number of red cars in the image.",
384
- "Count the number of red cars in the image.",
385
- "Object Counting",
386
- ],
387
- [
388
- "examples/example_2.JPG",
389
- "Count the number of blue candies in the image.",
390
- "Count the number of blue candies in the image.",
391
- "Object Counting",
392
- ],
393
- [
394
- "examples/example_1.jpg",
395
- 'locate every instance that belongs to the following categories: "red car". Report bbox coordinates in JSON format..',
396
- "red cars",
397
- "Visual Grounding + Keypoint Detection",
398
- ],
399
- [
400
- "examples/example_2.JPG",
401
- "Identify the blue candies in this image, detect their key points and return their positions in the form of points.",
402
- "blue candies",
403
- "Visual Grounding + Keypoint Detection",
404
- ],
405
- [
406
- "examples/example_1.jpg",
407
- 'locate every instance that belongs to the following categories: "leading red car". Report bbox coordinates in JSON format..',
408
- "leading red car",
409
- "Visual Grounding + Object Detection",
410
- ],
411
- [
412
- "examples/example_2.JPG",
413
- 'locate every instance that belongs to the following categories: "blue candy located at the top of the group". Report bbox coordinates in JSON format.',
414
- "blue candy located at the top of the group",
415
- "Visual Grounding + Object Detection",
416
- ],
417
- ]
418
  gr.Examples(
419
- examples=example_prompts,
420
- inputs=[
421
- image_input,
422
- prompt_input_model_1,
423
- prompt_input_model_2,
424
- category_input,
425
- ],
426
- label="Click an example to populate the input",
427
  )
428
 
429
- generate_btn.click(
430
- fn=detect,
431
- inputs=[
432
- image_input,
433
- prompt_input_model_1,
434
- prompt_input_model_2,
435
- category_input,
436
- ],
437
- outputs=[
438
- output_image_model_1,
439
- output_textbox_model_1,
440
- output_time_model_1,
441
- output_image_model_2,
442
- output_textbox_model_2,
443
- output_time_model_2,
444
- ],
 
445
  )
446
 
447
  if __name__ == "__main__":
448
- demo.launch()
 
 
 
 
1
  import gradio as gr
 
2
  from gradio.themes.ocean import Ocean
3
+ import torch
4
+ import numpy as np
5
+ import supervision as sv
6
  from transformers import (
7
  AutoModelForCausalLM,
 
8
  Qwen3VLForConditionalGeneration,
9
+ Qwen3VLProcessor,
10
  )
11
+ import json
12
+ import ast
13
+ import re
14
+ from PIL import Image
15
  from spaces import GPU
 
16
 
17
+ # --- Constants and Configuration ---
18
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
19
+ DTYPE = "auto"
20
 
21
+ CATEGORIES = ["Query", "Caption", "Point", "Detect"]
22
+ PLACEHOLDERS = {
23
+ "Query": "What's in this image?",
24
+ "Caption": "Enter caption length: short, normal, or long",
25
+ "Point": "Select an object from suggestions or enter manually",
26
+ "Detect": "Select an object from suggestions or enter manually",
27
+ }
28
+
29
+ # --- Model Loading ---
30
+ # Load Moondream
31
+ moondream = AutoModelForCausalLM.from_pretrained(
32
+ "moondream/moondream3-preview",
33
  trust_remote_code=True,
34
+ dtype=DTYPE,
35
+ device_map=DEVICE,
36
+ revision="main",
37
+ ).eval()
38
+
39
+ # Load Qwen3-VL
40
+ qwen_model = Qwen3VLForConditionalGeneration.from_pretrained(
41
+ "Qwen/Qwen3-VL-4B-Instruct",
42
+ dtype=DTYPE,
43
+ device_map=DEVICE,
44
+ ).eval()
45
+ qwen_processor = Qwen3VLProcessor.from_pretrained(
46
+ "Qwen/Qwen3-VL-4B-Instruct",
47
  )
48
 
49
 
50
+ # --- Utility Functions ---
51
+ def safe_parse_json(text: str):
52
+ text = text.strip()
53
+ text = re.sub(r"^```(json)?", "", text)
54
+ text = re.sub(r"```$", "", text)
55
+ text = text.strip()
56
+ try:
57
+ return json.loads(text)
58
+ except json.JSONDecodeError:
59
+ pass
 
 
60
  try:
61
+ return ast.literal_eval(text)
 
62
  except Exception:
63
+ return {}
64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
+ @GPU
67
+ def get_suggested_objects(image: Image.Image):
68
+ """Get suggested objects in the image using Moondream"""
69
+ if image is None:
70
+ return []
71
 
72
+ try:
73
+ result = moondream.query(
74
+ image=image,
75
+ question="What objects are in the image, provide the list.",
76
+ reasoning=False,
77
  )
78
+ suggested_objects = ast.literal_eval(result["answer"])
79
+ if isinstance(suggested_objects, list):
80
+ if len(suggested_objects) > 3: # send not more than 3 suggestions
81
+ return suggested_objects[:3]
82
+ else:
83
+ suggested_objects
84
+ return []
85
+ except Exception as e:
86
+ print(f"Error getting suggestions: {e}")
87
+ return []
88
 
89
 
90
+ def annotate_image(image: Image.Image, result: dict):
91
+ if not isinstance(image, Image.Image):
92
+ return image # Return original if not a valid image
93
+ if not isinstance(result, dict):
94
+ return image # Return original if result is not a dict
95
 
96
  original_width, original_height = image.size
 
97
 
98
+ # Handle Point annotations
99
+ if "points" in result and result["points"]:
100
+ points_list = []
101
+ for point in result.get("points", []):
102
  x = int(point["x"] * original_width)
103
  y = int(point["y"] * original_height)
104
+ points_list.append([x, y])
105
 
106
+ if not points_list:
107
+ return image
 
 
 
 
108
 
109
+ points_array = np.array(points_list).reshape(1, -1, 2)
 
110
  key_points = sv.KeyPoints(xy=points_array)
111
+ vertex_annotator = sv.VertexAnnotator(radius=8, color=sv.Color.RED)
112
  annotated_image = vertex_annotator.annotate(
113
+ scene=image.copy(), key_points=key_points
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
  )
115
+ return annotated_image
116
+
117
+ # Handle Detection annotations
118
+ if "objects" in result and result["objects"]:
119
+ detections = sv.Detections.from_vlm(
120
+ sv.VLM.MOONDREAM,
121
+ result,
122
+ resolution_wh=image.size,
123
  )
124
+ if len(detections) == 0:
125
+ return image
126
 
127
+ box_annotator = sv.BoxAnnotator(color_lookup=sv.ColorLookup.INDEX, thickness=5)
128
+ annotated_scene = box_annotator.annotate(
129
+ scene=image.copy(), detections=detections
130
+ )
131
+ return annotated_scene
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
 
133
+ return image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
 
135
 
136
+ # --- Inference Functions ---
137
+ def run_qwen_inference(image: Image.Image, prompt: str):
138
  messages = [
139
  {
140
  "role": "user",
 
144
  ],
145
  }
146
  ]
147
+ inputs = qwen_processor.apply_chat_template(
 
 
148
  messages,
149
  tokenize=True,
150
  add_generation_prompt=True,
151
  return_dict=True,
152
+ return_tensors="pt",
153
+ ).to(DEVICE)
154
+
155
+ with torch.inference_mode():
156
+ generated_ids = qwen_model.generate(
157
+ **inputs,
158
+ max_new_tokens=512,
159
+ )
160
 
 
161
  generated_ids_trimmed = [
162
  out_ids[len(in_ids) :]
163
  for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
164
  ]
165
+ output_text = qwen_processor.batch_decode(
166
  generated_ids_trimmed,
167
  skip_special_tokens=True,
168
  clean_up_tokenization_spaces=False,
169
  )[0]
170
+ return output_text
171
 
 
172
 
173
+ @GPU
174
+ def process_qwen(image: Image.Image, category: str, prompt: str):
175
+ if category == "Query":
176
+ return run_qwen_inference(image, prompt), {}
177
+ elif category == "Caption":
178
+ full_prompt = f"Provide a {prompt} length caption for the image."
179
+ return run_qwen_inference(image, full_prompt), {}
180
+ elif category == "Point":
181
+ full_prompt = (
182
+ f"Provide 2d point coordinates for {prompt}. Report in JSON format."
183
+ )
184
+ output_text = run_qwen_inference(image, full_prompt)
185
+ parsed_json = safe_parse_json(output_text)
186
+ points_result = {"points": []}
187
+ if isinstance(parsed_json, list):
188
+ for item in parsed_json:
189
+ if "point_2d" in item and len(item["point_2d"]) == 2:
190
+ x, y = item["point_2d"]
191
+ points_result["points"].append({"x": x / 1000.0, "y": y / 1000.0})
192
+ return json.dumps(points_result, indent=2), points_result
193
+ elif category == "Detect":
194
+ full_prompt = (
195
+ f"Provide bounding box coordinates for {prompt}. Report in JSON format."
196
+ )
197
+ output_text = run_qwen_inference(image, full_prompt)
198
+ parsed_json = safe_parse_json(output_text)
199
+ objects_result = {"objects": []}
200
+ if isinstance(parsed_json, list):
201
+ for item in parsed_json:
202
+ if "bbox_2d" in item and len(item["bbox_2d"]) == 4:
203
+ xmin, ymin, xmax, ymax = item["bbox_2d"]
204
+ objects_result["objects"].append(
205
+ {
206
+ "x_min": xmin / 1000.0,
207
+ "y_min": ymin / 1000.0,
208
+ "x_max": xmax / 1000.0,
209
+ "y_max": ymax / 1000.0,
210
+ }
211
+ )
212
+ return json.dumps(objects_result, indent=2), objects_result
213
+ return "Invalid category", {}
214
 
215
 
216
  @GPU
217
+ def process_moondream(image: Image.Image, category: str, prompt: str):
218
+ if category == "Query":
219
+ result = moondream.query(image=image, question=prompt)
220
+ return result["answer"], {}
221
+ elif category == "Caption":
222
+ result = moondream.caption(image, length=prompt)
223
+ return result["caption"], {}
224
+ elif category == "Point":
225
+ result = moondream.point(image, prompt)
226
+ return json.dumps(result, indent=2), result
227
+ elif category == "Detect":
228
+ result = moondream.detect(image, prompt)
229
+ return json.dumps(result, indent=2), result
230
+ return "Invalid category", {}
231
+
232
+
233
+ # --- Gradio Interface Logic ---
234
+ def on_category_and_image_change(image, category):
235
+ """Generate suggestions when category changes to Point or Detect"""
236
+ text_box = gr.Textbox(value="", placeholder=PLACEHOLDERS.get(category, ""), interactive=True)
237
+
238
+ if image is None or category not in ["Point", "Detect", "Caption"]:
239
+ return gr.Radio(choices=[], visible=False), text_box
240
+
241
+ if category == "Caption":
242
+ return gr.Radio(choices=["short", "normal", "long"], visible=True), text_box
243
+
244
+ suggestions = get_suggested_objects(image)
245
+ if suggestions:
246
+ return gr.Radio(choices=suggestions, visible=True, interactive=True), text_box
247
  else:
248
+ return gr.Radio(choices=["no choice possible"], visible=True, interactive=True), text_box
 
 
 
249
 
 
 
 
250
 
251
+ def update_prompt_from_radio(selected_object):
252
+ """Update prompt textbox when a radio option is selected"""
253
+ if selected_object:
254
+ return gr.Textbox(value=selected_object)
255
+ return gr.Textbox(value="")
256
 
257
 
258
+ def process_inputs(image, category, prompt):
259
+ if image is None:
260
+ raise gr.Error("Please upload an image.")
261
+ if not prompt:
262
+ raise gr.Error("Please provide a prompt.")
263
 
264
+ # Process with Qwen
265
+ qwen_text, qwen_data = process_qwen(image, category, prompt)
266
+ qwen_annotated_image = annotate_image(image, qwen_data)
 
 
 
267
 
268
+ # Process with Moondream
269
+ moondream_text, moondream_data = process_moondream(image, category, prompt)
270
+ moondream_annotated_image = annotate_image(image, moondream_data)
271
+
272
+ return qwen_annotated_image, qwen_text, moondream_annotated_image, moondream_text
 
 
 
273
 
274
 
275
  css_hide_share = """
 
278
  }
279
  """
280
 
281
+ # --- Gradio UI Layout ---
282
  with gr.Blocks(theme=Ocean(), css=css_hide_share) as demo:
283
  gr.Markdown("# 👓 Object Understanding with Vision Language Models")
284
  gr.Markdown(
 
290
  """)
291
 
292
  with gr.Row():
293
+ with gr.Column(scale=1):
294
+ image_input = gr.Image(type="pil", label="Input Image")
295
+ category_select = gr.Radio(
296
+ choices=CATEGORIES,
297
+ value=CATEGORIES[0],
298
+ label="Select Task Category",
299
+ interactive=True,
300
  )
301
+ # Suggested objects radio (hidden by default)
302
+ suggestions_radio = gr.Radio(
303
+ choices=[],
304
+ label="Suggestions",
305
+ visible=False,
306
+ interactive=True,
307
  )
308
+ prompt_input = gr.Textbox(
309
+ placeholder=PLACEHOLDERS[CATEGORIES[0]],
310
+ label="Prompt",
311
+ lines=2,
 
 
 
 
 
 
 
312
  )
 
313
 
314
+ submit_btn = gr.Button("Compare Models", variant="primary")
315
+
316
+ with gr.Column(scale=2):
317
+ with gr.Row():
318
+ with gr.Column():
319
+ gr.Markdown("### Qwen/Qwen3-VL-4B-Instruct")
320
+ qwen_img_output = gr.Image(label="Annotated Image")
321
+ qwen_text_output = gr.Textbox(
322
+ label="Text Output", lines=8, interactive=False
323
+ )
324
+ with gr.Column():
325
+ gr.Markdown("### moondream/moondream3-preview")
326
+ moon_img_output = gr.Image(label="Annotated Image")
327
+ moon_text_output = gr.Textbox(
328
+ label="Text Output", lines=8, interactive=False
329
+ )
330
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
331
  gr.Examples(
332
+ examples=[
333
+ ["examples/example_1.jpg", "Query", "How many cars are in the image?"],
334
+ ["examples/example_1.jpg", "Caption", ""],
335
+ ["examples/example_2.JPG", "Point", ""],
336
+ ["examples/example_2.JPG", "Detect", ""],
337
+ ],
338
+ inputs=[image_input, category_select, prompt_input],
 
339
  )
340
 
341
+ # --- Event Listeners ---
342
+ category_select.change(
343
+ fn=on_category_and_image_change,
344
+ inputs=[image_input, category_select],
345
+ outputs=[suggestions_radio, prompt_input],
346
+ )
347
+
348
+ suggestions_radio.change(
349
+ fn=update_prompt_from_radio,
350
+ inputs=[suggestions_radio],
351
+ outputs=[prompt_input],
352
+ )
353
+
354
+ submit_btn.click(
355
+ fn=process_inputs,
356
+ inputs=[image_input, category_select, prompt_input],
357
+ outputs=[qwen_img_output, qwen_text_output, moon_img_output, moon_text_output],
358
  )
359
 
360
  if __name__ == "__main__":
361
+ demo.launch()
requirements.txt CHANGED
@@ -1,10 +1,7 @@
1
- torch
2
- transformers
3
- datasets
4
- Pillow
5
- gradio
6
- accelerate
7
- qwen-vl-utils
8
- torchvision
9
- matplotlib
10
- supervision
 
1
+ torch==2.8.0
2
+ transformers==4.57.0
3
+ Pillow==11.3.0
4
+ gradio==5.49.1
5
+ accelerate==1.10.1
6
+ torchvision==0.23.0
7
+ supervision==0.26.1