Skip to content

[CP] Enable KCP for DPLR#822

Merged
zhiyuan1i merged 9 commits intomainfrom
lzy/dplr-kcp
Apr 12, 2026
Merged

[CP] Enable KCP for DPLR#822
zhiyuan1i merged 9 commits intomainfrom
lzy/dplr-kcp

Conversation

@zhiyuan1i
Copy link
Copy Markdown
Collaborator

@zhiyuan1i zhiyuan1i commented Apr 11, 2026

Summary by CodeRabbit

  • New Features

    • Added Context Parallel (CP) implementations for token-shift and DPLR to support distributed sequence chunks and cross-rank exchanges; updated kernels to optionally accept an extra gating/transition input.
  • Tests

    • Added end-to-end distributed test suites for CP DPLR and CP token-shift (multiple layouts); simplified some legacy CP tests and tightened numeric tolerances.
  • Documentation

    • Added a comprehensive CP debugging guide for precision and troubleshooting.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Apr 11, 2026

No actionable comments were generated in the recent review. 🎉

ℹ️ Recent review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: 41b74c7e-da4b-4514-97b4-21bf86847370

📥 Commits

Reviewing files that changed from the base of the PR and between 1207333 and 8fc4b33.

📒 Files selected for processing (1)
  • fla/ops/common/intracard_cp.py

Walkthrough

Implements Context-Parallel (CP) token-shift autograd, extends CP gated-delta Triton kernels (USE_BG path), threads CP context through DPLR chunk ops, adds distributed CP tests and a CP debugging guide; includes minor test maintenance and kernel argument plumbing.

Changes

Cohort / File(s) Summary
Token Shift CP Implementation
fla/modules/token_shift_cp.py
New autograd TokenShiftCPFunction and helper token_shift_cp(...) adding CP forward/backward: validates cp_context, derives cu_seqlens/group, exchanges per-sequence cache via conv_cp_send_recv_fwd/bwd, calls existing token_shift_fwd/bwd, and merges CP gradient corrections.
Chunk Delta Rule Kernel Updates
fla/ops/cp/chunk_delta_h.py
Adds USE_BG conditional path and bg/v parameters to forward pre-process and merged backward Triton kernels; adjusts head-dim/stride handling and changes forward recurrence and backward dh/dm accumulation/signs for the gated path.
DPLR CP Integration
fla/ops/generalized_delta_rule/dplr/chunk.py
Plumbs optional cp_context through chunk_dplr_fwd / ChunkDPLRDeltaRuleFunction / chunk_dplr_delta_rule; adds CP preprocessing/postprocessing (chunk_gated_delta_rule_fwd_h_pre_process, compress_h0, expand_h0), returns/ saves initial_state for recompute, and inserts CP-specific backward preprocessing.
Intracard CP Kernel Call
fla/ops/common/intracard_cp.py
Updated intracard_pre_scan to pass explicit bg=None and u=u to the forward pre-process kernel invocation.
CP Token Shift Tests
tests/context_parallel/test_cp_token_shift.py
New distributed CUDA tests exercising token_shift_cp across CP2/CP4 scenarios: spawn per-rank workers, build CP context, run local CP op, gather outputs/gradients, and compare to single-process reference.
CP DPLR Tests
tests/context_parallel/test_cp_dplr.py
New distributed tests validating CP DPLR operator: broadcast globals, compute naive reference, run local chunk_dplr_delta_rule with cp_context, gather results, and assert equivalence across multiple CP scenarios (including safe_gate).
Test Maintenance / Tolerances
tests/context_parallel/test_cp_bwd_gk_offset.py, tests/context_parallel/test_cp_kda.py
Removed distributed worker harness and one multi-GPU test in test_cp_bwd_gk_offset.py, simplified to a local kernel regression and added USE_BG=False to kernel launch; tightened numeric tolerance in test_cp_kda.py from 2e-2 to 8e-3.
Debugging Guide Added
tests/context_parallel/debug.md
New comprehensive debugging guide for Context-Parallel precision failures outlining simulation, kernel vs communication isolation steps, common failure modes, and recommended debug layout.

Sequence Diagram(s)

sequenceDiagram
    participant Client
    participant TokenShiftCPFunction
    participant CP_SendRecv as "CP Send/Recv"
    participant TokenShiftFwd as "token_shift_fwd"
    participant Cache

    Client->>TokenShiftCPFunction: forward(x, cu_seqlens, chunk_indices, cp_context)
    TokenShiftCPFunction->>TokenShiftCPFunction: validate cp_context, derive cu_seqlens/group
    TokenShiftCPFunction->>CP_SendRecv: conv_cp_send_recv_fwd(send_last_token)
    CP_SendRecv-->>Cache: recv_prev_rank_token (non-first ranks)
    TokenShiftCPFunction->>TokenShiftFwd: call with x and cache
    TokenShiftFwd-->>TokenShiftCPFunction: shifted output
    TokenShiftCPFunction-->>Client: return output
Loading
sequenceDiagram
    participant Client
    participant TokenShiftCPFunction
    participant TokenShiftBwd as "token_shift_bwd"
    participant CP_SendRecv as "CP Send/Recv"

    Client->>TokenShiftCPFunction: backward(dy)
    TokenShiftCPFunction->>TokenShiftBwd: compute dx, grad_cache
    TokenShiftCPFunction->>CP_SendRecv: conv_cp_send_recv_bwd(send_grad_cache)
    CP_SendRecv-->>TokenShiftCPFunction: recv_grad_from_next_rank
    TokenShiftCPFunction->>TokenShiftCPFunction: add recv_grad to dx[0,-1,:]
    TokenShiftCPFunction-->>Client: return dx
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Possibly related PRs

Suggested reviewers

  • yzhangcs
  • Nathancgy

Poem

🐇 I hop across ranks, token by token bright,
Caches whisper secrets through NCCL night,
Gates and deltas hum in parallel flight,
Ranks stitched together, gradients alight —
— a rabbit cheers this CP delight!

🚥 Pre-merge checks | ✅ 3
✅ Passed checks (3 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The PR title '[CP] Enable KCP for DPLR' clearly summarizes the main change - enabling Kimi Context Parallel (KCP) support for the DPLR operator. This is well-supported by the changes: new token shift CP module, modifications to chunk_delta_h.py for BG support, extensive updates to DPLR chunk operations with cp_context parameter, and comprehensive tests for CP+DPLR integration.
Docstring Coverage ✅ Passed Docstring coverage is 92.31% which is sufficient. The required threshold is 80.00%.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
📝 Generate docstrings
  • Create stacked PR
  • Commit on current branch
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch lzy/dplr-kcp

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces Context Parallel (CP) support for the Delta Product Linear RNN (DPLR) operator. Key changes include updating the Triton pre-processing kernels in fla/ops/cp/chunk_delta_h.py to handle DPLR-specific logic, such as the use of the bg gate and modified state update rules. The high-level chunk_dplr_delta_rule was integrated with these CP steps, and a comprehensive test suite was added to verify the implementation against a reference recurrence. Review feedback identifies a potential bug in the kernel argument passing for DPLR mode and points out a misleading comment regarding pointer offsets.

Comment thread fla/ops/cp/chunk_delta_h.py Outdated
Comment thread fla/ops/cp/chunk_delta_h.py
@zhiyuan1i zhiyuan1i marked this pull request as ready for review April 11, 2026 16:25
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🧹 Nitpick comments (1)
fla/modules/token_shift_cp.py (1)

136-137: Avoid two sources of truth for sequence metadata.

TokenShiftCPFunction.forward always uses cp_context.cu_seqlens, but token_shift_cp still accepts and forwards a separate cu_seqlens. That lets callers pass mismatched cu_seqlens/chunk_indices without any signal. Either remove the public cu_seqlens parameter or validate it before dispatch.

Example cleanup
 def token_shift_cp(
     x: torch.Tensor,
     cp_context: FLACPContext,
     cu_seqlens: torch.Tensor | None = None,
     chunk_indices: torch.Tensor | None = None,
 ):
@@
-    assert cp_context.cu_seqlens is not None, "cu_seqlens must be provided for token_shift_cp"
+    if cp_context.cu_seqlens is None:
+        raise ValueError("cp_context.cu_seqlens must be provided for token_shift_cp")
+    if cu_seqlens is not None and not torch.equal(cu_seqlens, cp_context.cu_seqlens):
+        raise ValueError("cu_seqlens must match cp_context.cu_seqlens")
 
     return TokenShiftCPFunction.apply(
-        x, cu_seqlens, chunk_indices, cp_context
+        x, cp_context.cu_seqlens, chunk_indices, cp_context
     )

Also applies to: 221-227

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@fla/modules/token_shift_cp.py` around lines 136 - 137, token_shift_cp is
accepting a cu_seqlens argument but TokenShiftCPFunction.forward always pulls
cu_seqlens from cp_context, creating two sources of truth; either remove the
public cu_seqlens parameter from token_shift_cp and all callers (so
TokenShiftCPFunction.forward uses only cp_context.cu_seqlens), or add a strict
validation at the start of token_shift_cp (and the corresponding code paths
around lines 221-227) that compares the supplied cu_seqlens to
cp_context.cu_seqlens (e.g., equality check or torch.equal) and raises a
ValueError if they differ; update the function signature and call sites if you
choose removal, or add the validation and an explanatory error if you choose to
keep the parameter.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@fla/ops/generalized_delta_rule/dplr/chunk.py`:
- Around line 472-478: When cp_context is provided you must override any
caller-supplied CPU boundaries so prepare_chunk_indices doesn't mix views;
replace the conditional assignment that only sets cu_seqlens_cpu when
cp_context.cu_seqlens_cpu is not None with an unconditional override (i.e., set
cu_seqlens_cpu = cp_context.cu_seqlens_cpu) in the block that handles cp_context
(alongside cp_context, cu_seqlens, initial_state, output_final_state checks) so
prepare_chunk_indices always sees the CP view (including None) from cp_context.

In `@tests/context_parallel/test_cp_dplr.py`:
- Around line 245-255: The call to chunk_dplr_delta_rule hardcodes
safe_gate=True so the non-safe path is never exercised; change the argument to
pass the test matrix variable (e.g., safe_gate) instead of the literal True so
the test will invoke both safe_gate=True and safe_gate=False; locate the
invocation of chunk_dplr_delta_rule (the parameters q=q_local, k=k_local,
v=v_local, a=a_local, b=b_local, gk=gk_local, cp_context=context) and replace
safe_gate=True with safe_gate (or the actual parameter name used by the test).

In `@tests/context_parallel/test_cp_token_shift.py`:
- Around line 181-188: The test always sets a hard-coded rendezvous port (port =
29510) which causes races when pytest runs tests in parallel; change the port
selection in the setup before calling mp.start_processes (the place where port
is passed to run_cp_token_shift_test_worker) to use a unique free port per test
run (e.g., call a helper that returns an available ephemeral port or derive a
port from the current process/worker id or random offset) so each spawned test
uses its own rendezvous port rather than the fixed 29510.

---

Nitpick comments:
In `@fla/modules/token_shift_cp.py`:
- Around line 136-137: token_shift_cp is accepting a cu_seqlens argument but
TokenShiftCPFunction.forward always pulls cu_seqlens from cp_context, creating
two sources of truth; either remove the public cu_seqlens parameter from
token_shift_cp and all callers (so TokenShiftCPFunction.forward uses only
cp_context.cu_seqlens), or add a strict validation at the start of
token_shift_cp (and the corresponding code paths around lines 221-227) that
compares the supplied cu_seqlens to cp_context.cu_seqlens (e.g., equality check
or torch.equal) and raises a ValueError if they differ; update the function
signature and call sites if you choose removal, or add the validation and an
explanatory error if you choose to keep the parameter.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: 7bf51201-398d-442e-8625-3e0190e7574f

📥 Commits

Reviewing files that changed from the base of the PR and between 967f8c0 and 77fe42d.

📒 Files selected for processing (7)
  • fla/modules/token_shift_cp.py
  • fla/ops/cp/chunk_delta_h.py
  • fla/ops/generalized_delta_rule/dplr/chunk.py
  • tests/context_parallel/test_cp_bwd_gk_offset.py
  • tests/context_parallel/test_cp_dplr.py
  • tests/context_parallel/test_cp_kda.py
  • tests/context_parallel/test_cp_token_shift.py

Comment on lines +472 to +478
if cp_context is not None:
assert initial_state is None, "Initial state is not supported for CP"
assert output_final_state is False, "Output final state is not supported for CP"
assert cp_context.cu_seqlens is not None, "cu_seqlens is required for CP"
cu_seqlens = cp_context.cu_seqlens
if cp_context.cu_seqlens_cpu is not None:
cu_seqlens_cpu = cp_context.cu_seqlens_cpu
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Override cu_seqlens_cpu unconditionally in CP mode.

Here cu_seqlens switches to the local CP view, but cu_seqlens_cpu keeps the caller-supplied value whenever cp_context.cu_seqlens_cpu is None. prepare_chunk_indices then mixes local GPU boundaries with stale CPU boundaries and can build the wrong chunks.

Minimal fix
     if cp_context is not None:
         assert initial_state is None, "Initial state is not supported for CP"
         assert output_final_state is False, "Output final state is not supported for CP"
         assert cp_context.cu_seqlens is not None, "cu_seqlens is required for CP"
         cu_seqlens = cp_context.cu_seqlens
-        if cp_context.cu_seqlens_cpu is not None:
-            cu_seqlens_cpu = cp_context.cu_seqlens_cpu
+        cu_seqlens_cpu = cp_context.cu_seqlens_cpu
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@fla/ops/generalized_delta_rule/dplr/chunk.py` around lines 472 - 478, When
cp_context is provided you must override any caller-supplied CPU boundaries so
prepare_chunk_indices doesn't mix views; replace the conditional assignment that
only sets cu_seqlens_cpu when cp_context.cu_seqlens_cpu is not None with an
unconditional override (i.e., set cu_seqlens_cpu = cp_context.cu_seqlens_cpu) in
the block that handles cp_context (alongside cp_context, cu_seqlens,
initial_state, output_final_state checks) so prepare_chunk_indices always sees
the CP view (including None) from cp_context.

Comment on lines +245 to +255
o_local, _ = chunk_dplr_delta_rule(
q=q_local,
k=k_local,
v=v_local,
a=a_local,
b=b_local,
gk=gk_local,
cp_context=context,
safe_gate=True,
chunk_size=64,
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

This never exercises safe_gate=False.

The worker plumbs safe_gate through its API and the test matrix varies it, but the actual call is hardcoded to safe_gate=True. That means the default cases stop validating the non-safe path, and test_cp2_safe_gate is effectively a duplicate.

Fix
         o_local, _ = chunk_dplr_delta_rule(
             q=q_local,
             k=k_local,
             v=v_local,
             a=a_local,
             b=b_local,
             gk=gk_local,
             cp_context=context,
-            safe_gate=True,
+            safe_gate=safe_gate,
             chunk_size=64,
         )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
o_local, _ = chunk_dplr_delta_rule(
q=q_local,
k=k_local,
v=v_local,
a=a_local,
b=b_local,
gk=gk_local,
cp_context=context,
safe_gate=True,
chunk_size=64,
)
o_local, _ = chunk_dplr_delta_rule(
q=q_local,
k=k_local,
v=v_local,
a=a_local,
b=b_local,
gk=gk_local,
cp_context=context,
safe_gate=safe_gate,
chunk_size=64,
)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/context_parallel/test_cp_dplr.py` around lines 245 - 255, The call to
chunk_dplr_delta_rule hardcodes safe_gate=True so the non-safe path is never
exercised; change the argument to pass the test matrix variable (e.g.,
safe_gate) instead of the literal True so the test will invoke both
safe_gate=True and safe_gate=False; locate the invocation of
chunk_dplr_delta_rule (the parameters q=q_local, k=k_local, v=v_local,
a=a_local, b=b_local, gk=gk_local, cp_context=context) and replace
safe_gate=True with safe_gate (or the actual parameter name used by the test).

Comment on lines +181 to +188
port = 29510 # Different port from other CP tests
mp.start_processes(
run_cp_token_shift_test_worker,
args=(world_size, test_name, T, D, lengths, dtype, port),
nprocs=world_size,
join=True,
start_method='spawn',
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Use a unique rendezvous port per spawned test run.

Every case in this module binds 29510. Under parallel pytest workers, two CP tests can race for the same localhost port and fail nondeterministically.

Example fix
+import socket
@@
 def run_cp_test_with_spawn(
     world_size: int,
     test_name: str,
     T: int,
     D: int,
     lengths: list[int],
     dtype=torch.float32,
 ):
     """Run CP test using torch.multiprocessing.spawn."""
-    port = 29510  # Different port from other CP tests
+    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
+        sock.bind(("127.0.0.1", 0))
+        port = sock.getsockname()[1]
     mp.start_processes(
         run_cp_token_shift_test_worker,
         args=(world_size, test_name, T, D, lengths, dtype, port),
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/context_parallel/test_cp_token_shift.py` around lines 181 - 188, The
test always sets a hard-coded rendezvous port (port = 29510) which causes races
when pytest runs tests in parallel; change the port selection in the setup
before calling mp.start_processes (the place where port is passed to
run_cp_token_shift_test_worker) to use a unique free port per test run (e.g.,
call a helper that returns an available ephemeral port or derive a port from the
current process/worker id or random offset) so each spawned test uses its own
rendezvous port rather than the fixed 29510.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (1)
tests/context_parallel/debug.md (1)

163-170: Make the 5e-3 acceptance bar explicitly conditional.

This threshold is useful, but it should be documented with conditions (hardware/dtype/seed/shape regime), otherwise it can be interpreted as universally stable and cause noisy triage.

📝 Suggested wording tweak
-Var-length KCP with `safe_gate=True`, bf16 inputs, and an unaligned
-cut point should land below 5e-3 `norm_ratio` per gradient against
+Var-length KCP with `safe_gate=True`, bf16 inputs, and an unaligned
+cut point is expected to land around/below 5e-3 `norm_ratio` per gradient against
 the per-token `naive` reference (measured per-sequence). That is pure
 bf16 chunk-vs-per-token noise and matches the magnitudes that the
 long-standing KCP tests (e.g. gated-delta-rule CP2) sit at. Anything
-above ~5e-3 means either the forward recomputation is using the wrong
+consistently above this range (under the same seed, shape regime, and hardware)
+usually means either the forward recomputation is using the wrong
 `initial_state`, or the merge kernel is being called with a stale
 `BV`, or both.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/context_parallel/debug.md` around lines 163 - 170, Update the prose
around the "5e-3" acceptance bar in the Var-length KCP test description (the
paragraph mentioning Var-length KCP with `safe_gate=True`, bf16 inputs,
unaligned cut point, `norm_ratio` vs per-token `naive` reference) to make the
threshold explicitly conditional: state the specific conditions under which 5e-3
was observed (hardware/GPU model, dtype bf16, RNG seed range or reproducibility
note, typical sequence shapes and batch regimes), and add a short guidance
sentence that the threshold may be higher/lower outside those regimes and should
be revalidated when any of those variables change. Ensure the modified text
names the metric (`norm_ratio`) and test configuration (`safe_gate=True`, bf16)
so readers can locate and re-run the same experiment.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Nitpick comments:
In `@tests/context_parallel/debug.md`:
- Around line 163-170: Update the prose around the "5e-3" acceptance bar in the
Var-length KCP test description (the paragraph mentioning Var-length KCP with
`safe_gate=True`, bf16 inputs, unaligned cut point, `norm_ratio` vs per-token
`naive` reference) to make the threshold explicitly conditional: state the
specific conditions under which 5e-3 was observed (hardware/GPU model, dtype
bf16, RNG seed range or reproducibility note, typical sequence shapes and batch
regimes), and add a short guidance sentence that the threshold may be
higher/lower outside those regimes and should be revalidated when any of those
variables change. Ensure the modified text names the metric (`norm_ratio`) and
test configuration (`safe_gate=True`, bf16) so readers can locate and re-run the
same experiment.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: 403f7c21-e8ab-4d83-9a0e-4476c5485279

📥 Commits

Reviewing files that changed from the base of the PR and between 77fe42d and 1207333.

📒 Files selected for processing (1)
  • tests/context_parallel/debug.md

@github-actions
Copy link
Copy Markdown

⚠️ Benchmark Results (NVIDIA-H100-PT2-7)

Status: 7 regression(s) detected

GPU NVIDIA H200
CUDA 12.8
PyTorch 2.7.1+cu128
Base 967f8c09ff
Head c76cd01a0c
Threshold 5.0%
Op Mode B T H D Base (ms) Head (ms) Speedup Change
chunk_comba fwd 1 8192 96 128 2.049 2.052 1.00x +0.1%
chunk_comba fwd 2 16384 16 128 1.436 1.437 1.00x +0.0%
chunk_comba fwd 4 2048 16 128 0.621 0.633 0.98x +1.8%
chunk_comba fwd 4 4096 64 128 2.563 2.560 1.00x -0.1%
chunk_comba fwd 8 1024 8 64 0.610 0.623 0.98x +2.0%
chunk_comba fwd 8 2048 32 256 2.706 2.705 1.00x -0.0%
chunk_delta_rule fwd 1 8192 96 128 1.586 1.587 1.00x +0.1%
chunk_delta_rule fwd 2 16384 16 128 1.077 1.076 1.00x -0.1%
chunk_delta_rule fwd 4 2048 16 128 0.335 0.337 0.99x +0.7%
chunk_delta_rule fwd 4 4096 64 128 1.954 1.953 1.00x -0.0%
chunk_delta_rule fwd 8 1024 8 64 0.323 0.315 1.02x -2.3%
chunk_delta_rule fwd 8 2048 32 256 2.174 2.174 1.00x +0.0%
chunk_dplr_delta_rule fwd 1 8192 96 128 5.939 5.940 1.00x +0.0%
chunk_dplr_delta_rule fwd 2 16384 16 128 5.871 5.873 1.00x +0.0%
chunk_dplr_delta_rule fwd 4 2048 16 128 1.128 1.129 1.00x +0.1%
chunk_dplr_delta_rule fwd 4 4096 64 128 7.442 7.445 1.00x +0.0%
chunk_dplr_delta_rule fwd 8 1024 8 64 0.416 0.414 1.00x -0.5%
chunk_dplr_delta_rule fwd 8 2048 32 256 11.612 11.601 1.00x -0.1%
chunk_gdn fwd 1 8192 96 128 1.677 1.672 1.00x -0.2%
chunk_gdn fwd 2 16384 16 128 1.169 1.166 1.00x -0.3%
chunk_gdn fwd 4 2048 16 128 0.505 0.522 0.97x +3.4%
chunk_gdn fwd 4 4096 64 128 2.037 2.030 1.00x -0.3%
chunk_gdn fwd 8 1024 8 64 0.497 0.505 0.99x +1.5%
chunk_gdn fwd 8 2048 32 256 2.390 2.391 1.00x +0.0%
chunk_gla fwd 1 8192 96 128 2.250 2.258 1.00x +0.3%
chunk_gla fwd 2 16384 16 128 1.831 1.830 1.00x -0.0%
chunk_gla fwd 4 2048 16 128 0.428 0.428 1.00x -0.0%
chunk_gla fwd 4 4096 64 128 3.010 3.010 1.00x +0.0%
chunk_gla fwd 8 1024 8 64 0.290 0.294 0.99x +1.5%
chunk_gla fwd 8 2048 32 256 3.697 3.693 1.00x -0.1%
chunk_kda fwd 1 8192 96 128 2.772 2.766 1.00x -0.2%
chunk_kda fwd 2 16384 16 128 1.925 1.927 1.00x +0.1%
chunk_kda fwd 4 2048 16 128 0.626 0.636 0.98x +1.6%
chunk_kda fwd 4 4096 64 128 3.468 3.463 1.00x -0.2%
chunk_kda fwd 8 1024 8 64 0.623 0.622 1.00x -0.1%
chunk_kda fwd 8 2048 32 256 4.404 4.409 1.00x +0.1%
chunk_lightning_attn fwd 1 8192 96 128 0.668 0.669 1.00x +0.2%
chunk_lightning_attn fwd 2 16384 16 128 0.532 0.537 0.99x +0.9%
chunk_lightning_attn fwd 4 2048 16 128 0.232 0.235 0.99x +1.4%
chunk_lightning_attn fwd 4 4096 64 128 0.873 0.875 1.00x +0.2%
chunk_lightning_attn fwd 8 1024 8 64 0.193 0.196 0.99x +1.5%
chunk_lightning_attn fwd 8 2048 32 256 1.197 1.200 1.00x +0.2%
chunk_linear_attn fwd 1 8192 96 128 4.306 4.306 1.00x +0.0%
chunk_linear_attn fwd 2 16384 16 128 6.908 6.907 1.00x -0.0%
chunk_linear_attn fwd 4 2048 16 128 1.001 1.002 1.00x +0.1%
chunk_linear_attn fwd 4 4096 64 128 3.255 3.257 1.00x +0.0%
chunk_linear_attn fwd 8 1024 8 64 0.440 0.440 1.00x -0.0%
chunk_linear_attn fwd 8 2048 32 256 2.682 2.680 1.00x -0.1%
chunk_retention fwd 1 8192 96 128 0.727 0.722 1.01x -0.7%
chunk_retention fwd 2 16384 16 128 0.588 0.586 1.00x -0.4%
chunk_retention fwd 4 2048 16 128 0.286 0.289 0.99x +0.9%
chunk_retention fwd 4 4096 64 128 0.930 0.931 1.00x +0.2%
chunk_retention fwd 8 1024 8 64 0.247 0.246 1.00x -0.5%
chunk_retention fwd 8 2048 32 256 1.257 1.256 1.00x -0.1%
chunk_rwkv6 fwd 1 8192 96 128 3.075 3.076 1.00x +0.0%
chunk_rwkv6 fwd 2 16384 16 128 2.383 2.385 1.00x +0.1%
chunk_rwkv6 fwd 4 2048 16 128 0.558 0.559 1.00x +0.1%
chunk_rwkv6 fwd 4 4096 64 128 3.893 3.895 1.00x +0.1%
chunk_rwkv6 fwd 8 1024 8 64 0.295 0.297 0.99x +0.6%
chunk_rwkv6 fwd 8 2048 32 256 5.139 5.133 1.00x -0.1%
chunk_rwkv7 fwd 1 8192 96 128 7.880 7.884 1.00x +0.1%
chunk_rwkv7 fwd 2 16384 16 128 7.198 7.201 1.00x +0.0%
chunk_rwkv7 fwd 4 2048 16 128 1.453 1.453 1.00x -0.0%
chunk_rwkv7 fwd 4 4096 64 128 10.072 10.063 1.00x -0.1%
chunk_rwkv7 fwd 8 1024 8 64 0.421 0.427 0.99x +1.3%
chunk_rwkv7 fwd 8 2048 32 256 45.371 45.377 1.00x +0.0%
chunk_simple_gla fwd 1 8192 96 128 0.669 0.672 1.00x +0.5%
chunk_simple_gla fwd 2 16384 16 128 0.616 0.615 1.00x -0.3%
chunk_simple_gla fwd 4 2048 16 128 0.177 0.186 0.95x +5.2% 🔴
chunk_simple_gla fwd 4 4096 64 128 0.847 0.844 1.00x -0.4%
chunk_simple_gla fwd 8 1024 8 64 0.167 0.164 1.02x -1.8%
chunk_simple_gla fwd 8 2048 32 256 1.259 1.265 1.00x +0.4%
flash_attn fwd 1 8192 96 128 4.512 4.547 0.99x +0.8%
flash_attn fwd 2 16384 16 128 6.102 6.154 0.99x +0.9%
flash_attn fwd 4 2048 16 128 0.250 0.251 1.00x +0.5%
flash_attn fwd 4 4096 64 128 3.133 3.158 0.99x +0.8%
flash_attn fwd 8 1024 8 64 0.058 0.058 1.00x -0.5%
flash_attn fwd 8 2048 32 256 1.957 1.961 1.00x +0.2%
fused_recurrent_hgrn fwd 1 8192 96 128 3.534 3.534 1.00x +0.0%
fused_recurrent_hgrn fwd 2 16384 16 128 6.990 6.993 1.00x +0.0%
fused_recurrent_hgrn fwd 4 2048 16 128 0.862 0.862 1.00x +0.0%
fused_recurrent_hgrn fwd 4 4096 64 128 1.864 1.864 1.00x -0.0%
fused_recurrent_hgrn fwd 8 1024 8 64 0.454 0.455 1.00x +0.2%
fused_recurrent_hgrn fwd 8 2048 32 256 0.970 0.971 1.00x +0.1%
parallel_attn fwd 1 8192 96 128 4.606 4.606 1.00x +0.0%
parallel_attn fwd 2 16384 16 128 5.993 5.989 1.00x -0.1%
parallel_attn fwd 4 2048 16 128 0.259 0.260 1.00x +0.3%
parallel_attn fwd 4 4096 64 128 3.159 3.162 1.00x +0.1%
parallel_attn fwd 8 1024 8 64 0.062 0.063 0.99x +1.3%
chunk_comba fwdbwd 1 8192 96 128 7.449 7.460 1.00x +0.1%
chunk_comba fwdbwd 2 16384 16 128 5.276 5.277 1.00x +0.0%
chunk_comba fwdbwd 4 2048 16 128 2.078 2.073 1.00x -0.2%
chunk_comba fwdbwd 4 4096 64 128 9.641 9.646 1.00x +0.0%
chunk_comba fwdbwd 8 1024 8 64 2.051 2.079 0.99x +1.4%
chunk_comba fwdbwd 8 2048 32 256 15.485 15.499 1.00x +0.1%
chunk_delta_rule fwdbwd 1 8192 96 128 5.000 4.998 1.00x -0.0%
chunk_delta_rule fwdbwd 2 16384 16 128 3.340 3.343 1.00x +0.1%
chunk_delta_rule fwdbwd 4 2048 16 128 1.271 1.218 1.04x -4.2%
chunk_delta_rule fwdbwd 4 4096 64 128 6.275 6.279 1.00x +0.1%
chunk_delta_rule fwdbwd 8 1024 8 64 1.251 1.206 1.04x -3.5%
chunk_delta_rule fwdbwd 8 2048 32 256 9.848 9.857 1.00x +0.1%
chunk_dplr_delta_rule fwdbwd 1 8192 96 128 24.605 24.622 1.00x +0.1%
chunk_dplr_delta_rule fwdbwd 2 16384 16 128 21.635 21.634 1.00x -0.0%
chunk_dplr_delta_rule fwdbwd 4 2048 16 128 4.512 4.516 1.00x +0.1%
chunk_dplr_delta_rule fwdbwd 4 4096 64 128 31.408 31.420 1.00x +0.0%
chunk_dplr_delta_rule fwdbwd 8 1024 8 64 1.929 1.874 1.03x -2.8%
chunk_dplr_delta_rule fwdbwd 8 2048 32 256 43.293 43.307 1.00x +0.0%
chunk_gdn fwdbwd 1 8192 96 128 6.805 6.790 1.00x -0.2%
chunk_gdn fwdbwd 2 16384 16 128 4.814 4.808 1.00x -0.1%
chunk_gdn fwdbwd 4 2048 16 128 1.841 1.851 0.99x +0.6%
chunk_gdn fwdbwd 4 4096 64 128 8.731 8.713 1.00x -0.2%
chunk_gdn fwdbwd 8 1024 8 64 1.792 1.772 1.01x -1.1%
chunk_gdn fwdbwd 8 2048 32 256 14.850 14.862 1.00x +0.1%
chunk_gla fwdbwd 1 8192 96 128 10.936 10.941 1.00x +0.0%
chunk_gla fwdbwd 2 16384 16 128 8.483 8.480 1.00x -0.0%
chunk_gla fwdbwd 4 2048 16 128 2.021 2.200 0.92x +8.9% 🔴
chunk_gla fwdbwd 4 4096 64 128 15.179 15.181 1.00x +0.0%
chunk_gla fwdbwd 8 1024 8 64 1.179 1.718 0.69x +45.8% 🔴
chunk_gla fwdbwd 8 2048 32 256 18.964 18.968 1.00x +0.0%
chunk_kda fwdbwd 1 8192 96 128 12.487 12.500 1.00x +0.1%
chunk_kda fwdbwd 2 16384 16 128 8.369 8.371 1.00x +0.0%
chunk_kda fwdbwd 4 2048 16 128 2.180 2.183 1.00x +0.1%
chunk_kda fwdbwd 4 4096 64 128 15.971 15.962 1.00x -0.1%
chunk_kda fwdbwd 8 1024 8 64 2.101 2.003 1.05x -4.7%
chunk_kda fwdbwd 8 2048 32 256 18.535 18.538 1.00x +0.0%
chunk_lightning_attn fwdbwd 1 8192 96 128 3.442 3.461 0.99x +0.6%
chunk_lightning_attn fwdbwd 2 16384 16 128 2.667 2.676 1.00x +0.4%
chunk_lightning_attn fwdbwd 4 2048 16 128 0.942 0.985 0.96x +4.6%
chunk_lightning_attn fwdbwd 4 4096 64 128 4.495 4.548 0.99x +1.2%
chunk_lightning_attn fwdbwd 8 1024 8 64 0.799 0.812 0.98x +1.6%
chunk_lightning_attn fwdbwd 8 2048 32 256 10.077 10.149 0.99x +0.7%
chunk_linear_attn fwdbwd 1 8192 96 128 12.339 12.345 1.00x +0.1%
chunk_linear_attn fwdbwd 2 16384 16 128 16.590 16.613 1.00x +0.1%
chunk_linear_attn fwdbwd 4 2048 16 128 2.698 2.700 1.00x +0.1%
chunk_linear_attn fwdbwd 4 4096 64 128 11.619 11.611 1.00x -0.1%
chunk_linear_attn fwdbwd 8 1024 8 64 1.016 1.069 0.95x +5.2% 🔴
chunk_linear_attn fwdbwd 8 2048 32 256 15.315 15.316 1.00x +0.0%
chunk_retention fwdbwd 1 8192 96 128 3.561 3.544 1.00x -0.5%
chunk_retention fwdbwd 2 16384 16 128 2.749 2.769 0.99x +0.7%
chunk_retention fwdbwd 4 2048 16 128 1.041 1.047 0.99x +0.6%
chunk_retention fwdbwd 4 4096 64 128 4.571 4.603 0.99x +0.7%
chunk_retention fwdbwd 8 1024 8 64 0.870 0.869 1.00x -0.2%
chunk_retention fwdbwd 8 2048 32 256 10.186 10.197 1.00x +0.1%
chunk_rwkv6 fwdbwd 1 8192 96 128 13.064 13.064 1.00x -0.0%
chunk_rwkv6 fwdbwd 2 16384 16 128 9.626 9.633 1.00x +0.1%
chunk_rwkv6 fwdbwd 4 2048 16 128 2.301 2.304 1.00x +0.1%
chunk_rwkv6 fwdbwd 4 4096 64 128 16.832 16.826 1.00x -0.0%
chunk_rwkv6 fwdbwd 8 1024 8 64 1.205 1.247 0.97x +3.5%
chunk_rwkv6 fwdbwd 8 2048 32 256 20.796 20.785 1.00x -0.1%
chunk_rwkv7 fwdbwd 1 8192 96 128 23.836 23.841 1.00x +0.0%
chunk_rwkv7 fwdbwd 2 16384 16 128 21.187 21.202 1.00x +0.1%
chunk_rwkv7 fwdbwd 4 2048 16 128 4.377 4.374 1.00x -0.1%
chunk_rwkv7 fwdbwd 4 4096 64 128 30.466 30.463 1.00x -0.0%
chunk_rwkv7 fwdbwd 8 1024 8 64 1.793 1.989 0.90x +10.9% 🔴
chunk_rwkv7 fwdbwd 8 2048 32 256 104.621 104.601 1.00x -0.0%
chunk_simple_gla fwdbwd 1 8192 96 128 3.899 3.902 1.00x +0.1%
chunk_simple_gla fwdbwd 2 16384 16 128 3.128 3.131 1.00x +0.1%
chunk_simple_gla fwdbwd 4 2048 16 128 1.018 0.928 1.10x -8.8% 🟢
chunk_simple_gla fwdbwd 4 4096 64 128 5.238 5.240 1.00x +0.0%
chunk_simple_gla fwdbwd 8 1024 8 64 0.919 1.028 0.89x +11.8% 🔴
chunk_simple_gla fwdbwd 8 2048 32 256 12.184 12.196 1.00x +0.1%
flash_attn fwdbwd 1 8192 96 128 18.668 18.299 1.02x -2.0%
flash_attn fwdbwd 2 16384 16 128 23.051 23.037 1.00x -0.1%
flash_attn fwdbwd 4 2048 16 128 1.041 1.042 1.00x +0.1%
flash_attn fwdbwd 4 4096 64 128 12.854 12.846 1.00x -0.1%
flash_attn fwdbwd 8 1024 8 64 0.278 0.344 0.81x +23.4% 🔴
flash_attn fwdbwd 8 2048 32 256 8.292 8.302 1.00x +0.1%
fused_recurrent_hgrn fwdbwd 1 8192 96 128 8.191 8.189 1.00x -0.0%
fused_recurrent_hgrn fwdbwd 2 16384 16 128 15.379 15.388 1.00x +0.1%
fused_recurrent_hgrn fwdbwd 4 2048 16 128 1.975 1.976 1.00x +0.0%
fused_recurrent_hgrn fwdbwd 4 4096 64 128 4.850 4.853 1.00x +0.1%
fused_recurrent_hgrn fwdbwd 8 1024 8 64 0.952 0.952 1.00x +0.0%
fused_recurrent_hgrn fwdbwd 8 2048 32 256 3.021 3.021 1.00x +0.0%
parallel_attn fwdbwd 1 8192 96 128 22.015 22.022 1.00x +0.0%
parallel_attn fwdbwd 2 16384 16 128 28.969 28.946 1.00x -0.1%
parallel_attn fwdbwd 4 2048 16 128 1.471 1.471 1.00x +0.0%
parallel_attn fwdbwd 4 4096 64 128 16.988 16.984 1.00x -0.0%
parallel_attn fwdbwd 8 1024 8 64 0.465 0.487 0.95x +4.7%

This comment is automatically updated with the latest benchmark results.

@zhiyuan1i zhiyuan1i merged commit c113780 into main Apr 12, 2026
6 checks passed
@zhiyuan1i zhiyuan1i deleted the lzy/dplr-kcp branch April 12, 2026 15:01
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant