Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions src/transformers/generation/logits_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -933,18 +933,19 @@ def __init__(self, generate_config): # support for the kwargs
self.no_timestamps_token_id = generate_config.no_timestamps_token_id
self.timestamp_begin = generate_config.no_timestamps_token_id + 1

self.begin_index = len(generate_config.forced_decoder_ids) + 1
self.begin_index = len(generate_config.forced_decoder_ids) + 2
if generate_config.forced_decoder_ids[-1][1] == self.no_timestamps_token_id:
self.begin_index -= 1
if generate_config.is_multilingual:
self.begin_index += 1
self.max_initial_timestamp_index = generate_config.max_initial_timestamp_index

def __call__(self, input_ids, scores):
# suppress <|notimestamps|> which is handled by without_timestamps
scores[:, self.no_timestamps_token_id] = -float("inf")
if input_ids.shape[1] == self.begin_index:

if input_ids.shape[1] == self.begin_index - 1:
scores[:, :] = -float("inf")
scores[:, self.timestamp_begin] = 0
return scores

# timestamps have to appear in pairs, except directly before eos_token; mask logits accordingly
for k in range(input_ids.shape[0]):
Expand Down
34 changes: 17 additions & 17 deletions tests/models/whisper/test_modeling_tf_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -699,8 +699,9 @@ def _test_large_generation(in_queue, out_queue, timeout):
input_speech = _load_datasamples(1)
input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="tf").input_features

model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="en", task="transcribe")
generated_ids = model.generate(input_features, do_sample=False, max_length=20)
generated_ids = model.generate(
input_features, do_sample=False, max_length=20, language="<|en|>", task="transcribe"
)
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]

EXPECTED_TRANSCRIPT = " Mr. Quilter is the apostle of the middle classes and we are glad"
Expand Down Expand Up @@ -728,26 +729,25 @@ def _test_large_generation_multilingual(in_queue, out_queue, timeout):
input_speech = next(iter(ds))["audio"]["array"]
input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="tf").input_features

model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="ja", task="transcribe")
generated_ids = model.generate(input_features, do_sample=False, max_length=20)
generated_ids = model.generate(
input_features, do_sample=False, max_length=20, language="<|ja|>", task="transcribe"
)
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]

EXPECTED_TRANSCRIPT = "木村さんに電話を貸してもらいました"
unittest.TestCase().assertEqual(transcript, EXPECTED_TRANSCRIPT)

model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="en", task="transcribe")
generated_ids = model.generate(
input_features,
do_sample=False,
max_length=20,
input_features, do_sample=False, max_length=20, language="<|en|>", task="transcribe"
)
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]

EXPECTED_TRANSCRIPT = " Kimura-san called me."
unittest.TestCase().assertEqual(transcript, EXPECTED_TRANSCRIPT)

model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="ja", task="translate")
generated_ids = model.generate(input_features, do_sample=False, max_length=20)
generated_ids = model.generate(
input_features, do_sample=False, max_length=20, language="<|ja|>", task="translate"
)
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]

EXPECTED_TRANSCRIPT = " I borrowed a phone from Kimura san"
Expand Down Expand Up @@ -779,10 +779,10 @@ def _test_large_batched_generation(in_queue, out_queue, timeout):
# fmt: off
EXPECTED_LOGITS = tf.convert_to_tensor(
[
[50258, 50358, 50363, 2221, 13, 2326, 388, 391, 307, 264, 50244, 295, 264, 2808, 5359, 293, 321, 366, 5404, 281],
[50258, 50358, 50363, 6966, 307, 2221, 13, 2326, 388, 391, 311, 9060, 1570, 1880, 813, 702, 1871, 13, 50257, 50257],
[50258, 50358, 50363, 634, 5112, 505, 300, 412, 341, 42729, 3196, 295, 264, 1064, 11, 365, 5272, 293, 12904, 9256],
[50258, 50358, 50363, 634, 575, 12525, 22618, 1968, 6144, 35617, 20084, 1756, 311, 589, 307, 534, 10281, 934, 439, 11]
[50258, 50259, 50358, 50363, 2221, 13, 2326, 388, 391, 307, 264, 50244, 295, 264, 2808, 5359, 293, 321, 366, 5404],
[50258, 50259, 50358, 50363, 6966, 307, 2221, 13, 2326, 388, 391, 311, 9060, 1570, 1880, 813, 702, 1871, 13, 50257],
[50258, 50259, 50358, 50363, 634, 5112, 505, 300, 412, 341, 42729, 3196, 295, 264, 1064, 11, 365, 5272, 293, 12904],
[50258, 50259, 50358, 50363, 634, 575, 12525, 22618, 1968, 6144, 35617, 20084, 1756, 311, 589, 307, 534, 10281, 934, 439]
]
)
# fmt: on
Expand All @@ -791,10 +791,10 @@ def _test_large_batched_generation(in_queue, out_queue, timeout):

# fmt: off
EXPECTED_TRANSCRIPT = [
' Mr. Quilter is the apostle of the middle classes and we are glad to',
" Mr. Quilter is the apostle of the middle classes and we are glad",
" Nor is Mr. Quilter's manner less interesting than his matter.",
" He tells us that at this festive season of the year, with Christmas and roast beef",
" He has grave doubts whether Sir Frederick Layton's work is really Greek after all,"
" He tells us that at this festive season of the year, with Christmas and roast",
" He has grave doubts whether Sir Frederick Layton's work is really Greek after all",
]
# fmt: on

Expand Down
34 changes: 15 additions & 19 deletions tests/models/whisper/test_modeling_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -945,11 +945,8 @@ def test_large_generation(self):
torch_device
)

model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="en", task="transcribe")
generated_ids = model.generate(
input_features,
do_sample=False,
max_length=20,
input_features, do_sample=False, max_length=20, language="<|en|>", task="transcribe"
)
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]

Expand All @@ -971,26 +968,25 @@ def test_large_generation_multilingual(self):
torch_device
)

model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="ja", task="transcribe")
generated_ids = model.generate(input_features, do_sample=False, max_length=20)
generated_ids = model.generate(
input_features, do_sample=False, max_length=20, language="<|ja|>", task="transcribe"
)
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]

EXPECTED_TRANSCRIPT = "木村さんに電話を貸してもらいました"
self.assertEqual(transcript, EXPECTED_TRANSCRIPT)

model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="en", task="transcribe")
generated_ids = model.generate(
input_features,
do_sample=False,
max_length=20,
input_features, do_sample=False, max_length=20, language="<|en|>", task="transcribe"
)
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]

EXPECTED_TRANSCRIPT = " Kimura-san called me."
self.assertEqual(transcript, EXPECTED_TRANSCRIPT)

model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="ja", task="translate")
generated_ids = model.generate(input_features, do_sample=False, max_length=20)
generated_ids = model.generate(
input_features, do_sample=False, max_length=20, language="<|ja|>", task="translate"
)
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]

EXPECTED_TRANSCRIPT = " I borrowed a phone from Kimura san"
Expand All @@ -1009,10 +1005,10 @@ def test_large_batched_generation(self):
# fmt: off
EXPECTED_LOGITS = torch.tensor(
[
[50258, 50358, 50363, 2221, 13, 2326, 388, 391, 307, 264, 50244, 295, 264, 2808, 5359, 293, 321, 366, 5404, 281],
[50258, 50358, 50363, 6966, 307, 2221, 13, 2326, 388, 391, 311, 9060, 1570, 1880, 813, 702, 1871, 13, 50257, 50257],
[50258, 50358, 50363, 634, 5112, 505, 300, 412, 341, 42729, 3196, 295, 264, 1064, 11, 365, 5272, 293, 12904, 9256],
[50258, 50358, 50363, 634, 575, 12525, 22618, 1968, 6144, 35617, 20084, 1756, 311, 589, 307, 534, 10281, 934, 439, 11]
[50258, 50259, 50358, 50363, 2221, 13, 2326, 388, 391, 307, 264, 50244, 295, 264, 2808, 5359, 293, 321, 366, 5404],
[50258, 50259, 50358, 50363, 6966, 307, 2221, 13, 2326, 388, 391, 311, 9060, 1570, 1880, 813, 702, 1871, 13, 50257],
[50258, 50259, 50358, 50363, 634, 5112, 505, 300, 412, 341, 42729, 3196, 295, 264, 1064, 11, 365, 5272, 293, 12904],
[50258, 50259, 50358, 50363, 634, 575, 12525, 22618, 1968, 6144, 35617, 20084, 1756, 311, 589, 307, 534, 10281, 934, 439]
]
)
# fmt: on
Expand All @@ -1021,10 +1017,10 @@ def test_large_batched_generation(self):

# fmt: off
EXPECTED_TRANSCRIPT = [
" Mr. Quilter is the apostle of the middle classes and we are glad to",
" Mr. Quilter is the apostle of the middle classes and we are glad",
" Nor is Mr. Quilter's manner less interesting than his matter.",
" He tells us that at this festive season of the year, with Christmas and roast beef",
" He has grave doubts whether Sir Frederick Layton's work is really Greek after all,",
" He tells us that at this festive season of the year, with Christmas and roast",
" He has grave doubts whether Sir Frederick Layton's work is really Greek after all",
]
# fmt: on

Expand Down