Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion docs/OperatorKernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -958,7 +958,8 @@ Do not modify directly.*
|||1+|**T** = tensor(double), tensor(float), tensor(float16)|
|Tile|*in* input:**T**<br> *in* repeats:**T1**<br> *out* output:**T**<br><br>or<br><br>*in* input:**T**<br> *in* tiles:**T**<br> *in* axis:**T**<br> *out* output:**T**|13+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)<br/> **T1** = tensor(int64)|
|||[6, 12]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)<br/> **T1** = tensor(int64)|
|TopK|*in* X:**T**<br> *in* K:**tensor(int64)**<br> *out* Values:**T**<br> *out* Indices:**I**<br><br>or<br><br>*in* X:**T**<br> *out* Values:**T**<br> *out* Indices:**I**|11+|**I** = tensor(int64)<br/> **T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)|
|TopK|*in* X:**T**<br> *in* K:**tensor(int64)**<br> *out* Values:**T**<br> *out* Indices:**I**<br><br>or<br><br>*in* X:**T**<br> *out* Values:**T**<br> *out* Indices:**I**|24+|**I** = tensor(int64)<br/> **T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)|
|||[11, 23]|**I** = tensor(int64)<br/> **T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)|
|||10|**I** = tensor(int64)<br/> **T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)|
|||[1, 9]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)|
|Transpose|*in* data:**T**<br> *out* transposed:**T**|23+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
Expand Down
6 changes: 4 additions & 2 deletions onnxruntime/core/providers/cuda/cuda_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -981,7 +981,7 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kO
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, MLFloat16, LogSoftmax);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, Split);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, Squeeze);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, TopK);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 23, TopK);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, SequenceAt);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, SequenceConstruct);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, SequenceEmpty);
Expand Down Expand Up @@ -1634,6 +1634,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain,
#endif

// Opset 24.
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 24, TopK);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 24, TensorScatter);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 24, float, Attention);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 24, MLFloat16, Attention);
Expand Down Expand Up @@ -2070,7 +2071,7 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, MLFloat16, LogSoftmax)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, Split)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, Squeeze)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, TopK)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 23, TopK)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, SequenceAt)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, SequenceConstruct)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, SequenceEmpty)>,
Expand Down Expand Up @@ -2717,6 +2718,7 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, Unsqueeze)>,

// Opset 24
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 24, TopK)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 24, TensorScatter)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 24, float, Attention)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 24, MLFloat16, Attention)>,
Expand Down
17 changes: 16 additions & 1 deletion onnxruntime/core/providers/cuda/math/topk.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,25 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX(
.TypeConstraint("I", DataTypeImpl::GetTensorType<int64_t>()),
TopK<true>);

ONNX_OPERATOR_VERSIONED_KERNEL_EX(
TopK,
kOnnxDomain,
11, 23,
kCudaExecutionProvider,
(*KernelDefBuilder::Create())
.InputMemoryType(OrtMemTypeCPUInput, 1)
.TypeConstraint("T", {DataTypeImpl::GetTensorType<MLFloat16>(),
DataTypeImpl::GetTensorType<float>(),
DataTypeImpl::GetTensorType<double>(),
DataTypeImpl::GetTensorType<int32_t>(),
DataTypeImpl::GetTensorType<int64_t>()})
.TypeConstraint("I", DataTypeImpl::GetTensorType<int64_t>()),
TopK<true>);

ONNX_OPERATOR_KERNEL_EX(
TopK,
kOnnxDomain,
11,
24,
kCudaExecutionProvider,
(*KernelDefBuilder::Create())
.InputMemoryType(OrtMemTypeCPUInput, 1)
Expand Down
Loading