Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

llamafile : improve moe prompt eval speed on cpu #6840

Open
wants to merge 1 commit into
base: master
Choose a base branch
from

Conversation

jart
Copy link
Contributor

@jart jart commented Apr 23, 2024

This change introduces a llamafile_mixmul() API that allows tinyBLAS to speed up "Mixture of Expert" models. On my Threadripper, Mixtral's 8x7b F16 weights now process prompts 2x faster. I'm also seeing a 60 percent improvement with Mixtral 8x22b Q4_0. The same applies to Q8_0, which is also supported by tinyBLAS. MoE models spend the majority of their time inside MUL_MAT_ID rather than MUL_MAT, which is why llamafile_sgemm was not able to help them before. llamafile_mixmul works by decomposing the mixmul operation into approximatively two sgemm calls.

@jart jart force-pushed the moe branch 2 times, most recently from def794c to 828f3fe Compare April 23, 2024 05:44
@hiepxanh
Copy link

nice to see this PR <3 Thank you so much

@USBhost
Copy link

USBhost commented Apr 23, 2024

Does it also help the other K quants?

@jart
Copy link
Contributor Author

jart commented Apr 23, 2024

@USBhost Unfortunately no. The K quants were designed to exploit under-utilization of CPU resources when doing matvecs. I tried copying and pasting the Q5_K_M code into a tinyBLAS 2-d block-tiling kernel, but the compiler wasn't able to unroll it it in a way that offered performance gains through instruction level parallelism. I've only been able to make the simpler quants work. It's a shame because I really like Q5_K_M, so it'd be great to see Iwan Kawrakow develop a new quant specifically for block-tiling.

@USBhost
Copy link

USBhost commented Apr 23, 2024

@USBhost Unfortunately no. The K quants were designed to exploit under-utilization of CPU resources when doing matvecs. I tried copying and pasting the Q5_K_M code into a tinyBLAS 2-d block-tiling kernel, but the compiler wasn't able to unroll it it in a way that offered performance gains through instruction level parallelism. I've only been able to make the simpler quants work. It's a shame because I really like Q5_K_M, so it'd be great to see Iwan Kawrakow develop a new quant specifically for block-tiling.

I see thanks for the explanation.

Side note: I would love a doc that explains the speed between flat 4_0 vs 4_1 vs K quants. Because I keep seeing the simple ones getting buffs.

@jart
Copy link
Contributor Author

jart commented Apr 23, 2024

The tinyBLAS code upstreamed by Mozilla's llamafile project makes prompt processing go very fast for F32, F16, Q4_0, and Q8_0.

model size params backend threads test t/s
llama 1B F16 2.05 GiB 1.10 B CPU 96 pp 512 2048.86 ± 6.52
llama 1B F16 2.05 GiB 1.10 B CPU 96 tg 4 52.01 ± 0.06
llama 1B all F32 4.10 GiB 1.10 B CPU 96 pp 512 1946.19 ± 21.26
llama 1B all F32 4.10 GiB 1.10 B CPU 96 tg 4 39.75 ± 0.15
llama 1B Q2_K - Medium 411.41 MiB 1.10 B CPU 96 pp 512 1273.08 ± 18.57
llama 1B Q2_K - Medium 411.41 MiB 1.10 B CPU 96 tg 4 69.55 ± 0.22
llama 1B Q3_K - Large 563.42 MiB 1.10 B CPU 96 pp 512 1109.08 ± 7.66
llama 1B Q3_K - Large 563.42 MiB 1.10 B CPU 96 tg 4 66.83 ± 0.41
llama 1B Q3_K - Medium 522.30 MiB 1.10 B CPU 96 pp 512 1161.17 ± 7.58
llama 1B Q3_K - Medium 522.30 MiB 1.10 B CPU 96 tg 4 67.49 ± 0.10
llama 1B Q3_K - Small 475.51 MiB 1.10 B CPU 96 pp 512 1052.25 ± 161.64
llama 1B Q3_K - Small 475.51 MiB 1.10 B CPU 96 tg 4 68.30 ± 0.19
llama 1B Q4_0 606.53 MiB 1.10 B CPU 96 pp 512 1418.41 ± 10.54
llama 1B Q4_0 606.53 MiB 1.10 B CPU 96 tg 4 65.78 ± 0.24
llama 1B Q4_1 668.18 MiB 1.10 B CPU 96 pp 512 884.68 ± 3.74
llama 1B Q4_1 668.18 MiB 1.10 B CPU 96 tg 4 64.62 ± 0.05
llama 1B Q4_K - Medium 636.18 MiB 1.10 B CPU 96 pp 512 1197.76 ± 11.42
llama 1B Q4_K - Medium 636.18 MiB 1.10 B CPU 96 tg 4 66.04 ± 0.34
llama 1B Q4_K - Small 609.53 MiB 1.10 B CPU 96 pp 512 1200.22 ± 10.06
llama 1B Q4_K - Small 609.53 MiB 1.10 B CPU 96 tg 4 66.38 ± 0.27
llama 1B Q5_0 729.84 MiB 1.10 B CPU 96 pp 512 1058.68 ± 10.52
llama 1B Q5_0 729.84 MiB 1.10 B CPU 96 tg 4 63.10 ± 0.36
llama 1B Q5_1 791.50 MiB 1.10 B CPU 96 pp 512 718.18 ± 127.77
llama 1B Q5_1 791.50 MiB 1.10 B CPU 96 tg 4 62.07 ± 0.67
llama 1B Q5_K - Medium 745.11 MiB 1.10 B CPU 96 pp 512 1055.78 ± 5.80
llama 1B Q5_K - Medium 745.11 MiB 1.10 B CPU 96 tg 4 64.01 ± 0.16
llama 1B Q5_K - Small 729.84 MiB 1.10 B CPU 96 pp 512 1048.20 ± 3.90
llama 1B Q5_K - Small 729.84 MiB 1.10 B CPU 96 tg 4 64.32 ± 0.27
llama 1B Q6_K 860.86 MiB 1.10 B CPU 96 pp 512 995.96 ± 183.61
llama 1B Q6_K 860.86 MiB 1.10 B CPU 96 tg 4 62.67 ± 0.22
llama 1B Q8_0 1.09 GiB 1.10 B CPU 96 pp 512 1430.38 ± 9.86
llama 1B Q8_0 1.09 GiB 1.10 B CPU 96 tg 4 59.71 ± 0.14

Measured on AMD Ryzen Threadripper PRO 7995WX with TinyLlama 1.1B. This PR ensures those performance wins will happen for MoE models too.

@jart
Copy link
Contributor Author

jart commented Apr 24, 2024

Note: I'm still in the process of testing this change and verifying it's correct on all compilers and architectures.

@jart jart force-pushed the moe branch 2 times, most recently from 26ab943 to 89991a1 Compare April 25, 2024 00:56
@jart
Copy link
Contributor Author

jart commented Apr 25, 2024

OK I've worked out the remaining kinks. This code was just shipped as part of the llamafile 0.8 release. Thanks to this change, I'm seeing a 2x prompt eval speed increase across the board. My Threadripper now runs Mixtral 2x faster. My M2 Ultra runs Mixtral 2x faster on CPU. This change even pumps up the Raspberry Pi 5 to 78 tok/sec performance on non-MoE F16 models in case you want to buy a bag full of the things to build your next supercomputer. PTAL.

@ikawrakow
Copy link
Contributor

It's a shame because I really like Q5_K_M, so it'd be great to see Iwan Kawrakow develop a new quant specifically for block-tiling.

@jart

I became intrigued by your assumption that block-tiling is required to speed up prompt processing for k-quants, so spent some time optimizing k-quant CPU matrix multiplications. I'm running on a 16-core Ryzen-7950X CPU, so have done just a better AVX2 implementation. Baseline for this CPU (using your PR) for a 7B LLaMA is

model size params backend threads test t/s
llama 7B Q4_0 3.56 GiB 6.74 B CPU 16 pp 512 119.75 ± 0.38
llama 7B Q4_1 3.95 GiB 6.74 B CPU 16 pp 512 63.80 ± 0.22
llama 7B Q5_0 4.33 GiB 6.74 B CPU 16 pp 512 59.50 ± 0.11
llama 7B Q5_1 4.72 GiB 6.74 B CPU 16 pp 512 56.28 ± 0.08

Q4_0 is much faster than the other legacy quants thanks to your tinyBLAS.

Here is what I get for k-quants

model size params test t/s (master) t/s (optimized) Speedup
llama 7B Q2_K - Small 2.16 GiB 6.74 B pp 512 116.35 ± 0.08 162.66 ± 1.21 1.398 ± 0.009
llama 7B Q2_K - Medium 2.36 GiB 6.74 B pp 512 100.49 ± 0.15 149.31 ± 0.48 1.486 ± 0.006
llama 7B Q3_K - Small 2.75 GiB 6.74 B pp 512 82.01 ± 0.11 132.68 ± 0.21 1.618 ± 0.003
llama 7B Q3_K - Medium 3.07 GiB 6.74 B pp 512 89.16 ± 0.10 136.53 ± 0.19 1.531 ± 0.003
llama 7B Q4_K - Small 3.59 GiB 6.74 B pp 512 104.45 ± 0.15 144.87 ± 0.31 1.387 ± 0.003
llama 7B Q4_K - Medium 3.80 GiB 6.74 B pp 512 101.50 ± 0.28 145.77 ± 0.33 1.436 ± 0.005
llama 7B Q5_K - Small 4.33 GiB 6.74 B pp 512 73.72 ± 0.29 124.05 ± 0.14 1.682 ± 0.007
llama 7B Q5_K - Medium 4.45 GiB 6.74 B pp 512 74.84 ± 0.13 126.52 ± 0.14 1.691 ± 0.004
llama 7B Q6_K 5.15 GiB 6.74 B pp 512 81.38 ± 0.09 146.30 ± 0.15 1.798 ± 0.003

You favorite Q5_K_M quants are faster than Q4_0 with tinyBLAS with these changes :-)

There are 3 ingredients involved in this speedup:

  • Simdify Q8_K quantization. Quantization of activations is single-threaded in ggml. Quantization to Q8_0, needed by legacy quants, is already simdified, but quantization to Q8_K, required by k-quants, is not. This does not matter for token generation, but we get a 5-6% speedup for prompt processing.
  • Tweak some more the k-quant dot product kernels to reduce/eliminate dependencies between computational steps. I guess this is what you call "instruction level parallelism"
  • Last, but certainly not least, make use of the ability to do 2x2 matrix multiplications (rather than just dot products) that is already available in ggml (2 weight rows times two activation columns). This gives 20-40% speedup, depending on how costly it is to setup the bits as needed for the multiplication with the Q8_K quants.

We see Q4_K and Q5_K being ~2.2 times faster than their respective legacy counterparts Q4_1 and Q5_1.

@jart jart force-pushed the moe branch 2 times, most recently from f1a134a to c34c472 Compare April 26, 2024 15:47
@jart
Copy link
Contributor Author

jart commented Apr 26, 2024

That's outstanding news @ikawrakow! I can't wait to see your code. Now I won't need to recommend the legacy quantization formats. Am I correct in understanding you used .nrows? Have you tried copying your optimized code into a tinyBLAS kernel? If you do that, then K quants might be able to surpass F16 and BF16 at evaluation.

@ikawrakow
Copy link
Contributor

@jart

Yes, I used .nrows. I did that because it was not obvious to me how I can plug that into tinyBLAS. I could make a PR to your repository so you can do the integration into tinyBLAS?

@jart
Copy link
Contributor Author

jart commented Apr 26, 2024

@ikawrakow Receiving a PR from you would honor the llamafile project. What you'd want to do is create a copy of tinyBLAS_Q0_ARM named like tinyBLAS_K5_ARM and then have your vec_dot code replace these specific lines. You'd then tune its mnpack() method to be smaller until eventually there's no stack spillage.

Copy link
Contributor

github-actions bot commented Apr 26, 2024

📈 llama.cpp server for bench-server-baseline on Standard_NC4as_T4_v3 for phi-2-q4_0: 542 iterations 🚀

Expand details for performance related PR only
  • Concurrent users: 8, duration: 10m
  • HTTP request : avg=8617.47ms p(95)=20606.03ms fails=, finish reason: stop=493 truncated=49
  • Prompt processing (pp): avg=104.44tk/s p(95)=449.45tk/s
  • Token generation (tg): avg=32.11tk/s p(95)=47.66tk/s
  • ggml-org/models/phi-2/ggml-model-q4_0.gguf parallel=8 ctx-size=16384 ngl=33 batch-size=2048 ubatch-size=256 pp=1024 pp+tg=2048 branch=moe commit=2d4740434ece6c0438fb97b46cd0808db875da3e

prompt_tokens_seconds

More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 542 iterations"
    y-axis "llamacpp:prompt_tokens_seconds"
    x-axis "llamacpp:prompt_tokens_seconds" 1716343747 --> 1716344373
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 319.9, 319.9, 319.9, 319.9, 319.9, 761.62, 761.62, 761.62, 761.62, 761.62, 764.11, 764.11, 764.11, 764.11, 764.11, 755.7, 755.7, 755.7, 755.7, 755.7, 775.86, 775.86, 775.86, 775.86, 775.86, 790.62, 790.62, 790.62, 790.62, 790.62, 805.63, 805.63, 805.63, 805.63, 805.63, 800.95, 800.95, 800.95, 800.95, 800.95, 821.61, 821.61, 821.61, 821.61, 821.61, 814.67, 814.67, 814.67, 814.67, 814.67, 832.75, 832.75, 832.75, 832.75, 832.75, 854.91, 854.91, 854.91, 854.91, 854.91, 875.3, 875.3, 875.3, 875.3, 875.3, 881.78, 881.78, 881.78, 881.78, 881.78, 816.67, 816.67, 816.67, 816.67, 816.67, 822.25, 822.25, 822.25, 822.25, 822.25, 828.9, 828.9, 828.9, 828.9, 828.9, 816.56, 816.56, 816.56, 816.56, 816.56, 839.75, 839.75, 839.75, 839.75, 839.75, 839.94, 839.94, 839.94, 839.94, 839.94, 836.37, 836.37, 836.37, 836.37, 836.37, 843.53, 843.53, 843.53, 843.53, 843.53, 847.32, 847.32, 847.32, 847.32, 847.32, 860.8, 860.8, 860.8, 860.8, 860.8, 866.11, 866.11, 866.11, 866.11, 866.11, 868.18, 868.18, 868.18, 868.18, 868.18, 884.38, 884.38, 884.38, 884.38, 884.38, 879.97, 879.97, 879.97, 879.97, 879.97, 878.14, 878.14, 878.14, 878.14, 878.14, 878.43, 878.43, 878.43, 878.43, 878.43, 883.24, 883.24, 883.24, 883.24, 883.24, 882.73, 882.73, 882.73, 882.73, 882.73, 882.71, 882.71, 882.71, 882.71, 882.71, 883.54, 883.54, 883.54, 883.54, 883.54, 892.41, 892.41, 892.41, 892.41, 892.41, 894.09, 894.09, 894.09, 894.09, 894.09, 902.57, 902.57, 902.57, 902.57, 902.57, 898.28, 898.28, 898.28, 898.28, 898.28, 896.99, 896.99, 896.99, 896.99, 896.99, 898.0, 898.0, 898.0, 898.0, 898.0, 898.45, 898.45, 898.45, 898.45, 898.45, 906.8, 906.8, 906.8, 906.8, 906.8, 887.46, 887.46, 887.46, 887.46, 887.46, 880.49, 880.49, 880.49, 880.49, 880.49, 879.12, 879.12, 879.12, 879.12, 879.12, 876.68, 876.68, 876.68, 876.68, 876.68, 875.93, 875.93, 875.93, 875.93, 875.93, 879.62, 879.62, 879.62, 879.62, 879.62, 878.61, 878.61, 878.61, 878.61, 878.61, 880.96, 880.96, 880.96, 880.96, 880.96, 882.45, 882.45, 882.45, 882.45, 882.45, 885.14, 885.14, 885.14, 885.14, 885.14, 886.43, 886.43, 886.43, 886.43, 886.43, 886.27, 886.27, 886.27, 886.27, 886.27, 889.99, 889.99, 889.99, 889.99, 889.99, 889.83, 889.83, 889.83, 889.83, 889.83, 891.15, 891.15, 891.15, 891.15, 891.15, 889.95, 889.95, 889.95, 889.95, 889.95, 887.42, 887.42, 887.42, 887.42, 887.42, 888.38, 888.38, 888.38, 888.38, 888.38, 888.98, 888.98, 888.98, 888.98]
                    
Loading
predicted_tokens_seconds
More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 542 iterations"
    y-axis "llamacpp:predicted_tokens_seconds"
    x-axis "llamacpp:predicted_tokens_seconds" 1716343747 --> 1716344373
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 40.7, 40.7, 40.7, 40.7, 40.7, 47.05, 47.05, 47.05, 47.05, 47.05, 35.97, 35.97, 35.97, 35.97, 35.97, 25.88, 25.88, 25.88, 25.88, 25.88, 26.91, 26.91, 26.91, 26.91, 26.91, 27.43, 27.43, 27.43, 27.43, 27.43, 29.0, 29.0, 29.0, 29.0, 29.0, 30.12, 30.12, 30.12, 30.12, 30.12, 30.84, 30.84, 30.84, 30.84, 30.84, 31.68, 31.68, 31.68, 31.68, 31.68, 31.64, 31.64, 31.64, 31.64, 31.64, 32.24, 32.24, 32.24, 32.24, 32.24, 31.81, 31.81, 31.81, 31.81, 31.81, 31.51, 31.51, 31.51, 31.51, 31.51, 31.45, 31.45, 31.45, 31.45, 31.45, 30.26, 30.26, 30.26, 30.26, 30.26, 29.82, 29.82, 29.82, 29.82, 29.82, 30.2, 30.2, 30.2, 30.2, 30.2, 30.2, 30.2, 30.2, 30.2, 30.2, 29.92, 29.92, 29.92, 29.92, 29.92, 29.62, 29.62, 29.62, 29.62, 29.62, 29.58, 29.58, 29.58, 29.58, 29.58, 29.76, 29.76, 29.76, 29.76, 29.76, 30.07, 30.07, 30.07, 30.07, 30.07, 30.03, 30.03, 30.03, 30.03, 30.03, 30.22, 30.22, 30.22, 30.22, 30.22, 30.21, 30.21, 30.21, 30.21, 30.21, 29.96, 29.96, 29.96, 29.96, 29.96, 29.8, 29.8, 29.8, 29.8, 29.8, 29.92, 29.92, 29.92, 29.92, 29.92, 30.14, 30.14, 30.14, 30.14, 30.14, 30.3, 30.3, 30.3, 30.3, 30.3, 30.49, 30.49, 30.49, 30.49, 30.49, 30.61, 30.61, 30.61, 30.61, 30.61, 30.57, 30.57, 30.57, 30.57, 30.57, 30.18, 30.18, 30.18, 30.18, 30.18, 29.9, 29.9, 29.9, 29.9, 29.9, 29.81, 29.81, 29.81, 29.81, 29.81, 30.11, 30.11, 30.11, 30.11, 30.11, 30.25, 30.25, 30.25, 30.25, 30.25, 30.27, 30.27, 30.27, 30.27, 30.27, 30.39, 30.39, 30.39, 30.39, 30.39, 30.24, 30.24, 30.24, 30.24, 30.24, 30.06, 30.06, 30.06, 30.06, 30.06, 29.65, 29.65, 29.65, 29.65, 29.65, 28.97, 28.97, 28.97, 28.97, 28.97, 28.66, 28.66, 28.66, 28.66, 28.66, 28.54, 28.54, 28.54, 28.54, 28.54, 28.59, 28.59, 28.59, 28.59, 28.59, 28.65, 28.65, 28.65, 28.65, 28.65, 28.69, 28.69, 28.69, 28.69, 28.69, 28.75, 28.75, 28.75, 28.75, 28.75, 28.81, 28.81, 28.81, 28.81, 28.81, 28.7, 28.7, 28.7, 28.7, 28.7, 28.7, 28.7, 28.7, 28.7, 28.7, 28.65, 28.65, 28.65, 28.65, 28.65, 28.74, 28.74, 28.74, 28.74, 28.74, 28.88, 28.88, 28.88, 28.88, 28.88, 29.03, 29.03, 29.03, 29.03, 29.03, 29.1, 29.1, 29.1, 29.1, 29.1, 29.2, 29.2, 29.2, 29.2]
                    
Loading

Details

kv_cache_usage_ratio

More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 542 iterations"
    y-axis "llamacpp:kv_cache_usage_ratio"
    x-axis "llamacpp:kv_cache_usage_ratio" 1716343747 --> 1716344373
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.11, 0.11, 0.11, 0.11, 0.11, 0.31, 0.31, 0.31, 0.31, 0.31, 0.37, 0.37, 0.37, 0.37, 0.37, 0.35, 0.35, 0.35, 0.35, 0.35, 0.12, 0.12, 0.12, 0.12, 0.12, 0.17, 0.17, 0.17, 0.17, 0.17, 0.14, 0.14, 0.14, 0.14, 0.14, 0.12, 0.12, 0.12, 0.12, 0.12, 0.18, 0.18, 0.18, 0.18, 0.18, 0.14, 0.14, 0.14, 0.14, 0.14, 0.15, 0.15, 0.15, 0.15, 0.15, 0.16, 0.16, 0.16, 0.16, 0.16, 0.27, 0.27, 0.27, 0.27, 0.27, 0.19, 0.19, 0.19, 0.19, 0.19, 0.34, 0.34, 0.34, 0.34, 0.34, 0.28, 0.28, 0.28, 0.28, 0.28, 0.21, 0.21, 0.21, 0.21, 0.21, 0.12, 0.12, 0.12, 0.12, 0.12, 0.27, 0.27, 0.27, 0.27, 0.27, 0.33, 0.33, 0.33, 0.33, 0.33, 0.15, 0.15, 0.15, 0.15, 0.15, 0.16, 0.16, 0.16, 0.16, 0.16, 0.15, 0.15, 0.15, 0.15, 0.15, 0.21, 0.21, 0.21, 0.21, 0.21, 0.13, 0.13, 0.13, 0.13, 0.13, 0.09, 0.09, 0.09, 0.09, 0.09, 0.31, 0.31, 0.31, 0.31, 0.31, 0.29, 0.29, 0.29, 0.29, 0.29, 0.12, 0.12, 0.12, 0.12, 0.12, 0.08, 0.08, 0.08, 0.08, 0.08, 0.13, 0.13, 0.13, 0.13, 0.13, 0.14, 0.14, 0.14, 0.14, 0.14, 0.19, 0.19, 0.19, 0.19, 0.19, 0.16, 0.16, 0.16, 0.16, 0.16, 0.27, 0.27, 0.27, 0.27, 0.27, 0.23, 0.23, 0.23, 0.23, 0.23, 0.34, 0.34, 0.34, 0.34, 0.34, 0.09, 0.09, 0.09, 0.09, 0.09, 0.1, 0.1, 0.1, 0.1, 0.1, 0.16, 0.16, 0.16, 0.16, 0.16, 0.17, 0.17, 0.17, 0.17, 0.17, 0.32, 0.32, 0.32, 0.32, 0.32, 0.5, 0.5, 0.5, 0.5, 0.5, 0.53, 0.53, 0.53, 0.53, 0.53, 0.5, 0.5, 0.5, 0.5, 0.5, 0.33, 0.33, 0.33, 0.33, 0.33, 0.21, 0.21, 0.21, 0.21, 0.21, 0.23, 0.23, 0.23, 0.23, 0.23, 0.21, 0.21, 0.21, 0.21, 0.21, 0.18, 0.18, 0.18, 0.18, 0.18, 0.1, 0.1, 0.1, 0.1, 0.1, 0.22, 0.22, 0.22, 0.22, 0.22, 0.25, 0.25, 0.25, 0.25, 0.25, 0.16, 0.16, 0.16, 0.16, 0.16, 0.24, 0.24, 0.24, 0.24, 0.24, 0.09, 0.09, 0.09, 0.09, 0.09, 0.15, 0.15, 0.15, 0.15, 0.15, 0.12, 0.12, 0.12, 0.12, 0.12, 0.14, 0.14, 0.14, 0.14, 0.14, 0.18, 0.18, 0.18, 0.18, 0.18, 0.13, 0.13, 0.13, 0.13]
                    
Loading
requests_processing
More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 542 iterations"
    y-axis "llamacpp:requests_processing"
    x-axis "llamacpp:requests_processing" 1716343747 --> 1716344373
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 3.0, 3.0, 3.0, 3.0, 3.0, 6.0, 6.0, 6.0, 6.0, 6.0, 7.0, 7.0, 7.0, 7.0, 7.0, 5.0, 5.0, 5.0, 5.0, 5.0, 4.0, 4.0, 4.0, 4.0, 4.0, 6.0, 6.0, 6.0, 6.0, 6.0, 1.0, 1.0, 1.0, 1.0, 1.0, 7.0, 7.0, 7.0, 7.0, 7.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 6.0, 6.0, 6.0, 6.0, 6.0, 8.0, 8.0, 8.0, 8.0, 8.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 8.0, 8.0, 8.0, 8.0, 8.0, 5.0, 5.0, 5.0, 5.0, 5.0, 4.0, 4.0, 4.0, 4.0, 4.0, 3.0, 3.0, 3.0, 3.0, 3.0, 7.0, 7.0, 7.0, 7.0, 7.0, 6.0, 6.0, 6.0, 6.0, 6.0, 5.0, 5.0, 5.0, 5.0, 5.0, 7.0, 7.0, 7.0, 7.0, 7.0, 4.0, 4.0, 4.0, 4.0, 4.0, 3.0, 3.0, 3.0, 3.0, 3.0, 6.0, 6.0, 6.0, 6.0, 6.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 8.0, 8.0, 8.0, 8.0, 8.0, 5.0, 5.0, 5.0, 5.0, 5.0, 7.0, 7.0, 7.0, 7.0, 7.0, 5.0, 5.0, 5.0, 5.0, 5.0, 7.0, 7.0, 7.0, 7.0, 7.0, 6.0, 6.0, 6.0, 6.0, 6.0, 4.0, 4.0, 4.0, 4.0, 4.0, 7.0, 7.0, 7.0, 7.0, 7.0, 4.0, 4.0, 4.0, 4.0, 4.0, 2.0, 2.0, 2.0, 2.0, 2.0, 4.0, 4.0, 4.0, 4.0, 4.0, 2.0, 2.0, 2.0, 2.0, 2.0, 3.0, 3.0, 3.0, 3.0, 3.0, 5.0, 5.0, 5.0, 5.0, 5.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 5.0, 5.0, 5.0, 5.0, 5.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 8.0, 8.0, 8.0, 8.0, 8.0, 5.0, 5.0, 5.0, 5.0, 5.0, 7.0, 7.0, 7.0, 7.0, 7.0, 6.0, 6.0, 6.0, 6.0, 6.0, 7.0, 7.0, 7.0, 7.0, 7.0, 4.0, 4.0, 4.0, 4.0, 4.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0]
                    
Loading

@jart jart force-pushed the moe branch 2 times, most recently from e717fec to e1c02a7 Compare May 12, 2024 03:20
@lemmi
Copy link

lemmi commented May 12, 2024

Here's a benchmark of an AMD V3C48 (a Zen 3 part) with Mistral 7B Instruct v0.2. (I had to throw out some code that used the X86_HAVE(F16C) check to make it compile though)

PR:

model size params backend threads test t/s
mistral 7B BF16 13.49 GiB 7.24 B CPU 6 pp512 23.25 ± 0.17
mistral 7B BF16 13.49 GiB 7.24 B CPU 6 tg128 3.04 ± 0.01
mistral 7B F16 13.49 GiB 7.24 B CPU 6 pp512 23.10 ± 0.10
mistral 7B F16 13.49 GiB 7.24 B CPU 6 tg128 3.02 ± 0.01
mistral 7B all F32 26.98 GiB 7.24 B CPU 6 pp512 20.00 ± 0.06
mistral 7B all F32 26.98 GiB 7.24 B CPU 6 tg128 1.52 ± 0.01

Without these changes, prompt processing for BF16 clocks in at about 11 t/s (see #7182), rest stays the same. Good improvement overall :)

(I'm still a bit confused as to why F16 performs so much better than BF16 without tinyblas and whether there is still something left on the table, but at least this way there is no compromise in using BF16 now)

ggml.c Outdated Show resolved Hide resolved
@jart
Copy link
Contributor Author

jart commented May 12, 2024

@lemmi where you're going to see the biggest changes here are running mixtral (rather than mistral) because moe models use MUL_MAT_ID (that's where they spend most of their clock cycles) and until this change we had no BLAS support whatsoever for MUL_MAT_ID. As for BF16 vs. F16 this change introduces tinyBLAS support for BF16 (which is a very recently introduced data type) so finally having BLAS-like performance for BF16 is naturally going to help it catch up with F16, and then surpass it on znver4.

@ggerganov
Copy link
Owner

Take a look at the failing CI run: https://github.com/ggerganov/llama.cpp/actions/runs/9052429493/job/24870086560?pr=6840

D:\a\llama.cpp\llama.cpp\sgemm.cpp(827,59): error C3861: 'MM256_SET_M128I': identifier not found [D:\a\llama.cpp\llama.cpp\build\ggml.vcxproj]

@ggerganov ggerganov added the merge ready indicates that this may be ready to merge soon and is just holding out in case of objections label May 15, 2024
@ggerganov
Copy link
Owner

@jart I think the following patch should fix the CI:

diff --git a/ggml-impl.h b/ggml-impl.h
index d85b152b..85d3f23f 100644
--- a/ggml-impl.h
+++ b/ggml-impl.h
@@ -17,6 +17,9 @@
 #define MIN(a, b) ((a) < (b) ? (a) : (b))
 #define MAX(a, b) ((a) > (b) ? (a) : (b))
 
+// some compilers don't provide _mm256_set_m128i, e.g. gcc 7
+#define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1)
+
 /**
  * Converts brain16 to float32.
  *
diff --git a/ggml-quants.c b/ggml-quants.c
index 00334c5f..3677b2db 100644
--- a/ggml-quants.c
+++ b/ggml-quants.c
@@ -22,9 +22,6 @@
 
 #define UNUSED GGML_UNUSED
 
-// some compilers don't provide _mm256_set_m128i, e.g. gcc 7
-#define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1)
-
 #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__)
 // multiply int8_t, add results pairwise twice
 static inline __m128i mul_sum_i8_pairs(const __m128i x, const __m128i y) {

@github-actions github-actions bot added the ggml changes relating to the ggml tensor library for machine learning label May 22, 2024
@jart
Copy link
Contributor Author

jart commented May 22, 2024

@ggerganov I've updated this change with your suggestion. Test flake looks unrelated. PTAL.

@ggerganov
Copy link
Owner

On M2 Ultra I'm observing some TG regression for Q8_0 and Q4_0 MoE when running fully on the CPU:

./scripts/compare-commits.sh master pr/6840 \
  -m models/mixtral-instruct-8x7b-fast/ggml-model-f16.gguf \
  -m models/mixtral-instruct-8x7b-fast/ggml-model-q8_0.gguf \
  -m models/mixtral-instruct-8x7b-fast/ggml-model-q4_0.gguf -t 16 -ngl 0
CPU Model Model Size [GiB] Test t/s master t/s pr/6840 Speedup
M2 Ultra llama 8x7B F16 86.99 pp512 30.54 47.72 1.56
M2 Ultra llama 8x7B F16 86.99 tg128 8.75 7.66 0.88
M2 Ultra llama 8x7B F16 86.99 pp512+tg128 20.42 22.85 1.12
M2 Ultra llama 8x7B Q4_0 48.25 pp512 44.39 63.95 1.44
M2 Ultra llama 8x7B Q4_0 48.25 tg128 25.23 23.23 0.92
M2 Ultra llama 8x7B Q4_0 48.25 pp512+tg128 36.62 44.83 1.22
M2 Ultra llama 8x7B Q8_0 46.22 pp512 49.20 71.55 1.45
M2 Ultra llama 8x7B Q8_0 46.22 tg128 15.32 15.29 1.00
M2 Ultra llama 8x7B Q8_0 46.22 pp512+tg128 32.50 39.38 1.21

@jart Do you observe this regression on your M2 Ultra?

@jart
Copy link
Contributor Author

jart commented May 23, 2024

Thanks for pointing that out. I just reproduced the same regression. This change doesn't appear to be helpful for text generation so I've disabled it. PTAL

@mofosyne
Copy link
Collaborator

mofosyne commented Jun 9, 2024

@jart just a heads up that this was marked as merge ready, but CI is not passing. If it's not related to code changes you may want to rebase against latest known working CI in master, as I recall we had issue with CI in the master branch around that time.

This change introduces a llamafile_mixmul() API, that allows tinyBLAS to
speed up "Mixture of Expert" models. On my Threadripper the Mixtral 8x7b
F16 weights now process prompts 2x faster. I am also seeing a 60 percent
improvement with Mixtral 8x22b Q4_0. Support is provided for Q8_0; it is
also supported by tinyBLAS. MoE models spend the most time in MUL_MAT_ID
rather than MUL_MAT, which is why llamafile_sgemm() was not able to help
them before. The new code works by decomposing the mixmul operation into
fast 2d llamafile_sgemm() calls. This also adds BF16 support to tinyBLAS
@mofosyne mofosyne removed the merge ready indicates that this may be ready to merge soon and is just holding out in case of objections label Jul 19, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request ggml changes relating to the ggml tensor library for machine learning Review Complexity : Medium Generally require more time to grok but manageable by beginner to medium expertise level
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants