Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion python/tvm/relay/op/strategy/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,7 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target):
and layout in ["NCHW", "NHWC"]
and padding[0] == padding[2]
and padding[1] == padding[3]
and not ((data.dtype == "int8" or kernel.dtype == "int8") and layout == "NCHW")
):
# add cudnn implementation
if layout == "NHWC":
Expand Down Expand Up @@ -347,7 +348,12 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target):
# add cudnn implementation, if any
cudnn_impl = False
if target.kind.name == "cuda" and "cudnn" in target.libs:
if layout in ["NCHW", "NHWC"] and padding[0] == padding[2] and padding[1] == padding[3]:
if (
layout in ["NCHW", "NHWC"]
and padding[0] == padding[2]
and padding[1] == padding[3]
and not ((data.dtype == "int8" or kernel.dtype == "int8") and layout == "NCHW")
):
strategy.add_implementation(
wrap_compute_conv2d(
topi.cuda.conv2d_cudnn, need_data_layout=True, has_groups=True
Expand Down