-
Notifications
You must be signed in to change notification settings - Fork 10.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Signed-off-by: Molly Sophia <[email protected]>
- Loading branch information
1 parent
e1168a9
commit 7dd075a
Showing
4 changed files
with
143 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
#include "common.cuh" | ||
#include "rwkv-wkv.cuh" | ||
|
||
static __global__ void rwkv_wkv_f32(const int B, const int T, const int C, const int H, const float * k, const float * v, const float * r, const float * tf, const float * td, const float * s, float * dst) { | ||
const int tid = threadIdx.x; | ||
const int bid = blockIdx.x; | ||
|
||
const int head_size = CUDA_WKV_BLOCK_SIZE; | ||
const int batch_i = bid / H; | ||
const int head_i = bid % H; | ||
const int state_size = C * head_size; | ||
const int n_seq_tokens = T / B; | ||
|
||
float state[head_size]; | ||
__shared__ float _k[head_size], _r[head_size], _tf[head_size], _td[head_size]; | ||
|
||
#pragma unroll | ||
for (int i = 0; i < head_size; i++) { | ||
state[i] = s[batch_i * state_size + head_i * head_size * head_size + i * head_size + tid]; | ||
} | ||
|
||
__syncthreads(); | ||
_tf[tid] = tf[head_i * head_size + tid]; | ||
__syncthreads(); | ||
|
||
for (int t = batch_i * n_seq_tokens * C + head_i * head_size + tid; t < (batch_i + 1) * n_seq_tokens * C + head_i * head_size + tid; t += C) { | ||
__syncthreads(); | ||
_k[tid] = k[t]; | ||
_r[tid] = r[t]; | ||
_td[tid] = td[t]; | ||
__syncthreads(); | ||
|
||
const float _v = v[t]; | ||
float y; | ||
for (int j = 0; j < head_size; j += 4) { | ||
const float4& k = (float4&)(_k[j]); | ||
const float4& r = (float4&)(_r[j]); | ||
const float4& tf = (float4&)(_tf[j]); | ||
const float4& td = (float4&)(_td[j]); | ||
float4& s = (float4&)(state[j]); | ||
float4 kv; | ||
|
||
kv.x = k.x * _v; | ||
kv.y = k.y * _v; | ||
kv.z = k.z * _v; | ||
kv.w = k.w * _v; | ||
|
||
y += r.x * (tf.x * kv.x + s.x); | ||
y += r.y * (tf.y * kv.y + s.y); | ||
y += r.z * (tf.z * kv.z + s.z); | ||
y += r.w * (tf.w * kv.w + s.w); | ||
|
||
s.x = s.x * td.x + kv.x; | ||
s.y = s.y * td.y + kv.y; | ||
s.z = s.z * td.z + kv.z; | ||
s.w = s.w * td.w + kv.w; | ||
} | ||
dst[t] = y; | ||
} | ||
|
||
#pragma unroll | ||
for (int i = 0; i < head_size; i++) { | ||
dst[T * C + batch_i * state_size + head_i * head_size * head_size + i * head_size + tid] = state[i]; | ||
} | ||
} | ||
|
||
void ggml_cuda_op_rwkv_wkv(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { | ||
const float * k_d = (const float *)dst->src[0]->data; | ||
const float * v_d = (const float *)dst->src[1]->data; | ||
const float * r_d = (const float *)dst->src[2]->data; | ||
const float * tf_d = (const float *)dst->src[3]->data; | ||
const float * td_d = (const float *)dst->src[4]->data; | ||
const float * s_d = (const float *)dst->src[5]->data; | ||
|
||
const int64_t B = dst->src[5]->ne[1]; | ||
const int64_t T = dst->src[0]->ne[3]; | ||
const int64_t C = dst->ne[0]; | ||
const int64_t H = dst->src[0]->ne[2]; | ||
|
||
float * dst_d = (float *)dst->data; | ||
|
||
cudaStream_t stream = ctx.stream(); | ||
|
||
GGML_ASSERT(dst->src[5]->type == GGML_TYPE_F32); | ||
GGML_ASSERT(C % H == 0); | ||
GGML_ASSERT(C / H == CUDA_WKV_BLOCK_SIZE); | ||
|
||
rwkv_wkv_f32<<<B * H, C / H, 0, stream>>>(B, T, C, H, k_d, v_d, r_d, tf_d, td_d, s_d, dst_d); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
#include "common.cuh" | ||
|
||
#define CUDA_WKV_BLOCK_SIZE 64 | ||
|
||
void ggml_cuda_op_rwkv_wkv(ggml_backend_cuda_context & ctx, ggml_tensor * dst); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters