Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/tvm/relay/op/strategy/arm_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,7 @@ def conv2d_strategy_arm_cpu(attrs, inputs, out_type, target):
else:
logger.warning("depthwise_conv2d with layout NHWC is not optimized for arm cpu.")
strategy.add_implementation(
wrap_compute_conv2d(topi.nn.depthwise_conv2d_nhwc),
wrap_compute_conv2d(topi.nn.depthwise_conv2d_nhwc, need_kernel_layout=True),
wrap_topi_schedule(conv2d_generic.schedule_depthwise_conv2d_nhwc),
name="depthwise_conv2d_nhwc.generic",
)
Expand Down
3 changes: 1 addition & 2 deletions python/tvm/relay/op/strategy/hexagon.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,8 @@ def conv2d_strategy_hexagon(attrs, inputs, out_type, target):
name="depthwise_conv2d_nchw.hexagon",
)
elif layout == "NHWC":
assert kernel_layout == "HWOI"
strategy.add_implementation(
wrap_compute_conv2d(topi.nn.depthwise_conv2d_nhwc),
wrap_compute_conv2d(topi.nn.depthwise_conv2d_nhwc, need_kernel_layout=True),
wrap_topi_schedule(topi.hexagon.schedule_depthwise_conv2d_nhwc),
name="depthwise_conv2d_nhwc.hexagon",
)
Expand Down
3 changes: 1 addition & 2 deletions python/tvm/relay/op/strategy/x86.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,13 +228,12 @@ def conv2d_strategy_cpu(attrs, inputs, out_type, target):
assert _OIHWio_matcher.match(kernel_layout) # check if kernel is OIHWio
return depthwise_conv2d_NCHWc_strategy_cpu(attrs, inputs, out_type, target)
elif layout == "NHWC":
assert kernel_layout == "HWOI"
if (not need_auto_scheduler_layout) and (not need_meta_schedule_layout):
logger.warning(
"depthwise_conv2d NHWC layout is not optimized for x86 with autotvm."
)
strategy.add_implementation(
wrap_compute_conv2d(topi.nn.depthwise_conv2d_nhwc),
wrap_compute_conv2d(topi.nn.depthwise_conv2d_nhwc, need_kernel_layout=True),
wrap_topi_schedule(topi.generic.schedule_depthwise_conv2d_nhwc),
name="depthwise_conv2d_nhwc.generic",
)
Expand Down
19 changes: 16 additions & 3 deletions python/tvm/topi/nn/depthwise_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from __future__ import absolute_import as _abs
from collections import namedtuple
import tvm
import numpy as np
from tvm import te

from .dilate import dilate
Expand Down Expand Up @@ -211,7 +212,9 @@ def depthwise_conv2d_nchw(Input, Filter, stride, padding, dilation, out_dtype=No
return Output


def depthwise_conv2d_nhwc(Input, Filter, stride, padding, dilation, out_dtype=None):
def depthwise_conv2d_nhwc(
Input, Filter, stride, padding, dilation, kernel_layout="HWOI", out_dtype=None
):
"""Depthwise convolution nhwc forward operator.

Parameters
Expand Down Expand Up @@ -252,8 +255,14 @@ def depthwise_conv2d_nhwc(Input, Filter, stride, padding, dilation, out_dtype=No
dilation_h, dilation_w = dilation

batch, in_height, in_width, in_channel = Input.shape

# shape of dilated kernel
filter_height, filter_width, filter_channel, channel_multiplier = Filter.shape
if kernel_layout == "HWIO":
filter_height, filter_width, channel_multiplier, filter_channel = Filter.shape
kernel_permutation = [0, 1, 3, 2]
else:
filter_height, filter_width, filter_channel, channel_multiplier = Filter.shape
kernel_permutation = [0, 1, 2, 3]

dilated_kernel_h = (filter_height - 1) * dilation_h + 1
dilated_kernel_w = (filter_width - 1) * dilation_w + 1
Expand Down Expand Up @@ -285,7 +294,11 @@ def depthwise_conv2d_nhwc(Input, Filter, stride, padding, dilation, out_dtype=No
idxdiv(c, channel_multiplier),
].astype(out_dtype)
* Filter[
di, dj, idxdiv(c, channel_multiplier), idxmod(c, channel_multiplier)
tuple(
np.array(
[di, dj, idxdiv(c, channel_multiplier), idxmod(c, channel_multiplier)]
)[kernel_permutation]
)
].astype(out_dtype)
),
axis=[di, dj],
Expand Down