Skip to content

Commit

Permalink
[TOPI][Relay] max_pool2d & avg_pool2d gradient (apache#3601)
Browse files Browse the repository at this point in the history
  • Loading branch information
vinx13 authored and wweic committed Aug 9, 2019
1 parent 31a8d76 commit f0d1ed8
Show file tree
Hide file tree
Showing 15 changed files with 730 additions and 21 deletions.
18 changes: 18 additions & 0 deletions python/tvm/relay/op/_tensor_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from .transform import collapse_sum_like, broadcast_to_like, where
from .tensor import exp, negative, power, less
from .tensor import zeros_like, ones_like
from . import nn as _nn


@register_gradient("log")
Expand Down Expand Up @@ -146,3 +147,20 @@ def clip_grad(orig, grad):
zeros = zeros_like(x)
ones = ones_like(x)
return [where(less(x, a_mins), zeros, where(less(a_maxs, x), zeros, ones * grad))]

@register_gradient("nn.max_pool2d")
def max_pool2d_grad(orig, grad):
attrs = orig.attrs
pool_grad = _nn.max_pool2d_grad(grad, orig.args[0], pool_size=attrs.pool_size,
strides=attrs.strides, padding=attrs.padding,
layout=attrs.layout, ceil_mode=attrs.ceil_mode)
return [pool_grad]

@register_gradient("nn.avg_pool2d")
def avg_pool2d_grad(orig, grad):
attrs = orig.attrs
pool_grad = _nn.avg_pool2d_grad(grad, orig.args[0], pool_size=attrs.pool_size,
strides=attrs.strides, padding=attrs.padding,
layout=attrs.layout, ceil_mode=attrs.ceil_mode,
count_include_pad=attrs.count_include_pad)
return [pool_grad]
22 changes: 22 additions & 0 deletions python/tvm/relay/op/nn/_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,28 @@ def schedule_avg_pool2d(attrs, outs, target):
reg.register_pattern("nn.avg_pool2d", OpPattern.OUT_ELEMWISE_FUSABLE)


# max_pool2d_grad
@reg.register_schedule("nn.max_pool2d_grad")
def schedule_max_pool2d_grad(attrs, outs, target):
"""Schedule definition of max_pool2d_grad"""
with target:
return topi.generic.schedule_pool_grad(outs)


reg.register_pattern("nn.max_pool2d_grad", OpPattern.OUT_ELEMWISE_FUSABLE)


# avg_pool2d_grad
@reg.register_schedule("nn.avg_pool2d_grad")
def schedule_avg_pool2d_grad(attrs, outs, target):
"""Schedule definition of avg_pool2d_grad"""
with target:
return topi.generic.schedule_pool_grad(outs)


reg.register_pattern("nn.avg_pool2d_grad", OpPattern.OUT_ELEMWISE_FUSABLE)


# global_max_pool2d
@reg.register_schedule("nn.global_max_pool2d")
def schedule_global_max_pool2d(_, outs, target):
Expand Down
82 changes: 82 additions & 0 deletions python/tvm/relay/op/nn/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,88 @@ def avg_pool2d(data,
return _make.avg_pool2d(data, pool_size, strides, padding,
layout, ceil_mode, count_include_pad)

def max_pool2d_grad(out_grad,
data,
pool_size=(1, 1),
strides=(1, 1),
padding=(0, 0),
layout="NCHW",
ceil_mode=False):
r"""Gradient of 2D maximum pooling operator.
This operator takes out_grad and data as input and calculates gradient of max_pool2d.
Parameters
----------
out_grad : tvm.relay.Expr
The output gradient
data : tvm.relay.Expr
The input data to the operator.
strides : tuple of int, optional
The strides of pooling.
padding : tuple of int, optional
The padding for pooling.
layout : str, optional
Layout of the input.
ceil_mode : bool, optional
To enable or disable ceil while pooling.
Returns
-------
result : tvm.relay.Expr
The computed result.
"""
return _make.max_pool2d_grad(out_grad, data, pool_size, strides, padding,
layout, ceil_mode)

def avg_pool2d_grad(out_grad,
data,
pool_size=(1, 1),
strides=(1, 1),
padding=(0, 0),
layout="NCHW",
ceil_mode=False,
count_include_pad=False):
r"""Gradient of 2D average pooling operator.
This operator takes out_grad and data as input and calculates gradient of avg_pool2d.
Parameters
----------
out_grad : tvm.relay.Expr
The output gradient
data : tvm.relay.Expr
The input data to the operator.
strides : tuple of int, optional
The strides of pooling.
padding : tuple of int, optional
The padding for pooling.
layout : str, optional
Layout of the input.
ceil_mode : bool, optional
To enable or disable ceil while pooling.
count_include_pad : bool, optional
To include padding to compute the average.
Returns
-------
result : tvm.relay.Expr
The computed result.
"""
return _make.avg_pool2d_grad(out_grad, data, pool_size, strides, padding,
layout, ceil_mode, count_include_pad)

def global_max_pool2d(data,
layout="NCHW"):
r"""2D global maximum pooling operator.
Expand Down
10 changes: 10 additions & 0 deletions python/tvm/relay/op/op_attrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,3 +251,13 @@ class YoloReorgAttrs(Attrs):
@register_relay_attr_node
class ProposalAttrs(Attrs):
"""Attributes used in proposal operators"""


@register_relay_attr_node
class MaxPool2DAttrs(Attrs):
"""Attributes used in max_pool2d operators"""


@register_relay_attr_node
class AvgPool2DAttrs(Attrs):
"""Attributes used in avg_pool2d operators"""
160 changes: 158 additions & 2 deletions src/relay/op/nn/pooling.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
Expand Down Expand Up @@ -557,5 +557,161 @@ RELAY_REGISTER_OP("contrib.adaptive_max_pool2d")
Pool2DInferCorrectLayout<AdaptivePool2DAttrs>)
.set_attr<FTVMCompute>("FTVMCompute", AdaptivePool2DCompute<topi::nn::kMaxPool>);


bool Pool2DGradRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 3);
const auto* data = types[1].as<TensorTypeNode>();

if (data == nullptr) return false;

// assign output type
reporter->Assign(types[2], types[1]);
return true;
}

template <typename AttrType, topi::nn::PoolType mode>
Array<Tensor> Pool2DGradCompute(const Attrs& attrs, const Array<Tensor>& inputs,
const Type& out_type, const Target& target) {
static const Layout kNCHW("NCHW");
const auto* param = attrs.as<AttrType>();
CHECK(param != nullptr);
CHECK_EQ(inputs.size(), 2);
auto pool_size = param->pool_size;
auto strides = param->strides;
auto padding = param->padding;
auto ceil_mode = param->ceil_mode;
Layout layout(param->layout);

CHECK(BijectiveLayoutNode::make(layout, kNCHW).defined())
<< "pool2d_grad currently only supports layouts that are convertible from NCHW";
CHECK_EQ(layout.IndexOf(LayoutAxis::Get('h')), -1)
<< "pool2d_grad does not support input split on height";
CHECK_EQ(layout.IndexOf(LayoutAxis::Get('w')), -1)
<< "pool2d_grad does not support input split on width";

CHECK(inputs[0].ndim() == 4U || inputs[0].ndim() == 5U)
<< "Pool2DGrad only support 4-D output gradient (e.g., NCHW)"
<< " or 5-D output gradient (last dimension is a split of channel)";

CHECK(inputs[1].ndim() == 4U || inputs[1].ndim() == 5U)
<< "Pool2DGrad only support 4-D input (e.g., NCHW)"
<< " or 5-D input (last dimension is a split of channel)";

if (param->padding.size() == 1) {
padding.push_back(padding[0]);
padding.push_back(padding[0]);
padding.push_back(padding[0]);
} else if (param->padding.size() == 2) {
padding.push_back(padding[0]);
padding.push_back(padding[1]);
}
if (mode == topi::nn::kAvgPool) {
bool count_include_pad = reinterpret_cast<const AvgPool2DAttrs*>(param)->count_include_pad;
return Array<Tensor>{topi::nn::pool_grad(inputs[0], inputs[1], pool_size, strides, padding,
mode, ceil_mode, layout.name(), count_include_pad)};
} else {
return Array<Tensor>{topi::nn::pool_grad(inputs[0], inputs[1], pool_size, strides, padding,
mode, ceil_mode, layout.name())};
}
}


// MaxPool2DGrad
Expr MakeMaxPool2DGrad(Expr out_grad, Expr data, Array<IndexExpr> pool_size,
Array<IndexExpr> strides, Array<IndexExpr> padding, std::string layout, bool ceil_mode) {
auto attrs = make_node<MaxPool2DAttrs>();
attrs->pool_size = std::move(pool_size);
attrs->strides = std::move(strides);
attrs->padding = std::move(padding);
attrs->layout = std::move(layout);
attrs->ceil_mode = ceil_mode;
static const Op& op = Op::Get("nn.max_pool2d_grad");
return CallNode::make(op, {out_grad, data}, Attrs(attrs), {});
}

TVM_REGISTER_API("relay.op.nn._make.max_pool2d_grad").set_body_typed(MakeMaxPool2DGrad);


RELAY_REGISTER_OP("nn.max_pool2d_grad")
.describe(R"code(Gradient of max pooling operation for two dimensional data.
- **out_grad**: This depends on the `layout` parameter. Output gradient is 4D array of
shape (batch_size, channels, out_height, out_width) if `layout` is `NCHW`.
out_height and out_width are are the output size of the pooling operation,
which are calculated as::
out_height = floor((height+padding[0]+padding[2]-pool_size[0])/strides[0])+1
out_width = floor((width+padding[1]+padding[3]-pool_size[1])/strides[1])+1
where padding will be an expanded array based on number of values passed as::
one int : all sides same padding used.
two int : bottom, right use same as top and left.
four int: padding width in the order of (top, left, bottom, right).
When `ceil_mode` is `True`, ceil will be used instead of floor in this
equation.
- **data**: This depends on the `layout` parameter. Input is 4D array of shape
(batch_size, channels, height, width) if `layout` is `NCHW`.
- **grad**: This depends on the `layout` parameter. Grad is 4D array of shape
(batch_size, channels, height, width) if `layout` is `NCHW`.
)code" TVM_ADD_FILELINE)
.set_attrs_type_key("relay.attrs.MaxPool2DAttrs")
.set_num_inputs(2)
.add_argument("data", "Tensor", "The input tensor.")
.set_support_level(2)
.add_type_rel("MaxPool2DGrad", Pool2DGradRel)
.set_attr<FTVMCompute>("FTVMCompute", Pool2DGradCompute<MaxPool2DAttrs, topi::nn::kMaxPool>);


// AvgPool2DGrad
Expr MakeAvgPool2DGrad(Expr out_grad, Expr data, Array<IndexExpr> pool_size,
Array<IndexExpr> strides, Array<IndexExpr> padding, std::string layout, bool ceil_mode,
bool count_include_pad) {
auto attrs = make_node<AvgPool2DAttrs>();
attrs->pool_size = std::move(pool_size);
attrs->strides = std::move(strides);
attrs->padding = std::move(padding);
attrs->layout = std::move(layout);
attrs->ceil_mode = ceil_mode;
attrs->count_include_pad = count_include_pad;
static const Op& op = Op::Get("nn.avg_pool2d_grad");
return CallNode::make(op, {out_grad, data}, Attrs(attrs), {});
}

TVM_REGISTER_API("relay.op.nn._make.avg_pool2d_grad").set_body_typed(MakeAvgPool2DGrad);


RELAY_REGISTER_OP("nn.avg_pool2d_grad")
.describe(R"code(Gradient of average pooling operation for two dimensional data.
- **out_grad**: This depends on the `layout` parameter. Output gradient is 4D array of
shape (batch_size, channels, out_height, out_width) if `layout` is `NCHW`.
out_height and out_width are are the output size of the pooling operation,
which are calculated as::
out_height = floor((height+padding[0]+padding[2]-pool_size[0])/strides[0])+1
out_width = floor((width+padding[1]+padding[3]-pool_size[1])/strides[1])+1
where padding will be an expanded array based on number of values passed as::
one int : all sides same padding used.
two int : bottom, right use same as top and left.
four int: padding width in the order of (top, left, bottom, right).
When `ceil_mode` is `True`, ceil will be used instead of floor in this
equation.
- **data**: This depends on the `layout` parameter. Input is 4D array of shape
(batch_size, channels, height, width) if `layout` is `NCHW`.
- **grad**: This depends on the `layout` parameter. Grad is 4D array of shape
(batch_size, channels, height, width) if `layout` is `NCHW`.
)code" TVM_ADD_FILELINE)
.set_attrs_type_key("relay.attrs.MaxPool2DAttrs")
.set_num_inputs(2)
.add_argument("data", "Tensor", "The input tensor.")
.set_support_level(2)
.add_type_rel("MaxPool2DGrad", Pool2DGradRel)
.set_attr<FTVMCompute>("FTVMCompute", Pool2DGradCompute<AvgPool2DAttrs, topi::nn::kAvgPool>);


} // namespace relay
} // namespace tvm
2 changes: 1 addition & 1 deletion topi/include/topi/detail/ravel_unravel.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ using namespace tvm;
*
* \return The index after flattening
*/
inline Expr RavelIndex(Array<Var> indices, Array<Expr> shape) {
inline Expr RavelIndex(Array<Expr> indices, Array<Expr> shape) {
CHECK_EQ(indices.size(), shape.size()) << "indices and shape must have equal size";
CHECK_GT(indices.size(), 0) << "indices must not be empty";
Expr idx;
Expand Down
Loading

0 comments on commit f0d1ed8

Please sign in to comment.