Adibvafa commited on
Commit
ad5b346
·
1 Parent(s): 235ec3d

Support customizing systemp prompt and API endpoint

Browse files
Files changed (3) hide show
  1. main.py +32 -6
  2. medrax/docs/system_prompts.txt +3 -1
  3. pyproject.toml +1 -0
main.py CHANGED
@@ -12,6 +12,7 @@ with different model weights, tools, and parameters.
12
  import warnings
13
  import os
14
  import argparse
 
15
  import threading
16
  import uvicorn
17
  from typing import Dict, List, Optional, Any
@@ -156,7 +157,7 @@ def run_gradio_interface(agent, tools_dict, host="0.0.0.0", port=8686):
156
  demo.launch(server_name=host, server_port=port, share=True)
157
 
158
 
159
- def run_api_server(agent, tools_dict, host="0.0.0.0", port=8000):
160
  """
161
  Run the FastAPI server.
162
 
@@ -165,10 +166,33 @@ def run_api_server(agent, tools_dict, host="0.0.0.0", port=8000):
165
  tools_dict: Dictionary of available tools
166
  host (str): Host to bind the server to
167
  port (int): Port to run the server on
 
168
  """
169
  print(f"Starting API server on {host}:{port}")
 
 
 
 
 
 
 
 
 
 
 
 
 
170
  app = create_api(agent, tools_dict)
171
- uvicorn.run(app, host=host, port=port)
 
 
 
 
 
 
 
 
 
172
 
173
 
174
  def parse_arguments():
@@ -186,6 +210,7 @@ def parse_arguments():
186
  parser.add_argument("--gradio-port", type=int, default=8686, help="Gradio port")
187
  parser.add_argument("--api-host", default="0.0.0.0", help="API host address")
188
  parser.add_argument("--api-port", type=int, default=8000, help="API port")
 
189
 
190
  # Model and system configuration
191
  parser.add_argument(
@@ -328,7 +353,7 @@ if __name__ == "__main__":
328
  "XRayPhraseGroundingTool", # For locating described features in X-rays
329
 
330
  # VQA Tools
331
- "MedGemmaVQATool", # Google MedGemma VQA tool
332
  "XRayVQATool", # For visual question answering on X-rays
333
  # "LlavaMedTool", # For multimodal medical image understanding
334
 
@@ -336,7 +361,7 @@ if __name__ == "__main__":
336
  "MedicalRAGTool", # For retrieval-augmented generation with medical knowledge
337
 
338
  # Search Tools
339
- "WebBrowserTool", # For web browsing and search capabilities
340
  "DuckDuckGoSearchTool", # For privacy-focused web search using DuckDuckGo
341
 
342
  # Development Tools
@@ -351,6 +376,7 @@ if __name__ == "__main__":
351
  print(f"Using device: {device}")
352
  print(f"Using model: {args.model}")
353
  print(f"Selected tools: {selected_tools}")
 
354
 
355
  # Setup the MedGemma environment if the MedGemmaVQATool is selected
356
  if "MedGemmaVQATool" in selected_tools:
@@ -393,13 +419,13 @@ if __name__ == "__main__":
393
  run_gradio_interface(agent, tools_dict, args.gradio_host, args.gradio_port)
394
 
395
  elif args.mode == "api":
396
- run_api_server(agent, tools_dict, args.api_host, args.api_port)
397
 
398
  elif args.mode == "both":
399
  # Run both services in separate threads
400
  api_thread = threading.Thread(
401
  target=run_api_server,
402
- args=(agent, tools_dict, args.api_host, args.api_port)
403
  )
404
  api_thread.daemon = True
405
  api_thread.start()
 
12
  import warnings
13
  import os
14
  import argparse
15
+ from pyngrok import ngrok
16
  import threading
17
  import uvicorn
18
  from typing import Dict, List, Optional, Any
 
157
  demo.launch(server_name=host, server_port=port, share=True)
158
 
159
 
160
+ def run_api_server(agent, tools_dict, host="0.0.0.0", port=8585, public=False):
161
  """
162
  Run the FastAPI server.
163
 
 
166
  tools_dict: Dictionary of available tools
167
  host (str): Host to bind the server to
168
  port (int): Port to run the server on
169
+ public (bool): Whether to expose via ngrok tunnel
170
  """
171
  print(f"Starting API server on {host}:{port}")
172
+
173
+ if public:
174
+ try:
175
+ public_tunnel = ngrok.connect(port)
176
+ public_url = public_tunnel.public_url
177
+ print(f"🌍 Public URL: {public_url}\n🌍 API Documentation: {public_url}/docs\n🌍 Share this URL with your friend!\n{'=' * 60}")
178
+ except ImportError:
179
+ print("⚠️ pyngrok not installed. Install with: pip install pyngrok\nRunning locally only...")
180
+ public = False
181
+ except Exception as e:
182
+ print(f"⚠️ Failed to create public tunnel: {e}\nRunning locally only...")
183
+ public = False
184
+
185
  app = create_api(agent, tools_dict)
186
+
187
+ try:
188
+ uvicorn.run(app, host=host, port=port)
189
+ finally:
190
+ if public:
191
+ try:
192
+ ngrok.disconnect(public_tunnel.public_url)
193
+ ngrok.kill()
194
+ except:
195
+ pass
196
 
197
 
198
  def parse_arguments():
 
210
  parser.add_argument("--gradio-port", type=int, default=8686, help="Gradio port")
211
  parser.add_argument("--api-host", default="0.0.0.0", help="API host address")
212
  parser.add_argument("--api-port", type=int, default=8000, help="API port")
213
+ parser.add_argument("--public", action="store_true", help="Make API publicly accessible via ngrok tunnel")
214
 
215
  # Model and system configuration
216
  parser.add_argument(
 
353
  "XRayPhraseGroundingTool", # For locating described features in X-rays
354
 
355
  # VQA Tools
356
+ # "MedGemmaVQATool", # Google MedGemma VQA tool
357
  "XRayVQATool", # For visual question answering on X-rays
358
  # "LlavaMedTool", # For multimodal medical image understanding
359
 
 
361
  "MedicalRAGTool", # For retrieval-augmented generation with medical knowledge
362
 
363
  # Search Tools
364
+ # "WebBrowserTool", # For web browsing and search capabilities
365
  "DuckDuckGoSearchTool", # For privacy-focused web search using DuckDuckGo
366
 
367
  # Development Tools
 
376
  print(f"Using device: {device}")
377
  print(f"Using model: {args.model}")
378
  print(f"Selected tools: {selected_tools}")
379
+ print(f"Using system prompt: {args.system_prompt}")
380
 
381
  # Setup the MedGemma environment if the MedGemmaVQATool is selected
382
  if "MedGemmaVQATool" in selected_tools:
 
419
  run_gradio_interface(agent, tools_dict, args.gradio_host, args.gradio_port)
420
 
421
  elif args.mode == "api":
422
+ run_api_server(agent, tools_dict, args.api_host, args.api_port, args.public)
423
 
424
  elif args.mode == "both":
425
  # Run both services in separate threads
426
  api_thread = threading.Thread(
427
  target=run_api_server,
428
+ args=(agent, tools_dict, args.api_host, args.api_port, args.public)
429
  )
430
  api_thread.daemon = True
431
  api_thread.start()
medrax/docs/system_prompts.txt CHANGED
@@ -22,4 +22,6 @@ Use your state-of-the art reasoning and critical thinking skills to answer the q
22
  You may use tools (if available) to complement your reasoning and you are allowed to make multiple tool calls in parallel or in sequence as needed for comprehensive answers.
23
  Think critically about how to best use the tools available to you and scrutinize the tool outputs.
24
  When encountering a multiple-choice question, your final response should end with "Final answer: \boxed{A}" from list of possible choices A, B, C, D, E, F.
25
- It is extremely important that you answer strictly in the format described above.
 
 
 
22
  You may use tools (if available) to complement your reasoning and you are allowed to make multiple tool calls in parallel or in sequence as needed for comprehensive answers.
23
  Think critically about how to best use the tools available to you and scrutinize the tool outputs.
24
  When encountering a multiple-choice question, your final response should end with "Final answer: \boxed{A}" from list of possible choices A, B, C, D, E, F.
25
+ It is extremely important that you answer strictly in the format described above.
26
+
27
+ [EMPTY]
pyproject.toml CHANGED
@@ -74,6 +74,7 @@ dependencies = [
74
  "huggingface_hub>=0.17.0",
75
  "iopath>=0.1.10",
76
  "duckduckgo-search>=4.0.0",
 
77
  ]
78
 
79
  [project.optional-dependencies]
 
74
  "huggingface_hub>=0.17.0",
75
  "iopath>=0.1.10",
76
  "duckduckgo-search>=4.0.0",
77
+ "pyngrok>=7.0.0",
78
  ]
79
 
80
  [project.optional-dependencies]