Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 23 additions & 9 deletions onnxruntime/core/providers/cpu/tensor/gather.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,31 @@
#include "core/providers/cpu/tensor/gather.h"
#include "core/common/common.h"
#include "core/platform/threadpool.h"
#include "core/providers/op_kernel_type_control.h"
#include "core/providers/op_kernel_type_control_utils.h"

namespace onnxruntime {

namespace op_kernel_type_control {
ORT_SPECIFY_OP_KERNEL_ARG_SUPPORTED_TYPES_ALL_OPSETS(
kCpuExecutionProvider, kOnnxDomain, Gather, Input, 1, int32_t, int64_t);
}

namespace {
using EnabledIndexTypes = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST_ALL_OPSETS(
kCpuExecutionProvider, kOnnxDomain, Gather, Input, 1);

const auto index_type_constraints =
BuildKernelDefConstraintsFunctorFromTypeList<EnabledIndexTypes>{}();
} // namespace

ONNX_CPU_OPERATOR_VERSIONED_KERNEL(
Gather,
1,
10,
KernelDefBuilder()
.TypeConstraint("T", DataTypeImpl::AllTensorTypes())
.TypeConstraint("Tind", std::vector<MLDataType>{DataTypeImpl::GetTensorType<int32_t>(),
DataTypeImpl::GetTensorType<int64_t>()}),
.TypeConstraint("Tind", index_type_constraints),
Gather);

ONNX_CPU_OPERATOR_VERSIONED_KERNEL(
Expand All @@ -24,17 +38,15 @@ ONNX_CPU_OPERATOR_VERSIONED_KERNEL(
12,
KernelDefBuilder()
.TypeConstraint("T", DataTypeImpl::AllTensorTypes())
.TypeConstraint("Tind", std::vector<MLDataType>{DataTypeImpl::GetTensorType<int32_t>(),
DataTypeImpl::GetTensorType<int64_t>()}),
.TypeConstraint("Tind", index_type_constraints),
Gather);

ONNX_CPU_OPERATOR_KERNEL(
Gather,
13,
KernelDefBuilder()
.TypeConstraint("T", DataTypeImpl::AllTensorTypes())
.TypeConstraint("Tind", std::vector<MLDataType>{DataTypeImpl::GetTensorType<int32_t>(),
DataTypeImpl::GetTensorType<int64_t>()}),
.TypeConstraint("Tind", index_type_constraints),
Gather);

Status GatherBase::PrepareForCompute(OpKernelContext* context, Prepare& p) const {
Expand Down Expand Up @@ -132,16 +144,18 @@ Status Gather::Compute(OpKernelContext* context) const {

concurrency::ThreadPool* tp = context->GetOperatorThreadPool();

if (p.indices_tensor->IsDataType<int32_t>()) {
if (utils::HasTypeWithSameSize<EnabledIndexTypes, int32_t>() &&
p.indices_tensor->IsDataType<int32_t>()) {
return GatherCopyData<int32_t>(p.indices_tensor, src_base, dst_base, is_string_type, element_bytes,
block_size, M, N, data_batch_bytes, gathered_batch_bytes, input_data_shape, p.axis, tp);
}
if (p.indices_tensor->IsDataType<int64_t>()) {
if (utils::HasTypeWithSameSize<EnabledIndexTypes, int64_t>() &&
p.indices_tensor->IsDataType<int64_t>()) {
return GatherCopyData<int64_t>(p.indices_tensor, src_base, dst_base, is_string_type, element_bytes,
block_size, M, N, data_batch_bytes, gathered_batch_bytes, input_data_shape, p.axis, tp);
}

return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "Type for Tind not supported yet in Gather.");
return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "Gather Tind type not supported in this build.");
}

} // namespace onnxruntime
8 changes: 4 additions & 4 deletions onnxruntime/core/providers/op_kernel_type_control.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
* - Enabled types are the types that are supported in the actual, compiled implementation. They are obtained from the
* intersection of supported and allowed types.
*
* The types are associated with an Op kernel argument. It is also possible to specify a global list of allowed types.
* The types are associated with an Op argument. It is also possible to specify a global list of allowed types.
*
* Use of these utilities is optional. They are useful for cases where one registered Op kernel handles multiple types.
*
Expand Down Expand Up @@ -239,8 +239,8 @@ struct EnabledTypes {
* namespace onnxruntime {
* namespace op_kernel_type_control {
* // specify supported types, i.e., the full set of types that can be enabled
* ORT_SPECIFY_OP_KERNEL_ARG_SUPPORTED_TYPES(
* MyProvider, DomainContainingMyOp, MyOp, OpSet, Input, 0,
* ORT_SPECIFY_OP_KERNEL_ARG_SUPPORTED_TYPES_ALL_OPSETS(
* MyProvider, DomainContainingMyOp, MyOp, Input, 0,
* int, float, double);
* } // namespace op_kernel_type_control
* } // namespace onnxruntime
Expand All @@ -249,7 +249,7 @@ struct EnabledTypes {
*
* // get enabled types
* using MyOpFirstInputEnabledTypes =
* ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST(MyProvider, DomainContainingMyOp, MyOp, Input, 0);
* ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST_ALL_OPSETS(MyProvider, DomainContainingMyOp, MyOp, Input, 0);
*
* // ...
*
Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/core/providers/op_kernel_type_control_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ using SizeOfT = boost::mp11::mp_size_t<sizeof(T)>;

/**
* Check if the set of types contains a type with the same size as T.
*
* @remarks e.g. will return true if T is int32_t and the list contains any 4 byte type (i.e. sizeof(int32_t))
*
* @remarks e.g. will return true if T is int32_t and the list contains any 4 byte type (i.e. sizeof(int32_t))
* such as int32_t, uint32_t or float.
*/
template <typename TypeSet, typename T>
Expand Down