Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
5c7984d
1
Parent(s):
f1f1f55
Add modules
Browse files- configs/.gitignore +1 -0
- configs/s2.json +91 -0
- configs/s2v2Pro.json +91 -0
- configs/s2v2ProPlus.json +91 -0
- eres2net/ERes2Net.py +264 -0
- eres2net/ERes2NetV2.py +272 -0
- eres2net/ERes2Net_huge.py +289 -0
- eres2net/fusion.py +27 -0
- eres2net/kaldi.py +844 -0
- eres2net/pooling_layers.py +101 -0
- f5_tts/model/__init__.py +13 -0
- f5_tts/model/backbones/README.md +20 -0
- f5_tts/model/backbones/dit.py +194 -0
- f5_tts/model/backbones/mmdit.py +146 -0
- f5_tts/model/backbones/unett.py +219 -0
- f5_tts/model/modules.py +666 -0
- prepare_datasets/1-get-text.py +143 -0
- prepare_datasets/2-get-hubert-wav32k.py +134 -0
- prepare_datasets/2-get-sv.py +115 -0
- prepare_datasets/3-get-semantic.py +118 -0
configs/.gitignore
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
*.yaml
|
configs/s2.json
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"train": {
|
| 3 |
+
"log_interval": 100,
|
| 4 |
+
"eval_interval": 500,
|
| 5 |
+
"seed": 1234,
|
| 6 |
+
"epochs": 100,
|
| 7 |
+
"learning_rate": 0.0001,
|
| 8 |
+
"betas": [
|
| 9 |
+
0.8,
|
| 10 |
+
0.99
|
| 11 |
+
],
|
| 12 |
+
"eps": 1e-09,
|
| 13 |
+
"batch_size": 32,
|
| 14 |
+
"fp16_run": true,
|
| 15 |
+
"lr_decay": 0.999875,
|
| 16 |
+
"segment_size": 20480,
|
| 17 |
+
"init_lr_ratio": 1,
|
| 18 |
+
"warmup_epochs": 0,
|
| 19 |
+
"c_mel": 45,
|
| 20 |
+
"c_kl": 1.0,
|
| 21 |
+
"text_low_lr_rate": 0.4,
|
| 22 |
+
"grad_ckpt": false
|
| 23 |
+
},
|
| 24 |
+
"data": {
|
| 25 |
+
"max_wav_value": 32768.0,
|
| 26 |
+
"sampling_rate": 32000,
|
| 27 |
+
"filter_length": 2048,
|
| 28 |
+
"hop_length": 640,
|
| 29 |
+
"win_length": 2048,
|
| 30 |
+
"n_mel_channels": 128,
|
| 31 |
+
"mel_fmin": 0.0,
|
| 32 |
+
"mel_fmax": null,
|
| 33 |
+
"add_blank": true,
|
| 34 |
+
"n_speakers": 300,
|
| 35 |
+
"cleaned_text": true
|
| 36 |
+
},
|
| 37 |
+
"model": {
|
| 38 |
+
"inter_channels": 192,
|
| 39 |
+
"hidden_channels": 192,
|
| 40 |
+
"filter_channels": 768,
|
| 41 |
+
"n_heads": 2,
|
| 42 |
+
"n_layers": 6,
|
| 43 |
+
"kernel_size": 3,
|
| 44 |
+
"p_dropout": 0.1,
|
| 45 |
+
"resblock": "1",
|
| 46 |
+
"resblock_kernel_sizes": [
|
| 47 |
+
3,
|
| 48 |
+
7,
|
| 49 |
+
11
|
| 50 |
+
],
|
| 51 |
+
"resblock_dilation_sizes": [
|
| 52 |
+
[
|
| 53 |
+
1,
|
| 54 |
+
3,
|
| 55 |
+
5
|
| 56 |
+
],
|
| 57 |
+
[
|
| 58 |
+
1,
|
| 59 |
+
3,
|
| 60 |
+
5
|
| 61 |
+
],
|
| 62 |
+
[
|
| 63 |
+
1,
|
| 64 |
+
3,
|
| 65 |
+
5
|
| 66 |
+
]
|
| 67 |
+
],
|
| 68 |
+
"upsample_rates": [
|
| 69 |
+
10,
|
| 70 |
+
8,
|
| 71 |
+
2,
|
| 72 |
+
2,
|
| 73 |
+
2
|
| 74 |
+
],
|
| 75 |
+
"upsample_initial_channel": 512,
|
| 76 |
+
"upsample_kernel_sizes": [
|
| 77 |
+
16,
|
| 78 |
+
16,
|
| 79 |
+
8,
|
| 80 |
+
2,
|
| 81 |
+
2
|
| 82 |
+
],
|
| 83 |
+
"n_layers_q": 3,
|
| 84 |
+
"use_spectral_norm": false,
|
| 85 |
+
"gin_channels": 512,
|
| 86 |
+
"semantic_frame_rate": "25hz",
|
| 87 |
+
"freeze_quantizer": true
|
| 88 |
+
},
|
| 89 |
+
"s2_ckpt_dir": "logs/s2/big2k1",
|
| 90 |
+
"content_module": "cnhubert"
|
| 91 |
+
}
|
configs/s2v2Pro.json
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"train": {
|
| 3 |
+
"log_interval": 100,
|
| 4 |
+
"eval_interval": 500,
|
| 5 |
+
"seed": 1234,
|
| 6 |
+
"epochs": 100,
|
| 7 |
+
"learning_rate": 0.0001,
|
| 8 |
+
"betas": [
|
| 9 |
+
0.8,
|
| 10 |
+
0.99
|
| 11 |
+
],
|
| 12 |
+
"eps": 1e-09,
|
| 13 |
+
"batch_size": 32,
|
| 14 |
+
"fp16_run": true,
|
| 15 |
+
"lr_decay": 0.999875,
|
| 16 |
+
"segment_size": 20480,
|
| 17 |
+
"init_lr_ratio": 1,
|
| 18 |
+
"warmup_epochs": 0,
|
| 19 |
+
"c_mel": 45,
|
| 20 |
+
"c_kl": 1.0,
|
| 21 |
+
"text_low_lr_rate": 0.4,
|
| 22 |
+
"grad_ckpt": false
|
| 23 |
+
},
|
| 24 |
+
"data": {
|
| 25 |
+
"max_wav_value": 32768.0,
|
| 26 |
+
"sampling_rate": 32000,
|
| 27 |
+
"filter_length": 2048,
|
| 28 |
+
"hop_length": 640,
|
| 29 |
+
"win_length": 2048,
|
| 30 |
+
"n_mel_channels": 128,
|
| 31 |
+
"mel_fmin": 0.0,
|
| 32 |
+
"mel_fmax": null,
|
| 33 |
+
"add_blank": true,
|
| 34 |
+
"n_speakers": 300,
|
| 35 |
+
"cleaned_text": true
|
| 36 |
+
},
|
| 37 |
+
"model": {
|
| 38 |
+
"inter_channels": 192,
|
| 39 |
+
"hidden_channels": 192,
|
| 40 |
+
"filter_channels": 768,
|
| 41 |
+
"n_heads": 2,
|
| 42 |
+
"n_layers": 6,
|
| 43 |
+
"kernel_size": 3,
|
| 44 |
+
"p_dropout": 0.0,
|
| 45 |
+
"resblock": "1",
|
| 46 |
+
"resblock_kernel_sizes": [
|
| 47 |
+
3,
|
| 48 |
+
7,
|
| 49 |
+
11
|
| 50 |
+
],
|
| 51 |
+
"resblock_dilation_sizes": [
|
| 52 |
+
[
|
| 53 |
+
1,
|
| 54 |
+
3,
|
| 55 |
+
5
|
| 56 |
+
],
|
| 57 |
+
[
|
| 58 |
+
1,
|
| 59 |
+
3,
|
| 60 |
+
5
|
| 61 |
+
],
|
| 62 |
+
[
|
| 63 |
+
1,
|
| 64 |
+
3,
|
| 65 |
+
5
|
| 66 |
+
]
|
| 67 |
+
],
|
| 68 |
+
"upsample_rates": [
|
| 69 |
+
10,
|
| 70 |
+
8,
|
| 71 |
+
2,
|
| 72 |
+
2,
|
| 73 |
+
2
|
| 74 |
+
],
|
| 75 |
+
"upsample_initial_channel": 512,
|
| 76 |
+
"upsample_kernel_sizes": [
|
| 77 |
+
16,
|
| 78 |
+
16,
|
| 79 |
+
8,
|
| 80 |
+
2,
|
| 81 |
+
2
|
| 82 |
+
],
|
| 83 |
+
"n_layers_q": 3,
|
| 84 |
+
"use_spectral_norm": false,
|
| 85 |
+
"gin_channels": 1024,
|
| 86 |
+
"semantic_frame_rate": "25hz",
|
| 87 |
+
"freeze_quantizer": true
|
| 88 |
+
},
|
| 89 |
+
"s2_ckpt_dir": "logs/s2/big2k1",
|
| 90 |
+
"content_module": "cnhubert"
|
| 91 |
+
}
|
configs/s2v2ProPlus.json
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"train": {
|
| 3 |
+
"log_interval": 100,
|
| 4 |
+
"eval_interval": 500,
|
| 5 |
+
"seed": 1234,
|
| 6 |
+
"epochs": 100,
|
| 7 |
+
"learning_rate": 0.0001,
|
| 8 |
+
"betas": [
|
| 9 |
+
0.8,
|
| 10 |
+
0.99
|
| 11 |
+
],
|
| 12 |
+
"eps": 1e-09,
|
| 13 |
+
"batch_size": 32,
|
| 14 |
+
"fp16_run": true,
|
| 15 |
+
"lr_decay": 0.999875,
|
| 16 |
+
"segment_size": 20480,
|
| 17 |
+
"init_lr_ratio": 1,
|
| 18 |
+
"warmup_epochs": 0,
|
| 19 |
+
"c_mel": 45,
|
| 20 |
+
"c_kl": 1.0,
|
| 21 |
+
"text_low_lr_rate": 0.4,
|
| 22 |
+
"grad_ckpt": false
|
| 23 |
+
},
|
| 24 |
+
"data": {
|
| 25 |
+
"max_wav_value": 32768.0,
|
| 26 |
+
"sampling_rate": 32000,
|
| 27 |
+
"filter_length": 2048,
|
| 28 |
+
"hop_length": 640,
|
| 29 |
+
"win_length": 2048,
|
| 30 |
+
"n_mel_channels": 128,
|
| 31 |
+
"mel_fmin": 0.0,
|
| 32 |
+
"mel_fmax": null,
|
| 33 |
+
"add_blank": true,
|
| 34 |
+
"n_speakers": 300,
|
| 35 |
+
"cleaned_text": true
|
| 36 |
+
},
|
| 37 |
+
"model": {
|
| 38 |
+
"inter_channels": 192,
|
| 39 |
+
"hidden_channels": 192,
|
| 40 |
+
"filter_channels": 768,
|
| 41 |
+
"n_heads": 2,
|
| 42 |
+
"n_layers": 6,
|
| 43 |
+
"kernel_size": 3,
|
| 44 |
+
"p_dropout": 0.0,
|
| 45 |
+
"resblock": "1",
|
| 46 |
+
"resblock_kernel_sizes": [
|
| 47 |
+
3,
|
| 48 |
+
7,
|
| 49 |
+
11
|
| 50 |
+
],
|
| 51 |
+
"resblock_dilation_sizes": [
|
| 52 |
+
[
|
| 53 |
+
1,
|
| 54 |
+
3,
|
| 55 |
+
5
|
| 56 |
+
],
|
| 57 |
+
[
|
| 58 |
+
1,
|
| 59 |
+
3,
|
| 60 |
+
5
|
| 61 |
+
],
|
| 62 |
+
[
|
| 63 |
+
1,
|
| 64 |
+
3,
|
| 65 |
+
5
|
| 66 |
+
]
|
| 67 |
+
],
|
| 68 |
+
"upsample_rates": [
|
| 69 |
+
10,
|
| 70 |
+
8,
|
| 71 |
+
2,
|
| 72 |
+
2,
|
| 73 |
+
2
|
| 74 |
+
],
|
| 75 |
+
"upsample_initial_channel": 768,
|
| 76 |
+
"upsample_kernel_sizes": [
|
| 77 |
+
20,
|
| 78 |
+
16,
|
| 79 |
+
8,
|
| 80 |
+
2,
|
| 81 |
+
2
|
| 82 |
+
],
|
| 83 |
+
"n_layers_q": 3,
|
| 84 |
+
"use_spectral_norm": false,
|
| 85 |
+
"gin_channels": 1024,
|
| 86 |
+
"semantic_frame_rate": "25hz",
|
| 87 |
+
"freeze_quantizer": true
|
| 88 |
+
},
|
| 89 |
+
"s2_ckpt_dir": "logs/s2/big2k1",
|
| 90 |
+
"content_module": "cnhubert"
|
| 91 |
+
}
|
eres2net/ERes2Net.py
ADDED
|
@@ -0,0 +1,264 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker). All Rights Reserved.
|
| 2 |
+
# Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
| 3 |
+
|
| 4 |
+
"""
|
| 5 |
+
Res2Net implementation is adapted from https://github.com/wenet-e2e/wespeaker.
|
| 6 |
+
ERes2Net incorporates both local and global feature fusion techniques to improve the performance.
|
| 7 |
+
The local feature fusion (LFF) fuses the features within one single residual block to extract the local signal.
|
| 8 |
+
The global feature fusion (GFF) takes acoustic features of different scales as input to aggregate global signal.
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
import math
|
| 13 |
+
import torch.nn as nn
|
| 14 |
+
import torch.nn.functional as F
|
| 15 |
+
import pooling_layers as pooling_layers
|
| 16 |
+
from fusion import AFF
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class ReLU(nn.Hardtanh):
|
| 20 |
+
def __init__(self, inplace=False):
|
| 21 |
+
super(ReLU, self).__init__(0, 20, inplace)
|
| 22 |
+
|
| 23 |
+
def __repr__(self):
|
| 24 |
+
inplace_str = "inplace" if self.inplace else ""
|
| 25 |
+
return self.__class__.__name__ + " (" + inplace_str + ")"
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class BasicBlockERes2Net(nn.Module):
|
| 29 |
+
expansion = 2
|
| 30 |
+
|
| 31 |
+
def __init__(self, in_planes, planes, stride=1, baseWidth=32, scale=2):
|
| 32 |
+
super(BasicBlockERes2Net, self).__init__()
|
| 33 |
+
width = int(math.floor(planes * (baseWidth / 64.0)))
|
| 34 |
+
self.conv1 = nn.Conv2d(in_planes, width * scale, kernel_size=1, stride=stride, bias=False)
|
| 35 |
+
self.bn1 = nn.BatchNorm2d(width * scale)
|
| 36 |
+
self.nums = scale
|
| 37 |
+
|
| 38 |
+
convs = []
|
| 39 |
+
bns = []
|
| 40 |
+
for i in range(self.nums):
|
| 41 |
+
convs.append(nn.Conv2d(width, width, kernel_size=3, padding=1, bias=False))
|
| 42 |
+
bns.append(nn.BatchNorm2d(width))
|
| 43 |
+
self.convs = nn.ModuleList(convs)
|
| 44 |
+
self.bns = nn.ModuleList(bns)
|
| 45 |
+
self.relu = ReLU(inplace=True)
|
| 46 |
+
|
| 47 |
+
self.conv3 = nn.Conv2d(width * scale, planes * self.expansion, kernel_size=1, bias=False)
|
| 48 |
+
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
|
| 49 |
+
self.shortcut = nn.Sequential()
|
| 50 |
+
if stride != 1 or in_planes != self.expansion * planes:
|
| 51 |
+
self.shortcut = nn.Sequential(
|
| 52 |
+
nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
|
| 53 |
+
nn.BatchNorm2d(self.expansion * planes),
|
| 54 |
+
)
|
| 55 |
+
self.stride = stride
|
| 56 |
+
self.width = width
|
| 57 |
+
self.scale = scale
|
| 58 |
+
|
| 59 |
+
def forward(self, x):
|
| 60 |
+
residual = x
|
| 61 |
+
|
| 62 |
+
out = self.conv1(x)
|
| 63 |
+
out = self.bn1(out)
|
| 64 |
+
out = self.relu(out)
|
| 65 |
+
spx = torch.split(out, self.width, 1)
|
| 66 |
+
for i in range(self.nums):
|
| 67 |
+
if i == 0:
|
| 68 |
+
sp = spx[i]
|
| 69 |
+
else:
|
| 70 |
+
sp = sp + spx[i]
|
| 71 |
+
sp = self.convs[i](sp)
|
| 72 |
+
sp = self.relu(self.bns[i](sp))
|
| 73 |
+
if i == 0:
|
| 74 |
+
out = sp
|
| 75 |
+
else:
|
| 76 |
+
out = torch.cat((out, sp), 1)
|
| 77 |
+
|
| 78 |
+
out = self.conv3(out)
|
| 79 |
+
out = self.bn3(out)
|
| 80 |
+
|
| 81 |
+
residual = self.shortcut(x)
|
| 82 |
+
out += residual
|
| 83 |
+
out = self.relu(out)
|
| 84 |
+
|
| 85 |
+
return out
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
class BasicBlockERes2Net_diff_AFF(nn.Module):
|
| 89 |
+
expansion = 2
|
| 90 |
+
|
| 91 |
+
def __init__(self, in_planes, planes, stride=1, baseWidth=32, scale=2):
|
| 92 |
+
super(BasicBlockERes2Net_diff_AFF, self).__init__()
|
| 93 |
+
width = int(math.floor(planes * (baseWidth / 64.0)))
|
| 94 |
+
self.conv1 = nn.Conv2d(in_planes, width * scale, kernel_size=1, stride=stride, bias=False)
|
| 95 |
+
self.bn1 = nn.BatchNorm2d(width * scale)
|
| 96 |
+
self.nums = scale
|
| 97 |
+
|
| 98 |
+
convs = []
|
| 99 |
+
fuse_models = []
|
| 100 |
+
bns = []
|
| 101 |
+
for i in range(self.nums):
|
| 102 |
+
convs.append(nn.Conv2d(width, width, kernel_size=3, padding=1, bias=False))
|
| 103 |
+
bns.append(nn.BatchNorm2d(width))
|
| 104 |
+
for j in range(self.nums - 1):
|
| 105 |
+
fuse_models.append(AFF(channels=width))
|
| 106 |
+
|
| 107 |
+
self.convs = nn.ModuleList(convs)
|
| 108 |
+
self.bns = nn.ModuleList(bns)
|
| 109 |
+
self.fuse_models = nn.ModuleList(fuse_models)
|
| 110 |
+
self.relu = ReLU(inplace=True)
|
| 111 |
+
|
| 112 |
+
self.conv3 = nn.Conv2d(width * scale, planes * self.expansion, kernel_size=1, bias=False)
|
| 113 |
+
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
|
| 114 |
+
self.shortcut = nn.Sequential()
|
| 115 |
+
if stride != 1 or in_planes != self.expansion * planes:
|
| 116 |
+
self.shortcut = nn.Sequential(
|
| 117 |
+
nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
|
| 118 |
+
nn.BatchNorm2d(self.expansion * planes),
|
| 119 |
+
)
|
| 120 |
+
self.stride = stride
|
| 121 |
+
self.width = width
|
| 122 |
+
self.scale = scale
|
| 123 |
+
|
| 124 |
+
def forward(self, x):
|
| 125 |
+
residual = x
|
| 126 |
+
|
| 127 |
+
out = self.conv1(x)
|
| 128 |
+
out = self.bn1(out)
|
| 129 |
+
out = self.relu(out)
|
| 130 |
+
spx = torch.split(out, self.width, 1)
|
| 131 |
+
for i in range(self.nums):
|
| 132 |
+
if i == 0:
|
| 133 |
+
sp = spx[i]
|
| 134 |
+
else:
|
| 135 |
+
sp = self.fuse_models[i - 1](sp, spx[i])
|
| 136 |
+
|
| 137 |
+
sp = self.convs[i](sp)
|
| 138 |
+
sp = self.relu(self.bns[i](sp))
|
| 139 |
+
if i == 0:
|
| 140 |
+
out = sp
|
| 141 |
+
else:
|
| 142 |
+
out = torch.cat((out, sp), 1)
|
| 143 |
+
|
| 144 |
+
out = self.conv3(out)
|
| 145 |
+
out = self.bn3(out)
|
| 146 |
+
|
| 147 |
+
residual = self.shortcut(x)
|
| 148 |
+
out += residual
|
| 149 |
+
out = self.relu(out)
|
| 150 |
+
|
| 151 |
+
return out
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
class ERes2Net(nn.Module):
|
| 155 |
+
def __init__(
|
| 156 |
+
self,
|
| 157 |
+
block=BasicBlockERes2Net,
|
| 158 |
+
block_fuse=BasicBlockERes2Net_diff_AFF,
|
| 159 |
+
num_blocks=[3, 4, 6, 3],
|
| 160 |
+
m_channels=32,
|
| 161 |
+
feat_dim=80,
|
| 162 |
+
embedding_size=192,
|
| 163 |
+
pooling_func="TSTP",
|
| 164 |
+
two_emb_layer=False,
|
| 165 |
+
):
|
| 166 |
+
super(ERes2Net, self).__init__()
|
| 167 |
+
self.in_planes = m_channels
|
| 168 |
+
self.feat_dim = feat_dim
|
| 169 |
+
self.embedding_size = embedding_size
|
| 170 |
+
self.stats_dim = int(feat_dim / 8) * m_channels * 8
|
| 171 |
+
self.two_emb_layer = two_emb_layer
|
| 172 |
+
|
| 173 |
+
self.conv1 = nn.Conv2d(1, m_channels, kernel_size=3, stride=1, padding=1, bias=False)
|
| 174 |
+
self.bn1 = nn.BatchNorm2d(m_channels)
|
| 175 |
+
self.layer1 = self._make_layer(block, m_channels, num_blocks[0], stride=1)
|
| 176 |
+
self.layer2 = self._make_layer(block, m_channels * 2, num_blocks[1], stride=2)
|
| 177 |
+
self.layer3 = self._make_layer(block_fuse, m_channels * 4, num_blocks[2], stride=2)
|
| 178 |
+
self.layer4 = self._make_layer(block_fuse, m_channels * 8, num_blocks[3], stride=2)
|
| 179 |
+
|
| 180 |
+
# Downsampling module for each layer
|
| 181 |
+
self.layer1_downsample = nn.Conv2d(
|
| 182 |
+
m_channels * 2, m_channels * 4, kernel_size=3, stride=2, padding=1, bias=False
|
| 183 |
+
)
|
| 184 |
+
self.layer2_downsample = nn.Conv2d(
|
| 185 |
+
m_channels * 4, m_channels * 8, kernel_size=3, padding=1, stride=2, bias=False
|
| 186 |
+
)
|
| 187 |
+
self.layer3_downsample = nn.Conv2d(
|
| 188 |
+
m_channels * 8, m_channels * 16, kernel_size=3, padding=1, stride=2, bias=False
|
| 189 |
+
)
|
| 190 |
+
|
| 191 |
+
# Bottom-up fusion module
|
| 192 |
+
self.fuse_mode12 = AFF(channels=m_channels * 4)
|
| 193 |
+
self.fuse_mode123 = AFF(channels=m_channels * 8)
|
| 194 |
+
self.fuse_mode1234 = AFF(channels=m_channels * 16)
|
| 195 |
+
|
| 196 |
+
self.n_stats = 1 if pooling_func == "TAP" or pooling_func == "TSDP" else 2
|
| 197 |
+
self.pool = getattr(pooling_layers, pooling_func)(in_dim=self.stats_dim * block.expansion)
|
| 198 |
+
self.seg_1 = nn.Linear(self.stats_dim * block.expansion * self.n_stats, embedding_size)
|
| 199 |
+
if self.two_emb_layer:
|
| 200 |
+
self.seg_bn_1 = nn.BatchNorm1d(embedding_size, affine=False)
|
| 201 |
+
self.seg_2 = nn.Linear(embedding_size, embedding_size)
|
| 202 |
+
else:
|
| 203 |
+
self.seg_bn_1 = nn.Identity()
|
| 204 |
+
self.seg_2 = nn.Identity()
|
| 205 |
+
|
| 206 |
+
def _make_layer(self, block, planes, num_blocks, stride):
|
| 207 |
+
strides = [stride] + [1] * (num_blocks - 1)
|
| 208 |
+
layers = []
|
| 209 |
+
for stride in strides:
|
| 210 |
+
layers.append(block(self.in_planes, planes, stride))
|
| 211 |
+
self.in_planes = planes * block.expansion
|
| 212 |
+
return nn.Sequential(*layers)
|
| 213 |
+
|
| 214 |
+
def forward(self, x):
|
| 215 |
+
x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T)
|
| 216 |
+
x = x.unsqueeze_(1)
|
| 217 |
+
out = F.relu(self.bn1(self.conv1(x)))
|
| 218 |
+
out1 = self.layer1(out)
|
| 219 |
+
out2 = self.layer2(out1)
|
| 220 |
+
out1_downsample = self.layer1_downsample(out1)
|
| 221 |
+
fuse_out12 = self.fuse_mode12(out2, out1_downsample)
|
| 222 |
+
out3 = self.layer3(out2)
|
| 223 |
+
fuse_out12_downsample = self.layer2_downsample(fuse_out12)
|
| 224 |
+
fuse_out123 = self.fuse_mode123(out3, fuse_out12_downsample)
|
| 225 |
+
out4 = self.layer4(out3)
|
| 226 |
+
fuse_out123_downsample = self.layer3_downsample(fuse_out123)
|
| 227 |
+
fuse_out1234 = self.fuse_mode1234(out4, fuse_out123_downsample)
|
| 228 |
+
stats = self.pool(fuse_out1234)
|
| 229 |
+
|
| 230 |
+
embed_a = self.seg_1(stats)
|
| 231 |
+
if self.two_emb_layer:
|
| 232 |
+
out = F.relu(embed_a)
|
| 233 |
+
out = self.seg_bn_1(out)
|
| 234 |
+
embed_b = self.seg_2(out)
|
| 235 |
+
return embed_b
|
| 236 |
+
else:
|
| 237 |
+
return embed_a
|
| 238 |
+
|
| 239 |
+
def forward3(self, x):
|
| 240 |
+
x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T)
|
| 241 |
+
x = x.unsqueeze_(1)
|
| 242 |
+
out = F.relu(self.bn1(self.conv1(x)))
|
| 243 |
+
out1 = self.layer1(out)
|
| 244 |
+
out2 = self.layer2(out1)
|
| 245 |
+
out1_downsample = self.layer1_downsample(out1)
|
| 246 |
+
fuse_out12 = self.fuse_mode12(out2, out1_downsample)
|
| 247 |
+
out3 = self.layer3(out2)
|
| 248 |
+
fuse_out12_downsample = self.layer2_downsample(fuse_out12)
|
| 249 |
+
fuse_out123 = self.fuse_mode123(out3, fuse_out12_downsample)
|
| 250 |
+
out4 = self.layer4(out3)
|
| 251 |
+
fuse_out123_downsample = self.layer3_downsample(fuse_out123)
|
| 252 |
+
fuse_out1234 = self.fuse_mode1234(out4, fuse_out123_downsample).flatten(start_dim=1, end_dim=2).mean(-1)
|
| 253 |
+
return fuse_out1234
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
if __name__ == "__main__":
|
| 257 |
+
x = torch.zeros(10, 300, 80)
|
| 258 |
+
model = ERes2Net(feat_dim=80, embedding_size=192, pooling_func="TSTP")
|
| 259 |
+
model.eval()
|
| 260 |
+
out = model(x)
|
| 261 |
+
print(out.shape) # torch.Size([10, 192])
|
| 262 |
+
|
| 263 |
+
num_params = sum(param.numel() for param in model.parameters())
|
| 264 |
+
print("{} M".format(num_params / 1e6)) # 6.61M
|
eres2net/ERes2NetV2.py
ADDED
|
@@ -0,0 +1,272 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker). All Rights Reserved.
|
| 2 |
+
# Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
| 3 |
+
|
| 4 |
+
"""
|
| 5 |
+
To further improve the short-duration feature extraction capability of ERes2Net, we expand the channel dimension
|
| 6 |
+
within each stage. However, this modification also increases the number of model parameters and computational complexity.
|
| 7 |
+
To alleviate this problem, we propose an improved ERes2NetV2 by pruning redundant structures, ultimately reducing
|
| 8 |
+
both the model parameters and its computational cost.
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
import math
|
| 13 |
+
import torch.nn as nn
|
| 14 |
+
import torch.nn.functional as F
|
| 15 |
+
import pooling_layers as pooling_layers
|
| 16 |
+
from fusion import AFF
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class ReLU(nn.Hardtanh):
|
| 20 |
+
def __init__(self, inplace=False):
|
| 21 |
+
super(ReLU, self).__init__(0, 20, inplace)
|
| 22 |
+
|
| 23 |
+
def __repr__(self):
|
| 24 |
+
inplace_str = "inplace" if self.inplace else ""
|
| 25 |
+
return self.__class__.__name__ + " (" + inplace_str + ")"
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class BasicBlockERes2NetV2(nn.Module):
|
| 29 |
+
def __init__(self, in_planes, planes, stride=1, baseWidth=26, scale=2, expansion=2):
|
| 30 |
+
super(BasicBlockERes2NetV2, self).__init__()
|
| 31 |
+
width = int(math.floor(planes * (baseWidth / 64.0)))
|
| 32 |
+
self.conv1 = nn.Conv2d(in_planes, width * scale, kernel_size=1, stride=stride, bias=False)
|
| 33 |
+
self.bn1 = nn.BatchNorm2d(width * scale)
|
| 34 |
+
self.nums = scale
|
| 35 |
+
self.expansion = expansion
|
| 36 |
+
|
| 37 |
+
convs = []
|
| 38 |
+
bns = []
|
| 39 |
+
for i in range(self.nums):
|
| 40 |
+
convs.append(nn.Conv2d(width, width, kernel_size=3, padding=1, bias=False))
|
| 41 |
+
bns.append(nn.BatchNorm2d(width))
|
| 42 |
+
self.convs = nn.ModuleList(convs)
|
| 43 |
+
self.bns = nn.ModuleList(bns)
|
| 44 |
+
self.relu = ReLU(inplace=True)
|
| 45 |
+
|
| 46 |
+
self.conv3 = nn.Conv2d(width * scale, planes * self.expansion, kernel_size=1, bias=False)
|
| 47 |
+
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
|
| 48 |
+
self.shortcut = nn.Sequential()
|
| 49 |
+
if stride != 1 or in_planes != self.expansion * planes:
|
| 50 |
+
self.shortcut = nn.Sequential(
|
| 51 |
+
nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
|
| 52 |
+
nn.BatchNorm2d(self.expansion * planes),
|
| 53 |
+
)
|
| 54 |
+
self.stride = stride
|
| 55 |
+
self.width = width
|
| 56 |
+
self.scale = scale
|
| 57 |
+
|
| 58 |
+
def forward(self, x):
|
| 59 |
+
residual = x
|
| 60 |
+
|
| 61 |
+
out = self.conv1(x)
|
| 62 |
+
out = self.bn1(out)
|
| 63 |
+
out = self.relu(out)
|
| 64 |
+
spx = torch.split(out, self.width, 1)
|
| 65 |
+
for i in range(self.nums):
|
| 66 |
+
if i == 0:
|
| 67 |
+
sp = spx[i]
|
| 68 |
+
else:
|
| 69 |
+
sp = sp + spx[i]
|
| 70 |
+
sp = self.convs[i](sp)
|
| 71 |
+
sp = self.relu(self.bns[i](sp))
|
| 72 |
+
if i == 0:
|
| 73 |
+
out = sp
|
| 74 |
+
else:
|
| 75 |
+
out = torch.cat((out, sp), 1)
|
| 76 |
+
|
| 77 |
+
out = self.conv3(out)
|
| 78 |
+
out = self.bn3(out)
|
| 79 |
+
|
| 80 |
+
residual = self.shortcut(x)
|
| 81 |
+
out += residual
|
| 82 |
+
out = self.relu(out)
|
| 83 |
+
|
| 84 |
+
return out
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
class BasicBlockERes2NetV2AFF(nn.Module):
|
| 88 |
+
def __init__(self, in_planes, planes, stride=1, baseWidth=26, scale=2, expansion=2):
|
| 89 |
+
super(BasicBlockERes2NetV2AFF, self).__init__()
|
| 90 |
+
width = int(math.floor(planes * (baseWidth / 64.0)))
|
| 91 |
+
self.conv1 = nn.Conv2d(in_planes, width * scale, kernel_size=1, stride=stride, bias=False)
|
| 92 |
+
self.bn1 = nn.BatchNorm2d(width * scale)
|
| 93 |
+
self.nums = scale
|
| 94 |
+
self.expansion = expansion
|
| 95 |
+
|
| 96 |
+
convs = []
|
| 97 |
+
fuse_models = []
|
| 98 |
+
bns = []
|
| 99 |
+
for i in range(self.nums):
|
| 100 |
+
convs.append(nn.Conv2d(width, width, kernel_size=3, padding=1, bias=False))
|
| 101 |
+
bns.append(nn.BatchNorm2d(width))
|
| 102 |
+
for j in range(self.nums - 1):
|
| 103 |
+
fuse_models.append(AFF(channels=width, r=4))
|
| 104 |
+
|
| 105 |
+
self.convs = nn.ModuleList(convs)
|
| 106 |
+
self.bns = nn.ModuleList(bns)
|
| 107 |
+
self.fuse_models = nn.ModuleList(fuse_models)
|
| 108 |
+
self.relu = ReLU(inplace=True)
|
| 109 |
+
|
| 110 |
+
self.conv3 = nn.Conv2d(width * scale, planes * self.expansion, kernel_size=1, bias=False)
|
| 111 |
+
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
|
| 112 |
+
self.shortcut = nn.Sequential()
|
| 113 |
+
if stride != 1 or in_planes != self.expansion * planes:
|
| 114 |
+
self.shortcut = nn.Sequential(
|
| 115 |
+
nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
|
| 116 |
+
nn.BatchNorm2d(self.expansion * planes),
|
| 117 |
+
)
|
| 118 |
+
self.stride = stride
|
| 119 |
+
self.width = width
|
| 120 |
+
self.scale = scale
|
| 121 |
+
|
| 122 |
+
def forward(self, x):
|
| 123 |
+
residual = x
|
| 124 |
+
|
| 125 |
+
out = self.conv1(x)
|
| 126 |
+
out = self.bn1(out)
|
| 127 |
+
out = self.relu(out)
|
| 128 |
+
spx = torch.split(out, self.width, 1)
|
| 129 |
+
for i in range(self.nums):
|
| 130 |
+
if i == 0:
|
| 131 |
+
sp = spx[i]
|
| 132 |
+
else:
|
| 133 |
+
sp = self.fuse_models[i - 1](sp, spx[i])
|
| 134 |
+
|
| 135 |
+
sp = self.convs[i](sp)
|
| 136 |
+
sp = self.relu(self.bns[i](sp))
|
| 137 |
+
if i == 0:
|
| 138 |
+
out = sp
|
| 139 |
+
else:
|
| 140 |
+
out = torch.cat((out, sp), 1)
|
| 141 |
+
|
| 142 |
+
out = self.conv3(out)
|
| 143 |
+
out = self.bn3(out)
|
| 144 |
+
|
| 145 |
+
residual = self.shortcut(x)
|
| 146 |
+
out += residual
|
| 147 |
+
out = self.relu(out)
|
| 148 |
+
|
| 149 |
+
return out
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
class ERes2NetV2(nn.Module):
|
| 153 |
+
def __init__(
|
| 154 |
+
self,
|
| 155 |
+
block=BasicBlockERes2NetV2,
|
| 156 |
+
block_fuse=BasicBlockERes2NetV2AFF,
|
| 157 |
+
num_blocks=[3, 4, 6, 3],
|
| 158 |
+
m_channels=64,
|
| 159 |
+
feat_dim=80,
|
| 160 |
+
embedding_size=192,
|
| 161 |
+
baseWidth=26,
|
| 162 |
+
scale=2,
|
| 163 |
+
expansion=2,
|
| 164 |
+
pooling_func="TSTP",
|
| 165 |
+
two_emb_layer=False,
|
| 166 |
+
):
|
| 167 |
+
super(ERes2NetV2, self).__init__()
|
| 168 |
+
self.in_planes = m_channels
|
| 169 |
+
self.feat_dim = feat_dim
|
| 170 |
+
self.embedding_size = embedding_size
|
| 171 |
+
self.stats_dim = int(feat_dim / 8) * m_channels * 8
|
| 172 |
+
self.two_emb_layer = two_emb_layer
|
| 173 |
+
self.baseWidth = baseWidth
|
| 174 |
+
self.scale = scale
|
| 175 |
+
self.expansion = expansion
|
| 176 |
+
|
| 177 |
+
self.conv1 = nn.Conv2d(1, m_channels, kernel_size=3, stride=1, padding=1, bias=False)
|
| 178 |
+
self.bn1 = nn.BatchNorm2d(m_channels)
|
| 179 |
+
self.layer1 = self._make_layer(block, m_channels, num_blocks[0], stride=1)
|
| 180 |
+
self.layer2 = self._make_layer(block, m_channels * 2, num_blocks[1], stride=2)
|
| 181 |
+
self.layer3 = self._make_layer(block_fuse, m_channels * 4, num_blocks[2], stride=2)
|
| 182 |
+
self.layer4 = self._make_layer(block_fuse, m_channels * 8, num_blocks[3], stride=2)
|
| 183 |
+
|
| 184 |
+
# Downsampling module
|
| 185 |
+
self.layer3_ds = nn.Conv2d(
|
| 186 |
+
m_channels * 4 * self.expansion,
|
| 187 |
+
m_channels * 8 * self.expansion,
|
| 188 |
+
kernel_size=3,
|
| 189 |
+
padding=1,
|
| 190 |
+
stride=2,
|
| 191 |
+
bias=False,
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
+
# Bottom-up fusion module
|
| 195 |
+
self.fuse34 = AFF(channels=m_channels * 8 * self.expansion, r=4)
|
| 196 |
+
|
| 197 |
+
self.n_stats = 1 if pooling_func == "TAP" or pooling_func == "TSDP" else 2
|
| 198 |
+
self.pool = getattr(pooling_layers, pooling_func)(in_dim=self.stats_dim * self.expansion)
|
| 199 |
+
self.seg_1 = nn.Linear(self.stats_dim * self.expansion * self.n_stats, embedding_size)
|
| 200 |
+
if self.two_emb_layer:
|
| 201 |
+
self.seg_bn_1 = nn.BatchNorm1d(embedding_size, affine=False)
|
| 202 |
+
self.seg_2 = nn.Linear(embedding_size, embedding_size)
|
| 203 |
+
else:
|
| 204 |
+
self.seg_bn_1 = nn.Identity()
|
| 205 |
+
self.seg_2 = nn.Identity()
|
| 206 |
+
|
| 207 |
+
def _make_layer(self, block, planes, num_blocks, stride):
|
| 208 |
+
strides = [stride] + [1] * (num_blocks - 1)
|
| 209 |
+
layers = []
|
| 210 |
+
for stride in strides:
|
| 211 |
+
layers.append(
|
| 212 |
+
block(
|
| 213 |
+
self.in_planes, planes, stride, baseWidth=self.baseWidth, scale=self.scale, expansion=self.expansion
|
| 214 |
+
)
|
| 215 |
+
)
|
| 216 |
+
self.in_planes = planes * self.expansion
|
| 217 |
+
return nn.Sequential(*layers)
|
| 218 |
+
|
| 219 |
+
def forward(self, x):
|
| 220 |
+
x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T)
|
| 221 |
+
x = x.unsqueeze_(1)
|
| 222 |
+
out = F.relu(self.bn1(self.conv1(x)))
|
| 223 |
+
out1 = self.layer1(out)
|
| 224 |
+
out2 = self.layer2(out1)
|
| 225 |
+
out3 = self.layer3(out2)
|
| 226 |
+
out4 = self.layer4(out3)
|
| 227 |
+
out3_ds = self.layer3_ds(out3)
|
| 228 |
+
fuse_out34 = self.fuse34(out4, out3_ds)
|
| 229 |
+
stats = self.pool(fuse_out34)
|
| 230 |
+
|
| 231 |
+
embed_a = self.seg_1(stats)
|
| 232 |
+
if self.two_emb_layer:
|
| 233 |
+
out = F.relu(embed_a)
|
| 234 |
+
out = self.seg_bn_1(out)
|
| 235 |
+
embed_b = self.seg_2(out)
|
| 236 |
+
return embed_b
|
| 237 |
+
else:
|
| 238 |
+
return embed_a
|
| 239 |
+
|
| 240 |
+
def forward3(self, x):
|
| 241 |
+
x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T)
|
| 242 |
+
x = x.unsqueeze_(1)
|
| 243 |
+
out = F.relu(self.bn1(self.conv1(x)))
|
| 244 |
+
out1 = self.layer1(out)
|
| 245 |
+
out2 = self.layer2(out1)
|
| 246 |
+
out3 = self.layer3(out2)
|
| 247 |
+
out4 = self.layer4(out3)
|
| 248 |
+
out3_ds = self.layer3_ds(out3)
|
| 249 |
+
fuse_out34 = self.fuse34(out4, out3_ds)
|
| 250 |
+
# print(111111111,fuse_out34.shape)#111111111 torch.Size([16, 2048, 10, 72])
|
| 251 |
+
return fuse_out34.flatten(start_dim=1, end_dim=2).mean(-1)
|
| 252 |
+
# stats = self.pool(fuse_out34)
|
| 253 |
+
#
|
| 254 |
+
# embed_a = self.seg_1(stats)
|
| 255 |
+
# if self.two_emb_layer:
|
| 256 |
+
# out = F.relu(embed_a)
|
| 257 |
+
# out = self.seg_bn_1(out)
|
| 258 |
+
# embed_b = self.seg_2(out)
|
| 259 |
+
# return embed_b
|
| 260 |
+
# else:
|
| 261 |
+
# return embed_a
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
if __name__ == "__main__":
|
| 265 |
+
x = torch.randn(1, 300, 80)
|
| 266 |
+
model = ERes2NetV2(feat_dim=80, embedding_size=192, m_channels=64, baseWidth=26, scale=2, expansion=2)
|
| 267 |
+
model.eval()
|
| 268 |
+
y = model(x)
|
| 269 |
+
print(y.size())
|
| 270 |
+
macs, num_params = profile(model, inputs=(x,))
|
| 271 |
+
print("Params: {} M".format(num_params / 1e6)) # 17.86 M
|
| 272 |
+
print("MACs: {} G".format(macs / 1e9)) # 12.69 G
|
eres2net/ERes2Net_huge.py
ADDED
|
@@ -0,0 +1,289 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker). All Rights Reserved.
|
| 2 |
+
# Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
| 3 |
+
|
| 4 |
+
"""Res2Net implementation is adapted from https://github.com/wenet-e2e/wespeaker.
|
| 5 |
+
ERes2Net incorporates both local and global feature fusion techniques to improve the performance.
|
| 6 |
+
The local feature fusion (LFF) fuses the features within one single residual block to extract the local signal.
|
| 7 |
+
The global feature fusion (GFF) takes acoustic features of different scales as input to aggregate global signal.
|
| 8 |
+
ERes2Net-huge is an upgraded version of ERes2Net that uses a larger number of parameters to achieve better
|
| 9 |
+
recognition performance. Parameters expansion, baseWidth, and scale can be modified to obtain optimal performance.
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
import math
|
| 14 |
+
import torch.nn as nn
|
| 15 |
+
import torch.nn.functional as F
|
| 16 |
+
import pooling_layers as pooling_layers
|
| 17 |
+
from fusion import AFF
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class ReLU(nn.Hardtanh):
|
| 21 |
+
def __init__(self, inplace=False):
|
| 22 |
+
super(ReLU, self).__init__(0, 20, inplace)
|
| 23 |
+
|
| 24 |
+
def __repr__(self):
|
| 25 |
+
inplace_str = "inplace" if self.inplace else ""
|
| 26 |
+
return self.__class__.__name__ + " (" + inplace_str + ")"
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class BasicBlockERes2Net(nn.Module):
|
| 30 |
+
expansion = 4
|
| 31 |
+
|
| 32 |
+
def __init__(self, in_planes, planes, stride=1, baseWidth=24, scale=3):
|
| 33 |
+
super(BasicBlockERes2Net, self).__init__()
|
| 34 |
+
width = int(math.floor(planes * (baseWidth / 64.0)))
|
| 35 |
+
self.conv1 = nn.Conv2d(in_planes, width * scale, kernel_size=1, stride=stride, bias=False)
|
| 36 |
+
self.bn1 = nn.BatchNorm2d(width * scale)
|
| 37 |
+
self.nums = scale
|
| 38 |
+
|
| 39 |
+
convs = []
|
| 40 |
+
bns = []
|
| 41 |
+
for i in range(self.nums):
|
| 42 |
+
convs.append(nn.Conv2d(width, width, kernel_size=3, padding=1, bias=False))
|
| 43 |
+
bns.append(nn.BatchNorm2d(width))
|
| 44 |
+
self.convs = nn.ModuleList(convs)
|
| 45 |
+
self.bns = nn.ModuleList(bns)
|
| 46 |
+
self.relu = ReLU(inplace=True)
|
| 47 |
+
|
| 48 |
+
self.conv3 = nn.Conv2d(width * scale, planes * self.expansion, kernel_size=1, bias=False)
|
| 49 |
+
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
|
| 50 |
+
self.shortcut = nn.Sequential()
|
| 51 |
+
if stride != 1 or in_planes != self.expansion * planes:
|
| 52 |
+
self.shortcut = nn.Sequential(
|
| 53 |
+
nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
|
| 54 |
+
nn.BatchNorm2d(self.expansion * planes),
|
| 55 |
+
)
|
| 56 |
+
self.stride = stride
|
| 57 |
+
self.width = width
|
| 58 |
+
self.scale = scale
|
| 59 |
+
|
| 60 |
+
def forward(self, x):
|
| 61 |
+
residual = x
|
| 62 |
+
|
| 63 |
+
out = self.conv1(x)
|
| 64 |
+
out = self.bn1(out)
|
| 65 |
+
out = self.relu(out)
|
| 66 |
+
spx = torch.split(out, self.width, 1)
|
| 67 |
+
for i in range(self.nums):
|
| 68 |
+
if i == 0:
|
| 69 |
+
sp = spx[i]
|
| 70 |
+
else:
|
| 71 |
+
sp = sp + spx[i]
|
| 72 |
+
sp = self.convs[i](sp)
|
| 73 |
+
sp = self.relu(self.bns[i](sp))
|
| 74 |
+
if i == 0:
|
| 75 |
+
out = sp
|
| 76 |
+
else:
|
| 77 |
+
out = torch.cat((out, sp), 1)
|
| 78 |
+
|
| 79 |
+
out = self.conv3(out)
|
| 80 |
+
out = self.bn3(out)
|
| 81 |
+
|
| 82 |
+
residual = self.shortcut(x)
|
| 83 |
+
out += residual
|
| 84 |
+
out = self.relu(out)
|
| 85 |
+
|
| 86 |
+
return out
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
class BasicBlockERes2Net_diff_AFF(nn.Module):
|
| 90 |
+
expansion = 4
|
| 91 |
+
|
| 92 |
+
def __init__(self, in_planes, planes, stride=1, baseWidth=24, scale=3):
|
| 93 |
+
super(BasicBlockERes2Net_diff_AFF, self).__init__()
|
| 94 |
+
width = int(math.floor(planes * (baseWidth / 64.0)))
|
| 95 |
+
self.conv1 = nn.Conv2d(in_planes, width * scale, kernel_size=1, stride=stride, bias=False)
|
| 96 |
+
self.bn1 = nn.BatchNorm2d(width * scale)
|
| 97 |
+
self.nums = scale
|
| 98 |
+
|
| 99 |
+
convs = []
|
| 100 |
+
fuse_models = []
|
| 101 |
+
bns = []
|
| 102 |
+
for i in range(self.nums):
|
| 103 |
+
convs.append(nn.Conv2d(width, width, kernel_size=3, padding=1, bias=False))
|
| 104 |
+
bns.append(nn.BatchNorm2d(width))
|
| 105 |
+
for j in range(self.nums - 1):
|
| 106 |
+
fuse_models.append(AFF(channels=width))
|
| 107 |
+
|
| 108 |
+
self.convs = nn.ModuleList(convs)
|
| 109 |
+
self.bns = nn.ModuleList(bns)
|
| 110 |
+
self.fuse_models = nn.ModuleList(fuse_models)
|
| 111 |
+
self.relu = ReLU(inplace=True)
|
| 112 |
+
|
| 113 |
+
self.conv3 = nn.Conv2d(width * scale, planes * self.expansion, kernel_size=1, bias=False)
|
| 114 |
+
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
|
| 115 |
+
self.shortcut = nn.Sequential()
|
| 116 |
+
if stride != 1 or in_planes != self.expansion * planes:
|
| 117 |
+
self.shortcut = nn.Sequential(
|
| 118 |
+
nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
|
| 119 |
+
nn.BatchNorm2d(self.expansion * planes),
|
| 120 |
+
)
|
| 121 |
+
self.stride = stride
|
| 122 |
+
self.width = width
|
| 123 |
+
self.scale = scale
|
| 124 |
+
|
| 125 |
+
def forward(self, x):
|
| 126 |
+
residual = x
|
| 127 |
+
|
| 128 |
+
out = self.conv1(x)
|
| 129 |
+
out = self.bn1(out)
|
| 130 |
+
out = self.relu(out)
|
| 131 |
+
spx = torch.split(out, self.width, 1)
|
| 132 |
+
for i in range(self.nums):
|
| 133 |
+
if i == 0:
|
| 134 |
+
sp = spx[i]
|
| 135 |
+
else:
|
| 136 |
+
sp = self.fuse_models[i - 1](sp, spx[i])
|
| 137 |
+
|
| 138 |
+
sp = self.convs[i](sp)
|
| 139 |
+
sp = self.relu(self.bns[i](sp))
|
| 140 |
+
if i == 0:
|
| 141 |
+
out = sp
|
| 142 |
+
else:
|
| 143 |
+
out = torch.cat((out, sp), 1)
|
| 144 |
+
|
| 145 |
+
out = self.conv3(out)
|
| 146 |
+
out = self.bn3(out)
|
| 147 |
+
|
| 148 |
+
residual = self.shortcut(x)
|
| 149 |
+
out += residual
|
| 150 |
+
out = self.relu(out)
|
| 151 |
+
|
| 152 |
+
return out
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
class ERes2Net(nn.Module):
|
| 156 |
+
def __init__(
|
| 157 |
+
self,
|
| 158 |
+
block=BasicBlockERes2Net,
|
| 159 |
+
block_fuse=BasicBlockERes2Net_diff_AFF,
|
| 160 |
+
num_blocks=[3, 4, 6, 3],
|
| 161 |
+
m_channels=64,
|
| 162 |
+
feat_dim=80,
|
| 163 |
+
embedding_size=192,
|
| 164 |
+
pooling_func="TSTP",
|
| 165 |
+
two_emb_layer=False,
|
| 166 |
+
):
|
| 167 |
+
super(ERes2Net, self).__init__()
|
| 168 |
+
self.in_planes = m_channels
|
| 169 |
+
self.feat_dim = feat_dim
|
| 170 |
+
self.embedding_size = embedding_size
|
| 171 |
+
self.stats_dim = int(feat_dim / 8) * m_channels * 8
|
| 172 |
+
self.two_emb_layer = two_emb_layer
|
| 173 |
+
|
| 174 |
+
self.conv1 = nn.Conv2d(1, m_channels, kernel_size=3, stride=1, padding=1, bias=False)
|
| 175 |
+
self.bn1 = nn.BatchNorm2d(m_channels)
|
| 176 |
+
|
| 177 |
+
self.layer1 = self._make_layer(block, m_channels, num_blocks[0], stride=1)
|
| 178 |
+
self.layer2 = self._make_layer(block, m_channels * 2, num_blocks[1], stride=2)
|
| 179 |
+
self.layer3 = self._make_layer(block_fuse, m_channels * 4, num_blocks[2], stride=2)
|
| 180 |
+
self.layer4 = self._make_layer(block_fuse, m_channels * 8, num_blocks[3], stride=2)
|
| 181 |
+
|
| 182 |
+
self.layer1_downsample = nn.Conv2d(
|
| 183 |
+
m_channels * 4, m_channels * 8, kernel_size=3, padding=1, stride=2, bias=False
|
| 184 |
+
)
|
| 185 |
+
self.layer2_downsample = nn.Conv2d(
|
| 186 |
+
m_channels * 8, m_channels * 16, kernel_size=3, padding=1, stride=2, bias=False
|
| 187 |
+
)
|
| 188 |
+
self.layer3_downsample = nn.Conv2d(
|
| 189 |
+
m_channels * 16, m_channels * 32, kernel_size=3, padding=1, stride=2, bias=False
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
self.fuse_mode12 = AFF(channels=m_channels * 8)
|
| 193 |
+
self.fuse_mode123 = AFF(channels=m_channels * 16)
|
| 194 |
+
self.fuse_mode1234 = AFF(channels=m_channels * 32)
|
| 195 |
+
|
| 196 |
+
self.n_stats = 1 if pooling_func == "TAP" or pooling_func == "TSDP" else 2
|
| 197 |
+
self.pool = getattr(pooling_layers, pooling_func)(in_dim=self.stats_dim * block.expansion)
|
| 198 |
+
self.seg_1 = nn.Linear(self.stats_dim * block.expansion * self.n_stats, embedding_size)
|
| 199 |
+
if self.two_emb_layer:
|
| 200 |
+
self.seg_bn_1 = nn.BatchNorm1d(embedding_size, affine=False)
|
| 201 |
+
self.seg_2 = nn.Linear(embedding_size, embedding_size)
|
| 202 |
+
else:
|
| 203 |
+
self.seg_bn_1 = nn.Identity()
|
| 204 |
+
self.seg_2 = nn.Identity()
|
| 205 |
+
|
| 206 |
+
def _make_layer(self, block, planes, num_blocks, stride):
|
| 207 |
+
strides = [stride] + [1] * (num_blocks - 1)
|
| 208 |
+
layers = []
|
| 209 |
+
for stride in strides:
|
| 210 |
+
layers.append(block(self.in_planes, planes, stride))
|
| 211 |
+
self.in_planes = planes * block.expansion
|
| 212 |
+
return nn.Sequential(*layers)
|
| 213 |
+
|
| 214 |
+
def forward(self, x):
|
| 215 |
+
x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T)
|
| 216 |
+
|
| 217 |
+
x = x.unsqueeze_(1)
|
| 218 |
+
out = F.relu(self.bn1(self.conv1(x)))
|
| 219 |
+
out1 = self.layer1(out)
|
| 220 |
+
out2 = self.layer2(out1)
|
| 221 |
+
out1_downsample = self.layer1_downsample(out1)
|
| 222 |
+
fuse_out12 = self.fuse_mode12(out2, out1_downsample)
|
| 223 |
+
out3 = self.layer3(out2)
|
| 224 |
+
fuse_out12_downsample = self.layer2_downsample(fuse_out12)
|
| 225 |
+
fuse_out123 = self.fuse_mode123(out3, fuse_out12_downsample)
|
| 226 |
+
out4 = self.layer4(out3)
|
| 227 |
+
fuse_out123_downsample = self.layer3_downsample(fuse_out123)
|
| 228 |
+
fuse_out1234 = self.fuse_mode1234(out4, fuse_out123_downsample)
|
| 229 |
+
stats = self.pool(fuse_out1234)
|
| 230 |
+
|
| 231 |
+
embed_a = self.seg_1(stats)
|
| 232 |
+
if self.two_emb_layer:
|
| 233 |
+
out = F.relu(embed_a)
|
| 234 |
+
out = self.seg_bn_1(out)
|
| 235 |
+
embed_b = self.seg_2(out)
|
| 236 |
+
return embed_b
|
| 237 |
+
else:
|
| 238 |
+
return embed_a
|
| 239 |
+
|
| 240 |
+
def forward2(self, x, if_mean):
|
| 241 |
+
x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T)
|
| 242 |
+
|
| 243 |
+
x = x.unsqueeze_(1)
|
| 244 |
+
out = F.relu(self.bn1(self.conv1(x)))
|
| 245 |
+
out1 = self.layer1(out)
|
| 246 |
+
out2 = self.layer2(out1)
|
| 247 |
+
out1_downsample = self.layer1_downsample(out1)
|
| 248 |
+
fuse_out12 = self.fuse_mode12(out2, out1_downsample)
|
| 249 |
+
out3 = self.layer3(out2)
|
| 250 |
+
fuse_out12_downsample = self.layer2_downsample(fuse_out12)
|
| 251 |
+
fuse_out123 = self.fuse_mode123(out3, fuse_out12_downsample)
|
| 252 |
+
out4 = self.layer4(out3)
|
| 253 |
+
fuse_out123_downsample = self.layer3_downsample(fuse_out123)
|
| 254 |
+
fuse_out1234 = self.fuse_mode1234(out4, fuse_out123_downsample).flatten(start_dim=1, end_dim=2) # bs,20480,T
|
| 255 |
+
if if_mean == False:
|
| 256 |
+
mean = fuse_out1234[0].transpose(1, 0) # (T,20480),bs=T
|
| 257 |
+
else:
|
| 258 |
+
mean = fuse_out1234.mean(2) # bs,20480
|
| 259 |
+
mean_std = torch.cat([mean, torch.zeros_like(mean)], 1)
|
| 260 |
+
return self.seg_1(mean_std) # (T,192)
|
| 261 |
+
|
| 262 |
+
# stats = self.pool(fuse_out1234)
|
| 263 |
+
# if self.two_emb_layer:
|
| 264 |
+
# out = F.relu(embed_a)
|
| 265 |
+
# out = self.seg_bn_1(out)
|
| 266 |
+
# embed_b = self.seg_2(out)
|
| 267 |
+
# return embed_b
|
| 268 |
+
# else:
|
| 269 |
+
# return embed_a
|
| 270 |
+
|
| 271 |
+
def forward3(self, x):
|
| 272 |
+
x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T)
|
| 273 |
+
|
| 274 |
+
x = x.unsqueeze_(1)
|
| 275 |
+
out = F.relu(self.bn1(self.conv1(x)))
|
| 276 |
+
out1 = self.layer1(out)
|
| 277 |
+
out2 = self.layer2(out1)
|
| 278 |
+
out1_downsample = self.layer1_downsample(out1)
|
| 279 |
+
fuse_out12 = self.fuse_mode12(out2, out1_downsample)
|
| 280 |
+
out3 = self.layer3(out2)
|
| 281 |
+
fuse_out12_downsample = self.layer2_downsample(fuse_out12)
|
| 282 |
+
fuse_out123 = self.fuse_mode123(out3, fuse_out12_downsample)
|
| 283 |
+
out4 = self.layer4(out3)
|
| 284 |
+
fuse_out123_downsample = self.layer3_downsample(fuse_out123)
|
| 285 |
+
fuse_out1234 = self.fuse_mode1234(out4, fuse_out123_downsample).flatten(start_dim=1, end_dim=2).mean(-1)
|
| 286 |
+
return fuse_out1234
|
| 287 |
+
# print(fuse_out1234.shape)
|
| 288 |
+
# print(fuse_out1234.flatten(start_dim=1,end_dim=2).shape)
|
| 289 |
+
# pdb.set_trace()
|
eres2net/fusion.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker). All Rights Reserved.
|
| 2 |
+
# Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class AFF(nn.Module):
|
| 9 |
+
def __init__(self, channels=64, r=4):
|
| 10 |
+
super(AFF, self).__init__()
|
| 11 |
+
inter_channels = int(channels // r)
|
| 12 |
+
|
| 13 |
+
self.local_att = nn.Sequential(
|
| 14 |
+
nn.Conv2d(channels * 2, inter_channels, kernel_size=1, stride=1, padding=0),
|
| 15 |
+
nn.BatchNorm2d(inter_channels),
|
| 16 |
+
nn.SiLU(inplace=True),
|
| 17 |
+
nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
|
| 18 |
+
nn.BatchNorm2d(channels),
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
def forward(self, x, ds_y):
|
| 22 |
+
xa = torch.cat((x, ds_y), dim=1)
|
| 23 |
+
x_att = self.local_att(xa)
|
| 24 |
+
x_att = 1.0 + torch.tanh(x_att)
|
| 25 |
+
xo = torch.mul(x, x_att) + torch.mul(ds_y, 2.0 - x_att)
|
| 26 |
+
|
| 27 |
+
return xo
|
eres2net/kaldi.py
ADDED
|
@@ -0,0 +1,844 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
from typing import Tuple
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torchaudio
|
| 6 |
+
from torch import Tensor
|
| 7 |
+
|
| 8 |
+
__all__ = [
|
| 9 |
+
"get_mel_banks",
|
| 10 |
+
"inverse_mel_scale",
|
| 11 |
+
"inverse_mel_scale_scalar",
|
| 12 |
+
"mel_scale",
|
| 13 |
+
"mel_scale_scalar",
|
| 14 |
+
"spectrogram",
|
| 15 |
+
"fbank",
|
| 16 |
+
"mfcc",
|
| 17 |
+
"vtln_warp_freq",
|
| 18 |
+
"vtln_warp_mel_freq",
|
| 19 |
+
]
|
| 20 |
+
|
| 21 |
+
# numeric_limits<float>::epsilon() 1.1920928955078125e-07
|
| 22 |
+
EPSILON = torch.tensor(torch.finfo(torch.float).eps)
|
| 23 |
+
# 1 milliseconds = 0.001 seconds
|
| 24 |
+
MILLISECONDS_TO_SECONDS = 0.001
|
| 25 |
+
|
| 26 |
+
# window types
|
| 27 |
+
HAMMING = "hamming"
|
| 28 |
+
HANNING = "hanning"
|
| 29 |
+
POVEY = "povey"
|
| 30 |
+
RECTANGULAR = "rectangular"
|
| 31 |
+
BLACKMAN = "blackman"
|
| 32 |
+
WINDOWS = [HAMMING, HANNING, POVEY, RECTANGULAR, BLACKMAN]
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def _get_epsilon(device, dtype):
|
| 36 |
+
return EPSILON.to(device=device, dtype=dtype)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def _next_power_of_2(x: int) -> int:
|
| 40 |
+
r"""Returns the smallest power of 2 that is greater than x"""
|
| 41 |
+
return 1 if x == 0 else 2 ** (x - 1).bit_length()
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def _get_strided(waveform: Tensor, window_size: int, window_shift: int, snip_edges: bool) -> Tensor:
|
| 45 |
+
r"""Given a waveform (1D tensor of size ``num_samples``), it returns a 2D tensor (m, ``window_size``)
|
| 46 |
+
representing how the window is shifted along the waveform. Each row is a frame.
|
| 47 |
+
|
| 48 |
+
Args:
|
| 49 |
+
waveform (Tensor): Tensor of size ``num_samples``
|
| 50 |
+
window_size (int): Frame length
|
| 51 |
+
window_shift (int): Frame shift
|
| 52 |
+
snip_edges (bool): If True, end effects will be handled by outputting only frames that completely fit
|
| 53 |
+
in the file, and the number of frames depends on the frame_length. If False, the number of frames
|
| 54 |
+
depends only on the frame_shift, and we reflect the data at the ends.
|
| 55 |
+
|
| 56 |
+
Returns:
|
| 57 |
+
Tensor: 2D tensor of size (m, ``window_size``) where each row is a frame
|
| 58 |
+
"""
|
| 59 |
+
assert waveform.dim() == 1
|
| 60 |
+
num_samples = waveform.size(0)
|
| 61 |
+
strides = (window_shift * waveform.stride(0), waveform.stride(0))
|
| 62 |
+
|
| 63 |
+
if snip_edges:
|
| 64 |
+
if num_samples < window_size:
|
| 65 |
+
return torch.empty((0, 0), dtype=waveform.dtype, device=waveform.device)
|
| 66 |
+
else:
|
| 67 |
+
m = 1 + (num_samples - window_size) // window_shift
|
| 68 |
+
else:
|
| 69 |
+
reversed_waveform = torch.flip(waveform, [0])
|
| 70 |
+
m = (num_samples + (window_shift // 2)) // window_shift
|
| 71 |
+
pad = window_size // 2 - window_shift // 2
|
| 72 |
+
pad_right = reversed_waveform
|
| 73 |
+
if pad > 0:
|
| 74 |
+
# torch.nn.functional.pad returns [2,1,0,1,2] for 'reflect'
|
| 75 |
+
# but we want [2, 1, 0, 0, 1, 2]
|
| 76 |
+
pad_left = reversed_waveform[-pad:]
|
| 77 |
+
waveform = torch.cat((pad_left, waveform, pad_right), dim=0)
|
| 78 |
+
else:
|
| 79 |
+
# pad is negative so we want to trim the waveform at the front
|
| 80 |
+
waveform = torch.cat((waveform[-pad:], pad_right), dim=0)
|
| 81 |
+
|
| 82 |
+
sizes = (m, window_size)
|
| 83 |
+
return waveform.as_strided(sizes, strides)
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def _feature_window_function(
|
| 87 |
+
window_type: str,
|
| 88 |
+
window_size: int,
|
| 89 |
+
blackman_coeff: float,
|
| 90 |
+
device: torch.device,
|
| 91 |
+
dtype: int,
|
| 92 |
+
) -> Tensor:
|
| 93 |
+
r"""Returns a window function with the given type and size"""
|
| 94 |
+
if window_type == HANNING:
|
| 95 |
+
return torch.hann_window(window_size, periodic=False, device=device, dtype=dtype)
|
| 96 |
+
elif window_type == HAMMING:
|
| 97 |
+
return torch.hamming_window(window_size, periodic=False, alpha=0.54, beta=0.46, device=device, dtype=dtype)
|
| 98 |
+
elif window_type == POVEY:
|
| 99 |
+
# like hanning but goes to zero at edges
|
| 100 |
+
return torch.hann_window(window_size, periodic=False, device=device, dtype=dtype).pow(0.85)
|
| 101 |
+
elif window_type == RECTANGULAR:
|
| 102 |
+
return torch.ones(window_size, device=device, dtype=dtype)
|
| 103 |
+
elif window_type == BLACKMAN:
|
| 104 |
+
a = 2 * math.pi / (window_size - 1)
|
| 105 |
+
window_function = torch.arange(window_size, device=device, dtype=dtype)
|
| 106 |
+
# can't use torch.blackman_window as they use different coefficients
|
| 107 |
+
return (
|
| 108 |
+
blackman_coeff
|
| 109 |
+
- 0.5 * torch.cos(a * window_function)
|
| 110 |
+
+ (0.5 - blackman_coeff) * torch.cos(2 * a * window_function)
|
| 111 |
+
).to(device=device, dtype=dtype)
|
| 112 |
+
else:
|
| 113 |
+
raise Exception("Invalid window type " + window_type)
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def _get_log_energy(strided_input: Tensor, epsilon: Tensor, energy_floor: float) -> Tensor:
|
| 117 |
+
r"""Returns the log energy of size (m) for a strided_input (m,*)"""
|
| 118 |
+
device, dtype = strided_input.device, strided_input.dtype
|
| 119 |
+
log_energy = torch.max(strided_input.pow(2).sum(1), epsilon).log() # size (m)
|
| 120 |
+
if energy_floor == 0.0:
|
| 121 |
+
return log_energy
|
| 122 |
+
return torch.max(log_energy, torch.tensor(math.log(energy_floor), device=device, dtype=dtype))
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def _get_waveform_and_window_properties(
|
| 126 |
+
waveform: Tensor,
|
| 127 |
+
channel: int,
|
| 128 |
+
sample_frequency: float,
|
| 129 |
+
frame_shift: float,
|
| 130 |
+
frame_length: float,
|
| 131 |
+
round_to_power_of_two: bool,
|
| 132 |
+
preemphasis_coefficient: float,
|
| 133 |
+
) -> Tuple[Tensor, int, int, int]:
|
| 134 |
+
r"""Gets the waveform and window properties"""
|
| 135 |
+
channel = max(channel, 0)
|
| 136 |
+
assert channel < waveform.size(0), "Invalid channel {} for size {}".format(channel, waveform.size(0))
|
| 137 |
+
waveform = waveform[channel, :] # size (n)
|
| 138 |
+
window_shift = int(sample_frequency * frame_shift * MILLISECONDS_TO_SECONDS)
|
| 139 |
+
window_size = int(sample_frequency * frame_length * MILLISECONDS_TO_SECONDS)
|
| 140 |
+
padded_window_size = _next_power_of_2(window_size) if round_to_power_of_two else window_size
|
| 141 |
+
|
| 142 |
+
assert 2 <= window_size <= len(waveform), "choose a window size {} that is [2, {}]".format(
|
| 143 |
+
window_size, len(waveform)
|
| 144 |
+
)
|
| 145 |
+
assert 0 < window_shift, "`window_shift` must be greater than 0"
|
| 146 |
+
assert padded_window_size % 2 == 0, (
|
| 147 |
+
"the padded `window_size` must be divisible by two. use `round_to_power_of_two` or change `frame_length`"
|
| 148 |
+
)
|
| 149 |
+
assert 0.0 <= preemphasis_coefficient <= 1.0, "`preemphasis_coefficient` must be between [0,1]"
|
| 150 |
+
assert sample_frequency > 0, "`sample_frequency` must be greater than zero"
|
| 151 |
+
return waveform, window_shift, window_size, padded_window_size
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
def _get_window(
|
| 155 |
+
waveform: Tensor,
|
| 156 |
+
padded_window_size: int,
|
| 157 |
+
window_size: int,
|
| 158 |
+
window_shift: int,
|
| 159 |
+
window_type: str,
|
| 160 |
+
blackman_coeff: float,
|
| 161 |
+
snip_edges: bool,
|
| 162 |
+
raw_energy: bool,
|
| 163 |
+
energy_floor: float,
|
| 164 |
+
dither: float,
|
| 165 |
+
remove_dc_offset: bool,
|
| 166 |
+
preemphasis_coefficient: float,
|
| 167 |
+
) -> Tuple[Tensor, Tensor]:
|
| 168 |
+
r"""Gets a window and its log energy
|
| 169 |
+
|
| 170 |
+
Returns:
|
| 171 |
+
(Tensor, Tensor): strided_input of size (m, ``padded_window_size``) and signal_log_energy of size (m)
|
| 172 |
+
"""
|
| 173 |
+
device, dtype = waveform.device, waveform.dtype
|
| 174 |
+
epsilon = _get_epsilon(device, dtype)
|
| 175 |
+
|
| 176 |
+
# size (m, window_size)
|
| 177 |
+
strided_input = _get_strided(waveform, window_size, window_shift, snip_edges)
|
| 178 |
+
|
| 179 |
+
if dither != 0.0:
|
| 180 |
+
rand_gauss = torch.randn(strided_input.shape, device=device, dtype=dtype)
|
| 181 |
+
strided_input = strided_input + rand_gauss * dither
|
| 182 |
+
|
| 183 |
+
if remove_dc_offset:
|
| 184 |
+
# Subtract each row/frame by its mean
|
| 185 |
+
row_means = torch.mean(strided_input, dim=1).unsqueeze(1) # size (m, 1)
|
| 186 |
+
strided_input = strided_input - row_means
|
| 187 |
+
|
| 188 |
+
if raw_energy:
|
| 189 |
+
# Compute the log energy of each row/frame before applying preemphasis and
|
| 190 |
+
# window function
|
| 191 |
+
signal_log_energy = _get_log_energy(strided_input, epsilon, energy_floor) # size (m)
|
| 192 |
+
|
| 193 |
+
if preemphasis_coefficient != 0.0:
|
| 194 |
+
# strided_input[i,j] -= preemphasis_coefficient * strided_input[i, max(0, j-1)] for all i,j
|
| 195 |
+
offset_strided_input = torch.nn.functional.pad(strided_input.unsqueeze(0), (1, 0), mode="replicate").squeeze(
|
| 196 |
+
0
|
| 197 |
+
) # size (m, window_size + 1)
|
| 198 |
+
strided_input = strided_input - preemphasis_coefficient * offset_strided_input[:, :-1]
|
| 199 |
+
|
| 200 |
+
# Apply window_function to each row/frame
|
| 201 |
+
window_function = _feature_window_function(window_type, window_size, blackman_coeff, device, dtype).unsqueeze(
|
| 202 |
+
0
|
| 203 |
+
) # size (1, window_size)
|
| 204 |
+
strided_input = strided_input * window_function # size (m, window_size)
|
| 205 |
+
|
| 206 |
+
# Pad columns with zero until we reach size (m, padded_window_size)
|
| 207 |
+
if padded_window_size != window_size:
|
| 208 |
+
padding_right = padded_window_size - window_size
|
| 209 |
+
strided_input = torch.nn.functional.pad(
|
| 210 |
+
strided_input.unsqueeze(0), (0, padding_right), mode="constant", value=0
|
| 211 |
+
).squeeze(0)
|
| 212 |
+
|
| 213 |
+
# Compute energy after window function (not the raw one)
|
| 214 |
+
if not raw_energy:
|
| 215 |
+
signal_log_energy = _get_log_energy(strided_input, epsilon, energy_floor) # size (m)
|
| 216 |
+
|
| 217 |
+
return strided_input, signal_log_energy
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
def _subtract_column_mean(tensor: Tensor, subtract_mean: bool) -> Tensor:
|
| 221 |
+
# subtracts the column mean of the tensor size (m, n) if subtract_mean=True
|
| 222 |
+
# it returns size (m, n)
|
| 223 |
+
if subtract_mean:
|
| 224 |
+
col_means = torch.mean(tensor, dim=0).unsqueeze(0)
|
| 225 |
+
tensor = tensor - col_means
|
| 226 |
+
return tensor
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
def spectrogram(
|
| 230 |
+
waveform: Tensor,
|
| 231 |
+
blackman_coeff: float = 0.42,
|
| 232 |
+
channel: int = -1,
|
| 233 |
+
dither: float = 0.0,
|
| 234 |
+
energy_floor: float = 1.0,
|
| 235 |
+
frame_length: float = 25.0,
|
| 236 |
+
frame_shift: float = 10.0,
|
| 237 |
+
min_duration: float = 0.0,
|
| 238 |
+
preemphasis_coefficient: float = 0.97,
|
| 239 |
+
raw_energy: bool = True,
|
| 240 |
+
remove_dc_offset: bool = True,
|
| 241 |
+
round_to_power_of_two: bool = True,
|
| 242 |
+
sample_frequency: float = 16000.0,
|
| 243 |
+
snip_edges: bool = True,
|
| 244 |
+
subtract_mean: bool = False,
|
| 245 |
+
window_type: str = POVEY,
|
| 246 |
+
) -> Tensor:
|
| 247 |
+
r"""Create a spectrogram from a raw audio signal. This matches the input/output of Kaldi's
|
| 248 |
+
compute-spectrogram-feats.
|
| 249 |
+
|
| 250 |
+
Args:
|
| 251 |
+
waveform (Tensor): Tensor of audio of size (c, n) where c is in the range [0,2)
|
| 252 |
+
blackman_coeff (float, optional): Constant coefficient for generalized Blackman window. (Default: ``0.42``)
|
| 253 |
+
channel (int, optional): Channel to extract (-1 -> expect mono, 0 -> left, 1 -> right) (Default: ``-1``)
|
| 254 |
+
dither (float, optional): Dithering constant (0.0 means no dither). If you turn this off, you should set
|
| 255 |
+
the energy_floor option, e.g. to 1.0 or 0.1 (Default: ``0.0``)
|
| 256 |
+
energy_floor (float, optional): Floor on energy (absolute, not relative) in Spectrogram computation. Caution:
|
| 257 |
+
this floor is applied to the zeroth component, representing the total signal energy. The floor on the
|
| 258 |
+
individual spectrogram elements is fixed at std::numeric_limits<float>::epsilon(). (Default: ``1.0``)
|
| 259 |
+
frame_length (float, optional): Frame length in milliseconds (Default: ``25.0``)
|
| 260 |
+
frame_shift (float, optional): Frame shift in milliseconds (Default: ``10.0``)
|
| 261 |
+
min_duration (float, optional): Minimum duration of segments to process (in seconds). (Default: ``0.0``)
|
| 262 |
+
preemphasis_coefficient (float, optional): Coefficient for use in signal preemphasis (Default: ``0.97``)
|
| 263 |
+
raw_energy (bool, optional): If True, compute energy before preemphasis and windowing (Default: ``True``)
|
| 264 |
+
remove_dc_offset (bool, optional): Subtract mean from waveform on each frame (Default: ``True``)
|
| 265 |
+
round_to_power_of_two (bool, optional): If True, round window size to power of two by zero-padding input
|
| 266 |
+
to FFT. (Default: ``True``)
|
| 267 |
+
sample_frequency (float, optional): Waveform data sample frequency (must match the waveform file, if
|
| 268 |
+
specified there) (Default: ``16000.0``)
|
| 269 |
+
snip_edges (bool, optional): If True, end effects will be handled by outputting only frames that completely fit
|
| 270 |
+
in the file, and the number of frames depends on the frame_length. If False, the number of frames
|
| 271 |
+
depends only on the frame_shift, and we reflect the data at the ends. (Default: ``True``)
|
| 272 |
+
subtract_mean (bool, optional): Subtract mean of each feature file [CMS]; not recommended to do
|
| 273 |
+
it this way. (Default: ``False``)
|
| 274 |
+
window_type (str, optional): Type of window ('hamming'|'hanning'|'povey'|'rectangular'|'blackman')
|
| 275 |
+
(Default: ``'povey'``)
|
| 276 |
+
|
| 277 |
+
Returns:
|
| 278 |
+
Tensor: A spectrogram identical to what Kaldi would output. The shape is
|
| 279 |
+
(m, ``padded_window_size // 2 + 1``) where m is calculated in _get_strided
|
| 280 |
+
"""
|
| 281 |
+
device, dtype = waveform.device, waveform.dtype
|
| 282 |
+
epsilon = _get_epsilon(device, dtype)
|
| 283 |
+
|
| 284 |
+
waveform, window_shift, window_size, padded_window_size = _get_waveform_and_window_properties(
|
| 285 |
+
waveform, channel, sample_frequency, frame_shift, frame_length, round_to_power_of_two, preemphasis_coefficient
|
| 286 |
+
)
|
| 287 |
+
|
| 288 |
+
if len(waveform) < min_duration * sample_frequency:
|
| 289 |
+
# signal is too short
|
| 290 |
+
return torch.empty(0)
|
| 291 |
+
|
| 292 |
+
strided_input, signal_log_energy = _get_window(
|
| 293 |
+
waveform,
|
| 294 |
+
padded_window_size,
|
| 295 |
+
window_size,
|
| 296 |
+
window_shift,
|
| 297 |
+
window_type,
|
| 298 |
+
blackman_coeff,
|
| 299 |
+
snip_edges,
|
| 300 |
+
raw_energy,
|
| 301 |
+
energy_floor,
|
| 302 |
+
dither,
|
| 303 |
+
remove_dc_offset,
|
| 304 |
+
preemphasis_coefficient,
|
| 305 |
+
)
|
| 306 |
+
|
| 307 |
+
# size (m, padded_window_size // 2 + 1, 2)
|
| 308 |
+
fft = torch.fft.rfft(strided_input)
|
| 309 |
+
|
| 310 |
+
# Convert the FFT into a power spectrum
|
| 311 |
+
power_spectrum = torch.max(fft.abs().pow(2.0), epsilon).log() # size (m, padded_window_size // 2 + 1)
|
| 312 |
+
power_spectrum[:, 0] = signal_log_energy
|
| 313 |
+
|
| 314 |
+
power_spectrum = _subtract_column_mean(power_spectrum, subtract_mean)
|
| 315 |
+
return power_spectrum
|
| 316 |
+
|
| 317 |
+
|
| 318 |
+
def inverse_mel_scale_scalar(mel_freq: float) -> float:
|
| 319 |
+
return 700.0 * (math.exp(mel_freq / 1127.0) - 1.0)
|
| 320 |
+
|
| 321 |
+
|
| 322 |
+
def inverse_mel_scale(mel_freq: Tensor) -> Tensor:
|
| 323 |
+
return 700.0 * ((mel_freq / 1127.0).exp() - 1.0)
|
| 324 |
+
|
| 325 |
+
|
| 326 |
+
def mel_scale_scalar(freq: float) -> float:
|
| 327 |
+
return 1127.0 * math.log(1.0 + freq / 700.0)
|
| 328 |
+
|
| 329 |
+
|
| 330 |
+
def mel_scale(freq: Tensor) -> Tensor:
|
| 331 |
+
return 1127.0 * (1.0 + freq / 700.0).log()
|
| 332 |
+
|
| 333 |
+
|
| 334 |
+
def vtln_warp_freq(
|
| 335 |
+
vtln_low_cutoff: float,
|
| 336 |
+
vtln_high_cutoff: float,
|
| 337 |
+
low_freq: float,
|
| 338 |
+
high_freq: float,
|
| 339 |
+
vtln_warp_factor: float,
|
| 340 |
+
freq: Tensor,
|
| 341 |
+
) -> Tensor:
|
| 342 |
+
r"""This computes a VTLN warping function that is not the same as HTK's one,
|
| 343 |
+
but has similar inputs (this function has the advantage of never producing
|
| 344 |
+
empty bins).
|
| 345 |
+
|
| 346 |
+
This function computes a warp function F(freq), defined between low_freq
|
| 347 |
+
and high_freq inclusive, with the following properties:
|
| 348 |
+
F(low_freq) == low_freq
|
| 349 |
+
F(high_freq) == high_freq
|
| 350 |
+
The function is continuous and piecewise linear with two inflection
|
| 351 |
+
points.
|
| 352 |
+
The lower inflection point (measured in terms of the unwarped
|
| 353 |
+
frequency) is at frequency l, determined as described below.
|
| 354 |
+
The higher inflection point is at a frequency h, determined as
|
| 355 |
+
described below.
|
| 356 |
+
If l <= f <= h, then F(f) = f/vtln_warp_factor.
|
| 357 |
+
If the higher inflection point (measured in terms of the unwarped
|
| 358 |
+
frequency) is at h, then max(h, F(h)) == vtln_high_cutoff.
|
| 359 |
+
Since (by the last point) F(h) == h/vtln_warp_factor, then
|
| 360 |
+
max(h, h/vtln_warp_factor) == vtln_high_cutoff, so
|
| 361 |
+
h = vtln_high_cutoff / max(1, 1/vtln_warp_factor).
|
| 362 |
+
= vtln_high_cutoff * min(1, vtln_warp_factor).
|
| 363 |
+
If the lower inflection point (measured in terms of the unwarped
|
| 364 |
+
frequency) is at l, then min(l, F(l)) == vtln_low_cutoff
|
| 365 |
+
This implies that l = vtln_low_cutoff / min(1, 1/vtln_warp_factor)
|
| 366 |
+
= vtln_low_cutoff * max(1, vtln_warp_factor)
|
| 367 |
+
Args:
|
| 368 |
+
vtln_low_cutoff (float): Lower frequency cutoffs for VTLN
|
| 369 |
+
vtln_high_cutoff (float): Upper frequency cutoffs for VTLN
|
| 370 |
+
low_freq (float): Lower frequency cutoffs in mel computation
|
| 371 |
+
high_freq (float): Upper frequency cutoffs in mel computation
|
| 372 |
+
vtln_warp_factor (float): Vtln warp factor
|
| 373 |
+
freq (Tensor): given frequency in Hz
|
| 374 |
+
|
| 375 |
+
Returns:
|
| 376 |
+
Tensor: Freq after vtln warp
|
| 377 |
+
"""
|
| 378 |
+
assert vtln_low_cutoff > low_freq, "be sure to set the vtln_low option higher than low_freq"
|
| 379 |
+
assert vtln_high_cutoff < high_freq, "be sure to set the vtln_high option lower than high_freq [or negative]"
|
| 380 |
+
l = vtln_low_cutoff * max(1.0, vtln_warp_factor)
|
| 381 |
+
h = vtln_high_cutoff * min(1.0, vtln_warp_factor)
|
| 382 |
+
scale = 1.0 / vtln_warp_factor
|
| 383 |
+
Fl = scale * l # F(l)
|
| 384 |
+
Fh = scale * h # F(h)
|
| 385 |
+
assert l > low_freq and h < high_freq
|
| 386 |
+
# slope of left part of the 3-piece linear function
|
| 387 |
+
scale_left = (Fl - low_freq) / (l - low_freq)
|
| 388 |
+
# [slope of center part is just "scale"]
|
| 389 |
+
|
| 390 |
+
# slope of right part of the 3-piece linear function
|
| 391 |
+
scale_right = (high_freq - Fh) / (high_freq - h)
|
| 392 |
+
|
| 393 |
+
res = torch.empty_like(freq)
|
| 394 |
+
|
| 395 |
+
outside_low_high_freq = torch.lt(freq, low_freq) | torch.gt(freq, high_freq) # freq < low_freq || freq > high_freq
|
| 396 |
+
before_l = torch.lt(freq, l) # freq < l
|
| 397 |
+
before_h = torch.lt(freq, h) # freq < h
|
| 398 |
+
after_h = torch.ge(freq, h) # freq >= h
|
| 399 |
+
|
| 400 |
+
# order of operations matter here (since there is overlapping frequency regions)
|
| 401 |
+
res[after_h] = high_freq + scale_right * (freq[after_h] - high_freq)
|
| 402 |
+
res[before_h] = scale * freq[before_h]
|
| 403 |
+
res[before_l] = low_freq + scale_left * (freq[before_l] - low_freq)
|
| 404 |
+
res[outside_low_high_freq] = freq[outside_low_high_freq]
|
| 405 |
+
|
| 406 |
+
return res
|
| 407 |
+
|
| 408 |
+
|
| 409 |
+
def vtln_warp_mel_freq(
|
| 410 |
+
vtln_low_cutoff: float,
|
| 411 |
+
vtln_high_cutoff: float,
|
| 412 |
+
low_freq,
|
| 413 |
+
high_freq: float,
|
| 414 |
+
vtln_warp_factor: float,
|
| 415 |
+
mel_freq: Tensor,
|
| 416 |
+
) -> Tensor:
|
| 417 |
+
r"""
|
| 418 |
+
Args:
|
| 419 |
+
vtln_low_cutoff (float): Lower frequency cutoffs for VTLN
|
| 420 |
+
vtln_high_cutoff (float): Upper frequency cutoffs for VTLN
|
| 421 |
+
low_freq (float): Lower frequency cutoffs in mel computation
|
| 422 |
+
high_freq (float): Upper frequency cutoffs in mel computation
|
| 423 |
+
vtln_warp_factor (float): Vtln warp factor
|
| 424 |
+
mel_freq (Tensor): Given frequency in Mel
|
| 425 |
+
|
| 426 |
+
Returns:
|
| 427 |
+
Tensor: ``mel_freq`` after vtln warp
|
| 428 |
+
"""
|
| 429 |
+
return mel_scale(
|
| 430 |
+
vtln_warp_freq(
|
| 431 |
+
vtln_low_cutoff, vtln_high_cutoff, low_freq, high_freq, vtln_warp_factor, inverse_mel_scale(mel_freq)
|
| 432 |
+
)
|
| 433 |
+
)
|
| 434 |
+
|
| 435 |
+
|
| 436 |
+
def get_mel_banks(
|
| 437 |
+
num_bins: int,
|
| 438 |
+
window_length_padded: int,
|
| 439 |
+
sample_freq: float,
|
| 440 |
+
low_freq: float,
|
| 441 |
+
high_freq: float,
|
| 442 |
+
vtln_low: float,
|
| 443 |
+
vtln_high: float,
|
| 444 |
+
vtln_warp_factor: float,
|
| 445 |
+
device=None,
|
| 446 |
+
dtype=None,
|
| 447 |
+
) -> Tuple[Tensor, Tensor]:
|
| 448 |
+
"""
|
| 449 |
+
Returns:
|
| 450 |
+
(Tensor, Tensor): The tuple consists of ``bins`` (which is
|
| 451 |
+
melbank of size (``num_bins``, ``num_fft_bins``)) and ``center_freqs`` (which is
|
| 452 |
+
center frequencies of bins of size (``num_bins``)).
|
| 453 |
+
"""
|
| 454 |
+
assert num_bins > 3, "Must have at least 3 mel bins"
|
| 455 |
+
assert window_length_padded % 2 == 0
|
| 456 |
+
num_fft_bins = window_length_padded / 2
|
| 457 |
+
nyquist = 0.5 * sample_freq
|
| 458 |
+
|
| 459 |
+
if high_freq <= 0.0:
|
| 460 |
+
high_freq += nyquist
|
| 461 |
+
|
| 462 |
+
assert (0.0 <= low_freq < nyquist) and (0.0 < high_freq <= nyquist) and (low_freq < high_freq), (
|
| 463 |
+
"Bad values in options: low-freq {} and high-freq {} vs. nyquist {}".format(low_freq, high_freq, nyquist)
|
| 464 |
+
)
|
| 465 |
+
|
| 466 |
+
# fft-bin width [think of it as Nyquist-freq / half-window-length]
|
| 467 |
+
fft_bin_width = sample_freq / window_length_padded
|
| 468 |
+
mel_low_freq = mel_scale_scalar(low_freq)
|
| 469 |
+
mel_high_freq = mel_scale_scalar(high_freq)
|
| 470 |
+
|
| 471 |
+
# divide by num_bins+1 in next line because of end-effects where the bins
|
| 472 |
+
# spread out to the sides.
|
| 473 |
+
mel_freq_delta = (mel_high_freq - mel_low_freq) / (num_bins + 1)
|
| 474 |
+
|
| 475 |
+
if vtln_high < 0.0:
|
| 476 |
+
vtln_high += nyquist
|
| 477 |
+
|
| 478 |
+
assert vtln_warp_factor == 1.0 or (
|
| 479 |
+
(low_freq < vtln_low < high_freq) and (0.0 < vtln_high < high_freq) and (vtln_low < vtln_high)
|
| 480 |
+
), "Bad values in options: vtln-low {} and vtln-high {}, versus low-freq {} and high-freq {}".format(
|
| 481 |
+
vtln_low, vtln_high, low_freq, high_freq
|
| 482 |
+
)
|
| 483 |
+
|
| 484 |
+
bin = torch.arange(num_bins).unsqueeze(1)
|
| 485 |
+
left_mel = mel_low_freq + bin * mel_freq_delta # size(num_bins, 1)
|
| 486 |
+
center_mel = mel_low_freq + (bin + 1.0) * mel_freq_delta # size(num_bins, 1)
|
| 487 |
+
right_mel = mel_low_freq + (bin + 2.0) * mel_freq_delta # size(num_bins, 1)
|
| 488 |
+
|
| 489 |
+
if vtln_warp_factor != 1.0:
|
| 490 |
+
left_mel = vtln_warp_mel_freq(vtln_low, vtln_high, low_freq, high_freq, vtln_warp_factor, left_mel)
|
| 491 |
+
center_mel = vtln_warp_mel_freq(vtln_low, vtln_high, low_freq, high_freq, vtln_warp_factor, center_mel)
|
| 492 |
+
right_mel = vtln_warp_mel_freq(vtln_low, vtln_high, low_freq, high_freq, vtln_warp_factor, right_mel)
|
| 493 |
+
|
| 494 |
+
# center_freqs = inverse_mel_scale(center_mel) # size (num_bins)
|
| 495 |
+
# size(1, num_fft_bins)
|
| 496 |
+
mel = mel_scale(fft_bin_width * torch.arange(num_fft_bins)).unsqueeze(0)
|
| 497 |
+
|
| 498 |
+
# size (num_bins, num_fft_bins)
|
| 499 |
+
up_slope = (mel - left_mel) / (center_mel - left_mel)
|
| 500 |
+
down_slope = (right_mel - mel) / (right_mel - center_mel)
|
| 501 |
+
|
| 502 |
+
if vtln_warp_factor == 1.0:
|
| 503 |
+
# left_mel < center_mel < right_mel so we can min the two slopes and clamp negative values
|
| 504 |
+
bins = torch.max(torch.zeros(1), torch.min(up_slope, down_slope))
|
| 505 |
+
else:
|
| 506 |
+
# warping can move the order of left_mel, center_mel, right_mel anywhere
|
| 507 |
+
bins = torch.zeros_like(up_slope)
|
| 508 |
+
up_idx = torch.gt(mel, left_mel) & torch.le(mel, center_mel) # left_mel < mel <= center_mel
|
| 509 |
+
down_idx = torch.gt(mel, center_mel) & torch.lt(mel, right_mel) # center_mel < mel < right_mel
|
| 510 |
+
bins[up_idx] = up_slope[up_idx]
|
| 511 |
+
bins[down_idx] = down_slope[down_idx]
|
| 512 |
+
|
| 513 |
+
return bins.to(device=device, dtype=dtype) # , center_freqs
|
| 514 |
+
|
| 515 |
+
|
| 516 |
+
cache = {}
|
| 517 |
+
|
| 518 |
+
|
| 519 |
+
def fbank(
|
| 520 |
+
waveform: Tensor,
|
| 521 |
+
blackman_coeff: float = 0.42,
|
| 522 |
+
channel: int = -1,
|
| 523 |
+
dither: float = 0.0,
|
| 524 |
+
energy_floor: float = 1.0,
|
| 525 |
+
frame_length: float = 25.0,
|
| 526 |
+
frame_shift: float = 10.0,
|
| 527 |
+
high_freq: float = 0.0,
|
| 528 |
+
htk_compat: bool = False,
|
| 529 |
+
low_freq: float = 20.0,
|
| 530 |
+
min_duration: float = 0.0,
|
| 531 |
+
num_mel_bins: int = 23,
|
| 532 |
+
preemphasis_coefficient: float = 0.97,
|
| 533 |
+
raw_energy: bool = True,
|
| 534 |
+
remove_dc_offset: bool = True,
|
| 535 |
+
round_to_power_of_two: bool = True,
|
| 536 |
+
sample_frequency: float = 16000.0,
|
| 537 |
+
snip_edges: bool = True,
|
| 538 |
+
subtract_mean: bool = False,
|
| 539 |
+
use_energy: bool = False,
|
| 540 |
+
use_log_fbank: bool = True,
|
| 541 |
+
use_power: bool = True,
|
| 542 |
+
vtln_high: float = -500.0,
|
| 543 |
+
vtln_low: float = 100.0,
|
| 544 |
+
vtln_warp: float = 1.0,
|
| 545 |
+
window_type: str = POVEY,
|
| 546 |
+
) -> Tensor:
|
| 547 |
+
r"""Create a fbank from a raw audio signal. This matches the input/output of Kaldi's
|
| 548 |
+
compute-fbank-feats.
|
| 549 |
+
|
| 550 |
+
Args:
|
| 551 |
+
waveform (Tensor): Tensor of audio of size (c, n) where c is in the range [0,2)
|
| 552 |
+
blackman_coeff (float, optional): Constant coefficient for generalized Blackman window. (Default: ``0.42``)
|
| 553 |
+
channel (int, optional): Channel to extract (-1 -> expect mono, 0 -> left, 1 -> right) (Default: ``-1``)
|
| 554 |
+
dither (float, optional): Dithering constant (0.0 means no dither). If you turn this off, you should set
|
| 555 |
+
the energy_floor option, e.g. to 1.0 or 0.1 (Default: ``0.0``)
|
| 556 |
+
energy_floor (float, optional): Floor on energy (absolute, not relative) in Spectrogram computation. Caution:
|
| 557 |
+
this floor is applied to the zeroth component, representing the total signal energy. The floor on the
|
| 558 |
+
individual spectrogram elements is fixed at std::numeric_limits<float>::epsilon(). (Default: ``1.0``)
|
| 559 |
+
frame_length (float, optional): Frame length in milliseconds (Default: ``25.0``)
|
| 560 |
+
frame_shift (float, optional): Frame shift in milliseconds (Default: ``10.0``)
|
| 561 |
+
high_freq (float, optional): High cutoff frequency for mel bins (if <= 0, offset from Nyquist)
|
| 562 |
+
(Default: ``0.0``)
|
| 563 |
+
htk_compat (bool, optional): If true, put energy last. Warning: not sufficient to get HTK compatible features
|
| 564 |
+
(need to change other parameters). (Default: ``False``)
|
| 565 |
+
low_freq (float, optional): Low cutoff frequency for mel bins (Default: ``20.0``)
|
| 566 |
+
min_duration (float, optional): Minimum duration of segments to process (in seconds). (Default: ``0.0``)
|
| 567 |
+
num_mel_bins (int, optional): Number of triangular mel-frequency bins (Default: ``23``)
|
| 568 |
+
preemphasis_coefficient (float, optional): Coefficient for use in signal preemphasis (Default: ``0.97``)
|
| 569 |
+
raw_energy (bool, optional): If True, compute energy before preemphasis and windowing (Default: ``True``)
|
| 570 |
+
remove_dc_offset (bool, optional): Subtract mean from waveform on each frame (Default: ``True``)
|
| 571 |
+
round_to_power_of_two (bool, optional): If True, round window size to power of two by zero-padding input
|
| 572 |
+
to FFT. (Default: ``True``)
|
| 573 |
+
sample_frequency (float, optional): Waveform data sample frequency (must match the waveform file, if
|
| 574 |
+
specified there) (Default: ``16000.0``)
|
| 575 |
+
snip_edges (bool, optional): If True, end effects will be handled by outputting only frames that completely fit
|
| 576 |
+
in the file, and the number of frames depends on the frame_length. If False, the number of frames
|
| 577 |
+
depends only on the frame_shift, and we reflect the data at the ends. (Default: ``True``)
|
| 578 |
+
subtract_mean (bool, optional): Subtract mean of each feature file [CMS]; not recommended to do
|
| 579 |
+
it this way. (Default: ``False``)
|
| 580 |
+
use_energy (bool, optional): Add an extra dimension with energy to the FBANK output. (Default: ``False``)
|
| 581 |
+
use_log_fbank (bool, optional):If true, produce log-filterbank, else produce linear. (Default: ``True``)
|
| 582 |
+
use_power (bool, optional): If true, use power, else use magnitude. (Default: ``True``)
|
| 583 |
+
vtln_high (float, optional): High inflection point in piecewise linear VTLN warping function (if
|
| 584 |
+
negative, offset from high-mel-freq (Default: ``-500.0``)
|
| 585 |
+
vtln_low (float, optional): Low inflection point in piecewise linear VTLN warping function (Default: ``100.0``)
|
| 586 |
+
vtln_warp (float, optional): Vtln warp factor (only applicable if vtln_map not specified) (Default: ``1.0``)
|
| 587 |
+
window_type (str, optional): Type of window ('hamming'|'hanning'|'povey'|'rectangular'|'blackman')
|
| 588 |
+
(Default: ``'povey'``)
|
| 589 |
+
|
| 590 |
+
Returns:
|
| 591 |
+
Tensor: A fbank identical to what Kaldi would output. The shape is (m, ``num_mel_bins + use_energy``)
|
| 592 |
+
where m is calculated in _get_strided
|
| 593 |
+
"""
|
| 594 |
+
device, dtype = waveform.device, waveform.dtype
|
| 595 |
+
|
| 596 |
+
waveform, window_shift, window_size, padded_window_size = _get_waveform_and_window_properties(
|
| 597 |
+
waveform, channel, sample_frequency, frame_shift, frame_length, round_to_power_of_two, preemphasis_coefficient
|
| 598 |
+
)
|
| 599 |
+
|
| 600 |
+
if len(waveform) < min_duration * sample_frequency:
|
| 601 |
+
# signal is too short
|
| 602 |
+
return torch.empty(0, device=device, dtype=dtype)
|
| 603 |
+
|
| 604 |
+
# strided_input, size (m, padded_window_size) and signal_log_energy, size (m)
|
| 605 |
+
strided_input, signal_log_energy = _get_window(
|
| 606 |
+
waveform,
|
| 607 |
+
padded_window_size,
|
| 608 |
+
window_size,
|
| 609 |
+
window_shift,
|
| 610 |
+
window_type,
|
| 611 |
+
blackman_coeff,
|
| 612 |
+
snip_edges,
|
| 613 |
+
raw_energy,
|
| 614 |
+
energy_floor,
|
| 615 |
+
dither,
|
| 616 |
+
remove_dc_offset,
|
| 617 |
+
preemphasis_coefficient,
|
| 618 |
+
)
|
| 619 |
+
|
| 620 |
+
# size (m, padded_window_size // 2 + 1)
|
| 621 |
+
spectrum = torch.fft.rfft(strided_input).abs()
|
| 622 |
+
if use_power:
|
| 623 |
+
spectrum = spectrum.pow(2.0)
|
| 624 |
+
|
| 625 |
+
# size (num_mel_bins, padded_window_size // 2)
|
| 626 |
+
# print(num_mel_bins, padded_window_size, sample_frequency, low_freq, high_freq, vtln_low, vtln_high, vtln_warp)
|
| 627 |
+
|
| 628 |
+
cache_key = "%s-%s-%s-%s-%s-%s-%s-%s-%s-%s" % (
|
| 629 |
+
num_mel_bins,
|
| 630 |
+
padded_window_size,
|
| 631 |
+
sample_frequency,
|
| 632 |
+
low_freq,
|
| 633 |
+
high_freq,
|
| 634 |
+
vtln_low,
|
| 635 |
+
vtln_high,
|
| 636 |
+
vtln_warp,
|
| 637 |
+
device,
|
| 638 |
+
dtype,
|
| 639 |
+
)
|
| 640 |
+
if cache_key not in cache:
|
| 641 |
+
mel_energies = get_mel_banks(
|
| 642 |
+
num_mel_bins,
|
| 643 |
+
padded_window_size,
|
| 644 |
+
sample_frequency,
|
| 645 |
+
low_freq,
|
| 646 |
+
high_freq,
|
| 647 |
+
vtln_low,
|
| 648 |
+
vtln_high,
|
| 649 |
+
vtln_warp,
|
| 650 |
+
device,
|
| 651 |
+
dtype,
|
| 652 |
+
)
|
| 653 |
+
cache[cache_key] = mel_energies
|
| 654 |
+
else:
|
| 655 |
+
mel_energies = cache[cache_key]
|
| 656 |
+
|
| 657 |
+
# pad right column with zeros and add dimension, size (num_mel_bins, padded_window_size // 2 + 1)
|
| 658 |
+
mel_energies = torch.nn.functional.pad(mel_energies, (0, 1), mode="constant", value=0)
|
| 659 |
+
|
| 660 |
+
# sum with mel fiterbanks over the power spectrum, size (m, num_mel_bins)
|
| 661 |
+
mel_energies = torch.mm(spectrum, mel_energies.T)
|
| 662 |
+
if use_log_fbank:
|
| 663 |
+
# avoid log of zero (which should be prevented anyway by dithering)
|
| 664 |
+
mel_energies = torch.max(mel_energies, _get_epsilon(device, dtype)).log()
|
| 665 |
+
|
| 666 |
+
# if use_energy then add it as the last column for htk_compat == true else first column
|
| 667 |
+
if use_energy:
|
| 668 |
+
signal_log_energy = signal_log_energy.unsqueeze(1) # size (m, 1)
|
| 669 |
+
# returns size (m, num_mel_bins + 1)
|
| 670 |
+
if htk_compat:
|
| 671 |
+
mel_energies = torch.cat((mel_energies, signal_log_energy), dim=1)
|
| 672 |
+
else:
|
| 673 |
+
mel_energies = torch.cat((signal_log_energy, mel_energies), dim=1)
|
| 674 |
+
|
| 675 |
+
mel_energies = _subtract_column_mean(mel_energies, subtract_mean)
|
| 676 |
+
return mel_energies
|
| 677 |
+
|
| 678 |
+
|
| 679 |
+
def _get_dct_matrix(num_ceps: int, num_mel_bins: int) -> Tensor:
|
| 680 |
+
# returns a dct matrix of size (num_mel_bins, num_ceps)
|
| 681 |
+
# size (num_mel_bins, num_mel_bins)
|
| 682 |
+
dct_matrix = torchaudio.functional.create_dct(num_mel_bins, num_mel_bins, "ortho")
|
| 683 |
+
# kaldi expects the first cepstral to be weighted sum of factor sqrt(1/num_mel_bins)
|
| 684 |
+
# this would be the first column in the dct_matrix for torchaudio as it expects a
|
| 685 |
+
# right multiply (which would be the first column of the kaldi's dct_matrix as kaldi
|
| 686 |
+
# expects a left multiply e.g. dct_matrix * vector).
|
| 687 |
+
dct_matrix[:, 0] = math.sqrt(1 / float(num_mel_bins))
|
| 688 |
+
dct_matrix = dct_matrix[:, :num_ceps]
|
| 689 |
+
return dct_matrix
|
| 690 |
+
|
| 691 |
+
|
| 692 |
+
def _get_lifter_coeffs(num_ceps: int, cepstral_lifter: float) -> Tensor:
|
| 693 |
+
# returns size (num_ceps)
|
| 694 |
+
# Compute liftering coefficients (scaling on cepstral coeffs)
|
| 695 |
+
# coeffs are numbered slightly differently from HTK: the zeroth index is C0, which is not affected.
|
| 696 |
+
i = torch.arange(num_ceps)
|
| 697 |
+
return 1.0 + 0.5 * cepstral_lifter * torch.sin(math.pi * i / cepstral_lifter)
|
| 698 |
+
|
| 699 |
+
|
| 700 |
+
def mfcc(
|
| 701 |
+
waveform: Tensor,
|
| 702 |
+
blackman_coeff: float = 0.42,
|
| 703 |
+
cepstral_lifter: float = 22.0,
|
| 704 |
+
channel: int = -1,
|
| 705 |
+
dither: float = 0.0,
|
| 706 |
+
energy_floor: float = 1.0,
|
| 707 |
+
frame_length: float = 25.0,
|
| 708 |
+
frame_shift: float = 10.0,
|
| 709 |
+
high_freq: float = 0.0,
|
| 710 |
+
htk_compat: bool = False,
|
| 711 |
+
low_freq: float = 20.0,
|
| 712 |
+
num_ceps: int = 13,
|
| 713 |
+
min_duration: float = 0.0,
|
| 714 |
+
num_mel_bins: int = 23,
|
| 715 |
+
preemphasis_coefficient: float = 0.97,
|
| 716 |
+
raw_energy: bool = True,
|
| 717 |
+
remove_dc_offset: bool = True,
|
| 718 |
+
round_to_power_of_two: bool = True,
|
| 719 |
+
sample_frequency: float = 16000.0,
|
| 720 |
+
snip_edges: bool = True,
|
| 721 |
+
subtract_mean: bool = False,
|
| 722 |
+
use_energy: bool = False,
|
| 723 |
+
vtln_high: float = -500.0,
|
| 724 |
+
vtln_low: float = 100.0,
|
| 725 |
+
vtln_warp: float = 1.0,
|
| 726 |
+
window_type: str = POVEY,
|
| 727 |
+
) -> Tensor:
|
| 728 |
+
r"""Create a mfcc from a raw audio signal. This matches the input/output of Kaldi's
|
| 729 |
+
compute-mfcc-feats.
|
| 730 |
+
|
| 731 |
+
Args:
|
| 732 |
+
waveform (Tensor): Tensor of audio of size (c, n) where c is in the range [0,2)
|
| 733 |
+
blackman_coeff (float, optional): Constant coefficient for generalized Blackman window. (Default: ``0.42``)
|
| 734 |
+
cepstral_lifter (float, optional): Constant that controls scaling of MFCCs (Default: ``22.0``)
|
| 735 |
+
channel (int, optional): Channel to extract (-1 -> expect mono, 0 -> left, 1 -> right) (Default: ``-1``)
|
| 736 |
+
dither (float, optional): Dithering constant (0.0 means no dither). If you turn this off, you should set
|
| 737 |
+
the energy_floor option, e.g. to 1.0 or 0.1 (Default: ``0.0``)
|
| 738 |
+
energy_floor (float, optional): Floor on energy (absolute, not relative) in Spectrogram computation. Caution:
|
| 739 |
+
this floor is applied to the zeroth component, representing the total signal energy. The floor on the
|
| 740 |
+
individual spectrogram elements is fixed at std::numeric_limits<float>::epsilon(). (Default: ``1.0``)
|
| 741 |
+
frame_length (float, optional): Frame length in milliseconds (Default: ``25.0``)
|
| 742 |
+
frame_shift (float, optional): Frame shift in milliseconds (Default: ``10.0``)
|
| 743 |
+
high_freq (float, optional): High cutoff frequency for mel bins (if <= 0, offset from Nyquist)
|
| 744 |
+
(Default: ``0.0``)
|
| 745 |
+
htk_compat (bool, optional): If true, put energy last. Warning: not sufficient to get HTK compatible
|
| 746 |
+
features (need to change other parameters). (Default: ``False``)
|
| 747 |
+
low_freq (float, optional): Low cutoff frequency for mel bins (Default: ``20.0``)
|
| 748 |
+
num_ceps (int, optional): Number of cepstra in MFCC computation (including C0) (Default: ``13``)
|
| 749 |
+
min_duration (float, optional): Minimum duration of segments to process (in seconds). (Default: ``0.0``)
|
| 750 |
+
num_mel_bins (int, optional): Number of triangular mel-frequency bins (Default: ``23``)
|
| 751 |
+
preemphasis_coefficient (float, optional): Coefficient for use in signal preemphasis (Default: ``0.97``)
|
| 752 |
+
raw_energy (bool, optional): If True, compute energy before preemphasis and windowing (Default: ``True``)
|
| 753 |
+
remove_dc_offset (bool, optional): Subtract mean from waveform on each frame (Default: ``True``)
|
| 754 |
+
round_to_power_of_two (bool, optional): If True, round window size to power of two by zero-padding input
|
| 755 |
+
to FFT. (Default: ``True``)
|
| 756 |
+
sample_frequency (float, optional): Waveform data sample frequency (must match the waveform file, if
|
| 757 |
+
specified there) (Default: ``16000.0``)
|
| 758 |
+
snip_edges (bool, optional): If True, end effects will be handled by outputting only frames that completely fit
|
| 759 |
+
in the file, and the number of frames depends on the frame_length. If False, the number of frames
|
| 760 |
+
depends only on the frame_shift, and we reflect the data at the ends. (Default: ``True``)
|
| 761 |
+
subtract_mean (bool, optional): Subtract mean of each feature file [CMS]; not recommended to do
|
| 762 |
+
it this way. (Default: ``False``)
|
| 763 |
+
use_energy (bool, optional): Add an extra dimension with energy to the FBANK output. (Default: ``False``)
|
| 764 |
+
vtln_high (float, optional): High inflection point in piecewise linear VTLN warping function (if
|
| 765 |
+
negative, offset from high-mel-freq (Default: ``-500.0``)
|
| 766 |
+
vtln_low (float, optional): Low inflection point in piecewise linear VTLN warping function (Default: ``100.0``)
|
| 767 |
+
vtln_warp (float, optional): Vtln warp factor (only applicable if vtln_map not specified) (Default: ``1.0``)
|
| 768 |
+
window_type (str, optional): Type of window ('hamming'|'hanning'|'povey'|'rectangular'|'blackman')
|
| 769 |
+
(Default: ``"povey"``)
|
| 770 |
+
|
| 771 |
+
Returns:
|
| 772 |
+
Tensor: A mfcc identical to what Kaldi would output. The shape is (m, ``num_ceps``)
|
| 773 |
+
where m is calculated in _get_strided
|
| 774 |
+
"""
|
| 775 |
+
assert num_ceps <= num_mel_bins, "num_ceps cannot be larger than num_mel_bins: %d vs %d" % (num_ceps, num_mel_bins)
|
| 776 |
+
|
| 777 |
+
device, dtype = waveform.device, waveform.dtype
|
| 778 |
+
|
| 779 |
+
# The mel_energies should not be squared (use_power=True), not have mean subtracted
|
| 780 |
+
# (subtract_mean=False), and use log (use_log_fbank=True).
|
| 781 |
+
# size (m, num_mel_bins + use_energy)
|
| 782 |
+
feature = fbank(
|
| 783 |
+
waveform=waveform,
|
| 784 |
+
blackman_coeff=blackman_coeff,
|
| 785 |
+
channel=channel,
|
| 786 |
+
dither=dither,
|
| 787 |
+
energy_floor=energy_floor,
|
| 788 |
+
frame_length=frame_length,
|
| 789 |
+
frame_shift=frame_shift,
|
| 790 |
+
high_freq=high_freq,
|
| 791 |
+
htk_compat=htk_compat,
|
| 792 |
+
low_freq=low_freq,
|
| 793 |
+
min_duration=min_duration,
|
| 794 |
+
num_mel_bins=num_mel_bins,
|
| 795 |
+
preemphasis_coefficient=preemphasis_coefficient,
|
| 796 |
+
raw_energy=raw_energy,
|
| 797 |
+
remove_dc_offset=remove_dc_offset,
|
| 798 |
+
round_to_power_of_two=round_to_power_of_two,
|
| 799 |
+
sample_frequency=sample_frequency,
|
| 800 |
+
snip_edges=snip_edges,
|
| 801 |
+
subtract_mean=False,
|
| 802 |
+
use_energy=use_energy,
|
| 803 |
+
use_log_fbank=True,
|
| 804 |
+
use_power=True,
|
| 805 |
+
vtln_high=vtln_high,
|
| 806 |
+
vtln_low=vtln_low,
|
| 807 |
+
vtln_warp=vtln_warp,
|
| 808 |
+
window_type=window_type,
|
| 809 |
+
)
|
| 810 |
+
|
| 811 |
+
if use_energy:
|
| 812 |
+
# size (m)
|
| 813 |
+
signal_log_energy = feature[:, num_mel_bins if htk_compat else 0]
|
| 814 |
+
# offset is 0 if htk_compat==True else 1
|
| 815 |
+
mel_offset = int(not htk_compat)
|
| 816 |
+
feature = feature[:, mel_offset : (num_mel_bins + mel_offset)]
|
| 817 |
+
|
| 818 |
+
# size (num_mel_bins, num_ceps)
|
| 819 |
+
dct_matrix = _get_dct_matrix(num_ceps, num_mel_bins).to(dtype=dtype, device=device)
|
| 820 |
+
|
| 821 |
+
# size (m, num_ceps)
|
| 822 |
+
feature = feature.matmul(dct_matrix)
|
| 823 |
+
|
| 824 |
+
if cepstral_lifter != 0.0:
|
| 825 |
+
# size (1, num_ceps)
|
| 826 |
+
lifter_coeffs = _get_lifter_coeffs(num_ceps, cepstral_lifter).unsqueeze(0)
|
| 827 |
+
feature *= lifter_coeffs.to(device=device, dtype=dtype)
|
| 828 |
+
|
| 829 |
+
# if use_energy then replace the last column for htk_compat == true else first column
|
| 830 |
+
if use_energy:
|
| 831 |
+
feature[:, 0] = signal_log_energy
|
| 832 |
+
|
| 833 |
+
if htk_compat:
|
| 834 |
+
energy = feature[:, 0].unsqueeze(1) # size (m, 1)
|
| 835 |
+
feature = feature[:, 1:] # size (m, num_ceps - 1)
|
| 836 |
+
if not use_energy:
|
| 837 |
+
# scale on C0 (actually removing a scale we previously added that's
|
| 838 |
+
# part of one common definition of the cosine transform.)
|
| 839 |
+
energy *= math.sqrt(2)
|
| 840 |
+
|
| 841 |
+
feature = torch.cat((feature, energy), dim=1)
|
| 842 |
+
|
| 843 |
+
feature = _subtract_column_mean(feature, subtract_mean)
|
| 844 |
+
return feature
|
eres2net/pooling_layers.py
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker). All Rights Reserved.
|
| 2 |
+
# Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
| 3 |
+
|
| 4 |
+
"""This implementation is adapted from https://github.com/wenet-e2e/wespeaker."""
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class TAP(nn.Module):
|
| 11 |
+
"""
|
| 12 |
+
Temporal average pooling, only first-order mean is considered
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
def __init__(self, **kwargs):
|
| 16 |
+
super(TAP, self).__init__()
|
| 17 |
+
|
| 18 |
+
def forward(self, x):
|
| 19 |
+
pooling_mean = x.mean(dim=-1)
|
| 20 |
+
# To be compatable with 2D input
|
| 21 |
+
pooling_mean = pooling_mean.flatten(start_dim=1)
|
| 22 |
+
return pooling_mean
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class TSDP(nn.Module):
|
| 26 |
+
"""
|
| 27 |
+
Temporal standard deviation pooling, only second-order std is considered
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
def __init__(self, **kwargs):
|
| 31 |
+
super(TSDP, self).__init__()
|
| 32 |
+
|
| 33 |
+
def forward(self, x):
|
| 34 |
+
# The last dimension is the temporal axis
|
| 35 |
+
pooling_std = torch.sqrt(torch.var(x, dim=-1) + 1e-8)
|
| 36 |
+
pooling_std = pooling_std.flatten(start_dim=1)
|
| 37 |
+
return pooling_std
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class TSTP(nn.Module):
|
| 41 |
+
"""
|
| 42 |
+
Temporal statistics pooling, concatenate mean and std, which is used in
|
| 43 |
+
x-vector
|
| 44 |
+
Comment: simple concatenation can not make full use of both statistics
|
| 45 |
+
"""
|
| 46 |
+
|
| 47 |
+
def __init__(self, **kwargs):
|
| 48 |
+
super(TSTP, self).__init__()
|
| 49 |
+
|
| 50 |
+
def forward(self, x):
|
| 51 |
+
# The last dimension is the temporal axis
|
| 52 |
+
pooling_mean = x.mean(dim=-1)
|
| 53 |
+
pooling_std = torch.sqrt(torch.var(x, dim=-1) + 1e-8)
|
| 54 |
+
pooling_mean = pooling_mean.flatten(start_dim=1)
|
| 55 |
+
pooling_std = pooling_std.flatten(start_dim=1)
|
| 56 |
+
|
| 57 |
+
stats = torch.cat((pooling_mean, pooling_std), 1)
|
| 58 |
+
return stats
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
class ASTP(nn.Module):
|
| 62 |
+
"""Attentive statistics pooling: Channel- and context-dependent
|
| 63 |
+
statistics pooling, first used in ECAPA_TDNN.
|
| 64 |
+
"""
|
| 65 |
+
|
| 66 |
+
def __init__(self, in_dim, bottleneck_dim=128, global_context_att=False):
|
| 67 |
+
super(ASTP, self).__init__()
|
| 68 |
+
self.global_context_att = global_context_att
|
| 69 |
+
|
| 70 |
+
# Use Conv1d with stride == 1 rather than Linear, then we don't
|
| 71 |
+
# need to transpose inputs.
|
| 72 |
+
if global_context_att:
|
| 73 |
+
self.linear1 = nn.Conv1d(in_dim * 3, bottleneck_dim, kernel_size=1) # equals W and b in the paper
|
| 74 |
+
else:
|
| 75 |
+
self.linear1 = nn.Conv1d(in_dim, bottleneck_dim, kernel_size=1) # equals W and b in the paper
|
| 76 |
+
self.linear2 = nn.Conv1d(bottleneck_dim, in_dim, kernel_size=1) # equals V and k in the paper
|
| 77 |
+
|
| 78 |
+
def forward(self, x):
|
| 79 |
+
"""
|
| 80 |
+
x: a 3-dimensional tensor in tdnn-based architecture (B,F,T)
|
| 81 |
+
or a 4-dimensional tensor in resnet architecture (B,C,F,T)
|
| 82 |
+
0-dim: batch-dimension, last-dim: time-dimension (frame-dimension)
|
| 83 |
+
"""
|
| 84 |
+
if len(x.shape) == 4:
|
| 85 |
+
x = x.reshape(x.shape[0], x.shape[1] * x.shape[2], x.shape[3])
|
| 86 |
+
assert len(x.shape) == 3
|
| 87 |
+
|
| 88 |
+
if self.global_context_att:
|
| 89 |
+
context_mean = torch.mean(x, dim=-1, keepdim=True).expand_as(x)
|
| 90 |
+
context_std = torch.sqrt(torch.var(x, dim=-1, keepdim=True) + 1e-10).expand_as(x)
|
| 91 |
+
x_in = torch.cat((x, context_mean, context_std), dim=1)
|
| 92 |
+
else:
|
| 93 |
+
x_in = x
|
| 94 |
+
|
| 95 |
+
# DON'T use ReLU here! ReLU may be hard to converge.
|
| 96 |
+
alpha = torch.tanh(self.linear1(x_in)) # alpha = F.relu(self.linear1(x_in))
|
| 97 |
+
alpha = torch.softmax(self.linear2(alpha), dim=2)
|
| 98 |
+
mean = torch.sum(alpha * x, dim=2)
|
| 99 |
+
var = torch.sum(alpha * (x**2), dim=2) - mean**2
|
| 100 |
+
std = torch.sqrt(var.clamp(min=1e-10))
|
| 101 |
+
return torch.cat([mean, std], dim=1)
|
f5_tts/model/__init__.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# from f5_tts.model.cfm import CFM
|
| 2 |
+
#
|
| 3 |
+
# from f5_tts.model.backbones.unett import UNetT
|
| 4 |
+
from GPT_SoVITS.f5_tts.model.backbones.dit import DiT
|
| 5 |
+
# from f5_tts.model.backbones.dit import DiTNoCond
|
| 6 |
+
# from f5_tts.model.backbones.dit import DiTNoCondNoT
|
| 7 |
+
# from f5_tts.model.backbones.mmdit import MMDiT
|
| 8 |
+
|
| 9 |
+
# from f5_tts.model.trainer import Trainer
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
# __all__ = ["CFM", "UNetT", "DiT", "MMDiT", "Trainer"]
|
| 13 |
+
# __all__ = ["CFM", "UNetT", "DiTNoCond","DiT", "MMDiT"]
|
f5_tts/model/backbones/README.md
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
## Backbones quick introduction
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
### unett.py
|
| 5 |
+
- flat unet transformer
|
| 6 |
+
- structure same as in e2-tts & voicebox paper except using rotary pos emb
|
| 7 |
+
- update: allow possible abs pos emb & convnextv2 blocks for embedded text before concat
|
| 8 |
+
|
| 9 |
+
### dit.py
|
| 10 |
+
- adaln-zero dit
|
| 11 |
+
- embedded timestep as condition
|
| 12 |
+
- concatted noised_input + masked_cond + embedded_text, linear proj in
|
| 13 |
+
- possible abs pos emb & convnextv2 blocks for embedded text before concat
|
| 14 |
+
- possible long skip connection (first layer to last layer)
|
| 15 |
+
|
| 16 |
+
### mmdit.py
|
| 17 |
+
- sd3 structure
|
| 18 |
+
- timestep as condition
|
| 19 |
+
- left stream: text embedded and applied a abs pos emb
|
| 20 |
+
- right stream: masked_cond & noised_input concatted and with same conv pos emb as unett
|
f5_tts/model/backbones/dit.py
ADDED
|
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
ein notation:
|
| 3 |
+
b - batch
|
| 4 |
+
n - sequence
|
| 5 |
+
nt - text sequence
|
| 6 |
+
nw - raw wave length
|
| 7 |
+
d - dimension
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
from __future__ import annotations
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
from torch import nn
|
| 14 |
+
from torch.utils.checkpoint import checkpoint
|
| 15 |
+
|
| 16 |
+
from x_transformers.x_transformers import RotaryEmbedding
|
| 17 |
+
|
| 18 |
+
from GPT_SoVITS.f5_tts.model.modules import (
|
| 19 |
+
TimestepEmbedding,
|
| 20 |
+
ConvNeXtV2Block,
|
| 21 |
+
ConvPositionEmbedding,
|
| 22 |
+
DiTBlock,
|
| 23 |
+
AdaLayerNormZero_Final,
|
| 24 |
+
precompute_freqs_cis,
|
| 25 |
+
get_pos_embed_indices,
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
from module.commons import sequence_mask
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class TextEmbedding(nn.Module):
|
| 32 |
+
def __init__(self, text_dim, conv_layers=0, conv_mult=2):
|
| 33 |
+
super().__init__()
|
| 34 |
+
if conv_layers > 0:
|
| 35 |
+
self.extra_modeling = True
|
| 36 |
+
self.precompute_max_pos = 4096 # ~44s of 24khz audio
|
| 37 |
+
self.register_buffer("freqs_cis", precompute_freqs_cis(text_dim, self.precompute_max_pos), persistent=False)
|
| 38 |
+
self.text_blocks = nn.Sequential(
|
| 39 |
+
*[ConvNeXtV2Block(text_dim, text_dim * conv_mult) for _ in range(conv_layers)]
|
| 40 |
+
)
|
| 41 |
+
else:
|
| 42 |
+
self.extra_modeling = False
|
| 43 |
+
|
| 44 |
+
def forward(self, text: int["b nt"], seq_len, drop_text=False): # noqa: F722
|
| 45 |
+
batch, text_len = text.shape[0], text.shape[1]
|
| 46 |
+
|
| 47 |
+
if drop_text: # cfg for text
|
| 48 |
+
text = torch.zeros_like(text)
|
| 49 |
+
|
| 50 |
+
# possible extra modeling
|
| 51 |
+
if self.extra_modeling:
|
| 52 |
+
# sinus pos emb
|
| 53 |
+
batch_start = torch.zeros((batch,), dtype=torch.long)
|
| 54 |
+
pos_idx = get_pos_embed_indices(batch_start, seq_len, max_pos=self.precompute_max_pos)
|
| 55 |
+
text_pos_embed = self.freqs_cis[pos_idx]
|
| 56 |
+
|
| 57 |
+
# print(23333333,text.shape,text_pos_embed.shape)#torch.Size([7, 465, 256]) torch.Size([7, 465, 256])
|
| 58 |
+
|
| 59 |
+
text = text + text_pos_embed
|
| 60 |
+
|
| 61 |
+
# convnextv2 blocks
|
| 62 |
+
text = self.text_blocks(text)
|
| 63 |
+
|
| 64 |
+
return text
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
# noised input audio and context mixing embedding
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
class InputEmbedding(nn.Module):
|
| 71 |
+
def __init__(self, mel_dim, text_dim, out_dim):
|
| 72 |
+
super().__init__()
|
| 73 |
+
self.proj = nn.Linear(mel_dim * 2 + text_dim, out_dim)
|
| 74 |
+
self.conv_pos_embed = ConvPositionEmbedding(dim=out_dim)
|
| 75 |
+
|
| 76 |
+
def forward(self, x: float["b n d"], cond: float["b n d"], text_embed: float["b n d"], drop_audio_cond=False): # noqa: F722
|
| 77 |
+
if drop_audio_cond: # cfg for cond audio
|
| 78 |
+
cond = torch.zeros_like(cond)
|
| 79 |
+
|
| 80 |
+
x = self.proj(torch.cat((x, cond, text_embed), dim=-1))
|
| 81 |
+
x = self.conv_pos_embed(x) + x
|
| 82 |
+
return x
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
# Transformer backbone using DiT blocks
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
class DiT(nn.Module):
|
| 89 |
+
def __init__(
|
| 90 |
+
self,
|
| 91 |
+
*,
|
| 92 |
+
dim,
|
| 93 |
+
depth=8,
|
| 94 |
+
heads=8,
|
| 95 |
+
dim_head=64,
|
| 96 |
+
dropout=0.1,
|
| 97 |
+
ff_mult=4,
|
| 98 |
+
mel_dim=100,
|
| 99 |
+
text_dim=None,
|
| 100 |
+
conv_layers=0,
|
| 101 |
+
long_skip_connection=False,
|
| 102 |
+
):
|
| 103 |
+
super().__init__()
|
| 104 |
+
|
| 105 |
+
self.time_embed = TimestepEmbedding(dim)
|
| 106 |
+
self.d_embed = TimestepEmbedding(dim)
|
| 107 |
+
if text_dim is None:
|
| 108 |
+
text_dim = mel_dim
|
| 109 |
+
self.text_embed = TextEmbedding(text_dim, conv_layers=conv_layers)
|
| 110 |
+
self.input_embed = InputEmbedding(mel_dim, text_dim, dim)
|
| 111 |
+
|
| 112 |
+
self.rotary_embed = RotaryEmbedding(dim_head)
|
| 113 |
+
|
| 114 |
+
self.dim = dim
|
| 115 |
+
self.depth = depth
|
| 116 |
+
|
| 117 |
+
self.transformer_blocks = nn.ModuleList(
|
| 118 |
+
[DiTBlock(dim=dim, heads=heads, dim_head=dim_head, ff_mult=ff_mult, dropout=dropout) for _ in range(depth)]
|
| 119 |
+
)
|
| 120 |
+
self.long_skip_connection = nn.Linear(dim * 2, dim, bias=False) if long_skip_connection else None
|
| 121 |
+
|
| 122 |
+
self.norm_out = AdaLayerNormZero_Final(dim) # final modulation
|
| 123 |
+
self.proj_out = nn.Linear(dim, mel_dim)
|
| 124 |
+
|
| 125 |
+
def ckpt_wrapper(self, module):
|
| 126 |
+
# https://github.com/chuanyangjin/fast-DiT/blob/main/models.py
|
| 127 |
+
def ckpt_forward(*inputs):
|
| 128 |
+
outputs = module(*inputs)
|
| 129 |
+
return outputs
|
| 130 |
+
|
| 131 |
+
return ckpt_forward
|
| 132 |
+
|
| 133 |
+
def forward( # x, prompt_x, x_lens, t, style,cond
|
| 134 |
+
self, # d is channel,n is T
|
| 135 |
+
x0: float["b n d"], # nosied input audio # noqa: F722
|
| 136 |
+
cond0: float["b n d"], # masked cond audio # noqa: F722
|
| 137 |
+
x_lens,
|
| 138 |
+
time: float["b"] | float[""], # time step # noqa: F821 F722
|
| 139 |
+
dt_base_bootstrap,
|
| 140 |
+
text0, # : int["b nt"] # noqa: F722#####condition feature
|
| 141 |
+
use_grad_ckpt=False, # bool
|
| 142 |
+
###no-use
|
| 143 |
+
drop_audio_cond=False, # cfg for cond audio
|
| 144 |
+
drop_text=False, # cfg for text
|
| 145 |
+
# mask: bool["b n"] | None = None, # noqa: F722
|
| 146 |
+
infer=False, # bool
|
| 147 |
+
text_cache=None, # torch tensor as text_embed
|
| 148 |
+
dt_cache=None, # torch tensor as dt
|
| 149 |
+
):
|
| 150 |
+
x = x0.transpose(2, 1)
|
| 151 |
+
cond = cond0.transpose(2, 1)
|
| 152 |
+
text = text0.transpose(2, 1)
|
| 153 |
+
mask = sequence_mask(x_lens, max_length=x.size(1)).to(x.device)
|
| 154 |
+
|
| 155 |
+
batch, seq_len = x.shape[0], x.shape[1]
|
| 156 |
+
if time.ndim == 0:
|
| 157 |
+
time = time.repeat(batch)
|
| 158 |
+
|
| 159 |
+
# t: conditioning time, c: context (text + masked cond audio), x: noised input audio
|
| 160 |
+
t = self.time_embed(time)
|
| 161 |
+
if infer and dt_cache is not None:
|
| 162 |
+
dt = dt_cache
|
| 163 |
+
else:
|
| 164 |
+
dt = self.d_embed(dt_base_bootstrap)
|
| 165 |
+
t += dt
|
| 166 |
+
|
| 167 |
+
if infer and text_cache is not None:
|
| 168 |
+
text_embed = text_cache
|
| 169 |
+
else:
|
| 170 |
+
text_embed = self.text_embed(text, seq_len, drop_text=drop_text) ###need to change
|
| 171 |
+
|
| 172 |
+
x = self.input_embed(x, cond, text_embed, drop_audio_cond=drop_audio_cond)
|
| 173 |
+
|
| 174 |
+
rope = self.rotary_embed.forward_from_seq_len(seq_len)
|
| 175 |
+
|
| 176 |
+
if self.long_skip_connection is not None:
|
| 177 |
+
residual = x
|
| 178 |
+
|
| 179 |
+
for block in self.transformer_blocks:
|
| 180 |
+
if use_grad_ckpt:
|
| 181 |
+
x = checkpoint(self.ckpt_wrapper(block), x, t, mask, rope, use_reentrant=False)
|
| 182 |
+
else:
|
| 183 |
+
x = block(x, t, mask=mask, rope=rope)
|
| 184 |
+
|
| 185 |
+
if self.long_skip_connection is not None:
|
| 186 |
+
x = self.long_skip_connection(torch.cat((x, residual), dim=-1))
|
| 187 |
+
|
| 188 |
+
x = self.norm_out(x, t)
|
| 189 |
+
output = self.proj_out(x)
|
| 190 |
+
|
| 191 |
+
if infer:
|
| 192 |
+
return output, text_embed, dt
|
| 193 |
+
else:
|
| 194 |
+
return output
|
f5_tts/model/backbones/mmdit.py
ADDED
|
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
ein notation:
|
| 3 |
+
b - batch
|
| 4 |
+
n - sequence
|
| 5 |
+
nt - text sequence
|
| 6 |
+
nw - raw wave length
|
| 7 |
+
d - dimension
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
from __future__ import annotations
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
from torch import nn
|
| 14 |
+
|
| 15 |
+
from x_transformers.x_transformers import RotaryEmbedding
|
| 16 |
+
|
| 17 |
+
from f5_tts.model.modules import (
|
| 18 |
+
TimestepEmbedding,
|
| 19 |
+
ConvPositionEmbedding,
|
| 20 |
+
MMDiTBlock,
|
| 21 |
+
AdaLayerNormZero_Final,
|
| 22 |
+
precompute_freqs_cis,
|
| 23 |
+
get_pos_embed_indices,
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
# text embedding
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class TextEmbedding(nn.Module):
|
| 31 |
+
def __init__(self, out_dim, text_num_embeds):
|
| 32 |
+
super().__init__()
|
| 33 |
+
self.text_embed = nn.Embedding(text_num_embeds + 1, out_dim) # will use 0 as filler token
|
| 34 |
+
|
| 35 |
+
self.precompute_max_pos = 1024
|
| 36 |
+
self.register_buffer("freqs_cis", precompute_freqs_cis(out_dim, self.precompute_max_pos), persistent=False)
|
| 37 |
+
|
| 38 |
+
def forward(self, text: int["b nt"], drop_text=False) -> int["b nt d"]: # noqa: F722
|
| 39 |
+
text = text + 1
|
| 40 |
+
if drop_text:
|
| 41 |
+
text = torch.zeros_like(text)
|
| 42 |
+
text = self.text_embed(text)
|
| 43 |
+
|
| 44 |
+
# sinus pos emb
|
| 45 |
+
batch_start = torch.zeros((text.shape[0],), dtype=torch.long)
|
| 46 |
+
batch_text_len = text.shape[1]
|
| 47 |
+
pos_idx = get_pos_embed_indices(batch_start, batch_text_len, max_pos=self.precompute_max_pos)
|
| 48 |
+
text_pos_embed = self.freqs_cis[pos_idx]
|
| 49 |
+
|
| 50 |
+
text = text + text_pos_embed
|
| 51 |
+
|
| 52 |
+
return text
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
# noised input & masked cond audio embedding
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class AudioEmbedding(nn.Module):
|
| 59 |
+
def __init__(self, in_dim, out_dim):
|
| 60 |
+
super().__init__()
|
| 61 |
+
self.linear = nn.Linear(2 * in_dim, out_dim)
|
| 62 |
+
self.conv_pos_embed = ConvPositionEmbedding(out_dim)
|
| 63 |
+
|
| 64 |
+
def forward(self, x: float["b n d"], cond: float["b n d"], drop_audio_cond=False): # noqa: F722
|
| 65 |
+
if drop_audio_cond:
|
| 66 |
+
cond = torch.zeros_like(cond)
|
| 67 |
+
x = torch.cat((x, cond), dim=-1)
|
| 68 |
+
x = self.linear(x)
|
| 69 |
+
x = self.conv_pos_embed(x) + x
|
| 70 |
+
return x
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
# Transformer backbone using MM-DiT blocks
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
class MMDiT(nn.Module):
|
| 77 |
+
def __init__(
|
| 78 |
+
self,
|
| 79 |
+
*,
|
| 80 |
+
dim,
|
| 81 |
+
depth=8,
|
| 82 |
+
heads=8,
|
| 83 |
+
dim_head=64,
|
| 84 |
+
dropout=0.1,
|
| 85 |
+
ff_mult=4,
|
| 86 |
+
text_num_embeds=256,
|
| 87 |
+
mel_dim=100,
|
| 88 |
+
):
|
| 89 |
+
super().__init__()
|
| 90 |
+
|
| 91 |
+
self.time_embed = TimestepEmbedding(dim)
|
| 92 |
+
self.text_embed = TextEmbedding(dim, text_num_embeds)
|
| 93 |
+
self.audio_embed = AudioEmbedding(mel_dim, dim)
|
| 94 |
+
|
| 95 |
+
self.rotary_embed = RotaryEmbedding(dim_head)
|
| 96 |
+
|
| 97 |
+
self.dim = dim
|
| 98 |
+
self.depth = depth
|
| 99 |
+
|
| 100 |
+
self.transformer_blocks = nn.ModuleList(
|
| 101 |
+
[
|
| 102 |
+
MMDiTBlock(
|
| 103 |
+
dim=dim,
|
| 104 |
+
heads=heads,
|
| 105 |
+
dim_head=dim_head,
|
| 106 |
+
dropout=dropout,
|
| 107 |
+
ff_mult=ff_mult,
|
| 108 |
+
context_pre_only=i == depth - 1,
|
| 109 |
+
)
|
| 110 |
+
for i in range(depth)
|
| 111 |
+
]
|
| 112 |
+
)
|
| 113 |
+
self.norm_out = AdaLayerNormZero_Final(dim) # final modulation
|
| 114 |
+
self.proj_out = nn.Linear(dim, mel_dim)
|
| 115 |
+
|
| 116 |
+
def forward(
|
| 117 |
+
self,
|
| 118 |
+
x: float["b n d"], # nosied input audio # noqa: F722
|
| 119 |
+
cond: float["b n d"], # masked cond audio # noqa: F722
|
| 120 |
+
text: int["b nt"], # text # noqa: F722
|
| 121 |
+
time: float["b"] | float[""], # time step # noqa: F821 F722
|
| 122 |
+
drop_audio_cond, # cfg for cond audio
|
| 123 |
+
drop_text, # cfg for text
|
| 124 |
+
mask: bool["b n"] | None = None, # noqa: F722
|
| 125 |
+
):
|
| 126 |
+
batch = x.shape[0]
|
| 127 |
+
if time.ndim == 0:
|
| 128 |
+
time = time.repeat(batch)
|
| 129 |
+
|
| 130 |
+
# t: conditioning (time), c: context (text + masked cond audio), x: noised input audio
|
| 131 |
+
t = self.time_embed(time)
|
| 132 |
+
c = self.text_embed(text, drop_text=drop_text)
|
| 133 |
+
x = self.audio_embed(x, cond, drop_audio_cond=drop_audio_cond)
|
| 134 |
+
|
| 135 |
+
seq_len = x.shape[1]
|
| 136 |
+
text_len = text.shape[1]
|
| 137 |
+
rope_audio = self.rotary_embed.forward_from_seq_len(seq_len)
|
| 138 |
+
rope_text = self.rotary_embed.forward_from_seq_len(text_len)
|
| 139 |
+
|
| 140 |
+
for block in self.transformer_blocks:
|
| 141 |
+
c, x = block(x, c, t, mask=mask, rope=rope_audio, c_rope=rope_text)
|
| 142 |
+
|
| 143 |
+
x = self.norm_out(x, t)
|
| 144 |
+
output = self.proj_out(x)
|
| 145 |
+
|
| 146 |
+
return output
|
f5_tts/model/backbones/unett.py
ADDED
|
@@ -0,0 +1,219 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
ein notation:
|
| 3 |
+
b - batch
|
| 4 |
+
n - sequence
|
| 5 |
+
nt - text sequence
|
| 6 |
+
nw - raw wave length
|
| 7 |
+
d - dimension
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
from __future__ import annotations
|
| 11 |
+
from typing import Literal
|
| 12 |
+
|
| 13 |
+
import torch
|
| 14 |
+
from torch import nn
|
| 15 |
+
import torch.nn.functional as F
|
| 16 |
+
|
| 17 |
+
from x_transformers import RMSNorm
|
| 18 |
+
from x_transformers.x_transformers import RotaryEmbedding
|
| 19 |
+
|
| 20 |
+
from f5_tts.model.modules import (
|
| 21 |
+
TimestepEmbedding,
|
| 22 |
+
ConvNeXtV2Block,
|
| 23 |
+
ConvPositionEmbedding,
|
| 24 |
+
Attention,
|
| 25 |
+
AttnProcessor,
|
| 26 |
+
FeedForward,
|
| 27 |
+
precompute_freqs_cis,
|
| 28 |
+
get_pos_embed_indices,
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
# Text embedding
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class TextEmbedding(nn.Module):
|
| 36 |
+
def __init__(self, text_num_embeds, text_dim, conv_layers=0, conv_mult=2):
|
| 37 |
+
super().__init__()
|
| 38 |
+
self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim) # use 0 as filler token
|
| 39 |
+
|
| 40 |
+
if conv_layers > 0:
|
| 41 |
+
self.extra_modeling = True
|
| 42 |
+
self.precompute_max_pos = 4096 # ~44s of 24khz audio
|
| 43 |
+
self.register_buffer("freqs_cis", precompute_freqs_cis(text_dim, self.precompute_max_pos), persistent=False)
|
| 44 |
+
self.text_blocks = nn.Sequential(
|
| 45 |
+
*[ConvNeXtV2Block(text_dim, text_dim * conv_mult) for _ in range(conv_layers)]
|
| 46 |
+
)
|
| 47 |
+
else:
|
| 48 |
+
self.extra_modeling = False
|
| 49 |
+
|
| 50 |
+
def forward(self, text: int["b nt"], seq_len, drop_text=False): # noqa: F722
|
| 51 |
+
text = text + 1 # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx()
|
| 52 |
+
text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens
|
| 53 |
+
batch, text_len = text.shape[0], text.shape[1]
|
| 54 |
+
text = F.pad(text, (0, seq_len - text_len), value=0)
|
| 55 |
+
|
| 56 |
+
if drop_text: # cfg for text
|
| 57 |
+
text = torch.zeros_like(text)
|
| 58 |
+
|
| 59 |
+
text = self.text_embed(text) # b n -> b n d
|
| 60 |
+
|
| 61 |
+
# possible extra modeling
|
| 62 |
+
if self.extra_modeling:
|
| 63 |
+
# sinus pos emb
|
| 64 |
+
batch_start = torch.zeros((batch,), dtype=torch.long)
|
| 65 |
+
pos_idx = get_pos_embed_indices(batch_start, seq_len, max_pos=self.precompute_max_pos)
|
| 66 |
+
text_pos_embed = self.freqs_cis[pos_idx]
|
| 67 |
+
text = text + text_pos_embed
|
| 68 |
+
|
| 69 |
+
# convnextv2 blocks
|
| 70 |
+
text = self.text_blocks(text)
|
| 71 |
+
|
| 72 |
+
return text
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
# noised input audio and context mixing embedding
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
class InputEmbedding(nn.Module):
|
| 79 |
+
def __init__(self, mel_dim, text_dim, out_dim):
|
| 80 |
+
super().__init__()
|
| 81 |
+
self.proj = nn.Linear(mel_dim * 2 + text_dim, out_dim)
|
| 82 |
+
self.conv_pos_embed = ConvPositionEmbedding(dim=out_dim)
|
| 83 |
+
|
| 84 |
+
def forward(self, x: float["b n d"], cond: float["b n d"], text_embed: float["b n d"], drop_audio_cond=False): # noqa: F722
|
| 85 |
+
if drop_audio_cond: # cfg for cond audio
|
| 86 |
+
cond = torch.zeros_like(cond)
|
| 87 |
+
|
| 88 |
+
x = self.proj(torch.cat((x, cond, text_embed), dim=-1))
|
| 89 |
+
x = self.conv_pos_embed(x) + x
|
| 90 |
+
return x
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
# Flat UNet Transformer backbone
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
class UNetT(nn.Module):
|
| 97 |
+
def __init__(
|
| 98 |
+
self,
|
| 99 |
+
*,
|
| 100 |
+
dim,
|
| 101 |
+
depth=8,
|
| 102 |
+
heads=8,
|
| 103 |
+
dim_head=64,
|
| 104 |
+
dropout=0.1,
|
| 105 |
+
ff_mult=4,
|
| 106 |
+
mel_dim=100,
|
| 107 |
+
text_num_embeds=256,
|
| 108 |
+
text_dim=None,
|
| 109 |
+
conv_layers=0,
|
| 110 |
+
skip_connect_type: Literal["add", "concat", "none"] = "concat",
|
| 111 |
+
):
|
| 112 |
+
super().__init__()
|
| 113 |
+
assert depth % 2 == 0, "UNet-Transformer's depth should be even."
|
| 114 |
+
|
| 115 |
+
self.time_embed = TimestepEmbedding(dim)
|
| 116 |
+
if text_dim is None:
|
| 117 |
+
text_dim = mel_dim
|
| 118 |
+
self.text_embed = TextEmbedding(text_num_embeds, text_dim, conv_layers=conv_layers)
|
| 119 |
+
self.input_embed = InputEmbedding(mel_dim, text_dim, dim)
|
| 120 |
+
|
| 121 |
+
self.rotary_embed = RotaryEmbedding(dim_head)
|
| 122 |
+
|
| 123 |
+
# transformer layers & skip connections
|
| 124 |
+
|
| 125 |
+
self.dim = dim
|
| 126 |
+
self.skip_connect_type = skip_connect_type
|
| 127 |
+
needs_skip_proj = skip_connect_type == "concat"
|
| 128 |
+
|
| 129 |
+
self.depth = depth
|
| 130 |
+
self.layers = nn.ModuleList([])
|
| 131 |
+
|
| 132 |
+
for idx in range(depth):
|
| 133 |
+
is_later_half = idx >= (depth // 2)
|
| 134 |
+
|
| 135 |
+
attn_norm = RMSNorm(dim)
|
| 136 |
+
attn = Attention(
|
| 137 |
+
processor=AttnProcessor(),
|
| 138 |
+
dim=dim,
|
| 139 |
+
heads=heads,
|
| 140 |
+
dim_head=dim_head,
|
| 141 |
+
dropout=dropout,
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
ff_norm = RMSNorm(dim)
|
| 145 |
+
ff = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh")
|
| 146 |
+
|
| 147 |
+
skip_proj = nn.Linear(dim * 2, dim, bias=False) if needs_skip_proj and is_later_half else None
|
| 148 |
+
|
| 149 |
+
self.layers.append(
|
| 150 |
+
nn.ModuleList(
|
| 151 |
+
[
|
| 152 |
+
skip_proj,
|
| 153 |
+
attn_norm,
|
| 154 |
+
attn,
|
| 155 |
+
ff_norm,
|
| 156 |
+
ff,
|
| 157 |
+
]
|
| 158 |
+
)
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
self.norm_out = RMSNorm(dim)
|
| 162 |
+
self.proj_out = nn.Linear(dim, mel_dim)
|
| 163 |
+
|
| 164 |
+
def forward(
|
| 165 |
+
self,
|
| 166 |
+
x: float["b n d"], # nosied input audio # noqa: F722
|
| 167 |
+
cond: float["b n d"], # masked cond audio # noqa: F722
|
| 168 |
+
text: int["b nt"], # text # noqa: F722
|
| 169 |
+
time: float["b"] | float[""], # time step # noqa: F821 F722
|
| 170 |
+
drop_audio_cond, # cfg for cond audio
|
| 171 |
+
drop_text, # cfg for text
|
| 172 |
+
mask: bool["b n"] | None = None, # noqa: F722
|
| 173 |
+
):
|
| 174 |
+
batch, seq_len = x.shape[0], x.shape[1]
|
| 175 |
+
if time.ndim == 0:
|
| 176 |
+
time = time.repeat(batch)
|
| 177 |
+
|
| 178 |
+
# t: conditioning time, c: context (text + masked cond audio), x: noised input audio
|
| 179 |
+
t = self.time_embed(time)
|
| 180 |
+
text_embed = self.text_embed(text, seq_len, drop_text=drop_text)
|
| 181 |
+
x = self.input_embed(x, cond, text_embed, drop_audio_cond=drop_audio_cond)
|
| 182 |
+
|
| 183 |
+
# postfix time t to input x, [b n d] -> [b n+1 d]
|
| 184 |
+
x = torch.cat([t.unsqueeze(1), x], dim=1) # pack t to x
|
| 185 |
+
if mask is not None:
|
| 186 |
+
mask = F.pad(mask, (1, 0), value=1)
|
| 187 |
+
|
| 188 |
+
rope = self.rotary_embed.forward_from_seq_len(seq_len + 1)
|
| 189 |
+
|
| 190 |
+
# flat unet transformer
|
| 191 |
+
skip_connect_type = self.skip_connect_type
|
| 192 |
+
skips = []
|
| 193 |
+
for idx, (maybe_skip_proj, attn_norm, attn, ff_norm, ff) in enumerate(self.layers):
|
| 194 |
+
layer = idx + 1
|
| 195 |
+
|
| 196 |
+
# skip connection logic
|
| 197 |
+
is_first_half = layer <= (self.depth // 2)
|
| 198 |
+
is_later_half = not is_first_half
|
| 199 |
+
|
| 200 |
+
if is_first_half:
|
| 201 |
+
skips.append(x)
|
| 202 |
+
|
| 203 |
+
if is_later_half:
|
| 204 |
+
skip = skips.pop()
|
| 205 |
+
if skip_connect_type == "concat":
|
| 206 |
+
x = torch.cat((x, skip), dim=-1)
|
| 207 |
+
x = maybe_skip_proj(x)
|
| 208 |
+
elif skip_connect_type == "add":
|
| 209 |
+
x = x + skip
|
| 210 |
+
|
| 211 |
+
# attention and feedforward blocks
|
| 212 |
+
x = attn(attn_norm(x), rope=rope, mask=mask) + x
|
| 213 |
+
x = ff(ff_norm(x)) + x
|
| 214 |
+
|
| 215 |
+
assert len(skips) == 0
|
| 216 |
+
|
| 217 |
+
x = self.norm_out(x)[:, 1:, :] # unpack t from x
|
| 218 |
+
|
| 219 |
+
return self.proj_out(x)
|
f5_tts/model/modules.py
ADDED
|
@@ -0,0 +1,666 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
ein notation:
|
| 3 |
+
b - batch
|
| 4 |
+
n - sequence
|
| 5 |
+
nt - text sequence
|
| 6 |
+
nw - raw wave length
|
| 7 |
+
d - dimension
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
from __future__ import annotations
|
| 11 |
+
|
| 12 |
+
import math
|
| 13 |
+
from typing import Optional
|
| 14 |
+
|
| 15 |
+
import torch
|
| 16 |
+
import torch.nn.functional as F
|
| 17 |
+
import torchaudio
|
| 18 |
+
from librosa.filters import mel as librosa_mel_fn
|
| 19 |
+
from torch import nn
|
| 20 |
+
from x_transformers.x_transformers import apply_rotary_pos_emb
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
# raw wav to mel spec
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
mel_basis_cache = {}
|
| 27 |
+
hann_window_cache = {}
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def get_bigvgan_mel_spectrogram(
|
| 31 |
+
waveform,
|
| 32 |
+
n_fft=1024,
|
| 33 |
+
n_mel_channels=100,
|
| 34 |
+
target_sample_rate=24000,
|
| 35 |
+
hop_length=256,
|
| 36 |
+
win_length=1024,
|
| 37 |
+
fmin=0,
|
| 38 |
+
fmax=None,
|
| 39 |
+
center=False,
|
| 40 |
+
): # Copy from https://github.com/NVIDIA/BigVGAN/tree/main
|
| 41 |
+
device = waveform.device
|
| 42 |
+
key = f"{n_fft}_{n_mel_channels}_{target_sample_rate}_{hop_length}_{win_length}_{fmin}_{fmax}_{device}"
|
| 43 |
+
|
| 44 |
+
if key not in mel_basis_cache:
|
| 45 |
+
mel = librosa_mel_fn(sr=target_sample_rate, n_fft=n_fft, n_mels=n_mel_channels, fmin=fmin, fmax=fmax)
|
| 46 |
+
mel_basis_cache[key] = torch.from_numpy(mel).float().to(device) # TODO: why they need .float()?
|
| 47 |
+
hann_window_cache[key] = torch.hann_window(win_length).to(device)
|
| 48 |
+
|
| 49 |
+
mel_basis = mel_basis_cache[key]
|
| 50 |
+
hann_window = hann_window_cache[key]
|
| 51 |
+
|
| 52 |
+
padding = (n_fft - hop_length) // 2
|
| 53 |
+
waveform = torch.nn.functional.pad(waveform.unsqueeze(1), (padding, padding), mode="reflect").squeeze(1)
|
| 54 |
+
|
| 55 |
+
spec = torch.stft(
|
| 56 |
+
waveform,
|
| 57 |
+
n_fft,
|
| 58 |
+
hop_length=hop_length,
|
| 59 |
+
win_length=win_length,
|
| 60 |
+
window=hann_window,
|
| 61 |
+
center=center,
|
| 62 |
+
pad_mode="reflect",
|
| 63 |
+
normalized=False,
|
| 64 |
+
onesided=True,
|
| 65 |
+
return_complex=True,
|
| 66 |
+
)
|
| 67 |
+
spec = torch.sqrt(torch.view_as_real(spec).pow(2).sum(-1) + 1e-9)
|
| 68 |
+
|
| 69 |
+
mel_spec = torch.matmul(mel_basis, spec)
|
| 70 |
+
mel_spec = torch.log(torch.clamp(mel_spec, min=1e-5))
|
| 71 |
+
|
| 72 |
+
return mel_spec
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def get_vocos_mel_spectrogram(
|
| 76 |
+
waveform,
|
| 77 |
+
n_fft=1024,
|
| 78 |
+
n_mel_channels=100,
|
| 79 |
+
target_sample_rate=24000,
|
| 80 |
+
hop_length=256,
|
| 81 |
+
win_length=1024,
|
| 82 |
+
):
|
| 83 |
+
mel_stft = torchaudio.transforms.MelSpectrogram(
|
| 84 |
+
sample_rate=target_sample_rate,
|
| 85 |
+
n_fft=n_fft,
|
| 86 |
+
win_length=win_length,
|
| 87 |
+
hop_length=hop_length,
|
| 88 |
+
n_mels=n_mel_channels,
|
| 89 |
+
power=1,
|
| 90 |
+
center=True,
|
| 91 |
+
normalized=False,
|
| 92 |
+
norm=None,
|
| 93 |
+
).to(waveform.device)
|
| 94 |
+
if len(waveform.shape) == 3:
|
| 95 |
+
waveform = waveform.squeeze(1) # 'b 1 nw -> b nw'
|
| 96 |
+
|
| 97 |
+
assert len(waveform.shape) == 2
|
| 98 |
+
|
| 99 |
+
mel = mel_stft(waveform)
|
| 100 |
+
mel = mel.clamp(min=1e-5).log()
|
| 101 |
+
return mel
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
class MelSpec(nn.Module):
|
| 105 |
+
def __init__(
|
| 106 |
+
self,
|
| 107 |
+
n_fft=1024,
|
| 108 |
+
hop_length=256,
|
| 109 |
+
win_length=1024,
|
| 110 |
+
n_mel_channels=100,
|
| 111 |
+
target_sample_rate=24_000,
|
| 112 |
+
mel_spec_type="vocos",
|
| 113 |
+
):
|
| 114 |
+
super().__init__()
|
| 115 |
+
assert mel_spec_type in ["vocos", "bigvgan"], print("We only support two extract mel backend: vocos or bigvgan")
|
| 116 |
+
|
| 117 |
+
self.n_fft = n_fft
|
| 118 |
+
self.hop_length = hop_length
|
| 119 |
+
self.win_length = win_length
|
| 120 |
+
self.n_mel_channels = n_mel_channels
|
| 121 |
+
self.target_sample_rate = target_sample_rate
|
| 122 |
+
|
| 123 |
+
if mel_spec_type == "vocos":
|
| 124 |
+
self.extractor = get_vocos_mel_spectrogram
|
| 125 |
+
elif mel_spec_type == "bigvgan":
|
| 126 |
+
self.extractor = get_bigvgan_mel_spectrogram
|
| 127 |
+
|
| 128 |
+
self.register_buffer("dummy", torch.tensor(0), persistent=False)
|
| 129 |
+
|
| 130 |
+
def forward(self, wav):
|
| 131 |
+
if self.dummy.device != wav.device:
|
| 132 |
+
self.to(wav.device)
|
| 133 |
+
|
| 134 |
+
mel = self.extractor(
|
| 135 |
+
waveform=wav,
|
| 136 |
+
n_fft=self.n_fft,
|
| 137 |
+
n_mel_channels=self.n_mel_channels,
|
| 138 |
+
target_sample_rate=self.target_sample_rate,
|
| 139 |
+
hop_length=self.hop_length,
|
| 140 |
+
win_length=self.win_length,
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
return mel
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
# sinusoidal position embedding
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
class SinusPositionEmbedding(nn.Module):
|
| 150 |
+
def __init__(self, dim):
|
| 151 |
+
super().__init__()
|
| 152 |
+
self.dim = dim
|
| 153 |
+
|
| 154 |
+
def forward(self, x, scale=1000):
|
| 155 |
+
device = x.device
|
| 156 |
+
half_dim = self.dim // 2
|
| 157 |
+
emb = math.log(10000) / (half_dim - 1)
|
| 158 |
+
emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb)
|
| 159 |
+
emb = scale * x.unsqueeze(1) * emb.unsqueeze(0)
|
| 160 |
+
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
|
| 161 |
+
return emb
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
# convolutional position embedding
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
class ConvPositionEmbedding(nn.Module):
|
| 168 |
+
def __init__(self, dim, kernel_size=31, groups=16):
|
| 169 |
+
super().__init__()
|
| 170 |
+
assert kernel_size % 2 != 0
|
| 171 |
+
self.conv1d = nn.Sequential(
|
| 172 |
+
nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2),
|
| 173 |
+
nn.Mish(),
|
| 174 |
+
nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2),
|
| 175 |
+
nn.Mish(),
|
| 176 |
+
)
|
| 177 |
+
|
| 178 |
+
def forward(self, x: float["b n d"], mask: bool["b n"] | None = None): # noqa: F722
|
| 179 |
+
if mask is not None:
|
| 180 |
+
mask = mask[..., None]
|
| 181 |
+
x = x.masked_fill(~mask, 0.0)
|
| 182 |
+
|
| 183 |
+
x = x.permute(0, 2, 1)
|
| 184 |
+
x = self.conv1d(x)
|
| 185 |
+
out = x.permute(0, 2, 1)
|
| 186 |
+
|
| 187 |
+
if mask is not None:
|
| 188 |
+
out = out.masked_fill(~mask, 0.0)
|
| 189 |
+
|
| 190 |
+
return out
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
# rotary positional embedding related
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, theta_rescale_factor=1.0):
|
| 197 |
+
# proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
|
| 198 |
+
# has some connection to NTK literature
|
| 199 |
+
# https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
|
| 200 |
+
# https://github.com/lucidrains/rotary-embedding-torch/blob/main/rotary_embedding_torch/rotary_embedding_torch.py
|
| 201 |
+
theta *= theta_rescale_factor ** (dim / (dim - 2))
|
| 202 |
+
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
|
| 203 |
+
t = torch.arange(end, device=freqs.device) # type: ignore
|
| 204 |
+
freqs = torch.outer(t, freqs).float() # type: ignore
|
| 205 |
+
freqs_cos = torch.cos(freqs) # real part
|
| 206 |
+
freqs_sin = torch.sin(freqs) # imaginary part
|
| 207 |
+
return torch.cat([freqs_cos, freqs_sin], dim=-1)
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
def get_pos_embed_indices(start, length, max_pos, scale=1.0):
|
| 211 |
+
# length = length if isinstance(length, int) else length.max()
|
| 212 |
+
scale = scale * torch.ones_like(start, dtype=torch.float32) # in case scale is a scalar
|
| 213 |
+
pos = (
|
| 214 |
+
start.unsqueeze(1)
|
| 215 |
+
+ (torch.arange(length, device=start.device, dtype=torch.float32).unsqueeze(0) * scale.unsqueeze(1)).long()
|
| 216 |
+
)
|
| 217 |
+
# avoid extra long error.
|
| 218 |
+
pos = torch.where(pos < max_pos, pos, max_pos - 1)
|
| 219 |
+
return pos
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
# Global Response Normalization layer (Instance Normalization ?)
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
class GRN(nn.Module):
|
| 226 |
+
def __init__(self, dim):
|
| 227 |
+
super().__init__()
|
| 228 |
+
self.gamma = nn.Parameter(torch.zeros(1, 1, dim))
|
| 229 |
+
self.beta = nn.Parameter(torch.zeros(1, 1, dim))
|
| 230 |
+
|
| 231 |
+
def forward(self, x):
|
| 232 |
+
Gx = torch.norm(x, p=2, dim=1, keepdim=True)
|
| 233 |
+
Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
|
| 234 |
+
return self.gamma * (x * Nx) + self.beta + x
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
# ConvNeXt-V2 Block https://github.com/facebookresearch/ConvNeXt-V2/blob/main/models/convnextv2.py
|
| 238 |
+
# ref: https://github.com/bfs18/e2_tts/blob/main/rfwave/modules.py#L108
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
class ConvNeXtV2Block(nn.Module):
|
| 242 |
+
def __init__(
|
| 243 |
+
self,
|
| 244 |
+
dim: int,
|
| 245 |
+
intermediate_dim: int,
|
| 246 |
+
dilation: int = 1,
|
| 247 |
+
):
|
| 248 |
+
super().__init__()
|
| 249 |
+
padding = (dilation * (7 - 1)) // 2
|
| 250 |
+
self.dwconv = nn.Conv1d(
|
| 251 |
+
dim, dim, kernel_size=7, padding=padding, groups=dim, dilation=dilation
|
| 252 |
+
) # depthwise conv
|
| 253 |
+
self.norm = nn.LayerNorm(dim, eps=1e-6)
|
| 254 |
+
self.pwconv1 = nn.Linear(dim, intermediate_dim) # pointwise/1x1 convs, implemented with linear layers
|
| 255 |
+
self.act = nn.GELU()
|
| 256 |
+
self.grn = GRN(intermediate_dim)
|
| 257 |
+
self.pwconv2 = nn.Linear(intermediate_dim, dim)
|
| 258 |
+
|
| 259 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 260 |
+
residual = x
|
| 261 |
+
x = x.transpose(1, 2) # b n d -> b d n
|
| 262 |
+
x = self.dwconv(x)
|
| 263 |
+
x = x.transpose(1, 2) # b d n -> b n d
|
| 264 |
+
x = self.norm(x)
|
| 265 |
+
x = self.pwconv1(x)
|
| 266 |
+
x = self.act(x)
|
| 267 |
+
x = self.grn(x)
|
| 268 |
+
x = self.pwconv2(x)
|
| 269 |
+
return residual + x
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
# AdaLayerNormZero
|
| 273 |
+
# return with modulated x for attn input, and params for later mlp modulation
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
class AdaLayerNormZero(nn.Module):
|
| 277 |
+
def __init__(self, dim):
|
| 278 |
+
super().__init__()
|
| 279 |
+
|
| 280 |
+
self.silu = nn.SiLU()
|
| 281 |
+
self.linear = nn.Linear(dim, dim * 6)
|
| 282 |
+
|
| 283 |
+
self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
| 284 |
+
|
| 285 |
+
def forward(self, x, emb=None):
|
| 286 |
+
emb = self.linear(self.silu(emb))
|
| 287 |
+
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = torch.chunk(emb, 6, dim=1)
|
| 288 |
+
|
| 289 |
+
x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
|
| 290 |
+
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
# AdaLayerNormZero for final layer
|
| 294 |
+
# return only with modulated x for attn input, cuz no more mlp modulation
|
| 295 |
+
|
| 296 |
+
|
| 297 |
+
class AdaLayerNormZero_Final(nn.Module):
|
| 298 |
+
def __init__(self, dim):
|
| 299 |
+
super().__init__()
|
| 300 |
+
|
| 301 |
+
self.silu = nn.SiLU()
|
| 302 |
+
self.linear = nn.Linear(dim, dim * 2)
|
| 303 |
+
|
| 304 |
+
self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
| 305 |
+
|
| 306 |
+
def forward(self, x, emb):
|
| 307 |
+
emb = self.linear(self.silu(emb))
|
| 308 |
+
scale, shift = torch.chunk(emb, 2, dim=1)
|
| 309 |
+
|
| 310 |
+
x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
|
| 311 |
+
return x
|
| 312 |
+
|
| 313 |
+
|
| 314 |
+
# FeedForward
|
| 315 |
+
|
| 316 |
+
|
| 317 |
+
class FeedForward(nn.Module):
|
| 318 |
+
def __init__(self, dim, dim_out=None, mult=4, dropout=0.0, approximate: str = "none"):
|
| 319 |
+
super().__init__()
|
| 320 |
+
inner_dim = int(dim * mult)
|
| 321 |
+
dim_out = dim_out if dim_out is not None else dim
|
| 322 |
+
|
| 323 |
+
activation = nn.GELU(approximate=approximate)
|
| 324 |
+
project_in = nn.Sequential(nn.Linear(dim, inner_dim), activation)
|
| 325 |
+
self.ff = nn.Sequential(project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out))
|
| 326 |
+
|
| 327 |
+
def forward(self, x):
|
| 328 |
+
return self.ff(x)
|
| 329 |
+
|
| 330 |
+
|
| 331 |
+
# Attention with possible joint part
|
| 332 |
+
# modified from diffusers/src/diffusers/models/attention_processor.py
|
| 333 |
+
|
| 334 |
+
|
| 335 |
+
class Attention(nn.Module):
|
| 336 |
+
def __init__(
|
| 337 |
+
self,
|
| 338 |
+
processor: JointAttnProcessor | AttnProcessor,
|
| 339 |
+
dim: int,
|
| 340 |
+
heads: int = 8,
|
| 341 |
+
dim_head: int = 64,
|
| 342 |
+
dropout: float = 0.0,
|
| 343 |
+
context_dim: Optional[int] = None, # if not None -> joint attention
|
| 344 |
+
context_pre_only=None,
|
| 345 |
+
):
|
| 346 |
+
super().__init__()
|
| 347 |
+
|
| 348 |
+
if not hasattr(F, "scaled_dot_product_attention"):
|
| 349 |
+
raise ImportError("Attention equires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
| 350 |
+
|
| 351 |
+
self.processor = processor
|
| 352 |
+
|
| 353 |
+
self.dim = dim
|
| 354 |
+
self.heads = heads
|
| 355 |
+
self.inner_dim = dim_head * heads
|
| 356 |
+
self.dropout = dropout
|
| 357 |
+
|
| 358 |
+
self.context_dim = context_dim
|
| 359 |
+
self.context_pre_only = context_pre_only
|
| 360 |
+
|
| 361 |
+
self.to_q = nn.Linear(dim, self.inner_dim)
|
| 362 |
+
self.to_k = nn.Linear(dim, self.inner_dim)
|
| 363 |
+
self.to_v = nn.Linear(dim, self.inner_dim)
|
| 364 |
+
|
| 365 |
+
if self.context_dim is not None:
|
| 366 |
+
self.to_k_c = nn.Linear(context_dim, self.inner_dim)
|
| 367 |
+
self.to_v_c = nn.Linear(context_dim, self.inner_dim)
|
| 368 |
+
if self.context_pre_only is not None:
|
| 369 |
+
self.to_q_c = nn.Linear(context_dim, self.inner_dim)
|
| 370 |
+
|
| 371 |
+
self.to_out = nn.ModuleList([])
|
| 372 |
+
self.to_out.append(nn.Linear(self.inner_dim, dim))
|
| 373 |
+
self.to_out.append(nn.Dropout(dropout))
|
| 374 |
+
|
| 375 |
+
if self.context_pre_only is not None and not self.context_pre_only:
|
| 376 |
+
self.to_out_c = nn.Linear(self.inner_dim, dim)
|
| 377 |
+
|
| 378 |
+
def forward(
|
| 379 |
+
self,
|
| 380 |
+
x: float["b n d"], # noised input x # noqa: F722
|
| 381 |
+
c: float["b n d"] = None, # context c # noqa: F722
|
| 382 |
+
mask: bool["b n"] | None = None, # noqa: F722
|
| 383 |
+
rope=None, # rotary position embedding for x
|
| 384 |
+
c_rope=None, # rotary position embedding for c
|
| 385 |
+
) -> torch.Tensor:
|
| 386 |
+
if c is not None:
|
| 387 |
+
return self.processor(self, x, c=c, mask=mask, rope=rope, c_rope=c_rope)
|
| 388 |
+
else:
|
| 389 |
+
return self.processor(self, x, mask=mask, rope=rope)
|
| 390 |
+
|
| 391 |
+
|
| 392 |
+
# Attention processor
|
| 393 |
+
|
| 394 |
+
|
| 395 |
+
# from torch.nn.attention import SDPBackend
|
| 396 |
+
# torch.backends.cuda.enable_flash_sdp(True)
|
| 397 |
+
class AttnProcessor:
|
| 398 |
+
def __init__(self):
|
| 399 |
+
pass
|
| 400 |
+
|
| 401 |
+
def __call__(
|
| 402 |
+
self,
|
| 403 |
+
attn: Attention,
|
| 404 |
+
x: float["b n d"], # noised input x # noqa: F722
|
| 405 |
+
mask: bool["b n"] | None = None, # noqa: F722
|
| 406 |
+
rope=None, # rotary position embedding
|
| 407 |
+
) -> torch.FloatTensor:
|
| 408 |
+
batch_size = x.shape[0]
|
| 409 |
+
|
| 410 |
+
# `sample` projections.
|
| 411 |
+
query = attn.to_q(x)
|
| 412 |
+
key = attn.to_k(x)
|
| 413 |
+
value = attn.to_v(x)
|
| 414 |
+
|
| 415 |
+
# apply rotary position embedding
|
| 416 |
+
if rope is not None:
|
| 417 |
+
freqs, xpos_scale = rope
|
| 418 |
+
q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0)
|
| 419 |
+
|
| 420 |
+
query = apply_rotary_pos_emb(query, freqs, q_xpos_scale)
|
| 421 |
+
key = apply_rotary_pos_emb(key, freqs, k_xpos_scale)
|
| 422 |
+
|
| 423 |
+
# attention
|
| 424 |
+
inner_dim = key.shape[-1]
|
| 425 |
+
head_dim = inner_dim // attn.heads
|
| 426 |
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 427 |
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 428 |
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 429 |
+
|
| 430 |
+
# mask. e.g. inference got a batch with different target durations, mask out the padding
|
| 431 |
+
if mask is not None:
|
| 432 |
+
attn_mask = mask
|
| 433 |
+
attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) # 'b n -> b 1 1 n'
|
| 434 |
+
# print(3433333333,attn_mask.shape)
|
| 435 |
+
attn_mask = attn_mask.expand(batch_size, attn.heads, query.shape[-2], key.shape[-2])
|
| 436 |
+
else:
|
| 437 |
+
attn_mask = None
|
| 438 |
+
# with torch.nn.attention.sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]):
|
| 439 |
+
# with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=True):
|
| 440 |
+
# with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=False):
|
| 441 |
+
# print(torch.backends.cuda.flash_sdp_enabled())
|
| 442 |
+
# print(torch.backends.cuda.mem_efficient_sdp_enabled())
|
| 443 |
+
# print(torch.backends.cuda.math_sdp_enabled())
|
| 444 |
+
x = F.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False)
|
| 445 |
+
x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
| 446 |
+
x = x.to(query.dtype)
|
| 447 |
+
|
| 448 |
+
# linear proj
|
| 449 |
+
x = attn.to_out[0](x)
|
| 450 |
+
# dropout
|
| 451 |
+
x = attn.to_out[1](x)
|
| 452 |
+
|
| 453 |
+
if mask is not None:
|
| 454 |
+
mask = mask.unsqueeze(-1)
|
| 455 |
+
x = x.masked_fill(~mask, 0.0)
|
| 456 |
+
|
| 457 |
+
return x
|
| 458 |
+
|
| 459 |
+
|
| 460 |
+
# Joint Attention processor for MM-DiT
|
| 461 |
+
# modified from diffusers/src/diffusers/models/attention_processor.py
|
| 462 |
+
|
| 463 |
+
|
| 464 |
+
class JointAttnProcessor:
|
| 465 |
+
def __init__(self):
|
| 466 |
+
pass
|
| 467 |
+
|
| 468 |
+
def __call__(
|
| 469 |
+
self,
|
| 470 |
+
attn: Attention,
|
| 471 |
+
x: float["b n d"], # noised input x # noqa: F722
|
| 472 |
+
c: float["b nt d"] = None, # context c, here text # noqa: F722
|
| 473 |
+
mask: bool["b n"] | None = None, # noqa: F722
|
| 474 |
+
rope=None, # rotary position embedding for x
|
| 475 |
+
c_rope=None, # rotary position embedding for c
|
| 476 |
+
) -> torch.FloatTensor:
|
| 477 |
+
residual = x
|
| 478 |
+
|
| 479 |
+
batch_size = c.shape[0]
|
| 480 |
+
|
| 481 |
+
# `sample` projections.
|
| 482 |
+
query = attn.to_q(x)
|
| 483 |
+
key = attn.to_k(x)
|
| 484 |
+
value = attn.to_v(x)
|
| 485 |
+
|
| 486 |
+
# `context` projections.
|
| 487 |
+
c_query = attn.to_q_c(c)
|
| 488 |
+
c_key = attn.to_k_c(c)
|
| 489 |
+
c_value = attn.to_v_c(c)
|
| 490 |
+
|
| 491 |
+
# apply rope for context and noised input independently
|
| 492 |
+
if rope is not None:
|
| 493 |
+
freqs, xpos_scale = rope
|
| 494 |
+
q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0)
|
| 495 |
+
query = apply_rotary_pos_emb(query, freqs, q_xpos_scale)
|
| 496 |
+
key = apply_rotary_pos_emb(key, freqs, k_xpos_scale)
|
| 497 |
+
if c_rope is not None:
|
| 498 |
+
freqs, xpos_scale = c_rope
|
| 499 |
+
q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0)
|
| 500 |
+
c_query = apply_rotary_pos_emb(c_query, freqs, q_xpos_scale)
|
| 501 |
+
c_key = apply_rotary_pos_emb(c_key, freqs, k_xpos_scale)
|
| 502 |
+
|
| 503 |
+
# attention
|
| 504 |
+
query = torch.cat([query, c_query], dim=1)
|
| 505 |
+
key = torch.cat([key, c_key], dim=1)
|
| 506 |
+
value = torch.cat([value, c_value], dim=1)
|
| 507 |
+
|
| 508 |
+
inner_dim = key.shape[-1]
|
| 509 |
+
head_dim = inner_dim // attn.heads
|
| 510 |
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 511 |
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 512 |
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 513 |
+
|
| 514 |
+
# mask. e.g. inference got a batch with different target durations, mask out the padding
|
| 515 |
+
if mask is not None:
|
| 516 |
+
attn_mask = F.pad(mask, (0, c.shape[1]), value=True) # no mask for c (text)
|
| 517 |
+
attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) # 'b n -> b 1 1 n'
|
| 518 |
+
attn_mask = attn_mask.expand(batch_size, attn.heads, query.shape[-2], key.shape[-2])
|
| 519 |
+
else:
|
| 520 |
+
attn_mask = None
|
| 521 |
+
|
| 522 |
+
x = F.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False)
|
| 523 |
+
x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
| 524 |
+
x = x.to(query.dtype)
|
| 525 |
+
|
| 526 |
+
# Split the attention outputs.
|
| 527 |
+
x, c = (
|
| 528 |
+
x[:, : residual.shape[1]],
|
| 529 |
+
x[:, residual.shape[1] :],
|
| 530 |
+
)
|
| 531 |
+
|
| 532 |
+
# linear proj
|
| 533 |
+
x = attn.to_out[0](x)
|
| 534 |
+
# dropout
|
| 535 |
+
x = attn.to_out[1](x)
|
| 536 |
+
if not attn.context_pre_only:
|
| 537 |
+
c = attn.to_out_c(c)
|
| 538 |
+
|
| 539 |
+
if mask is not None:
|
| 540 |
+
mask = mask.unsqueeze(-1)
|
| 541 |
+
x = x.masked_fill(~mask, 0.0)
|
| 542 |
+
# c = c.masked_fill(~mask, 0.) # no mask for c (text)
|
| 543 |
+
|
| 544 |
+
return x, c
|
| 545 |
+
|
| 546 |
+
|
| 547 |
+
# DiT Block
|
| 548 |
+
|
| 549 |
+
|
| 550 |
+
class DiTBlock(nn.Module):
|
| 551 |
+
def __init__(self, dim, heads, dim_head, ff_mult=4, dropout=0.1):
|
| 552 |
+
super().__init__()
|
| 553 |
+
|
| 554 |
+
self.attn_norm = AdaLayerNormZero(dim)
|
| 555 |
+
self.attn = Attention(
|
| 556 |
+
processor=AttnProcessor(),
|
| 557 |
+
dim=dim,
|
| 558 |
+
heads=heads,
|
| 559 |
+
dim_head=dim_head,
|
| 560 |
+
dropout=dropout,
|
| 561 |
+
)
|
| 562 |
+
|
| 563 |
+
self.ff_norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
| 564 |
+
self.ff = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh")
|
| 565 |
+
|
| 566 |
+
def forward(self, x, t, mask=None, rope=None): # x: noised input, t: time embedding
|
| 567 |
+
# pre-norm & modulation for attention input
|
| 568 |
+
norm, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.attn_norm(x, emb=t)
|
| 569 |
+
|
| 570 |
+
# attention
|
| 571 |
+
attn_output = self.attn(x=norm, mask=mask, rope=rope)
|
| 572 |
+
|
| 573 |
+
# process attention output for input x
|
| 574 |
+
x = x + gate_msa.unsqueeze(1) * attn_output
|
| 575 |
+
|
| 576 |
+
norm = self.ff_norm(x) * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
| 577 |
+
ff_output = self.ff(norm)
|
| 578 |
+
x = x + gate_mlp.unsqueeze(1) * ff_output
|
| 579 |
+
|
| 580 |
+
return x
|
| 581 |
+
|
| 582 |
+
|
| 583 |
+
# MMDiT Block https://arxiv.org/abs/2403.03206
|
| 584 |
+
|
| 585 |
+
|
| 586 |
+
class MMDiTBlock(nn.Module):
|
| 587 |
+
r"""
|
| 588 |
+
modified from diffusers/src/diffusers/models/attention.py
|
| 589 |
+
|
| 590 |
+
notes.
|
| 591 |
+
_c: context related. text, cond, etc. (left part in sd3 fig2.b)
|
| 592 |
+
_x: noised input related. (right part)
|
| 593 |
+
context_pre_only: last layer only do prenorm + modulation cuz no more ffn
|
| 594 |
+
"""
|
| 595 |
+
|
| 596 |
+
def __init__(self, dim, heads, dim_head, ff_mult=4, dropout=0.1, context_pre_only=False):
|
| 597 |
+
super().__init__()
|
| 598 |
+
|
| 599 |
+
self.context_pre_only = context_pre_only
|
| 600 |
+
|
| 601 |
+
self.attn_norm_c = AdaLayerNormZero_Final(dim) if context_pre_only else AdaLayerNormZero(dim)
|
| 602 |
+
self.attn_norm_x = AdaLayerNormZero(dim)
|
| 603 |
+
self.attn = Attention(
|
| 604 |
+
processor=JointAttnProcessor(),
|
| 605 |
+
dim=dim,
|
| 606 |
+
heads=heads,
|
| 607 |
+
dim_head=dim_head,
|
| 608 |
+
dropout=dropout,
|
| 609 |
+
context_dim=dim,
|
| 610 |
+
context_pre_only=context_pre_only,
|
| 611 |
+
)
|
| 612 |
+
|
| 613 |
+
if not context_pre_only:
|
| 614 |
+
self.ff_norm_c = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
| 615 |
+
self.ff_c = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh")
|
| 616 |
+
else:
|
| 617 |
+
self.ff_norm_c = None
|
| 618 |
+
self.ff_c = None
|
| 619 |
+
self.ff_norm_x = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
| 620 |
+
self.ff_x = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh")
|
| 621 |
+
|
| 622 |
+
def forward(self, x, c, t, mask=None, rope=None, c_rope=None): # x: noised input, c: context, t: time embedding
|
| 623 |
+
# pre-norm & modulation for attention input
|
| 624 |
+
if self.context_pre_only:
|
| 625 |
+
norm_c = self.attn_norm_c(c, t)
|
| 626 |
+
else:
|
| 627 |
+
norm_c, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.attn_norm_c(c, emb=t)
|
| 628 |
+
norm_x, x_gate_msa, x_shift_mlp, x_scale_mlp, x_gate_mlp = self.attn_norm_x(x, emb=t)
|
| 629 |
+
|
| 630 |
+
# attention
|
| 631 |
+
x_attn_output, c_attn_output = self.attn(x=norm_x, c=norm_c, mask=mask, rope=rope, c_rope=c_rope)
|
| 632 |
+
|
| 633 |
+
# process attention output for context c
|
| 634 |
+
if self.context_pre_only:
|
| 635 |
+
c = None
|
| 636 |
+
else: # if not last layer
|
| 637 |
+
c = c + c_gate_msa.unsqueeze(1) * c_attn_output
|
| 638 |
+
|
| 639 |
+
norm_c = self.ff_norm_c(c) * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
|
| 640 |
+
c_ff_output = self.ff_c(norm_c)
|
| 641 |
+
c = c + c_gate_mlp.unsqueeze(1) * c_ff_output
|
| 642 |
+
|
| 643 |
+
# process attention output for input x
|
| 644 |
+
x = x + x_gate_msa.unsqueeze(1) * x_attn_output
|
| 645 |
+
|
| 646 |
+
norm_x = self.ff_norm_x(x) * (1 + x_scale_mlp[:, None]) + x_shift_mlp[:, None]
|
| 647 |
+
x_ff_output = self.ff_x(norm_x)
|
| 648 |
+
x = x + x_gate_mlp.unsqueeze(1) * x_ff_output
|
| 649 |
+
|
| 650 |
+
return c, x
|
| 651 |
+
|
| 652 |
+
|
| 653 |
+
# time step conditioning embedding
|
| 654 |
+
|
| 655 |
+
|
| 656 |
+
class TimestepEmbedding(nn.Module):
|
| 657 |
+
def __init__(self, dim, freq_embed_dim=256):
|
| 658 |
+
super().__init__()
|
| 659 |
+
self.time_embed = SinusPositionEmbedding(freq_embed_dim)
|
| 660 |
+
self.time_mlp = nn.Sequential(nn.Linear(freq_embed_dim, dim), nn.SiLU(), nn.Linear(dim, dim))
|
| 661 |
+
|
| 662 |
+
def forward(self, timestep: float["b"]): # noqa: F821
|
| 663 |
+
time_hidden = self.time_embed(timestep)
|
| 664 |
+
time_hidden = time_hidden.to(timestep.dtype)
|
| 665 |
+
time = self.time_mlp(time_hidden) # b d
|
| 666 |
+
return time
|
prepare_datasets/1-get-text.py
ADDED
|
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
|
| 5 |
+
inp_text = os.environ.get("inp_text")
|
| 6 |
+
inp_wav_dir = os.environ.get("inp_wav_dir")
|
| 7 |
+
exp_name = os.environ.get("exp_name")
|
| 8 |
+
i_part = os.environ.get("i_part")
|
| 9 |
+
all_parts = os.environ.get("all_parts")
|
| 10 |
+
if "_CUDA_VISIBLE_DEVICES" in os.environ:
|
| 11 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"]
|
| 12 |
+
opt_dir = os.environ.get("opt_dir")
|
| 13 |
+
bert_pretrained_dir = os.environ.get("bert_pretrained_dir")
|
| 14 |
+
import torch
|
| 15 |
+
|
| 16 |
+
is_half = eval(os.environ.get("is_half", "True")) and torch.cuda.is_available()
|
| 17 |
+
version = os.environ.get("version", None)
|
| 18 |
+
import traceback
|
| 19 |
+
import os.path
|
| 20 |
+
from text.cleaner import clean_text
|
| 21 |
+
from transformers import AutoModelForMaskedLM, AutoTokenizer
|
| 22 |
+
from tools.my_utils import clean_path
|
| 23 |
+
|
| 24 |
+
# inp_text=sys.argv[1]
|
| 25 |
+
# inp_wav_dir=sys.argv[2]
|
| 26 |
+
# exp_name=sys.argv[3]
|
| 27 |
+
# i_part=sys.argv[4]
|
| 28 |
+
# all_parts=sys.argv[5]
|
| 29 |
+
# os.environ["CUDA_VISIBLE_DEVICES"]=sys.argv[6]#i_gpu
|
| 30 |
+
# opt_dir="/data/docker/liujing04/gpt-vits/fine_tune_dataset/%s"%exp_name
|
| 31 |
+
# bert_pretrained_dir="/data/docker/liujing04/bert-vits2/Bert-VITS2-master20231106/bert/chinese-roberta-wwm-ext-large"
|
| 32 |
+
|
| 33 |
+
from time import time as ttime
|
| 34 |
+
import shutil
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def my_save(fea, path): #####fix issue: torch.save doesn't support chinese path
|
| 38 |
+
dir = os.path.dirname(path)
|
| 39 |
+
name = os.path.basename(path)
|
| 40 |
+
# tmp_path="%s/%s%s.pth"%(dir,ttime(),i_part)
|
| 41 |
+
tmp_path = "%s%s.pth" % (ttime(), i_part)
|
| 42 |
+
torch.save(fea, tmp_path)
|
| 43 |
+
shutil.move(tmp_path, "%s/%s" % (dir, name))
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
txt_path = "%s/2-name2text-%s.txt" % (opt_dir, i_part)
|
| 47 |
+
if os.path.exists(txt_path) == False:
|
| 48 |
+
bert_dir = "%s/3-bert" % (opt_dir)
|
| 49 |
+
os.makedirs(opt_dir, exist_ok=True)
|
| 50 |
+
os.makedirs(bert_dir, exist_ok=True)
|
| 51 |
+
if torch.cuda.is_available():
|
| 52 |
+
device = "cuda:0"
|
| 53 |
+
# elif torch.backends.mps.is_available():
|
| 54 |
+
# device = "mps"
|
| 55 |
+
else:
|
| 56 |
+
device = "cpu"
|
| 57 |
+
if os.path.exists(bert_pretrained_dir):
|
| 58 |
+
...
|
| 59 |
+
else:
|
| 60 |
+
raise FileNotFoundError(bert_pretrained_dir)
|
| 61 |
+
tokenizer = AutoTokenizer.from_pretrained(bert_pretrained_dir)
|
| 62 |
+
bert_model = AutoModelForMaskedLM.from_pretrained(bert_pretrained_dir)
|
| 63 |
+
if is_half == True:
|
| 64 |
+
bert_model = bert_model.half().to(device)
|
| 65 |
+
else:
|
| 66 |
+
bert_model = bert_model.to(device)
|
| 67 |
+
|
| 68 |
+
def get_bert_feature(text, word2ph):
|
| 69 |
+
with torch.no_grad():
|
| 70 |
+
inputs = tokenizer(text, return_tensors="pt")
|
| 71 |
+
for i in inputs:
|
| 72 |
+
inputs[i] = inputs[i].to(device)
|
| 73 |
+
res = bert_model(**inputs, output_hidden_states=True)
|
| 74 |
+
res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()[1:-1]
|
| 75 |
+
|
| 76 |
+
assert len(word2ph) == len(text)
|
| 77 |
+
phone_level_feature = []
|
| 78 |
+
for i in range(len(word2ph)):
|
| 79 |
+
repeat_feature = res[i].repeat(word2ph[i], 1)
|
| 80 |
+
phone_level_feature.append(repeat_feature)
|
| 81 |
+
|
| 82 |
+
phone_level_feature = torch.cat(phone_level_feature, dim=0)
|
| 83 |
+
|
| 84 |
+
return phone_level_feature.T
|
| 85 |
+
|
| 86 |
+
def process(data, res):
|
| 87 |
+
for name, text, lan in data:
|
| 88 |
+
try:
|
| 89 |
+
name = clean_path(name)
|
| 90 |
+
name = os.path.basename(name)
|
| 91 |
+
print(name)
|
| 92 |
+
phones, word2ph, norm_text = clean_text(text.replace("%", "-").replace("¥", ","), lan, version)
|
| 93 |
+
path_bert = "%s/%s.pt" % (bert_dir, name)
|
| 94 |
+
if os.path.exists(path_bert) == False and lan == "zh":
|
| 95 |
+
bert_feature = get_bert_feature(norm_text, word2ph)
|
| 96 |
+
assert bert_feature.shape[-1] == len(phones)
|
| 97 |
+
# torch.save(bert_feature, path_bert)
|
| 98 |
+
my_save(bert_feature, path_bert)
|
| 99 |
+
phones = " ".join(phones)
|
| 100 |
+
# res.append([name,phones])
|
| 101 |
+
res.append([name, phones, word2ph, norm_text])
|
| 102 |
+
except:
|
| 103 |
+
print(name, text, traceback.format_exc())
|
| 104 |
+
|
| 105 |
+
todo = []
|
| 106 |
+
res = []
|
| 107 |
+
with open(inp_text, "r", encoding="utf8") as f:
|
| 108 |
+
lines = f.read().strip("\n").split("\n")
|
| 109 |
+
|
| 110 |
+
language_v1_to_language_v2 = {
|
| 111 |
+
"ZH": "zh",
|
| 112 |
+
"zh": "zh",
|
| 113 |
+
"JP": "ja",
|
| 114 |
+
"jp": "ja",
|
| 115 |
+
"JA": "ja",
|
| 116 |
+
"ja": "ja",
|
| 117 |
+
"EN": "en",
|
| 118 |
+
"en": "en",
|
| 119 |
+
"En": "en",
|
| 120 |
+
"KO": "ko",
|
| 121 |
+
"Ko": "ko",
|
| 122 |
+
"ko": "ko",
|
| 123 |
+
"yue": "yue",
|
| 124 |
+
"YUE": "yue",
|
| 125 |
+
"Yue": "yue",
|
| 126 |
+
}
|
| 127 |
+
for line in lines[int(i_part) :: int(all_parts)]:
|
| 128 |
+
try:
|
| 129 |
+
wav_name, spk_name, language, text = line.split("|")
|
| 130 |
+
# todo.append([name,text,"zh"])
|
| 131 |
+
if language in language_v1_to_language_v2.keys():
|
| 132 |
+
todo.append([wav_name, text, language_v1_to_language_v2.get(language, language)])
|
| 133 |
+
else:
|
| 134 |
+
print(f"\033[33m[Waring] The {language = } of {wav_name} is not supported for training.\033[0m")
|
| 135 |
+
except:
|
| 136 |
+
print(line, traceback.format_exc())
|
| 137 |
+
|
| 138 |
+
process(todo, res)
|
| 139 |
+
opt = []
|
| 140 |
+
for name, phones, word2ph, norm_text in res:
|
| 141 |
+
opt.append("%s\t%s\t%s\t%s" % (name, phones, word2ph, norm_text))
|
| 142 |
+
with open(txt_path, "w", encoding="utf8") as f:
|
| 143 |
+
f.write("\n".join(opt) + "\n")
|
prepare_datasets/2-get-hubert-wav32k.py
ADDED
|
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
|
| 3 |
+
import sys
|
| 4 |
+
import os
|
| 5 |
+
|
| 6 |
+
inp_text = os.environ.get("inp_text")
|
| 7 |
+
inp_wav_dir = os.environ.get("inp_wav_dir")
|
| 8 |
+
exp_name = os.environ.get("exp_name")
|
| 9 |
+
i_part = os.environ.get("i_part")
|
| 10 |
+
all_parts = os.environ.get("all_parts")
|
| 11 |
+
if "_CUDA_VISIBLE_DEVICES" in os.environ:
|
| 12 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"]
|
| 13 |
+
from feature_extractor import cnhubert
|
| 14 |
+
|
| 15 |
+
opt_dir = os.environ.get("opt_dir")
|
| 16 |
+
cnhubert.cnhubert_base_path = os.environ.get("cnhubert_base_dir")
|
| 17 |
+
import torch
|
| 18 |
+
|
| 19 |
+
is_half = eval(os.environ.get("is_half", "True")) and torch.cuda.is_available()
|
| 20 |
+
|
| 21 |
+
import traceback
|
| 22 |
+
import numpy as np
|
| 23 |
+
from scipy.io import wavfile
|
| 24 |
+
import librosa
|
| 25 |
+
|
| 26 |
+
now_dir = os.getcwd()
|
| 27 |
+
sys.path.append(now_dir)
|
| 28 |
+
from tools.my_utils import load_audio, clean_path
|
| 29 |
+
|
| 30 |
+
# from config import cnhubert_base_path
|
| 31 |
+
# cnhubert.cnhubert_base_path=cnhubert_base_path
|
| 32 |
+
# inp_text=sys.argv[1]
|
| 33 |
+
# inp_wav_dir=sys.argv[2]
|
| 34 |
+
# exp_name=sys.argv[3]
|
| 35 |
+
# i_part=sys.argv[4]
|
| 36 |
+
# all_parts=sys.argv[5]
|
| 37 |
+
# os.environ["CUDA_VISIBLE_DEVICES"]=sys.argv[6]
|
| 38 |
+
# cnhubert.cnhubert_base_path=sys.argv[7]
|
| 39 |
+
# opt_dir="/data/docker/liujing04/gpt-vits/fine_tune_dataset/%s"%exp_name
|
| 40 |
+
|
| 41 |
+
from time import time as ttime
|
| 42 |
+
import shutil
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def my_save(fea, path): #####fix issue: torch.save doesn't support chinese path
|
| 46 |
+
dir = os.path.dirname(path)
|
| 47 |
+
name = os.path.basename(path)
|
| 48 |
+
# tmp_path="%s/%s%s.pth"%(dir,ttime(),i_part)
|
| 49 |
+
tmp_path = "%s%s.pth" % (ttime(), i_part)
|
| 50 |
+
torch.save(fea, tmp_path)
|
| 51 |
+
shutil.move(tmp_path, "%s/%s" % (dir, name))
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
hubert_dir = "%s/4-cnhubert" % (opt_dir)
|
| 55 |
+
wav32dir = "%s/5-wav32k" % (opt_dir)
|
| 56 |
+
os.makedirs(opt_dir, exist_ok=True)
|
| 57 |
+
os.makedirs(hubert_dir, exist_ok=True)
|
| 58 |
+
os.makedirs(wav32dir, exist_ok=True)
|
| 59 |
+
|
| 60 |
+
maxx = 0.95
|
| 61 |
+
alpha = 0.5
|
| 62 |
+
if torch.cuda.is_available():
|
| 63 |
+
device = "cuda:0"
|
| 64 |
+
# elif torch.backends.mps.is_available():
|
| 65 |
+
# device = "mps"
|
| 66 |
+
else:
|
| 67 |
+
device = "cpu"
|
| 68 |
+
model = cnhubert.get_model()
|
| 69 |
+
# is_half=False
|
| 70 |
+
if is_half == True:
|
| 71 |
+
model = model.half().to(device)
|
| 72 |
+
else:
|
| 73 |
+
model = model.to(device)
|
| 74 |
+
|
| 75 |
+
nan_fails = []
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def name2go(wav_name, wav_path):
|
| 79 |
+
hubert_path = "%s/%s.pt" % (hubert_dir, wav_name)
|
| 80 |
+
if os.path.exists(hubert_path):
|
| 81 |
+
return
|
| 82 |
+
tmp_audio = load_audio(wav_path, 32000)
|
| 83 |
+
tmp_max = np.abs(tmp_audio).max()
|
| 84 |
+
if tmp_max > 2.2:
|
| 85 |
+
print("%s-filtered,%s" % (wav_name, tmp_max))
|
| 86 |
+
return
|
| 87 |
+
tmp_audio32 = (tmp_audio / tmp_max * (maxx * alpha * 32768)) + ((1 - alpha) * 32768) * tmp_audio
|
| 88 |
+
tmp_audio32b = (tmp_audio / tmp_max * (maxx * alpha * 1145.14)) + ((1 - alpha) * 1145.14) * tmp_audio
|
| 89 |
+
tmp_audio = librosa.resample(tmp_audio32b, orig_sr=32000, target_sr=16000) # 不是重采样问题
|
| 90 |
+
tensor_wav16 = torch.from_numpy(tmp_audio)
|
| 91 |
+
if is_half == True:
|
| 92 |
+
tensor_wav16 = tensor_wav16.half().to(device)
|
| 93 |
+
else:
|
| 94 |
+
tensor_wav16 = tensor_wav16.to(device)
|
| 95 |
+
ssl = model.model(tensor_wav16.unsqueeze(0))["last_hidden_state"].transpose(1, 2).cpu() # torch.Size([1, 768, 215])
|
| 96 |
+
if np.isnan(ssl.detach().numpy()).sum() != 0:
|
| 97 |
+
nan_fails.append((wav_name, wav_path))
|
| 98 |
+
print("nan filtered:%s" % wav_name)
|
| 99 |
+
return
|
| 100 |
+
wavfile.write(
|
| 101 |
+
"%s/%s" % (wav32dir, wav_name),
|
| 102 |
+
32000,
|
| 103 |
+
tmp_audio32.astype("int16"),
|
| 104 |
+
)
|
| 105 |
+
my_save(ssl, hubert_path)
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
with open(inp_text, "r", encoding="utf8") as f:
|
| 109 |
+
lines = f.read().strip("\n").split("\n")
|
| 110 |
+
|
| 111 |
+
for line in lines[int(i_part) :: int(all_parts)]:
|
| 112 |
+
try:
|
| 113 |
+
# wav_name,text=line.split("\t")
|
| 114 |
+
wav_name, spk_name, language, text = line.split("|")
|
| 115 |
+
wav_name = clean_path(wav_name)
|
| 116 |
+
if inp_wav_dir != "" and inp_wav_dir != None:
|
| 117 |
+
wav_name = os.path.basename(wav_name)
|
| 118 |
+
wav_path = "%s/%s" % (inp_wav_dir, wav_name)
|
| 119 |
+
|
| 120 |
+
else:
|
| 121 |
+
wav_path = wav_name
|
| 122 |
+
wav_name = os.path.basename(wav_name)
|
| 123 |
+
name2go(wav_name, wav_path)
|
| 124 |
+
except:
|
| 125 |
+
print(line, traceback.format_exc())
|
| 126 |
+
|
| 127 |
+
if len(nan_fails) > 0 and is_half == True:
|
| 128 |
+
is_half = False
|
| 129 |
+
model = model.float()
|
| 130 |
+
for wav in nan_fails:
|
| 131 |
+
try:
|
| 132 |
+
name2go(wav[0], wav[1])
|
| 133 |
+
except:
|
| 134 |
+
print(wav_name, traceback.format_exc())
|
prepare_datasets/2-get-sv.py
ADDED
|
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
|
| 3 |
+
import sys
|
| 4 |
+
import os
|
| 5 |
+
|
| 6 |
+
inp_text = os.environ.get("inp_text")
|
| 7 |
+
inp_wav_dir = os.environ.get("inp_wav_dir")
|
| 8 |
+
exp_name = os.environ.get("exp_name")
|
| 9 |
+
i_part = os.environ.get("i_part")
|
| 10 |
+
all_parts = os.environ.get("all_parts")
|
| 11 |
+
if "_CUDA_VISIBLE_DEVICES" in os.environ:
|
| 12 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"]
|
| 13 |
+
|
| 14 |
+
opt_dir = os.environ.get("opt_dir")
|
| 15 |
+
sv_path = os.environ.get("sv_path")
|
| 16 |
+
import torch
|
| 17 |
+
|
| 18 |
+
is_half = eval(os.environ.get("is_half", "True")) and torch.cuda.is_available()
|
| 19 |
+
|
| 20 |
+
import traceback
|
| 21 |
+
import torchaudio
|
| 22 |
+
|
| 23 |
+
now_dir = os.getcwd()
|
| 24 |
+
sys.path.append(now_dir)
|
| 25 |
+
sys.path.append(f"{now_dir}/GPT_SoVITS/eres2net")
|
| 26 |
+
from tools.my_utils import clean_path
|
| 27 |
+
from time import time as ttime
|
| 28 |
+
import shutil
|
| 29 |
+
from ERes2NetV2 import ERes2NetV2
|
| 30 |
+
import kaldi as Kaldi
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def my_save(fea, path): #####fix issue: torch.save doesn't support chinese path
|
| 34 |
+
dir = os.path.dirname(path)
|
| 35 |
+
name = os.path.basename(path)
|
| 36 |
+
# tmp_path="%s/%s%s.pth"%(dir,ttime(),i_part)
|
| 37 |
+
tmp_path = "%s%s.pth" % (ttime(), i_part)
|
| 38 |
+
torch.save(fea, tmp_path)
|
| 39 |
+
shutil.move(tmp_path, "%s/%s" % (dir, name))
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
sv_cn_dir = "%s/7-sv_cn" % (opt_dir)
|
| 43 |
+
wav32dir = "%s/5-wav32k" % (opt_dir)
|
| 44 |
+
os.makedirs(opt_dir, exist_ok=True)
|
| 45 |
+
os.makedirs(sv_cn_dir, exist_ok=True)
|
| 46 |
+
os.makedirs(wav32dir, exist_ok=True)
|
| 47 |
+
|
| 48 |
+
maxx = 0.95
|
| 49 |
+
alpha = 0.5
|
| 50 |
+
if torch.cuda.is_available():
|
| 51 |
+
device = "cuda:0"
|
| 52 |
+
# elif torch.backends.mps.is_available():
|
| 53 |
+
# device = "mps"
|
| 54 |
+
else:
|
| 55 |
+
device = "cpu"
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class SV:
|
| 59 |
+
def __init__(self, device, is_half):
|
| 60 |
+
pretrained_state = torch.load(sv_path, map_location="cpu")
|
| 61 |
+
embedding_model = ERes2NetV2(baseWidth=24, scale=4, expansion=4)
|
| 62 |
+
embedding_model.load_state_dict(pretrained_state)
|
| 63 |
+
embedding_model.eval()
|
| 64 |
+
self.embedding_model = embedding_model
|
| 65 |
+
self.res = torchaudio.transforms.Resample(32000, 16000).to(device)
|
| 66 |
+
if is_half == False:
|
| 67 |
+
self.embedding_model = self.embedding_model.to(device)
|
| 68 |
+
else:
|
| 69 |
+
self.embedding_model = self.embedding_model.half().to(device)
|
| 70 |
+
self.is_half = is_half
|
| 71 |
+
|
| 72 |
+
def compute_embedding3(self, wav): # (1,x)#-1~1
|
| 73 |
+
with torch.no_grad():
|
| 74 |
+
wav = self.res(wav)
|
| 75 |
+
if self.is_half == True:
|
| 76 |
+
wav = wav.half()
|
| 77 |
+
feat = torch.stack(
|
| 78 |
+
[Kaldi.fbank(wav0.unsqueeze(0), num_mel_bins=80, sample_frequency=16000, dither=0) for wav0 in wav]
|
| 79 |
+
)
|
| 80 |
+
sv_emb = self.embedding_model.forward3(feat)
|
| 81 |
+
return sv_emb
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
sv = SV(device, is_half)
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def name2go(wav_name, wav_path):
|
| 88 |
+
sv_cn_path = "%s/%s.pt" % (sv_cn_dir, wav_name)
|
| 89 |
+
if os.path.exists(sv_cn_path):
|
| 90 |
+
return
|
| 91 |
+
wav_path = "%s/%s" % (wav32dir, wav_name)
|
| 92 |
+
wav32k, sr0 = torchaudio.load(wav_path)
|
| 93 |
+
assert sr0 == 32000
|
| 94 |
+
wav32k = wav32k.to(device)
|
| 95 |
+
emb = sv.compute_embedding3(wav32k).cpu() # torch.Size([1, 20480])
|
| 96 |
+
my_save(emb, sv_cn_path)
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
with open(inp_text, "r", encoding="utf8") as f:
|
| 100 |
+
lines = f.read().strip("\n").split("\n")
|
| 101 |
+
|
| 102 |
+
for line in lines[int(i_part) :: int(all_parts)]:
|
| 103 |
+
try:
|
| 104 |
+
wav_name, spk_name, language, text = line.split("|")
|
| 105 |
+
wav_name = clean_path(wav_name)
|
| 106 |
+
if inp_wav_dir != "" and inp_wav_dir != None:
|
| 107 |
+
wav_name = os.path.basename(wav_name)
|
| 108 |
+
wav_path = "%s/%s" % (inp_wav_dir, wav_name)
|
| 109 |
+
|
| 110 |
+
else:
|
| 111 |
+
wav_path = wav_name
|
| 112 |
+
wav_name = os.path.basename(wav_name)
|
| 113 |
+
name2go(wav_name, wav_path)
|
| 114 |
+
except:
|
| 115 |
+
print(line, traceback.format_exc())
|
prepare_datasets/3-get-semantic.py
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
inp_text = os.environ.get("inp_text")
|
| 4 |
+
exp_name = os.environ.get("exp_name")
|
| 5 |
+
i_part = os.environ.get("i_part")
|
| 6 |
+
all_parts = os.environ.get("all_parts")
|
| 7 |
+
if "_CUDA_VISIBLE_DEVICES" in os.environ:
|
| 8 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"]
|
| 9 |
+
opt_dir = os.environ.get("opt_dir")
|
| 10 |
+
pretrained_s2G = os.environ.get("pretrained_s2G")
|
| 11 |
+
s2config_path = os.environ.get("s2config_path")
|
| 12 |
+
|
| 13 |
+
if os.path.exists(pretrained_s2G):
|
| 14 |
+
...
|
| 15 |
+
else:
|
| 16 |
+
raise FileNotFoundError(pretrained_s2G)
|
| 17 |
+
# version=os.environ.get("version","v2")
|
| 18 |
+
size = os.path.getsize(pretrained_s2G)
|
| 19 |
+
if size < 82978 * 1024:
|
| 20 |
+
version = "v1"
|
| 21 |
+
elif size < 100 * 1024 * 1024:
|
| 22 |
+
version = "v2"
|
| 23 |
+
elif size < 103520 * 1024:
|
| 24 |
+
version = "v1"
|
| 25 |
+
elif size < 700 * 1024 * 1024:
|
| 26 |
+
version = "v2"
|
| 27 |
+
else:
|
| 28 |
+
version = "v3"
|
| 29 |
+
import torch
|
| 30 |
+
|
| 31 |
+
is_half = eval(os.environ.get("is_half", "True")) and torch.cuda.is_available()
|
| 32 |
+
import traceback
|
| 33 |
+
import sys
|
| 34 |
+
|
| 35 |
+
now_dir = os.getcwd()
|
| 36 |
+
sys.path.append(now_dir)
|
| 37 |
+
import logging
|
| 38 |
+
import utils
|
| 39 |
+
|
| 40 |
+
if version != "v3":
|
| 41 |
+
from module.models import SynthesizerTrn
|
| 42 |
+
else:
|
| 43 |
+
from module.models import SynthesizerTrnV3 as SynthesizerTrn
|
| 44 |
+
from tools.my_utils import clean_path
|
| 45 |
+
|
| 46 |
+
logging.getLogger("numba").setLevel(logging.WARNING)
|
| 47 |
+
# from config import pretrained_s2G
|
| 48 |
+
|
| 49 |
+
# inp_text=sys.argv[1]
|
| 50 |
+
# exp_name=sys.argv[2]
|
| 51 |
+
# i_part=sys.argv[3]
|
| 52 |
+
# all_parts=sys.argv[4]
|
| 53 |
+
# os.environ["CUDA_VISIBLE_DEVICES"]=sys.argv[5]
|
| 54 |
+
# opt_dir="/data/docker/liujing04/gpt-vits/fine_tune_dataset/%s"%exp_name
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
hubert_dir = "%s/4-cnhubert" % (opt_dir)
|
| 58 |
+
semantic_path = "%s/6-name2semantic-%s.tsv" % (opt_dir, i_part)
|
| 59 |
+
if os.path.exists(semantic_path) == False:
|
| 60 |
+
os.makedirs(opt_dir, exist_ok=True)
|
| 61 |
+
|
| 62 |
+
if torch.cuda.is_available():
|
| 63 |
+
device = "cuda"
|
| 64 |
+
# elif torch.backends.mps.is_available():
|
| 65 |
+
# device = "mps"
|
| 66 |
+
else:
|
| 67 |
+
device = "cpu"
|
| 68 |
+
hps = utils.get_hparams_from_file(s2config_path)
|
| 69 |
+
vq_model = SynthesizerTrn(
|
| 70 |
+
hps.data.filter_length // 2 + 1,
|
| 71 |
+
hps.train.segment_size // hps.data.hop_length,
|
| 72 |
+
n_speakers=hps.data.n_speakers,
|
| 73 |
+
version=version,
|
| 74 |
+
**hps.model,
|
| 75 |
+
)
|
| 76 |
+
if is_half == True:
|
| 77 |
+
vq_model = vq_model.half().to(device)
|
| 78 |
+
else:
|
| 79 |
+
vq_model = vq_model.to(device)
|
| 80 |
+
vq_model.eval()
|
| 81 |
+
# utils.load_checkpoint(utils.latest_checkpoint_path(hps.s2_ckpt_dir, "G_*.pth"), vq_model, None, True)
|
| 82 |
+
# utils.load_checkpoint(pretrained_s2G, vq_model, None, True)
|
| 83 |
+
print(
|
| 84 |
+
vq_model.load_state_dict(
|
| 85 |
+
torch.load(pretrained_s2G, map_location="cpu", weights_only=False)["weight"], strict=False
|
| 86 |
+
)
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
def name2go(wav_name, lines):
|
| 90 |
+
hubert_path = "%s/%s.pt" % (hubert_dir, wav_name)
|
| 91 |
+
if os.path.exists(hubert_path) == False:
|
| 92 |
+
return
|
| 93 |
+
ssl_content = torch.load(hubert_path, map_location="cpu")
|
| 94 |
+
if is_half == True:
|
| 95 |
+
ssl_content = ssl_content.half().to(device)
|
| 96 |
+
else:
|
| 97 |
+
ssl_content = ssl_content.to(device)
|
| 98 |
+
codes = vq_model.extract_latent(ssl_content)
|
| 99 |
+
semantic = " ".join([str(i) for i in codes[0, 0, :].tolist()])
|
| 100 |
+
lines.append("%s\t%s" % (wav_name, semantic))
|
| 101 |
+
|
| 102 |
+
with open(inp_text, "r", encoding="utf8") as f:
|
| 103 |
+
lines = f.read().strip("\n").split("\n")
|
| 104 |
+
|
| 105 |
+
lines1 = []
|
| 106 |
+
for line in lines[int(i_part) :: int(all_parts)]:
|
| 107 |
+
# print(line)
|
| 108 |
+
try:
|
| 109 |
+
# wav_name,text=line.split("\t")
|
| 110 |
+
wav_name, spk_name, language, text = line.split("|")
|
| 111 |
+
wav_name = clean_path(wav_name)
|
| 112 |
+
wav_name = os.path.basename(wav_name)
|
| 113 |
+
# name2go(name,lines1)
|
| 114 |
+
name2go(wav_name, lines1)
|
| 115 |
+
except:
|
| 116 |
+
print(line, traceback.format_exc())
|
| 117 |
+
with open(semantic_path, "w", encoding="utf8") as f:
|
| 118 |
+
f.write("\n".join(lines1))
|