toy3003 commited on
Commit
3061b42
·
verified ·
1 Parent(s): 8fb4997

Create handler.py

Browse files
Files changed (1) hide show
  1. handler.py +70 -0
handler.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoModelForCausalLM, AutoProcessor
3
+ from PIL import Image
4
+ import base64
5
+ import io
6
+ import logging
7
+
8
+ logger = logging.getLogger(__name__)
9
+
10
+ class EndpointHandler():
11
+ def __init__(self, path=""):
12
+ """
13
+ ฟังก์ชันนี้จะทำงานแค่ครั้งเดียวตอนเริ่มต้น Endpoint เพื่อโหลดโมเดลรอไว้
14
+ """
15
+ logger.info("Initializing model...")
16
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
17
+
18
+ # โหลดโมเดลและ processor จาก path ที่ Hugging Face ส่งมาให้
19
+ self.model = AutoModelForCausalLM.from_pretrained(
20
+ path,
21
+ trust_remote_code=True,
22
+ torch_dtype=torch.bfloat16,
23
+ device_map=self.device,
24
+ # revision="b98e57b" # อาจจะต้องใช้ถ้ามีปัญหา
25
+ )
26
+ self.processor = AutoProcessor.from_pretrained(path, trust_remote_code=True)
27
+ logger.info("Model initialized successfully.")
28
+
29
+ def __call__(self, data):
30
+ """
31
+ ฟังก์ชันนี้จะทำงานทุกครั้งที่มี request ส่งเข้ามาที่ API
32
+ """
33
+ logger.info("Processing new request...")
34
+ # ดึงข้อมูลจาก request
35
+ inputs = data.pop("inputs", data)
36
+ image_b64 = inputs.get("image") # แนะนำให้ใช้ key ชื่อ "image"
37
+
38
+ if not image_b64:
39
+ return {"error": "Missing 'image' key with base64 encoded string in inputs."}
40
+
41
+ try:
42
+ # แปลง base64 string กลับเป็นรูปภาพ
43
+ image_bytes = base64.b64decode(image_b64)
44
+ image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
45
+ except Exception as e:
46
+ logger.error(f"Error decoding image: {e}")
47
+ return {"error": f"Invalid base64 image data. {e}"}
48
+
49
+ # สร้าง Prompt ตามรูปแบบที่โมเดลต้องการ
50
+ prompt = "<|user|>\n<image>\n<|assistant|>"
51
+
52
+ # เตรียมข้อมูลสำหรับส่งเข้าโมเดล
53
+ model_inputs = self.processor(text=prompt, images=image, return_tensors="pt").to(self.device, torch.bfloat16)
54
+
55
+ # รันโมเดลเพื่อสร้าง text
56
+ generated_ids = self.model.generate(
57
+ input_ids=model_inputs["input_ids"],
58
+ pixel_values=model_inputs["pixel_values"],
59
+ max_new_tokens=2048,
60
+ do_sample=False,
61
+ num_beams=1
62
+ )
63
+
64
+ # ถอดรหัสผลลัพธ์
65
+ generated_ids = generated_ids[:, model_inputs['input_ids'].shape[1]:]
66
+ response_text = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
67
+
68
+ logger.info("Request processed successfully.")
69
+ # ส่งผลลัพธ์กลับในรูปแบบ JSON
70
+ return {"generated_text": response_text}