Skip to content

Commit

Permalink
[QNN] Legalization for Intel x86 QNN Conv2D
Browse files Browse the repository at this point in the history
  • Loading branch information
anijain2305 committed Sep 5, 2019
1 parent a25bed2 commit 12c6e3d
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 34 deletions.
2 changes: 1 addition & 1 deletion python/tvm/relay/qnn/op/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,5 @@
from __future__ import absolute_import as _abs
from .qnn import *
from .op import register_qnn_legalize
from . import _qnn
from . import legalizations
from . import op_attrs
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,36 @@
import tvm
from tvm import relay
from tvm.api import min_value, max_value
from tvm.relay.qnn.op import register_qnn_legalize
from .. import op as reg
from topi.util import get_const_int

# Registering QNN Conv2D legalization function.
@reg.register_qnn_legalize("qnn.conv2d")
def legalize_qnn_conv2d(attrs, inputs, types):
"""Legalizes QNN conv2d op.
Parameters
----------
attrs : tvm.attrs.Attrs
Attributes of current convolution
inputs : list of tvm.relay.Expr
The args of the Relay expr to be legalized
types : list of types
List of input and output types
Returns
-------
result : tvm.relay.Expr
The legalized expr
"""
return qnn_conv2d_legalize(attrs, inputs, types)

# Generic QNN Conv2D legalization function.
@tvm.target.generic_func
def qnn_conv2d_legalize(attrs, inputs, types):
"""Default legalization is None."""
return None

# Intel x86 QNN Conv2D legalization function.
@qnn_conv2d_legalize.register('cpu')
def _qnn_conv2d_legalize(attrs, inputs, types):
"""Legalizes QNN conv2d op. VNNI supports u8 x i8 fast conv/MM. If the dtypes are already good,
Expand Down Expand Up @@ -72,12 +93,6 @@ def _shift_quantized_tensor(data, shift, out_dtype):
data_modified = relay.cast(data_modified, out_dtype)
return data_modified

channels_expr = attrs['channels']
if isinstance(channels_expr, tvm.expr.IntImm):
channels = channels_expr.value
if channels == 1001:
return None

# Collect the dtypes.
data_dtype = types[0].dtype
kernel_dtype = types[1].dtype
Expand Down Expand Up @@ -108,23 +123,3 @@ def _shift_quantized_tensor(data, shift, out_dtype):
new_attrs['input_zero_point'] = input_zp
new_attrs['kernel_zero_point'] = kernel_zp
return relay.qnn.op.conv2d(data, kernel, **new_attrs)

@reg.register_qnn_legalize("qnn.conv2d")
def legalize_qnn_conv2d(attrs, inputs, types):
"""Legalizes QNN conv2d op.
Parameters
----------
attrs : tvm.attrs.Attrs
Attributes of current convolution
inputs : list of tvm.relay.Expr
The args of the Relay expr to be legalized
types : list of types
List of input and output types
Returns
-------
result : tvm.relay.Expr
The legalized expr
"""
return qnn_conv2d_legalize(attrs, inputs, types)
2 changes: 1 addition & 1 deletion python/tvm/relay/qnn/op/op_attrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""The attributes node used for Relay operators"""
"""The attributes node used for QNN operators"""

from ....attrs import Attrs
from ...base import register_relay_attr_node
Expand Down
6 changes: 2 additions & 4 deletions tests/python/relay/test_pass_qnn_legalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,6 @@ def expected():
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)

def test_qnn_legalize_qnn_conv2d():

def verify(ref_func, qnn_func, data_shape, data_dtype, kernel_shape, kernel_dtype):
def get_inputs(data_shape, data_dtype, kernel_shape, kernel_dtype):
low = -128
Expand All @@ -109,7 +108,6 @@ def get_output(func, golden_inputs):
graph, lib, params = relay.build(func, "llvm", params=params)
mod = graph_runtime.create(graph, lib, ctx=tvm.cpu(0))
mod.set_input("data", golden_data)
# mod.set_input("kernel", golden_weight)
mod.set_input(**params)
mod.run()
res = mod.get_output(0).asnumpy()
Expand Down Expand Up @@ -140,11 +138,11 @@ def get_output(func, golden_inputs):

mod = relay.Function(relay.analysis.free_vars(func), func)
mod = relay.Module.from_expr(mod)
ref_mod = relay.qnn.transform.QnnToRelay()(mod)
ref_mod = relay.qnn.transform.CanonicalizeOps()(mod)

with tvm.target.create('llvm'):
qnn_mod = relay.qnn.transform.Legalize()(mod)
qnn_mod = relay.qnn.transform.QnnToRelay()(qnn_mod)
qnn_mod = relay.qnn.transform.CanonicalizeOps()(qnn_mod)

verify(ref_mod, qnn_mod, data_shape, data_dtype, kernel_shape, kernel_dtype)

Expand Down

0 comments on commit 12c6e3d

Please sign in to comment.