From fac7103fd30a61b76e0838f10195ba9e6d660f3e Mon Sep 17 00:00:00 2001 From: Yi Dong <43824965+yidong72@users.noreply.github.com> Date: Thu, 13 Jul 2023 16:06:25 -0400 Subject: [PATCH] fix tab text gen (#7022) Signed-off-by: Yi Dong --- .../collections/nlp/modules/common/text_generation_utils.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/nemo/collections/nlp/modules/common/text_generation_utils.py b/nemo/collections/nlp/modules/common/text_generation_utils.py index d84d16efb5ba..3a41901f76ce 100644 --- a/nemo/collections/nlp/modules/common/text_generation_utils.py +++ b/nemo/collections/nlp/modules/common/text_generation_utils.py @@ -393,6 +393,7 @@ def synced_generate( context_length_tensor, tokens_to_generate, all_probs, + compute_attention_mask=compute_attention_mask, temperature=temperature, ) else: @@ -822,6 +823,7 @@ def tab_sample_sequence_batch( context_lengths, tokens_to_generate, all_probs=True, + compute_attention_mask=True, type_ids=None, temperature=None, ): @@ -845,7 +847,7 @@ def tab_sample_sequence_batch( # initialize the batch with torch.no_grad(): context_length = context_lengths.min().item() - inference_strategy.init_batch(context_tokens, context_length) + inference_strategy.init_batch(context_tokens, context_length, compute_attention_mask) context = context_tokens[:, :context_length] # the context may start in the middle of the row, # calculate the offset according to the position of '\n' or '<|endoftext|>' @@ -879,7 +881,7 @@ def tab_sample_sequence_batch( while context_length < maxlen: batch, tensor_shape = inference_strategy.prepare_batch_at_step( - tokens, maxlen, micro_batch_size, counter, context_length + tokens, maxlen, micro_batch_size, counter, context_length, compute_attention_mask ) output = inference_strategy.forward_step(batch, tensor_shape)