diff --git a/src/relay/op/nn/convolution.cc b/src/relay/op/nn/convolution.cc index 446f34340caf..447bbd4926b4 100644 --- a/src/relay/op/nn/convolution.cc +++ b/src/relay/op/nn/convolution.cc @@ -54,6 +54,96 @@ Expr MakeConvGemmWeightTransform(Expr weight, int tile_rows, int tile_cols, std: // relay.nn.conv1d TVM_REGISTER_NODE_TYPE(Conv1DAttrs); +bool Conv1DRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { + ICHECK_EQ(types.size(), 3); + const auto* data = types[0].as(); + const auto* weight = types[1].as(); + if (data == nullptr) return false; + static const Layout kNCW("NCW"); + static const Layout kOIW("OIW"); + + const auto* param = attrs.as(); + ICHECK(param != nullptr); + const Layout in_layout(param->data_layout); + const Layout kernel_layout(param->kernel_layout); + + const auto trans_in_layout = tir::BijectiveLayout(in_layout, kNCW); + ICHECK(trans_in_layout.defined()) + << "Conv only support input layouts that are convertible from NCW." + << " But got " << in_layout; + + const auto trans_kernel_layout = tir::BijectiveLayout(kernel_layout, kOIW); + ICHECK(trans_kernel_layout.defined()) + << "Conv only support kernel layouts that are convertible from OIW." + << " But got " << kernel_layout; + + Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout); + const auto trans_out_layout = tir::BijectiveLayout(out_layout, kNCW); + ICHECK(trans_out_layout.defined()) + << "Conv only support output layouts that are convertible from NCW." + << " But got " << out_layout; + + Array dshape_ncw = trans_in_layout.ForwardShape(data->shape); + + IndexExpr channels, dilated_ksize; + // infer weight if the kernel_size and channels are defined + if (param->kernel_size.defined() && param->channels.defined()) { + Array wshape; + + wshape = {{param->channels, indexdiv(dshape_ncw[1], param->groups), param->kernel_size[0]}}; + + wshape = trans_kernel_layout.BackwardShape(wshape); + channels = param->channels; + dilated_ksize = 1 + (param->kernel_size[0] - 1) * param->dilation[0]; + DataType weight_dtype = data->dtype; + if (weight != nullptr) { + weight_dtype = weight->dtype; + } + // assign result to reporter + reporter->Assign(types[1], TensorType(wshape, weight_dtype)); + } else { + // use weight to infer the conv shape. + if (weight == nullptr) return false; + auto wshape = trans_kernel_layout.ForwardShape(weight->shape); + if (param->kernel_size.defined()) { + // check the size + ICHECK(reporter->AssertEQ(param->kernel_size[0], wshape[2])) + << "Conv1D: shape of weight is inconsistent with kernel_size, " + << " kernel_size=" << param->kernel_size << " wshape=" << wshape; + } + if (param->channels.defined()) { + ICHECK(reporter->AssertEQ(param->channels, wshape[0])) + << "Conv1D: shape of weight is inconsistent with channels, " + << " channels=" << param->channels << " wshape=" << wshape; + } + if (!dshape_ncw[1].as() && !wshape[1].as()) { + ICHECK(reporter->AssertEQ(dshape_ncw[1], wshape[1])); + } + channels = wshape[0]; + dilated_ksize = 1 + (wshape[2] - 1) * param->dilation[0]; + } + // dilation + Array oshape({dshape_ncw[0], channels, 0}); + + if (!dshape_ncw[2].as()) { + oshape.Set(2, indexdiv(dshape_ncw[2] + param->padding[0] + param->padding[1] - dilated_ksize, + param->strides[0]) + + 1); + } else { + oshape.Set(2, dshape_ncw[2]); + } + + DataType out_dtype = param->out_dtype; + if (out_dtype.bits() == 0) { + out_dtype = data->dtype; + } + oshape = trans_out_layout.BackwardShape(oshape); + // assign output type + reporter->Assign(types[2], TensorType(oshape, out_dtype)); + return true; +} + TVM_REGISTER_GLOBAL("relay.op.nn._make.conv1d") .set_body_typed([](Expr data, Expr weight, Array strides, Array padding, Array dilation, int groups, IndexExpr channels, @@ -82,12 +172,190 @@ with the layer input to produce a tensor of outputs. .add_argument("data", "Tensor", "The input tensor.") .add_argument("weight", "Tensor", "The weight tensor.") .set_support_level(2) - .add_type_rel("Conv1D", Conv1DRel) + .add_type_rel("Conv1D", Conv1DRel) .set_attr("FInferCorrectLayout", ConvInferCorrectLayout); // relay.nn.conv2d TVM_REGISTER_NODE_TYPE(Conv2DAttrs); +bool Conv2DRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { + ICHECK_EQ(types.size(), 3); + const auto* data = types[0].as(); + const auto* weight = types[1].as(); + if (data == nullptr) return false; + static const Layout kNCHW("NCHW"); + static const Layout kOIHW("OIHW"); + + const auto* param = attrs.as(); + ICHECK(param != nullptr); + const Layout in_layout(param->data_layout); + const Layout kernel_layout(param->kernel_layout); + + const auto trans_in_layout = tir::BijectiveLayout(in_layout, kNCHW); + if (!trans_in_layout.defined()) { + reporter->GetDiagCtx().Emit( + Diagnostic::Error(reporter->GetSpan()) + << "conv2d only support input layouts that are convertible from NCHW." + << " The provided layout is: " << in_layout); + return false; + } + + const auto trans_kernel_layout = tir::BijectiveLayout(kernel_layout, kOIHW); + if (!trans_kernel_layout.defined()) { + reporter->GetDiagCtx().Emit( + Diagnostic::Error(reporter->GetSpan()) + << "conv2d only support kernel layouts that are convertible from OIHW." + << " The provided layout is: " << kernel_layout); + return false; + } + + Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout); + const auto trans_out_layout = tir::BijectiveLayout(out_layout, kNCHW); + if (!trans_out_layout.defined()) { + reporter->GetDiagCtx().Emit( + Diagnostic::Error(reporter->GetSpan()) + << "conv2d only support output layouts that are convertible from NCHW." + << "The provided layout is: " << out_layout); + return false; + } + + Array dshape_nchw = trans_in_layout.ForwardShape(data->shape); + bool is_depthwise = false; + if (param->groups > 1) { + if (!(weight && weight->shape.defined())) { + reporter->GetDiagCtx().Emit( + Diagnostic::Error(reporter->GetSpan()) + << "Weight shape must be specified when groups is greater than 1."); + return false; + } + + Array wshape_oihw = trans_kernel_layout.ForwardShape(weight->shape); + if (tvm::tir::ExprDeepEqual()(param->groups, dshape_nchw[1]) && + tvm::tir::ExprDeepEqual()(param->groups, wshape_oihw[0])) { + is_depthwise = true; + } + } + + IndexExpr channels, dilated_ksize_y, dilated_ksize_x; + // infer weight if the kernel_size and channels are defined + if (param->kernel_size.defined() && param->channels.defined()) { + ICHECK_EQ(param->kernel_size.size(), 2); + ICHECK_EQ(param->dilation.size(), 2); + Array wshape; + + if (is_depthwise) { + // infer weight's shape for depthwise convolution + wshape = {{dshape_nchw[1], indexdiv(param->channels, dshape_nchw[1]), param->kernel_size[0], + param->kernel_size[1]}}; + } else { + wshape = {{param->channels, indexdiv(dshape_nchw[1], param->groups), param->kernel_size[0], + param->kernel_size[1]}}; + } + + wshape = trans_kernel_layout.BackwardShape(wshape); + channels = param->channels; + dilated_ksize_y = 1 + (param->kernel_size[0] - 1) * param->dilation[0]; + dilated_ksize_x = 1 + (param->kernel_size[1] - 1) * param->dilation[1]; + DataType weight_dtype = data->dtype; + if (weight != nullptr) { + weight_dtype = weight->dtype; + } + + if (param->auto_scheduler_rewritten_layout.size() == 0) { + // Normal case: assign result to reporter + reporter->Assign(types[1], TensorType(wshape, weight_dtype)); + } else { + // If the layout is rewritten by auto-scheduler, + // we just forcly apply the layout provided by auto-scheduler and + // skip the normal inference logic. + {} // do nothing + } + } else { + // use weight to infer the conv shape. + if (weight == nullptr) return false; + + Array wshape; + if (param->auto_scheduler_rewritten_layout.size() == 0) { + wshape = weight->shape; + } else { + // works for the default kernel layout "HWIO" + ICHECK_EQ(param->kernel_layout, "HWIO"); + wshape = auto_scheduler::GetShapeFromRewrittenLayout(param->auto_scheduler_rewritten_layout, + {"ry", "rx", "rc", "ff"}); + } + + wshape = trans_kernel_layout.ForwardShape(wshape); + if (param->kernel_size.defined()) { + ICHECK_EQ(param->kernel_size.size(), 2); + + if (!reporter->AssertEQ(param->kernel_size[0], wshape[2])) { + reporter->GetDiagCtx().Emit(Diagnostic::Error(reporter->GetSpan()) + << "Conv2D: shape of weight is inconsistent with kernel_size," + << " kernel_size=" << param->kernel_size + << " wshape=" << wshape); + } + + if (!reporter->AssertEQ(param->kernel_size[1], wshape[3])) { + reporter->GetDiagCtx().Emit(Diagnostic::Error(reporter->GetSpan()) + << "Conv2D: shape of weight is inconsistent with kernel_size," + << " kernel_size=" << param->kernel_size + << " wshape=" << wshape); + return false; + } + } + + if (param->channels.defined() && !reporter->AssertEQ(param->channels, wshape[0])) { + reporter->GetDiagCtx().Emit( + Diagnostic::Error(reporter->GetSpan()) + << "conv2D: the first dimensions of the weight tensor (" << wshape << ")" + << "does not match the number of channels (" << param->channels << ")."); + return false; + } + + if (!dshape_nchw[1].as() && !wshape[1].as()) { + if (!reporter->AssertEQ(indexdiv(dshape_nchw[1], param->groups), wshape[1])) { + reporter->GetDiagCtx().Emit(Diagnostic::Error(reporter->GetSpan()) + << "conv2d: requires that `" + << indexdiv(dshape_nchw[1], param->groups) << "`," + << " the input channels (" << dshape_nchw[1] << ")" + << " divided by groups (" << param->groups << ")" + << ",\n must match the input channels" + << " of the weight `" << wshape[1] + << "`, where the weight shape is (" << wshape << ")."); + return false; + } + } + channels = wshape[0]; + dilated_ksize_y = 1 + (wshape[2] - 1) * param->dilation[0]; + dilated_ksize_x = 1 + (wshape[3] - 1) * param->dilation[1]; + } + // dilation + Array oshape({dshape_nchw[0], channels, 0, 0}); + + IndexExpr pad_h, pad_w; + GetPaddingHeightWidth(param->padding, &pad_h, &pad_w); + if (!dshape_nchw[2].as()) { + oshape.Set(2, indexdiv(dshape_nchw[2] + pad_h - dilated_ksize_y, param->strides[0]) + 1); + } else { + oshape.Set(2, dshape_nchw[2]); + } + + if (!dshape_nchw[3].as()) { + oshape.Set(3, indexdiv(dshape_nchw[3] + pad_w - dilated_ksize_x, param->strides[1]) + 1); + } else { + oshape.Set(3, dshape_nchw[3]); + } + DataType out_dtype = param->out_dtype; + if (out_dtype.bits() == 0) { + out_dtype = data->dtype; + } + oshape = trans_out_layout.BackwardShape(oshape); + // assign output type + reporter->Assign(types[2], TensorType(oshape, out_dtype)); + return true; +} + TVM_REGISTER_GLOBAL("relay.op.nn._make.conv2d") .set_body_typed([](Expr data, Expr weight, Array strides, Array padding, Array dilation, int groups, IndexExpr channels, @@ -116,12 +384,152 @@ with the layer input to produce a tensor of outputs. .add_argument("data", "Tensor", "The input tensor.") .add_argument("weight", "Tensor", "The weight tensor.") .set_support_level(2) - .add_type_rel("Conv2D", Conv2DRel) + .add_type_rel("Conv2D", Conv2DRel) .set_attr("FInferCorrectLayout", ConvInferCorrectLayout); // relay.nn.conv3d TVM_REGISTER_NODE_TYPE(Conv3DAttrs); +bool Conv3DRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { + ICHECK_EQ(types.size(), 3); + const auto* data = types[0].as(); + const auto* weight = types[1].as(); + if (data == nullptr) return false; + static const Layout kNCDHW("NCDHW"); + static const Layout kOIDHW("OIDHW"); + + const auto* param = attrs.as(); + ICHECK(param != nullptr); + const Layout in_layout(param->data_layout); + const Layout kernel_layout(param->kernel_layout); + + const auto trans_in_layout = tir::BijectiveLayout(in_layout, kNCDHW); + ICHECK(trans_in_layout.defined()) + << "Conv only support input layouts that are convertible from NCDHW." + << " But got " << in_layout; + + const auto trans_kernel_layout = tir::BijectiveLayout(kernel_layout, kOIDHW); + ICHECK(trans_kernel_layout.defined()) + << "Conv only support kernel layouts that are convertible from OIDHW." + << " But got " << kernel_layout; + + Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout); + const auto trans_out_layout = tir::BijectiveLayout(out_layout, kNCDHW); + ICHECK(trans_out_layout.defined()) + << "Conv only support output layouts that are convertible from NCDHW." + << " But got " << out_layout; + + Array dshape_ncdhw = trans_in_layout.ForwardShape(data->shape); + + IndexExpr channels, dilated_ksize_z, dilated_ksize_y, dilated_ksize_x; + // infer weight if the kernel_size and channels are defined + if (param->kernel_size.defined() && param->channels.defined()) { + ICHECK_EQ(param->kernel_size.size(), 3); + ICHECK_EQ(param->dilation.size(), 3); + Array wshape; + tvm::tir::ExprDeepEqual expr_equal; + + if (expr_equal(param->channels, param->groups) && !expr_equal(param->channels, 1)) { + // infer weight's shape for depthwise convolution + wshape = {{dshape_ncdhw[1], indexdiv(param->groups, dshape_ncdhw[1]), param->kernel_size[0], + param->kernel_size[1], param->kernel_size[2]}}; + } else { + wshape = {{param->channels, indexdiv(dshape_ncdhw[1], param->groups), param->kernel_size[0], + param->kernel_size[1], param->kernel_size[2]}}; + } + + wshape = trans_kernel_layout.BackwardShape(wshape); + channels = param->channels; + dilated_ksize_z = 1 + (param->kernel_size[0] - 1) * param->dilation[0]; + dilated_ksize_y = 1 + (param->kernel_size[1] - 1) * param->dilation[1]; + dilated_ksize_x = 1 + (param->kernel_size[2] - 1) * param->dilation[2]; + DataType weight_dtype = data->dtype; + if (weight != nullptr) { + weight_dtype = weight->dtype; + } + + if (param->auto_scheduler_rewritten_layout.size() == 0) { + // Normal case: assign result to reporter + reporter->Assign(types[1], TensorType(wshape, weight_dtype)); + } else { + // If the layout is rewritten by auto-scheduler, + // we just forcly apply the layout provided by auto-scheduler and + // skip the normal inference logic. + {} // do nothing + } + + } else { + // use weight to infer the conv shape. + if (weight == nullptr) return false; + + Array wshape; + if (param->auto_scheduler_rewritten_layout.size() == 0) { + wshape = weight->shape; + } else { + // works for the default kernel layout "DHWIO" + ICHECK_EQ(param->kernel_layout, "DHWIO"); + wshape = auto_scheduler::GetShapeFromRewrittenLayout(param->auto_scheduler_rewritten_layout, + {"rd", "rh", "rw", "rc", "cc"}); + } + + wshape = trans_kernel_layout.ForwardShape(wshape); + if (param->kernel_size.defined()) { + ICHECK_EQ(param->kernel_size.size(), 3); + // check the size + ICHECK(reporter->AssertEQ(param->kernel_size[0], wshape[2]) && + reporter->AssertEQ(param->kernel_size[1], wshape[3]) && + reporter->AssertEQ(param->kernel_size[2], wshape[4])) + << "Conv3D: shape of weight is inconsistent with kernel_size, " + << " kernel_size=" << param->kernel_size << " wshape=" << wshape; + } + + if (param->channels.defined()) { + ICHECK(reporter->AssertEQ(param->channels, wshape[0])) + << "Conv3D: shape of weight is inconsistent with channels, " + << " channels=" << param->channels << " wshape=" << wshape; + } + + if (!dshape_ncdhw[1].as() && !wshape[1].as()) { + ICHECK(reporter->AssertEQ(indexdiv(dshape_ncdhw[1], param->groups), wshape[1])); + } + channels = wshape[0]; + dilated_ksize_z = 1 + (wshape[2] - 1) * param->dilation[0]; + dilated_ksize_y = 1 + (wshape[3] - 1) * param->dilation[1]; + dilated_ksize_x = 1 + (wshape[4] - 1) * param->dilation[2]; + } + // dilation + Array oshape({dshape_ncdhw[0], channels, 0, 0, 0}); + + IndexExpr pad_d, pad_h, pad_w; + GetPaddingDepthHeightWidth(param->padding, &pad_d, &pad_h, &pad_w); + if (!dshape_ncdhw[2].as()) { + oshape.Set(2, indexdiv(dshape_ncdhw[2] + pad_d - dilated_ksize_z, param->strides[0]) + 1); + } else { + oshape.Set(2, dshape_ncdhw[2]); + } + + if (!dshape_ncdhw[3].as()) { + oshape.Set(3, indexdiv(dshape_ncdhw[3] + pad_h - dilated_ksize_y, param->strides[1]) + 1); + } else { + oshape.Set(3, dshape_ncdhw[3]); + } + + if (!dshape_ncdhw[4].as()) { + oshape.Set(4, indexdiv(dshape_ncdhw[4] + pad_w - dilated_ksize_x, param->strides[2]) + 1); + } else { + oshape.Set(4, dshape_ncdhw[4]); + } + DataType out_dtype = param->out_dtype; + if (out_dtype.bits() == 0) { + out_dtype = data->dtype; + } + oshape = trans_out_layout.BackwardShape(oshape); + // assign output type + reporter->Assign(types[2], TensorType(oshape, out_dtype)); + return true; +} + TVM_REGISTER_GLOBAL("relay.op.nn._make.conv3d") .set_body_typed([](Expr data, Expr weight, Array strides, Array padding, Array dilation, int groups, IndexExpr channels, @@ -151,12 +559,128 @@ with the layer input to produce a tensor of outputs. .add_argument("data", "Tensor", "The input tensor.") .add_argument("weight", "Tensor", "The weight tensor.") .set_support_level(2) - .add_type_rel("Conv3D", Conv3DRel) + .add_type_rel("Conv3D", Conv3DRel) .set_attr("FInferCorrectLayout", ConvInferCorrectLayout); // relay.nn.conv3d_transpose TVM_REGISTER_NODE_TYPE(Conv3DTransposeAttrs); +template +bool Conv3DTransposeRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { + ICHECK_EQ(types.size(), 3); + const auto* data = types[0].as(); + const auto* weight = types[1].as(); + if (data == nullptr) return false; + + static const Layout kNCDHW("NCDHW"); + static const Layout kOIDHW("OIDHW"); + + const Conv3DTransposeAttrs* param = attrs.as(); + ICHECK(param != nullptr); + const Layout in_layout(param->data_layout); + const Layout kernel_layout(param->kernel_layout); + + const auto trans_in_layout = tir::BijectiveLayout(in_layout, kNCDHW); + ICHECK(trans_in_layout.defined()) + << "Conv3d_transpose only support input layouts that are convertible from NCDHW." + << " But got " << in_layout; + + const auto trans_kernel_layout = tir::BijectiveLayout(kernel_layout, kOIDHW); + ICHECK(trans_kernel_layout.defined()) + << "Conv3d_transpose only support kernel layouts that are convertible from OIDHW." + << " But got " << kernel_layout; + + Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout); + const auto trans_out_layout = tir::BijectiveLayout(out_layout, kNCDHW); + ICHECK(trans_out_layout.defined()) + << "Conv3d_transpose only support output layouts that are convertible from NCDHW." + << " But got " << out_layout; + + IndexExpr channels, dilated_ksize_d, dilated_ksize_y, dilated_ksize_x; + + auto dshape_ncdhw = trans_in_layout.ForwardShape(data->shape); + + // infer weight if the kernel_size and channels are defined + if (param->kernel_size.defined() && param->channels.defined()) { + ICHECK_EQ(param->kernel_size.size(), 3); + ICHECK_EQ(param->dilation.size(), 3); + + Array wshape({dshape_ncdhw[1], indexdiv(param->channels, param->groups), + param->kernel_size[0], param->kernel_size[1], param->kernel_size[2]}); + + wshape = trans_kernel_layout.BackwardShape(wshape); + dilated_ksize_d = 1 + (param->kernel_size[0] - 1) * param->dilation[0]; + dilated_ksize_y = 1 + (param->kernel_size[1] - 1) * param->dilation[1]; + dilated_ksize_x = 1 + (param->kernel_size[2] - 1) * param->dilation[2]; + channels = param->channels; + + DataType weight_dtype = data->dtype; + if (weight != nullptr) { + weight_dtype = weight->dtype; + } + // assign result to reporter + reporter->Assign(types[1], TensorType(wshape, weight_dtype)); + } else { + // use weight to infer the conv shape. + if (weight == nullptr) return false; + auto wshape = trans_kernel_layout.ForwardShape(weight->shape); + if (param->kernel_size.defined()) { + ICHECK_EQ(param->kernel_size.size(), 3); + // check the size + ICHECK(reporter->AssertEQ(param->kernel_size[0], wshape[2]) && + reporter->AssertEQ(param->kernel_size[1], wshape[3]) && + reporter->AssertEQ(param->kernel_size[2], wshape[4])) + << "Conv3D: shape of weight is inconsistent with kernel_size, " + << " kernel_size=" << param->kernel_size << " wshape=" << Array(wshape); + } + if (param->channels.defined()) { + ICHECK(reporter->AssertEQ(param->channels, wshape[1])) + << "Conv3D: shape of weight is inconsistent with channels, " + << " channels=" << param->channels << " wshape=" << Array(wshape); + } + if (!dshape_ncdhw[1].as() && !wshape[0].as()) { + ICHECK(reporter->AssertEQ(indexdiv(dshape_ncdhw[1], param->groups), wshape[0])); + } + channels = wshape[1]; + dilated_ksize_d = 1 + (wshape[2] - 1) * param->dilation[0]; + dilated_ksize_x = 1 + (wshape[3] - 1) * param->dilation[1]; + dilated_ksize_y = 1 + (wshape[4] - 1) * param->dilation[2]; + } + + // dilation + Array oshape({dshape_ncdhw[0], channels, 0, 0, 0}); + IndexExpr pad_d, pad_h, pad_w; + GetPaddingDepthHeightWidth(param->padding, &pad_d, &pad_h, &pad_w); + + if (!dshape_ncdhw[2].as()) { + oshape.Set(2, (param->strides[0] * (dshape_ncdhw[2] - 1) + dilated_ksize_d - pad_d + + param->output_padding[0])); + } else { + oshape.Set(2, dshape_ncdhw[2]); + } + if (!dshape_ncdhw[3].as()) { + oshape.Set(3, (param->strides[1] * (dshape_ncdhw[3] - 1) + dilated_ksize_y - pad_h + + param->output_padding[1])); + } else { + oshape.Set(3, dshape_ncdhw[3]); + } + if (!dshape_ncdhw[4].as()) { + oshape.Set(4, (param->strides[2] * (dshape_ncdhw[4] - 1) + dilated_ksize_x - pad_w + + param->output_padding[2])); + } else { + oshape.Set(4, dshape_ncdhw[4]); + } + + DataType out_dtype = param->out_dtype; + if (out_dtype.bits() == 0) { + out_dtype = data->dtype; + } + oshape = trans_out_layout.BackwardShape(oshape); + reporter->Assign(types[2], TensorType(oshape, out_dtype)); + return true; +} + TVM_REGISTER_GLOBAL("relay.op.nn._make.conv3d_transpose") .set_body_typed([](Expr data, Expr weight, Array strides, Array padding, Array dilation, int groups, IndexExpr channels, @@ -202,6 +726,115 @@ said convolution. // relay.nn.conv2d_transpose TVM_REGISTER_NODE_TYPE(Conv2DTransposeAttrs); +bool Conv2DTransposeRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { + ICHECK_EQ(types.size(), 3); + const auto* data = types[0].as(); + const auto* weight = types[1].as(); + if (data == nullptr) return false; + + static const Layout kNCHW("NCHW"); + static const Layout kIOHW("IOHW"); + + const Conv2DTransposeAttrs* param = attrs.as(); + ICHECK(param != nullptr); + const Layout in_layout(param->data_layout); + const Layout kernel_layout(param->kernel_layout); + + const auto trans_in_layout = tir::BijectiveLayout(in_layout, kNCHW); + ICHECK(trans_in_layout.defined()) + << "Conv2DTransposed only support input layouts that are convertible from NCHW." + << " But got " << in_layout; + + const auto trans_kernel_layout = tir::BijectiveLayout(kernel_layout, kIOHW); + ICHECK(trans_kernel_layout.defined()) + << "Conv2DTransposed only support kernel layouts that are convertible from IOHW." + << " But got " << kernel_layout; + + Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout); + const auto trans_out_layout = tir::BijectiveLayout(out_layout, kNCHW); + ICHECK(trans_out_layout.defined()) + << "Conv2DTransposed only support output layouts that are convertible from NCHW." + << " But got " << out_layout; + + IndexExpr channels, dilated_ksize_y, dilated_ksize_x; + + auto dshape_nchw = trans_in_layout.ForwardShape(data->shape); + + // infer weight if the kernel_size and channels are defined + if (param->kernel_size.defined() && param->channels.defined()) { + ICHECK_EQ(param->kernel_size.size(), 2); + ICHECK_EQ(param->dilation.size(), 2); + + Array wshape({dshape_nchw[1], indexdiv(param->channels, param->groups), + param->kernel_size[0], param->kernel_size[1]}); + + wshape = trans_kernel_layout.BackwardShape(wshape); + dilated_ksize_y = 1 + (param->kernel_size[0] - 1) * param->dilation[0]; + dilated_ksize_x = 1 + (param->kernel_size[1] - 1) * param->dilation[1]; + channels = param->channels; + + DataType weight_dtype = data->dtype; + if (weight != nullptr) { + weight_dtype = weight->dtype; + } + // assign result to reporter + reporter->Assign(types[1], TensorType(wshape, weight_dtype)); + } else { + // use weight to infer the conv shape. + if (weight == nullptr) return false; + auto wshape = trans_kernel_layout.ForwardShape(weight->shape); + if (param->kernel_size.defined()) { + ICHECK_EQ(param->kernel_size.size(), 2); + // check the size + ICHECK(reporter->AssertEQ(param->kernel_size[0], wshape[2]) && + reporter->AssertEQ(param->kernel_size[1], wshape[3])) + << "Conv2DTransposed: shape of weight is inconsistent with kernel_size, " + << " kernel_size=" << param->kernel_size << " wshape=" << Array(wshape); + } + if (param->channels.defined()) { + ICHECK(reporter->AssertEQ(indexdiv(param->channels, param->groups), wshape[1])) + << "Conv2DTransposed: shape of weight is inconsistent with out_channels, " + << " out_channels // groups != weight.shape[1] " + << " out_channels=" << param->channels << " groups=" << param->groups + << " weight.shape=" << Array(wshape); + } + if (!dshape_nchw[1].as() && !wshape[0].as()) { + ICHECK(reporter->AssertEQ(dshape_nchw[1], wshape[0])) + << "Conv2DTransposed: shape of weight is inconsistent with in_channels." + << " data.shape= " << Array(dshape_nchw) << " groups= " << param->groups + << " weight.shape= " << Array(wshape); + } + channels = wshape[1]; + dilated_ksize_y = 1 + (wshape[2] - 1) * param->dilation[0]; + dilated_ksize_x = 1 + (wshape[3] - 1) * param->dilation[1]; + } + // dilation + Array oshape({dshape_nchw[0], channels, 0, 0}); + IndexExpr pad_h, pad_w; + GetPaddingHeightWidth(param->padding, &pad_h, &pad_w); + if (!dshape_nchw[2].as()) { + oshape.Set(2, (param->strides[0] * (dshape_nchw[2] - 1) + dilated_ksize_y - pad_h + + param->output_padding[0])); + } else { + oshape.Set(2, dshape_nchw[2]); + } + if (!dshape_nchw[3].as()) { + oshape.Set(3, (param->strides[1] * (dshape_nchw[3] - 1) + dilated_ksize_x - pad_w + + param->output_padding[1])); + } else { + oshape.Set(3, dshape_nchw[3]); + } + + DataType out_dtype = param->out_dtype; + if (out_dtype.bits() == 0) { + out_dtype = data->dtype; + } + oshape = trans_out_layout.BackwardShape(oshape); + reporter->Assign(types[2], TensorType(oshape, out_dtype)); + return true; +} + TVM_REGISTER_GLOBAL("relay.op.nn._make.conv2d_transpose") .set_body_typed([](Expr data, Expr weight, Array strides, Array padding, Array dilation, int groups, IndexExpr channels, @@ -241,11 +874,106 @@ v (batch_size, channels, out_height, out_width) if `layout` is `NCHW` .set_support_level(2) .set_attr("FInferCorrectLayout", ConvInferCorrectLayout) - .add_type_rel("Conv2DTranspose", Conv2DTransposeRel); + .add_type_rel("Conv2DTranspose", Conv2DTransposeRel); // relay.nn.conv1d_transpose TVM_REGISTER_NODE_TYPE(Conv1DTransposeAttrs); +bool Conv1DTransposeRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { + ICHECK_EQ(types.size(), 3); + const auto* data = types[0].as(); + const auto* weight = types[1].as(); + if (data == nullptr) return false; + + static const Layout kNCW("NCW"); + static const Layout kOIW("OIW"); + + const Conv1DTransposeAttrs* param = attrs.as(); + ICHECK(param != nullptr); + const Layout in_layout(param->data_layout); + const Layout kernel_layout(param->kernel_layout); + + const auto trans_in_layout = tir::BijectiveLayout(in_layout, kNCW); + ICHECK(trans_in_layout.defined()) + << "Conv only support input layouts that are convertible from NCW." + << " But got " << in_layout; + + const auto trans_kernel_layout = tir::BijectiveLayout(kernel_layout, kOIW); + ICHECK(trans_kernel_layout.defined()) + << "Conv only support kernel layouts that are convertible from OIW." + << " But got " << kernel_layout; + + Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout); + const auto trans_out_layout = tir::BijectiveLayout(out_layout, kNCW); + ICHECK(trans_out_layout.defined()) + << "Conv only support output layouts that are convertible from NCW." + << " But got " << out_layout; + + IndexExpr channels, dilated_ksize_y, dilated_ksize_x; + + auto dshape_ncw = trans_in_layout.ForwardShape(data->shape); + + // infer weight if the kernel_size and channels are defined + if (param->kernel_size.defined() && param->channels.defined()) { + ICHECK_EQ(param->kernel_size.size(), 1); + ICHECK_EQ(param->dilation.size(), 1); + + Array wshape( + {dshape_ncw[1], indexdiv(param->channels, param->groups), param->kernel_size[0]}); + + wshape = trans_kernel_layout.BackwardShape(wshape); + dilated_ksize_x = 1 + (param->kernel_size[0] - 1) * param->dilation[0]; + channels = param->channels; + + DataType weight_dtype = data->dtype; + if (weight != nullptr) { + weight_dtype = weight->dtype; + } + // assign result to reporter + reporter->Assign(types[1], TensorType(wshape, weight_dtype)); + } else { + // use weight to infer the conv shape. + if (weight == nullptr) return false; + auto wshape = trans_kernel_layout.ForwardShape(weight->shape); + if (param->kernel_size.defined()) { + ICHECK_EQ(param->kernel_size.size(), 1); + // check the size + ICHECK(reporter->AssertEQ(param->kernel_size[0], wshape[2])) + << "Conv1D: shape of weight is inconsistent with kernel_size, " + << " kernel_size=" << param->kernel_size << " wshape=" << Array(wshape); + } + if (param->channels.defined()) { + ICHECK(reporter->AssertEQ(param->channels, wshape[1])) + << "Conv1D: shape of weight is inconsistent with channels, " + << " channels=" << param->channels << " wshape=" << Array(wshape); + } + if (!dshape_ncw[1].as() && !wshape[0].as()) { + ICHECK(reporter->AssertEQ(indexdiv(dshape_ncw[1], param->groups), wshape[0])); + } + channels = wshape[1]; + dilated_ksize_x = 1 + (wshape[2] - 1) * param->dilation[0]; + } + // dilation + IndexExpr pad_w; + GetPaddingWidth(param->padding, &pad_w); + Array oshape({dshape_ncw[0], channels, 0}); + if (!dshape_ncw[2].as()) { + oshape.Set(2, (param->strides[0] * (dshape_ncw[2] - 1) + dilated_ksize_x - pad_w + + param->output_padding[0])); + } else { + oshape.Set(2, dshape_ncw[2]); + } + + DataType out_dtype = param->out_dtype; + if (out_dtype.bits() == 0) { + out_dtype = data->dtype; + } + oshape = trans_out_layout.BackwardShape(oshape); + reporter->Assign(types[2], TensorType(oshape, out_dtype)); + return true; +} + TVM_REGISTER_GLOBAL("relay.op.nn._make.conv1d_transpose") .set_body_typed([](Expr data, Expr weight, Array strides, Array padding, Array dilation, int groups, IndexExpr channels, @@ -282,7 +1010,7 @@ said convolution. .add_argument("data", "Tensor", "The input tensor.") .add_argument("weight", "Tensor", "The weight tensor.") .set_support_level(2) - .add_type_rel("Conv1DTranspose", Conv1DTransposeRel); + .add_type_rel("Conv1DTranspose", Conv1DTransposeRel); // relay.nn.contrib_conv2d_winograd_without_weight_transform TVM_REGISTER_NODE_TYPE(Conv2DWinogradAttrs); @@ -322,6 +1050,28 @@ RELAY_REGISTER_OP("nn.contrib_conv2d_winograd_without_weight_transform") // relay.nn.contrib_conv2d_winograd_weight_transform TVM_REGISTER_NODE_TYPE(ConvWinogradWeightTransformAttrs); +bool Conv2DWinogradWeightTransformRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { + ICHECK_EQ(types.size(), 2); + const auto* data = types[0].as(); + if (data == nullptr) return false; + + const ConvWinogradWeightTransformAttrs* param = attrs.as(); + ICHECK(param != nullptr); + + ICHECK_EQ(data->shape.size(), 4) << "Only support NCHW normal kernel layout"; + + std::vector oshape{ + param->tile_size + data->shape[2] - 1, + param->tile_size + data->shape[3] - 1, + data->shape[0], + data->shape[1], + }; + + reporter->Assign(types[1], TensorType(Array(oshape), data->dtype)); + return true; +} + TVM_REGISTER_GLOBAL("relay.op.nn._make.contrib_conv2d_winograd_weight_transform") .set_body_typed([](Expr weight, int tile_size) { return MakeConvWinogradWeightTransform(weight, tile_size, @@ -345,6 +1095,88 @@ weight transformation in advance. // relay.nn.contrib_conv3d_winograd_without_weight_transform TVM_REGISTER_NODE_TYPE(Conv3DWinogradAttrs); +bool Conv3DWinogradRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { + ICHECK_EQ(types.size(), 3); + const auto* data = types[0].as(); + if (data == nullptr) return false; + static const Layout kNCDHW("NCDHW"); + static const Layout kOIDHW("OIDHW"); + + const auto* param = attrs.as(); + ICHECK(param != nullptr); + const Layout in_layout(param->data_layout); + const Layout kernel_layout(param->kernel_layout); + + const auto trans_in_layout = tir::BijectiveLayout(in_layout, kNCDHW); + ICHECK(trans_in_layout.defined()) + << "Conv only support input layouts that are convertible from NCDHW." + << " But got " << in_layout; + + const auto trans_kernel_layout = tir::BijectiveLayout(kernel_layout, kOIDHW); + ICHECK(trans_kernel_layout.defined()) + << "Conv only support kernel layouts that are convertible from OIDHW." + << " But got " << kernel_layout; + + Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout); + const auto trans_out_layout = tir::BijectiveLayout(out_layout, kNCDHW); + ICHECK(trans_out_layout.defined()) + << "Conv only support output layouts that are convertible from NCDHW." + << " But got " << out_layout; + + Array dshape_ncdhw = trans_in_layout.ForwardShape(data->shape); + + IndexExpr channels, dilated_ksize_d, dilated_ksize_y, dilated_ksize_x; + + ICHECK(param->kernel_size.defined() && param->channels.defined()) + << "The kernel size and channels of a Conv must be set or inferred by previous pass"; + + ICHECK_EQ(param->kernel_size.size(), 3); + ICHECK_EQ(param->dilation.size(), 3); + + channels = param->channels; + dilated_ksize_d = 1 + (param->kernel_size[0] - 1) * param->dilation[0]; + dilated_ksize_y = 1 + (param->kernel_size[1] - 1) * param->dilation[1]; + dilated_ksize_x = 1 + (param->kernel_size[2] - 1) * param->dilation[2]; + + // NOTE: Do not check weight shape here! + // Different backend requires different layout to compute + // the batch gemm stage in winograd efficiently, but we want to + // make this op work for all backends. + // So we accept all weight shapes, and assume the TOPI developers + // can handle this correctly in alter_op_layout. + + // dilation + Array oshape({dshape_ncdhw[0], channels, 0, 0, 0}); + + IndexExpr pad_d, pad_h, pad_w; + GetPaddingDepthHeightWidth(param->padding, &pad_d, &pad_h, &pad_w); + if (!dshape_ncdhw[2].as()) { + oshape.Set(2, (dshape_ncdhw[2] + pad_d - dilated_ksize_d) / param->strides[0] + 1); + } else { + oshape.Set(2, dshape_ncdhw[2]); + } + if (!dshape_ncdhw[2].as()) { + oshape.Set(3, (dshape_ncdhw[3] + pad_h - dilated_ksize_y) / param->strides[1] + 1); + } else { + oshape.Set(3, dshape_ncdhw[3]); + } + if (!dshape_ncdhw[4].as()) { + oshape.Set(4, (dshape_ncdhw[4] + pad_w - dilated_ksize_x) / param->strides[2] + 1); + } else { + oshape.Set(4, dshape_ncdhw[4]); + } + + DataType out_dtype = param->out_dtype; + if (out_dtype.bits() == 0) { + out_dtype = data->dtype; + } + oshape = trans_out_layout.BackwardShape(oshape); + // assign output type + reporter->Assign(types[2], TensorType(oshape, out_dtype)); + return true; +} + TVM_REGISTER_GLOBAL("relay.op.nn._make.contrib_conv3d_winograd_without_weight_transform") .set_body_typed([](Expr data, Expr weight, int tile_size, Array strides, Array padding, Array dilation, int groups, @@ -373,7 +1205,7 @@ RELAY_REGISTER_OP("nn.contrib_conv3d_winograd_without_weight_transform") .add_argument("data", "Tensor", "The input tensor.") .add_argument("weight", "Tensor", "The weight tensor.") .set_support_level(10) - .add_type_rel("Conv3DWinograd", Conv3DWinogradRel) + .add_type_rel("Conv3DWinograd", Conv3DWinogradRel) .set_attr("FInferCorrectLayout", ConvInferCorrectLayout); @@ -384,6 +1216,35 @@ TVM_REGISTER_GLOBAL("relay.op.nn._make.contrib_conv3d_winograd_weight_transform" "nn.contrib_conv3d_winograd_weight_transform"); }); +bool Conv3DWinogradWeightTransformRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { + ICHECK_EQ(types.size(), 2); + const auto* data = types[0].as(); + if (data == nullptr) return false; + + const ConvWinogradWeightTransformAttrs* param = attrs.as(); + ICHECK(param != nullptr); + + ICHECK_EQ(data->shape.size(), 5) << "Only support NCDHW normal kernel layout"; + + // Shape of packed weights depends on whether depth is being transformed or not. + Array oshape({0, 0, 0, data->shape[0], data->shape[1]}); + auto* depth_imm = data->shape[2].as(); + bool transform_depth = (depth_imm->value > 2) && (depth_imm->value < 8); + if (transform_depth) { + oshape.Set(0, param->tile_size + data->shape[2] - 1); + oshape.Set(1, param->tile_size + data->shape[3] - 1); + oshape.Set(2, param->tile_size + data->shape[4] - 1); + } else { + oshape.Set(0, param->tile_size + data->shape[3] - 1); + oshape.Set(1, param->tile_size + data->shape[4] - 1); + oshape.Set(2, data->shape[2]); + } + + reporter->Assign(types[1], TensorType(oshape, data->dtype)); + return true; +} + RELAY_REGISTER_OP("nn.contrib_conv3d_winograd_weight_transform") .describe(R"code(Weight transformation of winograd fast 3d convolution algorithm. @@ -401,6 +1262,35 @@ weight transformation in advance. // relay.nn.contrib_conv2d_winograd_nnpack_weight_transform TVM_REGISTER_NODE_TYPE(Conv2DWinogradNNPACKWeightTransformAttrs); +bool Conv2DWinogradNNPACKWeightTransformRel(const Array& types, int num_inputs, + const Attrs& attrs, const TypeReporter& reporter) { + ICHECK_EQ(types.size(), 2); + const auto* data = types[0].as(); + if (data == nullptr) { + return false; + } + + const Conv2DWinogradNNPACKWeightTransformAttrs* param = + attrs.as(); + ICHECK(param != nullptr); + + ICHECK_EQ(data->shape.size(), 4) << "Only support NCHW normal kernel layout"; + + std::vector oshape{ + data->shape[0], + data->shape[1], + 8, + 8, + }; + + DataType out_dtype = param->out_dtype; + if (out_dtype.bits() == 0) { + out_dtype = data->dtype; + } + reporter->Assign(types[1], TensorType(Array(oshape), out_dtype)); + return true; +} + Expr MakeConv2DWinogradNNPACKWeightTransform(Expr weight, int convolution_algorithm, DataType out_dtype) { auto attrs = make_object(); @@ -438,6 +1328,77 @@ TVM_REGISTER_GLOBAL("relay.op.nn._make.contrib_conv2d_gemm_without_weight_transf kernel_layout, out_layout, out_dtype, "nn.contrib_conv2d_gemm_without_weight_transform"); }); +bool Conv2DGemmRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { + ICHECK_EQ(types.size(), 3); + const auto* data = types[0].as(); + if (data == nullptr) return false; + static const Layout kNHWC("NHWC"); + static const Layout kHWIO("HWIO"); + + const auto* param = attrs.as(); + ICHECK(param != nullptr); + const Layout in_layout(param->data_layout); + const Layout kernel_layout(param->kernel_layout); + + const auto trans_in_layout = tir::BijectiveLayout(in_layout, kNHWC); + ICHECK(trans_in_layout.defined()) + << "Conv only support input layouts that are convertible from NHWC." + << " But got " << in_layout; + + const auto trans_kernel_layout = tir::BijectiveLayout(kernel_layout, kHWIO); + ICHECK(trans_kernel_layout.defined()) + << "Conv only support kernel layouts that are convertible from HWIO." + << " But got " << kernel_layout; + + Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout); + const auto trans_out_layout = tir::BijectiveLayout(out_layout, kNHWC); + ICHECK(trans_out_layout.defined()) + << "Conv only support output layouts that are convertible from NHWC." + << " But got " << out_layout; + + Array dshape_nhwc = trans_in_layout.ForwardShape(data->shape); + + IndexExpr channels, dilated_ksize_y, dilated_ksize_x; + + ICHECK(param->kernel_size.defined() && param->channels.defined()) + << "The kernel size and channels of a Conv must be set or inferred by previous pass"; + + ICHECK_EQ(param->kernel_size.size(), 2); + ICHECK_EQ(param->dilation.size(), 2); + + channels = param->channels; + dilated_ksize_y = 1 + (param->kernel_size[0] - 1) * param->dilation[0]; + dilated_ksize_x = 1 + (param->kernel_size[1] - 1) * param->dilation[1]; + + // NOTE: Do not check weight shape here! + + // dilation + Array oshape({dshape_nhwc[0], 0, 0, channels}); + + IndexExpr pad_h, pad_w; + GetPaddingHeightWidth(param->padding, &pad_h, &pad_w); + if (!dshape_nhwc[2].as()) { + oshape.Set(1, (dshape_nhwc[1] + pad_h - dilated_ksize_y) / param->strides[0] + 1); + } else { + oshape.Set(1, dshape_nhwc[1]); + } + if (!dshape_nhwc[3].as()) { + oshape.Set(2, (dshape_nhwc[2] + pad_w - dilated_ksize_x) / param->strides[1] + 1); + } else { + oshape.Set(2, dshape_nhwc[2]); + } + + DataType out_dtype = param->out_dtype; + if (out_dtype.bits() == 0) { + out_dtype = data->dtype; + } + oshape = trans_out_layout.BackwardShape(oshape); + // assign output type + reporter->Assign(types[2], TensorType(oshape, out_dtype)); + return true; +} + RELAY_REGISTER_OP("nn.contrib_conv2d_gemm_without_weight_transform") .describe(R"code(Compute conv2d with gemm algorithm. Only supports NHWC layout. This operator assumes the weight tensor is already pre-transformed by @@ -455,13 +1416,72 @@ RELAY_REGISTER_OP("nn.contrib_conv2d_gemm_without_weight_transform") .add_argument("data", "Tensor", "The input tensor.") .add_argument("weight", "Tensor", "The weight tensor.") .set_support_level(10) - .add_type_rel("Conv2DGemm", Conv2DGemmRel) + .add_type_rel("Conv2DGemm", Conv2DGemmRel) .set_attr("FInferCorrectLayout", ConvInferCorrectLayout); // relay.nn.contrib_conv2d_gemm_weight_transform TVM_REGISTER_NODE_TYPE(ConvGemmWeightTransformAttrs); +// Gemm convolution shape relations +// In order to run GEMM we need to block-transpose and interleave the K x N weights matrix W. +// The high level idea is to subdivide W in tiles of tile_cols x tile_rows, and transpose and +// interleave them. The final output is a [N//tile_rows, K//tile_cols, tile_rows, tile_cols] +// matrix that we call W_interleaved_t. +// +// In the following picture, we show how the first [tile_cols,tile_rows] block of W is transformed +// for tile_rows = 4 and tile_cols = 16 +// +// W[0,0,:,:] W_interleaved_t[0,0,:,:] +// +-------------------------------+ +----------------------------------- + +// |W[0,0] W[0,1] W[0,2] W[0,3] | |W[0,0] W[1,0] W[2,0] ... W[15,0]| +// |W[1,0] W[1,1] W[1,2] W[1,3] | --\ |W[0,1] W[1,1] W[2,1] ... W[15,1]| +// |W[2,0] W[2,1] W[2,2] W[2,3] | --/ |W[0,2] W[1,2] W[2,2] ... W[15,2]| +// | ... ... ... ... | |W[0,3] W[1,3] W[2,3] ... W[15,3]| +// | ... ... ... ... | +------------------------------------+ +// |W[15,0] W[15,1] W[15,2] W[15,3]| +// +-------------------------------+ +// +// Tile columns is usually the direction of the reduction. So, if our target can reduce k elements +// at the time, we should set tile_cols = k. +// Tile rows is connected with the number of registers available for the given target. +// +bool Conv2DGemmWeightTransformRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { + ICHECK_EQ(types.size(), 2); + const auto* weight = types[0].as(); + if (weight == nullptr) return false; + + const ConvGemmWeightTransformAttrs* param = attrs.as(); + ICHECK(param != nullptr); + int n = param->tile_rows; + int k = param->tile_cols; + + ICHECK_EQ(weight->shape.size(), 4) << "Only support HWIO kernel layout"; + + const auto K = weight->shape[0] * weight->shape[1] * weight->shape[2]; + const auto N = weight->shape[3]; + + auto K_mod_k = indexmod(K, k); + auto N_mod_n = indexmod(N, n); + + auto pad_K = tvm::if_then_else(K_mod_k != 0, k - K_mod_k, tir::make_zero(DataType::Int(32))); + auto pad_N = tvm::if_then_else(N_mod_n != 0, n - N_mod_n, tir::make_zero(DataType::Int(32))); + + const auto N_padded = N + pad_N; + const auto K_padded = K + pad_K; + + Array oshape{ + indexdiv(N_padded, n), + indexdiv(K_padded, k), + n, + k, + }; + + reporter->Assign(types[1], TensorType(oshape, weight->dtype)); + return true; +} + TVM_REGISTER_GLOBAL("relay.op.nn._make.contrib_conv2d_gemm_weight_transform") .set_body_typed([](Expr weights, int tile_rows, int tile_cols) { return MakeConvGemmWeightTransform(weights, tile_rows, tile_cols, @@ -532,11 +1552,133 @@ RELAY_REGISTER_OP("nn.contrib_depthwise_conv2d_NCHWc") .add_argument("data", "Tensor", "The input tensor.") .add_argument("weight", "Tensor", "The weight tensor.") .set_support_level(10) - .add_type_rel("Conv2D", Conv2DRel) + .add_type_rel("Conv2D", Conv2DRel) .set_attr("FInferCorrectLayout", ConvInferCorrectLayout); TVM_REGISTER_NODE_TYPE(DeformableConv2DAttrs); +// Deformable Convolution shape relations. +bool DeformableConv2DRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { + ICHECK_EQ(types.size(), 4); + const auto* data = types[0].as(); + const auto* weight = types[2].as(); + + ICHECK(data); + static const Layout kNCHW("NCHW"); + static const Layout kOIHW("OIHW"); + + auto* param = attrs.as(); + ICHECK(param != nullptr); + const Layout in_layout(param->data_layout); + const Layout kernel_layout(param->kernel_layout); + + const auto trans_in_layout = tir::BijectiveLayout(in_layout, kNCHW); + if (!trans_in_layout.defined()) { + reporter->GetDiagCtx().Emit( + Diagnostic::Error(reporter->GetSpan()) + << "deformable_conv2d only support input layouts that are convertible from NCHW." + << " The provided layout is: " << in_layout); + return false; + } + + const auto trans_kernel_layout = tir::BijectiveLayout(kernel_layout, kOIHW); + if (!trans_kernel_layout.defined()) { + reporter->GetDiagCtx().Emit( + Diagnostic::Error(reporter->GetSpan()) + << "deformable_conv2d only support kernel layouts that are convertible from OIHW." + << " The provided layout is: " << kernel_layout); + return false; + } + + Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout); + const auto trans_out_layout = tir::BijectiveLayout(out_layout, kNCHW); + if (!trans_out_layout.defined()) { + reporter->GetDiagCtx().Emit( + Diagnostic::Error(reporter->GetSpan()) + << "deformable_conv2d only support output layouts that are convertible from NCHW." + << "The provided layout is: " << out_layout); + return false; + } + + Array dshape_nchw = trans_in_layout.ForwardShape(data->shape); + + IndexExpr channels, dilated_ksize_y, dilated_ksize_x, ksize_y, ksize_x; + + // infer weight shape if kernel_size and channels are defiend + if (param->kernel_size.defined() && param->channels.defined()) { + ICHECK_EQ(param->kernel_size.size(), 2); + ICHECK_EQ(param->dilation.size(), 2); + Array wshape({param->channels, indexdiv(dshape_nchw[1], param->groups), + param->kernel_size[0], param->kernel_size[1]}); + + wshape = trans_kernel_layout.BackwardShape(wshape); + channels = param->channels; + ksize_y = param->kernel_size[0]; + ksize_x = param->kernel_size[1]; + dilated_ksize_y = 1 + (param->kernel_size[0] - 1) * param->dilation[0]; + dilated_ksize_x = 1 + (param->kernel_size[1] - 1) * param->dilation[1]; + // assign result to reporter + reporter->Assign(types[2], TensorType(wshape, data->dtype)); + } else { + // use weight to infer the conv shape. + if (weight == nullptr) return false; + auto wshape = trans_kernel_layout.ForwardShape(weight->shape); + + if (param->kernel_size.defined()) { + ICHECK_EQ(param->kernel_size.size(), 2); + // check the size + ICHECK(reporter->AssertEQ(param->kernel_size[0], wshape[2]) && + reporter->AssertEQ(param->kernel_size[1], wshape[3])) + << "DeformableConv2D: shape of weight is inconsistent with kernel_size, " + << " kernel_size=" << param->kernel_size << " wshape=" << wshape; + } + if (param->channels.defined()) { + ICHECK(reporter->AssertEQ(param->channels, wshape[0])) + << "DeformableConv2D: shape of weight is inconsistent with channels, " + << " channels=" << param->channels << " wshape=" << wshape; + } + if (!dshape_nchw[1].as() && !wshape[1].as()) { + ICHECK(reporter->AssertEQ(indexdiv(dshape_nchw[1], param->groups), wshape[1])); + } + channels = wshape[0]; + ksize_y = wshape[2]; + ksize_x = wshape[3]; + dilated_ksize_y = 1 + (wshape[2] - 1) * param->dilation[0]; + dilated_ksize_x = 1 + (wshape[3] - 1) * param->dilation[1]; + } + // dilation + Array oshape({dshape_nchw[0], channels, 0, 0}); + + IndexExpr pad_h, pad_w; + GetPaddingHeightWidth(param->padding, &pad_h, &pad_w); + oshape.Set(2, indexdiv(dshape_nchw[2] + pad_h - dilated_ksize_y, param->strides[0]) + 1); + oshape.Set(3, indexdiv(dshape_nchw[3] + pad_w - dilated_ksize_x, param->strides[1]) + 1); + DataType out_dtype = param->out_dtype; + + // infer offset shape + Array offset_shape( + {dshape_nchw[0], 2 * ksize_y * ksize_x * param->deformable_groups, oshape[2], oshape[3]}); + offset_shape = trans_in_layout.BackwardShape(offset_shape); + reporter->Assign(types[1], TensorType(offset_shape, data->dtype)); + if (out_dtype.bits() == 0) { + out_dtype = data->dtype; + } + + oshape = trans_out_layout.BackwardShape(oshape); + reporter->Assign(types[3], TensorType(oshape, out_dtype)); + return true; +} + +InferCorrectLayoutOutput DeformableConvInferCorrectLayout( + const Attrs& attrs, const Array& new_in_layouts, const Array& old_in_layouts, + const Array& old_in_types) { + const auto* params = attrs.as(); + return InferCorrectLayoutOutput( + {params->data_layout, params->data_layout, params->kernel_layout}, + {params->out_layout == "" ? params->data_layout : params->out_layout}, attrs); +} + RELAY_REGISTER_OP("nn.deformable_conv2d") .describe(R"code(Compute 2-D deformable convolution on 4-D input. The deformable convolution operation is described in https://arxiv.org/abs/1703.06211 @@ -563,9 +1705,8 @@ by concating all the *g* results. .add_argument("offset", "Tensor", "The offset tensor.") .add_argument("weight", "Tensor", "The weight tensor.") .set_support_level(5) - .add_type_rel("DeformableConv2D", DeformableConv2DRel) - .set_attr("FInferCorrectLayout", - DeformableConvInferCorrectLayout); + .add_type_rel("DeformableConv2D", DeformableConv2DRel) + .set_attr("FInferCorrectLayout", DeformableConvInferCorrectLayout); // Positional relay function to create deformable_conv2d operator // used by frontend FFI. diff --git a/src/relay/op/nn/convolution.h b/src/relay/op/nn/convolution.h index 954daeaa86cf..62552ee4783e 100644 --- a/src/relay/op/nn/convolution.h +++ b/src/relay/op/nn/convolution.h @@ -37,558 +37,11 @@ namespace tvm { namespace relay { -// Standard convolution operator shape relations -template -bool Conv1DRel(const Array& types, int num_inputs, const Attrs& attrs, - const TypeReporter& reporter) { - ICHECK_EQ(types.size(), 3); - const auto* data = types[0].as(); - const auto* weight = types[1].as(); - if (data == nullptr) return false; - static const Layout kNCW("NCW"); - static const Layout kOIW("OIW"); - - const AttrType* param = attrs.as(); - ICHECK(param != nullptr); - const Layout in_layout(param->data_layout); - const Layout kernel_layout(param->kernel_layout); - - const auto trans_in_layout = tir::BijectiveLayout(in_layout, kNCW); - ICHECK(trans_in_layout.defined()) - << "Conv only support input layouts that are convertible from NCW." - << " But got " << in_layout; - - const auto trans_kernel_layout = tir::BijectiveLayout(kernel_layout, kOIW); - ICHECK(trans_kernel_layout.defined()) - << "Conv only support kernel layouts that are convertible from OIW." - << " But got " << kernel_layout; - - Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout); - const auto trans_out_layout = tir::BijectiveLayout(out_layout, kNCW); - ICHECK(trans_out_layout.defined()) - << "Conv only support output layouts that are convertible from NCW." - << " But got " << out_layout; - - Array dshape_ncw = trans_in_layout.ForwardShape(data->shape); - - IndexExpr channels, dilated_ksize; - // infer weight if the kernel_size and channels are defined - if (param->kernel_size.defined() && param->channels.defined()) { - Array wshape; - - wshape = {{param->channels, indexdiv(dshape_ncw[1], param->groups), param->kernel_size[0]}}; - - wshape = trans_kernel_layout.BackwardShape(wshape); - channels = param->channels; - dilated_ksize = 1 + (param->kernel_size[0] - 1) * param->dilation[0]; - DataType weight_dtype = data->dtype; - if (weight != nullptr) { - weight_dtype = weight->dtype; - } - // assign result to reporter - reporter->Assign(types[1], TensorType(wshape, weight_dtype)); - } else { - // use weight to infer the conv shape. - if (weight == nullptr) return false; - auto wshape = trans_kernel_layout.ForwardShape(weight->shape); - if (param->kernel_size.defined()) { - // check the size - ICHECK(reporter->AssertEQ(param->kernel_size[0], wshape[2])) - << "Conv1D: shape of weight is inconsistent with kernel_size, " - << " kernel_size=" << param->kernel_size << " wshape=" << wshape; - } - if (param->channels.defined()) { - ICHECK(reporter->AssertEQ(param->channels, wshape[0])) - << "Conv1D: shape of weight is inconsistent with channels, " - << " channels=" << param->channels << " wshape=" << wshape; - } - if (!dshape_ncw[1].as() && !wshape[1].as()) { - ICHECK(reporter->AssertEQ(dshape_ncw[1], wshape[1])); - } - channels = wshape[0]; - dilated_ksize = 1 + (wshape[2] - 1) * param->dilation[0]; - } - // dilation - Array oshape({dshape_ncw[0], channels, 0}); - - if (!dshape_ncw[2].as()) { - oshape.Set(2, indexdiv(dshape_ncw[2] + param->padding[0] + param->padding[1] - dilated_ksize, - param->strides[0]) + - 1); - } else { - oshape.Set(2, dshape_ncw[2]); - } - - DataType out_dtype = param->out_dtype; - if (out_dtype.bits() == 0) { - out_dtype = data->dtype; - } - oshape = trans_out_layout.BackwardShape(oshape); - // assign output type - reporter->Assign(types[2], TensorType(oshape, out_dtype)); - return true; -} - -template bool Conv2DRel(const Array& types, int num_inputs, const Attrs& attrs, - const TypeReporter& reporter) { - ICHECK_EQ(types.size(), 3); - const auto* data = types[0].as(); - const auto* weight = types[1].as(); - if (data == nullptr) return false; - static const Layout kNCHW("NCHW"); - static const Layout kOIHW("OIHW"); - - const AttrType* param = attrs.as(); - ICHECK(param != nullptr); - const Layout in_layout(param->data_layout); - const Layout kernel_layout(param->kernel_layout); - - const auto trans_in_layout = tir::BijectiveLayout(in_layout, kNCHW); - if (!trans_in_layout.defined()) { - reporter->GetDiagCtx().Emit( - Diagnostic::Error(reporter->GetSpan()) - << "conv2d only support input layouts that are convertible from NCHW." - << " The provided layout is: " << in_layout); - return false; - } - - const auto trans_kernel_layout = tir::BijectiveLayout(kernel_layout, kOIHW); - if (!trans_kernel_layout.defined()) { - reporter->GetDiagCtx().Emit( - Diagnostic::Error(reporter->GetSpan()) - << "conv2d only support kernel layouts that are convertible from OIHW." - << " The provided layout is: " << kernel_layout); - return false; - } - - Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout); - const auto trans_out_layout = tir::BijectiveLayout(out_layout, kNCHW); - if (!trans_out_layout.defined()) { - reporter->GetDiagCtx().Emit( - Diagnostic::Error(reporter->GetSpan()) - << "conv2d only support output layouts that are convertible from NCHW." - << "The provided layout is: " << out_layout); - return false; - } - - Array dshape_nchw = trans_in_layout.ForwardShape(data->shape); - bool is_depthwise = false; - if (param->groups > 1) { - if (!(weight && weight->shape.defined())) { - reporter->GetDiagCtx().Emit( - Diagnostic::Error(reporter->GetSpan()) - << "Weight shape must be specified when groups is greater than 1."); - return false; - } - - Array wshape_oihw = trans_kernel_layout.ForwardShape(weight->shape); - if (tvm::tir::ExprDeepEqual()(param->groups, dshape_nchw[1]) && - tvm::tir::ExprDeepEqual()(param->groups, wshape_oihw[0])) { - is_depthwise = true; - } - } - - IndexExpr channels, dilated_ksize_y, dilated_ksize_x; - // infer weight if the kernel_size and channels are defined - if (param->kernel_size.defined() && param->channels.defined()) { - ICHECK_EQ(param->kernel_size.size(), 2); - ICHECK_EQ(param->dilation.size(), 2); - Array wshape; - - if (is_depthwise) { - // infer weight's shape for depthwise convolution - wshape = {{dshape_nchw[1], indexdiv(param->channels, dshape_nchw[1]), param->kernel_size[0], - param->kernel_size[1]}}; - } else { - wshape = {{param->channels, indexdiv(dshape_nchw[1], param->groups), param->kernel_size[0], - param->kernel_size[1]}}; - } - - wshape = trans_kernel_layout.BackwardShape(wshape); - channels = param->channels; - dilated_ksize_y = 1 + (param->kernel_size[0] - 1) * param->dilation[0]; - dilated_ksize_x = 1 + (param->kernel_size[1] - 1) * param->dilation[1]; - DataType weight_dtype = data->dtype; - if (weight != nullptr) { - weight_dtype = weight->dtype; - } - - if (param->auto_scheduler_rewritten_layout.size() == 0) { - // Normal case: assign result to reporter - reporter->Assign(types[1], TensorType(wshape, weight_dtype)); - } else { - // If the layout is rewritten by auto-scheduler, - // we just forcly apply the layout provided by auto-scheduler and - // skip the normal inference logic. - {} // do nothing - } - } else { - // use weight to infer the conv shape. - if (weight == nullptr) return false; - - Array wshape; - if (param->auto_scheduler_rewritten_layout.size() == 0) { - wshape = weight->shape; - } else { - // works for the default kernel layout "HWIO" - ICHECK_EQ(param->kernel_layout, "HWIO"); - wshape = auto_scheduler::GetShapeFromRewrittenLayout(param->auto_scheduler_rewritten_layout, - {"ry", "rx", "rc", "ff"}); - } - - wshape = trans_kernel_layout.ForwardShape(wshape); - if (param->kernel_size.defined()) { - ICHECK_EQ(param->kernel_size.size(), 2); - - if (!reporter->AssertEQ(param->kernel_size[0], wshape[2])) { - reporter->GetDiagCtx().Emit(Diagnostic::Error(reporter->GetSpan()) - << "Conv2D: shape of weight is inconsistent with kernel_size," - << " kernel_size=" << param->kernel_size - << " wshape=" << wshape); - } - - if (!reporter->AssertEQ(param->kernel_size[1], wshape[3])) { - reporter->GetDiagCtx().Emit(Diagnostic::Error(reporter->GetSpan()) - << "Conv2D: shape of weight is inconsistent with kernel_size," - << " kernel_size=" << param->kernel_size - << " wshape=" << wshape); - return false; - } - } - - if (param->channels.defined() && !reporter->AssertEQ(param->channels, wshape[0])) { - reporter->GetDiagCtx().Emit( - Diagnostic::Error(reporter->GetSpan()) - << "conv2D: the first dimensions of the weight tensor (" << wshape << ")" - << "does not match the number of channels (" << param->channels << ")."); - return false; - } - - if (!dshape_nchw[1].as() && !wshape[1].as()) { - if (!reporter->AssertEQ(indexdiv(dshape_nchw[1], param->groups), wshape[1])) { - reporter->GetDiagCtx().Emit(Diagnostic::Error(reporter->GetSpan()) - << "conv2d: requires that `" - << indexdiv(dshape_nchw[1], param->groups) << "`," - << " the input channels (" << dshape_nchw[1] << ")" - << " divided by groups (" << param->groups << ")" - << ",\n must match the input channels" - << " of the weight `" << wshape[1] - << "`, where the weight shape is (" << wshape << ")."); - return false; - } - } - channels = wshape[0]; - dilated_ksize_y = 1 + (wshape[2] - 1) * param->dilation[0]; - dilated_ksize_x = 1 + (wshape[3] - 1) * param->dilation[1]; - } - // dilation - Array oshape({dshape_nchw[0], channels, 0, 0}); - - IndexExpr pad_h, pad_w; - GetPaddingHeightWidth(param->padding, &pad_h, &pad_w); - if (!dshape_nchw[2].as()) { - oshape.Set(2, indexdiv(dshape_nchw[2] + pad_h - dilated_ksize_y, param->strides[0]) + 1); - } else { - oshape.Set(2, dshape_nchw[2]); - } - - if (!dshape_nchw[3].as()) { - oshape.Set(3, indexdiv(dshape_nchw[3] + pad_w - dilated_ksize_x, param->strides[1]) + 1); - } else { - oshape.Set(3, dshape_nchw[3]); - } - DataType out_dtype = param->out_dtype; - if (out_dtype.bits() == 0) { - out_dtype = data->dtype; - } - oshape = trans_out_layout.BackwardShape(oshape); - // assign output type - reporter->Assign(types[2], TensorType(oshape, out_dtype)); - return true; -} - -template -bool Conv3DRel(const Array& types, int num_inputs, const Attrs& attrs, - const TypeReporter& reporter) { - ICHECK_EQ(types.size(), 3); - const auto* data = types[0].as(); - const auto* weight = types[1].as(); - if (data == nullptr) return false; - static const Layout kNCDHW("NCDHW"); - static const Layout kOIDHW("OIDHW"); - - const AttrType* param = attrs.as(); - ICHECK(param != nullptr); - const Layout in_layout(param->data_layout); - const Layout kernel_layout(param->kernel_layout); - - const auto trans_in_layout = tir::BijectiveLayout(in_layout, kNCDHW); - ICHECK(trans_in_layout.defined()) - << "Conv only support input layouts that are convertible from NCDHW." - << " But got " << in_layout; - - const auto trans_kernel_layout = tir::BijectiveLayout(kernel_layout, kOIDHW); - ICHECK(trans_kernel_layout.defined()) - << "Conv only support kernel layouts that are convertible from OIDHW." - << " But got " << kernel_layout; - - Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout); - const auto trans_out_layout = tir::BijectiveLayout(out_layout, kNCDHW); - ICHECK(trans_out_layout.defined()) - << "Conv only support output layouts that are convertible from NCDHW." - << " But got " << out_layout; - - Array dshape_ncdhw = trans_in_layout.ForwardShape(data->shape); - - IndexExpr channels, dilated_ksize_z, dilated_ksize_y, dilated_ksize_x; - // infer weight if the kernel_size and channels are defined - if (param->kernel_size.defined() && param->channels.defined()) { - ICHECK_EQ(param->kernel_size.size(), 3); - ICHECK_EQ(param->dilation.size(), 3); - Array wshape; - tvm::tir::ExprDeepEqual expr_equal; - - if (expr_equal(param->channels, param->groups) && !expr_equal(param->channels, 1)) { - // infer weight's shape for depthwise convolution - wshape = {{dshape_ncdhw[1], indexdiv(param->groups, dshape_ncdhw[1]), param->kernel_size[0], - param->kernel_size[1], param->kernel_size[2]}}; - } else { - wshape = {{param->channels, indexdiv(dshape_ncdhw[1], param->groups), param->kernel_size[0], - param->kernel_size[1], param->kernel_size[2]}}; - } - - wshape = trans_kernel_layout.BackwardShape(wshape); - channels = param->channels; - dilated_ksize_z = 1 + (param->kernel_size[0] - 1) * param->dilation[0]; - dilated_ksize_y = 1 + (param->kernel_size[1] - 1) * param->dilation[1]; - dilated_ksize_x = 1 + (param->kernel_size[2] - 1) * param->dilation[2]; - DataType weight_dtype = data->dtype; - if (weight != nullptr) { - weight_dtype = weight->dtype; - } + const TypeReporter& reporter); - if (param->auto_scheduler_rewritten_layout.size() == 0) { - // Normal case: assign result to reporter - reporter->Assign(types[1], TensorType(wshape, weight_dtype)); - } else { - // If the layout is rewritten by auto-scheduler, - // we just forcly apply the layout provided by auto-scheduler and - // skip the normal inference logic. - {} // do nothing - } - - } else { - // use weight to infer the conv shape. - if (weight == nullptr) return false; - - Array wshape; - if (param->auto_scheduler_rewritten_layout.size() == 0) { - wshape = weight->shape; - } else { - // works for the default kernel layout "DHWIO" - ICHECK_EQ(param->kernel_layout, "DHWIO"); - wshape = auto_scheduler::GetShapeFromRewrittenLayout(param->auto_scheduler_rewritten_layout, - {"rd", "rh", "rw", "rc", "cc"}); - } - - wshape = trans_kernel_layout.ForwardShape(wshape); - if (param->kernel_size.defined()) { - ICHECK_EQ(param->kernel_size.size(), 3); - // check the size - ICHECK(reporter->AssertEQ(param->kernel_size[0], wshape[2]) && - reporter->AssertEQ(param->kernel_size[1], wshape[3]) && - reporter->AssertEQ(param->kernel_size[2], wshape[4])) - << "Conv3D: shape of weight is inconsistent with kernel_size, " - << " kernel_size=" << param->kernel_size << " wshape=" << wshape; - } - - if (param->channels.defined()) { - ICHECK(reporter->AssertEQ(param->channels, wshape[0])) - << "Conv3D: shape of weight is inconsistent with channels, " - << " channels=" << param->channels << " wshape=" << wshape; - } - - if (!dshape_ncdhw[1].as() && !wshape[1].as()) { - ICHECK(reporter->AssertEQ(indexdiv(dshape_ncdhw[1], param->groups), wshape[1])); - } - channels = wshape[0]; - dilated_ksize_z = 1 + (wshape[2] - 1) * param->dilation[0]; - dilated_ksize_y = 1 + (wshape[3] - 1) * param->dilation[1]; - dilated_ksize_x = 1 + (wshape[4] - 1) * param->dilation[2]; - } - // dilation - Array oshape({dshape_ncdhw[0], channels, 0, 0, 0}); - - IndexExpr pad_d, pad_h, pad_w; - GetPaddingDepthHeightWidth(param->padding, &pad_d, &pad_h, &pad_w); - if (!dshape_ncdhw[2].as()) { - oshape.Set(2, indexdiv(dshape_ncdhw[2] + pad_d - dilated_ksize_z, param->strides[0]) + 1); - } else { - oshape.Set(2, dshape_ncdhw[2]); - } - - if (!dshape_ncdhw[3].as()) { - oshape.Set(3, indexdiv(dshape_ncdhw[3] + pad_h - dilated_ksize_y, param->strides[1]) + 1); - } else { - oshape.Set(3, dshape_ncdhw[3]); - } - - if (!dshape_ncdhw[4].as()) { - oshape.Set(4, indexdiv(dshape_ncdhw[4] + pad_w - dilated_ksize_x, param->strides[2]) + 1); - } else { - oshape.Set(4, dshape_ncdhw[4]); - } - DataType out_dtype = param->out_dtype; - if (out_dtype.bits() == 0) { - out_dtype = data->dtype; - } - oshape = trans_out_layout.BackwardShape(oshape); - // assign output type - reporter->Assign(types[2], TensorType(oshape, out_dtype)); - return true; -} - -// Winograd convolution shape relations -inline bool Conv2DWinogradWeightTransformRel(const Array& types, int num_inputs, - const Attrs& attrs, const TypeReporter& reporter) { - ICHECK_EQ(types.size(), 2); - const auto* data = types[0].as(); - if (data == nullptr) return false; - - const ConvWinogradWeightTransformAttrs* param = attrs.as(); - ICHECK(param != nullptr); - - ICHECK_EQ(data->shape.size(), 4) << "Only support NCHW normal kernel layout"; - - std::vector oshape{ - param->tile_size + data->shape[2] - 1, - param->tile_size + data->shape[3] - 1, - data->shape[0], - data->shape[1], - }; - - reporter->Assign(types[1], TensorType(Array(oshape), data->dtype)); - return true; -} - -// Gemm convolution shape relations -// In order to run GEMM we need to block-transpose and interleave the K x N weights matrix W. -// The high level idea is to subdivide W in tiles of tile_cols x tile_rows, and transpose and -// interleave them. The final output is a [N//tile_rows, K//tile_cols, tile_rows, tile_cols] -// matrix that we call W_interleaved_t. -// -// In the following picture, we show how the first [tile_cols,tile_rows] block of W is transformed -// for tile_rows = 4 and tile_cols = 16 -// -// W[0,0,:,:] W_interleaved_t[0,0,:,:] -// +-------------------------------+ +----------------------------------- + -// |W[0,0] W[0,1] W[0,2] W[0,3] | |W[0,0] W[1,0] W[2,0] ... W[15,0]| -// |W[1,0] W[1,1] W[1,2] W[1,3] | --\ |W[0,1] W[1,1] W[2,1] ... W[15,1]| -// |W[2,0] W[2,1] W[2,2] W[2,3] | --/ |W[0,2] W[1,2] W[2,2] ... W[15,2]| -// | ... ... ... ... | |W[0,3] W[1,3] W[2,3] ... W[15,3]| -// | ... ... ... ... | +------------------------------------+ -// |W[15,0] W[15,1] W[15,2] W[15,3]| -// +-------------------------------+ -// -// Tile columns is usually the direction of the reduction. So, if our target can reduce k elements -// at the time, we should set tile_cols = k. -// Tile rows is connected with the number of registers available for the given target. -// -inline bool Conv2DGemmWeightTransformRel(const Array& types, int num_inputs, - const Attrs& attrs, const TypeReporter& reporter) { - ICHECK_EQ(types.size(), 2); - const auto* weight = types[0].as(); - if (weight == nullptr) return false; - - const ConvGemmWeightTransformAttrs* param = attrs.as(); - ICHECK(param != nullptr); - int n = param->tile_rows; - int k = param->tile_cols; - - ICHECK_EQ(weight->shape.size(), 4) << "Only support HWIO kernel layout"; - - const auto K = weight->shape[0] * weight->shape[1] * weight->shape[2]; - const auto N = weight->shape[3]; - - auto K_mod_k = indexmod(K, k); - auto N_mod_n = indexmod(N, n); - - auto pad_K = tvm::if_then_else(K_mod_k != 0, k - K_mod_k, tir::make_zero(DataType::Int(32))); - auto pad_N = tvm::if_then_else(N_mod_n != 0, n - N_mod_n, tir::make_zero(DataType::Int(32))); - - const auto N_padded = N + pad_N; - const auto K_padded = K + pad_K; - - Array oshape{ - indexdiv(N_padded, n), - indexdiv(K_padded, k), - n, - k, - }; - - reporter->Assign(types[1], TensorType(oshape, weight->dtype)); - return true; -} - -inline bool Conv3DWinogradWeightTransformRel(const Array& types, int num_inputs, - const Attrs& attrs, const TypeReporter& reporter) { - ICHECK_EQ(types.size(), 2); - const auto* data = types[0].as(); - if (data == nullptr) return false; - - const ConvWinogradWeightTransformAttrs* param = attrs.as(); - ICHECK(param != nullptr); - - ICHECK_EQ(data->shape.size(), 5) << "Only support NCDHW normal kernel layout"; - - // Shape of packed weights depends on whether depth is being transformed or not. - Array oshape({0, 0, 0, data->shape[0], data->shape[1]}); - auto* depth_imm = data->shape[2].as(); - bool transform_depth = (depth_imm->value > 2) && (depth_imm->value < 8); - if (transform_depth) { - oshape.Set(0, param->tile_size + data->shape[2] - 1); - oshape.Set(1, param->tile_size + data->shape[3] - 1); - oshape.Set(2, param->tile_size + data->shape[4] - 1); - } else { - oshape.Set(0, param->tile_size + data->shape[3] - 1); - oshape.Set(1, param->tile_size + data->shape[4] - 1); - oshape.Set(2, data->shape[2]); - } - - reporter->Assign(types[1], TensorType(oshape, data->dtype)); - return true; -} - -inline bool Conv2DWinogradNNPACKWeightTransformRel(const Array& types, int num_inputs, - const Attrs& attrs, - const TypeReporter& reporter) { - ICHECK_EQ(types.size(), 2); - const auto* data = types[0].as(); - if (data == nullptr) { - return false; - } - - const Conv2DWinogradNNPACKWeightTransformAttrs* param = - attrs.as(); - ICHECK(param != nullptr); - - ICHECK_EQ(data->shape.size(), 4) << "Only support NCHW normal kernel layout"; - - std::vector oshape{ - data->shape[0], - data->shape[1], - 8, - 8, - }; - - DataType out_dtype = param->out_dtype; - if (out_dtype.bits() == 0) { - out_dtype = data->dtype; - } - reporter->Assign(types[1], TensorType(Array(oshape), out_dtype)); - return true; -} +bool Conv2DTransposeRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter); template bool Conv2DWinogradRel(const Array& types, int num_inputs, const Attrs& attrs, @@ -667,608 +120,6 @@ bool Conv2DWinogradRel(const Array& types, int num_inputs, const Attrs& at return true; } -template -bool Conv2DGemmRel(const Array& types, int num_inputs, const Attrs& attrs, - const TypeReporter& reporter) { - ICHECK_EQ(types.size(), 3); - const auto* data = types[0].as(); - if (data == nullptr) return false; - static const Layout kNHWC("NHWC"); - static const Layout kHWIO("HWIO"); - - const AttrType* param = attrs.as(); - ICHECK(param != nullptr); - const Layout in_layout(param->data_layout); - const Layout kernel_layout(param->kernel_layout); - - const auto trans_in_layout = tir::BijectiveLayout(in_layout, kNHWC); - ICHECK(trans_in_layout.defined()) - << "Conv only support input layouts that are convertible from NHWC." - << " But got " << in_layout; - - const auto trans_kernel_layout = tir::BijectiveLayout(kernel_layout, kHWIO); - ICHECK(trans_kernel_layout.defined()) - << "Conv only support kernel layouts that are convertible from HWIO." - << " But got " << kernel_layout; - - Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout); - const auto trans_out_layout = tir::BijectiveLayout(out_layout, kNHWC); - ICHECK(trans_out_layout.defined()) - << "Conv only support output layouts that are convertible from NHWC." - << " But got " << out_layout; - - Array dshape_nhwc = trans_in_layout.ForwardShape(data->shape); - - IndexExpr channels, dilated_ksize_y, dilated_ksize_x; - - ICHECK(param->kernel_size.defined() && param->channels.defined()) - << "The kernel size and channels of a Conv must be set or inferred by previous pass"; - - ICHECK_EQ(param->kernel_size.size(), 2); - ICHECK_EQ(param->dilation.size(), 2); - - channels = param->channels; - dilated_ksize_y = 1 + (param->kernel_size[0] - 1) * param->dilation[0]; - dilated_ksize_x = 1 + (param->kernel_size[1] - 1) * param->dilation[1]; - - // NOTE: Do not check weight shape here! - - // dilation - Array oshape({dshape_nhwc[0], 0, 0, channels}); - - IndexExpr pad_h, pad_w; - GetPaddingHeightWidth(param->padding, &pad_h, &pad_w); - if (!dshape_nhwc[2].as()) { - oshape.Set(1, (dshape_nhwc[1] + pad_h - dilated_ksize_y) / param->strides[0] + 1); - } else { - oshape.Set(1, dshape_nhwc[1]); - } - if (!dshape_nhwc[3].as()) { - oshape.Set(2, (dshape_nhwc[2] + pad_w - dilated_ksize_x) / param->strides[1] + 1); - } else { - oshape.Set(2, dshape_nhwc[2]); - } - - DataType out_dtype = param->out_dtype; - if (out_dtype.bits() == 0) { - out_dtype = data->dtype; - } - oshape = trans_out_layout.BackwardShape(oshape); - // assign output type - reporter->Assign(types[2], TensorType(oshape, out_dtype)); - return true; -} - -template -bool Conv3DWinogradRel(const Array& types, int num_inputs, const Attrs& attrs, - const TypeReporter& reporter) { - ICHECK_EQ(types.size(), 3); - const auto* data = types[0].as(); - if (data == nullptr) return false; - static const Layout kNCDHW("NCDHW"); - static const Layout kOIDHW("OIDHW"); - - const AttrType* param = attrs.as(); - ICHECK(param != nullptr); - const Layout in_layout(param->data_layout); - const Layout kernel_layout(param->kernel_layout); - - const auto trans_in_layout = tir::BijectiveLayout(in_layout, kNCDHW); - ICHECK(trans_in_layout.defined()) - << "Conv only support input layouts that are convertible from NCDHW." - << " But got " << in_layout; - - const auto trans_kernel_layout = tir::BijectiveLayout(kernel_layout, kOIDHW); - ICHECK(trans_kernel_layout.defined()) - << "Conv only support kernel layouts that are convertible from OIDHW." - << " But got " << kernel_layout; - - Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout); - const auto trans_out_layout = tir::BijectiveLayout(out_layout, kNCDHW); - ICHECK(trans_out_layout.defined()) - << "Conv only support output layouts that are convertible from NCDHW." - << " But got " << out_layout; - - Array dshape_ncdhw = trans_in_layout.ForwardShape(data->shape); - - IndexExpr channels, dilated_ksize_d, dilated_ksize_y, dilated_ksize_x; - - ICHECK(param->kernel_size.defined() && param->channels.defined()) - << "The kernel size and channels of a Conv must be set or inferred by previous pass"; - - ICHECK_EQ(param->kernel_size.size(), 3); - ICHECK_EQ(param->dilation.size(), 3); - - channels = param->channels; - dilated_ksize_d = 1 + (param->kernel_size[0] - 1) * param->dilation[0]; - dilated_ksize_y = 1 + (param->kernel_size[1] - 1) * param->dilation[1]; - dilated_ksize_x = 1 + (param->kernel_size[2] - 1) * param->dilation[2]; - - // NOTE: Do not check weight shape here! - // Different backend requires different layout to compute - // the batch gemm stage in winograd efficiently, but we want to - // make this op work for all backends. - // So we accept all weight shapes, and assume the TOPI developers - // can handle this correctly in alter_op_layout. - - // dilation - Array oshape({dshape_ncdhw[0], channels, 0, 0, 0}); - - IndexExpr pad_d, pad_h, pad_w; - GetPaddingDepthHeightWidth(param->padding, &pad_d, &pad_h, &pad_w); - if (!dshape_ncdhw[2].as()) { - oshape.Set(2, (dshape_ncdhw[2] + pad_d - dilated_ksize_d) / param->strides[0] + 1); - } else { - oshape.Set(2, dshape_ncdhw[2]); - } - if (!dshape_ncdhw[2].as()) { - oshape.Set(3, (dshape_ncdhw[3] + pad_h - dilated_ksize_y) / param->strides[1] + 1); - } else { - oshape.Set(3, dshape_ncdhw[3]); - } - if (!dshape_ncdhw[4].as()) { - oshape.Set(4, (dshape_ncdhw[4] + pad_w - dilated_ksize_x) / param->strides[2] + 1); - } else { - oshape.Set(4, dshape_ncdhw[4]); - } - - DataType out_dtype = param->out_dtype; - if (out_dtype.bits() == 0) { - out_dtype = data->dtype; - } - oshape = trans_out_layout.BackwardShape(oshape); - // assign output type - reporter->Assign(types[2], TensorType(oshape, out_dtype)); - return true; -} - -// Transposed convolution shape relations -template -bool Conv1DTransposeRel(const Array& types, int num_inputs, const Attrs& attrs, - const TypeReporter& reporter) { - ICHECK_EQ(types.size(), 3); - const auto* data = types[0].as(); - const auto* weight = types[1].as(); - if (data == nullptr) return false; - - static const Layout kNCW("NCW"); - static const Layout kOIW("OIW"); - - const Conv1DTransposeAttrs* param = attrs.as(); - ICHECK(param != nullptr); - const Layout in_layout(param->data_layout); - const Layout kernel_layout(param->kernel_layout); - - const auto trans_in_layout = tir::BijectiveLayout(in_layout, kNCW); - ICHECK(trans_in_layout.defined()) - << "Conv only support input layouts that are convertible from NCW." - << " But got " << in_layout; - - const auto trans_kernel_layout = tir::BijectiveLayout(kernel_layout, kOIW); - ICHECK(trans_kernel_layout.defined()) - << "Conv only support kernel layouts that are convertible from OIW." - << " But got " << kernel_layout; - - Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout); - const auto trans_out_layout = tir::BijectiveLayout(out_layout, kNCW); - ICHECK(trans_out_layout.defined()) - << "Conv only support output layouts that are convertible from NCW." - << " But got " << out_layout; - - IndexExpr channels, dilated_ksize_y, dilated_ksize_x; - - auto dshape_ncw = trans_in_layout.ForwardShape(data->shape); - - // infer weight if the kernel_size and channels are defined - if (param->kernel_size.defined() && param->channels.defined()) { - ICHECK_EQ(param->kernel_size.size(), 1); - ICHECK_EQ(param->dilation.size(), 1); - - Array wshape( - {dshape_ncw[1], indexdiv(param->channels, param->groups), param->kernel_size[0]}); - - wshape = trans_kernel_layout.BackwardShape(wshape); - dilated_ksize_x = 1 + (param->kernel_size[0] - 1) * param->dilation[0]; - channels = param->channels; - - DataType weight_dtype = data->dtype; - if (weight != nullptr) { - weight_dtype = weight->dtype; - } - // assign result to reporter - reporter->Assign(types[1], TensorType(wshape, weight_dtype)); - } else { - // use weight to infer the conv shape. - if (weight == nullptr) return false; - auto wshape = trans_kernel_layout.ForwardShape(weight->shape); - if (param->kernel_size.defined()) { - ICHECK_EQ(param->kernel_size.size(), 1); - // check the size - ICHECK(reporter->AssertEQ(param->kernel_size[0], wshape[2])) - << "Conv1D: shape of weight is inconsistent with kernel_size, " - << " kernel_size=" << param->kernel_size << " wshape=" << Array(wshape); - } - if (param->channels.defined()) { - ICHECK(reporter->AssertEQ(param->channels, wshape[1])) - << "Conv1D: shape of weight is inconsistent with channels, " - << " channels=" << param->channels << " wshape=" << Array(wshape); - } - if (!dshape_ncw[1].as() && !wshape[0].as()) { - ICHECK(reporter->AssertEQ(indexdiv(dshape_ncw[1], param->groups), wshape[0])); - } - channels = wshape[1]; - dilated_ksize_x = 1 + (wshape[2] - 1) * param->dilation[0]; - } - // dilation - IndexExpr pad_w; - GetPaddingWidth(param->padding, &pad_w); - Array oshape({dshape_ncw[0], channels, 0}); - if (!dshape_ncw[2].as()) { - oshape.Set(2, (param->strides[0] * (dshape_ncw[2] - 1) + dilated_ksize_x - pad_w + - param->output_padding[0])); - } else { - oshape.Set(2, dshape_ncw[2]); - } - - DataType out_dtype = param->out_dtype; - if (out_dtype.bits() == 0) { - out_dtype = data->dtype; - } - oshape = trans_out_layout.BackwardShape(oshape); - reporter->Assign(types[2], TensorType(oshape, out_dtype)); - return true; -} - -template -bool Conv3DTransposeRel(const Array& types, int num_inputs, const Attrs& attrs, - const TypeReporter& reporter) { - ICHECK_EQ(types.size(), 3); - const auto* data = types[0].as(); - const auto* weight = types[1].as(); - if (data == nullptr) return false; - - static const Layout kNCDHW("NCDHW"); - static const Layout kOIDHW("OIDHW"); - - const Conv3DTransposeAttrs* param = attrs.as(); - ICHECK(param != nullptr); - const Layout in_layout(param->data_layout); - const Layout kernel_layout(param->kernel_layout); - - const auto trans_in_layout = tir::BijectiveLayout(in_layout, kNCDHW); - ICHECK(trans_in_layout.defined()) - << "Conv3d_transpose only support input layouts that are convertible from NCDHW." - << " But got " << in_layout; - - const auto trans_kernel_layout = tir::BijectiveLayout(kernel_layout, kOIDHW); - ICHECK(trans_kernel_layout.defined()) - << "Conv3d_transpose only support kernel layouts that are convertible from OIDHW." - << " But got " << kernel_layout; - - Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout); - const auto trans_out_layout = tir::BijectiveLayout(out_layout, kNCDHW); - ICHECK(trans_out_layout.defined()) - << "Conv3d_transpose only support output layouts that are convertible from NCDHW." - << " But got " << out_layout; - - IndexExpr channels, dilated_ksize_d, dilated_ksize_y, dilated_ksize_x; - - auto dshape_ncdhw = trans_in_layout.ForwardShape(data->shape); - - // infer weight if the kernel_size and channels are defined - if (param->kernel_size.defined() && param->channels.defined()) { - ICHECK_EQ(param->kernel_size.size(), 3); - ICHECK_EQ(param->dilation.size(), 3); - - Array wshape({dshape_ncdhw[1], indexdiv(param->channels, param->groups), - param->kernel_size[0], param->kernel_size[1], param->kernel_size[2]}); - - wshape = trans_kernel_layout.BackwardShape(wshape); - dilated_ksize_d = 1 + (param->kernel_size[0] - 1) * param->dilation[0]; - dilated_ksize_y = 1 + (param->kernel_size[1] - 1) * param->dilation[1]; - dilated_ksize_x = 1 + (param->kernel_size[2] - 1) * param->dilation[2]; - channels = param->channels; - - DataType weight_dtype = data->dtype; - if (weight != nullptr) { - weight_dtype = weight->dtype; - } - // assign result to reporter - reporter->Assign(types[1], TensorType(wshape, weight_dtype)); - } else { - // use weight to infer the conv shape. - if (weight == nullptr) return false; - auto wshape = trans_kernel_layout.ForwardShape(weight->shape); - if (param->kernel_size.defined()) { - ICHECK_EQ(param->kernel_size.size(), 3); - // check the size - ICHECK(reporter->AssertEQ(param->kernel_size[0], wshape[2]) && - reporter->AssertEQ(param->kernel_size[1], wshape[3]) && - reporter->AssertEQ(param->kernel_size[2], wshape[4])) - << "Conv3D: shape of weight is inconsistent with kernel_size, " - << " kernel_size=" << param->kernel_size << " wshape=" << Array(wshape); - } - if (param->channels.defined()) { - ICHECK(reporter->AssertEQ(param->channels, wshape[1])) - << "Conv3D: shape of weight is inconsistent with channels, " - << " channels=" << param->channels << " wshape=" << Array(wshape); - } - if (!dshape_ncdhw[1].as() && !wshape[0].as()) { - ICHECK(reporter->AssertEQ(indexdiv(dshape_ncdhw[1], param->groups), wshape[0])); - } - channels = wshape[1]; - dilated_ksize_d = 1 + (wshape[2] - 1) * param->dilation[0]; - dilated_ksize_x = 1 + (wshape[3] - 1) * param->dilation[1]; - dilated_ksize_y = 1 + (wshape[4] - 1) * param->dilation[2]; - } - - // dilation - Array oshape({dshape_ncdhw[0], channels, 0, 0, 0}); - IndexExpr pad_d, pad_h, pad_w; - GetPaddingDepthHeightWidth(param->padding, &pad_d, &pad_h, &pad_w); - - if (!dshape_ncdhw[2].as()) { - oshape.Set(2, (param->strides[0] * (dshape_ncdhw[2] - 1) + dilated_ksize_d - pad_d + - param->output_padding[0])); - } else { - oshape.Set(2, dshape_ncdhw[2]); - } - if (!dshape_ncdhw[3].as()) { - oshape.Set(3, (param->strides[1] * (dshape_ncdhw[3] - 1) + dilated_ksize_y - pad_h + - param->output_padding[1])); - } else { - oshape.Set(3, dshape_ncdhw[3]); - } - if (!dshape_ncdhw[4].as()) { - oshape.Set(4, (param->strides[2] * (dshape_ncdhw[4] - 1) + dilated_ksize_x - pad_w + - param->output_padding[2])); - } else { - oshape.Set(4, dshape_ncdhw[4]); - } - - DataType out_dtype = param->out_dtype; - if (out_dtype.bits() == 0) { - out_dtype = data->dtype; - } - oshape = trans_out_layout.BackwardShape(oshape); - reporter->Assign(types[2], TensorType(oshape, out_dtype)); - return true; -} - -template -bool Conv2DTransposeRel(const Array& types, int num_inputs, const Attrs& attrs, - const TypeReporter& reporter) { - ICHECK_EQ(types.size(), 3); - const auto* data = types[0].as(); - const auto* weight = types[1].as(); - if (data == nullptr) return false; - - static const Layout kNCHW("NCHW"); - static const Layout kIOHW("IOHW"); - - const Conv2DTransposeAttrs* param = attrs.as(); - ICHECK(param != nullptr); - const Layout in_layout(param->data_layout); - const Layout kernel_layout(param->kernel_layout); - - const auto trans_in_layout = tir::BijectiveLayout(in_layout, kNCHW); - ICHECK(trans_in_layout.defined()) - << "Conv2DTransposed only support input layouts that are convertible from NCHW." - << " But got " << in_layout; - - const auto trans_kernel_layout = tir::BijectiveLayout(kernel_layout, kIOHW); - ICHECK(trans_kernel_layout.defined()) - << "Conv2DTransposed only support kernel layouts that are convertible from IOHW." - << " But got " << kernel_layout; - - Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout); - const auto trans_out_layout = tir::BijectiveLayout(out_layout, kNCHW); - ICHECK(trans_out_layout.defined()) - << "Conv2DTransposed only support output layouts that are convertible from NCHW." - << " But got " << out_layout; - - IndexExpr channels, dilated_ksize_y, dilated_ksize_x; - - auto dshape_nchw = trans_in_layout.ForwardShape(data->shape); - - // infer weight if the kernel_size and channels are defined - if (param->kernel_size.defined() && param->channels.defined()) { - ICHECK_EQ(param->kernel_size.size(), 2); - ICHECK_EQ(param->dilation.size(), 2); - - Array wshape({dshape_nchw[1], indexdiv(param->channels, param->groups), - param->kernel_size[0], param->kernel_size[1]}); - - wshape = trans_kernel_layout.BackwardShape(wshape); - dilated_ksize_y = 1 + (param->kernel_size[0] - 1) * param->dilation[0]; - dilated_ksize_x = 1 + (param->kernel_size[1] - 1) * param->dilation[1]; - channels = param->channels; - - DataType weight_dtype = data->dtype; - if (weight != nullptr) { - weight_dtype = weight->dtype; - } - // assign result to reporter - reporter->Assign(types[1], TensorType(wshape, weight_dtype)); - } else { - // use weight to infer the conv shape. - if (weight == nullptr) return false; - auto wshape = trans_kernel_layout.ForwardShape(weight->shape); - if (param->kernel_size.defined()) { - ICHECK_EQ(param->kernel_size.size(), 2); - // check the size - ICHECK(reporter->AssertEQ(param->kernel_size[0], wshape[2]) && - reporter->AssertEQ(param->kernel_size[1], wshape[3])) - << "Conv2DTransposed: shape of weight is inconsistent with kernel_size, " - << " kernel_size=" << param->kernel_size << " wshape=" << Array(wshape); - } - if (param->channels.defined()) { - ICHECK(reporter->AssertEQ(indexdiv(param->channels, param->groups), wshape[1])) - << "Conv2DTransposed: shape of weight is inconsistent with out_channels, " - << " out_channels // groups != weight.shape[1] " - << " out_channels=" << param->channels << " groups=" << param->groups - << " weight.shape=" << Array(wshape); - } - if (!dshape_nchw[1].as() && !wshape[0].as()) { - ICHECK(reporter->AssertEQ(dshape_nchw[1], wshape[0])) - << "Conv2DTransposed: shape of weight is inconsistent with in_channels." - << " data.shape= " << Array(dshape_nchw) << " groups= " << param->groups - << " weight.shape= " << Array(wshape); - } - channels = wshape[1]; - dilated_ksize_y = 1 + (wshape[2] - 1) * param->dilation[0]; - dilated_ksize_x = 1 + (wshape[3] - 1) * param->dilation[1]; - } - // dilation - Array oshape({dshape_nchw[0], channels, 0, 0}); - IndexExpr pad_h, pad_w; - GetPaddingHeightWidth(param->padding, &pad_h, &pad_w); - if (!dshape_nchw[2].as()) { - oshape.Set(2, (param->strides[0] * (dshape_nchw[2] - 1) + dilated_ksize_y - pad_h + - param->output_padding[0])); - } else { - oshape.Set(2, dshape_nchw[2]); - } - if (!dshape_nchw[3].as()) { - oshape.Set(3, (param->strides[1] * (dshape_nchw[3] - 1) + dilated_ksize_x - pad_w + - param->output_padding[1])); - } else { - oshape.Set(3, dshape_nchw[3]); - } - - DataType out_dtype = param->out_dtype; - if (out_dtype.bits() == 0) { - out_dtype = data->dtype; - } - oshape = trans_out_layout.BackwardShape(oshape); - reporter->Assign(types[2], TensorType(oshape, out_dtype)); - return true; -} - -// Deformable Convolution shape relations. -template -bool DeformableConv2DRel(const Array& types, int num_inputs, const Attrs& attrs, - const TypeReporter& reporter) { - ICHECK_EQ(types.size(), 4); - const auto* data = types[0].as(); - const auto* weight = types[2].as(); - - ICHECK(data); - static const Layout kNCHW("NCHW"); - static const Layout kOIHW("OIHW"); - - auto* param = attrs.as(); - ICHECK(param != nullptr); - const Layout in_layout(param->data_layout); - const Layout kernel_layout(param->kernel_layout); - - const auto trans_in_layout = tir::BijectiveLayout(in_layout, kNCHW); - if (!trans_in_layout.defined()) { - reporter->GetDiagCtx().Emit( - Diagnostic::Error(reporter->GetSpan()) - << "deformable_conv2d only support input layouts that are convertible from NCHW." - << " The provided layout is: " << in_layout); - return false; - } - - const auto trans_kernel_layout = tir::BijectiveLayout(kernel_layout, kOIHW); - if (!trans_kernel_layout.defined()) { - reporter->GetDiagCtx().Emit( - Diagnostic::Error(reporter->GetSpan()) - << "deformable_conv2d only support kernel layouts that are convertible from OIHW." - << " The provided layout is: " << kernel_layout); - return false; - } - - Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout); - const auto trans_out_layout = tir::BijectiveLayout(out_layout, kNCHW); - if (!trans_out_layout.defined()) { - reporter->GetDiagCtx().Emit( - Diagnostic::Error(reporter->GetSpan()) - << "deformable_conv2d only support output layouts that are convertible from NCHW." - << "The provided layout is: " << out_layout); - return false; - } - - Array dshape_nchw = trans_in_layout.ForwardShape(data->shape); - - IndexExpr channels, dilated_ksize_y, dilated_ksize_x, ksize_y, ksize_x; - - // infer weight shape if kernel_size and channels are defiend - if (param->kernel_size.defined() && param->channels.defined()) { - ICHECK_EQ(param->kernel_size.size(), 2); - ICHECK_EQ(param->dilation.size(), 2); - Array wshape({param->channels, indexdiv(dshape_nchw[1], param->groups), - param->kernel_size[0], param->kernel_size[1]}); - - wshape = trans_kernel_layout.BackwardShape(wshape); - channels = param->channels; - ksize_y = param->kernel_size[0]; - ksize_x = param->kernel_size[1]; - dilated_ksize_y = 1 + (param->kernel_size[0] - 1) * param->dilation[0]; - dilated_ksize_x = 1 + (param->kernel_size[1] - 1) * param->dilation[1]; - // assign result to reporter - reporter->Assign(types[2], TensorType(wshape, data->dtype)); - } else { - // use weight to infer the conv shape. - if (weight == nullptr) return false; - auto wshape = trans_kernel_layout.ForwardShape(weight->shape); - - if (param->kernel_size.defined()) { - ICHECK_EQ(param->kernel_size.size(), 2); - // check the size - ICHECK(reporter->AssertEQ(param->kernel_size[0], wshape[2]) && - reporter->AssertEQ(param->kernel_size[1], wshape[3])) - << "DeformableConv2D: shape of weight is inconsistent with kernel_size, " - << " kernel_size=" << param->kernel_size << " wshape=" << wshape; - } - if (param->channels.defined()) { - ICHECK(reporter->AssertEQ(param->channels, wshape[0])) - << "DeformableConv2D: shape of weight is inconsistent with channels, " - << " channels=" << param->channels << " wshape=" << wshape; - } - if (!dshape_nchw[1].as() && !wshape[1].as()) { - ICHECK(reporter->AssertEQ(indexdiv(dshape_nchw[1], param->groups), wshape[1])); - } - channels = wshape[0]; - ksize_y = wshape[2]; - ksize_x = wshape[3]; - dilated_ksize_y = 1 + (wshape[2] - 1) * param->dilation[0]; - dilated_ksize_x = 1 + (wshape[3] - 1) * param->dilation[1]; - } - // dilation - Array oshape({dshape_nchw[0], channels, 0, 0}); - - IndexExpr pad_h, pad_w; - GetPaddingHeightWidth(param->padding, &pad_h, &pad_w); - oshape.Set(2, indexdiv(dshape_nchw[2] + pad_h - dilated_ksize_y, param->strides[0]) + 1); - oshape.Set(3, indexdiv(dshape_nchw[3] + pad_w - dilated_ksize_x, param->strides[1]) + 1); - DataType out_dtype = param->out_dtype; - - // infer offset shape - Array offset_shape( - {dshape_nchw[0], 2 * ksize_y * ksize_x * param->deformable_groups, oshape[2], oshape[3]}); - offset_shape = trans_in_layout.BackwardShape(offset_shape); - reporter->Assign(types[1], TensorType(offset_shape, data->dtype)); - if (out_dtype.bits() == 0) { - out_dtype = data->dtype; - } - - oshape = trans_out_layout.BackwardShape(oshape); - reporter->Assign(types[3], TensorType(oshape, out_dtype)); - return true; -} - -template -InferCorrectLayoutOutput DeformableConvInferCorrectLayout( - const Attrs& attrs, const Array& new_in_layouts, const Array& old_in_layouts, - const Array& old_in_types) { - const AttrType* params = attrs.as(); - return InferCorrectLayoutOutput( - {params->data_layout, params->data_layout, params->kernel_layout}, - {params->out_layout == "" ? params->data_layout : params->out_layout}, attrs); -} - template InferCorrectLayoutOutput ConvInferCorrectLayout(const Attrs& attrs, const Array& new_in_layouts, diff --git a/src/relay/qnn/op/convolution.cc b/src/relay/qnn/op/convolution.cc index 03fa770e404f..8a7521e8ee50 100644 --- a/src/relay/qnn/op/convolution.cc +++ b/src/relay/qnn/op/convolution.cc @@ -84,7 +84,7 @@ bool QnnConv2DRel(const Array& types, int num_inputs, const Attrs& attrs, // Collect the input tensor and output tensor devoid of scale and zero points to reuse Relay // Conv2D infer type function. Array tensor_types = {types[0], types[1], types[6]}; - return Conv2DRel(tensor_types, 3, attrs, reporter); + return Conv2DRel(tensor_types, 3, attrs, reporter); } InferCorrectLayoutOutput QnnConvInferCorrectLayout(const Attrs& attrs, diff --git a/src/relay/qnn/op/convolution_transpose.cc b/src/relay/qnn/op/convolution_transpose.cc index b9227ff96c2a..cdc6c8d98f3a 100644 --- a/src/relay/qnn/op/convolution_transpose.cc +++ b/src/relay/qnn/op/convolution_transpose.cc @@ -128,7 +128,7 @@ bool QnnConv2DTransposeRel(const Array& types, int num_inputs, const Attrs // Collect the input tensor and output tensor devoid of scale and zero points to reuse Relay // Conv2D infer type function. Array tensor_types = {types[0], types[1], types[6]}; - return Conv2DTransposeRel(tensor_types, 3, attrs, reporter); + return Conv2DTransposeRel(tensor_types, 3, attrs, reporter); } RELAY_REGISTER_OP("qnn.conv2d_transpose")