diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 754408f4ca4e4..d91c188dd4c4c 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -996,8 +996,8 @@ Do not modify directly.* |||1+|**T** = tensor(double), tensor(float), tensor(float16)| |Tile|*in* input:**T**
*in* repeats:**T1**
*out* output:**T**

or

*in* input:**T**
*in* tiles:**T**
*in* axis:**T**
*out* output:**T**|13+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)
**T1** = tensor(int64)| |||[6, 12]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)
**T1** = tensor(int64)| -|TopK|*in* X:**T**
*in* K:**tensor(int64)**
*out* Values:**T**
*out* Indices:**I**

or

*in* X:**T**
*out* Values:**T**
*out* Indices:**I**|24+|**I** = tensor(int64)
**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)| -|||[11, 23]|**I** = tensor(int64)
**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)| +|TopK|*in* X:**T**
*in* K:**tensor(int64)**
*out* Values:**T**
*out* Indices:**I**

or

*in* X:**T**
*out* Values:**T**
*out* Indices:**I**|24+|**I** = tensor(int64)
**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint8)| +|||[11, 23]|**I** = tensor(int64)
**T** = tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint8)| |||10|**I** = tensor(int64)
**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)| |||[1, 9]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)| |Transpose|*in* data:**T**
*out* transposed:**T**|23+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| diff --git a/onnxruntime/core/providers/cuda/math/topk.cc b/onnxruntime/core/providers/cuda/math/topk.cc index bab6f15f2c774..b877fdc508aed 100644 --- a/onnxruntime/core/providers/cuda/math/topk.cc +++ b/onnxruntime/core/providers/cuda/math/topk.cc @@ -46,7 +46,10 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX( DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType(), - DataTypeImpl::GetTensorType()}) + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType()}) .TypeConstraint("I", DataTypeImpl::GetTensorType()), TopK); @@ -62,6 +65,9 @@ ONNX_OPERATOR_KERNEL_EX( DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType()}) .TypeConstraint("I", DataTypeImpl::GetTensorType()), TopK); @@ -137,6 +143,9 @@ Status TopK::ComputeInternal(OpKernelContext* ctx) const { if (IS_PRIM_TYPE(int32_t)) return TOPKIMPL(int32_t); if (IS_PRIM_TYPE(int64_t)) return TOPKIMPL(int64_t); + if (IS_PRIM_TYPE(int8_t)) return TOPKIMPL(int8_t); + if (IS_PRIM_TYPE(int16_t)) return TOPKIMPL(int16_t); + if (IS_PRIM_TYPE(uint8_t)) return TOPKIMPL(uint8_t); if (IS_PRIM_TYPE(MLFloat16)) return TOPKIMPL(MLFloat16); if (IS_PRIM_TYPE(float)) return TOPKIMPL(float); if (IS_PRIM_TYPE(double)) return TOPKIMPL(double); diff --git a/onnxruntime/core/providers/cuda/math/topk_impl_i16.cu b/onnxruntime/core/providers/cuda/math/topk_impl_i16.cu new file mode 100644 index 0000000000000..e194bd1bfd15a --- /dev/null +++ b/onnxruntime/core/providers/cuda/math/topk_impl_i16.cu @@ -0,0 +1,5 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#define TOPK_IMPL_TYPE int16_t +#include "topk_impl.cuh" diff --git a/onnxruntime/core/providers/cuda/math/topk_impl_i8.cu b/onnxruntime/core/providers/cuda/math/topk_impl_i8.cu new file mode 100644 index 0000000000000..db32e9e43392f --- /dev/null +++ b/onnxruntime/core/providers/cuda/math/topk_impl_i8.cu @@ -0,0 +1,5 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#define TOPK_IMPL_TYPE int8_t +#include "topk_impl.cuh" diff --git a/onnxruntime/core/providers/cuda/math/topk_impl_u8.cu b/onnxruntime/core/providers/cuda/math/topk_impl_u8.cu new file mode 100644 index 0000000000000..7fcd4b81b3bf9 --- /dev/null +++ b/onnxruntime/core/providers/cuda/math/topk_impl_u8.cu @@ -0,0 +1,5 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#define TOPK_IMPL_TYPE uint8_t +#include "topk_impl.cuh"