# Original code from ProtMamba under Apache License 2.0. # # Modifications made by Niklas Schmidinger, Lisa Schneckenreiter and Sohvi Luukkonen # - MambaTrainer renamed to ProtTrainer import os import re import torch from transformers import Trainer, TrainerCallback from protxlstm.utils import AA_TO_ID, find_fim_indices class ProtTrainer(Trainer): """ Base HuggingFace Trainer used for training. from https://github.com/havenhq/mamba-chat/blob/main/trainer/mamba_trainer.py""" def __init__(self, compute_only_fim_loss, **kwargs,): super().__init__(**kwargs) self.compute_only_fim_loss = compute_only_fim_loss def compute_loss(self, model, inputs, return_outputs=False): input_ids = inputs.pop("input_ids") labels = inputs.pop("labels") if "seq_position_ids" in inputs and "position_ids" in inputs: position_ids = inputs.pop("position_ids") seq_position_ids = inputs.pop("seq_position_ids") output = model(input_ids, position_ids=position_ids, seq_position_ids=seq_position_ids) elif "position_ids" in inputs: position_ids = inputs.pop("position_ids") output = model(input_ids, position_ids=position_ids) else: output = model(input_ids) lm_logits = output.logits labels = labels.to(lm_logits.device) shift_logits = lm_logits[:, :-1, :].contiguous() labels = labels[:, 1:].contiguous() loss_fct = torch.nn.CrossEntropyLoss() if self.compute_only_fim_loss: # start and end tokens is_cls_tokens = (labels == AA_TO_ID[""]) is_eos_tokens = (labels == AA_TO_ID[""]) bool_fim = find_fim_indices(is_cls_tokens, is_eos_tokens) # include also the cls token bool_fim = bool_fim | is_cls_tokens inds = torch.where(bool_fim) lm_loss = loss_fct(shift_logits[inds[0], inds[1], :], labels[bool_fim]) else: lm_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1)) return (lm_loss, output) if return_outputs else lm_loss def save_model(self, output_dir, _internal_call): if int(os.getenv('LOCAL_RANK', '0')) == 0: self.model.save_pretrained(output_dir) PREFIX_CHECKPOINT_DIR = "checkpoint" _re_checkpoint = re.compile(r"^" + PREFIX_CHECKPOINT_DIR + r"\-(\d+)$") def get_last_checkpoint(folder, max_steps=None): content = os.listdir(folder) checkpoints = [ path for path in content if _re_checkpoint.search(path) is not None and os.path.isdir(os.path.join(folder, path)) ] if len(checkpoints) == 0: return max_steps = max_steps if max_steps is not None else float("inf") # func = lambda x: int(_re_checkpoint.search(x).groups()[0]) def func(x): num = int(_re_checkpoint.search(x).groups()[0]) return num if num < max_steps else -1 return os.path.join(folder, max(checkpoints, key=func)) class EarlyStoppingCallback(TrainerCallback): def __init__(self, train_path, config=None): self.step_counter_reset = 0 self.step_counter_stop = 0 self.best_loss = None self.train_path = train_path self.patience = config["patience"] self.metric_name = config["early_stopping_metric"] self.checkpoint_path = None self.should_restart = False self.eval_steps = config["eval_steps"] self.loss_increase_factor = config["loss_increase_factor"] def get_checkpoint_path(self, max_steps): last_checkpoint = None if os.path.exists(self.train_path): last_checkpoint = get_last_checkpoint(self.train_path, max_steps) if last_checkpoint is None: print("No checkpoint found, starting training from scratch.") else: print(f"Max checkpoint allowed: {max_steps}, restarting from {last_checkpoint}.") return last_checkpoint def on_evaluate(self, args, state, control, model, metrics, **kwargs): if self.metric_name in metrics: if self.best_loss is None: self.best_loss = metrics[self.metric_name] elif self.best_loss*self.loss_increase_factor < metrics[self.metric_name]: self.step_counter += 1 if self.step_counter >= self.patience: checkpoint_path = self.get_checkpoint_path(max_steps=(state.global_step-self.patience*self.eval_steps)) control.should_training_stop = True self.checkpoint_path = checkpoint_path self.should_restart = True else: self.step_counter = 0 self.best_loss = min(self.best_loss, metrics[self.metric_name]) self.should_restart = False def on_train_begin(self, args, state, control, **kwargs): self.step_counter = 0 self.best_loss = None self.should_restart = False