Perunio commited on
Commit
90bc141
·
1 Parent(s): 1a18d8f
.dockerignore ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __pycache__
2
+ *.pyc
3
+ .git
4
+ node_modules
5
+ .env
6
+ README.md
7
+ Dockerfile
8
+
9
+ # data folder
10
+ data
11
+ predictor/data
12
+ model/data
13
+ dataset/data
.gitignore ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[codz]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py.cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # UV
98
+ # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ #uv.lock
102
+
103
+ # poetry
104
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
105
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
106
+ # commonly ignored for libraries.
107
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
108
+ #poetry.lock
109
+ #poetry.toml
110
+
111
+ # pdm
112
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
113
+ # pdm recommends including project-wide configuration in pdm.toml, but excluding .pdm-python.
114
+ # https://pdm-project.org/en/latest/usage/project/#working-with-version-control
115
+ #pdm.lock
116
+ #pdm.toml
117
+ .pdm-python
118
+ .pdm-build/
119
+
120
+ # pixi
121
+ # Similar to Pipfile.lock, it is generally recommended to include pixi.lock in version control.
122
+ #pixi.lock
123
+ # Pixi creates a virtual environment in the .pixi directory, just like venv module creates one
124
+ # in the .venv directory. It is recommended not to include this directory in version control.
125
+ .pixi
126
+
127
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
128
+ __pypackages__/
129
+
130
+ # Celery stuff
131
+ celerybeat-schedule
132
+ celerybeat.pid
133
+
134
+ # SageMath parsed files
135
+ *.sage.py
136
+
137
+ # Environments
138
+ .env
139
+ .envrc
140
+ .venv
141
+ env/
142
+ venv/
143
+ ENV/
144
+ env.bak/
145
+ venv.bak/
146
+
147
+ # Spyder project settings
148
+ .spyderproject
149
+ .spyproject
150
+
151
+ # Rope project settings
152
+ .ropeproject
153
+
154
+ # mkdocs documentation
155
+ /site
156
+
157
+ # mypy
158
+ .mypy_cache/
159
+ .dmypy.json
160
+ dmypy.json
161
+
162
+ # Pyre type checker
163
+ .pyre/
164
+
165
+ # pytype static type analyzer
166
+ .pytype/
167
+
168
+ # Cython debug symbols
169
+ cython_debug/
170
+
171
+ # PyCharm
172
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
173
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
174
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
175
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
176
+ .idea/
177
+
178
+ # Abstra
179
+ # Abstra is an AI-powered process automation framework.
180
+ # Ignore directories containing user credentials, local state, and settings.
181
+ # Learn more at https://abstra.io/docs
182
+ .abstra/
183
+
184
+ # Visual Studio Code
185
+ # Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore
186
+ # that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore
187
+ # and can be added to the global gitignore or merged into this file. However, if you prefer,
188
+ # you could uncomment the following to ignore the entire vscode folder
189
+ .vscode/
190
+
191
+ # Ruff stuff:
192
+ .ruff_cache/
193
+
194
+ # PyPI configuration file
195
+ .pypirc
196
+
197
+ # Cursor
198
+ # Cursor is an AI-powered code editor. `.cursorignore` specifies files/directories to
199
+ # exclude from AI features like autocomplete and code analysis. Recommended for sensitive data
200
+ # refer to https://docs.cursor.com/context/ignore-files
201
+ .cursorignore
202
+ .cursorindexingignore
203
+
204
+ # Marimo
205
+ marimo/_static/
206
+ marimo/_lsp/
207
+ __marimo__/
208
+
209
+ # data
210
+ data
211
+ predictor/data
212
+
213
+
214
+ # model
215
+ *.pth
Dockerfile ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.12.4-slim
2
+
3
+ WORKDIR /app
4
+
5
+ RUN pip install poetry
6
+
7
+ COPY pyproject.toml ./
8
+ RUN poetry install --no-root
9
+ RUN poetry run pip install torch-scatter torch-sparse torch-cluster pyg-lib -f https://data.pyg.org/whl/torch-2.3.1+cu121.html
10
+ RUN poetry run pip install torch-geometric
11
+
12
+ COPY galis_app.py ./
13
+ COPY model ./model
14
+ COPY dataset ./dataset
15
+ COPY predictor ./predictor
16
+ COPY llm ./llm
17
+
18
+ ENV GOOGLE_API_KEY=""
19
+
20
+ EXPOSE 7860
21
+
22
+ CMD ["poetry", "run", "streamlit", "run", "galis_app.py", "--server.port=7860", "--server.address=0.0.0.0"]
23
+
dataset/__init__.py ADDED
File without changes
dataset/ogbn_link_pred_dataset.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pandas as pd
3
+ import torch
4
+ from ogb.nodeproppred import PygNodePropPredDataset
5
+ from torch_geometric.transforms import RandomLinkSplit
6
+ from torch_geometric.loader import LinkNeighborLoader
7
+ from torch_geometric.data import Data
8
+
9
+ import requests
10
+ import gzip
11
+ import shutil
12
+
13
+
14
+ class OGBNLinkPredDataset:
15
+ def __init__(
16
+ self, root_dir: str = "data", val_size: float = 0.1, test_size: float = 0.2
17
+ ):
18
+ self._base_dataset = PygNodePropPredDataset(name="ogbn-arxiv", root=root_dir)
19
+ self.data = self._base_dataset[0]
20
+ self.root = self._base_dataset.root
21
+ self.num_features = self._base_dataset.num_features
22
+
23
+ self._download_abstracts()
24
+ self.corpus = self._load_corpus()
25
+
26
+ self.train_data, self.val_data, self.test_data = self._split_data(
27
+ val_size, test_size
28
+ )
29
+
30
+ def _download_abstracts(self):
31
+ target_dir = os.path.join(self.root, "mapping")
32
+ tsv_path = os.path.join(target_dir, "titleabs.tsv")
33
+
34
+ if not os.path.exists(tsv_path):
35
+ print("Downloading title and abstract information...")
36
+ gz_path = tsv_path + ".gz"
37
+ url = "https://snap.stanford.edu/ogb/data/misc/ogbn_arxiv/titleabs.tsv.gz"
38
+ os.makedirs(target_dir, exist_ok=True)
39
+
40
+ try:
41
+ print(f"Downloading from {url}...")
42
+ response = requests.get(url, stream=True)
43
+ response.raise_for_status()
44
+ with open(gz_path, "wb") as f:
45
+ for chunk in response.iter_content(chunk_size=8192):
46
+ f.write(chunk)
47
+ print(f"File downloaded to: {gz_path}")
48
+
49
+ print(f"Decompressing {gz_path}...")
50
+ with gzip.open(gz_path, 'rb') as f_in:
51
+ with open(tsv_path, 'wb') as f_out:
52
+ shutil.copyfileobj(f_in, f_out)
53
+ print(f"File extracted to: {tsv_path}")
54
+
55
+ os.remove(gz_path)
56
+ print(f"Removed temporary file: {gz_path}")
57
+
58
+ except requests.exceptions.RequestException as e:
59
+ print(f"Error downloading file: {e}")
60
+ except Exception as e:
61
+ print(f"An error occurred: {e}")
62
+
63
+ else:
64
+ print("Title and abstract file already exists.")
65
+
66
+ def _load_corpus(self) -> list[str]:
67
+ tsv_path = os.path.join(self.root, "mapping", "titleabs.tsv")
68
+ try:
69
+ df_text = pd.read_csv(
70
+ tsv_path,
71
+ sep="\t",
72
+ header=None,
73
+ names=["paper_id", "title", "abstract"],
74
+ lineterminator="\n",
75
+ low_memory=False,
76
+ )
77
+ df_text_aligned = df_text.reset_index(drop=True)
78
+ corpus = (
79
+ df_text_aligned["title"].fillna("")
80
+ + "\n "
81
+ + df_text_aligned["abstract"].fillna("")
82
+ ).tolist()
83
+ print(f"Corpus created with {len(corpus)} documents.")
84
+ return corpus
85
+ except FileNotFoundError:
86
+ print("Error: titleabs.tsv not found. Could not create corpus.")
87
+ return []
88
+
89
+ def _split_data(self, val_size: float, test_size: float) -> tuple[Data, Data, Data]:
90
+ transform = RandomLinkSplit(
91
+ num_val=val_size,
92
+ num_test=test_size,
93
+ is_undirected=False,
94
+ add_negative_train_samples=False,
95
+ )
96
+ train_split, val_split, test_split = transform(self.data)
97
+ print("Data successfully split into train, validation, and test sets.")
98
+ return train_split, val_split, test_split
99
+
100
+ def get_splits(self) -> tuple[Data, Data, Data]:
101
+ return self.train_data, self.val_data, self.test_data
galis_app.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ import streamlit as st
3
+
4
+ from predictor.link_predictor import (
5
+ prepare_system,
6
+ get_citation_predictions,
7
+ abstract_to_vector,
8
+ format_top_k_predictions,
9
+ )
10
+ from llm.related_work_generator import generate_related_work
11
+
12
+ MODEL_PATH = Path("predictor/model.pth")
13
+
14
+
15
+ @st.cache_resource
16
+ def load_prediction_system(model_path):
17
+ return prepare_system(model_path)
18
+
19
+
20
+ def app():
21
+ st.set_page_config(page_title="Galis", layout="wide")
22
+ st.title("Galis")
23
+
24
+ if "references" not in st.session_state:
25
+ st.session_state.references = None
26
+ if "related_work" not in st.session_state:
27
+ st.session_state.related_work = None
28
+ if "abstract_title" not in st.session_state:
29
+ st.session_state.abstract_title = ""
30
+ if "abstract_text" not in st.session_state:
31
+ st.session_state.abstract_text = ""
32
+
33
+ gcn_model, st_model, dataset, z_all = load_prediction_system(MODEL_PATH)
34
+
35
+ col1, col2 = st.columns(2, gap="large")
36
+
37
+ with col2:
38
+ references_placeholder = st.empty()
39
+ related_work_placeholder = st.empty()
40
+
41
+ with col1:
42
+ st.header("Abstract Title")
43
+ abstract_title = st.text_input(
44
+ "Paste your title here",
45
+ st.session_state.abstract_title,
46
+ key="abstract_title_input",
47
+ label_visibility="collapsed",
48
+ )
49
+
50
+ st.header("Abstract Text")
51
+ abstract_input = st.text_area(
52
+ "Paste your abstract here",
53
+ st.session_state.abstract_text,
54
+ key="abstract_text_input",
55
+ height=100,
56
+ label_visibility="collapsed",
57
+ )
58
+
59
+ st.write("...or **upload** a .txt file (first line = title, rest = abstract)")
60
+ uploaded_file = st.file_uploader(
61
+ "Drag and drop file here", type=["txt"], help="Limit 200MB per file • TXT"
62
+ )
63
+
64
+ if uploaded_file is not None:
65
+ content = uploaded_file.getvalue().decode("utf-8").splitlines()
66
+ st.session_state.abstract_title = content[0] if content else ""
67
+ st.session_state.abstract_text = (
68
+ "\n".join(content[1:]) if len(content) > 1 else ""
69
+ )
70
+ st.rerun()
71
+
72
+ st.session_state.abstract_title = abstract_title
73
+ st.session_state.abstract_text = abstract_input
74
+
75
+ num_citations = st.number_input(
76
+ "Number of suggestions",
77
+ min_value=1,
78
+ max_value=100,
79
+ value=10,
80
+ step=1,
81
+ help="Choose how many paper suggestions you want to see.",
82
+ )
83
+
84
+ if st.button("Suggest References and related work", type="primary"):
85
+ if not abstract_title.strip() or not abstract_input.strip():
86
+ st.warning("Please provide both a title and an abstract.")
87
+ else:
88
+ st.session_state.references = None
89
+ st.session_state.related_work = None
90
+ references_placeholder.empty()
91
+ related_work_placeholder.empty()
92
+
93
+ with st.spinner("Analyzing abstract and predicting references..."):
94
+ new_vector = abstract_to_vector(
95
+ abstract_input, abstract_title, st_model
96
+ )
97
+ probabilities = get_citation_predictions(
98
+ vector=new_vector,
99
+ model=gcn_model,
100
+ z_all=z_all,
101
+ num_nodes=dataset.data.num_nodes,
102
+ )
103
+ references = format_top_k_predictions(
104
+ probabilities, dataset, top_k=num_citations
105
+ )
106
+ st.session_state.references = references
107
+
108
+ with references_placeholder.container():
109
+ st.header("Suggested References")
110
+ with st.container(height=200):
111
+ st.markdown(st.session_state.references)
112
+
113
+ with related_work_placeholder.container():
114
+ with st.spinner("Generating related work section..."):
115
+ related_work = generate_related_work(st.session_state.references)
116
+ st.session_state.related_work = related_work
117
+
118
+ if st.session_state.references:
119
+ with references_placeholder.container():
120
+ st.header("Suggested References")
121
+ with st.container(height=200):
122
+ st.markdown(st.session_state.references)
123
+
124
+ if st.session_state.related_work:
125
+ with related_work_placeholder.container():
126
+ st.header("Suggested Related Works")
127
+ with st.container(height=200):
128
+ st.markdown(st.session_state.related_work)
129
+
130
+
131
+ if __name__ == "__main__":
132
+ app()
llm/__init__.py ADDED
File without changes
llm/related_work_generator.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dotenv import load_dotenv
2
+ import os
3
+ import structlog
4
+ from langchain_google_genai import ChatGoogleGenerativeAI
5
+ from langchain_core.prompts import PromptTemplate
6
+ from langchain_core.output_parsers import StrOutputParser
7
+
8
+ structlog.configure(
9
+ processors=[
10
+ structlog.processors.TimeStamper(fmt="iso"),
11
+ structlog.processors.JSONRenderer(indent=4, sort_keys=True),
12
+ ]
13
+ )
14
+ logger = structlog.get_logger()
15
+
16
+ load_dotenv()
17
+
18
+ PROMPT_TEXT = """
19
+ You are a research assistant AI specializing in academic writing. Your task is to generate a "Related Work" section
20
+ for a research paper. You will be given a list of citations.
21
+
22
+ Your goal is to synthesize the provided citations into a coherent and well-structured "Related Work" section that
23
+ contextualizes the user's project within the existing academic literature.
24
+
25
+ **PROVIDED CITATIONS:**
26
+ {citations}
27
+
28
+ **INSTRUCTIONS:**
29
+
30
+ 1. **Thematic Organization:** Do not simply list summaries of the papers. Group the provided citations into thematic
31
+ categories based on shared concepts, methodologies, or research problems. For example, you could create categories like
32
+ "Transformer-based Language Models," "Sentiment Analysis Techniques," and "Efficient Models for NLP." Introduce each
33
+ theme before discussing the relevant papers.
34
+
35
+ 2. **Synthesis and Analysis:** For each thematic group, synthesize the key contributions and findings of the papers.
36
+ Go beyond summarization; compare and contrast the different approaches. For instance, you could discuss the evolution
37
+ of certain methods or the trade-offs between different models (e.g., accuracy vs. computational efficiency).
38
+
39
+ 3. **Identify Research Gaps:** Critically analyze the literature you are reviewing. Explicitly identify the
40
+ limitations, open questions, or research gaps that the cited works leave unresolved. This will set the stage for
41
+ introducing the project's contribution.
42
+
43
+ 4. **Contextualize the User's Project:** After discussing a thematic group of papers and identifying a gap, clearly
44
+ and explicitly state how the user's project (described above) addresses this gap or builds upon the existing work. Use
45
+ phrases like: "While these methods have shown great success, they struggle with...", "To address this limitation, our
46
+ work introduces...", or "Building upon the foundation laid by [Author, Year], we propose a novel approach that...".
47
+
48
+ 5. **Academic Tone and Flow:** Maintain a formal, objective, and academic tone throughout the text. Ensure smooth
49
+ transitions between paragraphs and ideas to create a coherent narrative that logically leads the reader to understand
50
+ the novelty and importance of the user's project.
51
+
52
+ 6. **Output Format:** Generate only the text for the "Related Work" section. Do not include headers like
53
+ "INSTRUCTIONS" or "PROVIDED CITATIONS" in the final output. The entire response should be the section text itself.
54
+ """
55
+
56
+
57
+ def check_api_key():
58
+ api_key = os.getenv("GOOGLE_API_KEY")
59
+ if not api_key:
60
+ logger.error("GOOGLE_API_KEY not set")
61
+ return False
62
+ logger.info(f"Gemini API Key is loaded: {api_key[:10]}...")
63
+ return True
64
+
65
+
66
+ def create_related_work_pipeline():
67
+ """Creates a ready-to-use pipeline for generating the Related Work section."""
68
+
69
+ llm = ChatGoogleGenerativeAI(
70
+ model="gemini-1.5-flash",
71
+ temperature=0.3
72
+ )
73
+
74
+ prompt = PromptTemplate(
75
+ input_variables=["citations"],
76
+ template=PROMPT_TEXT
77
+ )
78
+
79
+ parser = StrOutputParser()
80
+
81
+ chain = prompt | llm | parser
82
+
83
+ return chain
84
+
85
+
86
+ def generate_related_work(citations_text: str) -> str:
87
+ """
88
+ Main function - pass citations, get Related Work
89
+
90
+ Args:
91
+ citations_text: Text with citations (can be a list or a string)
92
+
93
+ Returns:
94
+ The generated Related Work section
95
+ """
96
+ pipeline = create_related_work_pipeline()
97
+ result = pipeline.invoke({"citations": citations_text})
98
+ return result
99
+
100
+
101
+ if __name__ == "__main__":
102
+
103
+ my_citations = """
104
+ Top 5 Citation Predictions:
105
+ - Title: 'deterministic construction of rip matrices in compressed sensing from constant weight codes'
106
+ - Title: 'mizar items exploring fine grained dependencies in the mizar mathematical library'
107
+ - Title: 'rateless lossy compression via the extremes'
108
+ - Title: 'towards autonomic service provisioning systems'
109
+ - Title: 'anonymization with worst case distribution based background knowledge'
110
+ """
111
+
112
+ print("Generuję Related Work...")
113
+ print("=" * 50)
114
+
115
+ try:
116
+ related_work = generate_related_work(my_citations)
117
+ print(related_work)
118
+ except Exception as e:
119
+ print(f"Błąd: {e}")
120
+ print("\n=== INSTRUKCJE KONFIGURACJI ===")
121
+ print("1. Stwórz plik .env w tym samym folderze co skrypt")
122
+ print("2. Dodaj do niego linię: GOOGLE_API_KEY=twój_klucz")
123
+ print("3. Uzyskaj klucz na: https://makersuite.google.com/app/apikey")
124
+ check_api_key()
model/__init__.py ADDED
File without changes
model/simple_gcn_model.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from torch_geometric.nn import GCNConv
4
+
5
+
6
+ class EdgeDecoder(torch.nn.Module):
7
+ """Predict citation existence of two node embeddings."""
8
+
9
+ def __init__(self, in_channels):
10
+ super().__init__()
11
+ self.linear = torch.nn.Linear(in_channels * 2, 1)
12
+
13
+ def forward(self, z, edge_index):
14
+ row, col = edge_index
15
+ # Concatenate the embeddings of the two nodes
16
+ z_cat = torch.cat([z[row], z[col]], dim=-1)
17
+ return self.linear(z_cat).squeeze(-1)
18
+
19
+
20
+ class SimpleGCN(torch.nn.Module):
21
+ """Include encoder and decoder part. Encoder creates embedding for given node and decoder predict link existence between node embeddings."""
22
+
23
+ def __init__(self, in_channels, hidden_channels, out_channels):
24
+ super().__init__()
25
+ self.conv1 = GCNConv(in_channels, hidden_channels)
26
+ self.conv2 = GCNConv(hidden_channels, out_channels)
27
+ self.decoder = EdgeDecoder(out_channels)
28
+
29
+ def forward(self, x, edge_index):
30
+ x = self.conv1(x, edge_index).relu()
31
+ x = F.dropout(x, p=0.5, training=self.training)
32
+ z = self.conv2(x, edge_index)
33
+ return z
34
+
35
+ def decode(self, z, edge_label_index):
36
+ # We pass the edge_label_index to the decoder, which contains both pos and neg edges
37
+ return self.decoder(z, edge_label_index)
model/train.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from torch_geometric.loader import LinkNeighborLoader
4
+ from sklearn.metrics import roc_auc_score, accuracy_score
5
+ from tqdm import tqdm
6
+ from model.simple_gcn_model import SimpleGCN
7
+ from dataset.ogbn_link_pred_dataset import OGBNLinkPredDataset
8
+
9
+
10
+ BATCH_SIZE = 128
11
+ NUM_EPOCHS = 20
12
+ LR = 0.001
13
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
+
15
+ # data
16
+ dataset = OGBNLinkPredDataset(val_size=0.1, test_size=0.2)
17
+ train_data, val_data, test_data = dataset.get_splits()
18
+
19
+ train_loader = LinkNeighborLoader(
20
+ train_data,
21
+ num_neighbors=[-1, -1], # Use all neighbors
22
+ neg_sampling_ratio=1.0, # 1 negative sample per positive edge
23
+ edge_label_index=train_data.edge_label_index,
24
+ edge_label=train_data.edge_label,
25
+ batch_size=BATCH_SIZE,
26
+ shuffle=True,
27
+ num_workers=4,
28
+ )
29
+
30
+ val_loader = LinkNeighborLoader(
31
+ val_data,
32
+ num_neighbors=[-1, -1],
33
+ neg_sampling_ratio=0.0, # RandomLinkSplit already added negative edges
34
+ edge_label_index=val_data.edge_label_index,
35
+ edge_label=val_data.edge_label,
36
+ batch_size=BATCH_SIZE,
37
+ shuffle=False,
38
+ num_workers=4,
39
+ )
40
+
41
+ test_loader = LinkNeighborLoader(
42
+ test_data,
43
+ num_neighbors=[-1, -1],
44
+ neg_sampling_ratio=0.0,
45
+ edge_label_index=test_data.edge_label_index,
46
+ edge_label=test_data.edge_label,
47
+ batch_size=BATCH_SIZE,
48
+ shuffle=False,
49
+ num_workers=4,
50
+ )
51
+
52
+ # model
53
+ model = SimpleGCN(
54
+ in_channels=dataset.num_features,
55
+ hidden_channels=256,
56
+ out_channels=128,
57
+ ).to(DEVICE)
58
+
59
+ optimizer = torch.optim.Adam(model.parameters(), lr=LR)
60
+ criterion = torch.nn.BCEWithLogitsLoss()
61
+
62
+
63
+ # training
64
+ def train(train_loader, epoch):
65
+ model.train()
66
+ total_loss = 0
67
+ scaler = torch.GradScaler()
68
+
69
+ pbar = tqdm(train_loader, desc=f"Training Epoch: {epoch}")
70
+ for batch in pbar:
71
+ batch = batch.to(DEVICE)
72
+ optimizer.zero_grad()
73
+
74
+ with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True):
75
+ z = model(batch.x, batch.edge_index)
76
+ out = model.decode(z, batch.edge_label_index)
77
+ labels = batch.edge_label.float()
78
+
79
+ loss = criterion(out, labels)
80
+
81
+ scaler.scale(loss).backward()
82
+ scaler.step(optimizer)
83
+ scaler.update()
84
+
85
+ total_loss += loss.item()
86
+ pbar.set_postfix(loss=f"{loss.item():.4f}")
87
+
88
+ return total_loss / len(train_loader)
89
+
90
+
91
+ @torch.no_grad()
92
+ def calc_metrics(loader):
93
+ model.eval()
94
+ all_scores = []
95
+ all_labels = []
96
+
97
+ pbar = tqdm(loader, desc="Testing")
98
+ for batch in pbar:
99
+ batch = batch.to(DEVICE)
100
+ with torch.autocast(device_type=DEVICE.type, dtype=torch.bfloat16):
101
+ z = model(batch.x, batch.edge_index)
102
+ out = model.decode(z, batch.edge_label_index)
103
+
104
+ scores = torch.sigmoid(out).float().cpu().numpy()
105
+ labels = batch.edge_label.cpu().numpy()
106
+
107
+ all_scores.append(scores)
108
+ all_labels.append(labels)
109
+
110
+ all_scores = np.concatenate(all_scores)
111
+ all_labels = np.concatenate(all_labels)
112
+
113
+ return roc_auc_score(all_labels, all_scores), accuracy_score(
114
+ all_labels, all_scores > 0.5
115
+ )
116
+
117
+
118
+ if __name__ == "__main__":
119
+ best_val_auc = 0
120
+ best_auc = 0
121
+ for epoch in range(1, NUM_EPOCHS + 1):
122
+ loss = train(train_loader, epoch)
123
+ val_auc, val_acc = calc_metrics(val_loader)
124
+
125
+
126
+ print(
127
+ f"Epoch: {epoch:03d}, Loss: {loss:.4f}, Val AUC: {val_auc:.4f}, Val acc: {val_acc:.4f}",
128
+ end=" ",
129
+ )
130
+ if val_auc > best_val_auc:
131
+ print("New best")
132
+ best_val_auc = val_auc
133
+ best_auc = val_auc
134
+ torch.save(model.state_dict(), "model.pth")
135
+
136
+ test_auc, test_acc = calc_metrics(test_loader)
137
+
138
+ print("-" * 30)
139
+ print(f"Best validation AUC: {best_auc:.4f}")
predictor/__init__.py ADDED
File without changes
predictor/link_predictor.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ import torch
3
+ import structlog
4
+
5
+ from sentence_transformers import SentenceTransformer
6
+ from model.simple_gcn_model import SimpleGCN
7
+ from dataset.ogbn_link_pred_dataset import OGBNLinkPredDataset
8
+
9
+
10
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
11
+
12
+ structlog.configure(
13
+ processors=[
14
+ structlog.processors.TimeStamper(fmt="iso"),
15
+ structlog.processors.JSONRenderer(indent=4, sort_keys=True),
16
+ ]
17
+ )
18
+ logger = structlog.get_logger()
19
+
20
+
21
+ def abstract_to_vector(
22
+ title: str, abstract_text: str, st_model: SentenceTransformer
23
+ ) -> torch.Tensor:
24
+ text = title + "\n" + abstract_text
25
+ with torch.no_grad():
26
+ vector = st_model.encode(text, convert_to_tensor=True, device=DEVICE)
27
+ return vector.unsqueeze(0)
28
+
29
+
30
+ def get_citation_predictions(
31
+ vector: torch.Tensor, model: SimpleGCN, z_all: torch.Tensor, num_nodes: int
32
+ ) -> torch.Tensor:
33
+ model.eval()
34
+ with torch.no_grad():
35
+ empty_edge_index = torch.empty(2, 0, dtype=torch.long, device=DEVICE)
36
+ h1_new = model.conv1(vector, edge_index=empty_edge_index).relu()
37
+ z_new = model.conv2(h1_new, edge_index=empty_edge_index)
38
+
39
+ new_node_idx = num_nodes
40
+ row = torch.full((num_nodes,), fill_value=new_node_idx, device=DEVICE)
41
+ col = torch.arange(num_nodes, device=DEVICE)
42
+ edge_label_index_to_check = torch.stack([row, col], dim=0)
43
+
44
+ z_combined = torch.cat([z_all, z_new], dim=0)
45
+
46
+ with torch.no_grad():
47
+ logits = model.decode(z_combined, edge_label_index_to_check)
48
+
49
+ return torch.sigmoid(logits)
50
+
51
+
52
+ def format_top_k_predictions(
53
+ probs: torch.Tensor, dataset: OGBNLinkPredDataset, top_k=10., show_prob=False
54
+ ) -> str:
55
+ """
56
+ Formats the top K predictions into a single string for display.
57
+
58
+ Args:
59
+ probs (torch.Tensor): The tensor of probabilities for all potential links.
60
+ dataset (OGBNLinkPredDataset): The dataset object containing the corpus.
61
+ top_k (int): The number of top predictions to format.
62
+
63
+ Returns:
64
+ str: A formatted string with the top K predictions.
65
+ """
66
+ probs = probs.cpu()
67
+ top_probs, top_indices = torch.topk(probs, k=top_k)
68
+
69
+ output_lines = []
70
+
71
+ header = f"Top {top_k} Citation Predictions:"
72
+ output_lines.append(header)
73
+
74
+ for i in range(top_k):
75
+ paper_idx = top_indices[i].item()
76
+ prob = top_probs[i].item()
77
+ paper_info = dataset.corpus[paper_idx]
78
+ paper_title = paper_info.split("\n")[0]
79
+ if show_prob:
80
+ line = f" - Title: '{paper_title.strip()}', Probability: {prob:.4f}"
81
+ else:
82
+ line = f" - Title: '{paper_title.strip()}'"
83
+ output_lines.append(line)
84
+
85
+ return "\n".join(output_lines)
86
+
87
+
88
+ def prepare_system(model_path: Path):
89
+ """
90
+ Performs all one-time, expensive operations to prepare the system.
91
+ Initializes models, loads data, and pre-calculates embeddings using structured logging.
92
+ """
93
+ logger.info("system_preparation.start")
94
+
95
+ dataset = OGBNLinkPredDataset()
96
+ data = dataset.data.to(DEVICE)
97
+ logger.info("dataset.load.success")
98
+
99
+ model_name = "bongsoo/kpf-sbert-128d-v1"
100
+ logger.info(
101
+ "model.load.start", model_type="SentenceTransformer", model_name=model_name
102
+ )
103
+ st_model = SentenceTransformer(model_name, device=DEVICE)
104
+ logger.info("model.load.success", model_type="SentenceTransformer")
105
+
106
+ gcn_model = SimpleGCN(
107
+ in_channels=dataset.num_features, hidden_channels=256, out_channels=128
108
+ ).to(DEVICE)
109
+
110
+ if model_path.exists():
111
+ gcn_model.load_state_dict(torch.load(model_path, map_location=DEVICE))
112
+ logger.info("model.load.success", model_type="GCN", path=str(model_path))
113
+ else:
114
+ logger.warning(
115
+ "model.load.failure",
116
+ model_type="GCN",
117
+ path=str(model_path),
118
+ reason="File not found, using random weights.",
119
+ )
120
+ gcn_model.eval()
121
+
122
+ logger.info("embeddings.calculation.start", embedding_name="z_all")
123
+ with torch.no_grad():
124
+ z_all = gcn_model(data.x, data.edge_index)
125
+
126
+ logger.info(
127
+ "embeddings.calculation.success",
128
+ embedding_name="z_all",
129
+ shape=list(z_all.shape),
130
+ )
131
+
132
+ logger.info("system_preparation.finish", status="ready_for_predictions")
133
+ return gcn_model, st_model, dataset, z_all
134
+
135
+
136
+ if __name__ == "__main__":
137
+ MODEL_PATH = Path("model.pth")
138
+
139
+ gcn_model, st_model, dataset, z_all = prepare_system(MODEL_PATH)
140
+
141
+ my_title = "A Survey of Graph Neural Networks for Link Prediction"
142
+ my_abstract = """Link predictor is a critical task in graph analysis. "
143
+ "In this paper, we review various GNN architectures like GCN and GraphSAGE for predicting edges.
144
+ """
145
+
146
+ new_vector = abstract_to_vector(my_title, my_abstract, st_model)
147
+
148
+ probabilities = get_citation_predictions(
149
+ vector=new_vector,
150
+ model=gcn_model,
151
+ z_all=z_all,
152
+ num_nodes=dataset.data.num_nodes,
153
+ )
154
+
155
+ references = format_top_k_predictions(probabilities, dataset, top_k=5)
156
+ print(references)
pyproject.toml ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [tool.poetry]
2
+ name = "galis"
3
+ version = "0.3.0"
4
+ description = ""
5
+ authors = ["Perunio <[email protected]>"]
6
+ readme = "README.md"
7
+ packages = [{include = "galis"}]
8
+
9
+ [tool.poetry.dependencies]
10
+ python = ">=3.12, <3.13"
11
+
12
+ torch = [
13
+ {version = "2.3.1+cu121", source = "pytorch-cuda", markers = "sys_platform == 'linux' or sys_platform == 'win32'"},
14
+ {version = "^2.3.1", source = "pytorch-cpu", markers = "sys_platform == 'darwin'"}
15
+ ]
16
+
17
+ ogb = "^1.3.6"
18
+ torch-geometric = "^2.6.1"
19
+ pandas = "^2.3.1"
20
+ streamlit = "^1.46.1"
21
+ numpy = "1.26.4"
22
+ streamlit-extras = "^0.7.5"
23
+ sentence-transformers = "2.7.0"
24
+ transformers = "4.39.3"
25
+ ruff = "^0.12.7"
26
+ structlog = "^25.4.0"
27
+ langchain-google-genai = "^2.1.9"
28
+ langchain-core = "^0.3.72"
29
+ langchain = "^0.3.27"
30
+ python-dotenv = "^1.1.1"
31
+ hf-transfer = "^0.1.9"
32
+
33
+ [[tool.poetry.source]]
34
+ name = "pytorch-cuda"
35
+ url = "https://download.pytorch.org/whl/cu121"
36
+ priority = "explicit"
37
+
38
+ [[tool.poetry.source]]
39
+ name = "pytorch-cpu"
40
+ url = "https://download.pytorch.org/whl/cpu"
41
+ priority = "explicit"
42
+
43
+ [[tool.poetry.source]]
44
+ name = "PyPI"
45
+ priority = "primary"
46
+
47
+ [build-system]
48
+ requires = ["poetry-core>=2.0.0,<3.0.0"]
49
+ build-backend = "poetry.core.masonry.api"