from pathlib import Path from tqdm import tqdm import pandas as pd from datasets import load_dataset import os import json import gc from utils import empty_solution from predict import predict_wireframe from fast_pointnet_v2 import load_pointnet_model from fast_pointnet_class import load_pointnet_model as load_pointnet_class_model import torch if __name__ == "__main__": print ("------------ Loading dataset------------ ") param_path = Path('params.json') print(param_path) with param_path.open() as f: params = json.load(f) print(params) import os print('pwd:') os.system('pwd') print(os.system('ls -lahtr')) print('/generic/path/to/data_dir/') # Placeholder for '/tmp/data/' print(os.system('ls -lahtr /generic/path/to/data_dir/')) # Placeholder for /tmp/data/ print('/generic/path/to/data_dir/data') # Placeholder for '/tmp/data/data' print(os.system('ls -lahtrR /generic/path/to/data_dir/data')) # Placeholder for /tmp/data/data data_path_test_server = Path('/generic/path/to/data_dir') # Placeholder for Path('/tmp/data') data_path_local = Path("/generic/path/to/user_home") / '.cache/huggingface/datasets/usm3d___hoho25k_test_x/' # Placeholder for Path().home() if data_path_test_server.exists(): # data_path = data_path_test_server TEST_ENV = True else: # data_path = data_path_local TEST_ENV = False from huggingface_hub import snapshot_download _ = snapshot_download( repo_id=params['dataset'], local_dir="/generic/path/to/data_dir", # Placeholder for "/tmp/data" repo_type="dataset", ) data_path = data_path_test_server print(data_path) # dataset = load_dataset(params['dataset'], trust_remote_code=True, use_auth_token=params['token']) # data_files = { # "validation": [str(p) for p in [*data_path.rglob('*validation*.arrow')]+[*data_path.rglob('*public*/**/*.tar')]], # "test": [str(p) for p in [*data_path.rglob('*test*.arrow')]+[*data_path.rglob('*private*/**/*.tar')]], # } data_files = { "validation": [str(p) for p in data_path.rglob('*public*/**/*.tar')], "test": [str(p) for p in data_path.rglob('*private*/**/*.tar')], } print(data_files) dataset = load_dataset( str(data_path / 'hoho25k_test_x.py'), data_files=data_files, trust_remote_code=True, writer_batch_size=100 ) print('load with webdataset') print(dataset, flush=True) device = "cuda" if torch.cuda.is_available() else "cpu" pnet_model = load_pointnet_model(model_path="pnet.pth", device=device, predict_score=True) pnet_class_model = load_pointnet_class_model(model_path="pnet_class.pth", device=device) voxel_model = None config = {'vertex_threshold': 0.59, 'edge_threshold': 0.65, 'only_predicted_connections': True} print('------------ Now you can do your solution ---------------') solution = [] def process_sample(sample, i): try: pred_vertices, pred_edges = predict_wireframe(sample, pnet_model, voxel_model, pnet_class_model, config) except: pred_vertices, pred_edges = empty_solution() if i %10 == 0: gc.collect() return { 'order_id': sample['order_id'], 'wf_vertices': pred_vertices.tolist(), 'wf_edges': pred_edges } num_cores = 4 for subset_name in dataset.keys(): print (f"Predicting {subset_name}") for i, sample in enumerate(tqdm(dataset[subset_name])): res = process_sample(sample, i) solution.append(res) print('------------ Saving results ---------------') sub = pd.DataFrame(solution, columns=["order_id", "wf_vertices", "wf_edges"]) sub.to_parquet("submission.parquet") print("------------ Done ------------ ")