diff --git a/csrc/kernels/mhc_kernels.cu b/csrc/kernels/mhc_kernels.cu index 5b37bc6c14..8465fa426f 100644 --- a/csrc/kernels/mhc_kernels.cu +++ b/csrc/kernels/mhc_kernels.cu @@ -55,7 +55,8 @@ namespace aiter { static_assert(tile_k % (mfma_k * 8) == 0, "tile_k must be divisible by (mfma_k * 8)"); int64_t idx = blockIdx.x * tile_m; - int k_split_idx = blockIdx.y; + int n_idx = blockIdx.y * tile_n; + int k_split_idx = blockIdx.z; int k_split_offset = k_split_idx * (hc_hidden_size / split_k); int warp_id = __builtin_amdgcn_readfirstlane(threadIdx.x / warp_size); int lane_id = threadIdx.x % warp_size; @@ -69,14 +70,15 @@ namespace aiter { using fp32xtile = opus::vector_t; using halfxtile = opus::vector_t; - DTYPE_I* x_ptr = x + idx * static_cast(x_stride); - float* fn_ptr = fn; - float* out_ptr = out + (static_cast(k_split_idx * m) + idx) * static_cast(out_stride); + DTYPE_I* x_ptr = x + idx * x_stride; + float* fn_ptr = fn + n_idx * fn_stride; + float* out_ptr = out + (static_cast(k_split_idx * m) + idx) * out_stride + n_idx; const int m_oob = m < idx + tile_m ? (m - idx) : tile_m; static constexpr int32_t ooba_i = 4 / sizeof(DTYPE_I); const int oob_i = (x_stride + ooba_i - 1) / ooba_i * ooba_i; + const int n_oob = hc_mult3 < (n_idx + tile_n) ? (hc_mult3 - n_idx) : tile_n; auto g_a = opus::make_gmem(x_ptr, x_stride * sizeof(DTYPE_I) * m_oob); - auto g_b = opus::make_gmem(fn_ptr, fn_stride * sizeof(float) * hc_mult3); + auto g_b = opus::make_gmem(fn_ptr, fn_stride * sizeof(float) * n_oob); auto g_c = opus::make_gmem(out_ptr, out_stride * sizeof(float) * m_oob); int ga_offset = k_split_offset + (warp_id * mfma_m + lane_id % mfma_m) * x_stride + lane_id / mfma_m * 8; @@ -91,8 +93,8 @@ namespace aiter { // need load LDS[fn_row * 128 + (K_wanted ^ (fn_row & 0xF))] // lane l → bank = (fn_row * 128 + (K_wanted ^ (fn_row & 0xF))) % 32 // K_wanted same to 16 lanes, but fn_row is different(0,1,2,3,...,15) + const int fn_row_base = warp_id * (tile_n / warp_per_block); auto lds_load_fn_tile = [&](int k){ - int fn_row_base = warp_id * (tile_n / warp_per_block); float* s_fn_wr_ptr = k % 2 == 0 ? s_fn : (s_fn + tile_n * tile_k); int s_offset = fn_row_base * tile_k; s_fn_wr_ptr += s_offset; @@ -133,8 +135,10 @@ namespace aiter { fp32xtile v_af; \ for (int i = 0; i < vec_tile; i++) \ v_af[i] = static_cast(v_a[BUF][i]); \ - for (int i = 0; i < vec_tile; i++) \ - sqrsum_part += v_af[i] * v_af[i]; \ + if (n_idx == 0) { \ + for (int i = 0; i < vec_tile; i++) \ + sqrsum_part += v_af[i] * v_af[i]; \ + } \ v_a[BUF] = load_vector_nbytes( \ g_a, ga_offset + ((k) + 2) * tile_k); \ @@ -196,9 +200,11 @@ namespace aiter { } } - float sqrsum_ = cross_row_sum_4(sqrsum_part, lane_id); - if ((warp_id * mfma_m + lane_id < m_oob)) { - sqrsum[k_split_idx * m + idx + warp_id * mfma_m + lane_id] = sqrsum_; + if (n_idx == 0) { + float sqrsum_ = cross_row_sum_4(sqrsum_part, lane_id); + if ((warp_id * mfma_m + lane_id < m_oob)) { + sqrsum[k_split_idx * m + idx + warp_id * mfma_m + lane_id] = sqrsum_; + } } for (int n = 0; n < repeat_n; n++) { @@ -210,9 +216,10 @@ namespace aiter { AITER_DISPATCH_FLOATING16_TYPES(x.scalar_type(), "mhc_pre_gemm_sqrsum", [&] { \ using DTYPE_I = typename t2opus::type; \ const int tile_m = m_per_block; \ + int n_blocks = (hc_mult3 + tile_n - 1) / tile_n; \ + dim3 grid(m_blocks, n_blocks, split_k); \ TORCH_CHECK(hc_hidden_size % (tile_k * split_k) == 0, "hc_hidden_size must be divisible by tile_k * split_k"); \ TORCH_CHECK(hc_hidden_size >= (tile_k * split_k) * 2, "hc_hidden_size must >= tile_k * split_k * 2 stages prefetch"); \ - TORCH_CHECK(hc_mult3 <= tile_n, "hc_mult3 must be less than or equal to tile_n"); \ mhc_pre_gemm_sqrsum_kernel<<>>( \ reinterpret_cast(out.data_ptr()), \ reinterpret_cast(sqrsum.data_ptr()), \ @@ -230,9 +237,17 @@ namespace aiter { #define MHC_PRE_GEMM_SQRSUM_KERNEL_DISPATCH(tile_k) \ if (tile_k == 64) { \ - MHC_PRE_GEMM_SQRSUM_KERNEL_IMPL(256, 32, 64); \ + if (cu_num * 2 > m_blocks * split_k) { \ + MHC_PRE_GEMM_SQRSUM_KERNEL_IMPL(256, 16, 64); \ + } else { \ + MHC_PRE_GEMM_SQRSUM_KERNEL_IMPL(256, 32, 64); \ + } \ } else if (tile_k == 128) { \ - MHC_PRE_GEMM_SQRSUM_KERNEL_IMPL(256, 32, 128); \ + if (cu_num > m_blocks * split_k) { \ + MHC_PRE_GEMM_SQRSUM_KERNEL_IMPL(256, 16, 128); \ + } else { \ + MHC_PRE_GEMM_SQRSUM_KERNEL_IMPL(256, 32, 128); \ + } \ } else { \ TORCH_CHECK(false, "tile_k must be 64 or 128"); \ } @@ -256,12 +271,12 @@ namespace aiter { const int block_size = 256; const int warp_size = 64; const int m_per_block = block_size / warp_size * 16; - int n_blocks = (m + m_per_block - 1) / m_per_block; + int m_blocks = (m + m_per_block - 1) / m_per_block; + const int cu_num = get_num_cu_func(); const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(x)); const hipStream_t stream = at::hip::getCurrentHIPStream(); - dim3 grid(n_blocks, split_k); dim3 block(block_size); MHC_PRE_GEMM_SQRSUM_KERNEL_DISPATCH(tile_k); @@ -282,6 +297,13 @@ namespace aiter { data = reduce_op(opus::mov_dpp(data, opus::number<0x124>{}), data); data = reduce_op(opus::mov_dpp(data, opus::number<0x128>{}), data); return data; + // float res = data; + // asm volatile("s_nop 1"); + // asm volatile("v_add_f32 %0, %1, %1 " "row_ror:4 row_mask:0xf bank_mask:0xf bound_ctrl:1" : "=&v"(res) : "v"(res)); + // asm volatile("s_nop 1"); + // asm volatile("v_add_f32 %0, %1, %1 " "row_ror:8 row_mask:0xf bank_mask:0xf bound_ctrl:1" : "=&v"(res) : "v"(res)); + // asm volatile("s_nop 0"); + // return res; } template @@ -304,7 +326,8 @@ namespace aiter { float hc_sinkhorn_eps, float hc_post_mult_value, int sinkhorn_repeat, - int n_splits + int n_splits, + int sub_hidden_size ) { using opus::operator""_I; @@ -322,6 +345,7 @@ namespace aiter { using floatx8_t = opus::vector_t; using halfx8_t = opus::vector_t; const int m_idx = num_rows * blockIdx.x; + const int k_offset = sub_hidden_size * blockIdx.y; const int land_id = threadIdx.x % warp_size; const int m_oob = m < m_idx + num_rows ? (m - m_idx) : num_rows; auto sigmoid = [](float x) { return 1.0f / (1.0f + __expf(-x)); }; @@ -389,10 +413,10 @@ namespace aiter { const int res_rowhc_id = threadIdx.x % (num_rows * hc_mult); const int residual_hc_stride = residual_stride / hc_mult; - DTYPE_I* residual_ptr = residual + static_cast(m_idx) * static_cast(residual_stride); - auto buffer_res = opus::make_gmem(residual_ptr, (m - m_idx) * residual_stride * sizeof(DTYPE_I)); - DTYPE_I* layer_input_ptr = layer_input + static_cast(m_idx) * static_cast(hidden_size); - auto buffer_layer_input = opus::make_gmem(layer_input_ptr, (m - m_idx) * hidden_size * sizeof(DTYPE_I)); + DTYPE_I* residual_ptr = residual + static_cast(m_idx) * static_cast(residual_stride) + k_offset; + auto buffer_res = opus::make_gmem(residual_ptr, (m_oob * residual_stride - k_offset) * sizeof(DTYPE_I)); + DTYPE_I* layer_input_ptr = layer_input + static_cast(m_idx) * static_cast(hidden_size) + k_offset; + auto buffer_layer_input = opus::make_gmem(layer_input_ptr, (m_oob * hidden_size - k_offset) * sizeof(DTYPE_I)); const int lds_res_load_loop = (num_rows * hc_mult * residual_block) / (pre_thread_num * 2); auto lds_load_res_tile = [&](int k){ @@ -419,7 +443,7 @@ namespace aiter { static_assert(num_rows * hc_mult * residual_block % (pre_thread_num * 8) == 0, "num_rows * hc_mult * residual_block must be divisible by pre_thread_num * 8"); - const int out_loop = hidden_size / residual_block; + const int out_loop = sub_hidden_size / residual_block; const int row_hc_step = pre_thread_num / (num_rows * hc_mult) * 8; const int row_hc_iter = threadIdx.x / (num_rows * hc_mult); for(int i = 0; i < out_loop; i++) { @@ -448,7 +472,7 @@ namespace aiter { } } } - else { + else if (k_offset == 0){ // _pre_split_mixes_fwd (post & comb) float post_mix_v; if (land_id < num_rows * hc_mult) { @@ -466,14 +490,6 @@ namespace aiter { comb_mix_v =comb_mix_v * hc_scale[2] + hc_base[land_id % hc_mult2 + 2 * hc_mult]; } - auto row_reduce = [&](float data) { - return reduce_in_4threads(data, sum_f); - }; - - auto col_reduce = [&](float data) { - return reduce_cross_4threads(data, sum_f); - }; - // comb = comb.softmax(-1) + eps float row_max = reduce_in_4threads(comb_mix_v, fmaxf); comb_mix_v = expf(comb_mix_v - row_max); @@ -497,10 +513,19 @@ namespace aiter { } #define MHC_PRE_BIG_FUSE_KERNEL_IMPL(block_size, hc_mult, num_rows, residual_block) \ - dim3 grid((m + num_rows - 1) / num_rows); \ - dim3 block(block_size); \ TORCH_CHECK(hidden_size % residual_block == 0, "hidden_size must be divisible by residual_block"); \ TORCH_CHECK(hidden_size >= residual_block * 2, "hidden_size must be >= residual_block * 2 stages prefetch"); \ + int m_blocks = (m + num_rows - 1) / num_rows; \ + int num_tg_cu = 32 / (block_size / WARP_SIZE); \ + int max_k_blocks = cu_num * num_tg_cu / m_blocks; \ + if (max_k_blocks < 1) max_k_blocks = 1; \ + int k_blocks = max_k_blocks; \ + for(; k_blocks > 1; k_blocks--) { \ + if (hidden_size % (k_blocks * residual_block) == 0 && hidden_size / k_blocks >= residual_block * 2) break; \ + } \ + int sub_hidden_size = hidden_size / k_blocks; \ + dim3 grid(m_blocks, k_blocks); \ + dim3 block(block_size); \ AITER_DISPATCH_FLOATING16_TYPES(layer_input.scalar_type(), "mhc_pre_big_fuse", [&] { \ using DTYPE_I = typename t2opus::type; \ mhc_pre_big_fuse_kernel<<>>( \ @@ -521,15 +546,16 @@ namespace aiter { hc_sinkhorn_eps, \ hc_post_mult_value, \ sinkhorn_repeat, \ - n_splits \ + n_splits, \ + sub_hidden_size \ ); \ }); #define MHC_PRE_BIG_FUSE_KERNEL_DISPATCH(m) \ - if (m <= cu_num * 12) { \ - MHC_PRE_BIG_FUSE_KERNEL_IMPL(64 + 64 * 4, 4, 2, 256); \ + if (m <= cu_num * 12 || get_gpu_arch() != "gfx942") { \ + MHC_PRE_BIG_FUSE_KERNEL_IMPL((64 + 64 * 4), 4, 2, 256); \ } else { \ - MHC_PRE_BIG_FUSE_KERNEL_IMPL(64 + 64 * 2, 4, 2, 128); \ + MHC_PRE_BIG_FUSE_KERNEL_IMPL((64 + 64 * 2), 4, 2, 128); \ } void mhc_pre_big_fuse( diff --git a/op_tests/test_mhc.py b/op_tests/test_mhc.py index 9054470aa7..3f23699397 100644 --- a/op_tests/test_mhc.py +++ b/op_tests/test_mhc.py @@ -627,7 +627,7 @@ def test_mhc_post(m, hidden_size, hc_mult): "-m", type=int, nargs="*", - default=[512, 1024, 2048, 8192, 65536], + default=[1, 32, 64, 128, 256, 512, 1024, 2048, 8192, 65536], help="""M. e.g.: -m 32""", )