create app.py
Browse files
app.py
ADDED
|
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# FastVLM Screenshot Explainer (CPU-only, no uploads)
|
| 2 |
+
# Space idea: curated gallery β caption / extract numbers / VQA
|
| 3 |
+
# Model: apple/FastVLM-0.5B (Research-only license)
|
| 4 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 5 |
+
|
| 6 |
+
import time
|
| 7 |
+
import io
|
| 8 |
+
import requests
|
| 9 |
+
from PIL import Image
|
| 10 |
+
import gradio as gr
|
| 11 |
+
import torch
|
| 12 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
| 13 |
+
|
| 14 |
+
MODEL_ID = "apple/FastVLM-0.5B"
|
| 15 |
+
IMAGE_TOKEN_INDEX = -200 # per model card
|
| 16 |
+
DEVICE = "cpu"
|
| 17 |
+
|
| 18 |
+
# A tiny curated gallery (HF/COCO-hosted images)
|
| 19 |
+
SAMPLES = {
|
| 20 |
+
# general photo (COCO)
|
| 21 |
+
"Dog-in-street (COCO)": "http://images.cocodataset.org/val2017/000000039769.jpg",
|
| 22 |
+
# charts (ChartMuseum dataset)
|
| 23 |
+
"Chart β Blind wine tasting": "https://huggingface.co/datasets/lytang/ChartMuseum/resolve/main/images/wine_blind_taste.png",
|
| 24 |
+
"Chart β Life expectancy (Africa vs Asia)": "https://huggingface.co/datasets/lytang/ChartMuseum/resolve/main/images/life-expectancy-africa-vs-asia.png",
|
| 25 |
+
# document-like page (HF internal testing)
|
| 26 |
+
"Document page β example": "https://huggingface.co/datasets/hf-internal-testing/example-documents/resolve/main/jpeg_images/1.jpg",
|
| 27 |
+
}
|
| 28 |
+
|
| 29 |
+
TASK_PROMPTS = {
|
| 30 |
+
"Explain": "Describe this image in detail.",
|
| 31 |
+
"Extract numbers": (
|
| 32 |
+
"Extract every number you can see with its label/context. "
|
| 33 |
+
"Return a concise YAML list with fields: value, what_it_refers_to."
|
| 34 |
+
),
|
| 35 |
+
"Write alt-text": (
|
| 36 |
+
"Write high-quality alt-text (<=200 chars) that would help a blind user understand "
|
| 37 |
+
"the key content and purpose of this image."
|
| 38 |
+
),
|
| 39 |
+
"Ask a questionβ¦": None, # free-form
|
| 40 |
+
}
|
| 41 |
+
|
| 42 |
+
# ββ Model load (CPU) βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 43 |
+
tok = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
|
| 44 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 45 |
+
MODEL_ID,
|
| 46 |
+
torch_dtype=torch.float32, # CPU
|
| 47 |
+
device_map={"": DEVICE},
|
| 48 |
+
trust_remote_code=True,
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
# ββ Helpers ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 52 |
+
def _fetch_image(url: str) -> Image.Image:
|
| 53 |
+
r = requests.get(url, timeout=20)
|
| 54 |
+
r.raise_for_status()
|
| 55 |
+
return Image.open(io.BytesIO(r.content)).convert("RGB")
|
| 56 |
+
|
| 57 |
+
def _build_inputs(prompt: str):
|
| 58 |
+
# Build chat with <image> placeholder exactly once (per model card)
|
| 59 |
+
messages = [{"role": "user", "content": f"<image>\n{prompt}"}]
|
| 60 |
+
rendered = tok.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
|
| 61 |
+
pre, post = rendered.split("<image>", 1)
|
| 62 |
+
|
| 63 |
+
pre_ids = tok(pre, return_tensors="pt", add_special_tokens=False).input_ids
|
| 64 |
+
post_ids = tok(post, return_tensors="pt", add_special_tokens=False).input_ids
|
| 65 |
+
img_tok = torch.tensor([[IMAGE_TOKEN_INDEX]], dtype=pre_ids.dtype)
|
| 66 |
+
|
| 67 |
+
input_ids = torch.cat([pre_ids, img_tok, post_ids], dim=1).to(model.device)
|
| 68 |
+
attention_mask = torch.ones_like(input_ids, device=model.device)
|
| 69 |
+
return input_ids, attention_mask
|
| 70 |
+
|
| 71 |
+
def _prepare_pixels(pil_image: Image.Image):
|
| 72 |
+
# Use the model's own processor from the vision tower
|
| 73 |
+
px = model.get_vision_tower().image_processor(images=pil_image, return_tensors="pt")["pixel_values"]
|
| 74 |
+
return px.to(model.device, dtype=model.dtype)
|
| 75 |
+
|
| 76 |
+
@torch.inference_mode()
|
| 77 |
+
def run_inference(choice: str, task: str, user_q: str, max_new_tokens: int, temperature: float):
|
| 78 |
+
try:
|
| 79 |
+
img = _fetch_image(SAMPLES[choice])
|
| 80 |
+
except Exception as e:
|
| 81 |
+
return None, f"Could not load image: {e}", ""
|
| 82 |
+
|
| 83 |
+
# Decide prompt
|
| 84 |
+
if task == "Ask a questionβ¦":
|
| 85 |
+
prompt = user_q.strip() or "Answer questions about this image."
|
| 86 |
+
else:
|
| 87 |
+
prompt = TASK_PROMPTS[task]
|
| 88 |
+
|
| 89 |
+
# Build model inputs
|
| 90 |
+
input_ids, attention_mask = _build_inputs(prompt)
|
| 91 |
+
px = _prepare_pixels(img)
|
| 92 |
+
|
| 93 |
+
# Generate
|
| 94 |
+
t0 = time.perf_counter()
|
| 95 |
+
out = model.generate(
|
| 96 |
+
inputs=input_ids,
|
| 97 |
+
attention_mask=attention_mask,
|
| 98 |
+
images=px,
|
| 99 |
+
max_new_tokens=int(max_new_tokens),
|
| 100 |
+
temperature=float(temperature),
|
| 101 |
+
)
|
| 102 |
+
t1 = time.perf_counter()
|
| 103 |
+
|
| 104 |
+
text = tok.decode(out[0], skip_special_tokens=True)
|
| 105 |
+
|
| 106 |
+
# Rough throughput metric
|
| 107 |
+
gen_len = (out.shape[-1] - input_ids.shape[-1])
|
| 108 |
+
elapsed = t1 - t0
|
| 109 |
+
meta = f"β±οΈ {elapsed:.2f}s β’ new tokens: {gen_len} β’ ~{(gen_len/elapsed if elapsed>0 else 0):.1f} tok/s β’ device: {DEVICE.upper()}"
|
| 110 |
+
|
| 111 |
+
return img, text.strip(), meta
|
| 112 |
+
|
| 113 |
+
# ββ Gradio UI ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 114 |
+
with gr.Blocks(title="FastVLM Screenshot Explainer (CPU)") as demo:
|
| 115 |
+
gr.Markdown(
|
| 116 |
+
"""
|
| 117 |
+
# β‘ FastVLM Screenshot Explainer β CPU-only (no uploads)
|
| 118 |
+
Click an example image, pick a task, and go.
|
| 119 |
+
Model: **apple/FastVLM-0.5B** (research license).
|
| 120 |
+
"""
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
with gr.Row():
|
| 124 |
+
choice = gr.Dropdown(
|
| 125 |
+
label="Choose example image",
|
| 126 |
+
choices=list(SAMPLES.keys()),
|
| 127 |
+
value=list(SAMPLES.keys())[0],
|
| 128 |
+
)
|
| 129 |
+
task = gr.Radio(
|
| 130 |
+
label="Task",
|
| 131 |
+
choices=list(TASK_PROMPTS.keys()),
|
| 132 |
+
value="Explain",
|
| 133 |
+
info="βAsk a questionβ¦β enables free-form VQA.",
|
| 134 |
+
)
|
| 135 |
+
user_q = gr.Textbox(label="If asking a question, type it here", placeholder="e.g., What is the trend from 1950 to 2000?")
|
| 136 |
+
with gr.Accordion("Generation settings", open=False):
|
| 137 |
+
max_new = gr.Slider(32, 256, 128, step=8, label="max_new_tokens")
|
| 138 |
+
temp = gr.Slider(0.0, 1.0, 0.2, step=0.05, label="temperature")
|
| 139 |
+
|
| 140 |
+
go = gr.Button("Explain / Answer", variant="primary")
|
| 141 |
+
with gr.Row():
|
| 142 |
+
img_out = gr.Image(label="Image", interactive=False)
|
| 143 |
+
txt_out = gr.Textbox(label="Model output", lines=14)
|
| 144 |
+
meta = gr.Markdown()
|
| 145 |
+
|
| 146 |
+
go.click(run_inference, [choice, task, user_q, max_new, temp], [img_out, txt_out, meta])
|
| 147 |
+
|
| 148 |
+
gr.Markdown(
|
| 149 |
+
"""
|
| 150 |
+
**Notes**
|
| 151 |
+
- Runs on CPU by default (float32). For GPUs, restart Space with CUDA and it will auto-use float16.
|
| 152 |
+
- Model + usage based on the official model cardβs `trust_remote_code` API and <image> token handling.
|
| 153 |
+
- **License:** Apple AML Research License β *research & non-commercial use only*.
|
| 154 |
+
"""
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
if __name__ == "__main__":
|
| 158 |
+
demo.launch()
|