[Refactor] Rename boolean flags to follow constant naming convention#648
[Refactor] Rename boolean flags to follow constant naming convention#648
Conversation
* [KDA]: Add TMA support for chunk_kda_bwd_kernel_intra and tune BK * [KDA] enable TMA on some simple kernels * fix test * fixup! [KDA] enable TMA on some simple kernels
- Rename use_cuda_graph -> USE_CUDA_GRAPH - Rename is_tf32_supported -> HAS_TF32_SUPPORT - Rename is_gather_supported -> HAS_GATHER_SUPPORT - Rename is_tma_supported -> FLA_USE_TMA
WalkthroughThis PR systematically renames feature flag constants in Changes
Estimated code review effort🎯 2 (Simple) | ⏱️ ~10 minutes This is a systematic, highly repetitive refactoring where the same pattern is applied consistently across 21 files—namely, importing renamed constants and updating autotune decorator/function call arguments. The changes are low-risk and homogeneous, making review straightforward, though thoroughness is needed to verify all references were properly updated throughout the codebase. Possibly related PRs
Suggested reviewers
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello @zhiyuan1i, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request focuses on enhancing code consistency and clarity by refactoring several global boolean flags. The change involves renaming these flags to adhere to a standard constant naming convention, which makes their purpose and immutability more explicit. This refactor impacts multiple operational modules, ensuring that all references to these flags are updated to reflect the new naming scheme, thereby improving the overall maintainability and understanding of the codebase. Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request refactors several boolean flags to use UPPER_SNAKE_CASE for constants, which improves code style and consistency. The changes are applied correctly across all relevant files. I have one suggestion to extend this refactoring to other similar boolean flags in fla/utils.py for even better consistency. Overall, this is a good quality-of-life improvement for the codebase.
| is_intel_alchemist = (is_intel and 'Intel(R) Arc(TM) A' in torch.xpu.get_device_name(0)) | ||
| is_nvidia_hopper = (is_nvidia and ('NVIDIA H' in torch.cuda.get_device_name(0) or torch.cuda.get_device_capability()[0] >= 9)) | ||
| use_cuda_graph = (is_nvidia and os.environ.get('FLA_USE_CUDA_GRAPH', '0') == '1') | ||
| USE_CUDA_GRAPH = (is_nvidia and os.environ.get('FLA_USE_CUDA_GRAPH', '0') == '1') |
There was a problem hiding this comment.
For consistency with the changes in this pull request, consider also renaming other boolean flags in this file to UPPER_SNAKE_CASE. Specifically, is_amd, is_intel, is_nvidia, is_intel_alchemist, and is_nvidia_hopper could be renamed to IS_AMD, IS_INTEL, IS_NVIDIA, IS_INTEL_ALCHEMIST, and IS_NVIDIA_HOPPER respectively. This would require updating their usages throughout the codebase, but would improve overall consistency.
There was a problem hiding this comment.
Actionable comments posted: 0
🧹 Nitpick comments (1)
fla/ops/utils/solve_tril.py (1)
11-17: TMA flag rename is correct; consider deterministic DOT_PRECISION orderingWiring
FLA_USE_TMAinto bothDOT_PRECISION_AUTOTUNE_LISTand theUSE_TMAkernel arg preserves the original behavior of gating TMA-specific paths behind a single boolean, so the rename looks safe.One small refinement:
DOT_PRECISION_AUTOTUNE_LIST = list({"ieee", FLA_TRIL_PRECISION})relies on set iteration order, which is not guaranteed. If autotune config order matters for you (e.g., for determinism or warm‑cache behavior), consider something like:DOT_PRECISION_AUTOTUNE_LIST = ( ["ieee"] if not FLA_USE_TMA else (["ieee"] if FLA_TRIL_PRECISION == "ieee" else ["ieee", FLA_TRIL_PRECISION]) )Also applies to: 382-383
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (22)
fla/ops/common/chunk_delta_h.py(3 hunks)fla/ops/gated_delta_product/chunk_deltaproduct_h.py(3 hunks)fla/ops/generalized_delta_rule/dplr/chunk_A_bwd.py(4 hunks)fla/ops/generalized_delta_rule/dplr/chunk_A_fwd.py(3 hunks)fla/ops/generalized_delta_rule/dplr/chunk_h_bwd.py(2 hunks)fla/ops/generalized_delta_rule/dplr/chunk_h_fwd.py(2 hunks)fla/ops/generalized_delta_rule/dplr/chunk_o_bwd.py(4 hunks)fla/ops/generalized_delta_rule/dplr/chunk_o_fwd.py(2 hunks)fla/ops/generalized_delta_rule/dplr/fused_recurrent.py(2 hunks)fla/ops/generalized_delta_rule/dplr/wy_fast_bwd.py(2 hunks)fla/ops/generalized_delta_rule/dplr/wy_fast_fwd.py(5 hunks)fla/ops/generalized_delta_rule/iplr/chunk.py(3 hunks)fla/ops/kda/chunk_inter.py(2 hunks)fla/ops/kda/chunk_intra.py(3 hunks)fla/ops/kda/wy_fast.py(2 hunks)fla/ops/rwkv6/chunk.py(9 hunks)fla/ops/rwkv7/channel_mixing.py(3 hunks)fla/ops/rwkv7/fused_addcmul.py(3 hunks)fla/ops/rwkv7/fused_recurrent.py(2 hunks)fla/ops/utils/op.py(2 hunks)fla/ops/utils/solve_tril.py(2 hunks)fla/utils.py(1 hunks)
🧰 Additional context used
🧠 Learnings (3)
📓 Common learnings
Learnt from: KevlarKanou
Repo: fla-org/flash-linear-attention PR: 544
File: fla/utils.py:0-0
Timestamp: 2025-07-24T11:34:30.618Z
Learning: In `fla/utils.py`, the `TRITON_F32_DEFAULT` environment variable should be forcibly overridden to `'ieee'` for NVIDIA GPUs that don't support tf32 (device capability < 8), rather than using `setdefault()`. This override is required for compatibility with Turing GPUs and should not respect existing user settings.
📚 Learning: 2025-02-23T02:25:11.939Z
Learnt from: uniartisan
Repo: fla-org/flash-linear-attention PR: 197
File: fla/utils.py:104-107
Timestamp: 2025-02-23T02:25:11.939Z
Learning: The `get_available_device()` function in `fla/utils.py` is specifically used within Triton backend context, so error handling for backend access is not required as Triton availability is guaranteed at this point.
Applied to files:
fla/ops/utils/op.pyfla/utils.py
📚 Learning: 2025-07-24T11:34:30.618Z
Learnt from: KevlarKanou
Repo: fla-org/flash-linear-attention PR: 544
File: fla/utils.py:0-0
Timestamp: 2025-07-24T11:34:30.618Z
Learning: In `fla/utils.py`, the `TRITON_F32_DEFAULT` environment variable should be forcibly overridden to `'ieee'` for NVIDIA GPUs that don't support tf32 (device capability < 8), rather than using `setdefault()`. This override is required for compatibility with Turing GPUs and should not respect existing user settings.
Applied to files:
fla/ops/rwkv7/fused_addcmul.pyfla/ops/kda/wy_fast.pyfla/ops/utils/solve_tril.pyfla/ops/generalized_delta_rule/dplr/fused_recurrent.pyfla/ops/generalized_delta_rule/dplr/chunk_o_fwd.pyfla/ops/generalized_delta_rule/dplr/wy_fast_bwd.pyfla/utils.pyfla/ops/rwkv7/channel_mixing.py
🧬 Code graph analysis (9)
fla/ops/common/chunk_delta_h.py (1)
fla/utils.py (1)
check_shared_mem(445-451)
fla/ops/rwkv7/fused_addcmul.py (1)
fla/utils.py (2)
check_pytorch_version(349-350)input_guard(135-166)
fla/ops/kda/wy_fast.py (1)
fla/utils.py (1)
check_shared_mem(445-451)
fla/ops/kda/chunk_inter.py (1)
fla/utils.py (1)
check_shared_mem(445-451)
fla/ops/generalized_delta_rule/dplr/chunk_A_bwd.py (1)
fla/utils.py (1)
check_shared_mem(445-451)
fla/ops/generalized_delta_rule/dplr/chunk_h_bwd.py (1)
fla/utils.py (1)
check_shared_mem(445-451)
fla/ops/generalized_delta_rule/dplr/chunk_o_fwd.py (1)
fla/utils.py (1)
check_shared_mem(445-451)
fla/ops/generalized_delta_rule/dplr/chunk_o_bwd.py (1)
fla/utils.py (1)
check_shared_mem(445-451)
fla/ops/generalized_delta_rule/dplr/wy_fast_bwd.py (1)
fla/utils.py (1)
check_shared_mem(445-451)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: Test H100 (PyTorch 2.7) / test-ops
- GitHub Check: check-pt-python-compatibility
🔇 Additional comments (23)
fla/ops/generalized_delta_rule/dplr/fused_recurrent.py (1)
9-9: LGTM! Constant naming now follows Python conventions.The import and usage of
USE_CUDA_GRAPHcorrectly follow the uppercase constant naming convention.Also applies to: 25-25
fla/ops/rwkv7/fused_addcmul.py (1)
11-11: LGTM! Both autotune configurations updated correctly.The import and all usages of
USE_CUDA_GRAPHin both autotune decorators are consistent with the new naming convention.Also applies to: 44-44, 104-104
fla/ops/kda/wy_fast.py (1)
10-10: LGTM! TF32 support flag renamed appropriately.The import and usage of
HAS_TF32_SUPPORTfollow the uppercase constant naming convention. The naming change fromis_*toHAS_*is also semantically clearer for capability flags.Also applies to: 23-23
fla/ops/rwkv7/channel_mixing.py (1)
8-8: LGTM! All autotune configurations updated consistently.The import and both usages of
USE_CUDA_GRAPHin the autotune decorators are correct and follow the new naming convention.Also applies to: 28-28, 195-195
fla/utils.py (2)
396-403: LGTM! Constant definitions follow Python naming conventions.All four feature flags are now defined with uppercase names following PEP 8 constant naming guidelines:
USE_CUDA_GRAPHHAS_TF32_SUPPORTHAS_GATHER_SUPPORTFLA_USE_TMAThe rename from
is_tma_supportedtoFLA_USE_TMAis particularly well-chosen as it reflects that this constant is gated by theFLA_USE_TMAenvironment variable, not just checking for support.
405-411: LGTM! Internal usages updated correctly.The conditional checks now reference the renamed constants
HAS_TF32_SUPPORTandFLA_USE_TMAcorrectly. The log message at line 411 also appropriately uses the new constant name.fla/ops/rwkv7/fused_recurrent.py (1)
11-11: LGTM! Constant usage follows the new naming convention.The import and usage of
USE_CUDA_GRAPHare consistent with the refactoring across the codebase.Also applies to: 27-27
fla/ops/generalized_delta_rule/iplr/chunk.py (1)
12-12: LGTM! Both autotune configurations updated correctly.The import and both usages of
USE_CUDA_GRAPHin the autotune decorators are consistent with the new constant naming convention.Also applies to: 34-34, 119-119
fla/ops/utils/op.py (1)
9-9: LGTM! Gather support flag renamed appropriately.The import and usage of
HAS_GATHER_SUPPORTfollow the uppercase constant naming convention. The change fromis_*toHAS_*is semantically clearer for capability flags.Also applies to: 28-28
fla/ops/rwkv6/chunk.py (1)
14-20: CUDA graph flag rename is consistent across all kernels
USE_CUDA_GRAPHis imported once and threaded into every@triton.autotune(..., use_cuda_graph=...)site, keeping behavior identical while aligning with the constant-style naming. No functional or type issues here.Also applies to: 37-39, 115-118, 191-194, 266-269, 400-403, 482-484, 623-625
fla/ops/generalized_delta_rule/dplr/chunk_h_fwd.py (1)
10-11: USE_CUDA_GRAPH integration in dplr fwd looks goodThe new
USE_CUDA_GRAPHconstant is cleanly imported and used in the@triton.autotunedecorator, preserving the previous CUDA‑graph behavior with only a naming cleanup.Also applies to: 21-29
fla/ops/generalized_delta_rule/dplr/wy_fast_bwd.py (1)
9-10: CUDA graph flag rename in wy_fast_bwd is correct
USE_CUDA_GRAPHreplaces the old flag in both the import and@triton.autotunecall without altering configs or keys, so the backward kernel’s tuning behavior remains the same.Also applies to: 19-27
fla/ops/generalized_delta_rule/dplr/chunk_h_bwd.py (1)
10-11: Consistent CUDA graph flag rename in dplr backward kernelImporting
USE_CUDA_GRAPHand feeding it into@triton.autotune(..., use_cuda_graph=...)aligns this kernel with the new constant naming, with no functional changes.Also applies to: 21-29
fla/ops/kda/chunk_inter.py (1)
10-10: LGTM! Clean refactoring to constant naming convention.The import and usage of
FLA_USE_TMAare consistent and follow Python's uppercase naming convention for constants.Also applies to: 205-205
fla/ops/generalized_delta_rule/dplr/chunk_A_fwd.py (1)
10-10: LGTM! Consistent refactoring across import and usage sites.Both
USE_CUDA_GRAPHandHAS_GATHER_SUPPORTconstants are correctly imported and used in the autotune decorator and kernel invocation.Also applies to: 25-25, 195-195
fla/ops/gated_delta_product/chunk_deltaproduct_h.py (1)
10-10: LGTM! Consistent updates across multiple kernels.The
USE_CUDA_GRAPHconstant is correctly imported and consistently used in both forward and backward kernel autotune decorators.Also applies to: 30-30, 209-209
fla/ops/common/chunk_delta_h.py (1)
10-10: LGTM! Properly refactored across both kernel configurations.The constant is correctly imported and consistently applied to both the forward (
chunk_gated_delta_rule_fwd_kernel_h_blockdim64) and backward (chunk_gated_delta_rule_bwd_kernel_dhu_blockdim64) kernel autotune configurations.Also applies to: 31-31, 225-225
fla/ops/generalized_delta_rule/dplr/chunk_o_fwd.py (1)
9-9: LGTM! Straightforward constant refactoring.The import and usage are clean and consistent with the project-wide naming convention update.
Also applies to: 28-28
fla/ops/kda/chunk_intra.py (1)
10-10: LGTM! Consistent TMA flag updates.The
FLA_USE_TMAconstant is correctly imported and consistently used in bothchunk_kda_fwd_kernel_intra_sub_interandchunk_kda_bwd_kernel_intrakernel invocations.Also applies to: 509-509, 594-594
fla/ops/generalized_delta_rule/dplr/wy_fast_fwd.py (1)
10-10: LGTM! Comprehensive refactoring across multiple usage patterns.The constants are correctly imported and consistently used across:
- Three autotune decorator configurations (
USE_CUDA_GRAPH)- One kernel parameter default value (
GATHER_SUPPORTED = HAS_GATHER_SUPPORT)All usage patterns follow the new naming convention appropriately.
Also applies to: 22-22, 68-68, 82-82, 148-148
fla/ops/generalized_delta_rule/dplr/chunk_o_bwd.py (1)
10-10: LGTM! Thorough refactoring across all backward kernels.The
USE_CUDA_GRAPHconstant is correctly imported and consistently applied to all three kernel autotune configurations:
chunk_dplr_bwd_kernel_dAuchunk_dplr_bwd_o_kernelchunk_dplr_bwd_kernel_dvAlso applies to: 27-27, 100-100, 230-230
fla/ops/generalized_delta_rule/dplr/chunk_A_bwd.py (2)
10-10: LGTM! Constant naming follows Python conventions.The import correctly uses the new uppercase constant names (
USE_CUDA_GRAPH,HAS_GATHER_SUPPORT), which follows Python PEP 8 naming conventions for module-level constants.
25-25: Refactoring verified—all constant names successfully and consistently renamed.The verification confirms the constant renaming is complete and correct:
- All 42 matches for
use_cuda_graphcorrectly show the Triton API parameter receiving the newUSE_CUDA_GRAPHconstant value- Zero matches found for old constant names (
is_gather_supported,is_tf32_supported,is_tma_supported), confirming full replacement- No leftover inconsistencies across 15+ files in the ops directory
Summary by CodeRabbit