Skip to content

Commit 3131cdc

Browse files
authored
[Target] Replace utility functions with target.features (#12455)
Following on from #12454 this patch removes the utility functions in favour of the centralised `target.features` property.
1 parent 24e89be commit 3131cdc

File tree

11 files changed

+88
-118
lines changed

11 files changed

+88
-118
lines changed

python/tvm/relay/op/strategy/arm_cpu.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -207,21 +207,21 @@ def conv2d_strategy_arm_cpu(attrs, inputs, out_type, target):
207207
name="conv2d_nhwc_dsp.arm_cpu",
208208
)
209209
elif kernel_layout == "HWIO":
210-
is_aarch64 = topi.arm_cpu.arm_utils.is_aarch64_arm()
211-
has_dot_prod = topi.arm_cpu.arm_utils.is_dotprod_available()
210+
has_asimd = target.features.has_asimd
211+
has_dot_prod = target.features.has_dotprod
212212
if has_dot_prod and data.dtype in ["int8", "uint8"]:
213213
strategy.add_implementation(
214214
wrap_compute_conv2d(topi.arm_cpu.compute_conv2d_NHWC_quantized_native),
215215
wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_NHWC_quantized_native),
216216
name="conv2d_NHWC_quantized_native.arm_cpu",
217217
)
218-
if is_aarch64 and data.dtype in ["int8", "uint8"]:
218+
if has_asimd and data.dtype in ["int8", "uint8"]:
219219
strategy.add_implementation(
220220
wrap_compute_conv2d(topi.arm_cpu.compute_conv2d_NHWC_quantized_interleaved),
221221
wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_NHWC_quantized_interleaved),
222222
name="conv2d_NHWC_quantized_interleaved.arm_cpu",
223223
)
224-
if (not is_aarch64) or (data.dtype not in ["int8", "uint8"]):
224+
if (not has_asimd) or (data.dtype not in ["int8", "uint8"]):
225225
# TODO(@giuseros)
226226
# This strategy errors out for quantized data types when tuning.
227227
# Let's use this only for non-aarch64 or non-quantized cases
@@ -283,8 +283,7 @@ def conv2d_strategy_arm_cpu(attrs, inputs, out_type, target):
283283
)
284284
elif layout == "NHWC":
285285
assert kernel_layout == "HWOI"
286-
is_aarch64 = topi.arm_cpu.arm_utils.is_aarch64_arm()
287-
if is_aarch64 or "+neon" in target.mattr:
286+
if target.features.has_asimd:
288287
strategy.add_implementation(
289288
wrap_compute_conv2d(topi.arm_cpu.compute_depthwise_conv2d_nhwc),
290289
wrap_topi_schedule(topi.arm_cpu.schedule_depthwise_conv2d_nhwc),

python/tvm/relay/qnn/op/legalizations.py

Lines changed: 3 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -405,18 +405,6 @@ def is_fast_int8_on_intel():
405405
return target_has_sse42(target.mcpu)
406406

407407

408-
def is_fast_int8_on_arm():
409-
"""Checks whether the hardware has support for fast Int8 arithmetic operations."""
410-
target = tvm.target.Target.current(allow_none=False)
411-
return "+v8.2a" in target.mattr and "+dotprod" in target.mattr
412-
413-
414-
def is_aarch64_arm():
415-
"""Checks whether we are compiling for an AArch64 target."""
416-
target = tvm.target.Target.current(allow_none=False)
417-
return "aarch64" in target.attrs.get("mtriple", "")
418-
419-
420408
########################
421409
# ARM CPU legalizations.
422410
########################
@@ -425,7 +413,6 @@ def is_aarch64_arm():
425413
@qnn_conv2d_legalize.register("arm_cpu")
426414
def _qnn_conv2d_legalize_arm_cpu(attrs, inputs, types):
427415
target = tvm.target.Target.current(allow_none=False)
428-
has_asimd = is_aarch64_arm() or "+neon" in target.mattr
429416
is_depthwise = relay.op.strategy.is_depthwise_conv2d(
430417
types[0].shape,
431418
attrs["data_layout"],
@@ -434,9 +421,8 @@ def _qnn_conv2d_legalize_arm_cpu(attrs, inputs, types):
434421
attrs["groups"],
435422
)
436423
use_int8_on_arm = (not is_depthwise) and attrs["data_layout"] == "NHWC"
437-
has_dotprod = is_fast_int8_on_arm()
438-
other_options = use_int8_on_arm or has_dotprod
439-
if has_asimd and not other_options:
424+
other_options = use_int8_on_arm or target.features.has_dotprod
425+
if target.features.has_asimd and not other_options:
440426
return helper_no_fast_int8_hw_legalization(attrs, inputs, types, relay.nn.conv2d)
441427
# ARM prefers the dtypes to be same.
442428
return helper_change_dtypes_to_be_same(attrs, inputs, types, relay.qnn.op.conv2d)
@@ -445,8 +431,7 @@ def _qnn_conv2d_legalize_arm_cpu(attrs, inputs, types):
445431
@qnn_dense_legalize.register("arm_cpu")
446432
def _qnn_dense_legalize_arm_cpu(attrs, inputs, types):
447433
target = tvm.target.Target.current(allow_none=False)
448-
has_asimd = is_aarch64_arm() or "+neon" in target.mattr
449-
if has_asimd and not is_fast_int8_on_arm():
434+
if target.features.has_asimd and not target.features.has_dotprod:
450435
return helper_no_fast_int8_hw_legalization(attrs, inputs, types, relay.nn.dense)
451436
# ARM prefers the dtypes to be same.
452437
return helper_change_dtypes_to_be_same(attrs, inputs, types, relay.qnn.op.dense)

python/tvm/topi/arm_cpu/arm_utils.py

Lines changed: 5 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -17,57 +17,7 @@
1717
# pylint: disable=invalid-name,unused-variable,unused-argument,no-member
1818
"""Arm target utility functions"""
1919

20-
import re
21-
import tvm
22-
23-
24-
def get_arch_version(target_mattr):
25-
"""Parse the LLVM target -mattr, and return
26-
the architecture version in a decimal representation
27-
(e.g., if -mattr=v8.4a, return 8.4)
28-
"""
29-
30-
arch_version = 8.0
31-
m = re.compile(r"\+v(.*)\.(.*)a")
32-
for attr in target_mattr:
33-
match_obj = m.match(attr)
34-
if match_obj:
35-
major = int(match_obj.group(1))
36-
minor = int(match_obj.group(2))
37-
decimal = 10
38-
if minor >= 10:
39-
decimal = 100
40-
arch_version = major + float(minor) / decimal
41-
42-
return arch_version
43-
44-
45-
def is_dotprod_available():
46-
"""Checks whether the hardware has support for udot/sdot instructions."""
47-
target = tvm.target.Target.current(allow_none=False)
48-
arch_version = get_arch_version(target.mattr)
49-
return arch_version >= 8.4 or ((arch_version in (8.2, 8.3)) and "+dotprod" in target.mattr)
50-
51-
52-
def is_mmla_available():
53-
"""Checks whether the hardware has support for ummla/smmla instructions."""
54-
target = tvm.target.Target.current(allow_none=False)
55-
arch_version = get_arch_version(target.mattr)
56-
return arch_version >= 8.6 or (
57-
(arch_version in (8.2, 8.3, 8.4, 8.5)) and "+i8mm" in target.mattr
58-
)
59-
60-
61-
def is_aarch64_arm():
62-
"""Checks whether we are compiling for an AArch64 target."""
63-
target = tvm.target.Target.current(allow_none=False)
64-
return "aarch64" in target.attrs.get("mtriple", "")
65-
66-
67-
def is_neon_available():
68-
"""Check if neon instructions are available"""
69-
target = tvm.target.Target.current(allow_none=False)
70-
return "+neon" in target.mattr
20+
from tvm.target import Target
7121

7222

7323
def get_tiling_B_interleaved_t(interleave_A):
@@ -94,13 +44,15 @@ def get_tiling_B_interleaved_t(interleave_A):
9444
tile_rows_B: the output tile rows of B'
9545
tile_cols_B: the output tile columns of B'
9646
"""
97-
if is_mmla_available():
47+
target = Target.current(allow_none=False)
48+
49+
if target.features.has_matmul_i8:
9850
# If smmla/ummla is available, A must be interleaved.
9951
# Each load from B' will contain 8 elements
10052
# and we are loading 12 rows of B' (i.e., 12 columns of B)
10153
tile_rows_B = 12
10254
tile_cols_B = 8
103-
elif is_dotprod_available():
55+
elif target.features.has_dotprod:
10456
# The number of tile rows of B' vary depending on the
10557
# strategy:
10658
# * If we are interleaving A, then we select 12 columns from B'(i.e.,

python/tvm/topi/arm_cpu/conv2d_gemm.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
# pylint: disable=unused-argument, redefined-builtin
1919
"""GEMM Convolution schedule on ARM"""
2020
import tvm
21+
from tvm.target import Target
2122
from tvm import te
2223
from tvm.topi import nn
2324
from tvm.autotvm.task.space import AnnotateEntity, ReorderEntity, OtherOptionEntity
@@ -29,10 +30,9 @@
2930
gemm_acc_nx16_int8_int8_int32,
3031
gemm_acc_2x2_int8_int8_int32,
3132
)
32-
from .arm_utils import is_aarch64_arm, is_dotprod_available, is_mmla_available
3333

3434

35-
def configure_knobs(cfg, M, K):
35+
def configure_knobs(cfg, M, K, target):
3636
"""Configure auto-tuning knobs for the interleaved strategy"""
3737

3838
x, y = cfg.axis(M // 4), cfg.axis(K // 16)
@@ -48,7 +48,7 @@ def configure_knobs(cfg, M, K):
4848
cfg["reorder_gemm"] = ReorderEntity([0, 1])
4949
cfg["A_interleaved_unroll_vec"] = AnnotateEntity(["unroll", "vec"])
5050

51-
if not is_dotprod_available():
51+
if not target.features.has_dotprod:
5252
cfg.define_knob("gemm_quantized_unroll", [True, False])
5353
if cfg.is_fallback:
5454
cfg["gemm_quantized_unroll"] = OtherOptionEntity(False)
@@ -133,12 +133,13 @@ def compute_conv2d_gemm_without_weight_transform(
133133
# - Conv2DGemmWeightTransformRel in src/relay/op/nn/convolution.h
134134
# In order to have more information
135135
#
136-
if is_mmla_available():
136+
target = Target.current(allow_none=False)
137+
if target.features.has_matmul_i8:
137138
# If smmla/ummla is enabled, we are loading 8 rows from A. Each row
138139
# will contain 8 elements
139140
tile_rows_A = 8
140141
tile_cols_A = 8
141-
elif is_dotprod_available() and interleave_A:
142+
elif target.features.has_dotprod and interleave_A:
142143
# If dot product has been enabled, and we are interleaving A
143144
# tile size should be 8x4
144145
tile_rows_A = 8
@@ -173,15 +174,16 @@ def compute_conv2d_gemm_without_weight_transform(
173174

174175
if interleave_A:
175176
# Configuration space
176-
configure_knobs(cfg, M_padded, K_padded)
177+
configure_knobs(cfg, M_padded, K_padded, target)
177178

178179
# Pack the input data
179180
A_interleaved = te.compute(
180181
(batches, M_padded // tile_rows_A, K_padded // tile_cols_A, tile_rows_A, tile_cols_A),
181182
lambda b, x, y, z, w: A[b, z + tile_rows_A * x, w + tile_cols_A * y],
182183
name="A_interleaved",
183184
)
184-
if is_mmla_available():
185+
target = Target.current(allow_none=False)
186+
if target.features.has_matmul_i8:
185187
# Execute GEMM. In the case of mmla, we need to enforce the tiling
186188
# from the compute. This is because mmla is doing a tiled computation
187189
# as well. So we have a big 8x12 tile, with small 2x2 sub-tiles
@@ -323,7 +325,8 @@ def schedule_conv2d_gemm_interleaved(cfg, s, out, final_out):
323325
k = C_interleaved.op.reduce_axis[0]
324326
_, M, N = C.shape
325327
if in_type in ["int8", "uint8"]:
326-
if is_mmla_available():
328+
target = Target.current(allow_none=False)
329+
if target.features.has_matmul_i8:
327330
gemm_acc = gemm_acc_2x2_int8_int8_int32(in_type)
328331
xi_inner, yi_inner = C_interleaved.op.axis[-2:]
329332
k_outer, k_inner = s[C_interleaved].split(k, 8)
@@ -333,7 +336,7 @@ def schedule_conv2d_gemm_interleaved(cfg, s, out, final_out):
333336
s[C_interleaved].tensorize(xi_inner, gemm_acc)
334337
s[C_interleaved].unroll(xi)
335338
s[C_interleaved].unroll(yi)
336-
elif is_dotprod_available():
339+
elif target.features.has_dotprod:
337340
gemm_acc = gemm_acc_4x4_int8_int8_int32(in_type)
338341
xi_outer, yi_outer, xi_inner, yi_inner = s[C_interleaved].tile(
339342
xi, yi, x_factor=8, y_factor=4
@@ -354,7 +357,7 @@ def schedule_conv2d_gemm_interleaved(cfg, s, out, final_out):
354357
s[C_interleaved].tensorize(xi_inner_inner, gemm_acc)
355358
s[C_interleaved].unroll(xi_inner_outer)
356359

357-
elif is_aarch64_arm():
360+
elif target.features.has_asimd:
358361
s[C_interleaved].reorder(yi, xi)
359362
K = A_interleaved_input.shape[2]
360363
assert in_type in ["int8", "uint8"], "Only int8 and uint8 gemm are supported"

python/tvm/topi/arm_cpu/conv2d_int8.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
schedule_conv2d_gemm_interleaved,
3131
schedule_conv2d_gemm_native,
3232
)
33-
from .arm_utils import get_tiling_B_interleaved_t, is_dotprod_available, is_neon_available
33+
from .arm_utils import get_tiling_B_interleaved_t
3434

3535

3636
def _get_default_config(cfg, data, kernel, strides, padding, dilation, out_dtype):
@@ -124,7 +124,10 @@ def is_int8_hw_support(data_dtype, kernel_dtype):
124124
is_llvm_support = llvm_version >= 8
125125

126126
# 3) Check target
127-
is_target_support = is_neon_available() or is_dotprod_available()
127+
current_target = target.Target.current(allow_none=False)
128+
is_target_support = bool(
129+
current_target.features.has_asimd or current_target.features.has_dotprod
130+
)
128131

129132
return is_dtype_support and is_llvm_support and is_target_support
130133

@@ -154,9 +157,10 @@ def _callback(op):
154157
_, _, kh, kw, _, _, n_elems = get_const_tuple(kernel_vec.shape)
155158
assert n_elems == 4
156159
dtype = "uint" if data.dtype == "uint8" else "int"
157-
if is_dotprod_available():
160+
current_target = target.Target.current(allow_none=False)
161+
if current_target.features.has_dotprod:
158162
intrin = dot_int8_int8_int32_neon_82(int32_lanes=4, dtype=dtype)
159-
elif is_neon_available():
163+
elif current_target.features.has_asimd:
160164
assert dtype == "int", "uint8 not supported if dot product is not available"
161165
intrin = dot_int8_int8_int32_neon()
162166
else:

python/tvm/topi/arm_cpu/depthwise_conv2d.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
"""Depthwise convolution schedule for ARM CPU"""
1919

2020
import tvm
21+
from tvm.target import Target
2122
from tvm import te
2223
from tvm import autotvm
2324
from tvm.autotvm.task.space import SplitEntity, OtherOptionEntity
@@ -26,7 +27,6 @@
2627
from ..utils import traverse_inline, get_const_tuple, get_const_int
2728
from ..nn.utils import get_pad_tuple
2829
from .tensor_intrin import smlal_int16_int32
29-
from .arm_utils import is_aarch64_arm
3030
from .mprofile.dsp.depthwise_conv2d import (
3131
depthwise_conv2d_nhwc_dsp_compute,
3232
depthwise_conv2d_nhwc_dsp_schedule,
@@ -333,12 +333,13 @@ def schedule_conv(conv):
333333
co, ci = cfg["tile_c"].apply(s, conv, c)
334334

335335
split_val = cfg["tile_c"].size[-1]
336+
target = Target.current(allow_none=False)
336337
use_tensorization = (
337338
(in_type == "int16")
338339
and (split_val == 8)
339340
and (IC % split_val == 0)
340341
and (channel_multiplier == 1)
341-
and is_aarch64_arm()
342+
and target.features.has_asimd
342343
)
343344

344345
data_pad_value = -1

src/target/parsers/cpu.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
#include <string>
2222

23+
#include "aprofile.h"
2324
#include "mprofile.h"
2425

2526
namespace tvm {
@@ -32,6 +33,10 @@ TargetJSON ParseTarget(TargetJSON target) {
3233
return mprofile::ParseTarget(target);
3334
}
3435

36+
if (aprofile::IsArch(target)) {
37+
return aprofile::ParseTarget(target);
38+
}
39+
3540
return target;
3641
}
3742

tests/python/contrib/test_arm_compute_lib/test_network.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,9 @@
2020

2121
import numpy as np
2222
import pytest
23-
from tvm import testing
2423
from tvm import relay
2524

26-
from test_arm_compute_lib.infrastructure import skip_runtime_test, build_and_run, verify
27-
from test_arm_compute_lib.infrastructure import Device
25+
from test_arm_compute_lib.infrastructure import Device, skip_runtime_test, build_and_run, verify
2826

2927

3028
def _build_and_run_network(mod, params, inputs, device, tvm_ops, acl_partitions, atol, rtol):
@@ -108,7 +106,12 @@ def get_model():
108106
return mod, params, inputs
109107

110108
_build_and_run_network(
111-
*get_model(), device=device, tvm_ops=4, acl_partitions=21, atol=0.002, rtol=0.01
109+
*get_model(),
110+
device=device,
111+
tvm_ops=4,
112+
acl_partitions=21,
113+
atol=0.002,
114+
rtol=0.01,
112115
)
113116

114117

@@ -180,7 +183,12 @@ def get_model():
180183
return mod, params, inputs
181184

182185
_build_and_run_network(
183-
*get_model(), device=device, tvm_ops=3, acl_partitions=30, atol=9, rtol=0
186+
*get_model(),
187+
device=device,
188+
tvm_ops=3,
189+
acl_partitions=30,
190+
atol=10,
191+
rtol=0,
184192
)
185193

186194

@@ -207,7 +215,12 @@ def get_model():
207215
return mod, params, inputs
208216

209217
_build_and_run_network(
210-
*get_model(), device=device, tvm_ops=9, acl_partitions=31, atol=8, rtol=0
218+
*get_model(),
219+
device=device,
220+
tvm_ops=9,
221+
acl_partitions=31,
222+
atol=8,
223+
rtol=0,
211224
)
212225

213226

0 commit comments

Comments
 (0)