Skip to content

[Feature] Hierarchical reduction and warp reduction intrinsics support#1762

Open
tzj-fxz wants to merge 7 commits intotile-ai:mainfrom
tzj-fxz:redux
Open

[Feature] Hierarchical reduction and warp reduction intrinsics support#1762
tzj-fxz wants to merge 7 commits intotile-ai:mainfrom
tzj-fxz:redux

Conversation

@tzj-fxz
Copy link
Contributor

@tzj-fxz tzj-fxz commented Jan 31, 2026

For #1761

  • Add hierarchical reduction from warp to block to reduce workspace size
  • Add redux.sync PTX templates to support faster reduction on (u)int32 with __CUDA_ARCH__>=800

Summary by CodeRabbit

  • Performance

    • More efficient reductions via dynamic workspace sizing and CUDA-level optimizations, enabling hierarchical warp-and-warpgroup reductions for lower memory use and faster execution.
  • Bug Fixes

    • Improved handling and validation for integer reduce results to ensure correct casting and accurate sums.
  • Tests

    • Added robust reference-based tests exercising int32 reduce across multiple shapes.

@github-actions
Copy link

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run pre-commit run --all-files in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Jan 31, 2026

📝 Walkthrough

Walkthrough

Introduces a hierarchical CUDA reduction path with a new warp_reduce and per-warp buffering, and makes AllReduce workspace allocation dynamic based on reducing thread count; adds integer reduce_sum tests using a Torch-backed reference.

Changes

Cohort / File(s) Summary
Core reduction runtime
src/op/reduce.cc
Conditional AllReduce workspace sizing: compute workspace as reducing_threads/32 when reducing_threads > 32, divisible by 32, and scale == 1; call AddWorkspace with the smaller size and pass it to the AllReduce thread reduction.
CUDA reduction templates
src/tl_templates/cuda/reduce.h
Adds public warp_reduce declaration and implementation; implements architecture-aware warp-level reductions, hierarchical per-warp buffering and final reduction across warps; includes FP16/BF16 handling and shuffle/intrinsic fallbacks; removes a static_assert restricting thread counts.
Tests
testing/python/language/test_tilelang_language_reduce.py
Disables cache at startup; replaces simple sum check with a Torch-backed integer-aware reference; adds test_reduce_sum_int32() covering multiple shapes.

Sequence Diagram(s)

sequenceDiagram
    participant Kernel as Kernel (thread block)
    participant Warp as Warp (warp lanes)
    participant Workspace as Workspace (per-warp buffer)
    participant AllReduce as AllReduce (thread reduction)

    Kernel->>Warp: Each lane computes partial value
    Warp->>Warp: warp_reduce() performs warp-level shuffle reduction
    Warp->>Workspace: write per-warp result to Workspace (if hierarchical)
    Kernel->>AllReduce: AllReduce reads per-warp Workspace (or per-thread values)
    AllReduce->>AllReduce: final inter-warp reduction and writeback
    AllReduce-->>Kernel: reduced result available
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Possibly related issues

Possibly related PRs

Suggested reviewers

  • LeiWang1999

Poem

🐰 I hopped through warps with whiskers bright,

I merged small sums beneath the GPU light,
Per-warp crumbs gathered, stitched with care,
A trim-sized workspace, nimble and fair. 🥕

🚥 Pre-merge checks | ✅ 2 | ❌ 1
❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 10.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title accurately summarizes the main features introduced: hierarchical reduction and warp reduction intrinsics support, which align with the substantial changes across all three modified files.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Fix all issues with AI agents
In `@testing/python/language/test_tilelang_language_reduce.py`:
- Around line 99-105: The ref_fn uses torch.uint32 in a static dtype list which
raises AttributeError on PyTorch <2.3.0; update ref_fn to conditionally include
torch.uint32 only when hasattr(torch, "uint32") (or use getattr with a fallback)
so the dtype check is built at runtime, mirroring the existing pattern used for
version-dependent dtypes like float8_e4m3fn; locate and modify the dtype
membership test in ref_fn to construct the list/set conditionally and then
perform the same res.to(A.dtype) conversion for the supported integer dtypes
(torch.uint32, torch.int32, torch.int64).
🧹 Nitpick comments (1)
testing/python/language/test_tilelang_language_reduce.py (1)

7-7: Consider scoping disable_cache() to avoid global test side effects.
This flips a process-wide cache flag; if other tests run in the same session, they inherit the disabled cache. If that’s not intended, wrap it in a fixture/context that re-enables after this module.

Comment on lines +99 to +105
import torch

def ref_fn(A):
res = A.sum(dim=1)
if A.dtype in [torch.uint32, torch.int32, torch.int64]:
return res.to(A.dtype)
return res
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

🧩 Analysis chain

🏁 Script executed:

# First, find the root-level configuration files to determine PyTorch version requirements
fd -e "pyproject.toml" -e "requirements.txt" -e "setup.py" -e "setup.cfg" | head -20

Repository: tile-ai/tilelang

Length of output: 42


🏁 Script executed:

# Search for PyTorch version requirements in the repository
rg -i "torch.*version|pytorch" --type toml --type txt --type py -A 2 | head -50

Repository: tile-ai/tilelang

Length of output: 2715


🏁 Script executed:

# Check the specific test file and surrounding context
wc -l testing/python/language/test_tilelang_language_reduce.py

Repository: tile-ai/tilelang

Length of output: 120


🏁 Script executed:

# Look for other torch.uint32 usage in the codebase
rg "torch\.uint32" -B 2 -A 2

Repository: tile-ai/tilelang

Length of output: 3742


🏁 Script executed:

# Check for existing dtype guards or compatibility patterns
rg "hasattr.*torch" -B 2 -A 2

Repository: tile-ai/tilelang

Length of output: 6548


🏁 Script executed:

# Look for version requirements
cat pyproject.toml 2>/dev/null | head -100

Repository: tile-ai/tilelang

Length of output: 3126


🏁 Script executed:

# Check setup.py or requirements files
find . -maxdepth 2 -name "requirements*.txt" -o -name "setup.py" | xargs cat 2>/dev/null | grep -i torch

Repository: tile-ai/tilelang

Length of output: 130


🌐 Web query:

When was torch.uint32 introduced in PyTorch - version history

💡 Result:

torch.uint32 was introduced in PyTorch 2.3.0 (released April 24, 2024) as an unsigned 32-bit “shell dtype” (available as a dtype object, but with limited eager-mode operator support). [1][2][3]

Before that, PyTorch generally did not support uint32/uint64 (beyond uint8). [4]

Sources:
[1] (dev-discuss.pytorch.org)
[2] (github.com)
[3] (docs.pytorch.org)
[4] (discuss.pytorch.org)

Citations:


🏁 Script executed:

# Look at the context around the test function in question
sed -n '85,115p' testing/python/language/test_tilelang_language_reduce.py

Repository: tile-ai/tilelang

Length of output: 905


Guard torch.uint32 usage for torch-version compatibility.
torch.uint32 was introduced in PyTorch 2.3.0. Since the repository has no minimum PyTorch version requirement, this code will raise AttributeError at function definition time on PyTorch < 2.3.0, breaking tests even for other dtypes. Use a hasattr() check to build the dtype set conditionally, matching the pattern already used in the codebase for version-dependent dtypes like float8_e4m3fn.

🛠️ Suggested compatibility guard
+    _int_dtypes = {torch.int32, torch.int64}
+    if hasattr(torch, "uint32"):
+        _int_dtypes.add(torch.uint32)
+
     def ref_fn(A):
         res = A.sum(dim=1)
-        if A.dtype in [torch.uint32, torch.int32, torch.int64]:
+        if A.dtype in _int_dtypes:
             return res.to(A.dtype)
         return res
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
import torch
def ref_fn(A):
res = A.sum(dim=1)
if A.dtype in [torch.uint32, torch.int32, torch.int64]:
return res.to(A.dtype)
return res
import torch
_int_dtypes = {torch.int32, torch.int64}
if hasattr(torch, "uint32"):
_int_dtypes.add(torch.uint32)
def ref_fn(A):
res = A.sum(dim=1)
if A.dtype in _int_dtypes:
return res.to(A.dtype)
return res
🤖 Prompt for AI Agents
In `@testing/python/language/test_tilelang_language_reduce.py` around lines 99 -
105, The ref_fn uses torch.uint32 in a static dtype list which raises
AttributeError on PyTorch <2.3.0; update ref_fn to conditionally include
torch.uint32 only when hasattr(torch, "uint32") (or use getattr with a fallback)
so the dtype check is built at runtime, mirroring the existing pattern used for
version-dependent dtypes like float8_e4m3fn; locate and modify the dtype
membership test in ref_fn to construct the list/set conditionally and then
perform the same res.to(A.dtype) conversion for the supported integer dtypes
(torch.uint32, torch.int32, torch.int64).

@tzj-fxz
Copy link
Contributor Author

tzj-fxz commented Jan 31, 2026

Maybe we also need the performance regression tests. Where can I trigger them? @LeiWang1999

@Rachmanino
Copy link
Collaborator

@regression-perf

@bucket-xv
Copy link

Thanks for your contribution! I've further investigated this problem this weekend and have some suggestions for the code:

  1. Use intrinsic functions instead of raw redux PTX for maintainability. Refer to https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#warp-reduce-functions. Note that this is also for sm80+.
  2. May leverage redux instruction for f32 types? This is supported since sm100a. Refer to https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-redux-sync.
  3. At warp level reduction, maybe any floating types can be cast to f32 and integral types cast to int32. This cast helps leverage the redux inst. This is almost always better in performance since there is also an implicit cast for shfl.sync, which requires b32 types.(https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-shfl-sync)

@LeiWang1999
Copy link
Member

@tzj-fxz Would you mind take a look.

@tzj-fxz
Copy link
Contributor Author

tzj-fxz commented Feb 2, 2026

Thanks for your contribution! I've further investigated this problem this weekend and have some suggestions for the code:

  1. Use intrinsic functions instead of raw redux PTX for maintainability. Refer to https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#warp-reduce-functions. Note that this is also for sm80+.
  2. May leverage redux instruction for f32 types? This is supported since sm100a. Refer to https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-redux-sync.
  3. At warp level reduction, maybe any floating types can be cast to f32 and integral types cast to int32. This cast helps leverage the redux inst. This is almost always better in performance since there is also an implicit cast for shfl.sync, which requires b32 types.(https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-shfl-sync)

Thank you for the advice. I will further implement these features.

template <class Reducer, int threads, int scale, int thread_offset = 0,
int all_threads = threads>
struct AllReduce {
static_assert(threads == 1024 or threads == 512 or threads == 256 or
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need to remove this piece of code?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is more flexible to support reduction across any warps (not constrained to powers of 2) by doing warp-level reduction after inner-warp reduction.

@tzj-fxz tzj-fxz changed the title [Feature] Hierarchical reduction and redux PTX support [Feature] Hierarchical reduction and warp reduction intrinsics support Feb 2, 2026
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🤖 Fix all issues with AI agents
In `@src/tl_templates/cuda/reduce.h`:
- Around line 377-381: The fast-path incorrectly casts 64-bit integrals to
int32_t and uses run_reduce_sync, truncating results; update the compile-time
branch in the template so the run_reduce_sync path is only taken when T is
exactly int32_t or uint32_t (remove the branch that treats other
std::is_integral_v<T> types by casting), and let 64-bit types fall through to
the shfl_xor_sync-based fallback that correctly handles 64-bit operands; adjust
the condition around run_reduce_sync (the constexpr if that currently checks
std::is_same_v<T, int32_t> || std::is_same_v<T, uint32_t> and the following
std::is_integral_v<T> branch) accordingly so no cast-to-int32_t is performed for
64-bit integrals.
- Around line 344-358: The sm100 fast-path incorrectly casts floats and calls
__reduce_max_sync/__reduce_min_sync which only accept 32-bit integer operands;
modify the branch in reduce.h so the __reduce_*_sync calls are only compiled for
32-bit integer types (e.g., when std::is_same_v<T,int> or
std::is_same_v<T,unsigned int> / std::is_integral and sizeof(T)==4), and remove
the float/half/bfloat casting there; for all other types (float, half_t,
bfloat16_t, etc.) fall through to the existing non-sm100 handling. Ensure you
static_cast the value to int/unsigned before passing to
__reduce_max_sync/__reduce_min_sync and keep the ReduceOp/MaxOp/MinOp constexpr
checks to select the right intrinsic.

Comment on lines +377 to +381
if constexpr (std::is_same_v<T, int32_t> || std::is_same_v<T, uint32_t>) {
return run_reduce_sync(value);
} else if constexpr (std::is_integral_v<T>) {
return static_cast<T>(run_reduce_sync(static_cast<int32_t>(value)));
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

# First, let's understand the file structure and find the reduce.h file
find . -name "reduce.h" -type f | head -20

Repository: tile-ai/tilelang

Length of output: 142


🏁 Script executed:

# Read the reduce.h file around the lines in question to understand context
wc -l src/tl_templates/cuda/reduce.h

Repository: tile-ai/tilelang

Length of output: 94


🏁 Script executed:

# Get context around lines 377-381
sed -n '350,410p' src/tl_templates/cuda/reduce.h

Repository: tile-ai/tilelang

Length of output: 2198


🏁 Script executed:

# Search for run_reduce_sync definition to understand what it does
rg "run_reduce_sync" --type cpp --type h -B 3 -A 10

Repository: tile-ai/tilelang

Length of output: 2178


🌐 Web query:

CUDA __reduce_sync functions operand size constraints 32-bit 64-bit

💡 Result:

CUDA’s warp __reduce_*_sync(mask, value) intrinsics are limited to 32-bit integer operands:

  • The only provided overloads are for unsigned and int (both 32-bit). There are no 64-bit (long long / unsigned long long) overloads, so you can’t directly reduce 64-bit values with these intrinsics. [1]
  • If you need a warp reduction on 64-bit data, use a shuffle-based reduction (e.g., __shfl_xor_sync), which supports long long / unsigned long long and generally moves 8 bytes per thread for 64-bit types. [1]

Sources:
[1] NVIDIA CUDA C++ Programming Guide, sections “10.21 Warp Reduce Functions” and “10.22 Warp Shuffle Functions”. (docs.nvidia.com)

Citations:


Avoid truncating 64‑bit integral reductions.

CUDA's __reduce_*_sync() intrinsics support only 32-bit integer operands. The current code casts int64_t and uint64_t to int32_t before reduction, truncating the upper 32 bits and corrupting results. Restrict the fast-path to 32-bit integrals and let 64-bit types fall through to the shfl_xor_sync-based fallback, which properly handles 64-bit operands.

🔧 Proposed fix
-  if constexpr (std::is_same_v<T, int32_t> || std::is_same_v<T, uint32_t>) {
-    return run_reduce_sync(value);
-  } else if constexpr (std::is_integral_v<T>) {
-    return static_cast<T>(run_reduce_sync(static_cast<int32_t>(value)));
-  }
+  if constexpr (std::is_same_v<T, int32_t> || std::is_same_v<T, uint32_t>) {
+    return run_reduce_sync(value);
+  } else if constexpr (std::is_integral_v<T> && sizeof(T) <= 4) {
+    return static_cast<T>(run_reduce_sync(static_cast<int32_t>(value)));
+  }
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if constexpr (std::is_same_v<T, int32_t> || std::is_same_v<T, uint32_t>) {
return run_reduce_sync(value);
} else if constexpr (std::is_integral_v<T>) {
return static_cast<T>(run_reduce_sync(static_cast<int32_t>(value)));
}
if constexpr (std::is_same_v<T, int32_t> || std::is_same_v<T, uint32_t>) {
return run_reduce_sync(value);
} else if constexpr (std::is_integral_v<T> && sizeof(T) <= 4) {
return static_cast<T>(run_reduce_sync(static_cast<int32_t>(value)));
}
🤖 Prompt for AI Agents
In `@src/tl_templates/cuda/reduce.h` around lines 377 - 381, The fast-path
incorrectly casts 64-bit integrals to int32_t and uses run_reduce_sync,
truncating results; update the compile-time branch in the template so the
run_reduce_sync path is only taken when T is exactly int32_t or uint32_t (remove
the branch that treats other std::is_integral_v<T> types by casting), and let
64-bit types fall through to the shfl_xor_sync-based fallback that correctly
handles 64-bit operands; adjust the condition around run_reduce_sync (the
constexpr if that currently checks std::is_same_v<T, int32_t> ||
std::is_same_v<T, uint32_t> and the following std::is_integral_v<T> branch)
accordingly so no cast-to-int32_t is performed for 64-bit integrals.

@bucket-xv
Copy link

bucket-xv commented Feb 3, 2026

As far as I know, __reduce_max_sync does not accept float types. May you use assembly to use redux.sync for floating types since sm_100a?

It's indeed complex. It seems that __reduce_max_sync can be used to substitute ASM, but no function warps redux.sync for floating types since sm_100a.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Fix all issues with AI agents
In `@src/tl_templates/cuda/reduce.h`:
- Around line 344-366: The SM100 asm fast-path currently casts any T to float
and uses redux.sync.*.f32, which corrupts large integer ordering; change the
compile-time guard inside the SM100 block to only take this fast-path for
floating types (float, half_t, bfloat16_t) — e.g. wrap the asm branches with an
if constexpr that checks std::is_floating_point_v<T> || std::is_same_v<T,
half_t> || std::is_same_v<T, bfloat16_t> so integral types fall through to the
SM80+ __reduce_max_sync/__reduce_min_sync implementations; keep the existing
value_cast and asm/result handling for the floating case and do not alter the
fallback path for non-floating T.

@tzj-fxz
Copy link
Contributor Author

tzj-fxz commented Feb 4, 2026

As far as I know, __reduce_max_sync does not accept float types. May you use assembly to use redux.sync for floating types since sm_100a?

It's indeed complex. It seems that __reduce_max_sync can be used to substitute ASM, but no function warps redux.sync for floating types since sm_100a.

Fixed in the latest commit. :)

Copy link
Member

@LeiWang1999 LeiWang1999 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be better to have some benchmark results

@tzj-fxz
Copy link
Contributor Author

tzj-fxz commented Feb 5, 2026

@regression-perf

@github-actions
Copy link

github-actions bot commented Feb 5, 2026

Performance Regression Test Report

Triggered by: @tzj-fxz
Workflow run: https://github.com/tile-ai/tilelang/actions/runs/21701890076

Results

File Original Latency Current Latency Speedup
example_gemv 0.281301 0.29055 0.968167
example_warp_specialize_gemm_barrierpipe_stage2 0.038721 0.039393 0.982941
example_dequant_gemm_bf16_fp4_hopper 0.617419 0.625101 0.987711
example_gqa_decode 0.047905 0.048193 0.994024
example_dequant_gemm_fp4_hopper 1.03851 1.04228 0.996386
example_mha_sink_fwd_bhsd_sliding_window 0.0155414 0.0155875 0.997045
example_tilelang_gemm_fp8 0.317253 0.318151 0.997176
example_gemm_intrinsics 0.034593 0.034656 0.998182
example_warp_specialize_gemm_softpipe_stage2 0.038049 0.038113 0.998321
example_convolution_autotune 0.991021 0.992591 0.998418
example_tilelang_gemm_fp8_intrinsic 0.91042 0.91147 0.998848
example_tilelang_sparse_gqa_decode_varlen_indice 0.0168975 0.0169135 0.999054
example_mha_sink_fwd_bhsd 0.0157149 0.0157277 0.999182
example_mha_fwd_varlen 0.0449671 0.0449979 0.999316
example_mha_bwd_bhsd 0.0399926 0.0400117 0.999523
example_gqa_bwd 0.049014 0.0490372 0.999525
example_mha_sink_bwd_bhsd_sliding_window 0.0443848 0.0444048 0.999551
example_tilelang_gemm_splitk 1.4019 1.40252 0.999554
example_gemm_schedule 0.0322591 0.0322731 0.999566
example_mha_inference 0.079969 0.08 0.999612
fp8_lighting_indexer 0.0353686 0.0353812 0.999644
sparse_mla_fwd_pipelined 0.0946215 0.0946542 0.999655
example_group_per_split_token_cast_to_fp8 0.0103231 0.0103262 0.999698
sparse_mla_bwd 0.377131 0.377235 0.999726
example_gqa_sink_fwd_bhsd_wgmma_pipelined_sliding_window 0.0144159 0.0144195 0.999749
sparse_mla_fwd 0.129611 0.129639 0.999779
example_mha_bwd_bshd_wgmma_pipelined 0.0254211 0.0254248 0.999856
example_tilelang_block_sparse_attn 0.0100668 0.0100682 0.999864
example_gqa_bwd_tma_reduce_varlen 0.0512859 0.051292 0.999882
tilelang_example_sparse_tensorcore 0.0149007 0.0149024 0.99989
example_linear_attn_bwd 0.152459 0.152465 0.999959
example_linear_attn_fwd 0.0365545 0.0365553 0.999978
example_mla_decode 0.449224 0.449226 0.999996
example_dequant_gemv_fp16xint4 0.0283622 0.0283618 1.00001
example_mha_sink_fwd_bhsd_wgmma_pipelined_sliding_window 0.015343 0.0153427 1.00002
example_blocksparse_gemm 0.0224562 0.0224555 1.00003
example_elementwise_add 0.29402 0.29399 1.0001
example_tilelang_nsa_decode 0.00730636 0.00730504 1.00018
example_vertical_slash_sparse_attn 0.231701 0.231659 1.00018
example_convolution 1.30915 1.3088 1.00027
example_tilelang_sparse_gqa_decode_varlen_mask 0.023128 0.0231209 1.0003
example_gqa_sink_bwd_bhsd 0.0408243 0.0408075 1.00041
example_tilelang_gemm_splitk_vectorize_atomicadd 1.40117 1.40047 1.0005
example_gqa_sink_bwd_bhsd_sliding_window 0.0251555 0.0251424 1.00052
example_gqa_bwd_wgmma_pipelined 0.0687511 0.0687154 1.00052
example_mha_bwd_bshd 0.0406295 0.0406056 1.00059
example_gqa_sink_fwd_bhsd_wgmma_pipelined 0.0143039 0.0142944 1.00066
example_per_token_cast_to_fp8 0.00739874 0.00739373 1.00068
topk_selector 0.0531145 0.0530765 1.00072
example_dynamic 0.651212 0.650668 1.00084
example_tilelang_nsa_fwd 0.00681937 0.00681268 1.00098
block_sparse_attn_tilelang 0.0101621 0.0101517 1.00103
example_dequant_gemm_w4a8 5.30534 5.29934 1.00113
example_dequant_gemm_bf16_mxfp4_hopper 0.557416 0.556684 1.00131
example_mha_sink_fwd_bhsd_wgmma_pipelined 0.0152878 0.0152557 1.00211
example_mha_sink_bwd_bhsd 0.0615516 0.0614027 1.00242
example_tilelang_gemm_fp8_2xAcc 0.183363 0.182821 1.00297
example_gemm_autotune 0.022176 0.02208 1.00435
example_gemm 0.022752 0.022465 1.01278
example_warp_specialize_gemm_copy_0_gemm_1 0.038817 0.037985 1.0219
example_dequant_groupedgemm_bf16_mxfp4_hopper 4.10577 3.98573 1.03012
example_warp_specialize_gemm_copy_1_gemm_0 0.037473 0.036224 1.03448
example_topk 0.010752 0.009249 1.1625

Artifacts

  • regression_result.png (speedup plot) is attached as a workflow artifact. Download it from the workflow run page above.

@tzj-fxz
Copy link
Contributor Author

tzj-fxz commented Feb 5, 2026

Threads=128. One block.

Op M N Original (ms) Redux (ms) Original (tpt) Redux (tpt) Speedup
MAX 256 32 0.0077 0.0076 4.37 GB/s 4.46 GB/s +1.32%
MIN 256 32 0.0067 0.0062 5.05 GB/s 5.41 GB/s +7.13%
MAX 256 64 0.0068 0.0066 9.83 GB/s 10.09 GB/s +3.03%
MIN 256 64 0.0067 0.0063 9.90 GB/s 10.49 GB/s +6.35%
MAX 256 128 0.0067 0.0062 19.70 GB/s 21.31 GB/s +8.06%
MIN 256 128 0.0070 0.0065 18.88 GB/s 20.48 GB/s +7.69%

The regression performance test also shows that example_topk.py, which heavily uses reduce_max, has a 10%+ gain.

@tzj-fxz tzj-fxz requested a review from LeiWang1999 February 5, 2026 07:43
} else {
value_cast = static_cast<float>(value);
}
if constexpr (std::is_same_v<ReduceOp, MaxOp>) {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As the AI bot points out, the integer value may also go through this branch, causing correctness problems

}
#endif
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
auto run_reduce_sync = [&]<typename T_cast>(T_cast val) {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These __reduce_sync functions only accept integers.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. This lambda expression will only be triggered when the input is an integer.

@tzj-fxz tzj-fxz requested a review from bucket-xv February 6, 2026 06:01
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🤖 Fix all issues with AI agents
In `@src/tl_templates/cuda/reduce.h`:
- Around line 344-365: The fast-path in the CUDA reduce code leaves value_cast
as 0.0f for T=float (and mishandles double), so the inline asm reduces zeros;
update the branches in the block inside reduce.h (the section that sets
value_cast and emits redux.sync) to handle float explicitly by assigning
value_cast = value when std::is_same_v<T, float>, keep the existing assignments
for half_t and bfloat16_t, and ensure the fast-path is enabled only for the
three safe types (float, half_t, bfloat16_t) — do NOT include double — before
emitting the MaxOp/MinOp asm so the asm uses the real value_cast.
- Around line 368-382: The template lambda run_reduce_sync uses C++20-only
syntax and must be replaced for C++17: implement a regular templated helper
function (or a small function object) named e.g. run_reduce_sync_impl templated
on typename T_cast that contains the same if constexpr branches for ReduceOp vs
SumOp/MaxOp/MinOp/BitAndOp/BitOrOp/BitXorOp and calls the corresponding
__reduce_*_sync(mask, val), then call that helper from the existing site instead
of the lambda; ensure the helper has the same capture visibility (pass mask in
as a parameter or make it accessible) and is declared in the same scope so
ReduceOp and the __reduce_*_sync calls resolve correctly.
🧹 Nitpick comments (1)
src/tl_templates/cuda/reduce.h (1)

130-189: run_hopper duplicates run except for barrier calls — consider unifying.

The hierarchical reduction logic (warp reduce → shared-memory exchange → butterfly) is copy-pasted between run and run_hopper, differing only in barrier synchronization (__syncthreads() vs bar.sync asm). A barrier-policy template parameter or a shared helper parameterized on a sync functor would eliminate this ~60-line duplication and reduce maintenance risk when the reduction logic is updated.

Comment on lines +344 to +365
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) && \
(defined(__CUDA_ARCH_FEAT_SM100_ALL) || defined(__CUDA_ARCH_FEAT_SM100_F))
float value_cast = 0.0f;
if constexpr (std::is_same_v<T, half_t>) {
value_cast = __half2float(value);
} else if constexpr (std::is_same_v<T, bfloat16_t>) {
value_cast = __bfloat162float(value);
}
if constexpr (std::is_same_v<ReduceOp, MaxOp> && !std::is_integral_v<T>) {
float res;
asm("redux.sync.max.f32 %0, %1, %2;"
: "=f"(res)
: "f"(value_cast), "r"(mask));
return static_cast<T>(res);
} else if constexpr (std::is_same_v<ReduceOp, MinOp> &&
!std::is_integral_v<T>) {
float res;
asm("redux.sync.min.f32 %0, %1, %2;"
: "=f"(res)
: "f"(value_cast), "r"(mask));
return static_cast<T>(res);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

value_cast is uninitialized for T=float, producing a zero result.

When T is float, none of the if constexpr branches on lines 347–351 execute, so value_cast remains 0.0f. The redux.sync asm then reduces over zeros instead of the actual values, silently corrupting the result. The same problem applies to T=double (zero + precision loss).

Add the missing assignment for float, and restrict the fast-path to the three types that actually benefit from it (float, half_t, bfloat16_t), keeping double out to avoid silent precision loss:

🐛 Proposed fix
 `#if` defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) &&                       \
     (defined(__CUDA_ARCH_FEAT_SM100_ALL) || defined(__CUDA_ARCH_FEAT_SM100_F))
-  float value_cast = 0.0f;
-  if constexpr (std::is_same_v<T, half_t>) {
-    value_cast = __half2float(value);
-  } else if constexpr (std::is_same_v<T, bfloat16_t>) {
-    value_cast = __bfloat162float(value);
-  }
-  if constexpr (std::is_same_v<ReduceOp, MaxOp> && !std::is_integral_v<T>) {
-    float res;
-    asm("redux.sync.max.f32 %0, %1, %2;"
-        : "=f"(res)
-        : "f"(value_cast), "r"(mask));
-    return static_cast<T>(res);
-  } else if constexpr (std::is_same_v<ReduceOp, MinOp> &&
-                       !std::is_integral_v<T>) {
-    float res;
-    asm("redux.sync.min.f32 %0, %1, %2;"
-        : "=f"(res)
-        : "f"(value_cast), "r"(mask));
-    return static_cast<T>(res);
-  }
+  if constexpr (std::is_same_v<T, float> || std::is_same_v<T, half_t> ||
+                std::is_same_v<T, bfloat16_t>) {
+    float value_cast;
+    if constexpr (std::is_same_v<T, half_t>) {
+      value_cast = __half2float(value);
+    } else if constexpr (std::is_same_v<T, bfloat16_t>) {
+      value_cast = __bfloat162float(value);
+    } else {
+      value_cast = value;
+    }
+    if constexpr (std::is_same_v<ReduceOp, MaxOp>) {
+      float res;
+      asm("redux.sync.max.f32 %0, %1, %2;"
+          : "=f"(res)
+          : "f"(value_cast), "r"(mask));
+      return static_cast<T>(res);
+    } else if constexpr (std::is_same_v<ReduceOp, MinOp>) {
+      float res;
+      asm("redux.sync.min.f32 %0, %1, %2;"
+          : "=f"(res)
+          : "f"(value_cast), "r"(mask));
+      return static_cast<T>(res);
+    }
+  }
 `#endif`
🤖 Prompt for AI Agents
In `@src/tl_templates/cuda/reduce.h` around lines 344 - 365, The fast-path in the
CUDA reduce code leaves value_cast as 0.0f for T=float (and mishandles double),
so the inline asm reduces zeros; update the branches in the block inside
reduce.h (the section that sets value_cast and emits redux.sync) to handle float
explicitly by assigning value_cast = value when std::is_same_v<T, float>, keep
the existing assignments for half_t and bfloat16_t, and ensure the fast-path is
enabled only for the three safe types (float, half_t, bfloat16_t) — do NOT
include double — before emitting the MaxOp/MinOp asm so the asm uses the real
value_cast.

Comment on lines +368 to +382
auto run_reduce_sync = [&]<typename T_cast>(T_cast val) {
if constexpr (std::is_same_v<ReduceOp, SumOp>) {
return __reduce_add_sync(mask, val);
} else if constexpr (std::is_same_v<ReduceOp, MaxOp>) {
return __reduce_max_sync(mask, val);
} else if constexpr (std::is_same_v<ReduceOp, MinOp>) {
return __reduce_min_sync(mask, val);
} else if constexpr (std::is_same_v<ReduceOp, BitAndOp>) {
return __reduce_and_sync(mask, val);
} else if constexpr (std::is_same_v<ReduceOp, BitOrOp>) {
return __reduce_or_sync(mask, val);
} else if constexpr (std::is_same_v<ReduceOp, BitXorOp>) {
return __reduce_xor_sync(mask, val);
}
};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Find and examine CMake configuration files for C++ and CUDA standards
find . -name "CMakeLists.txt" -o -name "*.cmake" | head -20

Repository: tile-ai/tilelang

Length of output: 128


🏁 Script executed:

#!/bin/bash
# Search for C++ standard configuration
rg "CMAKE_CXX_STANDARD|cxx_std_|--std=" -B2 -A2 2>/dev/null | head -50

Repository: tile-ai/tilelang

Length of output: 265


🏁 Script executed:

#!/bin/bash
# Search for CUDA standard and version requirements
rg "CMAKE_CUDA_STANDARD|CUDA_STANDARD|CUDA.*VERSION|find_package.*CUDA" -B2 -A2 2>/dev/null | head -100

Repository: tile-ai/tilelang

Length of output: 5607


🏁 Script executed:

#!/bin/bash
# Find where reduce.h is included and used
rg "reduce\.h|reduce_h" -g "*.cc" -g "*.h" -g "*.cpp" -B2 -A2

Repository: tile-ai/tilelang

Length of output: 1771


🏁 Script executed:

#!/bin/bash
# Check for per-target C++ standard settings in CMakeLists.txt
rg "set_property.*CXX_STANDARD|target_compile_features" CMakeLists.txt -B2 -A2

Repository: tile-ai/tilelang

Length of output: 42


🏁 Script executed:

#!/bin/bash
# Check the actual reduce.h file location and context
head -50 src/tl_templates/cuda/reduce.h

Repository: tile-ai/tilelang

Length of output: 1138


🏁 Script executed:

#!/bin/bash
# Check if reduce.h has any C++ version guards or special compilation settings
rg "__cplusplus|C\+\+20|c\+\+20" src/tl_templates/cuda/reduce.h

Repository: tile-ai/tilelang

Length of output: 42


🏁 Script executed:

#!/bin/bash
# Search for any target_compile_features or per-file C++ standard overrides
rg "target_compile_features|set_source_files_properties.*LANGUAGE.*CXX|CXX_STANDARD" CMakeLists.txt cmake/

Repository: tile-ai/tilelang

Length of output: 101


C++20 template lambda incompatible with C++17 project standard.

The project is configured with CMAKE_CXX_STANDARD 17, but the template lambda syntax [&]<typename T_cast>(T_cast val) { ... } (P0428R2) requires C++20. This will fail compilation with C++17. Replace with a regular template helper function or explicit overloads to maintain C++17 compatibility.

🤖 Prompt for AI Agents
In `@src/tl_templates/cuda/reduce.h` around lines 368 - 382, The template lambda
run_reduce_sync uses C++20-only syntax and must be replaced for C++17: implement
a regular templated helper function (or a small function object) named e.g.
run_reduce_sync_impl templated on typename T_cast that contains the same if
constexpr branches for ReduceOp vs SumOp/MaxOp/MinOp/BitAndOp/BitOrOp/BitXorOp
and calls the corresponding __reduce_*_sync(mask, val), then call that helper
from the existing site instead of the lambda; ensure the helper has the same
capture visibility (pass mask in as a parameter or make it accessible) and is
declared in the same scope so ReduceOp and the __reduce_*_sync calls resolve
correctly.

@bucket-xv
Copy link

May you add some more tests? i.e:

  1. All datatypes and all types of reduce op. from fp4, fp8 to fp32 and int8 to int64. Also, max, min, add, absmax, absmin etc.
  2. All shapes and block sizes. The current ones do not involve num threads < 32.
    If that is all prepared, I have no further suggestions.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants