diff --git a/csrc/smxx/get_mla_metadata.cu b/csrc/smxx/get_mla_metadata.cu index d46fe53..4c37d5c 100644 --- a/csrc/smxx/get_mla_metadata.cu +++ b/csrc/smxx/get_mla_metadata.cu @@ -88,9 +88,125 @@ get_mla_metadata_kernel(__grid_constant__ const GetDecodingMetadataParams params } } +// Low shared memory fallback: single-threaded sequential computation for large batches +__global__ void __launch_bounds__(32, 1, 1) +get_mla_metadata_kernel_low_smem(__grid_constant__ const GetDecodingMetadataParams params) { + int *seqlens_k_ptr = params.seqlens_k_ptr; + int *tile_scheduler_metadata_ptr = params.tile_scheduler_metadata_ptr; + int *num_splits_ptr = params.num_splits_ptr; + int batch_size = params.batch_size; + int block_size_n = params.block_size_n; + int fixed_overhead_num_blocks = params.fixed_overhead_num_blocks; + int num_sm_parts = params.num_sm_parts; + + // Only thread 0 performs computation to minimize shared memory usage + if (threadIdx.x != 0) { + return; + } + + // Compute total_num_blocks + int total_num_blocks = 0; + for (int i = 0; i < batch_size; ++i) { + int cur_s_k = (params.topk == -1 ? seqlens_k_ptr[i] : params.topk); + int last_token_idx = max(cur_s_k - 1, 0); + int cur_first_block_idx = 0; // first_token_idx = 0 + int cur_last_block_idx = last_token_idx / block_size_n; + int num_blocks = cur_last_block_idx - cur_first_block_idx + 1; + total_num_blocks += num_blocks + fixed_overhead_num_blocks; + } + + int payload = cutlass::ceil_div(total_num_blocks, num_sm_parts) + fixed_overhead_num_blocks; + + // Allocate splits and write tile_scheduler_metadata + int now_idx = 0; + int now_block = 0; + int now_n_split_idx = 0; + int cum_num_splits = 0; + + num_splits_ptr[0] = 0; + + for (int part = 0; part < num_sm_parts; ++part) { + int tile_scheduler_metadata0[4]; + int tile_scheduler_metadata1; + + tile_scheduler_metadata0[0] = (now_idx >= batch_size ? -1 : now_idx); + tile_scheduler_metadata0[1] = now_block; // first_block_idx = 0 + tile_scheduler_metadata1 = now_n_split_idx; + + int remain_payload = payload; + + while (now_idx < batch_size) { + int cur_s_k = (params.topk == -1 ? seqlens_k_ptr[now_idx] : params.topk); + int last_token_idx = max(cur_s_k - 1, 0); + int cur_first_block_idx = 0; + int cur_last_block_idx = last_token_idx / block_size_n; + int num_blocks = cur_last_block_idx - cur_first_block_idx + 1; + + int now_remain_blocks = num_blocks - now_block; + + if (remain_payload >= now_remain_blocks + fixed_overhead_num_blocks) { + cum_num_splits += now_n_split_idx + 1; + num_splits_ptr[now_idx + 1] = cum_num_splits; + remain_payload -= now_remain_blocks + fixed_overhead_num_blocks; + ++now_idx; + now_block = 0; + now_n_split_idx = 0; + } else { + if (remain_payload - fixed_overhead_num_blocks > 0) { + now_block += remain_payload - fixed_overhead_num_blocks; + ++now_n_split_idx; + remain_payload = 0; + } + break; + } + } + + int prev_idx = now_idx - 1; + tile_scheduler_metadata0[2] = (now_block > 0 ? now_idx : prev_idx); + if (now_block > 0) { + tile_scheduler_metadata0[3] = now_block; + } else { + int prev_s_k = (params.topk == -1 ? seqlens_k_ptr[prev_idx] : params.topk); + int last_token_idx_prev = max(prev_s_k - 1, 0); + int prev_last_block_idx = last_token_idx_prev / block_size_n; + tile_scheduler_metadata0[3] = (prev_s_k == 0 ? 0 : prev_last_block_idx + 1); + } + + *reinterpret_cast(tile_scheduler_metadata_ptr + part * TileSchedulerMetaDataSize) = + *reinterpret_cast(tile_scheduler_metadata0); + tile_scheduler_metadata_ptr[part * TileSchedulerMetaDataSize + 4] = tile_scheduler_metadata1; + } + + FLASH_DEVICE_ASSERT(now_idx == batch_size && now_block == 0 && now_n_split_idx == 0); +} + void run_get_mla_metadata_kernel(GetDecodingMetadataParams ¶ms, cudaStream_t stream) { - int smem_size = sizeof(int) * (params.batch_size*5+1); - CHECK_CUDA(cudaFuncSetAttribute(get_mla_metadata_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - get_mla_metadata_kernel<<<1, 32, smem_size, stream>>>(params); - CHECK_CUDA_KERNEL_LAUNCH(); + int smem_size = sizeof(int) * (params.batch_size * 5 + 1); + + int max_smem = 0; + int dev = 0; + cudaGetDevice(&dev); + cudaDeviceGetAttribute(&max_smem, cudaDevAttrMaxSharedMemoryPerBlockOptin, dev); + + if (smem_size <= max_smem) { + // Fast path: shared memory available, use high-performance kernel + CHECK_CUDA(cudaFuncSetAttribute( + get_mla_metadata_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size)); + get_mla_metadata_kernel<<<1, 32, smem_size, stream>>>(params); + CHECK_CUDA_KERNEL_LAUNCH(); + } else { + // Fallback path: batch size exceeds shared memory limit + printf("[WARNING] batch_size=%d exceeds shared mem limit, " + "falling back to low-smem kernel\n", + params.batch_size); + fflush(stdout); + CHECK_CUDA(cudaFuncSetAttribute( + get_mla_metadata_kernel_low_smem, + cudaFuncAttributeMaxDynamicSharedMemorySize, + 0)); + get_mla_metadata_kernel_low_smem<<<1, 32, 0, stream>>>(params); + CHECK_CUDA_KERNEL_LAUNCH(); + } }