Skip to content

Commit e53a8bc

Browse files
authored
[TOPI][Target] Add fp16 SIMD support for conv2d on arm_cpu targets (#16383)
Optimised fp16 conv2d matrix tiling for Arm(R) Neon(TM) instructions and exposed `+fullfp16` as a target feature for Arm(R) Cortex(R) A-Profile CPUs. Also, a target test was added to `cpptest` for Arm(R) Cortex(R) A-Profile CPUs which checks that the `has_fp16_simd` flag is set exclusively when the user explicitly passes the `+fullfp16` or `+sve` attributes and a supporting architecture version at target creation.
1 parent 0b2358c commit e53a8bc

File tree

4 files changed

+52
-10
lines changed

4 files changed

+52
-10
lines changed

python/tvm/topi/arm_cpu/arm_utils.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,13 @@ def get_tiling_B_transformed(interleave_A, in_dtype):
7474
# we load 4 rows of B' (i.e., 4 columns of B). Each of them will contain 16 elements
7575
tile_N = 4
7676
tile_K = 16
77+
# In non-quantized cases, A is not interleaved.
78+
elif in_dtype == "float16" and target.features.has_fp16_simd:
79+
# Each load from B' contains 32 elements (i.e. 32 columns from B)
80+
# We are loading 4 rows from B', in the dimension of reduction (i.e. 4 rows from B)
81+
tile_N = 32
82+
tile_K = 4
7783
else:
78-
# In non-quantized cases, A is not interleaved.
7984
# Each load from B' contains 16 elements (i.e. 16 columns from B)
8085
# We are loading 4 rows from B', in the dimension of reduction (i.e. 4 rows from B)
8186
tile_N = 16

python/tvm/topi/arm_cpu/conv2d_gemm.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from tvm import te
2323
from tvm.topi import nn
2424
from tvm.autotvm.task.space import AnnotateEntity, ReorderEntity, OtherOptionEntity
25+
from tvm.topi.arm_cpu.arm_utils import get_tiling_B_transformed
2526
from ..utils import get_const_tuple, get_const_int
2627
from ..nn.utils import get_pad_tuple
2728
from .tensor_intrin import (
@@ -339,7 +340,15 @@ def compute_conv2d_gemm_without_weight_transform(
339340
),
340341
name="C",
341342
)
342-
zero = tvm.tir.const(0)
343+
# Ensure padding on the N axis does not get removed during tir passes
344+
# by adding a dummy reference to the specific padded area of the result
345+
if in_dtype == "float16" and target.features.has_fp16_simd:
346+
zero = (
347+
tvm.tir.const(1, C.dtype) * C[0, 0, N_padded - 1]
348+
- tvm.tir.const(1, C.dtype) * C[0, 0, N_padded - 1]
349+
)
350+
else:
351+
zero = tvm.tir.const(0)
343352

344353
# Reshape the result into a convolution output
345354
out_shape = (batches, OH, OW, OC)
@@ -454,14 +463,14 @@ def schedule_conv2d_gemm_native(cfg, s, out, final_out):
454463
C = out.op.input_tensors[0]
455464
A = C.op.input_tensors[0]
456465
in_type = A.dtype
466+
y_tile_size, _ = get_tiling_B_transformed(False, in_type)
457467

458468
# Computation
459469
b, x, y = C.op.axis
460470
(k,) = C.op.reduce_axis
461471

462472
if in_type in ["int8", "uint8"]:
463473
k_outer, k_inner = s[C].split(k, 16)
464-
y_tile_size = 16
465474
x_outer, y_outer, x_inner, y_inner = s[C].tile(x, y, x_factor=4, y_factor=y_tile_size)
466475
s[C].reorder(b, x_outer, y_outer, k_outer, x_inner, y_inner, k_inner)
467476
gemm_acc = gemm_acc_nx16_int8_int8_int32(in_type, rows=1)
@@ -470,9 +479,8 @@ def schedule_conv2d_gemm_native(cfg, s, out, final_out):
470479
s[C].parallel(x_outer)
471480
else:
472481
k_outer, k_inner = s[C].split(k, 4)
473-
y_tile_size = 16
474482
x_outer, y_outer, x_inner, y_inner = s[C].tile(x, y, x_factor=4, y_factor=y_tile_size)
475-
y_inner_outer, y_inner_inner = s[C].split(y_inner, 4)
483+
y_inner_outer, y_inner_inner = s[C].split(y_inner, nparts=4)
476484
b_x_outer_fused = s[C].fuse(b, x_outer)
477485
s[C].parallel(b_x_outer_fused)
478486
s[C].reorder(

src/target/parsers/aprofile.cc

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -127,11 +127,13 @@ static TargetFeatures GetFeatures(TargetJSON target) {
127127
const bool has_dotprod =
128128
(dotprod_default && !dotprod_disable) || (dotprod_support && dotprod_flag);
129129

130-
return {
131-
{"is_aarch64", Bool(is_aarch64)}, {"has_asimd", Bool(has_asimd)},
132-
{"has_sve", Bool(has_sve)}, {"has_dotprod", Bool(has_dotprod)},
133-
{"has_matmul_i8", Bool(has_i8mm)},
134-
};
130+
const bool fp16_flag = HasFlag(mcpu, mattr, "+fullfp16");
131+
const bool fp16_support = arch_version >= 8.2;
132+
const bool has_fp16_simd = fp16_support && (fp16_flag || has_sve);
133+
134+
return {{"is_aarch64", Bool(is_aarch64)}, {"has_asimd", Bool(has_asimd)},
135+
{"has_sve", Bool(has_sve)}, {"has_dotprod", Bool(has_dotprod)},
136+
{"has_matmul_i8", Bool(has_i8mm)}, {"has_fp16_simd", Bool(has_fp16_simd)}};
135137
}
136138

137139
static Array<String> MergeKeys(Optional<Array<String>> existing_keys) {

tests/cpp/target/parsers/aprofile_test.cc

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -307,11 +307,38 @@ TEST_P(AProfileOptionalSVE, OptionalSVESupport) {
307307
EXPECT_TRUE(Downcast<Bool>(features.at("has_sve")));
308308
}
309309

310+
using AProfileOptionalFP16 = testing::TestWithParam<float>;
311+
TEST_P(AProfileOptionalFP16, OptionalFP16Support) {
312+
const std::string arch_attr = "+v" + std::to_string(GetParam()) + "a";
313+
314+
// Check that the "has_fp16_simd" feature is not set by default when "+fullfp16" isn't set as an
315+
// attribute.
316+
TargetJSON target = ParseTargetWithAttrs("", "aarch64-arm-none-eabi", {arch_attr});
317+
TargetFeatures features = Downcast<TargetFeatures>(target.at("features"));
318+
EXPECT_TRUE(IsArch(target));
319+
EXPECT_FALSE(Downcast<Bool>(features.at("has_fp16_simd")));
320+
321+
// Check that the "has_fp16_simd" feature is set when "+fullfp16" is explicitly set as an
322+
// attribute.
323+
target = ParseTargetWithAttrs("", "aarch64-arm-none-eabi", {arch_attr, "+fullfp16"});
324+
features = Downcast<TargetFeatures>(target.at("features"));
325+
EXPECT_TRUE(IsArch(target));
326+
EXPECT_TRUE(Downcast<Bool>(features.at("has_fp16_simd")));
327+
328+
// Check that the "has_fp16_simd" feature is set when "+sve" is explicitly set as an attribute.
329+
target = ParseTargetWithAttrs("", "aarch64-arm-none-eabi", {arch_attr, "+sve"});
330+
features = Downcast<TargetFeatures>(target.at("features"));
331+
EXPECT_TRUE(IsArch(target));
332+
EXPECT_TRUE(Downcast<Bool>(features.at("has_fp16_simd")));
333+
}
334+
310335
INSTANTIATE_TEST_CASE_P(AProfileParser, AProfileOptionalI8MM, ::testing::ValuesIn(optionalI8MM));
311336
INSTANTIATE_TEST_CASE_P(AProfileParser, AProfileOptionalDotProd,
312337
::testing::ValuesIn(optionalDotProd));
313338
INSTANTIATE_TEST_CASE_P(AProfileParser, AProfileOptionalSVE,
314339
::testing::Values(8.0, 8.1, 8.2, 8.3, 8.4, 8.5, 8.6, 8.7, 8.8, 8.9, 9.0));
340+
INSTANTIATE_TEST_CASE_P(AProfileParser, AProfileOptionalFP16,
341+
::testing::Values(8.2, 8.3, 8.4, 8.5, 8.6, 8.7, 8.8, 8.9, 9.0));
315342

316343
} // namespace aprofile
317344
} // namespace parsers

0 commit comments

Comments
 (0)