Remove cudaStreamSynchronize from CUDA LLM ops for CUDA graph capture compatibility#27484
Merged
titaiwangms merged 1 commit intomainfrom Feb 27, 2026
Merged
Conversation
Contributor
There was a problem hiding this comment.
Pull request overview
This PR refactors CUDA LLM operations to eliminate synchronous host-device memory transfers and stream synchronizations (cudaStreamSynchronize), enabling CUDA graph capture compatibility and improving performance. The refactoring moves validation logic from host-side (CPU) to device-side (GPU) using CUDA_KERNEL_ASSERT for asynchronous error checking.
Changes:
- Replaces synchronous host-side validation with asynchronous device-side CUDA_KERNEL_ASSERT in attention mask processing and TensorScatter operations
- Updates corresponding test cases to restrict validation tests to CPU-only execution, since CUDA now validates asynchronously and won't synchronously return errors
- Removes host-device memory copies and cudaStreamSynchronize calls that would block CUDA graph capture
Reviewed changes
Copilot reviewed 6 out of 6 changed files in this pull request and generated no comments.
Show a summary per file
| File | Description |
|---|---|
| onnxruntime/core/providers/cuda/llm/attention_mask_impl.h | Updated function signature to remove validation_result parameter and added documentation about asynchronous CUDA_KERNEL_ASSERT validation |
| onnxruntime/core/providers/cuda/llm/attention_mask_impl.cu | Replaced host-side validation buffer logic with CUDA_KERNEL_ASSERT for mask validity checks (starts with True, contiguous values) |
| onnxruntime/core/providers/cuda/llm/attention.cc | Removed validation buffer allocation, host-side memory copy, stream synchronization, and error checking loop for boolean attention masks |
| onnxruntime/core/providers/cuda/llm/tensorscatter.cc | Removed host-side validation, memory copy, and stream synchronization for write_indices; added comment about CUDA_KERNEL_ASSERT validation |
| onnxruntime/core/providers/cuda/llm/tensorscatter_impl.cu | Added CUDA_KERNEL_ASSERT checks for write_indices (non-negative) and cache_pos bounds in linear mode |
| onnxruntime/test/providers/cpu/llm/tensorscatter_op_test.cc | Updated negative test cases to run CPU-only by explicitly passing DefaultCpuExecutionProvider, with comments explaining CUDA validates asynchronously |
Comments suppressed due to low confidence (1)
onnxruntime/core/providers/cuda/llm/tensorscatter_impl.cu:41
- The bounds checks on
write_indicesandcache_posin_TensorScatterKernelrely solely onCUDA_KERNEL_ASSERT, which expands toassert(...)(or a no-op on Apple/HIP) and is compiled out in typicalNDEBUGrelease builds. In those configurations, a caller can supply negative or out-of-rangewrite_indices(now only shape-validated on the host), causingcache_posto be negative or>= max_seq_lenand leading to out-of-bounds writes tooutput_data[out_offset], potentially corrupting adjacent GPU buffers and leaking or overwriting other data. Replace theseassert-based checks with runtime validation that remains active in release builds (either by restoring/keeping host-side range checks or adding device-side conditional checks that safely reject or clamp invalid indices without relying onassert).
int64_t wi = (write_indices != nullptr) ? write_indices[batch_idx] : 0;
CUDA_KERNEL_ASSERT(wi >= 0);
int64_t cache_pos;
if (circular) {
cache_pos = (wi + seq_idx) % max_seq_len;
} else {
cache_pos = wi + seq_idx;
CUDA_KERNEL_ASSERT(cache_pos < max_seq_len);
}
int64_t out_offset = prefix_idx * (max_seq_len * suffix_count) + cache_pos * suffix_count + suffix_idx;
output_data[out_offset] = update_data[id];
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
hariharans29
approved these changes
Feb 27, 2026
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
This pull request refactors validation logic for CUDA attention masks and tensor scatter operations to move error checking from host-side (CPU) to device-side (GPU) using CUDA kernel assertions (
CUDA_KERNEL_ASSERT). This change eliminates synchronous host-device memory transfers and stream synchronizations, improving performance and simplifying code. Corresponding test cases are updated to only expect validation failures on the CPU, as CUDA errors are now asynchronous.Key changes:
Attention mask validation (GQA path):
attention.cc; mask validity (right-padding, contiguous True/False) is now checked asynchronously viaCUDA_KERNEL_ASSERTin the CUDA kernel. [1] [2] [3]validation_resultbuffer and rely on device assertions for mask validation. Documentation is updated to reflect this asynchronous error checking. [1] [2] [3] [4] [5] [6] [7]TensorScatter write_indices validation:
write_indicesintensorscatter.cc; index bounds checking is now performed asynchronously inside the CUDA kernel viaCUDA_KERNEL_ASSERT. [1] [2]Test updates:
TensorScatterto run only on CPU, since CUDA now validates asynchronously and will not synchronously return errors to the host. [1] [2] [3] [4]