Spaces:
Sleeping
Sleeping
| from pathlib import Path | |
| import streamlit as st | |
| import torch.nn.functional as F | |
| from predictor.link_predictor import ( | |
| prepare_system, | |
| get_citation_predictions, | |
| abstract_to_vector, | |
| format_top_k_predictions, | |
| ) | |
| from llm.related_work_generator import generate_related_work | |
| MODEL_PATH = Path("model.pth") | |
| def load_prediction_system(model_path): | |
| return prepare_system(model_path) | |
| def app(): | |
| st.set_page_config(page_title="Galis", layout="wide") | |
| st.title("Galis") | |
| if "references" not in st.session_state: | |
| st.session_state.references = None | |
| if "related_work" not in st.session_state: | |
| st.session_state.related_work = None | |
| if "abstract_title" not in st.session_state: | |
| st.session_state.abstract_title = "" | |
| if "abstract_text" not in st.session_state: | |
| st.session_state.abstract_text = "" | |
| gcn_model, st_model, dataset, z_all = load_prediction_system(MODEL_PATH) | |
| col1, col2 = st.columns(2, gap="large") | |
| with col2: | |
| references_placeholder = st.empty() | |
| related_work_placeholder = st.empty() | |
| with col1: | |
| st.header("Abstract Title") | |
| abstract_title = st.text_input( | |
| "Paste your title here", | |
| st.session_state.abstract_title, | |
| key="abstract_title_input", | |
| label_visibility="collapsed", | |
| ) | |
| st.header("Abstract Text") | |
| abstract_input = st.text_area( | |
| "Paste your abstract here", | |
| st.session_state.abstract_text, | |
| key="abstract_text_input", | |
| height=100, | |
| label_visibility="collapsed", | |
| ) | |
| st.write("...or **upload** a .txt file (first line = title, rest = abstract)") | |
| uploaded_file = st.file_uploader( | |
| "Drag and drop file here", type=["txt"], help="Limit 200MB per file • TXT" | |
| ) | |
| if uploaded_file is not None: | |
| content = uploaded_file.getvalue().decode("utf-8").splitlines() | |
| st.session_state.abstract_title = content[0] if content else "" | |
| st.session_state.abstract_text = ( | |
| "\n".join(content[1:]) if len(content) > 1 else "" | |
| ) | |
| st.rerun() | |
| st.session_state.abstract_title = abstract_title | |
| st.session_state.abstract_text = abstract_input | |
| num_citations = st.number_input( | |
| "Number of suggestions", | |
| min_value=1, | |
| max_value=100, | |
| value=10, | |
| step=1, | |
| help="Choose how many paper suggestions you want to see.", | |
| ) | |
| if st.button("Suggest References and related work", type="primary"): | |
| if not abstract_title.strip() or not abstract_input.strip(): | |
| st.warning("Please provide both a title and an abstract.") | |
| else: | |
| st.session_state.references = None | |
| st.session_state.related_work = None | |
| references_placeholder.empty() | |
| related_work_placeholder.empty() | |
| with st.spinner("Analyzing abstract and predicting references..."): | |
| new_vector = abstract_to_vector( | |
| abstract_input, abstract_title, st_model | |
| ) | |
| probabilities = get_citation_predictions( | |
| vector=F.normalize(new_vector.view(1, -1), p=2, dim=1), | |
| model=gcn_model, | |
| z_all=z_all, | |
| num_nodes=dataset.data.num_nodes, | |
| ) | |
| references = format_top_k_predictions( | |
| probabilities, dataset, top_k=num_citations | |
| ) | |
| st.session_state.references = references | |
| with references_placeholder.container(): | |
| st.header("Suggested References") | |
| with st.container(height=200): | |
| st.markdown(st.session_state.references) | |
| with related_work_placeholder.container(): | |
| with st.spinner("Generating related work section..."): | |
| related_work = generate_related_work( | |
| st.session_state.references | |
| ) | |
| st.session_state.related_work = related_work | |
| if st.session_state.references: | |
| with references_placeholder.container(): | |
| st.header("Suggested References") | |
| with st.container(height=200): | |
| st.markdown(st.session_state.references) | |
| if st.session_state.related_work: | |
| with related_work_placeholder.container(): | |
| st.header("Suggested Related Works") | |
| with st.container(height=200): | |
| st.markdown(st.session_state.related_work) | |
| if __name__ == "__main__": | |
| app() | |