diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 591d2c6806ea7..abdcc81586909 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -950,6 +950,7 @@ Do not modify directly.* |||[6, 7]|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)| |Tanh|*in* input:**T**
*out* output:**T**|13+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)| |||[6, 12]|**T** = tensor(double), tensor(float), tensor(float16)| +|TensorScatter|*in* past_cache:**T**
*in* update:**T**
*in* write_indices:**tensor(int64)**
*out* present_cache:**T**|24+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |ThresholdedRelu|*in* X:**T**
*out* Y:**T**|10+|**T** = tensor(double), tensor(float), tensor(float16)| |||1+|**T** = tensor(double), tensor(float), tensor(float16)| |Tile|*in* input:**T**
*in* repeats:**T1**
*out* output:**T**

or

*in* input:**T**
*in* tiles:**T**
*in* axis:**T**
*out* output:**T**|13+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)
**T1** = tensor(int64)| diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc index b4cb6c6bd122c..3b334fb507f7e 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc @@ -1629,6 +1629,9 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, U class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, Float4E2M1x2, Cast); #endif +// Opset 24. +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 24, TensorScatter); + #endif template <> @@ -2703,6 +2706,9 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + + // Opset 24 + BuildKernelCreateInfo, #endif }; diff --git a/onnxruntime/core/providers/cuda/llm/tensorscatter.cc b/onnxruntime/core/providers/cuda/llm/tensorscatter.cc new file mode 100644 index 0000000000000..f889042b785d9 --- /dev/null +++ b/onnxruntime/core/providers/cuda/llm/tensorscatter.cc @@ -0,0 +1,145 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/cuda/llm/tensorscatter.h" +#include "core/providers/cuda/llm/tensorscatter_impl.h" +#include "core/providers/cuda/cuda_common.h" + +namespace onnxruntime { +namespace cuda { + +ONNX_OPERATOR_KERNEL_EX( + TensorScatter, + kOnnxDomain, + 24, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .MayInplace(0, 0) + .TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes()), + TensorScatter); + +TensorScatter::TensorScatter(const OpKernelInfo& info) : CudaKernel(info) { + axis_ = info.GetAttrOrDefault("axis", -2); + std::string mode = info.GetAttrOrDefault("mode", "linear"); + ORT_ENFORCE(mode == "linear" || mode == "circular", + "TensorScatter: mode must be 'linear' or 'circular', got '", mode, "'"); + circular_ = (mode == "circular"); +} + +Status TensorScatter::ComputeInternal(OpKernelContext* context) const { + const Tensor* past_cache = context->Input(0); + const Tensor* update = context->Input(1); + const Tensor* write_indices_tensor = context->Input(2); // optional + + ORT_ENFORCE(past_cache != nullptr && update != nullptr, + "TensorScatter: past_cache and update must not be null"); + + const auto& cache_shape = past_cache->Shape(); + const auto& update_shape = update->Shape(); + const int ndim = static_cast(cache_shape.NumDimensions()); + + ORT_ENFORCE(ndim >= 2, "TensorScatter: past_cache must have at least 2 dimensions"); + ORT_ENFORCE(update_shape.NumDimensions() == cache_shape.NumDimensions(), + "TensorScatter: past_cache and update must have the same number of dimensions"); + + // Resolve axis (handles negative values). + int axis = static_cast(axis_); + if (axis < 0) axis += ndim; + ORT_ENFORCE(axis > 0 && axis < ndim, + "TensorScatter: axis must be in [1, ndim-1] after normalization, got ", axis); + + // Validate shapes: all dimensions must match except the axis dimension. + const int64_t batch_size = cache_shape[0]; + const int64_t max_sequence_length = cache_shape[axis]; + const int64_t sequence_length = update_shape[axis]; + + ORT_ENFORCE(sequence_length <= max_sequence_length, + "TensorScatter: update sequence_length (", sequence_length, + ") exceeds max_sequence_length (", max_sequence_length, ")"); + + for (int d = 0; d < ndim; ++d) { + if (d != axis) { + ORT_ENFORCE(cache_shape[d] == update_shape[d], + "TensorScatter: shape mismatch in dimension ", d, + ": past_cache=", cache_shape[d], " vs update=", update_shape[d]); + } + } + + // Validate write_indices if provided. + const int64_t* write_indices = nullptr; + if (write_indices_tensor != nullptr) { + ORT_ENFORCE(write_indices_tensor->Shape().NumDimensions() == 1 && + write_indices_tensor->Shape()[0] == batch_size, + "TensorScatter: write_indices must have shape [batch_size]"); + write_indices = write_indices_tensor->Data(); + + // Copy write_indices to host for validation (batch_size elements, negligible overhead). + std::vector host_write_indices(static_cast(batch_size)); + CUDA_RETURN_IF_ERROR( + cudaMemcpyAsync(host_write_indices.data(), write_indices, + static_cast(batch_size) * sizeof(int64_t), + cudaMemcpyDeviceToHost, Stream(context))); + CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(Stream(context))); + + for (int64_t b = 0; b < batch_size; ++b) { + int64_t wi = host_write_indices[static_cast(b)]; + ORT_ENFORCE(wi >= 0, "TensorScatter: write_indices[", b, "] = ", wi, " is negative"); + if (!circular_) { + ORT_ENFORCE(wi + sequence_length <= max_sequence_length, + "TensorScatter linear mode: write_indices[", b, "] + sequence_length (", + wi, " + ", sequence_length, ") exceeds max_sequence_length (", max_sequence_length, ")"); + } + } + } + + // Allocate output with the same shape as past_cache. + Tensor* present_cache = context->Output(0, cache_shape); + + // Step 1: Copy past_cache -> present_cache. + const void* src_raw = past_cache->DataRaw(); + void* dst_raw = present_cache->MutableDataRaw(); + if (dst_raw != src_raw) { + CUDA_RETURN_IF_ERROR( + cudaMemcpyAsync(dst_raw, src_raw, past_cache->SizeInBytes(), + cudaMemcpyDeviceToDevice, Stream(context))); + } + + // Bail out early if nothing to scatter. + if (sequence_length == 0) { + return Status::OK(); + } + + // Step 2: Scatter the update into present_cache. + const size_t element_size = past_cache->DataType()->Size(); + + int64_t prefix_count = 1; + for (int d = 0; d < axis; ++d) { + prefix_count *= cache_shape[d]; + } + + int64_t suffix_count = 1; + for (int d = axis + 1; d < ndim; ++d) { + suffix_count *= cache_shape[d]; + } + + int64_t prefix_stride_for_batch = 1; + for (int d = 1; d < axis; ++d) { + prefix_stride_for_batch *= cache_shape[d]; + } + + return TensorScatterImpl( + Stream(context), + dst_raw, + update->DataRaw(), + write_indices, + element_size, + prefix_count, + prefix_stride_for_batch, + max_sequence_length, + sequence_length, + suffix_count, + circular_); +} + +} // namespace cuda +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/llm/tensorscatter.h b/onnxruntime/core/providers/cuda/llm/tensorscatter.h new file mode 100644 index 0000000000000..cde0f658fc786 --- /dev/null +++ b/onnxruntime/core/providers/cuda/llm/tensorscatter.h @@ -0,0 +1,22 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "core/common/common.h" +#include "core/providers/cuda/cuda_kernel.h" + +namespace onnxruntime { +namespace cuda { + +class TensorScatter final : public CudaKernel { + public: + TensorScatter(const OpKernelInfo& info); + Status ComputeInternal(OpKernelContext* context) const override; + + private: + int64_t axis_; + bool circular_; +}; + +} // namespace cuda +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/llm/tensorscatter_impl.cu b/onnxruntime/core/providers/cuda/llm/tensorscatter_impl.cu new file mode 100644 index 0000000000000..f9ebf3918db4e --- /dev/null +++ b/onnxruntime/core/providers/cuda/llm/tensorscatter_impl.cu @@ -0,0 +1,118 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/cuda/llm/tensorscatter_impl.h" +#include "core/providers/cuda/cu_inc/common.cuh" + +namespace onnxruntime { +namespace cuda { + +template +__global__ void _TensorScatterKernel( + T* output_data, + const T* update_data, + const int64_t* write_indices, + int64_t prefix_count, + int64_t prefix_stride_for_batch, + int64_t max_seq_len, + int64_t seq_len, + int64_t suffix_count, + size_t total_elements) { + CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, total_elements); + + int64_t seq_suffix = seq_len * suffix_count; + int64_t prefix_idx = id / seq_suffix; + int64_t remainder = id - prefix_idx * seq_suffix; + int64_t seq_idx = remainder / suffix_count; + int64_t suffix_idx = remainder - seq_idx * suffix_count; + + int64_t batch_idx = prefix_idx / prefix_stride_for_batch; + int64_t wi = (write_indices != nullptr) ? write_indices[batch_idx] : 0; + // write_indices are validated on the host before kernel launch. + int64_t cache_pos; + if (circular) { + cache_pos = (wi + seq_idx) % max_seq_len; + } else { + cache_pos = wi + seq_idx; + } + + int64_t out_offset = prefix_idx * (max_seq_len * suffix_count) + cache_pos * suffix_count + suffix_idx; + output_data[out_offset] = update_data[id]; +} + +template +Status _TensorScatterDispatchCircular( + cudaStream_t stream, + T* output_data, + const T* update_data, + const int64_t* write_indices, + int64_t prefix_count, + int64_t prefix_stride_for_batch, + int64_t max_seq_len, + int64_t seq_len, + int64_t suffix_count, + bool circular) { + size_t total_elements = static_cast(prefix_count) * static_cast(seq_len) * static_cast(suffix_count); + if (total_elements == 0) return Status::OK(); + + int blocksPerGrid = static_cast(CeilDiv(total_elements, static_cast(GridDim::maxThreadsPerBlock))); + + if (circular) { + _TensorScatterKernel<<>>( + output_data, update_data, write_indices, + prefix_count, prefix_stride_for_batch, max_seq_len, seq_len, suffix_count, + total_elements); + } else { + _TensorScatterKernel<<>>( + output_data, update_data, write_indices, + prefix_count, prefix_stride_for_batch, max_seq_len, seq_len, suffix_count, + total_elements); + } + + return CUDA_CALL(cudaGetLastError()); +} + +Status TensorScatterImpl( + cudaStream_t stream, + void* output_data, + const void* update_data, + const int64_t* write_indices, + size_t element_size, + int64_t prefix_count, + int64_t prefix_stride_for_batch, + int64_t max_seq_len, + int64_t seq_len, + int64_t suffix_count, + bool circular) { + switch (element_size) { + case sizeof(int8_t): + return _TensorScatterDispatchCircular( + stream, reinterpret_cast(output_data), + reinterpret_cast(update_data), write_indices, + prefix_count, prefix_stride_for_batch, max_seq_len, seq_len, suffix_count, circular); + + case sizeof(int16_t): + return _TensorScatterDispatchCircular( + stream, reinterpret_cast(output_data), + reinterpret_cast(update_data), write_indices, + prefix_count, prefix_stride_for_batch, max_seq_len, seq_len, suffix_count, circular); + + case sizeof(int32_t): + return _TensorScatterDispatchCircular( + stream, reinterpret_cast(output_data), + reinterpret_cast(update_data), write_indices, + prefix_count, prefix_stride_for_batch, max_seq_len, seq_len, suffix_count, circular); + + case sizeof(int64_t): + return _TensorScatterDispatchCircular( + stream, reinterpret_cast(output_data), + reinterpret_cast(update_data), write_indices, + prefix_count, prefix_stride_for_batch, max_seq_len, seq_len, suffix_count, circular); + + default: + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unsupported element size for TensorScatter: ", element_size); + } +} + +} // namespace cuda +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/llm/tensorscatter_impl.h b/onnxruntime/core/providers/cuda/llm/tensorscatter_impl.h new file mode 100644 index 0000000000000..8b6ce5b53b219 --- /dev/null +++ b/onnxruntime/core/providers/cuda/llm/tensorscatter_impl.h @@ -0,0 +1,25 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/cuda/shared_inc/cuda_utils.h" + +namespace onnxruntime { +namespace cuda { + +Status TensorScatterImpl( + cudaStream_t stream, + void* output_data, + const void* update_data, + const int64_t* write_indices, + size_t element_size, + int64_t prefix_count, + int64_t prefix_stride_for_batch, + int64_t max_seq_len, + int64_t seq_len, + int64_t suffix_count, + bool circular); + +} // namespace cuda +} // namespace onnxruntime diff --git a/onnxruntime/test/providers/cpu/llm/tensorscatter_op_test.cc b/onnxruntime/test/providers/cpu/llm/tensorscatter_op_test.cc index 34d72dab3d31b..bc68c29c1bc7b 100644 --- a/onnxruntime/test/providers/cpu/llm/tensorscatter_op_test.cc +++ b/onnxruntime/test/providers/cpu/llm/tensorscatter_op_test.cc @@ -296,5 +296,51 @@ TEST(TensorScatterTest, InPlace_IOBinding) { << "Output should alias the original cache_data buffer"; } +// Negative write_indices should fail validation. +TEST(TensorScatterTest, Linear_NegativeWriteIndex) { + OpTester test("TensorScatter", 24); + test.AddAttribute("mode", "linear"); + + test.AddInput("past_cache", {1, 4, 3}, + {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}); + test.AddInput("update", {1, 1, 3}, {1, 2, 3}); + test.AddInput("write_indices", {1}, {-1}); + test.AddOutput("present_cache", {1, 4, 3}, + {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}); + + test.Run(OpTester::ExpectResult::kExpectFailure, "is negative"); +} + +// Linear mode: write_indices + sequence_length > max_sequence_length should fail. +TEST(TensorScatterTest, Linear_OutOfBoundsWriteIndex) { + OpTester test("TensorScatter", 24); + test.AddAttribute("mode", "linear"); + + // max_seq=4, update seq_len=2, wi=3 -> 3+2=5 > 4 + test.AddInput("past_cache", {1, 4, 3}, + {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}); + test.AddInput("update", {1, 2, 3}, {1, 2, 3, 4, 5, 6}); + test.AddInput("write_indices", {1}, {3}); + test.AddOutput("present_cache", {1, 4, 3}, + {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}); + + test.Run(OpTester::ExpectResult::kExpectFailure, "exceeds max_sequence_length"); +} + +// Circular mode: negative write_indices should still fail. +TEST(TensorScatterTest, Circular_NegativeWriteIndex) { + OpTester test("TensorScatter", 24); + test.AddAttribute("mode", "circular"); + + test.AddInput("past_cache", {1, 4, 2}, + {0, 0, 0, 0, 0, 0, 0, 0}); + test.AddInput("update", {1, 1, 2}, {1, 2}); + test.AddInput("write_indices", {1}, {-2}); + test.AddOutput("present_cache", {1, 4, 2}, + {0, 0, 0, 0, 0, 0, 0, 0}); + + test.Run(OpTester::ExpectResult::kExpectFailure, "is negative"); +} + } // namespace test } // namespace onnxruntime