From c45329225aee510bcf81c233e05bdb2ae41c9900 Mon Sep 17 00:00:00 2001 From: YuChuXi <81864000+YuChuXi@users.noreply.github.com> Date: Sun, 12 May 2024 16:28:19 +0800 Subject: [PATCH 01/16] hip --- wkv6_state/hip/wkv6state_hip_v1.hip | 305 ++++++++++++++++++++++++++++ wkv6_state/hip/wkv6state_op.cpp | 22 ++ wkv6_state/run.py | 4 +- 3 files changed, 329 insertions(+), 2 deletions(-) create mode 100644 wkv6_state/hip/wkv6state_hip_v1.hip create mode 100644 wkv6_state/hip/wkv6state_op.cpp diff --git a/wkv6_state/hip/wkv6state_hip_v1.hip b/wkv6_state/hip/wkv6state_hip_v1.hip new file mode 100644 index 0000000..2685794 --- /dev/null +++ b/wkv6_state/hip/wkv6state_hip_v1.hip @@ -0,0 +1,305 @@ +#include +#include +#include "ATen/ATen.h" +typedef at::BFloat16 bf16; + +template +__global__ void kernel_forward(const int B, const int T, const int C, const int H, + const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const float *__restrict__ _w, const F *__restrict__ _u, const F *__restrict__ _s, + F *__restrict__ const _y) +{ + const int b = blockIdx.x / H; + const int h = blockIdx.x % H; + const int i = threadIdx.x; + _u += h*_N_; + _s += h*_N_*_N_ + i*_N_; + + __shared__ float r[_N_], k[_N_], u[_N_], w[_N_]; + float state[_N_]; + + __syncthreads(); + u[i] = float(_u[i]); + __syncthreads(); + for (int j = 0; j < _N_; j++) { + state[j] = _s[j]; + } + + for (int t = b*T*C + h*_N_ + i; t < (b+1)*T*C + h*_N_ + i; t += C) + { + __syncthreads(); + w[i] = exp(_w[t]); + r[i] = float(_r[t]); + k[i] = float(_k[t]); + __syncthreads(); + + const float v = float(_v[t]); + float y = 0; + + #pragma unroll + for (int j = 0; j < _N_; j+=4) + { + const float4& r_ = (float4&)(r[j]); + const float4& k_ = (float4&)(k[j]); + const float4& w_ = (float4&)(w[j]); + const float4& u_ = (float4&)(u[j]); + float4& s = (float4&)(state[j]); + float4 x; + + x.x = k_.x * v; + x.y = k_.y * v; + x.z = k_.z * v; + x.w = k_.w * v; + + y += r_.x * (u_.x * x.x + s.x); + y += r_.y * (u_.y * x.y + s.y); + y += r_.z * (u_.z * x.z + s.z); + y += r_.w * (u_.w * x.w + s.w); + + s.x = s.x * w_.x + x.x; + s.y = s.y * w_.y + x.y; + s.z = s.z * w_.z + x.z; + s.w = s.w * w_.w + x.w; + } + _y[t] = F(y); + } +} + +template +__global__ void kernel_backward_111(const int B, const int T, const int C, const int H, + const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const float *__restrict__ _w, const F *__restrict__ _u, const F *__restrict__ _s, const F *__restrict__ const _gy, + F *__restrict__ const _gr, F *__restrict__ const _gk, F *__restrict__ const _gv, F *__restrict__ const _gu, F *__restrict__ const _gs) +{ + const int b = blockIdx.x / H; + const int h = blockIdx.x % H; + const int i = threadIdx.x; + _u += h*_N_; + _s += h*_N_*_N_ + i; + + __shared__ float u_[_N_]; + __shared__ float r[_N_], k[_N_], v[_N_], w_[_N_], gy[_N_]; + __syncthreads(); + u_[i] = float(_u[i]); + __syncthreads(); + + const float u = u_[i]; + + float state[_N_], scccc[_N_] = {0}, sdddd[_N_] = {0}, sssss[_N_] = {0}, swwww[_N_]; + for (int j = 0; j < _N_; j++) { + state[j] = _s[j*_N_]; + swwww[j] = 1.0; + } + + const int t_0 = b*T*C + h*_N_ + i; + const int t_T_1 = t_0 + (T-1)*C; + const int t_T = t_0 + T*C; + + float gu = 0; + for (int t = t_0; t < t_T; t += C) + { + __syncthreads(); + v[i] = float(_v[t]); + gy[i] = float(_gy[t]); + __syncthreads(); + + const float k = float(_k[t]); + const float w = exp(_w[t]); + float gr = 0, gu_ = 0; + + #pragma unroll + for (int j = 0; j < _N_; j++) + { + float& s = state[j]; + float x = k * v[j]; + + gr += (u * x + s) * gy[j]; + gu_ += x * gy[j]; + s = s * w + x; + } + _gr[t] = F(gr); + gu += float(_r[t]) * gu_; + } + _gu[b*C + h*_N_ + i] = F(gu); + + for (int t = t_T_1; t >= t_0; t -= C) + { + __syncthreads(); + v[i] = float(_v[t]); + gy[i] = float(_gy[t]); + __syncthreads(); + + const float rr = float(_r[t]); + const float w = exp(_w[t]); + float gk = 0; + + #pragma unroll + for (int j = 0; j < _N_; j++) + { + float& s = scccc[j]; + float x = rr * gy[j]; + + gk += (u * x + s) * v[j]; + s = x + s * w; + } + _gk[t] = F(gk); + } + + for (int t = t_T_1; t >= t_0; t -= C) + { + __syncthreads(); + r[i] = float(_r[t]); + k[i] = float(_k[t]); + w_[i] = exp(_w[t]); + __syncthreads(); + + const float gyy = float(_gy[t]); + float gv = 0; + + #pragma unroll + for (int j = 0; j < _N_; j++) + { + float& s = sdddd[j]; + float x = gyy * r[j]; + + gv += (u_[j] * x + s) * k[j]; + s = x + s * w_[j]; + } + _gv[t] = F(gv); + } + + for (int t = t_0; t < t_T; t += C) + { + __syncthreads(); + r[i] = float(_r[t]); + w_[i] = exp(_w[t]); + __syncthreads(); + + const float gyy = float(_gy[t]); + + #pragma unroll + for (int j = 0; j < _N_; j++) + { + float& w = swwww[j]; + sssss[j] += gyy * w * r[j]; + w *= w_[j]; + } + } + for (int j = 0; j < _N_; j++) + _gs[b*H*_N_*_N_ + h*_N_*_N_ + i*_N_ + j] = sssss[j]; +} + +template +__global__ void kernel_backward_222(const int B, const int T, const int C, const int H, + const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const float *__restrict__ _w, const F *__restrict__ _u, const F *__restrict__ _s, const F *__restrict__ const _gy, + F *__restrict__ const _gw) +{ + const int b = blockIdx.x / H; + const int h = blockIdx.x % H; + const int i = threadIdx.x; + _s += h*_N_*_N_ + i; + + __shared__ float v[_N_], gy[_N_]; + float saaaa[_N_] = {0}, sbbbb[_T_-1] = {0}, scccc[_N_] = {0}; + + const int t_0 = b*T*C + h*_N_ + i; + const int t_1 = t_0 + C; + const int t_2 = t_0 + 2*C; + const int t_T_1 = t_0 + (T-1)*C; + + for (int t = t_T_1; t > t_1; t -= C) + { + __syncthreads(); + gy[i] = float(_gy[t]); + v[i] = float(_v[t-2*C]); + __syncthreads(); + + const float r = float(_r[t]); + const float w = exp(_w[t-C]); + float sum = 0.0f; + + #pragma unroll + for (int j = 0; j < _N_; j++) + { + float& s = saaaa[j]; + s = (s + r * gy[j]) * w; + sum += s * v[j]; + } + sbbbb[(t-t_1)/C] = sum * float(_k[t-2*C]); + } + { + __syncthreads(); + gy[i] = float(_gy[t_1]); + __syncthreads(); + + const float r = float(_r[t_1]); + const float w = exp(_w[t_0]); + float sum = 0.0f; + + #pragma unroll + for (int j = 0; j < _N_; j++) + { + float& s = saaaa[j]; + s = (s + r * gy[j]) * w; + sum += s * _s[j*_N_]; + } + sbbbb[0] = sum; + } + + float sss = sbbbb[0]; + _gw[t_0] = F(sss * _w[t_0]); + + { + __syncthreads(); + gy[i] = float(_gy[t_1]); + __syncthreads(); + + const float w = exp(_w[t_0]); + float sum = 0.0f; + + #pragma unroll + for (int j = 0; j < _N_; j++) + { + float& s = scccc[j]; + s = (s + _s[j*_N_]) * w; + sum += s * gy[j]; + } + sss += sbbbb[1] - (sum * float(_r[t_1])); + _gw[t_1] = F(sss * _w[t_1]); + } + for (int t = t_2; t < t_T_1; t += C) + { + __syncthreads(); + gy[i] = float(_gy[t]); + v[i] = float(_v[t-2*C]); + __syncthreads(); + + const float w = exp(_w[t-C]); + const float k = float(_k[t-2*C]); + float sum = 0.0f; + + #pragma unroll + for (int j = 0; j < _N_; j++) + { + float& s = scccc[j]; + s = (s + k * v[j]) * w; + sum += s * gy[j]; + } + sss += sbbbb[(t-t_0)/C] - (sum * float(_r[t])); + _gw[t] = F(sss * _w[t]); + } + _gw[t_T_1] = 0; +} + +void cuda_forward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *z, bf16 *y) +{ + assert(H*_N_ == C); + assert(_N_%4 == 0); + kernel_forward<<>>(B, T, C, H, r, k, v, w, u, z, y); +} + +void cuda_backward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *z, bf16 *gy, bf16 *gr, bf16 *gk, bf16 *gv, bf16 *gw, bf16 *gu, bf16 *gz) +{ + assert(H*_N_ == C); + assert(_N_%4 == 0); + kernel_backward_111<<>>(B, T, C, H, r, k, v, w, u, z, gy, gr, gk, gv, gu, gz); + kernel_backward_222<<>>(B, T, C, H, r, k, v, w, u, z, gy, gw); +} diff --git a/wkv6_state/hip/wkv6state_op.cpp b/wkv6_state/hip/wkv6state_op.cpp new file mode 100644 index 0000000..ef0f1d2 --- /dev/null +++ b/wkv6_state/hip/wkv6state_op.cpp @@ -0,0 +1,22 @@ +#include +#include "ATen/ATen.h" +typedef at::BFloat16 bf16; + +void cuda_forward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *s, bf16 *y); +void cuda_backward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *s, bf16 *gy, bf16 *gr, bf16 *gk, bf16 *gv, bf16 *gw, bf16 *gu, bf16 *gs); + +void forward(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &s, torch::Tensor &y) { + cuda_forward(B, T, C, H, r.data_ptr(), k.data_ptr(), v.data_ptr(), w.data_ptr(), u.data_ptr(), s.data_ptr(), y.data_ptr()); +} +void backward(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &s, torch::Tensor &gy, torch::Tensor &gr, torch::Tensor &gk, torch::Tensor &gv, torch::Tensor &gw, torch::Tensor &gu, torch::Tensor &gs) { + cuda_backward(B, T, C, H, r.data_ptr(), k.data_ptr(), v.data_ptr(), w.data_ptr(), u.data_ptr(), s.data_ptr(), gy.data_ptr(), gr.data_ptr(), gk.data_ptr(), gv.data_ptr(), gw.data_ptr(), gu.data_ptr(), gs.data_ptr()); +} +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("forward", &forward, "wkv6b forward"); + m.def("backward", &backward, "wkv6b backward"); +} + +TORCH_LIBRARY(wkv6b, m) { + m.def("forward", forward); + m.def("backward", backward); +} diff --git a/wkv6_state/run.py b/wkv6_state/run.py index aad0522..cca8b3d 100644 --- a/wkv6_state/run.py +++ b/wkv6_state/run.py @@ -74,8 +74,8 @@ def val(x): # CUDA Kernel ######################################################################################################## -wkv6state_cuda = load(name="wkv6state", sources=["cuda/wkv6state_op.cpp", f"cuda/wkv6state_cuda_{CUDA_KERNEL_VERSION}.cu"], - verbose=True, extra_cuda_cflags=["-res-usage", "--use_fast_math", "-O3", "-Xptxas -O3", "--extra-device-vectorization", f"-D_N_={HEAD_SIZE}", f"-D_T_={T}"]) +wkv6state_cuda = load(name="wkv6state", sources=["hip/wkv6state_op.cpp", f"hip/wkv6state_cuda_{CUDA_KERNEL_VERSION}.hip"], + verbose=True, extra_cuda_cflags=["--use_fast_math", "-O3", f"-D_N_={HEAD_SIZE}", f"-D_T_={T}"]) class WKV_6STATE(torch.autograd.Function): @staticmethod From e385c506e72a37b1e650f1816a2450d87f796e50 Mon Sep 17 00:00:00 2001 From: YuChuXi <81864000+YuChuXi@users.noreply.github.com> Date: Sun, 12 May 2024 16:30:02 +0800 Subject: [PATCH 02/16] hip --- wkv6_state/run.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/wkv6_state/run.py b/wkv6_state/run.py index cca8b3d..da83f5e 100644 --- a/wkv6_state/run.py +++ b/wkv6_state/run.py @@ -74,7 +74,7 @@ def val(x): # CUDA Kernel ######################################################################################################## -wkv6state_cuda = load(name="wkv6state", sources=["hip/wkv6state_op.cpp", f"hip/wkv6state_cuda_{CUDA_KERNEL_VERSION}.hip"], +wkv6state_cuda = load(name="wkv6state", sources=["hip/wkv6state_op.cpp", f"hip/wkv6state_hip_{CUDA_KERNEL_VERSION}.hip"], verbose=True, extra_cuda_cflags=["--use_fast_math", "-O3", f"-D_N_={HEAD_SIZE}", f"-D_T_={T}"]) class WKV_6STATE(torch.autograd.Function): From 044e3c3bad26e7a1275c3f058d4386d0c8ff6c5b Mon Sep 17 00:00:00 2001 From: YuChuXi <81864000+YuChuXi@users.noreply.github.com> Date: Sun, 12 May 2024 16:30:49 +0800 Subject: [PATCH 03/16] hip --- wkv6_state/run.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/wkv6_state/run.py b/wkv6_state/run.py index da83f5e..9db12d4 100644 --- a/wkv6_state/run.py +++ b/wkv6_state/run.py @@ -75,7 +75,7 @@ def val(x): ######################################################################################################## wkv6state_cuda = load(name="wkv6state", sources=["hip/wkv6state_op.cpp", f"hip/wkv6state_hip_{CUDA_KERNEL_VERSION}.hip"], - verbose=True, extra_cuda_cflags=["--use_fast_math", "-O3", f"-D_N_={HEAD_SIZE}", f"-D_T_={T}"]) + verbose=True, extra_cuda_cflags=["-O3", f"-D_N_={HEAD_SIZE}", f"-D_T_={T}"]) class WKV_6STATE(torch.autograd.Function): @staticmethod From a70934a9f11adfc9606150b6b8129e41b1d7c077 Mon Sep 17 00:00:00 2001 From: YuChuXi <81864000+YuChuXi@users.noreply.github.com> Date: Sun, 12 May 2024 16:34:29 +0800 Subject: [PATCH 04/16] hip --- wkv6_state/run.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/wkv6_state/run.py b/wkv6_state/run.py index 9db12d4..dcf8c10 100644 --- a/wkv6_state/run.py +++ b/wkv6_state/run.py @@ -75,7 +75,7 @@ def val(x): ######################################################################################################## wkv6state_cuda = load(name="wkv6state", sources=["hip/wkv6state_op.cpp", f"hip/wkv6state_hip_{CUDA_KERNEL_VERSION}.hip"], - verbose=True, extra_cuda_cflags=["-O3", f"-D_N_={HEAD_SIZE}", f"-D_T_={T}"]) + verbose=True, extra_cuda_cflags=["-O3", "--hipstdpar", f"-D_N_={HEAD_SIZE}", f"-D_T_={T}"]) class WKV_6STATE(torch.autograd.Function): @staticmethod From 9c45dc71c57ba500929ab167c20eba9572270ffc Mon Sep 17 00:00:00 2001 From: YuChuXi <81864000+YuChuXi@users.noreply.github.com> Date: Sun, 12 May 2024 16:48:06 +0800 Subject: [PATCH 05/16] aaa --- wkv/run.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/wkv/run.py b/wkv/run.py index 33b1a29..a6410b8 100644 --- a/wkv/run.py +++ b/wkv/run.py @@ -84,7 +84,7 @@ def RUN_PYTORCH(B, T, C, w, u, k, v, time_curve): from torch.utils.cpp_extension import load wkv_cuda = load(name="wkv", sources=["cuda/wkv_op.cpp", f"cuda/wkv_cuda_v{CUDA_KERNEL_VERSION}.cu"], - verbose=True, extra_cuda_cflags=['--use_fast_math', '--extra-device-vectorization'], extra_cflags=['/wd4624']) + verbose=True, extra_cuda_cflags=["--O3", "-xhip", "--hipstdpar"], extra_cflags=['/wd4624']) class WKV(torch.autograd.Function): @staticmethod From 8ad1c64f4b127d45375dbcfc1837c6fc0fa3d8ec Mon Sep 17 00:00:00 2001 From: YuChuXi <81864000+YuChuXi@users.noreply.github.com> Date: Sun, 12 May 2024 16:48:47 +0800 Subject: [PATCH 06/16] a --- wkv/run.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/wkv/run.py b/wkv/run.py index a6410b8..4e5231f 100644 --- a/wkv/run.py +++ b/wkv/run.py @@ -84,7 +84,7 @@ def RUN_PYTORCH(B, T, C, w, u, k, v, time_curve): from torch.utils.cpp_extension import load wkv_cuda = load(name="wkv", sources=["cuda/wkv_op.cpp", f"cuda/wkv_cuda_v{CUDA_KERNEL_VERSION}.cu"], - verbose=True, extra_cuda_cflags=["--O3", "-xhip", "--hipstdpar"], extra_cflags=['/wd4624']) + verbose=True, extra_cuda_cflags=["-O3", "-xhip", "--hipstdpar"]) class WKV(torch.autograd.Function): @staticmethod From 67ec9ccab8bc869b16065bc8685f359eb785041b Mon Sep 17 00:00:00 2001 From: YuChuXi <81864000+YuChuXi@users.noreply.github.com> Date: Sun, 12 May 2024 16:51:20 +0800 Subject: [PATCH 07/16] aaaaa --- wkv/run.py | 2 +- wkv6_state/run.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/wkv/run.py b/wkv/run.py index 4e5231f..9e0395b 100644 --- a/wkv/run.py +++ b/wkv/run.py @@ -84,7 +84,7 @@ def RUN_PYTORCH(B, T, C, w, u, k, v, time_curve): from torch.utils.cpp_extension import load wkv_cuda = load(name="wkv", sources=["cuda/wkv_op.cpp", f"cuda/wkv_cuda_v{CUDA_KERNEL_VERSION}.cu"], - verbose=True, extra_cuda_cflags=["-O3", "-xhip", "--hipstdpar"]) + verbose=True, extra_cuda_cflags=["-O3", "-xhip", "--hipstdpar", "--hip-link"]) class WKV(torch.autograd.Function): @staticmethod diff --git a/wkv6_state/run.py b/wkv6_state/run.py index dcf8c10..3afc04b 100644 --- a/wkv6_state/run.py +++ b/wkv6_state/run.py @@ -75,7 +75,7 @@ def val(x): ######################################################################################################## wkv6state_cuda = load(name="wkv6state", sources=["hip/wkv6state_op.cpp", f"hip/wkv6state_hip_{CUDA_KERNEL_VERSION}.hip"], - verbose=True, extra_cuda_cflags=["-O3", "--hipstdpar", f"-D_N_={HEAD_SIZE}", f"-D_T_={T}"]) + verbose=True, extra_cuda_cflags=["-O3", "--hipstdpar", "-xhip", "--hip-link", f"-D_N_={HEAD_SIZE}", f"-D_T_={T}"]) class WKV_6STATE(torch.autograd.Function): @staticmethod From ff5f21a24f0470fe70f3ef6e8905834dce7ad6a1 Mon Sep 17 00:00:00 2001 From: YuChuXi <81864000+YuChuXi@users.noreply.github.com> Date: Sun, 12 May 2024 16:58:16 +0800 Subject: [PATCH 08/16] a? --- wkv6_state/hip/wkv6state_hip_v1.hip | 305 ---------------------------- wkv6_state/hip/wkv6state_op.cpp | 22 -- wkv6_state/run.py | 2 +- 3 files changed, 1 insertion(+), 328 deletions(-) delete mode 100644 wkv6_state/hip/wkv6state_hip_v1.hip delete mode 100644 wkv6_state/hip/wkv6state_op.cpp diff --git a/wkv6_state/hip/wkv6state_hip_v1.hip b/wkv6_state/hip/wkv6state_hip_v1.hip deleted file mode 100644 index 2685794..0000000 --- a/wkv6_state/hip/wkv6state_hip_v1.hip +++ /dev/null @@ -1,305 +0,0 @@ -#include -#include -#include "ATen/ATen.h" -typedef at::BFloat16 bf16; - -template -__global__ void kernel_forward(const int B, const int T, const int C, const int H, - const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const float *__restrict__ _w, const F *__restrict__ _u, const F *__restrict__ _s, - F *__restrict__ const _y) -{ - const int b = blockIdx.x / H; - const int h = blockIdx.x % H; - const int i = threadIdx.x; - _u += h*_N_; - _s += h*_N_*_N_ + i*_N_; - - __shared__ float r[_N_], k[_N_], u[_N_], w[_N_]; - float state[_N_]; - - __syncthreads(); - u[i] = float(_u[i]); - __syncthreads(); - for (int j = 0; j < _N_; j++) { - state[j] = _s[j]; - } - - for (int t = b*T*C + h*_N_ + i; t < (b+1)*T*C + h*_N_ + i; t += C) - { - __syncthreads(); - w[i] = exp(_w[t]); - r[i] = float(_r[t]); - k[i] = float(_k[t]); - __syncthreads(); - - const float v = float(_v[t]); - float y = 0; - - #pragma unroll - for (int j = 0; j < _N_; j+=4) - { - const float4& r_ = (float4&)(r[j]); - const float4& k_ = (float4&)(k[j]); - const float4& w_ = (float4&)(w[j]); - const float4& u_ = (float4&)(u[j]); - float4& s = (float4&)(state[j]); - float4 x; - - x.x = k_.x * v; - x.y = k_.y * v; - x.z = k_.z * v; - x.w = k_.w * v; - - y += r_.x * (u_.x * x.x + s.x); - y += r_.y * (u_.y * x.y + s.y); - y += r_.z * (u_.z * x.z + s.z); - y += r_.w * (u_.w * x.w + s.w); - - s.x = s.x * w_.x + x.x; - s.y = s.y * w_.y + x.y; - s.z = s.z * w_.z + x.z; - s.w = s.w * w_.w + x.w; - } - _y[t] = F(y); - } -} - -template -__global__ void kernel_backward_111(const int B, const int T, const int C, const int H, - const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const float *__restrict__ _w, const F *__restrict__ _u, const F *__restrict__ _s, const F *__restrict__ const _gy, - F *__restrict__ const _gr, F *__restrict__ const _gk, F *__restrict__ const _gv, F *__restrict__ const _gu, F *__restrict__ const _gs) -{ - const int b = blockIdx.x / H; - const int h = blockIdx.x % H; - const int i = threadIdx.x; - _u += h*_N_; - _s += h*_N_*_N_ + i; - - __shared__ float u_[_N_]; - __shared__ float r[_N_], k[_N_], v[_N_], w_[_N_], gy[_N_]; - __syncthreads(); - u_[i] = float(_u[i]); - __syncthreads(); - - const float u = u_[i]; - - float state[_N_], scccc[_N_] = {0}, sdddd[_N_] = {0}, sssss[_N_] = {0}, swwww[_N_]; - for (int j = 0; j < _N_; j++) { - state[j] = _s[j*_N_]; - swwww[j] = 1.0; - } - - const int t_0 = b*T*C + h*_N_ + i; - const int t_T_1 = t_0 + (T-1)*C; - const int t_T = t_0 + T*C; - - float gu = 0; - for (int t = t_0; t < t_T; t += C) - { - __syncthreads(); - v[i] = float(_v[t]); - gy[i] = float(_gy[t]); - __syncthreads(); - - const float k = float(_k[t]); - const float w = exp(_w[t]); - float gr = 0, gu_ = 0; - - #pragma unroll - for (int j = 0; j < _N_; j++) - { - float& s = state[j]; - float x = k * v[j]; - - gr += (u * x + s) * gy[j]; - gu_ += x * gy[j]; - s = s * w + x; - } - _gr[t] = F(gr); - gu += float(_r[t]) * gu_; - } - _gu[b*C + h*_N_ + i] = F(gu); - - for (int t = t_T_1; t >= t_0; t -= C) - { - __syncthreads(); - v[i] = float(_v[t]); - gy[i] = float(_gy[t]); - __syncthreads(); - - const float rr = float(_r[t]); - const float w = exp(_w[t]); - float gk = 0; - - #pragma unroll - for (int j = 0; j < _N_; j++) - { - float& s = scccc[j]; - float x = rr * gy[j]; - - gk += (u * x + s) * v[j]; - s = x + s * w; - } - _gk[t] = F(gk); - } - - for (int t = t_T_1; t >= t_0; t -= C) - { - __syncthreads(); - r[i] = float(_r[t]); - k[i] = float(_k[t]); - w_[i] = exp(_w[t]); - __syncthreads(); - - const float gyy = float(_gy[t]); - float gv = 0; - - #pragma unroll - for (int j = 0; j < _N_; j++) - { - float& s = sdddd[j]; - float x = gyy * r[j]; - - gv += (u_[j] * x + s) * k[j]; - s = x + s * w_[j]; - } - _gv[t] = F(gv); - } - - for (int t = t_0; t < t_T; t += C) - { - __syncthreads(); - r[i] = float(_r[t]); - w_[i] = exp(_w[t]); - __syncthreads(); - - const float gyy = float(_gy[t]); - - #pragma unroll - for (int j = 0; j < _N_; j++) - { - float& w = swwww[j]; - sssss[j] += gyy * w * r[j]; - w *= w_[j]; - } - } - for (int j = 0; j < _N_; j++) - _gs[b*H*_N_*_N_ + h*_N_*_N_ + i*_N_ + j] = sssss[j]; -} - -template -__global__ void kernel_backward_222(const int B, const int T, const int C, const int H, - const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const float *__restrict__ _w, const F *__restrict__ _u, const F *__restrict__ _s, const F *__restrict__ const _gy, - F *__restrict__ const _gw) -{ - const int b = blockIdx.x / H; - const int h = blockIdx.x % H; - const int i = threadIdx.x; - _s += h*_N_*_N_ + i; - - __shared__ float v[_N_], gy[_N_]; - float saaaa[_N_] = {0}, sbbbb[_T_-1] = {0}, scccc[_N_] = {0}; - - const int t_0 = b*T*C + h*_N_ + i; - const int t_1 = t_0 + C; - const int t_2 = t_0 + 2*C; - const int t_T_1 = t_0 + (T-1)*C; - - for (int t = t_T_1; t > t_1; t -= C) - { - __syncthreads(); - gy[i] = float(_gy[t]); - v[i] = float(_v[t-2*C]); - __syncthreads(); - - const float r = float(_r[t]); - const float w = exp(_w[t-C]); - float sum = 0.0f; - - #pragma unroll - for (int j = 0; j < _N_; j++) - { - float& s = saaaa[j]; - s = (s + r * gy[j]) * w; - sum += s * v[j]; - } - sbbbb[(t-t_1)/C] = sum * float(_k[t-2*C]); - } - { - __syncthreads(); - gy[i] = float(_gy[t_1]); - __syncthreads(); - - const float r = float(_r[t_1]); - const float w = exp(_w[t_0]); - float sum = 0.0f; - - #pragma unroll - for (int j = 0; j < _N_; j++) - { - float& s = saaaa[j]; - s = (s + r * gy[j]) * w; - sum += s * _s[j*_N_]; - } - sbbbb[0] = sum; - } - - float sss = sbbbb[0]; - _gw[t_0] = F(sss * _w[t_0]); - - { - __syncthreads(); - gy[i] = float(_gy[t_1]); - __syncthreads(); - - const float w = exp(_w[t_0]); - float sum = 0.0f; - - #pragma unroll - for (int j = 0; j < _N_; j++) - { - float& s = scccc[j]; - s = (s + _s[j*_N_]) * w; - sum += s * gy[j]; - } - sss += sbbbb[1] - (sum * float(_r[t_1])); - _gw[t_1] = F(sss * _w[t_1]); - } - for (int t = t_2; t < t_T_1; t += C) - { - __syncthreads(); - gy[i] = float(_gy[t]); - v[i] = float(_v[t-2*C]); - __syncthreads(); - - const float w = exp(_w[t-C]); - const float k = float(_k[t-2*C]); - float sum = 0.0f; - - #pragma unroll - for (int j = 0; j < _N_; j++) - { - float& s = scccc[j]; - s = (s + k * v[j]) * w; - sum += s * gy[j]; - } - sss += sbbbb[(t-t_0)/C] - (sum * float(_r[t])); - _gw[t] = F(sss * _w[t]); - } - _gw[t_T_1] = 0; -} - -void cuda_forward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *z, bf16 *y) -{ - assert(H*_N_ == C); - assert(_N_%4 == 0); - kernel_forward<<>>(B, T, C, H, r, k, v, w, u, z, y); -} - -void cuda_backward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *z, bf16 *gy, bf16 *gr, bf16 *gk, bf16 *gv, bf16 *gw, bf16 *gu, bf16 *gz) -{ - assert(H*_N_ == C); - assert(_N_%4 == 0); - kernel_backward_111<<>>(B, T, C, H, r, k, v, w, u, z, gy, gr, gk, gv, gu, gz); - kernel_backward_222<<>>(B, T, C, H, r, k, v, w, u, z, gy, gw); -} diff --git a/wkv6_state/hip/wkv6state_op.cpp b/wkv6_state/hip/wkv6state_op.cpp deleted file mode 100644 index ef0f1d2..0000000 --- a/wkv6_state/hip/wkv6state_op.cpp +++ /dev/null @@ -1,22 +0,0 @@ -#include -#include "ATen/ATen.h" -typedef at::BFloat16 bf16; - -void cuda_forward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *s, bf16 *y); -void cuda_backward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *s, bf16 *gy, bf16 *gr, bf16 *gk, bf16 *gv, bf16 *gw, bf16 *gu, bf16 *gs); - -void forward(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &s, torch::Tensor &y) { - cuda_forward(B, T, C, H, r.data_ptr(), k.data_ptr(), v.data_ptr(), w.data_ptr(), u.data_ptr(), s.data_ptr(), y.data_ptr()); -} -void backward(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &s, torch::Tensor &gy, torch::Tensor &gr, torch::Tensor &gk, torch::Tensor &gv, torch::Tensor &gw, torch::Tensor &gu, torch::Tensor &gs) { - cuda_backward(B, T, C, H, r.data_ptr(), k.data_ptr(), v.data_ptr(), w.data_ptr(), u.data_ptr(), s.data_ptr(), gy.data_ptr(), gr.data_ptr(), gk.data_ptr(), gv.data_ptr(), gw.data_ptr(), gu.data_ptr(), gs.data_ptr()); -} -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("forward", &forward, "wkv6b forward"); - m.def("backward", &backward, "wkv6b backward"); -} - -TORCH_LIBRARY(wkv6b, m) { - m.def("forward", forward); - m.def("backward", backward); -} diff --git a/wkv6_state/run.py b/wkv6_state/run.py index 3afc04b..92a4b24 100644 --- a/wkv6_state/run.py +++ b/wkv6_state/run.py @@ -74,7 +74,7 @@ def val(x): # CUDA Kernel ######################################################################################################## -wkv6state_cuda = load(name="wkv6state", sources=["hip/wkv6state_op.cpp", f"hip/wkv6state_hip_{CUDA_KERNEL_VERSION}.hip"], +wkv6state_cuda = load(name="wkv6state", sources=["cuda/wkv6state_op.cpp", f"cuda/wkv6state_cuda_{CUDA_KERNEL_VERSION}.cuda"], verbose=True, extra_cuda_cflags=["-O3", "--hipstdpar", "-xhip", "--hip-link", f"-D_N_={HEAD_SIZE}", f"-D_T_={T}"]) class WKV_6STATE(torch.autograd.Function): From 818933814ad831184e0f899b0afb1d804cd7cf88 Mon Sep 17 00:00:00 2001 From: YuChuXi <81864000+YuChuXi@users.noreply.github.com> Date: Sun, 12 May 2024 16:58:48 +0800 Subject: [PATCH 09/16] o! --- wkv6_state/run.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/wkv6_state/run.py b/wkv6_state/run.py index 92a4b24..4e3bd5c 100644 --- a/wkv6_state/run.py +++ b/wkv6_state/run.py @@ -74,7 +74,7 @@ def val(x): # CUDA Kernel ######################################################################################################## -wkv6state_cuda = load(name="wkv6state", sources=["cuda/wkv6state_op.cpp", f"cuda/wkv6state_cuda_{CUDA_KERNEL_VERSION}.cuda"], +wkv6state_cuda = load(name="wkv6state", sources=["cuda/wkv6state_op.cpp", f"cuda/wkv6state_cuda_{CUDA_KERNEL_VERSION}.cu"], verbose=True, extra_cuda_cflags=["-O3", "--hipstdpar", "-xhip", "--hip-link", f"-D_N_={HEAD_SIZE}", f"-D_T_={T}"]) class WKV_6STATE(torch.autograd.Function): From 3e817fa752e4d3f507caf607331785b615ea7073 Mon Sep 17 00:00:00 2001 From: YuChuXi <81864000+YuChuXi@users.noreply.github.com> Date: Sun, 12 May 2024 18:31:55 +0800 Subject: [PATCH 10/16] hip --- wkv6/run.py | 19 +++++++++++++++---- wkv6_state/run.py | 11 +++++++++-- 2 files changed, 24 insertions(+), 6 deletions(-) diff --git a/wkv6/run.py b/wkv6/run.py index d15c2ef..3a11688 100644 --- a/wkv6/run.py +++ b/wkv6/run.py @@ -66,9 +66,15 @@ def get_err_ratio(x, y): def val(x): return x.detach().float().cpu().numpy() -wkv5_cuda = load(name="wkv5", sources=["cuda/wkv5_op.cpp", f"cuda/wkv5_cuda_v1b2.cu"], - verbose=True, extra_cuda_cflags=["-res-usage", "--use_fast_math", "-O3", "-Xptxas -O3", "--extra-device-vectorization", f"-D_N_={HEAD_SIZE}"]) - +run_rocm = True + +if run_rocm: + wkv5_cuda = load(name="wkv5", sources=["cuda/wkv5_op.cpp", f"cuda/wkv5_cuda_v1b2.cu"], + verbose=True, extra_cuda_cflags=["-O3", "--hipstdpar", "-xhip", "--hip-link", f"-D_N_={HEAD_SIZE}"]) +else: + wkv5_cuda = load(name="wkv5", sources=["cuda/wkv5_op.cpp", f"cuda/wkv5_cuda_v1b2.cu"], + verbose=True, extra_cuda_cflags=["-res-usage", "--use_fast_math", "-O3", "-Xptxas -O3", "--extra-device-vectorization", f"-D_N_={HEAD_SIZE}"]) + class WKV_5(torch.autograd.Function): @staticmethod def forward(ctx, B, T, C, H, r, k, v, w, u): @@ -122,7 +128,12 @@ def RUN_CUDA_5(B, T, C, H, r, k, v, w, u): # CUDA Kernel ######################################################################################################## -wkv6_cuda = load(name="wkv6", sources=["cuda/wkv6_op.cpp", f"cuda/wkv6_cuda_{CUDA_KERNEL_VERSION}.cu"], +if run_rocm: + wkv6_cuda = load(name="wkv6", sources=["cuda/wkv6_op.cpp", f"cuda/wkv6_cuda_{CUDA_KERNEL_VERSION}.cu"], + verbose=True, extra_cuda_cflags=["-O3", "--hipstdpar", "-xhip", "--hip-link", f"-D_N_={HEAD_SIZE}", f"-D_T_={T}"]) + +else: + wkv6_cuda = load(name="wkv6", sources=["cuda/wkv6_op.cpp", f"cuda/wkv6_cuda_{CUDA_KERNEL_VERSION}.cu"], verbose=True, extra_cuda_cflags=["-res-usage", "--use_fast_math", "-O3", "-Xptxas -O3", "--extra-device-vectorization", f"-D_N_={HEAD_SIZE}", f"-D_T_={T}"]) class WKV_6(torch.autograd.Function): diff --git a/wkv6_state/run.py b/wkv6_state/run.py index 4e3bd5c..d4071c9 100644 --- a/wkv6_state/run.py +++ b/wkv6_state/run.py @@ -74,9 +74,16 @@ def val(x): # CUDA Kernel ######################################################################################################## -wkv6state_cuda = load(name="wkv6state", sources=["cuda/wkv6state_op.cpp", f"cuda/wkv6state_cuda_{CUDA_KERNEL_VERSION}.cu"], - verbose=True, extra_cuda_cflags=["-O3", "--hipstdpar", "-xhip", "--hip-link", f"-D_N_={HEAD_SIZE}", f"-D_T_={T}"]) +run_rocm = True + +if run_rocm: + wkv6state_cuda = load(name="wkv6state", sources=["cuda/wkv6state_op.cpp", f"cuda/wkv6state_cuda_{CUDA_KERNEL_VERSION}.cu"], + verbose=True, extra_cuda_cflags=["-O3", "--hipstdpar", "-xhip", "--hip-link", f"-D_N_={HEAD_SIZE}", f"-D_T_={T}"]) +else: + wkv6state_cuda = load(name="wkv6state", sources=["cuda/wkv6state_op.cpp", f"cuda/wkv6state_cuda_{CUDA_KERNEL_VERSION}.cu"], + verbose=True, extra_cuda_cflags=["-res-usage", "--use_fast_math", "-Xptxas -O3", "--extra-device-vectorization", f"-D_N_={HEAD_SIZE}", f"-D_T_={T}"]) + class WKV_6STATE(torch.autograd.Function): @staticmethod def forward(ctx, B, T, C, H, r, k, v, w, u, s): From e12cc803e8382bb2f1b6aa0b7c9432154e7ec34a Mon Sep 17 00:00:00 2001 From: YuChuXi <81864000+YuChuXi@users.noreply.github.com> Date: Sun, 12 May 2024 18:37:15 +0800 Subject: [PATCH 11/16] a --- wkv/run.py | 2 +- wkv6/run.py | 4 ++-- wkv6_state/run.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/wkv/run.py b/wkv/run.py index 9e0395b..4e5231f 100644 --- a/wkv/run.py +++ b/wkv/run.py @@ -84,7 +84,7 @@ def RUN_PYTORCH(B, T, C, w, u, k, v, time_curve): from torch.utils.cpp_extension import load wkv_cuda = load(name="wkv", sources=["cuda/wkv_op.cpp", f"cuda/wkv_cuda_v{CUDA_KERNEL_VERSION}.cu"], - verbose=True, extra_cuda_cflags=["-O3", "-xhip", "--hipstdpar", "--hip-link"]) + verbose=True, extra_cuda_cflags=["-O3", "-xhip", "--hipstdpar"]) class WKV(torch.autograd.Function): @staticmethod diff --git a/wkv6/run.py b/wkv6/run.py index 3a11688..d3fe226 100644 --- a/wkv6/run.py +++ b/wkv6/run.py @@ -70,7 +70,7 @@ def val(x): if run_rocm: wkv5_cuda = load(name="wkv5", sources=["cuda/wkv5_op.cpp", f"cuda/wkv5_cuda_v1b2.cu"], - verbose=True, extra_cuda_cflags=["-O3", "--hipstdpar", "-xhip", "--hip-link", f"-D_N_={HEAD_SIZE}"]) + verbose=True, extra_cuda_cflags=["-O3", "--hipstdpar", "-xhip", f"-D_N_={HEAD_SIZE}"]) else: wkv5_cuda = load(name="wkv5", sources=["cuda/wkv5_op.cpp", f"cuda/wkv5_cuda_v1b2.cu"], verbose=True, extra_cuda_cflags=["-res-usage", "--use_fast_math", "-O3", "-Xptxas -O3", "--extra-device-vectorization", f"-D_N_={HEAD_SIZE}"]) @@ -130,7 +130,7 @@ def RUN_CUDA_5(B, T, C, H, r, k, v, w, u): if run_rocm: wkv6_cuda = load(name="wkv6", sources=["cuda/wkv6_op.cpp", f"cuda/wkv6_cuda_{CUDA_KERNEL_VERSION}.cu"], - verbose=True, extra_cuda_cflags=["-O3", "--hipstdpar", "-xhip", "--hip-link", f"-D_N_={HEAD_SIZE}", f"-D_T_={T}"]) + verbose=True, extra_cuda_cflags=["-O3", "--hipstdpar", "-xhip", f"-D_N_={HEAD_SIZE}", f"-D_T_={T}"]) else: wkv6_cuda = load(name="wkv6", sources=["cuda/wkv6_op.cpp", f"cuda/wkv6_cuda_{CUDA_KERNEL_VERSION}.cu"], diff --git a/wkv6_state/run.py b/wkv6_state/run.py index d4071c9..3f98e2c 100644 --- a/wkv6_state/run.py +++ b/wkv6_state/run.py @@ -78,7 +78,7 @@ def val(x): if run_rocm: wkv6state_cuda = load(name="wkv6state", sources=["cuda/wkv6state_op.cpp", f"cuda/wkv6state_cuda_{CUDA_KERNEL_VERSION}.cu"], - verbose=True, extra_cuda_cflags=["-O3", "--hipstdpar", "-xhip", "--hip-link", f"-D_N_={HEAD_SIZE}", f"-D_T_={T}"]) + verbose=True, extra_cuda_cflags=["-O3", "--hipstdpar", "-xhip", f"-D_N_={HEAD_SIZE}", f"-D_T_={T}"]) else: wkv6state_cuda = load(name="wkv6state", sources=["cuda/wkv6state_op.cpp", f"cuda/wkv6state_cuda_{CUDA_KERNEL_VERSION}.cu"], verbose=True, extra_cuda_cflags=["-res-usage", "--use_fast_math", "-Xptxas -O3", "--extra-device-vectorization", f"-D_N_={HEAD_SIZE}", f"-D_T_={T}"]) From 37d161b02150fb08ecaac981dc1ba4bbaceaa0f4 Mon Sep 17 00:00:00 2001 From: YuChuXi <81864000+YuChuXi@users.noreply.github.com> Date: Sun, 12 May 2024 18:42:37 +0800 Subject: [PATCH 12/16] ccc --- wkv6/run.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/wkv6/run.py b/wkv6/run.py index d3fe226..fb06ce0 100644 --- a/wkv6/run.py +++ b/wkv6/run.py @@ -1,5 +1,5 @@ import os -os.environ["CUDA_VISIBLE_DEVICES"] = "7" +# os.environ["CUDA_VISIBLE_DEVICES"] = "7" import torch from torch.utils.cpp_extension import load from torch.nn import functional as F From 857910da3c447340c7e05566ddc1b739071c1006 Mon Sep 17 00:00:00 2001 From: YuChuXi Date: Sun, 12 May 2024 18:44:41 +0800 Subject: [PATCH 13/16] hip --- wkv/hip/wkv_hip_v2.hip | 132 ++++++++++++ wkv/hip/wkv_op.cpp | 22 ++ wkv6/hip/wkv5_hip_v1b2.hip | 204 ++++++++++++++++++ wkv6/hip/wkv5_op.cpp | 23 +++ wkv6/hip/wkv6_hip_v1.hip | 225 ++++++++++++++++++++ wkv6/hip/wkv6_op.cpp | 23 +++ wkv6_state/hip/wkv6state_hip_v1.hip | 307 ++++++++++++++++++++++++++++ wkv6_state/hip/wkv6state_op.cpp | 23 +++ 8 files changed, 959 insertions(+) create mode 100644 wkv/hip/wkv_hip_v2.hip create mode 100644 wkv/hip/wkv_op.cpp create mode 100644 wkv6/hip/wkv5_hip_v1b2.hip create mode 100644 wkv6/hip/wkv5_op.cpp create mode 100644 wkv6/hip/wkv6_hip_v1.hip create mode 100644 wkv6/hip/wkv6_op.cpp create mode 100644 wkv6_state/hip/wkv6state_hip_v1.hip create mode 100644 wkv6_state/hip/wkv6state_op.cpp diff --git a/wkv/hip/wkv_hip_v2.hip b/wkv/hip/wkv_hip_v2.hip new file mode 100644 index 0000000..7061bd5 --- /dev/null +++ b/wkv/hip/wkv_hip_v2.hip @@ -0,0 +1,132 @@ +// !!! This is a file automatically generated by hipify!!! +#include "hip/hip_runtime.h" +#include +#include + +template +__global__ void kernel_forward(const int B, const int T, const int C, + const F *__restrict__ const _w, const F *__restrict__ const _u, const F *__restrict__ const _k, const F *__restrict__ const _v, + F *__restrict__ const _y) +{ + const int idx = blockIdx.x * blockDim.x + threadIdx.x; + const int _b = idx / C; + const int _c = idx % C; + const int _offset = _b * T * C + _c; + + F u = _u[_c]; + F w = _w[_c]; + const F *__restrict__ const k = _k + _offset; + const F *__restrict__ const v = _v + _offset; + F *__restrict__ const y = _y + _offset; + + F p = 0, q = 0, o = -65500; + // p and q are running sums divided by exp(o) (to avoid overflows) + for (int i = 0; i < T; i++) + { + const int ii = i * C; + + F no = max(o, u + k[ii]); + F A = exp(o - no); + F B = exp(u + k[ii] - no); + y[ii] = (A * p + B * v[ii]) / (A * q + B); + + no = max(w + o, k[ii]); + A = exp(w + o - no); + B = exp(k[ii] - no); + p = A * p + B * v[ii]; + q = A * q + B; + o = no; + } +} + +template +__global__ void kernel_backward(const int B, const int T, const int C, + const F *__restrict__ const _w, const F *__restrict__ const _u, const F *__restrict__ const _k, const F *__restrict__ const _v, const F *__restrict__ const _gy, + F *__restrict__ const _gw, F *__restrict__ const _gu, F *__restrict__ const _gk, F *__restrict__ const _gv) +{ + const int idx = blockIdx.x * blockDim.x + threadIdx.x; + const int _b = idx / C; + const int _c = idx % C; + const int _offset = _b * T * C + _c; + + F u = _u[_c]; + F w = _w[_c]; + const F *__restrict__ const k = _k + _offset; + const F *__restrict__ const v = _v + _offset; + const F *__restrict__ const gy = _gy + _offset; + + F *__restrict__ const gk = _gk + _offset; + F *__restrict__ const gv = _gv + _offset; + + F y[4096], z[4096], zexp[4096]; + + F gw = 0, gu = 0; + F p = 0, q = 0; + F dpdw = 0, dqdw = 0; + F o = -65500; + for (int i = 0; i < T; i++) + { + const int ii = i * C; + F no = max(o, k[ii] + u); + F A = exp(o - no); + F B = exp(k[ii] + u - no); + + F num = A * p + B * v[ii]; + F iden = 1 / (A * q + B); + + y[i] = num * iden; + z[i] = iden; + zexp[i] = k[ii] + u - no; + + gw += gy[ii] * (dpdw - dqdw * y[i]) * iden * A; + gu += gy[ii] * (v[ii] - y[i]) * B * iden; + + no = max(w + o, k[ii]); + A = exp(w + o - no); + B = exp(k[ii] - no); + dpdw = A * (p + dpdw); + dqdw = A * (q + dqdw); + p = A * p + B * v[ii]; + q = A * q + B; + o = no; + } + + F gp = 0, gq = 0; + o = -65500; + for (int i = T - 1; i >= 0; i--) + { + const int ii = i * C; + F A = gy[ii] * z[i] * exp(zexp[i]); + F B = exp(k[ii] + o); + gk[ii] = A * (v[ii] - y[i]) + B * (gp * v[ii] + gq); + gv[ii] = A + B * gp; + + F no = max(w + o, zexp[i] - k[ii] - u); + A = exp(w + o - no); + B = gy[ii] * z[i] * exp(zexp[i] - k[ii] - u - no); + gp = A * gp + B; + gq = A * gq - B * y[i]; + o = no; + } + + // Multiply by w because the w -> -exp(w) preprocessing is halfway in the backwards pass, even though it's not in the forward pass + const int _offsetBC = _b * C + _c; + _gw[_offsetBC] += gw * _w[_c]; + _gu[_offsetBC] += gu; +} + +void cuda_forward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y) +{ + dim3 threadsPerBlock( min(C, 32) ); + assert(B * C % threadsPerBlock.x == 0); + dim3 numBlocks(B * C / threadsPerBlock.x); + hipLaunchKernelGGL(( kernel_forward), dim3(numBlocks), dim3(threadsPerBlock), 0, 0, B, T, C, w, u, k, v, y); +} + +void cuda_backward(int B, int T, int C, float *w, float *u, float *k, float *v, float *gy, float *gw, float *gu, float *gk, float *gv) +{ + dim3 threadsPerBlock( min(C, 32) ); + assert(B * C % threadsPerBlock.x == 0); + dim3 numBlocks(B * C / threadsPerBlock.x); + hipLaunchKernelGGL(( kernel_backward), dim3(numBlocks), dim3(threadsPerBlock), 0, 0, B, T, C, w, u, k, v, gy, gw, gu, gk, gv); +} diff --git a/wkv/hip/wkv_op.cpp b/wkv/hip/wkv_op.cpp new file mode 100644 index 0000000..c3e1541 --- /dev/null +++ b/wkv/hip/wkv_op.cpp @@ -0,0 +1,22 @@ +// !!! This is a file automatically generated by hipify!!! +#include + +void cuda_forward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y); +void cuda_backward(int B, int T, int C, float *w, float *u, float *k, float *v, float *gy, float *gw, float *gu, float *gk, float *gv); + +void forward(int64_t B, int64_t T, int64_t C, torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y) { + cuda_forward(B, T, C, w.data_ptr(), u.data_ptr(), k.data_ptr(), v.data_ptr(), y.data_ptr()); +} +void backward(int64_t B, int64_t T, int64_t C, torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &gy, torch::Tensor &gw, torch::Tensor &gu, torch::Tensor &gk, torch::Tensor &gv) { + cuda_backward(B, T, C, w.data_ptr(), u.data_ptr(), k.data_ptr(), v.data_ptr(), gy.data_ptr(), gw.data_ptr(), gu.data_ptr(), gk.data_ptr(), gv.data_ptr()); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("forward", &forward, "wkv forward"); + m.def("backward", &backward, "wkv backward"); +} + +TORCH_LIBRARY(wkv, m) { + m.def("forward", forward); + m.def("backward", backward); +} diff --git a/wkv6/hip/wkv5_hip_v1b2.hip b/wkv6/hip/wkv5_hip_v1b2.hip new file mode 100644 index 0000000..b347ab3 --- /dev/null +++ b/wkv6/hip/wkv5_hip_v1b2.hip @@ -0,0 +1,204 @@ +// !!! This is a file automatically generated by hipify!!! +#include "hip/hip_runtime.h" +#include +#include +#include "ATen/ATen.h" +typedef at::BFloat16 bf16; + +template +__global__ void kernel_forward(const int B, const int T, const int C, const int H, + const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const float *__restrict__ _w, const F *__restrict__ _u, + F *__restrict__ const _y) +{ + const int b = blockIdx.x / H; + const int h = blockIdx.x % H; + const int i = threadIdx.x; + _w += h*_N_; + _u += h*_N_; + + __shared__ float r[_N_], k[_N_], u[_N_], w[_N_]; + float state[_N_] = {0}; + + __syncthreads(); + w[i] = _w[i]; + u[i] = float(_u[i]); + __syncthreads(); + + for (int t = b*T*C + h*_N_ + i; t < (b+1)*T*C + h*_N_ + i; t += C) + { + __syncthreads(); + r[i] = float(_r[t]); + k[i] = float(_k[t]); + __syncthreads(); + + const float v = float(_v[t]); + float y = 0; + + #pragma unroll + for (int j = 0; j < _N_; j+=4) + { + const float4& r_ = (float4&)(r[j]); + const float4& k_ = (float4&)(k[j]); + const float4& w_ = (float4&)(w[j]); + const float4& u_ = (float4&)(u[j]); + float4& s = (float4&)(state[j]); + float4 x; + + x.x = k_.x * v; + x.y = k_.y * v; + x.z = k_.z * v; + x.w = k_.w * v; + + y += r_.x * (u_.x * x.x + s.x); + y += r_.y * (u_.y * x.y + s.y); + y += r_.z * (u_.z * x.z + s.z); + y += r_.w * (u_.w * x.w + s.w); + + s.x = s.x * w_.x + x.x; + s.y = s.y * w_.y + x.y; + s.z = s.z * w_.z + x.z; + s.w = s.w * w_.w + x.w; + } + _y[t] = F(y); + } +} + +template +__global__ void kernel_backward(const int B, const int T, const int C, const int H, + const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const float *__restrict__ _w, const float *__restrict__ __w, const F *__restrict__ _u, const F *__restrict__ const _gy, + F *__restrict__ const _gr, F *__restrict__ const _gk, F *__restrict__ const _gv, F *__restrict__ const _gw, F *__restrict__ const _gu) +{ + const int b = blockIdx.x / H; + const int h = blockIdx.x % H; + const int i = threadIdx.x; + _w += h*_N_; + _u += h*_N_; + __w += h*_N_; + + __shared__ float w_[_N_], u_[_N_]; + __shared__ float r[_N_], k[_N_], v[_N_], gy[_N_]; + __syncthreads(); + w_[i] = _w[i]; + u_[i] = float(_u[i]); + __syncthreads(); + + const float w = w_[i]; + const float ww = __w[i]; + const float u = u_[i]; + + float state[_N_] = {0}, saaaa[_N_] = {0}, sbbbb[_N_] = {0}, scccc[_N_] = {0}, sdddd[_N_] = {0}; + + float gw = 0, gu = 0; + const int t000 = b*T*C + h*_N_ + i; + const int t111 = (b+1)*T*C + h*_N_ + i; + const int t222 = t111 - 2*C; + + for (int t = t000; t < t111; t += C) + { + __syncthreads(); + v[i] = float(_v[t]); + gy[i] = float(_gy[t]); + __syncthreads(); + + const float k = float(_k[t]); + float gr = 0, gu_ = 0; + + #pragma unroll + for (int j = 0; j < _N_; j++) + { + float& s = state[j]; + float x = k * v[j]; + + gr += (u * x + s) * gy[j]; + gu_ += x * gy[j]; + s = s * w + x; + } + _gr[t] = F(gr); + gu += float(_r[t]) * gu_; + } + _gu[b*C + h*_N_ + i] = F(gu); + + for (int t = t000; t < t222; t += C) + { + __syncthreads(); + v[i] = float(_v[t]); + gy[i] = float(_gy[t + 2*C]); + __syncthreads(); + + const float k = float(_k[t]); + float gw_ = 0; + + #pragma unroll + for (int j = 0; j < _N_; j++) + { + float& s = saaaa[j]; + float& s2 = sbbbb[j]; + float x = k * v[j]; + + float tmp = w * (x + s); + s = tmp; + s2 = tmp + w * s2; + gw_ += s2 * gy[j]; + } + gw += float(_r[t + 2*C]) * gw_; + } + _gw[b*C + h*_N_ + i] = F(ww * gw); + + for (int t = t111 - C; t >= t000; t -= C) + { + __syncthreads(); + v[i] = float(_v[t]); + gy[i] = float(_gy[t]); + __syncthreads(); + + const float rr = float(_r[t]); + float gk = 0; + + #pragma unroll + for (int j = 0; j < _N_; j++) + { + float& s = scccc[j]; + float x = rr * gy[j]; + + gk += (u * x + s) * v[j]; + s = x + s * w; + } + _gk[t] = F(gk); + } + + for (int t = t111 - C; t >= t000; t -= C) + { + __syncthreads(); + r[i] = float(_r[t]); + k[i] = float(_k[t]); + __syncthreads(); + + const float gyy = float(_gy[t]); + float gv = 0; + + #pragma unroll + for (int j = 0; j < _N_; j++) + { + float& s = sdddd[j]; + float x = gyy * r[j]; + + gv += (u_[j] * x + s) * k[j]; + s = x + s * w_[j]; + } + _gv[t] = F(gv); + } +} + +void cuda_forward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *y) +{ + assert(H*_N_ == C); + assert(_N_%4 == 0); + hipLaunchKernelGGL(( kernel_forward), dim3(dim3(B * H)), dim3(dim3(_N_)), 0, 0, B, T, C, H, r, k, v, w, u, y); +} + +void cuda_backward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, float *ww, bf16 *u, bf16 *gy, bf16 *gr, bf16 *gk, bf16 *gv, bf16 *gw, bf16 *gu) +{ + assert(H*_N_ == C); + assert(_N_%4 == 0); + hipLaunchKernelGGL(( kernel_backward), dim3(dim3(B * H)), dim3(dim3(_N_)), 0, 0, B, T, C, H, r, k, v, w, ww, u, gy, gr, gk, gv, gw, gu); +} diff --git a/wkv6/hip/wkv5_op.cpp b/wkv6/hip/wkv5_op.cpp new file mode 100644 index 0000000..01a6c2c --- /dev/null +++ b/wkv6/hip/wkv5_op.cpp @@ -0,0 +1,23 @@ +// !!! This is a file automatically generated by hipify!!! +#include +#include "ATen/ATen.h" +typedef at::BFloat16 bf16; + +void cuda_forward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *y); +void cuda_backward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, float *ww, bf16 *u, bf16 *gy, bf16 *gr, bf16 *gk, bf16 *gv, bf16 *gw, bf16 *gu); + +void forward(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) { + cuda_forward(B, T, C, H, r.data_ptr(), k.data_ptr(), v.data_ptr(), w.data_ptr(), u.data_ptr(), y.data_ptr()); +} +void backward(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &ww, torch::Tensor &u, torch::Tensor &gy, torch::Tensor &gr, torch::Tensor &gk, torch::Tensor &gv, torch::Tensor &gw, torch::Tensor &gu) { + cuda_backward(B, T, C, H, r.data_ptr(), k.data_ptr(), v.data_ptr(), w.data_ptr(), ww.data_ptr(), u.data_ptr(), gy.data_ptr(), gr.data_ptr(), gk.data_ptr(), gv.data_ptr(), gw.data_ptr(), gu.data_ptr()); +} +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("forward", &forward, "wkv5 forward"); + m.def("backward", &backward, "wkv5 backward"); +} + +TORCH_LIBRARY(wkv5, m) { + m.def("forward", forward); + m.def("backward", backward); +} diff --git a/wkv6/hip/wkv6_hip_v1.hip b/wkv6/hip/wkv6_hip_v1.hip new file mode 100644 index 0000000..519696f --- /dev/null +++ b/wkv6/hip/wkv6_hip_v1.hip @@ -0,0 +1,225 @@ +// !!! This is a file automatically generated by hipify!!! +#include "hip/hip_runtime.h" +#include +#include +#include "ATen/ATen.h" +typedef at::BFloat16 bf16; + +template +__global__ void kernel_forward(const int B, const int T, const int C, const int H, + const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const float *__restrict__ _w, const F *__restrict__ _u, + F *__restrict__ const _y) +{ + const int b = blockIdx.x / H; + const int h = blockIdx.x % H; + const int i = threadIdx.x; + _u += h*_N_; + + __shared__ float r[_N_], k[_N_], u[_N_], w[_N_]; + float state[_N_] = {0}; + + __syncthreads(); + u[i] = float(_u[i]); + __syncthreads(); + + for (int t = b*T*C + h*_N_ + i; t < (b+1)*T*C + h*_N_ + i; t += C) + { + __syncthreads(); + w[i] = exp(_w[t]); + r[i] = float(_r[t]); + k[i] = float(_k[t]); + __syncthreads(); + + const float v = float(_v[t]); + float y = 0; + + #pragma unroll + for (int j = 0; j < _N_; j+=4) + { + const float4& r_ = (float4&)(r[j]); + const float4& k_ = (float4&)(k[j]); + const float4& w_ = (float4&)(w[j]); + const float4& u_ = (float4&)(u[j]); + float4& s = (float4&)(state[j]); + float4 x; + + x.x = k_.x * v; + x.y = k_.y * v; + x.z = k_.z * v; + x.w = k_.w * v; + + y += r_.x * (u_.x * x.x + s.x); + y += r_.y * (u_.y * x.y + s.y); + y += r_.z * (u_.z * x.z + s.z); + y += r_.w * (u_.w * x.w + s.w); + + s.x = s.x * w_.x + x.x; + s.y = s.y * w_.y + x.y; + s.z = s.z * w_.z + x.z; + s.w = s.w * w_.w + x.w; + } + _y[t] = F(y); + } +} + +template +__global__ void kernel_backward(const int B, const int T, const int C, const int H, + const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const float *__restrict__ _w, const F *__restrict__ _u, const F *__restrict__ const _gy, + F *__restrict__ const _gr, F *__restrict__ const _gk, F *__restrict__ const _gv, F *__restrict__ const _gw, F *__restrict__ const _gu) +{ + const int b = blockIdx.x / H; + const int h = blockIdx.x % H; + const int i = threadIdx.x; + _u += h*_N_; + + __shared__ float u_[_N_]; + __shared__ float r[_N_], k[_N_], v[_N_], w_[_N_], gy[_N_]; + __syncthreads(); + u_[i] = float(_u[i]); + __syncthreads(); + + const float u = u_[i]; + + float state[_N_] = {0}, saaaa[_N_] = {0}, sbbbb[_N_] = {0}, scccc[_N_] = {0}, sdddd[_N_] = {0}; + float buf[_T_*_N_] = {0}; + + const int t_0 = b*T*C + h*_N_ + i; + const int t_1 = t_0 + C; + const int t_2 = t_0 + 2*C; + const int t_T_2 = t_0 + (T-2)*C; + const int t_T_1 = t_0 + (T-1)*C; + const int t_T = t_0 + T*C; + + float gu = 0; + for (int t = t_0; t < t_T; t += C) + { + __syncthreads(); + v[i] = float(_v[t]); + gy[i] = float(_gy[t]); + __syncthreads(); + + const float k = float(_k[t]); + const float w = exp(_w[t]); + float gr = 0, gu_ = 0; + + #pragma unroll + for (int j = 0; j < _N_; j++) + { + float& s = state[j]; + float x = k * v[j]; + + gr += (u * x + s) * gy[j]; + gu_ += x * gy[j]; + s = s * w + x; + } + _gr[t] = F(gr); + gu += float(_r[t]) * gu_; + } + _gu[b*C + h*_N_ + i] = F(gu); + + for (int t = t_0; t < t_T_2; t += C) + { + __syncthreads(); + v[i] = float(_v[t]); + __syncthreads(); + + const float k = float(_k[t]); + const float w = exp(_w[t]); + const int tt = (t-t_0)/C*_N_; + + #pragma unroll + for (int j = 0; j < _N_; j++) + { + float& s = saaaa[j]; + float x = k * v[j]; + + float tmp = w * s + x; + s = tmp; + buf[tt + j] = tmp; + } + } + + for (int t = t_T_1; t > t_1; t -= C) + { + __syncthreads(); + gy[i] = float(_gy[t]); + __syncthreads(); + + const float r = float(_r[t]); + const float w = exp(_w[t]); + float sum = 0.0f; + const int tt = (t-t_2)/C*_N_; + + #pragma unroll + for (int j = 0; j < _N_; j++) + { + float& s = sbbbb[j]; + float x = r * gy[j]; + + float tmp = w * s + x; + s = tmp; + sum += buf[tt + j] * tmp; + } + _gw[t-C] = F(sum * _w[t-C] * exp(_w[t-C])); + } + + for (int t = t_T_1; t >= t_0; t -= C) + { + __syncthreads(); + v[i] = float(_v[t]); + gy[i] = float(_gy[t]); + __syncthreads(); + + const float rr = float(_r[t]); + const float w = exp(_w[t]); + float gk = 0; + + #pragma unroll + for (int j = 0; j < _N_; j++) + { + float& s = scccc[j]; + float x = rr * gy[j]; + + gk += (u * x + s) * v[j]; + s = x + s * w; + } + _gk[t] = F(gk); + } + + for (int t = t_T_1; t >= t_0; t -= C) + { + __syncthreads(); + r[i] = float(_r[t]); + k[i] = float(_k[t]); + w_[i] = exp(_w[t]); + __syncthreads(); + + const float gyy = float(_gy[t]); + float gv = 0; + + #pragma unroll + for (int j = 0; j < _N_; j++) + { + float& s = sdddd[j]; + float x = gyy * r[j]; + + gv += (u_[j] * x + s) * k[j]; + s = x + s * w_[j]; + } + _gv[t] = F(gv); + } +} + +void cuda_forward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *y) +{ + assert(H*_N_ == C); + assert(_N_%4 == 0); + hipLaunchKernelGGL(( kernel_forward), dim3(dim3(B * H)), dim3(dim3(_N_)), 0, 0, B, T, C, H, r, k, v, w, u, y); +} + +void cuda_backward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *gy, bf16 *gr, bf16 *gk, bf16 *gv, bf16 *gw, bf16 *gu) +{ + assert(H*_N_ == C); + assert(_N_%4 == 0); + hipLaunchKernelGGL(( kernel_backward), dim3(dim3(B * H)), dim3(dim3(_N_)), 0, 0, B, T, C, H, r, k, v, w, u, gy, gr, gk, gv, gw, gu); +} diff --git a/wkv6/hip/wkv6_op.cpp b/wkv6/hip/wkv6_op.cpp new file mode 100644 index 0000000..ac45ef0 --- /dev/null +++ b/wkv6/hip/wkv6_op.cpp @@ -0,0 +1,23 @@ +// !!! This is a file automatically generated by hipify!!! +#include +#include "ATen/ATen.h" +typedef at::BFloat16 bf16; + +void cuda_forward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *y); +void cuda_backward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *gy, bf16 *gr, bf16 *gk, bf16 *gv, bf16 *gw, bf16 *gu); + +void forward(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) { + cuda_forward(B, T, C, H, r.data_ptr(), k.data_ptr(), v.data_ptr(), w.data_ptr(), u.data_ptr(), y.data_ptr()); +} +void backward(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &gy, torch::Tensor &gr, torch::Tensor &gk, torch::Tensor &gv, torch::Tensor &gw, torch::Tensor &gu) { + cuda_backward(B, T, C, H, r.data_ptr(), k.data_ptr(), v.data_ptr(), w.data_ptr(), u.data_ptr(), gy.data_ptr(), gr.data_ptr(), gk.data_ptr(), gv.data_ptr(), gw.data_ptr(), gu.data_ptr()); +} +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("forward", &forward, "wkv6 forward"); + m.def("backward", &backward, "wkv6 backward"); +} + +TORCH_LIBRARY(wkv6, m) { + m.def("forward", forward); + m.def("backward", backward); +} diff --git a/wkv6_state/hip/wkv6state_hip_v1.hip b/wkv6_state/hip/wkv6state_hip_v1.hip new file mode 100644 index 0000000..7527931 --- /dev/null +++ b/wkv6_state/hip/wkv6state_hip_v1.hip @@ -0,0 +1,307 @@ +// !!! This is a file automatically generated by hipify!!! +#include "hip/hip_runtime.h" +#include +#include +#include "ATen/ATen.h" +typedef at::BFloat16 bf16; + +template +__global__ void kernel_forward(const int B, const int T, const int C, const int H, + const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const float *__restrict__ _w, const F *__restrict__ _u, const F *__restrict__ _s, + F *__restrict__ const _y) +{ + const int b = blockIdx.x / H; + const int h = blockIdx.x % H; + const int i = threadIdx.x; + _u += h*_N_; + _s += h*_N_*_N_ + i*_N_; + + __shared__ float r[_N_], k[_N_], u[_N_], w[_N_]; + float state[_N_]; + + __syncthreads(); + u[i] = float(_u[i]); + __syncthreads(); + for (int j = 0; j < _N_; j++) { + state[j] = _s[j]; + } + + for (int t = b*T*C + h*_N_ + i; t < (b+1)*T*C + h*_N_ + i; t += C) + { + __syncthreads(); + w[i] = exp(_w[t]); + r[i] = float(_r[t]); + k[i] = float(_k[t]); + __syncthreads(); + + const float v = float(_v[t]); + float y = 0; + + #pragma unroll + for (int j = 0; j < _N_; j+=4) + { + const float4& r_ = (float4&)(r[j]); + const float4& k_ = (float4&)(k[j]); + const float4& w_ = (float4&)(w[j]); + const float4& u_ = (float4&)(u[j]); + float4& s = (float4&)(state[j]); + float4 x; + + x.x = k_.x * v; + x.y = k_.y * v; + x.z = k_.z * v; + x.w = k_.w * v; + + y += r_.x * (u_.x * x.x + s.x); + y += r_.y * (u_.y * x.y + s.y); + y += r_.z * (u_.z * x.z + s.z); + y += r_.w * (u_.w * x.w + s.w); + + s.x = s.x * w_.x + x.x; + s.y = s.y * w_.y + x.y; + s.z = s.z * w_.z + x.z; + s.w = s.w * w_.w + x.w; + } + _y[t] = F(y); + } +} + +template +__global__ void kernel_backward_111(const int B, const int T, const int C, const int H, + const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const float *__restrict__ _w, const F *__restrict__ _u, const F *__restrict__ _s, const F *__restrict__ const _gy, + F *__restrict__ const _gr, F *__restrict__ const _gk, F *__restrict__ const _gv, F *__restrict__ const _gu, F *__restrict__ const _gs) +{ + const int b = blockIdx.x / H; + const int h = blockIdx.x % H; + const int i = threadIdx.x; + _u += h*_N_; + _s += h*_N_*_N_ + i; + + __shared__ float u_[_N_]; + __shared__ float r[_N_], k[_N_], v[_N_], w_[_N_], gy[_N_]; + __syncthreads(); + u_[i] = float(_u[i]); + __syncthreads(); + + const float u = u_[i]; + + float state[_N_], scccc[_N_] = {0}, sdddd[_N_] = {0}, sssss[_N_] = {0}, swwww[_N_]; + for (int j = 0; j < _N_; j++) { + state[j] = _s[j*_N_]; + swwww[j] = 1.0; + } + + const int t_0 = b*T*C + h*_N_ + i; + const int t_T_1 = t_0 + (T-1)*C; + const int t_T = t_0 + T*C; + + float gu = 0; + for (int t = t_0; t < t_T; t += C) + { + __syncthreads(); + v[i] = float(_v[t]); + gy[i] = float(_gy[t]); + __syncthreads(); + + const float k = float(_k[t]); + const float w = exp(_w[t]); + float gr = 0, gu_ = 0; + + #pragma unroll + for (int j = 0; j < _N_; j++) + { + float& s = state[j]; + float x = k * v[j]; + + gr += (u * x + s) * gy[j]; + gu_ += x * gy[j]; + s = s * w + x; + } + _gr[t] = F(gr); + gu += float(_r[t]) * gu_; + } + _gu[b*C + h*_N_ + i] = F(gu); + + for (int t = t_T_1; t >= t_0; t -= C) + { + __syncthreads(); + v[i] = float(_v[t]); + gy[i] = float(_gy[t]); + __syncthreads(); + + const float rr = float(_r[t]); + const float w = exp(_w[t]); + float gk = 0; + + #pragma unroll + for (int j = 0; j < _N_; j++) + { + float& s = scccc[j]; + float x = rr * gy[j]; + + gk += (u * x + s) * v[j]; + s = x + s * w; + } + _gk[t] = F(gk); + } + + for (int t = t_T_1; t >= t_0; t -= C) + { + __syncthreads(); + r[i] = float(_r[t]); + k[i] = float(_k[t]); + w_[i] = exp(_w[t]); + __syncthreads(); + + const float gyy = float(_gy[t]); + float gv = 0; + + #pragma unroll + for (int j = 0; j < _N_; j++) + { + float& s = sdddd[j]; + float x = gyy * r[j]; + + gv += (u_[j] * x + s) * k[j]; + s = x + s * w_[j]; + } + _gv[t] = F(gv); + } + + for (int t = t_0; t < t_T; t += C) + { + __syncthreads(); + r[i] = float(_r[t]); + w_[i] = exp(_w[t]); + __syncthreads(); + + const float gyy = float(_gy[t]); + + #pragma unroll + for (int j = 0; j < _N_; j++) + { + float& w = swwww[j]; + sssss[j] += gyy * w * r[j]; + w *= w_[j]; + } + } + for (int j = 0; j < _N_; j++) + _gs[b*H*_N_*_N_ + h*_N_*_N_ + i*_N_ + j] = sssss[j]; +} + +template +__global__ void kernel_backward_222(const int B, const int T, const int C, const int H, + const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const float *__restrict__ _w, const F *__restrict__ _u, const F *__restrict__ _s, const F *__restrict__ const _gy, + F *__restrict__ const _gw) +{ + const int b = blockIdx.x / H; + const int h = blockIdx.x % H; + const int i = threadIdx.x; + _s += h*_N_*_N_ + i; + + __shared__ float v[_N_], gy[_N_]; + float saaaa[_N_] = {0}, sbbbb[_T_-1] = {0}, scccc[_N_] = {0}; + + const int t_0 = b*T*C + h*_N_ + i; + const int t_1 = t_0 + C; + const int t_2 = t_0 + 2*C; + const int t_T_1 = t_0 + (T-1)*C; + + for (int t = t_T_1; t > t_1; t -= C) + { + __syncthreads(); + gy[i] = float(_gy[t]); + v[i] = float(_v[t-2*C]); + __syncthreads(); + + const float r = float(_r[t]); + const float w = exp(_w[t-C]); + float sum = 0.0f; + + #pragma unroll + for (int j = 0; j < _N_; j++) + { + float& s = saaaa[j]; + s = (s + r * gy[j]) * w; + sum += s * v[j]; + } + sbbbb[(t-t_1)/C] = sum * float(_k[t-2*C]); + } + { + __syncthreads(); + gy[i] = float(_gy[t_1]); + __syncthreads(); + + const float r = float(_r[t_1]); + const float w = exp(_w[t_0]); + float sum = 0.0f; + + #pragma unroll + for (int j = 0; j < _N_; j++) + { + float& s = saaaa[j]; + s = (s + r * gy[j]) * w; + sum += s * _s[j*_N_]; + } + sbbbb[0] = sum; + } + + float sss = sbbbb[0]; + _gw[t_0] = F(sss * _w[t_0]); + + { + __syncthreads(); + gy[i] = float(_gy[t_1]); + __syncthreads(); + + const float w = exp(_w[t_0]); + float sum = 0.0f; + + #pragma unroll + for (int j = 0; j < _N_; j++) + { + float& s = scccc[j]; + s = (s + _s[j*_N_]) * w; + sum += s * gy[j]; + } + sss += sbbbb[1] - (sum * float(_r[t_1])); + _gw[t_1] = F(sss * _w[t_1]); + } + for (int t = t_2; t < t_T_1; t += C) + { + __syncthreads(); + gy[i] = float(_gy[t]); + v[i] = float(_v[t-2*C]); + __syncthreads(); + + const float w = exp(_w[t-C]); + const float k = float(_k[t-2*C]); + float sum = 0.0f; + + #pragma unroll + for (int j = 0; j < _N_; j++) + { + float& s = scccc[j]; + s = (s + k * v[j]) * w; + sum += s * gy[j]; + } + sss += sbbbb[(t-t_0)/C] - (sum * float(_r[t])); + _gw[t] = F(sss * _w[t]); + } + _gw[t_T_1] = 0; +} + +void cuda_forward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *z, bf16 *y) +{ + assert(H*_N_ == C); + assert(_N_%4 == 0); + hipLaunchKernelGGL(( kernel_forward), dim3(dim3(B * H)), dim3(dim3(_N_)), 0, 0, B, T, C, H, r, k, v, w, u, z, y); +} + +void cuda_backward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *z, bf16 *gy, bf16 *gr, bf16 *gk, bf16 *gv, bf16 *gw, bf16 *gu, bf16 *gz) +{ + assert(H*_N_ == C); + assert(_N_%4 == 0); + hipLaunchKernelGGL(( kernel_backward_111), dim3(dim3(B * H)), dim3(dim3(_N_)), 0, 0, B, T, C, H, r, k, v, w, u, z, gy, gr, gk, gv, gu, gz); + hipLaunchKernelGGL(( kernel_backward_222), dim3(dim3(B * H)), dim3(dim3(_N_)), 0, 0, B, T, C, H, r, k, v, w, u, z, gy, gw); +} diff --git a/wkv6_state/hip/wkv6state_op.cpp b/wkv6_state/hip/wkv6state_op.cpp new file mode 100644 index 0000000..17bbb0c --- /dev/null +++ b/wkv6_state/hip/wkv6state_op.cpp @@ -0,0 +1,23 @@ +// !!! This is a file automatically generated by hipify!!! +#include +#include "ATen/ATen.h" +typedef at::BFloat16 bf16; + +void cuda_forward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *s, bf16 *y); +void cuda_backward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *s, bf16 *gy, bf16 *gr, bf16 *gk, bf16 *gv, bf16 *gw, bf16 *gu, bf16 *gs); + +void forward(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &s, torch::Tensor &y) { + cuda_forward(B, T, C, H, r.data_ptr(), k.data_ptr(), v.data_ptr(), w.data_ptr(), u.data_ptr(), s.data_ptr(), y.data_ptr()); +} +void backward(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &s, torch::Tensor &gy, torch::Tensor &gr, torch::Tensor &gk, torch::Tensor &gv, torch::Tensor &gw, torch::Tensor &gu, torch::Tensor &gs) { + cuda_backward(B, T, C, H, r.data_ptr(), k.data_ptr(), v.data_ptr(), w.data_ptr(), u.data_ptr(), s.data_ptr(), gy.data_ptr(), gr.data_ptr(), gk.data_ptr(), gv.data_ptr(), gw.data_ptr(), gu.data_ptr(), gs.data_ptr()); +} +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("forward", &forward, "wkv6b forward"); + m.def("backward", &backward, "wkv6b backward"); +} + +TORCH_LIBRARY(wkv6b, m) { + m.def("forward", forward); + m.def("backward", backward); +} From 9ca999893a401efcb7ab7372e9318ffa6d675540 Mon Sep 17 00:00:00 2001 From: YuChuXi <81864000+YuChuXi@users.noreply.github.com> Date: Sun, 12 May 2024 18:52:19 +0800 Subject: [PATCH 14/16] hip --- wkv/run.py | 12 ++++++++++-- wkv6/run.py | 4 ++-- wkv6_state/run.py | 2 +- 3 files changed, 13 insertions(+), 5 deletions(-) diff --git a/wkv/run.py b/wkv/run.py index 4e5231f..6cca0d2 100644 --- a/wkv/run.py +++ b/wkv/run.py @@ -83,8 +83,16 @@ def RUN_PYTORCH(B, T, C, w, u, k, v, time_curve): ###################################################################################################### from torch.utils.cpp_extension import load -wkv_cuda = load(name="wkv", sources=["cuda/wkv_op.cpp", f"cuda/wkv_cuda_v{CUDA_KERNEL_VERSION}.cu"], - verbose=True, extra_cuda_cflags=["-O3", "-xhip", "--hipstdpar"]) + +run_rocm = False + +if run_rocm: + wkv_cuda = load(name="wkv", sources=["hip/wkv_op.cpp", f"hip/wkv_hip_v{CUDA_KERNEL_VERSION}.hip"], + verbose=True, extra_cuda_cflags=["-O3", "-xhip", "--hipstdpar"]) +else: + wkv_cuda = load(name="wkv", sources=["cuda/wkv_op.cpp", f"cuda/wkv_cuda_v{CUDA_KERNEL_VERSION}.cu"], + verbose=True, extra_cuda_cflags=['--use_fast_math', '--extra-device-vectorization']) + class WKV(torch.autograd.Function): @staticmethod diff --git a/wkv6/run.py b/wkv6/run.py index fb06ce0..d5c5b44 100644 --- a/wkv6/run.py +++ b/wkv6/run.py @@ -69,7 +69,7 @@ def val(x): run_rocm = True if run_rocm: - wkv5_cuda = load(name="wkv5", sources=["cuda/wkv5_op.cpp", f"cuda/wkv5_cuda_v1b2.cu"], + wkv5_cuda = load(name="wkv5", sources=["hip/wkv5_op.cpp", f"cuda/wkv5_hip_v1b2.hip"], verbose=True, extra_cuda_cflags=["-O3", "--hipstdpar", "-xhip", f"-D_N_={HEAD_SIZE}"]) else: wkv5_cuda = load(name="wkv5", sources=["cuda/wkv5_op.cpp", f"cuda/wkv5_cuda_v1b2.cu"], @@ -129,7 +129,7 @@ def RUN_CUDA_5(B, T, C, H, r, k, v, w, u): ######################################################################################################## if run_rocm: - wkv6_cuda = load(name="wkv6", sources=["cuda/wkv6_op.cpp", f"cuda/wkv6_cuda_{CUDA_KERNEL_VERSION}.cu"], + wkv6_cuda = load(name="wkv6", sources=["hip/wkv6_op.cpp", f"hip/wkv6_hip_{CUDA_KERNEL_VERSION}.hip"], verbose=True, extra_cuda_cflags=["-O3", "--hipstdpar", "-xhip", f"-D_N_={HEAD_SIZE}", f"-D_T_={T}"]) else: diff --git a/wkv6_state/run.py b/wkv6_state/run.py index 3f98e2c..c54ca9b 100644 --- a/wkv6_state/run.py +++ b/wkv6_state/run.py @@ -77,7 +77,7 @@ def val(x): run_rocm = True if run_rocm: - wkv6state_cuda = load(name="wkv6state", sources=["cuda/wkv6state_op.cpp", f"cuda/wkv6state_cuda_{CUDA_KERNEL_VERSION}.cu"], + wkv6state_cuda = load(name="wkv6state", sources=["hip/wkv6state_op.cpp", f"hip/wkv6state_hip_{CUDA_KERNEL_VERSION}.hip"], verbose=True, extra_cuda_cflags=["-O3", "--hipstdpar", "-xhip", f"-D_N_={HEAD_SIZE}", f"-D_T_={T}"]) else: wkv6state_cuda = load(name="wkv6state", sources=["cuda/wkv6state_op.cpp", f"cuda/wkv6state_cuda_{CUDA_KERNEL_VERSION}.cu"], From 757aafb5a20be692d0a11c74428f8a7ef37a279f Mon Sep 17 00:00:00 2001 From: YuChuXi <81864000+YuChuXi@users.noreply.github.com> Date: Sun, 12 May 2024 19:02:32 +0800 Subject: [PATCH 15/16] dd --- wkv6_state/run.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/wkv6_state/run.py b/wkv6_state/run.py index c54ca9b..29fd0ba 100644 --- a/wkv6_state/run.py +++ b/wkv6_state/run.py @@ -78,7 +78,7 @@ def val(x): if run_rocm: wkv6state_cuda = load(name="wkv6state", sources=["hip/wkv6state_op.cpp", f"hip/wkv6state_hip_{CUDA_KERNEL_VERSION}.hip"], - verbose=True, extra_cuda_cflags=["-O3", "--hipstdpar", "-xhip", f"-D_N_={HEAD_SIZE}", f"-D_T_={T}"]) + verbose=True, extra_cuda_cflags=["-O3", "--hipstdpar", "-xhip", "--hip-link", f"-D_N_={HEAD_SIZE}", f"-D_T_={T}"]) else: wkv6state_cuda = load(name="wkv6state", sources=["cuda/wkv6state_op.cpp", f"cuda/wkv6state_cuda_{CUDA_KERNEL_VERSION}.cu"], verbose=True, extra_cuda_cflags=["-res-usage", "--use_fast_math", "-Xptxas -O3", "--extra-device-vectorization", f"-D_N_={HEAD_SIZE}", f"-D_T_={T}"]) From e77d4def60ebd17aa829b9efe43c84a0c2508d29 Mon Sep 17 00:00:00 2001 From: YuChuXi Date: Sun, 12 May 2024 19:12:43 +0800 Subject: [PATCH 16/16] sss --- wkv/run.py | 2 +- wkv6/run.py | 4 ++-- wkv6_state/run.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/wkv/run.py b/wkv/run.py index 6cca0d2..4c150ea 100644 --- a/wkv/run.py +++ b/wkv/run.py @@ -87,7 +87,7 @@ def RUN_PYTORCH(B, T, C, w, u, k, v, time_curve): run_rocm = False if run_rocm: - wkv_cuda = load(name="wkv", sources=["hip/wkv_op.cpp", f"hip/wkv_hip_v{CUDA_KERNEL_VERSION}.hip"], + wkv_cuda = load(name="wkv", sources=["cuda/wkv_op.cpp", f"hip/wkv_cuda_v{CUDA_KERNEL_VERSION}.cu"], verbose=True, extra_cuda_cflags=["-O3", "-xhip", "--hipstdpar"]) else: wkv_cuda = load(name="wkv", sources=["cuda/wkv_op.cpp", f"cuda/wkv_cuda_v{CUDA_KERNEL_VERSION}.cu"], diff --git a/wkv6/run.py b/wkv6/run.py index d5c5b44..fb06ce0 100644 --- a/wkv6/run.py +++ b/wkv6/run.py @@ -69,7 +69,7 @@ def val(x): run_rocm = True if run_rocm: - wkv5_cuda = load(name="wkv5", sources=["hip/wkv5_op.cpp", f"cuda/wkv5_hip_v1b2.hip"], + wkv5_cuda = load(name="wkv5", sources=["cuda/wkv5_op.cpp", f"cuda/wkv5_cuda_v1b2.cu"], verbose=True, extra_cuda_cflags=["-O3", "--hipstdpar", "-xhip", f"-D_N_={HEAD_SIZE}"]) else: wkv5_cuda = load(name="wkv5", sources=["cuda/wkv5_op.cpp", f"cuda/wkv5_cuda_v1b2.cu"], @@ -129,7 +129,7 @@ def RUN_CUDA_5(B, T, C, H, r, k, v, w, u): ######################################################################################################## if run_rocm: - wkv6_cuda = load(name="wkv6", sources=["hip/wkv6_op.cpp", f"hip/wkv6_hip_{CUDA_KERNEL_VERSION}.hip"], + wkv6_cuda = load(name="wkv6", sources=["cuda/wkv6_op.cpp", f"cuda/wkv6_cuda_{CUDA_KERNEL_VERSION}.cu"], verbose=True, extra_cuda_cflags=["-O3", "--hipstdpar", "-xhip", f"-D_N_={HEAD_SIZE}", f"-D_T_={T}"]) else: diff --git a/wkv6_state/run.py b/wkv6_state/run.py index 29fd0ba..d4071c9 100644 --- a/wkv6_state/run.py +++ b/wkv6_state/run.py @@ -77,7 +77,7 @@ def val(x): run_rocm = True if run_rocm: - wkv6state_cuda = load(name="wkv6state", sources=["hip/wkv6state_op.cpp", f"hip/wkv6state_hip_{CUDA_KERNEL_VERSION}.hip"], + wkv6state_cuda = load(name="wkv6state", sources=["cuda/wkv6state_op.cpp", f"cuda/wkv6state_cuda_{CUDA_KERNEL_VERSION}.cu"], verbose=True, extra_cuda_cflags=["-O3", "--hipstdpar", "-xhip", "--hip-link", f"-D_N_={HEAD_SIZE}", f"-D_T_={T}"]) else: wkv6state_cuda = load(name="wkv6state", sources=["cuda/wkv6state_op.cpp", f"cuda/wkv6state_cuda_{CUDA_KERNEL_VERSION}.cu"],