Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

new 'raise' mode for nd.take and fix for backward of 'wrap' mode #15887

Merged
merged 1 commit into from
Sep 3, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 21 additions & 5 deletions 3rdparty/mshadow/mshadow/cuda/tensor_gpu-inl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -596,7 +596,8 @@ inline void SoftmaxGrad(const Tensor<gpu, 3, DType> &dst,
MSHADOW_CUDA_POST_KERNEL_CHECK(Softmax3DGradKernel);
}

template<int x_bits, typename DType, typename DstPlan, typename SrcPlan1, typename SrcPlan2>
template<bool clip, int x_bits, typename DType, typename DstPlan,
typename SrcPlan1, typename SrcPlan2>
__global__ void AddTakeGradKernel(DstPlan dst,
SrcPlan1 index, SrcPlan2 src,
index_t ymax, index_t xmax, const int K) {
Expand All @@ -606,8 +607,13 @@ __global__ void AddTakeGradKernel(DstPlan dst,
for (unsigned y = 0; y < ymax; ++y) {
if (threadIdx.x == 0) {
ptr = index.Eval(0, y);
if (ptr <= 0) ptr = 0;
else if (ptr >= K) ptr = K - 1;
if (clip) {
if (ptr <= 0) ptr = 0;
else if (ptr >= K) ptr = K - 1;
} else {
ptr %= K;
if (ptr < 0) ptr += K;
}
}
__syncthreads();
if (xindex < xmax) {
Expand Down Expand Up @@ -671,7 +677,7 @@ __global__ void AddTakeGradLargeBatchKernel(DType* dst,
}
}

template<typename IndexType, typename DType>
template<bool clip = true, typename IndexType, typename DType>
inline void AddTakeGrad(Tensor<gpu, 2, DType> dst,
const Tensor<gpu, 1, IndexType>& index,
const Tensor<gpu, 2, DType> &src) {
Expand All @@ -688,13 +694,23 @@ inline void AddTakeGrad(Tensor<gpu, 2, DType> dst,
cudaStream_t stream = Stream<gpu>::GetStream(dst.stream_);
const int K = dst.shape_[0];

AddTakeGradKernel<kUnitBits, DType>
if (clip) {
AddTakeGradKernel<true, kUnitBits, DType>
<<<dimGrid, dimBlock, 0, stream>>>
(expr::MakePlan(dst),
expr::MakePlan(index),
expr::MakePlan(src),
src.size(0),
src.size(1), K);
} else {
AddTakeGradKernel<false, kUnitBits, DType>
<<<dimGrid, dimBlock, 0, stream>>>
(expr::MakePlan(dst),
expr::MakePlan(index),
expr::MakePlan(src),
src.size(0),
src.size(1), K);
}
MSHADOW_CUDA_POST_KERNEL_CHECK(AddTakeGradKernel);
}

Expand Down
4 changes: 2 additions & 2 deletions 3rdparty/mshadow/mshadow/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -817,7 +817,7 @@ inline void SoftmaxGrad(const Tensor<gpu, 2, DType> &dst,
* \param index index to take
* \param src source output
*/
template<typename IndexType, typename DType>
template<bool clip = true, typename IndexType, typename DType>
inline void AddTakeGrad(Tensor<cpu, 2, DType> dst,
const Tensor<cpu, 1, IndexType>& index,
const Tensor<cpu, 2, DType> &src);
Expand All @@ -829,7 +829,7 @@ inline void AddTakeGrad(Tensor<cpu, 2, DType> dst,
* \param index index to take
* \param src source output
*/
template<typename IndexType, typename DType>
template<bool clip = true, typename IndexType, typename DType>
inline void AddTakeGrad(Tensor<gpu, 2, DType> dst,
const Tensor<gpu, 1, IndexType>& index,
const Tensor<gpu, 2, DType> &src);
Expand Down
11 changes: 8 additions & 3 deletions 3rdparty/mshadow/mshadow/tensor_cpu-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -494,15 +494,20 @@ inline void Softmax(Tensor<cpu, 3, DType> dst,
}
}

template<typename IndexType, typename DType>
template<bool clip, typename IndexType, typename DType>
inline void AddTakeGrad(Tensor<cpu, 2, DType> dst,
const Tensor<cpu, 1, IndexType>& index,
const Tensor<cpu, 2, DType> &src) {
const int K = dst.shape_[0];
for (index_t y = 0; y < index.size(0); ++y) {
int j = index[y];
if (j <= 0) j = 0;
else if (j >= K) j = K - 1;
if (clip) {
if (j <= 0) j = 0;
else if (j >= K) j = K - 1;
} else {
j %= K;
if (j < 0) j += K;
}
dst[j] += src[y];
}
}
Expand Down
4 changes: 2 additions & 2 deletions 3rdparty/mshadow/mshadow/tensor_gpu-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -213,11 +213,11 @@ inline void SoftmaxGrad(const Tensor<gpu, 3, DType> &dst,
cuda::SoftmaxGrad(dst, src, label, ignore_label);
}

template<typename IndexType, typename DType>
template<bool clip, typename IndexType, typename DType>
inline void AddTakeGrad(Tensor<gpu, 2, DType> dst,
const Tensor<gpu, 1, IndexType>& index,
const Tensor<gpu, 2, DType> &src) {
cuda::AddTakeGrad(dst, index, src);
cuda::AddTakeGrad<clip, IndexType, DType>(dst, index, src);
}

template<typename IndexType, typename DType>
Expand Down
15 changes: 12 additions & 3 deletions src/operator/tensor/indexing_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -291,8 +291,17 @@ void TakeOpForward<cpu>(const nnvm::NodeAttrs& attrs,
Stream<cpu> *s = ctx.get_stream<cpu>();
const int actual_axis = param.axis + ((param.axis < 0) ? arrshape.ndim() : 0);

MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, { // output data type
MSHADOW_TYPE_SWITCH(inputs[1].type_flag_, IType, { // index data type
MSHADOW_TYPE_SWITCH(outputs[take_::kOut].type_flag_, DType, { // output data type
MSHADOW_TYPE_SWITCH(inputs[take_::kIdx].type_flag_, IType, { // index data type
if (param.mode == take_::kRaise) {
IType min = 0;
IType max = static_cast<IType>(arrshape[actual_axis] - 1);
// check with single thread is faster since data is small
IType* idx_ptr = inputs[take_::kIdx].dptr<IType>();
size_t idx_size = idxshape.Size();
bool is_valid = CheckIndexOutOfBound(idx_ptr, idx_size, min, max);
CHECK(is_valid) << "take operator contains indices out of bound";
}
if (actual_axis == 0) {
if (param.mode == take_::kClip) {
Kernel<TakeCPU<true>, cpu>::Launch(s, idxshape.Size(),
Expand Down Expand Up @@ -326,7 +335,7 @@ void TakeOpForward<cpu>(const nnvm::NodeAttrs& attrs,
in_strides, out_strides, arrshape.ndim(),
oshape.ndim(), idxshape.ndim(),
arrshape[actual_axis], actual_axis);
} else if (param.mode == take_::kWrap) {
} else {
Kernel<Take<false>, cpu>::Launch(s, oshape.Size(),
outputs[take_::kOut].dptr<DType>(),
inputs[take_::kArr].dptr<DType>(),
Expand Down
18 changes: 15 additions & 3 deletions src/operator/tensor/indexing_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -482,8 +482,20 @@ void TakeOpForward<gpu>(const nnvm::NodeAttrs& attrs,
Stream<gpu> *s = ctx.get_stream<gpu>();
const int actual_axis = param.axis + ((param.axis < 0) ? arrshape.ndim() : 0);

MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, { // output data type
MSHADOW_TYPE_SWITCH(inputs[1].type_flag_, IType, { // index data type
MSHADOW_TYPE_SWITCH(outputs[take_::kOut].type_flag_, DType, { // output data type
MSHADOW_TYPE_SWITCH(inputs[take_::kIdx].type_flag_, IType, { // index data type
if (param.mode == take_::kRaise) {
// check out-of-bound indices
IType min = 0;
IType max = static_cast<IType>(arrshape[actual_axis] - 1);
IType* idx_ptr = inputs[take_::kIdx].dptr<IType>();
size_t idx_size = idxshape.Size();
Tensor<gpu, 1, char> workspace =
ctx.requested[0].get_space_typed<gpu, 1, char>(Shape1(1), s);
char* is_valid_ptr = reinterpret_cast<char*>(workspace.dptr_);
bool is_valid = CheckIndexOutOfBound(s, idx_ptr, idx_size, min, max, is_valid_ptr);
CHECK(is_valid) << "Take indices contains indices out of bound";
}
if (actual_axis == 0) {
if (param.mode == take_::kClip) {
Kernel<TakeGPU<true>, gpu>::Launch(s, oshape.Size(),
Expand Down Expand Up @@ -516,7 +528,7 @@ void TakeOpForward<gpu>(const nnvm::NodeAttrs& attrs,
inputs[take_::kIdx].dptr<IType>(),
in_strides, out_strides, arrshape.ndim(), oshape.ndim(),
idxshape.ndim(), arrshape[actual_axis], actual_axis);
} else if (param.mode == take_::kWrap) {
} else {
Kernel<Take<false>, gpu>::Launch(s, oshape.Size(),
outputs[take_::kOut].dptr<DType>(),
inputs[take_::kArr].dptr<DType>(),
Expand Down
27 changes: 19 additions & 8 deletions src/operator/tensor/indexing_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -684,9 +684,6 @@ inline bool TakeOpShape(const nnvm::NodeAttrs& attrs,
const mxnet::TShape &idxshape = (*in_attrs)[take_::kIdx];
if (!shape_is_known(idxshape)) return false;
const TakeParam& param = nnvm::get<TakeParam>(attrs.parsed);
if (param.mode == take_::kRaise) {
LOG(FATAL) << "Raise is not supported for the time being...";
}
CHECK(param.axis >= -1 * arrshape.ndim() && param.axis < arrshape.ndim())
<< "Axis should be in the range of [-r, r-1] where r is the rank of input tensor";

Expand Down Expand Up @@ -813,14 +810,15 @@ struct TakeGradGeneralKernel {
const IType* src_indptr, const IType* original_idx,
mshadow::Shape<10> in_strides, mshadow::Shape<10> out_strides,
const int in_ndims, const int out_ndims, const int idx_ndims,
const int axis) {
const int axis, const int K) {
const int in_head_index = (axis == 0) ? 0 : tid / in_strides[axis - 1];
const int in_rest_index = (axis == 0) ? tid : tid % in_strides[axis - 1];
const int in_mid_index = in_rest_index / in_strides[axis];
const int in_tail_index = (axis == in_ndims - 1) ?
0 : (in_rest_index % in_strides[axis]);
for (IType i = src_indptr[in_mid_index]; i < src_indptr[in_mid_index + 1]; ++i) {
const int out_mid_index = original_idx[i];
int out_mid_index = original_idx[i];
out_mid_index = (out_mid_index < 0) ? out_mid_index + K : out_mid_index;
int target = in_tail_index + out_mid_index * in_strides[axis];
target += (axis == 0) ? 0 : in_head_index * out_strides[axis - 1];
arr_grad[tid] += ograd[target];
Expand Down Expand Up @@ -894,7 +892,7 @@ void TakeOpBackwardImpl(mshadow::Stream<cpu>* s,
Kernel<TakeGradGeneralKernel, cpu>::Launch(
s, arrshape.Size(), arr.dptr<DType>(), ograd.dptr<DType>(), src_indptr_ptr,
original_idx_ptr, in_strides, out_strides,
arrshape.ndim(), oshape.ndim(), idxshape.ndim(), axis);
arrshape.ndim(), oshape.ndim(), idxshape.ndim(), axis, static_cast<int>(arrshape[axis]));
});
});
}
Expand Down Expand Up @@ -968,6 +966,15 @@ void TakeOpBackwardImpl(mshadow::Stream<gpu>* s,
int num_bits = common::ilog2ui(static_cast<unsigned int>(idxshape.Size()) - 1);
Tensor<gpu, 1, int> sorted_idx(sorted_idx_ptr, Shape1(idxshape.Size()), s);
SortByKey(sorted_idx, original_idx, true, &temp_storage, 0, num_bits);
cub::DeviceHistogram::HistogramEven(temp_storage_ptr,
temp_storage_bytes,
sorted_idx_ptr,
src_indptr_ptr,
static_cast<int>(arrshape[axis] + 1),
0,
static_cast<int>(arrshape[axis] + 1),
static_cast<int>(idxshape.Size()),
mshadow::Stream<gpu>::GetStream(s));
cub::DeviceScan::ExclusiveSum(temp_storage_ptr,
temp_storage_bytes,
src_indptr_ptr,
Expand All @@ -989,7 +996,7 @@ void TakeOpBackwardImpl(mshadow::Stream<gpu>* s,
Kernel<TakeGradGeneralKernel, gpu>::Launch(
s, arrshape.Size(), arr.dptr<DType>(), ograd.dptr<DType>(),
src_indptr_ptr, original_idx_ptr, in_strides, out_strides,
arrshape.ndim(), oshape.ndim(), idxshape.ndim(), axis);
arrshape.ndim(), oshape.ndim(), idxshape.ndim(), axis, static_cast<int>(arrshape[axis]));
});
});
}
Expand Down Expand Up @@ -1044,7 +1051,11 @@ void TakeOpBackward(const nnvm::NodeAttrs& attrs,
if (req[take_::kArr] == kWriteTo) {
grad_in = scalar<DType>(0.0f);
}
AddTakeGrad(grad_in, idx, grad_out);
if (param.mode == take_::kClip) {
AddTakeGrad(grad_in, idx, grad_out);
} else {
AddTakeGrad<false>(grad_in, idx, grad_out);
}
} else {
LOG(FATAL) << "wrong req";
}
Expand Down
2 changes: 1 addition & 1 deletion tests/python/unittest/test_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1136,7 +1136,7 @@ def test_ndarray_lesser_equal():


@with_seed()
def test_take():
def test_ndarray_take():
for data_ndim in range(2, 5):
for idx_ndim in range(1, 4):
data_shape = ()
Expand Down
25 changes: 22 additions & 3 deletions tests/python/unittest/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4426,15 +4426,21 @@ def grad_helper(grad_in, axis, idx):
else:
raise ValueError("axis %d is not supported..." % axis)

def check_output_n_grad(data_shape, idx_shape, axis, mode):
def check_output_n_grad(data_shape, idx_shape, axis, mode, out_of_range=True):
data = mx.sym.Variable('a')
idx = mx.sym.Variable('indices')
idx = mx.sym.BlockGrad(idx)
result = mx.sym.take(a=data, indices=idx, axis=axis, mode=mode)
exe = result.simple_bind(default_context(), a=data_shape,
indices=idx_shape, axis=axis, mode=mode)
data_real = np.random.normal(size=data_shape).astype('float32')
idx_real = np.random.randint(low=0, high=data_shape[axis], size=idx_shape)
if out_of_range:
idx_real = np.random.randint(low=-data_shape[axis], high=data_shape[axis], size=idx_shape)
if mode == 'raise':
idx_real[idx_real == 0] = 1
idx_real *= data_shape[axis]
else:
idx_real = np.random.randint(low=0, high=data_shape[axis], size=idx_shape)
if axis < 0:
axis += len(data_shape)

Expand All @@ -4444,9 +4450,20 @@ def check_output_n_grad(data_shape, idx_shape, axis, mode):
exe.arg_dict['a'][:] = mx.nd.array(data_real)
exe.arg_dict['indices'][:] = mx.nd.array(idx_real)
exe.forward(is_train=True)
if out_of_range and mode == 'raise':
try:
mx_out = exe.outputs[0].asnumpy()
except MXNetError as e:
return
else:
# Did not raise exception
assert False, "did not raise %s" % MXNetError.__name__

assert_almost_equal(exe.outputs[0].asnumpy(), np.take(data_real, idx_real, axis=axis, mode=mode))

for i in np.nditer(idx_real):
if mode == 'clip':
i = np.clip(i, 0, data_shape[axis])
grad_helper(grad_in, axis, i)

exe.backward([mx.nd.array(grad_out)])
Expand Down Expand Up @@ -4477,7 +4494,7 @@ def check_autograd_req():
x.backward()
assert_almost_equal(np.ones(sc.grad.shape), sc.grad.asnumpy())

for mode in ['clip', 'wrap']:
for mode in ['clip', 'wrap', 'raise']:
for data_ndim in range(1, 5):
for idx_ndim in range(1, 4):
for axis in range(-data_ndim, data_ndim):
Expand All @@ -4487,6 +4504,8 @@ def check_autograd_req():
idx_shape = ()
for _ in range(idx_ndim):
idx_shape += (np.random.randint(low=1, high=5), )
if mode == 'raise':
check_output_n_grad(data_shape, idx_shape, axis, 'raise', False)
check_output_n_grad(data_shape, idx_shape, axis, mode)

check_autograd_req()
Expand Down