james-ham commited on
Commit
08aebdd
Β·
verified Β·
1 Parent(s): 40e4cbb

create app.py

Browse files
Files changed (1) hide show
  1. app.py +158 -0
app.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # FastVLM Screenshot Explainer (CPU-only, no uploads)
2
+ # Space idea: curated gallery β†’ caption / extract numbers / VQA
3
+ # Model: apple/FastVLM-0.5B (Research-only license)
4
+ # ─────────────────────────────────────────────────────────────────────────────
5
+
6
+ import time
7
+ import io
8
+ import requests
9
+ from PIL import Image
10
+ import gradio as gr
11
+ import torch
12
+ from transformers import AutoTokenizer, AutoModelForCausalLM
13
+
14
+ MODEL_ID = "apple/FastVLM-0.5B"
15
+ IMAGE_TOKEN_INDEX = -200 # per model card
16
+ DEVICE = "cpu"
17
+
18
+ # A tiny curated gallery (HF/COCO-hosted images)
19
+ SAMPLES = {
20
+ # general photo (COCO)
21
+ "Dog-in-street (COCO)": "http://images.cocodataset.org/val2017/000000039769.jpg",
22
+ # charts (ChartMuseum dataset)
23
+ "Chart β€” Blind wine tasting": "https://huggingface.co/datasets/lytang/ChartMuseum/resolve/main/images/wine_blind_taste.png",
24
+ "Chart β€” Life expectancy (Africa vs Asia)": "https://huggingface.co/datasets/lytang/ChartMuseum/resolve/main/images/life-expectancy-africa-vs-asia.png",
25
+ # document-like page (HF internal testing)
26
+ "Document page β€” example": "https://huggingface.co/datasets/hf-internal-testing/example-documents/resolve/main/jpeg_images/1.jpg",
27
+ }
28
+
29
+ TASK_PROMPTS = {
30
+ "Explain": "Describe this image in detail.",
31
+ "Extract numbers": (
32
+ "Extract every number you can see with its label/context. "
33
+ "Return a concise YAML list with fields: value, what_it_refers_to."
34
+ ),
35
+ "Write alt-text": (
36
+ "Write high-quality alt-text (<=200 chars) that would help a blind user understand "
37
+ "the key content and purpose of this image."
38
+ ),
39
+ "Ask a question…": None, # free-form
40
+ }
41
+
42
+ # ── Model load (CPU) ─────────────────────────────────────────────────────────
43
+ tok = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
44
+ model = AutoModelForCausalLM.from_pretrained(
45
+ MODEL_ID,
46
+ torch_dtype=torch.float32, # CPU
47
+ device_map={"": DEVICE},
48
+ trust_remote_code=True,
49
+ )
50
+
51
+ # ── Helpers ──────────────────────────────────────────────────────────────────
52
+ def _fetch_image(url: str) -> Image.Image:
53
+ r = requests.get(url, timeout=20)
54
+ r.raise_for_status()
55
+ return Image.open(io.BytesIO(r.content)).convert("RGB")
56
+
57
+ def _build_inputs(prompt: str):
58
+ # Build chat with <image> placeholder exactly once (per model card)
59
+ messages = [{"role": "user", "content": f"<image>\n{prompt}"}]
60
+ rendered = tok.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
61
+ pre, post = rendered.split("<image>", 1)
62
+
63
+ pre_ids = tok(pre, return_tensors="pt", add_special_tokens=False).input_ids
64
+ post_ids = tok(post, return_tensors="pt", add_special_tokens=False).input_ids
65
+ img_tok = torch.tensor([[IMAGE_TOKEN_INDEX]], dtype=pre_ids.dtype)
66
+
67
+ input_ids = torch.cat([pre_ids, img_tok, post_ids], dim=1).to(model.device)
68
+ attention_mask = torch.ones_like(input_ids, device=model.device)
69
+ return input_ids, attention_mask
70
+
71
+ def _prepare_pixels(pil_image: Image.Image):
72
+ # Use the model's own processor from the vision tower
73
+ px = model.get_vision_tower().image_processor(images=pil_image, return_tensors="pt")["pixel_values"]
74
+ return px.to(model.device, dtype=model.dtype)
75
+
76
+ @torch.inference_mode()
77
+ def run_inference(choice: str, task: str, user_q: str, max_new_tokens: int, temperature: float):
78
+ try:
79
+ img = _fetch_image(SAMPLES[choice])
80
+ except Exception as e:
81
+ return None, f"Could not load image: {e}", ""
82
+
83
+ # Decide prompt
84
+ if task == "Ask a question…":
85
+ prompt = user_q.strip() or "Answer questions about this image."
86
+ else:
87
+ prompt = TASK_PROMPTS[task]
88
+
89
+ # Build model inputs
90
+ input_ids, attention_mask = _build_inputs(prompt)
91
+ px = _prepare_pixels(img)
92
+
93
+ # Generate
94
+ t0 = time.perf_counter()
95
+ out = model.generate(
96
+ inputs=input_ids,
97
+ attention_mask=attention_mask,
98
+ images=px,
99
+ max_new_tokens=int(max_new_tokens),
100
+ temperature=float(temperature),
101
+ )
102
+ t1 = time.perf_counter()
103
+
104
+ text = tok.decode(out[0], skip_special_tokens=True)
105
+
106
+ # Rough throughput metric
107
+ gen_len = (out.shape[-1] - input_ids.shape[-1])
108
+ elapsed = t1 - t0
109
+ meta = f"⏱️ {elapsed:.2f}s β€’ new tokens: {gen_len} β€’ ~{(gen_len/elapsed if elapsed>0 else 0):.1f} tok/s β€’ device: {DEVICE.upper()}"
110
+
111
+ return img, text.strip(), meta
112
+
113
+ # ── Gradio UI ────────────────────────────────────────────────────────────────
114
+ with gr.Blocks(title="FastVLM Screenshot Explainer (CPU)") as demo:
115
+ gr.Markdown(
116
+ """
117
+ # ⚑ FastVLM Screenshot Explainer β€” CPU-only (no uploads)
118
+ Click an example image, pick a task, and go.
119
+ Model: **apple/FastVLM-0.5B** (research license).
120
+ """
121
+ )
122
+
123
+ with gr.Row():
124
+ choice = gr.Dropdown(
125
+ label="Choose example image",
126
+ choices=list(SAMPLES.keys()),
127
+ value=list(SAMPLES.keys())[0],
128
+ )
129
+ task = gr.Radio(
130
+ label="Task",
131
+ choices=list(TASK_PROMPTS.keys()),
132
+ value="Explain",
133
+ info="β€˜Ask a question…’ enables free-form VQA.",
134
+ )
135
+ user_q = gr.Textbox(label="If asking a question, type it here", placeholder="e.g., What is the trend from 1950 to 2000?")
136
+ with gr.Accordion("Generation settings", open=False):
137
+ max_new = gr.Slider(32, 256, 128, step=8, label="max_new_tokens")
138
+ temp = gr.Slider(0.0, 1.0, 0.2, step=0.05, label="temperature")
139
+
140
+ go = gr.Button("Explain / Answer", variant="primary")
141
+ with gr.Row():
142
+ img_out = gr.Image(label="Image", interactive=False)
143
+ txt_out = gr.Textbox(label="Model output", lines=14)
144
+ meta = gr.Markdown()
145
+
146
+ go.click(run_inference, [choice, task, user_q, max_new, temp], [img_out, txt_out, meta])
147
+
148
+ gr.Markdown(
149
+ """
150
+ **Notes**
151
+ - Runs on CPU by default (float32). For GPUs, restart Space with CUDA and it will auto-use float16.
152
+ - Model + usage based on the official model card’s `trust_remote_code` API and <image> token handling.
153
+ - **License:** Apple AML Research License β€” *research & non-commercial use only*.
154
+ """
155
+ )
156
+
157
+ if __name__ == "__main__":
158
+ demo.launch()