diff --git a/include/tvm/runtime/data_type.h b/include/tvm/runtime/data_type.h index e0c3106e14fa..7f68ce2ad5bb 100644 --- a/include/tvm/runtime/data_type.h +++ b/include/tvm/runtime/data_type.h @@ -124,6 +124,16 @@ class DataType { * \return the result type. */ DataType element_of() const { return with_lanes(1); } + /*! + * \brief Assignment operator. + */ + DataType& operator=(const DataType& rhs) { + if (this == &rhs) { + return *this; + } + data_ = rhs.data_; + return *this; + } /*! * \brief Equal comparator. * \param other The data type to compare against. diff --git a/python/tvm/relay/backend/te_compiler.py b/python/tvm/relay/backend/te_compiler.py index a2fbf555e12b..173f31ef08f9 100644 --- a/python/tvm/relay/backend/te_compiler.py +++ b/python/tvm/relay/backend/te_compiler.py @@ -281,25 +281,28 @@ def get_shape(shape): @tvm._ffi.register_func("relay.backend.lower_call") -def lower_call(call, inputs, target): +def lower_call(call, inputs, target, otype=None): """Lower the call expression to op implementation and tensor outputs.""" assert isinstance(call.op, tvm.ir.Op) op = call.op - # Prepare the call_node->checked_type(). For the call node inputs, we ensure that - # the shape is Int32. Following code ensures the same for the output as well. - # TODO(@icemelon9): Support recursive tuple - ret_type = call.checked_type - if isinstance(ret_type, _ty.TensorType): - ret_type = _ty.TensorType(get_shape(ret_type.shape), ret_type.dtype) - elif isinstance(ret_type, _ty.TupleType): - new_fields = [] - for field in ret_type.fields: - if isinstance(field, _ty.TensorType): - new_fields.append(_ty.TensorType(get_shape(field.shape), field.dtype)) - else: - new_fields.append(field) - ret_type = _ty.TupleType(new_fields) + if otype is not None: + ret_type = otype + else: + # Prepare the call_node->checked_type(). For the call node inputs, we ensure that + # the shape is Int32. Following code ensures the same for the output as well. + # TODO(@icemelon9): Support recursive tuple + ret_type = call.checked_type + if isinstance(ret_type, _ty.TensorType): + ret_type = _ty.TensorType(get_shape(ret_type.shape), ret_type.dtype) + elif isinstance(ret_type, _ty.TupleType): + new_fields = [] + for field in ret_type.fields: + if isinstance(field, _ty.TensorType): + new_fields.append(_ty.TensorType(get_shape(field.shape), field.dtype)) + else: + new_fields.append(field) + ret_type = _ty.TupleType(new_fields) is_dyn = _ty.is_dynamic(call.checked_type) for arg in call.args: diff --git a/python/tvm/relay/qnn/op/_qnn.py b/python/tvm/relay/qnn/op/_qnn.py index a059c293a0f8..4e54583a3be0 100644 --- a/python/tvm/relay/qnn/op/_qnn.py +++ b/python/tvm/relay/qnn/op/_qnn.py @@ -19,9 +19,10 @@ from tvm import topi +from .. import strategy from ...op.op import register_compute from ...op.op import register_injective_schedule -from ...op.op import register_pattern, OpPattern +from ...op.op import register_strategy, register_pattern, OpPattern @register_compute("qnn.simulated_quantize") @@ -50,3 +51,35 @@ def simulated_dequantize_compute(attrs, inputs, output_type): register_injective_schedule("qnn.simulated_dequantize") register_pattern("qnn.simulated_dequantize", OpPattern.ELEMWISE) + +# qnn.quantize +register_strategy("qnn.quantize", strategy.qnn_quantize_strategy) +register_pattern("qnn.quantize", OpPattern.ELEMWISE) + +# qnn.dequantize +register_strategy("qnn.dequantize", strategy.qnn_dequantize_strategy) +register_pattern("qnn.dequantize", OpPattern.ELEMWISE) + +# qnn.requantize +register_strategy("qnn.requantize", strategy.qnn_requantize_strategy) +register_pattern("qnn.requantize", OpPattern.ELEMWISE) + +# qnn.add +register_strategy("qnn.add", strategy.qnn_add_strategy) +register_pattern("qnn.add", OpPattern.BROADCAST) + +# qnn.concatenate +register_strategy("qnn.concatenate", strategy.qnn_concatenate_strategy) +register_pattern("qnn.concatenate", OpPattern.INJECTIVE) + +# qnn.conv2d +register_strategy("qnn.conv2d", strategy.qnn_conv2d_strategy) +register_pattern("qnn.conv2d", OpPattern.OUT_ELEMWISE_FUSABLE) + +# qnn.dense +register_strategy("qnn.dense", strategy.qnn_dense_strategy) +register_pattern("qnn.dense", OpPattern.OUT_ELEMWISE_FUSABLE) + +# qnn.batch_matmul +register_strategy("qnn.batch_matmul", strategy.qnn_batch_matmul_strategy) +register_pattern("qnn.batch_matmul", OpPattern.OUT_ELEMWISE_FUSABLE) diff --git a/python/tvm/relay/qnn/op/qnn.py b/python/tvm/relay/qnn/op/qnn.py index 1f383851071b..78d6669413ca 100644 --- a/python/tvm/relay/qnn/op/qnn.py +++ b/python/tvm/relay/qnn/op/qnn.py @@ -29,8 +29,6 @@ from tvm.topi.nn.qnn import SQNN_DTYPE_TO_CODE from tvm.topi.x86.utils import target_has_sse41 -from ... import op as reg -from ...op import OpPattern from . import _make, _requantize @@ -1212,11 +1210,6 @@ def batch_matmul(x, y, x_zero_point, y_zero_point, x_scale, y_scale, out_dtype=" return _make.batch_matmul(x, y, x_zero_point, y_zero_point, x_scale, y_scale, out_dtype) -# register fuse pattern for qnn ops -reg.register_pattern("qnn.quantize", OpPattern.OPAQUE) -reg.register_pattern("qnn.dequantize", OpPattern.OPAQUE) - - def leaky_relu(x, alpha, input_scale, input_zero_point, output_scale, output_zero_point): """Quantized leaky relu. diff --git a/python/tvm/relay/qnn/strategy/__init__.py b/python/tvm/relay/qnn/strategy/__init__.py new file mode 100644 index 000000000000..05778c3e9f86 --- /dev/null +++ b/python/tvm/relay/qnn/strategy/__init__.py @@ -0,0 +1,23 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# pylint: disable=wildcard-import +"""QNN op strategies.""" +from __future__ import absolute_import as _abs + +from .generic import * +from . import hexagon diff --git a/python/tvm/relay/qnn/strategy/generic.py b/python/tvm/relay/qnn/strategy/generic.py new file mode 100644 index 000000000000..57a364f7e057 --- /dev/null +++ b/python/tvm/relay/qnn/strategy/generic.py @@ -0,0 +1,249 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Definition of generic operator strategy.""" + +from tvm.target import override_native_generic_func + + +def wrap_topi_schedule(topi_schedule): + """Wrap TOPI schedule which doesn't use attrs""" + + def wrapper(_attrs, outs, target): + with target: + return topi_schedule(outs) + + return wrapper + + +def wrap_topi_compute(topi_compute): + """Wrap TOPI compute which doesn't use attrs""" + + def wrapper(_attrs, inputs, _out_type): + return [topi_compute(*inputs)] + + return wrapper + + +def wrap_compute_quantize(topi_compute): + """Wrap TOPI compute which use axis and out data type from attrs""" + + def wrapper(attrs, inputs, _out_type): + axis = attrs.axis + out_dtype = attrs.out_dtype + args = [*inputs, axis, out_dtype] + return [topi_compute(*args)] + + return wrapper + + +def wrap_compute_dequantize(topi_compute): + """Wrap TOPI compute which use axis from attrs""" + + def wrapper(attrs, inputs, _out_type): + args = [*inputs, attrs.axis] + return [topi_compute(*args)] + + return wrapper + + +def wrap_topi_qnn_conv2d(topi_compute): + """Wrap TOPI compute which use conv2d attrs and output data type""" + + def wrapper(attrs, inputs, out_type): + out_dtype = out_type.dtype + oshape = out_type.shape + strides = attrs.strides + padding = attrs.padding + dilation = attrs.dilation + if len([*inputs]) == 11: + args = [*inputs, strides, padding, dilation, oshape, out_dtype] + elif len([*inputs]) == 10: + args = [ # QNN Conv2d params: + inputs[0], + inputs[1], + inputs[2], + inputs[3], + inputs[4], + inputs[5], + # Bias argument + None, + # Requantization params: + inputs[6], + inputs[7], + inputs[8], + inputs[9], + # Conv2d attrs: + strides, + padding, + dilation, + oshape, + out_dtype, + ] + else: + assert len([*inputs]) == 6 + args = [ # QNN Conv2d params: + *inputs, + # Bias argument: + None, + # Requantization params: + None, + None, + None, + None, + strides, + padding, + dilation, + oshape, + out_dtype, + ] + return [topi_compute(*args)] + + return wrapper + + +def wrap_topi_qnn_dense(topi_compute): + """Wrap TOPI compute which use qnn.dense attrs""" + + def wrapper(_attrs, inputs, out_type): + out_dtype = out_type.dtype + if len([*inputs]) == 11: + args = [*inputs, out_dtype] + elif len([*inputs]) == 10: + args = [ # QNN Dense params: + inputs[0], + inputs[1], + inputs[2], + inputs[3], + inputs[4], + inputs[5], + # Bias argument + None, + # Requantization params: + inputs[6], + inputs[7], + inputs[8], + inputs[9], + out_dtype, + ] + else: + assert len([*inputs]) == 6 + args = [ # QNN Dense params: + *inputs, + # Bias argument: + None, + # Requantization params: + None, + None, + None, + None, + out_dtype, + ] + return [topi_compute(*args)] + + return wrapper + + +def wrap_topi_concatenate(topi_compute): + """Wrap TOPI compute which use qnn.concatenate attrs""" + + def wrapper(attrs, inputs, out_type): + return [topi_compute(inputs, attrs.axis, out_type.dtype)] + + return wrapper + + +def wrap_topi_qnn_batch_matmul(topi_compute): + """Wrap TOPI compute which use qnn.batch_matmul attrs""" + + def wrapper(attrs, inputs, _out_type): + assert len([*inputs]) == 6 + args = [*inputs, attrs.transpose_a, attrs.transpose_b, attrs.out_dtype] + return [topi_compute(*args)] + + return wrapper + + +@override_native_generic_func("qnn_quantize_strategy") +def qnn_quantize_strategy(attrs, inputs, out_type, target): + """qnn.quantize generic strategy""" + raise RuntimeError( + "qnn.quantize is currently only supported with Hexagon. " + "Please run QNN Canonicalize pass to decompose this op into supported ops." + ) + + +@override_native_generic_func("qnn_dequantize_strategy") +def qnn_dequantize_strategy(attrs, inputs, out_type, target): + """qnn.dequantize generic strategy""" + raise RuntimeError( + "qnn.dequantize is currently only supported with Hexagon. " + "Please run QNN Canonicalize pass to decompose this op into supported ops." + ) + + +@override_native_generic_func("qnn_requantize_strategy") +def qnn_requantize_strategy(attrs, inputs, out_type, target): + """qnn.requantize generic strategy""" + raise RuntimeError( + "qnn.requantize is currently only supported with Hexagon. " + "Please run QNN Canonicalize pass to decompose this op into supported ops." + ) + + +@override_native_generic_func("qnn_add_strategy") +def qnn_add_strategy(attrs, inputs, out_type, target): + """qnn.add generic strategy""" + raise RuntimeError( + "qnn.add is currently only supported with Hexagon. " + "Please run QNN Canonicalize pass to decompose this op into supported ops." + ) + + +@override_native_generic_func("qnn_concatenate_strategy") +def qnn_concatenate_strategy(attrs, inputs, out_type, target): + """qnn.concatenate generic strategy""" + raise RuntimeError( + "qnn.concatenate is currently only supported with Hexagon. " + "Please run QNN Canonicalize pass to decompose this op into supported ops." + ) + + +@override_native_generic_func("qnn_conv2d_strategy") +def qnn_conv2d_strategy(attrs, inputs, out_type, target): + """qnn.conv2d generic strategy""" + raise RuntimeError( + "qnn.conv2d is currently only supported with Hexagon. " + "Please run QNN Canonicalize pass to decompose this op into supported ops." + ) + + +@override_native_generic_func("qnn_dense_strategy") +def qnn_dense_strategy(attrs, inputs, out_type, target): + """qnn.dense generic strategy""" + raise RuntimeError( + "qnn.dense is currently only supported with Hexagon. " + "Please run QNN Canonicalize pass to decompose this op into supported ops." + ) + + +@override_native_generic_func("qnn_batch_matmul_strategy") +def qnn_batch_matmul_strategy(attrs, inputs, out_type, target): + """qnn.batch_matmul generic strategy""" + raise RuntimeError( + "qnn.batch_matmul is currently only supported with Hexagon. " + "Please run QNN Canonicalize pass to decompose this op into supported ops." + ) diff --git a/python/tvm/relay/qnn/strategy/hexagon.py b/python/tvm/relay/qnn/strategy/hexagon.py new file mode 100644 index 000000000000..c7f59cc096fc --- /dev/null +++ b/python/tvm/relay/qnn/strategy/hexagon.py @@ -0,0 +1,136 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Definition of Hexagon operator strategy.""" +# pylint: disable=unused-argument,wildcard-import,unused-wildcard-import + +from tvm import topi +from .generic import * +from ... import op as _op +from ...op.strategy.generic import is_depthwise_conv2d + + +@qnn_quantize_strategy.register("hexagon") +def qnn_quantize_strategy_hexagon(attrs, inputs, out_type, target): + """qnn.quantize strategy for Hexagon""" + strategy = _op.OpStrategy() + strategy.add_implementation( + wrap_compute_quantize(topi.hexagon.qnn_quantize), + wrap_topi_schedule(topi.hexagon.schedule_qnn_quantize), + name="qnn_quantize.hexagon", + ) + return strategy + + +@qnn_dequantize_strategy.register("hexagon") +def qnn_dequantize_strategy_hexagon(attrs, inputs, out_type, target): + """qnn.dequantize strategy for Hexagon""" + strategy = _op.OpStrategy() + strategy.add_implementation( + wrap_compute_dequantize(topi.hexagon.qnn_dequantize), + wrap_topi_schedule(topi.hexagon.schedule_qnn_dequantize), + name="qnn_dequantize.hexagon", + ) + return strategy + + +@qnn_requantize_strategy.register("hexagon") +def qnn_requantize_strategy_hexagon(attrs, inputs, out_type, target): + """qnn.requantize strategy for Hexagon""" + strategy = _op.OpStrategy() + strategy.add_implementation( + wrap_compute_quantize(topi.hexagon.qnn_requantize), + wrap_topi_schedule(topi.hexagon.schedule_qnn_requantize), + name="qnn_requantize.hexagon", + ) + return strategy + + +@qnn_add_strategy.register("hexagon") +def qnn_add_strategy_hexagon(attrs, inputs, out_type, target): + """qnn.add strategy for Hexagon""" + strategy = _op.OpStrategy() + strategy.add_implementation( + wrap_topi_compute(topi.hexagon.qnn_add), + wrap_topi_schedule(topi.hexagon.schedule_qnn_add), + name="qnn_add.hexagon", + ) + return strategy + + +@qnn_concatenate_strategy.register("hexagon") +def qnn_concatenate_strategy_hexagon(attrs, inputs, out_type, target): + """qnn.concatenate strategy for Hexagon""" + strategy = _op.OpStrategy() + strategy.add_implementation( + wrap_topi_concatenate(topi.hexagon.qnn_concatenate), + wrap_topi_schedule(topi.hexagon.schedule_qnn_concatenate), + name="qnn_concatenate.hexagon", + ) + return strategy + + +@qnn_conv2d_strategy.register("hexagon") +def qnn_conv2d_strategy_hexagon(attrs, inputs, out_type, target): + """qnn.conv2d strategy for Hexagon""" + data = inputs[0] + kernel = inputs[1] + data_layout = attrs.data_layout + kernel_layout = attrs.kernel_layout + groups = attrs.groups + strategy = _op.OpStrategy() + if groups == 1: + if data_layout == "NCHW" and kernel_layout == "OIHW": + strategy.add_implementation( + wrap_topi_qnn_conv2d(topi.hexagon.qnn_conv2d), + wrap_topi_schedule(topi.hexagon.schedule_qnn_conv2d), + name="qnn_conv2d.hexagon", + ) + elif is_depthwise_conv2d(data.shape, data_layout, kernel.shape, kernel_layout, groups): + if data_layout == "NCHW" and kernel_layout == "OIHW": + strategy.add_implementation( + wrap_topi_qnn_conv2d(topi.hexagon.qnn_depthwise_conv2d), + wrap_topi_schedule(topi.hexagon.schedule_qnn_depthwise_conv2d), + name="qnn_depthwise_conv2d.hexagon", + ) + else: + raise RuntimeError("Unsupported strategy for group qnn.conv2d") + + return strategy + + +@qnn_dense_strategy.register("hexagon") +def qnn_dense_strategy_hexagon(attrs, inputs, out_type, target): + """qnn.dense strategy for Hexagon""" + strategy = _op.OpStrategy() + strategy.add_implementation( + wrap_topi_qnn_dense(topi.hexagon.qnn_dense), + wrap_topi_schedule(topi.hexagon.schedule_qnn_dense), + name="qnn_dense.hexagon", + ) + return strategy + + +@qnn_batch_matmul_strategy.register("hexagon") +def qnn_batch_matmul_strategy_hexagon(attrs, inputs, out_type, target): + """qnn.batch_matmul strategy for Hexagon""" + strategy = _op.OpStrategy() + strategy.add_implementation( + wrap_topi_qnn_batch_matmul(topi.hexagon.qnn_batch_matmul), + wrap_topi_schedule(topi.hexagon.schedule_qnn_batch_matmul), + name="qnn_batch_matmul.hexagon", + ) + return strategy diff --git a/python/tvm/te/__init__.py b/python/tvm/te/__init__.py index a52422f6c1d2..0907ea2ebf85 100644 --- a/python/tvm/te/__init__.py +++ b/python/tvm/te/__init__.py @@ -26,6 +26,7 @@ from tvm.tir import isnan, isfinite, isinf from tvm.tir import div, indexdiv, indexmod, truncdiv, truncmod, floordiv, floormod from tvm.tir import comm_reducer, min, max, sum +from tvm.tir import add, subtract, multiply from .schedule import ( Schedule, diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py index 8e637d2d6564..2767f2d5f779 100644 --- a/python/tvm/tir/__init__.py +++ b/python/tvm/tir/__init__.py @@ -74,6 +74,7 @@ from .op import comm_reducer, min, max, sum from .op import q_multiply_shift, shift_left, shift_right from .op import TVMBackendAllocWorkspace, TVMBackendFreeWorkspace +from .generic import add, subtract, multiply from .schedule import StmtSRef, BlockScope, ScheduleState, Schedule, ScheduleError diff --git a/python/tvm/topi/hexagon/qnn/__init__.py b/python/tvm/topi/hexagon/qnn/__init__.py index 2616b9315a9b..bafc6846b6fb 100644 --- a/python/tvm/topi/hexagon/qnn/__init__.py +++ b/python/tvm/topi/hexagon/qnn/__init__.py @@ -25,3 +25,4 @@ ) from .quantize import quantize_compute, tir_quantize_schedule +from .nn import * diff --git a/python/tvm/topi/hexagon/qnn/nn.py b/python/tvm/topi/hexagon/qnn/nn.py new file mode 100644 index 000000000000..40cfd0ee96b1 --- /dev/null +++ b/python/tvm/topi/hexagon/qnn/nn.py @@ -0,0 +1,667 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Hexagon QNN operators""" +# pylint: disable=invalid-name + +import tvm +from tvm import te, topi +from ...utils import get_const_tuple +from ...nn.utils import get_pad_tuple +from ...nn.pad import pad +from ... import tag, nn +from ...x86.concat import concatenate + + +def clip_cast(val, dtype): + # clip + cast: + const_min = tvm.tir.min_value(dtype) + const_max = tvm.tir.max_value(dtype) + return te.max(tvm.te.min(val, const_max), const_min).astype(dtype) + + +def get_qnn_param(param, indices, axis): + # Account scalar and 1D quantization parameters: + if len(param.shape) == 0: + return param + + param_idx = tvm.tir.indexmod(indices[axis], topi.shape(param)[0]) + return param[param_idx] + + +def default_schedule(outs): + """Simple default schedule for QNN ops. + + Parameters + ---------- + outs: Array of Tensor + The computation graph description of dense in the format + of an array of tensors. + + Returns + ------- + sch: Schedule + The computation schedule for the op. + """ + outs = [outs] if isinstance(outs, tvm.te.tensor.Tensor) else outs + s = tvm.te.create_schedule([x.op for x in outs]) + tvm.te.schedule.AutoInlineInjective(s) + return s + + +def qnn_quantize(data, output_scale, output_zero_point, axis, out_dtype): + """Compute for qnn.quantize + + Q_output = clamp((round(input_tensor/output_scale) + output_zero_point), + out_dtype::min, + out_dtype::max) + """ + + assert len(output_scale.shape) == 0 or len(output_scale.shape) == 1 + assert len(output_zero_point.shape) == 0 or len(output_zero_point.shape) == 1 + + def _compute(*indices): + value = data(*indices) + scale = get_qnn_param(output_scale, indices, axis) + zp = get_qnn_param(output_zero_point, indices, axis) + + val = te.add(te.round(te.div(value, scale)), zp) + return clip_cast(val, out_dtype) + + return te.compute(data.shape, _compute, tag=tag.ELEMWISE) + + +def schedule_qnn_quantize(outs): + """Schedule for qnn.quantize + + Parameters + ---------- + outs: Array of Tensor + The computation graph description of qnn.quantize + in the format of an array of tensors. + + Returns + ------- + sch: Schedule + The computation schedule for the op. + """ + return default_schedule(outs) + + +def qnn_dequantize(data, input_scale, input_zero_point, axis): + """Compute for qnn.dequantize + + fp_output = input_scale * (Q_input - input_zero_point) + """ + + def _compute(*indices): + value = data(*indices) + scale = get_qnn_param(input_scale, indices, axis) + zp = get_qnn_param(input_zero_point, indices, axis) + + return te.multiply(scale, te.subtract(value, zp)) + + return te.compute(data.shape, _compute, tag=tag.ELEMWISE) + + +def schedule_qnn_dequantize(outs): + """Schedule for qnn.dequantize + + Parameters + ---------- + outs: Array of Tensor + The computation graph description of qnn.dequantize + in the format of an array of tensors. + + Returns + ------- + sch: Schedule + The computation schedule for the op. + """ + return default_schedule(outs) + + +def qnn_requantize(data, input_scale, input_zp, output_scale, output_zp, axis, out_dtype): + """Compute for qnn.requantize + + Q_output = zp_output + round((scale_input)/(scale_output) * (Q_input - zp_input)) + + TODO: support 'rounding' and 'compute_dtype' arguments. + """ + + def _compute(*indices): + value = data(*indices) + + iscale = get_qnn_param(input_scale, indices, axis) + oscale = get_qnn_param(output_scale, indices, axis) + + sub = te.subtract(value, input_zp) + mul = te.div(iscale, oscale) + val = te.add(te.round(te.multiply(mul, sub)), output_zp) + + # clip + cast: + const_min = tvm.tir.min_value(out_dtype) + const_max = tvm.tir.max_value(out_dtype) + return te.max(tvm.te.min(val, const_max), const_min).astype(out_dtype) + + return te.compute(data.shape, _compute) + + +def schedule_qnn_requantize(outs): + """Schedule for qnn.requantize + + Parameters + ---------- + outs: Array of Tensor + The computation graph description of qnn.requantize + in the format of an array of tensors. + + Returns + ------- + sch: Schedule + The computation schedule for the op. + """ + return default_schedule(outs) + + +def qnn_add( + lhs, rhs, lhs_scale, lhs_zero_point, rhs_scale, rhs_zero_point, output_scale, output_zero_point +): + """Compute for qnn.add + + Q_output = zp_output + round((lhs_scale)/(scale_output) * (lhs_input - lhs_zp_input)) + + round((rhs_scale)/(scale_output) * (rhs_input - rhs_zp_input)) + + TODO: support 'axis' argument. + """ + + assert lhs.dtype == rhs.dtype + dtype = lhs.dtype + + def _compute(*indices): + lvalue = lhs(*indices) + rvalue = rhs(*indices) + q_lv = te.round( + te.multiply(te.div(lhs_scale, output_scale), te.subtract(lvalue, lhs_zero_point)) + ).astype("int32") + q_rv = te.round( + te.multiply(te.div(rhs_scale, output_scale), te.subtract(rvalue, rhs_zero_point)) + ).astype("int32") + val = te.add(te.add(q_lv, q_rv), output_zero_point) + + # clip + cast: + const_min = tvm.tir.min_value(dtype) + const_max = tvm.tir.max_value(dtype) + return te.max(tvm.te.min(val, const_max), const_min).astype(dtype) + + return te.compute(lhs.shape, _compute) + + +def schedule_qnn_add(outs): + """Schedule for qnn.add + + Parameters + ---------- + outs: Array of Tensor + The computation graph description of qnn.add + in the format of an array of tensors. + + Returns + ------- + sch: Schedule + The computation schedule for the op. + """ + return default_schedule(outs) + + +def requantize_tensor(tensor, i_scale, i_zp, o_scale, o_zp, out_dtype): + """Requantize tensor""" + + def _compute(*indices): + value = tensor(*indices) + mul_value = te.round( + te.multiply(te.div(i_scale, o_scale), te.subtract(value, i_zp)) + ).astype("int32") + rq_value = te.add(mul_value, o_zp) + + return clip_cast(rq_value, out_dtype) + + return te.compute(tensor.shape, _compute) + + +def qnn_concatenate(data, axis, out_dtype): + """Compute for qnn.concatenate + + Parameters + ---------- + data: Array of Tensor + The computation graph description of qnn.concatenate + in the format of an array of tensors. + + axis: int + The axis along which the tensors are concatenated. + + out_dtype: string + Data type of output tensor + + Returns + ------- + out: Tensor + The computation for the op. + """ + + # Get output quantization parameters. + o_scale = data[-2] + o_zp = data[-1] + + # Initially qnn.concatenate had 3 tuples: (1) tuple with input tensors, (2) tuple with input + # scales and (3) tuple with input zero points. + # Last 2 elements in data represent output scale and zero point. + num_of_tuples = 3 + assert ((len(data) - 2) % num_of_tuples) == 0 + args_num = (len(data) - 2) // num_of_tuples + + args = [] + for i in range(args_num): + # Get next tensor and its quantization parameters. + tensor = data[i] + i_scale = data[i + args_num] + i_zp = data[i + args_num * 2] + + # Requantize tensors and add them to the list. + args.append(requantize_tensor(tensor, i_scale, i_zp, o_scale, o_zp, out_dtype)) + + # Call x86 implementation of concatenate. + return concatenate(args, axis) + + +def schedule_qnn_concatenate(outs): + """Schedule for qnn.concatenate + + Parameters + ---------- + outs: Array of Tensor + The computation graph description of qnn.add + in the format of an array of tensors. + + Returns + ------- + sch: Schedule + The computation schedule for the op. + """ + return default_schedule(outs) + + +def qnn_conv2d( # Conv2d inputs + data, + weight, + # Conv2d quantization params: + input_zero_point, + kernel_zero_point, + _input_scale, + _kernel_scale, + # bias + bias, + # Requantization params: + rq_input_scale, + rq_input_zero_point, + rq_output_scale, + rq_output_zero_point, + # Conv2d attributes: + strides, + padding, + dilation, + oshape, + odtype, +): + """Compute for qnn.conv2d with NCHW layout. + + Output data type should be specified through the 'odtype' parameter. qnn.conv2d leverages int32 + type to store intermediate results. If 'odtype' differs from int32, you need to specify + requantization parameters. + """ + in_channel = data.shape[1] # NCHW layout + kernel_height = weight.shape[2] # OIHW layout + kernel_width = weight.shape[3] # OIHW layout + + height_stride, width_stride = strides + dilation_h, dilation_w = dilation + + dilated_kernel_h = (kernel_height - 1) * dilation_h + 1 + dilated_kernel_w = (kernel_width - 1) * dilation_w + 1 + + pad_top, pad_left, pad_down, pad_right = get_pad_tuple( + get_const_tuple(padding), (dilated_kernel_h, dilated_kernel_w) + ) + + # Subtract zero point from input and then do padding with 0 value + data = te.compute(data.shape, lambda *indices: te.subtract(data(*indices), input_zero_point)) + + # DOPAD + if pad_top != 0 or pad_down != 0 or pad_left != 0 or pad_right != 0: + pad_before = (0, 0, pad_top, pad_left) + pad_after = (0, 0, pad_down, pad_right) + data_pad = pad(data, pad_before, pad_after, name="data_pad") + else: + data_pad = data + + ic = te.reduce_axis((0, in_channel), name="ic") + kh = te.reduce_axis((0, kernel_height), name="kh") + kw = te.reduce_axis((0, kernel_width), name="kw") + + # axis=0 in get_qnn_param means 'O' dimension in "OIHW" weights layout. + out = te.compute( + oshape, + lambda n, oc, oh, ow: te.sum( + data_pad[ + n, + ic, + oh * height_stride + kh * dilation_h, + ow * width_stride + kw * dilation_w, + ].astype("int32") + * te.subtract( + weight[oc, ic, kh, kw], get_qnn_param(kernel_zero_point, (oc, ic, kh, kw), axis=0) + ).astype("int32"), + axis=[ic, kh, kw], + ), + ) + + # Add bias + if bias is not None: + assert len(out.shape) == len(bias.shape) + assert bias.shape[2] == 1 and bias.shape[3] == 1 + out = te.compute(out.shape, lambda n, c, h, w: out[n, c, h, w] + bias[n, c, 0, 0]) + + # Requantize output of convolution + # Q_output = zp_output + round((scale_input)/(scale_output) * (Q_input - zp_input)) + if rq_input_scale is not None and rq_output_scale is not None: + # Now supported only scalar and 1D quantization parameters + assert len(rq_input_scale.shape) == 0 or len(rq_input_scale.shape) == 1 + assert len(rq_output_scale.shape) == 0 or len(rq_output_scale.shape) == 1 + axis = -1 + if len(rq_input_scale.shape) == 1 or len(rq_output_scale.shape) == 1: + axis = 1 # Axis param should correspond to 'C' dimension. + + return qnn_requantize( + out, + rq_input_scale, + rq_input_zero_point, + rq_output_scale, + rq_output_zero_point, + axis, + odtype, + ) + + return out + + +def schedule_qnn_conv2d(outs): + """Schedule for qnn.conv2d + + Parameters + ---------- + outs: Array of Tensor + The computation graph description of qnn.conv2d + in the format of an array of tensors. + + Returns + ------- + sch: Schedule + The computation schedule for the op. + """ + return default_schedule(outs) + + +def qnn_depthwise_conv2d( # Conv2d inputs + data, + weight, + # Conv2d quantization params: + input_zero_point, + kernel_zero_point, + _input_scale, + _kernel_scale, + # bias + bias, + # Requantization params: + rq_input_scale, + rq_input_zero_point, + rq_output_scale, + rq_output_zero_point, + # Conv2d attributes: + strides, + padding, + dilation, + oshape, + odtype, +): + """Compute for qnn.conv2d with NCHW layout + + Output data type should be specified through the 'odtype' parameter. qdepthwise nn.conv2d + leverages int32 type to store intermediate results. If 'odtype' differs from int32, you need to + specify requantization parameters. + """ + kernel_height = weight.shape[2] # OIHW layout + kernel_width = weight.shape[3] # OIHW layout + + height_stride, width_stride = strides + dilation_h, dilation_w = dilation + + dilated_kernel_h = (kernel_height - 1) * dilation_h + 1 + dilated_kernel_w = (kernel_width - 1) * dilation_w + 1 + + pad_top, pad_left, pad_down, pad_right = get_pad_tuple( + get_const_tuple(padding), (dilated_kernel_h, dilated_kernel_w) + ) + + # Subtract zero point from input and then do padding with 0 value + data = te.compute(data.shape, lambda *indices: te.subtract(data(*indices), input_zero_point)) + + # DOPAD + if pad_top != 0 or pad_down != 0 or pad_left != 0 or pad_right != 0: + pad_before = (0, 0, pad_top, pad_left) + pad_after = (0, 0, pad_down, pad_right) + data_pad = pad(data, pad_before, pad_after, name="data_pad") + else: + data_pad = data + + kh = te.reduce_axis((0, kernel_height), name="kh") + kw = te.reduce_axis((0, kernel_width), name="kw") + + out = te.compute( + oshape, + lambda n, oc, oh, ow: te.sum( + data_pad[ + n, + oc, + oh * height_stride + kh * dilation_h, + ow * width_stride + kw * dilation_w, + ].astype("int32") + * te.subtract(weight[oc, 0, kh, kw], kernel_zero_point).astype("int32"), + axis=[kh, kw], + ), + ) + + # Add bias + if bias is not None: + assert len(out.shape) == len(bias.shape) + assert bias.shape[2] == 1 and bias.shape[3] == 1 + out = te.compute(out.shape, lambda n, c, h, w: out[n, c, h, w] + bias[n, c, 0, 0]) + + # Requantize output of convolution + # Q_output = zp_output + round((scale_input)/(scale_output) * (Q_input - zp_input)) + if rq_input_scale is not None and rq_output_scale is not None: + # Now supported only scalar and 1D quantization parameters + assert len(rq_input_scale.shape) == 0 or len(rq_input_scale.shape) == 1 + assert len(rq_output_scale.shape) == 0 or len(rq_output_scale.shape) == 1 + axis = -1 + if len(rq_input_scale.shape) == 1 or len(rq_output_scale.shape) == 1: + axis = 1 # Axis param should correspond to 'C' dimension. + + return qnn_requantize( + out, + rq_input_scale, + rq_input_zero_point, + rq_output_scale, + rq_output_zero_point, + axis, + odtype, + ) + + return out + + +def schedule_qnn_depthwise_conv2d(outs): + """Schedule for depthwise qnn.conv2d + + Parameters + ---------- + outs: Array of Tensor + The computation graph description of qnn.conv2d + in the format of an array of tensors. + + Returns + ------- + sch: Schedule + The computation schedule for the op. + """ + return default_schedule(outs) + + +def qnn_dense( + data, + weight, + # Dense quantization params: + input_zero_point, + kernel_zero_point, + _input_scale, + _kernel_scale, + # bias + bias, + # Requantization params: + rq_input_scale, + rq_input_zero_point, + rq_output_scale, + rq_output_zero_point, + out_dtype, +): + """Compute for qnn.dense + + Output data type should be specified through the 'odtype' parameter. qnn.dense leverages int32 + type to store intermediate results. If 'odtype' differs from int32, you need to specify + requantization parameters. + """ + M, K = get_const_tuple(data.shape) + N, _ = get_const_tuple(weight.shape) + k = te.reduce_axis((0, K), "k") + # This implementation uses "int32" dense output data type. + # axis=0 in get_qnn_param mean 'N' dimension in "NK" weights layout. + out = te.compute( + (M, N), + lambda m, n: te.sum( + te.subtract(data[m, k], input_zero_point).astype("int32") + * te.subtract(weight[n, k], get_qnn_param(kernel_zero_point, (n, k), axis=0)).astype( + "int32" + ), + axis=k, + ), + ) + + # Add bias + if bias is not None: + out = te.compute(out.shape, lambda n, c: out[n, c] + bias[c]) + + # Requantize output of dense + # Q_output = zp_output + round((scale_input)/(scale_output) * (Q_input - zp_input)) + if rq_input_scale is not None and rq_output_scale is not None: + # Now supported only scalar and 1D quantization parameters + assert len(rq_input_scale.shape) == 0 or len(rq_input_scale.shape) == 1 + assert len(rq_output_scale.shape) == 0 or len(rq_output_scale.shape) == 1 + axis = -1 + if len(rq_input_scale.shape) == 1 or len(rq_output_scale.shape) == 1: + axis = 1 # Axis param should correspond to 'N' dimension. + + return qnn_requantize( + out, + rq_input_scale, + rq_input_zero_point, + rq_output_scale, + rq_output_zero_point, + axis, + out_dtype, + ) + + return out + + +def schedule_qnn_dense(outs): + """Schedule for qnn.dense + + Parameters + ---------- + outs: Array of Tensor + The computation graph description of qnn.dense + in the format of an array of tensors. + + Returns + ------- + sch: Schedule + The computation schedule for the op. + """ + return default_schedule(outs) + + +def qnn_batch_matmul( + tensor_a, + tensor_b, + # batch_matmul quantization params: + a_zero_point, + b_zero_point, + _a_scale, + _b_scale, + # Attributes + transpose_a, + transpose_b, + out_dtype, +): + """Compute for qnn.batch_matmul""" + + # Preprocess tensor_a: subtract zp + a_sub_zp = te.compute( + tensor_a.shape, lambda *indices: te.subtract(tensor_a(*indices), a_zero_point) + ) + # Preprocess tensor_b: subtract zp + b_sub_zp = te.compute( + tensor_b.shape, lambda *indices: te.subtract(tensor_b(*indices), b_zero_point) + ) + + return nn.batch_matmul(a_sub_zp, b_sub_zp, None, out_dtype, transpose_a, transpose_b) + + +def schedule_qnn_batch_matmul(outs): + """Schedule for qnn.batch_matmul + + Parameters + ---------- + outs: Array of Tensor + The computation graph description of qnn.batch_matmul + in the format of an array of tensors. + + Returns + ------- + sch: Schedule + The computation schedule for the op. + """ + return default_schedule(outs) diff --git a/src/relay/backend/te_compiler_cache.cc b/src/relay/backend/te_compiler_cache.cc index 9a0a2bef9a47..e7326ed5dd4d 100644 --- a/src/relay/backend/te_compiler_cache.cc +++ b/src/relay/backend/te_compiler_cache.cc @@ -123,6 +123,81 @@ Array GetShape(const Array& shape) { return res; } +// Helper class that is used during lowering to TE. +// It matches sequence of Ops and lower them into single TOPI operation. All supported patterns are +// enumerated in "supported_patterns_". +class QnnPatternMatcher { + public: + QnnPatternMatcher() + : qnn_conv2d_op_(Op::Get("qnn.conv2d")), + qnn_dense_op_(Op::Get("qnn.dense")), + qnn_requantize_op_(Op::Get("qnn.requantize")), + bias_add_op_(Op::Get("add")) {} + + // Memoize visited operations + void Register(const CallNode* call_node) { + ICHECK(call_node->op.as()); + Op op = Downcast(call_node->op); + if (op == qnn_conv2d_op_) { + registered_ops_.push_front(P_QConv2d); + ICHECK(anchor_op_ == nullptr); + anchor_op_ = call_node; + } else if (op == qnn_requantize_op_) { + registered_ops_.push_front(P_QRequantize); + } else if (op == bias_add_op_) { + registered_ops_.push_front(P_BiasAdd); + } else if (op == qnn_dense_op_) { + registered_ops_.push_front(P_QDense); + ICHECK(anchor_op_ == nullptr); + anchor_op_ = call_node; + } else { + registered_ops_.push_front(P_Opaque); + } + } + + // Check whether given Op is a part of matched pattern. + bool find(const Op& op) { + if (registered_ops_.empty()) return false; + + if (op == qnn_conv2d_op_ || op == qnn_requantize_op_ || op == bias_add_op_ || + op == qnn_dense_op_) { + for (const auto& pat : supported_patterns_) { + auto it = + std::search(registered_ops_.begin(), registered_ops_.end(), pat.begin(), pat.end()); + if (it != registered_ops_.end()) return true; + } + } + return false; + } + + // returns whether given Op is last in the pattern sequence. + bool IsLeafOp(const Op& op) { return op == qnn_requantize_op_; } + + const CallNode* GetAnchorOp() { return anchor_op_; } + + void Clear() { registered_ops_.clear(); } + + private: + const Op& qnn_conv2d_op_; + const Op& qnn_dense_op_; + const Op& qnn_requantize_op_; + const Op& bias_add_op_; + + // Main (complicated) operation in the primitive (for example qnn.conv2d, qnn.dense etc.). + const CallNode* anchor_op_ = nullptr; + + enum POper { P_QConv2d, P_QDense, P_BiasAdd, P_QRequantize, P_Opaque }; + + std::deque registered_ops_; + + const std::vector> supported_patterns_ = { + {P_QDense, P_BiasAdd, P_QRequantize}, // Pattern qnn.dense -> bias_add -> qnn.requantize + {P_QDense, P_QRequantize}, // Patter qnn.dense -> qnn.requantize + {P_QConv2d, P_BiasAdd, P_QRequantize}, // Pattern qnn.conv2d -> bias_add -> qnn.requantize + {P_QConv2d, P_QRequantize} // Patter qnn.conv2d -> qnn.requantize + }; +}; + // Lowers Relay primitive Function to TE Compute class LowerToTECompute : public backend::MemoizedExprTranslator> { public: @@ -213,6 +288,8 @@ class LowerToTECompute : public backend::MemoizedExprTranslator inputs; int count_tuple = 0; for (Expr arg : call_node->args) { @@ -224,21 +301,35 @@ class LowerToTECompute : public backend::MemoizedExprTranslatorargs.size(), 1U) - << "Only functions with a single tuple input are allowed, but " << count_tuple - << " were provided."; - } - ICHECK(call_node->op.as()) << "Primitive function only allows call into primitive ops"; Op op = Downcast(call_node->op); // TODO(mbs): device_copy cleanup ICHECK_NE(op, device_copy_op_) << "device_copy cannot be lowered"; - LoweredOutput lowered_out = (*flower_call)(GetRef(call_node), inputs, target_); - Array outputs = lowered_out->outputs; - op_implementations_[op.operator->()] = lowered_out->implementation; + Array outputs; + + if (pattern_matcher_.find(op)) { + if (pattern_matcher_.IsLeafOp(op)) { + // Lower anchor op when pattern leaf op was reached + auto anchor_op = pattern_matcher_.GetAnchorOp(); + LoweredOutput lowered_out = + (*flower_call)(GetRef(anchor_op), inputs, target_, call_node->checked_type()); + outputs = lowered_out->outputs; + Op a_op = Downcast(anchor_op->op); + op_implementations_[a_op.operator->()] = lowered_out->implementation; + + pattern_matcher_.Clear(); + } else { + // Forward inputs as "outputs" for successor. + readable_name_stream_ << '_' << op->name; + return inputs; + } + } else { + LoweredOutput lowered_out = (*flower_call)(GetRef(call_node), inputs, target_); + outputs = lowered_out->outputs; + op_implementations_[op.operator->()] = lowered_out->implementation; + } if (outputs.size() != 1) { const auto* tuple_type = call_node->checked_type().as(); @@ -294,6 +385,8 @@ class LowerToTECompute : public backend::MemoizedExprTranslator pass_seqs; pass_seqs.push_back(relay::transform::Legalize("FTVMQnnLegalize")); pass_seqs.push_back(relay::transform::Legalize("FTVMQnnCanonicalize")); - relay::transform::Pass seq = relay::transform::Sequential(pass_seqs); + relay::transform::Pass seq = relay::transform::Sequential(pass_seqs, "qnn.Legalize"); return seq; } diff --git a/src/relay/transforms/fuse_ops.cc b/src/relay/transforms/fuse_ops.cc index dac5dc69ead5..afa60f1bb4e5 100644 --- a/src/relay/transforms/fuse_ops.cc +++ b/src/relay/transforms/fuse_ops.cc @@ -885,8 +885,10 @@ class FuseMutator : private MixedModeMutator { Expr Rewrite_(const CallNode* call, const Expr& post) { if (call->op.as()) { static auto fnoncomputational = Op::GetAttrMap("TNonComputational"); + static auto fqnncanonicalize = Op::GetAttrMap("FTVMQnnCanonicalize"); - if (fnoncomputational.get(Downcast(call->op), false)) { + Op op = Downcast(call->op); + if (fnoncomputational.get(op, false) && !fqnncanonicalize.count(op)) { return ExprMutator::VisitExpr_(call); } diff --git a/tests/python/contrib/test_hexagon/test_wo_qnn_canonicalization.py b/tests/python/contrib/test_hexagon/test_wo_qnn_canonicalization.py new file mode 100644 index 000000000000..24da1faac697 --- /dev/null +++ b/tests/python/contrib/test_hexagon/test_wo_qnn_canonicalization.py @@ -0,0 +1,185 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import pytest +import numpy as np + +import tvm.testing +from tvm import relay +from tvm.contrib.hexagon.session import Session +from tvm.contrib import graph_executor +from tvm.relay.backend import Executor + + +@tvm.testing.requires_hexagon +def test_no_qnn_pass(): + x = relay.var("x", shape=(4, 8), dtype="float32") + op0 = relay.qnn.op.quantize(x, relay.const(2.0), relay.const(10), out_dtype="uint8") + op1 = relay.qnn.op.dequantize(op0, relay.const(0.5), relay.const(5)) + mod = tvm.IRModule.from_expr(op1) + + target_hexagon = tvm.target.hexagon("v68") + # Default compilation flow + with tvm.transform.PassContext(opt_level=3): + opt_mod_1, _ = relay.optimize(mod, tvm.target.Target(target_hexagon, host=target_hexagon)) + + # Disable QNN legalization and canonicalization passes + with tvm.transform.PassContext(opt_level=3, disabled_pass=["qnn.Legalize"]): + opt_mod_2, _ = relay.optimize(mod, tvm.target.Target(target_hexagon, host=target_hexagon)) + + # Check that QNN ops are absent with default compilation flow. + assert "qnn.quantize" not in opt_mod_1.astext(show_meta_data=False) + assert "qnn.dequantize" not in opt_mod_1.astext(show_meta_data=False) + + # Check that QNN ops are present without "qnn.Legalize" passes. + assert "qnn.quantize" in opt_mod_2.astext(show_meta_data=False) + assert "qnn.dequantize" in opt_mod_2.astext(show_meta_data=False) + + +def execute(executor, data_np, weight_np, bias_np=None): + executor.set_input("data", data_np) + executor.set_input("weight", weight_np) + if bias_np is not None: + executor.set_input("bias", bias_np) + executor.run() + return executor.get_output(0) + + +@tvm.testing.requires_hexagon +def test_qnn_conv2d_rq(hexagon_session: Session): + data_shape = [1, 8, 32, 32] + weight_shape = [16, 8, 3, 3] + data = relay.var("data", shape=data_shape, dtype="float32") + weight = relay.var("weight", shape=weight_shape, dtype="float32") + op0 = relay.qnn.op.quantize(data, relay.const(0.078), relay.const(0), out_dtype="int8") + op1 = relay.qnn.op.quantize(weight, relay.const(0.07), relay.const(0), out_dtype="int8") + op2 = relay.qnn.op.conv2d( + op0, + op1, + input_zero_point=relay.const(0), + kernel_zero_point=relay.const(0), + input_scale=relay.const(0.078), + kernel_scale=relay.const(0.07), + padding=[0, 0, 0, 0], + channels=16, + kernel_size=[3, 3], + ) + op5 = relay.qnn.op.requantize( + op2, + input_scale=relay.const(0.05), + input_zero_point=relay.const(0), + output_scale=relay.const(0.21), + output_zero_point=relay.const(61), + out_dtype="int8", + ) + relay_mod = tvm.IRModule.from_expr(op5) + + target_hexagon = tvm.target.hexagon("v68") + target_llvm = tvm.target.Target("llvm") + executor = Executor("graph", {"link-params": True}) + with tvm.transform.PassContext(opt_level=3, disabled_pass=["qnn.Legalize"]): + hexagon_lowered = tvm.relay.build( + relay_mod, + tvm.target.Target(target_hexagon, host=target_hexagon), + executor=executor, + ) + + with tvm.transform.PassContext(opt_level=3): + llvm_lowered = tvm.relay.build( + relay_mod, + tvm.target.Target(target_llvm, host=target_llvm), + executor=executor, + ) + + data_np = np.random.rand(*data_shape) - 0.5 + weight_np = np.random.rand(*weight_shape) - 0.5 + + hx_m = hexagon_session.get_executor_from_factory(hexagon_lowered) + hexagon_output = execute(hx_m, data_np, weight_np) + + dev = tvm.cpu(0) + llvm_m = graph_executor.GraphModule(llvm_lowered["default"](dev)) + llvm_out = execute(llvm_m, data_np, weight_np) + + np.testing.assert_equal(hexagon_output.numpy(), llvm_out.numpy()) + + +@tvm.testing.requires_hexagon +def test_qnn_dense_bias_rq(hexagon_session: Session): + data_shape = [8, 8] + weight_shape = [16, 8] + bias_shape = [16] + data = relay.var("data", shape=data_shape, dtype="float32") + weight = relay.var("weight", shape=weight_shape, dtype="float32") + bias = relay.var("bias", shape=bias_shape, dtype="float32") + + op0 = relay.qnn.op.quantize(data, relay.const(0.08), relay.const(0), out_dtype="int8") + op1 = relay.qnn.op.quantize(weight, relay.const(0.07), relay.const(0), out_dtype="int8") + op2 = relay.qnn.op.dense( + op0, + op1, + input_zero_point=relay.const(0), + kernel_zero_point=relay.const(0), + input_scale=relay.const(0.08), + kernel_scale=relay.const(0.07), + units=None, + ) + op3 = relay.qnn.op.quantize(bias, relay.const(0.5), relay.const(0), out_dtype="int32") + op4 = relay.nn.bias_add(op2, op3) + op5 = relay.qnn.op.requantize( + op4, + input_scale=relay.const(0.05), + input_zero_point=relay.const(0), + output_scale=relay.const(0.212), + output_zero_point=relay.const(10), + out_dtype="int8", + ) + relay_mod = tvm.IRModule.from_expr(op5) + + target_hexagon = tvm.target.hexagon("v68") + target_llvm = tvm.target.Target("llvm") + executor = Executor("graph", {"link-params": True}) + with tvm.transform.PassContext(opt_level=3, disabled_pass=["qnn.Legalize"]): + hexagon_lowered = tvm.relay.build( + relay_mod, + tvm.target.Target(target_hexagon, host=target_hexagon), + executor=executor, + ) + + with tvm.transform.PassContext(opt_level=3): + llvm_lowered = tvm.relay.build( + relay_mod, + tvm.target.Target(target_llvm, host=target_llvm), + executor=executor, + ) + + data_np = np.random.rand(*data_shape) - 0.5 + weight_np = np.random.rand(*weight_shape) - 0.5 + bias_np = np.random.rand(*bias_shape) + + hx_m = hexagon_session.get_executor_from_factory(hexagon_lowered) + hexagon_output = execute(hx_m, data_np, weight_np, bias_np) + + dev = tvm.cpu(0) + llvm_m = graph_executor.GraphModule(llvm_lowered["default"](dev)) + llvm_out = execute(llvm_m, data_np, weight_np, bias_np) + + np.testing.assert_equal(hexagon_output.numpy(), llvm_out.numpy()) + + +if __name__ == "__main__": + tvm.testing.main()