import math from typing import Dict, Tuple from huggingface_hub import hf_hub_download import torch from torch import nn, Tensor from torch.nn import functional as F from tqdm import tqdm def batchify(tensor: Tensor, T: int) -> Tensor: orig_size = tensor.size(-1) new_size = math.ceil(orig_size / T) * T tensor = F.pad(tensor, [0, new_size - orig_size]) return torch.cat(torch.split(tensor, T, dim=-1), dim=0) class EncoderBlock(nn.Module): def __init__(self, in_channels: int, out_channels: int) -> None: super().__init__() self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=5, stride=(2, 2)) self.bn = nn.BatchNorm2d( num_features=out_channels, track_running_stats=True, eps=0.001, momentum=0.01, ) self.relu = nn.LeakyReLU(negative_slope=0.2) def forward(self, input: Tensor) -> Tuple[Tensor, Tensor]: down = self.conv(F.pad(input, (1, 2, 1, 2), "constant", 0)) return down, self.relu(self.bn(down)) class DecoderBlock(nn.Module): def __init__( self, in_channels: int, out_channels: int, dropout_prob: float = 0.0 ) -> None: super().__init__() self.tconv = nn.ConvTranspose2d( in_channels, out_channels, kernel_size=5, stride=2 ) self.relu = nn.ReLU() self.bn = nn.BatchNorm2d( out_channels, track_running_stats=True, eps=1e-3, momentum=0.01 ) self.dropout = nn.Dropout(dropout_prob) if dropout_prob > 0 else nn.Identity() def forward(self, input: Tensor) -> Tensor: up = self.tconv(input) # reverse padding l, r, t, b = 1, 2, 1, 2 up = up[:, :, l:-r, t:-b] return self.dropout(self.bn(self.relu(up))) class UNet(nn.Module): def __init__( self, n_layers: int = 6, in_channels: int = 1, ) -> None: super().__init__() # DownSample layers down_set = [in_channels] + [2 ** (i + 4) for i in range(n_layers)] self.encoder_layers = nn.ModuleList( [ EncoderBlock(in_channels=in_ch, out_channels=out_ch) for in_ch, out_ch in zip(down_set[:-1], down_set[1:]) ] ) # UpSample layers up_set = [1] + [2 ** (i + 4) for i in range(n_layers)] up_set.reverse() self.decoder_layers = nn.ModuleList( [ DecoderBlock( # doubled for concatenated inputs (skip connections) in_channels=in_ch if i == 0 else in_ch * 2, out_channels=out_ch, # 50% dropout for first 3 layers dropout_prob=0.5 if i < 3 else 0, ) for i, (in_ch, out_ch) in enumerate(zip(up_set[:-1], up_set[1:])) ] ) # reconstruct the final mask same as the original channels self.up_final = nn.Conv2d(1, in_channels, kernel_size=4, dilation=2, padding=3) self.sigmoid = nn.Sigmoid() def forward(self, input: Tensor) -> Tensor: encoder_outputs_pre_act = [] x = input for down in self.encoder_layers: conv, x = down(x) encoder_outputs_pre_act.append(conv) for i, up in enumerate(self.decoder_layers): if i == 0: x = up(encoder_outputs_pre_act.pop()) else: # merge skip connection x = up(torch.cat([encoder_outputs_pre_act.pop(), x], dim=1)) mask = self.sigmoid(self.up_final(x)) # --- Crop both mask and input to match in size --- min_f = min(mask.size(-2), input.size(-2)) min_t = min(mask.size(-1), input.size(-1)) mask = mask[..., :min_f, :min_t] input = input[..., :min_f, :min_t] # ------------------------------------------------- return mask * input class Splitter(nn.Module): def __init__(self, stem_num=2): super(Splitter, self).__init__() if stem_num == 2: stem_names = ["vocals","accompaniment"] if stem_num == 4: stem_names = ["vocals", "drums", "bass", "other"] if stem_num == 5: stem_names = ["vocals", "piano", "drums", "bass", "other"] # stft config self.F = 1024 self.T = 512 self.win_length = 4096 self.hop_length = 1024 self.win = nn.Parameter(torch.hann_window(self.win_length), requires_grad=False) self.stems = nn.ModuleDict({name: UNet(in_channels=2) for name in stem_names}) self.load_state_dict(torch.load(hf_hub_download("shethjenil/spleeter-torch",f"{stem_num}.pt"))) self.eval() def compute_stft(self, wav: Tensor) -> Tuple[Tensor, Tensor]: """ Computes STFT feature from wav Args: wav (Tensor): B x L or 2 x L for stereo Returns: stft (Tensor): B x F x T x 2 (real+imag) mag (Tensor): B x F x T magnitude """ stft = torch.stft( wav, n_fft=self.win_length, hop_length=self.hop_length, window=self.win, center=True, return_complex=False, # keep old format pad_mode="constant", ) # Keep only first F frequency bins stft = stft[:, :self.F, :, :] # magnitude real = stft[:, :, :, 0] imag = stft[:, :, :, 1] mag = torch.sqrt(real**2 + imag**2 + 1e-10) return stft, mag def inverse_stft(self, stft: Tensor) -> Tensor: """Inverse STFT from real+imag tensor (B x F x T x 2)""" # Ensure frequency dimension matches n_fft / 2 + 1 target_F = self.win_length // 2 + 1 if stft.size(1) < target_F: pad = target_F - stft.size(1) stft = F.pad(stft, (0, 0, 0, 0, 0, pad)) # pad along freq dim # Convert real+imag to complex for istft stft_complex = torch.view_as_complex(stft) wav = torch.istft( stft_complex, n_fft=self.win_length, hop_length=self.hop_length, win_length=self.win_length, center=True, window=self.win, ) return wav.detach() def forward(self, wav: Tensor,batch_size=16) -> Dict[str, Tensor]: # stft - 2 X F x L x 2 # stft_mag - 2 X F x L stft, stft_mag = self.compute_stft(wav.squeeze()) L = stft.size(2) # 1 x 2 x F x T stft_mag = stft_mag.unsqueeze(-1).permute([3, 0, 1, 2]) stft_mag = batchify(stft_mag, self.T) # B x 2 x F x T stft_mag = stft_mag.transpose(2, 3) # B x 2 x T x F # compute stems' mask masks = self.infer_with_batches(stft_mag,batch_size) # compute denominator mask_sum = sum([m**2 for m in masks.values()]) mask_sum += 1e-10 def apply_mask(mask): mask = (mask**2 + 1e-10 / 2) / (mask_sum) mask = mask.transpose(2, 3) # B x 2 X F x T mask = torch.cat(torch.split(mask, 1, dim=0), dim=3) mask = mask.squeeze(0)[:, :, :L].unsqueeze(-1) # 2 x F x L x 1 stft_masked = stft * mask return stft_masked return {name: self.inverse_stft(apply_mask(m)) for name, m in masks.items()} def infer_with_batches(self, stft_mag, batch_size): masks = {name: [] for name in self.stems.keys()} with torch.inference_mode(): for i in tqdm(range(0, stft_mag.shape[0], batch_size)): batch = stft_mag[i:i + batch_size] batch_outputs = {name: net(batch) for name, net in self.stems.items()} for name in self.stems.keys(): masks[name].append(batch_outputs[name]) masks = {name: torch.cat(masks[name], dim=0) for name in masks} return masks