[Feature] add T.sync_warp & T.shfl_sync; change extern pdl into intrin#1614
[Feature] add T.sync_warp & T.shfl_sync; change extern pdl into intrin#1614LeiWang1999 merged 4 commits intotile-ai:mainfrom
T.sync_warp & T.shfl_sync; change extern pdl into intrin#1614Conversation
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
📝 WalkthroughWalkthroughAdds TileLang warp-level intrinsics ( Changes
Sequence Diagram(s)(omitted — changes add intrinsics, codegen, and tests but do not introduce a new multi-component control flow that requires a sequence diagram.) Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes Possibly related issues
Possibly related PRs
Suggested reviewers
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
📜 Recent review detailsConfiguration used: defaults Review profile: CHILL Plan: Pro 📒 Files selected for processing (4)
🚧 Files skipped from review as they are similar to previous changes (1)
🧰 Additional context used🧬 Code graph analysis (3)tilelang/language/pdl.py (1)
src/op/builtin.h (2)
src/op/builtin.cc (2)
🔇 Additional comments (8)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 2
🤖 Fix all issues with AI Agents
In @tilelang/language/builtin.py:
- Around line 712-717: Update the signature of sync_warp to use a proper
optional type for mask (e.g., mask: int | None = None or mask: Optional[int] =
None) to comply with PEP 484, and update its docstring to document the mask
parameter and default behavior (explain that mask selects which threads in the
warp to synchronize and that omitting it synchronizes all threads). Reference
the function name sync_warp and the parameter mask when making these changes.
- Around line 719-724: Update the shfl_sync signature to use a PEP 484-compliant
optional type (change width: int = None to width: int | None = None) and expand
the docstring for shfl_sync to describe all parameters and semantics: document
mask (active lanes bitmask), value (the value to shuffle and its dtype), srcLane
(source lane index within the warp), and width (shuffle width/stride and that it
is optional and affects which lanes participate), and note return type and that
it calls the underlying "__shfl_sync" extern; keep references to the function
name shfl_sync and the tir.call_extern invocation when editing.
🧹 Nitpick comments (2)
tilelang/language/builtin.py (1)
719-724: Consider adding HIP/ROCm support for consistency.The existing shuffle functions (
shfl_xor,shfl_down,shfl_upat lines 659-700) check_IS_HIP_AVAILABLEand use different function names for AMD GPUs. Consider adding similar HIP support toshfl_syncfor consistency and cross-platform compatibility.🔎 Example implementation with HIP support
def shfl_sync(mask: int, value: int | PrimExpr, srcLane: int, width: int | None = None): """Broadcast a value from one thread to other threads in the warp.""" - if width is None: - return tir.call_extern(value.dtype, "__shfl_sync", mask, value, srcLane) - return tir.call_extern(value.dtype, "__shfl_sync", mask, value, srcLane, width) + if _IS_HIP_AVAILABLE: + if width is None: + return tir.call_extern(value.dtype, "__shfl", value, srcLane) + return tir.call_extern(value.dtype, "__shfl", value, srcLane, width) + else: + if width is None: + return tir.call_extern(value.dtype, "__shfl_sync", mask, value, srcLane) + return tir.call_extern(value.dtype, "__shfl_sync", mask, value, srcLane, width)Note: Verify the correct HIP function name and signature before implementing.
testing/python/language/test_tilelang_language_warp_sync.py (1)
28-29: Consider usingtorch.zeros()for clearer test initialization.The tests use
torch.empty()which leaves tensors uninitialized. While this works correctly since the kernels write to all accessed elements, usingtorch.zeros()would make the tests more robust and easier to debug if issues arise.🔎 Suggested change
def test_warp_sync(): - a = torch.empty((1), device="cuda", dtype=torch.int32) - b = torch.empty((1), device="cuda", dtype=torch.int32) + a = torch.zeros((1), device="cuda", dtype=torch.int32) + b = torch.zeros((1), device="cuda", dtype=torch.int32) kernel = kernel_with_warp_sync() assert "__syncwarp" in kernel.get_kernel_source() kernel(a, b) assert b[0] == -1 def test_shfl_sync(): - a = torch.empty((32), device="cuda", dtype=torch.int32) + a = torch.zeros((32), device="cuda", dtype=torch.int32) kernel = kernel_with_shfl_sync() assert "__shfl_sync" in kernel.get_kernel_source() kernel(a)Also applies to: 52-52
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
testing/python/language/test_tilelang_language_warp_sync.pytilelang/language/builtin.py
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-12-18T04:50:00.512Z
Learnt from: silentCoder-dev
Repo: tile-ai/tilelang PR: 1464
File: testing/python/language/test_tilelang_language_rand.py:14-14
Timestamp: 2025-12-18T04:50:00.512Z
Learning: In `testing/python/language/test_tilelang_language_rand.py`, the TileLang kernel uses `blk_M = M` (single block) and calls `rng_rand()` four times per element to align results with the Triton implementation, which uses `blk_M = 128` (multiple blocks) and calls the RNG once per element. These differences compensate for internal RNG behavior differences between TileLang and Triton.
Applied to files:
testing/python/language/test_tilelang_language_warp_sync.py
🧬 Code graph analysis (2)
tilelang/language/builtin.py (2)
tilelang/language/tir/op.py (1)
call_extern(173-195)src/tl_templates/cuda/common.h (3)
shfl_sync(617-619)shfl_sync(643-647)shfl_sync(673-677)
testing/python/language/test_tilelang_language_warp_sync.py (3)
tilelang/language/kernel.py (1)
threads(214-218)tilelang/language/tir/op.py (2)
call_extern(173-195)all(1913-1930)tilelang/language/builtin.py (2)
sync_warp(712-716)shfl_sync(719-723)
🪛 Ruff (0.14.10)
tilelang/language/builtin.py
712-712: PEP 484 prohibits implicit Optional
Convert to Optional[T]
(RUF013)
719-719: PEP 484 prohibits implicit Optional
Convert to Optional[T]
(RUF013)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
🔇 Additional comments (2)
testing/python/language/test_tilelang_language_warp_sync.py (2)
8-33: Well-designed synchronization test!The test correctly validates warp synchronization by:
- Using
__nanosleepto increase the likelihood of a race condition without sync- Verifying both source code generation (
"__syncwarp"in source) and runtime behavior (B[0] == -1)- Using a realistic scenario with 32 threads (standard NVIDIA warp size)
The test logic is sound and should effectively catch synchronization issues.
36-56: Effective shuffle broadcast test!The test correctly validates
shfl_syncby:
- Computing unique values per lane (
val = tx * 10)- Broadcasting lane 31's value (310) to all 32 threads using mask
0xFFFFFFFF- Verifying both source generation and that all output elements equal 310
The test logic is correct and comprehensive.
There was a problem hiding this comment.
Actionable comments posted: 2
🤖 Fix all issues with AI Agents
In @tilelang/language/builtin.py:
- Around line 719-723: The function shfl_sync has an invalid PEP 484 annotation
and a sparse docstring; change the parameter annotation from width: int = None
to width: int | None = None and expand the docstring for shfl_sync to describe
each parameter: mask (active lanes bitmask), value (the value to shuffle, note
its dtype is used in the call), srcLane (source lane index), and width (optional
warp segment width), while keeping existing behavior that calls tir.call_extern
with or without width based on None.
- Around line 712-716: Update the sync_warp signature and docstring: change the
parameter annotation from "mask: int = None" to "mask: int | None = None" to be
PEP 484 compliant, and expand the docstring for sync_warp to document the mask
parameter as a bitmask selecting which threads to synchronize and note that when
mask is None the call synchronizes all threads in the warp; keep existing return
behavior using tir.call_intrin with or without mask.
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (4)
src/op/builtin.ccsrc/op/builtin.hsrc/target/codegen_cuda.cctilelang/language/builtin.py
🧰 Additional context used
🧬 Code graph analysis (4)
src/op/builtin.h (1)
tilelang/language/builtin.py (1)
sync_warp(712-716)
src/target/codegen_cuda.cc (1)
tilelang/language/builtin.py (1)
sync_warp(712-716)
src/op/builtin.cc (1)
tilelang/language/builtin.py (1)
sync_warp(712-716)
tilelang/language/builtin.py (2)
tilelang/language/tir/op.py (2)
call_intrin(120-145)call_extern(173-195)src/tl_templates/cuda/common.h (3)
shfl_sync(617-619)shfl_sync(643-647)shfl_sync(673-677)
🪛 Ruff (0.14.10)
tilelang/language/builtin.py
712-712: PEP 484 prohibits implicit Optional
Convert to Optional[T]
(RUF013)
719-719: PEP 484 prohibits implicit Optional
Convert to Optional[T]
(RUF013)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
🔇 Additional comments (3)
src/op/builtin.h (1)
445-451: LGTM! Well-placed warp synchronization declaration.The new
sync_warp()declaration follows the established pattern for synchronization primitives and is logically positioned next tosync_grid().src/op/builtin.cc (1)
291-292: LGTM! Proper builtin registration.The registration correctly uses variadic inputs to support the optional mask parameter and marks the call effect as opaque, which is appropriate for a synchronization primitive.
src/target/codegen_cuda.cc (1)
1880-1886: LGTM! Correct CUDA codegen for warp synchronization.The implementation correctly emits
__syncwarp()with an optional mask parameter, matching the CUDA intrinsic signature.
T.sync_warp & T.shfl_syncT.sync_warp & T.shfl_sync; change extern pdl into intrin
Solve #1598
Summary by CodeRabbit
New Features
Tests
✏️ Tip: You can customize this high-level summary in your review settings.