Skip to content

Commit

Permalink
[Relay][QNN] Support for non scalar zero points in qnn.conv2d (apache…
Browse files Browse the repository at this point in the history
…#8620)

* conv2d working, fixing conv2d_depthwise

* Depthwise conv2d working.

* Make convinteger work on cuda.

* Simplify code and add tests.

* Formatting.

* Fixed fallback broadcasting.

* Fix fallback broadcasting.

* Formatting.

* Fix lint

* Merge with new test parameterization.
  • Loading branch information
Josh Fromm authored and mehrdadh committed Aug 11, 2021
1 parent 010bf86 commit e8e9b63
Show file tree
Hide file tree
Showing 4 changed files with 273 additions and 59 deletions.
91 changes: 71 additions & 20 deletions python/tvm/relay/qnn/op/legalizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,25 @@ def get_scalar_from_constant(expr):
return value.item(0)


def _shift(data, zero_point, out_dtype):
"""Shifts (add/subtracts) the qnn tensor with +/-128)"""
if out_dtype == "uint8":
shift = 128
elif out_dtype == "int8":
shift = -128
else:
raise ValueError("Unsupported out dtype.")
data_modified = relay.cast(data, "int32")
data_modified = relay.add(data_modified, relay.const(shift, "int32"))
data_modified = relay.cast(data_modified, out_dtype)
if isinstance(zero_point, relay.Constant):
zero_point_val = get_scalar_from_constant(zero_point)
zero_point_modified = relay.const(zero_point_val + shift, "int32")
else:
zero_point_modified = zero_point + relay.const(shift, "int32")
return (data_modified, zero_point_modified)


# Helper function for lowering in the abscence of fast Int8 arithmetic units.
def helper_no_fast_int8_hw_legalization(attrs, inputs, types, relay_op):
"""Converts QNN operators into a sequence of Relay operators that are friendly to HW that do
Expand Down Expand Up @@ -161,22 +180,6 @@ def helper_change_dtypes_to_uint8_int8(attrs, inputs, types, relay_op):
result : tvm.relay.Expr
The legalized expr
"""

def _shift(data, zero_point, out_dtype):
"""Shifts (add/subtracts) the qnn tensor with +/-128)"""
if out_dtype == "uint8":
shift = 128
elif out_dtype == "int8":
shift = -128
else:
raise ValueError("Unsupported out dtype.")
data_modified = relay.cast(data, "int32")
data_modified = relay.add(data_modified, relay.const(shift, "int32"))
data_modified = relay.cast(data_modified, out_dtype)
zero_point_val = get_scalar_from_constant(zero_point)
zero_point_modified = relay.const(zero_point_val + shift, "int32")
return (data_modified, zero_point_modified)

# Collect the dtypes.
data_dtype = types[0].dtype
kernel_dtype = types[1].dtype
Expand Down Expand Up @@ -205,6 +208,54 @@ def _shift(data, zero_point, out_dtype):
)


# Helper function to change dtypes to int8 x int8. Cuda dp4a instructions prefer this setting.
def helper_change_dtypes_to_int8(attrs, inputs, types, relay_op):
"""Legalizes QNN conv2d/dense op for Nvidia HW. dp4a supports i8 x i8 fast conv/MM. If the
dtypes are already good, we dont transform. Else, we shift the tensor values and zero points
to change the dtype.
Parameters
----------
attrs : tvm.ir.Attrs
Attributes of current convolution
inputs : list of tvm.relay.Expr
The args of the Relay expr to be legalized
types : list of types
List of input and output types
Returns
-------
result : tvm.relay.Expr
The legalized expr
"""
# Collect the dtypes.
data_dtype = types[0].dtype
kernel_dtype = types[1].dtype

# Collect the input exprs.
data, kernel, input_zero_point, kernel_zero_point, input_scale, kernel_scale = inputs

# dp4a supports i8 x i8 fast conv/MM. Don't do anything if it is already satisfied.
if data_dtype == "int8" and kernel_dtype == "int8":
return None

# Shift input if necessary.
if data_dtype == "uint8":
# Compute (QA + 128) and (zp_a + 128)
data, input_zero_point = _shift(data, input_zero_point, "int8")

# Shift kernel if necessary.
if kernel_dtype == "uint8":
# Compute (QA - 128) and (zp_a - 128)
kernel, kernel_zero_point = _shift(kernel, kernel_zero_point, "int8")

# Call qnn.conv2d with modified inputs and zero points.
new_attrs = {k: attrs[k] for k in attrs.keys()}
return relay_op(
data, kernel, input_zero_point, kernel_zero_point, input_scale, kernel_scale, **new_attrs
)


# Helper function to change dtypes to be same. ARM dotprod instructions prefer this setting.
def helper_change_dtypes_to_be_same(attrs, inputs, types, relay_op):
"""Sometimes MxNet + MLDNN can lead to uint8 x int8 datatypes for the conv inputs. However,
Expand Down Expand Up @@ -339,11 +390,11 @@ def _qnn_dense_legalize_intel_cpu(attrs, inputs, types):

@qnn_conv2d_legalize.register("cuda")
def _qnn_conv2d_legalize_cuda(attrs, inputs, types):
# CUDA prefers the dtypes to be same.
return helper_change_dtypes_to_be_same(attrs, inputs, types, relay.qnn.op.conv2d)
# CUDA prefers both datatypes to be int8.
return helper_change_dtypes_to_int8(attrs, inputs, types, relay.qnn.op.conv2d)


@qnn_dense_legalize.register("cuda")
def _qnn_dense_legalize_cuda(attrs, inputs, types):
# CUDA prefers the dtypes to be same.
return helper_change_dtypes_to_be_same(attrs, inputs, types, relay.qnn.op.dense)
# CUDA prefers both datatypes to be the int8.
return helper_change_dtypes_to_int8(attrs, inputs, types, relay.qnn.op.dense)
113 changes: 102 additions & 11 deletions src/relay/qnn/op/convolution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@ bool QnnConv2DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
}
}
ICHECK(IsScalarType(types[2], DataType::Int(32))); // input_zero_point
ICHECK(IsScalarType(types[3], DataType::Int(32))); // weight_zero_point
ICHECK(IsScalarType(types[4], DataType::Float(32))); // input_scale
// Kernel scale can be a vector of length output_channels or a scalar.
if (param->groups == 1) {
Expand Down Expand Up @@ -293,7 +292,11 @@ Expr DepthwiseConv2DSecondTerm(const Expr& padded_data, const Expr& kernel_zero_
auto multiplied_t2 = reduced_t2;
auto one_scalar = MakeConstantScalar(DataType::Int(32), 1);
if (!IsEqualScalar(kernel_zero_point, one_scalar)) {
multiplied_t2 = Multiply(kernel_zero_point, reduced_t2);
if (!IsConstScalar(kernel_zero_point)) {
multiplied_t2 = Multiply(MakeRepeat(kernel_zero_point, channel_multiplier, 0), reduced_t2);
} else {
multiplied_t2 = Multiply(kernel_zero_point, reduced_t2);
}
}

// Reduce the C dimension. Find the dimension.
Expand Down Expand Up @@ -378,6 +381,25 @@ Expr DepthwiseConv2DFourthTerm(int input_zero_point_int, int kernel_zero_point_i
return MakeConstantScalar(DataType::Int(32), scalar_term4);
}

/*
* \brief Calculates the fourth term in the qnn.conv2d depthwise lowering sequence
for non-constant zero_points.
* \param input_zero_point The Expr for the input zero point.
* \param kernel_zero_point The Expr for the kernel zero point.
* \param kernel_h The height of kernel.
* \param kernel_w The width of kernel.
* \return The sequence of Relay operators for term4.
* \note The term4 looks like this
*
* Sigma(r, s) zp_a * zp_w
*/
Expr DepthwiseConv2DFourthTerm(const Expr& input_zero_point, const Expr& kernel_zero_point,
int kernel_h, int kernel_w) {
Expr scalar_term4 = MakeConstantScalar(DataType::Int(32), kernel_h * kernel_w);
Expr variable_term4 = Multiply(input_zero_point, kernel_zero_point);
return Multiply(scalar_term4, variable_term4);
}

/*
* \brief Calculates the first term in the qnn.conv2d lowering sequence.
* \param data The input expr.
Expand Down Expand Up @@ -457,6 +479,11 @@ Expr Conv2DSecondTerm(const Expr& padded_data, const Expr& kernel_zero_point,
auto multiplied_t2 = reduced_t2;
auto one_scalar = MakeConstantScalar(DataType::Int(32), 1);
if (!IsEqualScalar(kernel_zero_point, one_scalar)) {
if (!IsConstScalar(kernel_zero_point)) {
Layout layout(param->data_layout);
int channel_axis = layout.IndexOf(LayoutAxis::Get('C'));
reduced_t2 = MakeRepeat(reduced_t2, out_channels, channel_axis);
}
multiplied_t2 = Multiply(kernel_zero_point, reduced_t2);
}
return multiplied_t2;
Expand Down Expand Up @@ -531,6 +558,27 @@ Expr Conv2DFourthTerm(int input_zero_point_int, int kernel_zero_point_int, int i
return MakeConstantScalar(DataType::Int(32), scalar_term4);
}

/*
* \brief Calculates the fourth term in the qnn.conv2d lowering sequence
for non-constant zero_points.
* \param input_zero_point The Expr for the input zero point.
* \param kernel_zero_point The Expr for the kernel zero point.
* \param in_channels The number of input channels.
* \param kernel_h The height of kernel.
* \param kernel_w The width of kernel.
* \return The sequence of Relay operators for term4.
* \note The term4 looks like this
*
* Sigma(c,r,s) zp_a * zp_w
*
*/
Expr Conv2DFourthTerm(const Expr& input_zero_point, const Expr& kernel_zero_point, int in_channels,
int kernel_h, int kernel_w) {
Expr scalar_term4 = MakeConstantScalar(DataType::Int(32), in_channels * kernel_h * kernel_w);
Expr variable_term4 = Multiply(input_zero_point, kernel_zero_point);
return Multiply(scalar_term4, variable_term4);
}

/*
* \brief Combines different terms of qnn conv2d lowering.
* \param term1 The term1 of qnn conv2d lowering.
Expand Down Expand Up @@ -656,9 +704,24 @@ Expr QnnConv2DCanonicalize(const Attrs& attrs, const Array<Expr>& new_args,
std::tie(batch_size, in_channels, out_channels, kernel_h, kernel_w, channel_multiplier) =
GetWorkload(arg_types, param);

// Extract the integer zero points.
auto input_zero_point_int = GetScalarFromConstant<int>(input_zero_point);
auto kernel_zero_point_int = GetScalarFromConstant<int>(kernel_zero_point);
// zero points are allowed to be non-scalar. Let's check if that's the case.
bool dynamic_zp = false;
// Use -1 zero point as a default for dynamic.
int input_zero_point_int = -1;
int kernel_zero_point_int = -1;

// Input zero point can either be a constant or a scalar expression.
if (IsConstScalar(input_zero_point) && (IsConstScalar(kernel_zero_point))) {
// Extract the integer zero points.
input_zero_point_int = GetScalarFromConstant<int>(input_zero_point);
kernel_zero_point_int = GetScalarFromConstant<int>(kernel_zero_point);
} else {
// Make kernel_zero_point expression a 1-D tensor for consistent shape.
kernel_zero_point = Reshape(kernel_zero_point, {
-1,
});
dynamic_zp = true;
}

// Fallback to int32 conv if there is dilation with non-zero kernel point or grouped conv2d
// For dilated conv, if the kernel zero point is non-zero, the pooling operator also has to
Expand All @@ -668,8 +731,26 @@ Expr QnnConv2DCanonicalize(const Attrs& attrs, const Array<Expr>& new_args,
ICHECK_EQ(param->dilation.size(), 2) << "qnn.conv2d only supports 2D dilation";
auto dilation_h = get_const_int(param->dilation[0]);
auto dilation_w = get_const_int(param->dilation[1]);
if ((kernel_zero_point_int != 0 && (dilation_h != 1 || dilation_w != 1)) ||
(param->groups != 1 && !is_depthwise(param))) {
// Check if qnn supports the conv2d parameters. If not, fallback to regular conv2d.
bool supported_dilation = (kernel_zero_point_int == 0) || (dilation_h == 1 && dilation_w == 1);
bool supported_groups = (param->groups == 1 || is_depthwise(param));
bool conv2d_params_supported = supported_dilation && supported_groups;

// If we need to fall back to default conv2d, kernel zp may need to be broadcast to kernel_layout.
// Otherwise, we broadcast it to data_layout for qnn lowering.
if (dynamic_zp) {
if (!conv2d_params_supported) {
Layout kernel_layout(param->kernel_layout);
int kernel_axis = kernel_layout.IndexOf(LayoutAxis::Get("O"));
kernel_zero_point = ExpandBiasToMatchAxis(kernel_zero_point, 4, {kernel_axis});
} else {
Layout data_layout(param->data_layout);
int channel_axis = data_layout.IndexOf(LayoutAxis::Get("C"));
kernel_zero_point = ExpandBiasToMatchAxis(kernel_zero_point, 4, {channel_axis});
}
}

if (!conv2d_params_supported) {
return Conv2DFallBack(data, weight, input_zero_point, kernel_zero_point, param);
} else if (is_depthwise(param)) {
ICHECK_NE(channel_multiplier, -1);
Expand All @@ -679,8 +760,13 @@ Expr QnnConv2DCanonicalize(const Attrs& attrs, const Array<Expr>& new_args,
kernel_w, channel_multiplier);
auto term3 =
DepthwiseConv2DThirdTerm(weight, input_zero_point, param, out_channels, channel_multiplier);
auto term4 =
DepthwiseConv2DFourthTerm(input_zero_point_int, kernel_zero_point_int, kernel_h, kernel_w);
Expr term4;
if (dynamic_zp) {
term4 = DepthwiseConv2DFourthTerm(input_zero_point, kernel_zero_point, kernel_h, kernel_w);
} else {
term4 = DepthwiseConv2DFourthTerm(input_zero_point_int, kernel_zero_point_int, kernel_h,
kernel_w);
}
return Conv2DCombineTerms(term1, term2, term3, term4, input_zero_point_int,
kernel_zero_point_int);
}
Expand All @@ -690,8 +776,13 @@ Expr QnnConv2DCanonicalize(const Attrs& attrs, const Array<Expr>& new_args,
auto term2 =
Conv2DSecondTerm(padded_data, kernel_zero_point, param, kernel_h, kernel_w, out_channels);
auto term3 = Conv2DThirdTerm(weight, input_zero_point, param, out_channels);
auto term4 = Conv2DFourthTerm(input_zero_point_int, kernel_zero_point_int, in_channels, kernel_h,
kernel_w);
Expr term4;
if (dynamic_zp) {
term4 = Conv2DFourthTerm(input_zero_point, kernel_zero_point, in_channels, kernel_h, kernel_w);
} else {
term4 = Conv2DFourthTerm(input_zero_point_int, kernel_zero_point_int, in_channels, kernel_h,
kernel_w);
}
return Conv2DCombineTerms(term1, term2, term3, term4, input_zero_point_int,
kernel_zero_point_int);
}
Expand Down
27 changes: 11 additions & 16 deletions tests/python/frontend/onnx/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -4919,8 +4919,6 @@ def verify_eyelike(indata):

target_skips = {
"cuda": [
"test_basic_convinteger",
"test_convinteger_with_padding",
"test_range_float_type_positive_delta_expanded",
"test_range_int32_type_positive_delta_expanded",
"test_mod_mixed_sign_float16",
Expand Down Expand Up @@ -5375,7 +5373,6 @@ def get_random_uniform(shape, dtype="float32", high=1.0, low=0.0, seed=None):
tvm.testing.assert_allclose(real, expected, rtol=1e-5)


@tvm.testing.known_failing_targets("cuda")
@tvm.testing.parametrize_targets
def test_convinteger(target, dev):
def verify_convinteger(
Expand All @@ -5389,26 +5386,22 @@ def verify_convinteger(
auto_pad="NOTSET",
dtype="uint8",
):

x_array = np.random.randint(low=0, high=255, size=x_shape).astype(dtype)
w_array = np.random.uniform(low=0, high=255, size=w_shape).astype(dtype)
x_zero_point_array = np.random.randint(0, 255, size=[]).astype(dtype)
w_zero_point_array = np.random.randint(0, 255, size=[]).astype(dtype)
x_zero_point_array = np.random.randint(0, 255, size=[1]).astype(dtype)
w_zero_point_array = np.random.randint(0, 255, size=[1]).astype(dtype)

ONNX_DTYPE = mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype(dtype)]
input_nodes = [
helper.make_tensor_value_info("x", ONNX_DTYPE, list(x_shape)),
helper.make_tensor_value_info("w", ONNX_DTYPE, list(w_shape)),
helper.make_tensor_value_info("x_zero_point", ONNX_DTYPE, []),
helper.make_tensor_value_info("w_zero_point", ONNX_DTYPE, []),
]
input_names = [
"x",
"w",
"x_zero_point",
"w_zero_point",
initializer = [
helper.make_tensor("x_zero_point", ONNX_DTYPE, [], x_zero_point_array),
helper.make_tensor("w_zero_point", ONNX_DTYPE, [], w_zero_point_array),
]
input_values = [x_array, w_array, x_zero_point_array, w_zero_point_array]
input_names = ["x", "w", "x_zero_point", "w_zero_point"]
input_values = [x_array, w_array]

if padding is None:
## autopadding with unset default attributes
Expand Down Expand Up @@ -5443,11 +5436,12 @@ def verify_convinteger(
[node],
"convinteger_test",
inputs=input_nodes,
initializer=initializer,
outputs=[helper.make_tensor_value_info("y", TensorProto.INT32, list(y_shape))],
)
model = helper.make_model(graph, producer_name="convinteger_test")
# opt_level=1 will cause error
verify_with_ort_with_inputs(model, input_values, opt_level=2, target=target, dev=dev)
verify_with_ort_with_inputs(model, input_values, target=target, dev=dev, opt_level=2)

def repeat(N, D):
return tuple([N for _ in range(D)])
Expand Down Expand Up @@ -5556,7 +5550,8 @@ def repeat(N, D):
test_scatter()
test_lrn()
test_instance_norm()
test_upsample()
test_upsample_nearest()
test_upsample_bilinear()
test_forward_min()
test_forward_max()
test_forward_mean()
Expand Down
Loading

0 comments on commit e8e9b63

Please sign in to comment.