From 2ea16f1424c1390d2417891f4768785fb2cd5a08 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Fri, 27 Feb 2026 18:17:34 +0000 Subject: [PATCH 1/8] Initial plan From f58bc38784f0c7c931728a4c82eb827ae0cce668 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Fri, 27 Feb 2026 18:36:31 +0000 Subject: [PATCH 2/8] Fix cascading CUDA errors in TopK benchmark by using cudaGetLastError consistently Two fixes to prevent stale CUDA errors from cascading across TopK algorithm benchmarks: 1. CUDA_CHECK_LAUNCH: Remove the NDEBUG-conditional branch that used cudaPeekAtLastError() in release builds. Now consistently uses cudaGetLastError() which clears the error after detection, preventing stale errors from propagating to subsequent kernel launches. 2. BENCHMARK_KERNEL: Add cudaGetLastError() in the catch handler to explicitly clear the CUDA error state after a benchmark failure, ensuring the next algorithm starts with a clean error state. These changes fix the issue where all TopK algorithms fail in cascade for vocab_size=151936 (Qwen2 models). The root cause was that a genuine cooperative launch failure in one algorithm would leave a stale error that was falsely detected by subsequent algorithms' CUDA_CHECK_LAUNCH calls in release builds. Co-authored-by: tianleiwu <30328909+tianleiwu@users.noreply.github.com> --- src/cuda/cuda_common.h | 14 -------------- src/cuda/cuda_topk_benchmark.cuh | 1 + 2 files changed, 1 insertion(+), 14 deletions(-) diff --git a/src/cuda/cuda_common.h b/src/cuda/cuda_common.h index 7e283a6a4d..b487e4734f 100644 --- a/src/cuda/cuda_common.h +++ b/src/cuda/cuda_common.h @@ -129,19 +129,6 @@ class CudaError : public std::runtime_error { } \ } while (0) -#ifdef NDEBUG -#define CUDA_CHECK_LAUNCH() \ - do { \ - cudaError_t err = cudaPeekAtLastError(); \ - if (err != cudaSuccess) { \ - std::stringstream ss; \ - ss << "CUDA launch error in " << __func__ << " at " \ - << __FILE__ << ":" << __LINE__ << " - " \ - << cudaGetErrorString(err); \ - throw Generators::CudaError(ss.str(), err); \ - } \ - } while (0) -#else #define CUDA_CHECK_LAUNCH() \ do { \ cudaError_t err = cudaGetLastError(); \ @@ -153,6 +140,5 @@ class CudaError : public std::runtime_error { throw Generators::CudaError(ss.str(), err); \ } \ } while (0) -#endif } // namespace Generators diff --git a/src/cuda/cuda_topk_benchmark.cuh b/src/cuda/cuda_topk_benchmark.cuh index 9766e7a9be..5bc5bc42ca 100644 --- a/src/cuda/cuda_topk_benchmark.cuh +++ b/src/cuda/cuda_topk_benchmark.cuh @@ -79,6 +79,7 @@ static const char* TopkAlgoToString(TopkAlgo algo) { best_algo = algo_enum; \ } \ } catch (const Generators::CudaError& e) { \ + cudaGetLastError(); \ std::cerr << "Benchmarking failed for " << TopkAlgoToString(algo_enum) \ << " kernel with k=" << k << ", batch_size=" << batch_size \ << ", vocab_size=" << vocab_size << ". Error: " << e.what() << std::endl; \ From cd5667c499854f5b04d8954804433db468c3af2a Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Fri, 27 Feb 2026 19:16:22 +0000 Subject: [PATCH 3/8] Changes before error encountered Co-authored-by: tianleiwu <30328909+tianleiwu@users.noreply.github.com> --- src/cuda/cuda_topk_benchmark.cuh | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/cuda/cuda_topk_benchmark.cuh b/src/cuda/cuda_topk_benchmark.cuh index 5bc5bc42ca..3382e26520 100644 --- a/src/cuda/cuda_topk_benchmark.cuh +++ b/src/cuda/cuda_topk_benchmark.cuh @@ -102,6 +102,12 @@ static TopkAlgo BenchmarkAndSelectBestAlgo(TopkData* topk_data, int vocab_size, int batch_size, int k) { + // Clear any stale CUDA errors from previous operations to prevent false failures. + // Successful CUDA API calls do NOT clear the thread-local error state, so a stale + // error (e.g., from TopkData construction or prior inference) can persist and be + // falsely detected by CUDA_CHECK_LAUNCH() inside the benchmark kernels. + cudaGetLastError(); + float min_latency = std::numeric_limits::max(); TopkAlgo best_algo = TopkAlgo::UNKNOWN; From 126ab897b41366632161764d2eae76ef1f665220 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Fri, 27 Feb 2026 19:26:12 +0000 Subject: [PATCH 4/8] Fix cooperative kernel benchmark fallback and stale CUDA error handling Two fixes to prevent TopK benchmark failures on GPUs with fewer SMs (e.g., RTX 4050 with 18 SMs): 1. Clear stale CUDA errors at the start of BenchmarkAndSelectBestAlgo. The IsSupportedCooperative check correctly uses cudaOccupancyMaxActiveBlocksPerMultiprocessor to allow multiple blocks per SM for cooperative launches (e.g., 60 blocks on 18 SMs with 4 blocks/SM occupancy = 72 max). However, stale CUDA errors from prior operations persist in the thread-local error state and are falsely detected by CUDA_CHECK_LAUNCH() inside benchmark kernels. 2. Fix hybrid_sort fallback: when cooperative kernel benchmarks all fail at runtime (caught by try-catch), best_algo remains UNKNOWN but hybrid_sort was being skipped because IsSupported returned true. Now hybrid_sort is also tried when best_algo == UNKNOWN, providing a robust fallback path before resorting to full_sort. Co-authored-by: tianleiwu <30328909+tianleiwu@users.noreply.github.com> --- src/cuda/cuda_topk_benchmark.cuh | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/cuda/cuda_topk_benchmark.cuh b/src/cuda/cuda_topk_benchmark.cuh index 3382e26520..5d739bdcf5 100644 --- a/src/cuda/cuda_topk_benchmark.cuh +++ b/src/cuda/cuda_topk_benchmark.cuh @@ -154,9 +154,11 @@ static TopkAlgo BenchmarkAndSelectBestAlgo(TopkData* topk_data, }); } - // Candidate: Hybrid Sort. This is a robust fallback. We benchmark it if either the cooperative - // kernels are not supported, or if the vocab size is small, where hybrid can sometimes be faster. - if (!use_iterative_sort && !use_cascaded_sort && !use_flash_convergent || vocab_size <= 4096) { + // Candidate: Hybrid Sort. This is a robust fallback. We benchmark it if the cooperative + // kernels are not supported, if their benchmarks all failed at runtime (best_algo is still + // UNKNOWN despite IsSupported returning true), or if the vocab size is small, where hybrid + // can sometimes be faster. + if (best_algo == TopkAlgo::UNKNOWN || vocab_size <= 4096) { if (hybrid_sort::IsSupported(batch_size, vocab_size, k)) { BENCHMARK_KERNEL(TopkAlgo::HYBRID, [&]() { hybrid_sort::RunTopK(topk_data, stream, scores_in, vocab_size, batch_size, k); From 764de802be08fede3031a3de3ff90612afeb92a7 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Fri, 27 Feb 2026 11:43:55 -0800 Subject: [PATCH 5/8] refine --- src/cuda/cuda_common.h | 14 ++++++++++++++ src/cuda/cuda_topk_benchmark.cuh | 1 - 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/src/cuda/cuda_common.h b/src/cuda/cuda_common.h index b487e4734f..f749879d2e 100644 --- a/src/cuda/cuda_common.h +++ b/src/cuda/cuda_common.h @@ -129,6 +129,19 @@ class CudaError : public std::runtime_error { } \ } while (0) +#ifdef NDEBUG +#define CUDA_CHECK_LAUNCH() \ + do { \ + cudaError_t err = cudaPeekAtLastError(); \ + if (err != cudaSuccess) { \ + std::stringstream ss; \ + ss << "CUDA launch error in " << __func__ << " at " \ + << __FILE__ << ":" << __LINE__ << " - " \ + << cudaGetErrorString(err); \ + throw Generators::CudaError(ss.str(), cudaGetLastError()); \ + } \ + } while (0) +#else #define CUDA_CHECK_LAUNCH() \ do { \ cudaError_t err = cudaGetLastError(); \ @@ -140,5 +153,6 @@ class CudaError : public std::runtime_error { throw Generators::CudaError(ss.str(), err); \ } \ } while (0) +#endif } // namespace Generators diff --git a/src/cuda/cuda_topk_benchmark.cuh b/src/cuda/cuda_topk_benchmark.cuh index 5d739bdcf5..c19bf65e41 100644 --- a/src/cuda/cuda_topk_benchmark.cuh +++ b/src/cuda/cuda_topk_benchmark.cuh @@ -79,7 +79,6 @@ static const char* TopkAlgoToString(TopkAlgo algo) { best_algo = algo_enum; \ } \ } catch (const Generators::CudaError& e) { \ - cudaGetLastError(); \ std::cerr << "Benchmarking failed for " << TopkAlgoToString(algo_enum) \ << " kernel with k=" << k << ", batch_size=" << batch_size \ << ", vocab_size=" << vocab_size << ". Error: " << e.what() << std::endl; \ From d2e7de0bf7c5de73e3a219e16e1058ecfc832ea4 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Fri, 27 Feb 2026 13:54:46 -0800 Subject: [PATCH 6/8] Add CUDA error checking to all kernel launches and API calls - Wrap all unchecked cudaMemset/cudaMemcpy/cudaStreamSync/cudaEventRecord/ cudaEventSync calls with CUDA_CHECK() - Add CUDA_CHECK_LAUNCH() after every kernel launch wrapper function - Add #include cuda_common.h where macros are used directly - Restructure cuda_common.h: move macros before RAII helpers that use them - Wrap cudaMalloc/cudaMallocHost/cudaEventCreate/cudaStreamCreate in RAII helpers with CUDA_CHECK() - Remove unused OnCudaError/CudaCheck legacy error helpers - Make CUDA_CHECK clear stale error via cudaGetLastError() before throwing - Make CUDA_CHECK_LAUNCH (NDEBUG) consistent: clear with cudaGetLastError(), throw with the detected error code --- src/cuda/beam_search_scorer_cuda.cpp | 9 +- src/cuda/beam_search_scorer_cuda.cu | 7 ++ src/cuda/beam_search_topk.cu | 11 +- src/cuda/cuda_common.h | 115 +++++++++----------- src/cuda/cuda_topk.cu | 2 + src/cuda/cuda_topk_per_batch_radix_sort.cuh | 2 +- src/cuda/model_kernels.cu | 11 ++ src/cuda/search_cuda.cpp | 23 ++-- src/cuda/search_cuda.cu | 7 ++ 9 files changed, 105 insertions(+), 82 deletions(-) diff --git a/src/cuda/beam_search_scorer_cuda.cpp b/src/cuda/beam_search_scorer_cuda.cpp index 61fc86accb..964dcf430a 100644 --- a/src/cuda/beam_search_scorer_cuda.cpp +++ b/src/cuda/beam_search_scorer_cuda.cpp @@ -4,6 +4,7 @@ #include "generators.h" #include "search.h" #include "search_cuda.h" +#include "cuda_common.h" #include "beam_search_scorer_cuda.cuh" #include "beam_search_scorer_cuda.h" #include "interface.h" @@ -22,7 +23,7 @@ BeamSearchScorer_Cuda::BeamSearchScorer_Cuda(const GeneratorParams& parameters, state_cpu_->not_done_count_ = parameters.search.batch_size; state_cpu_->hypothesis_buffer_used_ = 0; state_gpu_ = CudaMallocArray(1); - cudaMemcpyAsync(state_gpu_.get(), state_cpu_.get(), sizeof(cuda::BeamScorerState), ::cudaMemcpyHostToDevice, stream_); + CUDA_CHECK(cudaMemcpyAsync(state_gpu_.get(), state_cpu_.get(), sizeof(cuda::BeamScorerState), ::cudaMemcpyHostToDevice, stream_)); size_t batch_beam_size = state_cpu_->batch_size_ * state_cpu_->num_beams_; @@ -63,7 +64,7 @@ void BeamSearchScorer_Cuda::Process(Sequences& sequences, next_tokens, next_indices, stream_); - cudaEventRecord(event_process_complete_, stream_); + CUDA_CHECK(cudaEventRecord(event_process_complete_, stream_)); cuda::LaunchBeamSearchScorer_AppendNextTokenToSequences(*state_cpu_, *state_gpu_, @@ -76,7 +77,7 @@ void BeamSearchScorer_Cuda::Process(Sequences& sequences, } bool BeamSearchScorer_Cuda::IsDoneLater() const { - cudaEventSynchronize(event_process_complete_); + CUDA_CHECK(cudaEventSynchronize(event_process_complete_)); return state_cpu_->not_done_count_ == 0; } @@ -90,7 +91,7 @@ DeviceSpan BeamSearchScorer_Cuda::GetBeamHypothesis(size_t batch_id, si cuda_host_unique_ptr hypothesis_length = CudaMallocHostArray(1); cuda_host_unique_ptr hypothesis_score = CudaMallocHostArray(1); cuda::LaunchBeamSearchScorer_GetHypothesisPtr(batch_id, beam_id, beam_hyps_, hypothesis_ptr.get(), hypothesis_length.get(), hypothesis_score.get(), stream_); - CudaCheck() == cudaStreamSynchronize(stream_); + CUDA_CHECK(cudaStreamSynchronize(stream_)); std::span hypothesis(*hypothesis_ptr.get(), *hypothesis_length.get()); // Translate the hypothesis span back to the original device buffer span return hypothesis_buffer_.subspan(hypothesis.data() - hypothesis_buffer_.Span().data(), hypothesis.size()); diff --git a/src/cuda/beam_search_scorer_cuda.cu b/src/cuda/beam_search_scorer_cuda.cu index d7247d55e4..58c72b45d7 100644 --- a/src/cuda/beam_search_scorer_cuda.cu +++ b/src/cuda/beam_search_scorer_cuda.cu @@ -5,6 +5,7 @@ #include #include #include "span.h" +#include "cuda_common.h" #include "beam_search_scorer_cuda.cuh" namespace Generators { @@ -49,6 +50,7 @@ void LaunchInitializeBeamHypotheses(std::span beam_hyps, length_penalty, beams.data(), num_beams); + CUDA_CHECK_LAUNCH(); } __device__ void BeamHypotheses::Add(const int32_t* hypothesis, int hypothesis_length, float sum_logprobs) { @@ -187,6 +189,7 @@ void LaunchBeamSearchScorer_Process(BeamScorerState& state_cpu, next_scores.data(), next_tokens.data(), next_indices.data()); + CUDA_CHECK_LAUNCH(); } __global__ void BeamSearchScorer_AppendNextTokenToSequences1(BeamScorerState& state, @@ -254,6 +257,7 @@ void LaunchBeamSearchScorer_AppendNextTokenToSequences(BeamScorerState& state_cp next_sequences.data(), sequence_length, next_beam_tokens.data()); + CUDA_CHECK_LAUNCH(); } __global__ void BeamSearchScorer_Finalize(BeamScorerState& state, @@ -298,6 +302,7 @@ void LaunchBeamSearchScorer_Finalize(int batch_size, beam_hyps.data(), hypothesis_buffer.data(), final_beam_scores.data()); + CUDA_CHECK_LAUNCH(); } __global__ void BeamSearchScorer_GetHypothesisPtr(size_t batch_id, @@ -326,6 +331,7 @@ void LaunchBeamSearchScorer_GetHypothesisPtr(size_t batch_id, hypothesis_ptr, hypothesis_length, hypothesis_score); + CUDA_CHECK_LAUNCH(); } __global__ void InitScoresKernel(float* beam_scores, @@ -347,6 +353,7 @@ void LaunchInitScoresKernel( constexpr int blockSize = 256; const int gridSize = (total_elements + blockSize - 1) / blockSize; InitScoresKernel<<>>(beam_scores, num_beams, total_elements); + CUDA_CHECK_LAUNCH(); } } // namespace cuda diff --git a/src/cuda/beam_search_topk.cu b/src/cuda/beam_search_topk.cu index 32da76fa2b..0186017f5a 100644 --- a/src/cuda/beam_search_topk.cu +++ b/src/cuda/beam_search_topk.cu @@ -5,6 +5,7 @@ #include #include #include "beam_search_topk.h" +#include "cuda_common.h" namespace Generators { namespace cuda { @@ -124,17 +125,20 @@ void LaunchBeamSearchOnlineTopKStage2Kernel( if (parts_per_beam <= 32) { BeamSearchOnlineTopKStage2Kernel<<>>( topk_values_tmp, topk_indices_tmp, K, vocab_size, parts_per_beam, output_values, output_indices); + CUDA_CHECK_LAUNCH(); return; } if (parts_per_beam <= 64) { BeamSearchOnlineTopKStage2Kernel<<>>( topk_values_tmp, topk_indices_tmp, K, vocab_size, parts_per_beam, output_values, output_indices); + CUDA_CHECK_LAUNCH(); return; } BeamSearchOnlineTopKStage2Kernel<<>>( topk_values_tmp, topk_indices_tmp, K, vocab_size, parts_per_beam, output_values, output_indices); + CUDA_CHECK_LAUNCH(); return; } @@ -160,12 +164,13 @@ void TopKLauncherMaxK( dim3 grid(batch_beam_size, voc_parts); - cudaFuncSetAttribute(BeamSearchOnlineTopKStage1Kernel, + CUDA_CHECK(cudaFuncSetAttribute(BeamSearchOnlineTopKStage1Kernel, cudaFuncAttributePreferredSharedMemoryCarveout, - cudaSharedmemCarveoutMaxL1); + cudaSharedmemCarveoutMaxL1)); BeamSearchOnlineTopKStage1Kernel <<>>(input, K, vocab_size, (vocab_size + voc_parts - 1) / voc_parts, output_values_tmp, output_indices_tmp); + CUDA_CHECK_LAUNCH(); LaunchBeamSearchOnlineTopKStage2Kernel( output_values_tmp, @@ -242,6 +247,7 @@ void LaunchBatchTopKKernel(const T* topk_scores, } else { BatchTopKKernelLauncher(64); } + CUDA_CHECK_LAUNCH(); } template void LaunchBatchTopKKernel(const float* topk_scores, @@ -301,6 +307,7 @@ void BeamSearchTopK( num_beams, k, stream); + CUDA_CHECK_LAUNCH(); } } // namespace cuda diff --git a/src/cuda/cuda_common.h b/src/cuda/cuda_common.h index f749879d2e..ade52d6195 100644 --- a/src/cuda/cuda_common.h +++ b/src/cuda/cuda_common.h @@ -16,22 +16,65 @@ namespace Generators { cudaStream_t GetStream(); -void OnCudaError(cudaError_t error); +#define CeilDiv(a, b) ((a + (b - 1)) / b) -struct CudaCheck { - void operator==(cudaError_t error) { - if (error != cudaSuccess) - OnCudaError(error); - } +class CudaError : public std::runtime_error { + public: + explicit CudaError(const std::string& msg, cudaError_t code) + : std::runtime_error(msg), code_(code) {} + + cudaError_t code() const noexcept { return code_; } + + private: + cudaError_t code_; }; +#define CUDA_CHECK(call) \ + do { \ + cudaError_t err = (call); \ + if (err != cudaSuccess) { \ + std::stringstream ss; \ + ss << "CUDA error in " << __func__ << " at " << __FILE__ \ + << ":" << __LINE__ << " - " << cudaGetErrorString(err); \ + (void)cudaGetLastError(); \ + throw Generators::CudaError(ss.str(), err); \ + } \ + } while (0) + +#ifdef NDEBUG +#define CUDA_CHECK_LAUNCH() \ + do { \ + cudaError_t err = cudaPeekAtLastError(); \ + if (err != cudaSuccess) { \ + std::stringstream ss; \ + ss << "CUDA launch error in " << __func__ << " at " \ + << __FILE__ << ":" << __LINE__ << " - " \ + << cudaGetErrorString(err); \ + (void)cudaGetLastError(); \ + throw Generators::CudaError(ss.str(), err); \ + } \ + } while (0) +#else +#define CUDA_CHECK_LAUNCH() \ + do { \ + cudaError_t err = cudaGetLastError(); \ + if (err != cudaSuccess) { \ + std::stringstream ss; \ + ss << "CUDA launch error in " << __func__ << " at " \ + << __FILE__ << ":" << __LINE__ << " - " \ + << cudaGetErrorString(err); \ + throw Generators::CudaError(ss.str(), err); \ + } \ + } while (0) +#endif + struct cuda_event_holder { cuda_event_holder() { - cudaEventCreate(&v_); + CUDA_CHECK(cudaEventCreate(&v_)); } cuda_event_holder(unsigned flags) { - cudaEventCreateWithFlags(&v_, flags); + CUDA_CHECK(cudaEventCreateWithFlags(&v_, flags)); } ~cuda_event_holder() { @@ -48,7 +91,7 @@ struct cuda_event_holder { struct cuda_stream_holder { void Create() { assert(!v_); - cudaStreamCreate(&v_); + CUDA_CHECK(cudaStreamCreate(&v_)); } ~cuda_stream_holder() { @@ -81,7 +124,7 @@ using cuda_host_unique_ptr = std::unique_ptr; template cuda_host_unique_ptr CudaMallocHostArray(size_t count, std::span* p_span = nullptr) { T* p; - ::cudaMallocHost(&p, sizeof(T) * count); + CUDA_CHECK(::cudaMallocHost(&p, sizeof(T) * count)); if (p_span) *p_span = std::span(p, count); return cuda_host_unique_ptr{p}; @@ -99,60 +142,10 @@ using cuda_unique_ptr = std::unique_ptr; template cuda_unique_ptr CudaMallocArray(size_t count, std::span* p_span = nullptr) { T* p; - ::cudaMalloc(&p, sizeof(T) * count); + CUDA_CHECK(::cudaMalloc(&p, sizeof(T) * count)); if (p_span) *p_span = std::span(p, count); return cuda_unique_ptr{p}; } -#define CeilDiv(a, b) ((a + (b - 1)) / b) - -class CudaError : public std::runtime_error { - public: - explicit CudaError(const std::string& msg, cudaError_t code) - : std::runtime_error(msg), code_(code) {} - - cudaError_t code() const noexcept { return code_; } - - private: - cudaError_t code_; -}; - -#define CUDA_CHECK(call) \ - do { \ - cudaError_t err = (call); \ - if (err != cudaSuccess) { \ - std::stringstream ss; \ - ss << "CUDA error in " << __func__ << " at " << __FILE__ \ - << ":" << __LINE__ << " - " << cudaGetErrorString(err); \ - throw Generators::CudaError(ss.str(), err); \ - } \ - } while (0) - -#ifdef NDEBUG -#define CUDA_CHECK_LAUNCH() \ - do { \ - cudaError_t err = cudaPeekAtLastError(); \ - if (err != cudaSuccess) { \ - std::stringstream ss; \ - ss << "CUDA launch error in " << __func__ << " at " \ - << __FILE__ << ":" << __LINE__ << " - " \ - << cudaGetErrorString(err); \ - throw Generators::CudaError(ss.str(), cudaGetLastError()); \ - } \ - } while (0) -#else -#define CUDA_CHECK_LAUNCH() \ - do { \ - cudaError_t err = cudaGetLastError(); \ - if (err != cudaSuccess) { \ - std::stringstream ss; \ - ss << "CUDA launch error in " << __func__ << " at " \ - << __FILE__ << ":" << __LINE__ << " - " \ - << cudaGetErrorString(err); \ - throw Generators::CudaError(ss.str(), err); \ - } \ - } while (0) -#endif - } // namespace Generators diff --git a/src/cuda/cuda_topk.cu b/src/cuda/cuda_topk.cu index 2a13ae899f..1f9160c4d9 100644 --- a/src/cuda/cuda_topk.cu +++ b/src/cuda/cuda_topk.cu @@ -6,6 +6,7 @@ #include #include "cuda_topk.h" +#include "cuda_common.h" #include "cuda_topk_benchmark_cache.h" #include "cuda_topk_benchmark.cuh" #include "cuda_topk_common.cuh" @@ -153,6 +154,7 @@ void TopkDataCompact::CompactOutput(int batch_size, int k, cudaStream_t stream) dim3 block(256); CompactStridedData<<>>(topk_scores, topk_scores_compact.get(), k, batch_size, topk_stride); CompactStridedData<<>>(topk_indices, topk_indices_compact.get(), k, batch_size, topk_stride); + CUDA_CHECK_LAUNCH(); } // Main dispatcher for Top-K. It implements the caching and benchmarking logic to select and run the best algorithm. diff --git a/src/cuda/cuda_topk_per_batch_radix_sort.cuh b/src/cuda/cuda_topk_per_batch_radix_sort.cuh index 6fb4329517..b594bc76cd 100644 --- a/src/cuda/cuda_topk_per_batch_radix_sort.cuh +++ b/src/cuda/cuda_topk_per_batch_radix_sort.cuh @@ -76,7 +76,7 @@ void RunTopK(TopkData* data, cudaStream_t stream, const float* scores_in, int vo // Populate workspace buffers with the current batch item's scores and indices. FillInput<<>>(current_scores_in, workspace_scores, workspace_indices, vocab_size); // Launch the CUB radix sort. It sorts from the workspace directly into the final output buffers. - cub::DeviceRadixSort::SortPairsDescending(temp_storage, temp_storage_bytes, workspace_scores, final_scores_out, workspace_indices, final_indices_out, vocab_size, 0, sizeof(float) * 8, stream); + CUDA_CHECK(cub::DeviceRadixSort::SortPairsDescending(temp_storage, temp_storage_bytes, workspace_scores, final_scores_out, workspace_indices, final_indices_out, vocab_size, 0, sizeof(float) * 8, stream)); } data->topk_scores = final_scores_buffer; diff --git a/src/cuda/model_kernels.cu b/src/cuda/model_kernels.cu index 09f42ae903..0f96276b8a 100644 --- a/src/cuda/model_kernels.cu +++ b/src/cuda/model_kernels.cu @@ -7,6 +7,7 @@ #include #include #include +#include "cuda_common.h" namespace Generators { namespace cuda { @@ -36,6 +37,7 @@ void Launch_UpdatePositionIds(T* positions, int batch_beam_size, int total_lengt // For batch size > 1 we increment position ids by 1... continuous decoding is not supported UpdatePositionIds<<<(batch_beam_size + 255) / 256, 256, 0, stream>>>(positions, batch_beam_size); } + CUDA_CHECK_LAUNCH(); } template void Launch_UpdatePositionIds(int32_t* positions, int batch_beam_size, int total_length, int new_kv_length, cudaStream_t stream); @@ -77,6 +79,7 @@ void Launch_UpdateAttentionMask(T* next_mask_data, T* mask_data, int batch_beam_ int blocks = (batch_beam_size * total_length + threads - 1) / threads; CopyAndUpdateAttentionMask<<>>(next_mask_data, mask_data, batch_beam_size, new_kv_length, total_length); } + CUDA_CHECK_LAUNCH(); } template void Launch_UpdateAttentionMask(int32_t* next_mask_data, int32_t* mask_data, int batch_beam_size, int new_kv_length, int total_length, int max_length, bool update_only, cudaStream_t stream); @@ -96,6 +99,7 @@ void LaunchAddLogitsMask(float* batch_logits, int batch_beam_size, int vocab_siz int block_size = 256; int num_blocks = (batch_beam_size * vocab_size + block_size - 1) / block_size; AddLogitsMask<<>>(batch_logits, batch_beam_size, vocab_size, logits_mask); + CUDA_CHECK_LAUNCH(); } __global__ void ConvertFp16ToFp32(const half* src, float* dst, int count) { @@ -108,6 +112,7 @@ void LaunchFp16ToFp32(const uint16_t* fp16, float* fp32, int count, cudaStream_t int block_size = 256; int num_blocks = (count + block_size - 1) / block_size; ConvertFp16ToFp32<<>>(reinterpret_cast(fp16), fp32, count); + CUDA_CHECK_LAUNCH(); } __global__ void ConvertFp32ToFp16(const float* src, half* dst, int count) { @@ -120,6 +125,7 @@ void LaunchFp32ToFp16(const float* fp32, uint16_t* fp16, int count, cudaStream_t int block_size = 256; int num_blocks = (count + block_size - 1) / block_size; ConvertFp32ToFp16<<>>(fp32, reinterpret_cast(fp16), count); + CUDA_CHECK_LAUNCH(); } __global__ void ConvertInt32ToInt64(const int32_t* src, int64_t* dst, int count) { @@ -133,6 +139,7 @@ void LaunchInt32ToInt64(const int32_t* src, int64_t* dst, int count, cudaStream_ int block_size = 256; int num_blocks = (count + block_size - 1) / block_size; ConvertInt32ToInt64<<>>(src, dst, count); + CUDA_CHECK_LAUNCH(); } namespace { @@ -197,6 +204,7 @@ void ReorderPastStatesKernelLauncher(void* out_buffer, num_heads, max_length, chunked_head_size); + CUDA_CHECK_LAUNCH(); } } @@ -258,6 +266,7 @@ void UpdateCacheIndirectionKernelLauncher(int32_t* tgt_indir_cache, input_seq_length, max_seq_length, current_length); + CUDA_CHECK_LAUNCH(); } template @@ -325,6 +334,7 @@ void LaunchCopyCrossQKSingleDecodeStep(cudaStream_t stream, max_length, sequence_length); } + CUDA_CHECK_LAUNCH(); } template void LaunchCopyCrossQKSingleDecodeStep(cudaStream_t stream, @@ -419,6 +429,7 @@ void LaunchFinalizeCrossQK(cudaStream_t stream, cross_qk_output, cache_indir_data); } + CUDA_CHECK_LAUNCH(); } template void LaunchFinalizeCrossQK(cudaStream_t stream, diff --git a/src/cuda/search_cuda.cpp b/src/cuda/search_cuda.cpp index b306d0f473..cc4d5c01cf 100644 --- a/src/cuda/search_cuda.cpp +++ b/src/cuda/search_cuda.cpp @@ -5,6 +5,7 @@ #include "interface.h" #include "search.h" #include "search_cuda.h" +#include "cuda_common.h" #include "beam_search_scorer_cuda.cuh" #include "beam_search_scorer_cuda.h" #include "beam_search_topk.h" @@ -13,12 +14,6 @@ namespace Generators { -void OnCudaError(cudaError_t error) { - printf("Cuda Error: %s\n", cudaGetErrorString(error)); - assert(false); - throw std::exception(); -} - Search_Cuda::Search_Cuda(const GeneratorParams& params) : Search{params} { auto batch_beam_size = params.BatchBeamSize(); @@ -69,14 +64,14 @@ BeamSearch_Cuda::BeamSearch_Cuda(const GeneratorParams& params) topk_buffer_ = CudaMallocArray(topk_buffer_size); static_assert(sizeof(float) == sizeof(int32_t)); // The topk_buffer assumes these match, fix for float16 - cudaMemsetAsync(topk_buffer_.get(), 0, topk_buffer_size * sizeof(float), GetStream()); + CUDA_CHECK(cudaMemsetAsync(topk_buffer_.get(), 0, topk_buffer_size * sizeof(float), GetStream())); } BeamSearch_Cuda::~BeamSearch_Cuda() = default; void Search_Cuda::ResetDone() { *done_cpu_ = false; - cudaMemsetAsync(eos_seen_.data(), 0, eos_seen_.size_bytes(), GetStream()); + CUDA_CHECK(cudaMemsetAsync(eos_seen_.data(), 0, eos_seen_.size_bytes(), GetStream())); } DeviceSpan Search_Cuda::GetLogits() const { @@ -105,8 +100,8 @@ void BeamSearch_Cuda::SelectTop() { // Copy next_token_scores to CPU auto next_token_scores_cpu = CudaMallocHostArray(params_->BatchBeamSize() * params_->config.model.vocab_size); - cudaMemcpyAsync(next_token_scores_cpu.get(), softmax_buffer_.get(), params_->BatchBeamSize() * params_->config.model.vocab_size * sizeof(float), cudaMemcpyDeviceToHost, GetStream()); - CudaCheck() == cudaStreamSynchronize(GetStream()); + CUDA_CHECK(cudaMemcpyAsync(next_token_scores_cpu.get(), softmax_buffer_.get(), params_->BatchBeamSize() * params_->config.model.vocab_size * sizeof(float), cudaMemcpyDeviceToHost, GetStream())); + CUDA_CHECK(cudaStreamSynchronize(GetStream())); auto beam_scores = beam_scorer_->GetNextScores(); @@ -140,7 +135,7 @@ void BeamSearch_Cuda::SelectTop() { } else assert(false); - CudaCheck() == cudaStreamSynchronize(GetStream()); + CUDA_CHECK(cudaStreamSynchronize(GetStream())); size_t size = params_->BatchBeamSize() * 2; std::span next_scores{topk_next_scores_.get(), size}; @@ -170,7 +165,7 @@ void GreedySearch_Cuda::SampleTopKTopP(int k, float p, float temperature) { cuda::Launch_CheckForEOSAndPad(next_tokens_.data(), static_cast(next_tokens_.size()), eos_seen_.data(), eos_token_ids_.Span().data(), static_cast(eos_token_ids_.Span().size()), params_->config.model.pad_token_id, done_cpu_.get(), GetStream()); // Append tokens - cudaStreamSynchronize(GetStream()); + CUDA_CHECK(cudaStreamSynchronize(GetStream())); if (!*done_cpu_) { cuda::Launch_AppendNextTokensToSequences(next_tokens_buffer_.Span(), sequences_.GetSequences().Span(), params_->BatchBeamSize(), sequences_.GetSequenceLength(), sequences_.max_length_, GetStream()); sequences_.AfterAppendNextTokens(next_tokens_buffer_, params_->BatchBeamSize()); @@ -246,7 +241,7 @@ void BeamSearch_Cuda::AppendTokens(DeviceSpan& next_tokens) { cuda::Launch_ExpandInputSequences(next_tokens_gpu, sequences_.GetNextSequences().Span(), params_->search.batch_size, params_->search.num_beams, sequences_.max_length_, GetStream()); cuda::Launch_ExpandInputSequences(next_tokens_gpu, sequences_.GetSequences().Span(), params_->search.batch_size, params_->search.num_beams, sequences_.max_length_, GetStream()); sequences_.AfterAppendNextTokens(next_tokens, params_->search.batch_size); // next_tokens is batch_size - cudaStreamSynchronize(GetStream()); + CUDA_CHECK(cudaStreamSynchronize(GetStream())); } void GreedySearch_Cuda::RewindTo(size_t index) { @@ -254,7 +249,7 @@ void GreedySearch_Cuda::RewindTo(size_t index) { if (index > 0) cuda::Launch_GetLastTokens(next_tokens_.data(), sequences_.GetSequences().Span().data(), static_cast(params_->BatchBeamSize()), static_cast(index), sequences_.max_length_, GetStream()); else - cudaMemsetAsync(next_tokens_.data(), 0, params_->search.batch_size * sizeof(int32_t), GetStream()); + CUDA_CHECK(cudaMemsetAsync(next_tokens_.data(), 0, params_->search.batch_size * sizeof(int32_t), GetStream())); sequences_.RewindTo(index); } diff --git a/src/cuda/search_cuda.cu b/src/cuda/search_cuda.cu index e72689291b..f5b52fe9bd 100644 --- a/src/cuda/search_cuda.cu +++ b/src/cuda/search_cuda.cu @@ -27,6 +27,7 @@ void Launch_ExpandInputSequences(const std::span input_sequences, std:: const int total_elements = static_cast(input_sequences.size()); const int new_length = total_elements / batch_size; ExpandInputSequences<<<1, 1, 0, stream>>>(input_sequences.data(), sequences.data(), batch_size, beam_size, new_length, max_length); + CUDA_CHECK_LAUNCH(); } __global__ void AppendNextTokensToSequences(const int32_t* next_tokens, int32_t* sequences, int batch_beam_size, int past_length, int new_length, int max_length) { @@ -45,6 +46,7 @@ void Launch_AppendNextTokensToSequences(std::span next_tokens, st const int gridSize = (total_elements + blockSize - 1) / blockSize; const int new_length = total_elements / batch_beam_size; AppendNextTokensToSequences<<>>(next_tokens.data(), sequences.data(), batch_beam_size, past_length, new_length, max_length); + CUDA_CHECK_LAUNCH(); } __global__ void GetLastTokens(int32_t* next_tokens, const int32_t* sequences, int batch_beam_size, int sequence_length, int max_length) { @@ -59,6 +61,7 @@ void Launch_GetLastTokens(int32_t* next_tokens, const int32_t* sequences, int ba const int blockSize = std::min(batch_beam_size, 256); const int gridSize = (batch_beam_size + blockSize - 1) / blockSize; GetLastTokens<<>>(next_tokens, sequences, batch_beam_size, sequence_length, max_length); + CUDA_CHECK_LAUNCH(); } __global__ void ArgMax(cub::KeyValuePair* argmaxen, int32_t* next_tokens, int batch_size) { @@ -110,6 +113,7 @@ __global__ void CheckForEOSAndPad(int32_t* next_tokens, int next_tokens_count, b void Launch_CheckForEOSAndPad(int32_t* next_tokens, int next_tokens_count, bool* eos_seen, const int* eos_token_ids, int eos_token_count, int pad_token_id, bool* done_cpu, cudaStream_t stream) { CheckForEOSAndPad<<<1, 1, 0, stream>>>(next_tokens, next_tokens_count, eos_seen, eos_token_ids, eos_token_count, pad_token_id, done_cpu); + CUDA_CHECK_LAUNCH(); } __global__ void AddProbsKernel(float* log_probs, @@ -133,6 +137,7 @@ void LaunchAddProbsKernel(float* log_probs, constexpr int blockSize = 256; const int gridSize = (total_elements + blockSize - 1) / blockSize; AddProbsKernel<<>>(log_probs, cum_log_probs, vocab_size, total_elements); + CUDA_CHECK_LAUNCH(); } __global__ void SetScoreProcessor(float* next_token_scores, int batch_beam_size, int vocab_size, int token, float score) { @@ -149,6 +154,7 @@ void LaunchSetScoreProcessor(float* next_token_scores, int batch_beam_size, int const int gridSize = (total_elements + blockSize - 1) / blockSize; SetScoreProcessor<<>>(next_token_scores, batch_beam_size, vocab_size, token, score); + CUDA_CHECK_LAUNCH(); } __global__ void RepetitionPenaltyProcessor(const int32_t* sequences, float* next_token_scores, int max_sequence_length, int vocab_size, int total_elements, int current_sequence_length, float repetition_penalty) { @@ -179,6 +185,7 @@ void LaunchRepetitionPenaltyProcessor(const int32_t* sequences, float* next_toke const int gridSize = (total_elements + blockSize - 1) / blockSize; RepetitionPenaltyProcessor<<>>(sequences, next_token_scores, max_sequence_length, vocab_size, total_elements, current_sequence_length, repetition_penalty); + CUDA_CHECK_LAUNCH(); } } // namespace cuda From 5b0e36675799375fce3dc2afc34b6b5ec79f0b0b Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Fri, 27 Feb 2026 14:09:14 -0800 Subject: [PATCH 7/8] refine --- src/cuda/beam_search_topk.cu | 4 ++-- src/cuda/cuda_common.h | 18 ++---------------- 2 files changed, 4 insertions(+), 18 deletions(-) diff --git a/src/cuda/beam_search_topk.cu b/src/cuda/beam_search_topk.cu index 0186017f5a..34abfa6fdb 100644 --- a/src/cuda/beam_search_topk.cu +++ b/src/cuda/beam_search_topk.cu @@ -165,8 +165,8 @@ void TopKLauncherMaxK( dim3 grid(batch_beam_size, voc_parts); CUDA_CHECK(cudaFuncSetAttribute(BeamSearchOnlineTopKStage1Kernel, - cudaFuncAttributePreferredSharedMemoryCarveout, - cudaSharedmemCarveoutMaxL1)); + cudaFuncAttributePreferredSharedMemoryCarveout, + cudaSharedmemCarveoutMaxL1)); BeamSearchOnlineTopKStage1Kernel <<>>(input, K, vocab_size, (vocab_size + voc_parts - 1) / voc_parts, output_values_tmp, output_indices_tmp); diff --git a/src/cuda/cuda_common.h b/src/cuda/cuda_common.h index ade52d6195..76cd502307 100644 --- a/src/cuda/cuda_common.h +++ b/src/cuda/cuda_common.h @@ -41,32 +41,18 @@ class CudaError : public std::runtime_error { } \ } while (0) -#ifdef NDEBUG -#define CUDA_CHECK_LAUNCH() \ - do { \ - cudaError_t err = cudaPeekAtLastError(); \ - if (err != cudaSuccess) { \ - std::stringstream ss; \ - ss << "CUDA launch error in " << __func__ << " at " \ - << __FILE__ << ":" << __LINE__ << " - " \ - << cudaGetErrorString(err); \ - (void)cudaGetLastError(); \ - throw Generators::CudaError(ss.str(), err); \ - } \ - } while (0) -#else #define CUDA_CHECK_LAUNCH() \ do { \ - cudaError_t err = cudaGetLastError(); \ + cudaError_t err = cudaPeekAtLastError(); \ if (err != cudaSuccess) { \ std::stringstream ss; \ ss << "CUDA launch error in " << __func__ << " at " \ << __FILE__ << ":" << __LINE__ << " - " \ << cudaGetErrorString(err); \ + (void)cudaGetLastError(); \ throw Generators::CudaError(ss.str(), err); \ } \ } while (0) -#endif struct cuda_event_holder { cuda_event_holder() { From ce8bab987c5bc51ab099abe1804500b40700aac7 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Fri, 27 Feb 2026 15:55:32 -0800 Subject: [PATCH 8/8] Address PR review feedback: FillInput check, redundant check, hybrid_sort condition - Add missing CUDA_CHECK_LAUNCH() after FillInput kernel launch in cuda_topk_per_batch_radix_sort.cuh - Remove redundant CUDA_CHECK_LAUNCH() after LaunchBatchTopKKernel call in BeamSearchTopK (already checked inside LaunchBatchTopKKernel) - Expand hybrid_sort benchmark condition to also run when none of the cooperative kernels are supported, not just when best_algo is UNKNOWN --- src/cuda/beam_search_topk.cu | 1 - src/cuda/cuda_topk_benchmark.cuh | 2 +- src/cuda/cuda_topk_per_batch_radix_sort.cuh | 1 + 3 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/cuda/beam_search_topk.cu b/src/cuda/beam_search_topk.cu index 34abfa6fdb..4c930413d1 100644 --- a/src/cuda/beam_search_topk.cu +++ b/src/cuda/beam_search_topk.cu @@ -307,7 +307,6 @@ void BeamSearchTopK( num_beams, k, stream); - CUDA_CHECK_LAUNCH(); } } // namespace cuda diff --git a/src/cuda/cuda_topk_benchmark.cuh b/src/cuda/cuda_topk_benchmark.cuh index c19bf65e41..e7519daaba 100644 --- a/src/cuda/cuda_topk_benchmark.cuh +++ b/src/cuda/cuda_topk_benchmark.cuh @@ -157,7 +157,7 @@ static TopkAlgo BenchmarkAndSelectBestAlgo(TopkData* topk_data, // kernels are not supported, if their benchmarks all failed at runtime (best_algo is still // UNKNOWN despite IsSupported returning true), or if the vocab size is small, where hybrid // can sometimes be faster. - if (best_algo == TopkAlgo::UNKNOWN || vocab_size <= 4096) { + if (best_algo == TopkAlgo::UNKNOWN || (!use_iterative_sort && !use_cascaded_sort && !use_flash_convergent) || vocab_size <= 4096) { if (hybrid_sort::IsSupported(batch_size, vocab_size, k)) { BENCHMARK_KERNEL(TopkAlgo::HYBRID, [&]() { hybrid_sort::RunTopK(topk_data, stream, scores_in, vocab_size, batch_size, k); diff --git a/src/cuda/cuda_topk_per_batch_radix_sort.cuh b/src/cuda/cuda_topk_per_batch_radix_sort.cuh index b594bc76cd..e1c0567119 100644 --- a/src/cuda/cuda_topk_per_batch_radix_sort.cuh +++ b/src/cuda/cuda_topk_per_batch_radix_sort.cuh @@ -75,6 +75,7 @@ void RunTopK(TopkData* data, cudaStream_t stream, const float* scores_in, int vo // Populate workspace buffers with the current batch item's scores and indices. FillInput<<>>(current_scores_in, workspace_scores, workspace_indices, vocab_size); + CUDA_CHECK_LAUNCH(); // Launch the CUB radix sort. It sorts from the workspace directly into the final output buffers. CUDA_CHECK(cub::DeviceRadixSort::SortPairsDescending(temp_storage, temp_storage_bytes, workspace_scores, final_scores_out, workspace_indices, final_indices_out, vocab_size, 0, sizeof(float) * 8, stream)); }