Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 5 additions & 6 deletions python/tvm/relay/op/strategy/arm_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,21 +207,21 @@ def conv2d_strategy_arm_cpu(attrs, inputs, out_type, target):
name="conv2d_nhwc_dsp.arm_cpu",
)
elif kernel_layout == "HWIO":
is_aarch64 = topi.arm_cpu.arm_utils.is_aarch64_arm()
has_dot_prod = topi.arm_cpu.arm_utils.is_dotprod_available()
has_asimd = target.features.has_asimd
has_dot_prod = target.features.has_dotprod
if has_dot_prod and data.dtype in ["int8", "uint8"]:
strategy.add_implementation(
wrap_compute_conv2d(topi.arm_cpu.compute_conv2d_NHWC_quantized_native),
wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_NHWC_quantized_native),
name="conv2d_NHWC_quantized_native.arm_cpu",
)
if is_aarch64 and data.dtype in ["int8", "uint8"]:
if has_asimd and data.dtype in ["int8", "uint8"]:
strategy.add_implementation(
wrap_compute_conv2d(topi.arm_cpu.compute_conv2d_NHWC_quantized_interleaved),
wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_NHWC_quantized_interleaved),
name="conv2d_NHWC_quantized_interleaved.arm_cpu",
)
if (not is_aarch64) or (data.dtype not in ["int8", "uint8"]):
if (not has_asimd) or (data.dtype not in ["int8", "uint8"]):
# TODO(@giuseros)
# This strategy errors out for quantized data types when tuning.
# Let's use this only for non-aarch64 or non-quantized cases
Expand Down Expand Up @@ -281,8 +281,7 @@ def conv2d_strategy_arm_cpu(attrs, inputs, out_type, target):
)
elif layout == "NHWC":
assert kernel_layout == "HWOI"
is_aarch64 = topi.arm_cpu.arm_utils.is_aarch64_arm()
if is_aarch64 or "+neon" in target.mattr:
if target.features.has_asimd:
strategy.add_implementation(
wrap_compute_conv2d(topi.arm_cpu.compute_depthwise_conv2d_nhwc),
wrap_topi_schedule(topi.arm_cpu.schedule_depthwise_conv2d_nhwc),
Expand Down
21 changes: 3 additions & 18 deletions python/tvm/relay/qnn/op/legalizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,18 +405,6 @@ def is_fast_int8_on_intel():
return target_has_sse42(target.mcpu)


def is_fast_int8_on_arm():
"""Checks whether the hardware has support for fast Int8 arithmetic operations."""
target = tvm.target.Target.current(allow_none=False)
return "+v8.2a" in target.mattr and "+dotprod" in target.mattr


def is_aarch64_arm():
"""Checks whether we are compiling for an AArch64 target."""
target = tvm.target.Target.current(allow_none=False)
return "aarch64" in target.attrs.get("mtriple", "")


########################
# ARM CPU legalizations.
########################
Expand All @@ -425,7 +413,6 @@ def is_aarch64_arm():
@qnn_conv2d_legalize.register("arm_cpu")
def _qnn_conv2d_legalize_arm_cpu(attrs, inputs, types):
target = tvm.target.Target.current(allow_none=False)
has_asimd = is_aarch64_arm() or "+neon" in target.mattr
is_depthwise = relay.op.strategy.is_depthwise_conv2d(
types[0].shape,
attrs["data_layout"],
Expand All @@ -434,9 +421,8 @@ def _qnn_conv2d_legalize_arm_cpu(attrs, inputs, types):
attrs["groups"],
)
use_int8_on_arm = (not is_depthwise) and attrs["data_layout"] == "NHWC"
has_dotprod = is_fast_int8_on_arm()
other_options = use_int8_on_arm or has_dotprod
if has_asimd and not other_options:
other_options = use_int8_on_arm or target.features.has_dotprod
if target.features.has_asimd and not other_options:
return helper_no_fast_int8_hw_legalization(attrs, inputs, types, relay.nn.conv2d)
# ARM prefers the dtypes to be same.
return helper_change_dtypes_to_be_same(attrs, inputs, types, relay.qnn.op.conv2d)
Expand All @@ -445,8 +431,7 @@ def _qnn_conv2d_legalize_arm_cpu(attrs, inputs, types):
@qnn_dense_legalize.register("arm_cpu")
def _qnn_dense_legalize_arm_cpu(attrs, inputs, types):
target = tvm.target.Target.current(allow_none=False)
has_asimd = is_aarch64_arm() or "+neon" in target.mattr
if has_asimd and not is_fast_int8_on_arm():
if target.features.has_asimd and not target.features.has_dotprod:
return helper_no_fast_int8_hw_legalization(attrs, inputs, types, relay.nn.dense)
# ARM prefers the dtypes to be same.
return helper_change_dtypes_to_be_same(attrs, inputs, types, relay.qnn.op.dense)
Expand Down
58 changes: 5 additions & 53 deletions python/tvm/topi/arm_cpu/arm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,57 +17,7 @@
# pylint: disable=invalid-name,unused-variable,unused-argument,no-member
"""Arm target utility functions"""

import re
import tvm


def get_arch_version(target_mattr):
"""Parse the LLVM target -mattr, and return
the architecture version in a decimal representation
(e.g., if -mattr=v8.4a, return 8.4)
"""

arch_version = 8.0
m = re.compile(r"\+v(.*)\.(.*)a")
for attr in target_mattr:
match_obj = m.match(attr)
if match_obj:
major = int(match_obj.group(1))
minor = int(match_obj.group(2))
decimal = 10
if minor >= 10:
decimal = 100
arch_version = major + float(minor) / decimal

return arch_version


def is_dotprod_available():
"""Checks whether the hardware has support for udot/sdot instructions."""
target = tvm.target.Target.current(allow_none=False)
arch_version = get_arch_version(target.mattr)
return arch_version >= 8.4 or ((arch_version in (8.2, 8.3)) and "+dotprod" in target.mattr)


def is_mmla_available():
"""Checks whether the hardware has support for ummla/smmla instructions."""
target = tvm.target.Target.current(allow_none=False)
arch_version = get_arch_version(target.mattr)
return arch_version >= 8.6 or (
(arch_version in (8.2, 8.3, 8.4, 8.5)) and "+i8mm" in target.mattr
)


def is_aarch64_arm():
"""Checks whether we are compiling for an AArch64 target."""
target = tvm.target.Target.current(allow_none=False)
return "aarch64" in target.attrs.get("mtriple", "")


def is_neon_available():
"""Check if neon instructions are available"""
target = tvm.target.Target.current(allow_none=False)
return "+neon" in target.mattr
from tvm.target import Target


def get_tiling_B_interleaved_t(interleave_A):
Expand All @@ -94,13 +44,15 @@ def get_tiling_B_interleaved_t(interleave_A):
tile_rows_B: the output tile rows of B'
tile_cols_B: the output tile columns of B'
"""
if is_mmla_available():
target = Target.current(allow_none=False)

if target.features.has_matmul_i8:
# If smmla/ummla is available, A must be interleaved.
# Each load from B' will contain 8 elements
# and we are loading 12 rows of B' (i.e., 12 columns of B)
tile_rows_B = 12
tile_cols_B = 8
elif is_dotprod_available():
elif target.features.has_dotprod:
# The number of tile rows of B' vary depending on the
# strategy:
# * If we are interleaving A, then we select 12 columns from B'(i.e.,
Expand Down
23 changes: 13 additions & 10 deletions python/tvm/topi/arm_cpu/conv2d_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
# pylint: disable=unused-argument, redefined-builtin
"""GEMM Convolution schedule on ARM"""
import tvm
from tvm.target import Target
from tvm import te
from tvm.topi import nn
from tvm.autotvm.task.space import AnnotateEntity, ReorderEntity, OtherOptionEntity
Expand All @@ -29,10 +30,9 @@
gemm_acc_nx16_int8_int8_int32,
gemm_acc_2x2_int8_int8_int32,
)
from .arm_utils import is_aarch64_arm, is_dotprod_available, is_mmla_available


def configure_knobs(cfg, M, K):
def configure_knobs(cfg, M, K, target):
"""Configure auto-tuning knobs for the interleaved strategy"""

x, y = cfg.axis(M // 4), cfg.axis(K // 16)
Expand All @@ -48,7 +48,7 @@ def configure_knobs(cfg, M, K):
cfg["reorder_gemm"] = ReorderEntity([0, 1])
cfg["A_interleaved_unroll_vec"] = AnnotateEntity(["unroll", "vec"])

if not is_dotprod_available():
if not target.features.has_dotprod:
cfg.define_knob("gemm_quantized_unroll", [True, False])
if cfg.is_fallback:
cfg["gemm_quantized_unroll"] = OtherOptionEntity(False)
Expand Down Expand Up @@ -133,12 +133,13 @@ def compute_conv2d_gemm_without_weight_transform(
# - Conv2DGemmWeightTransformRel in src/relay/op/nn/convolution.h
# In order to have more information
#
if is_mmla_available():
target = Target.current(allow_none=False)
if target.features.has_matmul_i8:
# If smmla/ummla is enabled, we are loading 8 rows from A. Each row
# will contain 8 elements
tile_rows_A = 8
tile_cols_A = 8
elif is_dotprod_available() and interleave_A:
elif target.features.has_dotprod and interleave_A:
# If dot product has been enabled, and we are interleaving A
# tile size should be 8x4
tile_rows_A = 8
Expand Down Expand Up @@ -173,15 +174,16 @@ def compute_conv2d_gemm_without_weight_transform(

if interleave_A:
# Configuration space
configure_knobs(cfg, M_padded, K_padded)
configure_knobs(cfg, M_padded, K_padded, target)

# Pack the input data
A_interleaved = te.compute(
(batches, M_padded // tile_rows_A, K_padded // tile_cols_A, tile_rows_A, tile_cols_A),
lambda b, x, y, z, w: A[b, z + tile_rows_A * x, w + tile_cols_A * y],
name="A_interleaved",
)
if is_mmla_available():
target = Target.current(allow_none=False)
if target.features.has_matmul_i8:
# Execute GEMM. In the case of mmla, we need to enforce the tiling
# from the compute. This is because mmla is doing a tiled computation
# as well. So we have a big 8x12 tile, with small 2x2 sub-tiles
Expand Down Expand Up @@ -323,7 +325,8 @@ def schedule_conv2d_gemm_interleaved(cfg, s, out, final_out):
k = C_interleaved.op.reduce_axis[0]
_, M, N = C.shape
if in_type in ["int8", "uint8"]:
if is_mmla_available():
target = Target.current(allow_none=False)
if target.features.has_matmul_i8:
gemm_acc = gemm_acc_2x2_int8_int8_int32(in_type)
xi_inner, yi_inner = C_interleaved.op.axis[-2:]
k_outer, k_inner = s[C_interleaved].split(k, 8)
Expand All @@ -333,7 +336,7 @@ def schedule_conv2d_gemm_interleaved(cfg, s, out, final_out):
s[C_interleaved].tensorize(xi_inner, gemm_acc)
s[C_interleaved].unroll(xi)
s[C_interleaved].unroll(yi)
elif is_dotprod_available():
elif target.features.has_dotprod:
gemm_acc = gemm_acc_4x4_int8_int8_int32(in_type)
xi_outer, yi_outer, xi_inner, yi_inner = s[C_interleaved].tile(
xi, yi, x_factor=8, y_factor=4
Expand All @@ -354,7 +357,7 @@ def schedule_conv2d_gemm_interleaved(cfg, s, out, final_out):
s[C_interleaved].tensorize(xi_inner_inner, gemm_acc)
s[C_interleaved].unroll(xi_inner_outer)

elif is_aarch64_arm():
elif target.features.has_asimd:
s[C_interleaved].reorder(yi, xi)
K = A_interleaved_input.shape[2]
assert in_type in ["int8", "uint8"], "Only int8 and uint8 gemm are supported"
Expand Down
12 changes: 8 additions & 4 deletions python/tvm/topi/arm_cpu/conv2d_int8.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
schedule_conv2d_gemm_interleaved,
schedule_conv2d_gemm_native,
)
from .arm_utils import get_tiling_B_interleaved_t, is_dotprod_available, is_neon_available
from .arm_utils import get_tiling_B_interleaved_t


def _get_default_config(cfg, data, kernel, strides, padding, dilation, out_dtype):
Expand Down Expand Up @@ -124,7 +124,10 @@ def is_int8_hw_support(data_dtype, kernel_dtype):
is_llvm_support = llvm_version >= 8

# 3) Check target
is_target_support = is_neon_available() or is_dotprod_available()
current_target = target.Target.current(allow_none=False)
is_target_support = bool(
current_target.features.has_asimd or current_target.features.has_dotprod
)

return is_dtype_support and is_llvm_support and is_target_support

Expand Down Expand Up @@ -154,9 +157,10 @@ def _callback(op):
_, _, kh, kw, _, _, n_elems = get_const_tuple(kernel_vec.shape)
assert n_elems == 4
dtype = "uint" if data.dtype == "uint8" else "int"
if is_dotprod_available():
current_target = target.Target.current(allow_none=False)
if current_target.features.has_dotprod:
intrin = dot_int8_int8_int32_neon_82(int32_lanes=4, dtype=dtype)
elif is_neon_available():
elif current_target.features.has_asimd:
assert dtype == "int", "uint8 not supported if dot product is not available"
intrin = dot_int8_int8_int32_neon()
else:
Expand Down
5 changes: 3 additions & 2 deletions python/tvm/topi/arm_cpu/depthwise_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
"""Depthwise convolution schedule for ARM CPU"""

import tvm
from tvm.target import Target
from tvm import te
from tvm import autotvm
from tvm.autotvm.task.space import SplitEntity, OtherOptionEntity
Expand All @@ -26,7 +27,6 @@
from ..utils import traverse_inline, get_const_tuple, get_const_int
from ..nn.utils import get_pad_tuple
from .tensor_intrin import smlal_int16_int32
from .arm_utils import is_aarch64_arm
from .mprofile.dsp.depthwise_conv2d import (
depthwise_conv2d_nhwc_dsp_compute,
depthwise_conv2d_nhwc_dsp_schedule,
Expand Down Expand Up @@ -333,12 +333,13 @@ def schedule_conv(conv):
co, ci = cfg["tile_c"].apply(s, conv, c)

split_val = cfg["tile_c"].size[-1]
target = Target.current(allow_none=False)
use_tensorization = (
(in_type == "int16")
and (split_val == 8)
and (IC % split_val == 0)
and (channel_multiplier == 1)
and is_aarch64_arm()
and target.features.has_asimd
)

data_pad_value = -1
Expand Down
5 changes: 5 additions & 0 deletions src/target/parsers/cpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

#include <string>

#include "aprofile.h"
#include "mprofile.h"

namespace tvm {
Expand All @@ -32,6 +33,10 @@ TargetJSON ParseTarget(TargetJSON target) {
return mprofile::ParseTarget(target);
}

if (aprofile::IsArch(target)) {
return aprofile::ParseTarget(target);
}

return target;
}

Expand Down
25 changes: 19 additions & 6 deletions tests/python/contrib/test_arm_compute_lib/test_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,9 @@

import numpy as np
import pytest
from tvm import testing
from tvm import relay

from test_arm_compute_lib.infrastructure import skip_runtime_test, build_and_run, verify
from test_arm_compute_lib.infrastructure import Device
from test_arm_compute_lib.infrastructure import Device, skip_runtime_test, build_and_run, verify


def _build_and_run_network(mod, params, inputs, device, tvm_ops, acl_partitions, atol, rtol):
Expand Down Expand Up @@ -108,7 +106,12 @@ def get_model():
return mod, params, inputs

_build_and_run_network(
*get_model(), device=device, tvm_ops=4, acl_partitions=21, atol=0.002, rtol=0.01
*get_model(),
device=device,
tvm_ops=4,
acl_partitions=21,
atol=0.002,
rtol=0.01,
)


Expand Down Expand Up @@ -180,7 +183,12 @@ def get_model():
return mod, params, inputs

_build_and_run_network(
*get_model(), device=device, tvm_ops=3, acl_partitions=30, atol=9, rtol=0
*get_model(),
device=device,
tvm_ops=3,
acl_partitions=30,
atol=10,
rtol=0,
)


Expand All @@ -207,7 +215,12 @@ def get_model():
return mod, params, inputs

_build_and_run_network(
*get_model(), device=device, tvm_ops=9, acl_partitions=31, atol=8, rtol=0
*get_model(),
device=device,
tvm_ops=9,
acl_partitions=31,
atol=8,
rtol=0,
)


Expand Down
Loading