MetaLoRA-code / modeling_metalora.py
Arthur-LAGACHERIE's picture
Upload 3 files
f30a887 verified
raw
history blame
15.8 kB
import torch
import torch.nn as nn
from tqdm.auto import tqdm
import torch.nn.functional as F
import torch.nn.init as init
from huggingface_hub import PyTorchModelHubMixin, hf_hub_download, snapshot_download, HfApi
from transformers import AutoModel, AutoConfig, ModelCard, AutoTokenizer, AutoModelForCausalLM, GenerationMixin
import tempfile
from pathlib import Path
import json, os
import numpy as np
import safetensors.torch as sfts
from tqdm.auto import tqdm
from typing import List, Optional, Tuple, Union, TypedDict
from transformers import Cache, DynamicCache, StaticCache
from transformers.processing_utils import Unpack
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
class MLoRASingleton:
_instance = None # Stocke l'instance unique
def __new__(cls, *args, **kwargs):
if cls._instance is None:
# Crée l'instance si elle n'existe pas
cls._instance = super().__new__(cls)
cls._instance.tensors = {} # Initialise un attribut pour le tenseur
return cls._instance
def set_tensor(self, tensor):
"""Définit le tenseur à stocker dans le singleton."""
self.tensors[tensor.device] = tensor
def get_tensor(self, device):
"""Récupère le tenseur stocké dans le singleton."""
return self.tensors[device]
class FlashAttentionKwargs(TypedDict, total=False):
"""
Keyword arguments for Flash Attention with Compile.
Attributes:
cu_seq_lens_q (`torch.LongTensor`, *optional*)
Gets cumlative sequence length for query state.
cu_seq_lens_k (`torch.LongTensor`, *optional*)
Gets cumlative sequence length for key state.
max_length_q (`int`, *optional*):
Maximum sequence length for query state.
max_length_k (`int`, *optional*):
Maximum sequence length for key state.
"""
cu_seq_lens_q: Optional[torch.LongTensor]
cu_seq_lens_k: Optional[torch.LongTensor]
max_length_q: Optional[int]
max_length_k: Optional[int]
class LossKwargs(TypedDict, total=False):
"""
Keyword arguments to be passed to the loss function
Attributes:
num_items_in_batch (`int`, *optional*):
Number of items in the batch. It is recommended to pass it when
you are doing gradient accumulation.
"""
num_items_in_batch: Optional[int]
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
class MLoRAlinear(nn.Module):
__constants__ = ["in_features", "out_features"]
in_features: int
out_features: int
weight: torch.Tensor
def __init__(
self,
in_features: int,
out_features: int,
base_size:int,
weights: torch.Tensor,
bias_tensor: torch.Tensor = None,
bias: bool = False,
device=None,
dtype=None,
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.base_size = base_size
self.weight = nn.Parameter(weights)
self.weight.requires_grad = False
self.linearinp = nn.Linear(self.base_size, self.in_features, bias=False)
self.linearinp.weight = nn.Parameter(torch.zeros_like(self.linearinp.weight.data))
self.linearout = nn.Linear(self.base_size, self.out_features, bias=False)
self.linearout.weight = nn.Parameter(torch.zeros_like(self.linearout.weight.data))
self.singleton = MLoRASingleton()
if bias:
self.bias = nn.Parameter(bias_tensor)
else:
self.register_parameter("bias", None)
self.reset_parameters()
def reset_parameters(self) -> None:
# Setting a=sqrt(5) in kaiming_uniform is the same as initializing with
# uniform(-1/sqrt(in_features), 1/sqrt(in_features)). For details, see
# https://github.com/pytorch/pytorch/issues/57109
init.xavier_normal_(self.linearout.weight)
init.xavier_normal_(self.linearinp.weight)
def forward(self, input: torch.Tensor) -> torch.Tensor:
base = self.singleton.get_tensor(input.device) # (batch_size, base_size)
A = self.linearinp(base).unsqueeze(1) # (batch_size, 1, in_features)
B = self.linearout(base).unsqueeze(2) # (batch_size, out_features, 1)
AB = torch.tanh(torch.bmm(B, A).squeeze(0)) # (batch_size, out_features, in_features)
return torch.bmm(input, (torch.stack([self.weight for i in range(AB.shape[0])], dim=0) - AB).transpose(-1, -2)) if AB.dim() == 3 else F.linear(input, self.weight - AB)
def extra_repr(self) -> str:
return f"in_features={self.in_features}, out_features={self.out_features}, bias={self.bias is not None}"
def set_layer_attribute(model, attribute_str, new_module):
# On divise l'input en parties
attrs = attribute_str.split('.')
# Navigation jusqu'à l'avant-dernier module
module = model
for attr in attrs[:-1]:
if attr.isdigit(): # si c'est un index (ex : layers[0])
module = module[int(attr)]
else:
module = getattr(module, attr)
# Remplacement du dernier attribut par le nouveau module
setattr(module, attrs[-1], new_module)
def get_mlora_model(model, mlora_layers, base_size, lm_head=False, print_=True):
for layer_name, param in model.named_parameters():
param.requires_grad = False
if layer_name in mlora_layers:
layer_name = layer_name.replace(".weight", "")
if print_:
print(layer_name)
out, inp = param.data.shape
new_linear = MLoRAlinear(in_features=inp, out_features=out, base_size=base_size, weights=param.data, bias=False)
set_layer_attribute(model, layer_name, new_linear)
if lm_head:
out, inp = model.lm_head.weight.data.shape
new_linear = MLoRAlinear(in_features=inp, out_features=out, base_size=base_size, weights=model.lm_head.weight.data, bias=False)
model.lm_head = new_linear
return model
class EmbdModel(nn.Module):
def __init__(self, tokenizer, model):
super().__init__()
self.tokenizer = tokenizer
self.model = model
self.set_require_grad(False)
def encode(self, text, device):
inputs = self.tokenizer(text, padding=True, truncation=True, return_tensors='pt', max_length=512).to(device)
scores = self.model(**inputs, return_dict=True)
out = scores.pooler_output
return out
def set_require_grad(self, require_grad):
for layer_name, param in self.model.named_parameters():
param.requires_grad = require_grad
class MLoRAModel(nn.Module, GenerationMixin):
'''
config:dict->
"mlora_layers":list[str],
"base_size":int,
"embd_model":str,
"llm_tokenizer":str,
'''
def __init__(self, config, llm, tokenizer, embd_model):
super().__init__()
self.config = config
self.llm = llm
self.tokenizer = tokenizer
self.embd_model = embd_model
self.singleton = MLoRASingleton()
self.singleton.set_tensor(torch.zeros(1, 1).to("cpu"))
@classmethod
def from_pretrained(cls, repo, token=None, print_=False):
with tempfile.TemporaryDirectory() as tmp:
# load the repo and the config
snapshot_download(repo_id=repo, local_dir=f"{tmp}/repo/", token=token)
config = MetaLoRAConfig.from_pretrained(repo, token=token)
### load model
llm_config = AutoConfig.from_pretrained(f"{tmp}/repo/model/config.json")
print("load LLM")
llm = AutoModelForCausalLM.from_config(llm_config)
llm = get_mlora_model(llm, config.mlora_layers, config.base_size, lm_head="lm_head.weight" in config.mlora_layers, print_=print_)
sfts.load_model(llm, f"{tmp}/repo/model/model.safetensors")
print("LLM loaded")
### load tokenizer
tokenizer = AutoTokenizer.from_pretrained(config.llm_tokenizer)
### load embd_model
print("load Embd model")
embd_model = EmbdModel(AutoTokenizer.from_pretrained(config.embd_model), AutoModel.from_pretrained(config.embd_model))
print("Embd model loaded")
### create instance
mloramodel = cls(config, llm, tokenizer, embd_model)
return mloramodel
def push_to_hub(self, repo, token=None, private=False):
api = HfApi(token=token)
repo_id = api.create_repo(repo_id=repo, exist_ok=True, private=private).repo_id
with tempfile.TemporaryDirectory() as tmp:
saved_path = Path(tmp) / repo
### push model
model_path = saved_path / "model"
model_path.mkdir(parents=True, exist_ok=True)
self.llm.save_pretrained(model_path)
api.upload_folder(
folder_path=model_path,
repo_id=repo_id,
path_in_repo="model",
)
print(f"llm pushed to {repo}/model")
### push embd model
model_path = saved_path / "embd_model"
model_path.mkdir(parents=True, exist_ok=True)
self.embd_model.model.save_pretrained(model_path)
self.embd_model.tokenizer.save_pretrained(model_path)
api.upload_folder(
folder_path=model_path,
repo_id=repo_id,
path_in_repo="embd_model",
)
print(f"Embd model pushed to {repo}/embd_model")
### push config
config_path = saved_path / "config.json"
with open(config_path, "w") as config_file:
json.dump(self.config.to_dict(), config_file)
api.upload_file(
path_or_fileobj=config_path,
repo_id=repo_id,
path_in_repo="config.json", # Push to the main folder
)
print(f"Config pushed to {repo}/config.json")
### push model card
md_path = saved_path / "README.md"
content = f"""
Model size: {self.get_n_params()}
"""
with open(md_path, 'w') as f:
f.write(content)
api.upload_file(
path_or_fileobj=md_path,
repo_id=repo_id,
path_in_repo="README.md", # Push to the main folder
)
def get_n_params(self):
params = sum(p.numel() for p in self.parameters())
return "{:,}".format(params)
def get_memory_footprint(self, return_buffers=True):
mem = sum([param.nelement() * param.element_size() for param in self.parameters()])
if return_buffers:
mem_bufs = sum([buf.nelement() * buf.element_size() for buf in self.buffers()])
mem = mem + mem_bufs
return mem
def count_learnable_params(model):
"""
Calculate the number of learnable parameters.
Args:
model (torch.nn.Module): The PyTorch model.
Returns:
int: The total number of learnable parameters.
"""
return sum(p.numel() for p in self.parameters() if p.requires_grad)
def extract_text_until_eos(self, input_tensor, tokenizer, eos_token_id, pad_token_id=None):
"""
Truncate input sequences at the first EOS token.
Args:
- input_tensor (torch.Tensor): Input tensor of shape [batch_size, inp_size]
- tokenizer: Tokenizer used for generation
- eos_token_id (int): ID of the EOS token
- pad_token_id (int, optional): ID of the padding token. If None, uses tokenizer's pad token
Returns:
- torch.Tensor: Tensor of truncated sequences
- torch.Tensor: Boolean mask indicating if EOS was found in each sequence
"""
# Determine padding token
if pad_token_id is None:
pad_token_id = tokenizer.pad_token_id if hasattr(tokenizer, 'pad_token_id') else 0
# Create output tensor
output_tensor = input_tensor.clone()
# Create mask to track EOS detection
eos_mask = torch.zeros(input_tensor.size(0), dtype=torch.bool, device=input_tensor.device)
for i in range(input_tensor.size(0)): # Iterate over batch
# Find first EOS token
eos_indices = torch.where(input_tensor[i] == eos_token_id)[0]
if len(eos_indices) > 0:
# Mark first EOS position
first_eos_pos = eos_indices[0].item()
# Truncate sequence and pad
output_tensor[i, first_eos_pos:] = pad_token_id
# Mark EOS found for this sequence
eos_mask[i] = True
return output_tensor
def set_learnable_layers(self):
"""
Set require grad of layers to False except for layers in MetaLoRA layers.
"""
pass
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
texts = None,
num_logits_to_keep: int = 0,
**kwargs: Unpack[KwargsForCausalLM],
) -> Union[Tuple, CausalLMOutputWithPast]:
if texts is None:
texts = self.extract_text_until_eos(input_ids, self.tokenizer, self.tokenizer.eos_token_id)
texts = self.tokenizer.batch_decode(texts)
texts = [t.split(self.tokenizer.eos_token)[0]+self.tokenizer.eos_token for t in texts]
encoding = self.embd_model.encode(texts, device=input_ids.device)
self.singleton.set_tensor(encoding)
outputs = self.llm.forward(
input_ids,
attention_mask,
position_ids,
past_key_values,
inputs_embeds,
labels,
use_cache,
output_attentions,
output_hidden_states,
return_dict,
cache_position,
num_logits_to_keep,
**kwargs,
)
return outputs
def format_layers(layers):
"""
Given a list of strings representing layer paths with embedded index lists,
return a list with the expanded paths for each index.
Args:
layers (list): List of strings with embedded index lists in the path.
Returns:
list: Expanded list of layer paths.
"""
formatted_layers = []
for layer in layers:
if "[" in layer or "]" in layer:
# Split the string to separate the index list
prefix, indexes_and_suffix = layer.split(".[", maxsplit=1)
indexes, suffix = indexes_and_suffix.split("].", maxsplit=1)
# Parse the indexes and reconstruct individual layer paths
for index in indexes.split(","):
formatted_layers.append(f"{prefix}.{index.strip()}.{suffix}")
else:
formatted_layers.append(layer)
return formatted_layers