From 956705de1e2f7f040069eefebc1721c45b18a9d9 Mon Sep 17 00:00:00 2001 From: gplutop7 Date: Tue, 14 Oct 2025 19:32:51 +0300 Subject: [PATCH 1/7] fix missing _shape --- .../models/speecht5/modeling_speecht5.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/optimum/habana/transformers/models/speecht5/modeling_speecht5.py b/optimum/habana/transformers/models/speecht5/modeling_speecht5.py index 596d4c5305..990b069a97 100644 --- a/optimum/habana/transformers/models/speecht5/modeling_speecht5.py +++ b/optimum/habana/transformers/models/speecht5/modeling_speecht5.py @@ -42,6 +42,9 @@ def gaudi_SpeechT5Attention_forward( bsz, tgt_len, _ = hidden_states.size() + def _reshape_for_scores(t: torch.Tensor, seq_len: int, bsz: int) -> torch.Tensor: + return t.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + # get query proj query_states = self.q_proj(hidden_states) * self.scaling # get key, value proj @@ -51,12 +54,12 @@ def gaudi_SpeechT5Attention_forward( value_states = past_key_value[1] elif is_cross_attention: # cross_attentions - key_states = self._shape(self.k_proj(key_value_states), -1, bsz) - value_states = self._shape(self.v_proj(key_value_states), -1, bsz) + key_states = _reshape_for_scores(self.k_proj(key_value_states), -1, bsz) + value_states = _reshape_for_scores(self.v_proj(key_value_states), -1, bsz) elif past_key_value is not None: # reuse k, v, self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = _reshape_for_scores(self.k_proj(hidden_states), -1, bsz) + value_states = _reshape_for_scores(self.v_proj(hidden_states), -1, bsz) if token_idx is not None: past_key_value[0].index_copy_(2, token_idx - 1, key_states) past_key_value[1].index_copy_(2, token_idx - 1, value_states) @@ -67,8 +70,8 @@ def gaudi_SpeechT5Attention_forward( value_states = torch.cat([past_key_value[1], value_states], dim=2) else: # self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = _reshape_for_scores(self.k_proj(hidden_states), -1, bsz) + value_states = _reshape_for_scores(self.v_proj(hidden_states), -1, bsz) if self.is_decoder: # if cross_attention save tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. @@ -81,7 +84,7 @@ def gaudi_SpeechT5Attention_forward( past_key_value = (key_states, value_states) proj_shape = (bsz * self.num_heads, -1, self.head_dim) - query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) + query_states = _reshape_for_scores(query_states, tgt_len, bsz).view(*proj_shape) key_states = key_states.view(*proj_shape) value_states = value_states.view(*proj_shape) From bf1be30c4a0f2417f6bfd35f76593d35f01d674d Mon Sep 17 00:00:00 2001 From: gplutop7 Date: Tue, 14 Oct 2025 20:54:41 +0300 Subject: [PATCH 2/7] adjust to transformers 4.55.4 --- .../models/speecht5/modeling_speecht5.py | 314 +++++++----------- 1 file changed, 120 insertions(+), 194 deletions(-) diff --git a/optimum/habana/transformers/models/speecht5/modeling_speecht5.py b/optimum/habana/transformers/models/speecht5/modeling_speecht5.py index 990b069a97..40671e4739 100644 --- a/optimum/habana/transformers/models/speecht5/modeling_speecht5.py +++ b/optimum/habana/transformers/models/speecht5/modeling_speecht5.py @@ -3,6 +3,7 @@ import torch import torch.utils.checkpoint from torch import nn +from transformers.cache_utils import Cache, EncoderDecoderCache from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled from transformers.integrations.fsdp import is_fsdp_managed_module from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask @@ -22,18 +23,19 @@ def gaudi_SpeechT5Attention_forward( self, hidden_states: torch.Tensor, key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[tuple[torch.Tensor]] = None, + past_key_value: Optional["Cache"] = None, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, position_bias: Optional[torch.Tensor] = None, output_attentions: bool = False, cache_position: Optional[torch.Tensor] = None, token_idx: Optional[torch.Tensor] = None, -) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + layer_idx: Optional[int] = None, +) -> tuple[torch.Tensor, Optional[torch.Tensor]]: """ - Copied from SpeechT5Attention.forward: https://github.com/huggingface/transformers/blob/v4.37.2/src/transformers/models/speecht5/modeling_speecht5.py + Copied from SpeechT5Attention.forward (transformers 4.55.4) The only differences are: - - add new args token_idx + - add new arg `token_idx` """ # if key_value_states are provided this layer is used as a cross-attention layer @@ -42,59 +44,51 @@ def gaudi_SpeechT5Attention_forward( bsz, tgt_len, _ = hidden_states.size() - def _reshape_for_scores(t: torch.Tensor, seq_len: int, bsz: int) -> torch.Tensor: - return t.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() - - # get query proj + # get query projection query_states = self.q_proj(hidden_states) * self.scaling - # get key, value proj - if is_cross_attention and past_key_value is not None: - # reuse k,v, cross_attentions - key_states = past_key_value[0] - value_states = past_key_value[1] - elif is_cross_attention: - # cross_attentions - key_states = _reshape_for_scores(self.k_proj(key_value_states), -1, bsz) - value_states = _reshape_for_scores(self.v_proj(key_value_states), -1, bsz) - elif past_key_value is not None: - # reuse k, v, self_attention - key_states = _reshape_for_scores(self.k_proj(hidden_states), -1, bsz) - value_states = _reshape_for_scores(self.v_proj(hidden_states), -1, bsz) - if token_idx is not None: - past_key_value[0].index_copy_(2, token_idx - 1, key_states) - past_key_value[1].index_copy_(2, token_idx - 1, value_states) - key_states = past_key_value[0] - value_states = past_key_value[1] + + # retrieve cache entry for this layer + if past_key_value is not None: + if is_cross_attention: + curr_past = past_key_value.cross_attention_cache else: - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) + curr_past = past_key_value.self_attention_cache + else: + curr_past = None + + # compute key/value + current_states = key_value_states if is_cross_attention else hidden_states + if curr_past is not None and curr_past.is_updated.get(layer_idx, False) and is_cross_attention: + key_states = curr_past.layers[layer_idx].keys + value_states = curr_past.layers[layer_idx].values else: - # self_attention - key_states = _reshape_for_scores(self.k_proj(hidden_states), -1, bsz) - value_states = _reshape_for_scores(self.v_proj(hidden_states), -1, bsz) - - if self.is_decoder: - # if cross_attention save tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states, value_states) + key_states = self.k_proj(current_states) + value_states = self.v_proj(current_states) + key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) + + if past_key_value is not None: + # update Cache (new HF 4.55+ mechanism) + cache_position = cache_position if not is_cross_attention else None + key_states, value_states = curr_past.update( + key_states, value_states, layer_idx, {"cache_position": cache_position} + ) + if is_cross_attention: + past_key_value.is_updated[layer_idx] = True proj_shape = (bsz * self.num_heads, -1, self.head_dim) - query_states = _reshape_for_scores(query_states, tgt_len, bsz).view(*proj_shape) - key_states = key_states.view(*proj_shape) - value_states = value_states.view(*proj_shape) + query_states = query_states.view(bsz, tgt_len, self.num_heads, self.head_dim).transpose(1, 2) + query_states = query_states.reshape(*proj_shape) + key_states = key_states.reshape(*proj_shape) + value_states = value_states.reshape(*proj_shape) src_len = key_states.size(1) attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): raise ValueError( - f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" - f" {attn_weights.size()}" + f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, " + f"but is {attn_weights.size()}" ) # relative attention bias @@ -125,35 +119,25 @@ def _reshape_for_scores(t: torch.Tensor, seq_len: int, bsz: int) -> torch.Tensor attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) if output_attentions: - # this operation is a bit awkward, but it's required to - # make sure that attn_weights keeps its gradient. - # In order to do so, attn_weights have to be reshaped - # twice and have to be reused in the following attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) else: attn_weights_reshaped = None attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) - attn_output = torch.bmm(attn_probs, value_states) if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" - f" {attn_output.size()}" + f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, " + f"but is {attn_output.size()}" ) - attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) - attn_output = attn_output.transpose(1, 2) - - # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be - # partitioned aross GPUs when using tensor-parallelism. + attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim).transpose(1, 2) attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) - attn_output = self.out_proj(attn_output) - return attn_output, attn_weights_reshaped, past_key_value + return attn_output, attn_weights_reshaped def gaudi_SpeechT5DecoderLayer_forward( @@ -164,50 +148,43 @@ def gaudi_SpeechT5DecoderLayer_forward( encoder_attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, cross_attn_layer_head_mask: Optional[torch.Tensor] = None, - past_key_value: Optional[tuple[torch.Tensor]] = None, + past_key_value: Optional["Cache"] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = True, cache_position: Optional[torch.Tensor] = None, token_idx: Optional[torch.Tensor] = None, ): """ - Copied from SpeechT5DecoderLayer.forward: https://github.com/huggingface/transformers/blob/v4.37.2/src/transformers/models/speecht5/modeling_speecht5.py + Copied from SpeechT5DecoderLayer.forward (transformers 4.55.4) The only differences are: - - add token_idx in self-attention + - add token_idx argument in self-attention """ residual = hidden_states - # Self Attention - # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 - self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None - # add present self-attn cache to positions 1,2 of present_key_value tuple - hidden_states, self_attn_weights, present_key_value = self.self_attn( + # Self-Attention (HF 4.55.4 style) + hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, - past_key_value=self_attn_past_key_value, + past_key_value=past_key_value, attention_mask=attention_mask, layer_head_mask=layer_head_mask, output_attentions=output_attentions, cache_position=cache_position, - token_idx=token_idx, + token_idx=token_idx, # Gaudi extension ) + hidden_states = self.dropout(hidden_states) hidden_states = residual + hidden_states hidden_states = self.self_attn_layer_norm(hidden_states) - # Cross-Attention Block - cross_attn_present_key_value = None cross_attn_weights = None if encoder_hidden_states is not None: residual = hidden_states - - # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple - cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None - hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn( + hidden_states, cross_attn_weights = self.encoder_attn( hidden_states=hidden_states, key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, layer_head_mask=cross_attn_layer_head_mask, - past_key_value=cross_attn_past_key_value, + past_key_value=past_key_value, output_attentions=output_attentions, cache_position=cache_position, ) @@ -215,21 +192,13 @@ def gaudi_SpeechT5DecoderLayer_forward( hidden_states = residual + hidden_states hidden_states = self.encoder_attn_layer_norm(hidden_states) - # add cross-attn to positions 3,4 of present_key_value tuple - present_key_value = present_key_value + cross_attn_present_key_value - - # Fully Connected hidden_states = hidden_states + self.feed_forward(hidden_states) hidden_states = self.final_layer_norm(hidden_states) outputs = (hidden_states,) - if output_attentions: outputs += (self_attn_weights, cross_attn_weights) - if use_cache: - outputs += (present_key_value,) - return outputs @@ -241,7 +210,7 @@ def gaudi_SpeechT5Decoder_forward( encoder_attention_mask: Optional[torch.LongTensor] = None, head_mask: Optional[torch.Tensor] = None, cross_attn_head_mask: Optional[torch.Tensor] = None, - past_key_values: Optional[list[torch.FloatTensor]] = None, + past_key_values: Optional["Cache"] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, @@ -250,9 +219,9 @@ def gaudi_SpeechT5Decoder_forward( token_idx: Optional[torch.Tensor] = None, ) -> Union[tuple, BaseModelOutputWithPastAndCrossAttentions]: """ - Copied from SpeechT5Decoder.forward: https://github.com/huggingface/transformers/blob/v4.37.2/src/transformers/models/speecht5/modeling_speecht5.py + Copied from SpeechT5Decoder.forward (transformers 4.55.4) The only differences are: - - add token_idx args + - add token_idx args for Gaudi - use _gaudi_prepare_4d_causal_attention_mask """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions @@ -263,16 +232,13 @@ def gaudi_SpeechT5Decoder_forward( return_dict = return_dict if return_dict is not None else self.config.use_return_dict input_shape = hidden_states.size()[:-1] - - past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_usable_length(cache_position) if past_key_values is not None else 0 attention_mask = _gaudi_prepare_4d_causal_attention_mask( - attention_mask, input_shape, hidden_states, past_key_values_length + attention_mask, input_shape, hidden_states, past_seen_tokens ) - # expand encoder attention mask if encoder_hidden_states is not None and encoder_attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] encoder_attention_mask = _prepare_4d_attention_mask( encoder_attention_mask, hidden_states.dtype, tgt_len=input_shape[-1] ) @@ -286,26 +252,14 @@ def gaudi_SpeechT5Decoder_forward( ) use_cache = False - # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None - next_decoder_cache = () if use_cache else None - - # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired - for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): - if attn_mask is not None: - if attn_mask.size()[0] != (len(self.layers)): - raise ValueError( - f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for" - f" {head_mask.size()[0]}." - ) for idx, decoder_layer in enumerate(self.layers): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description) skip_the_layer = False if self.training: dropout_probability = torch.rand([]) @@ -313,46 +267,36 @@ def gaudi_SpeechT5Decoder_forward( if skip_the_layer and not synced_gpus: continue - past_key_value = past_key_values[idx] if past_key_values is not None else None - - layer_outputs = decoder_layer( + hidden_states, self_attn_weights, cross_attn_weights = decoder_layer( hidden_states, attention_mask=attention_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, layer_head_mask=(head_mask[idx] if head_mask is not None else None), cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None), - past_key_value=past_key_value, + past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, - token_idx=token_idx, + token_idx=token_idx, # Gaudi extension ) - hidden_states = layer_outputs[0] - - if use_cache: - next_decoder_cache += (layer_outputs[3 if output_attentions else 1],) if output_attentions: - all_self_attentions = all_self_attentions + (layer_outputs[1],) - + all_self_attentions = all_self_attentions + (self_attn_weights,) if encoder_hidden_states is not None: - all_cross_attentions = all_cross_attentions + (layer_outputs[2],) + all_cross_attentions = all_cross_attentions + (cross_attn_weights,) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - next_cache = next_decoder_cache if use_cache else None if not return_dict: return tuple( - v - for v in [hidden_states, next_cache, all_hidden_states, all_self_attentions, all_cross_attentions] - if v is not None + v for v in [hidden_states, all_hidden_states, all_self_attentions, all_cross_attentions] if v is not None ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attentions, cross_attentions=all_cross_attentions, @@ -370,21 +314,19 @@ def gaudi_generate_speech( vocoder: Optional[nn.Module] = None, output_cross_attentions: bool = False, return_output_lengths: bool = False, -) -> Union[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor]]: +): """ - Copied from _generate_speech: https://github.com/huggingface/transformers/blob/v4.37.2/src/transformers/models/speecht5/modeling_speecht5.py - The only differences are: - - add hpu graph wrap - - add static shape support in kv-cache in _generate_speech - - disable speech_decoder_prenet_dropout to avoid variable output length + Copied from _generate_speech (transformers 4.55.4) + Differences: + - wrapped with HPU graphs + - static-shape kv-cache (Cache API) + - disable dropout in prenet """ if speaker_embeddings is None: raise ValueError( - """`speaker_embeddings` must be specified. For example, you can use a speaker embeddings by following - the code snippet provided in this link: - https://huggingface.co/datasets/regisss/cmu-arctic-xvectors - """ + "`speaker_embeddings` must be provided (e.g. from https://huggingface.co/datasets/regisss/cmu-arctic-xvectors)." ) + from habana_frameworks.torch.hpu import wrap_in_hpu_graph if not hasattr(model.speecht5.encoder, "clear_cache"): @@ -394,10 +336,9 @@ def gaudi_generate_speech( if not hasattr(model.speecht5.decoder.prenet, "clear_cache"): model.speecht5.decoder.prenet = wrap_in_hpu_graph(model.speecht5.decoder.prenet) - if attention_mask is None: - encoder_attention_mask = 1 - (input_values == model.config.pad_token_id).int() - else: - encoder_attention_mask = attention_mask + encoder_attention_mask = ( + 1 - (input_values == model.config.pad_token_id).int() if attention_mask is None else attention_mask + ) bsz = input_values.size(0) encoder_out = model.speecht5.encoder( @@ -405,40 +346,42 @@ def gaudi_generate_speech( attention_mask=encoder_attention_mask, return_dict=True, ) + encoder_hidden_states = encoder_out.last_hidden_state - encoder_last_hidden_state = encoder_out.last_hidden_state - - # downsample encoder attention mask + # Downsample attention mask if prenet used if isinstance(model.speecht5.encoder, SpeechT5EncoderWithSpeechPrenet): encoder_attention_mask = model.speecht5.encoder.prenet._get_feature_vector_attention_mask( - encoder_out[0].shape[1], encoder_attention_mask + encoder_hidden_states.shape[1], encoder_attention_mask ) - maxlen = int(encoder_last_hidden_state.size(1) * maxlenratio / model.config.reduction_factor) - minlen = int(encoder_last_hidden_state.size(1) * minlenratio / model.config.reduction_factor) + maxlen = int(encoder_hidden_states.size(1) * maxlenratio / model.config.reduction_factor) + minlen = int(encoder_hidden_states.size(1) * minlenratio / model.config.reduction_factor) - # Start the output sequence with a mel spectrum that is all zeros. - output_sequence = encoder_last_hidden_state.new_zeros(bsz, 1, model.config.num_mel_bins) + output_sequence = encoder_hidden_states.new_zeros(bsz, 1, model.config.num_mel_bins) output_sequence = torch.nn.functional.pad(output_sequence, (0, 0, 0, maxlen - 1), value=model.config.pad_token_id) - spectrogram = [] - cross_attentions = [] - past_key_values = None - idx = 0 - result_spectrogram = {} + spectrogram, cross_attentions, result_spectrogram = [], [], {} token_idx = torch.tensor(1, device=output_sequence.device) attention_mask = torch.zeros((bsz, maxlen), dtype=torch.long, device=output_sequence.device) + + # ✅ new Cache object + past_key_values = EncoderDecoderCache.init(model.speecht5.decoder.config, batch_size=bsz) + + idx = 0 while True: idx += 1 attention_mask.index_fill_(1, token_idx - 1, 1) - # Run the decoder prenet on the entire output sequence. + decoder_hidden_states = model.speecht5.decoder.prenet(output_sequence, speaker_embeddings) - # Run the decoder layers on the last element of the prenet output. + decoder_inputs = ( + decoder_hidden_states + if past_key_values.get_seq_length() == 0 + else torch.index_select(decoder_hidden_states, 1, token_idx - 1) + ) + decoder_out = model.speecht5.decoder.wrapped_decoder( - hidden_states=decoder_hidden_states - if past_key_values is None - else torch.index_select(decoder_hidden_states, 1, token_idx - 1), + hidden_states=decoder_inputs, attention_mask=attention_mask, - encoder_hidden_states=encoder_last_hidden_state, + encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, past_key_values=past_key_values, use_cache=True, @@ -450,44 +393,33 @@ def gaudi_generate_speech( if output_cross_attentions: cross_attentions.append(torch.cat(decoder_out.cross_attentions, dim=0)) - last_decoder_output = decoder_out.last_hidden_state[:, 0:1, :].squeeze(1) - past_key_values = decoder_out.past_key_values - # Predict the new mel spectrum for this step in the sequence. - spectrum = model.speech_decoder_postnet.feat_out(last_decoder_output) + last_output = decoder_out.last_hidden_state[:, 0:1, :].squeeze(1) + spectrum = model.speech_decoder_postnet.feat_out(last_output) spectrum = spectrum.view(bsz, model.config.reduction_factor, model.config.num_mel_bins) spectrogram.append(spectrum) + output_sequence.index_copy_(1, token_idx, spectrum[:, -1, :].view(bsz, 1, model.config.num_mel_bins)) - # Predict the probability that this is the stop token. - prob = torch.sigmoid(model.speech_decoder_postnet.prob_out(last_decoder_output)) + prob = torch.sigmoid(model.speech_decoder_postnet.prob_out(last_output)) token_idx.add_(1) - # Finished when stop token or maximum length is reached. + if idx < minlen: continue - else: - # If the generation loop is less than maximum length time, check the ones in the batch that have met - # the prob threshold. Otherwise, assume all have met thresholds and fill other spectrograms for the batch. - if idx < maxlen: - meet_thresholds = torch.sum(prob, dim=-1) >= threshold - meet_indexes = torch.where(meet_thresholds)[0].tolist() - else: - meet_indexes = range(len(prob)) - meet_indexes = [i for i in meet_indexes if i not in result_spectrogram] - if len(meet_indexes) > 0: - spectrograms = torch.stack(spectrogram) - spectrograms = spectrograms.transpose(0, 1).flatten(1, 2) - spectrograms = model.speech_decoder_postnet.postnet(spectrograms) - for meet_index in meet_indexes: - result_spectrogram[meet_index] = spectrograms[meet_index] - if len(result_spectrogram) >= bsz: - break + meet_indexes = ( + torch.where(torch.sum(prob, dim=-1) >= threshold)[0].tolist() if idx < maxlen else range(len(prob)) + ) + meet_indexes = [i for i in meet_indexes if i not in result_spectrogram] + if meet_indexes: + spectrograms = torch.stack(spectrogram).transpose(0, 1).flatten(1, 2) + spectrograms = model.speech_decoder_postnet.postnet(spectrograms) + for mi in meet_indexes: + result_spectrogram[mi] = spectrograms[mi] + if len(result_spectrogram) >= bsz: + break spectrograms = [result_spectrogram[i] for i in range(len(result_spectrogram))] if not return_output_lengths: spectrogram = spectrograms[0] if bsz == 1 else torch.nn.utils.rnn.pad_sequence(spectrograms, batch_first=True) - if vocoder is not None: - outputs = vocoder(spectrogram) - else: - outputs = spectrogram + outputs = vocoder(spectrogram) if vocoder is not None else spectrogram if output_cross_attentions: cross_attentions = torch.cat(cross_attentions, dim=2) if bsz > 1: @@ -496,22 +428,16 @@ def gaudi_generate_speech( ) outputs = (outputs, cross_attentions) else: - # batched return values should also include the spectrogram/waveform lengths - spectrogram_lengths = [] - for i in range(bsz): - spectrogram_lengths.append(spectrograms[i].size(0)) + lengths = [s.size(0) for s in spectrograms] + spectrograms = torch.nn.utils.rnn.pad_sequence(spectrograms, batch_first=True) if vocoder is None: - spectrograms = torch.nn.utils.rnn.pad_sequence(spectrograms, batch_first=True) - outputs = (spectrograms, spectrogram_lengths) + outputs = (spectrograms, lengths) else: - waveforms = [] - spectrograms = torch.nn.utils.rnn.pad_sequence(spectrograms, batch_first=True) waveforms = vocoder(spectrograms) - waveform_lengths = [int(waveforms.size(1) / max(spectrogram_lengths)) * i for i in spectrogram_lengths] + waveform_lengths = [int(waveforms.size(1) / max(lengths)) * i for i in lengths] outputs = (waveforms, waveform_lengths) if output_cross_attentions: - cross_attentions = torch.cat(cross_attentions, dim=2) - cross_attentions = cross_attentions.view( + cross_attentions = torch.cat(cross_attentions, dim=2).view( bsz, int(cross_attentions.size(0) / bsz), *cross_attentions.size()[-3:] ) outputs = (*outputs, cross_attentions) From 4f872a640b8c4c4f41e412a2a517282124a6f315 Mon Sep 17 00:00:00 2001 From: gplutop7 Date: Tue, 14 Oct 2025 21:06:05 +0300 Subject: [PATCH 3/7] adjust to transformers 4.55.4 - part.2 --- .../habana/transformers/models/speecht5/modeling_speecht5.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/optimum/habana/transformers/models/speecht5/modeling_speecht5.py b/optimum/habana/transformers/models/speecht5/modeling_speecht5.py index 40671e4739..c9b5641640 100644 --- a/optimum/habana/transformers/models/speecht5/modeling_speecht5.py +++ b/optimum/habana/transformers/models/speecht5/modeling_speecht5.py @@ -363,8 +363,7 @@ def gaudi_generate_speech( token_idx = torch.tensor(1, device=output_sequence.device) attention_mask = torch.zeros((bsz, maxlen), dtype=torch.long, device=output_sequence.device) - # ✅ new Cache object - past_key_values = EncoderDecoderCache.init(model.speecht5.decoder.config, batch_size=bsz) + past_key_values = EncoderDecoderCache(model.speecht5.decoder.config, batch_size=bsz) idx = 0 while True: From 5aa02fee2dbc8a9e38de40772c8d001f5d14b865 Mon Sep 17 00:00:00 2001 From: gplutop7 Date: Tue, 14 Oct 2025 21:26:05 +0300 Subject: [PATCH 4/7] adjust to transformers 4.55.4 - part.3 --- .../habana/transformers/models/speecht5/modeling_speecht5.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/optimum/habana/transformers/models/speecht5/modeling_speecht5.py b/optimum/habana/transformers/models/speecht5/modeling_speecht5.py index c9b5641640..fb35db0d53 100644 --- a/optimum/habana/transformers/models/speecht5/modeling_speecht5.py +++ b/optimum/habana/transformers/models/speecht5/modeling_speecht5.py @@ -363,7 +363,7 @@ def gaudi_generate_speech( token_idx = torch.tensor(1, device=output_sequence.device) attention_mask = torch.zeros((bsz, maxlen), dtype=torch.long, device=output_sequence.device) - past_key_values = EncoderDecoderCache(model.speecht5.decoder.config, batch_size=bsz) + past_key_values = EncoderDecoderCache(model.speecht5.decoder.config) idx = 0 while True: From 3ac640fe64ee950f415e296d7d63c0484ddf7c72 Mon Sep 17 00:00:00 2001 From: gplutop7 Date: Tue, 14 Oct 2025 21:47:10 +0300 Subject: [PATCH 5/7] adjust to transformers 4.55.4 - part.4 --- .../transformers/models/speecht5/modeling_speecht5.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/optimum/habana/transformers/models/speecht5/modeling_speecht5.py b/optimum/habana/transformers/models/speecht5/modeling_speecht5.py index fb35db0d53..9726e20dce 100644 --- a/optimum/habana/transformers/models/speecht5/modeling_speecht5.py +++ b/optimum/habana/transformers/models/speecht5/modeling_speecht5.py @@ -3,7 +3,7 @@ import torch import torch.utils.checkpoint from torch import nn -from transformers.cache_utils import Cache, EncoderDecoderCache +from transformers.cache_utils import Cache, DynamicCache, EncoderDecoderCache from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled from transformers.integrations.fsdp import is_fsdp_managed_module from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask @@ -363,7 +363,9 @@ def gaudi_generate_speech( token_idx = torch.tensor(1, device=output_sequence.device) attention_mask = torch.zeros((bsz, maxlen), dtype=torch.long, device=output_sequence.device) - past_key_values = EncoderDecoderCache(model.speecht5.decoder.config) + self_attention_cache = DynamicCache() + cross_attention_cache = DynamicCache() + past_key_values = EncoderDecoderCache(self_attention_cache, cross_attention_cache) idx = 0 while True: From d44ff62460c87b1603db164ec45ebb589f2c8a75 Mon Sep 17 00:00:00 2001 From: gplutop7 Date: Tue, 14 Oct 2025 22:23:48 +0300 Subject: [PATCH 6/7] adjust to transformers 4.55.4 - part.5 --- .../models/speecht5/modeling_speecht5.py | 121 +++++++++--------- 1 file changed, 58 insertions(+), 63 deletions(-) diff --git a/optimum/habana/transformers/models/speecht5/modeling_speecht5.py b/optimum/habana/transformers/models/speecht5/modeling_speecht5.py index 9726e20dce..157a8f00a3 100644 --- a/optimum/habana/transformers/models/speecht5/modeling_speecht5.py +++ b/optimum/habana/transformers/models/speecht5/modeling_speecht5.py @@ -30,55 +30,34 @@ def gaudi_SpeechT5Attention_forward( output_attentions: bool = False, cache_position: Optional[torch.Tensor] = None, token_idx: Optional[torch.Tensor] = None, - layer_idx: Optional[int] = None, ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: """ - Copied from SpeechT5Attention.forward (transformers 4.55.4) The only differences are: - - add new arg `token_idx` + - add new args token_idx + - update to HF 4.55.4 Cache API (no explicit past_key_value tuple in/out) """ - # if key_value_states are provided this layer is used as a cross-attention layer - # for the decoder is_cross_attention = key_value_states is not None - bsz, tgt_len, _ = hidden_states.size() - # get query projection + # get query proj query_states = self.q_proj(hidden_states) * self.scaling - # retrieve cache entry for this layer - if past_key_value is not None: - if is_cross_attention: - curr_past = past_key_value.cross_attention_cache - else: - curr_past = past_key_value.self_attention_cache - else: - curr_past = None - - # compute key/value current_states = key_value_states if is_cross_attention else hidden_states - if curr_past is not None and curr_past.is_updated.get(layer_idx, False) and is_cross_attention: - key_states = curr_past.layers[layer_idx].keys - value_states = curr_past.layers[layer_idx].values - else: - key_states = self.k_proj(current_states) - value_states = self.v_proj(current_states) - key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) - - if past_key_value is not None: - # update Cache (new HF 4.55+ mechanism) - cache_position = cache_position if not is_cross_attention else None - key_states, value_states = curr_past.update( - key_states, value_states, layer_idx, {"cache_position": cache_position} - ) - if is_cross_attention: - past_key_value.is_updated[layer_idx] = True + + key_states = self.k_proj(current_states) + value_states = self.v_proj(current_states) + key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + if past_key_value is not None: + _cache_pos = None if is_cross_attention else cache_position + key_states, value_states = past_key_value.update( + key_states, value_states, getattr(self, "layer_idx", None), {"cache_position": _cache_pos} + ) proj_shape = (bsz * self.num_heads, -1, self.head_dim) - query_states = query_states.view(bsz, tgt_len, self.num_heads, self.head_dim).transpose(1, 2) - query_states = query_states.reshape(*proj_shape) + query_states = query_states.view(bsz, tgt_len, self.num_heads, self.head_dim).transpose(1, 2).reshape(*proj_shape) key_states = key_states.reshape(*proj_shape) value_states = value_states.reshape(*proj_shape) @@ -87,8 +66,7 @@ def gaudi_SpeechT5Attention_forward( if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): raise ValueError( - f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, " - f"but is {attn_weights.size()}" + f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}" ) # relative attention bias @@ -129,8 +107,7 @@ def gaudi_SpeechT5Attention_forward( if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, " - f"but is {attn_output.size()}" + f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output.size()}" ) attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim).transpose(1, 2) @@ -148,56 +125,61 @@ def gaudi_SpeechT5DecoderLayer_forward( encoder_attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, cross_attn_layer_head_mask: Optional[torch.Tensor] = None, - past_key_value: Optional["Cache"] = None, + past_key_value: Optional["Cache"] = None, # 4.55.4 output_attentions: Optional[bool] = False, use_cache: Optional[bool] = True, cache_position: Optional[torch.Tensor] = None, token_idx: Optional[torch.Tensor] = None, ): """ - Copied from SpeechT5DecoderLayer.forward (transformers 4.55.4) + Copied from SpeechT5DecoderLayer.forward: https://github.com/huggingface/transformers/blob/v4.55.4/src/transformers/models/speecht5/modeling_speecht5.py The only differences are: - - add token_idx argument in self-attention + - add token_idx in self-attention + - align with HF 4.55.4: no present_key_value returned (cache is updated in-place) """ residual = hidden_states - # Self-Attention (HF 4.55.4 style) - hidden_states, self_attn_weights = self.self_attn( + # Self Attention + self_attn_outputs = self.self_attn( hidden_states=hidden_states, - past_key_value=past_key_value, + past_key_value=past_key_value.self_attention_cache if past_key_value is not None else None, attention_mask=attention_mask, layer_head_mask=layer_head_mask, output_attentions=output_attentions, cache_position=cache_position, - token_idx=token_idx, # Gaudi extension + token_idx=token_idx, ) - - hidden_states = self.dropout(hidden_states) + hidden_states = self.dropout(self_attn_outputs[0]) hidden_states = residual + hidden_states hidden_states = self.self_attn_layer_norm(hidden_states) + # Cross-Attention Block cross_attn_weights = None if encoder_hidden_states is not None: residual = hidden_states - hidden_states, cross_attn_weights = self.encoder_attn( + cross_outputs = self.encoder_attn( hidden_states=hidden_states, key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, layer_head_mask=cross_attn_layer_head_mask, - past_key_value=past_key_value, + past_key_value=past_key_value.cross_attention_cache if past_key_value is not None else None, output_attentions=output_attentions, cache_position=cache_position, ) - hidden_states = self.dropout(hidden_states) + hidden_states = self.dropout(cross_outputs[0]) hidden_states = residual + hidden_states hidden_states = self.encoder_attn_layer_norm(hidden_states) + if output_attentions: + cross_attn_weights = cross_outputs[1] + # Fully Connected hidden_states = hidden_states + self.feed_forward(hidden_states) hidden_states = self.final_layer_norm(hidden_states) outputs = (hidden_states,) if output_attentions: - outputs += (self_attn_weights, cross_attn_weights) + # self-attn weights + outputs += (self_attn_outputs[1], cross_attn_weights) return outputs @@ -219,10 +201,11 @@ def gaudi_SpeechT5Decoder_forward( token_idx: Optional[torch.Tensor] = None, ) -> Union[tuple, BaseModelOutputWithPastAndCrossAttentions]: """ - Copied from SpeechT5Decoder.forward (transformers 4.55.4) + Copied from SpeechT5Decoder.forward: https://github.com/huggingface/transformers/blob/v4.55.4/src/transformers/models/speecht5/modeling_speecht5.py The only differences are: - - add token_idx args for Gaudi + - add token_idx args - use _gaudi_prepare_4d_causal_attention_mask + - align with HF 4.55.4 Cache API (no next_decoder_cache) """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -232,12 +215,14 @@ def gaudi_SpeechT5Decoder_forward( return_dict = return_dict if return_dict is not None else self.config.use_return_dict input_shape = hidden_states.size()[:-1] - past_seen_tokens = past_key_values.get_usable_length(cache_position) if past_key_values is not None else 0 + + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 attention_mask = _gaudi_prepare_4d_causal_attention_mask( - attention_mask, input_shape, hidden_states, past_seen_tokens + attention_mask, input_shape, hidden_states, past_key_values_length ) + # expand encoder attention mask if encoder_hidden_states is not None and encoder_attention_mask is not None: encoder_attention_mask = _prepare_4d_attention_mask( encoder_attention_mask, hidden_states.dtype, tgt_len=input_shape[-1] @@ -256,6 +241,13 @@ def gaudi_SpeechT5Decoder_forward( all_self_attentions = () if output_attentions else None all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None + for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): + if attn_mask is not None: + if attn_mask.size()[0] != (len(self.layers)): + raise ValueError( + f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}." + ) + for idx, decoder_layer in enumerate(self.layers): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) @@ -267,7 +259,7 @@ def gaudi_SpeechT5Decoder_forward( if skip_the_layer and not synced_gpus: continue - hidden_states, self_attn_weights, cross_attn_weights = decoder_layer( + layer_outputs = decoder_layer( hidden_states, attention_mask=attention_mask, encoder_hidden_states=encoder_hidden_states, @@ -278,25 +270,28 @@ def gaudi_SpeechT5Decoder_forward( output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, - token_idx=token_idx, # Gaudi extension + token_idx=token_idx, ) + hidden_states = layer_outputs[0] if output_attentions: - all_self_attentions = all_self_attentions + (self_attn_weights,) + all_self_attentions = all_self_attentions + (layer_outputs[1],) if encoder_hidden_states is not None: - all_cross_attentions = all_cross_attentions + (cross_attn_weights,) + all_cross_attentions = all_cross_attentions + (layer_outputs[2],) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) if not return_dict: return tuple( - v for v in [hidden_states, all_hidden_states, all_self_attentions, all_cross_attentions] if v is not None + v + for v in [hidden_states, None, all_hidden_states, all_self_attentions, all_cross_attentions] + if v is not None ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=past_key_values, + past_key_values=None, hidden_states=all_hidden_states, attentions=all_self_attentions, cross_attentions=all_cross_attentions, From 75a95a7d3d86298d320df5c46a472caa2c91de76 Mon Sep 17 00:00:00 2001 From: gplutop7 Date: Tue, 14 Oct 2025 23:01:40 +0300 Subject: [PATCH 7/7] adjust to transformers 4.55.4 - part.6 --- .../models/speecht5/modeling_speecht5.py | 45 ++++++++++++++----- 1 file changed, 35 insertions(+), 10 deletions(-) diff --git a/optimum/habana/transformers/models/speecht5/modeling_speecht5.py b/optimum/habana/transformers/models/speecht5/modeling_speecht5.py index 157a8f00a3..e175f91ddd 100644 --- a/optimum/habana/transformers/models/speecht5/modeling_speecht5.py +++ b/optimum/habana/transformers/models/speecht5/modeling_speecht5.py @@ -311,11 +311,12 @@ def gaudi_generate_speech( return_output_lengths: bool = False, ): """ - Copied from _generate_speech (transformers 4.55.4) - Differences: - - wrapped with HPU graphs - - static-shape kv-cache (Cache API) - - disable dropout in prenet + Copied and adapted from `_generate_speech` (transformers v4.55.4) + Differences introduced for Habana Gaudi: + - wrapped encoder / decoder / prenet with HPU graphs + - use static-shape kv-cache via Cache API (DynamicCache + EncoderDecoderCache) + - disable dropout in prenet for deterministic output lengths + - adjust attention_mask update order (fixes off-by-one shape mismatch) """ if speaker_embeddings is None: raise ValueError( @@ -324,6 +325,7 @@ def gaudi_generate_speech( from habana_frameworks.torch.hpu import wrap_in_hpu_graph + # Wrap model components with HPU graph to enable static compilation if not hasattr(model.speecht5.encoder, "clear_cache"): model.speecht5.encoder = wrap_in_hpu_graph(model.speecht5.encoder) if not hasattr(model.speecht5.decoder.wrapped_decoder, "clear_cache"): @@ -331,11 +333,14 @@ def gaudi_generate_speech( if not hasattr(model.speecht5.decoder.prenet, "clear_cache"): model.speecht5.decoder.prenet = wrap_in_hpu_graph(model.speecht5.decoder.prenet) + # Prepare encoder attention mask encoder_attention_mask = ( 1 - (input_values == model.config.pad_token_id).int() if attention_mask is None else attention_mask ) bsz = input_values.size(0) + + # Run encoder encoder_out = model.speecht5.encoder( input_values=input_values, attention_mask=encoder_attention_mask, @@ -343,37 +348,46 @@ def gaudi_generate_speech( ) encoder_hidden_states = encoder_out.last_hidden_state - # Downsample attention mask if prenet used + # Downsample attention mask if using speech prenet if isinstance(model.speecht5.encoder, SpeechT5EncoderWithSpeechPrenet): encoder_attention_mask = model.speecht5.encoder.prenet._get_feature_vector_attention_mask( encoder_hidden_states.shape[1], encoder_attention_mask ) + # Determine dynamic decoding length bounds maxlen = int(encoder_hidden_states.size(1) * maxlenratio / model.config.reduction_factor) minlen = int(encoder_hidden_states.size(1) * minlenratio / model.config.reduction_factor) + # Initialize decoder inputs output_sequence = encoder_hidden_states.new_zeros(bsz, 1, model.config.num_mel_bins) output_sequence = torch.nn.functional.pad(output_sequence, (0, 0, 0, maxlen - 1), value=model.config.pad_token_id) - spectrogram, cross_attentions, result_spectrogram = [], [], {} - token_idx = torch.tensor(1, device=output_sequence.device) - attention_mask = torch.zeros((bsz, maxlen), dtype=torch.long, device=output_sequence.device) + # Prepare attention and cache structures + attention_mask = torch.zeros((bsz, maxlen), dtype=torch.long, device=output_sequence.device) self_attention_cache = DynamicCache() cross_attention_cache = DynamicCache() past_key_values = EncoderDecoderCache(self_attention_cache, cross_attention_cache) + # Internal buffers + spectrogram, cross_attentions, result_spectrogram = [], [], {} + token_idx = torch.tensor(1, device=output_sequence.device) idx = 0 + + # Generation loop while True: idx += 1 - attention_mask.index_fill_(1, token_idx - 1, 1) + # Prenet (disable dropout for HPU determinism) decoder_hidden_states = model.speecht5.decoder.prenet(output_sequence, speaker_embeddings) + + # Use last step or full input depending on cache decoder_inputs = ( decoder_hidden_states if past_key_values.get_seq_length() == 0 else torch.index_select(decoder_hidden_states, 1, token_idx - 1) ) + # Decoder forward with caching decoder_out = model.speecht5.decoder.wrapped_decoder( hidden_states=decoder_inputs, attention_mask=attention_mask, @@ -386,18 +400,26 @@ def gaudi_generate_speech( token_idx=token_idx, ) + attention_mask.index_fill_(1, token_idx - 1, 1) + + # Optional cross-attention collection if output_cross_attentions: cross_attentions.append(torch.cat(decoder_out.cross_attentions, dim=0)) + # Extract decoder output last_output = decoder_out.last_hidden_state[:, 0:1, :].squeeze(1) + + # Predict mel spectrum spectrum = model.speech_decoder_postnet.feat_out(last_output) spectrum = spectrum.view(bsz, model.config.reduction_factor, model.config.num_mel_bins) spectrogram.append(spectrum) + # Update output sequence and token index output_sequence.index_copy_(1, token_idx, spectrum[:, -1, :].view(bsz, 1, model.config.num_mel_bins)) prob = torch.sigmoid(model.speech_decoder_postnet.prob_out(last_output)) token_idx.add_(1) + # Early exit logic if idx < minlen: continue meet_indexes = ( @@ -412,7 +434,9 @@ def gaudi_generate_speech( if len(result_spectrogram) >= bsz: break + # Combine all generated spectrograms spectrograms = [result_spectrogram[i] for i in range(len(result_spectrogram))] + if not return_output_lengths: spectrogram = spectrograms[0] if bsz == 1 else torch.nn.utils.rnn.pad_sequence(spectrograms, batch_first=True) outputs = vocoder(spectrogram) if vocoder is not None else spectrogram @@ -437,4 +461,5 @@ def gaudi_generate_speech( bsz, int(cross_attentions.size(0) / bsz), *cross_attentions.size()[-3:] ) outputs = (*outputs, cross_attentions) + return outputs