medrax2 / app.py
samwell's picture
refactor: Integrate message box into chat interface
30b53f2
raw
history blame
11.6 kB
"""
MedRax2 Gradio Interface for Hugging Face Spaces
Simple standalone version for deployment
"""
import os
os.environ["OMP_NUM_THREADS"] = "1"
import warnings
warnings.filterwarnings("ignore")
from huggingface_hub import login
hf_token = os.environ.get("HF_TOKEN")
if hf_token:
login(token=hf_token)
print("✓ Logged in to HuggingFace")
import gradio as gr
from dotenv import load_dotenv
import torch
from PIL import Image
import base64
from io import BytesIO
load_dotenv()
from langgraph.checkpoint.memory import MemorySaver
from medrax.models import ModelFactory
from medrax.agent import Agent
from medrax.utils import load_prompts_from_file
os.makedirs("temp", exist_ok=True)
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
tools = []
if device == "cuda":
# Load GPU-based tools
# NV-Reason-CXR - Re-enabled for L40S (48GB VRAM)
# With 48GB, we have room for all tools: MAIRA-2 (22GB) + VQA (2GB) + NV-Reason (7GB)
try:
from medrax.tools import NVReasonCXRTool
nv_reason_tool = NVReasonCXRTool(
device=device,
load_in_4bit=True # Use quantization to save VRAM (~7GB)
)
tools.append(nv_reason_tool)
print("✓ Loaded NV-Reason-CXR tool")
except Exception as e:
print(f"✗ Failed to load NV-Reason-CXR tool: {e}")
# MAIRA-2 Grounding - Re-enabled for L40S (48GB VRAM)
try:
from medrax.tools import XRayPhraseGroundingTool
grounding_tool = XRayPhraseGroundingTool(
device=device,
temp_dir="temp",
load_in_4bit=False # Quantization disabled due to compatibility
)
tools.append(grounding_tool)
print("✓ Loaded grounding tool")
except Exception as e:
print(f"✗ Failed to load grounding tool: {e}")
try:
from medrax.tools.vqa import CheXagentXRayVQATool
vqa_tool = CheXagentXRayVQATool(
device=device,
temp_dir="temp",
load_in_4bit=True
)
tools.append(vqa_tool)
print("✓ Loaded VQA tool")
except Exception as e:
print(f"✗ Failed to load VQA tool: {e}")
try:
from medrax.tools.classification import TorchXRayVisionClassifierTool
classification_tool = TorchXRayVisionClassifierTool(
device=device
)
tools.append(classification_tool)
print("✓ Loaded classification tool")
except Exception as e:
print(f"✗ Failed to load classification tool: {e}")
try:
from medrax.tools.report_generation import ChestXRayReportGeneratorTool
report_tool = ChestXRayReportGeneratorTool(
device=device
)
tools.append(report_tool)
print("✓ Loaded report generation tool")
except Exception as e:
print(f"✗ Failed to load report generation tool: {e}")
try:
from medrax.tools.segmentation import ChestXRaySegmentationTool
segmentation_tool = ChestXRaySegmentationTool(
device=device,
temp_dir="temp"
)
tools.append(segmentation_tool)
print("✓ Loaded segmentation tool")
except Exception as e:
print(f"✗ Failed to load segmentation tool: {e}")
# Load non-GPU tools
try:
from medrax.tools.dicom import DicomProcessorTool
dicom_tool = DicomProcessorTool(temp_dir="temp")
tools.append(dicom_tool)
print("✓ Loaded DICOM tool")
except Exception as e:
print(f"✗ Failed to load DICOM tool: {e}")
try:
from medrax.tools.browsing import WebBrowserTool
browsing_tool = WebBrowserTool()
tools.append(browsing_tool)
print("✓ Loaded web browsing tool")
except Exception as e:
print(f"✗ Failed to load web browsing tool: {e}")
checkpointer = MemorySaver()
llm = ModelFactory.create_model(
model_name="gemini-2.0-flash",
temperature=0.7,
max_tokens=5000
)
prompts = load_prompts_from_file("medrax/docs/system_prompts.txt")
# We'll initialize the agent dynamically based on mode
# Remove the global agent initialization
print(f"Tools loaded: {len(tools)}")
import glob
# Store agents for each mode to avoid recreating them
agents_cache = {}
def get_or_create_agent(mode):
"""Get or create an agent for the specified mode."""
if mode not in agents_cache:
# Select appropriate prompt based on mode
if mode == "socratic":
prompt = prompts.get("SOCRATIC_TUTOR", "You are a Socratic medical educator.")
else:
prompt = prompts.get("MEDICAL_ASSISTANT", "You are a helpful medical imaging assistant.")
# Create agent with specified mode
agents_cache[mode] = Agent(
llm,
tools=tools,
system_prompt=prompt,
checkpointer=checkpointer,
mode=mode # Pass the mode to the Agent
)
return agents_cache[mode]
def chat(message, history, mode):
"""Chat function that uses the appropriate agent based on mode."""
config = {"configurable": {"thread_id": f"thread_{mode}"}}
# Get or create the appropriate agent
agent = get_or_create_agent(mode)
# Handle multimodal input - Gemini 2.0 Flash supports vision
image_content = None
if isinstance(message, dict):
text = message.get("text", "")
files = message.get("files", [])
if files and len(files) > 0:
image_path = files[0]
# Check if it's a DICOM file
is_dicom = image_path.lower().endswith(('.dcm', '.dicom'))
# Store image path for tools to use
# LangChain Google GenAI expects images as base64 or PIL
try:
if is_dicom:
# DICOM files need to be converted first
# We'll just pass the path and let the agent handle it
text = f"[DICOM file uploaded: {image_path}]\n\n{text}"
print(f"DICOM file detected: {image_path}")
else:
# Open and encode image for Gemini
with Image.open(image_path) as img:
# Convert to RGB if needed
if img.mode != "RGB":
img = img.convert("RGB")
# Resize if too large (max 4096x4096 for Gemini)
max_size = 4096
if img.width > max_size or img.height > max_size:
img.thumbnail((max_size, max_size), Image.Resampling.LANCZOS)
# Store as bytes for LangChain
buffered = BytesIO()
img.save(buffered, format="PNG")
img_bytes = buffered.getvalue()
img_b64 = base64.b64encode(img_bytes).decode()
# Create multimodal content for Gemini
# Format: [{"type": "text", "text": "..."}, {"type": "image_url", "image_url": {"url": "data:image/png;base64,..."}}]
image_content = {
"type": "image_url",
"image_url": {
"url": f"data:image/png;base64,{img_b64}"
}
}
# Include image path in text for tools to use
text = f"[Image: {image_path}]\n\n{text}"
except Exception as e:
print(f"Error processing image: {e}")
import traceback
traceback.print_exc()
text = f"[Failed to load image: {image_path}. Error: {str(e)}]\n\n{text}"
message = text
# Create message content - multimodal if image exists
if image_content:
# For Gemini multimodal: pass list of content parts
user_message = [
{"type": "text", "text": message},
image_content
]
else:
user_message = message
response = agent.workflow.invoke(
{"messages": [("user", user_message)]},
config=config
)
# Extract text response
assistant_message = response["messages"][-1].content
# Check for visualization images (grounding or segmentation)
viz_image = None
viz_files = glob.glob("temp/grounding_*.png") + glob.glob("temp/segmentation_*.png")
if viz_files:
# Get the most recent visualization
viz_files.sort(key=os.path.getmtime, reverse=True)
latest_viz = viz_files[0]
# Return the file path directly - Gradio can handle it
if os.path.exists(latest_viz):
viz_image = latest_viz
return assistant_message, viz_image
# Custom interface with image output
with gr.Blocks() as demo:
gr.Markdown(f"# MedRAX2 - Medical AI Assistant\n**Device:** {device} | **Tools:** {len(tools)} loaded | **Orchestrator:** Gemini 2.0 Flash")
# Add mode toggle at the top
with gr.Row():
mode_toggle = gr.Radio(
["Assistant Mode", "Tutor Mode"],
value="Assistant Mode",
label="Interaction Mode",
info="Assistant Mode: Direct answers | Tutor Mode: Socratic guidance through questions"
)
# Side-by-side layout: Chat on left, Visualization on right
with gr.Row():
# Left column: Chat interface (unified chat + message box)
with gr.Column(scale=2):
# Chatbot with reduced height to leave room for message box
try:
chatbot = gr.Chatbot(type="messages", height=520, show_label=False)
except TypeError:
# Fallback for older Gradio versions
chatbot = gr.Chatbot(height=520, show_label=False)
# Message box directly below chatbot (no gap)
msg = gr.MultimodalTextbox(
placeholder="Upload an X-ray image (JPG, PNG, DICOM) and ask a question...",
file_types=["image", ".dcm", ".dicom", ".DCM", ".DICOM"],
show_label=False,
container=False
)
# Right column: Visualization
with gr.Column(scale=1):
viz_output = gr.Image(label="Visualization", height=600)
def respond(message, chat_history, mode_selection):
# Convert mode selection to internal mode string
mode = "socratic" if mode_selection == "Tutor Mode" else "assistant"
# Get response and visualization with mode
bot_message, viz_image = chat(message, chat_history, mode)
# Initialize chat history if None
if chat_history is None:
chat_history = []
# Extract text from multimodal message
if isinstance(message, dict):
user_text = message.get("text", "")
if message.get("files"):
user_text = f"[Image uploaded] {user_text}"
else:
user_text = message
# Add BOTH user message and assistant response to create proper chat flow
chat_history.append({"role": "user", "content": user_text})
chat_history.append({"role": "assistant", "content": bot_message})
return "", chat_history, viz_image
msg.submit(respond, [msg, chatbot, mode_toggle], [msg, chatbot, viz_output])
gr.Examples(
examples=[
[{"text": "What do you see in this X-ray?", "files": []}],
[{"text": "Can you show me where exactly using grounding?", "files": []}],
],
inputs=msg,
)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860)