Spaces:
Sleeping
Sleeping
alpha app
Browse files- .dockerignore +13 -0
- .gitignore +215 -0
- Dockerfile +23 -0
- dataset/__init__.py +0 -0
- dataset/ogbn_link_pred_dataset.py +101 -0
- galis_app.py +132 -0
- llm/__init__.py +0 -0
- llm/related_work_generator.py +124 -0
- model/__init__.py +0 -0
- model/simple_gcn_model.py +37 -0
- model/train.py +139 -0
- predictor/__init__.py +0 -0
- predictor/link_predictor.py +156 -0
- pyproject.toml +49 -0
.dockerignore
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__pycache__
|
| 2 |
+
*.pyc
|
| 3 |
+
.git
|
| 4 |
+
node_modules
|
| 5 |
+
.env
|
| 6 |
+
README.md
|
| 7 |
+
Dockerfile
|
| 8 |
+
|
| 9 |
+
# data folder
|
| 10 |
+
data
|
| 11 |
+
predictor/data
|
| 12 |
+
model/data
|
| 13 |
+
dataset/data
|
.gitignore
ADDED
|
@@ -0,0 +1,215 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Byte-compiled / optimized / DLL files
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[codz]
|
| 4 |
+
*$py.class
|
| 5 |
+
|
| 6 |
+
# C extensions
|
| 7 |
+
*.so
|
| 8 |
+
|
| 9 |
+
# Distribution / packaging
|
| 10 |
+
.Python
|
| 11 |
+
build/
|
| 12 |
+
develop-eggs/
|
| 13 |
+
dist/
|
| 14 |
+
downloads/
|
| 15 |
+
eggs/
|
| 16 |
+
.eggs/
|
| 17 |
+
lib/
|
| 18 |
+
lib64/
|
| 19 |
+
parts/
|
| 20 |
+
sdist/
|
| 21 |
+
var/
|
| 22 |
+
wheels/
|
| 23 |
+
share/python-wheels/
|
| 24 |
+
*.egg-info/
|
| 25 |
+
.installed.cfg
|
| 26 |
+
*.egg
|
| 27 |
+
MANIFEST
|
| 28 |
+
|
| 29 |
+
# PyInstaller
|
| 30 |
+
# Usually these files are written by a python script from a template
|
| 31 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
| 32 |
+
*.manifest
|
| 33 |
+
*.spec
|
| 34 |
+
|
| 35 |
+
# Installer logs
|
| 36 |
+
pip-log.txt
|
| 37 |
+
pip-delete-this-directory.txt
|
| 38 |
+
|
| 39 |
+
# Unit test / coverage reports
|
| 40 |
+
htmlcov/
|
| 41 |
+
.tox/
|
| 42 |
+
.nox/
|
| 43 |
+
.coverage
|
| 44 |
+
.coverage.*
|
| 45 |
+
.cache
|
| 46 |
+
nosetests.xml
|
| 47 |
+
coverage.xml
|
| 48 |
+
*.cover
|
| 49 |
+
*.py.cover
|
| 50 |
+
.hypothesis/
|
| 51 |
+
.pytest_cache/
|
| 52 |
+
cover/
|
| 53 |
+
|
| 54 |
+
# Translations
|
| 55 |
+
*.mo
|
| 56 |
+
*.pot
|
| 57 |
+
|
| 58 |
+
# Django stuff:
|
| 59 |
+
*.log
|
| 60 |
+
local_settings.py
|
| 61 |
+
db.sqlite3
|
| 62 |
+
db.sqlite3-journal
|
| 63 |
+
|
| 64 |
+
# Flask stuff:
|
| 65 |
+
instance/
|
| 66 |
+
.webassets-cache
|
| 67 |
+
|
| 68 |
+
# Scrapy stuff:
|
| 69 |
+
.scrapy
|
| 70 |
+
|
| 71 |
+
# Sphinx documentation
|
| 72 |
+
docs/_build/
|
| 73 |
+
|
| 74 |
+
# PyBuilder
|
| 75 |
+
.pybuilder/
|
| 76 |
+
target/
|
| 77 |
+
|
| 78 |
+
# Jupyter Notebook
|
| 79 |
+
.ipynb_checkpoints
|
| 80 |
+
|
| 81 |
+
# IPython
|
| 82 |
+
profile_default/
|
| 83 |
+
ipython_config.py
|
| 84 |
+
|
| 85 |
+
# pyenv
|
| 86 |
+
# For a library or package, you might want to ignore these files since the code is
|
| 87 |
+
# intended to run in multiple environments; otherwise, check them in:
|
| 88 |
+
# .python-version
|
| 89 |
+
|
| 90 |
+
# pipenv
|
| 91 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
| 92 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
| 93 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
| 94 |
+
# install all needed dependencies.
|
| 95 |
+
#Pipfile.lock
|
| 96 |
+
|
| 97 |
+
# UV
|
| 98 |
+
# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
|
| 99 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
| 100 |
+
# commonly ignored for libraries.
|
| 101 |
+
#uv.lock
|
| 102 |
+
|
| 103 |
+
# poetry
|
| 104 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
| 105 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
| 106 |
+
# commonly ignored for libraries.
|
| 107 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
| 108 |
+
#poetry.lock
|
| 109 |
+
#poetry.toml
|
| 110 |
+
|
| 111 |
+
# pdm
|
| 112 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
| 113 |
+
# pdm recommends including project-wide configuration in pdm.toml, but excluding .pdm-python.
|
| 114 |
+
# https://pdm-project.org/en/latest/usage/project/#working-with-version-control
|
| 115 |
+
#pdm.lock
|
| 116 |
+
#pdm.toml
|
| 117 |
+
.pdm-python
|
| 118 |
+
.pdm-build/
|
| 119 |
+
|
| 120 |
+
# pixi
|
| 121 |
+
# Similar to Pipfile.lock, it is generally recommended to include pixi.lock in version control.
|
| 122 |
+
#pixi.lock
|
| 123 |
+
# Pixi creates a virtual environment in the .pixi directory, just like venv module creates one
|
| 124 |
+
# in the .venv directory. It is recommended not to include this directory in version control.
|
| 125 |
+
.pixi
|
| 126 |
+
|
| 127 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
| 128 |
+
__pypackages__/
|
| 129 |
+
|
| 130 |
+
# Celery stuff
|
| 131 |
+
celerybeat-schedule
|
| 132 |
+
celerybeat.pid
|
| 133 |
+
|
| 134 |
+
# SageMath parsed files
|
| 135 |
+
*.sage.py
|
| 136 |
+
|
| 137 |
+
# Environments
|
| 138 |
+
.env
|
| 139 |
+
.envrc
|
| 140 |
+
.venv
|
| 141 |
+
env/
|
| 142 |
+
venv/
|
| 143 |
+
ENV/
|
| 144 |
+
env.bak/
|
| 145 |
+
venv.bak/
|
| 146 |
+
|
| 147 |
+
# Spyder project settings
|
| 148 |
+
.spyderproject
|
| 149 |
+
.spyproject
|
| 150 |
+
|
| 151 |
+
# Rope project settings
|
| 152 |
+
.ropeproject
|
| 153 |
+
|
| 154 |
+
# mkdocs documentation
|
| 155 |
+
/site
|
| 156 |
+
|
| 157 |
+
# mypy
|
| 158 |
+
.mypy_cache/
|
| 159 |
+
.dmypy.json
|
| 160 |
+
dmypy.json
|
| 161 |
+
|
| 162 |
+
# Pyre type checker
|
| 163 |
+
.pyre/
|
| 164 |
+
|
| 165 |
+
# pytype static type analyzer
|
| 166 |
+
.pytype/
|
| 167 |
+
|
| 168 |
+
# Cython debug symbols
|
| 169 |
+
cython_debug/
|
| 170 |
+
|
| 171 |
+
# PyCharm
|
| 172 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
| 173 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
| 174 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
| 175 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
| 176 |
+
.idea/
|
| 177 |
+
|
| 178 |
+
# Abstra
|
| 179 |
+
# Abstra is an AI-powered process automation framework.
|
| 180 |
+
# Ignore directories containing user credentials, local state, and settings.
|
| 181 |
+
# Learn more at https://abstra.io/docs
|
| 182 |
+
.abstra/
|
| 183 |
+
|
| 184 |
+
# Visual Studio Code
|
| 185 |
+
# Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore
|
| 186 |
+
# that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore
|
| 187 |
+
# and can be added to the global gitignore or merged into this file. However, if you prefer,
|
| 188 |
+
# you could uncomment the following to ignore the entire vscode folder
|
| 189 |
+
.vscode/
|
| 190 |
+
|
| 191 |
+
# Ruff stuff:
|
| 192 |
+
.ruff_cache/
|
| 193 |
+
|
| 194 |
+
# PyPI configuration file
|
| 195 |
+
.pypirc
|
| 196 |
+
|
| 197 |
+
# Cursor
|
| 198 |
+
# Cursor is an AI-powered code editor. `.cursorignore` specifies files/directories to
|
| 199 |
+
# exclude from AI features like autocomplete and code analysis. Recommended for sensitive data
|
| 200 |
+
# refer to https://docs.cursor.com/context/ignore-files
|
| 201 |
+
.cursorignore
|
| 202 |
+
.cursorindexingignore
|
| 203 |
+
|
| 204 |
+
# Marimo
|
| 205 |
+
marimo/_static/
|
| 206 |
+
marimo/_lsp/
|
| 207 |
+
__marimo__/
|
| 208 |
+
|
| 209 |
+
# data
|
| 210 |
+
data
|
| 211 |
+
predictor/data
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
# model
|
| 215 |
+
*.pth
|
Dockerfile
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.12.4-slim
|
| 2 |
+
|
| 3 |
+
WORKDIR /app
|
| 4 |
+
|
| 5 |
+
RUN pip install poetry
|
| 6 |
+
|
| 7 |
+
COPY pyproject.toml ./
|
| 8 |
+
RUN poetry install --no-root
|
| 9 |
+
RUN poetry run pip install torch-scatter torch-sparse torch-cluster pyg-lib -f https://data.pyg.org/whl/torch-2.3.1+cu121.html
|
| 10 |
+
RUN poetry run pip install torch-geometric
|
| 11 |
+
|
| 12 |
+
COPY galis_app.py ./
|
| 13 |
+
COPY model ./model
|
| 14 |
+
COPY dataset ./dataset
|
| 15 |
+
COPY predictor ./predictor
|
| 16 |
+
COPY llm ./llm
|
| 17 |
+
|
| 18 |
+
ENV GOOGLE_API_KEY=""
|
| 19 |
+
|
| 20 |
+
EXPOSE 7860
|
| 21 |
+
|
| 22 |
+
CMD ["poetry", "run", "streamlit", "run", "galis_app.py", "--server.port=7860", "--server.address=0.0.0.0"]
|
| 23 |
+
|
dataset/__init__.py
ADDED
|
File without changes
|
dataset/ogbn_link_pred_dataset.py
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import pandas as pd
|
| 3 |
+
import torch
|
| 4 |
+
from ogb.nodeproppred import PygNodePropPredDataset
|
| 5 |
+
from torch_geometric.transforms import RandomLinkSplit
|
| 6 |
+
from torch_geometric.loader import LinkNeighborLoader
|
| 7 |
+
from torch_geometric.data import Data
|
| 8 |
+
|
| 9 |
+
import requests
|
| 10 |
+
import gzip
|
| 11 |
+
import shutil
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class OGBNLinkPredDataset:
|
| 15 |
+
def __init__(
|
| 16 |
+
self, root_dir: str = "data", val_size: float = 0.1, test_size: float = 0.2
|
| 17 |
+
):
|
| 18 |
+
self._base_dataset = PygNodePropPredDataset(name="ogbn-arxiv", root=root_dir)
|
| 19 |
+
self.data = self._base_dataset[0]
|
| 20 |
+
self.root = self._base_dataset.root
|
| 21 |
+
self.num_features = self._base_dataset.num_features
|
| 22 |
+
|
| 23 |
+
self._download_abstracts()
|
| 24 |
+
self.corpus = self._load_corpus()
|
| 25 |
+
|
| 26 |
+
self.train_data, self.val_data, self.test_data = self._split_data(
|
| 27 |
+
val_size, test_size
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
def _download_abstracts(self):
|
| 31 |
+
target_dir = os.path.join(self.root, "mapping")
|
| 32 |
+
tsv_path = os.path.join(target_dir, "titleabs.tsv")
|
| 33 |
+
|
| 34 |
+
if not os.path.exists(tsv_path):
|
| 35 |
+
print("Downloading title and abstract information...")
|
| 36 |
+
gz_path = tsv_path + ".gz"
|
| 37 |
+
url = "https://snap.stanford.edu/ogb/data/misc/ogbn_arxiv/titleabs.tsv.gz"
|
| 38 |
+
os.makedirs(target_dir, exist_ok=True)
|
| 39 |
+
|
| 40 |
+
try:
|
| 41 |
+
print(f"Downloading from {url}...")
|
| 42 |
+
response = requests.get(url, stream=True)
|
| 43 |
+
response.raise_for_status()
|
| 44 |
+
with open(gz_path, "wb") as f:
|
| 45 |
+
for chunk in response.iter_content(chunk_size=8192):
|
| 46 |
+
f.write(chunk)
|
| 47 |
+
print(f"File downloaded to: {gz_path}")
|
| 48 |
+
|
| 49 |
+
print(f"Decompressing {gz_path}...")
|
| 50 |
+
with gzip.open(gz_path, 'rb') as f_in:
|
| 51 |
+
with open(tsv_path, 'wb') as f_out:
|
| 52 |
+
shutil.copyfileobj(f_in, f_out)
|
| 53 |
+
print(f"File extracted to: {tsv_path}")
|
| 54 |
+
|
| 55 |
+
os.remove(gz_path)
|
| 56 |
+
print(f"Removed temporary file: {gz_path}")
|
| 57 |
+
|
| 58 |
+
except requests.exceptions.RequestException as e:
|
| 59 |
+
print(f"Error downloading file: {e}")
|
| 60 |
+
except Exception as e:
|
| 61 |
+
print(f"An error occurred: {e}")
|
| 62 |
+
|
| 63 |
+
else:
|
| 64 |
+
print("Title and abstract file already exists.")
|
| 65 |
+
|
| 66 |
+
def _load_corpus(self) -> list[str]:
|
| 67 |
+
tsv_path = os.path.join(self.root, "mapping", "titleabs.tsv")
|
| 68 |
+
try:
|
| 69 |
+
df_text = pd.read_csv(
|
| 70 |
+
tsv_path,
|
| 71 |
+
sep="\t",
|
| 72 |
+
header=None,
|
| 73 |
+
names=["paper_id", "title", "abstract"],
|
| 74 |
+
lineterminator="\n",
|
| 75 |
+
low_memory=False,
|
| 76 |
+
)
|
| 77 |
+
df_text_aligned = df_text.reset_index(drop=True)
|
| 78 |
+
corpus = (
|
| 79 |
+
df_text_aligned["title"].fillna("")
|
| 80 |
+
+ "\n "
|
| 81 |
+
+ df_text_aligned["abstract"].fillna("")
|
| 82 |
+
).tolist()
|
| 83 |
+
print(f"Corpus created with {len(corpus)} documents.")
|
| 84 |
+
return corpus
|
| 85 |
+
except FileNotFoundError:
|
| 86 |
+
print("Error: titleabs.tsv not found. Could not create corpus.")
|
| 87 |
+
return []
|
| 88 |
+
|
| 89 |
+
def _split_data(self, val_size: float, test_size: float) -> tuple[Data, Data, Data]:
|
| 90 |
+
transform = RandomLinkSplit(
|
| 91 |
+
num_val=val_size,
|
| 92 |
+
num_test=test_size,
|
| 93 |
+
is_undirected=False,
|
| 94 |
+
add_negative_train_samples=False,
|
| 95 |
+
)
|
| 96 |
+
train_split, val_split, test_split = transform(self.data)
|
| 97 |
+
print("Data successfully split into train, validation, and test sets.")
|
| 98 |
+
return train_split, val_split, test_split
|
| 99 |
+
|
| 100 |
+
def get_splits(self) -> tuple[Data, Data, Data]:
|
| 101 |
+
return self.train_data, self.val_data, self.test_data
|
galis_app.py
ADDED
|
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pathlib import Path
|
| 2 |
+
import streamlit as st
|
| 3 |
+
|
| 4 |
+
from predictor.link_predictor import (
|
| 5 |
+
prepare_system,
|
| 6 |
+
get_citation_predictions,
|
| 7 |
+
abstract_to_vector,
|
| 8 |
+
format_top_k_predictions,
|
| 9 |
+
)
|
| 10 |
+
from llm.related_work_generator import generate_related_work
|
| 11 |
+
|
| 12 |
+
MODEL_PATH = Path("predictor/model.pth")
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@st.cache_resource
|
| 16 |
+
def load_prediction_system(model_path):
|
| 17 |
+
return prepare_system(model_path)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def app():
|
| 21 |
+
st.set_page_config(page_title="Galis", layout="wide")
|
| 22 |
+
st.title("Galis")
|
| 23 |
+
|
| 24 |
+
if "references" not in st.session_state:
|
| 25 |
+
st.session_state.references = None
|
| 26 |
+
if "related_work" not in st.session_state:
|
| 27 |
+
st.session_state.related_work = None
|
| 28 |
+
if "abstract_title" not in st.session_state:
|
| 29 |
+
st.session_state.abstract_title = ""
|
| 30 |
+
if "abstract_text" not in st.session_state:
|
| 31 |
+
st.session_state.abstract_text = ""
|
| 32 |
+
|
| 33 |
+
gcn_model, st_model, dataset, z_all = load_prediction_system(MODEL_PATH)
|
| 34 |
+
|
| 35 |
+
col1, col2 = st.columns(2, gap="large")
|
| 36 |
+
|
| 37 |
+
with col2:
|
| 38 |
+
references_placeholder = st.empty()
|
| 39 |
+
related_work_placeholder = st.empty()
|
| 40 |
+
|
| 41 |
+
with col1:
|
| 42 |
+
st.header("Abstract Title")
|
| 43 |
+
abstract_title = st.text_input(
|
| 44 |
+
"Paste your title here",
|
| 45 |
+
st.session_state.abstract_title,
|
| 46 |
+
key="abstract_title_input",
|
| 47 |
+
label_visibility="collapsed",
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
st.header("Abstract Text")
|
| 51 |
+
abstract_input = st.text_area(
|
| 52 |
+
"Paste your abstract here",
|
| 53 |
+
st.session_state.abstract_text,
|
| 54 |
+
key="abstract_text_input",
|
| 55 |
+
height=100,
|
| 56 |
+
label_visibility="collapsed",
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
st.write("...or **upload** a .txt file (first line = title, rest = abstract)")
|
| 60 |
+
uploaded_file = st.file_uploader(
|
| 61 |
+
"Drag and drop file here", type=["txt"], help="Limit 200MB per file • TXT"
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
if uploaded_file is not None:
|
| 65 |
+
content = uploaded_file.getvalue().decode("utf-8").splitlines()
|
| 66 |
+
st.session_state.abstract_title = content[0] if content else ""
|
| 67 |
+
st.session_state.abstract_text = (
|
| 68 |
+
"\n".join(content[1:]) if len(content) > 1 else ""
|
| 69 |
+
)
|
| 70 |
+
st.rerun()
|
| 71 |
+
|
| 72 |
+
st.session_state.abstract_title = abstract_title
|
| 73 |
+
st.session_state.abstract_text = abstract_input
|
| 74 |
+
|
| 75 |
+
num_citations = st.number_input(
|
| 76 |
+
"Number of suggestions",
|
| 77 |
+
min_value=1,
|
| 78 |
+
max_value=100,
|
| 79 |
+
value=10,
|
| 80 |
+
step=1,
|
| 81 |
+
help="Choose how many paper suggestions you want to see.",
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
if st.button("Suggest References and related work", type="primary"):
|
| 85 |
+
if not abstract_title.strip() or not abstract_input.strip():
|
| 86 |
+
st.warning("Please provide both a title and an abstract.")
|
| 87 |
+
else:
|
| 88 |
+
st.session_state.references = None
|
| 89 |
+
st.session_state.related_work = None
|
| 90 |
+
references_placeholder.empty()
|
| 91 |
+
related_work_placeholder.empty()
|
| 92 |
+
|
| 93 |
+
with st.spinner("Analyzing abstract and predicting references..."):
|
| 94 |
+
new_vector = abstract_to_vector(
|
| 95 |
+
abstract_input, abstract_title, st_model
|
| 96 |
+
)
|
| 97 |
+
probabilities = get_citation_predictions(
|
| 98 |
+
vector=new_vector,
|
| 99 |
+
model=gcn_model,
|
| 100 |
+
z_all=z_all,
|
| 101 |
+
num_nodes=dataset.data.num_nodes,
|
| 102 |
+
)
|
| 103 |
+
references = format_top_k_predictions(
|
| 104 |
+
probabilities, dataset, top_k=num_citations
|
| 105 |
+
)
|
| 106 |
+
st.session_state.references = references
|
| 107 |
+
|
| 108 |
+
with references_placeholder.container():
|
| 109 |
+
st.header("Suggested References")
|
| 110 |
+
with st.container(height=200):
|
| 111 |
+
st.markdown(st.session_state.references)
|
| 112 |
+
|
| 113 |
+
with related_work_placeholder.container():
|
| 114 |
+
with st.spinner("Generating related work section..."):
|
| 115 |
+
related_work = generate_related_work(st.session_state.references)
|
| 116 |
+
st.session_state.related_work = related_work
|
| 117 |
+
|
| 118 |
+
if st.session_state.references:
|
| 119 |
+
with references_placeholder.container():
|
| 120 |
+
st.header("Suggested References")
|
| 121 |
+
with st.container(height=200):
|
| 122 |
+
st.markdown(st.session_state.references)
|
| 123 |
+
|
| 124 |
+
if st.session_state.related_work:
|
| 125 |
+
with related_work_placeholder.container():
|
| 126 |
+
st.header("Suggested Related Works")
|
| 127 |
+
with st.container(height=200):
|
| 128 |
+
st.markdown(st.session_state.related_work)
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
if __name__ == "__main__":
|
| 132 |
+
app()
|
llm/__init__.py
ADDED
|
File without changes
|
llm/related_work_generator.py
ADDED
|
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dotenv import load_dotenv
|
| 2 |
+
import os
|
| 3 |
+
import structlog
|
| 4 |
+
from langchain_google_genai import ChatGoogleGenerativeAI
|
| 5 |
+
from langchain_core.prompts import PromptTemplate
|
| 6 |
+
from langchain_core.output_parsers import StrOutputParser
|
| 7 |
+
|
| 8 |
+
structlog.configure(
|
| 9 |
+
processors=[
|
| 10 |
+
structlog.processors.TimeStamper(fmt="iso"),
|
| 11 |
+
structlog.processors.JSONRenderer(indent=4, sort_keys=True),
|
| 12 |
+
]
|
| 13 |
+
)
|
| 14 |
+
logger = structlog.get_logger()
|
| 15 |
+
|
| 16 |
+
load_dotenv()
|
| 17 |
+
|
| 18 |
+
PROMPT_TEXT = """
|
| 19 |
+
You are a research assistant AI specializing in academic writing. Your task is to generate a "Related Work" section
|
| 20 |
+
for a research paper. You will be given a list of citations.
|
| 21 |
+
|
| 22 |
+
Your goal is to synthesize the provided citations into a coherent and well-structured "Related Work" section that
|
| 23 |
+
contextualizes the user's project within the existing academic literature.
|
| 24 |
+
|
| 25 |
+
**PROVIDED CITATIONS:**
|
| 26 |
+
{citations}
|
| 27 |
+
|
| 28 |
+
**INSTRUCTIONS:**
|
| 29 |
+
|
| 30 |
+
1. **Thematic Organization:** Do not simply list summaries of the papers. Group the provided citations into thematic
|
| 31 |
+
categories based on shared concepts, methodologies, or research problems. For example, you could create categories like
|
| 32 |
+
"Transformer-based Language Models," "Sentiment Analysis Techniques," and "Efficient Models for NLP." Introduce each
|
| 33 |
+
theme before discussing the relevant papers.
|
| 34 |
+
|
| 35 |
+
2. **Synthesis and Analysis:** For each thematic group, synthesize the key contributions and findings of the papers.
|
| 36 |
+
Go beyond summarization; compare and contrast the different approaches. For instance, you could discuss the evolution
|
| 37 |
+
of certain methods or the trade-offs between different models (e.g., accuracy vs. computational efficiency).
|
| 38 |
+
|
| 39 |
+
3. **Identify Research Gaps:** Critically analyze the literature you are reviewing. Explicitly identify the
|
| 40 |
+
limitations, open questions, or research gaps that the cited works leave unresolved. This will set the stage for
|
| 41 |
+
introducing the project's contribution.
|
| 42 |
+
|
| 43 |
+
4. **Contextualize the User's Project:** After discussing a thematic group of papers and identifying a gap, clearly
|
| 44 |
+
and explicitly state how the user's project (described above) addresses this gap or builds upon the existing work. Use
|
| 45 |
+
phrases like: "While these methods have shown great success, they struggle with...", "To address this limitation, our
|
| 46 |
+
work introduces...", or "Building upon the foundation laid by [Author, Year], we propose a novel approach that...".
|
| 47 |
+
|
| 48 |
+
5. **Academic Tone and Flow:** Maintain a formal, objective, and academic tone throughout the text. Ensure smooth
|
| 49 |
+
transitions between paragraphs and ideas to create a coherent narrative that logically leads the reader to understand
|
| 50 |
+
the novelty and importance of the user's project.
|
| 51 |
+
|
| 52 |
+
6. **Output Format:** Generate only the text for the "Related Work" section. Do not include headers like
|
| 53 |
+
"INSTRUCTIONS" or "PROVIDED CITATIONS" in the final output. The entire response should be the section text itself.
|
| 54 |
+
"""
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def check_api_key():
|
| 58 |
+
api_key = os.getenv("GOOGLE_API_KEY")
|
| 59 |
+
if not api_key:
|
| 60 |
+
logger.error("GOOGLE_API_KEY not set")
|
| 61 |
+
return False
|
| 62 |
+
logger.info(f"Gemini API Key is loaded: {api_key[:10]}...")
|
| 63 |
+
return True
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def create_related_work_pipeline():
|
| 67 |
+
"""Creates a ready-to-use pipeline for generating the Related Work section."""
|
| 68 |
+
|
| 69 |
+
llm = ChatGoogleGenerativeAI(
|
| 70 |
+
model="gemini-1.5-flash",
|
| 71 |
+
temperature=0.3
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
prompt = PromptTemplate(
|
| 75 |
+
input_variables=["citations"],
|
| 76 |
+
template=PROMPT_TEXT
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
parser = StrOutputParser()
|
| 80 |
+
|
| 81 |
+
chain = prompt | llm | parser
|
| 82 |
+
|
| 83 |
+
return chain
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def generate_related_work(citations_text: str) -> str:
|
| 87 |
+
"""
|
| 88 |
+
Main function - pass citations, get Related Work
|
| 89 |
+
|
| 90 |
+
Args:
|
| 91 |
+
citations_text: Text with citations (can be a list or a string)
|
| 92 |
+
|
| 93 |
+
Returns:
|
| 94 |
+
The generated Related Work section
|
| 95 |
+
"""
|
| 96 |
+
pipeline = create_related_work_pipeline()
|
| 97 |
+
result = pipeline.invoke({"citations": citations_text})
|
| 98 |
+
return result
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
if __name__ == "__main__":
|
| 102 |
+
|
| 103 |
+
my_citations = """
|
| 104 |
+
Top 5 Citation Predictions:
|
| 105 |
+
- Title: 'deterministic construction of rip matrices in compressed sensing from constant weight codes'
|
| 106 |
+
- Title: 'mizar items exploring fine grained dependencies in the mizar mathematical library'
|
| 107 |
+
- Title: 'rateless lossy compression via the extremes'
|
| 108 |
+
- Title: 'towards autonomic service provisioning systems'
|
| 109 |
+
- Title: 'anonymization with worst case distribution based background knowledge'
|
| 110 |
+
"""
|
| 111 |
+
|
| 112 |
+
print("Generuję Related Work...")
|
| 113 |
+
print("=" * 50)
|
| 114 |
+
|
| 115 |
+
try:
|
| 116 |
+
related_work = generate_related_work(my_citations)
|
| 117 |
+
print(related_work)
|
| 118 |
+
except Exception as e:
|
| 119 |
+
print(f"Błąd: {e}")
|
| 120 |
+
print("\n=== INSTRUKCJE KONFIGURACJI ===")
|
| 121 |
+
print("1. Stwórz plik .env w tym samym folderze co skrypt")
|
| 122 |
+
print("2. Dodaj do niego linię: GOOGLE_API_KEY=twój_klucz")
|
| 123 |
+
print("3. Uzyskaj klucz na: https://makersuite.google.com/app/apikey")
|
| 124 |
+
check_api_key()
|
model/__init__.py
ADDED
|
File without changes
|
model/simple_gcn_model.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
from torch_geometric.nn import GCNConv
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class EdgeDecoder(torch.nn.Module):
|
| 7 |
+
"""Predict citation existence of two node embeddings."""
|
| 8 |
+
|
| 9 |
+
def __init__(self, in_channels):
|
| 10 |
+
super().__init__()
|
| 11 |
+
self.linear = torch.nn.Linear(in_channels * 2, 1)
|
| 12 |
+
|
| 13 |
+
def forward(self, z, edge_index):
|
| 14 |
+
row, col = edge_index
|
| 15 |
+
# Concatenate the embeddings of the two nodes
|
| 16 |
+
z_cat = torch.cat([z[row], z[col]], dim=-1)
|
| 17 |
+
return self.linear(z_cat).squeeze(-1)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class SimpleGCN(torch.nn.Module):
|
| 21 |
+
"""Include encoder and decoder part. Encoder creates embedding for given node and decoder predict link existence between node embeddings."""
|
| 22 |
+
|
| 23 |
+
def __init__(self, in_channels, hidden_channels, out_channels):
|
| 24 |
+
super().__init__()
|
| 25 |
+
self.conv1 = GCNConv(in_channels, hidden_channels)
|
| 26 |
+
self.conv2 = GCNConv(hidden_channels, out_channels)
|
| 27 |
+
self.decoder = EdgeDecoder(out_channels)
|
| 28 |
+
|
| 29 |
+
def forward(self, x, edge_index):
|
| 30 |
+
x = self.conv1(x, edge_index).relu()
|
| 31 |
+
x = F.dropout(x, p=0.5, training=self.training)
|
| 32 |
+
z = self.conv2(x, edge_index)
|
| 33 |
+
return z
|
| 34 |
+
|
| 35 |
+
def decode(self, z, edge_label_index):
|
| 36 |
+
# We pass the edge_label_index to the decoder, which contains both pos and neg edges
|
| 37 |
+
return self.decoder(z, edge_label_index)
|
model/train.py
ADDED
|
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
from torch_geometric.loader import LinkNeighborLoader
|
| 4 |
+
from sklearn.metrics import roc_auc_score, accuracy_score
|
| 5 |
+
from tqdm import tqdm
|
| 6 |
+
from model.simple_gcn_model import SimpleGCN
|
| 7 |
+
from dataset.ogbn_link_pred_dataset import OGBNLinkPredDataset
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
BATCH_SIZE = 128
|
| 11 |
+
NUM_EPOCHS = 20
|
| 12 |
+
LR = 0.001
|
| 13 |
+
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 14 |
+
|
| 15 |
+
# data
|
| 16 |
+
dataset = OGBNLinkPredDataset(val_size=0.1, test_size=0.2)
|
| 17 |
+
train_data, val_data, test_data = dataset.get_splits()
|
| 18 |
+
|
| 19 |
+
train_loader = LinkNeighborLoader(
|
| 20 |
+
train_data,
|
| 21 |
+
num_neighbors=[-1, -1], # Use all neighbors
|
| 22 |
+
neg_sampling_ratio=1.0, # 1 negative sample per positive edge
|
| 23 |
+
edge_label_index=train_data.edge_label_index,
|
| 24 |
+
edge_label=train_data.edge_label,
|
| 25 |
+
batch_size=BATCH_SIZE,
|
| 26 |
+
shuffle=True,
|
| 27 |
+
num_workers=4,
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
val_loader = LinkNeighborLoader(
|
| 31 |
+
val_data,
|
| 32 |
+
num_neighbors=[-1, -1],
|
| 33 |
+
neg_sampling_ratio=0.0, # RandomLinkSplit already added negative edges
|
| 34 |
+
edge_label_index=val_data.edge_label_index,
|
| 35 |
+
edge_label=val_data.edge_label,
|
| 36 |
+
batch_size=BATCH_SIZE,
|
| 37 |
+
shuffle=False,
|
| 38 |
+
num_workers=4,
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
test_loader = LinkNeighborLoader(
|
| 42 |
+
test_data,
|
| 43 |
+
num_neighbors=[-1, -1],
|
| 44 |
+
neg_sampling_ratio=0.0,
|
| 45 |
+
edge_label_index=test_data.edge_label_index,
|
| 46 |
+
edge_label=test_data.edge_label,
|
| 47 |
+
batch_size=BATCH_SIZE,
|
| 48 |
+
shuffle=False,
|
| 49 |
+
num_workers=4,
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
# model
|
| 53 |
+
model = SimpleGCN(
|
| 54 |
+
in_channels=dataset.num_features,
|
| 55 |
+
hidden_channels=256,
|
| 56 |
+
out_channels=128,
|
| 57 |
+
).to(DEVICE)
|
| 58 |
+
|
| 59 |
+
optimizer = torch.optim.Adam(model.parameters(), lr=LR)
|
| 60 |
+
criterion = torch.nn.BCEWithLogitsLoss()
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
# training
|
| 64 |
+
def train(train_loader, epoch):
|
| 65 |
+
model.train()
|
| 66 |
+
total_loss = 0
|
| 67 |
+
scaler = torch.GradScaler()
|
| 68 |
+
|
| 69 |
+
pbar = tqdm(train_loader, desc=f"Training Epoch: {epoch}")
|
| 70 |
+
for batch in pbar:
|
| 71 |
+
batch = batch.to(DEVICE)
|
| 72 |
+
optimizer.zero_grad()
|
| 73 |
+
|
| 74 |
+
with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True):
|
| 75 |
+
z = model(batch.x, batch.edge_index)
|
| 76 |
+
out = model.decode(z, batch.edge_label_index)
|
| 77 |
+
labels = batch.edge_label.float()
|
| 78 |
+
|
| 79 |
+
loss = criterion(out, labels)
|
| 80 |
+
|
| 81 |
+
scaler.scale(loss).backward()
|
| 82 |
+
scaler.step(optimizer)
|
| 83 |
+
scaler.update()
|
| 84 |
+
|
| 85 |
+
total_loss += loss.item()
|
| 86 |
+
pbar.set_postfix(loss=f"{loss.item():.4f}")
|
| 87 |
+
|
| 88 |
+
return total_loss / len(train_loader)
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
@torch.no_grad()
|
| 92 |
+
def calc_metrics(loader):
|
| 93 |
+
model.eval()
|
| 94 |
+
all_scores = []
|
| 95 |
+
all_labels = []
|
| 96 |
+
|
| 97 |
+
pbar = tqdm(loader, desc="Testing")
|
| 98 |
+
for batch in pbar:
|
| 99 |
+
batch = batch.to(DEVICE)
|
| 100 |
+
with torch.autocast(device_type=DEVICE.type, dtype=torch.bfloat16):
|
| 101 |
+
z = model(batch.x, batch.edge_index)
|
| 102 |
+
out = model.decode(z, batch.edge_label_index)
|
| 103 |
+
|
| 104 |
+
scores = torch.sigmoid(out).float().cpu().numpy()
|
| 105 |
+
labels = batch.edge_label.cpu().numpy()
|
| 106 |
+
|
| 107 |
+
all_scores.append(scores)
|
| 108 |
+
all_labels.append(labels)
|
| 109 |
+
|
| 110 |
+
all_scores = np.concatenate(all_scores)
|
| 111 |
+
all_labels = np.concatenate(all_labels)
|
| 112 |
+
|
| 113 |
+
return roc_auc_score(all_labels, all_scores), accuracy_score(
|
| 114 |
+
all_labels, all_scores > 0.5
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
if __name__ == "__main__":
|
| 119 |
+
best_val_auc = 0
|
| 120 |
+
best_auc = 0
|
| 121 |
+
for epoch in range(1, NUM_EPOCHS + 1):
|
| 122 |
+
loss = train(train_loader, epoch)
|
| 123 |
+
val_auc, val_acc = calc_metrics(val_loader)
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
print(
|
| 127 |
+
f"Epoch: {epoch:03d}, Loss: {loss:.4f}, Val AUC: {val_auc:.4f}, Val acc: {val_acc:.4f}",
|
| 128 |
+
end=" ",
|
| 129 |
+
)
|
| 130 |
+
if val_auc > best_val_auc:
|
| 131 |
+
print("New best")
|
| 132 |
+
best_val_auc = val_auc
|
| 133 |
+
best_auc = val_auc
|
| 134 |
+
torch.save(model.state_dict(), "model.pth")
|
| 135 |
+
|
| 136 |
+
test_auc, test_acc = calc_metrics(test_loader)
|
| 137 |
+
|
| 138 |
+
print("-" * 30)
|
| 139 |
+
print(f"Best validation AUC: {best_auc:.4f}")
|
predictor/__init__.py
ADDED
|
File without changes
|
predictor/link_predictor.py
ADDED
|
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pathlib import Path
|
| 2 |
+
import torch
|
| 3 |
+
import structlog
|
| 4 |
+
|
| 5 |
+
from sentence_transformers import SentenceTransformer
|
| 6 |
+
from model.simple_gcn_model import SimpleGCN
|
| 7 |
+
from dataset.ogbn_link_pred_dataset import OGBNLinkPredDataset
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 11 |
+
|
| 12 |
+
structlog.configure(
|
| 13 |
+
processors=[
|
| 14 |
+
structlog.processors.TimeStamper(fmt="iso"),
|
| 15 |
+
structlog.processors.JSONRenderer(indent=4, sort_keys=True),
|
| 16 |
+
]
|
| 17 |
+
)
|
| 18 |
+
logger = structlog.get_logger()
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def abstract_to_vector(
|
| 22 |
+
title: str, abstract_text: str, st_model: SentenceTransformer
|
| 23 |
+
) -> torch.Tensor:
|
| 24 |
+
text = title + "\n" + abstract_text
|
| 25 |
+
with torch.no_grad():
|
| 26 |
+
vector = st_model.encode(text, convert_to_tensor=True, device=DEVICE)
|
| 27 |
+
return vector.unsqueeze(0)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def get_citation_predictions(
|
| 31 |
+
vector: torch.Tensor, model: SimpleGCN, z_all: torch.Tensor, num_nodes: int
|
| 32 |
+
) -> torch.Tensor:
|
| 33 |
+
model.eval()
|
| 34 |
+
with torch.no_grad():
|
| 35 |
+
empty_edge_index = torch.empty(2, 0, dtype=torch.long, device=DEVICE)
|
| 36 |
+
h1_new = model.conv1(vector, edge_index=empty_edge_index).relu()
|
| 37 |
+
z_new = model.conv2(h1_new, edge_index=empty_edge_index)
|
| 38 |
+
|
| 39 |
+
new_node_idx = num_nodes
|
| 40 |
+
row = torch.full((num_nodes,), fill_value=new_node_idx, device=DEVICE)
|
| 41 |
+
col = torch.arange(num_nodes, device=DEVICE)
|
| 42 |
+
edge_label_index_to_check = torch.stack([row, col], dim=0)
|
| 43 |
+
|
| 44 |
+
z_combined = torch.cat([z_all, z_new], dim=0)
|
| 45 |
+
|
| 46 |
+
with torch.no_grad():
|
| 47 |
+
logits = model.decode(z_combined, edge_label_index_to_check)
|
| 48 |
+
|
| 49 |
+
return torch.sigmoid(logits)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def format_top_k_predictions(
|
| 53 |
+
probs: torch.Tensor, dataset: OGBNLinkPredDataset, top_k=10., show_prob=False
|
| 54 |
+
) -> str:
|
| 55 |
+
"""
|
| 56 |
+
Formats the top K predictions into a single string for display.
|
| 57 |
+
|
| 58 |
+
Args:
|
| 59 |
+
probs (torch.Tensor): The tensor of probabilities for all potential links.
|
| 60 |
+
dataset (OGBNLinkPredDataset): The dataset object containing the corpus.
|
| 61 |
+
top_k (int): The number of top predictions to format.
|
| 62 |
+
|
| 63 |
+
Returns:
|
| 64 |
+
str: A formatted string with the top K predictions.
|
| 65 |
+
"""
|
| 66 |
+
probs = probs.cpu()
|
| 67 |
+
top_probs, top_indices = torch.topk(probs, k=top_k)
|
| 68 |
+
|
| 69 |
+
output_lines = []
|
| 70 |
+
|
| 71 |
+
header = f"Top {top_k} Citation Predictions:"
|
| 72 |
+
output_lines.append(header)
|
| 73 |
+
|
| 74 |
+
for i in range(top_k):
|
| 75 |
+
paper_idx = top_indices[i].item()
|
| 76 |
+
prob = top_probs[i].item()
|
| 77 |
+
paper_info = dataset.corpus[paper_idx]
|
| 78 |
+
paper_title = paper_info.split("\n")[0]
|
| 79 |
+
if show_prob:
|
| 80 |
+
line = f" - Title: '{paper_title.strip()}', Probability: {prob:.4f}"
|
| 81 |
+
else:
|
| 82 |
+
line = f" - Title: '{paper_title.strip()}'"
|
| 83 |
+
output_lines.append(line)
|
| 84 |
+
|
| 85 |
+
return "\n".join(output_lines)
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def prepare_system(model_path: Path):
|
| 89 |
+
"""
|
| 90 |
+
Performs all one-time, expensive operations to prepare the system.
|
| 91 |
+
Initializes models, loads data, and pre-calculates embeddings using structured logging.
|
| 92 |
+
"""
|
| 93 |
+
logger.info("system_preparation.start")
|
| 94 |
+
|
| 95 |
+
dataset = OGBNLinkPredDataset()
|
| 96 |
+
data = dataset.data.to(DEVICE)
|
| 97 |
+
logger.info("dataset.load.success")
|
| 98 |
+
|
| 99 |
+
model_name = "bongsoo/kpf-sbert-128d-v1"
|
| 100 |
+
logger.info(
|
| 101 |
+
"model.load.start", model_type="SentenceTransformer", model_name=model_name
|
| 102 |
+
)
|
| 103 |
+
st_model = SentenceTransformer(model_name, device=DEVICE)
|
| 104 |
+
logger.info("model.load.success", model_type="SentenceTransformer")
|
| 105 |
+
|
| 106 |
+
gcn_model = SimpleGCN(
|
| 107 |
+
in_channels=dataset.num_features, hidden_channels=256, out_channels=128
|
| 108 |
+
).to(DEVICE)
|
| 109 |
+
|
| 110 |
+
if model_path.exists():
|
| 111 |
+
gcn_model.load_state_dict(torch.load(model_path, map_location=DEVICE))
|
| 112 |
+
logger.info("model.load.success", model_type="GCN", path=str(model_path))
|
| 113 |
+
else:
|
| 114 |
+
logger.warning(
|
| 115 |
+
"model.load.failure",
|
| 116 |
+
model_type="GCN",
|
| 117 |
+
path=str(model_path),
|
| 118 |
+
reason="File not found, using random weights.",
|
| 119 |
+
)
|
| 120 |
+
gcn_model.eval()
|
| 121 |
+
|
| 122 |
+
logger.info("embeddings.calculation.start", embedding_name="z_all")
|
| 123 |
+
with torch.no_grad():
|
| 124 |
+
z_all = gcn_model(data.x, data.edge_index)
|
| 125 |
+
|
| 126 |
+
logger.info(
|
| 127 |
+
"embeddings.calculation.success",
|
| 128 |
+
embedding_name="z_all",
|
| 129 |
+
shape=list(z_all.shape),
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
logger.info("system_preparation.finish", status="ready_for_predictions")
|
| 133 |
+
return gcn_model, st_model, dataset, z_all
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
if __name__ == "__main__":
|
| 137 |
+
MODEL_PATH = Path("model.pth")
|
| 138 |
+
|
| 139 |
+
gcn_model, st_model, dataset, z_all = prepare_system(MODEL_PATH)
|
| 140 |
+
|
| 141 |
+
my_title = "A Survey of Graph Neural Networks for Link Prediction"
|
| 142 |
+
my_abstract = """Link predictor is a critical task in graph analysis. "
|
| 143 |
+
"In this paper, we review various GNN architectures like GCN and GraphSAGE for predicting edges.
|
| 144 |
+
"""
|
| 145 |
+
|
| 146 |
+
new_vector = abstract_to_vector(my_title, my_abstract, st_model)
|
| 147 |
+
|
| 148 |
+
probabilities = get_citation_predictions(
|
| 149 |
+
vector=new_vector,
|
| 150 |
+
model=gcn_model,
|
| 151 |
+
z_all=z_all,
|
| 152 |
+
num_nodes=dataset.data.num_nodes,
|
| 153 |
+
)
|
| 154 |
+
|
| 155 |
+
references = format_top_k_predictions(probabilities, dataset, top_k=5)
|
| 156 |
+
print(references)
|
pyproject.toml
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[tool.poetry]
|
| 2 |
+
name = "galis"
|
| 3 |
+
version = "0.3.0"
|
| 4 |
+
description = ""
|
| 5 |
+
authors = ["Perunio <[email protected]>"]
|
| 6 |
+
readme = "README.md"
|
| 7 |
+
packages = [{include = "galis"}]
|
| 8 |
+
|
| 9 |
+
[tool.poetry.dependencies]
|
| 10 |
+
python = ">=3.12, <3.13"
|
| 11 |
+
|
| 12 |
+
torch = [
|
| 13 |
+
{version = "2.3.1+cu121", source = "pytorch-cuda", markers = "sys_platform == 'linux' or sys_platform == 'win32'"},
|
| 14 |
+
{version = "^2.3.1", source = "pytorch-cpu", markers = "sys_platform == 'darwin'"}
|
| 15 |
+
]
|
| 16 |
+
|
| 17 |
+
ogb = "^1.3.6"
|
| 18 |
+
torch-geometric = "^2.6.1"
|
| 19 |
+
pandas = "^2.3.1"
|
| 20 |
+
streamlit = "^1.46.1"
|
| 21 |
+
numpy = "1.26.4"
|
| 22 |
+
streamlit-extras = "^0.7.5"
|
| 23 |
+
sentence-transformers = "2.7.0"
|
| 24 |
+
transformers = "4.39.3"
|
| 25 |
+
ruff = "^0.12.7"
|
| 26 |
+
structlog = "^25.4.0"
|
| 27 |
+
langchain-google-genai = "^2.1.9"
|
| 28 |
+
langchain-core = "^0.3.72"
|
| 29 |
+
langchain = "^0.3.27"
|
| 30 |
+
python-dotenv = "^1.1.1"
|
| 31 |
+
hf-transfer = "^0.1.9"
|
| 32 |
+
|
| 33 |
+
[[tool.poetry.source]]
|
| 34 |
+
name = "pytorch-cuda"
|
| 35 |
+
url = "https://download.pytorch.org/whl/cu121"
|
| 36 |
+
priority = "explicit"
|
| 37 |
+
|
| 38 |
+
[[tool.poetry.source]]
|
| 39 |
+
name = "pytorch-cpu"
|
| 40 |
+
url = "https://download.pytorch.org/whl/cpu"
|
| 41 |
+
priority = "explicit"
|
| 42 |
+
|
| 43 |
+
[[tool.poetry.source]]
|
| 44 |
+
name = "PyPI"
|
| 45 |
+
priority = "primary"
|
| 46 |
+
|
| 47 |
+
[build-system]
|
| 48 |
+
requires = ["poetry-core>=2.0.0,<3.0.0"]
|
| 49 |
+
build-backend = "poetry.core.masonry.api"
|