diff --git a/python/tvm/relay/qnn/op/legalizations.py b/python/tvm/relay/qnn/op/legalizations.py index 961517f863fb..3226240fbe39 100644 --- a/python/tvm/relay/qnn/op/legalizations.py +++ b/python/tvm/relay/qnn/op/legalizations.py @@ -94,6 +94,25 @@ def get_scalar_from_constant(expr): return value.item(0) +def _shift(data, zero_point, out_dtype): + """Shifts (add/subtracts) the qnn tensor with +/-128)""" + if out_dtype == "uint8": + shift = 128 + elif out_dtype == "int8": + shift = -128 + else: + raise ValueError("Unsupported out dtype.") + data_modified = relay.cast(data, "int32") + data_modified = relay.add(data_modified, relay.const(shift, "int32")) + data_modified = relay.cast(data_modified, out_dtype) + if isinstance(zero_point, relay.Constant): + zero_point_val = get_scalar_from_constant(zero_point) + zero_point_modified = relay.const(zero_point_val + shift, "int32") + else: + zero_point_modified = zero_point + relay.const(shift, "int32") + return (data_modified, zero_point_modified) + + # Helper function for lowering in the abscence of fast Int8 arithmetic units. def helper_no_fast_int8_hw_legalization(attrs, inputs, types, relay_op): """Converts QNN operators into a sequence of Relay operators that are friendly to HW that do @@ -161,22 +180,6 @@ def helper_change_dtypes_to_uint8_int8(attrs, inputs, types, relay_op): result : tvm.relay.Expr The legalized expr """ - - def _shift(data, zero_point, out_dtype): - """Shifts (add/subtracts) the qnn tensor with +/-128)""" - if out_dtype == "uint8": - shift = 128 - elif out_dtype == "int8": - shift = -128 - else: - raise ValueError("Unsupported out dtype.") - data_modified = relay.cast(data, "int32") - data_modified = relay.add(data_modified, relay.const(shift, "int32")) - data_modified = relay.cast(data_modified, out_dtype) - zero_point_val = get_scalar_from_constant(zero_point) - zero_point_modified = relay.const(zero_point_val + shift, "int32") - return (data_modified, zero_point_modified) - # Collect the dtypes. data_dtype = types[0].dtype kernel_dtype = types[1].dtype @@ -205,6 +208,54 @@ def _shift(data, zero_point, out_dtype): ) +# Helper function to change dtypes to int8 x int8. Cuda dp4a instructions prefer this setting. +def helper_change_dtypes_to_int8(attrs, inputs, types, relay_op): + """Legalizes QNN conv2d/dense op for Nvidia HW. dp4a supports i8 x i8 fast conv/MM. If the + dtypes are already good, we dont transform. Else, we shift the tensor values and zero points + to change the dtype. + + Parameters + ---------- + attrs : tvm.ir.Attrs + Attributes of current convolution + inputs : list of tvm.relay.Expr + The args of the Relay expr to be legalized + types : list of types + List of input and output types + + Returns + ------- + result : tvm.relay.Expr + The legalized expr + """ + # Collect the dtypes. + data_dtype = types[0].dtype + kernel_dtype = types[1].dtype + + # Collect the input exprs. + data, kernel, input_zero_point, kernel_zero_point, input_scale, kernel_scale = inputs + + # dp4a supports i8 x i8 fast conv/MM. Don't do anything if it is already satisfied. + if data_dtype == "int8" and kernel_dtype == "int8": + return None + + # Shift input if necessary. + if data_dtype == "uint8": + # Compute (QA + 128) and (zp_a + 128) + data, input_zero_point = _shift(data, input_zero_point, "int8") + + # Shift kernel if necessary. + if kernel_dtype == "uint8": + # Compute (QA - 128) and (zp_a - 128) + kernel, kernel_zero_point = _shift(kernel, kernel_zero_point, "int8") + + # Call qnn.conv2d with modified inputs and zero points. + new_attrs = {k: attrs[k] for k in attrs.keys()} + return relay_op( + data, kernel, input_zero_point, kernel_zero_point, input_scale, kernel_scale, **new_attrs + ) + + # Helper function to change dtypes to be same. ARM dotprod instructions prefer this setting. def helper_change_dtypes_to_be_same(attrs, inputs, types, relay_op): """Sometimes MxNet + MLDNN can lead to uint8 x int8 datatypes for the conv inputs. However, @@ -339,11 +390,11 @@ def _qnn_dense_legalize_intel_cpu(attrs, inputs, types): @qnn_conv2d_legalize.register("cuda") def _qnn_conv2d_legalize_cuda(attrs, inputs, types): - # CUDA prefers the dtypes to be same. - return helper_change_dtypes_to_be_same(attrs, inputs, types, relay.qnn.op.conv2d) + # CUDA prefers both datatypes to be int8. + return helper_change_dtypes_to_int8(attrs, inputs, types, relay.qnn.op.conv2d) @qnn_dense_legalize.register("cuda") def _qnn_dense_legalize_cuda(attrs, inputs, types): - # CUDA prefers the dtypes to be same. - return helper_change_dtypes_to_be_same(attrs, inputs, types, relay.qnn.op.dense) + # CUDA prefers both datatypes to be the int8. + return helper_change_dtypes_to_int8(attrs, inputs, types, relay.qnn.op.dense) diff --git a/src/relay/qnn/op/convolution.cc b/src/relay/qnn/op/convolution.cc index a5161358865a..cf5266485f2e 100644 --- a/src/relay/qnn/op/convolution.cc +++ b/src/relay/qnn/op/convolution.cc @@ -65,7 +65,6 @@ bool QnnConv2DRel(const Array& types, int num_inputs, const Attrs& attrs, } } ICHECK(IsScalarType(types[2], DataType::Int(32))); // input_zero_point - ICHECK(IsScalarType(types[3], DataType::Int(32))); // weight_zero_point ICHECK(IsScalarType(types[4], DataType::Float(32))); // input_scale // Kernel scale can be a vector of length output_channels or a scalar. if (param->groups == 1) { @@ -293,7 +292,11 @@ Expr DepthwiseConv2DSecondTerm(const Expr& padded_data, const Expr& kernel_zero_ auto multiplied_t2 = reduced_t2; auto one_scalar = MakeConstantScalar(DataType::Int(32), 1); if (!IsEqualScalar(kernel_zero_point, one_scalar)) { - multiplied_t2 = Multiply(kernel_zero_point, reduced_t2); + if (!IsConstScalar(kernel_zero_point)) { + multiplied_t2 = Multiply(MakeRepeat(kernel_zero_point, channel_multiplier, 0), reduced_t2); + } else { + multiplied_t2 = Multiply(kernel_zero_point, reduced_t2); + } } // Reduce the C dimension. Find the dimension. @@ -378,6 +381,25 @@ Expr DepthwiseConv2DFourthTerm(int input_zero_point_int, int kernel_zero_point_i return MakeConstantScalar(DataType::Int(32), scalar_term4); } +/* + * \brief Calculates the fourth term in the qnn.conv2d depthwise lowering sequence + for non-constant zero_points. + * \param input_zero_point The Expr for the input zero point. + * \param kernel_zero_point The Expr for the kernel zero point. + * \param kernel_h The height of kernel. + * \param kernel_w The width of kernel. + * \return The sequence of Relay operators for term4. + * \note The term4 looks like this + * + * Sigma(r, s) zp_a * zp_w + */ +Expr DepthwiseConv2DFourthTerm(const Expr& input_zero_point, const Expr& kernel_zero_point, + int kernel_h, int kernel_w) { + Expr scalar_term4 = MakeConstantScalar(DataType::Int(32), kernel_h * kernel_w); + Expr variable_term4 = Multiply(input_zero_point, kernel_zero_point); + return Multiply(scalar_term4, variable_term4); +} + /* * \brief Calculates the first term in the qnn.conv2d lowering sequence. * \param data The input expr. @@ -457,6 +479,11 @@ Expr Conv2DSecondTerm(const Expr& padded_data, const Expr& kernel_zero_point, auto multiplied_t2 = reduced_t2; auto one_scalar = MakeConstantScalar(DataType::Int(32), 1); if (!IsEqualScalar(kernel_zero_point, one_scalar)) { + if (!IsConstScalar(kernel_zero_point)) { + Layout layout(param->data_layout); + int channel_axis = layout.IndexOf(LayoutAxis::Get('C')); + reduced_t2 = MakeRepeat(reduced_t2, out_channels, channel_axis); + } multiplied_t2 = Multiply(kernel_zero_point, reduced_t2); } return multiplied_t2; @@ -531,6 +558,27 @@ Expr Conv2DFourthTerm(int input_zero_point_int, int kernel_zero_point_int, int i return MakeConstantScalar(DataType::Int(32), scalar_term4); } +/* + * \brief Calculates the fourth term in the qnn.conv2d lowering sequence + for non-constant zero_points. + * \param input_zero_point The Expr for the input zero point. + * \param kernel_zero_point The Expr for the kernel zero point. + * \param in_channels The number of input channels. + * \param kernel_h The height of kernel. + * \param kernel_w The width of kernel. + * \return The sequence of Relay operators for term4. + * \note The term4 looks like this + * + * Sigma(c,r,s) zp_a * zp_w + * + */ +Expr Conv2DFourthTerm(const Expr& input_zero_point, const Expr& kernel_zero_point, int in_channels, + int kernel_h, int kernel_w) { + Expr scalar_term4 = MakeConstantScalar(DataType::Int(32), in_channels * kernel_h * kernel_w); + Expr variable_term4 = Multiply(input_zero_point, kernel_zero_point); + return Multiply(scalar_term4, variable_term4); +} + /* * \brief Combines different terms of qnn conv2d lowering. * \param term1 The term1 of qnn conv2d lowering. @@ -656,9 +704,24 @@ Expr QnnConv2DCanonicalize(const Attrs& attrs, const Array& new_args, std::tie(batch_size, in_channels, out_channels, kernel_h, kernel_w, channel_multiplier) = GetWorkload(arg_types, param); - // Extract the integer zero points. - auto input_zero_point_int = GetScalarFromConstant(input_zero_point); - auto kernel_zero_point_int = GetScalarFromConstant(kernel_zero_point); + // zero points are allowed to be non-scalar. Let's check if that's the case. + bool dynamic_zp = false; + // Use -1 zero point as a default for dynamic. + int input_zero_point_int = -1; + int kernel_zero_point_int = -1; + + // Input zero point can either be a constant or a scalar expression. + if (IsConstScalar(input_zero_point) && (IsConstScalar(kernel_zero_point))) { + // Extract the integer zero points. + input_zero_point_int = GetScalarFromConstant(input_zero_point); + kernel_zero_point_int = GetScalarFromConstant(kernel_zero_point); + } else { + // Make kernel_zero_point expression a 1-D tensor for consistent shape. + kernel_zero_point = Reshape(kernel_zero_point, { + -1, + }); + dynamic_zp = true; + } // Fallback to int32 conv if there is dilation with non-zero kernel point or grouped conv2d // For dilated conv, if the kernel zero point is non-zero, the pooling operator also has to @@ -668,8 +731,26 @@ Expr QnnConv2DCanonicalize(const Attrs& attrs, const Array& new_args, ICHECK_EQ(param->dilation.size(), 2) << "qnn.conv2d only supports 2D dilation"; auto dilation_h = get_const_int(param->dilation[0]); auto dilation_w = get_const_int(param->dilation[1]); - if ((kernel_zero_point_int != 0 && (dilation_h != 1 || dilation_w != 1)) || - (param->groups != 1 && !is_depthwise(param))) { + // Check if qnn supports the conv2d parameters. If not, fallback to regular conv2d. + bool supported_dilation = (kernel_zero_point_int == 0) || (dilation_h == 1 && dilation_w == 1); + bool supported_groups = (param->groups == 1 || is_depthwise(param)); + bool conv2d_params_supported = supported_dilation && supported_groups; + + // If we need to fall back to default conv2d, kernel zp may need to be broadcast to kernel_layout. + // Otherwise, we broadcast it to data_layout for qnn lowering. + if (dynamic_zp) { + if (!conv2d_params_supported) { + Layout kernel_layout(param->kernel_layout); + int kernel_axis = kernel_layout.IndexOf(LayoutAxis::Get("O")); + kernel_zero_point = ExpandBiasToMatchAxis(kernel_zero_point, 4, {kernel_axis}); + } else { + Layout data_layout(param->data_layout); + int channel_axis = data_layout.IndexOf(LayoutAxis::Get("C")); + kernel_zero_point = ExpandBiasToMatchAxis(kernel_zero_point, 4, {channel_axis}); + } + } + + if (!conv2d_params_supported) { return Conv2DFallBack(data, weight, input_zero_point, kernel_zero_point, param); } else if (is_depthwise(param)) { ICHECK_NE(channel_multiplier, -1); @@ -679,8 +760,13 @@ Expr QnnConv2DCanonicalize(const Attrs& attrs, const Array& new_args, kernel_w, channel_multiplier); auto term3 = DepthwiseConv2DThirdTerm(weight, input_zero_point, param, out_channels, channel_multiplier); - auto term4 = - DepthwiseConv2DFourthTerm(input_zero_point_int, kernel_zero_point_int, kernel_h, kernel_w); + Expr term4; + if (dynamic_zp) { + term4 = DepthwiseConv2DFourthTerm(input_zero_point, kernel_zero_point, kernel_h, kernel_w); + } else { + term4 = DepthwiseConv2DFourthTerm(input_zero_point_int, kernel_zero_point_int, kernel_h, + kernel_w); + } return Conv2DCombineTerms(term1, term2, term3, term4, input_zero_point_int, kernel_zero_point_int); } @@ -690,8 +776,13 @@ Expr QnnConv2DCanonicalize(const Attrs& attrs, const Array& new_args, auto term2 = Conv2DSecondTerm(padded_data, kernel_zero_point, param, kernel_h, kernel_w, out_channels); auto term3 = Conv2DThirdTerm(weight, input_zero_point, param, out_channels); - auto term4 = Conv2DFourthTerm(input_zero_point_int, kernel_zero_point_int, in_channels, kernel_h, - kernel_w); + Expr term4; + if (dynamic_zp) { + term4 = Conv2DFourthTerm(input_zero_point, kernel_zero_point, in_channels, kernel_h, kernel_w); + } else { + term4 = Conv2DFourthTerm(input_zero_point_int, kernel_zero_point_int, in_channels, kernel_h, + kernel_w); + } return Conv2DCombineTerms(term1, term2, term3, term4, input_zero_point_int, kernel_zero_point_int); } diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 06832a9cbf62..8422cda42afc 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -4919,8 +4919,6 @@ def verify_eyelike(indata): target_skips = { "cuda": [ - "test_basic_convinteger", - "test_convinteger_with_padding", "test_range_float_type_positive_delta_expanded", "test_range_int32_type_positive_delta_expanded", "test_mod_mixed_sign_float16", @@ -5375,7 +5373,6 @@ def get_random_uniform(shape, dtype="float32", high=1.0, low=0.0, seed=None): tvm.testing.assert_allclose(real, expected, rtol=1e-5) -@tvm.testing.known_failing_targets("cuda") @tvm.testing.parametrize_targets def test_convinteger(target, dev): def verify_convinteger( @@ -5389,26 +5386,22 @@ def verify_convinteger( auto_pad="NOTSET", dtype="uint8", ): - x_array = np.random.randint(low=0, high=255, size=x_shape).astype(dtype) w_array = np.random.uniform(low=0, high=255, size=w_shape).astype(dtype) - x_zero_point_array = np.random.randint(0, 255, size=[]).astype(dtype) - w_zero_point_array = np.random.randint(0, 255, size=[]).astype(dtype) + x_zero_point_array = np.random.randint(0, 255, size=[1]).astype(dtype) + w_zero_point_array = np.random.randint(0, 255, size=[1]).astype(dtype) ONNX_DTYPE = mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype(dtype)] input_nodes = [ helper.make_tensor_value_info("x", ONNX_DTYPE, list(x_shape)), helper.make_tensor_value_info("w", ONNX_DTYPE, list(w_shape)), - helper.make_tensor_value_info("x_zero_point", ONNX_DTYPE, []), - helper.make_tensor_value_info("w_zero_point", ONNX_DTYPE, []), ] - input_names = [ - "x", - "w", - "x_zero_point", - "w_zero_point", + initializer = [ + helper.make_tensor("x_zero_point", ONNX_DTYPE, [], x_zero_point_array), + helper.make_tensor("w_zero_point", ONNX_DTYPE, [], w_zero_point_array), ] - input_values = [x_array, w_array, x_zero_point_array, w_zero_point_array] + input_names = ["x", "w", "x_zero_point", "w_zero_point"] + input_values = [x_array, w_array] if padding is None: ## autopadding with unset default attributes @@ -5443,11 +5436,12 @@ def verify_convinteger( [node], "convinteger_test", inputs=input_nodes, + initializer=initializer, outputs=[helper.make_tensor_value_info("y", TensorProto.INT32, list(y_shape))], ) model = helper.make_model(graph, producer_name="convinteger_test") # opt_level=1 will cause error - verify_with_ort_with_inputs(model, input_values, opt_level=2, target=target, dev=dev) + verify_with_ort_with_inputs(model, input_values, target=target, dev=dev, opt_level=2) def repeat(N, D): return tuple([N for _ in range(D)]) @@ -5556,7 +5550,8 @@ def repeat(N, D): test_scatter() test_lrn() test_instance_norm() - test_upsample() + test_upsample_nearest() + test_upsample_bilinear() test_forward_min() test_forward_max() test_forward_mean() diff --git a/tests/python/relay/test_op_qnn_conv2d.py b/tests/python/relay/test_op_qnn_conv2d.py index 3a81e6e7b47a..3736350cbfe1 100644 --- a/tests/python/relay/test_op_qnn_conv2d.py +++ b/tests/python/relay/test_op_qnn_conv2d.py @@ -49,10 +49,21 @@ def get_ref_func( groups, channels=None, ): + if isinstance(input_zero_point, (int, float)): + input_zero_point = relay.const(input_zero_point, "int32") + if isinstance(kernel_zero_point, (int, float)): + kernel_zero_point = relay.const(kernel_zero_point, "int32") + else: + # Kernel zero point expression requires manual broadcasting for some layouts. + if kernel_layout == "OIHW": + kernel_zero_point = relay.reshape(kernel_zero_point, [-1, 1, 1, 1]) + elif kernel_layout == "HWOI": + kernel_zero_point = relay.reshape(kernel_zero_point, [1, 1, -1, 1]) + casted_data = relay.op.cast(data, "int32") casted_kernel = relay.op.cast(kernel, "int32") - shifted_data = relay.op.subtract(casted_data, relay.const(input_zero_point, "int32")) - shifted_kernel = relay.op.subtract(casted_kernel, relay.const(kernel_zero_point, "int32")) + shifted_data = relay.op.subtract(casted_data, input_zero_point) + shifted_kernel = relay.op.subtract(casted_kernel, kernel_zero_point) func = relay.op.nn.conv2d( shifted_data, shifted_kernel, @@ -88,11 +99,16 @@ def get_qnn_func( channels, groups, ): + if isinstance(input_zero_point, (int, float)): + input_zero_point = relay.const(input_zero_point, "int32") + if isinstance(kernel_zero_point, (int, float)): + kernel_zero_point = relay.const(kernel_zero_point, "int32") + func = relay.qnn.op.conv2d( data, kernel, - input_zero_point=relay.const(input_zero_point, "int32"), - kernel_zero_point=relay.const(kernel_zero_point, "int32"), + input_zero_point=input_zero_point, + kernel_zero_point=kernel_zero_point, input_scale=relay.const(input_scale, "float32"), kernel_scale=relay.const(kernel_scale, "float32"), kernel_size=kernel_size, @@ -419,6 +435,62 @@ def test_both_zero_point(): verify(ref_func, qnn_func, data_shape, data_dtype, kernel_shape, kernel_dtype) +def test_dynamic_zero_point(): + with TempOpAttr("qnn.conv2d", "FTVMQnnLegalize", legalize_qnn_conv2d): + + # uint8 input with non static zero points. + data_shape = (2, 4, 2, 4) + data_dtype = "uint8" + kernel_shape = (3, 4, 2, 2) + kernel_dtype = "uint8" + input_zero_point = relay.op.multiply( + relay.const(2, dtype="int32"), relay.const(2, dtype="int32") + ) + kernel_zero_point = relay.const(np.random.randint(10, size=[3]), "int32") + ref_func, qnn_func = get_funcs( + data_shape=data_shape, + data_dtype=data_dtype, + kernel_shape=kernel_shape, + kernel_dtype=kernel_dtype, + input_zero_point=input_zero_point, + kernel_zero_point=kernel_zero_point, + input_scale=1.0, + kernel_scale=1.0, + kernel_size=(2, 2), + padding=(0, 0), + strides=(1, 1), + dilation=(1, 1), + data_layout="NCHW", + kernel_layout="OIHW", + out_dtype="int32", + ) + verify(ref_func, qnn_func, data_shape, data_dtype, kernel_shape, kernel_dtype) + + # int8 input + data_shape = (2, 4, 2, 4) + data_dtype = "int8" + kernel_shape = (3, 4, 2, 2) + kernel_dtype = "int8" + ref_func, qnn_func = get_funcs( + data_shape=data_shape, + data_dtype=data_dtype, + kernel_shape=kernel_shape, + kernel_dtype=kernel_dtype, + input_zero_point=input_zero_point, + kernel_zero_point=kernel_zero_point, + input_scale=1.0, + kernel_scale=1.0, + kernel_size=(2, 2), + padding=(0, 0), + strides=(1, 1), + dilation=(1, 1), + data_layout="NCHW", + kernel_layout="OIHW", + out_dtype="int32", + ) + verify(ref_func, qnn_func, data_shape, data_dtype, kernel_shape, kernel_dtype) + + def test_layout(): with TempOpAttr("qnn.conv2d", "FTVMQnnLegalize", legalize_qnn_conv2d): @@ -888,13 +960,17 @@ def test_depthwise_depth_multiplier(): data_dtype = "uint8" kernel_shape = (4, 1, 3, 3) kernel_dtype = "uint8" + input_zero_point = relay.op.multiply( + relay.const(2, dtype="int32"), relay.const(2, dtype="int32") + ) + kernel_zero_point = relay.const(np.random.randint(10, size=[4]), "int32") ref_func, qnn_func = get_funcs( data_shape=data_shape, data_dtype=data_dtype, kernel_shape=kernel_shape, kernel_dtype=kernel_dtype, - input_zero_point=5, - kernel_zero_point=3, + input_zero_point=input_zero_point, + kernel_zero_point=kernel_zero_point, input_scale=1.0, kernel_scale=1.0, kernel_size=(3, 3), @@ -905,6 +981,7 @@ def test_depthwise_depth_multiplier(): kernel_layout="OIHW", out_dtype="int32", groups=4, + channels=4, ) verify(ref_func, qnn_func, data_shape, data_dtype, kernel_shape, kernel_dtype) @@ -919,8 +996,8 @@ def test_depthwise_depth_multiplier(): data_dtype=data_dtype, kernel_shape=kernel_shape, kernel_dtype=kernel_dtype, - input_zero_point=5, - kernel_zero_point=3, + input_zero_point=input_zero_point, + kernel_zero_point=kernel_zero_point, input_scale=1.0, kernel_scale=1.0, kernel_size=(3, 3), @@ -946,8 +1023,8 @@ def test_depthwise_depth_multiplier(): data_dtype=data_dtype, kernel_shape=kernel_shape, kernel_dtype=kernel_dtype, - input_zero_point=5, - kernel_zero_point=3, + input_zero_point=input_zero_point, + kernel_zero_point=kernel_zero_point, input_scale=1.0, kernel_scale=1.0, kernel_size=(3, 3), @@ -971,8 +1048,8 @@ def test_depthwise_depth_multiplier(): data_dtype=data_dtype, kernel_shape=kernel_shape, kernel_dtype=kernel_dtype, - input_zero_point=5, - kernel_zero_point=3, + input_zero_point=input_zero_point, + kernel_zero_point=kernel_zero_point, input_scale=1.0, kernel_scale=1.0, kernel_size=(3, 3),