From 7546fae30ccb26d86e31964fe71c079c07a79308 Mon Sep 17 00:00:00 2001 From: bear-zd Date: Thu, 10 Oct 2024 19:24:30 +0800 Subject: [PATCH 1/5] [GELU] Add f32/x4, f16/x2/x8/x8pack kernel. --- gelu/.gitignore | 10 +++ gelu/README.md | 164 ++++++++++++++++++++++++++++++++++ gelu/gelu.cu | 228 ++++++++++++++++++++++++++++++++++++++++++++++++ gelu/gelu.py | 80 +++++++++++++++++ 4 files changed, 482 insertions(+) create mode 100644 gelu/.gitignore create mode 100755 gelu/README.md create mode 100644 gelu/gelu.cu create mode 100644 gelu/gelu.py diff --git a/gelu/.gitignore b/gelu/.gitignore new file mode 100644 index 00000000..eb33da95 --- /dev/null +++ b/gelu/.gitignore @@ -0,0 +1,10 @@ +*.so +*.a +*.dylib +*.dll +*.lib +.DS_Store +build +*.whl +tmp + diff --git a/gelu/README.md b/gelu/README.md new file mode 100755 index 00000000..c6dd37be --- /dev/null +++ b/gelu/README.md @@ -0,0 +1,164 @@ +# GELU + +## 0x00 说明 + +包含以下内容: + +- [X] gelu_f32_kernel +- [X] gelu_f32x4_kernel(float4向量化版本) +- [X] gelu_f16_kernel +- [X] gelu_f16x2_kernel(half2向量化) +- [X] gelu_f16x8_kernel(unpack版本) +- [X] gelu_f16x8_pack_kernel(pack版本) +- [X] PyTorch bindings + + +## 测试 + +对于半精度(half)的GELU操作,由于CUDA的半精度计算中并不包含tanh操作,因此需要使用hexp来替代对应的操作,因此会引入较大的误差。(或许可以考虑从汇编上解决这个问题);而torch是通过转化数据类型完成的。想要测试很简单,修改一下cu中f16里面的代码做一下强制类型转换即可: + +```cpp +// line 96 +y[idx] = HALF_GELU_OPS(__half2float(v)); +// line 109 , line 110 +reg_y.x = HALF_GELU_OPS(__half2float(reg_x.x)); +reg_y.y = HALF_GELU_OPS(__half2float(reg_x.y)); +``` +测试结果如下(由于不是所有数据都会掉误差所以取了会有误差的情况,可见修改后out_f16和out_f16x2的结果和torch相同了): +```bash + S=2048, K=4096 + out_f32: [-0.08196318, -0.1613517], time:0.13425708ms + out_f32x4: [-0.08196318, -0.1613517], time:0.14128804ms + out_f32_th: [-0.08196313, -0.1613517], time:0.08195782ms +------------------------------------------------------------------------------------- + out_f16: [-0.08197021, -0.16137695], time:0.12120271ms + out_f16x2: [-0.08197021, -0.16137695], time:0.12122369ms + out_f16x8: [-0.08251953, -0.16137695], time:0.04196978ms + out_f16x8pack: [-0.08251953, -0.16137695], time:0.04215288ms + out_f16_th: [-0.08197021, -0.16137695], time:0.04287958ms + +``` +相关参考: +- (pytorch-c10-BFloat16.h)[https://github.com/pytorch/pytorch/blob/main/c10/util/BFloat16.h] +- (math ptx)[https://github.com/pavanky/math_ptx] + +此外仿照torch实现了在float下tanh和none两种近似下的GELU函数,可以在gelu.cu的宏中进行修改实现不同的版本的编译。 + +```bash +# 只测试Ada架构 不指定默认编译所有架构 耗时较长: Volta, Ampere, Ada, Hopper, ... +export TORCH_CUDA_ARCH_LIST=Ada +python3 gelu.py +``` + +输出(不做类型转换导致half误差): + +```bash +------------------------------------------------------------------------------------- + S=1024, K=1024 + out_f32: [0.93880296, 0.15988638], time:0.02785468ms + out_f32x4: [0.93880296, 0.15988638], time:0.02076554ms + out_f32_th: [0.93880296, 0.15988638], time:0.01221609ms +------------------------------------------------------------------------------------- + out_f16: [0.93798828, 0.15979004], time:0.00964093ms + out_f16x2: [0.93798828, 0.15979004], time:0.00525022ms + out_f16x8: [0.93798828, 0.15979004], time:0.00469351ms + out_f16x8pack: [0.93798828, 0.15979004], time:0.00465655ms + out_f16_th: [0.93847656, 0.15991211], time:0.00669861ms +------------------------------------------------------------------------------------- +------------------------------------------------------------------------------------- + S=1024, K=2048 + out_f32: [-0.14857908, 0.10128548], time:0.03697181ms + out_f32x4: [-0.14857908, 0.10128548], time:0.03849959ms + out_f32_th: [-0.14857908, 0.10128548], time:0.02257371ms +------------------------------------------------------------------------------------- + out_f16: [-0.14904785, 0.10119629], time:0.01546693ms + out_f16x2: [-0.14904785, 0.10119629], time:0.01501513ms + out_f16x8: [-0.14904785, 0.10119629], time:0.01015544ms + out_f16x8pack: [-0.14904785, 0.10119629], time:0.01015282ms + out_f16_th: [-0.14855957, 0.10125732], time:0.01221085ms +------------------------------------------------------------------------------------- +------------------------------------------------------------------------------------- + S=1024, K=4096 + out_f32: [-0.16260667, 2.28252459], time:0.07104182ms + out_f32x4: [-0.16260667, 2.28252459], time:0.08304977ms + out_f32_th: [-0.16260667, 2.28252459], time:0.04243922ms +------------------------------------------------------------------------------------- + out_f16: [-0.16296387, 2.28125], time:0.02782536ms + out_f16x2: [-0.16296387, 2.28125], time:0.02191663ms + out_f16x8: [-0.16296387, 2.28125], time:0.02220559ms + out_f16x8pack: [-0.16296387, 2.28125], time:0.02232957ms + out_f16_th: [-0.16259766, 2.28320312], time:0.02265978ms +------------------------------------------------------------------------------------- +------------------------------------------------------------------------------------- + S=2048, K=1024 + out_f32: [-0.16840045, -0.14960197], time:0.05070662ms + out_f32x4: [-0.16840045, -0.14960197], time:0.03644156ms + out_f32_th: [-0.16840045, -0.14960195], time:0.02212596ms +------------------------------------------------------------------------------------- + out_f16: [-0.16845703, -0.1496582], time:0.02071333ms + out_f16x2: [-0.16845703, -0.1496582], time:0.01206446ms + out_f16x8: [-0.16845703, -0.1496582], time:0.00981784ms + out_f16x8pack: [-0.16845703, -0.1496582], time:0.00988960ms + out_f16_th: [-0.16845703, -0.1496582], time:0.01215363ms +------------------------------------------------------------------------------------- +------------------------------------------------------------------------------------- + S=2048, K=2048 + out_f32: [-0.16697021, -0.16277096], time:0.06218576ms + out_f32x4: [-0.16697021, -0.16277096], time:0.06344438ms + out_f32_th: [-0.16697019, -0.16277094], time:0.04222322ms +------------------------------------------------------------------------------------- + out_f16: [-0.16699219, -0.16271973], time:0.02624702ms + out_f16x2: [-0.16699219, -0.16271973], time:0.02568126ms + out_f16x8: [-0.16699219, -0.16271973], time:0.02205300ms + out_f16x8pack: [-0.16699219, -0.16271973], time:0.02210712ms + out_f16_th: [-0.16699219, -0.16271973], time:0.02253604ms +------------------------------------------------------------------------------------- +------------------------------------------------------------------------------------- + S=2048, K=4096 + out_f32: [-0.09021921, -0.16487332], time:0.13927341ms + out_f32x4: [-0.09021921, -0.16487332], time:0.14096951ms + out_f32_th: [-0.09021921, -0.16487332], time:0.08194113ms +------------------------------------------------------------------------------------- + out_f16: [-0.09033203, -0.16503906], time:0.05144143ms + out_f16x2: [-0.09033203, -0.16503906], time:0.04174685ms + out_f16x8: [-0.09033203, -0.16503906], time:0.04198074ms + out_f16x8pack: [-0.09033203, -0.16503906], time:0.04212999ms + out_f16_th: [-0.09020996, -0.16491699], time:0.04287744ms +------------------------------------------------------------------------------------- +------------------------------------------------------------------------------------- + S=4096, K=1024 + out_f32: [0.07282269, -0.06332674], time:0.09058189ms + out_f32x4: [0.07282269, -0.06332674], time:0.06340218ms + out_f32_th: [0.07282269, -0.06332674], time:0.04206586ms +------------------------------------------------------------------------------------- + out_f16: [0.07281494, -0.06335449], time:0.03970504ms + out_f16x2: [0.07281494, -0.06335449], time:0.02199268ms + out_f16x8: [0.07281494, -0.06335449], time:0.02213860ms + out_f16x8pack: [0.07281494, -0.06335449], time:0.02209067ms + out_f16_th: [0.07281494, -0.06335449], time:0.02286220ms +------------------------------------------------------------------------------------- +------------------------------------------------------------------------------------- + S=4096, K=2048 + out_f32: [-0.11318169, -0.10836542], time:0.12297416ms + out_f32x4: [-0.11318169, -0.10836542], time:0.12383652ms + out_f32_th: [-0.11318169, -0.10836542], time:0.08190846ms +------------------------------------------------------------------------------------- + out_f16: [-0.11315918, -0.10888672], time:0.05153990ms + out_f16x2: [-0.11315918, -0.10888672], time:0.04872131ms + out_f16x8: [-0.11315918, -0.10888672], time:0.04182482ms + out_f16x8pack: [-0.11315918, -0.10888672], time:0.04196978ms + out_f16_th: [-0.11322021, -0.1083374], time:0.04286408ms +------------------------------------------------------------------------------------- +------------------------------------------------------------------------------------- + S=4096, K=4096 + out_f32: [-0.16762884, 0.33026037], time:0.26759410ms + out_f32x4: [-0.16762884, 0.33026037], time:0.27700567ms + out_f32_th: [-0.16762884, 0.33026037], time:0.16148257ms +------------------------------------------------------------------------------------- + out_f16: [-0.16760254, 0.33032227], time:0.10299659ms + out_f16x2: [-0.16760254, 0.33032227], time:0.08103538ms + out_f16x8: [-0.16760254, 0.33032227], time:0.08191633ms + out_f16x8pack: [-0.16760254, 0.33032227], time:0.08227539ms + out_f16_th: [-0.16760254, 0.33032227], time:0.08262110ms +------------------------------------------------------------------------------------- +``` diff --git a/gelu/gelu.cu b/gelu/gelu.cu new file mode 100644 index 00000000..d06230df --- /dev/null +++ b/gelu/gelu.cu @@ -0,0 +1,228 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#define WARP_SIZE 32 +#define INT4(value) (reinterpret_cast(&(value))[0]) +#define FLOAT4(value) (reinterpret_cast(&(value))[0]) +#define HALF2(value) (reinterpret_cast(&(value))[0]) +#define BFLOAT2(value) (reinterpret_cast<__nv_bfloat162*>(&(value))[0]) +#define LDST128BITS(value) (reinterpret_cast(&(value))[0]) +#define MAX_EXP_F32 88.3762626647949f +#define MIN_EXP_F32 -88.3762626647949f +#define MAX_EXP_F16 __float2half(11.089866488461016f) +#define MIN_EXP_F16 __float2half(-9.704060527839234f) +#define SQRT_2_PI M_SQRT2 * M_2_SQRTPI * 0.5f +#define HALF_1 __float2half(1.0f) +#define HALF_2 __float2half(2.0f) +#define HALF_DIV2 __float2half(0.5f) +// to clear the error among self defined gelu and pytorch gelu. Calculate $\sqrt{\frac{\pi}{2}}$ by $\sqrt{2 * \pi} / 2$ +#define HALF_SQRT_2_PI __float2half(M_SQRT2) * __float2half(M_2_SQRTPI) * HALF_DIV2 +#define HALF_V_APP __float2half(0.044715f) + +#define HALF_GELU_OPS gelu_tanh_approximate +#define GELU_OPS gelu_tanh_approximate + +// There is no half presicion operation like sinh, cosh, tanh. [Half Math Functions](https://docs.nvidia.com/cuda/cuda-math-api/group__CUDA__MATH____HALF__FUNCTIONS.html#group__CUDA__MATH____HALF__FUNCTIONS) +// $$ tanh(x) = \frac{exp^{2x} - 1}{exp^{2x} + 1}$$ +// But ops above will introduce error. +// pytorch transform type while do tanh operator which include in the [pytorch/c10/util/BFloat16-math.h](https://github.com/pytorch/pytorch/blob/main/c10/util/BFloat16-math.h) +__inline__ __device__ half gelu_tanh_approximate(half x){ + + half x_cube = x * x * x; + // compute mid value : inner = 0.7978845608 * (x + 0.044715 * x * x * x) + half inner = HALF_SQRT_2_PI * (x + HALF_V_APP * x_cube); + // compute tanh + return HALF_DIV2 * x * (HALF_1 + ((hexp(inner * HALF_2) - HALF_1) / (hexp(inner * HALF_2) + HALF_1))); +} + +__inline__ __device__ float gelu_tanh_approximate(float x){ + return 0.5f * x * (1.0f + tanhf(SQRT_2_PI * (x + 0.044715f * x * x * x))); +} + + +__inline__ __device__ float gelu_none_approximate(float x){ + return x * 0.5 * (1 + erff(x * M_SQRT1_2)); +} + +// -------------------------------------- FP32 -------------------------------------- +// GELU tanh approximate: x, y:x 0.5 * x * (1.0 + tanh(0.7978845608 * x * (1.0 + 0.044715 * x * x))) +// grid(N/256), block(K=256) + +__global__ void gelu_f32_kernel(float* x, float* y, int N) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < N) { + float v = fminf(fmaxf(x[idx], MIN_EXP_F32), MAX_EXP_F32); + y[idx] = GELU_OPS(v); + } +} + +// GELU tanh approximate; Vec4 +// grid(N/256), block(256/4) +__global__ void gelu_f32x4_kernel(float* x, float* y, int N) { + int idx = (blockIdx.x * blockDim.x + threadIdx.x) * 4; + float4 reg_x = FLOAT4(x[idx]); + float4 reg_y; + + reg_x.x = fminf(fmaxf(reg_x.x, MIN_EXP_F32), MAX_EXP_F32); + reg_x.y = fminf(fmaxf(reg_x.y, MIN_EXP_F32), MAX_EXP_F32); + reg_x.z = fminf(fmaxf(reg_x.z, MIN_EXP_F32), MAX_EXP_F32); + reg_x.w = fminf(fmaxf(reg_x.w, MIN_EXP_F32), MAX_EXP_F32); + + reg_y.x = GELU_OPS(reg_x.x); + reg_y.y = GELU_OPS(reg_x.y); + reg_y.z = GELU_OPS(reg_x.z); + reg_y.w = GELU_OPS(reg_x.w); + + if ((idx + 0) < N) { FLOAT4(y[idx]) = reg_y; } +} + +// -------------------------------------- FP16 -------------------------------------- +// GELU approximate: x, y:x 0.5 * x * (1.0 + tanh(0.7978845608 (x + 0.044715 * x * x * x))) Vec4 +__global__ void gelu_f16_kernel(half* x, half* y, int N) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < N) { + half v = x[idx]; + v = __hmin(__hmax(v, MIN_EXP_F16), MAX_EXP_F16); + + y[idx] = HALF_GELU_OPS(v); + } +} + +__global__ void gelu_f16x2_kernel(half* x, half* y, int N) { + int idx = (blockIdx.x * blockDim.x + threadIdx.x) * 2; + + + half2 reg_x = HALF2(x[idx]); + half2 reg_y; + reg_x.x = __hmin(__hmax(reg_x.x, MIN_EXP_F16), MAX_EXP_F16); + reg_x.y = __hmin(__hmax(reg_x.y, MIN_EXP_F16), MAX_EXP_F16); + + reg_y.x = HALF_GELU_OPS(reg_x.x); + reg_y.y = HALF_GELU_OPS(reg_x.y); + if ((idx + 0) < N) { HALF2(y[idx]) = reg_y; } +} + +// unpack f16x8 +__global__ void gelu_f16x8_kernel(half* x, half* y, int N) { + int idx = (blockIdx.x * blockDim.x + threadIdx.x) * 8; + + half2 reg_x_0 = HALF2(x[idx + 0]); + half2 reg_x_1 = HALF2(x[idx + 2]); + half2 reg_x_2 = HALF2(x[idx + 4]); + half2 reg_x_3 = HALF2(x[idx + 6]); + + reg_x_0.x = __hmin(__hmax(reg_x_0.x, MIN_EXP_F16), MAX_EXP_F16); + reg_x_0.y = __hmin(__hmax(reg_x_0.y, MIN_EXP_F16), MAX_EXP_F16); + reg_x_1.x = __hmin(__hmax(reg_x_1.x, MIN_EXP_F16), MAX_EXP_F16); + reg_x_1.y = __hmin(__hmax(reg_x_1.y, MIN_EXP_F16), MAX_EXP_F16); + reg_x_2.x = __hmin(__hmax(reg_x_2.x, MIN_EXP_F16), MAX_EXP_F16); + reg_x_2.y = __hmin(__hmax(reg_x_2.y, MIN_EXP_F16), MAX_EXP_F16); + reg_x_3.x = __hmin(__hmax(reg_x_3.x, MIN_EXP_F16), MAX_EXP_F16); + reg_x_3.y = __hmin(__hmax(reg_x_3.y, MIN_EXP_F16), MAX_EXP_F16); + + half2 reg_y_0, reg_y_1, reg_y_2, reg_y_3; + + reg_x_0.x = HALF_GELU_OPS(reg_x_0.x); + reg_x_0.y = HALF_GELU_OPS(reg_x_0.y); + reg_x_1.x = HALF_GELU_OPS(reg_x_1.x); + reg_x_1.y = HALF_GELU_OPS(reg_x_1.y); + reg_x_2.x = HALF_GELU_OPS(reg_x_2.x); + reg_x_2.y = HALF_GELU_OPS(reg_x_2.y); + reg_x_3.x = HALF_GELU_OPS(reg_x_3.x); + reg_x_3.y = HALF_GELU_OPS(reg_x_3.y); + + if ((idx + 0) < N) { HALF2(y[idx + 0]) = reg_x_0; } + if ((idx + 2) < N) { HALF2(y[idx + 2]) = reg_x_1; } + if ((idx + 4) < N) { HALF2(y[idx + 4]) = reg_x_2; } + if ((idx + 6) < N) { HALF2(y[idx + 6]) = reg_x_3; } +} + +// pack f16x8 +__global__ void gelu_f16x8_pack_kernel(half* x, half* y, int N) { + int idx = (blockIdx.x * blockDim.x + threadIdx.x) * 8; + + // temporary register(memory), .local space in ptx, addressable + half pack_x[8], pack_y[8]; // 8x16 bits=128 bits. + // reinterpret as float4 and load 128 bits in 1 memory issue. + LDST128BITS(pack_x[0]) = LDST128BITS(x[idx]); // load 128 bits + + #pragma unroll + for (int i = 0; i < 8; ++i) { + half v = __hmin(__hmax(pack_x[i], MIN_EXP_F16), MAX_EXP_F16); + pack_y[i] = HALF_GELU_OPS(v); + } + // reinterpret as float4 and store 128 bits in 1 memory issue. + if ((idx + 7) < N) { LDST128BITS(y[idx]) = LDST128BITS(pack_y[0]); } +} + +// --------------------- PyTorch bindings for custom kernel ----------------------- +#define STRINGFY(str) #str +#define TORCH_BINDING_COMMON_EXTENSION(func) \ + m.def(STRINGFY(func), &func, STRINGFY(func)); + +#define CHECK_TORCH_TENSOR_DTYPE(T, th_type) \ +if(((T).options().dtype() != (th_type))) { \ + std::cout << "Tensor Info:" << (T).options() << std::endl; \ + throw std::runtime_error("values must be "#th_type); \ +} + +#define TORCH_BINDING_GELU(packed_type, th_type, element_type, n_elements) \ +void gelu_##packed_type(torch::Tensor x, torch::Tensor y) { \ + CHECK_TORCH_TENSOR_DTYPE(x, (th_type)) \ + CHECK_TORCH_TENSOR_DTYPE(y, (th_type)) \ + const int ndim = x.dim(); \ + if (ndim != 2) { \ + int N = 1; \ + for (int i = 0; i < ndim; ++i) { N *= x.size(i); } \ + dim3 block(256 / (n_elements)); \ + dim3 grid((N + 256 - 1) / 256); \ + gelu_##packed_type##_kernel<<>>( \ + reinterpret_cast(x.data_ptr()), \ + reinterpret_cast(y.data_ptr()), N); \ + } else { \ + const int S = x.size(0); \ + const int K = x.size(1); \ + const int N = S * K; \ + if ((K/(n_elements)) <= 1024) { \ + dim3 block(K/(n_elements)); \ + dim3 grid(S); \ + gelu_##packed_type##_kernel<<>>( \ + reinterpret_cast(x.data_ptr()), \ + reinterpret_cast(y.data_ptr()), N); \ + } else { \ + int N = 1; \ + for (int i = 0; i < ndim; ++i) { N *= x.size(i); } \ + dim3 block(256 / (n_elements)); \ + dim3 grid((N + 256 - 1) / 256); \ + gelu_##packed_type##_kernel<<>>( \ + reinterpret_cast(x.data_ptr()), \ + reinterpret_cast(y.data_ptr()), N); \ + } \ + } \ +} + + +TORCH_BINDING_GELU(f32, torch::kFloat32, float, 1) +TORCH_BINDING_GELU(f32x4, torch::kFloat32, float, 4) +TORCH_BINDING_GELU(f16, torch::kHalf, half, 1) +TORCH_BINDING_GELU(f16x2, torch::kHalf, half, 2) +TORCH_BINDING_GELU(f16x8, torch::kHalf, half, 8) +TORCH_BINDING_GELU(f16x8_pack, torch::kHalf, half, 8) + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + TORCH_BINDING_COMMON_EXTENSION(gelu_f32) + TORCH_BINDING_COMMON_EXTENSION(gelu_f32x4) + TORCH_BINDING_COMMON_EXTENSION(gelu_f16) + TORCH_BINDING_COMMON_EXTENSION(gelu_f16x2) + TORCH_BINDING_COMMON_EXTENSION(gelu_f16x8) + TORCH_BINDING_COMMON_EXTENSION(gelu_f16x8_pack) +} diff --git a/gelu/gelu.py b/gelu/gelu.py new file mode 100644 index 00000000..c3ed7ed5 --- /dev/null +++ b/gelu/gelu.py @@ -0,0 +1,80 @@ +import torch.nn +import time +import torch.utils +from torch.utils.cpp_extension import load +from typing import Optional +from functools import partial + +torch.set_grad_enabled(False) + +# Load the CUDA kernel as a python module +lib = load(name='gelu_lib', + sources=['gelu.cu'], + extra_cuda_cflags=[ + "-O3", + "-U__CUDA_NO_HALF_OPERATORS__", + "-U__CUDA_NO_HALF_CONVERSIONS__", + "-U__CUDA_NO_HALF2_OPERATORS__", + "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", + "--expt-relaxed-constexpr", + "--expt-extended-lambda", + "--use_fast_math", + ], + extra_cflags=['-std=c++17']) + + +def run_benchmark(perf_func: callable, x: torch.Tensor, tag: str, + out: Optional[torch.Tensor] = None, warmup: int = 10, + iters: int = 1000, show_all: bool = False): + if out is not None: + out.fill_(0) + # warmup + if out is not None: + for i in range(warmup): + perf_func(x, out) + else: + for i in range(warmup): + _ = perf_func(x) + torch.cuda.synchronize() + + start = time.time() + # iters + if out is not None: + for i in range(iters): + perf_func(x, out) + else: + for i in range(iters): + out = perf_func(x) + torch.cuda.synchronize() + end = time.time() + total_time = (end - start) * 1000 # ms + mean_time = total_time / iters + out_info = f"out_{tag}" + out_val = out.flatten().detach().cpu().numpy().tolist()[:2] + out_val = [round(v, 8) for v in out_val] + print(f"{out_info:>18}: {out_val}, time:{mean_time:.8f}ms") + if show_all: print(out) + return out, mean_time + +Ss = [1024, 2048, 4096] +Ks = [1024, 2048, 4096] +SKs = [(S, K) for S in Ss for K in Ks] +torch.gelu = torch.nn.GELU("tanh") +for (S, K) in SKs: + print("-" * 85) + print(" " * 40 + f"S={S}, K={K}") + x = torch.randn((S, K)).cuda().float().contiguous() + y = torch.zeros_like(x).cuda().float().contiguous() + run_benchmark(lib.gelu_f32, x, "f32", y) + run_benchmark(lib.gelu_f32x4, x, "f32x4", y) + run_benchmark(partial(torch.gelu), x, "f32_th") + + print("-" * 85) + x_f16 = x.half().contiguous() + y_f16 = y.half().contiguous() + run_benchmark(lib.gelu_f16, x_f16, "f16", y_f16) + run_benchmark(lib.gelu_f16x2, x_f16, "f16x2", y_f16) + run_benchmark(lib.gelu_f16x8, x_f16, "f16x8", y_f16) + run_benchmark(lib.gelu_f16x8_pack, x_f16, "f16x8pack", y_f16) + run_benchmark(partial(torch.gelu), x_f16, "f16_th") + print("-" * 85) From 043be2dbedafa8306df7cbbf00a31773ad752838 Mon Sep 17 00:00:00 2001 From: DefTruth <31974251+DefTruth@users.noreply.github.com> Date: Fri, 11 Oct 2024 08:58:42 +0800 Subject: [PATCH 2/5] Update README.md --- gelu/README.md | 161 ++++++++++++++++++++++++------------------------- 1 file changed, 80 insertions(+), 81 deletions(-) diff --git a/gelu/README.md b/gelu/README.md index c6dd37be..23ea35fe 100755 --- a/gelu/README.md +++ b/gelu/README.md @@ -17,15 +17,14 @@ 对于半精度(half)的GELU操作,由于CUDA的半精度计算中并不包含tanh操作,因此需要使用hexp来替代对应的操作,因此会引入较大的误差。(或许可以考虑从汇编上解决这个问题);而torch是通过转化数据类型完成的。想要测试很简单,修改一下cu中f16里面的代码做一下强制类型转换即可: -```cpp -// line 96 -y[idx] = HALF_GELU_OPS(__half2float(v)); -// line 109 , line 110 -reg_y.x = HALF_GELU_OPS(__half2float(reg_x.x)); +```c++ +y[idx] = HALF_GELU_OPS(__half2float(v)); // line 96 +reg_y.x = HALF_GELU_OPS(__half2float(reg_x.x)); // line 109 , line 110 reg_y.y = HALF_GELU_OPS(__half2float(reg_x.y)); ``` 测试结果如下(由于不是所有数据都会掉误差所以取了会有误差的情况,可见修改后out_f16和out_f16x2的结果和torch相同了): ```bash +------------------------------------------------------------------------------------- S=2048, K=4096 out_f32: [-0.08196318, -0.1613517], time:0.13425708ms out_f32x4: [-0.08196318, -0.1613517], time:0.14128804ms @@ -36,11 +35,11 @@ reg_y.y = HALF_GELU_OPS(__half2float(reg_x.y)); out_f16x8: [-0.08251953, -0.16137695], time:0.04196978ms out_f16x8pack: [-0.08251953, -0.16137695], time:0.04215288ms out_f16_th: [-0.08197021, -0.16137695], time:0.04287958ms - +------------------------------------------------------------------------------------- ``` 相关参考: -- (pytorch-c10-BFloat16.h)[https://github.com/pytorch/pytorch/blob/main/c10/util/BFloat16.h] -- (math ptx)[https://github.com/pavanky/math_ptx] +- [pytorch-c10-BFloat16.h](https://github.com/pytorch/pytorch/blob/main/c10/util/BFloat16.h) +- [math ptx](https://github.com/pavanky/math_ptx) 此外仿照torch实现了在float下tanh和none两种近似下的GELU函数,可以在gelu.cu的宏中进行修改实现不同的版本的编译。 @@ -55,110 +54,110 @@ python3 gelu.py ```bash ------------------------------------------------------------------------------------- S=1024, K=1024 - out_f32: [0.93880296, 0.15988638], time:0.02785468ms - out_f32x4: [0.93880296, 0.15988638], time:0.02076554ms - out_f32_th: [0.93880296, 0.15988638], time:0.01221609ms + out_f32: [-0.13358943, -0.06881647], time:0.01621890ms + out_f32x4: [-0.13358943, -0.06881647], time:0.01278400ms + out_f32_th: [-0.13358943, -0.06881647], time:0.00897789ms ------------------------------------------------------------------------------------- - out_f16: [0.93798828, 0.15979004], time:0.00964093ms - out_f16x2: [0.93798828, 0.15979004], time:0.00525022ms - out_f16x8: [0.93798828, 0.15979004], time:0.00469351ms - out_f16x8pack: [0.93798828, 0.15979004], time:0.00465655ms - out_f16_th: [0.93847656, 0.15991211], time:0.00669861ms + out_f16: [-0.13378906, -0.06884766], time:0.00663781ms + out_f16x2: [-0.13378906, -0.06884766], time:0.00366306ms + out_f16x8: [-0.13378906, -0.06884766], time:0.00343323ms + out_f16x8pack: [-0.13378906, -0.06884766], time:0.00331473ms + out_f16_th: [-0.13354492, -0.06884766], time:0.00907278ms ------------------------------------------------------------------------------------- ------------------------------------------------------------------------------------- S=1024, K=2048 - out_f32: [-0.14857908, 0.10128548], time:0.03697181ms - out_f32x4: [-0.14857908, 0.10128548], time:0.03849959ms - out_f32_th: [-0.14857908, 0.10128548], time:0.02257371ms + out_f32: [1.38783729, -0.06707606], time:0.02223682ms + out_f32x4: [1.38783729, -0.06707606], time:0.02367806ms + out_f32_th: [1.38783729, -0.06707606], time:0.00959325ms ------------------------------------------------------------------------------------- - out_f16: [-0.14904785, 0.10119629], time:0.01546693ms - out_f16x2: [-0.14904785, 0.10119629], time:0.01501513ms - out_f16x8: [-0.14904785, 0.10119629], time:0.01015544ms - out_f16x8pack: [-0.14904785, 0.10119629], time:0.01015282ms - out_f16_th: [-0.14855957, 0.10125732], time:0.01221085ms + out_f16: [1.38769531, -0.06713867], time:0.00834370ms + out_f16x2: [1.38769531, -0.06713867], time:0.00784707ms + out_f16x8: [1.38769531, -0.06713867], time:0.00499964ms + out_f16x8pack: [1.38769531, -0.06713867], time:0.00461078ms + out_f16_th: [1.38769531, -0.06707764], time:0.00895357ms ------------------------------------------------------------------------------------- ------------------------------------------------------------------------------------- S=1024, K=4096 - out_f32: [-0.16260667, 2.28252459], time:0.07104182ms - out_f32x4: [-0.16260667, 2.28252459], time:0.08304977ms - out_f32_th: [-0.16260667, 2.28252459], time:0.04243922ms + out_f32: [0.47386399, 0.05760021], time:0.04273629ms + out_f32x4: [0.47386399, 0.05760021], time:0.05011940ms + out_f32_th: [0.47386405, 0.05760022], time:0.00933146ms ------------------------------------------------------------------------------------- - out_f16: [-0.16296387, 2.28125], time:0.02782536ms - out_f16x2: [-0.16296387, 2.28125], time:0.02191663ms - out_f16x8: [-0.16296387, 2.28125], time:0.02220559ms - out_f16x8pack: [-0.16296387, 2.28125], time:0.02232957ms - out_f16_th: [-0.16259766, 2.28320312], time:0.02265978ms + out_f16: [0.47387695, 0.05761719], time:0.01495123ms + out_f16x2: [0.47387695, 0.05761719], time:0.01039743ms + out_f16x8: [0.47387695, 0.05761719], time:0.00936055ms + out_f16x8pack: [0.47387695, 0.05761719], time:0.00845838ms + out_f16_th: [0.47387695, 0.05758667], time:0.00918818ms ------------------------------------------------------------------------------------- ------------------------------------------------------------------------------------- S=2048, K=1024 - out_f32: [-0.16840045, -0.14960197], time:0.05070662ms - out_f32x4: [-0.16840045, -0.14960197], time:0.03644156ms - out_f32_th: [-0.16840045, -0.14960195], time:0.02212596ms + out_f32: [1.3562144, 0.40408486], time:0.03009892ms + out_f32x4: [1.3562144, 0.40408486], time:0.02289677ms + out_f32_th: [1.3562144, 0.40408486], time:0.00921512ms ------------------------------------------------------------------------------------- - out_f16: [-0.16845703, -0.1496582], time:0.02071333ms - out_f16x2: [-0.16845703, -0.1496582], time:0.01206446ms - out_f16x8: [-0.16845703, -0.1496582], time:0.00981784ms - out_f16x8pack: [-0.16845703, -0.1496582], time:0.00988960ms - out_f16_th: [-0.16845703, -0.1496582], time:0.01215363ms + out_f16: [1.35644531, 0.40405273], time:0.01173806ms + out_f16x2: [1.35644531, 0.40405273], time:0.00565076ms + out_f16x8: [1.35644531, 0.40405273], time:0.00502610ms + out_f16x8pack: [1.35644531, 0.40405273], time:0.00457048ms + out_f16_th: [1.35644531, 0.40429688], time:0.00904894ms ------------------------------------------------------------------------------------- ------------------------------------------------------------------------------------- S=2048, K=2048 - out_f32: [-0.16697021, -0.16277096], time:0.06218576ms - out_f32x4: [-0.16697021, -0.16277096], time:0.06344438ms - out_f32_th: [-0.16697019, -0.16277094], time:0.04222322ms + out_f32: [-0.16498716, -0.15077244], time:0.04273534ms + out_f32x4: [-0.16498716, -0.15077244], time:0.04386163ms + out_f32_th: [-0.16498716, -0.15077244], time:0.00913596ms ------------------------------------------------------------------------------------- - out_f16: [-0.16699219, -0.16271973], time:0.02624702ms - out_f16x2: [-0.16699219, -0.16271973], time:0.02568126ms - out_f16x8: [-0.16699219, -0.16271973], time:0.02205300ms - out_f16x8pack: [-0.16699219, -0.16271973], time:0.02210712ms - out_f16_th: [-0.16699219, -0.16271973], time:0.02253604ms + out_f16: [-0.16516113, -0.15075684], time:0.01495862ms + out_f16x2: [-0.16516113, -0.15075684], time:0.01407337ms + out_f16x8: [-0.16516113, -0.15075684], time:0.00796247ms + out_f16x8pack: [-0.16516113, -0.15075684], time:0.00734925ms + out_f16_th: [-0.16503906, -0.15075684], time:0.00917435ms ------------------------------------------------------------------------------------- ------------------------------------------------------------------------------------- S=2048, K=4096 - out_f32: [-0.09021921, -0.16487332], time:0.13927341ms - out_f32x4: [-0.09021921, -0.16487332], time:0.14096951ms - out_f32_th: [-0.09021921, -0.16487332], time:0.08194113ms + out_f32: [-0.03888749, 0.32139146], time:0.08363676ms + out_f32x4: [-0.03888749, 0.32139146], time:0.09505510ms + out_f32_th: [-0.03888749, 0.32139146], time:0.04022837ms ------------------------------------------------------------------------------------- - out_f16: [-0.09033203, -0.16503906], time:0.05144143ms - out_f16x2: [-0.09033203, -0.16503906], time:0.04174685ms - out_f16x8: [-0.09033203, -0.16503906], time:0.04198074ms - out_f16x8pack: [-0.09033203, -0.16503906], time:0.04212999ms - out_f16_th: [-0.09020996, -0.16491699], time:0.04287744ms + out_f16: [-0.03887939, 0.3215332], time:0.02813959ms + out_f16x2: [-0.03887939, 0.3215332], time:0.01906514ms + out_f16x8: [-0.03887939, 0.3215332], time:0.01664281ms + out_f16x8pack: [-0.03887939, 0.3215332], time:0.01474833ms + out_f16_th: [-0.03887939, 0.32128906], time:0.01357365ms ------------------------------------------------------------------------------------- ------------------------------------------------------------------------------------- S=4096, K=1024 - out_f32: [0.07282269, -0.06332674], time:0.09058189ms - out_f32x4: [0.07282269, -0.06332674], time:0.06340218ms - out_f32_th: [0.07282269, -0.06332674], time:0.04206586ms + out_f32: [-0.13875209, 1.08477271], time:0.05790567ms + out_f32x4: [-0.13875209, 1.08477271], time:0.04317236ms + out_f32_th: [-0.13875209, 1.08477271], time:0.00910425ms ------------------------------------------------------------------------------------- - out_f16: [0.07281494, -0.06335449], time:0.03970504ms - out_f16x2: [0.07281494, -0.06335449], time:0.02199268ms - out_f16x8: [0.07281494, -0.06335449], time:0.02213860ms - out_f16x8pack: [0.07281494, -0.06335449], time:0.02209067ms - out_f16_th: [0.07281494, -0.06335449], time:0.02286220ms + out_f16: [-0.13903809, 1.08496094], time:0.02198315ms + out_f16x2: [-0.13903809, 1.08496094], time:0.00964355ms + out_f16x8: [-0.13903809, 1.08496094], time:0.00780869ms + out_f16x8pack: [-0.13903809, 1.08496094], time:0.00729132ms + out_f16_th: [-0.13879395, 1.08496094], time:0.00926042ms ------------------------------------------------------------------------------------- ------------------------------------------------------------------------------------- S=4096, K=2048 - out_f32: [-0.11318169, -0.10836542], time:0.12297416ms - out_f32x4: [-0.11318169, -0.10836542], time:0.12383652ms - out_f32_th: [-0.11318169, -0.10836542], time:0.08190846ms + out_f32: [0.82045084, -0.0894338], time:0.08363843ms + out_f32x4: [0.82045084, -0.0894338], time:0.08431888ms + out_f32_th: [0.82045084, -0.0894338], time:0.03837347ms ------------------------------------------------------------------------------------- - out_f16: [-0.11315918, -0.10888672], time:0.05153990ms - out_f16x2: [-0.11315918, -0.10888672], time:0.04872131ms - out_f16x8: [-0.11315918, -0.10888672], time:0.04182482ms - out_f16x8pack: [-0.11315918, -0.10888672], time:0.04196978ms - out_f16_th: [-0.11322021, -0.1083374], time:0.04286408ms + out_f16: [0.8203125, -0.08947754], time:0.02813506ms + out_f16x2: [0.8203125, -0.08947754], time:0.02643061ms + out_f16x8: [0.8203125, -0.08947754], time:0.01383305ms + out_f16x8pack: [0.8203125, -0.08947754], time:0.01273918ms + out_f16_th: [0.82080078, -0.0894165], time:0.01357722ms ------------------------------------------------------------------------------------- ------------------------------------------------------------------------------------- S=4096, K=4096 - out_f32: [-0.16762884, 0.33026037], time:0.26759410ms - out_f32x4: [-0.16762884, 0.33026037], time:0.27700567ms - out_f32_th: [-0.16762884, 0.33026037], time:0.16148257ms -------------------------------------------------------------------------------------- - out_f16: [-0.16760254, 0.33032227], time:0.10299659ms - out_f16x2: [-0.16760254, 0.33032227], time:0.08103538ms - out_f16x8: [-0.16760254, 0.33032227], time:0.08191633ms - out_f16x8pack: [-0.16760254, 0.33032227], time:0.08227539ms - out_f16_th: [-0.16760254, 0.33032227], time:0.08262110ms + out_f32: [-0.06997654, -0.16092129], time:0.19113564ms + out_f32x4: [-0.06997654, -0.16092129], time:0.20371628ms + out_f32_th: [-0.06997654, -0.16092129], time:0.20496607ms +------------------------------------------------------------------------------------- + out_f16: [-0.07012939, -0.16113281], time:0.05451322ms + out_f16x2: [-0.07012939, -0.16113281], time:0.03633785ms + out_f16x8: [-0.07012939, -0.16113281], time:0.03115463ms + out_f16x8pack: [-0.07012939, -0.16113281], time:0.02735877ms + out_f16_th: [-0.07000732, -0.16088867], time:0.03889561ms ------------------------------------------------------------------------------------- ``` From 6867836fa3b2f36e029c635dbb1f7ba45f10af66 Mon Sep 17 00:00:00 2001 From: DefTruth <31974251+DefTruth@users.noreply.github.com> Date: Fri, 11 Oct 2024 09:03:03 +0800 Subject: [PATCH 3/5] Update gelu.cu --- gelu/gelu.cu | 26 +++++++++++--------------- 1 file changed, 11 insertions(+), 15 deletions(-) diff --git a/gelu/gelu.cu b/gelu/gelu.cu index d06230df..9f23322a 100644 --- a/gelu/gelu.cu +++ b/gelu/gelu.cu @@ -36,19 +36,17 @@ // But ops above will introduce error. // pytorch transform type while do tanh operator which include in the [pytorch/c10/util/BFloat16-math.h](https://github.com/pytorch/pytorch/blob/main/c10/util/BFloat16-math.h) __inline__ __device__ half gelu_tanh_approximate(half x){ - - half x_cube = x * x * x; - // compute mid value : inner = 0.7978845608 * (x + 0.044715 * x * x * x) - half inner = HALF_SQRT_2_PI * (x + HALF_V_APP * x_cube); - // compute tanh - return HALF_DIV2 * x * (HALF_1 + ((hexp(inner * HALF_2) - HALF_1) / (hexp(inner * HALF_2) + HALF_1))); + half x_cube = x * x * x; + // compute mid value : inner = 0.7978845608 * (x + 0.044715 * x * x * x) + half inner = HALF_SQRT_2_PI * (x + HALF_V_APP * x_cube); + // compute tanh + return HALF_DIV2 * x * (HALF_1 + ((hexp(inner * HALF_2) - HALF_1) / (hexp(inner * HALF_2) + HALF_1))); } __inline__ __device__ float gelu_tanh_approximate(float x){ - return 0.5f * x * (1.0f + tanhf(SQRT_2_PI * (x + 0.044715f * x * x * x))); + return 0.5f * x * (1.0f + tanhf(SQRT_2_PI * (x + 0.044715f * x * x * x))); } - __inline__ __device__ float gelu_none_approximate(float x){ return x * 0.5 * (1 + erff(x * M_SQRT1_2)); } @@ -56,7 +54,6 @@ __inline__ __device__ float gelu_none_approximate(float x){ // -------------------------------------- FP32 -------------------------------------- // GELU tanh approximate: x, y:x 0.5 * x * (1.0 + tanh(0.7978845608 * x * (1.0 + 0.044715 * x * x))) // grid(N/256), block(K=256) - __global__ void gelu_f32_kernel(float* x, float* y, int N) { int idx = blockIdx.x * blockDim.x + threadIdx.x; if (idx < N) { @@ -100,7 +97,6 @@ __global__ void gelu_f16_kernel(half* x, half* y, int N) { __global__ void gelu_f16x2_kernel(half* x, half* y, int N) { int idx = (blockIdx.x * blockDim.x + threadIdx.x) * 2; - half2 reg_x = HALF2(x[idx]); half2 reg_y; reg_x.x = __hmin(__hmax(reg_x.x, MIN_EXP_F16), MAX_EXP_F16); @@ -175,8 +171,8 @@ if(((T).options().dtype() != (th_type))) { \ throw std::runtime_error("values must be "#th_type); \ } -#define TORCH_BINDING_GELU(packed_type, th_type, element_type, n_elements) \ -void gelu_##packed_type(torch::Tensor x, torch::Tensor y) { \ +#define TORCH_BINDING_GELU(packed_type, th_type, element_type, n_elements) \ +void gelu_##packed_type(torch::Tensor x, torch::Tensor y) { \ CHECK_TORCH_TENSOR_DTYPE(x, (th_type)) \ CHECK_TORCH_TENSOR_DTYPE(y, (th_type)) \ const int ndim = x.dim(); \ @@ -185,7 +181,7 @@ void gelu_##packed_type(torch::Tensor x, torch::Tensor y) { \ for (int i = 0; i < ndim; ++i) { N *= x.size(i); } \ dim3 block(256 / (n_elements)); \ dim3 grid((N + 256 - 1) / 256); \ - gelu_##packed_type##_kernel<<>>( \ + gelu_##packed_type##_kernel<<>>( \ reinterpret_cast(x.data_ptr()), \ reinterpret_cast(y.data_ptr()), N); \ } else { \ @@ -195,7 +191,7 @@ void gelu_##packed_type(torch::Tensor x, torch::Tensor y) { \ if ((K/(n_elements)) <= 1024) { \ dim3 block(K/(n_elements)); \ dim3 grid(S); \ - gelu_##packed_type##_kernel<<>>( \ + gelu_##packed_type##_kernel<<>>( \ reinterpret_cast(x.data_ptr()), \ reinterpret_cast(y.data_ptr()), N); \ } else { \ @@ -203,7 +199,7 @@ void gelu_##packed_type(torch::Tensor x, torch::Tensor y) { \ for (int i = 0; i < ndim; ++i) { N *= x.size(i); } \ dim3 block(256 / (n_elements)); \ dim3 grid((N + 256 - 1) / 256); \ - gelu_##packed_type##_kernel<<>>( \ + gelu_##packed_type##_kernel<<>>( \ reinterpret_cast(x.data_ptr()), \ reinterpret_cast(y.data_ptr()), N); \ } \ From 1fc33a6501d4f3ce5bdc174a4e2b58e811e743b9 Mon Sep 17 00:00:00 2001 From: DefTruth <31974251+DefTruth@users.noreply.github.com> Date: Fri, 11 Oct 2024 09:04:57 +0800 Subject: [PATCH 4/5] Update gelu.py --- gelu/gelu.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/gelu/gelu.py b/gelu/gelu.py index c3ed7ed5..52cd24cf 100644 --- a/gelu/gelu.py +++ b/gelu/gelu.py @@ -67,14 +67,13 @@ def run_benchmark(perf_func: callable, x: torch.Tensor, tag: str, y = torch.zeros_like(x).cuda().float().contiguous() run_benchmark(lib.gelu_f32, x, "f32", y) run_benchmark(lib.gelu_f32x4, x, "f32x4", y) - run_benchmark(partial(torch.gelu), x, "f32_th") - + run_benchmark(partial(torch.gelu), x, "f32_th") print("-" * 85) x_f16 = x.half().contiguous() y_f16 = y.half().contiguous() - run_benchmark(lib.gelu_f16, x_f16, "f16", y_f16) - run_benchmark(lib.gelu_f16x2, x_f16, "f16x2", y_f16) - run_benchmark(lib.gelu_f16x8, x_f16, "f16x8", y_f16) - run_benchmark(lib.gelu_f16x8_pack, x_f16, "f16x8pack", y_f16) - run_benchmark(partial(torch.gelu), x_f16, "f16_th") + run_benchmark(lib.gelu_f16, x_f16, "f16", y_f16) + run_benchmark(lib.gelu_f16x2, x_f16, "f16x2", y_f16) + run_benchmark(lib.gelu_f16x8, x_f16, "f16x8", y_f16) + run_benchmark(lib.gelu_f16x8_pack, x_f16, "f16x8pack", y_f16) + run_benchmark(partial(torch.gelu), x_f16, "f16_th") print("-" * 85) From f24aab8e7e8b165ba116692079f4526b5d105bad Mon Sep 17 00:00:00 2001 From: DefTruth <31974251+DefTruth@users.noreply.github.com> Date: Fri, 11 Oct 2024 09:07:31 +0800 Subject: [PATCH 5/5] Update README.md --- README.md | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/README.md b/README.md index cbb62810..bcd8fc17 100644 --- a/README.md +++ b/README.md @@ -42,6 +42,12 @@ | ✔️ [relu_f16x2](./relu/relu.cu)|f16|/|[link](./relu/)|⭐️| | ✔️ [relu_f16x8](./relu/relu.cu)|f16|/|[link](./relu/)|⭐️| | ✔️ [relu_f16x8_pack](./relu/relu.cu)|f16|/|[link](./relu/)|⭐️⭐️| +| ✔️ [gelu_f32](./gelu/gelu.cu)|f32|/|[link](./gelu/)|⭐️| +| ✔️ [gelu_f32x4](./gelu/gelu.cu)|f32|/|[link](./gelu/)|⭐️| +| ✔️ [gelu_f16](./gelu/gelu.cu)|f16|/|[link](./gelu/)|⭐️| +| ✔️ [gelu_f16x2](./gelu/gelu.cu)|f16|/|[link](./gelu/)|⭐️| +| ✔️ [gelu_f16x8](./gelu/gelu.cu)|f16|/|[link](./gelu/)|⭐️| +| ✔️ [gelu_f16x8_pack](./gelu/gelu.cu)|f16|/|[link](./gelu/)|⭐️⭐️| | ✔️ [warp_reduce_[all]](./reduce/reduce.cu)|all|all|[link](./reduce/)|⭐️⭐️| | ✔️ [reduce_f32_f32](./reduce/reduce.cu)|f32|f32|[link](./reduce/)|⭐️⭐️| | ✔️ [reduce_f32x4_f32](./reduce/reduce.cu)|f32|f32|[link](./reduce/)|⭐️⭐️|