|
|
""" |
|
|
Test script for Socratic mode functionality |
|
|
""" |
|
|
import os |
|
|
os.environ["OMP_NUM_THREADS"] = "1" |
|
|
|
|
|
import warnings |
|
|
warnings.filterwarnings("ignore") |
|
|
|
|
|
from dotenv import load_dotenv |
|
|
load_dotenv() |
|
|
|
|
|
from langgraph.checkpoint.memory import MemorySaver |
|
|
from medrax.models import ModelFactory |
|
|
from medrax.agent import Agent |
|
|
from medrax.utils import load_prompts_from_file |
|
|
|
|
|
|
|
|
checkpointer = MemorySaver() |
|
|
|
|
|
|
|
|
llm = ModelFactory.create_model( |
|
|
model_name="gemini-2.0-flash", |
|
|
temperature=0.7, |
|
|
max_tokens=5000 |
|
|
) |
|
|
|
|
|
|
|
|
prompts = load_prompts_from_file("medrax/docs/system_prompts.txt") |
|
|
|
|
|
|
|
|
print("\n" + "="*60) |
|
|
print("Testing ASSISTANT MODE") |
|
|
print("="*60) |
|
|
|
|
|
assistant_prompt = prompts.get("MEDICAL_ASSISTANT", "You are a helpful medical imaging assistant.") |
|
|
assistant_agent = Agent( |
|
|
llm, |
|
|
tools=[], |
|
|
system_prompt=assistant_prompt, |
|
|
checkpointer=checkpointer, |
|
|
mode="assistant" |
|
|
) |
|
|
|
|
|
config = {"configurable": {"thread_id": "test_assistant"}} |
|
|
response = assistant_agent.workflow.invoke( |
|
|
{"messages": [("user", "What is pneumonia?")]}, |
|
|
config=config |
|
|
) |
|
|
|
|
|
print("\nAssistant Response:") |
|
|
print(response["messages"][-1].content) |
|
|
|
|
|
|
|
|
print("\n" + "="*60) |
|
|
print("Testing SOCRATIC TUTOR MODE") |
|
|
print("="*60) |
|
|
|
|
|
socratic_prompt = prompts.get("SOCRATIC_TUTOR", "You are a Socratic medical educator.") |
|
|
socratic_agent = Agent( |
|
|
llm, |
|
|
tools=[], |
|
|
system_prompt=socratic_prompt, |
|
|
checkpointer=checkpointer, |
|
|
mode="socratic" |
|
|
) |
|
|
|
|
|
config = {"configurable": {"thread_id": "test_socratic"}} |
|
|
response = socratic_agent.workflow.invoke( |
|
|
{"messages": [("user", "What is pneumonia?")]}, |
|
|
config=config |
|
|
) |
|
|
|
|
|
print("\nSocratic Tutor Response:") |
|
|
print(response["messages"][-1].content) |
|
|
|
|
|
print("\n" + "="*60) |
|
|
print("Test Complete!") |
|
|
print("="*60) |
|
|
print("\nThe key difference:") |
|
|
print("- Assistant Mode: Provides direct answers and explanations") |
|
|
print("- Socratic Mode: Guides learning through questions without giving direct answers") |