Spaces:
Running
on
Zero
Running
on
Zero
| # part of the code is borrowed from https://github.com/lawlict/ECAPA-TDNN | |
| from __future__ import annotations | |
| from pathlib import Path | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import torchaudio.transforms as trans | |
| from torchaudio.models import Conformer | |
| from f5_tts.model.utils import (default, exists, lens_to_mask, list_str_to_idx, | |
| list_str_to_tensor, mask_from_frac_lengths) | |
| class ResBlock(nn.Module): | |
| def __init__(self, hidden_dim, n_conv=3, dropout_p=0.2): | |
| super().__init__() | |
| self._n_groups = 8 | |
| self.blocks = nn.ModuleList( | |
| [ | |
| self._get_conv(hidden_dim, dilation=3**i, dropout_p=dropout_p) | |
| for i in range(n_conv) | |
| ] | |
| ) | |
| def forward(self, x): | |
| for block in self.blocks: | |
| res = x | |
| x = block(x) | |
| x += res | |
| return x | |
| def _get_conv(self, hidden_dim, dilation, dropout_p=0.2): | |
| layers = [ | |
| nn.Conv1d( | |
| hidden_dim, | |
| hidden_dim, | |
| kernel_size=3, | |
| padding=dilation, | |
| dilation=dilation, | |
| ), | |
| nn.ReLU(), | |
| nn.GroupNorm(num_groups=self._n_groups, num_channels=hidden_dim), | |
| nn.Dropout(p=dropout_p), | |
| nn.Conv1d(hidden_dim, hidden_dim, kernel_size=3, padding=1, dilation=1), | |
| nn.ReLU(), | |
| nn.Dropout(p=dropout_p), | |
| ] | |
| return nn.Sequential(*layers) | |
| class ConformerDiscirminator(nn.Module): | |
| def __init__( | |
| self, | |
| input_dim, | |
| channels=512, | |
| num_layers=3, | |
| num_heads=8, | |
| depthwise_conv_kernel_size=15, | |
| use_group_norm=True, | |
| ): | |
| super().__init__() | |
| self.input_layer = nn.Conv1d(input_dim, channels, kernel_size=3, padding=1) | |
| self.resblock1 = nn.Sequential( | |
| ResBlock(channels), nn.GroupNorm(num_groups=1, num_channels=channels) | |
| ) | |
| self.resblock2 = nn.Sequential( | |
| ResBlock(channels), nn.GroupNorm(num_groups=1, num_channels=channels) | |
| ) | |
| self.conformer1 = Conformer( | |
| **{ | |
| "input_dim": channels, | |
| "num_heads": num_heads, | |
| "ffn_dim": channels * 2, | |
| "num_layers": 1, | |
| "depthwise_conv_kernel_size": depthwise_conv_kernel_size // 2, | |
| "use_group_norm": use_group_norm, | |
| } | |
| ) | |
| self.conformer2 = Conformer( | |
| **{ | |
| "input_dim": channels, | |
| "num_heads": num_heads, | |
| "ffn_dim": channels * 2, | |
| "num_layers": num_layers - 1, | |
| "depthwise_conv_kernel_size": depthwise_conv_kernel_size, | |
| "use_group_norm": use_group_norm, | |
| } | |
| ) | |
| self.linear = nn.Conv1d(channels, 1, kernel_size=1) | |
| def forward(self, x): | |
| # x = torch.stack(x, dim=1).transpose(-1, -2).flatten(start_dim=1, end_dim=2) | |
| x = torch.cat(x, dim=-1) | |
| x = x.transpose(1, 2) | |
| x = self.input_layer(x) | |
| x = self.resblock1(x) | |
| x = nn.functional.avg_pool1d(x, 2) | |
| x = self.resblock2(x) | |
| x = nn.functional.avg_pool1d(x, 2) | |
| # Transpose to (B, T, C) for the conformer. | |
| x = x.transpose(1, 2) | |
| batch_size, time_steps, _ = x.shape | |
| # Create a dummy lengths tensor (all sequences are assumed to be full length). | |
| lengths = torch.full( | |
| (batch_size,), time_steps, device=x.device, dtype=torch.int64 | |
| ) | |
| # The built-in Conformer returns (output, output_lengths); we discard lengths. | |
| x, _ = self.conformer1(x, lengths) | |
| x, _ = self.conformer2(x, lengths) | |
| # Transpose back to (B, C, T). | |
| x = x.transpose(1, 2) | |
| # out = self.bn(self.pooling(out)) | |
| out = self.linear(x).squeeze(1) | |
| return out | |
| if __name__ == "__main__": | |
| from f5_tts.model import DiT | |
| from f5_tts.model.utils import get_tokenizer | |
| bsz = 2 | |
| tokenizer = "pinyin" # 'pinyin', 'char', or 'custom' | |
| tokenizer_path = None # if tokenizer = 'custom', define the path to the tokenizer you want to use (should be vocab.txt) | |
| dataset_name = "Emilia_ZH_EN" | |
| if tokenizer == "custom": | |
| tokenizer_path = tokenizer_path | |
| else: | |
| tokenizer_path = dataset_name | |
| vocab_char_map, vocab_size = get_tokenizer(tokenizer_path, tokenizer) | |
| fake_unet = DiT( | |
| dim=1024, | |
| depth=22, | |
| heads=16, | |
| ff_mult=2, | |
| text_dim=512, | |
| conv_layers=4, | |
| text_num_embeds=vocab_size, | |
| mel_dim=80, | |
| ) | |
| fake_unet = fake_unet.cuda() | |
| text = ["hello world"] * bsz | |
| lens = torch.randint(1, 1000, (bsz,)).cuda() | |
| inp = torch.randn(bsz, lens.max(), 80).cuda() | |
| batch, seq_len, dtype, device = *inp.shape[:2], inp.dtype, inp.device | |
| batch, seq_len, dtype, device = *inp.shape[:2], inp.dtype, inp.device | |
| # handle text as string | |
| if isinstance(text, list): | |
| if exists(vocab_char_map): | |
| text = list_str_to_idx(text, vocab_char_map).to(device) | |
| else: | |
| text = list_str_to_tensor(text).to(device) | |
| assert text.shape[0] == batch | |
| # lens and mask | |
| if not exists(lens): | |
| lens = torch.full((batch,), seq_len, device=device) | |
| mask = lens_to_mask( | |
| lens, length=seq_len | |
| ) # useless here, as collate_fn will pad to max length in batch | |
| frac_lengths_mask = (0.7, 1.0) | |
| # get a random span to mask out for training conditionally | |
| frac_lengths = ( | |
| torch.zeros((batch,), device=device).float().uniform_(*frac_lengths_mask) | |
| ) | |
| rand_span_mask = mask_from_frac_lengths(lens, frac_lengths) | |
| if exists(mask): | |
| rand_span_mask &= mask | |
| # Sample a time | |
| time = torch.rand((batch,), dtype=dtype, device=device) | |
| x1 = inp | |
| x0 = torch.randn_like(x1) | |
| t = time.unsqueeze(-1).unsqueeze(-1) | |
| phi = (1 - t) * x0 + t * x1 | |
| flow = x1 - x0 | |
| cond = torch.where(rand_span_mask[..., None], torch.zeros_like(x1), x1) | |
| layers = fake_unet( | |
| x=phi, | |
| cond=cond, | |
| text=text, | |
| time=time, | |
| drop_audio_cond=False, | |
| drop_text=False, | |
| classify_mode=True, | |
| ) | |
| # layers = torch.stack(layers, dim=1).transpose(-1, -2).flatten(start_dim=1, end_dim=2) | |
| # print(layers.shape) | |
| from ctcmodel import ConformerCTC | |
| ctcmodel = ConformerCTC( | |
| vocab_size=vocab_size, mel_dim=80, num_heads=8, d_hid=512, nlayers=6 | |
| ).cuda() | |
| real_out, layer = ctcmodel(inp) | |
| layer = layer[-3:] # only use the last 3 layers | |
| layer = [ | |
| F.interpolate(l, mode="nearest", scale_factor=4).transpose(-1, -2) | |
| for l in layer | |
| ] | |
| if layer[0].size(1) < layers[0].size(1): | |
| layer = [F.pad(l, (0, 0, 0, layers[0].size(1) - l.size(1))) for l in layer] | |
| layers = layer + layers | |
| model = ConformerDiscirminator(input_dim=23 * 1024 + 3 * 512, channels=512) | |
| model = model.cuda() | |
| print(model) | |
| out = model(layers) | |
| print(out.shape) | |