diff --git a/megatron/arguments.py b/megatron/arguments.py index 326c948ee..7f2125cc7 100644 --- a/megatron/arguments.py +++ b/megatron/arguments.py @@ -452,6 +452,8 @@ def _add_training_args(parser): help='Run optimizer on CPU') group.add_argument('--cpu_torch_adam', action='store_true', help='Use Torch Adam as optimizer on CPU.') + group.add_argument('--codecarbon-dir', type=str, default=None, + help='Write CodeCarbon logs to this directory.') return parser diff --git a/megatron/checkpointing.py b/megatron/checkpointing.py index 829fb1101..f7328dcbb 100644 --- a/megatron/checkpointing.py +++ b/megatron/checkpointing.py @@ -22,6 +22,7 @@ import torch +from megatron.global_vars import codecarbon_tracker_flush from megatron import (get_args, mpu, print_rank_0, @@ -135,7 +136,7 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler): for i in range(len(model)): mpu.set_virtual_pipeline_model_parallel_rank(i) state_dict['model%d' % i] = model[i].state_dict_for_save_checkpoint() - + # Optimizer stuff. if not args.no_save_optim: if optimizer is not None: @@ -183,6 +184,11 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler): if torch.distributed.is_initialized(): torch.distributed.barrier() + # since the code can be exited or aborted in various places we use the checkpoint saving as + # a save saving point for the codecarbon tracker. If the program doesn't run to its normal + # end, then only the data since the last saved checkpoint will be lost. + codecarbon_tracker_flush() + def _transpose_first_dim(t, num_splits, num_splits_first, model): input_shape = t.size() # We use a self_attention module but the values extracted aren't @@ -417,7 +423,7 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True def load_biencoder_checkpoint(model, only_query_model=False, only_context_model=False, custom_load_path=None): """ - selectively load retrieval models for indexing/retrieving + selectively load retrieval models for indexing/retrieving from saved checkpoints """ diff --git a/megatron/global_vars.py b/megatron/global_vars.py index 5f3c28463..b5dcac4d9 100644 --- a/megatron/global_vars.py +++ b/megatron/global_vars.py @@ -19,6 +19,8 @@ import sys import time +from pathlib import Path + import torch from megatron.tokenizer import build_tokenizer @@ -29,10 +31,10 @@ _GLOBAL_NUM_MICROBATCHES_CALCULATOR = None _GLOBAL_TOKENIZER = None _GLOBAL_TENSORBOARD_WRITER = None +_GLOBAL_CODECARBON_TRACKER = None _GLOBAL_ADLR_AUTORESUME = None _GLOBAL_TIMERS = None - def get_args(): """Return arguments.""" _ensure_var_is_initialized(_GLOBAL_ARGS, 'args') @@ -63,6 +65,10 @@ def get_tensorboard_writer(): to check if it is initialized.""" return _GLOBAL_TENSORBOARD_WRITER +def get_codecarbon_tracker(): + """Return codecarbon tracker. It can be None so no need + to check if it is initialized.""" + return _GLOBAL_CODECARBON_TRACKER def get_adlr_autoresume(): """ADLR autoresume object. It can be None so no need @@ -86,6 +92,7 @@ def set_global_variables(extra_args_provider=None, args_defaults={}, if args.vocab_file or args.tokenizer_name_or_path: _ = _build_tokenizer(args) _set_tensorboard_writer(args) + _set_codecarbon_tracker(args) _set_adlr_autoresume(args) _set_timers() @@ -145,6 +152,56 @@ def _set_tensorboard_writer(args): 'no TensorBoard logs will be written.', flush=True) +def _set_codecarbon_tracker(args): + global _GLOBAL_CODECARBON_TRACKER + if not hasattr(args, 'codecarbon_dir'): + return + + import codecarbon + if args.rank == 0: + print('> setting codecarbon ...') + + output_dir = args.codecarbon_dir + output_file = f"emissions-{args.rank:03d}.csv" + log_level = "warning" + country_iso_code="FRA" + + Path(output_dir).mkdir(parents=True, exist_ok=True) + _GLOBAL_CODECARBON_TRACKER = codecarbon.OfflineEmissionsTracker( + output_dir=output_dir, + output_file=output_file, + log_level=log_level, + country_iso_code=country_iso_code, + ) + + +def codecarbon_tracker_start(): + global _GLOBAL_CODECARBON_TRACKER + if _GLOBAL_CODECARBON_TRACKER is None: + return + + #print("CC START") + _GLOBAL_CODECARBON_TRACKER.start() + + +def codecarbon_tracker_stop(): + global _GLOBAL_CODECARBON_TRACKER + if _GLOBAL_CODECARBON_TRACKER is None: + return + + #print("CC STOP") + _GLOBAL_CODECARBON_TRACKER.stop() + + +def codecarbon_tracker_flush(): + global _GLOBAL_CODECARBON_TRACKER + if _GLOBAL_CODECARBON_TRACKER is None: + return + + #print("CC FLUSH") + _GLOBAL_CODECARBON_TRACKER.flush() + + def _set_adlr_autoresume(args): """Initialize ADLR autoresume.""" global _GLOBAL_ADLR_AUTORESUME diff --git a/megatron/training.py b/megatron/training.py index 21ef13b94..f66544dff 100644 --- a/megatron/training.py +++ b/megatron/training.py @@ -51,6 +51,7 @@ from megatron.schedules import forward_backward_pipelining_without_interleaving from megatron.schedules import forward_backward_pipelining_with_interleaving from megatron.utils import report_memory, flops_calculator +from megatron.global_vars import codecarbon_tracker_start, codecarbon_tracker_stop import deepspeed @@ -95,6 +96,8 @@ def pretrain(train_valid_test_dataset_provider, initialize_megatron(extra_args_provider=extra_args_provider, args_defaults=args_defaults) + codecarbon_tracker_start() + # Adjust the startup time so it reflects the largest value. # This will be closer to what scheduler will see (outside of # image ... launches. @@ -162,6 +165,9 @@ def pretrain(train_valid_test_dataset_provider, test_data_iterator, model, 0, True) + codecarbon_tracker_stop() + + def update_train_iters(args): # For iteration-based training, we don't need to do anything diff --git a/requirements.txt b/requirements.txt index 234a9902a..a96ff42be 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,3 +5,5 @@ regex numpy transformers # git+https://github.com/microsoft/DeepSpeed.git@big-science +# edit to a higher SHA or future release if needed +git+git://github.com/mlco2/codecarbon.git@03479b695a771c28df6b877a809f5af3eb9ef3b8 diff --git a/tests/test_training.py b/tests/test_training.py index 7306615f1..c0a8fff69 100644 --- a/tests/test_training.py +++ b/tests/test_training.py @@ -97,6 +97,7 @@ def test_training_all(self): --save {output_dir}/checkpoints --load {output_dir}/checkpoints --data-path {data_dir}/meg-gpt2-openwebtext_text_document + --codecarbon-dir {output_dir}/codecarbon --tensorboard-dir {output_dir}/tensorboard --tensorboard-queue-size 5 --log-timers-to-tensorboard