Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 58 additions & 24 deletions tools/mtmd/mtmd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,16 @@ struct mtmd_bitmap {
bool is_audio = false; // true if the bitmap is audio
};

// position indexing for decoder model
enum mtmd_pos_type {
MTMD_POS_TYPE_NORMAL, // number of positions equals to number of tokens
MTMD_POS_TYPE_MROPE, // qwen-vl mrope style, each image takes max(t,h,w) position indexes
};

struct mtmd_image_tokens {
uint32_t nx; // number of tokens in x direction
uint32_t ny; // number of tokens in y direction
bool use_mrope_pos = false; // use M-RoPE position counting (the whole image is 1 temporal position)
mtmd_pos_type pos = MTMD_POS_TYPE_NORMAL;
uint32_t n_tokens() const { return nx * ny; }
clip_image_f32_batch batch_f32; // preprocessed image patches
std::string id; // optional user-defined ID, useful for KV cache tracking
Expand All @@ -45,7 +51,7 @@ struct mtmd_image_tokens {
return mtmd_image_tokens{
nx,
ny,
use_mrope_pos,
pos,
batch_f32.clone(),
id
};
Expand Down Expand Up @@ -131,7 +137,7 @@ struct mtmd_context {
int n_threads;
std::string media_marker;
const int n_embd_text;
llama_rope_type decoder_rope;
mtmd_pos_type pos_type;

// these are not token, but strings used to mark the beginning and end of image/audio embeddings
std::string img_beg;
Expand Down Expand Up @@ -168,8 +174,7 @@ struct mtmd_context {
print_timings(ctx_params.print_timings),
n_threads (ctx_params.n_threads),
media_marker (ctx_params.media_marker),
n_embd_text (llama_model_n_embd_inp(text_model)),
decoder_rope (llama_model_rope_type(text_model))
n_embd_text (llama_model_n_embd_inp(text_model))
{
if (ctx_params.image_marker != nullptr) {
throw std::runtime_error("custom image_marker is not supported anymore, use media_marker instead");
Expand All @@ -179,6 +184,22 @@ struct mtmd_context {
throw std::runtime_error("media_marker must not be empty");
}

auto decoder_rope_type = llama_model_rope_type(text_model);
switch (decoder_rope_type) {
case LLAMA_ROPE_TYPE_NORM:
case LLAMA_ROPE_TYPE_NEOX:
{
pos_type = MTMD_POS_TYPE_NORMAL;
} break;
case LLAMA_ROPE_TYPE_MROPE:
case LLAMA_ROPE_TYPE_IMROPE:
{
pos_type = MTMD_POS_TYPE_MROPE;
} break;
default:
throw std::runtime_error(string_format("unsupported decoder rope type: %d\n", decoder_rope_type));
}

clip_context_params ctx_clip_params {
/* use_gpu */ ctx_params.use_gpu,
/* flash_attn_type */ mtmd_get_clip_flash_attn_type(ctx_params.flash_attn_type),
Expand Down Expand Up @@ -779,12 +800,12 @@ struct mtmd_tokenizer {
// for Qwen2VL, we need this information for M-RoPE decoding positions
image_tokens->nx = clip_n_output_tokens_x(ctx->ctx_v, batch_f32.entries[0].get());
image_tokens->ny = clip_n_output_tokens_y(ctx->ctx_v, batch_f32.entries[0].get());
image_tokens->use_mrope_pos = true;
} else {
// other models, we only need the total number of tokens
image_tokens->nx = n_tokens;
image_tokens->ny = 1;
}
image_tokens->pos = ctx->pos_type;
image_tokens->batch_f32 = std::move(batch_f32);
image_tokens->id = bitmap->id; // optional

Expand Down Expand Up @@ -1016,7 +1037,7 @@ float * mtmd_get_output_embd(mtmd_context * ctx) {
return ctx->image_embd_v.data();
}

bool mtmd_decode_use_non_causal(mtmd_context * ctx, const mtmd_input_chunk * chunk) {
bool mtmd_decode_use_non_causal(const mtmd_context * ctx, const mtmd_input_chunk * chunk) {
auto proj_type = ctx->proj_type_v();
if (chunk && chunk->type == MTMD_INPUT_CHUNK_TYPE_AUDIO) {
proj_type = ctx->proj_type_a();
Expand All @@ -1030,20 +1051,19 @@ bool mtmd_decode_use_non_causal(mtmd_context * ctx, const mtmd_input_chunk * chu
}
}

bool mtmd_decode_use_mrope(mtmd_context * ctx) {
return ctx->decoder_rope == LLAMA_ROPE_TYPE_MROPE
|| ctx->decoder_rope == LLAMA_ROPE_TYPE_IMROPE;
bool mtmd_decode_use_mrope(const mtmd_context * ctx) {
return ctx->pos_type;
}

bool mtmd_support_vision(mtmd_context * ctx) {
bool mtmd_support_vision(const mtmd_context * ctx) {
return ctx->ctx_v != nullptr;
}

bool mtmd_support_audio(mtmd_context * ctx) {
bool mtmd_support_audio(const mtmd_context * ctx) {
return ctx->ctx_a != nullptr;
}

int mtmd_get_audio_sample_rate(mtmd_context * ctx) {
int mtmd_get_audio_sample_rate(const mtmd_context * ctx) {
if (!ctx->ctx_a) {
return -1;
}
Expand Down Expand Up @@ -1238,12 +1258,24 @@ size_t mtmd_image_tokens_get_ny(const mtmd_image_tokens * image_tokens) {

mtmd_decoder_pos mtmd_image_tokens_get_decoder_pos(const mtmd_image_tokens * image_tokens, llama_pos pos_0, size_t i) {
mtmd_decoder_pos pos;
// M-RoPE logic
// TODO: support other types of position encoding if needed
pos.t = pos_0;
pos.x = pos_0 + (i % image_tokens->nx);
pos.y = pos_0 + (i / image_tokens->nx);
pos.z = 0; // unused for now
switch (image_tokens->pos) {
case MTMD_POS_TYPE_MROPE:
{
pos.t = pos_0;
pos.x = pos_0 + (i % image_tokens->nx);
pos.y = pos_0 + (i / image_tokens->nx);
pos.z = 0; // unused for now
} break;
case MTMD_POS_TYPE_NORMAL:
{
pos.t = pos_0 + i;
pos.x = pos_0 + i;
pos.y = pos_0 + i;
pos.z = pos_0 + i;
} break;
default:
GGML_ABORT("invalid position type");
}
return pos;
}

Expand All @@ -1252,12 +1284,14 @@ const char * mtmd_image_tokens_get_id(const mtmd_image_tokens * image_tokens) {
}

llama_pos mtmd_image_tokens_get_n_pos(const mtmd_image_tokens * image_tokens) {
if (image_tokens->use_mrope_pos) {
// for M-RoPE, temporal dimension = max(t,h,w)
// t is omitted as we don't support video input
return std::max(image_tokens->nx, image_tokens->ny);
switch (image_tokens->pos) {
case MTMD_POS_TYPE_MROPE:
return std::max(image_tokens->nx, image_tokens->ny);
case MTMD_POS_TYPE_NORMAL:
return image_tokens->n_tokens();
default:
GGML_ABORT("invalid position type");
}
return image_tokens->n_tokens();
}

// test function
Expand Down
10 changes: 5 additions & 5 deletions tools/mtmd/mtmd.h
Original file line number Diff line number Diff line change
Expand Up @@ -112,20 +112,20 @@ MTMD_API void mtmd_free(mtmd_context * ctx);

// whether we need to set non-causal mask before llama_decode
// if chunk is nullptr, we assume the default case where chunk is an image chunk
MTMD_API bool mtmd_decode_use_non_causal(mtmd_context * ctx, const mtmd_input_chunk * chunk);
MTMD_API bool mtmd_decode_use_non_causal(const mtmd_context * ctx, const mtmd_input_chunk * chunk);

// whether the current model use M-RoPE for llama_decode
MTMD_API bool mtmd_decode_use_mrope(mtmd_context * ctx);
MTMD_API bool mtmd_decode_use_mrope(const mtmd_context * ctx);

// whether the current model supports vision input
MTMD_API bool mtmd_support_vision(mtmd_context * ctx);
MTMD_API bool mtmd_support_vision(const mtmd_context * ctx);

// whether the current model supports audio input
MTMD_API bool mtmd_support_audio(mtmd_context * ctx);
MTMD_API bool mtmd_support_audio(const mtmd_context * ctx);

// get audio sample rate in Hz, for example 16000 for Whisper
// return -1 if audio is not supported
MTMD_API int mtmd_get_audio_sample_rate(mtmd_context * ctx);
MTMD_API int mtmd_get_audio_sample_rate(const mtmd_context * ctx);

// mtmd_bitmap
//
Expand Down
Loading