Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 50 additions & 11 deletions csrc/kernels/mla/metadata/v1_2_device.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
template <int32_t kPackedQoLenPerWg_,
bool kQoSplits_,
int32_t kUniSeqlenQo_,
bool kLdsBatchInfo_,
bool kIsSparse_ = false>
struct MlaMetadataV12Traits
{
Expand All @@ -18,18 +19,21 @@ struct MlaMetadataV12Traits
static constexpr int32_t kUniSeqlenQo = kUniSeqlenQo_;
static constexpr int32_t kFixedOverheadNumBlocks = 16;
static constexpr int32_t kIsSparse = kIsSparse_;
static constexpr int32_t kLdsBatchInfo = kLdsBatchInfo_;
};

template <typename Traits>
__launch_bounds__(ck_tile::get_warp_size(), 1) __global__
void kn_get_mla_metadata_v1_2(MlaMetadataV1KernelParameter params)
{
using QoState = QoState<Traits>;

extern __shared__ uint8_t p_smem[];
int32_t* p_lds_seqlens_qo = reinterpret_cast<int32_t*>(p_smem);
int32_t* p_lds_seqlens_kv = p_lds_seqlens_qo + params.num_batches;
int32_t* p_lds_partial_info = p_lds_seqlens_kv + params.num_batches;
int32_t* p_lds_seqlens_kv = p_lds_seqlens_qo + (QoState::is_unique() ? 0 : params.num_batches);
int32_t* p_lds_partial_info = p_lds_seqlens_kv + (Traits::kLdsBatchInfo ? params.num_batches : 0);

QoState<Traits> qo_state(
QoState qo_state(
params.uni_seqlen_qo, params.ori_seqlen_qo, p_lds_seqlens_qo, params.p_seqlens_qo_indptr);

const int32_t lane_idx = ck_tile::get_lane_id();
Expand All @@ -43,15 +47,20 @@ __launch_bounds__(ck_tile::get_warp_size(), 1) __global__
? (bid / params.ori_seqlen_qo / params.qk_batch_ratio)
: (bid / params.qk_batch_ratio);
const int32_t kv_end = params.p_seqlens_kv_indptr[bid_ori + 1];
const int32_t seqlen_kv = kv_end - params.p_seqlens_kv_indptr[bid_ori];
const int32_t seqlen_kv = Traits::kIsSparse ?
min(kv_end - params.p_seqlens_kv_indptr[bid_ori], params.topk) :
(kv_end - params.p_seqlens_kv_indptr[bid_ori]);

p_lds_seqlens_kv[bid] = Traits::kIsSparse ? min(seqlen_kv, params.topk) : seqlen_kv;
if constexpr (Traits::kLdsBatchInfo)
{
p_lds_seqlens_kv[bid] = seqlen_kv;
}

const int32_t num_blocks = integer_divide_ceil_power2(
seqlen_kv, params.kv_granularity, params.kv_granularity_log2);
sum_blocks += num_blocks;

if constexpr(Traits::kUniSeqlenQo == -1)
if constexpr(QoState::is_unique() == false)
{
p_lds_seqlens_qo[bid] =
params.p_seqlens_qo_indptr[bid_ori + 1] - params.p_seqlens_qo_indptr[bid_ori];
Expand Down Expand Up @@ -84,7 +93,9 @@ __launch_bounds__(ck_tile::get_warp_size(), 1) __global__

int32_t curr_kv_begin = 0;
// The size of 1st element equals to the end loc of the 1st element.
int32_t curr_kv_end = p_lds_seqlens_kv[0];
int32_t curr_kv_end = Traits::kLdsBatchInfo ? p_lds_seqlens_kv[0] :
Traits::kIsSparse ? min(params.p_seqlens_kv_indptr[1], params.topk) :
params.p_seqlens_kv_indptr[1];
int32_t curr_kv_seqlen = curr_kv_end - curr_kv_begin;

int32_t num_works = 0;
Expand Down Expand Up @@ -226,7 +237,19 @@ __launch_bounds__(ck_tile::get_warp_size(), 1) __global__
{
if(curr_sub_head_idx == 0)
{
curr_kv_seqlen = p_lds_seqlens_kv[curr_batch];
if constexpr (Traits::kLdsBatchInfo)
{
curr_kv_seqlen = p_lds_seqlens_kv[curr_batch];
}
else
{
const int32_t bid_ori = Traits::kIsSparse
? (curr_batch / params.ori_seqlen_qo / params.qk_batch_ratio)
: (curr_batch / params.qk_batch_ratio);
curr_kv_seqlen =
params.p_seqlens_kv_indptr[bid_ori + 1] - params.p_seqlens_kv_indptr[bid_ori];
curr_kv_seqlen = Traits::kIsSparse ? min(curr_kv_seqlen, params.topk) : curr_kv_seqlen;
}
curr_kv_begin =
Traits::kIsSparse ? (curr_kv_begin + params.topk) : curr_kv_end;
curr_kv_end = curr_kv_begin + curr_kv_seqlen;
Expand Down Expand Up @@ -302,12 +325,28 @@ __launch_bounds__(ck_tile::get_warp_size(), 1) __global__
template <int32_t kPackedQoLenPerWg, bool kQoSplits, int32_t kUniSeqlenQo, bool kIsSparse>
void dispatch_mla_metadata_v1_2_device(const MlaMetadataV1KernelParameter& params,
const hipStream_t stream,
const int32_t max_seqlen_qo,
const int32_t warp_size,
const int32_t lds_size)
{
using Traits = MlaMetadataV12Traits<kPackedQoLenPerWg, kQoSplits, kUniSeqlenQo, kIsSparse>;
const dim3 grid = dim3(1, 1, 1);
kn_get_mla_metadata_v1_2<Traits><<<grid, warp_size, lds_size, stream>>>(params);

using DummyTraits = MlaMetadataV12Traits<kPackedQoLenPerWg, kQoSplits, kUniSeqlenQo, true, kIsSparse>;
const int32_t lds_bytes_per_batch = sizeof(int32_t) * (QoState<DummyTraits>::is_unique() ? 1 : 2);
const int32_t max_qo_tiles = kQoSplits ? (ck_tile::integer_divide_ceil(max_seqlen_qo, kPackedQoLenPerWg)) : 1;
const int32_t lds_bytes_partial_info = kQoSplits ? params.num_cu * max_qo_tiles * sizeof(int32_t) : 0;
const int32_t max_lds_batch_size = (lds_size - lds_bytes_partial_info) / lds_bytes_per_batch;

if (params.num_batches <= max_lds_batch_size)
{
using Traits = MlaMetadataV12Traits<kPackedQoLenPerWg, kQoSplits, kUniSeqlenQo, true, kIsSparse>;
kn_get_mla_metadata_v1_2<Traits><<<grid, warp_size, lds_size, stream>>>(params);
}
else
{
using Traits = MlaMetadataV12Traits<kPackedQoLenPerWg, kQoSplits, kUniSeqlenQo, false, kIsSparse>;
kn_get_mla_metadata_v1_2<Traits><<<grid, warp_size, lds_size, stream>>>(params);
}
}

void get_mla_metadata_v1_2_device(const torch::Tensor& seqlens_qo_indptr, // [batch size + 1]
Expand Down Expand Up @@ -393,5 +432,5 @@ void get_mla_metadata_v1_2_device(const torch::Tensor& seqlens_qo_indptr, // [ba
params.uni_seqlen_qo,
topk,
dispatch_mla_metadata_v1_2_device<kPackedQoLenPerWg, kQoSplits, kUniSeqlenQo, kIsSparse>(
params, stream, dev_prop.warpSize, dev_prop.maxSharedMemoryPerMultiProcessor));
params, stream, max_seqlen_qo, dev_prop.warpSize, dev_prop.maxSharedMemoryPerMultiProcessor));
}
2 changes: 1 addition & 1 deletion csrc/kernels/mla/metadata/v1_comm.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ public:
p_seqlens_qo_indptr_(p_seqlens_qo_indptr)
{ }

CK_TILE_DEVICE constexpr bool is_unique()
CK_TILE_HOST_DEVICE static constexpr bool is_unique()
{
return Traits::kUniSeqlenQo >= 0;
}
Expand Down