Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion docs/OperatorKernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -958,7 +958,8 @@ Do not modify directly.*
|||1+|**T** = tensor(double), tensor(float), tensor(float16)|
|Tile|*in* input:**T**<br> *in* repeats:**T1**<br> *out* output:**T**<br><br>or<br><br>*in* input:**T**<br> *in* tiles:**T**<br> *in* axis:**T**<br> *out* output:**T**|13+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)<br/> **T1** = tensor(int64)|
|||[6, 12]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)<br/> **T1** = tensor(int64)|
|TopK|*in* X:**T**<br> *in* K:**tensor(int64)**<br> *out* Values:**T**<br> *out* Indices:**I**<br><br>or<br><br>*in* X:**T**<br> *out* Values:**T**<br> *out* Indices:**I**|11+|**I** = tensor(int64)<br/> **T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)|
|TopK|*in* X:**T**<br> *in* K:**tensor(int64)**<br> *out* Values:**T**<br> *out* Indices:**I**<br><br>or<br><br>*in* X:**T**<br> *out* Values:**T**<br> *out* Indices:**I**|24+|**I** = tensor(int64)<br/> **T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)|
|||[11, 23]|**I** = tensor(int64)<br/> **T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)|
|||10|**I** = tensor(int64)<br/> **T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)|
|||[1, 9]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)|
|Transpose|*in* data:**T**<br> *out* transposed:**T**|23+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
Expand Down
6 changes: 4 additions & 2 deletions onnxruntime/core/providers/cuda/cuda_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -981,7 +981,7 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kO
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, MLFloat16, LogSoftmax);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, Split);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, Squeeze);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, TopK);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 23, TopK);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, SequenceAt);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, SequenceConstruct);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, SequenceEmpty);
Expand Down Expand Up @@ -1634,6 +1634,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain,
#endif

// Opset 24.
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 24, TopK);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 24, TensorScatter);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 24, float, Attention);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 24, MLFloat16, Attention);
Expand Down Expand Up @@ -2070,7 +2071,7 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, MLFloat16, LogSoftmax)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, Split)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, Squeeze)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, TopK)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 23, TopK)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, SequenceAt)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, SequenceConstruct)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, SequenceEmpty)>,
Expand Down Expand Up @@ -2717,6 +2718,7 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, Unsqueeze)>,

// Opset 24
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 24, TopK)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 24, TensorScatter)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 24, float, Attention)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 24, MLFloat16, Attention)>,
Expand Down
21 changes: 19 additions & 2 deletions onnxruntime/core/providers/cuda/math/topk.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,10 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX(
.TypeConstraint("I", DataTypeImpl::GetTensorType<int64_t>()),
TopK<true>);

ONNX_OPERATOR_KERNEL_EX(
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
TopK,
kOnnxDomain,
11,
11, 23,
kCudaExecutionProvider,
(*KernelDefBuilder::Create())
.InputMemoryType(OrtMemTypeCPUInput, 1)
Expand All @@ -50,6 +50,22 @@ ONNX_OPERATOR_KERNEL_EX(
.TypeConstraint("I", DataTypeImpl::GetTensorType<int64_t>()),
TopK<true>);

ONNX_OPERATOR_KERNEL_EX(
TopK,
kOnnxDomain,
24,
kCudaExecutionProvider,
(*KernelDefBuilder::Create())
.InputMemoryType(OrtMemTypeCPUInput, 1)
.TypeConstraint("T", {DataTypeImpl::GetTensorType<MLFloat16>(),
DataTypeImpl::GetTensorType<float>(),
DataTypeImpl::GetTensorType<double>(),
DataTypeImpl::GetTensorType<int32_t>(),
DataTypeImpl::GetTensorType<int64_t>(),
DataTypeImpl::GetTensorType<BFloat16>()})
.TypeConstraint("I", DataTypeImpl::GetTensorType<int64_t>()),
TopK<true>);

template <bool inputk>
TopK<inputk>::TopK(const OpKernelInfo& info) : CudaKernel(info) {
info.GetAttrOrDefault<int64_t>("axis", &axis_, -1);
Expand Down Expand Up @@ -124,6 +140,7 @@ Status TopK<inputk>::ComputeInternal(OpKernelContext* ctx) const {
if (IS_PRIM_TYPE(MLFloat16)) return TOPKIMPL(MLFloat16);
if (IS_PRIM_TYPE(float)) return TOPKIMPL(float);
if (IS_PRIM_TYPE(double)) return TOPKIMPL(double);
if (IS_PRIM_TYPE(BFloat16)) return TOPKIMPL(BFloat16);

return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Type not supported for TopK operator");
}
Expand Down
55 changes: 42 additions & 13 deletions onnxruntime/core/providers/cuda/math/topk_impl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,18 @@

using namespace cub;

// CUB's radix sort requires types it natively recognizes (float, double, __half, __nv_bfloat16).
// ORT's BFloat16 wrapper is bitwise compatible with __nv_bfloat16, so map it for CUB sort operations.
template <typename T>
struct CubSortType {
using type = T;
};

template <>
struct CubSortType<BFloat16> {
using type = __nv_bfloat16;
};

template <typename T>
struct KV {
T key;
Expand Down Expand Up @@ -170,6 +182,10 @@
return SamePrefix((const int16_t*)f0, (const int16_t*)f1, skip);
}

__device__ __forceinline__ bool SamePrefix(const BFloat16* f0, const BFloat16* f1, int64_t skip) {
return SamePrefix((const int16_t*)f0, (const int16_t*)f1, skip);
}

__device__ __forceinline__ bool SamePrefix(const float* f0, const float* f1, int64_t skip) {
return SamePrefix((const int32_t*)f0, (const int32_t*)f1, skip);
}
Expand All @@ -187,6 +203,10 @@
return Radix((const int16_t*)f, skip);
}

__device__ __forceinline__ int32_t Radix(const BFloat16* f, int64_t skip) {
return Radix((const int16_t*)f, skip);
}

__device__ __forceinline__ int32_t Radix(const float* f, int64_t skip) {
return Radix((const int32_t*)f, skip);
}
Expand All @@ -204,6 +224,10 @@
SetByte((int16_t*)f, byte);
}

__device__ __forceinline__ void SetByte(BFloat16* f, int64_t byte) {
SetByte((int16_t*)f, byte);

Check warning on line 228 in onnxruntime/core/providers/cuda/math/topk_impl.cuh

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Using C-style cast. Use reinterpret_cast<int16_t*>(...) instead [readability/casting] [4] Raw Output: onnxruntime/core/providers/cuda/math/topk_impl.cuh:228: Using C-style cast. Use reinterpret_cast<int16_t*>(...) instead [readability/casting] [4]
}

__device__ __forceinline__ void SetByte(float* f, int64_t byte) {
SetByte((int32_t*)f, byte);
}
Expand All @@ -224,11 +248,12 @@
T Kth = (T)0, sign = (T)1;
typedef BlockScan<uint32_t, THREADS> BlockScan;
typedef BlockReduce<uint32_t, THREADS> BlockReduce;
typedef BlockRadixSort<T, THREADS, KPT, int64_t> BlockRadixSort;
using CubT = typename CubSortType<T>::type;
typedef cub::BlockRadixSort<CubT, THREADS, KPT, int64_t> CubBlockRadixSort;
__shared__ union {
typename BlockScan::TempStorage scan;
typename BlockReduce::TempStorage reduce;
typename BlockRadixSort::TempStorage sort;
typename CubBlockRadixSort::TempStorage sort;
} temp_storage;
uint32_t positive = 0, negative = 0;
for (int64_t x_i = tid; x_i < dimension; x_i += blockDim.x) {
Expand Down Expand Up @@ -330,34 +355,34 @@
}
__syncthreads();
if (1 == sorted) {
T keys[KPT];
CubT keys[KPT];
int64_t vals[KPT];
for (int64_t k_i = tid, k_c = 0; k_c < KPT; k_i += blockDim.x, ++k_c) {
if (k_i < K) {
auto to_i = TO(k_i);
keys[k_c] = V[to_i];
memcpy(&keys[k_c], &V[to_i], sizeof(CubT));
vals[k_c] = I[to_i];
} else {
if (1 == largest) {
keys[k_c] = type_min;
memcpy(&keys[k_c], &type_min, sizeof(CubT));
} else {
keys[k_c] = type_max;
memcpy(&keys[k_c], &type_max, sizeof(CubT));
}
}
}
__syncthreads();
if (1 == largest) {
BlockRadixSort(temp_storage.sort).SortDescending(keys, vals);
CubBlockRadixSort(temp_storage.sort).SortDescending(keys, vals);
} else {
BlockRadixSort(temp_storage.sort).Sort(keys, vals);
CubBlockRadixSort(temp_storage.sort).Sort(keys, vals);
}
__syncthreads();
#pragma unroll
for (int64_t k_c = 0; k_c < KPT; ++k_c) {
auto k_i = tid * KPT + k_c;
if (k_i < K) {
auto to_i = TO(k_i);
V[to_i] = keys[k_c];
memcpy(&V[to_i], &keys[k_c], sizeof(CubT));
I[to_i] = vals[k_c];
}
}
Expand Down Expand Up @@ -399,6 +424,7 @@
const TArray<int64_t>& elem_nums, size_t size, int32_t axis, int64_t K, int64_t largest,
int64_t sorted, int64_t N, int64_t dimension) {
typedef typename ToCudaType<T>::MappedType CudaT;
using CubT = typename CubSortType<CudaT>::type;
const CudaT* input_x_ptr = reinterpret_cast<const CudaT*>(input_x);
CudaT* output_v_ptr = reinterpret_cast<CudaT*>(output_v);
cudaStream_t stream = ort_stream ? static_cast<cudaStream_t>(ort_stream->GetHandle()) : nullptr;
Expand Down Expand Up @@ -444,21 +470,24 @@
auto* output_key = output_key_buffer.get();
auto* input_value = input_value_buffer.get();
auto* output_value = output_value_buffer.get();
// CUB sort requires native CUDA types; cast through CubSortType for BFloat16 → __nv_bfloat16 mapping.
auto* input_key_cub = reinterpret_cast<CubT*>(input_key);
auto* output_key_cub = reinterpret_cast<CubT*>(output_key);
size_t temp_bytes = 0;
CUDA_RETURN_IF_ERROR(cub::DeviceRadixSort::SortPairs(nullptr, temp_bytes, input_key, output_key, input_value, output_value, dimension, 0, sizeof(T) * 8, stream));
CUDA_RETURN_IF_ERROR(cub::DeviceRadixSort::SortPairs(nullptr, temp_bytes, input_key_cub, output_key_cub, input_value, output_value, dimension, 0, sizeof(CubT) * 8, stream));
auto temp_storage_buffer = kernel->GetScratchBuffer<char>(temp_bytes, ort_stream);
auto* temp_storage = temp_storage_buffer.get();
auto blocks_per_grid_D = (int)(ceil(static_cast<float>(dimension) / BT));
auto blocks_per_grid_K = (int)(ceil(static_cast<float>(K) / BT));
for (int64_t i = 0; i < N; i++) {
FillInput<CudaT><<<blocks_per_grid_D, BT, 0, stream>>>(input_x_ptr, input_key, input_value, elem_nums, size, axis, K, i, dimension);
CUDA_RETURN_IF_ERROR(1 == largest ? cub::DeviceRadixSort::SortPairsDescending(temp_storage, temp_bytes, input_key, output_key, input_value, output_value, dimension, 0, sizeof(T) * 8, stream)
: cub::DeviceRadixSort::SortPairs(temp_storage, temp_bytes, input_key, output_key, input_value, output_value, dimension, 0, sizeof(T) * 8, stream));
CUDA_RETURN_IF_ERROR(1 == largest ? cub::DeviceRadixSort::SortPairsDescending(temp_storage, temp_bytes, input_key_cub, output_key_cub, input_value, output_value, dimension, 0, sizeof(CubT) * 8, stream)
: cub::DeviceRadixSort::SortPairs(temp_storage, temp_bytes, input_key_cub, output_key_cub, input_value, output_value, dimension, 0, sizeof(CubT) * 8, stream));
if (1 == sorted) {
FillOutput<CudaT><<<blocks_per_grid_K, BT, 0, stream>>>(output_key, output_value, output_v_ptr, output_i, elem_nums, size, axis, K, i, dimension);
} else { // reorder by ascending index
ExcludeOutput<int64_t><<<blocks_per_grid_D, BT, 0, stream>>>(output_value, K, dimension);
CUDA_RETURN_IF_ERROR(cub::DeviceRadixSort::SortPairs(temp_storage, temp_bytes, output_value, input_value, output_key, input_key, dimension, 0, sizeof(T) * 8, stream));
CUDA_RETURN_IF_ERROR(cub::DeviceRadixSort::SortPairs(temp_storage, temp_bytes, output_value, input_value, output_key_cub, input_key_cub, dimension, 0, sizeof(CubT) * 8, stream));
FillOutput<CudaT><<<blocks_per_grid_K, BT, 0, stream>>>(input_key, input_value, output_v_ptr, output_i, elem_nums, size, axis, K, i, dimension);
}
}
Expand Down
5 changes: 5 additions & 0 deletions onnxruntime/core/providers/cuda/math/topk_impl_bf16.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#define TOPK_IMPL_TYPE BFloat16
#include "topk_impl.cuh"

Check warning on line 5 in onnxruntime/core/providers/cuda/math/topk_impl_bf16.cu

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Include the directory when naming header files [build/include_subdir] [4] Raw Output: onnxruntime/core/providers/cuda/math/topk_impl_bf16.cu:5: Include the directory when naming header files [build/include_subdir] [4]
11 changes: 11 additions & 0 deletions onnxruntime/core/providers/cuda/shared_inc/cuda_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,17 @@ struct NumericLimits<half> {
}
};

template <>
struct NumericLimits<BFloat16> {
__inline__ __host__ __device__ static BFloat16 Lowest() {
return BFloat16::FromBits(0xFF7FU); // -3.38953139e38
}

__inline__ __host__ __device__ static BFloat16 Max() {
return BFloat16::FromBits(0x7F7FU); // 3.38953139e38
}
};

// TODO Where to put this? good places might be
// core/framework/tensor_shape.h
// core/util/matrix_layout.h
Expand Down
52 changes: 50 additions & 2 deletions onnxruntime/test/providers/cpu/math/topk_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@ static void RunTest(int op_set,
test.AddAttribute("axis", axis);
if (op_set <= 9)
test.AddAttribute("k", k);
if (op_set == 11 && largest != 1)
if (op_set >= 11 && largest != 1)
test.AddAttribute("largest", largest);
if (op_set == 11 && sorted != 1)
if (op_set >= 11 && sorted != 1)
test.AddAttribute("sorted", sorted);

// Inputs
Expand Down Expand Up @@ -679,6 +679,54 @@ TEST(TopKOperator, NthElementHalf_NegtiveVals) {
RunTest(11, 4, input_vals, input_dimensions, expected_vals, expected_indices, expected_dimensions, false);
}

TEST(TopKOperator, NthElementBFloat16) {
if (!CudaHasBF16Support()) {
return;
}

std::vector<float> input_vals_f = {10.0f, 8.0f, 7.0f, 4.0f, 5.0f, 6.0f};
std::vector<float> expected_vals_f = {10.0f, 8.0f, 7.0f, 6.0f};
std::vector<BFloat16> input_vals = FloatsToBFloat16s(input_vals_f);
std::vector<BFloat16> expected_vals = FloatsToBFloat16s(expected_vals_f);
std::vector<int64_t> input_dimensions = {6};
std::vector<int64_t> expected_indices = {0, 1, 2, 5};
std::vector<int64_t> expected_dimensions = {4};
RunTest(24, 4, input_vals, input_dimensions, expected_vals, expected_indices, expected_dimensions, false);
}

TEST(TopKOperator, NthElementBFloat16_NegativeVals) {
if (!CudaHasBF16Support()) {
return;
}

std::vector<float> input_vals_f = {10.0f, -8.0f, -7.0f, -4.0f, -5.0f, -6.0f};
std::vector<float> expected_vals_f = {10.0f, -4.0f, -5.0f, -6.0f};
std::vector<BFloat16> input_vals = FloatsToBFloat16s(input_vals_f);
std::vector<BFloat16> expected_vals = FloatsToBFloat16s(expected_vals_f);
std::vector<int64_t> input_dimensions = {6};
std::vector<int64_t> expected_indices = {0, 3, 4, 5};
std::vector<int64_t> expected_dimensions = {4};
RunTest(24, 4, input_vals, input_dimensions, expected_vals, expected_indices, expected_dimensions, false);
}

TEST(TopKOperator, TopKBFloat16_2D) {
if (!CudaHasBF16Support()) {
return;
}

std::vector<float> input_vals_f = {0.1f, 0.3f, 0.2f, 0.4f,
0.1f, 0.3f, 0.3f, 0.2f};
std::vector<float> expected_vals_f = {0.4f, 0.3f,
0.3f, 0.3f};
std::vector<BFloat16> input_vals = FloatsToBFloat16s(input_vals_f);
std::vector<BFloat16> expected_vals = FloatsToBFloat16s(expected_vals_f);
std::vector<int64_t> input_dimensions = {2, 4};
std::vector<int64_t> expected_indices = {3, 1,
1, 2};
std::vector<int64_t> expected_dimensions = {2, 2};
RunTest(24, 2, input_vals, input_dimensions, expected_vals, expected_indices, expected_dimensions, false);
}

// test dimension in range (GridDim::maxThreadsPerBlock, GridDim::maxThreadsPerBlock * 2], ie. [257, 512]
TEST(TopKOperator, SmallArrayTopKSorted) {
std::vector<float> input_vals(400, 0.0f);
Expand Down
Loading