diff --git a/megatron/arguments.py b/megatron/arguments.py index 5f4e2b53f..304652adc 100644 --- a/megatron/arguments.py +++ b/megatron/arguments.py @@ -170,6 +170,7 @@ def parse_args(extra_args_provider=None, defaults={}, # Consumed tokens. args.consumed_train_samples = 0 args.consumed_valid_samples = 0 + args.gigaflos_no_embeds = 0 # Iteration-based training. if args.train_iters: diff --git a/megatron/checkpointing.py b/megatron/checkpointing.py index 829fb1101..8a73d1d5b 100644 --- a/megatron/checkpointing.py +++ b/megatron/checkpointing.py @@ -353,6 +353,8 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True update_num_microbatches(consumed_samples=args.consumed_train_samples) args.consumed_valid_samples = getattr(checkpoint_args, 'consumed_valid_samples', 0) + args.gigaflos_no_embeds = getattr(checkpoint_args, + 'gigaflos_no_embeds', 0) else: print_rank_0('could not find arguments in the checkpoint ...') diff --git a/megatron/data/gpt_dataset.py b/megatron/data/gpt_dataset.py index e6c64e975..e605c216e 100644 --- a/megatron/data/gpt_dataset.py +++ b/megatron/data/gpt_dataset.py @@ -237,10 +237,10 @@ def _build_index_mappings(name, data_prefix, documents, sizes, last_epoch_num_samples = num_samples - \ num_samples_from_epochs_minus_one assert last_epoch_num_samples >= 0, \ - 'last epoch number of samples should be non-negative.' + f'last epoch number of samples {last_epoch_num_samples} should be non-negative.' num_samples_per_epoch = (tokens_per_epoch - 1) // seq_length - assert last_epoch_num_samples < (num_samples_per_epoch + 1), \ - 'last epoch number of samples exceeded max value.' + assert last_epoch_num_samples <= num_samples_per_epoch, \ + f'last epoch number of samples {last_epoch_num_samples} exceeded max value {num_samples_per_epoch}.' # If we have less than 80% of the samples for the last epoch, # seperate out the epoch and treat it differently. # Note: the 80% number is just based on common sense and can diff --git a/megatron/training.py b/megatron/training.py index 21ef13b94..df300fac8 100644 --- a/megatron/training.py +++ b/megatron/training.py @@ -43,7 +43,7 @@ from megatron.initialize import write_args_to_tensorboard from megatron.learning_rates import AnnealingLR from megatron.model import DistributedDataParallel as LocalDDP -from megatron.utils import check_adlr_autoresume_termination +from megatron.utils import check_adlr_autoresume_termination, get_parameters_in_billions from megatron.utils import unwrap_model from megatron.data.data_samplers import build_pretraining_data_loader from megatron.utils import calc_params_l2_norm @@ -113,6 +113,8 @@ def pretrain(train_valid_test_dataset_provider, # Model, optimizer, and learning rate. timers('model-and-optimizer-setup').start() model, optimizer, lr_scheduler = setup_model_and_optimizer(model_provider) + print(f'estimated model parameters: {get_parameters_in_billions(model)}') + print(f'estimated model parameters without embeddings: {get_parameters_in_billions(model, exclude_embeddings=True)}') timers('model-and-optimizer-setup').stop() print_datetime('after model, optimizer, and learning rate ' 'scheduler are built') @@ -545,7 +547,7 @@ def add_to_logging(name): total_loss_dict[skipped_iters_key] # Tensorboard values. - if writer and (iteration % args.tensorboard_log_interval == 0 ) and \ + if writer and (iteration % args.tensorboard_log_interval == 0) and \ is_last_rank(): writer.add_scalar('steps-vs-samples/y=steps,x=samples', iteration, args.consumed_train_samples) writer.add_scalar('steps-vs-samples/y=samples,x=steps', args.consumed_train_samples, iteration) @@ -561,6 +563,8 @@ def add_to_logging(name): writer.add_scalar(f"lm-loss-training/{key}", loss_dict[key], iteration) writer.add_scalar(f"lm-loss-training/{key}" + ' vs samples', loss_dict[key], args.consumed_train_samples) + writer.add_scalar(f"lm-loss-training/{key}" + ' vs gigaflos (without embeddings)', loss_dict[key], + args.gigaflos_no_embeds) if args.log_loss_scale_to_tensorboard: writer.add_scalar('loss-scale/loss-scale', loss_scale, iteration) writer.add_scalar('loss-scale/loss-scale vs samples', loss_scale, @@ -647,6 +651,8 @@ def save_checkpoint_and_time(iteration, model, optimizer, lr_scheduler): def train(forward_step_func, model, optimizer, lr_scheduler, train_data_iterator, valid_data_iterator): """Train the model function.""" + print(f"Number of parameters: {get_parameters_in_billions(model)} billion") + print(f"Number of parameters without embeddings: {get_parameters_in_billions(model, exclude_embeddings=True)} billion") args = get_args() timers = get_timers() @@ -683,9 +689,11 @@ def train(forward_step_func, model, optimizer, lr_scheduler, optimizer, lr_scheduler) iteration += 1 - args.consumed_train_samples += mpu.get_data_parallel_world_size() * \ + new_samples = mpu.get_data_parallel_world_size() * \ args.micro_batch_size * \ get_num_microbatches() + args.consumed_train_samples += new_samples + args.gigaflos_no_embeds += (6 * new_samples * args.seq_length * get_parameters_in_billions(model, exclude_embeddings=True)) # Logging. if args.deepspeed: @@ -827,11 +835,16 @@ def evaluate_and_print_results(prefix, forward_step_func, writer.add_scalar(f'lm-loss-validation/{key} validation vs samples', total_loss_dict[key].item(), args.consumed_train_samples) + writer.add_scalar(f'lm-loss-validation/{key} validation vs gigaflos (without embeddings)', + total_loss_dict[key].item(), + args.gigaflos_no_embeds) if args.log_validation_ppl_to_tensorboard: writer.add_scalar(f'lm-loss-validation/{key} validation ppl', ppl, iteration) writer.add_scalar(f'lm-loss-validation/{key} validation ppl vs samples', ppl, args.consumed_train_samples) + writer.add_scalar(f'lm-loss-validation/{key} validation ppl vs gigaflos (without embeddings)', + ppl, args.gigaflos_no_embeds) length = len(string) + 1 print_rank_last('-' * length) diff --git a/megatron/utils.py b/megatron/utils.py index ed7047c95..65148d8e3 100644 --- a/megatron/utils.py +++ b/megatron/utils.py @@ -16,8 +16,10 @@ """General utilities.""" import sys +import warnings import torch +from torch import nn from torch.nn.parallel import DistributedDataParallel as torchDDP from apex.multi_tensor_apply import multi_tensor_applier @@ -28,7 +30,7 @@ from megatron import get_adlr_autoresume from megatron import mpu from megatron.model.module import param_is_not_shared -from megatron.mpu.layers import param_is_not_tensor_parallel_duplicate +from megatron.mpu.layers import param_is_not_tensor_parallel_duplicate, VocabParallelEmbedding from megatron import get_num_microbatches def unwrap_model(model, module_instances=(torchDDP)): @@ -204,11 +206,32 @@ def get_ltor_masks_and_position_ids(data, return attention_mask, loss_mask, position_ids -def get_parameters_in_billions(model): +def param_size(parameter): + return parameter.ds_numel if hasattr(parameter, 'ds_id') else parameter.nelement() + + +def unique_param_count(param_list): + return sum(dict((p.data_ptr(), param_size(p)) for p in param_list).values()) + + +def non_embedding_params(module): + embedding_param_names = [ + f"{name}.weight" for name, module_type in module.named_modules() if isinstance(module_type, nn.Embedding) or isinstance(module_type, VocabParallelEmbedding) + ] + non_embedding_parameters = [ + parameter for name, parameter in module.named_parameters() if name not in embedding_param_names + ] + return unique_param_count(non_embedding_parameters) + + +def get_parameters_in_billions(model, exclude_embeddings=False): gpus_per_model = torch.distributed.get_world_size(group=mpu.get_model_parallel_group()) - approx_parameters_in_billions = sum([sum([p.ds_numel if hasattr(p,'ds_id') else p.nelement() for p in model_module.parameters()]) - for model_module in model]) + if exclude_embeddings: + approx_parameters_in_billions = sum([non_embedding_params(model_module) for model_module in model]) + else: + warnings.warn("Parameter count with the embeddings will be inaccurate with PP > 1, as the first and last stage hold several copies of the embeddings") + approx_parameters_in_billions = unique_param_count([p for model_module in model for p in model_module.parameters()]) return approx_parameters_in_billions*gpus_per_model/(1e9)