Adibvafa commited on
Commit
50157f0
·
1 Parent(s): 128b355

Add security to gradio

Browse files
Files changed (1) hide show
  1. main.py +62 -7
main.py CHANGED
@@ -135,7 +135,8 @@ def initialize_agent(
135
  return agent, tools_dict
136
 
137
 
138
- def run_gradio_interface(agent, tools_dict, host="0.0.0.0", port=8686):
 
139
  """
140
  Run the Gradio web interface.
141
 
@@ -144,10 +145,32 @@ def run_gradio_interface(agent, tools_dict, host="0.0.0.0", port=8686):
144
  tools_dict: Dictionary of available tools
145
  host (str): Host to bind the server to
146
  port (int): Port to run the server on
 
 
147
  """
148
  print(f"Starting Gradio interface on {host}:{port}")
 
 
 
 
 
 
 
 
 
149
  demo = create_demo(agent, tools_dict)
150
- demo.launch(server_name=host, server_port=port, share=True)
 
 
 
 
 
 
 
 
 
 
 
151
 
152
 
153
  def run_api_server(agent, tools_dict, host="0.0.0.0", port=8585, public=False):
@@ -194,15 +217,26 @@ def parse_arguments():
194
  """Parse command line arguments."""
195
  parser = argparse.ArgumentParser(description="MedRAX - Medical Reasoning Agent for Chest X-ray")
196
 
197
- # Server configuration
198
  parser.add_argument(
199
  "--mode",
200
  choices=["gradio", "api", "both"],
201
  default="gradio",
202
  help="Run mode: 'gradio' for web interface, 'api' for REST API, 'both' for both services",
203
  )
 
 
204
  parser.add_argument("--gradio-host", default="0.0.0.0", help="Gradio host address")
205
  parser.add_argument("--gradio-port", type=int, default=8686, help="Gradio port")
 
 
 
 
 
 
 
 
 
206
  parser.add_argument("--api-host", default="0.0.0.0", help="API host address")
207
  parser.add_argument("--api-port", type=int, default=8000, help="API port")
208
  parser.add_argument("--public", action="store_true", help="Make API publicly accessible via ngrok tunnel")
@@ -316,6 +350,14 @@ if __name__ == "__main__":
316
  print(f"Using model: {args.model}")
317
  print(f"Selected tools: {selected_tools}")
318
  print(f"Using system prompt: {args.system_prompt}")
 
 
 
 
 
 
 
 
319
 
320
  # Setup the MedGemma environment if the MedGemmaVQATool is selected
321
  if "MedGemmaVQATool" in selected_tools:
@@ -355,7 +397,13 @@ if __name__ == "__main__":
355
 
356
  # Launch based on selected mode
357
  if args.mode == "gradio":
358
- run_gradio_interface(agent, tools_dict, args.gradio_host, args.gradio_port)
 
 
 
 
 
 
359
 
360
  elif args.mode == "api":
361
  run_api_server(agent, tools_dict, args.api_host, args.api_port, args.public)
@@ -363,10 +411,17 @@ if __name__ == "__main__":
363
  elif args.mode == "both":
364
  # Run both services in separate threads
365
  api_thread = threading.Thread(
366
- target=run_api_server, args=(agent, tools_dict, args.api_host, args.api_port, args.public)
 
367
  )
368
  api_thread.daemon = True
369
  api_thread.start()
370
 
371
- # Run Gradio in main thread
372
- run_gradio_interface(agent, tools_dict, args.gradio_host, args.gradio_port)
 
 
 
 
 
 
 
135
  return agent, tools_dict
136
 
137
 
138
+ def run_gradio_interface(agent, tools_dict, host="0.0.0.0", port=8686,
139
+ auth=None, share=False):
140
  """
141
  Run the Gradio web interface.
142
 
 
145
  tools_dict: Dictionary of available tools
146
  host (str): Host to bind the server to
147
  port (int): Port to run the server on
148
+ auth: Authentication credentials (tuple)
149
+ share (bool): Whether to create a shareable public link
150
  """
151
  print(f"Starting Gradio interface on {host}:{port}")
152
+
153
+ if auth:
154
+ print(f"🔐 Authentication enabled for user: {auth[0]}")
155
+ else:
156
+ print("⚠️ Running without authentication (public access)")
157
+
158
+ if share:
159
+ print("🌍 Creating shareable public link (expires in 1 week)...")
160
+
161
  demo = create_demo(agent, tools_dict)
162
+
163
+ # Prepare launch parameters
164
+ launch_kwargs = {
165
+ "server_name": host,
166
+ "server_port": port,
167
+ "share": share
168
+ }
169
+
170
+ if auth:
171
+ launch_kwargs["auth"] = auth
172
+
173
+ demo.launch(**launch_kwargs)
174
 
175
 
176
  def run_api_server(agent, tools_dict, host="0.0.0.0", port=8585, public=False):
 
217
  """Parse command line arguments."""
218
  parser = argparse.ArgumentParser(description="MedRAX - Medical Reasoning Agent for Chest X-ray")
219
 
220
+ # Run mode
221
  parser.add_argument(
222
  "--mode",
223
  choices=["gradio", "api", "both"],
224
  default="gradio",
225
  help="Run mode: 'gradio' for web interface, 'api' for REST API, 'both' for both services",
226
  )
227
+
228
+ # Gradio interface options
229
  parser.add_argument("--gradio-host", default="0.0.0.0", help="Gradio host address")
230
  parser.add_argument("--gradio-port", type=int, default=8686, help="Gradio port")
231
+ parser.add_argument("--auth", nargs=2, metavar=("USERNAME", "PASSWORD"),
232
+ default=["admin", "adibjun"],
233
+ help="Enable password authentication (default: admin adibjun)")
234
+ parser.add_argument("--no-auth", action="store_true",
235
+ help="Disable authentication (public access)")
236
+ parser.add_argument("--share", action="store_true",
237
+ help="Create a temporary shareable link (expires in 1 week)")
238
+
239
+ # API server options
240
  parser.add_argument("--api-host", default="0.0.0.0", help="API host address")
241
  parser.add_argument("--api-port", type=int, default=8000, help="API port")
242
  parser.add_argument("--public", action="store_true", help="Make API publicly accessible via ngrok tunnel")
 
350
  print(f"Using model: {args.model}")
351
  print(f"Selected tools: {selected_tools}")
352
  print(f"Using system prompt: {args.system_prompt}")
353
+
354
+ # Set up authentication (simplified with argparse defaults)
355
+ if args.no_auth:
356
+ auth_credentials = None
357
+ print("⚠️ Authentication disabled (public access)")
358
+ else:
359
+ auth_credentials = tuple(args.auth) # Uses default ["admin", "adibjun"] if not specified
360
+ print(f"✅ Authentication enabled for user: {auth_credentials[0]}")
361
 
362
  # Setup the MedGemma environment if the MedGemmaVQATool is selected
363
  if "MedGemmaVQATool" in selected_tools:
 
397
 
398
  # Launch based on selected mode
399
  if args.mode == "gradio":
400
+ run_gradio_interface(
401
+ agent, tools_dict,
402
+ host=args.gradio_host,
403
+ port=args.gradio_port,
404
+ auth=auth_credentials,
405
+ share=args.share
406
+ )
407
 
408
  elif args.mode == "api":
409
  run_api_server(agent, tools_dict, args.api_host, args.api_port, args.public)
 
411
  elif args.mode == "both":
412
  # Run both services in separate threads
413
  api_thread = threading.Thread(
414
+ target=run_api_server,
415
+ args=(agent, tools_dict, args.api_host, args.api_port, args.public)
416
  )
417
  api_thread.daemon = True
418
  api_thread.start()
419
 
420
+ # Run Gradio in main thread with authentication and sharing
421
+ run_gradio_interface(
422
+ agent, tools_dict,
423
+ host=args.gradio_host,
424
+ port=args.gradio_port,
425
+ auth=auth_credentials,
426
+ share=args.share
427
+ )