Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 25 additions & 3 deletions examples/speech-recognition/run_speech_recognition_ctc.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,20 @@ class ModelArguments:
ctc_loss_reduction: Optional[str] = field(
default="mean", metadata={"help": "The way the ctc loss should be reduced. Should be one of 'mean' or 'sum'."}
)
ctc_zero_infinity: Optional[bool] = field(
default=False,
metadata={
"help": "Whether to zero infinite losses and the associated gradients of `torch.nn.CTCLoss`. Infinite losses mainly"
" occur when the inputs are too short to be aligned to the targets."
},
)
add_adapter: Optional[bool] = field(
default=False,
metadata={
"help": "Whether a convolutional attention network should be stacked on top of the Wav2Vec2Bert Encoder. Can be very "
"useful to downsample the output length."
},
)


@dataclass
Expand Down Expand Up @@ -315,11 +329,14 @@ class DataCollatorCTCWithPadding:
padding: Union[bool, str] = "longest"
pad_to_multiple_of: Optional[int] = None
pad_to_multiple_of_labels: Optional[int] = None
feature_extractor_input_name: Optional[str] = "input_values"

def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
# split inputs and labels since they have to be of different lengths and need
# different padding methods
input_features = [{"input_values": feature["input_values"]} for feature in features]
input_features = [
{self.feature_extractor_input_name: feature[self.feature_extractor_input_name]} for feature in features
]
label_features = [{"input_ids": feature["labels"]} for feature in features]

batch = self.processor.pad(
Expand Down Expand Up @@ -612,9 +629,11 @@ def remove_special_characters(batch):
"gradient_checkpointing": training_args.gradient_checkpointing,
"layerdrop": model_args.layerdrop,
"ctc_loss_reduction": model_args.ctc_loss_reduction,
"ctc_zero_infinity": model_args.ctc_zero_infinity,
"pad_token_id": tokenizer.pad_token_id,
"vocab_size": len(tokenizer),
"activation_dropout": model_args.activation_dropout,
"add_adapter": model_args.add_adapter,
}
)

Expand Down Expand Up @@ -653,6 +672,7 @@ def remove_special_characters(batch):
min_input_length = data_args.min_duration_in_seconds * feature_extractor.sampling_rate
audio_column_name = data_args.audio_column_name
num_workers = data_args.preprocessing_num_workers
feature_extractor_input_name = feature_extractor.model_input_names[0]

# `phoneme_language` is only relevant if the model is fine-tuned on phoneme classification
phoneme_language = data_args.phoneme_language
Expand All @@ -664,8 +684,9 @@ def prepare_dataset(batch):
sample = batch[audio_column_name]

inputs = feature_extractor(sample["array"], sampling_rate=sample["sampling_rate"])
batch["input_values"] = inputs.input_values[0]
batch["input_length"] = len(batch["input_values"])
batch[feature_extractor_input_name] = getattr(inputs, feature_extractor_input_name)[0]
# take length of raw audio waveform
batch["input_length"] = len(sample["array"].squeeze())

# encode targets
additional_kwargs = {}
Expand Down Expand Up @@ -748,6 +769,7 @@ def compute_metrics(pred):
# Instantiate custom data collator
data_collator = DataCollatorCTCWithPadding(
processor=processor,
feature_extractor_input_name=feature_extractor_input_name,
pad_to_multiple_of=int(max_input_length),
pad_to_multiple_of_labels=500,
)
Expand Down
39 changes: 21 additions & 18 deletions tests/example_diff/run_speech_recognition_ctc.txt
Original file line number Diff line number Diff line change
Expand Up @@ -30,33 +30,37 @@
>
> require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt")
>
141d147
145c152
< "help": "Whether a convolutional attention network should be stacked on top of the Wav2Vec2Bert Encoder. Can be very"
---
> "help": "Whether a convolutional attention network should be stacked on top of the Wav2Vec2Bert Encoder. Can be very "
155d161
<
251c257
265c271
< "Whether or not to allow for custom models defined on the Hub in their own modeling files. This option"
---
> "Whether or not to allow for custom models defined on the Hub in their own modeling files. This option "
390c396
407c413
< parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
---
> parser = HfArgumentParser((ModelArguments, DataTrainingArguments, GaudiTrainingArguments))
433a440,445
450a457,462
> gaudi_config = GaudiConfig.from_pretrained(
> training_args.gaudi_config_name,
> cache_dir=model_args.cache_dir,
> use_auth_token=True if data_args.use_auth_token else None,
> )
>
434a447
451a464
> mixed_precision = training_args.bf16 or gaudi_config.use_torch_autocast
436,437c449,451
453,454c466,468
< f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}, "
< f"distributed training: {training_args.parallel_mode.value == 'distributed'}, 16-bits training: {training_args.fp16}"
---
> f"Process rank: {training_args.local_rank}, device: {training_args.device}, "
> + f"distributed training: {training_args.parallel_mode.value == 'distributed'}, "
> + f"mixed-precision training: {mixed_precision}"
450,456c464,469
467,473c481,486
< if training_args.do_train:
< raw_datasets["train"] = load_dataset(
< data_args.dataset_name,
Expand All @@ -71,7 +75,7 @@
> split=data_args.train_split_name,
> token=data_args.token,
> )
458,463c471,476
475,480c488,493
< if data_args.audio_column_name not in raw_datasets["train"].column_names:
< raise ValueError(
< f"--audio_column_name '{data_args.audio_column_name}' not found in dataset '{data_args.dataset_name}'."
Expand All @@ -85,7 +89,7 @@
> " Make sure to set `--audio_column_name` to the correct audio column - one of"
> f" {', '.join(raw_datasets['train'].column_names)}."
> )
465,470c478,483
482,487c495,500
< if data_args.text_column_name not in raw_datasets["train"].column_names:
< raise ValueError(
< f"--text_column_name {data_args.text_column_name} not found in dataset '{data_args.dataset_name}'. "
Expand All @@ -99,33 +103,32 @@
> "Make sure to set `--text_column_name` to the correct text column - one of "
> f"{', '.join(raw_datasets['train'].column_names)}."
> )
472,473c485,486
489,490c502,503
< if data_args.max_train_samples is not None:
< raw_datasets["train"] = raw_datasets["train"].select(range(data_args.max_train_samples))
---
> if data_args.max_train_samples is not None:
> raw_datasets["train"] = raw_datasets["train"].select(range(data_args.max_train_samples))
491c504
508c521
< f'[{"".join(data_args.chars_to_ignore)}]' if data_args.chars_to_ignore is not None else None
---
> f'[{"".join(data_args.chars_to_ignore).replace(" ", "")}]' if data_args.chars_to_ignore is not None else None
628a642,646
647a661,665
> raise RuntimeError(
> f"The dataset sampling rate ({dataset_sampling_rate}) is different from the feature extractor one"
> f" ({feature_extractor.sampling_rate}).Data resampling should be done. The Datasets library does not"
> " support it on HPUs yet."
> )
731c749,753
< data_collator = DataCollatorCTCWithPadding(processor=processor)
753c771,774
< processor=processor, feature_extractor_input_name=feature_extractor_input_name
---
> data_collator = DataCollatorCTCWithPadding(
> processor=processor,
> feature_extractor_input_name=feature_extractor_input_name,
> pad_to_multiple_of=int(max_input_length),
> pad_to_multiple_of_labels=500,
> )
734c756
757c778
< trainer = Trainer(
---
> trainer = GaudiTrainer(
735a758
758a780
> gaudi_config=gaudi_config,