import os import torch.nn.functional as F import random from torch_sparse import SparseTensor import pandas as pd import torch from ogb.nodeproppred import PygNodePropPredDataset from torch_geometric.transforms import RandomLinkSplit from torch_geometric.data import Data import requests import gzip import shutil class OGBNLinkPredDataset: def __init__( self, root_dir: str = "data", val_size: float = 0.1, test_size: float = 0.2 ): self._base_dataset = PygNodePropPredDataset(name="ogbn-arxiv", root=root_dir) self.data = self._base_dataset[0] self.root = self._base_dataset.root self.num_features = self._base_dataset.num_features self._download_abstracts() self.corpus = self._load_corpus() self.val_size = val_size self.test_size = test_size def _download_abstracts(self): target_dir = os.path.join(self.root, "mapping") tsv_path = os.path.join(target_dir, "titleabs.tsv") if not os.path.exists(tsv_path): print("Downloading title and abstract information...") gz_path = tsv_path + ".gz" url = "https://snap.stanford.edu/ogb/data/misc/ogbn_arxiv/titleabs.tsv.gz" os.makedirs(target_dir, exist_ok=True) try: response = requests.get(url, stream=True) response.raise_for_status() with open(gz_path, "wb") as f: for chunk in response.iter_content(chunk_size=8192): f.write(chunk) with gzip.open(gz_path, "rb") as f_in: with open(tsv_path, "wb") as f_out: shutil.copyfileobj(f_in, f_out) os.remove(gz_path) except requests.exceptions.RequestException as e: print(f"Error downloading file: {e}") except Exception as e: print(f"An error occurred: {e}") else: print("Title and abstract file already exists.") def _load_corpus(self) -> list[str]: tsv_path = os.path.join(self.root, "mapping", "titleabs.tsv") try: df_text = pd.read_csv( tsv_path, sep="\t", header=None, names=["paper_id", "title", "abstract"], lineterminator="\n", low_memory=False, ) df_text_aligned = df_text.reset_index(drop=True) corpus = ( df_text_aligned["title"].fillna("") + "\n " + df_text_aligned["abstract"].fillna("") ).tolist() return corpus except FileNotFoundError: print("Error: titleabs.tsv not found. Could not create corpus.") return [] def get_splits(self) -> tuple[Data, Data, Data]: transform = RandomLinkSplit( num_val=self.val_size, num_test=self.test_size, is_undirected=False, add_negative_train_samples=True, neg_sampling_ratio=1.0, ) train_split, val_split, test_split = transform(self.data) return train_split, val_split, test_split class OGBNLinkPredNegDataset(OGBNLinkPredDataset): """Degree similar hard negatives sampling""" def __init__( self, root_dir: str = "data", val_size: float = 0.1, test_size: float = 0.2 ): super().__init__(root_dir, val_size, test_size) self.degree_tol = 0 def get_splits(self) -> tuple[Data, Data, Data]: transform = RandomLinkSplit( num_val=self.val_size, num_test=self.test_size, is_undirected=False, add_negative_train_samples=False, neg_sampling_ratio=0.0, ) train_split, val_split, test_split = transform(self.data) print("Generating hard negatives...") adj_matrix = SparseTensor.from_edge_index( train_split.edge_index, # only from train_split sparse_sizes=(self.data.num_nodes, self.data.num_nodes), ) self.degrees = adj_matrix.sum(dim=0).to(torch.long) # to prevent creating negative edges that are positive in other split self.all_edge_set = set(zip(*self.data.edge_index.tolist())) train_split = self._add_balanced_negs(train_split) val_split = self._add_balanced_negs(val_split) test_split = self._add_balanced_negs(test_split) return train_split, val_split, test_split def _add_balanced_negs(self, split_data): assert (split_data.edge_label == 1).all(), "Expected only positive edges" pos_edges = split_data.edge_label_index pos_list = pos_edges.t().tolist() num_negs = pos_edges.size(1) negs = [] for _ in range(num_negs): u, v_orig = random.choice(pos_list) target_deg = int(self.degrees[v_orig]) found = False for _ in range(20): w = random.randrange(self.data.num_nodes) if ( (u, w) not in self.all_edge_set and w != u and abs(int(self.degrees[w]) - target_deg) <= self.degree_tol ): negs.append((u, w)) found = True break if not found: while True: w = random.randrange(self.data.num_nodes) if (u, w) not in self.all_edge_set and w != u: negs.append((u, w)) break neg_edges = torch.tensor(negs, dtype=torch.long).t() N = pos_edges.size(1) split_data.edge_label_index = torch.cat([pos_edges, neg_edges], dim=1) split_data.edge_label = torch.cat( [ torch.ones(N, dtype=torch.long, device=pos_edges.device), torch.zeros(N, dtype=torch.long, device=pos_edges.device), ] ) return split_data # class OGBNLinkPredNegDataset2(OGBNLinkPredDataset): # """Degree and semantically similar hard negatives sampling""" # # def __init__(self, root_dir="data", val_size=0.1, test_size=0.2): # super().__init__(root_dir, val_size, test_size) # # def get_splits(self) -> tuple[Data, Data, Data]: # transform = RandomLinkSplit( # num_val=self.val_size, # num_test=self.test_size, # is_undirected=False, # add_negative_train_samples=False, # neg_sampling_ratio=0.0, # ) # train_split, val_split, test_split = transform(self.data) # # print("Generating semantic hard negatives...") # train_split = self._add_balanced_negs(train_split) # val_split = self._add_balanced_negs(val_split) # test_split = self._add_balanced_negs(test_split) # return train_split, val_split, test_split # # def _add_balanced_negs(self, split_data): # assert (split_data.edge_label == 1).all(), "Expected only positive edges" # # BS = 1_000 # B = self.data.x.to("cuda", dtype=torch.bfloat16) # (num_nodes, dim) # B = F.normalize(B, p=2, dim=1) # K = 100 # # pos_edges = split_data.edge_label_index # adj_matrix = SparseTensor.from_edge_index( # split_data.edge_index, # sparse_sizes=(self.data.num_nodes, self.data.num_nodes), # ) # degrees = adj_matrix.sum(dim=0).to("cuda") # # topk_val = torch.empty((BS, K), dtype=torch.bfloat16, device="cuda") # topk_idx = torch.empty((BS, K), dtype=torch.int64, device="cuda") # # neg_edges = [] # # for i in range(0, pos_edges.shape[1], BS): # batch_end = min(i + BS, pos_edges.shape[1]) # src_idx = pos_edges[0, i:batch_end] # (batch_size,) # dst_idx = pos_edges[1, i:batch_end] # (batch_size,) # # A = B[src_idx] # (batch_size, dim) # # with torch.autocast("cuda", dtype=torch.bfloat16): # sim = torch.mm(A, B.t()) # equivalent to cos-sim # # # mask for similarity with itself and existing edges # sim[torch.arange(len(A)), dst_idx] = -1 # sim[torch.arange(len(A)), src_idx] = -1 # # TODO: exclude edges from val&test sets # # torch.topk(sim, K, out=(topk_val, topk_idx)) # topk_idx2 = topk_idx[: len(A)] # # # sample degree matched negs # topk_deg = degrees[topk_idx2] # src_deg = degrees[src_idx] # # deg_diffs = torch.abs(topk_deg - src_deg.unsqueeze(1)) # closest_idx = torch.argmin(deg_diffs, dim=1) # (batch_size,) # sampled_negs = topk_idx[ # torch.arange(len(A), device="cuda"), closest_idx # ] # neg_edges.append(sampled_negs) # # neg_dsts = torch.cat(neg_edges, dim=0).to("cpu") # neg_edge_index = torch.stack([pos_edges[0].cpu(), neg_dsts], dim=0) # edge_label_index = torch.cat([pos_edges.cpu(), neg_edge_index], dim=1) # edge_label = torch.cat( # [split_data.edge_label, torch.zeros(neg_dsts.shape[0])], dim=0 # ) # assert edge_label.shape[0] == edge_label_index.shape[1], ( # "Label and index shape mismatch" # ) # assert len(neg_dsts) == pos_edges.shape[1], ( # "Expected same amount of positive and negative edges" # ) # return Data( # x=split_data.x, # edge_index=edge_label, # edge_label_index=edge_label_index, # edge_label=edge_label, # ) if __name__ == "__main__": dataset = OGBNLinkPredNegDataset() train, val, test = dataset.get_splits() def extract_pos_neg_edges(split): pos = split.edge_label_index[:, split.edge_label == 1] neg = split.edge_label_index[:, split.edge_label == 0] return pos, neg for name, split in [("train", train), ("val", val), ("test", test)]: assert split.edge_label_index.shape[0] == 2, ( f"{name}: edge_label_index must have 2 rows" ) assert split.edge_label_index.shape[1] == split.edge_label.shape[0], ( f"{name}: label/index shape mismatch" ) assert torch.all(0 <= split.edge_label) and torch.all(split.edge_label <= 1), ( f"{name}: labels not 0/1" ) pos, neg = extract_pos_neg_edges(split) assert pos.size(1) == neg.size(1), f"{name}: pos/neg count mismatch" pos_set = set(tuple(e) for e in pos.t().tolist()) neg_set = set(tuple(e) for e in neg.t().tolist()) assert pos_set.isdisjoint(neg_set), f"{name}: pos/neg overlap" assert all(u != v for u, v in pos_set), f"{name}: pos self-loops" assert all(u != v for u, v in neg_set), f"{name}: neg self-loops" assert len(pos_set) == pos.size(1), f"{name}: pos duplicates" assert len(neg_set) == neg.size(1), f"{name}: neg duplicates" assert pos.size(1) / neg.size(1) == 1.0 if neg.size(1) > 0 else True, ( f"{name}: ratio not 1.0" ) print("All asserts passed!")