Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GELU] Add f32/x4, f16/x2/x8/x8pack kernel. #66

Merged
merged 5 commits into from
Oct 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,12 @@
| ✔️ [relu_f16x2](./relu/relu.cu)|f16|/|[link](./relu/)|⭐️|
| ✔️ [relu_f16x8](./relu/relu.cu)|f16|/|[link](./relu/)|⭐️|
| ✔️ [relu_f16x8_pack](./relu/relu.cu)|f16|/|[link](./relu/)|⭐️⭐️|
| ✔️ [gelu_f32](./gelu/gelu.cu)|f32|/|[link](./gelu/)|⭐️|
| ✔️ [gelu_f32x4](./gelu/gelu.cu)|f32|/|[link](./gelu/)|⭐️|
| ✔️ [gelu_f16](./gelu/gelu.cu)|f16|/|[link](./gelu/)|⭐️|
| ✔️ [gelu_f16x2](./gelu/gelu.cu)|f16|/|[link](./gelu/)|⭐️|
| ✔️ [gelu_f16x8](./gelu/gelu.cu)|f16|/|[link](./gelu/)|⭐️|
| ✔️ [gelu_f16x8_pack](./gelu/gelu.cu)|f16|/|[link](./gelu/)|⭐️⭐️|
| ✔️ [warp_reduce_[all]](./reduce/reduce.cu)|all|all|[link](./reduce/)|⭐️⭐️|
| ✔️ [reduce_f32_f32](./reduce/reduce.cu)|f32|f32|[link](./reduce/)|⭐️⭐️|
| ✔️ [reduce_f32x4_f32](./reduce/reduce.cu)|f32|f32|[link](./reduce/)|⭐️⭐️|
Expand Down
10 changes: 10 additions & 0 deletions gelu/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
*.so
*.a
*.dylib
*.dll
*.lib
.DS_Store
build
*.whl
tmp

163 changes: 163 additions & 0 deletions gelu/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
# GELU

## 0x00 说明

包含以下内容:

- [X] gelu_f32_kernel
- [X] gelu_f32x4_kernel(float4向量化版本)
- [X] gelu_f16_kernel
- [X] gelu_f16x2_kernel(half2向量化)
- [X] gelu_f16x8_kernel(unpack版本)
- [X] gelu_f16x8_pack_kernel(pack版本)
- [X] PyTorch bindings


## 测试

对于半精度(half)的GELU操作,由于CUDA的半精度计算中并不包含tanh操作,因此需要使用hexp来替代对应的操作,因此会引入较大的误差。(或许可以考虑从汇编上解决这个问题);而torch是通过转化数据类型完成的。想要测试很简单,修改一下cu中f16里面的代码做一下强制类型转换即可:

```c++
y[idx] = HALF_GELU_OPS(__half2float(v)); // line 96
reg_y.x = HALF_GELU_OPS(__half2float(reg_x.x)); // line 109 , line 110
reg_y.y = HALF_GELU_OPS(__half2float(reg_x.y));
```
测试结果如下(由于不是所有数据都会掉误差所以取了会有误差的情况,可见修改后out_f16和out_f16x2的结果和torch相同了):
```bash
-------------------------------------------------------------------------------------
S=2048, K=4096
out_f32: [-0.08196318, -0.1613517], time:0.13425708ms
out_f32x4: [-0.08196318, -0.1613517], time:0.14128804ms
out_f32_th: [-0.08196313, -0.1613517], time:0.08195782ms
-------------------------------------------------------------------------------------
out_f16: [-0.08197021, -0.16137695], time:0.12120271ms
out_f16x2: [-0.08197021, -0.16137695], time:0.12122369ms
out_f16x8: [-0.08251953, -0.16137695], time:0.04196978ms
out_f16x8pack: [-0.08251953, -0.16137695], time:0.04215288ms
out_f16_th: [-0.08197021, -0.16137695], time:0.04287958ms
-------------------------------------------------------------------------------------
```
相关参考:
- [pytorch-c10-BFloat16.h](https://github.com/pytorch/pytorch/blob/main/c10/util/BFloat16.h)
- [math ptx](https://github.com/pavanky/math_ptx)

此外仿照torch实现了在float下tanh和none两种近似下的GELU函数,可以在gelu.cu的宏中进行修改实现不同的版本的编译。

```bash
# 只测试Ada架构 不指定默认编译所有架构 耗时较长: Volta, Ampere, Ada, Hopper, ...
export TORCH_CUDA_ARCH_LIST=Ada
python3 gelu.py
```

输出(不做类型转换导致half误差):

```bash
-------------------------------------------------------------------------------------
S=1024, K=1024
out_f32: [-0.13358943, -0.06881647], time:0.01621890ms
out_f32x4: [-0.13358943, -0.06881647], time:0.01278400ms
out_f32_th: [-0.13358943, -0.06881647], time:0.00897789ms
-------------------------------------------------------------------------------------
out_f16: [-0.13378906, -0.06884766], time:0.00663781ms
out_f16x2: [-0.13378906, -0.06884766], time:0.00366306ms
out_f16x8: [-0.13378906, -0.06884766], time:0.00343323ms
out_f16x8pack: [-0.13378906, -0.06884766], time:0.00331473ms
out_f16_th: [-0.13354492, -0.06884766], time:0.00907278ms
-------------------------------------------------------------------------------------
-------------------------------------------------------------------------------------
S=1024, K=2048
out_f32: [1.38783729, -0.06707606], time:0.02223682ms
out_f32x4: [1.38783729, -0.06707606], time:0.02367806ms
out_f32_th: [1.38783729, -0.06707606], time:0.00959325ms
-------------------------------------------------------------------------------------
out_f16: [1.38769531, -0.06713867], time:0.00834370ms
out_f16x2: [1.38769531, -0.06713867], time:0.00784707ms
out_f16x8: [1.38769531, -0.06713867], time:0.00499964ms
out_f16x8pack: [1.38769531, -0.06713867], time:0.00461078ms
out_f16_th: [1.38769531, -0.06707764], time:0.00895357ms
-------------------------------------------------------------------------------------
-------------------------------------------------------------------------------------
S=1024, K=4096
out_f32: [0.47386399, 0.05760021], time:0.04273629ms
out_f32x4: [0.47386399, 0.05760021], time:0.05011940ms
out_f32_th: [0.47386405, 0.05760022], time:0.00933146ms
-------------------------------------------------------------------------------------
out_f16: [0.47387695, 0.05761719], time:0.01495123ms
out_f16x2: [0.47387695, 0.05761719], time:0.01039743ms
out_f16x8: [0.47387695, 0.05761719], time:0.00936055ms
out_f16x8pack: [0.47387695, 0.05761719], time:0.00845838ms
out_f16_th: [0.47387695, 0.05758667], time:0.00918818ms
-------------------------------------------------------------------------------------
-------------------------------------------------------------------------------------
S=2048, K=1024
out_f32: [1.3562144, 0.40408486], time:0.03009892ms
out_f32x4: [1.3562144, 0.40408486], time:0.02289677ms
out_f32_th: [1.3562144, 0.40408486], time:0.00921512ms
-------------------------------------------------------------------------------------
out_f16: [1.35644531, 0.40405273], time:0.01173806ms
out_f16x2: [1.35644531, 0.40405273], time:0.00565076ms
out_f16x8: [1.35644531, 0.40405273], time:0.00502610ms
out_f16x8pack: [1.35644531, 0.40405273], time:0.00457048ms
out_f16_th: [1.35644531, 0.40429688], time:0.00904894ms
-------------------------------------------------------------------------------------
-------------------------------------------------------------------------------------
S=2048, K=2048
out_f32: [-0.16498716, -0.15077244], time:0.04273534ms
out_f32x4: [-0.16498716, -0.15077244], time:0.04386163ms
out_f32_th: [-0.16498716, -0.15077244], time:0.00913596ms
-------------------------------------------------------------------------------------
out_f16: [-0.16516113, -0.15075684], time:0.01495862ms
out_f16x2: [-0.16516113, -0.15075684], time:0.01407337ms
out_f16x8: [-0.16516113, -0.15075684], time:0.00796247ms
out_f16x8pack: [-0.16516113, -0.15075684], time:0.00734925ms
out_f16_th: [-0.16503906, -0.15075684], time:0.00917435ms
-------------------------------------------------------------------------------------
-------------------------------------------------------------------------------------
S=2048, K=4096
out_f32: [-0.03888749, 0.32139146], time:0.08363676ms
out_f32x4: [-0.03888749, 0.32139146], time:0.09505510ms
out_f32_th: [-0.03888749, 0.32139146], time:0.04022837ms
-------------------------------------------------------------------------------------
out_f16: [-0.03887939, 0.3215332], time:0.02813959ms
out_f16x2: [-0.03887939, 0.3215332], time:0.01906514ms
out_f16x8: [-0.03887939, 0.3215332], time:0.01664281ms
out_f16x8pack: [-0.03887939, 0.3215332], time:0.01474833ms
out_f16_th: [-0.03887939, 0.32128906], time:0.01357365ms
-------------------------------------------------------------------------------------
-------------------------------------------------------------------------------------
S=4096, K=1024
out_f32: [-0.13875209, 1.08477271], time:0.05790567ms
out_f32x4: [-0.13875209, 1.08477271], time:0.04317236ms
out_f32_th: [-0.13875209, 1.08477271], time:0.00910425ms
-------------------------------------------------------------------------------------
out_f16: [-0.13903809, 1.08496094], time:0.02198315ms
out_f16x2: [-0.13903809, 1.08496094], time:0.00964355ms
out_f16x8: [-0.13903809, 1.08496094], time:0.00780869ms
out_f16x8pack: [-0.13903809, 1.08496094], time:0.00729132ms
out_f16_th: [-0.13879395, 1.08496094], time:0.00926042ms
-------------------------------------------------------------------------------------
-------------------------------------------------------------------------------------
S=4096, K=2048
out_f32: [0.82045084, -0.0894338], time:0.08363843ms
out_f32x4: [0.82045084, -0.0894338], time:0.08431888ms
out_f32_th: [0.82045084, -0.0894338], time:0.03837347ms
-------------------------------------------------------------------------------------
out_f16: [0.8203125, -0.08947754], time:0.02813506ms
out_f16x2: [0.8203125, -0.08947754], time:0.02643061ms
out_f16x8: [0.8203125, -0.08947754], time:0.01383305ms
out_f16x8pack: [0.8203125, -0.08947754], time:0.01273918ms
out_f16_th: [0.82080078, -0.0894165], time:0.01357722ms
-------------------------------------------------------------------------------------
-------------------------------------------------------------------------------------
S=4096, K=4096
out_f32: [-0.06997654, -0.16092129], time:0.19113564ms
out_f32x4: [-0.06997654, -0.16092129], time:0.20371628ms
out_f32_th: [-0.06997654, -0.16092129], time:0.20496607ms
-------------------------------------------------------------------------------------
out_f16: [-0.07012939, -0.16113281], time:0.05451322ms
out_f16x2: [-0.07012939, -0.16113281], time:0.03633785ms
out_f16x8: [-0.07012939, -0.16113281], time:0.03115463ms
out_f16x8pack: [-0.07012939, -0.16113281], time:0.02735877ms
out_f16_th: [-0.07000732, -0.16088867], time:0.03889561ms
-------------------------------------------------------------------------------------
```
Loading