diff --git a/include/tvm/relay/attrs/nn.h b/include/tvm/relay/attrs/nn.h index b2555de6d35e..e697ac45bd12 100644 --- a/include/tvm/relay/attrs/nn.h +++ b/include/tvm/relay/attrs/nn.h @@ -1324,6 +1324,34 @@ struct CorrelationAttrs : public tvm::AttrsNode { } }; // struct CorrelationAttrs +/*! \brief Attributes used in SpaceToBatchND operator */ +struct SpaceToBatchNDAttrs : public tvm::AttrsNode { + Array block_shape; + Array> paddings; + double pad_value; + + TVM_DECLARE_ATTRS(SpaceToBatchNDAttrs, "relay.attrs.SpaceToBatchNDAttrs") { + TVM_ATTR_FIELD(block_shape) + .set_default(Array({1, 1})) + .describe("1-D containing block size for each spatial dimension."); + TVM_ATTR_FIELD(paddings).describe("2-D containing paddings for each spatial dimension."); + TVM_ATTR_FIELD(pad_value).set_default(0.0).describe("The value used for padding."); + } +}; // struct SpaceToBatchNDAttrs + +/*! \brief Attributes used in BatchToSpaceND operator */ +struct BatchToSpaceNDAttrs : public tvm::AttrsNode { + Array block_shape; + Array> crops; + + TVM_DECLARE_ATTRS(BatchToSpaceNDAttrs, "relay.attrs.BatchToSpaceNDAttrs") { + TVM_ATTR_FIELD(block_shape) + .set_default(Array({1, 1})) + .describe("1-D containing block size for each spatial dimension."); + TVM_ATTR_FIELD(crops).describe("2-D containing amount to crop from spatial dimension."); + } +}; // struct BatchToSpaceNDAttrs + } // namespace relay } // namespace tvm #endif // TVM_RELAY_ATTRS_NN_H_ diff --git a/include/tvm/topi/nn.h b/include/tvm/topi/nn.h index ba1be3424fcc..f958048f13c3 100644 --- a/include/tvm/topi/nn.h +++ b/include/tvm/topi/nn.h @@ -30,6 +30,7 @@ #include #include #include +#include #include #include @@ -459,6 +460,183 @@ inline tvm::te::Tensor group_conv2d_ngchw(const tvm::te::Tensor& I, const tvm::t return tvm::te::compute(output_shape, l, name, tag); } +/*! + * \brief Divide spatial dimensions of the input into a grid of blocks. + * + * \param data The input tensor. + * \param block_shape The size of the spatial block. + * \param pad_before The zero-padding size before each spatial dimension. + * \param pad_after The zero-padding size after each spatial dimension. + * \param pad_value The value used for padding. + * \param name The name of the operation. + * \param tag The tag to mark the operation. + * + * \return A Tensor whose op member is the space_to_batch_nd operation + */ +inline tvm::te::Tensor space_to_batch_nd(const tvm::te::Tensor& data, + const tvm::Array& block_shape, + const tvm::Array& pad_before, + const tvm::Array& pad_after, + PrimExpr pad_value = PrimExpr(), + std::string name = "space_to_batch_nd", + std::string tag = kInjective) { + tvm::te::Tensor padded_t; + CHECK_EQ(pad_before.size(), pad_after.size()); + CHECK_EQ(block_shape.size(), pad_before.size()) + << "Paddings must be provided for each spatial dimension"; + tvm::Array pad_before_int32; + tvm::Array pad_after_int32; + + // pad size for batch dimension is 0 + pad_before_int32.push_back(tvm::cast(tvm::DataType::Int(32), 0)); + pad_after_int32.push_back(tvm::cast(tvm::DataType::Int(32), 0)); + // insert pad sizes given for spatial dimensions + for (const auto& ele : pad_before) { + pad_before_int32.push_back(tvm::cast(tvm::DataType::Int(32), ele)); + } + for (const auto& ele : pad_after) { + pad_after_int32.push_back(tvm::cast(tvm::DataType::Int(32), ele)); + } + + // pad the input with paddings provided + if (!pad_value.defined()) { + pad_value = tvm::tir::make_const(data->dtype, 0); + } + padded_t = pad(data, pad_before_int32, pad_after_int32, pad_value); + + auto input_shape = data->shape; + auto padded_shape = padded_t->shape; + + // infer shapes + tvm::Array r_shape; + tvm::Array axis; + tvm::Array o_shape; + + size_t num_block_dims = block_shape.size(); + int batch = static_cast(GetConstInt(input_shape[0])); + tvm::PrimExpr block_shape_prod(1); + r_shape.push_back(batch); + + for (size_t i = 1; i <= num_block_dims; i++) { + int padded_input = static_cast(GetConstInt(padded_shape[i])); + int block_size = static_cast(GetConstInt(block_shape[i - 1])); + CHECK_EQ((padded_input % block_size), 0) + << "(" << i + << ")th " + "Input dimension after padding (" + << padded_input << ")" + << " must be divisible by its block size (" << block_size << ")"; + + r_shape.push_back(div(padded_shape[i], block_shape[i - 1])); + r_shape.push_back(block_shape[i - 1]); + block_shape_prod *= block_shape[i - 1]; + axis.push_back(Integer(r_shape.size() - 1)); // index of block_shape[i - 1] + } + + size_t n = axis.size(); + axis.push_back(0); // batch is at index 0 + // index of (padded_shape[i] / block_shape[i - 1]) in r_shape + for (size_t i = 0; i < n; i++) { + axis.push_back(static_cast(GetConstInt(axis[i] - 1))); + } + o_shape.push_back(tvm::PrimExpr(batch) * block_shape_prod); + for (size_t i = 1; i <= num_block_dims; i++) { + o_shape.push_back(div(padded_shape[i], block_shape[i - 1])); + } + // append remaining shape + for (size_t i = num_block_dims + 1; i < input_shape.size(); i++) { + r_shape.push_back(input_shape[i]); + axis.push_back(Integer(r_shape.size() - 1)); // index of remaining shape in r_shape + o_shape.push_back(input_shape[i]); + } + + tvm::te::Tensor output = reshape(padded_t, r_shape); + output = transpose(output, axis); + output = reshape(output, o_shape); + + return output; +} + +/*! + * \brief Reshape the batch dimension into spatial dimensions. + * + * \param data The input tensor. + * \param block_shape The size of the spatial block. + * \param crop_begin_list The begin crop size for each spatial dimension. + * \param crop_end_list The end crop size for each spatial dimension. + * \param name The name of the operation. + * \param tag The tag to mark the operation. + * + * \return A Tensor whose op member is the batch_to_space_nd operation + */ +inline tvm::te::Tensor batch_to_space_nd(const tvm::te::Tensor& data, + const tvm::Array& block_shape, + const tvm::Array& crop_begin_list, + const tvm::Array& crop_end_list, + std::string name = "batch_to_space_nd", + std::string tag = kInjective) { + // Construct shapes for reshape and transpose operation + Array in_shape = data->shape; + Array r_shape; + Array axis; + size_t num_block_dims = block_shape.size(); + size_t num_input_dims = in_shape.size(); + tvm::PrimExpr block_shape_prod(1); + int batch = static_cast(GetConstInt(in_shape[0])); + + for (size_t i = 0; i < num_block_dims; i++) { + r_shape.push_back(block_shape[i]); + block_shape_prod *= block_shape[i]; + } + axis.push_back(Integer(r_shape.size())); // axis of (batch / block_shape_prod) + r_shape.push_back(batch / block_shape_prod); + + for (size_t i = 1; i < num_input_dims; i++) { + axis.push_back(Integer(r_shape.size())); // axis of in_shape[i] + if (axis.size() < (num_block_dims + num_input_dims)) { + axis.push_back(Integer(r_shape.size() - (num_block_dims + 1))); // axis of block_shape[i] + } + r_shape.push_back(in_shape[i]); + } + + Array r_p_shape; + r_p_shape.push_back(batch / block_shape_prod); + for (size_t i = 1; i <= num_block_dims; i++) { + r_p_shape.push_back(in_shape[i] * block_shape[i - 1]); + } + for (size_t i = num_block_dims + 1; i < num_input_dims; i++) { + r_p_shape.push_back(in_shape[i]); + } + + tvm::te::Tensor out; + out = reshape(data, r_shape); + out = transpose(out, axis); + out = reshape(out, r_p_shape); + + // Crop the start and end of dimensions of out + Array begin_idx, end_idx, strides; + for (size_t i = 0; i < r_p_shape.size(); ++i) { + strides.push_back(Integer(1)); + if (i > 0 && i <= num_block_dims) { + // prepare begin and end index for spatial dimensions + int begin_i = static_cast(GetConstInt(crop_begin_list[i - 1])); + int end_i = static_cast(GetConstInt(crop_end_list[i - 1])); + int out_i = static_cast(GetConstInt(r_p_shape[i])); + CHECK_GT(out_i, (begin_i + end_i)) + << "Incorrect crop sizes for (" << i << ")th dim, can not crop more than" + << " output size" << out_i << " vs " << (begin_i + end_i); + begin_idx.push_back(begin_i); + end_idx.push_back(out_i - end_i); + } else { + // ignore the batch and remaining dimension + begin_idx.push_back(Integer(0)); + end_idx.push_back(static_cast(GetConstInt(r_p_shape[i]))); + } + } + + out = strided_slice(out, begin_idx, end_idx, strides); + return out; +} } // namespace topi } // namespace tvm #endif // TVM_TOPI_NN_H_ diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index 89b36256152e..27f67b83850d 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -1980,8 +1980,6 @@ def _impl(inputs, attr, params, mod): def _space_to_batch_nd(): def _impl(inputs, attr, params, mod): - input_node = inputs[0] - input_shape = _infer_shape(input_node, mod) try: block_shape = _get_list_param(params, inputs[1]) except (IndexError, KeyError, AttributeError): @@ -1995,48 +1993,18 @@ def _impl(inputs, attr, params, mod): if len(paddings.shape) == 1: paddings = np.expand_dims(paddings, axis=0) paddings = paddings.tolist() - N = len(input_shape) - M = len(block_shape) - batch = input_shape[0] - remaining_shape_length = N - M - 1 - paddings = [(0, 0)] + paddings + [(0, 0)] * remaining_shape_length - # From https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/space-to-batch-n-d: - # Zero-pad the start and end of dimensions [1, ..., M] of the input according to paddings - # to produce padded of shape padded_shape. - padded = tvm.relay.nn.pad(input_node, pad_width=paddings) - # Reshape padded to reshaped_padded of shape: - # [batch] + [padded_shape[1] / block_shape[0], block_shape[0], ..., - # padded_shape[M] / block_shape[M-1], block_shape[M-1]] + remaining_shape - shape1 = [batch] + [item for i in range(M) for item in [-4, -1, block_shape[i]]] + [-2] - reshaped_padded = tvm.relay.reshape(padded, newshape=shape1) - # Permute dimensions of reshaped_padded to produce permuted_reshaped_padded of shape: - # block_shape + [batch] + [padded_shape[1] / block_shape[0], ..., - # padded_shape[M] / block_shape[M-1]] + remaining_shape - axes = ( - [2 * i + 2 for i in range(M)] - + [0] - + [2 * i + 1 for i in range(M)] - + list(range(1 + 2 * M, 1 + 2 * M + remaining_shape_length)) - ) - permuted_reshaped_padded = tvm.relay.transpose(reshaped_padded, axes=axes) - permuted_reshaped_padded_shape = _infer_shape(permuted_reshaped_padded, mod) - # Reshape permuted_reshaped_padded to flatten block_shape into the batch dimension, - # producing an output tensor of shape: - # [batch * prod(block_shape)] + [padded_shape[1] / block_shape[0], ..., - # padded_shape[M] / block_shape[M-1]] + remaining_shape - shape2 = [batch * np.prod(block_shape)] + list(permuted_reshaped_padded_shape)[M + 1 :] - reshaped_permuted_reshaped_padded = tvm.relay.reshape( - permuted_reshaped_padded, newshape=shape2 - ) - return reshaped_permuted_reshaped_padded + + attr["block_shape"] = block_shape + attr["paddings"] = paddings + out = AttrCvt("space_to_batch_nd", ignores=["Tblock_shape", "Tpaddings"])([inputs[0]], attr) + + return out return _impl def _batch_to_space_nd(): def _impl(inputs, attr, params, mod): - input_node = inputs[0] - input_shape = _infer_shape(input_node, mod) try: block_shape = _get_list_param(params, inputs[1]) except (IndexError, KeyError, AttributeError): @@ -2050,46 +2018,12 @@ def _impl(inputs, attr, params, mod): if len(crops.shape) == 1: crops = np.expand_dims(crops, axis=0) crops = crops.tolist() - M = len(block_shape) - batch = input_shape[0] - # From https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/batch-to-space-n-d: - # Reshape input to reshaped of shape: - # [block_shape[0], ..., block_shape[M-1], batch / prod(block_shape), - # input_shape[1], ..., input_shape[N-1]] - shape1 = block_shape + [batch // np.prod(block_shape)] + list(input_shape[1:]) - reshaped = tvm.relay.reshape(input_node, newshape=shape1) - # Permute dimensions of reshaped to produce permuted of shape - # [batch / prod(block_shape), input_shape[1], block_shape[0], ..., - # input_shape[M], block_shape[M-1], input_shape[M+1], ..., input_shape[N-1]] - axes = ( - [M] - + [axis for i in range(M) for axis in [M + i + 1, i]] - + list(range(2 * M + 1, len(shape1))) - ) - permuted = tvm.relay.transpose(reshaped, axes=axes) - # Reshape permuted to produce reshaped_permuted of shape - # [batch / prod(block_shape), input_shape[1] * block_shape[0], ..., - # input_shape[M] * block_shape[M-1], input_shape[M+1], ..., input_shape[N-1]] - shape2 = [0] + [-3] * M + [-2] - reshaped_permuted = tvm.relay.reshape(permuted, newshape=shape2) - # Crop the start and end of dimensions [1, ..., M] of reshaped_permuted according to crops - # to produce the output of shape: - # [batch / prod(block_shape), input_shape[1] * block_shape[0] - crops[0,0] - crops[0,1], - # ..., input_shape[M] * block_shape[M-1] - crops[M-1,0] - crops[M-1,1], - # input_shape[M+1], ..., input_shape[N-1]] - reshaped_permuted_shape = _infer_shape(reshaped_permuted, mod) - cropped = reshaped_permuted - for axis in range(1, M + 1): - crop = crops[axis - 1] - if crop != [0, 0]: - indices = tvm.relay.arange( - _expr.const(crop[0]), - _expr.const(reshaped_permuted_shape[axis] - crop[1]), - dtype="int32", - ) - cropped = tvm.relay.take(cropped, indices=indices, axis=axis) - return cropped + attr["block_shape"] = block_shape + attr["crops"] = crops + out = AttrCvt("batch_to_space_nd", ignores=["Tblock_shape", "Tcrops"])([inputs[0]], attr) + + return out return _impl diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index f52c318c8e97..594ab2df7a8d 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -2570,46 +2570,12 @@ def convert_batch_to_space_nd(self, op): input_tensor_idx = input_tensor.tensor_idx in_expr = self.get_expr(input_tensor_idx) - input_shape = list(input_tensor.tensor.ShapeAsNumpy()) - batch = input_shape[0] - block_shape = list(self.get_tensor_value(input_tensors[1])) - M = len(block_shape) - - crops = list(self.get_tensor_value(input_tensors[2])) + crops = self.get_tensor_value(input_tensors[2]).tolist() - # From https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/batch-to-space-n-d: - # Reshape input to reshaped of shape - shape1 = block_shape + [batch // np.prod(block_shape)] + input_shape[1:] - reshaped = _op.reshape(in_expr, newshape=shape1) - - # Permute dimensions of reshaped to produce permuted of shape - axes = ( - [M] - + [axis for i in range(M) for axis in [M + i + 1, i]] - + list(range(2 * M + 1, len(shape1))) - ) - permuted = _op.transpose(reshaped, axes=axes) - - # Reshape permuted to produce reshaped_permuted of shape - shape2 = [0] + [-3] * M + [-2] - reshaped_permuted = _op.reshape(permuted, newshape=shape2) - - # Crop the start and end of dimensions [1, ..., M] of reshaped_permuted according to crops - # to produce the output of shape: - reshaped_permuted_shape = _infer_shape(reshaped_permuted) - cropped = reshaped_permuted - for axis in range(1, M + 1): - crop = crops[axis - 1] - if (crop != [0, 0]).any(): - indices = _op.arange( - _expr.const(crop[0]), - _expr.const(reshaped_permuted_shape[axis] - crop[1]), - dtype="int32", - ) - cropped = _op.take(cropped, indices=indices, axis=axis) + out = _op.nn.batch_to_space_nd(in_expr, block_shape, crops) - return cropped + return out def convert_space_to_batch_nd(self, op): """space_to_batch_nd implementation.""" @@ -2620,51 +2586,12 @@ def convert_space_to_batch_nd(self, op): input_tensor_idx = input_tensor.tensor_idx in_expr = self.get_expr(input_tensor_idx) - input_shape = list(input_tensor.tensor.ShapeAsNumpy()) - batch = input_shape[0] - N = len(input_shape) - block_shape = list(self.get_tensor_value(input_tensors[1])) - M = len(block_shape) - - paddings = list(self.get_tensor_value(input_tensors[2])) + paddings = self.get_tensor_value(input_tensors[2]).tolist() - # From https://www.tensorflow.org/api_docs/python/tf/space_to_batch_nd: - # Zero-pad the start and end of dimensions [1, ..., M] of the input according to paddings - # to produce padded of shape padded_shape. - remaining_shape_length = N - M - 1 - padded_list = [(0, 0)] + paddings + [(0, 0)] * remaining_shape_length + out = _op.nn.space_to_batch_nd(in_expr, block_shape, paddings) - padded_shape = [] - for element in padded_list: - if isinstance(element, np.ndarray): - element = element.tolist() - - padded_shape.append(element) - - padded_shape = tuple(padded_shape) - padded = _op.nn.pad(in_expr, pad_width=tuple(padded_shape)) - - # Reshape padded to reshaped_padded of shape: - shape1 = [batch] + [item for i in range(M) for item in [-4, -1, block_shape[i]]] + [-2] - reshaped_padded = _op.reshape(padded, newshape=shape1) - - # Permute dimensions of reshaped_padded to produce permuted_reshaped_padded of shape: - axes = ( - [2 * i + 2 for i in range(M)] - + [0] - + [2 * i + 1 for i in range(M)] - + list(range(1 + 2 * M, 1 + 2 * M + remaining_shape_length)) - ) - permuted_reshaped_padded = _op.transpose(reshaped_padded, axes=axes) - permuted_reshaped_padded_shape = _infer_shape(permuted_reshaped_padded) - - # Reshape permuted_reshaped_padded to flatten block_shape into the batch dimension, - # producing an output tensor of shape: - shape2 = [batch * np.prod(block_shape)] + list(permuted_reshaped_padded_shape)[M + 1 :] - reshaped_permuted_reshaped_padded = _op.reshape(permuted_reshaped_padded, newshape=shape2) - - return reshaped_permuted_reshaped_padded + return out def convert_depth_to_space(self, op): """Convert TFLite DEPTH_TO_SPACE""" diff --git a/python/tvm/relay/op/nn/_nn.py b/python/tvm/relay/op/nn/_nn.py index c9926647989e..f5057f854b27 100644 --- a/python/tvm/relay/op/nn/_nn.py +++ b/python/tvm/relay/op/nn/_nn.py @@ -746,6 +746,11 @@ def compute_space_to_depth(attrs, inputs, out_dtype): reg.register_pattern("nn.correlation", OpPattern.OUT_ELEMWISE_FUSABLE) +# space_to_batch_nd and batch_to_space_nd +reg.register_injective_schedule("nn.space_to_batch_nd") +reg.register_injective_schedule("nn.batch_to_space_nd") + + ##################### # Shape functions # ##################### diff --git a/python/tvm/relay/op/nn/nn.py b/python/tvm/relay/op/nn/nn.py index 1aad4e7125fd..e11646e90978 100644 --- a/python/tvm/relay/op/nn/nn.py +++ b/python/tvm/relay/op/nn/nn.py @@ -3156,3 +3156,64 @@ def correlation( return _make.correlation( data1, data2, kernel_size, max_displacement, stride1, stride2, padding, is_multiply, layout ) + + +def space_to_batch_nd(data, block_shape, paddings, pad_value=0): + r"""Divide spatial dimensions of the data into a grid of blocks + and interleave them into batch dim. + + Parameters + ---------- + data : tvm.te.Tensor + N-D with shape [batch, spatial_shape, remaining_shape] + + block_shape : relay.Expr + 1-D of size [M] where M is number of spatial dims, specifies block size + for each spatial dimension. + + paddings : relay.Expr + 2-D of shape [M, 2] where M is number of spatial dims, specifies + [before, after] paddings for each spatial dimension. + + pad_value : float, or relay.Expr, optional, default=0 + The value used for padding. + + Returns + ------- + result : relay.Expr + N-D Tensor with shape + [in_batch * prod(block_shape), + padded_data[1] / block_shape[0], ..., padded_data[M] / block_shape[M-1], + remaining_shape] + """ + + return _make.space_to_batch_nd(data, block_shape, paddings, pad_value) + + +def batch_to_space_nd(data, block_shape, crops): + r"""Reshape the batch dimension into spatial dimensions. + + Parameters + ---------- + data : tvm.te.Tensor + N-D with shape [batch, spatial_shape, remaining_shape] + + block_shape : relay.Expr + 1-D of size [M] where M is number of spatial dims, specifies block size + for each spatial dimension. + + crops : relay.Expr + 2-D of shape [M, 2] where M is number of spatial dims, specifies + [begin, end] crop size for each spatial dimension. + + Returns + ------- + result : relay.Expr + N-D Tensor with shape + [batch / prod(block_shape), + in_shape[1] * block_shape[0] - crops[0,0] - crops[0,1], ..., + in_shape[M] * block_shape[M-1] - crops[M-1, 0] - crops[M-1, 1], + remaining_shape] + """ + + return _make.batch_to_space_nd(data, block_shape, crops) diff --git a/python/tvm/relay/op/op_attrs.py b/python/tvm/relay/op/op_attrs.py index 5dc2c2402c08..c7af42661da0 100644 --- a/python/tvm/relay/op/op_attrs.py +++ b/python/tvm/relay/op/op_attrs.py @@ -532,3 +532,13 @@ class TupleGetItemAttrs(Attrs): @tvm._ffi.register_object("relay.attrs.WithFuncIdAttrs") class WithFuncIdAttrs(Attrs): """Attributes used in with_funcid annotation operators""" + + +@tvm._ffi.register_object("relay.attrs.SpaceToBatchNDAttrs") +class SpaceToBatchNDAttrs(Attrs): + """Attributes used in SpaceToBatchND operators""" + + +@tvm._ffi.register_object("relay.attrs.BatchToSpaceNDAttrs") +class BatchToSpaceNDAttrs(Attrs): + """Attributes used in BatchToSpaceNDAttrs operators""" diff --git a/python/tvm/topi/nn/__init__.py b/python/tvm/topi/nn/__init__.py index a035f6778c97..2ebbd1d67bd1 100644 --- a/python/tvm/topi/nn/__init__.py +++ b/python/tvm/topi/nn/__init__.py @@ -46,3 +46,5 @@ from .fifo_buffer import * from .depth_to_space import * from .space_to_depth import * +from .space_to_batch_nd import * +from .batch_to_space_nd import * diff --git a/python/tvm/topi/nn/batch_to_space_nd.py b/python/tvm/topi/nn/batch_to_space_nd.py new file mode 100644 index 000000000000..c61a90a7777b --- /dev/null +++ b/python/tvm/topi/nn/batch_to_space_nd.py @@ -0,0 +1,49 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name +"""TVM operator batch_to_space_nd compute.""" +from __future__ import absolute_import +from . import cpp + + +def batch_to_space_nd(data, block_shape, crop_begin_list, crop_end_list): + """Perform space to batch transformation on the data + + Parameters + ---------- + data : tvm.te.Tensor + N-D Tensor with shape [batch, spatial_shape, remaining_shapes], + where spatial_shape has M dimensions. + + block_size : list of ints + list of size [M] where M is number of spatial dims, specifies block + size for each spatial dimension. + + crop_begin_list : list of ints + list of shape [M] where M is number of spatial dims, specifies + begin crop size for each spatial dimension. + + crop_end_list : list of ints + list of shape [M] where M is number of spatial dims, specifies + end crop size for each spatial dimension. + + Returns + ------- + output : tvm.te.Tensor + """ + + return cpp.nn.batch_to_space_nd(data, block_shape, crop_begin_list, crop_end_list) diff --git a/python/tvm/topi/nn/space_to_batch_nd.py b/python/tvm/topi/nn/space_to_batch_nd.py new file mode 100644 index 000000000000..149f2b6464c6 --- /dev/null +++ b/python/tvm/topi/nn/space_to_batch_nd.py @@ -0,0 +1,52 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name +"""TVM operator space_to_batch_nd compute.""" +from __future__ import absolute_import +from . import cpp + + +def space_to_batch_nd(data, block_shape, pad_before, pad_after, pad_value=0.0): + """Perform batch to space transformation on the data + + Parameters + ---------- + data : tvm.te.Tensor + N-D Tensor with shape [batch, spatial_shape, remaining_shapes], + where spatial_shape has M dimensions. + + block_shape : list of ints + list of size [M] where M is number of spatial dims, specifies block + size for each spatial dimension. + + pad_before : list of ints + list of shape [M] where M is number of spatial dims, specifies + zero-padding size before each spatial dimension. + + pad_after : list of ints + list of shape [M] where M is number of spatial dims, specifies + zero-padding size after each spatial dimension. + + pad_value : float, optional + The value used for padding. + + Returns + ------- + output : tvm.te.Tensor + """ + + return cpp.nn.space_to_batch_nd(data, block_shape, pad_before, pad_after, pad_value) diff --git a/python/tvm/topi/testing/__init__.py b/python/tvm/topi/testing/__init__.py index 5b23e8f4600e..4f905500d3f1 100644 --- a/python/tvm/topi/testing/__init__.py +++ b/python/tvm/topi/testing/__init__.py @@ -67,3 +67,5 @@ from .adaptive_pool_python import adaptive_pool from .grid_sample_python import affine_grid_python, grid_sample_nchw_python from .matrix_set_diag import matrix_set_diag +from .space_to_batch_nd import space_to_batch_nd_python +from .batch_to_space_nd import batch_to_space_nd_python diff --git a/python/tvm/topi/testing/batch_to_space_nd.py b/python/tvm/topi/testing/batch_to_space_nd.py new file mode 100644 index 000000000000..80af79b8cacb --- /dev/null +++ b/python/tvm/topi/testing/batch_to_space_nd.py @@ -0,0 +1,97 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, line-too-long, unused-variable, too-many-locals +"""Batch to space ND in python""" +import numpy as np +from . import strided_slice_python + + +def batch_to_space_nd_python(data, block_shape, crop_begin_list, crop_end_list): + """Batch to Space operator in python for NHWC layout. + + Parameters + ---------- + data : np.ndarray + N-D with shape [batch, spatial_shape, remaining_shapes], + where spatial_shape has M dimensions. + + block_shape : list of ints + 1-D array of size [M] where M is number of spatial dims, specifies block + size for each spatial dimension. + + crop_begin_list : list of ints + list of shape [M] where M is number of spatial dims, specifies + begin crop size for each spatial dimension. + + crop_end_list : list of ints + list of shape [M] where M is number of spatial dims, specifies + end crop size for each spatial dimension. + + Returns + ------- + b2s_out : np.ndarray + N-D with shape + [batch / prod(block_shape), + in_shape[1] * block_shape[0] - crop_begin_list[0] - crop_end_list[0], ..., + in_shape[M] * block_shape[M-1] - crop_begin_list[M-1] - crop_end_list[M-1], + remaining_shape] + """ + in_shape = data.shape + N = len(in_shape) + M = len(block_shape) + block_shape_prod = np.prod(block_shape) + in_batch = data.shape[0] + axis = [] + r_p_shape = [] + + r_shape = [block_shape[i] for i in range(0, M)] + axis.append(len(r_shape)) + r_shape.append(in_batch // block_shape_prod) + + for i in range(1, N): + axis.append(len(r_shape)) + if len(axis) < (M + N): + axis.append(len(r_shape) - (M + 1)) + r_shape.append(in_shape[i]) + + r_p_shape.append(int((in_batch / block_shape_prod))) + for i in range(1, M + 1): + r_p_shape.append(in_shape[i] * block_shape[i - 1]) + for i in range(M + 1, N): + r_p_shape.append(in_shape[i]) + + b2s_out = np.reshape(data, newshape=r_shape) + b2s_out = np.transpose(b2s_out, axes=axis) + b2s_out = np.reshape(b2s_out, newshape=r_p_shape) + + # Crop the start and end of dimensions of b2s_out + begin_idx = [] + end_idx = [] + strides = [] + + for i, _ in enumerate(r_p_shape): + strides.append(1) + if 0 < i <= M: + # begin and end index for spatial dimensions + begin_idx.append(crop_begin_list[i - 1]) + end_idx.append(r_p_shape[i] - crop_end_list[i - 1]) + else: + begin_idx.append(0) + end_idx.append(r_p_shape[i]) + + b2s_out = strided_slice_python(b2s_out, begin_idx, end_idx, strides) + return b2s_out diff --git a/python/tvm/topi/testing/space_to_batch_nd.py b/python/tvm/topi/testing/space_to_batch_nd.py new file mode 100644 index 000000000000..de88c27e56d6 --- /dev/null +++ b/python/tvm/topi/testing/space_to_batch_nd.py @@ -0,0 +1,93 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, line-too-long, unused-variable, too-many-locals +"""Space to batch ND in python""" +import numpy as np + + +def space_to_batch_nd_python(data, block_shape, pad_before, pad_after, pad_value=0): + """Space to Batch operator in python for NHWC layout. + + Parameters + ---------- + data : np.ndarray + N-D with shape [batch, spatial_shape, remaining_shapes], + where spatial_shape has M dimensions. + + block_shape : list of ints + 1-D array of size [M] where M is number of spatial dims, specifies block + size for each spatial dimension. + + pad_before : list of ints + list of shape [M] where M is number of spatial dims, specifies + zero-padding size before each spatial dimension. + + pad_after : list of ints + list of shape [M] where M is number of spatial dims, specifies + zero-padding size after each spatial dimension. + + pad_value : float, optional + the value used for padding. Defaults to 0. + + Returns + ------- + s2b_out : np.ndarray + N-D with shape [batch * prod(block_shape), + padded_data[1] / block_shape[0], ..., padded_data[M] / block_shape[M-1], + remaining_shape] + """ + M = len(block_shape) + in_batch = data.shape[0] + block_shape_prod = np.prod(block_shape) + + # Apply padding to input data + input_shape = data.shape + # Add the paddings for batch and remaining dims + paddings = map(list, zip(pad_before, pad_after)) + paddings = [[0, 0]] + list(paddings) + [[0, 0]] * (data.ndim - 1 - M) + padded_data = np.pad(data, paddings, mode="constant", constant_values=pad_value) + padded_shape = padded_data.shape + + # Get the reshape shape and transpose axes + r_shape = [] + trans_axis = [] + r_shape.append(in_batch) + for i in range(1, M + 1): + r_shape.append((int(padded_shape[i] // block_shape[i - 1]))) + r_shape.append(block_shape[i - 1]) + trans_axis.append(len(r_shape) - 1) + + axis_len = len(trans_axis) + trans_axis.append(0) + for i in range(axis_len): + trans_axis.append(trans_axis[i] - 1) + + out_shape = [] + out_shape.append(int((in_batch * block_shape_prod))) + for i in range(1, M + 1): + out_shape.append(int(padded_shape[i] // block_shape[i - 1])) + + for i in range(M + 1, len(input_shape)): + r_shape.append(input_shape[i]) + trans_axis.append(len(r_shape) - 1) + out_shape.append(input_shape[i]) + + s2b_out = np.reshape(padded_data, newshape=r_shape) + s2b_out = np.transpose(s2b_out, axes=trans_axis) + s2b_out = np.reshape(s2b_out, newshape=out_shape) + + return s2b_out diff --git a/src/relay/op/nn/nn.cc b/src/relay/op/nn/nn.cc index ea25c1a9c0f9..816b98038e46 100644 --- a/src/relay/op/nn/nn.cc +++ b/src/relay/op/nn/nn.cc @@ -1145,5 +1145,223 @@ RELAY_REGISTER_OP("nn.space_to_depth") .set_support_level(5) .add_type_rel("SpaceToDepth", SpaceToDepthRel); +// Positional relay function to create SpaceToBatchND operator +// used by frontend FFI +TVM_REGISTER_NODE_TYPE(SpaceToBatchNDAttrs); + +Expr MakeSpaceToBatchND(Expr data, Array block_shape, Array> paddings, + double pad_value) { + auto attrs = make_object(); + attrs->block_shape = std::move(block_shape); + attrs->paddings = std::move(paddings); + attrs->pad_value = pad_value; + static const Op& op = Op::Get("nn.space_to_batch_nd"); + return Call(op, {data}, Attrs(attrs), {}); +} + +bool SpaceToBatchNDRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { + CHECK_EQ(types.size(), 2); + + auto* input = types[0].as(); + // Input must be a TensorType + if (input == nullptr) { + CHECK(types[0].as()) + << "SpaceToBatchND: expect input type to be TensorType but got " << types[0]; + return false; + } + + if (input->shape.size() <= 1) return false; + + const auto* param = attrs.as(); + CHECK(param != nullptr); + + auto block_shape = param->block_shape; + auto paddings = param->paddings; + const int bdims = static_cast(block_shape.size()); + const int pdims = static_cast(paddings.size()); + // Paddings must be provided for each spatial dim. + CHECK(pdims == bdims) << "SpaceToBatchND: Paddings must be provided for each spatial dim"; + + // Apply paddings to input + auto in_shape = input->shape; + std::vector padded_shape(input->shape.begin(), input->shape.end()); + for (size_t i = 0; i < paddings.size(); i++) { + CHECK_EQ(paddings[i].size(), 2U); + auto pad_before = tir::as_const_int(param->paddings[i][0]); + auto pad_after = tir::as_const_int(param->paddings[i][1]); + auto padding = tir::make_const(input->shape[i].dtype(), *pad_before + *pad_after); + padded_shape[i + 1] = in_shape[i + 1] + padding; + } + + auto block_shape_numele = tir::make_const(DataType::Int(32), 1); + for (size_t i = 0; i < block_shape.size(); i++) { + block_shape_numele *= block_shape[i]; + } + + // Construct output shape + std::vector out_shape(padded_shape); + out_shape[0] = in_shape[0] * block_shape_numele; + for (size_t i = 1; i <= block_shape.size(); i++) { + out_shape[i] = div(padded_shape[i], block_shape[i - 1]); + } + + // Assign output shape + reporter->Assign(types[1], TensorType(Array(out_shape), input->dtype)); + return true; +} + +Array SpaceToBatchNDCompute(const Attrs& attrs, const Array& inputs, + const Type& out_type) { + const auto* param = attrs.as(); + CHECK(param != nullptr); + + auto b_shape = param->block_shape; + auto paddings = param->paddings; + Array pad_before; + Array pad_after; + + for (size_t i = 0; i < paddings.size(); ++i) { + pad_before.push_back(paddings[i][0]); + } + for (size_t i = 0; i < paddings.size(); ++i) { + pad_after.push_back(paddings[i][1]); + } + const auto* out_ttype = out_type.as(); + return Array{ + topi::space_to_batch_nd(inputs[0], b_shape, pad_before, pad_after, + tvm::tir::make_const(out_ttype->dtype, param->pad_value))}; +} + +TVM_REGISTER_GLOBAL("relay.op.nn._make.space_to_batch_nd").set_body_typed(MakeSpaceToBatchND); + +RELAY_REGISTER_OP("nn.space_to_batch_nd") + .describe(R"code(Divide spatial dimensions of the input into a grid of blocks +and interleave them into batch dim. + +- **data**: data is a ND array of shape + (batch, spatial_shapes, remaining_shapes) for NHWC + +- **out**: Output is a ND array of shape + (batch * prod(block_shape), padded_data[1] / block_shape[0], ..., padded_data[M] / block_shape[M-1], + remaining_shape) for NHWC, where M is the number of spatial dimensions. + +Example:: + + x = [[[[1], [2]], [[3], [4]]]] + + space_to_batch_nd(x, block_shape = [2, 2]) = + [[[[1]]], [[[2]]], [[[3]]], [[[4]]]] + +)code" TVM_ADD_FILELINE) + .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor.") + .set_attrs_type() + .set_support_level(5) + .add_type_rel("SpaceToBatchND", SpaceToBatchNDRel) + .set_attr("FTVMCompute", SpaceToBatchNDCompute) + .set_attr("TOpPattern", kInjective); + +/*****************************************************************/ + +// Positional relay function to create BatchToSpaceND operator +// used by frontend FFI +TVM_REGISTER_NODE_TYPE(BatchToSpaceNDAttrs); + +Expr MakeBatchToSpaceND(Expr data, Array block_shape, Array> crops) { + auto attrs = make_object(); + attrs->block_shape = std::move(block_shape); + attrs->crops = std::move(crops); + static const Op& op = Op::Get("nn.batch_to_space_nd"); + return Call(op, {data}, Attrs(attrs), {}); +} + +bool BatchToSpaceNDRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { + CHECK_EQ(types.size(), 2); + + auto* input = types[0].as(); + // Input must be a TensorType + if (input == nullptr) { + CHECK(types[0].as()) + << "BatchToSpaceND: expect input type to be TensorType but got " << types[0]; + return false; + } + + if (input->shape.size() <= 1) return false; + + const auto* param = attrs.as(); + CHECK(param != nullptr); + + auto block_shape = param->block_shape; + auto crops = param->crops; + const int bdims = static_cast(block_shape.size()); + const int cdims = static_cast(crops.size()); + const int indims = static_cast(input->shape.size()); + // crops must be provided for each spatial dim. + CHECK(cdims == bdims) << "BatchToSpaceND: crops must be provided for each spatial dim"; + CHECK(bdims < indims) << "BatchToSpaceND: block_shape must be less than input shape"; + + auto block_shape_numele = tir::make_const(DataType::Int(32), 1); + for (size_t i = 0; i < block_shape.size(); i++) { + block_shape_numele *= block_shape[i]; + } + + auto in_shape = input->shape; + + // Construct output shape + // Start with input shape, only batch and spatial dims shapes are modified. + std::vector out_shape(input->shape.begin(), input->shape.end()); + out_shape[0] = in_shape[0] / block_shape_numele; + for (size_t i = 1; i <= block_shape.size(); i++) { + out_shape[i] = (in_shape[i] * block_shape[i - 1]) - crops[i - 1][0] - crops[i - 1][1]; + } + for (int i = bdims + 1; i < indims; i++) { + out_shape[i] = in_shape[i]; + } + + // Assign output shape + reporter->Assign(types[1], TensorType(Array(out_shape), input->dtype)); + return true; +} + +Array BatchToSpaceNDCompute(const Attrs& attrs, const Array& inputs, + const Type& out_type) { + const auto* param = attrs.as(); + CHECK(param != nullptr); + + auto b_shape = param->block_shape; + auto crops = param->crops; + Array crop_begin_list, crop_end_list; + for (size_t i = 0; i < crops.size(); ++i) { + crop_begin_list.push_back(crops[i][0]); + crop_end_list.push_back(crops[i][1]); + } + + return Array{ + topi::batch_to_space_nd(inputs[0], b_shape, crop_begin_list, crop_end_list)}; +} + +TVM_REGISTER_GLOBAL("relay.op.nn._make.batch_to_space_nd").set_body_typed(MakeBatchToSpaceND); + +RELAY_REGISTER_OP("nn.batch_to_space_nd") + .describe(R"code(Reshape the batch dimension into spatial dimensions. + +Example:: + + x = [[[[1]]], [[[2]]], [[[3]]], [[[4]]]] + + batch_to_space_nd(x, block_shape = [2, 2]) = + [[[[1], [2]], [[3], [4]]]] + +)code" TVM_ADD_FILELINE) + .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor.") + .set_attrs_type() + .set_support_level(5) + .add_type_rel("BatchToSpaceND", BatchToSpaceNDRel) + .set_attr("FTVMCompute", BatchToSpaceNDCompute) + .set_attr("TOpPattern", kInjective); + } // namespace relay } // namespace tvm diff --git a/src/topi/nn.cc b/src/topi/nn.cc index 2c9546507de6..092fe65e19dc 100644 --- a/src/topi/nn.cc +++ b/src/topi/nn.cc @@ -57,6 +57,14 @@ TVM_REGISTER_GLOBAL("topi.nn.pad").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = pad(args[0], args[1], args[2], args[3]); }); +TVM_REGISTER_GLOBAL("topi.nn.space_to_batch_nd").set_body([](TVMArgs args, TVMRetValue* rv) { + *rv = space_to_batch_nd(args[0], args[1], args[2], args[3], args[4]); +}); + +TVM_REGISTER_GLOBAL("topi.nn.batch_to_space_nd").set_body([](TVMArgs args, TVMRetValue* rv) { + *rv = batch_to_space_nd(args[0], args[1], args[2], args[3]); +}); + /* Ops from nn/dense.h */ TVM_REGISTER_GLOBAL("topi.nn.dense").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = nn::dense(args[0], args[1], args[2], args[3]); diff --git a/tests/python/relay/test_op_level5.py b/tests/python/relay/test_op_level5.py index cfb85b6d1e91..5a5a12c9efe0 100644 --- a/tests/python/relay/test_op_level5.py +++ b/tests/python/relay/test_op_level5.py @@ -1141,6 +1141,60 @@ def verify_grid_sample(data_shape, grid_shape): verify_grid_sample((4, 4, 16, 32), (4, 2, 32, 32)) +@tvm.testing.uses_gpu +def test_space_to_batch_nd(): + def verify_space_to_batch_nd(dshape, block_shape, paddings): + x_data = np.random.uniform(size=dshape).astype("float32") + pad_before, pad_after = map(list, zip(*paddings)) + ref_res = tvm.topi.testing.space_to_batch_nd_python( + x_data, block_shape, pad_before, pad_after + ) + + x = relay.var("x", relay.TensorType(dshape, "float32")) + z = relay.nn.space_to_batch_nd(x, block_shape, paddings) + assert "block_shape=" in z.astext() + assert "paddings=" in z.astext() + zz = run_infer_type(z) + assert zz.checked_type == relay.TensorType(ref_res.shape, "float32") + func = relay.Function([x], z) + + for target, ctx in tvm.testing.enabled_targets(): + for kind in ["graph", "debug"]: + intrp = relay.create_executor(kind, ctx=ctx, target=target) + op_res = intrp.evaluate(func)(x_data) + tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-4) + + verify_space_to_batch_nd([3, 3, 2, 1], [3], [[0, 0]]) + verify_space_to_batch_nd([2, 2, 4, 1], [2, 2], [[0, 0], [2, 0]]) + + +@tvm.testing.uses_gpu +def test_batch_to_space_nd(): + def verify_batch_to_space_nd(dshape, block_shape, crops): + x_data = np.random.uniform(size=dshape).astype("float32") + crop_begin_list, crop_end_list = map(list, zip(*crops)) + ref_res = tvm.topi.testing.batch_to_space_nd_python( + x_data, block_shape, crop_begin_list, crop_end_list + ) + + x = relay.var("x", relay.TensorType(dshape, "float32")) + z = relay.nn.batch_to_space_nd(x, block_shape, crops) + assert "block_shape=" in z.astext() + assert "crops=" in z.astext() + zz = run_infer_type(z) + assert zz.checked_type == relay.TensorType(ref_res.shape, "float32") + func = relay.Function([x], z) + + for target, ctx in tvm.testing.enabled_targets(): + for kind in ["graph", "debug"]: + intrp = relay.create_executor(kind, ctx=ctx, target=target) + op_res = intrp.evaluate(func)(x_data) + tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-4) + + verify_batch_to_space_nd([4, 1, 1, 3], [2, 2], [[0, 0], [0, 0]]) + verify_batch_to_space_nd([8, 1, 3, 1], [2, 2], [[0, 0], [2, 0]]) + + if __name__ == "__main__": test_resize_infer_type() test_resize() @@ -1163,3 +1217,5 @@ def verify_grid_sample(data_shape, grid_shape): test_dilation2d_run() test_affine_grid() test_grid_sample() + test_space_to_batch_nd() + test_batch_to_space_nd() diff --git a/tests/python/topi/python/test_topi_batch_to_space_nd.py b/tests/python/topi/python/test_topi_batch_to_space_nd.py new file mode 100644 index 000000000000..89d044fed963 --- /dev/null +++ b/tests/python/topi/python/test_topi_batch_to_space_nd.py @@ -0,0 +1,70 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Test code for batch to space""" +import numpy as np +import tvm +from tvm import te +from tvm import topi +import tvm.testing +import tvm.topi.testing + + +def verify_batch_to_space_nd(input_shape, block_shape, crop_begin_list, crop_end_list): + out_shape = [] + out_shape.append(int((input_shape[0] / np.prod(block_shape)))) + for i in range(1, len(block_shape) + 1): + crop = crop_begin_list[i - 1] + crop_end_list[i - 1] + out_shape.append(input_shape[i] * block_shape[i - 1] - crop) + for i in range(len(block_shape) + 1, len(input_shape)): + out_shape.append(input_shape[i]) + + A = te.placeholder(input_shape, name="A", dtype="float32") + dtype = A.dtype + a_np = np.random.uniform(size=input_shape).astype(dtype) + + B = topi.nn.batch_to_space_nd(A, block_shape, crop_begin_list, crop_end_list) + + b_np = tvm.topi.testing.batch_to_space_nd_python( + a_np, block_shape, crop_begin_list, crop_end_list + ) + + def check_device(device, ctx): + print("Running on target: %s" % device) + with tvm.target.create(device): + s = tvm.topi.testing.get_injective_schedule(device)(B) + a = tvm.nd.array(a_np, ctx) + b = tvm.nd.array(np.zeros(out_shape, dtype=dtype), ctx) + f = tvm.build(s, [A, B], device) + f(a, b) + tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-3, atol=1e-3) + + for device, ctx in tvm.testing.enabled_targets(): + check_device(device, ctx) + + +@tvm.testing.uses_gpu +def test_batch_to_space(): + # Without crops + verify_batch_to_space_nd([4, 1, 1, 1], [2, 2], [0, 0], [0, 0]) + # With crops + verify_batch_to_space_nd([8, 1, 3, 1], [2, 2], [0, 2], [0, 0]) + verify_batch_to_space_nd([18, 2, 1, 2], [2, 3], [1, 1], [0, 0]) + verify_batch_to_space_nd([20, 5, 8, 7], [2, 2], [1, 1], [1, 1]) + + +if __name__ == "__main__": + test_batch_to_space() diff --git a/tests/python/topi/python/test_topi_space_to_batch_nd.py b/tests/python/topi/python/test_topi_space_to_batch_nd.py new file mode 100644 index 000000000000..6f969f391002 --- /dev/null +++ b/tests/python/topi/python/test_topi_space_to_batch_nd.py @@ -0,0 +1,72 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Test code for space to batch""" +import numpy as np +import tvm +from tvm import te +from tvm import topi +import tvm.testing +import tvm.topi.testing + + +def verify_space_to_batch_nd(input_shape, block_shape, pad_before, pad_after, pad_value=0): + out_shape = [] + out_shape.append(int((input_shape[0] * np.prod(block_shape)))) + for i in range(1, len(block_shape) + 1): + pad = pad_before[i - 1] + pad_after[i - 1] + out_shape.append(int((input_shape[i] + pad) // block_shape[i - 1])) + for i in range(len(block_shape) + 1, len(input_shape)): + out_shape.append(input_shape[i]) + + A = te.placeholder(input_shape, name="A", dtype="float32") + dtype = A.dtype + a_np = np.random.uniform(size=input_shape).astype(dtype) + + B = topi.nn.space_to_batch_nd(A, block_shape, pad_before, pad_after, pad_value) + + b_np = tvm.topi.testing.space_to_batch_nd_python( + a_np, block_shape, pad_before, pad_after, pad_value + ) + + def check_device(device, ctx): + print("Running on target: %s" % device) + with tvm.target.create(device): + s = tvm.topi.testing.get_injective_schedule(device)(B) + a = tvm.nd.array(a_np, ctx) + b = tvm.nd.array(np.zeros(out_shape, dtype=dtype), ctx) + f = tvm.build(s, [A, B], device) + f(a, b) + tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-3, atol=1e-3) + + for device, ctx in tvm.testing.enabled_targets(): + check_device(device, ctx) + + +@tvm.testing.uses_gpu +def test_space_to_batch(): + # Without paddings + verify_space_to_batch_nd([3, 3, 2, 1], [3], [0], [0]) + # With paddings + verify_space_to_batch_nd([3, 3, 2, 1], [3], [1], [2]) + # Multiple spatial dims + verify_space_to_batch_nd([3, 3, 4, 5, 2], [3, 4, 2], [1, 0, 3], [2, 0, 0]) + # No remaining dims + verify_space_to_batch_nd([3, 3, 4, 5, 2], [3, 4, 2, 2], [1, 4, 0, 0], [2, 0, 1, 0]) + + +if __name__ == "__main__": + test_space_to_batch()