From 988bdfb4990943676d978c536072ad526333fa24 Mon Sep 17 00:00:00 2001 From: arendu Date: Tue, 27 Jun 2023 12:34:09 -0700 Subject: [PATCH] comments Signed-off-by: arendu --- .../language_modeling/megatron/gpt_prompt_learning_dataset.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/nemo/collections/nlp/data/language_modeling/megatron/gpt_prompt_learning_dataset.py b/nemo/collections/nlp/data/language_modeling/megatron/gpt_prompt_learning_dataset.py index 2892867142c1..db40afb397b8 100755 --- a/nemo/collections/nlp/data/language_modeling/megatron/gpt_prompt_learning_dataset.py +++ b/nemo/collections/nlp/data/language_modeling/megatron/gpt_prompt_learning_dataset.py @@ -353,7 +353,8 @@ def collate_fn(self, batch, tp_workers=0): else: resi_padding = 0 batch_max += resi_padding - ceil_batch_max = self._ceil_to_nearest(batch_max, 8) + ceil_batch_max = self._ceil_to_nearest(batch_max, 8) # @adithyare this padding does not conflict with the tp_workers padding above + # since tp_workers is always a multiple of 2. the padding to multiple of 8 is to ensure an mem-optimized softmax is used. batch_max = ceil_batch_max + 1 input_ids, loss_mask = self.pad_batch_and_build_loss_mask(input_ids, batch_max, answer_starts) # Should be a label for every token in batch, label is the next token