Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support dynamic length batches with GPT SFT #6510

Merged
merged 3 commits into from
May 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 52 additions & 68 deletions nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,56 @@ def forward(self, tokens, text_position_ids, attention_mask, labels):
output_tensor = self.model(tokens, text_position_ids, attention_mask, labels=labels)
return output_tensor

def fwd_bwd_step(self, dataloader_iter, batch_idx, forward_only):
tensor_shape = [self.cfg.encoder_seq_length, self.cfg.micro_batch_size, self.cfg.hidden_size]

# run forward and backwards passes for an entire global batch
# we do this inside training_step to support pipeline parallelism
fwd_bwd_function = get_forward_backward_func()

# TODO @akhattar: remove sync related stuff from config, add num_micro_batches_with_partial_activation_checkpoints when ready
losses_reduced_per_micro_batch = fwd_bwd_function(
forward_step_func=self.get_forward_output_and_loss_func(),
data_iterator=dataloader_iter,
model=[self.model],
num_microbatches=get_num_microbatches(),
forward_only=forward_only,
tensor_shape=tensor_shape,
dtype=self.autocast_dtype,
grad_scaler=self.trainer.precision_plugin.scaler if self.cfg.precision == 16 else None,
sequence_parallel=self.cfg.get('sequence_parallel', False),
enable_autocast=True,
)

# only the last stages of the pipeline return losses
if losses_reduced_per_micro_batch:
if (not forward_only) or self.cfg.data.get('validation_drop_last', True):
# average loss across micro batches
loss_tensors_list = [loss_reduced['avg'] for loss_reduced in losses_reduced_per_micro_batch]
loss_tensor = torch.concat(loss_tensors_list)
loss_mean = loss_tensor.mean()
else:
# Get the total loss since micro batches sizes are not uniform
loss_sum_tensors_list = [
loss_sum['loss_sum_and_ub_size']
for loss_sum in losses_reduced_per_micro_batch
if loss_sum['loss_sum_and_ub_size'][1] > 0
]
loss_sum = (
torch.vstack(loss_sum_tensors_list).sum(axis=0)
if len(loss_sum_tensors_list) > 0
else torch.tensor([0.0, 0.0]).cuda()
)
return loss_sum
else:
# we're not on the last pipeline stage so no losses
if forward_only:
loss_mean = []
else:
loss_mean = torch.tensor(0.0).cuda()

return loss_mean

def training_step(self, dataloader_iter, batch_idx):
"""
We pass the dataloader iterator function to the micro-batch scheduler.
Expand Down Expand Up @@ -358,34 +408,7 @@ def training_step(self, dataloader_iter, batch_idx):
for param in module.embedding.parameters():
param.data_ptr()

tensor_shape = [self.cfg.encoder_seq_length, self.cfg.micro_batch_size, self.cfg.hidden_size]

# run forward and backwards passes for an entire global batch
# we do this inside training_step to support pipeline parallelism
fwd_bwd_function = get_forward_backward_func()

# TODO @akhattar: remove sync related stuff from config, add num_micro_batches_with_partial_activation_checkpoints when ready
losses_reduced_per_micro_batch = fwd_bwd_function(
forward_step_func=self.get_forward_output_and_loss_func(),
data_iterator=dataloader_iter,
model=[self.model],
num_microbatches=get_num_microbatches(),
forward_only=False,
tensor_shape=tensor_shape,
dtype=self.autocast_dtype,
grad_scaler=self.trainer.precision_plugin.scaler if self.cfg.precision == 16 else None,
sequence_parallel=self.cfg.get('sequence_parallel', False),
enable_autocast=True,
)

# only the last stages of the pipeline return losses
if losses_reduced_per_micro_batch:
# average loss across micro batches
loss_tensors_list = [loss_reduced['avg'] for loss_reduced in losses_reduced_per_micro_batch]
loss_tensor = torch.concat(loss_tensors_list)
loss_mean = loss_tensor.mean()
else:
loss_mean = torch.tensor(0.0).cuda()
loss_mean = self.fwd_bwd_step(dataloader_iter, batch_idx, False)

# when using sequence parallelism, the sequence parallel layernorm grads must be all-reduced
if self.cfg.get('tensor_model_parallel_size', 1) > 1 and self.cfg.get('sequence_parallel', False):
Expand Down Expand Up @@ -642,46 +665,7 @@ def validation_step(self, dataloader_iter, batch_idx):
The list of microbatches is then piped through the pipeline using megatron-core fwd/bwd functions.
"""

tensor_shape = [self.cfg.encoder_seq_length, self.cfg.micro_batch_size, self.cfg.hidden_size]

# run forward passes for an entire global batch
# we do this inside validation_step to support pipeline parallelism
fwd_bwd_function = get_forward_backward_func()

losses_reduced_per_micro_batch = fwd_bwd_function(
forward_step_func=self.get_forward_output_and_loss_func(validation_step=True),
data_iterator=dataloader_iter,
model=[self.model],
num_microbatches=get_num_microbatches(),
forward_only=True,
tensor_shape=tensor_shape,
dtype=self.autocast_dtype,
sequence_parallel=self.cfg.get('sequence_parallel', False),
enable_autocast=True,
)

# only the last stage of the pipeline returns losses
if losses_reduced_per_micro_batch:
if self.cfg.data.get('validation_drop_last', True):
# average loss across micro batches
loss_tensors_list = [loss_reduced['avg'] for loss_reduced in losses_reduced_per_micro_batch]
return torch.concat(loss_tensors_list).mean()
else:
# Get the total loss since micro batches sizes are not uniform
loss_sum_tensors_list = [
loss_sum['loss_sum_and_ub_size']
for loss_sum in losses_reduced_per_micro_batch
if loss_sum['loss_sum_and_ub_size'][1] > 0
]
loss_sum = (
torch.vstack(loss_sum_tensors_list).sum(axis=0)
if len(loss_sum_tensors_list) > 0
else torch.tensor([0.0, 0.0]).cuda()
)
return loss_sum
else:
# we're not on the last pipeline stage so no losses
return []
return self.fwd_bwd_step(dataloader_iter, batch_idx, True)

def validation_epoch_end(self, outputs):
if parallel_state.is_pipeline_last_stage():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,21 +23,29 @@
get_datasets_weights_and_num_samples,
)
from nemo.collections.nlp.data.language_modeling.megatron.blendable_dataset import BlendableDataset
from nemo.collections.nlp.data.language_modeling.megatron.data_samplers import MegatronPretrainingSampler
from nemo.collections.nlp.data.language_modeling.megatron.gpt_sft_dataset import GPTSFTDataset
from nemo.collections.nlp.data.language_modeling.megatron.megatron_batch_samplers import (
MegatronPretrainingBatchSampler,
)
from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel
from nemo.collections.nlp.modules.common.megatron.utils import get_iterator_k_split
from nemo.collections.nlp.modules.common.text_generation_utils import LengthParam, SamplingParam, megatron_gpt_generate
from nemo.utils import AppState, logging

try:
from apex.transformer.pipeline_parallel.utils import _reconfigure_microbatch_calculator
from apex.transformer.pipeline_parallel.utils import (
_reconfigure_microbatch_calculator,
get_micro_batch_size,
get_num_microbatches,
)

HAVE_APEX = True
except (ImportError, ModuleNotFoundError):
HAVE_APEX = False

try:
from megatron.core import parallel_state
from megatron.core.pipeline_parallel.schedules import get_forward_backward_func

HAVE_MEGATRON_CORE = True

Expand Down Expand Up @@ -237,6 +245,7 @@ def _build_dataset(self, data_cfg, is_train=True):
),
answer_only_loss=self.cfg.get('answer_only_loss', True),
truncation_field=data_cfg.get('truncation_field', 'context'),
pad_to_max_length=False,
index_mapping_dir=data_cfg.get('index_mapping_dir', None),
prompt_template=data_cfg.get('prompt_template', None),
)
Expand Down Expand Up @@ -264,6 +273,56 @@ def _determine_log_key(self, data_config, dataloader_idx, metric_name, mode):
else:
return base_key + f"dataloader{dataloader_idx}"

def fwd_bwd_step(self, dataloader_iter, batch_idx, forward_only):
batch = next(dataloader_iter)
_, seq_length = batch['tokens'].shape
tensor_shape = [seq_length, get_micro_batch_size(), self.cfg.hidden_size]
data_iter = get_iterator_k_split(batch, get_num_microbatches())

fwd_bwd_function = get_forward_backward_func()

losses_reduced_per_micro_batch = fwd_bwd_function(
forward_step_func=self.get_forward_output_and_loss_func(),
data_iterator=data_iter,
model=[self.model],
num_microbatches=get_num_microbatches(),
forward_only=forward_only,
tensor_shape=tensor_shape,
dtype=self.autocast_dtype,
grad_scaler=self.trainer.precision_plugin.scaler if self.cfg.precision == 16 else None,
sequence_parallel=self.cfg.get('sequence_parallel', False),
enable_autocast=True,
)

# only the last stages of the pipeline return losses
if losses_reduced_per_micro_batch:
if (not forward_only) or self.cfg.data.get('validation_drop_last', True):
# average loss across micro batches
loss_tensors_list = [loss_reduced['avg'] for loss_reduced in losses_reduced_per_micro_batch]
loss_tensor = torch.concat(loss_tensors_list)
loss_mean = loss_tensor.mean()
else:
# Get the total loss since micro batches sizes are not uniform
loss_sum_tensors_list = [
loss_sum['loss_sum_and_ub_size']
for loss_sum in losses_reduced_per_micro_batch
if loss_sum['loss_sum_and_ub_size'][1] > 0
]
loss_sum = (
torch.vstack(loss_sum_tensors_list).sum(axis=0)
if len(loss_sum_tensors_list) > 0
else torch.tensor([0.0, 0.0]).cuda()
)
return loss_sum
else:
# we're not on the last pipeline stage so no losses
if forward_only:
loss_mean = []
else:
loss_mean = torch.tensor(0.0).cuda()

return loss_mean

def validation_step(self, dataloader_iter, batch_idx, dataloader_idx=0):
return self.inference_step(dataloader_iter, batch_idx, 'validation', dataloader_idx)

Expand Down Expand Up @@ -561,7 +620,7 @@ def build_data_loader(self, dataset, data_cfg, consumed_samples=0):
else:
collate_fn = dataset.collate_fn

batch_sampler = MegatronPretrainingSampler(
batch_sampler = MegatronPretrainingBatchSampler(
total_samples=len(dataset),
consumed_samples=consumed_samples,
micro_batch_size=data_cfg.micro_batch_size,
Expand Down
16 changes: 12 additions & 4 deletions nemo/collections/nlp/modules/common/megatron/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,8 +368,16 @@ def get_all_params_for_weight_decay_optimization(
return ({'params': weight_decay_params},)


def get_iterator_k_split(batch: List[torch.Tensor], microbatches: int) -> Iterator:
assert batch[0].shape[0] % microbatches == 0, "Issue with batch size configuration!"
split_batch = [torch.tensor_split(item, microbatches, dim=0) for item in batch]
microbatches = [[elem[i] for elem in split_batch] for i in range(microbatches)]
def get_iterator_k_split(batch: List[torch.Tensor], num_microbatches: int) -> Iterator:
if isinstance(batch, dict):
items = list(batch.items())
assert items[0][1].shape[0] % num_microbatches == 0, "Issue with batch size configuration!"
split_batch = [torch.tensor_split(item[1], num_microbatches, dim=0) for item in items]
microbatches = [[(items[i][0], split_batch[i][j]) for i in range(len(items))] for j in range(num_microbatches)]
microbatches = [dict(elem) for elem in microbatches]
else:
assert batch[0].shape[0] % num_microbatches == 0, "Issue with batch size configuration!"
split_batch = [torch.tensor_split(item, num_microbatches, dim=0) for item in batch]
microbatches = [[elem[i] for elem in split_batch] for i in range(num_microbatches)]

return itertools.chain(microbatches)
Loading