Skip to content

Commit ae92d59

Browse files
authored
ONNX Gelu Op in Opset 20 (#19560)
### ONNX Gelu Op in Opset 20 Refactor code to support MSDomain Gelu and ONNX Gelu-opset20 Op 1. Move CPU-GELU implmentation from `onnxruntime/contrib_ops/cpu/activations.h/cc` to `onnxruntime/core/providers/cpu/tensor/gelu.h/cc`, as the implementation for approximate attribute to be 'none'. 2. Dumplicate some logic from `onnxruntime/contrib_ops/cpu/bert/bias_gelu.cc` to `onnxruntime/core/providers/cpu/tensor/gelu.h/cc`, as the implementation for approximate attribute to be 'tanh'. 3. Register ONNX domain Gelu CPU kernel from opset 20 in `onnxruntime/core/providers/cpu/cpu_execution_provider.cc`. 4. Move `onnxruntime/contrib_ops/cuda/bert/fast_gelu_impl.h/cu` to `onnxruntime/core/providers/cuda/tensor/gelu_impl.h` and `onnxruntime/core/providers/cuda/tensor/gelu_approximate_impl.cu` respectively, as the implementation for approximate attribute to be 'tanh'. 5. Implement the logic for approximate attribute to be 'none' in `onnxruntime/core/providers/cuda/tensor/gelu_impl.cu`. 6. Register ONNX domain Gelu CUDA kernel from opset 20 in `onnxruntime/core/providers/cuda/cuda_execution_provider.cc`. 7. ROCM ep related changes. 8. Enrich the tests for ONNX domain Gelu in `onnxruntime/test/providers/cpu/activation/activation_op_test.cc`.
1 parent 29b1106 commit ae92d59

File tree

27 files changed

+395
-197
lines changed

27 files changed

+395
-197
lines changed

cmake/onnxruntime_rocm_hipify.cmake

-4
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,6 @@ set(contrib_ops_excluded_files
2020
"bert/fastertransformer_decoder_attention/*"
2121
"bert/multihead_attention.cc"
2222
"bert/multihead_attention.h"
23-
"bert/fast_gelu_impl.cu"
24-
"bert/fast_gelu_impl.h"
25-
"bert/fast_gelu.cc"
26-
"bert/fast_gelu.h"
2723
"bert/relative_attn_bias.cc"
2824
"bert/relative_attn_bias.h"
2925
"bert/relative_attn_bias_impl.cu"

csharp/test/Microsoft.ML.OnnxRuntime.Tests.NetCoreApp/InferenceTest.netcore.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ private void TestCUDAProviderOptions()
145145
private void CanRunInferenceOnAModelWithTensorRT()
146146
{
147147
string modelPath = Path.Combine(Directory.GetCurrentDirectory(), "squeezenet.onnx");
148-
148+
149149
int deviceId = 0;
150150
string deviceIdStr = System.Environment.GetEnvironmentVariable("ONNXRUNTIME_TEST_GPU_DEVICE_ID");
151151
if (!string.IsNullOrEmpty(deviceIdStr) && int.TryParse(deviceIdStr, out int parsedValue) && parsedValue >= 0)

docs/OperatorKernels.md

+2
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,7 @@ Do not modify directly.*
127127
|GatherND|*in* data:**T**<br> *in* indices:**tensor(int64)**<br> *out* output:**T**|13+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **indices** = tensor(int64)|
128128
|||12|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **indices** = tensor(int64)|
129129
|||11|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **indices** = tensor(int64)|
130+
|Gelu|*in* X:**T**<br> *out* Y:**T**|20+|**T** = tensor(float)|
130131
|Gemm|*in* A:**T**<br> *in* B:**T**<br> *in* C:**T**<br> *out* Y:**T**|13+|**T** = tensor(double), tensor(float)|
131132
|||[11, 12]|**T** = tensor(double), tensor(float)|
132133
|||[9, 10]|**T** = tensor(double), tensor(float)|
@@ -606,6 +607,7 @@ Do not modify directly.*
606607
|GatherND|*in* data:**T**<br> *in* indices:**tensor(int64)**<br> *out* output:**T**|13+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int64)<br/> **indices** = tensor(int64)|
607608
|||12|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int64)<br/> **indices** = tensor(int64)|
608609
|||11|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int64)<br/> **indices** = tensor(int64)|
610+
|Gelu|*in* X:**T**<br> *out* Y:**T**|20+|**T** = tensor(double), tensor(float), tensor(float16)|
609611
|Gemm|*in* A:**T**<br> *in* B:**T**<br> *in* C:**T**<br> *out* Y:**T**|13+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)|
610612
|||[11, 12]|**T** = tensor(double), tensor(float), tensor(float16)|
611613
|||[9, 10]|**T** = tensor(double), tensor(float), tensor(float16)|

include/onnxruntime/core/providers/cuda/cuda_resource.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -19,4 +19,4 @@ enum CudaResource : int {
1919
enable_skip_layer_norm_strict_mode_t,
2020
prefer_nhwc_t,
2121
use_tf32_t,
22-
};
22+
};

onnxruntime/contrib_ops/cpu/activations.cc

+1-9
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
// Licensed under the MIT License.
33

44
#include "core/providers/cpu/activation/activations.h"
5-
#include "activations.h"
5+
#include "contrib_ops/cpu/activations.h"
66

77
namespace onnxruntime {
88
namespace contrib {
@@ -26,14 +26,6 @@ ONNX_CPU_OPERATOR_VERSIONED_KERNEL(
2626
KernelDefBuilder().MayInplace(0, 0).TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
2727
ThresholdedRelu<float>);
2828

29-
ONNX_OPERATOR_KERNEL_EX(
30-
Gelu,
31-
kMSDomain,
32-
1,
33-
kCpuExecutionProvider,
34-
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
35-
Gelu<float>);
36-
3729
ONNX_OPERATOR_KERNEL_EX(
3830
QuickGelu,
3931
kMSDomain,

onnxruntime/contrib_ops/cpu/activations.h

-41
Original file line numberDiff line numberDiff line change
@@ -54,47 +54,6 @@ namespace contrib {
5454
DEFINE_ELE_KERNEL(ScaledTanh);
5555
DEFINE_ELE_KERNEL(ParametricSoftplus);
5656

57-
template <typename T>
58-
class Gelu : public OpKernel {
59-
public:
60-
Gelu(const OpKernelInfo& info) : OpKernel(info) {
61-
}
62-
63-
Status Compute(OpKernelContext* context) const override {
64-
const Tensor* input = context->Input<Tensor>(0);
65-
const T* input_data = input->Data<T>();
66-
67-
Tensor* output = context->Output(0, input->Shape());
68-
T* output_data = output->MutableData<T>();
69-
70-
concurrency::ThreadPool* tp = context->GetOperatorThreadPool();
71-
int64_t elem_count = input->Shape().Size();
72-
constexpr int64_t length_per_task = 4096; // this number comes from FastGelu.
73-
int64_t task_count = (elem_count + length_per_task - 1) / length_per_task;
74-
concurrency::ThreadPool::TryBatchParallelFor(
75-
tp, static_cast<int32_t>(task_count),
76-
[&](ptrdiff_t task_idx) {
77-
const auto start = task_idx * length_per_task;
78-
const T* p_input = input_data + start;
79-
T* p_output = output_data + start;
80-
int64_t count = std::min(length_per_task, elem_count - start);
81-
82-
for (int64_t i = 0; i < count; i++) {
83-
T value = p_input[i];
84-
p_output[i] = value * static_cast<T>(M_SQRT1_2);
85-
}
86-
87-
MlasComputeErf(p_output, p_output, narrow<size_t>(count));
88-
89-
for (int64_t i = 0; i < count; i++) {
90-
p_output[i] = 0.5f * p_input[i] * (p_output[i] + 1.0f);
91-
}
92-
},
93-
0);
94-
return Status::OK();
95-
}
96-
};
97-
9857
// Implement a new one instead of inheriting from ElementWiseRangedTransform so that we can call
9958
// MlasComputeLogistic instead of using Eigen for better perf.
10059
template <typename T>

onnxruntime/contrib_ops/cuda/activation/activations.cc

-1
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,6 @@ namespace cuda {
4949
UNARY_ACTIVATION_OP_HFD(Affine, 1, kOnnxDomain);
5050
UNARY_ACTIVATION_OP_HFD(ParametricSoftplus, 1, kOnnxDomain);
5151
UNARY_ACTIVATION_OP_HFD(ScaledTanh, 1, kOnnxDomain);
52-
UNARY_ACTIVATION_OP_HFD(Gelu, 1, kMSDomain);
5352
UNARY_ACTIVATION_OP_HFD(QuickGelu, 1, kMSDomain);
5453

5554
REGISTER_ACTIVATION_KERNEL(ThresholdedRelu, 1, kOnnxDomain, MLFloat16)

onnxruntime/contrib_ops/cuda/activation/activations.h

-11
Original file line numberDiff line numberDiff line change
@@ -66,17 +66,6 @@ class ScaledTanh final : public UnaryElementwise {
6666
float beta_;
6767
};
6868

69-
template <typename T>
70-
class Gelu final : public UnaryElementwise {
71-
public:
72-
Gelu(const OpKernelInfo& info) : UnaryElementwise(info) {}
73-
74-
Status ComputeInternal(OpKernelContext* context) const override;
75-
76-
private:
77-
MAKE_FUNC_CTX_NULL()
78-
};
79-
8069
template <typename T>
8170
class QuickGelu final : public UnaryElementwise {
8271
public:

onnxruntime/contrib_ops/cuda/activation/activations_impl.cu

-14
Original file line numberDiff line numberDiff line change
@@ -36,20 +36,6 @@ struct OP_ScaledTanh : public CtxScaledTanh {
3636
}
3737
};
3838

39-
template <typename T>
40-
struct OP_Gelu : public CtxGelu {
41-
__device__ __inline__ T operator()(const T& a) const {
42-
return _Gelu(a);
43-
}
44-
};
45-
46-
template <>
47-
struct OP_Gelu<half> : public CtxGelu {
48-
__device__ __inline__ half operator()(const half& a) const {
49-
return static_cast<half>(_Gelu(static_cast<float>(a)));
50-
}
51-
};
52-
5339
template <typename T>
5440
struct OP_QuickGelu : public CtxQuickGelu {
5541
__device__ __inline__ T operator()(const T& a) const {

onnxruntime/contrib_ops/cuda/activation/activations_impl.h

-2
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,12 @@ namespace cuda {
1111
typedef onnxruntime::cuda::CtxAlphaBeta CtxAffine;
1212
typedef onnxruntime::cuda::CtxAlphaBeta CtxParametricSoftplus;
1313
typedef onnxruntime::cuda::CtxAlphaBeta CtxScaledTanh;
14-
typedef onnxruntime::cuda::CtxNull CtxGelu;
1514
typedef onnxruntime::cuda::CtxAlpha CtxQuickGelu;
1615

1716
#define UNARY_CONTRIB_ACTIVATION_OPS() \
1817
UNARY_ACTIVATION_OP_NAME(ScaledTanh) \
1918
UNARY_ACTIVATION_OP_NAME(Affine) \
2019
UNARY_ACTIVATION_OP_NAME(ParametricSoftplus) \
21-
UNARY_ACTIVATION_OP_NAME(Gelu) \
2220
UNARY_ACTIVATION_OP_NAME(QuickGelu)
2321

2422
#define UNARY_ACTIVATION_OP_NAME(name) UNARY_ACTIVATION_IMPL_DECLARATION(name);

onnxruntime/contrib_ops/cuda/bert/fast_gelu.cc

+18-2
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,14 @@
44
#include "core/providers/cuda/cuda_common.h"
55
#include "core/providers/cuda/cudnn_common.h"
66
#include "fast_gelu.h"
7-
#include "fast_gelu_impl.h"
7+
#include "core/providers/cuda/tensor/gelu_impl.h"
88
#include "contrib_ops/cpu/bert/bias_gelu_helper.h"
9-
#include "transformer_common.h"
9+
#ifdef USE_ROCM
10+
#include "contrib_ops/rocm/bert/elementwise.h"
11+
#endif
12+
#ifdef USE_CUDA
13+
#include "contrib_ops/cuda/bert/transformer_common.h"
14+
#endif
1015

1116
namespace onnxruntime {
1217
namespace contrib {
@@ -31,8 +36,10 @@ using namespace ONNX_NAMESPACE;
3136

3237
template <typename T>
3338
FastGelu<T>::FastGelu(const OpKernelInfo& op_kernel_info) : CudaKernel(op_kernel_info) {
39+
#ifdef USE_CUDA
3440
const TransformerOptions* options = TransformerOptions::GetInstance();
3541
use_half2_ = !options->DisableHalf2();
42+
#endif
3643
}
3744

3845
template <typename T>
@@ -50,6 +57,14 @@ Status FastGelu<T>::ComputeInternal(OpKernelContext* context) const {
5057
int64_t bias_length = (nullptr == bias) ? 0 : bias->Shape().Size();
5158
typedef typename ToCudaType<T>::MappedType CudaT;
5259

60+
#ifdef USE_ROCM
61+
return LaunchElementwiseKernel<functor::FastGeLU, CudaT>(
62+
GetTuningContext(), context->GetComputeStream(),
63+
reinterpret_cast<const CudaT*>(input->Data<T>()), static_cast<int>(input_length),
64+
(nullptr != bias) ? reinterpret_cast<const CudaT*>(bias->Data<T>()) : nullptr, static_cast<int>(bias_length),
65+
reinterpret_cast<CudaT*>(output->MutableData<T>()));
66+
#endif
67+
#ifdef USE_CUDA
5368
return LaunchFastGeluKernel<CudaT>(GetDeviceProp(),
5469
Stream(context),
5570
static_cast<int>(input_length),
@@ -58,6 +73,7 @@ Status FastGelu<T>::ComputeInternal(OpKernelContext* context) const {
5873
(nullptr != bias) ? reinterpret_cast<const CudaT*>(bias->Data<T>()) : nullptr,
5974
reinterpret_cast<CudaT*>(output->MutableData<T>()),
6075
use_half2_);
76+
#endif
6177
}
6278

6379
} // namespace cuda

onnxruntime/contrib_ops/cuda/bert/fast_gelu.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ class FastGelu final : public CudaKernel {
1818
Status ComputeInternal(OpKernelContext* ctx) const override;
1919

2020
private:
21-
bool use_half2_;
21+
bool use_half2_; // Only applicable to CUDA kernel (not ROCM).
2222
};
2323

2424
} // namespace cuda

onnxruntime/contrib_ops/rocm/bert/fast_gelu.cc

-59
This file was deleted.

onnxruntime/contrib_ops/rocm/bert/fast_gelu.h

-24
This file was deleted.

onnxruntime/core/providers/cpu/cpu_execution_provider.cc

+2
Original file line numberDiff line numberDiff line change
@@ -1035,6 +1035,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain,
10351035
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, float, IsNaN);
10361036
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, double, IsNaN);
10371037
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, MLFloat16, IsNaN);
1038+
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, Gelu);
10381039
#if !defined(DISABLE_FLOAT8_TYPES)
10391040
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, Float8E4M3FN, IsNaN);
10401041
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, Float8E4M3FNUZ, IsNaN);
@@ -2562,6 +2563,7 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
25622563
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, double, IsNaN)>,
25632564
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, MLFloat16,
25642565
IsNaN)>,
2566+
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, Gelu)>,
25652567
#if !defined(DISABLE_FLOAT8_TYPES)
25662568
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, Float8E4M3FN,
25672569
IsNaN)>,

0 commit comments

Comments
 (0)