From 4c72d2775c97eac143a8587c1a22c9893d8da126 Mon Sep 17 00:00:00 2001 From: Hao Jin Date: Tue, 3 Sep 2019 14:11:27 +0800 Subject: [PATCH] new raise mode for nd.take and fix backward for wrap mode (#15887) --- .../mshadow/mshadow/cuda/tensor_gpu-inl.cuh | 26 ++++++++++++++---- 3rdparty/mshadow/mshadow/tensor.h | 4 +-- 3rdparty/mshadow/mshadow/tensor_cpu-inl.h | 11 +++++--- 3rdparty/mshadow/mshadow/tensor_gpu-inl.h | 4 +-- src/operator/tensor/indexing_op.cc | 15 ++++++++--- src/operator/tensor/indexing_op.cu | 18 ++++++++++--- src/operator/tensor/indexing_op.h | 27 +++++++++++++------ tests/python/unittest/test_ndarray.py | 2 +- tests/python/unittest/test_operator.py | 25 ++++++++++++++--- 9 files changed, 102 insertions(+), 30 deletions(-) diff --git a/3rdparty/mshadow/mshadow/cuda/tensor_gpu-inl.cuh b/3rdparty/mshadow/mshadow/cuda/tensor_gpu-inl.cuh index 72e4b7eb9ee9..3ebe83cff146 100755 --- a/3rdparty/mshadow/mshadow/cuda/tensor_gpu-inl.cuh +++ b/3rdparty/mshadow/mshadow/cuda/tensor_gpu-inl.cuh @@ -596,7 +596,8 @@ inline void SoftmaxGrad(const Tensor &dst, MSHADOW_CUDA_POST_KERNEL_CHECK(Softmax3DGradKernel); } -template +template __global__ void AddTakeGradKernel(DstPlan dst, SrcPlan1 index, SrcPlan2 src, index_t ymax, index_t xmax, const int K) { @@ -606,8 +607,13 @@ __global__ void AddTakeGradKernel(DstPlan dst, for (unsigned y = 0; y < ymax; ++y) { if (threadIdx.x == 0) { ptr = index.Eval(0, y); - if (ptr <= 0) ptr = 0; - else if (ptr >= K) ptr = K - 1; + if (clip) { + if (ptr <= 0) ptr = 0; + else if (ptr >= K) ptr = K - 1; + } else { + ptr %= K; + if (ptr < 0) ptr += K; + } } __syncthreads(); if (xindex < xmax) { @@ -671,7 +677,7 @@ __global__ void AddTakeGradLargeBatchKernel(DType* dst, } } -template +template inline void AddTakeGrad(Tensor dst, const Tensor& index, const Tensor &src) { @@ -688,13 +694,23 @@ inline void AddTakeGrad(Tensor dst, cudaStream_t stream = Stream::GetStream(dst.stream_); const int K = dst.shape_[0]; - AddTakeGradKernel + if (clip) { + AddTakeGradKernel + <<>> + (expr::MakePlan(dst), + expr::MakePlan(index), + expr::MakePlan(src), + src.size(0), + src.size(1), K); + } else { + AddTakeGradKernel <<>> (expr::MakePlan(dst), expr::MakePlan(index), expr::MakePlan(src), src.size(0), src.size(1), K); + } MSHADOW_CUDA_POST_KERNEL_CHECK(AddTakeGradKernel); } diff --git a/3rdparty/mshadow/mshadow/tensor.h b/3rdparty/mshadow/mshadow/tensor.h index ad29e751a050..0add5a87d63b 100755 --- a/3rdparty/mshadow/mshadow/tensor.h +++ b/3rdparty/mshadow/mshadow/tensor.h @@ -817,7 +817,7 @@ inline void SoftmaxGrad(const Tensor &dst, * \param index index to take * \param src source output */ -template +template inline void AddTakeGrad(Tensor dst, const Tensor& index, const Tensor &src); @@ -829,7 +829,7 @@ inline void AddTakeGrad(Tensor dst, * \param index index to take * \param src source output */ -template +template inline void AddTakeGrad(Tensor dst, const Tensor& index, const Tensor &src); diff --git a/3rdparty/mshadow/mshadow/tensor_cpu-inl.h b/3rdparty/mshadow/mshadow/tensor_cpu-inl.h index ab5f9a68df14..b7ae77fe56a4 100755 --- a/3rdparty/mshadow/mshadow/tensor_cpu-inl.h +++ b/3rdparty/mshadow/mshadow/tensor_cpu-inl.h @@ -494,15 +494,20 @@ inline void Softmax(Tensor dst, } } -template +template inline void AddTakeGrad(Tensor dst, const Tensor& index, const Tensor &src) { const int K = dst.shape_[0]; for (index_t y = 0; y < index.size(0); ++y) { int j = index[y]; - if (j <= 0) j = 0; - else if (j >= K) j = K - 1; + if (clip) { + if (j <= 0) j = 0; + else if (j >= K) j = K - 1; + } else { + j %= K; + if (j < 0) j += K; + } dst[j] += src[y]; } } diff --git a/3rdparty/mshadow/mshadow/tensor_gpu-inl.h b/3rdparty/mshadow/mshadow/tensor_gpu-inl.h index 94fdb0527e72..009e5d6530ff 100755 --- a/3rdparty/mshadow/mshadow/tensor_gpu-inl.h +++ b/3rdparty/mshadow/mshadow/tensor_gpu-inl.h @@ -213,11 +213,11 @@ inline void SoftmaxGrad(const Tensor &dst, cuda::SoftmaxGrad(dst, src, label, ignore_label); } -template +template inline void AddTakeGrad(Tensor dst, const Tensor& index, const Tensor &src) { - cuda::AddTakeGrad(dst, index, src); + cuda::AddTakeGrad(dst, index, src); } template diff --git a/src/operator/tensor/indexing_op.cc b/src/operator/tensor/indexing_op.cc index 21aefc5b2fd4..147205505e24 100644 --- a/src/operator/tensor/indexing_op.cc +++ b/src/operator/tensor/indexing_op.cc @@ -291,8 +291,17 @@ void TakeOpForward(const nnvm::NodeAttrs& attrs, Stream *s = ctx.get_stream(); const int actual_axis = param.axis + ((param.axis < 0) ? arrshape.ndim() : 0); - MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, { // output data type - MSHADOW_TYPE_SWITCH(inputs[1].type_flag_, IType, { // index data type + MSHADOW_TYPE_SWITCH(outputs[take_::kOut].type_flag_, DType, { // output data type + MSHADOW_TYPE_SWITCH(inputs[take_::kIdx].type_flag_, IType, { // index data type + if (param.mode == take_::kRaise) { + IType min = 0; + IType max = static_cast(arrshape[actual_axis] - 1); + // check with single thread is faster since data is small + IType* idx_ptr = inputs[take_::kIdx].dptr(); + size_t idx_size = idxshape.Size(); + bool is_valid = CheckIndexOutOfBound(idx_ptr, idx_size, min, max); + CHECK(is_valid) << "take operator contains indices out of bound"; + } if (actual_axis == 0) { if (param.mode == take_::kClip) { Kernel, cpu>::Launch(s, idxshape.Size(), @@ -326,7 +335,7 @@ void TakeOpForward(const nnvm::NodeAttrs& attrs, in_strides, out_strides, arrshape.ndim(), oshape.ndim(), idxshape.ndim(), arrshape[actual_axis], actual_axis); - } else if (param.mode == take_::kWrap) { + } else { Kernel, cpu>::Launch(s, oshape.Size(), outputs[take_::kOut].dptr(), inputs[take_::kArr].dptr(), diff --git a/src/operator/tensor/indexing_op.cu b/src/operator/tensor/indexing_op.cu index 94fe377ebbc7..9a46d894ee22 100644 --- a/src/operator/tensor/indexing_op.cu +++ b/src/operator/tensor/indexing_op.cu @@ -482,8 +482,20 @@ void TakeOpForward(const nnvm::NodeAttrs& attrs, Stream *s = ctx.get_stream(); const int actual_axis = param.axis + ((param.axis < 0) ? arrshape.ndim() : 0); - MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, { // output data type - MSHADOW_TYPE_SWITCH(inputs[1].type_flag_, IType, { // index data type + MSHADOW_TYPE_SWITCH(outputs[take_::kOut].type_flag_, DType, { // output data type + MSHADOW_TYPE_SWITCH(inputs[take_::kIdx].type_flag_, IType, { // index data type + if (param.mode == take_::kRaise) { + // check out-of-bound indices + IType min = 0; + IType max = static_cast(arrshape[actual_axis] - 1); + IType* idx_ptr = inputs[take_::kIdx].dptr(); + size_t idx_size = idxshape.Size(); + Tensor workspace = + ctx.requested[0].get_space_typed(Shape1(1), s); + char* is_valid_ptr = reinterpret_cast(workspace.dptr_); + bool is_valid = CheckIndexOutOfBound(s, idx_ptr, idx_size, min, max, is_valid_ptr); + CHECK(is_valid) << "Take indices contains indices out of bound"; + } if (actual_axis == 0) { if (param.mode == take_::kClip) { Kernel, gpu>::Launch(s, oshape.Size(), @@ -516,7 +528,7 @@ void TakeOpForward(const nnvm::NodeAttrs& attrs, inputs[take_::kIdx].dptr(), in_strides, out_strides, arrshape.ndim(), oshape.ndim(), idxshape.ndim(), arrshape[actual_axis], actual_axis); - } else if (param.mode == take_::kWrap) { + } else { Kernel, gpu>::Launch(s, oshape.Size(), outputs[take_::kOut].dptr(), inputs[take_::kArr].dptr(), diff --git a/src/operator/tensor/indexing_op.h b/src/operator/tensor/indexing_op.h index 7dfc77d25b27..161acae0ebf2 100644 --- a/src/operator/tensor/indexing_op.h +++ b/src/operator/tensor/indexing_op.h @@ -684,9 +684,6 @@ inline bool TakeOpShape(const nnvm::NodeAttrs& attrs, const mxnet::TShape &idxshape = (*in_attrs)[take_::kIdx]; if (!shape_is_known(idxshape)) return false; const TakeParam& param = nnvm::get(attrs.parsed); - if (param.mode == take_::kRaise) { - LOG(FATAL) << "Raise is not supported for the time being..."; - } CHECK(param.axis >= -1 * arrshape.ndim() && param.axis < arrshape.ndim()) << "Axis should be in the range of [-r, r-1] where r is the rank of input tensor"; @@ -813,14 +810,15 @@ struct TakeGradGeneralKernel { const IType* src_indptr, const IType* original_idx, mshadow::Shape<10> in_strides, mshadow::Shape<10> out_strides, const int in_ndims, const int out_ndims, const int idx_ndims, - const int axis) { + const int axis, const int K) { const int in_head_index = (axis == 0) ? 0 : tid / in_strides[axis - 1]; const int in_rest_index = (axis == 0) ? tid : tid % in_strides[axis - 1]; const int in_mid_index = in_rest_index / in_strides[axis]; const int in_tail_index = (axis == in_ndims - 1) ? 0 : (in_rest_index % in_strides[axis]); for (IType i = src_indptr[in_mid_index]; i < src_indptr[in_mid_index + 1]; ++i) { - const int out_mid_index = original_idx[i]; + int out_mid_index = original_idx[i]; + out_mid_index = (out_mid_index < 0) ? out_mid_index + K : out_mid_index; int target = in_tail_index + out_mid_index * in_strides[axis]; target += (axis == 0) ? 0 : in_head_index * out_strides[axis - 1]; arr_grad[tid] += ograd[target]; @@ -894,7 +892,7 @@ void TakeOpBackwardImpl(mshadow::Stream* s, Kernel::Launch( s, arrshape.Size(), arr.dptr(), ograd.dptr(), src_indptr_ptr, original_idx_ptr, in_strides, out_strides, - arrshape.ndim(), oshape.ndim(), idxshape.ndim(), axis); + arrshape.ndim(), oshape.ndim(), idxshape.ndim(), axis, static_cast(arrshape[axis])); }); }); } @@ -968,6 +966,15 @@ void TakeOpBackwardImpl(mshadow::Stream* s, int num_bits = common::ilog2ui(static_cast(idxshape.Size()) - 1); Tensor sorted_idx(sorted_idx_ptr, Shape1(idxshape.Size()), s); SortByKey(sorted_idx, original_idx, true, &temp_storage, 0, num_bits); + cub::DeviceHistogram::HistogramEven(temp_storage_ptr, + temp_storage_bytes, + sorted_idx_ptr, + src_indptr_ptr, + static_cast(arrshape[axis] + 1), + 0, + static_cast(arrshape[axis] + 1), + static_cast(idxshape.Size()), + mshadow::Stream::GetStream(s)); cub::DeviceScan::ExclusiveSum(temp_storage_ptr, temp_storage_bytes, src_indptr_ptr, @@ -989,7 +996,7 @@ void TakeOpBackwardImpl(mshadow::Stream* s, Kernel::Launch( s, arrshape.Size(), arr.dptr(), ograd.dptr(), src_indptr_ptr, original_idx_ptr, in_strides, out_strides, - arrshape.ndim(), oshape.ndim(), idxshape.ndim(), axis); + arrshape.ndim(), oshape.ndim(), idxshape.ndim(), axis, static_cast(arrshape[axis])); }); }); } @@ -1044,7 +1051,11 @@ void TakeOpBackward(const nnvm::NodeAttrs& attrs, if (req[take_::kArr] == kWriteTo) { grad_in = scalar(0.0f); } - AddTakeGrad(grad_in, idx, grad_out); + if (param.mode == take_::kClip) { + AddTakeGrad(grad_in, idx, grad_out); + } else { + AddTakeGrad(grad_in, idx, grad_out); + } } else { LOG(FATAL) << "wrong req"; } diff --git a/tests/python/unittest/test_ndarray.py b/tests/python/unittest/test_ndarray.py index e15240c30dd3..c1f0ef05f979 100644 --- a/tests/python/unittest/test_ndarray.py +++ b/tests/python/unittest/test_ndarray.py @@ -1136,7 +1136,7 @@ def test_ndarray_lesser_equal(): @with_seed() -def test_take(): +def test_ndarray_take(): for data_ndim in range(2, 5): for idx_ndim in range(1, 4): data_shape = () diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index e4617afd112a..f0e16b2729a2 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -4426,7 +4426,7 @@ def grad_helper(grad_in, axis, idx): else: raise ValueError("axis %d is not supported..." % axis) - def check_output_n_grad(data_shape, idx_shape, axis, mode): + def check_output_n_grad(data_shape, idx_shape, axis, mode, out_of_range=True): data = mx.sym.Variable('a') idx = mx.sym.Variable('indices') idx = mx.sym.BlockGrad(idx) @@ -4434,7 +4434,13 @@ def check_output_n_grad(data_shape, idx_shape, axis, mode): exe = result.simple_bind(default_context(), a=data_shape, indices=idx_shape, axis=axis, mode=mode) data_real = np.random.normal(size=data_shape).astype('float32') - idx_real = np.random.randint(low=0, high=data_shape[axis], size=idx_shape) + if out_of_range: + idx_real = np.random.randint(low=-data_shape[axis], high=data_shape[axis], size=idx_shape) + if mode == 'raise': + idx_real[idx_real == 0] = 1 + idx_real *= data_shape[axis] + else: + idx_real = np.random.randint(low=0, high=data_shape[axis], size=idx_shape) if axis < 0: axis += len(data_shape) @@ -4444,9 +4450,20 @@ def check_output_n_grad(data_shape, idx_shape, axis, mode): exe.arg_dict['a'][:] = mx.nd.array(data_real) exe.arg_dict['indices'][:] = mx.nd.array(idx_real) exe.forward(is_train=True) + if out_of_range and mode == 'raise': + try: + mx_out = exe.outputs[0].asnumpy() + except MXNetError as e: + return + else: + # Did not raise exception + assert False, "did not raise %s" % MXNetError.__name__ + assert_almost_equal(exe.outputs[0].asnumpy(), np.take(data_real, idx_real, axis=axis, mode=mode)) for i in np.nditer(idx_real): + if mode == 'clip': + i = np.clip(i, 0, data_shape[axis]) grad_helper(grad_in, axis, i) exe.backward([mx.nd.array(grad_out)]) @@ -4477,7 +4494,7 @@ def check_autograd_req(): x.backward() assert_almost_equal(np.ones(sc.grad.shape), sc.grad.asnumpy()) - for mode in ['clip', 'wrap']: + for mode in ['clip', 'wrap', 'raise']: for data_ndim in range(1, 5): for idx_ndim in range(1, 4): for axis in range(-data_ndim, data_ndim): @@ -4487,6 +4504,8 @@ def check_autograd_req(): idx_shape = () for _ in range(idx_ndim): idx_shape += (np.random.randint(low=1, high=5), ) + if mode == 'raise': + check_output_n_grad(data_shape, idx_shape, axis, 'raise', False) check_output_n_grad(data_shape, idx_shape, axis, mode) check_autograd_req()