[Deltarule] Added intra-card context parallel optimization for KDA and GDN#743
[Deltarule] Added intra-card context parallel optimization for KDA and GDN#743
Conversation
Summary of ChangesHello @zhiyuan1i, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances the performance of Gated Delta Networks (GDN) and Kernelized Delta Attention (KDA) during the prefill phase, especially for very long sequences in inference mode. By introducing an intra-card context parallel mechanism, it intelligently breaks down long sequences into smaller, manageable sub-sequences that can be processed concurrently on the GPU, and then efficiently merges their states. This optimization aims to reduce latency and improve throughput without altering the core functionality or mathematical correctness of the models. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
55ac5a1 to
4fdb39a
Compare
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
WalkthroughAdds an intra‑card context‑parallel varlen path: a common backend registry and IntraCardCPBackend, a new intracard orchestration module to split/process subsequences, Triton kernel extensions for multi‑sequence/intracard modes, cu_seqlens_cpu threaded through the KDA call path, benchmark updates, and new varlen prefill tests. Changes
Sequence Diagram(s)sequenceDiagram
participant User as User/Test
participant ChunkKDA as ChunkKDAFunction
participant ChunkFwd as chunk_kda_fwd
participant Dispatch as dispatch('common')
participant IntraCardBE as IntraCardCPBackend
participant Intracard as intracard_fwd_h
participant Kernel as Kernels (pre_scan / merge / fwd)
User->>ChunkKDA: forward(..., cu_seqlens, cu_seqlens_cpu)
ChunkKDA->>ChunkFwd: chunk_kda_fwd(..., cu_seqlens_cpu)
ChunkFwd->>Dispatch: chunk_gated_delta_rule_fwd_h(..., cu_seqlens, cu_seqlens_cpu)
Dispatch->>IntraCardBE: chunk_gated_delta_rule_fwd_h_verifier(cu_seqlens)
IntraCardBE-->>Dispatch: available?
Dispatch->>Intracard: intracard_fwd_h(..., cu_seqlens_cpu, max_splits)
Intracard->>Intracard: compute_subseq_len / prepare_subseq_cu_seqlens
alt splitting needed
Intracard->>Kernel: intracard_pre_scan(...)
Kernel-->>Intracard: hm (pre-scan)
Intracard->>Kernel: merge_fwd_bwd_kernel(INTRACARD_MODE=True, seq_offsets, init_offsets, ...)
Kernel-->>Intracard: merged states
Intracard->>Kernel: _raw_chunk_gated_delta_rule_fwd_h per-subseq
Kernel-->>Intracard: h, v_new, final_state
else no split
Intracard->>Kernel: _raw_chunk_gated_delta_rule_fwd_h(...)
Kernel-->>Intracard: h, v_new, final_state
end
Intracard-->>Dispatch: results
Dispatch-->>ChunkFwd: results
ChunkFwd-->>ChunkKDA: outputs
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
🧪 Generate unit tests (beta)
No actionable comments were generated in the recent review. 🎉 🧹 Recent nitpick comments
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request introduces an intra-card context parallel (CP) optimization for the delta rule kernels, aimed at accelerating prefill during inference. The changes are extensive, touching core kernels, adding a new backend dispatch system, and implementing the complex logic for splitting sequences and merging states. The addition of new tests and benchmarks for the feature is great. I've found one critical issue that could lead to a runtime error in an edge case. Other than that, the implementation looks solid.
acced75 to
db4760c
Compare
There was a problem hiding this comment.
Actionable comments posted: 3
🤖 Fix all issues with AI agents
In `@fla/ops/common/intracard_cp.py`:
- Around line 62-64: The final_state is allocated with k.new_empty(...) which
leaves padded slots uninitialized; replace the allocation in intracard_cp.py
where final_state is set (variable name final_state, allocation via k.new_empty)
with a zero-initialized allocation using k.new_zeros(...) (matching the pattern
from chunk_gated_delta_rule_fwd_h) so padded positions are deterministic for
CUDA Graphs/vLLM; keep the dtype and shape the same and preserve the conditional
(output_final_state) logic.
In `@tests/ops/test_gated_delta.py`:
- Line 400: Remove the unnecessary .requires_grad_() call inside the
`@torch.inference_mode`() block: in the mapping that assigns q, k, v, beta, g, h0
(the lambda used on those variables) only move tensors to device (x.to(device))
and do not call .requires_grad_(); if any of these tensors actually need
gradients for a specific test, wrap that test in torch.enable_grad() instead of
setting requires_grad_() inside inference mode.
In `@tests/ops/test_kda.py`:
- Around line 569-571: The test is running under `@torch.inference_mode`(), so
calling requires_grad_() on tensors (q, k, v, g, beta, h0 and conditionally
A_log, dt_bias) is unnecessary and can raise RuntimeError; remove the
requires_grad_() calls and just move the tensors to device (e.g., replace
map(lambda x: x.to(device).requires_grad_(), ...) with map(lambda x:
x.to(device), ...) for q, k, v, g, beta, h0 and similarly for A_log and dt_bias
when use_gate_in_kernel is true) so tensors are placed on the device without
enabling gradient tracking.
🧹 Nitpick comments (7)
fla/ops/common/chunk_delta_h.py (1)
468-480:cu_seqlens_cpuis unused in the default implementation (intentional for dispatch).Ruff flags
cu_seqlens_cpuas unused (ARG001). This is expected — the parameter exists so that backends (e.g.,IntraCardCPBackend) receive it via the@dispatch('common')wrapper. The default function correctly ignores it. Consider adding a brief comment or anoqa: ARG001annotation to make the intent clear.fla/ops/common/backends/intracard.py (1)
22-31:is_availableoverride is redundant.Since
package_name = None,BaseBackend.is_available()already returnsTrue. The override is harmless but unnecessary.fla/ops/common/intracard_cp.py (5)
148-171: Splitting threshold (4×) vs early-return threshold (2×) causes a wasted-work window.
prepare_subseq_cu_seqlenssplits a sequence only ifseq_len_i >= 4 * subseq_len(line 154), butintracard_fwd_hskips the early-return path when any sequence is ≥2 * subseq_len(line 394). When the longest sequence falls in [2×, 4×), the code enters the splitting path, callsprepare_subseq_cu_seqlens, finds nothing to split, and falls back — doing unnecessary CPU work. This is functionally correct (thenot split_infoguard on line 399 catches it), but the mismatch is worth a comment or aligning the thresholds.
290-362: Considerstrict=Trueinzip()calls for safety.The
zip()calls on lines 315, 321, and 326 iterate over parallel lists fromSplitSeqInfothat are expected to have equal length. Addingstrict=Truewould catch bugs where the lists have mismatched lengths, at negligible runtime cost for these small arrays.Proposed fix
- for s, n in zip(starts, num_ss): + for s, n in zip(starts, num_ss, strict=True): cu_seqlens_split_values.extend(cu_seqlens_subseq_values[s:s + n + 1]) S_split_total += n # num_subseqs_per_seq: [N_orig], default 1 for unsplit sequences num_subseqs_per_seq = [1] * N_orig - for sid, nss in zip(split_ids, num_ss): + for sid, nss in zip(split_ids, num_ss, strict=True): num_subseqs_per_seq[sid] = nss # non_first_indices: for scattering merged initial states non_first_indices: list[int] = [] - for s, n in zip(starts, num_ss): + for s, n in zip(starts, num_ss, strict=True): for j in range(1, n):
393-409:split_inforelies on short-circuit evaluation to avoidUnboundLocalError— fragile pattern.When
early_returnisTrue,split_infois never assigned (lines 395-398 are skipped). Line 399 works only becauseorshort-circuits. This is correct but fragile — a future refactor (e.g., changingortoand, or splitting the condition) could introduce a runtime crash.Consider initializing
split_infobefore the branch:Proposed fix
+ split_info = False + total_subseqs = 0 early_return = (seq_lens < 2 * subseq_len).all() if not early_return: cu_seqlens_subseq, split_info, total_subseqs = prepare_subseq_cu_seqlens( cu_seqlens_cpu, subseq_len, chunk_size, max_splits=max_splits ) if early_return or not split_info:
415-425: Unused variablenum_subseqs_per_seq.
num_subseqs_per_seqis unpacked on line 418 but never referenced afterward. Prefix with_to signal intent and silence the Ruff RUF059 warning.Proposed fix
( cu_seqlens_split_values, S_split_total, - num_subseqs_per_seq, + _num_subseqs_per_seq, non_first_indices, first_subseq_indices, last_subseq_indices,
448-456: Indexing with Python lists into a GPU tensor triggers implicit conversion.
first_subseq_indices,non_first_indices, andlast_subseq_indicesare Pythonlist[int]. When used as indices into GPU tensors (lines 451, 453, 456, 476), PyTorch will convert them to a CPU LongTensor each time, then transfer to GPU. This works correctly, but if the number of sub-sequences grows, the repeated implicit conversions could add overhead. Given that these are typically small arrays and this is an optimization-focused module, consider pre-converting to GPU tensors once.
There was a problem hiding this comment.
Actionable comments posted: 3
🤖 Fix all issues with AI agents
In `@fla/ops/common/intracard_cp.py`:
- Around line 112-115: The comment block incorrectly states the split threshold
as 2 * subseq_len; update the comment to reflect the actual logic:
prepare_subseq_cu_seqlens computes four_subseq_len = 4 * subseq_len and uses
that as the split threshold, while intracard_fwd_h uses an early-return check at
2 * subseq_len (a looser gate), so clarify both thresholds and their roles (2×
for early-return, 4× for actual split) and adjust the numeric examples
accordingly (e.g., chunk_size=64, MIN_SUBSEQ_CHUNKS=128 → subseq_len >= 8192
tokens, split threshold = 4 * subseq_len = 32768 tokens).
- Around line 394-409: The early-return logic relies on short-circuiting and can
leave split_info (and cu_seqlens_subseq/total_subseqs) unbound; change the flow
so you return immediately when early_return is True instead of using the
combined condition. Concretely, check if early_return: return the call to
_raw_chunk_gated_delta_rule_fwd_h(...) right away; only call
prepare_subseq_cu_seqlens(...) and use split_info when early_return is False,
and then keep the existing fallback of returning when split_info is empty. This
removes reliance on short-circuit evaluation and prevents UnboundLocalError.
- Around line 448-456: initial_state_expanded is allocated with new_empty and
leaves non-first subseq slots uninitialized when num_non_first > 0 but
initial_states_merge is None; update the allocation/assignment in the block
using initial_state_expanded, initial_state, initial_states_merge,
first_subseq_indices, non_first_indices, and num_non_first so those non-first
slots are deterministically initialized (either allocate with zeros/new_zeros or
explicitly set initial_state_expanded[non_first_indices] = 0 when
initial_states_merge is None) while preserving the existing assignments to
first_subseq_indices.
🧹 Nitpick comments (2)
fla/ops/common/intracard_cp.py (2)
415-425: Unused variablenum_subseqs_per_seq.Static analysis (RUF059) flags
num_subseqs_per_seqas unused after destructuring. Prefix with_to signal intent.Proposed fix
( cu_seqlens_split_values, S_split_total, - num_subseqs_per_seq, + _num_subseqs_per_seq, non_first_indices, first_subseq_indices, last_subseq_indices,
315-317: Consider addingstrict=Truetozip()calls for defensive checking.All three
zip()calls in this function iterate over paired lists fromSplitSeqInfowhich should always be the same length by construction. Addingstrict=True(Python 3.10+) would catch any future mismatch early instead of silently truncating. This is also flagged by Ruff B905.Proposed fix
- for s, n in zip(starts, num_ss): + for s, n in zip(starts, num_ss, strict=True): cu_seqlens_split_values.extend(cu_seqlens_subseq_values[s:s + n + 1]) S_split_total += n - for sid, nss in zip(split_ids, num_ss): + for sid, nss in zip(split_ids, num_ss, strict=True): num_subseqs_per_seq[sid] = nss - for s, n in zip(starts, num_ss): + for s, n in zip(starts, num_ss, strict=True): for j in range(1, n):Also applies to: 321-322, 326-328
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Fix all issues with AI agents
In `@fla/ops/common/intracard_cp.py`:
- Around line 415-425: The tuple unpack from _precompute_intracard_indices
assigns an unused variable num_subseqs_per_seq; rename it to prefix with an
underscore (e.g., _num_subseqs_per_seq) in the unpacking expression to signal
intentional unusedness and satisfy static analysis; update the tuple on the call
site where _precompute_intracard_indices(...) is unpacked so all other names
(cu_seqlens_split_values, S_split_total, non_first_indices,
first_subseq_indices, last_subseq_indices, num_non_first, merge_seq_offsets,
merge_init_offsets) remain unchanged.
🧹 Nitpick comments (2)
fla/ops/common/intracard_cp.py (2)
315-317: Consider addingstrict=Truetozip()calls to catch length mismatches.Lines 315, 321, and 326 use
zip()on lists fromSplitSeqInfothat should always be equal length. Addingstrict=Truewould catch bugs from mismatched list lengths early (Python 3.10+). Flagged by Ruff B905.Proposed fix (one example — apply to all three)
- for s, n in zip(starts, num_ss): + for s, n in zip(starts, num_ss, strict=True):
290-362:_precompute_intracard_indicesreturns a 9-element tuple — consider a structured return type.A 9-element tuple is hard to unpack correctly and fragile to extend. A
NamedTupleordataclass(similar toSplitSeqInfo) would improve readability and make the call site at lines 415–425 less error-prone.
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Fix all issues with AI agents
In `@fla/ops/common/intracard_cp.py`:
- Around line 78-119: The docstring and inline comments in compute_subseq_len
incorrectly state the split threshold as "2 * subseq_len" while the actual
threshold used by prepare_subseq_cu_seqlens is "4 * subseq_len"; update the
compute_subseq_len docstring and any related inline comments to reflect the
correct split threshold (4 * subseq_len) and mention prepare_subseq_cu_seqlens
by name so readers know where the threshold is enforced.
🧹 Nitpick comments (3)
tests/ops/test_kda.py (1)
550-556: Redundant.to(device)calls on tensors already ondevice.
A_loganddt_bias(lines 555–556) are created withdevice=deviceand then.to(device)again. Similarly,q/k/v/g(lines 550–553) could usedevice=devicedirectly liketest_chunk_varlendoes forA_log/dt_bias. This is a cosmetic nit — no functional impact.fla/ops/common/intracard_cp.py (2)
314-327: Consider addingstrict=Truetozip()calls for safety.Ruff B905 flags three
zip()calls withoutstrict=. While the lists should always be equal-length (they come from the sameSplitSeqInfo),strict=Trueprovides a runtime check that catches index computation bugs early.
122-185:prepare_subseq_cu_seqlens— return typeSplitSeqInfo | boolis a bit unusual.Returning
False(line 135, 174) whereSplitSeqInfois expected works because of the__bool__check at the call site, but it makes the type annotation on line 127 a bit awkward. A future caller using pattern matching or isinstance checks would be surprised. Not blocking, just noting.
Summary by CodeRabbit
New Features
Tests