Skip to content

Commit

Permalink
use correct PG when collecting metrics with HYBRID shard (#551)
Browse files Browse the repository at this point in the history
  • Loading branch information
epwalsh authored Apr 19, 2024
1 parent 06786a7 commit 7be71cd
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 12 deletions.
33 changes: 24 additions & 9 deletions olmo/optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,10 @@ def _clean_param_name(self, name: str) -> str:

@torch.no_grad()
def clip_grads_and_collect_metrics(
self, global_step: int, collect_param_metrics: bool = True
self,
global_step: int,
collect_param_metrics: bool = True,
process_group: Optional[dist.ProcessGroup] = None,
) -> Dict[str, torch.Tensor]:
"""
Clips gradients for every group that has the field `max_grad_norm`.
Expand Down Expand Up @@ -69,6 +72,10 @@ def clip_grads_and_collect_metrics(
per_param_avg_metric_names: List[str] = []
per_param_norm_metric_names: List[str] = []

dst_rank = 0
if process_group is not None:
dst_rank = dist.get_global_rank(process_group, 0)

# Collect metrics locally.
for group in self.param_groups:
if is_distributed():
Expand Down Expand Up @@ -144,12 +151,12 @@ def is_grad_norm_metric(metric_name: str) -> bool:
# Reduce mins.
if per_param_min_metrics:
all_mins = torch.cat(per_param_min_metrics).to(device)
dist.reduce(all_mins, 0, op=dist.ReduceOp.MIN)
dist.reduce(all_mins, dst_rank, op=dist.ReduceOp.MIN, group=process_group)
per_param_min_metrics = all_mins.split(1)
# Reduce maxs.
if per_param_max_metrics:
all_maxs = torch.cat(per_param_max_metrics).to(device)
dist.reduce(all_maxs, 0, op=dist.ReduceOp.MAX)
dist.reduce(all_maxs, dst_rank, op=dist.ReduceOp.MAX, group=process_group)
per_param_max_metrics = all_maxs.split(1)
# Reduce sums or just norms.
all_norms = torch.cat(per_param_norm_metrics).to(device) ** 2.0
Expand All @@ -159,13 +166,13 @@ def is_grad_norm_metric(metric_name: str) -> bool:
all_sums_norms_numels = torch.cat(
[all_sums.unsqueeze(0), all_norms.unsqueeze(0), all_numels.unsqueeze(0)], dim=0
)
dist.all_reduce(all_sums_norms_numels, op=dist.ReduceOp.SUM)
dist.all_reduce(all_sums_norms_numels, op=dist.ReduceOp.SUM, group=process_group)
all_sums, all_norms, all_numels = all_sums_norms_numels.split(1)
# Get averages.
# NOTE: could get infs for non-rank0 processes but that's okay.
per_param_avg_metrics = (all_sums / all_numels).squeeze(0).split(1)
else:
dist.all_reduce(all_norms, op=dist.ReduceOp.SUM)
dist.all_reduce(all_norms, op=dist.ReduceOp.SUM, group=process_group)
grad_norm_metric_mask = torch.tensor(
[float(is_grad_norm_metric(n)) for n in per_param_norm_metric_names], device=all_norms.device
)
Expand Down Expand Up @@ -325,8 +332,10 @@ def _do_global_fixed_clipping(
p.grad.detach().mul_(clip_coef_clamped.to(p.grad.device, p.grad.dtype))
return num_grads_clipped

def get_post_step_metrics(self, module: nn.Module) -> Dict[str, torch.Tensor]:
del module
def get_post_step_metrics(
self, module: nn.Module, process_group: Optional[dist.ProcessGroup] = None
) -> Dict[str, torch.Tensor]:
del module, process_group
return {}

def get_state_for_param(self, param: nn.Parameter) -> Dict[str, Optional[torch.Tensor]]:
Expand Down Expand Up @@ -356,7 +365,9 @@ def __init__(
self._update_total_norm: Optional[torch.Tensor] = None
self._signed_update_total_norm: Optional[torch.Tensor] = None

def get_post_step_metrics(self, module: nn.Module) -> Dict[str, torch.Tensor]:
def get_post_step_metrics(
self, module: nn.Module, process_group: Optional[dist.ProcessGroup] = None
) -> Dict[str, torch.Tensor]:
update_total_dot_prod = self._update_total_dot_prod
update_total_norm = self._update_total_norm
signed_update_total_norm = self._signed_update_total_norm
Expand All @@ -370,7 +381,11 @@ def get_post_step_metrics(self, module: nn.Module) -> Dict[str, torch.Tensor]:
# Reduce all together to avoid multiple communication calls.
all_together = torch.stack([update_total_dot_prod, update_total_norm, signed_update_total_norm])
# Only need the final result on rank0, since that's where we log from.
dist.reduce(all_together, 0)
dist.reduce(
all_together,
0 if process_group is None else dist.get_global_rank(process_group, 0),
group=process_group,
)
update_total_dot_prod, update_total_norm, signed_update_total_norm = all_together
update_total_norm = update_total_norm**0.5
signed_update_total_norm = signed_update_total_norm**0.5
Expand Down
10 changes: 8 additions & 2 deletions olmo/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -704,7 +704,11 @@ def train_step(self, batch: Dict[str, Any], reduce_global_loss: bool = True) ->
# Clip gradient norms and collect param/gradient/optim metrics.
should_log_optim_metrics_this_step = self.should_log_optim_metrics_this_step()
optim_metrics = self.optim.clip_grads_and_collect_metrics(
self.global_step, collect_param_metrics=should_log_optim_metrics_this_step
self.global_step,
collect_param_metrics=should_log_optim_metrics_this_step,
# passing this process group here ensures metrics are reduced correctly when we're using
# HYBRID sharding.
process_group=self.fsdp_model.process_group,
)

# Adjust the learning rate.
Expand Down Expand Up @@ -742,7 +746,9 @@ def train_step(self, batch: Dict[str, Any], reduce_global_loss: bool = True) ->

# Maybe collect post-step optimizer-specific metrics.
if should_log_optim_metrics_this_step:
optim_metrics = self.optim.get_post_step_metrics(self.fsdp_model)
optim_metrics = self.optim.get_post_step_metrics(
self.fsdp_model, process_group=self.fsdp_model.process_group
)
for key, value in optim_metrics.items():
metrics[f"optim/{key}"] = value.item()

Expand Down
2 changes: 1 addition & 1 deletion scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def dummy_init_fn(module: torch.nn.Module) -> None:
raise OLMoConfigurationError("fsdp.hybrid_sharding_num_model_replicas must be a positive integer")

num_nodes = get_world_size() // get_local_world_size()
if num_nodes % num_model_replicas != 0:
if num_nodes > 1 and num_nodes % num_model_replicas != 0:
raise OLMoConfigurationError("fsdp.hybrid_sharding_num_model_replicas must divide number of nodes")

device_mesh = init_device_mesh("cuda", (num_model_replicas, get_world_size() // num_model_replicas))
Expand Down

0 comments on commit 7be71cd

Please sign in to comment.