-
Notifications
You must be signed in to change notification settings - Fork 327
[Bugfix] Support 16bits shfl_sync #1169
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
- Introduced generic passthrough functions for warp shuffle operations: `shfl_xor_sync`, `shfl_down_sync`, `shfl_up_sync`, and `shfl_sync`. - Added specializations for `cutlass::half_t` and `cutlass::bfloat16_t` to ensure type safety during shuffle operations. - Updated `reduce.h` to utilize the new shuffle functions, enhancing code clarity and maintainability.
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
WalkthroughIntroduces type-safe warp shuffle helpers in namespace Changes
Sequence DiagramsequenceDiagram
participant Caller as Reduction Code
participant Generic as tl::shfl_* Generic
participant Spec as tl::shfl_* Specialization<br/>(half_t/bfloat16_t)
participant CUDA as CUDA Intrinsic
Caller->>Generic: Call shfl_xor_sync(mask, value, lane)
Generic->>CUDA: __shfl_xor_sync() [for default types]
CUDA-->>Generic: result
Generic-->>Caller: return
Caller->>Spec: Call shfl_xor_sync(mask, half_t, lane)
Spec->>Spec: Promote to float
Spec->>CUDA: __shfl_xor_sync() [float]
CUDA-->>Spec: float result
Spec->>Spec: Convert back to half_t
Spec-->>Caller: return half_t
Estimated code review effort🎯 2 (Simple) | ⏱️ ~12 minutes
Areas requiring attention:
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
src/tl_templates/cuda/reduce.h (2)
76-76: Inconsistent usage: SharedReduceWarp should also use the new wrapper.Line 76 still calls
__shfl_down_syncdirectly while other reduction code has been migrated totl::shfl_down_sync. For consistency and type safety with 16-bit types, this should also use the wrapper.Apply this diff:
- T other = __shfl_down_sync(mask, partial, offset); + T other = tl::shfl_down_sync(mask, partial, offset);
162-162: Inconsistent usage: CumSum1D should also use the new wrappers.CumSum1D still uses direct CUDA intrinsics (
__shfl_down_sync,__shfl_up_sync,__shfl_sync) while the similar CumSum2D struct has been migrated totl::shfl_*wrappers. For consistency and type safety with 16-bit types, CumSum1D should also use the wrappers.Update the following lines:
- Line 162:
__shfl_down_sync→tl::shfl_down_sync- Lines 172, 175:
__shfl_sync→tl::shfl_sync- Line 185:
__shfl_up_sync→tl::shfl_up_sync- Lines 195, 198:
__shfl_sync→tl::shfl_syncAlso applies to: 172-172, 175-175, 185-185, 195-195, 198-198
🧹 Nitpick comments (1)
src/tl_templates/cuda/common.h (1)
392-409: Consider: Optional width parameter not exposed.The CUDA
__shfl_*_syncintrinsics accept an optionalwidthparameter (default 32) to support sub-warp shuffles. The current wrappers don't expose this parameter, which limits flexibility for advanced use cases requiring non-standard warp widths.If sub-warp shuffle support is needed in the future, consider adding overloads that expose the width parameter:
template <typename T> TL_DEVICE T shfl_xor_sync(unsigned mask, T val, int laneMask, int width) { return __shfl_xor_sync(mask, val, laneMask, width); }
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
src/tl_templates/cuda/common.h(1 hunks)src/tl_templates/cuda/reduce.h(6 hunks)
🧰 Additional context used
🧠 Learnings (2)
📚 Learning: 2025-09-15T10:51:06.985Z
Learnt from: botbw
Repo: tile-ai/tilelang PR: 691
File: src/tl_templates/cuda/gemm_sp_sm80.h:81-85
Timestamp: 2025-09-15T10:51:06.985Z
Learning: In CUTLASS tensor operation layouts, crosswise constants should be computed using sizeof(T) (bytes), not cutlass::sizeof_bits<T>::value (bits). This is the established pattern in the official CUTLASS codebase, as seen in default_mma_core_sparse_sm80.h.
Applied to files:
src/tl_templates/cuda/common.h
📚 Learning: 2025-09-15T10:51:06.985Z
Learnt from: botbw
Repo: tile-ai/tilelang PR: 691
File: src/tl_templates/cuda/gemm_sp_sm80.h:81-85
Timestamp: 2025-09-15T10:51:06.985Z
Learning: In CUTLASS tensor operation layouts, crosswise constants should be computed using sizeof(T) (bytes), not cutlass::sizeof_bits<T>::value (bits). However, the layout template parameter should use sizeof_bits<T>::value (bits). This is the established pattern in the official CUTLASS codebase, as seen in default_mma_core_sparse_sm80.h where Crosswise uses sizeof(ElementA) but the layout template uses sizeof_bits<ElementA>::value.
Applied to files:
src/tl_templates/cuda/common.h
🧬 Code graph analysis (2)
src/tl_templates/cuda/reduce.h (1)
src/tl_templates/cuda/common.h (16)
tl(195-289)tl(369-371)tl(372-374)tl(389-469)shfl_xor_sync(393-395)shfl_xor_sync(413-417)shfl_xor_sync(441-446)shfl_down_sync(398-400)shfl_down_sync(420-424)shfl_down_sync(449-453)shfl_sync(407-409)shfl_sync(433-437)shfl_sync(463-467)shfl_up_sync(403-405)shfl_up_sync(427-431)shfl_up_sync(456-460)
src/tl_templates/cuda/common.h (1)
src/tl_templates/cuda/reduce.h (2)
tl(5-88)T(208-280)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
- GitHub Check: Test for Python 3.12 with CUDA-12.8 (on self-hosted-nvidia)
🔇 Additional comments (5)
src/tl_templates/cuda/reduce.h (2)
105-105: LGTM! Type-safe shuffle wrappers applied correctly.The replacement of
__shfl_xor_syncwithtl::shfl_xor_syncensures proper type handling for 16-bit floating types in both the standard and Hopper reduction paths.Also applies to: 125-125
237-237: LGTM! CumSum2D correctly migrated to type-safe shuffle wrappers.All shuffle operations in CumSum2D (both reverse and forward modes) now use the
tl::shfl_*wrappers, ensuring proper handling of 16-bit types.Also applies to: 247-247, 250-250, 263-263, 273-273, 276-276
src/tl_templates/cuda/common.h (3)
382-410: LGTM! Generic shuffle wrappers provide clean passthrough.The generic templates correctly forward to CUDA's warp shuffle intrinsics, maintaining compatibility with all types that natively support shuffling.
411-437: LGTM! half_t specializations correctly handle 16-bit shuffle requirements.The explicit float promotion and down-conversion pattern ensures type safety and avoids relying on implicit conversions. Float has sufficient precision to preserve all half_t values (including NaN and Inf) during the round-trip conversion.
439-467: LGTM! bfloat16_t specializations correctly handle 16-bit shuffle requirements.The explicit float promotion and down-conversion pattern ensures type safety and avoids relying on implicit conversions. Float has sufficient precision to preserve all bfloat16_t values (including NaN and Inf) during the round-trip conversion.
* Add type-safe warp shuffle helpers for 16-bit float types in common.h - Introduced generic passthrough functions for warp shuffle operations: `shfl_xor_sync`, `shfl_down_sync`, `shfl_up_sync`, and `shfl_sync`. - Added specializations for `cutlass::half_t` and `cutlass::bfloat16_t` to ensure type safety during shuffle operations. - Updated `reduce.h` to utilize the new shuffle functions, enhancing code clarity and maintainability. * lint fix
* [Test] Add cp async to avoid register spill * [BugFix] GQA fwd and bwd - Fix the undefined behavior of -inf in acc_s - Fix the causal loop range in varlen scenario * [TMA] Move on to TMA and locate the register spill issue * [Debug] Not the reason of zero-assignment. Probably the combination of Parallel op & conditional qkT * [Debug] The SIMT copy in producer occupies too many registers * [BugFix] Use 3D lse and delta to avoid illegal instruction * [Perf] Relaxed order for dQ and SIMT store for dKdV * [Feat] For atomic add version * [Lint] * [Bugfix] Enable code lowering with producer‑copy‑only program (#1168) * bugfix * lint fix * Enhance warp group register allocation to handle missing consumer bodies gracefully. Updated logic to annotate producer side when consumer is absent, ensuring robustness in degenerate warp-specialized patterns. * Refactor VisitExpr_ method in inject_tma_barrier.cc for improved readability. Adjusted formatting and spacing for clarity in barrier handling logic. * Update barrier handling in inject_tma_barrier.cc to accommodate newly appended entries. Adjusted the size of the replace vector to ensure it covers the full needed length, and modified the logic for appending barriers based on the updated replace conditions. * [Bugfix] Support 16bits shfl_sync (#1169) * Add type-safe warp shuffle helpers for 16-bit float types in common.h - Introduced generic passthrough functions for warp shuffle operations: `shfl_xor_sync`, `shfl_down_sync`, `shfl_up_sync`, and `shfl_sync`. - Added specializations for `cutlass::half_t` and `cutlass::bfloat16_t` to ensure type safety during shuffle operations. - Updated `reduce.h` to utilize the new shuffle functions, enhancing code clarity and maintainability. * lint fix * [Testing] Move TMA 1D and test for its functionality (#1167) * [Testing] Move TMA 1D and test for its functionality * [Lint] * [Refactor]: Change the params in pytest to avoid oom error during ci (#1170) * [Refactor]: Change the params in pytest to avoid oom error during ci * format * fix * Update test_example_cast.py * Update parameters in test_example_cast * Update test_example_flash_attention.py * update * format * fix * fix * format * [Bugfix] Fix tvm import path for editable build (#1172) * [Language] Expose `T.warpgroup_fence_operand` for nvcc code motion (#986) * remove debug print * pipeline fix * use the correct buffer access scope * rs support * warp warpgroup_fence_operand * fix * fp8 dtype ptx enhance * mma fix * TCGEN05 Interface * tcgen05 support * rebase * update * Enhance TCGEN05 support by adding new intrinsic operations and descriptors. Introduced `ptx_tcgen05_mma_ts` for tensor-memory to shared-memory instructions and `tcgen05_mma_arrive` for signaling barrier completion. Updated existing descriptors and code generation logic to accommodate these changes, ensuring compatibility with new instruction sets. Refactored related allocation functions and improved handling of shared memory descriptors. * lint fix * Refactor buffer reference handling in CUDA code generation and update test execution in tilelang. Ensure default annotations for unrolling are set correctly in TIR IR module. * wgmma fix --------- Co-authored-by: Zhiwen Mo <[email protected]> * [Language] Add Correctness and performance check scripts for V2 (#1174) * fix * lint fix * fix * lint fix * fix * upd * [Bugfix] Legalize Datatype for mma intrinisc codegen (#1179) * fix * lint fix * Enhance CUDA code generation by updating register type handling for float data types. Introduced a workaround for TF32 type compatibility and improved the registration of MMA register types for A and B operands. * [Perf] Add layout and use_tma to boost performance * [Lint] * [Note] --------- Co-authored-by: Lei Wang <[email protected]> Co-authored-by: Yuqi Dong <[email protected]> Co-authored-by: Zhiwen Mo <[email protected]>
* Add type-safe warp shuffle helpers for 16-bit float types in common.h - Introduced generic passthrough functions for warp shuffle operations: `shfl_xor_sync`, `shfl_down_sync`, `shfl_up_sync`, and `shfl_sync`. - Added specializations for `cutlass::half_t` and `cutlass::bfloat16_t` to ensure type safety during shuffle operations. - Updated `reduce.h` to utilize the new shuffle functions, enhancing code clarity and maintainability. * lint fix
* [Test] Add cp async to avoid register spill * [BugFix] GQA fwd and bwd - Fix the undefined behavior of -inf in acc_s - Fix the causal loop range in varlen scenario * [TMA] Move on to TMA and locate the register spill issue * [Debug] Not the reason of zero-assignment. Probably the combination of Parallel op & conditional qkT * [Debug] The SIMT copy in producer occupies too many registers * [BugFix] Use 3D lse and delta to avoid illegal instruction * [Perf] Relaxed order for dQ and SIMT store for dKdV * [Feat] For atomic add version * [Lint] * [Bugfix] Enable code lowering with producer‑copy‑only program (tile-ai#1168) * bugfix * lint fix * Enhance warp group register allocation to handle missing consumer bodies gracefully. Updated logic to annotate producer side when consumer is absent, ensuring robustness in degenerate warp-specialized patterns. * Refactor VisitExpr_ method in inject_tma_barrier.cc for improved readability. Adjusted formatting and spacing for clarity in barrier handling logic. * Update barrier handling in inject_tma_barrier.cc to accommodate newly appended entries. Adjusted the size of the replace vector to ensure it covers the full needed length, and modified the logic for appending barriers based on the updated replace conditions. * [Bugfix] Support 16bits shfl_sync (tile-ai#1169) * Add type-safe warp shuffle helpers for 16-bit float types in common.h - Introduced generic passthrough functions for warp shuffle operations: `shfl_xor_sync`, `shfl_down_sync`, `shfl_up_sync`, and `shfl_sync`. - Added specializations for `cutlass::half_t` and `cutlass::bfloat16_t` to ensure type safety during shuffle operations. - Updated `reduce.h` to utilize the new shuffle functions, enhancing code clarity and maintainability. * lint fix * [Testing] Move TMA 1D and test for its functionality (tile-ai#1167) * [Testing] Move TMA 1D and test for its functionality * [Lint] * [Refactor]: Change the params in pytest to avoid oom error during ci (tile-ai#1170) * [Refactor]: Change the params in pytest to avoid oom error during ci * format * fix * Update test_example_cast.py * Update parameters in test_example_cast * Update test_example_flash_attention.py * update * format * fix * fix * format * [Bugfix] Fix tvm import path for editable build (tile-ai#1172) * [Language] Expose `T.warpgroup_fence_operand` for nvcc code motion (tile-ai#986) * remove debug print * pipeline fix * use the correct buffer access scope * rs support * warp warpgroup_fence_operand * fix * fp8 dtype ptx enhance * mma fix * TCGEN05 Interface * tcgen05 support * rebase * update * Enhance TCGEN05 support by adding new intrinsic operations and descriptors. Introduced `ptx_tcgen05_mma_ts` for tensor-memory to shared-memory instructions and `tcgen05_mma_arrive` for signaling barrier completion. Updated existing descriptors and code generation logic to accommodate these changes, ensuring compatibility with new instruction sets. Refactored related allocation functions and improved handling of shared memory descriptors. * lint fix * Refactor buffer reference handling in CUDA code generation and update test execution in tilelang. Ensure default annotations for unrolling are set correctly in TIR IR module. * wgmma fix --------- Co-authored-by: Zhiwen Mo <[email protected]> * [Language] Add Correctness and performance check scripts for V2 (tile-ai#1174) * fix * lint fix * fix * lint fix * fix * upd * [Bugfix] Legalize Datatype for mma intrinisc codegen (tile-ai#1179) * fix * lint fix * Enhance CUDA code generation by updating register type handling for float data types. Introduced a workaround for TF32 type compatibility and improved the registration of MMA register types for A and B operands. * [Perf] Add layout and use_tma to boost performance * [Lint] * [Note] --------- Co-authored-by: Lei Wang <[email protected]> Co-authored-by: Yuqi Dong <[email protected]> Co-authored-by: Zhiwen Mo <[email protected]>
This pull request introduces new type-safe warp shuffle helper functions for 16-bit float types (such as
half_tandbfloat16_t) and refactors the usage of CUDA shuffle intrinsics in reduction and cumulative sum kernels to use these helpers. The main goal is to avoid implicit type conversions that may be disallowed, promoting values tofloatfor shuffling and then down-converting, which improves code safety and maintainability.Type-safe shuffle helpers
tlinsrc/tl_templates/cuda/common.hthat defines generic and specialized type-safe wrappers for CUDA warp shuffle functions (shfl_xor_sync,shfl_down_sync,shfl_up_sync,shfl_sync) for both generic types and specifically forcutlass::half_tandcutlass::bfloat16_t. These helpers ensure explicit type promotion and conversion for 16-bit float types.Refactoring to use type-safe helpers
AllReducestruct insrc/tl_templates/cuda/reduce.hto use the newtl::shfl_xor_synchelper instead of directly calling__shfl_xor_sync, ensuring type safety for reduction operations. [1] [2]CumSum2Dstruct insrc/tl_templates/cuda/reduce.hto use the newtl::shfl_down_sync,tl::shfl_up_sync, andtl::shfl_synchelpers for cumulative sum operations, replacing direct calls to CUDA intrinsics and removing unnecessary casts. [1] [2] [3] [4]Summary by CodeRabbit