diff --git a/megatron/optimizer/__init__.py b/megatron/optimizer/__init__.py index 3df172c8e..738717d55 100644 --- a/megatron/optimizer/__init__.py +++ b/megatron/optimizer/__init__.py @@ -44,7 +44,20 @@ def _get_params_for_weight_decay_optimization(modules): [p for n, p in list(module_._parameters.items()) if p is not None and n == 'bias']) - return weight_decay_params, no_weight_decay_params + # XXX: temp hack to workaround the crash in apex FusedAdam's multi_tensor_applier + # + # it crashes when the param count is larger than a certain size which we hit at 200B over 80 + # A100 gpus - I think around 2.7B per gpu, so halving it works around the issue + param_count = len(weight_decay_params['params']) + first_half = weight_decay_params['params'][:param_count // 2] + second_half = weight_decay_params['params'][param_count // 2:] + + first_half = { 'params': first_half } + second_half = { 'params': second_half } + + return first_half, second_half, no_weight_decay_params + + #return weight_decay_params, no_weight_decay_params def get_megatron_optimizer(model):