diff --git a/optimum/habana/transformers/models/speecht5/modeling_speecht5.py b/optimum/habana/transformers/models/speecht5/modeling_speecht5.py index 596d4c5305..e175f91ddd 100644 --- a/optimum/habana/transformers/models/speecht5/modeling_speecht5.py +++ b/optimum/habana/transformers/models/speecht5/modeling_speecht5.py @@ -3,6 +3,7 @@ import torch import torch.utils.checkpoint from torch import nn +from transformers.cache_utils import Cache, DynamicCache, EncoderDecoderCache from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled from transformers.integrations.fsdp import is_fsdp_managed_module from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask @@ -22,76 +23,50 @@ def gaudi_SpeechT5Attention_forward( self, hidden_states: torch.Tensor, key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[tuple[torch.Tensor]] = None, + past_key_value: Optional["Cache"] = None, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, position_bias: Optional[torch.Tensor] = None, output_attentions: bool = False, cache_position: Optional[torch.Tensor] = None, token_idx: Optional[torch.Tensor] = None, -) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: +) -> tuple[torch.Tensor, Optional[torch.Tensor]]: """ - Copied from SpeechT5Attention.forward: https://github.com/huggingface/transformers/blob/v4.37.2/src/transformers/models/speecht5/modeling_speecht5.py The only differences are: - add new args token_idx + - update to HF 4.55.4 Cache API (no explicit past_key_value tuple in/out) """ - # if key_value_states are provided this layer is used as a cross-attention layer - # for the decoder is_cross_attention = key_value_states is not None - bsz, tgt_len, _ = hidden_states.size() # get query proj query_states = self.q_proj(hidden_states) * self.scaling - # get key, value proj - if is_cross_attention and past_key_value is not None: - # reuse k,v, cross_attentions - key_states = past_key_value[0] - value_states = past_key_value[1] - elif is_cross_attention: - # cross_attentions - key_states = self._shape(self.k_proj(key_value_states), -1, bsz) - value_states = self._shape(self.v_proj(key_value_states), -1, bsz) - elif past_key_value is not None: - # reuse k, v, self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - if token_idx is not None: - past_key_value[0].index_copy_(2, token_idx - 1, key_states) - past_key_value[1].index_copy_(2, token_idx - 1, value_states) - key_states = past_key_value[0] - value_states = past_key_value[1] - else: - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) - else: - # self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - - if self.is_decoder: - # if cross_attention save tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states, value_states) + + current_states = key_value_states if is_cross_attention else hidden_states + + key_states = self.k_proj(current_states) + value_states = self.v_proj(current_states) + key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + if past_key_value is not None: + _cache_pos = None if is_cross_attention else cache_position + key_states, value_states = past_key_value.update( + key_states, value_states, getattr(self, "layer_idx", None), {"cache_position": _cache_pos} + ) proj_shape = (bsz * self.num_heads, -1, self.head_dim) - query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) - key_states = key_states.view(*proj_shape) - value_states = value_states.view(*proj_shape) + query_states = query_states.view(bsz, tgt_len, self.num_heads, self.head_dim).transpose(1, 2).reshape(*proj_shape) + key_states = key_states.reshape(*proj_shape) + value_states = value_states.reshape(*proj_shape) src_len = key_states.size(1) attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): raise ValueError( - f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" - f" {attn_weights.size()}" + f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}" ) # relative attention bias @@ -122,35 +97,24 @@ def gaudi_SpeechT5Attention_forward( attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) if output_attentions: - # this operation is a bit awkward, but it's required to - # make sure that attn_weights keeps its gradient. - # In order to do so, attn_weights have to be reshaped - # twice and have to be reused in the following attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) else: attn_weights_reshaped = None attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) - attn_output = torch.bmm(attn_probs, value_states) if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" - f" {attn_output.size()}" + f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output.size()}" ) - attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) - attn_output = attn_output.transpose(1, 2) - - # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be - # partitioned aross GPUs when using tensor-parallelism. + attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim).transpose(1, 2) attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) - attn_output = self.out_proj(attn_output) - return attn_output, attn_weights_reshaped, past_key_value + return attn_output, attn_weights_reshaped def gaudi_SpeechT5DecoderLayer_forward( @@ -161,71 +125,61 @@ def gaudi_SpeechT5DecoderLayer_forward( encoder_attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, cross_attn_layer_head_mask: Optional[torch.Tensor] = None, - past_key_value: Optional[tuple[torch.Tensor]] = None, + past_key_value: Optional["Cache"] = None, # 4.55.4 output_attentions: Optional[bool] = False, use_cache: Optional[bool] = True, cache_position: Optional[torch.Tensor] = None, token_idx: Optional[torch.Tensor] = None, ): """ - Copied from SpeechT5DecoderLayer.forward: https://github.com/huggingface/transformers/blob/v4.37.2/src/transformers/models/speecht5/modeling_speecht5.py + Copied from SpeechT5DecoderLayer.forward: https://github.com/huggingface/transformers/blob/v4.55.4/src/transformers/models/speecht5/modeling_speecht5.py The only differences are: - add token_idx in self-attention + - align with HF 4.55.4: no present_key_value returned (cache is updated in-place) """ residual = hidden_states # Self Attention - # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 - self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None - # add present self-attn cache to positions 1,2 of present_key_value tuple - hidden_states, self_attn_weights, present_key_value = self.self_attn( + self_attn_outputs = self.self_attn( hidden_states=hidden_states, - past_key_value=self_attn_past_key_value, + past_key_value=past_key_value.self_attention_cache if past_key_value is not None else None, attention_mask=attention_mask, layer_head_mask=layer_head_mask, output_attentions=output_attentions, cache_position=cache_position, token_idx=token_idx, ) - hidden_states = self.dropout(hidden_states) + hidden_states = self.dropout(self_attn_outputs[0]) hidden_states = residual + hidden_states hidden_states = self.self_attn_layer_norm(hidden_states) # Cross-Attention Block - cross_attn_present_key_value = None cross_attn_weights = None if encoder_hidden_states is not None: residual = hidden_states - - # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple - cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None - hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn( + cross_outputs = self.encoder_attn( hidden_states=hidden_states, key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, layer_head_mask=cross_attn_layer_head_mask, - past_key_value=cross_attn_past_key_value, + past_key_value=past_key_value.cross_attention_cache if past_key_value is not None else None, output_attentions=output_attentions, cache_position=cache_position, ) - hidden_states = self.dropout(hidden_states) + hidden_states = self.dropout(cross_outputs[0]) hidden_states = residual + hidden_states hidden_states = self.encoder_attn_layer_norm(hidden_states) - - # add cross-attn to positions 3,4 of present_key_value tuple - present_key_value = present_key_value + cross_attn_present_key_value + if output_attentions: + cross_attn_weights = cross_outputs[1] # Fully Connected hidden_states = hidden_states + self.feed_forward(hidden_states) hidden_states = self.final_layer_norm(hidden_states) outputs = (hidden_states,) - if output_attentions: - outputs += (self_attn_weights, cross_attn_weights) - - if use_cache: - outputs += (present_key_value,) + # self-attn weights + outputs += (self_attn_outputs[1], cross_attn_weights) return outputs @@ -238,7 +192,7 @@ def gaudi_SpeechT5Decoder_forward( encoder_attention_mask: Optional[torch.LongTensor] = None, head_mask: Optional[torch.Tensor] = None, cross_attn_head_mask: Optional[torch.Tensor] = None, - past_key_values: Optional[list[torch.FloatTensor]] = None, + past_key_values: Optional["Cache"] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, @@ -247,10 +201,11 @@ def gaudi_SpeechT5Decoder_forward( token_idx: Optional[torch.Tensor] = None, ) -> Union[tuple, BaseModelOutputWithPastAndCrossAttentions]: """ - Copied from SpeechT5Decoder.forward: https://github.com/huggingface/transformers/blob/v4.37.2/src/transformers/models/speecht5/modeling_speecht5.py + Copied from SpeechT5Decoder.forward: https://github.com/huggingface/transformers/blob/v4.55.4/src/transformers/models/speecht5/modeling_speecht5.py The only differences are: - add token_idx args - use _gaudi_prepare_4d_causal_attention_mask + - align with HF 4.55.4 Cache API (no next_decoder_cache) """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -261,7 +216,7 @@ def gaudi_SpeechT5Decoder_forward( input_shape = hidden_states.size()[:-1] - past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 attention_mask = _gaudi_prepare_4d_causal_attention_mask( attention_mask, input_shape, hidden_states, past_key_values_length @@ -269,7 +224,6 @@ def gaudi_SpeechT5Decoder_forward( # expand encoder attention mask if encoder_hidden_states is not None and encoder_attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] encoder_attention_mask = _prepare_4d_attention_mask( encoder_attention_mask, hidden_states.dtype, tgt_len=input_shape[-1] ) @@ -283,26 +237,21 @@ def gaudi_SpeechT5Decoder_forward( ) use_cache = False - # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None - next_decoder_cache = () if use_cache else None - # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): if attn_mask is not None: if attn_mask.size()[0] != (len(self.layers)): raise ValueError( - f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for" - f" {head_mask.size()[0]}." + f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}." ) for idx, decoder_layer in enumerate(self.layers): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description) skip_the_layer = False if self.training: dropout_probability = torch.rand([]) @@ -310,8 +259,6 @@ def gaudi_SpeechT5Decoder_forward( if skip_the_layer and not synced_gpus: continue - past_key_value = past_key_values[idx] if past_key_values is not None else None - layer_outputs = decoder_layer( hidden_states, attention_mask=attention_mask, @@ -319,7 +266,7 @@ def gaudi_SpeechT5Decoder_forward( encoder_attention_mask=encoder_attention_mask, layer_head_mask=(head_mask[idx] if head_mask is not None else None), cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None), - past_key_value=past_key_value, + past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, @@ -327,29 +274,24 @@ def gaudi_SpeechT5Decoder_forward( ) hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache += (layer_outputs[3 if output_attentions else 1],) - if output_attentions: all_self_attentions = all_self_attentions + (layer_outputs[1],) - if encoder_hidden_states is not None: all_cross_attentions = all_cross_attentions + (layer_outputs[2],) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - next_cache = next_decoder_cache if use_cache else None if not return_dict: return tuple( v - for v in [hidden_states, next_cache, all_hidden_states, all_self_attentions, all_cross_attentions] + for v in [hidden_states, None, all_hidden_states, all_self_attentions, all_cross_attentions] if v is not None ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=None, hidden_states=all_hidden_states, attentions=all_self_attentions, cross_attentions=all_cross_attentions, @@ -367,23 +309,23 @@ def gaudi_generate_speech( vocoder: Optional[nn.Module] = None, output_cross_attentions: bool = False, return_output_lengths: bool = False, -) -> Union[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor]]: +): """ - Copied from _generate_speech: https://github.com/huggingface/transformers/blob/v4.37.2/src/transformers/models/speecht5/modeling_speecht5.py - The only differences are: - - add hpu graph wrap - - add static shape support in kv-cache in _generate_speech - - disable speech_decoder_prenet_dropout to avoid variable output length + Copied and adapted from `_generate_speech` (transformers v4.55.4) + Differences introduced for Habana Gaudi: + - wrapped encoder / decoder / prenet with HPU graphs + - use static-shape kv-cache via Cache API (DynamicCache + EncoderDecoderCache) + - disable dropout in prenet for deterministic output lengths + - adjust attention_mask update order (fixes off-by-one shape mismatch) """ if speaker_embeddings is None: raise ValueError( - """`speaker_embeddings` must be specified. For example, you can use a speaker embeddings by following - the code snippet provided in this link: - https://huggingface.co/datasets/regisss/cmu-arctic-xvectors - """ + "`speaker_embeddings` must be provided (e.g. from https://huggingface.co/datasets/regisss/cmu-arctic-xvectors)." ) + from habana_frameworks.torch.hpu import wrap_in_hpu_graph + # Wrap model components with HPU graph to enable static compilation if not hasattr(model.speecht5.encoder, "clear_cache"): model.speecht5.encoder = wrap_in_hpu_graph(model.speecht5.encoder) if not hasattr(model.speecht5.decoder.wrapped_decoder, "clear_cache"): @@ -391,51 +333,65 @@ def gaudi_generate_speech( if not hasattr(model.speecht5.decoder.prenet, "clear_cache"): model.speecht5.decoder.prenet = wrap_in_hpu_graph(model.speecht5.decoder.prenet) - if attention_mask is None: - encoder_attention_mask = 1 - (input_values == model.config.pad_token_id).int() - else: - encoder_attention_mask = attention_mask + # Prepare encoder attention mask + encoder_attention_mask = ( + 1 - (input_values == model.config.pad_token_id).int() if attention_mask is None else attention_mask + ) bsz = input_values.size(0) + + # Run encoder encoder_out = model.speecht5.encoder( input_values=input_values, attention_mask=encoder_attention_mask, return_dict=True, ) + encoder_hidden_states = encoder_out.last_hidden_state - encoder_last_hidden_state = encoder_out.last_hidden_state - - # downsample encoder attention mask + # Downsample attention mask if using speech prenet if isinstance(model.speecht5.encoder, SpeechT5EncoderWithSpeechPrenet): encoder_attention_mask = model.speecht5.encoder.prenet._get_feature_vector_attention_mask( - encoder_out[0].shape[1], encoder_attention_mask + encoder_hidden_states.shape[1], encoder_attention_mask ) - maxlen = int(encoder_last_hidden_state.size(1) * maxlenratio / model.config.reduction_factor) - minlen = int(encoder_last_hidden_state.size(1) * minlenratio / model.config.reduction_factor) + # Determine dynamic decoding length bounds + maxlen = int(encoder_hidden_states.size(1) * maxlenratio / model.config.reduction_factor) + minlen = int(encoder_hidden_states.size(1) * minlenratio / model.config.reduction_factor) - # Start the output sequence with a mel spectrum that is all zeros. - output_sequence = encoder_last_hidden_state.new_zeros(bsz, 1, model.config.num_mel_bins) + # Initialize decoder inputs + output_sequence = encoder_hidden_states.new_zeros(bsz, 1, model.config.num_mel_bins) output_sequence = torch.nn.functional.pad(output_sequence, (0, 0, 0, maxlen - 1), value=model.config.pad_token_id) - spectrogram = [] - cross_attentions = [] - past_key_values = None - idx = 0 - result_spectrogram = {} - token_idx = torch.tensor(1, device=output_sequence.device) + + # Prepare attention and cache structures attention_mask = torch.zeros((bsz, maxlen), dtype=torch.long, device=output_sequence.device) + self_attention_cache = DynamicCache() + cross_attention_cache = DynamicCache() + past_key_values = EncoderDecoderCache(self_attention_cache, cross_attention_cache) + + # Internal buffers + spectrogram, cross_attentions, result_spectrogram = [], [], {} + token_idx = torch.tensor(1, device=output_sequence.device) + idx = 0 + + # Generation loop while True: idx += 1 - attention_mask.index_fill_(1, token_idx - 1, 1) - # Run the decoder prenet on the entire output sequence. + + # Prenet (disable dropout for HPU determinism) decoder_hidden_states = model.speecht5.decoder.prenet(output_sequence, speaker_embeddings) - # Run the decoder layers on the last element of the prenet output. + + # Use last step or full input depending on cache + decoder_inputs = ( + decoder_hidden_states + if past_key_values.get_seq_length() == 0 + else torch.index_select(decoder_hidden_states, 1, token_idx - 1) + ) + + # Decoder forward with caching decoder_out = model.speecht5.decoder.wrapped_decoder( - hidden_states=decoder_hidden_states - if past_key_values is None - else torch.index_select(decoder_hidden_states, 1, token_idx - 1), + hidden_states=decoder_inputs, attention_mask=attention_mask, - encoder_hidden_states=encoder_last_hidden_state, + encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, past_key_values=past_key_values, use_cache=True, @@ -444,47 +400,46 @@ def gaudi_generate_speech( token_idx=token_idx, ) + attention_mask.index_fill_(1, token_idx - 1, 1) + + # Optional cross-attention collection if output_cross_attentions: cross_attentions.append(torch.cat(decoder_out.cross_attentions, dim=0)) - last_decoder_output = decoder_out.last_hidden_state[:, 0:1, :].squeeze(1) - past_key_values = decoder_out.past_key_values - # Predict the new mel spectrum for this step in the sequence. - spectrum = model.speech_decoder_postnet.feat_out(last_decoder_output) + # Extract decoder output + last_output = decoder_out.last_hidden_state[:, 0:1, :].squeeze(1) + + # Predict mel spectrum + spectrum = model.speech_decoder_postnet.feat_out(last_output) spectrum = spectrum.view(bsz, model.config.reduction_factor, model.config.num_mel_bins) spectrogram.append(spectrum) + + # Update output sequence and token index output_sequence.index_copy_(1, token_idx, spectrum[:, -1, :].view(bsz, 1, model.config.num_mel_bins)) - # Predict the probability that this is the stop token. - prob = torch.sigmoid(model.speech_decoder_postnet.prob_out(last_decoder_output)) + prob = torch.sigmoid(model.speech_decoder_postnet.prob_out(last_output)) token_idx.add_(1) - # Finished when stop token or maximum length is reached. + + # Early exit logic if idx < minlen: continue - else: - # If the generation loop is less than maximum length time, check the ones in the batch that have met - # the prob threshold. Otherwise, assume all have met thresholds and fill other spectrograms for the batch. - if idx < maxlen: - meet_thresholds = torch.sum(prob, dim=-1) >= threshold - meet_indexes = torch.where(meet_thresholds)[0].tolist() - else: - meet_indexes = range(len(prob)) - meet_indexes = [i for i in meet_indexes if i not in result_spectrogram] - if len(meet_indexes) > 0: - spectrograms = torch.stack(spectrogram) - spectrograms = spectrograms.transpose(0, 1).flatten(1, 2) - spectrograms = model.speech_decoder_postnet.postnet(spectrograms) - for meet_index in meet_indexes: - result_spectrogram[meet_index] = spectrograms[meet_index] - if len(result_spectrogram) >= bsz: - break - + meet_indexes = ( + torch.where(torch.sum(prob, dim=-1) >= threshold)[0].tolist() if idx < maxlen else range(len(prob)) + ) + meet_indexes = [i for i in meet_indexes if i not in result_spectrogram] + if meet_indexes: + spectrograms = torch.stack(spectrogram).transpose(0, 1).flatten(1, 2) + spectrograms = model.speech_decoder_postnet.postnet(spectrograms) + for mi in meet_indexes: + result_spectrogram[mi] = spectrograms[mi] + if len(result_spectrogram) >= bsz: + break + + # Combine all generated spectrograms spectrograms = [result_spectrogram[i] for i in range(len(result_spectrogram))] + if not return_output_lengths: spectrogram = spectrograms[0] if bsz == 1 else torch.nn.utils.rnn.pad_sequence(spectrograms, batch_first=True) - if vocoder is not None: - outputs = vocoder(spectrogram) - else: - outputs = spectrogram + outputs = vocoder(spectrogram) if vocoder is not None else spectrogram if output_cross_attentions: cross_attentions = torch.cat(cross_attentions, dim=2) if bsz > 1: @@ -493,23 +448,18 @@ def gaudi_generate_speech( ) outputs = (outputs, cross_attentions) else: - # batched return values should also include the spectrogram/waveform lengths - spectrogram_lengths = [] - for i in range(bsz): - spectrogram_lengths.append(spectrograms[i].size(0)) + lengths = [s.size(0) for s in spectrograms] + spectrograms = torch.nn.utils.rnn.pad_sequence(spectrograms, batch_first=True) if vocoder is None: - spectrograms = torch.nn.utils.rnn.pad_sequence(spectrograms, batch_first=True) - outputs = (spectrograms, spectrogram_lengths) + outputs = (spectrograms, lengths) else: - waveforms = [] - spectrograms = torch.nn.utils.rnn.pad_sequence(spectrograms, batch_first=True) waveforms = vocoder(spectrograms) - waveform_lengths = [int(waveforms.size(1) / max(spectrogram_lengths)) * i for i in spectrogram_lengths] + waveform_lengths = [int(waveforms.size(1) / max(lengths)) * i for i in lengths] outputs = (waveforms, waveform_lengths) if output_cross_attentions: - cross_attentions = torch.cat(cross_attentions, dim=2) - cross_attentions = cross_attentions.view( + cross_attentions = torch.cat(cross_attentions, dim=2).view( bsz, int(cross_attentions.size(0) / bsz), *cross_attentions.size()[-3:] ) outputs = (*outputs, cross_attentions) + return outputs