galis / predictor /link_predictor.py
Perunio's picture
updated model
e49b23b
import torch
import torch.nn.functional as F
from dataset.ogbn_link_pred_dataset import OGBNLinkPredDataset
from pathlib import Path
import structlog
from sentence_transformers import SentenceTransformer
from model.mlp import edge_features, PairMLP
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
structlog.configure(
processors=[
structlog.processors.TimeStamper(fmt="iso"),
structlog.processors.JSONRenderer(indent=4, sort_keys=True),
]
)
logger = structlog.get_logger()
def abstract_to_vector(
title: str, abstract_text: str, st_model: SentenceTransformer
) -> torch.Tensor:
text = title + "\n" + abstract_text
with torch.no_grad():
vector = st_model.encode(text, convert_to_tensor=True, device=DEVICE)
return vector
def get_citation_predictions(
vector: torch.Tensor,
model: PairMLP,
z_all: torch.Tensor,
num_nodes: int,
) -> torch.Tensor:
model.eval()
with torch.no_grad():
combined_embeddings = torch.cat([vector.view(1, -1), z_all], dim=0)
edge_index = torch.tensor([[0] * num_nodes, list(range(1, num_nodes + 1))]).to(
DEVICE
)
feat = edge_features(combined_embeddings, edge_index).to(DEVICE)
scores = torch.sigmoid(model(feat))
return scores.squeeze()
def format_top_k_predictions(
probs: torch.Tensor, dataset: OGBNLinkPredDataset, top_k=10, show_prob=False
) -> str:
probs = probs.cpu()
top_probs, top_indices = torch.topk(probs, k=top_k)
output_lines = []
header = f"Top {top_k} Citation Predictions:"
output_lines.append(header)
for i in range(top_k):
paper_idx = top_indices[i].item()
prob = top_probs[i].item()
paper_info = dataset.corpus[paper_idx]
paper_title = paper_info.split("\n")[0]
if show_prob:
line = f" - Title: '{paper_title.strip()}', Probability: {prob:.4f}"
else:
line = f" - Title: '{paper_title.strip()}'"
output_lines.append(line)
return "\n".join(output_lines)
def prepare_system(model_path: Path):
logger.info("system_preparation.start")
dataset = OGBNLinkPredDataset()
logger.info("dataset.load.success")
model_name = "bongsoo/kpf-sbert-128d-v1"
logger.info(
"model.load.start", model_type="SentenceTransformer", model_name=model_name
)
st_model = SentenceTransformer(model_name, device=DEVICE)
logger.info("model.load.success", model_type="SentenceTransformer")
# Load corpus embeddings
if Path("model/embeddings.pth").exists():
corpus_embeddings = torch.load("model/embeddings.pth", map_location=DEVICE)
logger.info("embeddings.load.success")
else:
logger.info("embeddings.calculation.start")
corpus_embeddings = st_model.encode(
dataset.corpus, convert_to_tensor=True, show_progress_bar=True
)
Path("model").mkdir(parents=True, exist_ok=True)
torch.save(corpus_embeddings, "model/embeddings.pth")
logger.info("embeddings.calculation.success")
corpus_embeddings = F.normalize(corpus_embeddings.to(DEVICE), p=2, dim=1)
# Initialize PairMLP
embedding_dim = corpus_embeddings.size(1)
pair_mlp = PairMLP(embedding_dim * 2).to(DEVICE)
if model_path.exists():
pair_mlp.load_state_dict(torch.load(model_path, map_location=DEVICE))
logger.info("model.load.success", model_type="PairMLP", path=str(model_path))
else:
logger.warning(
"model.load.failure",
model_type="PairMLP",
path=str(model_path),
reason="File not found, using random weights.",
)
pair_mlp.eval()
logger.info(
"embeddings.calculation.success",
embedding_name="corpus_embeddings",
shape=list(corpus_embeddings.shape),
)
logger.info("system_preparation.finish", status="ready_for_predictions")
return pair_mlp, st_model, dataset, corpus_embeddings
if __name__ == "__main__":
MODEL_PATH = Path("model.pth")
pair_model, st_model, dataset, corpus_embeddings = prepare_system(MODEL_PATH)
my_title = "A Survey of Graph Neural Networks for Link Prediction"
my_abstract = """Link prediction is a critical task in graph analysis.
In this paper, we review various GNN architectures like GCN and GraphSAGE for predicting edges."""
new_vector = abstract_to_vector(my_title, my_abstract, st_model)
new_vector = F.normalize(
new_vector.view(1, -1), p=2, dim=1
) # Normalize like corpus embeddings
probabilities = get_citation_predictions(
vector=new_vector,
model=pair_model,
z_all=corpus_embeddings,
num_nodes=dataset.data.num_nodes,
)
references = format_top_k_predictions(
probabilities, dataset, top_k=5, show_prob=True
)
print(references)