diff --git a/src/transformers/models/whisper/modeling_tf_whisper.py b/src/transformers/models/whisper/modeling_tf_whisper.py index dd6ebcb2c8cc..12f6e7db5eef 100644 --- a/src/transformers/models/whisper/modeling_tf_whisper.py +++ b/src/transformers/models/whisper/modeling_tf_whisper.py @@ -1283,7 +1283,7 @@ def call( >>> inputs = processor(ds[0]["audio"]["array"], return_tensors="tf") >>> input_features = inputs.input_features - >>> generated_ids = model.generate(input_ids=input_features) + >>> generated_ids = model.generate(input_features=input_features) >>> transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] >>> transcription diff --git a/tests/models/whisper/test_modeling_whisper.py b/tests/models/whisper/test_modeling_whisper.py index bf092f9d3511..a092a79817ae 100644 --- a/tests/models/whisper/test_modeling_whisper.py +++ b/tests/models/whisper/test_modeling_whisper.py @@ -1158,7 +1158,7 @@ def test_large_batched_generation(self): input_speech = self._load_datasamples(4) input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="pt").input_features - generated_ids = model.generate(input_features, max_length=20) + generated_ids = model.generate(input_features, max_length=20, task="translate") # fmt: off EXPECTED_LOGITS = torch.tensor(