Skip to content

Commit 4cd2de9

Browse files
yidong72zhehuaichen
authored andcommitted
fix tab text gen (NVIDIA#7022)
Signed-off-by: Yi Dong <[email protected]>
1 parent c1ea67b commit 4cd2de9

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

nemo/collections/nlp/modules/common/text_generation_utils.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -393,6 +393,7 @@ def synced_generate(
393393
context_length_tensor,
394394
tokens_to_generate,
395395
all_probs,
396+
compute_attention_mask=compute_attention_mask,
396397
temperature=temperature,
397398
)
398399
else:
@@ -822,6 +823,7 @@ def tab_sample_sequence_batch(
822823
context_lengths,
823824
tokens_to_generate,
824825
all_probs=True,
826+
compute_attention_mask=True,
825827
type_ids=None,
826828
temperature=None,
827829
):
@@ -845,7 +847,7 @@ def tab_sample_sequence_batch(
845847
# initialize the batch
846848
with torch.no_grad():
847849
context_length = context_lengths.min().item()
848-
inference_strategy.init_batch(context_tokens, context_length)
850+
inference_strategy.init_batch(context_tokens, context_length, compute_attention_mask)
849851
context = context_tokens[:, :context_length]
850852
# the context may start in the middle of the row,
851853
# calculate the offset according to the position of '\n' or '<|endoftext|>'
@@ -879,7 +881,7 @@ def tab_sample_sequence_batch(
879881

880882
while context_length < maxlen:
881883
batch, tensor_shape = inference_strategy.prepare_batch_at_step(
882-
tokens, maxlen, micro_batch_size, counter, context_length
884+
tokens, maxlen, micro_batch_size, counter, context_length, compute_attention_mask
883885
)
884886
output = inference_strategy.forward_step(batch, tensor_shape)
885887

0 commit comments

Comments
 (0)