From b3b952f9d5490ee2707209ab866e6c3f094e2046 Mon Sep 17 00:00:00 2001 From: Hao Jin Date: Sun, 14 Apr 2019 22:04:52 -0700 Subject: [PATCH 1/9] fp16 safe norm operator (#14616) * use safe aggregation for norm * safe norm with DataType, AccuType and OutType * new test for backward * change back to MSHADOW_TYPE_SWITCH * remove dead debug outputs * Allow integer types --- src/operator/mshadow_op.h | 68 ++++- src/operator/mxnet_op.h | 83 +++++- src/operator/tensor/broadcast_reduce-inl.cuh | 61 +++-- src/operator/tensor/broadcast_reduce-inl.h | 38 ++- src/operator/tensor/broadcast_reduce_op.h | 257 +++++++++++++----- .../tensor/broadcast_reduce_op_value.cc | 2 +- src/operator/tensor/matrix_op-inl.h | 4 +- tests/python/unittest/test_operator.py | 65 +++-- 8 files changed, 430 insertions(+), 148 deletions(-) diff --git a/src/operator/mshadow_op.h b/src/operator/mshadow_op.h index c27a98ac1940..d9d6151c06bf 100644 --- a/src/operator/mshadow_op.h +++ b/src/operator/mshadow_op.h @@ -945,13 +945,13 @@ struct nanprod { /*! \brief compute l2 norm */ struct nrm2 { /*! \brief do reduction into dst */ - template - MSHADOW_XINLINE static void Reduce(volatile DType& sum_of_squares, volatile DType src) { // NOLINT(*) + template + MSHADOW_XINLINE static void Reduce(volatile AType& sum_of_squares, volatile DType src) { // NOLINT(*) sum_of_squares += src * src; } /*! \brief do stable reduction into dst */ - template - MSHADOW_XINLINE static void Reduce(volatile DType& sum_of_squares, volatile DType src, volatile DType& scale) { // NOLINT(*) + template + MSHADOW_XINLINE static void Reduce(volatile AType& sum_of_squares, volatile DType src, volatile DType& scale) { // NOLINT(*) if (src != 0) { DType abs = mshadow_op::abs::Map(src); if (scale < abs) { @@ -1012,6 +1012,66 @@ struct nrm2 { } }; +/*! \brief sum reducer */ +struct sum { + /*! \brief do reduction into dst */ + template + MSHADOW_XINLINE static void Reduce(volatile AType& dst, volatile DType src) { // NOLINT(*) + dst += src; + } + /*! \brief do stable reduction into dst */ + template + MSHADOW_XINLINE static void Reduce(volatile AType& dst, volatile DType src, volatile DType& residual) { // NOLINT(*) + DType y = src - residual; + DType t = dst + y; + residual = (t - dst) - y; + dst = t; + } + /*! \brief combine the results of two reducers */ + template + MSHADOW_XINLINE static void Merge(volatile DType& dst_val, volatile DType& src_val) { // NOLINT(*) + Reduce(dst_val, src_val); + } + /*! \brief combine the results of two reducers */ + template + MSHADOW_XINLINE static void Merge(volatile DType& dst_val, volatile DType& dst_residual, volatile DType& src_val, volatile DType& src_residual) { // NOLINT(*) + DType t1 = dst_val + src_val; + DType e = t1 - dst_val; + DType t2 = ((src_val - e) + (dst_val - (t1 - e))) + dst_residual + src_residual; + dst_val = t1 + t2; + dst_residual = t2 - (dst_val - t1); + } + /*! \brief finalize reduction */ + template + MSHADOW_XINLINE static void Finalize(volatile DType& dst) {} // NOLINT(*) + /*! \brief finalize reduction */ + template + MSHADOW_XINLINE static void Finalize(volatile DType& dst, volatile DType& residual) {} // NOLINT(*) + /*! + *\brief calculate gradient of redres with respect to redsrc, + * redres: reduced result, redsrc: one of reduction element + */ + template + MSHADOW_XINLINE static DType PartialGrad(DType redres, DType redsrc) { + return 1; + } + /*! + *\brief set the initial value during reduction + */ + template + MSHADOW_XINLINE static void SetInitValue(DType &initv) { // NOLINT(*) + initv = 0; + } + /*! + *\brief set the initial value during reduction + */ + template + MSHADOW_XINLINE static void SetInitValue(DType &initv, DType &residual) { // NOLINT(*) + SetInitValue(initv); + residual = 0; + } +}; + struct nanprod_grad : public mxnet_op::tunable { template MSHADOW_XINLINE static DType Map(DType a, DType b) { diff --git a/src/operator/mxnet_op.h b/src/operator/mxnet_op.h index d8fc5031e4ff..a937f839c9bb 100644 --- a/src/operator/mxnet_op.h +++ b/src/operator/mxnet_op.h @@ -273,20 +273,87 @@ inline int get_num_threads(const int N) { } \ break; \ case mshadow::kUint8: \ - LOG(FATAL) << "This operation only support " \ - "floating point types not uint8"; \ + { \ + typedef uint8_t DType; \ + typedef uint8_t AType; \ + LOG(FATAL) << "This operation only support " \ + "floating point types not uint8"; \ + } \ + break; \ + case mshadow::kInt8: \ + { \ + typedef int8_t DType; \ + typedef int8_t AType; \ + LOG(FATAL) << "This operation only support " \ + "floating point types not int8"; \ + } \ + break; \ + case mshadow::kInt32: \ + { \ + typedef int32_t DType; \ + typedef int32_t AType; \ + LOG(FATAL) << "This operation only support " \ + "floating point types, not int32"; \ + } \ + break; \ + case mshadow::kInt64: \ + { \ + typedef int64_t DType; \ + typedef int64_t AType; \ + LOG(FATAL) << "This operation only support " \ + "floating point types, not int64"; \ + } \ + break; \ + default: \ + LOG(FATAL) << "Unknown type enum " << type; \ + } + +#define MXNET_ACC_TYPE_SWITCH(type, DType, AType, ...)\ + switch (type) { \ + case mshadow::kFloat32: \ + { \ + typedef float DType; \ + typedef double AType; \ + {__VA_ARGS__} \ + } \ + break; \ + case mshadow::kFloat64: \ + { \ + typedef double DType; \ + typedef double AType; \ + {__VA_ARGS__} \ + } \ + break; \ + case mshadow::kFloat16: \ + { \ + typedef mshadow::half::half_t DType; \ + typedef float AType; \ + {__VA_ARGS__} \ + } \ + break; \ + case mshadow::kUint8: \ + { \ + typedef uint8_t DType; \ + typedef uint32_t AType; \ + } \ break; \ case mshadow::kInt8: \ - LOG(FATAL) << "This operation only support " \ - "floating point types not int8"; \ + { \ + typedef int8_t DType; \ + typedef int32_t AType; \ + } \ break; \ case mshadow::kInt32: \ - LOG(FATAL) << "This operation only support " \ - "floating point types, not int32"; \ + { \ + typedef int32_t DType; \ + typedef int64_t AType; \ + } \ break; \ case mshadow::kInt64: \ - LOG(FATAL) << "This operation only support " \ - "floating point types, not int64"; \ + { \ + typedef int64_t DType; \ + typedef int64_t AType; \ + } \ break; \ default: \ LOG(FATAL) << "Unknown type enum " << type; \ diff --git a/src/operator/tensor/broadcast_reduce-inl.cuh b/src/operator/tensor/broadcast_reduce-inl.cuh index 5d6c49ff8882..54db35061c6a 100644 --- a/src/operator/tensor/broadcast_reduce-inl.cuh +++ b/src/operator/tensor/broadcast_reduce-inl.cuh @@ -72,15 +72,15 @@ void BinaryBroadcastComputeImpl(Stream *s, const OpReqType req, } const int nthread_reduce = kMaxThreadsPerBlock; -template +template __launch_bounds__(nthread_reduce) __global__ void reduce_kernel(const int N, const int M, const bool addto, - const DType* __restrict big, DType *small, + const DType* __restrict big, OType *small, const Shape big_shape0, const Shape small_shape, const Shape big_shape, const Shape big_stride, const int Mnext, const bool do_transpose) { extern __shared__ char shTileChar[]; - DType* shTile = (DType*)(shTileChar); + AType* shTile = (AType*)(shTileChar); const int tid = threadIdx.x + threadIdx.y*blockDim.x; const int bx = (do_transpose) ? blockDim.y : blockDim.x; const int by = (do_transpose) ? blockDim.x : blockDim.y; @@ -95,7 +95,7 @@ __global__ void reduce_kernel(const int N, const int M, const bool addto, Shape coord = unravel(idx, small_shape); int idx_big0 = ravel(coord, big_shape0); - DType val, residual; + AType val, residual; Reducer::SetInitValue(val, residual); if (idx < N) { for (int k = tidy + Mstart; k < Mend; k += by*unroll) { @@ -113,7 +113,7 @@ __global__ void reduce_kernel(const int N, const int M, const bool addto, } #pragma unroll for (int u=0;u < unroll;u++) { - if (k + u*by < Mend) Reducer::Reduce(val, tmp[u], residual); + if (k + u*by < Mend) Reducer::Reduce(val, AType(tmp[u]), residual); } } } @@ -127,7 +127,7 @@ __global__ void reduce_kernel(const int N, const int M, const bool addto, shTile[it0 * 2 + 1] = residual; __syncthreads(); for (int t=1;t < by;t <<= 1) { - DType tmp, tmp_residual; + AType tmp, tmp_residual; Reducer::SetInitValue(tmp, tmp_residual); if (tidy + t < by) { tmp = shTile[(it0 + t*fbx) * 2]; @@ -139,12 +139,12 @@ __global__ void reduce_kernel(const int N, const int M, const bool addto, } if (idx < N && tidy == 0) { Reducer::Finalize(shTile[tidx * 2], shTile[tidx * 2 + 1]); - assign(&small[idx + m0*N], addto, shTile[tidx * 2]); + assign(&small[idx + m0*N], addto, OType(shTile[tidx * 2])); } } else { if (idx < N) { Reducer::Finalize(val, residual); - assign(&small[idx + m0*N], addto, val); + assign(&small[idx + m0*N], addto, OType(val)); } } } @@ -261,18 +261,18 @@ __global__ void reduce_lines_kernel(const int N, const int M, const bool addto, } } -template +template __global__ void reduce_kernel_M1(const int N, const bool addto, - const DType* __restrict big, DType *small, const Shape bshape, + const DType* __restrict big, OType *small, const Shape bshape, const Shape sshape) { for (int idx = threadIdx.x + blockIdx.x*blockDim.x; idx < N; idx += blockDim.x*gridDim.x) { Shape coord = unravel(idx, sshape); int j = ravel(coord, bshape); - DType val, residual; + AType val, residual; Reducer::SetInitValue(val, residual); - Reducer::Reduce(val, OP::Map(big[j]), residual); + Reducer::Reduce(val, AType(OP::Map(big[j])), residual); Reducer::Finalize(val, residual); - assign(&small[idx], addto, val); + assign(&small[idx], addto, OType(val)); } } @@ -491,7 +491,7 @@ ReduceImplConfig ConfigureReduceImpl(const mxnet::TShape& small, const mxn if (config.Mnext > 1) { // small_dptr[] is N*Mnext*sizeof(DType) bytes - config.workspace_size += config.N*config.Mnext*sizeof(DType); + config.workspace_size += config.N*config.Mnext*sizeof(double); // Set gridDim.y to Mnext config.kernel_1.gridDim.y = std::min(kBaseGridNum, config.Mnext); } @@ -516,23 +516,22 @@ ReduceImplConfig ConfigureReduceImpl(const mxnet::TShape& small, const mxn {__VA_ARGS__} \ } -template +template void ReduceImpl(cudaStream_t stream, const TBlob& small, const OpReqType req, const TBlob& big, const Tensor& workspace, const ReduceImplConfig& config) { if (config.M == 1) { - reduce_kernel_M1 + reduce_kernel_M1 <<< config.kernel_1.gridDim, config.kernel_1.blockDim, 0, stream >>>( - config.N, req == kAddTo, big.dptr(), small.dptr(), big.shape_.get(), + config.N, req == kAddTo, big.dptr(), small.dptr(), big.shape_.get(), small.shape_.get()); MSHADOW_CUDA_POST_KERNEL_CHECK(reduce_kernel_M1); } else { - - DType* small_dptr = small.dptr(); + OType* small_dptr = small.dptr(); bool addto = (req == kAddTo); if (config.Mnext > 1) { // small_dptr[] is N*Mnext*sizeof(DType) bytes - small_dptr = reinterpret_cast(workspace.dptr_); + small_dptr = reinterpret_cast(workspace.dptr_); addto = false; // Check that the workspace is contigiuous CHECK_EQ(workspace.CheckContiguous(), true); @@ -544,7 +543,7 @@ void ReduceImpl(cudaStream_t stream, const TBlob& small, const OpReqType req, config.kernel_1.blockDim.x : config.kernel_1.blockDim.y; const bool do_unroll = ( config.M / (by*config.Mnext) >= config.unroll_reduce ); KERNEL_UNROLL_SWITCH(do_unroll, ReduceImplConfig::unroll_reduce, UNROLL, { - reduce_kernel + reduce_kernel <<< config.kernel_1.gridDim, config.kernel_1.blockDim, config.kernel_1.shMemSize, stream>>>( config.N, config.M, addto, big.dptr(), small_dptr, big.shape_.get(), small.shape_.get(), config.rshape, config.rstride, config.Mnext, @@ -553,9 +552,9 @@ void ReduceImpl(cudaStream_t stream, const TBlob& small, const OpReqType req, MSHADOW_CUDA_POST_KERNEL_CHECK(reduce_kernel); if (config.Mnext > 1) { - reduce_lines_kernel + reduce_lines_kernel <<< config.kernel_2.gridSize, config.kernel_2.blockSize, 0, stream >>> - (config.N, config.Mnext, req == kAddTo, config.N, small_dptr, small.dptr()); + (config.N, config.Mnext, req == kAddTo, config.N, small_dptr, small.dptr()); MSHADOW_CUDA_POST_KERNEL_CHECK(reduce_lines_kernel); } } @@ -610,14 +609,26 @@ void ReduceImpl(cudaStream_t stream, const TBlob& small, const TBlob& lhs, const #undef KERNEL_UNROLL_SWITCH -template +template void Reduce(Stream *s, const TBlob& small, const OpReqType req, const Tensor& workspace, const TBlob& big) { if (req == kNullOp) return; cudaStream_t stream = Stream::GetStream(s); ReduceImplConfig config = ConfigureReduceImpl(small.shape_, big.shape_, NULL, NULL); - ReduceImpl(stream, small, req, big, workspace, config); + if (safe_acc) { + MXNET_ACC_TYPE_SWITCH(mshadow::DataType::kFlag, DataType, AType, { + typedef typename std::conditional::type AccType; + MSHADOW_TYPE_SWITCH(small.type_flag_, OType, { + typedef typename std::conditional::type OutType; + config = ConfigureReduceImpl(small.shape_, big.shape_, NULL, NULL); + ReduceImpl( + stream, small, req, big, workspace, config); + }); + }); + } else { + ReduceImpl(stream, small, req, big, workspace, config); + } } template diff --git a/src/operator/tensor/broadcast_reduce-inl.h b/src/operator/tensor/broadcast_reduce-inl.h index 0f6913e6e9df..be589c41168b 100644 --- a/src/operator/tensor/broadcast_reduce-inl.h +++ b/src/operator/tensor/broadcast_reduce-inl.h @@ -153,21 +153,21 @@ MSHADOW_XINLINE void binary_broadcast_assign(const index_t idx, const bool addto assign(&out[idx], addto, OP::Map(lhs[j], rhs[k])); } -template +template MSHADOW_XINLINE void seq_reduce_assign(const index_t idx, const size_t M, const bool addto, - const DType* __restrict big, DType *small, + const DType* __restrict big, OType *small, const Shape& bshape, const Shape& sshape, const Shape& rshape, const Shape& rstride) { Shape coord = unravel(idx, sshape); index_t j = ravel(coord, bshape); - DType val, residual; + AType val, residual; Reducer::SetInitValue(val, residual); for (size_t k = 0; k < M; ++k) { coord = unravel(k, rshape); - Reducer::Reduce(val, OP::Map(big[j + dot(coord, rstride)]), residual); + Reducer::Reduce(val, AType(OP::Map(big[j + dot(coord, rstride)])), residual); } Reducer::Finalize(val, residual); - assign(&small[idx], addto, val); + assign(&small[idx], addto, OType(val)); } #ifdef __CUDACC__ @@ -194,15 +194,15 @@ void BinaryBroadcastComputeImpl(Stream *s, const OpReqType req, out.shape_.get()); } -template +template void seq_reduce_compute(const size_t N, const size_t M, const bool addto, - const DType *big, DType *small, const Shape bshape, + const DType *big, OType *small, const Shape bshape, const Shape sshape, const Shape rshape, const Shape rstride) { #pragma omp parallel for num_threads(engine::OpenMP::Get()->GetRecommendedOMPThreadCount()) for (index_t idx = 0; idx < static_cast(N); ++idx) { - seq_reduce_assign(idx, M, addto, big, small, bshape, sshape, rshape, - rstride); + seq_reduce_assign(idx, M, addto, big, small, + bshape, sshape, rshape, rstride); } } @@ -227,16 +227,28 @@ void seq_reduce_compute_extra_mem(const size_t N, const size_t M, const bool add } } -template +template void Reduce(Stream* s, const TBlob& small, const OpReqType req, const Tensor& workspace, const TBlob& big) { if (req == kNullOp) return; Shape rshape, rstride; diff(small.shape_.get(), big.shape_.get(), &rshape, &rstride); size_t N = small.shape_.Size(), M = rshape.Size(); - seq_reduce_compute( - N, M, req == kAddTo, big.dptr(), small.dptr(), - big.shape_.get(), small.shape_.get(), rshape, rstride); + if (!safe_acc) { + seq_reduce_compute( + N, M, req == kAddTo, big.dptr(), small.dptr(), + big.shape_.get(), small.shape_.get(), rshape, rstride); + } else { + MXNET_ACC_TYPE_SWITCH(mshadow::DataType::kFlag, DataType, AType, { + typedef typename std::conditional::type AccType; + MSHADOW_TYPE_SWITCH(small.type_flag_, OType, { + typedef typename std::conditional::type OutType; + seq_reduce_compute( + N, M, req == kAddTo, big.dptr(), small.dptr(), + big.shape_.get(), small.shape_.get(), rshape, rstride); + }); + }); + } } template diff --git a/src/operator/tensor/broadcast_reduce_op.h b/src/operator/tensor/broadcast_reduce_op.h index b13906af6624..069c8ddb04fb 100644 --- a/src/operator/tensor/broadcast_reduce_op.h +++ b/src/operator/tensor/broadcast_reduce_op.h @@ -67,6 +67,7 @@ struct ReduceAxesParam : public dmlc::Parameter { struct NormParam : public dmlc::Parameter { int ord; dmlc::optional axis; + dmlc::optional out_dtype; bool keepdims; DMLC_DECLARE_PARAMETER(NormParam) { DMLC_DECLARE_FIELD(ord).set_default(2) @@ -78,6 +79,15 @@ struct NormParam : public dmlc::Parameter { If `axis` is int, a reduction is performed on a particular axis. If `axis` is a 2-tuple, it specifies the axes that hold 2-D matrices, and the matrix norms of these matrices are computed.)code"); + DMLC_DECLARE_FIELD(out_dtype) + .add_enum("float16", mshadow::kFloat16) + .add_enum("float32", mshadow::kFloat32) + .add_enum("float64", mshadow::kFloat64) + .add_enum("int64", mshadow::kInt64) + .add_enum("int32", mshadow::kInt32) + .add_enum("int8", mshadow::kInt8) + .set_default(dmlc::optional()) + .describe(R"code(The data type of the output.)code"); DMLC_DECLARE_FIELD(keepdims).set_default(false) .describe("If this is set to `True`, the reduced axis is left " "in the result as dimension with size one."); @@ -302,6 +312,23 @@ inline bool ReduceAxesShape(const nnvm::NodeAttrs& attrs, return true; } +inline bool NormType(const nnvm::NodeAttrs& attrs, + std::vector *in_attrs, + std::vector *out_attrs) { + CHECK_EQ(in_attrs->size(), 1U); + CHECK_EQ(out_attrs->size(), 1U); + const NormParam& param = nnvm::get(attrs.parsed); + if (param.out_dtype.has_value()) { + CHECK_NE(in_attrs->at(0), -1) + << "input data type should be specified when out_dtype is not null"; + TYPE_ASSIGN_CHECK(*out_attrs, 0, param.out_dtype.value()); + } else { + TYPE_ASSIGN_CHECK(*out_attrs, 0, (*in_attrs)[0]); + TYPE_ASSIGN_CHECK(*in_attrs, 0, (*out_attrs)[0]); + } + return (*out_attrs)[0] != -1; +} + inline bool NormShape(const nnvm::NodeAttrs& attrs, mxnet::ShapeVector *in_attrs, mxnet::ShapeVector *out_attrs) { @@ -525,7 +552,7 @@ void SearchAxisCompute(const nnvm::NodeAttrs& attrs, }); } -template void ReduceAxesComputeImpl(const OpContext& ctx, const std::vector& inputs, @@ -538,20 +565,22 @@ void ReduceAxesComputeImpl(const OpContext& ctx, mxnet::TShape src_shape, dst_shape; BroadcastReduceShapeCompact(inputs[0].shape_, small, &src_shape, &dst_shape); Stream *s = ctx.get_stream(); - MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, { - const TBlob in_data = inputs[0].reshape(src_shape); - const TBlob out_data = outputs[0].reshape(dst_shape); - BROADCAST_NDIM_SWITCH(dst_shape.ndim(), NDim, { - size_t workspace_size = broadcast::ReduceWorkspaceSize( - s, out_data.shape_, req[0], in_data.shape_); - Tensor workspace = - ctx.requested[0].get_space_typed(Shape1(workspace_size), s); - broadcast::Reduce( - s, out_data, req[0], workspace, in_data); - if (normalize) { - auto out = out_data.FlatTo2D(s); - out /= scalar(src_shape.Size()/dst_shape.Size()); - } + MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, DType, { + MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, OType, { + const TBlob in_data = inputs[0].reshape(src_shape); + const TBlob out_data = outputs[0].reshape(dst_shape); + BROADCAST_NDIM_SWITCH(dst_shape.ndim(), NDim, { + size_t workspace_size = broadcast::ReduceWorkspaceSize( + s, out_data.shape_, req[0], in_data.shape_); + Tensor workspace = + ctx.requested[0].get_space_typed(Shape1(workspace_size), s); + broadcast::Reduce( + s, out_data, req[0], workspace, in_data); + if (normalize) { + auto out = out_data.FlatTo2D(s); + out /= scalar(src_shape.Size()/dst_shape.Size()); + } + }); }); }); } @@ -571,7 +600,7 @@ void ReduceAxesCompute(const nnvm::NodeAttrs& attrs, small = ReduceAxesShapeImpl(inputs[0].shape_, param.axis, true, param.exclude); } - ReduceAxesComputeImpl(ctx, inputs, req, outputs, small); + ReduceAxesComputeImpl(ctx, inputs, req, outputs, small); } template @@ -813,6 +842,35 @@ void ReduceAxesOpForwardEx(const nnvm::NodeAttrs& attrs, const OpContext& ctx, } } +template +struct reduce_axes_backward_broadcast { + template + MSHADOW_XINLINE static void Map(index_t i, + DType *data, + OType *out, + DType *igrad, + OType *ograd, + mshadow::Shape<5> in_shape, + mshadow::Shape<5> out_shape, + const uint32_t ndim) { + size_t in_stride = 1; + size_t out_stride = 1; + index_t idx = i; + index_t out_idx = i; + for (int iter = ndim - 1; iter >= 0; --iter) { + size_t dim_idx = idx % in_shape[iter]; + out_idx -= dim_idx * in_stride; + if (out_shape[iter] != 1) { + out_idx += dim_idx * out_stride; + } + idx /= in_shape[iter]; + in_stride *= in_shape[iter]; + out_stride *= out_shape[iter]; + } + KERNEL_ASSIGN(igrad[i], req, DType(ograd[out_idx]) * OP::Map(data[i], DType(out[out_idx]))); + } +}; + template void ReduceAxesBackwardUseInOutImpl(const OpContext& ctx, const mxnet::TShape &small, @@ -821,37 +879,58 @@ void ReduceAxesBackwardUseInOutImpl(const OpContext& ctx, const std::vector& outputs) { using namespace mshadow; using namespace mshadow::expr; + using namespace mxnet_op; mxnet::TShape src_shape, dst_shape; BroadcastReduceShapeCompact(outputs[0].shape_, small, &src_shape, &dst_shape); Stream *s = ctx.get_stream(); - MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, { - if (dst_shape.ndim() == 2) { - Tensor igrad = - outputs[0].get_with_shape(src_shape.get<2>(), s); - Tensor ograd = - inputs[0].get_with_shape(dst_shape.get<2>(), s); - Tensor data = - inputs[1].get_with_shape(src_shape.get<2>(), s); - Tensor out = - inputs[2].get_with_shape(dst_shape.get<2>(), s); - ASSIGN_DISPATCH(igrad, req[0], - broadcast_to(ograd, src_shape)*F(data, broadcast_to(out, src_shape))); - if (normalize) igrad /= scalar(src_shape.Size()/dst_shape.Size()); - } else { - const int ndim = MXNET_SPECIAL_MAX_NDIM; - Tensor igrad = - outputs[0].get_with_shape(src_shape.get(), s); - Tensor ograd = - inputs[0].get_with_shape(dst_shape.get(), s); - Tensor data = - inputs[1].get_with_shape(src_shape.get(), s); - Tensor out = - inputs[2].get_with_shape(dst_shape.get(), s); - ASSIGN_DISPATCH(igrad, req[0], - broadcast_to(ograd, src_shape)*F(data, broadcast_to(out, src_shape))); - if (normalize) igrad /= scalar(src_shape.Size()/dst_shape.Size()); - } + + MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, DType, { + MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, OType, { + mshadow::Shape<5> in_shape; + mshadow::Shape<5> out_shape; + for (uint32_t i = 0; i < 5; ++i) { + if (i < dst_shape.ndim()) { + in_shape[i] = src_shape[i]; + out_shape[i] = dst_shape[i]; + } else { + in_shape[i] = 1; + out_shape[i] = 1; + } + } + if (dst_shape.ndim() == 2) { + Tensor igrad = + outputs[0].get_with_shape(src_shape.get<2>(), s); + Tensor ograd = + inputs[0].get_with_shape(dst_shape.get<2>(), s); + Tensor data = + inputs[1].get_with_shape(src_shape.get<2>(), s); + Tensor out = + inputs[2].get_with_shape(dst_shape.get<2>(), s); + MXNET_REQ_TYPE_SWITCH(req[0], Req, { + Kernel, xpu>::Launch( + s, outputs[0].shape_.Size(), data.dptr_, out.dptr_, igrad.dptr_, ograd.dptr_, + in_shape, out_shape, src_shape.ndim()); + }); + if (normalize) igrad /= scalar(src_shape.Size()/dst_shape.Size()); + } else { + const int ndim = MXNET_SPECIAL_MAX_NDIM; + Tensor igrad = + outputs[0].get_with_shape(src_shape.get(), s); + Tensor ograd = + inputs[0].get_with_shape(dst_shape.get(), s); + Tensor data = + inputs[1].get_with_shape(src_shape.get(), s); + Tensor out = + inputs[2].get_with_shape(dst_shape.get(), s); + MXNET_REQ_TYPE_SWITCH(req[0], Req, { + Kernel, xpu>::Launch( + s, outputs[0].shape_.Size(), data.dptr_, out.dptr_, igrad.dptr_, ograd.dptr_, + in_shape, out_shape, src_shape.ndim()); + }); + if (normalize) igrad /= scalar(src_shape.Size()/dst_shape.Size()); + } + }); }); } @@ -1090,14 +1169,42 @@ void LpNormCompute(const nnvm::NodeAttrs& attrs, small = ReduceAxesShapeImpl(inputs[0].shape_, param.axis, true, false); } if (param.ord == 1) { - ReduceAxesComputeImpl( - ctx, inputs, req, outputs, small); + ReduceAxesComputeImpl( + ctx, inputs, req, outputs, small); } else if (param.ord == 2) { - ReduceAxesComputeImpl( + ReduceAxesComputeImpl( ctx, inputs, req, outputs, small); } } +template +struct norm_backward_broadcast { + template + MSHADOW_XINLINE static void Map(index_t i, + DType *igrad, + OType *ograd, + DType *data, + mshadow::Shape<5> in_shape, + mshadow::Shape<5> out_shape, + const uint32_t ndim) { + size_t in_stride = 1; + size_t out_stride = 1; + index_t idx = i; + index_t out_idx = i; + for (int iter = ndim - 1; iter >= 0; --iter) { + size_t dim_idx = idx % in_shape[iter]; + out_idx -= dim_idx * in_stride; + if (out_shape[iter] != 1) { + out_idx += dim_idx * out_stride; + } + idx /= in_shape[iter]; + in_stride *= in_shape[iter]; + out_stride *= out_shape[iter]; + } + KERNEL_ASSIGN(igrad[i], req, ograd[out_idx] * mshadow_op::sign::Map(data[i])); + } +}; + template void LpNormGradCompute(const nnvm::NodeAttrs& attrs, const OpContext& ctx, @@ -1106,6 +1213,7 @@ void LpNormGradCompute(const nnvm::NodeAttrs& attrs, const std::vector& outputs) { using namespace mshadow; using namespace mshadow::expr; + using namespace mxnet_op; if (req[0] == kNullOp) return; const NormParam& param = nnvm::get(attrs.parsed); @@ -1119,27 +1227,46 @@ void LpNormGradCompute(const nnvm::NodeAttrs& attrs, mxnet::TShape src_shape, dst_shape; BroadcastReduceShapeCompact(outputs[0].shape_, small, &src_shape, &dst_shape); Stream *s = ctx.get_stream(); - MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, { - if (dst_shape.ndim() == 2) { - Tensor ograd = - inputs[0].get_with_shape(dst_shape.get<2>(), s); - Tensor igrad = - outputs[0].get_with_shape(src_shape.get<2>(), s); - Tensor data = - inputs[1].get_with_shape(src_shape.get<2>(), s); - ASSIGN_DISPATCH(igrad, req[0], - broadcast_to(ograd, src_shape)*F(data)); + mshadow::Shape<5> in_shape; + mshadow::Shape<5> out_shape; + for (uint32_t i = 0; i < 5; ++i) { + if (i < dst_shape.ndim()) { + in_shape[i] = src_shape[i]; + out_shape[i] = dst_shape[i]; } else { - const int ndim = MXNET_SPECIAL_MAX_NDIM; - Tensor igrad = - outputs[0].get_with_shape(src_shape.get(), s); - Tensor ograd = - inputs[0].get_with_shape(dst_shape.get(), s); - Tensor data = - inputs[1].get_with_shape(src_shape.get(), s); - ASSIGN_DISPATCH(igrad, req[0], - broadcast_to(ograd, src_shape)*F(data)); + in_shape[i] = 1; + out_shape[i] = 1; } + } + MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, { + MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, OType, { + if (dst_shape.ndim() == 2) { + Tensor ograd = + inputs[0].get_with_shape(dst_shape.get<2>(), s); + Tensor igrad = + outputs[0].get_with_shape(src_shape.get<2>(), s); + Tensor data = + inputs[1].get_with_shape(src_shape.get<2>(), s); + MXNET_REQ_TYPE_SWITCH(req[0], Req, { + Kernel, xpu>::Launch( + s, igrad.shape_.Size(), igrad.dptr_, ograd.dptr_, data.dptr_, + in_shape, out_shape, src_shape.ndim()); + }); + } else { + const int ndim = MXNET_SPECIAL_MAX_NDIM; + Tensor igrad = + outputs[0].get_with_shape(src_shape.get(), s); + Tensor ograd = + inputs[0].get_with_shape(dst_shape.get(), s); + Tensor data = + inputs[1].get_with_shape(src_shape.get(), s); + MXNET_REQ_TYPE_SWITCH(req[0], Req, { + Kernel, xpu>::Launch( + s, igrad.shape_.Size(), igrad.dptr_, ograd.dptr_, data.dptr_, + in_shape, out_shape, src_shape.ndim()); + }); + } + }); }); } else if (param.ord == 2) { ReduceAxesBackwardUseInOutImpl(ctx, small, inputs, diff --git a/src/operator/tensor/broadcast_reduce_op_value.cc b/src/operator/tensor/broadcast_reduce_op_value.cc index 52fd61aa110e..f4231917e90d 100644 --- a/src/operator/tensor/broadcast_reduce_op_value.cc +++ b/src/operator/tensor/broadcast_reduce_op_value.cc @@ -352,7 +352,7 @@ Examples:: .set_num_outputs(1) .set_attr_parser(ParamParser) .set_attr("FInferShape", NormShape) -.set_attr("FInferType", ElemwiseType<1, 1>) +.set_attr("FInferType", NormType) .set_attr("FInferStorageType", LpNormStorageType) .set_attr("FGradient", ReduceGrad{ "_backward_norm" }) .set_attr("FResourceRequest", diff --git a/src/operator/tensor/matrix_op-inl.h b/src/operator/tensor/matrix_op-inl.h index fa108158b5c9..ba62d0e9def7 100644 --- a/src/operator/tensor/matrix_op-inl.h +++ b/src/operator/tensor/matrix_op-inl.h @@ -1732,7 +1732,7 @@ void RepeatOpBackward(const nnvm::NodeAttrs& attrs, inputs[0].type_flag_, inputs[0].dev_id()); std::vector newInputs = {iblob}; - ReduceAxesComputeImpl( + ReduceAxesComputeImpl( ctx, newInputs, req, newOutputs, rshapes.first); } @@ -1914,7 +1914,7 @@ void TileOpBackward(const nnvm::NodeAttrs& attrs, inputs[0].type_flag_, inputs[0].dev_id()); std::vector newInputs = {iblob}; - ReduceAxesComputeImpl( + ReduceAxesComputeImpl( ctx, newInputs, req, newOutputs, rshapes.first); } diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index ccb351f434da..59d72d4b18b6 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -3384,41 +3384,46 @@ def l2norm(input_data, axis=0, keepdims=True): ctx = default_context() data = mx.symbol.Variable('data') in_data_dim = random_sample([4,5,6], 1)[0] - in_shape = rand_shape_nd(in_data_dim) + in_shape = rand_shape_nd(in_data_dim, dim=5) epsilon = 1e-3 + acc_type = {np.float16: np.float32, np.float32: np.float32, np.float64: np.float64} for order in [1, 2]: for dtype in [np.float16, np.float32, np.float64]: - in_data = np.random.uniform(-1, 1, in_shape).astype(dtype) - in_data[abs(in_data) < epsilon] = 2 * epsilon for i in range(in_data_dim): - norm_sym = mx.symbol.norm(data=data, ord=order, axis=i, keepdims=True) - npy_out = l1norm(in_data, i) if order is 1 else l2norm(in_data, i) - npy_out_backward = np.sign(in_data) if order is 1 else in_data/npy_out - check_symbolic_forward(norm_sym, [in_data], [npy_out], - rtol=1e-2 if dtype is np.float16 else 1e-5, - atol=1e-2 if dtype is np.float16 else 1e-5, ctx=ctx) - check_symbolic_backward(norm_sym, [in_data], [np.ones(npy_out.shape)], - [npy_out_backward], - rtol=1e-2 if dtype is np.float16 else 1e-5, - atol=1e-2 if dtype is np.float16 else 1e-5, ctx=ctx) - # Disable numeric gradient https://github.com/apache/incubator-mxnet/issues/11509 - # # check gradient - # if dtype is not np.float16: - # check_numeric_gradient(norm_sym, [in_data], numeric_eps=epsilon, rtol=1e-1, atol=1e-3) - if i < in_data_dim-1: - norm_sym = mx.symbol.norm(data=data, ord=order, axis=(i, i+1), keepdims=True) - npy_out = l1norm(in_data, (i, i+1)) if order is 1 else l2norm(in_data, (i, i+1)) + for out_dtype in ['float32', 'float64']: + backward_dtype = np.float32 if out_dtype == 'float32' else np.float64 + print(order, dtype, i, out_dtype, in_shape) + in_data = np.random.uniform(-1, 1, in_shape).astype(acc_type[dtype]) + in_data[abs(in_data) < epsilon] = 2 * epsilon + norm_sym = mx.symbol.norm(data=data, ord=order, axis=i, out_dtype=out_dtype, keepdims=True) + npy_out = l1norm(in_data, i) if order is 1 else l2norm(in_data, i) npy_out_backward = np.sign(in_data) if order is 1 else in_data/npy_out - check_symbolic_forward(norm_sym, [in_data], [npy_out], - rtol=1e-2 if dtype is np.float16 else 1e-5, - atol=1e-2 if dtype is np.float16 else 1e-5, ctx=ctx) - check_symbolic_backward(norm_sym, [in_data], [np.ones(npy_out.shape)], - [npy_out_backward], - rtol=1e-2 if dtype is np.float16 else 1e-5, - atol=1e-2 if dtype is np.float16 else 1e-5, ctx=ctx) - # # check gradient - # if dtype is not np.float16: - # check_numeric_gradient(norm_sym, [in_data], numeric_eps=epsilon, rtol=1e-1, atol=1e-3) + check_symbolic_forward(norm_sym, [in_data.astype(dtype)], [npy_out.astype(out_dtype)], + rtol=1e-3, atol=1e-5, ctx=ctx) + check_symbolic_backward(norm_sym, [in_data.astype(dtype)], + [np.ones(npy_out.shape).astype(out_dtype)], + [npy_out_backward], rtol=1e-3, atol=1e-5, ctx=ctx, + dtype=backward_dtype) + # Disable numeric gradient https://github.com/apache/incubator-mxnet/issues/11509 + # check gradient + if dtype is not np.float16: + check_numeric_gradient(norm_sym, [in_data], numeric_eps=epsilon, + rtol=1e-1, atol=1e-3, dtype=backward_dtype) + if i < in_data_dim-1: + norm_sym = mx.symbol.norm(data=data, ord=order, axis=(i, i+1), keepdims=True) + npy_out = l1norm(in_data, (i, i+1)) if order is 1 else l2norm(in_data, (i, i+1)) + npy_out_backward = np.sign(in_data) if order is 1 else in_data/npy_out + check_symbolic_forward(norm_sym, [in_data], [npy_out.astype(dtype)], + rtol=1e-3 if dtype is np.float16 else 1e-3, + atol=1e-5 if dtype is np.float16 else 1e-5, ctx=ctx) + check_symbolic_backward(norm_sym, [in_data], + [np.ones(npy_out.shape).astype(out_dtype)], + [npy_out_backward.astype(out_dtype)], + rtol=1e-3, atol=1e-5, ctx=ctx, dtype=backward_dtype) + # check gradient + if dtype is not np.float16: + check_numeric_gradient(norm_sym, [in_data], numeric_eps=epsilon, + rtol=1e-1, atol=1e-3, dtype=backward_dtype) def test_layer_norm(): From f90d1c0d51f68c7f0efe39a64d44aa24874985e1 Mon Sep 17 00:00:00 2001 From: Pedro Larroy Date: Mon, 15 Apr 2019 19:34:58 -0700 Subject: [PATCH 2/9] Use ubuntu_rat container for rat check (#14678) --- dev_menu.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dev_menu.py b/dev_menu.py index cea0f96793a9..d439d8194f2a 100755 --- a/dev_menu.py +++ b/dev_menu.py @@ -123,7 +123,7 @@ def create_virtualenv_default(): ('[Docker] sanity_check. Check for linting and code formatting and licenses.', [ "ci/build.py --platform ubuntu_cpu /work/runtime_functions.sh sanity_check", - "ci/build.py --platform ubuntu_cpu /work/runtime_functions.sh nightly_test_rat_check", + "ci/build.py --platform ubuntu_rat /work/runtime_functions.sh nightly_test_rat_check", ]), ('[Docker] Python3 CPU unittests', [ From 413fe97d01c9832e5590b1691c55579531a8289b Mon Sep 17 00:00:00 2001 From: Pedro Larroy Date: Mon, 15 Apr 2019 19:35:49 -0700 Subject: [PATCH 3/9] Avoid uneccesary vector copies in imperative_utils.cc (#14665) --- src/imperative/imperative_utils.cc | 58 ++++++++++++++++-------------- src/imperative/imperative_utils.h | 4 +-- 2 files changed, 33 insertions(+), 29 deletions(-) diff --git a/src/imperative/imperative_utils.cc b/src/imperative/imperative_utils.cc index 6cb4a70324b5..c7204c1d85e6 100644 --- a/src/imperative/imperative_utils.cc +++ b/src/imperative/imperative_utils.cc @@ -20,62 +20,61 @@ #include "./imperative_utils.h" #include "./cached_op.h" -namespace mxnet { -namespace imperative { +namespace { -inline std::vector NodeInputs(const nnvm::IndexedGraph& idx, - const int node_idx, - const std::vector arrays) { +std::vector NodeInputs(const nnvm::IndexedGraph& idx, + const int node_idx, + const std::vector& arrays) { const nnvm::IndexedGraph::Node& node = idx[node_idx]; const size_t num_inputs = node.inputs.size(); std::vector ndinputs; ndinputs.reserve(num_inputs); for (const auto& j : node.inputs) { - size_t eid = idx.entry_id(j); + const size_t eid = idx.entry_id(j); ndinputs.emplace_back(arrays[eid]); } return ndinputs; } -inline std::vector NodeOutputs(const nnvm::IndexedGraph& idx, - const int node_idx, - const std::vector arrays) { +std::vector NodeOutputs(const nnvm::IndexedGraph& idx, + const int node_idx, + const std::vector& arrays) { const nnvm::IndexedGraph::Node& node = idx[node_idx]; const size_t num_outputs = node.source->num_outputs(); std::vector ndoutputs; ndoutputs.reserve(num_outputs); for (size_t j = 0; j < num_outputs; ++j) { - size_t eid = idx.entry_id(node_idx, j); + const size_t eid = idx.entry_id(node_idx, j); ndoutputs.emplace_back(arrays[eid]); } return ndoutputs; } -inline std::vector NodeReq(const nnvm::IndexedGraph& idx, - const int node_idx, - const std::vector array_reqs) { +std::vector NodeReq(const nnvm::IndexedGraph& idx, + const int node_idx, + const std::vector& array_reqs) { const nnvm::IndexedGraph::Node& node = idx[node_idx]; const size_t num_outputs = node.source->num_outputs(); std::vector req; req.reserve(num_outputs); for (size_t j = 0; j < num_outputs; ++j) { - size_t eid = idx.entry_id(node_idx, j); + const size_t eid = idx.entry_id(node_idx, j); req.push_back(array_reqs[eid]); } return req; } -inline void InvokeOperator(const nnvm::IndexedGraph& idx, - const int node_idx, - const bool retain_graph, - const std::vector arrays, - Context ctx, - std::vector* p_states, - std::vector ndinputs, - std::vector ndoutputs, - std::vector *p_req, - std::vector *p_ref_count, - std::function invoke) { +void InvokeOperator(const nnvm::IndexedGraph& idx, + const int node_idx, + const bool retain_graph, + const std::vector& arrays, + Context ctx, + std::vector* p_states, + const std::vector& ndinputs, + const std::vector& ndoutputs, + std::vector *p_req, + std::vector *p_ref_count, + std::function invoke) { static const auto bwd_cached_op = Op::Get("_backward_CachedOp"); static auto& createop = nnvm::Op::GetAttr("FCreateOpState"); static auto& is_layer_backward = Op::GetAttr("TIsLayerOpBackward"); @@ -122,10 +121,15 @@ inline void InvokeOperator(const nnvm::IndexedGraph& idx, } } +} // namespace + +namespace mxnet { +namespace imperative { + void RunGraph( const bool retain_graph, const nnvm::IndexedGraph& idx, - const std::vector arrays, + const std::vector& arrays, size_t node_start, size_t node_end, std::vector&& array_reqs, std::vector&& ref_count, @@ -161,7 +165,7 @@ void NaiveRunGraph( const bool retain_graph, const Context& default_ctx, const nnvm::IndexedGraph& idx, - const std::vector arrays, + const std::vector& arrays, size_t node_start, size_t node_end, std::vector&& array_reqs, std::vector&& ref_count, diff --git a/src/imperative/imperative_utils.h b/src/imperative/imperative_utils.h index 071f4fa9dd0b..d134d47c55cf 100644 --- a/src/imperative/imperative_utils.h +++ b/src/imperative/imperative_utils.h @@ -999,7 +999,7 @@ inline void CreateEngineOpSeg( void RunGraph(const bool retain_graph, const nnvm::IndexedGraph& idx, - const std::vector arrays, + const std::vector& arrays, size_t node_start, size_t node_end, std::vector&& array_reqs, std::vector&& ref_count, @@ -1011,7 +1011,7 @@ void RunGraph(const bool retain_graph, void NaiveRunGraph(const bool retain_graph, const Context& default_ctx, const nnvm::IndexedGraph& idx, - const std::vector arrays, + const std::vector& arrays, size_t node_start, size_t node_end, std::vector&& array_reqs, std::vector&& ref_count, From 1f84682576433ec152ae8a5874a687ddad9190b4 Mon Sep 17 00:00:00 2001 From: Wang Jiajun Date: Tue, 16 Apr 2019 03:03:33 -0500 Subject: [PATCH 4/9] Properly handling custom op exception by modify engine (#14693) * fix custom except handling by modify engine * add test * fix lint * update test * fix test * trigger CI --- docs/faq/env_var.md | 3 - include/mxnet/engine.h | 6 +- src/engine/naive_engine.cc | 3 + src/engine/threaded_engine.cc | 5 + src/engine/threaded_engine.h | 5 +- src/operator/custom/custom-inl.h | 45 +++++++-- tests/python/unittest/test_operator.py | 129 ++++++++++++++++++++----- 7 files changed, 158 insertions(+), 38 deletions(-) diff --git a/docs/faq/env_var.md b/docs/faq/env_var.md index 2768f644c066..095c214e66b3 100644 --- a/docs/faq/env_var.md +++ b/docs/faq/env_var.md @@ -60,9 +60,6 @@ $env:MXNET_STORAGE_FALLBACK_LOG_VERBOSE=0 * MXNET_MP_OPENCV_NUM_THREADS - Values: Int ```(default=0)``` - The number of OpenCV execution threads given to multiprocess workers. OpenCV multithreading is disabled if `MXNET_MP_OPENCV_NUM_THREADS` < 1 (default). Enlarge this number may boost the performance of individual workers when executing underlying OpenCV functions but please consider reducing the overall `num_workers` to avoid thread contention (not available on Windows). -* MXNET_CUSTOM_OP_NUM_THREADS - - Values: Int ```(default=16)``` - - The maximum number of threads given to custom operators. ## Memory Options diff --git a/include/mxnet/engine.h b/include/mxnet/engine.h index 408a70a5feed..9d6367509f79 100644 --- a/include/mxnet/engine.h +++ b/include/mxnet/engine.h @@ -106,7 +106,9 @@ enum class FnProperty { /*! \brief Delete variable call */ kDeleteVar, /*! \brief Prioritized sync operation on GPU */ - kGPUPrioritized + kGPUPrioritized, + /*! \brief Operation not to be skipped even with associated exception */ + kNoSkip }; // enum class FnProperty /*! @@ -230,6 +232,8 @@ class MXNET_API Engine { * \brief Wait until all the activity of engine finishes. */ virtual void WaitForAll() = 0; + /*!\brief Throw if threre are associated exception with var */ + virtual void Throw(VarHandle var) = 0; /*!\brief virtual destructor */ virtual ~Engine() noexcept(false) {} /*! diff --git a/src/engine/naive_engine.cc b/src/engine/naive_engine.cc index db4491981bdd..93853c459298 100644 --- a/src/engine/naive_engine.cc +++ b/src/engine/naive_engine.cc @@ -212,6 +212,9 @@ class NaiveEngine final : public Engine { void WaitForAll() override { } + void Throw(VarHandle var) override { + } + void NotifyShutdown() override { shutdown_phase_.store(true); } diff --git a/src/engine/threaded_engine.cc b/src/engine/threaded_engine.cc index 986b6ad29909..38311908bdcd 100644 --- a/src/engine/threaded_engine.cc +++ b/src/engine/threaded_engine.cc @@ -498,6 +498,11 @@ inline void ThreadedEngine::ThrowException(ThreadedVar* threaded_var) { return; } +void ThreadedEngine::Throw(VarHandle var) { + ThreadedVar *threaded_var = ThreadedVar::CastFromBase(var); + ThrowException(threaded_var); +} + void ThreadedEngine::OnCompleteStatic(Engine *engine, void *opr_block_, const dmlc::Error* error) { OprBlock *opr_block = static_cast(opr_block_); diff --git a/src/engine/threaded_engine.h b/src/engine/threaded_engine.h index 3d2119d63291..7df232b1c62a 100644 --- a/src/engine/threaded_engine.h +++ b/src/engine/threaded_engine.h @@ -306,6 +306,7 @@ class ThreadedEngine : public Engine { void DeleteVariable(SyncFn delete_fn, Context exec_ctx, VarHandle var) override; void WaitForVar(VarHandle var) override; void WaitForAll() override; + void Throw(VarHandle var) override; void NotifyShutdown() override { shutdown_phase_.store(true); } @@ -374,8 +375,8 @@ class ThreadedEngine : public Engine { LOG(INFO) << "ExecuteOprFn "; } try { - if (!(threaded_opr->opr_exception && *threaded_opr->opr_exception) || - threaded_opr->wait) { + if ((!(threaded_opr->opr_exception && *threaded_opr->opr_exception) || + threaded_opr->prop == FnProperty::kNoSkip) || threaded_opr->wait) { threaded_opr->fn(run_ctx, callback); } else { callback(); diff --git a/src/operator/custom/custom-inl.h b/src/operator/custom/custom-inl.h index c5eaea13661e..3bf63b75cfdb 100644 --- a/src/operator/custom/custom-inl.h +++ b/src/operator/custom/custom-inl.h @@ -96,7 +96,12 @@ class CustomOperator { bool prev_recording = Imperative::Get()->set_is_recording(recording); bool prev_training = Imperative::Get()->set_is_training(training); - func(); + try { + func(); + } catch (dmlc::Error& e) { + exception_ = + std::make_shared(std::current_exception()); + } Imperative::Get()->set_is_training(prev_training); Imperative::Get()->set_is_recording(prev_recording); @@ -116,6 +121,16 @@ class CustomOperator { Engine::Get()->PushSync( [=](RunContext rctx) { + try { + Throw(); + for (const auto& i : arrs) { + Engine::Get()->Throw(i.var()); + } + } catch(dmlc::Error& err) { + ctx.async_on_complete(&err); + return; + } + for (size_t i = 0, out_idx = 0; i < arrs.size(); i++) { if (arrs[i].storage_type() == kDefaultStorage || arrs[i].storage_type() == kUndefinedStorage) @@ -125,14 +140,15 @@ class CustomOperator { out_idx++; } } + ctx.async_on_complete(); }, - ctx.run_ctx.ctx, vars, vars2, FnProperty::kNormal, 0, + ctx.run_ctx.ctx, vars, vars2, FnProperty::kNoSkip, 0, "CustomOperator"); }); // increase num_threads if there is not enough threads to execute custom operator - if (q_.size() > num_free_threads) - CreateThreads(q_.size() - num_free_threads); + if (q_.size() > num_free_threads_) + CreateThreads(q_.size() - num_free_threads_); cv_.notify_all(); } @@ -142,9 +158,10 @@ class CustomOperator { } void Start() { - num_free_threads = 0; + num_free_threads_ = 0; destructing_ = false; naive_engine_ = true; + exception_ = nullptr; if (std::string("NaiveEngine") != dmlc::GetEnv("MXNET_ENGINE_TYPE", std::string())) { naive_engine_ = false; } @@ -162,6 +179,14 @@ class CustomOperator { workers_.clear(); } + inline void Throw() { + if (exception_ && *exception_) { + std::exception_ptr tmp = *exception_; + exception_ = nullptr; + std::rethrow_exception(tmp); + } + } + private: CustomOperator() { this->Start(); @@ -171,21 +196,20 @@ class CustomOperator { while (!q_.empty() || !destructing_) { cv_.wait(lock, [&] {return !q_.empty() || destructing_;}); while (!q_.empty()) { - --num_free_threads; + --num_free_threads_; auto fn = q_.front(); q_.pop(); lock.unlock(); fn(); - ++num_free_threads; + ++num_free_threads_; lock.lock(); } } } void SetNumThreads(int num_threads) { - num_threads = std::min(dmlc::GetEnv("MXNET_CUSTOM_OP_NUM_THREADS", 16), num_threads); for (int i = workers_.size(); i < num_threads; ++i) { workers_.emplace_back(std::thread([this]{this->ThreadTarget();})); - ++num_free_threads; + ++num_free_threads_; } } void CreateThreads(int num_new_threads) { @@ -196,8 +220,9 @@ class CustomOperator { // async worker std::condition_variable cv_; std::vector workers_; - std::atomic num_free_threads; + std::atomic num_free_threads_; std::queue > q_; + std::shared_ptr exception_; bool naive_engine_; bool destructing_; }; diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 59d72d4b18b6..287b974d151e 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -29,6 +29,8 @@ from mxnet.test_utils import * from mxnet.base import py_str, MXNetError, _as_list from common import setup_module, with_seed, teardown, assert_raises_cudnn_not_satisfied, assertRaises +from common import run_in_spawned_process +from nose.tools import assert_raises import unittest import os @@ -5360,29 +5362,29 @@ def create_operator(self, ctx, shapes, dtypes): # test custom operator fork # see https://github.com/apache/incubator-mxnet/issues/14396 - if not sys.platform.startswith('win'): # no fork in windows - class AdditionOP(mx.operator.CustomOp): - def __init__(self): - super(AdditionOP, self).__init__() - def forward(self, is_train, req, in_data, out_data, aux): - out_data[0][:] = in_data[0] + in_data[1] - def backward(self, req, out_grad, in_data, out_data, in_grad, aux): - in_grad[0][:] = out_grad[0] - in_grad[1][:] = out_grad[0] - - @mx.operator.register("AdditionOP") - class AdditionOPProp(mx.operator.CustomOpProp): - def __init__(self): - super(AdditionOPProp, self).__init__() - def list_arguments(self): - return ['a', 'b'] - def list_outputs(self): - return ['output'] - def infer_shape(self, in_shape): - return in_shape, [in_shape[0]] - def create_operator(self, ctx, shapes, dtypes): - return AdditionOP() + class AdditionOP(mx.operator.CustomOp): + def __init__(self): + super(AdditionOP, self).__init__() + def forward(self, is_train, req, in_data, out_data, aux): + out_data[0][:] = in_data[0] + in_data[1] + def backward(self, req, out_grad, in_data, out_data, in_grad, aux): + in_grad[0][:] = out_grad[0] + in_grad[1][:] = out_grad[0] + @mx.operator.register("AdditionOP") + class AdditionOPProp(mx.operator.CustomOpProp): + def __init__(self): + super(AdditionOPProp, self).__init__() + def list_arguments(self): + return ['a', 'b'] + def list_outputs(self): + return ['output'] + def infer_shape(self, in_shape): + return in_shape, [in_shape[0]] + def create_operator(self, ctx, shapes, dtypes): + return AdditionOP() + + if not sys.platform.startswith('win'): # no fork in windows def custom_add(): a = mx.nd.array([1, 2, 3]) b = mx.nd.array([4, 5, 6]) @@ -5397,6 +5399,89 @@ def custom_add(): p.join(5) assert not p.is_alive(), "deadlock may exist in custom operator" + +def _build_dot_custom(fun_forward, name): + class Dot(mx.operator.CustomOp): + def __init__(self): + super(Dot, self).__init__() + def forward(self, is_train, req, in_data, out_data, aux): + fun_forward(in_data, out_data) + def backward(self, req, out_grad, in_data, out_data, in_grad, aux): + pass + + @mx.operator.register(name) + class DotProp(mx.operator.CustomOpProp): + def __init__(self): + super(DotProp, self).__init__() + def list_arguments(self): + return ['a', 'b'] + def list_outputs(self): + return ['output'] + def infer_shape(self, in_shape): + return in_shape, [(in_shape[0][0], in_shape[1][1])] + def create_operator(self, ctx, shapes, dtypes): + return Dot() + +def _custom_exc3(seed): + def custom_exc3(): + def f(in_data, out_data): + out_data[0][:] = mx.nd.dot(in_data[0], in_data[1]) + out_data[0].wait_to_read() + _build_dot_custom(f, 'Dot3') + n = int(1e8) + a = mx.nd.zeros((n, 1)) + b = mx.nd.zeros((1, n)) + # trigger OOM + c = mx.nd.Custom(a, b, op_type='Dot3') + c.wait_to_read() + assert_raises(MXNetError, custom_exc3) + +def _custom_exc4(seed): + def custom_exc4(): + def f(in_data, out_data): + out_data[0][:] = mx.nd.dot(in_data[0], in_data[1]) + _build_dot_custom(f, 'Dot4') + n = int(1e8) + a = mx.nd.zeros((n, 1)) + b = mx.nd.zeros((1, n)) + # trigger OOM + c = mx.nd.Custom(a, b, op_type='Dot4') + c.wait_to_read() + assert_raises(MXNetError, custom_exc4) + +@with_seed() +def test_custom_op_exc(): + # test except handling + # see https://github.com/apache/incubator-mxnet/pull/14693 + # 1. error in python code + def custom_exc1(): + def f(in_data, out_data): + assert False + out_data[0][:] = mx.nd.dot(in_data[0], in_data[1]) + _build_dot_custom(f, 'Dot1') + a = mx.nd.zeros((4, 1)) + b = mx.nd.zeros((1, 4)) + c = mx.nd.Custom(a, b, op_type='Dot1') + c.wait_to_read() + assert_raises(MXNetError, custom_exc1) + + # 2. error in pushing operator to engine + def custom_exc2(): + def f(in_data, out_data): + out_data[0][:] = mx.nd.dot(in_data[0], in_data[1]) + _build_dot_custom(f, 'Dot2') + a = mx.nd.zeros((4, 2)) + b = mx.nd.zeros((1, 4)) + # trigger error by invalid input shapes of operands + c = mx.nd.Custom(a, b, op_type='Dot2') + c.wait_to_read() + assert_raises(MXNetError, custom_exc2) + + # 3. error in real execution + run_in_spawned_process(_custom_exc3, {}) + run_in_spawned_process(_custom_exc4, {}) + + @with_seed() def test_psroipooling(): for num_rois in [1, 2]: From 52a3553fe200214437c717e7b35e6ce39adb59d8 Mon Sep 17 00:00:00 2001 From: Arthur Caillau Date: Tue, 16 Apr 2019 15:31:01 +0200 Subject: [PATCH 5/9] [docstring] improve docstring and indentation in `module.clj` (#14705) --- .../src/org/apache/clojure_mxnet/module.clj | 544 +++++++++++------- .../src/org/apache/clojure_mxnet/util.clj | 2 +- 2 files changed, 345 insertions(+), 201 deletions(-) diff --git a/contrib/clojure-package/src/org/apache/clojure_mxnet/module.clj b/contrib/clojure-package/src/org/apache/clojure_mxnet/module.clj index aa5ce39f7a80..09f17e5d81f4 100644 --- a/contrib/clojure-package/src/org/apache/clojure_mxnet/module.clj +++ b/contrib/clojure-package/src/org/apache/clojure_mxnet/module.clj @@ -16,6 +16,7 @@ ;; (ns org.apache.clojure-mxnet.module + "Module API for Clojure package." (:refer-clojure :exclude [update symbol]) (:require [org.apache.clojure-mxnet.callback :as callback] [org.apache.clojure-mxnet.context :as context] @@ -31,18 +32,29 @@ (:import (org.apache.mxnet.module Module FitParams BaseModule) (org.apache.mxnet.io MXDataIter NDArrayIter) (org.apache.mxnet Initializer Optimizer NDArray DataBatch - Context EvalMetric Monitor Callback$Speedometer DataDesc))) + Context EvalMetric Monitor Callback$Speedometer + DataDesc))) (defn module - "Module is a basic module that wrap a symbol. - sym : Symbol definition. - map of options - :data-names - Input data names. - :label-names - Input label names - :contexts - Default is cpu(). - :workload-list - Default nil, indicating uniform workload. - :fixed-param-names Default nil, indicating no network parameters are fixed." - ([sym {:keys [data-names label-names contexts workload-list fixed-param-names] :as opts + "Module is a basic module that wrap a `symbol`. + `sym`: Symbol definition. + `opts-map` { + `data-names`: vector of strings - Default is [\"data\"] + Input data names + `label-names`: vector of strings - Default is [\"softmax_label\"] + Input label names + `contexts`: Context - Default is `context/cpu`. + `workload-list`: Default nil + Indicating uniform workload. + `fixed-param-names`: Default nil + Indicating no network parameters are fixed. + } + Ex: + (module sym) + (module sym {:data-names [\"data\"] + :label-names [\"linear_regression_label\"]}" + ([sym {:keys [data-names label-names contexts + workload-list fixed-param-names] :as opts :or {data-names ["data"] label-names ["softmax_label"] contexts [(context/default-context)]}}] @@ -80,31 +92,41 @@ (s/def ::force-rebind boolean?) (s/def ::shared-module #(instance? Module)) (s/def ::grad-req string?) -(s/def ::bind-opts (s/keys :req-un [::data-shapes] :opt-un [::label-shapes ::for-training ::inputs-need-grad - ::force-rebind ::shared-module ::grad-req])) +(s/def ::bind-opts + (s/keys :req-un [::data-shapes] + :opt-un [::label-shapes ::for-training ::inputs-need-grad + ::force-rebind ::shared-module ::grad-req])) (defn bind "Bind the symbols to construct executors. This is necessary before one can perform computation with the module. - mod : module - map of opts: - :data-shapes Typically is (provide-data-desc data-iter). Data shape must be in the form of io/data-desc with is a map of :name :shape :dtype and :layout - :label-shapes Typically is (provide-label-desc data-iter). map of :name :shape :dtype and :layout - :for-training Default is `true`. Whether the executors should be bind for training. - :inputs-need-grad Default is `false`. - Whether the gradients to the input data need to be computed. - Typically this is not needed. - But this might be needed when implementing composition of modules. - :force-rebind Default is `false`. - This function does nothing if the executors are already binded. - But with this `true`, the executors will be forced to rebind. - :shared-module Default is nil. This is used in bucketing. - When not `None`, the shared module essentially corresponds to - a different bucket -- a module with different symbol - but with the same sets of parameters - (e.g. unrolled RNNs with different lengths). " - [mod {:keys [data-shapes label-shapes for-training inputs-need-grad force-rebind - shared-module grad-req] :as opts + `mod`: module + `opts-map` { + `data-shapes`: map of `:name`, `:shape`, `:dtype`, and `:layout` + Typically is `(provide-data-desc data-iter)`.Data shape must be in the + form of `io/data-desc` + `label-shapes`: map of `:name` `:shape` `:dtype` and `:layout` + Typically is `(provide-label-desc data-iter)`. + `for-training`: boolean - Default is `true` + Whether the executors should be bind for training. + `inputs-need-grad`: boolean - Default is `false`. + Whether the gradients to the input data need to be computed. + Typically this is not needed. But this might be needed when + implementing composition of modules. + `force-rebind`: boolean - Default is `false`. + This function does nothing if the executors are already binded. But + with this `true`, the executors will be forced to rebind. + `shared-module`: Default is nil. + This is used in bucketing. When not `nil`, the shared module + essentially corresponds to a different bucket -- a module with + different symbol but with the same sets of parameters (e.g. unrolled + RNNs with different lengths). + } + Ex: + (bind {:data-shapes (mx-io/provide-data train-iter) + :label-shapes (mx-io/provide-label test-iter)})) " + [mod {:keys [data-shapes label-shapes for-training inputs-need-grad + force-rebind shared-module grad-req] :as opts :or {for-training true inputs-need-grad false force-rebind false @@ -129,24 +151,36 @@ (s/def ::aux-params map?) (s/def ::force-init boolean?) (s/def ::allow-extra boolean?) -(s/def ::init-params-opts (s/keys :opt-un [::initializer ::arg-params ::aux-params - ::force-init ::allow-extra])) +(s/def ::init-params-opts + (s/keys :opt-un [::initializer ::arg-params ::aux-params + ::force-init ::allow-extra])) (defn init-params - " Initialize the parameters and auxiliary states. - options map - :initializer - Called to initialize parameters if needed. - :arg-params - If not nil, should be a map of existing arg-params. - Initialization will be copied from that. - :auxParams - If not nil, should be a map of existing aux-params. - Initialization will be copied from that. - :allow-missing - If true, params could contain missing values, - and the initializer will be called to fill those missing params. - :force-init - If true, will force re-initialize even if already initialized. - :allow-extra - Whether allow extra parameters that are not needed by symbol. - If this is True, no error will be thrown when argParams or auxParams - contain extra parameters that is not needed by the executor." - ([mod {:keys [initializer arg-params aux-params allow-missing force-init allow-extra] :as opts + "Initialize the parameters and auxiliary states. + `opts-map` { + `initializer`: Initializer - Default is `uniform` + Called to initialize parameters if needed. + `arg-params`: map + If not nil, should be a map of existing arg-params. Initialization + will be copied from that. + `aux-params`: map + If not nil, should be a map of existing aux-params. Initialization + will be copied from that. + `allow-missing`: boolean - Default is `false` + If true, params could contain missing values, and the initializer will + be called to fill those missing params. + `force-init` boolean - Default is `false` + If true, will force re-initialize even if already initialized. + `allow-extra`: boolean - Default is `false` + Whether allow extra parameters that are not needed by symbol. + If this is `true`, no error will be thrown when `arg-params` or + `aux-params` contain extra parameters that is not needed by the + executor. + Ex: + (init-params {:initializer (initializer/xavier)}) + (init-params {:force-init true :allow-extra true})" + ([mod {:keys [initializer arg-params aux-params allow-missing force-init + allow-extra] :as opts :or {initializer (initializer/uniform 0.01) allow-missing false force-init false @@ -167,17 +201,23 @@ (s/def ::kvstore string?) (s/def ::reset-optimizer boolean?) (s/def ::force-init boolean?) -(s/def ::init-optimizer-opts (s/keys :opt-un [::optimizer ::kvstore ::reset-optimizer ::force-init])) +(s/def ::init-optimizer-opts + (s/keys :opt-un [::optimizer ::kvstore ::reset-optimizer ::force-init])) (defn init-optimizer - " Install and initialize optimizers. - - mod Module - - options map of - - kvstore - - reset-optimizer Default `True`, indicating whether we should set - `rescaleGrad` & `idx2name` for optimizer according to executorGroup - - force-init Default `False`, indicating whether we should force - re-initializing the optimizer in the case an optimizer is already installed." + "Install and initialize optimizers. + `mod`: Module + `opts-map` { + `kvstore`: string - Default is \"local\" + `optimizer`: Optimizer - Default is `sgd` + `reset-optimizer`: boolean - Default is `true` + Indicating whether we should set `rescaleGrad` & `idx2name` for + optimizer according to executorGroup. + `force-init`: boolean - Default is `false` + Indicating whether we should force re-initializing the optimizer + in the case an optimizer is already installed. + Ex: + (init-optimizer {:optimizer (optimizer/sgd {:learning-rate 0.1})})" ([mod {:keys [kvstore optimizer reset-optimizer force-init] :as opts :or {kvstore "local" optimizer (optimizer/sgd) @@ -191,8 +231,10 @@ (defn forward "Forward computation. - data-batch - input data of form io/data-batch either map or DataBatch - is-train - Default is nil, which means `is_train` takes the value of `for_training`." + `data-batch`: Either map or DataBatch + Input data of form `io/data-batch`. + `is-train`: Default is nil + Which means `is_train` takes the value of `for_training`." ([mod data-batch is-train] (util/validate! ::mx-io/data-batch data-batch "Invalid data batch") (doto mod @@ -209,9 +251,9 @@ (defn backward "Backward computation. - out-grads - Gradient on the outputs to be propagated back. - This parameter is only needed when bind is called - on outputs that are not a loss function." + `out-grads`: collection of NDArrays + Gradient on the outputs to be propagated back. This parameter is only + needed when bind is called on outputs that are not a loss function." ([mod out-grads] (util/validate! ::out-grads out-grads "Invalid out-grads") (doto mod @@ -227,50 +269,48 @@ (.forwardBackward data-batch))) (defn outputs - " Get outputs of the previous forward computation. - In the case when data-parallelism is used, - the outputs will be collected from multiple devices. - The results will look like `[[out1_dev1, out1_dev2], [out2_dev1, out2_dev2]]`, - those `NDArray` might live on different devices." + "Get outputs of the previous forward computation. + In the case when data-parallelism is used, the outputs will be collected from + multiple devices. The results will look like + `[[out1_dev1, out1_dev2], [out2_dev1, out2_dev2]]`. + Those `NDArray`s might live on different devices." [mod] (->> (.getOutputs mod) (util/scala-vector->vec) (mapv util/scala-vector->vec))) (defn update - "Update parameters according to the installed optimizer and the gradients computed - in the previous forward-backward batch." + "Update parameters according to the installed optimizer and the gradients + computed in the previous forward-backward batch." [mod] (doto mod (.update))) (defn outputs-merged - " Get outputs of the previous forward computation. - return In the case when data-parallelism is used, - the outputs will be merged from multiple devices, - as they look like from a single executor. - The results will look like `[out1, out2]`" + "Get outputs of the previous forward computation. + In the case when data-parallelism is used, the outputs will be merged from + multiple devices, as they look like from a single executor. + The results will look like `[out1, out2]`." [mod] (->> (.getOutputsMerged mod) (util/scala-vector->vec))) (defn input-grads - " Get the gradients to the inputs, computed in the previous backward computation. - In the case when data-parallelism is used, - the outputs will be collected from multiple devices. - The results will look like `[[grad1_dev1, grad1_dev2], [grad2_dev1, grad2_dev2]]` - those `NDArray` might live on different devices." + "Get the gradients to the inputs, computed in the previous backward computation. + In the case when data-parallelism is used, the outputs will be collected from + multiple devices. The results will look like + `[[grad1_dev1, grad1_dev2], [grad2_dev1, grad2_dev2]]`. + Those `NDArray`s might live on different devices." [mod] (->> (.getInputGrads mod) (util/scala-vector->vec) (mapv util/scala-vector->vec))) (defn input-grads-merged - " Get the gradients to the inputs, computed in the previous backward computation. - return In the case when data-parallelism is used, - the outputs will be merged from multiple devices, - as they look like from a single executor. - The results will look like `[grad1, grad2]`" + "Get the gradients to the inputs, computed in the previous backward computation. + In the case when data-parallelism is used, the outputs will be merged from + multiple devices, as they look like from a single executor. + The results will look like `[grad1, grad2]`." [mod] (->> (.getInputGradsMerged mod) (util/scala-vector->vec))) @@ -278,16 +318,25 @@ (s/def ::prefix string?) (s/def ::epoch int?) (s/def ::save-opt-states boolean?) -(s/def ::save-checkpoint-opts (s/keys :req-un [::prefix ::epoch] :opt-un [::save-opt-states ::save-checkpoint])) +(s/def ::save-checkpoint-opts + (s/keys :req-un [::prefix ::epoch] + :opt-un [::save-opt-states ::save-checkpoint])) (defn save-checkpoint - " Save current progress to checkpoint. - Use mx.callback.module_checkpoint as epoch_end_callback to save during training. - - mod Module - - opt-map with - :prefix The file prefix to checkpoint to - :epoch The current epoch number - :save-opt-states Whether to save optimizer states for continue training " + "Save current progress to checkpoint. + Use mx.callback.module_checkpoint as epoch_end_callback to save during + training. + `mod`: Module + `opts-map` { + `prefix`: string + The file prefix to checkpoint to + `epoch`: int + The current epoch number + `save-opt-states`: boolean - Default is `false` + Whether to save optimizer states for continue training + } + Ex: + (save-checkpoint {:prefix \"saved_model\" :epoch 0 :save-opt-states true})" ([mod {:keys [prefix epoch save-opt-states] :as opts :or {save-opt-states false}}] (util/validate! ::save-checkpoint-opts opts "Invalid save checkpoint opts") @@ -303,24 +352,34 @@ (s/def ::contexts (s/coll-of ::context :kind vector?)) (s/def ::workload-list (s/coll-of number? :kind vector?)) (s/def ::fixed-params-names (s/coll-of string? :kind vector?)) -(s/def ::load-checkpoint-opts (s/keys :req-un [::prefix ::epoch] - :opt-un [::load-optimizer-states ::data-names ::label-names - ::contexts ::workload-list ::fixed-param-names])) +(s/def ::load-checkpoint-opts + (s/keys :req-un [::prefix ::epoch] + :opt-un [::load-optimizer-states ::data-names ::label-names + ::contexts ::workload-list ::fixed-param-names])) (defn load-checkpoint "Create a model from previously saved checkpoint. - - opts map of - - prefix Path prefix of saved model files. You should have prefix-symbol.json, - prefix-xxxx.params, and optionally prefix-xxxx.states, - where xxxx is the epoch number. - - epoch Epoch to load. - - load-optimizer-states Whether to load optimizer states. - Checkpoint needs to have been made with save-optimizer-states=True - - dataNames Input data names. - - labelNames Input label names - - contexts Default is cpu(). - - workload-list Default nil, indicating uniform workload. - - fixed-param-names Default nil, indicating no network parameters are fixed." + `opts-map` { + `prefix`: string + Path prefix of saved model files. You should have prefix-symbol.json, + prefix-xxxx.params, and optionally prefix-xxxx.states, where xxxx is + the epoch number. + `epoch`: int + Epoch to load. + `load-optimizer-states`: boolean - Default is false + Whether to load optimizer states. Checkpoint needs to have been made + with `save-optimizer-states` = `true`. + `data-names`: vector of strings - Default is [\"data\"] + Input data names. + `label-names`: vector of strings - Default is [\"softmax_label\"] + Input label names. + `contexts`: Context - Default is `context/cpu` + `workload-list`: Default nil + Indicating uniform workload. + `fixed-param-names`: Default nil + Indicating no network parameters are fixed. + Ex: + (load-checkpoint {:prefix \"my-model\" :epoch 1 :load-optimizer-states true}" ([{:keys [prefix epoch load-optimizer-states data-names label-names contexts workload-list fixed-param-names] :as opts :or {load-optimizer-states false @@ -358,10 +417,10 @@ (util/scala-map->map (.auxParams mod))) (defn reshape - " Reshapes the module for new input shapes. - - mod module - - data-shapes Typically is `(provide-data data-iter) - - param label-shapes Typically is `(provide-label data-tier)`. " + "Reshapes the module for new input shapes. + `mod`: Module + `data-shapes`: Typically is `(provide-data data-iter)` + `label-shapes`: Typically is `(provide-label data-tier)`" ([mod data-shapes label-shapes] (util/validate! ::data-shapes data-shapes "Invalid data-shapes") (util/validate! (s/nilable ::label-shapes) label-shapes "Invalid label-shapes") @@ -376,28 +435,35 @@ ([mod data-shapes] (reshape mod data-shapes nil))) -(s/def ::set-param-opts (s/keys :opt-un [::arg-params ::aux-params ::allow-missing ::force-init ::allow-extra])) +(s/def ::set-param-opts + (s/keys :opt-un [::arg-params ::aux-params ::allow-missing + ::force-init ::allow-extra])) (defn get-params [mod] (.getParams mod)) (defn set-params - " Assign parameter and aux state values. - - mod module - - arg-params : map - map of name to value (`NDArray`) mapping. - - aux-params : map - map of name to value (`NDArray`) mapping. - - allow-missing : bool - If true, params could contain missing values, and the initializer will be - called to fill those missing params. - - force-init : bool - If true, will force re-initialize even if already initialized. - - allow-extra : bool - Whether allow extra parameters that are not needed by symbol. - If this is True, no error will be thrown when arg-params or aux-params - contain extra parameters that is not needed by the executor." - [mod {:keys [arg-params aux-params allow-missing force-init allow-extra] :as opts + "Assign parameters and aux state values. + `mod`: Module + `opts-map` { + `arg-params`: map - map of name to value (`NDArray`) mapping. + `aux-params`: map - map of name to value (`NDArray`) mapping. + `allow-missing`: boolean + If true, params could contain missing values, and the initializer will + be called to fill those missing params. + `force-init`: boolean - Default is `false` + If true, will force re-initialize even if already initialized. + `allow-extra`: boolean - Default is `false` + Whether allow extra parameters that are not needed by symbol. If this + is `true`, no error will be thrown when arg-params or aux-params + contain extra parameters that is not needed by the executor. + } + Ex: + (set-params mod + {:arg-params {\"fc_0_weight\" (ndarray/array [0.15 0.2 0.25 0.3] [2 2]) + :allow-missing true})" + [mod {:keys [arg-params aux-params allow-missing force-init + allow-extra] :as opts :or {allow-missing false force-init true allow-extra false}}] (util/validate! ::set-param-opts opts "Invalid set-params") (doto mod @@ -409,33 +475,32 @@ allow-extra))) (defn install-monitor - "Install monitor on all executors" + "Install monitor on all executors." [mod monitor] (doto mod (.installMonitor monitor))) (defn borrow-optimizer - "Borrow optimizer from a shared module. Used in bucketing, where exactly the same - optimizer (esp. kvstore) is used. - - mod module - - shared-module" + "Borrow optimizer from a shared module. Used in bucketing, where exactly the + same optimizer (esp. kvstore) is used. + `mod`: Module + `shared-module`" [mod shared-module] (doto mod (.borrowOptimizer shared-module))) (defn save-optimizer-states - "Save optimizer (updater) state to file - - mod module - - fname Path to output states file." + "Save optimizer (updater) state to file. + `mod`: Module + `fname`: string - Path to output states file." [mod fname] (doto mod (.saveOptimizerStates mod fname))) (defn load-optimizer-states - "Load optimizer (updater) state from file - - mod module - - fname Path to input states file. - " + "Load optimizer (updater) state from file. + `mod`: Module + `fname`: string - Path to input states file." [mod fname] (doto mod (.loadOptimzerStates fname))) @@ -444,10 +509,13 @@ (s/def ::labels (s/coll-of ::ndarray :kind vector?)) (defn update-metric - "Evaluate and accumulate evaluation metric on outputs of the last forward computation. - - mod module - - eval-metric - - labels" + "Evaluate and accumulate evaluation metric on outputs of the last forward + computation. + `mod`: module + `eval-metric`: EvalMetric + `labels`: collection of NDArrays + Ex: + (update-metric mod (eval-metric/mse) labels)" [mod eval-metric labels] (util/validate! ::eval-metric eval-metric "Invalid eval metric") (util/validate! ::labels labels "Invalid labels") @@ -458,18 +526,48 @@ (s/def ::validation-metric ::eval-metric) (s/def ::monitor #(instance? Monitor %)) (s/def ::batch-end-callback #(instance? Callback$Speedometer %)) -(s/def ::fit-params-opts (s/keys :opt-un [::eval-metric ::kvstore ::optimizer ::initializer - ::arg-params ::aux-params ::allow-missing ::force-rebind - ::force-init ::begin-epoch ::validation-metric ::monitor - ::batch-end-callback])) +(s/def ::fit-params-opts + (s/keys :opt-un [::eval-metric ::kvstore ::optimizer ::initializer + ::arg-params ::aux-params ::allow-missing ::force-rebind + ::force-init ::begin-epoch ::validation-metric ::monitor + ::batch-end-callback])) ;; callbacks are not supported for now (defn fit-params - "Fit Params" + "Initialize FitParams with provided parameters. + `eval-metric`: EvalMetric - Default is `accuracy` + `kvstore`: String - Default is \"local\" + `optimizer`: Optimizer - Default is `sgd` + `initializer`: Initializer - Default is `uniform` + Called to initialize parameters if needed. + `arg-params`: map + If not nil, should be a map of existing `arg-params`. Initialization + will be copied from that. + `aux-params`: map - + If not nil, should be a map of existing `aux-params`. Initialization + will be copied from that. + `allow-missing`: boolean - Default is `false` + If `true`, params could contain missing values, and the initializer will + be called to fill those missing params. + `force-rebind`: boolean - Default is `false` + This function does nothing if the executors are already binded. But with + this `true`, the executors will be forced to rebind. + `force-init`: boolean - Default is `false` + If `true`, will force re-initialize even if already initialized. + `begin-epoch`: int - Default is 0 + `validation-metric`: EvalMetric + `monitor`: Monitor + Ex: + (fit-params {:force-init true :force-rebind true :allow-missing true}) + (fit-params + {:batch-end-callback (callback/speedometer batch-size 100) + :initializer (initializer/xavier) + :optimizer (optimizer/sgd {:learning-rate 0.01}) + :eval-metric (eval-metric/mse)})" ([{:keys [eval-metric kvstore optimizer initializer arg-params aux-params - allow-missing force-rebind force-init begin-epoch validation-metric monitor - batch-end-callback] :as opts + allow-missing force-rebind force-init begin-epoch + validation-metric monitor batch-end-callback] :as opts :or {eval-metric (eval-metric/accuracy) kvstore "local" optimizer (optimizer/sgd) @@ -500,25 +598,36 @@ (s/def ::ndarray-iter #(instance? NDArrayIter %)) (s/def ::train-data (s/or :mx-iter ::mx-data-iter :ndarry-iter ::ndarray-iter)) (s/def ::eval-data ::train-data) -(s/def ::num-epoch int?) +(s/def ::num-epoch (s/and int? pos?)) (s/def ::fit-params #(instance? FitParams %)) -(s/def ::fit-options (s/keys :req-un [::train-data] :opt-un [::eval-data ::num-epoch ::fit-params])) +(s/def ::fit-options + (s/keys :req-un [::train-data] + :opt-un [::eval-data ::num-epoch ::fit-params])) ;;; High Level API (defn score - " Run prediction on `eval-data` and evaluate the performance according to `eval-metric`. - - mod module - - option map with - :eval-data : DataIter - :eval-metric : EvalMetric - :num-batch Number of batches to run. Default is `Integer.MAX_VALUE`, - indicating run until the `DataIter` finishes. - :batch-end-callback -not supported yet - :reset Default `True`, - indicating whether we should reset `eval-data` before starting evaluating. - :epoch Default 0. For compatibility, this will be passed to callbacks (if any). - During training, this will correspond to the training epoch number." + "Run prediction on `eval-data` and evaluate the performance according to + `eval-metric`. + `mod`: module + `opts-map` { + `eval-data`: DataIter + `eval-metric`: EvalMetric + `num-batch`: int - Default is `Integer.MAX_VALUE` + Number of batches to run. Indicating run until the `DataIter` + finishes. + `batch-end-callback`: not supported yet. + `reset`: boolean - Default is `true`, + Indicating whether we should reset `eval-data` before starting + evaluating. + `epoch`: int - Default is 0 + For compatibility, this will be passed to callbacks (if any). During + training, this will correspond to the training epoch number. + } + Ex: + (score mod {:eval-data data-iter :eval-metric (eval-metric/accuracy)}) + (score mod {:eval-data data-iter + :eval-metric (eval-metric/mse) :num-batch 10})" [mod {:keys [eval-data eval-metric num-batch reset epoch] :as opts :or {num-batch Integer/MAX_VALUE reset true @@ -537,15 +646,30 @@ (defn fit "Train the module parameters. - - mod module - - train-data (data-iterator) - - eval-data (data-iterator)If not nil, will be used as validation set and evaluate - the performance after each epoch. - - num-epoch Number of epochs to run training. - - f-params Extra parameters for training (See fit-params)." + `mod`: Module + `opts-map` { + `train-data`: DataIter + `eval-data`: DataIter + If not nil, will be used as validation set and evaluate the + performance after each epoch. + `num-epoch`: int + Number of epochs to run training. + `fit-params`: FitParams + Extra parameters for training (see fit-params). + } + Ex: + (fit {:train-data train-iter :eval-data test-iter :num-epoch 100) + (fit {:train-data train-iter + :eval-data test-iter + :num-epoch 5 + :fit-params + (fit-params {:batch-end-callback (callback/speedometer 128 100) + :initializer (initializer/xavier) + :optimizer (optimizer/sgd {:learning-rate 0.01}) + :eval-metric (eval-metric/mse)}))" [mod {:keys [train-data eval-data num-epoch fit-params] :as opts - `:or {num-epoch 1 - fit-params (new FitParams)}}] + :or {num-epoch 1 + fit-params (new FitParams)}}] (util/validate! ::fit-options opts "Invalid options for fit") (doto mod (.fit @@ -557,12 +681,13 @@ (s/def ::eval-data ::train-data) (s/def ::num-batch integer?) (s/def ::reset boolean?) -(s/def ::predict-opts (s/keys :req-un [::eval-data] :opt-un [::num-batch ::reset])) +(s/def ::predict-opts + (s/keys :req-un [::eval-data] :opt-un [::num-batch ::reset])) (defn predict-batch - "Run the predication on a data batch - - mod module - - data-batch data-batch" + "Run the predication on a data batch. + `mod`: Module + `data-batch`: data-batch" [mod data-batch] (util/validate! ::mx-io/data-batch data-batch "Invalid data batch") (util/coerce-return (.predict mod (if (map? data-batch) @@ -571,41 +696,60 @@ (defn predict "Run prediction and collect the outputs. - - mod module - - option map with - - :eval-data - - :num-batch Default is -1, indicating running all the batches in the data iterator. - - :reset Default is `True`, indicating whether we should reset the data iter before start - doing prediction. - The return value will be a vector of NDArrays `[out1, out2, out3]`. - Where each element is concatenation of the outputs for all the mini-batches." + `mod`: Module + `opts-map` { + `eval-data`: DataIter + `num-batch` int - Default is `-1` + Indicating running all the batches in the data iterator. + `reset`: boolean - Default is `true` + Indicating whether we should reset the data iter before start doing + prediction. + } + returns: vector of NDArrays `[out1, out2, out3]` where each element is the + concatenation of the outputs for all the mini-batches. + Ex: + (predict mod {:eval-data test-iter}) + (predict mod {:eval-data test-iter :num-batch 10 :reset false})" [mod {:keys [eval-data num-batch reset] :as opts :or {num-batch -1 reset true}}] (util/validate! ::predict-opts opts "Invalid opts for predict") (util/scala-vector->vec (.predict mod eval-data (int num-batch) reset))) -(s/def ::predict-every-batch-opts (s/keys :req-un [::eval-data] :opt-un [::num-batch ::reset])) +(s/def ::predict-every-batch-opts + (s/keys :req-un [::eval-data] :opt-un [::num-batch ::reset])) (defn predict-every-batch - " Run prediction and collect the outputs. - - module - - option map with - :eval-data - :num-batch Default is -1, indicating running all the batches in the data iterator. - :reset Default is `True`, indicating whether we should reset the data iter before start - doing prediction. - The return value will be a nested list like - [[out1_batch1, out2_batch1, ...], [out1_batch2, out2_batch2, ...]]` - This mode is useful because in some cases (e.g. bucketing), - the module does not necessarily produce the same number of outputs." + "Run prediction and collect the outputs. + `mod`: Module + `opts-map` { + `eval-data`: DataIter + `num-batch` int - Default is `-1` + Indicating running all the batches in the data iterator. + `reset` boolean - Default is `true` + Indicating whether we should reset the data iter before start doing + prediction. + } + returns: nested list like this + `[[out1_batch1, out2_batch1, ...], [out1_batch2, out2_batch2, ...]]` + + Note: This mode is useful because in some cases (e.g. bucketing), the module + does not necessarily produce the same number of outputs. + Ex: + (predict-every-batch mod {:eval-data test-iter})" [mod {:keys [eval-data num-batch reset] :as opts :or {num-batch -1 reset true}}] - (util/validate! ::predict-every-batch-opts opts "Invalid opts for predict-every-batch") - (mapv util/scala-vector->vec (util/scala-vector->vec (.predictEveryBatch mod eval-data (int num-batch) reset)))) - -(s/def ::score-opts (s/keys :req-un [::eval-data ::eval-metric] :opt-un [::num-batch ::reset ::epoch])) + (util/validate! ::predict-every-batch-opts + opts + "Invalid opts for predict-every-batch") + (mapv util/scala-vector->vec + (util/scala-vector->vec + (.predictEveryBatch mod eval-data (int num-batch) reset)))) + +(s/def ::score-opts + (s/keys :req-un [::eval-data ::eval-metric] + :opt-un [::num-batch ::reset ::epoch])) (defn exec-group [mod] (.execGroup mod)) diff --git a/contrib/clojure-package/src/org/apache/clojure_mxnet/util.clj b/contrib/clojure-package/src/org/apache/clojure_mxnet/util.clj index 7ee25d4dd25e..9dc6c8f88ddd 100644 --- a/contrib/clojure-package/src/org/apache/clojure_mxnet/util.clj +++ b/contrib/clojure-package/src/org/apache/clojure_mxnet/util.clj @@ -250,7 +250,7 @@ shape))) (defn map->scala-tuple-seq - "* Convert a map to a scala-Seq of scala-Tubple. + "* Convert a map to a scala-Seq of scala-Tuple. * Should also work if a seq of seq of 2 things passed. * Otherwise passed through unchanged." [map-or-tuple-seq] From 3f3ba92ae1468d08de088d2291ca14e2d5dc5515 Mon Sep 17 00:00:00 2001 From: reminisce Date: Tue, 16 Apr 2019 10:00:54 -0700 Subject: [PATCH 6/9] [numpy] Support zero-dim and zero-size tensors in MXNet (#14661) * [numpy] Shape support scalar tensor (#14315) * Support scalar and zero-size tensors with np.sum * Add sanity check when ndim is set * [Numpy] Change semantics of ndim for operators in `src/operator/contrib` (#14409) * Initial commit * Address comments * [WIP] Use new shape definition (#14453) * Init checkin * Fix ndarray alloc bug * Use TShape(0) as default empty tuple params * Fix bugs * Fix TShape init value * Fix infer shape pass shape type and reshape infer shape func * [numpy] Fix unit tests after introducing numpy compatible shapes (#14487) * Fix infer shape rnn * Fix boolean mask and custom op unit tests * Fix multi proposal * Fix diag * Add global switch for backward compatibility and fix infer shape bugs * Fix slice op infer shape * Fix rnn infer shape * Add util funcs for ndim_is_known and dim_size_is_known * Revert rnn_cell.py * Fix a bug to pass the test in test_contrib_rnn (#14520) * fix. * remove type conversion. * remove type cast. * [numpy] Fix test_dynamic_shape.test_dynamic_shape (#14538) * Initial commit * Address comments from Jun * [numpy] Fix numpy import in python2 (#14537) * Fix several test failures * Fix subgraph op infer shape * Fix sparse slice * Fix deconv infer shape * Fix numpy import compatibility problem in python2 * fix concat and slice (#14549) * fix R-package (#14536) * Fix cpp package build after using new shape definition (#14554) * Fix pooling_v1 and deformable_convolution param initialization (#14577) * Fix pooling_v1 param initialization * Fix deformable_convolution param initialization * [Numpy] Misc fix (#14612) * [Numpy] Misc Fix * fix build * !shape_is_none => shape_is_known * Address comments * Fix * [Numpy] fix test_operator_gpu.test_upsampling_bilinear_with_type (#14557) * Fix test_operator_gpu.test_upsampling_bilinear_with_type * Address comments * [Numpy] Java/Scala modification (#14625) * modify jni to support 0 dim/shape * fix transpose axes default value * fix shape index bug (#14630) * fix jni lint (#14634) * [numpy] Fix numpy branch failing tests in CI (#14639) * Remove numpy namespaces for operator registration * Fix bug when shape is compeltely unknown * Fix singed/unsigned compare warning * Fix CI * Fix pylint * Avoid launching gpu kernels for zero-size output tensors * Fix test_ndarray * Fix binary broadcast with zero-size tensors * Better error message for infer shape failure in imperative * Fix TShape constructor ambiguity on certain platforms * Fix mkldnn build failure * Fix build failure in gpu and cpp test * Fix gpu cpp test build with mkldnn * Fix mkldnn cpp test * Fix concatenating zero-size tensors * Avoid letting mkldnn handle zero-size tensors in concat * Fix quantized_concat infer shape * Try to fix perl c api * fix invalid ndarray dispose (#14657) * swig fixes for the changes in c_api.h (#14655) * Rename np_comp to np_compat for readability * Fix import error * Keep old c apis unchanged * Fix lint * Rebase and fix build * Fix R build failure * Fix Perl build failure * Rebase with master * Address cr comments * Use just one scope to represent numpy compatibility * Add code comment to NumpyScope object in Scala * Add use_np_compat decorator * Fix pylint --- R-package/src/ndarray.cc | 6 +- R-package/src/symbol.cc | 20 +- cpp-package/include/mxnet-cpp/ndarray.hpp | 8 +- cpp-package/include/mxnet-cpp/symbol.hpp | 32 +- include/mxnet/c_api.h | 212 +++++++- include/mxnet/imperative.h | 16 + include/mxnet/ndarray.h | 11 +- include/mxnet/tensor_blob.h | 3 +- include/mxnet/tuple.h | 160 ++++-- .../AI-MXNet/lib/AI/MXNet/Executor.pm | 2 +- perl-package/AI-MXNet/lib/AI/MXNet/NDArray.pm | 4 +- perl-package/AI-MXNet/lib/AI/MXNet/Symbol.pm | 4 +- perl-package/AI-MXNetCAPI/mxnet.i | 172 +++---- perl-package/AI-MXNetCAPI/mxnet_typemaps.i | 22 +- python/mxnet/__init__.py | 2 +- python/mxnet/base.py | 151 +++++- python/mxnet/executor.py | 48 +- python/mxnet/ndarray/_internal.py | 2 - python/mxnet/ndarray/contrib.py | 1 + python/mxnet/ndarray/ndarray.py | 19 +- python/mxnet/ndarray/register.py | 7 +- python/mxnet/operator.py | 26 +- python/mxnet/symbol/_internal.py | 2 - python/mxnet/symbol/register.py | 7 +- python/mxnet/symbol/symbol.py | 113 +++-- .../scala/org/apache/mxnet/Executor.scala | 143 ++---- .../main/scala/org/apache/mxnet/LibInfo.scala | 32 +- .../main/scala/org/apache/mxnet/NDArray.scala | 10 +- .../scala/org/apache/mxnet/NumpyScope.scala | 63 +++ .../main/scala/org/apache/mxnet/Symbol.scala | 38 +- .../org/apache/mxnet/NumpyScopeSuite.scala | 34 ++ .../apache/mxnet/utils/CToScalaUtils.scala | 3 +- .../native/org_apache_mxnet_native_c_api.cc | 219 ++++++++- .../native/org_apache_mxnet_native_c_api.h | 32 ++ src/c_api/c_api.cc | 41 +- src/c_api/c_api_common.h | 31 ++ src/c_api/c_api_executor.cc | 460 +++++++++++++++++- src/c_api/c_api_ndarray.cc | 12 + src/c_api/c_api_symbolic.cc | 107 +++- src/c_api/c_predict_api.cc | 1 + src/common/exec_utils.h | 4 +- src/common/utils.h | 60 ++- src/executor/graph_executor.cc | 3 +- src/executor/infer_graph_attr_pass.cc | 20 +- src/imperative/imperative.cc | 4 +- src/imperative/imperative_utils.cc | 3 +- src/imperative/imperative_utils.h | 26 +- src/io/image_io.cc | 4 +- src/kvstore/gradient_compression.cc | 10 +- src/ndarray/ndarray.cc | 17 +- src/ndarray/ndarray_function.cc | 6 +- src/ndarray/ndarray_function.h | 2 +- src/nnvm/plan_memory.cc | 6 +- src/operator/batch_norm_v1-inl.h | 2 +- src/operator/bilinear_sampler-inl.h | 4 +- src/operator/channel_op_common.h | 4 + src/operator/contrib/adamw-inl.h | 5 +- .../contrib/adaptive_avg_pooling-inl.h | 6 +- src/operator/contrib/bilinear_resize-inl.h | 2 +- src/operator/contrib/boolean_mask.cc | 2 +- src/operator/contrib/bounding_box-inl.h | 11 +- src/operator/contrib/count_sketch-inl.h | 2 +- .../contrib/deformable_convolution-inl.h | 14 +- src/operator/contrib/dgl_graph.cc | 64 +-- src/operator/contrib/fft-inl.h | 2 +- src/operator/contrib/ifft-inl.h | 2 +- src/operator/contrib/index_copy-inl.h | 5 +- src/operator/contrib/multi_proposal-inl.h | 2 +- src/operator/contrib/multibox_detection-inl.h | 2 +- src/operator/contrib/multibox_prior-inl.h | 4 +- src/operator/contrib/nnvm_to_onnx.cc | 3 +- src/operator/contrib/optimizer_op.cc | 2 +- src/operator/contrib/proposal-inl.h | 2 +- src/operator/contrib/quadratic_op-inl.h | 2 +- src/operator/contrib/sync_batch_norm-inl.h | 2 +- src/operator/contrib/transformer-inl.h | 4 +- src/operator/control_flow.cc | 112 ++--- src/operator/convolution_v1-inl.h | 8 +- src/operator/custom/custom.cc | 18 +- src/operator/image/image_random-inl.h | 4 +- src/operator/image/resize-inl.h | 4 +- src/operator/leaky_relu-inl.h | 6 +- src/operator/loss_binary_op-inl.h | 2 +- src/operator/mxnet_op.h | 2 + src/operator/nn/batch_norm.cc | 2 +- src/operator/nn/concat.cc | 51 +- src/operator/nn/convolution-inl.h | 18 +- src/operator/nn/convolution.cc | 85 ++-- src/operator/nn/ctc_loss-inl.h | 2 +- src/operator/nn/cudnn/cudnn_batch_norm.cc | 2 +- src/operator/nn/cudnn/cudnn_convolution-inl.h | 6 +- .../nn/cudnn/cudnn_deconvolution-inl.h | 6 +- src/operator/nn/deconvolution-inl.h | 28 +- src/operator/nn/deconvolution.cc | 50 +- src/operator/nn/dropout-inl.h | 2 +- src/operator/nn/dropout.cc | 4 +- src/operator/nn/fully_connected.cc | 4 +- src/operator/nn/im2col.h | 4 +- src/operator/nn/layer_norm-inl.h | 2 +- src/operator/nn/layer_norm.cc | 6 +- src/operator/nn/lrn.cc | 2 +- src/operator/nn/mkldnn/mkldnn_base-inl.h | 2 +- src/operator/nn/mkldnn/mkldnn_concat.cc | 12 +- src/operator/nn/mkldnn/mkldnn_slice.cc | 6 +- src/operator/nn/mkldnn/mkldnn_transpose.cc | 6 +- src/operator/nn/pooling-inl.h | 14 +- src/operator/nn/pooling.cc | 6 +- src/operator/nn/upsampling.cc | 2 +- src/operator/operator_common.h | 19 +- src/operator/operator_util.cc | 2 +- src/operator/pad-inl.h | 2 +- src/operator/pooling_v1-inl.h | 35 +- src/operator/quantization/dequantize-inl.h | 4 +- .../mkldnn/mkldnn_requantize-inl.h | 2 +- src/operator/quantization/quantize-inl.h | 4 +- src/operator/quantization/quantize_v2-inl.h | 2 +- src/operator/quantization/quantized_concat.cc | 18 +- src/operator/quantization/quantized_conv.cc | 4 +- .../quantization/quantized_flatten-inl.h | 6 +- .../quantization/quantized_fully_connected.cc | 6 +- .../quantization/quantized_pooling.cc | 4 +- src/operator/quantization/requantize-inl.h | 2 +- src/operator/random/multisample_op.h | 2 +- src/operator/random/sample_multinomial_op.h | 18 +- src/operator/random/unique_sample_op.h | 2 +- src/operator/regression_output-inl.h | 2 +- src/operator/rnn.cc | 2 +- src/operator/sequence_last-inl.h | 2 +- src/operator/slice_channel-inl.h | 17 +- src/operator/softmax_output-inl.h | 12 +- src/operator/softmax_output.cc | 12 +- src/operator/spatial_transformer-inl.h | 4 +- src/operator/subgraph_op_common.cc | 4 +- src/operator/subgraph_op_common.h | 12 +- src/operator/svm_output-inl.h | 6 +- src/operator/swapaxis-inl.h | 6 +- src/operator/tensor/broadcast_reduce_op.h | 66 ++- src/operator/tensor/diag_op-inl.h | 12 +- src/operator/tensor/dot-inl.h | 12 +- .../tensor/elemwise_binary_broadcast_op.h | 53 +- .../tensor/elemwise_unary_op_basic.cc | 10 +- src/operator/tensor/histogram-inl.h | 14 +- src/operator/tensor/indexing_op.h | 29 +- src/operator/tensor/init_op.h | 18 +- src/operator/tensor/la_op.h | 2 +- src/operator/tensor/matrix_op-inl.h | 311 ++++++------ src/operator/tensor/matrix_op.cc | 4 +- src/operator/tensor/ordering_op-inl.h | 2 +- src/operator/tensor/slice-inl.h | 6 +- tests/cpp/include/test_mkldnn.h | 18 +- tests/cpp/include/test_util.h | 4 +- tests/cpp/misc/serialization.cc | 2 +- tests/cpp/operator/batchnorm_test.cc | 4 +- tests/cpp/operator/mkldnn_operator_test.cc | 6 +- tests/python/gpu/test_operator_gpu.py | 22 +- tests/python/unittest/test_infer_shape.py | 16 + tests/python/unittest/test_ndarray.py | 6 +- tests/python/unittest/test_operator.py | 88 +++- 158 files changed, 2838 insertions(+), 1145 deletions(-) create mode 100644 scala-package/core/src/main/scala/org/apache/mxnet/NumpyScope.scala create mode 100644 scala-package/core/src/test/scala/org/apache/mxnet/NumpyScopeSuite.scala diff --git a/R-package/src/ndarray.cc b/R-package/src/ndarray.cc index 94d24f3fb46b..0409d3ba8887 100644 --- a/R-package/src/ndarray.cc +++ b/R-package/src/ndarray.cc @@ -179,9 +179,9 @@ Rcpp::RObject NDArrayPacker::CreateNDArrayPacker() { } Rcpp::Dimension NDArray::dim() const { - mx_uint ndim; - const mx_uint *pshape; - MX_CALL(MXNDArrayGetShape( + int ndim; + const int *pshape; + MX_CALL(MXNDArrayGetShapeEx( ptr_->handle, &ndim, &pshape)); Rcpp::IntegerVector dat(pshape, pshape + ndim); std::reverse(dat.begin(), dat.end()); diff --git a/R-package/src/symbol.cc b/R-package/src/symbol.cc index 031c9a254019..317e82568012 100644 --- a/R-package/src/symbol.cc +++ b/R-package/src/symbol.cc @@ -167,8 +167,8 @@ Symbol::RObjectType Symbol::GetOutput(mx_uint index) const { // helper function to convert shape into Rcpp vector inline Rcpp::List BuildShapeData(mx_uint shape_size, - const mx_uint *shape_ndim, - const mx_uint **shape_data, + const int *shape_ndim, + const int **shape_data, const std::vector &names) { Rcpp::List ret(shape_size); for (mx_uint i = 0; i < shape_size; ++i) { @@ -185,7 +185,7 @@ SEXP Symbol::InferShape(const Rcpp::List& kwargs) const { << "Need to pass parameters in key=value style.\n"; std::vector keys = kwargs.names(); std::vector arg_ind_ptr(1, 0); - std::vector arg_shape_data; + std::vector arg_shape_data; for (size_t i = 0; i < kwargs.size(); ++i) { RCHECK(keys[i].length() != 0) @@ -197,17 +197,17 @@ SEXP Symbol::InferShape(const Rcpp::List& kwargs) const { std::vector c_keys = CKeys(keys); mx_uint in_shape_size; - const mx_uint *in_shape_ndim; - const mx_uint **in_shape_data; + const int *in_shape_ndim; + const int **in_shape_data; mx_uint out_shape_size; - const mx_uint *out_shape_ndim; - const mx_uint **out_shape_data; + const int *out_shape_ndim; + const int **out_shape_data; mx_uint aux_shape_size; - const mx_uint *aux_shape_ndim; - const mx_uint **aux_shape_data; + const int *aux_shape_ndim; + const int **aux_shape_data; int complete; - MX_CALL(MXSymbolInferShape( + MX_CALL(MXSymbolInferShapeEx( handle_, static_cast(kwargs.size()), dmlc::BeginPtr(c_keys), dmlc::BeginPtr(arg_ind_ptr), dmlc::BeginPtr(arg_shape_data), &in_shape_size, &in_shape_ndim, &in_shape_data, diff --git a/cpp-package/include/mxnet-cpp/ndarray.hpp b/cpp-package/include/mxnet-cpp/ndarray.hpp index b667542bffb5..d0438305a62e 100644 --- a/cpp-package/include/mxnet-cpp/ndarray.hpp +++ b/cpp-package/include/mxnet-cpp/ndarray.hpp @@ -397,11 +397,11 @@ inline size_t NDArray::Size() const { } inline std::vector NDArray::GetShape() const { - const mx_uint *out_pdata; - mx_uint out_dim; - MXNDArrayGetShape(blob_ptr_->handle_, &out_dim, &out_pdata); + const int *out_pdata; + int out_dim; + MXNDArrayGetShapeEx(blob_ptr_->handle_, &out_dim, &out_pdata); std::vector ret; - for (mx_uint i = 0; i < out_dim; ++i) { + for (int i = 0; i < out_dim; ++i) { ret.push_back(out_pdata[i]); } return ret; diff --git a/cpp-package/include/mxnet-cpp/symbol.hpp b/cpp-package/include/mxnet-cpp/symbol.hpp index aed963949060..2e3fb7a2d5de 100644 --- a/cpp-package/include/mxnet-cpp/symbol.hpp +++ b/cpp-package/include/mxnet-cpp/symbol.hpp @@ -188,7 +188,7 @@ inline void Symbol::InferShape( std::vector keys; std::vector arg_ind_ptr; - std::vector arg_shape_data; + std::vector arg_shape_data; for (const auto &arg : arg_shapes) { keys.push_back(arg.first.c_str()); @@ -200,40 +200,40 @@ inline void Symbol::InferShape( arg_ind_ptr.push_back(arg_shape_data.size()); mx_uint in_shape_size; - const mx_uint *in_shape_ndim; - const mx_uint **in_shape_data; + const int *in_shape_ndim; + const int **in_shape_data; mx_uint out_shape_size; - const mx_uint *out_shape_ndim; - const mx_uint **out_shape_data; + const int *out_shape_ndim; + const int **out_shape_data; mx_uint aux_shape_size; - const mx_uint *aux_shape_ndim; - const mx_uint **aux_shape_data; + const int *aux_shape_ndim; + const int **aux_shape_data; int complete; - CHECK_EQ(MXSymbolInferShape(GetHandle(), keys.size(), keys.data(), - arg_ind_ptr.data(), arg_shape_data.data(), - &in_shape_size, &in_shape_ndim, &in_shape_data, - &out_shape_size, &out_shape_ndim, &out_shape_data, - &aux_shape_size, &aux_shape_ndim, &aux_shape_data, - &complete), + CHECK_EQ(MXSymbolInferShapeEx(GetHandle(), keys.size(), keys.data(), + arg_ind_ptr.data(), arg_shape_data.data(), + &in_shape_size, &in_shape_ndim, &in_shape_data, + &out_shape_size, &out_shape_ndim, &out_shape_data, + &aux_shape_size, &aux_shape_ndim, &aux_shape_data, + &complete), 0); if (complete) { for (mx_uint i = 0; i < in_shape_size; ++i) { in_shape->push_back(std::vector()); - for (mx_uint j = 0; j < in_shape_ndim[i]; ++j) { + for (int j = 0; j < in_shape_ndim[i]; ++j) { (*in_shape)[i].push_back(in_shape_data[i][j]); } } for (mx_uint i = 0; i < aux_shape_size; ++i) { aux_shape->push_back(std::vector()); - for (mx_uint j = 0; j < aux_shape_ndim[i]; ++j) { + for (int j = 0; j < aux_shape_ndim[i]; ++j) { (*aux_shape)[i].push_back(aux_shape_data[i][j]); } } for (mx_uint i = 0; i < out_shape_size; ++i) { out_shape->push_back(std::vector()); - for (mx_uint j = 0; j < out_shape_ndim[i]; ++j) { + for (int j = 0; j < out_shape_ndim[i]; ++j) { (*out_shape)[i].push_back(out_shape_data[i][j]); } } diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h index 2f9d74dc5ba0..0acfde0686d4 100644 --- a/include/mxnet/c_api.h +++ b/include/mxnet/c_api.h @@ -182,7 +182,7 @@ typedef int (*CustomOpFBFunc)(int /*size*/, void** /*ptrs*/, int* /*tags*/, typedef int (*CustomOpDelFunc)(void* /*state*/); typedef int (*CustomOpListFunc)(char*** /*args*/, void* /*state*/); typedef int (*CustomOpInferShapeFunc)(int /*num_input*/, int* /*ndims*/, - unsigned** /*shapes*/, void* /*state*/); + int** /*shapes*/, void* /*state*/); typedef int (*CustomOpInferStorageTypeFunc)(int /*num_input*/, int* /*stypes*/, void* /*state*/); typedef int (*CustomOpBackwardInferStorageTypeFunc)(int /*num_input*/, int * /*stypes*/, @@ -768,7 +768,8 @@ MXNET_DLL int MXNDArrayReshape64(NDArrayHandle handle, bool reverse, NDArrayHandle *out); /*! - * \brief get the shape of the array + * \brief DEPRECATED. Use MXNDArrayGetShapeEx instead. + * get the shape of the array * \param handle the handle to the narray * \param out_dim the output dimension * \param out_pdata pointer holder to get data pointer of the shape @@ -777,6 +778,16 @@ MXNET_DLL int MXNDArrayReshape64(NDArrayHandle handle, MXNET_DLL int MXNDArrayGetShape(NDArrayHandle handle, mx_uint *out_dim, const mx_uint **out_pdata); +/*! + * \brief get the shape of the array + * \param handle the handle to the narray + * \param out_dim the output dimension + * \param out_pdata pointer holder to get data pointer of the shape + * \return 0 when success, -1 when failure happens + */ +MXNET_DLL int MXNDArrayGetShapeEx(NDArrayHandle handle, + int *out_dim, + const int **out_pdata); /*! * \brief get the content of the data in NDArray * \param handle the handle to the ndarray @@ -1048,6 +1059,19 @@ MXNET_DLL int MXAutogradIsRecording(bool* curr); * \return 0 when success, -1 when failure happens */ MXNET_DLL int MXAutogradIsTraining(bool* curr); +/*! + * \brief get whether numpy compatibility is on + * \param curr returns the current status + * \return 0 when success, -1 when failure happens + */ +MXNET_DLL int MXIsNumpyCompatible(bool* curr); +/*! + * \brief set numpy compatibility switch + * \param is_np_comp 1 when numpy compatibility is on, 0 when off + * \param prev returns the previous status before this set + * \return 0 when success, -1 when failure happens + */ +MXNET_DLL int MXSetIsNumpyCompatible(int is_np_comp, int* prev); /*! * \brief mark NDArrays as variables to compute gradient for autograd * \param num_var number of variable NDArrays @@ -1468,7 +1492,8 @@ MXNET_DLL int MXSymbolGrad(SymbolHandle sym, const char** wrt, SymbolHandle* out); /*! - * \brief infer shape of unknown input shapes given the known one. + * \brief DEPRECATED. Use MXSymbolInferShapeEx instead. + * infer shape of unknown input shapes given the known one. * The shapes are packed into a CSR matrix represented by arg_ind_ptr and arg_shape_data * The call will be treated as a kwargs call if key != nullptr or num_args==0, otherwise it is positional. * @@ -1504,8 +1529,47 @@ MXNET_DLL int MXSymbolInferShape(SymbolHandle sym, const mx_uint **aux_shape_ndim, const mx_uint ***aux_shape_data, int *complete); + /*! - * \brief partially infer shape of unknown input shapes given the known one. + * \brief infer shape of unknown input shapes given the known one. + * The shapes are packed into a CSR matrix represented by arg_ind_ptr and arg_shape_data + * The call will be treated as a kwargs call if key != nullptr or num_args==0, otherwise it is positional. + * + * \param sym symbol handle + * \param num_args numbe of input arguments. + * \param keys the key of keyword args (optional) + * \param arg_ind_ptr the head pointer of the rows in CSR + * \param arg_shape_data the content of the CSR + * \param in_shape_size sizeof the returning array of in_shapes + * \param in_shape_ndim returning array of shape dimensions of eachs input shape. + * \param in_shape_data returning array of pointers to head of the input shape. + * \param out_shape_size sizeof the returning array of out_shapes + * \param out_shape_ndim returning array of shape dimensions of eachs input shape. + * \param out_shape_data returning array of pointers to head of the input shape. + * \param aux_shape_size sizeof the returning array of aux_shapes + * \param aux_shape_ndim returning array of shape dimensions of eachs auxiliary shape. + * \param aux_shape_data returning array of pointers to head of the auxiliary shape. + * \param complete whether infer shape completes or more information is needed. + * \return 0 when success, -1 when failure happens + */ +MXNET_DLL int MXSymbolInferShapeEx(SymbolHandle sym, + mx_uint num_args, + const char** keys, + const mx_uint *arg_ind_ptr, + const int *arg_shape_data, + mx_uint *in_shape_size, + const int **in_shape_ndim, + const int ***in_shape_data, + mx_uint *out_shape_size, + const int **out_shape_ndim, + const int ***out_shape_data, + mx_uint *aux_shape_size, + const int **aux_shape_ndim, + const int ***aux_shape_data, + int *complete); +/*! + * \brief DEPRECATED. Use MXSymbolInferShapePartialEx instead. + * partially infer shape of unknown input shapes given the known one. * * Return partially inferred results if not all shapes could be inferred. * The shapes are packed into a CSR matrix represented by arg_ind_ptr and arg_shape_data @@ -1544,6 +1608,47 @@ MXNET_DLL int MXSymbolInferShapePartial(SymbolHandle sym, const mx_uint ***aux_shape_data, int *complete); + +/*! + * \brief partially infer shape of unknown input shapes given the known one. + * + * Return partially inferred results if not all shapes could be inferred. + * The shapes are packed into a CSR matrix represented by arg_ind_ptr and arg_shape_data + * The call will be treated as a kwargs call if key != nullptr or num_args==0, otherwise it is positional. + * + * \param sym symbol handle + * \param num_args numbe of input arguments. + * \param keys the key of keyword args (optional) + * \param arg_ind_ptr the head pointer of the rows in CSR + * \param arg_shape_data the content of the CSR + * \param in_shape_size sizeof the returning array of in_shapes + * \param in_shape_ndim returning array of shape dimensions of eachs input shape. + * \param in_shape_data returning array of pointers to head of the input shape. + * \param out_shape_size sizeof the returning array of out_shapes + * \param out_shape_ndim returning array of shape dimensions of eachs input shape. + * \param out_shape_data returning array of pointers to head of the input shape. + * \param aux_shape_size sizeof the returning array of aux_shapes + * \param aux_shape_ndim returning array of shape dimensions of eachs auxiliary shape. + * \param aux_shape_data returning array of pointers to head of the auxiliary shape. + * \param complete whether infer shape completes or more information is needed. + * \return 0 when success, -1 when failure happens + */ +MXNET_DLL int MXSymbolInferShapePartialEx(SymbolHandle sym, + mx_uint num_args, + const char** keys, + const mx_uint *arg_ind_ptr, + const int *arg_shape_data, + mx_uint *in_shape_size, + const int **in_shape_ndim, + const int ***in_shape_data, + mx_uint *out_shape_size, + const int **out_shape_ndim, + const int ***out_shape_data, + mx_uint *aux_shape_size, + const int **aux_shape_ndim, + const int ***aux_shape_data, + int *complete); + /*! * \brief infer type of unknown input types given the known one. * The types are packed into a CSR matrix represented by arg_ind_ptr and arg_type_data @@ -1807,7 +1912,8 @@ MXNET_DLL int MXExecutorBindEX(SymbolHandle symbol_handle, NDArrayHandle *aux_states, ExecutorHandle shared_exec, ExecutorHandle *out); - +/*! \brief DEPRECATED. Use MXExecutorSimpleBindEx instead. + */ MXNET_DLL int MXExecutorSimpleBind(SymbolHandle symbol_handle, int dev_type, int dev_id, @@ -1843,8 +1949,44 @@ MXNET_DLL int MXExecutorSimpleBind(SymbolHandle symbol_handle, ExecutorHandle shared_exec_handle, ExecutorHandle* out); -/*! - * \brief Return a new executor with the same symbol and shared memory, + +MXNET_DLL int MXExecutorSimpleBindEx(SymbolHandle symbol_handle, + int dev_type, + int dev_id, + const mx_uint num_g2c_keys, + const char** g2c_keys, + const int* g2c_dev_types, + const int* g2c_dev_ids, + const mx_uint provided_grad_req_list_len, + const char** provided_grad_req_names, + const char** provided_grad_req_types, + const mx_uint num_provided_arg_shapes, + const char** provided_arg_shape_names, + const int* provided_arg_shape_data, + const mx_uint* provided_arg_shape_idx, + const mx_uint num_provided_arg_dtypes, + const char** provided_arg_dtype_names, + const int* provided_arg_dtypes, + const mx_uint num_provided_arg_stypes, + const char** provided_arg_stype_names, + const int* provided_arg_stypes, + const mx_uint num_shared_arg_names, + const char** shared_arg_name_list, + int* shared_buffer_len, + const char** shared_buffer_name_list, + NDArrayHandle* shared_buffer_handle_list, + const char*** updated_shared_buffer_name_list, + NDArrayHandle** updated_shared_buffer_handle_list, + mx_uint* num_in_args, + NDArrayHandle** in_args, + NDArrayHandle** arg_grads, + mx_uint* num_aux_states, + NDArrayHandle** aux_states, + ExecutorHandle shared_exec_handle, + ExecutorHandle* out); +/*! + * \brief DEPRECATED. Use MXExecutorReshapeEx instead. + * Return a new executor with the same symbol and shared memory, * but different input/output shapes. * * \param partial_shaping Whether to allow changing the shape of unspecified arguments. @@ -1883,6 +2025,46 @@ MXNET_DLL int MXExecutorReshape(int partial_shaping, NDArrayHandle** aux_states, ExecutorHandle shared_exec, ExecutorHandle *out); +/*! + * \brief Return a new executor with the same symbol and shared memory, + * but different input/output shapes. + * + * \param partial_shaping Whether to allow changing the shape of unspecified arguments. + * \param allow_up_sizing Whether to allow allocating new ndarrays that's larger than the original. + * \param dev_type device type of default context + * \param dev_id device id of default context + * \param num_map_keys size of group2ctx map + * \param map_keys keys of group2ctx map + * \param map_dev_types device type of group2ctx map + * \param map_dev_ids device id of group2ctx map + * \param num_in_args length of in_args + * \param in_args in args array + * \param arg_grads arg grads handle array + * \param num_aux_states length of auxiliary states + * \param aux_states auxiliary states array + * \param shared_exec input executor handle for memory sharing + * \param out output executor handle + * \return a new executor + */ +MXNET_DLL int MXExecutorReshapeEx(int partial_shaping, + int allow_up_sizing, + int dev_type, + int dev_id, + mx_uint num_map_keys, + const char** map_keys, + const int* map_dev_types, + const int* map_dev_ids, + const mx_uint num_provided_arg_shapes, + const char** provided_arg_shape_names, + const int* provided_arg_shape_data, + const mx_uint* provided_arg_shape_idx, + mx_uint* num_in_args, + NDArrayHandle** in_args, + NDArrayHandle** arg_grads, + mx_uint* num_aux_states, + NDArrayHandle** aux_states, + ExecutorHandle shared_exec, + ExecutorHandle *out); /*! * \brief get optimized graph from graph executor @@ -2542,7 +2724,8 @@ MXNET_DLL int MXRtcCudaKernelCall(CudaKernelHandle handle, int dev_id, void** ar MXNET_DLL int MXNDArrayGetSharedMemHandle(NDArrayHandle handle, int* shared_pid, int* shared_id); /*! - * \brief Reconstruct NDArray from shared memory handle + * \brief DEPRECATED. Use MXNDArrayCreateFromSharedMemEx instead. + * Reconstruct NDArray from shared memory handle * \param shared_pid shared PID * \param shared_id shared memory id * \param shape pointer to NDArray dimensions @@ -2553,6 +2736,19 @@ MXNET_DLL int MXNDArrayGetSharedMemHandle(NDArrayHandle handle, int* shared_pid, MXNET_DLL int MXNDArrayCreateFromSharedMem(int shared_pid, int shared_id, const mx_uint *shape, mx_uint ndim, int dtype, NDArrayHandle *out); + +/*! + * \brief Reconstruct NDArray from shared memory handle + * \param shared_pid shared PID + * \param shared_id shared memory id + * \param shape pointer to NDArray dimensions + * \param ndim number of NDArray dimensions + * \param dtype data type of NDArray + * \param out constructed NDArray + */ +MXNET_DLL int MXNDArrayCreateFromSharedMemEx(int shared_pid, int shared_id, const int *shape, + int ndim, int dtype, NDArrayHandle *out); + /*! * \brief Push an asynchronous operation to the engine. * \param async_func Execution function whici takes a parameter on_complete diff --git a/include/mxnet/imperative.h b/include/mxnet/imperative.h index 52cedb2fadd9..ad209913ac53 100644 --- a/include/mxnet/imperative.h +++ b/include/mxnet/imperative.h @@ -97,6 +97,16 @@ class Imperative { is_recording_ = is_recording; return old; } + /*! brief whether numpy compatibility is on. */ + bool is_np_comp() const { + return is_np_comp_; + } + /*! brief turn on or turn off numpy compatibility switch. */ + bool set_is_np_comp(bool is_np_comp) { + bool old = is_np_comp_; + is_np_comp_ = is_np_comp; + return old; + } /*! \brief to record operator, return corresponding node. */ void RecordOp(nnvm::NodeAttrs&& attrs, const std::vector& inputs, @@ -165,9 +175,15 @@ class Imperative { #if DMLC_CXX11_THREAD_LOCAL static thread_local bool is_train_; static thread_local bool is_recording_; + // TOOD(junwu): Added numpy compatibility switch for backward compatibility. + // Delete it in the next major release. + static thread_local bool is_np_comp_; #else static MX_THREAD_LOCAL bool is_train_; static MX_THREAD_LOCAL bool is_recording_; + // TOOD(junwu): Added numpy compatibility switch for backward compatibility. + // Delete it in the next major release. + static MX_THREAD_LOCAL bool is_np_comp_; #endif /*! \brief node count used for naming */ std::atomic node_count_{0}; diff --git a/include/mxnet/ndarray.h b/include/mxnet/ndarray.h index d00cb479b92e..13fb42ce521e 100644 --- a/include/mxnet/ndarray.h +++ b/include/mxnet/ndarray.h @@ -859,12 +859,15 @@ class NDArray { Chunk(mxnet::TShape shape, Context ctx_, bool delay_alloc_, int dtype) : static_data(false), delay_alloc(true), ctx(ctx_), storage_ref_(Storage::_GetSharedRef()) { - auto size = shape.Size(); storage_shape = shape; + if (shape_is_known(storage_shape)) { + shandle.size = shape.Size() * mshadow::mshadow_sizeof(dtype); + } var = Engine::Get()->NewVariable(); - shandle.size = size * mshadow::mshadow_sizeof(dtype); shandle.ctx = ctx_; - if (!delay_alloc_) this->CheckAndAlloc(); + if (!delay_alloc_) { + this->CheckAndAlloc(); + } } Chunk(const TBlob &data, int dev_id) @@ -953,7 +956,7 @@ class NDArray { /*! \brief set the shape for ith aux data, and update storage shape if necessary */ inline void set_aux_shape(const size_t i, const mxnet::TShape& shape) { aux_shapes[i] = shape; - if (storage_shape.ndim() > 0) { + if (storage_shape.ndim() >= 0) { if (storage_type == kRowSparseStorage && i == rowsparse::kIdx) { storage_shape[0] = shape[0]; } else if (storage_type == kCSRStorage && i == csr::kIdx) { diff --git a/include/mxnet/tensor_blob.h b/include/mxnet/tensor_blob.h index 7d059025b03e..a7a57266dab8 100755 --- a/include/mxnet/tensor_blob.h +++ b/include/mxnet/tensor_blob.h @@ -198,7 +198,6 @@ class TBlob { << "Expected: " << type_flag_ << " v.s. given " << mshadow::DataType::kFlag; return mshadow::Tensor(static_cast(dptr_), shape_.FlatTo2D(), - shape_[shape_.ndim() - 1], stream); } /*! @@ -419,6 +418,8 @@ class TBlob { namespace dmlc { // Add a few patches to support mxnet::TShape in dmlc/parameter. DMLC_DECLARE_TYPE_NAME(mxnet::TShape, "Shape(tuple)"); +DMLC_DECLARE_TYPE_NAME(mxnet::Tuple, "Shape(tuple)"); +DMLC_DECLARE_TYPE_NAME(mxnet::Tuple>, "Shape(tuple)"); DMLC_DECLARE_TYPE_NAME(nnvm::Tuple, "Shape(tuple)"); DMLC_DECLARE_TYPE_NAME(nnvm::Tuple>, "Shape(tuple)"); diff --git a/include/mxnet/tuple.h b/include/mxnet/tuple.h index 7c1367333630..8431bbb23b96 100644 --- a/include/mxnet/tuple.h +++ b/include/mxnet/tuple.h @@ -17,7 +17,7 @@ * under the License. */ /*! - * Copyright (c) 2016 by Contributors + * Copyright (c) 2019 by Contributors * \file mxnet/tuple.h * \brief Data structure Tuple and TShape to store dynamic sized shapes. */ @@ -39,11 +39,14 @@ namespace mxnet { /*! * \brief A dynamic sized array data structure that is optimized for storing - * small number of elements with same type. + * small number of elements with same type. * * Data will be stored in stack when number of elements is small. * It is suitable to hold shape of Tensor. * + * The ndim of a valid tuple is an integer in range [0, inf). + * ndim = 0 means the tuple is empty. + * * \tparam ValueType The type of data stored inside tuple. * \sa TShape */ @@ -61,7 +64,11 @@ class Tuple { * \param s the source tuple */ inline Tuple(const Tuple& s) { - this->assign(s.begin(), s.end()); + if (s.ndim() == -1) { + this->SetDim(-1); + } else { + this->assign(s.begin(), s.end()); + } } /*! * \brief constructor from initializer list @@ -106,6 +113,7 @@ class Tuple { inline void assign(RandomAccessIterator begin, RandomAccessIterator end) { this->SetDim(end - begin); + CHECK_GE(ndim(), 0); std::copy(begin, end, this->begin()); } /*! @@ -124,7 +132,11 @@ class Tuple { * \return reference of self */ inline Tuple& operator=(const Tuple& src) { - this->assign(src.begin(), src.end()); + if (src.ndim() == -1) { + this->SetDim(-1); + } else { + this->assign(src.begin(), src.end()); + } return *this; } /*! @@ -151,6 +163,7 @@ class Tuple { */ inline bool operator==(const Tuple &s) const { if (ndim_ != s.ndim_) return false; + if (ndim() == -1) return true; return std::equal(begin(), end(), s.begin()); } /*! @@ -177,7 +190,7 @@ class Tuple { return ndim_ <= kStackCache ? (data_stack_ + ndim_): (data_heap_ + ndim_); } /*! \return number of dimension of the tuple */ - inline uint32_t ndim() const { + inline int ndim() const { return ndim_; } /*! @@ -185,7 +198,8 @@ class Tuple { * \param i dimension index * \return the corresponding dimension size */ - inline ValueType& operator[](size_t i) { + inline ValueType& operator[](int i) { + CHECK(i >= 0 && i < ndim()) << "index = " << i << " must be in range [0, " << ndim() << ")"; return begin()[i]; } /*! @@ -193,7 +207,8 @@ class Tuple { * \param i dimension index * \return the corresponding dimension size */ - inline const ValueType& operator[](size_t i) const { + inline const ValueType& operator[](int i) const { + CHECK(i >= 0 && i < ndim()) << "index = " << i << " must be in range [0, " << ndim() << ")"; return begin()[i]; } /*! @@ -220,6 +235,13 @@ class Tuple { * \return the ostream */ friend std::ostream &operator<<(std::ostream &os, const Tuple &t) { + if (t.ndim() == -1) { + // If t is an unknown shape, return string "None". + // This is consistent with returning unknown shape in Python and generating + // C++ operator APIs by OpWrapperGenerator.py (defaultString) in cpp-package. + os << "None"; + return os; + } os << '['; const ValueType* begin = t.begin(); const ValueType* end = t.end(); @@ -252,14 +274,16 @@ class Tuple { if (!isspace(ch)) { is.setstate(std::ios::failbit); return is; + } } - } - // Handle empty tuple + // Handle empty tuple. A tensor whose shape is an empty tuple + // represents a scalar with ndim = 0. while (isspace(is.peek())) { is.get(); } if (is.peek() == ')' || is.peek() == ']') { is.get(); + t.SetDim(0); return is; } // Handle non-empty tuple @@ -316,48 +340,85 @@ class Tuple { protected: // stack cache size - static const uint32_t kStackCache = 4; + static const int kStackCache = 4; /*! \brief number of dimension of the tuple */ - uint32_t ndim_{0}; + int ndim_{0}; /*! \brief number of cells allocated in data_heap_ */ - uint32_t num_heap_allocated_{0}; + int num_heap_allocated_{0}; /*! \brief in stack space used to store shape when it is small */ ValueType data_stack_[kStackCache]; /*! \brief space to store shape when dimension is big*/ ValueType* data_heap_{nullptr}; // internal function to change the dimension - inline void SetDim(uint32_t ndim) { + inline void SetDim(int ndim) { + CHECK_GE(ndim, -1) << "ndim cannot be less than -1, received " << ndim; if (ndim > kStackCache && ndim > num_heap_allocated_) { delete [] data_heap_; data_heap_ = new ValueType[ndim]; num_heap_allocated_ = ndim; + } else if (ndim <= 0 && data_heap_ != nullptr) { + delete [] data_heap_; + data_heap_ = nullptr; + num_heap_allocated_ = 0; } ndim_ = ndim; } }; + +/*! brief check if a shape's ndim is known. */ +inline bool ndim_is_known(const int ndim) { + CHECK_GE(ndim, -1) << "shape ndim must be >= -1, while received " << ndim; + return ndim != -1; +} + +/*! brief check if a shape's dim size is known. */ +inline bool dim_size_is_known(const dim_t dim_size) { + CHECK_GE(dim_size, -1) << "shape dim size must be >= -1, while received " << dim_size; + return dim_size != -1; +} + /*! * \brief A Shape class that is used to represent shape of each tensor. + * + * The ndim of a valid shape is an integer in range [-1, inf). + * ndim = -1 means the shape information is unknown and need to be inferred. + * ndim = 0 means the tensor with the shape is a scalar. + * + * The dimension size of a valid shape is an integer in range [-1, inf). + * dim_size = -1 means the size of that dimension is unknown and need to be inferred. + * dim_size = 0 means that dimension is empty. + * + * The definition of ndim = 0 and dim_size = 0 is consistent with NumPy. */ class TShape : public Tuple { public: /*! \brief default constructor */ - TShape() = default; + TShape() { + this->SetDim(-1); + } /*! - * constructor to construct a shape with all 1. + * constructor to construct a shape with all `value`. * \param ndim the number of dimension + * \param value the dimension size for all dims */ - inline TShape(uint32_t ndim) { // NOLINT(*) + inline TShape(const int ndim, const dim_t value) { // NOLINT(*) this->SetDim(ndim); - std::fill_n(begin(), ndim, 1); + if (ndim > 0) { + std::fill_n(begin(), ndim, value); + } } /*! * \brief copy constructor of TShape * \param s source shape. */ inline TShape(const Tuple& s) { // NOLINT(*) - this->assign(s.begin(), s.end()); + if (s.ndim() == -1) { + this->SetDim(-1); + } else { + this->assign(s.begin(), s.end()); + } } /*! * \brief constructor from initializer list @@ -374,12 +435,17 @@ class TShape : public Tuple { this->swap(s); } /*! - * \brief construct the Tuple from content of iterator + * \brief construct the Tuple from content of iterator. + * This function is enforced with template arguments of random access iterator types. + * This is necessary to distinguish from another constructor: TShape(const int, const dim_t). * \param begin the beginning of iterator * \param end end the end of the iterator * \tparam RandomAccessIterator iterator type */ - template + template::iterator_category, + std::random_access_iterator_tag>::value, int>::type = 0> inline TShape(RandomAccessIterator begin, RandomAccessIterator end) { this->assign(begin, end); @@ -390,7 +456,11 @@ class TShape : public Tuple { * \return self. */ inline TShape& operator=(const Tuple& src) { - this->assign(src.begin(), src.end()); + if (src.ndim() == -1) { + this->SetDim(-1); + } else { + this->assign(src.begin(), src.end()); + } return *this; } /*! @@ -404,9 +474,11 @@ class TShape : public Tuple { } /*! \return total number of elements in the shape */ inline size_t Size() const { + CHECK(ndim_is_known(this->ndim())) << "Shape is unknown."; dim_t size = 1; const dim_t* start = begin(), *fin = end(); for (const dim_t* it = start; it != fin; ++it) { + CHECK(dim_size_is_known(*it)) << "Shape dim size cannot be a negative value " << *it; size *= *it; } return size; @@ -417,9 +489,14 @@ class TShape : public Tuple { * \param dimend end dimension */ inline size_t ProdShape(int dimstart, int dimend) const { + CHECK(ndim_is_known(this->ndim())) << "Shape is unknown."; + CHECK_GE(dimstart, 0) << "dimstart must be >= 0, while received " << dimstart; + CHECK_LE(dimend, this->ndim()) << "dimend must be <= " << this->ndim() + << ", while received " << dimend; dim_t num = 1; const dim_t *d = this->data(); for (int i = dimstart; i < dimend; ++i) { + CHECK(dim_size_is_known(d[i])) << "Shape dim size must be known, while received " << d[i]; num *= d[i]; } return num; @@ -460,7 +537,7 @@ class TShape : public Tuple { */ template inline mshadow::Shape get() const { - CHECK_EQ(dim, static_cast(ndim())) + CHECK_EQ(dim, ndim()) << "dimension do not match target dimension " << dim << " vs " << ndim(); const dim_t *d = this->data(); mshadow::Shape s; @@ -475,11 +552,12 @@ class TShape : public Tuple { */ inline mshadow::Shape<2> FlatTo2D(void) const { mshadow::Shape<2> s; - if (ndim() == 0) return mshadow::Shape2(0, 0); + CHECK(ndim_is_known(ndim())) << "shape must have a valid ndim"; + if (ndim() == 0) return mshadow::Shape2(1, 1); const dim_t *d = this->data(); s.shape_[1] = d[ndim() - 1]; dim_t ymax = 1; - for (size_t i = 1; i < ndim(); ++i) { + for (int i = 1; i < ndim(); ++i) { ymax *= d[i - 1]; } s.shape_[0] = ymax; @@ -494,7 +572,8 @@ class TShape : public Tuple { inline mshadow::Shape<3> FlatTo3D(size_t axis_begin, size_t axis_end) const { CHECK(axis_end >= axis_begin); mshadow::Shape<3> s; - if (ndim() == 0) return mshadow::Shape3(0, 0, 0); + CHECK(ndim_is_known(ndim())) << "shape must have a valid ndim"; + if (ndim() == 0) return mshadow::Shape3(1, 1, 1); const dim_t *d = this->data(); s.shape_[0] = 1; s.shape_[1] = 1; @@ -506,7 +585,7 @@ class TShape : public Tuple { for (size_t i = axis_begin; i <= axis_end; ++i) { s.shape_[1] *= d[i]; } - for (size_t i = axis_end + 1; i < ndim(); ++i) { + for (int i = axis_end + 1; i < ndim(); ++i) { s.shape_[2] *= d[i]; } return s; @@ -552,6 +631,28 @@ class TShape : public Tuple { #endif }; +/*! brief check if a shape's ndim is known. */ +inline bool ndim_is_known(const TShape& x) { + return ndim_is_known(x.ndim()); +} + +/*! brief check if a shape's dim size is known. */ +inline bool dim_size_is_known(const TShape& x, const int idx) { + CHECK(idx >= 0 && idx < x.ndim()) + << "idx = " << idx << " exceeds shape dimension range [0, " << x.ndim() << ")"; + return dim_size_is_known(x[idx]); +} + +/*! brief check if shape is known using the NumPy compatible definition. + * zero-dim and zero-size tensors are valid. -1 means unknown.*/ +inline bool shape_is_known(const TShape& x) { + if (!ndim_is_known(x)) return false; + for (int i = 0; i < x.ndim(); ++i) { + if (!dim_size_is_known(x, i)) return false; + } + return true; +} + /*! \brief helper function to cast type of container elements */ template inline DstIter ShapeTypeCast(const SrcIter begin, @@ -567,7 +668,7 @@ inline DstIter ShapeTypeCast(const SrcIter begin, template inline TShape ShapeTypeCast(const SrcIter begin, const SrcIter end) { size_t ndim = std::distance(begin, end); - TShape res(ndim); + TShape res(ndim, -1); ShapeTypeCast(begin, end, res.begin()); return res; } @@ -613,7 +714,7 @@ struct hash > { size_t operator()(const mxnet::Tuple& val) const { std::hash hash_uint; size_t res = hash_uint(val.ndim()); - for (uint32_t i = 0; i < val.ndim(); ++i) { + for (int i = 0; i < val.ndim(); ++i) { res = dmlc::HashCombine(res, val[i]); } return res; @@ -627,7 +728,7 @@ struct hash { size_t operator()(const mxnet::TShape& val) const { std::hash hash_uint; size_t res = hash_uint(val.ndim()); - for (uint32_t i = 0; i < val.ndim(); ++i) { + for (int i = 0; i < val.ndim(); ++i) { res = dmlc::HashCombine(res, val[i]); } return res; @@ -638,6 +739,7 @@ struct hash { namespace dmlc { /*! \brief description for optional TShape */ DMLC_DECLARE_TYPE_NAME(optional, "Shape or None"); +DMLC_DECLARE_TYPE_NAME(optional>, "Shape or None"); // avoid low version of MSVC #if !defined(_MSC_VER) template diff --git a/perl-package/AI-MXNet/lib/AI/MXNet/Executor.pm b/perl-package/AI-MXNet/lib/AI/MXNet/Executor.pm index 573abbf588f2..5844302fce16 100644 --- a/perl-package/AI-MXNet/lib/AI/MXNet/Executor.pm +++ b/perl-package/AI-MXNet/lib/AI/MXNet/Executor.pm @@ -471,7 +471,7 @@ method reshape(HashRef[Shape] $kwargs, Int :$partial_shaping=0, Int :$allow_up_s my $shared_handle = $self->handle; my ($in_args_and_grad_handles, $aux_state_handles, $handle) = check_call( - AI::MXNetCAPI::ExecutorReshape( + AI::MXNetCAPI::ExecutorReshapeEx( $partial_shaping, $allow_up_sizing, $self->_ctx->device_type_id, diff --git a/perl-package/AI-MXNet/lib/AI/MXNet/NDArray.pm b/perl-package/AI-MXNet/lib/AI/MXNet/NDArray.pm index 72f6cc772178..f466aaa11a3d 100644 --- a/perl-package/AI-MXNet/lib/AI/MXNet/NDArray.pm +++ b/perl-package/AI-MXNet/lib/AI/MXNet/NDArray.pm @@ -535,7 +535,7 @@ method wait_to_read() method shape() { - return scalar(check_call(AI::MXNetCAPI::NDArrayGetShape($self->handle))); + return scalar(check_call(AI::MXNetCAPI::NDArrayGetShapeEx($self->handle))); } =head2 size @@ -1460,7 +1460,7 @@ func _new_alloc_handle($shape, $ctx, $delay_alloc, $dtype) method _new_from_shared_mem($shared_pid, $shared_id, $shape, $dtype) { my $hdl = check_call( - AI::MXNetCAPI::NDArrayCreateFromSharedMem( + AI::MXNetCAPI::NDArrayCreateFromSharedMemEx( $shared_pid, $shared_id, $shape, diff --git a/perl-package/AI-MXNet/lib/AI/MXNet/Symbol.pm b/perl-package/AI-MXNet/lib/AI/MXNet/Symbol.pm index 04dd1cbfc441..e4953f17031a 100644 --- a/perl-package/AI-MXNet/lib/AI/MXNet/Symbol.pm +++ b/perl-package/AI-MXNet/lib/AI/MXNet/Symbol.pm @@ -662,7 +662,7 @@ method _infer_shape_impl(Maybe[Str|Shape] @args) push @{ $indptr }, scalar(@{ $sdata }); } } - my $infer_func = $partial ? \&AI::MXNetCAPI::SymbolInferShapePartial : \&AI::MXNetCAPI::SymbolInferShape; + my $infer_func = $partial ? \&AI::MXNetCAPI::SymbolInferShapePartialEx : \&AI::MXNetCAPI::SymbolInferShapeEx; my ($arg_shapes, $out_shapes, $aux_shapes, $complete) = check_call( $infer_func->( $self->handle, @@ -937,7 +937,7 @@ method simple_bind( ($updated_shared_data, $in_arg_handles, $arg_grad_handles, $aux_state_handles, $exe_handle) = check_call( - AI::MXNetCAPI::ExecutorSimpleBind( + AI::MXNetCAPI::ExecutorSimpleBindEx( $self->handle, $ctx->device_type_id, $ctx->device_id, diff --git a/perl-package/AI-MXNetCAPI/mxnet.i b/perl-package/AI-MXNetCAPI/mxnet.i index 0e6a05ea9695..e38402c56100 100644 --- a/perl-package/AI-MXNetCAPI/mxnet.i +++ b/perl-package/AI-MXNetCAPI/mxnet.i @@ -640,9 +640,9 @@ int MXNDArrayReshape64(NDArrayHandle handle, * \param out_pdata pointer holder to get data pointer of the shape * \return 0 when success, -1 when failure happens */ -int MXNDArrayGetShape(NDArrayHandle handle, - mx_uint *out_dim, - const mx_uint **out_pdata); +int MXNDArrayGetShapeEx(NDArrayHandle handle, + int *out_dim, + const int **out_pdata); /*! * \brief get the content of the data in NDArray * \param handle the handle to the ndarray @@ -1289,21 +1289,21 @@ int MXSymbolGrad(SymbolHandle sym, * \param complete whether infer shape completes or more information is needed. * \return 0 when success, -1 when failure happens */ -int MXSymbolInferShape(SymbolHandle sym, - mx_uint num_args, - const char** in, - const mx_uint *in, - const mx_uint *in, - mx_uint *in_shape_size, - const mx_uint **in_shape_ndim, - const mx_uint ***in_shape_data, - mx_uint *out_shape_size, - const mx_uint **out_shape_ndim, - const mx_uint ***out_shape_data, - mx_uint *aux_shape_size, - const mx_uint **aux_shape_ndim, - const mx_uint ***aux_shape_data, - int *out); +int MXSymbolInferShapeEx(SymbolHandle sym, + mx_uint num_args, + const char** in, + const mx_uint *in, + const int *in, + mx_uint *in_shape_size, + const int **in_shape_ndim, + const int ***in_shape_data, + mx_uint *out_shape_size, + const int **out_shape_ndim, + const int ***out_shape_data, + mx_uint *aux_shape_size, + const int **aux_shape_ndim, + const int ***aux_shape_data, + int *out); /*! * \brief partially infer shape of unknown input shapes given the known one. * @@ -1328,21 +1328,21 @@ int MXSymbolInferShape(SymbolHandle sym, * \param complete whether infer shape completes or more information is needed. * \return 0 when success, -1 when failure happens */ -int MXSymbolInferShapePartial(SymbolHandle sym, - mx_uint num_args, - const char** in, - const mx_uint *in, - const mx_uint *in, - mx_uint *in_shape_size, - const mx_uint **in_shape_ndim, - const mx_uint ***in_shape_data, - mx_uint *out_shape_size, - const mx_uint **out_shape_ndim, - const mx_uint ***out_shape_data, - mx_uint *aux_shape_size, - const mx_uint **aux_shape_ndim, - const mx_uint ***aux_shape_data, - int *out); +int MXSymbolInferShapePartialEx(SymbolHandle sym, + mx_uint num_args, + const char** in, + const mx_uint *in, + const int *in, + mx_uint *in_shape_size, + const int **in_shape_ndim, + const int ***in_shape_data, + mx_uint *out_shape_size, + const int **out_shape_ndim, + const int ***out_shape_data, + mx_uint *aux_shape_size, + const int **aux_shape_ndim, + const int ***aux_shape_data, + int *out); /*! * \brief infer type of unknown input types given the known one. @@ -1535,40 +1535,40 @@ int MXExecutorBindEX(SymbolHandle symbol_handle, ExecutorHandle shared_exec, ExecutorHandle *out); -int MXExecutorSimpleBind(SymbolHandle symbol_handle, - int dev_type, - int dev_id, - const mx_uint num_g2c_keys, - const char** in, // g2c_keys, - const int* in, // g2c_dev_types, - const int* in, // g2c_dev_ids, - const mx_uint provided_grad_req_list_len, - const char** in, // provided_grad_req_names, - const char** in, // provided_grad_req_types, - const mx_uint num_provided_arg_shapes, - const char** in, // provided_arg_shape_names, - const mx_uint* in, // provided_arg_shape_data, - const mx_uint* in, // provided_arg_shape_idx, - const mx_uint num_provided_arg_dtypes, - const char** in, // provided_arg_dtype_names, - const int* in, // provided_arg_dtypes, - const mx_uint num_provided_arg_stypes, - const char** in, // provided_arg_stype_names, - const int* in, // provided_arg_stypes, - const mx_uint num_shared_arg_names, - const char** in, // shared_arg_name_list, - int* shared_buffer_len, - const char** shared_buffer_name_list, - NDArrayHandle* shared_buffer_handle_list, - const char*** updated_shared_buffer_name_list, - NDArrayHandle** updated_shared_buffer_handle_list, - mx_uint* num_in_args, - NDArrayHandle** in_args, - NDArrayHandle** arg_grads, - mx_uint* num_aux_states, - NDArrayHandle** aux_states, - ExecutorHandle shared_exec_handle, - ExecutorHandle* out +int MXExecutorSimpleBindEx(SymbolHandle symbol_handle, + int dev_type, + int dev_id, + const mx_uint num_g2c_keys, + const char** in, // g2c_keys, + const int* in, // g2c_dev_types, + const int* in, // g2c_dev_ids, + const mx_uint provided_grad_req_list_len, + const char** in, // provided_grad_req_names, + const char** in, // provided_grad_req_types, + const mx_uint num_provided_arg_shapes, + const char** in, // provided_arg_shape_names, + const int* in, // provided_arg_shape_data, + const mx_uint* in, // provided_arg_shape_idx, + const mx_uint num_provided_arg_dtypes, + const char** in, // provided_arg_dtype_names, + const int* in, // provided_arg_dtypes, + const mx_uint num_provided_arg_stypes, + const char** in, // provided_arg_stype_names, + const int* in, // provided_arg_stypes, + const mx_uint num_shared_arg_names, + const char** in, // shared_arg_name_list, + int* shared_buffer_len, + const char** shared_buffer_name_list, + NDArrayHandle* shared_buffer_handle_list, + const char*** updated_shared_buffer_name_list, + NDArrayHandle** updated_shared_buffer_handle_list, + mx_uint* num_in_args, + NDArrayHandle** in_args, + NDArrayHandle** arg_grads, + mx_uint* num_aux_states, + NDArrayHandle** aux_states, + ExecutorHandle shared_exec_handle, + ExecutorHandle* out ); /*! @@ -1592,25 +1592,25 @@ int MXExecutorSimpleBind(SymbolHandle symbol_handle, * \param out output executor handle * \return a new executor */ -int MXExecutorReshape(int partial_shaping, - int allow_up_sizing, - int dev_type, - int dev_id, - mx_uint num_map_keys, - const char** in, - const int* in, - const int* in, - const mx_uint num_provided_arg_shapes, - const char** in, - const mx_uint* in, - const mx_uint* in, - mx_uint* couple_out_size, - NDArrayHandle** out_first_array, - NDArrayHandle** out_second_array, - mx_uint* out_size, - NDArrayHandle** out_array, - ExecutorHandle shared_exec, - ExecutorHandle *out); +int MXExecutorReshapeEx(int partial_shaping, + int allow_up_sizing, + int dev_type, + int dev_id, + mx_uint num_map_keys, + const char** in, + const int* in, + const int* in, + const mx_uint num_provided_arg_shapes, + const char** in, + const int* in, + const mx_uint* in, + mx_uint* couple_out_size, + NDArrayHandle** out_first_array, + NDArrayHandle** out_second_array, + mx_uint* out_size, + NDArrayHandle** out_array, + ExecutorHandle shared_exec, + ExecutorHandle *out); /*! * \brief set a call back to notify the completion of operation diff --git a/perl-package/AI-MXNetCAPI/mxnet_typemaps.i b/perl-package/AI-MXNetCAPI/mxnet_typemaps.i index 50296c2aaba5..3ec9f95ea9c3 100644 --- a/perl-package/AI-MXNetCAPI/mxnet_typemaps.i +++ b/perl-package/AI-MXNetCAPI/mxnet_typemaps.i @@ -524,13 +524,13 @@ } } -%typemap(in,numinputs=0) (mx_uint *out_dim, const mx_uint **out_pdata) (mx_uint temp_dim, mx_uint *temp_pdata) +%typemap(in,numinputs=0) (int *out_dim, const int **out_pdata) (int temp_dim, int *temp_pdata) { $1 = &temp_dim; $2 = &temp_pdata; } -%typemap(argout) (mx_uint *out_dim, const mx_uint **out_pdata) +%typemap(argout) (int *out_dim, const int **out_pdata) { if(!result) { @@ -956,12 +956,12 @@ } } -%typemap(in,numinputs=0) (mx_uint *in_shape_size, const mx_uint **in_shape_ndim, const mx_uint ***in_shape_data) - (mx_uint temp1, mx_uint *temp2, mx_uint **temp3), - (mx_uint *out_shape_size, const mx_uint **out_shape_ndim, const mx_uint ***out_shape_data) - (mx_uint temp1, mx_uint *temp2, mx_uint **temp3), - (mx_uint *aux_shape_size, const mx_uint **aux_shape_ndim, const mx_uint ***aux_shape_data) - (mx_uint temp1, mx_uint *temp2, mx_uint **temp3) +%typemap(in,numinputs=0) (mx_uint *in_shape_size, const int **in_shape_ndim, const int ***in_shape_data) + (mx_uint temp1, int *temp2, int **temp3), + (mx_uint *out_shape_size, const int **out_shape_ndim, const int ***out_shape_data) + (mx_uint temp1, int *temp2, int **temp3), + (mx_uint *aux_shape_size, const int **aux_shape_ndim, const int ***aux_shape_data) + (mx_uint temp1, int *temp2, int **temp3) { $1 = &temp1; $2 = &temp2; @@ -969,9 +969,9 @@ *$1 = 0; } -%typemap(argout) (mx_uint *in_shape_size, const mx_uint **in_shape_ndim, const mx_uint ***in_shape_data), - (mx_uint *out_shape_size, const mx_uint **out_shape_ndim, const mx_uint ***out_shape_data), - (mx_uint *aux_shape_size, const mx_uint **aux_shape_ndim, const mx_uint ***aux_shape_data) +%typemap(argout) (mx_uint *in_shape_size, const int **in_shape_ndim, const int ***in_shape_data), + (mx_uint *out_shape_size, const int **out_shape_ndim, const int ***out_shape_data), + (mx_uint *aux_shape_size, const int **aux_shape_ndim, const int ***aux_shape_data) { if(!result && *arg15) { diff --git a/python/mxnet/__init__.py b/python/mxnet/__init__.py index 374a3b50bbb5..79eb1f10f427 100644 --- a/python/mxnet/__init__.py +++ b/python/mxnet/__init__.py @@ -23,7 +23,7 @@ from .context import Context, current_context, cpu, gpu, cpu_pinned from . import engine -from .base import MXNetError +from .base import MXNetError, is_np_compat, set_np_compat, np_compat, use_np_compat from . import base from . import contrib from . import ndarray diff --git a/python/mxnet/base.py b/python/mxnet/base.py index feb4d70b6533..58f222dc1e85 100644 --- a/python/mxnet/base.py +++ b/python/mxnet/base.py @@ -20,17 +20,18 @@ """ctypes library of mxnet and helper functions.""" from __future__ import absolute_import +from functools import wraps import atexit import ctypes import os import sys import inspect import platform -import numpy as np +import numpy as _np from . import libinfo -__all__ = ['MXNetError'] +__all__ = ['MXNetError', 'is_np_compat', 'set_np_compat', 'np_compat', 'use_np_compat'] #---------------------------- # library loading #---------------------------- @@ -44,8 +45,8 @@ long = int # pylint: enable=pointless-statement -integer_types = (int, long, np.int32, np.int64) -numeric_types = (float, int, long, np.generic) +integer_types = (int, long, _np.int32, _np.int64) +numeric_types = (float, int, long, _np.generic) string_types = basestring, if sys.version_info[0] > 2: @@ -213,10 +214,11 @@ def _load_lib(): _LIB = _load_lib() # type definitions +mx_int = ctypes.c_int mx_uint = ctypes.c_uint mx_float = ctypes.c_float mx_float_p = ctypes.POINTER(mx_float) -mx_real_t = np.float32 +mx_real_t = _np.float32 NDArrayHandle = ctypes.c_void_p FunctionHandle = ctypes.c_void_p OpHandle = ctypes.c_void_p @@ -455,7 +457,7 @@ def ctypes2numpy_shared(cptr, shape): for s in shape: size *= s dbuffer = (mx_float * size).from_address(ctypes.addressof(cptr.contents)) - return np.frombuffer(dbuffer, dtype=np.float32).reshape(shape) + return _np.frombuffer(dbuffer, dtype=_np.float32).reshape(shape) def build_param_doc(arg_names, arg_types, arg_descs, remove_dup=True): @@ -733,3 +735,140 @@ def write_all_str(module_file, module_all_list): ctypes.pythonapi.PyCapsule_New.restype = ctypes.py_object ctypes.pythonapi.PyCapsule_GetPointer.restype = ctypes.c_void_p + + +def set_np_compat(active): + """ + Turns on/off NumPy compatibility. NumPy-compatibility is turned off by default in backend. + + Parameters + ---------- + active : bool + Indicates whether to turn on/off NumPy compatibility. + + Returns + ------- + A bool value indicating the previous state of NumPy compatibility. + """ + prev = ctypes.c_int() + check_call(_LIB.MXSetIsNumpyCompatible(ctypes.c_int(active), ctypes.byref(prev))) + return bool(prev.value) + + +def is_np_compat(): + """ + Checks whether the NumPy compatibility is currently turned on. + NumPy-compatibility is turned off by default in backend. + + Returns + ------- + A bool value indicating whether the NumPy compatibility is currently on. + """ + curr = ctypes.c_bool() + check_call(_LIB.MXIsNumpyCompatible(ctypes.byref(curr))) + return curr.value + + +class _NumpyCompatibilityStateScope(object): + """Scope for managing numpy compatibility state. + Do not use this class directly. Use `np_compat(active)` instead. + + Example:: + + with _NumpyCompatibilityStateScope(True): + y = model(x) + backward([y]) + + """ + def __init__(self, is_np_compat): #pylint: disable=redefined-outer-name + self._enter_is_np_compat = is_np_compat + self._prev_is_np_compat = None + + def __enter__(self): + if self._enter_is_np_compat is not None: + self._prev_is_np_compat = set_np_compat(self._enter_is_np_compat) + + def __exit__(self, ptype, value, trace): + if self._enter_is_np_compat is not None and self._prev_is_np_compat != self._enter_is_np_compat: + set_np_compat(self._prev_is_np_compat) + + +def np_compat(active=True): + """Returns an activated/deactivated NumPy compatibility state scope to be used in 'with' statement + and captures code that needs the compatibility. + + Example:: + + with mx.np_compat(active=True): + # A scalar tensor's shape is `()`, whose `ndim` is `0`. + scalar = mx.nd.ones(shape=()) + assert scalar.shape == () + + # In NumPy compatible mode, 0 in a shape means that dimension contains zero elements. + data = mx.sym.var("data", shape=(0, 2, 3)) + ret = mx.sym.sin(data) + arg_shapes, out_shapes, _ = ret.infer_shape() + assert arg_shapes[0] == (0, 2, 3) + assert out_shapes[0] == (0, 2, 3) + + # -1 means unknown shape dimension size in the new NumPy-compatible shape definition + data = mx.sym.var("data", shape=(-1, 2, 3)) + ret = mx.sym.sin(data) + arg_shapes, out_shapes, _ = ret.infer_shape_partial() + assert arg_shapes[0] == (-1, 2, 3) + assert out_shapes[0] == (-1, 2, 3) + + # When a shape is completely unknown in NumPy-compatible mode, it is + # represented as `None` in Python. + data = mx.sym.var("data") + ret = mx.sym.sin(data) + arg_shapes, out_shapes, _ = ret.infer_shape_partial() + assert arg_shapes[0] is None + assert out_shapes[0] is None + + with mx.np_compat(active=False): + # 0 means unknown shape dimension size in the legacy shape definition. + data = mx.sym.var("data", shape=(0, 2, 3)) + ret = mx.sym.sin(data) + arg_shapes, out_shapes, _ = ret.infer_shape_partial() + assert arg_shapes[0] == (0, 2, 3) + assert out_shapes[0] == (0, 2, 3) + + # When a shape is completely unknown in the legacy mode (default), its ndim is + # equal to 0 and it is represented as `()` in Python. + data = mx.sym.var("data") + ret = mx.sym.sin(data) + arg_shapes, out_shapes, _ = ret.infer_shape_partial() + assert arg_shapes[0] == () + assert out_shapes[0] == () + """ + return _NumpyCompatibilityStateScope(active) + + +def use_np_compat(func): + """Wraps a function with an activated NumPy-compatibility scope. This ensures + that the execution of the function is guaranteed with NumPy compatible semantics, + such as zero-dim and zero size tensors. + + Example:: + import mxnet as mx + @mx.use_np_compat + def scalar_one(): + return mx.nd.ones(()) + print(scalar_one()) + + Parameters + ---------- + func : a user-provided callable function to be scoped by the NumPy compatibility state. + + Returns + ------- + Function + A function for wrapping the user functions in the NumPy compatibility scope. + """ + @wraps(func) + def _with_np_compat(*args, **kwargs): + with np_compat(active=True): + return func(*args, **kwargs) + + return _with_np_compat diff --git a/python/mxnet/executor.py b/python/mxnet/executor.py index 7bf867579d6b..9dfe63682f86 100644 --- a/python/mxnet/executor.py +++ b/python/mxnet/executor.py @@ -25,7 +25,7 @@ import copy import numpy as np from .base import _LIB -from .base import mx_uint, NDArrayHandle, ExecutorHandle, py_str +from .base import mx_uint, NDArrayHandle, ExecutorHandle, py_str, mx_int from .base import check_call, c_handle_array, c_array_buf, c_str_array from .ndarray import NDArray from .ndarray import _ndarray_cls @@ -433,29 +433,29 @@ def reshape(self, partial_shaping=False, allow_up_sizing=False, **kwargs): num_aux_states = ctypes.c_uint() aux_state_handles = ctypes.POINTER(NDArrayHandle)() - check_call(_LIB.MXExecutorReshape(ctypes.c_int(int(partial_shaping)), - ctypes.c_int(int(allow_up_sizing)), - ctypes.c_int(self._ctx.device_typeid), - ctypes.c_int(self._ctx.device_id), - mx_uint(len(ctx_map_keys)), - c_str_array(ctx_map_keys), - c_array_buf(ctypes.c_int, - py_array('i', ctx_map_dev_types)), - c_array_buf(ctypes.c_int, - py_array('i', ctx_map_dev_ids)), - mx_uint(len(provided_arg_shape_names)), - c_str_array(provided_arg_shape_names), - c_array_buf(mx_uint, - py_array('I', provided_arg_shape_data)), - c_array_buf(mx_uint, - py_array('I', provided_arg_shape_idx)), - ctypes.byref(num_in_args), - ctypes.byref(in_arg_handles), - ctypes.byref(arg_grad_handles), - ctypes.byref(num_aux_states), - ctypes.byref(aux_state_handles), - shared_handle, - ctypes.byref(handle))) + check_call(_LIB.MXExecutorReshapeEx(ctypes.c_int(int(partial_shaping)), + ctypes.c_int(int(allow_up_sizing)), + ctypes.c_int(self._ctx.device_typeid), + ctypes.c_int(self._ctx.device_id), + mx_uint(len(ctx_map_keys)), + c_str_array(ctx_map_keys), + c_array_buf(ctypes.c_int, + py_array('i', ctx_map_dev_types)), + c_array_buf(ctypes.c_int, + py_array('i', ctx_map_dev_ids)), + mx_uint(len(provided_arg_shape_names)), + c_str_array(provided_arg_shape_names), + c_array_buf(mx_int, + py_array('i', provided_arg_shape_data)), + c_array_buf(mx_uint, + py_array('I', provided_arg_shape_idx)), + ctypes.byref(num_in_args), + ctypes.byref(in_arg_handles), + ctypes.byref(arg_grad_handles), + ctypes.byref(num_aux_states), + ctypes.byref(aux_state_handles), + shared_handle, + ctypes.byref(handle))) arg_arrays = [_ndarray_cls(NDArrayHandle(in_arg_handles[i])) for i in range(num_in_args.value)] diff --git a/python/mxnet/ndarray/_internal.py b/python/mxnet/ndarray/_internal.py index 5f3ce976dbc5..8045d9bd2b14 100644 --- a/python/mxnet/ndarray/_internal.py +++ b/python/mxnet/ndarray/_internal.py @@ -20,8 +20,6 @@ import os as _os import sys as _sys -import numpy as np - try: if int(_os.environ.get("MXNET_ENABLE_CYTHON", True)) == 0: from .._ctypes.ndarray import NDArrayBase, CachedOp diff --git a/python/mxnet/ndarray/contrib.py b/python/mxnet/ndarray/contrib.py index 74c355dc1288..1718a2c68d13 100644 --- a/python/mxnet/ndarray/contrib.py +++ b/python/mxnet/ndarray/contrib.py @@ -18,6 +18,7 @@ # coding: utf-8 # pylint: disable=wildcard-import, unused-wildcard-import,redefined-outer-name """Contrib NDArray API of MXNet.""" +from __future__ import absolute_import import math import numpy as np from ..context import current_context diff --git a/python/mxnet/ndarray/ndarray.py b/python/mxnet/ndarray/ndarray.py index 87f2712d8a40..97cfd827c7fe 100644 --- a/python/mxnet/ndarray/ndarray.py +++ b/python/mxnet/ndarray/ndarray.py @@ -35,7 +35,7 @@ import numpy as np from ..base import _LIB, numeric_types, integer_types from ..base import c_str, c_array, c_array_buf, c_handle_array, mx_real_t -from ..base import mx_uint, NDArrayHandle, check_call, DLPackHandle +from ..base import mx_uint, NDArrayHandle, check_call, DLPackHandle, mx_int from ..base import ctypes2buffer from ..context import Context, current_context from . import _internal @@ -143,11 +143,11 @@ def _new_alloc_handle(shape, ctx, delay_alloc, dtype=mx_real_t): def _new_from_shared_mem(shared_pid, shared_id, shape, dtype): hdl = NDArrayHandle() - check_call(_LIB.MXNDArrayCreateFromSharedMem( + check_call(_LIB.MXNDArrayCreateFromSharedMemEx( ctypes.c_int(shared_pid), ctypes.c_int(shared_id), - c_array(mx_uint, shape), - mx_uint(len(shape)), + c_array(mx_int, shape), + mx_int(len(shape)), ctypes.c_int(int(_DTYPE_NP_TO_MX[np.dtype(dtype).type])), ctypes.byref(hdl))) return hdl @@ -1845,11 +1845,14 @@ def shape(self): >>> y.shape (2L, 3L, 4L) """ - ndim = mx_uint() - pdata = ctypes.POINTER(mx_uint)() - check_call(_LIB.MXNDArrayGetShape( + ndim = mx_int() + pdata = ctypes.POINTER(mx_int)() + check_call(_LIB.MXNDArrayGetShapeEx( self.handle, ctypes.byref(ndim), ctypes.byref(pdata))) - return tuple(pdata[:ndim.value]) # pylint: disable=invalid-slice-index + if ndim.value == -1: + return None + else: + return tuple(pdata[:ndim.value]) # pylint: disable=invalid-slice-index @property diff --git a/python/mxnet/ndarray/register.py b/python/mxnet/ndarray/register.py index 05d7f17a8fc1..1ccf228698ba 100644 --- a/python/mxnet/ndarray/register.py +++ b/python/mxnet/ndarray/register.py @@ -16,9 +16,10 @@ # under the License. """Register backend ops in mxnet.ndarray namespace""" +from __future__ import absolute_import import os as _os import ctypes -import numpy as np # pylint: disable=unused-import +import numpy as _np # pylint: disable=unused-import from ._internal import NDArrayBase, _imperative_invoke # pylint: disable=unused-import from ..ndarray_doc import _build_doc @@ -103,7 +104,7 @@ def %s(*%s, **kwargs):"""%(func_name, arr_name)) if dtype_name is not None: code.append(""" if '%s' in kwargs: - kwargs['%s'] = np.dtype(kwargs['%s']).name"""%( + kwargs['%s'] = _np.dtype(kwargs['%s']).name"""%( dtype_name, dtype_name, dtype_name)) code.append(""" _ = kwargs.pop('name', None) @@ -136,7 +137,7 @@ def %s(%s):"""%(func_name, ', '.join(signature))) code.append(""" if %s is not _Null: keys.append('%s') - vals.append(np.dtype(%s).name)"""%(dtype_name, dtype_name, dtype_name)) + vals.append(_np.dtype(%s).name)"""%(dtype_name, dtype_name, dtype_name)) if not signature_only: code.append(""" diff --git a/python/mxnet/operator.py b/python/mxnet/operator.py index e8fa571d44db..2c69b9b46521 100644 --- a/python/mxnet/operator.py +++ b/python/mxnet/operator.py @@ -28,7 +28,7 @@ from ctypes import CFUNCTYPE, POINTER, Structure, pointer from ctypes import c_void_p, c_int, c_char, c_char_p, cast, c_bool -from .base import _LIB, check_call, MXCallbackList, c_array, c_array_buf +from .base import _LIB, check_call, MXCallbackList, c_array, c_array_buf, mx_int from .base import c_str, mx_uint, mx_float, ctypes2numpy_shared, NDArrayHandle, py_str from . import symbol, context from .ndarray import NDArray, _DTYPE_NP_TO_MX, _DTYPE_MX_TO_NP @@ -164,7 +164,7 @@ def get_symbol(self, *args, **kwargs): fb_functype = CFUNCTYPE(None, c_int, POINTER(POINTER(mx_float)), POINTER(c_int), POINTER(POINTER(mx_uint)), POINTER(c_int), c_void_p) infer_functype = CFUNCTYPE(None, c_int, POINTER(c_int), - POINTER(POINTER(mx_uint)), c_void_p) + POINTER(POINTER(mx_int)), c_void_p) list_functype = CFUNCTYPE(None, POINTER(POINTER(POINTER(c_char))), c_void_p) class NumpyOpInfo(Structure): """Structure that holds Callback information. Passed to NumpyOpProp""" @@ -214,9 +214,9 @@ def infer_shape_entry(num_tensor, tensor_dims, assert len(ishape) == n_in rshape = list(ishape) + list(oshape) for i in range(n_in+n_out): - tensor_shapes[i] = cast(c_array_buf(mx_uint, - array('I', rshape[i])), - POINTER(mx_uint)) + tensor_shapes[i] = cast(c_array_buf(mx_int, + array('i', rshape[i])), + POINTER(mx_int)) tensor_dims[i] = len(rshape[i]) def list_outputs_entry(out, _): @@ -266,7 +266,7 @@ def __init__(self, need_top_grad=True): def get_symbol(self, *args, **kwargs): fb_functype = CFUNCTYPE(c_bool, c_int, POINTER(c_void_p), POINTER(c_int), c_void_p) infer_functype = CFUNCTYPE(c_bool, c_int, POINTER(c_int), - POINTER(POINTER(mx_uint)), c_void_p) + POINTER(POINTER(mx_int)), c_void_p) list_functype = CFUNCTYPE(c_bool, POINTER(POINTER(POINTER(c_char))), c_void_p) deps_functype = CFUNCTYPE(c_bool, c_int_p, c_int_p, c_int_p, c_int_p, POINTER(c_int_p), c_void_p) @@ -335,9 +335,9 @@ def infer_shape_entry(num_tensor, tensor_dims, assert len(ishape) == n_in rshape = list(ishape) + list(oshape) for i in range(n_in+n_out): - tensor_shapes[i] = cast(c_array_buf(mx_uint, - array('I', rshape[i])), - POINTER(mx_uint)) + tensor_shapes[i] = cast(c_array_buf(mx_int, + array('i', rshape[i])), + POINTER(mx_int)) tensor_dims[i] = len(rshape[i]) except Exception: print('Error in NDArrayOp.infer_shape: %s' % traceback.format_exc()) @@ -698,7 +698,7 @@ def do_register(prop_cls): del_functype = CFUNCTYPE(c_int, c_void_p) infershape_functype = CFUNCTYPE(c_int, c_int, POINTER(c_int), - POINTER(POINTER(mx_uint)), c_void_p) + POINTER(POINTER(mx_int)), c_void_p) infertype_functype = CFUNCTYPE(c_int, c_int, POINTER(c_int), c_void_p) inferstorage_functype = CFUNCTYPE(c_int, c_int, POINTER(c_int), c_void_p) inferstorage_backward_functype = CFUNCTYPE(c_int, c_int, POINTER(c_int), \ @@ -747,9 +747,9 @@ def infer_shape_entry(num_tensor, tensor_dims, "shapes, got %d."%(n_aux, len(ashape)) rshape = list(ishape) + list(oshape) + list(ashape) for i in range(n_in+n_out+n_aux): - tensor_shapes[i] = cast(c_array_buf(mx_uint, - array('I', rshape[i])), - POINTER(mx_uint)) + tensor_shapes[i] = cast(c_array_buf(mx_int, + array('i', rshape[i])), + POINTER(mx_int)) tensor_dims[i] = len(rshape[i]) infer_shape_entry._ref_holder = [tensor_shapes] diff --git a/python/mxnet/symbol/_internal.py b/python/mxnet/symbol/_internal.py index 53fc684008cf..7e9787e32b1c 100644 --- a/python/mxnet/symbol/_internal.py +++ b/python/mxnet/symbol/_internal.py @@ -22,8 +22,6 @@ import sys as _sys import os as _os -import numpy as np - try: if int(_os.environ.get("MXNET_ENABLE_CYTHON", True)) == 0: from .._ctypes.symbol import SymbolBase, _set_symbol_class diff --git a/python/mxnet/symbol/register.py b/python/mxnet/symbol/register.py index 15c8e5e1fa68..ac59f8b97f15 100644 --- a/python/mxnet/symbol/register.py +++ b/python/mxnet/symbol/register.py @@ -17,9 +17,10 @@ # pylint: disable=unused-import """Register backend ops in mxnet.symbol namespace.""" +from __future__ import absolute_import import os as _os import ctypes -import numpy as np +import numpy as _np from . import _internal from ._internal import SymbolBase, _symbol_creator @@ -109,7 +110,7 @@ def %s(*%s, **kwargs):"""%(func_name, arr_name)) if dtype_name is not None: code.append(""" if '%s' in kwargs: - kwargs['%s'] = np.dtype(kwargs['%s']).name"""%( + kwargs['%s'] = _np.dtype(kwargs['%s']).name"""%( dtype_name, dtype_name, dtype_name)) code.append(""" attr = kwargs.pop('attr', None) @@ -175,7 +176,7 @@ def %s(%s):"""%(func_name, ', '.join(signature))) code.append(""" if %s is not _Null: _keys.append('%s') - _vals.append(np.dtype(%s).name)"""%(dtype_name, dtype_name, dtype_name)) + _vals.append(_np.dtype(%s).name)"""%(dtype_name, dtype_name, dtype_name)) code.append(""" if not hasattr(NameManager._current, "value"): diff --git a/python/mxnet/symbol/symbol.py b/python/mxnet/symbol/symbol.py index 91d4ca16df07..4bf60a6a1fcd 100644 --- a/python/mxnet/symbol/symbol.py +++ b/python/mxnet/symbol/symbol.py @@ -34,7 +34,7 @@ from ..attribute import AttrScope from ..base import _LIB, numeric_types, c_array, c_array_buf, c_str, c_str_array, c_handle_array -from ..base import mx_uint, py_str, string_types, integer_types +from ..base import mx_uint, py_str, string_types, integer_types, mx_int, is_np_compat from ..base import NDArrayHandle, ExecutorHandle, SymbolHandle from ..base import check_call, MXNetError, NotImplementedForSymbol from ..context import Context, current_context @@ -1078,7 +1078,11 @@ def infer_shape(self, *args, **kwargs): arg_names = self.list_arguments() unknowns = [] for name, shape in zip(arg_names, arg_shapes): - if not shape or not _numpy.prod(shape): + if is_np_compat(): + shape_is_none = not shape or -1 in shape + else: + shape_is_none = not shape or 0 in shape + if shape_is_none: if len(unknowns) >= 10: unknowns.append('...') break @@ -1174,25 +1178,25 @@ def _infer_shape_impl(self, partial, *args, **kwargs): indptr.append(len(sdata)) keys = c_str_array(str_keys) arg_shape_size = mx_uint() - arg_shape_ndim = ctypes.POINTER(mx_uint)() - arg_shape_data = ctypes.POINTER(ctypes.POINTER(mx_uint))() + arg_shape_ndim = ctypes.POINTER(mx_int)() + arg_shape_data = ctypes.POINTER(ctypes.POINTER(mx_int))() out_shape_size = mx_uint() - out_shape_ndim = ctypes.POINTER(mx_uint)() - out_shape_data = ctypes.POINTER(ctypes.POINTER(mx_uint))() + out_shape_ndim = ctypes.POINTER(mx_int)() + out_shape_data = ctypes.POINTER(ctypes.POINTER(mx_int))() aux_shape_size = mx_uint() - aux_shape_ndim = ctypes.POINTER(mx_uint)() - aux_shape_data = ctypes.POINTER(ctypes.POINTER(mx_uint))() + aux_shape_ndim = ctypes.POINTER(mx_int)() + aux_shape_data = ctypes.POINTER(ctypes.POINTER(mx_int))() complete = ctypes.c_int() if partial: - infer_func = _LIB.MXSymbolInferShapePartial + infer_func = _LIB.MXSymbolInferShapePartialEx else: - infer_func = _LIB.MXSymbolInferShape + infer_func = _LIB.MXSymbolInferShapeEx check_call(infer_func( self.handle, mx_uint(len(indptr) - 1), keys, c_array_buf(mx_uint, array('I', indptr)), - c_array_buf(mx_uint, array('I', sdata)), + c_array_buf(mx_int, array('i', sdata)), ctypes.byref(arg_shape_size), ctypes.byref(arg_shape_ndim), ctypes.byref(arg_shape_data), @@ -1204,12 +1208,15 @@ def _infer_shape_impl(self, partial, *args, **kwargs): ctypes.byref(aux_shape_data), ctypes.byref(complete))) if complete.value != 0: - arg_shapes = [ - tuple(arg_shape_data[i][:arg_shape_ndim[i]]) for i in range(arg_shape_size.value)] - out_shapes = [ - tuple(out_shape_data[i][:out_shape_ndim[i]]) for i in range(out_shape_size.value)] - aux_shapes = [ - tuple(aux_shape_data[i][:aux_shape_ndim[i]]) for i in range(aux_shape_size.value)] + arg_shapes = [tuple(arg_shape_data[i][:arg_shape_ndim[i]]) + if arg_shape_ndim[i] >= 0 else None + for i in range(arg_shape_size.value)] + out_shapes = [tuple(out_shape_data[i][:out_shape_ndim[i]]) + if out_shape_ndim[i] >= 0 else None + for i in range(out_shape_size.value)] + aux_shapes = [tuple(aux_shape_data[i][:aux_shape_ndim[i]]) + if aux_shape_ndim[i] >= 0 else None + for i in range(aux_shape_size.value)] return (arg_shapes, out_shapes, aux_shapes) else: return (None, None, None) @@ -1564,42 +1571,42 @@ def simple_bind(self, ctx, grad_req='write', type_dict=None, stype_dict=None, aux_state_handles = ctypes.POINTER(NDArrayHandle)() try: - check_call(_LIB.MXExecutorSimpleBind(self.handle, - ctypes.c_int(ctx.device_typeid), - ctypes.c_int(ctx.device_id), - num_ctx_map_keys, - ctx_map_keys, - ctx_map_dev_types, - ctx_map_dev_ids, - mx_uint(provided_req_type_list_len), - provided_grad_req_names, - provided_grad_req_types, - mx_uint(len(provided_arg_shape_names)), - c_str_array(provided_arg_shape_names), - c_array_buf(mx_uint, - array('I', provided_arg_shape_data)), - c_array_buf(mx_uint, - array('I', provided_arg_shape_idx)), - num_provided_arg_types, - provided_arg_type_names, - provided_arg_type_data, - num_provided_arg_stypes, - provided_arg_stype_names, - provided_arg_stype_data, - mx_uint(len(shared_arg_name_list)), - c_str_array(shared_arg_name_list), - ctypes.byref(shared_buffer_len), - shared_buffer_names, - shared_buffer_handles, - ctypes.byref(updated_shared_buffer_names), - ctypes.byref(updated_shared_buffer_handles), - ctypes.byref(num_in_args), - ctypes.byref(in_arg_handles), - ctypes.byref(arg_grad_handles), - ctypes.byref(num_aux_states), - ctypes.byref(aux_state_handles), - shared_exec_handle, - ctypes.byref(exe_handle))) + check_call(_LIB.MXExecutorSimpleBindEx(self.handle, + ctypes.c_int(ctx.device_typeid), + ctypes.c_int(ctx.device_id), + num_ctx_map_keys, + ctx_map_keys, + ctx_map_dev_types, + ctx_map_dev_ids, + mx_uint(provided_req_type_list_len), + provided_grad_req_names, + provided_grad_req_types, + mx_uint(len(provided_arg_shape_names)), + c_str_array(provided_arg_shape_names), + c_array_buf(mx_int, + array('I', provided_arg_shape_data)), + c_array_buf(mx_uint, + array('i', provided_arg_shape_idx)), + num_provided_arg_types, + provided_arg_type_names, + provided_arg_type_data, + num_provided_arg_stypes, + provided_arg_stype_names, + provided_arg_stype_data, + mx_uint(len(shared_arg_name_list)), + c_str_array(shared_arg_name_list), + ctypes.byref(shared_buffer_len), + shared_buffer_names, + shared_buffer_handles, + ctypes.byref(updated_shared_buffer_names), + ctypes.byref(updated_shared_buffer_handles), + ctypes.byref(num_in_args), + ctypes.byref(in_arg_handles), + ctypes.byref(arg_grad_handles), + ctypes.byref(num_aux_states), + ctypes.byref(aux_state_handles), + shared_exec_handle, + ctypes.byref(exe_handle))) except MXNetError as e: error_msg = "simple_bind error. Arguments:\n" for k, v in kwargs.items(): diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/Executor.scala b/scala-package/core/src/main/scala/org/apache/mxnet/Executor.scala index aec44023a5d3..b0fae0f9d58d 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/Executor.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/Executor.scala @@ -61,9 +61,7 @@ class Executor private[mxnet](private[mxnet] val handle: ExecutorHandle, protected var monitorCallback: MXMonitorCallback = null private val logger: Logger = LoggerFactory.getLogger(classOf[Executor]) - private[mxnet] var ownsArgArrays = false - private[mxnet] var ownsGradArrays = false - private[mxnet] var ownsAuxArrays = false + private var reshaped = false override def nativeAddress: CPtrAddress = handle override def nativeDeAllocator: (CPtrAddress => Int) = _LIB.mxExecutorFree @@ -75,17 +73,12 @@ class Executor private[mxnet](private[mxnet] val handle: ExecutorHandle, if (!super.isDisposed) { super.dispose() outputs.foreach(o => o.dispose()) - // Symbol.bind clones symbol when creating the executor so we need to dispose of the clone - symbol.dispose() - if (ownsArgArrays && argArrays != null) {argArrays.foreach(a => a.dispose())} - if (ownsGradArrays && gradArrays != null) {gradArrays.foreach( + if (reshaped && argArrays != null) {argArrays.foreach(a => a.dispose())} + if (reshaped && gradArrays != null) {gradArrays.foreach( // Symbol will sometimes fill this with nulls so we've got to check the elements too a => if (a != null) {a.dispose()}) } - if (ownsAuxArrays && auxArrays != null) {auxArrays.foreach(a => a.dispose())} - if (_argDict != null) {_argDict.foreach(a => a._2.dispose())} - if (_gradDict != null) {_gradDict.foreach(a => a._2.dispose())} - if (_auxDict != null) {_auxDict.foreach(a => a._2.dispose())} + if (reshaped && auxArrays != null) {auxArrays.foreach(a => a.dispose())} } } @@ -104,95 +97,59 @@ class Executor private[mxnet](private[mxnet] val handle: ExecutorHandle, */ def reshape(partialShaping: Boolean = false, allowUpSizing: Boolean = false, kwargs: Map[String, Shape]): Executor = { - var setArgOwner = false - var setAuxOwner = false - var setGradOwner = false - val (argShapes, _, auxShapes) = this.symbol.inferShape(kwargs) - // TODO: more precise error message should be provided by backend - require(argShapes != null, "Shape inference failed." + - s"Known shapes are $kwargs for symbol arguments ${symbol.listArguments()} " + - s"and aux states ${symbol.listAuxiliaryStates()}") - var newArgDict = Map[String, NDArray]() - var newGradDict = Map[String, NDArray]() + val providedArgShapeNames = kwargs.keys + val providedArgShapeData = kwargs.values.flatMap(_.toVector) + val providedArgShapeIdx = kwargs.values.scanLeft(0)((sum, shape) => sum + shape.size) - this.symbol.listArguments().zipWithIndex.foreach { case (name, i) => - val newShape = argShapes(i) - val arr = this.argArrays(i) - val dArr = if (this.gradArrays == null) null else this.gradArrays(i) - if (partialShaping || kwargs.contains(name) || newShape.equals(arr.shape)) { - if (newShape.product > arr.shape.product) { - require(allowUpSizing, s"New shape of arg:$name larger than original. " + - "First making a big executor and then down sizing it " + - "is more efficient than the reverse." + - "If you really want to up size, set allowUpSizing = true " + - "to enable allocation of new arrays.") - newArgDict = newArgDict + (name -> NDArray.empty(newShape, arr.context, arr.dtype)) - setArgOwner = true - if (dArr != null) { - newGradDict = newGradDict + (name -> NDArray.empty(newShape, dArr.context, dArr.dtype)) - setGradOwner = true - } - } else { - newArgDict = newArgDict + (name -> arr.reshape(newShape.toArray)) - if (dArr != null) { - newGradDict = newGradDict + (name -> dArr.reshape(newShape.toArray)) - } - } - } else { - throw new AssertionError(s"Shape of unspecified array arg:$name changed." + - "This can cause the new executor to not share parameters " + - "with the old one. Please check for error in network." + - "If this is intended, set partialShaping = true to suppress this warning.") - } - } - - var newAuxDict = Map[String, NDArray]() - val zip3 = (this.symbol.listAuxiliaryStates(), auxShapes, this.auxArrays).zipped - zip3.foreach { case (name, newShape, arr) => - if (partialShaping || newShape.equals(arr.shape)) { - if (newShape.product > arr.shape.product) { - require(allowUpSizing, s"New shape of aux:$name larger than original. " + - "First making a big executor and then down sizing it " + - "is more efficient than the reverse." + - "If you really want to up size, set allowUpSizing = true " + - "to enable allocation of new arrays.") - newAuxDict = newAuxDict + (name -> NDArray.empty(newShape, arr.context)) - setAuxOwner = true - } else { - newAuxDict = newAuxDict + (name -> arr.reshape(newShape.toArray)) - } - } else { - throw new AssertionError(s"Shape of unspecified array aux:$name changed." + - "This can cause the new executor to not share parameters " + - "with the old one. Please check for error in network." + - "If this is intended, set partialShaping = true to suppress this warning.") - } + val ctxMapKeys = if (_group2ctx != null) _group2ctx.keys.toArray else Array.empty[String] + val ctxMapDevTypes = if (_group2ctx != null) { + _group2ctx.values.map(_.deviceTypeid).toArray + } else { + Array.empty[Int] } - val reshapedExecutor = if (this._gradsReq.isInstanceOf[Seq[_]]) { - this.symbol.bind(this._ctx, - newArgDict, - newGradDict, - this._gradsReq.asInstanceOf[Seq[String]], - newAuxDict, - this._group2ctx, - this) + val ctxMapDevIds = if (_group2ctx != null) { + _group2ctx.values.map(_.deviceId).toArray } else { - this.symbol.bind(this._ctx, - newArgDict, - newGradDict, - this._gradsReq.asInstanceOf[Map[String, String]], - newAuxDict, - this._group2ctx, - this) + Array.empty[Int] } - // This method has created new NDArrays that will need to be managed by the new Executor - if (setArgOwner) reshapedExecutor.ownsArgArrays = true - if (setGradOwner) reshapedExecutor.ownsGradArrays = true - if (setAuxOwner) reshapedExecutor.ownsAuxArrays = true + val inArgs = ArrayBuffer.empty[NDArrayHandle] + val argGrads = ArrayBuffer.empty[NDArrayHandle] + val auxStates = ArrayBuffer.empty[NDArrayHandle] + val outHandle = new ExecutorHandleRef() + + checkCall(_LIB.mxExecutorReshape( + if (partialShaping) 1 else 0, + if (allowUpSizing) 1 else 0, + _ctx.deviceTypeid, + _ctx.deviceId, + ctxMapKeys.toArray, + ctxMapDevTypes.toArray, + ctxMapDevIds.toArray, + providedArgShapeNames.toArray, + providedArgShapeData.toArray, + providedArgShapeIdx.toArray, + inArgs, + argGrads, + auxStates, + this.handle, + outHandle)) + + val argArrays = inArgs.map(new NDArray(_)).toArray + val gradArrays = argGrads.map(handle => + if (handle == 0) null else new NDArray(handle)).toArray + val auxArrays = auxStates.map(new NDArray(_)).toArray - reshapedExecutor + val executor = new Executor(outHandle.value, this.symbol) + executor._ctx = this._ctx + executor._gradsReq = this._gradsReq + executor._group2ctx = this._group2ctx + executor.argArrays = argArrays + executor.gradArrays = gradArrays + executor.auxArrays = auxArrays + executor.reshaped = true + executor } /** diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/LibInfo.scala b/scala-package/core/src/main/scala/org/apache/mxnet/LibInfo.scala index 40fc0951e885..aba618540141 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/LibInfo.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/LibInfo.scala @@ -188,6 +188,23 @@ private[mxnet] class LibInfo { grads: Array[NDArrayHandle]): Int @native def mxExecutorPrint(handle: ExecutorHandle, debugStr: RefString): Int @native def mxExecutorSetMonitorCallback(handle: ExecutorHandle, callback: MXMonitorCallback): Int + // scalastyle:off parameterNum + @native def mxExecutorReshape(partialShaping: Int, + allowUpSizing: Int, + devType: Int, + devId: Int, + mapKeys: Array[String], + mapDevTypes: Array[Int], + mapDevIds: Array[Int], + providedArgShapeNames: Array[String], + providedArgShapeData: Array[Int], + providedArgShapeIdx: Array[Int], + inArgs: ArrayBuffer[NDArrayHandle], + argGrads: ArrayBuffer[NDArrayHandle], + auxStates: ArrayBuffer[NDArrayHandle], + sharedExec: ExecutorHandle, + out: ExecutorHandleRef): Int + // scalastyle:on parameterNum // Symbols @native def mxSymbolListAtomicSymbolCreators(symbolList: ListBuffer[SymbolHandle]): Int @@ -240,11 +257,20 @@ private[mxnet] class LibInfo { numArgs: MXUint, keys: Array[String], argIndPtr: Array[MXUint], - argShapeData: Array[MXUint], + argShapeData: Array[Int], inShapeData: ListBuffer[Array[Int]], outShapeData: ListBuffer[Array[Int]], auxShapeData: ListBuffer[Array[Int]], complete: RefInt): Int + @native def mxSymbolInferShapePartial(handle: SymbolHandle, + numArgs: MXUint, + keys: Array[String], + argIndPtr: Array[MXUint], + argShapeData: Array[Int], + inShapeData: ListBuffer[Array[Int]], + outShapeData: ListBuffer[Array[Int]], + auxShapeData: ListBuffer[Array[Int]], + complete: RefInt): Int @native def mxSymbolGetOutput(handle: SymbolHandle, index: Int, out: SymbolHandleRef): Int @native def mxSymbolSaveToJSON(handle: SymbolHandle, out: RefString): Int @native def mxSymbolCreateFromJSON(json: String, handle: SymbolHandleRef): Int @@ -322,4 +348,8 @@ private[mxnet] class LibInfo { @native def mxSetProfilerConfig(keys: Array[String], vals: Array[String]): Int @native def mxSetProfilerState(state: Int): Int @native def mxDumpProfile(finished: Int): Int + + // Numpy + @native def mxIsNumpyCompatible(compatible: RefInt): Int + @native def mxSetIsNumpyCompatible(isNpComp: Int, prev: RefInt): Int } diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala b/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala index ab42265ae102..849f4566f528 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala @@ -1274,11 +1274,15 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle, * @return an array representing shape of current ndarray */ def shape: Shape = { - val ndim = new MXUintRef + val ndim = new RefInt val data = ArrayBuffer[Int]() checkCall(_LIB.mxNDArrayGetShape(handle, ndim, data)) - require(ndim.value == data.length, s"ndim=$ndim, while len(data)=${data.length}") - Shape(data) + if (ndim.value == -1) { + null + } else { + require(ndim.value == data.length, s"ndim=$ndim, while len(data)=${data.length}") + Shape(data) + } } // Get size of current NDArray. diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/NumpyScope.scala b/scala-package/core/src/main/scala/org/apache/mxnet/NumpyScope.scala new file mode 100644 index 000000000000..d3e76f1044a7 --- /dev/null +++ b/scala-package/core/src/main/scala/org/apache/mxnet/NumpyScope.scala @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mxnet + +import org.apache.mxnet.Base._ + +/** + * NumpyScope object provides util functions for turning on/off NumPy compatibility + * and checking whether NumPy compatibility has been turned on/off. NumPy compatibility + * is introduced first to support zero-dim and zero-size tensors as in NumPy. + */ +object NumpyScope { + def setNumpyCompatible(isNpComp: Boolean): Boolean = { + val prev = new RefInt() + checkCall(_LIB.mxSetIsNumpyCompatible(if (isNpComp) 1 else 0, prev)) + if (prev.value != 0) true else false + } + + def isNumpyCompatible: Boolean = { + val curr = new RefInt + checkCall(_LIB.mxIsNumpyCompatible(curr)) + if (curr.value != 0) true else false + } + + def enableNumpyCompatible: NumpyScope = { + new NumpyScope(true) + } + + + def disableNumpyCompatible: NumpyScope = { + new NumpyScope(false) + } +} + +class NumpyScope(var isCompatible: Boolean) { + private var prev: Boolean = false + + def withScope[T](body: => T): T = { + prev = NumpyScope.setNumpyCompatible(isCompatible) + try { + body + } finally { + if (prev != isCompatible) { + NumpyScope.setNumpyCompatible(prev) + } + } + } +} diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala b/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala index 821e04f08df2..808a23a8c945 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala @@ -260,17 +260,45 @@ class Symbol private(private[mxnet] val handle: SymbolHandle) extends NativeReso def inferShape(keys: Array[String], indPtr: Array[Int], values: Array[Int]) : (IndexedSeq[Shape], IndexedSeq[Shape], IndexedSeq[Shape]) = { + val res = inferShapeImpl(partial = false, keys, indPtr, values) + if (res._2 == null) { + val (argShapes, _, _) = inferShapeImpl(partial = true, keys, indPtr, values) + val argNames = listArguments() + val unknown = (argNames zip argShapes).map { case (name, shape) => + val shapeIsNone = if (NumpyScope.isNumpyCompatible) { + shape == null || shape.toVector.contains(-1) + } else { + shape == null || shape.toVector.contains(0) + } + if (shapeIsNone) s"$name: $shape" else "" + } + logger.warn("Cannot decide shape for the following arguments. " + + "Consider providing them as input: \n\t{}", + unknown.filter(_ != "").mkString("\n\t")) + } + res + } + + private def inferShapeImpl(partial: Boolean, + keys: Array[String], + indPtr: Array[Int], + values: Array[Int]) + : (IndexedSeq[Shape], IndexedSeq[Shape], IndexedSeq[Shape]) = { val argShapeData = ListBuffer.empty[Array[Int]] val outShapeData = ListBuffer.empty[Array[Int]] val auxShapeData = ListBuffer.empty[Array[Int]] val complete = new RefInt - - checkCall(_LIB.mxSymbolInferShape(handle, indPtr.length - 1, keys, indPtr, values, - argShapeData, outShapeData, auxShapeData, complete)) + if (partial) { + checkCall(_LIB.mxSymbolInferShapePartial(handle, indPtr.length - 1, keys, indPtr, values, + argShapeData, outShapeData, auxShapeData, complete)) + } else { + checkCall(_LIB.mxSymbolInferShape(handle, indPtr.length - 1, keys, indPtr, values, + argShapeData, outShapeData, auxShapeData, complete)) + } if (complete.value != 0) { (argShapeData.map(s => Shape(s)).toIndexedSeq, - outShapeData.map(s => Shape(s)).toIndexedSeq, - auxShapeData.map(s => Shape(s)).toIndexedSeq) + outShapeData.map(s => Shape(s)).toIndexedSeq, + auxShapeData.map(s => Shape(s)).toIndexedSeq) } else { (null, null, null) } diff --git a/scala-package/core/src/test/scala/org/apache/mxnet/NumpyScopeSuite.scala b/scala-package/core/src/test/scala/org/apache/mxnet/NumpyScopeSuite.scala new file mode 100644 index 000000000000..bf6627ac7e91 --- /dev/null +++ b/scala-package/core/src/test/scala/org/apache/mxnet/NumpyScopeSuite.scala @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mxnet + +import org.scalatest.{BeforeAndAfterAll, FunSuite} + +class NumpyScopeSuite extends FunSuite with BeforeAndAfterAll { + test("compatible") { + NumpyScope.enableNumpyCompatible.withScope { + assert(NumpyScope.isNumpyCompatible === true) + } + } + + test("incompatible") { + NumpyScope.disableNumpyCompatible.withScope { + assert(NumpyScope.isNumpyCompatible === false) + } + } +} diff --git a/scala-package/macros/src/main/scala/org/apache/mxnet/utils/CToScalaUtils.scala b/scala-package/macros/src/main/scala/org/apache/mxnet/utils/CToScalaUtils.scala index 57c4cfba10b7..12d797f9b100 100644 --- a/scala-package/macros/src/main/scala/org/apache/mxnet/utils/CToScalaUtils.scala +++ b/scala-package/macros/src/main/scala/org/apache/mxnet/utils/CToScalaUtils.scala @@ -47,7 +47,8 @@ private[mxnet] object CToScalaUtils { case "double" | "doubleorNone" => types("double") case "string" => "String" case "boolean" | "booleanorNone" => types("bool") - case "tupleof" | "tupleof" | "tupleof<>" | "ptr" | "" => "Any" + case "tupleof" | "tupleof" | "tupleof" | "tupleof" | + "tupleof<>" | "ptr" | "" => "Any" case default => throw new IllegalArgumentException( s"Invalid type for args: $default\nString argType: $argType\nargName: $argName") } diff --git a/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.cc b/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.cc index 33e4cca99b3a..7323d23ac556 100644 --- a/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.cc +++ b/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.cc @@ -354,9 +354,9 @@ JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxNDArrayLoadFromRawBytes JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxNDArrayGetShape (JNIEnv *env, jobject obj, jlong ndArrayPtr, jobject ndimRef, jobject dataBuf) { - mx_uint ndim; - const mx_uint *pdata; - int ret = MXNDArrayGetShape(reinterpret_cast(ndArrayPtr), &ndim, &pdata); + int ndim; + const int *pdata; + int ret = MXNDArrayGetShapeEx(reinterpret_cast(ndArrayPtr), &ndim, &pdata); // fill dataBuf jclass integerClass = env->FindClass("java/lang/Integer"); @@ -365,7 +365,7 @@ JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxNDArrayGetShape jclass arrayClass = env->FindClass("scala/collection/mutable/ArrayBuffer"); jmethodID arrayAppend = env->GetMethodID(arrayClass, "$plus$eq", "(Ljava/lang/Object;)Lscala/collection/mutable/ArrayBuffer;"); - for (size_t i = 0; i < ndim; ++i) { + for (int i = 0; i < ndim; ++i) { jobject data = env->NewObject(integerClass, newInteger, pdata[i]); env->CallObjectMethod(dataBuf, arrayAppend, data); env->DeleteLocalRef(data); @@ -892,6 +892,119 @@ JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxExecutorBackward return ret; } +JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxExecutorReshape + (JNIEnv * env, jobject obj, + jint partialReshaping, jint allowUpSizing, jint devType, jint devId, + jobjectArray jmapKeys, jintArray jmapDevTypes, jintArray jmapDevIds, + jobjectArray jprovidedArgShapeNames, jintArray jprovidedArgShapeData, + jintArray jprovidedArgShapeIdx, jobject jrefInArgs, jobject jrefArgGrads, + jobject jrefAuxStates, jlong jsharedExec, jobject jrefOut) { + CHECK(jmapKeys != NULL); + CHECK(jprovidedArgShapeNames != NULL); + + int numMapKeys = env->GetArrayLength(jmapKeys); + jint *mapDevTypes = env->GetIntArrayElements(jmapDevTypes, NULL); + jint *mapDevIds = env->GetIntArrayElements(jmapDevIds, NULL); + const char **mapKeys = NULL; + if (numMapKeys > 0) { + mapKeys = new const char*[numMapKeys]; + for (int i = 0; i < numMapKeys; ++i) { + jstring jkey = reinterpret_cast(env->GetObjectArrayElement(jmapKeys, i)); + mapKeys[i] = env->GetStringUTFChars(jkey, 0); + env->DeleteLocalRef(jkey); + } + } + + int numProvidedArgShapes = env->GetArrayLength(jprovidedArgShapeNames); + jint *providedArgShapeData = env->GetIntArrayElements(jprovidedArgShapeData, NULL); + jint *providedArgShapeIdx = env->GetIntArrayElements(jprovidedArgShapeIdx, NULL); + const char **providedArgShapeNames = NULL; + if (numProvidedArgShapes > 0) { + providedArgShapeNames = new const char*[numProvidedArgShapes]; + for (int i = 0; i < numProvidedArgShapes; ++i) { + jstring jkey = reinterpret_cast( + env->GetObjectArrayElement(jprovidedArgShapeNames, i)); + providedArgShapeNames[i] = env->GetStringUTFChars(jkey, 0); + env->DeleteLocalRef(jkey); + } + } + + mx_uint numInArgs = 0; + NDArrayHandle *inArgs; + NDArrayHandle *argGrads; + + mx_uint numAuxStates = 0; + NDArrayHandle *auxStates; + + ExecutorHandle out; + + int ret = MXExecutorReshapeEx(partialReshaping, + allowUpSizing, + devType, + devId, + static_cast(numMapKeys), + mapKeys, + static_cast(mapDevTypes), + static_cast(mapDevIds), + static_cast(numProvidedArgShapes), + providedArgShapeNames, + static_cast(providedArgShapeData), + reinterpret_cast(providedArgShapeIdx), + &numInArgs, + &inArgs, + &argGrads, + &numAuxStates, + &auxStates, + reinterpret_cast(jsharedExec), + &out); + + jclass longCls = env->FindClass("java/lang/Long"); + jmethodID newLong = env->GetMethodID(longCls, "", "(J)V"); + + jclass arrayClass = env->FindClass("scala/collection/mutable/ArrayBuffer"); + jmethodID arrayAppend = env->GetMethodID(arrayClass, + "$plus$eq", "(Ljava/lang/Object;)Lscala/collection/mutable/ArrayBuffer;"); + + for (size_t i = 0; i < numInArgs; ++i) { + jobject inArg = env->NewObject(longCls, newLong, inArgs[i]); + env->CallObjectMethod(jrefInArgs, arrayAppend, inArg); + env->DeleteLocalRef(inArg); + + jobject argGrad = env->NewObject(longCls, newLong, argGrads[i]); + env->CallObjectMethod(jrefArgGrads, arrayAppend, argGrad); + env->DeleteLocalRef(argGrad); + } + + for (size_t i = 0; i < numAuxStates; ++i) { + jobject auxState = env->NewObject(longCls, newLong, auxStates[i]); + env->CallObjectMethod(jrefAuxStates, arrayAppend, auxState); + env->DeleteLocalRef(auxState); + } + + SetLongField(env, jrefOut, reinterpret_cast(out)); + + // release allocated memory + for (int i = 0; i < numMapKeys; i++) { + jstring jkey = reinterpret_cast(env->GetObjectArrayElement(jmapKeys, i)); + env->ReleaseStringUTFChars(jkey, mapKeys[i]); + env->DeleteLocalRef(jkey); + } + if (mapKeys != NULL) { + delete[] mapKeys; + } + + for (int i = 0; i < numProvidedArgShapes; i++) { + jstring jkey = reinterpret_cast(env->GetObjectArrayElement(jprovidedArgShapeNames, i)); + env->ReleaseStringUTFChars(jkey, providedArgShapeNames[i]); + env->DeleteLocalRef(jkey); + } + if (providedArgShapeNames != NULL) { + delete[] providedArgShapeNames; + } + + return ret; +} + JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxExecutorPrint (JNIEnv * env, jobject obj, jlong ptr, jobject debugStr) { const char *retDebugStr; @@ -1530,23 +1643,27 @@ JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxSymbolCreateFromFile int FillSymbolInferShape (JNIEnv *env, jmethodID listAppend, jobject joutData, - mx_uint shapeSize, const mx_uint *shapeNdim, const mx_uint **shapeData) { - for (size_t i = 0; i < shapeSize; ++i) { - jintArray jshape = env->NewIntArray(shapeNdim[i]); - if (jshape == NULL) { - // TODO(Yizhi): out of memory error thrown, return a specific error code ? - return -1; + int shapeSize, const int *shapeNdim, const int **shapeData) { + for (int i = 0; i < shapeSize; ++i) { + jintArray jshape = NULL; + if (shapeNdim[i] >= 0) { + jshape = env->NewIntArray(shapeNdim[i]); + if (jshape == NULL) { + // TODO(Yizhi): out of memory error thrown, return a specific error code ? + return -1; + } + env->SetIntArrayRegion(jshape, 0, shapeNdim[i], reinterpret_cast(shapeData[i])); } - env->SetIntArrayRegion(jshape, 0, shapeNdim[i], reinterpret_cast(shapeData[i])); env->CallObjectMethod(joutData, listAppend, jshape); env->DeleteLocalRef(jshape); } return 0; } -JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxSymbolInferShape - (JNIEnv *env, jobject obj, jlong symbolPtr, jint jnumArgs, jobjectArray jkeys, - jintArray jargIndPtr, jintArray jargShapeData, - jobject jinShapeData, jobject joutShapeData, jobject jauxShapeData, jobject jcomplete) { + +int SymbolInferShapeHelper(JNIEnv *env, jobject obj, jlong symbolPtr, jint jnumArgs, + jobjectArray jkeys, jintArray jargIndPtr, jintArray jargShapeData, + jobject jinShapeData, jobject joutShapeData, jobject jauxShapeData, + jobject jcomplete, bool partial) { const char **keys = NULL; if (jkeys != NULL) { keys = new const char *[jnumArgs]; @@ -1559,26 +1676,28 @@ JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxSymbolInferShape } mx_uint inShapeSize; - const mx_uint *inShapeNdim; - const mx_uint **inShapeData; + const int *inShapeNdim; + const int **inShapeData; mx_uint outShapeSize; - const mx_uint *outShapeNdim; - const mx_uint **outShapeData; + const int *outShapeNdim; + const int **outShapeData; mx_uint auxShapeSize; - const mx_uint *auxShapeNdim; - const mx_uint **auxShapeData; + const int *auxShapeNdim; + const int **auxShapeData; int complete; jint *argIndPtr = env->GetIntArrayElements(jargIndPtr, NULL); jint *argShapeData = env->GetIntArrayElements(jargShapeData, NULL); - int ret = MXSymbolInferShape(reinterpret_cast(symbolPtr), + int ret; + if (!partial) { + ret = MXSymbolInferShapeEx(reinterpret_cast(symbolPtr), static_cast(jnumArgs), keys, - reinterpret_cast(argIndPtr), - reinterpret_cast(argShapeData), + reinterpret_cast(argIndPtr), + reinterpret_cast(argShapeData), &inShapeSize, &inShapeNdim, &inShapeData, @@ -1589,6 +1708,23 @@ JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxSymbolInferShape &auxShapeNdim, &auxShapeData, &complete); + } else { + ret = MXSymbolInferShapePartialEx(reinterpret_cast(symbolPtr), + static_cast(jnumArgs), + keys, + reinterpret_cast(argIndPtr), + reinterpret_cast(argShapeData), + &inShapeSize, + &inShapeNdim, + &inShapeData, + &outShapeSize, + &outShapeNdim, + &outShapeData, + &auxShapeSize, + &auxShapeNdim, + &auxShapeData, + &complete); + } env->ReleaseIntArrayElements(jargShapeData, argShapeData, 0); env->ReleaseIntArrayElements(jargIndPtr, argIndPtr, 0); @@ -1629,6 +1765,24 @@ JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxSymbolInferShape return ret; } +JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxSymbolInferShape + (JNIEnv *env, jobject obj, jlong symbolPtr, jint jnumArgs, jobjectArray jkeys, + jintArray jargIndPtr, jintArray jargShapeData, + jobject jinShapeData, jobject joutShapeData, jobject jauxShapeData, jobject jcomplete) { + + return SymbolInferShapeHelper(env, obj, symbolPtr, jnumArgs, jkeys, jargIndPtr, jargShapeData, + jinShapeData, joutShapeData, jauxShapeData, jcomplete, false); +} + +JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxSymbolInferShapePartial + (JNIEnv *env, jobject obj, jlong symbolPtr, jint jnumArgs, jobjectArray jkeys, + jintArray jargIndPtr, jintArray jargShapeData, + jobject jinShapeData, jobject joutShapeData, jobject jauxShapeData, jobject jcomplete) { + + return SymbolInferShapeHelper(env, obj, symbolPtr, jnumArgs, jkeys, jargIndPtr, jargShapeData, + jinShapeData, joutShapeData, jauxShapeData, jcomplete, true); +} + JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxExecutorBindX (JNIEnv *env, jobject obj, jlong symbolPtr, jint deviceTypeId, jint deviceID, jint numCtx, jobjectArray jctxMapKeys, jintArray jctxMapDevTypes, jintArray jctxMapDevIDs, jint numArgs, @@ -2551,3 +2705,20 @@ JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxDumpProfile (JNIEnv *env, jobject obj, jint finished) { return MXDumpProfile(finished); } + +// Numpy +JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxIsNumpyCompatible + (JNIEnv *env, jobject obj, jobject compatibleRef) { + bool isCompatible; + int ret = MXIsNumpyCompatible(&isCompatible); + SetIntField(env, compatibleRef, static_cast(isCompatible)); + return ret; +} + +JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxSetIsNumpyCompatible + (JNIEnv *env, jobject obj, jint isNpComp, jobject prevRef) { + int prev; + int ret = MXSetIsNumpyCompatible(isNpComp, &prev); + SetIntField(env, prevRef, prev); + return ret; +} diff --git a/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.h b/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.h index b8a9b3b9e64f..467272cea9cf 100644 --- a/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.h +++ b/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.h @@ -511,6 +511,14 @@ JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxExecutorPrint JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxExecutorSetMonitorCallback (JNIEnv *, jobject, jlong, jobject); +/* + * Class: org_apache_mxnet_LibInfo + * Method: mxExecutorReshape + * Signature: (IIII[Ljava/lang/String;[I[I[Ljava/lang/String;[I[ILscala/collection/mutable/ArrayBuffer;Lscala/collection/mutable/ArrayBuffer;Lscala/collection/mutable/ArrayBuffer;JLorg/apache/mxnet/Base/RefLong;)I + */ +JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxExecutorReshape + (JNIEnv *, jobject, jint, jint, jint, jint, jobjectArray, jintArray, jintArray, jobjectArray, jintArray, jintArray, jobject, jobject, jobject, jlong, jobject); + /* * Class: org_apache_mxnet_LibInfo * Method: mxSymbolListAtomicSymbolCreators @@ -655,6 +663,14 @@ JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxSymbolInferType JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxSymbolInferShape (JNIEnv *, jobject, jlong, jint, jobjectArray, jintArray, jintArray, jobject, jobject, jobject, jobject); +/* + * Class: org_apache_mxnet_LibInfo + * Method: mxSymbolInferShapePartial + * Signature: (JI[Ljava/lang/String;[I[ILscala/collection/mutable/ListBuffer;Lscala/collection/mutable/ListBuffer;Lscala/collection/mutable/ListBuffer;Lorg/apache/mxnet/Base/RefInt;)I + */ +JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxSymbolInferShapePartial + (JNIEnv *, jobject, jlong, jint, jobjectArray, jintArray, jintArray, jobject, jobject, jobject, jobject); + /* * Class: org_apache_mxnet_LibInfo * Method: mxSymbolGetOutput @@ -855,6 +871,22 @@ JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxSetProfilerState JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxDumpProfile (JNIEnv *, jobject, jint); +/* + * Class: org_apache_mxnet_LibInfo + * Method: mxIsNumpyCompatible + * Signature: (Lorg/apache/mxnet/Base/RefInt;)I + */ +JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxIsNumpyCompatible + (JNIEnv *, jobject, jobject); + +/* + * Class: org_apache_mxnet_LibInfo + * Method: mxSetIsNumpyCompatible + * Signature: (ILorg/apache/mxnet/Base/RefInt;)I + */ +JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxSetIsNumpyCompatible + (JNIEnv *, jobject, jint, jobject); + #ifdef __cplusplus } #endif diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index 45197aafe019..f549ddd13994 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -44,9 +44,11 @@ #include "mxnet/rtc.h" #include "mxnet/storage.h" #include "mxnet/libinfo.h" +#include "mxnet/imperative.h" #include "./c_api_common.h" #include "../operator/custom/custom-inl.h" #include "../operator/tensor/matrix_op-inl.h" +#include "../common/utils.h" using namespace mxnet; @@ -471,7 +473,7 @@ MXNET_DLL int MXNDArrayReshape64(NDArrayHandle handle, NDArray *ptr = new NDArray(); API_BEGIN(); NDArray *arr = static_cast(handle); - nnvm::Tuple shape(dims, dims+ndim); + mxnet::Tuple shape(dims, dims+ndim); CHECK_GT(arr->shape().Size(), 0) << "Source ndarray's shape is undefined. Input shape: " << arr->shape(); mxnet::TShape new_shape = mxnet::op::InferReshapeShape(shape, arr->shape(), reverse); @@ -511,6 +513,34 @@ int MXNDArrayGetShape(NDArrayHandle handle, API_END(); } +int MXNDArrayGetShapeEx(NDArrayHandle handle, + int *out_dim, + const int **out_pdata) { + MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get(); + API_BEGIN(); + NDArray *arr = static_cast(handle); + if (!arr->is_none()) { + mxnet::TShape s = arr->shape(); + if (!Imperative::Get()->is_np_comp()) { + common::ConvertToLegacyShape(&s); + } + *out_dim = s.ndim(); + if (s.ndim() >= 0) { + std::vector &buffer = ret->arg_shape_buffer_ex; + buffer.resize(s.ndim()); + mxnet::ShapeTypeCast(s.begin(), s.end(), buffer.data()); + *out_pdata = buffer.data(); + } + } else { + if (Imperative::Get()->is_np_comp()) { + *out_dim = -1; + } else { + *out_dim = 0; + } + } + API_END(); +} + int MXNDArrayGetData(NDArrayHandle handle, void **out_pdata) { API_BEGIN(); @@ -791,7 +821,7 @@ int MXDataIterGetLabel(DataIterHandle handle, NDArrayHandle *out) { // temp hack to make label 1D // TODO(tianjun) make label 1D when label_width=0 mxnet::TShape shape = db.data[1].shape(); - if (shape[1] == 1) { + if (shape.ndim() > 1 && shape[1] == 1) { *pndarray = db.data[1].Reshape(mshadow::Shape1(shape[0])); } else { *pndarray = db.data[1]; @@ -1402,6 +1432,13 @@ int MXNDArrayCreateFromSharedMem(int shared_pid, int shared_id, const mx_uint *s API_END(); } +int MXNDArrayCreateFromSharedMemEx(int shared_pid, int shared_id, const int *shape, + int ndim, int dtype, NDArrayHandle *out) { + API_BEGIN(); + *out = new NDArray(shared_pid, shared_id, mxnet::TShape(shape, shape + ndim), dtype); + API_END(); +} + typedef Engine::VarHandle VarHandle; typedef Engine::CallbackOnComplete CallbackOnComplete; diff --git a/src/c_api/c_api_common.h b/src/c_api/c_api_common.h index b5adfa37eca9..013ecab93da8 100644 --- a/src/c_api/c_api_common.h +++ b/src/c_api/c_api_common.h @@ -75,12 +75,19 @@ struct MXAPIThreadLocalEntry { std::vector arg_storage_types, out_storage_types, aux_storage_types; /*! \brief result holder for returning shape dimensions */ std::vector arg_shape_ndim, out_shape_ndim, aux_shape_ndim; + /*! \brief result holder for returning shape dimensions */ + std::vector arg_shape_ndim_ex, out_shape_ndim_ex, aux_shape_ndim_ex; /*! \brief result holder for returning shape pointer */ std::vector arg_shape_data, out_shape_data, aux_shape_data; + /*! \brief result holder for returning shape pointer */ + std::vector arg_shape_data_ex, out_shape_data_ex, aux_shape_data_ex; /*! \brief uint32_t buffer for returning shape pointer */ std::vector arg_shape_buffer, out_shape_buffer, aux_shape_buffer; + /*! \brief uint32_t buffer for returning shape pointer */ + std::vector arg_shape_buffer_ex, out_shape_buffer_ex, aux_shape_buffer_ex; /*! \brief bool buffer */ std::vector save_inputs, save_outputs; + // DEPRECATED. Use SetupShapeArrayReturnWithBufferEx instead. // helper function to setup return value of shape array inline static void SetupShapeArrayReturnWithBuffer( const mxnet::ShapeVector &shapes, @@ -99,6 +106,30 @@ struct MXAPIThreadLocalEntry { ptr = nnvm::ShapeTypeCast(shapes[i].begin(), shapes[i].end(), ptr); } } + // helper function to setup return value of shape array + inline static void SetupShapeArrayReturnWithBufferEx( + const mxnet::ShapeVector &shapes, + std::vector *ndim, + std::vector *data, + std::vector *buffer) { + ndim->resize(shapes.size()); + data->resize(shapes.size()); + size_t size = 0; + for (const auto& s : shapes) { + if (s.ndim() > 0) { + size += s.ndim(); + } + } + buffer->resize(size); + int *ptr = buffer->data(); + for (size_t i = 0; i < shapes.size(); ++i) { + ndim->at(i) = shapes[i].ndim(); + data->at(i) = ptr; + if (shapes[i].ndim() > 0) { + ptr = mxnet::ShapeTypeCast(shapes[i].begin(), shapes[i].end(), ptr); + } + } + } }; // define the threadlocal store. diff --git a/src/c_api/c_api_executor.cc b/src/c_api/c_api_executor.cc index a2e8bb810e6f..5352fcfe0951 100644 --- a/src/c_api/c_api_executor.cc +++ b/src/c_api/c_api_executor.cc @@ -25,8 +25,10 @@ #include #include #include +#include #include "./c_api_common.h" #include "../executor/graph_executor.h" +#include "../common/utils.h" #if MXNET_USE_TENSORRT #include "../executor/trt_graph_executor.h" #endif // MXNET_USE_TENSORRT @@ -183,7 +185,7 @@ int MXExecutorBindEX(SymbolHandle symbol_handle, } /*! - * \brief + * \brief DEPRECATED. Use MXExecutorSimpleBindEx instead. * \param symbol_handle symbol handle * \param dev_type default device type * \param dev_id default device id @@ -416,6 +418,371 @@ int MXExecutorSimpleBind(SymbolHandle symbol_handle, CHECK(p.second) << "Duplicate shapes are provided for argument " << provided_arg_shape_names[i] << " in simple_bind"; } + if (!Imperative::Get()->is_np_comp()) { + for (auto &kv : arg_shape_map) { + common::ConvertToNumpyShape(&kv.second); + } + } + + // create para name set for sharing data array memory + std::unordered_set shared_arg_name_set(num_shared_arg_names); + for (mx_uint i = 0; i < num_shared_arg_names; ++i) { + shared_arg_name_set.insert(shared_arg_name_list[i]); + } + + // create shared_buffer_map + std::unordered_map shared_buffer_map; + bool use_shared_buffer = (*shared_buffer_len >= 0); + if (*shared_buffer_len > 0) { + // create shared_buffer_map + shared_buffer_map.reserve(*shared_buffer_len); + NDArray** shared_buffer_ptrs = + reinterpret_cast(shared_buffer_handle_list); + for (int i = 0; i < *shared_buffer_len; ++i) { + shared_buffer_map[shared_buffer_name_list[i]] = *(shared_buffer_ptrs[i]); + } + } + + // create temporary place holders for the initialized NDArrays + // to be passed back to front end + std::vector in_arg_vec; + std::vector arg_grad_vec; + std::vector aux_state_vec; +#if MXNET_USE_TENSORRT + // If we've built with TensorRT support we by default return an TRTExecutor. + // Users can override this behaviour via env var, which is useful for example for A/B + // performance testing. + if (dmlc::GetEnv("MXNET_USE_TENSORRT", false)) { + *out = exec::TrtGraphExecutor::TensorRTBind(*sym, ctx, ctx_map, &in_arg_ctx_vec, + &arg_grad_ctx_vec, &aux_state_ctx_vec, + &arg_shape_map, &arg_dtype_map, &arg_stype_map, + &grad_req_type_vec, shared_arg_name_set, + &in_arg_vec, &arg_grad_vec, &aux_state_vec, + use_shared_buffer ? &shared_buffer_map : nullptr, + reinterpret_cast(shared_exec_handle)); + } else { + // Checks to see if this env var has been set to true or false by the user. + // If the user is using a TensorRT build, but has not enabled TRT at inference time, warn + // them and describe further steps. + const int unset_indicator = std::numeric_limits::quiet_NaN(); + if (dmlc::GetEnv("MXNET_USE_TENSORRT", unset_indicator) == unset_indicator) { + LOG(INFO) << "TensorRT not enabled by default. Please set the MXNET_USE_TENSORRT " + "environment variable to 1 or call mx.contrib.tensorrt.set_use_tensorrt(True) " + "to enable."; + } +#endif // MXNET_USE_TENSORRT + *out = Executor::SimpleBind(*sym, ctx, ctx_map, in_arg_ctx_vec, arg_grad_ctx_vec, + aux_state_ctx_vec, arg_shape_map, arg_dtype_map, arg_stype_map, + grad_req_type_vec, shared_arg_name_set, &in_arg_vec, + &arg_grad_vec, &aux_state_vec, + use_shared_buffer ? &shared_buffer_map : nullptr, + reinterpret_cast(shared_exec_handle)); +#if MXNET_USE_TENSORRT + } +#endif // MXNET_USE_TENSORRT + + // copy ndarray ptrs to ret->handles so that front end + // can access them + ret->ret_handles.clear(); + ret->ret_handles.reserve(in_arg_vec.size()+arg_grad_vec.size()+aux_state_vec.size() + +shared_buffer_map.size()); + size_t nd_idx = 0; + for (const auto& nd : in_arg_vec) { + if (nd.is_none()) { + LOG(FATAL) << "Input argument NDArray cannot be un-allocated"; + } + ret->ret_handles.push_back(new NDArray(nd)); + } + if (in_arg_vec.size() > 0) { + *num_in_args = in_arg_vec.size(); + *in_args = &(ret->ret_handles[nd_idx]); + nd_idx = ret->ret_handles.size(); + } + + for (const auto& nd : arg_grad_vec) { + if (nd.is_none()) { + ret->ret_handles.push_back(nullptr); + } else { + ret->ret_handles.push_back(new NDArray(nd)); + } + } + if (arg_grad_vec.size() > 0) { + *arg_grads = &(ret->ret_handles[nd_idx]); + nd_idx = ret->ret_handles.size(); + } + + for (const auto& nd : aux_state_vec) { + if (nd.is_none()) { + LOG(FATAL) << "Auxiliary argument NDArray cannot be un-allocated"; + } + ret->ret_handles.push_back(new NDArray(nd)); + } + if (aux_state_vec.size() > 0) { + *num_aux_states = aux_state_vec.size(); + *aux_states = &(ret->ret_handles[nd_idx]); + nd_idx = ret->ret_handles.size(); + } + + if (use_shared_buffer) { + ret->ret_vec_str.clear(); + ret->ret_vec_str.reserve(shared_buffer_map.size()); + ret->ret_vec_charp.clear(); + ret->ret_vec_charp.reserve(shared_buffer_map.size()); + for (const auto& kv : shared_buffer_map) { + if (kv.second.is_none()) { + LOG(FATAL) << "Shared data NDArray cannot be un-allocated"; + } + ret->ret_handles.push_back(new NDArray(kv.second)); + ret->ret_vec_str.emplace_back(kv.first); + ret->ret_vec_charp.push_back(ret->ret_vec_str.back().c_str()); + } + *shared_buffer_len = shared_buffer_map.size(); + *updated_shared_buffer_handle_list = &(ret->ret_handles[nd_idx]); + *updated_shared_buffer_name_list = &(ret->ret_vec_charp[0]); + } + + API_END(); +} + +/*! + * \brief + * \param symbol_handle symbol handle + * \param dev_type default device type + * \param dev_id default device id + * \param num_g2c_keys number of group2ctx keys + * \param g2c_keys key list of group2ctx + * \param g2c_dev_types device type list of group2ctx + * \param g2c_dev_ids id list of group2ctx + * \param provided_grad_req_list_len grad_req length provided by users in front-end + * \param provided_grad_req_names grad_req names provided by users in front-end + * \param provided_grad_req_types req types provided by users in front-end + * \param num_provided_arg_shapes number of user provided in_arg and aux_state shapes + * \param provided_arg_shape_names name list of provided shapes + * \param provided_arg_shape_data provided shape data + * \param provided_arg_shape_idx provided shape data index + * \param num_provided_arg_dtypes number of user provided in_arg and axu_state dtypes + * \param provided_arg_dtype_names argument name list of provided dtypes + * \param provided_arg_dtypes data of provided dtypes + * \param num_provided_arg_stypes number of user provided in_arg and axu_state storage types + * \param provided_arg_stype_names argument name list of provided storage types + * \param provided_arg_stypes data of provided storage types + * \param num_shared_arg_names number of parameter names passed from _bind_ith_exec + * \param shared_arg_name_list parameter name list passed from _bind_ith_exec + * \param shared_buffer_len number of shared data arrays passed from _bind_ith_exec + * \param shared_buffer_name_list shared data array names passed from _bind_ith_exec + * \param shared_buffer_handle_list shared data array handles passed from _bind_ith_exec + * \param updated_shared_buffer_name_list updated shared data array names after binding + * \param updated_shared_buffer_handle_list updated shared data arrays after binding + * \param num_in_args number of input arguments of this sym + * \param in_args list_arguments associated with the current executor + * \param arg_grads list of gradients of in_args associated with the current executor + * \param num_aux_states number of aux states of this sym + * \param aux_states list_auxiliary_states associated with the current executor + * \param shared_exec_handle shared excutor handle passed from _bind_ith_exec + * \param out the handle of the executor to be created + */ +int MXExecutorSimpleBindEx(SymbolHandle symbol_handle, + int dev_type, + int dev_id, + const mx_uint num_g2c_keys, + const char** g2c_keys, + const int* g2c_dev_types, + const int* g2c_dev_ids, + const mx_uint provided_grad_req_list_len, + const char** provided_grad_req_names, + const char** provided_grad_req_types, + const mx_uint num_provided_arg_shapes, + const char** provided_arg_shape_names, + const int* provided_arg_shape_data, + const mx_uint* provided_arg_shape_idx, + const mx_uint num_provided_arg_dtypes, + const char** provided_arg_dtype_names, + const int* provided_arg_dtypes, + const mx_uint num_provided_arg_stypes, + const char** provided_arg_stype_names, + const int* provided_arg_stypes, + const mx_uint num_shared_arg_names, + const char** shared_arg_name_list, + int* shared_buffer_len, + const char** shared_buffer_name_list, + NDArrayHandle* shared_buffer_handle_list, + const char*** updated_shared_buffer_name_list, + NDArrayHandle** updated_shared_buffer_handle_list, + mx_uint* num_in_args, + NDArrayHandle** in_args, + NDArrayHandle** arg_grads, + mx_uint* num_aux_states, + NDArrayHandle** aux_states, + ExecutorHandle shared_exec_handle, + ExecutorHandle* out) { + MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get(); + API_BEGIN(); + nnvm::Symbol *sym = static_cast(symbol_handle); + + // get in_arg names + std::vector in_arg_names = sym->ListInputNames(nnvm::Symbol::kReadOnlyArgs); + std::vector aux_state_names = sym->ListInputNames(nnvm::Symbol::kAuxiliaryStates); + + // attr_dict for setting up type_dict and arg/aux ctx + std::unordered_map> attr_dict; + if (nullptr == provided_arg_dtypes || nullptr != g2c_keys || nullptr == provided_arg_stypes) { + std::vector> attrs = + sym->ListAttrsRecursive(); + attr_dict.reserve(attrs.size()); + for (const auto& tp : attrs) { + attr_dict[std::get<0>(tp)][std::get<1>(tp)] = std::get<2>(tp); + } + } + + // setup arg_dtype_map + std::unordered_map arg_dtype_map; + if (nullptr == provided_arg_dtypes) { // use attr_dict + for (const auto& arg_name : in_arg_names) { + const auto it = attr_dict.find(arg_name); + if (it == attr_dict.end() || !it->second.count("__dtype__")) { + arg_dtype_map[arg_name] = mshadow::kFloat32; + } + } + } else { // use user input type_dict + // create dtype map for in_args and aux_states + arg_dtype_map.reserve(num_provided_arg_dtypes); + for (mx_uint i = 0; i < num_provided_arg_dtypes; ++i) { + arg_dtype_map[provided_arg_dtype_names[i]] = provided_arg_dtypes[i]; + } + } + + // setup arg_stype_map + std::unordered_map arg_stype_map; + if (nullptr == provided_arg_stypes) { // use attr_dict + for (const auto& arg_name : in_arg_names) { + const auto it = attr_dict.find(arg_name); + if (it == attr_dict.end() || !it->second.count("__storage_type__")) { + arg_stype_map[arg_name] = kDefaultStorage; + } + } + } else { // use user input type_dict + // create stype map for in_args and aux_states + arg_stype_map.reserve(num_provided_arg_stypes); + for (mx_uint i = 0; i < num_provided_arg_stypes; ++i) { + arg_stype_map[provided_arg_stype_names[i]] = provided_arg_stypes[i]; + } + } + + // create default ctx + Context ctx = Context::Create(static_cast(dev_type), dev_id); + // create ctx map + std::map ctx_map; + std::vector in_arg_ctx_vec(in_arg_names.size(), ctx); + std::vector aux_state_ctx_vec(aux_state_names.size(), ctx); + if (nullptr != g2c_keys) { // use user input group2ctx dict + for (mx_uint i = 0; i < num_g2c_keys; ++i) { + ctx_map[g2c_keys[i]] = Context::Create( + static_cast(g2c_dev_types[i]), g2c_dev_ids[i]); + } + + // initialize in_arg_ctx_vec using group2ctx if there are any + for (size_t i = 0; i < in_arg_ctx_vec.size(); ++i) { + const auto it1 = attr_dict.find(in_arg_names[i]); + if (it1 != attr_dict.end()) { + const auto it2 = it1->second.find("__ctx_group__"); + if (it2 != it1->second.end()) { + const auto it3 = ctx_map.find(it2->second); + if (it3 != ctx_map.end()) { + in_arg_ctx_vec[i] = it3->second; + } + } + } + } + + // initialize aux_state_ctx_vec using group2ctx if there are any + for (size_t i = 0; i < aux_state_ctx_vec.size(); ++i) { + const auto it1 = attr_dict.find(aux_state_names[i]); + if (it1 != attr_dict.end()) { + const auto it2 = it1->second.find("__ctx_group__"); + if (it2 != it1->second.end()) { + const auto it3 = ctx_map.find(it2->second); + if (it3 != ctx_map.end()) { + aux_state_ctx_vec[i] = it3->second; + } + } + } + } + } + + // create provided_grad_req_map + const std::map req_map = + {{"null", kNullOp}, {"write", kWriteTo}, {"add", kAddTo}}; + std::unordered_map provided_grad_req_map; + std::string grad_req_type; + if (0 == provided_grad_req_list_len + && nullptr == provided_grad_req_names + && nullptr != provided_grad_req_types) { // string, grad_req='write' + CHECK_EQ(req_map.count(provided_grad_req_types[0]), 1U) + << "grad_req=" << provided_grad_req_types[0] << " is not a valid input in simple_bind; " + "only \'null\', \'write\', and \'add\' are supported"; + grad_req_type = "string"; + } else if (provided_grad_req_list_len > 0 + && nullptr == provided_grad_req_names + && nullptr != provided_grad_req_types) { // list, grad_req=['null', 'write'] + grad_req_type = "list"; + CHECK_EQ(provided_grad_req_list_len, in_arg_names.size()) + << "The length of grad_req list does not match the number of input arguments in simple_bind, " + "expected " << in_arg_names.size() << ", provided " << provided_grad_req_list_len; + } else if (provided_grad_req_list_len > 0 + && nullptr != provided_grad_req_names + && nullptr != provided_grad_req_types) { // dict, grad_req=['lhs': 'null', 'rhs': 'write'] + grad_req_type = "dict"; + provided_grad_req_map.reserve(provided_grad_req_list_len); + for (mx_uint i = 0; i < provided_grad_req_list_len; ++i) { + CHECK_EQ(req_map.count(provided_grad_req_types[i]), 1U) + << "grad_req=" << provided_grad_req_types[i] << " is not a valid input in simple_bind; " + "only \'null\', \'write\', and \'add\' are supported"; + provided_grad_req_map[provided_grad_req_names[i]] = provided_grad_req_types[i]; + } + } else { // grad_req is None + grad_req_type = "none"; + } + + // initialize arg_grad_ctx_vec and grad_req_type_vec + std::vector arg_grad_ctx_vec(in_arg_names.size(), ctx); + std::vector grad_req_type_vec(in_arg_names.size(), kNullOp); + if ("none" != grad_req_type) { + for (size_t i = 0; i < in_arg_names.size(); ++i) { + OpReqType cur_req = kNullOp; + if ("string" == grad_req_type) { + cur_req = req_map.at(provided_grad_req_types[0]); + } else if ("list" == grad_req_type) { + CHECK_EQ(req_map.count(provided_grad_req_types[i]), 1U) + << "grad_req=" << provided_grad_req_types[i] << " is not a valid input in simple_bind; " + "only \'null\', \'write\', and \'add\' are supported"; + cur_req = req_map.at(provided_grad_req_types[i]); + } else if ("dict" == grad_req_type) { + const auto it = provided_grad_req_map.find(in_arg_names[i]); + if (it != provided_grad_req_map.end()) { + cur_req = req_map.at(it->second); + } + } + if (kNullOp != cur_req) { + arg_grad_ctx_vec[i] = in_arg_ctx_vec[i]; + grad_req_type_vec[i] = static_cast(cur_req); + } + } + } + + // create shape map for in_args and aux_states + std::unordered_map arg_shape_map(num_provided_arg_shapes); + for (mx_uint i = 0; i < num_provided_arg_shapes; ++i) { + auto p = arg_shape_map.emplace(provided_arg_shape_names[i], + mxnet::TShape(provided_arg_shape_data+provided_arg_shape_idx[i], + provided_arg_shape_data+provided_arg_shape_idx[i+1])); + CHECK(p.second) << "Duplicate shapes are provided for argument " + << provided_arg_shape_names[i] << " in simple_bind"; + } + if (!Imperative::Get()->is_np_comp()) { + for (auto &kv : arg_shape_map) { + common::ConvertToNumpyShape(&kv.second); + } + } // create para name set for sharing data array memory std::unordered_set shared_arg_name_set(num_shared_arg_names); @@ -628,6 +995,97 @@ int MXExecutorReshape(int partial_shaping, API_END_HANDLE_ERROR(delete new_exec); } +int MXExecutorReshapeEx(int partial_shaping, + int allow_up_sizing, + int dev_type, + int dev_id, + mx_uint num_map_keys, + const char** map_keys, + const int* map_dev_types, + const int* map_dev_ids, + const mx_uint num_provided_arg_shapes, + const char** provided_arg_shape_names, + const int* provided_arg_shape_data, + const mx_uint* provided_arg_shape_idx, + mx_uint* num_in_args, + NDArrayHandle** in_args, + NDArrayHandle** arg_grads, + mx_uint* num_aux_states, + NDArrayHandle** aux_states, + ExecutorHandle shared_exec, + ExecutorHandle *out) { + Executor* new_exec = nullptr; + + MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get(); + API_BEGIN(); + *out = nullptr; // ensure we can know whether to free executor on early abort + // create shape map for in_args and aux_states + std::unordered_map kwargs(num_provided_arg_shapes); + for (mx_uint i = 0; i < num_provided_arg_shapes; ++i) { + auto p = kwargs.emplace(provided_arg_shape_names[i], + mxnet::TShape(provided_arg_shape_data+provided_arg_shape_idx[i], + provided_arg_shape_data+provided_arg_shape_idx[i+1])); + CHECK(p.second) << "Duplicate shapes are provided for argument " + << provided_arg_shape_names[i] << " in reshape of executor"; + } + + Context ctx = Context::Create(static_cast(dev_type), dev_id); + std::map ctx_map; + for (mx_uint i = 0; i < num_map_keys; ++i) { + ctx_map[std::string(map_keys[i])] = Context::Create( + static_cast(map_dev_types[i]), map_dev_ids[i]); + } + std::vector in_arg_vec; + std::vector arg_grad_vec; + std::vector aux_state_vec; + + Executor* exec = static_cast(shared_exec); + new_exec = exec->Reshape(partial_shaping, allow_up_sizing, ctx, ctx_map, kwargs, + &in_arg_vec, &arg_grad_vec, &aux_state_vec); + *out = new_exec; + + ret->ret_handles.clear(); + ret->ret_handles.reserve(in_arg_vec.size()+arg_grad_vec.size()+aux_state_vec.size()); + + size_t nd_idx = 0; + for (const auto& nd : in_arg_vec) { + if (nd.is_none()) { + LOG(FATAL) << "Input argument NDArray cannot be un-allocated"; + } + ret->ret_handles.push_back(new NDArray(nd)); + } + if (in_arg_vec.size() > 0) { + *num_in_args = in_arg_vec.size(); + *in_args = &(ret->ret_handles[nd_idx]); + nd_idx = ret->ret_handles.size(); + } + + for (const auto& nd : arg_grad_vec) { + if (nd.is_none()) { + ret->ret_handles.push_back(nullptr); + } else { + ret->ret_handles.push_back(new NDArray(nd)); + } + } + if (arg_grad_vec.size() > 0) { + *arg_grads = &(ret->ret_handles[nd_idx]); + nd_idx = ret->ret_handles.size(); + } + + for (const auto& nd : aux_state_vec) { + if (nd.is_none()) { + LOG(FATAL) << "Auxiliary argument NDArray cannot be un-allocated"; + } + ret->ret_handles.push_back(new NDArray(nd)); + } + if (aux_state_vec.size() > 0) { + *num_aux_states = aux_state_vec.size(); + *aux_states = &(ret->ret_handles[nd_idx]); + nd_idx = ret->ret_handles.size(); + } + API_END_HANDLE_ERROR(delete new_exec); +} + int MXExecutorGetOptimizedSymbol(ExecutorHandle handle, SymbolHandle *out) { auto s = new nnvm::Symbol(); diff --git a/src/c_api/c_api_ndarray.cc b/src/c_api/c_api_ndarray.cc index 18f6c411e039..0e136b03ecd7 100644 --- a/src/c_api/c_api_ndarray.cc +++ b/src/c_api/c_api_ndarray.cc @@ -276,6 +276,18 @@ int MXAutogradSetIsRecording(int is_recording, int* prev) { API_END(); } +int MXIsNumpyCompatible(bool* curr) { + API_BEGIN(); + *curr = Imperative::Get()->is_np_comp(); + API_END(); +} + +int MXSetIsNumpyCompatible(int is_np_comp, int* prev) { + API_BEGIN(); + *prev = Imperative::Get()->set_is_np_comp(static_cast(is_np_comp)); + API_END(); +} + int MXAutogradMarkVariables(mx_uint num_var, NDArrayHandle *var_handles, mx_uint *reqs_array, diff --git a/src/c_api/c_api_symbolic.cc b/src/c_api/c_api_symbolic.cc index 545e95f04b79..24a88520376f 100644 --- a/src/c_api/c_api_symbolic.cc +++ b/src/c_api/c_api_symbolic.cc @@ -24,6 +24,7 @@ */ #include "mxnet/base.h" #include "mxnet/c_api.h" +#include "mxnet/imperative.h" #include "nnvm/c_api.h" #include "nnvm/pass.h" #include "nnvm/pass_functions.h" @@ -543,8 +544,14 @@ int MXSymbolInferShape(SymbolHandle sym, throw dmlc::Error(err.msg); } + // if use legacy shape definition, need to convert numpy shape to legacy shape + mxnet::ShapeVector shapes = g.GetAttr("shape"); + if (!Imperative::Get()->is_np_comp()) { + common::ConvertToLegacyShape(&shapes); + } + // copy back - CopyAttr(g.indexed_graph(), g.GetAttr("shape"), + CopyAttr(g.indexed_graph(), shapes, &(ret->arg_shapes), &(ret->out_shapes), &(ret->aux_shapes)); // copy data back @@ -568,6 +575,79 @@ int MXSymbolInferShape(SymbolHandle sym, API_END(); } +int MXSymbolInferShapeEx(SymbolHandle sym, + mx_uint num_args, + const char** keys, + const mx_uint *arg_ind_ptr, + const int *arg_shape_data, + mx_uint *in_shape_size, + const int **in_shape_ndim, + const int ***in_shape_data, + mx_uint *out_shape_size, + const int **out_shape_ndim, + const int ***out_shape_data, + mx_uint *aux_shape_size, + const int **aux_shape_ndim, + const int ***aux_shape_data, + int *complete) { + nnvm::Symbol *s = static_cast(sym); + MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get(); + API_BEGIN(); + nnvm::Graph g = Symbol2Graph(*s); + mxnet::ShapeVector arg_shapes(g.indexed_graph().input_nodes().size(), mxnet::TShape()); + if (keys == nullptr && num_args != 0) { + std::vector read_only_args = mxnet::ReadOnlyArgIndices(g.indexed_graph()); + CHECK_LE(num_args, read_only_args.size()); + for (mx_uint i = 0; i < num_args; ++i) { + arg_shapes[read_only_args[i]] = mxnet::ShapeTypeCast( + arg_shape_data + arg_ind_ptr[i], arg_shape_data + arg_ind_ptr[i+1]); + } + } else { + std::unordered_map kwargs; + for (mx_uint i = 0; i < num_args; ++i) { + kwargs[keys[i]] = mxnet::ShapeTypeCast( + arg_shape_data + arg_ind_ptr[i], arg_shape_data + arg_ind_ptr[i+1]); + } + mxnet::MatchArguments(g.indexed_graph(), kwargs, &arg_shapes, "InferShape"); + } + + try { + g = mxnet::exec::InferShape(std::move(g), std::move(arg_shapes), "__shape__"); + } catch (const mxnet::op::InferShapeError &err) { + throw dmlc::Error(err.msg); + } + + // if use legacy shape definition, need to convert numpy shape to legacy shape + mxnet::ShapeVector shapes = g.GetAttr("shape"); + if (!Imperative::Get()->is_np_comp()) { + common::ConvertToLegacyShape(&shapes); + } + + // copy back + CopyAttr(g.indexed_graph(), shapes, + &(ret->arg_shapes), &(ret->out_shapes), &(ret->aux_shapes)); + + // copy data back + MXAPIThreadLocalEntry::SetupShapeArrayReturnWithBufferEx(ret->arg_shapes, + &(ret->arg_shape_ndim_ex), &(ret->arg_shape_data_ex), &(ret->arg_shape_buffer_ex)); + MXAPIThreadLocalEntry::SetupShapeArrayReturnWithBufferEx(ret->out_shapes, + &(ret->out_shape_ndim_ex), &(ret->out_shape_data_ex), &(ret->out_shape_buffer_ex)); + MXAPIThreadLocalEntry::SetupShapeArrayReturnWithBufferEx(ret->aux_shapes, + &(ret->aux_shape_ndim_ex), &(ret->aux_shape_data_ex), &(ret->aux_shape_buffer_ex)); + *in_shape_size = static_cast(ret->arg_shapes.size()); + *in_shape_ndim = dmlc::BeginPtr(ret->arg_shape_ndim_ex); + *in_shape_data = dmlc::BeginPtr(ret->arg_shape_data_ex); + *out_shape_size = static_cast(ret->out_shapes.size()); + *out_shape_ndim = dmlc::BeginPtr(ret->out_shape_ndim_ex); + *out_shape_data = dmlc::BeginPtr(ret->out_shape_data_ex); + *aux_shape_size = static_cast(ret->aux_shapes.size()); + *aux_shape_ndim = dmlc::BeginPtr(ret->aux_shape_ndim_ex); + *aux_shape_data = dmlc::BeginPtr(ret->aux_shape_data_ex); + // mark complete + *complete = (g.GetAttr("shape_num_unknown_nodes") == 0); + API_END(); +} + int MXSymbolInferShapePartial(SymbolHandle sym, mx_uint num_args, const char** keys, @@ -593,6 +673,31 @@ int MXSymbolInferShapePartial(SymbolHandle sym, &succ); } +int MXSymbolInferShapePartialEx(SymbolHandle sym, + mx_uint num_args, + const char** keys, + const mx_uint *arg_ind_ptr, + const int *arg_shape_data, + mx_uint *in_shape_size, + const int **in_shape_ndim, + const int ***in_shape_data, + mx_uint *out_shape_size, + const int **out_shape_ndim, + const int ***out_shape_data, + mx_uint *aux_shape_size, + const int **aux_shape_ndim, + const int ***aux_shape_data, + int *complete) { + int succ; + *complete = 1; + return MXSymbolInferShapeEx(sym, num_args, keys, + arg_ind_ptr, arg_shape_data, + in_shape_size, in_shape_ndim, in_shape_data, + out_shape_size, out_shape_ndim, out_shape_data, + aux_shape_size, aux_shape_ndim, aux_shape_data, + &succ); +} + int MXSymbolInferType(SymbolHandle sym, mx_uint num_args, const char** keys, diff --git a/src/c_api/c_predict_api.cc b/src/c_api/c_predict_api.cc index 3b9f43d86079..7de23ef935ef 100644 --- a/src/c_api/c_predict_api.cc +++ b/src/c_api/c_predict_api.cc @@ -436,6 +436,7 @@ int MXPredGetOutputShape(PredictorHandle handle, << "Index exceed number of outputs"; const mxnet::TShape& s = p->out_shapes[out_index]; + CHECK_GE(s.ndim(), 0); p->out_shapes_buffer.resize(s.ndim()); nnvm::ShapeTypeCast(s.begin(), s.end(), p->out_shapes_buffer.data()); *shape_data = p->out_shapes_buffer.data(); diff --git a/src/common/exec_utils.h b/src/common/exec_utils.h index 279ecbd67f09..0551b429f17e 100644 --- a/src/common/exec_utils.h +++ b/src/common/exec_utils.h @@ -380,7 +380,7 @@ inline void HandleInferShapeError(const size_t num_forward_inputs, const uint32_t nid = idx.input_nodes().at(i); const uint32_t eid = idx.entry_id(nid, 0); const mxnet::TShape& inferred_shape = inferred_shapes[eid]; - if (inferred_shape.ndim() == 0 || inferred_shape.Size() == 0U) { + if (!shape_is_known(inferred_shape)) { const std::string& arg_name = idx[nid].source->attrs.name; oss << arg_name << ": " << inferred_shape << ", "; if (--cnt == 0) { @@ -390,7 +390,7 @@ inline void HandleInferShapeError(const size_t num_forward_inputs, } } LOG(FATAL) << "InferShape pass cannot decide shapes for the following arguments " - "(0s means unknown dimensions). Please consider providing them as inputs:\n" + "(-1 means unknown dimensions). Please consider providing them as inputs:\n" << oss.str(); } diff --git a/src/common/utils.h b/src/common/utils.h index 8e6966952890..251a8fe3c190 100644 --- a/src/common/utils.h +++ b/src/common/utils.h @@ -654,7 +654,7 @@ FCompType GetFCompute(const nnvm::Op* op, const std::string& name, } else if (ctx.dev_mask() == gpu::kDevMask) { return fcompute_gpu.get(op, nullptr); } else { - LOG(FATAL) << "Unknown device mask"; + LOG(FATAL) << "Unknown device mask " << ctx.dev_mask(); return nullptr; } } @@ -734,6 +734,64 @@ inline void ParallelCopy(DType* dst, const DType* src, index_t size) { } } +/*! + * \brief If numpy compatibility is turned off (default), the shapes passed in + * by users follow the legacy shape definition: + * 1. 0 ndim means the shape is completely unknown. + * 2. 0 dim size means the dim size is unknown. + * We need to convert those shapes to use the numpy shape definition: + * 1. 0 ndim means it's a scalar tensor. + * 2. -1 ndim means the shape is unknown. + * 3. 0 dim size means no elements in that dimension. + * 4. -1 dim size means the dimension's size is unknown. + * so that operator's infer shape function can work in backend. + * \param shape to be converted. + * Note: It is possible that the shape to be converted is already + * numpy compatible. For example, when a subgraph operator's infer + * shape function is called from the infer shape pass of the whole + * graph, its input/output shapes have been converted to numpy + * compatible shapes. + */ +inline void ConvertToNumpyShape(mxnet::TShape* shape) { + if (shape->ndim() == 0) { // legacy shape ndim = 0 means unknown + *shape = mxnet::TShape(); // unknown shape ndim = -1 + } else { + for (int j = 0; j < shape->ndim(); ++j) { + if ((*shape)[j] == 0) { // legacy shape dim_size = 0 means unknown + (*shape)[j] = -1; // unknown dim size = -1 + } + } + } +} + +inline void ConvertToNumpyShape(mxnet::ShapeVector* shapes) { + for (size_t i = 0; i < shapes->size(); ++i) { + ConvertToNumpyShape(&(shapes->at(i))); + } +} + +/*! + * \brief This is function is used to convert shapes returned by + * the infer shape functions/pass to the legacy shape definition. + */ +inline void ConvertToLegacyShape(mxnet::TShape* shape) { + if (!mxnet::ndim_is_known(*shape)) { + *shape = mxnet::TShape(0, -1); + } else { + for (int j = 0; j < shape->ndim(); ++j) { + if (!mxnet::dim_size_is_known(*shape, j)) { + (*shape)[j] = 0; + } + } + } +} + +inline void ConvertToLegacyShape(mxnet::ShapeVector* shapes) { + for (size_t i = 0; i < shapes->size(); ++i) { + ConvertToLegacyShape(&(shapes->at(i))); + } +} + } // namespace common } // namespace mxnet #endif // MXNET_COMMON_UTILS_H_ diff --git a/src/executor/graph_executor.cc b/src/executor/graph_executor.cc index 460cec371bd4..4a4505581920 100644 --- a/src/executor/graph_executor.cc +++ b/src/executor/graph_executor.cc @@ -34,6 +34,7 @@ #include "../common/utils.h" #include "../common/exec_utils.h" #include "../operator/subgraph/subgraph_property.h" +#include "../operator/operator_common.h" namespace mxnet { namespace exec { @@ -966,7 +967,7 @@ void GraphExecutor::InitDataEntryMemory(std::vector* shared_pool) { uint32_t oid = head_grad_map_.at(idx[nid].source); uint32_t eid = idx.entry_id(idx.outputs()[oid]); NDArrayStorageType stype = (NDArrayStorageType) vstorage_type[eid]; - CHECK_NE(vshape[eid].ndim(), 0U); + CHECK(mxnet::shape_is_known(vshape[eid])); CHECK_NE(vdtype[eid], -1); auto data_eid = idx.entry_id(nid, 0); // initialize based on storage_type diff --git a/src/executor/infer_graph_attr_pass.cc b/src/executor/infer_graph_attr_pass.cc index 6a7fde62c2cf..fa7aee518486 100644 --- a/src/executor/infer_graph_attr_pass.cc +++ b/src/executor/infer_graph_attr_pass.cc @@ -24,6 +24,7 @@ #include #include +#include #include "./exec_pass.h" #include "../operator/operator_common.h" #include "../common/exec_utils.h" @@ -467,6 +468,12 @@ nnvm::Graph InferShapeAttr(nnvm::Graph &&ret, std::vector ishape, oshape; // whether a shape is dynamic std::vector is_dynamic(rshape.size(), 0); + + // convert to numpy compatible shape to use operator's infer shape function + if (!Imperative::Get()->is_np_comp()) { + common::ConvertToNumpyShape(&rshape); + } + // inference step function for nid auto infer_step = [&](uint32_t nid, bool last_iter) { const auto& inode = idx[nid]; @@ -483,6 +490,9 @@ nnvm::Graph InferShapeAttr(nnvm::Graph &&ret, if (it != inode.source->attrs.dict.end()) { std::istringstream is(it->second); CHECK(is >> rshape[out_ent_id]) << "Invalid attribute"; + if (!Imperative::Get()->is_np_comp()) { + common::ConvertToNumpyShape(&rshape[out_ent_id]); + } } } // assign a default value to node attribute @@ -546,7 +556,7 @@ nnvm::Graph InferShapeAttr(nnvm::Graph &&ret, bool is_input_dynamic_shape = false; for (uint32_t i = 0; i < ishape.size(); ++i) { ishape[i] = rshape[idx.entry_id(inode.inputs[i])]; - if (ishape[i].ndim() == 0 && is_dynamic[idx.entry_id(inode.inputs[i])]) { + if (!mxnet::ndim_is_known(ishape[i]) && is_dynamic[idx.entry_id(inode.inputs[i])]) { is_input_dynamic_shape = true; } if (fis_none(ishape[i])) forward_known = false; @@ -563,7 +573,7 @@ nnvm::Graph InferShapeAttr(nnvm::Graph &&ret, auto finfer = finfer_shape.get(inode.source->op(), fdefault); if (finfer == nullptr || is_input_dynamic_shape) { for (uint32_t i = 0; i < oshape.size(); ++i) { - if (oshape[i].ndim() == 0) { + if (!mxnet::ndim_is_known(oshape[i].ndim())) { is_dynamic[idx.entry_id(nid, i)] = 1; } } @@ -648,14 +658,14 @@ nnvm::Graph InferShape(nnvm::Graph&& graph, std::move(graph), mxnet::TShape(), "FInferShape", "shape_inputs", "shape_attr_key", "shape", "shape_num_unknown_nodes", - [](const mxnet::TShape& s) { return s.ndim() == 0 || s.Size() == 0; }, + [](const mxnet::TShape& s) { return !mxnet::shape_is_known(s); }, [](const mxnet::TShape& s) { - if (s.ndim() == 0) { // TODO(reminisce): Usage of ndim + if (!mxnet::ndim_is_known(s)) { return static_cast(1); } size_t ret = 0; for (const auto& val : s) { - if (val == 0) { + if (!mxnet::dim_size_is_known(val)) { ++ret; } } diff --git a/src/imperative/imperative.cc b/src/imperative/imperative.cc index 3e5b3987522c..b027de0a0f6f 100644 --- a/src/imperative/imperative.cc +++ b/src/imperative/imperative.cc @@ -25,9 +25,11 @@ namespace mxnet { #if DMLC_CXX11_THREAD_LOCAL thread_local bool Imperative::is_train_ = false; thread_local bool Imperative::is_recording_ = false; +thread_local bool Imperative::is_np_comp_ = false; #else MX_THREAD_LOCAL bool Imperative::is_train_ = false; MX_THREAD_LOCAL bool Imperative::is_recording_ = false; +MX_THREAD_LOCAL bool Imperative::is_np_comp_ = false; #endif Imperative* Imperative::Get() { @@ -109,7 +111,7 @@ OpStatePtr Imperative::Invoke( OpStatePtr ret = InvokeOp(ctx, attrs, inputs, outputs, req, dispatch_mode); // the followinng loop is used for finding out the correct shape when some shapes are dynamic for (size_t i = 0; i < outputs.size(); i++) { - if (outputs[i]->shape().ndim() == 0) { + if (!shape_is_known(outputs[i]->shape())) { // the WaitToRead overhead here does not seem to be avoidable outputs[i]->WaitToRead(); outputs[i]->SetShapeFromChunk(); diff --git a/src/imperative/imperative_utils.cc b/src/imperative/imperative_utils.cc index c7204c1d85e6..568d39fc8043 100644 --- a/src/imperative/imperative_utils.cc +++ b/src/imperative/imperative_utils.cc @@ -19,6 +19,7 @@ #include "./imperative_utils.h" #include "./cached_op.h" +#include "../operator/operator_common.h" namespace { @@ -190,7 +191,7 @@ void NaiveRunGraph( Imperative::Get()->InvokeOp(ctx, node.source->attrs, ndinputs, ndoutputs, req, dispatch_mode, state); for (size_t j = 0; j < ndoutputs.size(); ++j) { - if (ndoutputs[j]->shape().ndim() == 0) { + if (mxnet::op::shape_is_none(ndoutputs[j]->shape())) { ndoutputs[j]->WaitToRead(); ndoutputs[j]->SetShapeFromChunk(); } diff --git a/src/imperative/imperative_utils.h b/src/imperative/imperative_utils.h index d134d47c55cf..9d4e4bd15a37 100644 --- a/src/imperative/imperative_utils.h +++ b/src/imperative/imperative_utils.h @@ -31,6 +31,7 @@ #include "../common/utils.h" #include "../common/exec_utils.h" #include "../operator/nn/mkldnn/mkldnn_base-inl.h" +#include "../operator/operator_common.h" #ifndef MXNET_IMPERATIVE_IMPERATIVE_UTILS_H_ #define MXNET_IMPERATIVE_IMPERATIVE_UTILS_H_ @@ -121,7 +122,28 @@ inline void SetShapeType(const Context& ctx, if (!infershape.count(attrs.op)) { is_dynamic_shape_existing = true; } else { - CHECK(infershape[attrs.op](attrs, &in_shapes, &out_shapes)); + if (!Imperative::Get()->is_np_comp()) { + common::ConvertToNumpyShape(&in_shapes); + common::ConvertToNumpyShape(&out_shapes); + } + const bool success = infershape[attrs.op](attrs, &in_shapes, &out_shapes); + if (!success) { + std::stringstream os; + os << "Operator " << attrs.op->name << " inferring shapes failed.\n"; + os << "input shapes:\n"; + for (const auto& s : in_shapes) { + os << s << '\n'; + } + os << "output shapes:\n"; + for (const auto& s : out_shapes) { + os << s << '\n'; + } + os << "operator attributes:\n"; + for (const auto& kv : attrs.dict) { + os << kv.first << " : " << kv.second << '\n'; + } + LOG(FATAL) << os.str(); + } CHECK_EQ(out_shapes.size(), outputs.size()); } // infer type @@ -179,7 +201,7 @@ inline void SetShapeType(const Context& ctx, for (size_t i = 0; i < outputs.size(); ++i) { NDArrayStorageType storage_type = static_cast(out_storage_types[i]); - if (outputs[i]->is_none() || outputs[i]->shape().ndim() == 0) { + if (outputs[i]->is_none() || mxnet::op::shape_is_none(outputs[i]->shape())) { if (is_dynamic_shape_existing) { // once there is dynamic shape somewhere, we could not pre-determine the shape. *outputs[i] = NDArray(ctx, out_types[i]); diff --git a/src/io/image_io.cc b/src/io/image_io.cc index 2196983928bb..965078cb2766 100644 --- a/src/io/image_io.cc +++ b/src/io/image_io.cc @@ -189,7 +189,7 @@ void Imdecode(const nnvm::NodeAttrs& attrs, size_t len = inputs[0].shape().Size(); CHECK(len > 0) << "Input cannot be an empty buffer"; - mxnet::TShape oshape(3); + mxnet::TShape oshape(3, 1); oshape[2] = param.flag == 0 ? 1 : 3; if (get_jpeg_size(str_img, len, &oshape[1], &oshape[0])) { } else if (get_png_size(str_img, len, &oshape[1], &oshape[0])) { @@ -229,7 +229,7 @@ void Imread(const nnvm::NodeAttrs& attrs, CHECK(file.good()) << "Failed reading image file: '" << param.filename << "' " << strerror(errno); - mxnet::TShape oshape(3); + mxnet::TShape oshape(3, 1); oshape[2] = param.flag == 0 ? 1 : 3; if (get_jpeg_size(buff.get(), fsize, &oshape[1], &oshape[0])) { } else if (get_png_size(buff.get(), fsize, &oshape[1], &oshape[0])) { diff --git a/src/kvstore/gradient_compression.cc b/src/kvstore/gradient_compression.cc index e4a06fa9a1f2..30aaec91e27f 100644 --- a/src/kvstore/gradient_compression.cc +++ b/src/kvstore/gradient_compression.cc @@ -100,9 +100,9 @@ int64_t GradientCompression::GetCompressedSize(const int64_t original_size) { void GradientCompression::Quantize(const mxnet::NDArray &from, mxnet::NDArray *to, mxnet::NDArray *residual, const int priority) { - CHECK(from.shape().ndim() != 0) << "source operand has zero dimension shape"; - CHECK(to->shape().ndim() != 0) << "destination operand has zero dimension shape"; - CHECK(residual->shape().ndim() != 0) << "residual operand has zero dimension shape"; + CHECK(shape_is_known(from.shape())) << "source operand has undefined shape"; + CHECK(shape_is_known(to->shape())) << "destination operand has undefined shape"; + CHECK(shape_is_known(residual->shape())) << "residual operand has undefined shape"; const int a = from.ctx().dev_mask(); const int b = to->ctx().dev_mask(); const float threshold = threshold_; @@ -137,8 +137,8 @@ void GradientCompression::Quantize(const mxnet::NDArray &from, mxnet::NDArray *t void GradientCompression::Dequantize(const mxnet::NDArray &from, mxnet::NDArray *to, const int priority) { - CHECK(from.shape().ndim() != 0) << "source operands has zero dimension shape"; - CHECK(to->shape().ndim() != 0) << "destination operand has zero dimension shape"; + CHECK(shape_is_known(from.shape())) << "source operand has undefined shape"; + CHECK(shape_is_known(to->shape())) << "destination operand has undefined shape"; const int a = from.ctx().dev_mask(); const int b = to->ctx().dev_mask(); const float threshold = threshold_; diff --git a/src/ndarray/ndarray.cc b/src/ndarray/ndarray.cc index 377bef072b03..f5aac36a48eb 100644 --- a/src/ndarray/ndarray.cc +++ b/src/ndarray/ndarray.cc @@ -549,7 +549,7 @@ const mkldnn::memory *NDArray::GetMKLDNNDataReorder( // If they have different shapes, we need to reshape the array first. // Since this method will only be used inside an operator, we can call // MKLDNNDataReshape to reshape an array. - mxnet::TShape required_shape(desc2.data.ndims); + mxnet::TShape required_shape(desc2.data.ndims, -1); for (int i = 0; i < desc2.data.ndims; i++) required_shape[i] = desc2.data.dims[i]; NDArray reshaped = MKLDNNDataReshape(required_shape); @@ -575,7 +575,7 @@ NDArray NDArray::Reorder2Default() const { // create new ndarray from mkldnn layout mkldnn::memory::desc from_desc = ptr_->mkl_mem_->GetPrimitiveDesc().desc(); - mxnet::TShape tshape(from_desc.data.ndims); + mxnet::TShape tshape(from_desc.data.ndims, -1); for (int i = 0; i < from_desc.data.ndims; i++) tshape[i] = from_desc.data.dims[i]; NDArray ret(tshape, ctx(), false, dtype()); mkldnn::memory::primitive_desc def_pd = ptr_->mkl_mem_->GetPrimitiveDesc(format); @@ -1191,8 +1191,8 @@ void CopyFromTo(const NDArray& from, const NDArray& to, int priority, bool is_op CHECK(from.shape() == to.shape()) << "operands shape mismatch" << "from.shape = " << from.shape() << " to.shape=" << to.shape(); - CHECK(from.shape().ndim() != 0) - << "source operands have zero dimension shape"; + CHECK(!mxnet::op::shape_is_none(from.shape())) + << "source operands have undefined shape"; // important: callback must always capture by value const Context from_ctx = from.ctx(); const int a = from_ctx.dev_mask(); @@ -1650,7 +1650,7 @@ bool LegacyTShapeLoad(dmlc::Stream *strm, mxnet::TShape *shape, const uint32_t m default: // meet legacy mxnet::TShape, magic is ndim here uint32_t ndim = magic; - *shape = mxnet::TShape(ndim); + *shape = mxnet::TShape(ndim, -1); std::vector buffer(ndim); size_t nread = ndim * sizeof(uint32_t); if (strm->Read(buffer.data(), nread) != nread) return false; @@ -1663,7 +1663,7 @@ bool NDArray::LegacyLoad(dmlc::Stream *strm, const uint32_t magic) { // load shape mxnet::TShape shape; if (!LegacyTShapeLoad(strm, &shape, magic)) return false; - if (shape.ndim() == 0) { + if (mxnet::op::shape_is_none(shape)) { *this = NDArray(); return true; } // load context @@ -1711,7 +1711,10 @@ bool NDArray::Load(dmlc::Stream *strm) { // load shape mxnet::TShape shape; if (!shape.Load(strm)) return false; - if (shape.ndim() == 0) { + if (!Imperative::Get()->is_np_comp()) { + common::ConvertToNumpyShape(&shape); + } + if (mxnet::op::shape_is_none(shape)) { *this = NDArray(); return true; } diff --git a/src/ndarray/ndarray_function.cc b/src/ndarray/ndarray_function.cc index a613d5a3decc..8f72bc259afc 100644 --- a/src/ndarray/ndarray_function.cc +++ b/src/ndarray/ndarray_function.cc @@ -210,8 +210,6 @@ void ElementwiseSumContainsDnsImpl(mshadow::Stream* s, Kernel::Launch(s, out_data.Size(), out_data.dptr()); for (size_t i = 0; i < nds.size(); ++i) { const NDArray& nd = nds[i]; - const nnvm::dim_t num_rows = nd.shape()[0]; - const nnvm::dim_t num_cols = nd.shape()[1]; const TBlob& nd_data = nd.data(); if (i == 0) { @@ -234,6 +232,8 @@ void ElementwiseSumContainsDnsImpl(mshadow::Stream* s, case kCSRStorage: { const TBlob& nd_indices = nd.aux_data(csr::kIdx); const TBlob& nd_indptr = nd.aux_data(csr::kIndPtr); + const nnvm::dim_t num_rows = nd.shape()[0]; + const nnvm::dim_t num_cols = nd.shape()[1]; MSHADOW_IDX_TYPE_SWITCH(nd_indices.type_flag_, IType, { // indices type MSHADOW_IDX_TYPE_SWITCH(nd_indptr.type_flag_, CType, { // indptr type if (nd.storage_initialized()) { @@ -248,6 +248,8 @@ void ElementwiseSumContainsDnsImpl(mshadow::Stream* s, } case kRowSparseStorage: { const TBlob& nd_indices = nd.aux_data(rowsparse::kIdx); + const nnvm::dim_t num_rows = nd.shape()[0]; + const nnvm::dim_t num_cols = nd.shape()[1]; MSHADOW_IDX_TYPE_SWITCH(nd_indices.type_flag_, IType, { // indices type if (nd.storage_initialized()) { const nnvm::dim_t nz_rows = nd_indices.Size(); diff --git a/src/ndarray/ndarray_function.h b/src/ndarray/ndarray_function.h index 70b626dbb9b7..505bd205a8d5 100644 --- a/src/ndarray/ndarray_function.h +++ b/src/ndarray/ndarray_function.h @@ -40,7 +40,7 @@ namespace ndarray { struct BinaryBase { inline static mxnet::TShape GetShape(const mxnet::TShape &lshape, const mxnet::TShape &rshape) { CHECK(lshape == rshape) << "operands shape mismatch"; - CHECK(lshape.ndim() != 0) << "source operand have zero dimension shape"; + CHECK(!mxnet::op::shape_is_none(lshape)) << "source operand have zero dimension shape"; return lshape; } }; diff --git a/src/nnvm/plan_memory.cc b/src/nnvm/plan_memory.cc index 2b18f990c845..41b8559d16c2 100644 --- a/src/nnvm/plan_memory.cc +++ b/src/nnvm/plan_memory.cc @@ -30,6 +30,7 @@ #include #include #include "graph_algorithm.h" +#include "../operator/operator_common.h" namespace nnvm { namespace pass { @@ -75,7 +76,7 @@ class GraphAllocator { // request a free storage StorageID Request(int dev_id, int dtype, mxnet::TShape shape, uint32_t node_id) { - if (shape.ndim() == 0) return kBadStorageID; + if (!mxnet::shape_is_known(shape)) return kBadStorageID; // search memory block in [size / match_range_, size * match_range_) // TODO(tqchen) add size of the dtype, assume 4 bytes for now size_t size = shape.Size() * 4; @@ -267,8 +268,7 @@ size_t AllocMemory(const Graph& ret, const IndexedGraph& idx, // only request memory for kBadStorageID if (storage[eid] == GraphAllocator::kBadStorageID) { auto &eshape = shape_vec[eid]; - size_t esize = 0; - if (eshape.ndim() != 0) esize = eshape.Size(); + size_t esize = eshape.Size(); eids.insert(std::make_pair(esize, eid)); } } diff --git a/src/operator/batch_norm_v1-inl.h b/src/operator/batch_norm_v1-inl.h index f407a5cce61b..89412357ac67 100644 --- a/src/operator/batch_norm_v1-inl.h +++ b/src/operator/batch_norm_v1-inl.h @@ -261,7 +261,7 @@ class BatchNormV1Prop : public OperatorProperty { using namespace mshadow; CHECK_EQ(in_shape->size(), 3U) << "Input:[data, gamma, beta]"; const mxnet::TShape &dshape = in_shape->at(0); - if (dshape.ndim() == 0) return false; + if (!mxnet::ndim_is_known(dshape)) return false; in_shape->at(1) = mxnet::TShape(Shape1(dshape[1])); in_shape->at(2) = mxnet::TShape(Shape1(dshape[1])); out_shape->clear(); diff --git a/src/operator/bilinear_sampler-inl.h b/src/operator/bilinear_sampler-inl.h index 8b1ff38709b6..abb4a61dc84c 100644 --- a/src/operator/bilinear_sampler-inl.h +++ b/src/operator/bilinear_sampler-inl.h @@ -149,10 +149,10 @@ class BilinearSamplerProp : public OperatorProperty { CHECK_EQ(in_shape->size(), 2U) << "Input:[data, grid]"; const mxnet::TShape &dshape = (*in_shape)[bs::kData]; const mxnet::TShape &lshape = (*in_shape)[bs::kGrid]; - if (dshape.ndim() == 0) return false; + if (!shape_is_known(dshape)) return false; CHECK_EQ(dshape.ndim(), 4U) \ << "input data should be 4D in batch-num_filter-y-x"; - if (lshape.ndim() == 0) return false; + if (!shape_is_known(lshape)) return false; CHECK_EQ(lshape.ndim(), 4U) \ << "Sampler grid should be 4D in batch-2-y-x"; CHECK_EQ(dshape[0], lshape[0]); diff --git a/src/operator/channel_op_common.h b/src/operator/channel_op_common.h index 1afc13ad2594..43f689d2defa 100644 --- a/src/operator/channel_op_common.h +++ b/src/operator/channel_op_common.h @@ -45,6 +45,8 @@ inline void concatenate_helper(const std::vector(out, begin, end), req, input[i]); begin = end; @@ -80,6 +82,8 @@ void split_helper(const mshadow::Tensor &input, size_t size = out.size(); index_t begin = 0; for (size_t i = 0; i < size; ++i) { + // If out[i] is a zero-size tensor, do nothing. + if (out[i].shape_.Size() == 0) continue; index_t end = begin + out[i].size(cdim); Assign(out[i], req[i], slice(input, begin, end)); begin = end; diff --git a/src/operator/contrib/adamw-inl.h b/src/operator/contrib/adamw-inl.h index 07feaefe87aa..6ae9e46b7def 100644 --- a/src/operator/contrib/adamw-inl.h +++ b/src/operator/contrib/adamw-inl.h @@ -87,8 +87,9 @@ inline bool MPUpdateInferShape(const nnvm::NodeAttrs& attrs, mxnet::ShapeVector *out_attrs) { CHECK_EQ(in_attrs->size(), static_cast(total_in)) << " in operator " << attrs.name; CHECK_EQ(out_attrs->size(), static_cast(n_out)) << " in operator " << attrs.name; - // rescale_grad.shape = (1,) - SHAPE_ASSIGN_CHECK(*in_attrs, total_in - 1, mshadow::Shape1(1)); + // rescale_grad.shape = () + SHAPE_ASSIGN_CHECK(*in_attrs, total_in - 1, mxnet::TShape()); + // TODO(@reminisce): change "none" behavior in ElemwiseAttr return ElemwiseAttr( attrs, in_attrs, out_attrs, mxnet::TShape()); } diff --git a/src/operator/contrib/adaptive_avg_pooling-inl.h b/src/operator/contrib/adaptive_avg_pooling-inl.h index 0d66de0a5692..eedab78db0c5 100644 --- a/src/operator/contrib/adaptive_avg_pooling-inl.h +++ b/src/operator/contrib/adaptive_avg_pooling-inl.h @@ -48,9 +48,9 @@ namespace mxnet { namespace op { struct AdaptiveAvgPoolParam : public dmlc::Parameter { - mxnet::TShape output_size; + mxnet::Tuple output_size; DMLC_DECLARE_PARAMETER(AdaptiveAvgPoolParam) { - DMLC_DECLARE_FIELD(output_size).set_default(mxnet::TShape()) + DMLC_DECLARE_FIELD(output_size).set_default(mxnet::Tuple()) .describe("int (output size) or a tuple of int for output (height, width)."); } }; @@ -125,7 +125,7 @@ static bool AdaptiveAvgPoolOpInferShape(const nnvm::NodeAttrs& attrs, CHECK_EQ(out_shape->size(), 1U) << "Output:[data]"; const AdaptiveAvgPoolParam& param = nnvm::get(attrs.parsed); mxnet::TShape dshape(in_shape->at(0)); - if (dshape.ndim() == 0) return false; + if (mxnet::op::shape_is_none(dshape)) return false; if (param.output_size.ndim() == 0) { dshape[2] = 1; dshape[3] = 1; diff --git a/src/operator/contrib/bilinear_resize-inl.h b/src/operator/contrib/bilinear_resize-inl.h index 46c8e1aa7c0d..ce9c6c83504c 100644 --- a/src/operator/contrib/bilinear_resize-inl.h +++ b/src/operator/contrib/bilinear_resize-inl.h @@ -134,7 +134,7 @@ static bool BilinearSampleOpInferShape(const nnvm::NodeAttrs& attrs, CHECK_EQ(out_shape->size(), 1U) << "Output:[data]"; const BilinearSampleParam& param = nnvm::get(attrs.parsed); mxnet::TShape dshape(in_shape->at(0)); - if (dshape.ndim() == 0) return false; + if (mxnet::op::shape_is_none(dshape)) return false; if (param.scale_height.has_value()) { dshape[2] = static_cast(param.scale_height.value() * in_shape->at(0)[2]); } else { diff --git a/src/operator/contrib/boolean_mask.cc b/src/operator/contrib/boolean_mask.cc index e22c493d5e2c..06d8439e23a0 100644 --- a/src/operator/contrib/boolean_mask.cc +++ b/src/operator/contrib/boolean_mask.cc @@ -121,7 +121,7 @@ inline void BooleanMaskForward(const nnvm::NodeAttrs& attrs, const NDArray &out = outputs[0]; CHECK_EQ(axis, 0) << "Not supported yet"; CHECK_EQ(data.shape()[axis], idx.shape()[0]); - CHECK_EQ(idx.shape().ndim(), 1U); + CHECK_EQ(idx.shape().ndim(), 1U); // idx is required to be 1-d. // count the number of 1s in `idx`, so that we could know the output dimension size_t idx_size = idx.shape()[0]; std::vector prefix_sum(idx_size, 0); diff --git a/src/operator/contrib/bounding_box-inl.h b/src/operator/contrib/bounding_box-inl.h index 37c4297ff49d..686f1666a310 100644 --- a/src/operator/contrib/bounding_box-inl.h +++ b/src/operator/contrib/bounding_box-inl.h @@ -94,7 +94,8 @@ inline bool BoxNMSShape(const nnvm::NodeAttrs& attrs, const BoxNMSParam& param = nnvm::get(attrs.parsed); CHECK_EQ(in_attrs->size(), 1U); CHECK_EQ(out_attrs->size(), 2U); - if (in_attrs->at(0).ndim() == 0U && out_attrs->at(0).ndim() == 0U) { + if (mxnet::op::shape_is_none(in_attrs->at(0)) + && mxnet::op::shape_is_none(out_attrs->at(0))) { return false; } @@ -556,7 +557,7 @@ inline bool BoxOverlapShape(const nnvm::NodeAttrs& attrs, << rdim << " provided"; // assign output shape - mxnet::TShape oshape(lshape.ndim() + rshape.ndim() - 2); + mxnet::TShape oshape(lshape.ndim() + rshape.ndim() - 2, -1); int idx = 0; for (index_t i = 0; i < lshape.ndim() - 1; ++i) { oshape[idx++] = lshape[i]; @@ -565,7 +566,7 @@ inline bool BoxOverlapShape(const nnvm::NodeAttrs& attrs, oshape[idx++] = rshape[i]; } SHAPE_ASSIGN_CHECK(*out_attrs, 0, oshape); - return true; + return shape_is_known(oshape); } struct compute_overlap { @@ -669,14 +670,14 @@ inline bool MatchingShape(const nnvm::NodeAttrs& attrs, << dshape.ndim() << " provided"; // assign output shape - mxnet::TShape oshape(dshape.ndim() - 1); + mxnet::TShape oshape(dshape.ndim() - 1, -1); for (index_t i = 0; i < dshape.ndim() - 1; ++i) { oshape[i] = dshape[i]; } SHAPE_ASSIGN_CHECK(*out_attrs, 0, oshape); oshape[oshape.ndim() - 1] = dshape[dshape.ndim() - 1]; SHAPE_ASSIGN_CHECK(*out_attrs, 1, oshape); - return true; + return shape_is_known(oshape); } struct bipartite_matching { diff --git a/src/operator/contrib/count_sketch-inl.h b/src/operator/contrib/count_sketch-inl.h index f3a294f6ad46..3ea93e63d6fc 100644 --- a/src/operator/contrib/count_sketch-inl.h +++ b/src/operator/contrib/count_sketch-inl.h @@ -151,7 +151,7 @@ class CountSketchProp : public OperatorProperty { CHECK_EQ(in_shape->size(), 3) <<"Input:[data, h, s]"; const mxnet::TShape &dshape = (*in_shape)[CountSketch::kData]; // require data to be known - if (dshape.ndim() == 0) return false; + if (mxnet::op::shape_is_none(dshape)) return false; out_shape->clear(); if (dshape.ndim() == 4) { diff --git a/src/operator/contrib/deformable_convolution-inl.h b/src/operator/contrib/deformable_convolution-inl.h index f50641fca6d6..000d703066d7 100644 --- a/src/operator/contrib/deformable_convolution-inl.h +++ b/src/operator/contrib/deformable_convolution-inl.h @@ -69,11 +69,11 @@ struct DeformableConvolutionParam : public dmlc::Parameter layout; DMLC_DECLARE_PARAMETER(DeformableConvolutionParam) { DMLC_DECLARE_FIELD(kernel).describe("Convolution kernel size: (h, w) or (d, h, w)"); - DMLC_DECLARE_FIELD(stride).set_default(mxnet::TShape()) + DMLC_DECLARE_FIELD(stride).set_default(mxnet::TShape(0, -1)) .describe("Convolution stride: (h, w) or (d, h, w). Defaults to 1 for each dimension."); - DMLC_DECLARE_FIELD(dilate).set_default(mxnet::TShape()) + DMLC_DECLARE_FIELD(dilate).set_default(mxnet::TShape(0, -1)) .describe("Convolution dilate: (h, w) or (d, h, w). Defaults to 1 for each dimension."); - DMLC_DECLARE_FIELD(pad).set_default(mxnet::TShape()) + DMLC_DECLARE_FIELD(pad).set_default(mxnet::TShape(0, -1)) .describe("Zero pad for convolution: (h, w) or (d, h, w). Defaults to no padding."); DMLC_DECLARE_FIELD(num_filter).set_range(1, 100000) .describe("Convolution filter(channel) number"); @@ -127,9 +127,9 @@ class DeformableConvolutionOp : public Operator { Tensor workspace = ctx.requested[conv::kTempSpace] .get_space_typed(Shape1(col_buffer_size_), s); // calculate the shape of col_buffer - mxnet::TShape col_buffer_shape(num_spatial_axes_ + 1); + mxnet::TShape col_buffer_shape(num_spatial_axes_ + 1, -1); col_buffer_shape[0] = conv_in_channels_ * param_.kernel.Size(); - for (size_t i = 1; i < col_buffer_shape.ndim(); ++i) { + for (int i = 1; i < col_buffer_shape.ndim(); ++i) { col_buffer_shape[i] = out_data[0].shape_[i + 1]; } // create a column buffer using workspace and col_buffer_shape @@ -189,7 +189,7 @@ class DeformableConvolutionOp : public Operator { Tensor workspace = ctx.requested[conv::kTempSpace] .get_space_typed(Shape1(col_buffer_size_), s); // calculate the shape of col_buffer - mxnet::TShape col_buffer_shape(num_spatial_axes_ + 1); + mxnet::TShape col_buffer_shape(num_spatial_axes_ + 1, -1); col_buffer_shape[0] = conv_in_channels_ * param_.kernel.Size(); for (index_t i = 1; i < col_buffer_shape.ndim(); ++i) { col_buffer_shape[i] = out_grad[conv::kData].shape_[i + 1]; @@ -371,7 +371,7 @@ class DeformableConvolutionProp : public OperatorProperty { out_shape->resize(1, mxnet::TShape()); const mxnet::TShape &dshp = (*in_shape)[conv::kData]; const mxnet::TShape &oshp = (*in_shape)[conv::kOffset]; - if (dshp.ndim() == 0) return false; + if (mxnet::op::shape_is_none(dshp)) return false; if (param_.kernel.ndim() == 2) { // 2d conv CHECK_EQ(dshp.ndim(), 4U) \ diff --git a/src/operator/contrib/dgl_graph.cc b/src/operator/contrib/dgl_graph.cc index f19af84ce9c6..313b855f0d2d 100644 --- a/src/operator/contrib/dgl_graph.cc +++ b/src/operator/contrib/dgl_graph.cc @@ -259,34 +259,28 @@ static bool CSRNeighborUniformSampleShape(const nnvm::NodeAttrs& attrs, // Output bool success = true; - mxnet::TShape out_shape(1); + mxnet::TShape out_shape(1, -1); // We use the last element to store the actual // number of vertices in the subgraph. out_shape[0] = params.max_num_vertices + 1; for (size_t i = 0; i < num_subgraphs; i++) { SHAPE_ASSIGN_CHECK(*out_attrs, i, out_shape); - success = success && - out_attrs->at(i).ndim() != 0U && - out_attrs->at(i).Size() != 0U; + success = success && !mxnet::op::shape_is_none(out_attrs->at(i)); } // sub_csr - mxnet::TShape out_csr_shape(2); + mxnet::TShape out_csr_shape(2, -1); out_csr_shape[0] = params.max_num_vertices; out_csr_shape[1] = in_attrs->at(0)[1]; for (size_t i = 0; i < num_subgraphs; i++) { SHAPE_ASSIGN_CHECK(*out_attrs, i + num_subgraphs, out_csr_shape); - success = success && - out_attrs->at(i + num_subgraphs).ndim() != 0U && - out_attrs->at(i + num_subgraphs).Size() != 0U; + success = success && !mxnet::op::shape_is_none(out_attrs->at(i + num_subgraphs)); } // sub_layer - mxnet::TShape out_layer_shape(1); + mxnet::TShape out_layer_shape(1, -1); out_layer_shape[0] = params.max_num_vertices; for (size_t i = 0; i < num_subgraphs; i++) { SHAPE_ASSIGN_CHECK(*out_attrs, i + 2*num_subgraphs, out_layer_shape); - success = success && - out_attrs->at(i + 2*num_subgraphs).ndim() != 0U && - out_attrs->at(i + 2*num_subgraphs).Size() != 0U; + success = success && !mxnet::op::shape_is_none(out_attrs->at(i + 2 * num_subgraphs)); } return success; @@ -317,43 +311,35 @@ static bool CSRNeighborNonUniformSampleShape(const nnvm::NodeAttrs& attrs, // Output bool success = true; - mxnet::TShape out_shape(1); + mxnet::TShape out_shape(1, -1); // We use the last element to store the actual // number of vertices in the subgraph. out_shape[0] = params.max_num_vertices + 1; for (size_t i = 0; i < num_subgraphs; i++) { SHAPE_ASSIGN_CHECK(*out_attrs, i, out_shape); - success = success && - out_attrs->at(i).ndim() != 0U && - out_attrs->at(i).Size() != 0U; + success = success && !mxnet::op::shape_is_none(out_attrs->at(i)); } // sub_csr - mxnet::TShape out_csr_shape(2); + mxnet::TShape out_csr_shape(2, -1); out_csr_shape[0] = params.max_num_vertices; out_csr_shape[1] = in_attrs->at(0)[1]; for (size_t i = 0; i < num_subgraphs; i++) { SHAPE_ASSIGN_CHECK(*out_attrs, i + num_subgraphs, out_csr_shape); - success = success && - out_attrs->at(i + num_subgraphs).ndim() != 0U && - out_attrs->at(i + num_subgraphs).Size() != 0U; + success = success && !mxnet::op::shape_is_none(out_attrs->at(i + num_subgraphs)); } // sub_probability - mxnet::TShape out_prob_shape(1); + mxnet::TShape out_prob_shape(1, -1); out_prob_shape[0] = params.max_num_vertices; for (size_t i = 0; i < num_subgraphs; i++) { SHAPE_ASSIGN_CHECK(*out_attrs, i + 2*num_subgraphs, out_prob_shape); - success = success && - out_attrs->at(i + 2*num_subgraphs).ndim() != 0U && - out_attrs->at(i + 2*num_subgraphs).Size() != 0U; + success = success && !mxnet::op::shape_is_none(out_attrs->at(i + 2 * num_subgraphs)); } // sub_layer - mxnet::TShape out_layer_shape(1); + mxnet::TShape out_layer_shape(1, -1); out_layer_shape[0] = params.max_num_vertices; for (size_t i = 0; i < num_subgraphs; i++) { SHAPE_ASSIGN_CHECK(*out_attrs, i + 3*num_subgraphs, out_prob_shape); - success = success && - out_attrs->at(i + 3*num_subgraphs).ndim() != 0U && - out_attrs->at(i + 3*num_subgraphs).Size() != 0U; + success = success && !mxnet::op::shape_is_none(out_attrs->at(i + 3 * num_subgraphs)); } return success; @@ -679,8 +665,8 @@ static void SampleSubgraph(const NDArray &csr, } } // Construct sub_csr_graph - mxnet::TShape shape_1(1); - mxnet::TShape shape_2(1); + mxnet::TShape shape_1(1, -1); + mxnet::TShape shape_2(1, -1); shape_1[0] = num_edges; shape_2[0] = max_num_vertices+1; sub_csr.CheckAndAllocData(shape_1); @@ -960,13 +946,13 @@ static bool DGLSubgraphShape(const nnvm::NodeAttrs& attrs, size_t num_g = params.num_args - 1; for (size_t i = 0; i < num_g; i++) { - mxnet::TShape gshape(2); + mxnet::TShape gshape(2, -1); gshape[0] = in_attrs->at(i + 1)[0]; gshape[1] = in_attrs->at(i + 1)[0]; out_attrs->at(i) = gshape; } for (size_t i = num_g; i < out_attrs->size(); i++) { - mxnet::TShape gshape(2); + mxnet::TShape gshape(2, -1); gshape[0] = in_attrs->at(i - num_g + 1)[0]; gshape[1] = in_attrs->at(i - num_g + 1)[0]; out_attrs->at(i) = gshape; @@ -1081,9 +1067,9 @@ static void GetSubgraph(const NDArray &csr_arr, const NDArray &varr, row_idx[i + 1] = col_idx.size(); } - mxnet::TShape nz_shape(1); + mxnet::TShape nz_shape(1, -1); nz_shape[0] = col_idx.size(); - mxnet::TShape indptr_shape(1); + mxnet::TShape indptr_shape(1, -1); indptr_shape[0] = row_idx.size(); // Store the non-zeros in a subgraph with edge attributes of new edge ids. @@ -1199,7 +1185,7 @@ inline bool EdgeIDShape(const nnvm::NodeAttrs& attrs, SHAPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(1)); SHAPE_ASSIGN_CHECK(*in_attrs, 1, out_attrs->at(0)); SHAPE_ASSIGN_CHECK(*in_attrs, 2, out_attrs->at(0)); - return out_attrs->at(0).ndim() != 0U && out_attrs->at(0).Size() != 0U; + return !mxnet::op::shape_is_none(out_attrs->at(0)); } inline bool EdgeIDType(const nnvm::NodeAttrs& attrs, @@ -1357,7 +1343,7 @@ inline bool DGLAdjacencyShape(const nnvm::NodeAttrs& attrs, SHAPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0)); SHAPE_ASSIGN_CHECK(*in_attrs, 0, out_attrs->at(0)); - return out_attrs->at(0).ndim() != 0U && out_attrs->at(0).Size() != 0U; + return !mxnet::op::shape_is_none(out_attrs->at(0)); } inline bool DGLAdjacencyType(const nnvm::NodeAttrs& attrs, @@ -1460,9 +1446,9 @@ static void CompactSubgraph(const NDArray &csr, const NDArray &vids, CHECK_NE(row_ids[i], -1); } - mxnet::TShape nz_shape(1); + mxnet::TShape nz_shape(1, -1); nz_shape[0] = num_elems; - mxnet::TShape indptr_shape(1); + mxnet::TShape indptr_shape(1, -1); CHECK_EQ(out_csr.shape()[0], graph_size); indptr_shape[0] = graph_size + 1; CHECK_GE(in_ptr_data.shape_[0], indptr_shape[0]); @@ -1540,7 +1526,7 @@ static bool SubgraphCompactShape(const nnvm::NodeAttrs& attrs, } for (size_t i = 0; i < num_g; i++) { - mxnet::TShape gshape(2); + mxnet::TShape gshape(2, -1); gshape[0] = params.graph_sizes[i]; gshape[1] = params.graph_sizes[i]; out_attrs->at(i) = gshape; diff --git a/src/operator/contrib/fft-inl.h b/src/operator/contrib/fft-inl.h index 247f6290c02a..a5471b4ba2e2 100644 --- a/src/operator/contrib/fft-inl.h +++ b/src/operator/contrib/fft-inl.h @@ -241,7 +241,7 @@ class FFTProp : public OperatorProperty { CHECK_EQ(in_shape->size(), 1) <<"Input:[data]"; const mxnet::TShape &dshape = (*in_shape)[fft::kData]; // require data to be known - if (dshape.ndim() == 0) return false; + if (mxnet::op::shape_is_none(dshape)) return false; out_shape->clear(); if (dshape.ndim() == 4) { diff --git a/src/operator/contrib/ifft-inl.h b/src/operator/contrib/ifft-inl.h index e53c0f60fa9e..7d8422e838b1 100644 --- a/src/operator/contrib/ifft-inl.h +++ b/src/operator/contrib/ifft-inl.h @@ -231,7 +231,7 @@ class IFFTProp : public OperatorProperty { CHECK_EQ(in_shape->size(), 1) <<"Input:[data]"; const mxnet::TShape &dshape = (*in_shape)[ifft::kData]; // require data to be known - if (dshape.ndim() == 0) return false; + if (mxnet::op::shape_is_none(dshape)) return false; out_shape->clear(); if (dshape.ndim() == 4) { diff --git a/src/operator/contrib/index_copy-inl.h b/src/operator/contrib/index_copy-inl.h index 903dee13272b..9f78f0593ed1 100644 --- a/src/operator/contrib/index_copy-inl.h +++ b/src/operator/contrib/index_copy-inl.h @@ -64,7 +64,7 @@ inline bool IndexCopyShape(const nnvm::NodeAttrs& attrs, CHECK_EQ(in_attrs->at(1).ndim(), 1); // Shape matching CHECK_EQ(in_attrs->at(0).ndim(), in_attrs->at(2).ndim()); - for (size_t i = 0; i < in_attrs->at(0).ndim(); ++i) { + for (int i = 0; i < in_attrs->at(0).ndim(); ++i) { if (i == 0) { CHECK_GE(in_attrs->at(0)[i], in_attrs->at(2)[i]); } else { @@ -76,8 +76,7 @@ inline bool IndexCopyShape(const nnvm::NodeAttrs& attrs, CHECK_EQ(in_attrs->at(1)[0], in_attrs->at(2)[0]); SHAPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0)); SHAPE_ASSIGN_CHECK(*in_attrs, 0, out_attrs->at(0)); - return out_attrs->at(0).ndim() != 0U && - out_attrs->at(0).Size() != 0U; + return !mxnet::op::shape_is_none(out_attrs->at(0)); } } // namespace op diff --git a/src/operator/contrib/multi_proposal-inl.h b/src/operator/contrib/multi_proposal-inl.h index 4b9a41c2fa87..4d278fb40645 100644 --- a/src/operator/contrib/multi_proposal-inl.h +++ b/src/operator/contrib/multi_proposal-inl.h @@ -108,7 +108,7 @@ class MultiProposalProp : public OperatorProperty { using namespace mshadow; CHECK_EQ(in_shape->size(), 3) << "Input:[cls_prob, bbox_pred, im_info]"; const mxnet::TShape &dshape = in_shape->at(proposal::kClsProb); - if (dshape.ndim() == 0) return false; + if (mxnet::op::shape_is_none(dshape)) return false; Shape<4> bbox_pred_shape; bbox_pred_shape = Shape4(dshape[0], dshape[1] * 2, dshape[2], dshape[3]); SHAPE_ASSIGN_CHECK(*in_shape, proposal::kBBoxPred, diff --git a/src/operator/contrib/multibox_detection-inl.h b/src/operator/contrib/multibox_detection-inl.h index 977126ad269d..1ac14e237f0d 100644 --- a/src/operator/contrib/multibox_detection-inl.h +++ b/src/operator/contrib/multibox_detection-inl.h @@ -161,7 +161,7 @@ class MultiBoxDetectionProp : public OperatorProperty { CHECK_EQ(cshape[2] * 4, lshape[1]) << "# anchors mismatch with # loc"; CHECK_GT(ashape[1], 0U) << "Number of anchors must > 0"; CHECK_EQ(ashape[2], 4U); - mxnet::TShape oshape = mxnet::TShape(3); + mxnet::TShape oshape = mxnet::TShape(3, -1); oshape[0] = cshape[0]; oshape[1] = ashape[1]; oshape[2] = 6; // [id, prob, xmin, ymin, xmax, ymax] diff --git a/src/operator/contrib/multibox_prior-inl.h b/src/operator/contrib/multibox_prior-inl.h index 3636a6016bd2..d8929f3deff4 100644 --- a/src/operator/contrib/multibox_prior-inl.h +++ b/src/operator/contrib/multibox_prior-inl.h @@ -180,7 +180,7 @@ class MultiBoxPriorProp: public OperatorProperty { int in_width = dshape[3]; CHECK_GT(in_width, 0) << "Input width should > 0"; // since input sizes are same in each batch, we could share MultiBoxPrior - mxnet::TShape oshape = mxnet::TShape(3); + mxnet::TShape oshape = mxnet::TShape(3, -1); int num_sizes = param_.sizes.ndim(); int num_ratios = param_.ratios.ndim(); oshape[0] = 1; @@ -189,7 +189,7 @@ class MultiBoxPriorProp: public OperatorProperty { out_shape->clear(); out_shape->push_back(oshape); CHECK_EQ(param_.steps.ndim(), 2) << "Step ndim must be 2: (step_y, step_x)"; - return true; + return shape_is_known(oshape); } OperatorProperty* Copy() const override { diff --git a/src/operator/contrib/nnvm_to_onnx.cc b/src/operator/contrib/nnvm_to_onnx.cc index 0417a085616a..0c8bd79490e3 100644 --- a/src/operator/contrib/nnvm_to_onnx.cc +++ b/src/operator/contrib/nnvm_to_onnx.cc @@ -417,7 +417,8 @@ std::unordered_map GetPlaceholderShapes( for (uint32_t i = 0; i < shape_inputs.size(); ++i) { std::string name = ig[ig.input_nodes()[i]].source->attrs.name; mxnet::TShape shp = shape_inputs[i]; - if (shp.ndim() > 0) { + if (!mxnet::op::shape_is_none(shp)) { + // TODO(@reminisce): confirm placeholder_shapes.emplace(name, shp); } } diff --git a/src/operator/contrib/optimizer_op.cc b/src/operator/contrib/optimizer_op.cc index 9f948bad81b6..83bbcdab833d 100644 --- a/src/operator/contrib/optimizer_op.cc +++ b/src/operator/contrib/optimizer_op.cc @@ -45,7 +45,7 @@ inline bool GroupAdagradShape(const nnvm::NodeAttrs &attrs, SHAPE_ASSIGN_CHECK(*in_attrs, 0, out_attrs->at(0)); SHAPE_ASSIGN_CHECK(*in_attrs, 1, out_attrs->at(0)); - return out_attrs->at(0).ndim() != 0U && out_attrs->at(0).Size() != 0U && + return !mxnet::op::shape_is_none(out_attrs->at(0)) && (in_attrs->at(0)[0] == in_attrs->at(1)[0]) && (in_attrs->at(0)[0] == in_attrs->at(2)[0]); } diff --git a/src/operator/contrib/proposal-inl.h b/src/operator/contrib/proposal-inl.h index 9908ca96ec5f..21e9fe198e63 100644 --- a/src/operator/contrib/proposal-inl.h +++ b/src/operator/contrib/proposal-inl.h @@ -106,7 +106,7 @@ class ProposalProp : public OperatorProperty { using namespace mshadow; CHECK_EQ(in_shape->size(), 3) << "Input:[cls_prob, bbox_pred, im_info]"; const mxnet::TShape &dshape = in_shape->at(proposal::kClsProb); - if (dshape.ndim() == 0) return false; + if (mxnet::op::shape_is_none(dshape)) return false; Shape<4> bbox_pred_shape; bbox_pred_shape = Shape4(dshape[0], dshape[1] * 2, dshape[2], dshape[3]); SHAPE_ASSIGN_CHECK(*in_shape, proposal::kBBoxPred, diff --git a/src/operator/contrib/quadratic_op-inl.h b/src/operator/contrib/quadratic_op-inl.h index e679fedc8e57..a7aca63de17a 100644 --- a/src/operator/contrib/quadratic_op-inl.h +++ b/src/operator/contrib/quadratic_op-inl.h @@ -60,7 +60,7 @@ inline bool QuadraticOpShape(const nnvm::NodeAttrs& attrs, SHAPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0)); SHAPE_ASSIGN_CHECK(*in_attrs, 0, out_attrs->at(0)); - return out_attrs->at(0).ndim() != 0U && out_attrs->at(0).Size() != 0U; + return !mxnet::op::shape_is_none(out_attrs->at(0)); } inline bool QuadraticOpType(const nnvm::NodeAttrs& attrs, diff --git a/src/operator/contrib/sync_batch_norm-inl.h b/src/operator/contrib/sync_batch_norm-inl.h index 1e6ab25db0e2..cd1a3285fe06 100644 --- a/src/operator/contrib/sync_batch_norm-inl.h +++ b/src/operator/contrib/sync_batch_norm-inl.h @@ -482,7 +482,7 @@ class SyncBatchNormProp : public OperatorProperty { using namespace mshadow; CHECK_EQ(in_shape->size(), 3U) << "Input:[data, gamma, beta]"; const mxnet::TShape &dshape = in_shape->at(0); - if (dshape.ndim() == 0) return false; + if (mxnet::op::shape_is_none(dshape)) return false; in_shape->at(1) = mxnet::TShape(Shape1(dshape[1])); in_shape->at(2) = mxnet::TShape(Shape1(dshape[1])); out_shape->clear(); diff --git a/src/operator/contrib/transformer-inl.h b/src/operator/contrib/transformer-inl.h index 01faf244aff9..da3d14e33cf4 100644 --- a/src/operator/contrib/transformer-inl.h +++ b/src/operator/contrib/transformer-inl.h @@ -41,7 +41,9 @@ static void DivSqrtDimForward_(const nnvm::NodeAttrs& attrs, const std::vector& req, const std::vector& outputs) { mshadow::Stream *s = ctx.get_stream(); - double sqrt_dim = std::sqrt(static_cast(inputs[0].shape_[inputs[0].ndim() - 1])); + CHECK_GE(inputs[0].ndim(), 1); + int last_idx = inputs[0].ndim() - 1; + double sqrt_dim = std::sqrt(static_cast(inputs[0].shape_[last_idx])); MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, { MXNET_ASSIGN_REQ_SWITCH(req[0], Req, { mxnet_op::Kernel, xpu>::Launch( diff --git a/src/operator/control_flow.cc b/src/operator/control_flow.cc index ac6fea7c143b..4c0d67bb08f7 100644 --- a/src/operator/control_flow.cc +++ b/src/operator/control_flow.cc @@ -37,11 +37,11 @@ struct ForeachParam : public dmlc::Parameter { int num_outputs; int num_out_data; // The location of states in the subgraph inputs. - nnvm::Tuple in_state_locs; + mxnet::Tuple in_state_locs; // The location of data arrays in the subgraph inputs. - nnvm::Tuple in_data_locs; + mxnet::Tuple in_data_locs; // The location of remaining arrays in the subgraph inputs. - nnvm::Tuple remain_locs; + mxnet::Tuple remain_locs; DMLC_DECLARE_PARAMETER(ForeachParam) { DMLC_DECLARE_FIELD(num_args).set_lower_bound(1) .describe("Number of inputs."); @@ -82,7 +82,7 @@ static void ForeachComputeExCPU(const OpStatePtr& state_ptr, CHECK_GT(params.in_data_locs.ndim(), 0); size_t len = inputs[0].shape()[iter_dim]; state.num_iterations = len; - for (size_t i = 1; i < params.in_data_locs.ndim(); i++) + for (int i = 1; i < params.in_data_locs.ndim(); i++) CHECK_EQ(inputs[i].shape()[iter_dim], len); for (size_t i = 0; i < (size_t) params.num_out_data; i++) CHECK_EQ(len, outputs[i].shape()[iter_dim]); @@ -120,7 +120,7 @@ static void ForeachComputeExCPU(const OpStatePtr& state_ptr, // and the loop states. std::vector subg_inputs(inputs.size()); // The remaining arrays (other than input data and states) only need to be set once. - for (size_t j = 0; j < params.remain_locs.ndim(); j++) { + for (int j = 0; j < params.remain_locs.ndim(); j++) { CHECK_LT(params.remain_locs[j], subg_inputs.size()); subg_inputs[params.remain_locs[j]] = inputs[j + params.in_data_locs.ndim() + params.in_state_locs.ndim()]; @@ -148,7 +148,7 @@ static void ForeachComputeExCPU(const OpStatePtr& state_ptr, // Initialize inputs for the subgraph. // Get a slice from the input data arrays. - for (size_t j = 0; j < params.in_data_locs.ndim(); j++) { + for (int j = 0; j < params.in_data_locs.ndim(); j++) { size_t loc = params.in_data_locs[j]; subg_inputs[loc] = inputs[j].At(i); } @@ -161,7 +161,7 @@ static void ForeachComputeExCPU(const OpStatePtr& state_ptr, subg_inputs[params.in_state_locs[idx]] = (*subg_out_prev)[j]; } } else { - for (size_t j = 0; j < params.in_state_locs.ndim(); j++) { + for (int j = 0; j < params.in_state_locs.ndim(); j++) { CHECK_LT(params.in_state_locs[j], subg_inputs.size()); subg_inputs[params.in_state_locs[j]] = inputs[j + params.in_data_locs.ndim()]; } @@ -203,7 +203,7 @@ static void ForeachGradComputeExCPU(const OpStatePtr& state_ptr, // [data vars], [loop vars], [remaining vars] // [remaining vars] - for (size_t i = 0; i < params.remain_locs.ndim(); i++) { + for (int i = 0; i < params.remain_locs.ndim(); i++) { size_t loc = params.remain_locs[i]; size_t orig_loc = i + params.in_data_locs.ndim() + params.in_state_locs.ndim(); subg_igrads[loc] = outputs[orig_loc]; @@ -216,20 +216,20 @@ static void ForeachGradComputeExCPU(const OpStatePtr& state_ptr, if (iter_num < len - 1) { // For the rest of the iterations, we should add graidents to the // remaining vars. - for (size_t i = 0; i < params.remain_locs.ndim(); i++) { + for (int i = 0; i < params.remain_locs.ndim(); i++) { size_t loc = params.remain_locs[i]; subg_req[loc] = kAddTo; } } // [data vars] - for (size_t i = 0; i < params.in_data_locs.ndim(); i++) { + for (int i = 0; i < params.in_data_locs.ndim(); i++) { size_t loc = params.in_data_locs[i]; subg_igrads[loc] = outputs[i].At(iter_num); subg_req[loc] = req[i]; } // [loop vars] - for (size_t i = 0; i < params.in_state_locs.ndim(); i++) { + for (int i = 0; i < params.in_state_locs.ndim(); i++) { size_t loc = params.in_state_locs[i]; const NDArray &output = outputs[i + params.in_data_locs.ndim()]; if (iter_num != 0) { @@ -258,9 +258,9 @@ static void ForeachGradComputeExCPU(const OpStatePtr& state_ptr, template static void remap(const std::vector &op_in, size_t start, - const nnvm::Tuple &locs, std::vector *subg_in) { + const mxnet::Tuple &locs, std::vector *subg_in) { auto op_in_it = op_in.begin() + start; - for (size_t i = 0; i < locs.ndim(); i++) { + for (int i = 0; i < locs.ndim(); i++) { dim_t loc = locs[i]; subg_in->at(loc) = *(op_in_it + i); } @@ -284,7 +284,7 @@ static bool ForeachShape(const nnvm::NodeAttrs& attrs, mxnet::ShapeVector subg_in_shape(in_shape->size()); // data shape std::vector data_1d(params.in_data_locs.ndim(), false); - for (size_t i = 0; i < params.in_data_locs.ndim(); i++) { + for (int i = 0; i < params.in_data_locs.ndim(); i++) { size_t loc = params.in_data_locs[i]; if (in_shape->at(i).ndim() == 1) data_1d[i] = true; @@ -301,7 +301,7 @@ static bool ForeachShape(const nnvm::NodeAttrs& attrs, for (int i = 0; i < params.num_out_data; i++) { mxnet::TShape shape = subg_out_shape[i]; // If we don't have shape info, we don't need to do anything. - if (shape.ndim() == 0) + if (!mxnet::ndim_is_known(shape)) continue; subg_out_shape[i] = SliceFirstDim(shape); } @@ -317,12 +317,12 @@ static bool ForeachShape(const nnvm::NodeAttrs& attrs, for (int i = 0; i < params.num_out_data; i++) { // If the output shape isn't inferred, we don't need to propogate the info. const auto& g_out_shape = subg_out_shape[i]; - if (g_out_shape.ndim() == 0) + if (!mxnet::ndim_is_known(g_out_shape)) continue; - auto out = mxnet::TShape(g_out_shape.ndim() + 1); + auto out = mxnet::TShape(g_out_shape.ndim() + 1, -1); out[0] = len; - for (size_t i = 1; i < out.ndim(); i++) + for (int i = 1; i < out.ndim(); i++) out[i] = g_out_shape[i - 1]; SHAPE_ASSIGN_CHECK(*out_shape, i, out); } @@ -331,34 +331,34 @@ static bool ForeachShape(const nnvm::NodeAttrs& attrs, SHAPE_ASSIGN_CHECK(*out_shape, i, subg_out_shape[i]); // For the shape of input data. - for (size_t i = 0; i < params.in_data_locs.ndim(); i++) { + for (int i = 0; i < params.in_data_locs.ndim(); i++) { size_t loc = params.in_data_locs[i]; const auto &shape = subg_in_shape[loc]; // If the input data shape isn't inferred, we don't need to propogate the // info. - if (shape.ndim() == 0) + if (!mxnet::ndim_is_known(shape)) continue; if (data_1d[i]) { - mxnet::TShape s(1); + mxnet::TShape s(1, -1); s[0] = len; SHAPE_ASSIGN_CHECK(*in_shape, i, s); } else { - auto in = mxnet::TShape(shape.ndim() + 1); + auto in = mxnet::TShape(shape.ndim() + 1, -1); in[0] = len; - for (size_t i = 1; i < in.ndim(); i++) + for (int i = 1; i < in.ndim(); i++) in[i] = shape[i - 1]; SHAPE_ASSIGN_CHECK(*in_shape, i, in); } } // For the shape of state. - for (size_t i = 0; i < params.in_state_locs.ndim(); i++) { + for (int i = 0; i < params.in_state_locs.ndim(); i++) { size_t loc = params.in_state_locs[i]; SHAPE_ASSIGN_CHECK(*in_shape, i + params.in_data_locs.ndim(), subg_in_shape[loc]); } // For the shape of remaining data. - for (size_t i = 0; i < params.remain_locs.ndim(); i++) { + for (int i = 0; i < params.remain_locs.ndim(); i++) { size_t loc = params.remain_locs[i]; SHAPE_ASSIGN_CHECK(*in_shape, i + params.in_data_locs.ndim() + params.in_state_locs.ndim(), @@ -387,15 +387,15 @@ static bool ForeachType(const nnvm::NodeAttrs& attrs, remap(*in_type, params.in_data_locs.ndim() + params.in_state_locs.ndim(), params.remain_locs, &subg_in_type); bool success = InferSubgraphDataType(*attrs.subgraphs[0], &subg_in_type, out_type); - for (size_t i = 0; i < params.in_data_locs.ndim(); i++) { + for (int i = 0; i < params.in_data_locs.ndim(); i++) { size_t loc = params.in_data_locs[i]; TYPE_ASSIGN_CHECK(*in_type, i, subg_in_type[loc]); } - for (size_t i = 0; i < params.in_state_locs.ndim(); i++) { + for (int i = 0; i < params.in_state_locs.ndim(); i++) { size_t loc = params.in_state_locs[i]; TYPE_ASSIGN_CHECK(*in_type, i + params.in_data_locs.ndim(), subg_in_type[loc]); } - for (size_t i = 0; i < params.remain_locs.ndim(); i++) { + for (int i = 0; i < params.remain_locs.ndim(); i++) { size_t loc = params.remain_locs[i]; TYPE_ASSIGN_CHECK(*in_type, i + params.in_data_locs.ndim() + params.in_state_locs.ndim(), subg_in_type[loc]); @@ -418,16 +418,16 @@ static bool ForeachStorageType(const nnvm::NodeAttrs& attrs, params.remain_locs, &subg_in_attrs); bool success = InferSubgraphStorage(*attrs.subgraphs[0], dev_mask, dispatch_mode, &subg_in_attrs, out_attrs); - for (size_t i = 0; i < params.in_data_locs.ndim(); i++) { + for (int i = 0; i < params.in_data_locs.ndim(); i++) { size_t loc = params.in_data_locs[i]; STORAGE_TYPE_ASSIGN_CHECK(*in_attrs, i, subg_in_attrs[loc]); } - for (size_t i = 0; i < params.in_state_locs.ndim(); i++) { + for (int i = 0; i < params.in_state_locs.ndim(); i++) { size_t loc = params.in_state_locs[i]; STORAGE_TYPE_ASSIGN_CHECK(*in_attrs, i + params.in_data_locs.ndim(), subg_in_attrs[loc]); } - for (size_t i = 0; i < params.remain_locs.ndim(); i++) { + for (int i = 0; i < params.remain_locs.ndim(); i++) { size_t loc = params.remain_locs[i]; STORAGE_TYPE_ASSIGN_CHECK(*in_attrs, i + params.in_data_locs.ndim() + params.in_state_locs.ndim(), @@ -488,9 +488,9 @@ struct WhileLoopParam : public dmlc::Parameter { // `cond_input_locs' contains indices of inputs fed to `cond', and // `func_input_locs' contains indices of inputs fed to `func'. // `func_var_locs' are indices in which input "variables" are stored in func's inputs. - nnvm::Tuple cond_input_locs; - nnvm::Tuple func_input_locs; - nnvm::Tuple func_var_locs; + mxnet::Tuple cond_input_locs; + mxnet::Tuple func_input_locs; + mxnet::Tuple func_var_locs; DMLC_DECLARE_PARAMETER(WhileLoopParam) { DMLC_DECLARE_FIELD(num_args).set_lower_bound(2) .describe("Number of input arguments, including cond and func as two symbol inputs."); @@ -538,12 +538,12 @@ class WhileLoopState: public LoopState { n_iterations(0U), cond_op(LoopState::MakeSharedOp(cond)), oi_map(params.func_var_locs.ndim(), -1) { - const nnvm::Tuple &func_input_locs = params.func_input_locs; - const nnvm::Tuple &func_var_locs = params.func_var_locs; - const nnvm::Tuple &cond_input_locs = params.cond_input_locs; - for (size_t i = 0; i < func_var_locs.ndim(); ++i) { + const mxnet::Tuple &func_input_locs = params.func_input_locs; + const mxnet::Tuple &func_var_locs = params.func_var_locs; + const mxnet::Tuple &cond_input_locs = params.cond_input_locs; + for (int i = 0; i < func_var_locs.ndim(); ++i) { dim_t pos_i = func_input_locs[func_var_locs[i]]; - for (size_t j = 0; j < cond_input_locs.ndim(); ++j) { + for (int j = 0; j < cond_input_locs.ndim(); ++j) { dim_t pos_j = cond_input_locs[j]; if (pos_i == pos_j) { this->oi_map[i] = j; @@ -740,7 +740,7 @@ static bool WhileLoopShape(const nnvm::NodeAttrs& attrs, // infer shape for cond and func auto infer_subg = [¶ms, in_shape, out_shape](std::shared_ptr subg, ShapeVector *_subg_out, - const nnvm::Tuple &input_locs, + const mxnet::Tuple &input_locs, int num_out_data, bool fill_out_shape) { // create subg_in @@ -781,7 +781,7 @@ static bool WhileLoopShape(const nnvm::NodeAttrs& attrs, for (size_t i = 0; i < subg_in.size(); ++i) { auto eid = idx.entry_id(input_nids[i], 0); auto g_out_shape = new_shapes[eid]; - if (g_out_shape.ndim() == 0 || g_out_shape.Size() == 0) { + if (!shape_is_known(g_out_shape)) { // when the shape is not fully inferred continue; } @@ -795,13 +795,13 @@ static bool WhileLoopShape(const nnvm::NodeAttrs& attrs, for (int i = 0; i < num_out_data; ++i) { auto eid = idx.entry_id(g.outputs[i]); auto g_out_shape = new_shapes[eid]; - if (g_out_shape.ndim() == 0 || g_out_shape.Size() == 0) { + if (!shape_is_known(g_out_shape)) { // when the shape is not fully inferred continue; } - auto out = mxnet::TShape(g_out_shape.ndim() + 1); + auto out = mxnet::TShape(g_out_shape.ndim() + 1, -1); out[0] = params.max_iterations; - for (size_t i = 1; i < out.ndim(); i++) + for (int i = 1; i < out.ndim(); i++) out[i] = g_out_shape[i - 1]; SHAPE_ASSIGN_CHECK(*out_shape, i, out); } @@ -809,7 +809,7 @@ static bool WhileLoopShape(const nnvm::NodeAttrs& attrs, for (size_t i = num_out_data; i < g.outputs.size(); ++i) { auto eid = idx.entry_id(g.outputs[i]); auto g_out_shape = new_shapes[eid]; - if (g_out_shape.ndim() == 0 || g_out_shape.Size() == 0) { + if (!shape_is_known(g_out_shape)) { // when the shape is not fully inferred continue; } @@ -817,7 +817,7 @@ static bool WhileLoopShape(const nnvm::NodeAttrs& attrs, } return g.GetAttr("shape_num_unknown_nodes") == 0; }; - mxnet::ShapeVector cond_out_shape{mxnet::TShape(1U)}; // this means: [(1, )] + mxnet::ShapeVector cond_out_shape{mxnet::TShape(1, 1)}; // this means: [(1, )] mxnet::ShapeVector func_out_shape(params.num_outputs); CHECK(params.sync_in_out(in_shape, out_shape, is_udf)); bool succ_0 = infer_subg(attrs.subgraphs[0], &cond_out_shape, params.cond_input_locs, 0, false); @@ -915,9 +915,9 @@ WhileLoopGradient(const nnvm::NodePtr& n, const std::vector& og struct CondParam : public dmlc::Parameter { int num_args; int num_outputs; - nnvm::Tuple cond_input_locs; - nnvm::Tuple then_input_locs; - nnvm::Tuple else_input_locs; + mxnet::Tuple cond_input_locs; + mxnet::Tuple then_input_locs; + mxnet::Tuple else_input_locs; DMLC_DECLARE_PARAMETER(CondParam) { DMLC_DECLARE_FIELD(num_args).set_lower_bound(3) .describe("Number of input arguments, including cond, then and else as three symbol inputs."); @@ -992,7 +992,7 @@ static void CondComputeExCPU(const OpStatePtr& state_ptr, state.cond_op->Forward(nullptr, cond_input_ptr, cond_output_ptr); branch_selection = as_bool_scalar(*cond_output_ptr[0]); // select the right branch - const nnvm::Tuple &func_input_locs = branch_selection + const mxnet::Tuple &func_input_locs = branch_selection ? params.then_input_locs : params.else_input_locs; LoopState &loop_state = branch_selection @@ -1017,7 +1017,7 @@ static void CondGradComputeExCPU(const OpStatePtr& state_ptr, // select the right branch int branch_selection = state.branch_selection; CHECK_NE(branch_selection, -1); - const nnvm::Tuple &func_input_locs = branch_selection + const mxnet::Tuple &func_input_locs = branch_selection ? params.then_input_locs : params.else_input_locs; LoopState &loop_state = branch_selection @@ -1048,7 +1048,7 @@ static bool CondShape(const nnvm::NodeAttrs& attrs, // infer shape for cond, then and else auto infer_subg = [¶ms, in_shape, out_shape](std::shared_ptr subg, ShapeVector *_subg_out, - const nnvm::Tuple &input_locs, + const mxnet::Tuple &input_locs, bool fill_out_shape) { // create subg_in mxnet::ShapeVector subg_in; @@ -1086,7 +1086,7 @@ static bool CondShape(const nnvm::NodeAttrs& attrs, for (size_t i = 0; i < subg_in.size(); ++i) { auto eid = idx.entry_id(input_nids[i], 0); auto g_out_shape = new_shapes[eid]; - if (g_out_shape.ndim() == 0 || g_out_shape.Size() == 0) { + if (!shape_is_known(g_out_shape)) { // when the shape is not fully inferred continue; } @@ -1099,7 +1099,7 @@ static bool CondShape(const nnvm::NodeAttrs& attrs, for (size_t i = 0; i < g.outputs.size(); ++i) { auto eid = idx.entry_id(g.outputs[i]); auto g_out_shape = new_shapes[eid]; - if (g_out_shape.ndim() == 0 || g_out_shape.Size() == 0) { + if (!shape_is_known(g_out_shape)) { // when the shape is not fully inferred continue; } @@ -1107,7 +1107,7 @@ static bool CondShape(const nnvm::NodeAttrs& attrs, } return g.GetAttr("shape_num_unknown_nodes") == 0; }; - ShapeVector cond_out_shape{mxnet::TShape(1U)}; // this means: [(1, )] + ShapeVector cond_out_shape{mxnet::TShape(1, 1)}; // this means: [(1, )] ShapeVector then_out_shape(params.num_outputs); ShapeVector else_out_shape(params.num_outputs); bool succ_0 = infer_subg(attrs.subgraphs[0], &cond_out_shape, \ @@ -1190,7 +1190,7 @@ static bool BackwardCondStorageType(const nnvm::NodeAttrs& attrs, CHECK_EQ(out_attrs->size() + 3U, (size_t) params.num_args); CHECK_EQ(attrs.subgraphs.size(), 3U); static const std::function is_udf = is_stype_udf; - auto sub_pass = [&](const std::shared_ptr &subg, const nnvm::Tuple &input_locs) { + auto sub_pass = [&](const std::shared_ptr &subg, const mxnet::Tuple &input_locs) { // A. first construct subg_in_attrs // need subg_in_attrs as subg_bwd_out (copy), subg_fwd_in (extract), subg_fwd_out (copy) std::vector subg_in_attrs; diff --git a/src/operator/convolution_v1-inl.h b/src/operator/convolution_v1-inl.h index ed6748a9c85c..080c718dc9bf 100644 --- a/src/operator/convolution_v1-inl.h +++ b/src/operator/convolution_v1-inl.h @@ -64,11 +64,11 @@ struct ConvolutionV1Param : public dmlc::Parameter { dmlc::optional layout; DMLC_DECLARE_PARAMETER(ConvolutionV1Param) { DMLC_DECLARE_FIELD(kernel).describe("convolution kernel size: (h, w) or (d, h, w)"); - DMLC_DECLARE_FIELD(stride).set_default(mxnet::TShape()) + DMLC_DECLARE_FIELD(stride).set_default(mxnet::TShape(0, 0)) .describe("convolution stride: (h, w) or (d, h, w)"); - DMLC_DECLARE_FIELD(dilate).set_default(mxnet::TShape()) + DMLC_DECLARE_FIELD(dilate).set_default(mxnet::TShape(0, 0)) .describe("convolution dilate: (h, w) or (d, h, w)"); - DMLC_DECLARE_FIELD(pad).set_default(mxnet::TShape()) + DMLC_DECLARE_FIELD(pad).set_default(mxnet::TShape(0, 0)) .describe("pad for convolution: (h, w) or (d, h, w)"); DMLC_DECLARE_FIELD(num_filter).set_range(1, 100000) .describe("convolution filter(channel) number"); @@ -405,7 +405,7 @@ class ConvolutionV1Prop : public OperatorProperty { // CHECK_EQ(out_shape->size(), 1) << "Output: [output]"; out_shape->resize(1, mxnet::TShape()); const mxnet::TShape &dshp = (*in_shape)[conv_v1::kData]; - if (dshp.ndim() == 0) return false; + if (!mxnet::ndim_is_known(dshp)) return false; if (param_.kernel.ndim() == 2) { // 2d conv_v1 CHECK_EQ(dshp.ndim(), 4U) \ diff --git a/src/operator/custom/custom.cc b/src/operator/custom/custom.cc index 46249c9bbcc6..412bfa1bc3aa 100644 --- a/src/operator/custom/custom.cc +++ b/src/operator/custom/custom.cc @@ -128,17 +128,21 @@ bool InferShape(const NodeAttrs& attrs, const CustomParam& params = nnvm::get(attrs.parsed); size_t total = params.num_args + params.num_outs + params.num_auxs; - std::vector shapes(total); + std::vector shapes(total); std::vector ndims(total); size_t buff_size = 0; - for (const auto& i : *in_shape) buff_size += i.ndim(); - std::vector buff(buff_size); - uint32_t *ptr = buff.data(); + for (const auto& i : *in_shape) { + if (i.ndim() > 0) { + buff_size += i.ndim(); + } + } + std::vector buff(buff_size); + int *ptr = buff.data(); for (size_t i = 0; i < in_shape->size(); ++i) { shapes[i] = ptr; ndims[i] = (*in_shape)[i].ndim(); - for (size_t j = 0; j < (*in_shape)[i].ndim(); ++j, ++ptr) { - *ptr = static_cast((*in_shape)[i][j]); + for (int j = 0; j < (*in_shape)[i].ndim(); ++j, ++ptr) { + *ptr = (*in_shape)[i][j]; } } @@ -263,7 +267,7 @@ OpStatePtr CreateState(const NodeAttrs& attrs, Context ctx, for (size_t i = 0; i < in_shape.size(); ++i) { shapes[i] = ptr; ndims[i] = in_shape[i].ndim(); - for (size_t j = 0; j < in_shape[i].ndim(); ++j, ++ptr) { + for (int j = 0; j < in_shape[i].ndim(); ++j, ++ptr) { *ptr = static_cast(in_shape[i][j]); } } diff --git a/src/operator/image/image_random-inl.h b/src/operator/image/image_random-inl.h index c37324678120..182cd682af8b 100644 --- a/src/operator/image/image_random-inl.h +++ b/src/operator/image/image_random-inl.h @@ -93,7 +93,7 @@ inline bool ToTensorShape(const nnvm::NodeAttrs& attrs, CHECK_EQ(out_attrs->size(), 1U); mxnet::TShape &shp = (*in_attrs)[0]; - if (!shp.ndim()) return false; + if (!shape_is_known(shp)) return false; CHECK((shp.ndim() == 3) || (shp.ndim() == 4)) << "Input image must have shape (height, width, channels), or " @@ -549,7 +549,7 @@ template void FlipImpl(const mxnet::TShape &shape, DType *src, DType *dst) { int head = 1, mid = shape[axis], tail = 1; for (int i = 0; i < axis; ++i) head *= shape[i]; - for (uint32_t i = axis+1; i < shape.ndim(); ++i) tail *= shape[i]; + for (int i = axis+1; i < shape.ndim(); ++i) tail *= shape[i]; for (int i = 0; i < head; ++i) { for (int j = 0; j < (mid >> 1); ++j) { diff --git a/src/operator/image/resize-inl.h b/src/operator/image/resize-inl.h index de2189838d76..4ebebbfb272c 100644 --- a/src/operator/image/resize-inl.h +++ b/src/operator/image/resize-inl.h @@ -49,12 +49,12 @@ void ResizeImplCUDA(Stream *s, #endif // MXNET_USE_CUDA struct ResizeParam : public dmlc::Parameter { - nnvm::Tuple size; + mxnet::Tuple size; bool keep_ratio; int interp; DMLC_DECLARE_PARAMETER(ResizeParam) { DMLC_DECLARE_FIELD(size) - .set_default(nnvm::Tuple()) + .set_default(mxnet::Tuple()) .describe("Size of new image. Could be (width, height) or (size)"); DMLC_DECLARE_FIELD(keep_ratio) .describe("Whether to resize the short edge or both edges to `size`, " diff --git a/src/operator/leaky_relu-inl.h b/src/operator/leaky_relu-inl.h index cfdd1064d6fb..7f8638630145 100644 --- a/src/operator/leaky_relu-inl.h +++ b/src/operator/leaky_relu-inl.h @@ -315,7 +315,7 @@ class LeakyReLUOp : public Operator { return a < b ? (a < c ? a : c) : (b < c ? b : c); } static inline mxnet::TShape expand_shape(const mxnet::TShape& src, const mxnet::TShape& dst) { - mxnet::TShape result(dst.ndim()); + mxnet::TShape result(dst.ndim(), -1); int s = src.ndim() - 1; for (int i = dst.ndim() - 1; i >= 0; i--) { if (s >= 0 && i <= 1 && (dst[i] == src[s] || src[s] == 1)) { @@ -355,10 +355,10 @@ class LeakyReLUProp : public OperatorProperty { CHECK_EQ(in_shape->size(), 1U) << "Input:[data]"; } const mxnet::TShape &dshape = in_shape->at(leakyrelu::kData); - if (dshape.ndim() == 0) return false; + if (!mxnet::ndim_is_known(dshape)) return false; if (param_.act_type == leakyrelu::kPReLU) { const mxnet::TShape &gshape = in_shape->at(leakyrelu::kGamma); - if (gshape.ndim() == 0) { + if (!mxnet::ndim_is_known(gshape)) { in_shape->at(leakyrelu::kGamma) = mxnet::TShape(Shape1(dshape[1])); } if (dshape == gshape) { diff --git a/src/operator/loss_binary_op-inl.h b/src/operator/loss_binary_op-inl.h index a3853c56359a..1d71993da515 100644 --- a/src/operator/loss_binary_op-inl.h +++ b/src/operator/loss_binary_op-inl.h @@ -43,7 +43,7 @@ inline bool SoftmaxCrossEntropyShape(const nnvm::NodeAttrs& attrs, << "SoftmaxCrossEntropy only accept 1D label"; CHECK_EQ((*in_attrs)[0][0], (*in_attrs)[1][0]) << "SoftmaxCrossEntropy: data label shape mismatch"; - SHAPE_ASSIGN_CHECK(*out_attrs, 0, mxnet::TShape(1)); + SHAPE_ASSIGN_CHECK(*out_attrs, 0, mxnet::TShape(1, 1)); return true; } diff --git a/src/operator/mxnet_op.h b/src/operator/mxnet_op.h index a937f839c9bb..e331255c2e50 100644 --- a/src/operator/mxnet_op.h +++ b/src/operator/mxnet_op.h @@ -781,6 +781,7 @@ struct Kernel { /*! \brief Launch GPU kernel */ template inline static void Launch(mshadow::Stream *s, int N, Args... args) { + if (0 == N) return; using namespace mshadow::cuda; int ngrid = std::min(kMaxGridNum, (N + kBaseThreadNum - 1) / kBaseThreadNum); mxnet_generic_kernel @@ -791,6 +792,7 @@ struct Kernel { template inline static void LaunchEx(mshadow::Stream *s, const int N, Args... args) { + if (0 == N) return; using namespace mshadow::cuda; int ngrid = std::min(kMaxGridNum, (N + kBaseThreadNum - 1) / kBaseThreadNum); mxnet_generic_kernel_ex diff --git a/src/operator/nn/batch_norm.cc b/src/operator/nn/batch_norm.cc index 511fe455e946..622952cc4bc5 100644 --- a/src/operator/nn/batch_norm.cc +++ b/src/operator/nn/batch_norm.cc @@ -332,7 +332,7 @@ static bool BatchNormShape(const nnvm::NodeAttrs& attrs, const int channelCount = dshape[channelAxis]; - if (dshape.ndim() == 0) { + if (!mxnet::ndim_is_known(dshape)) { return false; } diff --git a/src/operator/nn/concat.cc b/src/operator/nn/concat.cc index fa441c45321e..8fb229889332 100644 --- a/src/operator/nn/concat.cc +++ b/src/operator/nn/concat.cc @@ -39,39 +39,40 @@ static bool ConcatShape(const nnvm::NodeAttrs& attrs, const ConcatParam& param_ = nnvm::get(attrs.parsed); CHECK_EQ(in_shape->size(), static_cast(param_.num_args)); mxnet::TShape dshape; - index_t size = 0; - bool has_zero = false; + dim_t size = 0; + bool has_unknown_dim_size = false; int axis = -1; for (int i = 0; i < param_.num_args; ++i) { mxnet::TShape tmp = (*in_shape)[i]; - if (tmp.ndim()) { + if (tmp.ndim() > 0) { axis = CheckAxis(param_.dim, tmp.ndim()); - has_zero = tmp[axis] == 0 || has_zero; + has_unknown_dim_size = !mxnet::dim_size_is_known(tmp, axis) || has_unknown_dim_size; size += tmp[axis]; - tmp[axis] = 0; + tmp[axis] = -1; shape_assign(&dshape, tmp); } } mxnet::TShape tmp = (*out_shape)[0]; - if (tmp.ndim()) { + if (tmp.ndim() > 0) { axis = CheckAxis(param_.dim, tmp.ndim()); - tmp[axis] = 0; + tmp[axis] = -1; shape_assign(&dshape, tmp); } - if (dshape.ndim() == 0) return false; + if (dshape.ndim() == -1) return false; + CHECK_NE(dshape.ndim(), 0) << "zero-dimensional arrays cannot be concatenated"; for (int i = 0; i < param_.num_args; ++i) { CHECK(shape_assign(&(*in_shape)[i], dshape)) << "Incompatible input shape: expected " << dshape << ", got " << (*in_shape)[i]; } - if (!has_zero) dshape[axis] = size; + if (!has_unknown_dim_size) dshape[axis] = size; CHECK(shape_assign(&(*out_shape)[0], dshape)) << "Incompatible output shape: expected " << dshape << ", got " << (*out_shape)[0]; - return dshape.Size() != 0; + return shape_is_known(dshape); } // Concat for RNN param deals with the reverse shape inference from output @@ -90,26 +91,27 @@ static bool RNNParamConcatShape(const nnvm::NodeAttrs& attrs, int axis = -1; for (int i = 0; i < param_.num_args; ++i) { mxnet::TShape tmp = (*in_shape)[i]; - if (tmp.ndim()) { + if (tmp.ndim() > 0) { axis = CheckAxis(param_.dim, tmp.ndim()); - if (tmp[axis] == 0) { + if (!mxnet::dim_size_is_known(tmp, axis)) { zero_indices.emplace_back(i); } else { + CHECK_GE(tmp[axis], 0); size += tmp[axis]; } - tmp[axis] = 0; + tmp[axis] = -1; shape_assign(&dshape, tmp); } } mxnet::TShape tmp = (*out_shape)[0]; - if (tmp.ndim()) { + if (tmp.ndim() > 0) { axis = CheckAxis(param_.dim, tmp.ndim()); - tmp[axis] = 0; + tmp[axis] = -1; shape_assign(&dshape, tmp); } - if (dshape.ndim() == 0) return false; + if (!mxnet::ndim_is_known(dshape)) return false; for (int i = 0; i < param_.num_args; ++i) { CHECK(shape_assign(&(*in_shape)[i], dshape)) @@ -119,21 +121,21 @@ static bool RNNParamConcatShape(const nnvm::NodeAttrs& attrs, if (zero_indices.empty()) dshape[axis] = size; CHECK(shape_assign(&(*out_shape)[0], dshape)) << "Incompatible output shape: expected " << dshape << ", got " << (*out_shape)[0]; - if ((*out_shape)[0][axis] != 0 && !zero_indices.empty()) { + if ((*out_shape)[0][axis] != -1 && !zero_indices.empty()) { int residual = (*out_shape)[0][axis] - size; CHECK_GE(residual, 0) << "Input size already exceeds output size. Residual: " << residual; - CHECK(zero_indices.size() <= 2 && zero_indices.size() >= 0) + CHECK(zero_indices.size() <= 2 && zero_indices.size() > 0) << "Expecting 1 or 2 inputs that need shape inference. Got: " << zero_indices.size(); - bool need_infer = !(*out_shape)[0].Size(); + bool need_infer = !shape_is_known((*out_shape)[0]); for (int i : zero_indices) { (*in_shape)[i][axis] = residual / zero_indices.size(); - need_infer = need_infer || !(*in_shape)[i].Size(); + need_infer = need_infer || !shape_is_known((*in_shape)[i]); } return !need_infer; } - return dshape.Size() != 0; + return shape_is_known(dshape); } static bool ConcatType(const nnvm::NodeAttrs& attrs, @@ -232,9 +234,10 @@ bool SupportMKLDNNConcat(const std::vector &arrs) { for (auto &arr : arrs) { if (arr.IsView()) return false; if (arr.dtype() != mshadow::kFloat32) return false; - unsigned ndim = arr.shape().ndim(); - unsigned mkldnn_ndims = - static_cast(arr.GetMKLDNNData()->get_primitive_desc().desc().data.ndims); + // DO not support zero-size tensors. + if (arr.shape().Size() == 0) return false; + int ndim = arr.shape().ndim(); + const int mkldnn_ndims = arr.GetMKLDNNData()->get_primitive_desc().desc().data.ndims; if (!(ndim == 2 || ndim == 4) || ndim != mkldnn_ndims) return false; } return true; diff --git a/src/operator/nn/convolution-inl.h b/src/operator/nn/convolution-inl.h index 7ae34ae363b4..7d5f7c7d5757 100644 --- a/src/operator/nn/convolution-inl.h +++ b/src/operator/nn/convolution-inl.h @@ -69,11 +69,11 @@ struct ConvolutionParam : public dmlc::Parameter { dmlc::optional layout; DMLC_DECLARE_PARAMETER(ConvolutionParam) { DMLC_DECLARE_FIELD(kernel).describe("Convolution kernel size: (w,), (h, w) or (d, h, w)"); - DMLC_DECLARE_FIELD(stride).set_default(mxnet::TShape()) + DMLC_DECLARE_FIELD(stride).set_default(mxnet::TShape(0, 0)) .describe("Convolution stride: (w,), (h, w) or (d, h, w). Defaults to 1 for each dimension."); - DMLC_DECLARE_FIELD(dilate).set_default(mxnet::TShape()) + DMLC_DECLARE_FIELD(dilate).set_default(mxnet::TShape(0, 0)) .describe("Convolution dilate: (w,), (h, w) or (d, h, w). Defaults to 1 for each dimension."); - DMLC_DECLARE_FIELD(pad).set_default(mxnet::TShape()) + DMLC_DECLARE_FIELD(pad).set_default(mxnet::TShape(0, 0)) .describe("Zero pad for convolution: (w,), (h, w) or (d, h, w). Defaults to no padding."); DMLC_DECLARE_FIELD(num_filter).set_range(1, 100000) .describe("Convolution filter(channel) number"); @@ -209,9 +209,9 @@ class ConvolutionOp { Tensor workspace = ctx.requested[conv::kTempSpace] .get_space_typed(Shape1(col_buffer_size_), s); // calculate the shape of col_buffer - mxnet::TShape col_buffer_shape(num_spatial_axes_ + 1); + mxnet::TShape col_buffer_shape(num_spatial_axes_ + 1, 1); col_buffer_shape[0] = conv_in_channels_ * param_.kernel.Size(); - for (index_t i = 1; i < col_buffer_shape.ndim(); ++i) { + for (int i = 1; i < col_buffer_shape.ndim(); ++i) { col_buffer_shape[i] = out_data[0].shape_[i+1]; } // create a column buffer using workspace and col_buffer_shape @@ -295,9 +295,9 @@ class ConvolutionOp { Tensor workspace = ctx.requested[conv::kTempSpace] .get_space_typed(Shape1(col_buffer_size_), s); // calculate the shape of col_buffer - mxnet::TShape col_buffer_shape(num_spatial_axes_ + 1); + mxnet::TShape col_buffer_shape(num_spatial_axes_ + 1, 1); col_buffer_shape[0] = conv_in_channels_ * param_.kernel.Size(); - for (index_t i = 1; i < col_buffer_shape.ndim(); ++i) { + for (int i = 1; i < col_buffer_shape.ndim(); ++i) { col_buffer_shape[i] = out_grad[conv::kData].shape_[i+1]; } // create a column buffer using workspace and col_buffer_shape @@ -342,10 +342,10 @@ class ConvolutionOp { void LayerSetUp(const mxnet::TShape& ishape, const mxnet::TShape& oshape) { channel_axis_ = 1; // hard code channel axis const index_t first_spatial_axis = channel_axis_ + 1; - const index_t num_axes = param_.kernel.ndim() + 2; + const int num_axes = param_.kernel.ndim() + 2; num_spatial_axes_ = num_axes - first_spatial_axis; is_1x1_ = true; - for (index_t i = 0; i < param_.kernel.ndim(); ++i) { + for (int i = 0; i < param_.kernel.ndim(); ++i) { is_1x1_ &= param_.kernel[i] == 1 && param_.stride[i] == 1 && param_.pad[i] == 0; if (!is_1x1_) break; } diff --git a/src/operator/nn/convolution.cc b/src/operator/nn/convolution.cc index 527a0073930f..536e9a731171 100644 --- a/src/operator/nn/convolution.cc +++ b/src/operator/nn/convolution.cc @@ -96,24 +96,28 @@ static bool ConvolutionShape(const nnvm::NodeAttrs& attrs, // CHECK_EQ(out_shape->size(), 1) << "Output: [output]"; out_shape->resize(1, mxnet::TShape()); const mxnet::TShape &dshp = (*in_shape)[conv::kData]; - if (dshp.ndim() == 0) return false; + if (!mxnet::ndim_is_known(dshp)) return false; if (param_.kernel.ndim() == 1) { // 1d conv CHECK_EQ(dshp.ndim(), 3U) << "Input data should be 3D in batch-num_filter-x"; Shape<3> dshape = ConvertLayout(dshp.get<3>(), param_.layout.value(), kNCW); - Shape<3> wshape = Shape3(param_.num_filter / param_.num_group, dshape[1] / param_.num_group, + Shape<3> wshape = Shape3(param_.num_filter / param_.num_group, + mxnet::dim_size_is_known(dshape, 1) ? dshape[1] / param_.num_group : -1, param_.kernel[0]); wshape = ConvertLayout(wshape, kNCW, param_.layout.value()); - wshape[0] *= param_.num_group; + if (wshape[0] >= 0) { + wshape[0] *= param_.num_group; + } SHAPE_ASSIGN_CHECK(*in_shape, conv::kWeight, wshape); if (!param_.no_bias) { SHAPE_ASSIGN_CHECK(*in_shape, conv::kBias, Shape1(param_.num_filter)); } const index_t dilated_ksize_x = param_.DilatedKernelSize(0); - CHECK_EQ(dshape[1] % param_.num_group, 0U) \ - << "input num_filter must divide group size"; + if (dshape[1] != -1) { + CHECK_EQ(dshape[1] % param_.num_group, 0U) << "input num_filter must divide group size"; + } CHECK_EQ(param_.num_filter % param_.num_group, 0U) \ << "output num_filter must divide group size"; CHECK_GT(param_.kernel.Size(), 0U) \ @@ -125,21 +129,21 @@ static bool ConvolutionShape(const nnvm::NodeAttrs& attrs, Shape<3> oshape; oshape[0] = dshape[0]; oshape[1] = param_.num_filter; - oshape[2] = dshape[2] ? - (AddPad(dshape[2], param_.pad[0]) - dilated_ksize_x) / param_.stride[0] + 1 : 0; + oshape[2] = dshape[2] != -1 ? + (AddPad(dshape[2], param_.pad[0]) - dilated_ksize_x) / param_.stride[0] + 1 : -1; SHAPE_ASSIGN_CHECK(*out_shape, 0, ConvertLayout(oshape, kNCW, param_.layout.value())); // Perform incomplete shape inference. Fill in the missing values in data shape. // 1) We can always fill in the batch_size. // 2) We can back-calculate the input height/width if the corresponding stride is 1. oshape = ConvertLayout((*out_shape)[0].get<3>(), param_.layout.value(), kNCW); dshape[0] = oshape[0]; - if (oshape[2] && param_.stride[0] == 1) { + if (oshape[2] != -1 && param_.stride[0] == 1) { dshape[2] = oshape[2] + dilated_ksize_x - 1 - 2 * param_.pad[0]; } SHAPE_ASSIGN_CHECK(*in_shape, conv::kData, ConvertLayout(dshape, kNCW, param_.layout.value())); // Check whether the kernel sizes are valid - if (dshape[2] != 0) { + if (dshape[2] != -1) { CHECK_LE(dilated_ksize_x, AddPad(dshape[2], param_.pad[0])) << "kernel size exceed input"; } return true; @@ -149,10 +153,12 @@ static bool ConvolutionShape(const nnvm::NodeAttrs& attrs, << "Input data should be 4D in batch-num_filter-y-x"; Shape<4> dshape = ConvertLayout(dshp.get<4>(), param_.layout.value(), kNCHW); Shape<4> wshape = Shape4(param_.num_filter / param_.num_group, - dshape[1] / param_.num_group, + mxnet::dim_size_is_known(dshape, 1) ? dshape[1] / param_.num_group : -1, param_.kernel[0], param_.kernel[1]); wshape = ConvertLayout(wshape, kNCHW, param_.layout.value()); - wshape[0] *= param_.num_group; + if (wshape[0] >= 0) { + wshape[0] *= param_.num_group; + } SHAPE_ASSIGN_CHECK(*in_shape, conv::kWeight, wshape); if (!param_.no_bias) { SHAPE_ASSIGN_CHECK(*in_shape, conv::kBias, Shape1(param_.num_filter)); @@ -160,8 +166,9 @@ static bool ConvolutionShape(const nnvm::NodeAttrs& attrs, const index_t dilated_ksize_y = param_.DilatedKernelSize(0); const index_t dilated_ksize_x = param_.DilatedKernelSize(1); - CHECK_EQ(dshape[1] % param_.num_group, 0U) \ - << "input num_filter must divide group size"; + if (dshape[1] != -1) { + CHECK_EQ(dshape[1] % param_.num_group, 0U) << "input num_filter must divide group size"; + } CHECK_EQ(param_.num_filter % param_.num_group, 0U) \ << "output num_filter must divide group size"; CHECK_GT(param_.kernel.Size(), 0U) \ @@ -173,29 +180,29 @@ static bool ConvolutionShape(const nnvm::NodeAttrs& attrs, Shape<4> oshape; oshape[0] = dshape[0]; oshape[1] = param_.num_filter; - oshape[2] = dshape[2] ? - (AddPad(dshape[2], param_.pad[0]) - dilated_ksize_y) / param_.stride[0] + 1 : 0; - oshape[3] = dshape[3] ? - (AddPad(dshape[3], param_.pad[1]) - dilated_ksize_x) / param_.stride[1] + 1 : 0; + oshape[2] = dshape[2] != -1 ? + (AddPad(dshape[2], param_.pad[0]) - dilated_ksize_y) / param_.stride[0] + 1 : -1; + oshape[3] = dshape[3] != -1 ? + (AddPad(dshape[3], param_.pad[1]) - dilated_ksize_x) / param_.stride[1] + 1 : -1; SHAPE_ASSIGN_CHECK(*out_shape, 0, ConvertLayout(oshape, kNCHW, param_.layout.value())); // Perform incomplete shape inference. Fill in the missing values in data shape. // 1) We can always fill in the batch_size. // 2) We can back-calculate the input height/width if the corresponding stride is 1. oshape = ConvertLayout((*out_shape)[0].get<4>(), param_.layout.value(), kNCHW); dshape[0] = oshape[0]; - if (oshape[2] && param_.stride[0] == 1) { + if (oshape[2] != -1 && param_.stride[0] == 1) { dshape[2] = oshape[2] + dilated_ksize_y - 1 - 2 * param_.pad[0]; } - if (oshape[3] && param_.stride[1] == 1) { + if (oshape[3] != -1 && param_.stride[1] == 1) { dshape[3] = oshape[3] + dilated_ksize_x - 1 - 2 * param_.pad[1]; } SHAPE_ASSIGN_CHECK(*in_shape, conv::kData, ConvertLayout(dshape, kNCHW, param_.layout.value())); // Check whether the kernel sizes are valid - if (dshape[2] != 0) { + if (dshape[2] != -1) { CHECK_LE(dilated_ksize_y, AddPad(dshape[2], param_.pad[0])) << "kernel size exceed input"; } - if (dshape[3] != 0) { + if (dshape[3] != -1) { CHECK_LE(dilated_ksize_x, AddPad(dshape[3], param_.pad[1])) << "kernel size exceed input"; } return true; @@ -204,10 +211,13 @@ static bool ConvolutionShape(const nnvm::NodeAttrs& attrs, CHECK_EQ(dshp.ndim(), 5U) \ << "Input data should be 5D in batch-num_filter-depth-y-x"; Shape<5> dshape = ConvertLayout(dshp.get<5>(), param_.layout.value(), kNCDHW); - Shape<5> wshape = Shape5(param_.num_filter / param_.num_group, dshape[1] / param_.num_group, + Shape<5> wshape = Shape5(param_.num_filter / param_.num_group, + mxnet::dim_size_is_known(dshape, 1) ? dshape[1] / param_.num_group : -1, param_.kernel[0], param_.kernel[1], param_.kernel[2]); wshape = ConvertLayout(wshape, kNCDHW, param_.layout.value()); - wshape[0] *= param_.num_group; + if (wshape[0] >= 0) { + wshape[0] *= param_.num_group; + } SHAPE_ASSIGN_CHECK(*in_shape, conv::kWeight, wshape); if (!param_.no_bias) { SHAPE_ASSIGN_CHECK(*in_shape, conv::kBias, Shape1(param_.num_filter)); @@ -218,8 +228,9 @@ static bool ConvolutionShape(const nnvm::NodeAttrs& attrs, const index_t dilated_ksize_d = param_.DilatedKernelSize(0); const index_t dilated_ksize_y = param_.DilatedKernelSize(1); const index_t dilated_ksize_x = param_.DilatedKernelSize(2); - CHECK_EQ(dshape[1] % param_.num_group, 0U) - << "input num_filter must divide group size"; + if (dshape[1] >= 0) { + CHECK_EQ(dshape[1] % param_.num_group, 0U) << "input num_filter must divide group size"; + } CHECK_EQ(param_.num_filter % param_.num_group, 0U) << "output num_filter must divide group size"; CHECK_GT(param_.kernel.Size(), 0U) \ @@ -233,37 +244,37 @@ static bool ConvolutionShape(const nnvm::NodeAttrs& attrs, Shape<5> oshape; oshape[0] = dshape[0]; oshape[1] = param_.num_filter; - oshape[2] = dshape[2] ? - (AddPad(dshape[2], param_.pad[0]) - dilated_ksize_d) / param_.stride[0] + 1 : 0; - oshape[3] = dshape[3] ? - (AddPad(dshape[3], param_.pad[1]) - dilated_ksize_y) / param_.stride[1] + 1 : 0; - oshape[4] = dshape[4] ? - (AddPad(dshape[4], param_.pad[2]) - dilated_ksize_x) / param_.stride[2] + 1 : 0; + oshape[2] = dshape[2] != -1 ? + (AddPad(dshape[2], param_.pad[0]) - dilated_ksize_d) / param_.stride[0] + 1 : -1; + oshape[3] = dshape[3] != -1 ? + (AddPad(dshape[3], param_.pad[1]) - dilated_ksize_y) / param_.stride[1] + 1 : -1; + oshape[4] = dshape[4] != -1 ? + (AddPad(dshape[4], param_.pad[2]) - dilated_ksize_x) / param_.stride[2] + 1 : -1; SHAPE_ASSIGN_CHECK(*out_shape, 0, ConvertLayout(oshape, kNCDHW, param_.layout.value())); // Perform incomplete shape inference. Fill in the missing values in data shape. // 1) We can always fill in the batch_size. // 2) We can back-calculate the input depth/height/width if the corresponding stride is 1. oshape = ConvertLayout((*out_shape)[0].get<5>(), param_.layout.value(), kNCDHW); dshape[0] = oshape[0]; - if (oshape[2] && param_.stride[0] == 1) { + if (oshape[2] != -1 && param_.stride[0] == 1) { dshape[2] = oshape[2] + dilated_ksize_d - 1 - 2 * param_.pad[0]; } - if (oshape[3] && param_.stride[1] == 1) { + if (oshape[3] != -1 && param_.stride[1] == 1) { dshape[3] = oshape[3] + dilated_ksize_y - 1 - 2 * param_.pad[1]; } - if (oshape[4] && param_.stride[2] == 1) { + if (oshape[4] != -1 && param_.stride[2] == 1) { dshape[4] = oshape[4] + dilated_ksize_x - 1 - 2 * param_.pad[2]; } SHAPE_ASSIGN_CHECK(*in_shape, conv::kData, ConvertLayout(dshape, kNCDHW, param_.layout.value())); // Check whether the kernel sizes are valid - if (dshape[2] != 0) { + if (dshape[2] != -1) { CHECK_LE(dilated_ksize_d, AddPad(dshape[2], param_.pad[0])) << "kernel size exceed input"; } - if (dshape[3] != 0) { + if (dshape[3] != -1) { CHECK_LE(dilated_ksize_y, AddPad(dshape[3], param_.pad[1])) << "kernel size exceed input"; } - if (dshape[4] != 0) { + if (dshape[4] != -1) { CHECK_LE(dilated_ksize_x, AddPad(dshape[4], param_.pad[2])) << "kernel size exceed input"; } return true; diff --git a/src/operator/nn/ctc_loss-inl.h b/src/operator/nn/ctc_loss-inl.h index 357888dc30f1..8c841dfc24b4 100644 --- a/src/operator/nn/ctc_loss-inl.h +++ b/src/operator/nn/ctc_loss-inl.h @@ -239,7 +239,7 @@ inline bool CTCLossOpShape(const nnvm::NodeAttrs &attrs, "the maximum sequence length of the " "data."; - mxnet::TShape oshape(1); + mxnet::TShape oshape(1, -1); oshape[0] = dshape[1]; // batch size SHAPE_ASSIGN_CHECK(*out_attrs, 0, oshape); // forward output SHAPE_ASSIGN_CHECK(*out_attrs, 1, dshape); // grad output diff --git a/src/operator/nn/cudnn/cudnn_batch_norm.cc b/src/operator/nn/cudnn/cudnn_batch_norm.cc index 5632028dd769..cb35ce170e8e 100644 --- a/src/operator/nn/cudnn/cudnn_batch_norm.cc +++ b/src/operator/nn/cudnn/cudnn_batch_norm.cc @@ -37,7 +37,7 @@ static bool BatchNormShape(const nnvm::NodeAttrs& attrs, mxnet::ShapeVector *in_ using namespace mshadow; CHECK_EQ(in_shape->size(), 5U) << "Input:[data, gamma, beta, moving_mean, moving_var]"; const mxnet::TShape &dshape = in_shape->at(0); - if (dshape.ndim() == 0) return false; + if (!mxnet::ndim_is_known(dshape)) return false; in_shape->at(1) = mxnet::TShape(Shape1(dshape[1])); in_shape->at(2) = mxnet::TShape(Shape1(dshape[1])); in_shape->at(3) = mxnet::TShape(Shape1(dshape[1])); diff --git a/src/operator/nn/cudnn/cudnn_convolution-inl.h b/src/operator/nn/cudnn/cudnn_convolution-inl.h index 55b263896339..679e0cd1057b 100644 --- a/src/operator/nn/cudnn/cudnn_convolution-inl.h +++ b/src/operator/nn/cudnn/cudnn_convolution-inl.h @@ -1015,9 +1015,9 @@ class CuDNNConvolutionOp { // e.g. {shape[0], shape[1], shape[2]} -> {shape[1]*shape[2], shape[2], 1} template inline Shape Strides(const mxnet::TShape &s) { - uint32_t ndim = s.ndim(); - mxnet::TShape strides(ndim); - for (uint32_t i = 0; i != ndim; ++i) + int ndim = s.ndim(); + mxnet::TShape strides(ndim, -1); + for (int i = 0; i != ndim; ++i) strides[i] = s.ProdShape(i+1, ndim); return strides.get(); } diff --git a/src/operator/nn/cudnn/cudnn_deconvolution-inl.h b/src/operator/nn/cudnn/cudnn_deconvolution-inl.h index 47f688c8ab9c..adb6caf1c028 100644 --- a/src/operator/nn/cudnn/cudnn_deconvolution-inl.h +++ b/src/operator/nn/cudnn/cudnn_deconvolution-inl.h @@ -933,9 +933,9 @@ class CuDNNDeconvolutionOp { // e.g. {shape[0], shape[1], shape[2]} -> {shape[1]*shape[2], shape[2], 1} template inline Shape Strides(const mxnet::TShape &s) { - uint32_t ndim = s.ndim(); - mxnet::TShape strides(ndim); - for (uint32_t i = 0; i != ndim; ++i) + int ndim = s.ndim(); + mxnet::TShape strides(ndim, -1); + for (int i = 0; i != ndim; ++i) strides[i] = s.ProdShape(i+1, ndim); return strides.get(); } diff --git a/src/operator/nn/deconvolution-inl.h b/src/operator/nn/deconvolution-inl.h index 5248c1211ac7..1eeccb02e030 100644 --- a/src/operator/nn/deconvolution-inl.h +++ b/src/operator/nn/deconvolution-inl.h @@ -65,13 +65,13 @@ struct DeconvolutionParam : public dmlc::Parameter { DMLC_DECLARE_PARAMETER(DeconvolutionParam) { DMLC_DECLARE_FIELD(kernel).describe("Deconvolution kernel size: (w,), (h, w) or (d, h, w). " "This is same as the kernel size used for the corresponding convolution"); - DMLC_DECLARE_FIELD(stride).set_default(mxnet::TShape()) + DMLC_DECLARE_FIELD(stride).set_default(mxnet::TShape(0, 0)) .describe("The stride used for the corresponding convolution: (w,), (h, w) or (d, h, w). " "Defaults to 1 for each dimension."); - DMLC_DECLARE_FIELD(dilate).set_default(mxnet::TShape()) + DMLC_DECLARE_FIELD(dilate).set_default(mxnet::TShape(0, 0)) .describe("Dilation factor for each dimension of the input: (w,), (h, w) or (d, h, w). " "Defaults to 1 for each dimension."); - DMLC_DECLARE_FIELD(pad).set_default(mxnet::TShape()) + DMLC_DECLARE_FIELD(pad).set_default(mxnet::TShape(0, 0)) .describe("The amount of implicit zero padding added during convolution for each " "dimension of the input: " "(w,), (h, w) or (d, h, w). " @@ -79,11 +79,11 @@ struct DeconvolutionParam : public dmlc::Parameter { "If `target_shape` is set, " "`pad` will be ignored and a padding that will generate the target shape " "will be used. Defaults to no padding."); - DMLC_DECLARE_FIELD(adj).set_default(mxnet::TShape()) + DMLC_DECLARE_FIELD(adj).set_default(mxnet::TShape(0, 0)) .describe("Adjustment for output shape: (w,), (h, w) or (d, h, w). " "If `target_shape` is set, " "`adj` will be ignored and computed accordingly."); - DMLC_DECLARE_FIELD(target_shape).set_default(mxnet::TShape()) + DMLC_DECLARE_FIELD(target_shape).set_default(mxnet::TShape(0, 0)) .describe("Shape of the output tensor: (w,), (h, w) or (d, h, w)."); DMLC_DECLARE_FIELD(num_filter).set_range(1, 100000) .describe("Number of output filters."); @@ -134,16 +134,18 @@ struct DeconvolutionParam : public dmlc::Parameter { for (size_t i = 0; i < ndim; i++) { // input.ndim() can be larger than ndim, in case that the complete input // shape was passed and not only the ndim last ones - o_pad[i] = stride[i] * (input[(input_ndim - ndim) + i] - 1) + DilatedKernelSize(i); - CHECK_GE(o_pad[i], target_shape[i]) << "too big target shape"; - o_pad[i] -= target_shape[i]; - o_adj[i] = o_pad[i] % 2; - o_pad[i] = (o_pad[i] + 1) / 2; + if (mxnet::dim_size_is_known(input, input_ndim - ndim + i)) { + o_pad[i] = stride[i] * (input[(input_ndim - ndim) + i] - 1) + DilatedKernelSize(i); + CHECK_GE(o_pad[i], target_shape[i]) << "too big target shape"; + o_pad[i] -= target_shape[i]; + o_adj[i] = o_pad[i] % 2; + o_pad[i] = (o_pad[i] + 1) / 2; + } } } else { - for (size_t i = 0; i < ndim; i++) { - o_pad[i] = pad[i]; - o_adj[i] = adj[i]; + for (int i = 0; i < static_cast(ndim); i++) { + o_pad[i] = i < pad.ndim() ? pad[i] : 0; + o_adj[i] = i < adj.ndim() ? adj[i] : 0; } } } diff --git a/src/operator/nn/deconvolution.cc b/src/operator/nn/deconvolution.cc index 27928b9b41c3..09b255d009e0 100644 --- a/src/operator/nn/deconvolution.cc +++ b/src/operator/nn/deconvolution.cc @@ -54,7 +54,7 @@ static bool DeconvolutionShape(const nnvm::NodeAttrs& attrs, } out_shape->resize(1, mxnet::TShape()); const mxnet::TShape &dshape = (*in_shape)[deconv::kData]; - if (dshape.ndim() == 0) return false; + if (!mxnet::ndim_is_known(dshape)) return false; if (param_.kernel.ndim() == 1) { // 1d conv @@ -90,8 +90,12 @@ static bool DeconvolutionShape(const nnvm::NodeAttrs& attrs, Shape<3> oshape; oshape[0] = dshape_ncw[0]; oshape[1] = param_.num_filter; - oshape[2] = param_.stride[0] * (dshape_ncw[2] - 1) + - dilated_ksize_x - 2 * o_pad[0] + o_adj[0]; + if (mxnet::dim_size_is_known(dshape_ncw[2])) { + oshape[2] = param_.stride[0] * (dshape_ncw[2] - 1) + + dilated_ksize_x - 2 * o_pad[0] + o_adj[0]; + } else { + oshape[2] = -1; + } if (param_.target_shape.ndim() > 0) { if (param_.target_shape[0] > 0) { @@ -141,10 +145,18 @@ static bool DeconvolutionShape(const nnvm::NodeAttrs& attrs, Shape<4> oshape; oshape[0] = dshape_nchw[0]; oshape[1] = param_.num_filter; - oshape[2] = param_.stride[0] * (dshape_nchw[2] - 1) + - dilated_ksize_y - 2 * o_pad[0] + o_adj[0]; - oshape[3] = param_.stride[1] * (dshape_nchw[3] - 1) + - dilated_ksize_x - 2 * o_pad[1] + o_adj[1]; + if (mxnet::dim_size_is_known(dshape_nchw[2])) { + oshape[2] = param_.stride[0] * (dshape_nchw[2] - 1) + + dilated_ksize_y - 2 * o_pad[0] + o_adj[0]; + } else { + oshape[2] = -1; + } + if (mxnet::dim_size_is_known(dshape_nchw[3])) { + oshape[3] = param_.stride[1] * (dshape_nchw[3] - 1) + + dilated_ksize_x - 2 * o_pad[1] + o_adj[1]; + } else { + oshape[3] = -1; + } if (param_.target_shape.ndim() > 1) { if (param_.target_shape[0] > 0) { @@ -203,12 +215,24 @@ static bool DeconvolutionShape(const nnvm::NodeAttrs& attrs, Shape<5> oshape; oshape[0] = dshape_ncdhw[0]; oshape[1] = param_.num_filter; - oshape[2] = param_.stride[0] * (dshape_ncdhw[2] - 1) + - dilated_ksize_d - 2 * o_pad[0] + o_adj[0]; - oshape[3] = param_.stride[1] * (dshape_ncdhw[3] - 1) + - dilated_ksize_y - 2 * o_pad[1] + o_adj[1]; - oshape[4] = param_.stride[2] * (dshape_ncdhw[4] - 1) + - dilated_ksize_x - 2 * o_pad[2] + o_adj[2]; + if (mxnet::dim_size_is_known(dshape_ncdhw[2])) { + oshape[2] = param_.stride[0] * (dshape_ncdhw[2] - 1) + + dilated_ksize_d - 2 * o_pad[0] + o_adj[0]; + } else { + oshape[2] = -1; + } + if (mxnet::dim_size_is_known(dshape_ncdhw[3])) { + oshape[3] = param_.stride[1] * (dshape_ncdhw[3] - 1) + + dilated_ksize_y - 2 * o_pad[1] + o_adj[1]; + } else { + oshape[3] = -1; + } + if (mxnet::dim_size_is_known(dshape_ncdhw[4])) { + oshape[4] = param_.stride[2] * (dshape_ncdhw[4] - 1) + + dilated_ksize_x - 2 * o_pad[2] + o_adj[2]; + } else { + oshape[4] = -1; + } if (param_.target_shape.ndim() > 2) { if (param_.target_shape[0] > 0) { diff --git a/src/operator/nn/dropout-inl.h b/src/operator/nn/dropout-inl.h index 01611dfce191..a34d2992c8c6 100644 --- a/src/operator/nn/dropout-inl.h +++ b/src/operator/nn/dropout-inl.h @@ -78,7 +78,7 @@ struct DropoutParam : public dmlc::Parameter { .add_enum("always", dropout::kAlways) .set_default(dropout::kTraining) .describe("Whether to only turn on dropout during training or to also turn on for inference."); - DMLC_DECLARE_FIELD(axes).set_default(mxnet::TShape()) + DMLC_DECLARE_FIELD(axes).set_default(mxnet::TShape(0, 0)) .describe("Axes for variational dropout kernel."); DMLC_DECLARE_FIELD(cudnn_off).set_default(dmlc::optional(false)) .describe("Whether to turn off cudnn in dropout operator. " diff --git a/src/operator/nn/dropout.cc b/src/operator/nn/dropout.cc index 5fdc672d766e..afad6fd5cc80 100644 --- a/src/operator/nn/dropout.cc +++ b/src/operator/nn/dropout.cc @@ -95,10 +95,10 @@ Example:: CHECK_EQ(in_shape->size(), 1U); const DropoutParam& param = nnvm::get(attrs.parsed); mxnet::TShape dshape(in_shape->at(0)); - if (dshape.ndim() == 0) return false; + if (!mxnet::ndim_is_known(dshape)) return false; out_shape->clear(); out_shape->push_back(dshape); - for (index_t i = 0; i < param.axes.ndim(); ++i) { + for (int i = 0; i < param.axes.ndim(); ++i) { dshape[param.axes[i]] = 1; } out_shape->push_back(dshape); diff --git a/src/operator/nn/fully_connected.cc b/src/operator/nn/fully_connected.cc index 2bc321832af6..a097357ef5a3 100644 --- a/src/operator/nn/fully_connected.cc +++ b/src/operator/nn/fully_connected.cc @@ -52,7 +52,7 @@ static bool FullyConnectedShape(const nnvm::NodeAttrs& attrs, mxnet::TShape dshape = (*in_shape)[fullc::kData]; mxnet::TShape oshape = (*out_shape)[0]; // require data to be known - if (dshape.ndim() == 0) return false; + if (!mxnet::ndim_is_known(dshape)) return false; index_t num_input; if (!param.flatten) { @@ -75,7 +75,7 @@ static bool FullyConnectedShape(const nnvm::NodeAttrs& attrs, } else { SHAPE_ASSIGN_CHECK(*out_shape, 0, Shape2(dshape[0], param.num_hidden)); } - if (oshape.ndim() != 0) { + if (oshape.ndim() > 0) { dshape[0] = oshape[0]; SHAPE_ASSIGN_CHECK(*in_shape, fullc::kData, dshape); } diff --git a/src/operator/nn/im2col.h b/src/operator/nn/im2col.h index 0059a420726d..06a4e1b75b33 100644 --- a/src/operator/nn/im2col.h +++ b/src/operator/nn/im2col.h @@ -152,7 +152,7 @@ inline void im2col_nd_core_cpu(const DType* data_input, const bool im2col, const mxnet::TShape& kernel_shape, const mxnet::TShape& pad, const mxnet::TShape& stride, const mxnet::TShape& dilation, DType* data_output, OpReqType req = mxnet::kWriteTo) { if (mxnet::kNullOp == req) return; - index_t num_spatial_axes = kernel_shape.ndim(); + int num_spatial_axes = kernel_shape.ndim(); if (!im2col) { index_t im_size = im_shape[1]; // skip batch dim for (index_t i = 0; i < num_spatial_axes; ++i) { @@ -319,7 +319,7 @@ inline void col2im(mshadow::Stream* s, const mxnet::TShape& col_shape, const mxnet::TShape& kernel_shape, const mxnet::TShape& pad, const mxnet::TShape& stride, const mxnet::TShape& dilation, DType* data_im, OpReqType req) { - index_t num_spatial_axes = kernel_shape.ndim(); + int num_spatial_axes = kernel_shape.ndim(); if (2 == num_spatial_axes) { col2im_cpu(data_col, im_shape[1], im_shape[2], im_shape[3], kernel_shape[0], kernel_shape[1], pad[0], pad[1], diff --git a/src/operator/nn/layer_norm-inl.h b/src/operator/nn/layer_norm-inl.h index dc4914bf2457..c7de7d734521 100644 --- a/src/operator/nn/layer_norm-inl.h +++ b/src/operator/nn/layer_norm-inl.h @@ -167,7 +167,7 @@ void LayerNormGradCompute(const nnvm::NodeAttrs& attrs, const LayerNormParam& param = nnvm::get(attrs.parsed); int axis = param.axis; if (axis < 0) { - axis += static_cast(inputs[0].ndim()); + axis += inputs[0].ndim(); } CHECK(axis >= 0 && axis < inputs[0].ndim()) << "Channel axis out of range: " << param.axis; Stream *s = ctx.get_stream(); diff --git a/src/operator/nn/layer_norm.cc b/src/operator/nn/layer_norm.cc index d4c308398cb7..2e47503a3318 100644 --- a/src/operator/nn/layer_norm.cc +++ b/src/operator/nn/layer_norm.cc @@ -41,14 +41,14 @@ static bool LayerNormShape(const nnvm::NodeAttrs& attrs, const mxnet::TShape &dshape = in_shape->at(layernorm::kData); int axis = param.axis; if (axis < 0) { - axis += static_cast(dshape.ndim()); + axis += dshape.ndim(); } - CHECK(axis >= 0 && axis < static_cast(dshape.ndim())) + CHECK(axis >= 0 && axis < dshape.ndim()) << "Channel axis out of range: axis=" << param.axis; const int channelCount = dshape[axis]; - if (dshape.ndim() == 0) { + if (!mxnet::ndim_is_known(dshape)) { return false; } diff --git a/src/operator/nn/lrn.cc b/src/operator/nn/lrn.cc index 410bdab667e5..b632e35b57fe 100644 --- a/src/operator/nn/lrn.cc +++ b/src/operator/nn/lrn.cc @@ -40,7 +40,7 @@ bool LRNShape(const nnvm::NodeAttrs& attrs, using namespace mshadow; CHECK_EQ(in_shape->size(), 1U) << "Input:[data]"; const mxnet::TShape &dshape = in_shape->at(0); - if (dshape.ndim() == 0) return false; + if (!shape_is_known(dshape)) return false; out_shape->clear(); out_shape->push_back(dshape); out_shape->push_back(dshape); diff --git a/src/operator/nn/mkldnn/mkldnn_base-inl.h b/src/operator/nn/mkldnn/mkldnn_base-inl.h index a460e33fa548..3da3f23d7683 100644 --- a/src/operator/nn/mkldnn/mkldnn_base-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_base-inl.h @@ -464,7 +464,7 @@ mkldnn::memory::primitive_desc GetPrimitiveDesc(mkldnn::memory::primitive_desc p mkldnn_memory_format_t format); inline bool same_shape(const mxnet::TShape &shape, const mkldnn_dims_t dims, int ndims) { - if (shape.ndim() != (size_t)ndims) + if (shape.ndim() != ndims) return false; for (int i = 0; i < ndims; i++) if (shape[i] != dims[i]) diff --git a/src/operator/nn/mkldnn/mkldnn_concat.cc b/src/operator/nn/mkldnn/mkldnn_concat.cc index 8e2b57781a18..7b266efc2a14 100644 --- a/src/operator/nn/mkldnn/mkldnn_concat.cc +++ b/src/operator/nn/mkldnn/mkldnn_concat.cc @@ -92,13 +92,13 @@ void MKLDNNConcatBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, auto gz_mem = inputs[0].GetMKLDNNData(); mkldnn::memory::primitive_desc gz_pd = gz_mem->get_primitive_desc(); /* init the offset */ - mkldnn::memory::dims offsets = {0, 0, 0, 0}; + mkldnn::memory::dims offsets(outputs[0].shape().ndim()); + for (auto &v : offsets) { + v = 0; + } + for (int i = 0; i < num_in_data; i++) { - mkldnn::memory::dims diff_src_tz - = {static_cast(outputs[i].shape()[0]), - static_cast(outputs[i].shape()[1]), - static_cast(outputs[i].shape()[2]), - static_cast(outputs[i].shape()[3])}; + mkldnn::memory::dims diff_src_tz(outputs[i].shape().begin(), outputs[i].shape().end()); auto diff_src_mpd = outputs[i].GetMKLDNNData()->get_primitive_desc(); auto gradi_mem_ = CreateMKLDNNMem(outputs[i], diff_src_mpd, req[i]); // create view from gy to gxs[i] diff --git a/src/operator/nn/mkldnn/mkldnn_slice.cc b/src/operator/nn/mkldnn/mkldnn_slice.cc index 3f3d82020598..2a817a25a5b8 100644 --- a/src/operator/nn/mkldnn/mkldnn_slice.cc +++ b/src/operator/nn/mkldnn/mkldnn_slice.cc @@ -37,12 +37,12 @@ MKLDNNSliceFwd::MKLDNNSliceFwd(const SliceParam ¶m, const NDArray &out) { const mxnet::TShape ishape = in.shape(); const mxnet::TShape oshape = out.shape(); - uint32_t N = ishape.ndim(); + const int N = ishape.ndim(); mkldnn::memory::dims dims(N); mkldnn::memory::dims offsets(N); - for (uint32_t i = 0; i < N; ++i) { + for (int i = 0; i < N; ++i) { int s = 0; - if (param.begin[i]) { + if (i < param.begin.ndim() && param.begin[i]) { s = *param.begin[i]; if (s < 0) s += ishape[i]; } diff --git a/src/operator/nn/mkldnn/mkldnn_transpose.cc b/src/operator/nn/mkldnn/mkldnn_transpose.cc index 0986d0616f75..eec19bababb7 100644 --- a/src/operator/nn/mkldnn/mkldnn_transpose.cc +++ b/src/operator/nn/mkldnn/mkldnn_transpose.cc @@ -55,9 +55,9 @@ class MKLDNNTransposeForward { auto shape = data.shape(); auto data_ndim = shape.ndim(); auto axes_ndim = param.axes.ndim(); - auto axes = mxnet::TShape(data_ndim); + auto axes = mxnet::TShape(data_ndim, -1); if (axes_ndim == 0) { - for (size_t i = 0; i < data_ndim; i++) { + for (int i = 0; i < data_ndim; i++) { axes[i] = data_ndim - i - 1; } } else { @@ -79,7 +79,7 @@ class MKLDNNTransposeForward { dst_fmt.data_type = mkldnn_f32; dst_fmt.format = mkldnn_blocked; - for (size_t i = 0; i < data_ndim; i++) + for (int i = 0; i < data_ndim; i++) dst_fmt.dims[i] = shape[i]; unsigned int total_stride = 1; diff --git a/src/operator/nn/pooling-inl.h b/src/operator/nn/pooling-inl.h index 9e1e73bf19e2..03f0fa8edd6c 100644 --- a/src/operator/nn/pooling-inl.h +++ b/src/operator/nn/pooling-inl.h @@ -55,7 +55,7 @@ struct PoolingParam : public dmlc::Parameter { dmlc::optional count_include_pad; dmlc::optional layout; DMLC_DECLARE_PARAMETER(PoolingParam) { - DMLC_DECLARE_FIELD(kernel).set_default(mxnet::TShape()) // add default value here + DMLC_DECLARE_FIELD(kernel).set_default(mxnet::TShape(0, 0)) // add default value here .enforce_nonzero() .describe("Pooling kernel size: (y, x) or (d, y, x)"); @@ -78,11 +78,11 @@ struct PoolingParam : public dmlc::Parameter { .add_enum("same", pool_enum::kSame) .describe("Pooling convention to be applied."); - DMLC_DECLARE_FIELD(stride).set_default(mxnet::TShape()) + DMLC_DECLARE_FIELD(stride).set_default(mxnet::TShape(0, 0)) .enforce_nonzero() .describe("Stride: for pooling (y, x) or (d, y, x). Defaults to 1 for each dimension."); - DMLC_DECLARE_FIELD(pad).set_default(mxnet::TShape()) + DMLC_DECLARE_FIELD(pad).set_default(mxnet::TShape(0, 0)) .describe("Pad for pooling: (y, x) or (d, y, x). Defaults to no padding."); DMLC_DECLARE_FIELD(p_value).set_default(dmlc::optional()) @@ -200,11 +200,11 @@ class PoolingOp { kernel = mxnet::TShape(ishape.data() + 2, ishape.data() + ishape.ndim()); } - padding = mxnet::TShape(ishape.ndim() - 2); + padding = mxnet::TShape(ishape.ndim() - 2, 0); for (index_t i = 0; i < ishape.ndim() - 2; i++) { padding[i] = 0; } - stride = mxnet::TShape(ishape.ndim() - 2); + stride = mxnet::TShape(ishape.ndim() - 2, 1); } const int p_value = (param_.pool_type == pool_enum::kLpPooling && param_.p_value.has_value()) ? param_.p_value.value() : 1; @@ -257,11 +257,11 @@ class PoolingOp { kernel = mxnet::TShape(ishape.data() + 2, ishape.data() + ishape.ndim()); } - padding = mxnet::TShape(ishape.ndim() - 2); + padding = mxnet::TShape(ishape.ndim() - 2, 0); for (index_t i = 0; i < ishape.ndim() - 2; i++) { padding[i] = 0; } - stride = mxnet::TShape(ishape.ndim() - 2); + stride = mxnet::TShape(ishape.ndim() - 2, 1); } const int p_value = (param_.pool_type == pool_enum::kLpPooling && param_.p_value.has_value()) ? diff --git a/src/operator/nn/pooling.cc b/src/operator/nn/pooling.cc index 2d16604baa20..3e081c9a0552 100644 --- a/src/operator/nn/pooling.cc +++ b/src/operator/nn/pooling.cc @@ -114,11 +114,11 @@ static bool PoolingShape(const nnvm::NodeAttrs &attrs, << "Pooling: Input data should be 3D in (batch, channel, x)" << " Or 4D in (batch, channel, y, x) " << " Or 5D in (batch, channel, d, y, x)"; - if (dshape.ndim() == 0) return false; + if (!mxnet::ndim_is_known(dshape)) return false; int layout = param.GetLayout(dshape.ndim()); if (param.global_pool) { mxnet::TShape oshape = dshape; - size_t c_index = 0; + int c_index = 0; switch (layout) { case mshadow::kNCW: case mshadow::kNCHW: @@ -133,7 +133,7 @@ static bool PoolingShape(const nnvm::NodeAttrs &attrs, default: LOG(FATAL) << "Unsupported tensor layout " << param.layout.value(); } - for (size_t i{1}; i < dshape.ndim(); i++) + for (int i = 1; i < dshape.ndim(); i++) if (i != c_index) oshape[i] = 1; out_shape->clear(); diff --git a/src/operator/nn/upsampling.cc b/src/operator/nn/upsampling.cc index d09017bf713e..ac638162dc6d 100644 --- a/src/operator/nn/upsampling.cc +++ b/src/operator/nn/upsampling.cc @@ -60,7 +60,7 @@ static bool UpSamplingShape(const nnvm::NodeAttrs& attrs, CHECK_EQ(in_shape->size(), 2U) << "Input:[data, weight]"; CHECK_EQ(dshape.ndim(), 4U) << \ "UpSamplingBilinear: Input data should be 4D in (batch, channel, y, x)"; - if (dshape.ndim() == 0) return false; + if (!shape_is_known(dshape)) return false; int kernel = 2 * param_.scale - param_.scale % 2; SHAPE_ASSIGN_CHECK(*in_shape, up_enum::kWeight, diff --git a/src/operator/operator_common.h b/src/operator/operator_common.h index f629534dabd0..59f572211d0e 100644 --- a/src/operator/operator_common.h +++ b/src/operator/operator_common.h @@ -103,9 +103,10 @@ struct InferStorageTypeError : public dmlc::Error { : dmlc::Error(msg_), msg(msg_), index(index) {} }; -/*! \brief check if shape is empty or contains unknown (0) dim. */ +/*! \brief check if shape is empty or contains unknown (0) dim. + * DEPRECATED. */ inline bool shape_is_none(const mxnet::TShape& x) { - return x.ndim() == 0 || x.Size() == 0; + return !mxnet::shape_is_known(x); } /*! \brief check if type is none (-1) */ @@ -120,7 +121,7 @@ inline bool storage_type_is_none(const int& x) { /*! \brief check if shape is scalar({1}). */ inline bool shape_is_scalar(const mxnet::TShape& x) { - return x.ndim() == 1 && x.Size() == 1; + return x.ndim() == 0; } /*! \brief get string representation of shape */ @@ -159,16 +160,16 @@ inline std::string type_string(const int& x) { * \return whether x and y are compatible. */ inline bool shape_assign(mxnet::TShape *y, const mxnet::TShape& x) { - if (y->ndim() == 0) { + if (!mxnet::ndim_is_known(*y)) { *y = x; return true; } else if (y->ndim() != x.ndim()) { - return x.ndim() == 0; + return !mxnet::ndim_is_known(x); } else { - for (size_t i = 0; i < y->ndim(); ++i) { - if ((*y)[i] == 0) { + for (int i = 0; i < y->ndim(); ++i) { + if (!mxnet::dim_size_is_known(*y, i)) { (*y)[i] = x[i]; - } else if ((*y)[i] != x[i] && x[i] != 0) { + } else if ((*y)[i] != x[i] && x[i] >= 0) { return false; } } @@ -563,7 +564,7 @@ class OpSignature { } void AddSign(const mxnet::TShape &shape) { - for (size_t i = 0; i < shape.ndim(); i++) { + for (int i = 0; i < shape.ndim(); i++) { hash = hash * 2 + shape[i]; eles.push_back(shape[i]); } diff --git a/src/operator/operator_util.cc b/src/operator/operator_util.cc index b87428ca2b64..bc097a5b0c1c 100644 --- a/src/operator/operator_util.cc +++ b/src/operator/operator_util.cc @@ -774,7 +774,7 @@ class SimpleUnaryOpProp : public SimpleOpPropBase { using namespace mshadow; CHECK_EQ(in_shape->size(), 1) << "Input:[data]"; const mxnet::TShape &dshape = in_shape->at(0); - if (dshape.ndim() == 0) return false; + if (!shape_is_known(dshape)) return false; out_shape->clear(); if (source->unary_shape_ == nullptr) { out_shape->push_back(dshape); diff --git a/src/operator/pad-inl.h b/src/operator/pad-inl.h index 140d7099e817..89b0ab7780b6 100644 --- a/src/operator/pad-inl.h +++ b/src/operator/pad-inl.h @@ -230,7 +230,7 @@ class PadProp : public OperatorProperty { } } mxnet::TShape oshape = dshape; - for (size_t i = 0; i < dshape.ndim(); ++i) { + for (int i = 0; i < dshape.ndim(); ++i) { oshape[i] = param_.pad_width[2 * i] + param_.pad_width[2 * i + 1] + dshape[i]; } diff --git a/src/operator/pooling_v1-inl.h b/src/operator/pooling_v1-inl.h index 4e0ccc1caeb9..4241b08a0c5e 100644 --- a/src/operator/pooling_v1-inl.h +++ b/src/operator/pooling_v1-inl.h @@ -55,7 +55,7 @@ struct PoolingV1Param : public dmlc::Parameter { int pooling_convention; bool global_pool; DMLC_DECLARE_PARAMETER(PoolingV1Param) { - DMLC_DECLARE_FIELD(kernel).set_default(mxnet::TShape()) + DMLC_DECLARE_FIELD(kernel).set_default(mxnet::TShape(0, -1)) .enforce_nonzero() .describe("pooling kernel size: (y, x) or (d, y, x)"); @@ -73,11 +73,11 @@ struct PoolingV1Param : public dmlc::Parameter { .add_enum("valid", pool_v1_enum::kValid) .describe("Pooling convention to be applied."); - DMLC_DECLARE_FIELD(stride).set_default(mxnet::TShape()) + DMLC_DECLARE_FIELD(stride).set_default(mxnet::TShape(0, -1)) .enforce_nonzero() .describe("stride: for pooling (y, x) or (d, y, x)"); - DMLC_DECLARE_FIELD(pad).set_default(mxnet::TShape()) + DMLC_DECLARE_FIELD(pad).set_default(mxnet::TShape(0, -1)) .describe("pad for pooling: (y, x) or (d, y, x)"); } }; @@ -217,19 +217,20 @@ class PoolingV1Prop : public OperatorProperty { void Init(const std::vector >& kwargs) override { using namespace mshadow; param_.Init(kwargs); - if (!param_.global_pool) { - if (param_.kernel.ndim() == 2) { - if (param_.stride.ndim() == 0) param_.stride = Shape2(1, 1); - if (param_.pad.ndim() == 0) param_.pad = Shape2(0, 0); - } else { - CHECK_EQ(param_.kernel.ndim(), 3U) << param_.kernel.ndim() << "D pooling not supported"; - if (param_.stride.ndim() == 0) param_.stride = Shape3(1, 1, 1); - if (param_.pad.ndim() == 0) param_.pad = Shape3(0, 0, 0); - } - CHECK_EQ(param_.stride.ndim(), param_.kernel.ndim()) - << "stride and kernel should have the same length"; - CHECK_EQ(param_.pad.ndim(), param_.kernel.ndim()) - << "pad and kernel should have the same length"; + if (param_.kernel.ndim() == 1) { + if (param_.stride.ndim() == 0) param_.stride = Shape1(1); + if (param_.pad.ndim() == 0) param_.pad = Shape1(0); + } else if (param_.kernel.ndim() == 2) { + if (param_.stride.ndim() == 0) param_.stride = Shape2(1, 1); + if (param_.pad.ndim() == 0) param_.pad = Shape2(0, 0); + } else { + // ignore kernel size only if global_pool not assigned false + if (param_.global_pool == false) { + CHECK_EQ(param_.kernel.ndim(), 3U) << param_.kernel.ndim() + << "D pooling not supported"; + } + if (param_.stride.ndim() == 0) param_.stride = Shape3(1, 1, 1); + if (param_.pad.ndim() == 0) param_.pad = Shape3(0, 0, 0); } } @@ -247,7 +248,7 @@ class PoolingV1Prop : public OperatorProperty { CHECK_LE(dshape.ndim(), 5U) << "Pooling: Input data should be 4D in (batch, channel, y, x) " << "Or 5D in (batch, channel, d, y, x)"; mxnet::TShape oshape = dshape; - if (dshape.ndim() == 0) return false; + if (dshape.ndim() == -1) return false; if (param_.global_pool) { if (dshape.ndim() == 4) { oshape[2] = 1; diff --git a/src/operator/quantization/dequantize-inl.h b/src/operator/quantization/dequantize-inl.h index dcda5a8b4bef..7c91ad507fd9 100644 --- a/src/operator/quantization/dequantize-inl.h +++ b/src/operator/quantization/dequantize-inl.h @@ -99,11 +99,11 @@ inline bool DequantizeShape(const nnvm::NodeAttrs& attrs, CHECK_EQ(out_attrs->size(), 1U); for (size_t i = 1; i < 3; ++i) { - SHAPE_ASSIGN_CHECK(*in_attrs, i, mxnet::TShape({1})); + SHAPE_ASSIGN_CHECK(*in_attrs, i, mxnet::TShape(1, 1)); } SHAPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0)); - return !shape_is_none(out_attrs->at(0)); + return shape_is_known(out_attrs->at(0)); } inline bool DequantizeType(const nnvm::NodeAttrs& attrs, diff --git a/src/operator/quantization/mkldnn/mkldnn_requantize-inl.h b/src/operator/quantization/mkldnn/mkldnn_requantize-inl.h index 45713589dd48..ac414c72d51a 100644 --- a/src/operator/quantization/mkldnn/mkldnn_requantize-inl.h +++ b/src/operator/quantization/mkldnn/mkldnn_requantize-inl.h @@ -115,7 +115,7 @@ static void MKLDNNRequantizeForward(const nnvm::NodeAttrs& attrs, const size_t actual_float_size = sizeof(float); const size_t actual_quantized_size = sizeof(SrcDType); const size_t temp_reduce_size = ConfigReduce(s, - inputs[0].shape(), mxnet::TShape({1}), &src_shape, &dst_shape); + inputs[0].shape(), mxnet::TShape(1, 1), &src_shape, &dst_shape); Tensor temp_space = ctx.requested[0].get_space_typed( Shape1(2*actual_float_size+2*actual_quantized_size+temp_reduce_size), s); diff --git a/src/operator/quantization/quantize-inl.h b/src/operator/quantization/quantize-inl.h index 1ad0016c52bc..7b856579a7b5 100644 --- a/src/operator/quantization/quantize-inl.h +++ b/src/operator/quantization/quantize-inl.h @@ -120,13 +120,13 @@ inline bool QuantizeShape(const nnvm::NodeAttrs& attrs, CHECK_EQ(out_attrs->size(), 3U); for (size_t i = 1; i < 3; ++i) { - SHAPE_ASSIGN_CHECK(*in_attrs, i, mxnet::TShape({1})); + SHAPE_ASSIGN_CHECK(*in_attrs, i, mxnet::TShape(1, 1)); } SHAPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0)); SHAPE_ASSIGN_CHECK(*out_attrs, 1, mxnet::TShape{1}); SHAPE_ASSIGN_CHECK(*out_attrs, 2, mxnet::TShape{1}); - return !shape_is_none(out_attrs->at(0)); + return shape_is_known(out_attrs->at(0)); } inline bool QuantizeType(const nnvm::NodeAttrs& attrs, diff --git a/src/operator/quantization/quantize_v2-inl.h b/src/operator/quantization/quantize_v2-inl.h index 02ace6c39fac..9ebb645e1ba6 100644 --- a/src/operator/quantization/quantize_v2-inl.h +++ b/src/operator/quantization/quantize_v2-inl.h @@ -175,7 +175,7 @@ void QuantizeV2Compute(const nnvm::NodeAttrs &attrs, const OpContext &ctx, mxnet::TShape src_shape, dst_shape; const size_t actual_float_size = sizeof(float); const size_t temp_reduce_size = ConfigReduce( - s, inputs[0].shape_, mxnet::TShape({1}), &src_shape, &dst_shape); + s, inputs[0].shape_, mxnet::TShape(1, 1), &src_shape, &dst_shape); Tensor temp_space = ctx.requested[0].get_space_typed( Shape1(2 * actual_float_size + temp_reduce_size), s); const int dev_id = ctx.run_ctx.ctx.dev_id; diff --git a/src/operator/quantization/quantized_concat.cc b/src/operator/quantization/quantized_concat.cc index e32bb5a18e1a..d6aeb41da1f8 100644 --- a/src/operator/quantization/quantized_concat.cc +++ b/src/operator/quantization/quantized_concat.cc @@ -35,34 +35,34 @@ static bool ConcatShape(const nnvm::NodeAttrs& attrs, mxnet::ShapeVector* in_sha CHECK_EQ(out_shape->size(), 3U); mxnet::TShape dshape; index_t size = 0; - bool has_zero = false; + bool has_unknown_dim_size = false; int axis = -1; for (int i = 0; i < param_.num_args; ++i) { mxnet::TShape tmp = (*in_shape)[i]; - if (tmp.ndim()) { + if (tmp.ndim() > 0) { axis = CheckAxis(param_.dim, tmp.ndim()); - has_zero = tmp[axis] == 0 || has_zero; + has_unknown_dim_size = !mxnet::dim_size_is_known(tmp, axis) || has_unknown_dim_size; size += tmp[axis]; - tmp[axis] = 0; + tmp[axis] = -1; shape_assign(&dshape, tmp); } } mxnet::TShape tmp = (*out_shape)[0]; - if (tmp.ndim()) { + if (tmp.ndim() > 0) { axis = CheckAxis(param_.dim, tmp.ndim()); - tmp[axis] = 0; + tmp[axis] = -1; shape_assign(&dshape, tmp); } - if (dshape.ndim() == 0) return false; + if (!mxnet::ndim_is_known(dshape)) return false; for (int i = 0; i < param_.num_args; ++i) { CHECK(shape_assign(&(*in_shape)[i], dshape)) << "Incompatible input shape: expected " << dshape << ", got " << (*in_shape)[i]; } - if (!has_zero) dshape[axis] = size; + if (!has_unknown_dim_size) dshape[axis] = size; CHECK(shape_assign(&(*out_shape)[0], dshape)) << "Incompatible output shape: expected " << dshape << ", got " << (*out_shape)[0]; @@ -71,7 +71,7 @@ static bool ConcatShape(const nnvm::NodeAttrs& attrs, mxnet::ShapeVector* in_sha } SHAPE_ASSIGN_CHECK(*out_shape, 1, mxnet::TShape{1}); SHAPE_ASSIGN_CHECK(*out_shape, 2, mxnet::TShape{1}); - return dshape.Size() != 0; + return shape_is_known(dshape); } static bool ConcatType(const nnvm::NodeAttrs& attrs, std::vector* in_type, diff --git a/src/operator/quantization/quantized_conv.cc b/src/operator/quantization/quantized_conv.cc index 1a801ee50744..aa3f5ce1ad61 100644 --- a/src/operator/quantization/quantized_conv.cc +++ b/src/operator/quantization/quantized_conv.cc @@ -78,8 +78,8 @@ bool QuantizedConvShape(const nnvm::NodeAttrs& attrs, oshape[W] = (AddPad(dshape[W], param.pad[1]) - wshape[W]) / param.stride[1] + 1; SHAPE_ASSIGN_CHECK(*out_shape, 0, oshape); - SHAPE_ASSIGN_CHECK(*out_shape, 1, mxnet::TShape({1})); - SHAPE_ASSIGN_CHECK(*out_shape, 2, mxnet::TShape({1})); + SHAPE_ASSIGN_CHECK(*out_shape, 1, mxnet::TShape(1, 1)); + SHAPE_ASSIGN_CHECK(*out_shape, 2, mxnet::TShape(1, 1)); return true; } diff --git a/src/operator/quantization/quantized_flatten-inl.h b/src/operator/quantization/quantized_flatten-inl.h index 99a262de19ca..de051b969659 100644 --- a/src/operator/quantization/quantized_flatten-inl.h +++ b/src/operator/quantization/quantized_flatten-inl.h @@ -86,10 +86,10 @@ inline bool QuantizedFlattenShape(const nnvm::NodeAttrs& attrs, CHECK_EQ(out_attrs->size(), 3U); const mxnet::TShape &dshape = (*in_attrs)[0]; - if (shape_is_none(dshape)) return false; + if (!shape_is_known(dshape)) return false; - uint32_t target_dim = 1; - for (uint32_t i = 1; i < dshape.ndim(); ++i) { + dim_t target_dim = 1; + for (int i = 1; i < dshape.ndim(); ++i) { target_dim *= dshape[i]; } diff --git a/src/operator/quantization/quantized_fully_connected.cc b/src/operator/quantization/quantized_fully_connected.cc index 0a04e71b9093..e42ea3020352 100644 --- a/src/operator/quantization/quantized_fully_connected.cc +++ b/src/operator/quantization/quantized_fully_connected.cc @@ -47,7 +47,7 @@ bool QuantizedFullyConnectedShape(const nnvm::NodeAttrs& attrs, CHECK_EQ(in_shape->size(), num_inputs * 3); CHECK_EQ(out_shape->size(), 3U); - CHECK(!shape_is_none(in_shape->at(0))) + CHECK(shape_is_known(in_shape->at(0))) << "QuantizedFullyConnectedOp input data shape must be given"; const mxnet::TShape& dshape = in_shape->at(0); index_t num_input; @@ -75,8 +75,8 @@ bool QuantizedFullyConnectedShape(const nnvm::NodeAttrs& attrs, } else { SHAPE_ASSIGN_CHECK(*out_shape, 0, Shape2(dshape[0], param.num_hidden)); } - SHAPE_ASSIGN_CHECK(*out_shape, 1, mxnet::TShape({1})); - SHAPE_ASSIGN_CHECK(*out_shape, 2, mxnet::TShape({1})); + SHAPE_ASSIGN_CHECK(*out_shape, 1, mxnet::TShape(1, 1)); + SHAPE_ASSIGN_CHECK(*out_shape, 2, mxnet::TShape(1, 1)); return true; } diff --git a/src/operator/quantization/quantized_pooling.cc b/src/operator/quantization/quantized_pooling.cc index af604080a756..1839e2a29d77 100644 --- a/src/operator/quantization/quantized_pooling.cc +++ b/src/operator/quantization/quantized_pooling.cc @@ -35,7 +35,7 @@ bool QuantizedPoolingShape(const nnvm::NodeAttrs& attrs, mxnet::ShapeVector *out_shape) { const PoolingParam& param = nnvm::get(attrs.parsed); CHECK_EQ(in_shape->size(), 3U); - if (shape_is_none(in_shape->at(0))) return false; + if (!shape_is_known(in_shape->at(0))) return false; const mxnet::TShape &dshape = (*in_shape)[0]; CHECK_EQ(dshape.ndim(), 4U) << "quantized_pooling: Input data should be 4D in " @@ -45,7 +45,7 @@ bool QuantizedPoolingShape(const nnvm::NodeAttrs& attrs, << "QuantizedPoolingOp only supports NCHW layout for now, saw " << layout; // NCHW layout const int N = 0, H = 2, W = 3, C = 1; - mxnet::TShape oshape(4); + mxnet::TShape oshape(4, -1); CHECK_EQ(param.kernel.ndim(), 2) << "QuantizedPoolingOp only supports 2D pooling for now"; CHECK(param.kernel[0] <= dshape[H] + 2 * param.pad[0]) << "kernel size (" << param.kernel[0] diff --git a/src/operator/quantization/requantize-inl.h b/src/operator/quantization/requantize-inl.h index 21d58d4607eb..9106c7fe4716 100644 --- a/src/operator/quantization/requantize-inl.h +++ b/src/operator/quantization/requantize-inl.h @@ -111,7 +111,7 @@ void RequantizeForward(const nnvm::NodeAttrs& attrs, const size_t actual_float_size = sizeof(float); const size_t actual_quantized_size = sizeof(SrcDType); const size_t temp_reduce_size = ConfigReduce( - s, inputs[0].shape_, mxnet::TShape({1}), &src_shape, &dst_shape); + s, inputs[0].shape_, mxnet::TShape(1, 1), &src_shape, &dst_shape); Tensor temp_space = ctx.requested[0].get_space_typed( Shape1(2*actual_float_size+2*actual_quantized_size+temp_reduce_size), s); diff --git a/src/operator/random/multisample_op.h b/src/operator/random/multisample_op.h index e9f266932e13..7d5e256297ad 100644 --- a/src/operator/random/multisample_op.h +++ b/src/operator/random/multisample_op.h @@ -66,7 +66,7 @@ inline bool MultiSampleOpShape(const nnvm::NodeAttrs& attrs, // Get shape to be sampled for each parameter set. const MultiSampleParam& param = nnvm::get(attrs.parsed); mxnet::TShape sshape = param.shape; - for (size_t i = 0; i < sshape.ndim(); ++i) { + for (int i = 0; i < sshape.ndim(); ++i) { CHECK_GT(sshape[i], 0) << "shape parameter must be non-zero within each dimension"; } // Examine output shape whether it is already defined. diff --git a/src/operator/random/sample_multinomial_op.h b/src/operator/random/sample_multinomial_op.h index e76cd646b850..b38aefbc1634 100644 --- a/src/operator/random/sample_multinomial_op.h +++ b/src/operator/random/sample_multinomial_op.h @@ -41,7 +41,7 @@ struct SampleMultinomialParam : public dmlc::Parameter { int dtype; DMLC_DECLARE_PARAMETER(SampleMultinomialParam) { DMLC_DECLARE_FIELD(shape) - .set_default(mxnet::TShape()) + .set_default(mxnet::TShape(0, 1)) .describe("Shape to be sampled from each random distribution."); DMLC_DECLARE_FIELD(get_prob) .set_default(false) @@ -68,7 +68,7 @@ inline bool SampleMultinomialOpShape(const nnvm::NodeAttrs& attrs, CHECK_EQ(in_attrs->size(), 1U); CHECK_EQ(out_attrs->size(), param.get_prob ? 2U : 1U); const mxnet::TShape& ishape = (*in_attrs)[0]; - if (!ishape.ndim()) return false; + if (!shape_is_known(ishape)) return false; MSHADOW_TYPE_SWITCH(param.dtype, DType, { CHECK_LE(ishape[ishape.ndim() - 1], mxnet::common::MaxIntegerValue()) @@ -76,26 +76,26 @@ inline bool SampleMultinomialOpShape(const nnvm::NodeAttrs& attrs, }); if (ishape.ndim() == 1) { - if (param.shape.ndim()) { + if (param.shape.ndim() > 0) { SHAPE_ASSIGN_CHECK(*out_attrs, 0, param.shape); if (param.get_prob) SHAPE_ASSIGN_CHECK(*out_attrs, 1, param.shape); } else { - SHAPE_ASSIGN_CHECK(*out_attrs, 0, mxnet::TShape(1)); - if (param.get_prob) SHAPE_ASSIGN_CHECK(*out_attrs, 1, mxnet::TShape(1)); + SHAPE_ASSIGN_CHECK(*out_attrs, 0, mxnet::TShape(1, 1)); + if (param.get_prob) SHAPE_ASSIGN_CHECK(*out_attrs, 1, mxnet::TShape(1, 1)); } return true; } - mxnet::TShape oshape(ishape.ndim() - 1 + param.shape.ndim()); - for (size_t i = 0; i < ishape.ndim() - 1; ++i) { + mxnet::TShape oshape(ishape.ndim() - 1 + param.shape.ndim(), -1); + for (int i = 0; i < ishape.ndim() - 1; ++i) { oshape[i] = ishape[i]; } - for (size_t i = 0; i < param.shape.ndim(); ++i) { + for (int i = 0; i < param.shape.ndim(); ++i) { oshape[i + ishape.ndim() - 1] = param.shape[i]; } SHAPE_ASSIGN_CHECK(*out_attrs, 0, oshape); if (param.get_prob) SHAPE_ASSIGN_CHECK(*out_attrs, 1, oshape); - return true; + return shape_is_known(out_attrs->at(0)) && shape_is_known(out_attrs->at(1)); } diff --git a/src/operator/random/unique_sample_op.h b/src/operator/random/unique_sample_op.h index 87998c8f46b1..e88b95a8bdd6 100644 --- a/src/operator/random/unique_sample_op.h +++ b/src/operator/random/unique_sample_op.h @@ -60,7 +60,7 @@ inline bool SampleUniqueShape(const nnvm::NodeAttrs& attrs, CHECK_EQ(in_attrs->size(), 0U); CHECK_EQ(out_attrs->size(), 2U); // output shape is known - if ((*out_attrs)[0].ndim() == 2 && param.shape.ndim() == 0) { + if ((*out_attrs)[0].ndim() == 2 && !mxnet::ndim_is_known(param.shape)) { SHAPE_ASSIGN_CHECK(*out_attrs, 1, mshadow::Shape1((*out_attrs)[0][0])); return true; } diff --git a/src/operator/regression_output-inl.h b/src/operator/regression_output-inl.h index 8b63a8a2cff6..d8f102de1675 100644 --- a/src/operator/regression_output-inl.h +++ b/src/operator/regression_output-inl.h @@ -57,7 +57,7 @@ inline bool RegressionOpShape(const nnvm::NodeAttrs& attrs, using namespace mshadow; CHECK_EQ(in_attrs->size(), 2U) << "Input:[data, label]"; const mxnet::TShape &dshape = in_attrs->at(0); - if (dshape.ndim() == 0) return false; + if (!shape_is_known(dshape)) return false; auto &lshape = (*in_attrs)[1]; if (lshape.ndim() == 0) { // special treatment for 1D output, to allow 1D label by default. diff --git a/src/operator/rnn.cc b/src/operator/rnn.cc index 74c563afceb1..7012a3c22f50 100644 --- a/src/operator/rnn.cc +++ b/src/operator/rnn.cc @@ -50,7 +50,7 @@ static bool RNNShape(const nnvm::NodeAttrs& attrs, "Needed input:[data, parameters, state], got in_shape->size(): " << in_shape->size(); } const TShape &dshape = (*in_shape)[rnn_enum::kData]; - if (dshape.ndim() == 0) return false; + if (!mxnet::ndim_is_known(dshape)) return false; CHECK_EQ(dshape.ndim(), 3U) \ << "Input data should be rank-3 tensor of dim [sequence length, batch size, input size]"; // data: [sequence len, batch, input dimension] diff --git a/src/operator/sequence_last-inl.h b/src/operator/sequence_last-inl.h index b4db80bdd721..4c42934f1618 100644 --- a/src/operator/sequence_last-inl.h +++ b/src/operator/sequence_last-inl.h @@ -263,7 +263,7 @@ class SequenceLastProp : public OperatorProperty { SHAPE_ASSIGN_CHECK(*in_shape, seq_last::kSequenceLength, Shape1(sbatch)); // calculate output size - mxnet::TShape shape_o(dshape.ndim() - 1); + mxnet::TShape shape_o(dshape.ndim() - 1, -1); shape_o[0] = sbatch; for (index_t i = 1; i < shape_o.ndim(); ++i) shape_o[i] = dshape[i + 1]; diff --git a/src/operator/slice_channel-inl.h b/src/operator/slice_channel-inl.h index 6125782d525b..e37ffdcf1b91 100644 --- a/src/operator/slice_channel-inl.h +++ b/src/operator/slice_channel-inl.h @@ -195,9 +195,9 @@ class SliceChannelProp : public OperatorProperty { CHECK_EQ(in_shape->size(), 1U); mxnet::TShape dshape = in_shape->at(slice_enum::kData); mxnet::TShape ishape = in_shape->at(slice_enum::kData); - if (dshape.ndim() == 0) return false; + if (!mxnet::ndim_is_known(dshape)) return false; if (param_.axis >= 0) { - CHECK_LT(static_cast(param_.axis), dshape.ndim()); + CHECK_LT(param_.axis, dshape.ndim()); } else { CHECK_LT(param_.axis + dshape.ndim(), dshape.ndim()); } @@ -212,15 +212,18 @@ class SliceChannelProp : public OperatorProperty { << " evenly sized chunks, but this is not possible because " << param_.num_outputs << " does not evenly divide " << dshape[real_axis]; - if (param_.squeeze_axis && ishape[real_axis] != 0) { - CHECK_EQ(ishape[real_axis], static_cast(param_.num_outputs)) + if (param_.squeeze_axis && ishape[real_axis] != -1) { + CHECK_EQ(ishape[real_axis], param_.num_outputs) << "If squeeze axis is True, the size of the sliced axis must be the same as num_outputs." << " Input shape=" << ishape << ", axis=" << real_axis << ", num_outputs=" << param_.num_outputs << "."; } - dshape[real_axis] /= param_.num_outputs; - if (param_.squeeze_axis && (dshape[real_axis] == 1 || ishape[real_axis] == 0)) { - for (int d = real_axis; d < static_cast(dshape.ndim()) - 1; ++d) { + if (dshape[real_axis] >= 0) { + dshape[real_axis] /= param_.num_outputs; + } + if (param_.squeeze_axis && (dshape[real_axis] == 1 + || !mxnet::dim_size_is_known(ishape, real_axis))) { + for (int d = real_axis; d < dshape.ndim() - 1; ++d) { dshape[d] = dshape[d+1]; } dshape = mxnet::TShape(&dshape[0], &dshape[dshape.ndim()-1]); diff --git a/src/operator/softmax_output-inl.h b/src/operator/softmax_output-inl.h index 5dca8bac14a3..80ab40ef6c50 100644 --- a/src/operator/softmax_output-inl.h +++ b/src/operator/softmax_output-inl.h @@ -337,19 +337,19 @@ class SoftmaxOutputProp : public OperatorProperty { using namespace mshadow; CHECK_EQ(in_shape->size(), 2U) << "Input:[data, label]"; const mxnet::TShape &dshape = in_shape->at(0); - if (dshape.ndim() == 0) return false; + if (!shape_is_known(dshape)) return false; // label.shape == data.shape: use probability as label if (dshape != (*in_shape)[softmaxout_enum::kLabel]) { if (param_.multi_output) { mxnet::TShape lshape1 = Shape2(dshape[0], dshape.Size()/dshape[0]/dshape[1]); - mxnet::TShape lshape2(dshape.ndim() - 1); + mxnet::TShape lshape2(dshape.ndim() - 1, -1); lshape2[0] = dshape[0]; - for (index_t i = 2; i < dshape.ndim(); ++i) + for (int i = 2; i < dshape.ndim(); ++i) lshape2[i-1] = dshape[i]; mxnet::TShape lshape3 = dshape; lshape3[1] = 1; - if (in_shape->at(softmaxout_enum::kLabel).ndim() == 0) { + if (!mxnet::ndim_is_known(in_shape->at(softmaxout_enum::kLabel))) { in_shape->at(softmaxout_enum::kLabel) = lshape1; } else if (in_shape->at(softmaxout_enum::kLabel) == lshape1) { } else if (in_shape->at(softmaxout_enum::kLabel) == lshape2) { @@ -361,8 +361,8 @@ class SoftmaxOutputProp : public OperatorProperty { throw InferShapeError(os.str(), softmaxout_enum::kLabel); } } else { - mxnet::TShape label_shape(dshape.ndim() - 1); - for (index_t i = 0; i + 1 < dshape.ndim(); ++i) + mxnet::TShape label_shape(dshape.ndim() - 1, -1); + for (int i = 0; i + 1 < dshape.ndim(); ++i) label_shape[i] = dshape[i]; SHAPE_ASSIGN_CHECK(*in_shape, softmaxout_enum::kLabel, label_shape); } diff --git a/src/operator/softmax_output.cc b/src/operator/softmax_output.cc index b17ef3527297..548225f0496b 100644 --- a/src/operator/softmax_output.cc +++ b/src/operator/softmax_output.cc @@ -85,19 +85,19 @@ static bool SoftmaxOutputShape(const nnvm::NodeAttrs& attrs, const SoftmaxOutputParam& param = nnvm::get(attrs.parsed); CHECK_EQ(in_shape->size(), 2U) << "Input:[data, label]"; const mxnet::TShape &dshape = in_shape->at(0); - if (dshape.ndim() == 0) return false; + if (!mxnet::ndim_is_known(dshape)) return false; // label.shape == data.shape: use probability as label if (dshape != (*in_shape)[softmaxout_enum::kLabel]) { if (param.multi_output) { mxnet::TShape lshape1 = Shape2(dshape[0], dshape.Size()/dshape[0]/dshape[1]); - mxnet::TShape lshape2(dshape.ndim() - 1); + mxnet::TShape lshape2(dshape.ndim() - 1, -1); lshape2[0] = dshape[0]; - for (index_t i = 2; i < dshape.ndim(); ++i) + for (int i = 2; i < dshape.ndim(); ++i) lshape2[i-1] = dshape[i]; mxnet::TShape lshape3 = dshape; lshape3[1] = 1; - if (in_shape->at(softmaxout_enum::kLabel).ndim() == 0) { + if (!mxnet::ndim_is_known(in_shape->at(softmaxout_enum::kLabel))) { in_shape->at(softmaxout_enum::kLabel) = lshape1; } else if (in_shape->at(softmaxout_enum::kLabel) == lshape1) { } else if (in_shape->at(softmaxout_enum::kLabel) == lshape2) { @@ -109,8 +109,8 @@ static bool SoftmaxOutputShape(const nnvm::NodeAttrs& attrs, throw InferShapeError(os.str(), softmaxout_enum::kLabel); } } else { - mxnet::TShape label_shape(dshape.ndim() - 1); - for (index_t i = 0; i + 1 < dshape.ndim(); ++i) + mxnet::TShape label_shape(dshape.ndim() - 1, -1); + for (int i = 0; i + 1 < dshape.ndim(); ++i) label_shape[i] = dshape[i]; SHAPE_ASSIGN_CHECK(*in_shape, softmaxout_enum::kLabel, label_shape); } diff --git a/src/operator/spatial_transformer-inl.h b/src/operator/spatial_transformer-inl.h index 9e5dee842d0d..660d57d55bab 100644 --- a/src/operator/spatial_transformer-inl.h +++ b/src/operator/spatial_transformer-inl.h @@ -190,10 +190,10 @@ class SpatialTransformerProp : public OperatorProperty { CHECK_EQ(param_.sampler_type, st::kBilinear) << "only supports bilinear sampling currently"; const mxnet::TShape &dshape = (*in_shape)[st::kData]; const mxnet::TShape &lshape = (*in_shape)[st::kLoc]; - if (dshape.ndim() == 0) return false; + if (!shape_is_known(dshape)) return false; CHECK_EQ(dshape.ndim(), 4U) \ << "input data should be 4D in batch-num_filter-y-x"; - if (lshape.ndim() == 0) return false; + if (!shape_is_known(lshape)) return false; CHECK_EQ(lshape.ndim(), 2U) \ << "locolisation paramter should be 4D in batch-num_hidden"; if (param_.transform_type == st::kAffine) { diff --git a/src/operator/subgraph_op_common.cc b/src/operator/subgraph_op_common.cc index 8934438d428a..e53d911614a0 100644 --- a/src/operator/subgraph_op_common.cc +++ b/src/operator/subgraph_op_common.cc @@ -178,7 +178,7 @@ bool as_bool_scalar(const NDArray &a) { } bool is_shape_udf(const mxnet::TShape &x) { - return x.ndim() == 0 || x.Size() == 0; + return !shape_is_known(x); } bool is_stype_udf(const int &x) { @@ -225,7 +225,7 @@ void LoopState::Forward(int iter_no, if (!out_bufs[i].IsSame(coutputs[i])) { // The line below checks whether dynamic shape exists. // If so, re-initialize the shape. - if (coutputs[i].shape().ndim() == 0) { + if (!shape_is_known(coutputs[i].shape())) { const_cast(coutputs[i]).Init(out_bufs[i].shape()); } CopyFromTo(out_bufs[i], coutputs[i]); diff --git a/src/operator/subgraph_op_common.h b/src/operator/subgraph_op_common.h index 91adf576dc07..19528349c0c7 100644 --- a/src/operator/subgraph_op_common.h +++ b/src/operator/subgraph_op_common.h @@ -67,7 +67,7 @@ bool is_type_udf(const int &x); template void extract_by_loc(const std::vector &array, - const nnvm::Tuple input_locs, + const mxnet::Tuple input_locs, std::vector *out) { out->clear(); out->reserve(input_locs.ndim()); @@ -94,11 +94,11 @@ bool fill_value(T *x, T *y, bool x_empty, bool y_empty) { } template -bool sync_in_in(const nnvm::Tuple &input_locs, - std::vector *in, - std::vector *subg_in, - std::function is_empty) { - for (size_t i = 0; i < input_locs.ndim(); ++i) { +bool sync_in_in(const mxnet::Tuple &input_locs, + std::vector *in, + std::vector *subg_in, + std::function is_empty) { + for (int i = 0; i < input_locs.ndim(); ++i) { T &x = in->at(input_locs[i]); T &y = subg_in->at(i); fill_value(&x, &y, is_empty(x), is_empty(y)); diff --git a/src/operator/svm_output-inl.h b/src/operator/svm_output-inl.h index 1609764f0ebe..dfe9fa606e95 100644 --- a/src/operator/svm_output-inl.h +++ b/src/operator/svm_output-inl.h @@ -143,9 +143,9 @@ class SVMOutputProp : public OperatorProperty { using namespace mshadow; CHECK_EQ(in_shape->size(), 2U) << "Input:[data, label]"; const mxnet::TShape &dshape = in_shape->at(0); - if (dshape.ndim() == 0) return false; - mxnet::TShape label_shape(dshape.ndim() - 1); - for (index_t i = 0; i + 1 < dshape.ndim(); ++i) + if (!mxnet::ndim_is_known(dshape)) return false; + mxnet::TShape label_shape(dshape.ndim() - 1, -1); + for (int i = 0; i + 1 < dshape.ndim(); ++i) label_shape[i] = dshape[i]; SHAPE_ASSIGN_CHECK(*in_shape, svm_enum::kLabel, label_shape); out_shape->clear(); diff --git a/src/operator/swapaxis-inl.h b/src/operator/swapaxis-inl.h index ce835084ab32..7335daa48392 100644 --- a/src/operator/swapaxis-inl.h +++ b/src/operator/swapaxis-inl.h @@ -69,11 +69,11 @@ class SwapAxisOp : public Operator { void Reshape2Five(mshadow::Shape<5> *inter_shape, const mxnet::TShape &shape, - uint32_t dim1, uint32_t dim2) { + int dim1, int dim2) { using namespace mshadow; using namespace mshadow::expr; - index_t ndim_in = shape.ndim(); - index_t si; + int ndim_in = shape.ndim(); + int si; if (dim1 > dim2) { std::swap(dim1, dim2); diff --git a/src/operator/tensor/broadcast_reduce_op.h b/src/operator/tensor/broadcast_reduce_op.h index 069c8ddb04fb..fc51d8af0f01 100644 --- a/src/operator/tensor/broadcast_reduce_op.h +++ b/src/operator/tensor/broadcast_reduce_op.h @@ -139,9 +139,9 @@ struct BroadcastAxesParam : public dmlc::Parameter { mxnet::TShape axis; mxnet::TShape size; DMLC_DECLARE_PARAMETER(BroadcastAxesParam) { - DMLC_DECLARE_FIELD(axis).set_default(mxnet::TShape()) + DMLC_DECLARE_FIELD(axis).set_default(mxnet::TShape(0, -1)) .describe("The axes to perform the broadcasting."); - DMLC_DECLARE_FIELD(size).set_default(mxnet::TShape()) + DMLC_DECLARE_FIELD(size).set_default(mxnet::TShape(0, -1)) .describe("Target sizes of the broadcasting axes."); } }; @@ -149,7 +149,7 @@ struct BroadcastAxesParam : public dmlc::Parameter { struct BroadcastToParam : public dmlc::Parameter { mxnet::TShape shape; DMLC_DECLARE_PARAMETER(BroadcastToParam) { - DMLC_DECLARE_FIELD(shape).set_default(mxnet::TShape()) + DMLC_DECLARE_FIELD(shape).set_default(mxnet::TShape(0, -1)) .describe("The shape of the desired array." " We can set the dim to zero if it's same as the original." " E.g `A = broadcast_to(B, shape=(10, 0, 0))` " @@ -175,7 +175,7 @@ inline int CheckAxis(int axis, int ndim) { } inline mxnet::TShape AxisShapeCompact(mxnet::TShape shape, int *axis, bool allow_2d) { - int ndim = static_cast(shape.ndim()); + int ndim = shape.ndim(); index_t leading = 1, trailing = 1, M = shape[*axis]; for (int i = 0; i < *axis; ++i) leading *= shape[i]; for (int i = *axis + 1; i < ndim; ++i) trailing *= shape[i]; @@ -196,7 +196,7 @@ inline mxnet::TShape ReduceAxisShapeImpl(const mxnet::TShape& ishape, bool keepdims) { if (!axis || ishape.ndim() == 1) { if (keepdims) { - return mxnet::TShape(ishape.ndim()); + return mxnet::TShape(ishape.ndim(), 1); } return mshadow::Shape1(1); } @@ -208,7 +208,7 @@ inline mxnet::TShape ReduceAxisShapeImpl(const mxnet::TShape& ishape, return oshape; } - mxnet::TShape oshape(ishape.ndim() - 1); + mxnet::TShape oshape(ishape.ndim() - 1, 1); for (int i = 0; i < new_axis; ++i) oshape[i] = ishape[i]; for (int i = new_axis+1; i < static_cast(ishape.ndim()); ++i) { oshape[i-1] = ishape[i]; @@ -222,7 +222,7 @@ inline bool ReduceAxisShape(const nnvm::NodeAttrs& attrs, CHECK_EQ(in_attrs->size(), 1U); CHECK_EQ(out_attrs->size(), 1U); mxnet::TShape& ishape = (*in_attrs)[0]; - if (ishape.ndim() == 0) return false; + if (!shape_is_known(ishape)) return false; const ReduceAxisParam& param = nnvm::get(attrs.parsed); SHAPE_ASSIGN_CHECK(*out_attrs, 0, @@ -233,12 +233,12 @@ inline bool ReduceAxisShape(const nnvm::NodeAttrs& attrs, inline mxnet::TShape ReduceAxesShapeImpl(const mxnet::TShape& ishape, const dmlc::optional& axis, bool keepdims, bool exclude) { - // if axis doesn't have value, treat it same mxnet::TShape(). + // if axis doesn't have value, treat it same mxnet::TShape(0). if (!axis.has_value() || axis.value().ndim() == 0) { if (keepdims) { - return mxnet::TShape(ishape.ndim()); + return mxnet::TShape(ishape.ndim(), 1); } else { - return mxnet::TShape(1); + return mxnet::TShape(1, 1); } } // axis has value @@ -266,9 +266,9 @@ inline mxnet::TShape ReduceAxesShapeImpl(const mxnet::TShape& ishape, if (keepdims) { oshape = mxnet::TShape(ishape); } else if (exclude) { - oshape = mxnet::TShape(axes.ndim()); + oshape = mxnet::TShape(axes.ndim(), 1); } else { - oshape = mxnet::TShape(std::max(1, ishape.ndim() - axes.ndim())); + oshape = mxnet::TShape(std::max(1, ishape.ndim() - axes.ndim()), 1); } if (keepdims && exclude) { @@ -304,7 +304,7 @@ inline bool ReduceAxesShape(const nnvm::NodeAttrs& attrs, mxnet::ShapeVector *out_attrs) { CHECK_EQ(in_attrs->size(), 1U); CHECK_EQ(out_attrs->size(), 1U); - if ((*in_attrs)[0].ndim() == 0) return false; + if (!shape_is_known((*in_attrs)[0])) return false; const ReduceAxesParam& param = nnvm::get(attrs.parsed); SHAPE_ASSIGN_CHECK(*out_attrs, 0, ReduceAxesShapeImpl((*in_attrs)[0], param.axis, @@ -334,7 +334,7 @@ inline bool NormShape(const nnvm::NodeAttrs& attrs, mxnet::ShapeVector *out_attrs) { CHECK_EQ(in_attrs->size(), 1U); CHECK_EQ(out_attrs->size(), 1U); - if ((*in_attrs)[0].ndim() == 0) return false; + if (!shape_is_known((*in_attrs)[0])) return false; const NormParam& param = nnvm::get(attrs.parsed); SHAPE_ASSIGN_CHECK(*out_attrs, 0, ReduceAxesShapeImpl((*in_attrs)[0], param.axis, @@ -347,12 +347,12 @@ inline bool BroadcastAxesShape(const nnvm::NodeAttrs& attrs, mxnet::ShapeVector *out_attrs) { CHECK_EQ(in_attrs->size(), 1U); CHECK_EQ(out_attrs->size(), 1U); - if ((*in_attrs)[0].ndim() == 0) return false; + if (!shape_is_known((*in_attrs)[0])) return false; const BroadcastAxesParam& param = nnvm::get(attrs.parsed); CHECK_EQ(param.axis.ndim() , param.size.ndim()); mxnet::TShape &ishape = (*in_attrs)[0]; mxnet::TShape oshape = ishape; - for (index_t i = 0; i < param.axis.ndim(); ++i) { + for (int i = 0; i < param.axis.ndim(); ++i) { CHECK_EQ(oshape[param.axis[i]], 1U) << "Broadcasting axis must have size 1"; oshape[param.axis[i]] = param.size[i]; } @@ -366,13 +366,13 @@ inline bool BroadcastToShape(const nnvm::NodeAttrs& attrs, CHECK_EQ(in_attrs->size(), 1U); CHECK_EQ(out_attrs->size(), 1U); mxnet::TShape& ishape = (*in_attrs)[0]; - if (ishape.ndim() == 0) return false; + if (!mxnet::ndim_is_known(ishape)) return false; const BroadcastToParam& param = nnvm::get(attrs.parsed); CHECK_EQ(ishape.ndim(), param.shape.ndim()) << "Operand of shape " << ishape << " cannot be broadcasted to " << param.shape; mxnet::TShape oshape = param.shape; - for (index_t i = 0; i < ishape.ndim(); ++i) { - if (oshape[i] != 0) { + for (int i = 0; i < ishape.ndim(); ++i) { + if (oshape[i] != -1) { CHECK(ishape[i] == oshape[i] || ishape[i] == 1) << "Array cannot be broadcasted from " << ishape << " to " << param.shape; } else { @@ -391,7 +391,7 @@ inline bool BroadcastLikeShape(const nnvm::NodeAttrs& attrs, mxnet::TShape& lhs_shape = (*in_attrs)[0]; mxnet::TShape& rhs_shape = (*in_attrs)[1]; - if ((lhs_shape.ndim() == 0) || (lhs_shape.ndim() == 0)) { + if (!mxnet::ndim_is_known(lhs_shape) || !mxnet::ndim_is_known(rhs_shape)) { return false; } @@ -404,8 +404,8 @@ inline bool BroadcastLikeShape(const nnvm::NodeAttrs& attrs, << "Operand of shape " << lhs_shape << " cannot be broadcasted to " << rhs_shape; oshape = mxnet::TShape(rhs_shape); - for (index_t i = 0; i < lhs_shape.ndim(); ++i) { - if (rhs_shape[i] != 0) { + for (int i = 0; i < lhs_shape.ndim(); ++i) { + if (rhs_shape[i] != -1) { CHECK(lhs_shape[i] == rhs_shape[i] || lhs_shape[i] == 1) << "Array cannot be broadcasted from " << lhs_shape << " to " << rhs_shape; } else { @@ -423,7 +423,7 @@ inline bool BroadcastLikeShape(const nnvm::NodeAttrs& attrs, << "Empty axes tuple is not allowed"; oshape = mxnet::TShape(lhs_shape); - for (index_t i = 0; i < lhs_axes.ndim(); ++i) { + for (int i = 0; i < lhs_axes.ndim(); ++i) { auto copyfrom = lhs_axes[i]; if (copyfrom < 0) { copyfrom = lhs_shape.ndim() + copyfrom; @@ -450,9 +450,9 @@ inline bool BroadcastLikeShape(const nnvm::NodeAttrs& attrs, inline void BroadcastReduceShapeCompact(const mxnet::TShape& big, const mxnet::TShape& small, mxnet::TShape *new_big, mxnet::TShape *new_small) { - index_t idim = std::max(big.ndim(), MXNET_SPECIAL_MAX_NDIM); - *new_big = mxnet::TShape(idim); - *new_small = mxnet::TShape(idim); + const int idim = std::max(big.ndim(), MXNET_SPECIAL_MAX_NDIM); + *new_big = mxnet::TShape(idim, 1); + *new_small = mxnet::TShape(idim, 1); index_t j = 0; if (small.Size() == 1) { (*new_big)[j++] = big.Size(); @@ -478,12 +478,10 @@ inline void BroadcastReduceShapeCompact(const mxnet::TShape& big, const mxnet::T ++j; } } - if (j <= 2) { - new_small->assign(&(*new_small)[0], &(*new_small)[2]); - new_big->assign(&(*new_big)[0], &(*new_big)[2]); - } else if (j <= MXNET_SPECIAL_MAX_NDIM) { - new_small->assign(&(*new_small)[0], &(*new_small)[MXNET_SPECIAL_MAX_NDIM]); - new_big->assign(&(*new_big)[0], &(*new_big)[MXNET_SPECIAL_MAX_NDIM]); + if (j <= MXNET_SPECIAL_MAX_NDIM) { + const int ndim = (j <= 2? 2 : MXNET_SPECIAL_MAX_NDIM); + new_small->assign(new_small->begin(), new_small->begin() + ndim); + new_big->assign(new_big->begin(), new_big->begin() + ndim); } else { LOG(FATAL) << "Too many reduction axes from " << big << " to " << small; } @@ -889,7 +887,7 @@ void ReduceAxesBackwardUseInOutImpl(const OpContext& ctx, MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, OType, { mshadow::Shape<5> in_shape; mshadow::Shape<5> out_shape; - for (uint32_t i = 0; i < 5; ++i) { + for (int i = 0; i < 5; ++i) { if (i < dst_shape.ndim()) { in_shape[i] = src_shape[i]; out_shape[i] = dst_shape[i]; @@ -1229,7 +1227,7 @@ void LpNormGradCompute(const nnvm::NodeAttrs& attrs, Stream *s = ctx.get_stream(); mshadow::Shape<5> in_shape; mshadow::Shape<5> out_shape; - for (uint32_t i = 0; i < 5; ++i) { + for (int i = 0; i < 5; ++i) { if (i < dst_shape.ndim()) { in_shape[i] = src_shape[i]; out_shape[i] = dst_shape[i]; diff --git a/src/operator/tensor/diag_op-inl.h b/src/operator/tensor/diag_op-inl.h index 1e3c1c9701d4..c95c1ce414f2 100644 --- a/src/operator/tensor/diag_op-inl.h +++ b/src/operator/tensor/diag_op-inl.h @@ -84,19 +84,19 @@ inline mxnet::TShape DiagShapeImpl(const mxnet::TShape& ishape, const int k, auto s = std::min(h, w); if (s < 0) { - s = 0; + s = -1; } if (x1 > x2) { std::swap(x1, x2); } - int32_t n_dim = static_cast(ishape.ndim()) - 1; - mxnet::TShape oshape(n_dim); + int32_t n_dim = ishape.ndim() - 1; + mxnet::TShape oshape(n_dim, -1); // remove axis1 and axis2 and append the new axis to the end uint32_t idx = 0; - for (int32_t i = 0; i <= n_dim; ++i) { + for (int i = 0; i <= n_dim; ++i) { if (i != x1 && i != x2) { oshape[idx++] = ishape[i]; } @@ -114,7 +114,7 @@ inline bool DiagOpShape(const nnvm::NodeAttrs& attrs, CHECK_EQ(out_attrs->size(), 1U); const mxnet::TShape& ishape = (*in_attrs)[0]; - if (ishape.ndim() == 0) { + if (!mxnet::ndim_is_known(ishape)) { return false; } @@ -129,7 +129,7 @@ inline bool DiagOpShape(const nnvm::NodeAttrs& attrs, } SHAPE_ASSIGN_CHECK(*out_attrs, 0, oshape); - return out_attrs->at(0).ndim() != 0U; + return shape_is_known(out_attrs->at(0)); } inline bool DiagOpType(const nnvm::NodeAttrs& attrs, diff --git a/src/operator/tensor/dot-inl.h b/src/operator/tensor/dot-inl.h index 8a1eda0350b0..318254b26b9f 100644 --- a/src/operator/tensor/dot-inl.h +++ b/src/operator/tensor/dot-inl.h @@ -1241,20 +1241,20 @@ inline bool DotShape(const nnvm::NodeAttrs& attrs, if (Ta) { L[0] = mshadow::Shape1(lshape[0]); L[1] = lshape.ndim() > 1 ? - mxnet::TShape(&lshape[1], &lshape[lshape.ndim()]) : mxnet::TShape(1); + mxnet::TShape(&lshape[1], lshape.end()) : mxnet::TShape(1, 1); } else { L[0] = lshape.ndim() > 1 ? - mxnet::TShape(&lshape[0], &lshape[lshape.ndim()-1]) : mxnet::TShape(1); + mxnet::TShape(&lshape[0], &lshape[lshape.ndim()-1]) : mxnet::TShape(1, 1); L[1] = mshadow::Shape1(lshape[lshape.ndim()-1]); } if (Tb) { R[0] = rshape.ndim() > 1 ? - mxnet::TShape(&rshape[0], &rshape[rshape.ndim()-1]) : mxnet::TShape(1); + mxnet::TShape(&rshape[0], &rshape[rshape.ndim()-1]) : mxnet::TShape(1, 1); R[1] = mshadow::Shape1(rshape[rshape.ndim()-1]); } else { R[0] = mshadow::Shape1(rshape[0]); R[1] = rshape.ndim() > 1 ? - mxnet::TShape(&rshape[1], &rshape[rshape.ndim()]) : mxnet::TShape(1); + mxnet::TShape(&rshape[1], rshape.end()) : mxnet::TShape(1, 1); } if (L[!Ta].Size() != 0 && R[Tb].Size() != 0) { @@ -1262,8 +1262,8 @@ inline bool DotShape(const nnvm::NodeAttrs& attrs, << "dot shape error: " << lshape << " X " << rshape; } std::vector buf; - if (lshape.ndim() > 1) buf.insert(buf.end(), &L[Ta][0], &L[Ta][L[Ta].ndim()]); - if (rshape.ndim() > 1) buf.insert(buf.end(), &R[!Tb][0], &R[!Tb][R[!Tb].ndim()]); + if (lshape.ndim() > 1) buf.insert(buf.end(), &L[Ta][0], L[Ta].end()); + if (rshape.ndim() > 1) buf.insert(buf.end(), &R[!Tb][0], R[!Tb].end()); mxnet::TShape oshape(buf.begin(), buf.end()); SHAPE_ASSIGN_CHECK(*out_attrs, 0, oshape); } diff --git a/src/operator/tensor/elemwise_binary_broadcast_op.h b/src/operator/tensor/elemwise_binary_broadcast_op.h index 1d2b7c9c1163..73019fa8389b 100644 --- a/src/operator/tensor/elemwise_binary_broadcast_op.h +++ b/src/operator/tensor/elemwise_binary_broadcast_op.h @@ -48,33 +48,32 @@ inline bool BinaryBroadcastShape(const nnvm::NodeAttrs& attrs, mxnet::TShape& rhs = (*in_attrs)[1]; // avoid pre-mature shape inference. - if (lhs.ndim() == 0 || rhs.ndim() == 0) return false; + if (!mxnet::ndim_is_known(lhs) || !mxnet::ndim_is_known(rhs)) return false; if (lhs == rhs) { SHAPE_ASSIGN_CHECK(*out_attrs, 0, lhs); - return true; + return shape_is_known(lhs); } - mxnet::TShape out(std::max(lhs.ndim(), rhs.ndim())); - index_t bl = out.ndim() - lhs.ndim(); - index_t br = out.ndim() - rhs.ndim(); - for (index_t i = 0; i < out.ndim(); ++i) { - index_t l = 1, r = 1; + mxnet::TShape out(std::max(lhs.ndim(), rhs.ndim()), -1); + const int bl = out.ndim() - lhs.ndim(); + const int br = out.ndim() - rhs.ndim(); + for (int i = 0; i < out.ndim(); ++i) { + int l = 1, r = 1; if (i >= bl) l = lhs[i-bl]; if (i >= br) r = rhs[i-br]; + if (!mxnet::dim_size_is_known(l) || !mxnet::dim_size_is_known(r)) continue; if (l != r) { - if (l == 0 || r == 0) { - out[i] = 0; - } else { - CHECK(l == 1 || r == 1) - << "operands could not be broadcast together with shapes " << lhs << " " << rhs; - out[i] = std::max(l, r); - } + // Make it compatible with NumPy. + // For example, (2, 3) cannot broadcast to (2, 0, 3), but (1, 3) can broadcast to (2, 0, 3). + CHECK(l == 1 || r == 1) + << "operands could not be broadcast together with shapes " << lhs << " " << rhs; + out[i] = (l == 1 ? r : l); } else { out[i] = l; } } SHAPE_ASSIGN_CHECK(*out_attrs, 0, out); - return true; + return shape_is_known(lhs) && shape_is_known(rhs) && shape_is_known(out); } inline bool BinaryBroadcastMulStorageType(const nnvm::NodeAttrs& attrs, @@ -146,15 +145,15 @@ inline int BinaryBroadcastShapeCompact(const mxnet::TShape& lshape, const mxnet: const mxnet::TShape& oshape, mxnet::TShape *new_lshape, mxnet::TShape *new_rshape, mxnet::TShape *new_oshape) { if (lshape == rshape) return 0; - index_t odim = std::max(oshape.ndim(), broadcast::MAX_DIM); - *new_lshape = mxnet::TShape(odim); - *new_rshape = mxnet::TShape(odim); - *new_oshape = mxnet::TShape(odim); - index_t bl = oshape.ndim() - lshape.ndim(); - index_t br = oshape.ndim() - rshape.ndim(); - index_t j = 0, lprod = 1, rprod = 1, oprod = 1; - for (index_t i = 0; i < oshape.ndim(); ++i) { - index_t l = 1, r = 1, o = oshape[i]; + const int odim = std::max(oshape.ndim(), broadcast::MAX_DIM); + *new_lshape = mxnet::TShape(odim, 1); + *new_rshape = mxnet::TShape(odim, 1); + *new_oshape = mxnet::TShape(odim, 1); + int bl = oshape.ndim() - lshape.ndim(); + int br = oshape.ndim() - rshape.ndim(); + int j = 0, lprod = 1, rprod = 1, oprod = 1; + for (int i = 0; i < oshape.ndim(); ++i) { + int l = 1, r = 1, o = oshape[i]; if (i >= bl) l = lshape[i-bl]; if (i >= br) r = rshape[i-br]; if ((lprod != rprod || l != r) && @@ -176,9 +175,9 @@ inline int BinaryBroadcastShapeCompact(const mxnet::TShape& lshape, const mxnet: } if (j <= broadcast::MAX_DIM) { BROADCAST_NDIM_SWITCH(j, NDim, { - new_lshape->assign(&(*new_lshape)[0], &(*new_lshape)[NDim]); - new_rshape->assign(&(*new_rshape)[0], &(*new_rshape)[NDim]); - new_oshape->assign(&(*new_oshape)[0], &(*new_oshape)[NDim]); + new_lshape->assign(new_lshape->begin(), new_lshape->begin() + NDim); + new_rshape->assign(new_rshape->begin(), new_rshape->begin() + NDim); + new_oshape->assign(new_oshape->begin(), new_oshape->begin() + NDim); }); } else { LOG(FATAL) << "Too many broadcast dimensions with operands " << lshape << " " << rshape; diff --git a/src/operator/tensor/elemwise_unary_op_basic.cc b/src/operator/tensor/elemwise_unary_op_basic.cc index 19a9ac8359eb..5114a5d0dbe3 100644 --- a/src/operator/tensor/elemwise_unary_op_basic.cc +++ b/src/operator/tensor/elemwise_unary_op_basic.cc @@ -413,9 +413,9 @@ bool ReshapeLikeShapeCompute(const nnvm::NodeAttrs &attrs, GetReshapeLikeParams(param, lshape, rshape, &lhs_begin, &lhs_end, &rhs_begin, &rhs_end); - int lhsrank = static_cast(lshape.ndim()); + int lhsrank = lshape.ndim(); int orank = lhsrank + (rhs_end - rhs_begin) - (lhs_end - lhs_begin); - mxnet::TShape oshape(orank); + mxnet::TShape oshape(orank, -1); for (int i = 0; i < lhs_begin; ++i) oshape[i] = lshape[i]; @@ -436,7 +436,7 @@ bool ReshapeLikeShapeCompute(const nnvm::NodeAttrs &attrs, << "shape " << oshape << " because they have different " << "size."; SHAPE_ASSIGN_CHECK(*out_attrs, 0, oshape); - return true; + return shape_is_known(oshape); } DMLC_REGISTER_PARAMETER(ReshapeLikeParam); @@ -537,7 +537,7 @@ Example:: mxnet::ShapeVector *out_attrs) { CHECK_EQ(in_attrs->size(), 1U); CHECK_EQ(out_attrs->size(), 1U); - mxnet::TShape target_shape(1); + mxnet::TShape target_shape(1, -1); target_shape[0] = in_attrs->at(0).ndim(); SHAPE_ASSIGN_CHECK(*out_attrs, 0, target_shape); return !shape_is_none(out_attrs->at(0)); @@ -589,7 +589,7 @@ Example:: mxnet::ShapeVector *out_attrs) { CHECK_EQ(in_attrs->size(), 1U); CHECK_EQ(out_attrs->size(), 1U); - SHAPE_ASSIGN_CHECK(*out_attrs, 0, 1U); + SHAPE_ASSIGN_CHECK(*out_attrs, 0, mxnet::TShape(1, 1)); return !shape_is_none(out_attrs->at(0)); }) .set_attr("FInferType", diff --git a/src/operator/tensor/histogram-inl.h b/src/operator/tensor/histogram-inl.h index 51d0bdb6c2b6..7194445d7b52 100644 --- a/src/operator/tensor/histogram-inl.h +++ b/src/operator/tensor/histogram-inl.h @@ -46,13 +46,13 @@ namespace op { struct HistogramParam : public dmlc::Parameter { dmlc::optional bin_cnt; - dmlc::optional> range; + dmlc::optional> range; DMLC_DECLARE_PARAMETER(HistogramParam) { DMLC_DECLARE_FIELD(bin_cnt) .set_default(dmlc::optional()) .describe("Number of bins for uniform case"); DMLC_DECLARE_FIELD(range) - .set_default(dmlc::optional>()) + .set_default(dmlc::optional>()) .describe("The lower and upper range of the bins. if not provided, " "range is simply (a.min(), a.max()). values outside the " "range are ignored. the first element of the range must be " @@ -86,9 +86,9 @@ inline bool HistogramOpShape(const nnvm::NodeAttrs& attrs, if (has_cnt) { // if cnt is specified, the output histogram has shape (cnt,) // while output bins has shape (cnt+1,) - const int bin_cnt = param.bin_cnt.value(); - SHAPE_ASSIGN_CHECK(*out_attrs, 0, mxnet::TShape({bin_cnt})); - SHAPE_ASSIGN_CHECK(*out_attrs, 1, mxnet::TShape({bin_cnt + 1})); + const dim_t bin_cnt = param.bin_cnt.value(); + SHAPE_ASSIGN_CHECK(*out_attrs, 0, mxnet::TShape(1, bin_cnt)); + SHAPE_ASSIGN_CHECK(*out_attrs, 1, mxnet::TShape(1, bin_cnt + 1)); } else { // if cnt is not specified, the output histogram has shape (bins.Size() - 1) // while output bins has same shape as input bins @@ -97,11 +97,11 @@ inline bool HistogramOpShape(const nnvm::NodeAttrs& attrs, CHECK_EQ(oshape.ndim(), 1U) << "bins argument should be an 1D vector"; CHECK_GE(oshape.Size(), 2U) << "number of bounds should be >= 2"; - SHAPE_ASSIGN_CHECK(*out_attrs, 0, mxnet::TShape({(oshape[0] - 1)})); + SHAPE_ASSIGN_CHECK(*out_attrs, 0, mxnet::TShape(1, oshape[0] - 1)); SHAPE_ASSIGN_CHECK(*out_attrs, 1, in_attrs->at(1)); } - return !shape_is_none(out_attrs->at(0)) && !shape_is_none(out_attrs->at(1)) && + return shape_is_known(out_attrs->at(0)) && shape_is_known(out_attrs->at(1)) && out_attrs->at(0).Size() == out_attrs->at(1).Size() - 1; } diff --git a/src/operator/tensor/indexing_op.h b/src/operator/tensor/indexing_op.h index 8979531fef4e..6469aae17558 100644 --- a/src/operator/tensor/indexing_op.h +++ b/src/operator/tensor/indexing_op.h @@ -145,20 +145,20 @@ inline bool EmbeddingOpShape(const nnvm::NodeAttrs& attrs, mxnet::ShapeVector *out_attrs) { using namespace mshadow; const mxnet::TShape &dshape = (*in_attrs)[embedding::kData]; - if (dshape.ndim() == 0) return false; + if (!shape_is_known(dshape)) return false; const ParamType& param = nnvm::get(attrs.parsed); SHAPE_ASSIGN_CHECK(*in_attrs, embedding::kWeight, Shape2(param.input_dim, param.output_dim)); out_attrs->clear(); - mxnet::TShape oshape(dshape.ndim()+1); - for (size_t i = 0; i < dshape.ndim(); ++i) { + mxnet::TShape oshape(dshape.ndim()+1, -1); + for (int i = 0; i < dshape.ndim(); ++i) { oshape[i] = dshape[i]; } oshape[dshape.ndim()] = param.output_dim; out_attrs->push_back(oshape); - return true; + return shape_is_known(oshape); } template @@ -682,18 +682,18 @@ inline bool TakeOpShape(const nnvm::NodeAttrs& attrs, using namespace mshadow; const mxnet::TShape &arrshape = (*in_attrs)[take_::kArr]; const mxnet::TShape &idxshape = (*in_attrs)[take_::kIdx]; - if (idxshape.ndim() == 0U || idxshape.Size() == 0U) return false; + if (!shape_is_known(idxshape)) return false; const TakeParam& param = nnvm::get(attrs.parsed); if (param.mode == take_::kRaise) { LOG(FATAL) << "Raise is not supported for the time being..."; } - CHECK(param.axis >= -1 * (int)arrshape.ndim() && param.axis < (int)arrshape.ndim()) + CHECK(param.axis >= -1 * arrshape.ndim() && param.axis < arrshape.ndim()) << "Axis should be in the range of [-r, r-1] where r is the rank of input tensor"; out_attrs->clear(); const index_t actual_axis = param.axis + ((param.axis < 0) ? arrshape.ndim() : 0); - mxnet::TShape oshape(idxshape.ndim() + arrshape.ndim() - 1); + mxnet::TShape oshape(idxshape.ndim() + arrshape.ndim() - 1, -1); for (index_t i = 0; i < idxshape.ndim(); ++i) { oshape[i + actual_axis] = idxshape[i]; } @@ -705,7 +705,7 @@ inline bool TakeOpShape(const nnvm::NodeAttrs& attrs, } } out_attrs->push_back(oshape); - return true; + return shape_is_known(oshape); } inline bool TakeOpType(const nnvm::NodeAttrs& attrs, @@ -1170,6 +1170,7 @@ inline bool OneHotOpShape(const nnvm::NodeAttrs& attrs, CHECK_EQ(out_attrs->size(), 1U); // The shape of indices const mxnet::TShape& ishape = (*in_attrs)[0]; + if (!shape_is_known(ishape)) return false; int depth = 0; double on_value = 1.0; @@ -1177,13 +1178,13 @@ inline bool OneHotOpShape(const nnvm::NodeAttrs& attrs, int dtype = mshadow::kFloat32; GetOneHotParams(param, &depth, &on_value, &off_value, &dtype); - mxnet::TShape oshape(ishape.ndim() + 1); + mxnet::TShape oshape(ishape.ndim() + 1, -1); for (index_t i = 0; i < ishape.ndim(); ++i) { oshape[i] = ishape[i]; } oshape[oshape.ndim()-1] = depth; SHAPE_ASSIGN_CHECK(*out_attrs, 0, oshape); - return true; + return shape_is_known(oshape); } inline bool OneHotOpType(const nnvm::NodeAttrs& attrs, @@ -1270,15 +1271,15 @@ inline bool GatherNDShape(const nnvm::NodeAttrs& attrs, CHECK_LE(ishape[0], 10) << "gather_nd supports indexing along at most 10 dimensions."; - mxnet::TShape oshape(ishape.ndim() - 1 + dshape.ndim() - ishape[0]); + mxnet::TShape oshape(ishape.ndim() - 1 + dshape.ndim() - ishape[0], -1); - for (size_t i = 0; i < ishape.ndim() - 1; ++i) oshape[i] = ishape[i+1]; + for (int i = 0; i < ishape.ndim() - 1; ++i) oshape[i] = ishape[i+1]; for (int i = 0; i < dshape.ndim() - ishape[0]; ++i) { oshape[ishape.ndim()-1+i] = dshape[ishape[0] + i]; } SHAPE_ASSIGN_CHECK(*out_attrs, 0, oshape); - return true; + return shape_is_known(oshape); } inline bool GatherNDType(const nnvm::NodeAttrs& attrs, @@ -1370,7 +1371,7 @@ inline bool ScatterNDShape(const nnvm::NodeAttrs& attrs, bool valid = dshape.ndim() == ishape.ndim() - 1 + oshape.ndim() - ishape[0]; - for (size_t i = 0; i < ishape.ndim() - 1; ++i) { + for (int i = 0; i < ishape.ndim() - 1; ++i) { valid = valid && dshape[i] == ishape[i+1]; } for (int i = 0; i < oshape.ndim() - ishape[0]; ++i) { diff --git a/src/operator/tensor/init_op.h b/src/operator/tensor/init_op.h index fe1a1f62954a..b2e3830064ae 100644 --- a/src/operator/tensor/init_op.h +++ b/src/operator/tensor/init_op.h @@ -28,6 +28,7 @@ #include #include #include +#include #include #include #include @@ -49,7 +50,7 @@ struct InitOpParam : public dmlc::Parameter { int dtype; DMLC_DECLARE_PARAMETER(InitOpParam) { DMLC_DECLARE_FIELD(shape) - .set_default(mxnet::TShape()) + .set_default(mxnet::TShape(0, 1)) .describe("The shape of the output"); DMLC_DECLARE_FIELD(ctx) .set_default("") @@ -213,14 +214,13 @@ inline bool InitShape(const nnvm::NodeAttrs& attrs, const ParamType& param = nnvm::get(attrs.parsed); CHECK_EQ(in_attrs->size(), 0U); CHECK_EQ(out_attrs->size(), 1U); - if ((*out_attrs)[0].ndim() != 0 && param.shape.ndim() == 0) return true; - for (unsigned int i=0 ; i < param.shape.ndim() ; ++i) { - if (param.shape[i] < 0U) { - LOG(FATAL) << "Shape cannot contain negative values " << param.shape; - } + mxnet::TShape param_shape = param.shape; + if (!Imperative::Get()->is_np_comp()) { + common::ConvertToNumpyShape(¶m_shape); } - SHAPE_ASSIGN_CHECK(*out_attrs, 0, param.shape); - return true; + if (shape_is_known((*out_attrs)[0]) && !shape_is_known(param_shape)) return true; + SHAPE_ASSIGN_CHECK(*out_attrs, 0, param_shape); + return shape_is_known(out_attrs->at(0)); } template @@ -278,6 +278,8 @@ inline bool InitStorageType(const nnvm::NodeAttrs& attrs, */ template void Fill(mshadow::Stream *s, const TBlob& b, const OpReqType req, ValueType val) { + // If b is a zero-size tensor, do nothing. + if (b.Size() == 0) return; if (req != kNullOp) { const size_t size = b.Size(); if (val == 0) { diff --git a/src/operator/tensor/la_op.h b/src/operator/tensor/la_op.h index 5e18e0ef5a25..db4607fe9262 100644 --- a/src/operator/tensor/la_op.h +++ b/src/operator/tensor/la_op.h @@ -384,7 +384,7 @@ mshadow::Tensor LaOpFlatten(const TBlob& blob, } // Collapse ranges [0,axis-1] and [axis+1,ndim-2]. CHECK_EQ(dim, 4); - mxnet::TShape shape(dim); + mxnet::TShape shape(dim, -1); shape[0] = 1; for (int i = 0; i < axis; ++i) { shape[0] *= blob.shape_[i]; diff --git a/src/operator/tensor/matrix_op-inl.h b/src/operator/tensor/matrix_op-inl.h index ba62d0e9def7..0e7f66240926 100644 --- a/src/operator/tensor/matrix_op-inl.h +++ b/src/operator/tensor/matrix_op-inl.h @@ -49,17 +49,17 @@ namespace op { struct ReshapeParam : public dmlc::Parameter { mxnet::TShape target_shape; bool keep_highest; - nnvm::Tuple shape; + mxnet::Tuple shape; bool reverse; DMLC_DECLARE_PARAMETER(ReshapeParam) { DMLC_DECLARE_FIELD(shape) - .set_default(nnvm::Tuple()) + .set_default(mxnet::Tuple()) .describe("The target shape"); DMLC_DECLARE_FIELD(reverse) .set_default(false) .describe("If true then the special values are inferred from right to left"); DMLC_DECLARE_FIELD(target_shape) - .set_default(mxnet::TShape()) + .set_default(mxnet::TShape(0, -1)) .describe("(Deprecated! Use ``shape`` instead.) " "Target new shape. One and only one dim can be 0, " "in which case it will be inferred from the rest of dims"); @@ -71,11 +71,11 @@ struct ReshapeParam : public dmlc::Parameter { }; template -inline mxnet::TShape InferReshapeShape(const nnvm::Tuple& shape, - const mxnet::TShape& dshape, bool reverse) { +inline mxnet::TShape InferReshapeShape(const mxnet::Tuple& shape, + const mxnet::TShape& dshape, bool reverse) { std::vector dshape_vec; std::vector param_shape_vec(shape.begin(), shape.end()); - for (index_t i = 0; i < dshape.ndim(); ++i) { + for (int i = 0; i < dshape.ndim(); ++i) { dshape_vec.push_back(dshape[i]); } std::vector tmp; @@ -102,28 +102,31 @@ inline mxnet::TShape InferReshapeShape(const nnvm::Tuple& shape, } else if (proposed_dim == -2) { // copy all remaining dims from source while (src_idx < dshape_len) { - size_t dn = dshape_vec[src_idx++]; + const int dn = dshape_vec[src_idx++]; tmp.push_back(dn); } } else if (proposed_dim == -3) { // merge two dims from source CHECK_LT(src_idx, dshape_len-1); - size_t d1 = dshape_vec[src_idx++]; - size_t d2 = dshape_vec[src_idx++]; - size_t dn = d1 * d2; - tmp.push_back(dn); + const int d1 = dshape_vec[src_idx++]; + const int d2 = dshape_vec[src_idx++]; + if (!mxnet::dim_size_is_known(d1) || !mxnet::dim_size_is_known(d2)) { + tmp.push_back(-1); + } else { + tmp.push_back(d1 * d2); + } } else if (proposed_dim == -4) { // split the source dim s into two dims // read the left dim and then the right dim (either can be -1) CHECK_LT(i + 2, params_len); CHECK_LT(src_idx, dshape_len); - size_t d0 = dshape_vec[src_idx++]; + const int d0 = dshape_vec[src_idx++]; IType d1 = param_shape_vec[++i]; IType d2 = param_shape_vec[++i]; CHECK(d1 != -1 || d2 != -1) << "Split dims cannot both be -1."; - if (d1 == -1) d1 = d0 / d2; - if (d2 == -1) d2 = d0 / d1; - CHECK(d1 * d2 == static_cast(d0) || static_cast(d0) == IType(0)) << + if (d1 == -1 && d0 >= 0) d1 = d0 / d2; // d0 must be known to do this + if (d2 == -1 && d0 >= 0) d2 = d0 / d1; // d0 must be known to do this + CHECK(d1 * d2 == static_cast(d0) || static_cast(d0) == IType(-1)) << "Split dims " << d1 << ", " << d2 << " do not divide original dim " << d0; tmp.push_back(d1); tmp.push_back(d2); @@ -135,12 +138,12 @@ inline mxnet::TShape InferReshapeShape(const nnvm::Tuple& shape, } if (inf_idx >= 0) { - if (dshape.Size() > 0) { + if (shape_is_known(dshape)) { IType new_size = 1; for (IType x : tmp) new_size *= x; tmp[inf_idx] = dshape.Size() / new_size; } else { - tmp[inf_idx] = 0; + tmp[inf_idx] = -1; } } if (reverse) { @@ -153,24 +156,24 @@ inline mxnet::TShape InferReshapeShape(const nnvm::Tuple& shape, } inline bool ReverseReshapeInferShape(mxnet::TShape *in, const mxnet::TShape& out) { - if (in->Size() && out.Size()) { + if (shape_is_known(*in) && shape_is_known(out)) { return true; - } else if (!out.Size()) { + } else if (!shape_is_known(out)) { return false; } else { int zero_axis = -1; - int non_zero_prod = 1; - for (index_t i = 0; i < in->ndim(); i++) { - if ((*in)[i] == 0) { + int known_dim_size_prod = 1; + for (int i = 0; i < in->ndim(); i++) { + if (!mxnet::dim_size_is_known(*in, i)) { if (zero_axis != -1) return false; // more than 1 zero found. else zero_axis = i; } else { - non_zero_prod *= (*in)[i]; + known_dim_size_prod *= (*in)[i]; } } - (*in)[zero_axis] = out.Size() / non_zero_prod; + (*in)[zero_axis] = out.Size() / known_dim_size_prod; return true; } } @@ -182,11 +185,11 @@ inline bool ReshapeShape(const nnvm::NodeAttrs& attrs, CHECK_EQ(in_attrs->size(), 1U) << "Input: [data]"; CHECK_EQ(out_attrs->size(), 1U); mxnet::TShape &dshape = (*in_attrs)[0]; - if (dshape.ndim() == 0) return false; + if (!mxnet::ndim_is_known(dshape)) return false; mxnet::TShape oshape; if (param_.shape.ndim() != 0) { oshape = InferReshapeShape(param_.shape, dshape, param_.reverse); - } else if (param_.target_shape.ndim()) { + } else if (param_.target_shape.ndim() != -1) { LOG(INFO) << "Using target_shape will be deprecated."; oshape = param_.target_shape; int neg_count = 0; @@ -195,7 +198,7 @@ inline bool ReshapeShape(const nnvm::NodeAttrs& attrs, if (param_.keep_highest) { oshape[0] = dshape[0]; } - for (index_t i = start_idx; i < oshape.ndim(); ++i) { + for (int i = start_idx; i < oshape.ndim(); ++i) { if (oshape[i] == 0) { neg_count++; inf_idx = i; @@ -206,13 +209,16 @@ inline bool ReshapeShape(const nnvm::NodeAttrs& attrs, oshape[inf_idx] = dshape.Size() / oshape.Size(); } } else { - return (*out_attrs)[0].ndim() && ReverseReshapeInferShape(&(*in_attrs)[0], (*out_attrs)[0]); + return shape_is_known((*out_attrs)[0]) + && ReverseReshapeInferShape(&(*in_attrs)[0], (*out_attrs)[0]); } ReverseReshapeInferShape(&dshape, oshape); +#if 0 CHECK_EQ(oshape.Size(), dshape.Size()) << "Target shape size is different to source. " << "Target: " << oshape << "\nSource: " << dshape; +#endif SHAPE_ASSIGN_CHECK(*out_attrs, 0, oshape); return ReverseReshapeInferShape(&(*in_attrs)[0], (*out_attrs)[0]); } @@ -223,9 +229,9 @@ inline bool FlattenShape(const nnvm::NodeAttrs& attrs, CHECK_EQ(in_attrs->size(), 1U) << "Input: [data]"; CHECK_EQ(out_attrs->size(), 1U); const mxnet::TShape &dshape = (*in_attrs)[0]; - if (dshape.ndim() == 0) return false; - uint32_t target_dim = 1; - for (uint32_t i = 1; i < dshape.ndim(); ++i) { + if (!shape_is_known(dshape)) return false; + int target_dim = 1; + for (int i = 1; i < dshape.ndim(); ++i) { target_dim *= dshape[i]; } SHAPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::Shape2(dshape[0], target_dim)); @@ -235,7 +241,7 @@ inline bool FlattenShape(const nnvm::NodeAttrs& attrs, struct TransposeParam : public dmlc::Parameter { mxnet::TShape axes; DMLC_DECLARE_PARAMETER(TransposeParam) { - DMLC_DECLARE_FIELD(axes).set_default(mxnet::TShape()) + DMLC_DECLARE_FIELD(axes).set_default(mxnet::TShape(0, -1)) .describe("Target axis order. By default the axes will be inverted."); } @@ -314,8 +320,8 @@ void Transpose(const nnvm::NodeAttrs& attrs, const TransposeParam& param = nnvm::get(attrs.parsed); CHECK_EQ(req[0], kWriteTo) << "Transpose does not support inplace"; if (param.axes.ndim() == 0) { - mxnet::TShape axes = mxnet::TShape(inputs[0].ndim()); - for (index_t i = 0; i < axes.ndim(); ++i) { + mxnet::TShape axes(inputs[0].ndim(), -1); + for (int i = 0; i < axes.ndim(); ++i) { axes[i] = axes.ndim() - 1 - i; } TransposeImpl(ctx.run_ctx, inputs[0], outputs[0], axes); @@ -332,20 +338,20 @@ inline bool TransposeShape(const nnvm::NodeAttrs& attrs, CHECK_EQ(out_attrs->size(), 1U); mxnet::TShape& shp = (*in_attrs)[0]; CHECK_LE(shp.ndim(), 6U) << "Transpose support at most 6 dimensions"; - mxnet::TShape ret(shp.ndim()); + mxnet::TShape ret(shp.ndim(), -1); if (param.axes.ndim() == 0) { - for (index_t i = 0; i < shp.ndim(); ++i) { + for (int i = 0; i < shp.ndim(); ++i) { ret[i] = shp[shp.ndim()-1-i]; } } else { CHECK_EQ(shp.ndim(), param.axes.ndim()); - for (size_t i = 0; i < shp.ndim(); ++i) { + for (int i = 0; i < shp.ndim(); ++i) { CHECK(param.axes[i] < static_cast(shp.ndim())); ret[i] = shp[param.axes[i]]; } } SHAPE_ASSIGN_CHECK(*out_attrs, 0, ret); - return true; + return shape_is_known(ret); } @@ -366,7 +372,7 @@ inline bool ExpandDimShape(const nnvm::NodeAttrs& attrs, const ExpandDimParam& param = nnvm::get(attrs.parsed); CHECK_EQ(in_attrs->size(), 1U); CHECK_EQ(out_attrs->size(), 1U); - if (in_attrs->at(0).ndim() == 0U && out_attrs->at(0).ndim() == 0U) { + if (!mxnet::ndim_is_known(in_attrs->at(0)) && !mxnet::ndim_is_known(out_attrs->at(0))) { return false; } @@ -374,7 +380,7 @@ inline bool ExpandDimShape(const nnvm::NodeAttrs& attrs, mxnet::TShape& oshape = (*out_attrs)[0]; int indim = ishape.ndim(); bool unknown_ishape = false; - if (0 == indim) { + if (-1 == indim) { indim = oshape.ndim() - 1; unknown_ishape = true; } @@ -386,27 +392,27 @@ inline bool ExpandDimShape(const nnvm::NodeAttrs& attrs, CHECK(axis >= 0 && axis <= indim) << "axis must be in the range [" << -indim << ", " << indim << "] (" << param.axis << " provided)"; - mxnet::TShape ret(indim + 1); + mxnet::TShape ret(indim + 1, -1); for (int i = 0; i < axis; ++i) { - ret[i] = (unknown_ishape? 0 : ishape[i]); + ret[i] = (unknown_ishape? -1 : ishape[i]); } ret[axis] = 1; for (int i = axis+1; i < indim+1; ++i) { - ret[i] = (unknown_ishape? 0 : ishape[i-1]); + ret[i] = (unknown_ishape? -1 : ishape[i-1]); } SHAPE_ASSIGN_CHECK(*out_attrs, 0, ret); - ret = mxnet::TShape(indim); + ret = mxnet::TShape(indim, -1); for (int i = 0; i < axis; ++i) ret[i] = oshape[i]; for (int i = axis+1; i < indim+1; ++i) ret[i-1] = oshape[i]; SHAPE_ASSIGN_CHECK(*in_attrs, 0, ret); - return true; + return shape_is_known(in_attrs->at(0)) && shape_is_known(out_attrs->at(0)); } // Currently MKLDNN only supports step = 1 or step has no value inline bool SupportMKLDNNSlice(const SliceParam& param) { if (param.step.ndim() == 0U) return true; - for (uint32_t i = 0; i < param.step.ndim(); ++i) { + for (int i = 0; i < param.step.ndim(); ++i) { if (param.step[i].has_value() && param.step[i].value() != 1) return false; } @@ -589,11 +595,11 @@ void SliceCsrImpl(const SliceParam ¶m, const OpContext& ctx, const mxnet::TShape ishape = in.shape(); const mxnet::TShape oshape = out.shape(); - uint32_t N = ishape.ndim(); - mxnet::TShape begin(N), end(N); - for (uint32_t i = 0; i < N; ++i) { + int N = ishape.ndim(); + mxnet::TShape begin(N, -1), end(N, -1); + for (int i = 0; i < N; ++i) { int s = 0; - if (param.begin[i]) { + if (i < param.begin.ndim() && param.begin[i]) { s = *param.begin[i]; if (s < 0) s += ishape[i]; } @@ -634,9 +640,9 @@ void SliceEx(const nnvm::NodeAttrs& attrs, template inline void GetIndexRange(const mxnet::TShape& dshape, - const nnvm::Tuple>& param_begin, - const nnvm::Tuple>& param_end, - const nnvm::Tuple>& param_step, + const mxnet::Tuple>& param_begin, + const mxnet::Tuple>& param_end, + const mxnet::Tuple>& param_step, common::StaticArray* begin, common::StaticArray* end, common::StaticArray* step) { @@ -651,7 +657,7 @@ inline void GetIndexRange(const mxnet::TShape& dshape, << "Static array size=" << ndim << " is not equal to data shape ndim=" << dshape.ndim(); - if (param_step.ndim() != 0U) { + if (param_step.ndim() != 0) { CHECK_EQ(param_step.ndim(), param_begin.ndim()) << "step and begin must have the same length"; } @@ -723,7 +729,7 @@ inline bool SliceOpShape(const nnvm::NodeAttrs& attrs, CHECK_EQ(in_attrs->size(), 1U); CHECK_EQ(out_attrs->size(), 1U); const mxnet::TShape& dshape = (*in_attrs)[0]; - if (dshape.ndim() == 0) return false; + if (!mxnet::ndim_is_known(dshape)) return false; const SliceParam& param = nnvm::get(attrs.parsed); mxnet::TShape oshape = dshape; @@ -737,7 +743,7 @@ inline bool SliceOpShape(const nnvm::NodeAttrs& attrs, }); SHAPE_ASSIGN_CHECK(*out_attrs, 0, oshape); - return !shape_is_none(dshape) && !shape_is_none(oshape); + return shape_is_known(oshape); } template @@ -943,7 +949,7 @@ inline bool SliceAssignOpShape(const nnvm::NodeAttrs& attrs, MXNET_NDIM_SWITCH(dshape.ndim(), ndim, { common::StaticArray begin, end, step; GetIndexRange(dshape, param.begin, param.end, param.step, &begin, &end, &step); - for (index_t i = 0; i < param.begin.ndim(); ++i) { + for (int i = 0; i < param.begin.ndim(); ++i) { const int b = begin[i], e = end[i], s = step[i]; SetSliceOpOutputDimSize(i, b, e, s, &vshape); } @@ -997,8 +1003,8 @@ void SliceAssignOpForward(const nnvm::NodeAttrs& attrs, struct SliceAssignScalarParam : public dmlc::Parameter { double scalar; - nnvm::Tuple> begin, end; - nnvm::Tuple> step; + mxnet::Tuple> begin, end; + mxnet::Tuple> step; DMLC_DECLARE_PARAMETER(SliceAssignScalarParam) { DMLC_DECLARE_FIELD(scalar) .set_default(0) @@ -1008,7 +1014,7 @@ struct SliceAssignScalarParam : public dmlc::Parameter { DMLC_DECLARE_FIELD(end) .describe("ending indices for the slice operation, supports negative indices."); DMLC_DECLARE_FIELD(step) - .set_default(nnvm::Tuple>()) + .set_default(mxnet::Tuple>()) .describe("step for the slice operation, supports negative values."); } }; @@ -1019,7 +1025,7 @@ inline bool SliceAssignScalarOpShape(const nnvm::NodeAttrs& attrs, CHECK_EQ(in_attrs->size(), 1U); CHECK_EQ(out_attrs->size(), 1U); const mxnet::TShape& dshape = (*in_attrs)[0]; - if (dshape.ndim() == 0U || dshape.Size() == 0U) return false; + if (!shape_is_known(dshape)) return false; SHAPE_ASSIGN_CHECK(*out_attrs, 0, dshape); return true; } @@ -1114,9 +1120,9 @@ inline void GetSliceAxisParams(const SliceAxisParam& param, const mxnet::TShape& int* axis, index_t* begin, index_t* end) { *axis = param.axis; if (*axis < 0) { - *axis += static_cast(ishape.ndim()); + *axis += ishape.ndim(); } - CHECK(*axis < static_cast(ishape.ndim()) && *axis >= 0) << + CHECK(*axis < ishape.ndim() && *axis >= 0) << "Transformed axis must be smaller than the source ndim and larger than zero! Recieved axis=" << param.axis << ", src_ndim=" << ishape.ndim() << ", transformed axis=" << *axis; index_t axis_size = static_cast(ishape[*axis]); @@ -1125,7 +1131,7 @@ inline void GetSliceAxisParams(const SliceAxisParam& param, const mxnet::TShape& if (*begin < 0) { *begin += axis_size; } - if (axis_size) { + if (axis_size > 0) { if (!static_cast(param.end)) { *end = axis_size; } else { @@ -1153,11 +1159,16 @@ inline bool SliceAxisShape(const nnvm::NodeAttrs& attrs, CHECK_EQ(in_attrs->size(), 1U); CHECK_EQ(out_attrs->size(), 1U); mxnet::TShape& ishape = (*in_attrs)[0]; + if (!mxnet::ndim_is_known(ishape)) return false; int axis; index_t begin, end; GetSliceAxisParams(param, ishape, &axis, &begin, &end); - mxnet::TShape shape(ishape.ndim()); - for (index_t i = 0; i < ishape.ndim(); ++i) { + if (!mxnet::dim_size_is_known(ishape, axis)) { + SHAPE_ASSIGN_CHECK(*out_attrs, 0, ishape); + return false; + } + mxnet::TShape shape(ishape.ndim(), -1); + for (int i = 0; i < ishape.ndim(); ++i) { if (static_cast(i) == axis) { shape[i] = static_cast(end - begin); } else { @@ -1165,7 +1176,7 @@ inline bool SliceAxisShape(const nnvm::NodeAttrs& attrs, } } SHAPE_ASSIGN_CHECK(*out_attrs, 0, shape); - return true; + return shape_is_known(shape); } @@ -1181,7 +1192,7 @@ void SliceAxis(const nnvm::NodeAttrs& attrs, int axis; index_t begin, end; GetSliceAxisParams(param, inputs[0].shape_, &axis, &begin, &end); - int ndim = static_cast(outputs[0].ndim()); + int ndim = outputs[0].ndim(); if (axis + 1 == ndim) { MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, { @@ -1252,9 +1263,9 @@ void SliceAxisGrad_(const nnvm::NodeAttrs& attrs, } struct SliceLikeParam : public dmlc::Parameter { - mxnet::TShape axes; + mxnet::Tuple axes; DMLC_DECLARE_PARAMETER(SliceLikeParam) { - DMLC_DECLARE_FIELD(axes).set_default(mxnet::TShape()) + DMLC_DECLARE_FIELD(axes).set_default(mxnet::Tuple()) .describe("List of axes on which input data will be sliced according to the " "corresponding size of the second input. By default will slice on " "all axes. Negative axes are supported."); @@ -1273,7 +1284,7 @@ inline bool SliceLikeShape(const nnvm::NodeAttrs& attrs, CHECK_EQ(ishape.ndim(), from_shape.ndim()) << "By default slice_axis performs slice on all axes, but ndim mismatch " "for inputs: " << ishape.ndim() << " vs. " << from_shape.ndim(); - for (index_t i = 0; i < ishape.ndim(); ++i) { + for (int i = 0; i < ishape.ndim(); ++i) { CHECK_GE(ishape[i], from_shape[i]) << "Slice axis " << i << " with size " << from_shape[i] << "exceeds limit of input with size " << ishape[i]; @@ -1281,7 +1292,7 @@ inline bool SliceLikeShape(const nnvm::NodeAttrs& attrs, SHAPE_ASSIGN_CHECK(*out_attrs, 0, from_shape); } else { mxnet::TShape shape(ishape); - for (index_t i = 0; i < param.axes.ndim(); ++i) { + for (int i = 0; i < param.axes.ndim(); ++i) { int axis = static_cast(param.axes[i]); if (axis < 0) { axis += static_cast(ishape.ndim()); @@ -1304,21 +1315,21 @@ inline bool SliceLikeShape(const nnvm::NodeAttrs& attrs, inline void SliceLikeInferRanges(const mxnet::TShape& dshape, const mxnet::TShape& fshape, - const mxnet::TShape& axes, - nnvm::Tuple>* param_begin, - nnvm::Tuple>* param_end, - nnvm::Tuple>* param_step) { + const mxnet::Tuple& axes, + mxnet::Tuple>* param_begin, + mxnet::Tuple>* param_end, + mxnet::Tuple>* param_step) { std::vector> pb(dshape.ndim()); std::vector> pe(dshape.ndim()); std::vector> ps(dshape.ndim()); if (axes.ndim() == 0) { - for (index_t i = 0; i < dshape.ndim(); ++i) { + for (int i = 0; i < dshape.ndim(); ++i) { pb[i] = 0; pe[i] = fshape[i]; ps[i] = 1; } } else { - for (index_t i = 0; i < axes.ndim(); ++i) { + for (int i = 0; i < axes.ndim(); ++i) { int axis = static_cast(axes[i]); if (axis < 0) { axis += static_cast(dshape.ndim()); @@ -1334,9 +1345,9 @@ inline void SliceLikeInferRanges(const mxnet::TShape& dshape, ps[axis] = 1; } } - *param_begin = nnvm::Tuple>(pb.begin(), pb.end()); - *param_end = nnvm::Tuple>(pe.begin(), pe.end()); - *param_step = nnvm::Tuple>(ps.begin(), ps.end()); + *param_begin = mxnet::Tuple>(pb.begin(), pb.end()); + *param_end = mxnet::Tuple>(pe.begin(), pe.end()); + *param_step = mxnet::Tuple>(ps.begin(), ps.end()); } template @@ -1355,9 +1366,9 @@ void SliceLikeForward(const nnvm::NodeAttrs& attrs, const TBlob& out = outputs[0]; const mxnet::TShape& ishape = data.shape_; const mxnet::TShape& from_shape = inputs[1].shape_; - nnvm::Tuple> param_begin; - nnvm::Tuple> param_end; - nnvm::Tuple> param_step; + mxnet::Tuple> param_begin; + mxnet::Tuple> param_end; + mxnet::Tuple> param_step; SliceLikeInferRanges(ishape, from_shape, param.axes, ¶m_begin, ¶m_end, ¶m_step); MXNET_NDIM_SWITCH(data.ndim(), ndim, { @@ -1403,9 +1414,9 @@ void SliceLikeBackward(const nnvm::NodeAttrs& attrs, const mxnet::TShape& ishape = ograd.shape_; const mxnet::TShape& from_shape = outputs[1].shape_; - nnvm::Tuple> param_begin; - nnvm::Tuple> param_end; - nnvm::Tuple> param_step; + mxnet::Tuple> param_begin; + mxnet::Tuple> param_end; + mxnet::Tuple> param_step; SliceLikeInferRanges(ishape, from_shape, param.axes, ¶m_begin, ¶m_end, ¶m_step); MXNET_NDIM_SWITCH(ograd.ndim(), ndim, { @@ -1546,7 +1557,7 @@ inline void GetRepeatParams(const RepeatParam& param, const mxnet::TShape& ishap CHECK_GE(*repeats, 0) << "repeats cannot be a negative number"; *axisOpt = param.axis; if (static_cast(*axisOpt)) { - int ndims = static_cast(ishape.ndim()); + int ndims = ishape.ndim(); int axis = axisOpt->value(); if (axis < 0) { axis += ndims; @@ -1565,34 +1576,33 @@ inline bool RepeatOpShape(const nnvm::NodeAttrs& attrs, int repeats = 0; dmlc::optional axisOpt; GetRepeatParams(param, ishape, &repeats, &axisOpt); - // If 0 repeats, return an empty 0 dim array + // If 0 repeats, return an empty 1-dim, 0-size array if (0 == repeats) { - SHAPE_ASSIGN_CHECK(*out_attrs, 0, mxnet::TShape()); + SHAPE_ASSIGN_CHECK(*out_attrs, 0, mxnet::TShape(1, 0)); return true; } // If repeats > 0, multiply the size of the corresponding axis by repeats if (static_cast(axisOpt)) { - int ndims = static_cast(ishape.ndim()); + int ndims = ishape.ndim(); int axis = axisOpt.value(); if (axis < 0) { axis += ndims; } - mxnet::TShape shape(ishape.ndim()); - for (index_t i = 0; i < ishape.ndim(); ++i) { - if (static_cast(i) == axis) { - shape[i] = static_cast(repeats) * ishape[i]; + mxnet::TShape shape(ishape.ndim(), -1); + for (int i = 0; i < ishape.ndim(); ++i) { + if (i == axis) { + shape[i] = repeats * ishape[i]; } else { shape[i] = ishape[i]; } } SHAPE_ASSIGN_CHECK(*out_attrs, 0, shape); } else { // If axis is not input by user, return a flat 1D array of size = in.size*repeats - mxnet::TShape shape(1); - shape[0] = ishape.Size() * static_cast(repeats); + mxnet::TShape shape(1, ishape.Size() * repeats); SHAPE_ASSIGN_CHECK(*out_attrs, 0, shape); } - return true; + return shape_is_known(out_attrs->at(0)); } inline bool RepeatOpType(const nnvm::NodeAttrs& attrs, @@ -1620,16 +1630,16 @@ inline std::pair ReshapeInputOutputForRepeatOp( const int repeats) { if (static_cast(axisOpt)) { int axis = axisOpt.value(); - int ndim = static_cast(ishape.ndim()); + int ndim = ishape.ndim(); if (axis < 0) { axis += ndim; } - CHECK(axis >= 0 && axis < static_cast(ishape.ndim())) << "Invalid input of axis"; + CHECK(axis >= 0 && axis < ishape.ndim()) << "Invalid input of axis"; // reshape the input tensor by adding a dim at the (axis+1)-th dim - mxnet::TShape rshape(ishape.ndim()+1); + mxnet::TShape rshape(ishape.ndim()+1, 1); // the shape we want to broadcast to - mxnet::TShape bshape(rshape.ndim()); + mxnet::TShape bshape(rshape.ndim(), 1); int i = 0; while (i <= axis) { rshape[i] = bshape[i] = ishape[i]; @@ -1637,7 +1647,7 @@ inline std::pair ReshapeInputOutputForRepeatOp( } rshape[i] = 1; bshape[i] = repeats; - while (i < static_cast(ishape.ndim())) { + while (i < ishape.ndim()) { rshape[i+1] = ishape[i]; bshape[i+1] = ishape[i]; ++i; @@ -1648,11 +1658,11 @@ inline std::pair ReshapeInputOutputForRepeatOp( // reshape the tensor into shape (ishape.Size(), 1) // then add one dim at axis = 1 and broadcast to // shape (ishape.Size(), repeats) - mxnet::TShape rshape(2); + mxnet::TShape rshape(2, 1); rshape[0] = ishape.Size(); rshape[1] = 1; - mxnet::TShape bshape(2); + mxnet::TShape bshape(2, 1); bshape[0] = rshape[0]; bshape[1] = repeats; return std::make_pair(rshape, bshape); @@ -1667,7 +1677,7 @@ void RepeatOpForward(const nnvm::NodeAttrs& attrs, const std::vector& outputs) { const TBlob& iTBlob = inputs[0]; const mxnet::TShape& ishape = iTBlob.shape_; - if (ishape.ndim() == 0) return; + if (!shape_is_known(ishape)) return; int repeats = 0; dmlc::optional axisOpt; @@ -1711,7 +1721,7 @@ void RepeatOpBackward(const nnvm::NodeAttrs& attrs, CHECK_EQ(outputs.size(), 1U); const mxnet::TShape& oshape = outputs[0].shape_; - if (oshape.ndim() == 0) return; + if (!shape_is_known(oshape)) return; int repeats = 0; dmlc::optional axisOpt; @@ -1737,7 +1747,7 @@ void RepeatOpBackward(const nnvm::NodeAttrs& attrs, } struct TileParam : public dmlc::Parameter { - mxnet::TShape reps; + mxnet::Tuple reps; DMLC_DECLARE_PARAMETER(TileParam) { DMLC_DECLARE_FIELD(reps) .describe("The number of times for repeating the tensor a. Each dim size of reps" @@ -1755,19 +1765,22 @@ inline bool TileOpShape(const nnvm::NodeAttrs& attrs, CHECK_EQ(out_attrs->size(), 1U); const TileParam& param = nnvm::get(attrs.parsed); const mxnet::TShape& ishape = (*in_attrs)[0]; - const mxnet::TShape& reps = param.reps; + if (!shape_is_known(ishape)) { + return false; + } + const mxnet::Tuple& reps = param.reps; // If reps is empty, return a identical input array - if (reps.ndim() == 0 || ishape.ndim() == 0) { + if (reps.ndim() == 0) { SHAPE_ASSIGN_CHECK(*out_attrs, 0, ishape); return true; } - for (size_t i = 0; i < reps.ndim(); ++i) { + for (int i = 0; i < reps.ndim(); ++i) { CHECK_GT(reps[i], 0) << "invalid reps=" << i << ", dim size must be greater than zero"; } - mxnet::TShape oshape(std::max(ishape.ndim(), reps.ndim())); - int i1 = static_cast(ishape.ndim()) - 1; - int i2 = static_cast(reps.ndim()) - 1; - for (int i = static_cast(oshape.ndim()) - 1; i >= 0; --i) { + mxnet::TShape oshape(std::max(ishape.ndim(), reps.ndim()), -1); + int i1 = ishape.ndim() - 1; + int i2 = reps.ndim() - 1; + for (int i = oshape.ndim() - 1; i >= 0; --i) { if (i1 >= 0 && i2 >= 0) { oshape[i] = ishape[i1--] * reps[i2--]; } else if (i1 >= 0) { @@ -1777,7 +1790,7 @@ inline bool TileOpShape(const nnvm::NodeAttrs& attrs, } } SHAPE_ASSIGN_CHECK(*out_attrs, 0, oshape); - return true; + return shape_is_known(oshape); } inline bool TileOpType(const nnvm::NodeAttrs& attrs, @@ -1801,20 +1814,20 @@ inline bool TileOpType(const nnvm::NodeAttrs& attrs, */ inline std::pair ReshapeInputOutputForTileOp( const mxnet::TShape& ishape, - const mxnet::TShape& reps) { + const mxnet::Tuple& reps) { if (ishape.ndim() == 0 || reps.ndim() == 0) { return std::make_pair(ishape, ishape); } // The shape we want to broadcast to - mxnet::TShape bshape(std::max(ishape.ndim(), reps.ndim()) * 2); + mxnet::TShape bshape(std::max(ishape.ndim(), reps.ndim()) * 2, 1); // The shape of the input tensor after adding new axes before each dim - mxnet::TShape rshape(bshape.ndim()); + mxnet::TShape rshape(bshape.ndim(), 1); - int i1 = static_cast(ishape.ndim()) - 1; - int i2 = static_cast(reps.ndim()) - 1; - for (int i = static_cast(bshape.ndim()) - 1; i >= 0; --i) { + int i1 = ishape.ndim() - 1; + int i2 = reps.ndim() - 1; + for (int i = bshape.ndim() - 1; i >= 0; --i) { if (0 == (i & 1)) { bshape[i] = (i2 >= 0? reps[i2--] : 1); rshape[i] = 1; @@ -1854,10 +1867,10 @@ void TileOpForward(const nnvm::NodeAttrs& attrs, if (inputs[0].Size() == 0) return; const mxnet::TShape& ishape = inputs[0].shape_; - const mxnet::TShape& reps = nnvm::get(attrs.parsed).reps; + const mxnet::Tuple& reps = nnvm::get(attrs.parsed).reps; // If any one of the number in reps is zero, return immediately - for (index_t i = 0; i < reps.ndim(); ++i) { + for (int i = 0; i < reps.ndim(); ++i) { if (0 == reps[i]) return; } @@ -1896,10 +1909,10 @@ void TileOpBackward(const nnvm::NodeAttrs& attrs, if (inputs[0].Size() == 0) return; const mxnet::TShape& oshape = outputs[0].shape_; - const mxnet::TShape& reps = nnvm::get(attrs.parsed).reps; + const mxnet::Tuple& reps = nnvm::get(attrs.parsed).reps; // If any one of the number in reps is zero, return immediately - for (index_t i = 0; i < reps.ndim(); ++i) { + for (int i = 0; i < reps.ndim(); ++i) { if (0 == reps[i]) return; } @@ -1919,7 +1932,7 @@ void TileOpBackward(const nnvm::NodeAttrs& attrs, } struct ReverseParam : public dmlc::Parameter { - nnvm::Tuple axis; + mxnet::Tuple axis; DMLC_DECLARE_PARAMETER(ReverseParam) { DMLC_DECLARE_FIELD(axis) .describe("The axis which to reverse elements."); @@ -1990,10 +2003,10 @@ void ReverseOpForward(const nnvm::NodeAttrs& attrs, std::vector trailing_(param.axis.ndim()); index_t reverse_index = 0; for (int axis : param.axis) { - CHECK_LT(axis, static_cast(ishape.ndim())); + CHECK_LT(axis, ishape.ndim()); stride_[reverse_index] = ishape[axis]; trailing_[reverse_index] = 1; - for (index_t i2 = axis + 1; i2 < ishape.ndim(); ++i2) { + for (int i2 = axis + 1; i2 < ishape.ndim(); ++i2) { trailing_[reverse_index] *= ishape[i2]; } reverse_index++; @@ -2054,9 +2067,9 @@ inline bool StackOpShape(const nnvm::NodeAttrs& attrs, for (const mxnet::TShape& i : (*in_attrs)) { shape_assign(&dshape, i); } - if (dshape.ndim() == 0) return false; + if (!shape_is_known(dshape)) return false; - mxnet::TShape oshape(dshape.ndim() + 1); + mxnet::TShape oshape(dshape.ndim() + 1, -1); int axis = CheckAxis(param.axis, oshape.ndim()); for (int i = 0; i < axis; ++i) { oshape[i] = dshape[i]; @@ -2067,7 +2080,7 @@ inline bool StackOpShape(const nnvm::NodeAttrs& attrs, } SHAPE_ASSIGN_CHECK(*out_attrs, 0, oshape); - return true; + return shape_is_known(oshape); } @@ -2140,10 +2153,10 @@ void StackOpBackward(const nnvm::NodeAttrs& attrs, } struct SqueezeParam : public dmlc::Parameter { - dmlc::optional axis; + dmlc::optional> axis; DMLC_DECLARE_PARAMETER(SqueezeParam) { DMLC_DECLARE_FIELD(axis) - .set_default(dmlc::optional()) + .set_default(dmlc::optional>()) .describe("Selects a subset of the single-dimensional entries in the shape." " If an axis is selected with shape entry greater than one, an error is raised."); } @@ -2156,7 +2169,7 @@ struct SqueezeParam : public dmlc::Parameter { inline size_t SqueezeShapeHelper(mxnet::TShape* shape) { CHECK(shape != nullptr); size_t count = 0; - for (size_t i = 0; i < shape->ndim(); ++i) { + for (int i = 0; i < shape->ndim(); ++i) { if ((*shape)[i] == 0) { ++count; } else { @@ -2174,12 +2187,12 @@ inline bool SqueezeShape(const nnvm::NodeAttrs& attrs, CHECK_EQ(out_attrs->size(), 1U); const mxnet::TShape& dshape = in_attrs->at(0); const int dndim = dshape.ndim(); - if (shape_is_none(dshape)) return false; + if (!shape_is_known(dshape)) return false; mxnet::TShape oshape = dshape; if (param.axis.has_value()) { // preprocess axis - mxnet::TShape axes = param.axis.value(); - for (size_t i = 0; i < axes.ndim(); ++i) { + mxnet::Tuple axes = param.axis.value(); + for (int i = 0; i < axes.ndim(); ++i) { if (axes[i] < 0) { axes[i] += dndim; CHECK_GE(axes[i], 0) @@ -2194,7 +2207,7 @@ inline bool SqueezeShape(const nnvm::NodeAttrs& attrs, oshape[axes[i]] = 0; } } else { - for (size_t i = 0; i < oshape.ndim(); ++i) { + for (int i = 0; i < oshape.ndim(); ++i) { if (oshape[i] == 1) oshape[i] = 0; } } @@ -2223,7 +2236,7 @@ inline bool DepthToSpaceOpShape(const nnvm::NodeAttrs& attrs, CHECK_EQ(out_attrs->size(), 1U); CHECK_EQ(in_attrs->at(0).ndim(), 4) << "Operation Depth To Space requires exactly 4D tensor"; - mxnet::TShape expected_out(4); + mxnet::TShape expected_out(4, -1); mxnet::TShape& in_shape = in_attrs->at(0); int block = param.block_size; @@ -2241,14 +2254,14 @@ inline bool DepthToSpaceOpShape(const nnvm::NodeAttrs& attrs, expected_out[0] = in_shape[0]; expected_out[1] = in_shape[1] / (block * block); - size_t i = 2; + int i = 2; while (i < expected_out.ndim()) { expected_out[i] = in_shape[i] * block; ++i; } SHAPE_ASSIGN_CHECK(*out_attrs, 0, expected_out); - return true; + return shape_is_known(expected_out); } inline bool DepthToSpaceOpType(const nnvm::NodeAttrs& attrs, @@ -2387,7 +2400,7 @@ inline bool SpaceToDepthOpShape(const nnvm::NodeAttrs& attrs, CHECK_EQ(out_attrs->size(), 1U); CHECK_EQ(in_attrs->at(0).ndim(), 4) << "Operation Space To Depth requires exactly 4D tensor"; - mxnet::TShape expected_out(in_attrs->at(0).ndim()); + mxnet::TShape expected_out(in_attrs->at(0).ndim(), -1); mxnet::TShape& in_shape = in_attrs->at(0); int block = param.block_size; @@ -2408,14 +2421,14 @@ inline bool SpaceToDepthOpShape(const nnvm::NodeAttrs& attrs, expected_out[0] = in_shape[0]; expected_out[1] = in_shape[1] * block * block; - uint32_t i = 2; + int i = 2; while (i < expected_out.ndim()) { expected_out[i] = in_shape[i] / block; ++i; } SHAPE_ASSIGN_CHECK(*out_attrs, 0, expected_out); - return true; + return shape_is_known(expected_out); } inline bool SpaceToDepthOpType(const nnvm::NodeAttrs& attrs, @@ -2556,7 +2569,7 @@ struct SplitParam : public dmlc::Parameter { }; // struct SplitParam inline mxnet::TShape GetSplitIndices(const mxnet::TShape& ishape, int axis, int sections) { - mxnet::TShape indices(sections+1); + mxnet::TShape indices(sections+1, -1); indices[0] = 0; int64_t section_size = ishape[axis] / sections; for (int i = 0; i < sections; ++i) { @@ -2588,7 +2601,7 @@ inline bool SplitOpShape(const nnvm::NodeAttrs& attrs, CHECK_EQ(in_attrs->size(), 1U); mxnet::TShape dshape = in_attrs->at(split_enum::kData); mxnet::TShape ishape = in_attrs->at(split_enum::kData); - if (dshape.ndim() == 0) return false; + if (!mxnet::ndim_is_known(dshape)) return false; if (param.axis >= 0) { CHECK_LT(static_cast(param.axis), dshape.ndim()); } else { @@ -2603,7 +2616,7 @@ inline bool SplitOpShape(const nnvm::NodeAttrs& attrs, int num_outputs = (param.sections > 0) ? indices.ndim() - 1 : indices.ndim(); // Pre-compute squeezed output shape for future usage mxnet::TShape squeezed_dshape = dshape; - for (int d = real_axis; d < static_cast(squeezed_dshape.ndim()) - 1; ++d) { + for (int d = real_axis; d < squeezed_dshape.ndim() - 1; ++d) { squeezed_dshape[d] = squeezed_dshape[d+1]; } squeezed_dshape = mxnet::TShape(&squeezed_dshape[0], &squeezed_dshape[squeezed_dshape.ndim()-1]); @@ -2635,7 +2648,7 @@ inline bool SplitOpShape(const nnvm::NodeAttrs& attrs, back_calculate_dshape[real_axis] += (*out_attrs)[i][real_axis]; } } - for (int d = real_axis + 1; d < static_cast(ishape.ndim()); ++d) { + for (int d = real_axis + 1; d < ishape.ndim(); ++d) { if (param.squeeze_axis) { back_calculate_dshape[d] = (*out_attrs)[0][d - 1]; } else { diff --git a/src/operator/tensor/matrix_op.cc b/src/operator/tensor/matrix_op.cc index 1431fef13594..b80c9a54510f 100644 --- a/src/operator/tensor/matrix_op.cc +++ b/src/operator/tensor/matrix_op.cc @@ -410,8 +410,8 @@ Examples:: "transpose", n, ograds, {}, std::unordered_map()); } else { - mxnet::TShape axes = mxnet::TShape(param.axes.ndim()); - for (index_t i = 0; i < axes.ndim(); ++i) { + mxnet::TShape axes = mxnet::TShape(param.axes.ndim(), -1); + for (int i = 0; i < axes.ndim(); ++i) { axes[param.axes[i]] = i; } std::ostringstream os; diff --git a/src/operator/tensor/ordering_op-inl.h b/src/operator/tensor/ordering_op-inl.h index 5a95e05ffb65..1dda90104205 100644 --- a/src/operator/tensor/ordering_op-inl.h +++ b/src/operator/tensor/ordering_op-inl.h @@ -149,7 +149,7 @@ inline void ParseTopKParam(const mxnet::TShape& src_shape, const TopKParam& para << src_shape.ndim() << ", found axis=" << *axis; *batch_size = src_shape.Size() / src_shape[*axis]; *element_num = src_shape[*axis]; - if (*axis != static_cast(src_shape.ndim()) - 1) { + if (*axis != src_shape.ndim() - 1) { *do_transpose = true; } } diff --git a/src/operator/tensor/slice-inl.h b/src/operator/tensor/slice-inl.h index 4e94cbeda46c..78a2bd8c7b45 100644 --- a/src/operator/tensor/slice-inl.h +++ b/src/operator/tensor/slice-inl.h @@ -34,15 +34,15 @@ namespace mxnet { namespace op { struct SliceParam : public dmlc::Parameter { - nnvm::Tuple> begin, end; - nnvm::Tuple> step; + mxnet::Tuple> begin, end; + mxnet::Tuple> step; DMLC_DECLARE_PARAMETER(SliceParam) { DMLC_DECLARE_FIELD(begin) .describe("starting indices for the slice operation, supports negative indices."); DMLC_DECLARE_FIELD(end) .describe("ending indices for the slice operation, supports negative indices."); DMLC_DECLARE_FIELD(step) - .set_default(nnvm::Tuple>()) + .set_default(mxnet::Tuple>()) .describe("step for the slice operation, supports negative values."); } bool operator==(const SliceParam& other) const { diff --git a/tests/cpp/include/test_mkldnn.h b/tests/cpp/include/test_mkldnn.h index a379dab7bf90..f1682772a14a 100644 --- a/tests/cpp/include/test_mkldnn.h +++ b/tests/cpp/include/test_mkldnn.h @@ -49,7 +49,7 @@ inline static mkldnn::memory::primitive_desc GetMemPD(const mxnet::TShape s, int inline static mkldnn::memory::primitive_desc GetExpandedMemPD( mkldnn::memory::primitive_desc pd, float scale, int dim = 0) { CHECK(dim < pd.desc().data.ndims) << "dimension cannot be larger than total dimensions of input"; - mxnet::TShape s(pd.desc().data.ndims); + mxnet::TShape s(pd.desc().data.ndims, -1); for (size_t i = 0; i < pd.desc().data.ndims; i++) s[i] = pd.desc().data.dims[i]; s[dim] = static_cast(s[dim] * scale); @@ -165,7 +165,7 @@ inline static TestArrayShapes GetTestArrayShapes(bool spatial_data_format = fals std::vector pds; { // 1D - mxnet::TShape s(1); + mxnet::TShape s(1, -1); s[0] = 279936; shapes.push_back(s); pds.push_back(GetMemPD(s, dtype, mkldnn::memory::format::x)); @@ -175,7 +175,7 @@ inline static TestArrayShapes GetTestArrayShapes(bool spatial_data_format = fals } { // 2D - mxnet::TShape s(2); + mxnet::TShape s(2, -1); s[0] = 96; s[1] = 2916; shapes.push_back(s); @@ -187,12 +187,12 @@ inline static TestArrayShapes GetTestArrayShapes(bool spatial_data_format = fals } { // 4D - mxnet::TShape s1(4); + mxnet::TShape s1(4, -1); s1[0] = 10; s1[1] = 96; s1[2] = 54; s1[3] = 54; shapes.push_back(s1); pds.push_back(GetMemPD(s1, dtype, mkldnn::memory::format::nchw)); - mxnet::TShape s2(4); + mxnet::TShape s2(4, -1); s2[0] = 96; s2[1] = 3; s2[2] = 11; s2[3] = 11; shapes.push_back(s2); pds.push_back(GetMemPD(s2, dtype, mkldnn::memory::format::oihw)); @@ -204,7 +204,7 @@ inline static TestArrayShapes GetTestArrayShapes(bool spatial_data_format = fals } { // 5D - mxnet::TShape s(5); + mxnet::TShape s(5, -1); s[0] = 96; s[1] = 1; s[2] = 3; s[3] = 11; s[4] = 11; shapes.push_back(s); pds.push_back(GetMemPD(s, dtype, mkldnn::memory::format::goihw)); @@ -259,7 +259,7 @@ enum ArrayTypes { inline NDArray CreateKernelNDArray(mxnet::TShape kernel, int num_filters, mxnet::TShape input, bool is_deconv = false) { CHECK_EQ(kernel.ndim(), 2) << "mkldnn only supports 2d filters on 4d inputs"; - mxnet::TShape target_shape(4); + mxnet::TShape target_shape(4, -1); target_shape[0] = is_deconv ? input[1] : num_filters; target_shape[1] = is_deconv ? num_filters : input[1]; target_shape[2] = kernel[0]; @@ -470,7 +470,7 @@ inline std::vector GetTestOutputArrays( in_arrs.emplace_back(arr0.Slice(1, shape[0] + 1), "Reshaped NDArray"); } - mxnet::TShape s(1); + mxnet::TShape s(1, -1); if (types & ArrayTypes::NormalReused) { // Type 5. // Get a reused version. @@ -528,7 +528,7 @@ inline std::vector GetTestOutputArrays( // Type 8, 9. // Get a reused version. - mxnet::TShape s(1); + mxnet::TShape s(1, -1); s[0] = shape.Size(); NDArray arr = NDArray(s, Context()); arr = arr.AsArray(shape, arr.dtype()); diff --git a/tests/cpp/include/test_util.h b/tests/cpp/include/test_util.h index e0caddbcd027..b0114e1721ef 100644 --- a/tests/cpp/include/test_util.h +++ b/tests/cpp/include/test_util.h @@ -353,14 +353,14 @@ inline StreamType& print_blob_(const RunContext& ctx, if (dim == 1) { // probably a 1d tensor (mshadow::Tensor is deprecated) - TBlob changed(blob.dptr(), mxnet::TShape(3), blob.dev_mask(), blob.dev_id()); + TBlob changed(blob.dptr(), mxnet::TShape(3, -1), blob.dev_mask(), blob.dev_id()); changed.shape_[0] = 1; changed.shape_[1] = 1; changed.shape_[2] = blob.shape_[0]; return print_blob_(ctx, &os, changed, false, false, add_endl); } else if (dim == 2) { // probably a 2d tensor (mshadow::Tensor is deprecated) - TBlob changed(blob.dptr(), mxnet::TShape(4), blob.dev_mask(), blob.dev_id()); + TBlob changed(blob.dptr(), mxnet::TShape(4, -1), blob.dev_mask(), blob.dev_id()); changed.shape_[0] = 1; changed.shape_[1] = 1; changed.shape_[2] = blob.shape_[0]; diff --git a/tests/cpp/misc/serialization.cc b/tests/cpp/misc/serialization.cc index 77014238c2fa..2509a43c27ee 100644 --- a/tests/cpp/misc/serialization.cc +++ b/tests/cpp/misc/serialization.cc @@ -48,7 +48,7 @@ TEST(SerializerTest, OutputMapCorrect) { std::map > output_map; output_map.emplace("output_0", std::make_tuple(1, mxnet::TShape({23, 12, 63, 432}), 0, 1)); output_map.emplace("another_output", std::make_tuple(2, mxnet::TShape({23, 123}), 14, -23)); - output_map.emplace("last_output", std::make_tuple(0, mxnet::TShape({0}), -1, 0)); + output_map.emplace("last_output", std::make_tuple(0, mxnet::TShape(1, 0), -1, 0)); std::string serialized_data; common::Serialize(output_map, &serialized_data); std::map > deserialized_output_map; diff --git a/tests/cpp/operator/batchnorm_test.cc b/tests/cpp/operator/batchnorm_test.cc index d74493a0f7fb..ed0e70b831f1 100644 --- a/tests/cpp/operator/batchnorm_test.cc +++ b/tests/cpp/operator/batchnorm_test.cc @@ -1266,7 +1266,7 @@ static void testSaveAndLoad(const std::vector& dims, ChannelAxisTestData data; data.channel_data_ = inputChannelData; - mxnet::TShape shape(dims.size()); + mxnet::TShape shape(dims.size(), -1); for (size_t i = 0, n = dims.size(); i < n; ++i) { shape[i] = index_t(dims[i]); } @@ -1322,7 +1322,7 @@ static mxnet::TShape MakeShape(const std::vector& shape, } CHECK_LT(channelAxis, shape.size() + 1); const index_t dim = index_t(shape.size()) + 1; - mxnet::TShape newShape(dim); + mxnet::TShape newShape(dim, -1); for (size_t x = 0; x < static_cast(channelAxis); ++x) { newShape[x] = index_t(shape[x]); } diff --git a/tests/cpp/operator/mkldnn_operator_test.cc b/tests/cpp/operator/mkldnn_operator_test.cc index 559ab5da0ccc..961785dcfc87 100644 --- a/tests/cpp/operator/mkldnn_operator_test.cc +++ b/tests/cpp/operator/mkldnn_operator_test.cc @@ -916,13 +916,13 @@ void TestFullyConnectedOp(const OpAttrs &forward_attrs, const OpAttrs &backwards if (in_shape.ndim() < 2) continue; - mxnet::TShape wt_shape(2); + mxnet::TShape wt_shape(2, -1); wt_shape[0] = num_hid; wt_shape[1] = GetFCWeightDim2(in_shape); NDArray weights(wt_shape, Context()); InitDefaultArray(&weights, false); - mxnet::TShape bias_shape(1); + mxnet::TShape bias_shape(1, -1); bias_shape[0] = num_hid; NDArray bias(bias_shape, Context()); InitDefaultArray(&bias, false); @@ -931,7 +931,7 @@ void TestFullyConnectedOp(const OpAttrs &forward_attrs, const OpAttrs &backwards inputs[1] = &weights; inputs[2] = &bias; - mxnet::TShape out_shape(2); + mxnet::TShape out_shape(2, -1); out_shape[0] = in_shape[0]; out_shape[1] = num_hid; diff --git a/tests/python/gpu/test_operator_gpu.py b/tests/python/gpu/test_operator_gpu.py index fbbfc53a9a5e..19fc1eca89ce 100644 --- a/tests/python/gpu/test_operator_gpu.py +++ b/tests/python/gpu/test_operator_gpu.py @@ -1963,19 +1963,21 @@ def check_proposal_consistency(op, batch_size, with_nms=False): # The following 2 functions launch 0-thread kernels, an error that should be caught and signaled. def kernel_error_check_imperative(): os.environ['MXNET_ENGINE_TYPE'] = 'NaiveEngine' - a = mx.nd.array([1,2,3],ctx=mx.gpu(0)) - b = mx.nd.array([],ctx=mx.gpu(0)) - c = (a / b).asnumpy() + with mx.np_compat(active=True): + a = mx.nd.array([1,2,3],ctx=mx.gpu(0)) + b = mx.nd.array([],ctx=mx.gpu(0)) + c = (a / b).asnumpy() def kernel_error_check_symbolic(): os.environ['MXNET_ENGINE_TYPE'] = 'NaiveEngine' - a = mx.sym.Variable('a') - b = mx.sym.Variable('b') - c = a / b - f = c.bind(mx.gpu(0), { 'a':mx.nd.array([1,2,3],ctx=mx.gpu(0)), - 'b':mx.nd.array([],ctx=mx.gpu(0))}) - f.forward() - g = f.outputs[0].asnumpy() + with mx.np_compat(active=True): + a = mx.sym.Variable('a') + b = mx.sym.Variable('b') + c = a / b + f = c.bind(mx.gpu(0), { 'a':mx.nd.array([1,2,3],ctx=mx.gpu(0)), + 'b':mx.nd.array([],ctx=mx.gpu(0))}) + f.forward() + g = f.outputs[0].asnumpy() def test_kernel_error_checking(): # Running tests that may throw exceptions out of worker threads will stop CI testing diff --git a/tests/python/unittest/test_infer_shape.py b/tests/python/unittest/test_infer_shape.py index 73654a604135..612861bd8303 100644 --- a/tests/python/unittest/test_infer_shape.py +++ b/tests/python/unittest/test_infer_shape.py @@ -147,6 +147,21 @@ def test_fc_infer_type(): assert arg_type_dict[k] == v +def test_shape_completely_unknown(): + data = mx.sym.var("data") + ret = mx.sym.sin(data) + arg_shapes, out_shapes, _ = ret.infer_shape_partial() + assert arg_shapes[0] == () + assert out_shapes[0] == () + + with mx.np_compat(): + data = mx.sym.var("data") + ret = mx.sym.sin(data) + arg_shapes, out_shapes, _ = ret.infer_shape_partial() + assert arg_shapes[0] is None + assert out_shapes[0] is None + + if __name__ == "__main__": test_mlp2_infer_shape() test_mlp2_infer_error() @@ -156,3 +171,4 @@ def test_fc_infer_type(): test_incomplete_infer_slicechannel() test_incomplete_infer_convolution() test_incomplete_infer_concat() + test_shape_completely_unknown() diff --git a/tests/python/unittest/test_ndarray.py b/tests/python/unittest/test_ndarray.py index 2446107ad466..94777677354d 100644 --- a/tests/python/unittest/test_ndarray.py +++ b/tests/python/unittest/test_ndarray.py @@ -122,7 +122,11 @@ def test_ndarray_setitem(): # numpy assignment for empty axis for trivial_shape in [(), (1,), (1, 1), (1, 1, 1)]: - x = mx.nd.zeros(trivial_shape) + if trivial_shape == tuple(): + with mx.np_compat(): + x = mx.nd.zeros(trivial_shape) + else: + x = mx.nd.zeros(trivial_shape) x[:] = np.ones(trivial_shape) x_np = np.ones(trivial_shape, dtype=x.dtype) assert x.shape == trivial_shape diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 287b974d151e..9db1b5154bd1 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -4405,7 +4405,8 @@ def test_invalid_reps(): assert_exception(mx.nd.tile, MXNetError, data, (1, 0, 3)) test_normal_case() - test_empty_tensor() + with mx.np_compat(): + test_empty_tensor() test_empty_reps() test_tile_backward() test_tile_numeric_gradient() @@ -4465,7 +4466,8 @@ def test_zero_depth(): test_normal_case(index_type=np.float64) test_normal_case(index_type=np.float32) test_normal_case(index_type=np.float16) - test_empty_indices() + with mx.np_compat(): + test_empty_indices() test_zero_depth() @@ -6933,6 +6935,20 @@ def check_slice_axis_partial_infer(data, axis, begin, end, expected_out_shape): check_slice_axis_partial_infer(var1, 0, 0, 5, (5, 0)) check_slice_axis_partial_infer(var1, 1, 0, 5, (10, 0)) + with mx.np_compat(): + var1 = mx.sym.var(name="data", shape=(-1, 20)) + check_slice_partial_infer(var1, (None, None), (None, 10), [], (-1, 10)) + check_slice_partial_infer(var1, (None, None), (None, 10), (None, 2), (-1, 5)) + check_slice_partial_infer(var1, (None, 3), (None, 10), [], (-1, 7)) + check_slice_partial_infer(var1, (None, 3), (5, 10), [], (-1, 7)) + check_slice_partial_infer(var1, (2, 3), (None, 10), [], (-1, 7)) + check_slice_partial_infer(var1, (2, 3), (None, 10), (None, 1), (-1, 7)) + check_slice_partial_infer(var1, (2, 3), (None, 10), (3, 3), (-1, 3)) + + var1 = mx.sym.var(name='data', shape=(10, -1)) + check_slice_axis_partial_infer(var1, 0, 0, 5, (5, -1)) + check_slice_axis_partial_infer(var1, 1, 0, 5, (10, -1)) + @with_seed() def test_float16_min_max(): @@ -7981,6 +7997,74 @@ def test_image_normalize(): check_numeric_gradient(img_norm_sym, [data_in_4d], atol=0.001) +@with_seed() +def test_scalar_tensor_creation(): + assertRaises(MXNetError, mx.nd.zeros, shape=()) + assertRaises(MXNetError, mx.nd.ones, shape=()) + with mx.np_compat(): + data_mx = mx.nd.ones(shape=()) + data_np = np.ones((), dtype=data_mx.dtype) + assert same(data_mx.asnumpy(), data_np) + + +@with_seed() +def test_zero_size_tensor_creation(): + assertRaises(MXNetError, mx.nd.zeros, shape=(0, 1, 3, 0)) + assertRaises(MXNetError, mx.nd.ones, shape=(0, 1, 3, 0)) + with mx.np_compat(): + data_mx = mx.nd.ones(shape=(0, 1, 0, 4)) + data_np = np.ones(shape=data_mx.shape, dtype=data_mx.dtype) + assert same(data_mx.asnumpy(), data_np) + + +@with_seed() +def test_concat_with_zero_size_tensor(): + with mx.np_compat(): + data1 = mx.nd.ones((0, 8, 12)) + data2 = mx.nd.ones((3, 8, 12)) + data3 = mx.nd.ones((0, 8, 12)) + ret = mx.nd.Concat(data1, data2, data3, dim=0) + assert ret.shape == (3, 8, 12) + + data1 = mx.nd.ones((0, 3, 10)) + data2 = mx.nd.ones((0, 4, 10)) + data3 = mx.nd.ones((0, 5, 10)) + ret = mx.nd.Concat(data1, data2, data3, dim=1) + assert ret.shape == (0, 12, 10) + + +@with_seed() +def test_np_compat_decorator(): + @mx.use_np_compat + def check_scalar_one(): + """Generate scalar one tensor""" + return mx.nd.ones(shape=()) + assert check_scalar_one.__name__ == "check_scalar_one" + assert check_scalar_one.__doc__ == "Generate scalar one tensor" + assert check_scalar_one().shape == () + for active in [True, False]: + with mx.np_compat(active=active): + assert check_scalar_one.__name__ == "check_scalar_one" + assert check_scalar_one.__doc__ == "Generate scalar one tensor" + assert check_scalar_one().shape == () + + @mx.use_np_compat + def check_concat(shape1, shape2, axis): + data1 = mx.nd.ones(shape1) + data2 = mx.nd.ones(shape2) + ret = mx.nd.Concat(data1, data2, dim=axis) + expected_ret = np.concatenate((data1.asnumpy(), data2.asnumpy()), axis=axis) + assert ret.shape == expected_ret.shape + + check_concat((0, 3, 4), (5, 3, 4), 0) + check_concat((8, 0, 5), (8, 7, 5), 1) + check_concat((8, 0, 0), (8, 0, 0), 2) + for active in [True, False]: + check_concat((0, 3, 4), (5, 3, 4), 0) + check_concat((8, 0, 5), (8, 7, 5), 1) + check_concat((8, 0, 0), (8, 0, 0), 2) + + if __name__ == '__main__': import nose nose.runmodule() From 51d3291aec4fc5c057e1f56df17c4fed09745951 Mon Sep 17 00:00:00 2001 From: Piyush Ghai Date: Tue, 16 Apr 2019 12:19:01 -0700 Subject: [PATCH 7/9] Updated docs for R-package installation (#14269) * Updated docs for R-package installation * Addressed PR feedback * Removed spaces in cuda100 path * Removed cu80 mention --- docs/install/index.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/docs/install/index.md b/docs/install/index.md index f1e959eaf34c..5708a5b7d4c5 100644 --- a/docs/install/index.md +++ b/docs/install/index.md @@ -1125,7 +1125,8 @@ You can [build MXNet-R from source](windows_setup.html#install-mxnet-package-for options(repos = cran) install.packages("mxnet") ``` -Change cu92 to cu80, cu90 or cu91 based on your CUDA toolkit version. Currently, MXNet supports these versions of CUDA. +Change cu92 to cu90, cu91 or cuda100 based on your CUDA toolkit version. Currently, MXNet supports these versions of CUDA. +Note : You also need to have cuDNN installed on Windows. Check out this [guide](https://docs.nvidia.com/deeplearning/sdk/cudnn-install/index.html#installwindows) on the steps for installation. From 8e04b8870c01e2c7e88511aec5c7baa668c69a61 Mon Sep 17 00:00:00 2001 From: champagne828 <49048274+champagne828@users.noreply.github.com> Date: Wed, 17 Apr 2019 04:07:07 +0800 Subject: [PATCH 8/9] Update inception_inference.cpp (#14674) fix a bug when use --gpu --- cpp-package/example/inference/inception_inference.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp-package/example/inference/inception_inference.cpp b/cpp-package/example/inference/inception_inference.cpp index 78487e6ee0cd..fa5600190f95 100644 --- a/cpp-package/example/inference/inception_inference.cpp +++ b/cpp-package/example/inference/inception_inference.cpp @@ -301,7 +301,7 @@ void Predictor::PredictImage(const std::string& image_file) { executor->Forward(false); // The output is available in executor->outputs. - auto array = executor->outputs[0].Copy(global_ctx); + auto array = executor->outputs[0].Copy(Context::cpu()); /* * Find out the maximum accuracy and the index associated with that accuracy. From ff04de0a79a8f0c03948b304a02cfe25a7be0e00 Mon Sep 17 00:00:00 2001 From: Pedro Larroy Date: Tue, 16 Apr 2019 15:07:34 -0700 Subject: [PATCH 9/9] Add vim-nox to ci/docker/install/ubuntu_core.sh (#14632) --- ci/docker/install/ubuntu_core.sh | 1 + 1 file changed, 1 insertion(+) diff --git a/ci/docker/install/ubuntu_core.sh b/ci/docker/install/ubuntu_core.sh index 61a4637830da..3cb806e0aadd 100755 --- a/ci/docker/install/ubuntu_core.sh +++ b/ci/docker/install/ubuntu_core.sh @@ -45,6 +45,7 @@ apt-get install -y \ software-properties-common \ sudo \ unzip \ + vim-nox \ wget # Use libturbojpeg package as it is correctly compiled with -fPIC flag