Skip to content

Commit

Permalink
ptuning oom fix (#6916)
Browse files Browse the repository at this point in the history
* oom wip

Signed-off-by: arendu <[email protected]>

* minor

Signed-off-by: arendu <[email protected]>

* comments

Signed-off-by: arendu <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: arendu <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
arendu and pre-commit-ci[bot] authored Jun 27, 2023
1 parent 8204483 commit 7e3739b
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,9 @@ def __len__(self):
def __getitem__(self, idx):
return self.examples[idx]

def _ceil_to_nearest(self, n, m):
return (n + m - 1) // m * m

def collate_fn(self, batch, tp_workers=0):
""" Prepares input_ids, labels, loss mask, attention_mask, and position ids for global batch """
taskname_ids, input_ids, answer_starts = zip(*batch)
Expand All @@ -350,11 +353,16 @@ def collate_fn(self, batch, tp_workers=0):
else:
resi_padding = 0
batch_max += resi_padding
ceil_batch_max = self._ceil_to_nearest(
batch_max, 8
) # @adithyare this padding does not conflict with the tp_workers padding above
# since tp_workers is always a multiple of 2. the padding to multiple of 8 is to ensure an mem-optimized softmax is used.
batch_max = ceil_batch_max + 1
input_ids, loss_mask = self.pad_batch_and_build_loss_mask(input_ids, batch_max, answer_starts)
# Should be a label for every token in batch, label is the next token
labels = input_ids[:, 1:].contiguous()
input_ids = input_ids[:, :-1].contiguous()
batch_max -= 1
batch_max -= 1 # @adithyare I *think* this negatition is done to account for the above 2 lines which removes one item from the input_ids seq.

# Loss mask should align with labels
loss_mask = loss_mask[:, 1:].contiguous()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ def __init__(
AdapterName.LORA_KQV_ADAPTER,
]
lora_cfg = cfg.peft.lora_tuning
if cfg.kv_channels is None:
if cfg.get("kv_channels", None) is None:
assert (
cfg.hidden_size % cfg.num_attention_heads == 0
), 'hidden_size must be divisible by num_attention_heads if kv_channels is None'
Expand Down

0 comments on commit 7e3739b

Please sign in to comment.