diff --git a/python/mxnet/symbol/numpy/_symbol.py b/python/mxnet/symbol/numpy/_symbol.py index 67de3ba84856..ff2ca26a0e76 100644 --- a/python/mxnet/symbol/numpy/_symbol.py +++ b/python/mxnet/symbol/numpy/_symbol.py @@ -149,7 +149,6 @@ def __getitem__(self, key): # pylint: disable = too-many-return-statements, inco ' are supported! Received key={}'.format(key)) if is_symbol_tuple: return result - new_shape += (-4,) sliced = _npi.slice(self, begin, end, step) return _npi.reshape(sliced, new_shape) diff --git a/src/operator/numpy/np_indexing_op.cc b/src/operator/numpy/np_indexing_op.cc index b2b872248ae7..3a721319d287 100644 --- a/src/operator/numpy/np_indexing_op.cc +++ b/src/operator/numpy/np_indexing_op.cc @@ -198,9 +198,9 @@ void AdvancedIndexingOpForward(const nnvm::NodeAttrs& attrs, stream, idx_size, out.data().dptr(), data.data().dptr(), prefix_sum.data(), col_size); }); - } else if (inputs[np_indexing_::kIdx].dtype() == mshadow::kInt8 || - inputs[np_indexing_::kIdx].dtype() == mshadow::kInt16 || - inputs[np_indexing_::kIdx].dtype() == mshadow::kInt32 || + } else if (inputs[np_indexing_::kIdx].dtype() == mshadow::kInt8 || + inputs[np_indexing_::kIdx].dtype() == mshadow::kInt16 || + inputs[np_indexing_::kIdx].dtype() == mshadow::kInt32 || inputs[np_indexing_::kIdx].dtype() == mshadow::kInt64) { using namespace mshadow; const mxnet::TShape& idxshape = inputs[np_indexing_::kIdx].shape(); @@ -237,15 +237,16 @@ void AdvancedIndexingOpForward(const nnvm::NodeAttrs& attrs, bool is_valid = CheckIndexOutOfBound(idx_ptr, idx_size, min, max); CHECK(is_valid) << "take operator contains indices out of bound"; Kernel::Launch(s, idxshape.Size(), - outputs[np_indexing_::kOut].data().dptr(), - inputs[np_indexing_::kArr].data().dptr(), - inputs[np_indexing_::kIdx].data().dptr(), - oshape.Size()/idxshape.Size(), arrshape[0]); + outputs[np_indexing_::kOut].data().dptr(), + inputs[np_indexing_::kArr].data().dptr(), + inputs[np_indexing_::kIdx].data().dptr(), + oshape.Size()/idxshape.Size(), arrshape[0]); }); }); } else { - dmlc::LogMessageFatal(__FILE__, __LINE__).stream() << "arrays used as indices must be explictly declared as integer (or boolean) type. " - << "Use np.astype() to cast indices to integer or boolean."; + dmlc::LogMessageFatal(__FILE__, __LINE__).stream() + << "arrays used as indices must be explictly declared as integer (or boolean) type. " + << "Use np.astype() to cast indices to integer or boolean."; } } @@ -261,10 +262,11 @@ void AdvancedIndexingMultipleOpForward(const nnvm::NodeAttrs& attrs, CHECK_EQ(outputs.size(), 1U); if (inputs[np_indexing_::kIdx].dtype() == mshadow::kBool) { - dmlc::LogMessageFatal(__FILE__, __LINE__).stream() << "Multi-dimension boolean indexing is not supported."; - } else if (inputs[np_indexing_::kIdx].dtype() == mshadow::kInt8 || - inputs[np_indexing_::kIdx].dtype() == mshadow::kInt16 || - inputs[np_indexing_::kIdx].dtype() == mshadow::kInt32 || + dmlc::LogMessageFatal(__FILE__, __LINE__).stream() + << "Multi-dimension boolean indexing is not supported."; + } else if (inputs[np_indexing_::kIdx].dtype() == mshadow::kInt8 || + inputs[np_indexing_::kIdx].dtype() == mshadow::kInt16 || + inputs[np_indexing_::kIdx].dtype() == mshadow::kInt32 || inputs[np_indexing_::kIdx].dtype() == mshadow::kInt64) { using namespace mshadow; const mxnet::TShape& idxshape = inputs[np_indexing_::kIdx].shape(); @@ -274,7 +276,7 @@ void AdvancedIndexingMultipleOpForward(const nnvm::NodeAttrs& attrs, return; } - CHECK_EQ(arrshape[0], idxshape[0]); // size of index must equal to size of array + CHECK_EQ(arrshape[0], idxshape[0]); // size of index must equal to size of array mxnet::TShape oshape(arrshape.ndim() - 1, -1); oshape[0] = arrshape[0]; @@ -297,15 +299,16 @@ void AdvancedIndexingMultipleOpForward(const nnvm::NodeAttrs& attrs, bool is_valid = CheckIndexOutOfBound(idx_ptr, idx_size, min, max); CHECK(is_valid) << "take operator contains indices out of bound"; Kernel::Launch(s, idxshape.Size(), - outputs[np_indexing_::kOut].data().dptr(), - inputs[np_indexing_::kArr].data().dptr(), - inputs[np_indexing_::kIdx].data().dptr(), - oshape.Size()/idxshape.Size(), arrshape[1]); + outputs[np_indexing_::kOut].data().dptr(), + inputs[np_indexing_::kArr].data().dptr(), + inputs[np_indexing_::kIdx].data().dptr(), + oshape.Size()/idxshape.Size(), arrshape[1]); }); }); } else { - dmlc::LogMessageFatal(__FILE__, __LINE__).stream() << "arrays used as indices must be explictly declared as integer (or boolean) type. " - << "Use np.astype() to cast indices to integer or boolean."; + dmlc::LogMessageFatal(__FILE__, __LINE__).stream() + << "arrays used as indices must be explictly declared as integer (or boolean) type. " + << "Use np.astype() to cast indices to integer or boolean."; } } @@ -348,9 +351,9 @@ void AdvancedIndexingOpBackward(const nnvm::NodeAttrs& attrs, } }); }); - } else if (inputs[np_indexing_::kIdx+1].dtype() == mshadow::kInt8 || - inputs[np_indexing_::kIdx+1].dtype() == mshadow::kInt16 || - inputs[np_indexing_::kIdx+1].dtype() == mshadow::kInt32 || + } else if (inputs[np_indexing_::kIdx+1].dtype() == mshadow::kInt8 || + inputs[np_indexing_::kIdx+1].dtype() == mshadow::kInt16 || + inputs[np_indexing_::kIdx+1].dtype() == mshadow::kInt32 || inputs[np_indexing_::kIdx+1].dtype() == mshadow::kInt64) { using namespace mshadow; using namespace mshadow::expr; @@ -396,12 +399,12 @@ void AdvancedIndexingOpBackward(const nnvm::NodeAttrs& attrs, } else { LOG(FATAL) << "wrong req"; } - }); }); } else { - dmlc::LogMessageFatal(__FILE__, __LINE__).stream() << "arrays used as indices must be explictly declared as integer (or boolean) type. " - << "Use np.astype() to cast indices to integer or boolean."; + dmlc::LogMessageFatal(__FILE__, __LINE__).stream() + << "arrays used as indices must be explictly declared as integer (or boolean) type. " + << "Use np.astype() to cast indices to integer or boolean."; } } @@ -444,9 +447,9 @@ void AdvancedIndexingMultipleOpBackward(const nnvm::NodeAttrs& attrs, } }); }); - } else if (inputs[np_indexing_::kIdx+1].dtype() == mshadow::kInt8 || - inputs[np_indexing_::kIdx+1].dtype() == mshadow::kInt16 || - inputs[np_indexing_::kIdx+1].dtype() == mshadow::kInt32 || + } else if (inputs[np_indexing_::kIdx+1].dtype() == mshadow::kInt8 || + inputs[np_indexing_::kIdx+1].dtype() == mshadow::kInt16 || + inputs[np_indexing_::kIdx+1].dtype() == mshadow::kInt32 || inputs[np_indexing_::kIdx+1].dtype() == mshadow::kInt64) { using namespace mxnet_op; using namespace mshadow; @@ -463,20 +466,23 @@ void AdvancedIndexingMultipleOpBackward(const nnvm::NodeAttrs& attrs, MSHADOW_TYPE_SWITCH(inputs[2].dtype(), IType, { // index type if (req[0] != kAddTo) outputs[0].data().FlatTo1D(s) = 0; if (trailing == 1) { - Kernel, cpu>::Launch(s, inputs[0].data().Size(), outputs[0].data().dptr(), - inputs[0].data().dptr(), inputs[2].data().dptr(), - M, 1, Shape2(leading, M), Shape2(leading, 1)); + Kernel, cpu>::Launch(s, inputs[0].data().Size(), + outputs[0].data().dptr(), inputs[0].data().dptr(), + inputs[2].data().dptr(), M, + 1, Shape2(leading, M), Shape2(leading, 1)); } else { - Kernel, cpu>::Launch(s, inputs[0].data().Size(), outputs[0].data().dptr(), - inputs[0].data().dptr(), inputs[2].data().dptr(), - M, trailing, Shape3(leading, M, trailing), - Shape3(leading, 1, trailing)); + Kernel, cpu>::Launch(s, inputs[0].data().Size(), + outputs[0].data().dptr(), inputs[0].data().dptr(), + inputs[2].data().dptr(), M, + trailing, Shape3(leading, M, trailing), + Shape3(leading, 1, trailing)); } }); }); } else { - dmlc::LogMessageFatal(__FILE__, __LINE__).stream() << "arrays used as indices must be explictly declared as integer (or boolean) type. " - << "Use np.astype() to cast indices to integer or boolean."; + dmlc::LogMessageFatal(__FILE__, __LINE__).stream() + << "arrays used as indices must be explictly declared as integer (or boolean) type." + << "Use np.astype() to cast indices to integer or boolean."; } } @@ -545,7 +551,8 @@ which stands for the rows in x where the corresonding element in index is non-ze }) .set_attr("FInferType", AdvancedIndexingOpType) .set_attr("FComputeEx", AdvancedIndexingMultipleOpForward) -.set_attr("FGradient", ElemwiseGradUseIn{"_backward_np_advanced_indexing_multiple"}) +.set_attr("FGradient", + ElemwiseGradUseIn{"_backward_np_advanced_indexing_multiple"}) .set_attr("FInferStorageType", AdvancedIndexingOpStorageType) .add_argument("data", "NDArray-or-Symbol", "Data") .add_argument("indices", "NDArray-or-Symbol", "Indices"); diff --git a/src/operator/numpy/np_indexing_op.cu b/src/operator/numpy/np_indexing_op.cu index 0b6339ed36db..f623654ecca1 100644 --- a/src/operator/numpy/np_indexing_op.cu +++ b/src/operator/numpy/np_indexing_op.cu @@ -21,7 +21,6 @@ * \file np_indexing_op.cu */ -#include #include "./np_indexing_op.h" #include @@ -60,13 +59,13 @@ struct AdvancedIndexingTakeGPU { template MSHADOW_XINLINE static void Map(int i, DType* out_data, const DType* in_data, const IType* idx, const int64_t M, const int64_t K) { - int64_t j = static_cast(idx[i]); - j = j % K; - j += (j < 0) ? K : 0; + int64_t j = static_cast(idx[i]); + j = j % K; + j += (j < 0) ? K : 0; - for(int64_t k=0; k MSHADOW_XINLINE static void Map(int i, DType* out_data, const DType* in_data, const IType* idx, const int64_t M, const int64_t K) { - int64_t j = static_cast(idx[i]); - j = j % K; - j += (j < 0) ? K : 0; + int64_t j = static_cast(idx[i]); + j = j % K; + j += (j < 0) ? K : 0; - for(int64_t k=0; k(const nnvm::NodeAttrs& attrs, CHECK_EQ(inputs.size(), 2U); CHECK_EQ(outputs.size(), 1U); - if(inputs[np_indexing_::kIdx].dtype() == mshadow::kBool){ + if (inputs[np_indexing_::kIdx].dtype() == mshadow::kBool) { CHECK(req[0] == kWriteTo || req[0] == kWriteInplace); const int axis = 0; const NDArray &data = inputs[0]; @@ -155,13 +154,14 @@ inline void AdvancedIndexingOpForward(const nnvm::NodeAttrs& attrs, MSHADOW_TYPE_SWITCH_WITH_BOOL(out.dtype(), DType, { if (valid_num > 0) { mxnet_op::Kernel::Launch( - s, input_size, out.data().dptr(), data.data().dptr(), prefix_sum, col_size); + s, input_size, out.data().dptr(), + data.data().dptr(), prefix_sum, col_size); } }); -} else if (inputs[np_indexing_::kIdx].dtype() == mshadow::kInt8 || - inputs[np_indexing_::kIdx].dtype() == mshadow::kInt16 || - inputs[np_indexing_::kIdx].dtype() == mshadow::kInt32 || - inputs[np_indexing_::kIdx].dtype() == mshadow::kInt64){ +} else if (inputs[np_indexing_::kIdx].dtype() == mshadow::kInt8 || + inputs[np_indexing_::kIdx].dtype() == mshadow::kInt16 || + inputs[np_indexing_::kIdx].dtype() == mshadow::kInt32 || + inputs[np_indexing_::kIdx].dtype() == mshadow::kInt64) { using namespace mxnet_op; const mxnet::TShape& idxshape = inputs[np_indexing_::kIdx].shape(); const mxnet::TShape& arrshape = inputs[np_indexing_::kArr].shape(); @@ -188,7 +188,7 @@ inline void AdvancedIndexingOpForward(const nnvm::NodeAttrs& attrs, Stream *s = ctx.get_stream(); MSHADOW_TYPE_SWITCH_WITH_BOOL(outputs[np_indexing_::kOut].dtype(), DType, { // output data type - MSHADOW_TYPE_SWITCH_WITH_BOOL(inputs[np_indexing_::kIdx].dtype(), IType, { // index data type + MSHADOW_TYPE_SWITCH_WITH_BOOL(inputs[np_indexing_::kIdx].dtype(), IType, { IType min = 0; IType max = static_cast(arrshape[0] - 1); // check with single thread is faster since data is small @@ -200,17 +200,17 @@ inline void AdvancedIndexingOpForward(const nnvm::NodeAttrs& attrs, bool is_valid = CheckIndexOutOfBound(s, idx_ptr, idx_size, min, max, is_valid_ptr); CHECK(is_valid) << "take operator contains indices out of bound"; Kernel::Launch(s, idxshape.Size(), - outputs[np_indexing_::kOut].data().dptr(), - inputs[np_indexing_::kArr].data().dptr(), - inputs[np_indexing_::kIdx].data().dptr(), - oshape.Size()/idxshape.Size(), arrshape[0]); + outputs[np_indexing_::kOut].data().dptr(), + inputs[np_indexing_::kArr].data().dptr(), + inputs[np_indexing_::kIdx].data().dptr(), + oshape.Size()/idxshape.Size(), arrshape[0]); }); }); } else { - dmlc::LogMessageFatal(__FILE__, __LINE__).stream() << "arrays used as indices must be explictly declared as integer (or boolean) type. " - << "Use np.astype() to cast indices to integer or boolean."; + dmlc::LogMessageFatal(__FILE__, __LINE__).stream() + << "arrays used as indices must be explictly declared as integer (or boolean) type. " + << "Use np.astype() to cast indices to integer or boolean."; } - } template<> @@ -223,7 +223,7 @@ inline void AdvancedIndexingMultipleOpForward(const nnvm::NodeAttrs& attrs, CHECK_EQ(inputs.size(), 2U); CHECK_EQ(outputs.size(), 1U); - if(inputs[np_indexing_::kIdx].dtype() == mshadow::kBool){ + if (inputs[np_indexing_::kIdx].dtype() == mshadow::kBool) { CHECK(req[0] == kWriteTo || req[0] == kWriteInplace); const int axis = 0; const NDArray &data = inputs[0]; @@ -279,13 +279,14 @@ inline void AdvancedIndexingMultipleOpForward(const nnvm::NodeAttrs& attrs, MSHADOW_TYPE_SWITCH_WITH_BOOL(out.dtype(), DType, { if (valid_num > 0) { mxnet_op::Kernel::Launch( - s, input_size, out.data().dptr(), data.data().dptr(), prefix_sum, col_size); + s, input_size, out.data().dptr(), + data.data().dptr(), prefix_sum, col_size); } }); - } else if (inputs[np_indexing_::kIdx].dtype() == mshadow::kInt8 || - inputs[np_indexing_::kIdx].dtype() == mshadow::kInt16 || - inputs[np_indexing_::kIdx].dtype() == mshadow::kInt32 || - inputs[np_indexing_::kIdx].dtype() == mshadow::kInt64){ + } else if (inputs[np_indexing_::kIdx].dtype() == mshadow::kInt8 || + inputs[np_indexing_::kIdx].dtype() == mshadow::kInt16 || + inputs[np_indexing_::kIdx].dtype() == mshadow::kInt32 || + inputs[np_indexing_::kIdx].dtype() == mshadow::kInt64) { using namespace mxnet_op; const mxnet::TShape& idxshape = inputs[np_indexing_::kIdx].shape(); const mxnet::TShape& arrshape = inputs[np_indexing_::kArr].shape(); @@ -293,8 +294,8 @@ inline void AdvancedIndexingMultipleOpForward(const nnvm::NodeAttrs& attrs, if (idxshape.Size() == 0 || idxshape.Size() == 1) { return; } - - CHECK_EQ(arrshape[0], idxshape[0]); // size of index must equal to size of array + + CHECK_EQ(arrshape[0], idxshape[0]); // size of index must equal to size of array mxnet::TShape oshape(arrshape.ndim() - 1, -1); oshape[0] = arrshape[0]; @@ -307,8 +308,8 @@ inline void AdvancedIndexingMultipleOpForward(const nnvm::NodeAttrs& attrs, Stream *s = ctx.get_stream(); - MSHADOW_TYPE_SWITCH_WITH_BOOL(outputs[np_indexing_::kOut].dtype(), DType, { // output data type - MSHADOW_TYPE_SWITCH_WITH_BOOL(inputs[np_indexing_::kIdx].dtype(), IType, { // index data type + MSHADOW_TYPE_SWITCH_WITH_BOOL(outputs[np_indexing_::kOut].dtype(), DType, { + MSHADOW_TYPE_SWITCH_WITH_BOOL(inputs[np_indexing_::kIdx].dtype(), IType, { IType min = 0; IType max = static_cast(arrshape[0] - 1); // check with single thread is faster since data is small @@ -320,17 +321,17 @@ inline void AdvancedIndexingMultipleOpForward(const nnvm::NodeAttrs& attrs, bool is_valid = CheckIndexOutOfBound(s, idx_ptr, idx_size, min, max, is_valid_ptr); CHECK(is_valid) << "take operator contains indices out of bound"; Kernel::Launch(s, idxshape.Size(), - outputs[np_indexing_::kOut].data().dptr(), - inputs[np_indexing_::kArr].data().dptr(), - inputs[np_indexing_::kIdx].data().dptr(), - oshape.Size()/idxshape.Size(), arrshape[1]); + outputs[np_indexing_::kOut].data().dptr(), + inputs[np_indexing_::kArr].data().dptr(), + inputs[np_indexing_::kIdx].data().dptr(), + oshape.Size()/idxshape.Size(), arrshape[1]); }); }); } else { - dmlc::LogMessageFatal(__FILE__, __LINE__).stream() << "arrays used as indices must be explictly declared as integer (or boolean) type. " - << "Use np.astype() to cast indices to integer or boolean."; + dmlc::LogMessageFatal(__FILE__, __LINE__).stream() + << "arrays used as indices must be explictly declared as integer (or boolean) type. " + << "Use np.astype() to cast indices to integer or boolean."; } - } template<> @@ -392,9 +393,9 @@ inline void AdvancedIndexingOpBackward(const nnvm::NodeAttrs& attrs, prefix_sum, col_size); } }); - } else if (inputs[np_indexing_::kIdx+1].dtype() == mshadow::kInt8 || - inputs[np_indexing_::kIdx+1].dtype() == mshadow::kInt16 || - inputs[np_indexing_::kIdx+1].dtype() == mshadow::kInt32 || + } else if (inputs[np_indexing_::kIdx+1].dtype() == mshadow::kInt8 || + inputs[np_indexing_::kIdx+1].dtype() == mshadow::kInt16 || + inputs[np_indexing_::kIdx+1].dtype() == mshadow::kInt32 || inputs[np_indexing_::kIdx+1].dtype() == mshadow::kInt64) { using namespace mshadow; using namespace mshadow::expr; @@ -407,45 +408,45 @@ inline void AdvancedIndexingOpBackward(const nnvm::NodeAttrs& attrs, MSHADOW_TYPE_SWITCH(outputs[0].dtype(), DType, { // output data type MSHADOW_TYPE_SWITCH(inputs[2].dtype(), IType, { // index data type - // inputs are specified in the .cc file, which are the gradients from - // the upper layer and the input index - // outputs are the gradients of inputs in the feed-forward pass - const mxnet::TShape& idxshape = inputs[2].shape(); - const mxnet::TShape& arrshape = outputs[0].shape(); - const mxnet::TShape& oshape = inputs[0].shape(); - - if (idxshape.Size() == 0) { - return; - } - - if (req[np_indexing_::kIdx] != kNullOp) { - mxnet_op::Kernel::Launch( - s, idxshape.Size(), outputs[np_indexing_::kIdx].data().dptr()); - } - - int idxndim = idxshape.ndim(); - Tensor idx = inputs[2].data().get_with_shape( - Shape1(idxshape.ProdShape(0, idxndim)), s); - Tensor grad_out = inputs[0].data().get_with_shape( - Shape2(oshape.ProdShape(0, idxndim), oshape.ProdShape(idxndim, oshape.ndim())), s); - Tensor grad_in = outputs[0].data().get_with_shape( - Shape2(arrshape[0], arrshape.ProdShape(1, arrshape.ndim())), s); - - // re-using the previous code for axis = 0 case - if (req[np_indexing_::kArr] == kWriteTo || req[np_indexing_::kArr] == kAddTo) { - if (req[np_indexing_::kArr] == kWriteTo) { - grad_in = scalar(0.0f); - } - AddTakeGrad(grad_in, idx, grad_out); - } else { - LOG(FATAL) << "wrong req"; - } - + // inputs are specified in the .cc file, which are the gradients from + // the upper layer and the input index + // outputs are the gradients of inputs in the feed-forward pass + const mxnet::TShape& idxshape = inputs[2].shape(); + const mxnet::TShape& arrshape = outputs[0].shape(); + const mxnet::TShape& oshape = inputs[0].shape(); + + if (idxshape.Size() == 0) { + return; + } + + if (req[np_indexing_::kIdx] != kNullOp) { + mxnet_op::Kernel::Launch( + s, idxshape.Size(), outputs[np_indexing_::kIdx].data().dptr()); + } + + int idxndim = idxshape.ndim(); + Tensor idx = inputs[2].data().get_with_shape( + Shape1(idxshape.ProdShape(0, idxndim)), s); + Tensor grad_out = inputs[0].data().get_with_shape( + Shape2(oshape.ProdShape(0, idxndim), oshape.ProdShape(idxndim, oshape.ndim())), s); + Tensor grad_in = outputs[0].data().get_with_shape( + Shape2(arrshape[0], arrshape.ProdShape(1, arrshape.ndim())), s); + + // re-using the previous code for axis = 0 case + if (req[np_indexing_::kArr] == kWriteTo || req[np_indexing_::kArr] == kAddTo) { + if (req[np_indexing_::kArr] == kWriteTo) { + grad_in = scalar(0.0f); + } + AddTakeGrad(grad_in, idx, grad_out); + } else { + LOG(FATAL) << "wrong req"; + } }); }); } else { - dmlc::LogMessageFatal(__FILE__, __LINE__).stream() << "arrays used as indices must be explictly declared as integer (or boolean) type. " - << "Use np.astype() to cast indices to integer or boolean."; + dmlc::LogMessageFatal(__FILE__, __LINE__).stream() + << "arrays used as indices must be explictly declared as integer (or boolean) type. " + << "Use np.astype() to cast indices to integer or boolean."; } } @@ -508,9 +509,9 @@ inline void AdvancedIndexingMultipleOpBackward(const nnvm::NodeAttrs& attrs prefix_sum, col_size); } }); - } else if (inputs[np_indexing_::kIdx+1].dtype() == mshadow::kInt8 || - inputs[np_indexing_::kIdx+1].dtype() == mshadow::kInt16 || - inputs[np_indexing_::kIdx+1].dtype() == mshadow::kInt32 || + } else if (inputs[np_indexing_::kIdx+1].dtype() == mshadow::kInt8 || + inputs[np_indexing_::kIdx+1].dtype() == mshadow::kInt16 || + inputs[np_indexing_::kIdx+1].dtype() == mshadow::kInt32 || inputs[np_indexing_::kIdx+1].dtype() == mshadow::kInt64) { using namespace mxnet_op; using namespace mshadow; @@ -527,20 +528,23 @@ inline void AdvancedIndexingMultipleOpBackward(const nnvm::NodeAttrs& attrs MSHADOW_TYPE_SWITCH(inputs[2].dtype(), IType, { // index type if (req[0] != kAddTo) outputs[0].data().FlatTo1D(s) = 0; if (trailing == 1) { - Kernel, gpu>::Launch(s, inputs[0].data().Size(), outputs[0].data().dptr(), - inputs[0].data().dptr(), inputs[2].data().dptr(), - M, 1, Shape2(leading, M), Shape2(leading, 1)); + Kernel, gpu>::Launch(s, inputs[0].data().Size(), + outputs[0].data().dptr(), inputs[0].data().dptr(), + inputs[2].data().dptr(), M, + 1, Shape2(leading, M), Shape2(leading, 1)); } else { - Kernel, gpu>::Launch(s, inputs[0].data().Size(), outputs[0].data().dptr(), - inputs[0].data().dptr(), inputs[2].data().dptr(), - M, trailing, Shape3(leading, M, trailing), - Shape3(leading, 1, trailing)); + Kernel, gpu>::Launch(s, inputs[0].data().Size(), + outputs[0].data().dptr(), inputs[0].data().dptr(), + inputs[2].data().dptr(), M, + trailing, Shape3(leading, M, trailing), + Shape3(leading, 1, trailing)); } }); }); } else { - dmlc::LogMessageFatal(__FILE__, __LINE__).stream() << "arrays used as indices must be explictly declared as integer (or boolean) type. " - << "Use np.astype() to cast indices to integer or boolean."; + dmlc::LogMessageFatal(__FILE__, __LINE__).stream() + << "arrays used as indices must be explictly declared as integer (or boolean) type. " + << "Use np.astype() to cast indices to integer or boolean."; } } @@ -567,4 +571,4 @@ NNVM_REGISTER_OP(_backward_np_advanced_indexing_multiple) .set_attr("FComputeEx", AdvancedIndexingMultipleOpBackward); } // namespace op -} // namespace mxnet \ No newline at end of file +} // namespace mxnet diff --git a/src/operator/numpy/np_indexing_op.h b/src/operator/numpy/np_indexing_op.h index 2ba007b4b3fb..3b33aeae2fbe 100644 --- a/src/operator/numpy/np_indexing_op.h +++ b/src/operator/numpy/np_indexing_op.h @@ -25,6 +25,7 @@ #ifndef MXNET_OPERATOR_NUMPY_NP_INDEXING_OP_H_ #define MXNET_OPERATOR_NUMPY_NP_INDEXING_OP_H_ +#include #include "../contrib/boolean_mask-inl.h" #include "../tensor/indexing_op.h" #include "../tensor/broadcast_reduce_op.h" @@ -63,21 +64,21 @@ void AdvancedIndexingMultipleOpForward(const nnvm::NodeAttrs& attrs, const std::vector& req, const std::vector& outputs); -template +template void AdvancedIndexingOpBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, const std::vector &inputs, const std::vector &req, const std::vector &outputs); -template -void AdvancedIndexingMultipleOpBackward(const nnvm::NodeAttrs& attrs, - const OpContext& ctx, - const std::vector& inputs, - const std::vector& req, +template +void AdvancedIndexingMultipleOpBackward(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, const std::vector& outputs); } // namespace op } // namespace mxnet -#endif // MXNET_OPERATOR_NUMPY_NP_INDEXING_OP_H_ \ No newline at end of file +#endif // MXNET_OPERATOR_NUMPY_NP_INDEXING_OP_H_