diff --git a/flagscale/train/models/rwkv/cuda/wkv7_cuda.cu b/flagscale/train/models/rwkv/cuda/wkv7_cuda.cu new file mode 100644 index 000000000..f11e5e851 --- /dev/null +++ b/flagscale/train/models/rwkv/cuda/wkv7_cuda.cu @@ -0,0 +1,138 @@ +#include +#include + +using bf = __nv_bfloat16; +__device__ inline float to_float(const bf & u) { return __bfloat162float(u); } +__device__ inline bf to_bf(const float & u) { return __float2bfloat16_rn(u); } + +typedef bf * __restrict__ F_; + +__global__ void forward_kernel(int T, int H, F_ w_, F_ q_, F_ k_, F_ v_, F_ a_, F_ b_, bf* y_, float* s_, float* sa_) { + constexpr int C = _C_; + int bb = blockIdx.y, hh = blockIdx.x, i = threadIdx.x; + + float state[C] = {0}; + __shared__ float q[C], k[C], w[C], a[C], b[C]; + + for (int t = 0; t < T; t++) { + int ind = bb*T*H*C + t*H*C + hh * C + i; + __syncthreads(); + q[i] = to_float(q_[ind]); + w[i] = __expf(-__expf(to_float(w_[ind]))); + k[i] = to_float(k_[ind]); + a[i] = to_float(a_[ind]); + b[i] = to_float(b_[ind]); + __syncthreads(); + + float sa = 0; +#pragma unroll + for (int j = 0; j < C; j++) { + sa += a[j] * state[j]; + } + sa_[ind] = sa; + + float v = to_float(v_[ind]); + float y = 0; +#pragma unroll + for (int j = 0; j < C; j++) { + float& s = state[j]; + s = s * w[j] + sa * b[j] + k[j] * v; + y += s * q[j]; + } + y_[ind] = to_bf(y); + + if ((t+1)%_CHUNK_LEN_ == 0) { + int base = (bb*H+hh)*(T/_CHUNK_LEN_)*C*C + (t/_CHUNK_LEN_)*C*C + i; +#pragma unroll + for (int j = 0; j < C; j++) { + s_[base + j*C] = state[j]; + } + } + } +} + +__global__ void backward_kernel(int T, int H, F_ w_, F_ q_, F_ k_, F_ v_, F_ a_, F_ b_, F_ dy_, float * __restrict__ s_, float * __restrict__ sa_, bf* dw_, bf* dq_, bf* dk_, bf* dv_, bf* da_, bf* db_) { + constexpr int C = _C_; + int bb = blockIdx.y, hh = blockIdx.x, i = threadIdx.x; + + float stateT[C] = {0}, dstate[C] = {0}, dstateT[C] = {0}; + __shared__ float w[C], q[C], k[C], v[C], a[C], b[C], dy[C], sa[C], dSb_shared[C]; + float qi, wi, ki, ai, bi, dyi; + + for (int t = T-1; t >= 0; t--) { + int ind = bb*T*H*C + t*H*C + hh * C + i; + __syncthreads(); + q[i] = qi = to_float(q_[ind]); + float wi_fac = -__expf(to_float(w_[ind])); + w[i] = wi = __expf(wi_fac); + k[i] = ki = to_float(k_[ind]); + a[i] = ai = to_float(a_[ind]); + b[i] = bi = to_float(b_[ind]); + v[i] = to_float(v_[ind]); + dy[i] = dyi = to_float(dy_[ind]); + sa[i] = sa_[ind]; + __syncthreads(); + + if ((t+1)%_CHUNK_LEN_ == 0) { + int base = (bb*H+hh)*(T/_CHUNK_LEN_)*C*C + (t/_CHUNK_LEN_)*C*C + i*C; +#pragma unroll + for (int j = 0; j < C; j++) { + stateT[j] = s_[base + j]; + } + } + + float dq = 0; +#pragma unroll + for (int j = 0; j < C; j++) { + dq += stateT[j]*dy[j]; + } + dq_[ind] = to_bf(dq); + + float iwi = 1.0f/wi; +#pragma unroll + for (int j = 0; j < C; j++) { + stateT[j] = (stateT[j] - ki*v[j] - bi*sa[j]) * iwi; + dstate[j] += dyi * q[j]; + dstateT[j] += qi * dy[j]; + } + + float dw = 0, dk = 0, dv = 0, db = 0, dSb = 0; +#pragma unroll + for (int j = 0; j < C; j++) { + dw += dstateT[j]*stateT[j]; + dk += dstateT[j]*v[j]; + dv += dstate[j]*k[j]; + dSb += dstate[j]*b[j]; + db += dstateT[j]*sa[j]; + } + dw_[ind] = to_bf(dw * wi * wi_fac); + dk_[ind] = to_bf(dk); + dv_[ind] = to_bf(dv); + db_[ind] = to_bf(db); + + __syncthreads(); + dSb_shared[i] = dSb; + __syncthreads(); + + float da = 0; +#pragma unroll + for (int j = 0; j < C; j++) { + da += stateT[j]*dSb_shared[j]; + } + da_[ind] = to_bf(da); + +#pragma unroll + for (int j = 0; j < C; j++) { + dstate[j] = dstate[j]*w[j] + dSb * a[j]; + dstateT[j] = dstateT[j]*wi + ai * dSb_shared[j]; + } + } +} + +void cuda_forward(int B, int T, int H, bf*w, bf*q, bf*k, bf*v, bf*z, bf*a, bf*y, float*s, float*sa) { + forward_kernel<<>>(T,H,w,q,k,v,z,a,y,s,sa); +} +void cuda_backward(int B, int T, int H, bf*w, bf*q, bf*k, bf*v, bf*z, bf*a, bf*dy, float*s, float*sa, bf*dw, bf*dq, bf*dk, bf*dv, bf*dz, bf*da) { + assert(T%_CHUNK_LEN_ == 0); + backward_kernel<<>>(T,H,w,q,k,v,z,a,dy,s,sa,dw,dq,dk,dv,dz,da); +} diff --git a/flagscale/train/models/rwkv/cuda/wkv7_hip.hip b/flagscale/train/models/rwkv/cuda/wkv7_hip.hip new file mode 100644 index 000000000..31a143c32 --- /dev/null +++ b/flagscale/train/models/rwkv/cuda/wkv7_hip.hip @@ -0,0 +1,139 @@ +#include "hip/hip_runtime.h" +#include +#include + +using bf = __hip_bfloat16; +__device__ inline float to_float(const bf & u) { return __bfloat162float(u); } +__device__ inline bf to_bf(const float & u) { return __float2bfloat16(u); } + +typedef bf * __restrict__ F_; + +__global__ void forward_kernel(int T, int H, F_ w_, F_ q_, F_ k_, F_ v_, F_ a_, F_ b_, bf* y_, float* s_, float* sa_) { + constexpr int C = _C_; + int bb = blockIdx.y, hh = blockIdx.x, i = threadIdx.x; + + float state[C] = {0}; + __shared__ float q[C], k[C], w[C], a[C], b[C]; + + for (int t = 0; t < T; t++) { + int ind = bb*T*H*C + t*H*C + hh * C + i; + __syncthreads(); + q[i] = to_float(q_[ind]); + w[i] = __expf(-__expf(to_float(w_[ind]))); + k[i] = to_float(k_[ind]); + a[i] = to_float(a_[ind]); + b[i] = to_float(b_[ind]); + __syncthreads(); + + float sa = 0; +#pragma unroll + for (int j = 0; j < C; j++) { + sa += a[j] * state[j]; + } + sa_[ind] = sa; + + float v = to_float(v_[ind]); + float y = 0; +#pragma unroll + for (int j = 0; j < C; j++) { + float& s = state[j]; + s = s * w[j] + sa * b[j] + k[j] * v; + y += s * q[j]; + } + y_[ind] = to_bf(y); + + if ((t+1)%_CHUNK_LEN_ == 0) { + int base = (bb*H+hh)*(T/_CHUNK_LEN_)*C*C + (t/_CHUNK_LEN_)*C*C + i; +#pragma unroll + for (int j = 0; j < C; j++) { + s_[base + j*C] = state[j]; + } + } + } +} + +__global__ void backward_kernel(int T, int H, F_ w_, F_ q_, F_ k_, F_ v_, F_ a_, F_ b_, F_ dy_, float * __restrict__ s_, float * __restrict__ sa_, bf* dw_, bf* dq_, bf* dk_, bf* dv_, bf* da_, bf* db_) { + constexpr int C = _C_; + int bb = blockIdx.y, hh = blockIdx.x, i = threadIdx.x; + + float stateT[C] = {0}, dstate[C] = {0}, dstateT[C] = {0}; + __shared__ float w[C], q[C], k[C], v[C], a[C], b[C], dy[C], sa[C], dSb_shared[C]; + float qi, wi, ki, ai, bi, dyi; + + for (int t = T-1; t >= 0; t--) { + int ind = bb*T*H*C + t*H*C + hh * C + i; + __syncthreads(); + q[i] = qi = to_float(q_[ind]); + float wi_fac = -__expf(to_float(w_[ind])); + w[i] = wi = __expf(wi_fac); + k[i] = ki = to_float(k_[ind]); + a[i] = ai = to_float(a_[ind]); + b[i] = bi = to_float(b_[ind]); + v[i] = to_float(v_[ind]); + dy[i] = dyi = to_float(dy_[ind]); + sa[i] = sa_[ind]; + __syncthreads(); + + if ((t+1)%_CHUNK_LEN_ == 0) { + int base = (bb*H+hh)*(T/_CHUNK_LEN_)*C*C + (t/_CHUNK_LEN_)*C*C + i*C; +#pragma unroll + for (int j = 0; j < C; j++) { + stateT[j] = s_[base + j]; + } + } + + float dq = 0; +#pragma unroll + for (int j = 0; j < C; j++) { + dq += stateT[j]*dy[j]; + } + dq_[ind] = to_bf(dq); + + float iwi = 1.0f/wi; +#pragma unroll + for (int j = 0; j < C; j++) { + stateT[j] = (stateT[j] - ki*v[j] - bi*sa[j]) * iwi; + dstate[j] += dyi * q[j]; + dstateT[j] += qi * dy[j]; + } + + float dw = 0, dk = 0, dv = 0, db = 0, dSb = 0; +#pragma unroll + for (int j = 0; j < C; j++) { + dw += dstateT[j]*stateT[j]; + dk += dstateT[j]*v[j]; + dv += dstate[j]*k[j]; + dSb += dstate[j]*b[j]; + db += dstateT[j]*sa[j]; + } + dw_[ind] = to_bf(dw * wi * wi_fac); + dk_[ind] = to_bf(dk); + dv_[ind] = to_bf(dv); + db_[ind] = to_bf(db); + + __syncthreads(); + dSb_shared[i] = dSb; + __syncthreads(); + + float da = 0; +#pragma unroll + for (int j = 0; j < C; j++) { + da += stateT[j]*dSb_shared[j]; + } + da_[ind] = to_bf(da); + +#pragma unroll + for (int j = 0; j < C; j++) { + dstate[j] = dstate[j]*w[j] + dSb * a[j]; + dstateT[j] = dstateT[j]*wi + ai * dSb_shared[j]; + } + } +} + +void cuda_forward(int B, int T, int H, bf*w, bf*q, bf*k, bf*v, bf*z, bf*a, bf*y, float*s, float*sa) { + hipLaunchKernelGGL(( forward_kernel), dim3(dim3(H,B)), dim3(dim3(_C_)), 0, 0, T,H,w,q,k,v,z,a,y,s,sa); +} +void cuda_backward(int B, int T, int H, bf*w, bf*q, bf*k, bf*v, bf*z, bf*a, bf*dy, float*s, float*sa, bf*dw, bf*dq, bf*dk, bf*dv, bf*dz, bf*da) { + assert(T%_CHUNK_LEN_ == 0); + hipLaunchKernelGGL(( backward_kernel), dim3(dim3(H,B)), dim3(dim3(_C_)), 0, 0, T,H,w,q,k,v,z,a,dy,s,sa,dw,dq,dk,dv,dz,da); +} diff --git a/flagscale/train/models/rwkv/cuda/wkv7_op.cpp b/flagscale/train/models/rwkv/cuda/wkv7_op.cpp new file mode 100644 index 000000000..fe2981d81 --- /dev/null +++ b/flagscale/train/models/rwkv/cuda/wkv7_op.cpp @@ -0,0 +1,30 @@ +#include + +struct __nv_bfloat16; +using bf = __nv_bfloat16; + +void cuda_forward(int B, int T, int H, bf*w, bf*q, bf*k, bf*v, bf*z, bf*a, bf*y, float*s, float*sa); + +void forward(torch::Tensor &w, torch::Tensor &q, torch::Tensor &k, torch::Tensor &v, torch::Tensor &z, torch::Tensor &a, torch::Tensor &y, torch::Tensor &s, torch::Tensor &sa) { + int B = w.sizes()[0], T = w.sizes()[1], H = w.sizes()[2]; + cuda_forward(B, T, H, (bf*)w.data_ptr(), (bf*)q.data_ptr(), (bf*)k.data_ptr(), (bf*)v.data_ptr(), (bf*)z.data_ptr(), (bf*)a.data_ptr(), (bf*)y.data_ptr(), (float*)s.data_ptr(), (float*)sa.data_ptr()); +} + +void cuda_backward(int B, int T, int H, bf*w, bf*q, bf*k, bf*v, bf*z, bf*a, bf*dy, float*s, float*sa, bf*dw, bf*dq, bf*dk, bf*dv, bf*dz, bf*da); + +void backward(torch::Tensor &w, torch::Tensor &q, torch::Tensor &k, torch::Tensor &v, torch::Tensor &z, torch::Tensor &a, torch::Tensor &dy, + torch::Tensor &s, torch::Tensor &sa, torch::Tensor &dw, torch::Tensor &dq, torch::Tensor &dk, torch::Tensor &dv, torch::Tensor &dz, torch::Tensor &da) { + int B = w.sizes()[0], T = w.sizes()[1], H = w.sizes()[2]; + cuda_backward(B, T, H, (bf*)w.data_ptr(), (bf*)q.data_ptr(), (bf*)k.data_ptr(), (bf*)v.data_ptr(), (bf*)z.data_ptr(), (bf*)a.data_ptr(), (bf*)dy.data_ptr(), + (float*)s.data_ptr(), (float*)sa.data_ptr(), (bf*)dw.data_ptr(), (bf*)dq.data_ptr(), (bf*)dk.data_ptr(), (bf*)dv.data_ptr(), (bf*)dz.data_ptr(), (bf*)da.data_ptr()); +} + +TORCH_LIBRARY(wind_backstepping, m) { + m.def("forward(Tensor w, Tensor q, Tensor k, Tensor v, Tensor z, Tensor a, Tensor(a!) y, Tensor(b!) s, Tensor(c!) sa) -> ()"); + m.def("backward(Tensor w, Tensor q, Tensor k, Tensor v, Tensor z, Tensor a, Tensor dy, Tensor s, Tensor sa, Tensor(a!) dw, Tensor(b!) dq, Tensor(c!) dk, Tensor(d!) dv, Tensor(e!) dz, Tensor(f!) da) -> ()"); +} + +TORCH_LIBRARY_IMPL(wind_backstepping, CUDA, m) { + m.impl("forward", &forward); + m.impl("backward", &backward); +} diff --git a/flagscale/train/models/rwkv/cuda/wkv7_op.hip b/flagscale/train/models/rwkv/cuda/wkv7_op.hip new file mode 100644 index 000000000..522e9f082 --- /dev/null +++ b/flagscale/train/models/rwkv/cuda/wkv7_op.hip @@ -0,0 +1,30 @@ +#include + +struct __hip_bfloat16; +using bf = __hip_bfloat16; + +void cuda_forward(int B, int T, int H, bf*w, bf*q, bf*k, bf*v, bf*z, bf*a, bf*y, float*s, float*sa); + +void forward(torch::Tensor &w, torch::Tensor &q, torch::Tensor &k, torch::Tensor &v, torch::Tensor &z, torch::Tensor &a, torch::Tensor &y, torch::Tensor &s, torch::Tensor &sa) { + int B = w.sizes()[0], T = w.sizes()[1], H = w.sizes()[2]; + cuda_forward(B, T, H, (bf*)w.data_ptr(), (bf*)q.data_ptr(), (bf*)k.data_ptr(), (bf*)v.data_ptr(), (bf*)z.data_ptr(), (bf*)a.data_ptr(), (bf*)y.data_ptr(), (float*)s.data_ptr(), (float*)sa.data_ptr()); +} + +void cuda_backward(int B, int T, int H, bf*w, bf*q, bf*k, bf*v, bf*z, bf*a, bf*dy, float*s, float*sa, bf*dw, bf*dq, bf*dk, bf*dv, bf*dz, bf*da); + +void backward(torch::Tensor &w, torch::Tensor &q, torch::Tensor &k, torch::Tensor &v, torch::Tensor &z, torch::Tensor &a, torch::Tensor &dy, + torch::Tensor &s, torch::Tensor &sa, torch::Tensor &dw, torch::Tensor &dq, torch::Tensor &dk, torch::Tensor &dv, torch::Tensor &dz, torch::Tensor &da) { + int B = w.sizes()[0], T = w.sizes()[1], H = w.sizes()[2]; + cuda_backward(B, T, H, (bf*)w.data_ptr(), (bf*)q.data_ptr(), (bf*)k.data_ptr(), (bf*)v.data_ptr(), (bf*)z.data_ptr(), (bf*)a.data_ptr(), (bf*)dy.data_ptr(), + (float*)s.data_ptr(), (float*)sa.data_ptr(), (bf*)dw.data_ptr(), (bf*)dq.data_ptr(), (bf*)dk.data_ptr(), (bf*)dv.data_ptr(), (bf*)dz.data_ptr(), (bf*)da.data_ptr()); +} + +TORCH_LIBRARY(wind_backstepping_hip, m) { + m.def("forward(Tensor w, Tensor q, Tensor k, Tensor v, Tensor z, Tensor a, Tensor(a!) y, Tensor(b!) s, Tensor(c!) sa) -> ()"); + m.def("backward(Tensor w, Tensor q, Tensor k, Tensor v, Tensor z, Tensor a, Tensor dy, Tensor s, Tensor sa, Tensor(a!) dw, Tensor(b!) dq, Tensor(c!) dk, Tensor(d!) dv, Tensor(e!) dz, Tensor(f!) da) -> ()"); +} + +TORCH_LIBRARY_IMPL(wind_backstepping_hip, CUDA, m) { + m.impl("forward", &forward); + m.impl("backward", &backward); +} diff --git a/flagscale/train/models/rwkv/rwkv_model.py b/flagscale/train/models/rwkv/rwkv_model.py index 2952e0f7b..fbadced32 100644 --- a/flagscale/train/models/rwkv/rwkv_model.py +++ b/flagscale/train/models/rwkv/rwkv_model.py @@ -56,7 +56,7 @@ def __nop(ob): # Safe defaults for env vars HEAD_SIZE = int(os.getenv("RWKV_HEAD_SIZE", "64")) -_RWKV_MY_TESTING = os.getenv("RWKV_MY_TESTING", "") +_RWKV_MY_TESTING = os.getenv("RWKV_MY_TESTING", "x070") # Prepare possible CUDA extension only if requested in env var RUN_CUDA_RWKV7g = None @@ -76,7 +76,7 @@ def __nop(ob): ] load( name="wind_backstepping_hip", - sources=["megatron/core/models/rwkv/cuda/wkv7_hip.hip", "megatron/core/models/rwkv/cuda/wkv7_op.hip"], + sources=["flagscale/train/models/rwkv/cuda/wkv7_hip.hip", "flagscale/train/models/rwkv/cuda/wkv7_op.hip"], is_python_module=False, verbose=True, extra_cuda_cflags=flags, @@ -93,7 +93,7 @@ def __nop(ob): ] load( name="wind_backstepping", - sources=["megatron/core/models/rwkv/cuda/wkv7_cuda.cu", "megatron/core/models/rwkv/cuda/wkv7_op.cpp"], + sources=["flagscale/train/models/rwkv/cuda/wkv7_cuda.cu", "flagscale/train/models/rwkv/cuda/wkv7_op.cpp"], is_python_module=False, verbose=True, extra_cuda_cflags=flags, @@ -104,7 +104,21 @@ class WindBackstepping(torch.autograd.Function): def forward(ctx, w, q, k, v, z, b): B, T, H, C = w.shape assert T % CHUNK_LEN == 0 - assert all(i.dtype in [torch.bfloat16, torch.float16] for i in [w, q, k, v, z, b]) + assert all(i.dtype == torch.bfloat16 for i in [w, q, k, v, z, b]) + assert all(i.is_contiguous() for i in [w, q, k, v, z, b]) + y = torch.empty_like(v) + s = torch.empty( + B, H, T // CHUNK_LEN, C, C, dtype=torch.float32, device=w.device + ) + sa = torch.empty(B, T, H, C, dtype=torch.float32, device=w.device) + torch.ops.wind_backstepping.forward(w, q, k, v, z, b, y, s, sa) + ctx.save_for_backward(w, q, k, v, z, b, s, sa) + return y + + @staticmethod + def backward(ctx, dy): + assert all(i.dtype == torch.bfloat16 for i in [dy]) + assert all(i.is_contiguous() for i in [dy]) w, q, k, v, z, b, s, sa = ctx.saved_tensors dw, dq, dk, dv, dz, db = [torch.empty_like(x) for x in [ w, q, k, v, z, b]] @@ -119,24 +133,6 @@ def RUN_CUDA_RWKV7g(q, w, k, v, a, b): q, w, k, v, a, b = [i.view(B, T, HC // 64, 64) for i in [q, w, k, v, a, b]] return WindBackstepping.apply(w, q, k, v, a, b).view(B, T, HC) -else: - # Fallback: when custom kernel is not compiled/loaded, provide a clear fallback. - def RUN_CUDA_RWKV7g(q, w, k, v, a, b): - """ - Fallback CPU/PyTorch implementation placeholder. - This is intentionally simple and will likely be much slower than the optimized kernel. - It ensures the name exists so other code won't NameError. If you expect to run fast - on GPU, compile and enable the native kernel and set RWKV_MY_TESTING to include 'x070'. - """ - # Try a safe, pure-Torch (not optimized) computation that preserves shapes. - # Here we implement a plausible fallback: elementwise combination. - # NOTE: This fallback might not match the optimized kernel semantics exactly. - B, T, HC = q.shape - # reshape into (B,T,H,-1) if possible. We'll attempt dividing by head chunk 64 if matches, - # otherwise keep last dim intact. - return (q * w + k * v + a + b).view(B, T, HC) - - class RWKV_Tmix_x070(nn.Module): def __init__(self, args, layer_id): super().__init__()