|
|
import torch |
|
|
import torch.nn as nn |
|
|
from tqdm.auto import tqdm |
|
|
import torch.nn.functional as F |
|
|
import torch.nn.init as init |
|
|
from huggingface_hub import PyTorchModelHubMixin, hf_hub_download, snapshot_download, HfApi |
|
|
from transformers import AutoModel, AutoConfig, ModelCard, AutoTokenizer, AutoModelForCausalLM, GenerationMixin |
|
|
import tempfile |
|
|
from pathlib import Path |
|
|
import json, os |
|
|
import numpy as np |
|
|
import safetensors.torch as sfts |
|
|
from tqdm.auto import tqdm |
|
|
from typing import List, Optional, Tuple, Union, TypedDict |
|
|
from transformers import Cache, DynamicCache, StaticCache |
|
|
from transformers.processing_utils import Unpack |
|
|
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast |
|
|
|
|
|
|
|
|
class MLoRASingleton: |
|
|
_instance = None |
|
|
|
|
|
def __new__(cls, *args, **kwargs): |
|
|
if cls._instance is None: |
|
|
|
|
|
cls._instance = super().__new__(cls) |
|
|
cls._instance.tensors = {} |
|
|
return cls._instance |
|
|
|
|
|
def set_tensor(self, tensor): |
|
|
"""Définit le tenseur à stocker dans le singleton.""" |
|
|
self.tensors[tensor.device] = tensor |
|
|
|
|
|
def get_tensor(self, device): |
|
|
"""Récupère le tenseur stocké dans le singleton.""" |
|
|
return self.tensors[device] |
|
|
|
|
|
class FlashAttentionKwargs(TypedDict, total=False): |
|
|
""" |
|
|
Keyword arguments for Flash Attention with Compile. |
|
|
|
|
|
Attributes: |
|
|
cu_seq_lens_q (`torch.LongTensor`, *optional*) |
|
|
Gets cumlative sequence length for query state. |
|
|
cu_seq_lens_k (`torch.LongTensor`, *optional*) |
|
|
Gets cumlative sequence length for key state. |
|
|
max_length_q (`int`, *optional*): |
|
|
Maximum sequence length for query state. |
|
|
max_length_k (`int`, *optional*): |
|
|
Maximum sequence length for key state. |
|
|
""" |
|
|
|
|
|
cu_seq_lens_q: Optional[torch.LongTensor] |
|
|
cu_seq_lens_k: Optional[torch.LongTensor] |
|
|
max_length_q: Optional[int] |
|
|
max_length_k: Optional[int] |
|
|
|
|
|
class LossKwargs(TypedDict, total=False): |
|
|
""" |
|
|
Keyword arguments to be passed to the loss function |
|
|
|
|
|
Attributes: |
|
|
num_items_in_batch (`int`, *optional*): |
|
|
Number of items in the batch. It is recommended to pass it when |
|
|
you are doing gradient accumulation. |
|
|
""" |
|
|
|
|
|
num_items_in_batch: Optional[int] |
|
|
|
|
|
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... |
|
|
|
|
|
class MLoRAlinear(nn.Module): |
|
|
__constants__ = ["in_features", "out_features"] |
|
|
in_features: int |
|
|
out_features: int |
|
|
weight: torch.Tensor |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
in_features: int, |
|
|
out_features: int, |
|
|
base_size:int, |
|
|
weights: torch.Tensor, |
|
|
bias_tensor: torch.Tensor = None, |
|
|
bias: bool = False, |
|
|
device=None, |
|
|
dtype=None, |
|
|
) -> None: |
|
|
factory_kwargs = {"device": device, "dtype": dtype} |
|
|
super().__init__() |
|
|
self.in_features = in_features |
|
|
self.out_features = out_features |
|
|
self.base_size = base_size |
|
|
|
|
|
self.weight = nn.Parameter(weights) |
|
|
self.weight.requires_grad = False |
|
|
|
|
|
self.linearinp = nn.Linear(self.base_size, self.in_features, bias=False) |
|
|
self.linearinp.weight = nn.Parameter(torch.zeros_like(self.linearinp.weight.data)) |
|
|
self.linearout = nn.Linear(self.base_size, self.out_features, bias=False) |
|
|
self.linearout.weight = nn.Parameter(torch.zeros_like(self.linearout.weight.data)) |
|
|
self.singleton = MLoRASingleton() |
|
|
|
|
|
|
|
|
if bias: |
|
|
self.bias = nn.Parameter(bias_tensor) |
|
|
else: |
|
|
self.register_parameter("bias", None) |
|
|
self.reset_parameters() |
|
|
|
|
|
def reset_parameters(self) -> None: |
|
|
|
|
|
|
|
|
|
|
|
init.xavier_normal_(self.linearout.weight) |
|
|
init.xavier_normal_(self.linearinp.weight) |
|
|
|
|
|
def forward(self, input: torch.Tensor) -> torch.Tensor: |
|
|
base = self.singleton.get_tensor(input.device) |
|
|
|
|
|
A = self.linearinp(base).unsqueeze(1) |
|
|
B = self.linearout(base).unsqueeze(2) |
|
|
AB = torch.tanh(torch.bmm(B, A).squeeze(0)) |
|
|
|
|
|
return torch.bmm(input, (torch.stack([self.weight for i in range(AB.shape[0])], dim=0) - AB).transpose(-1, -2)) if AB.dim() == 3 else F.linear(input, self.weight - AB) |
|
|
|
|
|
def extra_repr(self) -> str: |
|
|
return f"in_features={self.in_features}, out_features={self.out_features}, bias={self.bias is not None}" |
|
|
|
|
|
def set_layer_attribute(model, attribute_str, new_module): |
|
|
|
|
|
attrs = attribute_str.split('.') |
|
|
|
|
|
|
|
|
module = model |
|
|
for attr in attrs[:-1]: |
|
|
if attr.isdigit(): |
|
|
module = module[int(attr)] |
|
|
else: |
|
|
module = getattr(module, attr) |
|
|
|
|
|
|
|
|
setattr(module, attrs[-1], new_module) |
|
|
|
|
|
|
|
|
def get_mlora_model(model, mlora_layers, base_size, lm_head=False, print_=True): |
|
|
for layer_name, param in model.named_parameters(): |
|
|
param.requires_grad = False |
|
|
if layer_name in mlora_layers: |
|
|
layer_name = layer_name.replace(".weight", "") |
|
|
if print_: |
|
|
print(layer_name) |
|
|
out, inp = param.data.shape |
|
|
new_linear = MLoRAlinear(in_features=inp, out_features=out, base_size=base_size, weights=param.data, bias=False) |
|
|
|
|
|
set_layer_attribute(model, layer_name, new_linear) |
|
|
if lm_head: |
|
|
out, inp = model.lm_head.weight.data.shape |
|
|
new_linear = MLoRAlinear(in_features=inp, out_features=out, base_size=base_size, weights=model.lm_head.weight.data, bias=False) |
|
|
model.lm_head = new_linear |
|
|
|
|
|
return model |
|
|
|
|
|
class EmbdModel(nn.Module): |
|
|
def __init__(self, tokenizer, model): |
|
|
super().__init__() |
|
|
self.tokenizer = tokenizer |
|
|
self.model = model |
|
|
self.set_require_grad(False) |
|
|
|
|
|
def encode(self, text, device): |
|
|
inputs = self.tokenizer(text, padding=True, truncation=True, return_tensors='pt', max_length=512).to(device) |
|
|
scores = self.model(**inputs, return_dict=True) |
|
|
out = scores.pooler_output |
|
|
return out |
|
|
|
|
|
def set_require_grad(self, require_grad): |
|
|
for layer_name, param in self.model.named_parameters(): |
|
|
param.requires_grad = require_grad |
|
|
|
|
|
|
|
|
|
|
|
class MLoRAModel(nn.Module, GenerationMixin): |
|
|
''' |
|
|
config:dict-> |
|
|
"mlora_layers":list[str], |
|
|
"base_size":int, |
|
|
"embd_model":str, |
|
|
"llm_tokenizer":str, |
|
|
''' |
|
|
def __init__(self, config, llm, tokenizer, embd_model): |
|
|
super().__init__() |
|
|
self.config = config |
|
|
self.llm = llm |
|
|
self.tokenizer = tokenizer |
|
|
self.embd_model = embd_model |
|
|
self.singleton = MLoRASingleton() |
|
|
self.singleton.set_tensor(torch.zeros(1, 1).to("cpu")) |
|
|
|
|
|
@classmethod |
|
|
def from_pretrained(cls, repo, token=None, print_=False): |
|
|
|
|
|
with tempfile.TemporaryDirectory() as tmp: |
|
|
|
|
|
snapshot_download(repo_id=repo, local_dir=f"{tmp}/repo/", token=token) |
|
|
config = MetaLoRAConfig.from_pretrained(repo, token=token) |
|
|
|
|
|
|
|
|
llm_config = AutoConfig.from_pretrained(f"{tmp}/repo/model/config.json") |
|
|
print("load LLM") |
|
|
llm = AutoModelForCausalLM.from_config(llm_config) |
|
|
llm = get_mlora_model(llm, config.mlora_layers, config.base_size, lm_head="lm_head.weight" in config.mlora_layers, print_=print_) |
|
|
sfts.load_model(llm, f"{tmp}/repo/model/model.safetensors") |
|
|
print("LLM loaded") |
|
|
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(config.llm_tokenizer) |
|
|
|
|
|
|
|
|
print("load Embd model") |
|
|
embd_model = EmbdModel(AutoTokenizer.from_pretrained(config.embd_model), AutoModel.from_pretrained(config.embd_model)) |
|
|
print("Embd model loaded") |
|
|
|
|
|
|
|
|
mloramodel = cls(config, llm, tokenizer, embd_model) |
|
|
|
|
|
return mloramodel |
|
|
|
|
|
|
|
|
def push_to_hub(self, repo, token=None, private=False): |
|
|
api = HfApi(token=token) |
|
|
repo_id = api.create_repo(repo_id=repo, exist_ok=True, private=private).repo_id |
|
|
|
|
|
with tempfile.TemporaryDirectory() as tmp: |
|
|
saved_path = Path(tmp) / repo |
|
|
|
|
|
|
|
|
model_path = saved_path / "model" |
|
|
model_path.mkdir(parents=True, exist_ok=True) |
|
|
self.llm.save_pretrained(model_path) |
|
|
|
|
|
api.upload_folder( |
|
|
folder_path=model_path, |
|
|
repo_id=repo_id, |
|
|
path_in_repo="model", |
|
|
) |
|
|
print(f"llm pushed to {repo}/model") |
|
|
|
|
|
|
|
|
model_path = saved_path / "embd_model" |
|
|
model_path.mkdir(parents=True, exist_ok=True) |
|
|
self.embd_model.model.save_pretrained(model_path) |
|
|
self.embd_model.tokenizer.save_pretrained(model_path) |
|
|
api.upload_folder( |
|
|
folder_path=model_path, |
|
|
repo_id=repo_id, |
|
|
path_in_repo="embd_model", |
|
|
) |
|
|
print(f"Embd model pushed to {repo}/embd_model") |
|
|
|
|
|
|
|
|
config_path = saved_path / "config.json" |
|
|
with open(config_path, "w") as config_file: |
|
|
json.dump(self.config.to_dict(), config_file) |
|
|
|
|
|
api.upload_file( |
|
|
path_or_fileobj=config_path, |
|
|
repo_id=repo_id, |
|
|
path_in_repo="config.json", |
|
|
) |
|
|
print(f"Config pushed to {repo}/config.json") |
|
|
|
|
|
|
|
|
md_path = saved_path / "README.md" |
|
|
content = f""" |
|
|
Model size: {self.get_n_params()} |
|
|
""" |
|
|
with open(md_path, 'w') as f: |
|
|
f.write(content) |
|
|
|
|
|
api.upload_file( |
|
|
path_or_fileobj=md_path, |
|
|
repo_id=repo_id, |
|
|
path_in_repo="README.md", |
|
|
) |
|
|
def get_n_params(self): |
|
|
params = sum(p.numel() for p in self.parameters()) |
|
|
return "{:,}".format(params) |
|
|
|
|
|
def get_memory_footprint(self, return_buffers=True): |
|
|
mem = sum([param.nelement() * param.element_size() for param in self.parameters()]) |
|
|
if return_buffers: |
|
|
mem_bufs = sum([buf.nelement() * buf.element_size() for buf in self.buffers()]) |
|
|
mem = mem + mem_bufs |
|
|
return mem |
|
|
|
|
|
def count_learnable_params(model): |
|
|
""" |
|
|
Calculate the number of learnable parameters. |
|
|
|
|
|
Args: |
|
|
model (torch.nn.Module): The PyTorch model. |
|
|
|
|
|
Returns: |
|
|
int: The total number of learnable parameters. |
|
|
""" |
|
|
return sum(p.numel() for p in self.parameters() if p.requires_grad) |
|
|
|
|
|
def extract_text_until_eos(self, input_tensor, tokenizer, eos_token_id, pad_token_id=None): |
|
|
""" |
|
|
Truncate input sequences at the first EOS token. |
|
|
|
|
|
Args: |
|
|
- input_tensor (torch.Tensor): Input tensor of shape [batch_size, inp_size] |
|
|
- tokenizer: Tokenizer used for generation |
|
|
- eos_token_id (int): ID of the EOS token |
|
|
- pad_token_id (int, optional): ID of the padding token. If None, uses tokenizer's pad token |
|
|
|
|
|
Returns: |
|
|
- torch.Tensor: Tensor of truncated sequences |
|
|
- torch.Tensor: Boolean mask indicating if EOS was found in each sequence |
|
|
""" |
|
|
|
|
|
if pad_token_id is None: |
|
|
pad_token_id = tokenizer.pad_token_id if hasattr(tokenizer, 'pad_token_id') else 0 |
|
|
|
|
|
|
|
|
output_tensor = input_tensor.clone() |
|
|
|
|
|
|
|
|
eos_mask = torch.zeros(input_tensor.size(0), dtype=torch.bool, device=input_tensor.device) |
|
|
|
|
|
for i in range(input_tensor.size(0)): |
|
|
|
|
|
eos_indices = torch.where(input_tensor[i] == eos_token_id)[0] |
|
|
|
|
|
if len(eos_indices) > 0: |
|
|
|
|
|
first_eos_pos = eos_indices[0].item() |
|
|
|
|
|
|
|
|
output_tensor[i, first_eos_pos:] = pad_token_id |
|
|
|
|
|
|
|
|
eos_mask[i] = True |
|
|
|
|
|
return output_tensor |
|
|
|
|
|
def set_learnable_layers(self): |
|
|
""" |
|
|
Set require grad of layers to False except for layers in MetaLoRA layers. |
|
|
""" |
|
|
pass |
|
|
|
|
|
|
|
|
def forward( |
|
|
self, |
|
|
input_ids: torch.LongTensor = None, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
position_ids: Optional[torch.LongTensor] = None, |
|
|
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, |
|
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
|
labels: Optional[torch.LongTensor] = None, |
|
|
use_cache: Optional[bool] = None, |
|
|
output_attentions: Optional[bool] = None, |
|
|
output_hidden_states: Optional[bool] = None, |
|
|
return_dict: Optional[bool] = None, |
|
|
cache_position: Optional[torch.LongTensor] = None, |
|
|
texts = None, |
|
|
num_logits_to_keep: int = 0, |
|
|
**kwargs: Unpack[KwargsForCausalLM], |
|
|
) -> Union[Tuple, CausalLMOutputWithPast]: |
|
|
|
|
|
if texts is None: |
|
|
texts = self.extract_text_until_eos(input_ids, self.tokenizer, self.tokenizer.eos_token_id) |
|
|
texts = self.tokenizer.batch_decode(texts) |
|
|
texts = [t.split(self.tokenizer.eos_token)[0]+self.tokenizer.eos_token for t in texts] |
|
|
|
|
|
encoding = self.embd_model.encode(texts, device=input_ids.device) |
|
|
self.singleton.set_tensor(encoding) |
|
|
|
|
|
outputs = self.llm.forward( |
|
|
input_ids, |
|
|
attention_mask, |
|
|
position_ids, |
|
|
past_key_values, |
|
|
inputs_embeds, |
|
|
labels, |
|
|
use_cache, |
|
|
output_attentions, |
|
|
output_hidden_states, |
|
|
return_dict, |
|
|
cache_position, |
|
|
num_logits_to_keep, |
|
|
**kwargs, |
|
|
) |
|
|
return outputs |
|
|
|
|
|
def format_layers(layers): |
|
|
""" |
|
|
Given a list of strings representing layer paths with embedded index lists, |
|
|
return a list with the expanded paths for each index. |
|
|
|
|
|
Args: |
|
|
layers (list): List of strings with embedded index lists in the path. |
|
|
|
|
|
Returns: |
|
|
list: Expanded list of layer paths. |
|
|
""" |
|
|
formatted_layers = [] |
|
|
|
|
|
for layer in layers: |
|
|
if "[" in layer or "]" in layer: |
|
|
|
|
|
prefix, indexes_and_suffix = layer.split(".[", maxsplit=1) |
|
|
indexes, suffix = indexes_and_suffix.split("].", maxsplit=1) |
|
|
|
|
|
|
|
|
for index in indexes.split(","): |
|
|
formatted_layers.append(f"{prefix}.{index.strip()}.{suffix}") |
|
|
else: |
|
|
formatted_layers.append(layer) |
|
|
|
|
|
return formatted_layers |