Skip to content
Merged
2 changes: 2 additions & 0 deletions megatron/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,6 +452,8 @@ def _add_training_args(parser):
help='Run optimizer on CPU')
group.add_argument('--cpu_torch_adam', action='store_true',
help='Use Torch Adam as optimizer on CPU.')
group.add_argument('--codecarbon-dir', type=str, default=None,
help='Write CodeCarbon logs to this directory.')

return parser

Expand Down
10 changes: 8 additions & 2 deletions megatron/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

import torch

from megatron.global_vars import codecarbon_tracker_flush
from megatron import (get_args,
mpu,
print_rank_0,
Expand Down Expand Up @@ -135,7 +136,7 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):
for i in range(len(model)):
mpu.set_virtual_pipeline_model_parallel_rank(i)
state_dict['model%d' % i] = model[i].state_dict_for_save_checkpoint()

# Optimizer stuff.
if not args.no_save_optim:
if optimizer is not None:
Expand Down Expand Up @@ -183,6 +184,11 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):
if torch.distributed.is_initialized():
torch.distributed.barrier()

# since the code can be exited or aborted in various places we use the checkpoint saving as
# a save saving point for the codecarbon tracker. If the program doesn't run to its normal
# end, then only the data since the last saved checkpoint will be lost.
codecarbon_tracker_flush()

def _transpose_first_dim(t, num_splits, num_splits_first, model):
input_shape = t.size()
# We use a self_attention module but the values extracted aren't
Expand Down Expand Up @@ -417,7 +423,7 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True
def load_biencoder_checkpoint(model, only_query_model=False,
only_context_model=False, custom_load_path=None):
"""
selectively load retrieval models for indexing/retrieving
selectively load retrieval models for indexing/retrieving
from saved checkpoints
"""

Expand Down
59 changes: 58 additions & 1 deletion megatron/global_vars.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
import sys
import time

from pathlib import Path

import torch

from megatron.tokenizer import build_tokenizer
Expand All @@ -29,10 +31,10 @@
_GLOBAL_NUM_MICROBATCHES_CALCULATOR = None
_GLOBAL_TOKENIZER = None
_GLOBAL_TENSORBOARD_WRITER = None
_GLOBAL_CODECARBON_TRACKER = None
_GLOBAL_ADLR_AUTORESUME = None
_GLOBAL_TIMERS = None


def get_args():
"""Return arguments."""
_ensure_var_is_initialized(_GLOBAL_ARGS, 'args')
Expand Down Expand Up @@ -63,6 +65,10 @@ def get_tensorboard_writer():
to check if it is initialized."""
return _GLOBAL_TENSORBOARD_WRITER

def get_codecarbon_tracker():
"""Return codecarbon tracker. It can be None so no need
to check if it is initialized."""
return _GLOBAL_CODECARBON_TRACKER

def get_adlr_autoresume():
"""ADLR autoresume object. It can be None so no need
Expand All @@ -86,6 +92,7 @@ def set_global_variables(extra_args_provider=None, args_defaults={},
if args.vocab_file or args.tokenizer_name_or_path:
_ = _build_tokenizer(args)
_set_tensorboard_writer(args)
_set_codecarbon_tracker(args)
_set_adlr_autoresume(args)
_set_timers()

Expand Down Expand Up @@ -145,6 +152,56 @@ def _set_tensorboard_writer(args):
'no TensorBoard logs will be written.', flush=True)


def _set_codecarbon_tracker(args):
global _GLOBAL_CODECARBON_TRACKER
if not hasattr(args, 'codecarbon_dir'):
return
Comment on lines +157 to +158
Copy link
Member

@thomasw21 thomasw21 Aug 26, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should check if it's none. @TevenLeScao

Basically this line adds the attribute, but assigns None when not set: https://github.com/bigscience-workshop/Megatron-DeepSpeed/pull/15/files#diff-5f7d1ddfb0666cb6bb4ec0f07fd2fd7b1cd0354f421df5560489091db2ff5a55R455
So I believe hasattr(args, "codecarbon_dir") will return True. Despite having no paths.

#74


import codecarbon
if args.rank == 0:
print('> setting codecarbon ...')

output_dir = args.codecarbon_dir
output_file = f"emissions-{args.rank:03d}.csv"
log_level = "warning"
country_iso_code="FRA"

Path(output_dir).mkdir(parents=True, exist_ok=True)
_GLOBAL_CODECARBON_TRACKER = codecarbon.OfflineEmissionsTracker(
output_dir=output_dir,
output_file=output_file,
log_level=log_level,
country_iso_code=country_iso_code,
)


def codecarbon_tracker_start():
global _GLOBAL_CODECARBON_TRACKER
if _GLOBAL_CODECARBON_TRACKER is None:
return

#print("CC START")
_GLOBAL_CODECARBON_TRACKER.start()


def codecarbon_tracker_stop():
global _GLOBAL_CODECARBON_TRACKER
if _GLOBAL_CODECARBON_TRACKER is None:
return

#print("CC STOP")
_GLOBAL_CODECARBON_TRACKER.stop()


def codecarbon_tracker_flush():
global _GLOBAL_CODECARBON_TRACKER
if _GLOBAL_CODECARBON_TRACKER is None:
return

#print("CC FLUSH")
_GLOBAL_CODECARBON_TRACKER.flush()


def _set_adlr_autoresume(args):
"""Initialize ADLR autoresume."""
global _GLOBAL_ADLR_AUTORESUME
Expand Down
6 changes: 6 additions & 0 deletions megatron/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
from megatron.schedules import forward_backward_pipelining_without_interleaving
from megatron.schedules import forward_backward_pipelining_with_interleaving
from megatron.utils import report_memory, flops_calculator
from megatron.global_vars import codecarbon_tracker_start, codecarbon_tracker_stop

import deepspeed

Expand Down Expand Up @@ -95,6 +96,8 @@ def pretrain(train_valid_test_dataset_provider,
initialize_megatron(extra_args_provider=extra_args_provider,
args_defaults=args_defaults)

codecarbon_tracker_start()

# Adjust the startup time so it reflects the largest value.
# This will be closer to what scheduler will see (outside of
# image ... launches.
Expand Down Expand Up @@ -162,6 +165,9 @@ def pretrain(train_valid_test_dataset_provider,
test_data_iterator, model,
0, True)

codecarbon_tracker_stop()


def update_train_iters(args):

# For iteration-based training, we don't need to do anything
Expand Down
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,5 @@ regex
numpy
transformers
# git+https://github.com/microsoft/DeepSpeed.git@big-science
# edit to a higher SHA or future release if needed
git+git://github.com/mlco2/codecarbon.git@03479b695a771c28df6b877a809f5af3eb9ef3b8
1 change: 1 addition & 0 deletions tests/test_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ def test_training_all(self):
--save {output_dir}/checkpoints
--load {output_dir}/checkpoints
--data-path {data_dir}/meg-gpt2-openwebtext_text_document
--codecarbon-dir {output_dir}/codecarbon
--tensorboard-dir {output_dir}/tensorboard
--tensorboard-queue-size 5
--log-timers-to-tensorboard
Expand Down