Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 63 additions & 37 deletions csrc/kernels/mhc_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ namespace aiter {
static_assert(tile_k % (mfma_k * 8) == 0, "tile_k must be divisible by (mfma_k * 8)");

int64_t idx = blockIdx.x * tile_m;
int k_split_idx = blockIdx.y;
int n_idx = blockIdx.y * tile_n;
int k_split_idx = blockIdx.z;
int k_split_offset = k_split_idx * (hc_hidden_size / split_k);
int warp_id = __builtin_amdgcn_readfirstlane(threadIdx.x / warp_size);
int lane_id = threadIdx.x % warp_size;
Expand All @@ -69,14 +70,15 @@ namespace aiter {
using fp32xtile = opus::vector_t<float, vec_tile>;
using halfxtile = opus::vector_t<DTYPE_I, vec_tile>;

DTYPE_I* x_ptr = x + idx * static_cast<int64_t>(x_stride);
float* fn_ptr = fn;
float* out_ptr = out + (static_cast<int64_t>(k_split_idx * m) + idx) * static_cast<int64_t>(out_stride);
DTYPE_I* x_ptr = x + idx * x_stride;
float* fn_ptr = fn + n_idx * fn_stride;
float* out_ptr = out + (static_cast<int64_t>(k_split_idx * m) + idx) * out_stride + n_idx;
const int m_oob = m < idx + tile_m ? (m - idx) : tile_m;
static constexpr int32_t ooba_i = 4 / sizeof(DTYPE_I);
const int oob_i = (x_stride + ooba_i - 1) / ooba_i * ooba_i;
const int n_oob = hc_mult3 < (n_idx + tile_n) ? (hc_mult3 - n_idx) : tile_n;
auto g_a = opus::make_gmem<DTYPE_I>(x_ptr, x_stride * sizeof(DTYPE_I) * m_oob);
auto g_b = opus::make_gmem<float>(fn_ptr, fn_stride * sizeof(float) * hc_mult3);
auto g_b = opus::make_gmem<float>(fn_ptr, fn_stride * sizeof(float) * n_oob);
auto g_c = opus::make_gmem<float>(out_ptr, out_stride * sizeof(float) * m_oob);

int ga_offset = k_split_offset + (warp_id * mfma_m + lane_id % mfma_m) * x_stride + lane_id / mfma_m * 8;
Expand All @@ -91,8 +93,8 @@ namespace aiter {
// need load LDS[fn_row * 128 + (K_wanted ^ (fn_row & 0xF))]
// lane l → bank = (fn_row * 128 + (K_wanted ^ (fn_row & 0xF))) % 32
// K_wanted same to 16 lanes, but fn_row is different(0,1,2,3,...,15)
const int fn_row_base = warp_id * (tile_n / warp_per_block);
auto lds_load_fn_tile = [&](int k){
int fn_row_base = warp_id * (tile_n / warp_per_block);
float* s_fn_wr_ptr = k % 2 == 0 ? s_fn : (s_fn + tile_n * tile_k);
int s_offset = fn_row_base * tile_k;
s_fn_wr_ptr += s_offset;
Expand Down Expand Up @@ -133,8 +135,10 @@ namespace aiter {
fp32xtile v_af; \
for (int i = 0; i < vec_tile; i++) \
v_af[i] = static_cast<float>(v_a[BUF][i]); \
for (int i = 0; i < vec_tile; i++) \
sqrsum_part += v_af[i] * v_af[i]; \
if (n_idx == 0) { \
for (int i = 0; i < vec_tile; i++) \
sqrsum_part += v_af[i] * v_af[i]; \
} \
Comment thread
junhaha666 marked this conversation as resolved.
v_a[BUF] = load_vector_nbytes<DTYPE_I, vec_tile, 8 * sizeof(DTYPE_I), \
0, true, interleave_size>( \
g_a, ga_offset + ((k) + 2) * tile_k); \
Expand Down Expand Up @@ -196,9 +200,11 @@ namespace aiter {
}
}

float sqrsum_ = cross_row_sum_4(sqrsum_part, lane_id);
if ((warp_id * mfma_m + lane_id < m_oob)) {
sqrsum[k_split_idx * m + idx + warp_id * mfma_m + lane_id] = sqrsum_;
if (n_idx == 0) {
float sqrsum_ = cross_row_sum_4(sqrsum_part, lane_id);
if ((warp_id * mfma_m + lane_id < m_oob)) {
sqrsum[k_split_idx * m + idx + warp_id * mfma_m + lane_id] = sqrsum_;
}
}

for (int n = 0; n < repeat_n; n++) {
Expand All @@ -210,9 +216,10 @@ namespace aiter {
AITER_DISPATCH_FLOATING16_TYPES(x.scalar_type(), "mhc_pre_gemm_sqrsum", [&] { \
using DTYPE_I = typename t2opus<scalar_t>::type; \
const int tile_m = m_per_block; \
int n_blocks = (hc_mult3 + tile_n - 1) / tile_n; \
dim3 grid(m_blocks, n_blocks, split_k); \
TORCH_CHECK(hc_hidden_size % (tile_k * split_k) == 0, "hc_hidden_size must be divisible by tile_k * split_k"); \
TORCH_CHECK(hc_hidden_size >= (tile_k * split_k) * 2, "hc_hidden_size must >= tile_k * split_k * 2 stages prefetch"); \
Comment thread
junhaha666 marked this conversation as resolved.
TORCH_CHECK(hc_mult3 <= tile_n, "hc_mult3 must be less than or equal to tile_n"); \
mhc_pre_gemm_sqrsum_kernel<DTYPE_I, block_size, tile_m, tile_n, tile_k><<<grid, block, 0, stream>>>( \
reinterpret_cast<float*>(out.data_ptr()), \
reinterpret_cast<float*>(sqrsum.data_ptr()), \
Expand All @@ -230,9 +237,17 @@ namespace aiter {

#define MHC_PRE_GEMM_SQRSUM_KERNEL_DISPATCH(tile_k) \
if (tile_k == 64) { \
MHC_PRE_GEMM_SQRSUM_KERNEL_IMPL(256, 32, 64); \
if (cu_num * 2 > m_blocks * split_k) { \
MHC_PRE_GEMM_SQRSUM_KERNEL_IMPL(256, 16, 64); \
} else { \
MHC_PRE_GEMM_SQRSUM_KERNEL_IMPL(256, 32, 64); \
} \
} else if (tile_k == 128) { \
MHC_PRE_GEMM_SQRSUM_KERNEL_IMPL(256, 32, 128); \
if (cu_num > m_blocks * split_k) { \
MHC_PRE_GEMM_SQRSUM_KERNEL_IMPL(256, 16, 128); \
} else { \
MHC_PRE_GEMM_SQRSUM_KERNEL_IMPL(256, 32, 128); \
} \
} else { \
TORCH_CHECK(false, "tile_k must be 64 or 128"); \
}
Expand All @@ -256,12 +271,12 @@ namespace aiter {
const int block_size = 256;
const int warp_size = 64;
const int m_per_block = block_size / warp_size * 16;
int n_blocks = (m + m_per_block - 1) / m_per_block;
int m_blocks = (m + m_per_block - 1) / m_per_block;
const int cu_num = get_num_cu_func();

const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(x));
const hipStream_t stream = at::hip::getCurrentHIPStream();

dim3 grid(n_blocks, split_k);
dim3 block(block_size);

MHC_PRE_GEMM_SQRSUM_KERNEL_DISPATCH(tile_k);
Expand All @@ -282,6 +297,13 @@ namespace aiter {
data = reduce_op(opus::mov_dpp(data, opus::number<0x124>{}), data);
data = reduce_op(opus::mov_dpp(data, opus::number<0x128>{}), data);
return data;
// float res = data;
// asm volatile("s_nop 1");
// asm volatile("v_add_f32 %0, %1, %1 " "row_ror:4 row_mask:0xf bank_mask:0xf bound_ctrl:1" : "=&v"(res) : "v"(res));
// asm volatile("s_nop 1");
// asm volatile("v_add_f32 %0, %1, %1 " "row_ror:8 row_mask:0xf bank_mask:0xf bound_ctrl:1" : "=&v"(res) : "v"(res));
// asm volatile("s_nop 0");
// return res;
Comment thread
junhaha666 marked this conversation as resolved.
}

template <typename DTYPE_I, int block_size, int hc_mult, int num_rows, int residual_block>
Expand All @@ -304,7 +326,8 @@ namespace aiter {
float hc_sinkhorn_eps,
float hc_post_mult_value,
int sinkhorn_repeat,
int n_splits
int n_splits,
int sub_hidden_size
)
{
using opus::operator""_I;
Expand All @@ -322,6 +345,7 @@ namespace aiter {
using floatx8_t = opus::vector_t<float, 8>;
using halfx8_t = opus::vector_t<DTYPE_I, 8>;
const int m_idx = num_rows * blockIdx.x;
const int k_offset = sub_hidden_size * blockIdx.y;
const int land_id = threadIdx.x % warp_size;
const int m_oob = m < m_idx + num_rows ? (m - m_idx) : num_rows;
auto sigmoid = [](float x) { return 1.0f / (1.0f + __expf(-x)); };
Expand Down Expand Up @@ -389,10 +413,10 @@ namespace aiter {
const int res_rowhc_id = threadIdx.x % (num_rows * hc_mult);
const int residual_hc_stride = residual_stride / hc_mult;

DTYPE_I* residual_ptr = residual + static_cast<int64_t>(m_idx) * static_cast<int64_t>(residual_stride);
auto buffer_res = opus::make_gmem<DTYPE_I>(residual_ptr, (m - m_idx) * residual_stride * sizeof(DTYPE_I));
DTYPE_I* layer_input_ptr = layer_input + static_cast<int64_t>(m_idx) * static_cast<int64_t>(hidden_size);
auto buffer_layer_input = opus::make_gmem<DTYPE_I>(layer_input_ptr, (m - m_idx) * hidden_size * sizeof(DTYPE_I));
DTYPE_I* residual_ptr = residual + static_cast<int64_t>(m_idx) * static_cast<int64_t>(residual_stride) + k_offset;
auto buffer_res = opus::make_gmem<DTYPE_I>(residual_ptr, (m_oob * residual_stride - k_offset) * sizeof(DTYPE_I));
DTYPE_I* layer_input_ptr = layer_input + static_cast<int64_t>(m_idx) * static_cast<int64_t>(hidden_size) + k_offset;
auto buffer_layer_input = opus::make_gmem<DTYPE_I>(layer_input_ptr, (m_oob * hidden_size - k_offset) * sizeof(DTYPE_I));

const int lds_res_load_loop = (num_rows * hc_mult * residual_block) / (pre_thread_num * 2);
auto lds_load_res_tile = [&](int k){
Expand All @@ -419,7 +443,7 @@ namespace aiter {

static_assert(num_rows * hc_mult * residual_block % (pre_thread_num * 8) == 0,
"num_rows * hc_mult * residual_block must be divisible by pre_thread_num * 8");
const int out_loop = hidden_size / residual_block;
const int out_loop = sub_hidden_size / residual_block;
const int row_hc_step = pre_thread_num / (num_rows * hc_mult) * 8;
const int row_hc_iter = threadIdx.x / (num_rows * hc_mult);
for(int i = 0; i < out_loop; i++) {
Expand Down Expand Up @@ -448,7 +472,7 @@ namespace aiter {
}
}
}
else {
else if (k_offset == 0){
// _pre_split_mixes_fwd (post & comb)
float post_mix_v;
if (land_id < num_rows * hc_mult) {
Expand All @@ -466,14 +490,6 @@ namespace aiter {
comb_mix_v =comb_mix_v * hc_scale[2] + hc_base[land_id % hc_mult2 + 2 * hc_mult];
}

auto row_reduce = [&](float data) {
return reduce_in_4threads(data, sum_f);
};

auto col_reduce = [&](float data) {
return reduce_cross_4threads(data, sum_f);
};

// comb = comb.softmax(-1) + eps
float row_max = reduce_in_4threads(comb_mix_v, fmaxf);
comb_mix_v = expf(comb_mix_v - row_max);
Expand All @@ -497,10 +513,19 @@ namespace aiter {
}

#define MHC_PRE_BIG_FUSE_KERNEL_IMPL(block_size, hc_mult, num_rows, residual_block) \
dim3 grid((m + num_rows - 1) / num_rows); \
dim3 block(block_size); \
TORCH_CHECK(hidden_size % residual_block == 0, "hidden_size must be divisible by residual_block"); \
TORCH_CHECK(hidden_size >= residual_block * 2, "hidden_size must be >= residual_block * 2 stages prefetch"); \
int m_blocks = (m + num_rows - 1) / num_rows; \
int num_tg_cu = 32 / (block_size / WARP_SIZE); \
int max_k_blocks = cu_num * num_tg_cu / m_blocks; \
if (max_k_blocks < 1) max_k_blocks = 1; \
int k_blocks = max_k_blocks; \
for(; k_blocks > 1; k_blocks--) { \
if (hidden_size % (k_blocks * residual_block) == 0 && hidden_size / k_blocks >= residual_block * 2) break; \
} \
int sub_hidden_size = hidden_size / k_blocks; \
dim3 grid(m_blocks, k_blocks); \
dim3 block(block_size); \
AITER_DISPATCH_FLOATING16_TYPES(layer_input.scalar_type(), "mhc_pre_big_fuse", [&] { \
using DTYPE_I = typename t2opus<scalar_t>::type; \
mhc_pre_big_fuse_kernel<DTYPE_I, block_size, hc_mult, num_rows, residual_block><<<grid, block, 0, stream>>>( \
Expand All @@ -521,15 +546,16 @@ namespace aiter {
hc_sinkhorn_eps, \
hc_post_mult_value, \
sinkhorn_repeat, \
n_splits \
n_splits, \
sub_hidden_size \
); \
});

#define MHC_PRE_BIG_FUSE_KERNEL_DISPATCH(m) \
if (m <= cu_num * 12) { \
MHC_PRE_BIG_FUSE_KERNEL_IMPL(64 + 64 * 4, 4, 2, 256); \
if (m <= cu_num * 12 || get_gpu_arch() != "gfx942") { \
MHC_PRE_BIG_FUSE_KERNEL_IMPL((64 + 64 * 4), 4, 2, 256); \
} else { \
MHC_PRE_BIG_FUSE_KERNEL_IMPL(64 + 64 * 2, 4, 2, 128); \
MHC_PRE_BIG_FUSE_KERNEL_IMPL((64 + 64 * 2), 4, 2, 128); \
}

void mhc_pre_big_fuse(
Expand Down
2 changes: 1 addition & 1 deletion op_tests/test_mhc.py
Original file line number Diff line number Diff line change
Expand Up @@ -627,7 +627,7 @@ def test_mhc_post(m, hidden_size, hc_mult):
"-m",
type=int,
nargs="*",
default=[512, 1024, 2048, 8192, 65536],
default=[1, 32, 64, 128, 256, 512, 1024, 2048, 8192, 65536],
help="""M.
e.g.: -m 32""",
)
Comment thread
junhaha666 marked this conversation as resolved.
Expand Down
Loading