Skip to content

Commit f719151

Browse files
authored
[Bugfix][Strategy] Fix arm_cpu int8 conv2d strategy for dotprod and i8mm targets (#15711)
Whenever both dotprod and i8mm were available together on a target (e.g. `"llvm --device=arm_cpu --mtriple=aarch64-linux-gnu -mattr=+v8.2a,+dotprod,+i8mm"`), the native int8 conv2d implementation corresponding to the `+dotprod` attribute would be selected, but the compute definition of the conv2d operation would be constructed for the `+i8mm` attribute and its related interleaved schedule instead. The reason for this was a different order of conditional statements being used in 2 separate files: - `arm_cpu.py`: When selecting the conv2d implementation, the program first checked for `dotprod` support. If present, it chose the native schedule - `conv2d_gemm.py`: when constructing the compute definition, `i8mm` support is checked first, then `dotprod` To fix this, I modified the int8 conv2d strategy to prioritize `i8mm` over `dotprod` when both are available too.
1 parent f23d6b2 commit f719151

File tree

2 files changed

+45
-12
lines changed

2 files changed

+45
-12
lines changed

python/tvm/relay/op/strategy/arm_cpu.py

Lines changed: 37 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -213,19 +213,35 @@ def conv2d_strategy_arm_cpu(attrs, inputs, out_type, target):
213213
is_aarch64 = target.features.is_aarch64
214214
has_asimd = target.features.has_asimd
215215
has_dot_prod = target.features.has_dotprod
216+
has_matmul_i8 = target.features.has_matmul_i8
216217

217-
if has_dot_prod and data.dtype in ["int8", "uint8"]:
218-
strategy.add_implementation(
219-
wrap_compute_conv2d(topi.arm_cpu.compute_conv2d_NHWC_quantized_native),
220-
wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_NHWC_quantized_native),
221-
name="conv2d_NHWC_quantized_native.arm_cpu",
222-
)
223-
if is_aarch64 and has_asimd and data.dtype in ["int8", "uint8"]:
224-
strategy.add_implementation(
225-
wrap_compute_conv2d(topi.arm_cpu.compute_conv2d_NHWC_quantized_interleaved),
226-
wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_NHWC_quantized_interleaved),
227-
name="conv2d_NHWC_quantized_interleaved.arm_cpu",
228-
)
218+
if data.dtype in ["int8", "uint8"]:
219+
if has_matmul_i8:
220+
strategy.add_implementation(
221+
wrap_compute_conv2d(
222+
topi.arm_cpu.compute_conv2d_NHWC_quantized_interleaved
223+
),
224+
wrap_topi_schedule(
225+
topi.arm_cpu.schedule_conv2d_NHWC_quantized_interleaved
226+
),
227+
name="conv2d_NHWC_quantized_interleaved.arm_cpu",
228+
)
229+
if has_dot_prod:
230+
strategy.add_implementation(
231+
wrap_compute_conv2d(topi.arm_cpu.compute_conv2d_NHWC_quantized_native),
232+
wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_NHWC_quantized_native),
233+
name="conv2d_NHWC_quantized_native.arm_cpu",
234+
)
235+
if is_aarch64 and has_asimd:
236+
strategy.add_implementation(
237+
wrap_compute_conv2d(
238+
topi.arm_cpu.compute_conv2d_NHWC_quantized_interleaved
239+
),
240+
wrap_topi_schedule(
241+
topi.arm_cpu.schedule_conv2d_NHWC_quantized_interleaved
242+
),
243+
name="conv2d_NHWC_quantized_interleaved.arm_cpu",
244+
)
229245
if (not is_aarch64) or (data.dtype not in ["int8", "uint8"]):
230246
# TODO(@giuseros)
231247
# This strategy errors out for quantized data types when tuning.
@@ -471,10 +487,19 @@ def conv2d_gemm_without_weight_transform_strategy_arm_cpu(attrs, inputs, out_typ
471487
is_aarch64 = target.features.is_aarch64
472488
has_asimd = target.features.has_asimd
473489
has_dot_prod = target.features.has_dotprod
490+
has_matmul_i8 = target.features.has_matmul_i8
474491

475492
interleaved_compute = topi.arm_cpu.compute_conv2d_NHWC_quantized_interleaved_without_transform
476493
native_compute = topi.arm_cpu.compute_conv2d_NHWC_quantized_native_without_transform
477494
if layout == "NHWC" and data.dtype in ["int8", "uint8"]:
495+
if has_matmul_i8:
496+
strategy.add_implementation(
497+
wrap_compute_conv2d_gemm(interleaved_compute),
498+
wrap_topi_schedule(
499+
topi.arm_cpu.schedule_conv2d_NHWC_quantized_interleaved_without_transform
500+
),
501+
name="conv2d_NHWC_quantized_interleaved_without_transform.arm_cpu",
502+
)
478503
if has_dot_prod:
479504
strategy.add_implementation(
480505
wrap_compute_conv2d_gemm(native_compute),

tests/python/relay/strategy/test_select_implementation.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,14 @@ def test_concatenate(target, expected_implementation):
8181
"llvm --device=arm_cpu --mtriple=aarch64-linux-gnu -mattr=+v8.2a,+i8mm",
8282
"conv2d_NHWC_quantized_interleaved_without_transform.arm_cpu",
8383
),
84+
(
85+
"llvm --device=arm_cpu --mtriple=aarch64-linux-gnu -mattr=+v8.2a,+dotprod,+i8mm",
86+
"conv2d_NHWC_quantized_interleaved_without_transform.arm_cpu",
87+
),
88+
(
89+
"llvm --device=arm_cpu --mtriple=aarch64-linux-gnu -mattr=+v9a",
90+
"conv2d_NHWC_quantized_interleaved_without_transform.arm_cpu",
91+
),
8492
],
8593
)
8694
def test_int8_conv2d(target, expected_impl):

0 commit comments

Comments
 (0)