diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md
index 591d2c6806ea7..abdcc81586909 100644
--- a/docs/OperatorKernels.md
+++ b/docs/OperatorKernels.md
@@ -950,6 +950,7 @@ Do not modify directly.*
|||[6, 7]|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)|
|Tanh|*in* input:**T**
*out* output:**T**|13+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)|
|||[6, 12]|**T** = tensor(double), tensor(float), tensor(float16)|
+|TensorScatter|*in* past_cache:**T**
*in* update:**T**
*in* write_indices:**tensor(int64)**
*out* present_cache:**T**|24+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|ThresholdedRelu|*in* X:**T**
*out* Y:**T**|10+|**T** = tensor(double), tensor(float), tensor(float16)|
|||1+|**T** = tensor(double), tensor(float), tensor(float16)|
|Tile|*in* input:**T**
*in* repeats:**T1**
*out* output:**T**
or
*in* input:**T**
*in* tiles:**T**
*in* axis:**T**
*out* output:**T**|13+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)
**T1** = tensor(int64)|
diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc
index b4cb6c6bd122c..3b334fb507f7e 100644
--- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc
+++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc
@@ -1629,6 +1629,9 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, U
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, Float4E2M1x2, Cast);
#endif
+// Opset 24.
+class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 24, TensorScatter);
+
#endif
template <>
@@ -2703,6 +2706,9 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
+
+ // Opset 24
+ BuildKernelCreateInfo,
#endif
};
diff --git a/onnxruntime/core/providers/cuda/llm/tensorscatter.cc b/onnxruntime/core/providers/cuda/llm/tensorscatter.cc
new file mode 100644
index 0000000000000..f889042b785d9
--- /dev/null
+++ b/onnxruntime/core/providers/cuda/llm/tensorscatter.cc
@@ -0,0 +1,145 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#include "core/providers/cuda/llm/tensorscatter.h"
+#include "core/providers/cuda/llm/tensorscatter_impl.h"
+#include "core/providers/cuda/cuda_common.h"
+
+namespace onnxruntime {
+namespace cuda {
+
+ONNX_OPERATOR_KERNEL_EX(
+ TensorScatter,
+ kOnnxDomain,
+ 24,
+ kCudaExecutionProvider,
+ (*KernelDefBuilder::Create())
+ .MayInplace(0, 0)
+ .TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes()),
+ TensorScatter);
+
+TensorScatter::TensorScatter(const OpKernelInfo& info) : CudaKernel(info) {
+ axis_ = info.GetAttrOrDefault("axis", -2);
+ std::string mode = info.GetAttrOrDefault("mode", "linear");
+ ORT_ENFORCE(mode == "linear" || mode == "circular",
+ "TensorScatter: mode must be 'linear' or 'circular', got '", mode, "'");
+ circular_ = (mode == "circular");
+}
+
+Status TensorScatter::ComputeInternal(OpKernelContext* context) const {
+ const Tensor* past_cache = context->Input(0);
+ const Tensor* update = context->Input(1);
+ const Tensor* write_indices_tensor = context->Input(2); // optional
+
+ ORT_ENFORCE(past_cache != nullptr && update != nullptr,
+ "TensorScatter: past_cache and update must not be null");
+
+ const auto& cache_shape = past_cache->Shape();
+ const auto& update_shape = update->Shape();
+ const int ndim = static_cast(cache_shape.NumDimensions());
+
+ ORT_ENFORCE(ndim >= 2, "TensorScatter: past_cache must have at least 2 dimensions");
+ ORT_ENFORCE(update_shape.NumDimensions() == cache_shape.NumDimensions(),
+ "TensorScatter: past_cache and update must have the same number of dimensions");
+
+ // Resolve axis (handles negative values).
+ int axis = static_cast(axis_);
+ if (axis < 0) axis += ndim;
+ ORT_ENFORCE(axis > 0 && axis < ndim,
+ "TensorScatter: axis must be in [1, ndim-1] after normalization, got ", axis);
+
+ // Validate shapes: all dimensions must match except the axis dimension.
+ const int64_t batch_size = cache_shape[0];
+ const int64_t max_sequence_length = cache_shape[axis];
+ const int64_t sequence_length = update_shape[axis];
+
+ ORT_ENFORCE(sequence_length <= max_sequence_length,
+ "TensorScatter: update sequence_length (", sequence_length,
+ ") exceeds max_sequence_length (", max_sequence_length, ")");
+
+ for (int d = 0; d < ndim; ++d) {
+ if (d != axis) {
+ ORT_ENFORCE(cache_shape[d] == update_shape[d],
+ "TensorScatter: shape mismatch in dimension ", d,
+ ": past_cache=", cache_shape[d], " vs update=", update_shape[d]);
+ }
+ }
+
+ // Validate write_indices if provided.
+ const int64_t* write_indices = nullptr;
+ if (write_indices_tensor != nullptr) {
+ ORT_ENFORCE(write_indices_tensor->Shape().NumDimensions() == 1 &&
+ write_indices_tensor->Shape()[0] == batch_size,
+ "TensorScatter: write_indices must have shape [batch_size]");
+ write_indices = write_indices_tensor->Data();
+
+ // Copy write_indices to host for validation (batch_size elements, negligible overhead).
+ std::vector host_write_indices(static_cast(batch_size));
+ CUDA_RETURN_IF_ERROR(
+ cudaMemcpyAsync(host_write_indices.data(), write_indices,
+ static_cast(batch_size) * sizeof(int64_t),
+ cudaMemcpyDeviceToHost, Stream(context)));
+ CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(Stream(context)));
+
+ for (int64_t b = 0; b < batch_size; ++b) {
+ int64_t wi = host_write_indices[static_cast(b)];
+ ORT_ENFORCE(wi >= 0, "TensorScatter: write_indices[", b, "] = ", wi, " is negative");
+ if (!circular_) {
+ ORT_ENFORCE(wi + sequence_length <= max_sequence_length,
+ "TensorScatter linear mode: write_indices[", b, "] + sequence_length (",
+ wi, " + ", sequence_length, ") exceeds max_sequence_length (", max_sequence_length, ")");
+ }
+ }
+ }
+
+ // Allocate output with the same shape as past_cache.
+ Tensor* present_cache = context->Output(0, cache_shape);
+
+ // Step 1: Copy past_cache -> present_cache.
+ const void* src_raw = past_cache->DataRaw();
+ void* dst_raw = present_cache->MutableDataRaw();
+ if (dst_raw != src_raw) {
+ CUDA_RETURN_IF_ERROR(
+ cudaMemcpyAsync(dst_raw, src_raw, past_cache->SizeInBytes(),
+ cudaMemcpyDeviceToDevice, Stream(context)));
+ }
+
+ // Bail out early if nothing to scatter.
+ if (sequence_length == 0) {
+ return Status::OK();
+ }
+
+ // Step 2: Scatter the update into present_cache.
+ const size_t element_size = past_cache->DataType()->Size();
+
+ int64_t prefix_count = 1;
+ for (int d = 0; d < axis; ++d) {
+ prefix_count *= cache_shape[d];
+ }
+
+ int64_t suffix_count = 1;
+ for (int d = axis + 1; d < ndim; ++d) {
+ suffix_count *= cache_shape[d];
+ }
+
+ int64_t prefix_stride_for_batch = 1;
+ for (int d = 1; d < axis; ++d) {
+ prefix_stride_for_batch *= cache_shape[d];
+ }
+
+ return TensorScatterImpl(
+ Stream(context),
+ dst_raw,
+ update->DataRaw(),
+ write_indices,
+ element_size,
+ prefix_count,
+ prefix_stride_for_batch,
+ max_sequence_length,
+ sequence_length,
+ suffix_count,
+ circular_);
+}
+
+} // namespace cuda
+} // namespace onnxruntime
diff --git a/onnxruntime/core/providers/cuda/llm/tensorscatter.h b/onnxruntime/core/providers/cuda/llm/tensorscatter.h
new file mode 100644
index 0000000000000..cde0f658fc786
--- /dev/null
+++ b/onnxruntime/core/providers/cuda/llm/tensorscatter.h
@@ -0,0 +1,22 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#pragma once
+#include "core/common/common.h"
+#include "core/providers/cuda/cuda_kernel.h"
+
+namespace onnxruntime {
+namespace cuda {
+
+class TensorScatter final : public CudaKernel {
+ public:
+ TensorScatter(const OpKernelInfo& info);
+ Status ComputeInternal(OpKernelContext* context) const override;
+
+ private:
+ int64_t axis_;
+ bool circular_;
+};
+
+} // namespace cuda
+} // namespace onnxruntime
diff --git a/onnxruntime/core/providers/cuda/llm/tensorscatter_impl.cu b/onnxruntime/core/providers/cuda/llm/tensorscatter_impl.cu
new file mode 100644
index 0000000000000..f9ebf3918db4e
--- /dev/null
+++ b/onnxruntime/core/providers/cuda/llm/tensorscatter_impl.cu
@@ -0,0 +1,118 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#include "core/providers/cuda/llm/tensorscatter_impl.h"
+#include "core/providers/cuda/cu_inc/common.cuh"
+
+namespace onnxruntime {
+namespace cuda {
+
+template
+__global__ void _TensorScatterKernel(
+ T* output_data,
+ const T* update_data,
+ const int64_t* write_indices,
+ int64_t prefix_count,
+ int64_t prefix_stride_for_batch,
+ int64_t max_seq_len,
+ int64_t seq_len,
+ int64_t suffix_count,
+ size_t total_elements) {
+ CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, total_elements);
+
+ int64_t seq_suffix = seq_len * suffix_count;
+ int64_t prefix_idx = id / seq_suffix;
+ int64_t remainder = id - prefix_idx * seq_suffix;
+ int64_t seq_idx = remainder / suffix_count;
+ int64_t suffix_idx = remainder - seq_idx * suffix_count;
+
+ int64_t batch_idx = prefix_idx / prefix_stride_for_batch;
+ int64_t wi = (write_indices != nullptr) ? write_indices[batch_idx] : 0;
+ // write_indices are validated on the host before kernel launch.
+ int64_t cache_pos;
+ if (circular) {
+ cache_pos = (wi + seq_idx) % max_seq_len;
+ } else {
+ cache_pos = wi + seq_idx;
+ }
+
+ int64_t out_offset = prefix_idx * (max_seq_len * suffix_count) + cache_pos * suffix_count + suffix_idx;
+ output_data[out_offset] = update_data[id];
+}
+
+template
+Status _TensorScatterDispatchCircular(
+ cudaStream_t stream,
+ T* output_data,
+ const T* update_data,
+ const int64_t* write_indices,
+ int64_t prefix_count,
+ int64_t prefix_stride_for_batch,
+ int64_t max_seq_len,
+ int64_t seq_len,
+ int64_t suffix_count,
+ bool circular) {
+ size_t total_elements = static_cast(prefix_count) * static_cast(seq_len) * static_cast(suffix_count);
+ if (total_elements == 0) return Status::OK();
+
+ int blocksPerGrid = static_cast(CeilDiv(total_elements, static_cast(GridDim::maxThreadsPerBlock)));
+
+ if (circular) {
+ _TensorScatterKernel<<>>(
+ output_data, update_data, write_indices,
+ prefix_count, prefix_stride_for_batch, max_seq_len, seq_len, suffix_count,
+ total_elements);
+ } else {
+ _TensorScatterKernel<<>>(
+ output_data, update_data, write_indices,
+ prefix_count, prefix_stride_for_batch, max_seq_len, seq_len, suffix_count,
+ total_elements);
+ }
+
+ return CUDA_CALL(cudaGetLastError());
+}
+
+Status TensorScatterImpl(
+ cudaStream_t stream,
+ void* output_data,
+ const void* update_data,
+ const int64_t* write_indices,
+ size_t element_size,
+ int64_t prefix_count,
+ int64_t prefix_stride_for_batch,
+ int64_t max_seq_len,
+ int64_t seq_len,
+ int64_t suffix_count,
+ bool circular) {
+ switch (element_size) {
+ case sizeof(int8_t):
+ return _TensorScatterDispatchCircular(
+ stream, reinterpret_cast(output_data),
+ reinterpret_cast(update_data), write_indices,
+ prefix_count, prefix_stride_for_batch, max_seq_len, seq_len, suffix_count, circular);
+
+ case sizeof(int16_t):
+ return _TensorScatterDispatchCircular(
+ stream, reinterpret_cast(output_data),
+ reinterpret_cast(update_data), write_indices,
+ prefix_count, prefix_stride_for_batch, max_seq_len, seq_len, suffix_count, circular);
+
+ case sizeof(int32_t):
+ return _TensorScatterDispatchCircular(
+ stream, reinterpret_cast(output_data),
+ reinterpret_cast(update_data), write_indices,
+ prefix_count, prefix_stride_for_batch, max_seq_len, seq_len, suffix_count, circular);
+
+ case sizeof(int64_t):
+ return _TensorScatterDispatchCircular(
+ stream, reinterpret_cast(output_data),
+ reinterpret_cast(update_data), write_indices,
+ prefix_count, prefix_stride_for_batch, max_seq_len, seq_len, suffix_count, circular);
+
+ default:
+ return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unsupported element size for TensorScatter: ", element_size);
+ }
+}
+
+} // namespace cuda
+} // namespace onnxruntime
diff --git a/onnxruntime/core/providers/cuda/llm/tensorscatter_impl.h b/onnxruntime/core/providers/cuda/llm/tensorscatter_impl.h
new file mode 100644
index 0000000000000..8b6ce5b53b219
--- /dev/null
+++ b/onnxruntime/core/providers/cuda/llm/tensorscatter_impl.h
@@ -0,0 +1,25 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#pragma once
+
+#include "core/providers/cuda/shared_inc/cuda_utils.h"
+
+namespace onnxruntime {
+namespace cuda {
+
+Status TensorScatterImpl(
+ cudaStream_t stream,
+ void* output_data,
+ const void* update_data,
+ const int64_t* write_indices,
+ size_t element_size,
+ int64_t prefix_count,
+ int64_t prefix_stride_for_batch,
+ int64_t max_seq_len,
+ int64_t seq_len,
+ int64_t suffix_count,
+ bool circular);
+
+} // namespace cuda
+} // namespace onnxruntime
diff --git a/onnxruntime/test/providers/cpu/llm/tensorscatter_op_test.cc b/onnxruntime/test/providers/cpu/llm/tensorscatter_op_test.cc
index 34d72dab3d31b..bc68c29c1bc7b 100644
--- a/onnxruntime/test/providers/cpu/llm/tensorscatter_op_test.cc
+++ b/onnxruntime/test/providers/cpu/llm/tensorscatter_op_test.cc
@@ -296,5 +296,51 @@ TEST(TensorScatterTest, InPlace_IOBinding) {
<< "Output should alias the original cache_data buffer";
}
+// Negative write_indices should fail validation.
+TEST(TensorScatterTest, Linear_NegativeWriteIndex) {
+ OpTester test("TensorScatter", 24);
+ test.AddAttribute("mode", "linear");
+
+ test.AddInput("past_cache", {1, 4, 3},
+ {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0});
+ test.AddInput("update", {1, 1, 3}, {1, 2, 3});
+ test.AddInput("write_indices", {1}, {-1});
+ test.AddOutput("present_cache", {1, 4, 3},
+ {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0});
+
+ test.Run(OpTester::ExpectResult::kExpectFailure, "is negative");
+}
+
+// Linear mode: write_indices + sequence_length > max_sequence_length should fail.
+TEST(TensorScatterTest, Linear_OutOfBoundsWriteIndex) {
+ OpTester test("TensorScatter", 24);
+ test.AddAttribute("mode", "linear");
+
+ // max_seq=4, update seq_len=2, wi=3 -> 3+2=5 > 4
+ test.AddInput("past_cache", {1, 4, 3},
+ {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0});
+ test.AddInput("update", {1, 2, 3}, {1, 2, 3, 4, 5, 6});
+ test.AddInput("write_indices", {1}, {3});
+ test.AddOutput("present_cache", {1, 4, 3},
+ {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0});
+
+ test.Run(OpTester::ExpectResult::kExpectFailure, "exceeds max_sequence_length");
+}
+
+// Circular mode: negative write_indices should still fail.
+TEST(TensorScatterTest, Circular_NegativeWriteIndex) {
+ OpTester test("TensorScatter", 24);
+ test.AddAttribute("mode", "circular");
+
+ test.AddInput("past_cache", {1, 4, 2},
+ {0, 0, 0, 0, 0, 0, 0, 0});
+ test.AddInput("update", {1, 1, 2}, {1, 2});
+ test.AddInput("write_indices", {1}, {-2});
+ test.AddOutput("present_cache", {1, 4, 2},
+ {0, 0, 0, 0, 0, 0, 0, 0});
+
+ test.Run(OpTester::ExpectResult::kExpectFailure, "is negative");
+}
+
} // namespace test
} // namespace onnxruntime