shethjenil commited on
Commit
4d4478c
·
verified ·
1 Parent(s): e143a6e

Upload 3 files

Browse files
Files changed (3) hide show
  1. app.py +20 -0
  2. requirements.txt +3 -0
  3. spleeter.py +223 -0
app.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from spleeter import Splitter
2
+ import torchaudio
3
+ from torchaudio.transforms import Resample
4
+ import torch
5
+ import gradio as gr
6
+ def separate(audio_path):
7
+ model = Splitter(2)
8
+ wav, sr = torchaudio.load(audio_path)
9
+ target_sr = 44100
10
+ if sr != target_sr:
11
+ resampler = Resample(sr, target_sr)
12
+ wav = resampler(wav)
13
+ sr = target_sr
14
+ with torch.no_grad():
15
+ results = model.forward(wav)
16
+ torchaudio.save("vocals.mp3", results['vocals'], sr,format="mp3")
17
+ torchaudio.save("accompaniment.mp3", results['accompaniment'], sr,format="mp3")
18
+ return "vocals.mp3" , "accompaniment.mp3"
19
+
20
+ gr.Interface(separate, gr.Audio(type="filepath"), [gr.Audio(type="filepath"), gr.Audio(type="filepath")]).launch()
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ torch
2
+ huggingface_hub
3
+ torchaudio
spleeter.py ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Dict, Tuple
3
+ from huggingface_hub import hf_hub_download
4
+ import torch
5
+ from torch import nn, Tensor
6
+ from torch.nn import functional as F
7
+
8
+ def batchify(tensor: Tensor, T: int) -> Tensor:
9
+ """
10
+ partition tensor into segments of length T, zero pad any ragged samples
11
+ Args:
12
+ tensor(Tensor): BxCxFxL
13
+ Returns:
14
+ tensor of size (B*[L/T] x C x F x T)
15
+ """
16
+ # Zero pad the original tensor to an even multiple of T
17
+ orig_size = tensor.size(-1)
18
+ new_size = math.ceil(orig_size / T) * T
19
+ tensor = F.pad(tensor, [0, new_size - orig_size])
20
+ # Partition the tensor into multiple samples of length T and stack them into a batch
21
+ return torch.cat(torch.split(tensor, T, dim=-1), dim=0)
22
+
23
+
24
+ class EncoderBlock(nn.Module):
25
+ def __init__(self, in_channels: int, out_channels: int) -> None:
26
+ super().__init__()
27
+ self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=5, stride=(2, 2))
28
+ self.bn = nn.BatchNorm2d(
29
+ num_features=out_channels,
30
+ track_running_stats=True,
31
+ eps=0.001,
32
+ momentum=0.01,
33
+ )
34
+ self.relu = nn.LeakyReLU(negative_slope=0.2)
35
+
36
+ def forward(self, input: Tensor) -> Tuple[Tensor, Tensor]:
37
+ down = self.conv(F.pad(input, (1, 2, 1, 2), "constant", 0))
38
+ return down, self.relu(self.bn(down))
39
+
40
+
41
+ class DecoderBlock(nn.Module):
42
+ def __init__(
43
+ self, in_channels: int, out_channels: int, dropout_prob: float = 0.0
44
+ ) -> None:
45
+ super().__init__()
46
+ self.tconv = nn.ConvTranspose2d(
47
+ in_channels, out_channels, kernel_size=5, stride=2
48
+ )
49
+ self.relu = nn.ReLU()
50
+ self.bn = nn.BatchNorm2d(
51
+ out_channels, track_running_stats=True, eps=1e-3, momentum=0.01
52
+ )
53
+ self.dropout = nn.Dropout(dropout_prob) if dropout_prob > 0 else nn.Identity()
54
+
55
+ def forward(self, input: Tensor) -> Tensor:
56
+ up = self.tconv(input)
57
+ # reverse padding
58
+ l, r, t, b = 1, 2, 1, 2
59
+ up = up[:, :, l:-r, t:-b]
60
+ return self.dropout(self.bn(self.relu(up)))
61
+
62
+
63
+ class UNet(nn.Module):
64
+ def __init__(
65
+ self,
66
+ n_layers: int = 6,
67
+ in_channels: int = 1,
68
+ ) -> None:
69
+ super().__init__()
70
+
71
+ # DownSample layers
72
+ down_set = [in_channels] + [2 ** (i + 4) for i in range(n_layers)]
73
+ self.encoder_layers = nn.ModuleList(
74
+ [
75
+ EncoderBlock(in_channels=in_ch, out_channels=out_ch)
76
+ for in_ch, out_ch in zip(down_set[:-1], down_set[1:])
77
+ ]
78
+ )
79
+
80
+ # UpSample layers
81
+ up_set = [1] + [2 ** (i + 4) for i in range(n_layers)]
82
+ up_set.reverse()
83
+ self.decoder_layers = nn.ModuleList(
84
+ [
85
+ DecoderBlock(
86
+ # doubled for concatenated inputs (skip connections)
87
+ in_channels=in_ch if i == 0 else in_ch * 2,
88
+ out_channels=out_ch,
89
+ # 50% dropout for first 3 layers
90
+ dropout_prob=0.5 if i < 3 else 0,
91
+ )
92
+ for i, (in_ch, out_ch) in enumerate(zip(up_set[:-1], up_set[1:]))
93
+ ]
94
+ )
95
+
96
+ # reconstruct the final mask same as the original channels
97
+ self.up_final = nn.Conv2d(1, in_channels, kernel_size=4, dilation=2, padding=3)
98
+ self.sigmoid = nn.Sigmoid()
99
+
100
+ def forward(self, input: Tensor) -> Tensor:
101
+ encoder_outputs_pre_act = []
102
+ x = input
103
+ for down in self.encoder_layers:
104
+ conv, x = down(x)
105
+ encoder_outputs_pre_act.append(conv)
106
+
107
+ for i, up in enumerate(self.decoder_layers):
108
+ if i == 0:
109
+ x = up(encoder_outputs_pre_act.pop())
110
+ else:
111
+ # merge skip connection
112
+ x = up(torch.cat([encoder_outputs_pre_act.pop(), x], dim=1))
113
+
114
+ mask = self.sigmoid(self.up_final(x))
115
+
116
+ # --- Crop both mask and input to match in size ---
117
+ min_f = min(mask.size(-2), input.size(-2))
118
+ min_t = min(mask.size(-1), input.size(-1))
119
+ mask = mask[..., :min_f, :min_t]
120
+ input = input[..., :min_f, :min_t]
121
+ # -------------------------------------------------
122
+
123
+ return mask * input
124
+
125
+
126
+
127
+
128
+ class Splitter(nn.Module):
129
+
130
+ def __init__(self, stem_num=2):
131
+ super(Splitter, self).__init__()
132
+ if stem_num == 2:
133
+ stem_names = ["vocals","accompaniment"]
134
+ if stem_num == 4:
135
+ stem_names = ["vocals", "drums", "bass", "other"]
136
+ if stem_num == 5:
137
+ stem_names = ["vocals", "piano", "drums", "bass", "other"]
138
+ # stft config
139
+ self.F = 1024
140
+ self.T = 512
141
+ self.win_length = 4096
142
+ self.hop_length = 1024
143
+ self.win = nn.Parameter(torch.hann_window(self.win_length), requires_grad=False)
144
+ self.stems = nn.ModuleDict({name: UNet(in_channels=2) for name in stem_names})
145
+ self.load_state_dict(torch.load(hf_hub_download("shethjenil/spleeter-torch",f"{stem_num}.pt")))
146
+ self.eval()
147
+
148
+ def compute_stft(self, wav: Tensor) -> Tuple[Tensor, Tensor]:
149
+ """
150
+ Computes STFT feature from wav
151
+ Args:
152
+ wav (Tensor): B x L or 2 x L for stereo
153
+ Returns:
154
+ stft (Tensor): B x F x T x 2 (real+imag)
155
+ mag (Tensor): B x F x T magnitude
156
+ """
157
+ stft = torch.stft(
158
+ wav,
159
+ n_fft=self.win_length,
160
+ hop_length=self.hop_length,
161
+ window=self.win,
162
+ center=True,
163
+ return_complex=False, # keep old format
164
+ pad_mode="constant",
165
+ )
166
+
167
+ # Keep only first F frequency bins
168
+ stft = stft[:, :self.F, :, :]
169
+
170
+ # magnitude
171
+ real = stft[:, :, :, 0]
172
+ imag = stft[:, :, :, 1]
173
+ mag = torch.sqrt(real**2 + imag**2 + 1e-10)
174
+
175
+ return stft, mag
176
+
177
+ def inverse_stft(self, stft: Tensor) -> Tensor:
178
+ """Inverse STFT from real+imag tensor (B x F x T x 2)"""
179
+
180
+ # Ensure frequency dimension matches n_fft / 2 + 1
181
+ target_F = self.win_length // 2 + 1
182
+ if stft.size(1) < target_F:
183
+ pad = target_F - stft.size(1)
184
+ stft = F.pad(stft, (0, 0, 0, 0, 0, pad)) # pad along freq dim
185
+
186
+ # Convert real+imag to complex for istft
187
+ stft_complex = torch.view_as_complex(stft)
188
+
189
+ wav = torch.istft(
190
+ stft_complex,
191
+ n_fft=self.win_length,
192
+ hop_length=self.hop_length,
193
+ win_length=self.win_length,
194
+ center=True,
195
+ window=self.win,
196
+ )
197
+
198
+ return wav.detach()
199
+
200
+ def forward(self, wav: Tensor) -> Dict[str, Tensor]:
201
+ # stft - 2 X F x L x 2
202
+ # stft_mag - 2 X F x L
203
+ stft, stft_mag = self.compute_stft(wav.squeeze())
204
+ L = stft.size(2)
205
+ # 1 x 2 x F x T
206
+ stft_mag = stft_mag.unsqueeze(-1).permute([3, 0, 1, 2])
207
+ stft_mag = batchify(stft_mag, self.T) # B x 2 x F x T
208
+ stft_mag = stft_mag.transpose(2, 3) # B x 2 x T x F
209
+ # compute stems' mask
210
+ masks = {name: net(stft_mag) for name, net in self.stems.items()}
211
+ # compute denominator
212
+ mask_sum = sum([m**2 for m in masks.values()])
213
+ mask_sum += 1e-10
214
+ def apply_mask(mask):
215
+ mask = (mask**2 + 1e-10 / 2) / (mask_sum)
216
+ mask = mask.transpose(2, 3) # B x 2 X F x T
217
+ mask = torch.cat(torch.split(mask, 1, dim=0), dim=3)
218
+ mask = mask.squeeze(0)[:, :, :L].unsqueeze(-1) # 2 x F x L x 1
219
+ stft_masked = stft * mask
220
+ return stft_masked
221
+ return {name: self.inverse_stft(apply_mask(m)) for name, m in masks.items()}
222
+
223
+