import streamlit as st from evaluater import test from huggingface_hub import hf_hub_download import zipfile from openai import OpenAI import json import os import shutil import time from pathlib import Path from graphrag.cli.query import run_local_search from graphrag.utils import storage import lancedb from streamlit.runtime.scriptrunner import RerunException from streamlit.runtime.scriptrunner.script_run_context import get_script_run_ctx def extract_guidelines(): repo_id = "Cryo3978/medguidelines" zip_filename = "guidelines.zip" other_filenames = [ "embeddings.npy", "embeddings_choices.npy", "index.faiss", "index_choices.faiss" ] target_dir = os.getcwd() print(f" Current working directory: {target_dir}\n") for file in other_filenames: cached_path = hf_hub_download( repo_id=repo_id, filename=file, repo_type="dataset", token=os.getenv("HF_TOKEN") ) target_path = os.path.join(target_dir, file) if not os.path.exists(target_path): shutil.copy(cached_path, target_path) print(f"Copied: {target_path}") zip_path = hf_hub_download( repo_id=repo_id, filename=zip_filename, repo_type="dataset", token=os.getenv("HF_TOKEN") ) extract_dir = os.path.join(target_dir, "guidelines") if not os.path.exists(extract_dir) or not os.listdir(extract_dir): os.makedirs(extract_dir, exist_ok=True) with zipfile.ZipFile(zip_path, "r") as zip_ref: zip_ref.extractall(extract_dir) print(f"\n Extracted to: {os.path.abspath(extract_dir)}") else: print(f"\n Already exists: {extract_dir}") output_dir = os.path.join(extract_dir, "output") if os.path.exists(output_dir): print(f"\n Listing all files in: {output_dir}\n") for root, dirs, files in os.walk(output_dir): level = root.replace(output_dir, '').count(os.sep) indent = ' ' * 4 * level print(f"{indent}{os.path.basename(root)}/") subindent = ' ' * 4 * (level + 1) for f in files: print(f"{subindent}{f}") else: print(f"\n Output directory not found: {output_dir}") print("\n Extraction and listing complete.\n") return extract_dir def clear_chat(): st.session_state.messages = [] def initialize_provider_settings(provider_choice): """Configure API settings for selected provider""" provider_configs = { "OpenAI": { "api_key_source": os.getenv("OPENAI_API_KEY"), "base_url_source": "https://api.openai.com/v1", "fallback_model": "gpt-4o-mini" }, "Deepseek": { "api_key_source": os.getenv("DEEPSEEK_API_KEY"), "base_url_source": "https://api.deepseek.com/v1", "fallback_model": "deepseek-chat" } } return provider_configs.get(provider_choice, {}) if __name__ == "__main__": extract_guidelines() print(os.listdir(os.path.expanduser("~"))) # prepare_output_structure(Path("guidelines/output")) st.title("Med-GraphRAG Demo") base = "guidelines/output/lancedb" mapping = { "default-entity-description.lance": "entities.lance", "default-text_unit-text.lance": "text_units.lance", "default-community-full_content.lance": "relationships.lance", } for old, new in mapping.items(): src = os.path.join(base, old) dst = os.path.join(base, new) if os.path.exists(src) and not os.path.exists(dst): shutil.copytree(src, dst) db_path = Path("guidelines/output/lancedb") print(" Checking LanceDB folder:", db_path.resolve()) db = lancedb.connect(str(db_path)) tables = db.table_names() print("Existing tables:", tables) if not tables: for name in ["entities", "text_units", "relationships"]: lance_dir = db_path / f"{name}.lance" if lance_dir.exists(): print(f"Registering {name} from {lance_dir}") db.create_table(name, str(lance_dir), mode="create_if_not_exists") print(" Tables registered successfully.") print("Available tables after registration:", db.table_names()) with st.sidebar: # Provider selection dropdown available_providers = ["OpenAI", "Deepseek"] if "current_provider_choice" not in st.session_state: st.session_state.current_provider_choice = available_providers[0] provider_selection = st.selectbox( "Choose AI Provider:", available_providers, key="current_provider_choice" ) # Get provider-specific settings provider_settings = initialize_provider_settings(provider_selection) # Validate required credentials if not provider_settings.get("api_key_source") or not provider_settings.get("base_url_source"): st.error(f"Configuration missing for {provider_selection}. Check environment variables.") st.stop() # Setup OpenAI client try: print(f'api_key used in app.py: {provider_settings["api_key_source"]}') api_client = OpenAI( api_key=provider_settings["api_key_source"], base_url=provider_settings["base_url_source"] ) available_models = api_client.models.list() model_list = sorted([m.id for m in available_models]) # Handle model selection with provider switching session_key = f"model_for_{provider_selection}" if session_key not in st.session_state or st.session_state.get("last_provider") != provider_selection: preferred_model = provider_settings.get("fallback_model") if preferred_model and preferred_model in model_list: st.session_state[session_key] = preferred_model elif model_list: st.session_state[session_key] = model_list[0] st.session_state.last_provider = provider_selection if not model_list: st.error(f"No models found for {provider_selection}") st.stop() # Model selection interface chosen_model = st.selectbox( f"Available models from {provider_selection}:", model_list, key=session_key, ) st.info(f"Active model: {chosen_model}") test = test(OPENAI_API_KEY=provider_settings["api_key_source"], model=chosen_model) except Exception as connection_error: st.error(f"Connection failed for {provider_selection}: {connection_error}") st.stop() st.button("Reset Conversation", on_click=clear_chat) st.markdown("---") # Display provider-specific information if provider_selection == "OpenAI": st.markdown( """ Example: A 60-year-old man with a history of alcoholic liver cirrhosis presented to the emergency department with new onset of haematemesis. Blood pressure was 90/40 mmHg and the heart rate was 98 beats per minute. The haemoglobin was 8 g/dL (normal range: 11.5-16 g/dL) with normal clotting parameters and platelets. Electrolytes were also normal. Upper GI endoscopy showed large oesophageal varices that continued to bleed despite attempts at endoscopic treatment with band ligation and sclerotherapy. The patient was referred for consideration of transjugular intrahepatic portosystemic shunt (TIPS) """ ) elif provider_selection == "Deepseek": st.markdown( """ Deepseek """ ) if "messages" not in st.session_state: st.session_state.messages = [] for msg in st.session_state.messages: with st.chat_message(msg["role"]): st.markdown(msg["content"], unsafe_allow_html=True) user_input = st.chat_input("Please enter the patient's conditions...") if user_input: st.session_state.messages.append({"role": "user", "content": user_input}) with st.chat_message("user"): st.markdown(user_input) try: ( response_patientcase, response_generated_plan, response_key_info ) = test.evaluate_SOA_P_with_patientbase_graphrag(user_input) with st.chat_message("assistant"): # === 1) Most Similar Patient (folded) === st.markdown("""