|
|
""" |
|
|
MedRax2 Gradio Interface for Hugging Face Spaces |
|
|
Simple standalone version for deployment |
|
|
""" |
|
|
import os |
|
|
os.environ["OMP_NUM_THREADS"] = "1" |
|
|
|
|
|
import warnings |
|
|
warnings.filterwarnings("ignore") |
|
|
|
|
|
from huggingface_hub import login |
|
|
hf_token = os.environ.get("HF_TOKEN") |
|
|
if hf_token: |
|
|
login(token=hf_token) |
|
|
print("β Logged in to HuggingFace") |
|
|
|
|
|
import gradio as gr |
|
|
from dotenv import load_dotenv |
|
|
import torch |
|
|
|
|
|
load_dotenv() |
|
|
|
|
|
from langgraph.checkpoint.memory import MemorySaver |
|
|
from medrax.models import ModelFactory |
|
|
from medrax.agent import Agent |
|
|
from medrax.utils import load_prompts_from_file |
|
|
|
|
|
os.makedirs("temp", exist_ok=True) |
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
print(f"Using device: {device}") |
|
|
|
|
|
tools = [] |
|
|
|
|
|
if device == "cuda": |
|
|
try: |
|
|
from medrax.tools import XRayPhraseGroundingTool |
|
|
grounding_tool = XRayPhraseGroundingTool( |
|
|
device=device, |
|
|
temp_dir="temp", |
|
|
load_in_4bit=True |
|
|
) |
|
|
tools.append(grounding_tool) |
|
|
print("β Loaded grounding tool") |
|
|
except Exception as e: |
|
|
print(f"β Failed to load grounding tool: {e}") |
|
|
|
|
|
checkpointer = MemorySaver() |
|
|
|
|
|
llm = ModelFactory.create_model( |
|
|
model_name="gemini-2.0-flash", |
|
|
temperature=0.7, |
|
|
max_tokens=5000 |
|
|
) |
|
|
|
|
|
prompts = load_prompts_from_file("medrax/docs/system_prompts.txt") |
|
|
prompt = prompts.get("MEDICAL_ASSISTANT", "You are a helpful medical imaging assistant.") |
|
|
|
|
|
agent = Agent( |
|
|
llm, |
|
|
tools=tools, |
|
|
system_prompt=prompt, |
|
|
checkpointer=checkpointer, |
|
|
) |
|
|
|
|
|
print(f"Tools loaded: {len(tools)}") |
|
|
|
|
|
def chat(message, history): |
|
|
config = {"configurable": {"thread_id": "default"}} |
|
|
|
|
|
|
|
|
if isinstance(message, dict): |
|
|
text = message.get("text", "") |
|
|
files = message.get("files", []) |
|
|
if files: |
|
|
file_info = f"[Image uploaded: {files[0]}]\n\n" |
|
|
text = file_info + text |
|
|
message = text |
|
|
|
|
|
response = agent.workflow.invoke( |
|
|
{"messages": [("user", message)]}, |
|
|
config=config |
|
|
) |
|
|
return response["messages"][-1].content |
|
|
|
|
|
demo = gr.ChatInterface( |
|
|
fn=chat, |
|
|
title="MedRAX2 - Medical AI Assistant", |
|
|
description=f"Device: {device} | Tools: {len(tools)} loaded", |
|
|
multimodal=True |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch(server_name="0.0.0.0", server_port=7860) |
|
|
|