Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay] Add space_to_batch_nd and batch_to_space_nd operators #6477

Merged
merged 6 commits into from
Nov 17, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions include/tvm/relay/attrs/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -1324,6 +1324,34 @@ struct CorrelationAttrs : public tvm::AttrsNode<CorrelationAttrs> {
}
}; // struct CorrelationAttrs

/*! \brief Attributes used in SpaceToBatchND operator */
struct SpaceToBatchNDAttrs : public tvm::AttrsNode<SpaceToBatchNDAttrs> {
Array<Integer> block_shape;
Array<Array<IndexExpr>> paddings;
double pad_value;

TVM_DECLARE_ATTRS(SpaceToBatchNDAttrs, "relay.attrs.SpaceToBatchNDAttrs") {
TVM_ATTR_FIELD(block_shape)
.set_default(Array<Integer>({1, 1}))
.describe("1-D containing block size for each spatial dimension.");
TVM_ATTR_FIELD(paddings).describe("2-D containing paddings for each spatial dimension.");
TVM_ATTR_FIELD(pad_value).set_default(0.0).describe("The value used for padding.");
}
}; // struct SpaceToBatchNDAttrs

/*! \brief Attributes used in BatchToSpaceND operator */
struct BatchToSpaceNDAttrs : public tvm::AttrsNode<BatchToSpaceNDAttrs> {
Array<Integer> block_shape;
Array<Array<IndexExpr>> crops;

TVM_DECLARE_ATTRS(BatchToSpaceNDAttrs, "relay.attrs.BatchToSpaceNDAttrs") {
TVM_ATTR_FIELD(block_shape)
.set_default(Array<Integer>({1, 1}))
.describe("1-D containing block size for each spatial dimension.");
TVM_ATTR_FIELD(crops).describe("2-D containing amount to crop from spatial dimension.");
}
}; // struct BatchToSpaceNDAttrs

} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_ATTRS_NN_H_
178 changes: 178 additions & 0 deletions include/tvm/topi/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
#include <tvm/tir/op.h>
#include <tvm/topi/detail/constant_utils.h>
#include <tvm/topi/tags.h>
#include <tvm/topi/transform.h>

#include <algorithm>
#include <string>
Expand Down Expand Up @@ -459,6 +460,183 @@ inline tvm::te::Tensor group_conv2d_ngchw(const tvm::te::Tensor& I, const tvm::t
return tvm::te::compute(output_shape, l, name, tag);
}

/*!
* \brief Divide spatial dimensions of the input into a grid of blocks.
*
* \param data The input tensor.
* \param block_shape The size of the spatial block.
* \param pad_before The zero-padding size before each spatial dimension.
* \param pad_after The zero-padding size after each spatial dimension.
* \param pad_value The value used for padding.
* \param name The name of the operation.
* \param tag The tag to mark the operation.
*
* \return A Tensor whose op member is the space_to_batch_nd operation
*/
inline tvm::te::Tensor space_to_batch_nd(const tvm::te::Tensor& data,
const tvm::Array<Integer>& block_shape,
const tvm::Array<tvm::PrimExpr>& pad_before,
const tvm::Array<tvm::PrimExpr>& pad_after,
PrimExpr pad_value = PrimExpr(),
std::string name = "space_to_batch_nd",
std::string tag = kInjective) {
tvm::te::Tensor padded_t;
CHECK_EQ(pad_before.size(), pad_after.size());
CHECK_EQ(block_shape.size(), pad_before.size())
<< "Paddings must be provided for each spatial dimension";
tvm::Array<tvm::PrimExpr> pad_before_int32;
tvm::Array<tvm::PrimExpr> pad_after_int32;

// pad size for batch dimension is 0
pad_before_int32.push_back(tvm::cast(tvm::DataType::Int(32), 0));
pad_after_int32.push_back(tvm::cast(tvm::DataType::Int(32), 0));
// insert pad sizes given for spatial dimensions
for (const auto& ele : pad_before) {
pad_before_int32.push_back(tvm::cast(tvm::DataType::Int(32), ele));
}
for (const auto& ele : pad_after) {
pad_after_int32.push_back(tvm::cast(tvm::DataType::Int(32), ele));
}

// pad the input with paddings provided
if (!pad_value.defined()) {
pad_value = tvm::tir::make_const(data->dtype, 0);
}
padded_t = pad(data, pad_before_int32, pad_after_int32, pad_value);

auto input_shape = data->shape;
auto padded_shape = padded_t->shape;

// infer shapes
tvm::Array<PrimExpr> r_shape;
tvm::Array<Integer> axis;
tvm::Array<PrimExpr> o_shape;

size_t num_block_dims = block_shape.size();
int batch = static_cast<int>(GetConstInt(input_shape[0]));
tvm::PrimExpr block_shape_prod(1);
r_shape.push_back(batch);

for (size_t i = 1; i <= num_block_dims; i++) {
int padded_input = static_cast<int>(GetConstInt(padded_shape[i]));
int block_size = static_cast<int>(GetConstInt(block_shape[i - 1]));
CHECK_EQ((padded_input % block_size), 0)
<< "(" << i
<< ")th "
"Input dimension after padding ("
<< padded_input << ")"
<< " must be divisible by its block size (" << block_size << ")";

r_shape.push_back(div(padded_shape[i], block_shape[i - 1]));
r_shape.push_back(block_shape[i - 1]);
block_shape_prod *= block_shape[i - 1];
axis.push_back(Integer(r_shape.size() - 1)); // index of block_shape[i - 1]
}

size_t n = axis.size();
axis.push_back(0); // batch is at index 0
// index of (padded_shape[i] / block_shape[i - 1]) in r_shape
for (size_t i = 0; i < n; i++) {
axis.push_back(static_cast<int>(GetConstInt(axis[i] - 1)));
}
o_shape.push_back(tvm::PrimExpr(batch) * block_shape_prod);
for (size_t i = 1; i <= num_block_dims; i++) {
o_shape.push_back(div(padded_shape[i], block_shape[i - 1]));
}
// append remaining shape
for (size_t i = num_block_dims + 1; i < input_shape.size(); i++) {
r_shape.push_back(input_shape[i]);
axis.push_back(Integer(r_shape.size() - 1)); // index of remaining shape in r_shape
o_shape.push_back(input_shape[i]);
}

tvm::te::Tensor output = reshape(padded_t, r_shape);
output = transpose(output, axis);
output = reshape(output, o_shape);

return output;
}

/*!
* \brief Reshape the batch dimension into spatial dimensions.
*
* \param data The input tensor.
* \param block_shape The size of the spatial block.
* \param crop_begin_list The begin crop size for each spatial dimension.
* \param crop_end_list The end crop size for each spatial dimension.
* \param name The name of the operation.
* \param tag The tag to mark the operation.
*
* \return A Tensor whose op member is the batch_to_space_nd operation
*/
inline tvm::te::Tensor batch_to_space_nd(const tvm::te::Tensor& data,
const tvm::Array<Integer>& block_shape,
const tvm::Array<tvm::PrimExpr>& crop_begin_list,
const tvm::Array<tvm::PrimExpr>& crop_end_list,
std::string name = "batch_to_space_nd",
std::string tag = kInjective) {
// Construct shapes for reshape and transpose operation
Array<PrimExpr> in_shape = data->shape;
Array<PrimExpr> r_shape;
Array<Integer> axis;
size_t num_block_dims = block_shape.size();
size_t num_input_dims = in_shape.size();
tvm::PrimExpr block_shape_prod(1);
int batch = static_cast<int>(GetConstInt(in_shape[0]));

for (size_t i = 0; i < num_block_dims; i++) {
r_shape.push_back(block_shape[i]);
block_shape_prod *= block_shape[i];
}
axis.push_back(Integer(r_shape.size())); // axis of (batch / block_shape_prod)
r_shape.push_back(batch / block_shape_prod);

for (size_t i = 1; i < num_input_dims; i++) {
axis.push_back(Integer(r_shape.size())); // axis of in_shape[i]
if (axis.size() < (num_block_dims + num_input_dims)) {
axis.push_back(Integer(r_shape.size() - (num_block_dims + 1))); // axis of block_shape[i]
}
r_shape.push_back(in_shape[i]);
}

Array<PrimExpr> r_p_shape;
r_p_shape.push_back(batch / block_shape_prod);
for (size_t i = 1; i <= num_block_dims; i++) {
r_p_shape.push_back(in_shape[i] * block_shape[i - 1]);
}
for (size_t i = num_block_dims + 1; i < num_input_dims; i++) {
r_p_shape.push_back(in_shape[i]);
}

tvm::te::Tensor out;
out = reshape(data, r_shape);
out = transpose(out, axis);
out = reshape(out, r_p_shape);

// Crop the start and end of dimensions of out
Array<Integer> begin_idx, end_idx, strides;
for (size_t i = 0; i < r_p_shape.size(); ++i) {
strides.push_back(Integer(1));
if (i > 0 && i <= num_block_dims) {
// prepare begin and end index for spatial dimensions
int begin_i = static_cast<int>(GetConstInt(crop_begin_list[i - 1]));
int end_i = static_cast<int>(GetConstInt(crop_end_list[i - 1]));
int out_i = static_cast<int>(GetConstInt(r_p_shape[i]));
CHECK_GT(out_i, (begin_i + end_i))
<< "Incorrect crop sizes for (" << i << ")th dim, can not crop more than"
<< " output size" << out_i << " vs " << (begin_i + end_i);
begin_idx.push_back(begin_i);
end_idx.push_back(out_i - end_i);
} else {
// ignore the batch and remaining dimension
begin_idx.push_back(Integer(0));
end_idx.push_back(static_cast<int>(GetConstInt(r_p_shape[i])));
}
}

out = strided_slice(out, begin_idx, end_idx, strides);
return out;
}
} // namespace topi
} // namespace tvm
#endif // TVM_TOPI_NN_H_
88 changes: 11 additions & 77 deletions python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -1980,8 +1980,6 @@ def _impl(inputs, attr, params, mod):

def _space_to_batch_nd():
def _impl(inputs, attr, params, mod):
input_node = inputs[0]
input_shape = _infer_shape(input_node, mod)
try:
block_shape = _get_list_param(params, inputs[1])
except (IndexError, KeyError, AttributeError):
Expand All @@ -1995,48 +1993,18 @@ def _impl(inputs, attr, params, mod):
if len(paddings.shape) == 1:
paddings = np.expand_dims(paddings, axis=0)
paddings = paddings.tolist()
N = len(input_shape)
M = len(block_shape)
batch = input_shape[0]
remaining_shape_length = N - M - 1
paddings = [(0, 0)] + paddings + [(0, 0)] * remaining_shape_length
# From https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/space-to-batch-n-d:
# Zero-pad the start and end of dimensions [1, ..., M] of the input according to paddings
# to produce padded of shape padded_shape.
padded = tvm.relay.nn.pad(input_node, pad_width=paddings)
# Reshape padded to reshaped_padded of shape:
# [batch] + [padded_shape[1] / block_shape[0], block_shape[0], ...,
# padded_shape[M] / block_shape[M-1], block_shape[M-1]] + remaining_shape
shape1 = [batch] + [item for i in range(M) for item in [-4, -1, block_shape[i]]] + [-2]
reshaped_padded = tvm.relay.reshape(padded, newshape=shape1)
# Permute dimensions of reshaped_padded to produce permuted_reshaped_padded of shape:
# block_shape + [batch] + [padded_shape[1] / block_shape[0], ...,
# padded_shape[M] / block_shape[M-1]] + remaining_shape
axes = (
[2 * i + 2 for i in range(M)]
+ [0]
+ [2 * i + 1 for i in range(M)]
+ list(range(1 + 2 * M, 1 + 2 * M + remaining_shape_length))
)
permuted_reshaped_padded = tvm.relay.transpose(reshaped_padded, axes=axes)
permuted_reshaped_padded_shape = _infer_shape(permuted_reshaped_padded, mod)
# Reshape permuted_reshaped_padded to flatten block_shape into the batch dimension,
# producing an output tensor of shape:
# [batch * prod(block_shape)] + [padded_shape[1] / block_shape[0], ...,
# padded_shape[M] / block_shape[M-1]] + remaining_shape
shape2 = [batch * np.prod(block_shape)] + list(permuted_reshaped_padded_shape)[M + 1 :]
reshaped_permuted_reshaped_padded = tvm.relay.reshape(
permuted_reshaped_padded, newshape=shape2
)
return reshaped_permuted_reshaped_padded

attr["block_shape"] = block_shape
attr["paddings"] = paddings
out = AttrCvt("space_to_batch_nd", ignores=["Tblock_shape", "Tpaddings"])([inputs[0]], attr)

return out

return _impl


def _batch_to_space_nd():
def _impl(inputs, attr, params, mod):
input_node = inputs[0]
input_shape = _infer_shape(input_node, mod)
try:
block_shape = _get_list_param(params, inputs[1])
except (IndexError, KeyError, AttributeError):
Expand All @@ -2050,46 +2018,12 @@ def _impl(inputs, attr, params, mod):
if len(crops.shape) == 1:
crops = np.expand_dims(crops, axis=0)
crops = crops.tolist()
M = len(block_shape)
batch = input_shape[0]
# From https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/batch-to-space-n-d:
# Reshape input to reshaped of shape:
# [block_shape[0], ..., block_shape[M-1], batch / prod(block_shape),
# input_shape[1], ..., input_shape[N-1]]
shape1 = block_shape + [batch // np.prod(block_shape)] + list(input_shape[1:])
reshaped = tvm.relay.reshape(input_node, newshape=shape1)
# Permute dimensions of reshaped to produce permuted of shape
# [batch / prod(block_shape), input_shape[1], block_shape[0], ...,
# input_shape[M], block_shape[M-1], input_shape[M+1], ..., input_shape[N-1]]
axes = (
[M]
+ [axis for i in range(M) for axis in [M + i + 1, i]]
+ list(range(2 * M + 1, len(shape1)))
)
permuted = tvm.relay.transpose(reshaped, axes=axes)
# Reshape permuted to produce reshaped_permuted of shape
# [batch / prod(block_shape), input_shape[1] * block_shape[0], ...,
# input_shape[M] * block_shape[M-1], input_shape[M+1], ..., input_shape[N-1]]
shape2 = [0] + [-3] * M + [-2]
reshaped_permuted = tvm.relay.reshape(permuted, newshape=shape2)
# Crop the start and end of dimensions [1, ..., M] of reshaped_permuted according to crops
# to produce the output of shape:
# [batch / prod(block_shape), input_shape[1] * block_shape[0] - crops[0,0] - crops[0,1],
# ..., input_shape[M] * block_shape[M-1] - crops[M-1,0] - crops[M-1,1],
# input_shape[M+1], ..., input_shape[N-1]]
reshaped_permuted_shape = _infer_shape(reshaped_permuted, mod)
cropped = reshaped_permuted
for axis in range(1, M + 1):
crop = crops[axis - 1]
if crop != [0, 0]:
indices = tvm.relay.arange(
_expr.const(crop[0]),
_expr.const(reshaped_permuted_shape[axis] - crop[1]),
dtype="int32",
)
cropped = tvm.relay.take(cropped, indices=indices, axis=axis)

return cropped
attr["block_shape"] = block_shape
attr["crops"] = crops
out = AttrCvt("batch_to_space_nd", ignores=["Tblock_shape", "Tcrops"])([inputs[0]], attr)

return out

return _impl

Expand Down
Loading