Skip to content

[Feature] add T.sync_warp & T.shfl_sync; change extern pdl into intrin#1614

Merged
LeiWang1999 merged 4 commits intotile-ai:mainfrom
silentCoder-dev:sync-warp
Jan 6, 2026
Merged

[Feature] add T.sync_warp & T.shfl_sync; change extern pdl into intrin#1614
LeiWang1999 merged 4 commits intotile-ai:mainfrom
silentCoder-dev:sync-warp

Conversation

@silentCoder-dev
Copy link
Collaborator

@silentCoder-dev silentCoder-dev commented Jan 6, 2026

Solve #1598

Summary by CodeRabbit

  • New Features

    • Added sync_warp() to synchronize threads within a warp (optional mask) and shfl_sync() to broadcast values across a warp (optional width).
    • Added programmatic launch and grid-dependency sync primitives for improved cross-kernel coordination.
  • Tests

    • Added tests validating warp synchronization and shuffle/broadcast behavior, including end-to-end checks when run as a script.

✏️ Tip: You can customize this high-level summary in your review settings.

@github-actions
Copy link

github-actions bot commented Jan 6, 2026

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run pre-commit run --all-files in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Jan 6, 2026

📝 Walkthrough

Walkthrough

Adds TileLang warp-level intrinsics (sync_warp, shfl_sync), registers new TL builtins (sync_warp, pdl_trigger, pdl_sync), updates CUDA codegen to emit __syncwarp/__shfl_sync and PDL calls, and adds Python tests validating warp sync and shuffle behavior and generated source.

Changes

Cohort / File(s) Summary
Language builtins
tilelang/language/builtin.py
Added sync_warp(mask: int = None) and shfl_sync(mask: int, value: int | PrimExpr, srcLane: int, width: int = None) to expose warp-level sync and shuffle APIs.
TL builtin registrations
src/op/builtin.cc, src/op/builtin.h
Registered new TL builtins: sync_warp (variadic inputs, opaque effect), pdl_trigger, and pdl_sync; added corresponding header declarations.
CUDA codegen
src/target/codegen_cuda.cc
Lowers tl::sync_warp to __syncwarp(<opt>), tl::pdl_trigger to cudaTriggerProgrammaticLaunchCompletion(), and tl::pdl_sync to cudaGridDependencySynchronize().
PDL helpers
tilelang/language/pdl.py
Replaced extern CUDA API calls with TL intrinsics (tl.pdl_trigger, tl.pdl_sync) returning void.
Tests
testing/python/language/test_tilelang_language_warp_sync.py
New tests: kernel_with_warp_sync (checks __syncwarp in generated source and runtime B[0] == -1) and kernel_with_shfl_sync (checks __shfl_sync and broadcast correctness).

Sequence Diagram(s)

(omitted — changes add intrinsics, codegen, and tests but do not introduce a new multi-component control flow that requires a sequence diagram.)

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

Possibly related issues

Possibly related PRs

Suggested reviewers

  • LeiWang1999

Poem

🐰 I hopped into the CUDA night,
Where lanes converse in shared light,
A snooze, a sync, a shuffle's call,
Values leap and land for all,
Warps now dance — small, synchronous delight.

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 40.00% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title clearly and concisely summarizes the main changes: adding two new warp synchronization features (sync_warp and shfl_sync) to the TileLang API.
✨ Finishing touches
  • 📝 Generate docstrings

📜 Recent review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 660f035 and a485bd9.

📒 Files selected for processing (4)
  • src/op/builtin.cc
  • src/op/builtin.h
  • src/target/codegen_cuda.cc
  • tilelang/language/pdl.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • src/target/codegen_cuda.cc
🧰 Additional context used
🧬 Code graph analysis (3)
tilelang/language/pdl.py (1)
tilelang/language/tir/op.py (1)
  • call_intrin (120-145)
src/op/builtin.h (2)
tilelang/language/builtin.py (1)
  • sync_warp (712-716)
tilelang/language/pdl.py (2)
  • pdl_trigger (10-14)
  • pdl_sync (17-21)
src/op/builtin.cc (2)
tilelang/language/builtin.py (1)
  • sync_warp (712-716)
tilelang/language/pdl.py (2)
  • pdl_trigger (10-14)
  • pdl_sync (17-21)
🔇 Additional comments (8)
src/op/builtin.h (3)

445-451: LGTM!

The sync_warp declaration follows the established pattern, is well-documented, and is logically positioned after sync_grid.


453-459: LGTM!

The pdl_trigger declaration is consistent with the codebase conventions and provides clear documentation.


461-467: LGTM!

The pdl_sync declaration maintains consistency with existing code patterns and includes appropriate documentation.

src/op/builtin.cc (3)

291-292: LGTM!

The sync_warp registration correctly uses variadic inputs to support the optional mask parameter (as seen in the Python wrapper), and kOpaque is the appropriate call effect for synchronization primitives.


294-297: LGTM!

The pdl_trigger registration is correct with zero inputs (matching the parameterless Python wrapper) and uses the appropriate kOpaque effect for an operation that triggers programmatic launches.


299-300: LGTM!

The pdl_sync registration correctly specifies zero inputs and uses kOpaque call effect, consistent with the synchronization semantics and the Python wrapper interface.

tilelang/language/pdl.py (2)

17-21: Return type is correctly void; error handling does not rely on return values.

The void return type is intentional and consistent with pdl_trigger(). All call sites use pdl_sync() as a standalone statement without capturing any return value. Error handling for CUDA synchronization is managed through the MarkCudaSyncCalls transform that tracks which functions contain PDL sync calls, not through return values from the intrinsic itself. This is the standard TVM/TIR pattern for side-effect operations.


10-14: No actionable issues found.

All call sites of pdl_trigger() already treat it as a void-returning function (used as a standalone statement, not assigned or checked). The implementation is consistent with pdl_sync() and other void-returning intrinsic calls in the language module. CUDA error handling in this codebase is performed at the wrapper/adapter level rather than per-intrinsic-call, which is the appropriate design pattern. There is no evidence this function previously returned int32 or that the change breaks existing code.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🤖 Fix all issues with AI Agents
In @tilelang/language/builtin.py:
- Around line 712-717: Update the signature of sync_warp to use a proper
optional type for mask (e.g., mask: int | None = None or mask: Optional[int] =
None) to comply with PEP 484, and update its docstring to document the mask
parameter and default behavior (explain that mask selects which threads in the
warp to synchronize and that omitting it synchronizes all threads). Reference
the function name sync_warp and the parameter mask when making these changes.
- Around line 719-724: Update the shfl_sync signature to use a PEP 484-compliant
optional type (change width: int = None to width: int | None = None) and expand
the docstring for shfl_sync to describe all parameters and semantics: document
mask (active lanes bitmask), value (the value to shuffle and its dtype), srcLane
(source lane index within the warp), and width (shuffle width/stride and that it
is optional and affects which lanes participate), and note return type and that
it calls the underlying "__shfl_sync" extern; keep references to the function
name shfl_sync and the tir.call_extern invocation when editing.
🧹 Nitpick comments (2)
tilelang/language/builtin.py (1)

719-724: Consider adding HIP/ROCm support for consistency.

The existing shuffle functions (shfl_xor, shfl_down, shfl_up at lines 659-700) check _IS_HIP_AVAILABLE and use different function names for AMD GPUs. Consider adding similar HIP support to shfl_sync for consistency and cross-platform compatibility.

🔎 Example implementation with HIP support
 def shfl_sync(mask: int, value: int | PrimExpr, srcLane: int, width: int | None = None):
     """Broadcast a value from one thread to other threads in the warp."""
-    if width is None:
-        return tir.call_extern(value.dtype, "__shfl_sync", mask, value, srcLane)
-    return tir.call_extern(value.dtype, "__shfl_sync", mask, value, srcLane, width)
+    if _IS_HIP_AVAILABLE:
+        if width is None:
+            return tir.call_extern(value.dtype, "__shfl", value, srcLane)
+        return tir.call_extern(value.dtype, "__shfl", value, srcLane, width)
+    else:
+        if width is None:
+            return tir.call_extern(value.dtype, "__shfl_sync", mask, value, srcLane)
+        return tir.call_extern(value.dtype, "__shfl_sync", mask, value, srcLane, width)

Note: Verify the correct HIP function name and signature before implementing.

testing/python/language/test_tilelang_language_warp_sync.py (1)

28-29: Consider using torch.zeros() for clearer test initialization.

The tests use torch.empty() which leaves tensors uninitialized. While this works correctly since the kernels write to all accessed elements, using torch.zeros() would make the tests more robust and easier to debug if issues arise.

🔎 Suggested change
 def test_warp_sync():
-    a = torch.empty((1), device="cuda", dtype=torch.int32)
-    b = torch.empty((1), device="cuda", dtype=torch.int32)
+    a = torch.zeros((1), device="cuda", dtype=torch.int32)
+    b = torch.zeros((1), device="cuda", dtype=torch.int32)
     kernel = kernel_with_warp_sync()
     assert "__syncwarp" in kernel.get_kernel_source()
     kernel(a, b)
     assert b[0] == -1


 def test_shfl_sync():
-    a = torch.empty((32), device="cuda", dtype=torch.int32)
+    a = torch.zeros((32), device="cuda", dtype=torch.int32)
     kernel = kernel_with_shfl_sync()
     assert "__shfl_sync" in kernel.get_kernel_source()
     kernel(a)

Also applies to: 52-52

📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between cfbc49b and 5f63962.

📒 Files selected for processing (2)
  • testing/python/language/test_tilelang_language_warp_sync.py
  • tilelang/language/builtin.py
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-12-18T04:50:00.512Z
Learnt from: silentCoder-dev
Repo: tile-ai/tilelang PR: 1464
File: testing/python/language/test_tilelang_language_rand.py:14-14
Timestamp: 2025-12-18T04:50:00.512Z
Learning: In `testing/python/language/test_tilelang_language_rand.py`, the TileLang kernel uses `blk_M = M` (single block) and calls `rng_rand()` four times per element to align results with the Triton implementation, which uses `blk_M = 128` (multiple blocks) and calls the RNG once per element. These differences compensate for internal RNG behavior differences between TileLang and Triton.

Applied to files:

  • testing/python/language/test_tilelang_language_warp_sync.py
🧬 Code graph analysis (2)
tilelang/language/builtin.py (2)
tilelang/language/tir/op.py (1)
  • call_extern (173-195)
src/tl_templates/cuda/common.h (3)
  • shfl_sync (617-619)
  • shfl_sync (643-647)
  • shfl_sync (673-677)
testing/python/language/test_tilelang_language_warp_sync.py (3)
tilelang/language/kernel.py (1)
  • threads (214-218)
tilelang/language/tir/op.py (2)
  • call_extern (173-195)
  • all (1913-1930)
tilelang/language/builtin.py (2)
  • sync_warp (712-716)
  • shfl_sync (719-723)
🪛 Ruff (0.14.10)
tilelang/language/builtin.py

712-712: PEP 484 prohibits implicit Optional

Convert to Optional[T]

(RUF013)


719-719: PEP 484 prohibits implicit Optional

Convert to Optional[T]

(RUF013)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
🔇 Additional comments (2)
testing/python/language/test_tilelang_language_warp_sync.py (2)

8-33: Well-designed synchronization test!

The test correctly validates warp synchronization by:

  1. Using __nanosleep to increase the likelihood of a race condition without sync
  2. Verifying both source code generation ("__syncwarp" in source) and runtime behavior (B[0] == -1)
  3. Using a realistic scenario with 32 threads (standard NVIDIA warp size)

The test logic is sound and should effectively catch synchronization issues.


36-56: Effective shuffle broadcast test!

The test correctly validates shfl_sync by:

  1. Computing unique values per lane (val = tx * 10)
  2. Broadcasting lane 31's value (310) to all 32 threads using mask 0xFFFFFFFF
  3. Verifying both source generation and that all output elements equal 310

The test logic is correct and comprehensive.

@kurisu6912 kurisu6912 linked an issue Jan 6, 2026 that may be closed by this pull request
1 task
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🤖 Fix all issues with AI Agents
In @tilelang/language/builtin.py:
- Around line 719-723: The function shfl_sync has an invalid PEP 484 annotation
and a sparse docstring; change the parameter annotation from width: int = None
to width: int | None = None and expand the docstring for shfl_sync to describe
each parameter: mask (active lanes bitmask), value (the value to shuffle, note
its dtype is used in the call), srcLane (source lane index), and width (optional
warp segment width), while keeping existing behavior that calls tir.call_extern
with or without width based on None.
- Around line 712-716: Update the sync_warp signature and docstring: change the
parameter annotation from "mask: int = None" to "mask: int | None = None" to be
PEP 484 compliant, and expand the docstring for sync_warp to document the mask
parameter as a bitmask selecting which threads to synchronize and note that when
mask is None the call synchronizes all threads in the warp; keep existing return
behavior using tir.call_intrin with or without mask.
📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 41f69be and 660f035.

📒 Files selected for processing (4)
  • src/op/builtin.cc
  • src/op/builtin.h
  • src/target/codegen_cuda.cc
  • tilelang/language/builtin.py
🧰 Additional context used
🧬 Code graph analysis (4)
src/op/builtin.h (1)
tilelang/language/builtin.py (1)
  • sync_warp (712-716)
src/target/codegen_cuda.cc (1)
tilelang/language/builtin.py (1)
  • sync_warp (712-716)
src/op/builtin.cc (1)
tilelang/language/builtin.py (1)
  • sync_warp (712-716)
tilelang/language/builtin.py (2)
tilelang/language/tir/op.py (2)
  • call_intrin (120-145)
  • call_extern (173-195)
src/tl_templates/cuda/common.h (3)
  • shfl_sync (617-619)
  • shfl_sync (643-647)
  • shfl_sync (673-677)
🪛 Ruff (0.14.10)
tilelang/language/builtin.py

712-712: PEP 484 prohibits implicit Optional

Convert to Optional[T]

(RUF013)


719-719: PEP 484 prohibits implicit Optional

Convert to Optional[T]

(RUF013)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
🔇 Additional comments (3)
src/op/builtin.h (1)

445-451: LGTM! Well-placed warp synchronization declaration.

The new sync_warp() declaration follows the established pattern for synchronization primitives and is logically positioned next to sync_grid().

src/op/builtin.cc (1)

291-292: LGTM! Proper builtin registration.

The registration correctly uses variadic inputs to support the optional mask parameter and marks the call effect as opaque, which is appropriate for a synchronization primitive.

src/target/codegen_cuda.cc (1)

1880-1886: LGTM! Correct CUDA codegen for warp synchronization.

The implementation correctly emits __syncwarp() with an optional mask parameter, matching the CUDA intrinsic signature.

@silentCoder-dev silentCoder-dev changed the title [Feature] add T.sync_warp & T.shfl_sync [Feature] add T.sync_warp & T.shfl_sync; change extern pdl into intrin Jan 6, 2026
@LeiWang1999 LeiWang1999 merged commit a756074 into tile-ai:main Jan 6, 2026
6 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Feature Request] T.sync_warp support

2 participants