Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion torchtitan/distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,7 +420,12 @@ def _clip_grad_norm_with_ep(
non_ep_grads.append(p.grad)
ep_grads_total_norm = torch.nn.utils.get_total_norm(
ep_grads, norm_type, error_if_nonfinite, foreach
).full_tensor()
)
# ep_grads may be an empty list, in which case get_total_norm returns tensor(0.), a non-DTensor
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This edge case can occur in PP+EP setups when model uses some fully dense and some MoE layers (like DSv3), in which case some pp ranks may not be assigned any MoE layers.

Oh makes sense to me. Could you actually put this example edge case in the comment too? I think it'd be very helpful.

I suppose it is possible that non_ep_grads could also be empty, but I can only imagine this happening in extreme cases, so I did not change the non_ep_grads code.

I think this is not possible if a PP stage always

  1. contains any non-MoE params
  2. contains full MoE modules -- the shared expert and router.gate will be non_ep_params anyways

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Expanded the comment.

I think this is not possible if a PP stage always [...]

Yeah, I was imagining very extreme cases where PP is very granularly applied an somehow a PP rank only ends up owning MoE layers and nothing else. Can't happen for any model or parallelism you could setup with torchtitan:main today, for sure. I was mostly just explaining why I only touched the ep_grads code.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need anything else from me on this one @tianyu-l ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh sorry forgot to merge :)

# This can occur in PP + EP setups where certain PP ranks only own non-EP layers, for instance.
if isinstance(ep_grads_total_norm, DTensor):
ep_grads_total_norm = ep_grads_total_norm.full_tensor()

non_ep_grads_total_norm = torch.nn.utils.get_total_norm(
non_ep_grads, norm_type, error_if_nonfinite, foreach
).full_tensor()
Expand Down
Loading