[GDN] Optimize b_dg computation in chunk_bwd_kernel_dqkwg#USE_G#823
[GDN] Optimize b_dg computation in chunk_bwd_kernel_dqkwg#USE_G#823MzeroMiko wants to merge 2 commits intofla-org:mainfrom
b_dg computation in chunk_bwd_kernel_dqkwg#USE_G#823Conversation
|
No actionable comments were generated in the recent review. 🎉 ℹ️ Recent review info⚙️ Run configurationConfiguration used: Repository UI Review profile: CHILL Plan: Pro Run ID: 📒 Files selected for processing (1)
🚧 Files skipped from review as they are similar to previous changes (1)
WalkthroughRefactors Changes
Sequence Diagram(s)(omitted) Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes Possibly related PRs
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request refactors the calculation of b_dg within the chunk_bwd_kernel_dqkwg Triton kernel by moving its initialization and consolidating its updates at the end of the block. A review comment suggests further simplifying the b_dg assignment into a single expression to improve code clarity and efficiency.
| b_dg = tl.zeros([BT], dtype=tl.float32) | ||
| b_dg += tl.sum(b_dq * b_q, axis=1) | ||
| b_dg -= tl.sum(b_dk * b_k, axis=1) |
There was a problem hiding this comment.
The initialization of b_dg to zero followed by incremental updates can be simplified into a single expression. This improves code clarity and avoids redundant operations.
| b_dg = tl.zeros([BT], dtype=tl.float32) | |
| b_dg += tl.sum(b_dq * b_q, axis=1) | |
| b_dg -= tl.sum(b_dk * b_k, axis=1) | |
| b_dg = tl.sum(b_dq * b_q, axis=1) - tl.sum(b_dk * b_k, axis=1) |
Math behind modification
before
after
Corresponding code:
fla.ops.common.chunk_o -> chunk_bwd_dqkwgBenchmark Results
Tests
gated delta rule,comba,simple-glaas they are the only ops that useschunk_bwd_dqkwgwithg.pytest tests/ops/test_gated_delta.py
pytest tests/ops/test_comba.py
pytest tests/ops/test_simple_gla.py
Benchmark Code
Summary by CodeRabbit