@@ -234,68 +234,65 @@ struct Relu<Device, qint8> {
234234 reinterpret_cast <int32*>(output.data ())));
235235 }
236236};
237- #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
238237
239- #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
240238template <class T >
241- __global__ void GeluKernel (const T* in, T* out, int32 count) {
239+ __global__ void GeluKernel (const T* __restrict__ in,
240+ T* __restrict__ out, int32 count) {
242241 int i = threadIdx.x + blockIdx.x * blockDim.x ;
243242 if (i >= count) return ;
244- const auto scale = static_cast <T>(0.7978845608028654 );
245- const auto p1 = scale;
246- const auto p3 = static_cast <T>(0.044715 * 0.7978845608028654 );
247- T x = in[i];
248- out[i] = 0.5 * x * (1 + tanh (p1 * x + p3 * x * x * x));
243+
244+ constexpr bool is_half = std::is_same_v<T, Eigen::half>;
245+ if constexpr (is_half || std::is_same_v<T, Eigen::bfloat16>) {
246+ using NT = std::conditional_t < is_half, half, bfloat16 >;
247+ auto *xin = reinterpret_cast <const NT*>(in);
248+ auto *xout = reinterpret_cast <NT*>(out);
249+ const float scale = 0.7978845608028654 ;
250+ const float p1 = scale;
251+ const float p3 = 0.044715 * 0.7978845608028654 ;
252+ float x = xin[i];
253+ float out = 0 .5f * x * (1 .f + tanh (p1 * x + p3 * x * x * x));
254+ xout[i] = static_cast <NT>(out);
255+ } else {
256+ const auto scale = static_cast <T>(0.7978845608028654 );
257+ const auto p1 = scale;
258+ const auto p3 = static_cast <T>(0.044715 * 0.7978845608028654 );
259+ T x = in[i];
260+ out[i] = 0.5 * x * (1 . + tanh (p1 * x + p3 * x * x * x));
261+ }
249262}
250263
251264template <class T >
252- __global__ void GeluGradKernel (const T* gradient, const T* feature, T* backprop,
253- int32 count) {
265+ __global__ void GeluGradKernel (const T* __restrict__ gradient,
266+ const T* __restrict__ feature, T* __restrict__ backprop, int32 count) {
254267 int i = threadIdx.x + blockIdx.x * blockDim.x ;
255268 if (i >= count) return ;
256269
257- const T p1 = static_cast <T>(0.7978845608028654 );
258- const T p3 = static_cast <T>(0.044715 * 0.7978845608028654 );
259- T x = feature[i];
260- T z = p1 * x + p3 * x * x * x;
261- T g = gradient[i];
262- T cz = 1 . / cosh (z);
263- backprop[i] = static_cast <T>(
270+ constexpr bool is_half = std::is_same_v<T, Eigen::half>;
271+ if constexpr (is_half || std::is_same_v<T, Eigen::bfloat16>) {
272+ using NT = std::conditional_t < is_half, half, bfloat16 >;
273+ const float scale = 0.7978845608028654 ;
274+ const float p1 = scale;
275+ const float p3 = 0.044715 * 0.7978845608028654 ;
276+ auto *xgrad = reinterpret_cast <const NT*>(gradient);
277+ auto *xfeature = reinterpret_cast <const NT*>(feature);
278+ auto *xbackprop = reinterpret_cast <NT*>(backprop);
279+ float x = xfeature[i];
280+ float z = p1 * x + p3 * x * x * x;
281+ float g = xgrad[i];
282+ float cz = 1 .f / cosh (z);
283+ float out = g * 0 .5f * (1 .f + tanh (z) +
284+ x * (p1 + 3 * p3 * x * x) * cz * cz);
285+ xbackprop[i] = static_cast < NT >(out);
286+ } else {
287+ const T p1 = static_cast <T>(0.7978845608028654 );
288+ const T p3 = static_cast <T>(0.044715 * 0.7978845608028654 );
289+ T x = feature[i];
290+ T z = p1 * x + p3 * x * x * x;
291+ T g = gradient[i];
292+ T cz = 1 . / cosh (z);
293+ backprop[i] = static_cast <T>(
264294 g * 0.5 * (1 . + tanh (z) + x * (p1 + 3 * p3 * x * x) * cz * cz));
265- }
266-
267- template <>
268- __global__ void GeluKernel<Eigen::half>(const Eigen::half* _in,
269- Eigen::half* _out, int32 count) {
270- int i = threadIdx.x + blockIdx.x * blockDim.x ;
271- if (i >= count) return ;
272- const half* in = reinterpret_cast <const half*>(_in);
273- half* out = reinterpret_cast <half*>(_out);
274- const float scale = 0.7978845608028654 ;
275- const float p1 = scale;
276- const float p3 = 0.044715 * 0.7978845608028654 ;
277- float x = in[i];
278- out[i] = 0.5 * x * (1 + tanh (p1 * x + p3 * x * x * x));
279- }
280-
281- template <>
282- __global__ void GeluGradKernel<Eigen::half>(const Eigen::half* _gradient,
283- const Eigen::half* _feature,
284- Eigen::half* _backprop,
285- int32 count) {
286- int i = threadIdx.x + blockIdx.x * blockDim.x ;
287- if (i >= count) return ;
288- const float scale = 0.7978845608028654 ;
289- const float p1 = scale;
290- const float p3 = 0.044715 * 0.7978845608028654 ;
291- const half* gradient = reinterpret_cast <const half*>(_gradient);
292- const half* feature = reinterpret_cast <const half*>(_feature);
293- half* backprop = reinterpret_cast <half*>(_backprop);
294- float x = feature[i];
295- float z = p1 * x + p3 * x * x * x;
296- float g = gradient[i];
297- float cz = 1 . / cosh (z);
298- backprop[i] = g * 0.5 * (1 . + tanh (z) + x * (p1 + 3 * p3 * x * x) * cz * cz);
295+ }
299296}
300297
301298template <typename T>
@@ -338,9 +335,7 @@ TF_CALL_half(DEFINE_GPU_NO_MLIR_KERNELS);
338335TF_CALL_float (DEFINE_GPU_NO_MLIR_KERNELS);
339336TF_CALL_double (DEFINE_GPU_NO_MLIR_KERNELS);
340337#endif
341- #if GOOGLE_CUDA
342338TF_CALL_bfloat16 (DEFINE_GPU_NO_MLIR_KERNELS);
343- #endif
344339#undef DEFINE_GPU_NO_MLIR_KERNELS
345340
346341// Definition of the GPU implementations declared in relu_op.cc.
@@ -356,9 +351,7 @@ TF_CALL_bfloat16(DEFINE_GPU_NO_MLIR_KERNELS);
356351 template struct functor ::GeluGrad<GPUDevice, T>;
357352
358353TF_CALL_GPU_NUMBER_TYPES_NO_BF16 (DEFINE_GPU_KERNELS);
359- #if GOOGLE_CUDA
360354TF_CALL_bfloat16 (DEFINE_GPU_KERNELS);
361- #endif
362355template struct functor ::Relu<GPUDevice, qint8>;
363356
364357} // end namespace tensorflow
0 commit comments