Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Add fp16 support for topk #15560

Merged
merged 5 commits into from
Aug 23, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 7 additions & 15 deletions src/operator/tensor/ordering_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,6 @@ MSHADOW_FORCE_INLINE void TopKSort(const Tensor<gpu, 1, DType>& dat,
* \param resource temporary resource handler
* \param src the Source blob
* \param ret the destination blobs
* \param k the K elements to keep
* \param param the topk parameters
* \tparam xpu the device type.
* \tparam DType type of the output value/mask.
Expand Down Expand Up @@ -563,13 +562,13 @@ void TopK(const nnvm::NodeAttrs& attrs,
const std::vector<TBlob>& outputs) {
const TopKParam& param = nnvm::get<TopKParam>(attrs.parsed);
if (param.ret_typ == topk_enum::kReturnIndices || param.ret_typ == topk_enum::kReturnBoth) {
MXNET_NO_FLOAT16_TYPE_SWITCH(inputs[0].type_flag_, DType, {
MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, DType, {
MSHADOW_TYPE_SWITCH(param.dtype, IDType, {
TopKImpl<xpu, DType, IDType>(ctx.run_ctx, ctx.requested[0], req, inputs[0], outputs, param);
})
});
} else {
MXNET_NO_FLOAT16_TYPE_SWITCH(inputs[0].type_flag_, DType, {
MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, DType, {
TopKImpl<xpu, DType, index_t>(ctx.run_ctx, ctx.requested[0], req, inputs[0], outputs, param);
});
}
Expand Down Expand Up @@ -695,13 +694,13 @@ void TopKBackward_(const nnvm::NodeAttrs& attrs,
const std::vector<TBlob>& outputs) {
const TopKParam& param = nnvm::get<TopKParam>(attrs.parsed);
if (param.ret_typ == topk_enum::kReturnBoth) {
MXNET_NO_FLOAT16_TYPE_SWITCH(inputs[0].type_flag_, DType, {
MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, DType, {
MSHADOW_TYPE_SWITCH(param.dtype, IDType, {
TopKBackwardImpl<xpu, DType, IDType>(ctx, inputs, req, outputs, param);
});
});
} else if (param.ret_typ == topk_enum::kReturnValue) {
MXNET_NO_FLOAT16_TYPE_SWITCH(inputs[0].type_flag_, DType, {
MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, DType, {
TopKBackwardImpl<xpu, DType, index_t>(ctx, inputs, req, outputs, param);
});
} else {
Expand Down Expand Up @@ -732,7 +731,6 @@ inline bool TopKType(const nnvm::NodeAttrs& attrs,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs) {
const TopKParam& param = nnvm::get<TopKParam>(attrs.parsed);
int data_type = -1;
size_t in_size = in_attrs->size();
size_t out_size = out_attrs->size();
CHECK_EQ(in_size, 1);
Expand All @@ -756,15 +754,9 @@ inline bool TopKType(const nnvm::NodeAttrs& attrs,
CHECK(type_assign(&(*out_attrs)[0], param.dtype))
<< "Failed to set the type of ret_indices.";
} else {
CHECK(type_assign(&data_type, (*in_attrs)[0])) << "Incompatible dtype of input, in_attrs[0]="
<< (*in_attrs)[0];
CHECK(type_assign(&data_type, (*out_attrs)[0])) << "Incompatible dtype of output, out_attrs[0]="
<< (*out_attrs)[0];
CHECK(type_assign(&(*in_attrs)[0], data_type)) << "Incompatible dtype of input, in_attrs[0]="
<< (*in_attrs)[0];
CHECK(type_assign(&(*out_attrs)[0], data_type)) << "Incompatible dtype of output, out_attrs[0]="
<< (*out_attrs)[0];
if (data_type == -1) return false;
TYPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0));
TYPE_ASSIGN_CHECK(*in_attrs, 0, out_attrs->at(0));
return out_attrs->at(0) != -1;
}
return true;
}
Expand Down
68 changes: 46 additions & 22 deletions src/operator/tensor/sort_op-inl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -60,24 +60,48 @@ struct greater_half
typedef T second_argument_type;
typedef bool result_type;
__host__ __device__ bool operator()(const T &lhs, const T &rhs) const {
return static_cast<mshadow::half::half_t>(lhs) < static_cast<mshadow::half::half_t>(rhs);
return static_cast<mshadow::half::half_t>(lhs) > static_cast<mshadow::half::half_t>(rhs);
}
};
}

#ifndef SORT_WITH_THRUST
template <typename KDType, typename VDType>
inline void WorkspaceSize4KeysAndValues(
const size_t num_keys, size_t *pKeys_bytes, size_t *pValues_bytes) {
const size_t alignment = std::max(sizeof(KDType), sizeof(VDType));
*pKeys_bytes = PadBytes(num_keys * sizeof(KDType), alignment);
*pValues_bytes = PadBytes(num_keys * sizeof(VDType), alignment);
}

template <typename KDType, typename VDType>
inline typename std::enable_if<!std::is_same<KDType, mshadow::half::half_t>::value, size_t>::type
SortPairsWorkspaceSize(const size_t num_keys) {
size_t sortpairs_bytes = 0;
cub::DeviceRadixSort::SortPairs<KDType, VDType>(NULL, sortpairs_bytes,
NULL, NULL, NULL, NULL, num_keys);
return sortpairs_bytes;
}

template <typename KDType, typename VDType>
inline typename std::enable_if<std::is_same<KDType, mshadow::half::half_t>::value, size_t>::type
SortPairsWorkspaceSize(const size_t num_keys) {
size_t sortpairs_bytes = 0;
cub::DeviceRadixSort::SortPairs<__half, VDType>(NULL, sortpairs_bytes,
NULL, NULL, NULL, NULL, num_keys);
return sortpairs_bytes;
}
#endif

template <typename KDType, typename VDType, typename xpu>
inline typename std::enable_if<std::is_same<xpu, gpu>::value, size_t>::type
SortByKeyWorkspaceSize(const size_t num_keys) {
#ifdef SORT_WITH_THRUST
return 0;
#else
size_t sortpairs_bytes = 0;
cub::DeviceRadixSort::SortPairs<KDType, VDType>(NULL, sortpairs_bytes,
NULL, NULL, NULL, NULL, num_keys);
size_t alignment = std::max(sizeof(KDType), sizeof(VDType));
size_t keys_bytes = PadBytes(num_keys*sizeof(KDType), alignment);
size_t values_bytes = PadBytes(num_keys*sizeof(VDType), alignment);
return (keys_bytes + values_bytes + sortpairs_bytes);
size_t keys_bytes, values_bytes;
WorkspaceSize4KeysAndValues<KDType, VDType>(num_keys, &keys_bytes, &values_bytes);
return keys_bytes + values_bytes + SortPairsWorkspaceSize<KDType, VDType>(num_keys);
#endif
}

Expand Down Expand Up @@ -142,11 +166,11 @@ SortByKeyImpl(mshadow::Tensor<gpu, 1, KDType> keys,
if (is_ascend) {
thrust::stable_sort_by_key(
thrust::cuda::par.on(stream),
key_iter, key_iter + keys.size(0), value_iter, thrust::less<KDType>());
key_iter, key_iter + keys.size(0), value_iter.get(), thrust::less<KDType>());
} else {
thrust::stable_sort_by_key(
thrust::cuda::par.on(stream),
key_iter, key_iter + keys.size(0), value_iter, thrust::greater<KDType>());
key_iter, key_iter + keys.size(0), value_iter.get(), thrust::greater<KDType>());
}
#ifndef SORT_WITH_THRUST
}
Expand All @@ -169,8 +193,8 @@ SortByKeyImpl(mshadow::Tensor<gpu, 1, KDType> keys,
#if CUDA_VERSION >= 9000
cudaStream_t stream = mshadow::Stream<gpu>::GetStream(keys.stream_);
thrust::device_ptr<KDType> key_iter = thrust::device_pointer_cast(keys.dptr_);
thrust::device_ptr<half> value_iter = thrust::device_pointer_cast(
reinterpret_cast<half*>(values.dptr_));
thrust::device_ptr<__half> value_iter = thrust::device_pointer_cast(
reinterpret_cast<__half*>(values.dptr_));
if (is_ascend) {
thrust::stable_sort_by_key(
thrust::cuda::par.on(stream),
Expand All @@ -197,17 +221,17 @@ SortByKeyImpl(mshadow::Tensor<gpu, 1, KDType> keys,
CHECK_EQ(values.CheckContiguous(), true);
#if CUDA_VERSION >= 9000
cudaStream_t stream = mshadow::Stream<gpu>::GetStream(keys.stream_);
thrust::device_ptr<half> key_iter = thrust::device_pointer_cast(
reinterpret_cast<half*>(keys.dptr_));
thrust::device_ptr<__half> key_iter = thrust::device_pointer_cast(
reinterpret_cast<__half*>(keys.dptr_));
thrust::device_ptr<VDType> value_iter = thrust::device_pointer_cast(values.dptr_);
if (is_ascend) {
thrust::stable_sort_by_key(
thrust::cuda::par.on(stream),
key_iter, key_iter + (keys.size(0)), value_iter, cuda::less_half<half>());
key_iter, key_iter + (keys.size(0)), value_iter.get(), cuda::less_half<__half>());
} else {
thrust::stable_sort_by_key(
thrust::cuda::par.on(stream),
key_iter, key_iter + (keys.size(0)), value_iter, cuda::greater_half<half>());
key_iter, key_iter + (keys.size(0)), value_iter.get(), cuda::greater_half<__half>());
}
MSHADOW_CUDA_POST_KERNEL_CHECK(SortByKey);
#else
Expand All @@ -227,18 +251,18 @@ SortByKeyImpl(mshadow::Tensor<gpu, 1, KDType> keys,
CHECK_EQ(values.CheckContiguous(), true);
#if CUDA_VERSION >= 9000
cudaStream_t stream = mshadow::Stream<gpu>::GetStream(keys.stream_);
thrust::device_ptr<half> key_iter = thrust::device_pointer_cast(
reinterpret_cast<half*>(keys.dptr_));
thrust::device_ptr<half> value_iter = thrust::device_pointer_cast(
reinterpret_cast<half*>(values.dptr_));
thrust::device_ptr<__half> key_iter = thrust::device_pointer_cast(
reinterpret_cast<__half*>(keys.dptr_));
thrust::device_ptr<__half> value_iter = thrust::device_pointer_cast(
reinterpret_cast<__half*>(values.dptr_));
if (is_ascend) {
thrust::stable_sort_by_key(
thrust::cuda::par.on(stream),
key_iter, key_iter + (keys.size(0)), value_iter, cuda::less_half<half>());
key_iter, key_iter + (keys.size(0)), value_iter.get(), cuda::less_half<__half>());
sxjscience marked this conversation as resolved.
Show resolved Hide resolved
} else {
thrust::stable_sort_by_key(
thrust::cuda::par.on(stream),
key_iter, key_iter + (keys.size(0)), value_iter, cuda::greater_half<half>());
key_iter, key_iter + (keys.size(0)), value_iter.get(), cuda::greater_half<__half>());
}
MSHADOW_CUDA_POST_KERNEL_CHECK(SortByKey);
#else
Expand Down
64 changes: 39 additions & 25 deletions tests/python/unittest/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4304,15 +4304,6 @@ def get_large_matrix():

large_matrix_npy = get_large_matrix()

for axis in [1, 3, None]:
K = [1, 3, 5, 7] if axis is None else [1, 3, 5]
for k in K:
for is_ascend in [True, False]:
b = mx.sym.topk(a, axis=axis, is_ascend=is_ascend, ret_typ="value", k=k)
out_npy = gt_topk(dat=a_npy, axis=axis, ret_typ="value", k=k, is_ascend=is_ascend)
check_numeric_gradient(b, location={'a': a_npy}, numeric_eps=1e-2, ctx=ctx)
check_symbolic_forward(b, location={'a': a_npy}, expected=[out_npy])

for axis in [1, 3, None]:
for is_ascend in [True, False]:
b = mx.sym.sort(a, axis=axis, is_ascend=is_ascend)
Expand All @@ -4332,22 +4323,6 @@ def get_large_matrix():
ret_typ="indices", k=5,
is_ascend=is_ascend)])

b = mx.sym.topk(a, axis=3, is_ascend=is_ascend, ret_typ="indices", k=3)
check_symbolic_backward(sym=b, location={'a': a_npy},
out_grads=[np.random.normal(size=(5, 5, 5, 3))],
expected=[np.zeros((5, 5, 5, 5))])
check_symbolic_forward(b, location={'a': a_npy},
expected=[gt_topk(dat=a_npy, axis=3, ret_typ="indices", k=3,
is_ascend=False)])

b = mx.sym.topk(a, axis=1, is_ascend=True, ret_typ="mask", k=3)
check_symbolic_backward(sym=b, location={'a': a_npy},
out_grads=[np.random.normal(size=(5, 5, 5, 5))],
expected=[np.zeros((5, 5, 5, 5))])
check_symbolic_forward(b, location={'a': a_npy},
expected=[gt_topk(dat=a_npy, axis=1, ret_typ="mask", k=3,
is_ascend=True)])

b = mx.sym.argsort(a, axis=1, is_ascend=False)
check_symbolic_backward(sym=b, location={'a': a_npy},
out_grads=[np.random.normal(size=(5, 5, 5, 5))],
Expand All @@ -4372,6 +4347,45 @@ def get_large_matrix():
expected=[gt_topk(dat=a_npy, axis=1, ret_typ="indices", k=1,
is_ascend=True)])

for dtype in [np.float16, np.float32, np.float64]:
anirudhacharya marked this conversation as resolved.
Show resolved Hide resolved
dshape = (5, 5, 5, 5)
a_npy = np.arange(np.prod(dshape)).astype(dtype)
np.random.shuffle(a_npy)
a_npy = a_npy.reshape(dshape)
a = mx.sym.Variable('a')
for axis in [1, 3, None]:
K = [1, 3, 5, 7] if axis is None else [1, 3, 5]
for k in K:
for is_ascend in [True, False]:
b = mx.sym.topk(a, axis=axis, is_ascend=is_ascend, ret_typ="value", k=k)
out_npy = gt_topk(dat=a_npy, axis=axis, ret_typ="value", k=k, is_ascend=is_ascend)
check_numeric_gradient(b, location={'a': a_npy}, numeric_eps=1e-2, ctx=ctx)
check_symbolic_forward(b, location={'a': a_npy}, expected=[out_npy])

b = mx.sym.topk(a, axis=1, is_ascend=is_ascend, ret_typ="indices", k=5)
check_symbolic_backward(sym=b, location={'a': large_matrix_npy},
out_grads=[np.random.normal(size=(100, 5))],
expected=[np.zeros((100, 300096))])
check_symbolic_forward(b, location={'a': large_matrix_npy},
expected=[gt_topk(dat=large_matrix_npy, axis=1,
ret_typ="indices", k=5, is_ascend=is_ascend)])

b = mx.sym.topk(a, axis=3, is_ascend=is_ascend, ret_typ="indices", k=3)
check_symbolic_backward(sym=b, location={'a': a_npy},
out_grads=[np.random.normal(size=(5, 5, 5, 3))],
expected=[np.zeros((5, 5, 5, 5))])
check_symbolic_forward(b, location={'a': a_npy},
expected=[gt_topk(dat=a_npy, axis=3, ret_typ="indices", k=3,
is_ascend=False)])

b = mx.sym.topk(a, axis=1, is_ascend=True, ret_typ="mask", k=3)
check_symbolic_backward(sym=b, location={'a': a_npy},
out_grads=[np.random.normal(size=(5, 5, 5, 5))],
expected=[np.zeros((5, 5, 5, 5))])
check_symbolic_forward(b, location={'a': a_npy},
expected=[gt_topk(dat=a_npy, axis=1, ret_typ="mask", k=3,
is_ascend=True)])


@with_seed()
def test_blockgrad():
Expand Down