from typing import Dict, List, Any
from transformers import AutoModelForVision2Seq, AutoTokenizer, AutoProcessor
from PIL import Image
import torch
import io
import base64
from peft import PeftModel
class EndpointHandler():
def __init__(self, model_dir: str):
self.path = model_dir
# Load base model and tokenizer
base_model_id = "Qwen/Qwen2-VL-2B-Instruct"
# Load tokenizer/processor
self.processor = AutoProcessor.from_pretrained(
self.path,
trust_remote_code=True
)
# Load base model
self.model = AutoModelForVision2Seq.from_pretrained(
base_model_id,
torch_dtype=torch.float16,
device_map="auto",
trust_remote_code=True
)
# Load LoRA adapter
self.model = PeftModel.from_pretrained(
self.model,
self.path,
device_map="auto"
)
# Merge and unload for faster inference
self.model = self.model.merge_and_unload()
# Set to eval mode
self.model.eval()
# Store the instruction template
self.instruction = """
A conversation between a Healthcare Provider and an AI Medical Image Analysis Assistant. The provider shares a medical image, and the Assistant generates a clear description/report. The assistant first analyzes the image systematically, then provides a concise report. The analysis process and report are enclosed within .
Always respond in this format:
1. Initial Assessment:
- What type of image is this? (X-ray, CT, MRI, etc.)
- Which body part/region is shown?
- Is the image quality adequate?
2. Key Findings:
- What are the normal structures visible?
- Are there any abnormalities?
- What are the important measurements?
3. Clinical Significance:
- What are the main clinical findings?
- Are there any critical findings?
Brief Structured Report:
1. EXAM TYPE: [imaging type and body region]
2. FINDINGS: [key observations and abnormalities]
3. IMPRESSION: [summary and clinical significance]
"""
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, str]]:
"""
data args:
inputs (:obj: `str` | `PIL.Image` | `np.array`)
parameters (:obj: `Dict[str, Any]`, *optional*)
Return:
A :obj:`list` | `dict`: will be serialized and returned
"""
# Extract inputs and parameters
inputs = data.pop("inputs", data)
parameters = data.pop("parameters", {})
# Handle different input formats
if isinstance(inputs, str):
# Base64 encoded image
image_bytes = base64.b64decode(inputs)
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
elif isinstance(inputs, dict):
# Dictionary with image key
image_data = inputs.get("image", inputs.get("inputs", ""))
if isinstance(image_data, str):
image_bytes = base64.b64decode(image_data)
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
else:
image = image_data
else:
# Direct image
image = inputs
# Ensure image is RGB
if image.mode != "RGB":
image = image.convert("RGB")
# Prepare messages in Qwen format
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": image},
{"type": "text", "text": self.instruction}
]
}
]
# Process inputs
text = self.processor.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
# Prepare inputs for model
inputs = self.processor(
text=[text],
images=[image],
padding=True,
return_tensors="pt"
).to(self.model.device)
# Generate response
with torch.no_grad():
output_ids = self.model.generate(
**inputs,
max_new_tokens=parameters.get("max_new_tokens", 512),
temperature=parameters.get("temperature", 0.7),
top_p=parameters.get("top_p", 0.9),
do_sample=True,
pad_token_id=self.processor.tokenizer.pad_token_id,
eos_token_id=self.processor.tokenizer.eos_token_id,
)
# Decode output - only the generated part
output_text = self.processor.batch_decode(
output_ids[:, inputs.input_ids.shape[1]:],
skip_special_tokens=True
)[0]
return [{"generated_text": output_text}]