The ruGPT-3 medium model finetuned for 3 epochs on the corpus of 300 electronic books in the universe of the Zona.
Usage
## Load NN locally
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import torch
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def load_tokenizer_and_model(model_name_or_path):
return GPT2Tokenizer.from_pretrained(model_name_or_path), \
GPT2LMHeadModel.from_pretrained(model_name_or_path).to(DEVICE)
# Load model from local dir
tokenizer, model = load_tokenizer_and_model("./models/path/")
model.eval() #freeze gradient calc
print('Model was loaded')
# Define generator
def generate_story_actions(
model,
tok,
text,
max_length = 500,
top_k = 5,
top_p = 0.95,
do_sample = True,
temperature = 1.2,
num_beams = 3,
no_repeat_ngram_size = 3,
repetition_penalty = 2.,
last_text = None,
num_sentences = 3
):
input_ids = tok.encode(text, return_tensors="pt").to(DEVICE)
out = model.generate(
input_ids,
max_length=max_length,
repetition_penalty=repetition_penalty,
do_sample=do_sample,
top_k=top_k, top_p=top_p, temperature=temperature,
num_beams=num_beams, no_repeat_ngram_size=no_repeat_ngram_size
)
generated_content = list(map(tok.decode, out))[0]
return ' '.join(generated_content)
# Generate
story_action = generate_story_actions(model, tokenizer, text = player_action_promt, last_text = last_text)
- Downloads last month
- 22
Model tree for alexeymosc/ai_stalker_ru_gpt_3_medium
Base model
ai-forever/rugpt3medium_based_on_gpt2