Abdou
Add application file
db322df
raw
history blame
7.94 kB
#!/usr/bin/env python3
"""
Medical Knowledge API Server
============================
This script creates a FastAPI backend to serve the medical summarization models.
It exposes a /generate endpoint that the frontend can call.
Based on the original medical_knowledge_test.py script.
"""
import os
import torch
import gc
import logging
from typing import Dict
from pydantic import BaseModel
# ML Libraries
from transformers import (
AutoTokenizer, AutoModelForCausalLM, AutoModelForSeq2SeqLM,
BitsAndBytesConfig, Gemma3ForConditionalGeneration
)
from huggingface_hub import login
from peft import PeftModel
import warnings
# API Framework
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
warnings.filterwarnings("ignore")
# Setup logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# --- Data Models for API ---
class GenerationRequest(BaseModel):
input_text: str
model_name: str
task_type: str
class GenerationResponse(BaseModel):
response: str
# --- Medical Knowledge Tester Class (Adapted for API) ---
class MedicalKnowledgeTester:
def __init__(self):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"Using device: {self.device}")
# Hugging Face login (optional, use if models are private or to avoid rate limits)
# It's better to set this as an environment variable in your deployment
hf_token = os.environ.get("HF_TOKEN")
if hf_token:
login(token=hf_token)
logger.info("Logged in to Hugging Face using token from environment variable.")
self.models = {}
self.tokenizers = {}
self.model_configs = {
"led-base": {
"model_type": "encoder-decoder",
"base_model": "allenai/led-base-16384",
"adapter_model": "ALQAMARI/led-base-sbar-summary-adapter",
"max_length": 4096,
"use_quantization": False,
},
"gemma-3-12b-it": {
"model_type": "decoder",
"base_model": "google/gemma-3-12b-it",
"adapter_model": "ALQAMARI/gemma-3-12b-it-summary-adapter",
"max_length": 4096,
"use_quantization": True,
}
}
self.SUMMARY_TEMPLATE = "You are a doctor in a hospital. You must summarize the patient's medical history...\n\nPatient Record:\n\n{input_text}\n\nSummary:"
self.KNOWLEDGE_TEMPLATE = "You are an experienced physician...\n\nMedical Question/Scenario:\n\n{input_text}\n\nMedical Explanation:"
def load_model(self, model_name: str):
# This function is now designed to prevent re-loading an already loaded model.
if model_name in self.models:
logger.info(f"Model '{model_name}' is already loaded.")
return
if model_name not in self.model_configs:
raise ValueError(f"Model {model_name} not supported.")
config = self.model_configs[model_name]
logger.info(f"Loading {model_name}...")
model_kwargs = {"device_map": "auto", "trust_remote_code": True}
if config["use_quantization"]:
bnb_config = BitsAndBytesConfig(
load_in_4bit=True, bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=True
)
model_kwargs["quantization_config"] = bnb_config
model_kwargs["torch_dtype"] = torch.bfloat16
else:
model_kwargs["torch_dtype"] = torch.float16
tokenizer = AutoTokenizer.from_pretrained(config["base_model"])
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"
if config["model_type"] == "encoder-decoder":
base_model = AutoModelForSeq2SeqLM.from_pretrained(config["base_model"], **model_kwargs)
else:
base_model = AutoModelForCausalLM.from_pretrained(config["base_model"], **model_kwargs)
try:
model = PeftModel.from_pretrained(base_model, config["adapter_model"])
logger.info(f"Successfully loaded adapter from {config['adapter_model']}")
except Exception as e:
logger.error(f"Failed to load adapter: {e}. Using base model without adapter.")
model = base_model
model.eval()
self.models[model_name] = model
self.tokenizers[model_name] = tokenizer
logger.info(f"{model_name} loaded successfully.")
def generate_response(self, model_name: str, input_text: str, task_type: str) -> str:
if model_name not in self.models:
self.load_model(model_name)
model = self.models[model_name]
tokenizer = self.tokenizers[model_name]
config = self.model_configs[model_name]
prompt = (self.SUMMARY_TEMPLATE if task_type == "summary" else self.KNOWLEDGE_TEMPLATE).format(input_text=input_text)
if config["model_type"] == "decoder":
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=config["max_length"]).to(self.device)
else:
inputs = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=config["max_length"]).to(self.device)
with torch.no_grad():
outputs = model.generate(
**inputs, max_new_tokens=512, do_sample=True, temperature=0.1,
pad_token_id=tokenizer.eos_token_id, repetition_penalty=1.1
)
if config["model_type"] == "decoder":
input_length = inputs.input_ids.shape[1]
generated_tokens = outputs[0][input_length:]
else:
generated_tokens = outputs[0]
response = tokenizer.decode(generated_tokens, skip_special_tokens=True)
return response.strip()
# --- Initialize FastAPI App and Medical Tester ---
app = FastAPI()
tester = MedicalKnowledgeTester()
# Allow Cross-Origin Resource Sharing (CORS) so your website on Hostinger
# can communicate with this API.
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Allows all origins
allow_credentials=True,
allow_methods=["*"], # Allows all methods
allow_headers=["*"], # Allows all headers
)
@app.on_event("startup")
async def startup_event():
# Pre-load a default model on startup to reduce wait time for the first user.
# The gemma model is larger, so loading it first is a good idea.
logger.info("Server starting up. Pre-loading default model...")
try:
tester.load_model("gemma-3-12b-it")
except Exception as e:
logger.error(f"Could not pre-load gemma-3-12b-it model: {e}")
logger.info("Attempting to load led-base instead.")
try:
tester.load_model("led-base")
except Exception as e2:
logger.error(f"Could not pre-load any model: {e2}")
@app.get("/")
def read_root():
return {"status": "Medical AI API is running"}
@app.post("/generate", response_model=GenerationResponse)
async def generate(request: GenerationRequest):
logger.info(f"Received request for model: {request.model_name}, task: {request.task_type}")
try:
response_text = tester.generate_response(
model_name=request.model_name,
input_text=request.input_text,
task_type=request.task_type
)
return GenerationResponse(response=response_text)
except Exception as e:
logger.error(f"Error during generation: {e}")
raise HTTPException(status_code=500, detail=str(e))
# To run this API locally for testing, you would use:
# uvicorn main:app --host 0.0.0.0 --port 8001