Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 120 additions & 4 deletions csrc/smxx/get_mla_metadata.cu
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,125 @@ get_mla_metadata_kernel(__grid_constant__ const GetDecodingMetadataParams params
}
}

// Low shared memory fallback: single-threaded sequential computation for large batches
__global__ void __launch_bounds__(32, 1, 1)
get_mla_metadata_kernel_low_smem(__grid_constant__ const GetDecodingMetadataParams params) {
int *seqlens_k_ptr = params.seqlens_k_ptr;
int *tile_scheduler_metadata_ptr = params.tile_scheduler_metadata_ptr;
int *num_splits_ptr = params.num_splits_ptr;
int batch_size = params.batch_size;
int block_size_n = params.block_size_n;
int fixed_overhead_num_blocks = params.fixed_overhead_num_blocks;
int num_sm_parts = params.num_sm_parts;

// Only thread 0 performs computation to minimize shared memory usage
if (threadIdx.x != 0) {
return;
}

// Compute total_num_blocks
int total_num_blocks = 0;
for (int i = 0; i < batch_size; ++i) {
int cur_s_k = (params.topk == -1 ? seqlens_k_ptr[i] : params.topk);
int last_token_idx = max(cur_s_k - 1, 0);
int cur_first_block_idx = 0; // first_token_idx = 0
int cur_last_block_idx = last_token_idx / block_size_n;
int num_blocks = cur_last_block_idx - cur_first_block_idx + 1;
total_num_blocks += num_blocks + fixed_overhead_num_blocks;
}

int payload = cutlass::ceil_div(total_num_blocks, num_sm_parts) + fixed_overhead_num_blocks;

// Allocate splits and write tile_scheduler_metadata
int now_idx = 0;
int now_block = 0;
int now_n_split_idx = 0;
int cum_num_splits = 0;

num_splits_ptr[0] = 0;

for (int part = 0; part < num_sm_parts; ++part) {
int tile_scheduler_metadata0[4];
int tile_scheduler_metadata1;

tile_scheduler_metadata0[0] = (now_idx >= batch_size ? -1 : now_idx);
tile_scheduler_metadata0[1] = now_block; // first_block_idx = 0
tile_scheduler_metadata1 = now_n_split_idx;

int remain_payload = payload;

while (now_idx < batch_size) {
int cur_s_k = (params.topk == -1 ? seqlens_k_ptr[now_idx] : params.topk);
int last_token_idx = max(cur_s_k - 1, 0);
int cur_first_block_idx = 0;
int cur_last_block_idx = last_token_idx / block_size_n;
int num_blocks = cur_last_block_idx - cur_first_block_idx + 1;

int now_remain_blocks = num_blocks - now_block;

if (remain_payload >= now_remain_blocks + fixed_overhead_num_blocks) {
cum_num_splits += now_n_split_idx + 1;
num_splits_ptr[now_idx + 1] = cum_num_splits;
remain_payload -= now_remain_blocks + fixed_overhead_num_blocks;
++now_idx;
now_block = 0;
now_n_split_idx = 0;
} else {
if (remain_payload - fixed_overhead_num_blocks > 0) {
now_block += remain_payload - fixed_overhead_num_blocks;
++now_n_split_idx;
remain_payload = 0;
}
break;
}
}

int prev_idx = now_idx - 1;
tile_scheduler_metadata0[2] = (now_block > 0 ? now_idx : prev_idx);
if (now_block > 0) {
tile_scheduler_metadata0[3] = now_block;
} else {
int prev_s_k = (params.topk == -1 ? seqlens_k_ptr[prev_idx] : params.topk);
int last_token_idx_prev = max(prev_s_k - 1, 0);
int prev_last_block_idx = last_token_idx_prev / block_size_n;
tile_scheduler_metadata0[3] = (prev_s_k == 0 ? 0 : prev_last_block_idx + 1);
}

*reinterpret_cast<int4 *>(tile_scheduler_metadata_ptr + part * TileSchedulerMetaDataSize) =
*reinterpret_cast<int4 *>(tile_scheduler_metadata0);
tile_scheduler_metadata_ptr[part * TileSchedulerMetaDataSize + 4] = tile_scheduler_metadata1;
}

FLASH_DEVICE_ASSERT(now_idx == batch_size && now_block == 0 && now_n_split_idx == 0);
}

void run_get_mla_metadata_kernel(GetDecodingMetadataParams &params, cudaStream_t stream) {
int smem_size = sizeof(int) * (params.batch_size*5+1);
CHECK_CUDA(cudaFuncSetAttribute(get_mla_metadata_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
get_mla_metadata_kernel<<<1, 32, smem_size, stream>>>(params);
CHECK_CUDA_KERNEL_LAUNCH();
int smem_size = sizeof(int) * (params.batch_size * 5 + 1);

int max_smem = 0;
int dev = 0;
cudaGetDevice(&dev);
cudaDeviceGetAttribute(&max_smem, cudaDevAttrMaxSharedMemoryPerBlockOptin, dev);

if (smem_size <= max_smem) {
// Fast path: shared memory available, use high-performance kernel
CHECK_CUDA(cudaFuncSetAttribute(
get_mla_metadata_kernel,
cudaFuncAttributeMaxDynamicSharedMemorySize,
smem_size));
get_mla_metadata_kernel<<<1, 32, smem_size, stream>>>(params);
CHECK_CUDA_KERNEL_LAUNCH();
} else {
// Fallback path: batch size exceeds shared memory limit
printf("[WARNING] batch_size=%d exceeds shared mem limit, "
"falling back to low-smem kernel\n",
params.batch_size);
fflush(stdout);
CHECK_CUDA(cudaFuncSetAttribute(
get_mla_metadata_kernel_low_smem,
cudaFuncAttributeMaxDynamicSharedMemorySize,
0));
get_mla_metadata_kernel_low_smem<<<1, 32, 0, stream>>>(params);
CHECK_CUDA_KERNEL_LAUNCH();
}
}