File size: 7,941 Bytes
db322df |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 |
#!/usr/bin/env python3
"""
Medical Knowledge API Server
============================
This script creates a FastAPI backend to serve the medical summarization models.
It exposes a /generate endpoint that the frontend can call.
Based on the original medical_knowledge_test.py script.
"""
import os
import torch
import gc
import logging
from typing import Dict
from pydantic import BaseModel
# ML Libraries
from transformers import (
AutoTokenizer, AutoModelForCausalLM, AutoModelForSeq2SeqLM,
BitsAndBytesConfig, Gemma3ForConditionalGeneration
)
from huggingface_hub import login
from peft import PeftModel
import warnings
# API Framework
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
warnings.filterwarnings("ignore")
# Setup logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# --- Data Models for API ---
class GenerationRequest(BaseModel):
input_text: str
model_name: str
task_type: str
class GenerationResponse(BaseModel):
response: str
# --- Medical Knowledge Tester Class (Adapted for API) ---
class MedicalKnowledgeTester:
def __init__(self):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"Using device: {self.device}")
# Hugging Face login (optional, use if models are private or to avoid rate limits)
# It's better to set this as an environment variable in your deployment
hf_token = os.environ.get("HF_TOKEN")
if hf_token:
login(token=hf_token)
logger.info("Logged in to Hugging Face using token from environment variable.")
self.models = {}
self.tokenizers = {}
self.model_configs = {
"led-base": {
"model_type": "encoder-decoder",
"base_model": "allenai/led-base-16384",
"adapter_model": "ALQAMARI/led-base-sbar-summary-adapter",
"max_length": 4096,
"use_quantization": False,
},
"gemma-3-12b-it": {
"model_type": "decoder",
"base_model": "google/gemma-3-12b-it",
"adapter_model": "ALQAMARI/gemma-3-12b-it-summary-adapter",
"max_length": 4096,
"use_quantization": True,
}
}
self.SUMMARY_TEMPLATE = "You are a doctor in a hospital. You must summarize the patient's medical history...\n\nPatient Record:\n\n{input_text}\n\nSummary:"
self.KNOWLEDGE_TEMPLATE = "You are an experienced physician...\n\nMedical Question/Scenario:\n\n{input_text}\n\nMedical Explanation:"
def load_model(self, model_name: str):
# This function is now designed to prevent re-loading an already loaded model.
if model_name in self.models:
logger.info(f"Model '{model_name}' is already loaded.")
return
if model_name not in self.model_configs:
raise ValueError(f"Model {model_name} not supported.")
config = self.model_configs[model_name]
logger.info(f"Loading {model_name}...")
model_kwargs = {"device_map": "auto", "trust_remote_code": True}
if config["use_quantization"]:
bnb_config = BitsAndBytesConfig(
load_in_4bit=True, bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=True
)
model_kwargs["quantization_config"] = bnb_config
model_kwargs["torch_dtype"] = torch.bfloat16
else:
model_kwargs["torch_dtype"] = torch.float16
tokenizer = AutoTokenizer.from_pretrained(config["base_model"])
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"
if config["model_type"] == "encoder-decoder":
base_model = AutoModelForSeq2SeqLM.from_pretrained(config["base_model"], **model_kwargs)
else:
base_model = AutoModelForCausalLM.from_pretrained(config["base_model"], **model_kwargs)
try:
model = PeftModel.from_pretrained(base_model, config["adapter_model"])
logger.info(f"Successfully loaded adapter from {config['adapter_model']}")
except Exception as e:
logger.error(f"Failed to load adapter: {e}. Using base model without adapter.")
model = base_model
model.eval()
self.models[model_name] = model
self.tokenizers[model_name] = tokenizer
logger.info(f"{model_name} loaded successfully.")
def generate_response(self, model_name: str, input_text: str, task_type: str) -> str:
if model_name not in self.models:
self.load_model(model_name)
model = self.models[model_name]
tokenizer = self.tokenizers[model_name]
config = self.model_configs[model_name]
prompt = (self.SUMMARY_TEMPLATE if task_type == "summary" else self.KNOWLEDGE_TEMPLATE).format(input_text=input_text)
if config["model_type"] == "decoder":
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=config["max_length"]).to(self.device)
else:
inputs = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=config["max_length"]).to(self.device)
with torch.no_grad():
outputs = model.generate(
**inputs, max_new_tokens=512, do_sample=True, temperature=0.1,
pad_token_id=tokenizer.eos_token_id, repetition_penalty=1.1
)
if config["model_type"] == "decoder":
input_length = inputs.input_ids.shape[1]
generated_tokens = outputs[0][input_length:]
else:
generated_tokens = outputs[0]
response = tokenizer.decode(generated_tokens, skip_special_tokens=True)
return response.strip()
# --- Initialize FastAPI App and Medical Tester ---
app = FastAPI()
tester = MedicalKnowledgeTester()
# Allow Cross-Origin Resource Sharing (CORS) so your website on Hostinger
# can communicate with this API.
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Allows all origins
allow_credentials=True,
allow_methods=["*"], # Allows all methods
allow_headers=["*"], # Allows all headers
)
@app.on_event("startup")
async def startup_event():
# Pre-load a default model on startup to reduce wait time for the first user.
# The gemma model is larger, so loading it first is a good idea.
logger.info("Server starting up. Pre-loading default model...")
try:
tester.load_model("gemma-3-12b-it")
except Exception as e:
logger.error(f"Could not pre-load gemma-3-12b-it model: {e}")
logger.info("Attempting to load led-base instead.")
try:
tester.load_model("led-base")
except Exception as e2:
logger.error(f"Could not pre-load any model: {e2}")
@app.get("/")
def read_root():
return {"status": "Medical AI API is running"}
@app.post("/generate", response_model=GenerationResponse)
async def generate(request: GenerationRequest):
logger.info(f"Received request for model: {request.model_name}, task: {request.task_type}")
try:
response_text = tester.generate_response(
model_name=request.model_name,
input_text=request.input_text,
task_type=request.task_type
)
return GenerationResponse(response=response_text)
except Exception as e:
logger.error(f"Error during generation: {e}")
raise HTTPException(status_code=500, detail=str(e))
# To run this API locally for testing, you would use:
# uvicorn main:app --host 0.0.0.0 --port 8001 |