Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 23 additions & 9 deletions onnxruntime/core/providers/webgpu/tensor/split.cc
Original file line number Diff line number Diff line change
Expand Up @@ -103,30 +103,44 @@ Status Split::ComputeInternal(ComputeContext& context) const {
ORT_RETURN_IF_ERROR(PrepareForCompute(input_shape, num_outputs, axis, before_dims, after_dims_including_split_axis,
after_dims_excluding_split, split_sizes));

SplitProgram program{static_cast<uint32_t>(axis)};
program.AddInput({input, ProgramTensorMetadataDependency::TypeAndRank});

// Create all output tensors first (required for ONNX node contract)
auto output_dimensions = input_shape.AsShapeVector();
std::vector<Tensor*> all_outputs;
std::vector<int> non_empty_output_indices;

for (int i = 0; i < num_outputs; ++i) {
// Update the size of dimension for axis we're splitting on.
auto split_size = narrow<int>(split_sizes[i]);
output_dimensions[narrow<size_t>(axis)] = split_size;

Tensor* output = context.Output(i, TensorShape{output_dimensions});
program.AddOutput({output, ProgramTensorMetadataDependency::Rank});
all_outputs.push_back(output);

// Only include non-empty outputs in the GPU program
if (split_size > 0) {
non_empty_output_indices.push_back(i);
}
}

uint32_t input_size = onnxruntime::narrow<uint32_t>(input_shape.Size());
// Early return if the input tensor is empty.
if (input_size == 0) {
// Early return if the input tensor is empty or all outputs are empty.
if (input_size == 0 || non_empty_output_indices.empty()) {
return Status::OK();
}

SplitProgram program{static_cast<uint32_t>(axis)};
program.AddInput({input, ProgramTensorMetadataDependency::TypeAndRank});

// Only add non-empty outputs to the program
for (int output_idx : non_empty_output_indices) {
program.AddOutput({all_outputs[output_idx], ProgramTensorMetadataDependency::Rank});
}

uint32_t previous_sum = 0;
std::vector<uint32_t> sizes_in_split_axis;
// sizes_in_split_axis are the cumulative sizes of the splits in the split axis.
for (auto split_size : split_sizes) {
previous_sum += onnxruntime::narrow<uint32_t>(split_size);
// sizes_in_split_axis are the cumulative sizes of the NON-EMPTY splits in the split axis.
for (int output_idx : non_empty_output_indices) {
previous_sum += onnxruntime::narrow<uint32_t>(split_sizes[output_idx]);
sizes_in_split_axis.push_back(previous_sum);
}

Expand Down
17 changes: 17 additions & 0 deletions onnxruntime/test/providers/cpu/tensor/split_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,23 @@ TEST(SplitOperatorTest, ZeroSizeInput) {
RunTest<float>(axis, {}, input, outputs, {kTensorrtExecutionProvider, kQnnExecutionProvider, kCoreMLExecutionProvider});
}

TEST(SplitOperatorTest, ZeroSizeOutput) {
constexpr int64_t axis = 1;
std::vector<ShapeAndFloatData> outputs;

// Non-zero input that will be split to produce zero-size outputs
ShapeAndFloatData input = {{2, 3}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f}};

// Split sizes: 0, 2, 1 - first output will have zero size
std::vector<int64_t> splits{0, 2, 1};

outputs.push_back({{2, 0}, {}}); // Zero-size output
outputs.push_back({{2, 2}, {1.f, 2.f, 4.f, 5.f}});
outputs.push_back({{2, 1}, {3.f, 6.f}});

RunTest<float>(axis, splits, input, outputs, {kTensorrtExecutionProvider, kQnnExecutionProvider, kCoreMLExecutionProvider, kDmlExecutionProvider}, false, true);
}

// test a split of a dimension that has leading and trailing dimensions
TEST(SplitOperatorTest, Axis1SplitMiddleDimensionEqually) {
constexpr int64_t axis = 1;
Expand Down
Loading