Feature Extraction
Transformers
Safetensors
English
usad
automatic-speech-recognition
audio-classification
audio
speech
music
custom_code
Instructions to use MIT-SLS/USAD-Large with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use MIT-SLS/USAD-Large with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("feature-extraction", model="MIT-SLS/USAD-Large", trust_remote_code=True)# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("MIT-SLS/USAD-Large", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
| from dataclasses import make_dataclass | |
| import torch | |
| import torchaudio | |
| from torch import nn | |
| from .usad_modules import ConformerEncoder | |
| MAX_MEL_LENGTH = 3000 # 30 seconds | |
| def wav_to_fbank( | |
| wavs: torch.Tensor, | |
| mel_dim: int = 128, | |
| norm_mean: float = -4.268, | |
| norm_std: float = 4.569, | |
| ) -> torch.Tensor: | |
| """Convert waveform to fbank features. | |
| Args: | |
| wavs (torch.Tensor): (B, T_wav) waveform tensor. | |
| mel_dim (int, optional): mel dimension. Defaults to 128. | |
| norm_mean (float, optional): | |
| mean for normalization. Defaults to -4.268. | |
| norm_std (float, optional): | |
| std for normalization. Defaults to 4.569. | |
| Returns: | |
| torch.Tensor: (B, T_mel, mel_dim) fbank features. | |
| """ | |
| # ref: https://github.com/cwx-worst-one/EAT/tree/main/feature_extract | |
| dtype = wavs.dtype | |
| wavs = wavs.to(torch.float32) | |
| wavs = wavs - wavs.mean(dim=-1, keepdim=True) | |
| feats = [ | |
| torchaudio.compliance.kaldi.fbank( | |
| wavs[i : i + 1], | |
| htk_compat=True, | |
| sample_frequency=16000, | |
| use_energy=False, | |
| window_type="hanning", | |
| num_mel_bins=mel_dim, | |
| dither=0.0, | |
| frame_shift=10, | |
| ).to(dtype=dtype) | |
| for i in range(wavs.shape[0]) | |
| ] | |
| mels = torch.stack(feats, dim=0) | |
| mels = (mels - norm_mean) / (norm_std * 2) | |
| return mels | |
| class UsadModel(nn.Module): | |
| def __init__(self, cfg) -> None: | |
| """Initialize the UsadModel. | |
| Args: | |
| cfg: Configuration object containing model parameters. | |
| """ | |
| super().__init__() | |
| self.cfg = cfg | |
| self.encoder = ConformerEncoder(cfg) | |
| self.max_mel_length = MAX_MEL_LENGTH | |
| # NOTE: The max_mel_length is set to 3000, | |
| # which corresponds to 30 seconds of audio at 100 Hz frame rate. | |
| def sample_rate(self) -> int: | |
| return 16000 # Hz | |
| def encoder_frame_rate(self) -> int: | |
| return 50 # Hz | |
| def mel_dim(self) -> int: | |
| return self.cfg.input_dim | |
| def encoder_dim(self) -> int: | |
| return self.cfg.encoder_dim | |
| def num_layers(self) -> int: | |
| return self.cfg.num_layers | |
| def scene_embedding_size(self) -> int: | |
| return self.cfg.encoder_dim * self.cfg.num_layers | |
| def timestamp_embedding_size(self) -> int: | |
| return self.cfg.encoder_dim * self.cfg.num_layers | |
| def device(self) -> torch.device: | |
| """Get the device on which the model is located.""" | |
| return next(self.parameters()).device | |
| def set_audio_chunk_size(self, seconds: float = 30.0) -> None: | |
| """Set the maximum chunk size for feature extraction. | |
| Args: | |
| seconds (float, optional): Chunk size in seconds. Defaults to 30.0. | |
| """ | |
| assert ( | |
| seconds >= 0.1 | |
| ), f"Chunk size must be greater than 0.1s, got {seconds} seconds." | |
| self.max_mel_length = int(seconds * 100) # 100 Hz frame rate | |
| def load_audio(self, audio_path: str) -> torch.Tensor: | |
| """Load audio file and return waveform tensor. | |
| Args: | |
| audio_path (str): Path to the audio file. | |
| Returns: | |
| torch.Tensor: Waveform tensor of shape (wav_len,). | |
| """ | |
| waveform, sr = torchaudio.load(audio_path) | |
| if sr != self.sample_rate: | |
| waveform = torchaudio.functional.resample(waveform, sr, self.sample_rate) | |
| if waveform.shape[0] > 1: | |
| # If stereo, convert to mono by averaging channels | |
| waveform = waveform.mean(dim=0, keepdim=True) | |
| waveform = waveform.squeeze(0) # Remove channel dimension if mono | |
| return waveform.to(self.device) # Ensure tensor is on the same device | |
| def forward( | |
| self, | |
| wavs: torch.Tensor, | |
| norm_mean: float = -4.268, | |
| norm_std: float = 4.569, | |
| ) -> dict: | |
| """Forward pass for the model. | |
| Args: | |
| wavs (torch.Tensor): | |
| Input waveform tensor of shape (batch_size, wav_len). | |
| norm_mean (float, optional): | |
| Mean for normalization. Defaults to -4.268. | |
| norm_std (float, optional): | |
| Standard deviation for normalization. Defaults to 4.569. | |
| Returns: | |
| dict: A dictionary containing the model's outputs. | |
| """ | |
| # wavs: (batch_size, wav_len) | |
| mel = wav_to_fbank(wavs, norm_mean=norm_mean, norm_std=norm_std) | |
| mel = mel[:, : mel.shape[1] - mel.shape[1] % 2] | |
| if mel.shape[1] <= self.max_mel_length: | |
| x, x_len, layer_results = self.encoder(mel, return_hidden=True) | |
| result = { | |
| "x": x, | |
| "mel": mel, | |
| "hidden_states": layer_results["hidden_states"], | |
| "ffn": layer_results["ffn_1"], | |
| } | |
| return result | |
| result = { | |
| "x": [], | |
| "mel": mel, | |
| "hidden_states": [[] for _ in range(self.cfg.num_layers)], | |
| "ffn": [[] for _ in range(self.cfg.num_layers)], | |
| } | |
| for i in range(0, mel.shape[1], self.max_mel_length): | |
| if mel.shape[1] - i < 10: | |
| break | |
| x, x_len, layer_results = self.encoder( | |
| mel[:, i : i + self.max_mel_length], return_hidden=True | |
| ) | |
| result["x"].append(x) | |
| for j in range(self.cfg.num_layers): | |
| result["hidden_states"][j].append(layer_results["hidden_states"][j]) | |
| result["ffn"][j].append(layer_results["ffn_1"][j]) | |
| result["x"] = torch.cat(result["x"], dim=1) | |
| for j in range(self.cfg.num_layers): | |
| result["hidden_states"][j] = torch.cat(result["hidden_states"][j], dim=1) | |
| result["ffn"][j] = torch.cat(result["ffn"][j], dim=1) | |
| # result["x"]: model final output (batch_size, seq_len) | |
| # result["mel"]: mel fbank (batch_size, seq_len * 2, mel_dim) | |
| # result["hidden_states"]: List of (batch_size, seq_len, encoder_dim) | |
| # result["ffn"]: List of (batch_size, seq_len, encoder_dim) | |
| return result | |
| def load_from_fairseq_ckpt(cls, ckpt_path: str): | |
| checkpoint = torch.load(ckpt_path, weights_only=False) | |
| config = checkpoint["cfg"]["model"] | |
| config = make_dataclass("Config", config.keys())(**config) | |
| model = cls(config) | |
| state_dict = checkpoint["model"] | |
| for k in list(state_dict.keys()): | |
| if not k.startswith("encoder."): | |
| del state_dict[k] | |
| model.load_state_dict(state_dict, strict=True) | |
| return model | |