From 7688ba0db3b7b3e36bc10aaadef4a3e122c9c525 Mon Sep 17 00:00:00 2001 From: Clement Lin Date: Wed, 16 Jul 2025 14:37:37 +0800 Subject: [PATCH 01/13] [CK_TILE] Correct BlockWarps calculation and fix smoke-test in rmsnorm --- example/ck_tile/10_rmsnorm2d/generate.py | 2 +- .../ck_tile/10_rmsnorm2d/script/smoke_test.sh | 124 +++++++++++------- 2 files changed, 81 insertions(+), 45 deletions(-) diff --git a/example/ck_tile/10_rmsnorm2d/generate.py b/example/ck_tile/10_rmsnorm2d/generate.py index b0ba400af1e..e682270ceb6 100644 --- a/example/ck_tile/10_rmsnorm2d/generate.py +++ b/example/ck_tile/10_rmsnorm2d/generate.py @@ -85,7 +85,7 @@ class rmsnorm_fwd_codegen: if constexpr(is_warp_per_row) { static_assert(ck_tile::get_warp_size() % ThreadPerBlock_N_ == 0); - return total_warps * (ck_tile::get_warp_size() / ThreadPerBlock_N_); + return total_warps; } else { diff --git a/example/ck_tile/10_rmsnorm2d/script/smoke_test.sh b/example/ck_tile/10_rmsnorm2d/script/smoke_test.sh index 1c79dafadd8..3a0f7dbb66c 100755 --- a/example/ck_tile/10_rmsnorm2d/script/smoke_test.sh +++ b/example/ck_tile/10_rmsnorm2d/script/smoke_test.sh @@ -1,49 +1,85 @@ -#!/bin/sh +#!/bin/bash + EXE="$(find . -name tile_rmsnorm2d_fwd -type f | head -n 1)" -for fquant in "" "-fquant=1 -prec_o=int8" "-fquant=2 -prec_o=int8" "-fquant=1 -prec_o=fp8" "-fquant=2 -prec_o=fp8"\ - "-fquant=1 -prec_o=int8 -save_unquant=1" "-fquant=2 -prec_o=int8 -save_unquant=1" "-fquant=1 -prec_o=fp8 -save_unquant=1" "-fquant=2 -prec_o=fp8 -save_unquant=1"; do -for pr_i in "fp16" "bf16" ; do -for fadd in "0" "1"; do -# 0: for no specific RMSNorm; 1: for T-5 like RMSNorm -for s in "0" "1"; do -$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=99 -n=13 -$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=17 -n=16 -$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=1 -n=100 -$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=4 -n=128 -$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=80 -n=127 -# $EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=22 -n=255 -stride=256 -$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=7 -n=599 -$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=19 -n=512 -# $EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=33 -n=313 -stride=1000 -$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=11 -n=510 -# $EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=171 -n=676 -stride=818 -$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=91 -n=636 -# $EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=12 -n=768 -stride=800 -# $EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=100 -n=766 -stride=812 -$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=31 -n=1024 -# $EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=64 -n=1000 -stride=1004 -$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=8 -n=1501 -$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=3 -n=1826 -$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=5 -n=2040 -$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=7 -n=2734 -$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=1 -n=3182 -$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=9 -n=4096 -$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=3 -n=8192 -done -done -done -done +total=0 +valid=0 -# The following cases uses two pass pipeline which doesn't support quant epilogue. -for fquant in "" -for pr_i in "fp16" "bf16" ; do -for fadd in "0" "1"; do -# 0: for no specific RMSNorm; 1: for T-5 like RMSNorm -for s in "0" "1"; do -$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=1 -n=10547 -#$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=3 -n=17134 -done -done +run_case() { + cmd="$EXE -prec_i=$1 -fadd=$2 -s=$3 $4 -m=$5 -n=$6 $7" + echo "[CMD] $cmd" + output=$($cmd 2>&1) + echo "$output" + if echo "$output" | grep -q "valid:y"; then + valid=$((valid + 1)) + fi + total=$((total + 1)) +} + +fquant_list=( + "" + "-fquant=1 -prec_o=int8" + "-fquant=2 -prec_o=int8" + "-fquant=1 -prec_o=fp8" + "-fquant=2 -prec_o=fp8" + "-fquant=1 -prec_o=int8 -save_unquant=1" + "-fquant=2 -prec_o=int8 -save_unquant=1" + "-fquant=1 -prec_o=fp8 -save_unquant=1" + "-fquant=2 -prec_o=fp8 -save_unquant=1" +) + +m_n_list=( + "99 13" "17 16" "1 100" "4 128" "80 127" + "7 599" "19 512" "11 510" "91 636" + "31 1024" "8 1501" "3 1826" "5 2040" + "7 2734" "1 3182" "9 4096" "3 8192" +) + +### Add special stride test ### +m_n_stride_list=( + "22 255 -x_stride=256 -xr_stride=256 -y_stride=256 -yr_stride=256" + "33 313 -x_stride=1000 -xr_stride=1000 -y_stride=1000 -yr_stride=1000" + "171 676 -x_stride=818 -xr_stride=818 -y_stride=818 -yr_stride=818" + "12 768 -x_stride=800 -xr_stride=800 -y_stride=800 -yr_stride=800" + "100 766 -x_stride=812 -xr_stride=812 -y_stride=812 -yr_stride=812" + "64 1000 -x_stride=1004 -xr_stride=1004 -y_stride=1004 -yr_stride=1004" +) + +for fquant in "${fquant_list[@]}"; do + for pr_i in "fp16" "bf16"; do + for fadd in "0" "1"; do + for s in "0" "1"; do + for pair in "${m_n_list[@]}"; do + m=$(echo $pair | cut -d ' ' -f1) + n=$(echo $pair | cut -d ' ' -f2) + run_case "$pr_i" "$fadd" "$s" "$fquant" "$m" "$n" "" + done + + ### Running tests with stride ### + for triple in "${m_n_stride_list[@]}"; do + m=$(echo $triple | cut -d ' ' -f1) + n=$(echo $triple | cut -d ' ' -f2) + stride_args=$(echo $triple | cut -d ' ' -f3-) + run_case "$pr_i" "$fadd" "$s" "$fquant" "$m" "$n" "$stride_args" + done + done + done + done done + +# Special two-pass only +for pr_i in "fp16" "bf16"; do + for fadd in "0" "1"; do + for s in "0" "1"; do + run_case "$pr_i" "$fadd" "$s" "" "1" "10547" "" + done + done done + +# Summary +echo "==============================" +echo "Total cases: $total" +echo "Valid cases: $valid" +accuracy=$(awk "BEGIN {printf \"%.2f\", ($valid / $total) * 100}") +echo "Accuracy: $accuracy%" +echo "==============================" From 510dc8374f27cc07d2e45df8bc43d8e8ec4cb87d Mon Sep 17 00:00:00 2001 From: Clement Lin Date: Fri, 18 Jul 2025 18:12:11 +0800 Subject: [PATCH 02/13] Update rmsnorm host reference --- .../reference/reference_rmsnorm2d_fwd.hpp | 78 +++++++++++++++++-- 1 file changed, 70 insertions(+), 8 deletions(-) diff --git a/include/ck_tile/host/reference/reference_rmsnorm2d_fwd.hpp b/include/ck_tile/host/reference/reference_rmsnorm2d_fwd.hpp index 070168b51db..55248879412 100644 --- a/include/ck_tile/host/reference/reference_rmsnorm2d_fwd.hpp +++ b/include/ck_tile/host/reference/reference_rmsnorm2d_fwd.hpp @@ -45,30 +45,91 @@ void reference_rmsnorm2d_fwd(const HostTensor& x_m_n, ComputeDataType epsilon, Epilogue epilogue_functor = {}) { + constexpr int elements_per_thread = 5; + constexpr int warp_size = 64; + auto rmsnorm2d_fwd_func = [&](auto m) { const int N = x_m_n.mDesc.get_lengths()[1]; - ComputeDataType mean_square = 0; - ComputeDataType divisor = 0; + const int num_threads = N / elements_per_thread; + const int num_warps = (num_threads + warp_size - 1) / warp_size; - for(int n = 0; n < N; ++n) + // Step 1: per-thread local partial sum + std::vector thread_partial_sum(num_threads, 0); + + for(int tid = 0; tid < num_threads; ++tid) { - ComputeDataType x = ck_tile::type_convert(x_m_n(m, n)); - mean_square += x * x; + for(int i = 0; i < elements_per_thread; ++i) + { + int n = tid * elements_per_thread + i; + if(n < N) + { + ComputeDataType x = ck_tile::type_convert(x_m_n(m, n)); + thread_partial_sum[tid] += x * x; + } + } } - mean_square = mean_square / N; - divisor = ck_tile::type_convert(1) / ck_tile::sqrt(mean_square + epsilon); + // Step 2: intra-warp tree reduction + std::vector warp_partial_sum(num_warps, 0); + for(int w = 0; w < num_warps; ++w) + { + ComputeDataType warp_sum = 0; + for(int t = 0; t < warp_size; ++t) + { + int tid = w * warp_size + t; + if(tid < num_threads) + warp_sum += thread_partial_sum[tid]; + } + warp_partial_sum[w] = warp_sum; + } + + // Step 3: cross-warp reduction + // ComputeDataType total_sum = 0; + // for(int w = 0; w < num_warps; ++w) + // total_sum += warp_partial_sum[w]; + // Step 3: cross-warp tree reduction + ComputeDataType total_sum = 0; + { + std::vector buffer = warp_partial_sum; // copy for reduction + int size = static_cast(buffer.size()); + while(size > 1) + { + int half = size / 2; + for(int i = 0; i < half; ++i) + { + buffer[i] += buffer[i + half]; + } + if(size % 2 == 1) // handle odd case + { + buffer[0] += buffer[size - 1]; + size = half + 1; + } + else + { + size = half; + } + } + total_sum = buffer[0]; + } + + + ComputeDataType mean_square = total_sum / N; + ComputeDataType divisor = ck_tile::type_convert(1) / + ck_tile::sqrt(mean_square + epsilon); if constexpr(!std::is_same_v) invRms_m(m) = ck_tile::type_convert(divisor); + // Compute y = x * gamma * inv_rms HostTensor acc(x_m_n.get_lengths(), x_m_n.get_strides()); for(int n = 0; n < N; ++n) { ComputeDataType x = ck_tile::type_convert(x_m_n(m, n)); + YDataType tmp = ck_tile::type_convert(x*divisor); ComputeDataType gamma = ck_tile::type_convert(gamma_n(n)); - acc(m, n) = x * divisor * gamma; + ComputeDataType tmp1 = ck_tile::type_convert(tmp) * gamma; + acc(m, n) = tmp1; } if constexpr(!std::is_same_v) @@ -84,4 +145,5 @@ void reference_rmsnorm2d_fwd(const HostTensor& x_m_n, make_ParallelTensorFunctor(rmsnorm2d_fwd_func, invRms_m.mDesc.get_lengths()[0])( std::thread::hardware_concurrency()); } + } // namespace ck_tile From 2985332cb1323b71503036abd7dfddc38b82e658 Mon Sep 17 00:00:00 2001 From: Clement Lin Date: Fri, 18 Jul 2025 18:31:24 +0800 Subject: [PATCH 03/13] Update tree reduction of rmsnorm for reference host --- .../reference/reference_rmsnorm2d_fwd.hpp | 35 +++++++++++++++++-- 1 file changed, 32 insertions(+), 3 deletions(-) diff --git a/include/ck_tile/host/reference/reference_rmsnorm2d_fwd.hpp b/include/ck_tile/host/reference/reference_rmsnorm2d_fwd.hpp index 55248879412..5ce34caa73a 100644 --- a/include/ck_tile/host/reference/reference_rmsnorm2d_fwd.hpp +++ b/include/ck_tile/host/reference/reference_rmsnorm2d_fwd.hpp @@ -71,19 +71,48 @@ void reference_rmsnorm2d_fwd(const HostTensor& x_m_n, } // Step 2: intra-warp tree reduction + // std::vector warp_partial_sum(num_warps, 0); + // for(int w = 0; w < num_warps; ++w) + // { + // ComputeDataType warp_sum = 0; + // for(int t = 0; t < warp_size; ++t) + // { + // int tid = w * warp_size + t; + // if(tid < num_threads) + // warp_sum += thread_partial_sum[tid]; + // } + // warp_partial_sum[w] = warp_sum; + // } + std::vector warp_partial_sum(num_warps, 0); for(int w = 0; w < num_warps; ++w) { - ComputeDataType warp_sum = 0; + std::vector buffer(warp_size, 0); for(int t = 0; t < warp_size; ++t) { int tid = w * warp_size + t; if(tid < num_threads) - warp_sum += thread_partial_sum[tid]; + buffer[t] = thread_partial_sum[tid]; + } + + // Tree reduction + int size = warp_size; + while(size > 1) + { + int half = size / 2; + for(int i = 0; i < half; ++i) + { + buffer[i] += buffer[i + half]; + } + if(size % 2 == 1) // odd case + buffer[0] += buffer[size - 1]; + size = (size + 1) / 2; } - warp_partial_sum[w] = warp_sum; + + warp_partial_sum[w] = buffer[0]; } + // Step 3: cross-warp reduction // ComputeDataType total_sum = 0; // for(int w = 0; w < num_warps; ++w) From 997996f1733f1245cd1bc4ebb8b6df5c5749c876 Mon Sep 17 00:00:00 2001 From: MHYang Date: Mon, 21 Jul 2025 14:46:05 +0800 Subject: [PATCH 04/13] Fix cross warp for m > 1 cases --- example/ck_tile/10_rmsnorm2d/generate.py | 6 +++--- include/ck_tile/ops/reduce/block/block_reduce2d.hpp | 5 +++-- .../rmsnorm2d_fwd_pipeline_model_sensitive_pass.hpp | 6 +++--- 3 files changed, 9 insertions(+), 8 deletions(-) diff --git a/example/ck_tile/10_rmsnorm2d/generate.py b/example/ck_tile/10_rmsnorm2d/generate.py index e682270ceb6..c308b283e54 100644 --- a/example/ck_tile/10_rmsnorm2d/generate.py +++ b/example/ck_tile/10_rmsnorm2d/generate.py @@ -642,15 +642,15 @@ def get_blobs(self): h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1, 256, 4, True, False, False, True, 0, 0, 1), h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 12, 1, 256, 2, True, False, False, True, 0, 0, 1), h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1,1024, 1, True, False, False, True, 0, 0, 1)] - } + } } - + total_blob = list() for model_sensitive_flag in [0, 1]: # 0: default; 1: model sensitive current_trait_dict = h_trait_dicts[model_sensitive_flag] for hs_key in current_trait_dict: - hs = current_trait_dict[hs_key] + hs = current_trait_dict[hs_key] current_n = hs_key for dtype, scale_type, fused_add, fused_quant, save_unquant in itertools.product(dtype_list, scale_list, fused_add_list, fused_sweep_list, bool_list): prec_i, prec_o = dtype.split(',') diff --git a/include/ck_tile/ops/reduce/block/block_reduce2d.hpp b/include/ck_tile/ops/reduce/block/block_reduce2d.hpp index 62c9944bd24..6803c25ecac 100644 --- a/include/ck_tile/ops/reduce/block/block_reduce2d.hpp +++ b/include/ck_tile/ops/reduce/block/block_reduce2d.hpp @@ -357,18 +357,19 @@ struct BlockReduce2dTreeCrossWarpSync if(lane_id == 0) { static_for<0, thread_buf_size, 1>{}([&](auto i) { - // Store the i-th element of this warp's thread_buffer into SMEM smem_ptr[smem_offset + i * num_warps] = y_tensor.get_thread_buffer()[i]; }); } block_sync_lds(); // We let each warp holds a duplication to do reduction. + index_t local_warp_id = warp_id / num_reduce_warps; + index_t local_smem_os = local_warp_id * num_reduce_warps; static_for<0, thread_buf_size, 1>{}([&](auto i) { DataType v = 0; if(lane_id < num_reduce_warps) { - v = smem_ptr[lane_id + i * num_warps]; + v = smem_ptr[i * num_warps + local_smem_os + lane_id]; } // cross-lane reduce for replication diff --git a/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_model_sensitive_pass.hpp b/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_model_sensitive_pass.hpp index 810c3c52437..2c7f36ad199 100644 --- a/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_model_sensitive_pass.hpp +++ b/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_model_sensitive_pass.hpp @@ -146,7 +146,7 @@ struct Rmsnorm2dFwdPipelineModelSensitiveT5Pass // compute mean square each-thread->cross-lane->cross-warp auto square_sum = block_reduce2d.template MakeYBlockTile(); set_tile(square_sum, 0); - if constexpr(Problem::BlockShape::Vector_N % 2 == 0) + if constexpr((Problem::BlockShape::Repeat_N * Problem::BlockShape::Vector_N) % 2 == 0) { sweep_tile( acc, @@ -179,7 +179,7 @@ struct Rmsnorm2dFwdPipelineModelSensitiveT5Pass const auto gamma_ = type_convert(gamma[j_idx]); - if constexpr(std::is_same_v) + if constexpr(std::is_same_v) { const auto tmp0 = float_to_bf16(acc[idx] * inv_rms_[i_idx]); @@ -190,7 +190,7 @@ struct Rmsnorm2dFwdPipelineModelSensitiveT5Pass } else { - const auto tmp = type_convert(acc[idx] * inv_rms_[i_idx]); + const auto tmp = type_convert(acc[idx] * inv_rms_[i_idx]); const auto rmsn_ = type_convert(tmp) * gamma_; rmsn(idx) = rmsn_; } From 6a1ac38d37639a79a654172f09dbf51f59180acd Mon Sep 17 00:00:00 2001 From: Clement Lin Date: Mon, 21 Jul 2025 20:18:11 +0800 Subject: [PATCH 05/13] Add RMSNorm model selectable option for host reference --- .../ck_tile/10_rmsnorm2d/rmsnorm2d_fwd.cpp | 15 ++- .../reference/reference_rmsnorm2d_fwd.hpp | 123 ++++-------------- .../ops/reduce/block/block_reduce2d.hpp | 4 +- 3 files changed, 36 insertions(+), 106 deletions(-) diff --git a/example/ck_tile/10_rmsnorm2d/rmsnorm2d_fwd.cpp b/example/ck_tile/10_rmsnorm2d/rmsnorm2d_fwd.cpp index 049a0cad41b..3615858f60c 100644 --- a/example/ck_tile/10_rmsnorm2d/rmsnorm2d_fwd.cpp +++ b/example/ck_tile/10_rmsnorm2d/rmsnorm2d_fwd.cpp @@ -314,7 +314,8 @@ bool run(const ck_tile::ArgParser& arg_parser) invRms_host_ref, unquant_y_host_ref, epsilon, - default_and_dquant_functor); + default_and_dquant_functor, + use_model_sensitive_rmsnorm); } else { @@ -329,7 +330,8 @@ bool run(const ck_tile::ArgParser& arg_parser) invRms_host_ref, unquant_y_host_ref, epsilon, - dquant_functor); + dquant_functor, + use_model_sensitive_rmsnorm); } } else @@ -341,7 +343,14 @@ bool run(const ck_tile::ArgParser& arg_parser) YDataType, InvRmsDataType, ck_tile::null_type>( - x_host, gamma_host, y_host_ref, invRms_host_ref, unquant_y_null, epsilon); + x_host, + gamma_host, + y_host_ref, + invRms_host_ref, + unquant_y_null, + epsilon, + ck_tile::reference_rmsnorm2d_default_epilogue{}, + use_model_sensitive_rmsnorm); } y_buf.FromDevice(y_host_dev.data()); diff --git a/include/ck_tile/host/reference/reference_rmsnorm2d_fwd.hpp b/include/ck_tile/host/reference/reference_rmsnorm2d_fwd.hpp index 5ce34caa73a..b10ef230b03 100644 --- a/include/ck_tile/host/reference/reference_rmsnorm2d_fwd.hpp +++ b/include/ck_tile/host/reference/reference_rmsnorm2d_fwd.hpp @@ -43,122 +43,43 @@ void reference_rmsnorm2d_fwd(const HostTensor& x_m_n, HostTensor& invRms_m, HostTensor& unquant_y_m_n, ComputeDataType epsilon, - Epilogue epilogue_functor = {}) + Epilogue epilogue_functor = {}, + const int use_model_sensitive_rmsnorm = 0) { - constexpr int elements_per_thread = 5; - constexpr int warp_size = 64; - auto rmsnorm2d_fwd_func = [&](auto m) { const int N = x_m_n.mDesc.get_lengths()[1]; - const int num_threads = N / elements_per_thread; - const int num_warps = (num_threads + warp_size - 1) / warp_size; - - // Step 1: per-thread local partial sum - std::vector thread_partial_sum(num_threads, 0); - - for(int tid = 0; tid < num_threads; ++tid) - { - for(int i = 0; i < elements_per_thread; ++i) - { - int n = tid * elements_per_thread + i; - if(n < N) - { - ComputeDataType x = ck_tile::type_convert(x_m_n(m, n)); - thread_partial_sum[tid] += x * x; - } - } - } - - // Step 2: intra-warp tree reduction - // std::vector warp_partial_sum(num_warps, 0); - // for(int w = 0; w < num_warps; ++w) - // { - // ComputeDataType warp_sum = 0; - // for(int t = 0; t < warp_size; ++t) - // { - // int tid = w * warp_size + t; - // if(tid < num_threads) - // warp_sum += thread_partial_sum[tid]; - // } - // warp_partial_sum[w] = warp_sum; - // } - - std::vector warp_partial_sum(num_warps, 0); - for(int w = 0; w < num_warps; ++w) - { - std::vector buffer(warp_size, 0); - for(int t = 0; t < warp_size; ++t) - { - int tid = w * warp_size + t; - if(tid < num_threads) - buffer[t] = thread_partial_sum[tid]; - } - - // Tree reduction - int size = warp_size; - while(size > 1) - { - int half = size / 2; - for(int i = 0; i < half; ++i) - { - buffer[i] += buffer[i + half]; - } - if(size % 2 == 1) // odd case - buffer[0] += buffer[size - 1]; - size = (size + 1) / 2; - } - - warp_partial_sum[w] = buffer[0]; - } + ComputeDataType mean_square = 0; + ComputeDataType divisor = 0; - - // Step 3: cross-warp reduction - // ComputeDataType total_sum = 0; - // for(int w = 0; w < num_warps; ++w) - // total_sum += warp_partial_sum[w]; - // Step 3: cross-warp tree reduction - ComputeDataType total_sum = 0; + for(int n = 0; n < N; ++n) { - std::vector buffer = warp_partial_sum; // copy for reduction - int size = static_cast(buffer.size()); - while(size > 1) - { - int half = size / 2; - for(int i = 0; i < half; ++i) - { - buffer[i] += buffer[i + half]; - } - if(size % 2 == 1) // handle odd case - { - buffer[0] += buffer[size - 1]; - size = half + 1; - } - else - { - size = half; - } - } - total_sum = buffer[0]; + ComputeDataType x = ck_tile::type_convert(x_m_n(m, n)); + mean_square += x * x; } - - ComputeDataType mean_square = total_sum / N; - ComputeDataType divisor = ck_tile::type_convert(1) / - ck_tile::sqrt(mean_square + epsilon); + mean_square = mean_square / N; + divisor = ck_tile::type_convert(1) / ck_tile::sqrt(mean_square + epsilon); if constexpr(!std::is_same_v) invRms_m(m) = ck_tile::type_convert(divisor); - // Compute y = x * gamma * inv_rms HostTensor acc(x_m_n.get_lengths(), x_m_n.get_strides()); for(int n = 0; n < N; ++n) { - ComputeDataType x = ck_tile::type_convert(x_m_n(m, n)); - YDataType tmp = ck_tile::type_convert(x*divisor); - ComputeDataType gamma = ck_tile::type_convert(gamma_n(n)); - ComputeDataType tmp1 = ck_tile::type_convert(tmp) * gamma; - acc(m, n) = tmp1; + if(use_model_sensitive_rmsnorm == 0) // 0: for no specific model + { + ComputeDataType x = ck_tile::type_convert(x_m_n(m, n)); + ComputeDataType gamma = ck_tile::type_convert(gamma_n(n)); + acc(m, n) = x * divisor * gamma; + } + else if(use_model_sensitive_rmsnorm == 1) // 1: for T5-like model + { + ComputeDataType x = ck_tile::type_convert(x_m_n(m, n)); + const auto tmp = type_convert(x * divisor); + const auto rmsn_ = type_convert(tmp) * gamma_n(n); + acc(m, n) = rmsn_; + } } if constexpr(!std::is_same_v) diff --git a/include/ck_tile/ops/reduce/block/block_reduce2d.hpp b/include/ck_tile/ops/reduce/block/block_reduce2d.hpp index 6803c25ecac..2b1bfe1cbed 100644 --- a/include/ck_tile/ops/reduce/block/block_reduce2d.hpp +++ b/include/ck_tile/ops/reduce/block/block_reduce2d.hpp @@ -363,8 +363,8 @@ struct BlockReduce2dTreeCrossWarpSync block_sync_lds(); // We let each warp holds a duplication to do reduction. - index_t local_warp_id = warp_id / num_reduce_warps; - index_t local_smem_os = local_warp_id * num_reduce_warps; + const index_t local_warp_id = warp_id / num_reduce_warps; + const index_t local_smem_os = local_warp_id * num_reduce_warps; static_for<0, thread_buf_size, 1>{}([&](auto i) { DataType v = 0; if(lane_id < num_reduce_warps) From 58a6ee814d971303b2c374f405fdca54e20c1371 Mon Sep 17 00:00:00 2001 From: MHYang Date: Tue, 22 Jul 2025 13:19:32 +0800 Subject: [PATCH 06/13] Fix save_unquant cases --- example/ck_tile/10_rmsnorm2d/rmsnorm2d_fwd.cpp | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/example/ck_tile/10_rmsnorm2d/rmsnorm2d_fwd.cpp b/example/ck_tile/10_rmsnorm2d/rmsnorm2d_fwd.cpp index 3615858f60c..3d9f324137e 100644 --- a/example/ck_tile/10_rmsnorm2d/rmsnorm2d_fwd.cpp +++ b/example/ck_tile/10_rmsnorm2d/rmsnorm2d_fwd.cpp @@ -295,7 +295,7 @@ bool run(const ck_tile::ArgParser& arg_parser) const int N = acc_.mDesc.get_lengths()[1]; for(int n_ = 0; n_ < N; ++n_) { - o_unquant_(m_, n_) = ck_tile::type_convert(acc_(m_, n_)); + o_unquant_(m_, n_) = ck_tile::type_convert(acc_(m_, n_)); } dquant_functor(m_, o_, acc_); @@ -361,6 +361,11 @@ bool run(const ck_tile::ArgParser& arg_parser) y_residual_buf.FromDevice(y_residual_host_dev.data()); } + if constexpr(SaveUnquant) + { + unquant_y_buf.FromDevice(unquant_y_host_dev.data()); + } + auto [rtol, atol] = get_elimit(); if(x_stride == n) { From b79626983e49994cb13b0e1791ea223672cbdf0f Mon Sep 17 00:00:00 2001 From: Clement Lin Date: Wed, 23 Jul 2025 23:09:32 +0800 Subject: [PATCH 07/13] Update reference rmsnorm forward function to use enum for model sensitivity --- .../host/reference/reference_rmsnorm2d_fwd.hpp | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/include/ck_tile/host/reference/reference_rmsnorm2d_fwd.hpp b/include/ck_tile/host/reference/reference_rmsnorm2d_fwd.hpp index b10ef230b03..2c73d1260f8 100644 --- a/include/ck_tile/host/reference/reference_rmsnorm2d_fwd.hpp +++ b/include/ck_tile/host/reference/reference_rmsnorm2d_fwd.hpp @@ -5,6 +5,7 @@ #include "ck_tile/core.hpp" #include "ck_tile/host/host_tensor.hpp" +#include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_traits.hpp" namespace ck_tile { @@ -43,8 +44,9 @@ void reference_rmsnorm2d_fwd(const HostTensor& x_m_n, HostTensor& invRms_m, HostTensor& unquant_y_m_n, ComputeDataType epsilon, - Epilogue epilogue_functor = {}, - const int use_model_sensitive_rmsnorm = 0) + Epilogue epilogue_functor = {}, + const int use_model_sensitive_rmsnorm = + static_cast(Rmsnorm2dSensitiveEnum::NO_SPECIFIC_MODEL)) { auto rmsnorm2d_fwd_func = [&](auto m) { const int N = x_m_n.mDesc.get_lengths()[1]; @@ -67,13 +69,16 @@ void reference_rmsnorm2d_fwd(const HostTensor& x_m_n, HostTensor acc(x_m_n.get_lengths(), x_m_n.get_strides()); for(int n = 0; n < N; ++n) { - if(use_model_sensitive_rmsnorm == 0) // 0: for no specific model + if(use_model_sensitive_rmsnorm == + static_cast( + Rmsnorm2dSensitiveEnum::NO_SPECIFIC_MODEL)) // 0: for no specific model { ComputeDataType x = ck_tile::type_convert(x_m_n(m, n)); ComputeDataType gamma = ck_tile::type_convert(gamma_n(n)); acc(m, n) = x * divisor * gamma; } - else if(use_model_sensitive_rmsnorm == 1) // 1: for T5-like model + else if(use_model_sensitive_rmsnorm == + static_cast(Rmsnorm2dSensitiveEnum::T5_MODEL_LIKE)) // 1: for T5-like model { ComputeDataType x = ck_tile::type_convert(x_m_n(m, n)); const auto tmp = type_convert(x * divisor); From ac2ba697287538adcde3bb821f0aee5fefa6a826 Mon Sep 17 00:00:00 2001 From: Clement Lin Date: Thu, 24 Jul 2025 09:43:33 +0800 Subject: [PATCH 08/13] Update reference rmsnorm calculation for model sensitivity --- .../reference/reference_rmsnorm2d_fwd.hpp | 24 +++++++++++++------ 1 file changed, 17 insertions(+), 7 deletions(-) diff --git a/include/ck_tile/host/reference/reference_rmsnorm2d_fwd.hpp b/include/ck_tile/host/reference/reference_rmsnorm2d_fwd.hpp index 2c73d1260f8..f4c90f4f7ba 100644 --- a/include/ck_tile/host/reference/reference_rmsnorm2d_fwd.hpp +++ b/include/ck_tile/host/reference/reference_rmsnorm2d_fwd.hpp @@ -69,21 +69,31 @@ void reference_rmsnorm2d_fwd(const HostTensor& x_m_n, HostTensor acc(x_m_n.get_lengths(), x_m_n.get_strides()); for(int n = 0; n < N; ++n) { + ComputeDataType x = ck_tile::type_convert(x_m_n(m, n)); + ComputeDataType gamma = ck_tile::type_convert(gamma_n(n)); if(use_model_sensitive_rmsnorm == static_cast( Rmsnorm2dSensitiveEnum::NO_SPECIFIC_MODEL)) // 0: for no specific model { - ComputeDataType x = ck_tile::type_convert(x_m_n(m, n)); - ComputeDataType gamma = ck_tile::type_convert(gamma_n(n)); - acc(m, n) = x * divisor * gamma; + acc(m, n) = x * divisor * gamma; } else if(use_model_sensitive_rmsnorm == static_cast(Rmsnorm2dSensitiveEnum::T5_MODEL_LIKE)) // 1: for T5-like model { - ComputeDataType x = ck_tile::type_convert(x_m_n(m, n)); - const auto tmp = type_convert(x * divisor); - const auto rmsn_ = type_convert(tmp) * gamma_n(n); - acc(m, n) = rmsn_; + if constexpr(std::is_same_v) + { + const auto tmp0 = float_to_bf16(x * divisor); + const auto tmp1 = float_to_bf16( + type_convert(tmp0) * gamma); + const auto rmsn_ = type_convert(tmp1); + acc(m, n) = rmsn_; + } + else + { + const auto tmp = type_convert(x * divisor); + const auto rmsn_ = type_convert(tmp) * gamma_n(n); + acc(m, n) = rmsn_; + } } } From 3a141eb26f22f60eea721f54edb50ed5b3915931 Mon Sep 17 00:00:00 2001 From: MHYang Date: Fri, 25 Jul 2025 16:33:11 +0800 Subject: [PATCH 09/13] Fix m warp for layernorm --- example/ck_tile/02_layernorm2d/generate.py | 2 +- include/ck_tile/ops/reduce/block/block_reduce2d.hpp | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/example/ck_tile/02_layernorm2d/generate.py b/example/ck_tile/02_layernorm2d/generate.py index d77582630a8..9f7aa1363fd 100644 --- a/example/ck_tile/02_layernorm2d/generate.py +++ b/example/ck_tile/02_layernorm2d/generate.py @@ -85,7 +85,7 @@ class layernorm_fwd_codegen: if constexpr(is_warp_per_row) { static_assert(ck_tile::get_warp_size() % ThreadPerBlock_N_ == 0); - return total_warps * (ck_tile::get_warp_size() / ThreadPerBlock_N_); + return total_warps; } else { diff --git a/include/ck_tile/ops/reduce/block/block_reduce2d.hpp b/include/ck_tile/ops/reduce/block/block_reduce2d.hpp index 2b1bfe1cbed..3559ec736fc 100644 --- a/include/ck_tile/ops/reduce/block/block_reduce2d.hpp +++ b/include/ck_tile/ops/reduce/block/block_reduce2d.hpp @@ -357,6 +357,7 @@ struct BlockReduce2dTreeCrossWarpSync if(lane_id == 0) { static_for<0, thread_buf_size, 1>{}([&](auto i) { + // Store the i-th element of this warp's thread_buffer into SMEM smem_ptr[smem_offset + i * num_warps] = y_tensor.get_thread_buffer()[i]; }); } From 0c803d1c16243d5c7c79b5930ebad7ad68557c23 Mon Sep 17 00:00:00 2001 From: Clement Lin Date: Mon, 28 Jul 2025 17:30:28 +0800 Subject: [PATCH 10/13] Adjust parameter of reference for twoPass --- .../ck_tile/10_rmsnorm2d/rmsnorm2d_fwd.cpp | 25 +++++++++++-------- .../reference/reference_rmsnorm2d_fwd.hpp | 2 +- 2 files changed, 16 insertions(+), 11 deletions(-) diff --git a/example/ck_tile/10_rmsnorm2d/rmsnorm2d_fwd.cpp b/example/ck_tile/10_rmsnorm2d/rmsnorm2d_fwd.cpp index 3d9f324137e..3c7a9b6e48c 100644 --- a/example/ck_tile/10_rmsnorm2d/rmsnorm2d_fwd.cpp +++ b/example/ck_tile/10_rmsnorm2d/rmsnorm2d_fwd.cpp @@ -67,16 +67,16 @@ template bool run(const ck_tile::ArgParser& arg_parser) { - ck_tile::index_t m = arg_parser.get_int("m"); - ck_tile::index_t n = arg_parser.get_int("n"); - float epsilon = arg_parser.get_float("e"); - int kname = arg_parser.get_int("kname"); - int do_validation = arg_parser.get_int("v"); - int fused_add = arg_parser.get_int("fadd"); - int fused_quant = arg_parser.get_int("fquant"); - int warmup = arg_parser.get_int("warmup"); - int repeat = arg_parser.get_int("repeat"); - const int use_model_sensitive_rmsnorm = arg_parser.get_int("s"); + ck_tile::index_t m = arg_parser.get_int("m"); + ck_tile::index_t n = arg_parser.get_int("n"); + float epsilon = arg_parser.get_float("e"); + int kname = arg_parser.get_int("kname"); + int do_validation = arg_parser.get_int("v"); + int fused_add = arg_parser.get_int("fadd"); + int fused_quant = arg_parser.get_int("fquant"); + int warmup = arg_parser.get_int("warmup"); + int repeat = arg_parser.get_int("repeat"); + int use_model_sensitive_rmsnorm = arg_parser.get_int("s"); ck_tile::index_t x_stride = arg_parser.get_int("x_stride"); if(x_stride < 0) @@ -193,6 +193,11 @@ bool run(const ck_tile::ArgParser& arg_parser) return base_str; }(); + if(n > 8192) + { + use_model_sensitive_rmsnorm = 0; + } + std::cout << "[" << prec_str << "]" << " m:" << m << ", n:" << n << ", x_stride:" << x_stride << ", xr_stride:" << xr_stride << ", y_stride:" << y_stride diff --git a/include/ck_tile/host/reference/reference_rmsnorm2d_fwd.hpp b/include/ck_tile/host/reference/reference_rmsnorm2d_fwd.hpp index f4c90f4f7ba..424fff44704 100644 --- a/include/ck_tile/host/reference/reference_rmsnorm2d_fwd.hpp +++ b/include/ck_tile/host/reference/reference_rmsnorm2d_fwd.hpp @@ -91,7 +91,7 @@ void reference_rmsnorm2d_fwd(const HostTensor& x_m_n, else { const auto tmp = type_convert(x * divisor); - const auto rmsn_ = type_convert(tmp) * gamma_n(n); + const auto rmsn_ = type_convert(tmp) * gamma; acc(m, n) = rmsn_; } } From 1cb4149d8279e780683dc6dba3efd401f96ab9fb Mon Sep 17 00:00:00 2001 From: Clement Lin Date: Wed, 30 Jul 2025 13:00:03 +0800 Subject: [PATCH 11/13] Fix clang format --- example/ck_tile/10_rmsnorm2d/rmsnorm2d_fwd.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/example/ck_tile/10_rmsnorm2d/rmsnorm2d_fwd.cpp b/example/ck_tile/10_rmsnorm2d/rmsnorm2d_fwd.cpp index 0cb9985809a..3c7a9b6e48c 100644 --- a/example/ck_tile/10_rmsnorm2d/rmsnorm2d_fwd.cpp +++ b/example/ck_tile/10_rmsnorm2d/rmsnorm2d_fwd.cpp @@ -197,7 +197,7 @@ bool run(const ck_tile::ArgParser& arg_parser) { use_model_sensitive_rmsnorm = 0; } - + std::cout << "[" << prec_str << "]" << " m:" << m << ", n:" << n << ", x_stride:" << x_stride << ", xr_stride:" << xr_stride << ", y_stride:" << y_stride From 0dc5d41584dce966d77387af735e57cc958747bc Mon Sep 17 00:00:00 2001 From: Clement Lin Date: Tue, 12 Aug 2025 22:51:18 +0800 Subject: [PATCH 12/13] Run clang-format-overwrite.sh to fix formating issue --- .../02_gemm_bilinear/gemm_bilinear_wmma_fp16.cpp | 2 +- .../02_gemm_bilinear/gemm_bilinear_wmma_int8.cpp | 2 +- .../02_gemm_bilinear/gemm_bilinear_xdl_fp16.cpp | 2 +- .../60_gemm_multi_ABD/gemm_multi_ABD_xdl_fp16.cpp | 2 +- .../contraction_multi_ABD_xdl_fp16.cpp | 2 +- .../ck/library/utility/host_tensor_generator.hpp | 8 ++++---- .../gpu/element/binary_element_wise_operation.hpp | 4 ++-- .../gpu/element/unary_element_wise_operation.hpp | 14 +++++++------- .../elementwise/unary_element_wise_operation.hpp | 12 ++++++------ 9 files changed, 24 insertions(+), 24 deletions(-) diff --git a/example/02_gemm_bilinear/gemm_bilinear_wmma_fp16.cpp b/example/02_gemm_bilinear/gemm_bilinear_wmma_fp16.cpp index 18731e810e1..b50925d7116 100644 --- a/example/02_gemm_bilinear/gemm_bilinear_wmma_fp16.cpp +++ b/example/02_gemm_bilinear/gemm_bilinear_wmma_fp16.cpp @@ -21,7 +21,7 @@ struct AlphaBetaAdd { - AlphaBetaAdd(float alpha, float beta) : alpha_(alpha), beta_(beta){}; + AlphaBetaAdd(float alpha, float beta) : alpha_(alpha), beta_(beta) {}; template __host__ __device__ constexpr void operator()(E& e, const C& c, const D& d) const; diff --git a/example/02_gemm_bilinear/gemm_bilinear_wmma_int8.cpp b/example/02_gemm_bilinear/gemm_bilinear_wmma_int8.cpp index 87812369bd1..a9eef9c6cb6 100644 --- a/example/02_gemm_bilinear/gemm_bilinear_wmma_int8.cpp +++ b/example/02_gemm_bilinear/gemm_bilinear_wmma_int8.cpp @@ -21,7 +21,7 @@ struct AlphaBetaAdd { - AlphaBetaAdd(int alpha, int beta) : alpha_(alpha), beta_(beta){}; + AlphaBetaAdd(int alpha, int beta) : alpha_(alpha), beta_(beta) {}; template __host__ __device__ constexpr void operator()(E& e, const C& c, const D& d) const; diff --git a/example/02_gemm_bilinear/gemm_bilinear_xdl_fp16.cpp b/example/02_gemm_bilinear/gemm_bilinear_xdl_fp16.cpp index c3e6ef7d5df..aa39afe2774 100644 --- a/example/02_gemm_bilinear/gemm_bilinear_xdl_fp16.cpp +++ b/example/02_gemm_bilinear/gemm_bilinear_xdl_fp16.cpp @@ -20,7 +20,7 @@ struct AlphaBetaAdd { - AlphaBetaAdd(float alpha, float beta) : alpha_(alpha), beta_(beta){}; + AlphaBetaAdd(float alpha, float beta) : alpha_(alpha), beta_(beta) {}; template __host__ __device__ constexpr void operator()(E& e, const C& c, const D& d) const; diff --git a/example/60_gemm_multi_ABD/gemm_multi_ABD_xdl_fp16.cpp b/example/60_gemm_multi_ABD/gemm_multi_ABD_xdl_fp16.cpp index 93034a8b70c..9b218ed5835 100644 --- a/example/60_gemm_multi_ABD/gemm_multi_ABD_xdl_fp16.cpp +++ b/example/60_gemm_multi_ABD/gemm_multi_ABD_xdl_fp16.cpp @@ -83,7 +83,7 @@ struct AddScale struct AlphaBetaAdd { - AlphaBetaAdd(float alpha, float beta) : alpha_(alpha), beta_(beta){}; + AlphaBetaAdd(float alpha, float beta) : alpha_(alpha), beta_(beta) {}; template __host__ __device__ constexpr void operator()(E& e, const C& c, const D& d) const; diff --git a/example/61_contraction_multi_ABD/contraction_multi_ABD_xdl_fp16.cpp b/example/61_contraction_multi_ABD/contraction_multi_ABD_xdl_fp16.cpp index e7c1d6f0be4..b50e876384b 100644 --- a/example/61_contraction_multi_ABD/contraction_multi_ABD_xdl_fp16.cpp +++ b/example/61_contraction_multi_ABD/contraction_multi_ABD_xdl_fp16.cpp @@ -42,7 +42,7 @@ static constexpr ck::index_t NumDimK = 2; struct AlphaBetaAdd { - AlphaBetaAdd(float alpha, float beta) : alpha_(alpha), beta_(beta){}; + AlphaBetaAdd(float alpha, float beta) : alpha_(alpha), beta_(beta) {}; template __host__ __device__ constexpr void operator()(E& e, const C& c, const D& d) const; diff --git a/include/ck/library/utility/host_tensor_generator.hpp b/include/ck/library/utility/host_tensor_generator.hpp index ab69412c155..bc376ffcdf3 100644 --- a/include/ck/library/utility/host_tensor_generator.hpp +++ b/include/ck/library/utility/host_tensor_generator.hpp @@ -483,7 +483,7 @@ struct GeneratorTensor_4 std::normal_distribution distribution; GeneratorTensor_4(float mean, float stddev, unsigned int seed = 1) - : generator(seed), distribution(mean, stddev){}; + : generator(seed), distribution(mean, stddev) {}; template T operator()(Is...) @@ -501,7 +501,7 @@ struct GeneratorTensor_4 std::normal_distribution distribution; GeneratorTensor_4(float mean, float stddev, unsigned int seed = 1) - : generator(seed), distribution(mean, stddev){}; + : generator(seed), distribution(mean, stddev) {}; template ck::f4x2_pk_t operator()(Is...) @@ -520,7 +520,7 @@ struct GeneratorTensor_4 std::normal_distribution distribution; GeneratorTensor_4(float mean, float stddev, unsigned int seed = 1) - : generator(seed), distribution(mean, stddev){}; + : generator(seed), distribution(mean, stddev) {}; template ck::f6x32_pk_t operator()(Is...) @@ -542,7 +542,7 @@ struct GeneratorTensor_4 std::normal_distribution distribution; GeneratorTensor_4(float mean, float stddev, unsigned int seed = 1) - : generator(seed), distribution(mean, stddev){}; + : generator(seed), distribution(mean, stddev) {}; template ck::bf6x32_pk_t operator()(Is...) diff --git a/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp index d86f01e2558..f326c4a28db 100644 --- a/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp +++ b/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp @@ -279,7 +279,7 @@ struct Subtract struct Bilinear { - Bilinear(float alpha = 1.f, float beta = 1.f) : alpha_(alpha), beta_(beta){}; + Bilinear(float alpha = 1.f, float beta = 1.f) : alpha_(alpha), beta_(beta) {}; template __host__ __device__ constexpr void operator()(Y&, const X0&, const X1&) const; @@ -354,7 +354,7 @@ struct Bilinear struct AddClamp { AddClamp(float floor = 0.f, float ceil = NumericLimits::Max()) - : floor_(floor), ceil_(ceil){}; + : floor_(floor), ceil_(ceil) {}; template __host__ __device__ constexpr void operator()(Y& y, const X0& x0, const X1& x1) const; diff --git a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp index 4a87e8a2775..80b8306a51a 100644 --- a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp +++ b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp @@ -756,7 +756,7 @@ struct UnarySqrt struct Clamp { Clamp(float floor = 0.f, float ceil = NumericLimits::Max()) - : floor_(floor), ceil_(ceil){}; + : floor_(floor), ceil_(ceil) {}; template __host__ __device__ constexpr void operator()(Y& y, const X& x) const; @@ -1324,7 +1324,7 @@ struct Swish struct SoftRelu { - SoftRelu(float alpha = 1.f) : alpha_(alpha){}; + SoftRelu(float alpha = 1.f) : alpha_(alpha) {}; template __host__ __device__ void operator()(T& y, const T& x) const @@ -1353,7 +1353,7 @@ struct SoftRelu struct Power { Power(float alpha = 0.f, float beta = 1.f, float gamma = 2.f) - : alpha_(alpha), beta_(beta), gamma_(gamma){}; + : alpha_(alpha), beta_(beta), gamma_(gamma) {}; template __host__ __device__ void operator()(T& y, const T& x) const @@ -1386,7 +1386,7 @@ struct Power struct ClippedRelu { - ClippedRelu(float alpha = 0.f, float beta = 1.f) : alpha_(alpha), beta_(beta){}; + ClippedRelu(float alpha = 0.f, float beta = 1.f) : alpha_(alpha), beta_(beta) {}; template __host__ __device__ void operator()(T& y, const T& x) const @@ -1415,7 +1415,7 @@ struct ClippedRelu struct LeakyRelu { - LeakyRelu(float alpha = 0.01f) : alpha_(alpha){}; + LeakyRelu(float alpha = 0.01f) : alpha_(alpha) {}; template __host__ __device__ void operator()(T& y, const T& x) const @@ -1442,7 +1442,7 @@ struct LeakyRelu struct Elu { - Elu(float alpha = 1.f) : alpha_(alpha){}; + Elu(float alpha = 1.f) : alpha_(alpha) {}; template __host__ __device__ void operator()(T& y, const T& x) const @@ -1469,7 +1469,7 @@ struct Elu struct Logistic { - Logistic(float alpha = 1.f) : alpha_(alpha){}; + Logistic(float alpha = 1.f) : alpha_(alpha) {}; template __host__ __device__ void operator()(T& y, const T& x) const diff --git a/include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp b/include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp index 0e385901edd..8a97231030c 100644 --- a/include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp +++ b/include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp @@ -1308,7 +1308,7 @@ struct Swish struct SoftRelu { - SoftRelu(float alpha = 1.f) : alpha_(alpha){}; + SoftRelu(float alpha = 1.f) : alpha_(alpha) {}; template CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const @@ -1327,7 +1327,7 @@ struct SoftRelu struct Power { Power(float alpha = 0.f, float beta = 1.f, float gamma = 2.f) - : alpha_(alpha), beta_(beta), gamma_(gamma){}; + : alpha_(alpha), beta_(beta), gamma_(gamma) {}; template CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const @@ -1349,7 +1349,7 @@ struct Power struct ClippedRelu { - ClippedRelu(float alpha = 0.f, float beta = 1.f) : alpha_(alpha), beta_(beta){}; + ClippedRelu(float alpha = 0.f, float beta = 1.f) : alpha_(alpha), beta_(beta) {}; template CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const @@ -1368,7 +1368,7 @@ struct ClippedRelu struct LeakyRelu { - LeakyRelu(float alpha = 0.01f) : alpha_(alpha){}; + LeakyRelu(float alpha = 0.01f) : alpha_(alpha) {}; template CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const @@ -1385,7 +1385,7 @@ struct LeakyRelu struct Elu { - Elu(float alpha = 1.f) : alpha_(alpha){}; + Elu(float alpha = 1.f) : alpha_(alpha) {}; template CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const @@ -1402,7 +1402,7 @@ struct Elu struct Logistic { - Logistic(float alpha = 1.f) : alpha_(alpha){}; + Logistic(float alpha = 1.f) : alpha_(alpha) {}; template CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const From e358e8c07ae7c7359dfa06cc508bd714e4bc60fd Mon Sep 17 00:00:00 2001 From: illsilin_amdeng Date: Tue, 12 Aug 2025 18:55:17 -0700 Subject: [PATCH 13/13] fix clang format --- .../02_gemm_bilinear/gemm_bilinear_wmma_fp16.cpp | 2 +- .../02_gemm_bilinear/gemm_bilinear_wmma_int8.cpp | 2 +- .../02_gemm_bilinear/gemm_bilinear_xdl_fp16.cpp | 2 +- .../60_gemm_multi_ABD/gemm_multi_ABD_xdl_fp16.cpp | 2 +- .../contraction_multi_ABD_xdl_fp16.cpp | 2 +- example/ck_tile/10_rmsnorm2d/rmsnorm2d_fwd.cpp | 3 +-- .../ck/library/utility/host_tensor_generator.hpp | 8 ++++---- .../gpu/element/binary_element_wise_operation.hpp | 4 ++-- .../gpu/element/unary_element_wise_operation.hpp | 14 +++++++------- .../elementwise/unary_element_wise_operation.hpp | 12 ++++++------ 10 files changed, 25 insertions(+), 26 deletions(-) diff --git a/example/02_gemm_bilinear/gemm_bilinear_wmma_fp16.cpp b/example/02_gemm_bilinear/gemm_bilinear_wmma_fp16.cpp index b50925d7116..18731e810e1 100644 --- a/example/02_gemm_bilinear/gemm_bilinear_wmma_fp16.cpp +++ b/example/02_gemm_bilinear/gemm_bilinear_wmma_fp16.cpp @@ -21,7 +21,7 @@ struct AlphaBetaAdd { - AlphaBetaAdd(float alpha, float beta) : alpha_(alpha), beta_(beta) {}; + AlphaBetaAdd(float alpha, float beta) : alpha_(alpha), beta_(beta){}; template __host__ __device__ constexpr void operator()(E& e, const C& c, const D& d) const; diff --git a/example/02_gemm_bilinear/gemm_bilinear_wmma_int8.cpp b/example/02_gemm_bilinear/gemm_bilinear_wmma_int8.cpp index a9eef9c6cb6..87812369bd1 100644 --- a/example/02_gemm_bilinear/gemm_bilinear_wmma_int8.cpp +++ b/example/02_gemm_bilinear/gemm_bilinear_wmma_int8.cpp @@ -21,7 +21,7 @@ struct AlphaBetaAdd { - AlphaBetaAdd(int alpha, int beta) : alpha_(alpha), beta_(beta) {}; + AlphaBetaAdd(int alpha, int beta) : alpha_(alpha), beta_(beta){}; template __host__ __device__ constexpr void operator()(E& e, const C& c, const D& d) const; diff --git a/example/02_gemm_bilinear/gemm_bilinear_xdl_fp16.cpp b/example/02_gemm_bilinear/gemm_bilinear_xdl_fp16.cpp index aa39afe2774..c3e6ef7d5df 100644 --- a/example/02_gemm_bilinear/gemm_bilinear_xdl_fp16.cpp +++ b/example/02_gemm_bilinear/gemm_bilinear_xdl_fp16.cpp @@ -20,7 +20,7 @@ struct AlphaBetaAdd { - AlphaBetaAdd(float alpha, float beta) : alpha_(alpha), beta_(beta) {}; + AlphaBetaAdd(float alpha, float beta) : alpha_(alpha), beta_(beta){}; template __host__ __device__ constexpr void operator()(E& e, const C& c, const D& d) const; diff --git a/example/60_gemm_multi_ABD/gemm_multi_ABD_xdl_fp16.cpp b/example/60_gemm_multi_ABD/gemm_multi_ABD_xdl_fp16.cpp index 9b218ed5835..93034a8b70c 100644 --- a/example/60_gemm_multi_ABD/gemm_multi_ABD_xdl_fp16.cpp +++ b/example/60_gemm_multi_ABD/gemm_multi_ABD_xdl_fp16.cpp @@ -83,7 +83,7 @@ struct AddScale struct AlphaBetaAdd { - AlphaBetaAdd(float alpha, float beta) : alpha_(alpha), beta_(beta) {}; + AlphaBetaAdd(float alpha, float beta) : alpha_(alpha), beta_(beta){}; template __host__ __device__ constexpr void operator()(E& e, const C& c, const D& d) const; diff --git a/example/61_contraction_multi_ABD/contraction_multi_ABD_xdl_fp16.cpp b/example/61_contraction_multi_ABD/contraction_multi_ABD_xdl_fp16.cpp index b50e876384b..e7c1d6f0be4 100644 --- a/example/61_contraction_multi_ABD/contraction_multi_ABD_xdl_fp16.cpp +++ b/example/61_contraction_multi_ABD/contraction_multi_ABD_xdl_fp16.cpp @@ -42,7 +42,7 @@ static constexpr ck::index_t NumDimK = 2; struct AlphaBetaAdd { - AlphaBetaAdd(float alpha, float beta) : alpha_(alpha), beta_(beta) {}; + AlphaBetaAdd(float alpha, float beta) : alpha_(alpha), beta_(beta){}; template __host__ __device__ constexpr void operator()(E& e, const C& c, const D& d) const; diff --git a/example/ck_tile/10_rmsnorm2d/rmsnorm2d_fwd.cpp b/example/ck_tile/10_rmsnorm2d/rmsnorm2d_fwd.cpp index 3c7a9b6e48c..9070346ad42 100644 --- a/example/ck_tile/10_rmsnorm2d/rmsnorm2d_fwd.cpp +++ b/example/ck_tile/10_rmsnorm2d/rmsnorm2d_fwd.cpp @@ -198,8 +198,7 @@ bool run(const ck_tile::ArgParser& arg_parser) use_model_sensitive_rmsnorm = 0; } - std::cout << "[" << prec_str << "]" - << " m:" << m << ", n:" << n << ", x_stride:" << x_stride + std::cout << "[" << prec_str << "]" << " m:" << m << ", n:" << n << ", x_stride:" << x_stride << ", xr_stride:" << xr_stride << ", y_stride:" << y_stride << ", yr_stride:" << yr_stride << ", s:" << use_model_sensitive_rmsnorm << std::flush; diff --git a/include/ck/library/utility/host_tensor_generator.hpp b/include/ck/library/utility/host_tensor_generator.hpp index bc376ffcdf3..ab69412c155 100644 --- a/include/ck/library/utility/host_tensor_generator.hpp +++ b/include/ck/library/utility/host_tensor_generator.hpp @@ -483,7 +483,7 @@ struct GeneratorTensor_4 std::normal_distribution distribution; GeneratorTensor_4(float mean, float stddev, unsigned int seed = 1) - : generator(seed), distribution(mean, stddev) {}; + : generator(seed), distribution(mean, stddev){}; template T operator()(Is...) @@ -501,7 +501,7 @@ struct GeneratorTensor_4 std::normal_distribution distribution; GeneratorTensor_4(float mean, float stddev, unsigned int seed = 1) - : generator(seed), distribution(mean, stddev) {}; + : generator(seed), distribution(mean, stddev){}; template ck::f4x2_pk_t operator()(Is...) @@ -520,7 +520,7 @@ struct GeneratorTensor_4 std::normal_distribution distribution; GeneratorTensor_4(float mean, float stddev, unsigned int seed = 1) - : generator(seed), distribution(mean, stddev) {}; + : generator(seed), distribution(mean, stddev){}; template ck::f6x32_pk_t operator()(Is...) @@ -542,7 +542,7 @@ struct GeneratorTensor_4 std::normal_distribution distribution; GeneratorTensor_4(float mean, float stddev, unsigned int seed = 1) - : generator(seed), distribution(mean, stddev) {}; + : generator(seed), distribution(mean, stddev){}; template ck::bf6x32_pk_t operator()(Is...) diff --git a/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp index f326c4a28db..d86f01e2558 100644 --- a/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp +++ b/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp @@ -279,7 +279,7 @@ struct Subtract struct Bilinear { - Bilinear(float alpha = 1.f, float beta = 1.f) : alpha_(alpha), beta_(beta) {}; + Bilinear(float alpha = 1.f, float beta = 1.f) : alpha_(alpha), beta_(beta){}; template __host__ __device__ constexpr void operator()(Y&, const X0&, const X1&) const; @@ -354,7 +354,7 @@ struct Bilinear struct AddClamp { AddClamp(float floor = 0.f, float ceil = NumericLimits::Max()) - : floor_(floor), ceil_(ceil) {}; + : floor_(floor), ceil_(ceil){}; template __host__ __device__ constexpr void operator()(Y& y, const X0& x0, const X1& x1) const; diff --git a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp index 80b8306a51a..4a87e8a2775 100644 --- a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp +++ b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp @@ -756,7 +756,7 @@ struct UnarySqrt struct Clamp { Clamp(float floor = 0.f, float ceil = NumericLimits::Max()) - : floor_(floor), ceil_(ceil) {}; + : floor_(floor), ceil_(ceil){}; template __host__ __device__ constexpr void operator()(Y& y, const X& x) const; @@ -1324,7 +1324,7 @@ struct Swish struct SoftRelu { - SoftRelu(float alpha = 1.f) : alpha_(alpha) {}; + SoftRelu(float alpha = 1.f) : alpha_(alpha){}; template __host__ __device__ void operator()(T& y, const T& x) const @@ -1353,7 +1353,7 @@ struct SoftRelu struct Power { Power(float alpha = 0.f, float beta = 1.f, float gamma = 2.f) - : alpha_(alpha), beta_(beta), gamma_(gamma) {}; + : alpha_(alpha), beta_(beta), gamma_(gamma){}; template __host__ __device__ void operator()(T& y, const T& x) const @@ -1386,7 +1386,7 @@ struct Power struct ClippedRelu { - ClippedRelu(float alpha = 0.f, float beta = 1.f) : alpha_(alpha), beta_(beta) {}; + ClippedRelu(float alpha = 0.f, float beta = 1.f) : alpha_(alpha), beta_(beta){}; template __host__ __device__ void operator()(T& y, const T& x) const @@ -1415,7 +1415,7 @@ struct ClippedRelu struct LeakyRelu { - LeakyRelu(float alpha = 0.01f) : alpha_(alpha) {}; + LeakyRelu(float alpha = 0.01f) : alpha_(alpha){}; template __host__ __device__ void operator()(T& y, const T& x) const @@ -1442,7 +1442,7 @@ struct LeakyRelu struct Elu { - Elu(float alpha = 1.f) : alpha_(alpha) {}; + Elu(float alpha = 1.f) : alpha_(alpha){}; template __host__ __device__ void operator()(T& y, const T& x) const @@ -1469,7 +1469,7 @@ struct Elu struct Logistic { - Logistic(float alpha = 1.f) : alpha_(alpha) {}; + Logistic(float alpha = 1.f) : alpha_(alpha){}; template __host__ __device__ void operator()(T& y, const T& x) const diff --git a/include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp b/include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp index 8a97231030c..0e385901edd 100644 --- a/include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp +++ b/include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp @@ -1308,7 +1308,7 @@ struct Swish struct SoftRelu { - SoftRelu(float alpha = 1.f) : alpha_(alpha) {}; + SoftRelu(float alpha = 1.f) : alpha_(alpha){}; template CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const @@ -1327,7 +1327,7 @@ struct SoftRelu struct Power { Power(float alpha = 0.f, float beta = 1.f, float gamma = 2.f) - : alpha_(alpha), beta_(beta), gamma_(gamma) {}; + : alpha_(alpha), beta_(beta), gamma_(gamma){}; template CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const @@ -1349,7 +1349,7 @@ struct Power struct ClippedRelu { - ClippedRelu(float alpha = 0.f, float beta = 1.f) : alpha_(alpha), beta_(beta) {}; + ClippedRelu(float alpha = 0.f, float beta = 1.f) : alpha_(alpha), beta_(beta){}; template CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const @@ -1368,7 +1368,7 @@ struct ClippedRelu struct LeakyRelu { - LeakyRelu(float alpha = 0.01f) : alpha_(alpha) {}; + LeakyRelu(float alpha = 0.01f) : alpha_(alpha){}; template CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const @@ -1385,7 +1385,7 @@ struct LeakyRelu struct Elu { - Elu(float alpha = 1.f) : alpha_(alpha) {}; + Elu(float alpha = 1.f) : alpha_(alpha){}; template CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const @@ -1402,7 +1402,7 @@ struct Elu struct Logistic { - Logistic(float alpha = 1.f) : alpha_(alpha) {}; + Logistic(float alpha = 1.f) : alpha_(alpha){}; template CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const