@@ -393,6 +393,7 @@ def synced_generate(
393
393
context_length_tensor ,
394
394
tokens_to_generate ,
395
395
all_probs ,
396
+ compute_attention_mask = compute_attention_mask ,
396
397
temperature = temperature ,
397
398
)
398
399
else :
@@ -822,6 +823,7 @@ def tab_sample_sequence_batch(
822
823
context_lengths ,
823
824
tokens_to_generate ,
824
825
all_probs = True ,
826
+ compute_attention_mask = True ,
825
827
type_ids = None ,
826
828
temperature = None ,
827
829
):
@@ -845,7 +847,7 @@ def tab_sample_sequence_batch(
845
847
# initialize the batch
846
848
with torch .no_grad ():
847
849
context_length = context_lengths .min ().item ()
848
- inference_strategy .init_batch (context_tokens , context_length )
850
+ inference_strategy .init_batch (context_tokens , context_length , compute_attention_mask )
849
851
context = context_tokens [:, :context_length ]
850
852
# the context may start in the middle of the row,
851
853
# calculate the offset according to the position of '\n' or '<|endoftext|>'
@@ -879,7 +881,7 @@ def tab_sample_sequence_batch(
879
881
880
882
while context_length < maxlen :
881
883
batch , tensor_shape = inference_strategy .prepare_batch_at_step (
882
- tokens , maxlen , micro_batch_size , counter , context_length
884
+ tokens , maxlen , micro_batch_size , counter , context_length , compute_attention_mask
883
885
)
884
886
output = inference_strategy .forward_step (batch , tensor_shape )
885
887
0 commit comments