| #!/usr/bin/env python3 | |
| import sys | |
| from collections import OrderedDict | |
| import torch | |
| # Load and keep backup | |
| m_input = torch.load("2_Dense/pytorch_model.bin") | |
| torch.save(m_input, "2_Dense/pytorch_model.bin.bak") | |
| # Mappings | |
| rename = {"layer.weight": "linear.weight"} | |
| # Output | |
| m_output = OrderedDict() | |
| for key, params in m_input.items(): | |
| dst = key | |
| if key in rename: | |
| print(f"Mapping {key} to {rename[key]}", file=sys.stderr) | |
| dst = rename[key] | |
| m_output[dst] = params | |
| torch.save(m_output, "2_Dense/pytorch_model.bin") | |