diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index bd33f430625a..02e7a3a8f60f 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -314,11 +314,17 @@ ggml_tensor * clip_graph::build_vit( std::function add_pos, const build_vit_opts & opts ) { + // batch dim: inp is [n_embd, n_pos] (B==1) or [n_embd, n_pos, B] (multi-tile encode) + const int64_t B = inp->ne[2]; + if (learned_pos_embd) { inp = ggml_add(ctx0, inp, learned_pos_embd); cb(inp, "pos_embed", -1); } + // flatten batch; unflatten again in attention + inp = ggml_reshape_2d(ctx0, inp, n_embd, n_pos * B); + ggml_tensor * inpL = inp; // pre-layernorm @@ -348,20 +354,24 @@ ggml_tensor * clip_graph::build_vit( cur = ggml_add(ctx0, cur, layer.qkv_b); } - Qcur = ggml_view_3d(ctx0, cur, d_head, n_head, n_pos, - /* nb1 */ ggml_row_size(cur->type, d_head), - /* nb2 */ cur->nb[1], - /* offset */ 0); - - Kcur = ggml_view_3d(ctx0, cur, d_head, n_head, n_pos, - /* nb1 */ ggml_row_size(cur->type, d_head), - /* nb2 */ cur->nb[1], - /* offset */ ggml_row_size(cur->type, n_embd)); - - Vcur = ggml_view_3d(ctx0, cur, d_head, n_head, n_pos, - /* nb1 */ ggml_row_size(cur->type, d_head), - /* nb2 */ cur->nb[1], - /* offset */ ggml_row_size(cur->type, 2 * n_embd)); + // Q/K/V as [d_head, n_head, n_pos, B], the batch stride is cur->nb[1]*n_pos. + Qcur = ggml_view_4d(ctx0, cur, d_head, n_head, n_pos, B, + /* nb1 */ ggml_row_size(cur->type, d_head), + /* nb2 */ cur->nb[1], + /* nb3 */ cur->nb[1] * n_pos, + /* offset */ 0); + + Kcur = ggml_view_4d(ctx0, cur, d_head, n_head, n_pos, B, + /* nb1 */ ggml_row_size(cur->type, d_head), + /* nb2 */ cur->nb[1], + /* nb3 */ cur->nb[1] * n_pos, + /* offset */ ggml_row_size(cur->type, n_embd)); + + Vcur = ggml_view_4d(ctx0, cur, d_head, n_head, n_pos, B, + /* nb1 */ ggml_row_size(cur->type, d_head), + /* nb2 */ cur->nb[1], + /* nb3 */ cur->nb[1] * n_pos, + /* offset */ ggml_row_size(cur->type, 2 * n_embd)); if (layer.q_norm) { GGML_ASSERT(layer.q_norm->ne[0] == Qcur->ne[0]); @@ -406,9 +416,9 @@ ggml_tensor * clip_graph::build_vit( } } - Qcur = ggml_reshape_3d(ctx0, Qcur, d_head, n_head, n_pos); - Kcur = ggml_reshape_3d(ctx0, Kcur, d_head, n_head_kv, n_pos); - Vcur = ggml_reshape_3d(ctx0, Vcur, d_head, n_head_kv, n_pos); + Qcur = ggml_reshape_4d(ctx0, Qcur, d_head, n_head, n_pos, B); + Kcur = ggml_reshape_4d(ctx0, Kcur, d_head, n_head_kv, n_pos, B); + Vcur = ggml_reshape_4d(ctx0, Vcur, d_head, n_head_kv, n_pos, B); if (norm_per_head) { if (layer.q_norm) { @@ -438,6 +448,7 @@ ggml_tensor * clip_graph::build_vit( cb(Vcur, "Vcur_normed", il); } + // build_attn returns a flat 2D [n_embd, n_pos*B] cur = build_attn(layer.o_w, layer.o_b, Qcur, Kcur, Vcur, opts.attn_mask, kq_scale, il); cb(cur, "attn_out", il); @@ -509,6 +520,10 @@ ggml_tensor * clip_graph::build_vit( if (model.post_ln_w) { inpL = build_norm(inpL, model.post_ln_w, model.post_ln_b, norm_t, eps, -1); } + + // restore the batch dim + GGML_ASSERT(inpL->ne[1] % B == 0); + inpL = ggml_reshape_3d(ctx0, inpL, n_embd, inpL->ne[1] / B, B); return inpL; }