import torch from transformers import AutoModelForCausalLM, AutoProcessor from PIL import Image import base64 import io import logging logger = logging.getLogger(__name__) class EndpointHandler(): def __init__(self, path=""): """ ฟังก์ชันนี้จะทำงานแค่ครั้งเดียวตอนเริ่มต้น Endpoint เพื่อโหลดโมเดลรอไว้ """ logger.info("Initializing model...") self.device = "cuda" if torch.cuda.is_available() else "cpu" # โหลดโมเดลและ processor จาก path ที่ Hugging Face ส่งมาให้ self.model = AutoModelForCausalLM.from_pretrained( path, trust_remote_code=True, torch_dtype=torch.bfloat16, device_map=self.device, # revision="b98e57b" # อาจจะต้องใช้ถ้ามีปัญหา ) self.processor = AutoProcessor.from_pretrained(path, trust_remote_code=True) logger.info("Model initialized successfully.") def __call__(self, data): """ ฟังก์ชันนี้จะทำงานทุกครั้งที่มี request ส่งเข้ามาที่ API """ logger.info("Processing new request...") # ดึงข้อมูลจาก request inputs = data.pop("inputs", data) image_b64 = inputs.get("image") # แนะนำให้ใช้ key ชื่อ "image" if not image_b64: return {"error": "Missing 'image' key with base64 encoded string in inputs."} try: # แปลง base64 string กลับเป็นรูปภาพ image_bytes = base64.b64decode(image_b64) image = Image.open(io.BytesIO(image_bytes)).convert("RGB") except Exception as e: logger.error(f"Error decoding image: {e}") return {"error": f"Invalid base64 image data. {e}"} # สร้าง Prompt ตามรูปแบบที่โมเดลต้องการ prompt = "<|user|>\n\n<|assistant|>" # เตรียมข้อมูลสำหรับส่งเข้าโมเดล model_inputs = self.processor(text=prompt, images=image, return_tensors="pt").to(self.device, torch.bfloat16) # รันโมเดลเพื่อสร้าง text generated_ids = self.model.generate( input_ids=model_inputs["input_ids"], pixel_values=model_inputs["pixel_values"], max_new_tokens=2048, do_sample=False, num_beams=1 ) # ถอดรหัสผลลัพธ์ generated_ids = generated_ids[:, model_inputs['input_ids'].shape[1]:] response_text = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0] logger.info("Request processed successfully.") # ส่งผลลัพธ์กลับในรูปแบบ JSON return {"generated_text": response_text}