Skip to content

Commit

Permalink
Revive "Modularize zero step function and make it customizable" #7233 (
Browse files Browse the repository at this point in the history
…#8332)

Co-authored-by: Can Karakus <[email protected]>
Co-authored-by: Karakus <[email protected]>
Co-authored-by: Shreyas Labhsetwar <[email protected]>
  • Loading branch information
4 people authored Nov 6, 2024
1 parent 81c4caa commit 7d6a4f2
Showing 1 changed file with 82 additions and 48 deletions.
130 changes: 82 additions & 48 deletions torch_xla/distributed/zero_redundancy_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,17 +324,20 @@ def _clip_grad_norm(
if p.grad is not None:
p.grad.detach().mul_(clip_value)

@torch.no_grad()
def step(self, closure=None, **kwargs):
"""
Performs a single optimizer step and syncs parameters across all ranks.
"""
assert self.inited, "must call init_zero() first"

loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
def _get_sharding_scheme(self, kwargs):
if "sharding_scheme" in kwargs:
return kwargs["sharding_scheme"]
else:
return [
{
"scale_factor": 1.0,
"sharding_group": self.sharding_groups,
"group_size": self.local_world_size,
},
]

def _reduce_gradients(self, **kwargs):
sharding_scheme = self._get_sharding_scheme(kwargs)

# sync to base optimizer
self._sync_param_groups(self.param_groups, self.base_optimizer.param_groups)
Expand All @@ -356,30 +359,34 @@ def step(self, closure=None, **kwargs):
if self.coalesce_cc_reduce_scatter:
padded_grads.append(padded_grad)
else:
grad_shard = xm.reduce_scatter(
xm.REDUCE_SUM,
padded_grad,
scale=1.0 / self.local_world_size,
scatter_dim=0,
shard_count=self.local_world_size,
pin_layout=self.pin_layout,
groups=self.sharding_groups,
)
grad_shard = padded_grad
for step in sharding_scheme:
grad_shard = xm.reduce_scatter(
xm.REDUCE_SUM,
grad_shard,
scale=step['scale_factor'] / step['group_size'],
scatter_dim=0,
shard_count=step['group_size'],
pin_layout=self.pin_layout,
groups=step['sharding_group'],
)
if grad_shard.dtype != self.optimizer_dtype:
grad_shard = grad_shard.to(dtype=self.optimizer_dtype)
shard.grad = grad_shard

if self.coalesce_cc_reduce_scatter:
grad_shards = xm.reduce_scatter_bucketized(
xm.REDUCE_SUM,
padded_grads,
scale=1.0 / self.local_world_size,
scatter_dim=0,
shard_count=self.local_world_size,
pin_layout=self.pin_layout,
groups=self.sharding_groups,
bucket_cap_mb=self.bucket_cap_mb_reduce_scatter,
)
grad_shards = padded_grads
for step in sharding_scheme:
grad_shards = xm.reduce_scatter_bucketized(
xm.REDUCE_SUM,
grad_shards,
scale=step['scale_factor'] / step['group_size'],
scatter_dim=0,
shard_count=step['group_size'],
pin_layout=self.pin_layout,
groups=step['sharding_group'],
bucket_cap_mb=self.bucket_cap_mb_reduce_scatter,
)
index = 0
for param_group, sharded_param_group in zip(
self.param_groups, self.base_optimizer.param_groups):
Expand All @@ -393,25 +400,48 @@ def step(self, closure=None, **kwargs):
shard.grad = grad_shard
index += 1

if self.grad_clipping:
# Update unscale/clip with sub partitions
self._clip_grad_norm(max_norm=self.max_norm)
def _update_parameters(self, **kwargs):
sharding_scheme = self._get_sharding_scheme(kwargs)
kwargs.pop("sharding_scheme", None)

# Step the wrapped optimizer
# Closure already executed, pass none here
self.base_optimizer.step(closure=None, **kwargs)
# Remove shards' grads
self.base_optimizer.zero_grad(set_to_none=True)

self.allgather_weights_and_update_full_parameter()
self.allgather_weights_and_update_full_parameter(sharding_scheme)

# sync back
self._sync_param_groups(self.base_optimizer.param_groups, self.param_groups)

@torch.no_grad()
def step(self, closure=None, **kwargs):
"""
Performs a single optimizer step and syncs parameters across all ranks.
"""
assert self.inited, "must call init_zero() first"

loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()

self._reduce_gradients(**kwargs)

if self.grad_clipping:
# Update unscale/clip with sub partitions
self._clip_grad_norm(max_norm=self.max_norm)

self._update_parameters(**kwargs)

return loss

def allgather_weights_and_update_full_parameter(self):
def allgather_weights_and_update_full_parameter(self, sharding_scheme=None):

# All gather the new weights across the ranks and assign them to the full parameters
if sharding_scheme is None:
sharding_scheme = self._get_sharding_scheme({})
sharded_data = []
for param_group, sharded_param_group in zip(
self.param_groups, self.base_optimizer.param_groups):
Expand All @@ -424,22 +454,26 @@ def allgather_weights_and_update_full_parameter(self):
if self.coalesce_cc_all_gather:
sharded_data.append(shard_data)
else:
padded_param = xm.all_gather(
shard_data,
dim=0,
pin_layout=self.pin_layout,
groups=self.sharding_groups,
)
padded_param = shard_data
for step in reversed(sharding_scheme):
padded_param = xm.all_gather(
padded_param,
dim=0,
pin_layout=self.pin_layout,
groups=step['sharding_group'],
)
param.data.copy_(padded_param.data[:param.size(0)])

if self.coalesce_cc_all_gather:
padded_params = xm.all_gather_bucketized(
sharded_data,
dim=0,
pin_layout=self.pin_layout,
groups=self.sharding_groups,
bucket_cap_mb=self.bucket_cap_mb_all_gather,
)
padded_params = sharded_data
for step in reversed(sharding_scheme):
padded_params = xm.all_gather_bucketized(
padded_params,
dim=0,
pin_layout=self.pin_layout,
groups=step['sharding_group'],
bucket_cap_mb=self.bucket_cap_mb_all_gather,
)
index = 0
for param_group, sharded_param_group in zip(
self.param_groups, self.base_optimizer.param_groups):
Expand Down

0 comments on commit 7d6a4f2

Please sign in to comment.