Skip to content

Commit

Permalink
comments
Browse files Browse the repository at this point in the history
Signed-off-by: arendu <[email protected]>
  • Loading branch information
arendu committed Jun 27, 2023
1 parent 0330665 commit 988bdfb
Showing 1 changed file with 2 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,8 @@ def collate_fn(self, batch, tp_workers=0):
else:
resi_padding = 0
batch_max += resi_padding
ceil_batch_max = self._ceil_to_nearest(batch_max, 8)
ceil_batch_max = self._ceil_to_nearest(batch_max, 8) # @adithyare this padding does not conflict with the tp_workers padding above
# since tp_workers is always a multiple of 2. the padding to multiple of 8 is to ensure an mem-optimized softmax is used.
batch_max = ceil_batch_max + 1
input_ids, loss_mask = self.pad_batch_and_build_loss_mask(input_ids, batch_max, answer_starts)
# Should be a label for every token in batch, label is the next token
Expand Down

0 comments on commit 988bdfb

Please sign in to comment.