@@ -633,7 +633,6 @@ def _update_model_kwargs_for_generation(
633
633
model_kwargs : Dict [str , Any ],
634
634
is_encoder_decoder : bool = False ,
635
635
standardize_cache_format : bool = False ,
636
- model_inputs : Optional [Dict [str , Any ]] = None ,
637
636
) -> Dict [str , Any ]:
638
637
# update past_key_values
639
638
model_kwargs ["past_key_values" ] = self ._extract_past_from_model_output (
@@ -663,7 +662,8 @@ def _update_model_kwargs_for_generation(
663
662
dim = - 1 ,
664
663
)
665
664
666
- model_kwargs ["cache_position" ] = model_inputs .get ("cache_position" , None )
665
+ if "cache_position" in model_kwargs and model_kwargs ["cache_position" ] is not None :
666
+ model_kwargs ["cache_position" ] = model_kwargs ["cache_position" ][- 1 :] + 1
667
667
668
668
return model_kwargs
669
669
@@ -1931,10 +1931,15 @@ def _contrastive_search(
1931
1931
)
1932
1932
1933
1933
# keep track of which sequences are already finished
1934
- unfinished_sequences = torch .ones (input_ids .shape [0 ], dtype = torch .long , device = input_ids .device )
1934
+ batch_size , cur_len = (
1935
+ model_kwargs ["attention_mask" ].shape
1936
+ if model_kwargs .get ("attention_mask" , None ) is not None
1937
+ else input_ids .shape
1938
+ )
1939
+ unfinished_sequences = torch .ones (batch_size , dtype = torch .long , device = input_ids .device )
1940
+ model_kwargs ["cache_position" ] = torch .arange (cur_len , device = input_ids .device )
1935
1941
1936
1942
this_peer_finished = False # used by synced_gpus only
1937
- batch_size = input_ids .shape [0 ]
1938
1943
1939
1944
while True :
1940
1945
if synced_gpus :
@@ -1975,7 +1980,6 @@ def _contrastive_search(
1975
1980
model_kwargs ,
1976
1981
is_encoder_decoder = self .config .is_encoder_decoder ,
1977
1982
standardize_cache_format = True ,
1978
- model_inputs = model_inputs ,
1979
1983
)
1980
1984
if not sequential :
1981
1985
# Expands model inputs top_k times, for batched forward passes (akin to beam search).
@@ -2170,7 +2174,9 @@ def _contrastive_search(
2170
2174
if streamer is not None :
2171
2175
streamer .put (next_tokens .cpu ())
2172
2176
model_kwargs = self ._update_model_kwargs_for_generation (
2173
- outputs , model_kwargs , is_encoder_decoder = self .config .is_encoder_decoder , model_inputs = model_inputs
2177
+ outputs ,
2178
+ model_kwargs ,
2179
+ is_encoder_decoder = self .config .is_encoder_decoder ,
2174
2180
)
2175
2181
2176
2182
# if eos_token was found in one sentence, set sentence to finished
@@ -2389,7 +2395,13 @@ def _greedy_search(
2389
2395
)
2390
2396
2391
2397
# keep track of which sequences are already finished
2392
- unfinished_sequences = torch .ones (input_ids .shape [0 ], dtype = torch .long , device = input_ids .device )
2398
+ batch_size , cur_len = (
2399
+ model_kwargs ["attention_mask" ].shape
2400
+ if model_kwargs .get ("attention_mask" , None ) is not None
2401
+ else input_ids .shape
2402
+ )
2403
+ unfinished_sequences = torch .ones (batch_size , dtype = torch .long , device = input_ids .device )
2404
+ model_kwargs ["cache_position" ] = torch .arange (cur_len , device = input_ids .device )
2393
2405
2394
2406
this_peer_finished = False # used by synced_gpus only
2395
2407
while True :
@@ -2459,7 +2471,6 @@ def _greedy_search(
2459
2471
outputs ,
2460
2472
model_kwargs ,
2461
2473
is_encoder_decoder = self .config .is_encoder_decoder ,
2462
- model_inputs = model_inputs ,
2463
2474
)
2464
2475
2465
2476
# if eos_token was found in one sentence, set sentence to finished
@@ -2688,7 +2699,13 @@ def _sample(
2688
2699
)
2689
2700
2690
2701
# keep track of which sequences are already finished
2691
- unfinished_sequences = torch .ones (input_ids .shape [0 ], dtype = torch .long , device = input_ids .device )
2702
+ batch_size , cur_len = (
2703
+ model_kwargs ["attention_mask" ].shape
2704
+ if model_kwargs .get ("attention_mask" , None ) is not None
2705
+ else input_ids .shape
2706
+ )
2707
+ unfinished_sequences = torch .ones (batch_size , dtype = torch .long , device = input_ids .device )
2708
+ model_kwargs ["cache_position" ] = torch .arange (cur_len , device = input_ids .device )
2692
2709
2693
2710
this_peer_finished = False # used by synced_gpus only
2694
2711
# auto-regressive generation
@@ -2758,7 +2775,9 @@ def _sample(
2758
2775
if streamer is not None :
2759
2776
streamer .put (next_tokens .cpu ())
2760
2777
model_kwargs = self ._update_model_kwargs_for_generation (
2761
- outputs , model_kwargs , is_encoder_decoder = self .config .is_encoder_decoder , model_inputs = model_inputs
2778
+ outputs ,
2779
+ model_kwargs ,
2780
+ is_encoder_decoder = self .config .is_encoder_decoder ,
2762
2781
)
2763
2782
2764
2783
# if eos_token was found in one sentence, set sentence to finished
@@ -3003,6 +3022,7 @@ def _beam_search(
3003
3022
num_beams = beam_scorer .num_beams
3004
3023
3005
3024
batch_beam_size , cur_len = input_ids .shape
3025
+ model_kwargs ["cache_position" ] = torch .arange (cur_len , device = input_ids .device )
3006
3026
3007
3027
if num_beams * batch_size != batch_beam_size :
3008
3028
raise ValueError (
@@ -3156,7 +3176,9 @@ def _beam_search(
3156
3176
input_ids = torch .cat ([input_ids [beam_idx , :], beam_next_tokens .unsqueeze (- 1 )], dim = - 1 )
3157
3177
3158
3178
model_kwargs = self ._update_model_kwargs_for_generation (
3159
- outputs , model_kwargs , is_encoder_decoder = self .config .is_encoder_decoder , model_inputs = model_inputs
3179
+ outputs ,
3180
+ model_kwargs ,
3181
+ is_encoder_decoder = self .config .is_encoder_decoder ,
3160
3182
)
3161
3183
if model_kwargs .get ("past_key_values" , None ) is not None :
3162
3184
model_kwargs ["past_key_values" ] = self ._temporary_reorder_cache (
@@ -3397,6 +3419,7 @@ def _beam_sample(
3397
3419
num_beams = beam_scorer .num_beams
3398
3420
3399
3421
batch_beam_size , cur_len = input_ids .shape
3422
+ model_kwargs ["cache_position" ] = torch .arange (cur_len , device = input_ids .device )
3400
3423
3401
3424
# init attention / hidden states / scores tuples
3402
3425
scores = () if (return_dict_in_generate and output_scores ) else None
@@ -3510,7 +3533,9 @@ def _beam_sample(
3510
3533
input_ids = torch .cat ([input_ids [beam_idx , :], beam_next_tokens .unsqueeze (- 1 )], dim = - 1 )
3511
3534
3512
3535
model_kwargs = self ._update_model_kwargs_for_generation (
3513
- outputs , model_kwargs , is_encoder_decoder = self .config .is_encoder_decoder , model_inputs = model_inputs
3536
+ outputs ,
3537
+ model_kwargs ,
3538
+ is_encoder_decoder = self .config .is_encoder_decoder ,
3514
3539
)
3515
3540
if model_kwargs .get ("past_key_values" , None ) is not None :
3516
3541
model_kwargs ["past_key_values" ] = self ._temporary_reorder_cache (
@@ -3747,6 +3772,7 @@ def _group_beam_search(
3747
3772
device = input_ids .device
3748
3773
3749
3774
batch_beam_size , cur_len = input_ids .shape
3775
+ model_kwargs ["cache_position" ] = torch .arange (cur_len , device = input_ids .device )
3750
3776
3751
3777
if return_dict_in_generate and output_scores :
3752
3778
beam_indices = [tuple (() for _ in range (num_sub_beams * batch_size )) for _ in range (num_beam_groups )]
@@ -3916,7 +3942,9 @@ def _group_beam_search(
3916
3942
input_ids = torch .cat ([input_ids , current_tokens .unsqueeze (- 1 )], dim = - 1 )
3917
3943
3918
3944
model_kwargs = self ._update_model_kwargs_for_generation (
3919
- outputs , model_kwargs , is_encoder_decoder = self .config .is_encoder_decoder , model_inputs = model_inputs
3945
+ outputs ,
3946
+ model_kwargs ,
3947
+ is_encoder_decoder = self .config .is_encoder_decoder ,
3920
3948
)
3921
3949
if model_kwargs .get ("past_key_values" , None ) is not None :
3922
3950
model_kwargs ["past_key_values" ] = self ._temporary_reorder_cache (
@@ -4155,6 +4183,7 @@ def _constrained_beam_search(
4155
4183
num_beams = constrained_beam_scorer .num_beams
4156
4184
4157
4185
batch_beam_size , cur_len = input_ids .shape
4186
+ model_kwargs ["cache_position" ] = torch .arange (cur_len , device = input_ids .device )
4158
4187
4159
4188
if num_beams * batch_size != batch_beam_size :
4160
4189
raise ValueError (
@@ -4275,7 +4304,9 @@ def _constrained_beam_search(
4275
4304
4276
4305
input_ids = torch .cat ([input_ids [beam_idx , :], beam_next_tokens .unsqueeze (- 1 )], dim = - 1 )
4277
4306
model_kwargs = self ._update_model_kwargs_for_generation (
4278
- outputs , model_kwargs , is_encoder_decoder = self .config .is_encoder_decoder , model_inputs = model_inputs
4307
+ outputs ,
4308
+ model_kwargs ,
4309
+ is_encoder_decoder = self .config .is_encoder_decoder ,
4279
4310
)
4280
4311
if model_kwargs .get ("past_key_values" , None ) is not None :
4281
4312
model_kwargs ["past_key_values" ] = self ._temporary_reorder_cache (
@@ -4511,7 +4542,13 @@ def _assisted_decoding(
4511
4542
)
4512
4543
4513
4544
# keep track of which sequences are already finished
4514
- unfinished_sequences = input_ids .new (input_ids .shape [0 ]).fill_ (1 )
4545
+ batch_size , cur_len = batch_size , cur_len = (
4546
+ model_kwargs ["attention_mask" ].shape
4547
+ if model_kwargs .get ("attention_mask" , None ) is not None
4548
+ else input_ids .shape
4549
+ )
4550
+ unfinished_sequences = torch .ones (batch_size , dtype = torch .long , device = input_ids .device )
4551
+ model_kwargs ["cache_position" ] = torch .arange (cur_len , device = input_ids .device )
4515
4552
4516
4553
# other auxiliary variables
4517
4554
max_len = stopping_criteria [0 ].max_length
@@ -4555,6 +4592,14 @@ def _assisted_decoding(
4555
4592
candidate_kwargs , candidate_input_ids .shape [1 ], self .config .is_encoder_decoder
4556
4593
)
4557
4594
candidate_kwargs = _prepare_token_type_ids (candidate_kwargs , candidate_input_ids .shape [1 ])
4595
+ if "cache_position" in candidate_kwargs :
4596
+ candidate_kwargs ["cache_position" ] = torch .cat (
4597
+ (
4598
+ candidate_kwargs ["cache_position" ],
4599
+ torch .arange (cur_len , cur_len + candidate_length , device = input_ids .device , dtype = torch .long ),
4600
+ ),
4601
+ dim = 0 ,
4602
+ )
4558
4603
4559
4604
model_inputs = self .prepare_inputs_for_generation (candidate_input_ids , ** candidate_kwargs )
4560
4605
@@ -4673,7 +4718,9 @@ def _assisted_decoding(
4673
4718
)
4674
4719
4675
4720
model_kwargs = self ._update_model_kwargs_for_generation (
4676
- outputs , model_kwargs , is_encoder_decoder = self .config .is_encoder_decoder , model_inputs = model_inputs
4721
+ outputs ,
4722
+ model_kwargs ,
4723
+ is_encoder_decoder = self .config .is_encoder_decoder ,
4677
4724
)
4678
4725
4679
4726
# if eos_token was found in one sentence, set sentence to finished
0 commit comments