From f8eeb794c381f479bb3b245aac81415660549a6d Mon Sep 17 00:00:00 2001 From: Tim Moon <4406448+timmoon10@users.noreply.github.com> Date: Wed, 12 Jun 2024 14:26:08 -0700 Subject: [PATCH] Add option to merge distributed optimizer buckets (#9414) * Add option to merge distopt buckets in GPT Signed-off-by: Tim Moon * Move distopt bucket merge logic to base LLM class Signed-off-by: Tim Moon * Apply isort and black reformatting Signed-off-by: timmoon10 --------- Signed-off-by: Tim Moon Signed-off-by: timmoon10 Co-authored-by: timmoon10 Co-authored-by: Sangkug Lym --- .../models/language_modeling/megatron_base_model.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/nemo/collections/nlp/models/language_modeling/megatron_base_model.py b/nemo/collections/nlp/models/language_modeling/megatron_base_model.py index e7f2aa805a9c..0828d88a8133 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_base_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_base_model.py @@ -861,7 +861,15 @@ def configure_optimizers(self): # Initialize param buckets if explicitly provided if getattr(self, 'distributed_adam_buckets', None) is not None: - for bucket in self.distributed_adam_buckets: + buckets = self.distributed_adam_buckets + if self.cfg.get('distributed_adam_bucket_merge_size', 1) > 1: + # Merge buckets if needed + stride = self.cfg.get('distributed_adam_bucket_merge_size', 1) + buckets = [ + list(itertools.chain.from_iterable(buckets[i : i + stride])) + for i in range(0, len(buckets), stride) + ] + for bucket in buckets: self._optimizer.init_params_bucket(bucket) self._optimizer.init_params_bucket(self.parameters()) if hasattr(self, 'distributed_adam_buckets'):