Spleeter / spleeter.py
shethjenil's picture
Update spleeter.py
650b3a2 verified
import math
from typing import Dict, Tuple
from huggingface_hub import hf_hub_download
import torch
from torch import nn, Tensor
from torch.nn import functional as F
from tqdm import tqdm
def batchify(tensor: Tensor, T: int) -> Tensor:
orig_size = tensor.size(-1)
new_size = math.ceil(orig_size / T) * T
tensor = F.pad(tensor, [0, new_size - orig_size])
return torch.cat(torch.split(tensor, T, dim=-1), dim=0)
class EncoderBlock(nn.Module):
def __init__(self, in_channels: int, out_channels: int) -> None:
super().__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=5, stride=(2, 2))
self.bn = nn.BatchNorm2d(
num_features=out_channels,
track_running_stats=True,
eps=0.001,
momentum=0.01,
)
self.relu = nn.LeakyReLU(negative_slope=0.2)
def forward(self, input: Tensor) -> Tuple[Tensor, Tensor]:
down = self.conv(F.pad(input, (1, 2, 1, 2), "constant", 0))
return down, self.relu(self.bn(down))
class DecoderBlock(nn.Module):
def __init__(
self, in_channels: int, out_channels: int, dropout_prob: float = 0.0
) -> None:
super().__init__()
self.tconv = nn.ConvTranspose2d(
in_channels, out_channels, kernel_size=5, stride=2
)
self.relu = nn.ReLU()
self.bn = nn.BatchNorm2d(
out_channels, track_running_stats=True, eps=1e-3, momentum=0.01
)
self.dropout = nn.Dropout(dropout_prob) if dropout_prob > 0 else nn.Identity()
def forward(self, input: Tensor) -> Tensor:
up = self.tconv(input)
# reverse padding
l, r, t, b = 1, 2, 1, 2
up = up[:, :, l:-r, t:-b]
return self.dropout(self.bn(self.relu(up)))
class UNet(nn.Module):
def __init__(
self,
n_layers: int = 6,
in_channels: int = 1,
) -> None:
super().__init__()
# DownSample layers
down_set = [in_channels] + [2 ** (i + 4) for i in range(n_layers)]
self.encoder_layers = nn.ModuleList(
[
EncoderBlock(in_channels=in_ch, out_channels=out_ch)
for in_ch, out_ch in zip(down_set[:-1], down_set[1:])
]
)
# UpSample layers
up_set = [1] + [2 ** (i + 4) for i in range(n_layers)]
up_set.reverse()
self.decoder_layers = nn.ModuleList(
[
DecoderBlock(
# doubled for concatenated inputs (skip connections)
in_channels=in_ch if i == 0 else in_ch * 2,
out_channels=out_ch,
# 50% dropout for first 3 layers
dropout_prob=0.5 if i < 3 else 0,
)
for i, (in_ch, out_ch) in enumerate(zip(up_set[:-1], up_set[1:]))
]
)
# reconstruct the final mask same as the original channels
self.up_final = nn.Conv2d(1, in_channels, kernel_size=4, dilation=2, padding=3)
self.sigmoid = nn.Sigmoid()
def forward(self, input: Tensor) -> Tensor:
encoder_outputs_pre_act = []
x = input
for down in self.encoder_layers:
conv, x = down(x)
encoder_outputs_pre_act.append(conv)
for i, up in enumerate(self.decoder_layers):
if i == 0:
x = up(encoder_outputs_pre_act.pop())
else:
# merge skip connection
x = up(torch.cat([encoder_outputs_pre_act.pop(), x], dim=1))
mask = self.sigmoid(self.up_final(x))
# --- Crop both mask and input to match in size ---
min_f = min(mask.size(-2), input.size(-2))
min_t = min(mask.size(-1), input.size(-1))
mask = mask[..., :min_f, :min_t]
input = input[..., :min_f, :min_t]
# -------------------------------------------------
return mask * input
class Splitter(nn.Module):
def __init__(self, stem_num=2):
super(Splitter, self).__init__()
if stem_num == 2:
stem_names = ["vocals","accompaniment"]
if stem_num == 4:
stem_names = ["vocals", "drums", "bass", "other"]
if stem_num == 5:
stem_names = ["vocals", "piano", "drums", "bass", "other"]
# stft config
self.F = 1024
self.T = 512
self.win_length = 4096
self.hop_length = 1024
self.win = nn.Parameter(torch.hann_window(self.win_length), requires_grad=False)
self.stems = nn.ModuleDict({name: UNet(in_channels=2) for name in stem_names})
self.load_state_dict(torch.load(hf_hub_download("shethjenil/spleeter-torch",f"{stem_num}.pt")))
self.eval()
def compute_stft(self, wav: Tensor) -> Tuple[Tensor, Tensor]:
"""
Computes STFT feature from wav
Args:
wav (Tensor): B x L or 2 x L for stereo
Returns:
stft (Tensor): B x F x T x 2 (real+imag)
mag (Tensor): B x F x T magnitude
"""
stft = torch.stft(
wav,
n_fft=self.win_length,
hop_length=self.hop_length,
window=self.win,
center=True,
return_complex=False, # keep old format
pad_mode="constant",
)
# Keep only first F frequency bins
stft = stft[:, :self.F, :, :]
# magnitude
real = stft[:, :, :, 0]
imag = stft[:, :, :, 1]
mag = torch.sqrt(real**2 + imag**2 + 1e-10)
return stft, mag
def inverse_stft(self, stft: Tensor) -> Tensor:
"""Inverse STFT from real+imag tensor (B x F x T x 2)"""
# Ensure frequency dimension matches n_fft / 2 + 1
target_F = self.win_length // 2 + 1
if stft.size(1) < target_F:
pad = target_F - stft.size(1)
stft = F.pad(stft, (0, 0, 0, 0, 0, pad)) # pad along freq dim
# Convert real+imag to complex for istft
stft_complex = torch.view_as_complex(stft)
wav = torch.istft(
stft_complex,
n_fft=self.win_length,
hop_length=self.hop_length,
win_length=self.win_length,
center=True,
window=self.win,
)
return wav.detach()
def forward(self, wav: Tensor,batch_size=16) -> Dict[str, Tensor]:
# stft - 2 X F x L x 2
# stft_mag - 2 X F x L
stft, stft_mag = self.compute_stft(wav.squeeze())
L = stft.size(2)
# 1 x 2 x F x T
stft_mag = stft_mag.unsqueeze(-1).permute([3, 0, 1, 2])
stft_mag = batchify(stft_mag, self.T) # B x 2 x F x T
stft_mag = stft_mag.transpose(2, 3) # B x 2 x T x F
# compute stems' mask
masks = self.infer_with_batches(stft_mag,batch_size)
# compute denominator
mask_sum = sum([m**2 for m in masks.values()])
mask_sum += 1e-10
def apply_mask(mask):
mask = (mask**2 + 1e-10 / 2) / (mask_sum)
mask = mask.transpose(2, 3) # B x 2 X F x T
mask = torch.cat(torch.split(mask, 1, dim=0), dim=3)
mask = mask.squeeze(0)[:, :, :L].unsqueeze(-1) # 2 x F x L x 1
stft_masked = stft * mask
return stft_masked
return {name: self.inverse_stft(apply_mask(m)) for name, m in masks.items()}
def infer_with_batches(self, stft_mag, batch_size):
masks = {name: [] for name in self.stems.keys()}
with torch.inference_mode():
for i in tqdm(range(0, stft_mag.shape[0], batch_size)):
batch = stft_mag[i:i + batch_size]
batch_outputs = {name: net(batch) for name, net in self.stems.items()}
for name in self.stems.keys():
masks[name].append(batch_outputs[name])
masks = {name: torch.cat(masks[name], dim=0) for name in masks}
return masks