diff --git a/csrc/smxx/mla_combine.cu b/csrc/smxx/mla_combine.cu index afd54d3..61691a1 100644 --- a/csrc/smxx/mla_combine.cu +++ b/csrc/smxx/mla_combine.cu @@ -13,11 +13,11 @@ using namespace cute; template __global__ void __launch_bounds__(NUM_THREADS) flash_fwd_mla_combine_kernel(__grid_constant__ const DecodingParams params) { - // grid_shape: [batch_size, num_q_heads*s_q / BLOCK_SIZE_M] + // grid_shape: [num_q_heads*s_q / BLOCK_SIZE_M, batch_size] // Each CTA gathers the activation of some heads from one batch, do scaling & accumulation, and save the result static_assert(NUM_THREADS/32 == BLOCK_SIZE_M); // The number of warps == block_size_m - const int batch_idx = blockIdx.x; - const int m_block_idx = blockIdx.y; + const int batch_idx = blockIdx.y; + const int m_block_idx = blockIdx.x; const int warp_idx = threadIdx.x / 32; const int lane_idx = threadIdx.x % 32; @@ -189,7 +189,7 @@ void run_flash_mla_combine_kernel(DecodingParams ¶ms, cudaStream_t stream) { attribute[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; attribute[0].val.programmaticStreamSerializationAllowed = 1; cudaLaunchConfig_t combine_kernel_config = { - dim3(params.b, cute::ceil_div(params.h_k*params.q_seq_per_hk, BLOCK_SIZE_M), 1), + dim3(cute::ceil_div(params.h_k*params.q_seq_per_hk, BLOCK_SIZE_M), params.b, 1), dim3(NUM_THREADS, 1, 1), smem_size, stream,