Spaces:
Sleeping
Sleeping
File size: 7,941 Bytes
a274d42 650b3a2 a274d42 650b3a2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 |
import math
from typing import Dict, Tuple
from huggingface_hub import hf_hub_download
import torch
from torch import nn, Tensor
from torch.nn import functional as F
from tqdm import tqdm
def batchify(tensor: Tensor, T: int) -> Tensor:
orig_size = tensor.size(-1)
new_size = math.ceil(orig_size / T) * T
tensor = F.pad(tensor, [0, new_size - orig_size])
return torch.cat(torch.split(tensor, T, dim=-1), dim=0)
class EncoderBlock(nn.Module):
def __init__(self, in_channels: int, out_channels: int) -> None:
super().__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=5, stride=(2, 2))
self.bn = nn.BatchNorm2d(
num_features=out_channels,
track_running_stats=True,
eps=0.001,
momentum=0.01,
)
self.relu = nn.LeakyReLU(negative_slope=0.2)
def forward(self, input: Tensor) -> Tuple[Tensor, Tensor]:
down = self.conv(F.pad(input, (1, 2, 1, 2), "constant", 0))
return down, self.relu(self.bn(down))
class DecoderBlock(nn.Module):
def __init__(
self, in_channels: int, out_channels: int, dropout_prob: float = 0.0
) -> None:
super().__init__()
self.tconv = nn.ConvTranspose2d(
in_channels, out_channels, kernel_size=5, stride=2
)
self.relu = nn.ReLU()
self.bn = nn.BatchNorm2d(
out_channels, track_running_stats=True, eps=1e-3, momentum=0.01
)
self.dropout = nn.Dropout(dropout_prob) if dropout_prob > 0 else nn.Identity()
def forward(self, input: Tensor) -> Tensor:
up = self.tconv(input)
# reverse padding
l, r, t, b = 1, 2, 1, 2
up = up[:, :, l:-r, t:-b]
return self.dropout(self.bn(self.relu(up)))
class UNet(nn.Module):
def __init__(
self,
n_layers: int = 6,
in_channels: int = 1,
) -> None:
super().__init__()
# DownSample layers
down_set = [in_channels] + [2 ** (i + 4) for i in range(n_layers)]
self.encoder_layers = nn.ModuleList(
[
EncoderBlock(in_channels=in_ch, out_channels=out_ch)
for in_ch, out_ch in zip(down_set[:-1], down_set[1:])
]
)
# UpSample layers
up_set = [1] + [2 ** (i + 4) for i in range(n_layers)]
up_set.reverse()
self.decoder_layers = nn.ModuleList(
[
DecoderBlock(
# doubled for concatenated inputs (skip connections)
in_channels=in_ch if i == 0 else in_ch * 2,
out_channels=out_ch,
# 50% dropout for first 3 layers
dropout_prob=0.5 if i < 3 else 0,
)
for i, (in_ch, out_ch) in enumerate(zip(up_set[:-1], up_set[1:]))
]
)
# reconstruct the final mask same as the original channels
self.up_final = nn.Conv2d(1, in_channels, kernel_size=4, dilation=2, padding=3)
self.sigmoid = nn.Sigmoid()
def forward(self, input: Tensor) -> Tensor:
encoder_outputs_pre_act = []
x = input
for down in self.encoder_layers:
conv, x = down(x)
encoder_outputs_pre_act.append(conv)
for i, up in enumerate(self.decoder_layers):
if i == 0:
x = up(encoder_outputs_pre_act.pop())
else:
# merge skip connection
x = up(torch.cat([encoder_outputs_pre_act.pop(), x], dim=1))
mask = self.sigmoid(self.up_final(x))
# --- Crop both mask and input to match in size ---
min_f = min(mask.size(-2), input.size(-2))
min_t = min(mask.size(-1), input.size(-1))
mask = mask[..., :min_f, :min_t]
input = input[..., :min_f, :min_t]
# -------------------------------------------------
return mask * input
class Splitter(nn.Module):
def __init__(self, stem_num=2):
super(Splitter, self).__init__()
if stem_num == 2:
stem_names = ["vocals","accompaniment"]
if stem_num == 4:
stem_names = ["vocals", "drums", "bass", "other"]
if stem_num == 5:
stem_names = ["vocals", "piano", "drums", "bass", "other"]
# stft config
self.F = 1024
self.T = 512
self.win_length = 4096
self.hop_length = 1024
self.win = nn.Parameter(torch.hann_window(self.win_length), requires_grad=False)
self.stems = nn.ModuleDict({name: UNet(in_channels=2) for name in stem_names})
self.load_state_dict(torch.load(hf_hub_download("shethjenil/spleeter-torch",f"{stem_num}.pt")))
self.eval()
def compute_stft(self, wav: Tensor) -> Tuple[Tensor, Tensor]:
"""
Computes STFT feature from wav
Args:
wav (Tensor): B x L or 2 x L for stereo
Returns:
stft (Tensor): B x F x T x 2 (real+imag)
mag (Tensor): B x F x T magnitude
"""
stft = torch.stft(
wav,
n_fft=self.win_length,
hop_length=self.hop_length,
window=self.win,
center=True,
return_complex=False, # keep old format
pad_mode="constant",
)
# Keep only first F frequency bins
stft = stft[:, :self.F, :, :]
# magnitude
real = stft[:, :, :, 0]
imag = stft[:, :, :, 1]
mag = torch.sqrt(real**2 + imag**2 + 1e-10)
return stft, mag
def inverse_stft(self, stft: Tensor) -> Tensor:
"""Inverse STFT from real+imag tensor (B x F x T x 2)"""
# Ensure frequency dimension matches n_fft / 2 + 1
target_F = self.win_length // 2 + 1
if stft.size(1) < target_F:
pad = target_F - stft.size(1)
stft = F.pad(stft, (0, 0, 0, 0, 0, pad)) # pad along freq dim
# Convert real+imag to complex for istft
stft_complex = torch.view_as_complex(stft)
wav = torch.istft(
stft_complex,
n_fft=self.win_length,
hop_length=self.hop_length,
win_length=self.win_length,
center=True,
window=self.win,
)
return wav.detach()
def forward(self, wav: Tensor,batch_size=16) -> Dict[str, Tensor]:
# stft - 2 X F x L x 2
# stft_mag - 2 X F x L
stft, stft_mag = self.compute_stft(wav.squeeze())
L = stft.size(2)
# 1 x 2 x F x T
stft_mag = stft_mag.unsqueeze(-1).permute([3, 0, 1, 2])
stft_mag = batchify(stft_mag, self.T) # B x 2 x F x T
stft_mag = stft_mag.transpose(2, 3) # B x 2 x T x F
# compute stems' mask
masks = self.infer_with_batches(stft_mag,batch_size)
# compute denominator
mask_sum = sum([m**2 for m in masks.values()])
mask_sum += 1e-10
def apply_mask(mask):
mask = (mask**2 + 1e-10 / 2) / (mask_sum)
mask = mask.transpose(2, 3) # B x 2 X F x T
mask = torch.cat(torch.split(mask, 1, dim=0), dim=3)
mask = mask.squeeze(0)[:, :, :L].unsqueeze(-1) # 2 x F x L x 1
stft_masked = stft * mask
return stft_masked
return {name: self.inverse_stft(apply_mask(m)) for name, m in masks.items()}
def infer_with_batches(self, stft_mag, batch_size):
masks = {name: [] for name in self.stems.keys()}
with torch.inference_mode():
for i in tqdm(range(0, stft_mag.shape[0], batch_size)):
batch = stft_mag[i:i + batch_size]
batch_outputs = {name: net(batch) for name, net in self.stems.items()}
for name in self.stems.keys():
masks[name].append(batch_outputs[name])
masks = {name: torch.cat(masks[name], dim=0) for name in masks}
return masks
|