Conversation
|
Caution Review failedThe pull request is closed. WalkthroughThis update streamlines and simplifies test parameterizations, removes environment-dependent logic, and adjusts Triton kernel autotuning configurations across several files. It also introduces conditional support for Changes
Sequence Diagram(s)sequenceDiagram
participant User
participant TestSuite
participant Kernel
participant CI_Workflow
User->>TestSuite: Run tests (with fixed parameters)
TestSuite->>Kernel: Launch kernel with new autotuning configs
CI_Workflow->>CI_Workflow: Dynamically select Conda env by runner
CI_Workflow->>TestSuite: Run all tests in one pytest invocation
Kernel-->>TestSuite: Return results
TestSuite-->>User: Report test outcomes
Possibly related PRs
Poem
📜 Recent review detailsConfiguration used: CodeRabbit UI 📒 Files selected for processing (2)
✨ Finishing Touches
🧪 Generate Unit Tests
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
SupportNeed help? Create a ticket on our support page for assistance with any issues or questions. Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
CodeRabbit Configuration File (
|
There was a problem hiding this comment.
Actionable comments posted: 0
🧹 Nitpick comments (10)
tests/models/test_modeling_hgrn.py (2)
21-29: Suppress Pylint overly-strict argument warnings
Static analysis flags R0913/R0917 due to 7 parameters; consider disabling them for this test:@pytest.mark.parametrize(... ) +# pylint: disable=R0913,R0917 def test_modeling(
42-49: Suppress Pylint overly-strict argument warnings
6 parameters trigger R0913/R0917; you can silence this with:@pytest.mark.parametrize(... ) +# pylint: disable=R0913,R0917 def test_generation(tests/models/test_modeling_deltanet.py (2)
21-29: Suppress Pylint overly-strict argument warnings
To avoid R0913/R0917 complaints for this 7-arg test, add:@pytest.mark.parametrize(... ) +# pylint: disable=R0913,R0917 def test_modeling(
42-49: Suppress Pylint overly-strict argument warnings
Disable R0913/R0917 for this 6-arg test to silence false positives:@pytest.mark.parametrize(... ) +# pylint: disable=R0913,R0917 def test_generation(tests/models/test_modeling_transformer.py (2)
24-32: Suppress Pylint overly-strict argument warnings
7 parameters trigger R0913/R0917; consider disabling for this function:@pytest.mark.parametrize(... ) +# pylint: disable=R0913,R0917 def test_modeling(
45-52: Suppress Pylint overly-strict argument warnings
6 parameters prompts R0913/R0917; disable to avoid noisy warnings:@pytest.mark.parametrize(... ) +# pylint: disable=R0913,R0917 def test_generation(tests/models/test_modeling_mamba.py (2)
21-29: Suppress Pylint overly-strict argument warnings
To silence R0913/R0917 for this test, add:@pytest.mark.parametrize(... ) +# pylint: disable=R0913,R0917 def test_modeling(
42-49: Suppress Pylint overly-strict argument warnings
6-arg signature triggers R0913/R0917; consider disabling:@pytest.mark.parametrize(... ) +# pylint: disable=R0913,R0917 def test_generation(tests/models/test_modeling_comba.py (2)
21-29: Suppress Pylint overly-strict argument warnings
Add a disable for R0913/R0917 to prevent noise:@pytest.mark.parametrize(... ) +# pylint: disable=R0913,R0917 def test_modeling(
42-49: Suppress Pylint overly-strict argument warnings
To avoid false positives on argument count, use:@pytest.mark.parametrize(... ) +# pylint: disable=R0913,R0917 def test_generation(
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (24)
tests/models/test_modeling_abc.py(2 hunks)tests/models/test_modeling_bitnet.py(2 hunks)tests/models/test_modeling_comba.py(2 hunks)tests/models/test_modeling_deltanet.py(2 hunks)tests/models/test_modeling_forgetting_transformer.py(2 hunks)tests/models/test_modeling_gated_deltanet.py(2 hunks)tests/models/test_modeling_gated_deltaproduct.py(1 hunks)tests/models/test_modeling_gla.py(2 hunks)tests/models/test_modeling_gsa.py(2 hunks)tests/models/test_modeling_hgrn.py(2 hunks)tests/models/test_modeling_hgrn2.py(2 hunks)tests/models/test_modeling_lightnet.py(2 hunks)tests/models/test_modeling_linear_attn.py(2 hunks)tests/models/test_modeling_mamba.py(2 hunks)tests/models/test_modeling_mamba2.py(2 hunks)tests/models/test_modeling_mesanet.py(2 hunks)tests/models/test_modeling_nsa.py(2 hunks)tests/models/test_modeling_path_attn.py(2 hunks)tests/models/test_modeling_retnet.py(2 hunks)tests/models/test_modeling_rodimus.py(2 hunks)tests/models/test_modeling_rwkv6.py(2 hunks)tests/models/test_modeling_rwkv7.py(1 hunks)tests/models/test_modeling_samba.py(2 hunks)tests/models/test_modeling_transformer.py(2 hunks)
🧰 Additional context used
🪛 Pylint (3.3.7)
tests/models/test_modeling_abc.py
[refactor] 21-21: Too many arguments (7/5)
(R0913)
[refactor] 21-21: Too many positional arguments (7/5)
(R0917)
[refactor] 42-42: Too many arguments (6/5)
(R0913)
[refactor] 42-42: Too many positional arguments (6/5)
(R0917)
tests/models/test_modeling_bitnet.py
[refactor] 21-21: Too many arguments (7/5)
(R0913)
[refactor] 21-21: Too many positional arguments (7/5)
(R0917)
[refactor] 42-42: Too many arguments (6/5)
(R0913)
[refactor] 42-42: Too many positional arguments (6/5)
(R0917)
tests/models/test_modeling_comba.py
[refactor] 21-21: Too many arguments (7/5)
(R0913)
[refactor] 21-21: Too many positional arguments (7/5)
(R0917)
[refactor] 42-42: Too many arguments (6/5)
(R0913)
[refactor] 42-42: Too many positional arguments (6/5)
(R0917)
tests/models/test_modeling_deltanet.py
[refactor] 21-21: Too many arguments (7/5)
(R0913)
[refactor] 21-21: Too many positional arguments (7/5)
(R0917)
[refactor] 42-42: Too many arguments (6/5)
(R0913)
[refactor] 42-42: Too many positional arguments (6/5)
(R0917)
tests/models/test_modeling_forgetting_transformer.py
[refactor] 21-21: Too many arguments (7/5)
(R0913)
[refactor] 21-21: Too many positional arguments (7/5)
(R0917)
[refactor] 42-42: Too many arguments (6/5)
(R0913)
[refactor] 42-42: Too many positional arguments (6/5)
(R0917)
tests/models/test_modeling_gated_deltanet.py
[refactor] 21-21: Too many arguments (7/5)
(R0913)
[refactor] 21-21: Too many positional arguments (7/5)
(R0917)
[refactor] 42-42: Too many arguments (6/5)
(R0913)
[refactor] 42-42: Too many positional arguments (6/5)
(R0917)
tests/models/test_modeling_gated_deltaproduct.py
[refactor] 24-24: Too many arguments (7/5)
(R0913)
[refactor] 24-24: Too many positional arguments (7/5)
(R0917)
tests/models/test_modeling_gla.py
[refactor] 21-21: Too many arguments (7/5)
(R0913)
[refactor] 21-21: Too many positional arguments (7/5)
(R0917)
[refactor] 42-42: Too many arguments (6/5)
(R0913)
[refactor] 42-42: Too many positional arguments (6/5)
(R0917)
tests/models/test_modeling_gsa.py
[refactor] 21-21: Too many arguments (7/5)
(R0913)
[refactor] 21-21: Too many positional arguments (7/5)
(R0917)
[refactor] 42-42: Too many arguments (6/5)
(R0913)
[refactor] 42-42: Too many positional arguments (6/5)
(R0917)
tests/models/test_modeling_hgrn.py
[refactor] 21-21: Too many arguments (7/5)
(R0913)
[refactor] 21-21: Too many positional arguments (7/5)
(R0917)
[refactor] 42-42: Too many arguments (6/5)
(R0913)
[refactor] 42-42: Too many positional arguments (6/5)
(R0917)
tests/models/test_modeling_hgrn2.py
[refactor] 21-21: Too many arguments (7/5)
(R0913)
[refactor] 21-21: Too many positional arguments (7/5)
(R0917)
[refactor] 42-42: Too many arguments (6/5)
(R0913)
[refactor] 42-42: Too many positional arguments (6/5)
(R0917)
tests/models/test_modeling_lightnet.py
[refactor] 21-21: Too many arguments (7/5)
(R0913)
[refactor] 21-21: Too many positional arguments (7/5)
(R0917)
[refactor] 42-42: Too many arguments (6/5)
(R0913)
[refactor] 42-42: Too many positional arguments (6/5)
(R0917)
tests/models/test_modeling_linear_attn.py
[refactor] 21-21: Too many arguments (7/5)
(R0913)
[refactor] 21-21: Too many positional arguments (7/5)
(R0917)
[refactor] 42-42: Too many arguments (6/5)
(R0913)
[refactor] 42-42: Too many positional arguments (6/5)
(R0917)
tests/models/test_modeling_mamba.py
[refactor] 21-21: Too many arguments (7/5)
(R0913)
[refactor] 21-21: Too many positional arguments (7/5)
(R0917)
[refactor] 42-42: Too many arguments (6/5)
(R0913)
[refactor] 42-42: Too many positional arguments (6/5)
(R0917)
tests/models/test_modeling_mamba2.py
[refactor] 21-21: Too many arguments (7/5)
(R0913)
[refactor] 21-21: Too many positional arguments (7/5)
(R0917)
[refactor] 42-42: Too many arguments (6/5)
(R0913)
[refactor] 42-42: Too many positional arguments (6/5)
(R0917)
tests/models/test_modeling_mesanet.py
[refactor] 21-21: Too many arguments (7/5)
(R0913)
[refactor] 21-21: Too many positional arguments (7/5)
(R0917)
[refactor] 42-42: Too many arguments (6/5)
(R0913)
[refactor] 42-42: Too many positional arguments (6/5)
(R0917)
tests/models/test_modeling_nsa.py
[refactor] 21-21: Too many arguments (7/5)
(R0913)
[refactor] 21-21: Too many positional arguments (7/5)
(R0917)
[refactor] 42-42: Too many arguments (6/5)
(R0913)
[refactor] 42-42: Too many positional arguments (6/5)
(R0917)
tests/models/test_modeling_path_attn.py
[refactor] 21-21: Too many arguments (7/5)
(R0913)
[refactor] 21-21: Too many positional arguments (7/5)
(R0917)
[refactor] 42-42: Too many arguments (6/5)
(R0913)
[refactor] 42-42: Too many positional arguments (6/5)
(R0917)
tests/models/test_modeling_retnet.py
[refactor] 21-21: Too many arguments (7/5)
(R0913)
[refactor] 21-21: Too many positional arguments (7/5)
(R0917)
[refactor] 42-42: Too many arguments (6/5)
(R0913)
[refactor] 42-42: Too many positional arguments (6/5)
(R0917)
tests/models/test_modeling_rodimus.py
[refactor] 21-21: Too many arguments (7/5)
(R0913)
[refactor] 21-21: Too many positional arguments (7/5)
(R0917)
[refactor] 42-42: Too many arguments (6/5)
(R0913)
[refactor] 42-42: Too many positional arguments (6/5)
(R0917)
tests/models/test_modeling_rwkv6.py
[refactor] 21-21: Too many arguments (7/5)
(R0913)
[refactor] 21-21: Too many positional arguments (7/5)
(R0917)
[refactor] 42-42: Too many arguments (6/5)
(R0913)
[refactor] 42-42: Too many positional arguments (6/5)
(R0917)
tests/models/test_modeling_rwkv7.py
[refactor] 21-21: Too many arguments (7/5)
(R0913)
[refactor] 21-21: Too many positional arguments (7/5)
(R0917)
[refactor] 43-43: Too many arguments (6/5)
(R0913)
[refactor] 43-43: Too many positional arguments (6/5)
(R0917)
tests/models/test_modeling_samba.py
[refactor] 21-21: Too many arguments (7/5)
(R0913)
[refactor] 21-21: Too many positional arguments (7/5)
(R0917)
[refactor] 42-42: Too many arguments (6/5)
(R0913)
[refactor] 42-42: Too many positional arguments (6/5)
(R0917)
tests/models/test_modeling_transformer.py
[refactor] 24-24: Too many arguments (7/5)
(R0913)
[refactor] 24-24: Too many positional arguments (7/5)
(R0917)
[refactor] 45-45: Too many arguments (6/5)
(R0913)
[refactor] 45-45: Too many positional arguments (6/5)
(R0917)
⏰ Context from checks skipped due to timeout of 90000ms (2)
- GitHub Check: test-ops
- GitHub Check: test-ops
🔇 Additional comments (47)
tests/models/test_modeling_gated_deltaproduct.py (1)
24-32: Add explicit type annotations totest_modelingsignature
The multi-line format with clear type hints for each parameter enhances readability and aligns with the PR’s consistency improvements across modeling tests.tests/models/test_modeling_abc.py (2)
21-29: Annotate and reformattest_modelingparameters
The explicitint,torch.dtype, andboolannotations in a multi-line signature improve clarity and maintain consistency with other model tests.
42-49: Add type hints and multi-line layout totest_generation
The updated signature withintandtorch.dtypeannotations elevates readability without altering behavior.tests/models/test_modeling_linear_attn.py (2)
21-29: Explicitly type and reformattest_modelingargs
Consistent multi-line formatting andint/torch.dtype/boolhints enhance the signature’s clarity across test suites.
42-49: Reformattest_generationwith type annotations
Addingintandtorch.dtypehints in a structured layout improves maintainability without side effects.tests/models/test_modeling_rodimus.py (2)
21-29: Improvetest_modelingsignature readability
The multi-line parameter list with explicit type hints ensures consistency and clearer intent in the test definitions.
42-49: Enhancetest_generationsignature with type hints
This refactor maintains existing logic while aligning the signature style with other model tests.tests/models/test_modeling_bitnet.py (2)
21-29: Add multi-line type-annotated signature totest_modeling
Explicit type hints and formatting keep the test definitions uniform and more readable.
42-49: Reformat and annotatetest_generationparameters
Type annotations (int,torch.dtype) and structured layout improve clarity without modifying behavior.tests/models/test_modeling_hgrn.py (2)
21-29: Add explicit type hints and multiline formatting to test_modeling signature
Type annotations improve clarity and align this test with the rest of the modeling suite.
42-49: Add explicit type hints and multiline formatting to test_generation signature
Consistent annotations across all modeling tests enhance maintainability.tests/models/test_modeling_deltanet.py (2)
21-29: Add explicit type hints and multiline formatting to test_modeling signature
These annotations match the style applied across other model tests.
42-49: Add explicit type hints and multiline formatting to test_generation signature
Maintains consistency and clarity in the test suite.tests/models/test_modeling_transformer.py (2)
24-32: Add explicit type hints and multiline formatting to test_modeling signature
Improves readability and enforces type safety in this test.
45-52: Add explicit type hints and multiline formatting to test_generation signature
Aligns with the rest of the modeling tests for consistency.tests/models/test_modeling_mamba.py (2)
21-29: Add explicit type hints and multiline formatting to test_modeling signature
This enhancement matches the style improvements across all model tests.
42-49: Add explicit type hints and multiline formatting to test_generation signature
Ensures uniformity in test definitions.tests/models/test_modeling_comba.py (2)
21-29: Add explicit type hints and multiline formatting to test_modeling signature
Type annotations and formatting boost clarity across the suite.
42-49: Add explicit type hints and multiline formatting to test_generation signature
Follows the PR’s consistent signature style.tests/models/test_modeling_nsa.py (2)
21-29: Add explicit type annotations totest_modelingsignature.
The multi-line signature withint,torch.dtype, andboolhints enhances readability and maintains consistency across modeling tests.
42-49: Add explicit type annotations totest_generationsignature.
Aligns the function signature with the typed, multi-line style used in other tests, improving clarity.tests/models/test_modeling_path_attn.py (2)
21-29: Add explicit type annotations totest_modelingsignature.
Introducingint,torch.dtype, andboolhints in a multi-line format improves readability and consistency for Path Attention tests.
42-49: Add explicit type annotations totest_generationsignature.
Matches the updated, typed signature format used elsewhere, making the test definitions uniform.tests/models/test_modeling_forgetting_transformer.py (2)
21-29: Add explicit type annotations totest_modelingsignature.
The added hints and multi-line layout boost clarity and maintain the standardized style across all model tests.
42-49: Add explicit type annotations totest_generationsignature.
Consistent multi-line, typed signatures align this test with the rest of the suite.tests/models/test_modeling_gla.py (2)
21-29: Add explicit type annotations totest_modelingsignature.
Enhances readability by clearly specifying parameter types in line with other GLA tests.
42-49: Add explicit type annotations totest_generationsignature.
Maintains uniform test signature style and type clarity across the modeling suite.tests/models/test_modeling_gsa.py (2)
21-29: Add explicit type annotations totest_modelingsignature.
Type hints in the signature improve code clarity and match the formatting of peer tests.
42-49: Add explicit type annotations totest_generationsignature.
Consistent application of multi-line, typed signature style enhances maintainability.tests/models/test_modeling_gated_deltanet.py (2)
21-28: Add explicit type annotations totest_modeling.
The multi-line, typed signature improves readability and enforces consistency across the test suite.
42-49: Add explicit type annotations totest_generation.
Aligns with the standardized function signature format used elsewhere.tests/models/test_modeling_mamba2.py (2)
21-29: Add explicit type annotations totest_modeling.
The revised, multi-line signature with parameter types enhances clarity and consistency.
42-49: Add explicit type annotations totest_generation.
Standardizes the test signature and improves maintainability.tests/models/test_modeling_retnet.py (2)
21-29: Add explicit type annotations totest_modeling.
Consistent multi-line formatting and typing mirror other model tests.
42-49: Add explicit type annotations totest_generation.
Maintains uniform signature style across the generation tests.tests/models/test_modeling_samba.py (2)
21-29: Add explicit type annotations totest_modeling.
Improves type clarity and follows the established formatting convention.
42-49: Add explicit type annotations totest_generation.
Ensures consistency with the rest of the test suite’s signature style.tests/models/test_modeling_rwkv6.py (2)
21-29: Add explicit type annotations totest_modeling.
The explicit, multi-line signature enhances readability and consistency.
42-49: Add explicit type annotations totest_generation.
Aligns this test’s signature with the updated project conventions.tests/models/test_modeling_mesanet.py (2)
21-29: Add explicit type hints and multi-line signature totest_modeling.
Standardizes parameter annotations without altering logic, matching other model test suites.
42-49: Add explicit type hints and multi-line signature totest_generation.
Consistent formatting update; no behavioral change.tests/models/test_modeling_hgrn2.py (2)
21-29: Add explicit type hints and multi-line signature totest_hgrn2_modeling.
Aligns with the unified style across modeling tests and preserves existing behavior.
42-49: Add explicit type hints and multi-line signature totest_generation.
Matches formatting in other test files; no functional impact.tests/models/test_modeling_rwkv7.py (2)
21-29: Rename and refactortest_modelingsignature with explicit type annotations.
Follows the standard multi-line style; internal logic remains unchanged.
43-50: Rename and refactortest_generationsignature with explicit type annotations.
Consistent update across model tests without affecting behavior.tests/models/test_modeling_lightnet.py (2)
21-29: Add explicit type hints and multi-line signature totest_modeling.
Standardizes the function header in line with other model tests; no changes to test logic.
42-49: Add explicit type hints and multi-line signature totest_generation.
Consistency improvement across the suite; behavior remains intact.
There was a problem hiding this comment.
Actionable comments posted: 0
🧹 Nitpick comments (1)
tests/ops/test_forgetting_attn.py (1)
123-123: Add explicitstrict=parameter to zip().The static analysis tool correctly identifies that
zip()should have an explicitstrictparameter for safety.Apply this diff to address the static analysis warning:
- for bos, eos in zip(cu_seqlens[:-1], cu_seqlens[1:]): + for bos, eos in zip(cu_seqlens[:-1], cu_seqlens[1:], strict=True):The
strict=Trueparameter ensures that both iterables have the same length, which should always be the case here since they're slices of the same tensor.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
tests/ops/test_attn.py(3 hunks)tests/ops/test_forgetting_attn.py(4 hunks)
🧰 Additional context used
🪛 Pylint (3.3.7)
tests/ops/test_attn.py
[refactor] 33-33: Too many arguments (6/5)
(R0913)
[refactor] 33-33: Too many positional arguments (6/5)
(R0917)
[refactor] 33-33: Too many local variables (18/15)
(R0914)
[refactor] 81-81: Too many local variables (18/15)
(R0914)
tests/ops/test_forgetting_attn.py
[refactor] 48-48: Too many arguments (6/5)
(R0913)
[refactor] 48-48: Too many positional arguments (6/5)
(R0917)
[refactor] 48-48: Too many local variables (22/15)
(R0914)
[refactor] 105-105: Too many local variables (23/15)
(R0914)
🪛 Ruff (0.11.9)
tests/ops/test_forgetting_attn.py
123-123: zip() without an explicit strict= parameter
Add explicit value for parameter strict=
(B905)
⏰ Context from checks skipped due to timeout of 90000ms (3)
- GitHub Check: test-ops
- GitHub Check: test-ops
- GitHub Check: test-ops
🔇 Additional comments (12)
tests/ops/test_attn.py (6)
4-4: Good addition of type hints.Adding
Listimport for proper type annotation improves code clarity.
20-32: Excellent simplification of test parameterization.The move from environment-dependent conditional parameterization to explicit fixed test cases improves:
- Test determinism and reproducibility
- Debugging experience (easier to identify which specific case failed)
- Code maintainability (no complex conditional logic)
The test cases cover a good range of dimensions and scale values.
43-44: Good placement of flash-attn availability check.Moving the check inside the test function is cleaner than having it in the parameterization logic.
47-50: Hardcoded dtype ensures test consistency.Using
torch.float16consistently across all tensors removes environment-dependent behavior and ensures tests run uniformly across different setups.
70-80: Consistent simplification for variable-length tests.The same beneficial pattern applied to the varlen test - explicit parameterization with fixed
cu_seqlensvalues instead of random generation.
85-89: Clean parameter handling for variable-length sequences.The explicit
cu_seqlensparameter with proper type annotation and tensor conversion is well-structured.tests/ops/test_forgetting_attn.py (6)
3-3: Good addition of comprehensive type hints.Adding
List, Optionalimports supports proper type annotations throughout the file.
35-47: Consistent test simplification pattern.Same beneficial pattern as in test_attn.py - explicit parameterization improves test determinism and maintainability.
56-57: Good addition of reproducibility controls.Adding manual seed and hardcoding dtype ensures consistent test behavior across runs and environments.
66-66: Verify the impact of simplified gating tensor initialization.The change from a complex log-sigmoid scaled initialization to a simple uniform distribution
(-0.1, -0.01)simplifies the test but may reduce the range of values being tested. Ensure this still adequately tests the forgetting attention mechanism.Consider whether the simpler initialization still provides adequate test coverage for the gating mechanism's behavior across different value ranges.
90-100: Consistent parameterization improvements for varlen tests.Same positive pattern applied to variable-length sequence tests.
119-119: Consistent gating tensor initialization.Same simplification as in the fixed-length test - ensure this provides adequate coverage.
There was a problem hiding this comment.
Actionable comments posted: 0
🔭 Outside diff range comments (1)
fla/ops/utils/cumsum.py (1)
4-4: Remove unused import to fix linting error.The
warningsimport is no longer used and is causing a pipeline failure.-import warnings
🧹 Nitpick comments (1)
fla/ops/utils/cumsum.py (1)
181-182: Consider removing redundant condition for consistency.Similar to the vector kernel change on line 248, this condition
i_c >= 0is always true sincei_cranges from 0 to NT-1 in the loop. Consider removing it for consistency and clarity.- if i_c >= 0: - b_z += b_ss + b_z += b_ss
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
fla/ops/utils/cumsum.py(3 hunks)
🧰 Additional context used
🪛 GitHub Actions: lint
fla/ops/utils/cumsum.py
[error] 4-4: flake8: 'warnings' imported but unused (F401)
🔇 Additional comments (2)
fla/ops/utils/cumsum.py (2)
168-168: Good formatting improvements for readability.Adding spaces around the subtraction operator improves code readability and follows Python style conventions.
Also applies to: 235-235
248-248: Verify the conditional logic removal is correct.The unconditional accumulation of
b_zappears to remove a redundant condition (sincei_cranges from 0 to NT-1,i_c >= 0is always true). However, there's an inconsistency with the scalar kernel which still has this condition on line 181-182.#!/bin/bash # Description: Check for similar conditional logic patterns in other kernels # Expected: Find similar i_c >= 0 conditions that might also be redundant rg -A 2 -B 2 "if i_c >= 0"
There was a problem hiding this comment.
Actionable comments posted: 0
🧹 Nitpick comments (1)
tests/ops/test_gla.py (1)
188-256: Consider refactoring to reduce local variable complexity.The static analysis tool correctly identifies that this function has many local variables (26 out of 15 recommended). Consider extracting helper functions for tensor creation, forward/backward passes, or assertion checks to improve readability and maintainability.
Example refactor approach:
def _create_test_tensors(N, T, H, D, dtype, device, cu_seqlens): """Helper to create test tensors.""" # Move tensor creation logic here pass def _run_forward_backward(model_fn, tensors, do): """Helper to run forward and backward passes.""" # Move forward/backward logic here pass
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
tests/ops/test_gla.py(7 hunks)
🧰 Additional context used
🪛 Pylint (3.3.7)
tests/ops/test_gla.py
[refactor] 188-188: Too many local variables (26/15)
(R0914)
⏰ Context from checks skipped due to timeout of 90000ms (2)
- GitHub Check: test-ops
- GitHub Check: test-ops
🔇 Additional comments (10)
tests/ops/test_gla.py (10)
4-4: LGTM: Import addition for type hints.Good addition of the
Listtype for better type annotations.
15-28: LGTM: Improved parameterization with consolidated test cases.The consolidation of multiple
pytest.mark.parametrizedecorators into a single unified parameter tuple improves test clarity and maintainability. The explicit test IDs will make test failures easier to identify.
38-38: LGTM: New gate_logit_normalizer parameter.The addition of the
gate_logit_normalizerparameter allows for testing different gate scaling scenarios, which improves test coverage.
43-47: LGTM: Improved tensor initialization and gate normalization.The changes from
torch.randntotorch.randfor input tensors and the application ofgate_logit_normalizerfor gate scaling provide more controlled test scenarios. The use ofF.logsigmoidensures proper gate value ranges.
89-102: LGTM: Consistent parameterization pattern.The parameterization matches the pattern established in the first test function, maintaining consistency across the test suite.
118-122: LGTM: Consistent tensor initialization.The tensor initialization changes are consistent with the first test function, ensuring uniform testing conditions.
127-130: LGTM: Explicit keyword arguments improve readability.Using explicit keyword arguments (
q=q,k=k, etc.) makes the function calls more readable and less prone to parameter ordering errors.
142-145: Note the parameter name change:gtogk.The gate parameter has been renamed from
gtogkin thefused_recurrent_glafunction calls, whilechunk_glastill usesg. This suggests an API change in the fused recurrent implementation.Also applies to: 149-153
173-183: LGTM: Explicit cu_seqlens parameterization improves test determinism.The explicit parameterization with
cu_seqlensmakes the variable-length tests more deterministic and easier to debug compared to random generation within the test.
191-191: LGTM: Clean cu_seqlens-based test setup.The refactoring to derive batch size (
N) and total sequence length (T) from the explicitcu_seqlensparameter makes the test logic clearer and more predictable.Also applies to: 197-205
There was a problem hiding this comment.
Actionable comments posted: 11
♻️ Duplicate comments (3)
tests/models/test_modeling_comba.py (1)
14-33: Critical: Same parameter order mismatch as in other test files.This file has the identical parameter order issue where the parametrize decorator and function signature don't match.
tests/models/test_modeling_mamba.py (1)
14-33: Critical: Parameter order mismatch affects this file too.tests/models/test_modeling_samba.py (1)
14-33: Critical: Parameter order mismatch present here as well.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (24)
tests/models/test_modeling_abc.py(1 hunks)tests/models/test_modeling_bitnet.py(1 hunks)tests/models/test_modeling_comba.py(1 hunks)tests/models/test_modeling_deltanet.py(1 hunks)tests/models/test_modeling_forgetting_transformer.py(1 hunks)tests/models/test_modeling_gated_deltanet.py(1 hunks)tests/models/test_modeling_gated_deltaproduct.py(1 hunks)tests/models/test_modeling_gla.py(1 hunks)tests/models/test_modeling_gsa.py(1 hunks)tests/models/test_modeling_hgrn.py(1 hunks)tests/models/test_modeling_hgrn2.py(1 hunks)tests/models/test_modeling_lightnet.py(1 hunks)tests/models/test_modeling_linear_attn.py(1 hunks)tests/models/test_modeling_mamba.py(1 hunks)tests/models/test_modeling_mamba2.py(1 hunks)tests/models/test_modeling_mesanet.py(1 hunks)tests/models/test_modeling_nsa.py(1 hunks)tests/models/test_modeling_path_attn.py(1 hunks)tests/models/test_modeling_retnet.py(1 hunks)tests/models/test_modeling_rodimus.py(1 hunks)tests/models/test_modeling_rwkv6.py(1 hunks)tests/models/test_modeling_rwkv7.py(1 hunks)tests/models/test_modeling_samba.py(1 hunks)tests/models/test_modeling_transformer.py(1 hunks)
🧰 Additional context used
🪛 Pylint (3.3.7)
tests/models/test_modeling_abc.py
[refactor] 25-25: Too many arguments (7/5)
(R0913)
[refactor] 25-25: Too many positional arguments (7/5)
(R0917)
[refactor] 49-49: Too many arguments (6/5)
(R0913)
[refactor] 49-49: Too many positional arguments (6/5)
(R0917)
tests/models/test_modeling_bitnet.py
[refactor] 25-25: Too many arguments (7/5)
(R0913)
[refactor] 25-25: Too many positional arguments (7/5)
(R0917)
[refactor] 49-49: Too many arguments (6/5)
(R0913)
[refactor] 49-49: Too many positional arguments (6/5)
(R0917)
tests/models/test_modeling_comba.py
[refactor] 25-25: Too many arguments (7/5)
(R0913)
[refactor] 25-25: Too many positional arguments (7/5)
(R0917)
[refactor] 49-49: Too many arguments (6/5)
(R0913)
[refactor] 49-49: Too many positional arguments (6/5)
(R0917)
tests/models/test_modeling_deltanet.py
[refactor] 25-25: Too many arguments (7/5)
(R0913)
[refactor] 25-25: Too many positional arguments (7/5)
(R0917)
[refactor] 49-49: Too many arguments (6/5)
(R0913)
[refactor] 49-49: Too many positional arguments (6/5)
(R0917)
tests/models/test_modeling_forgetting_transformer.py
[refactor] 25-25: Too many arguments (7/5)
(R0913)
[refactor] 25-25: Too many positional arguments (7/5)
(R0917)
[refactor] 49-49: Too many arguments (6/5)
(R0913)
[refactor] 49-49: Too many positional arguments (6/5)
(R0917)
tests/models/test_modeling_gated_deltanet.py
[refactor] 24-24: Too many arguments (7/5)
(R0913)
[refactor] 24-24: Too many positional arguments (7/5)
(R0917)
[refactor] 48-48: Too many arguments (6/5)
(R0913)
[refactor] 48-48: Too many positional arguments (6/5)
(R0917)
tests/models/test_modeling_gated_deltaproduct.py
[refactor] 28-28: Too many arguments (7/5)
(R0913)
[refactor] 28-28: Too many positional arguments (7/5)
(R0917)
tests/models/test_modeling_gla.py
[refactor] 25-25: Too many arguments (7/5)
(R0913)
[refactor] 25-25: Too many positional arguments (7/5)
(R0917)
[refactor] 49-49: Too many arguments (6/5)
(R0913)
[refactor] 49-49: Too many positional arguments (6/5)
(R0917)
tests/models/test_modeling_gsa.py
[refactor] 25-25: Too many arguments (7/5)
(R0913)
[refactor] 25-25: Too many positional arguments (7/5)
(R0917)
[refactor] 49-49: Too many arguments (6/5)
(R0913)
[refactor] 49-49: Too many positional arguments (6/5)
(R0917)
tests/models/test_modeling_hgrn.py
[refactor] 25-25: Too many arguments (7/5)
(R0913)
[refactor] 25-25: Too many positional arguments (7/5)
(R0917)
[refactor] 49-49: Too many arguments (6/5)
(R0913)
[refactor] 49-49: Too many positional arguments (6/5)
(R0917)
tests/models/test_modeling_hgrn2.py
[refactor] 25-25: Too many arguments (7/5)
(R0913)
[refactor] 25-25: Too many positional arguments (7/5)
(R0917)
[refactor] 49-49: Too many arguments (6/5)
(R0913)
[refactor] 49-49: Too many positional arguments (6/5)
(R0917)
tests/models/test_modeling_lightnet.py
[refactor] 25-25: Too many arguments (7/5)
(R0913)
[refactor] 25-25: Too many positional arguments (7/5)
(R0917)
[refactor] 49-49: Too many arguments (6/5)
(R0913)
[refactor] 49-49: Too many positional arguments (6/5)
(R0917)
tests/models/test_modeling_linear_attn.py
[refactor] 25-25: Too many arguments (7/5)
(R0913)
[refactor] 25-25: Too many positional arguments (7/5)
(R0917)
[refactor] 49-49: Too many arguments (6/5)
(R0913)
[refactor] 49-49: Too many positional arguments (6/5)
(R0917)
tests/models/test_modeling_mamba.py
[refactor] 25-25: Too many arguments (7/5)
(R0913)
[refactor] 25-25: Too many positional arguments (7/5)
(R0917)
[refactor] 49-49: Too many arguments (6/5)
(R0913)
[refactor] 49-49: Too many positional arguments (6/5)
(R0917)
tests/models/test_modeling_mamba2.py
[refactor] 25-25: Too many arguments (7/5)
(R0913)
[refactor] 25-25: Too many positional arguments (7/5)
(R0917)
[refactor] 49-49: Too many arguments (6/5)
(R0913)
[refactor] 49-49: Too many positional arguments (6/5)
(R0917)
tests/models/test_modeling_mesanet.py
[refactor] 25-25: Too many arguments (7/5)
(R0913)
[refactor] 25-25: Too many positional arguments (7/5)
(R0917)
[refactor] 49-49: Too many arguments (6/5)
(R0913)
[refactor] 49-49: Too many positional arguments (6/5)
(R0917)
tests/models/test_modeling_nsa.py
[refactor] 25-25: Too many arguments (7/5)
(R0913)
[refactor] 25-25: Too many positional arguments (7/5)
(R0917)
[refactor] 49-49: Too many arguments (6/5)
(R0913)
[refactor] 49-49: Too many positional arguments (6/5)
(R0917)
tests/models/test_modeling_path_attn.py
[refactor] 25-25: Too many arguments (7/5)
(R0913)
[refactor] 25-25: Too many positional arguments (7/5)
(R0917)
[refactor] 49-49: Too many arguments (6/5)
(R0913)
[refactor] 49-49: Too many positional arguments (6/5)
(R0917)
tests/models/test_modeling_retnet.py
[refactor] 25-25: Too many arguments (7/5)
(R0913)
[refactor] 25-25: Too many positional arguments (7/5)
(R0917)
[refactor] 49-49: Too many arguments (6/5)
(R0913)
[refactor] 49-49: Too many positional arguments (6/5)
(R0917)
tests/models/test_modeling_rodimus.py
[refactor] 25-25: Too many arguments (7/5)
(R0913)
[refactor] 25-25: Too many positional arguments (7/5)
(R0917)
[refactor] 49-49: Too many arguments (6/5)
(R0913)
[refactor] 49-49: Too many positional arguments (6/5)
(R0917)
tests/models/test_modeling_rwkv6.py
[refactor] 25-25: Too many arguments (7/5)
(R0913)
[refactor] 25-25: Too many positional arguments (7/5)
(R0917)
[refactor] 49-49: Too many arguments (6/5)
(R0913)
[refactor] 49-49: Too many positional arguments (6/5)
(R0917)
tests/models/test_modeling_rwkv7.py
[refactor] 25-25: Too many arguments (7/5)
(R0913)
[refactor] 25-25: Too many positional arguments (7/5)
(R0917)
[refactor] 50-50: Too many arguments (6/5)
(R0913)
[refactor] 50-50: Too many positional arguments (6/5)
(R0917)
tests/models/test_modeling_samba.py
[refactor] 25-25: Too many arguments (7/5)
(R0913)
[refactor] 25-25: Too many positional arguments (7/5)
(R0917)
[refactor] 49-49: Too many arguments (6/5)
(R0913)
[refactor] 49-49: Too many positional arguments (6/5)
(R0917)
tests/models/test_modeling_transformer.py
[refactor] 28-28: Too many arguments (7/5)
(R0913)
[refactor] 28-28: Too many positional arguments (7/5)
(R0917)
[refactor] 52-52: Too many arguments (6/5)
(R0913)
[refactor] 52-52: Too many positional arguments (6/5)
(R0917)
⏰ Context from checks skipped due to timeout of 90000ms (3)
- GitHub Check: test-ops
- GitHub Check: test-ops
- GitHub Check: test-ops
🔇 Additional comments (29)
tests/models/test_modeling_bitnet.py (1)
40-56: LGTM! Clean refactoring with improved type annotations.The consolidated parametrization and explicit type annotations improve test clarity and maintainability.
tests/models/test_modeling_comba.py (1)
40-56: LGTM! Generation test parameterization is correct.tests/models/test_modeling_mamba.py (1)
40-56: LGTM! Generation test is correctly structured.tests/models/test_modeling_samba.py (1)
40-56: LGTM! Generation test refactoring is solid.tests/models/test_modeling_retnet.py (1)
40-56: LGTM! Consistent and clean generation test structure.tests/models/test_modeling_transformer.py (2)
17-27: Good refactoring approach for test parameterization.The consolidation of multiple
@pytest.mark.parametrizedecorators into a single decorator with explicit parameter tuples and descriptive test IDs improves readability and maintainability.
17-36: Critical parameter order mismatch between tuples and function signature.The parameter order in the test tuples doesn't match the function signature, which will cause
dtypeanduse_l2warpparameters to be swapped during test execution.Fix the parameter order by updating the function signature:
def test_modeling( L: int, B: int, T: int, H: int, D: int, - dtype: torch.dtype, - use_l2warp: bool, + use_l2warp: bool, + dtype: torch.dtype, ):Or alternatively, reorder the tuples to match the current function signature.
Likely an incorrect or invalid review comment.
tests/models/test_modeling_rwkv6.py (1)
25-25: Good function naming consistency.The renaming from
test_rwkv6_modelingtotest_modelingimproves consistency across all model test files.tests/models/test_modeling_hgrn2.py (1)
40-56: LGTM! Clean parameterization consolidation.The test generation function has been properly refactored with consolidated parameterization and explicit type annotations. The parameter order is consistent between decorator and function signature.
tests/models/test_modeling_abc.py (1)
40-56: LGTM! Clean parameterization consolidation.The test generation function has been properly refactored with consolidated parameterization and explicit type annotations.
tests/models/test_modeling_gsa.py (1)
40-56: LGTM! Clean parameterization consolidation.The test generation function has been properly refactored with consolidated parameterization and explicit type annotations.
tests/models/test_modeling_gla.py (1)
40-56: LGTM! Clean parameterization consolidation.The test generation function has been properly refactored with consolidated parameterization and explicit type annotations.
tests/models/test_modeling_forgetting_transformer.py (1)
40-56: LGTM! Clean parameterization consolidation.The test generation function has been properly refactored with consolidated parameterization and explicit type annotations.
tests/models/test_modeling_path_attn.py (3)
14-24: Well-structured parameterization consolidation.The refactoring successfully consolidates multiple decorators into a single, more readable parameterization with explicit test IDs. This improves test maintainability and makes test cases easier to identify.
25-33: Excellent addition of type annotations.The explicit type annotations significantly improve code clarity and help with IDE support and static analysis. The parameter types are correctly specified.
40-48: Consistent parameterization pattern applied.The generation test follows the same improved parameterization pattern as the modeling test, maintaining consistency across the test suite.
tests/models/test_modeling_rodimus.py (1)
14-33: Consistent refactoring pattern maintained.The parameterization consolidation and type annotations follow the same improved pattern as other test files, ensuring consistency across the test suite while improving maintainability and readability.
tests/models/test_modeling_gated_deltanet.py (1)
14-33: Consistent and appropriate refactoring.The same beneficial parameterization and type annotation patterns are applied consistently. The test case configuration appears appropriate for the GatedDeltaNet model.
tests/models/test_modeling_nsa.py (1)
14-33: Maintains consistent refactoring standards.The parameterization consolidation and type annotations are properly applied following the same pattern as other test files, contributing to a uniform and maintainable test suite.
tests/models/test_modeling_mesanet.py (2)
14-33: Completes consistent test suite refactoring.This file successfully applies the same parameterization consolidation and type annotation improvements as the other test files, resulting in a uniform and maintainable test suite across all model types.
40-56: Appropriate test parameter adjustment for faster testing.The reduction of sequence length T from 4000 to 2000 in generation tests aligns with the PR objective of "Fast testing" while maintaining adequate test coverage.
tests/models/test_modeling_mamba2.py (2)
14-33: Excellent refactoring of test parameterization and type annotations.The consolidation of pytest parameters into a single decorator with explicit test IDs significantly improves test organization and readability. The type annotations enhance code clarity and IDE support.
40-56: Well-structured generation test with proper parameterization.The same excellent refactoring pattern applied consistently. The explicit type annotations and descriptive test IDs make the test more maintainable.
tests/models/test_modeling_lightnet.py (2)
14-33: Consistent refactoring pattern maintained across test files.The parameterization consolidation and type annotations follow the same excellent pattern as other test files, ensuring consistency across the test suite.
40-56: Generation test properly refactored with appropriate config.The test correctly uses
LightNetConfigand maintains the same structural improvements as the modeling test.tests/models/test_modeling_rwkv7.py (2)
14-33: Systematic refactoring applied consistently.The parameterization consolidation and type annotations are well-executed, following the same pattern as other test files in this refactoring effort.
41-57: Generation test refactored with proper structure.The test maintains the same quality improvements with appropriate use of
RWKV7Configand consistent parameter organization.tests/models/test_modeling_gated_deltaproduct.py (2)
17-36: Modeling test refactored with consistent pattern.The parameterization consolidation and type annotations follow the same excellent pattern as other files, maintaining consistency across the test suite.
43-72: Generation test properly handles model-specific complexity.The refactored parameterization is well-structured, and the custom model creation logic appropriately handles the specific requirements of GatedDeltaProduct testing with different configurations.
There was a problem hiding this comment.
Actionable comments posted: 4
🧹 Nitpick comments (2)
.github/workflows/reusable-ci-tests.yml (1)
45-55:check_gpu.pyresult is used but the script output is not loggedFor post-mortem debugging it’s useful to capture the script’s stdout/stderr even on success. Consider tee-ing it to the console or adding
set -euo pipefailso failures surface early..github/workflows/nvidia-h100.yml (1)
18-23: Use a build matrix instead of three near-identical jobsThe only variance across the three jobs is the PyTorch version / conda env name. A matrix keeps the file concise and guarantees future versions stay in sync:
strategy: fail-fast: false matrix: torch: [ {env: pytorch_2_7, ver: '2.7.0'}, {env: pytorch_nightly, ver: nightly}, {env: pytorch_2_6, ver: '2.6.0'} ] jobs: test-h100: uses: ./.github/workflows/reusable-ci-tests.yml with: runner: nvidia-h100 gpu_type: nvidia conda_env_name: ${{ matrix.torch.env }} pytorch_version: ${{ matrix.torch.ver }}This reduces maintenance overhead.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
.github/workflows/nvidia-h100.yml(1 hunks).github/workflows/reusable-ci-tests.yml(1 hunks)
⏰ Context from checks skipped due to timeout of 90000ms (1)
- GitHub Check: test-ops
🔇 Additional comments (1)
.github/workflows/reusable-ci-tests.yml (1)
129-133: Quote the substitution to preserve spaces & newlines returned bychanged-filesWithout quotes the shell performs word-splitting, so filenames that contain spaces or glob characters break the script and the output variable.
-TEST_FILES=$(TEST_SCOPE=EXCLUDE_MODELS python scripts/find_dependent_tests.py "${{ steps.changed-files.outputs.all_changed_files }}") +TEST_FILES=$(TEST_SCOPE=EXCLUDE_MODELS python scripts/find_dependent_tests.py \ + "${{ steps.changed-files.outputs.all_changed_files }}" )Same applies to the MODELS block below.
25565bd to
6d7d6e1
Compare
There was a problem hiding this comment.
Actionable comments posted: 1
♻️ Duplicate comments (3)
.github/workflows/reusable-ci-tests.yml (3)
173-174:torch~=still risks pulling a CPU wheelEarlier reviews already covered this – the
~=operator may resolve to+cpuwheels when CUDA wheels lag behind. Use an explicit== …+${{ inputs.pytorch_cuda_version }}(or pin the exact filename).Also applies to: 185-186
200-204: Multi-line$TEST_FILEScorruptsGITHUB_OUTPUTWriting a newline-separated list via
key=valuetruncates after the first line. Use the heredoc pattern suggested in the previous review.Also applies to: 284-288
206-207:if:should test string emptiness explicitlyThe expression still relies on raw string truthiness:
if: steps.find-ops-tests.outputs.test_files && …As noted before, compare against
'':-if: steps.find-ops-tests.outputs.test_files && steps.check_skip.outputs.skip_tests == 'false' +if: steps.find-ops-tests.outputs.test_files != '' && steps.check_skip.outputs.skip_tests == 'false'Repeat for the models job.
Also applies to: 290-291
🧹 Nitpick comments (1)
.github/workflows/reusable-ci-tests.yml (1)
58-70: Minor: bake the runner name into a shell var onceNitpick: interpolate
${{ runner.name }}once to avoid repetition/quoting issues:RUNNER_NAME="${{ runner.name }}" echo "Determining conda env for runner: ${RUNNER_NAME}" case "${RUNNER_NAME}" in nvidia-h100-1) TARGET_CONDA_ENV="${{ inputs.conda_env_name }}" ;; nvidia-h100-2) TARGET_CONDA_ENV="${{ inputs.conda_env_name }}_1" ;; *) TARGET_CONDA_ENV="${{ inputs.conda_env_name }}" ;; esac
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
.github/workflows/nvidia-h100.yml(2 hunks).github/workflows/reusable-ci-tests.yml(6 hunks)
✅ Files skipped from review due to trivial changes (1)
- .github/workflows/nvidia-h100.yml
…prove parameterization for variable lengths
There was a problem hiding this comment.
Actionable comments posted: 1
🔭 Outside diff range comments (2)
tests/ops/test_gsa.py (2)
82-82: Fix gradient assignment bugThere's a copy-paste error where
g.gradis incorrectly assigned tos.grad.- tri_dg, s.grad = g.grad.clone(), None + tri_dg, g.grad = g.grad.clone(), None
180-180: Fix gradient assignment bugSame copy-paste error as in line 82.
- tri_dg, s.grad = g.grad.clone(), None + tri_dg, g.grad = g.grad.clone(), None
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (4)
.github/workflows/reusable-ci-tests.yml(8 hunks)fla/ops/gsa/fused_recurrent.py(1 hunks)tests/ops/test_gla.py(2 hunks)tests/ops/test_gsa.py(13 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- .github/workflows/reusable-ci-tests.yml
🧰 Additional context used
🧬 Code Graph Analysis (2)
tests/ops/test_gla.py (2)
fla/ops/common/fused_recurrent.py (1)
backward(517-535)fla/ops/gla/chunk.py (1)
backward(1195-1213)
tests/ops/test_gsa.py (4)
fla/ops/gsa/chunk.py (2)
chunk_gsa(996-1133)backward(967-992)fla/ops/gsa/fused_recurrent.py (2)
fused_recurrent_gsa(431-536)backward(405-428)fla/ops/gsa/naive.py (1)
naive_recurrent_gsa(9-69)fla/utils.py (2)
assert_close(78-90)check_shared_mem(434-440)
🪛 Pylint (3.3.7)
tests/ops/test_gsa.py
[refactor] 115-115: Too many local variables (41/15)
(R0914)
[refactor] 115-115: Too many statements (51/50)
(R0915)
[refactor] 214-214: Too many arguments (7/5)
(R0913)
[refactor] 214-214: Too many positional arguments (7/5)
(R0917)
[refactor] 214-214: Too many local variables (37/15)
(R0914)
⏰ Context from checks skipped due to timeout of 90000ms (3)
- GitHub Check: Test H100 (PyTorch 2.7) / test-ops
- GitHub Check: Test H100 (PyTorch Nightly) / test-ops
- GitHub Check: Test 4090 (PyTorch 2.7) / test-ops
🔇 Additional comments (8)
fla/ops/gsa/fused_recurrent.py (1)
250-250: Good optimization for GPU kernel performance!Using
triton.next_power_of_2for block size calculation ensures power-of-2 block dimensions, which typically improves GPU kernel performance through better memory alignment and access patterns.tests/ops/test_gla.py (1)
211-211: Excellent test coverage improvement!Changing
dhtfrom zeros to random values and including it in the backward pass ensures proper gradient flow through the final hidden state, making the test more comprehensive.Also applies to: 236-236
tests/ops/test_gsa.py (6)
4-4: Clean import management!Adding
Listfor type hints and removing unused imports improves code clarity.Also applies to: 12-12
55-56: Great improvement to gradient testing coverage!Adding gradient tensors for initial hidden states and including them in backward passes ensures comprehensive gradient propagation testing.
Also applies to: 59-59, 77-77, 94-95
98-109: Excellent test parameterization improvements!Replacing conditional parameters with explicit lists and using direct
cu_seqlenstensors makes tests more deterministic and easier to understand.Also applies to: 119-119, 124-126
239-247: Good addition of scale parameter!Adding
scale=D**-0.5ensures proper attention scaling in both implementations.Also applies to: 257-266
288-298: Consistent improvements across variable-length tests!The changes maintain consistency with other test functions, properly handling cu_seqlens and scale parameters.
Also applies to: 318-321, 336-345, 356-365
418-419: Better test coverage with random initial states!Using random tensors instead of zeros for initial states tests the system under more realistic conditions.
There was a problem hiding this comment.
Actionable comments posted: 1
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
fla/ops/rwkv7/fused_addcmul.py(3 hunks)
⏰ Context from checks skipped due to timeout of 90000ms (1)
- GitHub Check: Test 4090 (PyTorch 2.7) / test-ops
🔇 Additional comments (3)
fla/ops/rwkv7/fused_addcmul.py (3)
4-4: Import added for Python version checking.The
sysimport is correctly added to support version checking.
19-20: Identity decorator implementation is correct.The identity decorator provides a clean fallback when torch.compile is not available.
179-179: Decorator change enables conditional compilation.The change from hardcoded
@torch.compile(fullgraph=True)to@torch_compilecorrectly enables conditional compilation based on the environment setup.
There was a problem hiding this comment.
Actionable comments posted: 1
♻️ Duplicate comments (1)
tests/ops/test_gsa.py (1)
82-82: Fix critical gradient assignment bugs.The gradient assignment has a copy-paste error where
s.gradis being set toNoneinstead ofg.grad.Apply this fix for all three occurrences:
- tri_dg, s.grad = g.grad.clone(), None + tri_dg, g.grad = g.grad.clone(), NoneAlso applies to: 180-180, 272-272
🧹 Nitpick comments (1)
tests/ops/test_delta_product.py (1)
28-32: Static analysis warnings are acceptable for test functions.The pylint warnings about too many local variables and statements are common in test functions that need extensive setup. These are acceptable given the testing context where multiple tensors, gradients, and validation steps are required.
Also applies to: 95-97
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (18)
tests/ops/test_comba.py(2 hunks)tests/ops/test_cumsum.py(0 hunks)tests/ops/test_delta.py(3 hunks)tests/ops/test_delta_product.py(5 hunks)tests/ops/test_dplr_delta.py(6 hunks)tests/ops/test_gated_delta.py(3 hunks)tests/ops/test_gated_delta_product.py(3 hunks)tests/ops/test_gla.py(4 hunks)tests/ops/test_gsa.py(4 hunks)tests/ops/test_hgrn.py(3 hunks)tests/ops/test_iplr_delta.py(2 hunks)tests/ops/test_linear_attn.py(3 hunks)tests/ops/test_nsa.py(1 hunks)tests/ops/test_retention.py(3 hunks)tests/ops/test_rwkv6.py(3 hunks)tests/ops/test_rwkv7.py(2 hunks)tests/ops/test_simple_gla.py(6 hunks)tests/ops/test_utils.py(2 hunks)
💤 Files with no reviewable changes (1)
- tests/ops/test_cumsum.py
✅ Files skipped from review due to trivial changes (8)
- tests/ops/test_utils.py
- tests/ops/test_linear_attn.py
- tests/ops/test_gla.py
- tests/ops/test_hgrn.py
- tests/ops/test_retention.py
- tests/ops/test_comba.py
- tests/ops/test_iplr_delta.py
- tests/ops/test_simple_gla.py
🚧 Files skipped from review as they are similar to previous changes (3)
- tests/ops/test_rwkv7.py
- tests/ops/test_dplr_delta.py
- tests/ops/test_rwkv6.py
🧰 Additional context used
🧬 Code Graph Analysis (4)
tests/ops/test_nsa.py (1)
fla/utils.py (1)
assert_close(78-90)
tests/ops/test_gsa.py (1)
fla/utils.py (1)
assert_close(78-90)
tests/ops/test_gated_delta.py (1)
fla/utils.py (1)
assert_close(78-90)
tests/ops/test_gated_delta_product.py (1)
fla/utils.py (1)
assert_close(78-90)
🪛 Pylint (3.3.7)
tests/ops/test_delta.py
[refactor] 25-25: Too many local variables (27/15)
(R0914)
[refactor] 87-87: Too many local variables (28/15)
(R0914)
tests/ops/test_delta_product.py
[refactor] 28-28: Too many local variables (28/15)
(R0914)
[refactor] 95-95: Too many local variables (40/15)
(R0914)
[refactor] 95-95: Too many statements (51/50)
(R0915)
🔇 Additional comments (11)
tests/ops/test_nsa.py (1)
165-165: LGTM! Consistent formatting improvement.The removal of the leading space from the assertion prefix improves consistency with the
assert_closefunction's right-aligned formatting.tests/ops/test_gsa.py (1)
86-93: LGTM! Consistent formatting improvement.The removal of leading spaces from assertion prefixes improves consistency with the
assert_closefunction's right-aligned formatting across all test functions.Also applies to: 184-191, 276-283, 374-381
tests/ops/test_gated_delta.py (1)
194-195: LGTM! Consistent formatting improvement.The systematic removal of leading spaces from assertion prefixes across all test functions improves consistency with the
assert_closefunction's formatting and enhances test output readability.Also applies to: 262-269, 349-356
tests/ops/test_gated_delta_product.py (1)
99-106: LGTM! Consistent formatting improvement.The removal of leading spaces from assertion prefixes maintains consistency with the broader test suite formatting standardization and improves test output readability.
Also applies to: 182-189, 209-216
tests/ops/test_delta_product.py (3)
2-2: LGTM! Copyright header updated.The copyright year range has been updated appropriately.
16-37: LGTM! Test parameterization simplification.The changes improve test predictability by:
- Replacing environment variable-based skips with clear platform-based skips
- Using explicit parameter tuples instead of complex conditional logic
- Simplifying function signatures by moving fixed parameters into the test body
This makes the tests more maintainable and deterministic.
76-82: LGTM! Consistent formatting improvement.The removal of leading spaces from assertion prefixes aligns with the broader test suite formatting standardization effort and improves consistency.
Also applies to: 149-155, 174-180
tests/ops/test_delta.py (4)
3-3: LGTM: Import additions are appropriate.The added imports support the updated test structure with explicit type hints and device platform checks.
Also applies to: 10-10
17-33: Verify test coverage impact of simplified parameterization.The move from environment-dependent to fixed parameters improves test predictability and maintainability. However, this significantly reduces test coverage compared to the previous parameterized approach.
Consider whether the fixed parameter values (B=2, H=3, T/D combinations) adequately cover edge cases and different tensor dimensions that were previously tested.
68-73: Good formatting cleanup.Removing leading spaces from assertion labels improves code consistency and readability.
135-140: Good formatting cleanup.Consistent with the previous assertion formatting improvements.
| @pytest.mark.parametrize( | ||
| ("cu_seqlens"), | ||
| [ | ||
| ([0, 15, 100, 300, 1203, 2000]), | ||
| ] | ||
| ) | ||
| @pytest.mark.skipif( | ||
| device_platform == 'intel', | ||
| reason='Intel Triton Failure' | ||
| ) | ||
| def test_chunk_varlen( | ||
| N: int, | ||
| T: int, | ||
| H: int, | ||
| D: int, | ||
| scale: float, | ||
| dtype: torch.dtype, | ||
| cu_seqlens: List[int], | ||
| ): | ||
| torch.manual_seed(42) | ||
| os.environ['TRITON_F32_DEFAULT'] = 'ieee' | ||
| # randomly split the sequence into N segments | ||
| cu_seqlens = torch.cat([ | ||
| torch.tensor([0], dtype=torch.long), | ||
| torch.arange(16, T)[torch.randperm(T - 16)[:N-1]], | ||
| torch.tensor([T], dtype=torch.long) | ||
| ], 0).to(device).sort()[0] | ||
| cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32, device=device) | ||
| T = cu_seqlens[-1] | ||
| H = 2 | ||
| D = 64 | ||
| dtype = torch.float16 | ||
| scale = 1.0 | ||
| N = len(cu_seqlens) - 1 |
There was a problem hiding this comment.
🛠️ Refactor suggestion
Simplify parameterization structure and consider refactoring for readability.
The parameterization has been simplified, but there are two concerns:
- The parameterization syntax
("cu_seqlens")should be("cu_seqlens",)(with trailing comma) to properly define a tuple parameter. - Both test functions have excessive local variables (27-28 variables) which impacts readability.
Apply this fix for the parameterization:
-@pytest.mark.parametrize(
- ("cu_seqlens"),
- [
- ([0, 15, 100, 300, 1203, 2000]),
- ]
-)
+@pytest.mark.parametrize(
+ ("cu_seqlens",),
+ [
+ ([0, 15, 100, 300, 1203, 2000]),
+ ]
+)Consider extracting test setup logic into helper functions to reduce local variable count:
def setup_test_tensors(B, T, H, D, dtype, N=None):
"""Helper to create and setup test tensors."""
# Move tensor creation logic here
pass🧰 Tools
🪛 Pylint (3.3.7)
[refactor] 87-87: Too many local variables (28/15)
(R0914)
🤖 Prompt for AI Agents
In tests/ops/test_delta.py around lines 77 to 96, the parameterization syntax
should be corrected from ("cu_seqlens") to ("cu_seqlens",) to properly define a
tuple parameter. Additionally, to improve readability and reduce the large
number of local variables in the test functions, extract the test setup logic
such as tensor creation and initialization into separate helper functions.
Create a helper function like setup_test_tensors that accepts parameters like B,
T, H, D, dtype, and optionally N, and returns the necessary tensors and
variables for the tests. Replace the local setup code in the test functions with
calls to this helper to simplify the test bodies.
…e input configurations and improve readability.
…ndle unsupported configurations
There was a problem hiding this comment.
Actionable comments posted: 0
🧹 Nitpick comments (1)
tests/ops/test_ttt.py (1)
96-96: Consider making hardcoded tensor slicing more robust.The hardcoded slicing
[:, :14, :, :]appears fragile and could break if test parameters change. Consider making this dynamic based on test parameters or add comments explaining the magic number14.- assert_close(" de0", ref_deta[:, :14, :, :], tri_deta[:, :14, :, :], 0.010) + # Test gradient for first 14 time steps to avoid edge effects + early_steps = min(14, T-1) + assert_close(" de0", ref_deta[:, :early_steps, :, :], tri_deta[:, :early_steps, :, :], 0.010)Also applies to: 182-182
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (6)
tests/ops/test_hgrn.py(4 hunks)tests/ops/test_nsa.py(4 hunks)tests/ops/test_solve_tril.py(3 hunks)tests/ops/test_titans.py(3 hunks)tests/ops/test_ttt.py(4 hunks)tests/ops/test_utils.py(12 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- tests/ops/test_utils.py
🧰 Additional context used
🧬 Code Graph Analysis (2)
tests/ops/test_titans.py (1)
fla/utils.py (1)
assert_close(78-90)
tests/ops/test_ttt.py (4)
fla/ops/ttt/chunk.py (1)
chunk_ttt_linear(1340-1441)fla/ops/ttt/fused_chunk.py (1)
fused_chunk_ttt_linear(728-831)fla/ops/ttt/naive.py (1)
chunk_ttt_linear_ref(73-126)fla/utils.py (2)
assert_close(78-90)check_shared_mem(434-440)
🪛 Pylint (3.3.7)
tests/ops/test_titans.py
[refactor] 53-53: Too many local variables (19/15)
(R0914)
tests/ops/test_nsa.py
[refactor] 33-33: Too many arguments (9/5)
(R0913)
[refactor] 33-33: Too many positional arguments (9/5)
(R0917)
[refactor] 33-33: Too many local variables (26/15)
(R0914)
[refactor] 96-96: Too many arguments (7/5)
(R0913)
[refactor] 96-96: Too many positional arguments (7/5)
(R0917)
[refactor] 96-96: Too many local variables (26/15)
(R0914)
tests/ops/test_ttt.py
[refactor] 29-29: Too many arguments (6/5)
(R0913)
[refactor] 29-29: Too many positional arguments (6/5)
(R0917)
[refactor] 29-29: Too many local variables (40/15)
(R0914)
[refactor] 115-115: Too many arguments (6/5)
(R0913)
[refactor] 115-115: Too many positional arguments (6/5)
(R0917)
[refactor] 115-115: Too many local variables (40/15)
(R0914)
[refactor] 203-203: Too many local variables (25/15)
(R0914)
tests/ops/test_hgrn.py
[refactor] 76-76: Too many local variables (23/15)
(R0914)
⏰ Context from checks skipped due to timeout of 90000ms (4)
- GitHub Check: Test H100 (PyTorch 2.7) / test-ops
- GitHub Check: Test H100 (PyTorch Nightly) / test-ops
- GitHub Check: Test A770 (PyTorch 2.7) / test-ops
- GitHub Check: Test 4090 (PyTorch 2.7) / test-ops
🔇 Additional comments (15)
tests/ops/test_titans.py (3)
40-52: LGTM: Test parameterization simplified effectively.The refactoring from conditional parameter lists to explicit tuples improves test clarity and maintainability. The test cases cover a good range of configurations.
53-59: LGTM: Function signature simplified by removing deprecated parameters.Removing the
head_firstandscaleparameters aligns with the standardization effort. The function is now focused on testing the core functionality without environment-dependent variations.
78-83: LGTM: Tensor shapes standardized to consistent format.The unconditional permutation to
(B, T, H, D)format eliminates the need for conditional logic and ensures all tests use the same tensor layout, improving consistency.tests/ops/test_hgrn.py (3)
15-27: LGTM: Test parameterization streamlined with explicit cases.The replacement of conditional parameter lists with explicit tuples improves test predictability and removes environment dependencies.
64-86: LGTM: Variable-length test refactored to use explicit sequence lengths.The change from random sequence splitting to explicit
cu_seqlenslists makes tests more predictable and reproducible. The type annotation forcu_seqlens: List[int]is appropriate.
57-61: LGTM: Assertion labels standardized by removing leading spaces.This change improves consistency in test output formatting across the codebase.
Also applies to: 119-123, 164-164
tests/ops/test_nsa.py (3)
17-29: LGTM: Test parameterization simplified with explicit test cases.The refactoring removes conditional logic and provides clear, explicit test parameters that improve test maintainability.
78-110: LGTM: Variable-length test properly refactored to use explicit sequence lengths.The implementation correctly handles the
cu_seqlensparameter pattern with proper type annotations and tensor conversion.
31-31: Note: Tests are currently skipped with "TBD" reason.Both test functions are marked as skipped. Consider addressing the underlying issues or providing more specific skip reasons if these tests are intentionally disabled.
Could you provide more context about why these tests are skipped and when they might be re-enabled?
Also applies to: 94-94
tests/ops/test_solve_tril.py (3)
15-27: LGTM: Test parameterization simplified with explicit cases.The explicit parameter tuples improve test clarity and remove environment dependencies.
51-77: LGTM: Variable-length test properly refactored with type annotations.The implementation correctly uses the
cu_seqlenspattern with proper type annotations. The function signature is clean and well-typed.
83-83: ```shell
#!/bin/bashExtract the full body of chunk_scaled_dot_kkt_fwd to confirm its return signature
sed -n '70,150p' fla/ops/common/chunk_scaled_dot_kkt.py
</details> <details> <summary>tests/ops/test_ttt.py (3)</summary> `15-28`: **LGTM: Test parameterization simplified with explicit cases.** The refactoring removes conditional logic and provides clear test parameters for both `test_chunk` and `test_fused_chunk` functions. Also applies to: 101-114 --- `40-47`: **LGTM: Tensor shapes standardized to consistent format.** All tensors are now created with the consistent `(B, T, H, D)` shape, eliminating the need for conditional permutations and improving test clarity. Also applies to: 126-133 --- `187-216`: **LGTM: Variable-length test properly refactored to use explicit sequence lengths.** The implementation correctly handles the `cu_seqlens` parameter pattern with proper type annotations and tensor indexing. The function calls are updated appropriately to use sequence-based slicing. </details> </blockquote></details> </details> <!-- This is an auto-generated comment by CodeRabbit for review status -->
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (2)
tests/ops/test_titans.py (2)
54-54: Address the FIXME: Test is unconditionally skipped.The test is currently disabled with an unconditional skip marked as 'FIXME'. This needs to be resolved to ensure test coverage.
What specific issue is preventing this test from running? I can help fix the underlying problem or create a tracking issue for this.
56-124: Consider refactoring to reduce local variable count.Static analysis flags this function for having too many local variables (19/15). Consider extracting tensor initialization into a helper function to improve readability and maintainability.
Example refactor:
+def setup_titans_test_tensors(B, H, T, D, dtype, BT=64): + """Initialize all tensors needed for titans test.""" + torch.manual_seed(1) + + theta = torch.rand(B, H, T, 1, dtype=dtype) + alpha = torch.rand(B, H, T, 1, dtype=dtype) + eta = torch.rand(B, H, T, 1, dtype=dtype) + + q = F.normalize(torch.randn(B, H, T, D, dtype=torch.float32), p=2, dim=-1).to(dtype) + k = F.normalize(torch.randn(B, H, T, D, dtype=torch.float32), p=2, dim=-1).to(dtype) + v = torch.randn(B, H, T, D, dtype=dtype) + w = torch.randn(H, D, dtype=dtype) + b = torch.randn(H, D, dtype=dtype) + h0 = torch.randn(B, H, D, D, dtype=torch.float32) + + # Permute to (B, T, H, D) layout + q = q.permute(0, 2, 1, 3) + k = k.permute(0, 2, 1, 3) + v = v.permute(0, 2, 1, 3) + theta = theta.permute(0, 2, 1, 3) + alpha = alpha.permute(0, 2, 1, 3) + eta = eta.permute(0, 2, 1, 3) + + tensors = (q, k, v, w, b, theta, alpha, eta) + tensors = tuple(x.to(device).requires_grad_(False) for x in tensors) + h0 = h0.to(device) + + return tensors + (h0,) def test_naive_chunk(B, H, T, D, dtype): BT = 64 - # Tensor initialization code... + q, k, v, w, b, theta, alpha, eta, h0 = setup_titans_test_tensors(B, H, T, D, dtype, BT) # Test logic remains the same...
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
tests/ops/test_titans.py(3 hunks)tests/ops/test_ttt.py(4 hunks)
🧰 Additional context used
🧬 Code Graph Analysis (1)
tests/ops/test_ttt.py (4)
fla/ops/ttt/chunk.py (1)
chunk_ttt_linear(1340-1441)fla/ops/ttt/fused_chunk.py (1)
fused_chunk_ttt_linear(728-831)fla/ops/ttt/naive.py (1)
chunk_ttt_linear_ref(73-126)fla/utils.py (2)
assert_close(78-90)check_shared_mem(434-440)
🪛 Pylint (3.3.7)
tests/ops/test_titans.py
[refactor] 56-56: Too many local variables (19/15)
(R0914)
tests/ops/test_ttt.py
[refactor] 29-29: Too many arguments (6/5)
(R0913)
[refactor] 29-29: Too many positional arguments (6/5)
(R0917)
[refactor] 29-29: Too many local variables (40/15)
(R0914)
[refactor] 117-117: Too many arguments (6/5)
(R0913)
[refactor] 117-117: Too many positional arguments (6/5)
(R0917)
[refactor] 117-117: Too many local variables (40/15)
(R0914)
[refactor] 207-207: Too many local variables (25/15)
(R0914)
⏰ Context from checks skipped due to timeout of 90000ms (4)
- GitHub Check: Test 4090 (PyTorch 2.7) / test-ops
- GitHub Check: Test H100 (PyTorch Nightly) / test-ops
- GitHub Check: Test H100 (PyTorch 2.7) / test-ops
- GitHub Check: Test A770 (PyTorch 2.7) / test-ops
🔇 Additional comments (13)
tests/ops/test_ttt.py (9)
4-4: LGTM! Type annotation import added.Adding
Listimport supports the newcu_seqlens: List[int]type annotation in the variable-length test.
12-12: LGTM! Utility functions imported.The imports of
assert_closeandcheck_shared_memfromfla.utilsare used consistently throughout the tests.
15-28: Excellent refactoring of test parametrization.The consolidation from multiple conditional
pytest.mark.parametrizedecorators to a single explicit list of test cases significantly improves readability and maintainability. The explicit test cases are easier to understand and modify than the previous conditional logic.
42-49: LGTM! Tensor shapes standardized to sequence-first format.All tensor creations now use the consistent
(B, T, H, D)sequence-first format, aligning with the removal of thehead_firstparameter. This standardization improves code clarity and reduces conditional complexity.
103-116: LGTM! Consistent refactoring applied to fused chunk test.The same parametrization consolidation and skip condition improvements are consistently applied to the
test_fused_chunkfunction, maintaining uniformity across the test suite.Also applies to: 127-128
191-202: LGTM! Improved variable-length test parametrization.The explicit
cu_seqlenslists replace the previous random splitting approach, making the tests more deterministic and easier to debug. The test cases cover various sequence length patterns effectively.
243-243: LGTM! Correct variable-length sequence handling.The
cu_seqlensparameter is properly passed to the test function, and the reference implementation correctly slices the input tensors using the cumulative sequence lengths. This approach is more deterministic than the previous random splitting.Also applies to: 251-256
39-40: Verify the skip condition threshold.The skip condition for
T > 1000may be too restrictive and could miss important test coverage for longer sequences. Consider if this threshold is appropriate or if it should be higher.#!/bin/bash # Description: Check if there are other sequence length thresholds used in similar tests # Expected: Find other T thresholds or CI limitations that justify this value rg -A 3 -B 3 "T > [0-9]+" --type py rg -A 3 -B 3 "skip.*T" --type py rg -A 3 -B 3 "Current CI.*support.*config" --type py
98-98: Verify the partial gradient assertion logic.The assertion tests only the first 14 elements in the second dimension (
[:, :14, :, :]). Ensure this partial testing approach is sufficient and the magic number 14 is appropriate for the test scenarios.#!/bin/bash # Description: Check if similar partial gradient testing patterns exist and understand the rationale # Expected: Find similar patterns or documentation explaining why 14 elements are tested rg -A 2 -B 2 "\[:, :[0-9]+, :, :\]" --type py rg -A 5 -B 5 "because the last element of the chunk" --type pytests/ops/test_titans.py (4)
8-8: LGTM: Clean import simplification.Good cleanup removing unused imports as part of the refactoring.
40-52: LGTM: Clear and explicit test parametrization.The explicit parameter list is much cleaner than conditional logic and provides good test coverage across different tensor dimensions and data types.
56-62: LGTM: Simplified function signature.The removal of
head_firstandscaleparameters simplifies the test interface and aligns with the unconditional tensor permutations below.
81-86: To confirm the expected input layout, let’s pull in the full signature and doc-string aroundchunk_titans_linear_ref:#!/bin/bash # Show signature, doc-string, and first part of body for context rg -n -C10 'def chunk_titans_linear_ref' fla/ops/titans/naive.py
| T = cu_seqlens[-1] | ||
| N = len(cu_seqlens) - 1 | ||
| cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32, device=device) | ||
|
|
There was a problem hiding this comment.
Fix the cu_seqlens tensor dtype.
The cu_seqlens tensor is created with dtype=torch.int32, but based on the function signature from the relevant code snippets, it expects Optional[torch.LongTensor] (int64).
- cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32, device=device)
+ cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.long, device=device)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| T = cu_seqlens[-1] | |
| N = len(cu_seqlens) - 1 | |
| cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32, device=device) | |
| T = cu_seqlens[-1] | |
| N = len(cu_seqlens) - 1 | |
| cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.long, device=device) |
🤖 Prompt for AI Agents
In tests/ops/test_ttt.py around lines 217 to 220, the cu_seqlens tensor is
created with dtype=torch.int32, but the function expects an
Optional[torch.LongTensor] which corresponds to int64. Change the dtype of the
cu_seqlens tensor creation to torch.int64 to match the expected type and avoid
type errors.
Summary by CodeRabbit
torch.compilebased on Python version in certain modules.