diff --git a/python/tvm/relay/op/strategy/arm_cpu.py b/python/tvm/relay/op/strategy/arm_cpu.py index e56e7ba12e94..2936d29285d2 100644 --- a/python/tvm/relay/op/strategy/arm_cpu.py +++ b/python/tvm/relay/op/strategy/arm_cpu.py @@ -207,21 +207,21 @@ def conv2d_strategy_arm_cpu(attrs, inputs, out_type, target): name="conv2d_nhwc_dsp.arm_cpu", ) elif kernel_layout == "HWIO": - is_aarch64 = topi.arm_cpu.arm_utils.is_aarch64_arm() - has_dot_prod = topi.arm_cpu.arm_utils.is_dotprod_available() + has_asimd = target.features.has_asimd + has_dot_prod = target.features.has_dotprod if has_dot_prod and data.dtype in ["int8", "uint8"]: strategy.add_implementation( wrap_compute_conv2d(topi.arm_cpu.compute_conv2d_NHWC_quantized_native), wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_NHWC_quantized_native), name="conv2d_NHWC_quantized_native.arm_cpu", ) - if is_aarch64 and data.dtype in ["int8", "uint8"]: + if has_asimd and data.dtype in ["int8", "uint8"]: strategy.add_implementation( wrap_compute_conv2d(topi.arm_cpu.compute_conv2d_NHWC_quantized_interleaved), wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_NHWC_quantized_interleaved), name="conv2d_NHWC_quantized_interleaved.arm_cpu", ) - if (not is_aarch64) or (data.dtype not in ["int8", "uint8"]): + if (not has_asimd) or (data.dtype not in ["int8", "uint8"]): # TODO(@giuseros) # This strategy errors out for quantized data types when tuning. # Let's use this only for non-aarch64 or non-quantized cases @@ -281,8 +281,7 @@ def conv2d_strategy_arm_cpu(attrs, inputs, out_type, target): ) elif layout == "NHWC": assert kernel_layout == "HWOI" - is_aarch64 = topi.arm_cpu.arm_utils.is_aarch64_arm() - if is_aarch64 or "+neon" in target.mattr: + if target.features.has_asimd: strategy.add_implementation( wrap_compute_conv2d(topi.arm_cpu.compute_depthwise_conv2d_nhwc), wrap_topi_schedule(topi.arm_cpu.schedule_depthwise_conv2d_nhwc), diff --git a/python/tvm/relay/qnn/op/legalizations.py b/python/tvm/relay/qnn/op/legalizations.py index 9bc6efdad00f..ad016bc20089 100644 --- a/python/tvm/relay/qnn/op/legalizations.py +++ b/python/tvm/relay/qnn/op/legalizations.py @@ -405,18 +405,6 @@ def is_fast_int8_on_intel(): return target_has_sse42(target.mcpu) -def is_fast_int8_on_arm(): - """Checks whether the hardware has support for fast Int8 arithmetic operations.""" - target = tvm.target.Target.current(allow_none=False) - return "+v8.2a" in target.mattr and "+dotprod" in target.mattr - - -def is_aarch64_arm(): - """Checks whether we are compiling for an AArch64 target.""" - target = tvm.target.Target.current(allow_none=False) - return "aarch64" in target.attrs.get("mtriple", "") - - ######################## # ARM CPU legalizations. ######################## @@ -425,7 +413,6 @@ def is_aarch64_arm(): @qnn_conv2d_legalize.register("arm_cpu") def _qnn_conv2d_legalize_arm_cpu(attrs, inputs, types): target = tvm.target.Target.current(allow_none=False) - has_asimd = is_aarch64_arm() or "+neon" in target.mattr is_depthwise = relay.op.strategy.is_depthwise_conv2d( types[0].shape, attrs["data_layout"], @@ -434,9 +421,8 @@ def _qnn_conv2d_legalize_arm_cpu(attrs, inputs, types): attrs["groups"], ) use_int8_on_arm = (not is_depthwise) and attrs["data_layout"] == "NHWC" - has_dotprod = is_fast_int8_on_arm() - other_options = use_int8_on_arm or has_dotprod - if has_asimd and not other_options: + other_options = use_int8_on_arm or target.features.has_dotprod + if target.features.has_asimd and not other_options: return helper_no_fast_int8_hw_legalization(attrs, inputs, types, relay.nn.conv2d) # ARM prefers the dtypes to be same. return helper_change_dtypes_to_be_same(attrs, inputs, types, relay.qnn.op.conv2d) @@ -445,8 +431,7 @@ def _qnn_conv2d_legalize_arm_cpu(attrs, inputs, types): @qnn_dense_legalize.register("arm_cpu") def _qnn_dense_legalize_arm_cpu(attrs, inputs, types): target = tvm.target.Target.current(allow_none=False) - has_asimd = is_aarch64_arm() or "+neon" in target.mattr - if has_asimd and not is_fast_int8_on_arm(): + if target.features.has_asimd and not target.features.has_dotprod: return helper_no_fast_int8_hw_legalization(attrs, inputs, types, relay.nn.dense) # ARM prefers the dtypes to be same. return helper_change_dtypes_to_be_same(attrs, inputs, types, relay.qnn.op.dense) diff --git a/python/tvm/topi/arm_cpu/arm_utils.py b/python/tvm/topi/arm_cpu/arm_utils.py index 4ab72178b30b..1b2efc61ea56 100644 --- a/python/tvm/topi/arm_cpu/arm_utils.py +++ b/python/tvm/topi/arm_cpu/arm_utils.py @@ -17,57 +17,7 @@ # pylint: disable=invalid-name,unused-variable,unused-argument,no-member """Arm target utility functions""" -import re -import tvm - - -def get_arch_version(target_mattr): - """Parse the LLVM target -mattr, and return - the architecture version in a decimal representation - (e.g., if -mattr=v8.4a, return 8.4) - """ - - arch_version = 8.0 - m = re.compile(r"\+v(.*)\.(.*)a") - for attr in target_mattr: - match_obj = m.match(attr) - if match_obj: - major = int(match_obj.group(1)) - minor = int(match_obj.group(2)) - decimal = 10 - if minor >= 10: - decimal = 100 - arch_version = major + float(minor) / decimal - - return arch_version - - -def is_dotprod_available(): - """Checks whether the hardware has support for udot/sdot instructions.""" - target = tvm.target.Target.current(allow_none=False) - arch_version = get_arch_version(target.mattr) - return arch_version >= 8.4 or ((arch_version in (8.2, 8.3)) and "+dotprod" in target.mattr) - - -def is_mmla_available(): - """Checks whether the hardware has support for ummla/smmla instructions.""" - target = tvm.target.Target.current(allow_none=False) - arch_version = get_arch_version(target.mattr) - return arch_version >= 8.6 or ( - (arch_version in (8.2, 8.3, 8.4, 8.5)) and "+i8mm" in target.mattr - ) - - -def is_aarch64_arm(): - """Checks whether we are compiling for an AArch64 target.""" - target = tvm.target.Target.current(allow_none=False) - return "aarch64" in target.attrs.get("mtriple", "") - - -def is_neon_available(): - """Check if neon instructions are available""" - target = tvm.target.Target.current(allow_none=False) - return "+neon" in target.mattr +from tvm.target import Target def get_tiling_B_interleaved_t(interleave_A): @@ -94,13 +44,15 @@ def get_tiling_B_interleaved_t(interleave_A): tile_rows_B: the output tile rows of B' tile_cols_B: the output tile columns of B' """ - if is_mmla_available(): + target = Target.current(allow_none=False) + + if target.features.has_matmul_i8: # If smmla/ummla is available, A must be interleaved. # Each load from B' will contain 8 elements # and we are loading 12 rows of B' (i.e., 12 columns of B) tile_rows_B = 12 tile_cols_B = 8 - elif is_dotprod_available(): + elif target.features.has_dotprod: # The number of tile rows of B' vary depending on the # strategy: # * If we are interleaving A, then we select 12 columns from B'(i.e., diff --git a/python/tvm/topi/arm_cpu/conv2d_gemm.py b/python/tvm/topi/arm_cpu/conv2d_gemm.py index 8e416be8daa2..04748a4d81fb 100644 --- a/python/tvm/topi/arm_cpu/conv2d_gemm.py +++ b/python/tvm/topi/arm_cpu/conv2d_gemm.py @@ -18,6 +18,7 @@ # pylint: disable=unused-argument, redefined-builtin """GEMM Convolution schedule on ARM""" import tvm +from tvm.target import Target from tvm import te from tvm.topi import nn from tvm.autotvm.task.space import AnnotateEntity, ReorderEntity, OtherOptionEntity @@ -29,10 +30,9 @@ gemm_acc_nx16_int8_int8_int32, gemm_acc_2x2_int8_int8_int32, ) -from .arm_utils import is_aarch64_arm, is_dotprod_available, is_mmla_available -def configure_knobs(cfg, M, K): +def configure_knobs(cfg, M, K, target): """Configure auto-tuning knobs for the interleaved strategy""" x, y = cfg.axis(M // 4), cfg.axis(K // 16) @@ -48,7 +48,7 @@ def configure_knobs(cfg, M, K): cfg["reorder_gemm"] = ReorderEntity([0, 1]) cfg["A_interleaved_unroll_vec"] = AnnotateEntity(["unroll", "vec"]) - if not is_dotprod_available(): + if not target.features.has_dotprod: cfg.define_knob("gemm_quantized_unroll", [True, False]) if cfg.is_fallback: cfg["gemm_quantized_unroll"] = OtherOptionEntity(False) @@ -133,12 +133,13 @@ def compute_conv2d_gemm_without_weight_transform( # - Conv2DGemmWeightTransformRel in src/relay/op/nn/convolution.h # In order to have more information # - if is_mmla_available(): + target = Target.current(allow_none=False) + if target.features.has_matmul_i8: # If smmla/ummla is enabled, we are loading 8 rows from A. Each row # will contain 8 elements tile_rows_A = 8 tile_cols_A = 8 - elif is_dotprod_available() and interleave_A: + elif target.features.has_dotprod and interleave_A: # If dot product has been enabled, and we are interleaving A # tile size should be 8x4 tile_rows_A = 8 @@ -173,7 +174,7 @@ def compute_conv2d_gemm_without_weight_transform( if interleave_A: # Configuration space - configure_knobs(cfg, M_padded, K_padded) + configure_knobs(cfg, M_padded, K_padded, target) # Pack the input data A_interleaved = te.compute( @@ -181,7 +182,8 @@ def compute_conv2d_gemm_without_weight_transform( lambda b, x, y, z, w: A[b, z + tile_rows_A * x, w + tile_cols_A * y], name="A_interleaved", ) - if is_mmla_available(): + target = Target.current(allow_none=False) + if target.features.has_matmul_i8: # Execute GEMM. In the case of mmla, we need to enforce the tiling # from the compute. This is because mmla is doing a tiled computation # as well. So we have a big 8x12 tile, with small 2x2 sub-tiles @@ -323,7 +325,8 @@ def schedule_conv2d_gemm_interleaved(cfg, s, out, final_out): k = C_interleaved.op.reduce_axis[0] _, M, N = C.shape if in_type in ["int8", "uint8"]: - if is_mmla_available(): + target = Target.current(allow_none=False) + if target.features.has_matmul_i8: gemm_acc = gemm_acc_2x2_int8_int8_int32(in_type) xi_inner, yi_inner = C_interleaved.op.axis[-2:] k_outer, k_inner = s[C_interleaved].split(k, 8) @@ -333,7 +336,7 @@ def schedule_conv2d_gemm_interleaved(cfg, s, out, final_out): s[C_interleaved].tensorize(xi_inner, gemm_acc) s[C_interleaved].unroll(xi) s[C_interleaved].unroll(yi) - elif is_dotprod_available(): + elif target.features.has_dotprod: gemm_acc = gemm_acc_4x4_int8_int8_int32(in_type) xi_outer, yi_outer, xi_inner, yi_inner = s[C_interleaved].tile( xi, yi, x_factor=8, y_factor=4 @@ -354,7 +357,7 @@ def schedule_conv2d_gemm_interleaved(cfg, s, out, final_out): s[C_interleaved].tensorize(xi_inner_inner, gemm_acc) s[C_interleaved].unroll(xi_inner_outer) - elif is_aarch64_arm(): + elif target.features.has_asimd: s[C_interleaved].reorder(yi, xi) K = A_interleaved_input.shape[2] assert in_type in ["int8", "uint8"], "Only int8 and uint8 gemm are supported" diff --git a/python/tvm/topi/arm_cpu/conv2d_int8.py b/python/tvm/topi/arm_cpu/conv2d_int8.py index 224d21b34d9a..df231c0bc083 100644 --- a/python/tvm/topi/arm_cpu/conv2d_int8.py +++ b/python/tvm/topi/arm_cpu/conv2d_int8.py @@ -30,7 +30,7 @@ schedule_conv2d_gemm_interleaved, schedule_conv2d_gemm_native, ) -from .arm_utils import get_tiling_B_interleaved_t, is_dotprod_available, is_neon_available +from .arm_utils import get_tiling_B_interleaved_t def _get_default_config(cfg, data, kernel, strides, padding, dilation, out_dtype): @@ -124,7 +124,10 @@ def is_int8_hw_support(data_dtype, kernel_dtype): is_llvm_support = llvm_version >= 8 # 3) Check target - is_target_support = is_neon_available() or is_dotprod_available() + current_target = target.Target.current(allow_none=False) + is_target_support = bool( + current_target.features.has_asimd or current_target.features.has_dotprod + ) return is_dtype_support and is_llvm_support and is_target_support @@ -154,9 +157,10 @@ def _callback(op): _, _, kh, kw, _, _, n_elems = get_const_tuple(kernel_vec.shape) assert n_elems == 4 dtype = "uint" if data.dtype == "uint8" else "int" - if is_dotprod_available(): + current_target = target.Target.current(allow_none=False) + if current_target.features.has_dotprod: intrin = dot_int8_int8_int32_neon_82(int32_lanes=4, dtype=dtype) - elif is_neon_available(): + elif current_target.features.has_asimd: assert dtype == "int", "uint8 not supported if dot product is not available" intrin = dot_int8_int8_int32_neon() else: diff --git a/python/tvm/topi/arm_cpu/depthwise_conv2d.py b/python/tvm/topi/arm_cpu/depthwise_conv2d.py index 58cd11e8cc09..08967f121fd3 100644 --- a/python/tvm/topi/arm_cpu/depthwise_conv2d.py +++ b/python/tvm/topi/arm_cpu/depthwise_conv2d.py @@ -18,6 +18,7 @@ """Depthwise convolution schedule for ARM CPU""" import tvm +from tvm.target import Target from tvm import te from tvm import autotvm from tvm.autotvm.task.space import SplitEntity, OtherOptionEntity @@ -26,7 +27,6 @@ from ..utils import traverse_inline, get_const_tuple, get_const_int from ..nn.utils import get_pad_tuple from .tensor_intrin import smlal_int16_int32 -from .arm_utils import is_aarch64_arm from .mprofile.dsp.depthwise_conv2d import ( depthwise_conv2d_nhwc_dsp_compute, depthwise_conv2d_nhwc_dsp_schedule, @@ -333,12 +333,13 @@ def schedule_conv(conv): co, ci = cfg["tile_c"].apply(s, conv, c) split_val = cfg["tile_c"].size[-1] + target = Target.current(allow_none=False) use_tensorization = ( (in_type == "int16") and (split_val == 8) and (IC % split_val == 0) and (channel_multiplier == 1) - and is_aarch64_arm() + and target.features.has_asimd ) data_pad_value = -1 diff --git a/src/target/parsers/cpu.cc b/src/target/parsers/cpu.cc index fbf55f468313..3cfabb7639df 100644 --- a/src/target/parsers/cpu.cc +++ b/src/target/parsers/cpu.cc @@ -20,6 +20,7 @@ #include +#include "aprofile.h" #include "mprofile.h" namespace tvm { @@ -32,6 +33,10 @@ TargetJSON ParseTarget(TargetJSON target) { return mprofile::ParseTarget(target); } + if (aprofile::IsArch(target)) { + return aprofile::ParseTarget(target); + } + return target; } diff --git a/tests/python/contrib/test_arm_compute_lib/test_network.py b/tests/python/contrib/test_arm_compute_lib/test_network.py index b5b9ed6b6ef9..3cf81e971f77 100644 --- a/tests/python/contrib/test_arm_compute_lib/test_network.py +++ b/tests/python/contrib/test_arm_compute_lib/test_network.py @@ -20,11 +20,9 @@ import numpy as np import pytest -from tvm import testing from tvm import relay -from test_arm_compute_lib.infrastructure import skip_runtime_test, build_and_run, verify -from test_arm_compute_lib.infrastructure import Device +from test_arm_compute_lib.infrastructure import Device, skip_runtime_test, build_and_run, verify def _build_and_run_network(mod, params, inputs, device, tvm_ops, acl_partitions, atol, rtol): @@ -108,7 +106,12 @@ def get_model(): return mod, params, inputs _build_and_run_network( - *get_model(), device=device, tvm_ops=4, acl_partitions=21, atol=0.002, rtol=0.01 + *get_model(), + device=device, + tvm_ops=4, + acl_partitions=21, + atol=0.002, + rtol=0.01, ) @@ -180,7 +183,12 @@ def get_model(): return mod, params, inputs _build_and_run_network( - *get_model(), device=device, tvm_ops=3, acl_partitions=30, atol=9, rtol=0 + *get_model(), + device=device, + tvm_ops=3, + acl_partitions=30, + atol=10, + rtol=0, ) @@ -207,7 +215,12 @@ def get_model(): return mod, params, inputs _build_and_run_network( - *get_model(), device=device, tvm_ops=9, acl_partitions=31, atol=8, rtol=0 + *get_model(), + device=device, + tvm_ops=9, + acl_partitions=31, + atol=8, + rtol=0, ) diff --git a/tests/python/relay/test_op_level2.py b/tests/python/relay/test_op_level2.py index 7efec2db03b9..ca1adf940029 100644 --- a/tests/python/relay/test_op_level2.py +++ b/tests/python/relay/test_op_level2.py @@ -2209,7 +2209,9 @@ def get_conv2d_nchw( @tvm.testing.requires_arm_dot def test_conv2d_int8_alter_dtype_arm(): - _test_conv2d_int8_alter_dtype("uint8", "llvm --device arm_cpu -mattr=+v8.2a,+dotprod", "sdot") + _test_conv2d_int8_alter_dtype( + "uint8", "llvm -mtriple=aarch64-linux-gnu -mattr=+v8.2a,+dotprod", "sdot" + ) @tvm.testing.requires_cascadelake diff --git a/tests/python/target/test_arm_target.py b/tests/python/target/test_arm_target.py index 9106c169c869..dc8452710a8a 100644 --- a/tests/python/target/test_arm_target.py +++ b/tests/python/target/test_arm_target.py @@ -20,24 +20,31 @@ from tvm.topi.arm_cpu.conv2d_int8 import is_int8_hw_support from tvm.target import codegen -arm_target, input_dtype, kernel_dtype, is_supported = tvm.testing.parameters( +llvm_version, arm_target, input_dtype, kernel_dtype, is_supported = tvm.testing.parameters( # Testing mcpu type - ("c -mcpu=cortex-m4 -keys=arm_cpu", "int8", "int8", False), - ("c -mcpu=cortex-m7 -keys=arm_cpu", "int8", "int8", False), - ("c -mcpu=cortex-m33 -keys=arm_cpu", "int8", "int8", False), - ("c -mcpu=cortex-m55 -keys=arm_cpu", "int8", "int8", False), - ("c -mcpu=cortex-m3 -keys=arm_cpu", "int8", "int8", False), - ("llvm -keys=arm_cpu -mattr=+neon", "int8", "int8", True), - # This fails because of a bug in topi.arm_cpu.arm_utils.get_arch_version - # ("llvm -keys=arm_cpu -mattr=v8.4a,+dotprod", "int8", "int8", True), + (8, "c -mcpu=cortex-m4", "int8", "int8", False), + (8, "c -mcpu=cortex-m7", "int8", "int8", False), + (8, "c -mcpu=cortex-m33", "int8", "int8", False), + (8, "c -mcpu=cortex-m55", "int8", "int8", False), + (8, "c -mcpu=cortex-m3", "int8", "int8", False), + (7, "llvm -mtriple=arm-linux-gnueabi -mattr=+neon", "int8", "int8", False), + (8, "llvm -mtriple=arm-linux-gnueabi -mattr=+neon", "int8", "int8", True), + (9, "llvm -mtriple=arm-linux-gnueabi -mattr=+neon", "int8", "int8", True), + (8, "llvm -mtriple=arm-linux-gnueabi", "int8", "int8", False), + (7, "llvm -mtriple=aarch64-linux-gnu -mattr=+v8.4a,+dotprod", "int8", "int8", False), + (8, "llvm -mtriple=aarch64-linux-gnu -mattr=+v8.4a,+dotprod", "int8", "int8", True), + (9, "llvm -mtriple=arm-linux-gnueabi -mattr=+neon", "int8", "int8", True), + (8, "llvm -mtriple=aarch64-linux-gnu", "int8", "int8", True), # Testing dtype - ("llvm -keys=arm_cpu -mattr=+neon", "int16", "int8", False), - ("llvm -keys=arm_cpu -mattr=+neon", "int8", "int16", False), - ("llvm -keys=arm_cpu -mattr=+neon", "int16", "int16", False), + (8, "llvm -mtriple=aarch64-linux-gnu -mattr=+neon", "int16", "int8", False), + (8, "llvm -mtriple=aarch64-linux-gnu -mattr=+neon", "int8", "int16", False), + (8, "llvm -mtriple=aarch64-linux-gnu -mattr=+neon", "int16", "int16", False), ) -def test_arm_conv2d_int8_support(arm_target, input_dtype, kernel_dtype, is_supported): +def test_arm_conv2d_int8_support( + monkeypatch, llvm_version, arm_target, input_dtype, kernel_dtype, is_supported +): """Test ARM conv2d int8 support for different targets. Parameters @@ -52,5 +59,5 @@ def test_arm_conv2d_int8_support(arm_target, input_dtype, kernel_dtype, is_suppo Expected result. """ with tvm.target.Target(arm_target): - expected_result = is_supported and (codegen.llvm_version_major() >= 8) - assert is_int8_hw_support(input_dtype, kernel_dtype) == expected_result + monkeypatch.setattr(codegen, "llvm_version_major", lambda: llvm_version) + assert is_int8_hw_support(input_dtype, kernel_dtype) == is_supported diff --git a/tests/python/topi/python/test_topi_conv2d_int8.py b/tests/python/topi/python/test_topi_conv2d_int8.py index 6070cafa9c2c..c84f39ab5a66 100644 --- a/tests/python/topi/python/test_topi_conv2d_int8.py +++ b/tests/python/topi/python/test_topi_conv2d_int8.py @@ -26,7 +26,6 @@ from tvm.contrib.pickle_memoize import memoize from tvm.topi.nn.utils import get_pad_tuple from tvm.topi.utils import get_const_tuple -from tvm.topi.arm_cpu.conv2d_gemm import is_aarch64_arm from tvm.topi.nn.conv2d import _get_workload from tvm.topi.generic.conv2d import fallback_schedule_cpu_common_int8 @@ -94,8 +93,8 @@ def compile_conv2d_NHWC_gemm_int8_arm( print("Skip because %s is not enabled" % target) return print("Compiling on arm AArch64 target: %s" % target) - with tvm.target.Target(target): - assert is_aarch64_arm(), "AArch64 target not recognized" + with tvm.target.Target(target) as tvm_target: + assert tvm_target.features.is_aarch64, "AArch64 target not recognized" C = compute(A, W, (stride, stride), padding, (dilation, dilation), dtype) if add_bias: