From 958d1c2a9803006d456307b0ad5c80403f858024 Mon Sep 17 00:00:00 2001 From: Xinghua Cao Date: Thu, 5 Mar 2026 14:39:20 +0800 Subject: [PATCH] webgpu: support bool for Expand, Flatten, Gather and Unsqueeze --- .../core/providers/webgpu/tensor/expand.cc | 68 +++++++++++++---- .../core/providers/webgpu/tensor/expand.h | 8 +- .../core/providers/webgpu/tensor/flatten.cc | 10 +-- .../core/providers/webgpu/tensor/gather.cc | 75 +++++++++++++------ .../core/providers/webgpu/tensor/unsqueeze.cc | 4 +- .../providers/webgpu/webgpu_supported_types.h | 13 ++++ .../test/providers/cpu/nn/flatten_op_test.cc | 7 ++ .../test/providers/cpu/tensor/expand_test.cc | 56 ++++++++++++++ .../providers/cpu/tensor/unsqueeze_op_test.cc | 9 +++ 9 files changed, 205 insertions(+), 45 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/tensor/expand.cc b/onnxruntime/core/providers/webgpu/tensor/expand.cc index 653250279bf38..0dacd589cbba8 100644 --- a/onnxruntime/core/providers/webgpu/tensor/expand.cc +++ b/onnxruntime/core/providers/webgpu/tensor/expand.cc @@ -14,6 +14,38 @@ Status ExpandProgram::GenerateShaderCode(ShaderHelper& shader) const { const auto& input = shader.AddInput("input", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform); shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.data_size"); + if (Inputs()[0].var_type == ProgramVariableDataType::Boolx4) { + const auto& input_indices = shader.AddIndices("input_indices"); + const auto& output_indices = shader.AddIndices("output_indices"); + if (input_last_dim_divisible_by_4_) { + // The last dims of input shape and output shape are all divisible by 4. + shader.MainFunctionBody() << " let output_indices = " << output_indices.OffsetToIndices("global_idx * 4") << ";\n" + << " let input_offset = " << input_indices.BroadcastedIndicesToOffset("output_indices", output_indices) << ";\n" + << output.SetByOffset("global_idx", input.GetByOffset("input_offset")); + } else if (output_last_dim_divisible_by_4_) { + // The last dim of output shape is divisible by 4, and the last dim of input shape is 1. + shader.MainFunctionBody() << " let output_indices = " << output_indices.OffsetToIndices("global_idx * 4") << ";\n" + << " let input_offset = " << input_indices.BroadcastedIndicesToOffset("output_indices", output_indices) << ";\n" + << " let value = vec4(" << input.GetByOffset("input_offset / 4") << "[input_offset % 4]);\n" + << " " << output.SetByOffset("global_idx", "value"); + } else { + shader.MainFunctionBody() << " var output_indices = " << output_indices.OffsetToIndices("global_idx * 4") << ";\n" + << " let input_offset_0 = " << input_indices.BroadcastedIndicesToOffset("output_indices", output_indices) << ";\n" + << " output_indices = " << output_indices.OffsetToIndices("global_idx * 4 + 1") << ";\n" + << " let input_offset_1 = " << input_indices.BroadcastedIndicesToOffset("output_indices", output_indices) << ";\n" + << " output_indices = " << output_indices.OffsetToIndices("global_idx * 4 + 2") << ";\n" + << " let input_offset_2 = " << input_indices.BroadcastedIndicesToOffset("output_indices", output_indices) << ";\n" + << " output_indices = " << output_indices.OffsetToIndices("global_idx * 4 + 3") << ";\n" + << " let input_offset_3 = " << input_indices.BroadcastedIndicesToOffset("output_indices", output_indices) << ";\n" + << " let value = vec4(" + << input.GetByOffset("input_offset_0 / 4") << "[input_offset_0 % 4], " + << input.GetByOffset("input_offset_1 / 4") << "[input_offset_1 % 4], " + << input.GetByOffset("input_offset_2 / 4") << "[input_offset_2 % 4], " + << input.GetByOffset("input_offset_3 / 4") << "[input_offset_3 % 4]);\n" + << output.SetByOffset("global_idx", "value"); + } + return Status::OK(); + } if (input.NumComponents() != output.NumComponents()) { const auto& output_indices = shader.AddIndices("output_indices"); shader.MainFunctionBody() << " let output_indices = " << output_indices.OffsetToIndices("global_idx * 4") << ";\n" @@ -40,23 +72,33 @@ Status Expand::ComputeInternal(ComputeContext& context) const { auto* output_tensor = context.Output(0, output_shape); bool is_int64 = input_tensor->DataType() == DataTypeImpl::GetType(); - const int components_i = (input_shape.IsScalar() || is_int64) ? 1 : input_shape[input_shape.NumDimensions() - 1] % 4 == 0 ? 4 - : 1; - const int components_o = (output_shape.IsScalar() || is_int64) ? 1 : output_shape[output_shape.NumDimensions() - 1] % 4 == 0 ? 4 - : 1; - uint32_t data_size = onnxruntime::narrow(output_shape.Size() / components_o); + // Check if either input is boolean + // For boolean inputs, we need to handle them differently in the shader. This is because `bool` is not a valid type in + // storage buffer. We have to use a `u32` to represent 4 boolean values. + bool is_bool = input_tensor->DataType() == DataTypeImpl::GetType(); + bool input_last_dim_divisible_by_4 = (!(input_shape.IsScalar() || is_int64)) && (input_shape[input_shape.NumDimensions() - 1] % 4 == 0); + bool output_last_dim_divisible_by_4 = (!(output_shape.IsScalar() || is_int64)) && (output_shape[output_shape.NumDimensions() - 1] % 4 == 0); + const int components_i = (is_bool || input_last_dim_divisible_by_4) ? 4 : 1; + const int components_o = (is_bool || output_last_dim_divisible_by_4) ? 4 : 1; + uint32_t data_size = onnxruntime::narrow((output_shape.Size() + components_o - 1) / components_o); if (data_size == 0) { return Status::OK(); } - ExpandProgram program{}; - program - .AddInputs({{input_tensor, ProgramTensorMetadataDependency::TypeAndRank, components_i}}) - .AddOutputs({{output_tensor, ProgramTensorMetadataDependency::TypeAndRank, components_o}}) - .SetDispatchGroupSize((data_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) + ExpandProgram program{input_last_dim_divisible_by_4, output_last_dim_divisible_by_4}; + program.SetDispatchGroupSize((data_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) .AddUniformVariables({ {data_size}, }); - if (components_i != components_o) { + if (is_bool) { + program.CacheHint(std::to_string(static_cast(input_last_dim_divisible_by_4)), std::to_string(static_cast(output_last_dim_divisible_by_4))) + .AddInputs({{input_tensor, ProgramTensorMetadataDependency::TypeAndRank, ProgramInput::Flatten, components_i}}) + .AddOutputs({{output_tensor, ProgramTensorMetadataDependency::TypeAndRank, {data_size}, components_o}}) + .AddIndices(std::move(input_shape)); + } else { + program.AddInputs({{input_tensor, ProgramTensorMetadataDependency::TypeAndRank, components_i}}) + .AddOutputs({{output_tensor, ProgramTensorMetadataDependency::TypeAndRank, components_o}}); + } + if (is_bool || components_i != components_o) { program.AddIndices(std::move(output_shape)); } return context.RunProgram(program); @@ -64,7 +106,7 @@ Status Expand::ComputeInternal(ComputeContext& context) const { template KernelCreateInfo CreateExpandVersionedKernelInfo(bool enable_int64) { - const auto& type_constraints = GetOpTypeConstraints(enable_int64, false); + const auto& type_constraints = GetOpTypeConstraints(enable_int64, true); KernelCreateFn kernel_create_fn = [](FuncManager&, const OpKernelInfo& info, std::unique_ptr& out) -> Status { out = std::make_unique(info); @@ -85,7 +127,7 @@ KernelCreateInfo CreateExpandVersionedKernelInfo(bool enable_int64) { template KernelCreateInfo CreateExpandKernelInfo(bool enable_int64) { - const auto& type_constraints = GetOpTypeConstraints(enable_int64, false); + const auto& type_constraints = GetOpTypeConstraints(enable_int64, true); KernelCreateFn kernel_create_fn = [](FuncManager&, const OpKernelInfo& info, std::unique_ptr& out) -> Status { out = std::make_unique(info); diff --git a/onnxruntime/core/providers/webgpu/tensor/expand.h b/onnxruntime/core/providers/webgpu/tensor/expand.h index 3532640d01439..cb458f48aea7c 100644 --- a/onnxruntime/core/providers/webgpu/tensor/expand.h +++ b/onnxruntime/core/providers/webgpu/tensor/expand.h @@ -11,11 +11,17 @@ namespace webgpu { class ExpandProgram final : public Program { public: - ExpandProgram() : Program{"Expand"} {} + ExpandProgram(const bool input_last_dim_divisible_by_4, const bool output_last_dim_divisible_by_4) : Program{"Expand"}, + input_last_dim_divisible_by_4_{input_last_dim_divisible_by_4}, + output_last_dim_divisible_by_4_{output_last_dim_divisible_by_4} {} Status GenerateShaderCode(ShaderHelper& sh) const override; WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"data_size", ProgramUniformVariableDataType::Uint32}); + + private: + bool input_last_dim_divisible_by_4_; + bool output_last_dim_divisible_by_4_; }; class Expand final : public WebGpuKernel { diff --git a/onnxruntime/core/providers/webgpu/tensor/flatten.cc b/onnxruntime/core/providers/webgpu/tensor/flatten.cc index 11ded865b6be2..9b8217b7d1182 100644 --- a/onnxruntime/core/providers/webgpu/tensor/flatten.cc +++ b/onnxruntime/core/providers/webgpu/tensor/flatten.cc @@ -15,7 +15,7 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX( kWebGpuExecutionProvider, (*KernelDefBuilder::Create()) .Alias(0, 0) - .TypeConstraint("T", WebGpuSupportedNumberTypes()) + .TypeConstraint("T", WebGpuSupportedNumberAndBoolTypes()) .InputMemoryType(OrtMemTypeCPU, 1), Flatten); @@ -26,7 +26,7 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX( kWebGpuExecutionProvider, (*KernelDefBuilder::Create()) .Alias(0, 0) - .TypeConstraint("T", WebGpuSupportedNumberTypes()) + .TypeConstraint("T", WebGpuSupportedNumberAndBoolTypes()) .InputMemoryType(OrtMemTypeCPU, 1), Flatten); @@ -37,7 +37,7 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX( kWebGpuExecutionProvider, (*KernelDefBuilder::Create()) .Alias(0, 0) - .TypeConstraint("T", WebGpuSupportedNumberTypes()) + .TypeConstraint("T", WebGpuSupportedNumberAndBoolTypes()) .InputMemoryType(OrtMemTypeCPU, 1), Flatten); @@ -48,7 +48,7 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX( kWebGpuExecutionProvider, (*KernelDefBuilder::Create()) .Alias(0, 0) - .TypeConstraint("T", WebGpuSupportedNumberTypes()) + .TypeConstraint("T", WebGpuSupportedNumberAndBoolTypes()) .InputMemoryType(OrtMemTypeCPU, 1), Flatten); @@ -59,7 +59,7 @@ ONNX_OPERATOR_KERNEL_EX( kWebGpuExecutionProvider, (*KernelDefBuilder::Create()) .Alias(0, 0) - .TypeConstraint("T", WebGpuSupportedNumberTypes()) + .TypeConstraint("T", WebGpuSupportedNumberAndBoolTypes()) .InputMemoryType(OrtMemTypeCPU, 1), Flatten); diff --git a/onnxruntime/core/providers/webgpu/tensor/gather.cc b/onnxruntime/core/providers/webgpu/tensor/gather.cc index 39d07991f3c5a..b3e5c7b4e8310 100644 --- a/onnxruntime/core/providers/webgpu/tensor/gather.cc +++ b/onnxruntime/core/providers/webgpu/tensor/gather.cc @@ -9,32 +9,51 @@ namespace onnxruntime { namespace webgpu { Status GatherProgram::GenerateShaderCode(ShaderHelper& shader) const { - const auto& data = shader.AddInput("data", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias); + const auto& data = shader.AddInput("data", ShaderUsage::UseIndicesTypeAlias); const auto& indices = shader.AddInput("input_indices", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias); - const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform); + const auto& output = shader.AddOutput("output", ShaderUsage::UseValueTypeAlias); + const auto& data_indices = shader.AddIndices("data_indices", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias); + const auto& output_indices = shader.AddIndices("output_indices", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias); + bool is_bool = Inputs()[0].var_type == ProgramVariableDataType::Boolx4; shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.data_size") - << " let output_indices = " << output.OffsetToIndices("global_idx") << ";\n" - << " var indices_indices = input_indices_indices_t(0);\n"; - for (int i = 0; i < indices.Rank(); i++) { - shader.MainFunctionBody() << " " << indices.IndicesSet("indices_indices", i, output.IndicesGet("output_indices", axis_ + i)) << ";\n"; - } - shader.MainFunctionBody() << " var idx = " << indices.GetByIndices("indices_indices") << ";\n" - << " if (idx < 0) {\n" - << " idx = idx + input_indices_value_t(" << data.IndicesGet("uniforms.data_shape", axis_) << ");\n" - << " }\n" - << " var data_indices : data_indices_t;\n"; - for (int i = 0, j = 0; i < data.Rank(); i++) { - if (static_cast(i) == axis_) { - shader.MainFunctionBody() << " " << data.IndicesSet("data_indices", i, "u32(idx)") << ";\n"; - j += indices.Rank(); + << " var idx : input_indices_value_t;\n" + << " var output_indices : output_indices_indices_t;\n" + << " var indices_indices : input_indices_indices_t;\n" + << " var data_indices : data_indices_indices_t;\n" + << " var value : output_value_t;\n" + << " var data_offset : u32;\n"; + for (int comp = 0; comp < (is_bool ? 4 : 1); comp++) { + shader.MainFunctionBody() << " output_indices = " << output_indices.OffsetToIndices(is_bool ? (std::to_string(comp) + " + 4 * global_idx") : "global_idx") << ";\n"; + + for (int i = 0; i < indices.Rank(); i++) { + shader.MainFunctionBody() << " " << indices.IndicesSet("indices_indices", i, output_indices.IndicesGet("output_indices", axis_ + i)) << ";\n"; + } + + shader.MainFunctionBody() << " idx = " << indices.GetByIndices("indices_indices") << ";\n" + << " if (idx < 0) {\n" + << " idx = idx + input_indices_value_t(" << data_indices.IndicesGet("uniforms.data_indices_shape", axis_) << ");\n" + << " }\n"; + + for (int i = 0, j = 0; i < data_indices.Rank(); i++) { + if (static_cast(i) == axis_) { + shader.MainFunctionBody() << " " << data_indices.IndicesSet("data_indices", i, "u32(idx)") << ";\n"; + j += indices.Rank(); + } else { + shader.MainFunctionBody() << " " << data_indices.IndicesSet("data_indices", i, output_indices.IndicesGet("output_indices", j)) << ";\n"; + j++; + } + } + + shader.MainFunctionBody() << " data_offset = " << data_indices.IndicesToOffset("data_indices") << ";\n"; + if (is_bool) { + shader.MainFunctionBody() << " value[" << comp << "] = " << data.GetByOffset("data_offset / 4") << "[data_offset % 4];\n"; } else { - shader.MainFunctionBody() << " " << data.IndicesSet("data_indices", i, output.IndicesGet("output_indices", j)) << ";\n"; - j++; + shader.MainFunctionBody() << " value = " << data.GetByOffset("data_offset") << ";\n"; } } - shader.MainFunctionBody() << " " << output.SetByOffset("global_idx", data.GetByIndices("data_indices")); + shader.MainFunctionBody() << " " << output.SetByOffset("global_idx", "value"); return Status::OK(); } @@ -47,14 +66,22 @@ Status Gather::ComputeInternal(ComputeContext& context) const { return Status::OK(); } + bool is_bool = p.input_tensor->DataType() == DataTypeImpl::GetType(); + if (is_bool) { + // Shader will pack four bools into one uint, so we consider the types of input and output as vec4. + data_size = (data_size + 3) / 4; + } + uint32_t axis = static_cast(p.axis); GatherProgram program{axis}; program - .AddInputs({{p.input_tensor, ProgramTensorMetadataDependency::TypeAndRank}, + .AddInputs({{p.input_tensor, ProgramTensorMetadataDependency::TypeAndRank, ProgramInput::Flatten, (is_bool ? 4 : 1)}, {p.indices_tensor, ProgramTensorMetadataDependency::TypeAndRank}}) - .AddOutput({p.output_tensor, ProgramTensorMetadataDependency::Rank}) + .AddOutput({p.output_tensor, ProgramTensorMetadataDependency::Rank, {data_size}, (is_bool ? 4 : 1)}) .SetDispatchGroupSize((data_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) .CacheHint(std::to_string(axis)) + .AddIndices(p.input_tensor->Shape()) + .AddIndices(p.output_tensor->Shape()) .AddUniformVariables({{data_size}}); return context.RunProgram(program); } @@ -71,9 +98,9 @@ Status Gather::ComputeInternal(ComputeContext& context) const { KernelDefBuilder().TypeConstraint("T", TYPE).TypeConstraint("Tind", BuildKernelDefConstraintsFromTypeList>()), \ KERNEL_CLASS); -WEBGPU_GATHER_VERSIONED_KERNEL(Gather, 1, 10, Gather, WebGpuSupportedNumberTypes()) -WEBGPU_GATHER_VERSIONED_KERNEL(Gather, 11, 12, Gather, WebGpuSupportedNumberTypes()) -WEBGPU_GATHER_KERNEL(Gather, 13, Gather, WebGpuSupportedNumberTypes()) +WEBGPU_GATHER_VERSIONED_KERNEL(Gather, 1, 10, Gather, WebGpuSupportedNumberAndBoolTypes()) +WEBGPU_GATHER_VERSIONED_KERNEL(Gather, 11, 12, Gather, WebGpuSupportedNumberAndBoolTypes()) +WEBGPU_GATHER_KERNEL(Gather, 13, Gather, WebGpuSupportedNumberAndBoolTypes()) } // namespace webgpu } // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/tensor/unsqueeze.cc b/onnxruntime/core/providers/webgpu/tensor/unsqueeze.cc index 09439dd33dc56..104fcf1812af8 100644 --- a/onnxruntime/core/providers/webgpu/tensor/unsqueeze.cc +++ b/onnxruntime/core/providers/webgpu/tensor/unsqueeze.cc @@ -10,7 +10,7 @@ namespace webgpu { template KernelCreateInfo CreateUnsqueezeVersionedKernelInfo(bool enable_int64) { - const auto& type_constraints = GetOpTypeConstraints(enable_int64, false); + const auto& type_constraints = GetOpTypeConstraints(enable_int64, true); KernelCreateFn kernel_create_fn = [](FuncManager&, const OpKernelInfo& info, std::unique_ptr& out) -> Status { out = std::make_unique(info); @@ -45,7 +45,7 @@ KernelCreateInfo CreateUnsqueezeVersionedKernelInfo(bool enable_int64) { template KernelCreateInfo CreateUnsqueezeKernelInfo(bool enable_int64) { - const auto& type_constraints = GetOpTypeConstraints(enable_int64, false); + const auto& type_constraints = GetOpTypeConstraints(enable_int64, true); KernelCreateFn kernel_create_fn = [](FuncManager&, const OpKernelInfo& info, std::unique_ptr& out) -> Status { out = std::make_unique(info); diff --git a/onnxruntime/core/providers/webgpu/webgpu_supported_types.h b/onnxruntime/core/providers/webgpu/webgpu_supported_types.h index 30ec269495782..1efbda00ec869 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_supported_types.h +++ b/onnxruntime/core/providers/webgpu/webgpu_supported_types.h @@ -20,6 +20,14 @@ using SupportedFloats = float, MLFloat16>; +using SupportedNumberAndBoolTypes = + TypeList< + float, + MLFloat16, + int32_t, + uint32_t, + bool>; + inline const std::vector& WebGpuSupportedNumberTypes() { static const std::vector supportedDataTypes = BuildKernelDefConstraintsFromTypeList(); return supportedDataTypes; @@ -30,6 +38,11 @@ inline const std::vector& WebGpuSupportedFloatTypes() { return supportedDataTypes; } +inline const std::vector& WebGpuSupportedNumberAndBoolTypes() { + static const std::vector supportedDataTypes = BuildKernelDefConstraintsFromTypeList(); + return supportedDataTypes; +} + inline const std::vector& GetOpTypeConstraints(bool enable_int64 = false, bool enable_bool = false) { static std::vector base_types{ DataTypeImpl::GetTensorType(), diff --git a/onnxruntime/test/providers/cpu/nn/flatten_op_test.cc b/onnxruntime/test/providers/cpu/nn/flatten_op_test.cc index 11a9f626d5709..b4eb1228966bb 100644 --- a/onnxruntime/test/providers/cpu/nn/flatten_op_test.cc +++ b/onnxruntime/test/providers/cpu/nn/flatten_op_test.cc @@ -45,6 +45,13 @@ TEST_F(FlattenOpTest, Flatten_invalid_axis) { test_.Run(OpTester::ExpectResult::kExpectFailure, "Invalid value(5) for attribute 'axis'"); } +TEST_F(FlattenOpTest, Flatten_axis2_bool) { + test_.AddAttribute("axis", 2L); + test_.AddInput("data", {2L, 2L, 2L, 3L}, {false, true, false, true, true, false, true, false, false, false, false, true, true, false, true, true, true, false, false, true, true, true, false, true}); + test_.AddOutput("output", {4L, 6L}, {false, true, false, true, true, false, true, false, false, false, false, true, true, false, true, true, true, false, false, true, true, true, false, true}); + test_.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); +} + TEST_F(FlattenOpTest, Flatten_axis3) { test_.AddAttribute("axis", 3L); test_.AddInput("data", {2L, 3L, 4L, 5L}, data0_); diff --git a/onnxruntime/test/providers/cpu/tensor/expand_test.cc b/onnxruntime/test/providers/cpu/tensor/expand_test.cc index f1ff112956188..cb2774a22f99c 100644 --- a/onnxruntime/test/providers/cpu/tensor/expand_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/expand_test.cc @@ -247,6 +247,62 @@ TEST(ExpandOpTest, Expand_3x1x8_float) { test.Run(); } +TEST(ExpandOpTest, Expand_3x3_bool) { + OpTester test("Expand", 8); + test.AddInput("data_0", {1}, {true}); + test.AddInput("data_1", {2}, {3, 3}); + test.AddOutput("result", {3, 3}, + {true, true, true, + true, true, true, + true, true, true}); + test.Run(); +} + +TEST(ExpandOpTest, Expand_3x1_bool) { + OpTester test("Expand", 8); + test.AddInput("data_0", {3}, {false, true, false}); + test.AddInput("data_1", {2}, {3, 1}); + test.AddOutput("result", {3, 3}, + {false, true, false, + false, true, false, + false, true, false}); + test.Run(); +} + +TEST(ExpandOpTest, Expand_1x3_bool) { + OpTester test("Expand", 8); + test.AddInput("data_0", {3, 1}, {false, true, false}); + test.AddInput("data_1", {2}, {1, 3}); + test.AddOutput("result", {3, 3}, + {false, false, false, + true, true, true, + false, false, false}); + test.Run(); +} + +TEST(ExpandOpTest, Expand_1x4_bool) { + OpTester test("Expand", 8); + test.AddInput("data_0", {3, 1}, {false, true, false}); + test.AddInput("data_1", {2}, {1, 4}); + test.AddOutput("result", {3, 4}, + {false, false, false, false, + true, true, true, true, + false, false, false, false}); + test.Run(); +} + +TEST(ExpandOpTest, Expand_4x1_bool) { + OpTester test("Expand", 8); + test.AddInput("data_0", {1, 4}, {false, true, false, false}); + test.AddInput("data_1", {2}, {4, 1}); + test.AddOutput("result", {4, 4}, + {false, true, false, false, + false, true, false, false, + false, true, false, false, + false, true, false, false}); + test.Run(); +} + #ifndef USE_TENSORRT TEST(ExpandOpTest, Expand_scalar_float) { OpTester test("Expand", 8); diff --git a/onnxruntime/test/providers/cpu/tensor/unsqueeze_op_test.cc b/onnxruntime/test/providers/cpu/tensor/unsqueeze_op_test.cc index 308d7f18f3821..1dcbf3b8466d4 100644 --- a/onnxruntime/test/providers/cpu/tensor/unsqueeze_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/unsqueeze_op_test.cc @@ -44,6 +44,15 @@ TEST(UnsqueezeOpTest, Unsqueeze_1_int64) { } #endif +TEST(UnsqueezeOpTest, Unsqueeze_1_bool) { + OpTester test("Unsqueeze"); + + test.AddAttribute("axes", std::vector{1}); + test.AddInput("input", {2, 3, 4}, {true, false, true, false, false, true, false, true, false, true, true, false, true, false, false, true, true, true, true, false, true, false, false, true}); + test.AddOutput("output", {2, 1, 3, 4}, {true, false, true, false, false, true, false, true, false, true, true, false, true, false, false, true, true, true, true, false, true, false, false, true}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); +} + TEST(UnsqueezeOpTest, Unsqueeze_2) { OpTester test("Unsqueeze");