|
|
""" |
|
|
MedRax2 Gradio Interface for Hugging Face Spaces |
|
|
Simple standalone version for deployment |
|
|
""" |
|
|
import os |
|
|
os.environ["OMP_NUM_THREADS"] = "1" |
|
|
|
|
|
import warnings |
|
|
warnings.filterwarnings("ignore") |
|
|
|
|
|
from huggingface_hub import login |
|
|
hf_token = os.environ.get("HF_TOKEN") |
|
|
if hf_token: |
|
|
login(token=hf_token) |
|
|
print("✓ Logged in to HuggingFace") |
|
|
|
|
|
import gradio as gr |
|
|
from dotenv import load_dotenv |
|
|
import torch |
|
|
from PIL import Image |
|
|
import base64 |
|
|
from io import BytesIO |
|
|
|
|
|
load_dotenv() |
|
|
|
|
|
from langgraph.checkpoint.memory import MemorySaver |
|
|
from medrax.models import ModelFactory |
|
|
from medrax.agent import Agent |
|
|
from medrax.utils import load_prompts_from_file |
|
|
|
|
|
os.makedirs("temp", exist_ok=True) |
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
print(f"Using device: {device}") |
|
|
|
|
|
tools = [] |
|
|
|
|
|
if device == "cuda": |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
from medrax.tools import NVReasonCXRTool |
|
|
nv_reason_tool = NVReasonCXRTool( |
|
|
device=device, |
|
|
load_in_4bit=True |
|
|
) |
|
|
tools.append(nv_reason_tool) |
|
|
print("✓ Loaded NV-Reason-CXR tool") |
|
|
except Exception as e: |
|
|
print(f"✗ Failed to load NV-Reason-CXR tool: {e}") |
|
|
|
|
|
|
|
|
try: |
|
|
from medrax.tools import XRayPhraseGroundingTool |
|
|
grounding_tool = XRayPhraseGroundingTool( |
|
|
device=device, |
|
|
temp_dir="temp", |
|
|
load_in_4bit=False |
|
|
) |
|
|
tools.append(grounding_tool) |
|
|
print("✓ Loaded grounding tool") |
|
|
except Exception as e: |
|
|
print(f"✗ Failed to load grounding tool: {e}") |
|
|
|
|
|
try: |
|
|
from medrax.tools.vqa import CheXagentXRayVQATool |
|
|
vqa_tool = CheXagentXRayVQATool( |
|
|
device=device, |
|
|
temp_dir="temp", |
|
|
load_in_4bit=True |
|
|
) |
|
|
tools.append(vqa_tool) |
|
|
print("✓ Loaded VQA tool") |
|
|
except Exception as e: |
|
|
print(f"✗ Failed to load VQA tool: {e}") |
|
|
|
|
|
try: |
|
|
from medrax.tools.classification import TorchXRayVisionClassifierTool |
|
|
classification_tool = TorchXRayVisionClassifierTool( |
|
|
device=device |
|
|
) |
|
|
tools.append(classification_tool) |
|
|
print("✓ Loaded classification tool") |
|
|
except Exception as e: |
|
|
print(f"✗ Failed to load classification tool: {e}") |
|
|
|
|
|
try: |
|
|
from medrax.tools.report_generation import ChestXRayReportGeneratorTool |
|
|
report_tool = ChestXRayReportGeneratorTool( |
|
|
device=device |
|
|
) |
|
|
tools.append(report_tool) |
|
|
print("✓ Loaded report generation tool") |
|
|
except Exception as e: |
|
|
print(f"✗ Failed to load report generation tool: {e}") |
|
|
|
|
|
try: |
|
|
from medrax.tools.segmentation import ChestXRaySegmentationTool |
|
|
segmentation_tool = ChestXRaySegmentationTool( |
|
|
device=device, |
|
|
temp_dir="temp" |
|
|
) |
|
|
tools.append(segmentation_tool) |
|
|
print("✓ Loaded segmentation tool") |
|
|
except Exception as e: |
|
|
print(f"✗ Failed to load segmentation tool: {e}") |
|
|
|
|
|
|
|
|
try: |
|
|
from medrax.tools.dicom import DicomProcessorTool |
|
|
dicom_tool = DicomProcessorTool(temp_dir="temp") |
|
|
tools.append(dicom_tool) |
|
|
print("✓ Loaded DICOM tool") |
|
|
except Exception as e: |
|
|
print(f"✗ Failed to load DICOM tool: {e}") |
|
|
|
|
|
try: |
|
|
from medrax.tools.browsing import WebBrowserTool |
|
|
browsing_tool = WebBrowserTool() |
|
|
tools.append(browsing_tool) |
|
|
print("✓ Loaded web browsing tool") |
|
|
except Exception as e: |
|
|
print(f"✗ Failed to load web browsing tool: {e}") |
|
|
|
|
|
checkpointer = MemorySaver() |
|
|
|
|
|
llm = ModelFactory.create_model( |
|
|
model_name="gemini-2.0-flash", |
|
|
temperature=0.7, |
|
|
max_tokens=5000 |
|
|
) |
|
|
|
|
|
prompts = load_prompts_from_file("medrax/docs/system_prompts.txt") |
|
|
|
|
|
|
|
|
|
|
|
print(f"Tools loaded: {len(tools)}") |
|
|
|
|
|
import glob |
|
|
|
|
|
|
|
|
agents_cache = {} |
|
|
|
|
|
def get_or_create_agent(mode): |
|
|
"""Get or create an agent for the specified mode.""" |
|
|
if mode not in agents_cache: |
|
|
|
|
|
if mode == "socratic": |
|
|
prompt = prompts.get("SOCRATIC_TUTOR", "You are a Socratic medical educator.") |
|
|
else: |
|
|
prompt = prompts.get("MEDICAL_ASSISTANT", "You are a helpful medical imaging assistant.") |
|
|
|
|
|
|
|
|
agents_cache[mode] = Agent( |
|
|
llm, |
|
|
tools=tools, |
|
|
system_prompt=prompt, |
|
|
checkpointer=checkpointer, |
|
|
mode=mode |
|
|
) |
|
|
return agents_cache[mode] |
|
|
|
|
|
def chat(message, history, mode): |
|
|
"""Chat function that uses the appropriate agent based on mode.""" |
|
|
config = {"configurable": {"thread_id": f"thread_{mode}"}} |
|
|
|
|
|
|
|
|
agent = get_or_create_agent(mode) |
|
|
|
|
|
|
|
|
image_content = None |
|
|
if isinstance(message, dict): |
|
|
text = message.get("text", "") |
|
|
files = message.get("files", []) |
|
|
|
|
|
if files and len(files) > 0: |
|
|
image_path = files[0] |
|
|
|
|
|
|
|
|
is_dicom = image_path.lower().endswith(('.dcm', '.dicom')) |
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
if is_dicom: |
|
|
|
|
|
|
|
|
text = f"[DICOM file uploaded: {image_path}]\n\n{text}" |
|
|
print(f"DICOM file detected: {image_path}") |
|
|
else: |
|
|
|
|
|
with Image.open(image_path) as img: |
|
|
|
|
|
if img.mode != "RGB": |
|
|
img = img.convert("RGB") |
|
|
|
|
|
|
|
|
max_size = 4096 |
|
|
if img.width > max_size or img.height > max_size: |
|
|
img.thumbnail((max_size, max_size), Image.Resampling.LANCZOS) |
|
|
|
|
|
|
|
|
buffered = BytesIO() |
|
|
img.save(buffered, format="PNG") |
|
|
img_bytes = buffered.getvalue() |
|
|
img_b64 = base64.b64encode(img_bytes).decode() |
|
|
|
|
|
|
|
|
|
|
|
image_content = { |
|
|
"type": "image_url", |
|
|
"image_url": { |
|
|
"url": f"data:image/png;base64,{img_b64}" |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
text = f"[Image: {image_path}]\n\n{text}" |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error processing image: {e}") |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
text = f"[Failed to load image: {image_path}. Error: {str(e)}]\n\n{text}" |
|
|
|
|
|
message = text |
|
|
|
|
|
|
|
|
if image_content: |
|
|
|
|
|
user_message = [ |
|
|
{"type": "text", "text": message}, |
|
|
image_content |
|
|
] |
|
|
else: |
|
|
user_message = message |
|
|
|
|
|
response = agent.workflow.invoke( |
|
|
{"messages": [("user", user_message)]}, |
|
|
config=config |
|
|
) |
|
|
|
|
|
|
|
|
assistant_message = response["messages"][-1].content |
|
|
|
|
|
|
|
|
viz_image = None |
|
|
viz_files = glob.glob("temp/grounding_*.png") + glob.glob("temp/segmentation_*.png") |
|
|
if viz_files: |
|
|
|
|
|
viz_files.sort(key=os.path.getmtime, reverse=True) |
|
|
latest_viz = viz_files[0] |
|
|
|
|
|
if os.path.exists(latest_viz): |
|
|
viz_image = latest_viz |
|
|
|
|
|
return assistant_message, viz_image |
|
|
|
|
|
|
|
|
with gr.Blocks() as demo: |
|
|
gr.Markdown(f"# MedRAX2 - Medical AI Assistant\n**Device:** {device} | **Tools:** {len(tools)} loaded | **Orchestrator:** Gemini 2.0 Flash") |
|
|
|
|
|
|
|
|
with gr.Row(): |
|
|
mode_toggle = gr.Radio( |
|
|
["Assistant Mode", "Tutor Mode"], |
|
|
value="Assistant Mode", |
|
|
label="Interaction Mode", |
|
|
info="Assistant Mode: Direct answers | Tutor Mode: Socratic guidance through questions" |
|
|
) |
|
|
|
|
|
|
|
|
with gr.Row(): |
|
|
|
|
|
with gr.Column(scale=2): |
|
|
|
|
|
try: |
|
|
chatbot = gr.Chatbot(type="messages", height=520, show_label=False) |
|
|
except TypeError: |
|
|
|
|
|
chatbot = gr.Chatbot(height=520, show_label=False) |
|
|
|
|
|
|
|
|
msg = gr.MultimodalTextbox( |
|
|
placeholder="Upload an X-ray image (JPG, PNG, DICOM) and ask a question...", |
|
|
file_types=["image", ".dcm", ".dicom", ".DCM", ".DICOM"], |
|
|
show_label=False, |
|
|
container=False |
|
|
) |
|
|
|
|
|
|
|
|
with gr.Column(scale=1): |
|
|
viz_output = gr.Image(label="Visualization", height=600) |
|
|
|
|
|
def respond(message, chat_history, mode_selection): |
|
|
|
|
|
mode = "socratic" if mode_selection == "Tutor Mode" else "assistant" |
|
|
|
|
|
|
|
|
bot_message, viz_image = chat(message, chat_history, mode) |
|
|
|
|
|
|
|
|
if chat_history is None: |
|
|
chat_history = [] |
|
|
|
|
|
|
|
|
if isinstance(message, dict): |
|
|
user_text = message.get("text", "") |
|
|
if message.get("files"): |
|
|
user_text = f"[Image uploaded] {user_text}" |
|
|
else: |
|
|
user_text = message |
|
|
|
|
|
|
|
|
chat_history.append({"role": "user", "content": user_text}) |
|
|
chat_history.append({"role": "assistant", "content": bot_message}) |
|
|
|
|
|
return "", chat_history, viz_image |
|
|
|
|
|
msg.submit(respond, [msg, chatbot, mode_toggle], [msg, chatbot, viz_output]) |
|
|
|
|
|
gr.Examples( |
|
|
examples=[ |
|
|
[{"text": "What do you see in this X-ray?", "files": []}], |
|
|
[{"text": "Can you show me where exactly using grounding?", "files": []}], |
|
|
], |
|
|
inputs=msg, |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch(server_name="0.0.0.0", server_port=7860) |
|
|
|