Skip to content

Commit

Permalink
Correct padding for SFT input data to account for sequence parallel +…
Browse files Browse the repository at this point in the history
… TE's fp8 op dimension requirements (#8240)

* Alter GPTSFTDataset / GPTSFTPackedDataset to account for SP when padding sequences to ensure divisibility by 8/16 for TE with fp8

Signed-off-by: Valerie Sarge <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: Valerie Sarge <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
2 people authored and yaoyu-33 committed Jan 31, 2024
1 parent 47b0126 commit 84d8d6b
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def __init__(
tokenizer: TokenizerSpec,
max_seq_length: int = 1024,
min_seq_length: int = 1,
pad_seq_length_to_mult: int = 16,
add_bos: bool = False,
add_eos: bool = True,
add_sep: bool = False,
Expand Down Expand Up @@ -88,6 +89,7 @@ def __init__(
self.file_path = file_path
self.max_seq_length = max_seq_length
self.min_seq_length = min_seq_length
self.pad_seq_length_to_mult = pad_seq_length_to_mult
self.add_bos = add_bos
self.add_eos = add_eos
self.add_sep = add_sep
Expand Down Expand Up @@ -440,7 +442,7 @@ def collate_fn(self, batch):
if self.pad_to_max_length:
max_length = self.max_seq_length
else:
max_length = min(self.max_seq_length, self._ceil_to_nearest(max_length, 16))
max_length = min(self.max_seq_length, self._ceil_to_nearest(max_length, self.pad_seq_length_to_mult))
assert max_length <= self.max_seq_length

attention_mask = [self._create_attention_mask(max_length) for _ in batch]
Expand Down Expand Up @@ -534,7 +536,7 @@ def collate_fn(self, batch):
# for many datasets in practice, all packed sequence lengths are very close to the
# target length (2048, 4096, 8192), so there is very minimal padding
max_length = max(len(l) for l in input_ids)
max_length = min(self.max_seq_length, self._ceil_to_nearest(max_length, 16))
max_length = min(self.max_seq_length, self._ceil_to_nearest(max_length, self.pad_seq_length_to_mult))
assert max_length <= self.max_seq_length

position_ids: List[List[int]] = []
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,12 @@ def _build_dataset(self, data_cfg, is_train=True):
)
data_cfg.max_seq_length = self.cfg.max_position_embeddings

# TE requires that the first input dim is divisible by 8 and the second by 16 for fp8
# When using sequence parallel, sequence will further be split by TP size
pad_seq_length_to_mult = (
8 * self.cfg.get('tensor_model_parallel_size', 1) if self.cfg.get('sequence_parallel', False) else 16
)

for file_path, num_samples in zip(data_cfg.file_names, num_train_samples_per_dataset):
if self.cfg.data.get("chat", False):
dataset_cls = GPTSFTChatDataset
Expand All @@ -265,6 +271,7 @@ def _build_dataset(self, data_cfg, is_train=True):
tokenizer=self.tokenizer,
max_seq_length=data_cfg.max_seq_length,
min_seq_length=data_cfg.min_seq_length,
pad_seq_length_to_mult=pad_seq_length_to_mult,
add_bos=data_cfg.get('add_bos', False),
add_eos=data_cfg.get('add_eos', True),
add_sep=data_cfg.get('add_sep', False),
Expand Down

0 comments on commit 84d8d6b

Please sign in to comment.