Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
new raise mode for nd.take and fix backward for wrap mode
Browse files Browse the repository at this point in the history
  • Loading branch information
haojin2 committed Aug 20, 2019
1 parent 3dfb19a commit f6ce09c
Show file tree
Hide file tree
Showing 9 changed files with 102 additions and 30 deletions.
26 changes: 21 additions & 5 deletions 3rdparty/mshadow/mshadow/cuda/tensor_gpu-inl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -596,7 +596,8 @@ inline void SoftmaxGrad(const Tensor<gpu, 3, DType> &dst,
MSHADOW_CUDA_POST_KERNEL_CHECK(Softmax3DGradKernel);
}

template<int x_bits, typename DType, typename DstPlan, typename SrcPlan1, typename SrcPlan2>
template<bool clip, int x_bits, typename DType, typename DstPlan,
typename SrcPlan1, typename SrcPlan2>
__global__ void AddTakeGradKernel(DstPlan dst,
SrcPlan1 index, SrcPlan2 src,
index_t ymax, index_t xmax, const int K) {
Expand All @@ -606,8 +607,13 @@ __global__ void AddTakeGradKernel(DstPlan dst,
for (unsigned y = 0; y < ymax; ++y) {
if (threadIdx.x == 0) {
ptr = index.Eval(0, y);
if (ptr <= 0) ptr = 0;
else if (ptr >= K) ptr = K - 1;
if (clip) {
if (ptr <= 0) ptr = 0;
else if (ptr >= K) ptr = K - 1;
} else {
ptr %= K;
if (ptr < 0) ptr += K;
}
}
__syncthreads();
if (xindex < xmax) {
Expand Down Expand Up @@ -671,7 +677,7 @@ __global__ void AddTakeGradLargeBatchKernel(DType* dst,
}
}

template<typename IndexType, typename DType>
template<bool clip = true, typename IndexType, typename DType>
inline void AddTakeGrad(Tensor<gpu, 2, DType> dst,
const Tensor<gpu, 1, IndexType>& index,
const Tensor<gpu, 2, DType> &src) {
Expand All @@ -688,13 +694,23 @@ inline void AddTakeGrad(Tensor<gpu, 2, DType> dst,
cudaStream_t stream = Stream<gpu>::GetStream(dst.stream_);
const int K = dst.shape_[0];

AddTakeGradKernel<kUnitBits, DType>
if (clip) {
AddTakeGradKernel<true, kUnitBits, DType>
<<<dimGrid, dimBlock, 0, stream>>>
(expr::MakePlan(dst),
expr::MakePlan(index),
expr::MakePlan(src),
src.size(0),
src.size(1), K);
} else {
AddTakeGradKernel<false, kUnitBits, DType>
<<<dimGrid, dimBlock, 0, stream>>>
(expr::MakePlan(dst),
expr::MakePlan(index),
expr::MakePlan(src),
src.size(0),
src.size(1), K);
}
MSHADOW_CUDA_POST_KERNEL_CHECK(AddTakeGradKernel);
}

Expand Down
4 changes: 2 additions & 2 deletions 3rdparty/mshadow/mshadow/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -817,7 +817,7 @@ inline void SoftmaxGrad(const Tensor<gpu, 2, DType> &dst,
* \param index index to take
* \param src source output
*/
template<typename IndexType, typename DType>
template<bool clip = true, typename IndexType, typename DType>
inline void AddTakeGrad(Tensor<cpu, 2, DType> dst,
const Tensor<cpu, 1, IndexType>& index,
const Tensor<cpu, 2, DType> &src);
Expand All @@ -829,7 +829,7 @@ inline void AddTakeGrad(Tensor<cpu, 2, DType> dst,
* \param index index to take
* \param src source output
*/
template<typename IndexType, typename DType>
template<bool clip = true, typename IndexType, typename DType>
inline void AddTakeGrad(Tensor<gpu, 2, DType> dst,
const Tensor<gpu, 1, IndexType>& index,
const Tensor<gpu, 2, DType> &src);
Expand Down
11 changes: 8 additions & 3 deletions 3rdparty/mshadow/mshadow/tensor_cpu-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -494,15 +494,20 @@ inline void Softmax(Tensor<cpu, 3, DType> dst,
}
}

template<typename IndexType, typename DType>
template<bool clip, typename IndexType, typename DType>
inline void AddTakeGrad(Tensor<cpu, 2, DType> dst,
const Tensor<cpu, 1, IndexType>& index,
const Tensor<cpu, 2, DType> &src) {
const int K = dst.shape_[0];
for (index_t y = 0; y < index.size(0); ++y) {
int j = index[y];
if (j <= 0) j = 0;
else if (j >= K) j = K - 1;
if (clip) {
if (j <= 0) j = 0;
else if (j >= K) j = K - 1;
} else {
j %= K;
if (j < 0) j += K;
}
dst[j] += src[y];
}
}
Expand Down
4 changes: 2 additions & 2 deletions 3rdparty/mshadow/mshadow/tensor_gpu-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -213,11 +213,11 @@ inline void SoftmaxGrad(const Tensor<gpu, 3, DType> &dst,
cuda::SoftmaxGrad(dst, src, label, ignore_label);
}

template<typename IndexType, typename DType>
template<bool clip, typename IndexType, typename DType>
inline void AddTakeGrad(Tensor<gpu, 2, DType> dst,
const Tensor<gpu, 1, IndexType>& index,
const Tensor<gpu, 2, DType> &src) {
cuda::AddTakeGrad(dst, index, src);
cuda::AddTakeGrad<clip, IndexType, DType>(dst, index, src);
}

template<typename IndexType, typename DType>
Expand Down
15 changes: 12 additions & 3 deletions src/operator/tensor/indexing_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -291,8 +291,17 @@ void TakeOpForward<cpu>(const nnvm::NodeAttrs& attrs,
Stream<cpu> *s = ctx.get_stream<cpu>();
const int actual_axis = param.axis + ((param.axis < 0) ? arrshape.ndim() : 0);

MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, { // output data type
MSHADOW_TYPE_SWITCH(inputs[1].type_flag_, IType, { // index data type
MSHADOW_TYPE_SWITCH(outputs[take_::kOut].type_flag_, DType, { // output data type
MSHADOW_TYPE_SWITCH(inputs[take_::kIdx].type_flag_, IType, { // index data type
if (param.mode == take_::kRaise) {
IType min = 0;
IType max = static_cast<IType>(arrshape[actual_axis] - 1);
// check with single thread is faster since data is small
IType* idx_ptr = inputs[take_::kIdx].dptr<IType>();
size_t idx_size = idxshape.Size();
bool is_valid = CheckIndexOutOfBound(idx_ptr, idx_size, min, max);
CHECK(is_valid) << "take operator contains indices out of bound";
}
if (actual_axis == 0) {
if (param.mode == take_::kClip) {
Kernel<TakeCPU<true>, cpu>::Launch(s, idxshape.Size(),
Expand Down Expand Up @@ -326,7 +335,7 @@ void TakeOpForward<cpu>(const nnvm::NodeAttrs& attrs,
in_strides, out_strides, arrshape.ndim(),
oshape.ndim(), idxshape.ndim(),
arrshape[actual_axis], actual_axis);
} else if (param.mode == take_::kWrap) {
} else {
Kernel<Take<false>, cpu>::Launch(s, oshape.Size(),
outputs[take_::kOut].dptr<DType>(),
inputs[take_::kArr].dptr<DType>(),
Expand Down
18 changes: 15 additions & 3 deletions src/operator/tensor/indexing_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -482,8 +482,20 @@ void TakeOpForward<gpu>(const nnvm::NodeAttrs& attrs,
Stream<gpu> *s = ctx.get_stream<gpu>();
const int actual_axis = param.axis + ((param.axis < 0) ? arrshape.ndim() : 0);

MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, { // output data type
MSHADOW_TYPE_SWITCH(inputs[1].type_flag_, IType, { // index data type
MSHADOW_TYPE_SWITCH(outputs[take_::kOut].type_flag_, DType, { // output data type
MSHADOW_TYPE_SWITCH(inputs[take_::kIdx].type_flag_, IType, { // index data type
if (param.mode == take_::kRaise) {
// check out-of-bound indices
IType min = 0;
IType max = static_cast<IType>(arrshape[actual_axis] - 1);
IType* idx_ptr = inputs[take_::kIdx].dptr<IType>();
size_t idx_size = idxshape.Size();
Tensor<gpu, 1, char> workspace =
ctx.requested[0].get_space_typed<gpu, 1, char>(Shape1(1), s);
char* is_valid_ptr = reinterpret_cast<char*>(workspace.dptr_);
bool is_valid = CheckIndexOutOfBound(s, idx_ptr, idx_size, min, max, is_valid_ptr);
CHECK(is_valid) << "Take indices contains indices out of bound";
}
if (actual_axis == 0) {
if (param.mode == take_::kClip) {
Kernel<TakeGPU<true>, gpu>::Launch(s, oshape.Size(),
Expand Down Expand Up @@ -516,7 +528,7 @@ void TakeOpForward<gpu>(const nnvm::NodeAttrs& attrs,
inputs[take_::kIdx].dptr<IType>(),
in_strides, out_strides, arrshape.ndim(), oshape.ndim(),
idxshape.ndim(), arrshape[actual_axis], actual_axis);
} else if (param.mode == take_::kWrap) {
} else {
Kernel<Take<false>, gpu>::Launch(s, oshape.Size(),
outputs[take_::kOut].dptr<DType>(),
inputs[take_::kArr].dptr<DType>(),
Expand Down
27 changes: 19 additions & 8 deletions src/operator/tensor/indexing_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -684,9 +684,6 @@ inline bool TakeOpShape(const nnvm::NodeAttrs& attrs,
const mxnet::TShape &idxshape = (*in_attrs)[take_::kIdx];
if (!shape_is_known(idxshape)) return false;
const TakeParam& param = nnvm::get<TakeParam>(attrs.parsed);
if (param.mode == take_::kRaise) {
LOG(FATAL) << "Raise is not supported for the time being...";
}
CHECK(param.axis >= -1 * arrshape.ndim() && param.axis < arrshape.ndim())
<< "Axis should be in the range of [-r, r-1] where r is the rank of input tensor";

Expand Down Expand Up @@ -813,14 +810,15 @@ struct TakeGradGeneralKernel {
const IType* src_indptr, const IType* original_idx,
mshadow::Shape<10> in_strides, mshadow::Shape<10> out_strides,
const int in_ndims, const int out_ndims, const int idx_ndims,
const int axis) {
const int axis, const int K) {
const int in_head_index = (axis == 0) ? 0 : tid / in_strides[axis - 1];
const int in_rest_index = (axis == 0) ? tid : tid % in_strides[axis - 1];
const int in_mid_index = in_rest_index / in_strides[axis];
const int in_tail_index = (axis == in_ndims - 1) ?
0 : (in_rest_index % in_strides[axis]);
for (IType i = src_indptr[in_mid_index]; i < src_indptr[in_mid_index + 1]; ++i) {
const int out_mid_index = original_idx[i];
int out_mid_index = original_idx[i];
out_mid_index = (out_mid_index < 0) ? out_mid_index + K : out_mid_index;
int target = in_tail_index + out_mid_index * in_strides[axis];
target += (axis == 0) ? 0 : in_head_index * out_strides[axis - 1];
arr_grad[tid] += ograd[target];
Expand Down Expand Up @@ -894,7 +892,7 @@ void TakeOpBackwardImpl(mshadow::Stream<cpu>* s,
Kernel<TakeGradGeneralKernel, cpu>::Launch(
s, arrshape.Size(), arr.dptr<DType>(), ograd.dptr<DType>(), src_indptr_ptr,
original_idx_ptr, in_strides, out_strides,
arrshape.ndim(), oshape.ndim(), idxshape.ndim(), axis);
arrshape.ndim(), oshape.ndim(), idxshape.ndim(), axis, static_cast<int>(arrshape[axis]));
});
});
}
Expand Down Expand Up @@ -968,6 +966,15 @@ void TakeOpBackwardImpl(mshadow::Stream<gpu>* s,
int num_bits = common::ilog2ui(static_cast<unsigned int>(idxshape.Size()) - 1);
Tensor<gpu, 1, int> sorted_idx(sorted_idx_ptr, Shape1(idxshape.Size()), s);
SortByKey(sorted_idx, original_idx, true, &temp_storage, 0, num_bits);
cub::DeviceHistogram::HistogramEven(temp_storage_ptr,
temp_storage_bytes,
sorted_idx_ptr,
src_indptr_ptr,
static_cast<int>(arrshape[axis] + 1),
0,
static_cast<int>(arrshape[axis] + 1),
static_cast<int>(idxshape.Size()),
mshadow::Stream<gpu>::GetStream(s));
cub::DeviceScan::ExclusiveSum(temp_storage_ptr,
temp_storage_bytes,
src_indptr_ptr,
Expand All @@ -989,7 +996,7 @@ void TakeOpBackwardImpl(mshadow::Stream<gpu>* s,
Kernel<TakeGradGeneralKernel, gpu>::Launch(
s, arrshape.Size(), arr.dptr<DType>(), ograd.dptr<DType>(),
src_indptr_ptr, original_idx_ptr, in_strides, out_strides,
arrshape.ndim(), oshape.ndim(), idxshape.ndim(), axis);
arrshape.ndim(), oshape.ndim(), idxshape.ndim(), axis, static_cast<int>(arrshape[axis]));
});
});
}
Expand Down Expand Up @@ -1044,7 +1051,11 @@ void TakeOpBackward(const nnvm::NodeAttrs& attrs,
if (req[take_::kArr] == kWriteTo) {
grad_in = scalar<DType>(0.0f);
}
AddTakeGrad(grad_in, idx, grad_out);
if (param.mode == take_::kClip) {
AddTakeGrad(grad_in, idx, grad_out);
} else {
AddTakeGrad<false>(grad_in, idx, grad_out);
}
} else {
LOG(FATAL) << "wrong req";
}
Expand Down
2 changes: 1 addition & 1 deletion tests/python/unittest/test_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1142,7 +1142,7 @@ def test_ndarray_lesser_equal():


@with_seed()
def test_take():
def test_ndarray_take():
for data_ndim in range(2, 5):
for idx_ndim in range(1, 4):
data_shape = ()
Expand Down
25 changes: 22 additions & 3 deletions tests/python/unittest/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4412,15 +4412,21 @@ def grad_helper(grad_in, axis, idx):
else:
raise ValueError("axis %d is not supported..." % axis)

def check_output_n_grad(data_shape, idx_shape, axis, mode):
def check_output_n_grad(data_shape, idx_shape, axis, mode, out_of_range=True):
data = mx.sym.Variable('a')
idx = mx.sym.Variable('indices')
idx = mx.sym.BlockGrad(idx)
result = mx.sym.take(a=data, indices=idx, axis=axis, mode=mode)
exe = result.simple_bind(default_context(), a=data_shape,
indices=idx_shape, axis=axis, mode=mode)
data_real = np.random.normal(size=data_shape).astype('float32')
idx_real = np.random.randint(low=0, high=data_shape[axis], size=idx_shape)
if out_of_range:
idx_real = np.random.randint(low=-data_shape[axis], high=data_shape[axis], size=idx_shape)
if mode == 'raise':
idx_real[idx_real == 0] = 1
idx_real *= data_shape[axis]
else:
idx_real = np.random.randint(low=0, high=data_shape[axis], size=idx_shape)
if axis < 0:
axis += len(data_shape)

Expand All @@ -4430,9 +4436,20 @@ def check_output_n_grad(data_shape, idx_shape, axis, mode):
exe.arg_dict['a'][:] = mx.nd.array(data_real)
exe.arg_dict['indices'][:] = mx.nd.array(idx_real)
exe.forward(is_train=True)
if out_of_range and mode == 'raise':
try:
mx_out = exe.outputs[0].asnumpy()
except MXNetError as e:
return
else:
# Did not raise exception
assert False, "did not raise %s" % MXNetError.__name__

assert_almost_equal(exe.outputs[0].asnumpy(), np.take(data_real, idx_real, axis=axis, mode=mode))

for i in np.nditer(idx_real):
if mode == 'clip':
i = np.clip(i, 0, data_shape[axis])
grad_helper(grad_in, axis, i)

exe.backward([mx.nd.array(grad_out)])
Expand Down Expand Up @@ -4463,7 +4480,7 @@ def check_autograd_req():
x.backward()
assert_almost_equal(np.ones(sc.grad.shape), sc.grad.asnumpy())

for mode in ['clip', 'wrap']:
for mode in ['clip', 'wrap', 'raise']:
for data_ndim in range(1, 5):
for idx_ndim in range(1, 4):
for axis in range(-data_ndim, data_ndim):
Expand All @@ -4473,6 +4490,8 @@ def check_autograd_req():
idx_shape = ()
for _ in range(idx_ndim):
idx_shape += (np.random.randint(low=1, high=5), )
if mode == 'raise':
check_output_n_grad(data_shape, idx_shape, axis, 'raise', False)
check_output_n_grad(data_shape, idx_shape, axis, mode)

check_autograd_req()
Expand Down

0 comments on commit f6ce09c

Please sign in to comment.