Spaces:
Sleeping
Sleeping
| import os | |
| import torch.nn.functional as F | |
| import random | |
| from torch_sparse import SparseTensor | |
| import pandas as pd | |
| import torch | |
| from ogb.nodeproppred import PygNodePropPredDataset | |
| from torch_geometric.transforms import RandomLinkSplit | |
| from torch_geometric.data import Data | |
| import requests | |
| import gzip | |
| import shutil | |
| class OGBNLinkPredDataset: | |
| def __init__( | |
| self, root_dir: str = "data", val_size: float = 0.1, test_size: float = 0.2 | |
| ): | |
| self._base_dataset = PygNodePropPredDataset(name="ogbn-arxiv", root=root_dir) | |
| self.data = self._base_dataset[0] | |
| self.root = self._base_dataset.root | |
| self.num_features = self._base_dataset.num_features | |
| self._download_abstracts() | |
| self.corpus = self._load_corpus() | |
| self.val_size = val_size | |
| self.test_size = test_size | |
| def _download_abstracts(self): | |
| target_dir = os.path.join(self.root, "mapping") | |
| tsv_path = os.path.join(target_dir, "titleabs.tsv") | |
| if not os.path.exists(tsv_path): | |
| print("Downloading title and abstract information...") | |
| gz_path = tsv_path + ".gz" | |
| url = "https://snap.stanford.edu/ogb/data/misc/ogbn_arxiv/titleabs.tsv.gz" | |
| os.makedirs(target_dir, exist_ok=True) | |
| try: | |
| response = requests.get(url, stream=True) | |
| response.raise_for_status() | |
| with open(gz_path, "wb") as f: | |
| for chunk in response.iter_content(chunk_size=8192): | |
| f.write(chunk) | |
| with gzip.open(gz_path, "rb") as f_in: | |
| with open(tsv_path, "wb") as f_out: | |
| shutil.copyfileobj(f_in, f_out) | |
| os.remove(gz_path) | |
| except requests.exceptions.RequestException as e: | |
| print(f"Error downloading file: {e}") | |
| except Exception as e: | |
| print(f"An error occurred: {e}") | |
| else: | |
| print("Title and abstract file already exists.") | |
| def _load_corpus(self) -> list[str]: | |
| tsv_path = os.path.join(self.root, "mapping", "titleabs.tsv") | |
| try: | |
| df_text = pd.read_csv( | |
| tsv_path, | |
| sep="\t", | |
| header=None, | |
| names=["paper_id", "title", "abstract"], | |
| lineterminator="\n", | |
| low_memory=False, | |
| ) | |
| df_text_aligned = df_text.reset_index(drop=True) | |
| corpus = ( | |
| df_text_aligned["title"].fillna("") | |
| + "\n " | |
| + df_text_aligned["abstract"].fillna("") | |
| ).tolist() | |
| return corpus | |
| except FileNotFoundError: | |
| print("Error: titleabs.tsv not found. Could not create corpus.") | |
| return [] | |
| def get_splits(self) -> tuple[Data, Data, Data]: | |
| transform = RandomLinkSplit( | |
| num_val=self.val_size, | |
| num_test=self.test_size, | |
| is_undirected=False, | |
| add_negative_train_samples=True, | |
| neg_sampling_ratio=1.0, | |
| ) | |
| train_split, val_split, test_split = transform(self.data) | |
| return train_split, val_split, test_split | |
| class OGBNLinkPredNegDataset(OGBNLinkPredDataset): | |
| """Degree similar hard negatives sampling""" | |
| def __init__( | |
| self, root_dir: str = "data", val_size: float = 0.1, test_size: float = 0.2 | |
| ): | |
| super().__init__(root_dir, val_size, test_size) | |
| self.degree_tol = 0 | |
| def get_splits(self) -> tuple[Data, Data, Data]: | |
| transform = RandomLinkSplit( | |
| num_val=self.val_size, | |
| num_test=self.test_size, | |
| is_undirected=False, | |
| add_negative_train_samples=False, | |
| neg_sampling_ratio=0.0, | |
| ) | |
| train_split, val_split, test_split = transform(self.data) | |
| print("Generating hard negatives...") | |
| adj_matrix = SparseTensor.from_edge_index( | |
| train_split.edge_index, # only from train_split | |
| sparse_sizes=(self.data.num_nodes, self.data.num_nodes), | |
| ) | |
| self.degrees = adj_matrix.sum(dim=0).to(torch.long) | |
| # to prevent creating negative edges that are positive in other split | |
| self.all_edge_set = set(zip(*self.data.edge_index.tolist())) | |
| train_split = self._add_balanced_negs(train_split) | |
| val_split = self._add_balanced_negs(val_split) | |
| test_split = self._add_balanced_negs(test_split) | |
| return train_split, val_split, test_split | |
| def _add_balanced_negs(self, split_data): | |
| assert (split_data.edge_label == 1).all(), "Expected only positive edges" | |
| pos_edges = split_data.edge_label_index | |
| pos_list = pos_edges.t().tolist() | |
| num_negs = pos_edges.size(1) | |
| negs = [] | |
| for _ in range(num_negs): | |
| u, v_orig = random.choice(pos_list) | |
| target_deg = int(self.degrees[v_orig]) | |
| found = False | |
| for _ in range(20): | |
| w = random.randrange(self.data.num_nodes) | |
| if ( | |
| (u, w) not in self.all_edge_set | |
| and w != u | |
| and abs(int(self.degrees[w]) - target_deg) <= self.degree_tol | |
| ): | |
| negs.append((u, w)) | |
| found = True | |
| break | |
| if not found: | |
| while True: | |
| w = random.randrange(self.data.num_nodes) | |
| if (u, w) not in self.all_edge_set and w != u: | |
| negs.append((u, w)) | |
| break | |
| neg_edges = torch.tensor(negs, dtype=torch.long).t() | |
| N = pos_edges.size(1) | |
| split_data.edge_label_index = torch.cat([pos_edges, neg_edges], dim=1) | |
| split_data.edge_label = torch.cat( | |
| [ | |
| torch.ones(N, dtype=torch.long, device=pos_edges.device), | |
| torch.zeros(N, dtype=torch.long, device=pos_edges.device), | |
| ] | |
| ) | |
| return split_data | |
| # class OGBNLinkPredNegDataset2(OGBNLinkPredDataset): | |
| # """Degree and semantically similar hard negatives sampling""" | |
| # | |
| # def __init__(self, root_dir="data", val_size=0.1, test_size=0.2): | |
| # super().__init__(root_dir, val_size, test_size) | |
| # | |
| # def get_splits(self) -> tuple[Data, Data, Data]: | |
| # transform = RandomLinkSplit( | |
| # num_val=self.val_size, | |
| # num_test=self.test_size, | |
| # is_undirected=False, | |
| # add_negative_train_samples=False, | |
| # neg_sampling_ratio=0.0, | |
| # ) | |
| # train_split, val_split, test_split = transform(self.data) | |
| # | |
| # print("Generating semantic hard negatives...") | |
| # train_split = self._add_balanced_negs(train_split) | |
| # val_split = self._add_balanced_negs(val_split) | |
| # test_split = self._add_balanced_negs(test_split) | |
| # return train_split, val_split, test_split | |
| # | |
| # def _add_balanced_negs(self, split_data): | |
| # assert (split_data.edge_label == 1).all(), "Expected only positive edges" | |
| # | |
| # BS = 1_000 | |
| # B = self.data.x.to("cuda", dtype=torch.bfloat16) # (num_nodes, dim) | |
| # B = F.normalize(B, p=2, dim=1) | |
| # K = 100 | |
| # | |
| # pos_edges = split_data.edge_label_index | |
| # adj_matrix = SparseTensor.from_edge_index( | |
| # split_data.edge_index, | |
| # sparse_sizes=(self.data.num_nodes, self.data.num_nodes), | |
| # ) | |
| # degrees = adj_matrix.sum(dim=0).to("cuda") | |
| # | |
| # topk_val = torch.empty((BS, K), dtype=torch.bfloat16, device="cuda") | |
| # topk_idx = torch.empty((BS, K), dtype=torch.int64, device="cuda") | |
| # | |
| # neg_edges = [] | |
| # | |
| # for i in range(0, pos_edges.shape[1], BS): | |
| # batch_end = min(i + BS, pos_edges.shape[1]) | |
| # src_idx = pos_edges[0, i:batch_end] # (batch_size,) | |
| # dst_idx = pos_edges[1, i:batch_end] # (batch_size,) | |
| # | |
| # A = B[src_idx] # (batch_size, dim) | |
| # | |
| # with torch.autocast("cuda", dtype=torch.bfloat16): | |
| # sim = torch.mm(A, B.t()) # equivalent to cos-sim | |
| # | |
| # # mask for similarity with itself and existing edges | |
| # sim[torch.arange(len(A)), dst_idx] = -1 | |
| # sim[torch.arange(len(A)), src_idx] = -1 | |
| # # TODO: exclude edges from val&test sets | |
| # | |
| # torch.topk(sim, K, out=(topk_val, topk_idx)) | |
| # topk_idx2 = topk_idx[: len(A)] | |
| # | |
| # # sample degree matched negs | |
| # topk_deg = degrees[topk_idx2] | |
| # src_deg = degrees[src_idx] | |
| # | |
| # deg_diffs = torch.abs(topk_deg - src_deg.unsqueeze(1)) | |
| # closest_idx = torch.argmin(deg_diffs, dim=1) # (batch_size,) | |
| # sampled_negs = topk_idx[ | |
| # torch.arange(len(A), device="cuda"), closest_idx | |
| # ] | |
| # neg_edges.append(sampled_negs) | |
| # | |
| # neg_dsts = torch.cat(neg_edges, dim=0).to("cpu") | |
| # neg_edge_index = torch.stack([pos_edges[0].cpu(), neg_dsts], dim=0) | |
| # edge_label_index = torch.cat([pos_edges.cpu(), neg_edge_index], dim=1) | |
| # edge_label = torch.cat( | |
| # [split_data.edge_label, torch.zeros(neg_dsts.shape[0])], dim=0 | |
| # ) | |
| # assert edge_label.shape[0] == edge_label_index.shape[1], ( | |
| # "Label and index shape mismatch" | |
| # ) | |
| # assert len(neg_dsts) == pos_edges.shape[1], ( | |
| # "Expected same amount of positive and negative edges" | |
| # ) | |
| # return Data( | |
| # x=split_data.x, | |
| # edge_index=edge_label, | |
| # edge_label_index=edge_label_index, | |
| # edge_label=edge_label, | |
| # ) | |
| if __name__ == "__main__": | |
| dataset = OGBNLinkPredNegDataset() | |
| train, val, test = dataset.get_splits() | |
| def extract_pos_neg_edges(split): | |
| pos = split.edge_label_index[:, split.edge_label == 1] | |
| neg = split.edge_label_index[:, split.edge_label == 0] | |
| return pos, neg | |
| for name, split in [("train", train), ("val", val), ("test", test)]: | |
| assert split.edge_label_index.shape[0] == 2, ( | |
| f"{name}: edge_label_index must have 2 rows" | |
| ) | |
| assert split.edge_label_index.shape[1] == split.edge_label.shape[0], ( | |
| f"{name}: label/index shape mismatch" | |
| ) | |
| assert torch.all(0 <= split.edge_label) and torch.all(split.edge_label <= 1), ( | |
| f"{name}: labels not 0/1" | |
| ) | |
| pos, neg = extract_pos_neg_edges(split) | |
| assert pos.size(1) == neg.size(1), f"{name}: pos/neg count mismatch" | |
| pos_set = set(tuple(e) for e in pos.t().tolist()) | |
| neg_set = set(tuple(e) for e in neg.t().tolist()) | |
| assert pos_set.isdisjoint(neg_set), f"{name}: pos/neg overlap" | |
| assert all(u != v for u, v in pos_set), f"{name}: pos self-loops" | |
| assert all(u != v for u, v in neg_set), f"{name}: neg self-loops" | |
| assert len(pos_set) == pos.size(1), f"{name}: pos duplicates" | |
| assert len(neg_set) == neg.size(1), f"{name}: neg duplicates" | |
| assert pos.size(1) / neg.size(1) == 1.0 if neg.size(1) > 0 else True, ( | |
| f"{name}: ratio not 1.0" | |
| ) | |
| print("All asserts passed!") | |