diff --git a/onnxruntime/core/providers/cpu/tensor/gather.cc b/onnxruntime/core/providers/cpu/tensor/gather.cc index fdf62b6302d2b..b0fd5dcdf921b 100644 --- a/onnxruntime/core/providers/cpu/tensor/gather.cc +++ b/onnxruntime/core/providers/cpu/tensor/gather.cc @@ -5,17 +5,31 @@ #include "core/providers/cpu/tensor/gather.h" #include "core/common/common.h" #include "core/platform/threadpool.h" +#include "core/providers/op_kernel_type_control.h" +#include "core/providers/op_kernel_type_control_utils.h" namespace onnxruntime { +namespace op_kernel_type_control { +ORT_SPECIFY_OP_KERNEL_ARG_SUPPORTED_TYPES_ALL_OPSETS( + kCpuExecutionProvider, kOnnxDomain, Gather, Input, 1, int32_t, int64_t); +} + +namespace { +using EnabledIndexTypes = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST_ALL_OPSETS( + kCpuExecutionProvider, kOnnxDomain, Gather, Input, 1); + +const auto index_type_constraints = + BuildKernelDefConstraintsFunctorFromTypeList{}(); +} // namespace + ONNX_CPU_OPERATOR_VERSIONED_KERNEL( Gather, 1, 10, KernelDefBuilder() .TypeConstraint("T", DataTypeImpl::AllTensorTypes()) - .TypeConstraint("Tind", std::vector{DataTypeImpl::GetTensorType(), - DataTypeImpl::GetTensorType()}), + .TypeConstraint("Tind", index_type_constraints), Gather); ONNX_CPU_OPERATOR_VERSIONED_KERNEL( @@ -24,8 +38,7 @@ ONNX_CPU_OPERATOR_VERSIONED_KERNEL( 12, KernelDefBuilder() .TypeConstraint("T", DataTypeImpl::AllTensorTypes()) - .TypeConstraint("Tind", std::vector{DataTypeImpl::GetTensorType(), - DataTypeImpl::GetTensorType()}), + .TypeConstraint("Tind", index_type_constraints), Gather); ONNX_CPU_OPERATOR_KERNEL( @@ -33,8 +46,7 @@ ONNX_CPU_OPERATOR_KERNEL( 13, KernelDefBuilder() .TypeConstraint("T", DataTypeImpl::AllTensorTypes()) - .TypeConstraint("Tind", std::vector{DataTypeImpl::GetTensorType(), - DataTypeImpl::GetTensorType()}), + .TypeConstraint("Tind", index_type_constraints), Gather); Status GatherBase::PrepareForCompute(OpKernelContext* context, Prepare& p) const { @@ -132,16 +144,18 @@ Status Gather::Compute(OpKernelContext* context) const { concurrency::ThreadPool* tp = context->GetOperatorThreadPool(); - if (p.indices_tensor->IsDataType()) { + if (utils::HasTypeWithSameSize() && + p.indices_tensor->IsDataType()) { return GatherCopyData(p.indices_tensor, src_base, dst_base, is_string_type, element_bytes, block_size, M, N, data_batch_bytes, gathered_batch_bytes, input_data_shape, p.axis, tp); } - if (p.indices_tensor->IsDataType()) { + if (utils::HasTypeWithSameSize() && + p.indices_tensor->IsDataType()) { return GatherCopyData(p.indices_tensor, src_base, dst_base, is_string_type, element_bytes, block_size, M, N, data_batch_bytes, gathered_batch_bytes, input_data_shape, p.axis, tp); } - return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "Type for Tind not supported yet in Gather."); + return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "Gather Tind type not supported in this build."); } } // namespace onnxruntime diff --git a/onnxruntime/core/providers/op_kernel_type_control.h b/onnxruntime/core/providers/op_kernel_type_control.h index d70ca7600ba5d..a9030838acd6c 100644 --- a/onnxruntime/core/providers/op_kernel_type_control.h +++ b/onnxruntime/core/providers/op_kernel_type_control.h @@ -20,7 +20,7 @@ * - Enabled types are the types that are supported in the actual, compiled implementation. They are obtained from the * intersection of supported and allowed types. * - * The types are associated with an Op kernel argument. It is also possible to specify a global list of allowed types. + * The types are associated with an Op argument. It is also possible to specify a global list of allowed types. * * Use of these utilities is optional. They are useful for cases where one registered Op kernel handles multiple types. * @@ -239,8 +239,8 @@ struct EnabledTypes { * namespace onnxruntime { * namespace op_kernel_type_control { * // specify supported types, i.e., the full set of types that can be enabled - * ORT_SPECIFY_OP_KERNEL_ARG_SUPPORTED_TYPES( - * MyProvider, DomainContainingMyOp, MyOp, OpSet, Input, 0, + * ORT_SPECIFY_OP_KERNEL_ARG_SUPPORTED_TYPES_ALL_OPSETS( + * MyProvider, DomainContainingMyOp, MyOp, Input, 0, * int, float, double); * } // namespace op_kernel_type_control * } // namespace onnxruntime @@ -249,7 +249,7 @@ struct EnabledTypes { * * // get enabled types * using MyOpFirstInputEnabledTypes = - * ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST(MyProvider, DomainContainingMyOp, MyOp, Input, 0); + * ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST_ALL_OPSETS(MyProvider, DomainContainingMyOp, MyOp, Input, 0); * * // ... * diff --git a/onnxruntime/core/providers/op_kernel_type_control_utils.h b/onnxruntime/core/providers/op_kernel_type_control_utils.h index 3e1354d43aaf6..e79bcdb8ba492 100644 --- a/onnxruntime/core/providers/op_kernel_type_control_utils.h +++ b/onnxruntime/core/providers/op_kernel_type_control_utils.h @@ -24,8 +24,8 @@ using SizeOfT = boost::mp11::mp_size_t; /** * Check if the set of types contains a type with the same size as T. -* -* @remarks e.g. will return true if T is int32_t and the list contains any 4 byte type (i.e. sizeof(int32_t)) +* +* @remarks e.g. will return true if T is int32_t and the list contains any 4 byte type (i.e. sizeof(int32_t)) * such as int32_t, uint32_t or float. */ template