diff --git a/examples/speech-recognition/README.md b/examples/speech-recognition/README.md index 510a52d213..753428027c 100644 --- a/examples/speech-recognition/README.md +++ b/examples/speech-recognition/README.md @@ -78,13 +78,15 @@ python run_speech_recognition_ctc.py \ --use_lazy_mode \ --gaudi_config_name="Habana/wav2vec2" \ --throughput_warmup_steps="3" \ - --bf16 + --bf16 \ + --use_hpu_graphs_for_training \ + --use_hpu_grpahs_for_inference ``` On a single HPU, this script should run in *ca.* 6 hours and yield a CTC loss of **0.059** and a word error rate of **0.0423**. > If your data has a sampling rate which is different from the one of the data the model was trained on, this script will raise an error. -> Resampling with the `datasets` library is not supported on HPUs yet. +> Resampling with the `datasets` library is not supported on HPUs yet. HPU graphs are supported only on Gaudi2 and from SynapseAI v1.15. ### Multi-HPU CTC @@ -117,13 +119,15 @@ python ../gaudi_spawn.py \ --use_lazy_mode \ --gaudi_config_name Habana/wav2vec2 \ --throughput_warmup_steps 3 \ - --bf16 + --bf16 \ + --use_hpu_graphs_for_training \ + --use_hpu_graphs_for_inference ``` On 8 HPUs, this script should run in *ca.* 49 minutes and yield a CTC loss of **0.0613** and a word error rate of **0.0458**. > If your data has a sampling rate which is different from the one of the data the model was trained on, this script will raise an error. -> Resampling with the `datasets` library is not supported on HPUs yet. +> Resampling with the `datasets` library is not supported on HPUs yet. HPU graphs are supported only on Gaudi2 and from SynapseAI v1.15. ## DeepSpeed @@ -196,7 +200,8 @@ python run_speech_recognition_ctc.py \ --use_habana \ --use_lazy_mode \ --gaudi_config_name="Habana/wav2vec2" \ - --bf16 + --bf16 \ + --use_hpu_graphs_for_inference ``` ## Sequence to Sequence diff --git a/optimum/habana/transformers/modeling_utils.py b/optimum/habana/transformers/modeling_utils.py index 9d4e473aab..10d2db216c 100644 --- a/optimum/habana/transformers/modeling_utils.py +++ b/optimum/habana/transformers/modeling_utils.py @@ -135,6 +135,7 @@ gaudi_wav2vec2_encoder_forward, gaudi_wav2vec2_forward, gaudi_wav2vec2_tdnnlayer_forward, + gaudi_wav2vec2forctc_forward, ) @@ -161,6 +162,7 @@ def adapt_transformers_to_gaudi(): ) transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Model.forward = gaudi_wav2vec2_forward transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Encoder.forward = gaudi_wav2vec2_encoder_forward + transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC.forward = gaudi_wav2vec2forctc_forward transformers.models.wav2vec2.modeling_wav2vec2.TDNNLayer.forward = gaudi_wav2vec2_tdnnlayer_forward # Generation is modified to run faster in lazy mode diff --git a/optimum/habana/transformers/models/__init__.py b/optimum/habana/transformers/models/__init__.py index d0eb8b2dcd..19d39893b4 100644 --- a/optimum/habana/transformers/models/__init__.py +++ b/optimum/habana/transformers/models/__init__.py @@ -146,4 +146,5 @@ gaudi_wav2vec2_encoder_forward, gaudi_wav2vec2_forward, gaudi_wav2vec2_tdnnlayer_forward, + gaudi_wav2vec2forctc_forward, ) diff --git a/optimum/habana/transformers/models/wav2vec2/__init__.py b/optimum/habana/transformers/models/wav2vec2/__init__.py index 3a5bae22b8..84372061b6 100644 --- a/optimum/habana/transformers/models/wav2vec2/__init__.py +++ b/optimum/habana/transformers/models/wav2vec2/__init__.py @@ -5,4 +5,5 @@ gaudi_wav2vec2_encoder_forward, gaudi_wav2vec2_forward, gaudi_wav2vec2_tdnnlayer_forward, + gaudi_wav2vec2forctc_forward, ) diff --git a/optimum/habana/transformers/models/wav2vec2/modeling_wav2vec2.py b/optimum/habana/transformers/models/wav2vec2/modeling_wav2vec2.py index 983c5b5375..afef8d580f 100644 --- a/optimum/habana/transformers/models/wav2vec2/modeling_wav2vec2.py +++ b/optimum/habana/transformers/models/wav2vec2/modeling_wav2vec2.py @@ -17,11 +17,23 @@ from typing import Optional, Tuple, Union import torch +from habana_frameworks.torch.hpu import get_device_name from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled from transformers.modeling_outputs import ( BaseModelOutput, + CausalLMOutput, Wav2Vec2BaseModelOutput, ) +from transformers.models.wav2vec2.modeling_wav2vec2 import _HIDDEN_STATES_START_POSITION + + +try: + from habana_frameworks.torch.hpex.kernels import CTCLoss + + custom_ctc_loss_fwd = CTCLoss.apply +except ImportError: + print("Could not import Custom CTCLoss kernel. This Kernel is available only for SynapseAI >= 1.15.0") + custom_ctc_loss_fwd = None def _gaudi_wav2vec2_compute_mask_indices( @@ -33,7 +45,8 @@ def _gaudi_wav2vec2_compute_mask_indices( ) -> torch.Tensor: """ Copied from Transformers: https://github.com/huggingface/transformers/blob/bd469c40659ce76c81f69c7726759d249b4aef49/src/transformers/models/wav2vec2/modeling_wav2vec2.py#L135 - The only difference is that the processing is performed with PyTorch on HPUs (Numpy is used in Transformers). + The only differences are (1) that the processing is performed with PyTorch on HPUs (Numpy is used in Transformers), (2) epsilon is generated on HPU instead of CPU, (3) check + to ensure indices are not larger than sequence length is re-written to avoid host sync. """ batch_size, sequence_length = shape @@ -122,8 +135,13 @@ def compute_num_masked_span(input_length): spec_aug_mask_idxs = spec_aug_mask_idxs + offsets # ensure that we cannot have indices larger than sequence_length - if spec_aug_mask_idxs.max() > sequence_length - 1: - spec_aug_mask_idxs[spec_aug_mask_idxs > sequence_length - 1] = sequence_length - 1 + if get_device_name() == "GAUDI" or custom_ctc_loss_fwd is None: + if spec_aug_mask_idxs.max() > sequence_length - 1: + spec_aug_mask_idxs[spec_aug_mask_idxs > sequence_length - 1] = sequence_length - 1 + else: + mask = (spec_aug_mask_idxs > sequence_length - 1) * (spec_aug_mask_idxs.max() > sequence_length - 1) + inverse_mask = torch.bitwise_not(mask) + spec_aug_mask_idxs = spec_aug_mask_idxs * inverse_mask + (sequence_length - 1) * mask # scatter indices to mask spec_aug_mask.scatter_(-1, spec_aug_mask_idxs, 1) @@ -374,3 +392,74 @@ def gaudi_wav2vec2_tdnnlayer_forward(self, hidden_states: torch.Tensor) -> torch hidden_states = self.activation(hidden_states) return hidden_states + + +def gaudi_wav2vec2forctc_forward( + self, + input_values: Optional[torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + labels: Optional[torch.Tensor] = None, +) -> Union[Tuple, CausalLMOutput]: + """ + copied from Transformers https://github.com/huggingface/transformers/blob/e770f0316d2a9b787c9d1440f204fcb65e176682/src/transformers/models/wav2vec2/modeling_wav2vec2.py#L1950 + only differences are (1) attention_mask tensor generation using ones_like is done on HPU, (2) masked_select is not applied on labels to compute flattened_targets to avoid + changing flattened_targets tensor shapes across training iterations. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + outputs = self.wav2vec2( + input_values, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = outputs[0] + hidden_states = self.dropout(hidden_states) + logits = self.lm_head(hidden_states) + loss = None + if labels is not None: + if labels.max() >= self.config.vocab_size: + raise ValueError(f"Label values must be <= vocab_size: {self.config.vocab_size}") + # retrieve loss input_lengths from attention_mask + attention_mask = ( + attention_mask + if attention_mask is not None + else torch.ones_like(input_values, dtype=torch.long, device="hpu") + ) + input_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long) + # assuming that padded tokens are filled with -100 + # when not being attended to + labels_mask = labels >= 0 + target_lengths = labels_mask.sum(-1) + # ctc_loss doesn't support fp16 + log_probs = torch.nn.functional.log_softmax(logits, dim=-1, dtype=torch.float32).transpose(0, 1) + if get_device_name() == "GAUDI" or custom_ctc_loss_fwd is None: + flattened_targets = labels.masked_select(labels_mask) + loss = torch.nn.functional.ctc_loss( + log_probs, + flattened_targets, + input_lengths, + target_lengths, + blank=self.config.pad_token_id, + reduction=self.config.ctc_loss_reduction, + zero_infinity=self.config.ctc_zero_infinity, + ) + else: + flattened_targets = labels + loss = custom_ctc_loss_fwd( + log_probs, + flattened_targets, + input_lengths, + target_lengths, + self.config.pad_token_id, + self.config.ctc_loss_reduction, + self.config.ctc_zero_infinity, + ) + + if not return_dict: + output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:] + return ((loss,) + output) if loss is not None else output + return CausalLMOutput(loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions) diff --git a/tests/baselines/wav2vec2_large_lv60.json b/tests/baselines/wav2vec2_large_lv60.json index b1071302fa..6792b855ee 100644 --- a/tests/baselines/wav2vec2_large_lv60.json +++ b/tests/baselines/wav2vec2_large_lv60.json @@ -33,12 +33,12 @@ "eval_batch_size": 8, "distribution": { "multi_card": { - "learning_rate": 3e-4, + "learning_rate": 4e-4, "train_batch_size": 8, - "eval_wer": 0.0531535105117017, - "train_runtime": 356.4723, - "train_samples_per_second": 183.245, - "eval_samples_per_second": 158.985, + "eval_wer": 0.06120587068623562, + "train_runtime": 308.8036, + "train_samples_per_second": 225.572, + "eval_samples_per_second": 196.665, "extra_arguments": [ "--dataset_config_name clean", "--train_split_name train.100", @@ -49,7 +49,9 @@ "--layerdrop 0.0", "--freeze_feature_encoder", "--dataloader_num_workers 8", - "--chars_to_ignore ',?.!-;:\"“%‘”'" + "--chars_to_ignore ',?.!-;:\"“%‘”'", + "--use_hpu_graphs_for_training", + "--use_hpu_graphs_for_inference" ] } }