Skip to content

[Feature] Support passing PrimExpr value in tile-level atomic operation#1796

Merged
LeiWang1999 merged 6 commits intotile-ai:mainfrom
SiriusNEO:chaofan/atomic_0204
Feb 6, 2026
Merged

[Feature] Support passing PrimExpr value in tile-level atomic operation#1796
LeiWang1999 merged 6 commits intotile-ai:mainfrom
SiriusNEO:chaofan/atomic_0204

Conversation

@SiriusNEO
Copy link
Collaborator

@SiriusNEO SiriusNEO commented Feb 5, 2026

This PR supports operations like T.atomic_add(A_buf, 1.0), which is primarily only supported in thread-level atomic add operations. And it also enhances the whole structure of atomic operators as well as the argument checking part of copy operators.

Summary by CodeRabbit

Release Notes

  • New Features

    • Atomic operations now support both buffer regions and scalar values as sources
    • Added Tensor Memory Accelerator (TMA) path for atomic add operations
    • Enhanced vectorization support for atomic operations in parallel contexts
  • Tests

    • Expanded atomic operation test coverage with scalar sources, expressions, and memory ordering scenarios

@SiriusNEO SiriusNEO requested a review from LeiWang1999 February 5, 2026 06:50
@github-actions
Copy link

github-actions bot commented Feb 5, 2026

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run pre-commit run --all-files in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀

@SiriusNEO SiriusNEO requested a review from Copilot February 5, 2026 06:50
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Feb 5, 2026

📝 Walkthrough

Walkthrough

This PR refactors atomic operations (add, max, min) to support both buffer-like and scalar source values. A new utility function IsBufferLikeExpr categorizes expressions, unified SIMT loop construction handles either source type, and extent inference is centralized. Public APIs are extended with src_value field exposure and a use_tma parameter. Vectorization and testing are enhanced accordingly.

Changes

Cohort / File(s) Summary
Atomic Operation Core
src/op/atomic_add.cc, src/op/atomic_add.h, src/op/atomic_reduce.cc, src/op/atomic_reduce.h
Unified source operand handling: AtomicAdd/AtomicMax/AtomicMin constructors now accept either buffer-like (src + src_range) or scalar (src_value) sources. Removed scalar fast-path in SIMT loop construction; unified source-argument loading via src_value_arg. Exposes new public src_value field via reflection binding. Tightened TMA usage check when src_value is defined.
Utility Functions
src/op/utils.cc, src/op/utils.h
Added new public utility IsBufferLikeExpr() to identify buffer-like expressions (BufferLoadNode, BufferRegionNode, or CallNode with RegionOp callee).
Type Extensions
tilelang/_typing.py
Introduced PyPrimExpr type alias extending PrimExpr to include Python primitives (int, float, bool).
High-Level Atomic API
tilelang/language/atomic.py, tilelang/language/utils.py
Refactored extent handling via new public utilities get_buffer_region_from_load() and get_extent(); added use_tma parameter to atomic_add(); unified scalar/buffer source paths with early extent checks and region normalization.
Copy Operation
tilelang/language/copy_op.py
Simplified extent deduction by delegating to external get_extent() utility; removed local extent-handling logic.
Utility Migration
tilelang/utils/language.py
Moved get_buffer_region_from_load implementation from this module to tilelang.language.utils; now imports and re-uses external version.
Vectorization
src/transform/vectorize_loop.cc
Enhanced atomic add vectorization: broadcasts src operand to vector size when not already in Ramp/Broadcast form; improved Var handling with let_var_map_ and SSA-style Let rebinding.
Infrastructure
tilelang/intrinsics/mfma_macro_generator.py
Updated import path for get_buffer_region_from_load from tilelang.utils.language to tilelang.language.utils.
Comprehensive Testing
testing/python/language/test_tilelang_language_atomic.py
Significant test expansion: added thread-level atomic add (tma_atomic_add), tile-level variants (tile_atomic_add, tile_atomic_add_expr, tile_atomic_add_scalar), tile-level max/min tests with expression paths, memory ordering and auxiliary operation tests; explicit CUDA gating and streamlined verification via runtime assertions.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

Possibly related PRs

Suggested reviewers

  • LeiWang1999

Poem

🐰 Hop hop, the atoms now unite,
Buffer or scalar—both paths shine bright!
No more fast trails that scatter about,
One loop to rule them, without a doubt!
With extents deduced and vectors in flight,
Atomic operations dance through the night!

🚥 Pre-merge checks | ✅ 2 | ❌ 1
❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 30.56% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The PR title accurately describes the main feature: adding support for PrimExpr values in tile-level atomic operations, which is the primary objective of this changeset.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This pull request adds support for passing PrimExpr values (scalars like 1.0 or other expressions) in tile-level atomic operations, which was previously only supported at the thread level. It also refactors the atomic operator structure by extracting common utilities and improving argument checking.

Changes:

  • Moved get_buffer_region_from_load from tilelang/utils/language.py to tilelang/language/utils.py and added new get_extent utility function
  • Enhanced atomic operators (add/max/min) to support PrimExpr values as the source argument
  • Added vectorization support for scalar values in atomic operations
  • Reorganized and expanded test coverage for atomic operations

Reviewed changes

Copilot reviewed 16 out of 16 changed files in this pull request and generated 15 comments.

Show a summary per file
File Description
tilelang/utils/language.py Removed get_buffer_region_from_load function (moved to utils)
tilelang/language/utils.py Added get_buffer_region_from_load and new get_extent utility function
tilelang/language/copy_op.py Refactored to use shared get_extent utility
tilelang/language/atomic.py Enhanced atomic operations to support PrimExpr values, with improved argument checking
tilelang/_typing.py Added PyPrimExpr type alias for future use
tilelang/intrinsics/mfma_macro_generator.py Updated import path for get_buffer_region_from_load
tilelang/engine/phase.py Added debug print statements (should be removed)
testing/python/language/test_tilelang_language_atomic.py Reorganized tests and added new test cases for PrimExpr values
src/op/utils.h, src/op/utils.cc Added IsBufferLikeExpr utility function
src/op/atomic_reduce.h, src/op/atomic_reduce.cc Added src_value field to support scalar values, updated loop generation logic
src/op/atomic_add.h, src/op/atomic_add.cc Updated to support scalar values with TMA check
src/transform/vectorize_loop.cc Added broadcasting logic for scalar values in atomic operations

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

🤖 Fix all issues with AI agents
In `@src/transform/vectorize_loop.cc`:
- Around line 566-570: The current logic calls BroadcastTo(src, vector_size,
...) whenever src.same_as(op->args[1]) without verifying src is actually a
scalar, which can trigger BroadcastTo assertions for non-scalar inputs; update
the branch to first check that src is a scalar (or that src.dtype().lanes() ==
1) before calling BroadcastTo, and if it is not scalar or its lane count
mismatches vector_size set need_scalarize = true (and skip BroadcastTo) so
VisitExpr will handle scalarization; reference symbols: src, op->args[1],
BroadcastTo, need_scalarize, and vector_size.

In `@tilelang/engine/phase.py`:
- Around line 239-241: Remove the temporary debug prints around the
VectorizeLoop transformation: delete the two unguarded print(mod) calls that
appear immediately before and after the call to
tilelang.transform.VectorizeLoop(enable_vectorize=allow_vectorize(pass_ctx=pass_ctx))(mod)
so compilation no longer emits verbose output; keep only the transformation call
(in the function/method where mod is processed).
- Around line 182-184: Remove the unguarded debug prints printing the IRModule
by deleting the two print(mod) statements around the LowerTileOp invocation; if
conditional IR inspection is needed, wrap a diagnostic print behind the existing
should_enable_ast_print() check (or similar) so the IR is only emitted when that
flag is true, keeping the call to tilelang.transform.LowerTileOp()(mod)
unchanged.

In `@tilelang/language/utils.py`:
- Around line 118-120: In the isinstance(indice, tir.Ramp) branch (the assert on
extents), make the assertion message consistent with the condition: update the
assert to either check extents is not None if extents are required for
BufferLoad with Ramp indices, or (preferably here) change the message to
"extents should be None for BufferLoad with Ramp indices" so it matches the
existing condition assert extents is None; modify the assertion in the same
block that currently reads assert extents is None, "extents should be provided
for BufferLoad with Ramp indices".
🧹 Nitpick comments (1)
testing/python/language/test_tilelang_language_atomic.py (1)

337-361: Remove debug prints from test runs.
These print calls add noise without asserting anything.

🧹 Suggested cleanup
-    print(A, ref_A)
@@
-    print(A, ref_A)

Also applies to: 579-603

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Fix all issues with AI agents
In `@tilelang/language/atomic.py`:
- Around line 250-258: The code currently allows use_tma=True even when the
source is a scalar (src_extent is falsy); add a guard after src_extent is
determined that rejects or disables TMA: if use_tma and not src_extent, raise a
clear error (e.g., ValueError) or force use_tma=False to prevent emitting TMA
annotations for scalar PrimExpr sources. Update the logic near the existing
src_extent/dst_extent handling (the block that calls to_buffer_region and checks
Buffer instances) so the check runs before any TMA-related annotations are
emitted (refer to symbols src_extent, use_tma, to_buffer_region, Buffer).
🧹 Nitpick comments (1)
testing/python/language/test_tilelang_language_atomic.py (1)

337-360: Use the dtype parameter when allocating torch tensors in expr tests.
Right now the expr tests always allocate float32 tensors even when a different dtype is passed, which reduces coverage for non-fp32 dtypes.

🛠️ Suggested fix
-    A = torch.zeros(M, N, dtype=torch.float32).cuda()
+    A = torch.zeros(M, N, dtype=getattr(torch, dtype)).cuda()
...
-    A = torch.randn(M, N, dtype=torch.float32).cuda()
+    A = torch.randn(M, N, dtype=getattr(torch, dtype)).cuda()

Also applies to: 578-601

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Fix all issues with AI agents
In `@tilelang/_typing.py`:
- Around line 38-41: Replace the PEP 604 union expression used by PyPrimExpr
with the typing.Union form to maintain Python 3.9 compatibility: change the
declaration of PyPrimExpr to use Union[tir.PrimExpr, int, float, bool] (and
ensure Union is imported from typing if not already) while keeping the TypeAlias
annotation; update the symbol PyPrimExpr accordingly so it matches the other
TypeAlias usages in the file.

@SiriusNEO
Copy link
Collaborator Author

@regression-perf

@github-actions
Copy link

github-actions bot commented Feb 5, 2026

Performance Regression Test Report

Triggered by: @SiriusNEO
Workflow run: https://github.com/tile-ai/tilelang/actions/runs/21704845040

Results

File Original Latency Current Latency Speedup
example_warp_specialize_gemm_copy_1_gemm_0 0.036543 0.037505 0.97435
example_dequant_gemm_fp4_hopper 1.01496 1.03455 0.981063
example_gemm 0.022497 0.022881 0.983218
example_dequant_gemm_bf16_mxfp4_hopper 0.496905 0.505098 0.983779
example_gemm_autotune 0.022049 0.022304 0.988567
example_warp_specialize_gemm_copy_0_gemm_1 0.038688 0.038848 0.995881
example_tilelang_gemm_fp8 0.318469 0.319462 0.996893
example_mha_sink_bwd_bhsd_sliding_window 0.044237 0.0443651 0.997112
example_gemm_intrinsics 0.034496 0.034592 0.997225
example_mha_sink_bwd_bhsd 0.0615202 0.0616556 0.997804
example_gqa_bwd_wgmma_pipelined 0.0686207 0.068769 0.997844
example_tilelang_gemm_fp8_intrinsic 0.909904 0.911381 0.998379
example_tilelang_nsa_decode 0.00730685 0.00731646 0.998686
example_tilelang_nsa_fwd 0.0068136 0.00682159 0.998828
example_group_per_split_token_cast_to_fp8 0.0103164 0.0103282 0.998863
example_gemv 0.281573 0.28186 0.998981
example_tilelang_block_sparse_attn 0.0100648 0.0100744 0.999047
tilelang_example_sparse_tensorcore 0.0149016 0.0149145 0.999137
example_tilelang_gemm_fp8_2xAcc 0.186468 0.186627 0.999143
example_gqa_sink_fwd_bhsd_wgmma_pipelined_sliding_window 0.0144237 0.0144358 0.999167
example_gqa_bwd 0.0489774 0.0490116 0.999302
example_tilelang_sparse_gqa_decode_varlen_indice 0.0168907 0.016902 0.999327
block_sparse_attn_tilelang 0.010157 0.0101635 0.999359
fp8_lighting_indexer 0.0353793 0.0354001 0.999412
example_convolution 1.30972 1.31046 0.999439
sparse_mla_bwd 0.376631 0.376839 0.999449
example_gemm_schedule 0.0322643 0.0322775 0.999591
example_linear_attn_bwd 0.151374 0.151433 0.99961
example_dequant_gemv_fp16xint4 0.0283754 0.0283814 0.999789
example_gqa_sink_bwd_bhsd_sliding_window 0.0251474 0.0251504 0.999883
example_gqa_sink_fwd_bhsd_wgmma_pipelined 0.0142976 0.0142992 0.999887
example_gqa_sink_bwd_bhsd 0.0408132 0.0408173 0.999898
example_per_token_cast_to_fp8 0.00739677 0.0073973 0.999929
example_mha_bwd_bhsd 0.0400306 0.0400328 0.999944
example_topk 0.01072 0.01072 1
example_mha_sink_fwd_bhsd 0.0157218 0.0157217 1
example_mha_fwd_varlen 0.0450046 0.0450037 1.00002
example_mha_bwd_bshd 0.0406125 0.0406109 1.00004
example_elementwise_add 0.294015 0.294 1.00005
example_tilelang_sparse_gqa_decode_varlen_mask 0.0231352 0.023134 1.00005
example_vertical_slash_sparse_attn 0.231717 0.231703 1.00006
sparse_mla_fwd_pipelined 0.0946257 0.0946201 1.00006
example_mha_bwd_bshd_wgmma_pipelined 0.0254182 0.0254164 1.00007
example_linear_attn_fwd 0.0365575 0.036554 1.0001
example_tilelang_gemm_splitk_vectorize_atomicadd 1.40175 1.40138 1.00026
example_mla_decode 0.44932 0.449191 1.00029
example_tilelang_gemm_splitk 1.40213 1.40166 1.00033
example_gqa_bwd_tma_reduce_varlen 0.0513034 0.0512861 1.00034
topk_selector 0.0531166 0.0530952 1.0004
example_mha_sink_fwd_bhsd_wgmma_pipelined_sliding_window 0.0153253 0.0153182 1.00046
example_blocksparse_gemm 0.0224181 0.0224057 1.00056
example_dequant_gemm_w4a8 5.30859 5.3045 1.00077
sparse_mla_fwd 0.12919 0.129081 1.00084
example_mha_sink_fwd_bhsd_sliding_window 0.0155555 0.0155415 1.0009
example_mha_inference 0.079424 0.079329 1.0012
example_dynamic 0.652338 0.651338 1.00154
example_mha_sink_fwd_bhsd_wgmma_pipelined 0.0152924 0.0152577 1.00228
example_convolution_autotune 0.991376 0.987992 1.00343
example_gqa_decode 0.048481 0.048193 1.00598
example_dequant_gemm_bf16_fp4_hopper 0.56337 0.560009 1.006
example_dequant_groupedgemm_bf16_mxfp4_hopper 3.47059 3.41357 1.0167
example_warp_specialize_gemm_softpipe_stage2 0.039168 0.037729 1.03814
example_warp_specialize_gemm_barrierpipe_stage2 0.040064 0.037792 1.06012

Artifacts

  • regression_result.png (speedup plot) is attached as a workflow artifact. Download it from the workflow run page above.

@LeiWang1999 LeiWang1999 merged commit 4349b2c into tile-ai:main Feb 6, 2026
6 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants