Skip to content

[Misc.] Fast testing & Autotune#476

Merged
yzhangcs merged 68 commits intomainfrom
fast_test_autotune
Jun 26, 2025
Merged

[Misc.] Fast testing & Autotune#476
yzhangcs merged 68 commits intomainfrom
fast_test_autotune

Conversation

@sustcsonglin
Copy link
Copy Markdown
Collaborator

@sustcsonglin sustcsonglin commented Jun 21, 2025

Summary by CodeRabbit

  • Chores
    • Simplified and consolidated test parameterizations across multiple test files, removing environment-dependent and conditional logic for explicit, fixed test cases.
    • Enhanced CI workflows to dynamically select Conda environments and updated runner assignments; removed PyTorch 2.6 job.
    • Improved kernel autotuning configurations for several Triton kernels, reducing unnecessary combinations and clarifying block size selection.
    • Updated conditional support for torch.compile based on Python version in certain modules.
    • Minor code readability improvements and cleanup of unused imports, warnings, and leading spaces in test assertion labels.
  • Bug Fixes
    • Fixed potential test skipping logic by basing skips on actual device platform detection instead of environment variables.
  • Refactor
    • Streamlined test logic and removed redundant or overly complex test cases, focusing on essential coverage.
    • Removed an entire cumulative sum test suite to reduce redundancy.
    • Deleted a fused attention kernel test file to consolidate testing efforts.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Jun 21, 2025

Caution

Review failed

The pull request is closed.

Walkthrough

This update streamlines and simplifies test parameterizations, removes environment-dependent logic, and adjusts Triton kernel autotuning configurations across several files. It also introduces conditional support for torch.compile in one module, updates CI workflow environment handling, and modifies block size calculations in a kernel function. No major functional logic or control flow changes are introduced.

Changes

File(s) Change Summary
fla/ops/attn/parallel.py Removed warp count 1 from Triton kernel autotuning configs; only [2, 4] (plus [8] for 'hopper') are now used.
fla/ops/common/chunk_o.py Replaced parameter grid autotuning with three explicit configs for chunk_fwd_kernel_o.
fla/ops/gsa/fused_recurrent.py Block sizes BK, BV, BM now use triton.next_power_of_2 for dimension rounding before min with 64.
fla/ops/rwkv7/fused_addcmul.py Added Python-version-conditional torch.compile decorator for addcmul_bwd2, using a no-op if version ≤ 3.10.
fla/ops/utils/cumsum.py Removed warnings import and runtime warning; minor code readability tweaks; made accumulation unconditional in one kernel.
fla/ops/generalized_delta_rule/iplr/chunk.py Restricted warp counts for Hopper GPUs; removed num_stages variation in one kernel's autotuning configs.
tests/ops/test_attn.py Simplified parameterization for all tests; removed dtype as parameter; fixed test cases; cleaned up imports and logic.
tests/ops/test_delta.py Simplified and fixed test parameters; removed environment logic; removed test_l2_in_kernel; reduced parameterization.
tests/ops/test_delta_product.py Simplified test parameters; removed environment skips; fixed test values; now skips on Intel platforms.
tests/ops/test_forgetting_attn.py Simplified and fixed test parameterization; removed environment-based skips; fixed gating tensor initialization.
tests/ops/test_rwkv6.py Removed conditional parameter lists; now uses explicit parameter tuples; cleaned up imports.
tests/ops/test_rwkv7.py Reduced test coverage to only largest embedding/feedforward sizes in parameterization.
tests/ops/test_dplr_delta.py Consolidated parameterization to explicit tuples; removed conditional logic and unused imports; added stochastic mask parameter.
tests/ops/test_comba.py Consolidated parameterization; removed leading spaces from assertion labels; cleaned imports.
tests/ops/test_gated_delta.py Removed leading spaces from assertion label strings in assert_close calls.
tests/ops/test_gated_delta_product.py Removed leading spaces from assertion label strings in assert_close calls.
tests/ops/test_gla.py Removed leading spaces from assertion label strings in assert_close calls.
tests/ops/test_gsa.py Removed leading spaces from assertion label strings in assert_close calls.
tests/ops/test_hgrn.py Removed leading spaces from assertion label strings in assert_close calls; simplified parameterization.
tests/ops/test_iplr_delta.py Renamed test functions; added gradient checks in fused recurrent test; removed gradients in chunk test; simplified parameterization.
tests/ops/test_linear_attn.py Simplified parameterization; unified gradient computations; removed leading spaces from assertion labels.
tests/ops/test_nsa.py Simplified parameterization; updated function signatures; removed leading spaces from assertion labels.
tests/ops/test_retention.py Simplified parameterization; unified reference outputs; removed leading spaces from assertion labels; removed unused imports.
tests/ops/test_simple_gla.py Simplified parameterization; added explicit cu_seqlens parameters; removed leading spaces from assertion labels.
tests/ops/test_utils.py Simplified parameterization; added explicit cu_seqlens parameters; removed leading spaces from assertion labels.
tests/ops/test_cumsum.py Deleted entire test module for cumulative sum operations.
tests/test_fused_chunk.py Deleted entire test module implementing fused attention kernel and autograd function.
tests/ops/test_path_attn.py Simplified parameterization to explicit tuples; reordered parameters; removed COMPILER_MODE import.
tests/ops/test_based.py Simplified parameterization to explicit tuples; removed environment skips and imports.
tests/ops/test_mesa.py Removed COMPILER_MODE import and conditional test parameter lists; no other changes.
tests/ops/test_solve_tril.py Simplified parameterization; added type annotations; updated input tensor shapes; removed conditional skips.
tests/ops/test_titans.py Simplified test; removed head_first and scale parameters; removed commented-out tests; removed conditional skips and imports.
tests/ops/test_ttt.py Simplified parameterization; removed head_first parameter; standardized tensor shapes; replaced random splits with explicit cu_seqlens.
.github/workflows/reusable-ci-tests.yml Made Conda env selection dynamic by runner; consolidated pytest invocations; improved logging and env handling.
.github/workflows/nvidia-h100.yml Changed runner labels for two jobs; removed PyTorch 2.6 job entirely.

Sequence Diagram(s)

sequenceDiagram
    participant User
    participant TestSuite
    participant Kernel
    participant CI_Workflow

    User->>TestSuite: Run tests (with fixed parameters)
    TestSuite->>Kernel: Launch kernel with new autotuning configs
    CI_Workflow->>CI_Workflow: Dynamically select Conda env by runner
    CI_Workflow->>TestSuite: Run all tests in one pytest invocation
    Kernel-->>TestSuite: Return results
    TestSuite-->>User: Report test outcomes
Loading

Possibly related PRs

  • #260: Adds Triton parallel attention kernels in fla/ops/attn/parallel.py, which are the same kernels whose autotuning configs are updated here.
  • #256: Modifies autotuning warp counts conditionally based on device capabilities in fla/ops/common/chunk_o.py, related to autotuning changes here.
  • #291: Refactors tests/ops/test_forgetting_attn.py with parameterization and gating tensor initialization changes, related to test simplifications here.

Poem

In the garden where test seeds grow,
The rabbit hops with kernels in tow.
No more warps of one, just two and four,
Parameter lists clutter the ground no more.
Conda paths twist, but CI runs true—
With simpler tests, the carrots accrue!
🥕✨


📜 Recent review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between c306b32 and ebcd751.

📒 Files selected for processing (2)
  • tests/ops/test_attn.py (3 hunks)
  • tests/ops/test_rwkv6.py (3 hunks)
✨ Finishing Touches
  • 📝 Generate Docstrings
🧪 Generate Unit Tests
  • Create PR with Unit Tests
  • Post Copyable Unit Tests in Comment
  • Commit Unit Tests in branch fast_test_autotune

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share
🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Explain this complex logic.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query. Examples:
    • @coderabbitai explain this code block.
    • @coderabbitai modularize this function.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read src/utils.ts and explain its main purpose.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.
    • @coderabbitai help me debug CodeRabbit configuration file.

Support

Need help? Create a ticket on our support page for assistance with any issues or questions.

Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments.

CodeRabbit Commands (Invoked using PR comments)

  • @coderabbitai pause to pause the reviews on a PR.
  • @coderabbitai resume to resume the paused reviews.
  • @coderabbitai review to trigger an incremental review. This is useful when automatic reviews are disabled for the repository.
  • @coderabbitai full review to do a full review from scratch and review all the files again.
  • @coderabbitai summary to regenerate the summary of the PR.
  • @coderabbitai generate docstrings to generate docstrings for this PR.
  • @coderabbitai generate sequence diagram to generate a sequence diagram of the changes in this PR.
  • @coderabbitai auto-generate unit tests to generate unit tests for this PR.
  • @coderabbitai resolve resolve all the CodeRabbit review comments.
  • @coderabbitai configuration to show the current CodeRabbit configuration for the repository.
  • @coderabbitai help to get help.

Other keywords and placeholders

  • Add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai anywhere in the PR title to generate the title automatically.

CodeRabbit Configuration File (.coderabbit.yaml)

  • You can programmatically configure CodeRabbit by adding a .coderabbit.yaml file to the root of your repository.
  • Please see the configuration documentation for more information.
  • If your editor has YAML language server enabled, you can add the path at the top of this file to enable auto-completion and validation: # yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json

Documentation and Community

  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

@yzhangcs yzhangcs linked an issue Jun 23, 2025 that may be closed by this pull request
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (10)
tests/models/test_modeling_hgrn.py (2)

21-29: Suppress Pylint overly-strict argument warnings
Static analysis flags R0913/R0917 due to 7 parameters; consider disabling them for this test:

 @pytest.mark.parametrize(... )
+# pylint: disable=R0913,R0917
 def test_modeling(

42-49: Suppress Pylint overly-strict argument warnings
6 parameters trigger R0913/R0917; you can silence this with:

 @pytest.mark.parametrize(... )
+# pylint: disable=R0913,R0917
 def test_generation(
tests/models/test_modeling_deltanet.py (2)

21-29: Suppress Pylint overly-strict argument warnings
To avoid R0913/R0917 complaints for this 7-arg test, add:

 @pytest.mark.parametrize(... )
+# pylint: disable=R0913,R0917
 def test_modeling(

42-49: Suppress Pylint overly-strict argument warnings
Disable R0913/R0917 for this 6-arg test to silence false positives:

 @pytest.mark.parametrize(... )
+# pylint: disable=R0913,R0917
 def test_generation(
tests/models/test_modeling_transformer.py (2)

24-32: Suppress Pylint overly-strict argument warnings
7 parameters trigger R0913/R0917; consider disabling for this function:

 @pytest.mark.parametrize(... )
+# pylint: disable=R0913,R0917
 def test_modeling(

45-52: Suppress Pylint overly-strict argument warnings
6 parameters prompts R0913/R0917; disable to avoid noisy warnings:

 @pytest.mark.parametrize(... )
+# pylint: disable=R0913,R0917
 def test_generation(
tests/models/test_modeling_mamba.py (2)

21-29: Suppress Pylint overly-strict argument warnings
To silence R0913/R0917 for this test, add:

 @pytest.mark.parametrize(... )
+# pylint: disable=R0913,R0917
 def test_modeling(

42-49: Suppress Pylint overly-strict argument warnings
6-arg signature triggers R0913/R0917; consider disabling:

 @pytest.mark.parametrize(... )
+# pylint: disable=R0913,R0917
 def test_generation(
tests/models/test_modeling_comba.py (2)

21-29: Suppress Pylint overly-strict argument warnings
Add a disable for R0913/R0917 to prevent noise:

 @pytest.mark.parametrize(... )
+# pylint: disable=R0913,R0917
 def test_modeling(

42-49: Suppress Pylint overly-strict argument warnings
To avoid false positives on argument count, use:

 @pytest.mark.parametrize(... )
+# pylint: disable=R0913,R0917
 def test_generation(
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between acc38de and 542a4bd.

📒 Files selected for processing (24)
  • tests/models/test_modeling_abc.py (2 hunks)
  • tests/models/test_modeling_bitnet.py (2 hunks)
  • tests/models/test_modeling_comba.py (2 hunks)
  • tests/models/test_modeling_deltanet.py (2 hunks)
  • tests/models/test_modeling_forgetting_transformer.py (2 hunks)
  • tests/models/test_modeling_gated_deltanet.py (2 hunks)
  • tests/models/test_modeling_gated_deltaproduct.py (1 hunks)
  • tests/models/test_modeling_gla.py (2 hunks)
  • tests/models/test_modeling_gsa.py (2 hunks)
  • tests/models/test_modeling_hgrn.py (2 hunks)
  • tests/models/test_modeling_hgrn2.py (2 hunks)
  • tests/models/test_modeling_lightnet.py (2 hunks)
  • tests/models/test_modeling_linear_attn.py (2 hunks)
  • tests/models/test_modeling_mamba.py (2 hunks)
  • tests/models/test_modeling_mamba2.py (2 hunks)
  • tests/models/test_modeling_mesanet.py (2 hunks)
  • tests/models/test_modeling_nsa.py (2 hunks)
  • tests/models/test_modeling_path_attn.py (2 hunks)
  • tests/models/test_modeling_retnet.py (2 hunks)
  • tests/models/test_modeling_rodimus.py (2 hunks)
  • tests/models/test_modeling_rwkv6.py (2 hunks)
  • tests/models/test_modeling_rwkv7.py (1 hunks)
  • tests/models/test_modeling_samba.py (2 hunks)
  • tests/models/test_modeling_transformer.py (2 hunks)
🧰 Additional context used
🪛 Pylint (3.3.7)
tests/models/test_modeling_abc.py

[refactor] 21-21: Too many arguments (7/5)

(R0913)


[refactor] 21-21: Too many positional arguments (7/5)

(R0917)


[refactor] 42-42: Too many arguments (6/5)

(R0913)


[refactor] 42-42: Too many positional arguments (6/5)

(R0917)

tests/models/test_modeling_bitnet.py

[refactor] 21-21: Too many arguments (7/5)

(R0913)


[refactor] 21-21: Too many positional arguments (7/5)

(R0917)


[refactor] 42-42: Too many arguments (6/5)

(R0913)


[refactor] 42-42: Too many positional arguments (6/5)

(R0917)

tests/models/test_modeling_comba.py

[refactor] 21-21: Too many arguments (7/5)

(R0913)


[refactor] 21-21: Too many positional arguments (7/5)

(R0917)


[refactor] 42-42: Too many arguments (6/5)

(R0913)


[refactor] 42-42: Too many positional arguments (6/5)

(R0917)

tests/models/test_modeling_deltanet.py

[refactor] 21-21: Too many arguments (7/5)

(R0913)


[refactor] 21-21: Too many positional arguments (7/5)

(R0917)


[refactor] 42-42: Too many arguments (6/5)

(R0913)


[refactor] 42-42: Too many positional arguments (6/5)

(R0917)

tests/models/test_modeling_forgetting_transformer.py

[refactor] 21-21: Too many arguments (7/5)

(R0913)


[refactor] 21-21: Too many positional arguments (7/5)

(R0917)


[refactor] 42-42: Too many arguments (6/5)

(R0913)


[refactor] 42-42: Too many positional arguments (6/5)

(R0917)

tests/models/test_modeling_gated_deltanet.py

[refactor] 21-21: Too many arguments (7/5)

(R0913)


[refactor] 21-21: Too many positional arguments (7/5)

(R0917)


[refactor] 42-42: Too many arguments (6/5)

(R0913)


[refactor] 42-42: Too many positional arguments (6/5)

(R0917)

tests/models/test_modeling_gated_deltaproduct.py

[refactor] 24-24: Too many arguments (7/5)

(R0913)


[refactor] 24-24: Too many positional arguments (7/5)

(R0917)

tests/models/test_modeling_gla.py

[refactor] 21-21: Too many arguments (7/5)

(R0913)


[refactor] 21-21: Too many positional arguments (7/5)

(R0917)


[refactor] 42-42: Too many arguments (6/5)

(R0913)


[refactor] 42-42: Too many positional arguments (6/5)

(R0917)

tests/models/test_modeling_gsa.py

[refactor] 21-21: Too many arguments (7/5)

(R0913)


[refactor] 21-21: Too many positional arguments (7/5)

(R0917)


[refactor] 42-42: Too many arguments (6/5)

(R0913)


[refactor] 42-42: Too many positional arguments (6/5)

(R0917)

tests/models/test_modeling_hgrn.py

[refactor] 21-21: Too many arguments (7/5)

(R0913)


[refactor] 21-21: Too many positional arguments (7/5)

(R0917)


[refactor] 42-42: Too many arguments (6/5)

(R0913)


[refactor] 42-42: Too many positional arguments (6/5)

(R0917)

tests/models/test_modeling_hgrn2.py

[refactor] 21-21: Too many arguments (7/5)

(R0913)


[refactor] 21-21: Too many positional arguments (7/5)

(R0917)


[refactor] 42-42: Too many arguments (6/5)

(R0913)


[refactor] 42-42: Too many positional arguments (6/5)

(R0917)

tests/models/test_modeling_lightnet.py

[refactor] 21-21: Too many arguments (7/5)

(R0913)


[refactor] 21-21: Too many positional arguments (7/5)

(R0917)


[refactor] 42-42: Too many arguments (6/5)

(R0913)


[refactor] 42-42: Too many positional arguments (6/5)

(R0917)

tests/models/test_modeling_linear_attn.py

[refactor] 21-21: Too many arguments (7/5)

(R0913)


[refactor] 21-21: Too many positional arguments (7/5)

(R0917)


[refactor] 42-42: Too many arguments (6/5)

(R0913)


[refactor] 42-42: Too many positional arguments (6/5)

(R0917)

tests/models/test_modeling_mamba.py

[refactor] 21-21: Too many arguments (7/5)

(R0913)


[refactor] 21-21: Too many positional arguments (7/5)

(R0917)


[refactor] 42-42: Too many arguments (6/5)

(R0913)


[refactor] 42-42: Too many positional arguments (6/5)

(R0917)

tests/models/test_modeling_mamba2.py

[refactor] 21-21: Too many arguments (7/5)

(R0913)


[refactor] 21-21: Too many positional arguments (7/5)

(R0917)


[refactor] 42-42: Too many arguments (6/5)

(R0913)


[refactor] 42-42: Too many positional arguments (6/5)

(R0917)

tests/models/test_modeling_mesanet.py

[refactor] 21-21: Too many arguments (7/5)

(R0913)


[refactor] 21-21: Too many positional arguments (7/5)

(R0917)


[refactor] 42-42: Too many arguments (6/5)

(R0913)


[refactor] 42-42: Too many positional arguments (6/5)

(R0917)

tests/models/test_modeling_nsa.py

[refactor] 21-21: Too many arguments (7/5)

(R0913)


[refactor] 21-21: Too many positional arguments (7/5)

(R0917)


[refactor] 42-42: Too many arguments (6/5)

(R0913)


[refactor] 42-42: Too many positional arguments (6/5)

(R0917)

tests/models/test_modeling_path_attn.py

[refactor] 21-21: Too many arguments (7/5)

(R0913)


[refactor] 21-21: Too many positional arguments (7/5)

(R0917)


[refactor] 42-42: Too many arguments (6/5)

(R0913)


[refactor] 42-42: Too many positional arguments (6/5)

(R0917)

tests/models/test_modeling_retnet.py

[refactor] 21-21: Too many arguments (7/5)

(R0913)


[refactor] 21-21: Too many positional arguments (7/5)

(R0917)


[refactor] 42-42: Too many arguments (6/5)

(R0913)


[refactor] 42-42: Too many positional arguments (6/5)

(R0917)

tests/models/test_modeling_rodimus.py

[refactor] 21-21: Too many arguments (7/5)

(R0913)


[refactor] 21-21: Too many positional arguments (7/5)

(R0917)


[refactor] 42-42: Too many arguments (6/5)

(R0913)


[refactor] 42-42: Too many positional arguments (6/5)

(R0917)

tests/models/test_modeling_rwkv6.py

[refactor] 21-21: Too many arguments (7/5)

(R0913)


[refactor] 21-21: Too many positional arguments (7/5)

(R0917)


[refactor] 42-42: Too many arguments (6/5)

(R0913)


[refactor] 42-42: Too many positional arguments (6/5)

(R0917)

tests/models/test_modeling_rwkv7.py

[refactor] 21-21: Too many arguments (7/5)

(R0913)


[refactor] 21-21: Too many positional arguments (7/5)

(R0917)


[refactor] 43-43: Too many arguments (6/5)

(R0913)


[refactor] 43-43: Too many positional arguments (6/5)

(R0917)

tests/models/test_modeling_samba.py

[refactor] 21-21: Too many arguments (7/5)

(R0913)


[refactor] 21-21: Too many positional arguments (7/5)

(R0917)


[refactor] 42-42: Too many arguments (6/5)

(R0913)


[refactor] 42-42: Too many positional arguments (6/5)

(R0917)

tests/models/test_modeling_transformer.py

[refactor] 24-24: Too many arguments (7/5)

(R0913)


[refactor] 24-24: Too many positional arguments (7/5)

(R0917)


[refactor] 45-45: Too many arguments (6/5)

(R0913)


[refactor] 45-45: Too many positional arguments (6/5)

(R0917)

⏰ Context from checks skipped due to timeout of 90000ms (2)
  • GitHub Check: test-ops
  • GitHub Check: test-ops
🔇 Additional comments (47)
tests/models/test_modeling_gated_deltaproduct.py (1)

24-32: Add explicit type annotations to test_modeling signature
The multi-line format with clear type hints for each parameter enhances readability and aligns with the PR’s consistency improvements across modeling tests.

tests/models/test_modeling_abc.py (2)

21-29: Annotate and reformat test_modeling parameters
The explicit int, torch.dtype, and bool annotations in a multi-line signature improve clarity and maintain consistency with other model tests.


42-49: Add type hints and multi-line layout to test_generation
The updated signature with int and torch.dtype annotations elevates readability without altering behavior.

tests/models/test_modeling_linear_attn.py (2)

21-29: Explicitly type and reformat test_modeling args
Consistent multi-line formatting and int/torch.dtype/bool hints enhance the signature’s clarity across test suites.


42-49: Reformat test_generation with type annotations
Adding int and torch.dtype hints in a structured layout improves maintainability without side effects.

tests/models/test_modeling_rodimus.py (2)

21-29: Improve test_modeling signature readability
The multi-line parameter list with explicit type hints ensures consistency and clearer intent in the test definitions.


42-49: Enhance test_generation signature with type hints
This refactor maintains existing logic while aligning the signature style with other model tests.

tests/models/test_modeling_bitnet.py (2)

21-29: Add multi-line type-annotated signature to test_modeling
Explicit type hints and formatting keep the test definitions uniform and more readable.


42-49: Reformat and annotate test_generation parameters
Type annotations (int, torch.dtype) and structured layout improve clarity without modifying behavior.

tests/models/test_modeling_hgrn.py (2)

21-29: Add explicit type hints and multiline formatting to test_modeling signature
Type annotations improve clarity and align this test with the rest of the modeling suite.


42-49: Add explicit type hints and multiline formatting to test_generation signature
Consistent annotations across all modeling tests enhance maintainability.

tests/models/test_modeling_deltanet.py (2)

21-29: Add explicit type hints and multiline formatting to test_modeling signature
These annotations match the style applied across other model tests.


42-49: Add explicit type hints and multiline formatting to test_generation signature
Maintains consistency and clarity in the test suite.

tests/models/test_modeling_transformer.py (2)

24-32: Add explicit type hints and multiline formatting to test_modeling signature
Improves readability and enforces type safety in this test.


45-52: Add explicit type hints and multiline formatting to test_generation signature
Aligns with the rest of the modeling tests for consistency.

tests/models/test_modeling_mamba.py (2)

21-29: Add explicit type hints and multiline formatting to test_modeling signature
This enhancement matches the style improvements across all model tests.


42-49: Add explicit type hints and multiline formatting to test_generation signature
Ensures uniformity in test definitions.

tests/models/test_modeling_comba.py (2)

21-29: Add explicit type hints and multiline formatting to test_modeling signature
Type annotations and formatting boost clarity across the suite.


42-49: Add explicit type hints and multiline formatting to test_generation signature
Follows the PR’s consistent signature style.

tests/models/test_modeling_nsa.py (2)

21-29: Add explicit type annotations to test_modeling signature.
The multi-line signature with int, torch.dtype, and bool hints enhances readability and maintains consistency across modeling tests.


42-49: Add explicit type annotations to test_generation signature.
Aligns the function signature with the typed, multi-line style used in other tests, improving clarity.

tests/models/test_modeling_path_attn.py (2)

21-29: Add explicit type annotations to test_modeling signature.
Introducing int, torch.dtype, and bool hints in a multi-line format improves readability and consistency for Path Attention tests.


42-49: Add explicit type annotations to test_generation signature.
Matches the updated, typed signature format used elsewhere, making the test definitions uniform.

tests/models/test_modeling_forgetting_transformer.py (2)

21-29: Add explicit type annotations to test_modeling signature.
The added hints and multi-line layout boost clarity and maintain the standardized style across all model tests.


42-49: Add explicit type annotations to test_generation signature.
Consistent multi-line, typed signatures align this test with the rest of the suite.

tests/models/test_modeling_gla.py (2)

21-29: Add explicit type annotations to test_modeling signature.
Enhances readability by clearly specifying parameter types in line with other GLA tests.


42-49: Add explicit type annotations to test_generation signature.
Maintains uniform test signature style and type clarity across the modeling suite.

tests/models/test_modeling_gsa.py (2)

21-29: Add explicit type annotations to test_modeling signature.
Type hints in the signature improve code clarity and match the formatting of peer tests.


42-49: Add explicit type annotations to test_generation signature.
Consistent application of multi-line, typed signature style enhances maintainability.

tests/models/test_modeling_gated_deltanet.py (2)

21-28: Add explicit type annotations to test_modeling.
The multi-line, typed signature improves readability and enforces consistency across the test suite.


42-49: Add explicit type annotations to test_generation.
Aligns with the standardized function signature format used elsewhere.

tests/models/test_modeling_mamba2.py (2)

21-29: Add explicit type annotations to test_modeling.
The revised, multi-line signature with parameter types enhances clarity and consistency.


42-49: Add explicit type annotations to test_generation.
Standardizes the test signature and improves maintainability.

tests/models/test_modeling_retnet.py (2)

21-29: Add explicit type annotations to test_modeling.
Consistent multi-line formatting and typing mirror other model tests.


42-49: Add explicit type annotations to test_generation.
Maintains uniform signature style across the generation tests.

tests/models/test_modeling_samba.py (2)

21-29: Add explicit type annotations to test_modeling.
Improves type clarity and follows the established formatting convention.


42-49: Add explicit type annotations to test_generation.
Ensures consistency with the rest of the test suite’s signature style.

tests/models/test_modeling_rwkv6.py (2)

21-29: Add explicit type annotations to test_modeling.
The explicit, multi-line signature enhances readability and consistency.


42-49: Add explicit type annotations to test_generation.
Aligns this test’s signature with the updated project conventions.

tests/models/test_modeling_mesanet.py (2)

21-29: Add explicit type hints and multi-line signature to test_modeling.
Standardizes parameter annotations without altering logic, matching other model test suites.


42-49: Add explicit type hints and multi-line signature to test_generation.
Consistent formatting update; no behavioral change.

tests/models/test_modeling_hgrn2.py (2)

21-29: Add explicit type hints and multi-line signature to test_hgrn2_modeling.
Aligns with the unified style across modeling tests and preserves existing behavior.


42-49: Add explicit type hints and multi-line signature to test_generation.
Matches formatting in other test files; no functional impact.

tests/models/test_modeling_rwkv7.py (2)

21-29: Rename and refactor test_modeling signature with explicit type annotations.
Follows the standard multi-line style; internal logic remains unchanged.


43-50: Rename and refactor test_generation signature with explicit type annotations.
Consistent update across model tests without affecting behavior.

tests/models/test_modeling_lightnet.py (2)

21-29: Add explicit type hints and multi-line signature to test_modeling.
Standardizes the function header in line with other model tests; no changes to test logic.


42-49: Add explicit type hints and multi-line signature to test_generation.
Consistency improvement across the suite; behavior remains intact.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (1)
tests/ops/test_forgetting_attn.py (1)

123-123: Add explicit strict= parameter to zip().

The static analysis tool correctly identifies that zip() should have an explicit strict parameter for safety.

Apply this diff to address the static analysis warning:

-    for bos, eos in zip(cu_seqlens[:-1], cu_seqlens[1:]):
+    for bos, eos in zip(cu_seqlens[:-1], cu_seqlens[1:], strict=True):

The strict=True parameter ensures that both iterables have the same length, which should always be the case here since they're slices of the same tensor.

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 542a4bd and c093386.

📒 Files selected for processing (2)
  • tests/ops/test_attn.py (3 hunks)
  • tests/ops/test_forgetting_attn.py (4 hunks)
🧰 Additional context used
🪛 Pylint (3.3.7)
tests/ops/test_attn.py

[refactor] 33-33: Too many arguments (6/5)

(R0913)


[refactor] 33-33: Too many positional arguments (6/5)

(R0917)


[refactor] 33-33: Too many local variables (18/15)

(R0914)


[refactor] 81-81: Too many local variables (18/15)

(R0914)

tests/ops/test_forgetting_attn.py

[refactor] 48-48: Too many arguments (6/5)

(R0913)


[refactor] 48-48: Too many positional arguments (6/5)

(R0917)


[refactor] 48-48: Too many local variables (22/15)

(R0914)


[refactor] 105-105: Too many local variables (23/15)

(R0914)

🪛 Ruff (0.11.9)
tests/ops/test_forgetting_attn.py

123-123: zip() without an explicit strict= parameter

Add explicit value for parameter strict=

(B905)

⏰ Context from checks skipped due to timeout of 90000ms (3)
  • GitHub Check: test-ops
  • GitHub Check: test-ops
  • GitHub Check: test-ops
🔇 Additional comments (12)
tests/ops/test_attn.py (6)

4-4: Good addition of type hints.

Adding List import for proper type annotation improves code clarity.


20-32: Excellent simplification of test parameterization.

The move from environment-dependent conditional parameterization to explicit fixed test cases improves:

  • Test determinism and reproducibility
  • Debugging experience (easier to identify which specific case failed)
  • Code maintainability (no complex conditional logic)

The test cases cover a good range of dimensions and scale values.


43-44: Good placement of flash-attn availability check.

Moving the check inside the test function is cleaner than having it in the parameterization logic.


47-50: Hardcoded dtype ensures test consistency.

Using torch.float16 consistently across all tensors removes environment-dependent behavior and ensures tests run uniformly across different setups.


70-80: Consistent simplification for variable-length tests.

The same beneficial pattern applied to the varlen test - explicit parameterization with fixed cu_seqlens values instead of random generation.


85-89: Clean parameter handling for variable-length sequences.

The explicit cu_seqlens parameter with proper type annotation and tensor conversion is well-structured.

tests/ops/test_forgetting_attn.py (6)

3-3: Good addition of comprehensive type hints.

Adding List, Optional imports supports proper type annotations throughout the file.


35-47: Consistent test simplification pattern.

Same beneficial pattern as in test_attn.py - explicit parameterization improves test determinism and maintainability.


56-57: Good addition of reproducibility controls.

Adding manual seed and hardcoding dtype ensures consistent test behavior across runs and environments.


66-66: Verify the impact of simplified gating tensor initialization.

The change from a complex log-sigmoid scaled initialization to a simple uniform distribution (-0.1, -0.01) simplifies the test but may reduce the range of values being tested. Ensure this still adequately tests the forgetting attention mechanism.

Consider whether the simpler initialization still provides adequate test coverage for the gating mechanism's behavior across different value ranges.


90-100: Consistent parameterization improvements for varlen tests.

Same positive pattern applied to variable-length sequence tests.


119-119: Consistent gating tensor initialization.

Same simplification as in the fixed-length test - ensure this provides adequate coverage.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🔭 Outside diff range comments (1)
fla/ops/utils/cumsum.py (1)

4-4: Remove unused import to fix linting error.

The warnings import is no longer used and is causing a pipeline failure.

-import warnings
🧹 Nitpick comments (1)
fla/ops/utils/cumsum.py (1)

181-182: Consider removing redundant condition for consistency.

Similar to the vector kernel change on line 248, this condition i_c >= 0 is always true since i_c ranges from 0 to NT-1 in the loop. Consider removing it for consistency and clarity.

-        if i_c >= 0:
-            b_z += b_ss
+        b_z += b_ss
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between e389ef0 and ff63caf.

📒 Files selected for processing (1)
  • fla/ops/utils/cumsum.py (3 hunks)
🧰 Additional context used
🪛 GitHub Actions: lint
fla/ops/utils/cumsum.py

[error] 4-4: flake8: 'warnings' imported but unused (F401)

🔇 Additional comments (2)
fla/ops/utils/cumsum.py (2)

168-168: Good formatting improvements for readability.

Adding spaces around the subtraction operator improves code readability and follows Python style conventions.

Also applies to: 235-235


248-248: Verify the conditional logic removal is correct.

The unconditional accumulation of b_z appears to remove a redundant condition (since i_c ranges from 0 to NT-1, i_c >= 0 is always true). However, there's an inconsistency with the scalar kernel which still has this condition on line 181-182.

#!/bin/bash
# Description: Check for similar conditional logic patterns in other kernels
# Expected: Find similar i_c >= 0 conditions that might also be redundant

rg -A 2 -B 2 "if i_c >= 0" 

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (1)
tests/ops/test_gla.py (1)

188-256: Consider refactoring to reduce local variable complexity.

The static analysis tool correctly identifies that this function has many local variables (26 out of 15 recommended). Consider extracting helper functions for tensor creation, forward/backward passes, or assertion checks to improve readability and maintainability.

Example refactor approach:

def _create_test_tensors(N, T, H, D, dtype, device, cu_seqlens):
    """Helper to create test tensors."""
    # Move tensor creation logic here
    pass

def _run_forward_backward(model_fn, tensors, do):
    """Helper to run forward and backward passes."""
    # Move forward/backward logic here
    pass
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 7a1cc65 and cb4afaf.

📒 Files selected for processing (1)
  • tests/ops/test_gla.py (7 hunks)
🧰 Additional context used
🪛 Pylint (3.3.7)
tests/ops/test_gla.py

[refactor] 188-188: Too many local variables (26/15)

(R0914)

⏰ Context from checks skipped due to timeout of 90000ms (2)
  • GitHub Check: test-ops
  • GitHub Check: test-ops
🔇 Additional comments (10)
tests/ops/test_gla.py (10)

4-4: LGTM: Import addition for type hints.

Good addition of the List type for better type annotations.


15-28: LGTM: Improved parameterization with consolidated test cases.

The consolidation of multiple pytest.mark.parametrize decorators into a single unified parameter tuple improves test clarity and maintainability. The explicit test IDs will make test failures easier to identify.


38-38: LGTM: New gate_logit_normalizer parameter.

The addition of the gate_logit_normalizer parameter allows for testing different gate scaling scenarios, which improves test coverage.


43-47: LGTM: Improved tensor initialization and gate normalization.

The changes from torch.randn to torch.rand for input tensors and the application of gate_logit_normalizer for gate scaling provide more controlled test scenarios. The use of F.logsigmoid ensures proper gate value ranges.


89-102: LGTM: Consistent parameterization pattern.

The parameterization matches the pattern established in the first test function, maintaining consistency across the test suite.


118-122: LGTM: Consistent tensor initialization.

The tensor initialization changes are consistent with the first test function, ensuring uniform testing conditions.


127-130: LGTM: Explicit keyword arguments improve readability.

Using explicit keyword arguments (q=q, k=k, etc.) makes the function calls more readable and less prone to parameter ordering errors.


142-145: Note the parameter name change: g to gk.

The gate parameter has been renamed from g to gk in the fused_recurrent_gla function calls, while chunk_gla still uses g. This suggests an API change in the fused recurrent implementation.

Also applies to: 149-153


173-183: LGTM: Explicit cu_seqlens parameterization improves test determinism.

The explicit parameterization with cu_seqlens makes the variable-length tests more deterministic and easier to debug compared to random generation within the test.


191-191: LGTM: Clean cu_seqlens-based test setup.

The refactoring to derive batch size (N) and total sequence length (T) from the explicit cu_seqlens parameter makes the test logic clearer and more predictable.

Also applies to: 197-205

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 11

♻️ Duplicate comments (3)
tests/models/test_modeling_comba.py (1)

14-33: Critical: Same parameter order mismatch as in other test files.

This file has the identical parameter order issue where the parametrize decorator and function signature don't match.

tests/models/test_modeling_mamba.py (1)

14-33: Critical: Parameter order mismatch affects this file too.

tests/models/test_modeling_samba.py (1)

14-33: Critical: Parameter order mismatch present here as well.

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between cb4afaf and bc39bf9.

📒 Files selected for processing (24)
  • tests/models/test_modeling_abc.py (1 hunks)
  • tests/models/test_modeling_bitnet.py (1 hunks)
  • tests/models/test_modeling_comba.py (1 hunks)
  • tests/models/test_modeling_deltanet.py (1 hunks)
  • tests/models/test_modeling_forgetting_transformer.py (1 hunks)
  • tests/models/test_modeling_gated_deltanet.py (1 hunks)
  • tests/models/test_modeling_gated_deltaproduct.py (1 hunks)
  • tests/models/test_modeling_gla.py (1 hunks)
  • tests/models/test_modeling_gsa.py (1 hunks)
  • tests/models/test_modeling_hgrn.py (1 hunks)
  • tests/models/test_modeling_hgrn2.py (1 hunks)
  • tests/models/test_modeling_lightnet.py (1 hunks)
  • tests/models/test_modeling_linear_attn.py (1 hunks)
  • tests/models/test_modeling_mamba.py (1 hunks)
  • tests/models/test_modeling_mamba2.py (1 hunks)
  • tests/models/test_modeling_mesanet.py (1 hunks)
  • tests/models/test_modeling_nsa.py (1 hunks)
  • tests/models/test_modeling_path_attn.py (1 hunks)
  • tests/models/test_modeling_retnet.py (1 hunks)
  • tests/models/test_modeling_rodimus.py (1 hunks)
  • tests/models/test_modeling_rwkv6.py (1 hunks)
  • tests/models/test_modeling_rwkv7.py (1 hunks)
  • tests/models/test_modeling_samba.py (1 hunks)
  • tests/models/test_modeling_transformer.py (1 hunks)
🧰 Additional context used
🪛 Pylint (3.3.7)
tests/models/test_modeling_abc.py

[refactor] 25-25: Too many arguments (7/5)

(R0913)


[refactor] 25-25: Too many positional arguments (7/5)

(R0917)


[refactor] 49-49: Too many arguments (6/5)

(R0913)


[refactor] 49-49: Too many positional arguments (6/5)

(R0917)

tests/models/test_modeling_bitnet.py

[refactor] 25-25: Too many arguments (7/5)

(R0913)


[refactor] 25-25: Too many positional arguments (7/5)

(R0917)


[refactor] 49-49: Too many arguments (6/5)

(R0913)


[refactor] 49-49: Too many positional arguments (6/5)

(R0917)

tests/models/test_modeling_comba.py

[refactor] 25-25: Too many arguments (7/5)

(R0913)


[refactor] 25-25: Too many positional arguments (7/5)

(R0917)


[refactor] 49-49: Too many arguments (6/5)

(R0913)


[refactor] 49-49: Too many positional arguments (6/5)

(R0917)

tests/models/test_modeling_deltanet.py

[refactor] 25-25: Too many arguments (7/5)

(R0913)


[refactor] 25-25: Too many positional arguments (7/5)

(R0917)


[refactor] 49-49: Too many arguments (6/5)

(R0913)


[refactor] 49-49: Too many positional arguments (6/5)

(R0917)

tests/models/test_modeling_forgetting_transformer.py

[refactor] 25-25: Too many arguments (7/5)

(R0913)


[refactor] 25-25: Too many positional arguments (7/5)

(R0917)


[refactor] 49-49: Too many arguments (6/5)

(R0913)


[refactor] 49-49: Too many positional arguments (6/5)

(R0917)

tests/models/test_modeling_gated_deltanet.py

[refactor] 24-24: Too many arguments (7/5)

(R0913)


[refactor] 24-24: Too many positional arguments (7/5)

(R0917)


[refactor] 48-48: Too many arguments (6/5)

(R0913)


[refactor] 48-48: Too many positional arguments (6/5)

(R0917)

tests/models/test_modeling_gated_deltaproduct.py

[refactor] 28-28: Too many arguments (7/5)

(R0913)


[refactor] 28-28: Too many positional arguments (7/5)

(R0917)

tests/models/test_modeling_gla.py

[refactor] 25-25: Too many arguments (7/5)

(R0913)


[refactor] 25-25: Too many positional arguments (7/5)

(R0917)


[refactor] 49-49: Too many arguments (6/5)

(R0913)


[refactor] 49-49: Too many positional arguments (6/5)

(R0917)

tests/models/test_modeling_gsa.py

[refactor] 25-25: Too many arguments (7/5)

(R0913)


[refactor] 25-25: Too many positional arguments (7/5)

(R0917)


[refactor] 49-49: Too many arguments (6/5)

(R0913)


[refactor] 49-49: Too many positional arguments (6/5)

(R0917)

tests/models/test_modeling_hgrn.py

[refactor] 25-25: Too many arguments (7/5)

(R0913)


[refactor] 25-25: Too many positional arguments (7/5)

(R0917)


[refactor] 49-49: Too many arguments (6/5)

(R0913)


[refactor] 49-49: Too many positional arguments (6/5)

(R0917)

tests/models/test_modeling_hgrn2.py

[refactor] 25-25: Too many arguments (7/5)

(R0913)


[refactor] 25-25: Too many positional arguments (7/5)

(R0917)


[refactor] 49-49: Too many arguments (6/5)

(R0913)


[refactor] 49-49: Too many positional arguments (6/5)

(R0917)

tests/models/test_modeling_lightnet.py

[refactor] 25-25: Too many arguments (7/5)

(R0913)


[refactor] 25-25: Too many positional arguments (7/5)

(R0917)


[refactor] 49-49: Too many arguments (6/5)

(R0913)


[refactor] 49-49: Too many positional arguments (6/5)

(R0917)

tests/models/test_modeling_linear_attn.py

[refactor] 25-25: Too many arguments (7/5)

(R0913)


[refactor] 25-25: Too many positional arguments (7/5)

(R0917)


[refactor] 49-49: Too many arguments (6/5)

(R0913)


[refactor] 49-49: Too many positional arguments (6/5)

(R0917)

tests/models/test_modeling_mamba.py

[refactor] 25-25: Too many arguments (7/5)

(R0913)


[refactor] 25-25: Too many positional arguments (7/5)

(R0917)


[refactor] 49-49: Too many arguments (6/5)

(R0913)


[refactor] 49-49: Too many positional arguments (6/5)

(R0917)

tests/models/test_modeling_mamba2.py

[refactor] 25-25: Too many arguments (7/5)

(R0913)


[refactor] 25-25: Too many positional arguments (7/5)

(R0917)


[refactor] 49-49: Too many arguments (6/5)

(R0913)


[refactor] 49-49: Too many positional arguments (6/5)

(R0917)

tests/models/test_modeling_mesanet.py

[refactor] 25-25: Too many arguments (7/5)

(R0913)


[refactor] 25-25: Too many positional arguments (7/5)

(R0917)


[refactor] 49-49: Too many arguments (6/5)

(R0913)


[refactor] 49-49: Too many positional arguments (6/5)

(R0917)

tests/models/test_modeling_nsa.py

[refactor] 25-25: Too many arguments (7/5)

(R0913)


[refactor] 25-25: Too many positional arguments (7/5)

(R0917)


[refactor] 49-49: Too many arguments (6/5)

(R0913)


[refactor] 49-49: Too many positional arguments (6/5)

(R0917)

tests/models/test_modeling_path_attn.py

[refactor] 25-25: Too many arguments (7/5)

(R0913)


[refactor] 25-25: Too many positional arguments (7/5)

(R0917)


[refactor] 49-49: Too many arguments (6/5)

(R0913)


[refactor] 49-49: Too many positional arguments (6/5)

(R0917)

tests/models/test_modeling_retnet.py

[refactor] 25-25: Too many arguments (7/5)

(R0913)


[refactor] 25-25: Too many positional arguments (7/5)

(R0917)


[refactor] 49-49: Too many arguments (6/5)

(R0913)


[refactor] 49-49: Too many positional arguments (6/5)

(R0917)

tests/models/test_modeling_rodimus.py

[refactor] 25-25: Too many arguments (7/5)

(R0913)


[refactor] 25-25: Too many positional arguments (7/5)

(R0917)


[refactor] 49-49: Too many arguments (6/5)

(R0913)


[refactor] 49-49: Too many positional arguments (6/5)

(R0917)

tests/models/test_modeling_rwkv6.py

[refactor] 25-25: Too many arguments (7/5)

(R0913)


[refactor] 25-25: Too many positional arguments (7/5)

(R0917)


[refactor] 49-49: Too many arguments (6/5)

(R0913)


[refactor] 49-49: Too many positional arguments (6/5)

(R0917)

tests/models/test_modeling_rwkv7.py

[refactor] 25-25: Too many arguments (7/5)

(R0913)


[refactor] 25-25: Too many positional arguments (7/5)

(R0917)


[refactor] 50-50: Too many arguments (6/5)

(R0913)


[refactor] 50-50: Too many positional arguments (6/5)

(R0917)

tests/models/test_modeling_samba.py

[refactor] 25-25: Too many arguments (7/5)

(R0913)


[refactor] 25-25: Too many positional arguments (7/5)

(R0917)


[refactor] 49-49: Too many arguments (6/5)

(R0913)


[refactor] 49-49: Too many positional arguments (6/5)

(R0917)

tests/models/test_modeling_transformer.py

[refactor] 28-28: Too many arguments (7/5)

(R0913)


[refactor] 28-28: Too many positional arguments (7/5)

(R0917)


[refactor] 52-52: Too many arguments (6/5)

(R0913)


[refactor] 52-52: Too many positional arguments (6/5)

(R0917)

⏰ Context from checks skipped due to timeout of 90000ms (3)
  • GitHub Check: test-ops
  • GitHub Check: test-ops
  • GitHub Check: test-ops
🔇 Additional comments (29)
tests/models/test_modeling_bitnet.py (1)

40-56: LGTM! Clean refactoring with improved type annotations.

The consolidated parametrization and explicit type annotations improve test clarity and maintainability.

tests/models/test_modeling_comba.py (1)

40-56: LGTM! Generation test parameterization is correct.

tests/models/test_modeling_mamba.py (1)

40-56: LGTM! Generation test is correctly structured.

tests/models/test_modeling_samba.py (1)

40-56: LGTM! Generation test refactoring is solid.

tests/models/test_modeling_retnet.py (1)

40-56: LGTM! Consistent and clean generation test structure.

tests/models/test_modeling_transformer.py (2)

17-27: Good refactoring approach for test parameterization.

The consolidation of multiple @pytest.mark.parametrize decorators into a single decorator with explicit parameter tuples and descriptive test IDs improves readability and maintainability.


17-36: Critical parameter order mismatch between tuples and function signature.

The parameter order in the test tuples doesn't match the function signature, which will cause dtype and use_l2warp parameters to be swapped during test execution.

Fix the parameter order by updating the function signature:

 def test_modeling(
     L: int,
     B: int,
     T: int,
     H: int,
     D: int,
-    dtype: torch.dtype,
-    use_l2warp: bool,
+    use_l2warp: bool,
+    dtype: torch.dtype,
 ):

Or alternatively, reorder the tuples to match the current function signature.

Likely an incorrect or invalid review comment.

tests/models/test_modeling_rwkv6.py (1)

25-25: Good function naming consistency.

The renaming from test_rwkv6_modeling to test_modeling improves consistency across all model test files.

tests/models/test_modeling_hgrn2.py (1)

40-56: LGTM! Clean parameterization consolidation.

The test generation function has been properly refactored with consolidated parameterization and explicit type annotations. The parameter order is consistent between decorator and function signature.

tests/models/test_modeling_abc.py (1)

40-56: LGTM! Clean parameterization consolidation.

The test generation function has been properly refactored with consolidated parameterization and explicit type annotations.

tests/models/test_modeling_gsa.py (1)

40-56: LGTM! Clean parameterization consolidation.

The test generation function has been properly refactored with consolidated parameterization and explicit type annotations.

tests/models/test_modeling_gla.py (1)

40-56: LGTM! Clean parameterization consolidation.

The test generation function has been properly refactored with consolidated parameterization and explicit type annotations.

tests/models/test_modeling_forgetting_transformer.py (1)

40-56: LGTM! Clean parameterization consolidation.

The test generation function has been properly refactored with consolidated parameterization and explicit type annotations.

tests/models/test_modeling_path_attn.py (3)

14-24: Well-structured parameterization consolidation.

The refactoring successfully consolidates multiple decorators into a single, more readable parameterization with explicit test IDs. This improves test maintainability and makes test cases easier to identify.


25-33: Excellent addition of type annotations.

The explicit type annotations significantly improve code clarity and help with IDE support and static analysis. The parameter types are correctly specified.


40-48: Consistent parameterization pattern applied.

The generation test follows the same improved parameterization pattern as the modeling test, maintaining consistency across the test suite.

tests/models/test_modeling_rodimus.py (1)

14-33: Consistent refactoring pattern maintained.

The parameterization consolidation and type annotations follow the same improved pattern as other test files, ensuring consistency across the test suite while improving maintainability and readability.

tests/models/test_modeling_gated_deltanet.py (1)

14-33: Consistent and appropriate refactoring.

The same beneficial parameterization and type annotation patterns are applied consistently. The test case configuration appears appropriate for the GatedDeltaNet model.

tests/models/test_modeling_nsa.py (1)

14-33: Maintains consistent refactoring standards.

The parameterization consolidation and type annotations are properly applied following the same pattern as other test files, contributing to a uniform and maintainable test suite.

tests/models/test_modeling_mesanet.py (2)

14-33: Completes consistent test suite refactoring.

This file successfully applies the same parameterization consolidation and type annotation improvements as the other test files, resulting in a uniform and maintainable test suite across all model types.


40-56: Appropriate test parameter adjustment for faster testing.

The reduction of sequence length T from 4000 to 2000 in generation tests aligns with the PR objective of "Fast testing" while maintaining adequate test coverage.

tests/models/test_modeling_mamba2.py (2)

14-33: Excellent refactoring of test parameterization and type annotations.

The consolidation of pytest parameters into a single decorator with explicit test IDs significantly improves test organization and readability. The type annotations enhance code clarity and IDE support.


40-56: Well-structured generation test with proper parameterization.

The same excellent refactoring pattern applied consistently. The explicit type annotations and descriptive test IDs make the test more maintainable.

tests/models/test_modeling_lightnet.py (2)

14-33: Consistent refactoring pattern maintained across test files.

The parameterization consolidation and type annotations follow the same excellent pattern as other test files, ensuring consistency across the test suite.


40-56: Generation test properly refactored with appropriate config.

The test correctly uses LightNetConfig and maintains the same structural improvements as the modeling test.

tests/models/test_modeling_rwkv7.py (2)

14-33: Systematic refactoring applied consistently.

The parameterization consolidation and type annotations are well-executed, following the same pattern as other test files in this refactoring effort.


41-57: Generation test refactored with proper structure.

The test maintains the same quality improvements with appropriate use of RWKV7Config and consistent parameter organization.

tests/models/test_modeling_gated_deltaproduct.py (2)

17-36: Modeling test refactored with consistent pattern.

The parameterization consolidation and type annotations follow the same excellent pattern as other files, maintaining consistency across the test suite.


43-72: Generation test properly handles model-specific complexity.

The refactored parameterization is well-structured, and the custom model creation logic appropriately handles the specific requirements of GatedDeltaProduct testing with different configurations.

Comment thread tests/models/test_modeling_bitnet.py
Comment thread tests/models/test_modeling_retnet.py
Comment thread tests/models/test_modeling_deltanet.py
Comment thread tests/models/test_modeling_hgrn.py
Comment thread tests/models/test_modeling_linear_attn.py
Comment thread tests/models/test_modeling_hgrn2.py
Comment thread tests/models/test_modeling_abc.py
Comment thread tests/models/test_modeling_gsa.py
Comment thread tests/models/test_modeling_gla.py
Comment thread tests/models/test_modeling_forgetting_transformer.py
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

🧹 Nitpick comments (2)
.github/workflows/reusable-ci-tests.yml (1)

45-55: check_gpu.py result is used but the script output is not logged

For post-mortem debugging it’s useful to capture the script’s stdout/stderr even on success. Consider tee-ing it to the console or adding set -euo pipefail so failures surface early.

.github/workflows/nvidia-h100.yml (1)

18-23: Use a build matrix instead of three near-identical jobs

The only variance across the three jobs is the PyTorch version / conda env name. A matrix keeps the file concise and guarantees future versions stay in sync:

strategy:
  fail-fast: false
  matrix:
    torch: [ {env: pytorch_2_7, ver: '2.7.0'},
             {env: pytorch_nightly, ver: nightly},
             {env: pytorch_2_6, ver: '2.6.0'} ]

jobs:
  test-h100:
    uses: ./.github/workflows/reusable-ci-tests.yml
    with:
      runner: nvidia-h100
      gpu_type: nvidia
      conda_env_name: ${{ matrix.torch.env }}
      pytorch_version: ${{ matrix.torch.ver }}

This reduces maintenance overhead.

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between bc39bf9 and 25565bd.

📒 Files selected for processing (2)
  • .github/workflows/nvidia-h100.yml (1 hunks)
  • .github/workflows/reusable-ci-tests.yml (1 hunks)
⏰ Context from checks skipped due to timeout of 90000ms (1)
  • GitHub Check: test-ops
🔇 Additional comments (1)
.github/workflows/reusable-ci-tests.yml (1)

129-133: Quote the substitution to preserve spaces & newlines returned by changed-files

Without quotes the shell performs word-splitting, so filenames that contain spaces or glob characters break the script and the output variable.

-TEST_FILES=$(TEST_SCOPE=EXCLUDE_MODELS python scripts/find_dependent_tests.py "${{ steps.changed-files.outputs.all_changed_files }}")
+TEST_FILES=$(TEST_SCOPE=EXCLUDE_MODELS python scripts/find_dependent_tests.py \
+  "${{ steps.changed-files.outputs.all_changed_files }}" )

Same applies to the MODELS block below.

Comment thread .github/workflows/reusable-ci-tests.yml
Comment thread .github/workflows/reusable-ci-tests.yml
Comment thread .github/workflows/reusable-ci-tests.yml Outdated
Comment thread .github/workflows/nvidia-h100.yml Outdated
@zhiyuan1i zhiyuan1i force-pushed the fast_test_autotune branch 2 times, most recently from 25565bd to 6d7d6e1 Compare June 25, 2025 19:33
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (3)
.github/workflows/reusable-ci-tests.yml (3)

173-174: torch~= still risks pulling a CPU wheel

Earlier reviews already covered this – the ~= operator may resolve to +cpu wheels when CUDA wheels lag behind. Use an explicit == …+${{ inputs.pytorch_cuda_version }} (or pin the exact filename).

Also applies to: 185-186


200-204: Multi-line $TEST_FILES corrupts GITHUB_OUTPUT

Writing a newline-separated list via key=value truncates after the first line. Use the heredoc pattern suggested in the previous review.

Also applies to: 284-288


206-207: if: should test string emptiness explicitly

The expression still relies on raw string truthiness:

if: steps.find-ops-tests.outputs.test_files && …

As noted before, compare against '':

-if: steps.find-ops-tests.outputs.test_files && steps.check_skip.outputs.skip_tests == 'false'
+if: steps.find-ops-tests.outputs.test_files != '' && steps.check_skip.outputs.skip_tests == 'false'

Repeat for the models job.

Also applies to: 290-291

🧹 Nitpick comments (1)
.github/workflows/reusable-ci-tests.yml (1)

58-70: Minor: bake the runner name into a shell var once

Nitpick: interpolate ${{ runner.name }} once to avoid repetition/quoting issues:

RUNNER_NAME="${{ runner.name }}"
echo "Determining conda env for runner: ${RUNNER_NAME}"
case "${RUNNER_NAME}" in
  nvidia-h100-1) TARGET_CONDA_ENV="${{ inputs.conda_env_name }}" ;;
  nvidia-h100-2) TARGET_CONDA_ENV="${{ inputs.conda_env_name }}_1" ;;
  *)             TARGET_CONDA_ENV="${{ inputs.conda_env_name }}" ;;
esac
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 51a11dc and 63b33fd.

📒 Files selected for processing (2)
  • .github/workflows/nvidia-h100.yml (2 hunks)
  • .github/workflows/reusable-ci-tests.yml (6 hunks)
✅ Files skipped from review due to trivial changes (1)
  • .github/workflows/nvidia-h100.yml

Comment thread .github/workflows/reusable-ci-tests.yml
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🔭 Outside diff range comments (2)
tests/ops/test_gsa.py (2)

82-82: Fix gradient assignment bug

There's a copy-paste error where g.grad is incorrectly assigned to s.grad.

-    tri_dg, s.grad = g.grad.clone(), None
+    tri_dg, g.grad = g.grad.clone(), None

180-180: Fix gradient assignment bug

Same copy-paste error as in line 82.

-    tri_dg, s.grad = g.grad.clone(), None
+    tri_dg, g.grad = g.grad.clone(), None
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 63b33fd and 37488f7.

📒 Files selected for processing (4)
  • .github/workflows/reusable-ci-tests.yml (8 hunks)
  • fla/ops/gsa/fused_recurrent.py (1 hunks)
  • tests/ops/test_gla.py (2 hunks)
  • tests/ops/test_gsa.py (13 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • .github/workflows/reusable-ci-tests.yml
🧰 Additional context used
🧬 Code Graph Analysis (2)
tests/ops/test_gla.py (2)
fla/ops/common/fused_recurrent.py (1)
  • backward (517-535)
fla/ops/gla/chunk.py (1)
  • backward (1195-1213)
tests/ops/test_gsa.py (4)
fla/ops/gsa/chunk.py (2)
  • chunk_gsa (996-1133)
  • backward (967-992)
fla/ops/gsa/fused_recurrent.py (2)
  • fused_recurrent_gsa (431-536)
  • backward (405-428)
fla/ops/gsa/naive.py (1)
  • naive_recurrent_gsa (9-69)
fla/utils.py (2)
  • assert_close (78-90)
  • check_shared_mem (434-440)
🪛 Pylint (3.3.7)
tests/ops/test_gsa.py

[refactor] 115-115: Too many local variables (41/15)

(R0914)


[refactor] 115-115: Too many statements (51/50)

(R0915)


[refactor] 214-214: Too many arguments (7/5)

(R0913)


[refactor] 214-214: Too many positional arguments (7/5)

(R0917)


[refactor] 214-214: Too many local variables (37/15)

(R0914)

⏰ Context from checks skipped due to timeout of 90000ms (3)
  • GitHub Check: Test H100 (PyTorch 2.7) / test-ops
  • GitHub Check: Test H100 (PyTorch Nightly) / test-ops
  • GitHub Check: Test 4090 (PyTorch 2.7) / test-ops
🔇 Additional comments (8)
fla/ops/gsa/fused_recurrent.py (1)

250-250: Good optimization for GPU kernel performance!

Using triton.next_power_of_2 for block size calculation ensures power-of-2 block dimensions, which typically improves GPU kernel performance through better memory alignment and access patterns.

tests/ops/test_gla.py (1)

211-211: Excellent test coverage improvement!

Changing dht from zeros to random values and including it in the backward pass ensures proper gradient flow through the final hidden state, making the test more comprehensive.

Also applies to: 236-236

tests/ops/test_gsa.py (6)

4-4: Clean import management!

Adding List for type hints and removing unused imports improves code clarity.

Also applies to: 12-12


55-56: Great improvement to gradient testing coverage!

Adding gradient tensors for initial hidden states and including them in backward passes ensures comprehensive gradient propagation testing.

Also applies to: 59-59, 77-77, 94-95


98-109: Excellent test parameterization improvements!

Replacing conditional parameters with explicit lists and using direct cu_seqlens tensors makes tests more deterministic and easier to understand.

Also applies to: 119-119, 124-126


239-247: Good addition of scale parameter!

Adding scale=D**-0.5 ensures proper attention scaling in both implementations.

Also applies to: 257-266


288-298: Consistent improvements across variable-length tests!

The changes maintain consistency with other test functions, properly handling cu_seqlens and scale parameters.

Also applies to: 318-321, 336-345, 356-365


418-419: Better test coverage with random initial states!

Using random tensors instead of zeros for initial states tests the system under more realistic conditions.

Comment thread tests/ops/test_gsa.py
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 37488f7 and 145ed19.

📒 Files selected for processing (1)
  • fla/ops/rwkv7/fused_addcmul.py (3 hunks)
⏰ Context from checks skipped due to timeout of 90000ms (1)
  • GitHub Check: Test 4090 (PyTorch 2.7) / test-ops
🔇 Additional comments (3)
fla/ops/rwkv7/fused_addcmul.py (3)

4-4: Import added for Python version checking.

The sys import is correctly added to support version checking.


19-20: Identity decorator implementation is correct.

The identity decorator provides a clean fallback when torch.compile is not available.


179-179: Decorator change enables conditional compilation.

The change from hardcoded @torch.compile(fullgraph=True) to @torch_compile correctly enables conditional compilation based on the environment setup.

Comment thread fla/ops/rwkv7/fused_addcmul.py
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (1)
tests/ops/test_gsa.py (1)

82-82: Fix critical gradient assignment bugs.

The gradient assignment has a copy-paste error where s.grad is being set to None instead of g.grad.

Apply this fix for all three occurrences:

-    tri_dg, s.grad = g.grad.clone(), None
+    tri_dg, g.grad = g.grad.clone(), None

Also applies to: 180-180, 272-272

🧹 Nitpick comments (1)
tests/ops/test_delta_product.py (1)

28-32: Static analysis warnings are acceptable for test functions.

The pylint warnings about too many local variables and statements are common in test functions that need extensive setup. These are acceptable given the testing context where multiple tensors, gradients, and validation steps are required.

Also applies to: 95-97

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 145ed19 and 0a8cd55.

📒 Files selected for processing (18)
  • tests/ops/test_comba.py (2 hunks)
  • tests/ops/test_cumsum.py (0 hunks)
  • tests/ops/test_delta.py (3 hunks)
  • tests/ops/test_delta_product.py (5 hunks)
  • tests/ops/test_dplr_delta.py (6 hunks)
  • tests/ops/test_gated_delta.py (3 hunks)
  • tests/ops/test_gated_delta_product.py (3 hunks)
  • tests/ops/test_gla.py (4 hunks)
  • tests/ops/test_gsa.py (4 hunks)
  • tests/ops/test_hgrn.py (3 hunks)
  • tests/ops/test_iplr_delta.py (2 hunks)
  • tests/ops/test_linear_attn.py (3 hunks)
  • tests/ops/test_nsa.py (1 hunks)
  • tests/ops/test_retention.py (3 hunks)
  • tests/ops/test_rwkv6.py (3 hunks)
  • tests/ops/test_rwkv7.py (2 hunks)
  • tests/ops/test_simple_gla.py (6 hunks)
  • tests/ops/test_utils.py (2 hunks)
💤 Files with no reviewable changes (1)
  • tests/ops/test_cumsum.py
✅ Files skipped from review due to trivial changes (8)
  • tests/ops/test_utils.py
  • tests/ops/test_linear_attn.py
  • tests/ops/test_gla.py
  • tests/ops/test_hgrn.py
  • tests/ops/test_retention.py
  • tests/ops/test_comba.py
  • tests/ops/test_iplr_delta.py
  • tests/ops/test_simple_gla.py
🚧 Files skipped from review as they are similar to previous changes (3)
  • tests/ops/test_rwkv7.py
  • tests/ops/test_dplr_delta.py
  • tests/ops/test_rwkv6.py
🧰 Additional context used
🧬 Code Graph Analysis (4)
tests/ops/test_nsa.py (1)
fla/utils.py (1)
  • assert_close (78-90)
tests/ops/test_gsa.py (1)
fla/utils.py (1)
  • assert_close (78-90)
tests/ops/test_gated_delta.py (1)
fla/utils.py (1)
  • assert_close (78-90)
tests/ops/test_gated_delta_product.py (1)
fla/utils.py (1)
  • assert_close (78-90)
🪛 Pylint (3.3.7)
tests/ops/test_delta.py

[refactor] 25-25: Too many local variables (27/15)

(R0914)


[refactor] 87-87: Too many local variables (28/15)

(R0914)

tests/ops/test_delta_product.py

[refactor] 28-28: Too many local variables (28/15)

(R0914)


[refactor] 95-95: Too many local variables (40/15)

(R0914)


[refactor] 95-95: Too many statements (51/50)

(R0915)

🔇 Additional comments (11)
tests/ops/test_nsa.py (1)

165-165: LGTM! Consistent formatting improvement.

The removal of the leading space from the assertion prefix improves consistency with the assert_close function's right-aligned formatting.

tests/ops/test_gsa.py (1)

86-93: LGTM! Consistent formatting improvement.

The removal of leading spaces from assertion prefixes improves consistency with the assert_close function's right-aligned formatting across all test functions.

Also applies to: 184-191, 276-283, 374-381

tests/ops/test_gated_delta.py (1)

194-195: LGTM! Consistent formatting improvement.

The systematic removal of leading spaces from assertion prefixes across all test functions improves consistency with the assert_close function's formatting and enhances test output readability.

Also applies to: 262-269, 349-356

tests/ops/test_gated_delta_product.py (1)

99-106: LGTM! Consistent formatting improvement.

The removal of leading spaces from assertion prefixes maintains consistency with the broader test suite formatting standardization and improves test output readability.

Also applies to: 182-189, 209-216

tests/ops/test_delta_product.py (3)

2-2: LGTM! Copyright header updated.

The copyright year range has been updated appropriately.


16-37: LGTM! Test parameterization simplification.

The changes improve test predictability by:

  • Replacing environment variable-based skips with clear platform-based skips
  • Using explicit parameter tuples instead of complex conditional logic
  • Simplifying function signatures by moving fixed parameters into the test body

This makes the tests more maintainable and deterministic.


76-82: LGTM! Consistent formatting improvement.

The removal of leading spaces from assertion prefixes aligns with the broader test suite formatting standardization effort and improves consistency.

Also applies to: 149-155, 174-180

tests/ops/test_delta.py (4)

3-3: LGTM: Import additions are appropriate.

The added imports support the updated test structure with explicit type hints and device platform checks.

Also applies to: 10-10


17-33: Verify test coverage impact of simplified parameterization.

The move from environment-dependent to fixed parameters improves test predictability and maintainability. However, this significantly reduces test coverage compared to the previous parameterized approach.

Consider whether the fixed parameter values (B=2, H=3, T/D combinations) adequately cover edge cases and different tensor dimensions that were previously tested.


68-73: Good formatting cleanup.

Removing leading spaces from assertion labels improves code consistency and readability.


135-140: Good formatting cleanup.

Consistent with the previous assertion formatting improvements.

Comment thread tests/ops/test_delta.py
Comment on lines +77 to +96
@pytest.mark.parametrize(
("cu_seqlens"),
[
([0, 15, 100, 300, 1203, 2000]),
]
)
@pytest.mark.skipif(
device_platform == 'intel',
reason='Intel Triton Failure'
)
def test_chunk_varlen(
N: int,
T: int,
H: int,
D: int,
scale: float,
dtype: torch.dtype,
cu_seqlens: List[int],
):
torch.manual_seed(42)
os.environ['TRITON_F32_DEFAULT'] = 'ieee'
# randomly split the sequence into N segments
cu_seqlens = torch.cat([
torch.tensor([0], dtype=torch.long),
torch.arange(16, T)[torch.randperm(T - 16)[:N-1]],
torch.tensor([T], dtype=torch.long)
], 0).to(device).sort()[0]
cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32, device=device)
T = cu_seqlens[-1]
H = 2
D = 64
dtype = torch.float16
scale = 1.0
N = len(cu_seqlens) - 1
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

⚠️ Potential issue

Simplify parameterization structure and consider refactoring for readability.

The parameterization has been simplified, but there are two concerns:

  1. The parameterization syntax ("cu_seqlens") should be ("cu_seqlens",) (with trailing comma) to properly define a tuple parameter.
  2. Both test functions have excessive local variables (27-28 variables) which impacts readability.

Apply this fix for the parameterization:

-@pytest.mark.parametrize(
-    ("cu_seqlens"),
-    [
-        ([0, 15, 100, 300, 1203, 2000]),
-    ]
-)
+@pytest.mark.parametrize(
+    ("cu_seqlens",),
+    [
+        ([0, 15, 100, 300, 1203, 2000]),
+    ]
+)

Consider extracting test setup logic into helper functions to reduce local variable count:

def setup_test_tensors(B, T, H, D, dtype, N=None):
    """Helper to create and setup test tensors."""
    # Move tensor creation logic here
    pass
🧰 Tools
🪛 Pylint (3.3.7)

[refactor] 87-87: Too many local variables (28/15)

(R0914)

🤖 Prompt for AI Agents
In tests/ops/test_delta.py around lines 77 to 96, the parameterization syntax
should be corrected from ("cu_seqlens") to ("cu_seqlens",) to properly define a
tuple parameter. Additionally, to improve readability and reduce the large
number of local variables in the test functions, extract the test setup logic
such as tensor creation and initialization into separate helper functions.
Create a helper function like setup_test_tensors that accepts parameters like B,
T, H, D, dtype, and optionally N, and returns the necessary tensors and
variables for the tests. Replace the local setup code in the test functions with
calls to this helper to simplify the test bodies.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (1)
tests/ops/test_ttt.py (1)

96-96: Consider making hardcoded tensor slicing more robust.

The hardcoded slicing [:, :14, :, :] appears fragile and could break if test parameters change. Consider making this dynamic based on test parameters or add comments explaining the magic number 14.

-    assert_close(" de0", ref_deta[:, :14, :, :], tri_deta[:, :14, :, :], 0.010)
+    # Test gradient for first 14 time steps to avoid edge effects
+    early_steps = min(14, T-1)
+    assert_close(" de0", ref_deta[:, :early_steps, :, :], tri_deta[:, :early_steps, :, :], 0.010)

Also applies to: 182-182

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between bfb9a53 and 4f26f23.

📒 Files selected for processing (6)
  • tests/ops/test_hgrn.py (4 hunks)
  • tests/ops/test_nsa.py (4 hunks)
  • tests/ops/test_solve_tril.py (3 hunks)
  • tests/ops/test_titans.py (3 hunks)
  • tests/ops/test_ttt.py (4 hunks)
  • tests/ops/test_utils.py (12 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • tests/ops/test_utils.py
🧰 Additional context used
🧬 Code Graph Analysis (2)
tests/ops/test_titans.py (1)
fla/utils.py (1)
  • assert_close (78-90)
tests/ops/test_ttt.py (4)
fla/ops/ttt/chunk.py (1)
  • chunk_ttt_linear (1340-1441)
fla/ops/ttt/fused_chunk.py (1)
  • fused_chunk_ttt_linear (728-831)
fla/ops/ttt/naive.py (1)
  • chunk_ttt_linear_ref (73-126)
fla/utils.py (2)
  • assert_close (78-90)
  • check_shared_mem (434-440)
🪛 Pylint (3.3.7)
tests/ops/test_titans.py

[refactor] 53-53: Too many local variables (19/15)

(R0914)

tests/ops/test_nsa.py

[refactor] 33-33: Too many arguments (9/5)

(R0913)


[refactor] 33-33: Too many positional arguments (9/5)

(R0917)


[refactor] 33-33: Too many local variables (26/15)

(R0914)


[refactor] 96-96: Too many arguments (7/5)

(R0913)


[refactor] 96-96: Too many positional arguments (7/5)

(R0917)


[refactor] 96-96: Too many local variables (26/15)

(R0914)

tests/ops/test_ttt.py

[refactor] 29-29: Too many arguments (6/5)

(R0913)


[refactor] 29-29: Too many positional arguments (6/5)

(R0917)


[refactor] 29-29: Too many local variables (40/15)

(R0914)


[refactor] 115-115: Too many arguments (6/5)

(R0913)


[refactor] 115-115: Too many positional arguments (6/5)

(R0917)


[refactor] 115-115: Too many local variables (40/15)

(R0914)


[refactor] 203-203: Too many local variables (25/15)

(R0914)

tests/ops/test_hgrn.py

[refactor] 76-76: Too many local variables (23/15)

(R0914)

⏰ Context from checks skipped due to timeout of 90000ms (4)
  • GitHub Check: Test H100 (PyTorch 2.7) / test-ops
  • GitHub Check: Test H100 (PyTorch Nightly) / test-ops
  • GitHub Check: Test A770 (PyTorch 2.7) / test-ops
  • GitHub Check: Test 4090 (PyTorch 2.7) / test-ops
🔇 Additional comments (15)
tests/ops/test_titans.py (3)

40-52: LGTM: Test parameterization simplified effectively.

The refactoring from conditional parameter lists to explicit tuples improves test clarity and maintainability. The test cases cover a good range of configurations.


53-59: LGTM: Function signature simplified by removing deprecated parameters.

Removing the head_first and scale parameters aligns with the standardization effort. The function is now focused on testing the core functionality without environment-dependent variations.


78-83: LGTM: Tensor shapes standardized to consistent format.

The unconditional permutation to (B, T, H, D) format eliminates the need for conditional logic and ensures all tests use the same tensor layout, improving consistency.

tests/ops/test_hgrn.py (3)

15-27: LGTM: Test parameterization streamlined with explicit cases.

The replacement of conditional parameter lists with explicit tuples improves test predictability and removes environment dependencies.


64-86: LGTM: Variable-length test refactored to use explicit sequence lengths.

The change from random sequence splitting to explicit cu_seqlens lists makes tests more predictable and reproducible. The type annotation for cu_seqlens: List[int] is appropriate.


57-61: LGTM: Assertion labels standardized by removing leading spaces.

This change improves consistency in test output formatting across the codebase.

Also applies to: 119-123, 164-164

tests/ops/test_nsa.py (3)

17-29: LGTM: Test parameterization simplified with explicit test cases.

The refactoring removes conditional logic and provides clear, explicit test parameters that improve test maintainability.


78-110: LGTM: Variable-length test properly refactored to use explicit sequence lengths.

The implementation correctly handles the cu_seqlens parameter pattern with proper type annotations and tensor conversion.


31-31: Note: Tests are currently skipped with "TBD" reason.

Both test functions are marked as skipped. Consider addressing the underlying issues or providing more specific skip reasons if these tests are intentionally disabled.

Could you provide more context about why these tests are skipped and when they might be re-enabled?

Also applies to: 94-94

tests/ops/test_solve_tril.py (3)

15-27: LGTM: Test parameterization simplified with explicit cases.

The explicit parameter tuples improve test clarity and remove environment dependencies.


51-77: LGTM: Variable-length test properly refactored with type annotations.

The implementation correctly uses the cu_seqlens pattern with proper type annotations. The function signature is clean and well-typed.


83-83: ```shell
#!/bin/bash

Extract the full body of chunk_scaled_dot_kkt_fwd to confirm its return signature

sed -n '70,150p' fla/ops/common/chunk_scaled_dot_kkt.py


</details>
<details>
<summary>tests/ops/test_ttt.py (3)</summary>

`15-28`: **LGTM: Test parameterization simplified with explicit cases.**

The refactoring removes conditional logic and provides clear test parameters for both `test_chunk` and `test_fused_chunk` functions.




Also applies to: 101-114

---

`40-47`: **LGTM: Tensor shapes standardized to consistent format.**

All tensors are now created with the consistent `(B, T, H, D)` shape, eliminating the need for conditional permutations and improving test clarity.




Also applies to: 126-133

---

`187-216`: **LGTM: Variable-length test properly refactored to use explicit sequence lengths.**

The implementation correctly handles the `cu_seqlens` parameter pattern with proper type annotations and tensor indexing. The function calls are updated appropriately to use sequence-based slicing.

</details>

</blockquote></details>

</details>

<!-- This is an auto-generated comment by CodeRabbit for review status -->

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (2)
tests/ops/test_titans.py (2)

54-54: Address the FIXME: Test is unconditionally skipped.

The test is currently disabled with an unconditional skip marked as 'FIXME'. This needs to be resolved to ensure test coverage.

What specific issue is preventing this test from running? I can help fix the underlying problem or create a tracking issue for this.


56-124: Consider refactoring to reduce local variable count.

Static analysis flags this function for having too many local variables (19/15). Consider extracting tensor initialization into a helper function to improve readability and maintainability.

Example refactor:

+def setup_titans_test_tensors(B, H, T, D, dtype, BT=64):
+    """Initialize all tensors needed for titans test."""
+    torch.manual_seed(1)
+    
+    theta = torch.rand(B, H, T, 1, dtype=dtype)
+    alpha = torch.rand(B, H, T, 1, dtype=dtype)
+    eta = torch.rand(B, H, T, 1, dtype=dtype)
+    
+    q = F.normalize(torch.randn(B, H, T, D, dtype=torch.float32), p=2, dim=-1).to(dtype)
+    k = F.normalize(torch.randn(B, H, T, D, dtype=torch.float32), p=2, dim=-1).to(dtype)
+    v = torch.randn(B, H, T, D, dtype=dtype)
+    w = torch.randn(H, D, dtype=dtype)
+    b = torch.randn(H, D, dtype=dtype)
+    h0 = torch.randn(B, H, D, D, dtype=torch.float32)
+    
+    # Permute to (B, T, H, D) layout
+    q = q.permute(0, 2, 1, 3)
+    k = k.permute(0, 2, 1, 3)
+    v = v.permute(0, 2, 1, 3)
+    theta = theta.permute(0, 2, 1, 3)
+    alpha = alpha.permute(0, 2, 1, 3)
+    eta = eta.permute(0, 2, 1, 3)
+    
+    tensors = (q, k, v, w, b, theta, alpha, eta)
+    tensors = tuple(x.to(device).requires_grad_(False) for x in tensors)
+    h0 = h0.to(device)
+    
+    return tensors + (h0,)

 def test_naive_chunk(B, H, T, D, dtype):
     BT = 64
-    # Tensor initialization code...
+    q, k, v, w, b, theta, alpha, eta, h0 = setup_titans_test_tensors(B, H, T, D, dtype, BT)
     
     # Test logic remains the same...
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 4f26f23 and c306b32.

📒 Files selected for processing (2)
  • tests/ops/test_titans.py (3 hunks)
  • tests/ops/test_ttt.py (4 hunks)
🧰 Additional context used
🧬 Code Graph Analysis (1)
tests/ops/test_ttt.py (4)
fla/ops/ttt/chunk.py (1)
  • chunk_ttt_linear (1340-1441)
fla/ops/ttt/fused_chunk.py (1)
  • fused_chunk_ttt_linear (728-831)
fla/ops/ttt/naive.py (1)
  • chunk_ttt_linear_ref (73-126)
fla/utils.py (2)
  • assert_close (78-90)
  • check_shared_mem (434-440)
🪛 Pylint (3.3.7)
tests/ops/test_titans.py

[refactor] 56-56: Too many local variables (19/15)

(R0914)

tests/ops/test_ttt.py

[refactor] 29-29: Too many arguments (6/5)

(R0913)


[refactor] 29-29: Too many positional arguments (6/5)

(R0917)


[refactor] 29-29: Too many local variables (40/15)

(R0914)


[refactor] 117-117: Too many arguments (6/5)

(R0913)


[refactor] 117-117: Too many positional arguments (6/5)

(R0917)


[refactor] 117-117: Too many local variables (40/15)

(R0914)


[refactor] 207-207: Too many local variables (25/15)

(R0914)

⏰ Context from checks skipped due to timeout of 90000ms (4)
  • GitHub Check: Test 4090 (PyTorch 2.7) / test-ops
  • GitHub Check: Test H100 (PyTorch Nightly) / test-ops
  • GitHub Check: Test H100 (PyTorch 2.7) / test-ops
  • GitHub Check: Test A770 (PyTorch 2.7) / test-ops
🔇 Additional comments (13)
tests/ops/test_ttt.py (9)

4-4: LGTM! Type annotation import added.

Adding List import supports the new cu_seqlens: List[int] type annotation in the variable-length test.


12-12: LGTM! Utility functions imported.

The imports of assert_close and check_shared_mem from fla.utils are used consistently throughout the tests.


15-28: Excellent refactoring of test parametrization.

The consolidation from multiple conditional pytest.mark.parametrize decorators to a single explicit list of test cases significantly improves readability and maintainability. The explicit test cases are easier to understand and modify than the previous conditional logic.


42-49: LGTM! Tensor shapes standardized to sequence-first format.

All tensor creations now use the consistent (B, T, H, D) sequence-first format, aligning with the removal of the head_first parameter. This standardization improves code clarity and reduces conditional complexity.


103-116: LGTM! Consistent refactoring applied to fused chunk test.

The same parametrization consolidation and skip condition improvements are consistently applied to the test_fused_chunk function, maintaining uniformity across the test suite.

Also applies to: 127-128


191-202: LGTM! Improved variable-length test parametrization.

The explicit cu_seqlens lists replace the previous random splitting approach, making the tests more deterministic and easier to debug. The test cases cover various sequence length patterns effectively.


243-243: LGTM! Correct variable-length sequence handling.

The cu_seqlens parameter is properly passed to the test function, and the reference implementation correctly slices the input tensors using the cumulative sequence lengths. This approach is more deterministic than the previous random splitting.

Also applies to: 251-256


39-40: Verify the skip condition threshold.

The skip condition for T > 1000 may be too restrictive and could miss important test coverage for longer sequences. Consider if this threshold is appropriate or if it should be higher.

#!/bin/bash
# Description: Check if there are other sequence length thresholds used in similar tests
# Expected: Find other T thresholds or CI limitations that justify this value

rg -A 3 -B 3 "T > [0-9]+" --type py
rg -A 3 -B 3 "skip.*T" --type py  
rg -A 3 -B 3 "Current CI.*support.*config" --type py

98-98: Verify the partial gradient assertion logic.

The assertion tests only the first 14 elements in the second dimension ([:, :14, :, :]). Ensure this partial testing approach is sufficient and the magic number 14 is appropriate for the test scenarios.

#!/bin/bash
# Description: Check if similar partial gradient testing patterns exist and understand the rationale
# Expected: Find similar patterns or documentation explaining why 14 elements are tested

rg -A 2 -B 2 "\[:, :[0-9]+, :, :\]" --type py
rg -A 5 -B 5 "because the last element of the chunk" --type py
tests/ops/test_titans.py (4)

8-8: LGTM: Clean import simplification.

Good cleanup removing unused imports as part of the refactoring.


40-52: LGTM: Clear and explicit test parametrization.

The explicit parameter list is much cleaner than conditional logic and provides good test coverage across different tensor dimensions and data types.


56-62: LGTM: Simplified function signature.

The removal of head_first and scale parameters simplifies the test interface and aligns with the unconditional tensor permutations below.


81-86: To confirm the expected input layout, let’s pull in the full signature and doc-string around chunk_titans_linear_ref:

#!/bin/bash
# Show signature, doc-string, and first part of body for context
rg -n -C10 'def chunk_titans_linear_ref' fla/ops/titans/naive.py

Comment thread tests/ops/test_ttt.py
Comment on lines +217 to +220
T = cu_seqlens[-1]
N = len(cu_seqlens) - 1
cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32, device=device)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Fix the cu_seqlens tensor dtype.

The cu_seqlens tensor is created with dtype=torch.int32, but based on the function signature from the relevant code snippets, it expects Optional[torch.LongTensor] (int64).

-    cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32, device=device)
+    cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.long, device=device)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
T = cu_seqlens[-1]
N = len(cu_seqlens) - 1
cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32, device=device)
T = cu_seqlens[-1]
N = len(cu_seqlens) - 1
cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.long, device=device)
🤖 Prompt for AI Agents
In tests/ops/test_ttt.py around lines 217 to 220, the cu_seqlens tensor is
created with dtype=torch.int32, but the function expects an
Optional[torch.LongTensor] which corresponds to int64. Change the dtype of the
cu_seqlens tensor creation to torch.int64 to match the expected type and avoid
type errors.

@yzhangcs yzhangcs merged commit 16bafa7 into main Jun 26, 2025
2 of 6 checks passed
@yzhangcs yzhangcs deleted the fast_test_autotune branch June 26, 2025 23:04
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[RFC] Reduce Redundant Tests and Enhancing Tests

3 participants