Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
fix cuda sort routine for half precision
Browse files Browse the repository at this point in the history
  • Loading branch information
anirudhacharya committed Aug 22, 2019
1 parent 314370e commit c9cd5c8
Showing 1 changed file with 31 additions and 7 deletions.
38 changes: 31 additions & 7 deletions src/operator/tensor/sort_op-inl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -65,19 +65,43 @@ struct greater_half
};
}

#ifndef SORT_WITH_THRUST
template <typename KDType, typename VDType>
inline void WorkspaceSize4KeysAndValues(
const size_t num_keys, size_t *pKeys_bytes, size_t *pValues_bytes) {
const size_t alignment = std::max(sizeof(KDType), sizeof(VDType));
*pKeys_bytes = PadBytes(num_keys * sizeof(KDType), alignment);
*pValues_bytes = PadBytes(num_keys * sizeof(VDType), alignment);
}

template <typename KDType, typename VDType>
inline typename std::enable_if<!std::is_same<KDType, mshadow::half::half_t>::value, size_t>::type
SortPairsWorkspaceSize(const size_t num_keys) {
size_t sortpairs_bytes = 0;
cub::DeviceRadixSort::SortPairs<KDType, VDType>(NULL, sortpairs_bytes,
NULL, NULL, NULL, NULL, num_keys);
return sortpairs_bytes;
}

template <typename KDType, typename VDType>
inline typename std::enable_if<std::is_same<KDType, mshadow::half::half_t>::value, size_t>::type
SortPairsWorkspaceSize(const size_t num_keys) {
size_t sortpairs_bytes = 0;
cub::DeviceRadixSort::SortPairs<half, VDType>(NULL, sortpairs_bytes,
NULL, NULL, NULL, NULL, num_keys);
return sortpairs_bytes;
}
#endif

template <typename KDType, typename VDType, typename xpu>
inline typename std::enable_if<std::is_same<xpu, gpu>::value, size_t>::type
SortByKeyWorkspaceSize(const size_t num_keys) {
#ifdef SORT_WITH_THRUST
return 0;
#else
size_t sortpairs_bytes = 0;
cub::DeviceRadixSort::SortPairs<KDType, VDType>(NULL, sortpairs_bytes,
NULL, NULL, NULL, NULL, num_keys);
size_t alignment = std::max(sizeof(KDType), sizeof(VDType));
size_t keys_bytes = PadBytes(num_keys*sizeof(KDType), alignment);
size_t values_bytes = PadBytes(num_keys*sizeof(VDType), alignment);
return (keys_bytes + values_bytes + sortpairs_bytes);
size_t keys_bytes, values_bytes;
WorkspaceSize4KeysAndValues<KDType, VDType>(num_keys, &keys_bytes, &values_bytes);
return keys_bytes + values_bytes + SortPairsWorkspaceSize<KDType, VDType>(num_keys);
#endif
}

Expand Down

0 comments on commit c9cd5c8

Please sign in to comment.