Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions onnxruntime/core/providers/cuda/cuda_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1629,6 +1629,9 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, U
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, Float4E2M1x2, Cast);
#endif

// Opset 24.
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 24, TensorScatter);

#endif

template <>
Expand Down Expand Up @@ -2703,6 +2706,9 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, Transpose)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, Squeeze)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, Unsqueeze)>,

// Opset 24
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 24, TensorScatter)>,
#endif
};

Expand Down
127 changes: 127 additions & 0 deletions onnxruntime/core/providers/cuda/llm/tensorscatter.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "core/providers/cuda/llm/tensorscatter.h"
#include "core/providers/cuda/llm/tensorscatter_impl.h"
#include "core/providers/cuda/cuda_common.h"

namespace onnxruntime {
namespace cuda {

ONNX_OPERATOR_KERNEL_EX(
TensorScatter,
kOnnxDomain,
24,
kCudaExecutionProvider,
(*KernelDefBuilder::Create())
.MayInplace(0, 0)
.TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes()),
TensorScatter);

TensorScatter::TensorScatter(const OpKernelInfo& info) : CudaKernel(info) {
axis_ = info.GetAttrOrDefault<int64_t>("axis", -2);
std::string mode = info.GetAttrOrDefault<std::string>("mode", "linear");

Check warning on line 23 in onnxruntime/core/providers/cuda/llm/tensorscatter.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <string> for string [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/cuda/llm/tensorscatter.cc:23: Add #include <string> for string [build/include_what_you_use] [4]
ORT_ENFORCE(mode == "linear" || mode == "circular",
"TensorScatter: mode must be 'linear' or 'circular', got '", mode, "'");
circular_ = (mode == "circular");
}

Status TensorScatter::ComputeInternal(OpKernelContext* context) const {
const Tensor* past_cache = context->Input<Tensor>(0);
const Tensor* update = context->Input<Tensor>(1);
const Tensor* write_indices_tensor = context->Input<Tensor>(2); // optional

ORT_ENFORCE(past_cache != nullptr && update != nullptr,
"TensorScatter: past_cache and update must not be null");

const auto& cache_shape = past_cache->Shape();
const auto& update_shape = update->Shape();
const int ndim = static_cast<int>(cache_shape.NumDimensions());

ORT_ENFORCE(ndim >= 2, "TensorScatter: past_cache must have at least 2 dimensions");
ORT_ENFORCE(update_shape.NumDimensions() == cache_shape.NumDimensions(),
"TensorScatter: past_cache and update must have the same number of dimensions");

// Resolve axis (handles negative values).
int axis = static_cast<int>(axis_);
if (axis < 0) axis += ndim;
ORT_ENFORCE(axis > 0 && axis < ndim,
"TensorScatter: axis must be in [1, ndim-1] after normalization, got ", axis);

// Validate shapes: all dimensions must match except the axis dimension.
const int64_t batch_size = cache_shape[0];
const int64_t max_sequence_length = cache_shape[axis];
const int64_t sequence_length = update_shape[axis];

ORT_ENFORCE(sequence_length <= max_sequence_length,
"TensorScatter: update sequence_length (", sequence_length,
") exceeds max_sequence_length (", max_sequence_length, ")");

for (int d = 0; d < ndim; ++d) {
if (d != axis) {
ORT_ENFORCE(cache_shape[d] == update_shape[d],
"TensorScatter: shape mismatch in dimension ", d,
": past_cache=", cache_shape[d], " vs update=", update_shape[d]);
}
}

// Validate write_indices if provided.
const int64_t* write_indices = nullptr;
if (write_indices_tensor != nullptr) {
ORT_ENFORCE(write_indices_tensor->Shape().NumDimensions() == 1 &&
write_indices_tensor->Shape()[0] == batch_size,
"TensorScatter: write_indices must have shape [batch_size]");
write_indices = write_indices_tensor->Data<int64_t>();
}

// Allocate output with the same shape as past_cache.
Tensor* present_cache = context->Output(0, cache_shape);

// Step 1: Copy past_cache -> present_cache.
const void* src_raw = past_cache->DataRaw();
void* dst_raw = present_cache->MutableDataRaw();
if (dst_raw != src_raw) {
CUDA_RETURN_IF_ERROR(
cudaMemcpyAsync(dst_raw, src_raw, past_cache->SizeInBytes(),
cudaMemcpyDeviceToDevice, Stream(context)));
}

// Bail out early if nothing to scatter.
if (sequence_length == 0) {
return Status::OK();
}

// Step 2: Scatter the update into present_cache.
const size_t element_size = past_cache->DataType()->Size();

int64_t prefix_count = 1;
for (int d = 0; d < axis; ++d) {
prefix_count *= cache_shape[d];
}

int64_t suffix_count = 1;
for (int d = axis + 1; d < ndim; ++d) {
suffix_count *= cache_shape[d];
}

int64_t prefix_stride_for_batch = 1;
for (int d = 1; d < axis; ++d) {
prefix_stride_for_batch *= cache_shape[d];
}

return TensorScatterImpl(
Stream(context),
dst_raw,
update->DataRaw(),
write_indices,
element_size,
prefix_count,
prefix_stride_for_batch,
max_sequence_length,
sequence_length,
suffix_count,
circular_);
}

} // namespace cuda
} // namespace onnxruntime
22 changes: 22 additions & 0 deletions onnxruntime/core/providers/cuda/llm/tensorscatter.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once
#include "core/common/common.h"
#include "core/providers/cuda/cuda_kernel.h"

namespace onnxruntime {
namespace cuda {

class TensorScatter final : public CudaKernel {
public:
TensorScatter(const OpKernelInfo& info);
Status ComputeInternal(OpKernelContext* context) const override;

private:
int64_t axis_;
bool circular_;
};

} // namespace cuda
} // namespace onnxruntime
122 changes: 122 additions & 0 deletions onnxruntime/core/providers/cuda/llm/tensorscatter_impl.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "core/providers/cuda/llm/tensorscatter_impl.h"
#include "core/providers/cuda/cu_inc/common.cuh"

namespace onnxruntime {
namespace cuda {

template <typename T, bool circular>
__global__ void _TensorScatterKernel(
T* output_data,
const T* update_data,
const int64_t* write_indices,
int64_t prefix_count,
int64_t prefix_stride_for_batch,
int64_t max_seq_len,
int64_t seq_len,
int64_t suffix_count,
size_t total_elements) {
CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, total_elements);

int64_t seq_suffix = seq_len * suffix_count;
int64_t prefix_idx = id / seq_suffix;
int64_t remainder = id - prefix_idx * seq_suffix;
int64_t seq_idx = remainder / suffix_count;
int64_t suffix_idx = remainder - seq_idx * suffix_count;

int64_t batch_idx = prefix_idx / prefix_stride_for_batch;
int64_t wi = (write_indices != nullptr) ? write_indices[batch_idx] : 0;
// Clamp negative write indices to 0 (can't throw from device code,
// following the ScatterND clamping pattern).
if (wi < 0) wi = 0;
Comment thread
titaiwangms marked this conversation as resolved.
Outdated
int64_t cache_pos;
if (circular) {
cache_pos = (wi + seq_idx) % max_seq_len;
} else {
cache_pos = wi + seq_idx;
// Clamp to valid range to prevent out-of-bounds writes.
if (cache_pos >= max_seq_len) cache_pos = max_seq_len - 1;
Comment thread
titaiwangms marked this conversation as resolved.
Outdated
}
Comment thread
titaiwangms marked this conversation as resolved.

int64_t out_offset = prefix_idx * (max_seq_len * suffix_count) + cache_pos * suffix_count + suffix_idx;
output_data[out_offset] = update_data[id];
}

template <typename T>
Status _TensorScatterDispatchCircular(
cudaStream_t stream,
T* output_data,
const T* update_data,
const int64_t* write_indices,
int64_t prefix_count,
int64_t prefix_stride_for_batch,
int64_t max_seq_len,
int64_t seq_len,
int64_t suffix_count,
bool circular) {
size_t total_elements = static_cast<size_t>(prefix_count) * static_cast<size_t>(seq_len) * static_cast<size_t>(suffix_count);
if (total_elements == 0) return Status::OK();

int blocksPerGrid = static_cast<int>(ceil(static_cast<float>(total_elements) / GridDim::maxThreadsPerBlock));
Comment thread
titaiwangms marked this conversation as resolved.
Outdated

if (circular) {
_TensorScatterKernel<T, true><<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0, stream>>>(
output_data, update_data, write_indices,
prefix_count, prefix_stride_for_batch, max_seq_len, seq_len, suffix_count,
total_elements);
} else {
_TensorScatterKernel<T, false><<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0, stream>>>(
output_data, update_data, write_indices,
prefix_count, prefix_stride_for_batch, max_seq_len, seq_len, suffix_count,
total_elements);
}

return Status::OK();
Comment thread
titaiwangms marked this conversation as resolved.
Outdated
}

Status TensorScatterImpl(
cudaStream_t stream,
void* output_data,
const void* update_data,
const int64_t* write_indices,
size_t element_size,
int64_t prefix_count,
int64_t prefix_stride_for_batch,
int64_t max_seq_len,
int64_t seq_len,
int64_t suffix_count,
bool circular) {
switch (element_size) {
case sizeof(int8_t):
return _TensorScatterDispatchCircular<int8_t>(
stream, reinterpret_cast<int8_t*>(output_data),
reinterpret_cast<const int8_t*>(update_data), write_indices,
prefix_count, prefix_stride_for_batch, max_seq_len, seq_len, suffix_count, circular);

case sizeof(int16_t):
return _TensorScatterDispatchCircular<int16_t>(
stream, reinterpret_cast<int16_t*>(output_data),
reinterpret_cast<const int16_t*>(update_data), write_indices,
prefix_count, prefix_stride_for_batch, max_seq_len, seq_len, suffix_count, circular);

case sizeof(int32_t):
return _TensorScatterDispatchCircular<int32_t>(
stream, reinterpret_cast<int32_t*>(output_data),
reinterpret_cast<const int32_t*>(update_data), write_indices,
prefix_count, prefix_stride_for_batch, max_seq_len, seq_len, suffix_count, circular);

case sizeof(int64_t):
return _TensorScatterDispatchCircular<int64_t>(
stream, reinterpret_cast<int64_t*>(output_data),
reinterpret_cast<const int64_t*>(update_data), write_indices,
prefix_count, prefix_stride_for_batch, max_seq_len, seq_len, suffix_count, circular);

default:
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unsupported element size for TensorScatter: ", element_size);
}
}

} // namespace cuda
} // namespace onnxruntime
25 changes: 25 additions & 0 deletions onnxruntime/core/providers/cuda/llm/tensorscatter_impl.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

#include "core/providers/cuda/shared_inc/cuda_utils.h"

namespace onnxruntime {
namespace cuda {

Status TensorScatterImpl(
cudaStream_t stream,
void* output_data,
const void* update_data,
const int64_t* write_indices,
size_t element_size,
int64_t prefix_count,
int64_t prefix_stride_for_batch,
int64_t max_seq_len,
int64_t seq_len,
int64_t suffix_count,
bool circular);

} // namespace cuda
} // namespace onnxruntime
Loading