fix: ep clipping with no ep grads#1541
Conversation
tianyu-l
left a comment
There was a problem hiding this comment.
Thanks, had a minor comment.
| ep_grads, norm_type, error_if_nonfinite, foreach | ||
| ).full_tensor() | ||
| ) | ||
| # ep_grads may be an empty list, in which case get_total_norm returns tensor(0.), a non-DTensor |
There was a problem hiding this comment.
This edge case can occur in PP+EP setups when model uses some fully dense and some MoE layers (like DSv3), in which case some pp ranks may not be assigned any MoE layers.
Oh makes sense to me. Could you actually put this example edge case in the comment too? I think it'd be very helpful.
I suppose it is possible that
non_ep_gradscould also be empty, but I can only imagine this happening in extreme cases, so I did not change thenon_ep_gradscode.
I think this is not possible if a PP stage always
- contains any non-MoE params
- contains full MoE modules -- the shared expert and router.gate will be
non_ep_paramsanyways
There was a problem hiding this comment.
Expanded the comment.
I think this is not possible if a PP stage always [...]
Yeah, I was imagining very extreme cases where PP is very granularly applied an somehow a PP rank only ends up owning MoE layers and nothing else. Can't happen for any model or parallelism you could setup with torchtitan:main today, for sure. I was mostly just explaining why I only touched the ep_grads code.
There was a problem hiding this comment.
Need anything else from me on this one @tianyu-l ?
There was a problem hiding this comment.
oh sorry forgot to merge :)
| ep_grads, norm_type, error_if_nonfinite, foreach | ||
| ).full_tensor() | ||
| ) | ||
| # ep_grads may be an empty list, in which case get_total_norm returns tensor(0.), a non-DTensor |
There was a problem hiding this comment.
oh sorry forgot to merge :)
|
Np, thanks! |
The current EP grad clipping logic assumes that when using EP all of the norms returned by
torch.nn.utils.get_total_normareDTensors. This assumption can be violated and the subsequentfull_tensorcall can correspondingly fail in the edge case where the ep_grad list is empty, in which caseget_total_normreturnstensor(0.), a non-DTensor.torchtitan/torchtitan/distributed/utils.py
Lines 421 to 423 in a1fdd7e
This edge case can occur in PP+EP setups when model uses some fully dense and some MoE layers (like DSv3), in which case some pp ranks may not be assigned any MoE layers.
I suppose it is possible that
non_ep_gradscould also be empty, but I can only imagine this happening in extreme cases, so I did not change thenon_ep_gradscode.CC @tianyu-l