From 5adb6fcfb3ea50c76d387e63ce367bbe8c3f18d3 Mon Sep 17 00:00:00 2001 From: Dick Carter Date: Sat, 16 Feb 2019 15:17:33 -0800 Subject: [PATCH] Add NHWC layout support to Pooling (cpu, gpu cuda, gpu cuDNN) (#13749) * Adds layout support: mx.sym.Pooling(..., layout='NHWC',...) with tests. * Docs changes * Trigger * Skip NHWC pooling tests on non-cuDNN platforms * Fix pylint NHWC pooling * Fixes from review * Add CuDNNPoolingOp::Supports() in place of Forward()/Backward() bool return. * Add layout support to cpu implementation of Pooling, with tests. * Fix cpplint. * Fix bug in cpu nhwc impl. * Add MXNet CUDA pooling in NWC, NHWC and NDHWC. Turn on 3D cuDNN pooling. Tests. * Add PoolingParam::GetLayout() for better default layout handling. * Fix cpplint. * Throw exception for quantization pooling not NCHW. * Expand nhwc pooling test coverage. * SupportMKLDNNPooling() to examine layout param. * Compare 'std' and 'v1' pooling versions only when op definitions permit. * Add pooling test diagnostic output. * Fix syntax. * Fix pooling FInplaceOption so it can be shared by all implementations. * Add missing param definition. * Fix #if logic. * Temp switch to DickJC123/mshadow: shows effect of half round-to-nearest on cpu. * Move back to dmlc/mshadow.git, now with float->half rounding. * Avoid underflow of lp pooling calc for dtype=float16. * Remove redundant pooling test. * Minor variable naming fixes. * Modify FInplaceOption handling per reviewer comments. Expand testing. * Correct gluon Pooling layout param description. * Correct Symbol Pooling description. * Use 'CHECK(x)' rather than 'if (x) LOG(FATAL)'. * Empty commit to trigger CI. --- python/mxnet/gluon/nn/conv_layers.py | 93 +- src/operator/nn/cudnn/cudnn_pooling-inl.h | 239 +++- src/operator/nn/mkldnn/mkldnn_pooling-inl.h | 3 +- src/operator/nn/pool.cuh | 479 +++++--- src/operator/nn/pool.h | 1010 +++++++++++++++-- src/operator/nn/pool_utils.h | 18 +- src/operator/nn/pooling-inl.h | 73 +- src/operator/nn/pooling.cc | 207 ++-- src/operator/nn/pooling.cu | 28 +- .../quantization/quantized_pooling.cc | 3 + tests/python/gpu/test_operator_gpu.py | 460 ++++---- tests/python/unittest/test_gluon.py | 165 +-- 12 files changed, 2024 insertions(+), 754 deletions(-) diff --git a/python/mxnet/gluon/nn/conv_layers.py b/python/mxnet/gluon/nn/conv_layers.py index 5f20d20c02ab..c210081f6071 100644 --- a/python/mxnet/gluon/nn/conv_layers.py +++ b/python/mxnet/gluon/nn/conv_layers.py @@ -673,7 +673,7 @@ def __init__(self, channels, kernel_size, strides=(1, 1, 1), padding=(0, 0, 0), class _Pooling(HybridBlock): """Abstract class for different pooling layers.""" def __init__(self, pool_size, strides, padding, ceil_mode, global_pool, - pool_type, count_include_pad=None, **kwargs): + pool_type, layout, count_include_pad=None, **kwargs): super(_Pooling, self).__init__(**kwargs) if strides is None: strides = pool_size @@ -684,6 +684,7 @@ def __init__(self, pool_size, strides, padding, ceil_mode, global_pool, self._kwargs = { 'kernel': pool_size, 'stride': strides, 'pad': padding, 'global_pool': global_pool, 'pool_type': pool_type, + 'layout': layout, 'pooling_convention': 'full' if ceil_mode else 'valid'} if count_include_pad is not None: self._kwargs['count_include_pad'] = count_include_pad @@ -695,7 +696,8 @@ def hybrid_forward(self, F, x): return F.Pooling(x, name='fwd', **self._kwargs) def __repr__(self): - s = '{name}(size={kernel}, stride={stride}, padding={pad}, ceil_mode={ceil_mode})' + s = '{name}(size={kernel}, stride={stride}, padding={pad}, ceil_mode={ceil_mode}' + s += ', global_pool={global_pool}, pool_type={pool_type}, layout={layout})' return s.format(name=self.__class__.__name__, ceil_mode=self._kwargs['pooling_convention'] == 'full', **self._kwargs) @@ -716,7 +718,7 @@ class MaxPool1D(_Pooling): If padding is non-zero, then the input is implicitly zero-padded on both sides for padding number of points. layout : str, default 'NCW' - Dimension ordering of data and weight. Only supports 'NCW' layout for now. + Dimension ordering of data and out ('NCW' or 'NWC'). 'N', 'C', 'W' stands for batch, channel, and width (time) dimensions respectively. Pooling is applied on the W dimension. ceil_mode : bool, default False @@ -738,12 +740,13 @@ class MaxPool1D(_Pooling): """ def __init__(self, pool_size=2, strides=None, padding=0, layout='NCW', ceil_mode=False, **kwargs): - assert layout == 'NCW', "Only supports 'NCW' layout for now" + assert layout in ('NCW', 'NWC'),\ + "Only NCW and NWC layouts are valid for 1D Pooling" if isinstance(pool_size, numeric_types): pool_size = (pool_size,) assert len(pool_size) == 1, "pool_size must be a number or a list of 1 ints" super(MaxPool1D, self).__init__( - pool_size, strides, padding, ceil_mode, False, 'max', **kwargs) + pool_size, strides, padding, ceil_mode, False, 'max', layout, **kwargs) class MaxPool2D(_Pooling): @@ -761,7 +764,7 @@ class MaxPool2D(_Pooling): If padding is non-zero, then the input is implicitly zero-padded on both sides for padding number of points. layout : str, default 'NCHW' - Dimension ordering of data and weight. Only supports 'NCHW' layout for now. + Dimension ordering of data and out ('NCHW' or 'NHWC'). 'N', 'C', 'H', 'W' stands for batch, channel, height, and width dimensions respectively. padding is applied on 'H' and 'W' dimension. ceil_mode : bool, default False @@ -786,12 +789,13 @@ class MaxPool2D(_Pooling): """ def __init__(self, pool_size=(2, 2), strides=None, padding=0, layout='NCHW', ceil_mode=False, **kwargs): - assert layout == 'NCHW', "Only supports 'NCHW' layout for now" + assert layout in ('NCHW', 'NHWC'),\ + "Only NCHW and NHWC layouts are valid for 2D Pooling" if isinstance(pool_size, numeric_types): pool_size = (pool_size,)*2 assert len(pool_size) == 2, "pool_size must be a number or a list of 2 ints" super(MaxPool2D, self).__init__( - pool_size, strides, padding, ceil_mode, False, 'max', **kwargs) + pool_size, strides, padding, ceil_mode, False, 'max', layout, **kwargs) class MaxPool3D(_Pooling): @@ -809,7 +813,7 @@ class MaxPool3D(_Pooling): If padding is non-zero, then the input is implicitly zero-padded on both sides for padding number of points. layout : str, default 'NCDHW' - Dimension ordering of data and weight. Only supports 'NCDHW' layout for now. + Dimension ordering of data and out ('NCDHW' or 'NDHWC'). 'N', 'C', 'H', 'W', 'D' stands for batch, channel, height, width and depth dimensions respectively. padding is applied on 'D', 'H' and 'W' dimension. @@ -836,12 +840,13 @@ class MaxPool3D(_Pooling): """ def __init__(self, pool_size=(2, 2, 2), strides=None, padding=0, ceil_mode=False, layout='NCDHW', **kwargs): - assert layout == 'NCDHW', "Only supports 'NCDHW' layout for now" + assert layout in ('NCDHW', 'NDHWC'),\ + "Only NCDHW and NDHWC layouts are valid for 3D Pooling" if isinstance(pool_size, numeric_types): pool_size = (pool_size,)*3 assert len(pool_size) == 3, "pool_size must be a number or a list of 3 ints" super(MaxPool3D, self).__init__( - pool_size, strides, padding, ceil_mode, False, 'max', **kwargs) + pool_size, strides, padding, ceil_mode, False, 'max', layout, **kwargs) class AvgPool1D(_Pooling): @@ -858,7 +863,7 @@ class AvgPool1D(_Pooling): If padding is non-zero, then the input is implicitly zero-padded on both sides for padding number of points. layout : str, default 'NCW' - Dimension ordering of data and weight. Only supports 'NCW' layout for now. + Dimension ordering of data and out ('NCW' or 'NWC'). 'N', 'C', 'W' stands for batch, channel, and width (time) dimensions respectively. padding is applied on 'W' dimension. ceil_mode : bool, default False @@ -882,12 +887,14 @@ class AvgPool1D(_Pooling): """ def __init__(self, pool_size=2, strides=None, padding=0, layout='NCW', ceil_mode=False, count_include_pad=True, **kwargs): - assert layout == 'NCW', "Only supports 'NCW' layout for now" + assert layout in ('NCW', 'NWC'),\ + "Only NCW and NWC layouts are valid for 1D Pooling" if isinstance(pool_size, numeric_types): pool_size = (pool_size,) assert len(pool_size) == 1, "pool_size must be a number or a list of 1 ints" super(AvgPool1D, self).__init__( - pool_size, strides, padding, ceil_mode, False, 'avg', count_include_pad, **kwargs) + pool_size, strides, padding, ceil_mode, False, 'avg', layout, count_include_pad, + **kwargs) class AvgPool2D(_Pooling): @@ -904,7 +911,7 @@ class AvgPool2D(_Pooling): If padding is non-zero, then the input is implicitly zero-padded on both sides for padding number of points. layout : str, default 'NCHW' - Dimension ordering of data and weight. Only supports 'NCHW' layout for now. + Dimension ordering of data and out ('NCHW' or 'NHWC'). 'N', 'C', 'H', 'W' stands for batch, channel, height, and width dimensions respectively. padding is applied on 'H' and 'W' dimension. ceil_mode : bool, default False @@ -931,12 +938,14 @@ class AvgPool2D(_Pooling): """ def __init__(self, pool_size=(2, 2), strides=None, padding=0, ceil_mode=False, layout='NCHW', count_include_pad=True, **kwargs): - assert layout == 'NCHW', "Only supports 'NCHW' layout for now" + assert layout in ('NCHW', 'NHWC'),\ + "Only NCHW and NHWC layouts are valid for 2D Pooling" if isinstance(pool_size, numeric_types): pool_size = (pool_size,)*2 assert len(pool_size) == 2, "pool_size must be a number or a list of 2 ints" super(AvgPool2D, self).__init__( - pool_size, strides, padding, ceil_mode, False, 'avg', count_include_pad, **kwargs) + pool_size, strides, padding, ceil_mode, False, 'avg', layout, count_include_pad, + **kwargs) class AvgPool3D(_Pooling): @@ -953,7 +962,7 @@ class AvgPool3D(_Pooling): If padding is non-zero, then the input is implicitly zero-padded on both sides for padding number of points. layout : str, default 'NCDHW' - Dimension ordering of data and weight. Can be 'NCDHW', 'NDHWC', etc. + Dimension ordering of data and out ('NCDHW' or 'NDHWC'). 'N', 'C', 'H', 'W', 'D' stands for batch, channel, height, width and depth dimensions respectively. padding is applied on 'D', 'H' and 'W' dimension. @@ -982,12 +991,14 @@ class AvgPool3D(_Pooling): """ def __init__(self, pool_size=(2, 2, 2), strides=None, padding=0, ceil_mode=False, layout='NCDHW', count_include_pad=True, **kwargs): - assert layout == 'NCDHW', "Only supports 'NCDHW' layout for now" + assert layout in ('NCDHW', 'NDHWC'),\ + "Only NCDHW and NDHWC layouts are valid for 3D Pooling" if isinstance(pool_size, numeric_types): pool_size = (pool_size,)*3 assert len(pool_size) == 3, "pool_size must be a number or a list of 3 ints" super(AvgPool3D, self).__init__( - pool_size, strides, padding, ceil_mode, False, 'avg', count_include_pad, **kwargs) + pool_size, strides, padding, ceil_mode, False, 'avg', layout, count_include_pad, + **kwargs) class GlobalMaxPool1D(_Pooling): @@ -997,7 +1008,7 @@ class GlobalMaxPool1D(_Pooling): Parameters ---------- layout : str, default 'NCW' - Dimension ordering of data and weight. Only supports 'NCW' layout for now. + Dimension ordering of data and out ('NCW' or 'NWC'). 'N', 'C', 'W' stands for batch, channel, and width (time) dimensions respectively. Pooling is applied on the W dimension. @@ -1011,9 +1022,10 @@ class GlobalMaxPool1D(_Pooling): when `layout` is `NCW`. """ def __init__(self, layout='NCW', **kwargs): - assert layout == 'NCW', "Only supports 'NCW' layout for now" + assert layout in ('NCW', 'NWC'),\ + "Only NCW and NWC layouts are valid for 1D Pooling" super(GlobalMaxPool1D, self).__init__( - (1,), None, 0, True, True, 'max', **kwargs) + (1,), None, 0, True, True, 'max', layout, **kwargs) class GlobalMaxPool2D(_Pooling): @@ -1023,7 +1035,7 @@ class GlobalMaxPool2D(_Pooling): Parameters ---------- layout : str, default 'NCHW' - Dimension ordering of data and weight. Only supports 'NCHW' layout for now. + Dimension ordering of data and out ('NCHW' or 'NHWC'). 'N', 'C', 'H', 'W' stands for batch, channel, height, and width dimensions respectively. padding is applied on 'H' and 'W' dimension. @@ -1038,9 +1050,10 @@ class GlobalMaxPool2D(_Pooling): `(batch_size, channels, 1, 1)` when `layout` is `NCHW`. """ def __init__(self, layout='NCHW', **kwargs): - assert layout == 'NCHW', "Only supports 'NCHW' layout for now" + assert layout in ('NCHW', 'NHWC'),\ + "Only NCHW and NHWC layouts are valid for 2D Pooling" super(GlobalMaxPool2D, self).__init__( - (1, 1), None, 0, True, True, 'max', **kwargs) + (1, 1), None, 0, True, True, 'max', layout, **kwargs) class GlobalMaxPool3D(_Pooling): @@ -1050,7 +1063,7 @@ class GlobalMaxPool3D(_Pooling): Parameters ---------- layout : str, default 'NCDHW' - Dimension ordering of data and weight. Only supports 'NCDHW' layout for now. + Dimension ordering of data and out ('NCDHW' or 'NDHWC'). 'N', 'C', 'H', 'W', 'D' stands for batch, channel, height, width and depth dimensions respectively. padding is applied on 'D', 'H' and 'W' dimension. @@ -1066,9 +1079,10 @@ class GlobalMaxPool3D(_Pooling): `(batch_size, channels, 1, 1, 1)` when `layout` is `NCDHW`. """ def __init__(self, layout='NCDHW', **kwargs): - assert layout == 'NCDHW', "Only supports 'NCDHW' layout for now" + assert layout in ('NCDHW', 'NDHWC'),\ + "Only NCDHW and NDHWC layouts are valid for 3D Pooling" super(GlobalMaxPool3D, self).__init__( - (1, 1, 1), None, 0, True, True, 'max', **kwargs) + (1, 1, 1), None, 0, True, True, 'max', layout, **kwargs) class GlobalAvgPool1D(_Pooling): @@ -1077,7 +1091,7 @@ class GlobalAvgPool1D(_Pooling): Parameters ---------- layout : str, default 'NCW' - Dimension ordering of data and weight. Only supports 'NCW' layout for now. + Dimension ordering of data and out ('NCW' or 'NWC'). 'N', 'C', 'W' stands for batch, channel, and width (time) dimensions respectively. padding is applied on 'W' dimension. @@ -1090,9 +1104,10 @@ class GlobalAvgPool1D(_Pooling): - **out**: 3D output tensor with shape `(batch_size, channels, 1)`. """ def __init__(self, layout='NCW', **kwargs): - assert layout == 'NCW', "Only supports 'NCW' layout for now" + assert layout in ('NCW', 'NWC'),\ + "Only NCW and NWC layouts are valid for 1D Pooling" super(GlobalAvgPool1D, self).__init__( - (1,), None, 0, True, True, 'avg', **kwargs) + (1,), None, 0, True, True, 'avg', layout, **kwargs) class GlobalAvgPool2D(_Pooling): @@ -1101,7 +1116,7 @@ class GlobalAvgPool2D(_Pooling): Parameters ---------- layout : str, default 'NCHW' - Dimension ordering of data and weight. Only supports 'NCHW' layout for now. + Dimension ordering of data and out ('NCHW' or 'NHWC'). 'N', 'C', 'H', 'W' stands for batch, channel, height, and width dimensions respectively. @@ -1116,9 +1131,10 @@ class GlobalAvgPool2D(_Pooling): `(batch_size, channels, 1, 1)` when `layout` is `NCHW`. """ def __init__(self, layout='NCHW', **kwargs): - assert layout == 'NCHW', "Only supports 'NCHW' layout for now" + assert layout in ('NCHW', 'NHWC'),\ + "Only NCHW and NHWC layouts are valid for 2D Pooling" super(GlobalAvgPool2D, self).__init__( - (1, 1), None, 0, True, True, 'avg', **kwargs) + (1, 1), None, 0, True, True, 'avg', layout, **kwargs) class GlobalAvgPool3D(_Pooling): @@ -1127,7 +1143,7 @@ class GlobalAvgPool3D(_Pooling): Parameters ---------- layout : str, default 'NCDHW' - Dimension ordering of data and weight. Can be 'NCDHW', 'NDHWC', etc. + Dimension ordering of data and out ('NCDHW' or 'NDHWC'). 'N', 'C', 'H', 'W', 'D' stands for batch, channel, height, width and depth dimensions respectively. padding is applied on 'D', 'H' and 'W' dimension. @@ -1143,9 +1159,10 @@ class GlobalAvgPool3D(_Pooling): `(batch_size, channels, 1, 1, 1)` when `layout` is `NCDHW`. """ def __init__(self, layout='NCDHW', **kwargs): - assert layout == 'NCDHW', "Only supports 'NCDHW' layout for now" + assert layout in ('NCDHW', 'NDHWC'),\ + "Only NCDHW and NDHWC layouts are valid for 3D Pooling" super(GlobalAvgPool3D, self).__init__( - (1, 1, 1), None, 0, True, True, 'avg', **kwargs) + (1, 1, 1), None, 0, True, True, 'avg', layout, **kwargs) class ReflectionPad2D(HybridBlock): diff --git a/src/operator/nn/cudnn/cudnn_pooling-inl.h b/src/operator/nn/cudnn/cudnn_pooling-inl.h index 89fa73ef5471..ada605db0ee9 100644 --- a/src/operator/nn/cudnn/cudnn_pooling-inl.h +++ b/src/operator/nn/cudnn/cudnn_pooling-inl.h @@ -21,13 +21,13 @@ * Copyright (c) 2015 by Contributors * \file cudnn_pooling-inl.h * \brief - * \author Bing Xu + * \author Bing Xu, Dick Carter */ #ifndef MXNET_OPERATOR_NN_CUDNN_CUDNN_POOLING_INL_H_ #define MXNET_OPERATOR_NN_CUDNN_CUDNN_POOLING_INL_H_ #include -#include +#include #include "../pooling-inl.h" namespace mxnet { @@ -63,7 +63,7 @@ class CuDNNPoolingOp { } break; default: - LOG(FATAL) << "Not implmented"; + LOG(FATAL) << "Pooling type not implemented by cuDNN."; } } @@ -81,7 +81,7 @@ class CuDNNPoolingOp { CHECK_EQ(s->dnn_handle_ownership_, mshadow::Stream::OwnHandle); typename DataType::ScaleType alpha = 1.0f; typename DataType::ScaleType beta = 0.0f; - this->Init(s, in_data, out_data); + CHECK(this->Init(s, in_data, out_data)) << "cuDNN Pooling invoked with unsupported parameters."; if (param_.kernel.ndim() == 2) { // 2d pool Tensor data = in_data.get(s); @@ -111,7 +111,7 @@ class CuDNNPoolingOp { out_desc_, out.dptr_)); } else { - LOG(FATAL) << "Only support 2D or 3D pooling"; + LOG(FATAL) << "cuDNN only supports 2D or 3D pooling."; } } @@ -125,7 +125,7 @@ class CuDNNPoolingOp { CHECK_EQ(s->dnn_handle_ownership_, mshadow::Stream::OwnHandle); typename DataType::ScaleType alpha = 1.0f; typename DataType::ScaleType beta = 0.0f; - this->Init(s, in_data, out_data); + CHECK(this->Init(s, in_data, out_data)) << "cuDNN Pooling invoked with unsupported parameters."; if (param_.kernel.ndim() == 2) { // 2d pool Tensor m_out_grad = out_grad.get(s); @@ -163,106 +163,219 @@ class CuDNNPoolingOp { in_desc_, m_in_grad.dptr_)); } else { - LOG(FATAL) << "Only support 2D or 3D pooling"; + LOG(FATAL) << "cuDNN only supports 2D or 3D pooling."; } } +/*! + * \brief Returns whether the cuDNN library version supports the pooling operation + * described by `param`: cuDNN v5 and earlier does not support 3D pooling for example. + * CuDNN v7.1.4 backprop kernel doesn't support kernel sizes 9 and above. + */ + static bool Supports(const PoolingParam ¶m, const TBlob& input) { + using namespace mshadow; + static bool sum_pooling_warning_issued = false; + static bool lp_pooling_warning_issued = false; + static bool unsupported_dim_warning_issued = false; + int layout = param.GetLayout(input.ndim()); + + switch (param.pool_type) { + case pool_enum::kMaxPooling: + case pool_enum::kAvgPooling: + break; + case pool_enum::kSumPooling: + if (!sum_pooling_warning_issued) { + sum_pooling_warning_issued = true; + LOG(WARNING) << "Sum pooling is not supported by cudnn, MXNet sum pooling is applied."; + } + return false; + case pool_enum::kLpPooling: + if (!lp_pooling_warning_issued) { + lp_pooling_warning_issued = true; + LOG(WARNING) << "Lp pooling is not supported by cudnn, MXNet Lp pooling is applied."; + } + return false; + default: + return false; + } + + if (param.kernel.ndim() == 2) { + // 2d pooling + if (!(layout == mshadow::kNCHW || layout == mshadow::kNHWC)) + return false; +#if CUDNN_VERSION == 7104 + // CuDNN v7.1.4 backprop kernel doesn't support kernel sizes 9 and above. + // Perform shape calculations in a standard (NCHW) layout space + mshadow::Shape<4> input_shape = input.shape_.get<4>(); + mshadow::Shape<4> dshape_nchw = (layout == mshadow::kNHWC) ? + ConvertLayout(input_shape, mshadow::kNHWC, mshadow::kNCHW) : + input_shape; + int kernel_height = param.global_pool ? dshape_nchw[2] : param.kernel[0]; + int kernel_width = param.global_pool ? dshape_nchw[3] : param.kernel[1]; + if (kernel_height > 8 || kernel_width > 8) + return false; +#endif +#if CUDNN_VERSION >= 7105 && CUDNN_VERSION < 7500 + // Avoid strided NHWC max pooling for some configs + if (layout == mshadow::kNHWC && + param.pool_type == pool_enum::kMaxPooling && !param.global_pool) { + if (param.stride[0] >= 3 || + param.stride[0] == 2 && param.kernel[0] % 2 == 0 && param.kernel[0] != 2) + return false; + if (param.stride[1] >= 3 || + param.stride[1] == 2 && param.kernel[1] % 2 == 0 && param.kernel[1] != 2) + return false; + } +#endif + } else if (param.kernel.ndim() == 3) { + // 3d pooling +#if CUDNN_MAJOR < 5 + LogUnsupportedDim(&unsupported_dim_warning_issued, param.kernel.ndim()); + return false; +#endif + if (!(layout == mshadow::kNCDHW || layout == mshadow::kNDHWC)) + return false; + } else { + // Unsupported kernel dim + LogUnsupportedDim(&unsupported_dim_warning_issued, param.kernel.ndim()); + return false; + } + + return true; + } + private: - inline void Init(mshadow::Stream *s, const TBlob &in_data, + // Return boolean saying whether pooling configuration is supported + inline bool Init(mshadow::Stream *s, const TBlob &in_data, const TBlob &out_data) { using namespace mshadow; + bool is_supported = true; #if CUDNN_MAJOR >= 5 nan_prop_ = CUDNN_NOT_PROPAGATE_NAN; #endif + int layout = param_.GetLayout(in_data.ndim()); if (param_.kernel.ndim() == 2) { - // 2d conv + // 2d pooling + CHECK(layout == mshadow::kNCHW || layout == mshadow::kNHWC) << "Need 2D layout NCHW or NHWC."; + cudnnTensorFormat_t cudnn_layout = (layout == mshadow::kNCHW) ? CUDNN_TENSOR_NCHW + : CUDNN_TENSOR_NHWC; Tensor data = in_data.get(s); Tensor out = out_data.get(s); - mshadow::Shape<4> dshape = data.shape_; + // Perform shape calculations in a standard (NCHW) layout space + mshadow::Shape<4> dshape_nchw = (layout == mshadow::kNHWC) ? + ConvertLayout(data.shape_, mshadow::kNHWC, mshadow::kNCHW) : + data.shape_; + mshadow::Shape<4> oshape_nchw = (layout == mshadow::kNHWC) ? + ConvertLayout(out.shape_, mshadow::kNHWC, mshadow::kNCHW) : + out.shape_; CUDNN_CALL(cudnnSetTensor4dDescriptor(in_desc_, - CUDNN_TENSOR_NCHW, + cudnn_layout, dtype_, - data.shape_[0], - data.shape_[1], - data.shape_[2], - data.shape_[3])); + dshape_nchw[0], + dshape_nchw[1], + dshape_nchw[2], + dshape_nchw[3])); CUDNN_CALL(cudnnSetTensor4dDescriptor(out_desc_, - CUDNN_TENSOR_NCHW, + cudnn_layout, dtype_, - out.shape_[0], - out.shape_[1], - out.shape_[2], - out.shape_[3])); + oshape_nchw[0], + oshape_nchw[1], + oshape_nchw[2], + oshape_nchw[3])); + int kernel_height = param_.global_pool ? dshape_nchw[2] : param_.kernel[0]; + int kernel_width = param_.global_pool ? dshape_nchw[3] : param_.kernel[1]; + // CuDNN v7.1.4 backprop kernel doesn't support kernel sizes 9 and above. + // For reference see Fixed Issues section in + // https://docs.nvidia.com/deeplearning/sdk/cudnn-release-notes/rel_721.html#rel_721 + #if CUDNN_VERSION == 7104 + is_supported = kernel_height <= 8 && kernel_width <= 8; + #endif #if CUDNN_MAJOR >= 5 CUDNN_CALL(cudnnSetPooling2dDescriptor(pooling_desc_, mode_, nan_prop_, - param_.global_pool ? dshape[2] : param_.kernel[0], - param_.global_pool ? dshape[3] : param_.kernel[1], + kernel_height, + kernel_width, param_.global_pool ? 0 : param_.pad[0], param_.global_pool ? 0 : param_.pad[1], param_.global_pool ? 1 : param_.stride[0], - param_.global_pool ? 1 :param_.stride[1])); + param_.global_pool ? 1 : param_.stride[1])); #else CUDNN_CALL(cudnnSetPooling2dDescriptor(pooling_desc_, mode_, - param_.global_pool ? dshape[2] : param_.kernel[0], - param_.global_pool ? dshape[3] : param_.kernel[1], + kernel_height, + kernel_width, param_.global_pool ? 0 : param_.pad[0], - param_.global_ppol ? 0 : param_.pad[1], + param_.global_pool ? 0 : param_.pad[1], param_.global_pool ? 1 : param_.stride[0], param_.global_pool ? 1 : param_.stride[1])); #endif } else { + CHECK(layout == mshadow::kNCDHW || + layout == mshadow::kNDHWC) << "Need 3D layout NCDHW or NDHWC."; Tensor data = in_data.get(s); - Tensor out = out_data.get(s); - std::vector ishape = {static_cast(data.shape_[0]), - static_cast(data.shape_[1]), - static_cast(data.shape_[2]), - static_cast(data.shape_[3]), - static_cast(data.shape_[4])}; - - std::vector istride = {static_cast(ishape[1] * ishape[2] * ishape[3] * ishape[4]), - static_cast(ishape[2] * ishape[3] * ishape[4]), - static_cast(ishape[3] * ishape[4]), - static_cast(ishape[4]), 1}; + mshadow::Shape<5> dshape = data.shape_; + mshadow::Shape<5> dstride = mshadow::Shape5(dshape.ProdShape(1, 5), + dshape.ProdShape(2, 5), + dshape.ProdShape(3, 5), + dshape.ProdShape(4, 5), + dshape.ProdShape(5, 5)); - std::vector oshape = {static_cast(out.shape_[0]), - static_cast(out.shape_[1]), - static_cast(out.shape_[2]), - static_cast(out.shape_[3]), - static_cast(out.shape_[4])}; + Tensor out = out_data.get(s); + mshadow::Shape<5> oshape = out.shape_; + mshadow::Shape<5> ostride = mshadow::Shape5(oshape.ProdShape(1, 5), + oshape.ProdShape(2, 5), + oshape.ProdShape(3, 5), + oshape.ProdShape(4, 5), + oshape.ProdShape(5, 5)); + // Convert to a standard (NCDHW) layout space to create args for cuDNN - std::vector ostride = {static_cast(oshape[1] * oshape[2] * oshape[3] * oshape[4]), - static_cast(oshape[2] * oshape[3] * oshape[4]), - static_cast(oshape[3] * oshape[4]), - static_cast(oshape[4]), 1}; + mshadow::Shape<5> dshape_ncdhw = (layout == mshadow::kNDHWC) ? + ConvertLayout(dshape, mshadow::kNDHWC, mshadow::kNCDHW) : + dshape; + mshadow::Shape<5> dstride_ncdhw = (layout == mshadow::kNDHWC) ? + ConvertLayout(dstride, mshadow::kNDHWC, mshadow::kNCDHW) : + dstride; + mshadow::Shape<5> oshape_ncdhw = (layout == mshadow::kNDHWC) ? + ConvertLayout(oshape, mshadow::kNDHWC, mshadow::kNCDHW) : + oshape; + mshadow::Shape<5> ostride_ncdhw = (layout == mshadow::kNDHWC) ? + ConvertLayout(ostride, mshadow::kNDHWC, mshadow::kNCDHW) : + ostride; + // Create int arrays for passing into cuDNN + std::array dshape_ncdhw_int, dstride_ncdhw_int, oshape_ncdhw_int, ostride_ncdhw_int; + for (int i = 0; i < 5; ++i) { + dshape_ncdhw_int[i] = static_cast(dshape_ncdhw[i]); + dstride_ncdhw_int[i] = static_cast(dstride_ncdhw[i]); + oshape_ncdhw_int[i] = static_cast(oshape_ncdhw[i]); + ostride_ncdhw_int[i] = static_cast(ostride_ncdhw[i]); + } - std::vector kernel_vec = {param_.global_pool ? ishape[2] : + std::array kernel_vec = {param_.global_pool ? static_cast(dshape_ncdhw[2]) : static_cast(param_.kernel[0]), - param_.global_pool ? ishape[3] : + param_.global_pool ? static_cast(dshape_ncdhw[3]) : static_cast(param_.kernel[1]), - param_.global_pool ? ishape[4] : + param_.global_pool ? static_cast(dshape_ncdhw[4]) : static_cast(param_.kernel[2])}; - std::vector pad_vec = {param_.global_pool ? 0 : static_cast(param_.pad[0]), + std::array pad_vec = {param_.global_pool ? 0 : static_cast(param_.pad[0]), param_.global_pool ? 0 : static_cast(param_.pad[1]), param_.global_pool ? 0 : static_cast(param_.pad[2])}; - std::vector stride_vec = {param_.global_pool ? 1 : static_cast(param_.stride[0]), + std::array stride_vec = {param_.global_pool ? 1 : static_cast(param_.stride[0]), param_.global_pool ? 1 : static_cast(param_.stride[1]), param_.global_pool ? 1 : static_cast(param_.stride[2])}; CUDNN_CALL(cudnnSetTensorNdDescriptor(in_desc_, dtype_, - static_cast(ishape.size()), - &ishape[0], - &istride[0])); + static_cast(dshape_ncdhw_int.size()), + &dshape_ncdhw_int[0], + &dstride_ncdhw_int[0])); CUDNN_CALL(cudnnSetTensorNdDescriptor(out_desc_, dtype_, - static_cast(oshape.size()), - &oshape[0], - &ostride[0])); + static_cast(oshape_ncdhw_int.size()), + &oshape_ncdhw_int[0], + &ostride_ncdhw_int[0])); #if CUDNN_MAJOR >= 5 CUDNN_CALL(cudnnSetPoolingNdDescriptor(pooling_desc_, mode_, @@ -272,9 +385,19 @@ class CuDNNPoolingOp { &(pad_vec[0]), &(stride_vec[0]))); #else - LOG(FATAL) << "3D pooling only support CUDNN v5 and above"; + LOG(FATAL) << "3D pooling is only supported by CUDNN v5 and above."; #endif } + return is_supported; + } + + // Log once that the dimension of the pooling operation isn't supported + static void LogUnsupportedDim(bool *msg_logged, int ndim) { + if (!*msg_logged) { + *msg_logged = true; + LOG(WARNING) << ndim << "D pooling is not supported by cudnn, " + << "MXNet " << ndim << "D pooling is applied."; + } } cudnnDataType_t dtype_; diff --git a/src/operator/nn/mkldnn/mkldnn_pooling-inl.h b/src/operator/nn/mkldnn/mkldnn_pooling-inl.h index f548778c7615..de3d63e24f6c 100644 --- a/src/operator/nn/mkldnn/mkldnn_pooling-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_pooling-inl.h @@ -104,7 +104,8 @@ class MKLDNNPoolingBwd { inline bool SupportMKLDNNPooling(const PoolingParam ¶m) { return param.kernel.ndim() == 2 && (param.pool_type == pool_enum::kMaxPooling || - param.pool_type == pool_enum::kAvgPooling); + param.pool_type == pool_enum::kAvgPooling) && + (!param.layout.has_value() || param.layout.value() == mshadow::kNCHW); } inline bool SupportMKLDNNPooling(const PoolingParam ¶m, diff --git a/src/operator/nn/pool.cuh b/src/operator/nn/pool.cuh index 976aacf63a55..671bc7932ef9 100644 --- a/src/operator/nn/pool.cuh +++ b/src/operator/nn/pool.cuh @@ -89,29 +89,32 @@ namespace mxnet { namespace op { /*! - * \brief max pooling gpu kernel for 1-D images. + * \brief max pooling gpu kernel for 1-D images, for both NCW and NWC layouts. * Do not call this kernel directly. Use the interface pool(). */ -template +template __global__ void pool_max_1d_gpu_kernel(const int nthreads, const DType* in_data, const int channels, const int width, const int pooled_width, const int kernel_w, const int stride_w, const int pad_w, DType* out_data) { using mshadow::red::limits::MinValue; - // index is the output image's pixel index in NCW + // index is the output image's pixel index CUDA_KERNEL_LOOP(index, nthreads) { - const int pw = index % pooled_width; - const int c = (index / pooled_width) % channels; + const bool nwc_layout = layout == mshadow::kNWC; + const int idx = nwc_layout ? (index / channels) : index; + const int pw = idx % pooled_width; + const int c = nwc_layout ? (index % channels) : (index / pooled_width) % channels; const int n = index / pooled_width / channels; int wstart = pw * stride_w - pad_w; const int wend = min(wstart + kernel_w, width); wstart = max(wstart, 0); - const DType* in_slice = - in_data + (n * channels + c) * width; + const DType* in_slice = nwc_layout ? in_data + n * channels * width + c + : in_data + (n * channels + c) * width; DType max_val = MinValue(); + const int multiplier = nwc_layout ? channels : 1; for (int w = wstart; w < wend; ++w) { - const DType in_val = in_slice[w]; + const DType in_val = in_slice[w * multiplier]; if (in_val > max_val) { max_val = in_val; } @@ -121,10 +124,10 @@ __global__ void pool_max_1d_gpu_kernel(const int nthreads, const DType* in_data, } /*! - * \brief max pooling gpu kernel for 2-D images. + * \brief max pooling gpu kernel for 2-D images, for both NCHW and NHWC layouts. * Do not call this kernel directly. Use the interface pool(). */ -template +template __global__ void pool_max_2d_gpu_kernel(const int nthreads, const DType* in_data, const int channels, const int height, const int width, const int pooled_height, const int pooled_width, @@ -132,11 +135,14 @@ __global__ void pool_max_2d_gpu_kernel(const int nthreads, const DType* in_data, const int stride_w, const int pad_h, const int pad_w, DType* out_data) { using mshadow::red::limits::MinValue; - // index is the output image's pixel index in NCHW + // index is the output image's pixel index CUDA_KERNEL_LOOP(index, nthreads) { - const int pw = index % pooled_width; - const int ph = (index / pooled_width) % pooled_height; - const int c = (index / pooled_width / pooled_height) % channels; + const bool nhwc_layout = layout == mshadow::kNHWC; + const int idx = nhwc_layout ? (index / channels) : index; + const int pw = idx % pooled_width; + const int ph = (idx / pooled_width) % pooled_height; + const int c = nhwc_layout ? (index % channels) + : (index / pooled_width / pooled_height) % channels; const int n = index / pooled_width / pooled_height / channels; int hstart = ph * stride_h - pad_h; int wstart = pw * stride_w - pad_w; @@ -144,12 +150,13 @@ __global__ void pool_max_2d_gpu_kernel(const int nthreads, const DType* in_data, const int wend = min(wstart + kernel_w, width); hstart = max(hstart, 0); wstart = max(wstart, 0); - const DType* in_slice = - in_data + (n * channels + c) * height * width; + const DType* in_slice = nhwc_layout ? in_data + n * channels * height * width + c + : in_data + (n * channels + c) * height * width; DType max_val = MinValue(); + const int multiplier = nhwc_layout ? channels : 1; for (int h = hstart; h < hend; ++h) { for (int w = wstart; w < wend; ++w) { - const DType in_val = in_slice[h * width + w]; + const DType in_val = in_slice[(h * width + w) * multiplier]; if (in_val > max_val) { max_val = in_val; } @@ -160,10 +167,10 @@ __global__ void pool_max_2d_gpu_kernel(const int nthreads, const DType* in_data, } /*! - * \brief max pooling gpu kernel for 3-D images. + * \brief max pooling gpu kernel for 3-D images, for both NCDHW and NDHWC layouts. * Do not call this kernel directly. Use the interface pool(). */ -template +template __global__ void pool_max_3d_gpu_kernel(const int nthreads, const DType* in_data, const int channels, const int depth, const int height, const int width, const int pooled_depth, const int pooled_height, @@ -173,12 +180,15 @@ __global__ void pool_max_3d_gpu_kernel(const int nthreads, const DType* in_data, const int pad_h, const int pad_w, DType* out_data) { using mshadow::red::limits::MinValue; - // index is the output image's pixel index in NCDHW + // index is the output image's pixel index CUDA_KERNEL_LOOP(index, nthreads) { - const int pw = index % pooled_width; - const int ph = (index / pooled_width) % pooled_height; - const int pd = (index / pooled_width / pooled_height) % pooled_depth; - const int c = (index / pooled_width / pooled_height / pooled_depth) % channels; + const bool ndhwc_layout = layout == mshadow::kNDHWC; + const int idx = ndhwc_layout ? (index / channels) : index; + const int pw = idx % pooled_width; + const int ph = (idx / pooled_width) % pooled_height; + const int pd = (idx / pooled_width / pooled_height) % pooled_depth; + const int c = ndhwc_layout ? (index % channels) + : (index / pooled_width / pooled_height / pooled_depth) % channels; const int n = index / pooled_width / pooled_height / pooled_depth / channels; int dstart = pd * stride_d - pad_d; int hstart = ph * stride_h - pad_h; @@ -189,13 +199,14 @@ __global__ void pool_max_3d_gpu_kernel(const int nthreads, const DType* in_data, dstart = max(dstart, 0); hstart = max(hstart, 0); wstart = max(wstart, 0); - const DType* in_slice = - in_data + (n * channels + c) * depth * height * width; + const DType* in_slice = ndhwc_layout ? in_data + n * channels * depth * height * width + c + : in_data + (n * channels + c) * depth * height * width; DType max_val = MinValue(); + const int multiplier = ndhwc_layout ? channels : 1; for (int d = dstart; d < dend; ++d) { for (int h = hstart; h < hend; ++h) { for (int w = wstart; w < wend; ++w) { - const DType in_val = in_slice[(d * height + h) * width + w]; + const DType in_val = in_slice[((d * height + h) * width + w) * multiplier]; if (in_val > max_val) { max_val = in_val; } @@ -207,17 +218,21 @@ __global__ void pool_max_3d_gpu_kernel(const int nthreads, const DType* in_data, } /*! - * \brief avg/sum pooling gpu kernel for 1-D images. + * \brief avg/sum pooling gpu kernel for 1-D images, for both NCW and NWC layouts. * Do not call this kernel directly. Use the interface pool(). */ -template +template __global__ void pool_sum_1d_gpu_kernel(const int nthreads, const DType* in_data, const int channels, const int width, const int pooled_width, const int kernel_w, const int stride_w, const int pad_w, DType* out_data, - const bool get_avg = false, const bool count_include_pad = true) { + const bool get_avg = false, + const bool count_include_pad = true) { + using AccType = typename PoolingTypes::AccType; CUDA_KERNEL_LOOP(index, nthreads) { - const int pw = index % pooled_width; - const int c = (index / pooled_width) % channels; + const bool nwc_layout = layout == mshadow::kNWC; + const int idx = nwc_layout ? (index / channels) : index; + const int pw = idx % pooled_width; + const int c = nwc_layout ? (index % channels) : (index / pooled_width) % channels; const int n = index / pooled_width / channels; int wstart = pw * stride_w - pad_w; int wend = min(wstart + kernel_w, width + pad_w); @@ -227,20 +242,22 @@ __global__ void pool_sum_1d_gpu_kernel(const int nthreads, const DType* in_data, if (get_avg && !count_include_pad) { pool_size = (wend - wstart); } - DType sum = 0; - const DType* out_slice = in_data + (n * channels + c) * width; + AccType sum = 0; + const DType* out_slice = nwc_layout ? in_data + n * channels * width + c + : in_data + (n * channels + c) * width; + const int multiplier = nwc_layout ? channels : 1; for (int w = wstart; w < wend; ++w) { - sum += a_pow_p::Map(out_slice[w]) / pool_size; + sum += a_pow_p::Map(out_slice[w * multiplier]) / pool_size; } - out_data[index] = a_root_p::Map(sum); + out_data[index] = a_root_p::Map(sum); } } /*! - * \brief avg/sum pooling gpu kernel for 2-D images. + * \brief avg/sum pooling gpu kernel for 2-D images, for both NCHW and NHWC layouts. * Do not call this kernel directly. Use the interface pool(). */ -template +template __global__ void pool_sum_2d_gpu_kernel(const int nthreads, const DType* in_data, const int channels, const int height, const int width, const int pooled_height, const int pooled_width, @@ -249,10 +266,14 @@ __global__ void pool_sum_2d_gpu_kernel(const int nthreads, const DType* in_data, const int pad_h, const int pad_w, DType* out_data, const bool get_avg = false, const bool count_include_pad = true) { + using AccType = typename PoolingTypes::AccType; CUDA_KERNEL_LOOP(index, nthreads) { - const int pw = index % pooled_width; - const int ph = (index / pooled_width) % pooled_height; - const int c = (index / pooled_width / pooled_height) % channels; + const bool nhwc_layout = layout == mshadow::kNHWC; + const int idx = nhwc_layout ? (index / channels) : index; + const int pw = idx % pooled_width; + const int ph = (idx / pooled_width) % pooled_height; + const int c = nhwc_layout ? (index % channels) + : (index / pooled_width / pooled_height) % channels; const int n = index / pooled_width / pooled_height / channels; int hstart = ph * stride_h - pad_h; int wstart = pw * stride_w - pad_w; @@ -266,22 +287,24 @@ __global__ void pool_sum_2d_gpu_kernel(const int nthreads, const DType* in_data, if (get_avg && !count_include_pad) { pool_size = (hend - hstart) * (wend - wstart); } - DType sum = 0; - const DType* out_slice = in_data + (n * channels + c) * height * width; + AccType sum = 0; + const DType* out_slice = nhwc_layout ? in_data + n * channels * height * width + c + : in_data + (n * channels + c) * height * width; + const int multiplier = nhwc_layout ? channels : 1; for (int h = hstart; h < hend; ++h) { for (int w = wstart; w < wend; ++w) { - sum += a_pow_p::Map(out_slice[h * width + w]) / pool_size; + sum += a_pow_p::Map(out_slice[(h * width + w) * multiplier]) / pool_size; } } - out_data[index] = a_root_p::Map(sum); + out_data[index] = a_root_p::Map(sum); } } /*! - * \brief avg/sum pooling gpu kernel for 3-D images. + * \brief avg/sum pooling gpu kernel for 3-D images, for both NCDHW and NDHWC layouts. * Do not call this kernel directly. Use the interface pool(). */ -template +template __global__ void pool_sum_3d_gpu_kernel(const int nthreads, const DType* in_data, const int channels, const int depth, const int height, const int width, const int pooled_depth, const int pooled_height, @@ -291,11 +314,15 @@ __global__ void pool_sum_3d_gpu_kernel(const int nthreads, const DType* in_data, const int pad_d, const int pad_h, const int pad_w, DType* out_data, const bool get_avg = false, const bool count_include_pad = true) { + using AccType = typename PoolingTypes::AccType; CUDA_KERNEL_LOOP(index, nthreads) { - const int pw = index % pooled_width; - const int ph = (index / pooled_width) % pooled_height; - const int pd = (index / pooled_width / pooled_height) % pooled_depth; - const int c = (index / pooled_width / pooled_height / pooled_depth) % channels; + const bool ndhwc_layout = layout == mshadow::kNDHWC; + const int idx = ndhwc_layout ? (index / channels) : index; + const int pw = idx % pooled_width; + const int ph = (idx / pooled_width) % pooled_height; + const int pd = (idx / pooled_width / pooled_height) % pooled_depth; + const int c = ndhwc_layout ? (index % channels) + : (index / pooled_width / pooled_height / pooled_depth) % channels; const int n = index / pooled_width / pooled_height / pooled_depth / channels; int dstart = pd * stride_d - pad_d; int hstart = ph * stride_h - pad_h; @@ -313,51 +340,57 @@ __global__ void pool_sum_3d_gpu_kernel(const int nthreads, const DType* in_data, if (get_avg && !count_include_pad) { pool_size = (dend - dstart) * (hend - hstart) * (wend - wstart); } - DType sum = 0; - const DType* out_slice = in_data + (n * channels + c) * depth * height * width; + AccType sum = 0; + const DType* out_slice = ndhwc_layout ? in_data + n * channels * depth * height * width + c + : in_data + (n * channels + c) * depth * height * width; + const int multiplier = ndhwc_layout ? channels : 1; for (int d = dstart; d < dend; ++d) { for (int h = hstart; h < hend; ++h) { for (int w = wstart; w < wend; ++w) { - sum += a_pow_p::Map(out_slice[(d * height + h) * width + w]) / pool_size; + sum += a_pow_p::Map(out_slice[((d * height + h) * width + w) * + multiplier]) / pool_size; } } } out_data[index] = (pool_size == 0) ? - DType(nanf("")) : - a_root_p::Map(sum); + AccType(nanf("")) : + a_root_p::Map(sum); } } /*! - * \brief max unpooling gpu kernel for 1-D images. + * \brief max unpooling gpu kernel for 1-D images, for both NCW and NWC layouts. * Do not call this kernel directly. Use the interface unpool(). */ -template +template __global__ void unpool_max_1d_gpu_kernel(const int nthreads, const DType* out_grad, const DType* in_data, const DType* out_data, const int channels, const int width, const int pooled_width, const int kernel_w, const int stride_w, const int pad_w, DType* in_grad) { - // index is the output image's pixel index in NCHW + // index is the output image's pixel index // the order has to be consistent with pooling max // to avoid adding out_grad to the wrong in_grad // in the case where there are multiple max pixels // covered by a kernel window CUDA_KERNEL_LOOP(index, nthreads) { - const int pw = index % pooled_width; - const int c = (index / pooled_width) % channels; - const int n = index / pooled_width / channels; + const bool nwc_layout = layout == mshadow::kNWC; + const int idx = nwc_layout ? (index / channels) : index; + const int pw = idx % pooled_width; + const int c = nwc_layout ? index % channels : (index / pooled_width) % channels; + const int n = index / channels / pooled_width; int wstart = pw * stride_w - pad_w; const int wend = min(wstart + kernel_w, width); wstart = max(wstart, 0); // in data/grad offset batch and channel dims - int in_offset = (n * channels + c) * width; + const int in_offset = nwc_layout ? n * channels * width + c : (n * channels + c) * width; const DType* in_data_slice = in_data + in_offset; int max_idx = -1; DType max_val = out_data[index]; + const int multiplier = nwc_layout ? channels : 1; for (int w = wstart; w < wend; ++w) { - if (in_data_slice[w] == max_val) { + if (in_data_slice[w * multiplier] == max_val) { max_idx = w; break; } @@ -366,16 +399,16 @@ __global__ void unpool_max_1d_gpu_kernel(const int nthreads, const DType* out_gr // In the case where pad > 0 and kernel = 1, for example, // max_idx can be -1 reaching this step. if (max_idx >= 0) { - atomicAdd(&in_grad[in_offset+max_idx], out_grad[index]); + atomicAdd(&in_grad[in_offset + max_idx * multiplier], out_grad[index]); } } } /*! - * \brief max unpooling gpu kernel for 2-D images. + * \brief max unpooling gpu kernel for 2-D images, for both NCHW and NHWC layouts. * Do not call this kernel directly. Use the interface unpool(). */ -template +template __global__ void unpool_max_2d_gpu_kernel(const int nthreads, const DType* out_grad, const DType* in_data, const DType* out_data, const int channels, const int height, const int width, @@ -384,15 +417,18 @@ __global__ void unpool_max_2d_gpu_kernel(const int nthreads, const DType* out_gr const int stride_h, const int stride_w, const int pad_h, const int pad_w, DType* in_grad) { - // index is the output image's pixel index in NCHW + // index is the output image's pixel index // the order has to be consistent with pooling max // to avoid adding out_grad to the wrong in_grad // in the case where there are multiple max pixels // covered by a kernel window CUDA_KERNEL_LOOP(index, nthreads) { - const int pw = index % pooled_width; - const int ph = (index / pooled_width) % pooled_height; - const int c = (index / pooled_width / pooled_height) % channels; + const bool nhwc_layout = layout == mshadow::kNHWC; + const int idx = nhwc_layout ? (index / channels) : index; + const int pw = idx % pooled_width; + const int ph = (idx / pooled_width) % pooled_height; + const int c = nhwc_layout ? (index % channels) + : (index / pooled_width / pooled_height) % channels; const int n = index / pooled_width / pooled_height / channels; int hstart = ph * stride_h - pad_h; int wstart = pw * stride_w - pad_w; @@ -401,15 +437,17 @@ __global__ void unpool_max_2d_gpu_kernel(const int nthreads, const DType* out_gr hstart = max(hstart, 0); wstart = max(wstart, 0); // in data/grad offset batch and channel dims - int in_offset = (n * channels + c) * height * width; + int in_offset = nhwc_layout ? n * channels * height * width + c + : (n * channels + c) * height * width; const DType* in_data_slice = in_data + in_offset; int max_idx = -1; DType max_val = out_data[index]; + const int multiplier = nhwc_layout ? channels : 1; bool found = false; for (int h = hstart; h < hend; ++h) { for (int w = wstart; w < wend; ++w) { const int idx = h * width + w; - if (in_data_slice[idx] == max_val) { + if (in_data_slice[idx * multiplier] == max_val) { max_idx = idx; found = true; break; @@ -421,16 +459,16 @@ __global__ void unpool_max_2d_gpu_kernel(const int nthreads, const DType* out_gr // In the case where pad > 0 and kernel = 1, for example, // max_idx can be -1 reaching this step. if (max_idx >= 0) { - atomicAdd(&in_grad[in_offset+max_idx], out_grad[index]); + atomicAdd(&in_grad[in_offset + max_idx * multiplier], out_grad[index]); } } } /*! - * \brief max unpooling gpu kernel for 3-D images. + * \brief max unpooling gpu kernel for 3-D images, for both NCDHW and NDHWC layouts. * Do not call this kernel directly. Use the interface unpool(). */ -template +template __global__ void unpool_max_3d_gpu_kernel(const int nthreads, const DType* out_grad, const DType* in_data, const DType* out_data, const int channels, const int depth, const int height, @@ -441,16 +479,19 @@ __global__ void unpool_max_3d_gpu_kernel(const int nthreads, const DType* out_gr const int stride_h, const int stride_w, const int pad_d, const int pad_h, const int pad_w, DType* in_grad) { - // index is the output image's pixel index in NCDHW + // index is the output image's pixel index // the order has to be consistent with pooling max // to avoid adding out_grad to the wrong in_grad // in the case where there are multiple max pixels // covered by a kernel window CUDA_KERNEL_LOOP(index, nthreads) { - const int pw = index % pooled_width; - const int ph = (index / pooled_width) % pooled_height; - const int pd = (index / pooled_width / pooled_height) % pooled_depth; - const int c = (index / pooled_width / pooled_height / pooled_depth) % channels; + const bool ndhwc_layout = layout == mshadow::kNDHWC; + const int idx = ndhwc_layout ? (index / channels) : index; + const int pw = idx % pooled_width; + const int ph = (idx / pooled_width) % pooled_height; + const int pd = (idx / pooled_width / pooled_height) % pooled_depth; + const int c = ndhwc_layout ? (index % channels) + : (index / pooled_width / pooled_height / pooled_depth) % channels; const int n = index / pooled_width / pooled_height / pooled_depth / channels; int dstart = pd * stride_d - pad_d; int hstart = ph * stride_h - pad_h; @@ -462,16 +503,18 @@ __global__ void unpool_max_3d_gpu_kernel(const int nthreads, const DType* out_gr hstart = max(hstart, 0); wstart = max(wstart, 0); // in data/grad offset batch and channel dims - int in_offset = (n * channels + c) * depth * height * width; + int in_offset = ndhwc_layout ? n * channels * depth * height * width + c + : (n * channels + c) * depth * height * width; const DType* in_data_slice = in_data + in_offset; int max_idx = -1; DType max_val = out_data[index]; + const int multiplier = ndhwc_layout ? channels : 1; bool found = false; for (int d = dstart; d < dend; ++d) { for (int h = hstart; h < hend; ++h) { for (int w = wstart; w < wend; ++w) { const int idx = (d * height + h) * width + w; - if (in_data_slice[idx] == max_val) { + if (in_data_slice[idx * multiplier] == max_val) { max_idx = idx; found = true; break; @@ -485,16 +528,16 @@ __global__ void unpool_max_3d_gpu_kernel(const int nthreads, const DType* out_gr // In the case where pad > 0 and kernel = 1, for example, // max_idx can be -1 reaching this step. if (max_idx >= 0) { - atomicAdd(&in_grad[in_offset+max_idx], out_grad[index]); + atomicAdd(&in_grad[in_offset + max_idx * multiplier], out_grad[index]); } } } /*! - * \brief avg/sum unpooling gpu kernel for 1-D images. + * \brief avg/sum unpooling gpu kernel for 1-D images, for both NCW and NWC layouts. * Do not call this kernel directly. Use the interface unpool(). */ -template +template __global__ void unpool_sum_1d_gpu_kernel(const int nthreads, const DType* out_grad, const DType* in_data, const DType* out_data, const int channels, const int width, @@ -502,20 +545,23 @@ __global__ void unpool_sum_1d_gpu_kernel(const int nthreads, const DType* out_gr const int stride_w, const int pad_w, DType* in_grad, const bool is_avg = false, const bool count_include_pad = true) { - // index is the input image index in NCW + // index is the input image index CUDA_KERNEL_LOOP(index, nthreads) { // find out the local index // find out the local offset - const int w = index % width + pad_w; - const int c = (index / width) % channels; + const bool nwc_layout = layout == mshadow::kNWC; + const int idx = nwc_layout ? (index / channels) : index; + const int w = idx % width + pad_w; + const int c = nwc_layout ? index % channels : (index / width) % channels; const int n = index / width / channels; const int pwstart = (w < kernel_w) ? 0 : (w - kernel_w) / stride_w + 1; const int pwend = min(w / stride_w + 1, pooled_width); DType gradient = 0; - const DType* out_grad_slice = - out_grad + (n * channels + c) * pooled_width; - const DType* out_data_slice = - out_data + (n * channels + c) * pooled_width; + const int slice_offset = nwc_layout ? n * channels * pooled_width + c + : (n * channels + c) * pooled_width; + const DType* out_grad_slice = out_grad + slice_offset; + const DType* out_data_slice = out_data + slice_offset; + const int multiplier = nwc_layout ? channels : 1; for (int pw = pwstart; pw < pwend; ++pw) { // figure out the pooling size int wstart = pw * stride_w - pad_w; @@ -527,7 +573,8 @@ __global__ void unpool_sum_1d_gpu_kernel(const int nthreads, const DType* out_gr pool_size = (wend - wstart); } gradient += - lp_grad::Map(out_grad_slice[pw], in_data[index], out_data_slice[pw]) / pool_size; + lp_grad::Map(out_grad_slice[pw * multiplier], in_data[index], + out_data_slice[pw * multiplier]) / pool_size; } // if req=kWriteTo, in_grad has already been assigned zero values in unpool() // use "+=" here instead of "=" to accommodate when req=kAddTo @@ -536,10 +583,10 @@ __global__ void unpool_sum_1d_gpu_kernel(const int nthreads, const DType* out_gr } /*! - * \brief avg/sum unpooling gpu kernel for 2-D images. + * \brief avg/sum unpooling gpu kernel for 2-D images, for both NCHW and NHWC layouts. * Do not call this kernel directly. Use the interface unpool(). */ -template +template __global__ void unpool_sum_2d_gpu_kernel(const int nthreads, const DType* out_grad, const DType* in_data, const DType* out_data, const int channels, const int height, const int width, @@ -549,23 +596,26 @@ __global__ void unpool_sum_2d_gpu_kernel(const int nthreads, const DType* out_gr const int pad_h, const int pad_w, DType* in_grad, const bool is_avg = false, const bool count_include_pad = true) { - // index is the input image index in NCHW + // index is the input image index CUDA_KERNEL_LOOP(index, nthreads) { // find out the local index // find out the local offset - const int w = index % width + pad_w; - const int h = (index / width) % height + pad_h; - const int c = (index / width / height) % channels; + const bool nhwc_layout = layout == mshadow::kNHWC; + const int idx = nhwc_layout ? (index / channels) : index; + const int w = idx % width + pad_w; + const int h = (idx / width) % height + pad_h; + const int c = nhwc_layout ? index % channels : (index / width / height) % channels; const int n = index / width / height / channels; const int phstart = (h < kernel_h) ? 0 : (h - kernel_h) / stride_h + 1; const int phend = min(h / stride_h + 1, pooled_height); const int pwstart = (w < kernel_w) ? 0 : (w - kernel_w) / stride_w + 1; const int pwend = min(w / stride_w + 1, pooled_width); DType gradient = 0; - const DType* out_grad_slice = - out_grad + (n * channels + c) * pooled_height * pooled_width; - const DType* out_data_slice = - out_data + (n * channels + c) * pooled_height * pooled_width; + const int slice_offset = nhwc_layout ? n * channels * pooled_height * pooled_width + c + : (n * channels + c) * pooled_height * pooled_width; + const DType* out_grad_slice = out_grad + slice_offset; + const DType* out_data_slice = out_data + slice_offset; + const int multiplier = nhwc_layout ? channels : 1; for (int ph = phstart; ph < phend; ++ph) { for (int pw = pwstart; pw < pwend; ++pw) { // figure out the pooling size @@ -583,9 +633,9 @@ __global__ void unpool_sum_2d_gpu_kernel(const int nthreads, const DType* out_gr pool_size = (hend - hstart) * (wend - wstart); } gradient += - lp_grad::Map(out_grad_slice[out_index], + lp_grad::Map(out_grad_slice[out_index * multiplier], in_data[index], - out_data_slice[out_index]) / pool_size; + out_data_slice[out_index * multiplier]) / pool_size; } } // if req=kWriteTo, in_grad has already been assigned zero values in unpool() @@ -595,10 +645,10 @@ __global__ void unpool_sum_2d_gpu_kernel(const int nthreads, const DType* out_gr } /*! - * \brief avg/sum unpooling gpu kernel for 3-D images. + * \brief avg/sum unpooling gpu kernel for 3-D images, for both NCDHW and NDHWC layouts. * Do not call this kernel directly. Use the interface unpool(). */ -template +template __global__ void unpool_sum_3d_gpu_kernel(const int nthreads, const DType* out_grad, const DType* in_data, const DType* out_data, const int channels, const int depth, const int height, @@ -609,14 +659,16 @@ __global__ void unpool_sum_3d_gpu_kernel(const int nthreads, const DType* out_gr const int stride_w, const int pad_d, const int pad_h, const int pad_w, DType* in_grad, const bool is_avg = false, const bool count_include_pad = true) { - // index is the input image index in NCDHW + // index is the input image index CUDA_KERNEL_LOOP(index, nthreads) { // find out the local index // find out the local offset - const int w = index % width + pad_w; - const int h = (index / width) % height + pad_h; - const int d = (index / width / height) % depth + pad_d; - const int c = (index / width / height / depth) % channels; + const bool ndhwc_layout = layout == mshadow::kNDHWC; + const int idx = ndhwc_layout ? (index / channels) : index; + const int w = idx % width + pad_w; + const int h = (idx / width) % height + pad_h; + const int d = (idx / width / height) % depth + pad_d; + const int c = ndhwc_layout ? index % channels : (index / width / height / depth) % channels; const int n = index / width / height / depth / channels; const int pdstart = (d < kernel_d) ? 0 : (d - kernel_d) / stride_d + 1; const int pdend = min(d / stride_d + 1, pooled_depth); @@ -625,10 +677,12 @@ __global__ void unpool_sum_3d_gpu_kernel(const int nthreads, const DType* out_gr const int pwstart = (w < kernel_w) ? 0 : (w - kernel_w) / stride_w + 1; const int pwend = min(w / stride_w + 1, pooled_width); DType gradient = 0; - const DType* out_grad_slice = - out_grad + (n * channels + c) * pooled_depth * pooled_height * pooled_width; - const DType* out_data_slice = - out_data + (n * channels + c) * pooled_depth * pooled_height * pooled_width; + const int slice_offset = + ndhwc_layout ? n * channels * pooled_depth * pooled_height * pooled_width + c + : (n * channels + c) * pooled_depth * pooled_height * pooled_width; + const DType* out_grad_slice = out_grad + slice_offset; + const DType* out_data_slice = out_data + slice_offset; + const int multiplier = ndhwc_layout ? channels : 1; for (int pd = pdstart; pd < pdend; ++pd) { for (int ph = phstart; ph < phend; ++ph) { for (int pw = pwstart; pw < pwend; ++pw) { @@ -650,9 +704,9 @@ __global__ void unpool_sum_3d_gpu_kernel(const int nthreads, const DType* out_gr wend = min(wend, width); pool_size = (dend - dstart) * (hend - hstart) * (wend - wstart); } - gradient += lp_grad::Map(out_grad_slice[out_index], + gradient += lp_grad::Map(out_grad_slice[out_index * multiplier], in_data[index], - out_data_slice[out_index]) / pool_size; + out_data_slice[out_index * multiplier]) / pool_size; } } } @@ -674,9 +728,9 @@ __global__ void unpool_sum_3d_gpu_kernel(const int nthreads, const DType* out_gr * \param pool_type supported pooling type: max, avg, sum * \param req_type operator request type, only support kWriteTo for now * \param out_data pointer of the output tensor data in the format of NCW, NCHW, or NCDHW - * \param p_value value of p for Lp pooling + * \param count_include_pad for avg pooling, should 0 pad values be averaged in the window */ -template +template inline void pool(mshadow::Stream* s, const DType* in_data, const TShape& ishape, const TShape& oshape, const TShape& kernel, const TShape& pad, const TShape& stride, const int pool_type, OpReqType req_type, @@ -686,14 +740,14 @@ inline void pool(mshadow::Stream* s, const DType* in_data, const TShape& is if (kernel.ndim() == 1) { if (pool_enum::kMaxPooling == pool_type) { // NOLINT_NEXT_LINE(whitespace/operators) - pool_max_1d_gpu_kernel<<<<::GetStream(s)>>>( oshape.Size(), in_data, ishape[1], ishape[2], oshape[2], kernel[0], stride[0], pad[0], out_data); MSHADOW_CUDA_POST_KERNEL_CHECK(pool_max_1d_gpu_kernel); } else if (pool_enum::kAvgPooling == pool_type) { // NOLINT_NEXT_LINE(whitespace/operators) - pool_sum_1d_gpu_kernel<<<<::GetStream(s)>>>( oshape.Size(), in_data, ishape[1], ishape[2], oshape[2], kernel[0], stride[0], pad[0], out_data, @@ -701,14 +755,14 @@ inline void pool(mshadow::Stream* s, const DType* in_data, const TShape& is MSHADOW_CUDA_POST_KERNEL_CHECK(pool_sum_1d_gpu_kernel); } else if (pool_enum::kSumPooling == pool_type) { // NOLINT_NEXT_LINE(whitespace/operators) - pool_sum_1d_gpu_kernel<<<<::GetStream(s)>>>( oshape.Size(), in_data, ishape[1], ishape[2], oshape[2], kernel[0], stride[0], pad[0], out_data); MSHADOW_CUDA_POST_KERNEL_CHECK(pool_sum_1d_gpu_kernel); } else if (pool_enum::kLpPooling == pool_type) { // NOLINT_NEXT_LINE(whitespace/operators) - pool_sum_1d_gpu_kernel<<<<::GetStream(s)>>>( oshape.Size(), in_data, ishape[1], ishape[2], oshape[2], kernel[0], stride[0], pad[0], out_data); @@ -719,7 +773,7 @@ inline void pool(mshadow::Stream* s, const DType* in_data, const TShape& is } else if (kernel.ndim() == 2) { if (pool_enum::kMaxPooling == pool_type) { // NOLINT_NEXT_LINE(whitespace/operators) - pool_max_2d_gpu_kernel<<<<::GetStream(s)>>>( oshape.Size(), in_data, ishape[1], ishape[2], ishape[3], oshape[2], oshape[3], kernel[0], kernel[1], @@ -727,7 +781,7 @@ inline void pool(mshadow::Stream* s, const DType* in_data, const TShape& is MSHADOW_CUDA_POST_KERNEL_CHECK(pool_max_2d_gpu_kernel); } else if (pool_enum::kAvgPooling == pool_type) { // NOLINT_NEXT_LINE(whitespace/operators) - pool_sum_2d_gpu_kernel<<<<::GetStream(s)>>>( oshape.Size(), in_data, ishape[1], ishape[2], ishape[3], oshape[2], oshape[3], kernel[0], kernel[1], @@ -736,7 +790,7 @@ inline void pool(mshadow::Stream* s, const DType* in_data, const TShape& is MSHADOW_CUDA_POST_KERNEL_CHECK(pool_sum_2d_gpu_kernel); } else if (pool_enum::kSumPooling == pool_type) { // NOLINT_NEXT_LINE(whitespace/operators) - pool_sum_2d_gpu_kernel<<<<::GetStream(s)>>>( oshape.Size(), in_data, ishape[1], ishape[2], ishape[3], oshape[2], oshape[3], kernel[0], kernel[1], @@ -744,7 +798,7 @@ inline void pool(mshadow::Stream* s, const DType* in_data, const TShape& is MSHADOW_CUDA_POST_KERNEL_CHECK(pool_sum_2d_gpu_kernel); } else if (pool_enum::kLpPooling == pool_type) { // NOLINT_NEXT_LINE(whitespace/operators) - pool_sum_2d_gpu_kernel<<<<::GetStream(s)>>>( oshape.Size(), in_data, ishape[1], ishape[2], ishape[3], oshape[2], oshape[3], kernel[0], kernel[1], @@ -756,7 +810,7 @@ inline void pool(mshadow::Stream* s, const DType* in_data, const TShape& is } else if (kernel.ndim() == 3) { if (pool_enum::kMaxPooling == pool_type) { // NOLINT_NEXT_LINE(whitespace/operators) - pool_max_3d_gpu_kernel<<<<::GetStream(s)>>>( oshape.Size(), in_data, ishape[1], ishape[2], ishape[3], ishape[4], oshape[2], oshape[3], oshape[4], @@ -765,7 +819,7 @@ inline void pool(mshadow::Stream* s, const DType* in_data, const TShape& is MSHADOW_CUDA_POST_KERNEL_CHECK(pool_max_3d_gpu_kernel); } else if (pool_enum::kAvgPooling == pool_type) { // NOLINT_NEXT_LINE(whitespace/operators) - pool_sum_3d_gpu_kernel<<<<::GetStream(s)>>>( oshape.Size(), in_data, ishape[1], ishape[2], ishape[3], ishape[4], oshape[2], oshape[3], oshape[4], kernel[0], @@ -774,7 +828,7 @@ inline void pool(mshadow::Stream* s, const DType* in_data, const TShape& is MSHADOW_CUDA_POST_KERNEL_CHECK(pool_sum_3d_gpu_kernel); } else if (pool_enum::kSumPooling == pool_type) { // NOLINT_NEXT_LINE(whitespace/operators) - pool_sum_3d_gpu_kernel<<<<::GetStream(s)>>>( oshape.Size(), in_data, ishape[1], ishape[2], ishape[3], ishape[4], oshape[2], oshape[3], oshape[4], kernel[0], @@ -783,7 +837,7 @@ inline void pool(mshadow::Stream* s, const DType* in_data, const TShape& is MSHADOW_CUDA_POST_KERNEL_CHECK(pool_sum_3d_gpu_kernel); } else if (pool_enum::kLpPooling == pool_type) { // NOLINT_NEXT_LINE(whitespace/operators) - pool_sum_3d_gpu_kernel<<<<::GetStream(s)>>>( oshape.Size(), in_data, ishape[1], ishape[2], ishape[3], ishape[4], oshape[2], oshape[3], oshape[4], kernel[0], @@ -796,6 +850,70 @@ inline void pool(mshadow::Stream* s, const DType* in_data, const TShape& is } } +/*! + * \brief This function serves as an interface for 1/2/3-D pooling operations. + * \param s context stream defining the device in use is gpu + * \param in_data pointer of the input tensor data + * \param ishape input tensor shape + * \param oshape output tensor shape + * \param kernel kernel shape + * \param pad pad shape + * \param stride stride shape + * \param pool_type supported pooling type: max, avg, sum + * \param req_type operator request type, only support kWriteTo for now + * \param out_data pointer of the output tensor data + * \param count_include_pad for avg pooling, should 0 pad values be averaged in the window + * \param layout I/O tensor layout, e.g. NCHW vs. NHWC + */ +template +inline void pool(mshadow::Stream* s, const DType* in_data, const TShape& ishape, + const TShape& oshape, const TShape& kernel, const TShape& pad, + const TShape& stride, const int pool_type, OpReqType req_type, + DType* out_data, const bool count_include_pad, int layout) { + if (kernel.ndim() == 1) { + if (layout == mshadow::kNWC) { + // standardize shapes to NCW to aid templated kernel invocation + TShape ishape_ncw = ConvertLayout(ishape.get<3>(), mshadow::kNWC, mshadow::kNCW); + TShape oshape_ncw = ConvertLayout(oshape.get<3>(), mshadow::kNWC, mshadow::kNCW); + pool(s, in_data, ishape_ncw, oshape_ncw, kernel, + pad, stride, pool_type, req_type, out_data, count_include_pad); + } else if (layout == mshadow::kNCW) { + pool(s, in_data, ishape, oshape, kernel, + pad, stride, pool_type, req_type, out_data, count_include_pad); + } else { + LOG(FATAL) << "Unsupported layout, expecting kNCW or kNWC, saw: " << layout; + } + } else if (kernel.ndim() == 2) { + if (layout == mshadow::kNHWC) { + // standardize shapes to NCHW to aid templated kernel invocation + TShape ishape_nchw = ConvertLayout(ishape.get<4>(), mshadow::kNHWC, mshadow::kNCHW); + TShape oshape_nchw = ConvertLayout(oshape.get<4>(), mshadow::kNHWC, mshadow::kNCHW); + pool(s, in_data, ishape_nchw, oshape_nchw, kernel, + pad, stride, pool_type, req_type, out_data, count_include_pad); + } else if (layout == mshadow::kNCHW) { + pool(s, in_data, ishape, oshape, kernel, + pad, stride, pool_type, req_type, out_data, count_include_pad); + } else { + LOG(FATAL) << "Unsupported layout, expecting kNCHW or kNHWC, saw: " << layout; + } + } else if (kernel.ndim() == 3) { + if (layout == mshadow::kNDHWC) { + // standardize shapes to NCDHW to aid templated kernel invocation + TShape ishape_ncdhw = ConvertLayout(ishape.get<5>(), mshadow::kNDHWC, mshadow::kNCDHW); + TShape oshape_ncdhw = ConvertLayout(oshape.get<5>(), mshadow::kNDHWC, mshadow::kNCDHW); + pool(s, in_data, ishape_ncdhw, oshape_ncdhw, kernel, + pad, stride, pool_type, req_type, out_data, count_include_pad); + } else if (layout == mshadow::kNCDHW) { + pool(s, in_data, ishape, oshape, kernel, + pad, stride, pool_type, req_type, out_data, count_include_pad); + } else { + LOG(FATAL) << "Unsupported layout, expecting kNCDHW or kNDHWC, saw: " << layout; + } + } else { + LOG(FATAL) << "Unsupported " << kernel.ndim() << "-D pooling"; + } +} + /*! * \brief This function serves as an interface for 1/2/3-D unpooling operations. * \param s context stream defining the device in use is gpu @@ -810,9 +928,9 @@ inline void pool(mshadow::Stream* s, const DType* in_data, const TShape& is * \param pool_type supported pooling type: max, avg, sum * \param req_type operator request type: kNullOp, kNullWriteInplace, kNullWriteTo, kNullAddTo * \param in_grad pointer of the gradient of the operator's input tensor - * \param p_value value of p for Lp pooling + * \param count_include_pad for avg pooling, should 0 pad values be averaged in the window */ -template +template inline void unpool(mshadow::Stream* s, const DType* out_grad, const DType* in_data, const DType* out_data, const TShape& ishape, const TShape& oshape, const TShape& kernel, const TShape& pad, const TShape& stride, @@ -826,7 +944,7 @@ inline void unpool(mshadow::Stream* s, const DType* out_grad, const DType* if (kernel.ndim() == 1) { if (pool_enum::kMaxPooling == pool_type) { // NOLINT_NEXT_LINE(whitespace/operators) - unpool_max_1d_gpu_kernel<<<<::GetStream(s)>>>( oshape.Size(), out_grad, in_data, out_data, ishape[1], ishape[2], oshape[2], kernel[0], stride[0], pad[0], @@ -834,7 +952,7 @@ inline void unpool(mshadow::Stream* s, const DType* out_grad, const DType* MSHADOW_CUDA_POST_KERNEL_CHECK(unpool_max_1d_gpu_kernel); } else if (pool_enum::kAvgPooling == pool_type) { // NOLINT_NEXT_LINE(whitespace/operators) - unpool_sum_1d_gpu_kernel<<<<::GetStream(s)>>>( ishape.Size(), out_grad, in_data, out_data, ishape[1], ishape[2], oshape[2], kernel[0], @@ -842,7 +960,7 @@ inline void unpool(mshadow::Stream* s, const DType* out_grad, const DType* MSHADOW_CUDA_POST_KERNEL_CHECK(unpool_sum_1d_gpu_kernel); } else if (pool_enum::kSumPooling == pool_type) { // NOLINT_NEXT_LINE(whitespace/operators) - unpool_sum_1d_gpu_kernel<<<<::GetStream(s)>>>( ishape.Size(), out_grad, in_data, out_data, ishape[1], ishape[2], oshape[2], kernel[0], @@ -850,7 +968,7 @@ inline void unpool(mshadow::Stream* s, const DType* out_grad, const DType* MSHADOW_CUDA_POST_KERNEL_CHECK(unpool_sum_1d_gpu_kernel); } else if (pool_enum::kLpPooling == pool_type) { // NOLINT_NEXT_LINE(whitespace/operators) - unpool_sum_1d_gpu_kernel<<<<::GetStream(s)>>>( ishape.Size(), out_grad, in_data, out_data, ishape[1], ishape[2], oshape[2], kernel[0], @@ -862,7 +980,7 @@ inline void unpool(mshadow::Stream* s, const DType* out_grad, const DType* } else if (kernel.ndim() == 2) { if (pool_enum::kMaxPooling == pool_type) { // NOLINT_NEXT_LINE(whitespace/operators) - unpool_max_2d_gpu_kernel<<<<::GetStream(s)>>>( oshape.Size(), out_grad, in_data, out_data, ishape[1], ishape[2], ishape[3], @@ -871,7 +989,7 @@ inline void unpool(mshadow::Stream* s, const DType* out_grad, const DType* MSHADOW_CUDA_POST_KERNEL_CHECK(unpool_max_2d_gpu_kernel); } else if (pool_enum::kAvgPooling == pool_type) { // NOLINT_NEXT_LINE(whitespace/operators) - unpool_sum_2d_gpu_kernel<<<<::GetStream(s)>>>( ishape.Size(), out_grad, in_data, out_data, ishape[1], ishape[2], ishape[3], @@ -881,7 +999,7 @@ inline void unpool(mshadow::Stream* s, const DType* out_grad, const DType* MSHADOW_CUDA_POST_KERNEL_CHECK(unpool_sum_2d_gpu_kernel); } else if (pool_enum::kSumPooling == pool_type) { // NOLINT_NEXT_LINE(whitespace/operators) - unpool_sum_2d_gpu_kernel<<<<::GetStream(s)>>>( ishape.Size(), out_grad, in_data, out_data, ishape[1], ishape[2], ishape[3], @@ -890,7 +1008,7 @@ inline void unpool(mshadow::Stream* s, const DType* out_grad, const DType* MSHADOW_CUDA_POST_KERNEL_CHECK(unpool_sum_2d_gpu_kernel); } else if (pool_enum::kLpPooling == pool_type) { // NOLINT_NEXT_LINE(whitespace/operators) - unpool_sum_2d_gpu_kernel<<<<::GetStream(s)>>>( ishape.Size(), out_grad, in_data, out_data, ishape[1], ishape[2], ishape[3], @@ -903,7 +1021,7 @@ inline void unpool(mshadow::Stream* s, const DType* out_grad, const DType* } else if (kernel.ndim() == 3) { if (pool_enum::kMaxPooling == pool_type) { // NOLINT_NEXT_LINE(whitespace/operators) - unpool_max_3d_gpu_kernel<<<<::GetStream(s)>>>( oshape.Size(), out_grad, in_data, out_data, ishape[1], ishape[2], ishape[3], ishape[4], @@ -913,7 +1031,7 @@ inline void unpool(mshadow::Stream* s, const DType* out_grad, const DType* MSHADOW_CUDA_POST_KERNEL_CHECK(unpool_max_3d_gpu_kernel); } else if (pool_enum::kAvgPooling == pool_type) { // NOLINT_NEXT_LINE(whitespace/operators) - unpool_sum_3d_gpu_kernel<<<<::GetStream(s)>>>( ishape.Size(), out_grad, in_data, out_data, ishape[1], ishape[2], ishape[3], ishape[4], @@ -923,7 +1041,7 @@ inline void unpool(mshadow::Stream* s, const DType* out_grad, const DType* MSHADOW_CUDA_POST_KERNEL_CHECK(unpool_sum_3d_gpu_kernel); } else if (pool_enum::kSumPooling == pool_type) { // NOLINT_NEXT_LINE(whitespace/operators) - unpool_sum_3d_gpu_kernel<<<<::GetStream(s)>>>( ishape.Size(), out_grad, in_data, out_data, ishape[1], ishape[2], ishape[3], ishape[4], @@ -933,7 +1051,7 @@ inline void unpool(mshadow::Stream* s, const DType* out_grad, const DType* MSHADOW_CUDA_POST_KERNEL_CHECK(unpool_sum_3d_gpu_kernel); } else if (pool_enum::kLpPooling == pool_type) { // NOLINT_NEXT_LINE(whitespace/operators) - unpool_sum_3d_gpu_kernel<<<<::GetStream(s)>>>( ishape.Size(), out_grad, in_data, out_data, ishape[1], ishape[2], ishape[3], ishape[4], @@ -949,6 +1067,73 @@ inline void unpool(mshadow::Stream* s, const DType* out_grad, const DType* } } +/*! + * \brief This function serves as an interface for 1/2/3-D unpooling operations. + * \param s context stream defining the device in use is gpu + * \param out_grad pointer of the gradient of operator's output tensor + * \param in_data pointer of the input tensor in the format of NCW, NCHW, or NCDHW + * \param out_data pointer of the output tensor in the format of NCW, NCHW, or NCDHW + * \param ishape input tensor shape + * \param oshape output tensor shape + * \param kernel kernel shape + * \param pad pad shape + * \param stride stride shape + * \param pool_type supported pooling type: max, avg, sum + * \param req_type operator request type: kNullOp, kNullWriteInplace, kNullWriteTo, kNullAddTo + * \param in_grad pointer of the gradient of the operator's input tensor + * \param count_include_pad for avg pooling, should 0 pad values be averaged in the window + * \param layout I/O tensor layout, e.g. NCHW vs. NHWC + */ +template +inline void unpool(mshadow::Stream* s, const DType* out_grad, const DType* in_data, + const DType* out_data, const TShape& ishape, const TShape& oshape, + const TShape& kernel, const TShape& pad, const TShape& stride, + const int pool_type, OpReqType req_type, DType* in_grad, + const bool count_include_pad, int layout) { + if (kernel.ndim() == 1) { + if (layout == mshadow::kNWC) { + // standardize shapes to NCW to aid templated kernel invocation + TShape ishape_ncw = ConvertLayout(ishape.get<3>(), mshadow::kNWC, mshadow::kNCW); + TShape oshape_ncw = ConvertLayout(oshape.get<3>(), mshadow::kNWC, mshadow::kNCW); + unpool(s, out_grad, in_data, out_data, ishape_ncw, oshape_ncw, + kernel, pad, stride, pool_type, req_type, in_grad, count_include_pad); + } else if (layout == mshadow::kNCW) { + unpool(s, out_grad, in_data, out_data, ishape, oshape, kernel, + pad, stride, pool_type, req_type, in_grad, count_include_pad); + } else { + LOG(FATAL) << "Unsupported layout, expecting kNCW or kNWC, saw: " << layout; + } + } else if (kernel.ndim() == 2) { + if (layout == mshadow::kNHWC) { + // standardize shapes to NCHW to aid templated kernel invocation + TShape ishape_nchw = ConvertLayout(ishape.get<4>(), mshadow::kNHWC, mshadow::kNCHW); + TShape oshape_nchw = ConvertLayout(oshape.get<4>(), mshadow::kNHWC, mshadow::kNCHW); + unpool(s, out_grad, in_data, out_data, ishape_nchw, oshape_nchw, + kernel, pad, stride, pool_type, req_type, in_grad, count_include_pad); + } else if (layout == mshadow::kNCHW) { + unpool(s, out_grad, in_data, out_data, ishape, oshape, kernel, + pad, stride, pool_type, req_type, in_grad, count_include_pad); + } else { + LOG(FATAL) << "Unsupported layout, expecting kNCHW or kNHWC, saw: " << layout; + } + } else if (kernel.ndim() == 3) { + if (layout == mshadow::kNDHWC) { + // standardize shapes to NCDHW to aid templated kernel invocation + TShape ishape_ncdhw = ConvertLayout(ishape.get<5>(), mshadow::kNDHWC, mshadow::kNCDHW); + TShape oshape_ncdhw = ConvertLayout(oshape.get<5>(), mshadow::kNDHWC, mshadow::kNCDHW); + unpool(s, out_grad, in_data, out_data, ishape_ncdhw, oshape_ncdhw, + kernel, pad, stride, pool_type, req_type, in_grad, count_include_pad); + } else if (layout == mshadow::kNCDHW) { + unpool(s, out_grad, in_data, out_data, ishape, oshape, kernel, + pad, stride, pool_type, req_type, in_grad, count_include_pad); + } else { + LOG(FATAL) << "Unsupported layout, expecting kNCDHW or kNDHWC, saw: " << layout; + } + } else { + LOG(FATAL) << "Unsupported " << kernel.ndim() << "-D unpooling"; + } +} + } // namespace op } // namespace mxnet diff --git a/src/operator/nn/pool.h b/src/operator/nn/pool.h index 33005c8e5f0f..3c8c19a02607 100644 --- a/src/operator/nn/pool.h +++ b/src/operator/nn/pool.h @@ -61,6 +61,7 @@ #include #include +#include #include #include "./pool_utils.h" #include "../mxnet_op.h" @@ -77,13 +78,13 @@ enum PoolingOpPadConventionType {kValid, kFull, kSame}; } // namespace pool_enum /*! - * \brief max pooling cpu function for 1-D images. + * \brief max pooling cpu function for 1-D images in 'ncw' layout. * Do not call this kernel directly. Use the interface pool(). */ template -inline void pool_max_1d_cpu(const DType* in_data, const TShape& ishape, const TShape& oshape, - const TShape& kernel, const TShape& pad, const TShape& stride, - DType* out_data) { +inline void pool_max_1d_ncw_cpu(const DType *in_data, const TShape &ishape, const TShape &oshape, + const TShape &kernel, const TShape &pad, const TShape &stride, + DType *out_data) { using mshadow::red::limits::MinValue; const int width = ishape[2]; const int pooled_width = oshape[2]; @@ -113,14 +114,53 @@ inline void pool_max_1d_cpu(const DType* in_data, const TShape& ishape, const TS } /*! - * \brief max pooling cpu function for 2-D images. + * \brief max pooling cpu function for 1-D images in 'nwc' layout. * Do not call this kernel directly. Use the interface pool(). */ template -inline void pool_max_2d_cpu(const DType* in_data, const TShape& ishape, const TShape& oshape, +inline void pool_max_1d_nwc_cpu(const DType* in_data, const TShape& ishape, const TShape& oshape, const TShape& kernel, const TShape& pad, const TShape& stride, DType* out_data) { using mshadow::red::limits::MinValue; + const int width = ishape[1]; + const int pooled_width = oshape[1]; + const int kernel_w = kernel[0]; + const int pad_w = pad[0]; + const int stride_w = stride[0]; + const int features = oshape[2]; + const index_t in_data_offset = ishape[1] * features; + const index_t out_data_offset = oshape[1] * features; + std::vector max_vals(features); + for (index_t n = 0; n < oshape[0]; ++n) { + for (int pw = 0; pw < pooled_width; ++pw) { + int wstart = pw * stride_w - pad_w; + int wend = std::min(wstart + kernel_w, width); + wstart = std::max(wstart, 0); + std::fill(max_vals.begin(), max_vals.end(), MinValue()); + for (int w = wstart; w < wend; ++w) { + for (index_t c = 0; c < features; ++c) { + if (in_data[w * features + c] > max_vals[c]) { + max_vals[c] = in_data[w * features + c]; + } + } + } + for (index_t c = 0; c < features; ++c) + out_data[pw * features + c] = max_vals[c]; + } + in_data += in_data_offset; + out_data += out_data_offset; + } +} + +/*! + * \brief max pooling cpu function for 2-D images in 'nchw' layout. + * Do not call this kernel directly. Use the interface pool(). + */ +template +inline void pool_max_2d_nchw_cpu(const DType *in_data, const TShape &ishape, const TShape &oshape, + const TShape &kernel, const TShape &pad, const TShape &stride, + DType *out_data) { + using mshadow::red::limits::MinValue; const int height = ishape[2], width = ishape[3]; const int pooled_height = oshape[2], pooled_width = oshape[3]; const int kernel_h = kernel[0], kernel_w = kernel[1]; @@ -158,14 +198,62 @@ inline void pool_max_2d_cpu(const DType* in_data, const TShape& ishape, const TS } /*! - * \brief max pooling cpu function for 3-D images. + * \brief max pooling cpu function for 2-D images in 'nhwc' layout. * Do not call this kernel directly. Use the interface pool(). */ template -inline void pool_max_3d_cpu(const DType* in_data, const TShape& ishape, const TShape& oshape, +inline void pool_max_2d_nhwc_cpu(const DType* in_data, const TShape& ishape, const TShape& oshape, const TShape& kernel, const TShape& pad, const TShape& stride, DType* out_data) { using mshadow::red::limits::MinValue; + const int height = ishape[1], width = ishape[2]; + const int pooled_height = oshape[1], pooled_width = oshape[2]; + const int kernel_h = kernel[0], kernel_w = kernel[1]; + const int pad_h = pad[0], pad_w = pad[1]; + const int stride_h = stride[0], stride_w = stride[1]; + const int features = oshape[3]; + const index_t in_data_offset = ishape[1] * ishape[2] * features; + const index_t out_data_offset = oshape[1] * oshape[2] * features; + std::vector max_vals(features); + for (index_t n = 0; n < oshape[0]; ++n) { + for (int ph = 0; ph < pooled_height; ++ph) { + for (int pw = 0; pw < pooled_width; ++pw) { + int hstart = ph * stride_h - pad_h; + int wstart = pw * stride_w - pad_w; + int hend = std::min(hstart + kernel_h, height); + int wend = std::min(wstart + kernel_w, width); + hstart = std::max(hstart, 0); + wstart = std::max(wstart, 0); + const int pool_index = ph * pooled_width + pw; + std::fill(max_vals.begin(), max_vals.end(), MinValue()); + for (int h = hstart; h < hend; ++h) { + for (int w = wstart; w < wend; ++w) { + const int in_index = h * width + w; + for (index_t c = 0; c < features; ++c) { + if (in_data[in_index * features + c] > max_vals[c]) { + max_vals[c] = in_data[in_index * features + c]; + } + } + } + } + for (index_t c = 0; c < features; ++c) + out_data[pool_index * features + c] = max_vals[c]; + } + } + in_data += in_data_offset; + out_data += out_data_offset; + } +} + +/*! + * \brief max pooling cpu function for 3-D images in 'ncdhw' layout. + * Do not call this kernel directly. Use the interface pool(). + */ +template +inline void pool_max_3d_ncdhw_cpu(const DType *in_data, const TShape &ishape, const TShape &oshape, + const TShape &kernel, const TShape &pad, const TShape &stride, + DType *out_data) { + using mshadow::red::limits::MinValue; const int depth = ishape[2], height = ishape[3], width = ishape[4]; const int pooled_depth = oshape[2], pooled_height = oshape[3], pooled_width = oshape[4]; const int kernel_d = kernel[0], kernel_h = kernel[1], kernel_w = kernel[2]; @@ -210,14 +298,70 @@ inline void pool_max_3d_cpu(const DType* in_data, const TShape& ishape, const TS } /*! - * \brief avg/sum pooling cpu function for 1-D images. + * \brief max pooling cpu function for 3-D images in 'ndhwc' layout. * Do not call this kernel directly. Use the interface pool(). */ -template -inline void pool_sum_1d_cpu(const DType* in_data, const TShape& ishape, const TShape& oshape, +template +inline void pool_max_3d_ndhwc_cpu(const DType* in_data, const TShape& ishape, const TShape& oshape, const TShape& kernel, const TShape& pad, const TShape& stride, - DType* out_data, - const bool get_avg = false, const bool count_include_pad = true) { + DType* out_data) { + using mshadow::red::limits::MinValue; + const int depth = ishape[1], height = ishape[2], width = ishape[3]; + const int pooled_depth = oshape[1], pooled_height = oshape[2], pooled_width = oshape[3]; + const int kernel_d = kernel[0], kernel_h = kernel[1], kernel_w = kernel[2]; + const int pad_d = pad[0], pad_h = pad[1], pad_w = pad[2]; + const int stride_d = stride[0], stride_h = stride[1], stride_w = stride[2]; + const int features = oshape[4]; + const index_t in_data_offset = ishape[1] * ishape[2] * ishape[3] * features; + const index_t out_data_offset = oshape[1] * oshape[2] * oshape[3] * features; + std::vector max_vals(features); + for (index_t n = 0; n < oshape[0]; ++n) { + for (int pd = 0; pd < pooled_depth; ++pd) { + for (int ph = 0; ph < pooled_height; ++ph) { + for (int pw = 0; pw < pooled_width; ++pw) { + int dstart = pd * stride_d - pad_d; + int hstart = ph * stride_h - pad_h; + int wstart = pw * stride_w - pad_w; + int dend = std::min(dstart + kernel_d, depth); + int hend = std::min(hstart + kernel_h, height); + int wend = std::min(wstart + kernel_w, width); + dstart = std::max(dstart, 0); + hstart = std::max(hstart, 0); + wstart = std::max(wstart, 0); + const int pool_index = (pd * pooled_height + ph) * pooled_width + pw; + std::fill(max_vals.begin(), max_vals.end(), MinValue()); + for (int d = dstart; d < dend; ++d) { + for (int h = hstart; h < hend; ++h) { + for (int w = wstart; w < wend; ++w) { + const int in_index = (d * height + h) * width + w; + for (index_t c = 0; c < features; ++c) { + if (in_data[in_index * features + c] > max_vals[c]) { + max_vals[c] = in_data[in_index * features + c]; + } + } + } + } + } + for (index_t c = 0; c < features; ++c) + out_data[pool_index * features + c] = max_vals[c]; + } + } + } + in_data += in_data_offset; + out_data += out_data_offset; + } +} + +/*! + * \brief avg/sum pooling cpu function for 1-D images in 'ncw' layout. + * Do not call this kernel directly. Use the interface pool(). + */ +template +inline void pool_sum_1d_ncw_cpu(const DType *in_data, const TShape &ishape, const TShape &oshape, + const TShape &kernel, const TShape &pad, const TShape &stride, + DType *out_data, + const bool get_avg = false, const bool count_include_pad = true) { + using AccType = typename PoolingTypes::AccType; const int width = ishape[2]; const int pooled_width = oshape[2]; const int kernel_w = kernel[0]; @@ -236,11 +380,11 @@ inline void pool_sum_1d_cpu(const DType* in_data, const TShape& ishape, const TS if (get_avg && !count_include_pad) { pool_size = (wend - wstart); } - DType sum = 0; + AccType sum = 0; for (int w = wstart; w < wend; ++w) { - sum += a_pow_p::Map(in_data[w]) / pool_size; + sum += a_pow_p::Map(in_data[w]) / pool_size; } - out_data[pw] = a_root_p::Map(sum); + out_data[pw] = a_root_p::Map(sum); } in_data += in_data_offset; out_data += out_data_offset; @@ -249,14 +393,58 @@ inline void pool_sum_1d_cpu(const DType* in_data, const TShape& ishape, const TS } /*! - * \brief avg/sum pooling cpu function for 2-D images. + * \brief avg/sum pooling cpu function for 1-D images in 'nwc' layout. * Do not call this kernel directly. Use the interface pool(). */ template -inline void pool_sum_2d_cpu(const DType* in_data, const TShape& ishape, const TShape& oshape, +inline void pool_sum_1d_nwc_cpu(const DType* in_data, const TShape& ishape, const TShape& oshape, const TShape& kernel, const TShape& pad, const TShape& stride, DType* out_data, const bool get_avg = false, const bool count_include_pad = true) { + using AccType = typename PoolingTypes::AccType; + const int width = ishape[1]; + const int pooled_width = oshape[1]; + const int kernel_w = kernel[0]; + const int pad_w = pad[0]; + const int stride_w = stride[0]; + const int features = oshape[2]; + const index_t in_data_offset = ishape[1] * features; + const index_t out_data_offset = oshape[1] * features; + std::vector sums(features); + for (index_t n = 0; n < oshape[0]; ++n) { + for (int pw = 0; pw < pooled_width; ++pw) { + int wstart = pw * stride_w - pad_w; + int wend = std::min(wstart + kernel_w, width + pad_w); + int pool_size = (get_avg ? (wend - wstart) : 1); + wstart = std::max(wstart, 0); + wend = std::min(wend, width); + if (get_avg && !count_include_pad) { + pool_size = (wend - wstart); + } + std::fill(sums.begin(), sums.end(), 0); + for (int w = wstart; w < wend; ++w) { + for (index_t c = 0; c < features; ++c) { + sums[c] += a_pow_p::Map(in_data[w * features + c]) / pool_size; + } + } + for (index_t c = 0; c < features; ++c) + out_data[pw * features + c] = a_root_p::Map(sums[c]); + } + in_data += in_data_offset; + out_data += out_data_offset; + } +} + +/*! + * \brief avg/sum pooling cpu function for 2-D images in 'nchw' layout. + * Do not call this kernel directly. Use the interface pool(). + */ +template +inline void pool_sum_2d_nchw_cpu(const DType *in_data, const TShape &ishape, const TShape &oshape, + const TShape &kernel, const TShape &pad, const TShape &stride, + DType *out_data, + const bool get_avg = false, const bool count_include_pad = true) { + using AccType = typename PoolingTypes::AccType; const int height = ishape[2], width = ishape[3]; const int pooled_height = oshape[2], pooled_width = oshape[3]; const int kernel_h = kernel[0], kernel_w = kernel[1]; @@ -280,13 +468,13 @@ inline void pool_sum_2d_cpu(const DType* in_data, const TShape& ishape, const TS if (get_avg && !count_include_pad) { pool_size = (hend - hstart) * (wend - wstart); } - DType sum = 0; + AccType sum = 0; for (int h = hstart; h < hend; ++h) { for (int w = wstart; w < wend; ++w) { - sum += a_pow_p::Map(in_data[h*width+w]) / pool_size; + sum += a_pow_p::Map(in_data[h*width+w]) / pool_size; } } - out_data[ph*pooled_width+pw] = a_root_p::Map(sum); + out_data[ph*pooled_width+pw] = a_root_p::Map(sum); } } in_data += in_data_offset; @@ -296,14 +484,68 @@ inline void pool_sum_2d_cpu(const DType* in_data, const TShape& ishape, const TS } /*! - * \brief avg/sum pooling cpu function for 3-D images. + * \brief avg/sum pooling cpu function for 2-D images in 'nhwc' layout. * Do not call this kernel directly. Use the interface pool(). */ template -inline void pool_sum_3d_cpu(const DType* in_data, const TShape& ishape, const TShape& oshape, +inline void pool_sum_2d_nhwc_cpu(const DType* in_data, const TShape& ishape, const TShape& oshape, const TShape& kernel, const TShape& pad, const TShape& stride, DType* out_data, const bool get_avg = false, const bool count_include_pad = true) { + using AccType = typename PoolingTypes::AccType; + const int height = ishape[1], width = ishape[2]; + const int pooled_height = oshape[1], pooled_width = oshape[2]; + const int kernel_h = kernel[0], kernel_w = kernel[1]; + const int pad_h = pad[0], pad_w = pad[1]; + const int stride_h = stride[0], stride_w = stride[1]; + const int features = oshape[3]; + const index_t in_data_offset = ishape[1] * ishape[2] * features; + const index_t out_data_offset = oshape[1] * oshape[2] * features; + std::vector sums(features); + for (index_t n = 0; n < oshape[0]; ++n) { + for (int ph = 0; ph < pooled_height; ++ph) { + for (int pw = 0; pw < pooled_width; ++pw) { + int hstart = ph * stride_h - pad_h; + int wstart = pw * stride_w - pad_w; + int hend = std::min(hstart + kernel_h, height + pad_h); + int wend = std::min(wstart + kernel_w, width + pad_w); + int pool_size = (get_avg ? (hend - hstart) * (wend - wstart) : 1); + hstart = std::max(hstart, 0); + wstart = std::max(wstart, 0); + hend = std::min(hend, height); + wend = std::min(wend, width); + if (get_avg && !count_include_pad) { + pool_size = (hend - hstart) * (wend - wstart); + } + const int pool_index = ph * pooled_width + pw; + std::fill(sums.begin(), sums.end(), 0); + for (int h = hstart; h < hend; ++h) { + for (int w = wstart; w < wend; ++w) { + const int in_index = h * width + w; + for (index_t c = 0; c < features; ++c) { + sums[c] += a_pow_p::Map(in_data[in_index * features + c]) / pool_size; + } + } + } + for (index_t c = 0; c < features; ++c) + out_data[pool_index * features + c] = a_root_p::Map(sums[c]); + } + } + in_data += in_data_offset; + out_data += out_data_offset; + } +} + +/*! + * \brief avg/sum pooling cpu function for 3-D images in 'ncdhw' layout. + * Do not call this kernel directly. Use the interface pool(). + */ +template +inline void pool_sum_3d_ncdhw_cpu(const DType *in_data, const TShape &ishape, const TShape &oshape, + const TShape &kernel, const TShape &pad, const TShape &stride, + DType *out_data, + const bool get_avg = false, const bool count_include_pad = true) { + using AccType = typename PoolingTypes::AccType; const int depth = ishape[2], height = ishape[3], width = ishape[4]; const int pooled_depth = oshape[2], pooled_height = oshape[3], pooled_width = oshape[4]; const int kernel_d = kernel[0], kernel_h = kernel[1], kernel_w = kernel[2]; @@ -332,17 +574,17 @@ inline void pool_sum_3d_cpu(const DType* in_data, const TShape& ishape, const TS if (get_avg && !count_include_pad) { pool_size = (dend - dstart) * (hend - hstart) * (wend - wstart); } - DType sum = 0; + AccType sum = 0; for (int d = dstart; d < dend; ++d) { for (int h = hstart; h < hend; ++h) { for (int w = wstart; w < wend; ++w) { - sum += a_pow_p::Map(in_data[(d*height+h)*width+w]) / pool_size; + sum += a_pow_p::Map(in_data[(d*height+h)*width+w]) / pool_size; } } } out_data[(pd*pooled_height+ph)*pooled_width+pw] = (pool_size == 0) ? - DType(nanf("")) : - a_root_p::Map(sum); + AccType(nanf("")) : + a_root_p::Map(sum); } } } @@ -353,15 +595,78 @@ inline void pool_sum_3d_cpu(const DType* in_data, const TShape& ishape, const TS } /*! - * \brief max unpooling cpu function for 1-D images. + * \brief avg/sum pooling cpu function for 3-D images in 'ndhwc' layout. + * Do not call this kernel directly. Use the interface pool(). + */ +template +inline void pool_sum_3d_ndhwc_cpu(const DType* in_data, const TShape& ishape, const TShape& oshape, + const TShape& kernel, const TShape& pad, const TShape& stride, + DType* out_data, + const bool get_avg = false, const bool count_include_pad = true) { + using AccType = typename PoolingTypes::AccType; + const int depth = ishape[1], height = ishape[2], width = ishape[3]; + const int pooled_depth = oshape[1], pooled_height = oshape[2], pooled_width = oshape[3]; + const int kernel_d = kernel[0], kernel_h = kernel[1], kernel_w = kernel[2]; + const int pad_d = pad[0], pad_h = pad[1], pad_w = pad[2]; + const int stride_d = stride[0], stride_h = stride[1], stride_w = stride[2]; + const int features = oshape[4]; + const index_t in_data_offset = ishape[1] * ishape[2] * ishape[3] * features; + const index_t out_data_offset = oshape[1] * oshape[2] * oshape[3] * features; + std::vector sums(features); + for (index_t n = 0; n < oshape[0]; ++n) { + for (int pd = 0; pd < pooled_depth; ++pd) { + for (int ph = 0; ph < pooled_height; ++ph) { + for (int pw = 0; pw < pooled_width; ++pw) { + int dstart = pd * stride_d - pad_d; + int hstart = ph * stride_h - pad_h; + int wstart = pw * stride_w - pad_w; + int dend = std::min(dstart + kernel_d, depth + pad_d); + int hend = std::min(hstart + kernel_h, height + pad_h); + int wend = std::min(wstart + kernel_w, width + pad_w); + int pool_size = (get_avg ? (dend - dstart) * (hend - hstart) * (wend - wstart) : 1); + dstart = std::max(dstart, 0); + hstart = std::max(hstart, 0); + wstart = std::max(wstart, 0); + dend = std::min(dend, depth); + hend = std::min(hend, height); + wend = std::min(wend, width); + if (get_avg && !count_include_pad) { + pool_size = (dend - dstart) * (hend - hstart) * (wend - wstart); + } + const int pool_index = (pd * pooled_height + ph) * pooled_width + pw; + std::fill(sums.begin(), sums.end(), 0); + for (int d = dstart; d < dend; ++d) { + for (int h = hstart; h < hend; ++h) { + for (int w = wstart; w < wend; ++w) { + const int in_index = (d * height + h) * width + w; + for (index_t c = 0; c < features; ++c) { + sums[c] += a_pow_p::Map(in_data[in_index * features + c]) / pool_size; + } + } + } + } + for (index_t c = 0; c < features; ++c) + out_data[pool_index * features + c] = (pool_size == 0) ? + AccType(nanf("")) : + a_root_p::Map(sums[c]); + } + } + } + in_data += in_data_offset; + out_data += out_data_offset; + } +} + +/*! + * \brief max unpooling cpu function for 1-D images in 'ncw' layout. * Do not call this kernel directly. Use the interface unpool(). */ template -inline void unpool_max_1d_cpu(const DType* out_grad, const DType* in_data, - const DType* out_data, const TShape& ishape, - const TShape& oshape, const TShape& kernel, - const TShape& pad, const TShape& stride, - DType* in_grad) { +inline void unpool_max_1d_ncw_cpu(const DType *out_grad, const DType *in_data, + const DType *out_data, const TShape &ishape, + const TShape &oshape, const TShape &kernel, + const TShape &pad, const TShape &stride, + DType *in_grad) { const int width = ishape[2]; const int pooled_width = oshape[2]; const int kernel_w = kernel[0]; @@ -397,15 +702,63 @@ inline void unpool_max_1d_cpu(const DType* out_grad, const DType* in_data, } /*! - * \brief max unpooling cpu function for 2-D images. + * \brief max unpooling cpu function for 1-D images in 'nwc' layout. * Do not call this kernel directly. Use the interface unpool(). */ template -inline void unpool_max_2d_cpu(const DType* out_grad, const DType* in_data, +inline void unpool_max_1d_nwc_cpu(const DType* out_grad, const DType* in_data, const DType* out_data, const TShape& ishape, const TShape& oshape, const TShape& kernel, const TShape& pad, const TShape& stride, DType* in_grad) { + const int width = ishape[1]; + const int pooled_width = oshape[1]; + const int kernel_w = kernel[0]; + const int pad_w = pad[0]; + const int stride_w = stride[0]; + const int features = oshape[2]; + const index_t in_offset = ishape[1] * features; + const index_t out_offset = oshape[1] * features; + std::vector max_idxs(features); + for (index_t n = 0; n < oshape[0]; ++n) { + for (int pw = 0; pw < pooled_width; ++pw) { + int wstart = pw * stride_w - pad_w; + int wend = std::min(wstart + kernel_w, width); + wstart = std::max(wstart, 0); + std::fill(max_idxs.begin(), max_idxs.end(), -1); + for (index_t c = 0; c < features; ++c) { + for (int w = wstart; w < wend; ++w) { + if (in_data[w * features + c] == out_data[pw * features + c]) { + max_idxs[c] = w; + break; + } + } + } + // In the case where pad > 0 and kernel = 1, for example, + // max_idx can be -1 reaching this step. + for (index_t c = 0; c < features; ++c) { + if (max_idxs[c] >= 0) { + in_grad[max_idxs[c] * features + c] += out_grad[pw * features + c]; + } + } + } + in_data += in_offset; + in_grad += in_offset; + out_data += out_offset; + out_grad += out_offset; + } +} + +/*! + * \brief max unpooling cpu function for 2-D images in 'nchw' layout. + * Do not call this kernel directly. Use the interface unpool(). + */ +template +inline void unpool_max_2d_nchw_cpu(const DType *out_grad, const DType *in_data, + const DType *out_data, const TShape &ishape, + const TShape &oshape, const TShape &kernel, + const TShape &pad, const TShape &stride, + DType *in_grad) { const int height = ishape[2], width = ishape[3]; const int pooled_height = oshape[2], pooled_width = oshape[3]; const int kernel_h = kernel[0], kernel_w = kernel[1]; @@ -453,15 +806,75 @@ inline void unpool_max_2d_cpu(const DType* out_grad, const DType* in_data, } /*! - * \brief max unpooling cpu function for 3-D images. + * \brief max unpooling cpu function for 2-D images in 'nhwc' layout. * Do not call this kernel directly. Use the interface unpool(). */ template -inline void unpool_max_3d_cpu(const DType* out_grad, const DType* in_data, +inline void unpool_max_2d_nhwc_cpu(const DType* out_grad, const DType* in_data, const DType* out_data, const TShape& ishape, const TShape& oshape, const TShape& kernel, const TShape& pad, const TShape& stride, DType* in_grad) { + const int height = ishape[1], width = ishape[2]; + const int pooled_height = oshape[1], pooled_width = oshape[2]; + const int kernel_h = kernel[0], kernel_w = kernel[1]; + const int pad_h = pad[0], pad_w = pad[1]; + const int stride_h = stride[0], stride_w = stride[1]; + const int features = oshape[3]; + const index_t in_offset = ishape[1] * ishape[2] * features; + const index_t out_offset = oshape[1] * oshape[2] * features; + std::vector max_idxs(features); + for (index_t n = 0; n < oshape[0]; ++n) { + for (int ph = 0; ph < pooled_height; ++ph) { + for (int pw = 0; pw < pooled_width; ++pw) { + int hstart = ph * stride_h - pad_h; + int wstart = pw * stride_w - pad_w; + int hend = std::min(hstart + kernel_h, height); + int wend = std::min(wstart + kernel_w, width); + hstart = std::max(hstart, 0); + wstart = std::max(wstart, 0); + const int pool_index = ph * pooled_width + pw; + std::fill(max_idxs.begin(), max_idxs.end(), -1); + for (index_t c = 0; c < features; ++c) { + bool found = false; + for (int h = hstart; h < hend; ++h) { + for (int w = wstart; w < wend; ++w) { + const int idx = h * width + w; + if (in_data[idx * features + c] == out_data[pool_index * features + c]) { + max_idxs[c] = idx; + found = true; + break; + } + } + if (found) break; + } + } + // In the case where pad > 0 and kernel = 1, for example, + // max_idx can be -1 reaching this step. + for (index_t c = 0; c < features; ++c) { + if (max_idxs[c] >= 0) { + in_grad[max_idxs[c] * features + c] += out_grad[pool_index * features + c]; + } + } + } + } + in_data += in_offset; + in_grad += in_offset; + out_data += out_offset; + out_grad += out_offset; + } +} + +/*! + * \brief max unpooling cpu function for 3-D images in 'ncdhw' layout. + * Do not call this kernel directly. Use the interface unpool(). + */ +template +inline void unpool_max_3d_ncdhw_cpu(const DType *out_grad, const DType *in_data, + const DType *out_data, const TShape &ishape, + const TShape &oshape, const TShape &kernel, + const TShape &pad, const TShape &stride, + DType *in_grad) { const int depth = ishape[2], height = ishape[3], width = ishape[4]; const int pooled_depth = oshape[2], pooled_height = oshape[3], pooled_width = oshape[4]; const int kernel_d = kernel[0], kernel_h = kernel[1], kernel_w = kernel[2]; @@ -517,14 +930,83 @@ inline void unpool_max_3d_cpu(const DType* out_grad, const DType* in_data, } /*! - * \brief avg/sum unpooling cpu function for 1-D images. + * \brief max unpooling cpu function for 3-D images in 'ndhwc' layout. + * Do not call this kernel directly. Use the interface unpool(). + */ +template +inline void unpool_max_3d_ndhwc_cpu(const DType* out_grad, const DType* in_data, + const DType* out_data, const TShape& ishape, + const TShape& oshape, const TShape& kernel, + const TShape& pad, const TShape& stride, + DType* in_grad) { + const int depth = ishape[1], height = ishape[2], width = ishape[3]; + const int pooled_depth = oshape[1], pooled_height = oshape[2], pooled_width = oshape[3]; + const int kernel_d = kernel[0], kernel_h = kernel[1], kernel_w = kernel[2]; + const int pad_d = pad[0], pad_h = pad[1], pad_w = pad[2]; + const int stride_d = stride[0], stride_h = stride[1], stride_w = stride[2]; + const int features = oshape[4]; + const index_t in_offset = ishape[1] * ishape[2] * ishape[3] * features; + const index_t out_offset = oshape[1] * oshape[2] * oshape[3] * features; + std::vector max_idxs(features); + for (index_t n = 0; n < oshape[0]; ++n) { + for (int pd = 0; pd < pooled_depth; ++pd) { + for (int ph = 0; ph < pooled_height; ++ph) { + for (int pw = 0; pw < pooled_width; ++pw) { + int dstart = pd * stride_d - pad_d; + int hstart = ph * stride_h - pad_h; + int wstart = pw * stride_w - pad_w; + int dend = std::min(dstart + kernel_d, depth); + int hend = std::min(hstart + kernel_h, height); + int wend = std::min(wstart + kernel_w, width); + dstart = std::max(dstart, 0); + hstart = std::max(hstart, 0); + wstart = std::max(wstart, 0); + const int pool_index = (pd * pooled_height + ph) * pooled_width + pw; + std::fill(max_idxs.begin(), max_idxs.end(), -1); + for (index_t c = 0; c < features; ++c) { + bool found = false; + for (int d = dstart; d < dend; ++d) { + for (int h = hstart; h < hend; ++h) { + for (int w = wstart; w < wend; ++w) { + const int idx = (d * height + h) * width + w; + if (in_data[idx * features + c] == out_data[pool_index * features + c]) { + max_idxs[c] = idx; + found = true; + break; + } + } + if (found) break; + } + if (found) break; + } + } + // In the case where pad > 0 and kernel = 1, for example, + // max_idx can be -1 reaching this step. + for (index_t c = 0; c < features; ++c) { + if (max_idxs[c] >= 0) { + in_grad[max_idxs[c] * features + c] += out_grad[pool_index * features + c]; + } + } + } + } + } + in_data += in_offset; + in_grad += in_offset; + out_data += out_offset; + out_grad += out_offset; + } +} + +/*! + * \brief avg/sum unpooling cpu function for 1-D images in 'ncw' layout. * Do not call this kernel directly. Use the interface unpool(). */ template -inline void unpool_sum_1d_cpu(const DType* out_grad, const DType* in_data, const DType* out_data, - const TShape& ishape, const TShape& oshape, const TShape& kernel, - const TShape& pad, const TShape& stride, DType* in_grad, - const bool is_avg = false, const bool count_include_pad = true) { +inline void unpool_sum_1d_ncw_cpu(const DType *out_grad, const DType *in_data, + const DType *out_data, + const TShape &ishape, const TShape &oshape, const TShape &kernel, + const TShape &pad, const TShape &stride, DType *in_grad, + const bool is_avg = false, const bool count_include_pad = true) { const int width = ishape[2]; const int pooled_width = oshape[2]; const int kernel_w = kernel[0]; @@ -556,14 +1038,61 @@ inline void unpool_sum_1d_cpu(const DType* out_grad, const DType* in_data, const } /*! - * \brief avg/sum unpooling cpu function for 2-D images. + * \brief avg/sum unpooling cpu function for 1-D images in 'nwc' layout. + * Do not call this kernel directly. Use the interface unpool(). + */ +template +inline void unpool_sum_1d_nwc_cpu(const DType* out_grad, const DType* in_data, + const DType *out_data, const TShape &ishape, + const TShape &oshape, const TShape &kernel, + const TShape &pad, const TShape &stride, + DType *in_grad, const bool is_avg = false, + const bool count_include_pad = true) { + const int width = ishape[1]; + const int pooled_width = oshape[1]; + const int kernel_w = kernel[0]; + const int pad_w = pad[0]; + const int stride_w = stride[0]; + const int features = oshape[2]; + const index_t in_grad_offset = ishape[1] * features; + const index_t out_grad_offset = oshape[1] * features; + for (index_t n = 0; n < oshape[0]; ++n) { + for (int pw = 0; pw < pooled_width; ++pw) { + int wstart = pw * stride_w - pad_w; + int wend = std::min(wstart + kernel_w, width + pad_w); + int pool_size = (is_avg ? (wend - wstart) : 1); + wstart = std::max(wstart, 0); + wend = std::min(wend, width); + if (is_avg && !count_include_pad) { + pool_size = (wend - wstart); + } + for (int w = wstart; w < wend; ++w) { + for (index_t c = 0; c < features; ++c) { + in_grad[w * features + c] += + lp_grad::Map(out_grad[pw * features + c], + in_data[w * features + c], + out_data[pw * features + c]) / pool_size; + } + } + } + in_grad += in_grad_offset; + in_data += in_grad_offset; + out_grad += out_grad_offset; + out_data += out_grad_offset; + } +} + +/*! + * \brief avg/sum unpooling cpu function for 2-D images in 'nchw' layout. * Do not call this kernel directly. Use the interface unpool(). */ template -inline void unpool_sum_2d_cpu(const DType* out_grad, const DType* in_data, const DType* out_data, - const TShape& ishape, const TShape& oshape, const TShape& kernel, - const TShape& pad, const TShape& stride, DType* in_grad, - const bool is_avg = false, const bool count_include_pad = true) { +inline void unpool_sum_2d_nchw_cpu(const DType *out_grad, const DType *in_data, + const DType *out_data, const TShape &ishape, + const TShape &oshape, const TShape &kernel, + const TShape &pad, const TShape &stride, + DType *in_grad, const bool is_avg = false, + const bool count_include_pad = true) { const int height = ishape[2], width = ishape[3]; const int pooled_height = oshape[2], pooled_width = oshape[3]; const int kernel_h = kernel[0], kernel_w = kernel[1]; @@ -607,14 +1136,71 @@ inline void unpool_sum_2d_cpu(const DType* out_grad, const DType* in_data, const } /*! - * \brief avg/sum unpooling cpu function for 3-D images. + * \brief avg/sum unpooling cpu function for 2-D images in 'nhwc' layout. * Do not call this kernel directly. Use the interface unpool(). */ template -inline void unpool_sum_3d_cpu(const DType* out_grad, const DType* in_data, const DType* out_data, - const TShape& ishape, const TShape& oshape, const TShape& kernel, - const TShape& pad, const TShape& stride, DType* in_grad, - const bool is_avg = false, const bool count_include_pad = true) { +inline void unpool_sum_2d_nhwc_cpu(const DType* out_grad, const DType* in_data, + const DType *out_data, const TShape &ishape, + const TShape &oshape, const TShape &kernel, + const TShape &pad, const TShape &stride, + DType *in_grad, const bool is_avg = false, + const bool count_include_pad = true) { + const int height = ishape[1], width = ishape[2]; + const int pooled_height = oshape[1], pooled_width = oshape[2]; + const int kernel_h = kernel[0], kernel_w = kernel[1]; + const int pad_h = pad[0], pad_w = pad[1]; + const int features = oshape[3]; + const int stride_h = stride[0], stride_w = stride[1]; + const index_t in_grad_offset = ishape[1] * ishape[2] * features; + const index_t out_grad_offset = oshape[1] * oshape[2] * features; + for (index_t n = 0; n < oshape[0]; ++n) { + for (int ph = 0; ph < pooled_height; ++ph) { + for (int pw = 0; pw < pooled_width; ++pw) { + int hstart = ph * stride_h - pad_h; + int wstart = pw * stride_w - pad_w; + int hend = std::min(hstart + kernel_h, height + pad_h); + int wend = std::min(wstart + kernel_w, width + pad_w); + int pool_size = (is_avg ? (hend - hstart) * (wend - wstart) : 1); + hstart = std::max(hstart, 0); + wstart = std::max(wstart, 0); + hend = std::min(hend, height); + wend = std::min(wend, width); + if (is_avg && !count_include_pad) { + pool_size = (hend - hstart) * (wend - wstart); + } + const int pool_index = ph * pooled_width + pw; + for (int h = hstart; h < hend; ++h) { + for (int w = wstart; w < wend; ++w) { + const int in_index = h * width + w; + for (index_t c = 0; c < features; ++c) { + in_grad[in_index * features + c] += + lp_grad::Map(out_grad[pool_index * features + c], + in_data[in_index * features + c], + out_data[pool_index * features + c]) / pool_size; + } + } + } + } + } + in_grad += in_grad_offset; + in_data += in_grad_offset; + out_grad += out_grad_offset; + out_data += out_grad_offset; + } +} + +/*! + * \brief avg/sum unpooling cpu function for 3-D images in 'ncdhw' layout. + * Do not call this kernel directly. Use the interface unpool(). + */ +template +inline void unpool_sum_3d_ncdhw_cpu(const DType *out_grad, const DType *in_data, + const DType *out_data, const TShape &ishape, + const TShape &oshape, const TShape &kernel, + const TShape &pad, const TShape &stride, + DType *in_grad, const bool is_avg = false, + const bool count_include_pad = true) { const int depth = ishape[2], height = ishape[3], width = ishape[4]; const int pooled_depth = oshape[2], pooled_height = oshape[3], pooled_width = oshape[4]; const int kernel_d = kernel[0], kernel_h = kernel[1], kernel_w = kernel[2]; @@ -665,6 +1251,69 @@ inline void unpool_sum_3d_cpu(const DType* out_grad, const DType* in_data, const } } +/*! + * \brief avg/sum unpooling cpu function for 3-D images in 'ndhwc' layout. + * Do not call this kernel directly. Use the interface unpool(). + */ +template +inline void unpool_sum_3d_ndhwc_cpu(const DType* out_grad, const DType* in_data, + const DType *out_data, const TShape &ishape, + const TShape &oshape, const TShape &kernel, + const TShape &pad, const TShape &stride, + DType *in_grad, const bool is_avg = false, + const bool count_include_pad = true) { + const int depth = ishape[1], height = ishape[2], width = ishape[3]; + const int pooled_depth = oshape[1], pooled_height = oshape[2], pooled_width = oshape[3]; + const int kernel_d = kernel[0], kernel_h = kernel[1], kernel_w = kernel[2]; + const int pad_d = pad[0], pad_h = pad[1], pad_w = pad[2]; + const int stride_d = stride[0], stride_h = stride[1], stride_w = stride[2]; + const int features = oshape[4]; + const index_t in_grad_offset = ishape[1] * ishape[2] * ishape[3] * features; + const index_t out_grad_offset = oshape[1] * oshape[2] * oshape[3] * features; + for (index_t n = 0; n < oshape[0]; ++n) { + for (int pd = 0; pd < pooled_depth; ++pd) { + for (int ph = 0; ph < pooled_height; ++ph) { + for (int pw = 0; pw < pooled_width; ++pw) { + int dstart = pd * stride_d - pad_d; + int hstart = ph * stride_h - pad_h; + int wstart = pw * stride_w - pad_w; + int dend = std::min(dstart + kernel_d, depth + pad_d); + int hend = std::min(hstart + kernel_h, height + pad_h); + int wend = std::min(wstart + kernel_w, width + pad_w); + int pool_size = (is_avg ? (dend - dstart) * (hend - hstart) * (wend - wstart) : 1); + dstart = std::max(dstart, 0); + hstart = std::max(hstart, 0); + wstart = std::max(wstart, 0); + dend = std::min(dend, depth); + hend = std::min(hend, height); + wend = std::min(wend, width); + if (is_avg && !count_include_pad) { + pool_size = (dend - dstart) * (hend - hstart) * (wend - wstart); + } + const int pool_index = (pd * pooled_height + ph) * pooled_width + pw; + for (int d = dstart; d < dend; ++d) { + for (int h = hstart; h < hend; ++h) { + for (int w = wstart; w < wend; ++w) { + const int in_index = (d * height + h) * width + w; + for (index_t c = 0; c < features; ++c) { + in_grad[in_index * features + c] += + lp_grad::Map(out_grad[pool_index * features + c], + in_data[in_index * features + c], + out_data[pool_index * features + c]) / pool_size; + } + } + } + } + } + } + } + in_grad += in_grad_offset; + in_data += in_grad_offset; + out_grad += out_grad_offset; + out_data += out_grad_offset; + } +} + /*! * \brief This function serves as an interface for 1/2/3-D pooling operations. * \param s context stream defining the device in use is cpu @@ -683,46 +1332,97 @@ template inline void pool(mshadow::Stream* s, const DType* in_data, const TShape& ishape, const TShape& oshape, const TShape& kernel, const TShape& pad, const TShape& stride, const int pool_type, OpReqType req_type, - DType* out_data, const bool count_include_pad) { + DType* out_data, const bool count_include_pad, int layout) { CHECK_EQ(req_type, kWriteTo) << "Only support req=kWriteTo in pooling operations"; if (kernel.ndim() == 1) { - if (pool_enum::kMaxPooling == pool_type) { - pool_max_1d_cpu(in_data, ishape, oshape, kernel, pad, stride, out_data); - } else if (pool_enum::kAvgPooling == pool_type) { - pool_sum_1d_cpu(in_data, ishape, oshape, kernel, pad, stride, out_data, - true, count_include_pad); - } else if (pool_enum::kSumPooling == pool_type) { - pool_sum_1d_cpu(in_data, ishape, oshape, kernel, pad, stride, out_data); - } else if (pool_enum::kLpPooling == pool_type) { - pool_sum_1d_cpu(in_data, ishape, oshape, kernel, pad, stride, out_data); + if (layout == mshadow::kNWC) { + if (pool_enum::kMaxPooling == pool_type) { + pool_max_1d_nwc_cpu(in_data, ishape, oshape, kernel, pad, stride, out_data); + } else if (pool_enum::kAvgPooling == pool_type) { + pool_sum_1d_nwc_cpu(in_data, ishape, oshape, kernel, pad, stride, out_data, + true, count_include_pad); + } else if (pool_enum::kSumPooling == pool_type) { + pool_sum_1d_nwc_cpu(in_data, ishape, oshape, kernel, pad, stride, out_data); + } else if (pool_enum::kLpPooling == pool_type) { + pool_sum_1d_nwc_cpu(in_data, ishape, oshape, kernel, pad, stride, out_data); + } else { + LOG(FATAL) << "Unknown pooling type " << pool_type; + } + } else if (layout == mshadow::kNCW) { + if (pool_enum::kMaxPooling == pool_type) { + pool_max_1d_ncw_cpu(in_data, ishape, oshape, kernel, pad, stride, out_data); + } else if (pool_enum::kAvgPooling == pool_type) { + pool_sum_1d_ncw_cpu(in_data, ishape, oshape, kernel, pad, stride, out_data, + true, count_include_pad); + } else if (pool_enum::kSumPooling == pool_type) { + pool_sum_1d_ncw_cpu(in_data, ishape, oshape, kernel, pad, stride, out_data); + } else if (pool_enum::kLpPooling == pool_type) { + pool_sum_1d_ncw_cpu(in_data, ishape, oshape, kernel, pad, stride, out_data); + } else { + LOG(FATAL) << "Unknown pooling type " << pool_type; + } } else { - LOG(FATAL) << "Unknown pooling type " << pool_type; + LOG(FATAL) << "Unsupported layout, expecting kNCW or kNWC, saw: " << layout; } } else if (kernel.ndim() == 2) { - if (pool_enum::kMaxPooling == pool_type) { - pool_max_2d_cpu(in_data, ishape, oshape, kernel, pad, stride, out_data); - } else if (pool_enum::kAvgPooling == pool_type) { - pool_sum_2d_cpu(in_data, ishape, oshape, kernel, pad, stride, out_data, - true, count_include_pad); - } else if (pool_enum::kSumPooling == pool_type) { - pool_sum_2d_cpu(in_data, ishape, oshape, kernel, pad, stride, out_data); - } else if (pool_enum::kLpPooling == pool_type) { - pool_sum_2d_cpu(in_data, ishape, oshape, kernel, pad, stride, out_data); + if (layout == mshadow::kNHWC) { + if (pool_enum::kMaxPooling == pool_type) { + pool_max_2d_nhwc_cpu(in_data, ishape, oshape, kernel, pad, stride, out_data); + } else if (pool_enum::kAvgPooling == pool_type) { + pool_sum_2d_nhwc_cpu(in_data, ishape, oshape, kernel, pad, stride, out_data, + true, count_include_pad); + } else if (pool_enum::kSumPooling == pool_type) { + pool_sum_2d_nhwc_cpu(in_data, ishape, oshape, kernel, pad, stride, out_data); + } else if (pool_enum::kLpPooling == pool_type) { + pool_sum_2d_nhwc_cpu(in_data, ishape, oshape, kernel, pad, stride, out_data); + } else { + LOG(FATAL) << "Unknown pooling type " << pool_type; + } + } else if (layout == mshadow::kNCHW) { + if (pool_enum::kMaxPooling == pool_type) { + pool_max_2d_nchw_cpu(in_data, ishape, oshape, kernel, pad, stride, out_data); + } else if (pool_enum::kAvgPooling == pool_type) { + pool_sum_2d_nchw_cpu(in_data, ishape, oshape, kernel, pad, stride, out_data, + true, count_include_pad); + } else if (pool_enum::kSumPooling == pool_type) { + pool_sum_2d_nchw_cpu(in_data, ishape, oshape, kernel, pad, stride, out_data); + } else if (pool_enum::kLpPooling == pool_type) { + pool_sum_2d_nchw_cpu(in_data, ishape, oshape, kernel, pad, stride, out_data); + } else { + LOG(FATAL) << "Unknown pooling type " << pool_type; + } } else { - LOG(FATAL) << "Unknown pooling type " << pool_type; + LOG(FATAL) << "Unsupported layout, expecting kNCHW or kNHWC, saw: " << layout; } } else if (kernel.ndim() == 3) { - if (pool_enum::kMaxPooling == pool_type) { - pool_max_3d_cpu(in_data, ishape, oshape, kernel, pad, stride, out_data); - } else if (pool_enum::kAvgPooling == pool_type) { - pool_sum_3d_cpu(in_data, ishape, oshape, kernel, pad, stride, out_data, - true, count_include_pad); - } else if (pool_enum::kSumPooling == pool_type) { - pool_sum_3d_cpu(in_data, ishape, oshape, kernel, pad, stride, out_data); - } else if (pool_enum::kLpPooling == pool_type) { - pool_sum_3d_cpu(in_data, ishape, oshape, kernel, pad, stride, out_data); + if (layout == mshadow::kNDHWC) { + if (pool_enum::kMaxPooling == pool_type) { + pool_max_3d_ndhwc_cpu(in_data, ishape, oshape, kernel, pad, stride, out_data); + } else if (pool_enum::kAvgPooling == pool_type) { + pool_sum_3d_ndhwc_cpu(in_data, ishape, oshape, kernel, pad, stride, out_data, + true, count_include_pad); + } else if (pool_enum::kSumPooling == pool_type) { + pool_sum_3d_ndhwc_cpu(in_data, ishape, oshape, kernel, pad, stride, out_data); + } else if (pool_enum::kLpPooling == pool_type) { + pool_sum_3d_ndhwc_cpu(in_data, ishape, oshape, kernel, pad, stride, out_data); + } else { + LOG(FATAL) << "Unknown pooling type " << pool_type; + } + } else if (layout == mshadow::kNCDHW) { + if (pool_enum::kMaxPooling == pool_type) { + pool_max_3d_ncdhw_cpu(in_data, ishape, oshape, kernel, pad, stride, out_data); + } else if (pool_enum::kAvgPooling == pool_type) { + pool_sum_3d_ncdhw_cpu(in_data, ishape, oshape, kernel, pad, stride, out_data, + true, count_include_pad); + } else if (pool_enum::kSumPooling == pool_type) { + pool_sum_3d_ncdhw_cpu(in_data, ishape, oshape, kernel, pad, stride, out_data); + } else if (pool_enum::kLpPooling == pool_type) { + pool_sum_3d_ncdhw_cpu(in_data, ishape, oshape, kernel, pad, stride, out_data); + } else { + LOG(FATAL) << "Unknown pooling type " << pool_type; + } } else { - LOG(FATAL) << "Unknown pooling type " << pool_type; + LOG(FATAL) << "Unsupported layout, expecting kNCDHW or kNDHWC, saw: " << layout; } } else { LOG(FATAL) << "Unsupported " << kernel.ndim() << "-D pooling"; @@ -750,52 +1450,128 @@ inline void unpool(mshadow::Stream* s, const DType* out_grad, const DType* const DType* out_data, const TShape& ishape, const TShape& oshape, const TShape& kernel, const TShape& pad, const TShape& stride, const int pool_type, OpReqType req_type, DType* in_grad, - const bool count_include_pad) { + const bool count_include_pad, int layout) { if (mxnet::kNullOp == req_type) return; if (mxnet::kAddTo != req_type) { mxnet_op::Kernel::Launch(s, ishape.Size(), in_grad); } if (kernel.ndim() == 1) { - if (pool_enum::kMaxPooling == pool_type) { - unpool_max_1d_cpu(out_grad, in_data, out_data, ishape, oshape, kernel, pad, stride, in_grad); - } else if (pool_enum::kAvgPooling == pool_type) { - unpool_sum_1d_cpu(out_grad, in_data, out_data, ishape, oshape, kernel, pad, stride, in_grad, - true, count_include_pad); - } else if (pool_enum::kSumPooling == pool_type) { - unpool_sum_1d_cpu(out_grad, in_data, out_data, ishape, oshape, kernel, pad, stride, in_grad); - } else if (pool_enum::kLpPooling == pool_type) { - unpool_sum_1d_cpu(out_grad, in_data, out_data, ishape, oshape, kernel, pad, stride, - in_grad); + if (layout == mshadow::kNWC) { + if (pool_enum::kMaxPooling == pool_type) { + unpool_max_1d_nwc_cpu(out_grad, in_data, out_data, ishape, oshape, kernel, pad, stride, + in_grad); + } else if (pool_enum::kAvgPooling == pool_type) { + unpool_sum_1d_nwc_cpu(out_grad, in_data, out_data, ishape, oshape, kernel, pad, stride, + in_grad, true, count_include_pad); + } else if (pool_enum::kSumPooling == pool_type) { + unpool_sum_1d_nwc_cpu(out_grad, in_data, out_data, ishape, oshape, kernel, pad, stride, + in_grad); + } else if (pool_enum::kLpPooling == pool_type) { + unpool_sum_1d_nwc_cpu(out_grad, in_data, out_data, ishape, oshape, kernel, pad, + stride, + in_grad); + } else { + LOG(FATAL) << "Unknown pooling type " << pool_type; + } + } else if (layout == mshadow::kNCW) { + if (pool_enum::kMaxPooling == pool_type) { + unpool_max_1d_ncw_cpu(out_grad, in_data, out_data, ishape, oshape, kernel, pad, stride, + in_grad); + } else if (pool_enum::kAvgPooling == pool_type) { + unpool_sum_1d_ncw_cpu(out_grad, in_data, out_data, ishape, oshape, kernel, pad, stride, + in_grad, + true, count_include_pad); + } else if (pool_enum::kSumPooling == pool_type) { + unpool_sum_1d_ncw_cpu(out_grad, in_data, out_data, ishape, oshape, kernel, pad, stride, + in_grad); + } else if (pool_enum::kLpPooling == pool_type) { + unpool_sum_1d_ncw_cpu(out_grad, in_data, out_data, ishape, oshape, kernel, pad, + stride, + in_grad); + } else { + LOG(FATAL) << "Unknown pooling type " << pool_type; + } } else { - LOG(FATAL) << "Unknown pooling type " << pool_type; + LOG(FATAL) << "Unsupported layout, expecting kNCW or kNWC, saw: " << layout; } } else if (kernel.ndim() == 2) { - if (pool_enum::kMaxPooling == pool_type) { - unpool_max_2d_cpu(out_grad, in_data, out_data, ishape, oshape, kernel, pad, stride, in_grad); - } else if (pool_enum::kAvgPooling == pool_type) { - unpool_sum_2d_cpu(out_grad, in_data, out_data, ishape, oshape, kernel, pad, stride, in_grad, - true, count_include_pad); - } else if (pool_enum::kSumPooling == pool_type) { - unpool_sum_2d_cpu(out_grad, in_data, out_data, ishape, oshape, kernel, pad, stride, in_grad); - } else if (pool_enum::kLpPooling == pool_type) { - unpool_sum_2d_cpu(out_grad, in_data, out_data, ishape, oshape, kernel, pad, stride, - in_grad); + if (layout == mshadow::kNHWC) { + if (pool_enum::kMaxPooling == pool_type) { + unpool_max_2d_nhwc_cpu(out_grad, in_data, out_data, ishape, oshape, kernel, pad, stride, + in_grad); + } else if (pool_enum::kAvgPooling == pool_type) { + unpool_sum_2d_nhwc_cpu(out_grad, in_data, out_data, ishape, oshape, kernel, pad, stride, + in_grad, + true, count_include_pad); + } else if (pool_enum::kSumPooling == pool_type) { + unpool_sum_2d_nhwc_cpu(out_grad, in_data, out_data, ishape, oshape, kernel, pad, stride, + in_grad); + } else if (pool_enum::kLpPooling == pool_type) { + unpool_sum_2d_nhwc_cpu(out_grad, in_data, out_data, ishape, oshape, kernel, pad, + stride, + in_grad); + } else { + LOG(FATAL) << "Unknown pooling type " << pool_type; + } + } else if (layout == mshadow::kNCHW) { + if (pool_enum::kMaxPooling == pool_type) { + unpool_max_2d_nchw_cpu(out_grad, in_data, out_data, ishape, oshape, kernel, pad, stride, + in_grad); + } else if (pool_enum::kAvgPooling == pool_type) { + unpool_sum_2d_nchw_cpu(out_grad, in_data, out_data, ishape, oshape, kernel, pad, stride, + in_grad, + true, count_include_pad); + } else if (pool_enum::kSumPooling == pool_type) { + unpool_sum_2d_nchw_cpu(out_grad, in_data, out_data, ishape, oshape, kernel, pad, stride, + in_grad); + } else if (pool_enum::kLpPooling == pool_type) { + unpool_sum_2d_nchw_cpu(out_grad, in_data, out_data, ishape, oshape, kernel, pad, + stride, + in_grad); + } else { + LOG(FATAL) << "Unknown pooling type " << pool_type; + } } else { - LOG(FATAL) << "Unknown pooling type " << pool_type; + LOG(FATAL) << "Unsupported layout, expecting kNCHW or kNHWC, saw: " << layout; } } else if (kernel.ndim() == 3) { - if (pool_enum::kMaxPooling == pool_type) { - unpool_max_3d_cpu(out_grad, in_data, out_data, ishape, oshape, kernel, pad, stride, in_grad); - } else if (pool_enum::kAvgPooling == pool_type) { - unpool_sum_3d_cpu(out_grad, in_data, out_data, ishape, oshape, kernel, pad, stride, in_grad, - true, count_include_pad); - } else if (pool_enum::kSumPooling == pool_type) { - unpool_sum_3d_cpu(out_grad, in_data, out_data, ishape, oshape, kernel, pad, stride, in_grad); - } else if (pool_enum::kLpPooling == pool_type) { - unpool_sum_3d_cpu(out_grad, in_data, out_data, ishape, oshape, kernel, pad, stride, - in_grad); + if (layout == mshadow::kNDHWC) { + if (pool_enum::kMaxPooling == pool_type) { + unpool_max_3d_ndhwc_cpu(out_grad, in_data, out_data, ishape, oshape, kernel, pad, stride, + in_grad); + } else if (pool_enum::kAvgPooling == pool_type) { + unpool_sum_3d_ndhwc_cpu(out_grad, in_data, out_data, ishape, oshape, kernel, pad, stride, + in_grad, true, count_include_pad); + } else if (pool_enum::kSumPooling == pool_type) { + unpool_sum_3d_ndhwc_cpu(out_grad, in_data, out_data, ishape, oshape, kernel, pad, stride, + in_grad); + } else if (pool_enum::kLpPooling == pool_type) { + unpool_sum_3d_ndhwc_cpu(out_grad, in_data, out_data, ishape, oshape, kernel, pad, + stride, + in_grad); + } else { + LOG(FATAL) << "Unknown pooling type " << pool_type; + } + } else if (layout == mshadow::kNCDHW) { + if (pool_enum::kMaxPooling == pool_type) { + unpool_max_3d_ncdhw_cpu(out_grad, in_data, out_data, ishape, oshape, kernel, pad, stride, + in_grad); + } else if (pool_enum::kAvgPooling == pool_type) { + unpool_sum_3d_ncdhw_cpu(out_grad, in_data, out_data, ishape, oshape, kernel, pad, stride, + in_grad, + true, count_include_pad); + } else if (pool_enum::kSumPooling == pool_type) { + unpool_sum_3d_ncdhw_cpu(out_grad, in_data, out_data, ishape, oshape, kernel, pad, stride, + in_grad); + } else if (pool_enum::kLpPooling == pool_type) { + unpool_sum_3d_ncdhw_cpu(out_grad, in_data, out_data, ishape, oshape, kernel, pad, + stride, + in_grad); + } else { + LOG(FATAL) << "Unknown pooling type " << pool_type; + } } else { - LOG(FATAL) << "Unknown pooling type " << pool_type; + LOG(FATAL) << "Unsupported layout, expecting kNCDHW or kNDHWC, saw: " << layout; } } else { LOG(FATAL) << "Unsupported " << kernel.ndim() << "-D unpooling"; diff --git a/src/operator/nn/pool_utils.h b/src/operator/nn/pool_utils.h index 641cc4a995ab..6bf7235048dc 100644 --- a/src/operator/nn/pool_utils.h +++ b/src/operator/nn/pool_utils.h @@ -25,6 +25,17 @@ namespace mxnet { namespace op { +// Define an accumulator type AccType to permit float16-I/O lp pooling to avoid underflow. +template +struct PoolingTypes { + typedef DType AccType; +}; + +template<> +struct PoolingTypes { + typedef float AccType; +}; + template struct a_pow_p { static MSHADOW_XINLINE DType Map(const DType a) { @@ -98,14 +109,17 @@ struct lp_grad { template struct lp_grad { static MSHADOW_XINLINE DType Map(const DType grad, const DType in_data, const DType out_data) { - return grad * in_data / out_data; + // Avoid inf, if out_data has underflowed to 0 for a non-zero input, or nan if grad is also 0. + return (out_data == DType(0.0)) ? DType(0.0) : grad * (in_data / out_data); } }; template struct lp_grad { static MSHADOW_XINLINE DType Map(const DType grad, const DType in_data, const DType out_data) { - return grad * in_data * in_data / (out_data * out_data); + // Avoid inf, if out_data has underflowed to 0 for a non-zero input, or nan if grad is also 0. + DType in_out_ratio = in_data / out_data; + return (out_data == DType(0.0)) ? DType(0.0) : grad * in_out_ratio * in_out_ratio; } }; diff --git a/src/operator/nn/pooling-inl.h b/src/operator/nn/pooling-inl.h index 71d85da9ba52..af00fd5cfa3c 100644 --- a/src/operator/nn/pooling-inl.h +++ b/src/operator/nn/pooling-inl.h @@ -53,6 +53,7 @@ struct PoolingParam : public dmlc::Parameter { bool cudnn_off; dmlc::optional p_value; dmlc::optional count_include_pad; + dmlc::optional layout; DMLC_DECLARE_PARAMETER(PoolingParam) { DMLC_DECLARE_FIELD(kernel).set_default(TShape()) // add default value here .enforce_nonzero() @@ -92,6 +93,17 @@ struct PoolingParam : public dmlc::Parameter { "calculation. For example, with a 5*5 kernel on a 3*3 corner of a image," "the sum of the 9 valid elements will be divided by 25 if this is set to true," "or it will be divided by 9 if this is set to false. Defaults to true."); + + DMLC_DECLARE_FIELD(layout) + .add_enum("NCW", mshadow::kNCW) + .add_enum("NCHW", mshadow::kNCHW) + .add_enum("NCDHW", mshadow::kNCDHW) + .add_enum("NWC", mshadow::kNWC) + .add_enum("NHWC", mshadow::kNHWC) + .add_enum("NDHWC", mshadow::kNDHWC) + .set_default(dmlc::optional()) + .describe("Set layout for input and output. Empty for\n " + "default layout: NCW for 1d, NCHW for 2d and NCDHW for 3d."); } bool operator==(const PoolingParam& other) const { @@ -103,7 +115,28 @@ struct PoolingParam : public dmlc::Parameter { this->global_pool == other.global_pool && this->cudnn_off == other.cudnn_off && this->p_value == other.p_value && - this->count_include_pad == other.count_include_pad; + this->count_include_pad == other.count_include_pad && + this->layout == other.layout; + } + + // Extract layout from param, or supply default layout based on provided input dimension. + int GetLayout(int input_dim) const { + int ret_val = mshadow::kNCW; + if (layout.has_value()) { + ret_val = layout.value(); + } else { + switch (input_dim) { + case 3U: ret_val = mshadow::kNCW; break; + case 4U: ret_val = mshadow::kNCHW; break; + case 5U: ret_val = mshadow::kNCDHW; break; + default: + LOG(FATAL) << "Unexpected input data dim " << input_dim << "\n" + << "Pooling: Input data should be 3D in (batch, channel, x), " + << " or 4D in (batch, channel, y, x), " + << " or 5D in (batch, channel, d, y, x)."; + } + } + return ret_val; } }; @@ -124,6 +157,8 @@ struct hash { ret = dmlc::HashCombine(ret, val.cudnn_off); ret = dmlc::HashCombine(ret, val.p_value); ret = dmlc::HashCombine(ret, val.count_include_pad); + int val_layout = val.layout.has_value() ? val.layout.value() : -1; + ret = dmlc::HashCombine(ret, val_layout); return ret; } }; @@ -154,9 +189,17 @@ class PoolingOp { TShape kernel = param_.kernel; TShape padding = param_.pad; TShape stride = param_.stride; + int layout = param_.GetLayout(ishape.ndim()); if (param_.global_pool) { - kernel = TShape(ishape.data() + 2, - ishape.data() + ishape.ndim()); + // with global pooling, kernel shape corresponds to input shape with 'N' and 'C' removed + if (layout == mshadow::kNWC || layout == mshadow::kNHWC || layout == mshadow::kNDHWC) { + kernel = TShape(ishape.data() + 1, + ishape.data() + ishape.ndim() - 1); + + } else { + kernel = TShape(ishape.data() + 2, + ishape.data() + ishape.ndim()); + } padding = TShape(ishape.ndim() - 2); for (index_t i = 0; i < ishape.ndim() - 2; i++) { padding[i] = 0; @@ -173,21 +216,21 @@ class PoolingOp { kernel, padding, stride, - param_.pool_type, req, out_data.dptr(), count_include_pad); + param_.pool_type, req, out_data.dptr(), count_include_pad, layout); break; case 2: pool(s, in_data.dptr(), in_data.shape_, out_data.shape_, kernel, padding, stride, - param_.pool_type, req, out_data.dptr(), count_include_pad); + param_.pool_type, req, out_data.dptr(), count_include_pad, layout); break; case 3: pool(s, in_data.dptr(), in_data.shape_, out_data.shape_, kernel, padding, stride, - param_.pool_type, req, out_data.dptr(), count_include_pad); + param_.pool_type, req, out_data.dptr(), count_include_pad, layout); break; default: LOG(FATAL) << "p value of " << p_value << " is not supported yet..."; @@ -203,9 +246,17 @@ class PoolingOp { TShape kernel = param_.kernel; TShape padding = param_.pad; TShape stride = param_.stride; + int layout = param_.GetLayout(ishape.ndim()); if (param_.global_pool) { - kernel = TShape(ishape.data() + 2, - ishape.data() + ishape.ndim()); + // with global pooling, kernel shape corresponds to input shape with 'N' and 'C' removed + if (layout == mshadow::kNWC || layout == mshadow::kNHWC || layout == mshadow::kNDHWC) { + kernel = TShape(ishape.data() + 1, + ishape.data() + ishape.ndim() - 1); + + } else { + kernel = TShape(ishape.data() + 2, + ishape.data() + ishape.ndim()); + } padding = TShape(ishape.ndim() - 2); for (index_t i = 0; i < ishape.ndim() - 2; i++) { padding[i] = 0; @@ -224,7 +275,7 @@ class PoolingOp { kernel, padding, stride, - param_.pool_type, req, in_grad.dptr(), count_include_pad); + param_.pool_type, req, in_grad.dptr(), count_include_pad, layout); break; case 2: unpool(s, out_grad.dptr(), in_data.dptr(), out_data.dptr(), @@ -232,7 +283,7 @@ class PoolingOp { kernel, padding, stride, - param_.pool_type, req, in_grad.dptr(), count_include_pad); + param_.pool_type, req, in_grad.dptr(), count_include_pad, layout); break; case 3: unpool(s, out_grad.dptr(), in_data.dptr(), out_data.dptr(), @@ -240,7 +291,7 @@ class PoolingOp { kernel, padding, stride, - param_.pool_type, req, in_grad.dptr(), count_include_pad); + param_.pool_type, req, in_grad.dptr(), count_include_pad, layout); break; default: LOG(FATAL) << "p value of " << p_value << " is not supported yet..."; diff --git a/src/operator/nn/pooling.cc b/src/operator/nn/pooling.cc index 611568807a9a..9e9af4d97fd9 100644 --- a/src/operator/nn/pooling.cc +++ b/src/operator/nn/pooling.cc @@ -39,6 +39,9 @@ void PoolingParamParser(nnvm::NodeAttrs *attrs) { using namespace mshadow; PoolingParam param; param.Init(attrs->dict); + // Set default layout if it can be inferred from kernel shape. + if (param.kernel.ndim() > 0) + param.layout = param.GetLayout(param.kernel.ndim() + 2); if (param.kernel.ndim() == 1) { if (param.stride.ndim() == 0) param.stride = Shape1(1); if (param.pad.ndim() == 0) param.pad = Shape1(0); @@ -111,38 +114,65 @@ static bool PoolingShape(const nnvm::NodeAttrs &attrs, << "Pooling: Input data should be 3D in (batch, channel, x)" << " Or 4D in (batch, channel, y, x) " << " Or 5D in (batch, channel, d, y, x)"; - TShape oshape = dshape; if (dshape.ndim() == 0) return false; + int layout = param.GetLayout(dshape.ndim()); if (param.global_pool) { - for (size_t i{2}; i < dshape.ndim(); i++) - oshape[i] = 1; - out_shape->clear(); - out_shape->push_back(oshape); // save output shape + TShape oshape = dshape; + size_t c_index = 0; + switch (layout) { + case mshadow::kNCW: + case mshadow::kNCHW: + case mshadow::kNCDHW: + c_index = 1; + break; + case mshadow::kNWC: + case mshadow::kNHWC: + case mshadow::kNDHWC: + c_index = dshape.ndim() - 1; + break; + default: + LOG(FATAL) << "Unsupported tensor layout " << param.layout.value(); + } + for (size_t i{1}; i < dshape.ndim(); i++) + if (i != c_index) + oshape[i] = 1; + out_shape->clear(); + out_shape->push_back(oshape); // save output shape #if MXNET_USE_MKLDNN == 1 - if (MKLDNNRequireWorkspace(param) && SupportMKLDNNPooling(param)) + if (MKLDNNRequireWorkspace(param) && SupportMKLDNNPooling(param)) out_shape->push_back(oshape); // for workspace #endif + } else if (param.kernel.ndim() == 0) { + return false; } else if (param.kernel.ndim() == 1) { - CHECK_EQ(dshape.ndim(), 3U) - << "Pooling: Input data should be 3D in (batch, channel, x)"; - CHECK(param.kernel[0] <= dshape[2] + 2 * param.pad[0]) - << "kernel size (" << param.kernel[0] << ") exceeds input (" - << dshape[2] << " padded to " << (dshape[2] + 2 * param.pad[0]) - << ")"; + CHECK_EQ(dshape.ndim(), 3U) << + "Pooling: Input data should be 3D in (batch, channel, x)"; + CHECK(layout == mshadow::kNCW || layout == mshadow::kNWC) << "Need 1D layout"; + // Perform shape calculations in a standard (NCW) layout space + mshadow::Shape<3> dshape_ncw = (layout == mshadow::kNWC) ? + ConvertLayout(dshape.get<3>(), mshadow::kNWC, mshadow::kNCW) : + dshape.get<3>(); + mshadow::Shape<3> oshape_ncw = dshape_ncw; + CHECK(param.kernel[0] <= dshape_ncw[2] + 2 * param.pad[0]) + << "kernel size (" << param.kernel[0] << ") exceeds input (" << dshape[2] + << " padded to " << (dshape_ncw[2] + 2*param.pad[0]) << ")"; if (param.pooling_convention == pool_enum::kValid) { - oshape[2] = 1 + - (dshape[2] + 2 * param.pad[0] - param.kernel[0]) / - param.stride[0]; + oshape_ncw[2] = 1 + + (dshape_ncw[2] + 2 * param.pad[0] - param.kernel[0]) / + param.stride[0]; } else if (param.pooling_convention == pool_enum::kFull) { - oshape[2] = 1 + static_cast(std::ceil( - static_cast(dshape[2] + 2 * param.pad[0] - - param.kernel[0]) / - param.stride[0])); + oshape_ncw[2] = 1 + static_cast(std::ceil( + static_cast(dshape_ncw[2] + 2 * param.pad[0] - + param.kernel[0]) / + param.stride[0])); } else { - oshape[2] = static_cast(std::ceil( - static_cast(dshape[2] + 2 * param.pad[0]) / + oshape_ncw[2] = static_cast(std::ceil( + static_cast(dshape_ncw[2] + 2 * param.pad[0]) / param.stride[0])); } + // Convert back from standard (NCW) layout space to the actual layout type + TShape oshape = (layout == mshadow::kNWC) ? + ConvertLayout(oshape_ncw, mshadow::kNCW, mshadow::kNWC) : oshape_ncw; out_shape->clear(); out_shape->push_back(oshape); // save output shape #if MXNET_USE_MKLDNN == 1 @@ -150,33 +180,37 @@ static bool PoolingShape(const nnvm::NodeAttrs &attrs, out_shape->push_back(oshape); // for workspace #endif } else if (param.kernel.ndim() == 2) { - CHECK_EQ(dshape.ndim(), 4U) - << "Pooling: Input data should be 4D in (batch, channel, y, x)"; - CHECK(param.kernel[0] <= dshape[2] + 2 * param.pad[0]) - << "kernel size (" << param.kernel[0] << ") exceeds input (" - << dshape[2] << " padded to " << (dshape[2] + 2 * param.pad[0]) - << ")"; - CHECK(param.kernel[1] <= dshape[3] + 2 * param.pad[1]) - << "kernel size (" << param.kernel[1] << ") exceeds input (" - << dshape[3] << " padded to " << (dshape[3] + 2 * param.pad[1]) - << ")"; + CHECK_EQ(dshape.ndim(), 4U) << "Pooling: Input data should be 4D in (batch, channel, y, x)"; + CHECK(layout == mshadow::kNCHW || layout == mshadow::kNHWC) << "Need 2D layout"; + // Perform shape calculations in a standard (NCHW) layout space + mshadow::Shape<4> dshape_nchw = (layout == mshadow::kNHWC) ? + ConvertLayout(dshape.get<4>(), mshadow::kNHWC, mshadow::kNCHW) : + dshape.get<4>(); + mshadow::Shape<4> oshape_nchw = dshape_nchw; + CHECK(param.kernel[0] <= dshape_nchw[2] + 2 * param.pad[0]) + << "kernel size (" << param.kernel[0] << ") exceeds input (" << dshape_nchw[2] + << " padded to " << (dshape_nchw[2] + 2*param.pad[0]) << ")"; + CHECK(param.kernel[1] <= dshape_nchw[3] + 2 * param.pad[1]) + << "kernel size (" << param.kernel[1] << ") exceeds input (" << dshape_nchw[3] + << " padded to " << (dshape_nchw[3] + 2*param.pad[1]) << ")"; if (param.pooling_convention == pool_enum::kValid) { - oshape[2] = 1 + - (dshape[2] + 2 * param.pad[0] - param.kernel[0]) / - param.stride[0]; - oshape[3] = 1 + - (dshape[3] + 2 * param.pad[1] - param.kernel[1]) / - param.stride[1]; + oshape_nchw[2] = 1 + (dshape_nchw[2] + 2 * param.pad[0] - param.kernel[0]) / + param.stride[0]; + oshape_nchw[3] = 1 + (dshape_nchw[3] + 2 * param.pad[1] - param.kernel[1]) / + param.stride[1]; } else { - oshape[2] = 1 + static_cast(std::ceil( - static_cast(dshape[2] + 2 * param.pad[0] - - param.kernel[0]) / - param.stride[0])); - oshape[3] = 1 + static_cast(std::ceil( - static_cast(dshape[3] + 2 * param.pad[1] - - param.kernel[1]) / - param.stride[1])); + oshape_nchw[2] = 1 + static_cast(ceil( + static_cast(dshape_nchw[2] + 2 * param.pad[0] - + param.kernel[0]) / + param.stride[0])); + oshape_nchw[3] = 1 + static_cast(ceil( + static_cast(dshape_nchw[3] + 2 * param.pad[1] - + param.kernel[1]) / + param.stride[1])); } + // Convert back from standard (NCHW) layout space to the actual layout type + TShape oshape = (layout == mshadow::kNHWC) ? + ConvertLayout(oshape_nchw, mshadow::kNCHW, mshadow::kNHWC) : oshape_nchw; out_shape->clear(); out_shape->push_back(oshape); // save output shape #if MXNET_USE_MKLDNN == 1 @@ -185,38 +219,40 @@ static bool PoolingShape(const nnvm::NodeAttrs &attrs, #endif } else if (param.kernel.ndim() == 3) { CHECK_EQ(dshape.ndim(), 5U) - << "Pooling: Input data should be 5D in (batch, channel, d, y, x)"; - CHECK_LE(param.kernel[0], dshape[2] + 2 * param.pad[0]) - << "kernel size exceeds input"; - CHECK_LE(param.kernel[1], dshape[3] + 2 * param.pad[1]) - << "kernel size exceeds input"; - CHECK_LE(param.kernel[2], dshape[4] + 2 * param.pad[2]) - << "kernel size exceeds input"; + << "Pooling: Input data should be 5D in (batch, channel, d, y, x)"; + CHECK(layout == mshadow::kNCDHW || layout == mshadow::kNDHWC) << "Need 3D layout"; + // Perform shape calculations in a standard (NCDHW) layout space + mshadow::Shape<5> dshape_ncdhw = (layout == mshadow::kNDHWC) ? + ConvertLayout(dshape.get<5>(), mshadow::kNDHWC, mshadow::kNCDHW) : + dshape.get<5>(); + mshadow::Shape<5> oshape_ncdhw = dshape_ncdhw; + CHECK_LE(param.kernel[0], dshape_ncdhw[2] + 2 * param.pad[0]) << "kernel size exceeds input"; + CHECK_LE(param.kernel[1], dshape_ncdhw[3] + 2 * param.pad[1]) << "kernel size exceeds input"; + CHECK_LE(param.kernel[2], dshape_ncdhw[4] + 2 * param.pad[2]) << "kernel size exceeds input"; if (param.pooling_convention == pool_enum::kValid) { - oshape[2] = 1 + - (dshape[2] + 2 * param.pad[0] - param.kernel[0]) / - param.stride[0]; - oshape[3] = 1 + - (dshape[3] + 2 * param.pad[1] - param.kernel[1]) / - param.stride[1]; - oshape[4] = 1 + - (dshape[4] + 2 * param.pad[2] - param.kernel[2]) / - param.stride[2]; + oshape_ncdhw[2] = 1 + (dshape_ncdhw[2] + 2 * param.pad[0] - param.kernel[0]) / + param.stride[0]; + oshape_ncdhw[3] = 1 + (dshape_ncdhw[3] + 2 * param.pad[1] - param.kernel[1]) / + param.stride[1]; + oshape_ncdhw[4] = 1 + (dshape_ncdhw[4] + 2 * param.pad[2] - param.kernel[2]) / + param.stride[2]; } else { - oshape[2] = 1 + static_cast(std::ceil( - static_cast(dshape[2] + 2 * param.pad[0] - - param.kernel[0]) / - param.stride[0])); - oshape[3] = 1 + static_cast(std::ceil( - static_cast(dshape[3] + 2 * param.pad[1] - - param.kernel[1]) / - param.stride[1])); - oshape[4] = 1 + static_cast(std::ceil( - static_cast(dshape[4] + 2 * param.pad[2] - - param.kernel[2]) / - param.stride[2])); + oshape_ncdhw[2] = 1 + static_cast(ceil( + static_cast(dshape_ncdhw[2] + 2 * param.pad[0] - + param.kernel[0]) / + param.stride[0])); + oshape_ncdhw[3] = 1 + static_cast(ceil( + static_cast(dshape_ncdhw[3] + 2 * param.pad[1] - + param.kernel[1]) / + param.stride[1])); + oshape_ncdhw[4] = 1 + static_cast(ceil( + static_cast(dshape_ncdhw[4] + 2 * param.pad[2] - + param.kernel[2]) / + param.stride[2])); } - + // Convert back from standard (NCDHW) layout space to the actual layout type + TShape oshape = (layout == mshadow::kNDHWC) ? + ConvertLayout(oshape_ncdhw, mshadow::kNCDHW, mshadow::kNDHWC) : oshape_ncdhw; out_shape->clear(); out_shape->push_back(oshape); // save output shape #if MXNET_USE_MKLDNN == 1 @@ -224,6 +260,7 @@ static bool PoolingShape(const nnvm::NodeAttrs &attrs, out_shape->push_back(oshape); // for workspace #endif } + return true; } @@ -331,13 +368,13 @@ NNVM_REGISTER_OP(Pooling) The shapes for 1-D pooling are -- **data**: *(batch_size, channel, width)*, -- **out**: *(batch_size, num_filter, out_width)*. +- **data** and **out**: *(batch_size, channel, width)* (NCW layout) or + *(batch_size, width, channel)* (NWC layout), The shapes for 2-D pooling are -- **data**: *(batch_size, channel, height, width)* -- **out**: *(batch_size, num_filter, out_height, out_width)*, with:: +- **data** and **out**: *(batch_size, channel, height, width)* (NCHW layout) or + *(batch_size, height, width, channel)* (NHWC layout), out_height = f(height, kernel[0], pad[0], stride[0]) out_width = f(width, kernel[1], pad[1], stride[1]) @@ -363,8 +400,8 @@ Three pooling options are supported by ``pool_type``: - **lp**: Lp pooling For 3-D pooling, an additional *depth* dimension is added before -*height*. Namely the input data will have shape *(batch_size, channel, depth, -height, width)*. +*height*. Namely the input data and output will have shape *(batch_size, channel, depth, +height, width)* (NCDHW layout) or *(batch_size, depth, height, width, channel)* (NDHWC layout). Notes on Lp pooling: @@ -421,11 +458,13 @@ NNVM_REGISTER_OP(_backward_Pooling) .set_attr( "FInplaceOption", [](const NodeAttrs &attrs) { -#if MXNET_USE_CUDNN == 1 - return std::vector >(); -#else - return std::vector >{{1, 0}}; +// Different backend requires different FInplaceOption +#if MXNET_USE_MKLDNN == 1 + const PoolingParam ¶m = nnvm::get(attrs.parsed); + if (MKLDNNRequireWorkspace(param) && SupportMKLDNNPooling(param)) + return std::vector >{{1, 0}}; #endif + return std::vector >(); }) #if MXNET_USE_MKLDNN == 1 .set_attr("FResourceRequest", [](const NodeAttrs& n) { diff --git a/src/operator/nn/pooling.cu b/src/operator/nn/pooling.cu index 997218620c3a..84cacc15e239 100644 --- a/src/operator/nn/pooling.cu +++ b/src/operator/nn/pooling.cu @@ -56,19 +56,11 @@ void PoolingCompute(const nnvm::NodeAttrs& attrs, CHECK_EQ(outputs.size(), GetNumOutputs(param)); #if MXNET_USE_CUDNN == 1 - if (!param.cudnn_off && param.kernel.ndim() > 1) { + if (!param.cudnn_off) { MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, { - switch (param.pool_type) { - case pool_enum::kMaxPooling: - case pool_enum::kAvgPooling: - GetCuDNNPoolingOp(param).Forward(ctx, inputs[0], req[0], outputs[0]); - return; - case pool_enum::kSumPooling: - LOG(WARNING) << "Sum pooling is not supported by cudnn, MXNet sum pooling is applied."; - break; - case pool_enum::kLpPooling: - LOG(WARNING) << "Lp pooling is not supported by cudnn, MXNet lp pooling is applied."; - break; + if (CuDNNPoolingOp::Supports(param, inputs[0])) { + GetCuDNNPoolingOp(param).Forward(ctx, inputs[0], req[0], outputs[0]); + return; } }); } @@ -111,21 +103,13 @@ void PoolingGradCompute(const nnvm::NodeAttrs& attrs, } #if MXNET_USE_CUDNN == 1 - if (!param.cudnn_off && param.kernel.ndim() > 1) { + if (!param.cudnn_off) { MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, { - switch (param.pool_type) { - case pool_enum::kMaxPooling: - case pool_enum::kAvgPooling: + if (CuDNNPoolingOp::Supports(param, inputs[in_data_idx])) { GetCuDNNPoolingOp(param).Backward(ctx, inputs[ograd_idx], inputs[in_data_idx], inputs[out_data_idx], req[0], outputs[0]); return; - case pool_enum::kSumPooling: - LOG(WARNING) << "Sum pooling is not supported by cudnn, MXNet sum pooling is applied."; - break; - case pool_enum::kLpPooling: - LOG(WARNING) << "Lp pooling is not supported by cudnn, MXNet Lp pooling is applied."; - break; } }); } diff --git a/src/operator/quantization/quantized_pooling.cc b/src/operator/quantization/quantized_pooling.cc index 8b62db9c061d..b9daf2592b7d 100644 --- a/src/operator/quantization/quantized_pooling.cc +++ b/src/operator/quantization/quantized_pooling.cc @@ -40,6 +40,9 @@ bool QuantizedPoolingShape(const nnvm::NodeAttrs& attrs, CHECK_EQ(dshape.ndim(), 4U) << "quantized_pooling: Input data should be 4D in " << "(batch, channel, y, x)"; + int layout = param.GetLayout(dshape.ndim()); + CHECK_EQ(layout, mshadow::kNCHW) + << "QuantizedPoolingOp only supports NCHW layout for now, saw " << layout; // NCHW layout const int N = 0, H = 2, W = 3, C = 1; TShape oshape(4); diff --git a/tests/python/gpu/test_operator_gpu.py b/tests/python/gpu/test_operator_gpu.py index 7a7c6f69dd77..010cf504fe70 100644 --- a/tests/python/gpu/test_operator_gpu.py +++ b/tests/python/gpu/test_operator_gpu.py @@ -607,7 +607,40 @@ def test_convolution_versions(): check_consistency(syms, ctx_list) +# More max-pooling strides and pads to test cudnn pooling implementation code paths @with_seed() +def test_pooling_nhwc_with_convention(): + def make_pooling_syms(**kwargs): + # Conventional NCHW layout pooling + sym = mx.sym.Pooling(**kwargs) + # NHWC pooling + data = mx.sym.Variable('pool_data') + sym_nhwc = mx.sym.transpose(data, axes=(0,2,3,1)) + sym_nhwc = mx.sym.Pooling(sym_nhwc, layout='NHWC', **kwargs) + sym_nhwc = mx.sym.transpose(sym_nhwc, axes=(0,3,1,2), name='pool') + return [sym, sym_nhwc] + + # While the float32 and float64 output is reliably consistent, float16 departs occasionally. + # We compare nhwc and nchw results only within a given precision. + for in_shape in [(3, 4, 8, 8), (2, 2, 20, 20)]: + for kernel in [(2,2), (3,3), (4,4)]: + for stride in [(1,1), (1,2), (2,1), (2,2)]: + for data_type in [np.float64, np.float32, np.float16]: + ctx_list = [{'ctx': mx.gpu(0), 'pool_data': in_shape, + 'type_dict': {'pool_data': data_type}}] + symlist = make_pooling_syms(kernel=kernel, pool_type='max', stride=stride, + pooling_convention='valid', name='pool') + check_consistency_NxM(symlist, ctx_list) + + symlist = make_pooling_syms(kernel=kernel, pool_type='max', stride=stride, + pooling_convention='full', name='pool') + check_consistency_NxM(symlist, ctx_list) + + symlist = make_pooling_syms(kernel=(300,300), pool_type='max', + global_pool=True, name='pool') + check_consistency_NxM(symlist, ctx_list) + + def test_pooling_with_type(): ctx_list = [{'ctx': mx.gpu(0), 'pool_data': (2, 2, 10, 10), 'type_dict': {'pool_data': np.float64}}, {'ctx': mx.gpu(0), 'pool_data': (2, 2, 10, 10), 'type_dict': {'pool_data': np.float32}}, @@ -768,232 +801,241 @@ def test_spatial_transformer_with_type(): check_consistency(sym, ctx_list) check_consistency(sym, ctx_list, grad_req="add") - @with_seed() def test_pooling_with_type2(): - ctx_list = [{'ctx': mx.gpu(0), 'pool_data': (10, 2, 10, 10), 'type_dict': {'pool_data': np.float64}}, - {'ctx': mx.gpu(0), 'pool_data': (10, 2, 10, 10), 'type_dict': {'pool_data': np.float32}}, - {'ctx': mx.gpu(0), 'pool_data': (10, 2, 10, 10), 'type_dict': {'pool_data': np.float16}}, - {'ctx': mx.cpu(0), 'pool_data': (10, 2, 10, 10), 'type_dict': {'pool_data': np.float64}}, - {'ctx': mx.cpu(0), 'pool_data': (10, 2, 10, 10), 'type_dict': {'pool_data': np.float32}}] + # While the float32 and float64 output is reliably consistent, float16 departs occasionally. + # We compare cpu and gpu results only within a given precision. + for data_type in [np.float64, np.float32, np.float16]: + ctx_list = [{'ctx': mx.gpu(0), 'pool_data': (10, 2, 10, 10), 'type_dict': {'pool_data': data_type}}, + {'ctx': mx.cpu(0), 'pool_data': (10, 2, 10, 10), 'type_dict': {'pool_data': data_type}}] - sym = mx.sym.Pooling(name='pool', kernel=(3,3), stride=(2,2), pool_type='max') - check_consistency(sym, ctx_list, rand_type=np.float16) + sym = mx.sym.Pooling(name='pool', kernel=(3,3), stride=(2,2), pool_type='max') + check_consistency(sym, ctx_list) - sym = mx.sym.Pooling(name='pool', kernel=(3,3), pad=(1,1), pool_type='avg') - check_consistency(sym, ctx_list) + sym = mx.sym.Pooling(name='pool', kernel=(3,3), pad=(1,1), pool_type='avg') + check_consistency(sym, ctx_list) - sym = mx.sym.Pooling(name='pool', kernel=(5,5), pad=(2,2), pool_type='max') - check_consistency(sym, ctx_list, rand_type=np.float16) + sym = mx.sym.Pooling(name='pool', kernel=(5,5), pad=(2,2), pool_type='max') + check_consistency(sym, ctx_list) + + sym = mx.sym.Pooling(name='pool', kernel=(3,3), pad=(1,1), pool_type='sum') + check_consistency(sym, ctx_list) + +@with_seed() +def test_pooling_nhwc_with_type(): + def make_pooling_syms(**kwargs): + # Conventional NCHW layout pooling + sym = mx.sym.Pooling(**kwargs) + # NHWC pooling + data = mx.sym.Variable('pool_data') + sym_nhwc = mx.sym.transpose(data, axes=(0,2,3,1)) + sym_nhwc = mx.sym.Pooling(sym_nhwc, layout='NHWC', **kwargs) + sym_nhwc = mx.sym.transpose(sym_nhwc, axes=(0,3,1,2), name='pool') + return [sym, sym_nhwc] + + # While the float32 and float64 output is reliably consistent, float16 departs occasionally. + # We compare nhwc and nchw results only within a given precision. + for data_type in [np.float64, np.float32, np.float16]: + # NHWC pooling only enabled on GPU with CUDNN + ctx_list = [{'ctx': mx.gpu(0), 'pool_data': (10, 2, 10, 10), 'type_dict': {'pool_data': data_type}}] + symlist = make_pooling_syms(name='pool', kernel=(3,3), stride=(2,2), pool_type='max') + check_consistency_NxM(symlist, ctx_list) + + symlist = make_pooling_syms(name='pool', kernel=(3,3), pad=(1,1), pool_type='avg') + check_consistency_NxM(symlist, ctx_list) + + symlist = make_pooling_syms(name='pool', kernel=(5,5), pad=(2,2), pool_type='max') + check_consistency_NxM(symlist, ctx_list) - sym = mx.sym.Pooling(name='pool', kernel=(3,3), pad=(1,1), pool_type='sum') - check_consistency(sym, ctx_list) -@unittest.skip("Flaky test https://github.com/apache/incubator-mxnet/issues/11517") @with_seed() def test_pooling_versions(): - def test_pooling_versions_helper(pool_op_list, data, kernel, pool_type, pad, stride, pooling_convention='valid', - global_pool=False, p_value=2, count_include_pad=True, tol=None): - ctx_list = [] - sym_list = [] - # PoolingV1 cpu - if 'pool_v1_cpu' in pool_op_list: - ctx_list.append({'ctx': mx.cpu(0), 'pool_data': data, 'type_dict': {'pool_data': np.float32}}) - if not global_pool: - sym_list.append(mx.sym.Pooling_v1(kernel=kernel, pad=pad, stride=stride, pool_type=pool_type, - pooling_convention=pooling_convention, name='pool')) - else: - sym_list.append(mx.sym.Pooling_v1(kernel=kernel, pool_type=pool_type, global_pool=True, name='pool')) - # PoolingV1 gpu - if 'pool_v1_gpu' in pool_op_list: - ctx_list.append({'ctx': mx.gpu(0), 'pool_data': data, 'type_dict': {'pool_data': np.float32}}) - if not global_pool: - sym_list.append(mx.sym.Pooling_v1(kernel=kernel, pad=pad, stride=stride, pool_type=pool_type, - pooling_convention=pooling_convention, name='pool')) - else: - sym_list.append(mx.sym.Pooling_v1(kernel=kernel, pool_type=pool_type, global_pool=True, name='pool')) - # Pooling cpu - if 'pool_cpu' in pool_op_list: - ctx_list.append({'ctx': mx.cpu(0), 'pool_data': data, 'type_dict': {'pool_data': np.float32}}) - if not global_pool: - sym_list.append(mx.sym.Pooling(kernel=kernel, pad=pad, stride=stride, pool_type=pool_type, - pooling_convention=pooling_convention, name='pool', - p_value=p_value, count_include_pad=count_include_pad)) - else: - sym_list.append(mx.sym.Pooling(kernel=kernel, pool_type=pool_type, global_pool=True, name='pool', - p_value=p_value, count_include_pad=count_include_pad)) - # Pooling gpu - if 'pool_gpu' in pool_op_list: - ctx_list.append({'ctx': mx.gpu(0), 'pool_data': data, 'type_dict': {'pool_data': np.float32}}) - if not global_pool: - sym_list.append(mx.sym.Pooling(kernel=kernel, pad=pad, stride=stride, pool_type=pool_type, - pooling_convention=pooling_convention, cudnn_off=True, name='pool', - p_value=p_value, count_include_pad=count_include_pad)) - else: - sym_list.append(mx.sym.Pooling(kernel=kernel, pool_type=pool_type, global_pool=True, cudnn_off=True, - name='pool', p_value=p_value, count_include_pad=count_include_pad)) - # CuDNNPooling - if 'pool_cudnn' in pool_op_list: - ctx_list.append({'ctx': mx.gpu(0), 'pool_data': data, 'type_dict': {'pool_data': np.float32}}) - if not global_pool: - sym_list.append(mx.sym.Pooling(kernel=kernel, pad=pad, stride=stride, pool_type=pool_type, - pooling_convention=pooling_convention, p_value=p_value, cudnn_off=False, - name='pool', count_include_pad=count_include_pad)) - else: - sym_list.append(mx.sym.Pooling(kernel=kernel, pool_type=pool_type, global_pool=True, p_value=p_value, - cudnn_off=False, name='pool', count_include_pad=count_include_pad)) - check_consistency(sym_list, ctx_list, equal_nan=(not count_include_pad), tol=tol) - def test_1d_pooling(pool_type, p_value=2, count_include_pad=True): - data = (2, 3, 20) - kernel = (4,) - pad = (0,) - stride = (1,) - test_pooling_versions_helper(pool_op_list=['pool_cpu', 'pool_gpu'], - data=data, kernel=kernel, pad=pad, stride=stride, pool_type=pool_type, - pooling_convention='valid', global_pool=False, p_value=p_value, - count_include_pad=count_include_pad) + # Produce the name of the 'transposed' layout, given the dimension + def transposed_layout(ndim): + if ndim < 3 or ndim > 5: + raise RuntimeError("Invalid data dim, expecting 3, 4 or 5") + return ('NWC', 'NHWC', 'NDHWC')[ndim-3] - pad = (2,) - stride = (2,) - test_pooling_versions_helper(pool_op_list=['pool_cpu', 'pool_gpu'], - data=data, kernel=kernel, pad=pad, stride=stride, pool_type=pool_type, - pooling_convention='valid', global_pool=False, p_value=p_value, - count_include_pad=count_include_pad) - - pad = (0,) - stride = (1,) - test_pooling_versions_helper(pool_op_list=['pool_cpu', 'pool_gpu'], - data=data, kernel=kernel, pad=pad, stride=stride, pool_type=pool_type, - pooling_convention='full', global_pool=False, p_value=p_value, - count_include_pad=count_include_pad) + # default padding is all zeros + def is_default_pad(pad): + return pad == (0,) * len(pad) - pad = (2,) - stride = (2,) - test_pooling_versions_helper(pool_op_list=['pool_cpu', 'pool_gpu'], - data=data, kernel=kernel, pad=pad, stride=stride, pool_type=pool_type, - pooling_convention='full', global_pool=False, p_value=p_value, - count_include_pad=count_include_pad) + # default stride is all ones + def is_default_stride(stride): + return stride == (1,) * len(stride) - test_pooling_versions_helper(pool_op_list=['pool_cpu', 'pool_gpu'], - data=data, kernel=kernel, pad=pad, stride=stride, pool_type=pool_type, - global_pool=True, p_value=p_value, count_include_pad=count_include_pad) + # returns True/False randomly with equal probability + def random_choice(): + return np.random.random(1)[0] < 0.5 - def test_2d_pooling(pool_type, p_value=2, count_include_pad=True): - data = (2, 3, 20, 20) - kernel = (4, 5) - pad = (0, 0) - stride = (1, 1) - if pool_type == 'lp': - test_pooling_versions_helper(pool_op_list=['pool_cpu', 'pool_gpu'], - data=data, kernel=kernel, pad=pad, stride=stride, pool_type=pool_type, - pooling_convention='valid', global_pool=False, p_value=p_value) - else: - test_pooling_versions_helper(pool_op_list=['pool_v1_cpu', 'pool_v1_gpu', 'pool_cpu', 'pool_gpu', 'pool_cudnn'], - data=data, kernel=kernel, pad=pad, stride=stride, pool_type=pool_type, - pooling_convention='valid', global_pool=False, count_include_pad=count_include_pad) - - # pool_v1 has bugs when pad is not 0, do not test PoolingV1 here - pad = (2, 3) - stride = (2, 3) - test_pooling_versions_helper(pool_op_list=['pool_cpu', 'pool_gpu', 'pool_cudnn'], - data=data, kernel=kernel, pad=pad, stride=stride, pool_type=pool_type, - pooling_convention='valid', global_pool=False, p_value=p_value, - count_include_pad=count_include_pad) - - pad = (0, 0) - stride = (1, 1) - if pool_type == 'lp': - test_pooling_versions_helper(pool_op_list=['pool_cpu', 'pool_gpu', 'pool_cudnn'], - data=data, kernel=kernel, pad=pad, stride=stride, pool_type=pool_type, - pooling_convention='full', global_pool=False, p_value=p_value) - else: - if count_include_pad: - test_pooling_versions_helper(pool_op_list=['pool_v1_cpu', 'pool_v1_gpu', 'pool_cpu', 'pool_gpu', 'pool_cudnn'], - data=data, kernel=kernel, pad=pad, stride=stride, pool_type=pool_type, - pooling_convention='full', global_pool=False, - count_include_pad=count_include_pad) + def test_pooling_versions_helper(pool_op_list, data, kernel, pool_type, pad, stride, + pooling_convention='valid', global_pool=False, p_value=2, + count_include_pad=True, tol=None, dtype=np.float32): + ctx_list = [] + sym_list = [] + for pool_ctx in pool_op_list: + (pool_op, ctx_type) = pool_ctx.rsplit('_', 1) + expected_ctxs = ['cpu', 'gpu', 'cudnn'] + if ctx_type not in expected_ctxs: + raise RuntimeError('Expected one of {}, saw {}.'.format(expected_ctxs, ctx_type)) + ctx = mx.cpu(0) if ctx_type == 'cpu' else mx.gpu(0) + ctx_list.append({'ctx': ctx, 'pool_data': data, 'type_dict': {'pool_data': dtype}}) + # start with pool args present in all cases + pool_op_args = {'kernel': kernel, 'pool_type': pool_type, + 'pooling_convention' : pooling_convention, 'name' : 'pool'} + # add other args as needed + if global_pool: + pool_op_args['global_pool'] = True else: - test_pooling_versions_helper(pool_op_list=['pool_cpu', 'pool_gpu', 'pool_cudnn'], - data=data, kernel=kernel, pad=pad, stride=stride, pool_type=pool_type, - pooling_convention='full', global_pool=False, - count_include_pad=count_include_pad) - - # pool_v1 has bugs when pad is not 0, do not test PoolingV1 here - pad = (2, 3) - stride = (2, 3) - test_pooling_versions_helper(pool_op_list=['pool_cpu', 'pool_gpu', 'pool_cudnn'], - data=data, kernel=kernel, pad=pad, stride=stride, pool_type=pool_type, - pooling_convention='full', global_pool=False, p_value=p_value, - count_include_pad=count_include_pad) - - if pool_type == 'lp': - test_pooling_versions_helper(pool_op_list=['pool_cpu', 'pool_gpu', 'pool_cudnn'], - data=data, kernel=kernel, pad=pad, stride=stride, pool_type=pool_type, - global_pool=True, p_value=p_value) - else: - test_pooling_versions_helper(pool_op_list=['pool_v1_cpu', 'pool_v1_gpu', 'pool_cpu', 'pool_gpu', 'pool_cudnn'], - data=data, kernel=kernel, pad=pad, stride=stride, pool_type=pool_type, - global_pool=True, count_include_pad=count_include_pad) - - def test_3d_pooling(pool_type, p_value=2, count_include_pad=True): - data = (2, 3, 20, 20, 20) - kernel = (4, 5, 3) - pad = (0, 0, 0) - stride = (1, 1, 1) - test_pooling_versions_helper(pool_op_list=['pool_cpu', 'pool_gpu', 'pool_cudnn'], - data=data, kernel=kernel, pad=pad, stride=stride, pool_type=pool_type, - pooling_convention='valid', global_pool=False, p_value=p_value, - count_include_pad=count_include_pad) - - pad = (2, 3, 3) - stride = (2, 3, 1) - test_pooling_versions_helper(pool_op_list=['pool_cpu', 'pool_gpu', 'pool_cudnn'], - data=data, kernel=kernel, pad=pad, stride=stride, pool_type=pool_type, - pooling_convention='valid', global_pool=False, p_value=p_value, - count_include_pad=count_include_pad) - - pad = (0, 0, 0) - stride = (1, 1, 1) - test_pooling_versions_helper(pool_op_list=['pool_cpu', 'pool_gpu', 'pool_cudnn'], - data=data, kernel=kernel, pad=pad, stride=stride, pool_type=pool_type, - pooling_convention='full', global_pool=False, p_value=p_value, - count_include_pad=count_include_pad) - - pad = (2, 3, 3) - stride = (2, 3, 1) - test_pooling_versions_helper(pool_op_list=['pool_cpu', 'pool_gpu', 'pool_cudnn'], - data=data, kernel=kernel, pad=pad, stride=stride, pool_type=pool_type, - pooling_convention='full', global_pool=False, p_value=p_value, - count_include_pad=count_include_pad) - - test_pooling_versions_helper(pool_op_list=['pool_cpu', 'pool_gpu', 'pool_cudnn'], - data=data, kernel=kernel, pad=pad, stride=stride, pool_type=pool_type, - global_pool=True, p_value=p_value, count_include_pad=count_include_pad) - - test_1d_pooling('max') - test_1d_pooling('avg', count_include_pad=True) - test_1d_pooling('avg', count_include_pad=False) - test_1d_pooling('sum') - test_1d_pooling('lp', p_value=1) - test_1d_pooling('lp', p_value=2) - test_1d_pooling('lp', p_value=3) + # Add pad and stride param if needed, plus randomly when it matches the default + if not is_default_pad(pad) or random_choice(): + pool_op_args.update({'pad' : pad}) + if not is_default_stride(stride) or random_choice(): + pool_op_args.update({'stride' : stride}) + + expected_pool_ops = ['pool', 'pool_transposed', 'pool_v1'] + if pool_op == 'pool_v1': + sym = mx.sym.Pooling_v1(**pool_op_args) + else: + pool_op_args.update({'p_value' : p_value, 'count_include_pad' : count_include_pad}) + if ctx_type != 'cpu': + pool_op_args['cudnn_off'] = ctx_type == 'gpu' + if pool_op == 'pool': + # isolate pooling input from symbol input to test shared tensor optimizations + buffered_input = mx.sym.identity(name='pool') + sym = mx.sym.Pooling(buffered_input, **pool_op_args) + elif pool_op == 'pool_transposed': + ndim = len(data) + # NCW->NWC axes=(0,2,1) NCHW->NHWC axes=(0,2,3,1) NCDHW->NDHWC axes=(0,2,3,4,1); + axes = (0,) + tuple(range(2,ndim)) + (1,) + transposed = mx.sym.transpose(axes=axes, name='pool') + pooled = mx.sym.Pooling(data=transposed, layout=transposed_layout(ndim), + **pool_op_args) + # NWC->NCW axes=(0,2,1) NHWC->NCHW axes=(0,3,1,2) NDHWC->NCDHW axes=(0,4,1,2,3); + axes = (0, ndim-1) + tuple(range(1,ndim-1)) + sym = mx.sym.transpose(data=pooled, axes=axes, name='pool') + else: + raise RuntimeError('Expected one of {}, saw {}.'.format(expected_pool_ops, + pool_op)) + sym_list.append(sym) - test_2d_pooling('max') - test_2d_pooling('avg', count_include_pad=True) - test_2d_pooling('avg', count_include_pad=False) - test_2d_pooling('sum') - test_2d_pooling('lp', p_value=1) - test_2d_pooling('lp', p_value=2) - test_2d_pooling('lp', p_value=3) + check_consistency(sym_list, ctx_list, equal_nan=(not count_include_pad), tol=tol) - test_3d_pooling('max') - test_3d_pooling('avg', count_include_pad=True) - test_3d_pooling('avg', count_include_pad=False) - test_3d_pooling('sum') - test_3d_pooling('lp', p_value=1) - test_3d_pooling('lp', p_value=2) - test_3d_pooling('lp', p_value=3) + def test_pooling_dim(dim, pool_type, dtype, pool_op_list, p_value=2, count_include_pad=True, + tol=None): + if dim == '1D': + data = (3, 3, 10) + kernels = [(4,), (4,), (5,)] + pads = [(0,), (2,), (2,)] + strides = [(1,), (2,), (1,)] + elif dim == '2D_no_padding': + data = (3, 2, 20, 20) + kernels = [(3, 3), (4, 5)] + pads = [(0, 0), (0, 0)] + strides = [(1, 1), (2, 1)] + elif dim == '2D': + data = (2, 2, 20, 20) + kernels = [(3, 3), (3, 5), (4, 5), (4, 5)] + pads = [(0, 0), (1, 2), (0, 0), (2, 3)] + strides = [(1, 1), (1, 1), (2, 1), (1, 1)] + elif dim == '3D': + data = (2, 3, 20, 20, 20) + kernels = [(4, 5, 3), (4, 5, 3), (3, 5, 7)] + pads = [(0, 0, 0), (2, 3, 2), (1, 2, 3)] + strides = [(1, 1, 1), (2, 3, 1), (1, 1, 1)] + else: + raise RuntimeError('Unexpected pooling test class: {}.'.format(dim)) + + for kernel, pad, stride in zip(kernels, pads, strides): + for pooling_convention in ['valid', 'full']: + try: + test_pooling_versions_helper(pool_op_list=pool_op_list, + data=data, kernel=kernel, pad=pad, stride=stride, + pool_type=pool_type, pooling_convention=pooling_convention, + global_pool=False, p_value=p_value, + count_include_pad=count_include_pad, tol=tol, dtype=dtype) + except: + print('pool_op_list = {}'.format(pool_op_list)) + print('kernel={}, pad={}, stride={}'.format(kernel, pad, stride)) + print('pool_type={}, pooling_convention={}, global_pool=False'.format(pool_type, + pooling_convention)) + print('p_value={}, count_include_pad={}, dtype={}'.format(p_value, + count_include_pad, dtype)) + print('environ = \n{}'.format(os.environ)) + raise + + # Make sure kernel is ignored during global_pool by sometimes setting it to a crazy value + kernel = kernels[0] + if random_choice(): + kernel = (300,) * len(kernel) + + test_pooling_versions_helper(pool_op_list=pool_op_list, + data=data, kernel=kernel, pad=None, stride=None, + pool_type=pool_type, global_pool=True, p_value=p_value, + count_include_pad=count_include_pad, tol=tol, dtype=dtype) + + # The various implementations of the standard pooling operator + std_pool_op_list = ['pool_cpu', 'pool_transposed_cpu', + 'pool_gpu', 'pool_transposed_gpu', + 'pool_cudnn', 'pool_transposed_cudnn'] + # The implementations of the 'v1' pooling operator + v1_pool_op_list = ['pool_v1_cpu', 'pool_v1_gpu'] + # For those cases when all implementations should match- the combined implementation list. + combo_pool_op_list = std_pool_op_list + v1_pool_op_list + + for dtype in [np.float32, np.float64, np.float16]: + # Testing of the standard (not 'v1') pooling operator is universal across all + # data dimensions, implementations and layouts. + for dim in ['1D', '2D', '3D']: + test_pooling_dim(dim, 'max', dtype, std_pool_op_list) + test_pooling_dim(dim, 'avg', dtype, std_pool_op_list, count_include_pad=True) + test_pooling_dim(dim, 'avg', dtype, std_pool_op_list, count_include_pad=False) + test_pooling_dim(dim, 'sum', dtype, std_pool_op_list) + test_pooling_dim(dim, 'lp', dtype, std_pool_op_list, p_value=1) + test_pooling_dim(dim, 'lp', dtype, std_pool_op_list, p_value=2) + test_pooling_dim(dim, 'lp', dtype, std_pool_op_list, p_value=3) + + # Testing of the 'v1' pooling operator is over its restricted support domain of + # 2D data only and not with the 'lp' pooling type. The 'v1' cpu and gpu versions are + # always tested against each other, and sometimes against the standard operator versions. + # The slightly different 'v1' definition prevents this in the following cases: + # + # 1. In max pooling, when multiple input values are the maximum in the input window, + # the 'v1' implementation backprops the gradient to all maxima, whereas the standard + # pooling operator backprops the gradient to the lowest-indexed maximum only. + # 2. In max pooling, the 'v1' operator pads with 0's and this value can become the + # maximum output value in the case of an all-negative input. The standard pooling + # operator effectively considers the padding to be the largest negative value, so + # only input values should appear in the output. + # 3. In avg pooling, the 'v1' operator divides the sum by the same window size factor, + # even at the edges, and so does not support count_include_pad = False. + # 4. The float16 'v1' pooling operator performs forward sums and averages in + # float16, whereas the std operators perform those calculations in float32, so + # greater float16 tolerances are needed when comparing across implementations. + + # Double the float16 tol when comparing v1 and non-v1 implemenations, per note 4 above. + relaxed_tol = {np.dtype(np.float16): 2e-1, + np.dtype(np.float32): 1e-3, + np.dtype(np.float64): 1e-5, + np.dtype(np.uint8): 0, + np.dtype(np.int32): 0, + np.dtype(np.int64): 0} + + # Exclude std implementations due to points 1 and 2 above. + test_pooling_dim('2D', 'max', dtype, v1_pool_op_list) + # The standard and 'v1' implementations match for this case. + test_pooling_dim('2D', 'avg', dtype, combo_pool_op_list, count_include_pad=True, + tol=relaxed_tol) + # Exclude std implementations due to point 3 above. + test_pooling_dim('2D', 'avg', dtype, v1_pool_op_list, count_include_pad=False) + # The standard and 'v1' implementations match for this case. + test_pooling_dim('2D', 'sum', dtype, combo_pool_op_list, tol=relaxed_tol) + + # We can compare the standard and 'v1' max pooling implementations if we eliminate padding + # (see point 2 above) and use np.float64 data so that no two random input window values are + # likely to be the same (see point 1 above). + test_pooling_dim('2D_no_padding', 'max', np.float64, combo_pool_op_list) @with_seed() diff --git a/tests/python/unittest/test_gluon.py b/tests/python/unittest/test_gluon.py index abe6b136fe0c..34380dc00314 100644 --- a/tests/python/unittest/test_gluon.py +++ b/tests/python/unittest/test_gluon.py @@ -506,50 +506,75 @@ def test_deconv(): @with_seed() def test_pool(): - layers1d = [ - nn.MaxPool1D(), - nn.MaxPool1D(3), - nn.MaxPool1D(3, 2), - nn.AvgPool1D(), - nn.AvgPool1D(count_include_pad=False), - nn.GlobalAvgPool1D(), - ] - for layer in layers1d: - check_layer_forward(layer, (1, 2, 10)) - - - layers2d = [ - nn.MaxPool2D(), - nn.MaxPool2D((3, 3)), - nn.MaxPool2D(3, 2), - nn.AvgPool2D(), - nn.AvgPool2D(count_include_pad=False), - nn.GlobalAvgPool2D(), - ] - for layer in layers2d: - check_layer_forward(layer, (1, 2, 10, 10)) - - layers3d = [ - nn.MaxPool3D(), - nn.MaxPool3D((3, 3, 3)), - nn.MaxPool3D(3, 2), - nn.AvgPool3D(), - nn.AvgPool3D(count_include_pad=False), - nn.GlobalAvgPool3D(), - ] - for layer in layers3d: - check_layer_forward(layer, (1, 2, 10, 10, 10)) + # transpose shape to bring feature dimension 'c' from 2nd position to last + def transpose(shape): + return (shape[0],) + shape[2:] + (shape[1],) + + for layout in ['NCW', 'NWC']: + shape1d = (1, 2, 10) + if layout == 'NWC': + shape1d = transpose(shape1d) + layers1d = [ + nn.MaxPool1D(layout=layout), + nn.MaxPool1D(3, layout=layout), + nn.MaxPool1D(3, 2, layout=layout), + nn.AvgPool1D(layout=layout), + nn.AvgPool1D(count_include_pad=False, layout=layout), + nn.GlobalAvgPool1D(layout=layout), + ] + for layer in layers1d: + check_layer_forward(layer, shape1d) + + + for layout in ['NCHW', 'NHWC']: + shape2d = (1, 2, 10, 10) + if layout == 'NHWC': + shape2d = transpose(shape2d) + layers2d = [ + nn.MaxPool2D(layout=layout), + nn.MaxPool2D((3, 3), layout=layout), + nn.MaxPool2D(3, 2, layout=layout), + nn.AvgPool2D(layout=layout), + nn.AvgPool2D(count_include_pad=False, layout=layout), + nn.GlobalAvgPool2D(layout=layout), + ] + for layer in layers2d: + check_layer_forward(layer, shape2d) + + for layout in ['NCDHW', 'NDHWC']: + shape3d = (1, 2, 10, 10, 10) + if layout == 'NDHWC': + shape3d = transpose(shape3d) + layers3d = [ + nn.MaxPool3D(layout=layout), + nn.MaxPool3D((3, 3, 3), layout=layout), + nn.MaxPool3D(3, 2, layout=layout), + nn.AvgPool3D(layout=layout), + nn.AvgPool3D(count_include_pad=False, layout=layout), + nn.GlobalAvgPool3D(layout=layout), + ] + for layer in layers3d: + check_layer_forward(layer, shape3d) # test ceil_mode - x = mx.nd.zeros((2, 2, 10, 10)) + for layout in ['NCHW', 'NHWC']: + xshape = (2, 2, 10, 10) + noceil_out_shape = (2, 2, 3, 3) + ceil_out_shape = (2, 2, 4, 4) + if layout == 'NHWC': + xshape = transpose(xshape) + noceil_out_shape = transpose(noceil_out_shape) + ceil_out_shape = transpose(ceil_out_shape) - layer = nn.MaxPool2D(3, ceil_mode=False) - layer.collect_params().initialize() - assert (layer(x).shape==(2, 2, 3, 3)) + x = mx.nd.zeros(xshape) - layer = nn.MaxPool2D(3, ceil_mode=True) - layer.collect_params().initialize() - assert (layer(x).shape==(2, 2, 4, 4)) + layer = nn.MaxPool2D(3, ceil_mode=False, layout=layout) + layer.collect_params().initialize() + assert (layer(x).shape==noceil_out_shape) + + layer = nn.MaxPool2D(3, ceil_mode=True, layout=layout) + layer.collect_params().initialize() + assert (layer(x).shape==ceil_out_shape) @with_seed() @@ -2091,31 +2116,41 @@ def hybrid_forward(self, F, x): @with_seed() def test_slice_pooling2d(): - max_pooling = nn.MaxPool2D(strides=(2, 3), padding=(1, 1)) - avg_pooling = nn.AvgPool2D(strides=(2, 2), padding=(1, 1)) - global_maxpooling = nn.GlobalMaxPool2D() - global_avgpooling = nn.GlobalAvgPool2D() - pooling_layers = [max_pooling, avg_pooling, global_maxpooling, global_avgpooling] - class Net(gluon.HybridBlock): - def __init__(self, - slice, - pooling_layer, - **kwargs): - super(Net, self).__init__(**kwargs) - with self.name_scope(): - self.slice = slice - self.pool0 = pooling_layer - - def hybrid_forward(self, F, x): - x_slice = x.slice(begin=self.slice[0], end=self.slice[1]) - out = self.pool0(x_slice) - return out - - x = mx.nd.random.uniform(shape=(16, 128, 256, 256)) - slice = [(0, 0, 0, 0), (4, 16, 32, 64)] - for i in range(len(pooling_layers)): - net = Net(slice, pooling_layers[i]) - check_layer_forward_withinput(net, x) + # transpose shape to bring feature dimension 'c' from 2nd position to last + def transpose(shape): + return (shape[0],) + shape[2:] + (shape[1],) + + for layout in ['NCHW', 'NHWC']: + max_pooling = nn.MaxPool2D(strides=(2, 3), padding=(1, 1), layout=layout) + avg_pooling = nn.AvgPool2D(strides=(2, 2), padding=(1, 1), layout=layout) + global_maxpooling = nn.GlobalMaxPool2D(layout=layout) + global_avgpooling = nn.GlobalAvgPool2D(layout=layout) + pooling_layers = [max_pooling, avg_pooling, global_maxpooling, global_avgpooling] + class Net(gluon.HybridBlock): + def __init__(self, + slice, + pooling_layer, + **kwargs): + super(Net, self).__init__(**kwargs) + with self.name_scope(): + self.slice = slice + self.pool0 = pooling_layer + + def hybrid_forward(self, F, x): + x_slice = x.slice(begin=self.slice[0], end=self.slice[1]) + out = self.pool0(x_slice) + return out + + xshape = (16, 128, 256, 256) + slice_shape = (4, 16, 32, 64) + if layout == 'NHWC': + xshape = transpose(xshape) + slice_shape = transpose(slice_shape) + x = mx.nd.random.uniform(shape=xshape) + slice = [(0, 0, 0, 0), slice_shape] + for i in range(len(pooling_layers)): + net = Net(slice, pooling_layers[i]) + check_layer_forward_withinput(net, x) @with_seed() @unittest.skip('skippping temporarily, tracked by https://github.com/apache/incubator-mxnet/issues/11164')