Skip to content

[Deltarule] Added intra-card context parallel optimization for KDA and GDN#743

Merged
zhiyuan1i merged 8 commits intomainfrom
lzy/speedup-prefill
Feb 12, 2026
Merged

[Deltarule] Added intra-card context parallel optimization for KDA and GDN#743
zhiyuan1i merged 8 commits intomainfrom
lzy/speedup-prefill

Conversation

@zhiyuan1i
Copy link
Copy Markdown
Collaborator

@zhiyuan1i zhiyuan1i commented Feb 10, 2026

Summary by CodeRabbit

  • New Features

    • Inference-only intra-card context-parallel optimization for variable-length sequences: splits long sequences into subsequences, processes them in parallel on a single card, and merges states; configurable max subsequences.
    • Multi-sequence and intra-card merge support in the forward/backward path and propagation of CPU-side cumulative-sequence lengths through the forward flow.
    • Added a common backend registry exposing the intra-card backend.
  • Tests

    • Added parameterized tests validating variable-length prefill behavior and final-state correctness.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @zhiyuan1i, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the performance of Gated Delta Networks (GDN) and Kernelized Delta Attention (KDA) during the prefill phase, especially for very long sequences in inference mode. By introducing an intra-card context parallel mechanism, it intelligently breaks down long sequences into smaller, manageable sub-sequences that can be processed concurrently on the GPU, and then efficiently merges their states. This optimization aims to reduce latency and improve throughput without altering the core functionality or mathematical correctness of the models.

Highlights

  • Intra-Card Context Parallelism (CP) for GDN and KDA: Introduced a new IntraCardCPBackend to accelerate prefill operations for Gated Delta Networks (GDN) and Kernelized Delta Attention (KDA) by splitting long sequences into sub-sequences and processing them in parallel across Streaming Multiprocessors (SMs).
  • Unified Triton Kernel for Merging: Refactored the merge_fwd_bwd_kernel in Triton to support both existing Context Parallel (CP) mode and the new Intra-Card CP mode, allowing for flexible merging of states from sub-sequences.
  • Dynamic Backend Dispatch: Implemented a dispatch mechanism for chunk_gated_delta_rule_fwd_h to automatically select the appropriate backend (e.g., Intra-Card CP) based on runtime conditions, specifically when in inference mode and dealing with variable-length sequences.
  • Performance Benchmarking and Testing: Added new test cases for chunk_varlen_prefill in both test_gated_delta.py and test_kda.py to validate the new functionality, and included a dedicated benchmark script benchmark_intracard_cp.py to measure performance gains.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • benchmarks/cp/benchmark_chunk_delta_h_kernels.py
    • Added MULTI_SEQS=False parameter to kernel_fwd_merged call.
    • Introduced new parameters (seq_offsets, init_offsets, h0_seq_ids, h0, INTRACARD_MODE, NUM_SEQ_ENTRIES) to kernel_merge_fwd and kernel_merge_bwd calls to support the unified merge kernel.
  • fla/ops/common/backends/init.py
    • Added new file to define common backend registry.
    • Registered IntraCardCPBackend for common operations.
  • fla/ops/common/backends/intracard.py
    • Added new file defining the IntraCardCPBackend.
    • Implemented is_available and chunk_gated_delta_rule_fwd_h_verifier methods to determine when intra-card CP should be used.
    • Provided the chunk_gated_delta_rule_fwd_h implementation that orchestrates the intra-card parallel processing.
  • fla/ops/common/chunk_delta_h.py
    • Imported dispatch from fla.ops.backends.
    • Decorated chunk_gated_delta_rule_fwd_h with @dispatch('common') to enable backend selection.
    • Added cu_seqlens_cpu as a parameter to chunk_gated_delta_rule_fwd_h.
  • fla/ops/common/intracard_cp.py
    • Added new file containing the core logic for intra-card context parallelism.
    • Defined SplitSeqInfo for managing split sequence metadata.
    • Implemented _raw_chunk_gated_delta_rule_fwd_h as a helper for the base chunked delta rule.
    • Included compute_subseq_len to calculate optimal sub-sequence lengths.
    • Developed prepare_subseq_cu_seqlens to manage sequence splitting.
    • Created intracard_pre_scan and intracard_merge functions for processing and combining sub-sequence results.
    • Provided the main intracard_fwd_h function to coordinate the entire intra-card forward pass.
  • fla/ops/cp/chunk_delta_h.py
    • Modified pre_process_fwd_kernel_merged to accept a MULTI_SEQS constexpr and adjust hm indexing for multi-sequence scenarios.
    • Refactored merge_fwd_bwd_kernel into a unified kernel, adding INTRACARD_MODE, NUM_SEQ_ENTRIES, HAS_H0 constexprs and new parameters (seq_offsets, init_offsets, h0_seq_ids, h0) to support both CP and Intra-card modes.
    • Updated calls to merge_fwd_bwd_kernel in chunk_gated_delta_rule_fwd_h_pre_process and chunk_gated_delta_rule_bwd_dhu_pre_process to pass the newly introduced parameters.
  • fla/ops/kda/chunk.py
    • Removed direct imports of chunk_gated_delta_rule_bwd_dhu and chunk_gated_delta_rule_fwd_h to leverage the new dispatch mechanism.
    • Added cu_seqlens_cpu parameter to chunk_kda_fwd.
    • Introduced delayed imports for chunk_gated_delta_rule_fwd_h and chunk_gated_delta_rule_bwd_dhu to ensure proper backend dispatching.
    • Passed cu_seqlens_cpu to the chunk_gated_delta_rule_fwd_h call.
  • tests/ops/test_gated_delta.py
    • Added test_chunk_varlen_prefill to verify the correctness of the intra-card CP implementation for gated delta rule with variable-length sequences in inference mode.
  • tests/ops/test_kda.py
    • Added test_chunk_varlen_prefill to validate the intra-card CP functionality for KDA with variable-length sequences in inference mode.
  • x4/benchmark_intracard_cp.py
    • Added new file to provide a benchmark script.
    • Implemented run_benchmark function to compare the performance of baseline KDA prefill against intra-card CP KDA prefill, including split analysis and speedup calculations.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@zhiyuan1i zhiyuan1i closed this Feb 10, 2026
@zhiyuan1i zhiyuan1i reopened this Feb 10, 2026
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Feb 10, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review

Walkthrough

Adds an intra‑card context‑parallel varlen path: a common backend registry and IntraCardCPBackend, a new intracard orchestration module to split/process subsequences, Triton kernel extensions for multi‑sequence/intracard modes, cu_seqlens_cpu threaded through the KDA call path, benchmark updates, and new varlen prefill tests.

Changes

Cohort / File(s) Summary
Intracard Backend Registry
fla/ops/common/backends/__init__.py, fla/ops/common/backends/intracard.py
Adds common_registry, re‑exports dispatch, and implements IntraCardCPBackend with verifier + delegation to intracard flow; exposes MAX_SUBSEQS.
Intracard Orchestration
fla/ops/common/intracard_cp.py
New module implementing subsequence splitting (SplitSeqInfo), compute_subseq_len, prepare_subseq_cu_seqlens, pre‑scan, merge, index precomputation, and public intracard_fwd_h to coordinate split → pre‑scan → merge → per‑subseq forward.
CP / Triton kernels
fla/ops/cp/chunk_delta_h.py, fla/ops/common/chunk_delta_h.py
Kernel signatures extended: MULTI_SEQS added; merge_fwd_bwd_kernel gains seq_offsets, init_offsets, h0_seq_ids, h0, INTRACARD_MODE, NUM_SEQ_ENTRIES, HAS_H0; dispatch decorator applied and cu_seqlens_cpu threaded into common entry. Adjusted offset/stride handling and specializations.
KDA call chain propagation
fla/ops/kda/chunk.py, fla/ops/kda/chunk_fwd.py
Threads new cu_seqlens_cpu argument through ChunkKDAFunction.forward and chunk_kda_fwd into the dispatch/kernel layer.
Benchmarks
benchmarks/cp/benchmark_chunk_delta_h_kernels.py
Updated benchmark kernel invocations to pass new optional args (e.g., MULTI_SEQS=False, seq_offsets=None, init_offsets=None, h0_seq_ids=None, h0=None, INTRACARD_MODE=False, NUM_SEQ_ENTRIES=0).
Tests
tests/ops/test_gated_delta.py, tests/ops/test_kda.py
Adds parameterized test_chunk_varlen_prefill tests validating varlen prefill / initial‑state behavior against reference recurrent implementations.

Sequence Diagram(s)

sequenceDiagram
    participant User as User/Test
    participant ChunkKDA as ChunkKDAFunction
    participant ChunkFwd as chunk_kda_fwd
    participant Dispatch as dispatch('common')
    participant IntraCardBE as IntraCardCPBackend
    participant Intracard as intracard_fwd_h
    participant Kernel as Kernels (pre_scan / merge / fwd)

    User->>ChunkKDA: forward(..., cu_seqlens, cu_seqlens_cpu)
    ChunkKDA->>ChunkFwd: chunk_kda_fwd(..., cu_seqlens_cpu)
    ChunkFwd->>Dispatch: chunk_gated_delta_rule_fwd_h(..., cu_seqlens, cu_seqlens_cpu)
    Dispatch->>IntraCardBE: chunk_gated_delta_rule_fwd_h_verifier(cu_seqlens)
    IntraCardBE-->>Dispatch: available?
    Dispatch->>Intracard: intracard_fwd_h(..., cu_seqlens_cpu, max_splits)
    Intracard->>Intracard: compute_subseq_len / prepare_subseq_cu_seqlens
    alt splitting needed
        Intracard->>Kernel: intracard_pre_scan(...)
        Kernel-->>Intracard: hm (pre-scan)
        Intracard->>Kernel: merge_fwd_bwd_kernel(INTRACARD_MODE=True, seq_offsets, init_offsets, ...)
        Kernel-->>Intracard: merged states
        Intracard->>Kernel: _raw_chunk_gated_delta_rule_fwd_h per-subseq
        Kernel-->>Intracard: h, v_new, final_state
    else no split
        Intracard->>Kernel: _raw_chunk_gated_delta_rule_fwd_h(...)
        Kernel-->>Intracard: h, v_new, final_state
    end
    Intracard-->>Dispatch: results
    Dispatch-->>ChunkFwd: results
    ChunkFwd-->>ChunkKDA: outputs
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Possibly related PRs

Suggested reviewers

  • yzhangcs
  • Nathancgy

Poem

🐰✨

I split long hops into nimble sub-leaps,
cards hum in parallel while memory keeps,
pre‑scan then merge, the final state peeks,
cu_seqlens threaded, the sequence softly sleeps,
a rabbit's small hop — inference leaps!

🚥 Pre-merge checks | ✅ 2 | ❌ 1
❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 25.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title accurately summarizes the main change: adding intra-card context parallel optimization for KDA and GDN operations, which is reflected across all modified files introducing the IntraCardCPBackend and associated intracard processing infrastructure.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch lzy/speedup-prefill

No actionable comments were generated in the recent review. 🎉

🧹 Recent nitpick comments
fla/ops/common/intracard_cp.py (1)

314-327: Consider adding strict=True to zip() calls for defensive programming.

All three zip() calls iterate over attributes from the same SplitSeqInfo, so the lengths are guaranteed equal. However, strict=True (Python 3.10+) would surface bugs early if the invariant ever breaks.

Proposed fix
-    for s, n in zip(starts, num_ss):
+    for s, n in zip(starts, num_ss, strict=True):
         cu_seqlens_split_values.extend(cu_seqlens_subseq_values[s:s + n + 1])
         S_split_total += n
...
-    for sid, nss in zip(split_ids, num_ss):
+    for sid, nss in zip(split_ids, num_ss, strict=True):
         num_subseqs_per_seq[sid] = nss
...
-    for s, n in zip(starts, num_ss):
+    for s, n in zip(starts, num_ss, strict=True):
         for j in range(1, n):

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces an intra-card context parallel (CP) optimization for the delta rule kernels, aimed at accelerating prefill during inference. The changes are extensive, touching core kernels, adding a new backend dispatch system, and implementing the complex logic for splitting sequences and merging states. The addition of new tests and benchmarks for the feature is great. I've found one critical issue that could lead to a runtime error in an edge case. Other than that, the implementation looks solid.

Comment thread fla/ops/common/intracard_cp.py
@zhiyuan1i zhiyuan1i marked this pull request as ready for review February 11, 2026 15:24
@zhiyuan1i zhiyuan1i changed the title [Deltarule] Add intracard cp for gdn and kda [Deltarule] Added intra-card context parallel optimization for KDA and GDN Feb 11, 2026
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🤖 Fix all issues with AI agents
In `@fla/ops/common/intracard_cp.py`:
- Around line 62-64: The final_state is allocated with k.new_empty(...) which
leaves padded slots uninitialized; replace the allocation in intracard_cp.py
where final_state is set (variable name final_state, allocation via k.new_empty)
with a zero-initialized allocation using k.new_zeros(...) (matching the pattern
from chunk_gated_delta_rule_fwd_h) so padded positions are deterministic for
CUDA Graphs/vLLM; keep the dtype and shape the same and preserve the conditional
(output_final_state) logic.

In `@tests/ops/test_gated_delta.py`:
- Line 400: Remove the unnecessary .requires_grad_() call inside the
`@torch.inference_mode`() block: in the mapping that assigns q, k, v, beta, g, h0
(the lambda used on those variables) only move tensors to device (x.to(device))
and do not call .requires_grad_(); if any of these tensors actually need
gradients for a specific test, wrap that test in torch.enable_grad() instead of
setting requires_grad_() inside inference mode.

In `@tests/ops/test_kda.py`:
- Around line 569-571: The test is running under `@torch.inference_mode`(), so
calling requires_grad_() on tensors (q, k, v, g, beta, h0 and conditionally
A_log, dt_bias) is unnecessary and can raise RuntimeError; remove the
requires_grad_() calls and just move the tensors to device (e.g., replace
map(lambda x: x.to(device).requires_grad_(), ...) with map(lambda x:
x.to(device), ...) for q, k, v, g, beta, h0 and similarly for A_log and dt_bias
when use_gate_in_kernel is true) so tensors are placed on the device without
enabling gradient tracking.
🧹 Nitpick comments (7)
fla/ops/common/chunk_delta_h.py (1)

468-480: cu_seqlens_cpu is unused in the default implementation (intentional for dispatch).

Ruff flags cu_seqlens_cpu as unused (ARG001). This is expected — the parameter exists so that backends (e.g., IntraCardCPBackend) receive it via the @dispatch('common') wrapper. The default function correctly ignores it. Consider adding a brief comment or a noqa: ARG001 annotation to make the intent clear.

fla/ops/common/backends/intracard.py (1)

22-31: is_available override is redundant.

Since package_name = None, BaseBackend.is_available() already returns True. The override is harmless but unnecessary.

fla/ops/common/intracard_cp.py (5)

148-171: Splitting threshold (4×) vs early-return threshold (2×) causes a wasted-work window.

prepare_subseq_cu_seqlens splits a sequence only if seq_len_i >= 4 * subseq_len (line 154), but intracard_fwd_h skips the early-return path when any sequence is ≥ 2 * subseq_len (line 394). When the longest sequence falls in [2×, 4×), the code enters the splitting path, calls prepare_subseq_cu_seqlens, finds nothing to split, and falls back — doing unnecessary CPU work. This is functionally correct (the not split_info guard on line 399 catches it), but the mismatch is worth a comment or aligning the thresholds.


290-362: Consider strict=True in zip() calls for safety.

The zip() calls on lines 315, 321, and 326 iterate over parallel lists from SplitSeqInfo that are expected to have equal length. Adding strict=True would catch bugs where the lists have mismatched lengths, at negligible runtime cost for these small arrays.

Proposed fix
-    for s, n in zip(starts, num_ss):
+    for s, n in zip(starts, num_ss, strict=True):
         cu_seqlens_split_values.extend(cu_seqlens_subseq_values[s:s + n + 1])
         S_split_total += n

     # num_subseqs_per_seq: [N_orig], default 1 for unsplit sequences
     num_subseqs_per_seq = [1] * N_orig
-    for sid, nss in zip(split_ids, num_ss):
+    for sid, nss in zip(split_ids, num_ss, strict=True):
         num_subseqs_per_seq[sid] = nss

     # non_first_indices: for scattering merged initial states
     non_first_indices: list[int] = []
-    for s, n in zip(starts, num_ss):
+    for s, n in zip(starts, num_ss, strict=True):
         for j in range(1, n):

393-409: split_info relies on short-circuit evaluation to avoid UnboundLocalError — fragile pattern.

When early_return is True, split_info is never assigned (lines 395-398 are skipped). Line 399 works only because or short-circuits. This is correct but fragile — a future refactor (e.g., changing or to and, or splitting the condition) could introduce a runtime crash.

Consider initializing split_info before the branch:

Proposed fix
+    split_info = False
+    total_subseqs = 0
     early_return = (seq_lens < 2 * subseq_len).all()
     if not early_return:
         cu_seqlens_subseq, split_info, total_subseqs = prepare_subseq_cu_seqlens(
             cu_seqlens_cpu, subseq_len, chunk_size, max_splits=max_splits
         )
     if early_return or not split_info:

415-425: Unused variable num_subseqs_per_seq.

num_subseqs_per_seq is unpacked on line 418 but never referenced afterward. Prefix with _ to signal intent and silence the Ruff RUF059 warning.

Proposed fix
     (
         cu_seqlens_split_values,
         S_split_total,
-        num_subseqs_per_seq,
+        _num_subseqs_per_seq,
         non_first_indices,
         first_subseq_indices,
         last_subseq_indices,

448-456: Indexing with Python lists into a GPU tensor triggers implicit conversion.

first_subseq_indices, non_first_indices, and last_subseq_indices are Python list[int]. When used as indices into GPU tensors (lines 451, 453, 456, 476), PyTorch will convert them to a CPU LongTensor each time, then transfer to GPU. This works correctly, but if the number of sub-sequences grows, the repeated implicit conversions could add overhead. Given that these are typically small arrays and this is an optimization-focused module, consider pre-converting to GPU tensors once.

Comment thread fla/ops/common/intracard_cp.py
Comment thread tests/ops/test_gated_delta.py Outdated
Comment thread tests/ops/test_kda.py Outdated
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🤖 Fix all issues with AI agents
In `@fla/ops/common/intracard_cp.py`:
- Around line 112-115: The comment block incorrectly states the split threshold
as 2 * subseq_len; update the comment to reflect the actual logic:
prepare_subseq_cu_seqlens computes four_subseq_len = 4 * subseq_len and uses
that as the split threshold, while intracard_fwd_h uses an early-return check at
2 * subseq_len (a looser gate), so clarify both thresholds and their roles (2×
for early-return, 4× for actual split) and adjust the numeric examples
accordingly (e.g., chunk_size=64, MIN_SUBSEQ_CHUNKS=128 → subseq_len >= 8192
tokens, split threshold = 4 * subseq_len = 32768 tokens).
- Around line 394-409: The early-return logic relies on short-circuiting and can
leave split_info (and cu_seqlens_subseq/total_subseqs) unbound; change the flow
so you return immediately when early_return is True instead of using the
combined condition. Concretely, check if early_return: return the call to
_raw_chunk_gated_delta_rule_fwd_h(...) right away; only call
prepare_subseq_cu_seqlens(...) and use split_info when early_return is False,
and then keep the existing fallback of returning when split_info is empty. This
removes reliance on short-circuit evaluation and prevents UnboundLocalError.
- Around line 448-456: initial_state_expanded is allocated with new_empty and
leaves non-first subseq slots uninitialized when num_non_first > 0 but
initial_states_merge is None; update the allocation/assignment in the block
using initial_state_expanded, initial_state, initial_states_merge,
first_subseq_indices, non_first_indices, and num_non_first so those non-first
slots are deterministically initialized (either allocate with zeros/new_zeros or
explicitly set initial_state_expanded[non_first_indices] = 0 when
initial_states_merge is None) while preserving the existing assignments to
first_subseq_indices.
🧹 Nitpick comments (2)
fla/ops/common/intracard_cp.py (2)

415-425: Unused variable num_subseqs_per_seq.

Static analysis (RUF059) flags num_subseqs_per_seq as unused after destructuring. Prefix with _ to signal intent.

Proposed fix
     (
         cu_seqlens_split_values,
         S_split_total,
-        num_subseqs_per_seq,
+        _num_subseqs_per_seq,
         non_first_indices,
         first_subseq_indices,
         last_subseq_indices,

315-317: Consider adding strict=True to zip() calls for defensive checking.

All three zip() calls in this function iterate over paired lists from SplitSeqInfo which should always be the same length by construction. Adding strict=True (Python 3.10+) would catch any future mismatch early instead of silently truncating. This is also flagged by Ruff B905.

Proposed fix
-    for s, n in zip(starts, num_ss):
+    for s, n in zip(starts, num_ss, strict=True):
         cu_seqlens_split_values.extend(cu_seqlens_subseq_values[s:s + n + 1])
         S_split_total += n

-    for sid, nss in zip(split_ids, num_ss):
+    for sid, nss in zip(split_ids, num_ss, strict=True):
         num_subseqs_per_seq[sid] = nss

-    for s, n in zip(starts, num_ss):
+    for s, n in zip(starts, num_ss, strict=True):
         for j in range(1, n):

Also applies to: 321-322, 326-328

Comment thread fla/ops/common/intracard_cp.py Outdated
Comment thread fla/ops/common/intracard_cp.py
Comment thread fla/ops/common/intracard_cp.py Outdated
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Fix all issues with AI agents
In `@fla/ops/common/intracard_cp.py`:
- Around line 415-425: The tuple unpack from _precompute_intracard_indices
assigns an unused variable num_subseqs_per_seq; rename it to prefix with an
underscore (e.g., _num_subseqs_per_seq) in the unpacking expression to signal
intentional unusedness and satisfy static analysis; update the tuple on the call
site where _precompute_intracard_indices(...) is unpacked so all other names
(cu_seqlens_split_values, S_split_total, non_first_indices,
first_subseq_indices, last_subseq_indices, num_non_first, merge_seq_offsets,
merge_init_offsets) remain unchanged.
🧹 Nitpick comments (2)
fla/ops/common/intracard_cp.py (2)

315-317: Consider adding strict=True to zip() calls to catch length mismatches.

Lines 315, 321, and 326 use zip() on lists from SplitSeqInfo that should always be equal length. Adding strict=True would catch bugs from mismatched list lengths early (Python 3.10+). Flagged by Ruff B905.

Proposed fix (one example — apply to all three)
-    for s, n in zip(starts, num_ss):
+    for s, n in zip(starts, num_ss, strict=True):

290-362: _precompute_intracard_indices returns a 9-element tuple — consider a structured return type.

A 9-element tuple is hard to unpack correctly and fragile to extend. A NamedTuple or dataclass (similar to SplitSeqInfo) would improve readability and make the call site at lines 415–425 less error-prone.

Comment thread fla/ops/common/intracard_cp.py
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Fix all issues with AI agents
In `@fla/ops/common/intracard_cp.py`:
- Around line 78-119: The docstring and inline comments in compute_subseq_len
incorrectly state the split threshold as "2 * subseq_len" while the actual
threshold used by prepare_subseq_cu_seqlens is "4 * subseq_len"; update the
compute_subseq_len docstring and any related inline comments to reflect the
correct split threshold (4 * subseq_len) and mention prepare_subseq_cu_seqlens
by name so readers know where the threshold is enforced.
🧹 Nitpick comments (3)
tests/ops/test_kda.py (1)

550-556: Redundant .to(device) calls on tensors already on device.

A_log and dt_bias (lines 555–556) are created with device=device and then .to(device) again. Similarly, q/k/v/g (lines 550–553) could use device=device directly like test_chunk_varlen does for A_log/dt_bias. This is a cosmetic nit — no functional impact.

fla/ops/common/intracard_cp.py (2)

314-327: Consider adding strict=True to zip() calls for safety.

Ruff B905 flags three zip() calls without strict=. While the lists should always be equal-length (they come from the same SplitSeqInfo), strict=True provides a runtime check that catches index computation bugs early.


122-185: prepare_subseq_cu_seqlens — return type SplitSeqInfo | bool is a bit unusual.

Returning False (line 135, 174) where SplitSeqInfo is expected works because of the __bool__ check at the call site, but it makes the type annotation on line 127 a bit awkward. A future caller using pattern matching or isinstance checks would be surprised. Not blocking, just noting.

Comment thread fla/ops/common/intracard_cp.py
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant