|
|
|
|
|
|
|
|
""" |
|
|
Medical Knowledge API Server |
|
|
============================ |
|
|
|
|
|
This script creates a FastAPI backend to serve the medical summarization models. |
|
|
It exposes a /generate endpoint that the frontend can call. |
|
|
|
|
|
Based on the original medical_knowledge_test.py script. |
|
|
""" |
|
|
|
|
|
import os |
|
|
import torch |
|
|
import gc |
|
|
import logging |
|
|
from typing import Dict |
|
|
from pydantic import BaseModel |
|
|
|
|
|
|
|
|
from transformers import ( |
|
|
AutoTokenizer, AutoModelForCausalLM, AutoModelForSeq2SeqLM, |
|
|
BitsAndBytesConfig, Gemma3ForConditionalGeneration |
|
|
) |
|
|
from huggingface_hub import login |
|
|
from peft import PeftModel |
|
|
import warnings |
|
|
|
|
|
|
|
|
from fastapi import FastAPI, HTTPException |
|
|
from fastapi.middleware.cors import CORSMiddleware |
|
|
|
|
|
warnings.filterwarnings("ignore") |
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
class GenerationRequest(BaseModel): |
|
|
input_text: str |
|
|
model_name: str |
|
|
task_type: str |
|
|
|
|
|
class GenerationResponse(BaseModel): |
|
|
response: str |
|
|
|
|
|
|
|
|
class MedicalKnowledgeTester: |
|
|
def __init__(self): |
|
|
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
logger.info(f"Using device: {self.device}") |
|
|
|
|
|
|
|
|
|
|
|
hf_token = os.environ.get("HF_TOKEN") |
|
|
if hf_token: |
|
|
login(token=hf_token) |
|
|
logger.info("Logged in to Hugging Face using token from environment variable.") |
|
|
|
|
|
self.models = {} |
|
|
self.tokenizers = {} |
|
|
|
|
|
self.model_configs = { |
|
|
"led-base": { |
|
|
"model_type": "encoder-decoder", |
|
|
"base_model": "allenai/led-base-16384", |
|
|
"adapter_model": "ALQAMARI/led-base-sbar-summary-adapter", |
|
|
"max_length": 4096, |
|
|
"use_quantization": False, |
|
|
}, |
|
|
"gemma-3-12b-it": { |
|
|
"model_type": "decoder", |
|
|
"base_model": "google/gemma-3-12b-it", |
|
|
"adapter_model": "ALQAMARI/gemma-3-12b-it-summary-adapter", |
|
|
"max_length": 4096, |
|
|
"use_quantization": True, |
|
|
} |
|
|
} |
|
|
|
|
|
self.SUMMARY_TEMPLATE = "You are a doctor in a hospital. You must summarize the patient's medical history...\n\nPatient Record:\n\n{input_text}\n\nSummary:" |
|
|
self.KNOWLEDGE_TEMPLATE = "You are an experienced physician...\n\nMedical Question/Scenario:\n\n{input_text}\n\nMedical Explanation:" |
|
|
|
|
|
def load_model(self, model_name: str): |
|
|
|
|
|
if model_name in self.models: |
|
|
logger.info(f"Model '{model_name}' is already loaded.") |
|
|
return |
|
|
|
|
|
if model_name not in self.model_configs: |
|
|
raise ValueError(f"Model {model_name} not supported.") |
|
|
|
|
|
config = self.model_configs[model_name] |
|
|
logger.info(f"Loading {model_name}...") |
|
|
|
|
|
model_kwargs = {"device_map": "auto", "trust_remote_code": True} |
|
|
|
|
|
if config["use_quantization"]: |
|
|
bnb_config = BitsAndBytesConfig( |
|
|
load_in_4bit=True, bnb_4bit_quant_type="nf4", |
|
|
bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=True |
|
|
) |
|
|
model_kwargs["quantization_config"] = bnb_config |
|
|
model_kwargs["torch_dtype"] = torch.bfloat16 |
|
|
else: |
|
|
model_kwargs["torch_dtype"] = torch.float16 |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(config["base_model"]) |
|
|
if tokenizer.pad_token is None: |
|
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
tokenizer.padding_side = "left" |
|
|
|
|
|
if config["model_type"] == "encoder-decoder": |
|
|
base_model = AutoModelForSeq2SeqLM.from_pretrained(config["base_model"], **model_kwargs) |
|
|
else: |
|
|
base_model = AutoModelForCausalLM.from_pretrained(config["base_model"], **model_kwargs) |
|
|
|
|
|
try: |
|
|
model = PeftModel.from_pretrained(base_model, config["adapter_model"]) |
|
|
logger.info(f"Successfully loaded adapter from {config['adapter_model']}") |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to load adapter: {e}. Using base model without adapter.") |
|
|
model = base_model |
|
|
|
|
|
model.eval() |
|
|
|
|
|
self.models[model_name] = model |
|
|
self.tokenizers[model_name] = tokenizer |
|
|
logger.info(f"{model_name} loaded successfully.") |
|
|
|
|
|
def generate_response(self, model_name: str, input_text: str, task_type: str) -> str: |
|
|
if model_name not in self.models: |
|
|
self.load_model(model_name) |
|
|
|
|
|
model = self.models[model_name] |
|
|
tokenizer = self.tokenizers[model_name] |
|
|
config = self.model_configs[model_name] |
|
|
|
|
|
prompt = (self.SUMMARY_TEMPLATE if task_type == "summary" else self.KNOWLEDGE_TEMPLATE).format(input_text=input_text) |
|
|
|
|
|
if config["model_type"] == "decoder": |
|
|
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=config["max_length"]).to(self.device) |
|
|
else: |
|
|
inputs = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=config["max_length"]).to(self.device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model.generate( |
|
|
**inputs, max_new_tokens=512, do_sample=True, temperature=0.1, |
|
|
pad_token_id=tokenizer.eos_token_id, repetition_penalty=1.1 |
|
|
) |
|
|
|
|
|
if config["model_type"] == "decoder": |
|
|
input_length = inputs.input_ids.shape[1] |
|
|
generated_tokens = outputs[0][input_length:] |
|
|
else: |
|
|
generated_tokens = outputs[0] |
|
|
|
|
|
response = tokenizer.decode(generated_tokens, skip_special_tokens=True) |
|
|
return response.strip() |
|
|
|
|
|
|
|
|
app = FastAPI() |
|
|
tester = MedicalKnowledgeTester() |
|
|
|
|
|
|
|
|
|
|
|
app.add_middleware( |
|
|
CORSMiddleware, |
|
|
allow_origins=["*"], |
|
|
allow_credentials=True, |
|
|
allow_methods=["*"], |
|
|
allow_headers=["*"], |
|
|
) |
|
|
|
|
|
@app.on_event("startup") |
|
|
async def startup_event(): |
|
|
|
|
|
|
|
|
logger.info("Server starting up. Pre-loading default model...") |
|
|
try: |
|
|
tester.load_model("gemma-3-12b-it") |
|
|
except Exception as e: |
|
|
logger.error(f"Could not pre-load gemma-3-12b-it model: {e}") |
|
|
logger.info("Attempting to load led-base instead.") |
|
|
try: |
|
|
tester.load_model("led-base") |
|
|
except Exception as e2: |
|
|
logger.error(f"Could not pre-load any model: {e2}") |
|
|
|
|
|
@app.get("/") |
|
|
def read_root(): |
|
|
return {"status": "Medical AI API is running"} |
|
|
|
|
|
@app.post("/generate", response_model=GenerationResponse) |
|
|
async def generate(request: GenerationRequest): |
|
|
logger.info(f"Received request for model: {request.model_name}, task: {request.task_type}") |
|
|
try: |
|
|
response_text = tester.generate_response( |
|
|
model_name=request.model_name, |
|
|
input_text=request.input_text, |
|
|
task_type=request.task_type |
|
|
) |
|
|
return GenerationResponse(response=response_text) |
|
|
except Exception as e: |
|
|
logger.error(f"Error during generation: {e}") |
|
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
|
|
|
|
|