""" MedRax2 Gradio Interface for Hugging Face Spaces Simple standalone version for deployment """ import os os.environ["OMP_NUM_THREADS"] = "1" import warnings warnings.filterwarnings("ignore") from huggingface_hub import login hf_token = os.environ.get("HF_TOKEN") if hf_token: login(token=hf_token) print("✓ Logged in to HuggingFace") import gradio as gr from dotenv import load_dotenv import torch from PIL import Image import base64 from io import BytesIO load_dotenv() from langgraph.checkpoint.memory import MemorySaver from medrax.models import ModelFactory from medrax.agent import Agent from medrax.utils import load_prompts_from_file os.makedirs("temp", exist_ok=True) device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Using device: {device}") tools = [] if device == "cuda": # Load GPU-based tools # NV-Reason-CXR - Re-enabled for L40S (48GB VRAM) # With 48GB, we have room for all tools: MAIRA-2 (22GB) + VQA (2GB) + NV-Reason (7GB) try: from medrax.tools import NVReasonCXRTool nv_reason_tool = NVReasonCXRTool( device=device, load_in_4bit=True # Use quantization to save VRAM (~7GB) ) tools.append(nv_reason_tool) print("✓ Loaded NV-Reason-CXR tool") except Exception as e: print(f"✗ Failed to load NV-Reason-CXR tool: {e}") # MAIRA-2 Grounding - Re-enabled for L40S (48GB VRAM) try: from medrax.tools import XRayPhraseGroundingTool grounding_tool = XRayPhraseGroundingTool( device=device, temp_dir="temp", load_in_4bit=False # Quantization disabled due to compatibility ) tools.append(grounding_tool) print("✓ Loaded grounding tool") except Exception as e: print(f"✗ Failed to load grounding tool: {e}") try: from medrax.tools.vqa import CheXagentXRayVQATool vqa_tool = CheXagentXRayVQATool( device=device, temp_dir="temp", load_in_4bit=True ) tools.append(vqa_tool) print("✓ Loaded VQA tool") except Exception as e: print(f"✗ Failed to load VQA tool: {e}") try: from medrax.tools.classification import TorchXRayVisionClassifierTool classification_tool = TorchXRayVisionClassifierTool( device=device ) tools.append(classification_tool) print("✓ Loaded classification tool") except Exception as e: print(f"✗ Failed to load classification tool: {e}") try: from medrax.tools.report_generation import ChestXRayReportGeneratorTool report_tool = ChestXRayReportGeneratorTool( device=device ) tools.append(report_tool) print("✓ Loaded report generation tool") except Exception as e: print(f"✗ Failed to load report generation tool: {e}") try: from medrax.tools.segmentation import ChestXRaySegmentationTool segmentation_tool = ChestXRaySegmentationTool( device=device, temp_dir="temp" ) tools.append(segmentation_tool) print("✓ Loaded segmentation tool") except Exception as e: print(f"✗ Failed to load segmentation tool: {e}") # Load non-GPU tools try: from medrax.tools.dicom import DicomProcessorTool dicom_tool = DicomProcessorTool(temp_dir="temp") tools.append(dicom_tool) print("✓ Loaded DICOM tool") except Exception as e: print(f"✗ Failed to load DICOM tool: {e}") try: from medrax.tools.browsing import WebBrowserTool browsing_tool = WebBrowserTool() tools.append(browsing_tool) print("✓ Loaded web browsing tool") except Exception as e: print(f"✗ Failed to load web browsing tool: {e}") checkpointer = MemorySaver() llm = ModelFactory.create_model( model_name="gemini-2.0-flash", temperature=0.7, max_tokens=5000 ) prompts = load_prompts_from_file("medrax/docs/system_prompts.txt") # We'll initialize the agent dynamically based on mode # Remove the global agent initialization print(f"Tools loaded: {len(tools)}") import glob # Store agents for each mode to avoid recreating them agents_cache = {} def get_or_create_agent(mode): """Get or create an agent for the specified mode.""" if mode not in agents_cache: # Select appropriate prompt based on mode if mode == "socratic": prompt = prompts.get("SOCRATIC_TUTOR", "You are a Socratic medical educator.") else: prompt = prompts.get("MEDICAL_ASSISTANT", "You are a helpful medical imaging assistant.") # Create agent with specified mode agents_cache[mode] = Agent( llm, tools=tools, system_prompt=prompt, checkpointer=checkpointer, mode=mode # Pass the mode to the Agent ) return agents_cache[mode] def chat(message, history, mode): """Chat function that uses the appropriate agent based on mode.""" config = {"configurable": {"thread_id": f"thread_{mode}"}} # Get or create the appropriate agent agent = get_or_create_agent(mode) # Handle multimodal input - Gemini 2.0 Flash supports vision image_content = None if isinstance(message, dict): text = message.get("text", "") files = message.get("files", []) if files and len(files) > 0: image_path = files[0] # Check if it's a DICOM file is_dicom = image_path.lower().endswith(('.dcm', '.dicom')) # Store image path for tools to use # LangChain Google GenAI expects images as base64 or PIL try: if is_dicom: # DICOM files need to be converted first # We'll just pass the path and let the agent handle it text = f"[DICOM file uploaded: {image_path}]\n\n{text}" print(f"DICOM file detected: {image_path}") else: # Open and encode image for Gemini with Image.open(image_path) as img: # Convert to RGB if needed if img.mode != "RGB": img = img.convert("RGB") # Resize if too large (max 4096x4096 for Gemini) max_size = 4096 if img.width > max_size or img.height > max_size: img.thumbnail((max_size, max_size), Image.Resampling.LANCZOS) # Store as bytes for LangChain buffered = BytesIO() img.save(buffered, format="PNG") img_bytes = buffered.getvalue() img_b64 = base64.b64encode(img_bytes).decode() # Create multimodal content for Gemini # Format: [{"type": "text", "text": "..."}, {"type": "image_url", "image_url": {"url": "data:image/png;base64,..."}}] image_content = { "type": "image_url", "image_url": { "url": f"data:image/png;base64,{img_b64}" } } # Include image path in text for tools to use text = f"[Image: {image_path}]\n\n{text}" except Exception as e: print(f"Error processing image: {e}") import traceback traceback.print_exc() text = f"[Failed to load image: {image_path}. Error: {str(e)}]\n\n{text}" message = text # Create message content - multimodal if image exists if image_content: # For Gemini multimodal: pass list of content parts user_message = [ {"type": "text", "text": message}, image_content ] else: user_message = message response = agent.workflow.invoke( {"messages": [("user", user_message)]}, config=config ) # Extract text response assistant_message = response["messages"][-1].content # Check for visualization images (grounding or segmentation) viz_image = None viz_files = glob.glob("temp/grounding_*.png") + glob.glob("temp/segmentation_*.png") if viz_files: # Get the most recent visualization viz_files.sort(key=os.path.getmtime, reverse=True) latest_viz = viz_files[0] # Return the file path directly - Gradio can handle it if os.path.exists(latest_viz): viz_image = latest_viz return assistant_message, viz_image # Custom interface with image output with gr.Blocks() as demo: gr.Markdown(f"# MedRAX2 - Medical AI Assistant\n**Device:** {device} | **Tools:** {len(tools)} loaded | **Orchestrator:** Gemini 2.0 Flash") # Add mode toggle at the top with gr.Row(): mode_toggle = gr.Radio( ["Assistant Mode", "Tutor Mode"], value="Assistant Mode", label="Interaction Mode", info="Assistant Mode: Direct answers | Tutor Mode: Socratic guidance through questions" ) # Side-by-side layout: Chat on left, Visualization on right with gr.Row(): # Left column: Chat interface (unified chat + message box) with gr.Column(scale=2): # Chatbot with reduced height to leave room for message box try: chatbot = gr.Chatbot(type="messages", height=520, show_label=False) except TypeError: # Fallback for older Gradio versions chatbot = gr.Chatbot(height=520, show_label=False) # Message box directly below chatbot (no gap) msg = gr.MultimodalTextbox( placeholder="Upload an X-ray image (JPG, PNG, DICOM) and ask a question...", file_types=["image", ".dcm", ".dicom", ".DCM", ".DICOM"], show_label=False, container=False ) # Right column: Visualization with gr.Column(scale=1): viz_output = gr.Image(label="Visualization", height=600) def respond(message, chat_history, mode_selection): # Convert mode selection to internal mode string mode = "socratic" if mode_selection == "Tutor Mode" else "assistant" # Get response and visualization with mode bot_message, viz_image = chat(message, chat_history, mode) # Initialize chat history if None if chat_history is None: chat_history = [] # Extract text from multimodal message if isinstance(message, dict): user_text = message.get("text", "") if message.get("files"): user_text = f"[Image uploaded] {user_text}" else: user_text = message # Add BOTH user message and assistant response to create proper chat flow chat_history.append({"role": "user", "content": user_text}) chat_history.append({"role": "assistant", "content": bot_message}) return "", chat_history, viz_image msg.submit(respond, [msg, chatbot, mode_toggle], [msg, chatbot, viz_output]) gr.Examples( examples=[ [{"text": "What do you see in this X-ray?", "files": []}], [{"text": "Can you show me where exactly using grounding?", "files": []}], ], inputs=msg, ) if __name__ == "__main__": demo.launch(server_name="0.0.0.0", server_port=7860)