File size: 4,975 Bytes
ee1fa58
1471372
ee1fa58
 
1471372
 
 
ee1fa58
 
e935d19
 
1471372
 
 
 
 
 
 
 
 
 
 
 
 
ee1fa58
1471372
ee1fa58
 
1471372
 
 
 
 
 
 
 
 
 
 
 
ee1fa58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1471372
ee1fa58
1471372
 
 
ee1fa58
1471372
ee1fa58
1471372
ee1fa58
 
 
1471372
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ee1fa58
1471372
ee1fa58
1471372
 
 
 
 
 
 
ee1fa58
 
1471372
 
 
 
 
 
ee1fa58
1471372
 
 
 
 
 
ee1fa58
 
1471372
 
 
 
 
 
 
 
 
 
 
ee1fa58
1471372
 
 
 
 
ee1fa58
1471372
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
from typing import Dict, List, Any
from transformers import AutoModelForVision2Seq, AutoTokenizer, AutoProcessor
from PIL import Image
import torch
import io
import base64
from peft import PeftModel

class EndpointHandler():
    def __init__(self, model_dir: str):
        self.path = model_dir
        # Load base model and tokenizer
        base_model_id = "Qwen/Qwen2-VL-2B-Instruct"
        
        # Load tokenizer/processor
        self.processor = AutoProcessor.from_pretrained(
            self.path,
            trust_remote_code=True
        )
        
        # Load base model
        self.model = AutoModelForVision2Seq.from_pretrained(
            base_model_id,
            torch_dtype=torch.float16,
            device_map="auto",
            trust_remote_code=True
        )
        
        # Load LoRA adapter
        self.model = PeftModel.from_pretrained(
            self.model,
            self.path,
            device_map="auto"
        )
        
        # Merge and unload for faster inference
        self.model = self.model.merge_and_unload()
        
        # Set to eval mode
        self.model.eval()
        
        # Store the instruction template
        self.instruction = """
A conversation between a Healthcare Provider and an AI Medical Image Analysis Assistant. The provider shares a medical image, and the Assistant generates a clear description/report. The assistant first analyzes the image systematically, then provides a concise report. The analysis process and report are enclosed within <thinking> </thinking><answer> </answer>.

Always respond in this format:

<thinking>
1. Initial Assessment:
   - What type of image is this? (X-ray, CT, MRI, etc.)
   - Which body part/region is shown?
   - Is the image quality adequate?

2. Key Findings:
   - What are the normal structures visible?
   - Are there any abnormalities?
   - What are the important measurements?

3. Clinical Significance:
   - What are the main clinical findings?
   - Are there any critical findings?
</thinking>

<answer>
Brief Structured Report:
1. EXAM TYPE: [imaging type and body region]
2. FINDINGS: [key observations and abnormalities]
3. IMPRESSION: [summary and clinical significance]
</answer>
"""

    def __call__(self, data: Dict[str, Any]) -> List[Dict[str, str]]:
        """
        data args:
            inputs (:obj: `str` | `PIL.Image` | `np.array`)
            parameters (:obj: `Dict[str, Any]`, *optional*)
        Return:
            A :obj:`list` | `dict`: will be serialized and returned
        """
        # Extract inputs and parameters
        inputs = data.pop("inputs", data)
        parameters = data.pop("parameters", {})
        
        # Handle different input formats
        if isinstance(inputs, str):
            # Base64 encoded image
            image_bytes = base64.b64decode(inputs)
            image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
        elif isinstance(inputs, dict):
            # Dictionary with image key
            image_data = inputs.get("image", inputs.get("inputs", ""))
            if isinstance(image_data, str):
                image_bytes = base64.b64decode(image_data)
                image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
            else:
                image = image_data
        else:
            # Direct image
            image = inputs
            
        # Ensure image is RGB
        if image.mode != "RGB":
            image = image.convert("RGB")
        
        # Prepare messages in Qwen format
        messages = [
            {
                "role": "user",
                "content": [
                    {"type": "image", "image": image},
                    {"type": "text", "text": self.instruction}
                ]
            }
        ]
        
        # Process inputs
        text = self.processor.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True
        )
        
        # Prepare inputs for model
        inputs = self.processor(
            text=[text],
            images=[image],
            padding=True,
            return_tensors="pt"
        ).to(self.model.device)
        
        # Generate response
        with torch.no_grad():
            output_ids = self.model.generate(
                **inputs,
                max_new_tokens=parameters.get("max_new_tokens", 512),
                temperature=parameters.get("temperature", 0.7),
                top_p=parameters.get("top_p", 0.9),
                do_sample=True,
                pad_token_id=self.processor.tokenizer.pad_token_id,
                eos_token_id=self.processor.tokenizer.eos_token_id,
            )
        
        # Decode output - only the generated part
        output_text = self.processor.batch_decode(
            output_ids[:, inputs.input_ids.shape[1]:],
            skip_special_tokens=True
        )[0]
        
        return [{"generated_text": output_text}]