Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 11 additions & 11 deletions onnxruntime/core/providers/webgpu/tensor/scatter_nd.cc
Original file line number Diff line number Diff line change
Expand Up @@ -146,24 +146,24 @@
const auto* updates = context.Input<Tensor>(2);
const auto& input_shape = input->Shape();
const auto& indices_shape = indices->Shape();
auto indices_rank = indices_shape.NumDimensions();
auto last_index_dimension = static_cast<uint32_t>(indices_shape[indices_rank - 1]);
auto num_updates_elements = static_cast<uint32_t>(input_shape.SizeFromDimension(last_index_dimension));
// TODO: support bool with components 4.
const size_t components = 1;
auto output_size = static_cast<uint32_t>((indices_shape.SizeToDimension(indices_rank - 1) + components - 1) / components);
auto* output = context.Output(0, input_shape);
if (output_size == 0) {
// If the output tensor is empty, we can return early.
return Status::OK();
}
MLDataType data_type = input->DataType();
const void* source = input->DataRaw();
void* target = output->MutableDataRaw();
// If source and target pointers are not equal (non-inplace operation), we need to copy the data.
if (target != source) {
ORT_RETURN_IF_ERROR(Info().GetDataTransferManager().CopyTensor(*input, *output));
}
if (indices_shape.Size() == 0) {
// If the indices are empty, we can return early.
return Status::OK();
}
auto indices_rank = indices_shape.NumDimensions();
auto last_index_dimension = static_cast<uint32_t>(indices_shape[indices_rank - 1]);
auto num_updates_elements = static_cast<uint32_t>(input_shape.SizeFromDimension(last_index_dimension));
// TODO: support bool with components 4.

Check warning on line 163 in onnxruntime/core/providers/webgpu/tensor/scatter_nd.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2] Raw Output: onnxruntime/core/providers/webgpu/tensor/scatter_nd.cc:163: Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2]
const size_t components = 1;
auto output_size = static_cast<uint32_t>((indices_shape.SizeToDimension(indices_rank - 1) + components - 1) / components);
MLDataType data_type = input->DataType();
ScatterNDProgram program(reduction_, data_type);
program
.CacheHint(static_cast<uint32_t>(reduction_))
Expand Down
11 changes: 11 additions & 0 deletions onnxruntime/test/providers/cpu/tensor/scatter_nd_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -235,5 +235,16 @@ TEST(ScatterNDOpTest, ScatterND_18_max) {
test1.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider});
}

// Test for ScatterND with empty indices - output should be same as input
TEST(ScatterNDOpTest, ScatterND_empty_indices) {
// Test with float data type and minimal empty case
OpTester test1("ScatterND", 11);
test1.AddInput<float>("data", {2, 3}, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f});
test1.AddInput<int64_t>("indices", {0, 1}, {}); // Empty indices tensor - no indices to process
test1.AddInput<float>("updates", {0, 3}, {}); // Empty updates tensor
test1.AddOutput<float>("output", {2, 3}, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}); // Same as input
test1.Run(OpTester::ExpectResult::kExpectSuccess, "", {kDmlExecutionProvider});
}

} // namespace test
} // namespace onnxruntime
Loading