Skip to content

Commit

Permalink
Change Rust view ops to take element rather than byte offsets.
Browse files Browse the repository at this point in the history
  • Loading branch information
KerfuffleV2 committed Apr 10, 2023
1 parent 08b3172 commit 825a33b
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 10 deletions.
10 changes: 10 additions & 0 deletions ggml/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,7 @@ impl Context {

/// Creates a 1D view over `a`.
pub fn op_view_1d(&self, a: &Tensor, ne0: usize, offset: usize) -> Tensor {
let offset = offset * a.element_size();
let tensor = unsafe {
ggml_sys::ggml_view_1d(self.ptr.as_ptr(), a.ptr.as_ptr(), usize_to_i64(ne0), offset)
};
Expand All @@ -287,6 +288,10 @@ impl Context {
nb1: usize,
offset: usize,
) -> Tensor {
let elsize = a.element_size();
let offset = offset * elsize;
let nb1 = nb1 * elsize;

let tensor = unsafe {
ggml_sys::ggml_view_2d(
self.ptr.as_ptr(),
Expand All @@ -312,6 +317,11 @@ impl Context {
nb2: usize,
offset: usize,
) -> Tensor {
let elsize = a.element_size();
let offset = offset * a.element_size();
let nb1 = nb1 * elsize;
let nb2 = nb2 * elsize;

let tensor = unsafe {
ggml_sys::ggml_view_3d(
self.ptr.as_ptr(),
Expand Down
17 changes: 7 additions & 10 deletions llama-rs/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1130,9 +1130,6 @@ impl Model {
let n_past = session.n_past;
let n_threads = params.n_threads;

let memk_elsize = session.memory_k.element_size();
let memv_elsize = session.memory_v.element_size();

let Hyperparameters {
n_vocab,
n_ctx,
Expand Down Expand Up @@ -1213,15 +1210,15 @@ impl Model {
let k = ctx0.op_view_1d(
&session.memory_k,
n * n_embd,
(memk_elsize * n_embd) * (il * n_ctx + n_past),
n_embd * (il * n_ctx + n_past),
);

let v = ctx0.op_view_2d(
&session.memory_v,
n,
n_embd,
n_ctx * memv_elsize,
(il * n_ctx) * memv_elsize * n_embd + n_past * memv_elsize,
n_ctx,
(il * n_ctx) * n_embd + n_past,
);

// important: storing RoPE-ed version of K in the KV cache!
Expand All @@ -1236,7 +1233,7 @@ impl Model {
&ctx0.op_view_1d(
&session.memory_k,
(n_past + n) * n_embd,
il * n_ctx * memk_elsize * n_embd,
il * n_ctx * n_embd,
),
n_embd / n_head,
n_head,
Expand Down Expand Up @@ -1269,9 +1266,9 @@ impl Model {
n_past + n,
n_embd / n_head,
n_head,
n_ctx * memv_elsize,
n_ctx * memv_elsize * n_embd / n_head,
il * n_ctx * memv_elsize * n_embd,
n_ctx,
n_ctx * n_embd / n_head,
il * n_ctx * n_embd,
);

let k_q_v = ctx0.op_mul_mat(&v, &k_q_soft_max);
Expand Down

0 comments on commit 825a33b

Please sign in to comment.