Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

增加ROCm支持 #20

Open
wants to merge 16 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 132 additions & 0 deletions wkv/hip/wkv_hip_v2.hip
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
// !!! This is a file automatically generated by hipify!!!
#include "hip/hip_runtime.h"
#include <stdio.h>
#include <assert.h>

template <typename F>
__global__ void kernel_forward(const int B, const int T, const int C,
const F *__restrict__ const _w, const F *__restrict__ const _u, const F *__restrict__ const _k, const F *__restrict__ const _v,
F *__restrict__ const _y)
{
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
const int _b = idx / C;
const int _c = idx % C;
const int _offset = _b * T * C + _c;

F u = _u[_c];
F w = _w[_c];
const F *__restrict__ const k = _k + _offset;
const F *__restrict__ const v = _v + _offset;
F *__restrict__ const y = _y + _offset;

F p = 0, q = 0, o = -65500;
// p and q are running sums divided by exp(o) (to avoid overflows)
for (int i = 0; i < T; i++)
{
const int ii = i * C;

F no = max(o, u + k[ii]);
F A = exp(o - no);
F B = exp(u + k[ii] - no);
y[ii] = (A * p + B * v[ii]) / (A * q + B);

no = max(w + o, k[ii]);
A = exp(w + o - no);
B = exp(k[ii] - no);
p = A * p + B * v[ii];
q = A * q + B;
o = no;
}
}

template <typename F>
__global__ void kernel_backward(const int B, const int T, const int C,
const F *__restrict__ const _w, const F *__restrict__ const _u, const F *__restrict__ const _k, const F *__restrict__ const _v, const F *__restrict__ const _gy,
F *__restrict__ const _gw, F *__restrict__ const _gu, F *__restrict__ const _gk, F *__restrict__ const _gv)
{
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
const int _b = idx / C;
const int _c = idx % C;
const int _offset = _b * T * C + _c;

F u = _u[_c];
F w = _w[_c];
const F *__restrict__ const k = _k + _offset;
const F *__restrict__ const v = _v + _offset;
const F *__restrict__ const gy = _gy + _offset;

F *__restrict__ const gk = _gk + _offset;
F *__restrict__ const gv = _gv + _offset;

F y[4096], z[4096], zexp[4096];

F gw = 0, gu = 0;
F p = 0, q = 0;
F dpdw = 0, dqdw = 0;
F o = -65500;
for (int i = 0; i < T; i++)
{
const int ii = i * C;
F no = max(o, k[ii] + u);
F A = exp(o - no);
F B = exp(k[ii] + u - no);

F num = A * p + B * v[ii];
F iden = 1 / (A * q + B);

y[i] = num * iden;
z[i] = iden;
zexp[i] = k[ii] + u - no;

gw += gy[ii] * (dpdw - dqdw * y[i]) * iden * A;
gu += gy[ii] * (v[ii] - y[i]) * B * iden;

no = max(w + o, k[ii]);
A = exp(w + o - no);
B = exp(k[ii] - no);
dpdw = A * (p + dpdw);
dqdw = A * (q + dqdw);
p = A * p + B * v[ii];
q = A * q + B;
o = no;
}

F gp = 0, gq = 0;
o = -65500;
for (int i = T - 1; i >= 0; i--)
{
const int ii = i * C;
F A = gy[ii] * z[i] * exp(zexp[i]);
F B = exp(k[ii] + o);
gk[ii] = A * (v[ii] - y[i]) + B * (gp * v[ii] + gq);
gv[ii] = A + B * gp;

F no = max(w + o, zexp[i] - k[ii] - u);
A = exp(w + o - no);
B = gy[ii] * z[i] * exp(zexp[i] - k[ii] - u - no);
gp = A * gp + B;
gq = A * gq - B * y[i];
o = no;
}

// Multiply by w because the w -> -exp(w) preprocessing is halfway in the backwards pass, even though it's not in the forward pass
const int _offsetBC = _b * C + _c;
_gw[_offsetBC] += gw * _w[_c];
_gu[_offsetBC] += gu;
}

void cuda_forward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y)
{
dim3 threadsPerBlock( min(C, 32) );
assert(B * C % threadsPerBlock.x == 0);
dim3 numBlocks(B * C / threadsPerBlock.x);
hipLaunchKernelGGL(( kernel_forward), dim3(numBlocks), dim3(threadsPerBlock), 0, 0, B, T, C, w, u, k, v, y);
}

void cuda_backward(int B, int T, int C, float *w, float *u, float *k, float *v, float *gy, float *gw, float *gu, float *gk, float *gv)
{
dim3 threadsPerBlock( min(C, 32) );
assert(B * C % threadsPerBlock.x == 0);
dim3 numBlocks(B * C / threadsPerBlock.x);
hipLaunchKernelGGL(( kernel_backward), dim3(numBlocks), dim3(threadsPerBlock), 0, 0, B, T, C, w, u, k, v, gy, gw, gu, gk, gv);
}
22 changes: 22 additions & 0 deletions wkv/hip/wkv_op.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
// !!! This is a file automatically generated by hipify!!!
#include <torch/extension.h>

void cuda_forward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y);
void cuda_backward(int B, int T, int C, float *w, float *u, float *k, float *v, float *gy, float *gw, float *gu, float *gk, float *gv);

void forward(int64_t B, int64_t T, int64_t C, torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y) {
cuda_forward(B, T, C, w.data_ptr<float>(), u.data_ptr<float>(), k.data_ptr<float>(), v.data_ptr<float>(), y.data_ptr<float>());
}
void backward(int64_t B, int64_t T, int64_t C, torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &gy, torch::Tensor &gw, torch::Tensor &gu, torch::Tensor &gk, torch::Tensor &gv) {
cuda_backward(B, T, C, w.data_ptr<float>(), u.data_ptr<float>(), k.data_ptr<float>(), v.data_ptr<float>(), gy.data_ptr<float>(), gw.data_ptr<float>(), gu.data_ptr<float>(), gk.data_ptr<float>(), gv.data_ptr<float>());
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &forward, "wkv forward");
m.def("backward", &backward, "wkv backward");
}

TORCH_LIBRARY(wkv, m) {
m.def("forward", forward);
m.def("backward", backward);
}
12 changes: 10 additions & 2 deletions wkv/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,16 @@ def RUN_PYTORCH(B, T, C, w, u, k, v, time_curve):
######################################################################################################

from torch.utils.cpp_extension import load
wkv_cuda = load(name="wkv", sources=["cuda/wkv_op.cpp", f"cuda/wkv_cuda_v{CUDA_KERNEL_VERSION}.cu"],
verbose=True, extra_cuda_cflags=['--use_fast_math', '--extra-device-vectorization'], extra_cflags=['/wd4624'])

run_rocm = False

if run_rocm:
wkv_cuda = load(name="wkv", sources=["cuda/wkv_op.cpp", f"hip/wkv_cuda_v{CUDA_KERNEL_VERSION}.cu"],
verbose=True, extra_cuda_cflags=["-O3", "-xhip", "--hipstdpar"])
else:
wkv_cuda = load(name="wkv", sources=["cuda/wkv_op.cpp", f"cuda/wkv_cuda_v{CUDA_KERNEL_VERSION}.cu"],
verbose=True, extra_cuda_cflags=['--use_fast_math', '--extra-device-vectorization'])


class WKV(torch.autograd.Function):
@staticmethod
Expand Down
204 changes: 204 additions & 0 deletions wkv6/hip/wkv5_hip_v1b2.hip
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
// !!! This is a file automatically generated by hipify!!!
#include "hip/hip_runtime.h"
#include <stdio.h>
#include <assert.h>
#include "ATen/ATen.h"
typedef at::BFloat16 bf16;

template <typename F>
__global__ void kernel_forward(const int B, const int T, const int C, const int H,
const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const float *__restrict__ _w, const F *__restrict__ _u,
F *__restrict__ const _y)
{
const int b = blockIdx.x / H;
const int h = blockIdx.x % H;
const int i = threadIdx.x;
_w += h*_N_;
_u += h*_N_;

__shared__ float r[_N_], k[_N_], u[_N_], w[_N_];
float state[_N_] = {0};

__syncthreads();
w[i] = _w[i];
u[i] = float(_u[i]);
__syncthreads();

for (int t = b*T*C + h*_N_ + i; t < (b+1)*T*C + h*_N_ + i; t += C)
{
__syncthreads();
r[i] = float(_r[t]);
k[i] = float(_k[t]);
__syncthreads();

const float v = float(_v[t]);
float y = 0;

#pragma unroll
for (int j = 0; j < _N_; j+=4)
{
const float4& r_ = (float4&)(r[j]);
const float4& k_ = (float4&)(k[j]);
const float4& w_ = (float4&)(w[j]);
const float4& u_ = (float4&)(u[j]);
float4& s = (float4&)(state[j]);
float4 x;

x.x = k_.x * v;
x.y = k_.y * v;
x.z = k_.z * v;
x.w = k_.w * v;

y += r_.x * (u_.x * x.x + s.x);
y += r_.y * (u_.y * x.y + s.y);
y += r_.z * (u_.z * x.z + s.z);
y += r_.w * (u_.w * x.w + s.w);

s.x = s.x * w_.x + x.x;
s.y = s.y * w_.y + x.y;
s.z = s.z * w_.z + x.z;
s.w = s.w * w_.w + x.w;
}
_y[t] = F(y);
}
}

template <typename F>
__global__ void kernel_backward(const int B, const int T, const int C, const int H,
const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const float *__restrict__ _w, const float *__restrict__ __w, const F *__restrict__ _u, const F *__restrict__ const _gy,
F *__restrict__ const _gr, F *__restrict__ const _gk, F *__restrict__ const _gv, F *__restrict__ const _gw, F *__restrict__ const _gu)
{
const int b = blockIdx.x / H;
const int h = blockIdx.x % H;
const int i = threadIdx.x;
_w += h*_N_;
_u += h*_N_;
__w += h*_N_;

__shared__ float w_[_N_], u_[_N_];
__shared__ float r[_N_], k[_N_], v[_N_], gy[_N_];
__syncthreads();
w_[i] = _w[i];
u_[i] = float(_u[i]);
__syncthreads();

const float w = w_[i];
const float ww = __w[i];
const float u = u_[i];

float state[_N_] = {0}, saaaa[_N_] = {0}, sbbbb[_N_] = {0}, scccc[_N_] = {0}, sdddd[_N_] = {0};

float gw = 0, gu = 0;
const int t000 = b*T*C + h*_N_ + i;
const int t111 = (b+1)*T*C + h*_N_ + i;
const int t222 = t111 - 2*C;

for (int t = t000; t < t111; t += C)
{
__syncthreads();
v[i] = float(_v[t]);
gy[i] = float(_gy[t]);
__syncthreads();

const float k = float(_k[t]);
float gr = 0, gu_ = 0;

#pragma unroll
for (int j = 0; j < _N_; j++)
{
float& s = state[j];
float x = k * v[j];

gr += (u * x + s) * gy[j];
gu_ += x * gy[j];
s = s * w + x;
}
_gr[t] = F(gr);
gu += float(_r[t]) * gu_;
}
_gu[b*C + h*_N_ + i] = F(gu);

for (int t = t000; t < t222; t += C)
{
__syncthreads();
v[i] = float(_v[t]);
gy[i] = float(_gy[t + 2*C]);
__syncthreads();

const float k = float(_k[t]);
float gw_ = 0;

#pragma unroll
for (int j = 0; j < _N_; j++)
{
float& s = saaaa[j];
float& s2 = sbbbb[j];
float x = k * v[j];

float tmp = w * (x + s);
s = tmp;
s2 = tmp + w * s2;
gw_ += s2 * gy[j];
}
gw += float(_r[t + 2*C]) * gw_;
}
_gw[b*C + h*_N_ + i] = F(ww * gw);

for (int t = t111 - C; t >= t000; t -= C)
{
__syncthreads();
v[i] = float(_v[t]);
gy[i] = float(_gy[t]);
__syncthreads();

const float rr = float(_r[t]);
float gk = 0;

#pragma unroll
for (int j = 0; j < _N_; j++)
{
float& s = scccc[j];
float x = rr * gy[j];

gk += (u * x + s) * v[j];
s = x + s * w;
}
_gk[t] = F(gk);
}

for (int t = t111 - C; t >= t000; t -= C)
{
__syncthreads();
r[i] = float(_r[t]);
k[i] = float(_k[t]);
__syncthreads();

const float gyy = float(_gy[t]);
float gv = 0;

#pragma unroll
for (int j = 0; j < _N_; j++)
{
float& s = sdddd[j];
float x = gyy * r[j];

gv += (u_[j] * x + s) * k[j];
s = x + s * w_[j];
}
_gv[t] = F(gv);
}
}

void cuda_forward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *y)
{
assert(H*_N_ == C);
assert(_N_%4 == 0);
hipLaunchKernelGGL(( kernel_forward), dim3(dim3(B * H)), dim3(dim3(_N_)), 0, 0, B, T, C, H, r, k, v, w, u, y);
}

void cuda_backward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, float *ww, bf16 *u, bf16 *gy, bf16 *gr, bf16 *gk, bf16 *gv, bf16 *gw, bf16 *gu)
{
assert(H*_N_ == C);
assert(_N_%4 == 0);
hipLaunchKernelGGL(( kernel_backward), dim3(dim3(B * H)), dim3(dim3(_N_)), 0, 0, B, T, C, H, r, k, v, w, ww, u, gy, gr, gk, gv, gw, gu);
}
Loading