Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions examples/speech-recognition/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -78,13 +78,15 @@ python run_speech_recognition_ctc.py \
--use_lazy_mode \
--gaudi_config_name="Habana/wav2vec2" \
--throughput_warmup_steps="3" \
--bf16
--bf16 \
--use_hpu_graphs_for_training \
--use_hpu_grpahs_for_inference
```

On a single HPU, this script should run in *ca.* 6 hours and yield a CTC loss of **0.059** and a word error rate of **0.0423**.

> If your data has a sampling rate which is different from the one of the data the model was trained on, this script will raise an error.
> Resampling with the `datasets` library is not supported on HPUs yet.
> Resampling with the `datasets` library is not supported on HPUs yet. HPU graphs are supported only on Gaudi2 and from SynapseAI v1.15.

### Multi-HPU CTC

Expand Down Expand Up @@ -117,13 +119,15 @@ python ../gaudi_spawn.py \
--use_lazy_mode \
--gaudi_config_name Habana/wav2vec2 \
--throughput_warmup_steps 3 \
--bf16
--bf16 \
--use_hpu_graphs_for_training \
--use_hpu_graphs_for_inference
```

On 8 HPUs, this script should run in *ca.* 49 minutes and yield a CTC loss of **0.0613** and a word error rate of **0.0458**.

> If your data has a sampling rate which is different from the one of the data the model was trained on, this script will raise an error.
> Resampling with the `datasets` library is not supported on HPUs yet.
> Resampling with the `datasets` library is not supported on HPUs yet. HPU graphs are supported only on Gaudi2 and from SynapseAI v1.15.


## DeepSpeed
Expand Down Expand Up @@ -196,7 +200,8 @@ python run_speech_recognition_ctc.py \
--use_habana \
--use_lazy_mode \
--gaudi_config_name="Habana/wav2vec2" \
--bf16
--bf16 \
--use_hpu_graphs_for_inference
```
## Sequence to Sequence

Expand Down
2 changes: 2 additions & 0 deletions optimum/habana/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@
gaudi_wav2vec2_encoder_forward,
gaudi_wav2vec2_forward,
gaudi_wav2vec2_tdnnlayer_forward,
gaudi_wav2vec2forctc_forward,
)


Expand All @@ -161,6 +162,7 @@ def adapt_transformers_to_gaudi():
)
transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Model.forward = gaudi_wav2vec2_forward
transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Encoder.forward = gaudi_wav2vec2_encoder_forward
transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC.forward = gaudi_wav2vec2forctc_forward
transformers.models.wav2vec2.modeling_wav2vec2.TDNNLayer.forward = gaudi_wav2vec2_tdnnlayer_forward

# Generation is modified to run faster in lazy mode
Expand Down
1 change: 1 addition & 0 deletions optimum/habana/transformers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,4 +146,5 @@
gaudi_wav2vec2_encoder_forward,
gaudi_wav2vec2_forward,
gaudi_wav2vec2_tdnnlayer_forward,
gaudi_wav2vec2forctc_forward,
)
1 change: 1 addition & 0 deletions optimum/habana/transformers/models/wav2vec2/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@
gaudi_wav2vec2_encoder_forward,
gaudi_wav2vec2_forward,
gaudi_wav2vec2_tdnnlayer_forward,
gaudi_wav2vec2forctc_forward,
)
95 changes: 92 additions & 3 deletions optimum/habana/transformers/models/wav2vec2/modeling_wav2vec2.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,23 @@
from typing import Optional, Tuple, Union

import torch
from habana_frameworks.torch.hpu import get_device_name
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
from transformers.modeling_outputs import (
BaseModelOutput,
CausalLMOutput,
Wav2Vec2BaseModelOutput,
)
from transformers.models.wav2vec2.modeling_wav2vec2 import _HIDDEN_STATES_START_POSITION


try:
from habana_frameworks.torch.hpex.kernels import CTCLoss

custom_ctc_loss_fwd = CTCLoss.apply
except ImportError:
print("Could not import Custom CTCLoss kernel. This Kernel is available only for SynapseAI >= 1.15.0")
custom_ctc_loss_fwd = None


def _gaudi_wav2vec2_compute_mask_indices(
Expand All @@ -33,7 +45,8 @@ def _gaudi_wav2vec2_compute_mask_indices(
) -> torch.Tensor:
"""
Copied from Transformers: https://github.com/huggingface/transformers/blob/bd469c40659ce76c81f69c7726759d249b4aef49/src/transformers/models/wav2vec2/modeling_wav2vec2.py#L135
The only difference is that the processing is performed with PyTorch on HPUs (Numpy is used in Transformers).
The only differences are (1) that the processing is performed with PyTorch on HPUs (Numpy is used in Transformers), (2) epsilon is generated on HPU instead of CPU, (3) check
to ensure indices are not larger than sequence length is re-written to avoid host sync.
"""
batch_size, sequence_length = shape

Expand Down Expand Up @@ -122,8 +135,13 @@ def compute_num_masked_span(input_length):
spec_aug_mask_idxs = spec_aug_mask_idxs + offsets

# ensure that we cannot have indices larger than sequence_length
if spec_aug_mask_idxs.max() > sequence_length - 1:
spec_aug_mask_idxs[spec_aug_mask_idxs > sequence_length - 1] = sequence_length - 1
if get_device_name() == "GAUDI" or custom_ctc_loss_fwd is None:
if spec_aug_mask_idxs.max() > sequence_length - 1:
spec_aug_mask_idxs[spec_aug_mask_idxs > sequence_length - 1] = sequence_length - 1
else:
mask = (spec_aug_mask_idxs > sequence_length - 1) * (spec_aug_mask_idxs.max() > sequence_length - 1)
inverse_mask = torch.bitwise_not(mask)
spec_aug_mask_idxs = spec_aug_mask_idxs * inverse_mask + (sequence_length - 1) * mask

# scatter indices to mask
spec_aug_mask.scatter_(-1, spec_aug_mask_idxs, 1)
Expand Down Expand Up @@ -374,3 +392,74 @@ def gaudi_wav2vec2_tdnnlayer_forward(self, hidden_states: torch.Tensor) -> torch

hidden_states = self.activation(hidden_states)
return hidden_states


def gaudi_wav2vec2forctc_forward(
self,
input_values: Optional[torch.Tensor],
attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
labels: Optional[torch.Tensor] = None,
) -> Union[Tuple, CausalLMOutput]:
"""
copied from Transformers https://github.com/huggingface/transformers/blob/e770f0316d2a9b787c9d1440f204fcb65e176682/src/transformers/models/wav2vec2/modeling_wav2vec2.py#L1950
only differences are (1) attention_mask tensor generation using ones_like is done on HPU, (2) masked_select is not applied on labels to compute flattened_targets to avoid
changing flattened_targets tensor shapes across training iterations.
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.wav2vec2(
input_values,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = outputs[0]
hidden_states = self.dropout(hidden_states)
logits = self.lm_head(hidden_states)
loss = None
if labels is not None:
if labels.max() >= self.config.vocab_size:
raise ValueError(f"Label values must be <= vocab_size: {self.config.vocab_size}")
# retrieve loss input_lengths from attention_mask
attention_mask = (
attention_mask
if attention_mask is not None
else torch.ones_like(input_values, dtype=torch.long, device="hpu")
)
input_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long)
# assuming that padded tokens are filled with -100
# when not being attended to
labels_mask = labels >= 0
target_lengths = labels_mask.sum(-1)
# ctc_loss doesn't support fp16
log_probs = torch.nn.functional.log_softmax(logits, dim=-1, dtype=torch.float32).transpose(0, 1)
if get_device_name() == "GAUDI" or custom_ctc_loss_fwd is None:
flattened_targets = labels.masked_select(labels_mask)
loss = torch.nn.functional.ctc_loss(
log_probs,
flattened_targets,
input_lengths,
target_lengths,
blank=self.config.pad_token_id,
reduction=self.config.ctc_loss_reduction,
zero_infinity=self.config.ctc_zero_infinity,
)
else:
flattened_targets = labels
loss = custom_ctc_loss_fwd(
log_probs,
flattened_targets,
input_lengths,
target_lengths,
self.config.pad_token_id,
self.config.ctc_loss_reduction,
self.config.ctc_zero_infinity,
)

if not return_dict:
output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
return ((loss,) + output) if loss is not None else output
return CausalLMOutput(loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions)
14 changes: 8 additions & 6 deletions tests/baselines/wav2vec2_large_lv60.json
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,12 @@
"eval_batch_size": 8,
"distribution": {
"multi_card": {
"learning_rate": 3e-4,
"learning_rate": 4e-4,
"train_batch_size": 8,
"eval_wer": 0.0531535105117017,
"train_runtime": 356.4723,
"train_samples_per_second": 183.245,
"eval_samples_per_second": 158.985,
"eval_wer": 0.06120587068623562,
"train_runtime": 308.8036,
"train_samples_per_second": 225.572,
"eval_samples_per_second": 196.665,
"extra_arguments": [
"--dataset_config_name clean",
"--train_split_name train.100",
Expand All @@ -49,7 +49,9 @@
"--layerdrop 0.0",
"--freeze_feature_encoder",
"--dataloader_num_workers 8",
"--chars_to_ignore ',?.!-;:\"“%‘”'"
"--chars_to_ignore ',?.!-;:\"“%‘”'",
"--use_hpu_graphs_for_training",
"--use_hpu_graphs_for_inference"
]
}
}
Expand Down