Skip to content

Support for DeepseekV32ForCausalLM with generic DeepSeek Sparse Attention (DSA) implementation#23346

Merged
CISC merged 7 commits into
ggml-org:masterfrom
fairydreaming:deepseek-v32-minimal
May 29, 2026
Merged

Support for DeepseekV32ForCausalLM with generic DeepSeek Sparse Attention (DSA) implementation#23346
CISC merged 7 commits into
ggml-org:masterfrom
fairydreaming:deepseek-v32-minimal

Conversation

@fairydreaming
Copy link
Copy Markdown
Collaborator

@fairydreaming fairydreaming commented May 19, 2026

Warning: The DeepSeek V3.2 model conversion currently fails with transformers 5.x (required by requirements.txt after #21617 was merged). Downgrade transformers to 4.x (for example to 4.57.6) to convert the model.

Overview

This PR adds support for DeepseekV32ForCausalLM (DeepSeek V3.2 Exp, DeepSeek V3.2, DeepSeek V3.2 Speciale) models. It implements lightning indexer and DeepSeek Sparse Attention (DSA) in generic GGML without adding any new OPs.

This PR is a continuation of PR #21149 (now closed).

Additional information

Covered areas

Areas covered by this PR:

  • conversion: support for DeepseekV32ForCausalLM architecture,
  • ggml-cpu: support for f16 GGML_OP_FILL,
  • memory: refactored llama_kv_cache constructor to include explicit hparams argument,
  • memory: added llama_kv_cache_dsa class which aggregates two instances of llama_kv_cache - one for caching MLA latent representations, second for caching lightning indexer keys,
  • llama: added LLM_ARCH_DEEPSEEK32 architecture (mostly a copy of existing LLM_ARCH_GLM_DSA),
  • llama: implemented sparse attention by masking KQ mask elements corresponding to tokens not selected by the lightning indexer,
  • model: llama_model_deepseek32 implementation (mostly copied from llama_model_glm_dsa and llama_model_deepseek2)

Testing

GGUFs for testing (Q8_0/Q4_K_M):

You need over 700GB (Q8_0) or over 400GB (Q4_K_M) of RAM/VRAM to run these models. Generic lightning indexer implementation uses very large compute buffers, so if you encounter out of memory errors reduce context and/or ubatch size.

There is also a tiny 16GB 4-layer DeepSeek V3.2 GGUF that does not produce coherent output but may be useful for testing the implementation.

Use models/templates/deepseek-ai-DeepSeek-V3.2.jinja chat template when testing models.

Perplexity

I measured perplexity (on wiki.test.raw with 4k chunk size so that indexer does some actual work) of:

  • Q8_0 quant without lightning indexer (dense attention): Final estimate: PPL = 2.9115 +/- 0.0146
  • Q8_0 quant with lightning indexer (sparse attention): Final estimate: PPL = 2.9126 +/- 0.01466
  • NVFP4 quant with lightning indexer (sparse attention): Final estimate: PPL = 3.0727 +/- 0.01577

Requirements

sszymczy added 3 commits May 19, 2026 07:22
* convert : handle DeepseekV32ForCausalLM architecture

* ggml : support for f16 GGML_OP_FILL

* memory : separate hparams argument in llama_kv_cache constructor

* memory : add llama_kv_cache_dsa memory (KV cache + lightning indexer cache)

* llama : support for LLM_ARCH_DEEPSEEK32

* model : llama_model_deepseek32 implementation
@github-actions github-actions Bot added model Model specific testing Everything test related python python script changes ggml changes relating to the ggml tensor library for machine learning labels May 19, 2026
Copy link
Copy Markdown
Member

@ggerganov ggerganov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a TODO so I don't forget to do the refactor:

diff --git a/src/llama-kv-cache.h b/src/llama-kv-cache.h
index 0b0a56ce9..649269af6 100644
--- a/src/llama-kv-cache.h
+++ b/src/llama-kv-cache.h
@@ -93,6 +93,9 @@ public:
 
     using slot_info_vec_t = std::vector<slot_info>;
 
+    // TODO: refactor the memory instances to not depend on `llama_model`
+    //       instead pass all necessary info (e.g. hparams, dev layers, arch, etc.) directly
+    //       likely through `struct llama_memory_params`
     llama_kv_cache(
             const llama_model & model,
             const llama_hparams & hparams,

Copy link
Copy Markdown
Member

@CISC CISC left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Support NVFP4 model.

Comment thread src/llama-graph.cpp
Comment thread src/llama-graph.cpp Outdated
Comment thread src/llama-graph.h
Comment thread src/models/deepseek32.cpp Outdated
Comment thread src/models/deepseek32.cpp Outdated
Comment thread src/models/deepseek32.cpp Outdated
Comment thread src/models/deepseek32.cpp Outdated
Comment thread src/models/deepseek32.cpp
res->t_embd = cur;

// lm_head
cur = ggml_mul_mat(ctx0, model.output, cur);
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not build_lora_mm?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not build_lora_mm?

I guess nobody ever cared enough to add this to the DeepSeek code that I copied and modified in this PR, so it's kind of inherited.

Are there any standard conventions of which tensor matmuls should be LoRAble and which should be left alone?

@am17an
Copy link
Copy Markdown
Contributor

am17an commented May 28, 2026

@fairydreaming sorry for my ignorance, but does the flash model work with this same architecture? That requires way less VRAM and I can also test it out on my machine (I have 128GB vram)

@fairydreaming
Copy link
Copy Markdown
Collaborator Author

fairydreaming commented May 28, 2026

@fairydreaming sorry for my ignorance, but does the flash model work with this same architecture? That requires way less VRAM and I can also test it out on my machine (I have 128GB vram)

@am17an There is no DeepSeek V3.2 Flash model. I'm currently trying to get NVFP4 quant to work as @CISC suggested, but it's still almost 400GB.

Edit: in case you meant DeepSeek V4 Flash then unfortunately the answer is no, it's something completely different from DeepSeek V3.2.

@am17an
Copy link
Copy Markdown
Contributor

am17an commented May 28, 2026

@fairydreaming yes I mean the DSV4 flash model. I just read up on it and you're right it's completely different, but the lighting indexer work you're doing here will be useful there. I will try and work on the flash model in the meantime

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
@CISC
Copy link
Copy Markdown
Member

CISC commented May 28, 2026

@fairydreaming fairydreaming force-pushed the deepseek-v32-minimal branch from 30fdfe4 to 4643fda Compare May 28, 2026 15:09
sszymczy and others added 2 commits May 28, 2026 17:12
Co-authored-by: ggerganov <ggerganov@users.noreply.github.com>
Copy link
Copy Markdown
Member

@CISC CISC left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I still have the build_lora_mm question, but otherwise LGTM.

@fairydreaming
Copy link
Copy Markdown
Collaborator Author

@fairydreaming GitHub UI messed up EOL again, please normalize to \n: https://github.com/ggml-org/llama.cpp/actions/runs/26582700461/job/78319887350?pr=23346

@CISC Yeah I noticed, force-pushed a fixed commit.

@fairydreaming
Copy link
Copy Markdown
Collaborator Author

fairydreaming commented May 28, 2026

@CISC By the way I managed to convert and run nvidia/DeepSeek-V3.2-NVFP4 with your NVFP4 changes and it seems to work fine. Needed only regenerating model.safetensors.index.json as currently it misses NVFP4 scale tensors.

@am17an I thought about DSV4 too, but still don't have a clear vision of how to integrate it with llama.cpp memory subsystem without creating a bunch of new specialized classes. But it's definitely a good idea to keep common parts reusable in both. I suppose one obvious next step is to add separate lightning indexer GGML OP as it brings immense compute buffer size reductions. But since DS V3.2 is kind of obsolete now I can chill a bit and take it easy. Anyway, please keep me posted about any progress, wish you luck!

@CISC
Copy link
Copy Markdown
Member

CISC commented May 28, 2026

@CISC By the way I managed to convert and run nvidia/DeepSeek-V3.2-NVFP4 with your NVFP4 changes and it seems to work fine. Needed only regenerating model.safetensors.index.json as currently it misses NVFP4 scale tensors.

Weird, but great to hear it works, do you have BW hw, and if so how does performance compare?

@fairydreaming
Copy link
Copy Markdown
Collaborator Author

fairydreaming commented May 28, 2026

@CISC By the way I managed to convert and run nvidia/DeepSeek-V3.2-NVFP4 with your NVFP4 changes and it seems to work fine. Needed only regenerating model.safetensors.index.json as currently it misses NVFP4 scale tensors.

Weird, but great to hear it works, do you have BW hw, and if so how does performance compare?

@CISC I have Epyc 9374F with a single RTX PRO 6000 Max-Q (BLACKWELL_NATIVE_FP4 = 1), experts were in RAM. Some llama-bench experiments I did:

Q8_0, --no-op-offload 0

./bin/llama-bench -m ../models/DeepSeek-V3.2-Q8_0.gguf -ncmoe 999 -ngl 99 -fa 1 -ub 512 -p 512 -n 32 -r 1 --no-op-offload 0
ggml_cuda_init: found 1 CUDA devices (Total VRAM: 97247 MiB):
  Device 0: NVIDIA RTX PRO 6000 Blackwell Max-Q Workstation Edition, compute capability 12.0, VMM: yes, VRAM: 97247 MiB
| model                          |       size |     params | backend    | ngl |  n_cpu_moe | fa |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | ---------: | -: | --------------: | -------------------: |
| deepseek32 685B.A37B Q8_0      | 678.56 GiB |   685.36 B | CUDA       |  99 |        999 |  1 |           pp512 |         22.17 ± 0.00 |
| deepseek32 685B.A37B Q8_0      | 678.56 GiB |   685.36 B | CUDA       |  99 |        999 |  1 |            tg32 |         10.91 ± 0.00 |

Q8_0, --no-op-offload 1

./bin/llama-bench -m ../models/DeepSeek-V3.2-Q8_0.gguf -ncmoe 999 -ngl 99 -fa 1 -ub 512 -p 512 -n 32 -r 1 --no-op-offload 1
ggml_cuda_init: found 1 CUDA devices (Total VRAM: 97247 MiB):
  Device 0: NVIDIA RTX PRO 6000 Blackwell Max-Q Workstation Edition, compute capability 12.0, VMM: yes, VRAM: 97247 MiB
| model                          |       size |     params | backend    | ngl |  n_cpu_moe | fa | nopo |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | ---------: | -: | ---: | --------------: | -------------------: |
| deepseek32 685B.A37B Q8_0      | 678.56 GiB |   685.36 B | CUDA       |  99 |        999 |  1 |    1 |           pp512 |         42.01 ± 0.00 |
| deepseek32 685B.A37B Q8_0      | 678.56 GiB |   685.36 B | CUDA       |  99 |        999 |  1 |    1 |            tg32 |         10.97 ± 0.00 |

NVFP4, --no-op-offload 0

./bin/llama-bench -m ../models/DeepSeek-V3.2-NVFP4.gguf -ncmoe 999 -ngl 99 -fa 1 -ub 512 -p 512 -n 32 -r 1 --no-op-offload 0
ggml_cuda_init: found 1 CUDA devices (Total VRAM: 97247 MiB):
  Device 0: NVIDIA RTX PRO 6000 Blackwell Max-Q Workstation Edition, compute capability 12.0, VMM: yes, VRAM: 97247 MiB
| model                          |       size |     params | backend    | ngl |  n_cpu_moe | fa |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | ---------: | -: | --------------: | -------------------: |
| deepseek32 685B.A37B NVFP4     | 386.79 GiB |   685.36 B | CUDA       |  99 |        999 |  1 |           pp512 |         41.34 ± 0.00 |
| deepseek32 685B.A37B NVFP4     | 386.79 GiB |   685.36 B | CUDA       |  99 |        999 |  1 |            tg32 |          1.82 ± 0.00 |

NVFP4, --no-op-offload 1

./bin/llama-bench -m ../models/DeepSeek-V3.2-NVFP4.gguf -ncmoe 999 -ngl 99 -fa 1 -ub 512 -p 512 -n 32 -r 1 --no-op-offload 1
ggml_cuda_init: found 1 CUDA devices (Total VRAM: 97247 MiB):
  Device 0: NVIDIA RTX PRO 6000 Blackwell Max-Q Workstation Edition, compute capability 12.0, VMM: yes, VRAM: 97247 MiB
| model                          |       size |     params | backend    | ngl |  n_cpu_moe | fa | nopo |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | ---------: | -: | ---: | --------------: | -------------------: |
| deepseek32 685B.A37B NVFP4     | 386.79 GiB |   685.36 B | CUDA       |  99 |        999 |  1 |    1 |           pp512 |          1.99 ± 0.00 |
| deepseek32 685B.A37B NVFP4     | 386.79 GiB |   685.36 B | CUDA       |  99 |        999 |  1 |    1 |            tg32 |          1.83 ± 0.00 |

build: 101bad432 (9403)

From what I understand NVFP4 has horrible performance on the CPU and this slows everything down, I added some mul_mat backend op tests and they seem to confirm it:

Q8_0:

./bin/test-backend-ops perf -o "MUL_MAT(type_a=q8_0,type_b=f32,m=4096,n=1,k=512,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1)"
ggml_cuda_init: found 1 CUDA devices (Total VRAM: 97247 MiB):
  Device 0: NVIDIA RTX PRO 6000 Blackwell Max-Q Workstation Edition, compute capability 12.0, VMM: yes, VRAM: 97247 MiB
Testing 2 devices

Backend 1/2: CUDA0
  Device description: NVIDIA RTX PRO 6000 Blackwell Max-Q Workstation Edition
  Device memory: 97247 MB (96640 MB free)

ggml_backend_cuda_graph_compute: CUDA graph warmup complete
  MUL_MAT(type_a=q8_0,type_b=f32,m=4096,n=1,k=512,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1):                 319488 runs -     3.18 us/run -   4.19 MFLOP/run -   1.32 TFLOPS
  Backend CUDA0: OK
Backend 2/2: CPU
  Device description: AMD EPYC 9374F 32-Core Processor
  Device memory: 1160411 MB (1160411 MB free)

  MUL_MAT(type_a=q8_0,type_b=f32,m=4096,n=1,k=512,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1):                 150732 runs -     6.67 us/run -   4.19 MFLOP/run - 628.48 GFLOPS
  Backend CPU: OK
2/2 backends passed
OK

NVFP4

./bin/test-backend-ops perf -o "MUL_MAT(type_a=nvfp4,type_b=f32,m=4096,n=1,k=512,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1)"
ggml_cuda_init: found 1 CUDA devices (Total VRAM: 97247 MiB):
  Device 0: NVIDIA RTX PRO 6000 Blackwell Max-Q Workstation Edition, compute capability 12.0, VMM: yes, VRAM: 97247 MiB
Testing 2 devices

Backend 1/2: CUDA0
  Device description: NVIDIA RTX PRO 6000 Blackwell Max-Q Workstation Edition
  Device memory: 97247 MB (96640 MB free)

ggml_backend_cuda_graph_compute: CUDA graph warmup complete
  MUL_MAT(type_a=nvfp4,type_b=f32,m=4096,n=1,k=512,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1):                294912 runs -     3.44 us/run -   4.19 MFLOP/run -   1.22 TFLOPS
  Backend CUDA0: OK
Backend 2/2: CPU
  Device description: AMD EPYC 9374F 32-Core Processor
  Device memory: 1160411 MB (1160411 MB free)

  MUL_MAT(type_a=nvfp4,type_b=f32,m=4096,n=1,k=512,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1):                 19080 runs -    55.78 us/run -   4.19 MFLOP/run -  75.20 GFLOPS
  Backend CPU: OK
2/2 backends passed
OK

so while Q8_0 on CPU works pretty fast, NVFP4 is like 8 times slower, basically unusable.

@CISC
Copy link
Copy Markdown
Member

CISC commented May 29, 2026

so while Q8_0 on CPU works pretty fast, NVFP4 is like 8 times slower, basically unusable.

Yeah, it's only useful if you can fit all the NVFP4 tensors on GPU. :(

@am17an
Copy link
Copy Markdown
Contributor

am17an commented May 29, 2026

Also the current NVFP4 CPU path is the "generic" path, probably an AVX impl would bring it up to par with the rest of the quants

@CISC CISC merged commit 1f0aa2a into ggml-org:master May 29, 2026
32 checks passed
@JohannesGaessler
Copy link
Copy Markdown
Contributor

This PR broke the CI for test-llama-archs, see #23876 .

@CISC
Copy link
Copy Markdown
Member

CISC commented May 29, 2026

This PR broke the CI for test-llama-archs, see #23876 .

See #23864 :)

gabe-l-hart added a commit to gabe-l-hart/llama.cpp that referenced this pull request May 29, 2026
* origin/master:
vocab : support tokenizer for LFM2.5-8B-A1B (ggml-org#23826)
graph : ensure DS32 kq_mask_lid is F32 (ggml-org#23864)
server: remove obsolete scripts (ggml-org#23870)
ci : update macos release to use macos-26 runner (ggml-org#23878)
download: add option to skip_download (ggml-org#23059)
mtmd: Add DeepSeekOCR 2 Support (ggml-org#20975)
CUDA: Check PTX version on host side to guard PDL dispatch (ggml-org#23530)
server: bump timeout to 3600s (ggml-org#23842)
model : support for DeepseekV32ForCausalLM with generic DeepSeek Sparse Attention (DSA) implementation (ggml-org#23346)
llama: use f16 mask for FA to save VRAM (ggml-org#23764)
sync : ggml
ggml : bump version to 0.13.1 (ggml/1523)
ngram-mod : Add missing include (ggml-org#23857)
llama: add llm_graph_input_mtp (ggml-org#23643)
app : move licences to llama-app (ggml-org#23824)
cuda : disables launch_fattn PDL enrollment due to compiler bug (ggml-org#23825)
meta : Add missing `buffer` set in allreduce fallback !COMPUTE clear (ggml-org#23480)
fewtarius pushed a commit to fewtarius/llama.cpp that referenced this pull request May 30, 2026
…se Attention (DSA) implementation (ggml-org#23346)

* llama : support DeepSeek V3.2 model family (with DSA lightning indexer)

* convert : handle DeepseekV32ForCausalLM architecture

* ggml : support for f16 GGML_OP_FILL

* memory : separate hparams argument in llama_kv_cache constructor

* memory : add llama_kv_cache_dsa memory (KV cache + lightning indexer cache)

* llama : support for LLM_ARCH_DEEPSEEK32

* model : llama_model_deepseek32 implementation

* model : merge two scale operations into one in DSA lightning indexer implementation

* chore : remove unused code

* model : support NVFP4 in DeepSeek V3.2

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* memory : refactoring TODO

Co-authored-by: ggerganov <ggerganov@users.noreply.github.com>

---------

Co-authored-by: Stanisław Szymczyk <sszymczy@gmail.com>
Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
Co-authored-by: ggerganov <ggerganov@users.noreply.github.com>
turbo-tan pushed a commit to turbo-tan/llama.cpp-tq3 that referenced this pull request Jun 2, 2026
…se Attention (DSA) implementation (ggml-org#23346)

* llama : support DeepSeek V3.2 model family (with DSA lightning indexer)

* convert : handle DeepseekV32ForCausalLM architecture

* ggml : support for f16 GGML_OP_FILL

* memory : separate hparams argument in llama_kv_cache constructor

* memory : add llama_kv_cache_dsa memory (KV cache + lightning indexer cache)

* llama : support for LLM_ARCH_DEEPSEEK32

* model : llama_model_deepseek32 implementation

* model : merge two scale operations into one in DSA lightning indexer implementation

* chore : remove unused code

* model : support NVFP4 in DeepSeek V3.2

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* memory : refactoring TODO

Co-authored-by: ggerganov <ggerganov@users.noreply.github.com>

---------

Co-authored-by: Stanisław Szymczyk <sszymczy@gmail.com>
Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
Co-authored-by: ggerganov <ggerganov@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ggml changes relating to the ggml tensor library for machine learning model Model specific python python script changes testing Everything test related

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants