diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 41811201cbf0e..cc18ece351705 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -494,6 +494,7 @@ Do not modify directly.* |||[7, 21]|**T** = tensor(float)| |Tanh|*in* input:**T**
*out* output:**T**|13+|**T** = tensor(double), tensor(float)| |||[6, 12]|**T** = tensor(double), tensor(float)| +|TensorScatter|*in* past_cache:**T**
*in* update:**T**
*in* write_indices:**tensor(int64)**
*out* present_cache:**T**|24+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |TfIdfVectorizer|*in* X:**T**
*out* Y:**T1**|9+|**T** = tensor(int32), tensor(int64), tensor(string)
**T1** = tensor(float)| |ThresholdedRelu|*in* X:**T**
*out* Y:**T**|22+|**T** = tensor(float)| |||[10, 21]|**T** = tensor(float)| diff --git a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc index 2a67b008f849d..74b8f8e468097 100644 --- a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc +++ b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc @@ -1406,6 +1406,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, // Opset 24 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 24, float, Attention); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 24, MLFloat16, Attention); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 24, TensorScatter); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 24, 24, Cast); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 24, 24, ConstantOfShape); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 24, 24, int32_t, DequantizeLinear); @@ -3524,6 +3525,7 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { // opset 24 BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo + +namespace onnxruntime { + +ONNX_CPU_OPERATOR_KERNEL( + TensorScatter, + 24, + KernelDefBuilder() + .MayInplace(0, 0) + .TypeConstraint("T", DataTypeImpl::AllTensorTypes()), + TensorScatter); + +TensorScatter::TensorScatter(const OpKernelInfo& info) : OpKernel(info) { + axis_ = info.GetAttrOrDefault("axis", -2); + std::string mode = info.GetAttrOrDefault("mode", "linear"); + ORT_ENFORCE(mode == "linear" || mode == "circular", + "TensorScatter: mode must be 'linear' or 'circular', got '", mode, "'"); + circular_ = (mode == "circular"); +} + +Status TensorScatter::Compute(OpKernelContext* context) const { + const Tensor* past_cache = context->Input(0); + const Tensor* update = context->Input(1); + const Tensor* write_indices_tensor = context->Input(2); // optional + + ORT_ENFORCE(past_cache != nullptr && update != nullptr, + "TensorScatter: past_cache and update must not be null"); + + const auto& cache_shape = past_cache->Shape(); + const auto& update_shape = update->Shape(); + const int ndim = static_cast(cache_shape.NumDimensions()); + + ORT_ENFORCE(ndim >= 2, "TensorScatter: past_cache must have at least 2 dimensions"); + ORT_ENFORCE(update_shape.NumDimensions() == cache_shape.NumDimensions(), + "TensorScatter: past_cache and update must have the same number of dimensions"); + + // Resolve axis (handles negative values). + int axis = static_cast(axis_); + if (axis < 0) axis += ndim; + ORT_ENFORCE(axis > 0 && axis < ndim, + "TensorScatter: axis must be in [1, ndim-1] after normalization, got ", axis); + + // Validate shapes: all dimensions must match except the axis dimension. + const int64_t batch_size = cache_shape[0]; + const int64_t max_sequence_length = cache_shape[axis]; + const int64_t sequence_length = update_shape[axis]; + + ORT_ENFORCE(sequence_length <= max_sequence_length, + "TensorScatter: update sequence_length (", sequence_length, + ") exceeds max_sequence_length (", max_sequence_length, ")"); + + for (int d = 0; d < ndim; ++d) { + if (d != axis) { + ORT_ENFORCE(cache_shape[d] == update_shape[d], + "TensorScatter: shape mismatch in dimension ", d, + ": past_cache=", cache_shape[d], " vs update=", update_shape[d]); + } + } + + // Validate write_indices if provided. + const int64_t* write_indices = nullptr; + if (write_indices_tensor != nullptr) { + ORT_ENFORCE(write_indices_tensor->Shape().NumDimensions() == 1 && + write_indices_tensor->Shape()[0] == batch_size, + "TensorScatter: write_indices must have shape [batch_size]"); + write_indices = write_indices_tensor->Data(); + } + + // Allocate output with the same shape as past_cache. + Tensor* present_cache = context->Output(0, cache_shape); + + // Step 1: Copy past_cache -> present_cache. + const auto element_size = past_cache->DataType()->Size(); + const size_t total_bytes = SafeInt(cache_shape.Size()) * element_size; + const auto* src_raw = past_cache->DataRaw(); + auto* dst_raw = present_cache->MutableDataRaw(); + if (dst_raw != src_raw) { + LOGS(context->Logger(), WARNING) << "TensorScatter: in-place optimization not activated, copying past_cache to present_cache (" + << total_bytes << " bytes)"; + memcpy(dst_raw, src_raw, total_bytes); + } + + // Step 2: Scatter the update into present_cache. + // + // Layout: (batch_size, D1, ..., D_{axis-1}, max_seq_len, D_{axis+1}, ..., D_{n-1}) + // + // We decompose the tensor into: + // prefix_count = product of dims[0:axis] (number of prefix slices) + // suffix_bytes = product of dims[axis+1:] * element_size (bytes per single sequence position) + // + // For each prefix slice we determine batch_idx = prefix_linear_idx / prefix_stride_for_batch, + // look up write_indices[batch_idx], and memcpy sequence_length suffix-sized chunks. + + int64_t prefix_count = 1; + for (int d = 0; d < axis; ++d) { + prefix_count *= cache_shape[d]; + } + + int64_t suffix_count = 1; + for (int d = axis + 1; d < ndim; ++d) { + suffix_count *= cache_shape[d]; + } + const size_t suffix_bytes = SafeInt(suffix_count) * element_size; + + // prefix_stride_for_batch: number of prefix elements per batch element + // e.g., shape (B, D1, ..., D_{axis-1}, ...) -> stride = D1 * ... * D_{axis-1} + int64_t prefix_stride_for_batch = 1; + for (int d = 1; d < axis; ++d) { + prefix_stride_for_batch *= cache_shape[d]; + } + + const size_t cache_axis_stride = SafeInt(max_sequence_length) * suffix_bytes; + const size_t update_axis_stride = SafeInt(sequence_length) * suffix_bytes; + auto* dst_bytes = static_cast(dst_raw); + const auto* update_raw = static_cast(update->DataRaw()); + + for (int64_t p = 0; p < prefix_count; ++p) { + int64_t batch_idx = p / prefix_stride_for_batch; + int64_t wi = (write_indices != nullptr) ? write_indices[batch_idx] : 0; + ORT_ENFORCE(wi >= 0, "TensorScatter: write_indices[", batch_idx, "] = ", wi, " is negative"); + + ptrdiff_t update_offset = static_cast(SafeInt(p) * update_axis_stride); + ptrdiff_t cache_offset = static_cast(SafeInt(p) * cache_axis_stride); + const uint8_t* update_base = update_raw + update_offset; + uint8_t* cache_base = dst_bytes + cache_offset; + + if (!circular_) { + ORT_ENFORCE(wi + sequence_length <= max_sequence_length, + "TensorScatter linear mode: write_indices[", batch_idx, "] + sequence_length (", + wi, " + ", sequence_length, ") exceeds max_sequence_length (", max_sequence_length, ")"); + // Single contiguous memcpy for the whole slice. + ptrdiff_t wi_offset = static_cast(SafeInt(wi) * suffix_bytes); + size_t copy_len = SafeInt(sequence_length) * suffix_bytes; + memcpy(cache_base + wi_offset, update_base, copy_len); + } else { + // Circular: each sequence position wraps independently. + for (int64_t s = 0; s < sequence_length; ++s) { + int64_t cache_pos = (wi + s) % max_sequence_length; + ptrdiff_t dst_off = static_cast(SafeInt(cache_pos) * suffix_bytes); + ptrdiff_t src_off = static_cast(SafeInt(s) * suffix_bytes); + memcpy(cache_base + dst_off, update_base + src_off, suffix_bytes); + } + } + } + + return Status::OK(); +} + +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cpu/llm/tensorscatter.h b/onnxruntime/core/providers/cpu/llm/tensorscatter.h new file mode 100644 index 0000000000000..cb3bd84501476 --- /dev/null +++ b/onnxruntime/core/providers/cpu/llm/tensorscatter.h @@ -0,0 +1,20 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "core/common/common.h" +#include "core/framework/op_kernel.h" + +namespace onnxruntime { + +class TensorScatter final : public OpKernel { + public: + TensorScatter(const OpKernelInfo& info); + Status Compute(OpKernelContext* context) const override; + + private: + int64_t axis_; + bool circular_; +}; + +} // namespace onnxruntime diff --git a/onnxruntime/test/providers/cpu/llm/tensorscatter_op_test.cc b/onnxruntime/test/providers/cpu/llm/tensorscatter_op_test.cc new file mode 100644 index 0000000000000..34d72dab3d31b --- /dev/null +++ b/onnxruntime/test/providers/cpu/llm/tensorscatter_op_test.cc @@ -0,0 +1,300 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "gtest/gtest.h" +#include "test/providers/provider_test_utils.h" + +#include "core/graph/model.h" +#include "core/graph/node_attr_utils.h" +#include "core/graph/onnx_protobuf.h" +#include "core/session/inference_session.h" +#include "core/session/IOBinding.h" +#include "test/unittest_util/framework_test_utils.h" +#include "test/util/include/default_providers.h" +#include "test/util/include/test_environment.h" + +namespace onnxruntime { +namespace test { + +// From ONNX spec example: tensorscatter (4D, linear mode) +TEST(TensorScatterTest, Linear_4D) { + OpTester test("TensorScatter", 24); + test.AddAttribute("mode", "linear"); + + // past_cache: shape (2, 1, 4, 5) + test.AddInput("past_cache", {2, 1, 4, 5}, + {1, 2, 3, 4, 5, 5, 6, 7, 8, 9, 8, 7, 6, 5, 4, 4, 3, 2, 1, 0, + 1, 2, 3, 4, 5, 5, 6, 7, 8, 9, 8, 7, 6, 5, 4, 4, 3, 2, 1, 0}); + + // update: shape (2, 1, 1, 5) + test.AddInput("update", {2, 1, 1, 5}, + {5, 5, 5, 5, 5, + 1, 1, 1, 1, 1}); + + // write_indices: shape (2,) + test.AddInput("write_indices", {2}, {1, 2}); + + // present_cache: shape (2, 1, 4, 5) + test.AddOutput("present_cache", {2, 1, 4, 5}, + {1, 2, 3, 4, 5, 5, 5, 5, 5, 5, 8, 7, 6, 5, 4, 4, 3, 2, 1, 0, + 1, 2, 3, 4, 5, 5, 6, 7, 8, 9, 1, 1, 1, 1, 1, 4, 3, 2, 1, 0}); + + test.Run(); +} + +// From ONNX spec example: tensorscatter_3d (3D, default axis=-2 -> axis=1) +TEST(TensorScatterTest, Linear_3D) { + OpTester test("TensorScatter", 24); + + // past_cache: shape (3, 4, 5) + test.AddInput("past_cache", {3, 4, 5}, + {1, 2, 3, 4, 5, 5, 6, 7, 8, 9, 8, 7, 6, 5, 4, 5, 4, 3, 2, 1, + 1, 2, 3, 4, 5, 5, 6, 7, 8, 9, 8, 7, 6, 5, 4, 5, 4, 3, 2, 1, + 1, 2, 3, 4, 5, 5, 6, 7, 8, 9, 8, 7, 6, 5, 4, 5, 4, 3, 2, 1}); + + // update: shape (3, 2, 5) + test.AddInput("update", {3, 2, 5}, + {4, 4, 4, 4, 4, 5, 5, 5, 5, 5, + 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, + 2, 2, 2, 2, 2, 3, 3, 3, 3, 3}); + + // write_indices: shape (3,) + test.AddInput("write_indices", {3}, {1, 2, 0}); + + // present_cache: shape (3, 4, 5) + test.AddOutput("present_cache", {3, 4, 5}, + {1, 2, 3, 4, 5, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 4, 3, 2, 1, + 1, 2, 3, 4, 5, 5, 6, 7, 8, 9, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, + 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 8, 7, 6, 5, 4, 5, 4, 3, 2, 1}); + + test.Run(); +} + +// From ONNX spec example: tensorscatter_circular (4D, circular mode) +TEST(TensorScatterTest, Circular_4D) { + OpTester test("TensorScatter", 24); + test.AddAttribute("mode", "circular"); + + // past_cache: shape (2, 1, 4, 5) + test.AddInput("past_cache", {2, 1, 4, 5}, + {1, 2, 3, 4, 5, 5, 6, 7, 8, 9, 8, 7, 6, 5, 4, 4, 3, 2, 1, 0, + 1, 2, 3, 4, 5, 5, 6, 7, 8, 9, 8, 7, 6, 5, 4, 4, 3, 2, 1, 0}); + + // update: shape (2, 1, 2, 5) + test.AddInput("update", {2, 1, 2, 5}, + {5, 5, 5, 5, 5, 6, 6, 6, 6, 6, + 1, 1, 1, 1, 1, 2, 2, 2, 2, 2}); + + // write_indices: shape (2,) + test.AddInput("write_indices", {2}, {1, 3}); + + // present_cache: shape (2, 1, 4, 5) + // Batch 0: wi=1, seq_len=2 -> positions 1,2 (no wrap) + // Batch 1: wi=3, seq_len=2 -> positions 3, 0 (wraps around) + test.AddOutput("present_cache", {2, 1, 4, 5}, + {1, 2, 3, 4, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 4, 3, 2, 1, 0, + 2, 2, 2, 2, 2, 5, 6, 7, 8, 9, 8, 7, 6, 5, 4, 1, 1, 1, 1, 1}); + + test.Run(); +} + +// No write_indices (defaults to zero) — prefill scenario. +TEST(TensorScatterTest, Linear_NoWriteIndices) { + OpTester test("TensorScatter", 24); + test.AddAttribute("mode", "linear"); + + // past_cache: shape (1, 1, 4, 3) + test.AddInput("past_cache", {1, 1, 4, 3}, + {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}); + + // update: shape (1, 1, 2, 3) — writes at position 0 (default) + test.AddInput("update", {1, 1, 2, 3}, + {1, 2, 3, 4, 5, 6}); + + test.AddOptionalInputEdge(); + + // present_cache: positions 0,1 filled, 2,3 untouched + test.AddOutput("present_cache", {1, 1, 4, 3}, + {1, 2, 3, 4, 5, 6, 0, 0, 0, 0, 0, 0}); + + test.Run(); +} + +// Float16 type test +TEST(TensorScatterTest, Linear_Float16) { + OpTester test("TensorScatter", 24); + test.AddAttribute("mode", "linear"); + + std::vector past_f = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}; + std::vector update_f = {99, 98, 97}; + std::vector expected_f = {1, 2, 3, 99, 98, 97, 7, 8, 9, 10, 11, 12}; + + std::vector past_fp16(past_f.size()); + std::vector update_fp16(update_f.size()); + std::vector expected_fp16(expected_f.size()); + for (size_t i = 0; i < past_f.size(); ++i) past_fp16[i] = MLFloat16(past_f[i]); + for (size_t i = 0; i < update_f.size(); ++i) update_fp16[i] = MLFloat16(update_f[i]); + for (size_t i = 0; i < expected_f.size(); ++i) expected_fp16[i] = MLFloat16(expected_f[i]); + + // shape (1, 4, 3), axis=-2 -> axis=1 + test.AddInput("past_cache", {1, 4, 3}, past_fp16); + test.AddInput("update", {1, 1, 3}, update_fp16); + test.AddInput("write_indices", {1}, {1}); + test.AddOutput("present_cache", {1, 4, 3}, expected_fp16); + + test.Run(); +} + +// Explicit axis attribute test (axis=1 on a 3D tensor) +TEST(TensorScatterTest, Linear_ExplicitAxis) { + OpTester test("TensorScatter", 24); + test.AddAttribute("mode", "linear"); + test.AddAttribute("axis", 1); + + // shape (2, 3, 2) — axis=1 means the seq dim is dim 1 (size 3) + test.AddInput("past_cache", {2, 3, 2}, + {0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0}); + + // update: shape (2, 1, 2) + test.AddInput("update", {2, 1, 2}, + {1, 2, + 3, 4}); + + test.AddInput("write_indices", {2}, {0, 2}); + + test.AddOutput("present_cache", {2, 3, 2}, + {1, 2, 0, 0, 0, 0, + 0, 0, 0, 0, 3, 4}); + + test.Run(); +} + +// Circular wrap-around with multi-position update +TEST(TensorScatterTest, Circular_WrapAround) { + OpTester test("TensorScatter", 24); + test.AddAttribute("mode", "circular"); + + // shape (1, 4, 2), axis=-2 -> axis=1, max_seq=4 + test.AddInput("past_cache", {1, 4, 2}, + {10, 11, 20, 21, 30, 31, 40, 41}); + + // update: 3 positions starting at wi=2 -> positions 2, 3, 0 (wraps) + test.AddInput("update", {1, 3, 2}, + {1, 2, 3, 4, 5, 6}); + + test.AddInput("write_indices", {1}, {2}); + + // pos 2->1,2 pos 3->3,4 pos 0->5,6 (wrapped) + test.AddOutput("present_cache", {1, 4, 2}, + {5, 6, 20, 21, 1, 2, 3, 4}); + + test.Run(); +} + +// IO-binding test: bind the same buffer as both input[0] (past_cache) and +// output[0] (present_cache) to verify the MayInplace(0,0) in-place path. +TEST(TensorScatterTest, InPlace_IOBinding) { + // Build a TensorScatter model programmatically. + std::unordered_map domain_to_version; + domain_to_version[onnxruntime::kOnnxDomain] = 24; + std::vector model_specific_functions; + auto p_model = std::make_unique( + "tensorscatter_inplace_test", true, ModelMetaData(), PathString(), + IOnnxRuntimeOpSchemaRegistryList(), domain_to_version, + model_specific_functions, DefaultLoggingManager().DefaultLogger(), + ModelOptions(true, true)); + onnxruntime::Graph& graph = p_model->MainGraph(); + + // Define types. + ONNX_NAMESPACE::TypeProto tensor_float; + tensor_float.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + + ONNX_NAMESPACE::TypeProto tensor_int64; + tensor_int64.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); + + // Inputs. + auto& past_cache_arg = graph.GetOrCreateNodeArg("past_cache", &tensor_float); + auto& update_arg = graph.GetOrCreateNodeArg("update", &tensor_float); + auto& write_indices_arg = graph.GetOrCreateNodeArg("write_indices", &tensor_int64); + std::vector input_defs = {&past_cache_arg, &update_arg, &write_indices_arg}; + + // Output. + auto& present_cache_arg = graph.GetOrCreateNodeArg("present_cache", &tensor_float); + std::vector output_defs = {&present_cache_arg}; + + // Attributes: mode = "linear". + NodeAttributes attrs = { + {"mode", utils::MakeAttribute("mode", std::string("linear"))}}; + + auto& node = graph.AddNode("ts_node", "TensorScatter", "TensorScatter in-place test", + input_defs, output_defs, &attrs, onnxruntime::kOnnxDomain); + node.SetExecutionProviderType(kCpuExecutionProvider); + + ASSERT_STATUS_OK(graph.Resolve()); + + // Serialize and load into InferenceSession. + std::string model_str; + p_model->ToProto().SerializeToString(&model_str); + std::stringstream sstr(model_str); + + SessionOptions so; + so.session_logid = "TensorScatterInPlaceTest"; + InferenceSession session(so, GetEnvironment()); + ASSERT_STATUS_OK(session.Load(sstr)); + ASSERT_STATUS_OK(session.Initialize()); + + // Allocate the shared buffer for past_cache / present_cache. + // Shape: (1, 4, 3) — 1 batch, 4 seq positions, 3 features. + std::vector cache_dims = {1, 4, 3}; + std::vector cache_data = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}; + + // update shape: (1, 1, 3) — write 1 position. + std::vector update_dims = {1, 1, 3}; + std::vector update_data = {99, 98, 97}; + + // write_indices: write at position 2. + std::vector wi_dims = {1}; + std::vector wi_data = {2}; + + auto cpu_alloc = TestCPUExecutionProvider()->CreatePreferredAllocators()[0]; + + // Create the shared OrtValue backed by cache_data for both input and output. + OrtValue shared_cache; + CreateMLValue(cache_dims, cache_data.data(), cpu_alloc->Info(), &shared_cache); + + OrtValue update_val; + CreateMLValue(cpu_alloc, update_dims, update_data, &update_val); + + OrtValue wi_val; + CreateMLValue(cpu_alloc, wi_dims, wi_data, &wi_val); + + // Set up IO binding: same OrtValue for input past_cache and output present_cache. + std::unique_ptr io_binding; + ASSERT_STATUS_OK(session.NewIOBinding(&io_binding)); + ASSERT_STATUS_OK(io_binding->BindInput("past_cache", shared_cache)); + ASSERT_STATUS_OK(io_binding->BindInput("update", update_val)); + ASSERT_STATUS_OK(io_binding->BindInput("write_indices", wi_val)); + ASSERT_STATUS_OK(io_binding->BindOutput("present_cache", shared_cache)); + + RunOptions run_options; + ASSERT_STATUS_OK(session.Run(run_options, *io_binding)); + + // Verify that the shared buffer was updated in-place. + // Original: {1,2,3, 4,5,6, 7,8,9, 10,11,12} + // After scatter at position 2: {1,2,3, 4,5,6, 99,98,97, 10,11,12} + std::vector expected = {1, 2, 3, 4, 5, 6, 99, 98, 97, 10, 11, 12}; + const auto& output = io_binding->GetOutputs()[0]; + auto span = output.Get().DataAsSpan(); + ASSERT_EQ(static_cast(span.size()), expected.size()); + for (size_t i = 0; i < expected.size(); ++i) { + EXPECT_FLOAT_EQ(span[i], expected[i]) << "mismatch at index " << i; + } + + // Verify the buffer pointer is still the same (in-place). + EXPECT_EQ(output.Get().Data(), cache_data.data()) + << "Output should alias the original cache_data buffer"; +} + +} // namespace test +} // namespace onnxruntime