Skip to content

Conversation

@LeiWang1999
Copy link
Member

@LeiWang1999 LeiWang1999 commented Oct 31, 2025

This pull request introduces new type-safe warp shuffle helper functions for 16-bit float types (such as half_t and bfloat16_t) and refactors the usage of CUDA shuffle intrinsics in reduction and cumulative sum kernels to use these helpers. The main goal is to avoid implicit type conversions that may be disallowed, promoting values to float for shuffling and then down-converting, which improves code safety and maintainability.

Type-safe shuffle helpers

  • Added a new namespace tl in src/tl_templates/cuda/common.h that defines generic and specialized type-safe wrappers for CUDA warp shuffle functions (shfl_xor_sync, shfl_down_sync, shfl_up_sync, shfl_sync) for both generic types and specifically for cutlass::half_t and cutlass::bfloat16_t. These helpers ensure explicit type promotion and conversion for 16-bit float types.

Refactoring to use type-safe helpers

  • Updated the AllReduce struct in src/tl_templates/cuda/reduce.h to use the new tl::shfl_xor_sync helper instead of directly calling __shfl_xor_sync, ensuring type safety for reduction operations. [1] [2]
  • Refactored the CumSum2D struct in src/tl_templates/cuda/reduce.h to use the new tl::shfl_down_sync, tl::shfl_up_sync, and tl::shfl_sync helpers for cumulative sum operations, replacing direct calls to CUDA intrinsics and removing unnecessary casts. [1] [2] [3] [4]

Summary by CodeRabbit

  • Refactor
    • Enhanced type-safety for low-precision floating-point GPU operations through improved shuffle utility functions.
    • Updated internal GPU communication infrastructure to use centralized, type-aware shuffle operations.

- Introduced generic passthrough functions for warp shuffle operations: `shfl_xor_sync`, `shfl_down_sync`, `shfl_up_sync`, and `shfl_sync`.
- Added specializations for `cutlass::half_t` and `cutlass::bfloat16_t` to ensure type safety during shuffle operations.
- Updated `reduce.h` to utilize the new shuffle functions, enhancing code clarity and maintainability.
@github-actions
Copy link

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run pre-commit run --all-files in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Oct 31, 2025

Walkthrough

Introduces type-safe warp shuffle helpers in namespace tl that forward CUDA shuffle intrinsics. Adds generic templates and specializations for 16-bit types (half_t, bfloat16_t) that promote to float before shuffling and convert back. Updates existing reduction code to use these new wrappers instead of direct CUDA calls.

Changes

Cohort / File(s) Summary
New shuffle helper functions
src/tl_templates/cuda/common.h
Added four generic template functions (shfl_xor_sync, shfl_down_sync, shfl_up_sync, shfl_sync) that forward to CUDA intrinsics. Added template specializations for half_t and bfloat16_t that promote inputs to float, shuffle, then demote back to original 16-bit type.
Updated reduction code
src/tl_templates/cuda/reduce.h
Replaced direct CUDA intrinsic calls with new tl::shfl_* wrappers across AllReduce and CumSum2D implementations. Changed __shfl_xor_sync, __shfl_down_sync, __shfl_up_sync, and __shfl_sync invocations to their tl:: equivalents.

Sequence Diagram

sequenceDiagram
    participant Caller as Reduction Code
    participant Generic as tl::shfl_* Generic
    participant Spec as tl::shfl_* Specialization<br/>(half_t/bfloat16_t)
    participant CUDA as CUDA Intrinsic

    Caller->>Generic: Call shfl_xor_sync(mask, value, lane)
    Generic->>CUDA: __shfl_xor_sync() [for default types]
    CUDA-->>Generic: result
    Generic-->>Caller: return

    Caller->>Spec: Call shfl_xor_sync(mask, half_t, lane)
    Spec->>Spec: Promote to float
    Spec->>CUDA: __shfl_xor_sync() [float]
    CUDA-->>Spec: float result
    Spec->>Spec: Convert back to half_t
    Spec-->>Caller: return half_t
Loading

Estimated code review effort

🎯 2 (Simple) | ⏱️ ~12 minutes

  • Substitution pattern in reduce.h is straightforward and repetitive across call sites
  • Template definitions and specializations in common.h follow a consistent pattern
  • No complex control flow or algorithmic changes

Areas requiring attention:

  • Verify all __shfl_* calls in reduce.h have been properly replaced with tl:: equivalents
  • Confirm 16-bit type specializations correctly handle float conversion round-trips without precision loss
  • Check that generic templates properly forward to underlying CUDA intrinsics for non-specialized types

Poem

🐰 Shuffling through lanes with type-safe grace,
Half and bfloat now find their place,
Float promotion, then demote with care,
Warp exchanges happen fair,
No silent casts to lead us astray,
The templates guide us on our way!

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 12.50% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title Check ✅ Passed The pull request title "[Bugfix] Support 16bits shfl_sync" is directly related to the main change in the changeset. The primary objective of this PR is to introduce type-safe warp shuffle helper functions for 16-bit floating types (cutlass::half_t and cutlass::bfloat16_t) and refactor existing code to use these wrappers. The title accurately captures this by referencing "16bits shfl_sync," clearly indicating that the change is about supporting shuffle operations for 16-bit types. The title is concise, specific, and would allow teammates scanning the git history to understand the primary change without ambiguity.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
src/tl_templates/cuda/reduce.h (2)

76-76: Inconsistent usage: SharedReduceWarp should also use the new wrapper.

Line 76 still calls __shfl_down_sync directly while other reduction code has been migrated to tl::shfl_down_sync. For consistency and type safety with 16-bit types, this should also use the wrapper.

Apply this diff:

-        T other = __shfl_down_sync(mask, partial, offset);
+        T other = tl::shfl_down_sync(mask, partial, offset);

162-162: Inconsistent usage: CumSum1D should also use the new wrappers.

CumSum1D still uses direct CUDA intrinsics (__shfl_down_sync, __shfl_up_sync, __shfl_sync) while the similar CumSum2D struct has been migrated to tl::shfl_* wrappers. For consistency and type safety with 16-bit types, CumSum1D should also use the wrappers.

Update the following lines:

  • Line 162: __shfl_down_synctl::shfl_down_sync
  • Lines 172, 175: __shfl_synctl::shfl_sync
  • Line 185: __shfl_up_synctl::shfl_up_sync
  • Lines 195, 198: __shfl_synctl::shfl_sync

Also applies to: 172-172, 175-175, 185-185, 195-195, 198-198

🧹 Nitpick comments (1)
src/tl_templates/cuda/common.h (1)

392-409: Consider: Optional width parameter not exposed.

The CUDA __shfl_*_sync intrinsics accept an optional width parameter (default 32) to support sub-warp shuffles. The current wrappers don't expose this parameter, which limits flexibility for advanced use cases requiring non-standard warp widths.

If sub-warp shuffle support is needed in the future, consider adding overloads that expose the width parameter:

template <typename T>
TL_DEVICE T shfl_xor_sync(unsigned mask, T val, int laneMask, int width) {
  return __shfl_xor_sync(mask, val, laneMask, width);
}
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 10911e2 and 2e69c0d.

📒 Files selected for processing (2)
  • src/tl_templates/cuda/common.h (1 hunks)
  • src/tl_templates/cuda/reduce.h (6 hunks)
🧰 Additional context used
🧠 Learnings (2)
📚 Learning: 2025-09-15T10:51:06.985Z
Learnt from: botbw
Repo: tile-ai/tilelang PR: 691
File: src/tl_templates/cuda/gemm_sp_sm80.h:81-85
Timestamp: 2025-09-15T10:51:06.985Z
Learning: In CUTLASS tensor operation layouts, crosswise constants should be computed using sizeof(T) (bytes), not cutlass::sizeof_bits<T>::value (bits). This is the established pattern in the official CUTLASS codebase, as seen in default_mma_core_sparse_sm80.h.

Applied to files:

  • src/tl_templates/cuda/common.h
📚 Learning: 2025-09-15T10:51:06.985Z
Learnt from: botbw
Repo: tile-ai/tilelang PR: 691
File: src/tl_templates/cuda/gemm_sp_sm80.h:81-85
Timestamp: 2025-09-15T10:51:06.985Z
Learning: In CUTLASS tensor operation layouts, crosswise constants should be computed using sizeof(T) (bytes), not cutlass::sizeof_bits<T>::value (bits). However, the layout template parameter should use sizeof_bits<T>::value (bits). This is the established pattern in the official CUTLASS codebase, as seen in default_mma_core_sparse_sm80.h where Crosswise uses sizeof(ElementA) but the layout template uses sizeof_bits<ElementA>::value.

Applied to files:

  • src/tl_templates/cuda/common.h
🧬 Code graph analysis (2)
src/tl_templates/cuda/reduce.h (1)
src/tl_templates/cuda/common.h (16)
  • tl (195-289)
  • tl (369-371)
  • tl (372-374)
  • tl (389-469)
  • shfl_xor_sync (393-395)
  • shfl_xor_sync (413-417)
  • shfl_xor_sync (441-446)
  • shfl_down_sync (398-400)
  • shfl_down_sync (420-424)
  • shfl_down_sync (449-453)
  • shfl_sync (407-409)
  • shfl_sync (433-437)
  • shfl_sync (463-467)
  • shfl_up_sync (403-405)
  • shfl_up_sync (427-431)
  • shfl_up_sync (456-460)
src/tl_templates/cuda/common.h (1)
src/tl_templates/cuda/reduce.h (2)
  • tl (5-88)
  • T (208-280)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
  • GitHub Check: Test for Python 3.12 with CUDA-12.8 (on self-hosted-nvidia)
🔇 Additional comments (5)
src/tl_templates/cuda/reduce.h (2)

105-105: LGTM! Type-safe shuffle wrappers applied correctly.

The replacement of __shfl_xor_sync with tl::shfl_xor_sync ensures proper type handling for 16-bit floating types in both the standard and Hopper reduction paths.

Also applies to: 125-125


237-237: LGTM! CumSum2D correctly migrated to type-safe shuffle wrappers.

All shuffle operations in CumSum2D (both reverse and forward modes) now use the tl::shfl_* wrappers, ensuring proper handling of 16-bit types.

Also applies to: 247-247, 250-250, 263-263, 273-273, 276-276

src/tl_templates/cuda/common.h (3)

382-410: LGTM! Generic shuffle wrappers provide clean passthrough.

The generic templates correctly forward to CUDA's warp shuffle intrinsics, maintaining compatibility with all types that natively support shuffling.


411-437: LGTM! half_t specializations correctly handle 16-bit shuffle requirements.

The explicit float promotion and down-conversion pattern ensures type safety and avoids relying on implicit conversions. Float has sufficient precision to preserve all half_t values (including NaN and Inf) during the round-trip conversion.


439-467: LGTM! bfloat16_t specializations correctly handle 16-bit shuffle requirements.

The explicit float promotion and down-conversion pattern ensures type safety and avoids relying on implicit conversions. Float has sufficient precision to preserve all bfloat16_t values (including NaN and Inf) during the round-trip conversion.

@LeiWang1999 LeiWang1999 merged commit 54d4bd6 into tile-ai:main Oct 31, 2025
5 of 7 checks passed
tzj-fxz pushed a commit to tzj-fxz/tilelang that referenced this pull request Nov 3, 2025
* Add type-safe warp shuffle helpers for 16-bit float types in common.h

- Introduced generic passthrough functions for warp shuffle operations: `shfl_xor_sync`, `shfl_down_sync`, `shfl_up_sync`, and `shfl_sync`.
- Added specializations for `cutlass::half_t` and `cutlass::bfloat16_t` to ensure type safety during shuffle operations.
- Updated `reduce.h` to utilize the new shuffle functions, enhancing code clarity and maintainability.

* lint fix
LeiWang1999 added a commit that referenced this pull request Nov 5, 2025
* [Test] Add cp async to avoid register spill

* [BugFix] GQA fwd and bwd
- Fix the undefined behavior of -inf in acc_s
- Fix the causal loop range in varlen scenario

* [TMA] Move on to TMA and locate the register spill issue

* [Debug] Not the reason of zero-assignment. Probably the combination of Parallel op & conditional qkT

* [Debug] The SIMT copy in producer occupies too many registers

* [BugFix] Use 3D lse and delta to avoid illegal instruction

* [Perf] Relaxed order for dQ and SIMT store for dKdV

* [Feat] For atomic add version

* [Lint]

* [Bugfix] Enable code lowering with producer‑copy‑only program (#1168)

* bugfix

* lint fix

* Enhance warp group register allocation to handle missing consumer bodies gracefully. Updated logic to annotate producer side when consumer is absent, ensuring robustness in degenerate warp-specialized patterns.

* Refactor VisitExpr_ method in inject_tma_barrier.cc for improved readability. Adjusted formatting and spacing for clarity in barrier handling logic.

* Update barrier handling in inject_tma_barrier.cc to accommodate newly appended entries. Adjusted the size of the replace vector to ensure it covers the full needed length, and modified the logic for appending barriers based on the updated replace conditions.

* [Bugfix] Support 16bits shfl_sync (#1169)

* Add type-safe warp shuffle helpers for 16-bit float types in common.h

- Introduced generic passthrough functions for warp shuffle operations: `shfl_xor_sync`, `shfl_down_sync`, `shfl_up_sync`, and `shfl_sync`.
- Added specializations for `cutlass::half_t` and `cutlass::bfloat16_t` to ensure type safety during shuffle operations.
- Updated `reduce.h` to utilize the new shuffle functions, enhancing code clarity and maintainability.

* lint fix

* [Testing] Move TMA 1D and test for its functionality (#1167)

* [Testing] Move TMA 1D and test for its functionality

* [Lint]

* [Refactor]: Change the params in pytest to avoid oom error during ci (#1170)

* [Refactor]: Change the params in pytest to avoid oom error during ci

* format

* fix

* Update test_example_cast.py

* Update parameters in test_example_cast

* Update test_example_flash_attention.py

* update

* format

* fix

* fix

* format

* [Bugfix] Fix tvm import path for editable build (#1172)

* [Language] Expose `T.warpgroup_fence_operand` for nvcc code motion (#986)

* remove debug print

* pipeline fix

* use the correct buffer access scope

* rs support

* warp warpgroup_fence_operand

* fix

* fp8 dtype ptx enhance

* mma fix

* TCGEN05 Interface

* tcgen05 support

* rebase

* update

* Enhance TCGEN05 support by adding new intrinsic operations and descriptors. Introduced `ptx_tcgen05_mma_ts` for tensor-memory to shared-memory instructions and `tcgen05_mma_arrive` for signaling barrier completion. Updated existing descriptors and code generation logic to accommodate these changes, ensuring compatibility with new instruction sets. Refactored related allocation functions and improved handling of shared memory descriptors.

* lint fix

* Refactor buffer reference handling in CUDA code generation and update test execution in tilelang. Ensure default annotations for unrolling are set correctly in TIR IR module.

* wgmma fix

---------

Co-authored-by: Zhiwen Mo <[email protected]>

* [Language] Add Correctness and performance check scripts for V2 (#1174)

* fix

* lint fix

* fix

* lint fix

* fix

* upd

* [Bugfix] Legalize Datatype for mma intrinisc codegen  (#1179)

* fix

* lint fix

* Enhance CUDA code generation by updating register type handling for float data types. Introduced a workaround for TF32 type compatibility and improved the registration of MMA register types for A and B operands.

* [Perf] Add layout and use_tma to boost performance

* [Lint]

* [Note]

---------

Co-authored-by: Lei Wang <[email protected]>
Co-authored-by: Yuqi Dong <[email protected]>
Co-authored-by: Zhiwen Mo <[email protected]>
RubiaCx pushed a commit to RubiaCx/tilelang that referenced this pull request Nov 24, 2025
* Add type-safe warp shuffle helpers for 16-bit float types in common.h

- Introduced generic passthrough functions for warp shuffle operations: `shfl_xor_sync`, `shfl_down_sync`, `shfl_up_sync`, and `shfl_sync`.
- Added specializations for `cutlass::half_t` and `cutlass::bfloat16_t` to ensure type safety during shuffle operations.
- Updated `reduce.h` to utilize the new shuffle functions, enhancing code clarity and maintainability.

* lint fix
RubiaCx pushed a commit to RubiaCx/tilelang that referenced this pull request Nov 24, 2025
* [Test] Add cp async to avoid register spill

* [BugFix] GQA fwd and bwd
- Fix the undefined behavior of -inf in acc_s
- Fix the causal loop range in varlen scenario

* [TMA] Move on to TMA and locate the register spill issue

* [Debug] Not the reason of zero-assignment. Probably the combination of Parallel op & conditional qkT

* [Debug] The SIMT copy in producer occupies too many registers

* [BugFix] Use 3D lse and delta to avoid illegal instruction

* [Perf] Relaxed order for dQ and SIMT store for dKdV

* [Feat] For atomic add version

* [Lint]

* [Bugfix] Enable code lowering with producer‑copy‑only program (tile-ai#1168)

* bugfix

* lint fix

* Enhance warp group register allocation to handle missing consumer bodies gracefully. Updated logic to annotate producer side when consumer is absent, ensuring robustness in degenerate warp-specialized patterns.

* Refactor VisitExpr_ method in inject_tma_barrier.cc for improved readability. Adjusted formatting and spacing for clarity in barrier handling logic.

* Update barrier handling in inject_tma_barrier.cc to accommodate newly appended entries. Adjusted the size of the replace vector to ensure it covers the full needed length, and modified the logic for appending barriers based on the updated replace conditions.

* [Bugfix] Support 16bits shfl_sync (tile-ai#1169)

* Add type-safe warp shuffle helpers for 16-bit float types in common.h

- Introduced generic passthrough functions for warp shuffle operations: `shfl_xor_sync`, `shfl_down_sync`, `shfl_up_sync`, and `shfl_sync`.
- Added specializations for `cutlass::half_t` and `cutlass::bfloat16_t` to ensure type safety during shuffle operations.
- Updated `reduce.h` to utilize the new shuffle functions, enhancing code clarity and maintainability.

* lint fix

* [Testing] Move TMA 1D and test for its functionality (tile-ai#1167)

* [Testing] Move TMA 1D and test for its functionality

* [Lint]

* [Refactor]: Change the params in pytest to avoid oom error during ci (tile-ai#1170)

* [Refactor]: Change the params in pytest to avoid oom error during ci

* format

* fix

* Update test_example_cast.py

* Update parameters in test_example_cast

* Update test_example_flash_attention.py

* update

* format

* fix

* fix

* format

* [Bugfix] Fix tvm import path for editable build (tile-ai#1172)

* [Language] Expose `T.warpgroup_fence_operand` for nvcc code motion (tile-ai#986)

* remove debug print

* pipeline fix

* use the correct buffer access scope

* rs support

* warp warpgroup_fence_operand

* fix

* fp8 dtype ptx enhance

* mma fix

* TCGEN05 Interface

* tcgen05 support

* rebase

* update

* Enhance TCGEN05 support by adding new intrinsic operations and descriptors. Introduced `ptx_tcgen05_mma_ts` for tensor-memory to shared-memory instructions and `tcgen05_mma_arrive` for signaling barrier completion. Updated existing descriptors and code generation logic to accommodate these changes, ensuring compatibility with new instruction sets. Refactored related allocation functions and improved handling of shared memory descriptors.

* lint fix

* Refactor buffer reference handling in CUDA code generation and update test execution in tilelang. Ensure default annotations for unrolling are set correctly in TIR IR module.

* wgmma fix

---------

Co-authored-by: Zhiwen Mo <[email protected]>

* [Language] Add Correctness and performance check scripts for V2 (tile-ai#1174)

* fix

* lint fix

* fix

* lint fix

* fix

* upd

* [Bugfix] Legalize Datatype for mma intrinisc codegen  (tile-ai#1179)

* fix

* lint fix

* Enhance CUDA code generation by updating register type handling for float data types. Introduced a workaround for TF32 type compatibility and improved the registration of MMA register types for A and B operands.

* [Perf] Add layout and use_tma to boost performance

* [Lint]

* [Note]

---------

Co-authored-by: Lei Wang <[email protected]>
Co-authored-by: Yuqi Dong <[email protected]>
Co-authored-by: Zhiwen Mo <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant