From 3820d0fa4e401bc2bf446e50eff8a216ee93a163 Mon Sep 17 00:00:00 2001 From: Vivek Goel Date: Thu, 29 Feb 2024 20:22:33 +0530 Subject: [PATCH 1/9] enable hpu_graph support for wav2vec2-asr (#59) --- examples/speech-recognition/README.md | 11 +- optimum/habana/transformers/modeling_utils.py | 2 + .../habana/transformers/models/__init__.py | 1 + .../transformers/models/wav2vec2/__init__.py | 1 + .../models/wav2vec2/modeling_wav2vec2.py | 161 +++++++++++++----- tests/baselines/wav2vec2_large_lv60.json | 8 +- 6 files changed, 137 insertions(+), 47 deletions(-) diff --git a/examples/speech-recognition/README.md b/examples/speech-recognition/README.md index 510a52d213..9fe8fc720e 100644 --- a/examples/speech-recognition/README.md +++ b/examples/speech-recognition/README.md @@ -78,7 +78,9 @@ python run_speech_recognition_ctc.py \ --use_lazy_mode \ --gaudi_config_name="Habana/wav2vec2" \ --throughput_warmup_steps="3" \ - --bf16 + --bf16 \ + --use_hpu_graphs_for_training \ + --use_hpu_grpahs_for_inference ``` On a single HPU, this script should run in *ca.* 6 hours and yield a CTC loss of **0.059** and a word error rate of **0.0423**. @@ -117,7 +119,9 @@ python ../gaudi_spawn.py \ --use_lazy_mode \ --gaudi_config_name Habana/wav2vec2 \ --throughput_warmup_steps 3 \ - --bf16 + --bf16 \ + --use_hpu_graphs_for_training \ + --use_hpu_graphs_for_inference ``` On 8 HPUs, this script should run in *ca.* 49 minutes and yield a CTC loss of **0.0613** and a word error rate of **0.0458**. @@ -196,7 +200,8 @@ python run_speech_recognition_ctc.py \ --use_habana \ --use_lazy_mode \ --gaudi_config_name="Habana/wav2vec2" \ - --bf16 + --bf16 \ + --use_hpu_graphs_for_inference ``` ## Sequence to Sequence diff --git a/optimum/habana/transformers/modeling_utils.py b/optimum/habana/transformers/modeling_utils.py index f12e7d3540..721a16a0f9 100644 --- a/optimum/habana/transformers/modeling_utils.py +++ b/optimum/habana/transformers/modeling_utils.py @@ -120,6 +120,7 @@ gaudi_vit_self_attention_forward, gaudi_wav2vec2_encoder_forward, gaudi_wav2vec2_forward, + gaudi_wav2vec2forctc_forward, ) @@ -143,6 +144,7 @@ def adapt_transformers_to_gaudi(): ) transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Model.forward = gaudi_wav2vec2_forward transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Encoder.forward = gaudi_wav2vec2_encoder_forward + transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC.forward = gaudi_wav2vec2forctc_forward # Generation is modified to run faster in lazy mode transformers.generation.GenerationMixin.generate = GaudiGenerationMixin.generate diff --git a/optimum/habana/transformers/models/__init__.py b/optimum/habana/transformers/models/__init__.py index aa04c6f3c3..a4b1bba1ca 100644 --- a/optimum/habana/transformers/models/__init__.py +++ b/optimum/habana/transformers/models/__init__.py @@ -126,4 +126,5 @@ _gaudi_wav2vec2_sample_negative_indices, gaudi_wav2vec2_encoder_forward, gaudi_wav2vec2_forward, + gaudi_wav2vec2forctc_forward, ) diff --git a/optimum/habana/transformers/models/wav2vec2/__init__.py b/optimum/habana/transformers/models/wav2vec2/__init__.py index e38a0ec0fa..3a60ce43f6 100644 --- a/optimum/habana/transformers/models/wav2vec2/__init__.py +++ b/optimum/habana/transformers/models/wav2vec2/__init__.py @@ -4,4 +4,5 @@ _gaudi_wav2vec2_sample_negative_indices, gaudi_wav2vec2_encoder_forward, gaudi_wav2vec2_forward, + gaudi_wav2vec2forctc_forward, ) diff --git a/optimum/habana/transformers/models/wav2vec2/modeling_wav2vec2.py b/optimum/habana/transformers/models/wav2vec2/modeling_wav2vec2.py index b38af4b1b4..bb8640cb2e 100644 --- a/optimum/habana/transformers/models/wav2vec2/modeling_wav2vec2.py +++ b/optimum/habana/transformers/models/wav2vec2/modeling_wav2vec2.py @@ -17,13 +17,18 @@ from typing import Optional, Tuple, Union import torch +from habana_frameworks.torch.hpex.kernels import CTCLoss from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled from transformers.modeling_outputs import ( BaseModelOutput, + CausalLMOutput, Wav2Vec2BaseModelOutput, ) +ctc_loss_fwd = CTCLoss.apply + + def _gaudi_wav2vec2_compute_mask_indices( shape: Tuple[int, int], mask_prob: float, @@ -33,7 +38,8 @@ def _gaudi_wav2vec2_compute_mask_indices( ) -> torch.Tensor: """ Copied from Transformers: https://github.com/huggingface/transformers/blob/bd469c40659ce76c81f69c7726759d249b4aef49/src/transformers/models/wav2vec2/modeling_wav2vec2.py#L135 - The only difference is that the processing is performed with PyTorch on HPUs (Numpy is used in Transformers). + The only differences are (1) that the processing is performed with PyTorch on HPUs (Numpy is used in Transformers), (2) epsilon is generated on HPU instead of CPU, (3) check + to ensure indices are not larger than sequence length is re-written to avoid host sync. """ batch_size, sequence_length = shape @@ -122,8 +128,9 @@ def compute_num_masked_span(input_length): spec_aug_mask_idxs = spec_aug_mask_idxs + offsets # ensure that we cannot have indices larger than sequence_length - if spec_aug_mask_idxs.max() > sequence_length - 1: - spec_aug_mask_idxs[spec_aug_mask_idxs > sequence_length - 1] = sequence_length - 1 + mask = (spec_aug_mask_idxs > sequence_length - 1) * (spec_aug_mask_idxs.max() > sequence_length - 1) + inverse_mask = torch.bitwise_not(mask) + spec_aug_mask_idxs = spec_aug_mask_idxs * inverse_mask + (sequence_length - 1) * mask # scatter indices to mask spec_aug_mask.scatter_(-1, spec_aug_mask_idxs, 1) @@ -172,6 +179,63 @@ def _gaudi_wav2vec2_sample_negative_indices( return sampled_negative_indices +def gaudi_wav2vec2_forward( + self, + input_values: Optional[torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + mask_time_indices: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, +) -> Union[Tuple, Wav2Vec2BaseModelOutput]: + """ + Copied from Transformers: https://github.com/huggingface/transformers/blob/bd469c40659ce76c81f69c7726759d249b4aef49/src/transformers/models/wav2vec2/modeling_wav2vec2.py#L1282 + The only difference is that a clone of `hidden_states` is given to _mask_hidden_states to avoid an error. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + extract_features = self.feature_extractor(input_values) + extract_features = extract_features.transpose(1, 2) + + if attention_mask is not None: + # compute reduced attention_mask corresponding to feature vectors + attention_mask = self._get_feature_vector_attention_mask( + extract_features.shape[1], attention_mask, add_adapter=False + ) + + hidden_states, extract_features = self.feature_projection(extract_features) + hidden_states = self._mask_hidden_states( + hidden_states.clone(), mask_time_indices=mask_time_indices, attention_mask=attention_mask + ) + + encoder_outputs = self.encoder( + hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = encoder_outputs[0] + + if self.adapter is not None: + hidden_states = self.adapter(hidden_states) + + if not return_dict: + return (hidden_states, extract_features) + encoder_outputs[1:] + + return Wav2Vec2BaseModelOutput( + last_hidden_state=hidden_states, + extract_features=extract_features, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + def _gaudi_wav2vec2_mask_hidden_states( self, hidden_states: torch.FloatTensor, @@ -300,58 +364,71 @@ def gaudi_wav2vec2_encoder_forward( ) -def gaudi_wav2vec2_forward( +_HIDDEN_STATES_START_POSITION = 2 + + +def gaudi_wav2vec2forctc_forward( self, input_values: Optional[torch.Tensor], attention_mask: Optional[torch.Tensor] = None, - mask_time_indices: Optional[torch.FloatTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, -) -> Union[Tuple, Wav2Vec2BaseModelOutput]: + labels: Optional[torch.Tensor] = None, +) -> Union[Tuple, CausalLMOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, target_length)`, *optional*): + Labels for connectionist temporal classification. Note that `target_length` has to be smaller or equal to + the sequence length of the output logits. Indices are selected in `[-100, 0, ..., config.vocab_size - 1]`. + All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ..., + config.vocab_size - 1]`. """ - Copied from Transformers: https://github.com/huggingface/transformers/blob/bd469c40659ce76c81f69c7726759d249b4aef49/src/transformers/models/wav2vec2/modeling_wav2vec2.py#L1282 - The only difference is that a clone of `hidden_states` is given to _mask_hidden_states to avoid an error. """ - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) + copied from Transformers https://github.com/huggingface/transformers/blob/e770f0316d2a9b787c9d1440f204fcb65e176682/src/transformers/models/wav2vec2/modeling_wav2vec2.py#L1950 + only differences are (1) attention_mask tensor generation using ones_like is done on HPU, (2) masked_select is not applied on labels to compute flattened_targets to avoid + changing flattened_targets tensor shapes across training iterations. + """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - extract_features = self.feature_extractor(input_values) - extract_features = extract_features.transpose(1, 2) - - if attention_mask is not None: - # compute reduced attention_mask corresponding to feature vectors - attention_mask = self._get_feature_vector_attention_mask( - extract_features.shape[1], attention_mask, add_adapter=False - ) - - hidden_states, extract_features = self.feature_projection(extract_features) - hidden_states = self._mask_hidden_states( - hidden_states.clone(), mask_time_indices=mask_time_indices, attention_mask=attention_mask - ) - - encoder_outputs = self.encoder( - hidden_states, + outputs = self.wav2vec2( + input_values, attention_mask=attention_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) - - hidden_states = encoder_outputs[0] - - if self.adapter is not None: - hidden_states = self.adapter(hidden_states) + hidden_states = outputs[0] + hidden_states = self.dropout(hidden_states) + logits = self.lm_head(hidden_states) + loss = None + if labels is not None: + if labels.max() >= self.config.vocab_size: + raise ValueError(f"Label values must be <= vocab_size: {self.config.vocab_size}") + # retrieve loss input_lengths from attention_mask + attention_mask = ( + attention_mask + if attention_mask is not None + else torch.ones_like(input_values, dtype=torch.long, device="hpu") + ) + input_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long) + # assuming that padded tokens are filled with -100 + # when not being attended to + labels_mask = labels >= 0 + target_lengths = labels_mask.sum(-1) + flattened_targets = labels + # ctc_loss doesn't support fp16 + log_probs = torch.nn.functional.log_softmax(logits, dim=-1, dtype=torch.float32).transpose(0, 1) + with torch.backends.cudnn.flags(enabled=False): + loss = ctc_loss_fwd( + log_probs, + flattened_targets, + input_lengths, + target_lengths, + self.config.pad_token_id, + self.config.ctc_loss_reduction, + self.config.ctc_zero_infinity, + ) if not return_dict: - return (hidden_states, extract_features) + encoder_outputs[1:] - - return Wav2Vec2BaseModelOutput( - last_hidden_state=hidden_states, - extract_features=extract_features, - hidden_states=encoder_outputs.hidden_states, - attentions=encoder_outputs.attentions, - ) + output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:] + return ((loss,) + output) if loss is not None else output + return CausalLMOutput(loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions) diff --git a/tests/baselines/wav2vec2_large_lv60.json b/tests/baselines/wav2vec2_large_lv60.json index b1071302fa..86fa3b92b5 100644 --- a/tests/baselines/wav2vec2_large_lv60.json +++ b/tests/baselines/wav2vec2_large_lv60.json @@ -21,7 +21,9 @@ "--layerdrop 0.0", "--freeze_feature_encoder", "--dataloader_num_workers 8", - "--chars_to_ignore ',?.!-;:\"“%‘”'" + "--chars_to_ignore ',?.!-;:\"“%‘”'", + "--use_hpu_graphs_for_training", + "--use_hpu_graphs_for_inference" ] } } @@ -49,7 +51,9 @@ "--layerdrop 0.0", "--freeze_feature_encoder", "--dataloader_num_workers 8", - "--chars_to_ignore ',?.!-;:\"“%‘”'" + "--chars_to_ignore ',?.!-;:\"“%‘”'", + "--use_hpu_graphs_for_training", + "--use_hpu_graphs_for_inference" ] } } From af78afe2590de480809b9b3261a9e8de02a74595 Mon Sep 17 00:00:00 2001 From: Vivek Goel Date: Thu, 7 Mar 2024 09:29:24 +0530 Subject: [PATCH 2/9] Run custom ctc_loss only for Gaudi2 (#95) --- .../models/wav2vec2/modeling_wav2vec2.py | 46 +++++++++++++------ 1 file changed, 32 insertions(+), 14 deletions(-) diff --git a/optimum/habana/transformers/models/wav2vec2/modeling_wav2vec2.py b/optimum/habana/transformers/models/wav2vec2/modeling_wav2vec2.py index bb8640cb2e..4e428829fb 100644 --- a/optimum/habana/transformers/models/wav2vec2/modeling_wav2vec2.py +++ b/optimum/habana/transformers/models/wav2vec2/modeling_wav2vec2.py @@ -18,6 +18,7 @@ import torch from habana_frameworks.torch.hpex.kernels import CTCLoss +from habana_frameworks.torch.hpu import get_device_name from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled from transformers.modeling_outputs import ( BaseModelOutput, @@ -128,9 +129,13 @@ def compute_num_masked_span(input_length): spec_aug_mask_idxs = spec_aug_mask_idxs + offsets # ensure that we cannot have indices larger than sequence_length - mask = (spec_aug_mask_idxs > sequence_length - 1) * (spec_aug_mask_idxs.max() > sequence_length - 1) - inverse_mask = torch.bitwise_not(mask) - spec_aug_mask_idxs = spec_aug_mask_idxs * inverse_mask + (sequence_length - 1) * mask + if get_device_name() == "GAUDI": + if spec_aug_mask_idxs.max() > sequence_length - 1: + spec_aug_mask_idxs[spec_aug_mask_idxs > sequence_length - 1] = sequence_length - 1 + else: + mask = (spec_aug_mask_idxs > sequence_length - 1) * (spec_aug_mask_idxs.max() > sequence_length - 1) + inverse_mask = torch.bitwise_not(mask) + spec_aug_mask_idxs = spec_aug_mask_idxs * inverse_mask + (sequence_length - 1) * mask # scatter indices to mask spec_aug_mask.scatter_(-1, spec_aug_mask_idxs, 1) @@ -414,19 +419,32 @@ def gaudi_wav2vec2forctc_forward( # when not being attended to labels_mask = labels >= 0 target_lengths = labels_mask.sum(-1) - flattened_targets = labels # ctc_loss doesn't support fp16 log_probs = torch.nn.functional.log_softmax(logits, dim=-1, dtype=torch.float32).transpose(0, 1) - with torch.backends.cudnn.flags(enabled=False): - loss = ctc_loss_fwd( - log_probs, - flattened_targets, - input_lengths, - target_lengths, - self.config.pad_token_id, - self.config.ctc_loss_reduction, - self.config.ctc_zero_infinity, - ) + if get_device_name() == "GAUDI": + flattened_targets = labels.masked_select(labels_mask) + with torch.backends.cudnn.flags(enabled=False): + loss = torch.nn.functional.ctc_loss( + log_probs, + flattened_targets, + input_lengths, + target_lengths, + blank=self.config.pad_token_id, + reduction=self.config.ctc_loss_reduction, + zero_infinity=self.config.ctc_zero_infinity, + ) + else: + flattened_targets = labels + with torch.backends.cudnn.flags(enabled=False): + loss = ctc_loss_fwd( + log_probs, + flattened_targets, + input_lengths, + target_lengths, + self.config.pad_token_id, + self.config.ctc_loss_reduction, + self.config.ctc_zero_infinity, + ) if not return_dict: output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:] From 937533bd5ece60b3a3c198167c12904a01d41af9 Mon Sep 17 00:00:00 2001 From: Vivek Date: Sat, 9 Mar 2024 09:57:35 +0200 Subject: [PATCH 3/9] Update test baseline --- tests/baselines/wav2vec2_large_lv60.json | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/tests/baselines/wav2vec2_large_lv60.json b/tests/baselines/wav2vec2_large_lv60.json index 86fa3b92b5..6792b855ee 100644 --- a/tests/baselines/wav2vec2_large_lv60.json +++ b/tests/baselines/wav2vec2_large_lv60.json @@ -21,9 +21,7 @@ "--layerdrop 0.0", "--freeze_feature_encoder", "--dataloader_num_workers 8", - "--chars_to_ignore ',?.!-;:\"“%‘”'", - "--use_hpu_graphs_for_training", - "--use_hpu_graphs_for_inference" + "--chars_to_ignore ',?.!-;:\"“%‘”'" ] } } @@ -35,12 +33,12 @@ "eval_batch_size": 8, "distribution": { "multi_card": { - "learning_rate": 3e-4, + "learning_rate": 4e-4, "train_batch_size": 8, - "eval_wer": 0.0531535105117017, - "train_runtime": 356.4723, - "train_samples_per_second": 183.245, - "eval_samples_per_second": 158.985, + "eval_wer": 0.06120587068623562, + "train_runtime": 308.8036, + "train_samples_per_second": 225.572, + "eval_samples_per_second": 196.665, "extra_arguments": [ "--dataset_config_name clean", "--train_split_name train.100", From 392ff55bb4b14e10bf61fbf78acf0e7f771044fa Mon Sep 17 00:00:00 2001 From: Vivek Date: Tue, 26 Mar 2024 09:37:19 +0200 Subject: [PATCH 4/9] Fix backward compatibility issue --- .../models/wav2vec2/modeling_wav2vec2.py | 27 +++++++------------ 1 file changed, 10 insertions(+), 17 deletions(-) diff --git a/optimum/habana/transformers/models/wav2vec2/modeling_wav2vec2.py b/optimum/habana/transformers/models/wav2vec2/modeling_wav2vec2.py index 4e428829fb..bbd403a788 100644 --- a/optimum/habana/transformers/models/wav2vec2/modeling_wav2vec2.py +++ b/optimum/habana/transformers/models/wav2vec2/modeling_wav2vec2.py @@ -17,7 +17,6 @@ from typing import Optional, Tuple, Union import torch -from habana_frameworks.torch.hpex.kernels import CTCLoss from habana_frameworks.torch.hpu import get_device_name from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled from transformers.modeling_outputs import ( @@ -25,10 +24,15 @@ CausalLMOutput, Wav2Vec2BaseModelOutput, ) +from transformers.models.wav2vec2.modeling_wav2vec2 import _HIDDEN_STATES_START_POSITION -ctc_loss_fwd = CTCLoss.apply - +try: + from habana_frameworks.torch.hpex.kernels import CTCLoss + custom_ctc_loss_fwd = CTCLoss.apply +except ImportError: + print("Could not import Custom CTCLoss kernel. This Kernel is available only for SynapseAI >= 1.15.0") + custom_ctc_loss_fwd = None def _gaudi_wav2vec2_compute_mask_indices( shape: Tuple[int, int], @@ -129,7 +133,7 @@ def compute_num_masked_span(input_length): spec_aug_mask_idxs = spec_aug_mask_idxs + offsets # ensure that we cannot have indices larger than sequence_length - if get_device_name() == "GAUDI": + if get_device_name() == "GAUDI" or custom_ctc_loss_fwd is None: if spec_aug_mask_idxs.max() > sequence_length - 1: spec_aug_mask_idxs[spec_aug_mask_idxs > sequence_length - 1] = sequence_length - 1 else: @@ -368,10 +372,6 @@ def gaudi_wav2vec2_encoder_forward( attentions=all_self_attentions, ) - -_HIDDEN_STATES_START_POSITION = 2 - - def gaudi_wav2vec2forctc_forward( self, input_values: Optional[torch.Tensor], @@ -381,13 +381,6 @@ def gaudi_wav2vec2forctc_forward( return_dict: Optional[bool] = None, labels: Optional[torch.Tensor] = None, ) -> Union[Tuple, CausalLMOutput]: - r""" - labels (`torch.LongTensor` of shape `(batch_size, target_length)`, *optional*): - Labels for connectionist temporal classification. Note that `target_length` has to be smaller or equal to - the sequence length of the output logits. Indices are selected in `[-100, 0, ..., config.vocab_size - 1]`. - All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ..., - config.vocab_size - 1]`. - """ """ copied from Transformers https://github.com/huggingface/transformers/blob/e770f0316d2a9b787c9d1440f204fcb65e176682/src/transformers/models/wav2vec2/modeling_wav2vec2.py#L1950 only differences are (1) attention_mask tensor generation using ones_like is done on HPU, (2) masked_select is not applied on labels to compute flattened_targets to avoid @@ -421,7 +414,7 @@ def gaudi_wav2vec2forctc_forward( target_lengths = labels_mask.sum(-1) # ctc_loss doesn't support fp16 log_probs = torch.nn.functional.log_softmax(logits, dim=-1, dtype=torch.float32).transpose(0, 1) - if get_device_name() == "GAUDI": + if get_device_name() == "GAUDI" or custom_ctc_loss_fwd is None: flattened_targets = labels.masked_select(labels_mask) with torch.backends.cudnn.flags(enabled=False): loss = torch.nn.functional.ctc_loss( @@ -436,7 +429,7 @@ def gaudi_wav2vec2forctc_forward( else: flattened_targets = labels with torch.backends.cudnn.flags(enabled=False): - loss = ctc_loss_fwd( + loss = custom_ctc_loss_fwd( log_probs, flattened_targets, input_lengths, From b9006651fbfbd616c924f589c68a36acb0010d0c Mon Sep 17 00:00:00 2001 From: Vivek Date: Tue, 26 Mar 2024 10:33:28 +0200 Subject: [PATCH 5/9] Fix code formatting --- optimum/habana/transformers/modeling_utils.py | 2 +- optimum/habana/transformers/models/__init__.py | 2 +- optimum/habana/transformers/models/wav2vec2/__init__.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/optimum/habana/transformers/modeling_utils.py b/optimum/habana/transformers/modeling_utils.py index 0f1269a6fa..10d2db216c 100644 --- a/optimum/habana/transformers/modeling_utils.py +++ b/optimum/habana/transformers/modeling_utils.py @@ -134,8 +134,8 @@ gaudi_vit_self_attention_forward, gaudi_wav2vec2_encoder_forward, gaudi_wav2vec2_forward, - gaudi_wav2vec2forctc_forward, gaudi_wav2vec2_tdnnlayer_forward, + gaudi_wav2vec2forctc_forward, ) diff --git a/optimum/habana/transformers/models/__init__.py b/optimum/habana/transformers/models/__init__.py index cda62409c0..19d39893b4 100644 --- a/optimum/habana/transformers/models/__init__.py +++ b/optimum/habana/transformers/models/__init__.py @@ -145,6 +145,6 @@ _gaudi_wav2vec2_sample_negative_indices, gaudi_wav2vec2_encoder_forward, gaudi_wav2vec2_forward, - gaudi_wav2vec2forctc_forward, gaudi_wav2vec2_tdnnlayer_forward, + gaudi_wav2vec2forctc_forward, ) diff --git a/optimum/habana/transformers/models/wav2vec2/__init__.py b/optimum/habana/transformers/models/wav2vec2/__init__.py index df43104ce5..84372061b6 100644 --- a/optimum/habana/transformers/models/wav2vec2/__init__.py +++ b/optimum/habana/transformers/models/wav2vec2/__init__.py @@ -4,6 +4,6 @@ _gaudi_wav2vec2_sample_negative_indices, gaudi_wav2vec2_encoder_forward, gaudi_wav2vec2_forward, - gaudi_wav2vec2forctc_forward, gaudi_wav2vec2_tdnnlayer_forward, + gaudi_wav2vec2forctc_forward, ) From 0521d3a15eb9a26323d7f89080bb23ff06946c0a Mon Sep 17 00:00:00 2001 From: Vivek Date: Tue, 26 Mar 2024 11:01:28 +0200 Subject: [PATCH 6/9] Fix code formatting in modeling_wav2vec.py --- .../models/wav2vec2/modeling_wav2vec2.py | 192 +++++++++--------- 1 file changed, 97 insertions(+), 95 deletions(-) diff --git a/optimum/habana/transformers/models/wav2vec2/modeling_wav2vec2.py b/optimum/habana/transformers/models/wav2vec2/modeling_wav2vec2.py index a34eb41745..afef8d580f 100644 --- a/optimum/habana/transformers/models/wav2vec2/modeling_wav2vec2.py +++ b/optimum/habana/transformers/models/wav2vec2/modeling_wav2vec2.py @@ -29,11 +29,13 @@ try: from habana_frameworks.torch.hpex.kernels import CTCLoss + custom_ctc_loss_fwd = CTCLoss.apply except ImportError: print("Could not import Custom CTCLoss kernel. This Kernel is available only for SynapseAI >= 1.15.0") custom_ctc_loss_fwd = None + def _gaudi_wav2vec2_compute_mask_indices( shape: Tuple[int, int], mask_prob: float, @@ -188,63 +190,6 @@ def _gaudi_wav2vec2_sample_negative_indices( return sampled_negative_indices -def gaudi_wav2vec2_forward( - self, - input_values: Optional[torch.Tensor], - attention_mask: Optional[torch.Tensor] = None, - mask_time_indices: Optional[torch.FloatTensor] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, -) -> Union[Tuple, Wav2Vec2BaseModelOutput]: - """ - Copied from Transformers: https://github.com/huggingface/transformers/blob/bd469c40659ce76c81f69c7726759d249b4aef49/src/transformers/models/wav2vec2/modeling_wav2vec2.py#L1282 - The only difference is that a clone of `hidden_states` is given to _mask_hidden_states to avoid an error. - """ - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - extract_features = self.feature_extractor(input_values) - extract_features = extract_features.transpose(1, 2) - - if attention_mask is not None: - # compute reduced attention_mask corresponding to feature vectors - attention_mask = self._get_feature_vector_attention_mask( - extract_features.shape[1], attention_mask, add_adapter=False - ) - - hidden_states, extract_features = self.feature_projection(extract_features) - hidden_states = self._mask_hidden_states( - hidden_states.clone(), mask_time_indices=mask_time_indices, attention_mask=attention_mask - ) - - encoder_outputs = self.encoder( - hidden_states, - attention_mask=attention_mask, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - hidden_states = encoder_outputs[0] - - if self.adapter is not None: - hidden_states = self.adapter(hidden_states) - - if not return_dict: - return (hidden_states, extract_features) + encoder_outputs[1:] - - return Wav2Vec2BaseModelOutput( - last_hidden_state=hidden_states, - extract_features=extract_features, - hidden_states=encoder_outputs.hidden_states, - attentions=encoder_outputs.attentions, - ) - - def _gaudi_wav2vec2_mask_hidden_states( self, hidden_states: torch.FloatTensor, @@ -372,6 +317,83 @@ def gaudi_wav2vec2_encoder_forward( attentions=all_self_attentions, ) + +def gaudi_wav2vec2_forward( + self, + input_values: Optional[torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + mask_time_indices: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, +) -> Union[Tuple, Wav2Vec2BaseModelOutput]: + """ + Copied from Transformers: https://github.com/huggingface/transformers/blob/bd469c40659ce76c81f69c7726759d249b4aef49/src/transformers/models/wav2vec2/modeling_wav2vec2.py#L1282 + The only difference is that a clone of `hidden_states` is given to _mask_hidden_states to avoid an error. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + extract_features = self.feature_extractor(input_values) + extract_features = extract_features.transpose(1, 2) + + if attention_mask is not None: + # compute reduced attention_mask corresponding to feature vectors + attention_mask = self._get_feature_vector_attention_mask( + extract_features.shape[1], attention_mask, add_adapter=False + ) + + hidden_states, extract_features = self.feature_projection(extract_features) + hidden_states = self._mask_hidden_states( + hidden_states.clone(), mask_time_indices=mask_time_indices, attention_mask=attention_mask + ) + + encoder_outputs = self.encoder( + hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = encoder_outputs[0] + + if self.adapter is not None: + hidden_states = self.adapter(hidden_states) + + if not return_dict: + return (hidden_states, extract_features) + encoder_outputs[1:] + + return Wav2Vec2BaseModelOutput( + last_hidden_state=hidden_states, + extract_features=extract_features, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +def gaudi_wav2vec2_tdnnlayer_forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + """ + Copied from Transformers: https://github.com/huggingface/transformers/blob/v4.37.2/src/transformers/models/wav2vec2/modeling_wav2vec2.py#L2290 + v4.38.2 implementation caused accuracy issue to run pytest Wav2Vec2RobustModelTest. + """ + hidden_states = hidden_states.unsqueeze(1) + hidden_states = torch.nn.functional.unfold( + hidden_states, + (self.kernel_size, self.in_conv_dim), + stride=(1, self.in_conv_dim), + dilation=(self.dilation, 1), + ) + hidden_states = hidden_states.transpose(1, 2) + hidden_states = self.kernel(hidden_states) + + hidden_states = self.activation(hidden_states) + return hidden_states + + def gaudi_wav2vec2forctc_forward( self, input_values: Optional[torch.Tensor], @@ -416,48 +438,28 @@ def gaudi_wav2vec2forctc_forward( log_probs = torch.nn.functional.log_softmax(logits, dim=-1, dtype=torch.float32).transpose(0, 1) if get_device_name() == "GAUDI" or custom_ctc_loss_fwd is None: flattened_targets = labels.masked_select(labels_mask) - with torch.backends.cudnn.flags(enabled=False): - loss = torch.nn.functional.ctc_loss( - log_probs, - flattened_targets, - input_lengths, - target_lengths, - blank=self.config.pad_token_id, - reduction=self.config.ctc_loss_reduction, - zero_infinity=self.config.ctc_zero_infinity, - ) + loss = torch.nn.functional.ctc_loss( + log_probs, + flattened_targets, + input_lengths, + target_lengths, + blank=self.config.pad_token_id, + reduction=self.config.ctc_loss_reduction, + zero_infinity=self.config.ctc_zero_infinity, + ) else: flattened_targets = labels - with torch.backends.cudnn.flags(enabled=False): - loss = custom_ctc_loss_fwd( - log_probs, - flattened_targets, - input_lengths, - target_lengths, - self.config.pad_token_id, - self.config.ctc_loss_reduction, - self.config.ctc_zero_infinity, - ) + loss = custom_ctc_loss_fwd( + log_probs, + flattened_targets, + input_lengths, + target_lengths, + self.config.pad_token_id, + self.config.ctc_loss_reduction, + self.config.ctc_zero_infinity, + ) if not return_dict: output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:] return ((loss,) + output) if loss is not None else output return CausalLMOutput(loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions) - -def gaudi_wav2vec2_tdnnlayer_forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - """ - Copied from Transformers: https://github.com/huggingface/transformers/blob/v4.37.2/src/transformers/models/wav2vec2/modeling_wav2vec2.py#L2290 - v4.38.2 implementation caused accuracy issue to run pytest Wav2Vec2RobustModelTest. - """ - hidden_states = hidden_states.unsqueeze(1) - hidden_states = torch.nn.functional.unfold( - hidden_states, - (self.kernel_size, self.in_conv_dim), - stride=(1, self.in_conv_dim), - dilation=(self.dilation, 1), - ) - hidden_states = hidden_states.transpose(1, 2) - hidden_states = self.kernel(hidden_states) - - hidden_states = self.activation(hidden_states) - return hidden_states From ed4cb5597b8a815b847920278afff6428f0e830b Mon Sep 17 00:00:00 2001 From: Vivek Date: Tue, 26 Mar 2024 11:08:09 +0200 Subject: [PATCH 7/9] Update Readme --- examples/speech-recognition/README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/speech-recognition/README.md b/examples/speech-recognition/README.md index 9fe8fc720e..13fadf10f5 100644 --- a/examples/speech-recognition/README.md +++ b/examples/speech-recognition/README.md @@ -86,7 +86,7 @@ python run_speech_recognition_ctc.py \ On a single HPU, this script should run in *ca.* 6 hours and yield a CTC loss of **0.059** and a word error rate of **0.0423**. > If your data has a sampling rate which is different from the one of the data the model was trained on, this script will raise an error. -> Resampling with the `datasets` library is not supported on HPUs yet. +> Resampling with the `datasets` library is not supported on HPUs yet. HPU graphs are supported only on Gaudi2. ### Multi-HPU CTC @@ -127,7 +127,7 @@ python ../gaudi_spawn.py \ On 8 HPUs, this script should run in *ca.* 49 minutes and yield a CTC loss of **0.0613** and a word error rate of **0.0458**. > If your data has a sampling rate which is different from the one of the data the model was trained on, this script will raise an error. -> Resampling with the `datasets` library is not supported on HPUs yet. +> Resampling with the `datasets` library is not supported on HPUs yet. HPU graphs are supported only on Gaudi2. ## DeepSpeed From f5c1a1895c7b7bf4251b1e612f4395a6a8099a2a Mon Sep 17 00:00:00 2001 From: Vivek Goel Date: Tue, 26 Mar 2024 19:18:52 +0530 Subject: [PATCH 8/9] Update examples/speech-recognition/README.md Co-authored-by: regisss <15324346+regisss@users.noreply.github.com> --- examples/speech-recognition/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/speech-recognition/README.md b/examples/speech-recognition/README.md index 13fadf10f5..1793da842f 100644 --- a/examples/speech-recognition/README.md +++ b/examples/speech-recognition/README.md @@ -86,7 +86,7 @@ python run_speech_recognition_ctc.py \ On a single HPU, this script should run in *ca.* 6 hours and yield a CTC loss of **0.059** and a word error rate of **0.0423**. > If your data has a sampling rate which is different from the one of the data the model was trained on, this script will raise an error. -> Resampling with the `datasets` library is not supported on HPUs yet. HPU graphs are supported only on Gaudi2. +> Resampling with the `datasets` library is not supported on HPUs yet. HPU graphs are supported only on Gaudi2 and from SynapseAI v1.15. ### Multi-HPU CTC From 37040e216fc87fd805d6307c6baec921b837a006 Mon Sep 17 00:00:00 2001 From: Vivek Goel Date: Tue, 26 Mar 2024 19:19:00 +0530 Subject: [PATCH 9/9] Update examples/speech-recognition/README.md Co-authored-by: regisss <15324346+regisss@users.noreply.github.com> --- examples/speech-recognition/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/speech-recognition/README.md b/examples/speech-recognition/README.md index 1793da842f..753428027c 100644 --- a/examples/speech-recognition/README.md +++ b/examples/speech-recognition/README.md @@ -127,7 +127,7 @@ python ../gaudi_spawn.py \ On 8 HPUs, this script should run in *ca.* 49 minutes and yield a CTC loss of **0.0613** and a word error rate of **0.0458**. > If your data has a sampling rate which is different from the one of the data the model was trained on, this script will raise an error. -> Resampling with the `datasets` library is not supported on HPUs yet. HPU graphs are supported only on Gaudi2. +> Resampling with the `datasets` library is not supported on HPUs yet. HPU graphs are supported only on Gaudi2 and from SynapseAI v1.15. ## DeepSpeed