import os import dataclasses from typing import Optional, Dict, Any, Tuple import torch import torch.nn.functional as F import gradio as gr from PIL import Image from transformers import ( AutoProcessor, AutoModelForVision2Seq, AutoModelForCausalLM, ) from peft import PeftModel, PeftConfig # ----------------------------- # CPU-only enforcement # ----------------------------- FORCE_CPU = True DEVICE = torch.device("cpu") # ----------------------------- # Resolution gate # ----------------------------- RESOLUTION_MAP = {0: 384, 1: 768, 2: 1024} def load_and_resize_image(img: Image.Image, max_size: Optional[int] = None) -> Image.Image: img = img.convert("RGB") if max_size is None: return img w, h = img.size if max(w, h) <= max_size: return img s = max_size / max(w, h) return img.resize((round(w * s), round(h * s)), Image.BICUBIC) def token_id_for_digit(tokenizer, digit: str) -> int: ids = tokenizer.encode(digit, add_special_tokens=False) if not ids: raise ValueError(f"Could not encode digit {digit!r}") return ids[-1] class GraniteDoclingGateHF: def __init__(self, adapter_repo: str, token: Optional[str] = None): self.device = DEVICE peft_cfg = PeftConfig.from_pretrained(adapter_repo, token=token) base_model_name = peft_cfg.base_model_name_or_path self.processor = AutoProcessor.from_pretrained(adapter_repo, token=token) # CPU: use float32 for safety (bfloat16/float16 often slower or problematic on CPU) torch_dtype = torch.float32 base_model = AutoModelForVision2Seq.from_pretrained( base_model_name, torch_dtype=torch_dtype ) self.model = PeftModel.from_pretrained(base_model, adapter_repo, token=token) self.model.to(self.device).eval() tok = self.processor.tokenizer self.class_token_ids = [ token_id_for_digit(tok, "0"), token_id_for_digit(tok, "1"), token_id_for_digit(tok, "2"), ] @torch.no_grad() def predict_probs(self, image: Image.Image, question: str): messages = [ {"role": "user", "content": [{"type": "image"}, {"type": "text", "text": question}]} ] prompt = self.processor.apply_chat_template(messages, add_generation_prompt=True) inputs = self.processor(text=[prompt], images=[image], return_tensors="pt") inputs = {k: v.to(self.device) for k, v in inputs.items() if hasattr(v, "to")} outputs = self.model(**inputs) next_token_logits = outputs.logits[:, -1, :] class_logits = next_token_logits[:, self.class_token_ids] probs = F.softmax(class_logits, dim=-1)[0].detach().float().cpu().tolist() return probs def predict_expected(self, image: Image.Image, question: str) -> float: probs = self.predict_probs(image, question) return float(sum(RESOLUTION_MAP[i] * probs[i] for i in range(3))) # ----------------------------- # CPU-friendly downstream HF VLM inference # ----------------------------- # IMPORTANT: Choose models that can run on CPU. # Many VLMs will be too slow/heavy on CPU; start small. DOWNSTREAM_MODELS = { "ibm-granite/granite-vision-3.3-2b (recommended CPU)": "ibm-granite/granite-vision-3.3-2b", "HuggingFaceTB/SmolVLM-256M-Instruct (tiny CPU)": "HuggingFaceTB/SmolVLM-256M-Instruct", "google/paligemma-3b-mix-224 (CPU)": "google/paligemma-3b-mix-224", # Your list (kept available but not recommended on CPU) "Qwen/Qwen2.5-VL-3B-Instruct (slow CPU)": "Qwen/Qwen2.5-VL-3B-Instruct", "Qwen/Qwen2.5-VL-7B-Instruct (very slow CPU)": "Qwen/Qwen2.5-VL-7B-Instruct", "Qwen/Qwen2.5-VL-72B-Instruct (not for CPU)": "Qwen/Qwen2.5-VL-72B-Instruct", "Qwen/Qwen3-VL-8B-Instruct (very slow CPU)": "Qwen/Qwen3-VL-8B-Instruct", "OpenGVLab/InternVL3_5-8B (very slow CPU)": "OpenGVLab/InternVL3_5-8B", "OpenGVLab/InternVL3_5-38B (not for CPU)": "OpenGVLab/InternVL3_5-38B", "OpenGVLab/InternVL3_5-241B-A28B (not for CPU)": "OpenGVLab/InternVL3_5-241B-A28B", "None (gate only)": None, } _model_cache: Dict[str, Tuple[Any, Any]] = {} import inspect torch.set_num_threads(int(os.getenv("TORCH_NUM_THREADS", "4"))) def get_vlm(model_id: str): if model_id in _model_cache: return _model_cache[model_id] proc = AutoProcessor.from_pretrained(model_id, trust_remote_code=True) common_kwargs = dict( torch_dtype=torch.float32, trust_remote_code=True, ) # low_cpu_mem_usage exists on many HF models; use it if supported try: sig = inspect.signature(AutoModelForVision2Seq.from_pretrained) if "low_cpu_mem_usage" in sig.parameters: common_kwargs["low_cpu_mem_usage"] = True except Exception: pass model = None err = None # Try Vision2Seq first try: model = AutoModelForVision2Seq.from_pretrained(model_id, **common_kwargs) except Exception as e: err = e # Fallback: CausalLM model = AutoModelForCausalLM.from_pretrained(model_id, **common_kwargs) model.to(DEVICE).eval() _model_cache[model_id] = (proc, model) return proc, model # def get_vlm(model_id: str): # if model_id in _model_cache: # return _model_cache[model_id] # # CPU-only: float32 and no device_map # proc = AutoProcessor.from_pretrained(model_id, trust_remote_code=True) # try: # model = AutoModelForVision2Seq.from_pretrained( # model_id, # torch_dtype=torch.float32, # trust_remote_code=True, # ) # except Exception: # model = AutoModelForCausalLM.from_pretrained( # model_id, # torch_dtype=torch.float32, # trust_remote_code=True, # ) # model.to(DEVICE).eval() # _model_cache[model_id] = (proc, model) # return proc, model # @torch.no_grad() # def vlm_answer(model_id: str, image: Image.Image, question: str, max_new_tokens: int = 96) -> str: # proc, model = get_vlm(model_id) # messages = [ # {"role": "user", "content": [{"type": "image"}, {"type": "text", "text": question}]} # ] # if hasattr(proc, "apply_chat_template"): # prompt = proc.apply_chat_template(messages, add_generation_prompt=True) # inputs = proc(text=[prompt], images=[image], return_tensors="pt") # else: # inputs = proc(text=[question], images=[image], return_tensors="pt") # inputs = {k: v.to(DEVICE) for k, v in inputs.items() if hasattr(v, "to")} # out = model.generate(**inputs, max_new_tokens=max_new_tokens) # text = proc.batch_decode(out, skip_special_tokens=True)[0].strip() # # Heuristic: remove prompt echoes # if question in text and len(text) > 2 * len(question): # text = text.split(question, 1)[-1].strip() # return text ##GV ONLY # @torch.no_grad() # def vlm_answer(model_id: str, image: Image.Image, question: str, max_new_tokens: int = 96) -> str: # proc, model = get_vlm(model_id) # conversation = [ # { # "role": "user", # "content": [ # {"type": "image"}, # {"type": "text", "text": question}, # ], # } # ] # # Prefer the Granite-style path if supported # if hasattr(proc, "apply_chat_template"): # try: # inputs = proc.apply_chat_template( # conversation, # add_generation_prompt=True, # tokenize=True, # return_dict=True, # return_tensors="pt", # images=image, # some processors accept this; if not, except below # ) # except TypeError: # # Fallback: build prompt then call processor(text, images) # prompt = proc.apply_chat_template(conversation, add_generation_prompt=True) # inputs = proc(text=[prompt], images=[image], return_tensors="pt") # else: # inputs = proc(text=[question], images=[image], return_tensors="pt") # inputs = {k: v.to(DEVICE) for k, v in inputs.items() if hasattr(v, "to")} # out = model.generate(**inputs, max_new_tokens=max_new_tokens) # text = proc.batch_decode(out, skip_special_tokens=True)[0].strip() # if question in text and len(text) > 2 * len(question): # text = text.split(question, 1)[-1].strip() # return text @torch.no_grad() def vlm_answer(model_id: str, image: Image.Image, question: str, max_new_tokens: int = 96) -> str: proc, model = get_vlm(model_id) # ---- Path A: model.chat (InternVL-style, some others) ---- if hasattr(model, "chat") and callable(getattr(model, "chat")): try: # Different repos have different signatures; this is the most common pattern. # If it fails, we fall back to processor+generate. return str(model.chat(proc, image, question)).strip() except Exception: pass # ---- Path B: processor + generate ---- conversation = [{ "role": "user", "content": [{"type": "image"}, {"type": "text", "text": question}], }] inputs = None if hasattr(proc, "apply_chat_template"): # Try Granite-style “tokenize=True” path try: inputs = proc.apply_chat_template( conversation, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt", images=image, # supported by some processors ) except Exception: # Fallback: create prompt string then call processor try: prompt = proc.apply_chat_template(conversation, add_generation_prompt=True) inputs = proc(text=[prompt], images=[image], return_tensors="pt") except Exception: inputs = None if inputs is None: # Final fallback: no templates inputs = proc(text=[question], images=[image], return_tensors="pt") inputs = {k: v.to(DEVICE) for k, v in inputs.items() if hasattr(v, "to")} out = model.generate(**inputs, max_new_tokens=max_new_tokens) text = proc.batch_decode(out, skip_special_tokens=True)[0].strip() # prompt-echo cleanup (best effort) if question in text and len(text) > 2 * len(question): text = text.split(question, 1)[-1].strip() return text def cpu_model_allowed(model_id: str) -> Tuple[bool, str]: mid = (model_id or "").lower() blocked = ["72b", "38b", "241b", "a28b"] if any(b in mid for b in blocked): return False, "Too large for CPU Space (will OOM)." return True, "" # ----------------------------- # Resolution selection strategy # ----------------------------- def choose_resolution(expected: float, probs: list, strategy: str) -> int: if strategy == "expected": return int(round(expected)) if strategy == "argmax": k = int(max(range(len(probs)), key=lambda i: probs[i])) return int(RESOLUTION_MAP[k]) # conservative: choose highest bucket if it has meaningful mass, else next, else lowest if probs[2] >= 0.34: return int(RESOLUTION_MAP[2]) if probs[1] >= 0.34: return int(RESOLUTION_MAP[1]) return int(RESOLUTION_MAP[0]) # ----------------------------- # Gradio app # ----------------------------- GATE_ADAPTER_REPO = os.getenv("GATE_ADAPTER_REPO", "Kimhi/granite-docling-res-gate-lora") HF_TOKEN = os.getenv("HF_TOKEN", None) GATE_INPUT_MAX_SIDE = int(os.getenv("GATE_INPUT_MAX_SIDE", "256")) gate = None def run(image: Image.Image, question: str, vlm_choice: str, strategy: str): global gate if gate is None: gate = GraniteDoclingGateHF(adapter_repo=GATE_ADAPTER_REPO, token=HF_TOKEN) if image is None or not question: return "Upload an image and enter a question.", None, None native_w, native_h = image.size # Gate runs on small image gate_img = load_and_resize_image(image, GATE_INPUT_MAX_SIDE) probs = gate.predict_probs(gate_img, question) expected = float(sum(RESOLUTION_MAP[i] * probs[i] for i in range(3))) pred = choose_resolution(expected, probs, strategy) # never upscale above native max-side native_max = max(native_w, native_h) used_max = min(pred, native_max) resized = load_and_resize_image(image, used_max) resized_w, resized_h = resized.size model_id = DOWNSTREAM_MODELS.get(vlm_choice) if model_id is None: answer = "(gate only) No VLM selected." else: ok, reason = cpu_model_allowed(model_id) if not ok: answer = f"Blocked on CPU: {reason}" else: try: answer = vlm_answer(model_id, resized, question) except Exception as e: answer = f"VLM error: {type(e).__name__}: {e}" # model_id = DOWNSTREAM_MODELS.get(vlm_choice) # if model_id is None: # answer = "(gate only) No VLM selected." # else: # answer = vlm_answer(model_id, resized, question) if strategy != "expected": info = ( f"Native: {native_w}×{native_h}\n" #f"Gate probs [384,768,1024]: {['%.3f'%p for p in probs]}\n" f"Sufficient max-side: {expected:.1f}\n" f"Strategy: {strategy}\n" f"Predicted sufficient max-side: {pred}\n" f"Used max-side (clamped to native): {used_max}\n" f"Resized sent to VLM: {resized_w}×{resized_h}\n" f"VLM: {vlm_choice}\n" ) else: info = ( f"Native: {native_w}×{native_h}\n" #f"Gate probs [384,768,1024]: {['%.3f'%p for p in probs]}\n" f"Sufficient max-side: {expected:.1f}\n" #f"Strategy: {strategy}\n" #f"Predicted sufficient max-side: {pred}\n" #f"Used max-side (clamped to native): {used_max}\n" #f"Resized sent to VLM: {resized_w}×{resized_h}\n" f"VLM: {vlm_choice}\n") return info, resized, answer with gr.Blocks() as demo: gr.Markdown("# CARES – Sufficient Resolution Selection for VLMs") with gr.Row(): inp_img = gr.Image(type="pil", label="Upload image") with gr.Column(): inp_q = gr.Textbox(label="Question", placeholder="Ask something about the image…") vlm = gr.Dropdown( choices=list(DOWNSTREAM_MODELS.keys()), value=list(DOWNSTREAM_MODELS.keys())[0], label="VLM", ) strategy = gr.Dropdown( choices=["expected", "argmax", "conservative"], value="expected", label="Resolution selection strategy", ) btn = gr.Button("Run") out_info = gr.Textbox(label="Info", lines=10) out_img = gr.Image(type="pil", label="Image used for inference (sufficient resolution)") out_ans = gr.Textbox(label="Answer", lines=6) btn.click(run, inputs=[inp_img, inp_q, vlm, strategy], outputs=[out_info, out_img, out_ans]) demo.launch(ssr_mode=False)