diff --git a/python/tvm/relay/op/strategy/arm_cpu.py b/python/tvm/relay/op/strategy/arm_cpu.py index 44c46ae988af..862377887fec 100644 --- a/python/tvm/relay/op/strategy/arm_cpu.py +++ b/python/tvm/relay/op/strategy/arm_cpu.py @@ -93,13 +93,18 @@ def conv2d_strategy_arm_cpu(attrs, inputs, out_type, target): wrap_compute_conv2d(topi.arm_cpu.conv2d_nchw_spatial_pack), wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_nchw_spatial_pack), name="conv2d_nchw_spatial_pack.arm_cpu", + plevel=10, ) - if topi.arm_cpu.is_int8_hw_support(data.dtype, kernel.dtype): + if ( + topi.arm_cpu.is_int8_hw_support(data.dtype, kernel.dtype) + and kernel.shape[1] >= 64 + ): strategy.add_implementation( wrap_compute_conv2d(topi.arm_cpu.conv2d_nchw_int8), wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_nchw_int8), name="conv2d_nchw_int8.arm_cpu", + plevel=15, ) else: strategy.add_implementation( @@ -383,12 +388,16 @@ def conv2d_gemm_without_weight_transform_strategy_arm_cpu(attrs, inputs, out_typ if layout == "NHWC" and data.dtype in ["int8", "uint8"]: strategy.add_implementation( wrap_compute_conv2d_gemm(native_compute), - wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_NHWC_quantized_native), + wrap_topi_schedule( + topi.arm_cpu.schedule_conv2d_NHWC_quantized_native_without_transform + ), name="conv2d_NHWC_quantized_native_without_transform.arm_cpu", ) strategy.add_implementation( wrap_compute_conv2d_gemm(interleaved_compute), - wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_NHWC_quantized_interleaved), + wrap_topi_schedule( + topi.arm_cpu.schedule_conv2d_NHWC_quantized_interleaved_without_transform + ), name="conv2d_NHWC_quantized_interleaved_without_transform.arm_cpu", ) else: diff --git a/python/tvm/topi/arm_cpu/conv2d_alter_op.py b/python/tvm/topi/arm_cpu/conv2d_alter_op.py index 409768fc8f75..eb719dd66777 100644 --- a/python/tvm/topi/arm_cpu/conv2d_alter_op.py +++ b/python/tvm/topi/arm_cpu/conv2d_alter_op.py @@ -27,6 +27,7 @@ from ..nn import conv2d_alter_layout, conv2d_legalize from ..utils import get_const_tuple from ..x86.conv2d import _get_default_config as _get_x86_default_config +from ..x86.conv2d_int8 import _get_default_config_int8 from .conv2d_int8 import is_int8_hw_support from .arm_utils import get_tiling_B_interleaved_t from ..generic.conv2d import conv2d_alter_int8_common @@ -101,9 +102,6 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type): # we then assume it's not necessary to alter this op. return None cfg = dispatch_ctx.query(target, workload) - if cfg.is_fallback: # if is fallback, clear query cache and return None - autotvm.task.clear_fallback_cache(target, workload) - return None topi_tmpl = workload[0] new_attrs = {k: attrs[k] for k in attrs.keys()} @@ -346,6 +344,11 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type): if topi_tmpl == "conv2d_NCHWc_int8.arm_cpu": assert data_layout == "NCHW" and kernel_layout == "OIHW" + batch_size, in_channel, height, width = get_const_tuple(data_tensor.shape) + out_channel, _, kh, kw = get_const_tuple(kernel_tensor.shape) + + n_elems = 8 + if cfg.is_fallback: _get_default_config_int8( cfg, @@ -357,12 +360,14 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type): out_dtype, False, data_layout, + int32_lanes=4, ) - batch_size, in_channel, height, width = get_const_tuple(data_tensor.shape) - out_channel, channel_multiplier, kh, kw = get_const_tuple(kernel_tensor.shape) ic_bn, oc_bn = cfg["tile_ic"].size[-1], cfg["tile_oc"].size[-1] - n_elems = 8 + + if cfg.is_fallback: + # ic_bn needs to be divided by n_elems below + ic_bn = max(ic_bn, n_elems) # update new attrs new_attrs["channels"] = out_channel @@ -395,6 +400,12 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type): return relay.nn.contrib_conv2d_nchwc(*inputs, **new_attrs) if topi_tmpl == "conv2d_NHWC_quantized_interleaved.arm_cpu": + # TODO(masahi): This schedule can easily result in a tensorization error + # if used in the fallback mode + if cfg.is_fallback: # if is fallback, clear query cache and return None + autotvm.task.clear_fallback_cache(target, workload) + return None + assert data_layout == "NHWC" and kernel_layout == "HWIO" KH, KW, _, OC = get_const_tuple(kernel.shape) new_workload_name = "conv2d_NHWC_quantized_interleaved_without_transform.arm_cpu" @@ -411,6 +422,12 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type): inputs[0], new_kernel_expr, **new_attrs ) if topi_tmpl == "conv2d_NHWC_quantized_native.arm_cpu": + # TODO(masahi): This schedule can easily result in a tensorization error + # if used in the fallback mode + if cfg.is_fallback: # if is fallback, clear query cache and return None + autotvm.task.clear_fallback_cache(target, workload) + return None + assert data_layout == "NHWC" and kernel_layout == "HWIO" KH, KW, _, OC = get_const_tuple(kernel.shape) new_workload_name = "conv2d_NHWC_quantized_native_without_transform.arm_cpu" diff --git a/python/tvm/topi/arm_cpu/conv2d_int8.py b/python/tvm/topi/arm_cpu/conv2d_int8.py index 8d9c47966113..d09433b16a78 100644 --- a/python/tvm/topi/arm_cpu/conv2d_int8.py +++ b/python/tvm/topi/arm_cpu/conv2d_int8.py @@ -297,6 +297,12 @@ def schedule_conv2d_NHWC_quantized_interleaved(cfg, outs): return _schedule_conv2d_NHWC_quantized(cfg, outs, True) +@autotvm.register_topi_schedule("conv2d_NHWC_quantized_interleaved_without_transform.arm_cpu") +def schedule_conv2d_NHWC_quantized_interleaved_without_transform(cfg, outs): + """Interface for interleaved schedule_conv2d_NHWC_quantized_interleaved""" + return _schedule_conv2d_NHWC_quantized(cfg, outs, True) + + # Native schedules: those schedule won't interleave A (which is left in its native form). # The weights are interleaved and transposed @autotvm.register_topi_compute("conv2d_NHWC_quantized_native.arm_cpu") @@ -330,3 +336,9 @@ def compute_conv2d_NHWC_quantized_native_without_transform( def schedule_conv2d_NHWC_quantized_native(cfg, outs): """Interface for native schedule_conv2d_NHWC_quantized""" return _schedule_conv2d_NHWC_quantized(cfg, outs, False) + + +@autotvm.register_topi_schedule("conv2d_NHWC_quantized_native_without_transform.arm_cpu") +def schedule_conv2d_NHWC_quantized_native_without_transform(cfg, outs): + """Interface for native schedule_conv2d_NHWC_quantized""" + return _schedule_conv2d_NHWC_quantized(cfg, outs, False) diff --git a/python/tvm/topi/x86/conv2d_alter_op.py b/python/tvm/topi/x86/conv2d_alter_op.py index 9234581f1d5b..032c0e2e236b 100644 --- a/python/tvm/topi/x86/conv2d_alter_op.py +++ b/python/tvm/topi/x86/conv2d_alter_op.py @@ -159,6 +159,7 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type): out_dtype, False, data_layout, + int32_lanes=16, ) batch_size, in_channel, height, width = get_const_tuple(data_tensor.shape) diff --git a/python/tvm/topi/x86/conv2d_int8.py b/python/tvm/topi/x86/conv2d_int8.py index 075723303841..b0edb02b0804 100644 --- a/python/tvm/topi/x86/conv2d_int8.py +++ b/python/tvm/topi/x86/conv2d_int8.py @@ -34,7 +34,16 @@ def _get_default_config_int8( - cfg, data, kernel, strides, padding, dilation, out_dtype, is_depthwise=False, layout="NCHW" + cfg, + data, + kernel, + strides, + padding, + dilation, + out_dtype, + is_depthwise=False, + layout="NCHW", + int32_lanes=4, ): """ Get default schedule config for the workload @@ -50,11 +59,11 @@ def _get_default_config_int8( is_kernel_1x1 = wkl.kernel_h == 1 and wkl.kernel_w == 1 if is_kernel_1x1: conv2d_generic.fallback_schedule_cpu_1x1_int8( - cfg, wkl, int32_lanes=16, num_int8_elements=4 + cfg, wkl, int32_lanes=int32_lanes, num_int8_elements=4 ) else: conv2d_generic.fallback_schedule_cpu_common_int8( - cfg, wkl, int32_lanes=16, num_int8_elements=4 + cfg, wkl, int32_lanes=int32_lanes, num_int8_elements=4 ) @@ -163,6 +172,7 @@ def conv2d_NCHWc_int8(cfg, data, kernel, strides, padding, dilation, layout, out padding, dilation, out_dtype, + int32_lanes=16, ) # Pack data if raw 4-D data is provided. diff --git a/tests/python/unittest/test_meta_schedule_integration.py b/tests/python/unittest/test_meta_schedule_integration.py index 68ee840d15ea..8186d3c178d6 100644 --- a/tests/python/unittest/test_meta_schedule_integration.py +++ b/tests/python/unittest/test_meta_schedule_integration.py @@ -16,9 +16,11 @@ # under the License. import sys from typing import List +import numpy as np import pytest import tvm +from tvm import relay from tvm import meta_schedule as ms from tvm.ir.module import IRModule from tvm.meta_schedule.database import PyDatabase, TuningRecord, Workload @@ -149,5 +151,49 @@ def extract_task_qbert(): assert "vnni" in annotations["schedule_rule"] +def extract_task_arm_conv2d_nchwc(): + data_shape = (1, 64, 128, 128) + weight_shape = (32, 64, 1, 1) + bias_shape = (weight_shape[0],) + padding = (1, 1) + + data = relay.var("data", shape=data_shape, dtype="int8") + weight = relay.var("weight", shape=weight_shape, dtype="int8") + bias = relay.var("bias", shape=bias_shape, dtype="int32") + conv2d = relay.nn.conv2d( + data=data, + weight=weight, + kernel_size=weight_shape[2:], + channels=weight_shape[0], + padding=padding, + strides=(1, 1), + out_dtype="int32", + ) + bias_add = relay.nn.bias_add(conv2d, bias) + relay_mod = tvm.IRModule.from_expr(bias_add) + + weight_np = np.random.uniform(1, 10, size=weight_shape).astype("int8") + bias_np = np.random.uniform(1, 10, size=bias_shape).astype("int32") + + params = {"weight": weight_np, "bias": bias_np} + + target = "llvm -device arm_cpu -mtriple aarch64-linux-gnu -mattr=+neon" + extracted_tasks = extract_task_from_relay(relay_mod, target, params) + tune_tasks = list( + filter( + lambda task: "conv2d" in task.task_name, + extracted_tasks, + ) + ) + + assert len(tune_tasks) == 1 + + relay_func = list(tune_tasks[0].mod.functions.values())[0] + out_type = relay_func.body.checked_type + + # Check that the output is in NCHWc layout + assert list(out_type.shape) == [1, 8, 130, 130, 4] + + if __name__ == "__main__": sys.exit(pytest.main([__file__] + sys.argv[1:]))