Skip to content

Support TensorScatter (24) - CUDA#27446

Merged
titaiwangms merged 3 commits intomainfrom
titaiwang/support_tensor_scatter_cuda
Feb 26, 2026
Merged

Support TensorScatter (24) - CUDA#27446
titaiwangms merged 3 commits intomainfrom
titaiwang/support_tensor_scatter_cuda

Conversation

@titaiwangms
Copy link
Copy Markdown
Contributor

@titaiwangms titaiwangms commented Feb 25, 2026

This pull request adds a new CUDA kernel implementation for the TensorScatter operator in ONNX Runtime, targeting opset 24. The implementation includes kernel registration, device-side logic, and comprehensive input validation, supporting both "linear" and "circular" scatter modes. The operator is also documented and tested, including negative and out-of-bounds input scenarios.

New Operator Implementation:

  • Introduced the TensorScatter CUDA kernel (onnxruntime/core/providers/cuda/llm/tensorscatter.cc, tensorscatter.h, tensorscatter_impl.cu, tensorscatter_impl.h), supporting both "linear" and "circular" modes, with detailed input validation and device-side scatter logic. [1] [2] [3] [4]

Kernel Registration and Integration:

  • Registered the TensorScatter kernel for CUDA in opset 24 within the execution provider and kernel registry (onnxruntime/core/providers/cuda/cuda_execution_provider.cc). [1] [2]

Documentation:

  • Added the TensorScatter operator to the operator kernel documentation, specifying supported types and attributes (docs/OperatorKernels.md).

Testing and Validation:

  • Added unit tests for TensorScatter, including negative tests for invalid write_indices and out-of-bounds conditions in both "linear" and "circular" modes (onnxruntime/test/providers/cpu/llm/tensorscatter_op_test.cc).

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This pull request adds CUDA execution provider support for the new TensorScatter operator (opset 24). The implementation enables efficient tensor scatter operations on CUDA with support for both linear and circular update modes. The operator is designed for cache update scenarios in LLM inference, allowing updates to be scattered into a target tensor along a specified axis.

Changes:

  • Registered TensorScatter operator (opset 24) in the CUDA execution provider
  • Implemented CUDA kernel with efficient element-wise parallelization and template-based circular/linear mode dispatch
  • Added host-side operator logic with input validation, shape checking, and tensor copying before scatter updates

Reviewed changes

Copilot reviewed 5 out of 5 changed files in this pull request and generated 5 comments.

Show a summary per file
File Description
onnxruntime/core/providers/cuda/cuda_execution_provider.cc Added TensorScatter operator registration for opset 24 in CUDA EP
onnxruntime/core/providers/cuda/llm/tensorscatter.h Header defining TensorScatter class interface with axis and circular mode attributes
onnxruntime/core/providers/cuda/llm/tensorscatter.cc Host-side implementation with validation, memcpy, and kernel dispatch logic
onnxruntime/core/providers/cuda/llm/tensorscatter_impl.h CUDA kernel interface declaration for TensorScatterImpl function
onnxruntime/core/providers/cuda/llm/tensorscatter_impl.cu CUDA kernel implementation with element-size-based dispatch and circular/linear mode templates

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread onnxruntime/core/providers/cuda/llm/tensorscatter_impl.cu Outdated
Comment thread onnxruntime/core/providers/cuda/llm/tensorscatter_impl.cu Outdated
Comment thread onnxruntime/core/providers/cuda/llm/tensorscatter_impl.cu Outdated
Comment thread onnxruntime/core/providers/cuda/llm/tensorscatter_impl.cu Outdated
Comment thread onnxruntime/core/providers/cuda/llm/tensorscatter_impl.cu
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 7 out of 7 changed files in this pull request and generated 3 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread onnxruntime/test/providers/cpu/llm/tensorscatter_op_test.cc
Comment thread onnxruntime/core/providers/cuda/llm/tensorscatter.cc
Comment thread onnxruntime/core/providers/cuda/llm/tensorscatter.cc
@titaiwangms titaiwangms added the ep:CUDA issues related to the CUDA execution provider label Feb 25, 2026
@titaiwangms titaiwangms merged commit 39bdea4 into main Feb 26, 2026
98 of 99 checks passed
@titaiwangms titaiwangms deleted the titaiwang/support_tensor_scatter_cuda branch February 26, 2026 17:25
cudaMemcpyAsync(host_write_indices.data(), write_indices,
static_cast<size_t>(batch_size) * sizeof(int64_t),
cudaMemcpyDeviceToHost, Stream(context)));
CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(Stream(context)));
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Such synchronization operations are going to cause issues during cuda graph capture. Usually, on GPU ops , we don't bring data to the host to perform validations if the input is already on CUDA memory. It is probably better to use CUDA_KERNEL_ASSERT() within the kernel to asynchronously report invalid data.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. I will submit a follow up. How did you find out about this?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just happened to look at the PR post merge.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also re-

"The sync is the real cost, but it happens before the scatter kernel launch, not in the middle of a kernel pipeline"

While this is true, a stream sync here would mean that it needs to block untill all asynchronous work queued on this stream complete.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ep:CUDA issues related to the CUDA execution provider

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants