Skip to content

Conversation

@theo77186
Copy link
Contributor

This PR adds cases for head size 72, used for Qwen3-VL and Gemma 3 vision encoders. Added only the cases for the tile kernel, like head size 40.
The parameters for the tile kernel is taken from head size 40, will probably require some optimizations. Tested with test-backend-ops on a 3060 and a 4060 Ti. Not tested on AMD, though I have added the cases for AMD.
Fixes #16950

@ngxson
Copy link
Collaborator

ngxson commented Nov 3, 2025

Nice! Could you also upload the output logs when you start llama-server or llama-mtmd-cli with Qwen3-VL-32B ?

Copy link
Collaborator

@JohannesGaessler JohannesGaessler left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These changes look correct going by static code analysis.

@theo77186
Copy link
Contributor Author

Nice! Could you also upload the output logs when you start llama-server or llama-mtmd-cli with Qwen3-VL-32B ?

Startup logs
ggml_cuda_init: GGML_CUDA_FORCE_MMQ:    no
ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no
ggml_cuda_init: found 2 CUDA devices:
  Device 0: NVIDIA GeForce RTX 4060 Ti, compute capability 8.9, VMM: yes
  Device 1: NVIDIA GeForce RTX 3060, compute capability 8.6, VMM: yes
main: setting n_parallel = 4 and kv_unified = true
build: 6934 (72545ce2b) with cc (Debian 15.2.0-7) 15.2.0 for x86_64-linux-gnu
system info: n_threads = 16, n_threads_batch = 16, total_threads = 32

system_info: n_threads = 16 (n_threads_batch = 16) / 32 | CUDA : ARCHS = 860,890 | USE_GRAPHS = 1 | PEER_MAX_BATCH_SIZE = 128 | CPU : SSE3 = 1 | SSSE3 = 1 | AVX = 1 | AVX2 = 1 | F16C = 1 | FMA = 1 | BMI2 = 1 | LLAMAFILE = 1 | OPENMP = 1 | REPACK = 1 | 

main: binding port with default address family
main: HTTP server is listening, hostname: 127.0.0.1, port: 8080, http threads: 31
main: loading model
srv    load_model: loading model 'Qwen3VL-32B-Instruct-Q4_K_M.gguf'
llama_model_load_from_file_impl: using device CUDA0 (NVIDIA GeForce RTX 4060 Ti) (0000:2b:00.0) - 15001 MiB free
llama_model_load_from_file_impl: using device CUDA1 (NVIDIA GeForce RTX 3060) (0000:04:00.0) - 11822 MiB free
llama_model_loader: loaded meta data with 32 key-value pairs and 707 tensors from Qwen3VL-32B-Instruct-Q4_K_M.gguf (version GGUF V3 (latest))
llama_model_loader: Dumping metadata keys/values. Note: KV overrides do not apply in this output.
llama_model_loader: - kv   0:                       general.architecture str              = qwen3vl
llama_model_loader: - kv   1:                               general.type str              = model
llama_model_loader: - kv   2:                               general.name str              = Qwen3Vl 32b Instruct
llama_model_loader: - kv   3:                           general.finetune str              = instruct
llama_model_loader: - kv   4:                           general.basename str              = qwen3vl
llama_model_loader: - kv   5:                         general.size_label str              = 32B
llama_model_loader: - kv   6:                            general.license str              = apache-2.0
llama_model_loader: - kv   7:                               general.tags arr[str,1]       = ["image-text-to-text"]
llama_model_loader: - kv   8:                        qwen3vl.block_count u32              = 64
llama_model_loader: - kv   9:                     qwen3vl.context_length u32              = 262144
llama_model_loader: - kv  10:                   qwen3vl.embedding_length u32              = 5120
llama_model_loader: - kv  11:                qwen3vl.feed_forward_length u32              = 25600
llama_model_loader: - kv  12:               qwen3vl.attention.head_count u32              = 64
llama_model_loader: - kv  13:            qwen3vl.attention.head_count_kv u32              = 8
llama_model_loader: - kv  14:                     qwen3vl.rope.freq_base f32              = 5000000.000000
llama_model_loader: - kv  15:   qwen3vl.attention.layer_norm_rms_epsilon f32              = 0.000001
llama_model_loader: - kv  16:               qwen3vl.attention.key_length u32              = 128
llama_model_loader: - kv  17:             qwen3vl.attention.value_length u32              = 128
llama_model_loader: - kv  18:            qwen3vl.rope.dimension_sections arr[i32,4]       = [24, 20, 20, 0]
llama_model_loader: - kv  19:                 qwen3vl.n_deepstack_layers u32              = 3
llama_model_loader: - kv  20:                       tokenizer.ggml.model str              = gpt2
llama_model_loader: - kv  21:                         tokenizer.ggml.pre str              = qwen2
llama_model_loader: - kv  22:                      tokenizer.ggml.tokens arr[str,151936]  = ["!", "\"", "#", "$", "%", "&", "'", ...
llama_model_loader: - kv  23:                  tokenizer.ggml.token_type arr[i32,151936]  = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...
llama_model_loader: - kv  24:                      tokenizer.ggml.merges arr[str,151387]  = ["Ġ Ġ", "ĠĠ ĠĠ", "i n", "Ġ t",...
llama_model_loader: - kv  25:                tokenizer.ggml.eos_token_id u32              = 151645
llama_model_loader: - kv  26:            tokenizer.ggml.padding_token_id u32              = 151643
llama_model_loader: - kv  27:                tokenizer.ggml.bos_token_id u32              = 151643
llama_model_loader: - kv  28:               tokenizer.ggml.add_bos_token bool             = false
llama_model_loader: - kv  29:                    tokenizer.chat_template str              = {%- if tools %}\n    {{- '<|im_start|>...
llama_model_loader: - kv  30:               general.quantization_version u32              = 2
llama_model_loader: - kv  31:                          general.file_type u32              = 15
llama_model_loader: - type  f32:  257 tensors
llama_model_loader: - type q4_K:  385 tensors
llama_model_loader: - type q6_K:   65 tensors
print_info: file format = GGUF V3 (latest)
print_info: file type   = Q4_K - Medium
print_info: file size   = 18.40 GiB (4.82 BPW) 
load: printing all EOG tokens:
load:   - 151643 ('<|endoftext|>')
load:   - 151645 ('<|im_end|>')
load:   - 151662 ('<|fim_pad|>')
load:   - 151663 ('<|repo_name|>')
load:   - 151664 ('<|file_sep|>')
load: special tokens cache size = 26
load: token to piece cache size = 0.9311 MB
print_info: arch             = qwen3vl
print_info: vocab_only       = 0
print_info: n_ctx_train      = 262144
print_info: n_embd           = 20480
print_info: n_layer          = 64
print_info: n_head           = 64
print_info: n_head_kv        = 8
print_info: n_rot            = 128
print_info: n_swa            = 0
print_info: is_swa_any       = 0
print_info: n_embd_head_k    = 128
print_info: n_embd_head_v    = 128
print_info: n_gqa            = 8
print_info: n_embd_k_gqa     = 1024
print_info: n_embd_v_gqa     = 1024
print_info: f_norm_eps       = 0.0e+00
print_info: f_norm_rms_eps   = 1.0e-06
print_info: f_clamp_kqv      = 0.0e+00
print_info: f_max_alibi_bias = 0.0e+00
print_info: f_logit_scale    = 0.0e+00
print_info: f_attn_scale     = 0.0e+00
print_info: n_ff             = 25600
print_info: n_expert         = 0
print_info: n_expert_used    = 0
print_info: n_expert_groups  = 0
print_info: n_group_used     = 0
print_info: causal attn      = 1
print_info: pooling type     = 0
print_info: rope type        = 40
print_info: rope scaling     = linear
print_info: freq_base_train  = 5000000.0
print_info: freq_scale_train = 1
print_info: n_ctx_orig_yarn  = 262144
print_info: rope_finetuned   = unknown
print_info: mrope sections   = [24, 20, 20, 0]
print_info: model type       = 32B
print_info: model params     = 32.76 B
print_info: general.name     = Qwen3Vl 32b Instruct
print_info: vocab type       = BPE
print_info: n_vocab          = 151936
print_info: n_merges         = 151387
print_info: BOS token        = 151643 '<|endoftext|>'
print_info: EOS token        = 151645 '<|im_end|>'
print_info: EOT token        = 151645 '<|im_end|>'
print_info: PAD token        = 151643 '<|endoftext|>'
print_info: LF token         = 198 'Ċ'
print_info: FIM PRE token    = 151659 '<|fim_prefix|>'
print_info: FIM SUF token    = 151661 '<|fim_suffix|>'
print_info: FIM MID token    = 151660 '<|fim_middle|>'
print_info: FIM PAD token    = 151662 '<|fim_pad|>'
print_info: FIM REP token    = 151663 '<|repo_name|>'
print_info: FIM SEP token    = 151664 '<|file_sep|>'
print_info: EOG token        = 151643 '<|endoftext|>'
print_info: EOG token        = 151645 '<|im_end|>'
print_info: EOG token        = 151662 '<|fim_pad|>'
print_info: EOG token        = 151663 '<|repo_name|>'
print_info: EOG token        = 151664 '<|file_sep|>'
print_info: max token length = 256
load_tensors: loading model tensors, this can take a while... (mmap = true)
load_tensors: offloading 64 repeating layers to GPU
load_tensors: offloading output layer to GPU
load_tensors: offloaded 65/65 layers to GPU
load_tensors:   CPU_Mapped model buffer size =   417.30 MiB
load_tensors:        CUDA0 model buffer size = 10249.06 MiB
load_tensors:        CUDA1 model buffer size =  8174.59 MiB
.................................................................................................
llama_context: constructing llama_context
llama_context: n_seq_max     = 4
llama_context: n_ctx         = 8192
llama_context: n_ctx_seq     = 8192
llama_context: n_batch       = 2048
llama_context: n_ubatch      = 512
llama_context: causal_attn   = 1
llama_context: flash_attn    = auto
llama_context: kv_unified    = true
llama_context: freq_base     = 5000000.0
llama_context: freq_scale    = 1
llama_context: n_ctx_seq (8192) < n_ctx_train (262144) -- the full capacity of the model will not be utilized
llama_context:  CUDA_Host  output buffer size =     2.32 MiB
llama_kv_cache:      CUDA0 KV buffer size =  1184.00 MiB
llama_kv_cache:      CUDA1 KV buffer size =   864.00 MiB
llama_kv_cache: size = 2048.00 MiB (  8192 cells,  64 layers,  4/1 seqs), K (f16): 1024.00 MiB, V (f16): 1024.00 MiB
llama_context: Flash Attention was auto, set to enabled
llama_context:      CUDA0 compute buffer size =   194.02 MiB
llama_context:      CUDA1 compute buffer size =   306.75 MiB
llama_context:  CUDA_Host compute buffer size =    26.02 MiB
llama_context: graph nodes  = 2247
llama_context: graph splits = 3
common_init_from_params: added <|endoftext|> logit bias = -inf
common_init_from_params: added <|im_end|> logit bias = -inf
common_init_from_params: added <|fim_pad|> logit bias = -inf
common_init_from_params: added <|repo_name|> logit bias = -inf
common_init_from_params: added <|file_sep|> logit bias = -inf
common_init_from_params: setting dry_penalty_last_n to ctx_size = 8192
common_init_from_params: warming up the model with an empty run - please wait ... (--no-warmup to disable)
clip_model_loader: model name:   Qwen3Vl 32b Instruct
clip_model_loader: description:  
clip_model_loader: GGUF version: 3
clip_model_loader: alignment:    32
clip_model_loader: n_tensors:    352
clip_model_loader: n_kv:         25

clip_model_loader: has vision encoder
clip_ctx: CLIP using CUDA0 backend
load_hparams: projector:          qwen3vl_merger
load_hparams: n_embd:             1152
load_hparams: n_head:             16
load_hparams: n_ff:               4304
load_hparams: n_layer:            27
load_hparams: ffn_op:             gelu
load_hparams: projection_dim:     5120

--- vision hparams ---
load_hparams: image_size:         768
load_hparams: patch_size:         16
load_hparams: has_llava_proj:     0
load_hparams: minicpmv_version:   0
load_hparams: n_merge:            2
load_hparams: n_wa_pattern:       0
load_hparams: image_min_pixels:   8192
load_hparams: image_max_pixels:   2097152

load_hparams: model size:         1141.33 MiB
load_hparams: metadata size:      0.12 MiB
alloc_compute_meta: warmup with image size = 512 x 512
alloc_compute_meta:      CUDA0 compute buffer size =    50.52 MiB
alloc_compute_meta:        CPU compute buffer size =     3.02 MiB
alloc_compute_meta: graph splits = 1, nodes = 853
warmup: flash attention is enabled
srv    load_model: loaded multimodal model, 'mmproj-Qwen3VL-32B-Instruct-F16.gguf'
srv          init: initializing slots, n_slots = 4
slot         init: id  0 | task -1 | new slot, n_ctx = 8192
slot         init: id  1 | task -1 | new slot, n_ctx = 8192
slot         init: id  2 | task -1 | new slot, n_ctx = 8192
slot         init: id  3 | task -1 | new slot, n_ctx = 8192
srv          init: prompt cache is enabled, size limit: 8192 MiB
srv          init: use `--cache-ram 0` to disable the prompt cache
srv          init: for more info see https://github.com/ggml-org/llama.cpp/pull/16391
srv          init: thinking = 0
main: model loaded
main: chat template, chat_template: {%- if tools %}
    {{- '<|im_start|>system\n' }}
    {%- if messages[0].role == 'system' %}
        {%- if messages[0].content is string %}
            {{- messages[0].content }}
        {%- else %}
            {%- for content in messages[0].content %}
                {%- if 'text' in content %}
                    {{- content.text }}
                {%- endif %}
            {%- endfor %}
        {%- endif %}
        {{- '\n\n' }}
    {%- endif %}
    {{- "# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>" }}
    {%- for tool in tools %}
        {{- "\n" }}
        {{- tool | tojson }}
    {%- endfor %}
    {{- "\n</tools>\n\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{\"name\": <function-name>, \"arguments\": <args-json-object>}\n</tool_call><|im_end|>\n" }}
{%- else %}
    {%- if messages[0].role == 'system' %}
        {{- '<|im_start|>system\n' }}
        {%- if messages[0].content is string %}
            {{- messages[0].content }}
        {%- else %}
            {%- for content in messages[0].content %}
                {%- if 'text' in content %}
                    {{- content.text }}
                {%- endif %}
            {%- endfor %}
        {%- endif %}
        {{- '<|im_end|>\n' }}
    {%- endif %}
{%- endif %}
{%- set image_count = namespace(value=0) %}
{%- set video_count = namespace(value=0) %}
{%- for message in messages %}
    {%- if message.role == "user" %}
        {{- '<|im_start|>' + message.role + '\n' }}
        {%- if message.content is string %}
            {{- message.content }}
        {%- else %}
            {%- for content in message.content %}
                {%- if content.type == 'image' or 'image' in content or 'image_url' in content %}
                    {%- set image_count.value = image_count.value + 1 %}
                    {%- if add_vision_id %}Picture {{ image_count.value }}: {% endif -%}
                    <|vision_start|><|image_pad|><|vision_end|>
                {%- elif content.type == 'video' or 'video' in content %}
                    {%- set video_count.value = video_count.value + 1 %}
                    {%- if add_vision_id %}Video {{ video_count.value }}: {% endif -%}
                    <|vision_start|><|video_pad|><|vision_end|>
                {%- elif 'text' in content %}
                    {{- content.text }}
                {%- endif %}
            {%- endfor %}
        {%- endif %}
        {{- '<|im_end|>\n' }}
    {%- elif message.role == "assistant" %}
        {{- '<|im_start|>' + message.role + '\n' }}
        {%- if message.content is string %}
            {{- message.content }}
        {%- else %}
            {%- for content_item in message.content %}
                {%- if 'text' in content_item %}
                    {{- content_item.text }}
                {%- endif %}
            {%- endfor %}
        {%- endif %}
        {%- if message.tool_calls %}
            {%- for tool_call in message.tool_calls %}
                {%- if (loop.first and message.content) or (not loop.first) %}
                    {{- '\n' }}
                {%- endif %}
                {%- if tool_call.function %}
                    {%- set tool_call = tool_call.function %}
                {%- endif %}
                {{- '<tool_call>\n{"name": "' }}
                {{- tool_call.name }}
                {{- '", "arguments": ' }}
                {%- if tool_call.arguments is string %}
                    {{- tool_call.arguments }}
                {%- else %}
                    {{- tool_call.arguments | tojson }}
                {%- endif %}
                {{- '}\n</tool_call>' }}
            {%- endfor %}
        {%- endif %}
        {{- '<|im_end|>\n' }}
    {%- elif message.role == "tool" %}
        {%- if loop.first or (messages[loop.index0 - 1].role != "tool") %}
            {{- '<|im_start|>user' }}
        {%- endif %}
        {{- '\n<tool_response>\n' }}
        {%- if message.content is string %}
            {{- message.content }}
        {%- else %}
            {%- for content in message.content %}
                {%- if content.type == 'image' or 'image' in content or 'image_url' in content %}
                    {%- set image_count.value = image_count.value + 1 %}
                    {%- if add_vision_id %}Picture {{ image_count.value }}: {% endif -%}
                    <|vision_start|><|image_pad|><|vision_end|>
                {%- elif content.type == 'video' or 'video' in content %}
                    {%- set video_count.value = video_count.value + 1 %}
                    {%- if add_vision_id %}Video {{ video_count.value }}: {% endif -%}
                    <|vision_start|><|video_pad|><|vision_end|>
                {%- elif 'text' in content %}
                    {{- content.text }}
                {%- endif %}
            {%- endfor %}
        {%- endif %}
        {{- '\n</tool_response>' }}
        {%- if loop.last or (messages[loop.index0 + 1].role != "tool") %}
            {{- '<|im_end|>\n' }}
        {%- endif %}
    {%- endif %}
{%- endfor %}
{%- if add_generation_prompt %}
    {{- '<|im_start|>assistant\n' }}
{%- endif %}
, example_format: '<|im_start|>system
You are a helpful assistant<|im_end|>
<|im_start|>user
Hello<|im_end|>
<|im_start|>assistant
Hi there<|im_end|>
<|im_start|>user
How are you?<|im_end|>
<|im_start|>assistant
'
main: server is listening on http://127.0.0.1:8080 - starting the main loop
srv  update_slots: all slots are idle

Copy link
Collaborator

@ngxson ngxson left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, thanks!

@github-actions github-actions bot added Nvidia GPU Issues specific to Nvidia GPUs python python script changes ggml changes relating to the ggml tensor library for machine learning labels Nov 3, 2025
@ngxson ngxson merged commit 622cd01 into ggml-org:master Nov 3, 2025
67 of 69 checks passed
@theo77186 theo77186 deleted the fattn-hs72 branch November 3, 2025 15:21
gabe-l-hart added a commit to gabe-l-hart/llama.cpp that referenced this pull request Nov 3, 2025
* origin/master: (169 commits)
opencl: support imrope (ggml-org#16914)
fix: Viewing multiple PDF attachments (ggml-org#16974)
model-conversion : pass config to from_pretrained (ggml-org#16963)
server : add props.model_alias (ggml-org#16943)
ggml: CUDA: add head size 72 for flash-attn (ggml-org#16962)
mtmd: add --image-min/max-tokens (ggml-org#16921)
mtmd: pad mask for qwen2.5vl (ggml-org#16954)
ggml : LoongArch fixes (ggml-org#16958)
sync: minja (glm 4.6 & minmax m2 templates) (ggml-org#16949)
SYCL: optimized repeat_back kernel (3× fewer asm instructions, 2× faster)Feature/sycl repeat back opt (ggml-org#16869)
feat(webui): improve LaTeX rendering with currency detection (ggml-org#16508)
test-backend-ops : fix segfault in moe-expert-reduce test in support mode and coverage (ggml-org#16936)
ci : disable failing riscv cross build (ggml-org#16952)
model: add Janus Pro for image understanding (ggml-org#16906)
clip : use FA (ggml-org#16837)
server : support unified cache across slots (ggml-org#16736)
common : move gpt-oss reasoning processing to init params (ggml-org#16937)
docs: remove llama_sampler_accept reference in sampling sample usage (ggml-org#16920)
CUDA: add FLOOR, CEIL, ROUND, TRUNC unary ops (ggml-org#16917)
devops: fix failing s390x docker build (ggml-org#16918)
...
Nexesenex added a commit to Nexesenex/croco.cpp that referenced this pull request Nov 3, 2025
GittyBurstein pushed a commit to yael-works/llama.cpp that referenced this pull request Nov 5, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ggml changes relating to the ggml tensor library for machine learning Nvidia GPU Issues specific to Nvidia GPUs python python script changes

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Eval bug: mtmd: "flash attention is disabled / please report this on github as an issue"

3 participants