| import torch | |
| from transformers import AutoModelForMaskedLM, AutoTokenizer, FlaxAutoModelForMaskedLM | |
| from datasets import load_dataset | |
| from wechsel import WECHSEL, load_embeddings | |
| source_tokenizer = AutoTokenizer.from_pretrained("roberta-large") | |
| model = AutoModelForMaskedLM.from_pretrained("roberta-large") | |
| target_tokenizer = AutoTokenizer.from_pretrained("./") | |
| wechsel = WECHSEL( | |
| load_embeddings("en"), | |
| load_embeddings("fi"), | |
| bilingual_dictionary="finnish" | |
| ) | |
| target_embeddings, info = wechsel.apply( | |
| source_tokenizer, | |
| target_tokenizer, | |
| model.get_input_embeddings().weight.detach().numpy(), | |
| ) | |
| model.get_input_embeddings().weight.data = torch.from_numpy(target_embeddings).to(torch.float32) | |
| model.save_pretrained("./") | |
| # flax_model = FlaxAutoModelForMaskedLM.from_pretrained("./", from_pt=True) | |
| # flax_model.save_pretrained("./") | |