Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bug - Using Sharding in Flash Attention with segment ids. #8334

Open
dudulightricks opened this issue Oct 29, 2024 · 0 comments
Open

Bug - Using Sharding in Flash Attention with segment ids. #8334

dudulightricks opened this issue Oct 29, 2024 · 0 comments

Comments

@dudulightricks
Copy link
Contributor

🐛 Bug

When training a sharded model with Flash Attention using segment_ids, the segment_ids are not sharded, resulting in a size mismatch. We attempted to resolve this by modifying custom_kernel.py (PR #8333), which successfully addresses the mismatch. However, with this fix, the loss does not converge to zero when training with dummy data; instead, it stalls at 0.2.

To Reproduce

Run any train using flash attention with segment_ids.

Expected behavior

Loss is expected to converge when using this fix with sharding training (with flash attention and segment_ids).

Environment

  • Reproducible on XLA backend [CPU/TPU/CUDA]:
  • torch_xla version: 2.4 / 2.6
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant