diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 754408f4ca4e4..971654436cab2 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -503,8 +503,8 @@ Do not modify directly.* |||[1, 9]|**T** = tensor(float)| |Tile|*in* input:**T**
*in* repeats:**T1**
*out* output:**T**

or

*in* input:**T**
*in* tiles:**T**
*in* axis:**T**
*out* output:**T**|13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)| |||[6, 12]|**T** = tensor(bool), tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)| -|TopK|*in* X:**T**
*in* K:**tensor(int64)**
*out* Values:**T**
*out* Indices:**I**

or

*in* X:**T**
*out* Values:**T**
*out* Indices:**I**|24+|**I** = tensor(int64)
**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)| -|||[11, 23]|**I** = tensor(int64)
**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)| +|TopK|*in* X:**T**
*in* K:**tensor(int64)**
*out* Values:**T**
*out* Indices:**I**

or

*in* X:**T**
*out* Values:**T**
*out* Indices:**I**|24+|**I** = tensor(int64)
**T** = tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint8)| +|||[11, 23]|**I** = tensor(int64)
**T** = tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint8)| |||10|**I** = tensor(int64)
**T** = tensor(double), tensor(float)| |||[1, 9]|**I** = tensor(int64)
**T** = tensor(double), tensor(float)| |Transpose|*in* data:**T**
*out* transposed:**T**|25+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int2), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint2), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)| diff --git a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc index 9f19a20a2e680..5dcfccd6ff5b9 100644 --- a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc +++ b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc @@ -587,6 +587,9 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOn class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 23, double, TopK); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 23, int64_t, TopK); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 23, int32_t, TopK); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 23, int8_t, TopK); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 23, int16_t, TopK); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 23, uint8_t, TopK); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, int64_t_int64_t_int64_t, OneHot); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, float_int64_t_int64_t, OneHot); #if !defined(DISABLE_STRING_TYPE) @@ -1455,6 +1458,9 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 24, double, TopK); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 24, int64_t, TopK); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 24, int32_t, TopK); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 24, int8_t, TopK); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 24, int16_t, TopK); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 24, uint8_t, TopK); // Opset 25 class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 25, Cast); @@ -2281,6 +2287,12 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { int64_t, TopK)>, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, // opset 25 diff --git a/onnxruntime/core/providers/cpu/math/top_k.cc b/onnxruntime/core/providers/cpu/math/top_k.cc index 1666bfa7b2d03..ab28b2bacf1ab 100644 --- a/onnxruntime/core/providers/cpu/math/top_k.cc +++ b/onnxruntime/core/providers/cpu/math/top_k.cc @@ -542,17 +542,23 @@ static void TopkOpset11ConstructorCommon(const OpKernelInfo& op_kernel_info, return ComputeImplOpset1011(p_op_kernel_context, axis_, largest_, sorted_); \ } -// Generate specializations for opset 11 (used by versioned kernel 11-23) +// Generate specializations for opset 11-23 (used by versioned kernel 11-23) TOPK_MODERN_OPSET_SPECIALIZATIONS(23, float); TOPK_MODERN_OPSET_SPECIALIZATIONS(23, double); TOPK_MODERN_OPSET_SPECIALIZATIONS(23, int32_t); TOPK_MODERN_OPSET_SPECIALIZATIONS(23, int64_t); +TOPK_MODERN_OPSET_SPECIALIZATIONS(23, int8_t); +TOPK_MODERN_OPSET_SPECIALIZATIONS(23, int16_t); +TOPK_MODERN_OPSET_SPECIALIZATIONS(23, uint8_t); // Generate specializations for opset 24 (used by current kernel 24+) TOPK_MODERN_OPSET_SPECIALIZATIONS(24, float); TOPK_MODERN_OPSET_SPECIALIZATIONS(24, double); TOPK_MODERN_OPSET_SPECIALIZATIONS(24, int32_t); TOPK_MODERN_OPSET_SPECIALIZATIONS(24, int64_t); +TOPK_MODERN_OPSET_SPECIALIZATIONS(24, int8_t); +TOPK_MODERN_OPSET_SPECIALIZATIONS(24, int16_t); +TOPK_MODERN_OPSET_SPECIALIZATIONS(24, uint8_t); // Register necessary kernels // spec https://github.com/onnx/onnx/blob/main/docs/Operators.md#TopK @@ -582,10 +588,16 @@ REGISTER_TOPK_VERSIONED_TYPED_KERNEL(11, 23, float); REGISTER_TOPK_VERSIONED_TYPED_KERNEL(11, 23, double); REGISTER_TOPK_VERSIONED_TYPED_KERNEL(11, 23, int64_t); REGISTER_TOPK_VERSIONED_TYPED_KERNEL(11, 23, int32_t); +REGISTER_TOPK_VERSIONED_TYPED_KERNEL(11, 23, int8_t); +REGISTER_TOPK_VERSIONED_TYPED_KERNEL(11, 23, int16_t); +REGISTER_TOPK_VERSIONED_TYPED_KERNEL(11, 23, uint8_t); REGISTER_TOPK_TYPED_KERNEL(24, float); REGISTER_TOPK_TYPED_KERNEL(24, double); REGISTER_TOPK_TYPED_KERNEL(24, int64_t); REGISTER_TOPK_TYPED_KERNEL(24, int32_t); +REGISTER_TOPK_TYPED_KERNEL(24, int8_t); +REGISTER_TOPK_TYPED_KERNEL(24, int16_t); +REGISTER_TOPK_TYPED_KERNEL(24, uint8_t); } // namespace onnxruntime diff --git a/onnxruntime/test/providers/cpu/math/topk_op_test.cc b/onnxruntime/test/providers/cpu/math/topk_op_test.cc index d5732d9c6da02..8dbad50344ddf 100644 --- a/onnxruntime/test/providers/cpu/math/topk_op_test.cc +++ b/onnxruntime/test/providers/cpu/math/topk_op_test.cc @@ -621,6 +621,12 @@ TEST(TopKOperator, Top1ExplicitAxisMultiDInputSmallestElements) { top_1_explicit_axis_MultiD_input_smallest(11, 0); // unsorted top_1_explicit_axis_MultiD_input_smallest(11); top_1_explicit_axis_MultiD_input_smallest(11, 0); // unsorted + top_1_explicit_axis_MultiD_input_smallest(11); + top_1_explicit_axis_MultiD_input_smallest(11, 0); // unsorted + top_1_explicit_axis_MultiD_input_smallest(11); + top_1_explicit_axis_MultiD_input_smallest(11, 0); // unsorted + top_1_explicit_axis_MultiD_input_smallest(11); + top_1_explicit_axis_MultiD_input_smallest(11, 0); // unsorted } // test path where SelectTopK is used (select using std::nth_element) @@ -909,5 +915,80 @@ TEST(TopKOperator, SelectTopKThreaded) { TestThreaded(k, n, batch_size); } +// Tests for INT8, INT16, UINT8 types (opset 11+) +TEST(TopKOperator, TopK_Int8) { + std::vector input_vals = {10, 30, 20, 40, 10, 30, 40, 20}; + std::vector input_dimensions = {2, 4}; + std::vector expected_vals = {40, 30, 40, 30}; + std::vector expected_indices = {3, 1, 2, 1}; + std::vector expected_dimensions = {2, 2}; + RunTest(11, 2, input_vals, input_dimensions, expected_vals, expected_indices, expected_dimensions, false); +} + +TEST(TopKOperator, TopK_Int8_Negative) { + std::vector input_vals = {-10, -30, -20, -40, -10, -30, -40, -20}; + std::vector input_dimensions = {2, 4}; + std::vector expected_vals = {-10, -20, -10, -20}; + std::vector expected_indices = {0, 2, 0, 3}; + std::vector expected_dimensions = {2, 2}; + RunTest(11, 2, input_vals, input_dimensions, expected_vals, expected_indices, expected_dimensions, false); +} + +TEST(TopKOperator, TopK_Int8_Smallest) { + std::vector input_vals = {10, 30, 20, 40, 10, 30, 40, 20}; + std::vector input_dimensions = {2, 4}; + std::vector expected_vals = {10, 20, 10, 20}; + std::vector expected_indices = {0, 2, 0, 3}; + std::vector expected_dimensions = {2, 2}; + RunTest(11, 2, input_vals, input_dimensions, expected_vals, expected_indices, expected_dimensions, false, -1, 0); +} + +TEST(TopKOperator, TopK_Int16) { + std::vector input_vals = {100, 300, 200, 400, 100, 300, 400, 200}; + std::vector input_dimensions = {2, 4}; + std::vector expected_vals = {400, 300, 400, 300}; + std::vector expected_indices = {3, 1, 2, 1}; + std::vector expected_dimensions = {2, 2}; + RunTest(11, 2, input_vals, input_dimensions, expected_vals, expected_indices, expected_dimensions, false); +} + +TEST(TopKOperator, TopK_Uint8) { + std::vector input_vals = {10, 30, 20, 40, 10, 30, 40, 20}; + std::vector input_dimensions = {2, 4}; + std::vector expected_vals = {40, 30, 40, 30}; + std::vector expected_indices = {3, 1, 2, 1}; + std::vector expected_dimensions = {2, 2}; + RunTest(11, 2, input_vals, input_dimensions, expected_vals, expected_indices, expected_dimensions, false); +} + +TEST(TopKOperator, TopK_Int8_ExplicitAxis) { + std::vector input_vals = {1, 2, 3, 4, 5, 6, 7, 8}; + std::vector input_dimensions = {2, 2, 2}; + std::vector expected_vals = {3, 4, 7, 8}; + std::vector expected_indices = {1, 1, 1, 1}; + std::vector expected_dimensions = {2, 1, 2}; + int64_t axis = 1; + RunTest(11, 1, input_vals, input_dimensions, expected_vals, expected_indices, expected_dimensions, false, axis); +} + +// Opset 24 tests for new types +TEST(TopKOperator, TopK_Int8_Opset24) { + std::vector input_vals = {10, 30, 20, 40, 10, 30, 40, 20}; + std::vector input_dimensions = {2, 4}; + std::vector expected_vals = {40, 30, 40, 30}; + std::vector expected_indices = {3, 1, 2, 1}; + std::vector expected_dimensions = {2, 2}; + RunTest(24, 2, input_vals, input_dimensions, expected_vals, expected_indices, expected_dimensions, false); +} + +TEST(TopKOperator, TopK_Uint8_Opset24) { + std::vector input_vals = {10, 30, 20, 40, 10, 30, 40, 20}; + std::vector input_dimensions = {2, 4}; + std::vector expected_vals = {40, 30, 40, 30}; + std::vector expected_indices = {3, 1, 2, 1}; + std::vector expected_dimensions = {2, 2}; + RunTest(24, 2, input_vals, input_dimensions, expected_vals, expected_indices, expected_dimensions, false); +} + } // namespace test } // namespace onnxruntime