Support TensorScatter (24) - CUDA#27446
Conversation
There was a problem hiding this comment.
Pull request overview
This pull request adds CUDA execution provider support for the new TensorScatter operator (opset 24). The implementation enables efficient tensor scatter operations on CUDA with support for both linear and circular update modes. The operator is designed for cache update scenarios in LLM inference, allowing updates to be scattered into a target tensor along a specified axis.
Changes:
- Registered TensorScatter operator (opset 24) in the CUDA execution provider
- Implemented CUDA kernel with efficient element-wise parallelization and template-based circular/linear mode dispatch
- Added host-side operator logic with input validation, shape checking, and tensor copying before scatter updates
Reviewed changes
Copilot reviewed 5 out of 5 changed files in this pull request and generated 5 comments.
Show a summary per file
| File | Description |
|---|---|
| onnxruntime/core/providers/cuda/cuda_execution_provider.cc | Added TensorScatter operator registration for opset 24 in CUDA EP |
| onnxruntime/core/providers/cuda/llm/tensorscatter.h | Header defining TensorScatter class interface with axis and circular mode attributes |
| onnxruntime/core/providers/cuda/llm/tensorscatter.cc | Host-side implementation with validation, memcpy, and kernel dispatch logic |
| onnxruntime/core/providers/cuda/llm/tensorscatter_impl.h | CUDA kernel interface declaration for TensorScatterImpl function |
| onnxruntime/core/providers/cuda/llm/tensorscatter_impl.cu | CUDA kernel implementation with element-size-based dispatch and circular/linear mode templates |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 7 out of 7 changed files in this pull request and generated 3 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| cudaMemcpyAsync(host_write_indices.data(), write_indices, | ||
| static_cast<size_t>(batch_size) * sizeof(int64_t), | ||
| cudaMemcpyDeviceToHost, Stream(context))); | ||
| CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(Stream(context))); |
There was a problem hiding this comment.
Such synchronization operations are going to cause issues during cuda graph capture. Usually, on GPU ops , we don't bring data to the host to perform validations if the input is already on CUDA memory. It is probably better to use CUDA_KERNEL_ASSERT() within the kernel to asynchronously report invalid data.
There was a problem hiding this comment.
I see. I will submit a follow up. How did you find out about this?
There was a problem hiding this comment.
Just happened to look at the PR post merge.
There was a problem hiding this comment.
Also re-
"The sync is the real cost, but it happens before the scatter kernel launch, not in the middle of a kernel pipeline"
While this is true, a stream sync here would mean that it needs to block untill all asynchronous work queued on this stream complete.
This pull request adds a new CUDA kernel implementation for the
TensorScatteroperator in ONNX Runtime, targeting opset 24. The implementation includes kernel registration, device-side logic, and comprehensive input validation, supporting both "linear" and "circular" scatter modes. The operator is also documented and tested, including negative and out-of-bounds input scenarios.New Operator Implementation:
TensorScatterCUDA kernel (onnxruntime/core/providers/cuda/llm/tensorscatter.cc,tensorscatter.h,tensorscatter_impl.cu,tensorscatter_impl.h), supporting both "linear" and "circular" modes, with detailed input validation and device-side scatter logic. [1] [2] [3] [4]Kernel Registration and Integration:
TensorScatterkernel for CUDA in opset 24 within the execution provider and kernel registry (onnxruntime/core/providers/cuda/cuda_execution_provider.cc). [1] [2]Documentation:
TensorScatteroperator to the operator kernel documentation, specifying supported types and attributes (docs/OperatorKernels.md).Testing and Validation:
TensorScatter, including negative tests for invalidwrite_indicesand out-of-bounds conditions in both "linear" and "circular" modes (onnxruntime/test/providers/cpu/llm/tensorscatter_op_test.cc).