Skip to content

Commit

Permalink
Bugfix for BF16 grad reductions with distopt (NVIDIA#6340)
Browse files Browse the repository at this point in the history
* Debug distopt support for BF16 grad reductions

Signed-off-by: Tim Moon <[email protected]>

* Dump and load FP32 main params

Signed-off-by: Tim Moon <[email protected]>

* Style tweaks

Signed-off-by: Tim Moon <[email protected]>

---------

Signed-off-by: Tim Moon <[email protected]>
Co-authored-by: Mikołaj Błaż <[email protected]>
  • Loading branch information
timmoon10 and mikolajblaz committed Apr 5, 2023
1 parent 4fc86b6 commit 2c8813d
Show file tree
Hide file tree
Showing 2 changed files with 119 additions and 47 deletions.
3 changes: 1 addition & 2 deletions nemo/collections/nlp/modules/common/megatron/clip_grads.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,9 +178,8 @@ def clip_grad_norm_distributed_optimizer(optimizer, max_norm, norm_type=2):
# Filter parameters based on:
# - parameter should not be shared
# - should not be a replica due to tensor model parallelism
params = itertools.chain.from_iterable(param_group['params'] for param_group in optimizer.param_groups)
params_for_norm = []
for param in params:
for param in optimizer.parameters(with_fp32_optim_params=True):
is_not_shared = param_is_not_shared(param)
is_not_tp_duplicate = param_is_not_tensor_parallel_duplicate(param)
if is_not_shared and is_not_tp_duplicate:
Expand Down
163 changes: 118 additions & 45 deletions nemo/core/optim/distributed_adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import collections
import itertools

import torch
from apex.contrib.optimizers.distributed_fused_adam import DistributedFusedAdam, _disable_pre_forward_hook
from apex.contrib.optimizers.distributed_fused_adam import (
DistributedFusedAdam,
_coalescing_manager,
_disable_pre_forward_hook,
)
from apex.transformer import parallel_state


Expand Down Expand Up @@ -58,36 +65,52 @@ def __init__(self, params, disable_distributed_parameters=False, **kwargs):
if keyword in kwargs:
kwargs[keyword] = _str_to_dtype(kwargs[keyword])

# Check if any parameters require an explicit FP32 optimizer
# Make sure params are in consistent format (list of param group dicts)
param_groups = list(params)
assert param_groups
if not isinstance(param_groups[0], dict):
param_groups = [{'params': param_groups}]

# Check if explicit FP32 optimizer is needed
self._fp32_optim = None
distopt_params = params
distopt_param_groups = param_groups
dtype = kwargs['dtype'] if 'dtype' in kwargs else torch.float32
grad_sync_dtype = kwargs['grad_sync_dtype'] if 'grad_sync_dtype' in kwargs else dtype
if (dtype != torch.float32 or grad_sync_dtype != torch.float32) and any(
getattr(param, '_with_fp32_optimizer', False) for param in params
):
needs_fp32_optimizer = any(
getattr(param, '_with_fp32_optimizer', False)
for param in itertools.chain.from_iterable(param_group['params'] for param_group in param_groups)
)
if (dtype != torch.float32 or grad_sync_dtype != torch.float32) and needs_fp32_optimizer:

# Find params that require explicit FP32 optimizer
self._fp32_optim_model_params = []
self._fp32_optim_main_params = []
distopt_params = []
for model_param in params:
if getattr(param, '_with_fp32_optimizer', False):
main_param = param.detach().clone().float()
self._fp32_optim_model_params.append(model_param)
self._fp32_optim_main_params.append(main_param)
else:
distopt_params.append(model_param)
distopt_param_groups = []
fp32_param_groups = []
self._fp32_optim_main_params = collections.OrderedDict()
for param_group in param_groups:
distopt_param_group = {key: val for key, val in param_group.items() if key != 'params'}
distopt_param_group['params'] = []
fp32_param_group = {key: val for key, val in param_group.items() if key != 'params'}
fp32_param_group['params'] = []
for model_param in param_group['params']:
if getattr(model_param, '_with_fp32_optimizer', False):
main_param = model_param.detach().clone().float()
fp32_param_group['params'].append(main_param)
self._fp32_optim_main_params[model_param] = main_param
else:
distopt_param_group['params'].append(model_param)
distopt_param_groups.append(distopt_param_group)
fp32_param_groups.append(fp32_param_group)

# Construct explicit FP32 optimizer
adamw_kwargs = {}
for name in ('lr', 'betas', 'eps', 'weight_decay', 'amsgrad'):
if name in kwargs:
adamw_kwargs[name] = kwargs[name]
self.fp32_optim = torch.optim.AdamW(self._fp32_optim_main_params, **adamw_kwargs)
self._fp32_optim = torch.optim.AdamW(fp32_param_groups, **adamw_kwargs)
self._fp32_optim_grad_sync_needed = True

# Construct distributed optimizer
super().__init__(distopt_params, **kwargs)
super().__init__(distopt_param_groups, **kwargs)

def _make_post_backward_hook(self, param, param_group_id, param_id):
def hook(*unused):
Expand All @@ -112,29 +135,69 @@ def hook(*unused):

return hook

def _filter_distopt_params(self, params):
if self._fp32_optim is None:
return params
if params is None:
return None
if isinstance(params, torch.Tensor):
params = [params]
return filter(lambda param: param not in self._fp32_optim_main_params, params)

def parameters(self, with_fp32_optim_params=False):
if with_fp32_optim_params and self._fp32_optim is not None:
return itertools.chain(super().parameters(), self._fp32_optim_main_params.keys())
else:
return super().parameters()

def init_params(self, params=None):
super().init_params(self._filter_distopt_params(params))

def init_params_bucket(self, params):
super().init_params_bucket(self._filter_distopt_params(params))

def try_grad_sync(self, params):
params = self._filter_distopt_params(params)
params = [p for p in params if not getattr(p, '_disable_greedy_grad_copy', False)]
params = [p for p in params if not getattr(p, '_disable_overlap_grad_sync', False)]
for p in params:
self._grad_copy(p)
self._try_start_bucket_grad_sync(params=params)

def _try_start_bucket_param_sync(self, params=None):
super()._try_start_bucket_param_sync(self._filter_distopt_params(params))

def _fp32_optim_grad_sync(self):
if self._fp32_optim is None:
if self._fp32_optim is None or not self._fp32_optim_grad_sync_needed:
return
for model_param, main_param in zip(self._fp32_optim_model_params, self._fp32_optim_main_params):
if main_param.grad is None:
main_param.grad = model_param.grad.detach().clone().float()
torch.distributed.all_reduce(main_param.grad, group=self.process_group)
for model_param, main_param in self._fp32_optim_main_params.items():
if model_param.grad is not None:
main_param.grad += model_param.grad.detach()
sync_requests = []
with _coalescing_manager(self.process_group, self.device, sync_requests):
for main_param in self._fp32_optim_main_params.values():
sync_requests.append(
torch.distributed.all_reduce(
main_param.grad, op=torch.distributed.ReduceOp.AVG, group=self.process_group, async_op=True,
)
)
for req in sync_requests:
req.wait()
self._fp32_optim_grad_sync_needed = False

def zero_grad(self, *args, **kwargs):
super().zero_grad(*args, **kwargs)

# Reset grads for explicit FP32 optimizer
if self._fp32_optim is not None:
self._fp32_optim.zero_grad(set_to_none=True)
for param in self._fp32_optim_model_params:
param.grad = None
self._fp32_optim_grad_sync_needed = True
self._fp32_optim.zero_grad(set_to_none=False)
for model_param, main_param in self._fp32_optim_main_params.items():
if main_param.grad is None:
main_param.grad = torch.zeros_like(main_param)
if model_param.grad is not None:
model_param.grad.zero_()
model_param.main_grad = main_param.grad

# Reset main grads
if self.contiguous_grad_buffer:
Expand All @@ -145,30 +208,31 @@ def zero_grad(self, *args, **kwargs):
def grad_norm(self, parameters=None, norm_type=2.0, force=False):
assert norm_type == 2

if parameters is not None:
# Make sure we can access iterable multiple times
parameters = list(parameters)

# Compute grad norm
if force or self._grad_norm is None:

# Identify params for explicit FP32 optimizer
fp32_optim_params = []
if self._fp32_optim is not None:
if parameters is None:
fp32_optim_params = self._fp32_optim_model_params
else:
fp32_optim_params = [param for param in parameters if param in self._fp32_optim_model_params]
parameters = [param for param in parameters if param not in self._fp32_optim_model_params]

# Compute norm of local gradients for distributed optimizer
grad_norm_sq = self._local_grad_norm(parameters=parameters, norm_type=norm_type)
grad_norm_sq = self._local_grad_norm(
parameters=self._filter_distopt_params(parameters), norm_type=norm_type,
)
if self.redundant_size > 1:
grad_norm_sq /= self.redundant_size

# Compute norm of local gradients for explicit FP32 optimizer
if self._fp32_optim is not None:
_fp32_optim_grad_sync()
for model_param in fp32_optim_params:
i = self._fp32_optim_model_params.index(model_param)
main_param = self._fp32_optim_main_params[i]
grad_norm_sq += torch.linalg.norm(main_param.grad) ** 2 / self.process_group_size
self._fp32_optim_grad_sync()
if parameters is None:
for main_param in self._fp32_optim_main_params.values():
grad_norm_sq += torch.linalg.norm(main_param.grad) ** 2 / self.process_group_size
else:
for model_param in parameters:
if model_param in self._fp32_optim_main_params:
main_param = self._fp32_optim_main_params[model_param]
grad_norm_sq += torch.linalg.norm(main_param.grad) ** 2 / self.process_group_size

# Sum over all procs to get grad norm
torch.distributed.all_reduce(
Expand All @@ -193,25 +257,34 @@ def step(self, closure=None, *, grad_scaler=None):
if found_inf.item():
return loss

# Update learning rate
for distopt_group, fp32_optim_group in zip(self.param_groups, self._fp32_optim.param_groups):
fp32_optim_group['lr'] = distopt_group['lr']

# Apply explicit FP32 optimizer
self._fp32_optim_grad_sync()
for main_param in self._fp32_optim_main_params:
for main_param in self._fp32_optim_main_params.values():
main_param.grad *= self._grad_scale
self._fp32_optim.step()
for model_param, main_param in zip(self._fp32_optim_model_params, self._fp32_optim_main_params):
main_param.grad = None
model_param.copy_(main_param.detach())
for model_param, main_param in self._fp32_optim_main_params.items():
model_param.detach().copy_(main_param.detach())

return loss

def state_dict(self, *args, **kwargs):
state_dict = super().state_dict(*args, **kwargs)
if self._fp32_optim is not None and state_dict is not None:
state_dict['fp32_optim'] = self._fp32_optim.state_dict()
state_dict['fp32_optim_fp32_params'] = list(self._fp32_optim_main_params.values())
return state_dict

def load_state_dict(self, state_dict):
if self._fp32_optim is not None and 'fp32_optim' in state_dict:
self._fp32_optim.load_state_dict(state_dict['fp32_optim'])
del state_dict['fp32_optim']
for old_main_param, new_main_param in zip(
self._fp32_optim_main_params.values(), state_dict['fp32_optim_fp32_params']
):
old_main_param.copy_(new_main_param.detach())
del state_dict['fp32_optim_fp32_params']
return super().load_state_dict(state_dict)

0 comments on commit 2c8813d

Please sign in to comment.