diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md
index 754408f4ca4e4..971654436cab2 100644
--- a/docs/OperatorKernels.md
+++ b/docs/OperatorKernels.md
@@ -503,8 +503,8 @@ Do not modify directly.*
|||[1, 9]|**T** = tensor(float)|
|Tile|*in* input:**T**
*in* repeats:**T1**
*out* output:**T**
or
*in* input:**T**
*in* tiles:**T**
*in* axis:**T**
*out* output:**T**|13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)|
|||[6, 12]|**T** = tensor(bool), tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)|
-|TopK|*in* X:**T**
*in* K:**tensor(int64)**
*out* Values:**T**
*out* Indices:**I**
or
*in* X:**T**
*out* Values:**T**
*out* Indices:**I**|24+|**I** = tensor(int64)
**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)|
-|||[11, 23]|**I** = tensor(int64)
**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)|
+|TopK|*in* X:**T**
*in* K:**tensor(int64)**
*out* Values:**T**
*out* Indices:**I**
or
*in* X:**T**
*out* Values:**T**
*out* Indices:**I**|24+|**I** = tensor(int64)
**T** = tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint8)|
+|||[11, 23]|**I** = tensor(int64)
**T** = tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint8)|
|||10|**I** = tensor(int64)
**T** = tensor(double), tensor(float)|
|||[1, 9]|**I** = tensor(int64)
**T** = tensor(double), tensor(float)|
|Transpose|*in* data:**T**
*out* transposed:**T**|25+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int2), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint2), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)|
diff --git a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc
index 9f19a20a2e680..5dcfccd6ff5b9 100644
--- a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc
+++ b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc
@@ -587,6 +587,9 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOn
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 23, double, TopK);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 23, int64_t, TopK);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 23, int32_t, TopK);
+class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 23, int8_t, TopK);
+class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 23, int16_t, TopK);
+class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 23, uint8_t, TopK);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, int64_t_int64_t_int64_t, OneHot);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, float_int64_t_int64_t, OneHot);
#if !defined(DISABLE_STRING_TYPE)
@@ -1455,6 +1458,9 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 24, double, TopK);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 24, int64_t, TopK);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 24, int32_t, TopK);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 24, int8_t, TopK);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 24, int16_t, TopK);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 24, uint8_t, TopK);
// Opset 25
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 25, Cast);
@@ -2281,6 +2287,12 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
int64_t, TopK)>,
BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
BuildKernelCreateInfo,
// opset 25
diff --git a/onnxruntime/core/providers/cpu/math/top_k.cc b/onnxruntime/core/providers/cpu/math/top_k.cc
index 1666bfa7b2d03..ab28b2bacf1ab 100644
--- a/onnxruntime/core/providers/cpu/math/top_k.cc
+++ b/onnxruntime/core/providers/cpu/math/top_k.cc
@@ -542,17 +542,23 @@ static void TopkOpset11ConstructorCommon(const OpKernelInfo& op_kernel_info,
return ComputeImplOpset1011(p_op_kernel_context, axis_, largest_, sorted_); \
}
-// Generate specializations for opset 11 (used by versioned kernel 11-23)
+// Generate specializations for opset 11-23 (used by versioned kernel 11-23)
TOPK_MODERN_OPSET_SPECIALIZATIONS(23, float);
TOPK_MODERN_OPSET_SPECIALIZATIONS(23, double);
TOPK_MODERN_OPSET_SPECIALIZATIONS(23, int32_t);
TOPK_MODERN_OPSET_SPECIALIZATIONS(23, int64_t);
+TOPK_MODERN_OPSET_SPECIALIZATIONS(23, int8_t);
+TOPK_MODERN_OPSET_SPECIALIZATIONS(23, int16_t);
+TOPK_MODERN_OPSET_SPECIALIZATIONS(23, uint8_t);
// Generate specializations for opset 24 (used by current kernel 24+)
TOPK_MODERN_OPSET_SPECIALIZATIONS(24, float);
TOPK_MODERN_OPSET_SPECIALIZATIONS(24, double);
TOPK_MODERN_OPSET_SPECIALIZATIONS(24, int32_t);
TOPK_MODERN_OPSET_SPECIALIZATIONS(24, int64_t);
+TOPK_MODERN_OPSET_SPECIALIZATIONS(24, int8_t);
+TOPK_MODERN_OPSET_SPECIALIZATIONS(24, int16_t);
+TOPK_MODERN_OPSET_SPECIALIZATIONS(24, uint8_t);
// Register necessary kernels
// spec https://github.com/onnx/onnx/blob/main/docs/Operators.md#TopK
@@ -582,10 +588,16 @@ REGISTER_TOPK_VERSIONED_TYPED_KERNEL(11, 23, float);
REGISTER_TOPK_VERSIONED_TYPED_KERNEL(11, 23, double);
REGISTER_TOPK_VERSIONED_TYPED_KERNEL(11, 23, int64_t);
REGISTER_TOPK_VERSIONED_TYPED_KERNEL(11, 23, int32_t);
+REGISTER_TOPK_VERSIONED_TYPED_KERNEL(11, 23, int8_t);
+REGISTER_TOPK_VERSIONED_TYPED_KERNEL(11, 23, int16_t);
+REGISTER_TOPK_VERSIONED_TYPED_KERNEL(11, 23, uint8_t);
REGISTER_TOPK_TYPED_KERNEL(24, float);
REGISTER_TOPK_TYPED_KERNEL(24, double);
REGISTER_TOPK_TYPED_KERNEL(24, int64_t);
REGISTER_TOPK_TYPED_KERNEL(24, int32_t);
+REGISTER_TOPK_TYPED_KERNEL(24, int8_t);
+REGISTER_TOPK_TYPED_KERNEL(24, int16_t);
+REGISTER_TOPK_TYPED_KERNEL(24, uint8_t);
} // namespace onnxruntime
diff --git a/onnxruntime/test/providers/cpu/math/topk_op_test.cc b/onnxruntime/test/providers/cpu/math/topk_op_test.cc
index d5732d9c6da02..8dbad50344ddf 100644
--- a/onnxruntime/test/providers/cpu/math/topk_op_test.cc
+++ b/onnxruntime/test/providers/cpu/math/topk_op_test.cc
@@ -621,6 +621,12 @@ TEST(TopKOperator, Top1ExplicitAxisMultiDInputSmallestElements) {
top_1_explicit_axis_MultiD_input_smallest(11, 0); // unsorted
top_1_explicit_axis_MultiD_input_smallest(11);
top_1_explicit_axis_MultiD_input_smallest(11, 0); // unsorted
+ top_1_explicit_axis_MultiD_input_smallest(11);
+ top_1_explicit_axis_MultiD_input_smallest(11, 0); // unsorted
+ top_1_explicit_axis_MultiD_input_smallest(11);
+ top_1_explicit_axis_MultiD_input_smallest(11, 0); // unsorted
+ top_1_explicit_axis_MultiD_input_smallest(11);
+ top_1_explicit_axis_MultiD_input_smallest(11, 0); // unsorted
}
// test path where SelectTopK is used (select using std::nth_element)
@@ -909,5 +915,80 @@ TEST(TopKOperator, SelectTopKThreaded) {
TestThreaded(k, n, batch_size);
}
+// Tests for INT8, INT16, UINT8 types (opset 11+)
+TEST(TopKOperator, TopK_Int8) {
+ std::vector input_vals = {10, 30, 20, 40, 10, 30, 40, 20};
+ std::vector input_dimensions = {2, 4};
+ std::vector expected_vals = {40, 30, 40, 30};
+ std::vector expected_indices = {3, 1, 2, 1};
+ std::vector expected_dimensions = {2, 2};
+ RunTest(11, 2, input_vals, input_dimensions, expected_vals, expected_indices, expected_dimensions, false);
+}
+
+TEST(TopKOperator, TopK_Int8_Negative) {
+ std::vector input_vals = {-10, -30, -20, -40, -10, -30, -40, -20};
+ std::vector input_dimensions = {2, 4};
+ std::vector expected_vals = {-10, -20, -10, -20};
+ std::vector expected_indices = {0, 2, 0, 3};
+ std::vector expected_dimensions = {2, 2};
+ RunTest(11, 2, input_vals, input_dimensions, expected_vals, expected_indices, expected_dimensions, false);
+}
+
+TEST(TopKOperator, TopK_Int8_Smallest) {
+ std::vector input_vals = {10, 30, 20, 40, 10, 30, 40, 20};
+ std::vector input_dimensions = {2, 4};
+ std::vector expected_vals = {10, 20, 10, 20};
+ std::vector expected_indices = {0, 2, 0, 3};
+ std::vector expected_dimensions = {2, 2};
+ RunTest(11, 2, input_vals, input_dimensions, expected_vals, expected_indices, expected_dimensions, false, -1, 0);
+}
+
+TEST(TopKOperator, TopK_Int16) {
+ std::vector input_vals = {100, 300, 200, 400, 100, 300, 400, 200};
+ std::vector input_dimensions = {2, 4};
+ std::vector expected_vals = {400, 300, 400, 300};
+ std::vector expected_indices = {3, 1, 2, 1};
+ std::vector expected_dimensions = {2, 2};
+ RunTest(11, 2, input_vals, input_dimensions, expected_vals, expected_indices, expected_dimensions, false);
+}
+
+TEST(TopKOperator, TopK_Uint8) {
+ std::vector input_vals = {10, 30, 20, 40, 10, 30, 40, 20};
+ std::vector input_dimensions = {2, 4};
+ std::vector expected_vals = {40, 30, 40, 30};
+ std::vector expected_indices = {3, 1, 2, 1};
+ std::vector expected_dimensions = {2, 2};
+ RunTest(11, 2, input_vals, input_dimensions, expected_vals, expected_indices, expected_dimensions, false);
+}
+
+TEST(TopKOperator, TopK_Int8_ExplicitAxis) {
+ std::vector input_vals = {1, 2, 3, 4, 5, 6, 7, 8};
+ std::vector input_dimensions = {2, 2, 2};
+ std::vector expected_vals = {3, 4, 7, 8};
+ std::vector expected_indices = {1, 1, 1, 1};
+ std::vector expected_dimensions = {2, 1, 2};
+ int64_t axis = 1;
+ RunTest(11, 1, input_vals, input_dimensions, expected_vals, expected_indices, expected_dimensions, false, axis);
+}
+
+// Opset 24 tests for new types
+TEST(TopKOperator, TopK_Int8_Opset24) {
+ std::vector input_vals = {10, 30, 20, 40, 10, 30, 40, 20};
+ std::vector input_dimensions = {2, 4};
+ std::vector expected_vals = {40, 30, 40, 30};
+ std::vector expected_indices = {3, 1, 2, 1};
+ std::vector expected_dimensions = {2, 2};
+ RunTest(24, 2, input_vals, input_dimensions, expected_vals, expected_indices, expected_dimensions, false);
+}
+
+TEST(TopKOperator, TopK_Uint8_Opset24) {
+ std::vector input_vals = {10, 30, 20, 40, 10, 30, 40, 20};
+ std::vector input_dimensions = {2, 4};
+ std::vector expected_vals = {40, 30, 40, 30};
+ std::vector expected_indices = {3, 1, 2, 1};
+ std::vector expected_dimensions = {2, 2};
+ RunTest(24, 2, input_vals, input_dimensions, expected_vals, expected_indices, expected_dimensions, false);
+}
+
} // namespace test
} // namespace onnxruntime