-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathgraph.rs
99 lines (82 loc) · 3.08 KB
/
graph.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
use rusty_ggml::prelude::GTensor1;
use super::{map_ops, model::*};
impl LayerNorm {
pub fn norm_ops<T: AsRef<GTensor1>>(&self, x: T) -> GTensor1 {
(x.as_ref().norm() * &self.weight) + &self.bias
}
}
impl Mix {
pub fn mix_ops<TX: AsRef<GTensor1>, TLX: AsRef<GTensor1>>(
&self,
x: TX,
last_x: TLX,
) -> GTensor1 {
(x.as_ref() * &self.0) + (last_x.as_ref() * map_ops::one_minus(&self.0))
}
}
impl FeedForwardNetwork {
pub fn channel_mixing_ops(&self, state: &mut RWKVLayerState, x: GTensor1) -> GTensor1 {
let xk = self.time.mix_k.mix_ops(&x, &state.cm_last_x);
let xr = self.time.mix_r.mix_ops(&x, &state.cm_last_x);
let r = map_ops::sigmoid(&self.receptance_weight ^ xr);
let k = &map_ops::relu_squared(&self.key_weight ^ xk);
state.cm_last_x.copy_from(x);
r * (&self.value_weight ^ k)
}
}
impl Attention {
pub fn time_mixing_ops(&self, state: &mut RWKVLayerState, x: GTensor1) -> GTensor1 {
let (tm_last_x, aa, bb, pp) = (&state.tm_last_x, &state.tm_aa, &state.tm_bb, &state.tm_pp);
let xk = self.time.mix_k.mix_ops(&x, tm_last_x);
let xv = self.time.mix_v.mix_ops(&x, tm_last_x);
let xr = self.time.mix_r.mix_ops(&x, tm_last_x);
let r = map_ops::sigmoid(&self.receptance_weight ^ &xr);
let k = &self.key_weight ^ xk;
let v = &self.value_weight ^ xv;
let (a, b) = {
let ww = &self.time.first + &k;
let qq = map_ops::max(&ww, pp);
let e1 = map_ops::sub_exp(pp, &qq);
let e2 = map_ops::sub_exp(ww, qq);
let a = &e1 * aa + &e2 * &v;
let b = (e1 * bb) + e2;
(a, b)
};
let (wkv, new_aa, new_bb, new_pp) = {
let ww = pp + &self.time.decay;
let qq = map_ops::max(&ww, &k);
let e1 = map_ops::sub_exp(ww, &qq);
let e2 = map_ops::sub_exp(k, &qq);
let wkv = a / b;
let new_aa = &e1 * aa + &e2 * v;
let new_bb = (e1 * bb) + e2;
let new_pp = qq;
(wkv, new_aa, new_bb, new_pp)
};
state.tm_last_x.copy_from(x);
state.tm_aa.copy_from(new_aa);
state.tm_bb.copy_from(new_bb);
state.tm_pp.copy_from(new_pp);
&self.output_weight ^ (r * wkv)
}
}
impl RWKVLayer {
pub fn evaluate_layer_ops(&self, state: &mut RWKVLayerState, x: GTensor1) -> GTensor1 {
let x = self.att.time_mixing_ops(state, self.ln_tm.norm_ops(&x)) + x;
self.ffn.channel_mixing_ops(state, self.ln_cm.norm_ops(&x)) + x
}
}
impl RWKV {
pub fn evaluate_ops(&self, state: &mut [RWKVLayerState], token: GTensor1) -> GTensor1 {
let initial_x = self.emb.get_rows(token);
let initial_x = initial_x.view([initial_x.elements() as i64], [0]);
let x = self
.layers
.iter()
.enumerate()
.fold(initial_x, |x, (lnum, layer)| {
layer.evaluate_layer_ops(&mut state[lnum], x)
});
&self.head_weight ^ self.ln_out.norm_ops(&x)
}
}