diff --git a/onnxruntime/core/providers/webgpu/tensor/split.cc b/onnxruntime/core/providers/webgpu/tensor/split.cc index 150b0beb897f5..f6de34dcf120c 100644 --- a/onnxruntime/core/providers/webgpu/tensor/split.cc +++ b/onnxruntime/core/providers/webgpu/tensor/split.cc @@ -103,30 +103,44 @@ Status Split::ComputeInternal(ComputeContext& context) const { ORT_RETURN_IF_ERROR(PrepareForCompute(input_shape, num_outputs, axis, before_dims, after_dims_including_split_axis, after_dims_excluding_split, split_sizes)); - SplitProgram program{static_cast(axis)}; - program.AddInput({input, ProgramTensorMetadataDependency::TypeAndRank}); - + // Create all output tensors first (required for ONNX node contract) auto output_dimensions = input_shape.AsShapeVector(); + std::vector all_outputs; + std::vector non_empty_output_indices; + for (int i = 0; i < num_outputs; ++i) { // Update the size of dimension for axis we're splitting on. auto split_size = narrow(split_sizes[i]); output_dimensions[narrow(axis)] = split_size; Tensor* output = context.Output(i, TensorShape{output_dimensions}); - program.AddOutput({output, ProgramTensorMetadataDependency::Rank}); + all_outputs.push_back(output); + + // Only include non-empty outputs in the GPU program + if (split_size > 0) { + non_empty_output_indices.push_back(i); + } } uint32_t input_size = onnxruntime::narrow(input_shape.Size()); - // Early return if the input tensor is empty. - if (input_size == 0) { + // Early return if the input tensor is empty or all outputs are empty. + if (input_size == 0 || non_empty_output_indices.empty()) { return Status::OK(); } + SplitProgram program{static_cast(axis)}; + program.AddInput({input, ProgramTensorMetadataDependency::TypeAndRank}); + + // Only add non-empty outputs to the program + for (int output_idx : non_empty_output_indices) { + program.AddOutput({all_outputs[output_idx], ProgramTensorMetadataDependency::Rank}); + } + uint32_t previous_sum = 0; std::vector sizes_in_split_axis; - // sizes_in_split_axis are the cumulative sizes of the splits in the split axis. - for (auto split_size : split_sizes) { - previous_sum += onnxruntime::narrow(split_size); + // sizes_in_split_axis are the cumulative sizes of the NON-EMPTY splits in the split axis. + for (int output_idx : non_empty_output_indices) { + previous_sum += onnxruntime::narrow(split_sizes[output_idx]); sizes_in_split_axis.push_back(previous_sum); } diff --git a/onnxruntime/test/providers/cpu/tensor/split_op_test.cc b/onnxruntime/test/providers/cpu/tensor/split_op_test.cc index 1c2a86bb808b5..8db1c4d1fef2e 100644 --- a/onnxruntime/test/providers/cpu/tensor/split_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/split_op_test.cc @@ -397,6 +397,23 @@ TEST(SplitOperatorTest, ZeroSizeInput) { RunTest(axis, {}, input, outputs, {kTensorrtExecutionProvider, kQnnExecutionProvider, kCoreMLExecutionProvider}); } +TEST(SplitOperatorTest, ZeroSizeOutput) { + constexpr int64_t axis = 1; + std::vector outputs; + + // Non-zero input that will be split to produce zero-size outputs + ShapeAndFloatData input = {{2, 3}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f}}; + + // Split sizes: 0, 2, 1 - first output will have zero size + std::vector splits{0, 2, 1}; + + outputs.push_back({{2, 0}, {}}); // Zero-size output + outputs.push_back({{2, 2}, {1.f, 2.f, 4.f, 5.f}}); + outputs.push_back({{2, 1}, {3.f, 6.f}}); + + RunTest(axis, splits, input, outputs, {kTensorrtExecutionProvider, kQnnExecutionProvider, kCoreMLExecutionProvider, kDmlExecutionProvider}, false, true); +} + // test a split of a dimension that has leading and trailing dimensions TEST(SplitOperatorTest, Axis1SplitMiddleDimensionEqually) { constexpr int64_t axis = 1;