medrax2 / app.py
samwell's picture
Add Hugging Face Space deployment files
b6fb0da
raw
history blame
2.35 kB
"""
MedRax2 Gradio Interface for Hugging Face Spaces
Simple standalone version for deployment
"""
import os
os.environ["OMP_NUM_THREADS"] = "1"
import warnings
warnings.filterwarnings("ignore")
from huggingface_hub import login
hf_token = os.environ.get("HF_TOKEN")
if hf_token:
login(token=hf_token)
print("βœ“ Logged in to HuggingFace")
import gradio as gr
from dotenv import load_dotenv
import torch
load_dotenv()
from langgraph.checkpoint.memory import MemorySaver
from medrax.models import ModelFactory
from medrax.agent import Agent
from medrax.utils import load_prompts_from_file
os.makedirs("temp", exist_ok=True)
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
tools = []
if device == "cuda":
try:
from medrax.tools import XRayPhraseGroundingTool
grounding_tool = XRayPhraseGroundingTool(
device=device,
temp_dir="temp",
load_in_4bit=True
)
tools.append(grounding_tool)
print("βœ“ Loaded grounding tool")
except Exception as e:
print(f"βœ— Failed to load grounding tool: {e}")
checkpointer = MemorySaver()
llm = ModelFactory.create_model(
model_name="gemini-2.0-flash",
temperature=0.7,
max_tokens=5000
)
prompts = load_prompts_from_file("medrax/docs/system_prompts.txt")
prompt = prompts.get("MEDICAL_ASSISTANT", "You are a helpful medical imaging assistant.")
agent = Agent(
llm,
tools=tools,
system_prompt=prompt,
checkpointer=checkpointer,
)
print(f"Tools loaded: {len(tools)}")
def chat(message, history):
config = {"configurable": {"thread_id": "default"}}
# Handle multimodal input
if isinstance(message, dict):
text = message.get("text", "")
files = message.get("files", [])
if files:
file_info = f"[Image uploaded: {files[0]}]\n\n"
text = file_info + text
message = text
response = agent.workflow.invoke(
{"messages": [("user", message)]},
config=config
)
return response["messages"][-1].content
demo = gr.ChatInterface(
fn=chat,
title="MedRAX2 - Medical AI Assistant",
description=f"Device: {device} | Tools: {len(tools)} loaded",
multimodal=True
)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860)