[kda kernel optimization] implement token-parallel intra-chunk attention#653
[kda kernel optimization] implement token-parallel intra-chunk attention#653sustcsonglin merged 5 commits intomainfrom
Conversation
- Replace sub_intra kernel with token-parallel implementation for better performance on large sequences. - Support both fixed-length and variable-length sequences in token-parallel kernel. - Fix safe_exp usage and add causal masking in log_linear_attn chunk ops.
|
Note Other AI code review bot(s) detectedCodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review. WalkthroughAdds a token-parallel execution path to the KDA intra-chunk computation by introducing a conditional flag in Changes
Sequence DiagramsequenceDiagram
participant Caller
participant chunk_kda_fwd_intra
participant chunk_kda_fwd_intra_token_parallel
participant Triton Kernel
Caller->>chunk_kda_fwd_intra: call with use_token_parallel flag
alt use_token_parallel == True
chunk_kda_fwd_intra->>chunk_kda_fwd_intra_token_parallel: dispatch to wrapper
chunk_kda_fwd_intra_token_parallel->>Triton Kernel: configure grid and launch
Triton Kernel->>Triton Kernel: per-token, per-head-group processing
Triton Kernel->>Triton Kernel: binary search cu_seqlens (if varlen)
Triton Kernel->>Triton Kernel: compute Aqk, Akk with gating & masking
Triton Kernel-->>chunk_kda_fwd_intra_token_parallel: return results
else use_token_parallel == False
chunk_kda_fwd_intra->>chunk_kda_fwd_intra: use original sub-intra path
end
chunk_kda_fwd_intra-->>Caller: return (output, tril_decomposition)
Estimated Code Review Effort🎯 4 (Complex) | ⏱️ ~45 minutes
Possibly Related PRs
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello @sustcsonglin, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces a significant optimization for KDA (Kernelized Deep Attention) models by implementing a token-parallel strategy for intra-chunk attention computation. This new approach aims to enhance computational efficiency, especially for variable-length sequences, by allowing each token to be processed by its own thread block. Concurrently, the PR cleans up deprecated code related to Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces a significant optimization for KDA intra-chunk attention by implementing a token-parallel kernel. This new implementation is added as a default execution path in chunk_kda_fwd_intra, with an option to fall back to the original implementation. The PR also deprecates fused_chunk_gla, guiding users towards chunk_gla. Additionally, there are several correctness fixes and refactorings across different kernels, such as adding explicit masking for padded tokens and replacing the implicit causal masking of safe_exp with explicit tl.where conditions, which improves code clarity and robustness.
Overall, the changes are well-implemented and contribute to better performance and maintainability. I have one minor suggestion to improve the documentation for the new token-parallel kernel.
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
There was a problem hiding this comment.
Actionable comments posted: 1
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
fla/ops/kda/chunk_intra.py(4 hunks)fla/ops/kda/chunk_intra_token_parallel.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
fla/ops/kda/chunk_intra.py (1)
fla/ops/kda/chunk_intra_token_parallel.py (2)
chunk_kda_fwd_intra_token_parallel(158-219)grid(200-202)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Test H100 (PyTorch 2.7) / test-ops
| if cu_seqlens is not None: | ||
| total_tokens = q.shape[1] | ||
| # Use num_sequences as B for binary search | ||
| B_kernel = len(cu_seqlens) - 1 | ||
| else: | ||
| total_tokens = B * T | ||
| B_kernel = B |
There was a problem hiding this comment.
Fix varlen token count for token-parallel grid.
Line 193
When cu_seqlens is provided, total_tokens must come from the cumulative lengths (cu_seqlens[-1]). Using q.shape[1] only covers the padded T dimension, so in a batch with several sequences we launch work for at most T_max tokens and leave the remainder of the batch untouched (their Aqk/Akk rows stay zero). Please derive the grid size from the cumulative lengths instead.
- total_tokens = q.shape[1]
+ total_tokens = int(cu_seqlens[-1].item())📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| if cu_seqlens is not None: | |
| total_tokens = q.shape[1] | |
| # Use num_sequences as B for binary search | |
| B_kernel = len(cu_seqlens) - 1 | |
| else: | |
| total_tokens = B * T | |
| B_kernel = B | |
| if cu_seqlens is not None: | |
| total_tokens = int(cu_seqlens[-1].item()) | |
| # Use num_sequences as B for binary search | |
| B_kernel = len(cu_seqlens) - 1 | |
| else: | |
| total_tokens = B * T | |
| B_kernel = B |
🤖 Prompt for AI Agents
In fla/ops/kda/chunk_intra_token_parallel.py around lines 192 to 198, when
cu_seqlens is provided total_tokens is incorrectly set from q.shape[1] (padded
T) which undercounts real tokens for varlen batches; replace total_tokens =
q.shape[1] with total_tokens = int(cu_seqlens[-1]) (or cu_seqlens[-1].item() if
it's a tensor) so the kernel grid uses the true total token count derived from
cumulative lengths while keeping B_kernel = len(cu_seqlens) - 1.
KDA Intra Chunk Attention: Token-Parallel Implementation Benchmark
1. Overview
This report benchmarks the performance of the newly implemented Token-Parallel kernel for KDA intra-chunk attention against the Original (Sub-Chunk) implementation.
The goal was to improve efficiency for large-scale training scenarios by:
Kdimension at once (no tiling loop).2. Methodology
2.1 Function Benchmarked
Target Function:
chunk_kda_fwd_intrainfla/ops/kda/chunk_intra.pyComparison:
chunk_kda_fwd_kernel_intra_sub_intra(Sub-Chunk Parallelism).use_token_parallel=Falsechunk_kda_fwd_intra_token_parallel(Token Parallelism).use_token_parallel=True(New Default)2.2 Environment
torch.float162.3 Test Configurations
We selected configurations representative of large-scale LLM training:
3. Performance Results
B=16, T=4096, H=16, K=64B=8, T=4096, H=32, K=128B=8, T=8192, H=16, K=64B=4, T=8192, H=32, K=1283.1 Key Observations
Kincreases to 128 (reaching 1.21x speedup).4. Implementation Details
4.1 Original Implementation (Sub-Chunk)
(NT, NC, B*H)whereNTis number of chunks,NCis sub-chunks per chunk.Kdimension.4.2 Token-Parallel Implementation
(Total Tokens, H/BH)whereBHis autotuned head block size.BHheads per thread block (Head fusion).i_t, skipping upper-triangular computations completely (unlike tiled approach which masks after compute).5. Conclusion
The Token-Parallel implementation provides a significant performance boost for KDA intra-chunk attention, especially in computation-heavy scenarios relevant to large-scale training. It is now set as the default implementation.
use_token_parallel=Trueuse_token_parallel=False(Original implementation kept for backward compatibility/verification).