Skip to content

Commit

Permalink
fix tab text gen (#7022)
Browse files Browse the repository at this point in the history
Signed-off-by: Yi Dong <[email protected]>
  • Loading branch information
yidong72 authored and web-flow committed Jul 13, 2023
1 parent f7e33fc commit 4a72d2e
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions nemo/collections/nlp/modules/common/text_generation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,7 @@ def synced_generate(
context_length_tensor,
tokens_to_generate,
all_probs,
compute_attention_mask=compute_attention_mask,
temperature=temperature,
)
else:
Expand Down Expand Up @@ -825,6 +826,7 @@ def tab_sample_sequence_batch(
context_lengths,
tokens_to_generate,
all_probs=True,
compute_attention_mask=True,
type_ids=None,
temperature=None,
):
Expand All @@ -848,7 +850,7 @@ def tab_sample_sequence_batch(
# initialize the batch
with torch.no_grad():
context_length = context_lengths.min().item()
inference_strategy.init_batch(context_tokens, context_length)
inference_strategy.init_batch(context_tokens, context_length, compute_attention_mask)
context = context_tokens[:, :context_length]
# the context may start in the middle of the row,
# calculate the offset according to the position of '\n' or '<|endoftext|>'
Expand Down Expand Up @@ -882,7 +884,7 @@ def tab_sample_sequence_batch(

while context_length < maxlen:
batch, tensor_shape = inference_strategy.prepare_batch_at_step(
tokens, maxlen, micro_batch_size, counter, context_length
tokens, maxlen, micro_batch_size, counter, context_length, compute_attention_mask
)
output = inference_strategy.forward_step(batch, tensor_shape)

Expand Down

0 comments on commit 4a72d2e

Please sign in to comment.