From c8b81ce6cbb5967aff402e3ab445b2ce4f4854a1 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Wed, 25 Jan 2023 17:23:00 +0000 Subject: [PATCH 1/3] add small patch --- src/transformers/generation/logits_process.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/transformers/generation/logits_process.py b/src/transformers/generation/logits_process.py index 82dff5eab952..af4c697d8b39 100644 --- a/src/transformers/generation/logits_process.py +++ b/src/transformers/generation/logits_process.py @@ -933,18 +933,19 @@ def __init__(self, generate_config): # support for the kwargs self.no_timestamps_token_id = generate_config.no_timestamps_token_id self.timestamp_begin = generate_config.no_timestamps_token_id + 1 - self.begin_index = len(generate_config.forced_decoder_ids) + 1 + self.begin_index = len(generate_config.forced_decoder_ids) + 2 if generate_config.forced_decoder_ids[-1][1] == self.no_timestamps_token_id: self.begin_index -= 1 - if generate_config.is_multilingual: - self.begin_index += 1 self.max_initial_timestamp_index = generate_config.max_initial_timestamp_index def __call__(self, input_ids, scores): # suppress <|notimestamps|> which is handled by without_timestamps scores[:, self.no_timestamps_token_id] = -float("inf") - if input_ids.shape[1] == self.begin_index: + + if input_ids.shape[1] == self.begin_index - 1: + scores[:, :] = -float("inf") scores[:, self.timestamp_begin] = 0 + return scores # timestamps have to appear in pairs, except directly before eos_token; mask logits accordingly for k in range(input_ids.shape[0]): From 9663ced50766e9586aeed9f0d333b0416c53819e Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Wed, 25 Jan 2023 18:02:16 +0000 Subject: [PATCH 2/3] update tests, forced decoder ids is not prioritary against generation config --- .../whisper/test_modeling_tf_whisper.py | 20 +++++++++---------- tests/models/whisper/test_modeling_whisper.py | 20 ++++++++----------- 2 files changed, 18 insertions(+), 22 deletions(-) diff --git a/tests/models/whisper/test_modeling_tf_whisper.py b/tests/models/whisper/test_modeling_tf_whisper.py index 7facdd28743d..f64dc4388ea6 100644 --- a/tests/models/whisper/test_modeling_tf_whisper.py +++ b/tests/models/whisper/test_modeling_tf_whisper.py @@ -699,8 +699,9 @@ def _test_large_generation(in_queue, out_queue, timeout): input_speech = _load_datasamples(1) input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="tf").input_features - model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="en", task="transcribe") - generated_ids = model.generate(input_features, do_sample=False, max_length=20) + generated_ids = model.generate( + input_features, do_sample=False, max_length=20, language="<|en|>", task="transcribe" + ) transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] EXPECTED_TRANSCRIPT = " Mr. Quilter is the apostle of the middle classes and we are glad" @@ -728,26 +729,25 @@ def _test_large_generation_multilingual(in_queue, out_queue, timeout): input_speech = next(iter(ds))["audio"]["array"] input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="tf").input_features - model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="ja", task="transcribe") - generated_ids = model.generate(input_features, do_sample=False, max_length=20) + generated_ids = model.generate( + input_features, do_sample=False, max_length=20, language="<|ja|>", task="transcribe" + ) transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] EXPECTED_TRANSCRIPT = "木村さんに電話を貸してもらいました" unittest.TestCase().assertEqual(transcript, EXPECTED_TRANSCRIPT) - model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="en", task="transcribe") generated_ids = model.generate( - input_features, - do_sample=False, - max_length=20, + input_features, do_sample=False, max_length=20, language="<|en|>", task="transcribe" ) transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] EXPECTED_TRANSCRIPT = " Kimura-san called me." unittest.TestCase().assertEqual(transcript, EXPECTED_TRANSCRIPT) - model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="ja", task="translate") - generated_ids = model.generate(input_features, do_sample=False, max_length=20) + generated_ids = model.generate( + input_features, do_sample=False, max_length=20, language="<|ja|>", task="translate" + ) transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] EXPECTED_TRANSCRIPT = " I borrowed a phone from Kimura san" diff --git a/tests/models/whisper/test_modeling_whisper.py b/tests/models/whisper/test_modeling_whisper.py index cf2cb1d966d2..a634f545a632 100644 --- a/tests/models/whisper/test_modeling_whisper.py +++ b/tests/models/whisper/test_modeling_whisper.py @@ -945,11 +945,8 @@ def test_large_generation(self): torch_device ) - model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="en", task="transcribe") generated_ids = model.generate( - input_features, - do_sample=False, - max_length=20, + input_features, do_sample=False, max_length=20, language="<|en|>", task="transcribe" ) transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] @@ -971,26 +968,25 @@ def test_large_generation_multilingual(self): torch_device ) - model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="ja", task="transcribe") - generated_ids = model.generate(input_features, do_sample=False, max_length=20) + generated_ids = model.generate( + input_features, do_sample=False, max_length=20, language="<|ja|>", task="transcribe" + ) transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] EXPECTED_TRANSCRIPT = "木村さんに電話を貸してもらいました" self.assertEqual(transcript, EXPECTED_TRANSCRIPT) - model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="en", task="transcribe") generated_ids = model.generate( - input_features, - do_sample=False, - max_length=20, + input_features, do_sample=False, max_length=20, language="<|en|>", task="transcribe" ) transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] EXPECTED_TRANSCRIPT = " Kimura-san called me." self.assertEqual(transcript, EXPECTED_TRANSCRIPT) - model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="ja", task="translate") - generated_ids = model.generate(input_features, do_sample=False, max_length=20) + generated_ids = model.generate( + input_features, do_sample=False, max_length=20, language="<|ja|>", task="translate" + ) transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] EXPECTED_TRANSCRIPT = " I borrowed a phone from Kimura san" From 9cc10e773b7e661075cef12c6477425de65878a2 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Wed, 25 Jan 2023 18:41:09 +0000 Subject: [PATCH 3/3] fix two new tests --- tests/models/whisper/test_modeling_tf_whisper.py | 14 +++++++------- tests/models/whisper/test_modeling_whisper.py | 14 +++++++------- 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/tests/models/whisper/test_modeling_tf_whisper.py b/tests/models/whisper/test_modeling_tf_whisper.py index f64dc4388ea6..1fd89f2525b8 100644 --- a/tests/models/whisper/test_modeling_tf_whisper.py +++ b/tests/models/whisper/test_modeling_tf_whisper.py @@ -779,10 +779,10 @@ def _test_large_batched_generation(in_queue, out_queue, timeout): # fmt: off EXPECTED_LOGITS = tf.convert_to_tensor( [ - [50258, 50358, 50363, 2221, 13, 2326, 388, 391, 307, 264, 50244, 295, 264, 2808, 5359, 293, 321, 366, 5404, 281], - [50258, 50358, 50363, 6966, 307, 2221, 13, 2326, 388, 391, 311, 9060, 1570, 1880, 813, 702, 1871, 13, 50257, 50257], - [50258, 50358, 50363, 634, 5112, 505, 300, 412, 341, 42729, 3196, 295, 264, 1064, 11, 365, 5272, 293, 12904, 9256], - [50258, 50358, 50363, 634, 575, 12525, 22618, 1968, 6144, 35617, 20084, 1756, 311, 589, 307, 534, 10281, 934, 439, 11] + [50258, 50259, 50358, 50363, 2221, 13, 2326, 388, 391, 307, 264, 50244, 295, 264, 2808, 5359, 293, 321, 366, 5404], + [50258, 50259, 50358, 50363, 6966, 307, 2221, 13, 2326, 388, 391, 311, 9060, 1570, 1880, 813, 702, 1871, 13, 50257], + [50258, 50259, 50358, 50363, 634, 5112, 505, 300, 412, 341, 42729, 3196, 295, 264, 1064, 11, 365, 5272, 293, 12904], + [50258, 50259, 50358, 50363, 634, 575, 12525, 22618, 1968, 6144, 35617, 20084, 1756, 311, 589, 307, 534, 10281, 934, 439] ] ) # fmt: on @@ -791,10 +791,10 @@ def _test_large_batched_generation(in_queue, out_queue, timeout): # fmt: off EXPECTED_TRANSCRIPT = [ - ' Mr. Quilter is the apostle of the middle classes and we are glad to', + " Mr. Quilter is the apostle of the middle classes and we are glad", " Nor is Mr. Quilter's manner less interesting than his matter.", - " He tells us that at this festive season of the year, with Christmas and roast beef", - " He has grave doubts whether Sir Frederick Layton's work is really Greek after all," + " He tells us that at this festive season of the year, with Christmas and roast", + " He has grave doubts whether Sir Frederick Layton's work is really Greek after all", ] # fmt: on diff --git a/tests/models/whisper/test_modeling_whisper.py b/tests/models/whisper/test_modeling_whisper.py index a634f545a632..e9a6f0705f2d 100644 --- a/tests/models/whisper/test_modeling_whisper.py +++ b/tests/models/whisper/test_modeling_whisper.py @@ -1005,10 +1005,10 @@ def test_large_batched_generation(self): # fmt: off EXPECTED_LOGITS = torch.tensor( [ - [50258, 50358, 50363, 2221, 13, 2326, 388, 391, 307, 264, 50244, 295, 264, 2808, 5359, 293, 321, 366, 5404, 281], - [50258, 50358, 50363, 6966, 307, 2221, 13, 2326, 388, 391, 311, 9060, 1570, 1880, 813, 702, 1871, 13, 50257, 50257], - [50258, 50358, 50363, 634, 5112, 505, 300, 412, 341, 42729, 3196, 295, 264, 1064, 11, 365, 5272, 293, 12904, 9256], - [50258, 50358, 50363, 634, 575, 12525, 22618, 1968, 6144, 35617, 20084, 1756, 311, 589, 307, 534, 10281, 934, 439, 11] + [50258, 50259, 50358, 50363, 2221, 13, 2326, 388, 391, 307, 264, 50244, 295, 264, 2808, 5359, 293, 321, 366, 5404], + [50258, 50259, 50358, 50363, 6966, 307, 2221, 13, 2326, 388, 391, 311, 9060, 1570, 1880, 813, 702, 1871, 13, 50257], + [50258, 50259, 50358, 50363, 634, 5112, 505, 300, 412, 341, 42729, 3196, 295, 264, 1064, 11, 365, 5272, 293, 12904], + [50258, 50259, 50358, 50363, 634, 575, 12525, 22618, 1968, 6144, 35617, 20084, 1756, 311, 589, 307, 534, 10281, 934, 439] ] ) # fmt: on @@ -1017,10 +1017,10 @@ def test_large_batched_generation(self): # fmt: off EXPECTED_TRANSCRIPT = [ - " Mr. Quilter is the apostle of the middle classes and we are glad to", + " Mr. Quilter is the apostle of the middle classes and we are glad", " Nor is Mr. Quilter's manner less interesting than his matter.", - " He tells us that at this festive season of the year, with Christmas and roast beef", - " He has grave doubts whether Sir Frederick Layton's work is really Greek after all,", + " He tells us that at this festive season of the year, with Christmas and roast", + " He has grave doubts whether Sir Frederick Layton's work is really Greek after all", ] # fmt: on