ok
Browse files
app.py
CHANGED
|
@@ -2,12 +2,29 @@ import streamlit as st
|
|
| 2 |
import json
|
| 3 |
import random
|
| 4 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
@st.cache_resource
|
| 6 |
def load_model():
|
| 7 |
import adrd
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
|
|
|
|
|
|
|
|
|
| 11 |
return model
|
| 12 |
|
| 13 |
@st.cache_resource
|
|
@@ -37,28 +54,39 @@ dat_tst = CSVDataset(
|
|
| 37 |
if 'input_text' not in st.session_state:
|
| 38 |
st.session_state.input_text = ""
|
| 39 |
|
| 40 |
-
#
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
)
|
| 48 |
|
| 49 |
-
|
| 50 |
-
|
| 51 |
|
| 52 |
-
|
| 53 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
|
| 55 |
-
|
| 56 |
-
|
| 57 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
if sample_button:
|
| 59 |
idx = random.randint(0, len(dat_tst) - 1)
|
| 60 |
-
|
| 61 |
-
st.session_state.input_text = json.dumps(
|
| 62 |
|
| 63 |
# reset input text after form processing to show updated text in the input box
|
| 64 |
if 'input_text' in st.session_state:
|
|
@@ -69,8 +97,12 @@ elif submit_button:
|
|
| 69 |
# Parse the JSON input into a Python dictionary
|
| 70 |
data_dict = json.loads(json_input)
|
| 71 |
pred_dict = predict_proba(data_dict)
|
| 72 |
-
|
| 73 |
-
|
|
|
|
| 74 |
except json.JSONDecodeError as e:
|
| 75 |
# Handle JSON parsing errors
|
| 76 |
st.error(f"An error occurred: {e}")
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
import json
|
| 3 |
import random
|
| 4 |
|
| 5 |
+
# set page configuration to wide mode
|
| 6 |
+
st.set_page_config(layout="wide")
|
| 7 |
+
|
| 8 |
+
st.markdown("""
|
| 9 |
+
<style>
|
| 10 |
+
.bounding-box {
|
| 11 |
+
border: 2px solid #4CAF50; # Green border
|
| 12 |
+
border-radius: 5px; # Rounded corners
|
| 13 |
+
padding: 10px; # Padding inside the box
|
| 14 |
+
margin: 10px; # Space outside the box
|
| 15 |
+
}
|
| 16 |
+
</style>
|
| 17 |
+
""", unsafe_allow_html=True)
|
| 18 |
+
|
| 19 |
@st.cache_resource
|
| 20 |
def load_model():
|
| 21 |
import adrd
|
| 22 |
+
try:
|
| 23 |
+
ckpt_path = './ckpt_densenet_emb_encoder_2_AUPR.pt'
|
| 24 |
+
model = adrd.model.ADRDModel.from_ckpt(ckpt_path, device='cpu')
|
| 25 |
+
except:
|
| 26 |
+
ckpt_path = '../adrd_tool_copied_from_sahana/dev/ckpt/ckpt_densenet_emb_encoder_2_AUPR.pt'
|
| 27 |
+
model = adrd.model.ADRDModel.from_ckpt(ckpt_path, device='cpu')
|
| 28 |
return model
|
| 29 |
|
| 30 |
@st.cache_resource
|
|
|
|
| 54 |
if 'input_text' not in st.session_state:
|
| 55 |
st.session_state.input_text = ""
|
| 56 |
|
| 57 |
+
# section 1
|
| 58 |
+
st.markdown("#### About ADRD")
|
| 59 |
+
st.markdown("Differential diagnosis of dementia remains a challenge in neurology due to symptom overlap across etiologies, yet it is crucial for formulating early, personalized management strategies. Here, we present an AI model that harnesses a broad array of data, including demographics, individual and family medical history, medication use, neuropsychological assessments, functional evaluations, and multimodal neuroimaging, to identify the etiologies contributing to dementia in individuals.")
|
| 60 |
+
|
| 61 |
+
# section 2
|
| 62 |
+
st.markdown("#### Demo")
|
| 63 |
+
st.markdown("Please enter the input features in the textbox below, formatted as a JSON dictionary. Click the \"**Random NACC Case**\" button to populate the textbox with a randomly selected case from the NACC testing dataset. Use the \"**Predict**\" button to submit your input to the model, which will then provide probability predictions for mental status and all 10 etiologies.")
|
|
|
|
| 64 |
|
| 65 |
+
# layout
|
| 66 |
+
layout_l, layout_r = st.columns([1, 1])
|
| 67 |
|
| 68 |
+
# create a form for user input
|
| 69 |
+
with layout_l:
|
| 70 |
+
with st.form("json_input_form"):
|
| 71 |
+
json_input = st.text_area(
|
| 72 |
+
"Please enter JSON-formatted input features:",
|
| 73 |
+
value = st.session_state.input_text,
|
| 74 |
+
height = 250
|
| 75 |
+
)
|
| 76 |
|
| 77 |
+
# create three columns
|
| 78 |
+
left_col, middle_col, right_col = st.columns([3, 4, 1])
|
| 79 |
|
| 80 |
+
with left_col:
|
| 81 |
+
sample_button = st.form_submit_button("Random NACC Case")
|
| 82 |
+
|
| 83 |
+
with right_col:
|
| 84 |
+
submit_button = st.form_submit_button("Predict")
|
| 85 |
+
|
| 86 |
if sample_button:
|
| 87 |
idx = random.randint(0, len(dat_tst) - 1)
|
| 88 |
+
random_case = dat_tst[idx][0]
|
| 89 |
+
st.session_state.input_text = json.dumps(random_case, indent=2)
|
| 90 |
|
| 91 |
# reset input text after form processing to show updated text in the input box
|
| 92 |
if 'input_text' in st.session_state:
|
|
|
|
| 97 |
# Parse the JSON input into a Python dictionary
|
| 98 |
data_dict = json.loads(json_input)
|
| 99 |
pred_dict = predict_proba(data_dict)
|
| 100 |
+
with layout_r:
|
| 101 |
+
st.write("Predicted probabilities:")
|
| 102 |
+
st.json(pred_dict)
|
| 103 |
except json.JSONDecodeError as e:
|
| 104 |
# Handle JSON parsing errors
|
| 105 |
st.error(f"An error occurred: {e}")
|
| 106 |
+
|
| 107 |
+
# section 3
|
| 108 |
+
|