Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions python/tvm/relay/op/strategy/arm_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,14 +88,6 @@ def conv2d_strategy_arm_cpu(attrs, inputs, out_type, target):
if groups == 1:
if layout == "NCHW":
if kernel_layout == "OIHW":
# ARM conv2d spatial pack schedule.
strategy.add_implementation(
wrap_compute_conv2d(topi.arm_cpu.conv2d_nchw_spatial_pack),
wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_nchw_spatial_pack),
name="conv2d_nchw_spatial_pack.arm_cpu",
plevel=10,
)

if (
topi.arm_cpu.is_int8_hw_support(data.dtype, kernel.dtype)
and kernel.shape[1] >= 64
Expand All @@ -107,6 +99,14 @@ def conv2d_strategy_arm_cpu(attrs, inputs, out_type, target):
plevel=15,
)
else:
# ARM conv2d spatial pack schedule.
strategy.add_implementation(
wrap_compute_conv2d(topi.arm_cpu.conv2d_nchw_spatial_pack),
wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_nchw_spatial_pack),
name="conv2d_nchw_spatial_pack.arm_cpu",
plevel=10,
)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

schedule_conv2d_nchw_spatial_pack has tophub entries, so they are always picked up based on "time cost" even if topi.arm_cpu.is_int8_hw_support path is taken and the NCHWc int8 schedule is surely faster.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like another good argument for removing tophub. You can disable it for the test by putting it all in an empty ApplyHistoryBest.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The annoying thing is this happens on only specific input or filter sizes. It took me some time to figure out why the NCHWc schedule is not used when the input shape is (56, 56) and filter size is 3x3. If I use (128, 128) shape or 1x1 filter, I didn't see this issue.

strategy.add_implementation(
wrap_compute_conv2d(topi.x86.conv2d_nchw),
wrap_topi_schedule(topi.x86.schedule_conv2d_nchw),
Expand Down
3 changes: 2 additions & 1 deletion python/tvm/topi/arm_cpu/conv2d_int8.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,8 +152,9 @@ def _callback(op):
_, _, kh, kw, _, _, _ = get_const_tuple(kernel_vec.shape)
dtype = "uint" if data.dtype == "uint8" else "int"
if is_dotprod_available():
intrin = dot_int8_int8_int32_neon_82(int32_lanes=4)
intrin = dot_int8_int8_int32_neon_82(int32_lanes=4, dtype=dtype)
elif is_neon_available():
assert dtype == "int", "uint8 not supported if dot product is not available"
intrin = dot_int8_int8_int32_neon()
else:
raise RuntimeError(
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/topi/arm_cpu/tensor_intrin.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,7 +516,7 @@ def _instr(index):
int32_lanes * num_int8_elements,
)
vdot = tvm.tir.call_llvm_pure_intrin(
dtype_c, inst, tvm.tir.const(2, "uint32"), vec_c, vec_a, vec_b
dtype_c, inst, tvm.tir.const(3, "uint32"), vec_c, vec_a, vec_b
)
ib.emit(outs[0].vstore(0, vdot))
return ib.get()
Expand Down
97 changes: 46 additions & 51 deletions tests/python/topi/python/test_topi_conv2d_int8.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ def verify_conv2d_NCHWc_int8(
lo = -128 if in_dtype == "int8" else 0
hi = 127 if in_dtype == "int8" else 255

def check_target(target, compute, schedule, oc_block_factor):
def check_target(target, compute, schedule, oc_block_factor, build_only):
dev = tvm.device(target, 0)
if not tvm.testing.device_enabled(target):
print("Skip because %s is not enabled" % target)
Expand Down Expand Up @@ -323,45 +323,27 @@ def get_ref_data():
w = tvm.nd.array(w_np.astype(dtype), dev)
b = tvm.nd.array(b_np.astype(out_dtype), dev)
c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), dev)

if add_bias:
tvm.build(
s,
[A, W, bias, C],
target,
name="relu_%d_%d_%d_%d_%d_%d_%d_%d"
% (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation),
)
func = tvm.build(
s,
[A, W, bias, C],
target,
name="relu_%d_%d_%d_%d_%d_%d_%d_%d"
% (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation),
)
try:
func(a, w, b, c)
except tvm.TVMError as e:
if "architecture mismatch" in str(e):
print(f"Skipping execution because {target} is not supported by this CPU")
return
else:
raise
compile_args = [A, W, bias, C]
run_args = [a, w, b, c]
else:
func = tvm.build(
s,
[A, W, C],
target,
name="relu_%d_%d_%d_%d_%d_%d_%d_%d"
% (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation),
)
try:
func(a, w, c)
except tvm.TVMError as e:
if "architecture mismatch" in str(e):
print(f"Skipping execution because {target} is not supported by this CPU")
return
else:
raise
compile_args = [A, W, C]
run_args = [a, w, c]

func = tvm.build(
s,
compile_args,
target,
name="relu_%d_%d_%d_%d_%d_%d_%d_%d"
% (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation),
)

if build_only:
return

func(*run_args)

tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-5)

targets = [
Expand All @@ -370,29 +352,42 @@ def get_ref_data():
lambda a, w, s, p, d, l, ol, o: topi.cuda.conv2d_NCHWc_int8(a, w, s, p, d, l, o),
topi.cuda.schedule_conv2d_NCHWc_int8,
4,
False,
),
# Disable on CI since it does not support spirv int8 dot product
# (
# "vulkan -from_device=0",
# lambda a, w, s, p, d, l, ol, o: topi.cuda.conv2d_NCHWc_int8(a, w, s, p, d, l, o),
# topi.cuda.schedule_conv2d_NCHWc_int8,
# 4,
# False,
# ),
]

# TODO(Mousius) Re-enable once implementation is fixed
# if in_dtype == "int8":
# targets.append(
# (
# "llvm -device arm_cpu -mtriple aarch64-linux-gnu -mattr=+neon",
# topi.arm_cpu.conv2d_NCHWc_int8,
# topi.arm_cpu.schedule_conv2d_NCHWc_int8,
# 8,
# )
# )

for target, compute, schedule, oc_block_factor in targets:
check_target(target, compute, schedule, oc_block_factor)
# TODO(tvm-team): Properly run ARM code on CI aarch64 environment
targets.append(
(
"llvm -device arm_cpu -mtriple aarch64-linux-gnu -mattr=+neon,+v8.2a,+dotprod",
topi.arm_cpu.conv2d_NCHWc_int8,
topi.arm_cpu.schedule_conv2d_NCHWc_int8,
8,
True,
)
)

if in_dtype == "int8":
targets.append(
(
"llvm -device arm_cpu -mtriple aarch64-linux-gnu -mattr=+neon",
topi.arm_cpu.conv2d_NCHWc_int8,
topi.arm_cpu.schedule_conv2d_NCHWc_int8,
8,
True,
)
)

for target, compute, schedule, oc_block_factor, build_only in targets:
check_target(target, compute, schedule, oc_block_factor, build_only)


def verify_conv2d_nchw_int8(
Expand Down
4 changes: 3 additions & 1 deletion tests/python/unittest/test_meta_schedule_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import pytest
import tvm
import tvm.testing
from tvm import relay
from tvm import meta_schedule as ms
from tvm.ir.module import IRModule
Expand Down Expand Up @@ -151,7 +152,8 @@ def extract_task_qbert():
assert "vnni" in annotations["schedule_rule"]


def extract_task_arm_conv2d_nchwc():
@tvm.testing.skip_if_32bit(reason="Apparently the LLVM version on i386 image is too old")
def test_extract_task_arm_conv2d_nchwc():
data_shape = (1, 64, 128, 128)
weight_shape = (32, 64, 1, 1)
bias_shape = (weight_shape[0],)
Expand Down