|
|
from typing import Optional, Tuple, Union |
|
|
|
|
|
from dataclasses import dataclass |
|
|
|
|
|
|
|
|
import torch |
|
|
from torch import nn |
|
|
from torch.nn import CrossEntropyLoss |
|
|
|
|
|
from transformers.cache_utils import Cache, DynamicCache, EncoderDecoderCache, StaticCache |
|
|
from transformers.modeling_outputs import Seq2SeqLMOutput, Seq2SeqModelOutput, BaseModelOutput |
|
|
|
|
|
from transformers.models.whisper.generation_whisper import WhisperGenerationMixin |
|
|
from transformers.models.whisper.configuration_whisper import WhisperConfig |
|
|
from transformers import WhisperPreTrainedModel |
|
|
|
|
|
from transformers.models.whisper.modeling_whisper import WhisperEncoderLayer |
|
|
from transformers import WhisperModel, WhisperForConditionalGeneration |
|
|
|
|
|
class WhisperMultitaskConfig(WhisperConfig): |
|
|
model_type = "whisper" |
|
|
keys_to_ignore_at_inference = ["past_key_values"] |
|
|
attribute_map = { |
|
|
"num_key_value_heads": "encoder_attention_heads", |
|
|
"num_attention_heads": "encoder_attention_heads", |
|
|
"hidden_size": "d_model", |
|
|
} |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
ctc_char_dropout=0.1, |
|
|
ctc_char_vocab_size=33, |
|
|
ctc_char_layers=0, |
|
|
ctc_char_hidden_layer=-1, |
|
|
ctc_phoneme_dropout=0.1, |
|
|
ctc_phoneme_vocab_size=33, |
|
|
ctc_phoneme_layers=0, |
|
|
ctc_phoneme_hidden_layer=-1, |
|
|
vad_hidden_layer=-1, |
|
|
vad_layers=0, |
|
|
diarization_hidden_layer=-1, |
|
|
diarization_max_speakers=5, |
|
|
diarization_layers=0, |
|
|
ctc_loss_reduction='mean', |
|
|
ctc_zero_infinity=True, |
|
|
**kwargs, |
|
|
): |
|
|
self.ctc_char_dropout = ctc_char_dropout |
|
|
self.ctc_char_vocab_size = ctc_char_vocab_size |
|
|
self.ctc_char_layers = ctc_char_layers |
|
|
self.ctc_char_hidden_layer = ctc_char_hidden_layer |
|
|
self.ctc_phoneme_dropout = ctc_phoneme_dropout |
|
|
self.ctc_phoneme_vocab_size = ctc_phoneme_vocab_size |
|
|
self.ctc_phoneme_layers = ctc_phoneme_layers |
|
|
self.ctc_phoneme_hidden_layer = ctc_phoneme_hidden_layer |
|
|
self.vad_hidden_layer = vad_hidden_layer |
|
|
self.vad_layers = vad_layers |
|
|
self.diarization_hidden_layer = diarization_hidden_layer |
|
|
self.diarization_max_speakers = diarization_max_speakers |
|
|
self.diarization_layers = diarization_layers |
|
|
self.ctc_loss_reduction = ctc_loss_reduction |
|
|
self.ctc_zero_infinity = ctc_zero_infinity |
|
|
|
|
|
|
|
|
|
|
|
super().__init__( |
|
|
**kwargs, |
|
|
) |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class Seq2SeqMultitaskLMOutput(Seq2SeqLMOutput): |
|
|
loss: Optional[torch.FloatTensor] = None |
|
|
logits: torch.FloatTensor = None |
|
|
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None |
|
|
decoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None |
|
|
decoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None |
|
|
cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None |
|
|
encoder_last_hidden_state: Optional[torch.FloatTensor] = None |
|
|
encoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None |
|
|
encoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None |
|
|
ctc_char_logits: Optional[torch.FloatTensor] = None |
|
|
ctc_phoneme_logits: Optional[torch.FloatTensor] = None |
|
|
vad_logits: Optional[torch.FloatTensor] = None |
|
|
diarization_logits: Optional[torch.FloatTensor] = None |
|
|
|
|
|
|
|
|
|
|
|
class WhisperMultitask(WhisperForConditionalGeneration): |
|
|
config_class = WhisperMultitaskConfig |
|
|
base_model_prefix = "model" |
|
|
_tied_weights_keys = ["proj_out.weight"] |
|
|
|
|
|
def __init__(self, config: WhisperMultitaskConfig): |
|
|
super().__init__(config) |
|
|
self.model = WhisperModel(config) |
|
|
self.proj_out = nn.Linear(config.d_model, config.vocab_size, bias=False) |
|
|
self.max_target_positions = config.max_target_positions |
|
|
|
|
|
|
|
|
ctc_char_dropout = config.ctc_char_dropout if hasattr(config, "ctc_char_dropout") and config.ctc_char_dropout else 0.1 |
|
|
self.dropout_ctc_char = nn.Dropout(ctc_char_dropout) |
|
|
ctc_char_vocab_size = config.ctc_char_vocab_size if hasattr(config, "ctc_char_vocab_size") and config.ctc_char_vocab_size else 33 |
|
|
output_hidden_size = ( |
|
|
config.output_hidden_size if hasattr(config, "add_adapter") and config.add_adapter else config.hidden_size |
|
|
) |
|
|
self.ctc_char_lm_head = nn.Linear(output_hidden_size, ctc_char_vocab_size, bias=True) |
|
|
self.ctc_char_hidden_layer = config.ctc_char_hidden_layer if hasattr(config, "ctc_char_hidden_layer") and config.ctc_char_hidden_layer else -1 |
|
|
|
|
|
self.ctc_char_layers = nn.ModuleList([WhisperEncoderLayer(config) for _ in range(config.ctc_char_layers)]) |
|
|
self.ctc_char_layer_norm = nn.LayerNorm(config.d_model) |
|
|
|
|
|
|
|
|
|
|
|
ctc_phoneme_dropout = config.ctc_phoneme_dropout if hasattr(config, "ctc_phoneme_dropout") and config.ctc_phoneme_dropout else 0.1 |
|
|
self.dropout_ctc_phoneme = nn.Dropout(ctc_phoneme_dropout) |
|
|
ctc_phoneme_vocab_size = config.ctc_phoneme_vocab_size if hasattr(config, "ctc_phoneme_vocab_size") and config.ctc_phoneme_vocab_size else 33 |
|
|
output_hidden_size = ( |
|
|
config.output_hidden_size if hasattr(config, "add_adapter") and config.add_adapter else config.hidden_size |
|
|
) |
|
|
self.ctc_phoneme_lm_head = nn.Linear(output_hidden_size, ctc_phoneme_vocab_size, bias=True) |
|
|
self.ctc_phoneme_hidden_layer = config.ctc_phoneme_hidden_layer if hasattr(config, "ctc_phoneme_hidden_layer") and config.ctc_phoneme_hidden_layer else -1 |
|
|
|
|
|
self.ctc_phoneme_layers = nn.ModuleList([WhisperEncoderLayer(config) for _ in range(config.ctc_phoneme_layers)]) |
|
|
self.ctc_phoneme_layer_norm = nn.LayerNorm(config.d_model) |
|
|
|
|
|
|
|
|
self.vad_classifier = nn.Linear(config.hidden_size, 1, bias=True) |
|
|
self.vad_hidden_layer = config.vad_hidden_layer if hasattr(config, "vad_hidden_layer") and config.vad_hidden_layer else -1 |
|
|
|
|
|
self.vad_layers = nn.ModuleList([WhisperEncoderLayer(config) for _ in range(config.vad_layers)]) |
|
|
self.vad_layer_norm = nn.LayerNorm(config.d_model) |
|
|
|
|
|
|
|
|
self.diarization_max_speakers = config.diarization_max_speakers if hasattr(config, "diarization_max_speakers") and config.diarization_max_speakers else 5 |
|
|
self.diarization_classifier = nn.Linear(config.hidden_size, self.diarization_max_speakers, bias=True) |
|
|
self.diarization_hidden_layer = config.diarization_hidden_layer if hasattr(config, "diarization_hidden_layer") and config.diarization_hidden_layer else -1 |
|
|
self.diarization_layers = nn.ModuleList([WhisperEncoderLayer(config) for _ in range(config.diarization_layers)]) |
|
|
self.diarization_layer_norm = nn.LayerNorm(config.d_model) |
|
|
|
|
|
|
|
|
self.ctc_loss_reduction = config.ctc_loss_reduction if hasattr(config, "ctc_loss_reduction") and config.ctc_loss_reduction else "mean" |
|
|
self.ctc_zero_infinity = config.ctc_zero_infinity if hasattr(config, "ctc_zero_infinity") and config.ctc_zero_infinity else True |
|
|
|
|
|
if config.use_weighted_layer_sum: |
|
|
num_layers = config.num_hidden_layers + 1 |
|
|
self.ctc_char_layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers) |
|
|
self.ctc_phoneme_layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers) |
|
|
self.vad_layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers) |
|
|
self.diarization_layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers) |
|
|
|
|
|
|
|
|
self.post_init() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def forward( |
|
|
self, |
|
|
input_features: Optional[torch.FloatTensor] = None, |
|
|
attention_mask: Optional[torch.LongTensor] = None, |
|
|
decoder_input_ids: Optional[torch.LongTensor] = None, |
|
|
decoder_attention_mask: Optional[torch.LongTensor] = None, |
|
|
head_mask: Optional[torch.Tensor] = None, |
|
|
decoder_head_mask: Optional[torch.Tensor] = None, |
|
|
cross_attn_head_mask: Optional[torch.Tensor] = None, |
|
|
encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, |
|
|
past_key_values: Optional[Union[EncoderDecoderCache, Tuple[torch.FloatTensor]]] = None, |
|
|
decoder_inputs_embeds: Optional[Tuple[torch.FloatTensor]] = None, |
|
|
decoder_position_ids: Optional[Tuple[torch.LongTensor]] = None, |
|
|
labels: Optional[torch.LongTensor] = None, |
|
|
use_cache: Optional[bool] = None, |
|
|
output_attentions: Optional[bool] = None, |
|
|
output_hidden_states: Optional[bool] = None, |
|
|
return_dict: Optional[bool] = None, |
|
|
cache_position: Optional[torch.LongTensor] = None, |
|
|
) -> Union[Tuple[torch.Tensor], Seq2SeqLMOutput]: |
|
|
r""" |
|
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): |
|
|
Labels for computing the language modeling loss. Indices should either be in `[0, ..., config.vocab_size]` |
|
|
or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is |
|
|
only computed for the tokens with labels in `[0, ..., config.vocab_size]`. `sequence_length` should be smaller than or equal to `config.max_target_positions`. |
|
|
|
|
|
Returns: |
|
|
|
|
|
Example: |
|
|
|
|
|
```python |
|
|
>>> import torch |
|
|
>>> from transformers import AutoProcessor, WhisperForConditionalGeneration |
|
|
>>> from datasets import load_dataset |
|
|
|
|
|
>>> processor = AutoProcessor.from_pretrained("openai/whisper-tiny.en") |
|
|
>>> model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en") |
|
|
|
|
|
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") |
|
|
|
|
|
>>> inputs = processor(ds[0]["audio"]["array"], return_tensors="pt") |
|
|
>>> input_features = inputs.input_features |
|
|
|
|
|
>>> generated_ids = model.generate(inputs=input_features) |
|
|
|
|
|
>>> transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] |
|
|
>>> transcription |
|
|
' Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.' |
|
|
```""" |
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
|
|
if labels is not None: |
|
|
if labels.shape[1] > self.max_target_positions: |
|
|
raise ValueError( |
|
|
f"Labels' sequence length {labels.shape[1]} cannot exceed the maximum allowed length of {self.max_target_positions} tokens." |
|
|
) |
|
|
if decoder_input_ids is None and decoder_inputs_embeds is None: |
|
|
decoder_input_ids = shift_tokens_right( |
|
|
labels, self.config.pad_token_id, self.config.decoder_start_token_id |
|
|
) |
|
|
|
|
|
|
|
|
if True: |
|
|
|
|
|
outputs = self.model( |
|
|
input_features, |
|
|
attention_mask=attention_mask, |
|
|
decoder_input_ids=decoder_input_ids, |
|
|
encoder_outputs=encoder_outputs, |
|
|
decoder_attention_mask=decoder_attention_mask, |
|
|
head_mask=head_mask, |
|
|
decoder_head_mask=decoder_head_mask, |
|
|
cross_attn_head_mask=cross_attn_head_mask, |
|
|
past_key_values=past_key_values, |
|
|
decoder_inputs_embeds=decoder_inputs_embeds, |
|
|
decoder_position_ids=decoder_position_ids, |
|
|
use_cache=use_cache, |
|
|
output_attentions=output_attentions, |
|
|
output_hidden_states=output_hidden_states, |
|
|
return_dict=return_dict, |
|
|
cache_position=cache_position, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if output_hidden_states and outputs.encoder_hidden_states: |
|
|
|
|
|
if self.config.use_weighted_layer_sum: |
|
|
ctc_hidden_states = outputs.encoder_hidden_states |
|
|
ctc_hidden_states = torch.stack(ctc_hidden_states, dim=1) |
|
|
norm_weights = nn.functional.softmax(self.ctc_char_layer_weights, dim=-1) |
|
|
ctc_char_hidden_states = (ctc_hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1) |
|
|
norm_weights = nn.functional.softmax(self.ctc_phoneme_layer_weights, dim=-1) |
|
|
ctc_phoneme_hidden_states = (ctc_hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1) |
|
|
norm_weights = nn.functional.softmax(self.vad_layer_weights, dim=-1) |
|
|
vad_hidden_states = (ctc_hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1) |
|
|
norm_weights = nn.functional.softmax(self.diarization_layer_weights, dim=-1) |
|
|
diarization_hidden_states = (ctc_hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1) |
|
|
else: |
|
|
ctc_char_hidden_states = outputs.encoder_hidden_states[self.ctc_char_hidden_layer] |
|
|
ctc_phoneme_hidden_states = outputs.encoder_hidden_states[self.ctc_phoneme_hidden_layer] |
|
|
vad_hidden_states = outputs.encoder_hidden_states[self.vad_hidden_layer] |
|
|
diarization_hidden_states = outputs.encoder_hidden_states[self.diarization_hidden_layer] |
|
|
|
|
|
|
|
|
|
|
|
for idx, encoder_layer in enumerate(self.ctc_char_layers): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if True: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
layer_outputs = encoder_layer( |
|
|
ctc_char_hidden_states, |
|
|
None, |
|
|
layer_head_mask=(head_mask[idx] if head_mask is not None else None), |
|
|
output_attentions=output_attentions, |
|
|
) |
|
|
|
|
|
ctc_char_hidden_states = layer_outputs[0] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ctc_char_hidden_states = self.ctc_char_layer_norm(ctc_char_hidden_states) |
|
|
|
|
|
|
|
|
|
|
|
for idx, encoder_layer in enumerate(self.ctc_phoneme_layers): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if True: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
layer_outputs = encoder_layer( |
|
|
ctc_phoneme_hidden_states, |
|
|
None, |
|
|
layer_head_mask=(head_mask[idx] if head_mask is not None else None), |
|
|
output_attentions=output_attentions, |
|
|
) |
|
|
|
|
|
ctc_phoneme_hidden_states = layer_outputs[0] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ctc_phoneme_hidden_states = self.ctc_char_layer_norm(ctc_phoneme_hidden_states) |
|
|
|
|
|
|
|
|
ctc_char_hidden_states = self.dropout_ctc_char(ctc_char_hidden_states) |
|
|
ctc_phoneme_hidden_states = self.dropout_ctc_phoneme(ctc_phoneme_hidden_states) |
|
|
|
|
|
ctc_char_logits = self.ctc_char_lm_head(ctc_char_hidden_states) |
|
|
ctc_phoneme_logits = self.ctc_phoneme_lm_head(ctc_phoneme_hidden_states) |
|
|
|
|
|
|
|
|
for idx, encoder_layer in enumerate(self.vad_layers): |
|
|
if True: |
|
|
layer_outputs = encoder_layer( |
|
|
vad_hidden_states, |
|
|
None, |
|
|
layer_head_mask=(head_mask[idx] if head_mask is not None else None), |
|
|
output_attentions=output_attentions, |
|
|
) |
|
|
|
|
|
vad_hidden_states = layer_outputs[0] |
|
|
vad_hidden_states = self.vad_layer_norm(vad_hidden_states) |
|
|
|
|
|
vad_logits = torch.sigmoid(self.vad_classifier(vad_hidden_states)) |
|
|
|
|
|
|
|
|
for idx, encoder_layer in enumerate(self.diarization_layers): |
|
|
if True: |
|
|
layer_outputs = encoder_layer( |
|
|
diarization_hidden_states, |
|
|
None, |
|
|
layer_head_mask=(head_mask[idx] if head_mask is not None else None), |
|
|
output_attentions=output_attentions, |
|
|
) |
|
|
|
|
|
diarization_hidden_states = layer_outputs[0] |
|
|
diarization_hidden_states = self.diarization_layer_norm(diarization_hidden_states) |
|
|
|
|
|
diarization_logits = torch.sigmoid(self.diarization_classifier(diarization_hidden_states)) |
|
|
|
|
|
|
|
|
lm_logits = self.proj_out(outputs[0]) |
|
|
|
|
|
loss = None |
|
|
if labels is not None: |
|
|
loss_fct = CrossEntropyLoss() |
|
|
|
|
|
labels = labels.to(lm_logits.device) |
|
|
loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.reshape(-1)) |
|
|
|
|
|
if not return_dict: |
|
|
output = (lm_logits,) + outputs[1:] |
|
|
return ((loss,) + output) if loss is not None else output |
|
|
|
|
|
if output_hidden_states and outputs.encoder_hidden_states: |
|
|
return Seq2SeqMultitaskLMOutput( |
|
|
loss=loss, |
|
|
logits=lm_logits, |
|
|
past_key_values=outputs.past_key_values, |
|
|
decoder_hidden_states=outputs.decoder_hidden_states, |
|
|
decoder_attentions=outputs.decoder_attentions, |
|
|
cross_attentions=outputs.cross_attentions, |
|
|
encoder_last_hidden_state=outputs.encoder_last_hidden_state, |
|
|
encoder_hidden_states=outputs.encoder_hidden_states, |
|
|
encoder_attentions=outputs.encoder_attentions, |
|
|
ctc_char_logits=ctc_char_logits, |
|
|
ctc_phoneme_logits=ctc_phoneme_logits, |
|
|
vad_logits=vad_logits, |
|
|
diarization_logits=diarization_logits, |
|
|
) |
|
|
|
|
|
return Seq2SeqLMOutput( |
|
|
loss=loss, |
|
|
logits=lm_logits, |
|
|
past_key_values=outputs.past_key_values, |
|
|
decoder_hidden_states=outputs.decoder_hidden_states, |
|
|
decoder_attentions=outputs.decoder_attentions, |
|
|
cross_attentions=outputs.cross_attentions, |
|
|
encoder_last_hidden_state=outputs.encoder_last_hidden_state, |
|
|
encoder_hidden_states=outputs.encoder_hidden_states, |
|
|
encoder_attentions=outputs.encoder_attentions, |
|
|
) |
|
|
|
|
|
|