diff --git a/onnxruntime/core/providers/cpu/quantization/conv_integer.cc b/onnxruntime/core/providers/cpu/quantization/conv_integer.cc index 03b39e19ed748..f3c6b18f8e753 100644 --- a/onnxruntime/core/providers/cpu/quantization/conv_integer.cc +++ b/onnxruntime/core/providers/cpu/quantization/conv_integer.cc @@ -34,17 +34,18 @@ ONNX_OPERATOR_KERNEL_EX( ConvInteger); Status ConvInteger::Compute(OpKernelContext* context) const { - size_t num_inputs = OpKernel::Node().InputDefs().size(); + const auto input_defs = Node().InputDefs(); + size_t num_inputs = input_defs.size(); const auto* X = context->Input(0); const auto* W = context->Input(1); uint8_t input_offset = 0; uint8_t filter_offset = 0; - if (num_inputs >= 3) { + if (num_inputs >= 3 && input_defs[2]->Exists()) { const auto* X_Zero_Point = context->Input(2); ORT_ENFORCE(IsScalarOr1ElementVector(X_Zero_Point), "Must be a scalar or 1D tensor or size 1."); input_offset = *(X_Zero_Point->Data()); } - if (num_inputs >= 4) { + if (num_inputs >= 4 && input_defs[3]->Exists()) { const auto* W_Zero_Point = context->Input(3); ORT_ENFORCE(IsScalarOr1ElementVector(W_Zero_Point), "Non per-tensor quantization is not supported now."); filter_offset = *(W_Zero_Point->Data()); diff --git a/onnxruntime/test/providers/cpu/nn/conv_integer_test.cc b/onnxruntime/test/providers/cpu/nn/conv_integer_test.cc index a5378fa3cefd7..c98d9e28b2f46 100644 --- a/onnxruntime/test/providers/cpu/nn/conv_integer_test.cc +++ b/onnxruntime/test/providers/cpu/nn/conv_integer_test.cc @@ -254,5 +254,45 @@ TEST(ConvIntegerTest, WithStride3_2D_u8u8) { test.Run(); } +TEST(ConvIntegerTest, NoXZeroPoint) { + OpTester test("ConvInteger", 10); + std::vector x_dims{1, 1, 3, 3}; + test.AddInput("x", x_dims, + {2, 3, 4, + 5, 6, 7, + 8, 9, 10}); + std::vector w_dims{1, 1, 2, 2}; + test.AddInput("w", w_dims, + {2, 2, + 2, 2}); + test.AddOptionalInputEdge(); + test.AddInput("w_zero_point", {}, {1}); + std::vector y_dims{1, 1, 2, 2}; + test.AddOutput("y", y_dims, + {16, 20, + 28, 32}); + test.Run(); +} + +// provide optional input with empty name for w. tests that input args == 4 but the w_zero_point does not exist. +TEST(ConvIntegerTest, NoWZeroPoint) { + OpTester test("ConvInteger", 10); + std::vector x_dims{1, 1, 3, 3}; + test.AddInput("x", x_dims, + {2, 3, 4, + 5, 6, 7, + 8, 9, 10}); + std::vector w_dims{1, 1, 2, 2}; + test.AddInput("w", w_dims, + {2, 2, + 2, 2}); + test.AddInput("x_zero_point", {}, {1}); + test.AddOptionalInputEdge(); + std::vector y_dims{1, 1, 2, 2}; + test.AddOutput("y", y_dims, + {24, 32, + 48, 56}); + test.Run(); +} } // namespace test } // namespace onnxruntime