Abdou commited on
Commit
db322df
·
1 Parent(s): 6727974

Add application file

Browse files
Files changed (3) hide show
  1. Dockerfile.dockerfile +15 -0
  2. main.py +209 -0
  3. requirements.txt +10 -0
Dockerfile.dockerfile ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.10-slim
2
+
3
+ WORKDIR /code
4
+
5
+ # Copy and install requirements
6
+ COPY ./requirements.txt /code/requirements.txt
7
+ RUN pip install --no-cache-dir --upgrade -r /code/requirements.txt
8
+
9
+ # Copy your application code
10
+ # Assuming main.py is in a sub-folder named 'app'
11
+ # If main.py is in the root, change the next line to: COPY ./main.py /code/
12
+ COPY ./app /code/app
13
+
14
+ # The port needs to be 7860 for Hugging Face Spaces to expose it
15
+ CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "7860"]
main.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ """
4
+ Medical Knowledge API Server
5
+ ============================
6
+
7
+ This script creates a FastAPI backend to serve the medical summarization models.
8
+ It exposes a /generate endpoint that the frontend can call.
9
+
10
+ Based on the original medical_knowledge_test.py script.
11
+ """
12
+
13
+ import os
14
+ import torch
15
+ import gc
16
+ import logging
17
+ from typing import Dict
18
+ from pydantic import BaseModel
19
+
20
+ # ML Libraries
21
+ from transformers import (
22
+ AutoTokenizer, AutoModelForCausalLM, AutoModelForSeq2SeqLM,
23
+ BitsAndBytesConfig, Gemma3ForConditionalGeneration
24
+ )
25
+ from huggingface_hub import login
26
+ from peft import PeftModel
27
+ import warnings
28
+
29
+ # API Framework
30
+ from fastapi import FastAPI, HTTPException
31
+ from fastapi.middleware.cors import CORSMiddleware
32
+
33
+ warnings.filterwarnings("ignore")
34
+
35
+ # Setup logging
36
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
37
+ logger = logging.getLogger(__name__)
38
+
39
+ # --- Data Models for API ---
40
+ class GenerationRequest(BaseModel):
41
+ input_text: str
42
+ model_name: str
43
+ task_type: str
44
+
45
+ class GenerationResponse(BaseModel):
46
+ response: str
47
+
48
+ # --- Medical Knowledge Tester Class (Adapted for API) ---
49
+ class MedicalKnowledgeTester:
50
+ def __init__(self):
51
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
52
+ logger.info(f"Using device: {self.device}")
53
+
54
+ # Hugging Face login (optional, use if models are private or to avoid rate limits)
55
+ # It's better to set this as an environment variable in your deployment
56
+ hf_token = os.environ.get("HF_TOKEN")
57
+ if hf_token:
58
+ login(token=hf_token)
59
+ logger.info("Logged in to Hugging Face using token from environment variable.")
60
+
61
+ self.models = {}
62
+ self.tokenizers = {}
63
+
64
+ self.model_configs = {
65
+ "led-base": {
66
+ "model_type": "encoder-decoder",
67
+ "base_model": "allenai/led-base-16384",
68
+ "adapter_model": "ALQAMARI/led-base-sbar-summary-adapter",
69
+ "max_length": 4096,
70
+ "use_quantization": False,
71
+ },
72
+ "gemma-3-12b-it": {
73
+ "model_type": "decoder",
74
+ "base_model": "google/gemma-3-12b-it",
75
+ "adapter_model": "ALQAMARI/gemma-3-12b-it-summary-adapter",
76
+ "max_length": 4096,
77
+ "use_quantization": True,
78
+ }
79
+ }
80
+
81
+ self.SUMMARY_TEMPLATE = "You are a doctor in a hospital. You must summarize the patient's medical history...\n\nPatient Record:\n\n{input_text}\n\nSummary:"
82
+ self.KNOWLEDGE_TEMPLATE = "You are an experienced physician...\n\nMedical Question/Scenario:\n\n{input_text}\n\nMedical Explanation:"
83
+
84
+ def load_model(self, model_name: str):
85
+ # This function is now designed to prevent re-loading an already loaded model.
86
+ if model_name in self.models:
87
+ logger.info(f"Model '{model_name}' is already loaded.")
88
+ return
89
+
90
+ if model_name not in self.model_configs:
91
+ raise ValueError(f"Model {model_name} not supported.")
92
+
93
+ config = self.model_configs[model_name]
94
+ logger.info(f"Loading {model_name}...")
95
+
96
+ model_kwargs = {"device_map": "auto", "trust_remote_code": True}
97
+
98
+ if config["use_quantization"]:
99
+ bnb_config = BitsAndBytesConfig(
100
+ load_in_4bit=True, bnb_4bit_quant_type="nf4",
101
+ bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=True
102
+ )
103
+ model_kwargs["quantization_config"] = bnb_config
104
+ model_kwargs["torch_dtype"] = torch.bfloat16
105
+ else:
106
+ model_kwargs["torch_dtype"] = torch.float16
107
+
108
+ tokenizer = AutoTokenizer.from_pretrained(config["base_model"])
109
+ if tokenizer.pad_token is None:
110
+ tokenizer.pad_token = tokenizer.eos_token
111
+ tokenizer.padding_side = "left"
112
+
113
+ if config["model_type"] == "encoder-decoder":
114
+ base_model = AutoModelForSeq2SeqLM.from_pretrained(config["base_model"], **model_kwargs)
115
+ else:
116
+ base_model = AutoModelForCausalLM.from_pretrained(config["base_model"], **model_kwargs)
117
+
118
+ try:
119
+ model = PeftModel.from_pretrained(base_model, config["adapter_model"])
120
+ logger.info(f"Successfully loaded adapter from {config['adapter_model']}")
121
+ except Exception as e:
122
+ logger.error(f"Failed to load adapter: {e}. Using base model without adapter.")
123
+ model = base_model
124
+
125
+ model.eval()
126
+
127
+ self.models[model_name] = model
128
+ self.tokenizers[model_name] = tokenizer
129
+ logger.info(f"{model_name} loaded successfully.")
130
+
131
+ def generate_response(self, model_name: str, input_text: str, task_type: str) -> str:
132
+ if model_name not in self.models:
133
+ self.load_model(model_name)
134
+
135
+ model = self.models[model_name]
136
+ tokenizer = self.tokenizers[model_name]
137
+ config = self.model_configs[model_name]
138
+
139
+ prompt = (self.SUMMARY_TEMPLATE if task_type == "summary" else self.KNOWLEDGE_TEMPLATE).format(input_text=input_text)
140
+
141
+ if config["model_type"] == "decoder":
142
+ inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=config["max_length"]).to(self.device)
143
+ else:
144
+ inputs = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=config["max_length"]).to(self.device)
145
+
146
+ with torch.no_grad():
147
+ outputs = model.generate(
148
+ **inputs, max_new_tokens=512, do_sample=True, temperature=0.1,
149
+ pad_token_id=tokenizer.eos_token_id, repetition_penalty=1.1
150
+ )
151
+
152
+ if config["model_type"] == "decoder":
153
+ input_length = inputs.input_ids.shape[1]
154
+ generated_tokens = outputs[0][input_length:]
155
+ else:
156
+ generated_tokens = outputs[0]
157
+
158
+ response = tokenizer.decode(generated_tokens, skip_special_tokens=True)
159
+ return response.strip()
160
+
161
+ # --- Initialize FastAPI App and Medical Tester ---
162
+ app = FastAPI()
163
+ tester = MedicalKnowledgeTester()
164
+
165
+ # Allow Cross-Origin Resource Sharing (CORS) so your website on Hostinger
166
+ # can communicate with this API.
167
+ app.add_middleware(
168
+ CORSMiddleware,
169
+ allow_origins=["*"], # Allows all origins
170
+ allow_credentials=True,
171
+ allow_methods=["*"], # Allows all methods
172
+ allow_headers=["*"], # Allows all headers
173
+ )
174
+
175
+ @app.on_event("startup")
176
+ async def startup_event():
177
+ # Pre-load a default model on startup to reduce wait time for the first user.
178
+ # The gemma model is larger, so loading it first is a good idea.
179
+ logger.info("Server starting up. Pre-loading default model...")
180
+ try:
181
+ tester.load_model("gemma-3-12b-it")
182
+ except Exception as e:
183
+ logger.error(f"Could not pre-load gemma-3-12b-it model: {e}")
184
+ logger.info("Attempting to load led-base instead.")
185
+ try:
186
+ tester.load_model("led-base")
187
+ except Exception as e2:
188
+ logger.error(f"Could not pre-load any model: {e2}")
189
+
190
+ @app.get("/")
191
+ def read_root():
192
+ return {"status": "Medical AI API is running"}
193
+
194
+ @app.post("/generate", response_model=GenerationResponse)
195
+ async def generate(request: GenerationRequest):
196
+ logger.info(f"Received request for model: {request.model_name}, task: {request.task_type}")
197
+ try:
198
+ response_text = tester.generate_response(
199
+ model_name=request.model_name,
200
+ input_text=request.input_text,
201
+ task_type=request.task_type
202
+ )
203
+ return GenerationResponse(response=response_text)
204
+ except Exception as e:
205
+ logger.error(f"Error during generation: {e}")
206
+ raise HTTPException(status_code=500, detail=str(e))
207
+
208
+ # To run this API locally for testing, you would use:
209
+ # uvicorn main:app --host 0.0.0.0 --port 8001
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ fastapi
2
+ uvicorn
3
+ python-multipart
4
+ torch
5
+ transformers
6
+ bitsandbytes
7
+ accelerate
8
+ peft
9
+ huggingface_hub
10
+ pydantic