diff --git a/csrc/kernels/mla/metadata/v1_2_device.cuh b/csrc/kernels/mla/metadata/v1_2_device.cuh index 48e7b35a12..9340d05416 100644 --- a/csrc/kernels/mla/metadata/v1_2_device.cuh +++ b/csrc/kernels/mla/metadata/v1_2_device.cuh @@ -6,6 +6,7 @@ template struct MlaMetadataV12Traits { @@ -18,18 +19,21 @@ struct MlaMetadataV12Traits static constexpr int32_t kUniSeqlenQo = kUniSeqlenQo_; static constexpr int32_t kFixedOverheadNumBlocks = 16; static constexpr int32_t kIsSparse = kIsSparse_; + static constexpr int32_t kLdsBatchInfo = kLdsBatchInfo_; }; template __launch_bounds__(ck_tile::get_warp_size(), 1) __global__ void kn_get_mla_metadata_v1_2(MlaMetadataV1KernelParameter params) { + using QoState = QoState; + extern __shared__ uint8_t p_smem[]; int32_t* p_lds_seqlens_qo = reinterpret_cast(p_smem); - int32_t* p_lds_seqlens_kv = p_lds_seqlens_qo + params.num_batches; - int32_t* p_lds_partial_info = p_lds_seqlens_kv + params.num_batches; + int32_t* p_lds_seqlens_kv = p_lds_seqlens_qo + (QoState::is_unique() ? 0 : params.num_batches); + int32_t* p_lds_partial_info = p_lds_seqlens_kv + (Traits::kLdsBatchInfo ? params.num_batches : 0); - QoState qo_state( + QoState qo_state( params.uni_seqlen_qo, params.ori_seqlen_qo, p_lds_seqlens_qo, params.p_seqlens_qo_indptr); const int32_t lane_idx = ck_tile::get_lane_id(); @@ -43,15 +47,20 @@ __launch_bounds__(ck_tile::get_warp_size(), 1) __global__ ? (bid / params.ori_seqlen_qo / params.qk_batch_ratio) : (bid / params.qk_batch_ratio); const int32_t kv_end = params.p_seqlens_kv_indptr[bid_ori + 1]; - const int32_t seqlen_kv = kv_end - params.p_seqlens_kv_indptr[bid_ori]; + const int32_t seqlen_kv = Traits::kIsSparse ? + min(kv_end - params.p_seqlens_kv_indptr[bid_ori], params.topk) : + (kv_end - params.p_seqlens_kv_indptr[bid_ori]); - p_lds_seqlens_kv[bid] = Traits::kIsSparse ? min(seqlen_kv, params.topk) : seqlen_kv; + if constexpr (Traits::kLdsBatchInfo) + { + p_lds_seqlens_kv[bid] = seqlen_kv; + } const int32_t num_blocks = integer_divide_ceil_power2( seqlen_kv, params.kv_granularity, params.kv_granularity_log2); sum_blocks += num_blocks; - if constexpr(Traits::kUniSeqlenQo == -1) + if constexpr(QoState::is_unique() == false) { p_lds_seqlens_qo[bid] = params.p_seqlens_qo_indptr[bid_ori + 1] - params.p_seqlens_qo_indptr[bid_ori]; @@ -84,7 +93,9 @@ __launch_bounds__(ck_tile::get_warp_size(), 1) __global__ int32_t curr_kv_begin = 0; // The size of 1st element equals to the end loc of the 1st element. - int32_t curr_kv_end = p_lds_seqlens_kv[0]; + int32_t curr_kv_end = Traits::kLdsBatchInfo ? p_lds_seqlens_kv[0] : + Traits::kIsSparse ? min(params.p_seqlens_kv_indptr[1], params.topk) : + params.p_seqlens_kv_indptr[1]; int32_t curr_kv_seqlen = curr_kv_end - curr_kv_begin; int32_t num_works = 0; @@ -226,7 +237,19 @@ __launch_bounds__(ck_tile::get_warp_size(), 1) __global__ { if(curr_sub_head_idx == 0) { - curr_kv_seqlen = p_lds_seqlens_kv[curr_batch]; + if constexpr (Traits::kLdsBatchInfo) + { + curr_kv_seqlen = p_lds_seqlens_kv[curr_batch]; + } + else + { + const int32_t bid_ori = Traits::kIsSparse + ? (curr_batch / params.ori_seqlen_qo / params.qk_batch_ratio) + : (curr_batch / params.qk_batch_ratio); + curr_kv_seqlen = + params.p_seqlens_kv_indptr[bid_ori + 1] - params.p_seqlens_kv_indptr[bid_ori]; + curr_kv_seqlen = Traits::kIsSparse ? min(curr_kv_seqlen, params.topk) : curr_kv_seqlen; + } curr_kv_begin = Traits::kIsSparse ? (curr_kv_begin + params.topk) : curr_kv_end; curr_kv_end = curr_kv_begin + curr_kv_seqlen; @@ -302,12 +325,28 @@ __launch_bounds__(ck_tile::get_warp_size(), 1) __global__ template void dispatch_mla_metadata_v1_2_device(const MlaMetadataV1KernelParameter& params, const hipStream_t stream, + const int32_t max_seqlen_qo, const int32_t warp_size, const int32_t lds_size) { - using Traits = MlaMetadataV12Traits; const dim3 grid = dim3(1, 1, 1); - kn_get_mla_metadata_v1_2<<>>(params); + + using DummyTraits = MlaMetadataV12Traits; + const int32_t lds_bytes_per_batch = sizeof(int32_t) * (QoState::is_unique() ? 1 : 2); + const int32_t max_qo_tiles = kQoSplits ? (ck_tile::integer_divide_ceil(max_seqlen_qo, kPackedQoLenPerWg)) : 1; + const int32_t lds_bytes_partial_info = kQoSplits ? params.num_cu * max_qo_tiles * sizeof(int32_t) : 0; + const int32_t max_lds_batch_size = (lds_size - lds_bytes_partial_info) / lds_bytes_per_batch; + + if (params.num_batches <= max_lds_batch_size) + { + using Traits = MlaMetadataV12Traits; + kn_get_mla_metadata_v1_2<<>>(params); + } + else + { + using Traits = MlaMetadataV12Traits; + kn_get_mla_metadata_v1_2<<>>(params); + } } void get_mla_metadata_v1_2_device(const torch::Tensor& seqlens_qo_indptr, // [batch size + 1] @@ -393,5 +432,5 @@ void get_mla_metadata_v1_2_device(const torch::Tensor& seqlens_qo_indptr, // [ba params.uni_seqlen_qo, topk, dispatch_mla_metadata_v1_2_device( - params, stream, dev_prop.warpSize, dev_prop.maxSharedMemoryPerMultiProcessor)); + params, stream, max_seqlen_qo, dev_prop.warpSize, dev_prop.maxSharedMemoryPerMultiProcessor)); } diff --git a/csrc/kernels/mla/metadata/v1_comm.cuh b/csrc/kernels/mla/metadata/v1_comm.cuh index a5bac9de5c..e166d4a362 100644 --- a/csrc/kernels/mla/metadata/v1_comm.cuh +++ b/csrc/kernels/mla/metadata/v1_comm.cuh @@ -236,7 +236,7 @@ public: p_seqlens_qo_indptr_(p_seqlens_qo_indptr) { } - CK_TILE_DEVICE constexpr bool is_unique() + CK_TILE_HOST_DEVICE static constexpr bool is_unique() { return Traits::kUniSeqlenQo >= 0; }