Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions python/tvm/relay/op/strategy/arm_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,13 +93,18 @@ def conv2d_strategy_arm_cpu(attrs, inputs, out_type, target):
wrap_compute_conv2d(topi.arm_cpu.conv2d_nchw_spatial_pack),
wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_nchw_spatial_pack),
name="conv2d_nchw_spatial_pack.arm_cpu",
plevel=10,
)

if topi.arm_cpu.is_int8_hw_support(data.dtype, kernel.dtype):
if (
topi.arm_cpu.is_int8_hw_support(data.dtype, kernel.dtype)
and kernel.shape[1] >= 64
):
strategy.add_implementation(
wrap_compute_conv2d(topi.arm_cpu.conv2d_nchw_int8),
wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_nchw_int8),
name="conv2d_nchw_int8.arm_cpu",
plevel=15,
)
else:
strategy.add_implementation(
Expand Down Expand Up @@ -383,12 +388,16 @@ def conv2d_gemm_without_weight_transform_strategy_arm_cpu(attrs, inputs, out_typ
if layout == "NHWC" and data.dtype in ["int8", "uint8"]:
strategy.add_implementation(
wrap_compute_conv2d_gemm(native_compute),
wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_NHWC_quantized_native),
wrap_topi_schedule(
topi.arm_cpu.schedule_conv2d_NHWC_quantized_native_without_transform
),
name="conv2d_NHWC_quantized_native_without_transform.arm_cpu",
)
strategy.add_implementation(
wrap_compute_conv2d_gemm(interleaved_compute),
wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_NHWC_quantized_interleaved),
wrap_topi_schedule(
topi.arm_cpu.schedule_conv2d_NHWC_quantized_interleaved_without_transform
),
name="conv2d_NHWC_quantized_interleaved_without_transform.arm_cpu",
)
else:
Expand Down
29 changes: 23 additions & 6 deletions python/tvm/topi/arm_cpu/conv2d_alter_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from ..nn import conv2d_alter_layout, conv2d_legalize
from ..utils import get_const_tuple
from ..x86.conv2d import _get_default_config as _get_x86_default_config
from ..x86.conv2d_int8 import _get_default_config_int8
from .conv2d_int8 import is_int8_hw_support
from .arm_utils import get_tiling_B_interleaved_t
from ..generic.conv2d import conv2d_alter_int8_common
Expand Down Expand Up @@ -101,9 +102,6 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type):
# we then assume it's not necessary to alter this op.
return None
cfg = dispatch_ctx.query(target, workload)
if cfg.is_fallback: # if is fallback, clear query cache and return None
autotvm.task.clear_fallback_cache(target, workload)
return None
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I hope it's safe to remove this. The x86 counterpart doesn't have thing like this.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How are you testing it ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I can pass the CI with this change, I assume it is safe.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Mousius - could you please look at this given you've recently been turning on topi tests on aarch64 ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I need to remove this code because it always makes alter_layout nop in the fallback mode. In contrast, in the x86 schedule, alter_layout always fires.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The CI has passed. I found that Giuseppe's im2col based conv2d implementation can fail to tensorize in the fallback mode, so I partially restored the fallback return path above.


topi_tmpl = workload[0]
new_attrs = {k: attrs[k] for k in attrs.keys()}
Expand Down Expand Up @@ -346,6 +344,11 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type):

if topi_tmpl == "conv2d_NCHWc_int8.arm_cpu":
assert data_layout == "NCHW" and kernel_layout == "OIHW"
batch_size, in_channel, height, width = get_const_tuple(data_tensor.shape)
out_channel, _, kh, kw = get_const_tuple(kernel_tensor.shape)

n_elems = 8

if cfg.is_fallback:
_get_default_config_int8(
cfg,
Expand All @@ -357,12 +360,14 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type):
out_dtype,
False,
data_layout,
int32_lanes=4,
)

batch_size, in_channel, height, width = get_const_tuple(data_tensor.shape)
out_channel, channel_multiplier, kh, kw = get_const_tuple(kernel_tensor.shape)
ic_bn, oc_bn = cfg["tile_ic"].size[-1], cfg["tile_oc"].size[-1]
n_elems = 8

if cfg.is_fallback:
# ic_bn needs to be divided by n_elems below
ic_bn = max(ic_bn, n_elems)

# update new attrs
new_attrs["channels"] = out_channel
Expand Down Expand Up @@ -395,6 +400,12 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type):
return relay.nn.contrib_conv2d_nchwc(*inputs, **new_attrs)

if topi_tmpl == "conv2d_NHWC_quantized_interleaved.arm_cpu":
# TODO(masahi): This schedule can easily result in a tensorization error
# if used in the fallback mode
if cfg.is_fallback: # if is fallback, clear query cache and return None
autotvm.task.clear_fallback_cache(target, workload)
return None

assert data_layout == "NHWC" and kernel_layout == "HWIO"
KH, KW, _, OC = get_const_tuple(kernel.shape)
new_workload_name = "conv2d_NHWC_quantized_interleaved_without_transform.arm_cpu"
Expand All @@ -411,6 +422,12 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type):
inputs[0], new_kernel_expr, **new_attrs
)
if topi_tmpl == "conv2d_NHWC_quantized_native.arm_cpu":
# TODO(masahi): This schedule can easily result in a tensorization error
# if used in the fallback mode
if cfg.is_fallback: # if is fallback, clear query cache and return None
autotvm.task.clear_fallback_cache(target, workload)
return None

assert data_layout == "NHWC" and kernel_layout == "HWIO"
KH, KW, _, OC = get_const_tuple(kernel.shape)
new_workload_name = "conv2d_NHWC_quantized_native_without_transform.arm_cpu"
Expand Down
12 changes: 12 additions & 0 deletions python/tvm/topi/arm_cpu/conv2d_int8.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,12 @@ def schedule_conv2d_NHWC_quantized_interleaved(cfg, outs):
return _schedule_conv2d_NHWC_quantized(cfg, outs, True)


@autotvm.register_topi_schedule("conv2d_NHWC_quantized_interleaved_without_transform.arm_cpu")
def schedule_conv2d_NHWC_quantized_interleaved_without_transform(cfg, outs):
"""Interface for interleaved schedule_conv2d_NHWC_quantized_interleaved"""
return _schedule_conv2d_NHWC_quantized(cfg, outs, True)


# Native schedules: those schedule won't interleave A (which is left in its native form).
# The weights are interleaved and transposed
@autotvm.register_topi_compute("conv2d_NHWC_quantized_native.arm_cpu")
Expand Down Expand Up @@ -330,3 +336,9 @@ def compute_conv2d_NHWC_quantized_native_without_transform(
def schedule_conv2d_NHWC_quantized_native(cfg, outs):
"""Interface for native schedule_conv2d_NHWC_quantized"""
return _schedule_conv2d_NHWC_quantized(cfg, outs, False)


@autotvm.register_topi_schedule("conv2d_NHWC_quantized_native_without_transform.arm_cpu")
def schedule_conv2d_NHWC_quantized_native_without_transform(cfg, outs):
"""Interface for native schedule_conv2d_NHWC_quantized"""
return _schedule_conv2d_NHWC_quantized(cfg, outs, False)
1 change: 1 addition & 0 deletions python/tvm/topi/x86/conv2d_alter_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type):
out_dtype,
False,
data_layout,
int32_lanes=16,
)

batch_size, in_channel, height, width = get_const_tuple(data_tensor.shape)
Expand Down
16 changes: 13 additions & 3 deletions python/tvm/topi/x86/conv2d_int8.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,16 @@


def _get_default_config_int8(
cfg, data, kernel, strides, padding, dilation, out_dtype, is_depthwise=False, layout="NCHW"
cfg,
data,
kernel,
strides,
padding,
dilation,
out_dtype,
is_depthwise=False,
layout="NCHW",
int32_lanes=4,
):
"""
Get default schedule config for the workload
Expand All @@ -50,11 +59,11 @@ def _get_default_config_int8(
is_kernel_1x1 = wkl.kernel_h == 1 and wkl.kernel_w == 1
if is_kernel_1x1:
conv2d_generic.fallback_schedule_cpu_1x1_int8(
cfg, wkl, int32_lanes=16, num_int8_elements=4
cfg, wkl, int32_lanes=int32_lanes, num_int8_elements=4
)
else:
conv2d_generic.fallback_schedule_cpu_common_int8(
cfg, wkl, int32_lanes=16, num_int8_elements=4
cfg, wkl, int32_lanes=int32_lanes, num_int8_elements=4
)


Expand Down Expand Up @@ -163,6 +172,7 @@ def conv2d_NCHWc_int8(cfg, data, kernel, strides, padding, dilation, layout, out
padding,
dilation,
out_dtype,
int32_lanes=16,
)

# Pack data if raw 4-D data is provided.
Expand Down
46 changes: 46 additions & 0 deletions tests/python/unittest/test_meta_schedule_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@
# under the License.
import sys
from typing import List
import numpy as np

import pytest
import tvm
from tvm import relay
from tvm import meta_schedule as ms
from tvm.ir.module import IRModule
from tvm.meta_schedule.database import PyDatabase, TuningRecord, Workload
Expand Down Expand Up @@ -149,5 +151,49 @@ def extract_task_qbert():
assert "vnni" in annotations["schedule_rule"]


def extract_task_arm_conv2d_nchwc():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we sure this runs @masahi? There's no test_ prefix

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oops thanks for pointing this out, I added a fix in #10773

data_shape = (1, 64, 128, 128)
weight_shape = (32, 64, 1, 1)
bias_shape = (weight_shape[0],)
padding = (1, 1)

data = relay.var("data", shape=data_shape, dtype="int8")
weight = relay.var("weight", shape=weight_shape, dtype="int8")
bias = relay.var("bias", shape=bias_shape, dtype="int32")
conv2d = relay.nn.conv2d(
data=data,
weight=weight,
kernel_size=weight_shape[2:],
channels=weight_shape[0],
padding=padding,
strides=(1, 1),
out_dtype="int32",
)
bias_add = relay.nn.bias_add(conv2d, bias)
relay_mod = tvm.IRModule.from_expr(bias_add)

weight_np = np.random.uniform(1, 10, size=weight_shape).astype("int8")
bias_np = np.random.uniform(1, 10, size=bias_shape).astype("int32")

params = {"weight": weight_np, "bias": bias_np}

target = "llvm -device arm_cpu -mtriple aarch64-linux-gnu -mattr=+neon"
extracted_tasks = extract_task_from_relay(relay_mod, target, params)
tune_tasks = list(
filter(
lambda task: "conv2d" in task.task_name,
extracted_tasks,
)
)

assert len(tune_tasks) == 1

relay_func = list(tune_tasks[0].mod.functions.values())[0]
out_type = relay_func.body.checked_type

# Check that the output is in NCHWc layout
assert list(out_type.shape) == [1, 8, 130, 130, 4]


if __name__ == "__main__":
sys.exit(pytest.main([__file__] + sys.argv[1:]))