Skip to content
12 changes: 6 additions & 6 deletions include/tvm/relay/attrs/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -430,10 +430,10 @@ struct Conv3DTransposeAttrs : public tvm::AttrsNode<Conv3DTransposeAttrs> {
"dimensions respectively. Convolution is applied on the 'D', 'H' and"
"'W' dimensions.");
TVM_ATTR_FIELD(kernel_layout)
.set_default("OIDHW")
.set_default("IODHW")
.describe(
"Dimension ordering of data and weight. Can be 'OIDHW', 'OIDHW16o16i', etc."
"'O', 'I', 'D', 'H', 'W' stands for num_filter, input_channel, depth, height, and width"
"Dimension ordering of data and weight. Can be 'IODHW', 'IODHW16i16o', etc."
"'I', 'O', 'D', 'H', 'W' stands for input_channel, num_filter, depth, height, and width"
"dimensions respectively.");
TVM_ATTR_FIELD(out_layout)
.set_default("")
Expand Down Expand Up @@ -588,10 +588,10 @@ struct Conv2DTransposeAttrs : public tvm::AttrsNode<Conv2DTransposeAttrs> {
"dimensions respectively. Convolution is applied on the 'H' and"
"'W' dimensions.");
TVM_ATTR_FIELD(kernel_layout)
.set_default("OIHW")
.set_default("IOHW")
.describe(
"Dimension ordering of data and weight. Can be 'OIHW', 'OIHW16o16i', etc."
"'O', 'I', 'H', 'W' stands for num_filter, input_channel, height, and width"
"Dimension ordering of data and weight. Can be 'IOHW', 'OIHW16o16i', etc."
"'I', 'O', 'H', 'W' stands for input_channel, num_filter, height, and width"
"dimensions respectively.");
TVM_ATTR_FIELD(out_layout)
.set_default("")
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/op/nn/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,7 +436,7 @@ def conv3d_transpose(
channels=None,
kernel_size=None,
data_layout="NCDHW",
kernel_layout="OIDHW",
kernel_layout="IODHW",
out_layout="",
output_padding=(0, 0, 0),
out_dtype="",
Expand Down
18 changes: 10 additions & 8 deletions src/relay/op/nn/convolution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -594,7 +594,7 @@ bool Conv3DTransposeRel(const Array<Type>& types, int num_inputs, const Attrs& a
if (data == nullptr) return false;

static const Layout kNCDHW("NCDHW");
static const Layout kOIDHW("OIDHW");
static const Layout kIODHW("IODHW");

const Conv3DTransposeAttrs* param = attrs.as<Conv3DTransposeAttrs>();
ICHECK(param != nullptr);
Expand All @@ -606,9 +606,9 @@ bool Conv3DTransposeRel(const Array<Type>& types, int num_inputs, const Attrs& a
<< "Conv3d_transpose only support input layouts that are convertible from NCDHW."
<< " But got " << in_layout;

const auto trans_kernel_layout = tir::BijectiveLayout(kernel_layout, kOIDHW);
const auto trans_kernel_layout = tir::BijectiveLayout(kernel_layout, kIODHW);
ICHECK(trans_kernel_layout.defined())
<< "Conv3d_transpose only support kernel layouts that are convertible from OIDHW."
<< "Conv3d_transpose only support kernel layouts that are convertible from IODHW."
<< " But got " << kernel_layout;

Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout);
Expand Down Expand Up @@ -651,16 +651,18 @@ bool Conv3DTransposeRel(const Array<Type>& types, int num_inputs, const Attrs& a
ICHECK(reporter->AssertEQ(param->kernel_size[0], wshape[2]) &&
reporter->AssertEQ(param->kernel_size[1], wshape[3]) &&
reporter->AssertEQ(param->kernel_size[2], wshape[4]))
<< "Conv3D: shape of weight is inconsistent with kernel_size, "
<< "Conv3DTransposed: shape of weight is inconsistent with kernel_size, "
<< " kernel_size=" << param->kernel_size << " wshape=" << Array<IndexExpr>(wshape);
}
if (param->channels.defined()) {
ICHECK(reporter->AssertEQ(param->channels, wshape[1]))
<< "Conv3D: shape of weight is inconsistent with channels, "
<< " channels=" << param->channels << " wshape=" << Array<IndexExpr>(wshape);
ICHECK(reporter->AssertEQ(indexdiv(param->channels, param->groups), wshape[1]))
<< "Conv3DTransposed: shape of weight is inconsistent out_channels, "
<< " out_channels // groups != weight.shape[1] "
<< " out_channels=" << param->channels << " groups=" << param->groups
<< " wshape=" << Array<IndexExpr>(wshape);
}
if (!dshape_ncdhw[1].as<tir::AnyNode>() && !wshape[0].as<tir::AnyNode>()) {
ICHECK(reporter->AssertEQ(indexdiv(dshape_ncdhw[1], param->groups), wshape[0]));
ICHECK(reporter->AssertEQ(dshape_ncdhw[1], wshape[0]));
}
channels = wshape[1];
dilated_ksize_d = 1 + (wshape[2] - 1) * param->dilation[0];
Expand Down