diff --git a/onnxruntime/python/tools/symbolic_shape_infer.py b/onnxruntime/python/tools/symbolic_shape_infer.py index bd88eb1b6b353..33f5b5e5853a5 100755 --- a/onnxruntime/python/tools/symbolic_shape_infer.py +++ b/onnxruntime/python/tools/symbolic_shape_infer.py @@ -219,6 +219,8 @@ def __init__(self, int_max, auto_merge, guess_output_rank, verbose, prefix=""): "PackedMultiHeadAttention": self._infer_PackedMultiHeadAttention, "PagedAttention": self._infer_PagedAttention, "PythonOp": self._infer_PythonOp, + "QLinearAdd": self._infer_QLinearBinary, + "QLinearMul": self._infer_QLinearBinary, "QuantizeLinear": self._infer_QuantizeLinear, "QuickGelu": self._infer_FastGelu, "RelativePositionBias": self._infer_RelativePositionBias, @@ -490,6 +492,8 @@ def _onnx_infer_single_node(self, node): "SkipSimplifiedLayerNormalization", "SparseAttention", "SkipGroupNorm", + "QLinearAdd", + "QLinearMul", ] if not skip_infer: @@ -1040,6 +1044,20 @@ def _infer_QuantizeLinear(self, node): # noqa: N802 vi = self.known_vi_[node.output[0]] vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, output_shape)) + def _infer_QLinearBinary(self, node): # noqa: N802 + # Get the output data type from the first input to QLinearAdd / QLinearMul. + output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type + + # The inputs are first and fourth operands respectively. + input_1_shape = self._get_shape(node, 0) + input_2_shape = self._get_shape(node, 3) + + # Compute the broadcasted shape + new_shape = self._broadcast_shapes(input_1_shape, input_2_shape) + + vi = self.known_vi_[node.output[0]] + vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, new_shape)) + def _infer_Einsum(self, node): # noqa: N802 # ref:https://github.com/onnx/onnx/blob/623dfaa0151b2e4ce49779c3ec31cbd78c592b80/onnx/defs/math/defs.cc#L3275 equation = get_attribute(node, "equation") diff --git a/onnxruntime/test/python/onnxruntime_test_python_symbolic_shape_infer.py b/onnxruntime/test/python/onnxruntime_test_python_symbolic_shape_infer.py index d311b4b8517cf..5267ffcc65ab7 100644 --- a/onnxruntime/test/python/onnxruntime_test_python_symbolic_shape_infer.py +++ b/onnxruntime/test/python/onnxruntime_test_python_symbolic_shape_infer.py @@ -644,6 +644,87 @@ def test_matmulnbits(self): ] self._check_shapes(graph, inferred.graph, expected_shapes) + def test_qlinear_binary(self): + """ + Test ONNX QLinearAdd op ('com.microsoft' domain). . + Check that the output shape is propagated from the inputs to the op with broadcasting. + """ + initializers = [ + helper.make_tensor( + "A_scale", + TensorProto.FLOAT, + [], + [0.7], + ), + helper.make_tensor( + "A_zero_point", + TensorProto.UINT8, + [], + [158], + ), + helper.make_tensor( + "B_scale", + TensorProto.FLOAT, + [], + [0.02], + ), + helper.make_tensor( + "B_zero_point", + TensorProto.UINT8, + [], + [5], + ), + helper.make_tensor( + "C_scale", + TensorProto.FLOAT, + [], + [0.26], + ), + helper.make_tensor( + "C_zero_point", + TensorProto.UINT8, + [], + [0], + ), + ] + + nodes = [ + helper.make_node( + "QLinearAdd", + inputs=[ + "A", + "A_scale", + "A_zero_point", + "B", + "B_scale", + "B_zero_point", + "C_scale", + "C_zero_point", + ], + outputs=["C"], + domain="com.microsoft", + ), + ] + + inputs = [ + helper.make_tensor_value_info("A", TensorProto.UINT8, ["b", 4, 128]), + helper.make_tensor_value_info("B", TensorProto.UINT8, ["b", 1, 4, 1, 128]), + ] + + outputs = [ + helper.make_tensor_value_info("C", TensorProto.UNDEFINED, None), + ] + + graph = helper.make_graph(nodes, "QLinearAdd_Test", inputs, outputs, initializers) + model = helper.make_model(graph) + + inferred = SymbolicShapeInference.infer_shapes(model, auto_merge=True) + + expected_shapes = [ + helper.make_tensor_value_info("C", TensorProto.UINT8, ["b", 1, 4, 4, 128]), + ] + self._check_shapes(graph, inferred.graph, expected_shapes) + class TestSymbolicShapeInferenceForSlice(unittest.TestCase): def check_slice_of_concat(self, input_dims, start, end, step, expected_output_dim):