diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md
index 754408f4ca4e4..d91c188dd4c4c 100644
--- a/docs/OperatorKernels.md
+++ b/docs/OperatorKernels.md
@@ -996,8 +996,8 @@ Do not modify directly.*
|||1+|**T** = tensor(double), tensor(float), tensor(float16)|
|Tile|*in* input:**T**
*in* repeats:**T1**
*out* output:**T**
or
*in* input:**T**
*in* tiles:**T**
*in* axis:**T**
*out* output:**T**|13+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)
**T1** = tensor(int64)|
|||[6, 12]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)
**T1** = tensor(int64)|
-|TopK|*in* X:**T**
*in* K:**tensor(int64)**
*out* Values:**T**
*out* Indices:**I**
or
*in* X:**T**
*out* Values:**T**
*out* Indices:**I**|24+|**I** = tensor(int64)
**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)|
-|||[11, 23]|**I** = tensor(int64)
**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)|
+|TopK|*in* X:**T**
*in* K:**tensor(int64)**
*out* Values:**T**
*out* Indices:**I**
or
*in* X:**T**
*out* Values:**T**
*out* Indices:**I**|24+|**I** = tensor(int64)
**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint8)|
+|||[11, 23]|**I** = tensor(int64)
**T** = tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint8)|
|||10|**I** = tensor(int64)
**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)|
|||[1, 9]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)|
|Transpose|*in* data:**T**
*out* transposed:**T**|23+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
diff --git a/onnxruntime/core/providers/cuda/math/topk.cc b/onnxruntime/core/providers/cuda/math/topk.cc
index bab6f15f2c774..b877fdc508aed 100644
--- a/onnxruntime/core/providers/cuda/math/topk.cc
+++ b/onnxruntime/core/providers/cuda/math/topk.cc
@@ -46,7 +46,10 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX(
DataTypeImpl::GetTensorType(),
DataTypeImpl::GetTensorType(),
DataTypeImpl::GetTensorType(),
- DataTypeImpl::GetTensorType()})
+ DataTypeImpl::GetTensorType(),
+ DataTypeImpl::GetTensorType(),
+ DataTypeImpl::GetTensorType(),
+ DataTypeImpl::GetTensorType()})
.TypeConstraint("I", DataTypeImpl::GetTensorType()),
TopK);
@@ -62,6 +65,9 @@ ONNX_OPERATOR_KERNEL_EX(
DataTypeImpl::GetTensorType(),
DataTypeImpl::GetTensorType(),
DataTypeImpl::GetTensorType(),
+ DataTypeImpl::GetTensorType(),
+ DataTypeImpl::GetTensorType(),
+ DataTypeImpl::GetTensorType(),
DataTypeImpl::GetTensorType()})
.TypeConstraint("I", DataTypeImpl::GetTensorType()),
TopK);
@@ -137,6 +143,9 @@ Status TopK::ComputeInternal(OpKernelContext* ctx) const {
if (IS_PRIM_TYPE(int32_t)) return TOPKIMPL(int32_t);
if (IS_PRIM_TYPE(int64_t)) return TOPKIMPL(int64_t);
+ if (IS_PRIM_TYPE(int8_t)) return TOPKIMPL(int8_t);
+ if (IS_PRIM_TYPE(int16_t)) return TOPKIMPL(int16_t);
+ if (IS_PRIM_TYPE(uint8_t)) return TOPKIMPL(uint8_t);
if (IS_PRIM_TYPE(MLFloat16)) return TOPKIMPL(MLFloat16);
if (IS_PRIM_TYPE(float)) return TOPKIMPL(float);
if (IS_PRIM_TYPE(double)) return TOPKIMPL(double);
diff --git a/onnxruntime/core/providers/cuda/math/topk_impl_i16.cu b/onnxruntime/core/providers/cuda/math/topk_impl_i16.cu
new file mode 100644
index 0000000000000..e194bd1bfd15a
--- /dev/null
+++ b/onnxruntime/core/providers/cuda/math/topk_impl_i16.cu
@@ -0,0 +1,5 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#define TOPK_IMPL_TYPE int16_t
+#include "topk_impl.cuh"
diff --git a/onnxruntime/core/providers/cuda/math/topk_impl_i8.cu b/onnxruntime/core/providers/cuda/math/topk_impl_i8.cu
new file mode 100644
index 0000000000000..db32e9e43392f
--- /dev/null
+++ b/onnxruntime/core/providers/cuda/math/topk_impl_i8.cu
@@ -0,0 +1,5 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#define TOPK_IMPL_TYPE int8_t
+#include "topk_impl.cuh"
diff --git a/onnxruntime/core/providers/cuda/math/topk_impl_u8.cu b/onnxruntime/core/providers/cuda/math/topk_impl_u8.cu
new file mode 100644
index 0000000000000..7fcd4b81b3bf9
--- /dev/null
+++ b/onnxruntime/core/providers/cuda/math/topk_impl_u8.cu
@@ -0,0 +1,5 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#define TOPK_IMPL_TYPE uint8_t
+#include "topk_impl.cuh"