diff --git a/csrc/topk.cu b/csrc/topk.cu index 364ecc21e532..b0f612ba6e4b 100644 --- a/csrc/topk.cu +++ b/csrc/topk.cu @@ -82,18 +82,73 @@ void launch_persistent_topk(const torch::Tensor& logits, size_t smem_size = P::kFixedSmemLarge + chunk_size * sizeof(uint32_t); if (smem_size < P::kSmemMedium) smem_size = P::kSmemMedium; + // Query occupancy for the instantiation that will actually launch; + // overestimating it deadlocks the cooperative barrier. int occupancy = 1; - cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &occupancy, P::persistent_topk_kernel, P::kThreadsPerBlock, - smem_size); + cudaError_t occ_err = cudaSuccess; + if (vec_size == 4) { + occ_err = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &occupancy, P::persistent_topk_kernel, P::kThreadsPerBlock, + smem_size); + } else if (vec_size == 2) { + occ_err = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &occupancy, P::persistent_topk_kernel, P::kThreadsPerBlock, + smem_size); + } else { + occ_err = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &occupancy, P::persistent_topk_kernel, P::kThreadsPerBlock, + smem_size); + } + TORCH_CHECK(occ_err == cudaSuccess, + "persistent_topk occupancy query failed: ", + cudaGetErrorString(occ_err)); if (occupancy < 1) occupancy = 1; - uint32_t max_resident_ctas = static_cast(num_sms) * occupancy; + // The cooperative spin-wait barrier only runs when at least one row hits + // the radix path (seq_len > RADIX_THRESHOLD). Below that, non-CTA-0 CTAs + // early-exit, so oversubscription can't deadlock and headroom is wasted. + const bool needs_cooperative = + static_cast(max_seq_len) > P::RADIX_THRESHOLD; + + const uint32_t hw_resident_cap = + static_cast(num_sms) * static_cast(occupancy); + uint32_t max_resident_ctas = hw_resident_cap; + if (needs_cooperative) { + // Reserve one CTA per SM when occupancy allows; fall back to a single + // CTA when occupancy == 1 (the most deadlock-prone case — any straggler + // kernel that takes the only slot on one SM hangs the barrier). Never + // drop below one full group's worth. + uint32_t headroom = (occupancy > 1) ? static_cast(num_sms) : 1u; + if (max_resident_ctas >= headroom + ctas_per_group) { + max_resident_ctas -= headroom; + } + } uint32_t num_groups = std::min(max_resident_ctas / ctas_per_group, static_cast(num_rows)); if (num_groups == 0) num_groups = 1; uint32_t total_ctas = num_groups * ctas_per_group; + // If the cooperative launch wouldn't fit, fall back to FilteredTopK + // instead of deadlocking. Only relevant when needs_cooperative. + if (needs_cooperative && total_ctas > hw_resident_cap) { + TORCH_CHECK(max_smem_per_block >= 128 * 1024, + "persistent_topk would oversubscribe and the FilteredTopK " + "fallback requires >=128KB smem per block (have ", + max_smem_per_block, "). total_ctas=", total_ctas, + " > num_sms*occupancy=", hw_resident_cap, " (TopK=", TopK, + ", vec_size=", vec_size, ", ctas_per_group=", ctas_per_group, + ", smem=", smem_size, ")."); + cudaError_t status = + vllm::FilteredTopKRaggedTransform( + logits.data_ptr(), output.data_ptr(), + lengths.data_ptr(), static_cast(num_rows), + static_cast(TopK), static_cast(stride), + stream); + TORCH_CHECK(status == cudaSuccess, + "FilteredTopK fallback failed: ", cudaGetErrorString(status)); + return; + } + size_t state_bytes = num_groups * sizeof(P::RadixRowState); TORCH_CHECK(workspace.size(0) >= static_cast(state_bytes), "workspace too small, need ", state_bytes, " bytes");