Qwen_UI_final / handler.py
BoghdadyJR's picture
Add inference handler for HF Endpoints
e935d19 verified
from typing import Dict, List, Any
from transformers import AutoModelForVision2Seq, AutoTokenizer, AutoProcessor
from PIL import Image
import torch
import io
import base64
from peft import PeftModel
class EndpointHandler():
def __init__(self, model_dir: str):
self.path = model_dir
# Load base model and tokenizer
base_model_id = "Qwen/Qwen2-VL-2B-Instruct"
# Load tokenizer/processor
self.processor = AutoProcessor.from_pretrained(
self.path,
trust_remote_code=True
)
# Load base model
self.model = AutoModelForVision2Seq.from_pretrained(
base_model_id,
torch_dtype=torch.float16,
device_map="auto",
trust_remote_code=True
)
# Load LoRA adapter
self.model = PeftModel.from_pretrained(
self.model,
self.path,
device_map="auto"
)
# Merge and unload for faster inference
self.model = self.model.merge_and_unload()
# Set to eval mode
self.model.eval()
# Store the instruction template
self.instruction = """
A conversation between a Healthcare Provider and an AI Medical Image Analysis Assistant. The provider shares a medical image, and the Assistant generates a clear description/report. The assistant first analyzes the image systematically, then provides a concise report. The analysis process and report are enclosed within <thinking> </thinking><answer> </answer>.
Always respond in this format:
<thinking>
1. Initial Assessment:
- What type of image is this? (X-ray, CT, MRI, etc.)
- Which body part/region is shown?
- Is the image quality adequate?
2. Key Findings:
- What are the normal structures visible?
- Are there any abnormalities?
- What are the important measurements?
3. Clinical Significance:
- What are the main clinical findings?
- Are there any critical findings?
</thinking>
<answer>
Brief Structured Report:
1. EXAM TYPE: [imaging type and body region]
2. FINDINGS: [key observations and abnormalities]
3. IMPRESSION: [summary and clinical significance]
</answer>
"""
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, str]]:
"""
data args:
inputs (:obj: `str` | `PIL.Image` | `np.array`)
parameters (:obj: `Dict[str, Any]`, *optional*)
Return:
A :obj:`list` | `dict`: will be serialized and returned
"""
# Extract inputs and parameters
inputs = data.pop("inputs", data)
parameters = data.pop("parameters", {})
# Handle different input formats
if isinstance(inputs, str):
# Base64 encoded image
image_bytes = base64.b64decode(inputs)
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
elif isinstance(inputs, dict):
# Dictionary with image key
image_data = inputs.get("image", inputs.get("inputs", ""))
if isinstance(image_data, str):
image_bytes = base64.b64decode(image_data)
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
else:
image = image_data
else:
# Direct image
image = inputs
# Ensure image is RGB
if image.mode != "RGB":
image = image.convert("RGB")
# Prepare messages in Qwen format
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": image},
{"type": "text", "text": self.instruction}
]
}
]
# Process inputs
text = self.processor.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
# Prepare inputs for model
inputs = self.processor(
text=[text],
images=[image],
padding=True,
return_tensors="pt"
).to(self.model.device)
# Generate response
with torch.no_grad():
output_ids = self.model.generate(
**inputs,
max_new_tokens=parameters.get("max_new_tokens", 512),
temperature=parameters.get("temperature", 0.7),
top_p=parameters.get("top_p", 0.9),
do_sample=True,
pad_token_id=self.processor.tokenizer.pad_token_id,
eos_token_id=self.processor.tokenizer.eos_token_id,
)
# Decode output - only the generated part
output_text = self.processor.batch_decode(
output_ids[:, inputs.input_ids.shape[1]:],
skip_special_tokens=True
)[0]
return [{"generated_text": output_text}]