| import torch |
| from location_encoder import get_neural_network, get_positional_encoding, LocationEncoder |
|
|
|
|
| def get_satclip_loc_encoder(ckpt_path, device): |
| ckpt = torch.load(ckpt_path,map_location=device) |
| hp = ckpt['hyper_parameters'] |
|
|
| posenc = get_positional_encoding( |
| hp['le_type'], |
| hp['legendre_polys'], |
| hp['harmonics_calculation'], |
| hp['min_radius'], |
| hp['max_radius'], |
| hp['frequency_num'] |
| ) |
| |
| nnet = get_neural_network( |
| hp['pe_type'], |
| posenc.embedding_dim, |
| hp['embed_dim'], |
| hp['capacity'], |
| hp['num_hidden_layers'] |
| ) |
|
|
| |
| state_dict = ckpt['state_dict'] |
| state_dict = {k[k.index('nnet'):]:state_dict[k] |
| for k in state_dict.keys() if 'nnet' in k} |
| |
| loc_encoder = LocationEncoder(posenc, nnet).double() |
| loc_encoder.load_state_dict(state_dict) |
| loc_encoder.eval() |
|
|
| return loc_encoder |
| |