Skip to content

Remove cudaStreamSynchronize from CUDA LLM ops for CUDA graph capture compatibility#27484

Merged
titaiwangms merged 1 commit intomainfrom
titaiwang/refactor_cuda_llm_onnx
Feb 27, 2026
Merged

Remove cudaStreamSynchronize from CUDA LLM ops for CUDA graph capture compatibility#27484
titaiwangms merged 1 commit intomainfrom
titaiwang/refactor_cuda_llm_onnx

Conversation

@titaiwangms
Copy link
Copy Markdown
Contributor

This pull request refactors validation logic for CUDA attention masks and tensor scatter operations to move error checking from host-side (CPU) to device-side (GPU) using CUDA kernel assertions (CUDA_KERNEL_ASSERT). This change eliminates synchronous host-device memory transfers and stream synchronizations, improving performance and simplifying code. Corresponding test cases are updated to only expect validation failures on the CPU, as CUDA errors are now asynchronous.

Key changes:

Attention mask validation (GQA path):

  • Removes host-side validation and memory copies for boolean attention masks in attention.cc; mask validity (right-padding, contiguous True/False) is now checked asynchronously via CUDA_KERNEL_ASSERT in the CUDA kernel. [1] [2] [3]
  • Updates the CUDA kernel and its interface to drop the validation_result buffer and rely on device assertions for mask validation. Documentation is updated to reflect this asynchronous error checking. [1] [2] [3] [4] [5] [6] [7]

TensorScatter write_indices validation:

  • Removes host-side validation and synchronization for write_indices in tensorscatter.cc; index bounds checking is now performed asynchronously inside the CUDA kernel via CUDA_KERNEL_ASSERT. [1] [2]

Test updates:

  • Updates negative test cases for TensorScatter to run only on CPU, since CUDA now validates asynchronously and will not synchronously return errors to the host. [1] [2] [3] [4]

@titaiwangms titaiwangms added the ep:CUDA issues related to the CUDA execution provider label Feb 27, 2026
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR refactors CUDA LLM operations to eliminate synchronous host-device memory transfers and stream synchronizations (cudaStreamSynchronize), enabling CUDA graph capture compatibility and improving performance. The refactoring moves validation logic from host-side (CPU) to device-side (GPU) using CUDA_KERNEL_ASSERT for asynchronous error checking.

Changes:

  • Replaces synchronous host-side validation with asynchronous device-side CUDA_KERNEL_ASSERT in attention mask processing and TensorScatter operations
  • Updates corresponding test cases to restrict validation tests to CPU-only execution, since CUDA now validates asynchronously and won't synchronously return errors
  • Removes host-device memory copies and cudaStreamSynchronize calls that would block CUDA graph capture

Reviewed changes

Copilot reviewed 6 out of 6 changed files in this pull request and generated no comments.

Show a summary per file
File Description
onnxruntime/core/providers/cuda/llm/attention_mask_impl.h Updated function signature to remove validation_result parameter and added documentation about asynchronous CUDA_KERNEL_ASSERT validation
onnxruntime/core/providers/cuda/llm/attention_mask_impl.cu Replaced host-side validation buffer logic with CUDA_KERNEL_ASSERT for mask validity checks (starts with True, contiguous values)
onnxruntime/core/providers/cuda/llm/attention.cc Removed validation buffer allocation, host-side memory copy, stream synchronization, and error checking loop for boolean attention masks
onnxruntime/core/providers/cuda/llm/tensorscatter.cc Removed host-side validation, memory copy, and stream synchronization for write_indices; added comment about CUDA_KERNEL_ASSERT validation
onnxruntime/core/providers/cuda/llm/tensorscatter_impl.cu Added CUDA_KERNEL_ASSERT checks for write_indices (non-negative) and cache_pos bounds in linear mode
onnxruntime/test/providers/cpu/llm/tensorscatter_op_test.cc Updated negative test cases to run CPU-only by explicitly passing DefaultCpuExecutionProvider, with comments explaining CUDA validates asynchronously
Comments suppressed due to low confidence (1)

onnxruntime/core/providers/cuda/llm/tensorscatter_impl.cu:41

  • The bounds checks on write_indices and cache_pos in _TensorScatterKernel rely solely on CUDA_KERNEL_ASSERT, which expands to assert(...) (or a no-op on Apple/HIP) and is compiled out in typical NDEBUG release builds. In those configurations, a caller can supply negative or out-of-range write_indices (now only shape-validated on the host), causing cache_pos to be negative or >= max_seq_len and leading to out-of-bounds writes to output_data[out_offset], potentially corrupting adjacent GPU buffers and leaking or overwriting other data. Replace these assert-based checks with runtime validation that remains active in release builds (either by restoring/keeping host-side range checks or adding device-side conditional checks that safely reject or clamp invalid indices without relying on assert).
  int64_t wi = (write_indices != nullptr) ? write_indices[batch_idx] : 0;
  CUDA_KERNEL_ASSERT(wi >= 0);
  int64_t cache_pos;
  if (circular) {
    cache_pos = (wi + seq_idx) % max_seq_len;
  } else {
    cache_pos = wi + seq_idx;
    CUDA_KERNEL_ASSERT(cache_pos < max_seq_len);
  }

  int64_t out_offset = prefix_idx * (max_seq_len * suffix_count) + cache_pos * suffix_count + suffix_idx;
  output_data[out_offset] = update_data[id];

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@titaiwangms titaiwangms merged commit 479dd39 into main Feb 27, 2026
93 of 95 checks passed
@titaiwangms titaiwangms deleted the titaiwang/refactor_cuda_llm_onnx branch February 27, 2026 21:39
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ep:CUDA issues related to the CUDA execution provider

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants