Tschoui's picture
Migrate application to hugginface
48097f5
raw
history blame
5.09 kB
# Original code from ProtMamba under Apache License 2.0.
#
# Modifications made by Niklas Schmidinger, Lisa Schneckenreiter and Sohvi Luukkonen
# - MambaTrainer renamed to ProtTrainer
import os
import re
import torch
from transformers import Trainer, TrainerCallback
from protxlstm.utils import AA_TO_ID, find_fim_indices
class ProtTrainer(Trainer):
"""
Base HuggingFace Trainer used for training.
from https://github.com/havenhq/mamba-chat/blob/main/trainer/mamba_trainer.py"""
def __init__(self, compute_only_fim_loss, **kwargs,):
super().__init__(**kwargs)
self.compute_only_fim_loss = compute_only_fim_loss
def compute_loss(self, model, inputs, return_outputs=False):
input_ids = inputs.pop("input_ids")
labels = inputs.pop("labels")
if "seq_position_ids" in inputs and "position_ids" in inputs:
position_ids = inputs.pop("position_ids")
seq_position_ids = inputs.pop("seq_position_ids")
output = model(input_ids, position_ids=position_ids, seq_position_ids=seq_position_ids)
elif "position_ids" in inputs:
position_ids = inputs.pop("position_ids")
output = model(input_ids, position_ids=position_ids)
else:
output = model(input_ids)
lm_logits = output.logits
labels = labels.to(lm_logits.device)
shift_logits = lm_logits[:, :-1, :].contiguous()
labels = labels[:, 1:].contiguous()
loss_fct = torch.nn.CrossEntropyLoss()
if self.compute_only_fim_loss:
# start and end tokens
is_cls_tokens = (labels == AA_TO_ID["<cls>"])
is_eos_tokens = (labels == AA_TO_ID["<eos>"])
bool_fim = find_fim_indices(is_cls_tokens, is_eos_tokens)
# include also the cls token
bool_fim = bool_fim | is_cls_tokens
inds = torch.where(bool_fim)
lm_loss = loss_fct(shift_logits[inds[0], inds[1], :], labels[bool_fim])
else:
lm_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1))
return (lm_loss, output) if return_outputs else lm_loss
def save_model(self, output_dir, _internal_call):
if int(os.getenv('LOCAL_RANK', '0')) == 0:
self.model.save_pretrained(output_dir)
PREFIX_CHECKPOINT_DIR = "checkpoint"
_re_checkpoint = re.compile(r"^" + PREFIX_CHECKPOINT_DIR + r"\-(\d+)$")
def get_last_checkpoint(folder, max_steps=None):
content = os.listdir(folder)
checkpoints = [
path
for path in content
if _re_checkpoint.search(path) is not None and os.path.isdir(os.path.join(folder, path))
]
if len(checkpoints) == 0:
return
max_steps = max_steps if max_steps is not None else float("inf")
# func = lambda x: int(_re_checkpoint.search(x).groups()[0])
def func(x):
num = int(_re_checkpoint.search(x).groups()[0])
return num if num < max_steps else -1
return os.path.join(folder, max(checkpoints, key=func))
class EarlyStoppingCallback(TrainerCallback):
def __init__(self, train_path, config=None):
self.step_counter_reset = 0
self.step_counter_stop = 0
self.best_loss = None
self.train_path = train_path
self.patience = config["patience"]
self.metric_name = config["early_stopping_metric"]
self.checkpoint_path = None
self.should_restart = False
self.eval_steps = config["eval_steps"]
self.loss_increase_factor = config["loss_increase_factor"]
def get_checkpoint_path(self, max_steps):
last_checkpoint = None
if os.path.exists(self.train_path):
last_checkpoint = get_last_checkpoint(self.train_path, max_steps)
if last_checkpoint is None:
print("No checkpoint found, starting training from scratch.")
else:
print(f"Max checkpoint allowed: {max_steps}, restarting from {last_checkpoint}.")
return last_checkpoint
def on_evaluate(self, args, state, control, model, metrics, **kwargs):
if self.metric_name in metrics:
if self.best_loss is None:
self.best_loss = metrics[self.metric_name]
elif self.best_loss*self.loss_increase_factor < metrics[self.metric_name]:
self.step_counter += 1
if self.step_counter >= self.patience:
checkpoint_path = self.get_checkpoint_path(max_steps=(state.global_step-self.patience*self.eval_steps))
control.should_training_stop = True
self.checkpoint_path = checkpoint_path
self.should_restart = True
else:
self.step_counter = 0
self.best_loss = min(self.best_loss, metrics[self.metric_name])
self.should_restart = False
def on_train_begin(self, args, state, control, **kwargs):
self.step_counter = 0
self.best_loss = None
self.should_restart = False