Skip to content

Commit

Permalink
Account for SP + CP case
Browse files Browse the repository at this point in the history
Signed-off-by: Valerie Sarge <[email protected]>
  • Loading branch information
vysarge committed Apr 18, 2024
1 parent 26b6252 commit eab5ecb
Showing 1 changed file with 5 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -253,13 +253,11 @@ def _build_dataset(self, data_cfg, is_train=True):

# TE requires that the first input dim is divisible by 8 and the second by 16 for fp8
# When using sequence parallel, sequence will further be split by TP size
# When using context parallel, sequence is split by CP size instead
if self.cfg.get('sequence_parallel', False):
pad_seq_length_to_mult = 8 * self.cfg.get('tensor_model_parallel_size', 1)
elif self.cfg.get('context_parallel_size', 1) > 1:
pad_seq_length_to_mult = 16 * self.cfg.get('context_parallel_size', 1)
else:
pad_seq_length_to_mult = 16
# When using context parallel, sequence is split by CP size as well
pad_seq_length_to_mult = (
8 * self.cfg.get('tensor_model_parallel_size', 1) if self.cfg.get('sequence_parallel', False) else 16
)
pad_seq_length_to_mult *= self.cfg.get('context_parallel_size', 1)

dataset_kwargs = {}
for file_path, num_samples in zip(data_cfg.file_names, num_train_samples_per_dataset):
Expand Down

0 comments on commit eab5ecb

Please sign in to comment.