diff --git a/.gitignore b/.gitignore index 8dc9d7d0b8d..61d783272ad 100644 --- a/.gitignore +++ b/.gitignore @@ -160,3 +160,4 @@ a.out.* AGENTS.local.md .pi/SYSTEM.md +llama_cpp_windows.zip diff --git a/README-deepseek-ocr.md b/README-deepseek-ocr.md new file mode 100644 index 00000000000..65b274db229 --- /dev/null +++ b/README-deepseek-ocr.md @@ -0,0 +1,104 @@ +# DeepSeek-OCR Server + +## Quick start + +```bash +./build/bin/llama-server \ + -m "/path/to/deepseek-ocr-q4_k_m.gguf" \ + --mmproj "/path/to/mmproj-deepseek-ocr-f16.gguf" \ + --temp 0 --flash-attn off \ + --chat-template deepseek-ocr \ + -ngl 0 -c 2048 --host 0.0.0.0 --port 8000 +``` + +Flags: +- `-ngl 0` — CPU only (required on GPUs with <6GB VRAM) +- `--flash-attn off` — avoids CUDA OOM on low-VRAM GPUs +- `--chat-template deepseek-ocr` — enables the correct prompt format +- `--mmproj-gpu false` — if you still get OOM, also forces mmproj to CPU + +## API: `/v1/chat/completions` + +``` +POST http://localhost:8000/v1/chat/completions +Content-Type: application/json +``` + +### Request + +```json +{ + "model": "deepseek-ocr", + "max_tokens": 512, + "temperature": 0, + "messages": [ + { + "role": "user", + "content": [ + {"type": "image_url", "image_url": {"url": "data:image/png;base64,"}}, + {"type": "text", "text": "<|grounding|>Convert the document to markdown."} + ] + } + ] +} +``` + +### Response + +```json +{ + "choices": [{ + "message": { + "content": "text[[76, 149, 945, 288]]\n\nequation[[104, 299, 691, 351]]\n\\[latex...\\]" + } + }], + "usage": { + "prompt_tokens": 277, + "completion_tokens": 128, + "total_tokens": 405 + } +} +``` + +## Tips + +**VRAM issues:** The Q4_K_M model + F16 mmproj needs ~2.7GB + compute buffers. On 4GB GPUs (RTX 3050), use `-ngl 0`. For partial GPU offload try `-ngl 12 --flash-attn on -c 1024`. + +**Prompt prefix:** Always include `<|grounding|>` in the text to get document OCR mode. Without it the model may behave like a generic chatbot. + +**Output tags:** The model outputs `<|ref|>type<|/ref|><|det|>[x1,y1,x2,y2]<|/det|>` bounding boxes in CLI mode. The server strips the `<|ref|>`/`<|det|>` wrapping depending on the chat template prefix. + +## Python example + +```python +import requests, base64 + +with open("document.png", "rb") as f: + b64 = base64.b64encode(f.read()).decode() + +r = requests.post("http://localhost:8000/v1/chat/completions", json={ + "model": "deepseek-ocr", + "max_tokens": 512, + "temperature": 0, + "messages": [{ + "role": "user", + "content": [ + {"type": "image_url", + "image_url": {"url": f"data:image/png;base64,{b64}"}}, + {"type": "text", + "text": "<|grounding|>Convert the document to markdown."} + ] + }] +}) + +print(r.json()["choices"][0]["message"]["content"]) +``` + +## Troubleshooting + +| Error | Fix | +|-------|-----| +| `500 failed to process image` | GPU OOM — add `-ngl 0 --mmproj-gpu false` | +| `number of bitmaps (1) does not match number of markers (0)` | Missing `--chat-template deepseek-ocr` flag | +| `GGML_ASSERT(batch.n_tokens > 0)` | Outdated build — rebuild with latest patches | +| Output is `<__media__><|grounding|>...` literal text | Missing `--chat-template deepseek-ocr` | diff --git a/common/chat.cpp b/common/chat.cpp index 56873e3a1e9..5ff21e80604 100644 --- a/common/chat.cpp +++ b/common/chat.cpp @@ -2474,10 +2474,25 @@ static common_chat_params common_chat_templates_apply_legacy(const struct common } common_chat_params common_chat_templates_apply(const struct common_chat_templates * tmpls, - const struct common_chat_templates_inputs & inputs) { + const struct common_chat_templates_inputs & inputs) { GGML_ASSERT(tmpls != nullptr); - return inputs.use_jinja ? common_chat_templates_apply_jinja(tmpls, inputs) : - common_chat_templates_apply_legacy(tmpls, inputs); + + // if use_jinja is requested, check if the template source is actually a built-in template name + // if so, fall back to the legacy path which resolves built-in templates by name + if (inputs.use_jinja) { + const auto src = common_chat_templates_source(tmpls); + if (!src.empty()) { + // test if this is a built-in template name by attempting to apply it with an empty chat + // if it returns >= 0, it's a built-in template + int32_t test_res = llama_chat_apply_template(src.c_str(), nullptr, 0, false, nullptr, 0); + if (test_res >= 0) { + return common_chat_templates_apply_legacy(tmpls, inputs); + } + } + return common_chat_templates_apply_jinja(tmpls, inputs); + } + + return common_chat_templates_apply_legacy(tmpls, inputs); } common_chat_msg common_chat_parse(const std::string & input, diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index 9727a738ed8..0925bc1d53d 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -2608,6 +2608,13 @@ struct clip_model_loader { // alloc memory and offload data ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(ctx_clip.backend); ctx_clip.buf.reset(ggml_backend_alloc_ctx_tensors_from_buft(ctx_clip.ctx_data.get(), buft)); + if (!ctx_clip.buf && ctx_clip.backend != ctx_clip.backend_cpu) { + LOG_WRN("%s: WARNING: failed to allocate tensors on %s, falling back to CPU\n", __func__, ggml_backend_name(ctx_clip.backend)); + ctx_clip.backend = ctx_clip.backend_cpu; + buft = ggml_backend_get_default_buffer_type(ctx_clip.backend); + ctx_clip.buf.reset(ggml_backend_alloc_ctx_tensors_from_buft(ctx_clip.ctx_data.get(), buft)); + } + GGML_ASSERT(ctx_clip.buf && "failed to allocate tensors on any backend"); ggml_backend_buffer_set_usage(ctx_clip.buf.get(), GGML_BACKEND_BUFFER_USAGE_WEIGHTS); for (auto & t : tensors_to_load) { ggml_tensor * cur = ggml_get_tensor(ctx_clip.ctx_data.get(), t->name); @@ -3341,7 +3348,10 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima // build the inference graph ggml_backend_sched_reset(ctx->sched.get()); ggml_cgraph * gf = clip_image_build_graph(ctx, imgs); - ggml_backend_sched_alloc_graph(ctx->sched.get(), gf); + if (!ggml_backend_sched_alloc_graph(ctx->sched.get(), gf)) { + LOG_ERR("%s: failed to allocate compute graph (OOM)\n", __func__); + return false; + } // set inputs const auto & model = ctx->model; diff --git a/tools/server/server-common.cpp b/tools/server/server-common.cpp index dc00edfa82a..0eeda5a72d2 100644 --- a/tools/server/server-common.cpp +++ b/tools/server/server-common.cpp @@ -91,7 +91,8 @@ const char * get_media_marker() { if (env && env[0] != '\0') { return std::string(env); } - return std::string("<__media_") + random_string() + "__>"; + // must match mtmd_default_marker() so the tokenizer can split on it + return std::string(mtmd_default_marker()); }(); return marker.c_str(); } diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index 1ce7f095827..f70a6e03479 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -823,7 +823,13 @@ struct server_context_impl { if (!mmproj_path.empty()) { mtmd_context_params mparams = mtmd_context_params_default(); - mparams.use_gpu = params_base.mmproj_use_gpu; + // if user explicitly sets n_gpu_layers to 0, disable mmproj GPU too + bool mmproj_use_gpu = params_base.mmproj_use_gpu; + if (mmproj_use_gpu && params_base.n_gpu_layers == 0) { + LOG_INF("%s: n_gpu_layers=0, disabling mmproj GPU\n", __func__); + mmproj_use_gpu = false; + } + mparams.use_gpu = mmproj_use_gpu; mparams.print_timings = false; mparams.n_threads = params_base.cpuparams.n_threads; mparams.flash_attn_type = params_base.flash_attn_type; @@ -2751,9 +2757,10 @@ struct server_context_impl { n_swa > 0); bool has_mtmd = false; + bool slot_released = false; // check if we should process the image - while (slot.prompt.n_tokens() < slot.task->n_tokens() && input_tokens[slot.prompt.n_tokens()] == LLAMA_TOKEN_NULL) { + while (!slot_released && slot.prompt.n_tokens() < slot.task->n_tokens() && input_tokens[slot.prompt.n_tokens()] == LLAMA_TOKEN_NULL) { // process the image size_t n_tokens_out = 0; int32_t res = input_tokens.process_chunk(ctx_tgt, mctx, slot.prompt.n_tokens(), slot.prompt.tokens.pos_next(), slot.id, n_tokens_out); @@ -2761,7 +2768,8 @@ struct server_context_impl { SLT_ERR(slot, "failed to process image, res = %d\n", res); send_error(slot, "failed to process image", ERROR_TYPE_SERVER); slot.release(); - continue; + slot_released = true; + break; } if (ctx_dft) { @@ -2785,6 +2793,11 @@ struct server_context_impl { has_mtmd = true; } + if (slot_released) { + // released inside mtmd loop, skip the rest + continue; + } + // add prompt tokens for processing in the current batch while (slot.prompt.n_tokens() < slot.task->n_tokens() && batch.n_tokens < n_batch) { // get next token to process @@ -2842,6 +2855,40 @@ struct server_context_impl { if (slot.prompt.n_tokens() == slot.task->n_tokens()) { slot.state = SLOT_STATE_DONE_PROMPT; + // If the prompt ended with mtmd chunks and no text tokens remain, + // add the last valid text token to the batch to produce generation logits + if (batch.n_tokens == 0) { + GGML_ASSERT(has_mtmd); + // find the last non-NULL token from the processed prompt + bool found = false; + for (int i = (int)slot.prompt.tokens.size() - 1; i >= 0; i--) { + if (slot.prompt.tokens[i] != LLAMA_TOKEN_NULL) { + common_batch_add(batch, + slot.prompt.tokens[i], + slot.prompt.tokens.pos_next(), + { slot.id }, + slot.task->need_embd()); + slot.prompt.tokens.push_back(slot.prompt.tokens[i]); + found = true; + break; + } + } + if (!found) { + // image-only prompt (no text tokens at all) + // use BOS to produce logits so sampling can proceed + const auto bos = llama_vocab_bos(vocab); + if (bos != LLAMA_TOKEN_NULL) { + common_batch_add(batch, + bos, + slot.prompt.tokens.pos_next(), + { slot.id }, + slot.task->need_embd()); + slot.prompt.tokens.push_back(bos); + } + } + SRV_DBG("slot %12.*s: id %2d | task %d | prompt ended with mtmd, added synthetic token to batch\n", 12, __func__, slot.id, slot.task ? slot.task->id : -1); + } + GGML_ASSERT(batch.n_tokens > 0); // extract the logits only for the last token