From 7d6a4f25735307829ed441e45d162d76827aefa9 Mon Sep 17 00:00:00 2001 From: jeffhataws Date: Tue, 5 Nov 2024 21:27:28 -0800 Subject: [PATCH] Revive "Modularize zero step function and make it customizable" #7233 (#8332) Co-authored-by: Can Karakus Co-authored-by: Karakus Co-authored-by: Shreyas Labhsetwar <137557845+slabhs@users.noreply.github.com> --- .../distributed/zero_redundancy_optimizer.py | 130 +++++++++++------- 1 file changed, 82 insertions(+), 48 deletions(-) diff --git a/torch_xla/distributed/zero_redundancy_optimizer.py b/torch_xla/distributed/zero_redundancy_optimizer.py index 7e1e7b6cc10..b76b53ee42c 100644 --- a/torch_xla/distributed/zero_redundancy_optimizer.py +++ b/torch_xla/distributed/zero_redundancy_optimizer.py @@ -324,17 +324,20 @@ def _clip_grad_norm( if p.grad is not None: p.grad.detach().mul_(clip_value) - @torch.no_grad() - def step(self, closure=None, **kwargs): - """ - Performs a single optimizer step and syncs parameters across all ranks. - """ - assert self.inited, "must call init_zero() first" - - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() + def _get_sharding_scheme(self, kwargs): + if "sharding_scheme" in kwargs: + return kwargs["sharding_scheme"] + else: + return [ + { + "scale_factor": 1.0, + "sharding_group": self.sharding_groups, + "group_size": self.local_world_size, + }, + ] + + def _reduce_gradients(self, **kwargs): + sharding_scheme = self._get_sharding_scheme(kwargs) # sync to base optimizer self._sync_param_groups(self.param_groups, self.base_optimizer.param_groups) @@ -356,30 +359,34 @@ def step(self, closure=None, **kwargs): if self.coalesce_cc_reduce_scatter: padded_grads.append(padded_grad) else: - grad_shard = xm.reduce_scatter( - xm.REDUCE_SUM, - padded_grad, - scale=1.0 / self.local_world_size, - scatter_dim=0, - shard_count=self.local_world_size, - pin_layout=self.pin_layout, - groups=self.sharding_groups, - ) + grad_shard = padded_grad + for step in sharding_scheme: + grad_shard = xm.reduce_scatter( + xm.REDUCE_SUM, + grad_shard, + scale=step['scale_factor'] / step['group_size'], + scatter_dim=0, + shard_count=step['group_size'], + pin_layout=self.pin_layout, + groups=step['sharding_group'], + ) if grad_shard.dtype != self.optimizer_dtype: grad_shard = grad_shard.to(dtype=self.optimizer_dtype) shard.grad = grad_shard if self.coalesce_cc_reduce_scatter: - grad_shards = xm.reduce_scatter_bucketized( - xm.REDUCE_SUM, - padded_grads, - scale=1.0 / self.local_world_size, - scatter_dim=0, - shard_count=self.local_world_size, - pin_layout=self.pin_layout, - groups=self.sharding_groups, - bucket_cap_mb=self.bucket_cap_mb_reduce_scatter, - ) + grad_shards = padded_grads + for step in sharding_scheme: + grad_shards = xm.reduce_scatter_bucketized( + xm.REDUCE_SUM, + grad_shards, + scale=step['scale_factor'] / step['group_size'], + scatter_dim=0, + shard_count=step['group_size'], + pin_layout=self.pin_layout, + groups=step['sharding_group'], + bucket_cap_mb=self.bucket_cap_mb_reduce_scatter, + ) index = 0 for param_group, sharded_param_group in zip( self.param_groups, self.base_optimizer.param_groups): @@ -393,9 +400,9 @@ def step(self, closure=None, **kwargs): shard.grad = grad_shard index += 1 - if self.grad_clipping: - # Update unscale/clip with sub partitions - self._clip_grad_norm(max_norm=self.max_norm) + def _update_parameters(self, **kwargs): + sharding_scheme = self._get_sharding_scheme(kwargs) + kwargs.pop("sharding_scheme", None) # Step the wrapped optimizer # Closure already executed, pass none here @@ -403,15 +410,38 @@ def step(self, closure=None, **kwargs): # Remove shards' grads self.base_optimizer.zero_grad(set_to_none=True) - self.allgather_weights_and_update_full_parameter() + self.allgather_weights_and_update_full_parameter(sharding_scheme) # sync back self._sync_param_groups(self.base_optimizer.param_groups, self.param_groups) + @torch.no_grad() + def step(self, closure=None, **kwargs): + """ + Performs a single optimizer step and syncs parameters across all ranks. + """ + assert self.inited, "must call init_zero() first" + + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + self._reduce_gradients(**kwargs) + + if self.grad_clipping: + # Update unscale/clip with sub partitions + self._clip_grad_norm(max_norm=self.max_norm) + + self._update_parameters(**kwargs) + return loss - def allgather_weights_and_update_full_parameter(self): + def allgather_weights_and_update_full_parameter(self, sharding_scheme=None): + # All gather the new weights across the ranks and assign them to the full parameters + if sharding_scheme is None: + sharding_scheme = self._get_sharding_scheme({}) sharded_data = [] for param_group, sharded_param_group in zip( self.param_groups, self.base_optimizer.param_groups): @@ -424,22 +454,26 @@ def allgather_weights_and_update_full_parameter(self): if self.coalesce_cc_all_gather: sharded_data.append(shard_data) else: - padded_param = xm.all_gather( - shard_data, - dim=0, - pin_layout=self.pin_layout, - groups=self.sharding_groups, - ) + padded_param = shard_data + for step in reversed(sharding_scheme): + padded_param = xm.all_gather( + padded_param, + dim=0, + pin_layout=self.pin_layout, + groups=step['sharding_group'], + ) param.data.copy_(padded_param.data[:param.size(0)]) if self.coalesce_cc_all_gather: - padded_params = xm.all_gather_bucketized( - sharded_data, - dim=0, - pin_layout=self.pin_layout, - groups=self.sharding_groups, - bucket_cap_mb=self.bucket_cap_mb_all_gather, - ) + padded_params = sharded_data + for step in reversed(sharding_scheme): + padded_params = xm.all_gather_bucketized( + padded_params, + dim=0, + pin_layout=self.pin_layout, + groups=step['sharding_group'], + bucket_cap_mb=self.bucket_cap_mb_all_gather, + ) index = 0 for param_group, sharded_param_group in zip( self.param_groups, self.base_optimizer.param_groups):