Skip to content
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
7688ba0
[CK_TILE] Correct BlockWarps calculation and fix smoke-test in rmsnorm
ClementLinCF Jul 16, 2025
510dc83
Update rmsnorm host reference
ClementLinCF Jul 18, 2025
2985332
Update tree reduction of rmsnorm for reference host
ClementLinCF Jul 18, 2025
997996f
Fix cross warp for m > 1 cases
MHYangAMD Jul 21, 2025
6a1ac38
Add RMSNorm model selectable option for host reference
ClementLinCF Jul 21, 2025
58a6ee8
Fix save_unquant cases
MHYangAMD Jul 22, 2025
b796269
Update reference rmsnorm forward function to use enum for model sensi…
ClementLinCF Jul 23, 2025
ac2ba69
Update reference rmsnorm calculation for model sensitivity
ClementLinCF Jul 24, 2025
3a141eb
Fix m warp for layernorm
MHYangAMD Jul 25, 2025
0c803d1
Adjust parameter of reference for twoPass
ClementLinCF Jul 28, 2025
b2e7af5
Merge branch 'develop' into ck_tile/rmsnorm-smoke-test
ClementLinCF Jul 30, 2025
1cb4149
Fix clang format
ClementLinCF Jul 30, 2025
16fce0c
Merge branch 'develop' into ck_tile/rmsnorm-smoke-test
ClementLinCF Aug 4, 2025
847cedd
Merge branch 'develop' into ck_tile/rmsnorm-smoke-test
ClementLinCF Aug 5, 2025
c441904
Merge branch 'develop' into ck_tile/rmsnorm-smoke-test
ClementLinCF Aug 7, 2025
59a92fe
Merge branch 'develop' into ck_tile/rmsnorm-smoke-test
ClementLinCF Aug 12, 2025
0dc5d41
Run clang-format-overwrite.sh to fix formating issue
ClementLinCF Aug 12, 2025
d71e744
Merge branch 'develop' into ck_tile/rmsnorm-smoke-test
ClementLinCF Aug 12, 2025
e358e8c
fix clang format
illsilin Aug 13, 2025
1a9a3a4
Merge branch 'develop' into ck_tile/rmsnorm-smoke-test
ClementLinCF Sep 8, 2025
bcf095f
solve the merge conflict
Oct 10, 2025
427e94b
Merge branch 'develop' into ck_tile/rmsnorm-smoke-test
Oct 12, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion example/ck_tile/02_layernorm2d/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ class layernorm_fwd_codegen:
if constexpr(is_warp_per_row)
{
static_assert(ck_tile::get_warp_size() % ThreadPerBlock_N_ == 0);
return total_warps * (ck_tile::get_warp_size() / ThreadPerBlock_N_);
return total_warps;
}
else
{
Expand Down
8 changes: 4 additions & 4 deletions example/ck_tile/10_rmsnorm2d/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ class rmsnorm_fwd_codegen:
if constexpr(is_warp_per_row)
{
static_assert(ck_tile::get_warp_size() % ThreadPerBlock_N_ == 0);
return total_warps * (ck_tile::get_warp_size() / ThreadPerBlock_N_);
return total_warps;
Comment thread
MHYangAMD marked this conversation as resolved.
}
else
{
Expand Down Expand Up @@ -642,15 +642,15 @@ def get_blobs(self):
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1, 256, 4, True, False, False, True, 0, 0, 1),
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 12, 1, 256, 2, True, False, False, True, 0, 0, 1),
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1,1024, 1, True, False, False, True, 0, 0, 1)]
}
}
}

total_blob = list()

for model_sensitive_flag in [0, 1]: # 0: default; 1: model sensitive
current_trait_dict = h_trait_dicts[model_sensitive_flag]
for hs_key in current_trait_dict:
hs = current_trait_dict[hs_key]
hs = current_trait_dict[hs_key]
current_n = hs_key
for dtype, scale_type, fused_add, fused_quant, save_unquant in itertools.product(dtype_list, scale_list, fused_add_list, fused_sweep_list, bool_list):
prec_i, prec_o = dtype.split(',')
Expand Down
50 changes: 35 additions & 15 deletions example/ck_tile/10_rmsnorm2d/rmsnorm2d_fwd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,16 +67,16 @@ template <typename InDataType,
bool SaveUnquant>
bool run(const ck_tile::ArgParser& arg_parser)
{
ck_tile::index_t m = arg_parser.get_int("m");
ck_tile::index_t n = arg_parser.get_int("n");
float epsilon = arg_parser.get_float("e");
int kname = arg_parser.get_int("kname");
int do_validation = arg_parser.get_int("v");
int fused_add = arg_parser.get_int("fadd");
int fused_quant = arg_parser.get_int("fquant");
int warmup = arg_parser.get_int("warmup");
int repeat = arg_parser.get_int("repeat");
const int use_model_sensitive_rmsnorm = arg_parser.get_int("s");
ck_tile::index_t m = arg_parser.get_int("m");
ck_tile::index_t n = arg_parser.get_int("n");
float epsilon = arg_parser.get_float("e");
int kname = arg_parser.get_int("kname");
int do_validation = arg_parser.get_int("v");
int fused_add = arg_parser.get_int("fadd");
int fused_quant = arg_parser.get_int("fquant");
int warmup = arg_parser.get_int("warmup");
int repeat = arg_parser.get_int("repeat");
int use_model_sensitive_rmsnorm = arg_parser.get_int("s");

ck_tile::index_t x_stride = arg_parser.get_int("x_stride");
if(x_stride < 0)
Expand Down Expand Up @@ -193,7 +193,13 @@ bool run(const ck_tile::ArgParser& arg_parser)
return base_str;
}();

std::cout << "[" << prec_str << "]" << " m:" << m << ", n:" << n << ", x_stride:" << x_stride
if(n > 8192)
{
use_model_sensitive_rmsnorm = 0;
}

std::cout << "[" << prec_str << "]"
<< " m:" << m << ", n:" << n << ", x_stride:" << x_stride
<< ", xr_stride:" << xr_stride << ", y_stride:" << y_stride
<< ", yr_stride:" << yr_stride << ", s:" << use_model_sensitive_rmsnorm << std::flush;

Expand Down Expand Up @@ -294,7 +300,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
const int N = acc_.mDesc.get_lengths()[1];
for(int n_ = 0; n_ < N; ++n_)
{
o_unquant_(m_, n_) = ck_tile::type_convert<OutDataType>(acc_(m_, n_));
o_unquant_(m_, n_) = ck_tile::type_convert<UnquantYDataType>(acc_(m_, n_));
}

dquant_functor(m_, o_, acc_);
Expand All @@ -313,7 +319,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
invRms_host_ref,
unquant_y_host_ref,
epsilon,
default_and_dquant_functor);
default_and_dquant_functor,
use_model_sensitive_rmsnorm);
}
else
{
Expand All @@ -328,7 +335,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
invRms_host_ref,
unquant_y_host_ref,
epsilon,
dquant_functor);
dquant_functor,
use_model_sensitive_rmsnorm);
}
}
else
Expand All @@ -340,7 +348,14 @@ bool run(const ck_tile::ArgParser& arg_parser)
YDataType,
InvRmsDataType,
ck_tile::null_type>(
x_host, gamma_host, y_host_ref, invRms_host_ref, unquant_y_null, epsilon);
x_host,
gamma_host,
y_host_ref,
invRms_host_ref,
unquant_y_null,
epsilon,
ck_tile::reference_rmsnorm2d_default_epilogue{},
use_model_sensitive_rmsnorm);
}

y_buf.FromDevice(y_host_dev.data());
Expand All @@ -351,6 +366,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
y_residual_buf.FromDevice(y_residual_host_dev.data());
}

if constexpr(SaveUnquant)
{
unquant_y_buf.FromDevice(unquant_y_host_dev.data());
}

auto [rtol, atol] = get_elimit<YDataType>();
if(x_stride == n)
{
Expand Down
124 changes: 80 additions & 44 deletions example/ck_tile/10_rmsnorm2d/script/smoke_test.sh
Original file line number Diff line number Diff line change
@@ -1,49 +1,85 @@
#!/bin/sh
#!/bin/bash

EXE="$(find . -name tile_rmsnorm2d_fwd -type f | head -n 1)"

for fquant in "" "-fquant=1 -prec_o=int8" "-fquant=2 -prec_o=int8" "-fquant=1 -prec_o=fp8" "-fquant=2 -prec_o=fp8"\
"-fquant=1 -prec_o=int8 -save_unquant=1" "-fquant=2 -prec_o=int8 -save_unquant=1" "-fquant=1 -prec_o=fp8 -save_unquant=1" "-fquant=2 -prec_o=fp8 -save_unquant=1"; do
for pr_i in "fp16" "bf16" ; do
for fadd in "0" "1"; do
# 0: for no specific RMSNorm; 1: for T-5 like RMSNorm
for s in "0" "1"; do
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=99 -n=13
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=17 -n=16
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=1 -n=100
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=4 -n=128
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=80 -n=127
# $EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=22 -n=255 -stride=256
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=7 -n=599
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=19 -n=512
# $EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=33 -n=313 -stride=1000
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=11 -n=510
# $EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=171 -n=676 -stride=818
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=91 -n=636
# $EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=12 -n=768 -stride=800
# $EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=100 -n=766 -stride=812
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=31 -n=1024
# $EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=64 -n=1000 -stride=1004
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=8 -n=1501
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=3 -n=1826
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=5 -n=2040
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=7 -n=2734
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=1 -n=3182
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=9 -n=4096
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=3 -n=8192
done
done
done
done
total=0
valid=0

# The following cases uses two pass pipeline which doesn't support quant epilogue.
for fquant in ""
for pr_i in "fp16" "bf16" ; do
for fadd in "0" "1"; do
# 0: for no specific RMSNorm; 1: for T-5 like RMSNorm
for s in "0" "1"; do
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=1 -n=10547
#$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=3 -n=17134
done
done
run_case() {
cmd="$EXE -prec_i=$1 -fadd=$2 -s=$3 $4 -m=$5 -n=$6 $7"
echo "[CMD] $cmd"
output=$($cmd 2>&1)
echo "$output"
if echo "$output" | grep -q "valid:y"; then
valid=$((valid + 1))
fi
total=$((total + 1))
}

fquant_list=(
""
"-fquant=1 -prec_o=int8"
"-fquant=2 -prec_o=int8"
"-fquant=1 -prec_o=fp8"
"-fquant=2 -prec_o=fp8"
"-fquant=1 -prec_o=int8 -save_unquant=1"
"-fquant=2 -prec_o=int8 -save_unquant=1"
"-fquant=1 -prec_o=fp8 -save_unquant=1"
"-fquant=2 -prec_o=fp8 -save_unquant=1"
)

m_n_list=(
"99 13" "17 16" "1 100" "4 128" "80 127"
"7 599" "19 512" "11 510" "91 636"
"31 1024" "8 1501" "3 1826" "5 2040"
"7 2734" "1 3182" "9 4096" "3 8192"
)

### Add special stride test ###
m_n_stride_list=(
"22 255 -x_stride=256 -xr_stride=256 -y_stride=256 -yr_stride=256"
"33 313 -x_stride=1000 -xr_stride=1000 -y_stride=1000 -yr_stride=1000"
"171 676 -x_stride=818 -xr_stride=818 -y_stride=818 -yr_stride=818"
"12 768 -x_stride=800 -xr_stride=800 -y_stride=800 -yr_stride=800"
"100 766 -x_stride=812 -xr_stride=812 -y_stride=812 -yr_stride=812"
"64 1000 -x_stride=1004 -xr_stride=1004 -y_stride=1004 -yr_stride=1004"
)

for fquant in "${fquant_list[@]}"; do
for pr_i in "fp16" "bf16"; do
for fadd in "0" "1"; do
for s in "0" "1"; do
for pair in "${m_n_list[@]}"; do
m=$(echo $pair | cut -d ' ' -f1)
n=$(echo $pair | cut -d ' ' -f2)
run_case "$pr_i" "$fadd" "$s" "$fquant" "$m" "$n" ""
done

### Running tests with stride ###
for triple in "${m_n_stride_list[@]}"; do
m=$(echo $triple | cut -d ' ' -f1)
n=$(echo $triple | cut -d ' ' -f2)
stride_args=$(echo $triple | cut -d ' ' -f3-)
run_case "$pr_i" "$fadd" "$s" "$fquant" "$m" "$n" "$stride_args"
done
done
done
done
done

# Special two-pass only
for pr_i in "fp16" "bf16"; do
for fadd in "0" "1"; do
for s in "0" "1"; do
run_case "$pr_i" "$fadd" "$s" "" "1" "10547" ""
done
done
done

# Summary
echo "=============================="
echo "Total cases: $total"
echo "Valid cases: $valid"
accuracy=$(awk "BEGIN {printf \"%.2f\", ($valid / $total) * 100}")
echo "Accuracy: $accuracy%"
echo "=============================="
31 changes: 29 additions & 2 deletions include/ck_tile/host/reference/reference_rmsnorm2d_fwd.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

#include "ck_tile/core.hpp"
#include "ck_tile/host/host_tensor.hpp"
#include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_traits.hpp"

namespace ck_tile {

Expand Down Expand Up @@ -43,7 +44,9 @@ void reference_rmsnorm2d_fwd(const HostTensor<XDataType>& x_m_n,
HostTensor<InvRmsDataType>& invRms_m,
HostTensor<UnquantYDataType>& unquant_y_m_n,
ComputeDataType epsilon,
Epilogue epilogue_functor = {})
Epilogue epilogue_functor = {},
const int use_model_sensitive_rmsnorm =
static_cast<int>(Rmsnorm2dSensitiveEnum::NO_SPECIFIC_MODEL))
{
auto rmsnorm2d_fwd_func = [&](auto m) {
const int N = x_m_n.mDesc.get_lengths()[1];
Expand All @@ -68,7 +71,30 @@ void reference_rmsnorm2d_fwd(const HostTensor<XDataType>& x_m_n,
{
ComputeDataType x = ck_tile::type_convert<ComputeDataType>(x_m_n(m, n));
ComputeDataType gamma = ck_tile::type_convert<ComputeDataType>(gamma_n(n));
acc(m, n) = x * divisor * gamma;
if(use_model_sensitive_rmsnorm ==
static_cast<int>(
Rmsnorm2dSensitiveEnum::NO_SPECIFIC_MODEL)) // 0: for no specific model
{
acc(m, n) = x * divisor * gamma;
}
else if(use_model_sensitive_rmsnorm ==
static_cast<int>(Rmsnorm2dSensitiveEnum::T5_MODEL_LIKE)) // 1: for T5-like model
{
if constexpr(std::is_same_v<XDataType, ck_tile::bf16_t>)
{
const auto tmp0 = float_to_bf16<bf16_rounding_mode::standard>(x * divisor);
const auto tmp1 = float_to_bf16<bf16_rounding_mode::standard>(
type_convert<ComputeDataType>(tmp0) * gamma);
const auto rmsn_ = type_convert<ComputeDataType>(tmp1);
acc(m, n) = rmsn_;
}
else
{
const auto tmp = type_convert<XDataType>(x * divisor);
const auto rmsn_ = type_convert<ComputeDataType>(tmp) * gamma;
acc(m, n) = rmsn_;
}
}
}

if constexpr(!std::is_same_v<UnquantYDataType, ck_tile::null_type>)
Expand All @@ -84,4 +110,5 @@ void reference_rmsnorm2d_fwd(const HostTensor<XDataType>& x_m_n,
make_ParallelTensorFunctor(rmsnorm2d_fwd_func, invRms_m.mDesc.get_lengths()[0])(
std::thread::hardware_concurrency());
}

} // namespace ck_tile
4 changes: 3 additions & 1 deletion include/ck_tile/ops/reduce/block/block_reduce2d.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -400,11 +400,13 @@ struct BlockReduce2dTreeCrossWarpSync
block_sync_lds();

// We let each warp holds a duplication to do reduction.
const index_t local_warp_id = warp_id / num_reduce_warps;
const index_t local_smem_os = local_warp_id * num_reduce_warps;
static_for<0, thread_buf_size, 1>{}([&](auto i) {
DataType v = 0;
if(lane_id < num_reduce_warps)
{
v = smem_ptr[lane_id + i * num_warps];
v = smem_ptr[i * num_warps + local_smem_os + lane_id];
}

// cross-lane reduce for replication
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ struct Rmsnorm2dFwdPipelineModelSensitiveT5Pass
// compute mean square each-thread->cross-lane->cross-warp
auto square_sum = block_reduce2d.template MakeYBlockTile<decltype(acc)>();
set_tile(square_sum, 0);
if constexpr(Problem::BlockShape::Vector_N % 2 == 0)
if constexpr((Problem::BlockShape::Repeat_N * Problem::BlockShape::Vector_N) % 2 == 0)
{
sweep_tile(
acc,
Expand Down Expand Up @@ -179,7 +179,7 @@ struct Rmsnorm2dFwdPipelineModelSensitiveT5Pass

const auto gamma_ = type_convert<ComputeDataType>(gamma[j_idx]);

if constexpr(std::is_same_v<YResidualDataType, ck_tile::bf16_t>)
if constexpr(std::is_same_v<XDataType, ck_tile::bf16_t>)
{
const auto tmp0 =
float_to_bf16<bf16_rounding_mode::standard>(acc[idx] * inv_rms_[i_idx]);
Expand All @@ -190,7 +190,7 @@ struct Rmsnorm2dFwdPipelineModelSensitiveT5Pass
}
else
{
const auto tmp = type_convert<YResidualDataType>(acc[idx] * inv_rms_[i_idx]);
const auto tmp = type_convert<XDataType>(acc[idx] * inv_rms_[i_idx]);
const auto rmsn_ = type_convert<ComputeDataType>(tmp) * gamma_;
rmsn(idx) = rmsn_;
}
Expand Down