[Feature] Support passing PrimExpr value in tile-level atomic operation#1796
[Feature] Support passing PrimExpr value in tile-level atomic operation#1796LeiWang1999 merged 6 commits intotile-ai:mainfrom
Conversation
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
📝 WalkthroughWalkthroughThis PR refactors atomic operations (add, max, min) to support both buffer-like and scalar source values. A new utility function IsBufferLikeExpr categorizes expressions, unified SIMT loop construction handles either source type, and extent inference is centralized. Public APIs are extended with src_value field exposure and a use_tma parameter. Vectorization and testing are enhanced accordingly. Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Pull request overview
This pull request adds support for passing PrimExpr values (scalars like 1.0 or other expressions) in tile-level atomic operations, which was previously only supported at the thread level. It also refactors the atomic operator structure by extracting common utilities and improving argument checking.
Changes:
- Moved
get_buffer_region_from_loadfromtilelang/utils/language.pytotilelang/language/utils.pyand added newget_extentutility function - Enhanced atomic operators (add/max/min) to support PrimExpr values as the source argument
- Added vectorization support for scalar values in atomic operations
- Reorganized and expanded test coverage for atomic operations
Reviewed changes
Copilot reviewed 16 out of 16 changed files in this pull request and generated 15 comments.
Show a summary per file
| File | Description |
|---|---|
| tilelang/utils/language.py | Removed get_buffer_region_from_load function (moved to utils) |
| tilelang/language/utils.py | Added get_buffer_region_from_load and new get_extent utility function |
| tilelang/language/copy_op.py | Refactored to use shared get_extent utility |
| tilelang/language/atomic.py | Enhanced atomic operations to support PrimExpr values, with improved argument checking |
| tilelang/_typing.py | Added PyPrimExpr type alias for future use |
| tilelang/intrinsics/mfma_macro_generator.py | Updated import path for get_buffer_region_from_load |
| tilelang/engine/phase.py | Added debug print statements (should be removed) |
| testing/python/language/test_tilelang_language_atomic.py | Reorganized tests and added new test cases for PrimExpr values |
| src/op/utils.h, src/op/utils.cc | Added IsBufferLikeExpr utility function |
| src/op/atomic_reduce.h, src/op/atomic_reduce.cc | Added src_value field to support scalar values, updated loop generation logic |
| src/op/atomic_add.h, src/op/atomic_add.cc | Updated to support scalar values with TMA check |
| src/transform/vectorize_loop.cc | Added broadcasting logic for scalar values in atomic operations |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
There was a problem hiding this comment.
Actionable comments posted: 4
🤖 Fix all issues with AI agents
In `@src/transform/vectorize_loop.cc`:
- Around line 566-570: The current logic calls BroadcastTo(src, vector_size,
...) whenever src.same_as(op->args[1]) without verifying src is actually a
scalar, which can trigger BroadcastTo assertions for non-scalar inputs; update
the branch to first check that src is a scalar (or that src.dtype().lanes() ==
1) before calling BroadcastTo, and if it is not scalar or its lane count
mismatches vector_size set need_scalarize = true (and skip BroadcastTo) so
VisitExpr will handle scalarization; reference symbols: src, op->args[1],
BroadcastTo, need_scalarize, and vector_size.
In `@tilelang/engine/phase.py`:
- Around line 239-241: Remove the temporary debug prints around the
VectorizeLoop transformation: delete the two unguarded print(mod) calls that
appear immediately before and after the call to
tilelang.transform.VectorizeLoop(enable_vectorize=allow_vectorize(pass_ctx=pass_ctx))(mod)
so compilation no longer emits verbose output; keep only the transformation call
(in the function/method where mod is processed).
- Around line 182-184: Remove the unguarded debug prints printing the IRModule
by deleting the two print(mod) statements around the LowerTileOp invocation; if
conditional IR inspection is needed, wrap a diagnostic print behind the existing
should_enable_ast_print() check (or similar) so the IR is only emitted when that
flag is true, keeping the call to tilelang.transform.LowerTileOp()(mod)
unchanged.
In `@tilelang/language/utils.py`:
- Around line 118-120: In the isinstance(indice, tir.Ramp) branch (the assert on
extents), make the assertion message consistent with the condition: update the
assert to either check extents is not None if extents are required for
BufferLoad with Ramp indices, or (preferably here) change the message to
"extents should be None for BufferLoad with Ramp indices" so it matches the
existing condition assert extents is None; modify the assertion in the same
block that currently reads assert extents is None, "extents should be provided
for BufferLoad with Ramp indices".
🧹 Nitpick comments (1)
testing/python/language/test_tilelang_language_atomic.py (1)
337-361: Remove debug prints from test runs.
These🧹 Suggested cleanup
- print(A, ref_A) @@ - print(A, ref_A)Also applies to: 579-603
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Fix all issues with AI agents
In `@tilelang/language/atomic.py`:
- Around line 250-258: The code currently allows use_tma=True even when the
source is a scalar (src_extent is falsy); add a guard after src_extent is
determined that rejects or disables TMA: if use_tma and not src_extent, raise a
clear error (e.g., ValueError) or force use_tma=False to prevent emitting TMA
annotations for scalar PrimExpr sources. Update the logic near the existing
src_extent/dst_extent handling (the block that calls to_buffer_region and checks
Buffer instances) so the check runs before any TMA-related annotations are
emitted (refer to symbols src_extent, use_tma, to_buffer_region, Buffer).
🧹 Nitpick comments (1)
testing/python/language/test_tilelang_language_atomic.py (1)
337-360: Use thedtypeparameter when allocating torch tensors in expr tests.
Right now the expr tests always allocate float32 tensors even when a different dtype is passed, which reduces coverage for non-fp32 dtypes.🛠️ Suggested fix
- A = torch.zeros(M, N, dtype=torch.float32).cuda() + A = torch.zeros(M, N, dtype=getattr(torch, dtype)).cuda() ... - A = torch.randn(M, N, dtype=torch.float32).cuda() + A = torch.randn(M, N, dtype=getattr(torch, dtype)).cuda()Also applies to: 578-601
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Fix all issues with AI agents
In `@tilelang/_typing.py`:
- Around line 38-41: Replace the PEP 604 union expression used by PyPrimExpr
with the typing.Union form to maintain Python 3.9 compatibility: change the
declaration of PyPrimExpr to use Union[tir.PrimExpr, int, float, bool] (and
ensure Union is imported from typing if not already) while keeping the TypeAlias
annotation; update the symbol PyPrimExpr accordingly so it matches the other
TypeAlias usages in the file.
|
@regression-perf |
Performance Regression Test ReportTriggered by: @SiriusNEO Results
Artifacts
|
This PR supports operations like
T.atomic_add(A_buf, 1.0), which is primarily only supported in thread-level atomic add operations. And it also enhances the whole structure of atomic operators as well as the argument checking part of copy operators.Summary by CodeRabbit
Release Notes
New Features
Tests