hikaruX nicolaus625 commited on
Commit
41c19fb
·
verified ·
0 Parent(s):

Duplicate from m-a-p/MusiLingo-long-v1

Browse files

Co-authored-by: Yinghao Ma <[email protected]>

.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language:
3
+ - en
4
+ license: cc-by-4.0
5
+ tags:
6
+ - music
7
+ - art
8
+ ---
9
+ # Model Card for Model ID
10
+ ## Model Details
11
+ ### Model Description
12
+ The model consists of a music encoder ```MERT-v1-300M```, a natural language decoder ```vicuna-7b-delta-v0```, and a linear projection laer between the two.
13
+
14
+ This checkpoint of MusiLingo is developed on the MusicInstruct (MI)-long and can answer long instructions with music raw audio, such as querying about the subjective feelings etc.
15
+ You can use the [MI](https://huggingface.co/datasets/m-a-p/Music-Instruct) dataset for the following demo
16
+
17
+
18
+ ### Model Sources [optional]
19
+ - **Repository:** [GitHub repo](https://github.com/zihaod/MusiLingo)
20
+ - **Paper [optional]:** __[MusiLingo: Bridging Music and Text with Pre-trained Language Models for Music Captioning and Query Response](https://arxiv.org/abs/2309.08730)__
21
+ <!-- - **Demo [optional]:** [More Information Needed] -->
22
+
23
+
24
+
25
+ ## Getting Start
26
+ ```
27
+ from tqdm.auto import tqdm
28
+
29
+ import torch
30
+ from torch.utils.data import DataLoader
31
+ from transformers import Wav2Vec2FeatureExtractor
32
+ from transformers import StoppingCriteria, StoppingCriteriaList
33
+
34
+
35
+
36
+ class StoppingCriteriaSub(StoppingCriteria):
37
+ def __init__(self, stops=[], encounters=1):
38
+ super().__init__()
39
+ self.stops = stops
40
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
41
+ for stop in self.stops:
42
+ if torch.all((stop == input_ids[0][-len(stop):])).item():
43
+ return True
44
+ return False
45
+
46
+
47
+ class StoppingCriteriaSub(StoppingCriteria):
48
+ def __init__(self, stops=[], encounters=1):
49
+ super().__init__()
50
+ self.stops = stops
51
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
52
+ for stop in self.stops:
53
+ if torch.all((stop == input_ids[0][-len(stop):])).item():
54
+ return True
55
+ return False
56
+
57
+ def get_musilingo_pred(model, text, audio_path, stopping, length_penalty=1, temperature=0.1,
58
+ max_new_tokens=300, num_beams=1, min_length=1, top_p=0.5, repetition_penalty=1.0):
59
+
60
+ # see https://huggingface.co/m-a-p/MusiLingo-musicqa-v1 for load_audio function definition
61
+ audio = load_audio(audio_path, target_sr=24000,
62
+ is_mono=True,
63
+ is_normalize=False,
64
+ crop_to_length_in_sample_points=int(30*16000)+1,
65
+ crop_randomly=True,
66
+ pad=False).cuda()
67
+ processor = Wav2Vec2FeatureExtractor.from_pretrained("m-a-p/MERT-v1-330M",trust_remote_code=True)
68
+ audio = processor(audio,
69
+ sampling_rate=24000,
70
+ return_tensors="pt")['input_values'][0].cuda()
71
+
72
+ audio_embeds, atts_audio = model.encode_audio(audio)
73
+
74
+ prompt = '<Audio><AudioHere></Audio> ' + text
75
+ instruction_prompt = [model.prompt_template.format(prompt)]
76
+ audio_embeds, atts_audio = model.instruction_prompt_wrap(audio_embeds, atts_audio, instruction_prompt)
77
+
78
+ model.llama_tokenizer.padding_side = "right"
79
+ batch_size = audio_embeds.shape[0]
80
+ bos = torch.ones([batch_size, 1],
81
+ dtype=torch.long,
82
+ device=torch.device('cuda')) * model.llama_tokenizer.bos_token_id
83
+ bos_embeds = model.llama_model.model.embed_tokens(bos)
84
+ # atts_bos = atts_audio[:, :1]
85
+ inputs_embeds = torch.cat([bos_embeds, audio_embeds], dim=1)
86
+ # attention_mask = torch.cat([atts_bos, atts_audio], dim=1)
87
+ outputs = model.llama_model.generate(
88
+ inputs_embeds=inputs_embeds,
89
+ max_new_tokens=max_new_tokens,
90
+ stopping_criteria=stopping,
91
+ num_beams=num_beams,
92
+ do_sample=True,
93
+ min_length=min_length,
94
+ top_p=top_p,
95
+ repetition_penalty=repetition_penalty,
96
+ length_penalty=length_penalty,
97
+ temperature=temperature,
98
+ )
99
+ output_token = outputs[0]
100
+ if output_token[0] == 0: # the model might output a unknow token <unk> at the beginning. remove it
101
+ output_token = output_token[1:]
102
+ if output_token[0] == 1: # if there is a start token <s> at the beginning. remove it
103
+ output_token = output_token[1:]
104
+ output_text = model.llama_tokenizer.decode(output_token, add_special_tokens=False)
105
+ output_text = output_text.split('###')[0] # remove the stop sign '###'
106
+ output_text = output_text.split('Assistant:')[-1].strip()
107
+ return output_text
108
+
109
+ musilingo = AutoModel.from_pretrained("m-a-p/MusiLingo-long-v1", trust_remote_code=True)
110
+ musilingo.to("cuda")
111
+ musilingo.eval()
112
+
113
+ prompt = "this is the task instruction and input question for MusiLingo model"
114
+ audio = "/path/to/the/audio"
115
+ stopping = StoppingCriteriaList([StoppingCriteriaSub([torch.tensor([835]).cuda(),
116
+ torch.tensor([2277, 29937]).cuda()])])
117
+ response = get_musilingo_pred(musilingo.model, prompt, audio_path, stopping, length_penalty=100, temperature=0.1)
118
+
119
+ ```
120
+
121
+ # Citing This Work
122
+
123
+ If you find the work useful for your research, please consider citing it using the following BibTeX entry:
124
+ ```
125
+ @inproceedings{deng2024musilingo,
126
+ title={MusiLingo: Bridging Music and Text with Pre-trained Language Models for Music Captioning and Query Response},
127
+ author={Deng, Zihao and Ma, Yinghao and Liu, Yudong and Guo, Rongchen and Zhang, Ge and Chen, Wenhu and Huang, Wenhao and Benetos, Emmanouil},
128
+ booktitle={Proceedings of the 2024 Annual Conference of the North American Chapter of the Association for Computational Linguistics (NAACL 2024)},
129
+ year={2024},
130
+ organization={Association for Computational Linguistics}
131
+ }
132
+ ```
config.json ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "MusilingoModel"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_musilingo.MusiLingoConfig",
7
+ "AutoModel": "modelling_musilingo.MusilingoModel"
8
+ },
9
+ "bos_token_id": 1,
10
+ "device_8bit": 0,
11
+ "end_sym": "\n",
12
+ "eos_token_id": 2,
13
+ "hidden_act": "silu",
14
+ "hidden_size": 4096,
15
+ "initializer_range": 0.02,
16
+ "intermediate_size": 11008,
17
+ "llama_model": "lmsys/vicuna-7b-delta-v0",
18
+ "low_resource": false,
19
+ "max_position_embeddings": 2048,
20
+ "max_txt_len": 32,
21
+ "mert_model": "m-a-p/MERT-v1-330M",
22
+ "model_type": "musilingo",
23
+ "num_attention_heads": 32,
24
+ "num_hidden_layers": 32,
25
+ "pad_token_id": 0,
26
+ "prompt_path": "",
27
+ "prompt_template": "###Human: {} ###Assistant: ",
28
+ "rms_norm_eps": 1e-06,
29
+ "tie_word_embeddings": false,
30
+ "torch_dtype": "float32",
31
+ "transformers_version": "4.39.3",
32
+ "use_cache": true,
33
+ "vocab_size": 32001
34
+ }
configuration_musilingo.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+
4
+ PATH = "."
5
+
6
+ class MusiLingoConfig(PretrainedConfig):
7
+ model_type = "musilingo"
8
+ is_encoder_decoder = True
9
+ def __init__(self,
10
+ mert_model = "m-a-p/MERT-v1-330M",
11
+ llama_model = f'lmsys/vicuna-7b-delta-v0',
12
+ prompt_path = "",
13
+ prompt_template = '###Human: {} ###Assistant: ',
14
+ max_txt_len = 32,
15
+ end_sym = '\n',
16
+ low_resource = False,
17
+ device_8bit = 0,
18
+ # linear_ckpt_path = "",
19
+ **kwargs):
20
+ self.mert_model = mert_model
21
+ self.llama_model = llama_model
22
+ self.prompt_path = prompt_path
23
+ self.prompt_template = prompt_template
24
+ self.max_txt_len = max_txt_len
25
+ self.end_sym = end_sym
26
+ self.low_resource = low_resource
27
+ self.device_8bit = device_8bit
28
+ # self.linear_ckpt_path = linear_ckpt_path
29
+ super().__init__(**kwargs)
model-00001-of-00003.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b3dfb974964f6b723d558e3107ee5b98ae1f13a08e681dfe7106806a01078e74
3
+ size 4986465504
model-00002-of-00003.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:54cb40c370505a6378fa41c4a30c352df90599ce458f55e36716d874cf292c31
3
+ size 4947397256
model-00003-of-00003.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2641d9106fd9436f8698e9a0a7ca44fc88ef5a51aa6703df879379ad3af6aa8c
3
+ size 4821600024
model.safetensors.index.json ADDED
@@ -0,0 +1,735 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "metadata": {
3
+ "total_size": 14755365376
4
+ },
5
+ "weight_map": {
6
+ "model.audio_encoder.encoder.layer_norm.bias": "model-00001-of-00003.safetensors",
7
+ "model.audio_encoder.encoder.layer_norm.weight": "model-00001-of-00003.safetensors",
8
+ "model.audio_encoder.encoder.layers.0.attention.k_proj.bias": "model-00001-of-00003.safetensors",
9
+ "model.audio_encoder.encoder.layers.0.attention.k_proj.weight": "model-00001-of-00003.safetensors",
10
+ "model.audio_encoder.encoder.layers.0.attention.out_proj.bias": "model-00001-of-00003.safetensors",
11
+ "model.audio_encoder.encoder.layers.0.attention.out_proj.weight": "model-00001-of-00003.safetensors",
12
+ "model.audio_encoder.encoder.layers.0.attention.q_proj.bias": "model-00001-of-00003.safetensors",
13
+ "model.audio_encoder.encoder.layers.0.attention.q_proj.weight": "model-00001-of-00003.safetensors",
14
+ "model.audio_encoder.encoder.layers.0.attention.v_proj.bias": "model-00001-of-00003.safetensors",
15
+ "model.audio_encoder.encoder.layers.0.attention.v_proj.weight": "model-00001-of-00003.safetensors",
16
+ "model.audio_encoder.encoder.layers.0.feed_forward.intermediate_dense.bias": "model-00001-of-00003.safetensors",
17
+ "model.audio_encoder.encoder.layers.0.feed_forward.intermediate_dense.weight": "model-00001-of-00003.safetensors",
18
+ "model.audio_encoder.encoder.layers.0.feed_forward.output_dense.bias": "model-00001-of-00003.safetensors",
19
+ "model.audio_encoder.encoder.layers.0.feed_forward.output_dense.weight": "model-00001-of-00003.safetensors",
20
+ "model.audio_encoder.encoder.layers.0.final_layer_norm.bias": "model-00001-of-00003.safetensors",
21
+ "model.audio_encoder.encoder.layers.0.final_layer_norm.weight": "model-00001-of-00003.safetensors",
22
+ "model.audio_encoder.encoder.layers.0.layer_norm.bias": "model-00001-of-00003.safetensors",
23
+ "model.audio_encoder.encoder.layers.0.layer_norm.weight": "model-00001-of-00003.safetensors",
24
+ "model.audio_encoder.encoder.layers.1.attention.k_proj.bias": "model-00001-of-00003.safetensors",
25
+ "model.audio_encoder.encoder.layers.1.attention.k_proj.weight": "model-00001-of-00003.safetensors",
26
+ "model.audio_encoder.encoder.layers.1.attention.out_proj.bias": "model-00001-of-00003.safetensors",
27
+ "model.audio_encoder.encoder.layers.1.attention.out_proj.weight": "model-00001-of-00003.safetensors",
28
+ "model.audio_encoder.encoder.layers.1.attention.q_proj.bias": "model-00001-of-00003.safetensors",
29
+ "model.audio_encoder.encoder.layers.1.attention.q_proj.weight": "model-00001-of-00003.safetensors",
30
+ "model.audio_encoder.encoder.layers.1.attention.v_proj.bias": "model-00001-of-00003.safetensors",
31
+ "model.audio_encoder.encoder.layers.1.attention.v_proj.weight": "model-00001-of-00003.safetensors",
32
+ "model.audio_encoder.encoder.layers.1.feed_forward.intermediate_dense.bias": "model-00001-of-00003.safetensors",
33
+ "model.audio_encoder.encoder.layers.1.feed_forward.intermediate_dense.weight": "model-00001-of-00003.safetensors",
34
+ "model.audio_encoder.encoder.layers.1.feed_forward.output_dense.bias": "model-00001-of-00003.safetensors",
35
+ "model.audio_encoder.encoder.layers.1.feed_forward.output_dense.weight": "model-00001-of-00003.safetensors",
36
+ "model.audio_encoder.encoder.layers.1.final_layer_norm.bias": "model-00001-of-00003.safetensors",
37
+ "model.audio_encoder.encoder.layers.1.final_layer_norm.weight": "model-00001-of-00003.safetensors",
38
+ "model.audio_encoder.encoder.layers.1.layer_norm.bias": "model-00001-of-00003.safetensors",
39
+ "model.audio_encoder.encoder.layers.1.layer_norm.weight": "model-00001-of-00003.safetensors",
40
+ "model.audio_encoder.encoder.layers.10.attention.k_proj.bias": "model-00001-of-00003.safetensors",
41
+ "model.audio_encoder.encoder.layers.10.attention.k_proj.weight": "model-00001-of-00003.safetensors",
42
+ "model.audio_encoder.encoder.layers.10.attention.out_proj.bias": "model-00001-of-00003.safetensors",
43
+ "model.audio_encoder.encoder.layers.10.attention.out_proj.weight": "model-00001-of-00003.safetensors",
44
+ "model.audio_encoder.encoder.layers.10.attention.q_proj.bias": "model-00001-of-00003.safetensors",
45
+ "model.audio_encoder.encoder.layers.10.attention.q_proj.weight": "model-00001-of-00003.safetensors",
46
+ "model.audio_encoder.encoder.layers.10.attention.v_proj.bias": "model-00001-of-00003.safetensors",
47
+ "model.audio_encoder.encoder.layers.10.attention.v_proj.weight": "model-00001-of-00003.safetensors",
48
+ "model.audio_encoder.encoder.layers.10.feed_forward.intermediate_dense.bias": "model-00001-of-00003.safetensors",
49
+ "model.audio_encoder.encoder.layers.10.feed_forward.intermediate_dense.weight": "model-00001-of-00003.safetensors",
50
+ "model.audio_encoder.encoder.layers.10.feed_forward.output_dense.bias": "model-00001-of-00003.safetensors",
51
+ "model.audio_encoder.encoder.layers.10.feed_forward.output_dense.weight": "model-00001-of-00003.safetensors",
52
+ "model.audio_encoder.encoder.layers.10.final_layer_norm.bias": "model-00001-of-00003.safetensors",
53
+ "model.audio_encoder.encoder.layers.10.final_layer_norm.weight": "model-00001-of-00003.safetensors",
54
+ "model.audio_encoder.encoder.layers.10.layer_norm.bias": "model-00001-of-00003.safetensors",
55
+ "model.audio_encoder.encoder.layers.10.layer_norm.weight": "model-00001-of-00003.safetensors",
56
+ "model.audio_encoder.encoder.layers.11.attention.k_proj.bias": "model-00001-of-00003.safetensors",
57
+ "model.audio_encoder.encoder.layers.11.attention.k_proj.weight": "model-00001-of-00003.safetensors",
58
+ "model.audio_encoder.encoder.layers.11.attention.out_proj.bias": "model-00001-of-00003.safetensors",
59
+ "model.audio_encoder.encoder.layers.11.attention.out_proj.weight": "model-00001-of-00003.safetensors",
60
+ "model.audio_encoder.encoder.layers.11.attention.q_proj.bias": "model-00001-of-00003.safetensors",
61
+ "model.audio_encoder.encoder.layers.11.attention.q_proj.weight": "model-00001-of-00003.safetensors",
62
+ "model.audio_encoder.encoder.layers.11.attention.v_proj.bias": "model-00001-of-00003.safetensors",
63
+ "model.audio_encoder.encoder.layers.11.attention.v_proj.weight": "model-00001-of-00003.safetensors",
64
+ "model.audio_encoder.encoder.layers.11.feed_forward.intermediate_dense.bias": "model-00001-of-00003.safetensors",
65
+ "model.audio_encoder.encoder.layers.11.feed_forward.intermediate_dense.weight": "model-00001-of-00003.safetensors",
66
+ "model.audio_encoder.encoder.layers.11.feed_forward.output_dense.bias": "model-00001-of-00003.safetensors",
67
+ "model.audio_encoder.encoder.layers.11.feed_forward.output_dense.weight": "model-00001-of-00003.safetensors",
68
+ "model.audio_encoder.encoder.layers.11.final_layer_norm.bias": "model-00001-of-00003.safetensors",
69
+ "model.audio_encoder.encoder.layers.11.final_layer_norm.weight": "model-00001-of-00003.safetensors",
70
+ "model.audio_encoder.encoder.layers.11.layer_norm.bias": "model-00001-of-00003.safetensors",
71
+ "model.audio_encoder.encoder.layers.11.layer_norm.weight": "model-00001-of-00003.safetensors",
72
+ "model.audio_encoder.encoder.layers.12.attention.k_proj.bias": "model-00001-of-00003.safetensors",
73
+ "model.audio_encoder.encoder.layers.12.attention.k_proj.weight": "model-00001-of-00003.safetensors",
74
+ "model.audio_encoder.encoder.layers.12.attention.out_proj.bias": "model-00001-of-00003.safetensors",
75
+ "model.audio_encoder.encoder.layers.12.attention.out_proj.weight": "model-00001-of-00003.safetensors",
76
+ "model.audio_encoder.encoder.layers.12.attention.q_proj.bias": "model-00001-of-00003.safetensors",
77
+ "model.audio_encoder.encoder.layers.12.attention.q_proj.weight": "model-00001-of-00003.safetensors",
78
+ "model.audio_encoder.encoder.layers.12.attention.v_proj.bias": "model-00001-of-00003.safetensors",
79
+ "model.audio_encoder.encoder.layers.12.attention.v_proj.weight": "model-00001-of-00003.safetensors",
80
+ "model.audio_encoder.encoder.layers.12.feed_forward.intermediate_dense.bias": "model-00001-of-00003.safetensors",
81
+ "model.audio_encoder.encoder.layers.12.feed_forward.intermediate_dense.weight": "model-00001-of-00003.safetensors",
82
+ "model.audio_encoder.encoder.layers.12.feed_forward.output_dense.bias": "model-00001-of-00003.safetensors",
83
+ "model.audio_encoder.encoder.layers.12.feed_forward.output_dense.weight": "model-00001-of-00003.safetensors",
84
+ "model.audio_encoder.encoder.layers.12.final_layer_norm.bias": "model-00001-of-00003.safetensors",
85
+ "model.audio_encoder.encoder.layers.12.final_layer_norm.weight": "model-00001-of-00003.safetensors",
86
+ "model.audio_encoder.encoder.layers.12.layer_norm.bias": "model-00001-of-00003.safetensors",
87
+ "model.audio_encoder.encoder.layers.12.layer_norm.weight": "model-00001-of-00003.safetensors",
88
+ "model.audio_encoder.encoder.layers.13.attention.k_proj.bias": "model-00001-of-00003.safetensors",
89
+ "model.audio_encoder.encoder.layers.13.attention.k_proj.weight": "model-00001-of-00003.safetensors",
90
+ "model.audio_encoder.encoder.layers.13.attention.out_proj.bias": "model-00001-of-00003.safetensors",
91
+ "model.audio_encoder.encoder.layers.13.attention.out_proj.weight": "model-00001-of-00003.safetensors",
92
+ "model.audio_encoder.encoder.layers.13.attention.q_proj.bias": "model-00001-of-00003.safetensors",
93
+ "model.audio_encoder.encoder.layers.13.attention.q_proj.weight": "model-00001-of-00003.safetensors",
94
+ "model.audio_encoder.encoder.layers.13.attention.v_proj.bias": "model-00001-of-00003.safetensors",
95
+ "model.audio_encoder.encoder.layers.13.attention.v_proj.weight": "model-00001-of-00003.safetensors",
96
+ "model.audio_encoder.encoder.layers.13.feed_forward.intermediate_dense.bias": "model-00001-of-00003.safetensors",
97
+ "model.audio_encoder.encoder.layers.13.feed_forward.intermediate_dense.weight": "model-00001-of-00003.safetensors",
98
+ "model.audio_encoder.encoder.layers.13.feed_forward.output_dense.bias": "model-00001-of-00003.safetensors",
99
+ "model.audio_encoder.encoder.layers.13.feed_forward.output_dense.weight": "model-00001-of-00003.safetensors",
100
+ "model.audio_encoder.encoder.layers.13.final_layer_norm.bias": "model-00001-of-00003.safetensors",
101
+ "model.audio_encoder.encoder.layers.13.final_layer_norm.weight": "model-00001-of-00003.safetensors",
102
+ "model.audio_encoder.encoder.layers.13.layer_norm.bias": "model-00001-of-00003.safetensors",
103
+ "model.audio_encoder.encoder.layers.13.layer_norm.weight": "model-00001-of-00003.safetensors",
104
+ "model.audio_encoder.encoder.layers.14.attention.k_proj.bias": "model-00001-of-00003.safetensors",
105
+ "model.audio_encoder.encoder.layers.14.attention.k_proj.weight": "model-00001-of-00003.safetensors",
106
+ "model.audio_encoder.encoder.layers.14.attention.out_proj.bias": "model-00001-of-00003.safetensors",
107
+ "model.audio_encoder.encoder.layers.14.attention.out_proj.weight": "model-00001-of-00003.safetensors",
108
+ "model.audio_encoder.encoder.layers.14.attention.q_proj.bias": "model-00001-of-00003.safetensors",
109
+ "model.audio_encoder.encoder.layers.14.attention.q_proj.weight": "model-00001-of-00003.safetensors",
110
+ "model.audio_encoder.encoder.layers.14.attention.v_proj.bias": "model-00001-of-00003.safetensors",
111
+ "model.audio_encoder.encoder.layers.14.attention.v_proj.weight": "model-00001-of-00003.safetensors",
112
+ "model.audio_encoder.encoder.layers.14.feed_forward.intermediate_dense.bias": "model-00001-of-00003.safetensors",
113
+ "model.audio_encoder.encoder.layers.14.feed_forward.intermediate_dense.weight": "model-00001-of-00003.safetensors",
114
+ "model.audio_encoder.encoder.layers.14.feed_forward.output_dense.bias": "model-00001-of-00003.safetensors",
115
+ "model.audio_encoder.encoder.layers.14.feed_forward.output_dense.weight": "model-00001-of-00003.safetensors",
116
+ "model.audio_encoder.encoder.layers.14.final_layer_norm.bias": "model-00001-of-00003.safetensors",
117
+ "model.audio_encoder.encoder.layers.14.final_layer_norm.weight": "model-00001-of-00003.safetensors",
118
+ "model.audio_encoder.encoder.layers.14.layer_norm.bias": "model-00001-of-00003.safetensors",
119
+ "model.audio_encoder.encoder.layers.14.layer_norm.weight": "model-00001-of-00003.safetensors",
120
+ "model.audio_encoder.encoder.layers.15.attention.k_proj.bias": "model-00001-of-00003.safetensors",
121
+ "model.audio_encoder.encoder.layers.15.attention.k_proj.weight": "model-00001-of-00003.safetensors",
122
+ "model.audio_encoder.encoder.layers.15.attention.out_proj.bias": "model-00001-of-00003.safetensors",
123
+ "model.audio_encoder.encoder.layers.15.attention.out_proj.weight": "model-00001-of-00003.safetensors",
124
+ "model.audio_encoder.encoder.layers.15.attention.q_proj.bias": "model-00001-of-00003.safetensors",
125
+ "model.audio_encoder.encoder.layers.15.attention.q_proj.weight": "model-00001-of-00003.safetensors",
126
+ "model.audio_encoder.encoder.layers.15.attention.v_proj.bias": "model-00001-of-00003.safetensors",
127
+ "model.audio_encoder.encoder.layers.15.attention.v_proj.weight": "model-00001-of-00003.safetensors",
128
+ "model.audio_encoder.encoder.layers.15.feed_forward.intermediate_dense.bias": "model-00001-of-00003.safetensors",
129
+ "model.audio_encoder.encoder.layers.15.feed_forward.intermediate_dense.weight": "model-00001-of-00003.safetensors",
130
+ "model.audio_encoder.encoder.layers.15.feed_forward.output_dense.bias": "model-00001-of-00003.safetensors",
131
+ "model.audio_encoder.encoder.layers.15.feed_forward.output_dense.weight": "model-00001-of-00003.safetensors",
132
+ "model.audio_encoder.encoder.layers.15.final_layer_norm.bias": "model-00001-of-00003.safetensors",
133
+ "model.audio_encoder.encoder.layers.15.final_layer_norm.weight": "model-00001-of-00003.safetensors",
134
+ "model.audio_encoder.encoder.layers.15.layer_norm.bias": "model-00001-of-00003.safetensors",
135
+ "model.audio_encoder.encoder.layers.15.layer_norm.weight": "model-00001-of-00003.safetensors",
136
+ "model.audio_encoder.encoder.layers.16.attention.k_proj.bias": "model-00001-of-00003.safetensors",
137
+ "model.audio_encoder.encoder.layers.16.attention.k_proj.weight": "model-00001-of-00003.safetensors",
138
+ "model.audio_encoder.encoder.layers.16.attention.out_proj.bias": "model-00001-of-00003.safetensors",
139
+ "model.audio_encoder.encoder.layers.16.attention.out_proj.weight": "model-00001-of-00003.safetensors",
140
+ "model.audio_encoder.encoder.layers.16.attention.q_proj.bias": "model-00001-of-00003.safetensors",
141
+ "model.audio_encoder.encoder.layers.16.attention.q_proj.weight": "model-00001-of-00003.safetensors",
142
+ "model.audio_encoder.encoder.layers.16.attention.v_proj.bias": "model-00001-of-00003.safetensors",
143
+ "model.audio_encoder.encoder.layers.16.attention.v_proj.weight": "model-00001-of-00003.safetensors",
144
+ "model.audio_encoder.encoder.layers.16.feed_forward.intermediate_dense.bias": "model-00001-of-00003.safetensors",
145
+ "model.audio_encoder.encoder.layers.16.feed_forward.intermediate_dense.weight": "model-00001-of-00003.safetensors",
146
+ "model.audio_encoder.encoder.layers.16.feed_forward.output_dense.bias": "model-00001-of-00003.safetensors",
147
+ "model.audio_encoder.encoder.layers.16.feed_forward.output_dense.weight": "model-00001-of-00003.safetensors",
148
+ "model.audio_encoder.encoder.layers.16.final_layer_norm.bias": "model-00001-of-00003.safetensors",
149
+ "model.audio_encoder.encoder.layers.16.final_layer_norm.weight": "model-00001-of-00003.safetensors",
150
+ "model.audio_encoder.encoder.layers.16.layer_norm.bias": "model-00001-of-00003.safetensors",
151
+ "model.audio_encoder.encoder.layers.16.layer_norm.weight": "model-00001-of-00003.safetensors",
152
+ "model.audio_encoder.encoder.layers.17.attention.k_proj.bias": "model-00001-of-00003.safetensors",
153
+ "model.audio_encoder.encoder.layers.17.attention.k_proj.weight": "model-00001-of-00003.safetensors",
154
+ "model.audio_encoder.encoder.layers.17.attention.out_proj.bias": "model-00001-of-00003.safetensors",
155
+ "model.audio_encoder.encoder.layers.17.attention.out_proj.weight": "model-00001-of-00003.safetensors",
156
+ "model.audio_encoder.encoder.layers.17.attention.q_proj.bias": "model-00001-of-00003.safetensors",
157
+ "model.audio_encoder.encoder.layers.17.attention.q_proj.weight": "model-00001-of-00003.safetensors",
158
+ "model.audio_encoder.encoder.layers.17.attention.v_proj.bias": "model-00001-of-00003.safetensors",
159
+ "model.audio_encoder.encoder.layers.17.attention.v_proj.weight": "model-00001-of-00003.safetensors",
160
+ "model.audio_encoder.encoder.layers.17.feed_forward.intermediate_dense.bias": "model-00001-of-00003.safetensors",
161
+ "model.audio_encoder.encoder.layers.17.feed_forward.intermediate_dense.weight": "model-00001-of-00003.safetensors",
162
+ "model.audio_encoder.encoder.layers.17.feed_forward.output_dense.bias": "model-00001-of-00003.safetensors",
163
+ "model.audio_encoder.encoder.layers.17.feed_forward.output_dense.weight": "model-00001-of-00003.safetensors",
164
+ "model.audio_encoder.encoder.layers.17.final_layer_norm.bias": "model-00001-of-00003.safetensors",
165
+ "model.audio_encoder.encoder.layers.17.final_layer_norm.weight": "model-00001-of-00003.safetensors",
166
+ "model.audio_encoder.encoder.layers.17.layer_norm.bias": "model-00001-of-00003.safetensors",
167
+ "model.audio_encoder.encoder.layers.17.layer_norm.weight": "model-00001-of-00003.safetensors",
168
+ "model.audio_encoder.encoder.layers.18.attention.k_proj.bias": "model-00001-of-00003.safetensors",
169
+ "model.audio_encoder.encoder.layers.18.attention.k_proj.weight": "model-00001-of-00003.safetensors",
170
+ "model.audio_encoder.encoder.layers.18.attention.out_proj.bias": "model-00001-of-00003.safetensors",
171
+ "model.audio_encoder.encoder.layers.18.attention.out_proj.weight": "model-00001-of-00003.safetensors",
172
+ "model.audio_encoder.encoder.layers.18.attention.q_proj.bias": "model-00001-of-00003.safetensors",
173
+ "model.audio_encoder.encoder.layers.18.attention.q_proj.weight": "model-00001-of-00003.safetensors",
174
+ "model.audio_encoder.encoder.layers.18.attention.v_proj.bias": "model-00001-of-00003.safetensors",
175
+ "model.audio_encoder.encoder.layers.18.attention.v_proj.weight": "model-00001-of-00003.safetensors",
176
+ "model.audio_encoder.encoder.layers.18.feed_forward.intermediate_dense.bias": "model-00001-of-00003.safetensors",
177
+ "model.audio_encoder.encoder.layers.18.feed_forward.intermediate_dense.weight": "model-00001-of-00003.safetensors",
178
+ "model.audio_encoder.encoder.layers.18.feed_forward.output_dense.bias": "model-00001-of-00003.safetensors",
179
+ "model.audio_encoder.encoder.layers.18.feed_forward.output_dense.weight": "model-00001-of-00003.safetensors",
180
+ "model.audio_encoder.encoder.layers.18.final_layer_norm.bias": "model-00001-of-00003.safetensors",
181
+ "model.audio_encoder.encoder.layers.18.final_layer_norm.weight": "model-00001-of-00003.safetensors",
182
+ "model.audio_encoder.encoder.layers.18.layer_norm.bias": "model-00001-of-00003.safetensors",
183
+ "model.audio_encoder.encoder.layers.18.layer_norm.weight": "model-00001-of-00003.safetensors",
184
+ "model.audio_encoder.encoder.layers.19.attention.k_proj.bias": "model-00001-of-00003.safetensors",
185
+ "model.audio_encoder.encoder.layers.19.attention.k_proj.weight": "model-00001-of-00003.safetensors",
186
+ "model.audio_encoder.encoder.layers.19.attention.out_proj.bias": "model-00001-of-00003.safetensors",
187
+ "model.audio_encoder.encoder.layers.19.attention.out_proj.weight": "model-00001-of-00003.safetensors",
188
+ "model.audio_encoder.encoder.layers.19.attention.q_proj.bias": "model-00001-of-00003.safetensors",
189
+ "model.audio_encoder.encoder.layers.19.attention.q_proj.weight": "model-00001-of-00003.safetensors",
190
+ "model.audio_encoder.encoder.layers.19.attention.v_proj.bias": "model-00001-of-00003.safetensors",
191
+ "model.audio_encoder.encoder.layers.19.attention.v_proj.weight": "model-00001-of-00003.safetensors",
192
+ "model.audio_encoder.encoder.layers.19.feed_forward.intermediate_dense.bias": "model-00001-of-00003.safetensors",
193
+ "model.audio_encoder.encoder.layers.19.feed_forward.intermediate_dense.weight": "model-00001-of-00003.safetensors",
194
+ "model.audio_encoder.encoder.layers.19.feed_forward.output_dense.bias": "model-00001-of-00003.safetensors",
195
+ "model.audio_encoder.encoder.layers.19.feed_forward.output_dense.weight": "model-00001-of-00003.safetensors",
196
+ "model.audio_encoder.encoder.layers.19.final_layer_norm.bias": "model-00001-of-00003.safetensors",
197
+ "model.audio_encoder.encoder.layers.19.final_layer_norm.weight": "model-00001-of-00003.safetensors",
198
+ "model.audio_encoder.encoder.layers.19.layer_norm.bias": "model-00001-of-00003.safetensors",
199
+ "model.audio_encoder.encoder.layers.19.layer_norm.weight": "model-00001-of-00003.safetensors",
200
+ "model.audio_encoder.encoder.layers.2.attention.k_proj.bias": "model-00001-of-00003.safetensors",
201
+ "model.audio_encoder.encoder.layers.2.attention.k_proj.weight": "model-00001-of-00003.safetensors",
202
+ "model.audio_encoder.encoder.layers.2.attention.out_proj.bias": "model-00001-of-00003.safetensors",
203
+ "model.audio_encoder.encoder.layers.2.attention.out_proj.weight": "model-00001-of-00003.safetensors",
204
+ "model.audio_encoder.encoder.layers.2.attention.q_proj.bias": "model-00001-of-00003.safetensors",
205
+ "model.audio_encoder.encoder.layers.2.attention.q_proj.weight": "model-00001-of-00003.safetensors",
206
+ "model.audio_encoder.encoder.layers.2.attention.v_proj.bias": "model-00001-of-00003.safetensors",
207
+ "model.audio_encoder.encoder.layers.2.attention.v_proj.weight": "model-00001-of-00003.safetensors",
208
+ "model.audio_encoder.encoder.layers.2.feed_forward.intermediate_dense.bias": "model-00001-of-00003.safetensors",
209
+ "model.audio_encoder.encoder.layers.2.feed_forward.intermediate_dense.weight": "model-00001-of-00003.safetensors",
210
+ "model.audio_encoder.encoder.layers.2.feed_forward.output_dense.bias": "model-00001-of-00003.safetensors",
211
+ "model.audio_encoder.encoder.layers.2.feed_forward.output_dense.weight": "model-00001-of-00003.safetensors",
212
+ "model.audio_encoder.encoder.layers.2.final_layer_norm.bias": "model-00001-of-00003.safetensors",
213
+ "model.audio_encoder.encoder.layers.2.final_layer_norm.weight": "model-00001-of-00003.safetensors",
214
+ "model.audio_encoder.encoder.layers.2.layer_norm.bias": "model-00001-of-00003.safetensors",
215
+ "model.audio_encoder.encoder.layers.2.layer_norm.weight": "model-00001-of-00003.safetensors",
216
+ "model.audio_encoder.encoder.layers.20.attention.k_proj.bias": "model-00001-of-00003.safetensors",
217
+ "model.audio_encoder.encoder.layers.20.attention.k_proj.weight": "model-00001-of-00003.safetensors",
218
+ "model.audio_encoder.encoder.layers.20.attention.out_proj.bias": "model-00001-of-00003.safetensors",
219
+ "model.audio_encoder.encoder.layers.20.attention.out_proj.weight": "model-00001-of-00003.safetensors",
220
+ "model.audio_encoder.encoder.layers.20.attention.q_proj.bias": "model-00001-of-00003.safetensors",
221
+ "model.audio_encoder.encoder.layers.20.attention.q_proj.weight": "model-00001-of-00003.safetensors",
222
+ "model.audio_encoder.encoder.layers.20.attention.v_proj.bias": "model-00001-of-00003.safetensors",
223
+ "model.audio_encoder.encoder.layers.20.attention.v_proj.weight": "model-00001-of-00003.safetensors",
224
+ "model.audio_encoder.encoder.layers.20.feed_forward.intermediate_dense.bias": "model-00001-of-00003.safetensors",
225
+ "model.audio_encoder.encoder.layers.20.feed_forward.intermediate_dense.weight": "model-00001-of-00003.safetensors",
226
+ "model.audio_encoder.encoder.layers.20.feed_forward.output_dense.bias": "model-00001-of-00003.safetensors",
227
+ "model.audio_encoder.encoder.layers.20.feed_forward.output_dense.weight": "model-00001-of-00003.safetensors",
228
+ "model.audio_encoder.encoder.layers.20.final_layer_norm.bias": "model-00001-of-00003.safetensors",
229
+ "model.audio_encoder.encoder.layers.20.final_layer_norm.weight": "model-00001-of-00003.safetensors",
230
+ "model.audio_encoder.encoder.layers.20.layer_norm.bias": "model-00001-of-00003.safetensors",
231
+ "model.audio_encoder.encoder.layers.20.layer_norm.weight": "model-00001-of-00003.safetensors",
232
+ "model.audio_encoder.encoder.layers.21.attention.k_proj.bias": "model-00001-of-00003.safetensors",
233
+ "model.audio_encoder.encoder.layers.21.attention.k_proj.weight": "model-00001-of-00003.safetensors",
234
+ "model.audio_encoder.encoder.layers.21.attention.out_proj.bias": "model-00001-of-00003.safetensors",
235
+ "model.audio_encoder.encoder.layers.21.attention.out_proj.weight": "model-00001-of-00003.safetensors",
236
+ "model.audio_encoder.encoder.layers.21.attention.q_proj.bias": "model-00001-of-00003.safetensors",
237
+ "model.audio_encoder.encoder.layers.21.attention.q_proj.weight": "model-00001-of-00003.safetensors",
238
+ "model.audio_encoder.encoder.layers.21.attention.v_proj.bias": "model-00001-of-00003.safetensors",
239
+ "model.audio_encoder.encoder.layers.21.attention.v_proj.weight": "model-00001-of-00003.safetensors",
240
+ "model.audio_encoder.encoder.layers.21.feed_forward.intermediate_dense.bias": "model-00001-of-00003.safetensors",
241
+ "model.audio_encoder.encoder.layers.21.feed_forward.intermediate_dense.weight": "model-00001-of-00003.safetensors",
242
+ "model.audio_encoder.encoder.layers.21.feed_forward.output_dense.bias": "model-00001-of-00003.safetensors",
243
+ "model.audio_encoder.encoder.layers.21.feed_forward.output_dense.weight": "model-00001-of-00003.safetensors",
244
+ "model.audio_encoder.encoder.layers.21.final_layer_norm.bias": "model-00001-of-00003.safetensors",
245
+ "model.audio_encoder.encoder.layers.21.final_layer_norm.weight": "model-00001-of-00003.safetensors",
246
+ "model.audio_encoder.encoder.layers.21.layer_norm.bias": "model-00001-of-00003.safetensors",
247
+ "model.audio_encoder.encoder.layers.21.layer_norm.weight": "model-00001-of-00003.safetensors",
248
+ "model.audio_encoder.encoder.layers.22.attention.k_proj.bias": "model-00001-of-00003.safetensors",
249
+ "model.audio_encoder.encoder.layers.22.attention.k_proj.weight": "model-00001-of-00003.safetensors",
250
+ "model.audio_encoder.encoder.layers.22.attention.out_proj.bias": "model-00001-of-00003.safetensors",
251
+ "model.audio_encoder.encoder.layers.22.attention.out_proj.weight": "model-00001-of-00003.safetensors",
252
+ "model.audio_encoder.encoder.layers.22.attention.q_proj.bias": "model-00001-of-00003.safetensors",
253
+ "model.audio_encoder.encoder.layers.22.attention.q_proj.weight": "model-00001-of-00003.safetensors",
254
+ "model.audio_encoder.encoder.layers.22.attention.v_proj.bias": "model-00001-of-00003.safetensors",
255
+ "model.audio_encoder.encoder.layers.22.attention.v_proj.weight": "model-00001-of-00003.safetensors",
256
+ "model.audio_encoder.encoder.layers.22.feed_forward.intermediate_dense.bias": "model-00001-of-00003.safetensors",
257
+ "model.audio_encoder.encoder.layers.22.feed_forward.intermediate_dense.weight": "model-00001-of-00003.safetensors",
258
+ "model.audio_encoder.encoder.layers.22.feed_forward.output_dense.bias": "model-00001-of-00003.safetensors",
259
+ "model.audio_encoder.encoder.layers.22.feed_forward.output_dense.weight": "model-00001-of-00003.safetensors",
260
+ "model.audio_encoder.encoder.layers.22.final_layer_norm.bias": "model-00001-of-00003.safetensors",
261
+ "model.audio_encoder.encoder.layers.22.final_layer_norm.weight": "model-00001-of-00003.safetensors",
262
+ "model.audio_encoder.encoder.layers.22.layer_norm.bias": "model-00001-of-00003.safetensors",
263
+ "model.audio_encoder.encoder.layers.22.layer_norm.weight": "model-00001-of-00003.safetensors",
264
+ "model.audio_encoder.encoder.layers.23.attention.k_proj.bias": "model-00001-of-00003.safetensors",
265
+ "model.audio_encoder.encoder.layers.23.attention.k_proj.weight": "model-00001-of-00003.safetensors",
266
+ "model.audio_encoder.encoder.layers.23.attention.out_proj.bias": "model-00001-of-00003.safetensors",
267
+ "model.audio_encoder.encoder.layers.23.attention.out_proj.weight": "model-00001-of-00003.safetensors",
268
+ "model.audio_encoder.encoder.layers.23.attention.q_proj.bias": "model-00001-of-00003.safetensors",
269
+ "model.audio_encoder.encoder.layers.23.attention.q_proj.weight": "model-00001-of-00003.safetensors",
270
+ "model.audio_encoder.encoder.layers.23.attention.v_proj.bias": "model-00001-of-00003.safetensors",
271
+ "model.audio_encoder.encoder.layers.23.attention.v_proj.weight": "model-00001-of-00003.safetensors",
272
+ "model.audio_encoder.encoder.layers.23.feed_forward.intermediate_dense.bias": "model-00001-of-00003.safetensors",
273
+ "model.audio_encoder.encoder.layers.23.feed_forward.intermediate_dense.weight": "model-00001-of-00003.safetensors",
274
+ "model.audio_encoder.encoder.layers.23.feed_forward.output_dense.bias": "model-00001-of-00003.safetensors",
275
+ "model.audio_encoder.encoder.layers.23.feed_forward.output_dense.weight": "model-00001-of-00003.safetensors",
276
+ "model.audio_encoder.encoder.layers.23.final_layer_norm.bias": "model-00001-of-00003.safetensors",
277
+ "model.audio_encoder.encoder.layers.23.final_layer_norm.weight": "model-00001-of-00003.safetensors",
278
+ "model.audio_encoder.encoder.layers.23.layer_norm.bias": "model-00001-of-00003.safetensors",
279
+ "model.audio_encoder.encoder.layers.23.layer_norm.weight": "model-00001-of-00003.safetensors",
280
+ "model.audio_encoder.encoder.layers.3.attention.k_proj.bias": "model-00001-of-00003.safetensors",
281
+ "model.audio_encoder.encoder.layers.3.attention.k_proj.weight": "model-00001-of-00003.safetensors",
282
+ "model.audio_encoder.encoder.layers.3.attention.out_proj.bias": "model-00001-of-00003.safetensors",
283
+ "model.audio_encoder.encoder.layers.3.attention.out_proj.weight": "model-00001-of-00003.safetensors",
284
+ "model.audio_encoder.encoder.layers.3.attention.q_proj.bias": "model-00001-of-00003.safetensors",
285
+ "model.audio_encoder.encoder.layers.3.attention.q_proj.weight": "model-00001-of-00003.safetensors",
286
+ "model.audio_encoder.encoder.layers.3.attention.v_proj.bias": "model-00001-of-00003.safetensors",
287
+ "model.audio_encoder.encoder.layers.3.attention.v_proj.weight": "model-00001-of-00003.safetensors",
288
+ "model.audio_encoder.encoder.layers.3.feed_forward.intermediate_dense.bias": "model-00001-of-00003.safetensors",
289
+ "model.audio_encoder.encoder.layers.3.feed_forward.intermediate_dense.weight": "model-00001-of-00003.safetensors",
290
+ "model.audio_encoder.encoder.layers.3.feed_forward.output_dense.bias": "model-00001-of-00003.safetensors",
291
+ "model.audio_encoder.encoder.layers.3.feed_forward.output_dense.weight": "model-00001-of-00003.safetensors",
292
+ "model.audio_encoder.encoder.layers.3.final_layer_norm.bias": "model-00001-of-00003.safetensors",
293
+ "model.audio_encoder.encoder.layers.3.final_layer_norm.weight": "model-00001-of-00003.safetensors",
294
+ "model.audio_encoder.encoder.layers.3.layer_norm.bias": "model-00001-of-00003.safetensors",
295
+ "model.audio_encoder.encoder.layers.3.layer_norm.weight": "model-00001-of-00003.safetensors",
296
+ "model.audio_encoder.encoder.layers.4.attention.k_proj.bias": "model-00001-of-00003.safetensors",
297
+ "model.audio_encoder.encoder.layers.4.attention.k_proj.weight": "model-00001-of-00003.safetensors",
298
+ "model.audio_encoder.encoder.layers.4.attention.out_proj.bias": "model-00001-of-00003.safetensors",
299
+ "model.audio_encoder.encoder.layers.4.attention.out_proj.weight": "model-00001-of-00003.safetensors",
300
+ "model.audio_encoder.encoder.layers.4.attention.q_proj.bias": "model-00001-of-00003.safetensors",
301
+ "model.audio_encoder.encoder.layers.4.attention.q_proj.weight": "model-00001-of-00003.safetensors",
302
+ "model.audio_encoder.encoder.layers.4.attention.v_proj.bias": "model-00001-of-00003.safetensors",
303
+ "model.audio_encoder.encoder.layers.4.attention.v_proj.weight": "model-00001-of-00003.safetensors",
304
+ "model.audio_encoder.encoder.layers.4.feed_forward.intermediate_dense.bias": "model-00001-of-00003.safetensors",
305
+ "model.audio_encoder.encoder.layers.4.feed_forward.intermediate_dense.weight": "model-00001-of-00003.safetensors",
306
+ "model.audio_encoder.encoder.layers.4.feed_forward.output_dense.bias": "model-00001-of-00003.safetensors",
307
+ "model.audio_encoder.encoder.layers.4.feed_forward.output_dense.weight": "model-00001-of-00003.safetensors",
308
+ "model.audio_encoder.encoder.layers.4.final_layer_norm.bias": "model-00001-of-00003.safetensors",
309
+ "model.audio_encoder.encoder.layers.4.final_layer_norm.weight": "model-00001-of-00003.safetensors",
310
+ "model.audio_encoder.encoder.layers.4.layer_norm.bias": "model-00001-of-00003.safetensors",
311
+ "model.audio_encoder.encoder.layers.4.layer_norm.weight": "model-00001-of-00003.safetensors",
312
+ "model.audio_encoder.encoder.layers.5.attention.k_proj.bias": "model-00001-of-00003.safetensors",
313
+ "model.audio_encoder.encoder.layers.5.attention.k_proj.weight": "model-00001-of-00003.safetensors",
314
+ "model.audio_encoder.encoder.layers.5.attention.out_proj.bias": "model-00001-of-00003.safetensors",
315
+ "model.audio_encoder.encoder.layers.5.attention.out_proj.weight": "model-00001-of-00003.safetensors",
316
+ "model.audio_encoder.encoder.layers.5.attention.q_proj.bias": "model-00001-of-00003.safetensors",
317
+ "model.audio_encoder.encoder.layers.5.attention.q_proj.weight": "model-00001-of-00003.safetensors",
318
+ "model.audio_encoder.encoder.layers.5.attention.v_proj.bias": "model-00001-of-00003.safetensors",
319
+ "model.audio_encoder.encoder.layers.5.attention.v_proj.weight": "model-00001-of-00003.safetensors",
320
+ "model.audio_encoder.encoder.layers.5.feed_forward.intermediate_dense.bias": "model-00001-of-00003.safetensors",
321
+ "model.audio_encoder.encoder.layers.5.feed_forward.intermediate_dense.weight": "model-00001-of-00003.safetensors",
322
+ "model.audio_encoder.encoder.layers.5.feed_forward.output_dense.bias": "model-00001-of-00003.safetensors",
323
+ "model.audio_encoder.encoder.layers.5.feed_forward.output_dense.weight": "model-00001-of-00003.safetensors",
324
+ "model.audio_encoder.encoder.layers.5.final_layer_norm.bias": "model-00001-of-00003.safetensors",
325
+ "model.audio_encoder.encoder.layers.5.final_layer_norm.weight": "model-00001-of-00003.safetensors",
326
+ "model.audio_encoder.encoder.layers.5.layer_norm.bias": "model-00001-of-00003.safetensors",
327
+ "model.audio_encoder.encoder.layers.5.layer_norm.weight": "model-00001-of-00003.safetensors",
328
+ "model.audio_encoder.encoder.layers.6.attention.k_proj.bias": "model-00001-of-00003.safetensors",
329
+ "model.audio_encoder.encoder.layers.6.attention.k_proj.weight": "model-00001-of-00003.safetensors",
330
+ "model.audio_encoder.encoder.layers.6.attention.out_proj.bias": "model-00001-of-00003.safetensors",
331
+ "model.audio_encoder.encoder.layers.6.attention.out_proj.weight": "model-00001-of-00003.safetensors",
332
+ "model.audio_encoder.encoder.layers.6.attention.q_proj.bias": "model-00001-of-00003.safetensors",
333
+ "model.audio_encoder.encoder.layers.6.attention.q_proj.weight": "model-00001-of-00003.safetensors",
334
+ "model.audio_encoder.encoder.layers.6.attention.v_proj.bias": "model-00001-of-00003.safetensors",
335
+ "model.audio_encoder.encoder.layers.6.attention.v_proj.weight": "model-00001-of-00003.safetensors",
336
+ "model.audio_encoder.encoder.layers.6.feed_forward.intermediate_dense.bias": "model-00001-of-00003.safetensors",
337
+ "model.audio_encoder.encoder.layers.6.feed_forward.intermediate_dense.weight": "model-00001-of-00003.safetensors",
338
+ "model.audio_encoder.encoder.layers.6.feed_forward.output_dense.bias": "model-00001-of-00003.safetensors",
339
+ "model.audio_encoder.encoder.layers.6.feed_forward.output_dense.weight": "model-00001-of-00003.safetensors",
340
+ "model.audio_encoder.encoder.layers.6.final_layer_norm.bias": "model-00001-of-00003.safetensors",
341
+ "model.audio_encoder.encoder.layers.6.final_layer_norm.weight": "model-00001-of-00003.safetensors",
342
+ "model.audio_encoder.encoder.layers.6.layer_norm.bias": "model-00001-of-00003.safetensors",
343
+ "model.audio_encoder.encoder.layers.6.layer_norm.weight": "model-00001-of-00003.safetensors",
344
+ "model.audio_encoder.encoder.layers.7.attention.k_proj.bias": "model-00001-of-00003.safetensors",
345
+ "model.audio_encoder.encoder.layers.7.attention.k_proj.weight": "model-00001-of-00003.safetensors",
346
+ "model.audio_encoder.encoder.layers.7.attention.out_proj.bias": "model-00001-of-00003.safetensors",
347
+ "model.audio_encoder.encoder.layers.7.attention.out_proj.weight": "model-00001-of-00003.safetensors",
348
+ "model.audio_encoder.encoder.layers.7.attention.q_proj.bias": "model-00001-of-00003.safetensors",
349
+ "model.audio_encoder.encoder.layers.7.attention.q_proj.weight": "model-00001-of-00003.safetensors",
350
+ "model.audio_encoder.encoder.layers.7.attention.v_proj.bias": "model-00001-of-00003.safetensors",
351
+ "model.audio_encoder.encoder.layers.7.attention.v_proj.weight": "model-00001-of-00003.safetensors",
352
+ "model.audio_encoder.encoder.layers.7.feed_forward.intermediate_dense.bias": "model-00001-of-00003.safetensors",
353
+ "model.audio_encoder.encoder.layers.7.feed_forward.intermediate_dense.weight": "model-00001-of-00003.safetensors",
354
+ "model.audio_encoder.encoder.layers.7.feed_forward.output_dense.bias": "model-00001-of-00003.safetensors",
355
+ "model.audio_encoder.encoder.layers.7.feed_forward.output_dense.weight": "model-00001-of-00003.safetensors",
356
+ "model.audio_encoder.encoder.layers.7.final_layer_norm.bias": "model-00001-of-00003.safetensors",
357
+ "model.audio_encoder.encoder.layers.7.final_layer_norm.weight": "model-00001-of-00003.safetensors",
358
+ "model.audio_encoder.encoder.layers.7.layer_norm.bias": "model-00001-of-00003.safetensors",
359
+ "model.audio_encoder.encoder.layers.7.layer_norm.weight": "model-00001-of-00003.safetensors",
360
+ "model.audio_encoder.encoder.layers.8.attention.k_proj.bias": "model-00001-of-00003.safetensors",
361
+ "model.audio_encoder.encoder.layers.8.attention.k_proj.weight": "model-00001-of-00003.safetensors",
362
+ "model.audio_encoder.encoder.layers.8.attention.out_proj.bias": "model-00001-of-00003.safetensors",
363
+ "model.audio_encoder.encoder.layers.8.attention.out_proj.weight": "model-00001-of-00003.safetensors",
364
+ "model.audio_encoder.encoder.layers.8.attention.q_proj.bias": "model-00001-of-00003.safetensors",
365
+ "model.audio_encoder.encoder.layers.8.attention.q_proj.weight": "model-00001-of-00003.safetensors",
366
+ "model.audio_encoder.encoder.layers.8.attention.v_proj.bias": "model-00001-of-00003.safetensors",
367
+ "model.audio_encoder.encoder.layers.8.attention.v_proj.weight": "model-00001-of-00003.safetensors",
368
+ "model.audio_encoder.encoder.layers.8.feed_forward.intermediate_dense.bias": "model-00001-of-00003.safetensors",
369
+ "model.audio_encoder.encoder.layers.8.feed_forward.intermediate_dense.weight": "model-00001-of-00003.safetensors",
370
+ "model.audio_encoder.encoder.layers.8.feed_forward.output_dense.bias": "model-00001-of-00003.safetensors",
371
+ "model.audio_encoder.encoder.layers.8.feed_forward.output_dense.weight": "model-00001-of-00003.safetensors",
372
+ "model.audio_encoder.encoder.layers.8.final_layer_norm.bias": "model-00001-of-00003.safetensors",
373
+ "model.audio_encoder.encoder.layers.8.final_layer_norm.weight": "model-00001-of-00003.safetensors",
374
+ "model.audio_encoder.encoder.layers.8.layer_norm.bias": "model-00001-of-00003.safetensors",
375
+ "model.audio_encoder.encoder.layers.8.layer_norm.weight": "model-00001-of-00003.safetensors",
376
+ "model.audio_encoder.encoder.layers.9.attention.k_proj.bias": "model-00001-of-00003.safetensors",
377
+ "model.audio_encoder.encoder.layers.9.attention.k_proj.weight": "model-00001-of-00003.safetensors",
378
+ "model.audio_encoder.encoder.layers.9.attention.out_proj.bias": "model-00001-of-00003.safetensors",
379
+ "model.audio_encoder.encoder.layers.9.attention.out_proj.weight": "model-00001-of-00003.safetensors",
380
+ "model.audio_encoder.encoder.layers.9.attention.q_proj.bias": "model-00001-of-00003.safetensors",
381
+ "model.audio_encoder.encoder.layers.9.attention.q_proj.weight": "model-00001-of-00003.safetensors",
382
+ "model.audio_encoder.encoder.layers.9.attention.v_proj.bias": "model-00001-of-00003.safetensors",
383
+ "model.audio_encoder.encoder.layers.9.attention.v_proj.weight": "model-00001-of-00003.safetensors",
384
+ "model.audio_encoder.encoder.layers.9.feed_forward.intermediate_dense.bias": "model-00001-of-00003.safetensors",
385
+ "model.audio_encoder.encoder.layers.9.feed_forward.intermediate_dense.weight": "model-00001-of-00003.safetensors",
386
+ "model.audio_encoder.encoder.layers.9.feed_forward.output_dense.bias": "model-00001-of-00003.safetensors",
387
+ "model.audio_encoder.encoder.layers.9.feed_forward.output_dense.weight": "model-00001-of-00003.safetensors",
388
+ "model.audio_encoder.encoder.layers.9.final_layer_norm.bias": "model-00001-of-00003.safetensors",
389
+ "model.audio_encoder.encoder.layers.9.final_layer_norm.weight": "model-00001-of-00003.safetensors",
390
+ "model.audio_encoder.encoder.layers.9.layer_norm.bias": "model-00001-of-00003.safetensors",
391
+ "model.audio_encoder.encoder.layers.9.layer_norm.weight": "model-00001-of-00003.safetensors",
392
+ "model.audio_encoder.encoder.pos_conv_embed.conv.bias": "model-00001-of-00003.safetensors",
393
+ "model.audio_encoder.encoder.pos_conv_embed.conv.parametrizations.weight.original0": "model-00001-of-00003.safetensors",
394
+ "model.audio_encoder.encoder.pos_conv_embed.conv.parametrizations.weight.original1": "model-00001-of-00003.safetensors",
395
+ "model.audio_encoder.feature_extractor.conv_layers.0.conv.weight": "model-00001-of-00003.safetensors",
396
+ "model.audio_encoder.feature_extractor.conv_layers.0.layer_norm.bias": "model-00001-of-00003.safetensors",
397
+ "model.audio_encoder.feature_extractor.conv_layers.0.layer_norm.weight": "model-00001-of-00003.safetensors",
398
+ "model.audio_encoder.feature_extractor.conv_layers.1.conv.weight": "model-00001-of-00003.safetensors",
399
+ "model.audio_encoder.feature_extractor.conv_layers.2.conv.weight": "model-00001-of-00003.safetensors",
400
+ "model.audio_encoder.feature_extractor.conv_layers.3.conv.weight": "model-00001-of-00003.safetensors",
401
+ "model.audio_encoder.feature_extractor.conv_layers.4.conv.weight": "model-00001-of-00003.safetensors",
402
+ "model.audio_encoder.feature_extractor.conv_layers.5.conv.weight": "model-00001-of-00003.safetensors",
403
+ "model.audio_encoder.feature_extractor.conv_layers.6.conv.weight": "model-00001-of-00003.safetensors",
404
+ "model.audio_encoder.feature_projection.layer_norm.bias": "model-00001-of-00003.safetensors",
405
+ "model.audio_encoder.feature_projection.layer_norm.weight": "model-00001-of-00003.safetensors",
406
+ "model.audio_encoder.feature_projection.projection.bias": "model-00001-of-00003.safetensors",
407
+ "model.audio_encoder.feature_projection.projection.weight": "model-00001-of-00003.safetensors",
408
+ "model.audio_encoder.masked_spec_embed": "model-00001-of-00003.safetensors",
409
+ "model.llama_model.lm_head.weight": "model-00003-of-00003.safetensors",
410
+ "model.llama_model.model.embed_tokens.weight": "model-00001-of-00003.safetensors",
411
+ "model.llama_model.model.layers.0.input_layernorm.weight": "model-00001-of-00003.safetensors",
412
+ "model.llama_model.model.layers.0.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
413
+ "model.llama_model.model.layers.0.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
414
+ "model.llama_model.model.layers.0.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
415
+ "model.llama_model.model.layers.0.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
416
+ "model.llama_model.model.layers.0.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
417
+ "model.llama_model.model.layers.0.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
418
+ "model.llama_model.model.layers.0.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
419
+ "model.llama_model.model.layers.0.self_attn.rotary_emb.inv_freq": "model-00001-of-00003.safetensors",
420
+ "model.llama_model.model.layers.0.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
421
+ "model.llama_model.model.layers.1.input_layernorm.weight": "model-00001-of-00003.safetensors",
422
+ "model.llama_model.model.layers.1.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
423
+ "model.llama_model.model.layers.1.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
424
+ "model.llama_model.model.layers.1.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
425
+ "model.llama_model.model.layers.1.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
426
+ "model.llama_model.model.layers.1.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
427
+ "model.llama_model.model.layers.1.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
428
+ "model.llama_model.model.layers.1.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
429
+ "model.llama_model.model.layers.1.self_attn.rotary_emb.inv_freq": "model-00001-of-00003.safetensors",
430
+ "model.llama_model.model.layers.1.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
431
+ "model.llama_model.model.layers.10.input_layernorm.weight": "model-00002-of-00003.safetensors",
432
+ "model.llama_model.model.layers.10.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
433
+ "model.llama_model.model.layers.10.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
434
+ "model.llama_model.model.layers.10.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
435
+ "model.llama_model.model.layers.10.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
436
+ "model.llama_model.model.layers.10.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
437
+ "model.llama_model.model.layers.10.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
438
+ "model.llama_model.model.layers.10.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
439
+ "model.llama_model.model.layers.10.self_attn.rotary_emb.inv_freq": "model-00002-of-00003.safetensors",
440
+ "model.llama_model.model.layers.10.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
441
+ "model.llama_model.model.layers.11.input_layernorm.weight": "model-00002-of-00003.safetensors",
442
+ "model.llama_model.model.layers.11.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
443
+ "model.llama_model.model.layers.11.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
444
+ "model.llama_model.model.layers.11.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
445
+ "model.llama_model.model.layers.11.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
446
+ "model.llama_model.model.layers.11.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
447
+ "model.llama_model.model.layers.11.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
448
+ "model.llama_model.model.layers.11.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
449
+ "model.llama_model.model.layers.11.self_attn.rotary_emb.inv_freq": "model-00002-of-00003.safetensors",
450
+ "model.llama_model.model.layers.11.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
451
+ "model.llama_model.model.layers.12.input_layernorm.weight": "model-00002-of-00003.safetensors",
452
+ "model.llama_model.model.layers.12.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
453
+ "model.llama_model.model.layers.12.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
454
+ "model.llama_model.model.layers.12.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
455
+ "model.llama_model.model.layers.12.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
456
+ "model.llama_model.model.layers.12.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
457
+ "model.llama_model.model.layers.12.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
458
+ "model.llama_model.model.layers.12.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
459
+ "model.llama_model.model.layers.12.self_attn.rotary_emb.inv_freq": "model-00002-of-00003.safetensors",
460
+ "model.llama_model.model.layers.12.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
461
+ "model.llama_model.model.layers.13.input_layernorm.weight": "model-00002-of-00003.safetensors",
462
+ "model.llama_model.model.layers.13.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
463
+ "model.llama_model.model.layers.13.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
464
+ "model.llama_model.model.layers.13.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
465
+ "model.llama_model.model.layers.13.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
466
+ "model.llama_model.model.layers.13.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
467
+ "model.llama_model.model.layers.13.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
468
+ "model.llama_model.model.layers.13.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
469
+ "model.llama_model.model.layers.13.self_attn.rotary_emb.inv_freq": "model-00002-of-00003.safetensors",
470
+ "model.llama_model.model.layers.13.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
471
+ "model.llama_model.model.layers.14.input_layernorm.weight": "model-00002-of-00003.safetensors",
472
+ "model.llama_model.model.layers.14.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
473
+ "model.llama_model.model.layers.14.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
474
+ "model.llama_model.model.layers.14.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
475
+ "model.llama_model.model.layers.14.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
476
+ "model.llama_model.model.layers.14.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
477
+ "model.llama_model.model.layers.14.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
478
+ "model.llama_model.model.layers.14.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
479
+ "model.llama_model.model.layers.14.self_attn.rotary_emb.inv_freq": "model-00002-of-00003.safetensors",
480
+ "model.llama_model.model.layers.14.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
481
+ "model.llama_model.model.layers.15.input_layernorm.weight": "model-00002-of-00003.safetensors",
482
+ "model.llama_model.model.layers.15.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
483
+ "model.llama_model.model.layers.15.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
484
+ "model.llama_model.model.layers.15.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
485
+ "model.llama_model.model.layers.15.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
486
+ "model.llama_model.model.layers.15.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
487
+ "model.llama_model.model.layers.15.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
488
+ "model.llama_model.model.layers.15.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
489
+ "model.llama_model.model.layers.15.self_attn.rotary_emb.inv_freq": "model-00002-of-00003.safetensors",
490
+ "model.llama_model.model.layers.15.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
491
+ "model.llama_model.model.layers.16.input_layernorm.weight": "model-00002-of-00003.safetensors",
492
+ "model.llama_model.model.layers.16.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
493
+ "model.llama_model.model.layers.16.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
494
+ "model.llama_model.model.layers.16.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
495
+ "model.llama_model.model.layers.16.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
496
+ "model.llama_model.model.layers.16.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
497
+ "model.llama_model.model.layers.16.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
498
+ "model.llama_model.model.layers.16.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
499
+ "model.llama_model.model.layers.16.self_attn.rotary_emb.inv_freq": "model-00002-of-00003.safetensors",
500
+ "model.llama_model.model.layers.16.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
501
+ "model.llama_model.model.layers.17.input_layernorm.weight": "model-00002-of-00003.safetensors",
502
+ "model.llama_model.model.layers.17.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
503
+ "model.llama_model.model.layers.17.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
504
+ "model.llama_model.model.layers.17.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
505
+ "model.llama_model.model.layers.17.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
506
+ "model.llama_model.model.layers.17.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
507
+ "model.llama_model.model.layers.17.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
508
+ "model.llama_model.model.layers.17.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
509
+ "model.llama_model.model.layers.17.self_attn.rotary_emb.inv_freq": "model-00002-of-00003.safetensors",
510
+ "model.llama_model.model.layers.17.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
511
+ "model.llama_model.model.layers.18.input_layernorm.weight": "model-00002-of-00003.safetensors",
512
+ "model.llama_model.model.layers.18.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
513
+ "model.llama_model.model.layers.18.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
514
+ "model.llama_model.model.layers.18.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
515
+ "model.llama_model.model.layers.18.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
516
+ "model.llama_model.model.layers.18.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
517
+ "model.llama_model.model.layers.18.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
518
+ "model.llama_model.model.layers.18.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
519
+ "model.llama_model.model.layers.18.self_attn.rotary_emb.inv_freq": "model-00002-of-00003.safetensors",
520
+ "model.llama_model.model.layers.18.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
521
+ "model.llama_model.model.layers.19.input_layernorm.weight": "model-00002-of-00003.safetensors",
522
+ "model.llama_model.model.layers.19.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
523
+ "model.llama_model.model.layers.19.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
524
+ "model.llama_model.model.layers.19.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
525
+ "model.llama_model.model.layers.19.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
526
+ "model.llama_model.model.layers.19.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
527
+ "model.llama_model.model.layers.19.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
528
+ "model.llama_model.model.layers.19.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
529
+ "model.llama_model.model.layers.19.self_attn.rotary_emb.inv_freq": "model-00002-of-00003.safetensors",
530
+ "model.llama_model.model.layers.19.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
531
+ "model.llama_model.model.layers.2.input_layernorm.weight": "model-00001-of-00003.safetensors",
532
+ "model.llama_model.model.layers.2.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
533
+ "model.llama_model.model.layers.2.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
534
+ "model.llama_model.model.layers.2.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
535
+ "model.llama_model.model.layers.2.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
536
+ "model.llama_model.model.layers.2.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
537
+ "model.llama_model.model.layers.2.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
538
+ "model.llama_model.model.layers.2.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
539
+ "model.llama_model.model.layers.2.self_attn.rotary_emb.inv_freq": "model-00001-of-00003.safetensors",
540
+ "model.llama_model.model.layers.2.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
541
+ "model.llama_model.model.layers.20.input_layernorm.weight": "model-00003-of-00003.safetensors",
542
+ "model.llama_model.model.layers.20.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
543
+ "model.llama_model.model.layers.20.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
544
+ "model.llama_model.model.layers.20.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
545
+ "model.llama_model.model.layers.20.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
546
+ "model.llama_model.model.layers.20.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
547
+ "model.llama_model.model.layers.20.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
548
+ "model.llama_model.model.layers.20.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
549
+ "model.llama_model.model.layers.20.self_attn.rotary_emb.inv_freq": "model-00002-of-00003.safetensors",
550
+ "model.llama_model.model.layers.20.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
551
+ "model.llama_model.model.layers.21.input_layernorm.weight": "model-00003-of-00003.safetensors",
552
+ "model.llama_model.model.layers.21.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
553
+ "model.llama_model.model.layers.21.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
554
+ "model.llama_model.model.layers.21.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
555
+ "model.llama_model.model.layers.21.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
556
+ "model.llama_model.model.layers.21.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
557
+ "model.llama_model.model.layers.21.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
558
+ "model.llama_model.model.layers.21.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
559
+ "model.llama_model.model.layers.21.self_attn.rotary_emb.inv_freq": "model-00003-of-00003.safetensors",
560
+ "model.llama_model.model.layers.21.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
561
+ "model.llama_model.model.layers.22.input_layernorm.weight": "model-00003-of-00003.safetensors",
562
+ "model.llama_model.model.layers.22.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
563
+ "model.llama_model.model.layers.22.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
564
+ "model.llama_model.model.layers.22.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
565
+ "model.llama_model.model.layers.22.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
566
+ "model.llama_model.model.layers.22.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
567
+ "model.llama_model.model.layers.22.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
568
+ "model.llama_model.model.layers.22.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
569
+ "model.llama_model.model.layers.22.self_attn.rotary_emb.inv_freq": "model-00003-of-00003.safetensors",
570
+ "model.llama_model.model.layers.22.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
571
+ "model.llama_model.model.layers.23.input_layernorm.weight": "model-00003-of-00003.safetensors",
572
+ "model.llama_model.model.layers.23.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
573
+ "model.llama_model.model.layers.23.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
574
+ "model.llama_model.model.layers.23.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
575
+ "model.llama_model.model.layers.23.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
576
+ "model.llama_model.model.layers.23.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
577
+ "model.llama_model.model.layers.23.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
578
+ "model.llama_model.model.layers.23.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
579
+ "model.llama_model.model.layers.23.self_attn.rotary_emb.inv_freq": "model-00003-of-00003.safetensors",
580
+ "model.llama_model.model.layers.23.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
581
+ "model.llama_model.model.layers.24.input_layernorm.weight": "model-00003-of-00003.safetensors",
582
+ "model.llama_model.model.layers.24.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
583
+ "model.llama_model.model.layers.24.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
584
+ "model.llama_model.model.layers.24.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
585
+ "model.llama_model.model.layers.24.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
586
+ "model.llama_model.model.layers.24.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
587
+ "model.llama_model.model.layers.24.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
588
+ "model.llama_model.model.layers.24.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
589
+ "model.llama_model.model.layers.24.self_attn.rotary_emb.inv_freq": "model-00003-of-00003.safetensors",
590
+ "model.llama_model.model.layers.24.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
591
+ "model.llama_model.model.layers.25.input_layernorm.weight": "model-00003-of-00003.safetensors",
592
+ "model.llama_model.model.layers.25.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
593
+ "model.llama_model.model.layers.25.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
594
+ "model.llama_model.model.layers.25.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
595
+ "model.llama_model.model.layers.25.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
596
+ "model.llama_model.model.layers.25.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
597
+ "model.llama_model.model.layers.25.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
598
+ "model.llama_model.model.layers.25.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
599
+ "model.llama_model.model.layers.25.self_attn.rotary_emb.inv_freq": "model-00003-of-00003.safetensors",
600
+ "model.llama_model.model.layers.25.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
601
+ "model.llama_model.model.layers.26.input_layernorm.weight": "model-00003-of-00003.safetensors",
602
+ "model.llama_model.model.layers.26.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
603
+ "model.llama_model.model.layers.26.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
604
+ "model.llama_model.model.layers.26.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
605
+ "model.llama_model.model.layers.26.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
606
+ "model.llama_model.model.layers.26.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
607
+ "model.llama_model.model.layers.26.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
608
+ "model.llama_model.model.layers.26.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
609
+ "model.llama_model.model.layers.26.self_attn.rotary_emb.inv_freq": "model-00003-of-00003.safetensors",
610
+ "model.llama_model.model.layers.26.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
611
+ "model.llama_model.model.layers.27.input_layernorm.weight": "model-00003-of-00003.safetensors",
612
+ "model.llama_model.model.layers.27.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
613
+ "model.llama_model.model.layers.27.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
614
+ "model.llama_model.model.layers.27.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
615
+ "model.llama_model.model.layers.27.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
616
+ "model.llama_model.model.layers.27.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
617
+ "model.llama_model.model.layers.27.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
618
+ "model.llama_model.model.layers.27.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
619
+ "model.llama_model.model.layers.27.self_attn.rotary_emb.inv_freq": "model-00003-of-00003.safetensors",
620
+ "model.llama_model.model.layers.27.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
621
+ "model.llama_model.model.layers.28.input_layernorm.weight": "model-00003-of-00003.safetensors",
622
+ "model.llama_model.model.layers.28.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
623
+ "model.llama_model.model.layers.28.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
624
+ "model.llama_model.model.layers.28.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
625
+ "model.llama_model.model.layers.28.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
626
+ "model.llama_model.model.layers.28.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
627
+ "model.llama_model.model.layers.28.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
628
+ "model.llama_model.model.layers.28.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
629
+ "model.llama_model.model.layers.28.self_attn.rotary_emb.inv_freq": "model-00003-of-00003.safetensors",
630
+ "model.llama_model.model.layers.28.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
631
+ "model.llama_model.model.layers.29.input_layernorm.weight": "model-00003-of-00003.safetensors",
632
+ "model.llama_model.model.layers.29.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
633
+ "model.llama_model.model.layers.29.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
634
+ "model.llama_model.model.layers.29.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
635
+ "model.llama_model.model.layers.29.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
636
+ "model.llama_model.model.layers.29.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
637
+ "model.llama_model.model.layers.29.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
638
+ "model.llama_model.model.layers.29.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
639
+ "model.llama_model.model.layers.29.self_attn.rotary_emb.inv_freq": "model-00003-of-00003.safetensors",
640
+ "model.llama_model.model.layers.29.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
641
+ "model.llama_model.model.layers.3.input_layernorm.weight": "model-00001-of-00003.safetensors",
642
+ "model.llama_model.model.layers.3.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
643
+ "model.llama_model.model.layers.3.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
644
+ "model.llama_model.model.layers.3.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
645
+ "model.llama_model.model.layers.3.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
646
+ "model.llama_model.model.layers.3.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
647
+ "model.llama_model.model.layers.3.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
648
+ "model.llama_model.model.layers.3.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
649
+ "model.llama_model.model.layers.3.self_attn.rotary_emb.inv_freq": "model-00001-of-00003.safetensors",
650
+ "model.llama_model.model.layers.3.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
651
+ "model.llama_model.model.layers.30.input_layernorm.weight": "model-00003-of-00003.safetensors",
652
+ "model.llama_model.model.layers.30.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
653
+ "model.llama_model.model.layers.30.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
654
+ "model.llama_model.model.layers.30.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
655
+ "model.llama_model.model.layers.30.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
656
+ "model.llama_model.model.layers.30.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
657
+ "model.llama_model.model.layers.30.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
658
+ "model.llama_model.model.layers.30.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
659
+ "model.llama_model.model.layers.30.self_attn.rotary_emb.inv_freq": "model-00003-of-00003.safetensors",
660
+ "model.llama_model.model.layers.30.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
661
+ "model.llama_model.model.layers.31.input_layernorm.weight": "model-00003-of-00003.safetensors",
662
+ "model.llama_model.model.layers.31.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
663
+ "model.llama_model.model.layers.31.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
664
+ "model.llama_model.model.layers.31.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
665
+ "model.llama_model.model.layers.31.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
666
+ "model.llama_model.model.layers.31.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
667
+ "model.llama_model.model.layers.31.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
668
+ "model.llama_model.model.layers.31.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
669
+ "model.llama_model.model.layers.31.self_attn.rotary_emb.inv_freq": "model-00003-of-00003.safetensors",
670
+ "model.llama_model.model.layers.31.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
671
+ "model.llama_model.model.layers.4.input_layernorm.weight": "model-00001-of-00003.safetensors",
672
+ "model.llama_model.model.layers.4.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
673
+ "model.llama_model.model.layers.4.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
674
+ "model.llama_model.model.layers.4.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
675
+ "model.llama_model.model.layers.4.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
676
+ "model.llama_model.model.layers.4.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
677
+ "model.llama_model.model.layers.4.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
678
+ "model.llama_model.model.layers.4.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
679
+ "model.llama_model.model.layers.4.self_attn.rotary_emb.inv_freq": "model-00001-of-00003.safetensors",
680
+ "model.llama_model.model.layers.4.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
681
+ "model.llama_model.model.layers.5.input_layernorm.weight": "model-00001-of-00003.safetensors",
682
+ "model.llama_model.model.layers.5.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
683
+ "model.llama_model.model.layers.5.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
684
+ "model.llama_model.model.layers.5.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
685
+ "model.llama_model.model.layers.5.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
686
+ "model.llama_model.model.layers.5.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
687
+ "model.llama_model.model.layers.5.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
688
+ "model.llama_model.model.layers.5.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
689
+ "model.llama_model.model.layers.5.self_attn.rotary_emb.inv_freq": "model-00001-of-00003.safetensors",
690
+ "model.llama_model.model.layers.5.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
691
+ "model.llama_model.model.layers.6.input_layernorm.weight": "model-00001-of-00003.safetensors",
692
+ "model.llama_model.model.layers.6.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
693
+ "model.llama_model.model.layers.6.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
694
+ "model.llama_model.model.layers.6.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
695
+ "model.llama_model.model.layers.6.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
696
+ "model.llama_model.model.layers.6.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
697
+ "model.llama_model.model.layers.6.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
698
+ "model.llama_model.model.layers.6.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
699
+ "model.llama_model.model.layers.6.self_attn.rotary_emb.inv_freq": "model-00001-of-00003.safetensors",
700
+ "model.llama_model.model.layers.6.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
701
+ "model.llama_model.model.layers.7.input_layernorm.weight": "model-00001-of-00003.safetensors",
702
+ "model.llama_model.model.layers.7.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
703
+ "model.llama_model.model.layers.7.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
704
+ "model.llama_model.model.layers.7.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
705
+ "model.llama_model.model.layers.7.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
706
+ "model.llama_model.model.layers.7.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
707
+ "model.llama_model.model.layers.7.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
708
+ "model.llama_model.model.layers.7.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
709
+ "model.llama_model.model.layers.7.self_attn.rotary_emb.inv_freq": "model-00001-of-00003.safetensors",
710
+ "model.llama_model.model.layers.7.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
711
+ "model.llama_model.model.layers.8.input_layernorm.weight": "model-00002-of-00003.safetensors",
712
+ "model.llama_model.model.layers.8.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
713
+ "model.llama_model.model.layers.8.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
714
+ "model.llama_model.model.layers.8.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
715
+ "model.llama_model.model.layers.8.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
716
+ "model.llama_model.model.layers.8.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
717
+ "model.llama_model.model.layers.8.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
718
+ "model.llama_model.model.layers.8.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
719
+ "model.llama_model.model.layers.8.self_attn.rotary_emb.inv_freq": "model-00001-of-00003.safetensors",
720
+ "model.llama_model.model.layers.8.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
721
+ "model.llama_model.model.layers.9.input_layernorm.weight": "model-00002-of-00003.safetensors",
722
+ "model.llama_model.model.layers.9.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
723
+ "model.llama_model.model.layers.9.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
724
+ "model.llama_model.model.layers.9.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
725
+ "model.llama_model.model.layers.9.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
726
+ "model.llama_model.model.layers.9.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
727
+ "model.llama_model.model.layers.9.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
728
+ "model.llama_model.model.layers.9.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
729
+ "model.llama_model.model.layers.9.self_attn.rotary_emb.inv_freq": "model-00002-of-00003.safetensors",
730
+ "model.llama_model.model.layers.9.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
731
+ "model.llama_model.model.norm.weight": "model-00003-of-00003.safetensors",
732
+ "model.llama_proj.bias": "model-00003-of-00003.safetensors",
733
+ "model.llama_proj.weight": "model-00003-of-00003.safetensors"
734
+ }
735
+ }
modelling_musilingo.py ADDED
@@ -0,0 +1,2275 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import random
4
+ import math
5
+ import re
6
+ import shutil
7
+ import warnings
8
+ import datetime
9
+ import time
10
+ from collections import defaultdict, deque
11
+ from typing import List, Optional, Tuple, Union
12
+
13
+ from torch.cuda.amp import autocast as autocast
14
+ import torch
15
+ import torch.distributed as dist
16
+ import torch.nn as nn
17
+ import torch.utils.checkpoint
18
+ from torch.nn import CrossEntropyLoss
19
+ from transformers import Wav2Vec2FeatureExtractor
20
+ from omegaconf import OmegaConf
21
+
22
+ from .configuration_musilingo import MusiLingoConfig, PATH
23
+ import timm.models.hub as timm_hub
24
+
25
+
26
+ from transformers import LlamaTokenizer, Wav2Vec2FeatureExtractor, AutoModel
27
+ from transformers.activations import ACT2FN
28
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
29
+ from transformers.modeling_utils import PreTrainedModel
30
+ from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
31
+ from transformers.models.llama.configuration_llama import LlamaConfig
32
+ from transformers import PreTrainedModel
33
+
34
+
35
+
36
+ def download_url(
37
+ url: str, root: str, filename: Optional[str] = None, md5: Optional[str] = None, max_redirect_hops: int = 3
38
+ ) -> None:
39
+ """Download a file from a url and place it in root.
40
+
41
+ Args:
42
+ url (str): URL to download file from
43
+ root (str): Directory to place downloaded file in
44
+ filename (str, optional): Name to save the file under. If None, use the basename of the URL
45
+ md5 (str, optional): MD5 checksum of the download. If None, do not check
46
+ max_redirect_hops (int, optional): Maximum number of redirect hops allowed
47
+ """
48
+ root = os.path.expanduser(root)
49
+ if not filename:
50
+ filename = os.path.basename(url)
51
+ fpath = os.path.join(root, filename)
52
+
53
+ os.makedirs(root, exist_ok=True)
54
+
55
+ # check if file is already present locally
56
+ if check_integrity(fpath, md5):
57
+ print("Using downloaded and verified file: " + fpath)
58
+ return
59
+
60
+ if _is_remote_location_available():
61
+ _download_file_from_remote_location(fpath, url)
62
+ else:
63
+ # expand redirect chain if needed
64
+ url = _get_redirect_url(url, max_hops=max_redirect_hops)
65
+
66
+ # check if file is located on Google Drive
67
+ file_id = _get_google_drive_file_id(url)
68
+ if file_id is not None:
69
+ return download_file_from_google_drive(file_id, root, filename, md5)
70
+
71
+ # download the file
72
+ try:
73
+ print("Downloading " + url + " to " + fpath)
74
+ _urlretrieve(url, fpath)
75
+ except (urllib.error.URLError, OSError) as e: # type: ignore[attr-defined]
76
+ if url[:5] == "https":
77
+ url = url.replace("https:", "http:")
78
+ print("Failed download. Trying https -> http instead. Downloading " + url + " to " + fpath)
79
+ _urlretrieve(url, fpath)
80
+ else:
81
+ raise e
82
+
83
+ # check integrity of downloaded file
84
+ if not check_integrity(fpath, md5):
85
+ raise RuntimeError("File not found or corrupted.")
86
+
87
+
88
+
89
+ def load_dataset_config(cfg_path):
90
+ cfg = OmegaConf.load(cfg_path).datasets
91
+ cfg = cfg[list(cfg.keys())[0]]
92
+
93
+ return cfg
94
+
95
+ class SmoothedValue(object):
96
+ """Track a series of values and provide access to smoothed values over a
97
+ window or the global series average.
98
+ """
99
+
100
+ def __init__(self, window_size=20, fmt=None):
101
+ if fmt is None:
102
+ fmt = "{median:.4f} ({global_avg:.4f})"
103
+ self.deque = deque(maxlen=window_size)
104
+ self.total = 0.0
105
+ self.count = 0
106
+ self.fmt = fmt
107
+
108
+ def update(self, value, n=1):
109
+ self.deque.append(value)
110
+ self.count += n
111
+ self.total += value * n
112
+
113
+ def synchronize_between_processes(self):
114
+ """
115
+ Warning: does not synchronize the deque!
116
+ """
117
+ if not is_dist_avail_and_initialized():
118
+ return
119
+ t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda")
120
+ dist.barrier()
121
+ dist.all_reduce(t)
122
+ t = t.tolist()
123
+ self.count = int(t[0])
124
+ self.total = t[1]
125
+
126
+ @property
127
+ def median(self):
128
+ d = torch.tensor(list(self.deque))
129
+ return d.median().item()
130
+
131
+ @property
132
+ def avg(self):
133
+ d = torch.tensor(list(self.deque), dtype=torch.float32)
134
+ return d.mean().item()
135
+
136
+ @property
137
+ def global_avg(self):
138
+ return self.total / self.count
139
+
140
+ @property
141
+ def max(self):
142
+ return max(self.deque)
143
+
144
+ @property
145
+ def value(self):
146
+ return self.deque[-1]
147
+
148
+ def __str__(self):
149
+ return self.fmt.format(
150
+ median=self.median,
151
+ avg=self.avg,
152
+ global_avg=self.global_avg,
153
+ max=self.max,
154
+ value=self.value,
155
+ )
156
+
157
+
158
+ class MetricLogger(object):
159
+ def __init__(self, delimiter="\t"):
160
+ self.meters = defaultdict(SmoothedValue)
161
+ self.delimiter = delimiter
162
+
163
+ def update(self, **kwargs):
164
+ for k, v in kwargs.items():
165
+ if isinstance(v, torch.Tensor):
166
+ v = v.item()
167
+ assert isinstance(v, (float, int))
168
+ self.meters[k].update(v)
169
+
170
+ def __getattr__(self, attr):
171
+ if attr in self.meters:
172
+ return self.meters[attr]
173
+ if attr in self.__dict__:
174
+ return self.__dict__[attr]
175
+ raise AttributeError(
176
+ "'{}' object has no attribute '{}'".format(type(self).__name__, attr)
177
+ )
178
+
179
+ def __str__(self):
180
+ loss_str = []
181
+ for name, meter in self.meters.items():
182
+ loss_str.append("{}: {}".format(name, str(meter)))
183
+ return self.delimiter.join(loss_str)
184
+
185
+ def global_avg(self):
186
+ loss_str = []
187
+ for name, meter in self.meters.items():
188
+ loss_str.append("{}: {:.4f}".format(name, meter.global_avg))
189
+ return self.delimiter.join(loss_str)
190
+
191
+ def synchronize_between_processes(self):
192
+ for meter in self.meters.values():
193
+ meter.synchronize_between_processes()
194
+
195
+ def add_meter(self, name, meter):
196
+ self.meters[name] = meter
197
+
198
+ def log_every(self, iterable, print_freq, header=None):
199
+ i = 0
200
+ if not header:
201
+ header = ""
202
+ start_time = time.time()
203
+ end = time.time()
204
+ iter_time = SmoothedValue(fmt="{avg:.4f}")
205
+ data_time = SmoothedValue(fmt="{avg:.4f}")
206
+ space_fmt = ":" + str(len(str(len(iterable)))) + "d"
207
+ log_msg = [
208
+ header,
209
+ "[{0" + space_fmt + "}/{1}]",
210
+ "eta: {eta}",
211
+ "{meters}",
212
+ "time: {time}",
213
+ "data: {data}",
214
+ ]
215
+ if torch.cuda.is_available():
216
+ log_msg.append("max mem: {memory:.0f}")
217
+ log_msg = self.delimiter.join(log_msg)
218
+ MB = 1024.0 * 1024.0
219
+ for obj in iterable:
220
+ data_time.update(time.time() - end)
221
+ yield obj
222
+ iter_time.update(time.time() - end)
223
+ if i % print_freq == 0 or i == len(iterable) - 1:
224
+ eta_seconds = iter_time.global_avg * (len(iterable) - i)
225
+ eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
226
+ if torch.cuda.is_available():
227
+ print(
228
+ log_msg.format(
229
+ i,
230
+ len(iterable),
231
+ eta=eta_string,
232
+ meters=str(self),
233
+ time=str(iter_time),
234
+ data=str(data_time),
235
+ memory=torch.cuda.max_memory_allocated() / MB,
236
+ )
237
+ )
238
+ else:
239
+ print(
240
+ log_msg.format(
241
+ i,
242
+ len(iterable),
243
+ eta=eta_string,
244
+ meters=str(self),
245
+ time=str(iter_time),
246
+ data=str(data_time),
247
+ )
248
+ )
249
+ i += 1
250
+ end = time.time()
251
+ total_time = time.time() - start_time
252
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
253
+ print(
254
+ "{} Total time: {} ({:.4f} s / it)".format(
255
+ header, total_time_str, total_time / len(iterable)
256
+ )
257
+ )
258
+
259
+
260
+ def move_to_cuda(sample):
261
+ def _move_to_cuda(tensor):
262
+ return tensor.cuda()
263
+
264
+ return apply_to_sample(_move_to_cuda, sample)
265
+
266
+ def apply_to_sample(f, sample):
267
+ if len(sample) == 0:
268
+ return {}
269
+
270
+ def _apply(x):
271
+ if torch.is_tensor(x):
272
+ return f(x)
273
+ elif isinstance(x, dict):
274
+ return {key: _apply(value) for key, value in x.items()}
275
+ elif isinstance(x, list):
276
+ return [_apply(x) for x in x]
277
+ else:
278
+ return x
279
+
280
+ return _apply(sample)
281
+
282
+ def prepare_sample(samples, cuda_enabled=True):
283
+ if cuda_enabled:
284
+ samples = move_to_cuda(samples)
285
+
286
+ # TODO fp16 support
287
+
288
+ return samples
289
+
290
+ def get_world_size():
291
+ if not is_dist_avail_and_initialized():
292
+ return 1
293
+ return dist.get_world_size()
294
+
295
+ class BaseTask:
296
+ def __init__(self, **kwargs):
297
+ super().__init__()
298
+
299
+ self.inst_id_key = "instance_id"
300
+
301
+ @classmethod
302
+ def setup_task(cls, **kwargs):
303
+ return cls()
304
+
305
+ def build_model(self, cfg):
306
+ model_config = cfg.model_cfg
307
+
308
+ model_cls = registry.get_model_class(model_config.arch)
309
+ return model_cls.from_config(model_config)
310
+
311
+ def build_datasets(self, cfg):
312
+ """
313
+ Build a dictionary of datasets, keyed by split 'train', 'valid', 'test'.
314
+ Download dataset and annotations automatically if not exist.
315
+
316
+ Args:
317
+ cfg (common.config.Config): _description_
318
+
319
+ Returns:
320
+ dict: Dictionary of torch.utils.data.Dataset objects by split.
321
+ """
322
+
323
+ datasets = dict()
324
+
325
+ datasets_config = cfg.datasets_cfg
326
+
327
+ assert len(datasets_config) > 0, "At least one dataset has to be specified."
328
+
329
+ for name in datasets_config:
330
+ dataset_config = datasets_config[name]
331
+
332
+ builder = registry.get_builder_class(name)(dataset_config)
333
+ dataset = builder.build_datasets()
334
+
335
+ dataset['train'].name = name
336
+ if 'sample_ratio' in dataset_config:
337
+ dataset['train'].sample_ratio = dataset_config.sample_ratio
338
+
339
+ datasets[name] = dataset
340
+
341
+ return datasets
342
+
343
+ def train_step(self, model, samples):
344
+ loss = model(samples)["loss"]
345
+ return loss
346
+
347
+ def valid_step(self, model, samples):
348
+ raise NotImplementedError
349
+
350
+ def before_evaluation(self, model, dataset, **kwargs):
351
+ model.before_evaluation(dataset=dataset, task_type=type(self))
352
+
353
+ def after_evaluation(self, **kwargs):
354
+ pass
355
+
356
+ def inference_step(self):
357
+ raise NotImplementedError
358
+
359
+ def evaluation(self, model, data_loader, cuda_enabled=True):
360
+ metric_logger = MetricLogger(delimiter=" ")
361
+ header = "Evaluation"
362
+ # TODO make it configurable
363
+ print_freq = 10
364
+
365
+ results = []
366
+
367
+ for samples in metric_logger.log_every(data_loader, print_freq, header):
368
+ samples = prepare_sample(samples, cuda_enabled=cuda_enabled)
369
+
370
+ eval_output = self.valid_step(model=model, samples=samples)
371
+ results.extend(eval_output)
372
+
373
+ if is_dist_avail_and_initialized():
374
+ dist.barrier()
375
+
376
+ return results
377
+
378
+ def train_epoch(
379
+ self,
380
+ epoch,
381
+ model,
382
+ data_loader,
383
+ optimizer,
384
+ lr_scheduler,
385
+ scaler=None,
386
+ cuda_enabled=False,
387
+ log_freq=50,
388
+ accum_grad_iters=1,
389
+ ):
390
+ return self._train_inner_loop(
391
+ epoch=epoch,
392
+ iters_per_epoch=lr_scheduler.iters_per_epoch,
393
+ model=model,
394
+ data_loader=data_loader,
395
+ optimizer=optimizer,
396
+ scaler=scaler,
397
+ lr_scheduler=lr_scheduler,
398
+ log_freq=log_freq,
399
+ cuda_enabled=cuda_enabled,
400
+ accum_grad_iters=accum_grad_iters,
401
+ )
402
+
403
+ def train_iters(
404
+ self,
405
+ epoch,
406
+ start_iters,
407
+ iters_per_inner_epoch,
408
+ model,
409
+ data_loader,
410
+ optimizer,
411
+ lr_scheduler,
412
+ scaler=None,
413
+ cuda_enabled=False,
414
+ log_freq=50,
415
+ accum_grad_iters=1,
416
+ ):
417
+ return self._train_inner_loop(
418
+ epoch=epoch,
419
+ start_iters=start_iters,
420
+ iters_per_epoch=iters_per_inner_epoch,
421
+ model=model,
422
+ data_loader=data_loader,
423
+ optimizer=optimizer,
424
+ scaler=scaler,
425
+ lr_scheduler=lr_scheduler,
426
+ log_freq=log_freq,
427
+ cuda_enabled=cuda_enabled,
428
+ accum_grad_iters=accum_grad_iters,
429
+ )
430
+
431
+ def _train_inner_loop(
432
+ self,
433
+ epoch,
434
+ iters_per_epoch,
435
+ model,
436
+ data_loader,
437
+ optimizer,
438
+ lr_scheduler,
439
+ scaler=None,
440
+ start_iters=None,
441
+ log_freq=50,
442
+ cuda_enabled=False,
443
+ accum_grad_iters=1,
444
+ ):
445
+ """
446
+ An inner training loop compatible with both epoch-based and iter-based training.
447
+
448
+ When using epoch-based, training stops after one epoch; when using iter-based,
449
+ training stops after #iters_per_epoch iterations.
450
+ """
451
+ use_amp = scaler is not None
452
+
453
+ if not hasattr(data_loader, "__next__"):
454
+ # convert to iterator if not already
455
+ data_loader = iter(data_loader)
456
+
457
+ metric_logger = MetricLogger(delimiter=" ")
458
+ metric_logger.add_meter("lr", SmoothedValue(window_size=1, fmt="{value:.6f}"))
459
+ metric_logger.add_meter("loss", SmoothedValue(window_size=1, fmt="{value:.4f}"))
460
+
461
+ # if iter-based runner, schedule lr based on inner epoch.
462
+ logging.info(
463
+ "Start training epoch {}, {} iters per inner epoch.".format(
464
+ epoch, iters_per_epoch
465
+ )
466
+ )
467
+ header = "Train: data epoch: [{}]".format(epoch)
468
+ if start_iters is None:
469
+ # epoch-based runner
470
+ inner_epoch = epoch
471
+ else:
472
+ # In iter-based runner, we schedule the learning rate based on iterations.
473
+ inner_epoch = start_iters // iters_per_epoch
474
+ header = header + "; inner epoch [{}]".format(inner_epoch)
475
+
476
+ for i in metric_logger.log_every(range(iters_per_epoch), log_freq, header):
477
+ # if using iter-based runner, we stop after iters_per_epoch iterations.
478
+ if i >= iters_per_epoch:
479
+ break
480
+
481
+ samples = next(data_loader)
482
+
483
+ samples = prepare_sample(samples, cuda_enabled=cuda_enabled)
484
+ samples.update(
485
+ {
486
+ "epoch": inner_epoch,
487
+ "num_iters_per_epoch": iters_per_epoch,
488
+ "iters": i,
489
+ }
490
+ )
491
+
492
+ lr_scheduler.step(cur_epoch=inner_epoch, cur_step=i)
493
+
494
+ with torch.cuda.amp.autocast(enabled=use_amp):
495
+ loss = self.train_step(model=model, samples=samples)
496
+
497
+ # after_train_step()
498
+ if use_amp:
499
+ scaler.scale(loss).backward()
500
+ else:
501
+ loss.backward()
502
+
503
+ # update gradients every accum_grad_iters iterations
504
+ if (i + 1) % accum_grad_iters == 0:
505
+ if use_amp:
506
+ scaler.step(optimizer)
507
+ scaler.update()
508
+ else:
509
+ optimizer.step()
510
+ optimizer.zero_grad()
511
+
512
+ metric_logger.update(loss=loss.item())
513
+ metric_logger.update(lr=optimizer.param_groups[0]["lr"])
514
+
515
+ # after train_epoch()
516
+ # gather the stats from all processes
517
+ metric_logger.synchronize_between_processes()
518
+ logging.info("Averaged stats: " + str(metric_logger.global_avg()))
519
+ return {
520
+ k: "{:.3f}".format(meter.global_avg)
521
+ for k, meter in metric_logger.meters.items()
522
+ }
523
+
524
+ @staticmethod
525
+ def save_result(result, result_dir, filename, remove_duplicate=""):
526
+ import json
527
+
528
+ result_file = os.path.join(
529
+ result_dir, "%s_rank%d.json" % (filename, get_rank())
530
+ )
531
+ final_result_file = os.path.join(result_dir, "%s.json" % filename)
532
+
533
+ json.dump(result, open(result_file, "w"))
534
+
535
+ if is_dist_avail_and_initialized():
536
+ dist.barrier()
537
+
538
+ if is_main_process():
539
+ logging.warning("rank %d starts merging results." % get_rank())
540
+ # combine results from all processes
541
+ result = []
542
+
543
+ for rank in range(get_world_size()):
544
+ result_file = os.path.join(
545
+ result_dir, "%s_rank%d.json" % (filename, rank)
546
+ )
547
+ res = json.load(open(result_file, "r"))
548
+ result += res
549
+
550
+ if remove_duplicate:
551
+ result_new = []
552
+ id_list = []
553
+ for res in result:
554
+ if res[remove_duplicate] not in id_list:
555
+ id_list.append(res[remove_duplicate])
556
+ result_new.append(res)
557
+ result = result_new
558
+
559
+ json.dump(result, open(final_result_file, "w"))
560
+ print("result file saved to %s" % final_result_file)
561
+
562
+ return final_result_file
563
+
564
+
565
+ class BaseProcessor:
566
+ def __init__(self):
567
+ self.transform = lambda x: x
568
+ return
569
+
570
+ def __call__(self, item):
571
+ return self.transform(item)
572
+
573
+ @classmethod
574
+ def from_config(cls, cfg=None):
575
+ return cls()
576
+
577
+ def build(self, **kwargs):
578
+ cfg = OmegaConf.create(kwargs)
579
+
580
+ return self.from_config(cfg)
581
+
582
+ def get_cache_path(rel_path):
583
+ return os.path.expanduser(os.path.join(registry.get_path("cache_root"), rel_path))
584
+
585
+
586
+ class BaseDatasetBuilder:
587
+ train_dataset_cls, eval_dataset_cls = None, None
588
+
589
+ def __init__(self, cfg=None):
590
+ super().__init__()
591
+
592
+ if cfg is None:
593
+ # help to create datasets from default config.
594
+ self.config = load_dataset_config(self.default_config_path())
595
+ elif isinstance(cfg, str):
596
+ self.config = load_dataset_config(cfg)
597
+ else:
598
+ # when called from task.build_dataset()
599
+ self.config = cfg
600
+
601
+ self.data_type = self.config.data_type
602
+
603
+ self.vis_processors = {"train": BaseProcessor(), "eval": BaseProcessor()}
604
+ self.text_processors = {"train": BaseProcessor(), "eval": BaseProcessor()}
605
+
606
+ def build_datasets(self):
607
+ # download, split, etc...
608
+ # only called on 1 GPU/TPU in distributed
609
+
610
+ if is_main_process():
611
+ self._download_data()
612
+
613
+ if is_dist_avail_and_initialized():
614
+ dist.barrier()
615
+
616
+ # at this point, all the annotations and image/videos should be all downloaded to the specified locations.
617
+ logging.info("Building datasets...")
618
+ datasets = self.build() # dataset['train'/'val'/'test']
619
+
620
+ return datasets
621
+
622
+ def build_processors(self):
623
+ vis_proc_cfg = self.config.get("vis_processor")
624
+ txt_proc_cfg = self.config.get("text_processor")
625
+
626
+ if vis_proc_cfg is not None:
627
+ vis_train_cfg = vis_proc_cfg.get("train")
628
+ vis_eval_cfg = vis_proc_cfg.get("eval")
629
+
630
+ self.vis_processors["train"] = self._build_proc_from_cfg(vis_train_cfg)
631
+ self.vis_processors["eval"] = self._build_proc_from_cfg(vis_eval_cfg)
632
+
633
+ if txt_proc_cfg is not None:
634
+ txt_train_cfg = txt_proc_cfg.get("train")
635
+ txt_eval_cfg = txt_proc_cfg.get("eval")
636
+
637
+ self.text_processors["train"] = self._build_proc_from_cfg(txt_train_cfg)
638
+ self.text_processors["eval"] = self._build_proc_from_cfg(txt_eval_cfg)
639
+
640
+ @staticmethod
641
+ def _build_proc_from_cfg(cfg):
642
+ return (
643
+ registry.get_processor_class(cfg.name).from_config(cfg)
644
+ if cfg is not None
645
+ else None
646
+ )
647
+
648
+ @classmethod
649
+ def default_config_path(cls, type="default"):
650
+ return get_abs_path(cls.DATASET_CONFIG_DICT[type])
651
+
652
+ def _download_data(self):
653
+ self._download_ann()
654
+ self._download_vis()
655
+
656
+ def _download_ann(self):
657
+ """
658
+ Download annotation files if necessary.
659
+ All the vision-language datasets should have annotations of unified format.
660
+
661
+ storage_path can be:
662
+ (1) relative/absolute: will be prefixed with env.cache_root to make full path if relative.
663
+ (2) basename/dirname: will be suffixed with base name of URL if dirname is provided.
664
+
665
+ Local annotation paths should be relative.
666
+ """
667
+ anns = self.config.build_info.annotations
668
+
669
+ splits = anns.keys()
670
+
671
+ cache_root = registry.get_path("cache_root")
672
+
673
+ for split in splits:
674
+ info = anns[split]
675
+
676
+ urls, storage_paths = info.get("url", None), info.storage
677
+
678
+ if isinstance(urls, str):
679
+ urls = [urls]
680
+ if isinstance(storage_paths, str):
681
+ storage_paths = [storage_paths]
682
+
683
+ assert len(urls) == len(storage_paths)
684
+
685
+ for url_or_filename, storage_path in zip(urls, storage_paths):
686
+ # if storage_path is relative, make it full by prefixing with cache_root.
687
+ if not os.path.isabs(storage_path):
688
+ storage_path = os.path.join(cache_root, storage_path)
689
+
690
+ dirname = os.path.dirname(storage_path)
691
+ if not os.path.exists(dirname):
692
+ os.makedirs(dirname)
693
+
694
+ if os.path.isfile(url_or_filename):
695
+ src, dst = url_or_filename, storage_path
696
+ if not os.path.exists(dst):
697
+ shutil.copyfile(src=src, dst=dst)
698
+ else:
699
+ logging.info("Using existing file {}.".format(dst))
700
+ else:
701
+ if os.path.isdir(storage_path):
702
+ # if only dirname is provided, suffix with basename of URL.
703
+ raise ValueError(
704
+ "Expecting storage_path to be a file path, got directory {}".format(
705
+ storage_path
706
+ )
707
+ )
708
+ else:
709
+ filename = os.path.basename(storage_path)
710
+
711
+ download_url(url=url_or_filename, root=dirname, filename=filename)
712
+
713
+ def _download_vis(self):
714
+
715
+ storage_path = self.config.build_info.get(self.data_type).storage
716
+ storage_path = get_cache_path(storage_path)
717
+
718
+ if not os.path.exists(storage_path):
719
+ warnings.warn(
720
+ f"""
721
+ The specified path {storage_path} for visual inputs does not exist.
722
+ Please provide a correct path to the visual inputs or
723
+ refer to datasets/download_scripts/README.md for downloading instructions.
724
+ """
725
+ )
726
+
727
+ def build(self):
728
+ """
729
+ Create by split datasets inheriting torch.utils.data.Datasets.
730
+
731
+ # build() can be dataset-specific. Overwrite to customize.
732
+ """
733
+ self.build_processors()
734
+
735
+ build_info = self.config.build_info
736
+
737
+ ann_info = build_info.annotations
738
+ vis_info = build_info.get(self.data_type)
739
+
740
+ datasets = dict()
741
+ for split in ann_info.keys():
742
+ if split not in ["train", "val", "test"]:
743
+ continue
744
+
745
+ is_train = split == "train"
746
+
747
+ # processors
748
+ vis_processor = (
749
+ self.vis_processors["train"]
750
+ if is_train
751
+ else self.vis_processors["eval"]
752
+ )
753
+ text_processor = (
754
+ self.text_processors["train"]
755
+ if is_train
756
+ else self.text_processors["eval"]
757
+ )
758
+
759
+ # annotation path
760
+ ann_paths = ann_info.get(split).storage
761
+ if isinstance(ann_paths, str):
762
+ ann_paths = [ann_paths]
763
+
764
+ abs_ann_paths = []
765
+ for ann_path in ann_paths:
766
+ if not os.path.isabs(ann_path):
767
+ ann_path = get_cache_path(ann_path)
768
+ abs_ann_paths.append(ann_path)
769
+ ann_paths = abs_ann_paths
770
+
771
+ # visual data storage path
772
+ vis_path = os.path.join(vis_info.storage, split)
773
+
774
+ if not os.path.isabs(vis_path):
775
+ # vis_path = os.path.join(utils.get_cache_path(), vis_path)
776
+ vis_path = get_cache_path(vis_path)
777
+
778
+ if not os.path.exists(vis_path):
779
+ warnings.warn("storage path {} does not exist.".format(vis_path))
780
+
781
+ # create datasets
782
+ dataset_cls = self.train_dataset_cls if is_train else self.eval_dataset_cls
783
+ datasets[split] = dataset_cls(
784
+ vis_processor=vis_processor,
785
+ text_processor=text_processor,
786
+ ann_paths=ann_paths,
787
+ vis_root=vis_path,
788
+ )
789
+
790
+ return datasets
791
+
792
+
793
+
794
+
795
+ class Registry:
796
+ mapping = {
797
+ "builder_name_mapping": {},
798
+ "task_name_mapping": {},
799
+ "processor_name_mapping": {},
800
+ "model_name_mapping": {},
801
+ "lr_scheduler_name_mapping": {},
802
+ "runner_name_mapping": {},
803
+ "state": {},
804
+ "paths": {},
805
+ }
806
+
807
+ @classmethod
808
+ def register_builder(cls, name):
809
+ r"""Register a dataset builder to registry with key 'name'
810
+
811
+ Args:
812
+ name: Key with which the builder will be registered.
813
+
814
+ Usage:
815
+
816
+ # from lavi.common.registry import registry
817
+ # from lavi.datasets.base_dataset_builder import BaseDatasetBuilder
818
+ """
819
+
820
+ def wrap(builder_cls):
821
+ # from musilingo.datasets.builders.base_dataset_builder import BaseDatasetBuilder
822
+
823
+ assert issubclass(
824
+ builder_cls, BaseDatasetBuilder
825
+ ), "All builders must inherit BaseDatasetBuilder class, found {}".format(
826
+ builder_cls
827
+ )
828
+ if name in cls.mapping["builder_name_mapping"]:
829
+ raise KeyError(
830
+ "Name '{}' already registered for {}.".format(
831
+ name, cls.mapping["builder_name_mapping"][name]
832
+ )
833
+ )
834
+ cls.mapping["builder_name_mapping"][name] = builder_cls
835
+ return builder_cls
836
+
837
+ return wrap
838
+
839
+ @classmethod
840
+ def register_task(cls, name):
841
+ r"""Register a task to registry with key 'name'
842
+
843
+ Args:
844
+ name: Key with which the task will be registered.
845
+
846
+ Usage:
847
+
848
+ # from lavi.common.registry import registry
849
+ """
850
+
851
+ def wrap(task_cls):
852
+ # from musilingo.tasks.base_task import BaseTask
853
+
854
+ assert issubclass(
855
+ task_cls, BaseTask
856
+ ), "All tasks must inherit BaseTask class"
857
+ if name in cls.mapping["task_name_mapping"]:
858
+ raise KeyError(
859
+ "Name '{}' already registered for {}.".format(
860
+ name, cls.mapping["task_name_mapping"][name]
861
+ )
862
+ )
863
+ cls.mapping["task_name_mapping"][name] = task_cls
864
+ return task_cls
865
+
866
+ return wrap
867
+
868
+ @classmethod
869
+ def register_model(cls, name):
870
+ r"""Register a task to registry with key 'name'
871
+
872
+ Args:
873
+ name: Key with which the task will be registered.
874
+
875
+ Usage:
876
+
877
+ # from lavi.common.registry import registry
878
+ """
879
+
880
+ def wrap(model_cls):
881
+
882
+ assert issubclass(
883
+ model_cls, BaseModel
884
+ ), "All models must inherit BaseModel class"
885
+ if name in cls.mapping["model_name_mapping"]:
886
+ raise KeyError(
887
+ "Name '{}' already registered for {}.".format(
888
+ name, cls.mapping["model_name_mapping"][name]
889
+ )
890
+ )
891
+ cls.mapping["model_name_mapping"][name] = model_cls
892
+ return model_cls
893
+
894
+ return wrap
895
+
896
+ @classmethod
897
+ def register_processor(cls, name):
898
+ r"""Register a processor to registry with key 'name'
899
+
900
+ Args:
901
+ name: Key with which the task will be registered.
902
+
903
+ Usage:
904
+
905
+ # from lavi.common.registry import registry
906
+ """
907
+
908
+ def wrap(processor_cls):
909
+ # from musilingo.processors import BaseProcessor
910
+
911
+ assert issubclass(
912
+ processor_cls, BaseProcessor
913
+ ), "All processors must inherit BaseProcessor class"
914
+ if name in cls.mapping["processor_name_mapping"]:
915
+ raise KeyError(
916
+ "Name '{}' already registered for {}.".format(
917
+ name, cls.mapping["processor_name_mapping"][name]
918
+ )
919
+ )
920
+ cls.mapping["processor_name_mapping"][name] = processor_cls
921
+ return processor_cls
922
+
923
+ return wrap
924
+
925
+ @classmethod
926
+ def register_lr_scheduler(cls, name):
927
+ r"""Register a model to registry with key 'name'
928
+
929
+ Args:
930
+ name: Key with which the task will be registered.
931
+
932
+ Usage:
933
+
934
+ # from minigpt4.common.registry import registry
935
+ """
936
+
937
+ def wrap(lr_sched_cls):
938
+ if name in cls.mapping["lr_scheduler_name_mapping"]:
939
+ raise KeyError(
940
+ "Name '{}' already registered for {}.".format(
941
+ name, cls.mapping["lr_scheduler_name_mapping"][name]
942
+ )
943
+ )
944
+ cls.mapping["lr_scheduler_name_mapping"][name] = lr_sched_cls
945
+ return lr_sched_cls
946
+
947
+ return wrap
948
+
949
+ @classmethod
950
+ def register_runner(cls, name):
951
+ r"""Register a model to registry with key 'name'
952
+
953
+ Args:
954
+ name: Key with which the task will be registered.
955
+
956
+ Usage:
957
+
958
+ # from minigpt4.common.registry import registry
959
+ """
960
+
961
+ def wrap(runner_cls):
962
+ if name in cls.mapping["runner_name_mapping"]:
963
+ raise KeyError(
964
+ "Name '{}' already registered for {}.".format(
965
+ name, cls.mapping["runner_name_mapping"][name]
966
+ )
967
+ )
968
+ cls.mapping["runner_name_mapping"][name] = runner_cls
969
+ return runner_cls
970
+
971
+ return wrap
972
+
973
+ @classmethod
974
+ def register_path(cls, name, path):
975
+ r"""Register a path to registry with key 'name'
976
+
977
+ Args:
978
+ name: Key with which the path will be registered.
979
+
980
+ Usage:
981
+
982
+ # from minigpt4.common.registry import registry
983
+ """
984
+ assert isinstance(path, str), "All path must be str."
985
+ if name in cls.mapping["paths"]:
986
+ raise KeyError("Name '{}' already registered.".format(name))
987
+ cls.mapping["paths"][name] = path
988
+
989
+ @classmethod
990
+ def register(cls, name, obj):
991
+ r"""Register an item to registry with key 'name'
992
+
993
+ Args:
994
+ name: Key with which the item will be registered.
995
+
996
+ Usage::
997
+
998
+ # from minigpt4.common.registry import registry
999
+
1000
+ registry.register("config", {})
1001
+ """
1002
+ path = name.split(".")
1003
+ current = cls.mapping["state"]
1004
+
1005
+ for part in path[:-1]:
1006
+ if part not in current:
1007
+ current[part] = {}
1008
+ current = current[part]
1009
+
1010
+ current[path[-1]] = obj
1011
+
1012
+ # @classmethod
1013
+ # def get_trainer_class(cls, name):
1014
+ # return cls.mapping["trainer_name_mapping"].get(name, None)
1015
+
1016
+ @classmethod
1017
+ def get_builder_class(cls, name):
1018
+ return cls.mapping["builder_name_mapping"].get(name, None)
1019
+
1020
+ @classmethod
1021
+ def get_model_class(cls, name):
1022
+ return cls.mapping["model_name_mapping"].get(name, None)
1023
+
1024
+ @classmethod
1025
+ def get_task_class(cls, name):
1026
+ return cls.mapping["task_name_mapping"].get(name, None)
1027
+
1028
+ @classmethod
1029
+ def get_processor_class(cls, name):
1030
+ return cls.mapping["processor_name_mapping"].get(name, None)
1031
+
1032
+ @classmethod
1033
+ def get_lr_scheduler_class(cls, name):
1034
+ return cls.mapping["lr_scheduler_name_mapping"].get(name, None)
1035
+
1036
+ @classmethod
1037
+ def get_runner_class(cls, name):
1038
+ return cls.mapping["runner_name_mapping"].get(name, None)
1039
+
1040
+ @classmethod
1041
+ def list_runners(cls):
1042
+ return sorted(cls.mapping["runner_name_mapping"].keys())
1043
+
1044
+ @classmethod
1045
+ def list_models(cls):
1046
+ return sorted(cls.mapping["model_name_mapping"].keys())
1047
+
1048
+ @classmethod
1049
+ def list_tasks(cls):
1050
+ return sorted(cls.mapping["task_name_mapping"].keys())
1051
+
1052
+ @classmethod
1053
+ def list_processors(cls):
1054
+ return sorted(cls.mapping["processor_name_mapping"].keys())
1055
+
1056
+ @classmethod
1057
+ def list_lr_schedulers(cls):
1058
+ return sorted(cls.mapping["lr_scheduler_name_mapping"].keys())
1059
+
1060
+ @classmethod
1061
+ def list_datasets(cls):
1062
+ return sorted(cls.mapping["builder_name_mapping"].keys())
1063
+
1064
+ @classmethod
1065
+ def get_path(cls, name):
1066
+ return cls.mapping["paths"].get(name, None)
1067
+
1068
+ @classmethod
1069
+ def get(cls, name, default=None, no_warning=False):
1070
+ r"""Get an item from registry with key 'name'
1071
+
1072
+ Args:
1073
+ name (string): Key whose value needs to be retrieved.
1074
+ default: If passed and key is not in registry, default value will
1075
+ be returned with a warning. Default: None
1076
+ no_warning (bool): If passed as True, warning when key doesn't exist
1077
+ will not be generated. Useful for MMF's
1078
+ internal operations. Default: False
1079
+ """
1080
+ original_name = name
1081
+ name = name.split(".")
1082
+ value = cls.mapping["state"]
1083
+ for subname in name:
1084
+ value = value.get(subname, default)
1085
+ if value is default:
1086
+ break
1087
+
1088
+ if (
1089
+ "writer" in cls.mapping["state"]
1090
+ and value == default
1091
+ and no_warning is False
1092
+ ):
1093
+ cls.mapping["state"]["writer"].warning(
1094
+ "Key {} is not present in registry, returning default value "
1095
+ "of {}".format(original_name, default)
1096
+ )
1097
+ return value
1098
+
1099
+ @classmethod
1100
+ def unregister(cls, name):
1101
+ r"""Remove an item from registry with key 'name'
1102
+
1103
+ Args:
1104
+ name: Key which needs to be removed.
1105
+ Usage::
1106
+
1107
+ # from mmf.common.registry import registry
1108
+
1109
+ config = registry.unregister("config")
1110
+ """
1111
+ return cls.mapping["state"].pop(name, None)
1112
+
1113
+
1114
+ registry = Registry()
1115
+
1116
+
1117
+ def get_abs_path(rel_path):
1118
+ return os.path.join(registry.get_path("library_root"), rel_path)
1119
+
1120
+ def is_url(input_url):
1121
+ """
1122
+ Check if an input string is a url. look for http(s):// and ignoring the case
1123
+ """
1124
+ is_url = re.match(r"^(?:http)s?://", input_url, re.IGNORECASE) is not None
1125
+ return is_url
1126
+
1127
+
1128
+ def download_cached_file(url, check_hash=True, progress=False):
1129
+ """
1130
+ Download a file from a URL and cache it locally. If the file already exists, it is not downloaded again.
1131
+ If distributed, only the main process downloads the file, and the other processes wait for the file to be downloaded.
1132
+ """
1133
+
1134
+ def get_cached_file_path():
1135
+ # a hack to sync the file path across processes
1136
+ parts = torch.hub.urlparse(url)
1137
+ filename = os.path.basename(parts.path)
1138
+ cached_file = os.path.join(timm_hub.get_cache_dir(), filename)
1139
+
1140
+ return cached_file
1141
+
1142
+ if is_main_process():
1143
+ timm_hub.download_cached_file(url, check_hash, progress)
1144
+
1145
+ if is_dist_avail_and_initialized():
1146
+ dist.barrier()
1147
+
1148
+ return get_cached_file_path()
1149
+
1150
+ def is_dist_avail_and_initialized():
1151
+ if not dist.is_available():
1152
+ return False
1153
+ if not dist.is_initialized():
1154
+ return False
1155
+ return True
1156
+
1157
+ def is_main_process():
1158
+ return get_rank() == 0
1159
+
1160
+ def get_rank():
1161
+ if not is_dist_avail_and_initialized():
1162
+ return 0
1163
+ return dist.get_rank()
1164
+
1165
+ class BaseModel(nn.Module):
1166
+ """Base class for models."""
1167
+
1168
+ def __init__(self):
1169
+ super().__init__()
1170
+
1171
+ @property
1172
+ def device(self):
1173
+ return list(self.parameters())[0].device
1174
+
1175
+ def load_checkpoint(self, url_or_filename):
1176
+ """
1177
+ Load from a finetuned checkpoint.
1178
+
1179
+ This should expect no mismatch in the model keys and the checkpoint keys.
1180
+ """
1181
+
1182
+ if is_url(url_or_filename):
1183
+ cached_file = download_cached_file(
1184
+ url_or_filename, check_hash=False, progress=True
1185
+ )
1186
+ checkpoint = torch.load(cached_file, map_location="cpu")
1187
+ elif os.path.isfile(url_or_filename):
1188
+ checkpoint = torch.load(url_or_filename, map_location="cpu")
1189
+ else:
1190
+ raise RuntimeError("checkpoint url or path is invalid")
1191
+
1192
+ if "model" in checkpoint.keys():
1193
+ state_dict = checkpoint["model"]
1194
+ else:
1195
+ state_dict = checkpoint
1196
+
1197
+ msg = self.load_state_dict(state_dict, strict=False)
1198
+
1199
+ logging.info("Missing keys {}".format(msg.missing_keys))
1200
+ logging.info("load checkpoint from %s" % url_or_filename)
1201
+
1202
+ return msg
1203
+
1204
+ @classmethod
1205
+ def from_pretrained(cls, model_type):
1206
+ """
1207
+ Build a pretrained model from default configuration file, specified by model_type.
1208
+
1209
+ Args:
1210
+ - model_type (str): model type, specifying architecture and checkpoints.
1211
+
1212
+ Returns:
1213
+ - model (nn.Module): pretrained or finetuned model, depending on the configuration.
1214
+ """
1215
+ model_cfg = OmegaConf.load(cls.default_config_path(model_type)).model
1216
+ model = cls.from_config(model_cfg)
1217
+
1218
+ return model
1219
+
1220
+ @classmethod
1221
+ def default_config_path(cls, model_type):
1222
+ assert (
1223
+ model_type in cls.PRETRAINED_MODEL_CONFIG_DICT
1224
+ ), "Unknown model type {}".format(model_type)
1225
+ return get_abs_path(cls.PRETRAINED_MODEL_CONFIG_DICT[model_type])
1226
+
1227
+ def load_checkpoint_from_config(self, cfg, **kwargs):
1228
+ """
1229
+ Load checkpoint as specified in the config file.
1230
+
1231
+ If load_finetuned is True, load the finetuned model; otherwise, load the pretrained model.
1232
+ When loading the pretrained model, each task-specific architecture may define their
1233
+ own load_from_pretrained() method.
1234
+ """
1235
+ load_finetuned = cfg.get("load_finetuned", True)
1236
+ if load_finetuned:
1237
+ finetune_path = cfg.get("finetuned", None)
1238
+ assert (
1239
+ finetune_path is not None
1240
+ ), "Found load_finetuned is True, but finetune_path is None."
1241
+ self.load_checkpoint(url_or_filename=finetune_path)
1242
+ else:
1243
+ # load pre-trained weights
1244
+ pretrain_path = cfg.get("pretrained", None)
1245
+ assert "Found load_finetuned is False, but pretrain_path is None."
1246
+ self.load_from_pretrained(url_or_filename=pretrain_path, **kwargs)
1247
+
1248
+ def before_evaluation(self, **kwargs):
1249
+ pass
1250
+
1251
+ def show_n_params(self, return_str=True):
1252
+ tot = 0
1253
+ for p in self.parameters():
1254
+ w = 1
1255
+ for x in p.shape:
1256
+ w *= x
1257
+ tot += w
1258
+ if return_str:
1259
+ if tot >= 1e6:
1260
+ return "{:.1f}M".format(tot / 1e6)
1261
+ else:
1262
+ return "{:.1f}K".format(tot / 1e3)
1263
+ else:
1264
+ return tot
1265
+
1266
+ LLAMA_INPUTS_DOCSTRING = r"""
1267
+ Args:
1268
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
1269
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
1270
+ it.
1271
+
1272
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
1273
+ [`PreTrainedTokenizer.__call__`] for details.
1274
+
1275
+ [What are input IDs?](../glossary#input-ids)
1276
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
1277
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
1278
+
1279
+ - 1 for tokens that are **not masked**,
1280
+ - 0 for tokens that are **masked**.
1281
+
1282
+ [What are attention masks?](../glossary#attention-mask)
1283
+
1284
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
1285
+ [`PreTrainedTokenizer.__call__`] for details.
1286
+
1287
+ If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
1288
+ `past_key_values`).
1289
+
1290
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
1291
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
1292
+ information on the default strategy.
1293
+
1294
+ - 1 indicates the head is **not masked**,
1295
+ - 0 indicates the head is **masked**.
1296
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1297
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
1298
+ config.n_positions - 1]`.
1299
+
1300
+ [What are position IDs?](../glossary#position-ids)
1301
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
1302
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
1303
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
1304
+ `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
1305
+
1306
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
1307
+ blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
1308
+
1309
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
1310
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
1311
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
1312
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
1313
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
1314
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
1315
+ model's internal embedding lookup matrix.
1316
+ use_cache (`bool`, *optional*):
1317
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
1318
+ `past_key_values`).
1319
+ output_attentions (`bool`, *optional*):
1320
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
1321
+ tensors for more detail.
1322
+ output_hidden_states (`bool`, *optional*):
1323
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
1324
+ more detail.
1325
+ return_dict (`bool`, *optional*):
1326
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1327
+ """
1328
+
1329
+
1330
+ LLAMA_START_DOCSTRING = r"""
1331
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
1332
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
1333
+ etc.)
1334
+
1335
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
1336
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
1337
+ and behavior.
1338
+
1339
+ Parameters:
1340
+ config ([`LlamaConfig`]):
1341
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
1342
+ load the weights associated with the model, only the configuration. Check out the
1343
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
1344
+ """
1345
+
1346
+
1347
+ logger = logging.get_logger(__name__)
1348
+
1349
+ _CONFIG_FOR_DOC = "LlamaConfig"
1350
+
1351
+
1352
+ # Copied from transformers.models.bart.modeling_bart._make_causal_mask
1353
+ def _make_causal_mask(
1354
+ input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
1355
+ ):
1356
+ """
1357
+ Make causal mask used for bi-directional self-attention.
1358
+ """
1359
+ bsz, tgt_len = input_ids_shape
1360
+ mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min, device=device), device=device)
1361
+ mask_cond = torch.arange(mask.size(-1), device=device)
1362
+ mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
1363
+ mask = mask.to(dtype)
1364
+
1365
+ if past_key_values_length > 0:
1366
+ mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
1367
+ return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
1368
+
1369
+
1370
+ # Copied from transformers.models.bart.modeling_bart._expand_mask
1371
+ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
1372
+ """
1373
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
1374
+ """
1375
+ bsz, src_len = mask.size()
1376
+ tgt_len = tgt_len if tgt_len is not None else src_len
1377
+
1378
+ expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
1379
+
1380
+ inverted_mask = 1.0 - expanded_mask
1381
+
1382
+ return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
1383
+
1384
+
1385
+ class LlamaRMSNorm(nn.Module):
1386
+ def __init__(self, hidden_size, eps=1e-6):
1387
+ """
1388
+ LlamaRMSNorm is equivalent to T5LayerNorm
1389
+ """
1390
+ super().__init__()
1391
+ self.weight = nn.Parameter(torch.ones(hidden_size))
1392
+ self.variance_epsilon = eps
1393
+
1394
+ def forward(self, hidden_states):
1395
+ variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
1396
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
1397
+
1398
+ # convert into half-precision if necessary
1399
+ if self.weight.dtype in [torch.float16, torch.bfloat16]:
1400
+ hidden_states = hidden_states.to(self.weight.dtype)
1401
+
1402
+ return self.weight * hidden_states
1403
+
1404
+
1405
+ class LlamaRotaryEmbedding(torch.nn.Module):
1406
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
1407
+ super().__init__()
1408
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
1409
+ self.register_buffer("inv_freq", inv_freq)
1410
+
1411
+ # Build here to make `torch.jit.trace` work.
1412
+ self.max_seq_len_cached = max_position_embeddings
1413
+ t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
1414
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
1415
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
1416
+ emb = torch.cat((freqs, freqs), dim=-1)
1417
+ self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
1418
+ self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)
1419
+
1420
+ def forward(self, x, seq_len=None):
1421
+ # x: [bs, num_attention_heads, seq_len, head_size]
1422
+ # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
1423
+ if seq_len > self.max_seq_len_cached:
1424
+ self.max_seq_len_cached = seq_len
1425
+ t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)
1426
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
1427
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
1428
+ emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
1429
+ self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
1430
+ self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)
1431
+ return (
1432
+ self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
1433
+ self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
1434
+ )
1435
+
1436
+
1437
+ def rotate_half(x):
1438
+ """Rotates half the hidden dims of the input."""
1439
+ x1 = x[..., : x.shape[-1] // 2]
1440
+ x2 = x[..., x.shape[-1] // 2 :]
1441
+ return torch.cat((-x2, x1), dim=-1)
1442
+
1443
+
1444
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
1445
+ gather_indices = position_ids[:, None, :, None] # [bs, 1, seq_len, 1]
1446
+ gather_indices = gather_indices.repeat(1, cos.shape[1], 1, cos.shape[3])
1447
+ cos = torch.gather(cos.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices)
1448
+ sin = torch.gather(sin.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices)
1449
+ q_embed = (q * cos) + (rotate_half(q) * sin)
1450
+ k_embed = (k * cos) + (rotate_half(k) * sin)
1451
+ return q_embed, k_embed
1452
+
1453
+
1454
+
1455
+
1456
+ class LlamaMLP(nn.Module):
1457
+ def __init__(
1458
+ self,
1459
+ hidden_size: int,
1460
+ intermediate_size: int,
1461
+ hidden_act: str,
1462
+ ):
1463
+ super().__init__()
1464
+ self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
1465
+ self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)
1466
+ self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
1467
+ self.act_fn = ACT2FN[hidden_act]
1468
+
1469
+ def forward(self, x):
1470
+ return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
1471
+
1472
+
1473
+ class LlamaAttention(nn.Module):
1474
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
1475
+
1476
+ def __init__(self, config: LlamaConfig):
1477
+ super().__init__()
1478
+ self.config = config
1479
+ self.hidden_size = config.hidden_size
1480
+ self.num_heads = config.num_attention_heads
1481
+ self.head_dim = self.hidden_size // self.num_heads
1482
+ self.max_position_embeddings = config.max_position_embeddings
1483
+
1484
+ if (self.head_dim * self.num_heads) != self.hidden_size:
1485
+ raise ValueError(
1486
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
1487
+ f" and `num_heads`: {self.num_heads})."
1488
+ )
1489
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
1490
+ self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
1491
+ self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
1492
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
1493
+ self.rotary_emb = LlamaRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings)
1494
+
1495
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
1496
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
1497
+
1498
+ def forward(
1499
+ self,
1500
+ hidden_states: torch.Tensor,
1501
+ attention_mask: Optional[torch.Tensor] = None,
1502
+ position_ids: Optional[torch.LongTensor] = None,
1503
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
1504
+ output_attentions: bool = False,
1505
+ use_cache: bool = False,
1506
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
1507
+ bsz, q_len, _ = hidden_states.size()
1508
+
1509
+ query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
1510
+ key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
1511
+ value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
1512
+
1513
+ kv_seq_len = key_states.shape[-2]
1514
+ if past_key_value is not None:
1515
+ kv_seq_len += past_key_value[0].shape[-2]
1516
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
1517
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
1518
+ # [bsz, nh, t, hd]
1519
+
1520
+ if past_key_value is not None:
1521
+ # reuse k, v, self_attention
1522
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
1523
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
1524
+
1525
+ past_key_value = (key_states, value_states) if use_cache else None
1526
+
1527
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
1528
+
1529
+ if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
1530
+ raise ValueError(
1531
+ f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is"
1532
+ f" {attn_weights.size()}"
1533
+ )
1534
+
1535
+ if attention_mask is not None:
1536
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
1537
+ raise ValueError(
1538
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
1539
+ )
1540
+ attn_weights = attn_weights + attention_mask
1541
+ attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))
1542
+
1543
+ # upcast attention to fp32
1544
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
1545
+ attn_output = torch.matmul(attn_weights, value_states)
1546
+
1547
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
1548
+ raise ValueError(
1549
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
1550
+ f" {attn_output.size()}"
1551
+ )
1552
+
1553
+ attn_output = attn_output.transpose(1, 2)
1554
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
1555
+
1556
+ attn_output = self.o_proj(attn_output)
1557
+
1558
+ if not output_attentions:
1559
+ attn_weights = None
1560
+
1561
+ return attn_output, attn_weights, past_key_value
1562
+
1563
+
1564
+
1565
+ class LlamaDecoderLayer(nn.Module):
1566
+ def __init__(self, config: LlamaConfig):
1567
+ super().__init__()
1568
+ self.hidden_size = config.hidden_size
1569
+ self.self_attn = LlamaAttention(config=config)
1570
+ self.mlp = LlamaMLP(
1571
+ hidden_size=self.hidden_size,
1572
+ intermediate_size=config.intermediate_size,
1573
+ hidden_act=config.hidden_act,
1574
+ )
1575
+ self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1576
+ self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1577
+
1578
+ def forward(
1579
+ self,
1580
+ hidden_states: torch.Tensor,
1581
+ attention_mask: Optional[torch.Tensor] = None,
1582
+ position_ids: Optional[torch.LongTensor] = None,
1583
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
1584
+ output_attentions: Optional[bool] = False,
1585
+ use_cache: Optional[bool] = False,
1586
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
1587
+ """
1588
+ Args:
1589
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
1590
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
1591
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
1592
+ output_attentions (`bool`, *optional*):
1593
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
1594
+ returned tensors for more detail.
1595
+ use_cache (`bool`, *optional*):
1596
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
1597
+ (see `past_key_values`).
1598
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
1599
+ """
1600
+
1601
+ residual = hidden_states
1602
+
1603
+ hidden_states = self.input_layernorm(hidden_states)
1604
+
1605
+ # Self Attention
1606
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
1607
+ hidden_states=hidden_states,
1608
+ attention_mask=attention_mask,
1609
+ position_ids=position_ids,
1610
+ past_key_value=past_key_value,
1611
+ output_attentions=output_attentions,
1612
+ use_cache=use_cache,
1613
+ )
1614
+ hidden_states = residual + hidden_states
1615
+
1616
+ # Fully Connected
1617
+ residual = hidden_states
1618
+ hidden_states = self.post_attention_layernorm(hidden_states)
1619
+ hidden_states = self.mlp(hidden_states)
1620
+ hidden_states = residual + hidden_states
1621
+
1622
+ outputs = (hidden_states,)
1623
+
1624
+ if output_attentions:
1625
+ outputs += (self_attn_weights,)
1626
+
1627
+ if use_cache:
1628
+ outputs += (present_key_value,)
1629
+
1630
+ return outputs
1631
+
1632
+
1633
+ @add_start_docstrings(
1634
+ "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
1635
+ LLAMA_START_DOCSTRING,
1636
+ )
1637
+ class LlamaPreTrainedModel(PreTrainedModel):
1638
+ config_class = LlamaConfig
1639
+ base_model_prefix = "model"
1640
+ supports_gradient_checkpointing = True
1641
+ _no_split_modules = ["LlamaDecoderLayer"]
1642
+ _keys_to_ignore_on_load_unexpected = [r"decoder\.version"]
1643
+
1644
+ def _init_weights(self, module):
1645
+ std = self.config.initializer_range
1646
+ if isinstance(module, nn.Linear):
1647
+ module.weight.data.normal_(mean=0.0, std=std)
1648
+ if module.bias is not None:
1649
+ module.bias.data.zero_()
1650
+ elif isinstance(module, nn.Embedding):
1651
+ module.weight.data.normal_(mean=0.0, std=std)
1652
+ if module.padding_idx is not None:
1653
+ module.weight.data[module.padding_idx].zero_()
1654
+
1655
+ def _set_gradient_checkpointing(self, module, value=False):
1656
+ if isinstance(module, LlamaModel):
1657
+ module.gradient_checkpointing = value
1658
+
1659
+
1660
+ @add_start_docstrings(
1661
+ "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
1662
+ LLAMA_START_DOCSTRING,
1663
+ )
1664
+ class LlamaModel(LlamaPreTrainedModel):
1665
+ """
1666
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]
1667
+
1668
+ Args:
1669
+ config: LlamaConfig
1670
+ """
1671
+
1672
+ def __init__(self, config: LlamaConfig):
1673
+ super().__init__(config)
1674
+ self.padding_idx = config.pad_token_id
1675
+ self.vocab_size = config.vocab_size
1676
+
1677
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
1678
+ self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)])
1679
+ self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1680
+
1681
+ self.gradient_checkpointing = False
1682
+ # Initialize weights and apply final processing
1683
+ self.post_init()
1684
+
1685
+ def get_input_embeddings(self):
1686
+ return self.embed_tokens
1687
+
1688
+ def set_input_embeddings(self, value):
1689
+ self.embed_tokens = value
1690
+
1691
+ # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
1692
+ def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
1693
+ # create causal mask
1694
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
1695
+ combined_attention_mask = None
1696
+ if input_shape[-1] > 1:
1697
+ combined_attention_mask = _make_causal_mask(
1698
+ input_shape,
1699
+ inputs_embeds.dtype,
1700
+ device=inputs_embeds.device,
1701
+ past_key_values_length=past_key_values_length,
1702
+ )
1703
+
1704
+ if attention_mask is not None:
1705
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
1706
+ expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
1707
+ inputs_embeds.device
1708
+ )
1709
+ combined_attention_mask = (
1710
+ expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
1711
+ )
1712
+
1713
+ return combined_attention_mask
1714
+
1715
+ @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
1716
+ def forward(
1717
+ self,
1718
+ input_ids: torch.LongTensor = None,
1719
+ attention_mask: Optional[torch.Tensor] = None,
1720
+ position_ids: Optional[torch.LongTensor] = None,
1721
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1722
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1723
+ query_embeds: Optional[torch.FloatTensor] = None,
1724
+ use_cache: Optional[bool] = None,
1725
+ output_attentions: Optional[bool] = None,
1726
+ output_hidden_states: Optional[bool] = None,
1727
+ return_dict: Optional[bool] = None,
1728
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
1729
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1730
+ output_hidden_states = (
1731
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1732
+ )
1733
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
1734
+
1735
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1736
+
1737
+ # retrieve input_ids and inputs_embeds
1738
+ if input_ids is not None and inputs_embeds is not None:
1739
+ raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
1740
+ elif input_ids is not None:
1741
+ batch_size, seq_length = input_ids.shape
1742
+ elif inputs_embeds is not None:
1743
+ batch_size, seq_length, _ = inputs_embeds.shape
1744
+ else:
1745
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
1746
+
1747
+ if inputs_embeds is None:
1748
+ inputs_embeds = self.embed_tokens(input_ids)
1749
+ if query_embeds is not None:
1750
+ inputs_embeds = torch.cat([query_embeds, inputs_embeds], dim=1)
1751
+ batch_size, seq_length, _ = inputs_embeds.shape
1752
+
1753
+ seq_length_with_past = seq_length
1754
+ past_key_values_length = 0
1755
+
1756
+ if past_key_values is not None:
1757
+ past_key_values_length = past_key_values[0][0].shape[2]
1758
+ seq_length_with_past = seq_length_with_past + past_key_values_length
1759
+
1760
+ if position_ids is None:
1761
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
1762
+ position_ids = torch.arange(
1763
+ past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
1764
+ )
1765
+ position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
1766
+ else:
1767
+ position_ids = position_ids.view(-1, seq_length).long()
1768
+
1769
+ # embed positions
1770
+ if attention_mask is None:
1771
+ attention_mask = torch.ones(
1772
+ (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
1773
+ )
1774
+ attention_mask = self._prepare_decoder_attention_mask(
1775
+ attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
1776
+ )
1777
+
1778
+ hidden_states = inputs_embeds
1779
+
1780
+ if self.gradient_checkpointing and self.training:
1781
+ if use_cache:
1782
+ logger.warning_once(
1783
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
1784
+ )
1785
+ use_cache = False
1786
+
1787
+ # decoder layers
1788
+ all_hidden_states = () if output_hidden_states else None
1789
+ all_self_attns = () if output_attentions else None
1790
+ next_decoder_cache = () if use_cache else None
1791
+
1792
+ for idx, decoder_layer in enumerate(self.layers):
1793
+ if output_hidden_states:
1794
+ all_hidden_states += (hidden_states,)
1795
+
1796
+ past_key_value = past_key_values[idx] if past_key_values is not None else None
1797
+
1798
+ if self.gradient_checkpointing and self.training:
1799
+
1800
+ def create_custom_forward(module):
1801
+ def custom_forward(*inputs):
1802
+ # None for past_key_value
1803
+ return module(*inputs, output_attentions, None)
1804
+
1805
+ return custom_forward
1806
+
1807
+ layer_outputs = torch.utils.checkpoint.checkpoint(
1808
+ create_custom_forward(decoder_layer),
1809
+ hidden_states,
1810
+ attention_mask,
1811
+ position_ids,
1812
+ None,
1813
+ )
1814
+ else:
1815
+ layer_outputs = decoder_layer(
1816
+ hidden_states,
1817
+ attention_mask=attention_mask,
1818
+ position_ids=position_ids,
1819
+ past_key_value=past_key_value,
1820
+ output_attentions=output_attentions,
1821
+ use_cache=use_cache,
1822
+ )
1823
+
1824
+ hidden_states = layer_outputs[0]
1825
+
1826
+ if use_cache:
1827
+ next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
1828
+
1829
+ if output_attentions:
1830
+ all_self_attns += (layer_outputs[1],)
1831
+
1832
+ hidden_states = self.norm(hidden_states)
1833
+
1834
+ # add hidden states from the last decoder layer
1835
+ if output_hidden_states:
1836
+ all_hidden_states += (hidden_states,)
1837
+
1838
+ next_cache = next_decoder_cache if use_cache else None
1839
+ if not return_dict:
1840
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
1841
+ return BaseModelOutputWithPast(
1842
+ last_hidden_state=hidden_states,
1843
+ past_key_values=next_cache,
1844
+ hidden_states=all_hidden_states,
1845
+ attentions=all_self_attns,
1846
+ )
1847
+
1848
+
1849
+
1850
+ class LlamaForCausalLM(LlamaPreTrainedModel):
1851
+ def __init__(self, config):
1852
+ super().__init__(config)
1853
+ self.model = LlamaModel(config)
1854
+
1855
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1856
+
1857
+ # Initialize weights and apply final processing
1858
+ self.post_init()
1859
+
1860
+ def get_input_embeddings(self):
1861
+ return self.model.embed_tokens
1862
+
1863
+ def set_input_embeddings(self, value):
1864
+ self.model.embed_tokens = value
1865
+
1866
+ def get_output_embeddings(self):
1867
+ return self.lm_head
1868
+
1869
+ def set_output_embeddings(self, new_embeddings):
1870
+ self.lm_head = new_embeddings
1871
+
1872
+ def set_decoder(self, decoder):
1873
+ self.model = decoder
1874
+
1875
+ def get_decoder(self):
1876
+ return self.model
1877
+
1878
+ @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
1879
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
1880
+ def forward(
1881
+ self,
1882
+ input_ids: torch.LongTensor = None,
1883
+ attention_mask: Optional[torch.Tensor] = None,
1884
+ position_ids: Optional[torch.LongTensor] = None,
1885
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1886
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1887
+ query_embeds: Optional[torch.FloatTensor] = None,
1888
+ labels: Optional[torch.LongTensor] = None,
1889
+ use_cache: Optional[bool] = None,
1890
+ output_attentions: Optional[bool] = None,
1891
+ output_hidden_states: Optional[bool] = None,
1892
+ return_dict: Optional[bool] = None,
1893
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
1894
+ r"""
1895
+ Args:
1896
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1897
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1898
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1899
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1900
+
1901
+ Returns:
1902
+
1903
+ Example:
1904
+
1905
+ ```python
1906
+ >>> from transformers import AutoTokenizer, LlamaForCausalLM
1907
+
1908
+ >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
1909
+ >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
1910
+
1911
+ >>> prompt = "Hey, are you consciours? Can you talk to me?"
1912
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
1913
+
1914
+ >>> # Generate
1915
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1916
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1917
+ "Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you."
1918
+ ```"""
1919
+
1920
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1921
+ output_hidden_states = (
1922
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1923
+ )
1924
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1925
+
1926
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1927
+ outputs = self.model(
1928
+ input_ids=input_ids,
1929
+ attention_mask=attention_mask,
1930
+ position_ids=position_ids,
1931
+ past_key_values=past_key_values,
1932
+ inputs_embeds=inputs_embeds,
1933
+ query_embeds=query_embeds,
1934
+ use_cache=use_cache,
1935
+ output_attentions=output_attentions,
1936
+ output_hidden_states=output_hidden_states,
1937
+ return_dict=return_dict,
1938
+ )
1939
+
1940
+ hidden_states = outputs[0]
1941
+ logits = self.lm_head(hidden_states)
1942
+
1943
+ loss = None
1944
+ if labels is not None:
1945
+ # Shift so that tokens < n predict n
1946
+ shift_logits = logits[..., :-1, :].contiguous()
1947
+ shift_labels = labels[..., 1:].contiguous()
1948
+ # Flatten the tokens
1949
+ loss_fct = CrossEntropyLoss()
1950
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
1951
+ shift_labels = shift_labels.view(-1)
1952
+ # Enable model parallelism
1953
+ shift_labels = shift_labels.to(shift_logits.device)
1954
+ loss = loss_fct(shift_logits, shift_labels)
1955
+
1956
+ if not return_dict:
1957
+ output = (logits,) + outputs[1:]
1958
+ return (loss,) + output if loss is not None else output
1959
+
1960
+ return CausalLMOutputWithPast(
1961
+ loss=loss,
1962
+ logits=logits,
1963
+ past_key_values=outputs.past_key_values,
1964
+ hidden_states=outputs.hidden_states,
1965
+ attentions=outputs.attentions,
1966
+ )
1967
+
1968
+ def prepare_inputs_for_generation(
1969
+ self, input_ids, query_embeds=None, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
1970
+ ):
1971
+ if past_key_values:
1972
+ input_ids = input_ids[:, -1:]
1973
+
1974
+ position_ids = kwargs.get("position_ids", None)
1975
+ if attention_mask is not None and position_ids is None:
1976
+ # create position_ids on the fly for batch generation
1977
+ position_ids = attention_mask.long().cumsum(-1) - 1
1978
+ position_ids.masked_fill_(attention_mask == 0, 1)
1979
+ if past_key_values:
1980
+ position_ids = position_ids[:, -1].unsqueeze(-1)
1981
+ query_embeds = None
1982
+
1983
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1984
+ if inputs_embeds is not None and past_key_values is None:
1985
+ model_inputs = {"inputs_embeds": inputs_embeds}
1986
+ else:
1987
+ model_inputs = {"input_ids": input_ids}
1988
+
1989
+ model_inputs.update(
1990
+ {
1991
+ "position_ids": position_ids,
1992
+ "query_embeds": query_embeds,
1993
+ "past_key_values": past_key_values,
1994
+ "use_cache": kwargs.get("use_cache"),
1995
+ "attention_mask": attention_mask,
1996
+ }
1997
+ )
1998
+ return model_inputs
1999
+
2000
+ @staticmethod
2001
+ def _reorder_cache(past_key_values, beam_idx):
2002
+ reordered_past = ()
2003
+ for layer_past in past_key_values:
2004
+ reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
2005
+ return reordered_past
2006
+
2007
+
2008
+ @registry.register_model("musilingo")
2009
+ class MusiLingo(BaseModel):
2010
+ """
2011
+ MERT GPT-LLAMA model.
2012
+ """
2013
+
2014
+ PRETRAINED_MODEL_CONFIG_DICT = {
2015
+ "pretrain_vicuna": "configs/models/musilingo.yaml",
2016
+ }
2017
+
2018
+ def __init__(
2019
+ self,
2020
+ mert_model,
2021
+ llama_model,
2022
+ config,
2023
+ prompt_path="",
2024
+ prompt_template="",
2025
+ max_txt_len=32,
2026
+ end_sym='\n',
2027
+ low_resource=False, # use 8 bit and put vit in cpu
2028
+ device_8bit=0, # the device of 8bit model should be set when loading and cannot be changed anymore.
2029
+ ):
2030
+ super().__init__()
2031
+
2032
+ self.low_resource = low_resource
2033
+
2034
+ print('Loading Audio Encoder')
2035
+ self.audio_encoder = AutoModel.from_pretrained(mert_model, trust_remote_code=True)
2036
+ # loading the corresponding preprocessor config
2037
+ self.processor = Wav2Vec2FeatureExtractor.from_pretrained(mert_model, trust_remote_code=True)
2038
+
2039
+ for name, param in self.audio_encoder.named_parameters():
2040
+ param.requires_grad = False
2041
+ self.audio_encoder = self.audio_encoder.eval()
2042
+
2043
+ print('Loading Audio Encoder Done')
2044
+
2045
+ print('Loading LLAMA')
2046
+ self.llama_tokenizer = LlamaTokenizer.from_pretrained(llama_model, use_fast=False)
2047
+ self.llama_tokenizer.pad_token = self.llama_tokenizer.eos_token
2048
+
2049
+ if self.low_resource:
2050
+ self.llama_model = LlamaForCausalLM.from_pretrained(
2051
+ llama_model,
2052
+ torch_dtype=torch.float16,
2053
+ load_in_8bit=True,
2054
+ device_map={'': device_8bit}
2055
+ )
2056
+ else:
2057
+ self.llama_model = LlamaForCausalLM.from_pretrained(
2058
+ llama_model,
2059
+ torch_dtype=torch.float16,
2060
+ )
2061
+
2062
+ for name, param in self.llama_model.named_parameters():
2063
+ param.requires_grad = False
2064
+ print('Loading LLAMA Done')
2065
+
2066
+ self.llama_proj = nn.Linear(
2067
+ self.audio_encoder.config.hidden_size, self.llama_model.config.hidden_size
2068
+ )
2069
+ self.max_txt_len = max_txt_len
2070
+ self.end_sym = end_sym
2071
+
2072
+ self.prompt_template = prompt_template
2073
+
2074
+ if prompt_path:
2075
+ with open(prompt_path, 'r') as f:
2076
+ raw_prompts = f.read().splitlines()
2077
+ filted_prompts = [raw_prompt for raw_prompt in raw_prompts if "<AudioHere>" in raw_prompt]
2078
+ self.prompt_list = [prompt_template.format(p) for p in filted_prompts]
2079
+ print('Load {} training prompts'.format(len(self.prompt_list)))
2080
+ print('Prompt Example \n{}'.format(random.choice(self.prompt_list)))
2081
+ else:
2082
+ self.prompt_list = []
2083
+
2084
+ def audioenc_to_cpu(self):
2085
+ self.audio_encoder.to("cpu")
2086
+ self.audio_encoder.float()
2087
+
2088
+ def encode_audio(self, audio, attn=None):
2089
+ device = audio.device
2090
+ if self.low_resource:
2091
+ self.audioenc_to_cpu()
2092
+ audio = audio.to("cpu")
2093
+
2094
+ if attn is None:
2095
+
2096
+ audio_embeds = torch.stack(self.audio_encoder(input_values=audio,
2097
+ output_hidden_states=True).hidden_states) # [25, B, T, 1024]
2098
+ audio_embeds = audio_embeds.transpose(0, 1).mean(-3) #[B, T, 1024]
2099
+
2100
+ else:
2101
+
2102
+ audio_embeds = torch.stack(self.audio_encoder(input_values=audio,
2103
+ output_hidden_states=True,
2104
+ attention_mask=attn).hidden_states) # [25, B, T, 1024]
2105
+ audio_embeds = audio_embeds.transpose(0, 1).mean(-3) #[B, T, 1024]
2106
+
2107
+ # Average time steps:
2108
+ t = 325
2109
+ B, T, D = audio_embeds.shape
2110
+ avg_tmp = audio_embeds[:, :T//t*t].reshape(B, T//t, t, D).mean(2)
2111
+
2112
+ # Average the remaining steps
2113
+ if T % t > 0:
2114
+ avg_last = audio_embeds[:, T//t*t:].reshape(B, 1, T%t, D).mean(2)
2115
+ audio_embeds = torch.concat([avg_tmp, avg_last], dim=1)
2116
+ else:
2117
+ audio_embeds = avg_tmp
2118
+ audio_embeds = audio_embeds.to(device)
2119
+ inputs_llama = self.llama_proj(audio_embeds)
2120
+ atts_llama = torch.ones(inputs_llama.size()[:-1], dtype=torch.long).to(audio.device)
2121
+ return inputs_llama, atts_llama
2122
+
2123
+ def prompt_wrap(self, audio_embeds, atts_audio, prompt):
2124
+ if prompt:
2125
+ batch_size = audio_embeds.shape[0]
2126
+ p_before, p_after = prompt.split('<AudioHere>')
2127
+ p_before_tokens = self.llama_tokenizer(
2128
+ p_before, return_tensors="pt", add_special_tokens=False).to(audio_embeds.device)
2129
+ p_after_tokens = self.llama_tokenizer(
2130
+ p_after, return_tensors="pt", add_special_tokens=False).to(audio_embeds.device)
2131
+ p_before_embeds = self.llama_model.model.embed_tokens(p_before_tokens.input_ids).expand(batch_size, -1, -1)
2132
+ p_after_embeds = self.llama_model.model.embed_tokens(p_after_tokens.input_ids).expand(batch_size, -1, -1)
2133
+ wrapped_audio_embeds = torch.cat([p_before_embeds, audio_embeds, p_after_embeds], dim=1)
2134
+ wrapped_atts_audio = atts_audio[:, :1].expand(-1, wrapped_audio_embeds.shape[1])
2135
+ return wrapped_audio_embeds, wrapped_atts_audio
2136
+ else:
2137
+ return audio_embeds, atts_audio
2138
+
2139
+ def instruction_prompt_wrap(self, audio_embeds, atts_audio, prompt):
2140
+ if prompt:
2141
+ batch_size = audio_embeds.shape[0]
2142
+ p_before = []
2143
+ p_after = []
2144
+
2145
+ for i in range(batch_size):
2146
+ p_b, p_a = prompt[i].split('<AudioHere>')
2147
+ p_before.append(p_b)
2148
+ p_after.append(p_a)
2149
+
2150
+ p_before_tokens = self.llama_tokenizer(
2151
+ p_before, return_tensors="pt", padding='longest', add_special_tokens=False).to(audio_embeds.device)
2152
+ p_after_tokens = self.llama_tokenizer(
2153
+ p_after, return_tensors="pt", padding='longest', add_special_tokens=False).to(audio_embeds.device)
2154
+ p_before_embeds = self.llama_model.model.embed_tokens(p_before_tokens.input_ids)
2155
+ p_after_embeds = self.llama_model.model.embed_tokens(p_after_tokens.input_ids)
2156
+ wrapped_audio_embeds = torch.cat([p_before_embeds, audio_embeds, p_after_embeds], dim=1)
2157
+ wrapped_atts_audio = torch.cat([p_before_tokens.attention_mask, atts_audio, p_after_tokens.attention_mask], dim=1)
2158
+ return wrapped_audio_embeds, wrapped_atts_audio
2159
+ else:
2160
+ return audio_embeds, atts_audio
2161
+
2162
+
2163
+ def forward(self, samples):
2164
+ audio = samples["audio"]
2165
+ attn = samples["attention_mask"] if "attention_mask" in samples else None
2166
+ audio_embeds, atts_audio = self.encode_audio(audio, attn)
2167
+
2168
+ if 'instruction_input' in samples: # instruction tuning dataset
2169
+ instruction_prompt = []
2170
+ for instruction in samples['instruction_input']:
2171
+ prompt = '<Audio><AudioHere></Audio> ' + instruction
2172
+ instruction_prompt.append(self.prompt_template.format(prompt))
2173
+ audio_embeds, atts_audio = self.instruction_prompt_wrap(audio_embeds, atts_audio, instruction_prompt)
2174
+
2175
+ elif self.prompt_list:
2176
+ prompt = random.choice(self.prompt_list)
2177
+ audio_embeds, atts_audio = self.prompt_wrap(audio_embeds, atts_audio, prompt)
2178
+
2179
+ self.llama_tokenizer.padding_side = "right"
2180
+
2181
+ text = [t + self.end_sym for t in samples["text_input"]]
2182
+
2183
+ to_regress_tokens = self.llama_tokenizer(
2184
+ text,
2185
+ return_tensors="pt",
2186
+ padding="longest",
2187
+ truncation=True,
2188
+ max_length=self.max_txt_len,
2189
+ add_special_tokens=False
2190
+ ).to(audio.device)
2191
+
2192
+ targets = to_regress_tokens.input_ids.masked_fill(
2193
+ to_regress_tokens.input_ids == self.llama_tokenizer.pad_token_id, -100
2194
+ )
2195
+
2196
+ empty_targets = (
2197
+ torch.ones([atts_audio.shape[0], atts_audio.shape[1]+1],
2198
+ dtype=torch.long).to(audio.device).fill_(-100) # plus one for bos
2199
+ )
2200
+ targets = torch.cat([empty_targets, targets], dim=1)
2201
+
2202
+ batch_size = audio_embeds.shape[0]
2203
+ bos = torch.ones([batch_size, 1],
2204
+ dtype=to_regress_tokens.input_ids.dtype,
2205
+ device=to_regress_tokens.input_ids.device) * self.llama_tokenizer.bos_token_id
2206
+ bos_embeds = self.llama_model.model.embed_tokens(bos)
2207
+ atts_bos = atts_audio[:, :1]
2208
+
2209
+ to_regress_embeds = self.llama_model.model.embed_tokens(to_regress_tokens.input_ids)
2210
+ inputs_embeds = torch.cat([bos_embeds, audio_embeds, to_regress_embeds], dim=1)
2211
+ attention_mask = torch.cat([atts_bos, atts_audio, to_regress_tokens.attention_mask], dim=1)
2212
+
2213
+ outputs = self.llama_model(
2214
+ inputs_embeds=inputs_embeds,
2215
+ attention_mask=attention_mask,
2216
+ return_dict=True,
2217
+ labels=targets,
2218
+ )
2219
+ loss = outputs.loss
2220
+
2221
+ return {"loss": loss}
2222
+
2223
+ @classmethod
2224
+ def from_config(cls, cfg):
2225
+ mert_model = cfg.get("mert_model", "")
2226
+ llama_model = cfg.get("llama_model")
2227
+
2228
+ low_resource = cfg.get("low_resource", False)
2229
+ device_8bit = cfg.get("device_8bit", 0)
2230
+
2231
+ prompt_path = cfg.get("prompt_path", "")
2232
+ prompt_template = cfg.get("prompt_template", "")
2233
+ max_txt_len = cfg.get("max_txt_len", 32)
2234
+ end_sym = cfg.get("end_sym", '\n')
2235
+
2236
+ model = cls(
2237
+ mert_model=mert_model,
2238
+ llama_model=llama_model,
2239
+ prompt_path=prompt_path,
2240
+ prompt_template=prompt_template,
2241
+ max_txt_len=max_txt_len,
2242
+ end_sym=end_sym,
2243
+ low_resource=low_resource,
2244
+ device_8bit=device_8bit,
2245
+ )
2246
+
2247
+ ckpt_path = cfg.get("ckpt", "") # load ckpt weights of MusiLingo
2248
+ if ckpt_path:
2249
+ print("Load MERT-LLM Checkpoint: {}".format(ckpt_path))
2250
+ ckpt = torch.load(ckpt_path, map_location="cpu")
2251
+ msg = model.load_state_dict(ckpt['model'], strict=False)
2252
+
2253
+ return model
2254
+
2255
+
2256
+ class MusilingoModel(PreTrainedModel):
2257
+ config_class = MusiLingoConfig
2258
+ def __init__(self, config):
2259
+ super().__init__(config)
2260
+ self.model = MusiLingo(
2261
+ mert_model=config.mert_model,
2262
+ llama_model=config.llama_model,
2263
+ config=config,
2264
+ prompt_path=config.prompt_path,
2265
+ prompt_template=config.prompt_template,
2266
+ max_txt_len=config.max_txt_len,
2267
+ end_sym=config.end_sym,
2268
+ low_resource=config.low_resource,
2269
+ device_8bit=config.device_8bit
2270
+ # self.linear_ckpt_path = config.linear_ckpt_path``
2271
+ )
2272
+
2273
+
2274
+ def forward(self, tensor):
2275
+ return self.model.forward(tensor)