You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
When training a sharded model with Flash Attention using segment_ids, the segment_ids are not sharded, resulting in a size mismatch. We attempted to resolve this by modifying custom_kernel.py (PR #8333), which successfully addresses the mismatch. However, with this fix, the loss does not converge to zero when training with dummy data; instead, it stalls at 0.2.
To Reproduce
Run any train using flash attention with segment_ids.
Expected behavior
Loss is expected to converge when using this fix with sharding training (with flash attention and segment_ids).
Environment
Reproducible on XLA backend [CPU/TPU/CUDA]:
torch_xla version: 2.4 / 2.6
The text was updated successfully, but these errors were encountered:
🐛 Bug
When training a sharded model with Flash Attention using segment_ids, the segment_ids are not sharded, resulting in a size mismatch. We attempted to resolve this by modifying custom_kernel.py (PR #8333), which successfully addresses the mismatch. However, with this fix, the loss does not converge to zero when training with dummy data; instead, it stalls at 0.2.
To Reproduce
Run any train using flash attention with segment_ids.
Expected behavior
Loss is expected to converge when using this fix with sharding training (with flash attention and segment_ids).
Environment
The text was updated successfully, but these errors were encountered: