File size: 21,256 Bytes
2f79879
 
 
 
 
 
 
 
 
 
 
eaca108
fbcbf94
b9f142a
ad5b346
b9f142a
 
e08f161
eaca108
 
 
 
f237c31
eaca108
 
b9f142a
eaca108
 
 
 
2f79879
eaca108
 
2f79879
 
eaca108
 
 
c4d7934
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31c8268
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ea89378
2f79879
 
465688a
2f79879
465688a
b9f142a
fbcbf94
2f79879
1fd95dc
2f79879
5849e07
a7d0aad
c4d7934
ea89378
 
 
 
 
 
 
 
 
f237c31
19ede8f
a6e10ac
f237c31
a7d0aad
b5eed77
ea89378
 
 
 
2f79879
eaca108
a7d0aad
eaca108
ea89378
ef6fc50
5849e07
ea89378
 
205758b
9a2c640
ea89378
eaca108
 
ea89378
eaca108
 
ea89378
 
a6e10ac
f237c31
e97f266
9a2c640
5b96cf3
 
 
 
0ad1012
9a2c640
5b96cf3
09ddc41
ea89378
2f79879
7b3e756
fbcbf94
 
9a2c640
ea89378
fbcbf94
7b3e756
 
 
 
 
ea89378
 
 
2f79879
eaca108
2f79879
f237c31
 
 
1fd95dc
f237c31
 
 
 
 
2f79879
eaca108
f237c31
eaca108
 
 
 
 
2f79879
eaca108
 
 
50157f0
 
ea89378
b9f142a
9a2c640
b9f142a
 
 
 
 
50157f0
 
ea89378
b9f142a
50157f0
 
 
 
 
 
 
 
 
b9f142a
50157f0
 
 
 
 
 
 
 
 
 
 
 
bc86327
 
ad5b346
b9f142a
 
9a2c640
b9f142a
 
 
 
 
ad5b346
b9f142a
 
9a2c640
ad5b346
 
 
 
9a2c640
 
 
ad5b346
 
 
 
 
 
9a2c640
b9f142a
9a2c640
ad5b346
 
 
 
 
 
 
 
 
bc86327
 
b9f142a
 
 
9a2c640
50157f0
b9f142a
9a2c640
 
b9f142a
9a2c640
b9f142a
50157f0
 
b9f142a
 
50157f0
31c8268
 
50157f0
 
 
 
 
 
b9f142a
 
ad5b346
9a2c640
b9f142a
 
9a2c640
b9f142a
9a2c640
b9f142a
 
9a2c640
b9f142a
 
9a2c640
b9f142a
9a2c640
b9f142a
9a2c640
 
b9f142a
9a2c640
b9f142a
9a2c640
b9f142a
 
9a2c640
b9f142a
9a2c640
b9f142a
 
9a2c640
b9f142a
 
9a2c640
b9f142a
 
9a2c640
b9f142a
9a2c640
 
 
 
 
 
 
b9f142a
 
9a2c640
b9f142a
9a2c640
 
 
 
 
 
b9f142a
9a2c640
c4d7934
 
 
 
 
 
0ad1012
b9f142a
bc86327
ea89378
b9f142a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ad5b346
b9f142a
 
 
 
 
ad5b346
b9f142a
 
 
 
 
 
 
 
 
 
 
 
 
ad5b346
50157f0
31c8268
 
b2aba7d
205758b
c4d7934
 
205758b
c4d7934
 
 
 
 
 
 
 
 
 
 
 
 
205758b
2f79879
 
a6e10ac
b9f142a
 
 
 
 
 
 
 
 
c0ca604
d5baad4
a6e10ac
 
f237c31
 
a122f1b
11020f4
b9f142a
19ede8f
f1b994a
b9f142a
f1b994a
b9f142a
 
5849e07
2f79879
b9f142a
c4d7934
11020f4
eaca108
b9f142a
 
50157f0
 
 
 
 
 
 
b9f142a
 
ad5b346
b9f142a
 
 
 
50157f0
 
b9f142a
 
 
9a2c640
50157f0
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
"""
MedRAX Application Main Module

This module serves as the entry point for the MedRAX medical imaging AI assistant.
It provides functionality to initialize an AI agent with various medical imaging tools
and launch a web interface for interacting with the system.

The system uses OpenAI's language models for reasoning and can be configured
with different model weights, tools, and parameters.
"""

import warnings
import os
import argparse
from pyngrok import ngrok
import threading
import uvicorn
from typing import Dict, List, Optional, Any
from dotenv import load_dotenv
from transformers import logging

from langgraph.checkpoint.memory import MemorySaver
from medrax.models import ModelFactory

from interface import create_demo
from api import create_api
from medrax.agent import *
from medrax.tools import *
from medrax.utils import *

# Suppress unnecessary warnings and logging
warnings.filterwarnings("ignore")
logging.set_verbosity_error()

# Load environment variables from .env file
_ = load_dotenv()


def resolve_medgemma_api_url_from_value(value: Optional[str]) -> str:
    """Resolve the MedGemma API base URL using CLI value, env var, and SLURM-aware fallback.

    Resolution order:
    1) Explicit provided value (e.g., CLI flag)
    2) MEDGEMMA_API_URL environment variable
    3) If on SLURM, require explicit URL (raise)
    4) Otherwise, default to localhost for single-box setups
    """
    if value:
        return value

    env_url = os.getenv("MEDGEMMA_API_URL")
    if env_url:
        return env_url

    if os.getenv("SLURM_JOB_ID") or os.getenv("SLURM_NODEID"):
        raise RuntimeError(
            "MEDGEMMA_API_URL not set and --medgemma-api-url not provided. "
            "On SLURM, the client usually runs on a different node, "
            "so you must point to the server’s reachable IP, e.g. http://<node-ip>:8002"
        )

    return "http://127.0.0.1:8002"


def resolve_medgemma_api_url(args) -> str:
    """Helper that reads from an argparse Namespace if available."""
    return resolve_medgemma_api_url_from_value(getattr(args, "medgemma_api_url", None))


def resolve_auth_credentials(args) -> Optional[tuple]:
    """Resolve authentication credentials from CLI args or environment variables.
    
    Resolution order:
    1) Explicit --no-auth flag (returns None, no warnings)
    2) Explicit --auth USERNAME PASSWORD (returns credentials tuple)
    3) MEDRAX_AUTH_USERNAME and MEDRAX_AUTH_PASSWORD environment variables
    4) Default to None with warning messages
    
    Args:
        args: Parsed command-line arguments
        
    Returns:
        Optional[tuple]: (username, password) tuple if auth is enabled, None otherwise
    """
    if args.no_auth:
        print("⚠️  Authentication disabled (public access)")
        return None
    
    if args.auth:
        username, password = args.auth
        print(f"✅ Authentication enabled for user: {username}")
        return (username, password)
    
    # Try to read from environment variables
    auth_username = os.getenv("MEDRAX_AUTH_USERNAME")
    auth_password = os.getenv("MEDRAX_AUTH_PASSWORD")
    
    if auth_username and auth_password:
        print(f"✅ Authentication enabled from environment for user: {auth_username}")
        return (auth_username, auth_password)
    
    # No auth specified anywhere - default to no auth with warning
    print("⚠️  No authentication configured!")
    print("⚠️  Running without authentication (public access)")
    print("⚠️  To enable auth, either:")
    print("    - Use --auth USERNAME PASSWORD")
    print("    - Set MEDRAX_AUTH_USERNAME and MEDRAX_AUTH_PASSWORD in .env")
    print("    - Or explicitly use --no-auth to suppress this warning")
    return None


def initialize_agent(
    prompt_file: str,
    tools_to_use: Optional[List[str]] = None,
    model_dir: str = "/model-weights",
    temp_dir: str = "temp",
    device: str = "cuda",
    model: str = "gpt-4.1",
    temperature: float = 1.0,
    top_p: float = 0.95,
    max_tokens: int = 5000,
    rag_config: Optional[RAGConfig] = None,
    model_kwargs: Dict[str, Any] = {},
    system_prompt: str = "MEDICAL_ASSISTANT",
    medgemma_api_url: Optional[str] = None,
):
    """Initialize the MedRAX agent with specified tools and configuration.

    Args:
        prompt_file (str): Path to file containing system prompts
        tools_to_use (List[str], optional): List of tool names to initialize. If None, all tools are initialized.
        model_dir (str, optional): Directory containing model weights. Defaults to "/model-weights".
        temp_dir (str, optional): Directory for temporary files. Defaults to "temp".
        device (str, optional): Device to run models on. Defaults to "cuda".
        model (str, optional): Model to use. Defaults to "gpt-4o".
        temperature (float, optional): Temperature for the model. Defaults to 0.7.
        rag_config (RAGConfig, optional): Configuration for the RAG tool. Defaults to None.
        model_kwargs (dict, optional): Additional keyword arguments for model.
        system_prompt (str, optional): System prompt to use. Defaults to "MEDICAL_ASSISTANT".
        debug (bool, optional): Whether to enable debug mode. Defaults to False.

    Returns:
        Tuple[Agent, Dict[str, BaseTool]]: Initialized agent and dictionary of tool instances
    """
    # Load system prompts from file
    prompts = load_prompts_from_file(prompt_file)
    prompt = prompts[system_prompt]

    all_tools = {
        "TorchXRayVisionClassifierTool": lambda: TorchXRayVisionClassifierTool(device=device),
        "ArcPlusClassifierTool": lambda: ArcPlusClassifierTool(cache_dir=model_dir, device=device),
        "ChestXRaySegmentationTool": lambda: ChestXRaySegmentationTool(device=device),
        "LlavaMedTool": lambda: LlavaMedTool(cache_dir=model_dir, device=device, load_in_8bit=True),
        "CheXagentXRayVQATool": lambda: CheXagentXRayVQATool(cache_dir=model_dir, device=device),
        "ChestXRayReportGeneratorTool": lambda: ChestXRayReportGeneratorTool(cache_dir=model_dir, device=device),
        "XRayPhraseGroundingTool": lambda: XRayPhraseGroundingTool(
            cache_dir=model_dir, temp_dir=temp_dir, load_in_8bit=True, device=device
        ),
        "ChestXRayGeneratorTool": lambda: ChestXRayGeneratorTool(
            model_path=f"{model_dir}/roentgen", temp_dir=temp_dir, device=device
        ),
        "ImageVisualizerTool": lambda: ImageVisualizerTool(),
        "DicomProcessorTool": lambda: DicomProcessorTool(temp_dir=temp_dir),
        "MedicalRAGTool": lambda: RAGTool(config=rag_config),
        "WebBrowserTool": lambda: WebBrowserTool(),
        "DuckDuckGoSearchTool": lambda: DuckDuckGoSearchTool(),
        "MedSAM2Tool": lambda: MedSAM2Tool(device=device, cache_dir=model_dir, temp_dir=temp_dir),
        "MedGemmaVQATool": lambda: MedGemmaAPIClientTool(
            cache_dir=model_dir,
            device=device,
            load_in_8bit=True,
            api_url=resolve_medgemma_api_url_from_value(medgemma_api_url),
        ),
    }

    # Initialize only selected tools or all if none specified
    tools_dict: Dict[str, BaseTool] = {}

    if tools_to_use is None:
        tools_to_use = []

    for tool_name in tools_to_use:
        if tool_name == "PythonSandboxTool":
            try:
                tools_dict["PythonSandboxTool"] = create_python_sandbox()
            except Exception as e:
                print(f"Error creating PythonSandboxTool: {e}")
                print("Skipping PythonSandboxTool")
        if tool_name in all_tools:
            tools_dict[tool_name] = all_tools[tool_name]()

    # Set up checkpointing for conversation state
    checkpointer = MemorySaver()

    # Create the language model using the factory
    try:
        llm = ModelFactory.create_model(
            model_name=model, temperature=temperature, top_p=top_p, max_tokens=max_tokens, **model_kwargs
        )
    except ValueError as e:
        print(f"Error creating language model: {e}")
        print(f"Available model providers: {list(ModelFactory._model_providers.keys())}")
        raise

    agent = Agent(
        llm,
        tools=list(tools_dict.values()),
        system_prompt=prompt,
        checkpointer=checkpointer,
    )
    print("Agent initialized")

    return agent, tools_dict


def run_gradio_interface(agent, tools_dict, host="0.0.0.0", port=8686, 
                        auth=None, share=False):
    """
    Run the Gradio web interface.

    Args:
        agent: The initialized MedRAX agent
        tools_dict: Dictionary of available tools
        host (str): Host to bind the server to
        port (int): Port to run the server on
        auth: Authentication credentials (tuple)
        share (bool): Whether to create a shareable public link
    """
    print(f"Starting Gradio interface on {host}:{port}")
    
    if auth:
        print(f"🔐 Authentication enabled for user: {auth[0]}")
    else:
        print("⚠️  Running without authentication (public access)")
    
    if share:
        print("🌍 Creating shareable public link (expires in 1 week)...")
    
    demo = create_demo(agent, tools_dict)
    
    # Prepare launch parameters
    launch_kwargs = {
        "server_name": host,
        "server_port": port,
        "share": share
    }
    
    if auth:
        launch_kwargs["auth"] = auth
        
    demo.launch(**launch_kwargs)


def run_api_server(agent, tools_dict, host="0.0.0.0", port=8585, public=False):
    """
    Run the FastAPI server.

    Args:
        agent: The initialized MedRAX agent
        tools_dict: Dictionary of available tools
        host (str): Host to bind the server to
        port (int): Port to run the server on
        public (bool): Whether to expose via ngrok tunnel
    """
    print(f"Starting API server on {host}:{port}")

    if public:
        try:
            public_tunnel = ngrok.connect(port)
            public_url = public_tunnel.public_url
            print(
                f"🌍 Public URL: {public_url}\n🌍 API Documentation: {public_url}/docs\n🌍 Share this URL with your friend!\n{'=' * 60}"
            )
        except ImportError:
            print("⚠️  pyngrok not installed. Install with: pip install pyngrok\nRunning locally only...")
            public = False
        except Exception as e:
            print(f"⚠️  Failed to create public tunnel: {e}\nRunning locally only...")
            public = False

    app = create_api(agent, tools_dict)

    try:
        uvicorn.run(app, host=host, port=port)
    finally:
        if public:
            try:
                ngrok.disconnect(public_tunnel.public_url)
                ngrok.kill()
            except:
                pass


def parse_arguments():
    """Parse command line arguments."""
    parser = argparse.ArgumentParser(description="MedRAX - Medical Reasoning Agent for Chest X-ray")

    # Run mode
    parser.add_argument(
        "--mode",
        choices=["gradio", "api", "both"],
        default="gradio",
        help="Run mode: 'gradio' for web interface, 'api' for REST API, 'both' for both services",
    )
    
    # Gradio interface options
    parser.add_argument("--gradio-host", default="0.0.0.0", help="Gradio host address")
    parser.add_argument("--gradio-port", type=int, default=8686, help="Gradio port")
    parser.add_argument("--auth", nargs=2, metavar=("USERNAME", "PASSWORD"), 
                       default=None,
                       help="Enable password authentication with specified username and password")
    parser.add_argument("--no-auth", action="store_true", 
                       help="Disable authentication (public access)")
    parser.add_argument("--share", action="store_true", 
                       help="Create a temporary shareable link (expires in 1 week)")
    
    # API server options
    parser.add_argument("--api-host", default="0.0.0.0", help="API host address")
    parser.add_argument("--api-port", type=int, default=8000, help="API port")
    parser.add_argument("--public", action="store_true", help="Make API publicly accessible via ngrok tunnel")

    # Model and system configuration
    parser.add_argument(
        "--model-dir",
        default="/model-weights",
        help="Directory containing model weights (default: uses MODEL_WEIGHTS_DIR env var or '/model-weights')",
    )
    parser.add_argument(
        "--device", default="cuda", help="Device to run models on (default: uses MEDRAX_DEVICE env var or 'cuda:1')"
    )
    parser.add_argument(
        "--model",
        default="gpt-4.1",
        help="Model to use (default: gpt-4.1). Examples: gpt-4.1-2025-04-14, gemini-2.5-pro, gpt-5",
    )
    parser.add_argument("--temperature", type=float, default=1.0, help="Temperature for the model (default: 1.0)")
    parser.add_argument("--temp-dir", default="temp2", help="Directory for temporary files (default: temp2)")
    parser.add_argument(
        "--prompt-file",
        default="medrax/docs/system_prompts.txt",
        help="Path to file containing system prompts (default: medrax/docs/system_prompts.txt)",
    )
    parser.add_argument(
        "--system-prompt", default="MEDICAL_ASSISTANT", help="System prompt to use (default: MEDICAL_ASSISTANT)"
    )

    # RAG configuration
    parser.add_argument(
        "--rag-model", default="command-a-03-2025", help="Chat model for RAG responses (default: command-a-03-2025)"
    )
    parser.add_argument(
        "--rag-embedding-model", default="embed-v4.0", help="Embedding model for RAG system (default: embed-v4.0)"
    )
    parser.add_argument(
        "--rag-rerank-model", default="rerank-v3.5", help="Reranking model for RAG system (default: rerank-v3.5)"
    )
    parser.add_argument("--rag-temperature", type=float, default=0.3, help="Temperature for RAG model (default: 0.3)")
    parser.add_argument("--pinecone-index", default="medrax2", help="Pinecone index name (default: medrax2)")
    parser.add_argument("--chunk-size", type=int, default=1500, help="RAG chunk size (default: 1500)")
    parser.add_argument("--chunk-overlap", type=int, default=300, help="RAG chunk overlap (default: 300)")
    parser.add_argument("--retriever-k", type=int, default=3, help="Number of documents to retrieve (default: 3)")
    parser.add_argument("--rag-docs-dir", default="rag_docs", help="Directory for RAG documents (default: rag_docs)")

    # Tools configuration
    parser.add_argument(
        "--tools",
        nargs="*",
        help="Specific tools to enable (if not provided, uses default set). Available tools: "
        + "ImageVisualizerTool, DicomProcessorTool, MedSAM2Tool, ChestXRaySegmentationTool, "
        + "ChestXRayGeneratorTool, TorchXRayVisionClassifierTool, ArcPlusClassifierTool, "
        + "ChestXRayReportGeneratorTool, XRayPhraseGroundingTool, MedGemmaVQATool, "
        + "XRayVQATool, LlavaMedTool, MedicalRAGTool, WebBrowserTool, DuckDuckGoSearchTool, "
        + "PythonSandboxTool",
    )

    # MedGemma API configuration
    parser.add_argument(
        "--medgemma-api-url",
        default=None,
        help="MedGemma API base URL, e.g. http://127.0.0.1:8002 or http://<node-ip>:8002"
    )

    return parser.parse_args()


if __name__ == "__main__":
    """
    This is the main entry point for the MedRAX application.
    It initializes the agent with the selected tools and creates the demo/API.
    """
    args = parse_arguments()
    print(f"Starting MedRAX in {args.mode} mode...")

    # Configure tools based on arguments
    if args.tools is not None:
        # Use tools specified via command line
        selected_tools = args.tools
    else:
        # Use default tools selection
        selected_tools = [
            # Image Processing Tools
            "ImageVisualizerTool",  # For displaying images in the UI
            # "DicomProcessorTool",  # For processing DICOM medical image files
            # Segmentation Tools
            "MedSAM2Tool",  # For advanced medical image segmentation using MedSAM2
            "ChestXRaySegmentationTool",  # For segmenting anatomical regions in chest X-rays
            # Generation Tools
            # "ChestXRayGeneratorTool",  # For generating synthetic chest X-rays
            # Classification Tools
            "TorchXRayVisionClassifierTool",  # For classifying chest X-ray images using TorchXRayVision
            "ArcPlusClassifierTool",  # For advanced chest X-ray classification using ArcPlus
            # Report Generation Tools
            "ChestXRayReportGeneratorTool",  # For generating medical reports from X-rays
            # Grounding Tools
            "XRayPhraseGroundingTool",  # For locating described features in X-rays
            # VQA Tools
            # "MedGemmaVQATool",  # Google MedGemma VQA tool
            "XRayVQATool",  # For visual question answering on X-rays
            # "LlavaMedTool",  # For multimodal medical image understanding
            # RAG Tools
            "MedicalRAGTool",  # For retrieval-augmented generation with medical knowledge
            # Search Tools
            # "WebBrowserTool",  # For web browsing and search capabilities
            "DuckDuckGoSearchTool",  # For privacy-focused web search using DuckDuckGo
            # Development Tools
            # "PythonSandboxTool",  # Add the Python sandbox tool
        ]

    # Configure model directory and device
    model_dir = args.model_dir or os.getenv("MODEL_WEIGHTS_DIR", "/model-weights")
    device = args.device or os.getenv("MEDRAX_DEVICE", "cuda:0")

    print(f"Using model directory: {model_dir}")
    print(f"Using device: {device}")
    print(f"Using model: {args.model}")
    print(f"Selected tools: {selected_tools}")
    print(f"Using system prompt: {args.system_prompt}")
    
    # Set up authentication (reads from CLI, env vars, or requires explicit choice)
    auth_credentials = resolve_auth_credentials(args)

    # Setup the MedGemma environment if the MedGemmaVQATool is selected
    medgemma_base_url_from_setup: Optional[str] = None
    medgemma_api_url_effective: Optional[str] = args.medgemma_api_url
    if "MedGemmaVQATool" in selected_tools:
        # Launch server and capture its URL if no explicit URL/ENV provided
        try:
            if medgemma_api_url_effective is None and os.getenv("MEDGEMMA_API_URL") is None:
                medgemma_base_url_from_setup = setup_medgemma_env(cache_dir=model_dir, device=device)
                # If we auto-launched, use this URL unless overridden later
                if medgemma_base_url_from_setup:
                    medgemma_api_url_effective = medgemma_base_url_from_setup
                    print(f"MedGemma API auto-launched at {medgemma_api_url_effective}")
            else:
                # Still ensure environment is set up; it will bind to provided host/port
                setup_medgemma_env(cache_dir=model_dir, device=device)
        except Exception as e:
            print(f"Warning: Failed to launch MedGemma service automatically: {e}")

    # Configure the Retrieval Augmented Generation (RAG) system
    # This allows the agent to access and use medical knowledge documents
    rag_config = RAGConfig(
        model=args.rag_model,
        embedding_model=args.rag_embedding_model,
        rerank_model=args.rag_rerank_model,
        temperature=args.rag_temperature,
        pinecone_index_name=args.pinecone_index,
        chunk_size=args.chunk_size,
        chunk_overlap=args.chunk_overlap,
        retriever_k=args.retriever_k,
        local_docs_dir=args.rag_docs_dir,
        huggingface_datasets=["VictorLJZ/medrax2"],  # List of HuggingFace datasets to load
        dataset_split="train",  # Which split of the datasets to use
    )

    # Prepare any additional model-specific kwargs
    model_kwargs = {}

    agent, tools_dict = initialize_agent(
        prompt_file=args.prompt_file,
        tools_to_use=selected_tools,
        model_dir=model_dir,
        temp_dir=args.temp_dir,
        device=device,
        model=args.model,
        temperature=args.temperature,
        model_kwargs=model_kwargs,
        rag_config=rag_config,
        system_prompt=args.system_prompt,
        medgemma_api_url=medgemma_api_url_effective,
    )

    # Launch based on selected mode
    if args.mode == "gradio":
        run_gradio_interface(
            agent, tools_dict, 
            host=args.gradio_host, 
            port=args.gradio_port,
            auth=auth_credentials,
            share=args.share
        )

    elif args.mode == "api":
        run_api_server(agent, tools_dict, args.api_host, args.api_port, args.public)

    elif args.mode == "both":
        # Run both services in separate threads
        api_thread = threading.Thread(
            target=run_api_server, 
            args=(agent, tools_dict, args.api_host, args.api_port, args.public)
        )
        api_thread.daemon = True
        api_thread.start()

        # Run Gradio in main thread with authentication and sharing
        run_gradio_interface(
            agent, tools_dict, 
            host=args.gradio_host, 
            port=args.gradio_port,
            auth=auth_credentials,
            share=args.share
        )