Spaces:
Running
on
Zero
Running
on
Zero
| import random | |
| import numpy as np | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
| import spaces | |
| # Available models | |
| MODEL_OPTIONS_TRANSLATE = { | |
| #"Flan-T5-A-Dhivehi-Latin Model": "alakxender/flan-t5-base-dhivehi-en-latin-v2", | |
| "Flan-T5-B-Dhivehi-Latin Model": "alakxender/flan-t5-base-dhivehi-en-latin", | |
| "MT5-B-Dhivehi-English Model": "alakxender/mt5-base-dv-en", | |
| "MT5-B1-Dhivehi-English Model": "alakxender/mt5-base-dv-en-md", | |
| "@politecat314-Dhivehi-Latin Model": "politecat314/flan-t5-base-dv2latin-mihaaru" | |
| } | |
| # Cache for loaded models/tokenizers | |
| MODEL_CACHE = {} | |
| def get_model_and_tokenizer(model_dir): | |
| if model_dir not in MODEL_CACHE: | |
| print(f"Loading model: {model_dir}") | |
| tokenizer = AutoTokenizer.from_pretrained(model_dir) | |
| model = AutoModelForSeq2SeqLM.from_pretrained(model_dir) | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print(f"Moving model to device: {device}") | |
| model.to(device) | |
| MODEL_CACHE[model_dir] = (tokenizer, model) | |
| return MODEL_CACHE[model_dir] | |
| max_input_length = 128 | |
| max_output_length = 128 | |
| def translate(instruction, input_text, model_choice, max_new_tokens=128, num_beams=4, repetition_penalty=1.2, no_repeat_ngram_size=3): | |
| model_dir = MODEL_OPTIONS_TRANSLATE[model_choice] | |
| tokenizer, model = get_model_and_tokenizer(model_dir) | |
| combined_input = f"{instruction.strip()} {input_text.strip()}" if input_text else instruction.strip() | |
| inputs = tokenizer( | |
| combined_input, | |
| return_tensors="pt", | |
| truncation=True, | |
| max_length=max_input_length | |
| ) | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| inputs = {k: v.to(device) for k, v in inputs.items()} | |
| gen_kwargs = { | |
| **inputs, | |
| "max_length":max_new_tokens, | |
| "min_length":10, | |
| "num_beams":num_beams, | |
| "early_stopping":True, | |
| "no_repeat_ngram_size":no_repeat_ngram_size, | |
| "repetition_penalty":repetition_penalty, | |
| "do_sample":False, | |
| "pad_token_id":tokenizer.pad_token_id, | |
| "eos_token_id":tokenizer.eos_token_id | |
| } | |
| with torch.no_grad(): | |
| outputs = model.generate(**gen_kwargs) | |
| decoded_output = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| return decoded_output |