diff --git a/examples/audio-classification/run_audio_classification.py b/examples/audio-classification/run_audio_classification.py index eba2c9e07e..2fcade692b 100644 --- a/examples/audio-classification/run_audio_classification.py +++ b/examples/audio-classification/run_audio_classification.py @@ -275,42 +275,42 @@ def main(): # Max input length max_length = int(round(feature_extractor.sampling_rate * data_args.max_length_seconds)) + model_input_name = feature_extractor.model_input_names[0] + def train_transforms(batch): """Apply train_transforms across a batch.""" - output_batch = {"input_values": []} + subsampled_wavs = [] for audio in batch[data_args.audio_column_name]: wav = random_subsample( audio["array"], max_length=data_args.max_length_seconds, sample_rate=feature_extractor.sampling_rate ) - preprocessed_audio = feature_extractor( - wav, - max_length=max_length, - sampling_rate=feature_extractor.sampling_rate, - padding="max_length", - truncation=True, - ) - output_batch["input_values"].append(preprocessed_audio["input_values"][0]) - + subsampled_wavs.append(wav) + inputs = feature_extractor( + subsampled_wavs, + max_length=max_length, + sampling_rate=feature_extractor.sampling_rate, + padding="max_length", + truncation=True, + ) + output_batch = {model_input_name: inputs.get(model_input_name)} output_batch["labels"] = list(batch[data_args.label_column_name]) + return output_batch def val_transforms(batch): """Apply val_transforms across a batch.""" - output_batch = {"input_values": []} - - for audio in batch[data_args.audio_column_name]: - wav = audio["array"] - preprocessed_audio = feature_extractor( - wav, - max_length=max_length, - sampling_rate=feature_extractor.sampling_rate, - padding="max_length", - truncation=True, - ) - output_batch["input_values"].append(preprocessed_audio["input_values"][0]) - + wavs = [audio["array"] for audio in batch[data_args.audio_column_name]] + inputs = feature_extractor( + wavs, + max_length=max_length, + sampling_rate=feature_extractor.sampling_rate, + padding="max_length", + truncation=True, + ) + output_batch = {model_input_name: inputs.get(model_input_name)} output_batch["labels"] = list(batch[data_args.label_column_name]) + return output_batch # Prepare label mappings. diff --git a/tests/example_diff/run_audio_classification.txt b/tests/example_diff/run_audio_classification.txt index 329e20d433..09736dd290 100644 --- a/tests/example_diff/run_audio_classification.txt +++ b/tests/example_diff/run_audio_classification.txt @@ -67,41 +67,31 @@ > # Max input length > max_length = int(round(feature_extractor.sampling_rate * data_args.max_length_seconds)) > -294a281 +296a283 > -299,300c286,293 -< output_batch["input_values"].append(wav) -< output_batch["labels"] = list(batch[data_args.label_column_name]) +302c289,295 +< inputs = feature_extractor(subsampled_wavs, sampling_rate=feature_extractor.sampling_rate) --- -> preprocessed_audio = feature_extractor( -> wav, -> max_length=max_length, -> sampling_rate=feature_extractor.sampling_rate, -> padding="max_length", -> truncation=True, -> ) -> output_batch["input_values"].append(preprocessed_audio["input_values"][0]) -301a295 -> output_batch["labels"] = list(batch[data_args.label_column_name]) -306a301 -> -309,310c304,311 -< output_batch["input_values"].append(wav) -< output_batch["labels"] = list(batch[data_args.label_column_name]) +> inputs = feature_extractor( +> subsampled_wavs, +> max_length=max_length, +> sampling_rate=feature_extractor.sampling_rate, +> padding="max_length", +> truncation=True, +> ) +311c304,310 +< inputs = feature_extractor(wavs, sampling_rate=feature_extractor.sampling_rate) --- -> preprocessed_audio = feature_extractor( -> wav, -> max_length=max_length, -> sampling_rate=feature_extractor.sampling_rate, -> padding="max_length", -> truncation=True, -> ) -> output_batch["input_values"].append(preprocessed_audio["input_values"][0]) -311a313 -> output_batch["labels"] = list(batch[data_args.label_column_name]) -373c375 +> inputs = feature_extractor( +> wavs, +> max_length=max_length, +> sampling_rate=feature_extractor.sampling_rate, +> padding="max_length", +> truncation=True, +> ) +376c375 < trainer = Trainer( --- > trainer = GaudiTrainer( -374a377 +377a377 > gaudi_config=gaudi_config,