-
Couldn't load subscription status.
- Fork 155
Q8_KV: 8-bit quantization type targeting the KV cache #208
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
We get 225.7 t/s for L3-8B. In comparison q8_0 without run-tinme-repacking is at 169 t/s.
We get 254 t/s for L3-8B vs 194 t/s for q8_0 without rtr.
This required quite a few fixes in ggml and llama.cpp: * ggml: do not calculate row size as n/block_size*type_size. I had removed most of it when implementing the quants with per row scale, bit it was stull lurking in ggml_copy. Not sure if these were the last remnants of ggmil-style row sizes, or if there are still places left * llama.cpp: get rid of the the 1d K cache assumption. Create and manage the K-cache as a 2D tensor so we can have per row meta data as needed by q8_KV. Using q8_KV for K-cache results in non-negligible performance gains. More details to follow, but for DeepSeek-Lite with MLA, we get 18% speedup for PP-8192 compared to q8_0 K-cache.
We get PP-512 = 167 t/s for L3-8B without interleaving! We do the interleaving on the fly, so I wonder if this could be done for other quants as well.
On Zen4 it is slower than q8_k_r8 (292 vs 370 t/s) This makes no sense whatsoever as the q8_KV_r8 GEMM is basically the q8_k_r8 GEMM with the unnecessary block stuff removed (so, one would think that it would be faster).
This is faster - 350 t/s. Why? Much better than the 290 t/s we had before, but still slower than the 370 t/s for q8_k_r8.
I do think a viable application that could leverage this benefit is compression (open PR exists on mainline: ggml-org/llama.cpp#9633), but even better would be a multimodal base model with an Evabyte or Evabyte-like architecture (multimodal, MTP with no tokenizer as it is byte based) potentially adding MLA and MoE and Muon like modern architectures. There are optimizations you could do, for example if it is an archive format, you could take advantage by batching files as they could be independent streams, it may even be worth including LoRAs that can be applied dynamically. I don't know who would do this as it is a large undertaking (especially if you make the model), and I'm not sure if it would be fast enough on enough systems to be viable. |
What is
Q8_KV? It is 8-bit quantization with a single scale per tensor row (so, no blocks at all). That may not be accurate enough for model quantization, but using it for KV cache quantization seems plausible, considering that there rows are defined by the head size, so contain 64, 80, 96, 128, 192, or 256 elements for all LLMs currently in circulation. We are not looking for KV cache size reduction but rather for improving inference performance for long contexts. This is especially relevant for MLA (DeepSeek) as in FA the kernels are highly optimized, so large improvements may not be really possible.Caveat: everything is CPU only, there is no CUDA or Metal implementation.
The following changes are made:
Q8_KVandQ8_KV_R8are added.Q8_KV_R8isQ8_KVwith 8 interleaved rowsQ8_K_R8, the so far fastest quantization type for prompt processing. OnAVX2/Zen4Q8_KV_R8is slightly slower thanQ8_K_R8, which is somewhat surprising.Q8_KVquants in the K cache. This required various fixes inllama.cppandggml. There were still places left where the number of bytes needed to store a row of sizeNare computed as(N/B)*T, whereBis the type block size andTis the type size. This of course fails when the row has extra meta data. There is the functionggml_row_size(ggml_type type, int64_t N)to compute this, but I had missed a few places when adding theIQKquants. It also turned out that in quite a few placesggml_row_size()is not used correctly. E.g., for the KV cache we findggml_row_size(type_k, head_size*num_heads)instead ofggml_row_size(type_k, head_size)*num_heads. Same issue was also present in the MoE matrix multiplication function.Q8_KV, I didn't put too much effort into hunting down all places of incorrectggml_row_size()usage.Q8_KVin FA. Here we get a minor speedup compared toQ8_0(1-2% at 16k tokens).A quantization type such as
Q8_KVhas the distinct advantage of making the results of matrix multiplications 100% reproducible and independent of the hardware the calculation is being done on (the row x column dot products are performed using integer arithmetic, and only at the end the row scale is applied, so number of threads used and order of summation does not affect the final result). I know there is interest in that sort of thing, but I leave further exploration for another day.After all this, here is a comparison between the main branch and this PR for DeepSeek-Lite (acting as a surrogate for DeepSeek-R1) with MLA enabled. The V cache is
bf16, the model is quantized withIQ4_XS, and the calculation is on a Ryzen-7950X CPU. The main branch usesQ8_0for the K cache, the PR usesQ8_KVHere is a perplexity comparison between
Q8_0andQ8_KVused for model and K cache quantization for DeepSeek-Lite with a context of 512 tokens.PPL(fp16) = 6.7612I.e., using
Q8_KVfor K-cache quantization leads to a very minor loss of accuracy (certainly much better thanQ6_0), but usingQ8_KVto quantize the model weights results in much more significant accuracy loss.Update
I have added the last 2 rows to the above table. In
Q8_KV*the output and token embedding tensors are quantized withQ8_0, so most of the accuracy loss comes from these two tensors (and they have negligible impact on performance). I have also rerun the performance tests after merging PR #210. Here are the updated results: