laubonghaudoi commited on
Commit
5c7984d
·
1 Parent(s): f1f1f55

Add modules

Browse files
configs/.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ *.yaml
configs/s2.json ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "train": {
3
+ "log_interval": 100,
4
+ "eval_interval": 500,
5
+ "seed": 1234,
6
+ "epochs": 100,
7
+ "learning_rate": 0.0001,
8
+ "betas": [
9
+ 0.8,
10
+ 0.99
11
+ ],
12
+ "eps": 1e-09,
13
+ "batch_size": 32,
14
+ "fp16_run": true,
15
+ "lr_decay": 0.999875,
16
+ "segment_size": 20480,
17
+ "init_lr_ratio": 1,
18
+ "warmup_epochs": 0,
19
+ "c_mel": 45,
20
+ "c_kl": 1.0,
21
+ "text_low_lr_rate": 0.4,
22
+ "grad_ckpt": false
23
+ },
24
+ "data": {
25
+ "max_wav_value": 32768.0,
26
+ "sampling_rate": 32000,
27
+ "filter_length": 2048,
28
+ "hop_length": 640,
29
+ "win_length": 2048,
30
+ "n_mel_channels": 128,
31
+ "mel_fmin": 0.0,
32
+ "mel_fmax": null,
33
+ "add_blank": true,
34
+ "n_speakers": 300,
35
+ "cleaned_text": true
36
+ },
37
+ "model": {
38
+ "inter_channels": 192,
39
+ "hidden_channels": 192,
40
+ "filter_channels": 768,
41
+ "n_heads": 2,
42
+ "n_layers": 6,
43
+ "kernel_size": 3,
44
+ "p_dropout": 0.1,
45
+ "resblock": "1",
46
+ "resblock_kernel_sizes": [
47
+ 3,
48
+ 7,
49
+ 11
50
+ ],
51
+ "resblock_dilation_sizes": [
52
+ [
53
+ 1,
54
+ 3,
55
+ 5
56
+ ],
57
+ [
58
+ 1,
59
+ 3,
60
+ 5
61
+ ],
62
+ [
63
+ 1,
64
+ 3,
65
+ 5
66
+ ]
67
+ ],
68
+ "upsample_rates": [
69
+ 10,
70
+ 8,
71
+ 2,
72
+ 2,
73
+ 2
74
+ ],
75
+ "upsample_initial_channel": 512,
76
+ "upsample_kernel_sizes": [
77
+ 16,
78
+ 16,
79
+ 8,
80
+ 2,
81
+ 2
82
+ ],
83
+ "n_layers_q": 3,
84
+ "use_spectral_norm": false,
85
+ "gin_channels": 512,
86
+ "semantic_frame_rate": "25hz",
87
+ "freeze_quantizer": true
88
+ },
89
+ "s2_ckpt_dir": "logs/s2/big2k1",
90
+ "content_module": "cnhubert"
91
+ }
configs/s2v2Pro.json ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "train": {
3
+ "log_interval": 100,
4
+ "eval_interval": 500,
5
+ "seed": 1234,
6
+ "epochs": 100,
7
+ "learning_rate": 0.0001,
8
+ "betas": [
9
+ 0.8,
10
+ 0.99
11
+ ],
12
+ "eps": 1e-09,
13
+ "batch_size": 32,
14
+ "fp16_run": true,
15
+ "lr_decay": 0.999875,
16
+ "segment_size": 20480,
17
+ "init_lr_ratio": 1,
18
+ "warmup_epochs": 0,
19
+ "c_mel": 45,
20
+ "c_kl": 1.0,
21
+ "text_low_lr_rate": 0.4,
22
+ "grad_ckpt": false
23
+ },
24
+ "data": {
25
+ "max_wav_value": 32768.0,
26
+ "sampling_rate": 32000,
27
+ "filter_length": 2048,
28
+ "hop_length": 640,
29
+ "win_length": 2048,
30
+ "n_mel_channels": 128,
31
+ "mel_fmin": 0.0,
32
+ "mel_fmax": null,
33
+ "add_blank": true,
34
+ "n_speakers": 300,
35
+ "cleaned_text": true
36
+ },
37
+ "model": {
38
+ "inter_channels": 192,
39
+ "hidden_channels": 192,
40
+ "filter_channels": 768,
41
+ "n_heads": 2,
42
+ "n_layers": 6,
43
+ "kernel_size": 3,
44
+ "p_dropout": 0.0,
45
+ "resblock": "1",
46
+ "resblock_kernel_sizes": [
47
+ 3,
48
+ 7,
49
+ 11
50
+ ],
51
+ "resblock_dilation_sizes": [
52
+ [
53
+ 1,
54
+ 3,
55
+ 5
56
+ ],
57
+ [
58
+ 1,
59
+ 3,
60
+ 5
61
+ ],
62
+ [
63
+ 1,
64
+ 3,
65
+ 5
66
+ ]
67
+ ],
68
+ "upsample_rates": [
69
+ 10,
70
+ 8,
71
+ 2,
72
+ 2,
73
+ 2
74
+ ],
75
+ "upsample_initial_channel": 512,
76
+ "upsample_kernel_sizes": [
77
+ 16,
78
+ 16,
79
+ 8,
80
+ 2,
81
+ 2
82
+ ],
83
+ "n_layers_q": 3,
84
+ "use_spectral_norm": false,
85
+ "gin_channels": 1024,
86
+ "semantic_frame_rate": "25hz",
87
+ "freeze_quantizer": true
88
+ },
89
+ "s2_ckpt_dir": "logs/s2/big2k1",
90
+ "content_module": "cnhubert"
91
+ }
configs/s2v2ProPlus.json ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "train": {
3
+ "log_interval": 100,
4
+ "eval_interval": 500,
5
+ "seed": 1234,
6
+ "epochs": 100,
7
+ "learning_rate": 0.0001,
8
+ "betas": [
9
+ 0.8,
10
+ 0.99
11
+ ],
12
+ "eps": 1e-09,
13
+ "batch_size": 32,
14
+ "fp16_run": true,
15
+ "lr_decay": 0.999875,
16
+ "segment_size": 20480,
17
+ "init_lr_ratio": 1,
18
+ "warmup_epochs": 0,
19
+ "c_mel": 45,
20
+ "c_kl": 1.0,
21
+ "text_low_lr_rate": 0.4,
22
+ "grad_ckpt": false
23
+ },
24
+ "data": {
25
+ "max_wav_value": 32768.0,
26
+ "sampling_rate": 32000,
27
+ "filter_length": 2048,
28
+ "hop_length": 640,
29
+ "win_length": 2048,
30
+ "n_mel_channels": 128,
31
+ "mel_fmin": 0.0,
32
+ "mel_fmax": null,
33
+ "add_blank": true,
34
+ "n_speakers": 300,
35
+ "cleaned_text": true
36
+ },
37
+ "model": {
38
+ "inter_channels": 192,
39
+ "hidden_channels": 192,
40
+ "filter_channels": 768,
41
+ "n_heads": 2,
42
+ "n_layers": 6,
43
+ "kernel_size": 3,
44
+ "p_dropout": 0.0,
45
+ "resblock": "1",
46
+ "resblock_kernel_sizes": [
47
+ 3,
48
+ 7,
49
+ 11
50
+ ],
51
+ "resblock_dilation_sizes": [
52
+ [
53
+ 1,
54
+ 3,
55
+ 5
56
+ ],
57
+ [
58
+ 1,
59
+ 3,
60
+ 5
61
+ ],
62
+ [
63
+ 1,
64
+ 3,
65
+ 5
66
+ ]
67
+ ],
68
+ "upsample_rates": [
69
+ 10,
70
+ 8,
71
+ 2,
72
+ 2,
73
+ 2
74
+ ],
75
+ "upsample_initial_channel": 768,
76
+ "upsample_kernel_sizes": [
77
+ 20,
78
+ 16,
79
+ 8,
80
+ 2,
81
+ 2
82
+ ],
83
+ "n_layers_q": 3,
84
+ "use_spectral_norm": false,
85
+ "gin_channels": 1024,
86
+ "semantic_frame_rate": "25hz",
87
+ "freeze_quantizer": true
88
+ },
89
+ "s2_ckpt_dir": "logs/s2/big2k1",
90
+ "content_module": "cnhubert"
91
+ }
eres2net/ERes2Net.py ADDED
@@ -0,0 +1,264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker). All Rights Reserved.
2
+ # Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
3
+
4
+ """
5
+ Res2Net implementation is adapted from https://github.com/wenet-e2e/wespeaker.
6
+ ERes2Net incorporates both local and global feature fusion techniques to improve the performance.
7
+ The local feature fusion (LFF) fuses the features within one single residual block to extract the local signal.
8
+ The global feature fusion (GFF) takes acoustic features of different scales as input to aggregate global signal.
9
+ """
10
+
11
+ import torch
12
+ import math
13
+ import torch.nn as nn
14
+ import torch.nn.functional as F
15
+ import pooling_layers as pooling_layers
16
+ from fusion import AFF
17
+
18
+
19
+ class ReLU(nn.Hardtanh):
20
+ def __init__(self, inplace=False):
21
+ super(ReLU, self).__init__(0, 20, inplace)
22
+
23
+ def __repr__(self):
24
+ inplace_str = "inplace" if self.inplace else ""
25
+ return self.__class__.__name__ + " (" + inplace_str + ")"
26
+
27
+
28
+ class BasicBlockERes2Net(nn.Module):
29
+ expansion = 2
30
+
31
+ def __init__(self, in_planes, planes, stride=1, baseWidth=32, scale=2):
32
+ super(BasicBlockERes2Net, self).__init__()
33
+ width = int(math.floor(planes * (baseWidth / 64.0)))
34
+ self.conv1 = nn.Conv2d(in_planes, width * scale, kernel_size=1, stride=stride, bias=False)
35
+ self.bn1 = nn.BatchNorm2d(width * scale)
36
+ self.nums = scale
37
+
38
+ convs = []
39
+ bns = []
40
+ for i in range(self.nums):
41
+ convs.append(nn.Conv2d(width, width, kernel_size=3, padding=1, bias=False))
42
+ bns.append(nn.BatchNorm2d(width))
43
+ self.convs = nn.ModuleList(convs)
44
+ self.bns = nn.ModuleList(bns)
45
+ self.relu = ReLU(inplace=True)
46
+
47
+ self.conv3 = nn.Conv2d(width * scale, planes * self.expansion, kernel_size=1, bias=False)
48
+ self.bn3 = nn.BatchNorm2d(planes * self.expansion)
49
+ self.shortcut = nn.Sequential()
50
+ if stride != 1 or in_planes != self.expansion * planes:
51
+ self.shortcut = nn.Sequential(
52
+ nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
53
+ nn.BatchNorm2d(self.expansion * planes),
54
+ )
55
+ self.stride = stride
56
+ self.width = width
57
+ self.scale = scale
58
+
59
+ def forward(self, x):
60
+ residual = x
61
+
62
+ out = self.conv1(x)
63
+ out = self.bn1(out)
64
+ out = self.relu(out)
65
+ spx = torch.split(out, self.width, 1)
66
+ for i in range(self.nums):
67
+ if i == 0:
68
+ sp = spx[i]
69
+ else:
70
+ sp = sp + spx[i]
71
+ sp = self.convs[i](sp)
72
+ sp = self.relu(self.bns[i](sp))
73
+ if i == 0:
74
+ out = sp
75
+ else:
76
+ out = torch.cat((out, sp), 1)
77
+
78
+ out = self.conv3(out)
79
+ out = self.bn3(out)
80
+
81
+ residual = self.shortcut(x)
82
+ out += residual
83
+ out = self.relu(out)
84
+
85
+ return out
86
+
87
+
88
+ class BasicBlockERes2Net_diff_AFF(nn.Module):
89
+ expansion = 2
90
+
91
+ def __init__(self, in_planes, planes, stride=1, baseWidth=32, scale=2):
92
+ super(BasicBlockERes2Net_diff_AFF, self).__init__()
93
+ width = int(math.floor(planes * (baseWidth / 64.0)))
94
+ self.conv1 = nn.Conv2d(in_planes, width * scale, kernel_size=1, stride=stride, bias=False)
95
+ self.bn1 = nn.BatchNorm2d(width * scale)
96
+ self.nums = scale
97
+
98
+ convs = []
99
+ fuse_models = []
100
+ bns = []
101
+ for i in range(self.nums):
102
+ convs.append(nn.Conv2d(width, width, kernel_size=3, padding=1, bias=False))
103
+ bns.append(nn.BatchNorm2d(width))
104
+ for j in range(self.nums - 1):
105
+ fuse_models.append(AFF(channels=width))
106
+
107
+ self.convs = nn.ModuleList(convs)
108
+ self.bns = nn.ModuleList(bns)
109
+ self.fuse_models = nn.ModuleList(fuse_models)
110
+ self.relu = ReLU(inplace=True)
111
+
112
+ self.conv3 = nn.Conv2d(width * scale, planes * self.expansion, kernel_size=1, bias=False)
113
+ self.bn3 = nn.BatchNorm2d(planes * self.expansion)
114
+ self.shortcut = nn.Sequential()
115
+ if stride != 1 or in_planes != self.expansion * planes:
116
+ self.shortcut = nn.Sequential(
117
+ nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
118
+ nn.BatchNorm2d(self.expansion * planes),
119
+ )
120
+ self.stride = stride
121
+ self.width = width
122
+ self.scale = scale
123
+
124
+ def forward(self, x):
125
+ residual = x
126
+
127
+ out = self.conv1(x)
128
+ out = self.bn1(out)
129
+ out = self.relu(out)
130
+ spx = torch.split(out, self.width, 1)
131
+ for i in range(self.nums):
132
+ if i == 0:
133
+ sp = spx[i]
134
+ else:
135
+ sp = self.fuse_models[i - 1](sp, spx[i])
136
+
137
+ sp = self.convs[i](sp)
138
+ sp = self.relu(self.bns[i](sp))
139
+ if i == 0:
140
+ out = sp
141
+ else:
142
+ out = torch.cat((out, sp), 1)
143
+
144
+ out = self.conv3(out)
145
+ out = self.bn3(out)
146
+
147
+ residual = self.shortcut(x)
148
+ out += residual
149
+ out = self.relu(out)
150
+
151
+ return out
152
+
153
+
154
+ class ERes2Net(nn.Module):
155
+ def __init__(
156
+ self,
157
+ block=BasicBlockERes2Net,
158
+ block_fuse=BasicBlockERes2Net_diff_AFF,
159
+ num_blocks=[3, 4, 6, 3],
160
+ m_channels=32,
161
+ feat_dim=80,
162
+ embedding_size=192,
163
+ pooling_func="TSTP",
164
+ two_emb_layer=False,
165
+ ):
166
+ super(ERes2Net, self).__init__()
167
+ self.in_planes = m_channels
168
+ self.feat_dim = feat_dim
169
+ self.embedding_size = embedding_size
170
+ self.stats_dim = int(feat_dim / 8) * m_channels * 8
171
+ self.two_emb_layer = two_emb_layer
172
+
173
+ self.conv1 = nn.Conv2d(1, m_channels, kernel_size=3, stride=1, padding=1, bias=False)
174
+ self.bn1 = nn.BatchNorm2d(m_channels)
175
+ self.layer1 = self._make_layer(block, m_channels, num_blocks[0], stride=1)
176
+ self.layer2 = self._make_layer(block, m_channels * 2, num_blocks[1], stride=2)
177
+ self.layer3 = self._make_layer(block_fuse, m_channels * 4, num_blocks[2], stride=2)
178
+ self.layer4 = self._make_layer(block_fuse, m_channels * 8, num_blocks[3], stride=2)
179
+
180
+ # Downsampling module for each layer
181
+ self.layer1_downsample = nn.Conv2d(
182
+ m_channels * 2, m_channels * 4, kernel_size=3, stride=2, padding=1, bias=False
183
+ )
184
+ self.layer2_downsample = nn.Conv2d(
185
+ m_channels * 4, m_channels * 8, kernel_size=3, padding=1, stride=2, bias=False
186
+ )
187
+ self.layer3_downsample = nn.Conv2d(
188
+ m_channels * 8, m_channels * 16, kernel_size=3, padding=1, stride=2, bias=False
189
+ )
190
+
191
+ # Bottom-up fusion module
192
+ self.fuse_mode12 = AFF(channels=m_channels * 4)
193
+ self.fuse_mode123 = AFF(channels=m_channels * 8)
194
+ self.fuse_mode1234 = AFF(channels=m_channels * 16)
195
+
196
+ self.n_stats = 1 if pooling_func == "TAP" or pooling_func == "TSDP" else 2
197
+ self.pool = getattr(pooling_layers, pooling_func)(in_dim=self.stats_dim * block.expansion)
198
+ self.seg_1 = nn.Linear(self.stats_dim * block.expansion * self.n_stats, embedding_size)
199
+ if self.two_emb_layer:
200
+ self.seg_bn_1 = nn.BatchNorm1d(embedding_size, affine=False)
201
+ self.seg_2 = nn.Linear(embedding_size, embedding_size)
202
+ else:
203
+ self.seg_bn_1 = nn.Identity()
204
+ self.seg_2 = nn.Identity()
205
+
206
+ def _make_layer(self, block, planes, num_blocks, stride):
207
+ strides = [stride] + [1] * (num_blocks - 1)
208
+ layers = []
209
+ for stride in strides:
210
+ layers.append(block(self.in_planes, planes, stride))
211
+ self.in_planes = planes * block.expansion
212
+ return nn.Sequential(*layers)
213
+
214
+ def forward(self, x):
215
+ x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T)
216
+ x = x.unsqueeze_(1)
217
+ out = F.relu(self.bn1(self.conv1(x)))
218
+ out1 = self.layer1(out)
219
+ out2 = self.layer2(out1)
220
+ out1_downsample = self.layer1_downsample(out1)
221
+ fuse_out12 = self.fuse_mode12(out2, out1_downsample)
222
+ out3 = self.layer3(out2)
223
+ fuse_out12_downsample = self.layer2_downsample(fuse_out12)
224
+ fuse_out123 = self.fuse_mode123(out3, fuse_out12_downsample)
225
+ out4 = self.layer4(out3)
226
+ fuse_out123_downsample = self.layer3_downsample(fuse_out123)
227
+ fuse_out1234 = self.fuse_mode1234(out4, fuse_out123_downsample)
228
+ stats = self.pool(fuse_out1234)
229
+
230
+ embed_a = self.seg_1(stats)
231
+ if self.two_emb_layer:
232
+ out = F.relu(embed_a)
233
+ out = self.seg_bn_1(out)
234
+ embed_b = self.seg_2(out)
235
+ return embed_b
236
+ else:
237
+ return embed_a
238
+
239
+ def forward3(self, x):
240
+ x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T)
241
+ x = x.unsqueeze_(1)
242
+ out = F.relu(self.bn1(self.conv1(x)))
243
+ out1 = self.layer1(out)
244
+ out2 = self.layer2(out1)
245
+ out1_downsample = self.layer1_downsample(out1)
246
+ fuse_out12 = self.fuse_mode12(out2, out1_downsample)
247
+ out3 = self.layer3(out2)
248
+ fuse_out12_downsample = self.layer2_downsample(fuse_out12)
249
+ fuse_out123 = self.fuse_mode123(out3, fuse_out12_downsample)
250
+ out4 = self.layer4(out3)
251
+ fuse_out123_downsample = self.layer3_downsample(fuse_out123)
252
+ fuse_out1234 = self.fuse_mode1234(out4, fuse_out123_downsample).flatten(start_dim=1, end_dim=2).mean(-1)
253
+ return fuse_out1234
254
+
255
+
256
+ if __name__ == "__main__":
257
+ x = torch.zeros(10, 300, 80)
258
+ model = ERes2Net(feat_dim=80, embedding_size=192, pooling_func="TSTP")
259
+ model.eval()
260
+ out = model(x)
261
+ print(out.shape) # torch.Size([10, 192])
262
+
263
+ num_params = sum(param.numel() for param in model.parameters())
264
+ print("{} M".format(num_params / 1e6)) # 6.61M
eres2net/ERes2NetV2.py ADDED
@@ -0,0 +1,272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker). All Rights Reserved.
2
+ # Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
3
+
4
+ """
5
+ To further improve the short-duration feature extraction capability of ERes2Net, we expand the channel dimension
6
+ within each stage. However, this modification also increases the number of model parameters and computational complexity.
7
+ To alleviate this problem, we propose an improved ERes2NetV2 by pruning redundant structures, ultimately reducing
8
+ both the model parameters and its computational cost.
9
+ """
10
+
11
+ import torch
12
+ import math
13
+ import torch.nn as nn
14
+ import torch.nn.functional as F
15
+ import pooling_layers as pooling_layers
16
+ from fusion import AFF
17
+
18
+
19
+ class ReLU(nn.Hardtanh):
20
+ def __init__(self, inplace=False):
21
+ super(ReLU, self).__init__(0, 20, inplace)
22
+
23
+ def __repr__(self):
24
+ inplace_str = "inplace" if self.inplace else ""
25
+ return self.__class__.__name__ + " (" + inplace_str + ")"
26
+
27
+
28
+ class BasicBlockERes2NetV2(nn.Module):
29
+ def __init__(self, in_planes, planes, stride=1, baseWidth=26, scale=2, expansion=2):
30
+ super(BasicBlockERes2NetV2, self).__init__()
31
+ width = int(math.floor(planes * (baseWidth / 64.0)))
32
+ self.conv1 = nn.Conv2d(in_planes, width * scale, kernel_size=1, stride=stride, bias=False)
33
+ self.bn1 = nn.BatchNorm2d(width * scale)
34
+ self.nums = scale
35
+ self.expansion = expansion
36
+
37
+ convs = []
38
+ bns = []
39
+ for i in range(self.nums):
40
+ convs.append(nn.Conv2d(width, width, kernel_size=3, padding=1, bias=False))
41
+ bns.append(nn.BatchNorm2d(width))
42
+ self.convs = nn.ModuleList(convs)
43
+ self.bns = nn.ModuleList(bns)
44
+ self.relu = ReLU(inplace=True)
45
+
46
+ self.conv3 = nn.Conv2d(width * scale, planes * self.expansion, kernel_size=1, bias=False)
47
+ self.bn3 = nn.BatchNorm2d(planes * self.expansion)
48
+ self.shortcut = nn.Sequential()
49
+ if stride != 1 or in_planes != self.expansion * planes:
50
+ self.shortcut = nn.Sequential(
51
+ nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
52
+ nn.BatchNorm2d(self.expansion * planes),
53
+ )
54
+ self.stride = stride
55
+ self.width = width
56
+ self.scale = scale
57
+
58
+ def forward(self, x):
59
+ residual = x
60
+
61
+ out = self.conv1(x)
62
+ out = self.bn1(out)
63
+ out = self.relu(out)
64
+ spx = torch.split(out, self.width, 1)
65
+ for i in range(self.nums):
66
+ if i == 0:
67
+ sp = spx[i]
68
+ else:
69
+ sp = sp + spx[i]
70
+ sp = self.convs[i](sp)
71
+ sp = self.relu(self.bns[i](sp))
72
+ if i == 0:
73
+ out = sp
74
+ else:
75
+ out = torch.cat((out, sp), 1)
76
+
77
+ out = self.conv3(out)
78
+ out = self.bn3(out)
79
+
80
+ residual = self.shortcut(x)
81
+ out += residual
82
+ out = self.relu(out)
83
+
84
+ return out
85
+
86
+
87
+ class BasicBlockERes2NetV2AFF(nn.Module):
88
+ def __init__(self, in_planes, planes, stride=1, baseWidth=26, scale=2, expansion=2):
89
+ super(BasicBlockERes2NetV2AFF, self).__init__()
90
+ width = int(math.floor(planes * (baseWidth / 64.0)))
91
+ self.conv1 = nn.Conv2d(in_planes, width * scale, kernel_size=1, stride=stride, bias=False)
92
+ self.bn1 = nn.BatchNorm2d(width * scale)
93
+ self.nums = scale
94
+ self.expansion = expansion
95
+
96
+ convs = []
97
+ fuse_models = []
98
+ bns = []
99
+ for i in range(self.nums):
100
+ convs.append(nn.Conv2d(width, width, kernel_size=3, padding=1, bias=False))
101
+ bns.append(nn.BatchNorm2d(width))
102
+ for j in range(self.nums - 1):
103
+ fuse_models.append(AFF(channels=width, r=4))
104
+
105
+ self.convs = nn.ModuleList(convs)
106
+ self.bns = nn.ModuleList(bns)
107
+ self.fuse_models = nn.ModuleList(fuse_models)
108
+ self.relu = ReLU(inplace=True)
109
+
110
+ self.conv3 = nn.Conv2d(width * scale, planes * self.expansion, kernel_size=1, bias=False)
111
+ self.bn3 = nn.BatchNorm2d(planes * self.expansion)
112
+ self.shortcut = nn.Sequential()
113
+ if stride != 1 or in_planes != self.expansion * planes:
114
+ self.shortcut = nn.Sequential(
115
+ nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
116
+ nn.BatchNorm2d(self.expansion * planes),
117
+ )
118
+ self.stride = stride
119
+ self.width = width
120
+ self.scale = scale
121
+
122
+ def forward(self, x):
123
+ residual = x
124
+
125
+ out = self.conv1(x)
126
+ out = self.bn1(out)
127
+ out = self.relu(out)
128
+ spx = torch.split(out, self.width, 1)
129
+ for i in range(self.nums):
130
+ if i == 0:
131
+ sp = spx[i]
132
+ else:
133
+ sp = self.fuse_models[i - 1](sp, spx[i])
134
+
135
+ sp = self.convs[i](sp)
136
+ sp = self.relu(self.bns[i](sp))
137
+ if i == 0:
138
+ out = sp
139
+ else:
140
+ out = torch.cat((out, sp), 1)
141
+
142
+ out = self.conv3(out)
143
+ out = self.bn3(out)
144
+
145
+ residual = self.shortcut(x)
146
+ out += residual
147
+ out = self.relu(out)
148
+
149
+ return out
150
+
151
+
152
+ class ERes2NetV2(nn.Module):
153
+ def __init__(
154
+ self,
155
+ block=BasicBlockERes2NetV2,
156
+ block_fuse=BasicBlockERes2NetV2AFF,
157
+ num_blocks=[3, 4, 6, 3],
158
+ m_channels=64,
159
+ feat_dim=80,
160
+ embedding_size=192,
161
+ baseWidth=26,
162
+ scale=2,
163
+ expansion=2,
164
+ pooling_func="TSTP",
165
+ two_emb_layer=False,
166
+ ):
167
+ super(ERes2NetV2, self).__init__()
168
+ self.in_planes = m_channels
169
+ self.feat_dim = feat_dim
170
+ self.embedding_size = embedding_size
171
+ self.stats_dim = int(feat_dim / 8) * m_channels * 8
172
+ self.two_emb_layer = two_emb_layer
173
+ self.baseWidth = baseWidth
174
+ self.scale = scale
175
+ self.expansion = expansion
176
+
177
+ self.conv1 = nn.Conv2d(1, m_channels, kernel_size=3, stride=1, padding=1, bias=False)
178
+ self.bn1 = nn.BatchNorm2d(m_channels)
179
+ self.layer1 = self._make_layer(block, m_channels, num_blocks[0], stride=1)
180
+ self.layer2 = self._make_layer(block, m_channels * 2, num_blocks[1], stride=2)
181
+ self.layer3 = self._make_layer(block_fuse, m_channels * 4, num_blocks[2], stride=2)
182
+ self.layer4 = self._make_layer(block_fuse, m_channels * 8, num_blocks[3], stride=2)
183
+
184
+ # Downsampling module
185
+ self.layer3_ds = nn.Conv2d(
186
+ m_channels * 4 * self.expansion,
187
+ m_channels * 8 * self.expansion,
188
+ kernel_size=3,
189
+ padding=1,
190
+ stride=2,
191
+ bias=False,
192
+ )
193
+
194
+ # Bottom-up fusion module
195
+ self.fuse34 = AFF(channels=m_channels * 8 * self.expansion, r=4)
196
+
197
+ self.n_stats = 1 if pooling_func == "TAP" or pooling_func == "TSDP" else 2
198
+ self.pool = getattr(pooling_layers, pooling_func)(in_dim=self.stats_dim * self.expansion)
199
+ self.seg_1 = nn.Linear(self.stats_dim * self.expansion * self.n_stats, embedding_size)
200
+ if self.two_emb_layer:
201
+ self.seg_bn_1 = nn.BatchNorm1d(embedding_size, affine=False)
202
+ self.seg_2 = nn.Linear(embedding_size, embedding_size)
203
+ else:
204
+ self.seg_bn_1 = nn.Identity()
205
+ self.seg_2 = nn.Identity()
206
+
207
+ def _make_layer(self, block, planes, num_blocks, stride):
208
+ strides = [stride] + [1] * (num_blocks - 1)
209
+ layers = []
210
+ for stride in strides:
211
+ layers.append(
212
+ block(
213
+ self.in_planes, planes, stride, baseWidth=self.baseWidth, scale=self.scale, expansion=self.expansion
214
+ )
215
+ )
216
+ self.in_planes = planes * self.expansion
217
+ return nn.Sequential(*layers)
218
+
219
+ def forward(self, x):
220
+ x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T)
221
+ x = x.unsqueeze_(1)
222
+ out = F.relu(self.bn1(self.conv1(x)))
223
+ out1 = self.layer1(out)
224
+ out2 = self.layer2(out1)
225
+ out3 = self.layer3(out2)
226
+ out4 = self.layer4(out3)
227
+ out3_ds = self.layer3_ds(out3)
228
+ fuse_out34 = self.fuse34(out4, out3_ds)
229
+ stats = self.pool(fuse_out34)
230
+
231
+ embed_a = self.seg_1(stats)
232
+ if self.two_emb_layer:
233
+ out = F.relu(embed_a)
234
+ out = self.seg_bn_1(out)
235
+ embed_b = self.seg_2(out)
236
+ return embed_b
237
+ else:
238
+ return embed_a
239
+
240
+ def forward3(self, x):
241
+ x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T)
242
+ x = x.unsqueeze_(1)
243
+ out = F.relu(self.bn1(self.conv1(x)))
244
+ out1 = self.layer1(out)
245
+ out2 = self.layer2(out1)
246
+ out3 = self.layer3(out2)
247
+ out4 = self.layer4(out3)
248
+ out3_ds = self.layer3_ds(out3)
249
+ fuse_out34 = self.fuse34(out4, out3_ds)
250
+ # print(111111111,fuse_out34.shape)#111111111 torch.Size([16, 2048, 10, 72])
251
+ return fuse_out34.flatten(start_dim=1, end_dim=2).mean(-1)
252
+ # stats = self.pool(fuse_out34)
253
+ #
254
+ # embed_a = self.seg_1(stats)
255
+ # if self.two_emb_layer:
256
+ # out = F.relu(embed_a)
257
+ # out = self.seg_bn_1(out)
258
+ # embed_b = self.seg_2(out)
259
+ # return embed_b
260
+ # else:
261
+ # return embed_a
262
+
263
+
264
+ if __name__ == "__main__":
265
+ x = torch.randn(1, 300, 80)
266
+ model = ERes2NetV2(feat_dim=80, embedding_size=192, m_channels=64, baseWidth=26, scale=2, expansion=2)
267
+ model.eval()
268
+ y = model(x)
269
+ print(y.size())
270
+ macs, num_params = profile(model, inputs=(x,))
271
+ print("Params: {} M".format(num_params / 1e6)) # 17.86 M
272
+ print("MACs: {} G".format(macs / 1e9)) # 12.69 G
eres2net/ERes2Net_huge.py ADDED
@@ -0,0 +1,289 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker). All Rights Reserved.
2
+ # Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
3
+
4
+ """Res2Net implementation is adapted from https://github.com/wenet-e2e/wespeaker.
5
+ ERes2Net incorporates both local and global feature fusion techniques to improve the performance.
6
+ The local feature fusion (LFF) fuses the features within one single residual block to extract the local signal.
7
+ The global feature fusion (GFF) takes acoustic features of different scales as input to aggregate global signal.
8
+ ERes2Net-huge is an upgraded version of ERes2Net that uses a larger number of parameters to achieve better
9
+ recognition performance. Parameters expansion, baseWidth, and scale can be modified to obtain optimal performance.
10
+ """
11
+
12
+ import torch
13
+ import math
14
+ import torch.nn as nn
15
+ import torch.nn.functional as F
16
+ import pooling_layers as pooling_layers
17
+ from fusion import AFF
18
+
19
+
20
+ class ReLU(nn.Hardtanh):
21
+ def __init__(self, inplace=False):
22
+ super(ReLU, self).__init__(0, 20, inplace)
23
+
24
+ def __repr__(self):
25
+ inplace_str = "inplace" if self.inplace else ""
26
+ return self.__class__.__name__ + " (" + inplace_str + ")"
27
+
28
+
29
+ class BasicBlockERes2Net(nn.Module):
30
+ expansion = 4
31
+
32
+ def __init__(self, in_planes, planes, stride=1, baseWidth=24, scale=3):
33
+ super(BasicBlockERes2Net, self).__init__()
34
+ width = int(math.floor(planes * (baseWidth / 64.0)))
35
+ self.conv1 = nn.Conv2d(in_planes, width * scale, kernel_size=1, stride=stride, bias=False)
36
+ self.bn1 = nn.BatchNorm2d(width * scale)
37
+ self.nums = scale
38
+
39
+ convs = []
40
+ bns = []
41
+ for i in range(self.nums):
42
+ convs.append(nn.Conv2d(width, width, kernel_size=3, padding=1, bias=False))
43
+ bns.append(nn.BatchNorm2d(width))
44
+ self.convs = nn.ModuleList(convs)
45
+ self.bns = nn.ModuleList(bns)
46
+ self.relu = ReLU(inplace=True)
47
+
48
+ self.conv3 = nn.Conv2d(width * scale, planes * self.expansion, kernel_size=1, bias=False)
49
+ self.bn3 = nn.BatchNorm2d(planes * self.expansion)
50
+ self.shortcut = nn.Sequential()
51
+ if stride != 1 or in_planes != self.expansion * planes:
52
+ self.shortcut = nn.Sequential(
53
+ nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
54
+ nn.BatchNorm2d(self.expansion * planes),
55
+ )
56
+ self.stride = stride
57
+ self.width = width
58
+ self.scale = scale
59
+
60
+ def forward(self, x):
61
+ residual = x
62
+
63
+ out = self.conv1(x)
64
+ out = self.bn1(out)
65
+ out = self.relu(out)
66
+ spx = torch.split(out, self.width, 1)
67
+ for i in range(self.nums):
68
+ if i == 0:
69
+ sp = spx[i]
70
+ else:
71
+ sp = sp + spx[i]
72
+ sp = self.convs[i](sp)
73
+ sp = self.relu(self.bns[i](sp))
74
+ if i == 0:
75
+ out = sp
76
+ else:
77
+ out = torch.cat((out, sp), 1)
78
+
79
+ out = self.conv3(out)
80
+ out = self.bn3(out)
81
+
82
+ residual = self.shortcut(x)
83
+ out += residual
84
+ out = self.relu(out)
85
+
86
+ return out
87
+
88
+
89
+ class BasicBlockERes2Net_diff_AFF(nn.Module):
90
+ expansion = 4
91
+
92
+ def __init__(self, in_planes, planes, stride=1, baseWidth=24, scale=3):
93
+ super(BasicBlockERes2Net_diff_AFF, self).__init__()
94
+ width = int(math.floor(planes * (baseWidth / 64.0)))
95
+ self.conv1 = nn.Conv2d(in_planes, width * scale, kernel_size=1, stride=stride, bias=False)
96
+ self.bn1 = nn.BatchNorm2d(width * scale)
97
+ self.nums = scale
98
+
99
+ convs = []
100
+ fuse_models = []
101
+ bns = []
102
+ for i in range(self.nums):
103
+ convs.append(nn.Conv2d(width, width, kernel_size=3, padding=1, bias=False))
104
+ bns.append(nn.BatchNorm2d(width))
105
+ for j in range(self.nums - 1):
106
+ fuse_models.append(AFF(channels=width))
107
+
108
+ self.convs = nn.ModuleList(convs)
109
+ self.bns = nn.ModuleList(bns)
110
+ self.fuse_models = nn.ModuleList(fuse_models)
111
+ self.relu = ReLU(inplace=True)
112
+
113
+ self.conv3 = nn.Conv2d(width * scale, planes * self.expansion, kernel_size=1, bias=False)
114
+ self.bn3 = nn.BatchNorm2d(planes * self.expansion)
115
+ self.shortcut = nn.Sequential()
116
+ if stride != 1 or in_planes != self.expansion * planes:
117
+ self.shortcut = nn.Sequential(
118
+ nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
119
+ nn.BatchNorm2d(self.expansion * planes),
120
+ )
121
+ self.stride = stride
122
+ self.width = width
123
+ self.scale = scale
124
+
125
+ def forward(self, x):
126
+ residual = x
127
+
128
+ out = self.conv1(x)
129
+ out = self.bn1(out)
130
+ out = self.relu(out)
131
+ spx = torch.split(out, self.width, 1)
132
+ for i in range(self.nums):
133
+ if i == 0:
134
+ sp = spx[i]
135
+ else:
136
+ sp = self.fuse_models[i - 1](sp, spx[i])
137
+
138
+ sp = self.convs[i](sp)
139
+ sp = self.relu(self.bns[i](sp))
140
+ if i == 0:
141
+ out = sp
142
+ else:
143
+ out = torch.cat((out, sp), 1)
144
+
145
+ out = self.conv3(out)
146
+ out = self.bn3(out)
147
+
148
+ residual = self.shortcut(x)
149
+ out += residual
150
+ out = self.relu(out)
151
+
152
+ return out
153
+
154
+
155
+ class ERes2Net(nn.Module):
156
+ def __init__(
157
+ self,
158
+ block=BasicBlockERes2Net,
159
+ block_fuse=BasicBlockERes2Net_diff_AFF,
160
+ num_blocks=[3, 4, 6, 3],
161
+ m_channels=64,
162
+ feat_dim=80,
163
+ embedding_size=192,
164
+ pooling_func="TSTP",
165
+ two_emb_layer=False,
166
+ ):
167
+ super(ERes2Net, self).__init__()
168
+ self.in_planes = m_channels
169
+ self.feat_dim = feat_dim
170
+ self.embedding_size = embedding_size
171
+ self.stats_dim = int(feat_dim / 8) * m_channels * 8
172
+ self.two_emb_layer = two_emb_layer
173
+
174
+ self.conv1 = nn.Conv2d(1, m_channels, kernel_size=3, stride=1, padding=1, bias=False)
175
+ self.bn1 = nn.BatchNorm2d(m_channels)
176
+
177
+ self.layer1 = self._make_layer(block, m_channels, num_blocks[0], stride=1)
178
+ self.layer2 = self._make_layer(block, m_channels * 2, num_blocks[1], stride=2)
179
+ self.layer3 = self._make_layer(block_fuse, m_channels * 4, num_blocks[2], stride=2)
180
+ self.layer4 = self._make_layer(block_fuse, m_channels * 8, num_blocks[3], stride=2)
181
+
182
+ self.layer1_downsample = nn.Conv2d(
183
+ m_channels * 4, m_channels * 8, kernel_size=3, padding=1, stride=2, bias=False
184
+ )
185
+ self.layer2_downsample = nn.Conv2d(
186
+ m_channels * 8, m_channels * 16, kernel_size=3, padding=1, stride=2, bias=False
187
+ )
188
+ self.layer3_downsample = nn.Conv2d(
189
+ m_channels * 16, m_channels * 32, kernel_size=3, padding=1, stride=2, bias=False
190
+ )
191
+
192
+ self.fuse_mode12 = AFF(channels=m_channels * 8)
193
+ self.fuse_mode123 = AFF(channels=m_channels * 16)
194
+ self.fuse_mode1234 = AFF(channels=m_channels * 32)
195
+
196
+ self.n_stats = 1 if pooling_func == "TAP" or pooling_func == "TSDP" else 2
197
+ self.pool = getattr(pooling_layers, pooling_func)(in_dim=self.stats_dim * block.expansion)
198
+ self.seg_1 = nn.Linear(self.stats_dim * block.expansion * self.n_stats, embedding_size)
199
+ if self.two_emb_layer:
200
+ self.seg_bn_1 = nn.BatchNorm1d(embedding_size, affine=False)
201
+ self.seg_2 = nn.Linear(embedding_size, embedding_size)
202
+ else:
203
+ self.seg_bn_1 = nn.Identity()
204
+ self.seg_2 = nn.Identity()
205
+
206
+ def _make_layer(self, block, planes, num_blocks, stride):
207
+ strides = [stride] + [1] * (num_blocks - 1)
208
+ layers = []
209
+ for stride in strides:
210
+ layers.append(block(self.in_planes, planes, stride))
211
+ self.in_planes = planes * block.expansion
212
+ return nn.Sequential(*layers)
213
+
214
+ def forward(self, x):
215
+ x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T)
216
+
217
+ x = x.unsqueeze_(1)
218
+ out = F.relu(self.bn1(self.conv1(x)))
219
+ out1 = self.layer1(out)
220
+ out2 = self.layer2(out1)
221
+ out1_downsample = self.layer1_downsample(out1)
222
+ fuse_out12 = self.fuse_mode12(out2, out1_downsample)
223
+ out3 = self.layer3(out2)
224
+ fuse_out12_downsample = self.layer2_downsample(fuse_out12)
225
+ fuse_out123 = self.fuse_mode123(out3, fuse_out12_downsample)
226
+ out4 = self.layer4(out3)
227
+ fuse_out123_downsample = self.layer3_downsample(fuse_out123)
228
+ fuse_out1234 = self.fuse_mode1234(out4, fuse_out123_downsample)
229
+ stats = self.pool(fuse_out1234)
230
+
231
+ embed_a = self.seg_1(stats)
232
+ if self.two_emb_layer:
233
+ out = F.relu(embed_a)
234
+ out = self.seg_bn_1(out)
235
+ embed_b = self.seg_2(out)
236
+ return embed_b
237
+ else:
238
+ return embed_a
239
+
240
+ def forward2(self, x, if_mean):
241
+ x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T)
242
+
243
+ x = x.unsqueeze_(1)
244
+ out = F.relu(self.bn1(self.conv1(x)))
245
+ out1 = self.layer1(out)
246
+ out2 = self.layer2(out1)
247
+ out1_downsample = self.layer1_downsample(out1)
248
+ fuse_out12 = self.fuse_mode12(out2, out1_downsample)
249
+ out3 = self.layer3(out2)
250
+ fuse_out12_downsample = self.layer2_downsample(fuse_out12)
251
+ fuse_out123 = self.fuse_mode123(out3, fuse_out12_downsample)
252
+ out4 = self.layer4(out3)
253
+ fuse_out123_downsample = self.layer3_downsample(fuse_out123)
254
+ fuse_out1234 = self.fuse_mode1234(out4, fuse_out123_downsample).flatten(start_dim=1, end_dim=2) # bs,20480,T
255
+ if if_mean == False:
256
+ mean = fuse_out1234[0].transpose(1, 0) # (T,20480),bs=T
257
+ else:
258
+ mean = fuse_out1234.mean(2) # bs,20480
259
+ mean_std = torch.cat([mean, torch.zeros_like(mean)], 1)
260
+ return self.seg_1(mean_std) # (T,192)
261
+
262
+ # stats = self.pool(fuse_out1234)
263
+ # if self.two_emb_layer:
264
+ # out = F.relu(embed_a)
265
+ # out = self.seg_bn_1(out)
266
+ # embed_b = self.seg_2(out)
267
+ # return embed_b
268
+ # else:
269
+ # return embed_a
270
+
271
+ def forward3(self, x):
272
+ x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T)
273
+
274
+ x = x.unsqueeze_(1)
275
+ out = F.relu(self.bn1(self.conv1(x)))
276
+ out1 = self.layer1(out)
277
+ out2 = self.layer2(out1)
278
+ out1_downsample = self.layer1_downsample(out1)
279
+ fuse_out12 = self.fuse_mode12(out2, out1_downsample)
280
+ out3 = self.layer3(out2)
281
+ fuse_out12_downsample = self.layer2_downsample(fuse_out12)
282
+ fuse_out123 = self.fuse_mode123(out3, fuse_out12_downsample)
283
+ out4 = self.layer4(out3)
284
+ fuse_out123_downsample = self.layer3_downsample(fuse_out123)
285
+ fuse_out1234 = self.fuse_mode1234(out4, fuse_out123_downsample).flatten(start_dim=1, end_dim=2).mean(-1)
286
+ return fuse_out1234
287
+ # print(fuse_out1234.shape)
288
+ # print(fuse_out1234.flatten(start_dim=1,end_dim=2).shape)
289
+ # pdb.set_trace()
eres2net/fusion.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker). All Rights Reserved.
2
+ # Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+
8
+ class AFF(nn.Module):
9
+ def __init__(self, channels=64, r=4):
10
+ super(AFF, self).__init__()
11
+ inter_channels = int(channels // r)
12
+
13
+ self.local_att = nn.Sequential(
14
+ nn.Conv2d(channels * 2, inter_channels, kernel_size=1, stride=1, padding=0),
15
+ nn.BatchNorm2d(inter_channels),
16
+ nn.SiLU(inplace=True),
17
+ nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
18
+ nn.BatchNorm2d(channels),
19
+ )
20
+
21
+ def forward(self, x, ds_y):
22
+ xa = torch.cat((x, ds_y), dim=1)
23
+ x_att = self.local_att(xa)
24
+ x_att = 1.0 + torch.tanh(x_att)
25
+ xo = torch.mul(x, x_att) + torch.mul(ds_y, 2.0 - x_att)
26
+
27
+ return xo
eres2net/kaldi.py ADDED
@@ -0,0 +1,844 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Tuple
3
+
4
+ import torch
5
+ import torchaudio
6
+ from torch import Tensor
7
+
8
+ __all__ = [
9
+ "get_mel_banks",
10
+ "inverse_mel_scale",
11
+ "inverse_mel_scale_scalar",
12
+ "mel_scale",
13
+ "mel_scale_scalar",
14
+ "spectrogram",
15
+ "fbank",
16
+ "mfcc",
17
+ "vtln_warp_freq",
18
+ "vtln_warp_mel_freq",
19
+ ]
20
+
21
+ # numeric_limits<float>::epsilon() 1.1920928955078125e-07
22
+ EPSILON = torch.tensor(torch.finfo(torch.float).eps)
23
+ # 1 milliseconds = 0.001 seconds
24
+ MILLISECONDS_TO_SECONDS = 0.001
25
+
26
+ # window types
27
+ HAMMING = "hamming"
28
+ HANNING = "hanning"
29
+ POVEY = "povey"
30
+ RECTANGULAR = "rectangular"
31
+ BLACKMAN = "blackman"
32
+ WINDOWS = [HAMMING, HANNING, POVEY, RECTANGULAR, BLACKMAN]
33
+
34
+
35
+ def _get_epsilon(device, dtype):
36
+ return EPSILON.to(device=device, dtype=dtype)
37
+
38
+
39
+ def _next_power_of_2(x: int) -> int:
40
+ r"""Returns the smallest power of 2 that is greater than x"""
41
+ return 1 if x == 0 else 2 ** (x - 1).bit_length()
42
+
43
+
44
+ def _get_strided(waveform: Tensor, window_size: int, window_shift: int, snip_edges: bool) -> Tensor:
45
+ r"""Given a waveform (1D tensor of size ``num_samples``), it returns a 2D tensor (m, ``window_size``)
46
+ representing how the window is shifted along the waveform. Each row is a frame.
47
+
48
+ Args:
49
+ waveform (Tensor): Tensor of size ``num_samples``
50
+ window_size (int): Frame length
51
+ window_shift (int): Frame shift
52
+ snip_edges (bool): If True, end effects will be handled by outputting only frames that completely fit
53
+ in the file, and the number of frames depends on the frame_length. If False, the number of frames
54
+ depends only on the frame_shift, and we reflect the data at the ends.
55
+
56
+ Returns:
57
+ Tensor: 2D tensor of size (m, ``window_size``) where each row is a frame
58
+ """
59
+ assert waveform.dim() == 1
60
+ num_samples = waveform.size(0)
61
+ strides = (window_shift * waveform.stride(0), waveform.stride(0))
62
+
63
+ if snip_edges:
64
+ if num_samples < window_size:
65
+ return torch.empty((0, 0), dtype=waveform.dtype, device=waveform.device)
66
+ else:
67
+ m = 1 + (num_samples - window_size) // window_shift
68
+ else:
69
+ reversed_waveform = torch.flip(waveform, [0])
70
+ m = (num_samples + (window_shift // 2)) // window_shift
71
+ pad = window_size // 2 - window_shift // 2
72
+ pad_right = reversed_waveform
73
+ if pad > 0:
74
+ # torch.nn.functional.pad returns [2,1,0,1,2] for 'reflect'
75
+ # but we want [2, 1, 0, 0, 1, 2]
76
+ pad_left = reversed_waveform[-pad:]
77
+ waveform = torch.cat((pad_left, waveform, pad_right), dim=0)
78
+ else:
79
+ # pad is negative so we want to trim the waveform at the front
80
+ waveform = torch.cat((waveform[-pad:], pad_right), dim=0)
81
+
82
+ sizes = (m, window_size)
83
+ return waveform.as_strided(sizes, strides)
84
+
85
+
86
+ def _feature_window_function(
87
+ window_type: str,
88
+ window_size: int,
89
+ blackman_coeff: float,
90
+ device: torch.device,
91
+ dtype: int,
92
+ ) -> Tensor:
93
+ r"""Returns a window function with the given type and size"""
94
+ if window_type == HANNING:
95
+ return torch.hann_window(window_size, periodic=False, device=device, dtype=dtype)
96
+ elif window_type == HAMMING:
97
+ return torch.hamming_window(window_size, periodic=False, alpha=0.54, beta=0.46, device=device, dtype=dtype)
98
+ elif window_type == POVEY:
99
+ # like hanning but goes to zero at edges
100
+ return torch.hann_window(window_size, periodic=False, device=device, dtype=dtype).pow(0.85)
101
+ elif window_type == RECTANGULAR:
102
+ return torch.ones(window_size, device=device, dtype=dtype)
103
+ elif window_type == BLACKMAN:
104
+ a = 2 * math.pi / (window_size - 1)
105
+ window_function = torch.arange(window_size, device=device, dtype=dtype)
106
+ # can't use torch.blackman_window as they use different coefficients
107
+ return (
108
+ blackman_coeff
109
+ - 0.5 * torch.cos(a * window_function)
110
+ + (0.5 - blackman_coeff) * torch.cos(2 * a * window_function)
111
+ ).to(device=device, dtype=dtype)
112
+ else:
113
+ raise Exception("Invalid window type " + window_type)
114
+
115
+
116
+ def _get_log_energy(strided_input: Tensor, epsilon: Tensor, energy_floor: float) -> Tensor:
117
+ r"""Returns the log energy of size (m) for a strided_input (m,*)"""
118
+ device, dtype = strided_input.device, strided_input.dtype
119
+ log_energy = torch.max(strided_input.pow(2).sum(1), epsilon).log() # size (m)
120
+ if energy_floor == 0.0:
121
+ return log_energy
122
+ return torch.max(log_energy, torch.tensor(math.log(energy_floor), device=device, dtype=dtype))
123
+
124
+
125
+ def _get_waveform_and_window_properties(
126
+ waveform: Tensor,
127
+ channel: int,
128
+ sample_frequency: float,
129
+ frame_shift: float,
130
+ frame_length: float,
131
+ round_to_power_of_two: bool,
132
+ preemphasis_coefficient: float,
133
+ ) -> Tuple[Tensor, int, int, int]:
134
+ r"""Gets the waveform and window properties"""
135
+ channel = max(channel, 0)
136
+ assert channel < waveform.size(0), "Invalid channel {} for size {}".format(channel, waveform.size(0))
137
+ waveform = waveform[channel, :] # size (n)
138
+ window_shift = int(sample_frequency * frame_shift * MILLISECONDS_TO_SECONDS)
139
+ window_size = int(sample_frequency * frame_length * MILLISECONDS_TO_SECONDS)
140
+ padded_window_size = _next_power_of_2(window_size) if round_to_power_of_two else window_size
141
+
142
+ assert 2 <= window_size <= len(waveform), "choose a window size {} that is [2, {}]".format(
143
+ window_size, len(waveform)
144
+ )
145
+ assert 0 < window_shift, "`window_shift` must be greater than 0"
146
+ assert padded_window_size % 2 == 0, (
147
+ "the padded `window_size` must be divisible by two. use `round_to_power_of_two` or change `frame_length`"
148
+ )
149
+ assert 0.0 <= preemphasis_coefficient <= 1.0, "`preemphasis_coefficient` must be between [0,1]"
150
+ assert sample_frequency > 0, "`sample_frequency` must be greater than zero"
151
+ return waveform, window_shift, window_size, padded_window_size
152
+
153
+
154
+ def _get_window(
155
+ waveform: Tensor,
156
+ padded_window_size: int,
157
+ window_size: int,
158
+ window_shift: int,
159
+ window_type: str,
160
+ blackman_coeff: float,
161
+ snip_edges: bool,
162
+ raw_energy: bool,
163
+ energy_floor: float,
164
+ dither: float,
165
+ remove_dc_offset: bool,
166
+ preemphasis_coefficient: float,
167
+ ) -> Tuple[Tensor, Tensor]:
168
+ r"""Gets a window and its log energy
169
+
170
+ Returns:
171
+ (Tensor, Tensor): strided_input of size (m, ``padded_window_size``) and signal_log_energy of size (m)
172
+ """
173
+ device, dtype = waveform.device, waveform.dtype
174
+ epsilon = _get_epsilon(device, dtype)
175
+
176
+ # size (m, window_size)
177
+ strided_input = _get_strided(waveform, window_size, window_shift, snip_edges)
178
+
179
+ if dither != 0.0:
180
+ rand_gauss = torch.randn(strided_input.shape, device=device, dtype=dtype)
181
+ strided_input = strided_input + rand_gauss * dither
182
+
183
+ if remove_dc_offset:
184
+ # Subtract each row/frame by its mean
185
+ row_means = torch.mean(strided_input, dim=1).unsqueeze(1) # size (m, 1)
186
+ strided_input = strided_input - row_means
187
+
188
+ if raw_energy:
189
+ # Compute the log energy of each row/frame before applying preemphasis and
190
+ # window function
191
+ signal_log_energy = _get_log_energy(strided_input, epsilon, energy_floor) # size (m)
192
+
193
+ if preemphasis_coefficient != 0.0:
194
+ # strided_input[i,j] -= preemphasis_coefficient * strided_input[i, max(0, j-1)] for all i,j
195
+ offset_strided_input = torch.nn.functional.pad(strided_input.unsqueeze(0), (1, 0), mode="replicate").squeeze(
196
+ 0
197
+ ) # size (m, window_size + 1)
198
+ strided_input = strided_input - preemphasis_coefficient * offset_strided_input[:, :-1]
199
+
200
+ # Apply window_function to each row/frame
201
+ window_function = _feature_window_function(window_type, window_size, blackman_coeff, device, dtype).unsqueeze(
202
+ 0
203
+ ) # size (1, window_size)
204
+ strided_input = strided_input * window_function # size (m, window_size)
205
+
206
+ # Pad columns with zero until we reach size (m, padded_window_size)
207
+ if padded_window_size != window_size:
208
+ padding_right = padded_window_size - window_size
209
+ strided_input = torch.nn.functional.pad(
210
+ strided_input.unsqueeze(0), (0, padding_right), mode="constant", value=0
211
+ ).squeeze(0)
212
+
213
+ # Compute energy after window function (not the raw one)
214
+ if not raw_energy:
215
+ signal_log_energy = _get_log_energy(strided_input, epsilon, energy_floor) # size (m)
216
+
217
+ return strided_input, signal_log_energy
218
+
219
+
220
+ def _subtract_column_mean(tensor: Tensor, subtract_mean: bool) -> Tensor:
221
+ # subtracts the column mean of the tensor size (m, n) if subtract_mean=True
222
+ # it returns size (m, n)
223
+ if subtract_mean:
224
+ col_means = torch.mean(tensor, dim=0).unsqueeze(0)
225
+ tensor = tensor - col_means
226
+ return tensor
227
+
228
+
229
+ def spectrogram(
230
+ waveform: Tensor,
231
+ blackman_coeff: float = 0.42,
232
+ channel: int = -1,
233
+ dither: float = 0.0,
234
+ energy_floor: float = 1.0,
235
+ frame_length: float = 25.0,
236
+ frame_shift: float = 10.0,
237
+ min_duration: float = 0.0,
238
+ preemphasis_coefficient: float = 0.97,
239
+ raw_energy: bool = True,
240
+ remove_dc_offset: bool = True,
241
+ round_to_power_of_two: bool = True,
242
+ sample_frequency: float = 16000.0,
243
+ snip_edges: bool = True,
244
+ subtract_mean: bool = False,
245
+ window_type: str = POVEY,
246
+ ) -> Tensor:
247
+ r"""Create a spectrogram from a raw audio signal. This matches the input/output of Kaldi's
248
+ compute-spectrogram-feats.
249
+
250
+ Args:
251
+ waveform (Tensor): Tensor of audio of size (c, n) where c is in the range [0,2)
252
+ blackman_coeff (float, optional): Constant coefficient for generalized Blackman window. (Default: ``0.42``)
253
+ channel (int, optional): Channel to extract (-1 -> expect mono, 0 -> left, 1 -> right) (Default: ``-1``)
254
+ dither (float, optional): Dithering constant (0.0 means no dither). If you turn this off, you should set
255
+ the energy_floor option, e.g. to 1.0 or 0.1 (Default: ``0.0``)
256
+ energy_floor (float, optional): Floor on energy (absolute, not relative) in Spectrogram computation. Caution:
257
+ this floor is applied to the zeroth component, representing the total signal energy. The floor on the
258
+ individual spectrogram elements is fixed at std::numeric_limits<float>::epsilon(). (Default: ``1.0``)
259
+ frame_length (float, optional): Frame length in milliseconds (Default: ``25.0``)
260
+ frame_shift (float, optional): Frame shift in milliseconds (Default: ``10.0``)
261
+ min_duration (float, optional): Minimum duration of segments to process (in seconds). (Default: ``0.0``)
262
+ preemphasis_coefficient (float, optional): Coefficient for use in signal preemphasis (Default: ``0.97``)
263
+ raw_energy (bool, optional): If True, compute energy before preemphasis and windowing (Default: ``True``)
264
+ remove_dc_offset (bool, optional): Subtract mean from waveform on each frame (Default: ``True``)
265
+ round_to_power_of_two (bool, optional): If True, round window size to power of two by zero-padding input
266
+ to FFT. (Default: ``True``)
267
+ sample_frequency (float, optional): Waveform data sample frequency (must match the waveform file, if
268
+ specified there) (Default: ``16000.0``)
269
+ snip_edges (bool, optional): If True, end effects will be handled by outputting only frames that completely fit
270
+ in the file, and the number of frames depends on the frame_length. If False, the number of frames
271
+ depends only on the frame_shift, and we reflect the data at the ends. (Default: ``True``)
272
+ subtract_mean (bool, optional): Subtract mean of each feature file [CMS]; not recommended to do
273
+ it this way. (Default: ``False``)
274
+ window_type (str, optional): Type of window ('hamming'|'hanning'|'povey'|'rectangular'|'blackman')
275
+ (Default: ``'povey'``)
276
+
277
+ Returns:
278
+ Tensor: A spectrogram identical to what Kaldi would output. The shape is
279
+ (m, ``padded_window_size // 2 + 1``) where m is calculated in _get_strided
280
+ """
281
+ device, dtype = waveform.device, waveform.dtype
282
+ epsilon = _get_epsilon(device, dtype)
283
+
284
+ waveform, window_shift, window_size, padded_window_size = _get_waveform_and_window_properties(
285
+ waveform, channel, sample_frequency, frame_shift, frame_length, round_to_power_of_two, preemphasis_coefficient
286
+ )
287
+
288
+ if len(waveform) < min_duration * sample_frequency:
289
+ # signal is too short
290
+ return torch.empty(0)
291
+
292
+ strided_input, signal_log_energy = _get_window(
293
+ waveform,
294
+ padded_window_size,
295
+ window_size,
296
+ window_shift,
297
+ window_type,
298
+ blackman_coeff,
299
+ snip_edges,
300
+ raw_energy,
301
+ energy_floor,
302
+ dither,
303
+ remove_dc_offset,
304
+ preemphasis_coefficient,
305
+ )
306
+
307
+ # size (m, padded_window_size // 2 + 1, 2)
308
+ fft = torch.fft.rfft(strided_input)
309
+
310
+ # Convert the FFT into a power spectrum
311
+ power_spectrum = torch.max(fft.abs().pow(2.0), epsilon).log() # size (m, padded_window_size // 2 + 1)
312
+ power_spectrum[:, 0] = signal_log_energy
313
+
314
+ power_spectrum = _subtract_column_mean(power_spectrum, subtract_mean)
315
+ return power_spectrum
316
+
317
+
318
+ def inverse_mel_scale_scalar(mel_freq: float) -> float:
319
+ return 700.0 * (math.exp(mel_freq / 1127.0) - 1.0)
320
+
321
+
322
+ def inverse_mel_scale(mel_freq: Tensor) -> Tensor:
323
+ return 700.0 * ((mel_freq / 1127.0).exp() - 1.0)
324
+
325
+
326
+ def mel_scale_scalar(freq: float) -> float:
327
+ return 1127.0 * math.log(1.0 + freq / 700.0)
328
+
329
+
330
+ def mel_scale(freq: Tensor) -> Tensor:
331
+ return 1127.0 * (1.0 + freq / 700.0).log()
332
+
333
+
334
+ def vtln_warp_freq(
335
+ vtln_low_cutoff: float,
336
+ vtln_high_cutoff: float,
337
+ low_freq: float,
338
+ high_freq: float,
339
+ vtln_warp_factor: float,
340
+ freq: Tensor,
341
+ ) -> Tensor:
342
+ r"""This computes a VTLN warping function that is not the same as HTK's one,
343
+ but has similar inputs (this function has the advantage of never producing
344
+ empty bins).
345
+
346
+ This function computes a warp function F(freq), defined between low_freq
347
+ and high_freq inclusive, with the following properties:
348
+ F(low_freq) == low_freq
349
+ F(high_freq) == high_freq
350
+ The function is continuous and piecewise linear with two inflection
351
+ points.
352
+ The lower inflection point (measured in terms of the unwarped
353
+ frequency) is at frequency l, determined as described below.
354
+ The higher inflection point is at a frequency h, determined as
355
+ described below.
356
+ If l <= f <= h, then F(f) = f/vtln_warp_factor.
357
+ If the higher inflection point (measured in terms of the unwarped
358
+ frequency) is at h, then max(h, F(h)) == vtln_high_cutoff.
359
+ Since (by the last point) F(h) == h/vtln_warp_factor, then
360
+ max(h, h/vtln_warp_factor) == vtln_high_cutoff, so
361
+ h = vtln_high_cutoff / max(1, 1/vtln_warp_factor).
362
+ = vtln_high_cutoff * min(1, vtln_warp_factor).
363
+ If the lower inflection point (measured in terms of the unwarped
364
+ frequency) is at l, then min(l, F(l)) == vtln_low_cutoff
365
+ This implies that l = vtln_low_cutoff / min(1, 1/vtln_warp_factor)
366
+ = vtln_low_cutoff * max(1, vtln_warp_factor)
367
+ Args:
368
+ vtln_low_cutoff (float): Lower frequency cutoffs for VTLN
369
+ vtln_high_cutoff (float): Upper frequency cutoffs for VTLN
370
+ low_freq (float): Lower frequency cutoffs in mel computation
371
+ high_freq (float): Upper frequency cutoffs in mel computation
372
+ vtln_warp_factor (float): Vtln warp factor
373
+ freq (Tensor): given frequency in Hz
374
+
375
+ Returns:
376
+ Tensor: Freq after vtln warp
377
+ """
378
+ assert vtln_low_cutoff > low_freq, "be sure to set the vtln_low option higher than low_freq"
379
+ assert vtln_high_cutoff < high_freq, "be sure to set the vtln_high option lower than high_freq [or negative]"
380
+ l = vtln_low_cutoff * max(1.0, vtln_warp_factor)
381
+ h = vtln_high_cutoff * min(1.0, vtln_warp_factor)
382
+ scale = 1.0 / vtln_warp_factor
383
+ Fl = scale * l # F(l)
384
+ Fh = scale * h # F(h)
385
+ assert l > low_freq and h < high_freq
386
+ # slope of left part of the 3-piece linear function
387
+ scale_left = (Fl - low_freq) / (l - low_freq)
388
+ # [slope of center part is just "scale"]
389
+
390
+ # slope of right part of the 3-piece linear function
391
+ scale_right = (high_freq - Fh) / (high_freq - h)
392
+
393
+ res = torch.empty_like(freq)
394
+
395
+ outside_low_high_freq = torch.lt(freq, low_freq) | torch.gt(freq, high_freq) # freq < low_freq || freq > high_freq
396
+ before_l = torch.lt(freq, l) # freq < l
397
+ before_h = torch.lt(freq, h) # freq < h
398
+ after_h = torch.ge(freq, h) # freq >= h
399
+
400
+ # order of operations matter here (since there is overlapping frequency regions)
401
+ res[after_h] = high_freq + scale_right * (freq[after_h] - high_freq)
402
+ res[before_h] = scale * freq[before_h]
403
+ res[before_l] = low_freq + scale_left * (freq[before_l] - low_freq)
404
+ res[outside_low_high_freq] = freq[outside_low_high_freq]
405
+
406
+ return res
407
+
408
+
409
+ def vtln_warp_mel_freq(
410
+ vtln_low_cutoff: float,
411
+ vtln_high_cutoff: float,
412
+ low_freq,
413
+ high_freq: float,
414
+ vtln_warp_factor: float,
415
+ mel_freq: Tensor,
416
+ ) -> Tensor:
417
+ r"""
418
+ Args:
419
+ vtln_low_cutoff (float): Lower frequency cutoffs for VTLN
420
+ vtln_high_cutoff (float): Upper frequency cutoffs for VTLN
421
+ low_freq (float): Lower frequency cutoffs in mel computation
422
+ high_freq (float): Upper frequency cutoffs in mel computation
423
+ vtln_warp_factor (float): Vtln warp factor
424
+ mel_freq (Tensor): Given frequency in Mel
425
+
426
+ Returns:
427
+ Tensor: ``mel_freq`` after vtln warp
428
+ """
429
+ return mel_scale(
430
+ vtln_warp_freq(
431
+ vtln_low_cutoff, vtln_high_cutoff, low_freq, high_freq, vtln_warp_factor, inverse_mel_scale(mel_freq)
432
+ )
433
+ )
434
+
435
+
436
+ def get_mel_banks(
437
+ num_bins: int,
438
+ window_length_padded: int,
439
+ sample_freq: float,
440
+ low_freq: float,
441
+ high_freq: float,
442
+ vtln_low: float,
443
+ vtln_high: float,
444
+ vtln_warp_factor: float,
445
+ device=None,
446
+ dtype=None,
447
+ ) -> Tuple[Tensor, Tensor]:
448
+ """
449
+ Returns:
450
+ (Tensor, Tensor): The tuple consists of ``bins`` (which is
451
+ melbank of size (``num_bins``, ``num_fft_bins``)) and ``center_freqs`` (which is
452
+ center frequencies of bins of size (``num_bins``)).
453
+ """
454
+ assert num_bins > 3, "Must have at least 3 mel bins"
455
+ assert window_length_padded % 2 == 0
456
+ num_fft_bins = window_length_padded / 2
457
+ nyquist = 0.5 * sample_freq
458
+
459
+ if high_freq <= 0.0:
460
+ high_freq += nyquist
461
+
462
+ assert (0.0 <= low_freq < nyquist) and (0.0 < high_freq <= nyquist) and (low_freq < high_freq), (
463
+ "Bad values in options: low-freq {} and high-freq {} vs. nyquist {}".format(low_freq, high_freq, nyquist)
464
+ )
465
+
466
+ # fft-bin width [think of it as Nyquist-freq / half-window-length]
467
+ fft_bin_width = sample_freq / window_length_padded
468
+ mel_low_freq = mel_scale_scalar(low_freq)
469
+ mel_high_freq = mel_scale_scalar(high_freq)
470
+
471
+ # divide by num_bins+1 in next line because of end-effects where the bins
472
+ # spread out to the sides.
473
+ mel_freq_delta = (mel_high_freq - mel_low_freq) / (num_bins + 1)
474
+
475
+ if vtln_high < 0.0:
476
+ vtln_high += nyquist
477
+
478
+ assert vtln_warp_factor == 1.0 or (
479
+ (low_freq < vtln_low < high_freq) and (0.0 < vtln_high < high_freq) and (vtln_low < vtln_high)
480
+ ), "Bad values in options: vtln-low {} and vtln-high {}, versus low-freq {} and high-freq {}".format(
481
+ vtln_low, vtln_high, low_freq, high_freq
482
+ )
483
+
484
+ bin = torch.arange(num_bins).unsqueeze(1)
485
+ left_mel = mel_low_freq + bin * mel_freq_delta # size(num_bins, 1)
486
+ center_mel = mel_low_freq + (bin + 1.0) * mel_freq_delta # size(num_bins, 1)
487
+ right_mel = mel_low_freq + (bin + 2.0) * mel_freq_delta # size(num_bins, 1)
488
+
489
+ if vtln_warp_factor != 1.0:
490
+ left_mel = vtln_warp_mel_freq(vtln_low, vtln_high, low_freq, high_freq, vtln_warp_factor, left_mel)
491
+ center_mel = vtln_warp_mel_freq(vtln_low, vtln_high, low_freq, high_freq, vtln_warp_factor, center_mel)
492
+ right_mel = vtln_warp_mel_freq(vtln_low, vtln_high, low_freq, high_freq, vtln_warp_factor, right_mel)
493
+
494
+ # center_freqs = inverse_mel_scale(center_mel) # size (num_bins)
495
+ # size(1, num_fft_bins)
496
+ mel = mel_scale(fft_bin_width * torch.arange(num_fft_bins)).unsqueeze(0)
497
+
498
+ # size (num_bins, num_fft_bins)
499
+ up_slope = (mel - left_mel) / (center_mel - left_mel)
500
+ down_slope = (right_mel - mel) / (right_mel - center_mel)
501
+
502
+ if vtln_warp_factor == 1.0:
503
+ # left_mel < center_mel < right_mel so we can min the two slopes and clamp negative values
504
+ bins = torch.max(torch.zeros(1), torch.min(up_slope, down_slope))
505
+ else:
506
+ # warping can move the order of left_mel, center_mel, right_mel anywhere
507
+ bins = torch.zeros_like(up_slope)
508
+ up_idx = torch.gt(mel, left_mel) & torch.le(mel, center_mel) # left_mel < mel <= center_mel
509
+ down_idx = torch.gt(mel, center_mel) & torch.lt(mel, right_mel) # center_mel < mel < right_mel
510
+ bins[up_idx] = up_slope[up_idx]
511
+ bins[down_idx] = down_slope[down_idx]
512
+
513
+ return bins.to(device=device, dtype=dtype) # , center_freqs
514
+
515
+
516
+ cache = {}
517
+
518
+
519
+ def fbank(
520
+ waveform: Tensor,
521
+ blackman_coeff: float = 0.42,
522
+ channel: int = -1,
523
+ dither: float = 0.0,
524
+ energy_floor: float = 1.0,
525
+ frame_length: float = 25.0,
526
+ frame_shift: float = 10.0,
527
+ high_freq: float = 0.0,
528
+ htk_compat: bool = False,
529
+ low_freq: float = 20.0,
530
+ min_duration: float = 0.0,
531
+ num_mel_bins: int = 23,
532
+ preemphasis_coefficient: float = 0.97,
533
+ raw_energy: bool = True,
534
+ remove_dc_offset: bool = True,
535
+ round_to_power_of_two: bool = True,
536
+ sample_frequency: float = 16000.0,
537
+ snip_edges: bool = True,
538
+ subtract_mean: bool = False,
539
+ use_energy: bool = False,
540
+ use_log_fbank: bool = True,
541
+ use_power: bool = True,
542
+ vtln_high: float = -500.0,
543
+ vtln_low: float = 100.0,
544
+ vtln_warp: float = 1.0,
545
+ window_type: str = POVEY,
546
+ ) -> Tensor:
547
+ r"""Create a fbank from a raw audio signal. This matches the input/output of Kaldi's
548
+ compute-fbank-feats.
549
+
550
+ Args:
551
+ waveform (Tensor): Tensor of audio of size (c, n) where c is in the range [0,2)
552
+ blackman_coeff (float, optional): Constant coefficient for generalized Blackman window. (Default: ``0.42``)
553
+ channel (int, optional): Channel to extract (-1 -> expect mono, 0 -> left, 1 -> right) (Default: ``-1``)
554
+ dither (float, optional): Dithering constant (0.0 means no dither). If you turn this off, you should set
555
+ the energy_floor option, e.g. to 1.0 or 0.1 (Default: ``0.0``)
556
+ energy_floor (float, optional): Floor on energy (absolute, not relative) in Spectrogram computation. Caution:
557
+ this floor is applied to the zeroth component, representing the total signal energy. The floor on the
558
+ individual spectrogram elements is fixed at std::numeric_limits<float>::epsilon(). (Default: ``1.0``)
559
+ frame_length (float, optional): Frame length in milliseconds (Default: ``25.0``)
560
+ frame_shift (float, optional): Frame shift in milliseconds (Default: ``10.0``)
561
+ high_freq (float, optional): High cutoff frequency for mel bins (if <= 0, offset from Nyquist)
562
+ (Default: ``0.0``)
563
+ htk_compat (bool, optional): If true, put energy last. Warning: not sufficient to get HTK compatible features
564
+ (need to change other parameters). (Default: ``False``)
565
+ low_freq (float, optional): Low cutoff frequency for mel bins (Default: ``20.0``)
566
+ min_duration (float, optional): Minimum duration of segments to process (in seconds). (Default: ``0.0``)
567
+ num_mel_bins (int, optional): Number of triangular mel-frequency bins (Default: ``23``)
568
+ preemphasis_coefficient (float, optional): Coefficient for use in signal preemphasis (Default: ``0.97``)
569
+ raw_energy (bool, optional): If True, compute energy before preemphasis and windowing (Default: ``True``)
570
+ remove_dc_offset (bool, optional): Subtract mean from waveform on each frame (Default: ``True``)
571
+ round_to_power_of_two (bool, optional): If True, round window size to power of two by zero-padding input
572
+ to FFT. (Default: ``True``)
573
+ sample_frequency (float, optional): Waveform data sample frequency (must match the waveform file, if
574
+ specified there) (Default: ``16000.0``)
575
+ snip_edges (bool, optional): If True, end effects will be handled by outputting only frames that completely fit
576
+ in the file, and the number of frames depends on the frame_length. If False, the number of frames
577
+ depends only on the frame_shift, and we reflect the data at the ends. (Default: ``True``)
578
+ subtract_mean (bool, optional): Subtract mean of each feature file [CMS]; not recommended to do
579
+ it this way. (Default: ``False``)
580
+ use_energy (bool, optional): Add an extra dimension with energy to the FBANK output. (Default: ``False``)
581
+ use_log_fbank (bool, optional):If true, produce log-filterbank, else produce linear. (Default: ``True``)
582
+ use_power (bool, optional): If true, use power, else use magnitude. (Default: ``True``)
583
+ vtln_high (float, optional): High inflection point in piecewise linear VTLN warping function (if
584
+ negative, offset from high-mel-freq (Default: ``-500.0``)
585
+ vtln_low (float, optional): Low inflection point in piecewise linear VTLN warping function (Default: ``100.0``)
586
+ vtln_warp (float, optional): Vtln warp factor (only applicable if vtln_map not specified) (Default: ``1.0``)
587
+ window_type (str, optional): Type of window ('hamming'|'hanning'|'povey'|'rectangular'|'blackman')
588
+ (Default: ``'povey'``)
589
+
590
+ Returns:
591
+ Tensor: A fbank identical to what Kaldi would output. The shape is (m, ``num_mel_bins + use_energy``)
592
+ where m is calculated in _get_strided
593
+ """
594
+ device, dtype = waveform.device, waveform.dtype
595
+
596
+ waveform, window_shift, window_size, padded_window_size = _get_waveform_and_window_properties(
597
+ waveform, channel, sample_frequency, frame_shift, frame_length, round_to_power_of_two, preemphasis_coefficient
598
+ )
599
+
600
+ if len(waveform) < min_duration * sample_frequency:
601
+ # signal is too short
602
+ return torch.empty(0, device=device, dtype=dtype)
603
+
604
+ # strided_input, size (m, padded_window_size) and signal_log_energy, size (m)
605
+ strided_input, signal_log_energy = _get_window(
606
+ waveform,
607
+ padded_window_size,
608
+ window_size,
609
+ window_shift,
610
+ window_type,
611
+ blackman_coeff,
612
+ snip_edges,
613
+ raw_energy,
614
+ energy_floor,
615
+ dither,
616
+ remove_dc_offset,
617
+ preemphasis_coefficient,
618
+ )
619
+
620
+ # size (m, padded_window_size // 2 + 1)
621
+ spectrum = torch.fft.rfft(strided_input).abs()
622
+ if use_power:
623
+ spectrum = spectrum.pow(2.0)
624
+
625
+ # size (num_mel_bins, padded_window_size // 2)
626
+ # print(num_mel_bins, padded_window_size, sample_frequency, low_freq, high_freq, vtln_low, vtln_high, vtln_warp)
627
+
628
+ cache_key = "%s-%s-%s-%s-%s-%s-%s-%s-%s-%s" % (
629
+ num_mel_bins,
630
+ padded_window_size,
631
+ sample_frequency,
632
+ low_freq,
633
+ high_freq,
634
+ vtln_low,
635
+ vtln_high,
636
+ vtln_warp,
637
+ device,
638
+ dtype,
639
+ )
640
+ if cache_key not in cache:
641
+ mel_energies = get_mel_banks(
642
+ num_mel_bins,
643
+ padded_window_size,
644
+ sample_frequency,
645
+ low_freq,
646
+ high_freq,
647
+ vtln_low,
648
+ vtln_high,
649
+ vtln_warp,
650
+ device,
651
+ dtype,
652
+ )
653
+ cache[cache_key] = mel_energies
654
+ else:
655
+ mel_energies = cache[cache_key]
656
+
657
+ # pad right column with zeros and add dimension, size (num_mel_bins, padded_window_size // 2 + 1)
658
+ mel_energies = torch.nn.functional.pad(mel_energies, (0, 1), mode="constant", value=0)
659
+
660
+ # sum with mel fiterbanks over the power spectrum, size (m, num_mel_bins)
661
+ mel_energies = torch.mm(spectrum, mel_energies.T)
662
+ if use_log_fbank:
663
+ # avoid log of zero (which should be prevented anyway by dithering)
664
+ mel_energies = torch.max(mel_energies, _get_epsilon(device, dtype)).log()
665
+
666
+ # if use_energy then add it as the last column for htk_compat == true else first column
667
+ if use_energy:
668
+ signal_log_energy = signal_log_energy.unsqueeze(1) # size (m, 1)
669
+ # returns size (m, num_mel_bins + 1)
670
+ if htk_compat:
671
+ mel_energies = torch.cat((mel_energies, signal_log_energy), dim=1)
672
+ else:
673
+ mel_energies = torch.cat((signal_log_energy, mel_energies), dim=1)
674
+
675
+ mel_energies = _subtract_column_mean(mel_energies, subtract_mean)
676
+ return mel_energies
677
+
678
+
679
+ def _get_dct_matrix(num_ceps: int, num_mel_bins: int) -> Tensor:
680
+ # returns a dct matrix of size (num_mel_bins, num_ceps)
681
+ # size (num_mel_bins, num_mel_bins)
682
+ dct_matrix = torchaudio.functional.create_dct(num_mel_bins, num_mel_bins, "ortho")
683
+ # kaldi expects the first cepstral to be weighted sum of factor sqrt(1/num_mel_bins)
684
+ # this would be the first column in the dct_matrix for torchaudio as it expects a
685
+ # right multiply (which would be the first column of the kaldi's dct_matrix as kaldi
686
+ # expects a left multiply e.g. dct_matrix * vector).
687
+ dct_matrix[:, 0] = math.sqrt(1 / float(num_mel_bins))
688
+ dct_matrix = dct_matrix[:, :num_ceps]
689
+ return dct_matrix
690
+
691
+
692
+ def _get_lifter_coeffs(num_ceps: int, cepstral_lifter: float) -> Tensor:
693
+ # returns size (num_ceps)
694
+ # Compute liftering coefficients (scaling on cepstral coeffs)
695
+ # coeffs are numbered slightly differently from HTK: the zeroth index is C0, which is not affected.
696
+ i = torch.arange(num_ceps)
697
+ return 1.0 + 0.5 * cepstral_lifter * torch.sin(math.pi * i / cepstral_lifter)
698
+
699
+
700
+ def mfcc(
701
+ waveform: Tensor,
702
+ blackman_coeff: float = 0.42,
703
+ cepstral_lifter: float = 22.0,
704
+ channel: int = -1,
705
+ dither: float = 0.0,
706
+ energy_floor: float = 1.0,
707
+ frame_length: float = 25.0,
708
+ frame_shift: float = 10.0,
709
+ high_freq: float = 0.0,
710
+ htk_compat: bool = False,
711
+ low_freq: float = 20.0,
712
+ num_ceps: int = 13,
713
+ min_duration: float = 0.0,
714
+ num_mel_bins: int = 23,
715
+ preemphasis_coefficient: float = 0.97,
716
+ raw_energy: bool = True,
717
+ remove_dc_offset: bool = True,
718
+ round_to_power_of_two: bool = True,
719
+ sample_frequency: float = 16000.0,
720
+ snip_edges: bool = True,
721
+ subtract_mean: bool = False,
722
+ use_energy: bool = False,
723
+ vtln_high: float = -500.0,
724
+ vtln_low: float = 100.0,
725
+ vtln_warp: float = 1.0,
726
+ window_type: str = POVEY,
727
+ ) -> Tensor:
728
+ r"""Create a mfcc from a raw audio signal. This matches the input/output of Kaldi's
729
+ compute-mfcc-feats.
730
+
731
+ Args:
732
+ waveform (Tensor): Tensor of audio of size (c, n) where c is in the range [0,2)
733
+ blackman_coeff (float, optional): Constant coefficient for generalized Blackman window. (Default: ``0.42``)
734
+ cepstral_lifter (float, optional): Constant that controls scaling of MFCCs (Default: ``22.0``)
735
+ channel (int, optional): Channel to extract (-1 -> expect mono, 0 -> left, 1 -> right) (Default: ``-1``)
736
+ dither (float, optional): Dithering constant (0.0 means no dither). If you turn this off, you should set
737
+ the energy_floor option, e.g. to 1.0 or 0.1 (Default: ``0.0``)
738
+ energy_floor (float, optional): Floor on energy (absolute, not relative) in Spectrogram computation. Caution:
739
+ this floor is applied to the zeroth component, representing the total signal energy. The floor on the
740
+ individual spectrogram elements is fixed at std::numeric_limits<float>::epsilon(). (Default: ``1.0``)
741
+ frame_length (float, optional): Frame length in milliseconds (Default: ``25.0``)
742
+ frame_shift (float, optional): Frame shift in milliseconds (Default: ``10.0``)
743
+ high_freq (float, optional): High cutoff frequency for mel bins (if <= 0, offset from Nyquist)
744
+ (Default: ``0.0``)
745
+ htk_compat (bool, optional): If true, put energy last. Warning: not sufficient to get HTK compatible
746
+ features (need to change other parameters). (Default: ``False``)
747
+ low_freq (float, optional): Low cutoff frequency for mel bins (Default: ``20.0``)
748
+ num_ceps (int, optional): Number of cepstra in MFCC computation (including C0) (Default: ``13``)
749
+ min_duration (float, optional): Minimum duration of segments to process (in seconds). (Default: ``0.0``)
750
+ num_mel_bins (int, optional): Number of triangular mel-frequency bins (Default: ``23``)
751
+ preemphasis_coefficient (float, optional): Coefficient for use in signal preemphasis (Default: ``0.97``)
752
+ raw_energy (bool, optional): If True, compute energy before preemphasis and windowing (Default: ``True``)
753
+ remove_dc_offset (bool, optional): Subtract mean from waveform on each frame (Default: ``True``)
754
+ round_to_power_of_two (bool, optional): If True, round window size to power of two by zero-padding input
755
+ to FFT. (Default: ``True``)
756
+ sample_frequency (float, optional): Waveform data sample frequency (must match the waveform file, if
757
+ specified there) (Default: ``16000.0``)
758
+ snip_edges (bool, optional): If True, end effects will be handled by outputting only frames that completely fit
759
+ in the file, and the number of frames depends on the frame_length. If False, the number of frames
760
+ depends only on the frame_shift, and we reflect the data at the ends. (Default: ``True``)
761
+ subtract_mean (bool, optional): Subtract mean of each feature file [CMS]; not recommended to do
762
+ it this way. (Default: ``False``)
763
+ use_energy (bool, optional): Add an extra dimension with energy to the FBANK output. (Default: ``False``)
764
+ vtln_high (float, optional): High inflection point in piecewise linear VTLN warping function (if
765
+ negative, offset from high-mel-freq (Default: ``-500.0``)
766
+ vtln_low (float, optional): Low inflection point in piecewise linear VTLN warping function (Default: ``100.0``)
767
+ vtln_warp (float, optional): Vtln warp factor (only applicable if vtln_map not specified) (Default: ``1.0``)
768
+ window_type (str, optional): Type of window ('hamming'|'hanning'|'povey'|'rectangular'|'blackman')
769
+ (Default: ``"povey"``)
770
+
771
+ Returns:
772
+ Tensor: A mfcc identical to what Kaldi would output. The shape is (m, ``num_ceps``)
773
+ where m is calculated in _get_strided
774
+ """
775
+ assert num_ceps <= num_mel_bins, "num_ceps cannot be larger than num_mel_bins: %d vs %d" % (num_ceps, num_mel_bins)
776
+
777
+ device, dtype = waveform.device, waveform.dtype
778
+
779
+ # The mel_energies should not be squared (use_power=True), not have mean subtracted
780
+ # (subtract_mean=False), and use log (use_log_fbank=True).
781
+ # size (m, num_mel_bins + use_energy)
782
+ feature = fbank(
783
+ waveform=waveform,
784
+ blackman_coeff=blackman_coeff,
785
+ channel=channel,
786
+ dither=dither,
787
+ energy_floor=energy_floor,
788
+ frame_length=frame_length,
789
+ frame_shift=frame_shift,
790
+ high_freq=high_freq,
791
+ htk_compat=htk_compat,
792
+ low_freq=low_freq,
793
+ min_duration=min_duration,
794
+ num_mel_bins=num_mel_bins,
795
+ preemphasis_coefficient=preemphasis_coefficient,
796
+ raw_energy=raw_energy,
797
+ remove_dc_offset=remove_dc_offset,
798
+ round_to_power_of_two=round_to_power_of_two,
799
+ sample_frequency=sample_frequency,
800
+ snip_edges=snip_edges,
801
+ subtract_mean=False,
802
+ use_energy=use_energy,
803
+ use_log_fbank=True,
804
+ use_power=True,
805
+ vtln_high=vtln_high,
806
+ vtln_low=vtln_low,
807
+ vtln_warp=vtln_warp,
808
+ window_type=window_type,
809
+ )
810
+
811
+ if use_energy:
812
+ # size (m)
813
+ signal_log_energy = feature[:, num_mel_bins if htk_compat else 0]
814
+ # offset is 0 if htk_compat==True else 1
815
+ mel_offset = int(not htk_compat)
816
+ feature = feature[:, mel_offset : (num_mel_bins + mel_offset)]
817
+
818
+ # size (num_mel_bins, num_ceps)
819
+ dct_matrix = _get_dct_matrix(num_ceps, num_mel_bins).to(dtype=dtype, device=device)
820
+
821
+ # size (m, num_ceps)
822
+ feature = feature.matmul(dct_matrix)
823
+
824
+ if cepstral_lifter != 0.0:
825
+ # size (1, num_ceps)
826
+ lifter_coeffs = _get_lifter_coeffs(num_ceps, cepstral_lifter).unsqueeze(0)
827
+ feature *= lifter_coeffs.to(device=device, dtype=dtype)
828
+
829
+ # if use_energy then replace the last column for htk_compat == true else first column
830
+ if use_energy:
831
+ feature[:, 0] = signal_log_energy
832
+
833
+ if htk_compat:
834
+ energy = feature[:, 0].unsqueeze(1) # size (m, 1)
835
+ feature = feature[:, 1:] # size (m, num_ceps - 1)
836
+ if not use_energy:
837
+ # scale on C0 (actually removing a scale we previously added that's
838
+ # part of one common definition of the cosine transform.)
839
+ energy *= math.sqrt(2)
840
+
841
+ feature = torch.cat((feature, energy), dim=1)
842
+
843
+ feature = _subtract_column_mean(feature, subtract_mean)
844
+ return feature
eres2net/pooling_layers.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker). All Rights Reserved.
2
+ # Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
3
+
4
+ """This implementation is adapted from https://github.com/wenet-e2e/wespeaker."""
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+
9
+
10
+ class TAP(nn.Module):
11
+ """
12
+ Temporal average pooling, only first-order mean is considered
13
+ """
14
+
15
+ def __init__(self, **kwargs):
16
+ super(TAP, self).__init__()
17
+
18
+ def forward(self, x):
19
+ pooling_mean = x.mean(dim=-1)
20
+ # To be compatable with 2D input
21
+ pooling_mean = pooling_mean.flatten(start_dim=1)
22
+ return pooling_mean
23
+
24
+
25
+ class TSDP(nn.Module):
26
+ """
27
+ Temporal standard deviation pooling, only second-order std is considered
28
+ """
29
+
30
+ def __init__(self, **kwargs):
31
+ super(TSDP, self).__init__()
32
+
33
+ def forward(self, x):
34
+ # The last dimension is the temporal axis
35
+ pooling_std = torch.sqrt(torch.var(x, dim=-1) + 1e-8)
36
+ pooling_std = pooling_std.flatten(start_dim=1)
37
+ return pooling_std
38
+
39
+
40
+ class TSTP(nn.Module):
41
+ """
42
+ Temporal statistics pooling, concatenate mean and std, which is used in
43
+ x-vector
44
+ Comment: simple concatenation can not make full use of both statistics
45
+ """
46
+
47
+ def __init__(self, **kwargs):
48
+ super(TSTP, self).__init__()
49
+
50
+ def forward(self, x):
51
+ # The last dimension is the temporal axis
52
+ pooling_mean = x.mean(dim=-1)
53
+ pooling_std = torch.sqrt(torch.var(x, dim=-1) + 1e-8)
54
+ pooling_mean = pooling_mean.flatten(start_dim=1)
55
+ pooling_std = pooling_std.flatten(start_dim=1)
56
+
57
+ stats = torch.cat((pooling_mean, pooling_std), 1)
58
+ return stats
59
+
60
+
61
+ class ASTP(nn.Module):
62
+ """Attentive statistics pooling: Channel- and context-dependent
63
+ statistics pooling, first used in ECAPA_TDNN.
64
+ """
65
+
66
+ def __init__(self, in_dim, bottleneck_dim=128, global_context_att=False):
67
+ super(ASTP, self).__init__()
68
+ self.global_context_att = global_context_att
69
+
70
+ # Use Conv1d with stride == 1 rather than Linear, then we don't
71
+ # need to transpose inputs.
72
+ if global_context_att:
73
+ self.linear1 = nn.Conv1d(in_dim * 3, bottleneck_dim, kernel_size=1) # equals W and b in the paper
74
+ else:
75
+ self.linear1 = nn.Conv1d(in_dim, bottleneck_dim, kernel_size=1) # equals W and b in the paper
76
+ self.linear2 = nn.Conv1d(bottleneck_dim, in_dim, kernel_size=1) # equals V and k in the paper
77
+
78
+ def forward(self, x):
79
+ """
80
+ x: a 3-dimensional tensor in tdnn-based architecture (B,F,T)
81
+ or a 4-dimensional tensor in resnet architecture (B,C,F,T)
82
+ 0-dim: batch-dimension, last-dim: time-dimension (frame-dimension)
83
+ """
84
+ if len(x.shape) == 4:
85
+ x = x.reshape(x.shape[0], x.shape[1] * x.shape[2], x.shape[3])
86
+ assert len(x.shape) == 3
87
+
88
+ if self.global_context_att:
89
+ context_mean = torch.mean(x, dim=-1, keepdim=True).expand_as(x)
90
+ context_std = torch.sqrt(torch.var(x, dim=-1, keepdim=True) + 1e-10).expand_as(x)
91
+ x_in = torch.cat((x, context_mean, context_std), dim=1)
92
+ else:
93
+ x_in = x
94
+
95
+ # DON'T use ReLU here! ReLU may be hard to converge.
96
+ alpha = torch.tanh(self.linear1(x_in)) # alpha = F.relu(self.linear1(x_in))
97
+ alpha = torch.softmax(self.linear2(alpha), dim=2)
98
+ mean = torch.sum(alpha * x, dim=2)
99
+ var = torch.sum(alpha * (x**2), dim=2) - mean**2
100
+ std = torch.sqrt(var.clamp(min=1e-10))
101
+ return torch.cat([mean, std], dim=1)
f5_tts/model/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # from f5_tts.model.cfm import CFM
2
+ #
3
+ # from f5_tts.model.backbones.unett import UNetT
4
+ from GPT_SoVITS.f5_tts.model.backbones.dit import DiT
5
+ # from f5_tts.model.backbones.dit import DiTNoCond
6
+ # from f5_tts.model.backbones.dit import DiTNoCondNoT
7
+ # from f5_tts.model.backbones.mmdit import MMDiT
8
+
9
+ # from f5_tts.model.trainer import Trainer
10
+
11
+
12
+ # __all__ = ["CFM", "UNetT", "DiT", "MMDiT", "Trainer"]
13
+ # __all__ = ["CFM", "UNetT", "DiTNoCond","DiT", "MMDiT"]
f5_tts/model/backbones/README.md ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Backbones quick introduction
2
+
3
+
4
+ ### unett.py
5
+ - flat unet transformer
6
+ - structure same as in e2-tts & voicebox paper except using rotary pos emb
7
+ - update: allow possible abs pos emb & convnextv2 blocks for embedded text before concat
8
+
9
+ ### dit.py
10
+ - adaln-zero dit
11
+ - embedded timestep as condition
12
+ - concatted noised_input + masked_cond + embedded_text, linear proj in
13
+ - possible abs pos emb & convnextv2 blocks for embedded text before concat
14
+ - possible long skip connection (first layer to last layer)
15
+
16
+ ### mmdit.py
17
+ - sd3 structure
18
+ - timestep as condition
19
+ - left stream: text embedded and applied a abs pos emb
20
+ - right stream: masked_cond & noised_input concatted and with same conv pos emb as unett
f5_tts/model/backbones/dit.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ein notation:
3
+ b - batch
4
+ n - sequence
5
+ nt - text sequence
6
+ nw - raw wave length
7
+ d - dimension
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ import torch
13
+ from torch import nn
14
+ from torch.utils.checkpoint import checkpoint
15
+
16
+ from x_transformers.x_transformers import RotaryEmbedding
17
+
18
+ from GPT_SoVITS.f5_tts.model.modules import (
19
+ TimestepEmbedding,
20
+ ConvNeXtV2Block,
21
+ ConvPositionEmbedding,
22
+ DiTBlock,
23
+ AdaLayerNormZero_Final,
24
+ precompute_freqs_cis,
25
+ get_pos_embed_indices,
26
+ )
27
+
28
+ from module.commons import sequence_mask
29
+
30
+
31
+ class TextEmbedding(nn.Module):
32
+ def __init__(self, text_dim, conv_layers=0, conv_mult=2):
33
+ super().__init__()
34
+ if conv_layers > 0:
35
+ self.extra_modeling = True
36
+ self.precompute_max_pos = 4096 # ~44s of 24khz audio
37
+ self.register_buffer("freqs_cis", precompute_freqs_cis(text_dim, self.precompute_max_pos), persistent=False)
38
+ self.text_blocks = nn.Sequential(
39
+ *[ConvNeXtV2Block(text_dim, text_dim * conv_mult) for _ in range(conv_layers)]
40
+ )
41
+ else:
42
+ self.extra_modeling = False
43
+
44
+ def forward(self, text: int["b nt"], seq_len, drop_text=False): # noqa: F722
45
+ batch, text_len = text.shape[0], text.shape[1]
46
+
47
+ if drop_text: # cfg for text
48
+ text = torch.zeros_like(text)
49
+
50
+ # possible extra modeling
51
+ if self.extra_modeling:
52
+ # sinus pos emb
53
+ batch_start = torch.zeros((batch,), dtype=torch.long)
54
+ pos_idx = get_pos_embed_indices(batch_start, seq_len, max_pos=self.precompute_max_pos)
55
+ text_pos_embed = self.freqs_cis[pos_idx]
56
+
57
+ # print(23333333,text.shape,text_pos_embed.shape)#torch.Size([7, 465, 256]) torch.Size([7, 465, 256])
58
+
59
+ text = text + text_pos_embed
60
+
61
+ # convnextv2 blocks
62
+ text = self.text_blocks(text)
63
+
64
+ return text
65
+
66
+
67
+ # noised input audio and context mixing embedding
68
+
69
+
70
+ class InputEmbedding(nn.Module):
71
+ def __init__(self, mel_dim, text_dim, out_dim):
72
+ super().__init__()
73
+ self.proj = nn.Linear(mel_dim * 2 + text_dim, out_dim)
74
+ self.conv_pos_embed = ConvPositionEmbedding(dim=out_dim)
75
+
76
+ def forward(self, x: float["b n d"], cond: float["b n d"], text_embed: float["b n d"], drop_audio_cond=False): # noqa: F722
77
+ if drop_audio_cond: # cfg for cond audio
78
+ cond = torch.zeros_like(cond)
79
+
80
+ x = self.proj(torch.cat((x, cond, text_embed), dim=-1))
81
+ x = self.conv_pos_embed(x) + x
82
+ return x
83
+
84
+
85
+ # Transformer backbone using DiT blocks
86
+
87
+
88
+ class DiT(nn.Module):
89
+ def __init__(
90
+ self,
91
+ *,
92
+ dim,
93
+ depth=8,
94
+ heads=8,
95
+ dim_head=64,
96
+ dropout=0.1,
97
+ ff_mult=4,
98
+ mel_dim=100,
99
+ text_dim=None,
100
+ conv_layers=0,
101
+ long_skip_connection=False,
102
+ ):
103
+ super().__init__()
104
+
105
+ self.time_embed = TimestepEmbedding(dim)
106
+ self.d_embed = TimestepEmbedding(dim)
107
+ if text_dim is None:
108
+ text_dim = mel_dim
109
+ self.text_embed = TextEmbedding(text_dim, conv_layers=conv_layers)
110
+ self.input_embed = InputEmbedding(mel_dim, text_dim, dim)
111
+
112
+ self.rotary_embed = RotaryEmbedding(dim_head)
113
+
114
+ self.dim = dim
115
+ self.depth = depth
116
+
117
+ self.transformer_blocks = nn.ModuleList(
118
+ [DiTBlock(dim=dim, heads=heads, dim_head=dim_head, ff_mult=ff_mult, dropout=dropout) for _ in range(depth)]
119
+ )
120
+ self.long_skip_connection = nn.Linear(dim * 2, dim, bias=False) if long_skip_connection else None
121
+
122
+ self.norm_out = AdaLayerNormZero_Final(dim) # final modulation
123
+ self.proj_out = nn.Linear(dim, mel_dim)
124
+
125
+ def ckpt_wrapper(self, module):
126
+ # https://github.com/chuanyangjin/fast-DiT/blob/main/models.py
127
+ def ckpt_forward(*inputs):
128
+ outputs = module(*inputs)
129
+ return outputs
130
+
131
+ return ckpt_forward
132
+
133
+ def forward( # x, prompt_x, x_lens, t, style,cond
134
+ self, # d is channel,n is T
135
+ x0: float["b n d"], # nosied input audio # noqa: F722
136
+ cond0: float["b n d"], # masked cond audio # noqa: F722
137
+ x_lens,
138
+ time: float["b"] | float[""], # time step # noqa: F821 F722
139
+ dt_base_bootstrap,
140
+ text0, # : int["b nt"] # noqa: F722#####condition feature
141
+ use_grad_ckpt=False, # bool
142
+ ###no-use
143
+ drop_audio_cond=False, # cfg for cond audio
144
+ drop_text=False, # cfg for text
145
+ # mask: bool["b n"] | None = None, # noqa: F722
146
+ infer=False, # bool
147
+ text_cache=None, # torch tensor as text_embed
148
+ dt_cache=None, # torch tensor as dt
149
+ ):
150
+ x = x0.transpose(2, 1)
151
+ cond = cond0.transpose(2, 1)
152
+ text = text0.transpose(2, 1)
153
+ mask = sequence_mask(x_lens, max_length=x.size(1)).to(x.device)
154
+
155
+ batch, seq_len = x.shape[0], x.shape[1]
156
+ if time.ndim == 0:
157
+ time = time.repeat(batch)
158
+
159
+ # t: conditioning time, c: context (text + masked cond audio), x: noised input audio
160
+ t = self.time_embed(time)
161
+ if infer and dt_cache is not None:
162
+ dt = dt_cache
163
+ else:
164
+ dt = self.d_embed(dt_base_bootstrap)
165
+ t += dt
166
+
167
+ if infer and text_cache is not None:
168
+ text_embed = text_cache
169
+ else:
170
+ text_embed = self.text_embed(text, seq_len, drop_text=drop_text) ###need to change
171
+
172
+ x = self.input_embed(x, cond, text_embed, drop_audio_cond=drop_audio_cond)
173
+
174
+ rope = self.rotary_embed.forward_from_seq_len(seq_len)
175
+
176
+ if self.long_skip_connection is not None:
177
+ residual = x
178
+
179
+ for block in self.transformer_blocks:
180
+ if use_grad_ckpt:
181
+ x = checkpoint(self.ckpt_wrapper(block), x, t, mask, rope, use_reentrant=False)
182
+ else:
183
+ x = block(x, t, mask=mask, rope=rope)
184
+
185
+ if self.long_skip_connection is not None:
186
+ x = self.long_skip_connection(torch.cat((x, residual), dim=-1))
187
+
188
+ x = self.norm_out(x, t)
189
+ output = self.proj_out(x)
190
+
191
+ if infer:
192
+ return output, text_embed, dt
193
+ else:
194
+ return output
f5_tts/model/backbones/mmdit.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ein notation:
3
+ b - batch
4
+ n - sequence
5
+ nt - text sequence
6
+ nw - raw wave length
7
+ d - dimension
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ import torch
13
+ from torch import nn
14
+
15
+ from x_transformers.x_transformers import RotaryEmbedding
16
+
17
+ from f5_tts.model.modules import (
18
+ TimestepEmbedding,
19
+ ConvPositionEmbedding,
20
+ MMDiTBlock,
21
+ AdaLayerNormZero_Final,
22
+ precompute_freqs_cis,
23
+ get_pos_embed_indices,
24
+ )
25
+
26
+
27
+ # text embedding
28
+
29
+
30
+ class TextEmbedding(nn.Module):
31
+ def __init__(self, out_dim, text_num_embeds):
32
+ super().__init__()
33
+ self.text_embed = nn.Embedding(text_num_embeds + 1, out_dim) # will use 0 as filler token
34
+
35
+ self.precompute_max_pos = 1024
36
+ self.register_buffer("freqs_cis", precompute_freqs_cis(out_dim, self.precompute_max_pos), persistent=False)
37
+
38
+ def forward(self, text: int["b nt"], drop_text=False) -> int["b nt d"]: # noqa: F722
39
+ text = text + 1
40
+ if drop_text:
41
+ text = torch.zeros_like(text)
42
+ text = self.text_embed(text)
43
+
44
+ # sinus pos emb
45
+ batch_start = torch.zeros((text.shape[0],), dtype=torch.long)
46
+ batch_text_len = text.shape[1]
47
+ pos_idx = get_pos_embed_indices(batch_start, batch_text_len, max_pos=self.precompute_max_pos)
48
+ text_pos_embed = self.freqs_cis[pos_idx]
49
+
50
+ text = text + text_pos_embed
51
+
52
+ return text
53
+
54
+
55
+ # noised input & masked cond audio embedding
56
+
57
+
58
+ class AudioEmbedding(nn.Module):
59
+ def __init__(self, in_dim, out_dim):
60
+ super().__init__()
61
+ self.linear = nn.Linear(2 * in_dim, out_dim)
62
+ self.conv_pos_embed = ConvPositionEmbedding(out_dim)
63
+
64
+ def forward(self, x: float["b n d"], cond: float["b n d"], drop_audio_cond=False): # noqa: F722
65
+ if drop_audio_cond:
66
+ cond = torch.zeros_like(cond)
67
+ x = torch.cat((x, cond), dim=-1)
68
+ x = self.linear(x)
69
+ x = self.conv_pos_embed(x) + x
70
+ return x
71
+
72
+
73
+ # Transformer backbone using MM-DiT blocks
74
+
75
+
76
+ class MMDiT(nn.Module):
77
+ def __init__(
78
+ self,
79
+ *,
80
+ dim,
81
+ depth=8,
82
+ heads=8,
83
+ dim_head=64,
84
+ dropout=0.1,
85
+ ff_mult=4,
86
+ text_num_embeds=256,
87
+ mel_dim=100,
88
+ ):
89
+ super().__init__()
90
+
91
+ self.time_embed = TimestepEmbedding(dim)
92
+ self.text_embed = TextEmbedding(dim, text_num_embeds)
93
+ self.audio_embed = AudioEmbedding(mel_dim, dim)
94
+
95
+ self.rotary_embed = RotaryEmbedding(dim_head)
96
+
97
+ self.dim = dim
98
+ self.depth = depth
99
+
100
+ self.transformer_blocks = nn.ModuleList(
101
+ [
102
+ MMDiTBlock(
103
+ dim=dim,
104
+ heads=heads,
105
+ dim_head=dim_head,
106
+ dropout=dropout,
107
+ ff_mult=ff_mult,
108
+ context_pre_only=i == depth - 1,
109
+ )
110
+ for i in range(depth)
111
+ ]
112
+ )
113
+ self.norm_out = AdaLayerNormZero_Final(dim) # final modulation
114
+ self.proj_out = nn.Linear(dim, mel_dim)
115
+
116
+ def forward(
117
+ self,
118
+ x: float["b n d"], # nosied input audio # noqa: F722
119
+ cond: float["b n d"], # masked cond audio # noqa: F722
120
+ text: int["b nt"], # text # noqa: F722
121
+ time: float["b"] | float[""], # time step # noqa: F821 F722
122
+ drop_audio_cond, # cfg for cond audio
123
+ drop_text, # cfg for text
124
+ mask: bool["b n"] | None = None, # noqa: F722
125
+ ):
126
+ batch = x.shape[0]
127
+ if time.ndim == 0:
128
+ time = time.repeat(batch)
129
+
130
+ # t: conditioning (time), c: context (text + masked cond audio), x: noised input audio
131
+ t = self.time_embed(time)
132
+ c = self.text_embed(text, drop_text=drop_text)
133
+ x = self.audio_embed(x, cond, drop_audio_cond=drop_audio_cond)
134
+
135
+ seq_len = x.shape[1]
136
+ text_len = text.shape[1]
137
+ rope_audio = self.rotary_embed.forward_from_seq_len(seq_len)
138
+ rope_text = self.rotary_embed.forward_from_seq_len(text_len)
139
+
140
+ for block in self.transformer_blocks:
141
+ c, x = block(x, c, t, mask=mask, rope=rope_audio, c_rope=rope_text)
142
+
143
+ x = self.norm_out(x, t)
144
+ output = self.proj_out(x)
145
+
146
+ return output
f5_tts/model/backbones/unett.py ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ein notation:
3
+ b - batch
4
+ n - sequence
5
+ nt - text sequence
6
+ nw - raw wave length
7
+ d - dimension
8
+ """
9
+
10
+ from __future__ import annotations
11
+ from typing import Literal
12
+
13
+ import torch
14
+ from torch import nn
15
+ import torch.nn.functional as F
16
+
17
+ from x_transformers import RMSNorm
18
+ from x_transformers.x_transformers import RotaryEmbedding
19
+
20
+ from f5_tts.model.modules import (
21
+ TimestepEmbedding,
22
+ ConvNeXtV2Block,
23
+ ConvPositionEmbedding,
24
+ Attention,
25
+ AttnProcessor,
26
+ FeedForward,
27
+ precompute_freqs_cis,
28
+ get_pos_embed_indices,
29
+ )
30
+
31
+
32
+ # Text embedding
33
+
34
+
35
+ class TextEmbedding(nn.Module):
36
+ def __init__(self, text_num_embeds, text_dim, conv_layers=0, conv_mult=2):
37
+ super().__init__()
38
+ self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim) # use 0 as filler token
39
+
40
+ if conv_layers > 0:
41
+ self.extra_modeling = True
42
+ self.precompute_max_pos = 4096 # ~44s of 24khz audio
43
+ self.register_buffer("freqs_cis", precompute_freqs_cis(text_dim, self.precompute_max_pos), persistent=False)
44
+ self.text_blocks = nn.Sequential(
45
+ *[ConvNeXtV2Block(text_dim, text_dim * conv_mult) for _ in range(conv_layers)]
46
+ )
47
+ else:
48
+ self.extra_modeling = False
49
+
50
+ def forward(self, text: int["b nt"], seq_len, drop_text=False): # noqa: F722
51
+ text = text + 1 # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx()
52
+ text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens
53
+ batch, text_len = text.shape[0], text.shape[1]
54
+ text = F.pad(text, (0, seq_len - text_len), value=0)
55
+
56
+ if drop_text: # cfg for text
57
+ text = torch.zeros_like(text)
58
+
59
+ text = self.text_embed(text) # b n -> b n d
60
+
61
+ # possible extra modeling
62
+ if self.extra_modeling:
63
+ # sinus pos emb
64
+ batch_start = torch.zeros((batch,), dtype=torch.long)
65
+ pos_idx = get_pos_embed_indices(batch_start, seq_len, max_pos=self.precompute_max_pos)
66
+ text_pos_embed = self.freqs_cis[pos_idx]
67
+ text = text + text_pos_embed
68
+
69
+ # convnextv2 blocks
70
+ text = self.text_blocks(text)
71
+
72
+ return text
73
+
74
+
75
+ # noised input audio and context mixing embedding
76
+
77
+
78
+ class InputEmbedding(nn.Module):
79
+ def __init__(self, mel_dim, text_dim, out_dim):
80
+ super().__init__()
81
+ self.proj = nn.Linear(mel_dim * 2 + text_dim, out_dim)
82
+ self.conv_pos_embed = ConvPositionEmbedding(dim=out_dim)
83
+
84
+ def forward(self, x: float["b n d"], cond: float["b n d"], text_embed: float["b n d"], drop_audio_cond=False): # noqa: F722
85
+ if drop_audio_cond: # cfg for cond audio
86
+ cond = torch.zeros_like(cond)
87
+
88
+ x = self.proj(torch.cat((x, cond, text_embed), dim=-1))
89
+ x = self.conv_pos_embed(x) + x
90
+ return x
91
+
92
+
93
+ # Flat UNet Transformer backbone
94
+
95
+
96
+ class UNetT(nn.Module):
97
+ def __init__(
98
+ self,
99
+ *,
100
+ dim,
101
+ depth=8,
102
+ heads=8,
103
+ dim_head=64,
104
+ dropout=0.1,
105
+ ff_mult=4,
106
+ mel_dim=100,
107
+ text_num_embeds=256,
108
+ text_dim=None,
109
+ conv_layers=0,
110
+ skip_connect_type: Literal["add", "concat", "none"] = "concat",
111
+ ):
112
+ super().__init__()
113
+ assert depth % 2 == 0, "UNet-Transformer's depth should be even."
114
+
115
+ self.time_embed = TimestepEmbedding(dim)
116
+ if text_dim is None:
117
+ text_dim = mel_dim
118
+ self.text_embed = TextEmbedding(text_num_embeds, text_dim, conv_layers=conv_layers)
119
+ self.input_embed = InputEmbedding(mel_dim, text_dim, dim)
120
+
121
+ self.rotary_embed = RotaryEmbedding(dim_head)
122
+
123
+ # transformer layers & skip connections
124
+
125
+ self.dim = dim
126
+ self.skip_connect_type = skip_connect_type
127
+ needs_skip_proj = skip_connect_type == "concat"
128
+
129
+ self.depth = depth
130
+ self.layers = nn.ModuleList([])
131
+
132
+ for idx in range(depth):
133
+ is_later_half = idx >= (depth // 2)
134
+
135
+ attn_norm = RMSNorm(dim)
136
+ attn = Attention(
137
+ processor=AttnProcessor(),
138
+ dim=dim,
139
+ heads=heads,
140
+ dim_head=dim_head,
141
+ dropout=dropout,
142
+ )
143
+
144
+ ff_norm = RMSNorm(dim)
145
+ ff = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh")
146
+
147
+ skip_proj = nn.Linear(dim * 2, dim, bias=False) if needs_skip_proj and is_later_half else None
148
+
149
+ self.layers.append(
150
+ nn.ModuleList(
151
+ [
152
+ skip_proj,
153
+ attn_norm,
154
+ attn,
155
+ ff_norm,
156
+ ff,
157
+ ]
158
+ )
159
+ )
160
+
161
+ self.norm_out = RMSNorm(dim)
162
+ self.proj_out = nn.Linear(dim, mel_dim)
163
+
164
+ def forward(
165
+ self,
166
+ x: float["b n d"], # nosied input audio # noqa: F722
167
+ cond: float["b n d"], # masked cond audio # noqa: F722
168
+ text: int["b nt"], # text # noqa: F722
169
+ time: float["b"] | float[""], # time step # noqa: F821 F722
170
+ drop_audio_cond, # cfg for cond audio
171
+ drop_text, # cfg for text
172
+ mask: bool["b n"] | None = None, # noqa: F722
173
+ ):
174
+ batch, seq_len = x.shape[0], x.shape[1]
175
+ if time.ndim == 0:
176
+ time = time.repeat(batch)
177
+
178
+ # t: conditioning time, c: context (text + masked cond audio), x: noised input audio
179
+ t = self.time_embed(time)
180
+ text_embed = self.text_embed(text, seq_len, drop_text=drop_text)
181
+ x = self.input_embed(x, cond, text_embed, drop_audio_cond=drop_audio_cond)
182
+
183
+ # postfix time t to input x, [b n d] -> [b n+1 d]
184
+ x = torch.cat([t.unsqueeze(1), x], dim=1) # pack t to x
185
+ if mask is not None:
186
+ mask = F.pad(mask, (1, 0), value=1)
187
+
188
+ rope = self.rotary_embed.forward_from_seq_len(seq_len + 1)
189
+
190
+ # flat unet transformer
191
+ skip_connect_type = self.skip_connect_type
192
+ skips = []
193
+ for idx, (maybe_skip_proj, attn_norm, attn, ff_norm, ff) in enumerate(self.layers):
194
+ layer = idx + 1
195
+
196
+ # skip connection logic
197
+ is_first_half = layer <= (self.depth // 2)
198
+ is_later_half = not is_first_half
199
+
200
+ if is_first_half:
201
+ skips.append(x)
202
+
203
+ if is_later_half:
204
+ skip = skips.pop()
205
+ if skip_connect_type == "concat":
206
+ x = torch.cat((x, skip), dim=-1)
207
+ x = maybe_skip_proj(x)
208
+ elif skip_connect_type == "add":
209
+ x = x + skip
210
+
211
+ # attention and feedforward blocks
212
+ x = attn(attn_norm(x), rope=rope, mask=mask) + x
213
+ x = ff(ff_norm(x)) + x
214
+
215
+ assert len(skips) == 0
216
+
217
+ x = self.norm_out(x)[:, 1:, :] # unpack t from x
218
+
219
+ return self.proj_out(x)
f5_tts/model/modules.py ADDED
@@ -0,0 +1,666 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ein notation:
3
+ b - batch
4
+ n - sequence
5
+ nt - text sequence
6
+ nw - raw wave length
7
+ d - dimension
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ import math
13
+ from typing import Optional
14
+
15
+ import torch
16
+ import torch.nn.functional as F
17
+ import torchaudio
18
+ from librosa.filters import mel as librosa_mel_fn
19
+ from torch import nn
20
+ from x_transformers.x_transformers import apply_rotary_pos_emb
21
+
22
+
23
+ # raw wav to mel spec
24
+
25
+
26
+ mel_basis_cache = {}
27
+ hann_window_cache = {}
28
+
29
+
30
+ def get_bigvgan_mel_spectrogram(
31
+ waveform,
32
+ n_fft=1024,
33
+ n_mel_channels=100,
34
+ target_sample_rate=24000,
35
+ hop_length=256,
36
+ win_length=1024,
37
+ fmin=0,
38
+ fmax=None,
39
+ center=False,
40
+ ): # Copy from https://github.com/NVIDIA/BigVGAN/tree/main
41
+ device = waveform.device
42
+ key = f"{n_fft}_{n_mel_channels}_{target_sample_rate}_{hop_length}_{win_length}_{fmin}_{fmax}_{device}"
43
+
44
+ if key not in mel_basis_cache:
45
+ mel = librosa_mel_fn(sr=target_sample_rate, n_fft=n_fft, n_mels=n_mel_channels, fmin=fmin, fmax=fmax)
46
+ mel_basis_cache[key] = torch.from_numpy(mel).float().to(device) # TODO: why they need .float()?
47
+ hann_window_cache[key] = torch.hann_window(win_length).to(device)
48
+
49
+ mel_basis = mel_basis_cache[key]
50
+ hann_window = hann_window_cache[key]
51
+
52
+ padding = (n_fft - hop_length) // 2
53
+ waveform = torch.nn.functional.pad(waveform.unsqueeze(1), (padding, padding), mode="reflect").squeeze(1)
54
+
55
+ spec = torch.stft(
56
+ waveform,
57
+ n_fft,
58
+ hop_length=hop_length,
59
+ win_length=win_length,
60
+ window=hann_window,
61
+ center=center,
62
+ pad_mode="reflect",
63
+ normalized=False,
64
+ onesided=True,
65
+ return_complex=True,
66
+ )
67
+ spec = torch.sqrt(torch.view_as_real(spec).pow(2).sum(-1) + 1e-9)
68
+
69
+ mel_spec = torch.matmul(mel_basis, spec)
70
+ mel_spec = torch.log(torch.clamp(mel_spec, min=1e-5))
71
+
72
+ return mel_spec
73
+
74
+
75
+ def get_vocos_mel_spectrogram(
76
+ waveform,
77
+ n_fft=1024,
78
+ n_mel_channels=100,
79
+ target_sample_rate=24000,
80
+ hop_length=256,
81
+ win_length=1024,
82
+ ):
83
+ mel_stft = torchaudio.transforms.MelSpectrogram(
84
+ sample_rate=target_sample_rate,
85
+ n_fft=n_fft,
86
+ win_length=win_length,
87
+ hop_length=hop_length,
88
+ n_mels=n_mel_channels,
89
+ power=1,
90
+ center=True,
91
+ normalized=False,
92
+ norm=None,
93
+ ).to(waveform.device)
94
+ if len(waveform.shape) == 3:
95
+ waveform = waveform.squeeze(1) # 'b 1 nw -> b nw'
96
+
97
+ assert len(waveform.shape) == 2
98
+
99
+ mel = mel_stft(waveform)
100
+ mel = mel.clamp(min=1e-5).log()
101
+ return mel
102
+
103
+
104
+ class MelSpec(nn.Module):
105
+ def __init__(
106
+ self,
107
+ n_fft=1024,
108
+ hop_length=256,
109
+ win_length=1024,
110
+ n_mel_channels=100,
111
+ target_sample_rate=24_000,
112
+ mel_spec_type="vocos",
113
+ ):
114
+ super().__init__()
115
+ assert mel_spec_type in ["vocos", "bigvgan"], print("We only support two extract mel backend: vocos or bigvgan")
116
+
117
+ self.n_fft = n_fft
118
+ self.hop_length = hop_length
119
+ self.win_length = win_length
120
+ self.n_mel_channels = n_mel_channels
121
+ self.target_sample_rate = target_sample_rate
122
+
123
+ if mel_spec_type == "vocos":
124
+ self.extractor = get_vocos_mel_spectrogram
125
+ elif mel_spec_type == "bigvgan":
126
+ self.extractor = get_bigvgan_mel_spectrogram
127
+
128
+ self.register_buffer("dummy", torch.tensor(0), persistent=False)
129
+
130
+ def forward(self, wav):
131
+ if self.dummy.device != wav.device:
132
+ self.to(wav.device)
133
+
134
+ mel = self.extractor(
135
+ waveform=wav,
136
+ n_fft=self.n_fft,
137
+ n_mel_channels=self.n_mel_channels,
138
+ target_sample_rate=self.target_sample_rate,
139
+ hop_length=self.hop_length,
140
+ win_length=self.win_length,
141
+ )
142
+
143
+ return mel
144
+
145
+
146
+ # sinusoidal position embedding
147
+
148
+
149
+ class SinusPositionEmbedding(nn.Module):
150
+ def __init__(self, dim):
151
+ super().__init__()
152
+ self.dim = dim
153
+
154
+ def forward(self, x, scale=1000):
155
+ device = x.device
156
+ half_dim = self.dim // 2
157
+ emb = math.log(10000) / (half_dim - 1)
158
+ emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb)
159
+ emb = scale * x.unsqueeze(1) * emb.unsqueeze(0)
160
+ emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
161
+ return emb
162
+
163
+
164
+ # convolutional position embedding
165
+
166
+
167
+ class ConvPositionEmbedding(nn.Module):
168
+ def __init__(self, dim, kernel_size=31, groups=16):
169
+ super().__init__()
170
+ assert kernel_size % 2 != 0
171
+ self.conv1d = nn.Sequential(
172
+ nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2),
173
+ nn.Mish(),
174
+ nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2),
175
+ nn.Mish(),
176
+ )
177
+
178
+ def forward(self, x: float["b n d"], mask: bool["b n"] | None = None): # noqa: F722
179
+ if mask is not None:
180
+ mask = mask[..., None]
181
+ x = x.masked_fill(~mask, 0.0)
182
+
183
+ x = x.permute(0, 2, 1)
184
+ x = self.conv1d(x)
185
+ out = x.permute(0, 2, 1)
186
+
187
+ if mask is not None:
188
+ out = out.masked_fill(~mask, 0.0)
189
+
190
+ return out
191
+
192
+
193
+ # rotary positional embedding related
194
+
195
+
196
+ def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, theta_rescale_factor=1.0):
197
+ # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
198
+ # has some connection to NTK literature
199
+ # https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
200
+ # https://github.com/lucidrains/rotary-embedding-torch/blob/main/rotary_embedding_torch/rotary_embedding_torch.py
201
+ theta *= theta_rescale_factor ** (dim / (dim - 2))
202
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
203
+ t = torch.arange(end, device=freqs.device) # type: ignore
204
+ freqs = torch.outer(t, freqs).float() # type: ignore
205
+ freqs_cos = torch.cos(freqs) # real part
206
+ freqs_sin = torch.sin(freqs) # imaginary part
207
+ return torch.cat([freqs_cos, freqs_sin], dim=-1)
208
+
209
+
210
+ def get_pos_embed_indices(start, length, max_pos, scale=1.0):
211
+ # length = length if isinstance(length, int) else length.max()
212
+ scale = scale * torch.ones_like(start, dtype=torch.float32) # in case scale is a scalar
213
+ pos = (
214
+ start.unsqueeze(1)
215
+ + (torch.arange(length, device=start.device, dtype=torch.float32).unsqueeze(0) * scale.unsqueeze(1)).long()
216
+ )
217
+ # avoid extra long error.
218
+ pos = torch.where(pos < max_pos, pos, max_pos - 1)
219
+ return pos
220
+
221
+
222
+ # Global Response Normalization layer (Instance Normalization ?)
223
+
224
+
225
+ class GRN(nn.Module):
226
+ def __init__(self, dim):
227
+ super().__init__()
228
+ self.gamma = nn.Parameter(torch.zeros(1, 1, dim))
229
+ self.beta = nn.Parameter(torch.zeros(1, 1, dim))
230
+
231
+ def forward(self, x):
232
+ Gx = torch.norm(x, p=2, dim=1, keepdim=True)
233
+ Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
234
+ return self.gamma * (x * Nx) + self.beta + x
235
+
236
+
237
+ # ConvNeXt-V2 Block https://github.com/facebookresearch/ConvNeXt-V2/blob/main/models/convnextv2.py
238
+ # ref: https://github.com/bfs18/e2_tts/blob/main/rfwave/modules.py#L108
239
+
240
+
241
+ class ConvNeXtV2Block(nn.Module):
242
+ def __init__(
243
+ self,
244
+ dim: int,
245
+ intermediate_dim: int,
246
+ dilation: int = 1,
247
+ ):
248
+ super().__init__()
249
+ padding = (dilation * (7 - 1)) // 2
250
+ self.dwconv = nn.Conv1d(
251
+ dim, dim, kernel_size=7, padding=padding, groups=dim, dilation=dilation
252
+ ) # depthwise conv
253
+ self.norm = nn.LayerNorm(dim, eps=1e-6)
254
+ self.pwconv1 = nn.Linear(dim, intermediate_dim) # pointwise/1x1 convs, implemented with linear layers
255
+ self.act = nn.GELU()
256
+ self.grn = GRN(intermediate_dim)
257
+ self.pwconv2 = nn.Linear(intermediate_dim, dim)
258
+
259
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
260
+ residual = x
261
+ x = x.transpose(1, 2) # b n d -> b d n
262
+ x = self.dwconv(x)
263
+ x = x.transpose(1, 2) # b d n -> b n d
264
+ x = self.norm(x)
265
+ x = self.pwconv1(x)
266
+ x = self.act(x)
267
+ x = self.grn(x)
268
+ x = self.pwconv2(x)
269
+ return residual + x
270
+
271
+
272
+ # AdaLayerNormZero
273
+ # return with modulated x for attn input, and params for later mlp modulation
274
+
275
+
276
+ class AdaLayerNormZero(nn.Module):
277
+ def __init__(self, dim):
278
+ super().__init__()
279
+
280
+ self.silu = nn.SiLU()
281
+ self.linear = nn.Linear(dim, dim * 6)
282
+
283
+ self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
284
+
285
+ def forward(self, x, emb=None):
286
+ emb = self.linear(self.silu(emb))
287
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = torch.chunk(emb, 6, dim=1)
288
+
289
+ x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
290
+ return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
291
+
292
+
293
+ # AdaLayerNormZero for final layer
294
+ # return only with modulated x for attn input, cuz no more mlp modulation
295
+
296
+
297
+ class AdaLayerNormZero_Final(nn.Module):
298
+ def __init__(self, dim):
299
+ super().__init__()
300
+
301
+ self.silu = nn.SiLU()
302
+ self.linear = nn.Linear(dim, dim * 2)
303
+
304
+ self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
305
+
306
+ def forward(self, x, emb):
307
+ emb = self.linear(self.silu(emb))
308
+ scale, shift = torch.chunk(emb, 2, dim=1)
309
+
310
+ x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
311
+ return x
312
+
313
+
314
+ # FeedForward
315
+
316
+
317
+ class FeedForward(nn.Module):
318
+ def __init__(self, dim, dim_out=None, mult=4, dropout=0.0, approximate: str = "none"):
319
+ super().__init__()
320
+ inner_dim = int(dim * mult)
321
+ dim_out = dim_out if dim_out is not None else dim
322
+
323
+ activation = nn.GELU(approximate=approximate)
324
+ project_in = nn.Sequential(nn.Linear(dim, inner_dim), activation)
325
+ self.ff = nn.Sequential(project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out))
326
+
327
+ def forward(self, x):
328
+ return self.ff(x)
329
+
330
+
331
+ # Attention with possible joint part
332
+ # modified from diffusers/src/diffusers/models/attention_processor.py
333
+
334
+
335
+ class Attention(nn.Module):
336
+ def __init__(
337
+ self,
338
+ processor: JointAttnProcessor | AttnProcessor,
339
+ dim: int,
340
+ heads: int = 8,
341
+ dim_head: int = 64,
342
+ dropout: float = 0.0,
343
+ context_dim: Optional[int] = None, # if not None -> joint attention
344
+ context_pre_only=None,
345
+ ):
346
+ super().__init__()
347
+
348
+ if not hasattr(F, "scaled_dot_product_attention"):
349
+ raise ImportError("Attention equires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
350
+
351
+ self.processor = processor
352
+
353
+ self.dim = dim
354
+ self.heads = heads
355
+ self.inner_dim = dim_head * heads
356
+ self.dropout = dropout
357
+
358
+ self.context_dim = context_dim
359
+ self.context_pre_only = context_pre_only
360
+
361
+ self.to_q = nn.Linear(dim, self.inner_dim)
362
+ self.to_k = nn.Linear(dim, self.inner_dim)
363
+ self.to_v = nn.Linear(dim, self.inner_dim)
364
+
365
+ if self.context_dim is not None:
366
+ self.to_k_c = nn.Linear(context_dim, self.inner_dim)
367
+ self.to_v_c = nn.Linear(context_dim, self.inner_dim)
368
+ if self.context_pre_only is not None:
369
+ self.to_q_c = nn.Linear(context_dim, self.inner_dim)
370
+
371
+ self.to_out = nn.ModuleList([])
372
+ self.to_out.append(nn.Linear(self.inner_dim, dim))
373
+ self.to_out.append(nn.Dropout(dropout))
374
+
375
+ if self.context_pre_only is not None and not self.context_pre_only:
376
+ self.to_out_c = nn.Linear(self.inner_dim, dim)
377
+
378
+ def forward(
379
+ self,
380
+ x: float["b n d"], # noised input x # noqa: F722
381
+ c: float["b n d"] = None, # context c # noqa: F722
382
+ mask: bool["b n"] | None = None, # noqa: F722
383
+ rope=None, # rotary position embedding for x
384
+ c_rope=None, # rotary position embedding for c
385
+ ) -> torch.Tensor:
386
+ if c is not None:
387
+ return self.processor(self, x, c=c, mask=mask, rope=rope, c_rope=c_rope)
388
+ else:
389
+ return self.processor(self, x, mask=mask, rope=rope)
390
+
391
+
392
+ # Attention processor
393
+
394
+
395
+ # from torch.nn.attention import SDPBackend
396
+ # torch.backends.cuda.enable_flash_sdp(True)
397
+ class AttnProcessor:
398
+ def __init__(self):
399
+ pass
400
+
401
+ def __call__(
402
+ self,
403
+ attn: Attention,
404
+ x: float["b n d"], # noised input x # noqa: F722
405
+ mask: bool["b n"] | None = None, # noqa: F722
406
+ rope=None, # rotary position embedding
407
+ ) -> torch.FloatTensor:
408
+ batch_size = x.shape[0]
409
+
410
+ # `sample` projections.
411
+ query = attn.to_q(x)
412
+ key = attn.to_k(x)
413
+ value = attn.to_v(x)
414
+
415
+ # apply rotary position embedding
416
+ if rope is not None:
417
+ freqs, xpos_scale = rope
418
+ q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0)
419
+
420
+ query = apply_rotary_pos_emb(query, freqs, q_xpos_scale)
421
+ key = apply_rotary_pos_emb(key, freqs, k_xpos_scale)
422
+
423
+ # attention
424
+ inner_dim = key.shape[-1]
425
+ head_dim = inner_dim // attn.heads
426
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
427
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
428
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
429
+
430
+ # mask. e.g. inference got a batch with different target durations, mask out the padding
431
+ if mask is not None:
432
+ attn_mask = mask
433
+ attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) # 'b n -> b 1 1 n'
434
+ # print(3433333333,attn_mask.shape)
435
+ attn_mask = attn_mask.expand(batch_size, attn.heads, query.shape[-2], key.shape[-2])
436
+ else:
437
+ attn_mask = None
438
+ # with torch.nn.attention.sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]):
439
+ # with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=True):
440
+ # with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=False):
441
+ # print(torch.backends.cuda.flash_sdp_enabled())
442
+ # print(torch.backends.cuda.mem_efficient_sdp_enabled())
443
+ # print(torch.backends.cuda.math_sdp_enabled())
444
+ x = F.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False)
445
+ x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
446
+ x = x.to(query.dtype)
447
+
448
+ # linear proj
449
+ x = attn.to_out[0](x)
450
+ # dropout
451
+ x = attn.to_out[1](x)
452
+
453
+ if mask is not None:
454
+ mask = mask.unsqueeze(-1)
455
+ x = x.masked_fill(~mask, 0.0)
456
+
457
+ return x
458
+
459
+
460
+ # Joint Attention processor for MM-DiT
461
+ # modified from diffusers/src/diffusers/models/attention_processor.py
462
+
463
+
464
+ class JointAttnProcessor:
465
+ def __init__(self):
466
+ pass
467
+
468
+ def __call__(
469
+ self,
470
+ attn: Attention,
471
+ x: float["b n d"], # noised input x # noqa: F722
472
+ c: float["b nt d"] = None, # context c, here text # noqa: F722
473
+ mask: bool["b n"] | None = None, # noqa: F722
474
+ rope=None, # rotary position embedding for x
475
+ c_rope=None, # rotary position embedding for c
476
+ ) -> torch.FloatTensor:
477
+ residual = x
478
+
479
+ batch_size = c.shape[0]
480
+
481
+ # `sample` projections.
482
+ query = attn.to_q(x)
483
+ key = attn.to_k(x)
484
+ value = attn.to_v(x)
485
+
486
+ # `context` projections.
487
+ c_query = attn.to_q_c(c)
488
+ c_key = attn.to_k_c(c)
489
+ c_value = attn.to_v_c(c)
490
+
491
+ # apply rope for context and noised input independently
492
+ if rope is not None:
493
+ freqs, xpos_scale = rope
494
+ q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0)
495
+ query = apply_rotary_pos_emb(query, freqs, q_xpos_scale)
496
+ key = apply_rotary_pos_emb(key, freqs, k_xpos_scale)
497
+ if c_rope is not None:
498
+ freqs, xpos_scale = c_rope
499
+ q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0)
500
+ c_query = apply_rotary_pos_emb(c_query, freqs, q_xpos_scale)
501
+ c_key = apply_rotary_pos_emb(c_key, freqs, k_xpos_scale)
502
+
503
+ # attention
504
+ query = torch.cat([query, c_query], dim=1)
505
+ key = torch.cat([key, c_key], dim=1)
506
+ value = torch.cat([value, c_value], dim=1)
507
+
508
+ inner_dim = key.shape[-1]
509
+ head_dim = inner_dim // attn.heads
510
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
511
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
512
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
513
+
514
+ # mask. e.g. inference got a batch with different target durations, mask out the padding
515
+ if mask is not None:
516
+ attn_mask = F.pad(mask, (0, c.shape[1]), value=True) # no mask for c (text)
517
+ attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) # 'b n -> b 1 1 n'
518
+ attn_mask = attn_mask.expand(batch_size, attn.heads, query.shape[-2], key.shape[-2])
519
+ else:
520
+ attn_mask = None
521
+
522
+ x = F.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False)
523
+ x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
524
+ x = x.to(query.dtype)
525
+
526
+ # Split the attention outputs.
527
+ x, c = (
528
+ x[:, : residual.shape[1]],
529
+ x[:, residual.shape[1] :],
530
+ )
531
+
532
+ # linear proj
533
+ x = attn.to_out[0](x)
534
+ # dropout
535
+ x = attn.to_out[1](x)
536
+ if not attn.context_pre_only:
537
+ c = attn.to_out_c(c)
538
+
539
+ if mask is not None:
540
+ mask = mask.unsqueeze(-1)
541
+ x = x.masked_fill(~mask, 0.0)
542
+ # c = c.masked_fill(~mask, 0.) # no mask for c (text)
543
+
544
+ return x, c
545
+
546
+
547
+ # DiT Block
548
+
549
+
550
+ class DiTBlock(nn.Module):
551
+ def __init__(self, dim, heads, dim_head, ff_mult=4, dropout=0.1):
552
+ super().__init__()
553
+
554
+ self.attn_norm = AdaLayerNormZero(dim)
555
+ self.attn = Attention(
556
+ processor=AttnProcessor(),
557
+ dim=dim,
558
+ heads=heads,
559
+ dim_head=dim_head,
560
+ dropout=dropout,
561
+ )
562
+
563
+ self.ff_norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
564
+ self.ff = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh")
565
+
566
+ def forward(self, x, t, mask=None, rope=None): # x: noised input, t: time embedding
567
+ # pre-norm & modulation for attention input
568
+ norm, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.attn_norm(x, emb=t)
569
+
570
+ # attention
571
+ attn_output = self.attn(x=norm, mask=mask, rope=rope)
572
+
573
+ # process attention output for input x
574
+ x = x + gate_msa.unsqueeze(1) * attn_output
575
+
576
+ norm = self.ff_norm(x) * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
577
+ ff_output = self.ff(norm)
578
+ x = x + gate_mlp.unsqueeze(1) * ff_output
579
+
580
+ return x
581
+
582
+
583
+ # MMDiT Block https://arxiv.org/abs/2403.03206
584
+
585
+
586
+ class MMDiTBlock(nn.Module):
587
+ r"""
588
+ modified from diffusers/src/diffusers/models/attention.py
589
+
590
+ notes.
591
+ _c: context related. text, cond, etc. (left part in sd3 fig2.b)
592
+ _x: noised input related. (right part)
593
+ context_pre_only: last layer only do prenorm + modulation cuz no more ffn
594
+ """
595
+
596
+ def __init__(self, dim, heads, dim_head, ff_mult=4, dropout=0.1, context_pre_only=False):
597
+ super().__init__()
598
+
599
+ self.context_pre_only = context_pre_only
600
+
601
+ self.attn_norm_c = AdaLayerNormZero_Final(dim) if context_pre_only else AdaLayerNormZero(dim)
602
+ self.attn_norm_x = AdaLayerNormZero(dim)
603
+ self.attn = Attention(
604
+ processor=JointAttnProcessor(),
605
+ dim=dim,
606
+ heads=heads,
607
+ dim_head=dim_head,
608
+ dropout=dropout,
609
+ context_dim=dim,
610
+ context_pre_only=context_pre_only,
611
+ )
612
+
613
+ if not context_pre_only:
614
+ self.ff_norm_c = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
615
+ self.ff_c = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh")
616
+ else:
617
+ self.ff_norm_c = None
618
+ self.ff_c = None
619
+ self.ff_norm_x = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
620
+ self.ff_x = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh")
621
+
622
+ def forward(self, x, c, t, mask=None, rope=None, c_rope=None): # x: noised input, c: context, t: time embedding
623
+ # pre-norm & modulation for attention input
624
+ if self.context_pre_only:
625
+ norm_c = self.attn_norm_c(c, t)
626
+ else:
627
+ norm_c, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.attn_norm_c(c, emb=t)
628
+ norm_x, x_gate_msa, x_shift_mlp, x_scale_mlp, x_gate_mlp = self.attn_norm_x(x, emb=t)
629
+
630
+ # attention
631
+ x_attn_output, c_attn_output = self.attn(x=norm_x, c=norm_c, mask=mask, rope=rope, c_rope=c_rope)
632
+
633
+ # process attention output for context c
634
+ if self.context_pre_only:
635
+ c = None
636
+ else: # if not last layer
637
+ c = c + c_gate_msa.unsqueeze(1) * c_attn_output
638
+
639
+ norm_c = self.ff_norm_c(c) * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
640
+ c_ff_output = self.ff_c(norm_c)
641
+ c = c + c_gate_mlp.unsqueeze(1) * c_ff_output
642
+
643
+ # process attention output for input x
644
+ x = x + x_gate_msa.unsqueeze(1) * x_attn_output
645
+
646
+ norm_x = self.ff_norm_x(x) * (1 + x_scale_mlp[:, None]) + x_shift_mlp[:, None]
647
+ x_ff_output = self.ff_x(norm_x)
648
+ x = x + x_gate_mlp.unsqueeze(1) * x_ff_output
649
+
650
+ return c, x
651
+
652
+
653
+ # time step conditioning embedding
654
+
655
+
656
+ class TimestepEmbedding(nn.Module):
657
+ def __init__(self, dim, freq_embed_dim=256):
658
+ super().__init__()
659
+ self.time_embed = SinusPositionEmbedding(freq_embed_dim)
660
+ self.time_mlp = nn.Sequential(nn.Linear(freq_embed_dim, dim), nn.SiLU(), nn.Linear(dim, dim))
661
+
662
+ def forward(self, timestep: float["b"]): # noqa: F821
663
+ time_hidden = self.time_embed(timestep)
664
+ time_hidden = time_hidden.to(timestep.dtype)
665
+ time = self.time_mlp(time_hidden) # b d
666
+ return time
prepare_datasets/1-get-text.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import os
4
+
5
+ inp_text = os.environ.get("inp_text")
6
+ inp_wav_dir = os.environ.get("inp_wav_dir")
7
+ exp_name = os.environ.get("exp_name")
8
+ i_part = os.environ.get("i_part")
9
+ all_parts = os.environ.get("all_parts")
10
+ if "_CUDA_VISIBLE_DEVICES" in os.environ:
11
+ os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"]
12
+ opt_dir = os.environ.get("opt_dir")
13
+ bert_pretrained_dir = os.environ.get("bert_pretrained_dir")
14
+ import torch
15
+
16
+ is_half = eval(os.environ.get("is_half", "True")) and torch.cuda.is_available()
17
+ version = os.environ.get("version", None)
18
+ import traceback
19
+ import os.path
20
+ from text.cleaner import clean_text
21
+ from transformers import AutoModelForMaskedLM, AutoTokenizer
22
+ from tools.my_utils import clean_path
23
+
24
+ # inp_text=sys.argv[1]
25
+ # inp_wav_dir=sys.argv[2]
26
+ # exp_name=sys.argv[3]
27
+ # i_part=sys.argv[4]
28
+ # all_parts=sys.argv[5]
29
+ # os.environ["CUDA_VISIBLE_DEVICES"]=sys.argv[6]#i_gpu
30
+ # opt_dir="/data/docker/liujing04/gpt-vits/fine_tune_dataset/%s"%exp_name
31
+ # bert_pretrained_dir="/data/docker/liujing04/bert-vits2/Bert-VITS2-master20231106/bert/chinese-roberta-wwm-ext-large"
32
+
33
+ from time import time as ttime
34
+ import shutil
35
+
36
+
37
+ def my_save(fea, path): #####fix issue: torch.save doesn't support chinese path
38
+ dir = os.path.dirname(path)
39
+ name = os.path.basename(path)
40
+ # tmp_path="%s/%s%s.pth"%(dir,ttime(),i_part)
41
+ tmp_path = "%s%s.pth" % (ttime(), i_part)
42
+ torch.save(fea, tmp_path)
43
+ shutil.move(tmp_path, "%s/%s" % (dir, name))
44
+
45
+
46
+ txt_path = "%s/2-name2text-%s.txt" % (opt_dir, i_part)
47
+ if os.path.exists(txt_path) == False:
48
+ bert_dir = "%s/3-bert" % (opt_dir)
49
+ os.makedirs(opt_dir, exist_ok=True)
50
+ os.makedirs(bert_dir, exist_ok=True)
51
+ if torch.cuda.is_available():
52
+ device = "cuda:0"
53
+ # elif torch.backends.mps.is_available():
54
+ # device = "mps"
55
+ else:
56
+ device = "cpu"
57
+ if os.path.exists(bert_pretrained_dir):
58
+ ...
59
+ else:
60
+ raise FileNotFoundError(bert_pretrained_dir)
61
+ tokenizer = AutoTokenizer.from_pretrained(bert_pretrained_dir)
62
+ bert_model = AutoModelForMaskedLM.from_pretrained(bert_pretrained_dir)
63
+ if is_half == True:
64
+ bert_model = bert_model.half().to(device)
65
+ else:
66
+ bert_model = bert_model.to(device)
67
+
68
+ def get_bert_feature(text, word2ph):
69
+ with torch.no_grad():
70
+ inputs = tokenizer(text, return_tensors="pt")
71
+ for i in inputs:
72
+ inputs[i] = inputs[i].to(device)
73
+ res = bert_model(**inputs, output_hidden_states=True)
74
+ res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()[1:-1]
75
+
76
+ assert len(word2ph) == len(text)
77
+ phone_level_feature = []
78
+ for i in range(len(word2ph)):
79
+ repeat_feature = res[i].repeat(word2ph[i], 1)
80
+ phone_level_feature.append(repeat_feature)
81
+
82
+ phone_level_feature = torch.cat(phone_level_feature, dim=0)
83
+
84
+ return phone_level_feature.T
85
+
86
+ def process(data, res):
87
+ for name, text, lan in data:
88
+ try:
89
+ name = clean_path(name)
90
+ name = os.path.basename(name)
91
+ print(name)
92
+ phones, word2ph, norm_text = clean_text(text.replace("%", "-").replace("¥", ","), lan, version)
93
+ path_bert = "%s/%s.pt" % (bert_dir, name)
94
+ if os.path.exists(path_bert) == False and lan == "zh":
95
+ bert_feature = get_bert_feature(norm_text, word2ph)
96
+ assert bert_feature.shape[-1] == len(phones)
97
+ # torch.save(bert_feature, path_bert)
98
+ my_save(bert_feature, path_bert)
99
+ phones = " ".join(phones)
100
+ # res.append([name,phones])
101
+ res.append([name, phones, word2ph, norm_text])
102
+ except:
103
+ print(name, text, traceback.format_exc())
104
+
105
+ todo = []
106
+ res = []
107
+ with open(inp_text, "r", encoding="utf8") as f:
108
+ lines = f.read().strip("\n").split("\n")
109
+
110
+ language_v1_to_language_v2 = {
111
+ "ZH": "zh",
112
+ "zh": "zh",
113
+ "JP": "ja",
114
+ "jp": "ja",
115
+ "JA": "ja",
116
+ "ja": "ja",
117
+ "EN": "en",
118
+ "en": "en",
119
+ "En": "en",
120
+ "KO": "ko",
121
+ "Ko": "ko",
122
+ "ko": "ko",
123
+ "yue": "yue",
124
+ "YUE": "yue",
125
+ "Yue": "yue",
126
+ }
127
+ for line in lines[int(i_part) :: int(all_parts)]:
128
+ try:
129
+ wav_name, spk_name, language, text = line.split("|")
130
+ # todo.append([name,text,"zh"])
131
+ if language in language_v1_to_language_v2.keys():
132
+ todo.append([wav_name, text, language_v1_to_language_v2.get(language, language)])
133
+ else:
134
+ print(f"\033[33m[Waring] The {language = } of {wav_name} is not supported for training.\033[0m")
135
+ except:
136
+ print(line, traceback.format_exc())
137
+
138
+ process(todo, res)
139
+ opt = []
140
+ for name, phones, word2ph, norm_text in res:
141
+ opt.append("%s\t%s\t%s\t%s" % (name, phones, word2ph, norm_text))
142
+ with open(txt_path, "w", encoding="utf8") as f:
143
+ f.write("\n".join(opt) + "\n")
prepare_datasets/2-get-hubert-wav32k.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import sys
4
+ import os
5
+
6
+ inp_text = os.environ.get("inp_text")
7
+ inp_wav_dir = os.environ.get("inp_wav_dir")
8
+ exp_name = os.environ.get("exp_name")
9
+ i_part = os.environ.get("i_part")
10
+ all_parts = os.environ.get("all_parts")
11
+ if "_CUDA_VISIBLE_DEVICES" in os.environ:
12
+ os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"]
13
+ from feature_extractor import cnhubert
14
+
15
+ opt_dir = os.environ.get("opt_dir")
16
+ cnhubert.cnhubert_base_path = os.environ.get("cnhubert_base_dir")
17
+ import torch
18
+
19
+ is_half = eval(os.environ.get("is_half", "True")) and torch.cuda.is_available()
20
+
21
+ import traceback
22
+ import numpy as np
23
+ from scipy.io import wavfile
24
+ import librosa
25
+
26
+ now_dir = os.getcwd()
27
+ sys.path.append(now_dir)
28
+ from tools.my_utils import load_audio, clean_path
29
+
30
+ # from config import cnhubert_base_path
31
+ # cnhubert.cnhubert_base_path=cnhubert_base_path
32
+ # inp_text=sys.argv[1]
33
+ # inp_wav_dir=sys.argv[2]
34
+ # exp_name=sys.argv[3]
35
+ # i_part=sys.argv[4]
36
+ # all_parts=sys.argv[5]
37
+ # os.environ["CUDA_VISIBLE_DEVICES"]=sys.argv[6]
38
+ # cnhubert.cnhubert_base_path=sys.argv[7]
39
+ # opt_dir="/data/docker/liujing04/gpt-vits/fine_tune_dataset/%s"%exp_name
40
+
41
+ from time import time as ttime
42
+ import shutil
43
+
44
+
45
+ def my_save(fea, path): #####fix issue: torch.save doesn't support chinese path
46
+ dir = os.path.dirname(path)
47
+ name = os.path.basename(path)
48
+ # tmp_path="%s/%s%s.pth"%(dir,ttime(),i_part)
49
+ tmp_path = "%s%s.pth" % (ttime(), i_part)
50
+ torch.save(fea, tmp_path)
51
+ shutil.move(tmp_path, "%s/%s" % (dir, name))
52
+
53
+
54
+ hubert_dir = "%s/4-cnhubert" % (opt_dir)
55
+ wav32dir = "%s/5-wav32k" % (opt_dir)
56
+ os.makedirs(opt_dir, exist_ok=True)
57
+ os.makedirs(hubert_dir, exist_ok=True)
58
+ os.makedirs(wav32dir, exist_ok=True)
59
+
60
+ maxx = 0.95
61
+ alpha = 0.5
62
+ if torch.cuda.is_available():
63
+ device = "cuda:0"
64
+ # elif torch.backends.mps.is_available():
65
+ # device = "mps"
66
+ else:
67
+ device = "cpu"
68
+ model = cnhubert.get_model()
69
+ # is_half=False
70
+ if is_half == True:
71
+ model = model.half().to(device)
72
+ else:
73
+ model = model.to(device)
74
+
75
+ nan_fails = []
76
+
77
+
78
+ def name2go(wav_name, wav_path):
79
+ hubert_path = "%s/%s.pt" % (hubert_dir, wav_name)
80
+ if os.path.exists(hubert_path):
81
+ return
82
+ tmp_audio = load_audio(wav_path, 32000)
83
+ tmp_max = np.abs(tmp_audio).max()
84
+ if tmp_max > 2.2:
85
+ print("%s-filtered,%s" % (wav_name, tmp_max))
86
+ return
87
+ tmp_audio32 = (tmp_audio / tmp_max * (maxx * alpha * 32768)) + ((1 - alpha) * 32768) * tmp_audio
88
+ tmp_audio32b = (tmp_audio / tmp_max * (maxx * alpha * 1145.14)) + ((1 - alpha) * 1145.14) * tmp_audio
89
+ tmp_audio = librosa.resample(tmp_audio32b, orig_sr=32000, target_sr=16000) # 不是重采样问题
90
+ tensor_wav16 = torch.from_numpy(tmp_audio)
91
+ if is_half == True:
92
+ tensor_wav16 = tensor_wav16.half().to(device)
93
+ else:
94
+ tensor_wav16 = tensor_wav16.to(device)
95
+ ssl = model.model(tensor_wav16.unsqueeze(0))["last_hidden_state"].transpose(1, 2).cpu() # torch.Size([1, 768, 215])
96
+ if np.isnan(ssl.detach().numpy()).sum() != 0:
97
+ nan_fails.append((wav_name, wav_path))
98
+ print("nan filtered:%s" % wav_name)
99
+ return
100
+ wavfile.write(
101
+ "%s/%s" % (wav32dir, wav_name),
102
+ 32000,
103
+ tmp_audio32.astype("int16"),
104
+ )
105
+ my_save(ssl, hubert_path)
106
+
107
+
108
+ with open(inp_text, "r", encoding="utf8") as f:
109
+ lines = f.read().strip("\n").split("\n")
110
+
111
+ for line in lines[int(i_part) :: int(all_parts)]:
112
+ try:
113
+ # wav_name,text=line.split("\t")
114
+ wav_name, spk_name, language, text = line.split("|")
115
+ wav_name = clean_path(wav_name)
116
+ if inp_wav_dir != "" and inp_wav_dir != None:
117
+ wav_name = os.path.basename(wav_name)
118
+ wav_path = "%s/%s" % (inp_wav_dir, wav_name)
119
+
120
+ else:
121
+ wav_path = wav_name
122
+ wav_name = os.path.basename(wav_name)
123
+ name2go(wav_name, wav_path)
124
+ except:
125
+ print(line, traceback.format_exc())
126
+
127
+ if len(nan_fails) > 0 and is_half == True:
128
+ is_half = False
129
+ model = model.float()
130
+ for wav in nan_fails:
131
+ try:
132
+ name2go(wav[0], wav[1])
133
+ except:
134
+ print(wav_name, traceback.format_exc())
prepare_datasets/2-get-sv.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import sys
4
+ import os
5
+
6
+ inp_text = os.environ.get("inp_text")
7
+ inp_wav_dir = os.environ.get("inp_wav_dir")
8
+ exp_name = os.environ.get("exp_name")
9
+ i_part = os.environ.get("i_part")
10
+ all_parts = os.environ.get("all_parts")
11
+ if "_CUDA_VISIBLE_DEVICES" in os.environ:
12
+ os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"]
13
+
14
+ opt_dir = os.environ.get("opt_dir")
15
+ sv_path = os.environ.get("sv_path")
16
+ import torch
17
+
18
+ is_half = eval(os.environ.get("is_half", "True")) and torch.cuda.is_available()
19
+
20
+ import traceback
21
+ import torchaudio
22
+
23
+ now_dir = os.getcwd()
24
+ sys.path.append(now_dir)
25
+ sys.path.append(f"{now_dir}/GPT_SoVITS/eres2net")
26
+ from tools.my_utils import clean_path
27
+ from time import time as ttime
28
+ import shutil
29
+ from ERes2NetV2 import ERes2NetV2
30
+ import kaldi as Kaldi
31
+
32
+
33
+ def my_save(fea, path): #####fix issue: torch.save doesn't support chinese path
34
+ dir = os.path.dirname(path)
35
+ name = os.path.basename(path)
36
+ # tmp_path="%s/%s%s.pth"%(dir,ttime(),i_part)
37
+ tmp_path = "%s%s.pth" % (ttime(), i_part)
38
+ torch.save(fea, tmp_path)
39
+ shutil.move(tmp_path, "%s/%s" % (dir, name))
40
+
41
+
42
+ sv_cn_dir = "%s/7-sv_cn" % (opt_dir)
43
+ wav32dir = "%s/5-wav32k" % (opt_dir)
44
+ os.makedirs(opt_dir, exist_ok=True)
45
+ os.makedirs(sv_cn_dir, exist_ok=True)
46
+ os.makedirs(wav32dir, exist_ok=True)
47
+
48
+ maxx = 0.95
49
+ alpha = 0.5
50
+ if torch.cuda.is_available():
51
+ device = "cuda:0"
52
+ # elif torch.backends.mps.is_available():
53
+ # device = "mps"
54
+ else:
55
+ device = "cpu"
56
+
57
+
58
+ class SV:
59
+ def __init__(self, device, is_half):
60
+ pretrained_state = torch.load(sv_path, map_location="cpu")
61
+ embedding_model = ERes2NetV2(baseWidth=24, scale=4, expansion=4)
62
+ embedding_model.load_state_dict(pretrained_state)
63
+ embedding_model.eval()
64
+ self.embedding_model = embedding_model
65
+ self.res = torchaudio.transforms.Resample(32000, 16000).to(device)
66
+ if is_half == False:
67
+ self.embedding_model = self.embedding_model.to(device)
68
+ else:
69
+ self.embedding_model = self.embedding_model.half().to(device)
70
+ self.is_half = is_half
71
+
72
+ def compute_embedding3(self, wav): # (1,x)#-1~1
73
+ with torch.no_grad():
74
+ wav = self.res(wav)
75
+ if self.is_half == True:
76
+ wav = wav.half()
77
+ feat = torch.stack(
78
+ [Kaldi.fbank(wav0.unsqueeze(0), num_mel_bins=80, sample_frequency=16000, dither=0) for wav0 in wav]
79
+ )
80
+ sv_emb = self.embedding_model.forward3(feat)
81
+ return sv_emb
82
+
83
+
84
+ sv = SV(device, is_half)
85
+
86
+
87
+ def name2go(wav_name, wav_path):
88
+ sv_cn_path = "%s/%s.pt" % (sv_cn_dir, wav_name)
89
+ if os.path.exists(sv_cn_path):
90
+ return
91
+ wav_path = "%s/%s" % (wav32dir, wav_name)
92
+ wav32k, sr0 = torchaudio.load(wav_path)
93
+ assert sr0 == 32000
94
+ wav32k = wav32k.to(device)
95
+ emb = sv.compute_embedding3(wav32k).cpu() # torch.Size([1, 20480])
96
+ my_save(emb, sv_cn_path)
97
+
98
+
99
+ with open(inp_text, "r", encoding="utf8") as f:
100
+ lines = f.read().strip("\n").split("\n")
101
+
102
+ for line in lines[int(i_part) :: int(all_parts)]:
103
+ try:
104
+ wav_name, spk_name, language, text = line.split("|")
105
+ wav_name = clean_path(wav_name)
106
+ if inp_wav_dir != "" and inp_wav_dir != None:
107
+ wav_name = os.path.basename(wav_name)
108
+ wav_path = "%s/%s" % (inp_wav_dir, wav_name)
109
+
110
+ else:
111
+ wav_path = wav_name
112
+ wav_name = os.path.basename(wav_name)
113
+ name2go(wav_name, wav_path)
114
+ except:
115
+ print(line, traceback.format_exc())
prepare_datasets/3-get-semantic.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ inp_text = os.environ.get("inp_text")
4
+ exp_name = os.environ.get("exp_name")
5
+ i_part = os.environ.get("i_part")
6
+ all_parts = os.environ.get("all_parts")
7
+ if "_CUDA_VISIBLE_DEVICES" in os.environ:
8
+ os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"]
9
+ opt_dir = os.environ.get("opt_dir")
10
+ pretrained_s2G = os.environ.get("pretrained_s2G")
11
+ s2config_path = os.environ.get("s2config_path")
12
+
13
+ if os.path.exists(pretrained_s2G):
14
+ ...
15
+ else:
16
+ raise FileNotFoundError(pretrained_s2G)
17
+ # version=os.environ.get("version","v2")
18
+ size = os.path.getsize(pretrained_s2G)
19
+ if size < 82978 * 1024:
20
+ version = "v1"
21
+ elif size < 100 * 1024 * 1024:
22
+ version = "v2"
23
+ elif size < 103520 * 1024:
24
+ version = "v1"
25
+ elif size < 700 * 1024 * 1024:
26
+ version = "v2"
27
+ else:
28
+ version = "v3"
29
+ import torch
30
+
31
+ is_half = eval(os.environ.get("is_half", "True")) and torch.cuda.is_available()
32
+ import traceback
33
+ import sys
34
+
35
+ now_dir = os.getcwd()
36
+ sys.path.append(now_dir)
37
+ import logging
38
+ import utils
39
+
40
+ if version != "v3":
41
+ from module.models import SynthesizerTrn
42
+ else:
43
+ from module.models import SynthesizerTrnV3 as SynthesizerTrn
44
+ from tools.my_utils import clean_path
45
+
46
+ logging.getLogger("numba").setLevel(logging.WARNING)
47
+ # from config import pretrained_s2G
48
+
49
+ # inp_text=sys.argv[1]
50
+ # exp_name=sys.argv[2]
51
+ # i_part=sys.argv[3]
52
+ # all_parts=sys.argv[4]
53
+ # os.environ["CUDA_VISIBLE_DEVICES"]=sys.argv[5]
54
+ # opt_dir="/data/docker/liujing04/gpt-vits/fine_tune_dataset/%s"%exp_name
55
+
56
+
57
+ hubert_dir = "%s/4-cnhubert" % (opt_dir)
58
+ semantic_path = "%s/6-name2semantic-%s.tsv" % (opt_dir, i_part)
59
+ if os.path.exists(semantic_path) == False:
60
+ os.makedirs(opt_dir, exist_ok=True)
61
+
62
+ if torch.cuda.is_available():
63
+ device = "cuda"
64
+ # elif torch.backends.mps.is_available():
65
+ # device = "mps"
66
+ else:
67
+ device = "cpu"
68
+ hps = utils.get_hparams_from_file(s2config_path)
69
+ vq_model = SynthesizerTrn(
70
+ hps.data.filter_length // 2 + 1,
71
+ hps.train.segment_size // hps.data.hop_length,
72
+ n_speakers=hps.data.n_speakers,
73
+ version=version,
74
+ **hps.model,
75
+ )
76
+ if is_half == True:
77
+ vq_model = vq_model.half().to(device)
78
+ else:
79
+ vq_model = vq_model.to(device)
80
+ vq_model.eval()
81
+ # utils.load_checkpoint(utils.latest_checkpoint_path(hps.s2_ckpt_dir, "G_*.pth"), vq_model, None, True)
82
+ # utils.load_checkpoint(pretrained_s2G, vq_model, None, True)
83
+ print(
84
+ vq_model.load_state_dict(
85
+ torch.load(pretrained_s2G, map_location="cpu", weights_only=False)["weight"], strict=False
86
+ )
87
+ )
88
+
89
+ def name2go(wav_name, lines):
90
+ hubert_path = "%s/%s.pt" % (hubert_dir, wav_name)
91
+ if os.path.exists(hubert_path) == False:
92
+ return
93
+ ssl_content = torch.load(hubert_path, map_location="cpu")
94
+ if is_half == True:
95
+ ssl_content = ssl_content.half().to(device)
96
+ else:
97
+ ssl_content = ssl_content.to(device)
98
+ codes = vq_model.extract_latent(ssl_content)
99
+ semantic = " ".join([str(i) for i in codes[0, 0, :].tolist()])
100
+ lines.append("%s\t%s" % (wav_name, semantic))
101
+
102
+ with open(inp_text, "r", encoding="utf8") as f:
103
+ lines = f.read().strip("\n").split("\n")
104
+
105
+ lines1 = []
106
+ for line in lines[int(i_part) :: int(all_parts)]:
107
+ # print(line)
108
+ try:
109
+ # wav_name,text=line.split("\t")
110
+ wav_name, spk_name, language, text = line.split("|")
111
+ wav_name = clean_path(wav_name)
112
+ wav_name = os.path.basename(wav_name)
113
+ # name2go(name,lines1)
114
+ name2go(wav_name, lines1)
115
+ except:
116
+ print(line, traceback.format_exc())
117
+ with open(semantic_path, "w", encoding="utf8") as f:
118
+ f.write("\n".join(lines1))