medrax2 / interface.py
Adibvafa
Dont directly eval tool call
530a475
raw
history blame
13.9 kB
import re
import base64
import json
import ast
import gradio as gr
from pathlib import Path
import time
import shutil
from typing import AsyncGenerator, List, Optional, Tuple
from gradio import ChatMessage
from langchain_core.messages import AIMessage, AIMessageChunk, ToolMessage
class ChatInterface:
"""
A chat interface for interacting with a medical AI agent through Gradio.
Handles file uploads, message processing, and chat history management.
Supports both regular image files and DICOM medical imaging files.
"""
def __init__(self, agent, tools_dict):
"""
Initialize the chat interface.
Args:
agent: The medical AI agent to handle requests
tools_dict (dict): Dictionary of available tools for image processing
"""
self.agent = agent
self.tools_dict = tools_dict
self.upload_dir = Path(f"temp/{time.time()}")
self.upload_dir.mkdir(exist_ok=True)
self.current_thread_id = None
# Separate storage for original and display paths
self.original_file_path = None # For LLM (.dcm or other)
self.display_file_path = None # For UI (always viewable format)
self.pending_tool_calls = {}
def handle_upload(self, file_path: str) -> str:
"""
Handle new file upload and set appropriate paths.
Args:
file_path (str): Path to the uploaded file
Returns:
str: Display path for UI, or None if no file uploaded
"""
if not file_path:
return None
source = Path(file_path)
timestamp = int(time.time())
# Save original file with proper suffix
suffix = source.suffix.lower()
saved_path = self.upload_dir / f"upload_{timestamp}{suffix}"
shutil.copy2(file_path, saved_path) # Use file_path directly instead of source
self.original_file_path = str(saved_path)
# Handle DICOM conversion for display only
if suffix == ".dcm":
output, _ = self.tools_dict["DicomProcessorTool"]._run(str(saved_path))
self.display_file_path = output["image_path"]
else:
self.display_file_path = str(saved_path)
return self.display_file_path
def add_message(self, message: str, display_image: str, history: List[dict]) -> Tuple[List[dict], gr.Textbox]:
"""
Add a new message to the chat history.
Args:
message (str): Text message to add
display_image (str): Path to image being displayed
history (List[dict]): Current chat history
Returns:
Tuple[List[dict], gr.Textbox]: Updated history and textbox component
"""
image_path = self.original_file_path or display_image
if image_path is not None:
history.append({"role": "user", "content": {"path": image_path}})
if message is not None:
history.append({"role": "user", "content": message})
return history, gr.Textbox(value=message, interactive=False)
async def process_message(
self, message: str, display_image: Optional[str], chat_history: List[ChatMessage]
) -> AsyncGenerator[Tuple[List[ChatMessage], Optional[str], str], None]:
"""
Process a message and generate responses.
Args:
message (str): User message to process
display_image (Optional[str]): Path to currently displayed image
chat_history (List[ChatMessage]): Current chat history
Yields:
Tuple[List[ChatMessage], Optional[str], str]: Updated chat history, display path, and empty string
"""
chat_history = chat_history or []
# Initialize thread if needed
if not self.current_thread_id:
self.current_thread_id = str(time.time())
messages = []
image_path = self.original_file_path or display_image
if image_path is not None:
# Send path for tools
messages.append({"role": "user", "content": f"image_path: {image_path}"})
# Load and encode image for multimodal
with open(image_path, "rb") as img_file:
img_base64 = base64.b64encode(img_file.read()).decode("utf-8")
messages.append(
{
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {"url": f"data:image/jpeg;base64,{img_base64}"},
}
],
}
)
if message is not None:
messages.append({"role": "user", "content": [{"type": "text", "text": message}]})
try:
accumulated_content = ""
final_message = None
for chunk in self.agent.workflow.stream(
{"messages": messages},
{"configurable": {"thread_id": self.current_thread_id}},
stream_mode="updates",
):
if not isinstance(chunk, dict):
continue
for node_name, node_output in chunk.items():
if "messages" not in node_output:
continue
for msg in node_output["messages"]:
if isinstance(msg, AIMessageChunk) and msg.content:
accumulated_content += msg.content
if final_message is None:
final_message = ChatMessage(role="assistant", content=accumulated_content)
chat_history.append(final_message)
else:
final_message.content = accumulated_content
yield chat_history, self.display_file_path, ""
elif isinstance(msg, AIMessage):
if msg.content:
final_content = re.sub(r"temp/[^\s]*", "", msg.content).strip()
if final_message:
final_message.content = final_content
else:
chat_history.append(ChatMessage(role="assistant", content=final_content))
yield chat_history, self.display_file_path, ""
if msg.tool_calls:
for tool_call in msg.tool_calls:
self.pending_tool_calls[tool_call["id"]] = {
"name": tool_call["name"],
"args": tool_call["args"],
}
final_message = None
accumulated_content = ""
elif isinstance(msg, ToolMessage):
tool_call_id = msg.tool_call_id
if tool_call_id in self.pending_tool_calls:
pending_call = self.pending_tool_calls.pop(tool_call_id)
tool_name = pending_call["name"]
tool_args = pending_call["args"]
# Parse content
try:
# Try JSON parsing first
result = json.loads(msg.content)
tool_output_str = json.dumps(result, indent=2)
except json.JSONDecodeError:
try:
# Use ast.literal_eval as safe fallback for Python literals
content_tuple = ast.literal_eval(msg.content)
result = content_tuple[0]
tool_output_str = json.dumps(result, indent=2)
except (ValueError, SyntaxError):
# Fall back to treating as plain string
result = msg.content
tool_output_str = str(msg.content)
# Display tool usage card
tool_args_str = json.dumps(tool_args, indent=2)
description = f"**Input:**\n```json\n{tool_args_str}\n```\n\n**Output:**\n```json\n{tool_output_str}\n```"
metadata = {
"title": f"βš’οΈ Tool: {tool_name}",
"description": description,
"status": "done",
}
chat_history.append(
ChatMessage(
role="assistant",
content=description,
metadata=metadata,
)
)
# Special handling for image_visualizer
if tool_name == "image_visualizer":
image_path = None
try:
image_path = result["image_path"]
except (TypeError, KeyError):
try:
image_path = result[0]["image_path"]
except (TypeError, KeyError, IndexError):
pass
if image_path:
self.display_file_path = image_path
chat_history.append(
ChatMessage(
role="assistant",
content={"path": self.display_file_path},
)
)
# Yield a single update for this tool event
yield chat_history, self.display_file_path, ""
except Exception as e:
chat_history.append(
ChatMessage(role="assistant", content=f"❌ Error: {str(e)}", metadata={"title": "Error"})
)
yield chat_history, self.display_file_path, ""
def create_demo(agent, tools_dict):
"""
Create a Gradio demo interface for the medical AI agent.
Args:
agent: The medical AI agent to handle requests
tools_dict (dict): Dictionary of available tools for image processing
Returns:
gr.Blocks: Gradio Blocks interface
"""
interface = ChatInterface(agent, tools_dict)
with gr.Blocks(theme=gr.themes.Soft()) as demo:
with gr.Column():
gr.Markdown(
"""
# πŸ₯ MedRAX-2
Medical Reasoning Agent for Chest X-ray
"""
)
with gr.Row():
with gr.Column(scale=5):
chatbot = gr.Chatbot(
[],
height=1000,
container=True,
show_label=True,
elem_classes="chat-box",
type="messages",
label="Agent",
avatar_images=(
None,
"assets/medrax_logo.jpg",
),
)
with gr.Row():
with gr.Column(scale=3):
txt = gr.Textbox(
show_label=False,
placeholder="Ask about the X-ray...",
container=False,
)
with gr.Column(scale=3):
image_display = gr.Image(label="Image", type="filepath", height=600, container=True)
with gr.Row():
upload_button = gr.UploadButton(
"πŸ“Ž Upload X-Ray",
file_types=["image"],
)
dicom_upload = gr.UploadButton(
"πŸ“„ Upload DICOM",
file_types=["file"],
)
with gr.Row():
new_chat_btn = gr.Button("New Chat")
# Event handlers
def new_chat():
interface.original_file_path = None
interface.display_file_path = None
interface.current_thread_id = str(time.time())
return [], None
def handle_file_upload(file):
return interface.handle_upload(file.name)
chat_msg = txt.submit(interface.add_message, inputs=[txt, image_display, chatbot], outputs=[chatbot, txt])
bot_msg = chat_msg.then(
interface.process_message,
inputs=[txt, image_display, chatbot],
outputs=[chatbot, image_display, txt],
)
bot_msg.then(lambda: gr.Textbox(interactive=True), None, [txt])
upload_button.upload(handle_file_upload, inputs=upload_button, outputs=image_display)
dicom_upload.upload(handle_file_upload, inputs=dicom_upload, outputs=image_display)
new_chat_btn.click(new_chat, outputs=[chatbot, image_display])
return demo