Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 59 additions & 4 deletions csrc/topk.cu
Original file line number Diff line number Diff line change
Expand Up @@ -82,18 +82,73 @@ void launch_persistent_topk(const torch::Tensor& logits,
size_t smem_size = P::kFixedSmemLarge + chunk_size * sizeof(uint32_t);
if (smem_size < P::kSmemMedium) smem_size = P::kSmemMedium;

// Query occupancy for the instantiation that will actually launch;
// overestimating it deadlocks the cooperative barrier.
int occupancy = 1;
cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&occupancy, P::persistent_topk_kernel<TopK, 4>, P::kThreadsPerBlock,
smem_size);
cudaError_t occ_err = cudaSuccess;
if (vec_size == 4) {
occ_err = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&occupancy, P::persistent_topk_kernel<TopK, 4>, P::kThreadsPerBlock,
smem_size);
} else if (vec_size == 2) {
occ_err = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&occupancy, P::persistent_topk_kernel<TopK, 2>, P::kThreadsPerBlock,
smem_size);
} else {
occ_err = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&occupancy, P::persistent_topk_kernel<TopK, 1>, P::kThreadsPerBlock,
smem_size);
}
TORCH_CHECK(occ_err == cudaSuccess,
"persistent_topk occupancy query failed: ",
cudaGetErrorString(occ_err));
if (occupancy < 1) occupancy = 1;

uint32_t max_resident_ctas = static_cast<uint32_t>(num_sms) * occupancy;
// The cooperative spin-wait barrier only runs when at least one row hits
// the radix path (seq_len > RADIX_THRESHOLD). Below that, non-CTA-0 CTAs
// early-exit, so oversubscription can't deadlock and headroom is wasted.
const bool needs_cooperative =
static_cast<uint32_t>(max_seq_len) > P::RADIX_THRESHOLD;

const uint32_t hw_resident_cap =
static_cast<uint32_t>(num_sms) * static_cast<uint32_t>(occupancy);
uint32_t max_resident_ctas = hw_resident_cap;
if (needs_cooperative) {
// Reserve one CTA per SM when occupancy allows; fall back to a single
// CTA when occupancy == 1 (the most deadlock-prone case — any straggler
// kernel that takes the only slot on one SM hangs the barrier). Never
// drop below one full group's worth.
uint32_t headroom = (occupancy > 1) ? static_cast<uint32_t>(num_sms) : 1u;
if (max_resident_ctas >= headroom + ctas_per_group) {
max_resident_ctas -= headroom;
}
}
uint32_t num_groups = std::min(max_resident_ctas / ctas_per_group,
static_cast<uint32_t>(num_rows));
if (num_groups == 0) num_groups = 1;
uint32_t total_ctas = num_groups * ctas_per_group;

// If the cooperative launch wouldn't fit, fall back to FilteredTopK
// instead of deadlocking. Only relevant when needs_cooperative.
if (needs_cooperative && total_ctas > hw_resident_cap) {
TORCH_CHECK(max_smem_per_block >= 128 * 1024,
"persistent_topk would oversubscribe and the FilteredTopK "
"fallback requires >=128KB smem per block (have ",
max_smem_per_block, "). total_ctas=", total_ctas,
" > num_sms*occupancy=", hw_resident_cap, " (TopK=", TopK,
", vec_size=", vec_size, ", ctas_per_group=", ctas_per_group,
", smem=", smem_size, ").");
cudaError_t status =
vllm::FilteredTopKRaggedTransform<float, int32_t, TopK>(
logits.data_ptr<float>(), output.data_ptr<int32_t>(),
lengths.data_ptr<int32_t>(), static_cast<uint32_t>(num_rows),
static_cast<uint32_t>(TopK), static_cast<uint32_t>(stride),
stream);
TORCH_CHECK(status == cudaSuccess,
"FilteredTopK fallback failed: ", cudaGetErrorString(status));
return;
}

size_t state_bytes = num_groups * sizeof(P::RadixRowState);
TORCH_CHECK(workspace.size(0) >= static_cast<int64_t>(state_bytes),
"workspace too small, need ", state_bytes, " bytes");
Expand Down
Loading