Conversation
|
No actionable comments were generated in the recent review. 🎉 ℹ️ Recent review info⚙️ Run configurationConfiguration used: Repository UI Review profile: CHILL Plan: Pro Run ID: 📒 Files selected for processing (1)
WalkthroughImplements Context-Parallel (CP) token-shift autograd, extends CP gated-delta Triton kernels (USE_BG path), threads CP context through DPLR chunk ops, adds distributed CP tests and a CP debugging guide; includes minor test maintenance and kernel argument plumbing. Changes
Sequence Diagram(s)sequenceDiagram
participant Client
participant TokenShiftCPFunction
participant CP_SendRecv as "CP Send/Recv"
participant TokenShiftFwd as "token_shift_fwd"
participant Cache
Client->>TokenShiftCPFunction: forward(x, cu_seqlens, chunk_indices, cp_context)
TokenShiftCPFunction->>TokenShiftCPFunction: validate cp_context, derive cu_seqlens/group
TokenShiftCPFunction->>CP_SendRecv: conv_cp_send_recv_fwd(send_last_token)
CP_SendRecv-->>Cache: recv_prev_rank_token (non-first ranks)
TokenShiftCPFunction->>TokenShiftFwd: call with x and cache
TokenShiftFwd-->>TokenShiftCPFunction: shifted output
TokenShiftCPFunction-->>Client: return output
sequenceDiagram
participant Client
participant TokenShiftCPFunction
participant TokenShiftBwd as "token_shift_bwd"
participant CP_SendRecv as "CP Send/Recv"
Client->>TokenShiftCPFunction: backward(dy)
TokenShiftCPFunction->>TokenShiftBwd: compute dx, grad_cache
TokenShiftCPFunction->>CP_SendRecv: conv_cp_send_recv_bwd(send_grad_cache)
CP_SendRecv-->>TokenShiftCPFunction: recv_grad_from_next_rank
TokenShiftCPFunction->>TokenShiftCPFunction: add recv_grad to dx[0,-1,:]
TokenShiftCPFunction-->>Client: return dx
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 3✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches📝 Generate docstrings
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request introduces Context Parallel (CP) support for the Delta Product Linear RNN (DPLR) operator. Key changes include updating the Triton pre-processing kernels in fla/ops/cp/chunk_delta_h.py to handle DPLR-specific logic, such as the use of the bg gate and modified state update rules. The high-level chunk_dplr_delta_rule was integrated with these CP steps, and a comprehensive test suite was added to verify the implementation against a reference recurrence. Review feedback identifies a potential bug in the kernel argument passing for DPLR mode and points out a misleading comment regarding pointer offsets.
There was a problem hiding this comment.
Actionable comments posted: 3
🧹 Nitpick comments (1)
fla/modules/token_shift_cp.py (1)
136-137: Avoid two sources of truth for sequence metadata.
TokenShiftCPFunction.forwardalways usescp_context.cu_seqlens, buttoken_shift_cpstill accepts and forwards a separatecu_seqlens. That lets callers pass mismatchedcu_seqlens/chunk_indiceswithout any signal. Either remove the publiccu_seqlensparameter or validate it before dispatch.Example cleanup
def token_shift_cp( x: torch.Tensor, cp_context: FLACPContext, cu_seqlens: torch.Tensor | None = None, chunk_indices: torch.Tensor | None = None, ): @@ - assert cp_context.cu_seqlens is not None, "cu_seqlens must be provided for token_shift_cp" + if cp_context.cu_seqlens is None: + raise ValueError("cp_context.cu_seqlens must be provided for token_shift_cp") + if cu_seqlens is not None and not torch.equal(cu_seqlens, cp_context.cu_seqlens): + raise ValueError("cu_seqlens must match cp_context.cu_seqlens") return TokenShiftCPFunction.apply( - x, cu_seqlens, chunk_indices, cp_context + x, cp_context.cu_seqlens, chunk_indices, cp_context )Also applies to: 221-227
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@fla/modules/token_shift_cp.py` around lines 136 - 137, token_shift_cp is accepting a cu_seqlens argument but TokenShiftCPFunction.forward always pulls cu_seqlens from cp_context, creating two sources of truth; either remove the public cu_seqlens parameter from token_shift_cp and all callers (so TokenShiftCPFunction.forward uses only cp_context.cu_seqlens), or add a strict validation at the start of token_shift_cp (and the corresponding code paths around lines 221-227) that compares the supplied cu_seqlens to cp_context.cu_seqlens (e.g., equality check or torch.equal) and raises a ValueError if they differ; update the function signature and call sites if you choose removal, or add the validation and an explanatory error if you choose to keep the parameter.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@fla/ops/generalized_delta_rule/dplr/chunk.py`:
- Around line 472-478: When cp_context is provided you must override any
caller-supplied CPU boundaries so prepare_chunk_indices doesn't mix views;
replace the conditional assignment that only sets cu_seqlens_cpu when
cp_context.cu_seqlens_cpu is not None with an unconditional override (i.e., set
cu_seqlens_cpu = cp_context.cu_seqlens_cpu) in the block that handles cp_context
(alongside cp_context, cu_seqlens, initial_state, output_final_state checks) so
prepare_chunk_indices always sees the CP view (including None) from cp_context.
In `@tests/context_parallel/test_cp_dplr.py`:
- Around line 245-255: The call to chunk_dplr_delta_rule hardcodes
safe_gate=True so the non-safe path is never exercised; change the argument to
pass the test matrix variable (e.g., safe_gate) instead of the literal True so
the test will invoke both safe_gate=True and safe_gate=False; locate the
invocation of chunk_dplr_delta_rule (the parameters q=q_local, k=k_local,
v=v_local, a=a_local, b=b_local, gk=gk_local, cp_context=context) and replace
safe_gate=True with safe_gate (or the actual parameter name used by the test).
In `@tests/context_parallel/test_cp_token_shift.py`:
- Around line 181-188: The test always sets a hard-coded rendezvous port (port =
29510) which causes races when pytest runs tests in parallel; change the port
selection in the setup before calling mp.start_processes (the place where port
is passed to run_cp_token_shift_test_worker) to use a unique free port per test
run (e.g., call a helper that returns an available ephemeral port or derive a
port from the current process/worker id or random offset) so each spawned test
uses its own rendezvous port rather than the fixed 29510.
---
Nitpick comments:
In `@fla/modules/token_shift_cp.py`:
- Around line 136-137: token_shift_cp is accepting a cu_seqlens argument but
TokenShiftCPFunction.forward always pulls cu_seqlens from cp_context, creating
two sources of truth; either remove the public cu_seqlens parameter from
token_shift_cp and all callers (so TokenShiftCPFunction.forward uses only
cp_context.cu_seqlens), or add a strict validation at the start of
token_shift_cp (and the corresponding code paths around lines 221-227) that
compares the supplied cu_seqlens to cp_context.cu_seqlens (e.g., equality check
or torch.equal) and raises a ValueError if they differ; update the function
signature and call sites if you choose removal, or add the validation and an
explanatory error if you choose to keep the parameter.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: 7bf51201-398d-442e-8625-3e0190e7574f
📒 Files selected for processing (7)
fla/modules/token_shift_cp.pyfla/ops/cp/chunk_delta_h.pyfla/ops/generalized_delta_rule/dplr/chunk.pytests/context_parallel/test_cp_bwd_gk_offset.pytests/context_parallel/test_cp_dplr.pytests/context_parallel/test_cp_kda.pytests/context_parallel/test_cp_token_shift.py
| if cp_context is not None: | ||
| assert initial_state is None, "Initial state is not supported for CP" | ||
| assert output_final_state is False, "Output final state is not supported for CP" | ||
| assert cp_context.cu_seqlens is not None, "cu_seqlens is required for CP" | ||
| cu_seqlens = cp_context.cu_seqlens | ||
| if cp_context.cu_seqlens_cpu is not None: | ||
| cu_seqlens_cpu = cp_context.cu_seqlens_cpu |
There was a problem hiding this comment.
Override cu_seqlens_cpu unconditionally in CP mode.
Here cu_seqlens switches to the local CP view, but cu_seqlens_cpu keeps the caller-supplied value whenever cp_context.cu_seqlens_cpu is None. prepare_chunk_indices then mixes local GPU boundaries with stale CPU boundaries and can build the wrong chunks.
Minimal fix
if cp_context is not None:
assert initial_state is None, "Initial state is not supported for CP"
assert output_final_state is False, "Output final state is not supported for CP"
assert cp_context.cu_seqlens is not None, "cu_seqlens is required for CP"
cu_seqlens = cp_context.cu_seqlens
- if cp_context.cu_seqlens_cpu is not None:
- cu_seqlens_cpu = cp_context.cu_seqlens_cpu
+ cu_seqlens_cpu = cp_context.cu_seqlens_cpu🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@fla/ops/generalized_delta_rule/dplr/chunk.py` around lines 472 - 478, When
cp_context is provided you must override any caller-supplied CPU boundaries so
prepare_chunk_indices doesn't mix views; replace the conditional assignment that
only sets cu_seqlens_cpu when cp_context.cu_seqlens_cpu is not None with an
unconditional override (i.e., set cu_seqlens_cpu = cp_context.cu_seqlens_cpu) in
the block that handles cp_context (alongside cp_context, cu_seqlens,
initial_state, output_final_state checks) so prepare_chunk_indices always sees
the CP view (including None) from cp_context.
| o_local, _ = chunk_dplr_delta_rule( | ||
| q=q_local, | ||
| k=k_local, | ||
| v=v_local, | ||
| a=a_local, | ||
| b=b_local, | ||
| gk=gk_local, | ||
| cp_context=context, | ||
| safe_gate=True, | ||
| chunk_size=64, | ||
| ) |
There was a problem hiding this comment.
This never exercises safe_gate=False.
The worker plumbs safe_gate through its API and the test matrix varies it, but the actual call is hardcoded to safe_gate=True. That means the default cases stop validating the non-safe path, and test_cp2_safe_gate is effectively a duplicate.
Fix
o_local, _ = chunk_dplr_delta_rule(
q=q_local,
k=k_local,
v=v_local,
a=a_local,
b=b_local,
gk=gk_local,
cp_context=context,
- safe_gate=True,
+ safe_gate=safe_gate,
chunk_size=64,
)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| o_local, _ = chunk_dplr_delta_rule( | |
| q=q_local, | |
| k=k_local, | |
| v=v_local, | |
| a=a_local, | |
| b=b_local, | |
| gk=gk_local, | |
| cp_context=context, | |
| safe_gate=True, | |
| chunk_size=64, | |
| ) | |
| o_local, _ = chunk_dplr_delta_rule( | |
| q=q_local, | |
| k=k_local, | |
| v=v_local, | |
| a=a_local, | |
| b=b_local, | |
| gk=gk_local, | |
| cp_context=context, | |
| safe_gate=safe_gate, | |
| chunk_size=64, | |
| ) |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tests/context_parallel/test_cp_dplr.py` around lines 245 - 255, The call to
chunk_dplr_delta_rule hardcodes safe_gate=True so the non-safe path is never
exercised; change the argument to pass the test matrix variable (e.g.,
safe_gate) instead of the literal True so the test will invoke both
safe_gate=True and safe_gate=False; locate the invocation of
chunk_dplr_delta_rule (the parameters q=q_local, k=k_local, v=v_local,
a=a_local, b=b_local, gk=gk_local, cp_context=context) and replace
safe_gate=True with safe_gate (or the actual parameter name used by the test).
| port = 29510 # Different port from other CP tests | ||
| mp.start_processes( | ||
| run_cp_token_shift_test_worker, | ||
| args=(world_size, test_name, T, D, lengths, dtype, port), | ||
| nprocs=world_size, | ||
| join=True, | ||
| start_method='spawn', | ||
| ) |
There was a problem hiding this comment.
Use a unique rendezvous port per spawned test run.
Every case in this module binds 29510. Under parallel pytest workers, two CP tests can race for the same localhost port and fail nondeterministically.
Example fix
+import socket
@@
def run_cp_test_with_spawn(
world_size: int,
test_name: str,
T: int,
D: int,
lengths: list[int],
dtype=torch.float32,
):
"""Run CP test using torch.multiprocessing.spawn."""
- port = 29510 # Different port from other CP tests
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
+ sock.bind(("127.0.0.1", 0))
+ port = sock.getsockname()[1]
mp.start_processes(
run_cp_token_shift_test_worker,
args=(world_size, test_name, T, D, lengths, dtype, port),🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tests/context_parallel/test_cp_token_shift.py` around lines 181 - 188, The
test always sets a hard-coded rendezvous port (port = 29510) which causes races
when pytest runs tests in parallel; change the port selection in the setup
before calling mp.start_processes (the place where port is passed to
run_cp_token_shift_test_worker) to use a unique free port per test run (e.g.,
call a helper that returns an available ephemeral port or derive a port from the
current process/worker id or random offset) so each spawned test uses its own
rendezvous port rather than the fixed 29510.
There was a problem hiding this comment.
🧹 Nitpick comments (1)
tests/context_parallel/debug.md (1)
163-170: Make the5e-3acceptance bar explicitly conditional.This threshold is useful, but it should be documented with conditions (hardware/dtype/seed/shape regime), otherwise it can be interpreted as universally stable and cause noisy triage.
📝 Suggested wording tweak
-Var-length KCP with `safe_gate=True`, bf16 inputs, and an unaligned -cut point should land below 5e-3 `norm_ratio` per gradient against +Var-length KCP with `safe_gate=True`, bf16 inputs, and an unaligned +cut point is expected to land around/below 5e-3 `norm_ratio` per gradient against the per-token `naive` reference (measured per-sequence). That is pure bf16 chunk-vs-per-token noise and matches the magnitudes that the long-standing KCP tests (e.g. gated-delta-rule CP2) sit at. Anything -above ~5e-3 means either the forward recomputation is using the wrong +consistently above this range (under the same seed, shape regime, and hardware) +usually means either the forward recomputation is using the wrong `initial_state`, or the merge kernel is being called with a stale `BV`, or both.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/context_parallel/debug.md` around lines 163 - 170, Update the prose around the "5e-3" acceptance bar in the Var-length KCP test description (the paragraph mentioning Var-length KCP with `safe_gate=True`, bf16 inputs, unaligned cut point, `norm_ratio` vs per-token `naive` reference) to make the threshold explicitly conditional: state the specific conditions under which 5e-3 was observed (hardware/GPU model, dtype bf16, RNG seed range or reproducibility note, typical sequence shapes and batch regimes), and add a short guidance sentence that the threshold may be higher/lower outside those regimes and should be revalidated when any of those variables change. Ensure the modified text names the metric (`norm_ratio`) and test configuration (`safe_gate=True`, bf16) so readers can locate and re-run the same experiment.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Nitpick comments:
In `@tests/context_parallel/debug.md`:
- Around line 163-170: Update the prose around the "5e-3" acceptance bar in the
Var-length KCP test description (the paragraph mentioning Var-length KCP with
`safe_gate=True`, bf16 inputs, unaligned cut point, `norm_ratio` vs per-token
`naive` reference) to make the threshold explicitly conditional: state the
specific conditions under which 5e-3 was observed (hardware/GPU model, dtype
bf16, RNG seed range or reproducibility note, typical sequence shapes and batch
regimes), and add a short guidance sentence that the threshold may be
higher/lower outside those regimes and should be revalidated when any of those
variables change. Ensure the modified text names the metric (`norm_ratio`) and
test configuration (`safe_gate=True`, bf16) so readers can locate and re-run the
same experiment.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: 403f7c21-e8ab-4d83-9a0e-4476c5485279
📒 Files selected for processing (1)
tests/context_parallel/debug.md
|
| GPU | NVIDIA H200 |
| CUDA | 12.8 |
| PyTorch | 2.7.1+cu128 |
| Base | 967f8c09ff |
| Head | c76cd01a0c |
| Threshold | 5.0% |
| Op | Mode | B | T | H | D | Base (ms) | Head (ms) | Speedup | Change |
|---|---|---|---|---|---|---|---|---|---|
| chunk_comba | fwd | 1 | 8192 | 96 | 128 | 2.049 | 2.052 | 1.00x | +0.1% |
| chunk_comba | fwd | 2 | 16384 | 16 | 128 | 1.436 | 1.437 | 1.00x | +0.0% |
| chunk_comba | fwd | 4 | 2048 | 16 | 128 | 0.621 | 0.633 | 0.98x | +1.8% |
| chunk_comba | fwd | 4 | 4096 | 64 | 128 | 2.563 | 2.560 | 1.00x | -0.1% |
| chunk_comba | fwd | 8 | 1024 | 8 | 64 | 0.610 | 0.623 | 0.98x | +2.0% |
| chunk_comba | fwd | 8 | 2048 | 32 | 256 | 2.706 | 2.705 | 1.00x | -0.0% |
| chunk_delta_rule | fwd | 1 | 8192 | 96 | 128 | 1.586 | 1.587 | 1.00x | +0.1% |
| chunk_delta_rule | fwd | 2 | 16384 | 16 | 128 | 1.077 | 1.076 | 1.00x | -0.1% |
| chunk_delta_rule | fwd | 4 | 2048 | 16 | 128 | 0.335 | 0.337 | 0.99x | +0.7% |
| chunk_delta_rule | fwd | 4 | 4096 | 64 | 128 | 1.954 | 1.953 | 1.00x | -0.0% |
| chunk_delta_rule | fwd | 8 | 1024 | 8 | 64 | 0.323 | 0.315 | 1.02x | -2.3% |
| chunk_delta_rule | fwd | 8 | 2048 | 32 | 256 | 2.174 | 2.174 | 1.00x | +0.0% |
| chunk_dplr_delta_rule | fwd | 1 | 8192 | 96 | 128 | 5.939 | 5.940 | 1.00x | +0.0% |
| chunk_dplr_delta_rule | fwd | 2 | 16384 | 16 | 128 | 5.871 | 5.873 | 1.00x | +0.0% |
| chunk_dplr_delta_rule | fwd | 4 | 2048 | 16 | 128 | 1.128 | 1.129 | 1.00x | +0.1% |
| chunk_dplr_delta_rule | fwd | 4 | 4096 | 64 | 128 | 7.442 | 7.445 | 1.00x | +0.0% |
| chunk_dplr_delta_rule | fwd | 8 | 1024 | 8 | 64 | 0.416 | 0.414 | 1.00x | -0.5% |
| chunk_dplr_delta_rule | fwd | 8 | 2048 | 32 | 256 | 11.612 | 11.601 | 1.00x | -0.1% |
| chunk_gdn | fwd | 1 | 8192 | 96 | 128 | 1.677 | 1.672 | 1.00x | -0.2% |
| chunk_gdn | fwd | 2 | 16384 | 16 | 128 | 1.169 | 1.166 | 1.00x | -0.3% |
| chunk_gdn | fwd | 4 | 2048 | 16 | 128 | 0.505 | 0.522 | 0.97x | +3.4% |
| chunk_gdn | fwd | 4 | 4096 | 64 | 128 | 2.037 | 2.030 | 1.00x | -0.3% |
| chunk_gdn | fwd | 8 | 1024 | 8 | 64 | 0.497 | 0.505 | 0.99x | +1.5% |
| chunk_gdn | fwd | 8 | 2048 | 32 | 256 | 2.390 | 2.391 | 1.00x | +0.0% |
| chunk_gla | fwd | 1 | 8192 | 96 | 128 | 2.250 | 2.258 | 1.00x | +0.3% |
| chunk_gla | fwd | 2 | 16384 | 16 | 128 | 1.831 | 1.830 | 1.00x | -0.0% |
| chunk_gla | fwd | 4 | 2048 | 16 | 128 | 0.428 | 0.428 | 1.00x | -0.0% |
| chunk_gla | fwd | 4 | 4096 | 64 | 128 | 3.010 | 3.010 | 1.00x | +0.0% |
| chunk_gla | fwd | 8 | 1024 | 8 | 64 | 0.290 | 0.294 | 0.99x | +1.5% |
| chunk_gla | fwd | 8 | 2048 | 32 | 256 | 3.697 | 3.693 | 1.00x | -0.1% |
| chunk_kda | fwd | 1 | 8192 | 96 | 128 | 2.772 | 2.766 | 1.00x | -0.2% |
| chunk_kda | fwd | 2 | 16384 | 16 | 128 | 1.925 | 1.927 | 1.00x | +0.1% |
| chunk_kda | fwd | 4 | 2048 | 16 | 128 | 0.626 | 0.636 | 0.98x | +1.6% |
| chunk_kda | fwd | 4 | 4096 | 64 | 128 | 3.468 | 3.463 | 1.00x | -0.2% |
| chunk_kda | fwd | 8 | 1024 | 8 | 64 | 0.623 | 0.622 | 1.00x | -0.1% |
| chunk_kda | fwd | 8 | 2048 | 32 | 256 | 4.404 | 4.409 | 1.00x | +0.1% |
| chunk_lightning_attn | fwd | 1 | 8192 | 96 | 128 | 0.668 | 0.669 | 1.00x | +0.2% |
| chunk_lightning_attn | fwd | 2 | 16384 | 16 | 128 | 0.532 | 0.537 | 0.99x | +0.9% |
| chunk_lightning_attn | fwd | 4 | 2048 | 16 | 128 | 0.232 | 0.235 | 0.99x | +1.4% |
| chunk_lightning_attn | fwd | 4 | 4096 | 64 | 128 | 0.873 | 0.875 | 1.00x | +0.2% |
| chunk_lightning_attn | fwd | 8 | 1024 | 8 | 64 | 0.193 | 0.196 | 0.99x | +1.5% |
| chunk_lightning_attn | fwd | 8 | 2048 | 32 | 256 | 1.197 | 1.200 | 1.00x | +0.2% |
| chunk_linear_attn | fwd | 1 | 8192 | 96 | 128 | 4.306 | 4.306 | 1.00x | +0.0% |
| chunk_linear_attn | fwd | 2 | 16384 | 16 | 128 | 6.908 | 6.907 | 1.00x | -0.0% |
| chunk_linear_attn | fwd | 4 | 2048 | 16 | 128 | 1.001 | 1.002 | 1.00x | +0.1% |
| chunk_linear_attn | fwd | 4 | 4096 | 64 | 128 | 3.255 | 3.257 | 1.00x | +0.0% |
| chunk_linear_attn | fwd | 8 | 1024 | 8 | 64 | 0.440 | 0.440 | 1.00x | -0.0% |
| chunk_linear_attn | fwd | 8 | 2048 | 32 | 256 | 2.682 | 2.680 | 1.00x | -0.1% |
| chunk_retention | fwd | 1 | 8192 | 96 | 128 | 0.727 | 0.722 | 1.01x | -0.7% |
| chunk_retention | fwd | 2 | 16384 | 16 | 128 | 0.588 | 0.586 | 1.00x | -0.4% |
| chunk_retention | fwd | 4 | 2048 | 16 | 128 | 0.286 | 0.289 | 0.99x | +0.9% |
| chunk_retention | fwd | 4 | 4096 | 64 | 128 | 0.930 | 0.931 | 1.00x | +0.2% |
| chunk_retention | fwd | 8 | 1024 | 8 | 64 | 0.247 | 0.246 | 1.00x | -0.5% |
| chunk_retention | fwd | 8 | 2048 | 32 | 256 | 1.257 | 1.256 | 1.00x | -0.1% |
| chunk_rwkv6 | fwd | 1 | 8192 | 96 | 128 | 3.075 | 3.076 | 1.00x | +0.0% |
| chunk_rwkv6 | fwd | 2 | 16384 | 16 | 128 | 2.383 | 2.385 | 1.00x | +0.1% |
| chunk_rwkv6 | fwd | 4 | 2048 | 16 | 128 | 0.558 | 0.559 | 1.00x | +0.1% |
| chunk_rwkv6 | fwd | 4 | 4096 | 64 | 128 | 3.893 | 3.895 | 1.00x | +0.1% |
| chunk_rwkv6 | fwd | 8 | 1024 | 8 | 64 | 0.295 | 0.297 | 0.99x | +0.6% |
| chunk_rwkv6 | fwd | 8 | 2048 | 32 | 256 | 5.139 | 5.133 | 1.00x | -0.1% |
| chunk_rwkv7 | fwd | 1 | 8192 | 96 | 128 | 7.880 | 7.884 | 1.00x | +0.1% |
| chunk_rwkv7 | fwd | 2 | 16384 | 16 | 128 | 7.198 | 7.201 | 1.00x | +0.0% |
| chunk_rwkv7 | fwd | 4 | 2048 | 16 | 128 | 1.453 | 1.453 | 1.00x | -0.0% |
| chunk_rwkv7 | fwd | 4 | 4096 | 64 | 128 | 10.072 | 10.063 | 1.00x | -0.1% |
| chunk_rwkv7 | fwd | 8 | 1024 | 8 | 64 | 0.421 | 0.427 | 0.99x | +1.3% |
| chunk_rwkv7 | fwd | 8 | 2048 | 32 | 256 | 45.371 | 45.377 | 1.00x | +0.0% |
| chunk_simple_gla | fwd | 1 | 8192 | 96 | 128 | 0.669 | 0.672 | 1.00x | +0.5% |
| chunk_simple_gla | fwd | 2 | 16384 | 16 | 128 | 0.616 | 0.615 | 1.00x | -0.3% |
| chunk_simple_gla | fwd | 4 | 2048 | 16 | 128 | 0.177 | 0.186 | 0.95x | +5.2% 🔴 |
| chunk_simple_gla | fwd | 4 | 4096 | 64 | 128 | 0.847 | 0.844 | 1.00x | -0.4% |
| chunk_simple_gla | fwd | 8 | 1024 | 8 | 64 | 0.167 | 0.164 | 1.02x | -1.8% |
| chunk_simple_gla | fwd | 8 | 2048 | 32 | 256 | 1.259 | 1.265 | 1.00x | +0.4% |
| flash_attn | fwd | 1 | 8192 | 96 | 128 | 4.512 | 4.547 | 0.99x | +0.8% |
| flash_attn | fwd | 2 | 16384 | 16 | 128 | 6.102 | 6.154 | 0.99x | +0.9% |
| flash_attn | fwd | 4 | 2048 | 16 | 128 | 0.250 | 0.251 | 1.00x | +0.5% |
| flash_attn | fwd | 4 | 4096 | 64 | 128 | 3.133 | 3.158 | 0.99x | +0.8% |
| flash_attn | fwd | 8 | 1024 | 8 | 64 | 0.058 | 0.058 | 1.00x | -0.5% |
| flash_attn | fwd | 8 | 2048 | 32 | 256 | 1.957 | 1.961 | 1.00x | +0.2% |
| fused_recurrent_hgrn | fwd | 1 | 8192 | 96 | 128 | 3.534 | 3.534 | 1.00x | +0.0% |
| fused_recurrent_hgrn | fwd | 2 | 16384 | 16 | 128 | 6.990 | 6.993 | 1.00x | +0.0% |
| fused_recurrent_hgrn | fwd | 4 | 2048 | 16 | 128 | 0.862 | 0.862 | 1.00x | +0.0% |
| fused_recurrent_hgrn | fwd | 4 | 4096 | 64 | 128 | 1.864 | 1.864 | 1.00x | -0.0% |
| fused_recurrent_hgrn | fwd | 8 | 1024 | 8 | 64 | 0.454 | 0.455 | 1.00x | +0.2% |
| fused_recurrent_hgrn | fwd | 8 | 2048 | 32 | 256 | 0.970 | 0.971 | 1.00x | +0.1% |
| parallel_attn | fwd | 1 | 8192 | 96 | 128 | 4.606 | 4.606 | 1.00x | +0.0% |
| parallel_attn | fwd | 2 | 16384 | 16 | 128 | 5.993 | 5.989 | 1.00x | -0.1% |
| parallel_attn | fwd | 4 | 2048 | 16 | 128 | 0.259 | 0.260 | 1.00x | +0.3% |
| parallel_attn | fwd | 4 | 4096 | 64 | 128 | 3.159 | 3.162 | 1.00x | +0.1% |
| parallel_attn | fwd | 8 | 1024 | 8 | 64 | 0.062 | 0.063 | 0.99x | +1.3% |
| chunk_comba | fwdbwd | 1 | 8192 | 96 | 128 | 7.449 | 7.460 | 1.00x | +0.1% |
| chunk_comba | fwdbwd | 2 | 16384 | 16 | 128 | 5.276 | 5.277 | 1.00x | +0.0% |
| chunk_comba | fwdbwd | 4 | 2048 | 16 | 128 | 2.078 | 2.073 | 1.00x | -0.2% |
| chunk_comba | fwdbwd | 4 | 4096 | 64 | 128 | 9.641 | 9.646 | 1.00x | +0.0% |
| chunk_comba | fwdbwd | 8 | 1024 | 8 | 64 | 2.051 | 2.079 | 0.99x | +1.4% |
| chunk_comba | fwdbwd | 8 | 2048 | 32 | 256 | 15.485 | 15.499 | 1.00x | +0.1% |
| chunk_delta_rule | fwdbwd | 1 | 8192 | 96 | 128 | 5.000 | 4.998 | 1.00x | -0.0% |
| chunk_delta_rule | fwdbwd | 2 | 16384 | 16 | 128 | 3.340 | 3.343 | 1.00x | +0.1% |
| chunk_delta_rule | fwdbwd | 4 | 2048 | 16 | 128 | 1.271 | 1.218 | 1.04x | -4.2% |
| chunk_delta_rule | fwdbwd | 4 | 4096 | 64 | 128 | 6.275 | 6.279 | 1.00x | +0.1% |
| chunk_delta_rule | fwdbwd | 8 | 1024 | 8 | 64 | 1.251 | 1.206 | 1.04x | -3.5% |
| chunk_delta_rule | fwdbwd | 8 | 2048 | 32 | 256 | 9.848 | 9.857 | 1.00x | +0.1% |
| chunk_dplr_delta_rule | fwdbwd | 1 | 8192 | 96 | 128 | 24.605 | 24.622 | 1.00x | +0.1% |
| chunk_dplr_delta_rule | fwdbwd | 2 | 16384 | 16 | 128 | 21.635 | 21.634 | 1.00x | -0.0% |
| chunk_dplr_delta_rule | fwdbwd | 4 | 2048 | 16 | 128 | 4.512 | 4.516 | 1.00x | +0.1% |
| chunk_dplr_delta_rule | fwdbwd | 4 | 4096 | 64 | 128 | 31.408 | 31.420 | 1.00x | +0.0% |
| chunk_dplr_delta_rule | fwdbwd | 8 | 1024 | 8 | 64 | 1.929 | 1.874 | 1.03x | -2.8% |
| chunk_dplr_delta_rule | fwdbwd | 8 | 2048 | 32 | 256 | 43.293 | 43.307 | 1.00x | +0.0% |
| chunk_gdn | fwdbwd | 1 | 8192 | 96 | 128 | 6.805 | 6.790 | 1.00x | -0.2% |
| chunk_gdn | fwdbwd | 2 | 16384 | 16 | 128 | 4.814 | 4.808 | 1.00x | -0.1% |
| chunk_gdn | fwdbwd | 4 | 2048 | 16 | 128 | 1.841 | 1.851 | 0.99x | +0.6% |
| chunk_gdn | fwdbwd | 4 | 4096 | 64 | 128 | 8.731 | 8.713 | 1.00x | -0.2% |
| chunk_gdn | fwdbwd | 8 | 1024 | 8 | 64 | 1.792 | 1.772 | 1.01x | -1.1% |
| chunk_gdn | fwdbwd | 8 | 2048 | 32 | 256 | 14.850 | 14.862 | 1.00x | +0.1% |
| chunk_gla | fwdbwd | 1 | 8192 | 96 | 128 | 10.936 | 10.941 | 1.00x | +0.0% |
| chunk_gla | fwdbwd | 2 | 16384 | 16 | 128 | 8.483 | 8.480 | 1.00x | -0.0% |
| chunk_gla | fwdbwd | 4 | 2048 | 16 | 128 | 2.021 | 2.200 | 0.92x | +8.9% 🔴 |
| chunk_gla | fwdbwd | 4 | 4096 | 64 | 128 | 15.179 | 15.181 | 1.00x | +0.0% |
| chunk_gla | fwdbwd | 8 | 1024 | 8 | 64 | 1.179 | 1.718 | 0.69x | +45.8% 🔴 |
| chunk_gla | fwdbwd | 8 | 2048 | 32 | 256 | 18.964 | 18.968 | 1.00x | +0.0% |
| chunk_kda | fwdbwd | 1 | 8192 | 96 | 128 | 12.487 | 12.500 | 1.00x | +0.1% |
| chunk_kda | fwdbwd | 2 | 16384 | 16 | 128 | 8.369 | 8.371 | 1.00x | +0.0% |
| chunk_kda | fwdbwd | 4 | 2048 | 16 | 128 | 2.180 | 2.183 | 1.00x | +0.1% |
| chunk_kda | fwdbwd | 4 | 4096 | 64 | 128 | 15.971 | 15.962 | 1.00x | -0.1% |
| chunk_kda | fwdbwd | 8 | 1024 | 8 | 64 | 2.101 | 2.003 | 1.05x | -4.7% |
| chunk_kda | fwdbwd | 8 | 2048 | 32 | 256 | 18.535 | 18.538 | 1.00x | +0.0% |
| chunk_lightning_attn | fwdbwd | 1 | 8192 | 96 | 128 | 3.442 | 3.461 | 0.99x | +0.6% |
| chunk_lightning_attn | fwdbwd | 2 | 16384 | 16 | 128 | 2.667 | 2.676 | 1.00x | +0.4% |
| chunk_lightning_attn | fwdbwd | 4 | 2048 | 16 | 128 | 0.942 | 0.985 | 0.96x | +4.6% |
| chunk_lightning_attn | fwdbwd | 4 | 4096 | 64 | 128 | 4.495 | 4.548 | 0.99x | +1.2% |
| chunk_lightning_attn | fwdbwd | 8 | 1024 | 8 | 64 | 0.799 | 0.812 | 0.98x | +1.6% |
| chunk_lightning_attn | fwdbwd | 8 | 2048 | 32 | 256 | 10.077 | 10.149 | 0.99x | +0.7% |
| chunk_linear_attn | fwdbwd | 1 | 8192 | 96 | 128 | 12.339 | 12.345 | 1.00x | +0.1% |
| chunk_linear_attn | fwdbwd | 2 | 16384 | 16 | 128 | 16.590 | 16.613 | 1.00x | +0.1% |
| chunk_linear_attn | fwdbwd | 4 | 2048 | 16 | 128 | 2.698 | 2.700 | 1.00x | +0.1% |
| chunk_linear_attn | fwdbwd | 4 | 4096 | 64 | 128 | 11.619 | 11.611 | 1.00x | -0.1% |
| chunk_linear_attn | fwdbwd | 8 | 1024 | 8 | 64 | 1.016 | 1.069 | 0.95x | +5.2% 🔴 |
| chunk_linear_attn | fwdbwd | 8 | 2048 | 32 | 256 | 15.315 | 15.316 | 1.00x | +0.0% |
| chunk_retention | fwdbwd | 1 | 8192 | 96 | 128 | 3.561 | 3.544 | 1.00x | -0.5% |
| chunk_retention | fwdbwd | 2 | 16384 | 16 | 128 | 2.749 | 2.769 | 0.99x | +0.7% |
| chunk_retention | fwdbwd | 4 | 2048 | 16 | 128 | 1.041 | 1.047 | 0.99x | +0.6% |
| chunk_retention | fwdbwd | 4 | 4096 | 64 | 128 | 4.571 | 4.603 | 0.99x | +0.7% |
| chunk_retention | fwdbwd | 8 | 1024 | 8 | 64 | 0.870 | 0.869 | 1.00x | -0.2% |
| chunk_retention | fwdbwd | 8 | 2048 | 32 | 256 | 10.186 | 10.197 | 1.00x | +0.1% |
| chunk_rwkv6 | fwdbwd | 1 | 8192 | 96 | 128 | 13.064 | 13.064 | 1.00x | -0.0% |
| chunk_rwkv6 | fwdbwd | 2 | 16384 | 16 | 128 | 9.626 | 9.633 | 1.00x | +0.1% |
| chunk_rwkv6 | fwdbwd | 4 | 2048 | 16 | 128 | 2.301 | 2.304 | 1.00x | +0.1% |
| chunk_rwkv6 | fwdbwd | 4 | 4096 | 64 | 128 | 16.832 | 16.826 | 1.00x | -0.0% |
| chunk_rwkv6 | fwdbwd | 8 | 1024 | 8 | 64 | 1.205 | 1.247 | 0.97x | +3.5% |
| chunk_rwkv6 | fwdbwd | 8 | 2048 | 32 | 256 | 20.796 | 20.785 | 1.00x | -0.1% |
| chunk_rwkv7 | fwdbwd | 1 | 8192 | 96 | 128 | 23.836 | 23.841 | 1.00x | +0.0% |
| chunk_rwkv7 | fwdbwd | 2 | 16384 | 16 | 128 | 21.187 | 21.202 | 1.00x | +0.1% |
| chunk_rwkv7 | fwdbwd | 4 | 2048 | 16 | 128 | 4.377 | 4.374 | 1.00x | -0.1% |
| chunk_rwkv7 | fwdbwd | 4 | 4096 | 64 | 128 | 30.466 | 30.463 | 1.00x | -0.0% |
| chunk_rwkv7 | fwdbwd | 8 | 1024 | 8 | 64 | 1.793 | 1.989 | 0.90x | +10.9% 🔴 |
| chunk_rwkv7 | fwdbwd | 8 | 2048 | 32 | 256 | 104.621 | 104.601 | 1.00x | -0.0% |
| chunk_simple_gla | fwdbwd | 1 | 8192 | 96 | 128 | 3.899 | 3.902 | 1.00x | +0.1% |
| chunk_simple_gla | fwdbwd | 2 | 16384 | 16 | 128 | 3.128 | 3.131 | 1.00x | +0.1% |
| chunk_simple_gla | fwdbwd | 4 | 2048 | 16 | 128 | 1.018 | 0.928 | 1.10x | -8.8% 🟢 |
| chunk_simple_gla | fwdbwd | 4 | 4096 | 64 | 128 | 5.238 | 5.240 | 1.00x | +0.0% |
| chunk_simple_gla | fwdbwd | 8 | 1024 | 8 | 64 | 0.919 | 1.028 | 0.89x | +11.8% 🔴 |
| chunk_simple_gla | fwdbwd | 8 | 2048 | 32 | 256 | 12.184 | 12.196 | 1.00x | +0.1% |
| flash_attn | fwdbwd | 1 | 8192 | 96 | 128 | 18.668 | 18.299 | 1.02x | -2.0% |
| flash_attn | fwdbwd | 2 | 16384 | 16 | 128 | 23.051 | 23.037 | 1.00x | -0.1% |
| flash_attn | fwdbwd | 4 | 2048 | 16 | 128 | 1.041 | 1.042 | 1.00x | +0.1% |
| flash_attn | fwdbwd | 4 | 4096 | 64 | 128 | 12.854 | 12.846 | 1.00x | -0.1% |
| flash_attn | fwdbwd | 8 | 1024 | 8 | 64 | 0.278 | 0.344 | 0.81x | +23.4% 🔴 |
| flash_attn | fwdbwd | 8 | 2048 | 32 | 256 | 8.292 | 8.302 | 1.00x | +0.1% |
| fused_recurrent_hgrn | fwdbwd | 1 | 8192 | 96 | 128 | 8.191 | 8.189 | 1.00x | -0.0% |
| fused_recurrent_hgrn | fwdbwd | 2 | 16384 | 16 | 128 | 15.379 | 15.388 | 1.00x | +0.1% |
| fused_recurrent_hgrn | fwdbwd | 4 | 2048 | 16 | 128 | 1.975 | 1.976 | 1.00x | +0.0% |
| fused_recurrent_hgrn | fwdbwd | 4 | 4096 | 64 | 128 | 4.850 | 4.853 | 1.00x | +0.1% |
| fused_recurrent_hgrn | fwdbwd | 8 | 1024 | 8 | 64 | 0.952 | 0.952 | 1.00x | +0.0% |
| fused_recurrent_hgrn | fwdbwd | 8 | 2048 | 32 | 256 | 3.021 | 3.021 | 1.00x | +0.0% |
| parallel_attn | fwdbwd | 1 | 8192 | 96 | 128 | 22.015 | 22.022 | 1.00x | +0.0% |
| parallel_attn | fwdbwd | 2 | 16384 | 16 | 128 | 28.969 | 28.946 | 1.00x | -0.1% |
| parallel_attn | fwdbwd | 4 | 2048 | 16 | 128 | 1.471 | 1.471 | 1.00x | +0.0% |
| parallel_attn | fwdbwd | 4 | 4096 | 64 | 128 | 16.988 | 16.984 | 1.00x | -0.0% |
| parallel_attn | fwdbwd | 8 | 1024 | 8 | 64 | 0.465 | 0.487 | 0.95x | +4.7% |
This comment is automatically updated with the latest benchmark results.
Summary by CodeRabbit
New Features
Tests
Documentation