Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 32 additions & 17 deletions tools/mtmd/clip.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -314,11 +314,17 @@ ggml_tensor * clip_graph::build_vit(
std::function<ggml_tensor *(ggml_tensor *, const clip_layer &)> add_pos,
const build_vit_opts & opts
) {
// batch dim: inp is [n_embd, n_pos] (B==1) or [n_embd, n_pos, B] (multi-tile encode)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note that batching is not just for multi-tile encode, but it should eventually allow batching multiple images of same size. that will be important for video processing where we need to process multiple images in the same pass

I will fix this comment along with my refactoring to add the proper architecture for doing so

const int64_t B = inp->ne[2];

if (learned_pos_embd) {
inp = ggml_add(ctx0, inp, learned_pos_embd);
cb(inp, "pos_embed", -1);
}

// flatten batch; unflatten again in attention
inp = ggml_reshape_2d(ctx0, inp, n_embd, n_pos * B);

ggml_tensor * inpL = inp;

// pre-layernorm
Expand Down Expand Up @@ -348,20 +354,24 @@ ggml_tensor * clip_graph::build_vit(
cur = ggml_add(ctx0, cur, layer.qkv_b);
}

Qcur = ggml_view_3d(ctx0, cur, d_head, n_head, n_pos,
/* nb1 */ ggml_row_size(cur->type, d_head),
/* nb2 */ cur->nb[1],
/* offset */ 0);

Kcur = ggml_view_3d(ctx0, cur, d_head, n_head, n_pos,
/* nb1 */ ggml_row_size(cur->type, d_head),
/* nb2 */ cur->nb[1],
/* offset */ ggml_row_size(cur->type, n_embd));

Vcur = ggml_view_3d(ctx0, cur, d_head, n_head, n_pos,
/* nb1 */ ggml_row_size(cur->type, d_head),
/* nb2 */ cur->nb[1],
/* offset */ ggml_row_size(cur->type, 2 * n_embd));
// Q/K/V as [d_head, n_head, n_pos, B], the batch stride is cur->nb[1]*n_pos.
Qcur = ggml_view_4d(ctx0, cur, d_head, n_head, n_pos, B,
/* nb1 */ ggml_row_size(cur->type, d_head),
/* nb2 */ cur->nb[1],
/* nb3 */ cur->nb[1] * n_pos,
/* offset */ 0);

Kcur = ggml_view_4d(ctx0, cur, d_head, n_head, n_pos, B,
/* nb1 */ ggml_row_size(cur->type, d_head),
/* nb2 */ cur->nb[1],
/* nb3 */ cur->nb[1] * n_pos,
/* offset */ ggml_row_size(cur->type, n_embd));

Vcur = ggml_view_4d(ctx0, cur, d_head, n_head, n_pos, B,
/* nb1 */ ggml_row_size(cur->type, d_head),
/* nb2 */ cur->nb[1],
/* nb3 */ cur->nb[1] * n_pos,
/* offset */ ggml_row_size(cur->type, 2 * n_embd));

if (layer.q_norm) {
GGML_ASSERT(layer.q_norm->ne[0] == Qcur->ne[0]);
Expand Down Expand Up @@ -406,9 +416,9 @@ ggml_tensor * clip_graph::build_vit(
}
}

Qcur = ggml_reshape_3d(ctx0, Qcur, d_head, n_head, n_pos);
Kcur = ggml_reshape_3d(ctx0, Kcur, d_head, n_head_kv, n_pos);
Vcur = ggml_reshape_3d(ctx0, Vcur, d_head, n_head_kv, n_pos);
Qcur = ggml_reshape_4d(ctx0, Qcur, d_head, n_head, n_pos, B);
Kcur = ggml_reshape_4d(ctx0, Kcur, d_head, n_head_kv, n_pos, B);
Vcur = ggml_reshape_4d(ctx0, Vcur, d_head, n_head_kv, n_pos, B);

if (norm_per_head) {
if (layer.q_norm) {
Expand Down Expand Up @@ -438,6 +448,7 @@ ggml_tensor * clip_graph::build_vit(
cb(Vcur, "Vcur_normed", il);
}

// build_attn returns a flat 2D [n_embd, n_pos*B]
cur = build_attn(layer.o_w, layer.o_b,
Qcur, Kcur, Vcur, opts.attn_mask, kq_scale, il);
cb(cur, "attn_out", il);
Expand Down Expand Up @@ -509,6 +520,10 @@ ggml_tensor * clip_graph::build_vit(
if (model.post_ln_w) {
inpL = build_norm(inpL, model.post_ln_w, model.post_ln_b, norm_t, eps, -1);
}

// restore the batch dim
GGML_ASSERT(inpL->ne[1] % B == 0);
inpL = ggml_reshape_3d(ctx0, inpL, n_embd, inpL->ne[1] / B, B);
return inpL;
}

Expand Down
Loading