diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md
index 41811201cbf0e..cc18ece351705 100644
--- a/docs/OperatorKernels.md
+++ b/docs/OperatorKernels.md
@@ -494,6 +494,7 @@ Do not modify directly.*
|||[7, 21]|**T** = tensor(float)|
|Tanh|*in* input:**T**
*out* output:**T**|13+|**T** = tensor(double), tensor(float)|
|||[6, 12]|**T** = tensor(double), tensor(float)|
+|TensorScatter|*in* past_cache:**T**
*in* update:**T**
*in* write_indices:**tensor(int64)**
*out* present_cache:**T**|24+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|TfIdfVectorizer|*in* X:**T**
*out* Y:**T1**|9+|**T** = tensor(int32), tensor(int64), tensor(string)
**T1** = tensor(float)|
|ThresholdedRelu|*in* X:**T**
*out* Y:**T**|22+|**T** = tensor(float)|
|||[10, 21]|**T** = tensor(float)|
diff --git a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc
index 2a67b008f849d..74b8f8e468097 100644
--- a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc
+++ b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc
@@ -1406,6 +1406,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain,
// Opset 24
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 24, float, Attention);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 24, MLFloat16, Attention);
+class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 24, TensorScatter);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 24, 24, Cast);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 24, 24, ConstantOfShape);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 24, 24, int32_t, DequantizeLinear);
@@ -3524,6 +3525,7 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
// opset 24
BuildKernelCreateInfo,
BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo
+
+namespace onnxruntime {
+
+ONNX_CPU_OPERATOR_KERNEL(
+ TensorScatter,
+ 24,
+ KernelDefBuilder()
+ .MayInplace(0, 0)
+ .TypeConstraint("T", DataTypeImpl::AllTensorTypes()),
+ TensorScatter);
+
+TensorScatter::TensorScatter(const OpKernelInfo& info) : OpKernel(info) {
+ axis_ = info.GetAttrOrDefault("axis", -2);
+ std::string mode = info.GetAttrOrDefault("mode", "linear");
+ ORT_ENFORCE(mode == "linear" || mode == "circular",
+ "TensorScatter: mode must be 'linear' or 'circular', got '", mode, "'");
+ circular_ = (mode == "circular");
+}
+
+Status TensorScatter::Compute(OpKernelContext* context) const {
+ const Tensor* past_cache = context->Input(0);
+ const Tensor* update = context->Input(1);
+ const Tensor* write_indices_tensor = context->Input(2); // optional
+
+ ORT_ENFORCE(past_cache != nullptr && update != nullptr,
+ "TensorScatter: past_cache and update must not be null");
+
+ const auto& cache_shape = past_cache->Shape();
+ const auto& update_shape = update->Shape();
+ const int ndim = static_cast(cache_shape.NumDimensions());
+
+ ORT_ENFORCE(ndim >= 2, "TensorScatter: past_cache must have at least 2 dimensions");
+ ORT_ENFORCE(update_shape.NumDimensions() == cache_shape.NumDimensions(),
+ "TensorScatter: past_cache and update must have the same number of dimensions");
+
+ // Resolve axis (handles negative values).
+ int axis = static_cast(axis_);
+ if (axis < 0) axis += ndim;
+ ORT_ENFORCE(axis > 0 && axis < ndim,
+ "TensorScatter: axis must be in [1, ndim-1] after normalization, got ", axis);
+
+ // Validate shapes: all dimensions must match except the axis dimension.
+ const int64_t batch_size = cache_shape[0];
+ const int64_t max_sequence_length = cache_shape[axis];
+ const int64_t sequence_length = update_shape[axis];
+
+ ORT_ENFORCE(sequence_length <= max_sequence_length,
+ "TensorScatter: update sequence_length (", sequence_length,
+ ") exceeds max_sequence_length (", max_sequence_length, ")");
+
+ for (int d = 0; d < ndim; ++d) {
+ if (d != axis) {
+ ORT_ENFORCE(cache_shape[d] == update_shape[d],
+ "TensorScatter: shape mismatch in dimension ", d,
+ ": past_cache=", cache_shape[d], " vs update=", update_shape[d]);
+ }
+ }
+
+ // Validate write_indices if provided.
+ const int64_t* write_indices = nullptr;
+ if (write_indices_tensor != nullptr) {
+ ORT_ENFORCE(write_indices_tensor->Shape().NumDimensions() == 1 &&
+ write_indices_tensor->Shape()[0] == batch_size,
+ "TensorScatter: write_indices must have shape [batch_size]");
+ write_indices = write_indices_tensor->Data();
+ }
+
+ // Allocate output with the same shape as past_cache.
+ Tensor* present_cache = context->Output(0, cache_shape);
+
+ // Step 1: Copy past_cache -> present_cache.
+ const auto element_size = past_cache->DataType()->Size();
+ const size_t total_bytes = SafeInt(cache_shape.Size()) * element_size;
+ const auto* src_raw = past_cache->DataRaw();
+ auto* dst_raw = present_cache->MutableDataRaw();
+ if (dst_raw != src_raw) {
+ LOGS(context->Logger(), WARNING) << "TensorScatter: in-place optimization not activated, copying past_cache to present_cache ("
+ << total_bytes << " bytes)";
+ memcpy(dst_raw, src_raw, total_bytes);
+ }
+
+ // Step 2: Scatter the update into present_cache.
+ //
+ // Layout: (batch_size, D1, ..., D_{axis-1}, max_seq_len, D_{axis+1}, ..., D_{n-1})
+ //
+ // We decompose the tensor into:
+ // prefix_count = product of dims[0:axis] (number of prefix slices)
+ // suffix_bytes = product of dims[axis+1:] * element_size (bytes per single sequence position)
+ //
+ // For each prefix slice we determine batch_idx = prefix_linear_idx / prefix_stride_for_batch,
+ // look up write_indices[batch_idx], and memcpy sequence_length suffix-sized chunks.
+
+ int64_t prefix_count = 1;
+ for (int d = 0; d < axis; ++d) {
+ prefix_count *= cache_shape[d];
+ }
+
+ int64_t suffix_count = 1;
+ for (int d = axis + 1; d < ndim; ++d) {
+ suffix_count *= cache_shape[d];
+ }
+ const size_t suffix_bytes = SafeInt(suffix_count) * element_size;
+
+ // prefix_stride_for_batch: number of prefix elements per batch element
+ // e.g., shape (B, D1, ..., D_{axis-1}, ...) -> stride = D1 * ... * D_{axis-1}
+ int64_t prefix_stride_for_batch = 1;
+ for (int d = 1; d < axis; ++d) {
+ prefix_stride_for_batch *= cache_shape[d];
+ }
+
+ const size_t cache_axis_stride = SafeInt(max_sequence_length) * suffix_bytes;
+ const size_t update_axis_stride = SafeInt(sequence_length) * suffix_bytes;
+ auto* dst_bytes = static_cast(dst_raw);
+ const auto* update_raw = static_cast(update->DataRaw());
+
+ for (int64_t p = 0; p < prefix_count; ++p) {
+ int64_t batch_idx = p / prefix_stride_for_batch;
+ int64_t wi = (write_indices != nullptr) ? write_indices[batch_idx] : 0;
+ ORT_ENFORCE(wi >= 0, "TensorScatter: write_indices[", batch_idx, "] = ", wi, " is negative");
+
+ ptrdiff_t update_offset = static_cast(SafeInt(p) * update_axis_stride);
+ ptrdiff_t cache_offset = static_cast(SafeInt(p) * cache_axis_stride);
+ const uint8_t* update_base = update_raw + update_offset;
+ uint8_t* cache_base = dst_bytes + cache_offset;
+
+ if (!circular_) {
+ ORT_ENFORCE(wi + sequence_length <= max_sequence_length,
+ "TensorScatter linear mode: write_indices[", batch_idx, "] + sequence_length (",
+ wi, " + ", sequence_length, ") exceeds max_sequence_length (", max_sequence_length, ")");
+ // Single contiguous memcpy for the whole slice.
+ ptrdiff_t wi_offset = static_cast(SafeInt(wi) * suffix_bytes);
+ size_t copy_len = SafeInt(sequence_length) * suffix_bytes;
+ memcpy(cache_base + wi_offset, update_base, copy_len);
+ } else {
+ // Circular: each sequence position wraps independently.
+ for (int64_t s = 0; s < sequence_length; ++s) {
+ int64_t cache_pos = (wi + s) % max_sequence_length;
+ ptrdiff_t dst_off = static_cast(SafeInt(cache_pos) * suffix_bytes);
+ ptrdiff_t src_off = static_cast(SafeInt(s) * suffix_bytes);
+ memcpy(cache_base + dst_off, update_base + src_off, suffix_bytes);
+ }
+ }
+ }
+
+ return Status::OK();
+}
+
+} // namespace onnxruntime
diff --git a/onnxruntime/core/providers/cpu/llm/tensorscatter.h b/onnxruntime/core/providers/cpu/llm/tensorscatter.h
new file mode 100644
index 0000000000000..cb3bd84501476
--- /dev/null
+++ b/onnxruntime/core/providers/cpu/llm/tensorscatter.h
@@ -0,0 +1,20 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#pragma once
+#include "core/common/common.h"
+#include "core/framework/op_kernel.h"
+
+namespace onnxruntime {
+
+class TensorScatter final : public OpKernel {
+ public:
+ TensorScatter(const OpKernelInfo& info);
+ Status Compute(OpKernelContext* context) const override;
+
+ private:
+ int64_t axis_;
+ bool circular_;
+};
+
+} // namespace onnxruntime
diff --git a/onnxruntime/test/providers/cpu/llm/tensorscatter_op_test.cc b/onnxruntime/test/providers/cpu/llm/tensorscatter_op_test.cc
new file mode 100644
index 0000000000000..34d72dab3d31b
--- /dev/null
+++ b/onnxruntime/test/providers/cpu/llm/tensorscatter_op_test.cc
@@ -0,0 +1,300 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#include "gtest/gtest.h"
+#include "test/providers/provider_test_utils.h"
+
+#include "core/graph/model.h"
+#include "core/graph/node_attr_utils.h"
+#include "core/graph/onnx_protobuf.h"
+#include "core/session/inference_session.h"
+#include "core/session/IOBinding.h"
+#include "test/unittest_util/framework_test_utils.h"
+#include "test/util/include/default_providers.h"
+#include "test/util/include/test_environment.h"
+
+namespace onnxruntime {
+namespace test {
+
+// From ONNX spec example: tensorscatter (4D, linear mode)
+TEST(TensorScatterTest, Linear_4D) {
+ OpTester test("TensorScatter", 24);
+ test.AddAttribute("mode", "linear");
+
+ // past_cache: shape (2, 1, 4, 5)
+ test.AddInput("past_cache", {2, 1, 4, 5},
+ {1, 2, 3, 4, 5, 5, 6, 7, 8, 9, 8, 7, 6, 5, 4, 4, 3, 2, 1, 0,
+ 1, 2, 3, 4, 5, 5, 6, 7, 8, 9, 8, 7, 6, 5, 4, 4, 3, 2, 1, 0});
+
+ // update: shape (2, 1, 1, 5)
+ test.AddInput("update", {2, 1, 1, 5},
+ {5, 5, 5, 5, 5,
+ 1, 1, 1, 1, 1});
+
+ // write_indices: shape (2,)
+ test.AddInput("write_indices", {2}, {1, 2});
+
+ // present_cache: shape (2, 1, 4, 5)
+ test.AddOutput("present_cache", {2, 1, 4, 5},
+ {1, 2, 3, 4, 5, 5, 5, 5, 5, 5, 8, 7, 6, 5, 4, 4, 3, 2, 1, 0,
+ 1, 2, 3, 4, 5, 5, 6, 7, 8, 9, 1, 1, 1, 1, 1, 4, 3, 2, 1, 0});
+
+ test.Run();
+}
+
+// From ONNX spec example: tensorscatter_3d (3D, default axis=-2 -> axis=1)
+TEST(TensorScatterTest, Linear_3D) {
+ OpTester test("TensorScatter", 24);
+
+ // past_cache: shape (3, 4, 5)
+ test.AddInput("past_cache", {3, 4, 5},
+ {1, 2, 3, 4, 5, 5, 6, 7, 8, 9, 8, 7, 6, 5, 4, 5, 4, 3, 2, 1,
+ 1, 2, 3, 4, 5, 5, 6, 7, 8, 9, 8, 7, 6, 5, 4, 5, 4, 3, 2, 1,
+ 1, 2, 3, 4, 5, 5, 6, 7, 8, 9, 8, 7, 6, 5, 4, 5, 4, 3, 2, 1});
+
+ // update: shape (3, 2, 5)
+ test.AddInput("update", {3, 2, 5},
+ {4, 4, 4, 4, 4, 5, 5, 5, 5, 5,
+ 6, 6, 6, 6, 6, 7, 7, 7, 7, 7,
+ 2, 2, 2, 2, 2, 3, 3, 3, 3, 3});
+
+ // write_indices: shape (3,)
+ test.AddInput("write_indices", {3}, {1, 2, 0});
+
+ // present_cache: shape (3, 4, 5)
+ test.AddOutput("present_cache", {3, 4, 5},
+ {1, 2, 3, 4, 5, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 4, 3, 2, 1,
+ 1, 2, 3, 4, 5, 5, 6, 7, 8, 9, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7,
+ 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 8, 7, 6, 5, 4, 5, 4, 3, 2, 1});
+
+ test.Run();
+}
+
+// From ONNX spec example: tensorscatter_circular (4D, circular mode)
+TEST(TensorScatterTest, Circular_4D) {
+ OpTester test("TensorScatter", 24);
+ test.AddAttribute("mode", "circular");
+
+ // past_cache: shape (2, 1, 4, 5)
+ test.AddInput("past_cache", {2, 1, 4, 5},
+ {1, 2, 3, 4, 5, 5, 6, 7, 8, 9, 8, 7, 6, 5, 4, 4, 3, 2, 1, 0,
+ 1, 2, 3, 4, 5, 5, 6, 7, 8, 9, 8, 7, 6, 5, 4, 4, 3, 2, 1, 0});
+
+ // update: shape (2, 1, 2, 5)
+ test.AddInput("update", {2, 1, 2, 5},
+ {5, 5, 5, 5, 5, 6, 6, 6, 6, 6,
+ 1, 1, 1, 1, 1, 2, 2, 2, 2, 2});
+
+ // write_indices: shape (2,)
+ test.AddInput("write_indices", {2}, {1, 3});
+
+ // present_cache: shape (2, 1, 4, 5)
+ // Batch 0: wi=1, seq_len=2 -> positions 1,2 (no wrap)
+ // Batch 1: wi=3, seq_len=2 -> positions 3, 0 (wraps around)
+ test.AddOutput("present_cache", {2, 1, 4, 5},
+ {1, 2, 3, 4, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 4, 3, 2, 1, 0,
+ 2, 2, 2, 2, 2, 5, 6, 7, 8, 9, 8, 7, 6, 5, 4, 1, 1, 1, 1, 1});
+
+ test.Run();
+}
+
+// No write_indices (defaults to zero) — prefill scenario.
+TEST(TensorScatterTest, Linear_NoWriteIndices) {
+ OpTester test("TensorScatter", 24);
+ test.AddAttribute("mode", "linear");
+
+ // past_cache: shape (1, 1, 4, 3)
+ test.AddInput("past_cache", {1, 1, 4, 3},
+ {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0});
+
+ // update: shape (1, 1, 2, 3) — writes at position 0 (default)
+ test.AddInput("update", {1, 1, 2, 3},
+ {1, 2, 3, 4, 5, 6});
+
+ test.AddOptionalInputEdge();
+
+ // present_cache: positions 0,1 filled, 2,3 untouched
+ test.AddOutput("present_cache", {1, 1, 4, 3},
+ {1, 2, 3, 4, 5, 6, 0, 0, 0, 0, 0, 0});
+
+ test.Run();
+}
+
+// Float16 type test
+TEST(TensorScatterTest, Linear_Float16) {
+ OpTester test("TensorScatter", 24);
+ test.AddAttribute("mode", "linear");
+
+ std::vector past_f = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12};
+ std::vector update_f = {99, 98, 97};
+ std::vector expected_f = {1, 2, 3, 99, 98, 97, 7, 8, 9, 10, 11, 12};
+
+ std::vector past_fp16(past_f.size());
+ std::vector update_fp16(update_f.size());
+ std::vector expected_fp16(expected_f.size());
+ for (size_t i = 0; i < past_f.size(); ++i) past_fp16[i] = MLFloat16(past_f[i]);
+ for (size_t i = 0; i < update_f.size(); ++i) update_fp16[i] = MLFloat16(update_f[i]);
+ for (size_t i = 0; i < expected_f.size(); ++i) expected_fp16[i] = MLFloat16(expected_f[i]);
+
+ // shape (1, 4, 3), axis=-2 -> axis=1
+ test.AddInput("past_cache", {1, 4, 3}, past_fp16);
+ test.AddInput("update", {1, 1, 3}, update_fp16);
+ test.AddInput("write_indices", {1}, {1});
+ test.AddOutput("present_cache", {1, 4, 3}, expected_fp16);
+
+ test.Run();
+}
+
+// Explicit axis attribute test (axis=1 on a 3D tensor)
+TEST(TensorScatterTest, Linear_ExplicitAxis) {
+ OpTester test("TensorScatter", 24);
+ test.AddAttribute("mode", "linear");
+ test.AddAttribute("axis", 1);
+
+ // shape (2, 3, 2) — axis=1 means the seq dim is dim 1 (size 3)
+ test.AddInput("past_cache", {2, 3, 2},
+ {0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0});
+
+ // update: shape (2, 1, 2)
+ test.AddInput("update", {2, 1, 2},
+ {1, 2,
+ 3, 4});
+
+ test.AddInput("write_indices", {2}, {0, 2});
+
+ test.AddOutput("present_cache", {2, 3, 2},
+ {1, 2, 0, 0, 0, 0,
+ 0, 0, 0, 0, 3, 4});
+
+ test.Run();
+}
+
+// Circular wrap-around with multi-position update
+TEST(TensorScatterTest, Circular_WrapAround) {
+ OpTester test("TensorScatter", 24);
+ test.AddAttribute("mode", "circular");
+
+ // shape (1, 4, 2), axis=-2 -> axis=1, max_seq=4
+ test.AddInput("past_cache", {1, 4, 2},
+ {10, 11, 20, 21, 30, 31, 40, 41});
+
+ // update: 3 positions starting at wi=2 -> positions 2, 3, 0 (wraps)
+ test.AddInput("update", {1, 3, 2},
+ {1, 2, 3, 4, 5, 6});
+
+ test.AddInput("write_indices", {1}, {2});
+
+ // pos 2->1,2 pos 3->3,4 pos 0->5,6 (wrapped)
+ test.AddOutput("present_cache", {1, 4, 2},
+ {5, 6, 20, 21, 1, 2, 3, 4});
+
+ test.Run();
+}
+
+// IO-binding test: bind the same buffer as both input[0] (past_cache) and
+// output[0] (present_cache) to verify the MayInplace(0,0) in-place path.
+TEST(TensorScatterTest, InPlace_IOBinding) {
+ // Build a TensorScatter model programmatically.
+ std::unordered_map domain_to_version;
+ domain_to_version[onnxruntime::kOnnxDomain] = 24;
+ std::vector model_specific_functions;
+ auto p_model = std::make_unique(
+ "tensorscatter_inplace_test", true, ModelMetaData(), PathString(),
+ IOnnxRuntimeOpSchemaRegistryList(), domain_to_version,
+ model_specific_functions, DefaultLoggingManager().DefaultLogger(),
+ ModelOptions(true, true));
+ onnxruntime::Graph& graph = p_model->MainGraph();
+
+ // Define types.
+ ONNX_NAMESPACE::TypeProto tensor_float;
+ tensor_float.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT);
+
+ ONNX_NAMESPACE::TypeProto tensor_int64;
+ tensor_int64.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_INT64);
+
+ // Inputs.
+ auto& past_cache_arg = graph.GetOrCreateNodeArg("past_cache", &tensor_float);
+ auto& update_arg = graph.GetOrCreateNodeArg("update", &tensor_float);
+ auto& write_indices_arg = graph.GetOrCreateNodeArg("write_indices", &tensor_int64);
+ std::vector input_defs = {&past_cache_arg, &update_arg, &write_indices_arg};
+
+ // Output.
+ auto& present_cache_arg = graph.GetOrCreateNodeArg("present_cache", &tensor_float);
+ std::vector output_defs = {&present_cache_arg};
+
+ // Attributes: mode = "linear".
+ NodeAttributes attrs = {
+ {"mode", utils::MakeAttribute("mode", std::string("linear"))}};
+
+ auto& node = graph.AddNode("ts_node", "TensorScatter", "TensorScatter in-place test",
+ input_defs, output_defs, &attrs, onnxruntime::kOnnxDomain);
+ node.SetExecutionProviderType(kCpuExecutionProvider);
+
+ ASSERT_STATUS_OK(graph.Resolve());
+
+ // Serialize and load into InferenceSession.
+ std::string model_str;
+ p_model->ToProto().SerializeToString(&model_str);
+ std::stringstream sstr(model_str);
+
+ SessionOptions so;
+ so.session_logid = "TensorScatterInPlaceTest";
+ InferenceSession session(so, GetEnvironment());
+ ASSERT_STATUS_OK(session.Load(sstr));
+ ASSERT_STATUS_OK(session.Initialize());
+
+ // Allocate the shared buffer for past_cache / present_cache.
+ // Shape: (1, 4, 3) — 1 batch, 4 seq positions, 3 features.
+ std::vector cache_dims = {1, 4, 3};
+ std::vector cache_data = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12};
+
+ // update shape: (1, 1, 3) — write 1 position.
+ std::vector update_dims = {1, 1, 3};
+ std::vector update_data = {99, 98, 97};
+
+ // write_indices: write at position 2.
+ std::vector wi_dims = {1};
+ std::vector wi_data = {2};
+
+ auto cpu_alloc = TestCPUExecutionProvider()->CreatePreferredAllocators()[0];
+
+ // Create the shared OrtValue backed by cache_data for both input and output.
+ OrtValue shared_cache;
+ CreateMLValue(cache_dims, cache_data.data(), cpu_alloc->Info(), &shared_cache);
+
+ OrtValue update_val;
+ CreateMLValue(cpu_alloc, update_dims, update_data, &update_val);
+
+ OrtValue wi_val;
+ CreateMLValue(cpu_alloc, wi_dims, wi_data, &wi_val);
+
+ // Set up IO binding: same OrtValue for input past_cache and output present_cache.
+ std::unique_ptr io_binding;
+ ASSERT_STATUS_OK(session.NewIOBinding(&io_binding));
+ ASSERT_STATUS_OK(io_binding->BindInput("past_cache", shared_cache));
+ ASSERT_STATUS_OK(io_binding->BindInput("update", update_val));
+ ASSERT_STATUS_OK(io_binding->BindInput("write_indices", wi_val));
+ ASSERT_STATUS_OK(io_binding->BindOutput("present_cache", shared_cache));
+
+ RunOptions run_options;
+ ASSERT_STATUS_OK(session.Run(run_options, *io_binding));
+
+ // Verify that the shared buffer was updated in-place.
+ // Original: {1,2,3, 4,5,6, 7,8,9, 10,11,12}
+ // After scatter at position 2: {1,2,3, 4,5,6, 99,98,97, 10,11,12}
+ std::vector expected = {1, 2, 3, 4, 5, 6, 99, 98, 97, 10, 11, 12};
+ const auto& output = io_binding->GetOutputs()[0];
+ auto span = output.Get().DataAsSpan();
+ ASSERT_EQ(static_cast(span.size()), expected.size());
+ for (size_t i = 0; i < expected.size(); ++i) {
+ EXPECT_FLOAT_EQ(span[i], expected[i]) << "mismatch at index " << i;
+ }
+
+ // Verify the buffer pointer is still the same (in-place).
+ EXPECT_EQ(output.Get().Data(), cache_data.data())
+ << "Output should alias the original cache_data buffer";
+}
+
+} // namespace test
+} // namespace onnxruntime