-
Notifications
You must be signed in to change notification settings - Fork 3.9k
Support TensorScatter (24) - CUDA #27446
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
3 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,145 @@ | ||
| // Copyright (c) Microsoft Corporation. All rights reserved. | ||
| // Licensed under the MIT License. | ||
|
|
||
| #include "core/providers/cuda/llm/tensorscatter.h" | ||
| #include "core/providers/cuda/llm/tensorscatter_impl.h" | ||
| #include "core/providers/cuda/cuda_common.h" | ||
|
|
||
| namespace onnxruntime { | ||
| namespace cuda { | ||
|
|
||
| ONNX_OPERATOR_KERNEL_EX( | ||
| TensorScatter, | ||
| kOnnxDomain, | ||
| 24, | ||
| kCudaExecutionProvider, | ||
| (*KernelDefBuilder::Create()) | ||
| .MayInplace(0, 0) | ||
| .TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes()), | ||
| TensorScatter); | ||
|
|
||
| TensorScatter::TensorScatter(const OpKernelInfo& info) : CudaKernel(info) { | ||
| axis_ = info.GetAttrOrDefault<int64_t>("axis", -2); | ||
| std::string mode = info.GetAttrOrDefault<std::string>("mode", "linear"); | ||
|
Check warning on line 23 in onnxruntime/core/providers/cuda/llm/tensorscatter.cc
|
||
| ORT_ENFORCE(mode == "linear" || mode == "circular", | ||
| "TensorScatter: mode must be 'linear' or 'circular', got '", mode, "'"); | ||
| circular_ = (mode == "circular"); | ||
| } | ||
|
|
||
| Status TensorScatter::ComputeInternal(OpKernelContext* context) const { | ||
| const Tensor* past_cache = context->Input<Tensor>(0); | ||
| const Tensor* update = context->Input<Tensor>(1); | ||
| const Tensor* write_indices_tensor = context->Input<Tensor>(2); // optional | ||
|
|
||
| ORT_ENFORCE(past_cache != nullptr && update != nullptr, | ||
| "TensorScatter: past_cache and update must not be null"); | ||
|
|
||
| const auto& cache_shape = past_cache->Shape(); | ||
| const auto& update_shape = update->Shape(); | ||
| const int ndim = static_cast<int>(cache_shape.NumDimensions()); | ||
|
|
||
| ORT_ENFORCE(ndim >= 2, "TensorScatter: past_cache must have at least 2 dimensions"); | ||
| ORT_ENFORCE(update_shape.NumDimensions() == cache_shape.NumDimensions(), | ||
| "TensorScatter: past_cache and update must have the same number of dimensions"); | ||
|
|
||
| // Resolve axis (handles negative values). | ||
| int axis = static_cast<int>(axis_); | ||
| if (axis < 0) axis += ndim; | ||
| ORT_ENFORCE(axis > 0 && axis < ndim, | ||
| "TensorScatter: axis must be in [1, ndim-1] after normalization, got ", axis); | ||
|
|
||
| // Validate shapes: all dimensions must match except the axis dimension. | ||
| const int64_t batch_size = cache_shape[0]; | ||
| const int64_t max_sequence_length = cache_shape[axis]; | ||
| const int64_t sequence_length = update_shape[axis]; | ||
|
|
||
| ORT_ENFORCE(sequence_length <= max_sequence_length, | ||
| "TensorScatter: update sequence_length (", sequence_length, | ||
| ") exceeds max_sequence_length (", max_sequence_length, ")"); | ||
|
|
||
| for (int d = 0; d < ndim; ++d) { | ||
| if (d != axis) { | ||
| ORT_ENFORCE(cache_shape[d] == update_shape[d], | ||
| "TensorScatter: shape mismatch in dimension ", d, | ||
| ": past_cache=", cache_shape[d], " vs update=", update_shape[d]); | ||
| } | ||
| } | ||
|
|
||
| // Validate write_indices if provided. | ||
| const int64_t* write_indices = nullptr; | ||
| if (write_indices_tensor != nullptr) { | ||
| ORT_ENFORCE(write_indices_tensor->Shape().NumDimensions() == 1 && | ||
| write_indices_tensor->Shape()[0] == batch_size, | ||
| "TensorScatter: write_indices must have shape [batch_size]"); | ||
| write_indices = write_indices_tensor->Data<int64_t>(); | ||
|
|
||
| // Copy write_indices to host for validation (batch_size elements, negligible overhead). | ||
| std::vector<int64_t> host_write_indices(static_cast<size_t>(batch_size)); | ||
|
Check warning on line 77 in onnxruntime/core/providers/cuda/llm/tensorscatter.cc
|
||
| CUDA_RETURN_IF_ERROR( | ||
| cudaMemcpyAsync(host_write_indices.data(), write_indices, | ||
| static_cast<size_t>(batch_size) * sizeof(int64_t), | ||
| cudaMemcpyDeviceToHost, Stream(context))); | ||
| CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(Stream(context))); | ||
|
|
||
|
titaiwangms marked this conversation as resolved.
|
||
| for (int64_t b = 0; b < batch_size; ++b) { | ||
| int64_t wi = host_write_indices[static_cast<size_t>(b)]; | ||
| ORT_ENFORCE(wi >= 0, "TensorScatter: write_indices[", b, "] = ", wi, " is negative"); | ||
| if (!circular_) { | ||
| ORT_ENFORCE(wi + sequence_length <= max_sequence_length, | ||
|
titaiwangms marked this conversation as resolved.
|
||
| "TensorScatter linear mode: write_indices[", b, "] + sequence_length (", | ||
| wi, " + ", sequence_length, ") exceeds max_sequence_length (", max_sequence_length, ")"); | ||
| } | ||
| } | ||
| } | ||
|
|
||
| // Allocate output with the same shape as past_cache. | ||
| Tensor* present_cache = context->Output(0, cache_shape); | ||
|
|
||
| // Step 1: Copy past_cache -> present_cache. | ||
| const void* src_raw = past_cache->DataRaw(); | ||
| void* dst_raw = present_cache->MutableDataRaw(); | ||
| if (dst_raw != src_raw) { | ||
| CUDA_RETURN_IF_ERROR( | ||
| cudaMemcpyAsync(dst_raw, src_raw, past_cache->SizeInBytes(), | ||
| cudaMemcpyDeviceToDevice, Stream(context))); | ||
| } | ||
|
|
||
| // Bail out early if nothing to scatter. | ||
| if (sequence_length == 0) { | ||
| return Status::OK(); | ||
| } | ||
|
|
||
| // Step 2: Scatter the update into present_cache. | ||
| const size_t element_size = past_cache->DataType()->Size(); | ||
|
|
||
| int64_t prefix_count = 1; | ||
| for (int d = 0; d < axis; ++d) { | ||
| prefix_count *= cache_shape[d]; | ||
| } | ||
|
|
||
| int64_t suffix_count = 1; | ||
| for (int d = axis + 1; d < ndim; ++d) { | ||
| suffix_count *= cache_shape[d]; | ||
| } | ||
|
|
||
| int64_t prefix_stride_for_batch = 1; | ||
| for (int d = 1; d < axis; ++d) { | ||
| prefix_stride_for_batch *= cache_shape[d]; | ||
| } | ||
|
|
||
| return TensorScatterImpl( | ||
| Stream(context), | ||
| dst_raw, | ||
| update->DataRaw(), | ||
| write_indices, | ||
| element_size, | ||
| prefix_count, | ||
| prefix_stride_for_batch, | ||
| max_sequence_length, | ||
| sequence_length, | ||
| suffix_count, | ||
| circular_); | ||
| } | ||
|
|
||
| } // namespace cuda | ||
| } // namespace onnxruntime | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,22 @@ | ||
| // Copyright (c) Microsoft Corporation. All rights reserved. | ||
| // Licensed under the MIT License. | ||
|
|
||
| #pragma once | ||
| #include "core/common/common.h" | ||
| #include "core/providers/cuda/cuda_kernel.h" | ||
|
|
||
| namespace onnxruntime { | ||
| namespace cuda { | ||
|
|
||
| class TensorScatter final : public CudaKernel { | ||
| public: | ||
| TensorScatter(const OpKernelInfo& info); | ||
| Status ComputeInternal(OpKernelContext* context) const override; | ||
|
|
||
| private: | ||
| int64_t axis_; | ||
| bool circular_; | ||
| }; | ||
|
|
||
| } // namespace cuda | ||
| } // namespace onnxruntime |
118 changes: 118 additions & 0 deletions
118
onnxruntime/core/providers/cuda/llm/tensorscatter_impl.cu
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,118 @@ | ||
| // Copyright (c) Microsoft Corporation. All rights reserved. | ||
| // Licensed under the MIT License. | ||
|
|
||
| #include "core/providers/cuda/llm/tensorscatter_impl.h" | ||
| #include "core/providers/cuda/cu_inc/common.cuh" | ||
|
|
||
| namespace onnxruntime { | ||
| namespace cuda { | ||
|
|
||
| template <typename T, bool circular> | ||
| __global__ void _TensorScatterKernel( | ||
| T* output_data, | ||
| const T* update_data, | ||
| const int64_t* write_indices, | ||
| int64_t prefix_count, | ||
| int64_t prefix_stride_for_batch, | ||
| int64_t max_seq_len, | ||
| int64_t seq_len, | ||
| int64_t suffix_count, | ||
| size_t total_elements) { | ||
| CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, total_elements); | ||
|
|
||
| int64_t seq_suffix = seq_len * suffix_count; | ||
| int64_t prefix_idx = id / seq_suffix; | ||
| int64_t remainder = id - prefix_idx * seq_suffix; | ||
| int64_t seq_idx = remainder / suffix_count; | ||
| int64_t suffix_idx = remainder - seq_idx * suffix_count; | ||
|
|
||
| int64_t batch_idx = prefix_idx / prefix_stride_for_batch; | ||
| int64_t wi = (write_indices != nullptr) ? write_indices[batch_idx] : 0; | ||
| // write_indices are validated on the host before kernel launch. | ||
| int64_t cache_pos; | ||
| if (circular) { | ||
| cache_pos = (wi + seq_idx) % max_seq_len; | ||
| } else { | ||
| cache_pos = wi + seq_idx; | ||
| } | ||
|
titaiwangms marked this conversation as resolved.
|
||
|
|
||
| int64_t out_offset = prefix_idx * (max_seq_len * suffix_count) + cache_pos * suffix_count + suffix_idx; | ||
| output_data[out_offset] = update_data[id]; | ||
| } | ||
|
|
||
| template <typename T> | ||
| Status _TensorScatterDispatchCircular( | ||
| cudaStream_t stream, | ||
| T* output_data, | ||
| const T* update_data, | ||
| const int64_t* write_indices, | ||
| int64_t prefix_count, | ||
| int64_t prefix_stride_for_batch, | ||
| int64_t max_seq_len, | ||
| int64_t seq_len, | ||
| int64_t suffix_count, | ||
| bool circular) { | ||
| size_t total_elements = static_cast<size_t>(prefix_count) * static_cast<size_t>(seq_len) * static_cast<size_t>(suffix_count); | ||
| if (total_elements == 0) return Status::OK(); | ||
|
|
||
| int blocksPerGrid = static_cast<int>(CeilDiv(total_elements, static_cast<size_t>(GridDim::maxThreadsPerBlock))); | ||
|
|
||
| if (circular) { | ||
| _TensorScatterKernel<T, true><<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0, stream>>>( | ||
| output_data, update_data, write_indices, | ||
| prefix_count, prefix_stride_for_batch, max_seq_len, seq_len, suffix_count, | ||
| total_elements); | ||
| } else { | ||
| _TensorScatterKernel<T, false><<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0, stream>>>( | ||
| output_data, update_data, write_indices, | ||
| prefix_count, prefix_stride_for_batch, max_seq_len, seq_len, suffix_count, | ||
| total_elements); | ||
| } | ||
|
|
||
| return CUDA_CALL(cudaGetLastError()); | ||
| } | ||
|
|
||
| Status TensorScatterImpl( | ||
| cudaStream_t stream, | ||
| void* output_data, | ||
| const void* update_data, | ||
| const int64_t* write_indices, | ||
| size_t element_size, | ||
| int64_t prefix_count, | ||
| int64_t prefix_stride_for_batch, | ||
| int64_t max_seq_len, | ||
| int64_t seq_len, | ||
| int64_t suffix_count, | ||
| bool circular) { | ||
| switch (element_size) { | ||
| case sizeof(int8_t): | ||
| return _TensorScatterDispatchCircular<int8_t>( | ||
| stream, reinterpret_cast<int8_t*>(output_data), | ||
| reinterpret_cast<const int8_t*>(update_data), write_indices, | ||
| prefix_count, prefix_stride_for_batch, max_seq_len, seq_len, suffix_count, circular); | ||
|
|
||
| case sizeof(int16_t): | ||
| return _TensorScatterDispatchCircular<int16_t>( | ||
| stream, reinterpret_cast<int16_t*>(output_data), | ||
| reinterpret_cast<const int16_t*>(update_data), write_indices, | ||
| prefix_count, prefix_stride_for_batch, max_seq_len, seq_len, suffix_count, circular); | ||
|
|
||
| case sizeof(int32_t): | ||
| return _TensorScatterDispatchCircular<int32_t>( | ||
| stream, reinterpret_cast<int32_t*>(output_data), | ||
| reinterpret_cast<const int32_t*>(update_data), write_indices, | ||
| prefix_count, prefix_stride_for_batch, max_seq_len, seq_len, suffix_count, circular); | ||
|
|
||
| case sizeof(int64_t): | ||
| return _TensorScatterDispatchCircular<int64_t>( | ||
| stream, reinterpret_cast<int64_t*>(output_data), | ||
| reinterpret_cast<const int64_t*>(update_data), write_indices, | ||
| prefix_count, prefix_stride_for_batch, max_seq_len, seq_len, suffix_count, circular); | ||
|
|
||
| default: | ||
| return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unsupported element size for TensorScatter: ", element_size); | ||
| } | ||
| } | ||
|
|
||
| } // namespace cuda | ||
| } // namespace onnxruntime | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,25 @@ | ||
| // Copyright (c) Microsoft Corporation. All rights reserved. | ||
| // Licensed under the MIT License. | ||
|
|
||
| #pragma once | ||
|
|
||
| #include "core/providers/cuda/shared_inc/cuda_utils.h" | ||
|
|
||
| namespace onnxruntime { | ||
| namespace cuda { | ||
|
|
||
| Status TensorScatterImpl( | ||
| cudaStream_t stream, | ||
| void* output_data, | ||
| const void* update_data, | ||
| const int64_t* write_indices, | ||
| size_t element_size, | ||
| int64_t prefix_count, | ||
| int64_t prefix_stride_for_batch, | ||
| int64_t max_seq_len, | ||
| int64_t seq_len, | ||
| int64_t suffix_count, | ||
| bool circular); | ||
|
|
||
| } // namespace cuda | ||
| } // namespace onnxruntime |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Such synchronization operations are going to cause issues during cuda graph capture. Usually, on GPU ops , we don't bring data to the host to perform validations if the input is already on CUDA memory. It is probably better to use CUDA_KERNEL_ASSERT() within the kernel to asynchronously report invalid data.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see. I will submit a follow up. How did you find out about this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just happened to look at the PR post merge.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also re-
"The sync is the real cost, but it happens before the scatter kernel launch, not in the middle of a kernel pipeline"
While this is true, a stream sync here would mean that it needs to block untill all asynchronous work queued on this stream complete.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
#27484