diff --git a/src/transformers/models/whisper/feature_extraction_whisper.py b/src/transformers/models/whisper/feature_extraction_whisper.py index f2349120280b..418cb23b68b0 100644 --- a/src/transformers/models/whisper/feature_extraction_whisper.py +++ b/src/transformers/models/whisper/feature_extraction_whisper.py @@ -215,6 +215,29 @@ def _np_extract_fbank_features(self, waveform: np.array) -> np.ndarray: return log_spec + @staticmethod + # Copied from transformers.models.wav2vec2.feature_extraction_wav2vec2.Wav2Vec2FeatureExtractor.zero_mean_unit_var_norm + def zero_mean_unit_var_norm( + input_values: List[np.ndarray], attention_mask: List[np.ndarray], padding_value: float = 0.0 + ) -> List[np.ndarray]: + """ + Every array in the list is normalized to have zero mean and unit variance + """ + if attention_mask is not None: + attention_mask = np.array(attention_mask, np.int32) + normed_input_values = [] + + for vector, length in zip(input_values, attention_mask.sum(-1)): + normed_slice = (vector - vector[:length].mean()) / np.sqrt(vector[:length].var() + 1e-7) + if length < normed_slice.shape[0]: + normed_slice[length:] = padding_value + + normed_input_values.append(normed_slice) + else: + normed_input_values = [(x - x.mean()) / np.sqrt(x.var() + 1e-7) for x in input_values] + + return normed_input_values + def __call__( self, raw_speech: Union[np.ndarray, List[float], List[np.ndarray], List[List[float]]], @@ -225,6 +248,7 @@ def __call__( padding: Optional[str] = "max_length", max_length: Optional[int] = None, sampling_rate: Optional[int] = None, + do_normalize: Optional[bool] = None, **kwargs, ) -> BatchFeature: """ @@ -266,6 +290,9 @@ def __call__( pipeline. padding_value (`float`, defaults to 0.0): The value that is used to fill the padding values / vectors. + do_normalize (`bool`, *optional*, defaults to `False`): + Whether or not to zero-mean unit-variance normalize the input. Normalizing can help to significantly + improve the performance of the model. """ if sampling_rate is not None: @@ -312,6 +339,18 @@ def __call__( # make sure list is in array format input_features = padded_inputs.get("input_features").transpose(2, 0, 1) + if return_attention_mask: + # rescale from sample (48000) to feature (3000) + padded_inputs["attention_mask"] = padded_inputs["attention_mask"][:, :: self.hop_length] + + # zero-mean and unit-variance normalization + if do_normalize: + padded_inputs["input_features"] = self.zero_mean_unit_var_norm( + padded_inputs["input_features"], + attention_mask=padded_inputs["attention_mask"], + padding_value=self.padding_value, + ) + input_features = [self._np_extract_fbank_features(waveform) for waveform in input_features[0]] if isinstance(input_features[0], List): diff --git a/tests/models/whisper/test_feature_extraction_whisper.py b/tests/models/whisper/test_feature_extraction_whisper.py index f1ef36b1f28f..57c12b86ddbc 100644 --- a/tests/models/whisper/test_feature_extraction_whisper.py +++ b/tests/models/whisper/test_feature_extraction_whisper.py @@ -21,6 +21,7 @@ import unittest import numpy as np +from datasets import load_dataset from transformers import is_speech_available from transformers.testing_utils import check_json_file_has_correct_format, require_torch, require_torchaudio @@ -198,8 +199,6 @@ def test_double_precision_pad(self): self.assertTrue(pt_processed.input_features.dtype == torch.float32) def _load_datasamples(self, num_samples): - from datasets import load_dataset - ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") # automatic decoding with librispeech speech_samples = ds.sort("id").select(range(num_samples))[:num_samples]["audio"] @@ -222,3 +221,12 @@ def test_integration(self): feaure_extractor = WhisperFeatureExtractor() input_features = feaure_extractor(input_speech, return_tensors="pt").input_features self.assertTrue(torch.allclose(input_features[0, 0, :30], EXPECTED_INPUT_FEATURES, atol=1e-4)) + + def test_zero_mean_unit_variance_normalization_trunc_np_longest(self): + feat_extract = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict()) + audio = self._load_datasamples(1)[0] + audio = ((audio - audio.min()) / (audio.max() - audio.min())) * 65535 # Rescale to [0, 65535] to show issue + audio = feat_extract.zero_mean_unit_var_norm([audio], attention_mask=None)[0] + + self.assertTrue(np.all(np.mean(audio) < 1e-3)) + self.assertTrue(np.all(np.abs(np.var(audio) - 1) < 1e-3))