From a0e862d28899d81731b581f73cda5f1c4428e591 Mon Sep 17 00:00:00 2001 From: tianhao zhang <15600919271@163.com> Date: Wed, 9 Nov 2022 10:15:37 +0000 Subject: [PATCH 1/3] wav2vec2_en, test=asr --- docs/source/released_model.md | 2 +- .../librispeech/asr3/conf/wav2vec2ASR.yaml | 7 +- paddlespeech/resource/pretrained_models.py | 41 ++++-- paddlespeech/s2t/exps/wav2vec2/model.py | 10 +- .../processing/speech_augmentation.py | 119 ++++++++++-------- 5 files changed, 106 insertions(+), 73 deletions(-) diff --git a/docs/source/released_model.md b/docs/source/released_model.md index 2f3c9d09851..79e8f4f4659 100644 --- a/docs/source/released_model.md +++ b/docs/source/released_model.md @@ -22,7 +22,7 @@ Acoustic Model | Training Data | Token-based | Size | Descriptions | CER | WER | Model | Pre-Train Method | Pre-Train Data | Finetune Data | Size | Descriptions | CER | WER | Example Link | :-------------:| :------------:| :-----: | -----: | :-----: |:-----:| :-----: | :-----: | :-----: | [Wav2vec2-large-960h-lv60-self Model](https://paddlespeech.bj.bcebos.com/wav2vec/wav2vec2-large-960h-lv60-self.pdparams) | wav2vec2 | Librispeech and LV-60k Dataset (5.3w h) | - | 1.18 GB |Pre-trained Wav2vec2.0 Model | - | - | - | -[Wav2vec2ASR-large-960h-librispeech Model](https://paddlespeech.bj.bcebos.com/s2t/librispeech/asr3/wav2vec2ASR-large-960h-librispeech_ckpt_1.3.0.model.tar.gz) | wav2vec2 | Librispeech and LV-60k Dataset (5.3w h) | Librispeech (960 h) | 1.18 GB |Encoder: Wav2vec2.0, Decoder: CTC, Decoding method: Greedy search | - | 0.0189 | [Wav2vecASR Librispeech ASR3](../../examples/librispeech/asr3) | +[Wav2vec2ASR-large-960h-librispeech Model](https://paddlespeech.bj.bcebos.com/s2t/librispeech/asr3/wav2vec2ASR-large-960h-librispeech_ckpt_1.3.1.model.tar.gz) | wav2vec2 | Librispeech and LV-60k Dataset (5.3w h) | Librispeech (960 h) | 718 MB |Encoder: Wav2vec2.0, Decoder: CTC, Decoding method: Greedy search | - | 0.0189 | [Wav2vecASR Librispeech ASR3](../../examples/librispeech/asr3) | ### Language Model based on NGram Language Model | Training Data | Token-based | Size | Descriptions diff --git a/examples/librispeech/asr3/conf/wav2vec2ASR.yaml b/examples/librispeech/asr3/conf/wav2vec2ASR.yaml index b19881b70e0..8d5899e28ef 100644 --- a/examples/librispeech/asr3/conf/wav2vec2ASR.yaml +++ b/examples/librispeech/asr3/conf/wav2vec2ASR.yaml @@ -9,6 +9,9 @@ dnn_neurons: 1024 blank_id: 0 ctc_dropout_rate: 0.0 wav2vec2_params_path: "exp/wav2vec2/wav2vec2-large-960h-lv60-self.pdparams" +speech_augment: + sample_rate: 16000 + speeds: [95, 100, 105] ############################################ # Wav2Vec2.0 # @@ -70,7 +73,6 @@ train_manifest: data/manifest.train dev_manifest: data/manifest.dev test_manifest: data/manifest.test-clean - ########################################### # Dataloader # ########################################### @@ -115,6 +117,3 @@ log_interval: 1 checkpoint: kbest_n: 50 latest_n: 5 -augment: True - - diff --git a/paddlespeech/resource/pretrained_models.py b/paddlespeech/resource/pretrained_models.py index df50a6a9d52..93ad30cd531 100644 --- a/paddlespeech/resource/pretrained_models.py +++ b/paddlespeech/resource/pretrained_models.py @@ -13,18 +13,13 @@ # limitations under the License. __all__ = [ - 'asr_dynamic_pretrained_models', - 'asr_static_pretrained_models', - 'asr_onnx_pretrained_models', - 'cls_dynamic_pretrained_models', - 'cls_static_pretrained_models', - 'st_dynamic_pretrained_models', - 'st_kaldi_bins', - 'text_dynamic_pretrained_models', - 'tts_dynamic_pretrained_models', - 'tts_static_pretrained_models', - 'tts_onnx_pretrained_models', - 'vector_dynamic_pretrained_models', + 'asr_dynamic_pretrained_models', 'asr_static_pretrained_models', + 'asr_onnx_pretrained_models', 'cls_dynamic_pretrained_models', + 'cls_static_pretrained_models', 'st_dynamic_pretrained_models', + 'st_kaldi_bins', 'text_dynamic_pretrained_models', + 'tts_dynamic_pretrained_models', 'tts_static_pretrained_models', + 'tts_onnx_pretrained_models', 'vector_dynamic_pretrained_models', + 'ssl_pretrained_models' ] # The tags for pretrained_models should be "{model_name}[_{dataset}][-{lang}][-...]". @@ -32,6 +27,28 @@ # Command line and python api use "{model_name}[_{dataset}]" as --model, usage: # "paddlespeech asr --model conformer_wenetspeech --lang zh --sr 16000 --input ./input.wav" +# --------------------------------- +# -------------- SSL -------------- +# --------------------------------- +ssl_pretrained_models = { + "wav2vec2ASR_librispeech-en-16k": { + '1.3': { + 'url': + 'https://paddlespeech.bj.bcebos.com/s2t/librispeech/asr3/wav2vec2ASR-large-960h-librispeech_ckpt_1.3.1.model.tar.gz', + 'md5': + '7d9449a8103ec4b17d6a004e928e0b1f', + 'cfg_path': + 'model.yaml', + 'ckpt_path': + 'exp/wav2vec2ASR/checkpoints/avg_1', + 'model': + 'exp/wav2vec2ASR/checkpoints/avg_1.pdparams', + 'params': + 'exp/wav2vec2ASR/checkpoints/avg_1.pdparams', + }, + }, +} + # --------------------------------- # -------------- ASR -------------- # --------------------------------- diff --git a/paddlespeech/s2t/exps/wav2vec2/model.py b/paddlespeech/s2t/exps/wav2vec2/model.py index 933e268edac..8e5f8d9db15 100644 --- a/paddlespeech/s2t/exps/wav2vec2/model.py +++ b/paddlespeech/s2t/exps/wav2vec2/model.py @@ -28,6 +28,7 @@ from paddlespeech.s2t.frontend.featurizer import TextFeaturizer from paddlespeech.s2t.io.dataloader import DataLoaderFactory from paddlespeech.s2t.models.wav2vec2.processing.speech_augmentation import TimeDomainSpecAugment +from paddlespeech.s2t.models.wav2vec2.processing.speech_augmentation import TimeDomainSpecAugmentConfig from paddlespeech.s2t.models.wav2vec2.wav2vec2_ASR import Wav2vec2ASR from paddlespeech.s2t.training.optimizer import OptimizerFactory from paddlespeech.s2t.training.reporter import ObsScope @@ -71,7 +72,8 @@ def train_batch(self, batch_index, batch, msg): wavs_lens_rate = wavs_lens / wav.shape[1] target_lens_rate = target_lens / target.shape[1] wav = wav[:, :, 0] - wav = self.speech_augmentation(wav, wavs_lens_rate) + if hasattr(train_conf, 'speech_augment'): + wav = self.speech_augmentation(wav, wavs_lens_rate) loss = self.model(wav, wavs_lens_rate, target, target_lens_rate) # loss div by `batch_size * accum_grad` loss /= train_conf.accum_grad @@ -277,7 +279,11 @@ def setup_model(self): logger.info("Setup model!") # setup speech augmentation for wav2vec2 - self.speech_augmentation = TimeDomainSpecAugment() + if hasattr(config, 'speech_augment'): + speechaugment_config = TimeDomainSpecAugmentConfig( + config.speech_augment) + self.speech_augmentation = TimeDomainSpecAugment( + speechaugment_config) if not self.train: return diff --git a/paddlespeech/s2t/models/wav2vec2/processing/speech_augmentation.py b/paddlespeech/s2t/models/wav2vec2/processing/speech_augmentation.py index 78a0782e72b..a9ab251d304 100644 --- a/paddlespeech/s2t/models/wav2vec2/processing/speech_augmentation.py +++ b/paddlespeech/s2t/models/wav2vec2/processing/speech_augmentation.py @@ -641,15 +641,56 @@ def forward(self, waveforms, lengths): class TimeDomainSpecAugment(nn.Layer): """A time-domain approximation of the SpecAugment algorithm. - + --------- This augmentation module implements three augmentations in the time-domain. - 1. Drop chunks of the audio (zero amplitude or white noise) 2. Drop frequency bands (with band-drop filters) 3. Speed peturbation (via resampling to slightly different rate) - Arguments + Example + ------- + >>> inputs = paddle.randn([10, 16000]) + >>> feature_maker = TimeDomainSpecAugment(speeds=[80]) + >>> feats = feature_maker(inputs, paddle.ones(10)) + >>> feats.shape + paddle.shape([10, 12800]) + """ + + def __init__(self, config): + super().__init__() + self.speed_perturb = SpeedPerturb( + perturb_prob=config.perturb_prob, + orig_freq=config.sample_rate, + speeds=config.speeds) + self.drop_freq = DropFreq( + drop_prob=config.drop_freq_prob, + drop_count_low=config.drop_freq_count_low, + drop_count_high=config.drop_freq_count_high) + self.drop_chunk = DropChunk( + drop_prob=config.drop_chunk_prob, + drop_count_low=config.drop_chunk_count_low, + drop_count_high=config.drop_chunk_count_high, + drop_length_low=config.drop_chunk_length_low, + drop_length_high=config.drop_chunk_length_high, + noise_factor=config.drop_chunk_noise_factor) + + def forward(self, waveforms, lengths): + """Returns the distorted waveforms. + --------- + waveforms : tensor + The waveforms to distort + """ + # Augmentation + with paddle.no_grad(): + waveforms = self.speed_perturb(waveforms) + waveforms = self.drop_freq(waveforms) + waveforms = self.drop_chunk(waveforms, lengths) + return waveforms + + +class TimeDomainSpecAugmentConfig(): + """Augmentation configuration for time domain spectrograms. --------- perturb_prob : float from 0 to 1 The probability that a batch will have speed perturbation applied. @@ -677,56 +718,26 @@ class TimeDomainSpecAugment(nn.Layer): drop_chunk_noise_factor : float The noise factor used to scale the white noise inserted, relative to the average amplitude of the utterance. Default 0 (no noise inserted). - - Example - ------- - >>> inputs = paddle.randn([10, 16000]) - >>> feature_maker = TimeDomainSpecAugment(speeds=[80]) - >>> feats = feature_maker(inputs, paddle.ones(10)) - >>> feats.shape - paddle.shape([10, 12800]) """ - def __init__( - self, - perturb_prob=1.0, - drop_freq_prob=1.0, - drop_chunk_prob=1.0, - speeds=[95, 100, 105], - sample_rate=16000, - drop_freq_count_low=0, - drop_freq_count_high=3, - drop_chunk_count_low=0, - drop_chunk_count_high=5, - drop_chunk_length_low=1000, - drop_chunk_length_high=2000, - drop_chunk_noise_factor=0, ): - super().__init__() - self.speed_perturb = SpeedPerturb( - perturb_prob=perturb_prob, orig_freq=sample_rate, speeds=speeds) - self.drop_freq = DropFreq( - drop_prob=drop_freq_prob, - drop_count_low=drop_freq_count_low, - drop_count_high=drop_freq_count_high, ) - self.drop_chunk = DropChunk( - drop_prob=drop_chunk_prob, - drop_count_low=drop_chunk_count_low, - drop_count_high=drop_chunk_count_high, - drop_length_low=drop_chunk_length_low, - drop_length_high=drop_chunk_length_high, - noise_factor=drop_chunk_noise_factor, ) - - def forward(self, waveforms, lengths): - """Returns the distorted waveforms. - - Arguments - --------- - waveforms : tensor - The waveforms to distort - """ - # Augmentation - with paddle.no_grad(): - waveforms = self.speed_perturb(waveforms) - waveforms = self.drop_freq(waveforms) - waveforms = self.drop_chunk(waveforms, lengths) - return waveforms + def __init__(self, config): + # speedperturb config + self.perturb_prob = getattr(config, 'perturb_prob', 1.0) + self.sample_rate = getattr(config, 'sample_rate', 16000) + self.speeds = getattr(config, 'speeds', [95, 100, 105]) + + # dropfreq config + self.drop_freq_prob = getattr(config, 'drop_freq_prob', 1.0) + self.drop_freq_count_low = getattr(config, 'drop_freq_count_low', 0) + self.drop_freq_count_high = getattr(config, 'drop_freq_count_high', 3) + + # dropchunk config + self.drop_chunk_prob = getattr(config, 'drop_chunk_prob', 1.0) + self.drop_chunk_count_low = getattr(config, 'drop_chunk_count_low', 0) + self.drop_chunk_count_high = getattr(config, 'drop_chunk_count_high', 5) + self.drop_chunk_length_low = getattr(config, 'drop_chunk_length_low', + 1000) + self.drop_chunk_length_high = getattr(config, 'drop_chunk_length_high', + 2000) + self.drop_chunk_noise_factor = getattr(config, + 'drop_chunk_noise_factor', 0) From 0f766848c2a222c36b2ec42bc6f1e556f27a9cba Mon Sep 17 00:00:00 2001 From: tianhao zhang <15600919271@163.com> Date: Wed, 9 Nov 2022 11:51:31 +0000 Subject: [PATCH 2/3] wav2vec2_en, test=asr --- paddlespeech/s2t/exps/wav2vec2/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddlespeech/s2t/exps/wav2vec2/model.py b/paddlespeech/s2t/exps/wav2vec2/model.py index 8e5f8d9db15..47ee1774709 100644 --- a/paddlespeech/s2t/exps/wav2vec2/model.py +++ b/paddlespeech/s2t/exps/wav2vec2/model.py @@ -279,7 +279,7 @@ def setup_model(self): logger.info("Setup model!") # setup speech augmentation for wav2vec2 - if hasattr(config, 'speech_augment'): + if hasattr(config, 'speech_augment') and self.train: speechaugment_config = TimeDomainSpecAugmentConfig( config.speech_augment) self.speech_augmentation = TimeDomainSpecAugment( From f53598b5c89bf82b36eff65415f9bbb1afc8f93b Mon Sep 17 00:00:00 2001 From: tianhao zhang <15600919271@163.com> Date: Wed, 9 Nov 2022 14:07:38 +0000 Subject: [PATCH 3/3] wav2vec2_en, test=asr --- .../librispeech/asr3/conf/wav2vec2ASR.yaml | 9 +- paddlespeech/resource/pretrained_models.py | 41 ++----- paddlespeech/s2t/exps/wav2vec2/model.py | 7 +- .../processing/speech_augmentation.py | 116 ++++++++---------- 4 files changed, 70 insertions(+), 103 deletions(-) diff --git a/examples/librispeech/asr3/conf/wav2vec2ASR.yaml b/examples/librispeech/asr3/conf/wav2vec2ASR.yaml index 8d5899e28ef..c45bd692a25 100644 --- a/examples/librispeech/asr3/conf/wav2vec2ASR.yaml +++ b/examples/librispeech/asr3/conf/wav2vec2ASR.yaml @@ -9,9 +9,6 @@ dnn_neurons: 1024 blank_id: 0 ctc_dropout_rate: 0.0 wav2vec2_params_path: "exp/wav2vec2/wav2vec2-large-960h-lv60-self.pdparams" -speech_augment: - sample_rate: 16000 - speeds: [95, 100, 105] ############################################ # Wav2Vec2.0 # @@ -97,6 +94,12 @@ dist_sampler: True shortest_first: True return_lens_rate: True +############################################ +# Data Augmentation # +############################################ +audio_augment: # for raw audio + sample_rate: 16000 + speeds: [95, 100, 105] ########################################### # Training # diff --git a/paddlespeech/resource/pretrained_models.py b/paddlespeech/resource/pretrained_models.py index 93ad30cd531..df50a6a9d52 100644 --- a/paddlespeech/resource/pretrained_models.py +++ b/paddlespeech/resource/pretrained_models.py @@ -13,13 +13,18 @@ # limitations under the License. __all__ = [ - 'asr_dynamic_pretrained_models', 'asr_static_pretrained_models', - 'asr_onnx_pretrained_models', 'cls_dynamic_pretrained_models', - 'cls_static_pretrained_models', 'st_dynamic_pretrained_models', - 'st_kaldi_bins', 'text_dynamic_pretrained_models', - 'tts_dynamic_pretrained_models', 'tts_static_pretrained_models', - 'tts_onnx_pretrained_models', 'vector_dynamic_pretrained_models', - 'ssl_pretrained_models' + 'asr_dynamic_pretrained_models', + 'asr_static_pretrained_models', + 'asr_onnx_pretrained_models', + 'cls_dynamic_pretrained_models', + 'cls_static_pretrained_models', + 'st_dynamic_pretrained_models', + 'st_kaldi_bins', + 'text_dynamic_pretrained_models', + 'tts_dynamic_pretrained_models', + 'tts_static_pretrained_models', + 'tts_onnx_pretrained_models', + 'vector_dynamic_pretrained_models', ] # The tags for pretrained_models should be "{model_name}[_{dataset}][-{lang}][-...]". @@ -27,28 +32,6 @@ # Command line and python api use "{model_name}[_{dataset}]" as --model, usage: # "paddlespeech asr --model conformer_wenetspeech --lang zh --sr 16000 --input ./input.wav" -# --------------------------------- -# -------------- SSL -------------- -# --------------------------------- -ssl_pretrained_models = { - "wav2vec2ASR_librispeech-en-16k": { - '1.3': { - 'url': - 'https://paddlespeech.bj.bcebos.com/s2t/librispeech/asr3/wav2vec2ASR-large-960h-librispeech_ckpt_1.3.1.model.tar.gz', - 'md5': - '7d9449a8103ec4b17d6a004e928e0b1f', - 'cfg_path': - 'model.yaml', - 'ckpt_path': - 'exp/wav2vec2ASR/checkpoints/avg_1', - 'model': - 'exp/wav2vec2ASR/checkpoints/avg_1.pdparams', - 'params': - 'exp/wav2vec2ASR/checkpoints/avg_1.pdparams', - }, - }, -} - # --------------------------------- # -------------- ASR -------------- # --------------------------------- diff --git a/paddlespeech/s2t/exps/wav2vec2/model.py b/paddlespeech/s2t/exps/wav2vec2/model.py index 47ee1774709..4f6bc0c5b87 100644 --- a/paddlespeech/s2t/exps/wav2vec2/model.py +++ b/paddlespeech/s2t/exps/wav2vec2/model.py @@ -28,7 +28,6 @@ from paddlespeech.s2t.frontend.featurizer import TextFeaturizer from paddlespeech.s2t.io.dataloader import DataLoaderFactory from paddlespeech.s2t.models.wav2vec2.processing.speech_augmentation import TimeDomainSpecAugment -from paddlespeech.s2t.models.wav2vec2.processing.speech_augmentation import TimeDomainSpecAugmentConfig from paddlespeech.s2t.models.wav2vec2.wav2vec2_ASR import Wav2vec2ASR from paddlespeech.s2t.training.optimizer import OptimizerFactory from paddlespeech.s2t.training.reporter import ObsScope @@ -279,11 +278,9 @@ def setup_model(self): logger.info("Setup model!") # setup speech augmentation for wav2vec2 - if hasattr(config, 'speech_augment') and self.train: - speechaugment_config = TimeDomainSpecAugmentConfig( - config.speech_augment) + if hasattr(config, 'audio_augment') and self.train: self.speech_augmentation = TimeDomainSpecAugment( - speechaugment_config) + **config.audio_augment) if not self.train: return diff --git a/paddlespeech/s2t/models/wav2vec2/processing/speech_augmentation.py b/paddlespeech/s2t/models/wav2vec2/processing/speech_augmentation.py index a9ab251d304..ac9bf45dbf1 100644 --- a/paddlespeech/s2t/models/wav2vec2/processing/speech_augmentation.py +++ b/paddlespeech/s2t/models/wav2vec2/processing/speech_augmentation.py @@ -641,56 +641,12 @@ def forward(self, waveforms, lengths): class TimeDomainSpecAugment(nn.Layer): """A time-domain approximation of the SpecAugment algorithm. - --------- This augmentation module implements three augmentations in the time-domain. 1. Drop chunks of the audio (zero amplitude or white noise) 2. Drop frequency bands (with band-drop filters) 3. Speed peturbation (via resampling to slightly different rate) - - Example - ------- - >>> inputs = paddle.randn([10, 16000]) - >>> feature_maker = TimeDomainSpecAugment(speeds=[80]) - >>> feats = feature_maker(inputs, paddle.ones(10)) - >>> feats.shape - paddle.shape([10, 12800]) - """ - - def __init__(self, config): - super().__init__() - self.speed_perturb = SpeedPerturb( - perturb_prob=config.perturb_prob, - orig_freq=config.sample_rate, - speeds=config.speeds) - self.drop_freq = DropFreq( - drop_prob=config.drop_freq_prob, - drop_count_low=config.drop_freq_count_low, - drop_count_high=config.drop_freq_count_high) - self.drop_chunk = DropChunk( - drop_prob=config.drop_chunk_prob, - drop_count_low=config.drop_chunk_count_low, - drop_count_high=config.drop_chunk_count_high, - drop_length_low=config.drop_chunk_length_low, - drop_length_high=config.drop_chunk_length_high, - noise_factor=config.drop_chunk_noise_factor) - - def forward(self, waveforms, lengths): - """Returns the distorted waveforms. - --------- - waveforms : tensor - The waveforms to distort - """ - # Augmentation - with paddle.no_grad(): - waveforms = self.speed_perturb(waveforms) - waveforms = self.drop_freq(waveforms) - waveforms = self.drop_chunk(waveforms, lengths) - return waveforms - - -class TimeDomainSpecAugmentConfig(): - """Augmentation configuration for time domain spectrograms. + Arguments --------- perturb_prob : float from 0 to 1 The probability that a batch will have speed perturbation applied. @@ -718,26 +674,54 @@ class TimeDomainSpecAugmentConfig(): drop_chunk_noise_factor : float The noise factor used to scale the white noise inserted, relative to the average amplitude of the utterance. Default 0 (no noise inserted). + Example + ------- + >>> inputs = paddle.randn([10, 16000]) + >>> feature_maker = TimeDomainSpecAugment(speeds=[80]) + >>> feats = feature_maker(inputs, paddle.ones(10)) + >>> feats.shape + paddle.shape([10, 12800]) """ - def __init__(self, config): - # speedperturb config - self.perturb_prob = getattr(config, 'perturb_prob', 1.0) - self.sample_rate = getattr(config, 'sample_rate', 16000) - self.speeds = getattr(config, 'speeds', [95, 100, 105]) - - # dropfreq config - self.drop_freq_prob = getattr(config, 'drop_freq_prob', 1.0) - self.drop_freq_count_low = getattr(config, 'drop_freq_count_low', 0) - self.drop_freq_count_high = getattr(config, 'drop_freq_count_high', 3) - - # dropchunk config - self.drop_chunk_prob = getattr(config, 'drop_chunk_prob', 1.0) - self.drop_chunk_count_low = getattr(config, 'drop_chunk_count_low', 0) - self.drop_chunk_count_high = getattr(config, 'drop_chunk_count_high', 5) - self.drop_chunk_length_low = getattr(config, 'drop_chunk_length_low', - 1000) - self.drop_chunk_length_high = getattr(config, 'drop_chunk_length_high', - 2000) - self.drop_chunk_noise_factor = getattr(config, - 'drop_chunk_noise_factor', 0) + def __init__( + self, + perturb_prob=1.0, + drop_freq_prob=1.0, + drop_chunk_prob=1.0, + speeds=[95, 100, 105], + sample_rate=16000, + drop_freq_count_low=0, + drop_freq_count_high=3, + drop_chunk_count_low=0, + drop_chunk_count_high=5, + drop_chunk_length_low=1000, + drop_chunk_length_high=2000, + drop_chunk_noise_factor=0, ): + super().__init__() + self.speed_perturb = SpeedPerturb( + perturb_prob=perturb_prob, orig_freq=sample_rate, speeds=speeds) + self.drop_freq = DropFreq( + drop_prob=drop_freq_prob, + drop_count_low=drop_freq_count_low, + drop_count_high=drop_freq_count_high, ) + self.drop_chunk = DropChunk( + drop_prob=drop_chunk_prob, + drop_count_low=drop_chunk_count_low, + drop_count_high=drop_chunk_count_high, + drop_length_low=drop_chunk_length_low, + drop_length_high=drop_chunk_length_high, + noise_factor=drop_chunk_noise_factor, ) + + def forward(self, waveforms, lengths): + """Returns the distorted waveforms. + Arguments + --------- + waveforms : tensor + The waveforms to distort + """ + # Augmentation + with paddle.no_grad(): + waveforms = self.speed_perturb(waveforms) + waveforms = self.drop_freq(waveforms) + waveforms = self.drop_chunk(waveforms, lengths) + return waveforms