Spaces:
Sleeping
Sleeping
| import math | |
| from typing import Dict, Tuple | |
| from huggingface_hub import hf_hub_download | |
| import torch | |
| from torch import nn, Tensor | |
| from torch.nn import functional as F | |
| from tqdm import tqdm | |
| def batchify(tensor: Tensor, T: int) -> Tensor: | |
| orig_size = tensor.size(-1) | |
| new_size = math.ceil(orig_size / T) * T | |
| tensor = F.pad(tensor, [0, new_size - orig_size]) | |
| return torch.cat(torch.split(tensor, T, dim=-1), dim=0) | |
| class EncoderBlock(nn.Module): | |
| def __init__(self, in_channels: int, out_channels: int) -> None: | |
| super().__init__() | |
| self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=5, stride=(2, 2)) | |
| self.bn = nn.BatchNorm2d( | |
| num_features=out_channels, | |
| track_running_stats=True, | |
| eps=0.001, | |
| momentum=0.01, | |
| ) | |
| self.relu = nn.LeakyReLU(negative_slope=0.2) | |
| def forward(self, input: Tensor) -> Tuple[Tensor, Tensor]: | |
| down = self.conv(F.pad(input, (1, 2, 1, 2), "constant", 0)) | |
| return down, self.relu(self.bn(down)) | |
| class DecoderBlock(nn.Module): | |
| def __init__( | |
| self, in_channels: int, out_channels: int, dropout_prob: float = 0.0 | |
| ) -> None: | |
| super().__init__() | |
| self.tconv = nn.ConvTranspose2d( | |
| in_channels, out_channels, kernel_size=5, stride=2 | |
| ) | |
| self.relu = nn.ReLU() | |
| self.bn = nn.BatchNorm2d( | |
| out_channels, track_running_stats=True, eps=1e-3, momentum=0.01 | |
| ) | |
| self.dropout = nn.Dropout(dropout_prob) if dropout_prob > 0 else nn.Identity() | |
| def forward(self, input: Tensor) -> Tensor: | |
| up = self.tconv(input) | |
| # reverse padding | |
| l, r, t, b = 1, 2, 1, 2 | |
| up = up[:, :, l:-r, t:-b] | |
| return self.dropout(self.bn(self.relu(up))) | |
| class UNet(nn.Module): | |
| def __init__( | |
| self, | |
| n_layers: int = 6, | |
| in_channels: int = 1, | |
| ) -> None: | |
| super().__init__() | |
| # DownSample layers | |
| down_set = [in_channels] + [2 ** (i + 4) for i in range(n_layers)] | |
| self.encoder_layers = nn.ModuleList( | |
| [ | |
| EncoderBlock(in_channels=in_ch, out_channels=out_ch) | |
| for in_ch, out_ch in zip(down_set[:-1], down_set[1:]) | |
| ] | |
| ) | |
| # UpSample layers | |
| up_set = [1] + [2 ** (i + 4) for i in range(n_layers)] | |
| up_set.reverse() | |
| self.decoder_layers = nn.ModuleList( | |
| [ | |
| DecoderBlock( | |
| # doubled for concatenated inputs (skip connections) | |
| in_channels=in_ch if i == 0 else in_ch * 2, | |
| out_channels=out_ch, | |
| # 50% dropout for first 3 layers | |
| dropout_prob=0.5 if i < 3 else 0, | |
| ) | |
| for i, (in_ch, out_ch) in enumerate(zip(up_set[:-1], up_set[1:])) | |
| ] | |
| ) | |
| # reconstruct the final mask same as the original channels | |
| self.up_final = nn.Conv2d(1, in_channels, kernel_size=4, dilation=2, padding=3) | |
| self.sigmoid = nn.Sigmoid() | |
| def forward(self, input: Tensor) -> Tensor: | |
| encoder_outputs_pre_act = [] | |
| x = input | |
| for down in self.encoder_layers: | |
| conv, x = down(x) | |
| encoder_outputs_pre_act.append(conv) | |
| for i, up in enumerate(self.decoder_layers): | |
| if i == 0: | |
| x = up(encoder_outputs_pre_act.pop()) | |
| else: | |
| # merge skip connection | |
| x = up(torch.cat([encoder_outputs_pre_act.pop(), x], dim=1)) | |
| mask = self.sigmoid(self.up_final(x)) | |
| # --- Crop both mask and input to match in size --- | |
| min_f = min(mask.size(-2), input.size(-2)) | |
| min_t = min(mask.size(-1), input.size(-1)) | |
| mask = mask[..., :min_f, :min_t] | |
| input = input[..., :min_f, :min_t] | |
| # ------------------------------------------------- | |
| return mask * input | |
| class Splitter(nn.Module): | |
| def __init__(self, stem_num=2): | |
| super(Splitter, self).__init__() | |
| if stem_num == 2: | |
| stem_names = ["vocals","accompaniment"] | |
| if stem_num == 4: | |
| stem_names = ["vocals", "drums", "bass", "other"] | |
| if stem_num == 5: | |
| stem_names = ["vocals", "piano", "drums", "bass", "other"] | |
| # stft config | |
| self.F = 1024 | |
| self.T = 512 | |
| self.win_length = 4096 | |
| self.hop_length = 1024 | |
| self.win = nn.Parameter(torch.hann_window(self.win_length), requires_grad=False) | |
| self.stems = nn.ModuleDict({name: UNet(in_channels=2) for name in stem_names}) | |
| self.load_state_dict(torch.load(hf_hub_download("shethjenil/spleeter-torch",f"{stem_num}.pt"))) | |
| self.eval() | |
| def compute_stft(self, wav: Tensor) -> Tuple[Tensor, Tensor]: | |
| """ | |
| Computes STFT feature from wav | |
| Args: | |
| wav (Tensor): B x L or 2 x L for stereo | |
| Returns: | |
| stft (Tensor): B x F x T x 2 (real+imag) | |
| mag (Tensor): B x F x T magnitude | |
| """ | |
| stft = torch.stft( | |
| wav, | |
| n_fft=self.win_length, | |
| hop_length=self.hop_length, | |
| window=self.win, | |
| center=True, | |
| return_complex=False, # keep old format | |
| pad_mode="constant", | |
| ) | |
| # Keep only first F frequency bins | |
| stft = stft[:, :self.F, :, :] | |
| # magnitude | |
| real = stft[:, :, :, 0] | |
| imag = stft[:, :, :, 1] | |
| mag = torch.sqrt(real**2 + imag**2 + 1e-10) | |
| return stft, mag | |
| def inverse_stft(self, stft: Tensor) -> Tensor: | |
| """Inverse STFT from real+imag tensor (B x F x T x 2)""" | |
| # Ensure frequency dimension matches n_fft / 2 + 1 | |
| target_F = self.win_length // 2 + 1 | |
| if stft.size(1) < target_F: | |
| pad = target_F - stft.size(1) | |
| stft = F.pad(stft, (0, 0, 0, 0, 0, pad)) # pad along freq dim | |
| # Convert real+imag to complex for istft | |
| stft_complex = torch.view_as_complex(stft) | |
| wav = torch.istft( | |
| stft_complex, | |
| n_fft=self.win_length, | |
| hop_length=self.hop_length, | |
| win_length=self.win_length, | |
| center=True, | |
| window=self.win, | |
| ) | |
| return wav.detach() | |
| def forward(self, wav: Tensor,batch_size=16) -> Dict[str, Tensor]: | |
| # stft - 2 X F x L x 2 | |
| # stft_mag - 2 X F x L | |
| stft, stft_mag = self.compute_stft(wav.squeeze()) | |
| L = stft.size(2) | |
| # 1 x 2 x F x T | |
| stft_mag = stft_mag.unsqueeze(-1).permute([3, 0, 1, 2]) | |
| stft_mag = batchify(stft_mag, self.T) # B x 2 x F x T | |
| stft_mag = stft_mag.transpose(2, 3) # B x 2 x T x F | |
| # compute stems' mask | |
| masks = self.infer_with_batches(stft_mag,batch_size) | |
| # compute denominator | |
| mask_sum = sum([m**2 for m in masks.values()]) | |
| mask_sum += 1e-10 | |
| def apply_mask(mask): | |
| mask = (mask**2 + 1e-10 / 2) / (mask_sum) | |
| mask = mask.transpose(2, 3) # B x 2 X F x T | |
| mask = torch.cat(torch.split(mask, 1, dim=0), dim=3) | |
| mask = mask.squeeze(0)[:, :, :L].unsqueeze(-1) # 2 x F x L x 1 | |
| stft_masked = stft * mask | |
| return stft_masked | |
| return {name: self.inverse_stft(apply_mask(m)) for name, m in masks.items()} | |
| def infer_with_batches(self, stft_mag, batch_size): | |
| masks = {name: [] for name in self.stems.keys()} | |
| with torch.inference_mode(): | |
| for i in tqdm(range(0, stft_mag.shape[0], batch_size)): | |
| batch = stft_mag[i:i + batch_size] | |
| batch_outputs = {name: net(batch) for name, net in self.stems.items()} | |
| for name in self.stems.keys(): | |
| masks[name].append(batch_outputs[name]) | |
| masks = {name: torch.cat(masks[name], dim=0) for name in masks} | |
| return masks | |