Update app.py
Browse files
app.py
CHANGED
|
@@ -24,8 +24,6 @@ warnings.filterwarnings("ignore")
|
|
| 24 |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
| 25 |
logger = logging.getLogger(__name__)
|
| 26 |
|
| 27 |
-
# --- MODIFIED: Simplified API Request Model ---
|
| 28 |
-
# We no longer need 'task_type' from the frontend.
|
| 29 |
class GenerationRequest(BaseModel):
|
| 30 |
input_text: str
|
| 31 |
model_name: str
|
|
@@ -46,6 +44,9 @@ class MedicalKnowledgeTester:
|
|
| 46 |
self.models = {}
|
| 47 |
self.tokenizers = {}
|
| 48 |
|
|
|
|
|
|
|
|
|
|
| 49 |
self.model_configs = {
|
| 50 |
"led-base": {
|
| 51 |
"model_type": "encoder-decoder",
|
|
@@ -60,11 +61,16 @@ class MedicalKnowledgeTester:
|
|
| 60 |
"adapter_model": "ALQAMARI/gemma-3-12b-it-summary-adapter",
|
| 61 |
"max_length": 4096,
|
| 62 |
"use_quantization": True,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
}
|
| 64 |
}
|
| 65 |
|
| 66 |
-
# --- MODIFIED: New General-Purpose Prompt Template ---
|
| 67 |
-
# This single prompt instructs the model to be smart about the user's input.
|
| 68 |
self.GENERAL_TEMPLATE = """You are a versatile and highly skilled medical AI assistant. Your role is to provide accurate and helpful responses to medical inquiries.
|
| 69 |
- If the user provides a patient record, a long medical report, or text that requires summarization, your primary task is to summarize it concisely. Highlight the key findings, diagnoses, and recommendations in a clear format suitable for other medical professionals.
|
| 70 |
- If the user asks a direct question, provide a comprehensive and clear medical explanation.
|
|
@@ -121,8 +127,6 @@ Your Response:"""
|
|
| 121 |
self.tokenizers[model_name] = tokenizer
|
| 122 |
logger.info(f"{model_name} loaded successfully.")
|
| 123 |
|
| 124 |
-
# --- MODIFIED: Simplified generate_response function ---
|
| 125 |
-
# It no longer needs 'task_type' and uses the general template for everything.
|
| 126 |
def generate_response(self, model_name: str, input_text: str) -> str:
|
| 127 |
if model_name not in self.models:
|
| 128 |
self.load_model(model_name)
|
|
@@ -168,6 +172,7 @@ app.add_middleware(
|
|
| 168 |
async def startup_event():
|
| 169 |
logger.info("Server starting up. Pre-loading default model...")
|
| 170 |
try:
|
|
|
|
| 171 |
tester.load_model("gemma-3-12b-it")
|
| 172 |
except Exception as e:
|
| 173 |
logger.error(f"Could not pre-load gemma-3-12b-it model: {e}")
|
|
@@ -176,7 +181,6 @@ async def startup_event():
|
|
| 176 |
def read_root():
|
| 177 |
return {"status": "Medical AI API - I AM THE NEW VERSION"}
|
| 178 |
|
| 179 |
-
# --- MODIFIED: Updated generate endpoint ---
|
| 180 |
@app.post("/generate", response_model=GenerationResponse)
|
| 181 |
async def generate(request: GenerationRequest):
|
| 182 |
logger.info(f"Received request for model: {request.model_name}")
|
|
|
|
| 24 |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
| 25 |
logger = logging.getLogger(__name__)
|
| 26 |
|
|
|
|
|
|
|
| 27 |
class GenerationRequest(BaseModel):
|
| 28 |
input_text: str
|
| 29 |
model_name: str
|
|
|
|
| 44 |
self.models = {}
|
| 45 |
self.tokenizers = {}
|
| 46 |
|
| 47 |
+
# --- MODIFICATION: Added 'med_gemma' to the model configurations ---
|
| 48 |
+
# NOTE: I have assumed 'google/med-gemma-2b' as the base model and a corresponding adapter name.
|
| 49 |
+
# Please verify these are the correct Hugging Face IDs for your model.
|
| 50 |
self.model_configs = {
|
| 51 |
"led-base": {
|
| 52 |
"model_type": "encoder-decoder",
|
|
|
|
| 61 |
"adapter_model": "ALQAMARI/gemma-3-12b-it-summary-adapter",
|
| 62 |
"max_length": 4096,
|
| 63 |
"use_quantization": True,
|
| 64 |
+
},
|
| 65 |
+
"med_gemma": {
|
| 66 |
+
"model_type": "decoder",
|
| 67 |
+
"base_model": "google/med-gemma-2b", # Assumed base model, please verify
|
| 68 |
+
"adapter_model": "ALQAMARI/med-gemma-summary-adapter", # Assumed adapter, please verify
|
| 69 |
+
"max_length": 4096,
|
| 70 |
+
"use_quantization": True, # Assumed quantization for efficiency
|
| 71 |
}
|
| 72 |
}
|
| 73 |
|
|
|
|
|
|
|
| 74 |
self.GENERAL_TEMPLATE = """You are a versatile and highly skilled medical AI assistant. Your role is to provide accurate and helpful responses to medical inquiries.
|
| 75 |
- If the user provides a patient record, a long medical report, or text that requires summarization, your primary task is to summarize it concisely. Highlight the key findings, diagnoses, and recommendations in a clear format suitable for other medical professionals.
|
| 76 |
- If the user asks a direct question, provide a comprehensive and clear medical explanation.
|
|
|
|
| 127 |
self.tokenizers[model_name] = tokenizer
|
| 128 |
logger.info(f"{model_name} loaded successfully.")
|
| 129 |
|
|
|
|
|
|
|
| 130 |
def generate_response(self, model_name: str, input_text: str) -> str:
|
| 131 |
if model_name not in self.models:
|
| 132 |
self.load_model(model_name)
|
|
|
|
| 172 |
async def startup_event():
|
| 173 |
logger.info("Server starting up. Pre-loading default model...")
|
| 174 |
try:
|
| 175 |
+
# You might want to change the default pre-loaded model or pre-load all of them
|
| 176 |
tester.load_model("gemma-3-12b-it")
|
| 177 |
except Exception as e:
|
| 178 |
logger.error(f"Could not pre-load gemma-3-12b-it model: {e}")
|
|
|
|
| 181 |
def read_root():
|
| 182 |
return {"status": "Medical AI API - I AM THE NEW VERSION"}
|
| 183 |
|
|
|
|
| 184 |
@app.post("/generate", response_model=GenerationResponse)
|
| 185 |
async def generate(request: GenerationRequest):
|
| 186 |
logger.info(f"Received request for model: {request.model_name}")
|