diff --git a/src/operator/tensor/ordering_op-inl.h b/src/operator/tensor/ordering_op-inl.h index b36d79acfc7b..0ccbe410a1e1 100644 --- a/src/operator/tensor/ordering_op-inl.h +++ b/src/operator/tensor/ordering_op-inl.h @@ -377,7 +377,6 @@ MSHADOW_FORCE_INLINE void TopKSort(const Tensor& dat, * \param resource temporary resource handler * \param src the Source blob * \param ret the destination blobs - * \param k the K elements to keep * \param param the topk parameters * \tparam xpu the device type. * \tparam DType type of the output value/mask. @@ -563,13 +562,13 @@ void TopK(const nnvm::NodeAttrs& attrs, const std::vector& outputs) { const TopKParam& param = nnvm::get(attrs.parsed); if (param.ret_typ == topk_enum::kReturnIndices || param.ret_typ == topk_enum::kReturnBoth) { - MXNET_NO_FLOAT16_TYPE_SWITCH(inputs[0].type_flag_, DType, { + MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, DType, { MSHADOW_TYPE_SWITCH(param.dtype, IDType, { TopKImpl(ctx.run_ctx, ctx.requested[0], req, inputs[0], outputs, param); }) }); } else { - MXNET_NO_FLOAT16_TYPE_SWITCH(inputs[0].type_flag_, DType, { + MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, DType, { TopKImpl(ctx.run_ctx, ctx.requested[0], req, inputs[0], outputs, param); }); } @@ -695,13 +694,13 @@ void TopKBackward_(const nnvm::NodeAttrs& attrs, const std::vector& outputs) { const TopKParam& param = nnvm::get(attrs.parsed); if (param.ret_typ == topk_enum::kReturnBoth) { - MXNET_NO_FLOAT16_TYPE_SWITCH(inputs[0].type_flag_, DType, { + MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, DType, { MSHADOW_TYPE_SWITCH(param.dtype, IDType, { TopKBackwardImpl(ctx, inputs, req, outputs, param); }); }); } else if (param.ret_typ == topk_enum::kReturnValue) { - MXNET_NO_FLOAT16_TYPE_SWITCH(inputs[0].type_flag_, DType, { + MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, DType, { TopKBackwardImpl(ctx, inputs, req, outputs, param); }); } else { @@ -732,7 +731,6 @@ inline bool TopKType(const nnvm::NodeAttrs& attrs, std::vector *in_attrs, std::vector *out_attrs) { const TopKParam& param = nnvm::get(attrs.parsed); - int data_type = -1; size_t in_size = in_attrs->size(); size_t out_size = out_attrs->size(); CHECK_EQ(in_size, 1); @@ -756,15 +754,9 @@ inline bool TopKType(const nnvm::NodeAttrs& attrs, CHECK(type_assign(&(*out_attrs)[0], param.dtype)) << "Failed to set the type of ret_indices."; } else { - CHECK(type_assign(&data_type, (*in_attrs)[0])) << "Incompatible dtype of input, in_attrs[0]=" - << (*in_attrs)[0]; - CHECK(type_assign(&data_type, (*out_attrs)[0])) << "Incompatible dtype of output, out_attrs[0]=" - << (*out_attrs)[0]; - CHECK(type_assign(&(*in_attrs)[0], data_type)) << "Incompatible dtype of input, in_attrs[0]=" - << (*in_attrs)[0]; - CHECK(type_assign(&(*out_attrs)[0], data_type)) << "Incompatible dtype of output, out_attrs[0]=" - << (*out_attrs)[0]; - if (data_type == -1) return false; + TYPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0)); + TYPE_ASSIGN_CHECK(*in_attrs, 0, out_attrs->at(0)); + return out_attrs->at(0) != -1; } return true; } diff --git a/src/operator/tensor/sort_op-inl.cuh b/src/operator/tensor/sort_op-inl.cuh index f0caee4f5cb6..b20b466d9c2b 100644 --- a/src/operator/tensor/sort_op-inl.cuh +++ b/src/operator/tensor/sort_op-inl.cuh @@ -60,24 +60,48 @@ struct greater_half typedef T second_argument_type; typedef bool result_type; __host__ __device__ bool operator()(const T &lhs, const T &rhs) const { - return static_cast(lhs) < static_cast(rhs); + return static_cast(lhs) > static_cast(rhs); } }; } +#ifndef SORT_WITH_THRUST +template +inline void WorkspaceSize4KeysAndValues( + const size_t num_keys, size_t *pKeys_bytes, size_t *pValues_bytes) { + const size_t alignment = std::max(sizeof(KDType), sizeof(VDType)); + *pKeys_bytes = PadBytes(num_keys * sizeof(KDType), alignment); + *pValues_bytes = PadBytes(num_keys * sizeof(VDType), alignment); +} + +template +inline typename std::enable_if::value, size_t>::type +SortPairsWorkspaceSize(const size_t num_keys) { + size_t sortpairs_bytes = 0; + cub::DeviceRadixSort::SortPairs(NULL, sortpairs_bytes, + NULL, NULL, NULL, NULL, num_keys); + return sortpairs_bytes; +} + +template +inline typename std::enable_if::value, size_t>::type +SortPairsWorkspaceSize(const size_t num_keys) { + size_t sortpairs_bytes = 0; + cub::DeviceRadixSort::SortPairs<__half, VDType>(NULL, sortpairs_bytes, + NULL, NULL, NULL, NULL, num_keys); + return sortpairs_bytes; +} +#endif + template inline typename std::enable_if::value, size_t>::type SortByKeyWorkspaceSize(const size_t num_keys) { #ifdef SORT_WITH_THRUST return 0; #else - size_t sortpairs_bytes = 0; - cub::DeviceRadixSort::SortPairs(NULL, sortpairs_bytes, - NULL, NULL, NULL, NULL, num_keys); - size_t alignment = std::max(sizeof(KDType), sizeof(VDType)); - size_t keys_bytes = PadBytes(num_keys*sizeof(KDType), alignment); - size_t values_bytes = PadBytes(num_keys*sizeof(VDType), alignment); - return (keys_bytes + values_bytes + sortpairs_bytes); + size_t keys_bytes, values_bytes; + WorkspaceSize4KeysAndValues(num_keys, &keys_bytes, &values_bytes); + return keys_bytes + values_bytes + SortPairsWorkspaceSize(num_keys); #endif } @@ -142,11 +166,11 @@ SortByKeyImpl(mshadow::Tensor keys, if (is_ascend) { thrust::stable_sort_by_key( thrust::cuda::par.on(stream), - key_iter, key_iter + keys.size(0), value_iter, thrust::less()); + key_iter, key_iter + keys.size(0), value_iter.get(), thrust::less()); } else { thrust::stable_sort_by_key( thrust::cuda::par.on(stream), - key_iter, key_iter + keys.size(0), value_iter, thrust::greater()); + key_iter, key_iter + keys.size(0), value_iter.get(), thrust::greater()); } #ifndef SORT_WITH_THRUST } @@ -169,8 +193,8 @@ SortByKeyImpl(mshadow::Tensor keys, #if CUDA_VERSION >= 9000 cudaStream_t stream = mshadow::Stream::GetStream(keys.stream_); thrust::device_ptr key_iter = thrust::device_pointer_cast(keys.dptr_); - thrust::device_ptr value_iter = thrust::device_pointer_cast( - reinterpret_cast(values.dptr_)); + thrust::device_ptr<__half> value_iter = thrust::device_pointer_cast( + reinterpret_cast<__half*>(values.dptr_)); if (is_ascend) { thrust::stable_sort_by_key( thrust::cuda::par.on(stream), @@ -197,17 +221,17 @@ SortByKeyImpl(mshadow::Tensor keys, CHECK_EQ(values.CheckContiguous(), true); #if CUDA_VERSION >= 9000 cudaStream_t stream = mshadow::Stream::GetStream(keys.stream_); - thrust::device_ptr key_iter = thrust::device_pointer_cast( - reinterpret_cast(keys.dptr_)); + thrust::device_ptr<__half> key_iter = thrust::device_pointer_cast( + reinterpret_cast<__half*>(keys.dptr_)); thrust::device_ptr value_iter = thrust::device_pointer_cast(values.dptr_); if (is_ascend) { thrust::stable_sort_by_key( thrust::cuda::par.on(stream), - key_iter, key_iter + (keys.size(0)), value_iter, cuda::less_half()); + key_iter, key_iter + (keys.size(0)), value_iter.get(), cuda::less_half<__half>()); } else { thrust::stable_sort_by_key( thrust::cuda::par.on(stream), - key_iter, key_iter + (keys.size(0)), value_iter, cuda::greater_half()); + key_iter, key_iter + (keys.size(0)), value_iter.get(), cuda::greater_half<__half>()); } MSHADOW_CUDA_POST_KERNEL_CHECK(SortByKey); #else @@ -227,18 +251,18 @@ SortByKeyImpl(mshadow::Tensor keys, CHECK_EQ(values.CheckContiguous(), true); #if CUDA_VERSION >= 9000 cudaStream_t stream = mshadow::Stream::GetStream(keys.stream_); - thrust::device_ptr key_iter = thrust::device_pointer_cast( - reinterpret_cast(keys.dptr_)); - thrust::device_ptr value_iter = thrust::device_pointer_cast( - reinterpret_cast(values.dptr_)); + thrust::device_ptr<__half> key_iter = thrust::device_pointer_cast( + reinterpret_cast<__half*>(keys.dptr_)); + thrust::device_ptr<__half> value_iter = thrust::device_pointer_cast( + reinterpret_cast<__half*>(values.dptr_)); if (is_ascend) { thrust::stable_sort_by_key( thrust::cuda::par.on(stream), - key_iter, key_iter + (keys.size(0)), value_iter, cuda::less_half()); + key_iter, key_iter + (keys.size(0)), value_iter.get(), cuda::less_half<__half>()); } else { thrust::stable_sort_by_key( thrust::cuda::par.on(stream), - key_iter, key_iter + (keys.size(0)), value_iter, cuda::greater_half()); + key_iter, key_iter + (keys.size(0)), value_iter.get(), cuda::greater_half<__half>()); } MSHADOW_CUDA_POST_KERNEL_CHECK(SortByKey); #else diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 466cee823029..4de7d16e6f61 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -4304,15 +4304,6 @@ def get_large_matrix(): large_matrix_npy = get_large_matrix() - for axis in [1, 3, None]: - K = [1, 3, 5, 7] if axis is None else [1, 3, 5] - for k in K: - for is_ascend in [True, False]: - b = mx.sym.topk(a, axis=axis, is_ascend=is_ascend, ret_typ="value", k=k) - out_npy = gt_topk(dat=a_npy, axis=axis, ret_typ="value", k=k, is_ascend=is_ascend) - check_numeric_gradient(b, location={'a': a_npy}, numeric_eps=1e-2, ctx=ctx) - check_symbolic_forward(b, location={'a': a_npy}, expected=[out_npy]) - for axis in [1, 3, None]: for is_ascend in [True, False]: b = mx.sym.sort(a, axis=axis, is_ascend=is_ascend) @@ -4332,22 +4323,6 @@ def get_large_matrix(): ret_typ="indices", k=5, is_ascend=is_ascend)]) - b = mx.sym.topk(a, axis=3, is_ascend=is_ascend, ret_typ="indices", k=3) - check_symbolic_backward(sym=b, location={'a': a_npy}, - out_grads=[np.random.normal(size=(5, 5, 5, 3))], - expected=[np.zeros((5, 5, 5, 5))]) - check_symbolic_forward(b, location={'a': a_npy}, - expected=[gt_topk(dat=a_npy, axis=3, ret_typ="indices", k=3, - is_ascend=False)]) - - b = mx.sym.topk(a, axis=1, is_ascend=True, ret_typ="mask", k=3) - check_symbolic_backward(sym=b, location={'a': a_npy}, - out_grads=[np.random.normal(size=(5, 5, 5, 5))], - expected=[np.zeros((5, 5, 5, 5))]) - check_symbolic_forward(b, location={'a': a_npy}, - expected=[gt_topk(dat=a_npy, axis=1, ret_typ="mask", k=3, - is_ascend=True)]) - b = mx.sym.argsort(a, axis=1, is_ascend=False) check_symbolic_backward(sym=b, location={'a': a_npy}, out_grads=[np.random.normal(size=(5, 5, 5, 5))], @@ -4372,6 +4347,45 @@ def get_large_matrix(): expected=[gt_topk(dat=a_npy, axis=1, ret_typ="indices", k=1, is_ascend=True)]) + for dtype in [np.float16, np.float32, np.float64]: + dshape = (5, 5, 5, 5) + a_npy = np.arange(np.prod(dshape)).astype(dtype) + np.random.shuffle(a_npy) + a_npy = a_npy.reshape(dshape) + a = mx.sym.Variable('a') + for axis in [1, 3, None]: + K = [1, 3, 5, 7] if axis is None else [1, 3, 5] + for k in K: + for is_ascend in [True, False]: + b = mx.sym.topk(a, axis=axis, is_ascend=is_ascend, ret_typ="value", k=k) + out_npy = gt_topk(dat=a_npy, axis=axis, ret_typ="value", k=k, is_ascend=is_ascend) + check_numeric_gradient(b, location={'a': a_npy}, numeric_eps=1e-2, ctx=ctx) + check_symbolic_forward(b, location={'a': a_npy}, expected=[out_npy]) + + b = mx.sym.topk(a, axis=1, is_ascend=is_ascend, ret_typ="indices", k=5) + check_symbolic_backward(sym=b, location={'a': large_matrix_npy}, + out_grads=[np.random.normal(size=(100, 5))], + expected=[np.zeros((100, 300096))]) + check_symbolic_forward(b, location={'a': large_matrix_npy}, + expected=[gt_topk(dat=large_matrix_npy, axis=1, + ret_typ="indices", k=5, is_ascend=is_ascend)]) + + b = mx.sym.topk(a, axis=3, is_ascend=is_ascend, ret_typ="indices", k=3) + check_symbolic_backward(sym=b, location={'a': a_npy}, + out_grads=[np.random.normal(size=(5, 5, 5, 3))], + expected=[np.zeros((5, 5, 5, 5))]) + check_symbolic_forward(b, location={'a': a_npy}, + expected=[gt_topk(dat=a_npy, axis=3, ret_typ="indices", k=3, + is_ascend=False)]) + + b = mx.sym.topk(a, axis=1, is_ascend=True, ret_typ="mask", k=3) + check_symbolic_backward(sym=b, location={'a': a_npy}, + out_grads=[np.random.normal(size=(5, 5, 5, 5))], + expected=[np.zeros((5, 5, 5, 5))]) + check_symbolic_forward(b, location={'a': a_npy}, + expected=[gt_topk(dat=a_npy, axis=1, ret_typ="mask", k=3, + is_ascend=True)]) + @with_seed() def test_blockgrad():