m.utrobin
upload model
11bb20f
from typing import Optional, Tuple, Union
from dataclasses import dataclass
# Mininmal version 4.46.2
import torch
from torch import nn
from torch.nn import CrossEntropyLoss
from transformers.cache_utils import Cache, DynamicCache, EncoderDecoderCache, StaticCache
from transformers.modeling_outputs import Seq2SeqLMOutput, Seq2SeqModelOutput, BaseModelOutput
from transformers.models.whisper.generation_whisper import WhisperGenerationMixin
from transformers.models.whisper.configuration_whisper import WhisperConfig
from transformers import WhisperPreTrainedModel
#from transformers.models.whisper.modeling_whisper import WhisperEncoder, WhisperDecoder
from transformers.models.whisper.modeling_whisper import WhisperEncoderLayer
from transformers import WhisperModel, WhisperForConditionalGeneration
class WhisperMultitaskConfig(WhisperConfig):
model_type = "whisper"
keys_to_ignore_at_inference = ["past_key_values"]
attribute_map = {
"num_key_value_heads": "encoder_attention_heads",
"num_attention_heads": "encoder_attention_heads",
"hidden_size": "d_model",
}
def __init__(
self,
ctc_char_dropout=0.1,
ctc_char_vocab_size=33,
ctc_char_layers=0,
ctc_char_hidden_layer=-1,
ctc_phoneme_dropout=0.1,
ctc_phoneme_vocab_size=33,
ctc_phoneme_layers=0,
ctc_phoneme_hidden_layer=-1,
vad_hidden_layer=-1,
vad_layers=0,
diarization_hidden_layer=-1,
diarization_max_speakers=5,
diarization_layers=0,
ctc_loss_reduction='mean',
ctc_zero_infinity=True,
**kwargs,
):
self.ctc_char_dropout = ctc_char_dropout
self.ctc_char_vocab_size = ctc_char_vocab_size
self.ctc_char_layers = ctc_char_layers
self.ctc_char_hidden_layer = ctc_char_hidden_layer
self.ctc_phoneme_dropout = ctc_phoneme_dropout
self.ctc_phoneme_vocab_size = ctc_phoneme_vocab_size
self.ctc_phoneme_layers = ctc_phoneme_layers
self.ctc_phoneme_hidden_layer = ctc_phoneme_hidden_layer
self.vad_hidden_layer = vad_hidden_layer
self.vad_layers = vad_layers
self.diarization_hidden_layer = diarization_hidden_layer
self.diarization_max_speakers = diarization_max_speakers
self.diarization_layers = diarization_layers
self.ctc_loss_reduction = ctc_loss_reduction
self.ctc_zero_infinity = ctc_zero_infinity
#if 'config_whisper' in kwargs.keys():
super().__init__(
**kwargs,
)
@dataclass
class Seq2SeqMultitaskLMOutput(Seq2SeqLMOutput):
loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
decoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
decoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
encoder_last_hidden_state: Optional[torch.FloatTensor] = None
encoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
encoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
ctc_char_logits: Optional[torch.FloatTensor] = None
ctc_phoneme_logits: Optional[torch.FloatTensor] = None
vad_logits: Optional[torch.FloatTensor] = None
diarization_logits: Optional[torch.FloatTensor] = None
#class WhisperMultitask(WhisperGenerationMixin, WhisperPreTrainedModel):
class WhisperMultitask(WhisperForConditionalGeneration):
config_class = WhisperMultitaskConfig
base_model_prefix = "model"
_tied_weights_keys = ["proj_out.weight"]
def __init__(self, config: WhisperMultitaskConfig):
super().__init__(config)
self.model = WhisperModel(config)
self.proj_out = nn.Linear(config.d_model, config.vocab_size, bias=False)
self.max_target_positions = config.max_target_positions
# ctc char level
ctc_char_dropout = config.ctc_char_dropout if hasattr(config, "ctc_char_dropout") and config.ctc_char_dropout else 0.1
self.dropout_ctc_char = nn.Dropout(ctc_char_dropout)
ctc_char_vocab_size = config.ctc_char_vocab_size if hasattr(config, "ctc_char_vocab_size") and config.ctc_char_vocab_size else 33
output_hidden_size = (
config.output_hidden_size if hasattr(config, "add_adapter") and config.add_adapter else config.hidden_size
)
self.ctc_char_lm_head = nn.Linear(output_hidden_size, ctc_char_vocab_size, bias=True)
self.ctc_char_hidden_layer = config.ctc_char_hidden_layer if hasattr(config, "ctc_char_hidden_layer") and config.ctc_char_hidden_layer else -1
self.ctc_char_layers = nn.ModuleList([WhisperEncoderLayer(config) for _ in range(config.ctc_char_layers)])
self.ctc_char_layer_norm = nn.LayerNorm(config.d_model)
# ctc phoneme level
ctc_phoneme_dropout = config.ctc_phoneme_dropout if hasattr(config, "ctc_phoneme_dropout") and config.ctc_phoneme_dropout else 0.1
self.dropout_ctc_phoneme = nn.Dropout(ctc_phoneme_dropout)
ctc_phoneme_vocab_size = config.ctc_phoneme_vocab_size if hasattr(config, "ctc_phoneme_vocab_size") and config.ctc_phoneme_vocab_size else 33
output_hidden_size = (
config.output_hidden_size if hasattr(config, "add_adapter") and config.add_adapter else config.hidden_size
)
self.ctc_phoneme_lm_head = nn.Linear(output_hidden_size, ctc_phoneme_vocab_size, bias=True)
self.ctc_phoneme_hidden_layer = config.ctc_phoneme_hidden_layer if hasattr(config, "ctc_phoneme_hidden_layer") and config.ctc_phoneme_hidden_layer else -1
self.ctc_phoneme_layers = nn.ModuleList([WhisperEncoderLayer(config) for _ in range(config.ctc_phoneme_layers)])
self.ctc_phoneme_layer_norm = nn.LayerNorm(config.d_model)
# vad classification
self.vad_classifier = nn.Linear(config.hidden_size, 1, bias=True)
self.vad_hidden_layer = config.vad_hidden_layer if hasattr(config, "vad_hidden_layer") and config.vad_hidden_layer else -1
self.vad_layers = nn.ModuleList([WhisperEncoderLayer(config) for _ in range(config.vad_layers)])
self.vad_layer_norm = nn.LayerNorm(config.d_model)
# diarization
self.diarization_max_speakers = config.diarization_max_speakers if hasattr(config, "diarization_max_speakers") and config.diarization_max_speakers else 5
self.diarization_classifier = nn.Linear(config.hidden_size, self.diarization_max_speakers, bias=True)
self.diarization_hidden_layer = config.diarization_hidden_layer if hasattr(config, "diarization_hidden_layer") and config.diarization_hidden_layer else -1
self.diarization_layers = nn.ModuleList([WhisperEncoderLayer(config) for _ in range(config.diarization_layers)])
self.diarization_layer_norm = nn.LayerNorm(config.d_model)
# ctc all
self.ctc_loss_reduction = config.ctc_loss_reduction if hasattr(config, "ctc_loss_reduction") and config.ctc_loss_reduction else "mean"
self.ctc_zero_infinity = config.ctc_zero_infinity if hasattr(config, "ctc_zero_infinity") and config.ctc_zero_infinity else True
if config.use_weighted_layer_sum:
num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings
self.ctc_char_layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
self.ctc_phoneme_layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
self.vad_layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
self.diarization_layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
# Initialize weights and apply final processing
self.post_init()
# def get_encoder(self):
# return self.model.get_encoder()
# def get_decoder(self):
# return self.model.get_decoder()
# def get_output_embeddings(self):
# return self.proj_out
# def set_output_embeddings(self, new_embeddings):
# self.proj_out = new_embeddings
# def get_input_embeddings(self) -> nn.Module:
# return self.model.get_input_embeddings()
# def freeze_encoder(self):
# """
# Calling this function will disable the gradient computation for the Whisper encoder so that its parameters will
# not be updated during training.
# """
# self.model.encoder._freeze_parameters()
#@add_start_docstrings_to_model_forward(WHISPER_INPUTS_DOCSTRING)
#@replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_features: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
decoder_input_ids: Optional[torch.LongTensor] = None,
decoder_attention_mask: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.Tensor] = None,
decoder_head_mask: Optional[torch.Tensor] = None,
cross_attn_head_mask: Optional[torch.Tensor] = None,
encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
past_key_values: Optional[Union[EncoderDecoderCache, Tuple[torch.FloatTensor]]] = None,
decoder_inputs_embeds: Optional[Tuple[torch.FloatTensor]] = None,
decoder_position_ids: Optional[Tuple[torch.LongTensor]] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple[torch.Tensor], Seq2SeqLMOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the language modeling loss. Indices should either be in `[0, ..., config.vocab_size]`
or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is
only computed for the tokens with labels in `[0, ..., config.vocab_size]`. `sequence_length` should be smaller than or equal to `config.max_target_positions`.
Returns:
Example:
```python
>>> import torch
>>> from transformers import AutoProcessor, WhisperForConditionalGeneration
>>> from datasets import load_dataset
>>> processor = AutoProcessor.from_pretrained("openai/whisper-tiny.en")
>>> model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
>>> inputs = processor(ds[0]["audio"]["array"], return_tensors="pt")
>>> input_features = inputs.input_features
>>> generated_ids = model.generate(inputs=input_features)
>>> transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
>>> transcription
' Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.'
```"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if labels is not None:
if labels.shape[1] > self.max_target_positions:
raise ValueError(
f"Labels' sequence length {labels.shape[1]} cannot exceed the maximum allowed length of {self.max_target_positions} tokens."
)
if decoder_input_ids is None and decoder_inputs_embeds is None:
decoder_input_ids = shift_tokens_right(
labels, self.config.pad_token_id, self.config.decoder_start_token_id
)
#if encoder_outputs is None:
if True:
#print(' - 1.1')
outputs = self.model(
input_features,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
encoder_outputs=encoder_outputs,
decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask,
decoder_head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
past_key_values=past_key_values,
decoder_inputs_embeds=decoder_inputs_embeds,
decoder_position_ids=decoder_position_ids,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
)
#else:
# print(' - 1.2')
# outputs = encoder_outputs
#print('past_key_values:', past_key_values)
#print('output_hidden_states: ', output_hidden_states)
if output_hidden_states and outputs.encoder_hidden_states:
if self.config.use_weighted_layer_sum:
ctc_hidden_states = outputs.encoder_hidden_states
ctc_hidden_states = torch.stack(ctc_hidden_states, dim=1)
norm_weights = nn.functional.softmax(self.ctc_char_layer_weights, dim=-1)
ctc_char_hidden_states = (ctc_hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
norm_weights = nn.functional.softmax(self.ctc_phoneme_layer_weights, dim=-1)
ctc_phoneme_hidden_states = (ctc_hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
norm_weights = nn.functional.softmax(self.vad_layer_weights, dim=-1)
vad_hidden_states = (ctc_hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
norm_weights = nn.functional.softmax(self.diarization_layer_weights, dim=-1)
diarization_hidden_states = (ctc_hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
else:
ctc_char_hidden_states = outputs.encoder_hidden_states[self.ctc_char_hidden_layer]
ctc_phoneme_hidden_states = outputs.encoder_hidden_states[self.ctc_phoneme_hidden_layer]
vad_hidden_states = outputs.encoder_hidden_states[self.vad_hidden_layer]
diarization_hidden_states = outputs.encoder_hidden_states[self.diarization_hidden_layer]
#ctc char layers
for idx, encoder_layer in enumerate(self.ctc_char_layers):
#if output_hidden_states:
# encoder_states = encoder_states + (hidden_states,)
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
#to_drop = False
#if self.training:
# dropout_probability = torch.rand([])
# if dropout_probability < self.layerdrop: # skip the layer
# to_drop = True
#if to_drop:
# layer_outputs = (None, None)
#else:
if True:
# if self.gradient_checkpointing and self.training:
# layer_outputs = self._gradient_checkpointing_func(
# encoder_layer.__call__,
# ctc_char_hidden_states,
# None,
# (head_mask[idx] if head_mask is not None else None),
# output_attentions,
# )
# else:
layer_outputs = encoder_layer(
ctc_char_hidden_states,
None,
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
output_attentions=output_attentions,
)
ctc_char_hidden_states = layer_outputs[0]
# if output_attentions:
# all_attentions = all_attentions + (layer_outputs[1],)
ctc_char_hidden_states = self.ctc_char_layer_norm(ctc_char_hidden_states)
#ctc phoneme layers
for idx, encoder_layer in enumerate(self.ctc_phoneme_layers):
#if output_hidden_states:
# encoder_states = encoder_states + (hidden_states,)
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
#to_drop = False
#if self.training:
# dropout_probability = torch.rand([])
# if dropout_probability < self.layerdrop: # skip the layer
# to_drop = True
#if to_drop:
# layer_outputs = (None, None)
#else:
if True:
# if self.gradient_checkpointing and self.training:
# layer_outputs = self._gradient_checkpointing_func(
# encoder_layer.__call__,
# ctc_phoneme_hidden_states,
# None,
# (head_mask[idx] if head_mask is not None else None),
# output_attentions,
# )
# else:
layer_outputs = encoder_layer(
ctc_phoneme_hidden_states,
None,
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
output_attentions=output_attentions,
)
ctc_phoneme_hidden_states = layer_outputs[0]
# if output_attentions:
# all_attentions = all_attentions + (layer_outputs[1],)
ctc_phoneme_hidden_states = self.ctc_char_layer_norm(ctc_phoneme_hidden_states)
#ctc char dropout
ctc_char_hidden_states = self.dropout_ctc_char(ctc_char_hidden_states)
ctc_phoneme_hidden_states = self.dropout_ctc_phoneme(ctc_phoneme_hidden_states)
ctc_char_logits = self.ctc_char_lm_head(ctc_char_hidden_states)
ctc_phoneme_logits = self.ctc_phoneme_lm_head(ctc_phoneme_hidden_states)
#vad layers
for idx, encoder_layer in enumerate(self.vad_layers):
if True:
layer_outputs = encoder_layer(
vad_hidden_states,
None,
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
output_attentions=output_attentions,
)
vad_hidden_states = layer_outputs[0]
vad_hidden_states = self.vad_layer_norm(vad_hidden_states)
vad_logits = torch.sigmoid(self.vad_classifier(vad_hidden_states))
#diarization layers
for idx, encoder_layer in enumerate(self.diarization_layers):
if True:
layer_outputs = encoder_layer(
diarization_hidden_states,
None,
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
output_attentions=output_attentions,
)
diarization_hidden_states = layer_outputs[0]
diarization_hidden_states = self.diarization_layer_norm(diarization_hidden_states)
diarization_logits = torch.sigmoid(self.diarization_classifier(diarization_hidden_states))
#print('ctc_logits:',ctc_logits.shape)
lm_logits = self.proj_out(outputs[0])
loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
# move labels to correct device to enable PP
labels = labels.to(lm_logits.device)
loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.reshape(-1))
if not return_dict:
output = (lm_logits,) + outputs[1:]
return ((loss,) + output) if loss is not None else output
if output_hidden_states and outputs.encoder_hidden_states:
return Seq2SeqMultitaskLMOutput(
loss=loss,
logits=lm_logits,
past_key_values=outputs.past_key_values,
decoder_hidden_states=outputs.decoder_hidden_states,
decoder_attentions=outputs.decoder_attentions,
cross_attentions=outputs.cross_attentions,
encoder_last_hidden_state=outputs.encoder_last_hidden_state,
encoder_hidden_states=outputs.encoder_hidden_states,
encoder_attentions=outputs.encoder_attentions,
ctc_char_logits=ctc_char_logits,
ctc_phoneme_logits=ctc_phoneme_logits,
vad_logits=vad_logits,
diarization_logits=diarization_logits,
)
return Seq2SeqLMOutput(
loss=loss,
logits=lm_logits,
past_key_values=outputs.past_key_values,
decoder_hidden_states=outputs.decoder_hidden_states,
decoder_attentions=outputs.decoder_attentions,
cross_attentions=outputs.cross_attentions,
encoder_last_hidden_state=outputs.encoder_last_hidden_state,
encoder_hidden_states=outputs.encoder_hidden_states,
encoder_attentions=outputs.encoder_attentions,
)