-
Notifications
You must be signed in to change notification settings - Fork 10.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Optimized DeepSeek V2/V3 implementation (MLA) #11446
base: master
Are you sure you want to change the base?
Conversation
…kv representations
…nsposing the cache during inference
@fairydreaming do you have a converted model available or instructions for replicating your setup? I would like to run some benchmarks on these changes. |
@wronkiew What model would you like to test? |
V3/R1, Q4_K_S. |
@wronkiew I don't have the model uploaded (my upload bandwidth is too low), you have to download, convert to bf16, convert to gguf and quantize the original model by yourself (or download one that is already converted to bf16, this will save you one step). |
I spent some time investigating this hint from the DeepSeek V2 paper:
At first glance it looks reasonable, each absorbed matrix allows to replace two matrix multiplications with a single multiplication, thus reducing the number of operations. However when we take a look into dimensions of these matrices, this stops being reasonable. For example in DeepSeek V2 lite:
So (let's ignore the head dimension) this allows to replace two multiplications: with [2048, 128] matrix and [512, 128] matrix with a single multiplication with a [512, 2048]. The combined matrix has over 3x elements compared to individual matrices, so it will take more memory and it will be actually slower to multiply compared to two multiplications with smaller matrices. With
I also found this blog post: https://github.com/xjdr-alt/mla_blog_translation where they mention:
So it looks like a dead end, it won't give us any speed gains. |
I ran into an issue with DeepSeek-R1-UD-Q2_K_XL from unsloth/DeepSeek-R1-GGUF
|
As I wrote in the PR:
Existing GGUFs won't work, you have to convert and quantize one with the code from this PR. |
Ohh hmm should I re-quantize the ones in https://huggingface.co/unsloth/DeepSeek-R1-GGUF? |
I think it's best to wait a bit until this is stable and merged, it's possible that there will be some changes that would cause them to stop working and you'd have to repeat the conversion again. |
I updated the token generation performance plots in the PR post, also added some new showing the prompt processing performance. The optimized implementation generally performs WORSE in prompt processing - DeepSeek R1 671B Q4_K_S running on CPU performs only a little worse (~10% with 4k prompt), but DeepSeek V2 Lite Q8_0 running on RTX 4090 performs MUCH WORSE (~30% with 16k prompt) and in both cases the gap widens as the prompt length increases. So it's not all sunshine and rainbows. Considering all these performance regressions I think the best course of action would be to put the optimized implementation into separate model architecture ( |
// whether to use n_tokens as the matrix dimension during multiplication or n_head | ||
// n_tokens is higher during prompt processing, this allows to optimize for this case | ||
bool pp_opt = n_tokens > n_head; | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not really sure this is the right approach. Haven't followed through the logic yet, but it seems strange to involve so many permutes and conts.
I would first look into improving the FA kernels to support DeepSeek head sizes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not really sure this is the right approach. Haven't followed through the logic yet, but it seems strange to involve so many permutes and conts.
Hmm? I'm quite sure there's only one ggml_cont() call (excluding the ones for CUDA compatibility that already existed in the previous implementation).
As for the permutes the idea is to multiply by a matrix with a second dimension equal to the number of heads instead of the number of tokens (which is 1) during a single sequence token generation, that increased the performance on a CPU a bit.
So during prompt processing we have 2 permutes and 1 cont. During token generation we have 5 permutes (yeah, that may be a lot) and 0 conts.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the correction - I did imagine the extra conts when I saw the permutes.
While this is possible to do, I think it has a lot of cons. It will make it difficult for everyone to know which model variation on which hardware to use for better performance. Ideally, we want to have a single implementation that is optimal in all use cases, which can be deprecated at some point for a better alternative. But having 2 alternatives neither of which is optimal is not great. Also, I'm not sure how this implementation fits with multiple parallel sequences and it introduces extra KV cache logic, specific to this type of arch. I know there is a lot of interest in the DeepSeek arch right now and such optimizations are really important for people. But I think that we have to keep this work in a PR for a while. It is much more important to fix the software architecture in |
That may not be possible - IMHO MLA attention implementation that caches "compressed" latent kv representations introduces unavoidable computational overhead due to the need to "decompress" these representations in order to calculate attention scores and attention output. So "naive" attention implementation that caches full K/V vectors will always use less compute but more memory bandwidth, while caching latent representations results in using more compute, but less memory bandwidth. So there can't be a single implementation optimal in all use cases. I'd be happy to be proven wrong about this, though.
I think there shouldn't be any problems with this, as there is a straightforward direct mapping between the cached representations and full K/V vectors.
That's fine with me. I'm taking a break from this anyway, got bored with tensor shuffling looking for 0.1 t/s more performance. 😉 |
@fairydreaming
I don't have a quant on hand that I can test without this branch, but this branch does give me a nice performance boost for TG at longer contexts, but RPC to CUDA does not work. |
OK, I can get the fake LoRA thing working really easily. For fixed
Or dynamic
but need to confirm that @slaren Does I don't want to spend all day doing this to find it's impossible to work with |
It looks to me that a rank-64 "fake LoRA" would explain around half the variance:
and possibly much more further into the LLM (ie: these early layers have the highest information density as found by ikawrakow in his experiments used to write the I think it would be super-worthwhile to try this! |
I think so, but I am not sure about the details. This was implemented by @ngxson and @compilade. |
No problem - I'm off out for a couple of hours, so hopefully will get confirmation by then :) I think I'm getting more comfortable with the GGML stuff now anyway, so may be able to get this working if it isn't already. |
Yes it should, but except for some edge cases where the conversion do some weird tensor permutations - this is too difficult to keep track so we don't have any methods to document it for now. But just a very very small number of models does that, I don't know if deepseek is the case here or not. |
@jukofyork If Otherwise, it should work as-is. MoE LoRAs should work since the LoRA refactor (ref: #8332 (comment)), and conversion for DeepSeekV3 seems to stack the experts in the usual way (I did not test it, though). And you're right, dynamic LoRAs for MoE aren't handled because when stacking the experts with |
Thanks guys, I've got it working by just exporting the LoRA adapter as GGUF directly for now: def export_lora_gguf(
path: os.PathLike[str] | str,
tensors: list[tuple[str, torch.Tensor]],
alpha: int,
quant_type: gguf.GGMLQuantizationType
):
print(f"Initializing GGUFWriter with path: '{path}'")
writer = gguf.GGUFWriter(path, "deepseek2")
writer.add_string("general.type", "adapter")
writer.add_string("adapter.type", "lora")
writer.add_float32("adapter.lora.alpha", alpha)
for name, tensor in tensors:
print(f"- Processing '{name}' with shape {tensor.shape}")
if quant_type in [gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16]:
# Handle float32 and float16 directly
dtype = np.float32 if quant_type == gguf.GGMLQuantizationType.F32 else np.float16
writer.add_tensor(name, tensor.cpu().numpy().astype(dtype))
else:
# Handle BF16 and Q8_0 through quantization
quant_tensor = gguf.quants.quantize(tensor.cpu().numpy(), quant_type)
print(f"-- Original tensor shape: {tensor.shape}")
print(f"-- Quantized tensor shape: {quant_tensor.shape}")
writer.add_tensor(name, quant_tensor, raw_shape=quant_tensor.shape, raw_dtype=quant_type)
writer.write_header_to_file()
writer.write_kv_data_to_file()
writer.write_tensors_to_file()
writer.close()
print("Export completed") It seems to be read in and working for a "layer 3 only" LoRA that I tested it with:
Just need to leave it running overnight now to do the full set of SVDs... I'm going to use
but if it looks promising, I will also try
which is pretty small when compared to even the "meme" 1.58bit quants. |
I couldn't get the "fake LoRA" to work as it did something very strange:
But, I have now managed to integrate it into the compute graph:
and it appears to be working fine and only adds a little overhead. It's pretty horribly hacked in for now: static struct ggml_tensor * llm_build_moe_ffn(
struct ggml_context * ctx,
struct llama_context & lctx,
struct ggml_tensor * cur,
struct ggml_tensor * gate_inp,
struct ggml_tensor * up_exps,
struct ggml_tensor * gate_exps,
struct ggml_tensor * down_exps,
struct ggml_tensor * exp_probs_b,
int64_t n_expert,
int64_t n_expert_used,
llm_ffn_op_type type_op,
bool norm_w,
bool scale_w,
float w_scale,
llama_expert_gating_func_type gating_op,
const llm_build_cb & cb,
int il,
struct ggml_tensor * up_exps_a = nullptr,
struct ggml_tensor * up_exps_b = nullptr,
struct ggml_tensor * gate_exps_a = nullptr,
struct ggml_tensor * gate_exps_b = nullptr,
struct ggml_tensor * down_exps_a = nullptr,
struct ggml_tensor * down_exps_b = nullptr) { static struct ggml_tensor * llm_build_lora_mm_id(
struct llama_context & lctx,
struct ggml_context * ctx0,
struct ggml_tensor * w, // struct ggml_tensor * as
struct ggml_tensor * cur, // struct ggml_tensor * b
struct ggml_tensor * ids,
struct ggml_tensor * a = nullptr,
struct ggml_tensor * b = nullptr) {
struct ggml_tensor * res = ggml_mul_mat_id(ctx0, w, cur, ids);
if (a && b) {
struct ggml_tensor * ab_cur = ggml_mul_mat_id(
ctx0, b,
ggml_mul_mat_id(ctx0, a, cur, ids),
ids
);
res = ggml_add(ctx0, res, ab_cur);
}
for (auto & it : lctx.lora) {
struct llama_adapter_lora_weight * lw = it.first->get_weight(w);
if (lw == nullptr) {
continue;
}
const float alpha = it.first->alpha;
const float rank = (float) lw->b->ne[0];
const float scale = alpha ? it.second * alpha / rank : it.second;
struct ggml_tensor * ab_cur = ggml_mul_mat_id(
ctx0, lw->b,
ggml_mul_mat_id(ctx0, lw->a, cur, ids),
ids
);
ab_cur = ggml_scale(ctx0, ab_cur, scale);
res = ggml_add(ctx0, res, ab_cur);
}
return res;
} and the first test of a rank-64 LoRA seems to actually make it slightly worse:
but I will investigate more tomorrow - there are lots of places a bug could have crept in for the SVD code, or it might just not like being quantised using |
It's not the quantising as with I've found a way to do "proper" LQER (that doesn't require all the expert tensors to be chopped back up and transposed) but it relies on the python It's going to take a couple of days to run because of the as it's not really specific to I'll have a look next week to see if I can find what causes the overflow using |
Forgot to post this, table comparison of all your quants alongside mine (including an IQ1 based quant I had tested). I do use an imatrix (but not on the new split tensor as it hasn't been applied since the imatrix.dat predates it). Your quants do beat mine, but I think they are all larger.
Edit: Included the IQ1_S |
I've still not found a good way to fix the You need too add this to https://github.com/ggml-org/llama.cpp/blob/master/src/llama-quant.cpp#L122 static ggml_type llama_tensor_get_type(quantize_state_impl & qs, ggml_type new_type, const ggml_tensor * tensor, llama_ftype ftype) {
const std::string name = ggml_get_name(tensor);
// TODO: avoid hardcoded tensor names - use the TN_* constants
const llm_arch arch = qs.model.arch;
const auto tn = LLM_TN(arch);
auto use_more_bits = [](int i_layer, int n_layers) -> bool {
return i_layer < n_layers/8 || i_layer >= 7*n_layers/8 || (i_layer - n_layers/8)%3 == 2;
};
const int n_expert = std::max(1, (int)qs.model.hparams.n_expert);
auto layer_info = [n_expert] (int i_layer, int n_layer, const char * name) {
if (n_expert > 1) {
// Believe it or not, "experts" in the FFN of Mixtral-8x7B are not consecutive, but occasionally randomly
// sprinkled in the model. Hence, simply dividing i_ffn_down by n_expert does not work
// for getting the current layer as I initially thought, and we need to resort to parsing the
// tensor name.
if (sscanf(name, "blk.%d.", &i_layer) != 1) {
throw std::runtime_error(format("Failed to determine layer for tensor %s", name));
}
if (i_layer < 0 || i_layer >= n_layer) {
throw std::runtime_error(format("Bad layer %d for tensor %s. Must be in [0, %d)", i_layer, name, n_layer));
}
}
return std::make_pair(i_layer, n_layer);
};
// <<<<<<<<<< <<<<<<<<<< <<<<<<<<<< <<<<<<<<<< <<<<<<<<<< <<<<<<<<<< <<<<<<<<<< <<<<<<<<<< <<<<<<<<<<
if (name.find("attn_k_b") != std::string::npos || name.find("attn_v_b") != std::string::npos) {
new_type = GGML_TYPE_F32;
}
else
// <<<<<<<<<< <<<<<<<<<< <<<<<<<<<< <<<<<<<<<< <<<<<<<<<< <<<<<<<<<< <<<<<<<<<< <<<<<<<<<< <<<<<<<<<<
// for arches that share the same tensor between the token embeddings and the output, we quantize the token embeddings
// with the quantization of the output tensor
if (name == tn(LLM_TENSOR_OUTPUT, "weight") || (!qs.has_output && name == tn(LLM_TENSOR_TOKEN_EMBD, "weight"))) {
.
.
. recompile, and then re-quantise so that these two tensors get overwritten to use You can also set it to You can't use const float kq_scale = 1.0f*mscale*mscale/sqrtf(float(hparams.n_embd_head_k));
.
.
.
kq = ggml_soft_max_ext(ctx0, kq, KQ_mask, kq_scale, hparams.f_max_alibi_bias);
cb(kq, "kq_soft_max_ext", il); but there are that many things getting sliced and permuted here, I'm not confident this wouldn't miss something and it's a pretty ugly hack anyway. There no point in using Which brings me to the reason why not quantising All the other weights in all the other tensors in the model only get accessed a single time per token for token generation, and hence why quantising these can actually speed up the token generation by trading a small amount of dequantising compute for higher effective memory throughput... BUT: I think between this fix (which may actually be worth adding to the PR as an I would think that this fix may have an even bigger effect on CPU-based systems as the sizes:
are so small they probably fit in CPU cache. |
I've now dequantised to a // ### JUK ###
if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K // Q2_K_XM : ~3.0 bits per weight for experts (16×3×3.5 + 42×(3.5 + 2×2.5))/(3×58) |
|| ftype == LLAMA_FTYPE_MOSTLY_Q2_K_S // Q2_K_XL : ~3.5 bits per weight for experts (16×3×4.5 + 42×(4.5 + 2×2.5))/(3×58) |
|| ftype == LLAMA_FTYPE_MOSTLY_Q3_K_S // Q3_K_XM : ~4.0 bits per weight for experts (16×3×4.5 + 42×(4.5 + 2×3.5))/(3×58) |
|| ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M // Q3_K_XL : ~4.5 bits per weight for experts (16×3×5.5 + 42×(5.5 + 2×3.5))/(3×58) |
|| ftype == LLAMA_FTYPE_MOSTLY_Q4_K_S // Q4_K_XM : ~5.0 bits per weight for experts (16*3*5.5 + 42*(5.5 + 2×4.5))/(3*58) | 404 GiB (5.16 BPW)
|| ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M // Q4_K_XL : ~5.5 bits per weight for experts (16*3*6.5 + 42*(6.5 + 2×4.5))/(3*58) | 446 GiB
|| ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M) { // Q5_K_XM : ~6.0 bits per weight for experts (16*3*6.5 + 42*(6.5 + 2×5.5))/(3*58) | 483 GiB (6.16 BPW)
if (name.find("_exps") != std::string::npos) {
int i_layer;
if (sscanf(name.c_str(), "blk.%d.", &i_layer) != 1) {
throw std::runtime_error(format("Failed to determine layer for tensor %s", name.c_str()));
}
if (name.find("ffn_down") != std::string::npos || i_layer <= 10 || i_layer >= 53) {
if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K) {
new_type = GGML_TYPE_Q3_K; // Q2_K_XM
}
else if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K_S || ftype == LLAMA_FTYPE_MOSTLY_Q3_K_S) {
new_type = GGML_TYPE_Q4_K; // Q2_K_XL & Q3_K_XM
}
else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q4_K_S) {
new_type = GGML_TYPE_Q5_K; // Q3_K_XL & Q4_K_XM
}
else /* if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M) */ {
new_type = GGML_TYPE_Q6_K; // Q4_K_XL & Q5_K_XM
}
}
else {
if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K || ftype == LLAMA_FTYPE_MOSTLY_Q2_K_S) {
new_type = GGML_TYPE_Q2_K; // Q2_K_XM & Q2_K_XL
}
else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_S || ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M) {
new_type = GGML_TYPE_Q3_K; // Q3_K_XM & Q3_K_XL
}
else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_S || ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M) {
new_type = GGML_TYPE_Q4_K; // Q4_K_XM & Q4_K_XL
}
else /* if (ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M) */ {
new_type = GGML_TYPE_Q5_K; // Q5_K_XM
}
}
}
else if (name.find("attn_kv_a_mqa") != std::string::npos || name.find("attn_k_b") != std::string::npos || name.find("attn_v_b") != std::string::npos) {
new_type = GGML_TYPE_F32; // Also used: type_kr = type_kv = GGML_TYPE_F32
}
else {
new_type = GGML_TYPE_Q8_0;
}
}
else
// ### JUK ### and then run I'm also using a # Use float32 for the compressed KV-cache.
safe_sed "src/llama-kv-cache.h" "ggml_type type_kr = GGML_TYPE_F16" "ggml_type type_kr = GGML_TYPE_F32"
safe_sed "src/llama-kv-cache.h" "ggml_type type_kv = GGML_TYPE_F16" "ggml_type type_kv = GGML_TYPE_F32"
It says The idea being that I can clearly see if this makes much difference to the perplexity scores vs using I'm also going to run another test using: else if (name.find("attn_kv_a_mqa") != std::string::npos || name.find("attn_k_b") != std::string::npos || name.find("attn_v_b") != std::string::npos) {
new_type = GGML_TYPE_F16;
} to see what effect this has on the token-generation speed (even if the model just spews garbage or |
@saood06 Here are the first 16 chunks for the
This is the absolute maximum I can run as trying to push it even a little higher will likely start to use swapfile and/or bring down the OS, so currently this is probably the best estimate of the lower-bound for the full model. I will post the full perplexity run results in a couple of days. |
It looks like the CUDA GGML code only uses this (found by searching for
Or for matrix-vector products:
It's also used in What |
The function makes input tensors contiguous and presents them as single-batch matrix multiplications to other kernels. The conversion to q8_1 is only done for kernels that use quantized data. |
I don't know if it's worth looking at yet as this is still a draft PR, but it should be quite easy to replicate the slowdown I saw using I just looked at my logs for running the first 16 chunks of
|
Generally speaking, the KQ matrix is susceptible to overflow. So it is preferable to use BF16 or FP32 accumulators for its calculation. However, I was never able to get FP16 matrix multiplication with FP32 accumulation to work with cuBLAS. The documentation says it should be possible but the kernel fails to launch when I try it. Currently the precision argument for KQ is not used for cuBLAS GEMM. For a FP16 K matrix FP16 accumulation is used unconditionally. |
I think I misremembered. After looking at the documentation again I think the problem was that FP16, FP16 -> FP32 GEMM is supported but the performance was so much worse that there was basically no point in using it. |
I investigated possible reasons for poor scaling of token generation when using DeepSeek V3/R1 on dual CPU systems. My current working hypothesis is that the DeepSeek V3/R1 expert FFN matrices are so small (7168 x 2048) that overhead of using two CPUs when doing matrix vector multiplication during token generation negates almost all performance gains. I suppose this is the reason why ktransformers folks in their v3.0-preview have two copies of experts in memory, one for each CPU. I'm going to create a NUMA-aware matrix vector multiplication benchmark to verify this hypothesis. I thought about possible solutions. One would be to assign the experts in each layer into N sets where N is equal to the number of CPUs and then use top n_expert_used/N experts from each set during inference. In this solution each CPU would handle only its assigned local experts and there would be no communication overhead. But it can result in non-optimal expert choices, not sure how it would affect the model performance. |
// ######
if (ftype == LLAMA_FTYPE_MOSTLY_IQ4_XS // IQ4_XXS : 4.25 BPW for experts | 345 GiB (4.41 BPW) (-28.4%) | PPL = 3.3850 +/- 0.01877 (+1.51%) | 15.05 tokens per second ( +8.0%)
|| ftype == LLAMA_FTYPE_MOSTLY_Q4_0 // Q4_0_XS : 4.5 BPW for experts | 365 GiB (4.66 BPW) (-24.4%) | PPL = 3.3944 +/- 0.01885 (+1.95%) | 14.17 tokens per second ( +1.6%)
|| ftype == LLAMA_FTYPE_MOSTLY_Q4_K_S // Q4_K_XS : 4.5 BPW for experts | 365 GiB (4.66 BPW) (-24.4%) | PPL = 3.3724 +/- 0.01866 (+0.66%) | 18.81 tokens per second (+34.9%)
|| ftype == LLAMA_FTYPE_MOSTLY_Q5_K_S // Q5_K_XS : 5.5 BPW for experts | 441 GiB (5.63 BPW) ( -8.6%) | PPL = 3.3546 +/- 0.01852 (+0.16%) | 13.84 tokens per second ( -0.7%)
// -----------------------------------------------------------------------------------------------------------------------------------
|| ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M // Q4_K_XM : ~5.0 BPW for experts | 404 GiB (5.16 BPW) (-16.2%) | PPL = 3.3666 +/- 0.01863 (+0.48%) | 15.82 tokens per second (+13.5%)
|| ftype == LLAMA_FTYPE_MOSTLY_IQ4_NL // Q4_K_XL : ~5.5 BPW for experts | 446 GiB (5.69 BPW) ( -7.6%) | PPL = 3.3614 +/- 0.01858 (+0.33%) | 16.03 tokens per second (+15.0%)
|| ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M // Q5_K_XM : ~6.0 BPW for experts | 483 GiB (6.16 BPW) | PPL = 3.3504 +/- 0.01849 | 13.94 tokens per second
// -----------------------------------------------------------------------------------------------------------------------------------
|| ftype == LLAMA_FTYPE_MOSTLY_Q5_1 // Q5_K_XH : 5.0 BPW for experts | 403 GiB (5.15 BPW) | PPL = 3.3695 +/- 0.01864 (+0.57%) | 15.90 tokens per second (+14.1%)
|| ftype == LLAMA_FTYPE_MOSTLY_Q6_K // Q6_K_XH : 6.0 BPW for experts | 481 GiB (6.15 BPW) (-16.2%) | PPL = 3.3548 +/- 0.01853 (+0.13%) | 13.87 tokens per second ( -0.5%)
// -----------------------------------------------------------------------------------------------------------------------------------
) { // iQ4_K_XS (Q4_K_XS using Bartowski imatrix for experts only) : PPL = 3.3734 +/- 0.01866 (+0.69%) | 18.76 tokens per second (+34.6%)
if (name.find("_exps") != std::string::npos) {
int i_layer;
if (sscanf(name.c_str(), "blk.%d.", &i_layer) != 1) {
throw std::runtime_error(format("Failed to determine layer for tensor %s", name.c_str()));
}
if (ftype == LLAMA_FTYPE_MOSTLY_IQ4_XS) {
new_type = GGML_TYPE_IQ4_XS; // IQ4_XXS
}
else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_0) {
new_type = GGML_TYPE_Q4_0; // Q4_0_XS
}
else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_S) {
new_type = GGML_TYPE_Q4_K; // Q4_K_XS
}
else if (ftype == LLAMA_FTYPE_MOSTLY_Q5_K_S) {
new_type = GGML_TYPE_Q5_K; // Q5_K_XS
}
else if (ftype == LLAMA_FTYPE_MOSTLY_Q5_1) {
new_type = (i_layer <= 31 ? GGML_TYPE_Q5_K : GGML_TYPE_Q4_K); // Q5_K_XH first and last 29 experts
}
else if (ftype == LLAMA_FTYPE_MOSTLY_Q6_K) {
new_type = (i_layer <= 31 ? GGML_TYPE_Q6_K : GGML_TYPE_Q5_K); // Q6_K_XH first and last 29 experts
}
else if (name.find("ffn_down") != std::string::npos || i_layer <= 10 || i_layer >= 53) { // First 8 and last 8 experts (ie: 16/58 experts)
if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M) {
new_type = GGML_TYPE_Q5_K; // Q4_K_XM
}
else {
new_type = GGML_TYPE_Q6_K; // Q4_K_XL & Q5_K_XM
}
}
else {
if (ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M) {
new_type = GGML_TYPE_Q5_K; // Q5_K_XM
}
else {
new_type = GGML_TYPE_Q4_K; // Q4_K_XM & Q4_K_XL
}
}
}
else if (name.find("attn_kv_a_mqa") != std::string::npos || name.find("attn_k_b") != std::string::npos || name.find("attn_v_b") != std::string::npos) {
new_type = GGML_TYPE_F32; // Also used: type_kr = type_kv = GGML_TYPE_F32
}
else {
new_type = GGML_TYPE_Q8_0;
}
}
else
// ###### I've finished the testing of the custom quants:
Just running one last test on pure EDIT:
No difference, |
I think the safest option then is probably to use |
In Daniel's 1.58 quantization he kept the shared expert at a higher resolution than the routed experts. |
// whether to use n_tokens as the matrix dimension during multiplication or n_head
// n_tokens is higher during prompt processing, this allows to optimize for this case
bool pp_opt = n_tokens > n_head; I think this might be causing some weird problem in the CUDA back-end where a different code-path is taken. If I leave it as default and use this 127-token test prompt:
The model wont' say the actual phrases and it feels "off" - like there is something wrong with the attention mechanism (it sometimes "sort of" says the phrases, but not quite and often not at all). If I fix the flag to always be true, eg: bool pp_opt = true; Then all of a sudden the model starts to says those phrases and seems way better at writing in general (I suspect this triggers a different code-path - possibly something to do with the matrix-vector vs matrix-matrix stuff I remember seeing the other day?) If I fix the flag to always be false eg: bool pp_opt = false; Then run So I thought I'd try running with
It actually seems to get better PPL for these first few values (sorry no way I can run the 24h to completion) and the difference is almost the size of the error bar from the full PPL calculated over the default setting. I don't know how else to help diagnose what's going on 😕 Could it be that the 127-token test prompt is not a multiple of 32 and when it gets permuted it's causing some problem there? |
@jukofyork If you think that some operation is producing wrong results with CUDA, an easy way to test that would be to add a test case to |
@slaren @JohannesGaessler @fairydreaming I've got a little further now and think it's the same overflow problem that affected Both the existing attention implementations use if (cparams.flash_attn) {
GGML_UNUSED(model);
GGML_UNUSED(n_ctx);
// split cached v into n_head heads (not transposed)
struct ggml_tensor * v =
ggml_view_3d(ctx, kv.v_l[il],
n_embd_head_v, n_kv, n_head_kv,
ggml_row_size(kv.v_l[il]->type, n_embd_v_gqa),
ggml_row_size(kv.v_l[il]->type, n_embd_head_v),
0);
cb(v, "v", il);
cur = ggml_flash_attn_ext(ctx, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias,
hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f);
ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32);
cur = ggml_reshape_2d(ctx, cur, n_embd_head_v*n_head, n_tokens);
} else {
struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q);
cb(kq, "kq", il);
// note: this op tends to require high floating point range
// while for some models F16 is enough, for others it is not, so we default to F32 here
ggml_mul_mat_set_prec(kq, GGML_PREC_F32); and by trial and error with struct ggml_tensor * kq_nope = ggml_mul_mat(ctx0, kv_cache, q_nope2);
ggml_mul_mat_set_prec(kq_nope, GGML_PREC_F32); // ***
cb(kq_nope, "kq_nope", il); struct ggml_tensor * kq_pe = ggml_mul_mat(ctx0, kr_cache, q_pe);
ggml_mul_mat_set_prec(kq_pe, GGML_PREC_F32); // ***
cb(kq_pe, "kq_pe", il); struct ggml_tensor * kqv_compressed = ggml_mul_mat(ctx0, kv_cache_trans, kq);
ggml_mul_mat_set_prec(kqv_compressed, GGML_PREC_F32); // ***
cb(kqv_compressed, "kqv_compressed", il); but this needs to be compiled with This was with When it comes times to merge the official MLA implementation, then I think this needs to be tested more thoroughly than I can do. |
The MMQ kernels always use FP32 for the accumulators, if there are numerical issues they must be due to extremal values in the inputs since FP16 is used for the scales of the quantized data. |
We should avoid these conversions regardless, because the memory required for the intermediate copies is too high with big contexts. However, that would require a |
I think just changing Here's my full script that merges the PRs and applies all the hacks (including the commented out ones I'm not using): #!/bin/bash
function safe_sed() {
local file=$1
local pattern=$2
local replacement=$3
# Check if pattern exists
if ! sed -n "s/${pattern}/${replacement}/p" "$file" | grep -q .; then
echo "Error: Pattern not found in $file: $pattern"
return 1
fi
# Create backup
cp "$file" "$file.bak"
# Perform the replacement
sed -i "s/${pattern}/${replacement}/g" "$file"
# Show diff
echo "Changes in $file:"
diff "$file.bak" "$file"
# Clean up
rm "$file.bak"
echo "Successfully replaced in $file"
echo "-------------------"
}
rm -rf llama.cpp
git clone https://github.com/ggerganov/llama.cpp
cd llama.cpp
git remote add fairydreaming https://github.com/fairydreaming/llama.cpp.git
git remote add sl https://github.com/ggerganov/llama.cpp.git
git fetch fairydreaming
git fetch sl
git checkout -b merged_features
# For MLA compressed KV-cache
git merge --no-edit fairydreaming/deepseek2-mla-exp
# To save having to wait ages for the warmup (~2.5x less wait)
git merge --no-edit fairydreaming/experts-warmup
# To allow the use of --override-tensor exps=CPU (and --override-tensor attn_kv_b=CPU)
git merge --no-edit sl/sl/custom-tensor-offload
# Allocate the minimum possible for the unused KV-cache.
safe_sed "src/llama-kv-cache.cpp" "ggml_tensor \* k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa\*kv_size);" "ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, 1);"
safe_sed "src/llama-kv-cache.cpp" "ggml_tensor \* v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa\*kv_size);" "ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, 1);"
# Don't offload to GPU.
safe_sed "ggml/src/ggml-cuda/ggml-cuda.cu" "const int min_batch_size = 32" "const int min_batch_size = 9999999"
safe_sed "src/llama.cpp" "bool pp_opt = n_tokens > n_head;" "bool pp_opt = true;"
#safe_sed "src/llama.cpp" "kv_cache, q_nope2);" "kv_cache, q_nope2);\n ggml_mul_mat_set_prec(kq_nope, GGML_PREC_F32);"
#safe_sed "src/llama.cpp" "kr_cache, q_pe);" "kr_cache, q_pe);\n ggml_mul_mat_set_prec(kq_pe, GGML_PREC_F32);"
#safe_sed "src/llama.cpp" "kv_cache_trans, kq);" "kv_cache_trans, kq);\n ggml_mul_mat_set_prec(kqv_compressed, GGML_PREC_F32);"
# Use float32 for the compressed KV-cache.
#safe_sed "src/llama-kv-cache.h" "ggml_type type_kr = GGML_TYPE_F16" "ggml_type type_kr = GGML_TYPE_F32"
#safe_sed "src/llama-kv-cache.h" "ggml_type type_kv = GGML_TYPE_F16" "ggml_type type_kv = GGML_TYPE_F32"
# Hack llama_tensor_get_type() to use our chosen custom quant.
safe_sed "src/llama-quant.cpp" \
"llama_tensor_get_type(qs, new_type, tensor, ftype);" \
"name.find(\"_exps\") != std::string::npos ? name.find(\"ffn_down\") != std::string::npos ? GGML_TYPE_Q6_K : GGML_TYPE_Q5_K : GGML_TYPE_BF16;"
# Must set GGML_SCHED_MAX_COPIES=1 for use with --override-tensor exps=CPU
#cmake -B build -DGGML_CUDA=ON -DGGML_NATIVE=ON -DGGML_SCHED_MAX_COPIES=1 -DGGML_RPC=ON
#cmake -B build -DGGML_CUDA=ON -DGGML_NATIVE=ON -DGGML_SCHED_MAX_COPIES=1 -DGGML_CUDA_PEER_MAX_BATCH_SIZE=9999999
cmake -B build -DGGML_CUDA=ON -DGGML_NATIVE=ON -DGGML_SCHED_MAX_COPIES=1
#cmake -B build -DGGML_CUDA=ON -DGGML_NATIVE=ON -DGGML_SCHED_MAX_COPIES=1 -DGGML_CUDA_FORCE_CUBLAS=1
cmake --build build --config Release -- -j 44 Which gets run using: numactl --interleave=all ./llama.cpp/build/bin/llama-server --host 192.168.1.111 --port 8080 \
--model ./DeepSeek-R1-mla-Q5_K_XL.gguf --chat-template deepseek3 --alias "DeepSeek-R1-mla-Q5_K_XL" --ctx_size 32768 \
--n-gpu-layers 99 --override-tensor exps=CPU --override-tensor attn_kv_b=CPU --numa distribute \
--temp 0.6 --min-p 0.0 --top-p 1.0 --top-k 0 --threads 30 --threads-batch 44 The quant is in the script on 1 line: // ######
if (name.find("_exps") != std::string::npos) {
if (name.find("ffn_down") != std::string::npos) {
new_type = GGML_TYPE_Q6_K;
}
else {
new_type = GGML_TYPE_Q5_K;
}
}
else {
new_type = GGML_TYPE_BF16;
}
else
// ###### and gave this on
I can't see the thought tags on openrouter, but this custom |
The PTX documentation has a table with the data types that are supported by tensor cores. In all cases the input matrices must have the same data type. So if the KV cache stays FP16 the activations must be converted to FP16. Alternative approaches would be to use BF16 which has the same numerical range as FP32 or to convert the FP16 data to TF32 in SRAM (this is to my knowledge not supported by cuBLAS, I did not check CUTLASS). Both BF16 and TF32 need Ampere or newer. In terms of speed |
A quick update on the I've found that fixing "static ggml_type llama_tensor_get_type(quantize_state_impl & qs, ggml_type new_type, const ggml_tensor * tensor, llama_ftype ftype) {\n\
const std::string name = ggml_get_name(tensor);\n\
if (name.find(\"_exps\") != std::string::npos) {\n\
return name.find(\"ffn_down\") != std::string::npos ? GGML_TYPE_Q6_K : GGML_TYPE_Q5_K;\n\
} else if (name.find(\"attn_\") != std::string::npos && name.find(\"_output\") == std::string::npos) {\n\
return GGML_TYPE_BF16;\n\
}\n\
return GGML_TYPE_Q8_0;\n\ It does somewhat work, and no longer gives I works quite a bit faster (3.6 tokens/s vs 3.1-3.2 tokens/s) compared to using the same custom quant with
These should be I've tried using Hopefully after the attention refactoring is over and MLA gets looked at again, some of these problems can be ironed out. |
This PR introduces various optimizations for DeepSeek V2/V3 implementation:
Note that you need to reconvert the model to use this implementation.
Performance compared to the previous "naive" implementation:
CUDA performance is worse for short context lengths, but the curve is flatter:
TODO:
address regressions in prompt processing performance (different permutations of tensors?)- I don't think it's possible, as this implementation is more compute-intensive compared to regular attention implementation