Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 11 additions & 11 deletions csrc/kernels/mla/metadata/v1_0_device.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,15 @@ __device__ int32_t get_local_splits(int32_t seqlen_kv,
int32_t ex_splits =
seqlen_kv /
196; // magic num 196. Experiments shows 196 per splits can get better performance.
return ck_tile::min(ck_tile::min(ex_splits, num_splits_per_cu), num_splits);
return opus::min(opus::min(ex_splits, num_splits_per_cu), num_splits);
#endif
}

template <bool DP_MODE = false>
__launch_bounds__(ck_tile::get_warp_size(), 1) __global__
__launch_bounds__(opus::get_warp_size(), 1) __global__
void kn_get_mla_metadata_v1_0(MlaMetadataV1KernelParameter params)
{
const int32_t lane_idx = ck_tile::get_lane_id();
const int32_t lane_idx = opus::lane_id();

MlaWorkInfo* p_work_info_set = reinterpret_cast<MlaWorkInfo*>(params.p_work_info_set_raw);

Expand All @@ -44,7 +44,7 @@ __launch_bounds__(ck_tile::get_warp_size(), 1) __global__

int32_t num_splits_per_cu = (params.num_cu + params.num_batches - 1) / params.num_batches;

for(int32_t bid = lane_idx; bid < params.num_batches; bid += ck_tile::get_warp_size())
for(int32_t bid = lane_idx; bid < params.num_batches; bid += opus::get_warp_size())
{
const int32_t bid_ori = bid / params.qk_batch_ratio;

Expand All @@ -67,8 +67,8 @@ __launch_bounds__(ck_tile::get_warp_size(), 1) __global__
seqlen_kv, params.kv_granularity, params.kv_granularity_log2);
const int32_t num_splits =
get_local_splits(seqlen_kv, params.num_splits, num_splits_per_cu);
const int32_t payload = ck_tile::integer_divide_ceil(num_blocks, num_splits);
int32_t split_local = ck_tile::integer_divide_ceil(num_blocks, payload);
const int32_t payload = integer_divide_ceil(num_blocks, num_splits);
int32_t split_local = integer_divide_ceil(num_blocks, payload);
int32_t tail = seqlen_kv % (payload * params.kv_granularity);
if(tail <= 4 && tail != 0 && split_local > 1)
{
Expand All @@ -95,7 +95,7 @@ __launch_bounds__(ck_tile::get_warp_size(), 1) __global__
int32_t work_per_cu = work_end / params.num_cu;
int32_t work_res = work_end % params.num_cu;

for(int32_t bid = lane_idx; bid < params.num_batches; bid += ck_tile::get_warp_size())
for(int32_t bid = lane_idx; bid < params.num_batches; bid += opus::get_warp_size())
{
const int32_t bid_ori = bid / params.qk_batch_ratio;

Expand Down Expand Up @@ -126,7 +126,7 @@ __launch_bounds__(ck_tile::get_warp_size(), 1) __global__
work_info.qo_end = work_info.qo_start + params.uni_seqlen_qo;
work_info.kv_start = kv_begin + (sid * payload * params.kv_granularity);
work_info.kv_end =
ck_tile::min(work_info.kv_start + payload * params.kv_granularity, kv_end);
opus::min(work_info.kv_start + payload * params.kv_granularity, kv_end);
work_info.kv_offset = kv_end - work_info.kv_end;
if(work_info.kv_offset <= 4 && split_local > 1)
{
Expand All @@ -143,21 +143,21 @@ __launch_bounds__(ck_tile::get_warp_size(), 1) __global__
}

int32_t reduce_end = params.p_reduce_indptr[params.num_batches];
for(int32_t work_id = lane_idx + 1; work_id < work_res; work_id += ck_tile::get_warp_size())
for(int32_t work_id = lane_idx + 1; work_id < work_res; work_id += opus::get_warp_size())
{
params.p_work_indptr[work_id] = min(work_id * (work_per_cu + 1), work_end);
}

int32_t stage = work_res * (work_per_cu + 1);

for(int32_t work_id = work_res + lane_idx; work_id < params.num_cu + 1;
work_id += ck_tile::get_warp_size())
work_id += opus::get_warp_size())
{
params.p_work_indptr[work_id] = stage + (work_id - work_res) * work_per_cu;
}

for(int32_t reduce_id = params.num_batches + lane_idx; reduce_id <= params.fixed_num_batches;
reduce_id += ck_tile::get_warp_size())
reduce_id += opus::get_warp_size())
{
params.p_reduce_indptr[reduce_id] = reduce_end;
}
Expand Down
121 changes: 62 additions & 59 deletions csrc/kernels/mla/metadata/v1_1_device.cuh

Large diffs are not rendered by default.

26 changes: 13 additions & 13 deletions csrc/kernels/mla/metadata/v1_1_host.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ get_mla_metadata_v1_1_host(const torch::Tensor& seqlens_qo_indptr, // [batch siz
const int32_t cluster_size = [&]() {
const int32_t avg_packed_qo_len = sum_packed_qo_len / num_batches;
const int32_t cluster_size =
ck_tile::integer_divide_ceil(avg_packed_qo_len, Traits::kPackedQoLenPerWg);
return ck_tile::min(cluster_size, Traits::kMaxClusterSize);
integer_divide_ceil(avg_packed_qo_len, Traits::kPackedQoLenPerWg);
return std::min(cluster_size, Traits::kMaxClusterSize);
}();
TORCH_CHECK(
(dev_prop.multiProcessorCount % cluster_size) == 0, __func__, ": Invalid cluster_size!");
Expand All @@ -71,8 +71,8 @@ get_mla_metadata_v1_1_host(const torch::Tensor& seqlens_qo_indptr, // [batch siz
for(const auto& binfo : batch_infos)
{
const int32_t packed_qo_len = binfo.qo_len * num_heads;
const int32_t num_qo_tiles = ck_tile::integer_divide_ceil(packed_qo_len, cluster_len_q);
const int32_t packed_qo_tile_len = ck_tile::min(packed_qo_len, cluster_len_q);
const int32_t num_qo_tiles = integer_divide_ceil(packed_qo_len, cluster_len_q);
const int32_t packed_qo_tile_len = std::min(packed_qo_len, cluster_len_q);

num_qo_clusters_indptr.push_back(num_qo_clusters_indptr.back() + num_qo_tiles);

Expand All @@ -86,8 +86,8 @@ get_mla_metadata_v1_1_host(const torch::Tensor& seqlens_qo_indptr, // [batch siz
num_heads,
is_causal);
// always assume that each batch of tile will be splited once along kv.
const int32_t kv_len_splited = ck_tile::integer_least_multiple(
ck_tile::integer_divide_ceil(kv_len_valid, 2), kv_granularity);
const int32_t kv_len_splited = integer_least_multiple(
integer_divide_ceil(kv_len_valid, 2), kv_granularity);
workload_sum += 2 * cal_cost(packed_qo_tile_len, kv_len_splited) + kv_granularity;
}
}
Expand Down Expand Up @@ -125,7 +125,7 @@ get_mla_metadata_v1_1_host(const torch::Tensor& seqlens_qo_indptr, // [batch siz
const int32_t qo_len = binfo.qo_len;
const int32_t kv_len = binfo.kv_len;
const int32_t packed_qo_len = qo_len * num_heads;
const int32_t num_qo_tiles = ck_tile::integer_divide_ceil(packed_qo_len, cluster_len_q);
const int32_t num_qo_tiles = integer_divide_ceil(packed_qo_len, cluster_len_q);
const int32_t qo_batch_start = p_seqlens_qo_indptr[bid];
const int32_t kv_batch_start = p_seqlens_kv_indptr[bid];
const int32_t kv_batch_end = p_seqlens_kv_indptr[bid + 1];
Expand All @@ -145,15 +145,15 @@ get_mla_metadata_v1_1_host(const torch::Tensor& seqlens_qo_indptr, // [batch siz
const int32_t remaining_capability_top =
cal_kv_len(workload_limit_global - accum_cost_top, cluster_len_q);
const int32_t num_splits_estimated =
ck_tile::integer_divide_ceil(remaining_kv_len, remaining_capability_top);
integer_divide_ceil(remaining_kv_len, remaining_capability_top);
// For the case of #splits==2, make sure that the tailing tile is smaller than
// Traits::kSplitTolerance.
const bool split_kv =
(num_splits_estimated == 2)
? ((remaining_kv_len - remaining_capability_top) > Traits::kSplitTolerance)
: (num_splits_estimated > 1);
const int32_t kv_len_limit_floor = ck_tile::integer_least_multiple(
ck_tile::integer_divide_ceil(kv_len, num_clusters), kv_granularity);
const int32_t kv_len_limit_floor = integer_least_multiple(
integer_divide_ceil(kv_len, num_clusters), kv_granularity);

do
{
Expand All @@ -164,15 +164,15 @@ get_mla_metadata_v1_1_host(const torch::Tensor& seqlens_qo_indptr, // [batch siz
cal_kv_len(workload_limit_global - accum_cost, cluster_len_q);
const int32_t kv_len_limit_local = [&]() {
const int32_t limit_ori =
ck_tile::max(remaining_capability, kv_len_limit_floor);
std::max(remaining_capability, kv_len_limit_floor);
const int32_t tail_size = (remaining_kv_len > limit_ori)
? (remaining_kv_len - limit_ori)
: 0x7fffffff;
const int32_t limit_fin =
(tail_size <= Traits::kSplitTolerance) ? remaining_kv_len : limit_ori;
return limit_fin;
}();
const int32_t kv_len_consuming = ck_tile::min(remaining_kv_len, kv_len_limit_local);
const int32_t kv_len_consuming = std::min(remaining_kv_len, kv_len_limit_local);
const int32_t cost = cal_cost(cluster_len_q, kv_len_consuming);
#if PRINT_DBG
printf("[metadata] cost heap updated: cid=%d, pre_cost=%d, new_cost=%d, "
Expand All @@ -191,7 +191,7 @@ get_mla_metadata_v1_1_host(const torch::Tensor& seqlens_qo_indptr, // [batch siz
work_info.batch_idx = bid;
work_info.qo_start = tid * cluster_len_q + qo_batch_start;
work_info.qo_end =
ck_tile::min(work_info.qo_start + cluster_len_q, qo_batch_start + qo_len);
std::min(work_info.qo_start + cluster_len_q, qo_batch_start + qo_len);
work_info.kv_start = kv_start_local + kv_batch_start;
work_info.kv_end = work_info.kv_start + kv_len_consuming;
work_info.kv_offset = kv_batch_end - work_info.kv_end;
Expand Down
38 changes: 19 additions & 19 deletions csrc/kernels/mla/metadata/v1_2_device.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ struct MlaMetadataV12Traits
};

template <typename Traits>
__launch_bounds__(ck_tile::get_warp_size(), 1) __global__
__launch_bounds__(opus::get_warp_size(), 1) __global__
void kn_get_mla_metadata_v1_2(MlaMetadataV1KernelParameter params)
{
using QoState = QoState<Traits>;
Expand Down Expand Up @@ -69,12 +69,12 @@ __launch_bounds__(ck_tile::get_warp_size(), 1) __global__
}
};

const int32_t lane_idx = ck_tile::get_lane_id();
const int32_t lane_idx = opus::lane_id();

MlaWorkInfo* p_work_info_set = reinterpret_cast<MlaWorkInfo*>(params.p_work_info_set_raw);

int32_t sum_blocks = 0;
for(int32_t bid = lane_idx; bid < num_batches; bid += ck_tile::get_warp_size())
for(int32_t bid = lane_idx; bid < num_batches; bid += opus::get_warp_size())
{
const int32_t bid_ori = Traits::kIsSparse ? (bid / ori_seqlen_qo / params.qk_batch_ratio)
: (bid / params.qk_batch_ratio);
Expand All @@ -101,7 +101,7 @@ __launch_bounds__(ck_tile::get_warp_size(), 1) __global__
}

sum_blocks =
aiter::warpReduce<aiter::AddFunctor, decltype(sum_blocks), ck_tile::get_warp_size()>(
aiter::warpReduce<aiter::AddFunctor, decltype(sum_blocks), opus::get_warp_size()>(
sum_blocks);

if(lane_idx == 0)
Expand All @@ -115,7 +115,7 @@ __launch_bounds__(ck_tile::get_warp_size(), 1) __global__
}

// expected payload handled by each cu part.
const int32_t payload = ck_tile::integer_divide_ceil(sum_blocks, params.num_splits) +
const int32_t payload = integer_divide_ceil(sum_blocks, params.num_splits) +
params.fixed_over_head_num_blocks;
const int32_t page_size = params.page_size;
int32_t curr_batch = 0; // batch ID of the batch which is under review
Expand Down Expand Up @@ -144,7 +144,7 @@ __launch_bounds__(ck_tile::get_warp_size(), 1) __global__
{
const int32_t num_qo_tiles = get_num_qo_tiles(curr_batch);
const int32_t qo_tile_size =
ck_tile::integer_divide_ceil(qo_state.get_seqlen(curr_batch), num_qo_tiles);
integer_divide_ceil(qo_state.get_seqlen(curr_batch), num_qo_tiles);
const int32_t num_kv_blocks = integer_divide_ceil_power2(
curr_kv_seqlen, params.kv_granularity, params.kv_granularity_log2);
const int32_t remain_kv_blocks = num_kv_blocks - curr_kv_block;
Expand All @@ -162,7 +162,7 @@ __launch_bounds__(ck_tile::get_warp_size(), 1) __global__
work_info.batch_idx = curr_batch;
work_info.qo_start =
qo_state.get_begin(curr_batch) + curr_qo_tile_idx * qo_tile_size;
work_info.qo_end = ck_tile::min(work_info.qo_start + qo_tile_size,
work_info.qo_end = opus::min(work_info.qo_start + qo_tile_size,
qo_state.get_end(curr_batch));
work_info.kv_start = curr_kv_begin + (curr_kv_block * params.kv_granularity);
if(page_size == 1)
Expand All @@ -178,21 +178,21 @@ __launch_bounds__(ck_tile::get_warp_size(), 1) __global__
1;
}
}
batch_tail = ck_tile::max(batch_tail, 0);
work_info.kv_end = ck_tile::min(
batch_tail = opus::max(batch_tail, 0);
work_info.kv_end = opus::min(
work_info.kv_start + (remain_kv_blocks * params.kv_granularity),
curr_kv_end - batch_tail);
if((curr_kv_end - work_info.kv_end < params.tail_done_threshold &&
curr_kv_end - work_info.kv_end > 0) ||
cur_tail_done)
{
work_info.kv_end = ck_tile::min(curr_kv_end - batch_tail, curr_kv_end);
work_info.kv_end = opus::min(curr_kv_end - batch_tail, curr_kv_end);
}
work_info.kv_offset = curr_kv_end - work_info.kv_end;
}
else
{
work_info.kv_end = ck_tile::min(
work_info.kv_end = opus::min(
work_info.kv_start + (remain_kv_blocks * params.kv_granularity),
curr_kv_end);
work_info.kv_offset =
Expand Down Expand Up @@ -227,7 +227,7 @@ __launch_bounds__(ck_tile::get_warp_size(), 1) __global__
// record a work in work_info_set
if(curr_n_split_idx > 0)
{
for(int32_t idx = lane_idx; idx < num_splits; idx += ck_tile::get_warp_size())
for(int32_t idx = lane_idx; idx < num_splits; idx += opus::get_warp_size())
{
fill_work_info(idx);
}
Expand Down Expand Up @@ -303,7 +303,7 @@ __launch_bounds__(ck_tile::get_warp_size(), 1) __global__
work_info.batch_idx = curr_batch;
work_info.qo_start =
qo_state.get_begin(curr_batch) + curr_qo_tile_idx * qo_tile_size;
work_info.qo_end = ck_tile::min(work_info.qo_start + qo_tile_size,
work_info.qo_end = opus::min(work_info.qo_start + qo_tile_size,
qo_state.get_end(curr_batch));
work_info.kv_start =
curr_kv_begin + (curr_kv_block * params.kv_granularity);
Expand All @@ -320,21 +320,21 @@ __launch_bounds__(ck_tile::get_warp_size(), 1) __global__
1;
}
}
batch_tail = ck_tile::max(batch_tail, 0);
work_info.kv_end = ck_tile::min(
batch_tail = opus::max(batch_tail, 0);
work_info.kv_end = opus::min(
work_info.kv_start + (consuming_blks * params.kv_granularity),
curr_kv_end - batch_tail);
if(curr_kv_end - work_info.kv_end < params.tail_done_threshold)
{
cur_tail_done = true;
work_info.kv_end =
ck_tile::min(curr_kv_end, curr_kv_end - batch_tail);
opus::min(curr_kv_end, curr_kv_end - batch_tail);
}
work_info.kv_offset = curr_kv_end - work_info.kv_end;
}
else
{
work_info.kv_end = ck_tile::min(
work_info.kv_end = opus::min(
work_info.kv_start + (consuming_blks * params.kv_granularity),
curr_kv_end);
work_info.kv_offset =
Expand Down Expand Up @@ -373,7 +373,7 @@ __launch_bounds__(ck_tile::get_warp_size(), 1) __global__
}

for(int32_t i = tot_qo_tiles + lane_idx; i < params.reduce_indptr_size;
i += ck_tile::get_warp_size())
i += opus::get_warp_size())
{
params.p_reduce_indptr[i] = last_reduce_indptr;
}
Expand All @@ -393,7 +393,7 @@ void dispatch_mla_metadata_v1_2_device(const MlaMetadataV1KernelParameter& param
const int32_t lds_bytes_per_batch =
sizeof(int32_t) * (QoState<DummyTraits>::is_unique() ? 1 : 2);
const int32_t max_qo_tiles =
kQoSplits ? (ck_tile::integer_divide_ceil(max_seqlen_qo, kPackedQoLenPerWg)) : 1;
kQoSplits ? (integer_divide_ceil(max_seqlen_qo, kPackedQoLenPerWg)) : 1;
const int32_t max_lds_batch_size = lds_size / lds_bytes_per_batch;

if(params.num_batches <= max_lds_batch_size)
Expand Down
2 changes: 1 addition & 1 deletion csrc/kernels/mla/metadata/v1_2_host.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ void kn_generate_ps_metadata(std::vector<int32_t>& seqlens_qo_indptr,
const int32_t effective_kv_length =
is_causal ? std::min(kv_length - qo_length + local_qo_end, kv_length) : kv_length;
const int32_t num_units =
ck_tile::integer_divide_ceil(effective_kv_length, kvlen_granularity);
integer_divide_ceil(effective_kv_length, kvlen_granularity);
// const int32_t num_units =
// offset_div(effective_kv_length, kvlen_granularity, SPLIT_KV_OVERHEAD);
query_tiles.push_back({batch_idx,
Expand Down
Loading
Loading