Skip to content

[CP] fuse fwd/bwd kernels and fix IMA in long context#733

Merged
zhiyuan1i merged 3 commits intomainfrom
lzy/enhance-cp-and-fix-ima
Feb 3, 2026
Merged

[CP] fuse fwd/bwd kernels and fix IMA in long context#733
zhiyuan1i merged 3 commits intomainfrom
lzy/enhance-cp-and-fix-ima

Conversation

@zhiyuan1i
Copy link
Copy Markdown
Collaborator

@zhiyuan1i zhiyuan1i commented Jan 30, 2026

  • Fuse forward and backward kernels for better performance

  • Fix IMA (Illegal Memory Access) issue in long context scenarios

Summary by CodeRabbit

  • New Features

    • Added end-to-end benchmarking scripts for single-GPU and distributed kernel comparisons.
  • Performance Improvements

    • Fused multi-stage pre-processing into single merged kernels for forward and backward paths.
  • Stability / Bug Fixes

    • Extended index/address arithmetic to 64-bit across kernels for correctness with large or variable-length sequences.
  • Tests

    • Simplified test memory-guard behavior: legacy guard removed and floating tensors initialized with NaNs for clearer failure detection.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Jan 30, 2026

Walkthrough

Adds two new benchmark scripts, fuses CP pre-processing into merged Triton kernels for forward/backward, widens kernel index arithmetic to 64-bit in several Triton kernels, and simplifies the test memory-guard by replacing canary/padding logic with direct NaN-filling for floating/complex tensors.

Changes

Cohort / File(s) Summary
Benchmarks
benchmarks/cp/benchmark_chunk_delta_h_kernels.py, benchmarks/cp/benchmark_kda_cp8_vs_cp2tp.py
Add two benchmark scripts: a Triton microbenchmark for chunk_delta_h kernels and a distributed CP8 vs CP2TP benchmark with group setup, profiling, baseline comparison, and timing utilities.
Merged CP Triton Kernels
fla/ops/cp/chunk_delta_h.py
Introduce pre_process_fwd_kernel_merged and pre_process_bwd_kernel_merged; replace prior multi-stage dispatch with fused pre-process kernels, adjust BLOCK_SIZE/grid logic and varlen cu_seqlens handling.
64-bit Indexing / Pointer Safety
fla/ops/common/chunk_delta_h.py, fla/ops/gla/chunk.py, fla/ops/kda/chunk_bwd.py
Cast loop indices and cu_seqlens-derived values to 64-bit (tl.int64) for pointer arithmetic and addressing in varlen and non-varlen branches.
Test Memory Guard
tests/conftest.py
Remove legacy canary/alignment padding guard; simplify guarded allocations to directly fill floating/complex tensors with NaN and update patched torch allocation helpers; remove explicit memory-guard verification.

Sequence Diagram(s)

sequenceDiagram
    participant Driver as Driver (launch)
    participant Rank as Worker Rank
    participant CP as CP Group
    participant TP as TP Group
    participant GPU as GPU / Triton Kernel

    Driver->>Rank: start distributed benchmark (args)
    Rank->>CP: join CP group
    Rank->>TP: join TP group
    Rank->>GPU: run local fused pre-process kernel
    GPU-->>Rank: return local shard outputs
    Rank->>CP: intra-CP all-gather / all-reduce (assemble/reduce shards)
    CP->>TP: optional inter-group coordination
    TP->>GPU: run TP-side kernels / execute remainder
    GPU-->>Rank: return final results
    Rank->>Driver: report timings & profiler data
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~50 minutes

Possibly related PRs

Suggested reviewers

  • yzhangcs
  • Nathancgy

Poem

🐰
I hop through kernels, merged and bright,
I cast indices to make addresses right.
Benchmarks hum, timings on show,
NaNs stand guard where errors might grow. ✨

🚥 Pre-merge checks | ✅ 2 | ❌ 1
❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 32.56% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title '[CP] fuse fwd/bwd kernels and fix IMA in long context' accurately summarizes the main changes: kernel fusion for CP operations and IMA bug fixes, which are reflected in the merged kernel additions and 64-bit integer casting fixes.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch lzy/enhance-cp-and-fix-ima

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @zhiyuan1i, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the performance and stability of the flash-linear-attention library by introducing fused kernels for chunk_delta_h operations, which are critical for recurrent state updates. It also resolves critical Illegal Memory Access bugs that previously affected long context processing, ensuring robust operation for larger sequence lengths. Furthermore, the PR includes new benchmarking tools to validate these performance improvements and provides detailed documentation for the Context Parallel implementation, improving clarity and maintainability.

Highlights

  • Kernel Fusion for chunk_delta_h: Forward and backward kernels for chunk_delta_h have been fused into single, more efficient kernels (pre_process_fwd_kernel_merged and pre_process_bwd_kernel_merged). This aims to improve performance by reducing overhead associated with launching multiple kernels.
  • Illegal Memory Access (IMA) Fixes for Long Contexts: Addressed and fixed Illegal Memory Access issues that occurred in long context scenarios across several Triton kernels. This was primarily achieved by explicitly casting intermediate index calculations to tl.int64 to prevent integer overflow when dealing with large sequence lengths.
  • New Benchmark Scripts: Introduced two new benchmark scripts: one to compare the performance of the newly fused chunk_delta_h kernels against their unfused counterparts, and another to benchmark chunk_kda under different Context Parallel (CP) and Tensor Parallel (TP) configurations (CP8 vs CP2TP).
  • Context Parallel (CP) Documentation: Added comprehensive documentation detailing the Context Parallel implementation for Gated Delta Net (GDN) and Key-Decoupled Attention (KDA), including mathematical formulations, parallel scan algorithms, and implementation specifics.
  • Autotuning Improvements: Switched to cache_autotune for certain GLA kernels to leverage cached autotuning results, potentially speeding up kernel compilation and selection.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@zhiyuan1i zhiyuan1i force-pushed the lzy/enhance-cp-and-fix-ima branch 2 times, most recently from d85e659 to f49979b Compare January 30, 2026 02:58
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces two significant improvements: the fusion of forward and backward pre-processing kernels for better performance in context parallelism, and a fix for potential Illegal Memory Access (IMA) errors in long context scenarios. The kernel fusion is well-implemented through new merged Triton kernels, which are correctly integrated and benchmarked, demonstrating a clear performance benefit. The IMA fixes are applied consistently across multiple kernels by ensuring pointer arithmetic is performed using 64-bit integers, which is a critical correction for handling long sequences. Furthermore, the inclusion of new, comprehensive benchmark scripts and detailed documentation for the context parallel implementation are excellent additions that enhance the library's testability and maintainability. Overall, the changes are of high quality and significantly improve the robustness and performance of the codebase.

@zhiyuan1i zhiyuan1i force-pushed the lzy/enhance-cp-and-fix-ima branch 2 times, most recently from e4d1d49 to da166ba Compare January 30, 2026 02:59
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Fix all issues with AI agents
In `@benchmarks/cp/benchmark_kda_cp8_vs_cp2tp.py`:
- Around line 259-349: The custom-CP benchmark (kda_with_custom_cp) calls
chunk_kda with different kernel/config params than the simple-CP and baseline
variants, making timings incomparable; update the chunk_kda call inside
kda_with_custom_cp to pass the same kernel/config arguments used by
kda_with_simple_cp/kda_baseline_single_gpu (A_log, dt_bias,
use_gate_in_kernel=True, safe_gate=True, lower_bound=-5,
use_qk_l2norm_in_kernel=True) while keeping its cu_seqlens=cp_cu_seqlens and
existing TP all-reduce/backward handling so the computation is functionally
equivalent across kda_with_custom_cp, kda_with_simple_cp, and
kda_baseline_single_gpu.
🧹 Nitpick comments (7)
fla/ops/cp/chunk_delta_h.py (1)

838-843: Redundant if-else branches - both paths execute identical code.

The if USE_EXP2 / else branches perform the same tensor load operation. This pattern repeats at lines 850-853, 862-865, and 872-875.

♻️ Simplify by removing redundant conditional
             if USE_GK:
                 o_k1 = tl.arange(0, 64)
-                if USE_EXP2:
-                    b_gk_last1 = tl.load(gk + last_idx * H * K + o_k1, mask=(o_k1 < K), other=0.).to(tl.float32)
-                else:
-                    b_gk_last1 = tl.load(gk + last_idx * H * K + o_k1, mask=(o_k1 < K), other=0.).to(tl.float32)
+                b_gk_last1 = tl.load(gk + last_idx * H * K + o_k1, mask=(o_k1 < K), other=0.).to(tl.float32)
             b_dv = tl.dot(b_k, b_dh1.to(b_k.dtype))

Apply the same simplification to the other redundant branches (o_k2, o_k3, o_k4).

benchmarks/cp/benchmark_chunk_delta_h_kernels.py (2)

70-71: Consider using torch.int64 for cu_seqlens to match kernel expectations.

The kernels load cu_seqlens values with .to(tl.int64) (e.g., line 57 in chunk_delta_h.py). While the current int32 works due to the explicit cast in the kernel, using int64 here would be more consistent with the 64-bit indexing changes in this PR.

♻️ Proposed change
     # cu_seqlens for 1 chunk
-    cu_seqlens = torch.tensor([0, T], device=device, dtype=torch.int32)
+    cu_seqlens = torch.tensor([0, T], device=device, dtype=torch.int64)

425-428: Minor: Guard against division by zero in FWD/BWD ratio.

While unlikely in practice, bwd_time could theoretically be zero.

♻️ Proposed change
     print(f"Forward pass total:  {fwd_time * 1000:.2f} us ({fwd_time/total_time*100:.1f}%)")
     print(f"Backward pass total: {bwd_time * 1000:.2f} us ({bwd_time/total_time*100:.1f}%)")
-    print(f"FWD/BWD ratio: {fwd_time/bwd_time:.2f}")
+    print(f"FWD/BWD ratio: {fwd_time/bwd_time:.2f}" if bwd_time > 0 else "FWD/BWD ratio: N/A")
     print()
benchmarks/cp/benchmark_kda_cp8_vs_cp2tp.py (4)

48-48: Redundant argument configuration.

Using action="store_true" with default=True makes this flag always True. If you want to allow disabling the benchmark, consider using --no-bench with action="store_false".

♻️ Suggested fix
-    parser.add_argument("--bench", action="store_true", default=True, help="Run benchmark")
+    parser.add_argument("--no-bench", action="store_true", default=False, help="Skip benchmark")

Then update line 365:

-    if args.bench:
+    if not args.no_bench:

77-77: Rename unused loop variable to _.

The loop variable i is not used within the loop body. Per Python convention, rename to _ to indicate it's intentionally unused.

♻️ Suggested fix
-    for i in range(warm_up):
+    for _ in range(warm_up):
         fn()
         if grad_to_none is not None:
             for x in grad_to_none:
                 x.grad = None
 
     # Benchmark
     torch.cuda.synchronize()
     start_event.record()
-    for i in range(step):
+    for _ in range(step):
         fn()

Also applies to: 86-86


95-95: Remove unused rank parameter.

The rank parameter is passed to profile_kernels() but never used within the function.

♻️ Suggested fix
-def profile_kernels(fn, rank, steps=5, warmup=2):
+def profile_kernels(fn, steps=5, warmup=2):

And update the call site at line 360:

-        kernel_stats = profile_kernels(kda_with_custom_cp, rank, steps=5, warmup=2)
+        kernel_stats = profile_kernels(kda_with_custom_cp, steps=5, warmup=2)

237-243: Consider defining do_base unconditionally to avoid potential NameError.

When run_backward=True and test_baseline=False, do_base is never defined. While the current control flow prevents accessing it, this structure is fragile. Defining it unconditionally improves clarity.

♻️ Suggested fix
     # Output gradient for backward
+    do_base = None
     if run_backward:
         do = torch.randn(B, T_local, H_local, V, dtype=DTYPE, device=device)
         if test_baseline:
             do_base = torch.randn(B, T_baseline, H_local, V, dtype=DTYPE, device=device)
     else:
         do = None
-        do_base = None

Comment thread benchmarks/cp/benchmark_kda_cp8_vs_cp2tp.py Outdated
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🤖 Fix all issues with AI agents
In `@benchmarks/cp/benchmark_kda_cp8_vs_cp2tp.py`:
- Around line 188-198: Add explicit divisibility checks immediately after
computing T_total, T_local, T_baseline and H_local: verify T_total % cp_size ==
0 and H % tp_size == 0 (and that cp_size and tp_size are > 0), and raise or
assert with a clear explanatory message if they are not divisible so the user
sees which dimension is invalid; update the logic around variables T_local,
T_baseline and H_local to rely on these validated values (variables referenced:
T_total, T_local, T_baseline, H_local, cp_size, tp_size, H).
- Around line 187-225: The CP paths assume B==1 which breaks when batch size >1;
fix by making cu_seqlens and the all-gather logic batch-aware and removing any
unconditional squeeze(0) that collapses the batch dim. Specifically, ensure
cu_seqlens passed to build_cp_context is shaped per-batch (one cu_seqlens
sequence per sample) instead of a single [0, T_total], update the
all-gather/concat that previously used squeeze(0) to preserve the B dimension
when gathering q/k/v across CP ranks, and adjust chunk_kda inputs/reshaping to
accept (B, T_local, H_local, ...) tensors rather than relying on B==1. Update
references: cu_seqlens, build_cp_context, chunk_kda, and the squeeze(0) usage to
handle arbitrary B.
- Around line 418-435: The script binds CUDA using the global process rank which
breaks multi-node torchrun; in main() replace uses of rank for device-selection
and device-name queries with the LOCAL_RANK provided by torchrun: read
local_rank = int(os.environ.get("LOCAL_RANK", rank)) (or fallback to rank), call
torch.cuda.set_device(local_rank), and use
torch.cuda.get_device_name(local_rank); also use local_rank when seeding CUDA
(torch.cuda.manual_seed) so per-process seeds are per-local-device while keeping
distributed init via dist.init_process_group() and world_size/rank logic
unchanged.
🧹 Nitpick comments (2)
benchmarks/cp/benchmark_kda_cp8_vs_cp2tp.py (2)

71-88: Avoid grad accumulation during timed iterations (and fix unused loop vars).

Grad accumulation can skew timing/memory, and Ruff flagged unused loop vars. Clear grads inside the timed loop and rename the index to _.

♻️ Proposed refactor
-    for i in range(warm_up):
+    for _ in range(warm_up):
         fn()
         if grad_to_none is not None:
             for x in grad_to_none:
                 x.grad = None
@@
-    for i in range(step):
+    for _ in range(step):
         fn()
+        if grad_to_none is not None:
+            for x in grad_to_none:
+                x.grad = None

95-98: Remove unused rank argument from profile_kernels.

rank isn’t used; simplify the signature and call site.

♻️ Proposed refactor
-def profile_kernels(fn, rank, steps=5, warmup=2):
+def profile_kernels(fn, steps=5, warmup=2):
@@
-        kernel_stats = profile_kernels(kda_with_cp, rank, steps=5, warmup=2)
+        kernel_stats = profile_kernels(kda_with_cp, steps=5, warmup=2)

Also applies to: 366-367

Comment thread benchmarks/cp/benchmark_kda_cp8_vs_cp2tp.py
Comment thread benchmarks/cp/benchmark_kda_cp8_vs_cp2tp.py
Comment on lines +418 to +435
def main():
dist.init_process_group()
world_size = dist.get_world_size()
rank = dist.get_rank()

torch.manual_seed(rank + 42)
torch.cuda.manual_seed(rank + 42)
random.seed(42)
torch.cuda.set_device(rank)

args = get_args()

# Print environment info
if rank == 0:
print(f"\nPyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"CUDA device: {torch.cuda.get_device_name(rank)}")
print(f"World size: {world_size}\n")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

# Check if the file exists and examine the code at the specified lines
fd "benchmark_kda_cp8_vs_cp2tp.py" --type f

Repository: fla-org/flash-linear-attention

Length of output: 117


🏁 Script executed:

# Also check the full imports and context of the file
fd "benchmark_kda_cp8_vs_cp2tp.py" -x wc -l {}

Repository: fla-org/flash-linear-attention

Length of output: 123


🌐 Web query:

torchrun LOCAL_RANK environment variable torch.cuda.set_device best practice multi-node

💡 Result:

Best practice (multi-node, torchrun)

  1. Read ranks from env (don’t invent them):
  • LOCAL_RANK = GPU index on this node for this process (0…LOCAL_WORLD_SIZE-1)
  • RANK / WORLD_SIZE = global process id / total processes across all nodes (docs.pytorch.org)
  1. Set the CUDA device early (before creating CUDA tensors/modules):
  • Recommended pattern is to bind each process to exactly one GPU using LOCAL_RANK, e.g.:
    • torch.cuda.set_device(local_rank) (or a with torch.cuda.device(local_rank): ... scope) (docs.pytorch.org)
  • This avoids every process defaulting to cuda:0 unless you explicitly change it. (discuss.pytorch.org)
  1. Initialize distributed with env:// and wrap model with per-process device ids:
  • Use init_method="env://" (officially supported approach) (docs.pytorch.org)
  • For DistributedDataParallel, ensure each process uses only its GPU:
    • DistributedDataParallel(model, device_ids=[local_rank], output_device=local_rank) (docs.pytorch.org)
  1. Don’t treat LOCAL_RANK as globally unique:
  • It’s only unique within a node—don’t use it for “rank-0 does logging/checkpointing” decisions; use global RANK==0 for that. (docs.pytorch.org)

Minimal, solid template

import os
import torch
import torch.distributed as dist

def main():
    local_rank = int(os.environ["LOCAL_RANK"])
    rank = int(os.environ["RANK"])
    world_size = int(os.environ["WORLD_SIZE"])

    torch.cuda.set_device(local_rank)  # do this before any CUDA allocations
    dist.init_process_group(backend="nccl", init_method="env://")

    device = torch.device("cuda", local_rank)
    model = build_model().to(device)
    model = torch.nn.parallel.DistributedDataParallel(
        model, device_ids=[local_rank], output_device=local_rank
    )

torch.cuda.set_device(local_rank) vs CUDA_VISIBLE_DEVICES

Either works as long as each process exclusively operates on its assigned GPU:

  • set_device(local_rank): process can still “see” all GPUs, but default ops go to the selected device. (discuss.pytorch.org)
  • CUDA_VISIBLE_DEVICES per process: each process only sees one GPU (so it often reports cuda:0, mapped to a different physical GPU per process). (discuss.pytorch.org)

Citations:


🏁 Script executed:

# Read the file to see the actual implementation at lines 418-435
sed -n '410,443p' benchmarks/cp/benchmark_kda_cp8_vs_cp2tp.py

Repository: fla-org/flash-linear-attention

Length of output: 1075


🏁 Script executed:

# Check imports at the top of the file
head -30 benchmarks/cp/benchmark_kda_cp8_vs_cp2tp.py

Repository: fla-org/flash-linear-attention

Length of output: 1143


🏁 Script executed:

# Check if os is imported anywhere in the file
rg "^import os|^from os" benchmarks/cp/benchmark_kda_cp8_vs_cp2tp.py

Repository: fla-org/flash-linear-attention

Length of output: 56


Use LOCAL_RANK for device selection (multi-node safe).

torch.cuda.set_device(rank) uses global rank and breaks multi-node torchrun. This script explicitly uses torchrun --nproc_per_node=8, which sets LOCAL_RANK for each process. Use LOCAL_RANK instead of global rank for device binding, and apply it to both device-setting and device-name queries.

Proposed fix
+import os
 import argparse
 import random
@@
 def main():
     dist.init_process_group()
     world_size = dist.get_world_size()
     rank = dist.get_rank()
@@
-    torch.cuda.set_device(rank)
+    local_rank = int(os.environ.get("LOCAL_RANK", rank))
+    torch.cuda.set_device(local_rank)
@@
-        print(f"CUDA device: {torch.cuda.get_device_name(rank)}")
+        print(f"CUDA device: {torch.cuda.get_device_name(local_rank)}")
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def main():
dist.init_process_group()
world_size = dist.get_world_size()
rank = dist.get_rank()
torch.manual_seed(rank + 42)
torch.cuda.manual_seed(rank + 42)
random.seed(42)
torch.cuda.set_device(rank)
args = get_args()
# Print environment info
if rank == 0:
print(f"\nPyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"CUDA device: {torch.cuda.get_device_name(rank)}")
print(f"World size: {world_size}\n")
def main():
dist.init_process_group()
world_size = dist.get_world_size()
rank = dist.get_rank()
torch.manual_seed(rank + 42)
torch.cuda.manual_seed(rank + 42)
random.seed(42)
local_rank = int(os.environ.get("LOCAL_RANK", rank))
torch.cuda.set_device(local_rank)
args = get_args()
# Print environment info
if rank == 0:
print(f"\nPyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"CUDA device: {torch.cuda.get_device_name(local_rank)}")
print(f"World size: {world_size}\n")
🤖 Prompt for AI Agents
In `@benchmarks/cp/benchmark_kda_cp8_vs_cp2tp.py` around lines 418 - 435, The
script binds CUDA using the global process rank which breaks multi-node
torchrun; in main() replace uses of rank for device-selection and device-name
queries with the LOCAL_RANK provided by torchrun: read local_rank =
int(os.environ.get("LOCAL_RANK", rank)) (or fallback to rank), call
torch.cuda.set_device(local_rank), and use
torch.cuda.get_device_name(local_rank); also use local_rank when seeding CUDA
(torch.cuda.manual_seed) so per-process seeds are per-local-device while keeping
distributed init via dist.init_process_group() and world_size/rank logic
unchanged.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Fix all issues with AI agents
In `@tests/conftest.py`:
- Around line 90-97: Remove the trailing-whitespace-only lines in the zeros-like
replacement block to satisfy lint: edit the function that calls
_ORIGINAL_ZEROS_LIKE so the lines around the result assignment and the
subsequent if/elif (where result.is_floating_point(), result.fill_(...),
result.is_complex(), result.fill_(...)) contain no trailing spaces or
blank-space-only lines (or run the repo's pre-commit hook) so the file has no
trailing whitespace.
🧹 Nitpick comments (1)
tests/conftest.py (1)

55-75: Align _guarded_zeros with compile bypass used elsewhere.

_guarded_zeros_like and _guarded_new_zeros skip poisoning when torch.compiler.is_compiling() is true, but _guarded_zeros does not. If torch.zeros is hit during compilation, NaN poisoning could leak into compiled graphs. Consider short‑circuiting on is_compiling() here as well for consistency.

♻️ Suggested tweak
 def _guarded_zeros(*args, **kwargs):
     """Create a tensor filled with NaN instead of zeros to detect incorrect dependency on zero initialization."""
-    dtype = kwargs.get('dtype') or torch.get_default_dtype()
+    if is_compiling() or not _is_called_from_fla():
+        return _ORIGINAL_ZEROS(*args, **kwargs)
+
+    dtype = kwargs.get('dtype') or torch.get_default_dtype()
     
     if not (dtype.is_floating_point or dtype.is_complex):
         return _ORIGINAL_ZEROS(*args, **kwargs)
-
-    # Check if this call is from fla package or from test directly
-    if not _is_called_from_fla():
-        # Direct call from test file - don't guard
-        return _ORIGINAL_ZEROS(*args, **kwargs)

Comment thread tests/conftest.py Outdated
@zhiyuan1i zhiyuan1i force-pushed the lzy/enhance-cp-and-fix-ima branch from f82f78f to 7c66d96 Compare February 3, 2026 01:12
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🤖 Fix all issues with AI agents
In `@benchmarks/cp/benchmark_kda_cp8_vs_cp2tp.py`:
- Around line 174-184: The divisibility checks use undefined variables T_total
and H (they are defined later), causing a NameError; move the validation block
that checks T_total % cp_size and H % tp_size so it runs after T_total and H are
assigned (after the assignments currently at lines where T_total and H are set
inside run_benchmark or the surrounding scope), or alternatively compute/assign
T_total and H before performing the checks; ensure the checks reference the same
cp_size and tp_size variables used elsewhere (e.g., in run_benchmark) and remove
the original validation lines from their current location.
- Around line 257-259: The dt_bias tensor is created with shape [H_local] but
KDA expects a per-head, per-head-dimension bias of shape (H_local, K); change
the dt_bias construction so it uses both H_local and K (e.g.,
torch.randn(H_local, K, device=device, dtype=torch.float32)) to match KDA
requirements and avoid shape mismatches when used alongside A_log and in
chunk_kda computations.
🧹 Nitpick comments (2)
benchmarks/cp/benchmark_kda_cp8_vs_cp2tp.py (2)

95-95: Unused function parameter rank.

The rank parameter is never used within profile_kernels. Consider removing it or prefixing with underscore if reserved for future use.

♻️ Proposed fix
-def profile_kernels(fn, rank, steps=5, warmup=2):
+def profile_kernels(fn, steps=5, warmup=2):

And update the call site at line 378:

-        kernel_stats = profile_kernels(kda_with_cp, rank, steps=5, warmup=2)
+        kernel_stats = profile_kernels(kda_with_cp, steps=5, warmup=2)

71-92: Unused loop variables i.

Loop control variables at lines 77 and 86 are not used within the loop bodies. Rename to _ per Python convention.

♻️ Proposed fix
     # Warmup
-    for i in range(warm_up):
+    for _ in range(warm_up):
         fn()
         if grad_to_none is not None:
             for x in grad_to_none:
                 x.grad = None

     # Benchmark
     torch.cuda.synchronize()
     start_event.record()
-    for i in range(step):
+    for _ in range(step):
         fn()

Comment thread benchmarks/cp/benchmark_kda_cp8_vs_cp2tp.py Outdated
Comment thread benchmarks/cp/benchmark_kda_cp8_vs_cp2tp.py Outdated
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🤖 Fix all issues with AI agents
In `@benchmarks/cp/benchmark_kda_cp8_vs_cp2tp.py`:
- Around line 172-173: The runtime check using assert on cp_size * tp_size is
unsafe because asserts can be disabled; replace it with an explicit validation
that checks if cp_size * tp_size != world_size and then raise a clear exception
(e.g., ValueError) or call sys.exit with the same descriptive message
referencing cp_size, tp_size, and world_size so invalid CP/TP sizing fails
immediately and deterministically.

In `@tests/conftest.py`:
- Around line 55-75: _guarded_zeros is missing the is_compiling() bypass present
in _guarded_zeros_like and _guarded_new_zeros; add an early guard that calls
is_compiling() and if True returns _ORIGINAL_ZEROS(*args, **kwargs) (place this
check after determining dtype and before calling _is_called_from_fla()) so the
function short-circuits during compilation and preserves the existing
inspect/NaN behavior only at runtime.

Comment on lines +172 to +173
assert cp_size * tp_size == world_size, \
f"CP size ({cp_size}) * TP size ({tp_size}) must equal world size ({world_size})"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Avoid assert for runtime config validation.

assert can be stripped with -O, allowing invalid CP/TP sizing to proceed and fail later. Prefer an explicit check.

🔧 Proposed fix
-    assert cp_size * tp_size == world_size, \
-        f"CP size ({cp_size}) * TP size ({tp_size}) must equal world size ({world_size})"
+    if cp_size * tp_size != world_size:
+        raise ValueError(
+            f"CP size ({cp_size}) * TP size ({tp_size}) must equal world size ({world_size})"
+        )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
assert cp_size * tp_size == world_size, \
f"CP size ({cp_size}) * TP size ({tp_size}) must equal world size ({world_size})"
if cp_size * tp_size != world_size:
raise ValueError(
f"CP size ({cp_size}) * TP size ({tp_size}) must equal world size ({world_size})"
)
🤖 Prompt for AI Agents
In `@benchmarks/cp/benchmark_kda_cp8_vs_cp2tp.py` around lines 172 - 173, The
runtime check using assert on cp_size * tp_size is unsafe because asserts can be
disabled; replace it with an explicit validation that checks if cp_size *
tp_size != world_size and then raise a clear exception (e.g., ValueError) or
call sys.exit with the same descriptive message referencing cp_size, tp_size,
and world_size so invalid CP/TP sizing fails immediately and deterministically.

Comment thread tests/conftest.py Outdated
- Fuse forward and backward kernels for better performance

- Fix IMA (Illegal Memory Access) issue in long context scenarios
@zhiyuan1i zhiyuan1i force-pushed the lzy/enhance-cp-and-fix-ima branch from 7c66d96 to 9b2af83 Compare February 3, 2026 02:13
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
fla/ops/cp/chunk_delta_h.py (1)

1051-1059: ⚠️ Potential issue | 🟠 Major

Guard cu_seqlens slicing when it can be None.

The fixed-length CP path allows cu_seqlens=None; slicing will now raise. Preserve previous behavior by passing None through.

🛠️ Suggested fix
-        pre_process_fwd_kernel_merged[grid](
+        local_cu_seqlens = cu_seqlens[-2:] if cu_seqlens is not None else None
+        pre_process_fwd_kernel_merged[grid](
             k=k,
             v=u,
             w=w,
             g=g,
             gk=gk,
             hm=hm,
-            cu_seqlens=cu_seqlens[-2:],
+            cu_seqlens=local_cu_seqlens,
             T=T,
             H=H,
             K=K,
             V=V,
             BT=BT,
             BK1=BK,
             USE_EXP2=use_exp2,
             BLOCK_SIZE=BLOCK_SIZE,
         )
-        pre_process_bwd_kernel_merged[grid](
+        local_cu_seqlens = cu_seqlens[:2] if cu_seqlens is not None else None
+        pre_process_bwd_kernel_merged[grid](
             q=q,
             k=k,
             w=w,
             g=g,
             gk=gk,
             do=do,
             dhm=dhm,
             dv=dv,
-            cu_seqlens=cu_seqlens[:2],
+            cu_seqlens=local_cu_seqlens,
             scale=scale,
             T=T,
             H=H,
             K=K,
             V=V,
             BT=BT,
             BK1=BK,
             USE_EXP2=use_exp2,
             BLOCK_SIZE=BLOCK_SIZE,
         )

Also applies to: 1120-1129

🤖 Fix all issues with AI agents
In `@benchmarks/cp/benchmark_chunk_delta_h_kernels.py`:
- Around line 110-127: The benchmark_all_kernels function runs kernels that
hardcode i_n = 0 (see generate_chunk_delta_h_cache kernels), so if B > 1 the
benchmark only measures the first batch and yields misleading results; add an
explicit guard at the start of benchmark_all_kernels (before create_tensors)
that checks B and either assert/raise if B != 1 or normalize B to 1 for these
single-batch kernels (or skip running these kernels with a clear warning),
referencing benchmark_all_kernels, generate_chunk_delta_h_cache, create_tensors
and the i_n usage so reviewers can locate the change.
- Around line 203-205: The grid sizing uses triton.cdiv(V + K, BLOCK_SIZE) which
undercounts blocks because the kernel expects separate block counts; change the
calculation for grid_merged (and the other occurrences at lines ~333-335) to sum
triton.cdiv(V, BLOCK_SIZE) + triton.cdiv(K, BLOCK_SIZE) for the first dimension
so the kernel's mapping (i_col, i_k_col = i_col - tl.cdiv(V, BLOCK_SIZE)) is
valid; update any related variables that build grid_merged accordingly
(referencing BLOCK_SIZE, grid_merged, V, K, i_col, i_k_col, tl.cdiv) so the K
portion is fully scheduled.

In `@fla/ops/cp/chunk_delta_h.py`:
- Around line 1049-1050: The grid sizing is under-allocating K tiles: change the
merged-kernel grid computation from using triton.cdiv(V + K, BLOCK_SIZE) to
summing separate cdivs so both V and K remainders are covered (e.g., grid =
(triton.cdiv(V, BLOCK_SIZE) + triton.cdiv(K, BLOCK_SIZE), H)); update the same
pattern at the other occurrence around lines 1118-1119, keeping BLOCK_SIZE, grid
and the kernel logic that uses is_h_part = i_col * BLOCK_SIZE < V intact so the
kernel can correctly partition V and K blocks.
- Around line 893-897: The merged backward kernel pre_process_bwd_kernel_merged
differs from pre_process_bwd_kernel_stage1 by applying exp2 when USE_EXP2 is
true to GK (b_gk_last1 and other GK uses); if backward parity with
pre_process_bwd_kernel_stage1 is required, remove the USE_EXP2 conditional for
GK operations in pre_process_bwd_kernel_merged so GK always uses exp (i.e.,
replace conditional exp2/exp usages on b_gk_* with a single exp call); otherwise
confirm intent to mirror the forward merged kernel and leave USE_EXP2 in
place—update tests or a comment clarifying which parity (stage1 vs
forward-merged) is intended.

Comment on lines +110 to +127
def benchmark_all_kernels(B, T, H, K, V):
"""Benchmark all 6 kernels from generate_chunk_delta_h_cache()."""
print(f"\n{'=' * 80}")
print("Benchmarking chunk_delta_h kernels (using triton.testing.do_bench)")
print(f"{'=' * 80}")
print("Configuration:")
print(f" Batch size: {B}")
print(f" Heads: {H}")
print(f" Head dim (K): {K}")
print(f" Head dim (V): {V}")
print(f" SeqLen per rank: {T}")
print(f" Dtype: {DTYPE}")
print(f" Device: {device}")
print(f"{'=' * 80}\n")

# Create tensors
tensors = create_tensors(B, T, H, K, V, device, DTYPE)
BT = 64
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Guard batch size for single-batch kernels.

The kernels hardcode i_n = 0, so B > 1 benchmarks only the first batch. Add an explicit guard to avoid misleading results.

🛠️ Suggested fix
 def benchmark_all_kernels(B, T, H, K, V):
     """Benchmark all 6 kernels from generate_chunk_delta_h_cache()."""
+    if B != 1:
+        raise ValueError("This benchmark assumes batch size = 1 (kernels use i_n=0).")
🤖 Prompt for AI Agents
In `@benchmarks/cp/benchmark_chunk_delta_h_kernels.py` around lines 110 - 127, The
benchmark_all_kernels function runs kernels that hardcode i_n = 0 (see
generate_chunk_delta_h_cache kernels), so if B > 1 the benchmark only measures
the first batch and yields misleading results; add an explicit guard at the
start of benchmark_all_kernels (before create_tensors) that checks B and either
assert/raise if B != 1 or normalize B to 1 for these single-batch kernels (or
skip running these kernels with a clear warning), referencing
benchmark_all_kernels, generate_chunk_delta_h_cache, create_tensors and the i_n
usage so reviewers can locate the change.

Comment thread benchmarks/cp/benchmark_chunk_delta_h_kernels.py
Comment thread fla/ops/cp/chunk_delta_h.py
Comment thread fla/ops/cp/chunk_delta_h.py Outdated
@zhiyuan1i zhiyuan1i merged commit 7a882c3 into main Feb 3, 2026
2 of 3 checks passed
@zhiyuan1i zhiyuan1i deleted the lzy/enhance-cp-and-fix-ima branch February 3, 2026 05:40
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant