Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimized DeepSeek V2/V3 implementation (MLA) #11446

Draft
wants to merge 10 commits into
base: master
Choose a base branch
from

Conversation

fairydreaming
Copy link
Collaborator

@fairydreaming fairydreaming commented Jan 27, 2025

This PR introduces various optimizations for DeepSeek V2/V3 implementation:

Note that you need to reconvert the model to use this implementation.

Performance compared to the previous "naive" implementation:

deepseek-mla

deepseek-lite-mla-pp

deepseek-r1-mla

deepseek-mla-pp

CUDA performance is worse for short context lengths, but the curve is flatter:

deepseek-lite-mla

deepseek-lite-cuda-mla-pp

TODO:

  • remove unused kv_b tensor from the model
  • maybe add support for old model files (compute k_b and v_b during inference with reduced performance)
  • wait for completion of: llama : refactor llama_kv_cache, llama_context and llm_build_context #11213
  • implement MLA KV cache
  • address regressions in prompt processing performance (different permutations of tensors?) - I don't think it's possible, as this implementation is more compute-intensive compared to regular attention implementation

@fairydreaming fairydreaming marked this pull request as draft January 28, 2025 11:23
@wronkiew
Copy link

@fairydreaming do you have a converted model available or instructions for replicating your setup? I would like to run some benchmarks on these changes.

@fairydreaming
Copy link
Collaborator Author

@fairydreaming do you have a converted model available or instructions for replicating your setup? I would like to run some benchmarks on these changes.

@wronkiew What model would you like to test?

@wronkiew
Copy link

@fairydreaming do you have a converted model available or instructions for replicating your setup? I would like to run some benchmarks on these changes.

@wronkiew What model would you like to test?

V3/R1, Q4_K_S.

@fairydreaming
Copy link
Collaborator Author

@fairydreaming do you have a converted model available or instructions for replicating your setup? I would like to run some benchmarks on these changes.

@wronkiew What model would you like to test?

V3/R1, Q4_K_S.

@wronkiew I don't have the model uploaded (my upload bandwidth is too low), you have to download, convert to bf16, convert to gguf and quantize the original model by yourself (or download one that is already converted to bf16, this will save you one step).

@fairydreaming
Copy link
Collaborator Author

I spent some time investigating this hint from the DeepSeek V2 paper:

Fortunately, due to the associative law of matrix multiplication, we can absorb $𝑊^{𝑈𝐾}$ into $𝑊^{𝑈𝑄}$ , and $𝑊^{𝑈𝑉}$ into $𝑊^𝑂$

At first glance it looks reasonable, each absorbed matrix allows to replace two matrix multiplications with a single multiplication, thus reducing the number of operations.

However when we take a look into dimensions of these matrices, this stops being reasonable. For example in DeepSeek V2 lite:

  • $𝑊^{𝑈𝑄}$ tensor has shape [2048, 2048], that is [16, 2048, 128] after reshaping to 3d and permutation
  • $𝑊^{𝑈𝐾}$ tensor has shape [128, 8192], that is [16, 512, 128] after reshaping to 3d and permutation
  • combined "absorbed" tensor has shape [16, 512, 2048]

So (let's ignore the head dimension) this allows to replace two multiplications: with [2048, 128] matrix and [512, 128] matrix with a single multiplication with a [512, 2048]. The combined matrix has over 3x elements compared to individual matrices, so it will take more memory and it will be actually slower to multiply compared to two multiplications with smaller matrices.

With $𝑊^{𝑈𝑉}$ and $𝑊^𝑂$ it's the same story:

  • $𝑊^{𝑈𝑉}$ tensor has shape [2048, 512], that is [16, 512, 128] after reshaping to 3d and permutation
  • $𝑊^𝑂$ tensor has shape [2048, 2048], that is [16, 2048, 128] after reshaping to 3d and permutation
  • combined "absorbed" tensor has shape [16, 512, 2048]

I also found this blog post: https://github.com/xjdr-alt/mla_blog_translation where they mention:

Compared to performing projection with these particularly large low-rank matrices, it is obviously more advantageous to multiply them successively according to the low-rank decomposition form. Therefore, we believe that this optimization step is not very necessary.

So it looks like a dead end, it won't give us any speed gains.

@divine-taco
Copy link

I ran into an issue with DeepSeek-R1-UD-Q2_K_XL from unsloth/DeepSeek-R1-GGUF

llama_model_load: error loading model: missing tensor 'blk.0.attn_k_b.weight'                                                        llama_model_load_from_file_impl: failed to load model

@fairydreaming
Copy link
Collaborator Author

fairydreaming commented Jan 31, 2025

I ran into an issue with DeepSeek-R1-UD-Q2_K_XL from unsloth/DeepSeek-R1-GGUF

llama_model_load: error loading model: missing tensor 'blk.0.attn_k_b.weight'                                                        llama_model_load_from_file_impl: failed to load model

As I wrote in the PR:

Note that you need to reconvert the model to use this implementation.

Existing GGUFs won't work, you have to convert and quantize one with the code from this PR.

@danielhanchen
Copy link
Contributor

Ohh hmm should I re-quantize the ones in https://huggingface.co/unsloth/DeepSeek-R1-GGUF?

@fairydreaming
Copy link
Collaborator Author

Ohh hmm should I re-quantize the ones in https://huggingface.co/unsloth/DeepSeek-R1-GGUF?

I think it's best to wait a bit until this is stable and merged, it's possible that there will be some changes that would cause them to stop working and you'd have to repeat the conversion again.

@fairydreaming
Copy link
Collaborator Author

I updated the token generation performance plots in the PR post, also added some new showing the prompt processing performance. The optimized implementation generally performs WORSE in prompt processing - DeepSeek R1 671B Q4_K_S running on CPU performs only a little worse (~10% with 4k prompt), but DeepSeek V2 Lite Q8_0 running on RTX 4090 performs MUCH WORSE (~30% with 16k prompt) and in both cases the gap widens as the prompt length increases. So it's not all sunshine and rainbows.

Considering all these performance regressions I think the best course of action would be to put the optimized implementation into separate model architecture (LLM_ARCH_DEEPSEEK2_MLA or something like this). This will prevent issues with existing GGUFs - they would keep working with existing architecture. I guess in this case the convert script would have to allow selection of the target model architecture with some option, but that shouldn't be difficult to add. @ggerganov what do you think?

Comment on lines +6406 to +6409
// whether to use n_tokens as the matrix dimension during multiplication or n_head
// n_tokens is higher during prompt processing, this allows to optimize for this case
bool pp_opt = n_tokens > n_head;

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not really sure this is the right approach. Haven't followed through the logic yet, but it seems strange to involve so many permutes and conts.

I would first look into improving the FA kernels to support DeepSeek head sizes.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not really sure this is the right approach. Haven't followed through the logic yet, but it seems strange to involve so many permutes and conts.

Hmm? I'm quite sure there's only one ggml_cont() call (excluding the ones for CUDA compatibility that already existed in the previous implementation).

As for the permutes the idea is to multiply by a matrix with a second dimension equal to the number of heads instead of the number of tokens (which is 1) during a single sequence token generation, that increased the performance on a CPU a bit.

So during prompt processing we have 2 permutes and 1 cont. During token generation we have 5 permutes (yeah, that may be a lot) and 0 conts.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the correction - I did imagine the extra conts when I saw the permutes.

@ggerganov
Copy link
Member

Considering all these performance regressions I think the best course of action would be to put the optimized implementation into separate model architecture (LLM_ARCH_DEEPSEEK2_MLA or something like this). This will prevent issues with existing GGUFs - they would keep working with existing architecture. I guess in this case the convert script would have to allow selection of the target model architecture with some option, but that shouldn't be difficult to add. @ggerganov what do you think?

While this is possible to do, I think it has a lot of cons. It will make it difficult for everyone to know which model variation on which hardware to use for better performance. Ideally, we want to have a single implementation that is optimal in all use cases, which can be deprecated at some point for a better alternative. But having 2 alternatives neither of which is optimal is not great.

Also, I'm not sure how this implementation fits with multiple parallel sequences and it introduces extra KV cache logic, specific to this type of arch.

I know there is a lot of interest in the DeepSeek arch right now and such optimizations are really important for people. But I think that we have to keep this work in a PR for a while. It is much more important to fix the software architecture in libllama after which such changes should become easier.

@fairydreaming
Copy link
Collaborator Author

Considering all these performance regressions I think the best course of action would be to put the optimized implementation into separate model architecture (LLM_ARCH_DEEPSEEK2_MLA or something like this). This will prevent issues with existing GGUFs - they would keep working with existing architecture. I guess in this case the convert script would have to allow selection of the target model architecture with some option, but that shouldn't be difficult to add. @ggerganov what do you think?

While this is possible to do, I think it has a lot of cons. It will make it difficult for everyone to know which model variation on which hardware to use for better performance. Ideally, we want to have a single implementation that is optimal in all use cases, which can be deprecated at some point for a better alternative. But having 2 alternatives neither of which is optimal is not great.

That may not be possible - IMHO MLA attention implementation that caches "compressed" latent kv representations introduces unavoidable computational overhead due to the need to "decompress" these representations in order to calculate attention scores and attention output. So "naive" attention implementation that caches full K/V vectors will always use less compute but more memory bandwidth, while caching latent representations results in using more compute, but less memory bandwidth. So there can't be a single implementation optimal in all use cases. I'd be happy to be proven wrong about this, though.

Also, I'm not sure how this implementation fits with multiple parallel sequences and it introduces extra KV cache logic, specific to this type of arch.

I think there shouldn't be any problems with this, as there is a straightforward direct mapping between the cached representations and full K/V vectors.

I know there is a lot of interest in the DeepSeek arch right now and such optimizations are really important for people. But I think that we have to keep this work in a PR for a while. It is much more important to fix the software architecture in libllama after which such changes should become easier.

That's fine with me. I'm taking a break from this anyway, got bored with tensor shuffling looking for 0.1 t/s more performance. 😉

@saood06
Copy link

saood06 commented Feb 2, 2025

@fairydreaming
Is there any reason this should cause issues with RPC.
Encountered:

ggml_cuda_compute_forward: cannot compute kqv-31: src0->ne[3] = 1, src1->ne[3] = 2 - fallback to CPU
evaluate_and_capture_cuda_graph: op not supported kqv-31 (MUL_MAT)
[...]\llama.cpp\ggml\src\ggml-cuda\ggml-cuda.cu:2660: GGML_ASSERT(ok) failed

I don't have a quant on hand that I can test without this branch, but this branch does give me a nice performance boost for TG at longer contexts, but RPC to CUDA does not work.

@jukofyork
Copy link
Contributor

jukofyork commented Feb 12, 2025

OK, I can get the fake LoRA thing working really easily.

For fixed --rank:

> python3 ./fp8_cast_bf16_and_SVD_MLP.py --input-fp8-hf-path DeepSeek-R1 --output-bf16-hf-path DeepSeek-R1-svd-bf16 --rank 512
Converting...
  0%|                                                                                                                                                      | 0/163 [00:00<?, ?it/s]
Processing MLP weight: model.layers.3.mlp.experts.0.down_proj.weight torch.Size([7168, 2048])
- Rank              : 512
- Variance Explained: 77.43%
- Compression Ratio : 32.14%
- New Shapes        : torch.Size([7168, 512]), torch.Size([512, 2048])

Processing MLP weight: model.layers.3.mlp.experts.0.gate_proj.weight torch.Size([2048, 7168])
- Rank              : 512
- Variance Explained: 83.98%
- Compression Ratio : 32.14%
- New Shapes        : torch.Size([2048, 512]), torch.Size([512, 7168])

Processing MLP weight: model.layers.3.mlp.experts.0.up_proj.weight torch.Size([2048, 7168])
- Rank              : 512
- Variance Explained: 83.64%
- Compression Ratio : 32.14%
- New Shapes        : torch.Size([2048, 512]), torch.Size([512, 7168])

Processing MLP weight: model.layers.3.mlp.experts.1.down_proj.weight torch.Size([7168, 2048])
- Rank              : 512
- Variance Explained: 91.89%
- Compression Ratio : 32.14%
- New Shapes        : torch.Size([7168, 512]), torch.Size([512, 2048])

Processing MLP weight: model.layers.3.mlp.experts.1.gate_proj.weight torch.Size([2048, 7168])
- Rank              : 512
- Variance Explained: 93.74%
- Compression Ratio : 32.14%
- New Shapes        : torch.Size([2048, 512]), torch.Size([512, 7168])

Processing MLP weight: model.layers.3.mlp.experts.1.up_proj.weight torch.Size([2048, 7168])
- Rank              : 512
- Variance Explained: 93.60%
- Compression Ratio : 32.14%
- New Shapes        : torch.Size([2048, 512]), torch.Size([512, 7168])

Processing MLP weight: model.layers.3.mlp.experts.2.down_proj.weight torch.Size([7168, 2048])
- Rank              : 512
- Variance Explained: 97.20%
- Compression Ratio : 32.14%
- New Shapes        : torch.Size([7168, 512]), torch.Size([512, 2048])

Or dynamic --min-variance-explained:

> python3 ./fp8_cast_bf16_and_SVD_MLP.py --input-fp8-hf-path DeepSeek-R1 --output-bf16-hf-path DeepSeek-R1-svd-bf16 --min-variance-explained 0.9
Converting...
  0%|                                                                                                                                                      | 0/163 [00:00<?, ?it/s]
Processing MLP weight: model.layers.3.mlp.experts.0.down_proj.weight torch.Size([7168, 2048])
- Rank              : 866
- Variance Explained: 90.00%
- Compression Ratio : 54.37%
- New Shapes        : torch.Size([7168, 866]), torch.Size([866, 2048])

Processing MLP weight: model.layers.3.mlp.experts.0.gate_proj.weight torch.Size([2048, 7168])
- Rank              : 700
- Variance Explained: 90.01%
- Compression Ratio : 43.95%
- New Shapes        : torch.Size([2048, 700]), torch.Size([700, 7168])

Processing MLP weight: model.layers.3.mlp.experts.0.up_proj.weight torch.Size([2048, 7168])
- Rank              : 709
- Variance Explained: 90.02%
- Compression Ratio : 44.51%
- New Shapes        : torch.Size([2048, 709]), torch.Size([709, 7168])

Processing MLP weight: model.layers.3.mlp.experts.1.down_proj.weight torch.Size([7168, 2048])
- Rank              : 454
- Variance Explained: 90.02%
- Compression Ratio : 28.50%
- New Shapes        : torch.Size([7168, 454]), torch.Size([454, 2048])

Processing MLP weight: model.layers.3.mlp.experts.1.gate_proj.weight torch.Size([2048, 7168])
- Rank              : 391
- Variance Explained: 90.01%
- Compression Ratio : 24.55%
- New Shapes        : torch.Size([2048, 391]), torch.Size([391, 7168])

but need to confirm that convert_lora_to_gguf.py will actually work with MoE models where all the experts are separate like this?

@slaren Does convert_lora_to_gguf.py use the same logic as convert_hf_to_gguf.py? I can't see how the dynamic version could ever work (even though it's a valid LoRA via the rank/alpha overrides), but will convert_lora_to_gguf.py even work for the fixed rank version and correctly turn the 256 LoRAs with the same rank into a large single tensor, etc?

I don't want to spend all day doing this to find it's impossible to work with llama.cpp... :)

@jukofyork
Copy link
Contributor

It looks to me that a rank-64 "fake LoRA" would explain around half the variance:

> python3 ./fp8_cast_bf16_and_SVD_MLP.py --input-fp8-hf-path DeepSeek-R1 --output-bf16-hf-path DeepSeek-R1-svd-bf16 --rank 64 
Converting...
  0%|                                                                                                                                                      | 0/163 [00:00<?, ?it/s]
Processing MLP weight: model.layers.3.mlp.experts.0.down_proj.weight torch.Size([7168, 2048])
- Rank              : 64
- Variance Explained: 29.12%
- Compression Ratio : 4.02%
- New Shapes        : torch.Size([7168, 64]), torch.Size([64, 2048])

Processing MLP weight: model.layers.3.mlp.experts.0.gate_proj.weight torch.Size([2048, 7168])
- Rank              : 64
- Variance Explained: 35.61%
- Compression Ratio : 4.02%
- New Shapes        : torch.Size([2048, 64]), torch.Size([64, 7168])

Processing MLP weight: model.layers.3.mlp.experts.0.up_proj.weight torch.Size([2048, 7168])
- Rank              : 64
- Variance Explained: 34.32%
- Compression Ratio : 4.02%
- New Shapes        : torch.Size([2048, 64]), torch.Size([64, 7168])

Processing MLP weight: model.layers.3.mlp.experts.1.down_proj.weight torch.Size([7168, 2048])
- Rank              : 64
- Variance Explained: 44.20%
- Compression Ratio : 4.02%
- New Shapes        : torch.Size([7168, 64]), torch.Size([64, 2048])

Processing MLP weight: model.layers.3.mlp.experts.1.gate_proj.weight torch.Size([2048, 7168])
- Rank              : 64
- Variance Explained: 48.99%
- Compression Ratio : 4.02%
- New Shapes        : torch.Size([2048, 64]), torch.Size([64, 7168])

Processing MLP weight: model.layers.3.mlp.experts.1.up_proj.weight torch.Size([2048, 7168])
- Rank              : 64
- Variance Explained: 48.14%
- Compression Ratio : 4.02%
- New Shapes        : torch.Size([2048, 64]), torch.Size([64, 7168])

Processing MLP weight: model.layers.3.mlp.experts.2.down_proj.weight torch.Size([7168, 2048])
- Rank              : 64
- Variance Explained: 59.40%
- Compression Ratio : 4.02%
- New Shapes        : torch.Size([7168, 64]), torch.Size([64, 2048])

Processing MLP weight: model.layers.3.mlp.experts.2.gate_proj.weight torch.Size([2048, 7168])
- Rank              : 64
- Variance Explained: 61.83%
- Compression Ratio : 4.02%
- New Shapes        : torch.Size([2048, 64]), torch.Size([64, 7168])

Processing MLP weight: model.layers.3.mlp.experts.2.up_proj.weight torch.Size([2048, 7168])
- Rank              : 64
- Variance Explained: 60.23%
- Compression Ratio : 4.02%
- New Shapes        : torch.Size([2048, 64]), torch.Size([64, 7168])

Processing MLP weight: model.layers.3.mlp.experts.3.down_proj.weight torch.Size([7168, 2048])
- Rank              : 64
- Variance Explained: 55.60%
- Compression Ratio : 4.02%
- New Shapes        : torch.Size([7168, 64]), torch.Size([64, 2048])

Processing MLP weight: model.layers.3.mlp.experts.3.gate_proj.weight torch.Size([2048, 7168])
- Rank              : 64
- Variance Explained: 61.61%
- Compression Ratio : 4.02%
- New Shapes        : torch.Size([2048, 64]), torch.Size([64, 7168])

Processing MLP weight: model.layers.3.mlp.experts.3.up_proj.weight torch.Size([2048, 7168])
- Rank              : 64
- Variance Explained: 57.45%
- Compression Ratio : 4.02%
- New Shapes        : torch.Size([2048, 64]), torch.Size([64, 7168])

and possibly much more further into the LLM (ie: these early layers have the highest information density as found by ikawrakow in his experiments used to write the llama_tensor_get_type() logic).

I think it would be super-worthwhile to try this!

@slaren
Copy link
Member

slaren commented Feb 12, 2025

@slaren Does convert_lora_to_gguf.py use the same logic as convert_hf_to_gguf.py? I can't see how the dynamic version could ever work (even though it's a valid LoRA via the rank/alpha overrides), but will convert_lora_to_gguf.py even work for the fixed rank version and correctly turn the 256 LoRAs with the same rank into a large single tensor, etc?

I think so, but I am not sure about the details. This was implemented by @ngxson and @compilade.

@jukofyork
Copy link
Contributor

@slaren Does convert_lora_to_gguf.py use the same logic as convert_hf_to_gguf.py? I can't see how the dynamic version could ever work (even though it's a valid LoRA via the rank/alpha overrides), but will convert_lora_to_gguf.py even work for the fixed rank version and correctly turn the 256 LoRAs with the same rank into a large single tensor, etc?

I think so, but I am not sure about the details. This was implemented by @ngxson and @compilade.

No problem - I'm off out for a couple of hours, so hopefully will get confirmation by then :)

I think I'm getting more comfortable with the GGML stuff now anyway, so may be able to get this working if it isn't already.

@ngxson
Copy link
Collaborator

ngxson commented Feb 12, 2025

Does convert_lora_to_gguf.py use the same logic as convert_hf_to_gguf.py

Yes it should, but except for some edge cases where the conversion do some weird tensor permutations - this is too difficult to keep track so we don't have any methods to document it for now. But just a very very small number of models does that, I don't know if deepseek is the case here or not.

@compilade
Copy link
Collaborator

compilade commented Feb 12, 2025

Does convert_lora_to_gguf.py use the same logic as convert_hf_to_gguf.py? I can't see how the dynamic version could ever work (even though it's a valid LoRA via the rank/alpha overrides), but will convert_lora_to_gguf.py even work for the fixed rank version and correctly turn the 256 LoRAs with the same rank into a large single tensor, etc?

@jukofyork
Yes, both use the same logic. convert_lora_to_gguf.py reimplements some operations on tensors so that the LoRA adapter is correctly transformed. Only a subset of operations are implemented, but it's enough for most model architectures (including those which use MoE).

If kv_b_proj.weight is affected by your LoRA, then torch.split is not yet implemented in convert_lora_to_gguf.py, so this will not work as-is for kv_b_proj.weight, but (a subset of) view, transpose, reshape and indexing-based slices do work, so it should be possible to adapt the existing transformation. Although this is probably not a likely tensor to be in a LoRA adapter.

Otherwise, it should work as-is. MoE LoRAs should work since the LoRA refactor (ref: #8332 (comment)), and conversion for DeepSeekV3 seems to stack the experts in the usual way (I did not test it, though).

And you're right, dynamic LoRAs for MoE aren't handled because when stacking the experts with torch.stack, they are assumed to have the same rank.

@jukofyork
Copy link
Contributor

jukofyork commented Feb 12, 2025

Thanks guys, I've got it working by just exporting the LoRA adapter as GGUF directly for now:

def export_lora_gguf(
    path: os.PathLike[str] | str,
    tensors: list[tuple[str, torch.Tensor]],
    alpha: int,
    quant_type: gguf.GGMLQuantizationType
):

    print(f"Initializing GGUFWriter with path: '{path}'")
    writer = gguf.GGUFWriter(path, "deepseek2")
    writer.add_string("general.type", "adapter")
    writer.add_string("adapter.type", "lora")
    writer.add_float32("adapter.lora.alpha", alpha)

    for name, tensor in tensors:
        print(f"- Processing '{name}' with shape {tensor.shape}")
        
        if quant_type in [gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16]:
            # Handle float32 and float16 directly
            dtype = np.float32 if quant_type == gguf.GGMLQuantizationType.F32 else np.float16
            writer.add_tensor(name, tensor.cpu().numpy().astype(dtype))
        else:
            # Handle BF16 and Q8_0 through quantization
            quant_tensor = gguf.quants.quantize(tensor.cpu().numpy(), quant_type)
            print(f"-- Original tensor shape: {tensor.shape}")
            print(f"-- Quantized tensor shape: {quant_tensor.shape}")
            writer.add_tensor(name, quant_tensor, raw_shape=quant_tensor.shape, raw_dtype=quant_type)

    writer.write_header_to_file()
    writer.write_kv_data_to_file()
    writer.write_tensors_to_file()
    writer.close()

    print("Export completed")

It seems to be read in and working for a "layer 3 only" LoRA that I tested it with:

Initializing GGUFWriter with path: 'DeepSeek-R1-svd-lora-Q8_0.gguf'
- Processing 'blk.3.ffn_up_exps.weight.lora_a' with shape torch.Size([256, 64, 7168])
-- Original tensor shape: torch.Size([256, 64, 7168])
-- Quantized tensor shape: (256, 64, 7616)
- Processing 'blk.3.ffn_up_exps.weight.lora_b' with shape torch.Size([256, 2048, 64])
-- Original tensor shape: torch.Size([256, 2048, 64])
-- Quantized tensor shape: (256, 2048, 68)
- Processing 'blk.3.ffn_gate_exps.weight.lora_a' with shape torch.Size([256, 64, 7168])
-- Original tensor shape: torch.Size([256, 64, 7168])
-- Quantized tensor shape: (256, 64, 7616)
- Processing 'blk.3.ffn_gate_exps.weight.lora_b' with shape torch.Size([256, 2048, 64])
-- Original tensor shape: torch.Size([256, 2048, 64])
-- Quantized tensor shape: (256, 2048, 68)
- Processing 'blk.3.ffn_down_exps.weight.lora_a' with shape torch.Size([256, 64, 2048])
-- Original tensor shape: torch.Size([256, 64, 2048])
-- Quantized tensor shape: (256, 64, 2176)
- Processing 'blk.3.ffn_down_exps.weight.lora_b' with shape torch.Size([256, 7168, 64])
-- Original tensor shape: torch.Size([256, 7168, 64])
-- Quantized tensor shape: (256, 7168, 68)
> ./llama-server --lora DeepSeek-R1-svd-lora-Q8_0.gguf  ...
llama_adapter_lora_init_impl: loading lora adapter from 'DeepSeek-R1-svd-lora-Q8_0.gguf' ...
llama_adapter_lora_init_impl: CPU_Mapped LoRA buffer size =   459.00 MiB
llama_adapter_lora_init_impl: loaded 6 tensors from lora file

Just need to leave it running overnight now to do the full set of SVDs...

I'm going to use F16 and rank-64 for now:

(16/8)×58×3×256×64(7168 + 2048)/1024^3 = ~49GB

but if it looks promising, I will also try Q8_0:

(8.5/8)×58×3×256×64(7168 + 2048)/1024^3 = ~26GB

which is pretty small when compared to even the "meme" 1.58bit quants.

@jukofyork
Copy link
Contributor

I couldn't get the "fake LoRA" to work as it did something very strange:

  • It wasn't mmaped and used a huge amount of extra GPU, but ran about 10x slower than expected.
  • llama-perplexity on the other hand said it loaded the LoRA, but clearly didn't (as no weird GPU spikes) and ended up outputting NaNs (due to having the rank-64 subspace missing).

But, I have now managed to integrate it into the compute graph:

[  57/1495]           blk.3.ffn_down_exps.weight - [ 2048,  7168,   256,     1], type =   bf16, converting to q4_0 .. size =  7168.00 MiB ->  2016.00 MiB
[  58/1495]         blk.3.ffn_down_exps_a.weight - [ 2048,    64,   256,     1], type =   bf16, converting to q8_0 .. size =    64.00 MiB ->    34.00 MiB
[  59/1495]         blk.3.ffn_down_exps_b.weight - [   64,  7168,   256,     1], type =   bf16, converting to q8_0 .. size =   224.00 MiB ->   119.00 MiB
[  60/1495]          blk.3.ffn_down_shexp.weight - [ 2048,  7168,     1,     1], type =   bf16, size =   28.000 MB
[  61/1495]           blk.3.ffn_gate_exps.weight - [ 7168,  2048,   256,     1], type =   bf16, converting to q4_0 .. size =  7168.00 MiB ->  2016.00 MiB
[  62/1495]         blk.3.ffn_gate_exps_a.weight - [ 7168,    64,   256,     1], type =   bf16, converting to q8_0 .. size =   224.00 MiB ->   119.00 MiB
[  63/1495]         blk.3.ffn_gate_exps_b.weight - [   64,  2048,   256,     1], type =   bf16, converting to q8_0 .. size =    64.00 MiB ->    34.00 MiB
[  64/1495]            blk.3.ffn_gate_inp.weight - [ 7168,   256,     1,     1], type =    f32, size =    7.000 MB
[  65/1495]          blk.3.ffn_gate_shexp.weight - [ 7168,  2048,     1,     1], type =   bf16, size =   28.000 MB
[  66/1495]                blk.3.ffn_norm.weight - [ 7168,     1,     1,     1], type =    f32, size =    0.027 MB
[  67/1495]             blk.3.ffn_up_exps.weight - [ 7168,  2048,   256,     1], type =   bf16, converting to q4_0 .. size =  7168.00 MiB ->  2016.00 MiB
[  68/1495]           blk.3.ffn_up_exps_a.weight - [ 7168,    64,   256,     1], type =   bf16, converting to q8_0 .. size =   224.00 MiB ->   119.00 MiB
[  69/1495]           blk.3.ffn_up_exps_b.weight - [   64,  2048,   256,     1], type =   bf16, converting to q8_0 .. size =    64.00 MiB ->    34.00 MiB

and it appears to be working fine and only adds a little overhead.

It's pretty horribly hacked in for now:

static struct ggml_tensor * llm_build_moe_ffn(
        struct ggml_context * ctx,
       struct llama_context & lctx,
         struct ggml_tensor * cur,
         struct ggml_tensor * gate_inp,
         struct ggml_tensor * up_exps,
         struct ggml_tensor * gate_exps,
         struct ggml_tensor * down_exps,
         struct ggml_tensor * exp_probs_b,
                    int64_t   n_expert,
                    int64_t   n_expert_used,
            llm_ffn_op_type   type_op,
                       bool   norm_w,
                       bool   scale_w,
                      float   w_scale,
llama_expert_gating_func_type gating_op,
         const llm_build_cb & cb,
                        int   il,
         struct ggml_tensor * up_exps_a = nullptr,
         struct ggml_tensor * up_exps_b = nullptr,
         struct ggml_tensor * gate_exps_a = nullptr,
         struct ggml_tensor * gate_exps_b = nullptr,
         struct ggml_tensor * down_exps_a = nullptr,
         struct ggml_tensor * down_exps_b = nullptr) {
static struct ggml_tensor * llm_build_lora_mm_id(
        struct llama_context & lctx,
         struct ggml_context * ctx0,
          struct ggml_tensor * w,   // struct ggml_tensor * as
          struct ggml_tensor * cur, // struct ggml_tensor * b
          struct ggml_tensor * ids,
          struct ggml_tensor * a = nullptr,
          struct ggml_tensor * b = nullptr) {
    struct ggml_tensor * res = ggml_mul_mat_id(ctx0, w, cur, ids);
    if (a && b) {
        struct ggml_tensor * ab_cur = ggml_mul_mat_id(
            ctx0, b,
            ggml_mul_mat_id(ctx0, a, cur, ids),
            ids
        );
        res = ggml_add(ctx0, res, ab_cur);
    }
    for (auto & it : lctx.lora) {
        struct llama_adapter_lora_weight * lw = it.first->get_weight(w);
        if (lw == nullptr) {
            continue;
        }
        const float alpha = it.first->alpha;
        const float rank  = (float) lw->b->ne[0];
        const float scale = alpha ? it.second * alpha / rank : it.second;
        struct ggml_tensor * ab_cur = ggml_mul_mat_id(
            ctx0, lw->b,
            ggml_mul_mat_id(ctx0, lw->a, cur, ids),
            ids
        );
        ab_cur = ggml_scale(ctx0, ab_cur, scale);
        res = ggml_add(ctx0, res, ab_cur);
    }
    return res;
}

and the first test of a rank-64 LoRA seems to actually make it slightly worse:

#### BF16/Q4_0/Q4_0

[1]2.5160,[2]3.3227,[3]2.4058,[4]2.0030,[5]1.8059,[6]1.6632,[7]1.5704,[8]1.5020,[9]1.4516,[10]1.4119,[11]1.3972,[12]1.4372,[13]1.4479,[14]1.5764,[15]1.7091,[16]1.7684

#### BF16/Q4_0/Q4_0 + Q8_0 rank-64 LoRA

[1]2.5476,[2]3.3248,[3]2.3992,[4]2.0021,[5]1.8158,[6]1.6708,[7]1.5785,[8]1.5101,[9]1.4608,[10]1.4206,[11]1.4038,[12]1.4445,[13]1.4551,[14]1.5830,[15]1.7137,[16]1.7727

but I will investigate more tomorrow - there are lots of places a bug could have crept in for the SVD code, or it might just not like being quantised using Q8_0.

@jukofyork
Copy link
Contributor

jukofyork commented Feb 14, 2025

and the first test of a rank-64 LoRA seems to actually make it slightly worse:

#### BF16/Q4_0/Q4_0

[1]2.5160,[2]3.3227,[3]2.4058,[4]2.0030,[5]1.8059,[6]1.6632,[7]1.5704,[8]1.5020,[9]1.4516,[10]1.4119,[11]1.3972,[12]1.4372,[13]1.4479,[14]1.5764,[15]1.7091,[16]1.7684

#### BF16/Q4_0/Q4_0 + Q8_0 rank-64 LoRA

[1]2.5476,[2]3.3248,[3]2.3992,[4]2.0021,[5]1.8158,[6]1.6708,[7]1.5785,[8]1.5101,[9]1.4608,[10]1.4206,[11]1.4038,[12]1.4445,[13]1.4551,[14]1.5830,[15]1.7137,[16]1.7727

but I will investigate more tomorrow - there are lots of places a bug could have crept in for the SVD code, or it might just not like being quantised using Q8_0.

It's not the quantising as with BF16 it gets the exact same results. It seems that this "reversed LQER" method just doesn't work and actually makes it harder for the Q4_0 quantiser to quantise the residual.

I've found a way to do "proper" LQER (that doesn't require all the expert tensors to be chopped back up and transposed) but it relies on the python gguf.quants.quantize --> gguf.quants.dequantize round-trip being deterministic and identical to the C++ code, and will only work on Q4_0, Q4_1, Q5_0, Q5_1 and Q8_0 as the rest of the quants don't have gguf.quants.quantize() implemented in native the python GGML library.

It's going to take a couple of days to run because of the numpy calls this uses, but will report back in there LQER discussion if I have any success with it:

#8831

as it's not really specific to deepkseek V3/R1 nor MLA.


I'll have a look next week to see if I can find what causes the overflow using float16 of the attn_k_b.weight tensor and see if I can scale and then redistribute into the layer_norm gamma parameter that is used before or after it.

@saood06
Copy link

saood06 commented Feb 15, 2025

@jukofyork

Forgot to post this, table comparison of all your quants alongside mine (including an IQ1 based quant I had tested). I do use an imatrix (but not on the new split tensor as it hasn't been applied since the imatrix.dat predates it). Your quants do beat mine, but I think they are all larger.

Quant [1] [2] [3] [4] [5] [6] [7] [8] [9] [10] [11] [12] [13] [14] [15] [16] SUM
My IQ1_S 3.7099 4.6162 3.5438 3.4199 3.5375 3.5710 3.5428 3.6748 3.7417 3.6724 3.7879 3.9602 4.0477 4.1439 4.2809 4.1981 61.4487
My V1 2.5944 3.3242 2.4001 1.9949 1.8067 1.6666 1.5704 1.5055 1.4559 1.4154 1.3999 1.4404 1.4500 1.5786 1.7101 1.7729 29.0860
My V2 2.5474 3.3247 2.4001 2.0029 1.8181 1.6716 1.5734 1.5084 1.4592 1.4194 1.4035 1.4376 1.4476 1.5734 1.7047 1.7654 29.0574
My V3 2.5551 3.3239 2.3980 1.9980 1.8057 1.6631 1.5676 1.5029 1.4525 1.4122 1.3963 1.4421 1.4516 1.5784 1.7089 1.7692 29.0255
BF16/Q4_0/Q4_0 2.5160 3.3227 2.4058 2.0030 1.8059 1.6632 1.5704 1.5020 1.4516 1.4119 1.3972 1.4372 1.4479 1.5764 1.7091 1.7684 28.9887
BF16/Q4_0/Q4_0 + imatrix 2.4996 3.3182 2.3944 1.9934 1.8041 1.6605 1.5667 1.4976 1.4491 1.4110 1.3963 1.4279 1.4390 1.5674 1.6989 1.7584 28.8825
BF16/Q4_0/Q8_0 2.5046 3.2991 2.3829 1.9872 1.7991 1.6562 1.5628 1.4979 1.4485 1.4099 1.3955 1.4280 1.4409 1.5679 1.6980 1.7582 28.8367
BF16/Q5_K/Q5_K 2.5143 3.3036 2.3746 1.9854 1.7920 1.6478 1.5561 1.4888 1.4393 1.4002 1.3845 1.4178 1.4293 1.5569 1.6882 1.7480 28.7268
BF16/Q4_K/Q6_K 2.5266 3.3006 2.3780 1.9832 1.7932 1.6461 1.5550 1.4902 1.4404 1.3994 1.3840 1.4207 1.4321 1.5584 1.6898 1.7498 28.7475
BF16/Q5_K/Q6_K 2.5030 3.2798 2.3704 1.9793 1.7866 1.6453 1.5536 1.4883 1.4388 1.3993 1.3838 1.4188 1.4298 1.5565 1.6874 1.7464 28.6671

Edit: Included the IQ1_S

@jukofyork
Copy link
Contributor

jukofyork commented Feb 16, 2025

I've still not found a good way to fix the float16 overflow for attn_k_b, but I have found a way to boost token generation quite significantly:

You need too add this to llama_tensor_get_type():

https://github.com/ggml-org/llama.cpp/blob/master/src/llama-quant.cpp#L122

static ggml_type llama_tensor_get_type(quantize_state_impl & qs, ggml_type new_type, const ggml_tensor * tensor, llama_ftype ftype) {
    const std::string name = ggml_get_name(tensor);

    // TODO: avoid hardcoded tensor names - use the TN_* constants
    const llm_arch arch = qs.model.arch;
    const auto       tn = LLM_TN(arch);

    auto use_more_bits = [](int i_layer, int n_layers) -> bool {
        return i_layer < n_layers/8 || i_layer >= 7*n_layers/8 || (i_layer - n_layers/8)%3 == 2;
    };
    const int n_expert = std::max(1, (int)qs.model.hparams.n_expert);
    auto layer_info = [n_expert] (int i_layer, int n_layer, const char * name) {
        if (n_expert > 1) {
            // Believe it or not, "experts" in the FFN of Mixtral-8x7B are not consecutive, but occasionally randomly
            // sprinkled in the model. Hence, simply dividing i_ffn_down by n_expert does not work
            // for getting the current layer as I initially thought, and we need to resort to parsing the
            // tensor name.
            if (sscanf(name, "blk.%d.", &i_layer) != 1) {
                throw std::runtime_error(format("Failed to determine layer for tensor %s", name));
            }
            if (i_layer < 0 || i_layer >= n_layer) {
                throw std::runtime_error(format("Bad layer %d for tensor %s. Must be in [0, %d)", i_layer, name, n_layer));
            }
        }
        return std::make_pair(i_layer, n_layer);
    };

    // <<<<<<<<<< <<<<<<<<<< <<<<<<<<<< <<<<<<<<<< <<<<<<<<<< <<<<<<<<<< <<<<<<<<<< <<<<<<<<<< <<<<<<<<<< 
    if (name.find("attn_k_b") != std::string::npos || name.find("attn_v_b") != std::string::npos) {
        new_type = GGML_TYPE_F32;
    }
    else
    // <<<<<<<<<< <<<<<<<<<< <<<<<<<<<< <<<<<<<<<< <<<<<<<<<< <<<<<<<<<< <<<<<<<<<< <<<<<<<<<< <<<<<<<<<< 

    // for arches that share the same tensor between the token embeddings and the output, we quantize the token embeddings
    // with the quantization of the output tensor
    if (name == tn(LLM_TENSOR_OUTPUT, "weight") || (!qs.has_output && name == tn(LLM_TENSOR_TOKEN_EMBD, "weight"))) {
.
.
.

recompile, and then re-quantise so that these two tensors get overwritten to use F32.

You can also set it to GGML_TYPE_BF16, but it's very small anyway (only a little larger than ffn_gate_inp.weight which are kept as F32 already), and some backends/CPUs might not like BF16 (and I'm not even sure if the CUDA backend doesn't just upcast it anyway....).


You can't use GGML_TYPE_F16 for attn_k_b or it will currently overflow. I assume this is because there is a F16 accumulator going out of the +/- 65504 range. The only way I can see to fix this currently would be to scale down attn_kv_a_norm.weight by a constant factor (to store the 512 compressed KV-cache as smaller values) and then scale kq_scale by the inverse:

const float kq_scale = 1.0f*mscale*mscale/sqrtf(float(hparams.n_embd_head_k));
.
.
.
kq = ggml_soft_max_ext(ctx0, kq, KQ_mask, kq_scale, hparams.f_max_alibi_bias);
cb(kq, "kq_soft_max_ext", il);

but there are that many things getting sliced and permuted here, I'm not confident this wouldn't miss something and it's a pretty ugly hack anyway.

There no point in using ggml_mul_mat_set_prec to bump to F32as the tensor is miniscule compared to the rest of the model (I tried this before and it didn't seem to work anyway), but if you could get F16 to work it would likely have (much) higher FLOPS on some GPUs than using F32...

Which brings me to the reason why not quantising attn_k_b and attn_v_b gives such a big gain:


All the other weights in all the other tensors in the model only get accessed a single time per token for token generation, and hence why quantising these can actually speed up the token generation by trading a small amount of dequantising compute for higher effective memory throughput...

BUT: attn_k_b and attn_v_b aren't actually used like this, and instead of a single O(1) access per token; they are accessed O(n_tokens), and the current code with the clever/optimised kernels to perform matrix multiplies whilst keeping the tensor in its quantised form are doing a huge amount of extra unnecessary work - which (I think) gets more and more significant as you increase the length of the context (it might not depending on if the rows are dequantised and used for the whole batch).

I think between this fix (which may actually be worth adding to the PR as an F32 override - in the same way as ffn_gate_inp.weight, etc already does this) and using 6 instead of 8 experts; likely explains much more of the difference between KTransformers and llama.cpp token generation differences.

I would think that this fix may have an even bigger effect on CPU-based systems as the sizes:

[   4/1147]                blk.0.attn_k_b.weight - [  128, 65536,     1,     1], type =    f32, size =   32.000 MB
[  13/1147]                blk.0.attn_v_b.weight - [  512, 16384,     1,     1], type =    f32, size =   32.000 MB

are so small they probably fit in CPU cache.

@jukofyork
Copy link
Contributor

@jukofyork

Forgot to post this, table comparison of all your quants alongside mine (including an IQ1 based quant I had tested). I do use an imatrix (but not on the new split tensor as it hasn't been applied since the imatrix.dat predates it). Your quants do beat mine, but I think they are all larger.

I've now dequantised to a F32 version of the original FP8 (ie: 0.6TB DeepSeek-R1 --> 2.4TB DeepSeek-R1-f32 --> 2.4TB DeepSeek-R1-mla-f32.gguf) and am going to test all these quant-mixes made using DeepSeek-R1-mla-f32.gguf:

    // ### JUK ###
    if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K          // Q2_K_XM : ~3.0 bits per weight for experts (16×3×3.5 + 42×(3.5 + 2×2.5))/(3×58) | 
        || ftype == LLAMA_FTYPE_MOSTLY_Q2_K_S     // Q2_K_XL : ~3.5 bits per weight for experts (16×3×4.5 + 42×(4.5 + 2×2.5))/(3×58) | 
        || ftype == LLAMA_FTYPE_MOSTLY_Q3_K_S     // Q3_K_XM : ~4.0 bits per weight for experts (16×3×4.5 + 42×(4.5 + 2×3.5))/(3×58) | 
        || ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M     // Q3_K_XL : ~4.5 bits per weight for experts (16×3×5.5 + 42×(5.5 + 2×3.5))/(3×58) | 
        || ftype == LLAMA_FTYPE_MOSTLY_Q4_K_S     // Q4_K_XM : ~5.0 bits per weight for experts (16*3*5.5 + 42*(5.5 + 2×4.5))/(3*58) | 404 GiB (5.16 BPW)
        || ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M     // Q4_K_XL : ~5.5 bits per weight for experts (16*3*6.5 + 42*(6.5 + 2×4.5))/(3*58) | 446 GiB
        || ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M) {  // Q5_K_XM : ~6.0 bits per weight for experts (16*3*6.5 + 42*(6.5 + 2×5.5))/(3*58) | 483 GiB (6.16 BPW)
        if (name.find("_exps") != std::string::npos) {
            int i_layer;
            if (sscanf(name.c_str(), "blk.%d.", &i_layer) != 1) {
                throw std::runtime_error(format("Failed to determine layer for tensor %s", name.c_str()));
            }
            if (name.find("ffn_down") != std::string::npos || i_layer <= 10 || i_layer >= 53) {
                if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K) {
                    new_type = GGML_TYPE_Q3_K;  // Q2_K_XM
                }
                else if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K_S || ftype == LLAMA_FTYPE_MOSTLY_Q3_K_S) {
                    new_type = GGML_TYPE_Q4_K;  // Q2_K_XL & Q3_K_XM
                }
                else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M  || ftype == LLAMA_FTYPE_MOSTLY_Q4_K_S) {
                    new_type = GGML_TYPE_Q5_K;  // Q3_K_XL & Q4_K_XM
                }
                else /* if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M  || ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M) */ { 
                    new_type = GGML_TYPE_Q6_K;  // Q4_K_XL & Q5_K_XM
                }
            }
            else {
                if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K || ftype == LLAMA_FTYPE_MOSTLY_Q2_K_S) {
                    new_type = GGML_TYPE_Q2_K;  // Q2_K_XM & Q2_K_XL
                }                
                else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_S || ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M) {
                    new_type = GGML_TYPE_Q3_K;  // Q3_K_XM & Q3_K_XL
                }
                else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_S || ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M) {
                    new_type = GGML_TYPE_Q4_K;  // Q4_K_XM & Q4_K_XL
                }
                else /* if (ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M) */ {
                    new_type = GGML_TYPE_Q5_K;  // Q5_K_XM
                }
            }
        }
        else if (name.find("attn_kv_a_mqa") != std::string::npos || name.find("attn_k_b") != std::string::npos || name.find("attn_v_b") != std::string::npos) {
            new_type = GGML_TYPE_F32;  // Also used: type_kr = type_kv = GGML_TYPE_F32
        }
        else {
            new_type = GGML_TYPE_Q8_0;
        }
    }
    else
    // ### JUK ###

and then run llama.perplexity on wiki.test.raw for the full 561 chunks (the Q5_K_XM is running now and only takes around 5.5 hours).


I'm also using a F32 quant for attn_kv_a_mqa and have hacked the PR to use F32 to store the compressed KV-cache:

# Use float32 for the compressed KV-cache.
safe_sed "src/llama-kv-cache.h" "ggml_type type_kr = GGML_TYPE_F16" "ggml_type type_kr = GGML_TYPE_F32"
safe_sed "src/llama-kv-cache.h" "ggml_type type_kv = GGML_TYPE_F16" "ggml_type type_kv = GGML_TYPE_F32"
llama_kv_cache_init:      CUDA0 KV buffer size =   263.51 MiB
llama_kv_cache_init:      CUDA1 KV buffer size =   255.01 MiB
llama_init_from_model: KV self size  =    0.00 MiB, K (f16):    0.00 MiB, V (f16):    0.00 MiB
llama_init_from_model: KV self size  =  274.50 MiB, K^R (f16):   30.50 MiB, c^KV (f16):  244.00 MiB

It says (f16) still but is definitely using F32 and double the size.

The idea being that I can clearly see if this makes much difference to the perplexity scores vs using F16 for the compressed KV-cache when I rerun this for the high-bit quants tomorrow.


I'm also going to run another test using:

        else if (name.find("attn_kv_a_mqa") != std::string::npos || name.find("attn_k_b") != std::string::npos || name.find("attn_v_b") != std::string::npos) {
            new_type = GGML_TYPE_F16;
        }

to see what effect this has on the token-generation speed (even if the model just spews garbage or nan for perplexity) and whether it is worth trying to implement my dodgy downscale/upscale hack I mentioned above, as my A6000 GPUs should theoretically have much higher FLOPS for F16.

@jukofyork
Copy link
Contributor

@saood06 Here are the first 16 chunks for the Q5_K_XM which is 483 GiB (6.16 BPW) to compare with:

[1]2.4987,[2]3.2797,[3]2.3680,[4]1.9800,[5]1.7890,[6]1.6474,[7]1.5545,[8]1.4885,[9]1.4385,[10]1.3998,[11]1.3850,[12]1.4199,[13]1.4320,[14]1.5578,[15]1.6885,[16]1.7487

This is the absolute maximum I can run as trying to push it even a little higher will likely start to use swapfile and/or bring down the OS, so currently this is probably the best estimate of the lower-bound for the full model.

I will post the full perplexity run results in a couple of days.

@jukofyork
Copy link
Contributor

jukofyork commented Feb 16, 2025

There no point in using ggml_mul_mat_set_prec to bump to F32as the tensor is miniscule compared to the rest of the model (I tried this before and it didn't seem to work anyway), but if you could get F16 to work it would likely have (much) higher FLOPS on some GPUs than using F32...

It looks like the CUDA GGML code only uses this (found by searching for "op_params[0]" 😱) if using CuBLAS:

  • ggml_cuda_op_mul_mat_cublas and ggml_cuda_mul_mat_batched_cublas both reference it.

Or for matrix-vector products:

  • ggml_cuda_mul_mat_vec and ggml_cuda_op_mul_mat_vec both reference it.

It's also used in ggml_cuda_op_acc but I can't tell if that is actually used anywhere in ggml-cuda.cu.

What ggml_cuda_op_mul_mat is doing is beyond me: it seems to be converting the inputs to QK8_1 and I can't see any references to F16 in it, so why specifically using F16 for attn_k_b causes an overflow, what happens with BF16, and whether there is some hidden option that could be passed to runtime-dequantise attn_k_b to F32 for the mul_mat to save everyone having to rerun llama-quantise is still a mystery sadly...

@JohannesGaessler
Copy link
Collaborator

What ggml_cuda_op_mul_mat is doing is beyond me

The function makes input tensors contiguous and presents them as single-batch matrix multiplications to other kernels. The conversion to q8_1 is only done for kernels that use quantized data.

@jukofyork
Copy link
Contributor

jukofyork commented Feb 16, 2025

What ggml_cuda_op_mul_mat is doing is beyond me

The function makes input tensors contiguous and presents them as single-batch matrix multiplications to other kernels. The conversion to q8_1 is only done for kernels that use quantized data.

I don't know if it's worth looking at yet as this is still a draft PR, but it should be quite easy to replicate the slowdown I saw using Q8_0 via the smaller deepseek-v2-lite model.

I just looked at my logs for running the first 16 chunks of llama-perplexity and it seems to affect prompt processing too:

BF16 for attn_k_b:

llama_model_loader: - type  f32:  361 tensors
llama_model_loader: - type q8_0:  551 tensors
llama_model_loader: - type q5_K:  116 tensors
llama_model_loader: - type q6_K:   58 tensors
llama_model_loader: - type bf16:   61 tensors

perplexity: tokenization took 1190.46 ms
perplexity: calculating perplexity over 561 chunks, n_ctx=512, batch_size=2048, n_seq=4
perplexity: 147.01 seconds per pass - ETA 5 hours 43.63 minutes
[1]2.5077,[2]3.2865,[3]2.3722,[4]1.9792,[5]1.7845,[6]1.6442,[7]1.5532,[8]1.4876,[9]1.4382,[10]1.3990,[11]1.3828,[12]1.4122,[13]1.4242,[14]1.5514,[15]1.6815,[16]1.7411,^C

F16 for attn_k_b:

llama_model_loader: - type  f32:  361 tensors
llama_model_loader: - type  f16:   61 tensors
llama_model_loader: - type q8_0:  551 tensors
llama_model_loader: - type q5_K:  116 tensors
llama_model_loader: - type q6_K:   58 tensors

perplexity: calculating perplexity over 561 chunks, n_ctx=512, batch_size=2048, n_seq=4
perplexity: 146.17 seconds per pass - ETA 5 hours 41.67 minutes
[1]nan,[2]nan,[3]nan,[4]nan,^C

Q8_0 for attn_k_b:

llama_model_loader: - type  f32:  361 tensors
llama_model_loader: - type q8_0:  612 tensors
llama_model_loader: - type q5_K:  116 tensors
llama_model_loader: - type q6_K:   58 tensors

perplexity: calculating perplexity over 561 chunks, n_ctx=512, batch_size=2048, n_seq=4
perplexity: 330.34 seconds per pass - ETA 12 hours 52.15 minutes
[1]2.4944,[2]3.2788,[3]2.3639,[4]1.9761,[5]1.7833,[6]1.6414,[7]1.5508,[8]1.4850,[9]1.4362,[10]1.3974,[11]1.3819,[12]1.4160,[13]1.4275,[14]1.5542,[15]1.6846,[16]1.7440,^C

Which is nearly 2.5x longer for Q8_0.

This isn't quite the same custom quant, but uses F32 for attn_k_b and looks to have similar timing to F16 and BF16:

llama_model_loader: - type  f32:  544 tensors
llama_model_loader: - type q8_0:  429 tensors
llama_model_loader: - type q5_K:   84 tensors
llama_model_loader: - type q6_K:   90 tensors

perplexity: calculating perplexity over 561 chunks, n_ctx=512, batch_size=2048, n_seq=4
perplexity: 146.66 seconds per pass - ETA 5 hours 42.80 minutes
[1]2.4987,[2]3.2797,[3]2.3680,[4]1.9800,[5]1.7890,[6]1.6474,[7]1.5545,[8]1.4885

I can't find the logs for the token generation but the speed difference wasn't as bad as this: something like 2.8 tokens per second for Q8_0 and 3.6-3.8 tokens per second for BF16.

The F16 overflow might be specific to deepseek-v3 or deepseek-r1 only, but as well as causing nan in llama-perplexity (as the example above shows), it just writes <think> and then the same word over and over.

I went as far as testing the magnitude of all the attn_k_b tensors and IIRC, none even had a magnitude over 256. It's also multiplying from the compressed KV-cache which has been passed through a layer_norm before being stored, so I'm at a loss to see what else could be overflowing. I did try adding a ggml_mul_mat_set_prec() call right before all the mul_mat() calls before I narrowed the overflow down to attn_k_b but from my brief skimming of the code this afternoon; I don't think that looks to be used here?

@JohannesGaessler
Copy link
Collaborator

Generally speaking, the KQ matrix is susceptible to overflow. So it is preferable to use BF16 or FP32 accumulators for its calculation. However, I was never able to get FP16 matrix multiplication with FP32 accumulation to work with cuBLAS. The documentation says it should be possible but the kernel fails to launch when I try it. Currently the precision argument for KQ is not used for cuBLAS GEMM. For a FP16 K matrix FP16 accumulation is used unconditionally.

@JohannesGaessler
Copy link
Collaborator

I think I misremembered. After looking at the documentation again I think the problem was that FP16, FP16 -> FP32 GEMM is supported but the performance was so much worse that there was basically no point in using it.

@fairydreaming
Copy link
Collaborator Author

I investigated possible reasons for poor scaling of token generation when using DeepSeek V3/R1 on dual CPU systems.

My current working hypothesis is that the DeepSeek V3/R1 expert FFN matrices are so small (7168 x 2048) that overhead of using two CPUs when doing matrix vector multiplication during token generation negates almost all performance gains.

I suppose this is the reason why ktransformers folks in their v3.0-preview have two copies of experts in memory, one for each CPU.

I'm going to create a NUMA-aware matrix vector multiplication benchmark to verify this hypothesis.

I thought about possible solutions. One would be to assign the experts in each layer into N sets where N is equal to the number of CPUs and then use top n_expert_used/N experts from each set during inference. In this solution each CPU would handle only its assigned local experts and there would be no communication overhead. But it can result in non-optimal expert choices, not sure how it would affect the model performance.

@jukofyork
Copy link
Contributor

jukofyork commented Feb 19, 2025

    // ######
    if (ftype == LLAMA_FTYPE_MOSTLY_IQ4_XS        // IQ4_XXS : 4.25 BPW for experts | 345 GiB (4.41 BPW) (-28.4%) | PPL = 3.3850 +/- 0.01877 (+1.51%) | 15.05 tokens per second ( +8.0%)
        || ftype == LLAMA_FTYPE_MOSTLY_Q4_0       // Q4_0_XS : 4.5 BPW for experts  | 365 GiB (4.66 BPW) (-24.4%) | PPL = 3.3944 +/- 0.01885 (+1.95%) | 14.17 tokens per second ( +1.6%)
        || ftype == LLAMA_FTYPE_MOSTLY_Q4_K_S     // Q4_K_XS : 4.5 BPW for experts  | 365 GiB (4.66 BPW) (-24.4%) | PPL = 3.3724 +/- 0.01866 (+0.66%) | 18.81 tokens per second (+34.9%)
        || ftype == LLAMA_FTYPE_MOSTLY_Q5_K_S     // Q5_K_XS : 5.5 BPW for experts  | 441 GiB (5.63 BPW) ( -8.6%) | PPL = 3.3546 +/- 0.01852 (+0.16%) | 13.84 tokens per second ( -0.7%)
                                                  // -----------------------------------------------------------------------------------------------------------------------------------
        || ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M     // Q4_K_XM : ~5.0 BPW for experts | 404 GiB (5.16 BPW) (-16.2%) | PPL = 3.3666 +/- 0.01863 (+0.48%) | 15.82 tokens per second (+13.5%)
        || ftype == LLAMA_FTYPE_MOSTLY_IQ4_NL     // Q4_K_XL : ~5.5 BPW for experts | 446 GiB (5.69 BPW) ( -7.6%) | PPL = 3.3614 +/- 0.01858 (+0.33%) | 16.03 tokens per second (+15.0%)
        || ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M     // Q5_K_XM : ~6.0 BPW for experts | 483 GiB (6.16 BPW)          | PPL = 3.3504 +/- 0.01849          | 13.94 tokens per second
                                                  // -----------------------------------------------------------------------------------------------------------------------------------
        || ftype == LLAMA_FTYPE_MOSTLY_Q5_1       // Q5_K_XH : 5.0 BPW for experts  | 403 GiB (5.15 BPW)          | PPL = 3.3695 +/- 0.01864 (+0.57%) | 15.90 tokens per second (+14.1%)
        || ftype == LLAMA_FTYPE_MOSTLY_Q6_K       // Q6_K_XH : 6.0 BPW for experts  | 481 GiB (6.15 BPW) (-16.2%) | PPL = 3.3548 +/- 0.01853 (+0.13%) | 13.87 tokens per second ( -0.5%)
                                                  // -----------------------------------------------------------------------------------------------------------------------------------
        ) {                                       // iQ4_K_XS (Q4_K_XS using Bartowski imatrix for experts only)  : PPL = 3.3734 +/- 0.01866 (+0.69%) | 18.76 tokens per second (+34.6%)
        if (name.find("_exps") != std::string::npos) {
            int i_layer;
            if (sscanf(name.c_str(), "blk.%d.", &i_layer) != 1) {
                throw std::runtime_error(format("Failed to determine layer for tensor %s", name.c_str()));
            }
            if (ftype == LLAMA_FTYPE_MOSTLY_IQ4_XS) {
                new_type = GGML_TYPE_IQ4_XS;    // IQ4_XXS
            }
            else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_0) {
                new_type = GGML_TYPE_Q4_0;      // Q4_0_XS
            }
            else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_S) {
                new_type = GGML_TYPE_Q4_K;      // Q4_K_XS
            }
            else if (ftype == LLAMA_FTYPE_MOSTLY_Q5_K_S) {
                new_type = GGML_TYPE_Q5_K;      // Q5_K_XS
            }
            else if (ftype == LLAMA_FTYPE_MOSTLY_Q5_1) {
                new_type = (i_layer <= 31 ? GGML_TYPE_Q5_K : GGML_TYPE_Q4_K);  // Q5_K_XH first and last 29 experts
            }
            else if (ftype == LLAMA_FTYPE_MOSTLY_Q6_K) {
                new_type = (i_layer <= 31 ? GGML_TYPE_Q6_K : GGML_TYPE_Q5_K);  // Q6_K_XH first and last 29 experts
            }
            else if (name.find("ffn_down") != std::string::npos || i_layer <= 10 || i_layer >= 53) {  // First 8 and last 8 experts (ie: 16/58 experts)
                if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M) {
                    new_type = GGML_TYPE_Q5_K;  // Q4_K_XM
                }
                else { 
                    new_type = GGML_TYPE_Q6_K;  // Q4_K_XL & Q5_K_XM
                }
            }
            else {
                if (ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M) {
                    new_type = GGML_TYPE_Q5_K;  // Q5_K_XM
                }
                else {
                    new_type = GGML_TYPE_Q4_K;  // Q4_K_XM & Q4_K_XL
                }
            }
        }
        else if (name.find("attn_kv_a_mqa") != std::string::npos || name.find("attn_k_b") != std::string::npos || name.find("attn_v_b") != std::string::npos) {
            new_type = GGML_TYPE_F32;  // Also used: type_kr = type_kv = GGML_TYPE_F32
        }
        else {
            new_type = GGML_TYPE_Q8_0;
        }
    }
    else
    // ######

I've finished the testing of the custom quants:

  • Using Q4_K for all expert tensors seems a clear winner (the 365GB model would likely fit on 2 x M2 Ultra 192GB too).
  • Using Q4_0 and IQ4_XS gave particularly bad performance in comparison ( Q4_0 surprisingly in terms of tokens/s too).
  • Using Bartowski's imatrix for experts made no measurable difference for Q4_K.
  • Mixing different quants for the first/last tensors and bumping up_proj had very little gain (ie: might as well just use Q5_K ).
  • Not included, but found that any mixtures involving Q3_K really start to hurt performance badly.

Just running one last test on pure Q4_K to see if using type_kr = type_kv = GGML_TYPE_F16 vs type_kr = type_kv = GGML_TYPE_F32 makes any difference.

EDIT:

// Q4_K_XS : 4.5 BPW for experts  | 365 GiB (4.66 BPW) (-24.4%) | PPL = 3.3724 +/- 0.01866 (+0.66%)
// Q4_K_XS using type_kr = type_kv = GGML_TYPE_F16 & 44 threads : PPL = 3.3728 +/- 0.01866 (+0.67%)

No difference,

@jukofyork
Copy link
Contributor

Generally speaking, the KQ matrix is susceptible to overflow. So it is preferable to use BF16 or FP32 accumulators for its calculation. However, I was never able to get FP16 matrix multiplication with FP32 accumulation to work with cuBLAS. The documentation says it should be possible but the kernel fails to launch when I try it. Currently the precision argument for KQ is not used for cuBLAS GEMM. For a FP16 K matrix FP16 accumulation is used unconditionally.

I think the safest option then is probably to use F32 as an override for now.

@Thomas-MMJ
Copy link

In Daniel's 1.58 quantization he kept the shared expert at a higher resolution than the routed experts.

@jukofyork
Copy link
Contributor

        // whether to use n_tokens as the matrix dimension during multiplication or n_head
        // n_tokens is higher during prompt processing, this allows to optimize for this case
        bool pp_opt = n_tokens > n_head;

I think this might be causing some weird problem in the CUDA back-end where a different code-path is taken.

If I leave it as default and use this 127-token test prompt:

> Varis adjusted the noose, its hemp fibers grinding beneath his calluses. “Last chance,” he said, voice like gravel dragged through mud. “Confess, and your soul stays your own.”
>
> Jurl laughed—a wet, gurgling sound. “You’re knee-deep in it, Coldwater. ” The thing inside him twisted the boy’s lips into a grin too wide for his face. “The Great Wolf’s howlin’ again. The Dead’s Gate’s rusted through… ”

Turn this into the opening chapter of a Grimdark trilogy.

The model wont' say the actual phrases and it feels "off" - like there is something wrong with the attention mechanism (it sometimes "sort of" says the phrases, but not quite and often not at all).

If I fix the flag to always be true, eg:

    bool pp_opt = true;

Then all of a sudden the model starts to says those phrases and seems way better at writing in general (I suspect this triggers a different code-path - possibly something to do with the matrix-vector vs matrix-matrix stuff I remember seeing the other day?)

If I fix the flag to always be false eg:

    bool pp_opt = false;

Then run llama-perplexity, I get almost (but not quite) the same PPL to the default (ie: where n_tokens > n_head --> 512 > 128 --> pp_opt = true always), so I think the code-path is testing for a batch size of 1 exactly and not related to the actual series of ggml_permute and ggml_cont this triggers in llama.cpp::build_deepseek2().

So I thought I'd try running with bool pp_opt = false and llama-perplexity with llama-perplexity to test this idea, and weirdly:

perplexity: 607.12 seconds per pass - ETA 23 hours 39.13 minutes
[1]2.4873,[2]3.2776,[3]2.3693,[4]1.9780

It actually seems to get better PPL for these first few values (sorry no way I can run the 24h to completion) and the difference is almost the size of the error bar from the full PPL calculated over the default setting.

I don't know how else to help diagnose what's going on 😕

Could it be that the 127-token test prompt is not a multiple of 32 and when it gets permuted it's causing some problem there?

@slaren
Copy link
Member

slaren commented Feb 19, 2025

@jukofyork If you think that some operation is producing wrong results with CUDA, an easy way to test that would be to add a test case to test-backend-ops. It should be fairly straightforward, you would need to add a test case for the relevant operations and with the same shapes and types as are used with this model, in make_test_cases_eval.

@jukofyork
Copy link
Contributor

@slaren @JohannesGaessler @fairydreaming

I've got a little further now and think it's the same overflow problem that affected float16 tensors - just with pp_opt set it must cause more severe problems and/or some kind of catastrophic-cancellation due to the rows/columns being swapped.

Both the existing attention implementations use set_prec(cur, GGML_PREC_F32) here:

    if (cparams.flash_attn) {
        GGML_UNUSED(model);
        GGML_UNUSED(n_ctx);

        // split cached v into n_head heads (not transposed)
        struct ggml_tensor * v =
            ggml_view_3d(ctx, kv.v_l[il],
                    n_embd_head_v, n_kv, n_head_kv,
                    ggml_row_size(kv.v_l[il]->type, n_embd_v_gqa),
                    ggml_row_size(kv.v_l[il]->type, n_embd_head_v),
                    0);
        cb(v, "v", il);

        cur = ggml_flash_attn_ext(ctx, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias,
                                  hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f);

        ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32);

        cur = ggml_reshape_2d(ctx, cur, n_embd_head_v*n_head, n_tokens);
    } else {
        struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q);
        cb(kq, "kq", il);

        // note: this op tends to require high floating point range
        //       while for some models F16 is enough, for others it is not, so we default to F32 here
        ggml_mul_mat_set_prec(kq, GGML_PREC_F32);

and by trial and error with --temp 0.0, I've found that these 3 also need to be upped for the MLA implementation:

                struct ggml_tensor * kq_nope = ggml_mul_mat(ctx0, kv_cache, q_nope2);
                ggml_mul_mat_set_prec(kq_nope, GGML_PREC_F32); // ***
                cb(kq_nope, "kq_nope", il);
                struct ggml_tensor * kq_pe = ggml_mul_mat(ctx0, kr_cache, q_pe);
                ggml_mul_mat_set_prec(kq_pe, GGML_PREC_F32); // ***
                cb(kq_pe, "kq_pe", il);
                struct ggml_tensor * kqv_compressed = ggml_mul_mat(ctx0, kv_cache_trans, kq);
                ggml_mul_mat_set_prec(kqv_compressed, GGML_PREC_F32); // ***
                cb(kqv_compressed, "kqv_compressed", il);

but this needs to be compiled with -DGGML_CUDA_FORCE_CUBLAS=1 to allow those to be set.

This was with bool pp_opt = true; fixed, and it still gives different output with bool pp_opt = n_tokens > n_head, but it's not as obviously broken as it was before.

When it comes times to merge the official MLA implementation, then I think this needs to be tested more thoroughly than I can do.

@JohannesGaessler
Copy link
Collaborator

but this needs to be compiled with -DGGML_CUDA_FORCE_CUBLAS=1 to allow those to be set.

The MMQ kernels always use FP32 for the accumulators, if there are numerical issues they must be due to extremal values in the inputs since FP16 is used for the scales of the quantized data.

@slaren
Copy link
Member

slaren commented Feb 19, 2025

ggml_cuda_mul_mat_batched_cublas always converts src1 to F16 regardless of the value of ggml_mul_mat_set_prec, and that may be a problem. This function is used in most KV operations.

We should avoid these conversions regardless, because the memory required for the intermediate copies is too high with big contexts. However, that would require a f16 x f32 -> f32 matrix multiplication, and I am not sure that cuBLAS can do that.

@jukofyork
Copy link
Contributor

I think just changing bool pp_opt = n_tokens > n_head; to bool pp_opt = true; has fixed whatever I was getting. It possibly runs a tiny bit slower though (~0.25 tokens/second).

Here's my full script that merges the PRs and applies all the hacks (including the commented out ones I'm not using):

#!/bin/bash

function safe_sed() {
    local file=$1
    local pattern=$2
    local replacement=$3

    # Check if pattern exists
    if ! sed -n "s/${pattern}/${replacement}/p" "$file" | grep -q .; then
        echo "Error: Pattern not found in $file: $pattern"
        return 1
    fi

    # Create backup
    cp "$file" "$file.bak"

    # Perform the replacement
    sed -i "s/${pattern}/${replacement}/g" "$file"

    # Show diff
    echo "Changes in $file:"
    diff "$file.bak" "$file"

    # Clean up
    rm "$file.bak"

    echo "Successfully replaced in $file"
    echo "-------------------"
}

rm -rf llama.cpp

git clone https://github.com/ggerganov/llama.cpp
cd llama.cpp
git remote add fairydreaming https://github.com/fairydreaming/llama.cpp.git
git remote add sl https://github.com/ggerganov/llama.cpp.git
git fetch fairydreaming
git fetch sl
git checkout -b merged_features

# For MLA compressed KV-cache
git merge --no-edit fairydreaming/deepseek2-mla-exp

# To save having to wait ages for the warmup (~2.5x less wait)
git merge --no-edit fairydreaming/experts-warmup

# To allow the use of --override-tensor exps=CPU (and --override-tensor attn_kv_b=CPU)
git merge --no-edit sl/sl/custom-tensor-offload

# Allocate the minimum possible for the unused KV-cache.
safe_sed "src/llama-kv-cache.cpp" "ggml_tensor \* k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa\*kv_size);" "ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, 1);"
safe_sed "src/llama-kv-cache.cpp" "ggml_tensor \* v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa\*kv_size);" "ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, 1);"

# Don't offload to GPU.
safe_sed "ggml/src/ggml-cuda/ggml-cuda.cu" "const int min_batch_size = 32" "const int min_batch_size = 9999999"

safe_sed "src/llama.cpp" "bool pp_opt = n_tokens > n_head;" "bool pp_opt = true;"

#safe_sed "src/llama.cpp" "kv_cache, q_nope2);" "kv_cache, q_nope2);\n                ggml_mul_mat_set_prec(kq_nope, GGML_PREC_F32);"
#safe_sed "src/llama.cpp" "kr_cache, q_pe);" "kr_cache, q_pe);\n                ggml_mul_mat_set_prec(kq_pe, GGML_PREC_F32);"
#safe_sed "src/llama.cpp" "kv_cache_trans, kq);" "kv_cache_trans, kq);\n                ggml_mul_mat_set_prec(kqv_compressed, GGML_PREC_F32);"

# Use float32 for the compressed KV-cache.
#safe_sed "src/llama-kv-cache.h" "ggml_type type_kr = GGML_TYPE_F16" "ggml_type type_kr = GGML_TYPE_F32"
#safe_sed "src/llama-kv-cache.h" "ggml_type type_kv = GGML_TYPE_F16" "ggml_type type_kv = GGML_TYPE_F32"

# Hack llama_tensor_get_type() to use our chosen custom quant.
safe_sed "src/llama-quant.cpp" \
  "llama_tensor_get_type(qs, new_type, tensor, ftype);" \
  "name.find(\"_exps\") != std::string::npos ? name.find(\"ffn_down\") != std::string::npos ? GGML_TYPE_Q6_K : GGML_TYPE_Q5_K : GGML_TYPE_BF16;"

# Must set GGML_SCHED_MAX_COPIES=1 for use with --override-tensor exps=CPU
#cmake -B build -DGGML_CUDA=ON -DGGML_NATIVE=ON -DGGML_SCHED_MAX_COPIES=1 -DGGML_RPC=ON
#cmake -B build -DGGML_CUDA=ON -DGGML_NATIVE=ON -DGGML_SCHED_MAX_COPIES=1 -DGGML_CUDA_PEER_MAX_BATCH_SIZE=9999999
cmake -B build -DGGML_CUDA=ON -DGGML_NATIVE=ON -DGGML_SCHED_MAX_COPIES=1
#cmake -B build -DGGML_CUDA=ON -DGGML_NATIVE=ON -DGGML_SCHED_MAX_COPIES=1 -DGGML_CUDA_FORCE_CUBLAS=1

cmake --build build --config Release -- -j 44

Which gets run using:

numactl --interleave=all ./llama.cpp/build/bin/llama-server --host 192.168.1.111 --port 8080 \
  --model ./DeepSeek-R1-mla-Q5_K_XL.gguf --chat-template deepseek3 --alias "DeepSeek-R1-mla-Q5_K_XL" --ctx_size 32768 \
  --n-gpu-layers 99 --override-tensor exps=CPU --override-tensor attn_kv_b=CPU --numa distribute \
  --temp 0.6 --min-p 0.0 --top-p 1.0 --top-k 0 --threads 30 --threads-batch 44

The quant is in the script on 1 line:

    // ######
    if (name.find("_exps") != std::string::npos) {
        if (name.find("ffn_down") != std::string::npos) {
            new_type = GGML_TYPE_Q6_K;
        }
        else {
                new_type = GGML_TYPE_Q5_K;
        }
    }
    else {
        new_type = GGML_TYPE_BF16;
    }
    else
    // ######

and gave this on wiki.test.raw:

Q5_K_XL : 479.64 GiB (6.13 BPW) | PPL = 3.3499 +/- 0.01849 | 19.72 tokens per second

I can't see the thought tags on openrouter, but this custom BF16/Q6_K /Q5_K appears to be working as good as any they are hosting now (the official deepseek openrouter just seems to always be down so can't text against them), and gives similar responses.

@JohannesGaessler
Copy link
Collaborator

We should avoid these conversions regardless, because the memory required for the intermediate copies is too high with big contexts. However, that would require a f16 x f32 -> f32 matrix multiplication, and I am not sure that cuBLAS can do that.

The PTX documentation has a table with the data types that are supported by tensor cores. In all cases the input matrices must have the same data type. So if the KV cache stays FP16 the activations must be converted to FP16. Alternative approaches would be to use BF16 which has the same numerical range as FP32 or to convert the FP16 data to TF32 in SRAM (this is to my knowledge not supported by cuBLAS, I did not check CUTLASS). Both BF16 and TF32 need Ampere or newer. In terms of speed FP16, FP16 -> FP16 > FP16, FP16 -> FP32 > TF32, TF32 -> FP32.

@jukofyork
Copy link
Contributor

A quick update on the F16 overflow issue:

I've found that fixing bool pp_opt = true; (which essentially removes all the extra perms), and keeping only the _a and _b tensors set as F16, eg:

"static ggml_type llama_tensor_get_type(quantize_state_impl & qs, ggml_type new_type, const ggml_tensor * tensor, llama_ftype ftype) {\n\
    const std::string name = ggml_get_name(tensor);\n\
    if (name.find(\"_exps\") != std::string::npos) {\n\
        return name.find(\"ffn_down\") != std::string::npos ? GGML_TYPE_Q6_K : GGML_TYPE_Q5_K;\n\
    } else if (name.find(\"attn_\") != std::string::npos && name.find(\"_output\") == std::string::npos) {\n\
        return GGML_TYPE_BF16;\n\
    }\n\
    return GGML_TYPE_Q8_0;\n\

It does somewhat work, and no longer gives nan for perplexity or repeat the same word over and over for token generation.

I works quite a bit faster (3.6 tokens/s vs 3.1-3.2 tokens/s) compared to using the same custom quant with BF16 or F32 ( probably by not having to do lost of up/down casting), but it still isn't working 100% correctly as the perplexity run shows:

[1]8.0332,[2]10.0018,[3]8.4663,[4]7.7059,[5]6.9553,[6]6.6773,[7]6.4792,[8]6.8003,[9]6.8766,[10]6.7664,[11]6.7516,[12]7.0069,^C

These should be [1]2.5 and so on.

I've tried using -DGGML_CUDA_FORCE_CUBLAS=1 and then using ggml_mul_mat_set_prec(XXX, GGML_PREC_F32) and this didn't seem to help (one actually made the perplexity go up into the 80s!).

Hopefully after the attention refactoring is over and MLA gets looked at again, some of these problems can be ironed out.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
python python script changes
Projects
None yet
Development

Successfully merging this pull request may close these issues.