diff --git a/nemo/collections/nlp/models/language_modeling/megatron_base_model.py b/nemo/collections/nlp/models/language_modeling/megatron_base_model.py index e7f2aa805a9c9..0828d88a81333 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_base_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_base_model.py @@ -861,7 +861,15 @@ def configure_optimizers(self): # Initialize param buckets if explicitly provided if getattr(self, 'distributed_adam_buckets', None) is not None: - for bucket in self.distributed_adam_buckets: + buckets = self.distributed_adam_buckets + if self.cfg.get('distributed_adam_bucket_merge_size', 1) > 1: + # Merge buckets if needed + stride = self.cfg.get('distributed_adam_bucket_merge_size', 1) + buckets = [ + list(itertools.chain.from_iterable(buckets[i : i + stride])) + for i in range(0, len(buckets), stride) + ] + for bucket in buckets: self._optimizer.init_params_bucket(bucket) self._optimizer.init_params_bucket(self.parameters()) if hasattr(self, 'distributed_adam_buckets'):