#!/usr/bin/env python3 """ Gradio UI for IP Assist Lite Interactive interface for medical information retrieval """ import sys import os import time import threading from pathlib import Path from typing import List, Tuple, Dict, Any import json import logging from collections import OrderedDict # Add parent directory to path sys.path.insert(0, str(Path(__file__).parent.parent)) import gradio as gr from datetime import datetime from orchestration.langgraph_agent import IPAssistOrchestrator from retrieval.hybrid_retriever import HybridRetriever from utils.serialization import to_jsonable # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # TTL Cache implementation class TTLCache: def __init__(self, maxsize=256, ttl=600): self.maxsize, self.ttl = maxsize, ttl self._data = OrderedDict() def get(self, key): v = self._data.get(key) if not v: return None val, ts = v if time.time() - ts > self.ttl: del self._data[key] return None self._data.move_to_end(key) return val def set(self, key, val): self._data[key] = (val, time.time()) self._data.move_to_end(key) if len(self._data) > self.maxsize: self._data.popitem(last=False) # Initialize caches _RESULT_CACHE = TTLCache( maxsize=int(os.getenv("RESULT_CACHE_MAX", "256")), ttl=int(os.getenv("RESULT_TTL_SEC", "600")), ) _INDEX_FINGERPRINT = os.getenv("INDEX_FINGERPRINT", "v1") # bump to invalidate cache after reindex # Stats cache _STATS_CACHE = {"html": "", "ts": 0.0} _STATS_TTL_SEC = int(os.getenv("STATS_TTL_SEC", "900")) # Thread-safe orchestrator singleton _orchestrator = None _orch_lock = threading.Lock() def get_orchestrator(): global _orchestrator if _orchestrator is None: with _orch_lock: if _orchestrator is None: logger.info("Initializing orchestrator...") _orchestrator = IPAssistOrchestrator() logger.info("Orchestrator initialized") return _orchestrator # Color coding for different elements EMERGENCY_COLOR = "#ff4444" WARNING_COLOR = "#ff9800" SUCCESS_COLOR = "#4caf50" INFO_COLOR = "#2196f3" # Allowed models for local UI - include GPT-4 as fallback options ALLOWED_MODELS = ["gpt-5-nano", "gpt-5-mini", "gpt-5", "gpt-4o-mini", "gpt-4o"] def _sanitize_model(selected: str | None) -> str: m = (selected or os.getenv("IP_GPT5_MODEL", "gpt-4o-mini")).strip() # Default to gpt-4o-mini for better reliability return m if m in ALLOWED_MODELS else "gpt-4o-mini" def format_response_html(result: Dict[str, Any]) -> str: """Format the response with proper HTML styling.""" html_parts = [] # Emergency banner if needed if result["is_emergency"]: html_parts.append(f"""
🚨 EMERGENCY DETECTED - Immediate action required
""") # Query type, confidence, and model used model_used = result.get("model_used") or result.get("llm_model_used") or "—" html_parts.append(f"""
Query Type: {result['query_type'].replace('_', ' ').title()}
Confidence: {result['confidence_score']:.1%}
Model: {model_used}
""") # LLM warning banner (e.g., fallback used) if result.get("llm_warning"): html_parts.append(f"""
{result.get('llm_warning')}
""") # LLM error banner (e.g., GPT-5 unavailable or auth issue) if result.get("llm_error"): html_parts.append(f"""
❌ {result.get('llm_error')}
""") # Safety flags if present if result["safety_flags"]: flags_html = ", ".join([f"⚠️ {flag}" for flag in result["safety_flags"]]) html_parts.append(f"""
Safety Considerations: {flags_html}
""") # Main response response_text = result["response"].replace("\n", "
") html_parts.append(f"""
{response_text}
""") # Citations if result["citations"]: citations_html = "📚 Sources:" html_parts.append(citations_html) # Review flag if result["needs_review"]: html_parts.append(f"""
⚠️ This response has been flagged for review due to safety concerns
""") return "".join(html_parts) def process_query(query: str, use_reranker: bool = True, top_k: int = 5, model: str = "gpt-5-mini") -> Tuple[str, str, str]: """Process a query and return formatted results.""" query_norm = (query or "").strip() if not query_norm: return "", "Please enter a query", json.dumps(to_jsonable({}), indent=2) # Budget knobs (two-stage) retrieve_m = int(os.getenv("RETRIEVE_M", "30")) # fast retriever fan-out rerank_n = int(os.getenv("RERANK_N", "10")) # cross-encoder candidates k = max(1, min(int(top_k), rerank_n)) # final results to display # Cache key (includes knobs + index version + model) chosen_model = _sanitize_model(model) cache_key = f"{_INDEX_FINGERPRINT}|{query_norm.lower()}|rerank={bool(use_reranker)}|k={k}|M={retrieve_m}|N={rerank_n}|model={chosen_model}" cached = _RESULT_CACHE.get(cache_key) if cached: html, _, meta = cached return html, "⚡ Cached result", meta start = time.time() orch = get_orchestrator() # Set sanitized model in orchestrator orch.set_model(chosen_model) # Call the orchestrator with proper error handling try: result = orch.process_query( query_norm, use_reranker=bool(use_reranker), top_k=int(k), retrieve_m=int(retrieve_m), rerank_n=int(rerank_n), ) except TypeError: # Older signature: try passing just the basics try: result = orch.process_query( query_norm, use_reranker=bool(use_reranker), top_k=int(k), ) except TypeError: # Legacy: last resort result = orch.process_query(query_norm) except Exception as e: logger.error(f"Error processing query: {e}") # Return a helpful error response result = { "response": f"An error occurred while processing your query. Please try again with a different model.\n\nError: {str(e)}", "query_type": "error", "is_emergency": False, "confidence_score": 0.0, "safety_flags": [], "citations": [], "needs_review": True, "llm_error": str(e), "model_used": chosen_model } # Format your existing result as before response_html = format_response_html(result) # Minimal metadata for quick inspection metadata = { "timestamp": time.strftime("%Y-%m-%dT%H:%M:%S"), "latency_ms": int((time.time() - start) * 1000), "reranker_used": bool(use_reranker), "top_k": int(k), "retrieved_m": int(retrieve_m), "rerank_n": int(rerank_n) if use_reranker else 0, "cache_hit": False, "query_type": result.get("query_type", "unknown"), "is_emergency": result.get("is_emergency", False), "confidence_score": f"{result.get('confidence_score', 0):.2%}", "safety_flags": result.get("safety_flags", []), "needs_review": result.get("needs_review", False), "citations_count": len(result.get("citations", [])), # LLM telemetry "model_requested": chosen_model, "model_used": result.get("model_used"), "llm_warning": result.get("llm_warning"), "llm_error": result.get("llm_error"), } metadata_json = json.dumps(to_jsonable(metadata), indent=2, ensure_ascii=False) # Status message if result.get("is_emergency"): status = "🚨 Emergency query processed successfully" elif result.get("needs_review"): status = "⚠️ Query processed - review recommended" else: status = "✅ Query processed successfully" _RESULT_CACHE.set(cache_key, (response_html, status, metadata_json)) return response_html, status, metadata_json def search_cpt(cpt_code: str) -> str: """Search for a specific CPT code.""" if not cpt_code or not cpt_code.isdigit() or len(cpt_code) != 5: return "Please enter a valid 5-digit CPT code" try: orch = get_orchestrator() retriever = orch.retriever if cpt_code in retriever.cpt_index: chunk_ids = retriever.cpt_index[cpt_code] results_html = f"

Found {len(chunk_ids)} results for CPT {cpt_code}

" for i, chunk_id in enumerate(chunk_ids[:5], 1): if chunk_id in retriever.chunk_map: chunk = retriever.chunk_map[chunk_id] results_html += f"""
Result {i}
Document: {chunk.get('doc_id', 'Unknown')}
Section: {chunk.get('section_title', 'Unknown')}
Year: {chunk.get('year', 'Unknown')}
{chunk['text'][:500]}...
""" return results_html else: return f"No results found for CPT code {cpt_code}" except Exception as e: logger.error(f"CPT search error: {e}") return f"Error searching for CPT code: {str(e)}" def get_system_stats(force_refresh: bool = False) -> str: """Get system statistics.""" now = time.time() if not force_refresh and _STATS_CACHE["html"] and now - _STATS_CACHE["ts"] < _STATS_TTL_SEC: return _STATS_CACHE["html"] try: orch = get_orchestrator() chunks = orch.retriever.chunks # Calculate statistics stats = { "Total Chunks": len(chunks), "Unique Documents": len(set(c.get("doc_id", "") for c in chunks)), "Authority Tiers": {}, "Evidence Levels": {}, "Document Types": {} } for chunk in chunks: # Authority at = chunk.get("authority_tier", "Unknown") stats["Authority Tiers"][at] = stats["Authority Tiers"].get(at, 0) + 1 # Evidence el = chunk.get("evidence_level", "Unknown") stats["Evidence Levels"][el] = stats["Evidence Levels"].get(el, 0) + 1 # Type dt = chunk.get("doc_type", "Unknown") stats["Document Types"][dt] = stats["Document Types"].get(dt, 0) + 1 # Format as HTML html = "

System Statistics

" html += f"

Total Chunks: {stats['Total Chunks']:,}

" html += f"

Unique Documents: {stats['Unique Documents']:,}

" html += "

Authority Distribution

" html += "

Evidence Level Distribution

" html += "

Document Type Distribution

" _STATS_CACHE["html"] = html _STATS_CACHE["ts"] = now return html except Exception as e: logger.error(f"Stats error: {e}") return f"Error getting statistics: {str(e)}" # Example queries for quick testing EXAMPLE_QUERIES = [ "What are the contraindications for bronchoscopy?", "Massive hemoptysis management protocol", "CPT code for EBUS-TBNA with needle aspiration", "Pediatric bronchoscopy dosing for lidocaine", "How to place fiducial markers for SBRT?", "Complications of endobronchial valve placement", "Sedation options for flexible bronchoscopy", "Management of malignant airway obstruction", "Cryobiopsy technique and yield rates", "Robotic bronchoscopy navigation accuracy" ] # Build Gradio interface def build_interface(): """Build the Gradio interface.""" with gr.Blocks(title="IP Assist Lite", theme=gr.themes.Soft()) as demo: gr.Markdown(""" # 🏥 IP Assist Lite ### Medical Information Retrieval for Interventional Pulmonology **Features:** - 🔍 Hybrid search with MedCPT embeddings - 📊 Hierarchy-aware ranking (Authority & Evidence) - 🚨 Emergency detection and routing - ⚠️ Safety checks for critical information - 📚 Source citations with confidence scoring """) with gr.Tabs(): # Main Query Tab with gr.Tab("Query Assistant"): with gr.Row(): with gr.Column(scale=2): query_input = gr.Textbox( label="Enter your medical query", placeholder="e.g., What are the contraindications for bronchoscopy?", lines=3 ) with gr.Row(): submit_btn = gr.Button("🔍 Submit Query", variant="primary") clear_btn = gr.Button("🗑️ Clear") gr.Examples( examples=EXAMPLE_QUERIES, inputs=query_input, label="Example Queries" ) with gr.Column(scale=1): model_selector = gr.Dropdown( choices=ALLOWED_MODELS, value="gpt-4o-mini", # Default to GPT-4 for reliability label="Model", info="Select the model (GPT-4 models are more reliable)" ) use_reranker = gr.Checkbox(label="Use Reranker", value=True) top_k = gr.Slider( minimum=1, maximum=10, value=5, step=1, label="Number of Results" ) status_output = gr.Textbox( label="Status", interactive=False, lines=2 ) response_output = gr.HTML(label="Response") metadata_output = gr.JSON(label="Metadata", visible=True) # Connect events submit_btn.click( fn=process_query, inputs=[query_input, use_reranker, top_k, model_selector], outputs=[response_output, status_output, metadata_output] ) clear_btn.click( fn=lambda: ("", "", "", ""), outputs=[query_input, response_output, status_output, metadata_output] ) # CPT Code Search Tab with gr.Tab("CPT Code Search"): with gr.Row(): cpt_input = gr.Textbox( label="Enter CPT Code", placeholder="e.g., 31622", max_lines=1 ) cpt_search_btn = gr.Button("Search CPT", variant="primary") cpt_output = gr.HTML(label="CPT Code Information") gr.Examples( examples=["31622", "31628", "31633", "31645", "31652"], inputs=cpt_input, label="Common CPT Codes" ) cpt_search_btn.click( fn=search_cpt, inputs=cpt_input, outputs=cpt_output ) # System Statistics Tab with gr.Tab("System Statistics"): stats_btn = gr.Button("📊 Refresh Statistics", variant="secondary") stats_output = gr.HTML(label="System Statistics") # Load stats on tab load stats_btn.click( fn=get_system_stats, outputs=stats_output ) gr.Markdown(""" --- ### ⚠️ Important Notice This system is for informational purposes only. Always verify medical information with official guidelines and consult with qualified healthcare professionals before making clinical decisions. **Safety Features:** - Emergency queries are automatically flagged and prioritized - Pediatric and dosage information includes safety warnings - Contraindications are highlighted when detected - Responses requiring review are clearly marked """) return demo # Main execution if __name__ == "__main__": demo = build_interface() # Pre-warm orchestrator on startup print("🔥 Pre-warming orchestrator...") get_orchestrator() print("✅ Orchestrator ready") # Keep UI responsive under concurrency demo.queue( max_size=int(os.getenv("GRADIO_QUEUE_MAX", "128")) ) demo.launch( server_name="0.0.0.0", server_port=7860, share=False, # Set to True to create a public link show_error=True )