Skip to content

Harden CUDA error checking across the codebase#1994

Merged
tianleiwu merged 8 commits into
mainfrom
copilot/analyze-cuda-errors
Mar 1, 2026
Merged

Harden CUDA error checking across the codebase#1994
tianleiwu merged 8 commits into
mainfrom
copilot/analyze-cuda-errors

Conversation

Copy link
Copy Markdown
Contributor

Copilot AI commented Feb 27, 2026

Summary

Comprehensive audit and hardening of CUDA error checking across all kernel launches and CUDA API calls. Fixes stale CUDA error cascading that caused false error reports on GPUs with fewer SMs (e.g., RTX 4050) like the following:

Benchmarking failed for Cascaded_Sort kernel with k=40, batch_size=1, vocab_size=151936. Error: CUDA launch error in RunTopK at E:\_work\1\onnxruntime-genai\src\cuda\cuda_topk_cascaded_sort.cuh:350 - invalid configuration argument
Benchmarking failed for Iterative_Sort kernel with k=40, batch_size=1, vocab_size=151936. Error: CUDA launch error in RunTopK at E:\_work\1\onnxruntime-genai\src\cuda\cuda_topk_iterative_sort.cuh:268 - invalid configuration argument
Benchmarking failed for Flash_Convergent kernel with k=40, batch_size=1, vocab_size=151936. Error: CUDA launch error in RunTopK at E:\_work\1\onnxruntime-genai\src\cuda\cuda_topk_flash_convergent.cuh:272 - invalid configuration argument
Benchmarking failed for Segmented_Radix kernel with k=40, batch_size=1, vocab_size=151936. Error: CUDA error in LaunchSortPairs at E:\_work\1\onnxruntime-genai\src\cuda\cuda_topk_full_sort.cuh:70 - invalid argument 

Changes

Fix stale CUDA error cascading in TopK benchmark (original commits)

  • Clear stale CUDA errors at benchmark entry: cudaGetLastError() at the top of BenchmarkAndSelectBestAlgo prevents pre-existing errors from being misattributed to the first benchmark kernel
  • Fix hybrid_sort fallback condition: Changed from !use_iterative_sort && !use_cascaded_sort && !use_flash_convergent to best_algo == TopkAlgo::UNKNOWN. Previously, if cooperative kernels passed IsSupported but failed at runtime, hybrid_sort was skipped entirely.

Add CUDA error checking to all kernel launches and API calls

  • Kernel launches: Added CUDA_CHECK_LAUNCH() after every kernel launch wrapper in beam_search_scorer_cuda.cu, beam_search_topk.cu, model_kernels.cu, search_cuda.cu, and cuda_topk.cu
  • CUDA API calls: Wrapped all unchecked cudaMemsetAsync, cudaMemcpyAsync, cudaStreamSynchronize, cudaEventRecord, cudaEventSynchronize calls with CUDA_CHECK() in search_cuda.cpp and beam_search_scorer_cuda.cpp

Improve error macros and helpers in cuda_common.h

  • Restructure header: Moved CudaError class and CUDA_CHECK/CUDA_CHECK_LAUNCH macro definitions before RAII helpers that depend on them
  • Harden RAII helpers: Wrapped cudaEventCreate, cudaEventCreateWithFlags, cudaStreamCreate, cudaMallocHost, cudaMalloc with CUDA_CHECK()
  • Fix CUDA_CHECK: Added cudaGetLastError() call to clear stale thread-local error state before throwing
  • Fix CUDA_CHECK_LAUNCH (NDEBUG): Now clears error with cudaGetLastError() separately and throws the detected error code (instead of potentially stale cudaGetLastError() return)
  • Remove dead code: Removed unused OnCudaError() and CudaCheck struct

Add missing includes

  • Added explicit #include "cuda_common.h" in files that use the macros directly: beam_search_scorer_cuda.cu, beam_search_topk.cu, model_kernels.cu, search_cuda.cpp, beam_search_scorer_cuda.cpp, cuda_topk.cu

… consistently

Two fixes to prevent stale CUDA errors from cascading across TopK algorithm
benchmarks:

1. CUDA_CHECK_LAUNCH: Remove the NDEBUG-conditional branch that used
   cudaPeekAtLastError() in release builds. Now consistently uses
   cudaGetLastError() which clears the error after detection, preventing
   stale errors from propagating to subsequent kernel launches.

2. BENCHMARK_KERNEL: Add cudaGetLastError() in the catch handler to
   explicitly clear the CUDA error state after a benchmark failure,
   ensuring the next algorithm starts with a clean error state.

These changes fix the issue where all TopK algorithms fail in cascade
for vocab_size=151936 (Qwen2 models). The root cause was that a genuine
cooperative launch failure in one algorithm would leave a stale error
that was falsely detected by subsequent algorithms' CUDA_CHECK_LAUNCH
calls in release builds.

Co-authored-by: tianleiwu <30328909+tianleiwu@users.noreply.github.com>
Copilot AI changed the title [WIP] Analyze causes of CUDA launch errors in benchmark Fix cascading CUDA errors in TopK benchmark due to stale error state Feb 27, 2026
Comment thread src/cuda/cuda_common.h
Co-authored-by: tianleiwu <30328909+tianleiwu@users.noreply.github.com>
Two fixes to prevent TopK benchmark failures on GPUs with fewer SMs
(e.g., RTX 4050 with 18 SMs):

1. Clear stale CUDA errors at the start of BenchmarkAndSelectBestAlgo.
   The IsSupportedCooperative check correctly uses
   cudaOccupancyMaxActiveBlocksPerMultiprocessor to allow multiple blocks
   per SM for cooperative launches (e.g., 60 blocks on 18 SMs with 4
   blocks/SM occupancy = 72 max). However, stale CUDA errors from prior
   operations persist in the thread-local error state and are falsely
   detected by CUDA_CHECK_LAUNCH() inside benchmark kernels.

2. Fix hybrid_sort fallback: when cooperative kernel benchmarks all fail
   at runtime (caught by try-catch), best_algo remains UNKNOWN but
   hybrid_sort was being skipped because IsSupported returned true. Now
   hybrid_sort is also tried when best_algo == UNKNOWN, providing a
   robust fallback path before resorting to full_sort.

Co-authored-by: tianleiwu <30328909+tianleiwu@users.noreply.github.com>
Copilot AI changed the title Fix cascading CUDA errors in TopK benchmark due to stale error state Fix stale CUDA error cascading in TopK benchmark and improve fallback logic Feb 27, 2026
@tianleiwu tianleiwu force-pushed the copilot/analyze-cuda-errors branch from f953ec3 to 764de80 Compare February 27, 2026 19:44
@tianleiwu tianleiwu marked this pull request as ready for review February 27, 2026 19:45
@tianleiwu tianleiwu enabled auto-merge (squash) February 27, 2026 19:55
baijumeswani
baijumeswani previously approved these changes Feb 27, 2026
@tianleiwu tianleiwu disabled auto-merge February 27, 2026 21:27
- Wrap all unchecked cudaMemset/cudaMemcpy/cudaStreamSync/cudaEventRecord/
  cudaEventSync calls with CUDA_CHECK()
- Add CUDA_CHECK_LAUNCH() after every kernel launch wrapper function
- Add #include cuda_common.h where macros are used directly
- Restructure cuda_common.h: move macros before RAII helpers that use them
- Wrap cudaMalloc/cudaMallocHost/cudaEventCreate/cudaStreamCreate in RAII
  helpers with CUDA_CHECK()
- Remove unused OnCudaError/CudaCheck legacy error helpers
- Make CUDA_CHECK clear stale error via cudaGetLastError() before throwing
- Make CUDA_CHECK_LAUNCH (NDEBUG) consistent: clear with cudaGetLastError(),
  throw with the detected error code
@tianleiwu tianleiwu changed the title Fix stale CUDA error cascading in TopK benchmark and improve fallback logic Harden CUDA error checking across the codebase Feb 27, 2026
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR hardens CUDA error checking across the codebase by adding CUDA_CHECK_LAUNCH() after every kernel launch and CUDA_CHECK() around all CUDA API calls in the CUDA backend. It also fixes a stale CUDA error cascade issue on GPUs with fewer SMs (e.g., RTX 4050) where pre-existing thread-local CUDA errors could cause false benchmark failures.

Changes:

  • Restructures cuda_common.h to move CudaError, CUDA_CHECK, and CUDA_CHECK_LAUNCH macros before RAII helpers that depend on them; removes dead code (OnCudaError, CudaCheck); unifies CUDA_CHECK_LAUNCH for debug/release to use cudaPeekAtLastError()
  • Adds CUDA_CHECK_LAUNCH() after every kernel launch in beam_search_scorer_cuda.cu, beam_search_topk.cu, model_kernels.cu, search_cuda.cu, and cuda_topk.cu
  • Wraps unchecked CUDA API calls with CUDA_CHECK() in search_cuda.cpp, beam_search_scorer_cuda.cpp, and beam_search_topk.cu; fixes BenchmarkAndSelectBestAlgo fallback condition and clears stale CUDA errors at benchmark entry

Reviewed changes

Copilot reviewed 10 out of 10 changed files in this pull request and generated 3 comments.

Show a summary per file
File Description
src/cuda/cuda_common.h Restructured: moved error types/macros before RAII helpers; replaced dead OnCudaError/CudaCheck with unified CUDA_CHECK/CUDA_CHECK_LAUNCH; added error checking to cudaMalloc, cudaMallocHost, cudaEventCreate, cudaStreamCreate
src/cuda/cuda_topk_benchmark.cuh Clears stale errors at benchmark entry; changes hybrid_sort fallback condition to best_algo == UNKNOWN
src/cuda/search_cuda.cu Adds CUDA_CHECK_LAUNCH() after 6 kernel launches
src/cuda/search_cuda.cpp Adds CUDA_CHECK() around cudaMemsetAsync, cudaMemcpyAsync, cudaStreamSynchronize; removes unused OnCudaError
src/cuda/model_kernels.cu Adds CUDA_CHECK_LAUNCH() after 9 kernel launches; adds cuda_common.h include
src/cuda/cuda_topk.cu Adds CUDA_CHECK_LAUNCH() after 2 kernel launches; adds cuda_common.h include
src/cuda/cuda_topk_per_batch_radix_sort.cuh Wraps cub::DeviceRadixSort::SortPairsDescending with CUDA_CHECK()
src/cuda/beam_search_topk.cu Adds CUDA_CHECK_LAUNCH() after 5 kernel launches; wraps cudaFuncSetAttribute with CUDA_CHECK(); adds cuda_common.h include
src/cuda/beam_search_scorer_cuda.cu Adds CUDA_CHECK_LAUNCH() after 6 kernel launches; adds cuda_common.h include
src/cuda/beam_search_scorer_cuda.cpp Wraps 4 CUDA API calls with CUDA_CHECK(); adds cuda_common.h include

Comment thread src/cuda/cuda_topk_per_batch_radix_sort.cuh
Comment thread src/cuda/beam_search_topk.cu Outdated
Comment thread src/cuda/cuda_topk_benchmark.cuh Outdated
…sort condition

- Add missing CUDA_CHECK_LAUNCH() after FillInput kernel launch in
  cuda_topk_per_batch_radix_sort.cuh
- Remove redundant CUDA_CHECK_LAUNCH() after LaunchBatchTopKKernel call
  in BeamSearchTopK (already checked inside LaunchBatchTopKKernel)
- Expand hybrid_sort benchmark condition to also run when none of the
  cooperative kernels are supported, not just when best_algo is UNKNOWN
@tianleiwu tianleiwu enabled auto-merge (squash) February 28, 2026 00:03
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 10 out of 10 changed files in this pull request and generated 1 comment.

Comment thread src/cuda/cuda_topk.cu
Comment thread src/cuda/cuda_topk_benchmark.cuh
@tianleiwu tianleiwu merged commit 3c47932 into main Mar 1, 2026
20 of 21 checks passed
@tianleiwu tianleiwu deleted the copilot/analyze-cuda-errors branch March 1, 2026 01:55
baijumeswani pushed a commit that referenced this pull request Mar 4, 2026
## Summary

Comprehensive audit and hardening of CUDA error checking across all
kernel launches and CUDA API calls. Fixes stale CUDA error cascading
that caused false error reports on GPUs with fewer SMs (e.g., RTX 4050)
like the following:
```
Benchmarking failed for Cascaded_Sort kernel with k=40, batch_size=1, vocab_size=151936. Error: CUDA launch error in RunTopK at E:\_work\1\onnxruntime-genai\src\cuda\cuda_topk_cascaded_sort.cuh:350 - invalid configuration argument
Benchmarking failed for Iterative_Sort kernel with k=40, batch_size=1, vocab_size=151936. Error: CUDA launch error in RunTopK at E:\_work\1\onnxruntime-genai\src\cuda\cuda_topk_iterative_sort.cuh:268 - invalid configuration argument
Benchmarking failed for Flash_Convergent kernel with k=40, batch_size=1, vocab_size=151936. Error: CUDA launch error in RunTopK at E:\_work\1\onnxruntime-genai\src\cuda\cuda_topk_flash_convergent.cuh:272 - invalid configuration argument
Benchmarking failed for Segmented_Radix kernel with k=40, batch_size=1, vocab_size=151936. Error: CUDA error in LaunchSortPairs at E:\_work\1\onnxruntime-genai\src\cuda\cuda_topk_full_sort.cuh:70 - invalid argument 
```

## Changes

### Fix stale CUDA error cascading in TopK benchmark (original commits)
- **Clear stale CUDA errors at benchmark entry**: `cudaGetLastError()`
at the top of `BenchmarkAndSelectBestAlgo` prevents pre-existing errors
from being misattributed to the first benchmark kernel
- **Fix hybrid_sort fallback condition**: Changed from
`!use_iterative_sort && !use_cascaded_sort && !use_flash_convergent` to
`best_algo == TopkAlgo::UNKNOWN`. Previously, if cooperative kernels
passed `IsSupported` but failed at runtime, `hybrid_sort` was skipped
entirely.

### Add CUDA error checking to all kernel launches and API calls
- **Kernel launches**: Added `CUDA_CHECK_LAUNCH()` after every kernel
launch wrapper in `beam_search_scorer_cuda.cu`, `beam_search_topk.cu`,
`model_kernels.cu`, `search_cuda.cu`, and `cuda_topk.cu`
- **CUDA API calls**: Wrapped all unchecked `cudaMemsetAsync`,
`cudaMemcpyAsync`, `cudaStreamSynchronize`, `cudaEventRecord`,
`cudaEventSynchronize` calls with `CUDA_CHECK()` in `search_cuda.cpp`
and `beam_search_scorer_cuda.cpp`

### Improve error macros and helpers in `cuda_common.h`
- **Restructure header**: Moved `CudaError` class and
`CUDA_CHECK`/`CUDA_CHECK_LAUNCH` macro definitions before RAII helpers
that depend on them
- **Harden RAII helpers**: Wrapped `cudaEventCreate`,
`cudaEventCreateWithFlags`, `cudaStreamCreate`, `cudaMallocHost`,
`cudaMalloc` with `CUDA_CHECK()`
- **Fix CUDA_CHECK**: Added `cudaGetLastError()` call to clear stale
thread-local error state before throwing
- **Fix CUDA_CHECK_LAUNCH (NDEBUG)**: Now clears error with
`cudaGetLastError()` separately and throws the detected error code
(instead of potentially stale `cudaGetLastError()` return)
- **Remove dead code**: Removed unused `OnCudaError()` and `CudaCheck`
struct

### Add missing includes
- Added explicit `#include "cuda_common.h"` in files that use the macros
directly: `beam_search_scorer_cuda.cu`, `beam_search_topk.cu`,
`model_kernels.cu`, `search_cuda.cpp`, `beam_search_scorer_cuda.cpp`,
`cuda_topk.cu`

---------

Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: tianleiwu <30328909+tianleiwu@users.noreply.github.com>
Co-authored-by: Tianlei Wu <tlwu@microsoft.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants