from typing import Dict, List, Any from transformers import AutoModelForVision2Seq, AutoTokenizer, AutoProcessor from PIL import Image import torch import io import base64 from peft import PeftModel class EndpointHandler(): def __init__(self, model_dir: str): self.path = model_dir # Load base model and tokenizer base_model_id = "Qwen/Qwen2-VL-2B-Instruct" # Load tokenizer/processor self.processor = AutoProcessor.from_pretrained( self.path, trust_remote_code=True ) # Load base model self.model = AutoModelForVision2Seq.from_pretrained( base_model_id, torch_dtype=torch.float16, device_map="auto", trust_remote_code=True ) # Load LoRA adapter self.model = PeftModel.from_pretrained( self.model, self.path, device_map="auto" ) # Merge and unload for faster inference self.model = self.model.merge_and_unload() # Set to eval mode self.model.eval() # Store the instruction template self.instruction = """ A conversation between a Healthcare Provider and an AI Medical Image Analysis Assistant. The provider shares a medical image, and the Assistant generates a clear description/report. The assistant first analyzes the image systematically, then provides a concise report. The analysis process and report are enclosed within . Always respond in this format: 1. Initial Assessment: - What type of image is this? (X-ray, CT, MRI, etc.) - Which body part/region is shown? - Is the image quality adequate? 2. Key Findings: - What are the normal structures visible? - Are there any abnormalities? - What are the important measurements? 3. Clinical Significance: - What are the main clinical findings? - Are there any critical findings? Brief Structured Report: 1. EXAM TYPE: [imaging type and body region] 2. FINDINGS: [key observations and abnormalities] 3. IMPRESSION: [summary and clinical significance] """ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, str]]: """ data args: inputs (:obj: `str` | `PIL.Image` | `np.array`) parameters (:obj: `Dict[str, Any]`, *optional*) Return: A :obj:`list` | `dict`: will be serialized and returned """ # Extract inputs and parameters inputs = data.pop("inputs", data) parameters = data.pop("parameters", {}) # Handle different input formats if isinstance(inputs, str): # Base64 encoded image image_bytes = base64.b64decode(inputs) image = Image.open(io.BytesIO(image_bytes)).convert("RGB") elif isinstance(inputs, dict): # Dictionary with image key image_data = inputs.get("image", inputs.get("inputs", "")) if isinstance(image_data, str): image_bytes = base64.b64decode(image_data) image = Image.open(io.BytesIO(image_bytes)).convert("RGB") else: image = image_data else: # Direct image image = inputs # Ensure image is RGB if image.mode != "RGB": image = image.convert("RGB") # Prepare messages in Qwen format messages = [ { "role": "user", "content": [ {"type": "image", "image": image}, {"type": "text", "text": self.instruction} ] } ] # Process inputs text = self.processor.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) # Prepare inputs for model inputs = self.processor( text=[text], images=[image], padding=True, return_tensors="pt" ).to(self.model.device) # Generate response with torch.no_grad(): output_ids = self.model.generate( **inputs, max_new_tokens=parameters.get("max_new_tokens", 512), temperature=parameters.get("temperature", 0.7), top_p=parameters.get("top_p", 0.9), do_sample=True, pad_token_id=self.processor.tokenizer.pad_token_id, eos_token_id=self.processor.tokenizer.eos_token_id, ) # Decode output - only the generated part output_text = self.processor.batch_decode( output_ids[:, inputs.input_ids.shape[1]:], skip_special_tokens=True )[0] return [{"generated_text": output_text}]