Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
157 changes: 99 additions & 58 deletions tools/mtmd/mtmd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ struct mtmd_image_tokens {
clip_image_f32_batch batch_f32; // preprocessed image patches
std::string id; // optional user-defined ID, useful for KV cache tracking

mtmd_image_tokens clone() {
mtmd_image_tokens clone() const {
return mtmd_image_tokens{
nx,
ny,
Expand Down Expand Up @@ -409,12 +409,6 @@ int32_t mtmd_tokenize(mtmd_context * ctx,
return 0;
}

static void mtmd_image_tokens_free(mtmd_image_tokens * image_tokens) {
if (image_tokens) {
delete image_tokens;
}
}

int32_t mtmd_encode(mtmd_context * ctx, const mtmd_image_tokens * image_tokens) {
int n_mmproj_embd = clip_n_mmproj_embd(ctx->ctx_clip);
ctx->image_embd_v.resize(image_tokens->n_tokens() * n_mmproj_embd);
Expand Down Expand Up @@ -454,6 +448,23 @@ float * mtmd_get_output_embd(mtmd_context * ctx) {
return ctx->image_embd_v.data();
}

float * mtmd_get_output_embd_copy(mtmd_context * ctx, size_t * n_embd_out) {
if (ctx->image_embd_v.empty()) {
*n_embd_out = 0;
return NULL;
}

*n_embd_out = ctx->image_embd_v.size();
float * copy = (float *) malloc(*n_embd_out * sizeof(float));
if (copy == NULL) {
*n_embd_out = 0;
return NULL;
}

memcpy(copy, ctx->image_embd_v.data(), ctx->image_embd_v.size() * sizeof(float));
return copy;
}

size_t mtmd_helper_get_n_tokens(const mtmd_input_chunks * chunks) {
size_t n_tokens = 0;
for (size_t i = 0; i < mtmd_input_chunks_size(chunks); i++) {
Expand Down Expand Up @@ -580,6 +591,69 @@ struct decode_embd_batch {
}
};

// Helper function for decoding an image whose embeddings have already been calculated
int32_t mtmd_helper_decode_image(
mtmd_context * ctx,
struct llama_context * lctx,
const mtmd_image_tokens * image_tokens,
float * embd,
llama_pos n_past,
llama_seq_id seq_id,
int32_t n_batch,
llama_pos * new_n_past) {
int n_mmproj_embd = clip_n_mmproj_embd(ctx->ctx_clip);
int n_pos_per_embd = mtmd_decode_use_mrope(ctx) ? 4 : 1;

int32_t n_tokens = mtmd_image_tokens_get_n_tokens(image_tokens);
int32_t i_batch = 0;
int32_t n_img_batches = GGML_PAD(n_tokens, n_batch) / n_batch;
decode_embd_batch batch_embd(embd, n_tokens, n_pos_per_embd, n_mmproj_embd);

const int nx = mtmd_image_tokens_get_nx(image_tokens);
const int ny = mtmd_image_tokens_get_ny(image_tokens);

if (mtmd_decode_use_mrope(ctx)) {
batch_embd.set_position_mrope(n_past, nx, ny, seq_id);
} else {
batch_embd.set_position_normal(n_past, seq_id);
}

if (mtmd_decode_use_non_causal(ctx)) {
llama_set_causal_attn(lctx, false);
// TODO @ngxson : need to make sure only one image is processed at a time, and n_ubatch must be enough to hold the image
}

while (i_batch < n_img_batches) { // split into batches
int pos_offset = i_batch*n_batch;
int n_tokens_batch = std::min(n_batch, n_tokens - pos_offset);
llama_batch batch_embd_view = batch_embd.get_view(pos_offset, n_tokens_batch);

LOG_INF("decoding image batch %d/%d, n_tokens_batch = %d\n", i_batch+1, n_img_batches, n_tokens_batch);

int64_t t1 = ggml_time_ms();
int32_t ret = llama_decode(lctx, batch_embd_view);
if (ret != 0) {
LOG_ERR("failed to decode image\n");
llama_set_causal_attn(lctx, true); // restore causal attn
return ret;
}

if (ctx->print_timings) {
LOG_INF("image decoded (batch %d/%d) in %" PRId64 " ms\n", i_batch+1, n_img_batches, ggml_time_ms() - t1);
}

i_batch++;
}

n_past += mtmd_image_tokens_get_n_pos(image_tokens);
*new_n_past = n_past;

if (mtmd_decode_use_non_causal(ctx)) {
llama_set_causal_attn(lctx, true);
}
return 0;
}

int32_t mtmd_helper_eval_chunk_single(mtmd_context * ctx,
struct llama_context * lctx,
const mtmd_input_chunk * chunk,
Expand All @@ -591,8 +665,6 @@ int32_t mtmd_helper_eval_chunk_single(mtmd_context * ctx,
int32_t ret;
llama_batch text_batch = llama_batch_init(n_batch, 0, 1);
auto chunk_type = mtmd_input_chunk_get_type(chunk);
int n_mmproj_embd = clip_n_mmproj_embd(ctx->ctx_clip);
int n_pos_per_embd = mtmd_decode_use_mrope(ctx) ? 4 : 1;

if (chunk_type == MTMD_INPUT_CHUNK_TYPE_TEXT) {
size_t n_tokens;
Expand Down Expand Up @@ -637,57 +709,13 @@ int32_t mtmd_helper_eval_chunk_single(mtmd_context * ctx,
if (ctx->print_timings) {
LOG_INF("image/slice encoded in %" PRId64 " ms\n", ggml_time_ms() - t0);
}

int32_t n_tokens = mtmd_image_tokens_get_n_tokens(image_tokens);
int32_t i_batch = 0;
int32_t n_img_batches = GGML_PAD(n_tokens, n_batch) / n_batch;
float * embd = mtmd_get_output_embd(ctx);
decode_embd_batch batch_embd(embd, n_tokens, n_pos_per_embd, n_mmproj_embd);

const int nx = mtmd_image_tokens_get_nx(image_tokens);
const int ny = mtmd_image_tokens_get_ny(image_tokens);

if (mtmd_decode_use_mrope(ctx)) {
batch_embd.set_position_mrope(n_past, nx, ny, seq_id);
} else {
batch_embd.set_position_normal(n_past, seq_id);
}

if (mtmd_decode_use_non_causal(ctx)) {
llama_set_causal_attn(lctx, false);
// TODO @ngxson : need to make sure only one image is processed at a time, and n_ubatch must be enough to hold the image
}

while (i_batch < n_img_batches) { // split into batches
int pos_offset = i_batch*n_batch;
int n_tokens_batch = std::min(n_batch, n_tokens - pos_offset);
llama_batch batch_embd_view = batch_embd.get_view(pos_offset, n_tokens_batch);

LOG_INF("decoding image batch %d/%d, n_tokens_batch = %d\n", i_batch+1, n_img_batches, n_tokens_batch);

int64_t t1 = ggml_time_ms();
ret = llama_decode(lctx, batch_embd_view);
if (ret != 0) {
LOG_ERR("failed to decode image\n");
llama_set_causal_attn(lctx, true); // restore causal attn
llama_batch_free(text_batch);
return ret;
}

if (ctx->print_timings) {
LOG_INF("image decoded (batch %d/%d) in %" PRId64 " ms\n", i_batch+1, n_img_batches, ggml_time_ms() - t1);
}

i_batch++;
}

n_past += mtmd_image_tokens_get_n_pos(image_tokens);
*new_n_past = n_past;

if (mtmd_decode_use_non_causal(ctx)) {
llama_set_causal_attn(lctx, true);
ret = mtmd_helper_decode_image(ctx, lctx, image_tokens, embd, n_past, seq_id, n_batch, new_n_past);
if (ret != 0) {
LOG_ERR("failed to decode image\n");
llama_batch_free(text_batch);
return ret;
}

} else {
GGML_ABORT("chunk type not supported");
}
Expand Down Expand Up @@ -903,6 +931,19 @@ llama_pos mtmd_image_tokens_get_n_pos(const mtmd_image_tokens * image_tokens) {
return image_tokens->n_tokens();
}

void mtmd_image_tokens_free(mtmd_image_tokens * image_tokens) {
if (image_tokens) {
delete image_tokens;
}
}

mtmd_image_tokens * mtmd_image_tokens_copy(const mtmd_image_tokens * image_tokens) {
if (!image_tokens) {
return nullptr;
}
return new mtmd_image_tokens(image_tokens->clone());
}

// test function

mtmd_input_chunks * mtmd_test_create_input_chunks() {
Expand Down
30 changes: 25 additions & 5 deletions tools/mtmd/mtmd.h
Original file line number Diff line number Diff line change
Expand Up @@ -143,12 +143,14 @@ MTMD_API void mtmd_input_chunk_free(mtmd_input_chunk * chunk);
//
// the instance will be constructed via mtmd_tokenize()
// it will be freed along with mtmd_input_chunk
MTMD_API size_t mtmd_image_tokens_get_n_tokens(const mtmd_image_tokens * image_tokens);
MTMD_API size_t mtmd_image_tokens_get_nx (const mtmd_image_tokens * image_tokens);
MTMD_API size_t mtmd_image_tokens_get_ny (const mtmd_image_tokens * image_tokens);
MTMD_API const char * mtmd_image_tokens_get_id (const mtmd_image_tokens * image_tokens);
MTMD_API size_t mtmd_image_tokens_get_n_tokens (const mtmd_image_tokens * image_tokens);
MTMD_API size_t mtmd_image_tokens_get_nx (const mtmd_image_tokens * image_tokens);
MTMD_API size_t mtmd_image_tokens_get_ny (const mtmd_image_tokens * image_tokens);
MTMD_API const char * mtmd_image_tokens_get_id (const mtmd_image_tokens * image_tokens);
// number of temporal positions (always 1 for M-RoPE, n_tokens otherwise)
MTMD_API llama_pos mtmd_image_tokens_get_n_pos (const mtmd_image_tokens * image_tokens);
MTMD_API llama_pos mtmd_image_tokens_get_n_pos (const mtmd_image_tokens * image_tokens);
MTMD_API mtmd_image_tokens * mtmd_image_tokens_copy (const mtmd_image_tokens * image_tokens);
MTMD_API void mtmd_image_tokens_free (mtmd_image_tokens * image_tokens);

// tokenize an input text prompt and an image
// the prompt must have the input image marker (default: "<__image__>") in it
Expand Down Expand Up @@ -178,6 +180,9 @@ MTMD_API int32_t mtmd_encode(mtmd_context * ctx,
// get output embeddings from the last encode pass
MTMD_API float * mtmd_get_output_embd(mtmd_context * ctx);

// returns a copy of output embeddings from the last encode pass, of size n_embd_out
MTMD_API float * mtmd_get_output_embd_copy(mtmd_context * ctx, size_t * n_embd_out);

/////////////////////////////////////////

//
Expand Down Expand Up @@ -231,6 +236,16 @@ MTMD_API int32_t mtmd_helper_eval_chunk_single(mtmd_context * ctx,
bool logits_last,
llama_pos * new_n_past);

// helper function to decode an image whose embeddings have already been calculated
MTMD_API int32_t mtmd_helper_decode_image(mtmd_context *ctx,
struct llama_context *lctx,
const mtmd_image_tokens *image_tokens,
float *embd,
llama_pos n_past,
llama_seq_id seq_id,
int32_t n_batch,
llama_pos *new_n_past);

/////////////////////////////////////////

// test function, to be used in test-mtmd-c-api.c
Expand Down Expand Up @@ -268,6 +283,11 @@ struct mtmd_input_chunk_deleter {
};
using input_chunk_ptr = std::unique_ptr<mtmd_input_chunk, mtmd_input_chunk_deleter>;

struct mtmd_image_tokens_deleter {
void operator()(mtmd_image_tokens * val) { mtmd_image_tokens_free(val); }
};
using image_tokens_ptr = std::unique_ptr<mtmd_image_tokens, mtmd_image_tokens_deleter>;

struct bitmap {
bitmap_ptr ptr;
bitmap() : ptr(nullptr) {}
Expand Down
Loading