Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 138 additions & 0 deletions flagscale/train/models/rwkv/cuda/wkv7_cuda.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
#include <cuda_bf16.h>
#include <assert.h>

using bf = __nv_bfloat16;
__device__ inline float to_float(const bf & u) { return __bfloat162float(u); }
__device__ inline bf to_bf(const float & u) { return __float2bfloat16_rn(u); }

typedef bf * __restrict__ F_;

__global__ void forward_kernel(int T, int H, F_ w_, F_ q_, F_ k_, F_ v_, F_ a_, F_ b_, bf* y_, float* s_, float* sa_) {
constexpr int C = _C_;
int bb = blockIdx.y, hh = blockIdx.x, i = threadIdx.x;

float state[C] = {0};
__shared__ float q[C], k[C], w[C], a[C], b[C];

for (int t = 0; t < T; t++) {
int ind = bb*T*H*C + t*H*C + hh * C + i;
__syncthreads();
q[i] = to_float(q_[ind]);
w[i] = __expf(-__expf(to_float(w_[ind])));
k[i] = to_float(k_[ind]);
a[i] = to_float(a_[ind]);
b[i] = to_float(b_[ind]);
__syncthreads();

float sa = 0;
#pragma unroll
for (int j = 0; j < C; j++) {
sa += a[j] * state[j];
}
sa_[ind] = sa;

float v = to_float(v_[ind]);
float y = 0;
#pragma unroll
for (int j = 0; j < C; j++) {
float& s = state[j];
s = s * w[j] + sa * b[j] + k[j] * v;
y += s * q[j];
}
y_[ind] = to_bf(y);

if ((t+1)%_CHUNK_LEN_ == 0) {
int base = (bb*H+hh)*(T/_CHUNK_LEN_)*C*C + (t/_CHUNK_LEN_)*C*C + i;
#pragma unroll
for (int j = 0; j < C; j++) {
s_[base + j*C] = state[j];
}
}
}
}

__global__ void backward_kernel(int T, int H, F_ w_, F_ q_, F_ k_, F_ v_, F_ a_, F_ b_, F_ dy_, float * __restrict__ s_, float * __restrict__ sa_, bf* dw_, bf* dq_, bf* dk_, bf* dv_, bf* da_, bf* db_) {
constexpr int C = _C_;
int bb = blockIdx.y, hh = blockIdx.x, i = threadIdx.x;

float stateT[C] = {0}, dstate[C] = {0}, dstateT[C] = {0};
__shared__ float w[C], q[C], k[C], v[C], a[C], b[C], dy[C], sa[C], dSb_shared[C];
float qi, wi, ki, ai, bi, dyi;

for (int t = T-1; t >= 0; t--) {
int ind = bb*T*H*C + t*H*C + hh * C + i;
__syncthreads();
q[i] = qi = to_float(q_[ind]);
float wi_fac = -__expf(to_float(w_[ind]));
w[i] = wi = __expf(wi_fac);
k[i] = ki = to_float(k_[ind]);
a[i] = ai = to_float(a_[ind]);
b[i] = bi = to_float(b_[ind]);
v[i] = to_float(v_[ind]);
dy[i] = dyi = to_float(dy_[ind]);
sa[i] = sa_[ind];
__syncthreads();

if ((t+1)%_CHUNK_LEN_ == 0) {
int base = (bb*H+hh)*(T/_CHUNK_LEN_)*C*C + (t/_CHUNK_LEN_)*C*C + i*C;
#pragma unroll
for (int j = 0; j < C; j++) {
stateT[j] = s_[base + j];
}
}

float dq = 0;
#pragma unroll
for (int j = 0; j < C; j++) {
dq += stateT[j]*dy[j];
}
dq_[ind] = to_bf(dq);

float iwi = 1.0f/wi;
#pragma unroll
for (int j = 0; j < C; j++) {
stateT[j] = (stateT[j] - ki*v[j] - bi*sa[j]) * iwi;
dstate[j] += dyi * q[j];
dstateT[j] += qi * dy[j];
}

float dw = 0, dk = 0, dv = 0, db = 0, dSb = 0;
#pragma unroll
for (int j = 0; j < C; j++) {
dw += dstateT[j]*stateT[j];
dk += dstateT[j]*v[j];
dv += dstate[j]*k[j];
dSb += dstate[j]*b[j];
db += dstateT[j]*sa[j];
}
dw_[ind] = to_bf(dw * wi * wi_fac);
dk_[ind] = to_bf(dk);
dv_[ind] = to_bf(dv);
db_[ind] = to_bf(db);

__syncthreads();
dSb_shared[i] = dSb;
__syncthreads();

float da = 0;
#pragma unroll
for (int j = 0; j < C; j++) {
da += stateT[j]*dSb_shared[j];
}
da_[ind] = to_bf(da);

#pragma unroll
for (int j = 0; j < C; j++) {
dstate[j] = dstate[j]*w[j] + dSb * a[j];
dstateT[j] = dstateT[j]*wi + ai * dSb_shared[j];
}
}
}

void cuda_forward(int B, int T, int H, bf*w, bf*q, bf*k, bf*v, bf*z, bf*a, bf*y, float*s, float*sa) {
forward_kernel<<<dim3(H,B), dim3(_C_)>>>(T,H,w,q,k,v,z,a,y,s,sa);
}
void cuda_backward(int B, int T, int H, bf*w, bf*q, bf*k, bf*v, bf*z, bf*a, bf*dy, float*s, float*sa, bf*dw, bf*dq, bf*dk, bf*dv, bf*dz, bf*da) {
assert(T%_CHUNK_LEN_ == 0);
backward_kernel<<<dim3(H,B), dim3(_C_)>>>(T,H,w,q,k,v,z,a,dy,s,sa,dw,dq,dk,dv,dz,da);
}
139 changes: 139 additions & 0 deletions flagscale/train/models/rwkv/cuda/wkv7_hip.hip
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
#include "hip/hip_runtime.h"
#include <hip/hip_bf16.h>
#include <assert.h>

using bf = __hip_bfloat16;
__device__ inline float to_float(const bf & u) { return __bfloat162float(u); }
__device__ inline bf to_bf(const float & u) { return __float2bfloat16(u); }

typedef bf * __restrict__ F_;

__global__ void forward_kernel(int T, int H, F_ w_, F_ q_, F_ k_, F_ v_, F_ a_, F_ b_, bf* y_, float* s_, float* sa_) {
constexpr int C = _C_;
int bb = blockIdx.y, hh = blockIdx.x, i = threadIdx.x;

float state[C] = {0};
__shared__ float q[C], k[C], w[C], a[C], b[C];

for (int t = 0; t < T; t++) {
int ind = bb*T*H*C + t*H*C + hh * C + i;
__syncthreads();
q[i] = to_float(q_[ind]);
w[i] = __expf(-__expf(to_float(w_[ind])));
k[i] = to_float(k_[ind]);
a[i] = to_float(a_[ind]);
b[i] = to_float(b_[ind]);
__syncthreads();

float sa = 0;
#pragma unroll
for (int j = 0; j < C; j++) {
sa += a[j] * state[j];
}
sa_[ind] = sa;

float v = to_float(v_[ind]);
float y = 0;
#pragma unroll
for (int j = 0; j < C; j++) {
float& s = state[j];
s = s * w[j] + sa * b[j] + k[j] * v;
y += s * q[j];
}
y_[ind] = to_bf(y);

if ((t+1)%_CHUNK_LEN_ == 0) {
int base = (bb*H+hh)*(T/_CHUNK_LEN_)*C*C + (t/_CHUNK_LEN_)*C*C + i;
#pragma unroll
for (int j = 0; j < C; j++) {
s_[base + j*C] = state[j];
}
}
}
}

__global__ void backward_kernel(int T, int H, F_ w_, F_ q_, F_ k_, F_ v_, F_ a_, F_ b_, F_ dy_, float * __restrict__ s_, float * __restrict__ sa_, bf* dw_, bf* dq_, bf* dk_, bf* dv_, bf* da_, bf* db_) {
constexpr int C = _C_;
int bb = blockIdx.y, hh = blockIdx.x, i = threadIdx.x;

float stateT[C] = {0}, dstate[C] = {0}, dstateT[C] = {0};
__shared__ float w[C], q[C], k[C], v[C], a[C], b[C], dy[C], sa[C], dSb_shared[C];
float qi, wi, ki, ai, bi, dyi;

for (int t = T-1; t >= 0; t--) {
int ind = bb*T*H*C + t*H*C + hh * C + i;
__syncthreads();
q[i] = qi = to_float(q_[ind]);
float wi_fac = -__expf(to_float(w_[ind]));
w[i] = wi = __expf(wi_fac);
k[i] = ki = to_float(k_[ind]);
a[i] = ai = to_float(a_[ind]);
b[i] = bi = to_float(b_[ind]);
v[i] = to_float(v_[ind]);
dy[i] = dyi = to_float(dy_[ind]);
sa[i] = sa_[ind];
__syncthreads();

if ((t+1)%_CHUNK_LEN_ == 0) {
int base = (bb*H+hh)*(T/_CHUNK_LEN_)*C*C + (t/_CHUNK_LEN_)*C*C + i*C;
#pragma unroll
for (int j = 0; j < C; j++) {
stateT[j] = s_[base + j];
}
}

float dq = 0;
#pragma unroll
for (int j = 0; j < C; j++) {
dq += stateT[j]*dy[j];
}
dq_[ind] = to_bf(dq);

float iwi = 1.0f/wi;
#pragma unroll
for (int j = 0; j < C; j++) {
stateT[j] = (stateT[j] - ki*v[j] - bi*sa[j]) * iwi;
dstate[j] += dyi * q[j];
dstateT[j] += qi * dy[j];
}

float dw = 0, dk = 0, dv = 0, db = 0, dSb = 0;
#pragma unroll
for (int j = 0; j < C; j++) {
dw += dstateT[j]*stateT[j];
dk += dstateT[j]*v[j];
dv += dstate[j]*k[j];
dSb += dstate[j]*b[j];
db += dstateT[j]*sa[j];
}
dw_[ind] = to_bf(dw * wi * wi_fac);
dk_[ind] = to_bf(dk);
dv_[ind] = to_bf(dv);
db_[ind] = to_bf(db);

__syncthreads();
dSb_shared[i] = dSb;
__syncthreads();

float da = 0;
#pragma unroll
for (int j = 0; j < C; j++) {
da += stateT[j]*dSb_shared[j];
}
da_[ind] = to_bf(da);

#pragma unroll
for (int j = 0; j < C; j++) {
dstate[j] = dstate[j]*w[j] + dSb * a[j];
dstateT[j] = dstateT[j]*wi + ai * dSb_shared[j];
}
}
}

void cuda_forward(int B, int T, int H, bf*w, bf*q, bf*k, bf*v, bf*z, bf*a, bf*y, float*s, float*sa) {
hipLaunchKernelGGL(( forward_kernel), dim3(dim3(H,B)), dim3(dim3(_C_)), 0, 0, T,H,w,q,k,v,z,a,y,s,sa);
}
void cuda_backward(int B, int T, int H, bf*w, bf*q, bf*k, bf*v, bf*z, bf*a, bf*dy, float*s, float*sa, bf*dw, bf*dq, bf*dk, bf*dv, bf*dz, bf*da) {
assert(T%_CHUNK_LEN_ == 0);
hipLaunchKernelGGL(( backward_kernel), dim3(dim3(H,B)), dim3(dim3(_C_)), 0, 0, T,H,w,q,k,v,z,a,dy,s,sa,dw,dq,dk,dv,dz,da);
}
30 changes: 30 additions & 0 deletions flagscale/train/models/rwkv/cuda/wkv7_op.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
#include <torch/extension.h>

struct __nv_bfloat16;
using bf = __nv_bfloat16;

void cuda_forward(int B, int T, int H, bf*w, bf*q, bf*k, bf*v, bf*z, bf*a, bf*y, float*s, float*sa);

void forward(torch::Tensor &w, torch::Tensor &q, torch::Tensor &k, torch::Tensor &v, torch::Tensor &z, torch::Tensor &a, torch::Tensor &y, torch::Tensor &s, torch::Tensor &sa) {
int B = w.sizes()[0], T = w.sizes()[1], H = w.sizes()[2];
cuda_forward(B, T, H, (bf*)w.data_ptr(), (bf*)q.data_ptr(), (bf*)k.data_ptr(), (bf*)v.data_ptr(), (bf*)z.data_ptr(), (bf*)a.data_ptr(), (bf*)y.data_ptr(), (float*)s.data_ptr(), (float*)sa.data_ptr());
}

void cuda_backward(int B, int T, int H, bf*w, bf*q, bf*k, bf*v, bf*z, bf*a, bf*dy, float*s, float*sa, bf*dw, bf*dq, bf*dk, bf*dv, bf*dz, bf*da);

void backward(torch::Tensor &w, torch::Tensor &q, torch::Tensor &k, torch::Tensor &v, torch::Tensor &z, torch::Tensor &a, torch::Tensor &dy,
torch::Tensor &s, torch::Tensor &sa, torch::Tensor &dw, torch::Tensor &dq, torch::Tensor &dk, torch::Tensor &dv, torch::Tensor &dz, torch::Tensor &da) {
int B = w.sizes()[0], T = w.sizes()[1], H = w.sizes()[2];
cuda_backward(B, T, H, (bf*)w.data_ptr(), (bf*)q.data_ptr(), (bf*)k.data_ptr(), (bf*)v.data_ptr(), (bf*)z.data_ptr(), (bf*)a.data_ptr(), (bf*)dy.data_ptr(),
(float*)s.data_ptr(), (float*)sa.data_ptr(), (bf*)dw.data_ptr(), (bf*)dq.data_ptr(), (bf*)dk.data_ptr(), (bf*)dv.data_ptr(), (bf*)dz.data_ptr(), (bf*)da.data_ptr());
}

TORCH_LIBRARY(wind_backstepping, m) {
m.def("forward(Tensor w, Tensor q, Tensor k, Tensor v, Tensor z, Tensor a, Tensor(a!) y, Tensor(b!) s, Tensor(c!) sa) -> ()");
m.def("backward(Tensor w, Tensor q, Tensor k, Tensor v, Tensor z, Tensor a, Tensor dy, Tensor s, Tensor sa, Tensor(a!) dw, Tensor(b!) dq, Tensor(c!) dk, Tensor(d!) dv, Tensor(e!) dz, Tensor(f!) da) -> ()");
}

TORCH_LIBRARY_IMPL(wind_backstepping, CUDA, m) {
m.impl("forward", &forward);
m.impl("backward", &backward);
}
30 changes: 30 additions & 0 deletions flagscale/train/models/rwkv/cuda/wkv7_op.hip
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
#include <torch/extension.h>

struct __hip_bfloat16;
using bf = __hip_bfloat16;

void cuda_forward(int B, int T, int H, bf*w, bf*q, bf*k, bf*v, bf*z, bf*a, bf*y, float*s, float*sa);

void forward(torch::Tensor &w, torch::Tensor &q, torch::Tensor &k, torch::Tensor &v, torch::Tensor &z, torch::Tensor &a, torch::Tensor &y, torch::Tensor &s, torch::Tensor &sa) {
int B = w.sizes()[0], T = w.sizes()[1], H = w.sizes()[2];
cuda_forward(B, T, H, (bf*)w.data_ptr(), (bf*)q.data_ptr(), (bf*)k.data_ptr(), (bf*)v.data_ptr(), (bf*)z.data_ptr(), (bf*)a.data_ptr(), (bf*)y.data_ptr(), (float*)s.data_ptr(), (float*)sa.data_ptr());
}

void cuda_backward(int B, int T, int H, bf*w, bf*q, bf*k, bf*v, bf*z, bf*a, bf*dy, float*s, float*sa, bf*dw, bf*dq, bf*dk, bf*dv, bf*dz, bf*da);

void backward(torch::Tensor &w, torch::Tensor &q, torch::Tensor &k, torch::Tensor &v, torch::Tensor &z, torch::Tensor &a, torch::Tensor &dy,
torch::Tensor &s, torch::Tensor &sa, torch::Tensor &dw, torch::Tensor &dq, torch::Tensor &dk, torch::Tensor &dv, torch::Tensor &dz, torch::Tensor &da) {
int B = w.sizes()[0], T = w.sizes()[1], H = w.sizes()[2];
cuda_backward(B, T, H, (bf*)w.data_ptr(), (bf*)q.data_ptr(), (bf*)k.data_ptr(), (bf*)v.data_ptr(), (bf*)z.data_ptr(), (bf*)a.data_ptr(), (bf*)dy.data_ptr(),
(float*)s.data_ptr(), (float*)sa.data_ptr(), (bf*)dw.data_ptr(), (bf*)dq.data_ptr(), (bf*)dk.data_ptr(), (bf*)dv.data_ptr(), (bf*)dz.data_ptr(), (bf*)da.data_ptr());
}

TORCH_LIBRARY(wind_backstepping_hip, m) {
m.def("forward(Tensor w, Tensor q, Tensor k, Tensor v, Tensor z, Tensor a, Tensor(a!) y, Tensor(b!) s, Tensor(c!) sa) -> ()");
m.def("backward(Tensor w, Tensor q, Tensor k, Tensor v, Tensor z, Tensor a, Tensor dy, Tensor s, Tensor sa, Tensor(a!) dw, Tensor(b!) dq, Tensor(c!) dk, Tensor(d!) dv, Tensor(e!) dz, Tensor(f!) da) -> ()");
}

TORCH_LIBRARY_IMPL(wind_backstepping_hip, CUDA, m) {
m.impl("forward", &forward);
m.impl("backward", &backward);
}
Loading
Loading