|
|
""" |
|
|
MedRAX API Module |
|
|
|
|
|
This module provides a FastAPI-based REST API for the MedRAX medical imaging AI assistant. |
|
|
It offers endpoints for processing medical images with text queries using the same agent |
|
|
architecture as the Gradio interface. |
|
|
|
|
|
The API supports: |
|
|
- Text-only queries |
|
|
- Single or multiple image inputs |
|
|
- Optional custom system prompts |
|
|
- Automatic thread management for each request |
|
|
- Tool execution and result aggregation |
|
|
""" |
|
|
|
|
|
import uuid |
|
|
import base64 |
|
|
from pathlib import Path |
|
|
from typing import List, Optional, Dict, Any |
|
|
import re |
|
|
import time |
|
|
|
|
|
from fastapi import FastAPI, HTTPException, UploadFile, File, Form |
|
|
from fastapi.middleware.cors import CORSMiddleware |
|
|
from pydantic import BaseModel, Field |
|
|
from langchain_core.messages import AIMessage, ToolMessage |
|
|
|
|
|
|
|
|
from medrax.agent import Agent |
|
|
|
|
|
|
|
|
class QueryRequest(BaseModel): |
|
|
""" |
|
|
Request model for text-only queries. |
|
|
|
|
|
Attributes: |
|
|
question (str): The question or query to ask the agent |
|
|
system_prompt (Optional[str]): Custom system prompt to override default |
|
|
thread_id (Optional[str]): Optional thread ID for conversation continuity |
|
|
""" |
|
|
|
|
|
question: str = Field(..., description="The question or query to ask the agent") |
|
|
system_prompt: Optional[str] = Field(None, description="Custom system prompt to override default") |
|
|
thread_id: Optional[str] = Field(None, description="Optional thread ID for conversation continuity") |
|
|
|
|
|
|
|
|
class QueryResponse(BaseModel): |
|
|
""" |
|
|
Response model for API queries. |
|
|
|
|
|
Attributes: |
|
|
response (str): The agent's text response |
|
|
thread_id (str): The thread ID used for this conversation |
|
|
tools_used (List[str]): List of tools that were executed |
|
|
processing_time (float): Time taken to process the request in seconds |
|
|
""" |
|
|
|
|
|
response: str = Field(..., description="The agent's text response") |
|
|
thread_id: str = Field(..., description="The thread ID used for this conversation") |
|
|
tools_used: List[str] = Field(..., description="List of tools that were executed") |
|
|
processing_time: float = Field(..., description="Time taken to process the request in seconds") |
|
|
|
|
|
|
|
|
class MedRAXAPI: |
|
|
""" |
|
|
FastAPI application wrapper for the MedRAX agent. |
|
|
|
|
|
This class provides a clean interface for creating and managing the API endpoints |
|
|
while maintaining separation of concerns from the core agent functionality. |
|
|
""" |
|
|
|
|
|
def __init__(self, agent: Agent, tools_dict: Dict[str, Any], temp_dir: str = "temp_api"): |
|
|
""" |
|
|
Initialize the MedRAX API. |
|
|
|
|
|
Args: |
|
|
agent (Agent): The initialized MedRAX agent |
|
|
tools_dict (Dict[str, Any]): Dictionary of available tools |
|
|
temp_dir (str): Directory for temporary file storage |
|
|
""" |
|
|
self.agent = agent |
|
|
self.tools_dict = tools_dict |
|
|
self.temp_dir = Path(temp_dir) |
|
|
self.temp_dir.mkdir(exist_ok=True) |
|
|
|
|
|
|
|
|
self.app = FastAPI( |
|
|
title="MedRAX API", |
|
|
description="Medical Reasoning Agent for Chest X-ray Analysis", |
|
|
version="2.0.0", |
|
|
docs_url="/docs", |
|
|
redoc_url="/redoc", |
|
|
) |
|
|
|
|
|
|
|
|
self.app.add_middleware( |
|
|
CORSMiddleware, |
|
|
allow_origins=["*"], |
|
|
allow_credentials=True, |
|
|
allow_methods=["*"], |
|
|
allow_headers=["*"], |
|
|
) |
|
|
|
|
|
|
|
|
self._register_routes() |
|
|
|
|
|
def _register_routes(self): |
|
|
"""Register all API routes.""" |
|
|
|
|
|
@self.app.get("/health") |
|
|
async def health_check(): |
|
|
"""Health check endpoint.""" |
|
|
return {"status": "healthy", "service": "MedRAX API"} |
|
|
|
|
|
@self.app.get("/tools") |
|
|
async def list_tools(): |
|
|
"""List available tools.""" |
|
|
return {"available_tools": list(self.tools_dict.keys()), "total_count": len(self.tools_dict)} |
|
|
|
|
|
@self.app.post("/query", response_model=QueryResponse) |
|
|
async def query_text_only(request: QueryRequest): |
|
|
""" |
|
|
Process a text-only query without images. |
|
|
|
|
|
Args: |
|
|
request (QueryRequest): The query request |
|
|
|
|
|
Returns: |
|
|
QueryResponse: The agent's response |
|
|
""" |
|
|
return await self._process_query( |
|
|
question=request.question, system_prompt=request.system_prompt, thread_id=request.thread_id, images=None |
|
|
) |
|
|
|
|
|
@self.app.post("/query-with-images", response_model=QueryResponse) |
|
|
async def query_with_images( |
|
|
question: str = Form(..., description="The question or query to ask the agent"), |
|
|
system_prompt: Optional[str] = Form(None, description="Custom system prompt to override default"), |
|
|
thread_id: Optional[str] = Form(None, description="Optional thread ID for conversation continuity"), |
|
|
images: List[UploadFile] = File(..., description="One or more medical images to analyze"), |
|
|
): |
|
|
""" |
|
|
Process a query with one or more images. |
|
|
|
|
|
Args: |
|
|
question (str): The question or query to ask the agent |
|
|
system_prompt (Optional[str]): Custom system prompt to override default |
|
|
thread_id (Optional[str]): Optional thread ID for conversation continuity |
|
|
images (List[UploadFile]): List of uploaded image files |
|
|
|
|
|
Returns: |
|
|
QueryResponse: The agent's response |
|
|
""" |
|
|
|
|
|
if not images or len(images) == 0: |
|
|
raise HTTPException(status_code=400, detail="At least one image is required") |
|
|
|
|
|
|
|
|
allowed_types = {"image/jpeg", "image/jpg", "image/png", "image/bmp", "image/tiff", "application/dicom"} |
|
|
for image in images: |
|
|
if image.content_type not in allowed_types: |
|
|
raise HTTPException( |
|
|
status_code=400, |
|
|
detail=f"Unsupported file type: {image.content_type}. Allowed types: {allowed_types}", |
|
|
) |
|
|
|
|
|
return await self._process_query( |
|
|
question=question, system_prompt=system_prompt, thread_id=thread_id, images=images |
|
|
) |
|
|
|
|
|
async def _process_query( |
|
|
self, |
|
|
question: str, |
|
|
system_prompt: Optional[str] = None, |
|
|
thread_id: Optional[str] = None, |
|
|
images: Optional[List[UploadFile]] = None, |
|
|
) -> QueryResponse: |
|
|
""" |
|
|
Internal method to process queries through the agent. |
|
|
|
|
|
Args: |
|
|
question (str): The question to ask |
|
|
system_prompt (Optional[str]): Custom system prompt |
|
|
thread_id (Optional[str]): Thread ID for conversation |
|
|
images (Optional[List[UploadFile]]): List of images |
|
|
|
|
|
Returns: |
|
|
QueryResponse: The processed response |
|
|
""" |
|
|
start_time = time.time() |
|
|
|
|
|
|
|
|
if not thread_id: |
|
|
thread_id = str(uuid.uuid4()) |
|
|
|
|
|
try: |
|
|
|
|
|
messages = [] |
|
|
image_paths = [] |
|
|
|
|
|
|
|
|
if images: |
|
|
for i, image in enumerate(images): |
|
|
|
|
|
temp_path = self.temp_dir / f"{thread_id}_{i}_{image.filename}" |
|
|
|
|
|
with open(temp_path, "wb") as buffer: |
|
|
content = await image.read() |
|
|
buffer.write(content) |
|
|
|
|
|
image_paths.append(str(temp_path)) |
|
|
|
|
|
|
|
|
messages.append({"role": "user", "content": f"image_path: {temp_path}"}) |
|
|
|
|
|
|
|
|
image_base64 = base64.b64encode(content).decode("utf-8") |
|
|
|
|
|
|
|
|
mime_type = "image/jpeg" |
|
|
if image.content_type: |
|
|
mime_type = image.content_type |
|
|
elif temp_path.suffix.lower() in [".png"]: |
|
|
mime_type = "image/png" |
|
|
|
|
|
messages.append( |
|
|
{ |
|
|
"role": "user", |
|
|
"content": [ |
|
|
{ |
|
|
"type": "image_url", |
|
|
"image_url": {"url": f"data:{mime_type};base64,{image_base64}"}, |
|
|
} |
|
|
], |
|
|
} |
|
|
) |
|
|
|
|
|
|
|
|
messages.append({"role": "user", "content": [{"type": "text", "text": question}]}) |
|
|
|
|
|
|
|
|
response_text = "" |
|
|
tools_used = [] |
|
|
|
|
|
|
|
|
original_prompt = None |
|
|
if system_prompt: |
|
|
original_prompt = self.agent.system_prompt |
|
|
self.agent.system_prompt = system_prompt |
|
|
|
|
|
try: |
|
|
async for chunk in self._stream_agent_response(messages, thread_id): |
|
|
if chunk.get("type") == "text": |
|
|
response_text += chunk.get("content", "") |
|
|
elif chunk.get("type") == "tool": |
|
|
tools_used.append(chunk.get("tool_name", "")) |
|
|
finally: |
|
|
|
|
|
if original_prompt is not None: |
|
|
self.agent.system_prompt = original_prompt |
|
|
|
|
|
|
|
|
for image_path in image_paths: |
|
|
try: |
|
|
Path(image_path).unlink(missing_ok=True) |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
processing_time = time.time() - start_time |
|
|
|
|
|
return QueryResponse( |
|
|
response=response_text.strip(), |
|
|
thread_id=thread_id, |
|
|
tools_used=list(set(tools_used)), |
|
|
processing_time=processing_time, |
|
|
) |
|
|
|
|
|
except Exception as e: |
|
|
|
|
|
for image_path in image_paths: |
|
|
try: |
|
|
Path(image_path).unlink(missing_ok=True) |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
raise HTTPException(status_code=500, detail=f"Error processing query: {str(e)}") |
|
|
|
|
|
async def _stream_agent_response(self, messages: List[Dict], thread_id: str): |
|
|
""" |
|
|
Stream responses from the agent workflow. |
|
|
|
|
|
Args: |
|
|
messages (List[Dict]): Messages to process |
|
|
thread_id (str): Thread ID for the conversation |
|
|
|
|
|
Yields: |
|
|
Dict: Response chunks with type and content |
|
|
""" |
|
|
try: |
|
|
for chunk in self.agent.workflow.stream( |
|
|
{"messages": messages}, |
|
|
{"configurable": {"thread_id": thread_id}}, |
|
|
stream_mode="updates", |
|
|
): |
|
|
if not isinstance(chunk, dict): |
|
|
continue |
|
|
|
|
|
for node_name, node_output in chunk.items(): |
|
|
if "messages" not in node_output: |
|
|
continue |
|
|
|
|
|
for msg in node_output["messages"]: |
|
|
if isinstance(msg, AIMessage) and msg.content: |
|
|
|
|
|
clean_content = re.sub(r"temp[^\s]*", "", msg.content).strip() |
|
|
if clean_content: |
|
|
yield {"type": "text", "content": clean_content} |
|
|
|
|
|
elif isinstance(msg, ToolMessage): |
|
|
|
|
|
tool_call_id = msg.tool_call_id |
|
|
|
|
|
yield {"type": "tool", "tool_name": "tool_executed"} |
|
|
|
|
|
except Exception as e: |
|
|
yield {"type": "error", "content": str(e)} |
|
|
|
|
|
|
|
|
def create_api(agent: Agent, tools_dict: Dict[str, Any], temp_dir: str = "temp_api") -> FastAPI: |
|
|
""" |
|
|
Create and configure the MedRAX FastAPI application. |
|
|
|
|
|
Args: |
|
|
agent (Agent): The initialized MedRAX agent |
|
|
tools_dict (Dict[str, Any]): Dictionary of available tools |
|
|
temp_dir (str): Directory for temporary file storage |
|
|
|
|
|
Returns: |
|
|
FastAPI: Configured FastAPI application |
|
|
""" |
|
|
api = MedRAXAPI(agent, tools_dict, temp_dir) |
|
|
return api.app |
|
|
|