Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay][QNN] Support for non scalar zero points in qnn.conv2d #8620

Merged
merged 11 commits into from
Aug 7, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 71 additions & 20 deletions python/tvm/relay/qnn/op/legalizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,25 @@ def get_scalar_from_constant(expr):
return value.item(0)


def _shift(data, zero_point, out_dtype):
"""Shifts (add/subtracts) the qnn tensor with +/-128)"""
if out_dtype == "uint8":
shift = 128
elif out_dtype == "int8":
shift = -128
else:
raise ValueError("Unsupported out dtype.")
data_modified = relay.cast(data, "int32")
data_modified = relay.add(data_modified, relay.const(shift, "int32"))
data_modified = relay.cast(data_modified, out_dtype)
if isinstance(zero_point, relay.Constant):
zero_point_val = get_scalar_from_constant(zero_point)
zero_point_modified = relay.const(zero_point_val + shift, "int32")
else:
zero_point_modified = zero_point + relay.const(shift, "int32")
return (data_modified, zero_point_modified)


# Helper function for lowering in the abscence of fast Int8 arithmetic units.
def helper_no_fast_int8_hw_legalization(attrs, inputs, types, relay_op):
"""Converts QNN operators into a sequence of Relay operators that are friendly to HW that do
Expand Down Expand Up @@ -161,22 +180,6 @@ def helper_change_dtypes_to_uint8_int8(attrs, inputs, types, relay_op):
result : tvm.relay.Expr
The legalized expr
"""

def _shift(data, zero_point, out_dtype):
"""Shifts (add/subtracts) the qnn tensor with +/-128)"""
if out_dtype == "uint8":
shift = 128
elif out_dtype == "int8":
shift = -128
else:
raise ValueError("Unsupported out dtype.")
data_modified = relay.cast(data, "int32")
data_modified = relay.add(data_modified, relay.const(shift, "int32"))
data_modified = relay.cast(data_modified, out_dtype)
zero_point_val = get_scalar_from_constant(zero_point)
zero_point_modified = relay.const(zero_point_val + shift, "int32")
return (data_modified, zero_point_modified)

# Collect the dtypes.
data_dtype = types[0].dtype
kernel_dtype = types[1].dtype
Expand Down Expand Up @@ -205,6 +208,54 @@ def _shift(data, zero_point, out_dtype):
)


# Helper function to change dtypes to int8 x int8. Cuda dp4a instructions prefer this setting.
def helper_change_dtypes_to_int8(attrs, inputs, types, relay_op):
"""Legalizes QNN conv2d/dense op for Nvidia HW. dp4a supports i8 x i8 fast conv/MM. If the
dtypes are already good, we dont transform. Else, we shift the tensor values and zero points
to change the dtype.

Parameters
----------
attrs : tvm.ir.Attrs
Attributes of current convolution
inputs : list of tvm.relay.Expr
The args of the Relay expr to be legalized
types : list of types
List of input and output types

Returns
-------
result : tvm.relay.Expr
The legalized expr
"""
# Collect the dtypes.
data_dtype = types[0].dtype
kernel_dtype = types[1].dtype

# Collect the input exprs.
data, kernel, input_zero_point, kernel_zero_point, input_scale, kernel_scale = inputs

# dp4a supports i8 x i8 fast conv/MM. Don't do anything if it is already satisfied.
if data_dtype == "int8" and kernel_dtype == "int8":
return None

# Shift input if necessary.
if data_dtype == "uint8":
# Compute (QA + 128) and (zp_a + 128)
data, input_zero_point = _shift(data, input_zero_point, "int8")

# Shift kernel if necessary.
if kernel_dtype == "uint8":
# Compute (QA - 128) and (zp_a - 128)
kernel, kernel_zero_point = _shift(kernel, kernel_zero_point, "int8")

# Call qnn.conv2d with modified inputs and zero points.
new_attrs = {k: attrs[k] for k in attrs.keys()}
return relay_op(
data, kernel, input_zero_point, kernel_zero_point, input_scale, kernel_scale, **new_attrs
)


# Helper function to change dtypes to be same. ARM dotprod instructions prefer this setting.
def helper_change_dtypes_to_be_same(attrs, inputs, types, relay_op):
"""Sometimes MxNet + MLDNN can lead to uint8 x int8 datatypes for the conv inputs. However,
Expand Down Expand Up @@ -339,11 +390,11 @@ def _qnn_dense_legalize_intel_cpu(attrs, inputs, types):

@qnn_conv2d_legalize.register("cuda")
def _qnn_conv2d_legalize_cuda(attrs, inputs, types):
# CUDA prefers the dtypes to be same.
return helper_change_dtypes_to_be_same(attrs, inputs, types, relay.qnn.op.conv2d)
# CUDA prefers both datatypes to be int8.
return helper_change_dtypes_to_int8(attrs, inputs, types, relay.qnn.op.conv2d)


@qnn_dense_legalize.register("cuda")
def _qnn_dense_legalize_cuda(attrs, inputs, types):
# CUDA prefers the dtypes to be same.
return helper_change_dtypes_to_be_same(attrs, inputs, types, relay.qnn.op.dense)
# CUDA prefers both datatypes to be the int8.
return helper_change_dtypes_to_int8(attrs, inputs, types, relay.qnn.op.dense)
113 changes: 102 additions & 11 deletions src/relay/qnn/op/convolution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@ bool QnnConv2DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
}
}
ICHECK(IsScalarType(types[2], DataType::Int(32))); // input_zero_point
ICHECK(IsScalarType(types[3], DataType::Int(32))); // weight_zero_point
ICHECK(IsScalarType(types[4], DataType::Float(32))); // input_scale
// Kernel scale can be a vector of length output_channels or a scalar.
if (param->groups == 1) {
Expand Down Expand Up @@ -293,7 +292,11 @@ Expr DepthwiseConv2DSecondTerm(const Expr& padded_data, const Expr& kernel_zero_
auto multiplied_t2 = reduced_t2;
auto one_scalar = MakeConstantScalar(DataType::Int(32), 1);
if (!IsEqualScalar(kernel_zero_point, one_scalar)) {
multiplied_t2 = Multiply(kernel_zero_point, reduced_t2);
if (!IsConstScalar(kernel_zero_point)) {
multiplied_t2 = Multiply(MakeRepeat(kernel_zero_point, channel_multiplier, 0), reduced_t2);
} else {
multiplied_t2 = Multiply(kernel_zero_point, reduced_t2);
}
}

// Reduce the C dimension. Find the dimension.
Expand Down Expand Up @@ -378,6 +381,25 @@ Expr DepthwiseConv2DFourthTerm(int input_zero_point_int, int kernel_zero_point_i
return MakeConstantScalar(DataType::Int(32), scalar_term4);
}

/*
* \brief Calculates the fourth term in the qnn.conv2d depthwise lowering sequence
for non-constant zero_points.
* \param input_zero_point The Expr for the input zero point.
* \param kernel_zero_point The Expr for the kernel zero point.
* \param kernel_h The height of kernel.
* \param kernel_w The width of kernel.
* \return The sequence of Relay operators for term4.
* \note The term4 looks like this
*
* Sigma(r, s) zp_a * zp_w
*/
Expr DepthwiseConv2DFourthTerm(const Expr& input_zero_point, const Expr& kernel_zero_point,
int kernel_h, int kernel_w) {
Expr scalar_term4 = MakeConstantScalar(DataType::Int(32), kernel_h * kernel_w);
Expr variable_term4 = Multiply(input_zero_point, kernel_zero_point);
return Multiply(scalar_term4, variable_term4);
}

/*
* \brief Calculates the first term in the qnn.conv2d lowering sequence.
* \param data The input expr.
Expand Down Expand Up @@ -457,6 +479,11 @@ Expr Conv2DSecondTerm(const Expr& padded_data, const Expr& kernel_zero_point,
auto multiplied_t2 = reduced_t2;
auto one_scalar = MakeConstantScalar(DataType::Int(32), 1);
if (!IsEqualScalar(kernel_zero_point, one_scalar)) {
if (!IsConstScalar(kernel_zero_point)) {
Layout layout(param->data_layout);
int channel_axis = layout.IndexOf(LayoutAxis::Get('C'));
reduced_t2 = MakeRepeat(reduced_t2, out_channels, channel_axis);
}
multiplied_t2 = Multiply(kernel_zero_point, reduced_t2);
}
return multiplied_t2;
Expand Down Expand Up @@ -531,6 +558,27 @@ Expr Conv2DFourthTerm(int input_zero_point_int, int kernel_zero_point_int, int i
return MakeConstantScalar(DataType::Int(32), scalar_term4);
}

/*
* \brief Calculates the fourth term in the qnn.conv2d lowering sequence
for non-constant zero_points.
* \param input_zero_point The Expr for the input zero point.
* \param kernel_zero_point The Expr for the kernel zero point.
* \param in_channels The number of input channels.
* \param kernel_h The height of kernel.
* \param kernel_w The width of kernel.
* \return The sequence of Relay operators for term4.
* \note The term4 looks like this
*
* Sigma(c,r,s) zp_a * zp_w
*
*/
Expr Conv2DFourthTerm(const Expr& input_zero_point, const Expr& kernel_zero_point, int in_channels,
int kernel_h, int kernel_w) {
Expr scalar_term4 = MakeConstantScalar(DataType::Int(32), in_channels * kernel_h * kernel_w);
Expr variable_term4 = Multiply(input_zero_point, kernel_zero_point);
return Multiply(scalar_term4, variable_term4);
}

/*
* \brief Combines different terms of qnn conv2d lowering.
* \param term1 The term1 of qnn conv2d lowering.
Expand Down Expand Up @@ -656,9 +704,24 @@ Expr QnnConv2DCanonicalize(const Attrs& attrs, const Array<Expr>& new_args,
std::tie(batch_size, in_channels, out_channels, kernel_h, kernel_w, channel_multiplier) =
GetWorkload(arg_types, param);

// Extract the integer zero points.
auto input_zero_point_int = GetScalarFromConstant<int>(input_zero_point);
auto kernel_zero_point_int = GetScalarFromConstant<int>(kernel_zero_point);
// zero points are allowed to be non-scalar. Let's check if that's the case.
bool dynamic_zp = false;
// Use -1 zero point as a default for dynamic.
int input_zero_point_int = -1;
int kernel_zero_point_int = -1;

// Input zero point can either be a constant or a scalar expression.
if (IsConstScalar(input_zero_point) && (IsConstScalar(kernel_zero_point))) {
// Extract the integer zero points.
input_zero_point_int = GetScalarFromConstant<int>(input_zero_point);
kernel_zero_point_int = GetScalarFromConstant<int>(kernel_zero_point);
} else {
// Make kernel_zero_point expression a 1-D tensor for consistent shape.
kernel_zero_point = Reshape(kernel_zero_point, {
-1,
});
dynamic_zp = true;
}

// Fallback to int32 conv if there is dilation with non-zero kernel point or grouped conv2d
// For dilated conv, if the kernel zero point is non-zero, the pooling operator also has to
Expand All @@ -668,8 +731,26 @@ Expr QnnConv2DCanonicalize(const Attrs& attrs, const Array<Expr>& new_args,
ICHECK_EQ(param->dilation.size(), 2) << "qnn.conv2d only supports 2D dilation";
auto dilation_h = get_const_int(param->dilation[0]);
auto dilation_w = get_const_int(param->dilation[1]);
if ((kernel_zero_point_int != 0 && (dilation_h != 1 || dilation_w != 1)) ||
(param->groups != 1 && !is_depthwise(param))) {
// Check if qnn supports the conv2d parameters. If not, fallback to regular conv2d.
bool supported_dilation = (kernel_zero_point_int == 0) || (dilation_h == 1 && dilation_w == 1);
bool supported_groups = (param->groups == 1 || is_depthwise(param));
bool conv2d_params_supported = supported_dilation && supported_groups;

// If we need to fall back to default conv2d, kernel zp may need to be broadcast to kernel_layout.
// Otherwise, we broadcast it to data_layout for qnn lowering.
if (dynamic_zp) {
if (!conv2d_params_supported) {
Layout kernel_layout(param->kernel_layout);
int kernel_axis = kernel_layout.IndexOf(LayoutAxis::Get("O"));
kernel_zero_point = ExpandBiasToMatchAxis(kernel_zero_point, 4, {kernel_axis});
} else {
Layout data_layout(param->data_layout);
int channel_axis = data_layout.IndexOf(LayoutAxis::Get("C"));
kernel_zero_point = ExpandBiasToMatchAxis(kernel_zero_point, 4, {channel_axis});
}
}

if (!conv2d_params_supported) {
return Conv2DFallBack(data, weight, input_zero_point, kernel_zero_point, param);
} else if (is_depthwise(param)) {
ICHECK_NE(channel_multiplier, -1);
Expand All @@ -679,8 +760,13 @@ Expr QnnConv2DCanonicalize(const Attrs& attrs, const Array<Expr>& new_args,
kernel_w, channel_multiplier);
auto term3 =
DepthwiseConv2DThirdTerm(weight, input_zero_point, param, out_channels, channel_multiplier);
auto term4 =
DepthwiseConv2DFourthTerm(input_zero_point_int, kernel_zero_point_int, kernel_h, kernel_w);
Expr term4;
if (dynamic_zp) {
term4 = DepthwiseConv2DFourthTerm(input_zero_point, kernel_zero_point, kernel_h, kernel_w);
} else {
term4 = DepthwiseConv2DFourthTerm(input_zero_point_int, kernel_zero_point_int, kernel_h,
kernel_w);
}
return Conv2DCombineTerms(term1, term2, term3, term4, input_zero_point_int,
kernel_zero_point_int);
}
Expand All @@ -690,8 +776,13 @@ Expr QnnConv2DCanonicalize(const Attrs& attrs, const Array<Expr>& new_args,
auto term2 =
Conv2DSecondTerm(padded_data, kernel_zero_point, param, kernel_h, kernel_w, out_channels);
auto term3 = Conv2DThirdTerm(weight, input_zero_point, param, out_channels);
auto term4 = Conv2DFourthTerm(input_zero_point_int, kernel_zero_point_int, in_channels, kernel_h,
kernel_w);
Expr term4;
if (dynamic_zp) {
term4 = Conv2DFourthTerm(input_zero_point, kernel_zero_point, in_channels, kernel_h, kernel_w);
} else {
term4 = Conv2DFourthTerm(input_zero_point_int, kernel_zero_point_int, in_channels, kernel_h,
kernel_w);
}
return Conv2DCombineTerms(term1, term2, term3, term4, input_zero_point_int,
kernel_zero_point_int);
}
Expand Down
27 changes: 11 additions & 16 deletions tests/python/frontend/onnx/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -4919,8 +4919,6 @@ def verify_eyelike(indata):

target_skips = {
"cuda": [
"test_basic_convinteger",
"test_convinteger_with_padding",
"test_range_float_type_positive_delta_expanded",
"test_range_int32_type_positive_delta_expanded",
"test_mod_mixed_sign_float16",
Expand Down Expand Up @@ -5375,7 +5373,6 @@ def get_random_uniform(shape, dtype="float32", high=1.0, low=0.0, seed=None):
tvm.testing.assert_allclose(real, expected, rtol=1e-5)


@tvm.testing.known_failing_targets("cuda")
@tvm.testing.parametrize_targets
def test_convinteger(target, dev):
def verify_convinteger(
Expand All @@ -5389,26 +5386,22 @@ def verify_convinteger(
auto_pad="NOTSET",
dtype="uint8",
):

x_array = np.random.randint(low=0, high=255, size=x_shape).astype(dtype)
w_array = np.random.uniform(low=0, high=255, size=w_shape).astype(dtype)
x_zero_point_array = np.random.randint(0, 255, size=[]).astype(dtype)
w_zero_point_array = np.random.randint(0, 255, size=[]).astype(dtype)
x_zero_point_array = np.random.randint(0, 255, size=[1]).astype(dtype)
w_zero_point_array = np.random.randint(0, 255, size=[1]).astype(dtype)

ONNX_DTYPE = mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype(dtype)]
input_nodes = [
helper.make_tensor_value_info("x", ONNX_DTYPE, list(x_shape)),
helper.make_tensor_value_info("w", ONNX_DTYPE, list(w_shape)),
helper.make_tensor_value_info("x_zero_point", ONNX_DTYPE, []),
helper.make_tensor_value_info("w_zero_point", ONNX_DTYPE, []),
]
input_names = [
"x",
"w",
"x_zero_point",
"w_zero_point",
initializer = [
helper.make_tensor("x_zero_point", ONNX_DTYPE, [], x_zero_point_array),
helper.make_tensor("w_zero_point", ONNX_DTYPE, [], w_zero_point_array),
]
input_values = [x_array, w_array, x_zero_point_array, w_zero_point_array]
input_names = ["x", "w", "x_zero_point", "w_zero_point"]
input_values = [x_array, w_array]

if padding is None:
## autopadding with unset default attributes
Expand Down Expand Up @@ -5443,11 +5436,12 @@ def verify_convinteger(
[node],
"convinteger_test",
inputs=input_nodes,
initializer=initializer,
outputs=[helper.make_tensor_value_info("y", TensorProto.INT32, list(y_shape))],
)
model = helper.make_model(graph, producer_name="convinteger_test")
# opt_level=1 will cause error
verify_with_ort_with_inputs(model, input_values, opt_level=2, target=target, dev=dev)
verify_with_ort_with_inputs(model, input_values, target=target, dev=dev, opt_level=2)

def repeat(N, D):
return tuple([N for _ in range(D)])
Expand Down Expand Up @@ -5556,7 +5550,8 @@ def repeat(N, D):
test_scatter()
test_lrn()
test_instance_norm()
test_upsample()
test_upsample_nearest()
test_upsample_bilinear()
test_forward_min()
test_forward_max()
test_forward_mean()
Expand Down
Loading