Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 6 additions & 8 deletions include/tvm/relay/attrs/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -891,10 +891,9 @@ struct MaxPool1DAttrs : public tvm::AttrsNode<MaxPool1DAttrs> {
.set_default(Array<IndexExpr>({0}))
.describe(
"If padding is non-zero, then the input is implicitly zero-padded"
"Padding support both symmetric and asymmetric as"
"one int : same padding used on all sides"
"three int : back, bottom, right will use same padding as front, top, left"
"six int : padding width in the order of (front, top, left, back, bottom, right)");
"Padding supports both symmetric and asymmetric as"
"one int : same padding used on each side"
"two int : indicates left padding, right padding");
TVM_ATTR_FIELD(layout).set_default("NCW").describe(
"Dimension ordering of input data. Can be 'NCW', 'NWC', etc."
"'N', 'C', 'W' stands for batch, channel, and width"
Expand Down Expand Up @@ -933,10 +932,9 @@ struct AvgPool1DAttrs : public tvm::AttrsNode<AvgPool1DAttrs> {
.set_default(Array<IndexExpr>({0}))
.describe(
"If padding is non-zero, then the input is implicitly zero-padded"
"Padding support both symmetric and asymmetric as"
"one int : same padding used on all sides"
"three int : back, bottom, right will use same padding as front, top, left"
"six int : padding width in the order of (front, top, left, back, bottom, right)");
"Padding supports both symmetric and asymmetric as"
"one int : same padding used on each side"
"two int : indicates left padding, right padding");
TVM_ATTR_FIELD(layout).set_default("NCW").describe(
"Dimension ordering of input data. Can be 'NCW', 'NHC', etc."
"'N', 'C', 'W' stands for batch, channel, and width"
Expand Down
206 changes: 170 additions & 36 deletions src/relay/transforms/fold_explicit_padding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,15 @@
* \brief A pass for folding explicit pads into other ops.
*/

#include <dmlc/optional.h>
#include <tvm/relay/dataflow_matcher.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/transform.h>
#include <tvm/runtime/data_type.h>
#include <tvm/runtime/logging.h>
#include <tvm/tir/op.h>
#include <tvm/topi/nn/pooling.h>

#include "../op/tensor/transform.h"
#include "pattern_utils.h"
Expand All @@ -35,46 +39,70 @@ namespace tvm {
namespace relay {

/*!
* \brief SimplifyConvPad matches a pad followed by a conv/convtranspose/pool/etc
* \brief SimplifyExplicitPad matches a pad followed by a conv/maxpool/avgpool
* with a pad attribute and merges the padding into the kernel.
*/
class SimplifyConvPad {
class SimplifyExplicitPad {
public:
DFPattern pattern() const { return pattern_; }

SimplifyConvPad() {
SimplifyExplicitPad() {
x_ = IsWildcard();
w_ = IsWildcard();
pad_ = IsOp("nn.pad")({x_, IsWildcard()});

// pad->conv patterns
w_ = IsWildcard();
conv1d_ = IsOp("nn.conv1d");
conv2d_ = IsOp("nn.conv2d");
conv3d_ = IsOp("nn.conv3d");

conv_ = (conv1d_ || conv2d_ || conv3d_)({pad_, w_});
contrib_conv2d_nchwc_ = IsOp("nn.contrib_conv2d_NCHWc");
conv_ = (conv1d_ || conv2d_ || conv3d_ || contrib_conv2d_nchwc_)({pad_, w_});

input_zero_point_ = IsWildcard();
kernel_zero_point_ = IsWildcard();
input_scale_ = IsWildcard();
kernel_scale_ = IsWildcard();

qconv2d_ = IsOp("qnn.conv2d")(
{pad_, w_, input_zero_point_, kernel_zero_point_, input_scale_, kernel_scale_});

pattern_ = conv_ || qconv2d_;
// pad->pool patterns
avg_pool1d_ = IsOp("nn.avg_pool1d");
avg_pool2d_ = IsOp("nn.avg_pool2d");
avg_pool3d_ = IsOp("nn.avg_pool3d");
max_pool1d_ = IsOp("nn.max_pool1d");
max_pool2d_ = IsOp("nn.max_pool2d");
max_pool3d_ = IsOp("nn.max_pool3d");
max_pool_ = max_pool1d_ || max_pool2d_ || max_pool3d_;
pool_ = (max_pool_ || avg_pool1d_ || avg_pool2d_ || avg_pool3d_)({pad_});

pattern_ = conv_ || qconv2d_ || pool_;
}

template <typename T>
Attrs MakeConvAttrs(const T* old_attrs, const Array<PrimExpr> padding) const {
ICHECK(old_attrs);
Array<PrimExpr> get_combined_padding(const T* old_attrs, Array<PrimExpr> padding) const {
ICHECK(padding.size() == old_attrs->padding.size())
<< "Number of dimensions to pad and convolution padding attributes should have the same "
"extent";

auto new_attrs = make_object<T>();
Array<PrimExpr> combined_padding;
for (size_t i = 0; i < padding.size(); ++i) {
combined_padding.push_back(padding[i] + old_attrs->padding[i]);
}
return combined_padding;
}

template <typename T>
Attrs MakeConvAttrs(const PadAttrs* param, const T* old_attrs) const {
// Creates attrs from old_attrs with fields shared by 1D, 2D, 3D conv attrs
ICHECK(old_attrs);
ICHECK(param);
auto padding = get_padding(param, old_attrs->data_layout);
if (!padding) {
return Attrs();
}
auto combined_padding = get_combined_padding(old_attrs, padding.value());

auto new_attrs = make_object<T>();
new_attrs->strides = old_attrs->strides;
new_attrs->padding = combined_padding;
new_attrs->dilation = old_attrs->dilation;
Expand All @@ -89,22 +117,85 @@ class SimplifyConvPad {
}

template <typename T>
Attrs GetAttrs(const PadAttrs* param, const T* attrs) const {
Attrs MakeConv2D3DAttrs(const PadAttrs* param, const T* old_attrs) const {
// Propagate additional Conv2D- and Conv3D-specific attrs
auto attrs = MakeConvAttrs(param, old_attrs);
if (!attrs.defined()) {
return Attrs();
}

T* new_attrs = const_cast<T*>(attrs.template as<T>());
new_attrs->auto_scheduler_rewritten_layout = old_attrs->auto_scheduler_rewritten_layout;
return attrs;
}

template <typename T>
Attrs MakePoolAttrs(const PadAttrs* param, const T* old_attrs) const {
// Creates attrs from old_attrs with fields shared by 1D, 2D, 3D pool attrs
ICHECK(old_attrs);
ICHECK(param);
ICHECK(attrs);
ICHECK(attrs->data_layout.size() == param->pad_width.size())
auto padding = get_padding(param, old_attrs->layout);
if (!padding) {
return Attrs();
}
auto combined_padding = get_combined_padding(old_attrs, padding.value());

auto new_attrs = make_object<T>();
new_attrs->pool_size = old_attrs->pool_size;
new_attrs->strides = old_attrs->strides;
new_attrs->dilation = old_attrs->dilation;
new_attrs->padding = combined_padding;
new_attrs->layout = old_attrs->layout;
new_attrs->out_layout = old_attrs->out_layout;
new_attrs->ceil_mode = old_attrs->ceil_mode;
return Attrs(new_attrs);
}

template <typename T>
Attrs MakeAvgPoolAttrs(const PadAttrs* param, const T* old_attrs) const {
// Propagate additional AvgPool-specific attrs
auto attrs = MakePoolAttrs(param, old_attrs);
if (!attrs.defined()) {
return attrs;
}

T* new_attrs = const_cast<T*>(attrs.template as<T>());
new_attrs->count_include_pad = old_attrs->count_include_pad;
if (!new_attrs->count_include_pad) {
// AvgPool's divisor doesn't include padding, so don't fold the explicit pad
// unless all original pad items are 0.
for (IndexExpr pad : old_attrs->padding) {
const IntImmNode* maybe_int_imm = pad.as<IntImmNode>();
if (!maybe_int_imm || maybe_int_imm->value != 0) {
// Return undefined attrs to signal that we don't want to fold explicit pad
return Attrs();
}
}
// Turn on `count_include_pad` to preserve original pad first, then pool behavior
// where AvgPool's divisor implicitly includes padding.
new_attrs->count_include_pad = true;
}

return attrs;
}

static const Optional<Array<PrimExpr>> get_padding(const PadAttrs* param,
std::string data_layout) {
// Gets spatial axes padding from the given PadAttrs `param`. If padding
// is non-zero on non-spatial axes, return NullOpt.
ICHECK(param);
ICHECK(data_layout.size() == param->pad_width.size())
<< "Data Layout and padding attributes should have the same extent";

std::string data_layout = attrs->data_layout;
std::set<char> image_dims({'H', 'W', 'D'});
Array<PrimExpr> padding;
// If we're padding a non-spatial dimension, don't simplify
// Convolution can only pad on spatial axes
// Convolution/Pool can only pad on spatial axes
for (size_t i = 0; i < param->pad_width.size(); ++i) {
if (!image_dims.count(data_layout[i])) {
for (size_t j = 0; j < param->pad_width[i].size(); ++j) {
if (param->pad_width[i][j] != 0) {
return Attrs();
return NullOpt;
}
}
}
Expand All @@ -116,8 +207,7 @@ class SimplifyConvPad {
}
}
}

return MakeConvAttrs(attrs, padding);
return padding;
}

Expr callback(const Expr& pre, const Expr& post,
Expand All @@ -131,40 +221,75 @@ class SimplifyConvPad {
ICHECK(param);

auto x = node_map[x_][0];
auto w = node_map[w_][0];

// Possibly perform more optimizations if the pad_value is 0
const Expr& pv = pad_node->args[1];
const ConstantNode* pad_value = pv.as<ConstantNode>();
auto pad_scalar = ToScalar(pad_value->data);

if (node_map.find(qconv2d_) != node_map.end()) {
Attrs attrs = GetAttrs(param, call_node->attrs.as<Conv2DAttrs>());
Attrs attrs = MakeConv2D3DAttrs(param, call_node->attrs.as<Conv2DAttrs>());
if (!attrs.defined()) {
return post;
}
auto input_zero_point = node_map[input_zero_point_][0];
auto kernel_zero_point = node_map[kernel_zero_point_][0];
auto input_scale = node_map[input_scale_][0];
auto kernel_scale = node_map[kernel_scale_][0];
// Fold Padding and QNN Convolution only if pad value == input zero point.
if (IsEqualScalar(input_zero_point, pv)) {
auto w = node_map[w_][0];
return Call(call_node->op,
{x, w, input_zero_point, kernel_zero_point, input_scale, kernel_scale}, attrs,
call_node->type_args, call_node->span);
} else {
return post;
}
} else if (param->pad_mode == "constant" && pad_value && ToScalar(pad_value->data) == 0.0) {
return post;
}

if (param->pad_mode == "constant" && pad_value) {
Attrs attrs;
if (node_map.count(conv1d_)) {
attrs = GetAttrs(param, call_node->attrs.as<Conv1DAttrs>());
} else if (node_map.count(conv2d_)) {
attrs = GetAttrs(param, call_node->attrs.as<Conv2DAttrs>());
} else if (node_map.count(conv3d_)) {
attrs = GetAttrs(param, call_node->attrs.as<Conv3DAttrs>());
} else {
return post;
if (pad_scalar == 0.0) {
// Fold Padding and Conv/AvgPool only if pad_value == 0.
if (node_map.count(conv_)) {
if (node_map.count(conv1d_)) {
attrs = MakeConvAttrs(param, call_node->attrs.as<Conv1DAttrs>());
} else if (node_map.count(conv2d_)) {
attrs = MakeConv2D3DAttrs(param, call_node->attrs.as<Conv2DAttrs>());
} else if (node_map.count(conv3d_)) {
attrs = MakeConv2D3DAttrs(param, call_node->attrs.as<Conv3DAttrs>());
}
if (!attrs.defined()) {
return post;
}
auto w = node_map[w_][0];
return Call(call_node->op, {x, w}, attrs, call_node->type_args, call_node->span);
} else if (node_map.count(avg_pool1d_)) {
attrs = MakeAvgPoolAttrs(param, call_node->attrs.as<AvgPool1DAttrs>());
} else if (node_map.count(avg_pool2d_)) {
attrs = MakeAvgPoolAttrs(param, call_node->attrs.as<AvgPool2DAttrs>());
} else if (node_map.count(avg_pool3d_)) {
attrs = MakeAvgPoolAttrs(param, call_node->attrs.as<AvgPool3DAttrs>());
}
} else if (node_map.count(max_pool_)) {
// Fold Padding and MaxPool only if pad_value is the min possible value for the dtype
auto min_value = tvm::min_value(tvm::runtime::DataType(pad_value->data->dtype));
const FloatImmNode* maybe_min_float = min_value.as<FloatImmNode>();
const IntImmNode* maybe_min_int = min_value.as<IntImmNode>();

if ((maybe_min_float && pad_scalar == maybe_min_float->value) ||
(maybe_min_int && pad_scalar == maybe_min_int->value)) {
if (node_map.count(max_pool1d_)) {
attrs = MakePoolAttrs(param, call_node->attrs.as<MaxPool1DAttrs>());
} else if (node_map.count(max_pool2d_)) {
attrs = MakePoolAttrs(param, call_node->attrs.as<MaxPool2DAttrs>());
} else if (node_map.count(max_pool3d_)) {
attrs = MakePoolAttrs(param, call_node->attrs.as<MaxPool3DAttrs>());
}
}
}
if (!attrs.defined()) {
return post;
}
return Call(call_node->op, {x, w}, attrs, call_node->type_args, call_node->span);
return Call(call_node->op, {x}, attrs, call_node->type_args, call_node->span);
}
return post;
}
Expand All @@ -183,18 +308,27 @@ class SimplifyConvPad {
DFPattern conv1d_;
DFPattern conv2d_;
DFPattern conv3d_;
DFPattern contrib_conv2d_nchwc_;
DFPattern qconv2d_;
DFPattern input_zero_point_;
DFPattern kernel_zero_point_;
DFPattern input_scale_;
DFPattern kernel_scale_;
/*! \brief Pattern pool */
DFPattern pool_;
DFPattern avg_pool1d_;
DFPattern avg_pool2d_;
DFPattern avg_pool3d_;
DFPattern max_pool1d_;
DFPattern max_pool2d_;
DFPattern max_pool3d_;
DFPattern max_pool_;
};

class SimplifyExplicitPadding {
public:
explicit SimplifyExplicitPadding(IRModule mod) : mod_(mod) {
CreateCallback(SimplifyConvPad());
// TODO(mbrookhart): ConvTranspose(Pad(x)), Pool(Pad(x))
CreateCallback(SimplifyExplicitPad());
}
template <typename T>
void CreateCallback(const T& pattern) {
Expand Down
Loading