#!/usr/bin/env python3 """ Medical Knowledge API Server ============================ This script creates a FastAPI backend to serve the medical summarization models. It exposes a /generate endpoint that the frontend can call. Based on the original medical_knowledge_test.py script. """ import os import torch import gc import logging from typing import Dict from pydantic import BaseModel # ML Libraries from transformers import ( AutoTokenizer, AutoModelForCausalLM, AutoModelForSeq2SeqLM, BitsAndBytesConfig, Gemma3ForConditionalGeneration ) from huggingface_hub import login from peft import PeftModel import warnings # API Framework from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware warnings.filterwarnings("ignore") # Setup logging logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) # --- Data Models for API --- class GenerationRequest(BaseModel): input_text: str model_name: str task_type: str class GenerationResponse(BaseModel): response: str # --- Medical Knowledge Tester Class (Adapted for API) --- class MedicalKnowledgeTester: def __init__(self): self.device = "cuda" if torch.cuda.is_available() else "cpu" logger.info(f"Using device: {self.device}") # Hugging Face login (optional, use if models are private or to avoid rate limits) # It's better to set this as an environment variable in your deployment hf_token = os.environ.get("HF_TOKEN") if hf_token: login(token=hf_token) logger.info("Logged in to Hugging Face using token from environment variable.") self.models = {} self.tokenizers = {} self.model_configs = { "led-base": { "model_type": "encoder-decoder", "base_model": "allenai/led-base-16384", "adapter_model": "ALQAMARI/led-base-sbar-summary-adapter", "max_length": 4096, "use_quantization": False, }, "gemma-3-12b-it": { "model_type": "decoder", "base_model": "google/gemma-3-12b-it", "adapter_model": "ALQAMARI/gemma-3-12b-it-summary-adapter", "max_length": 4096, "use_quantization": True, } } self.SUMMARY_TEMPLATE = "You are a doctor in a hospital. You must summarize the patient's medical history...\n\nPatient Record:\n\n{input_text}\n\nSummary:" self.KNOWLEDGE_TEMPLATE = "You are an experienced physician...\n\nMedical Question/Scenario:\n\n{input_text}\n\nMedical Explanation:" def load_model(self, model_name: str): # This function is now designed to prevent re-loading an already loaded model. if model_name in self.models: logger.info(f"Model '{model_name}' is already loaded.") return if model_name not in self.model_configs: raise ValueError(f"Model {model_name} not supported.") config = self.model_configs[model_name] logger.info(f"Loading {model_name}...") model_kwargs = {"device_map": "auto", "trust_remote_code": True} if config["use_quantization"]: bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=True ) model_kwargs["quantization_config"] = bnb_config model_kwargs["torch_dtype"] = torch.bfloat16 else: model_kwargs["torch_dtype"] = torch.float16 tokenizer = AutoTokenizer.from_pretrained(config["base_model"]) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token tokenizer.padding_side = "left" if config["model_type"] == "encoder-decoder": base_model = AutoModelForSeq2SeqLM.from_pretrained(config["base_model"], **model_kwargs) else: base_model = AutoModelForCausalLM.from_pretrained(config["base_model"], **model_kwargs) try: model = PeftModel.from_pretrained(base_model, config["adapter_model"]) logger.info(f"Successfully loaded adapter from {config['adapter_model']}") except Exception as e: logger.error(f"Failed to load adapter: {e}. Using base model without adapter.") model = base_model model.eval() self.models[model_name] = model self.tokenizers[model_name] = tokenizer logger.info(f"{model_name} loaded successfully.") def generate_response(self, model_name: str, input_text: str, task_type: str) -> str: if model_name not in self.models: self.load_model(model_name) model = self.models[model_name] tokenizer = self.tokenizers[model_name] config = self.model_configs[model_name] prompt = (self.SUMMARY_TEMPLATE if task_type == "summary" else self.KNOWLEDGE_TEMPLATE).format(input_text=input_text) if config["model_type"] == "decoder": inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=config["max_length"]).to(self.device) else: inputs = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=config["max_length"]).to(self.device) with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=512, do_sample=True, temperature=0.1, pad_token_id=tokenizer.eos_token_id, repetition_penalty=1.1 ) if config["model_type"] == "decoder": input_length = inputs.input_ids.shape[1] generated_tokens = outputs[0][input_length:] else: generated_tokens = outputs[0] response = tokenizer.decode(generated_tokens, skip_special_tokens=True) return response.strip() # --- Initialize FastAPI App and Medical Tester --- app = FastAPI() tester = MedicalKnowledgeTester() # Allow Cross-Origin Resource Sharing (CORS) so your website on Hostinger # can communicate with this API. app.add_middleware( CORSMiddleware, allow_origins=["*"], # Allows all origins allow_credentials=True, allow_methods=["*"], # Allows all methods allow_headers=["*"], # Allows all headers ) @app.on_event("startup") async def startup_event(): # Pre-load a default model on startup to reduce wait time for the first user. # The gemma model is larger, so loading it first is a good idea. logger.info("Server starting up. Pre-loading default model...") try: tester.load_model("gemma-3-12b-it") except Exception as e: logger.error(f"Could not pre-load gemma-3-12b-it model: {e}") logger.info("Attempting to load led-base instead.") try: tester.load_model("led-base") except Exception as e2: logger.error(f"Could not pre-load any model: {e2}") @app.get("/") def read_root(): return {"status": "Medical AI API is running"} @app.post("/generate", response_model=GenerationResponse) async def generate(request: GenerationRequest): logger.info(f"Received request for model: {request.model_name}, task: {request.task_type}") try: response_text = tester.generate_response( model_name=request.model_name, input_text=request.input_text, task_type=request.task_type ) return GenerationResponse(response=response_text) except Exception as e: logger.error(f"Error during generation: {e}") raise HTTPException(status_code=500, detail=str(e)) # To run this API locally for testing, you would use: # uvicorn main:app --host 0.0.0.0 --port 8001