Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CUDA: mul_mat_vec_q for batch sizes > 1 #5351

Merged

Conversation

JohannesGaessler
Copy link
Collaborator

On master the mul_mat_vec_q kernel only supports a batch size of 1. This PR implements support for batch sizes up to 8 in a minimally invasive way. In this range mul_mat_vec_q is universally faster. For larger batch sizes it depends on the quantization format. To keep things simple I am therefore only enabling the new implementation for batch sizes <= 8 (which should be enough for techniques like speculative decoding). I think the optimal solution would be to rewrite the mul_mat_vec_q kernel in a way that maximizes memory bandwidth (and also more optimization for the competing mul_mat_q kernels) but that is a larger project. As part of this PR I have also deduplicated the code for calling mul_mat_vec_q. On my systems the performance changes as follows:

GPU Model Batch size Test t/s master t/s PR Speedup
RTX 3090 llama 7B Q2_K_M 1 pp512 103.75 102.80 0.99
RTX 3090 llama 7B Q2_K_M 2 pp512 99.39 177.01 1.78
RTX 3090 llama 7B Q2_K_M 4 pp512 195.77 248.26 1.27
RTX 3090 llama 7B Q2_K_M 8 pp512 246.80 297.62 1.21
RTX 3090 llama 7B Q3_K_S 1 pp512 98.36 97.37 0.99
RTX 3090 llama 7B Q3_K_S 2 pp512 89.10 170.57 1.91
RTX 3090 llama 7B Q3_K_S 4 pp512 175.77 243.07 1.38
RTX 3090 llama 7B Q3_K_S 8 pp512 221.98 292.72 1.32
RTX 3090 llama 7B Q4_0 1 pp512 131.80 130.71 0.99
RTX 3090 llama 7B Q4_0 2 pp512 197.91 241.99 1.22
RTX 3090 llama 7B Q4_0 4 pp512 386.23 371.20 0.96
RTX 3090 llama 7B Q4_0 8 pp512 498.03 480.78 0.97
RTX 3090 llama 7B Q4_1 1 pp512 124.75 123.22 0.99
RTX 3090 llama 7B Q4_1 2 pp512 187.37 234.17 1.25
RTX 3090 llama 7B Q4_1 4 pp512 363.74 379.89 1.04
RTX 3090 llama 7B Q4_1 8 pp512 493.88 471.49 0.95
RTX 3090 llama 7B Q4_K_S 1 pp512 124.42 123.73 0.99
RTX 3090 llama 7B Q4_K_S 2 pp512 145.04 204.73 1.41
RTX 3090 llama 7B Q4_K_S 4 pp512 283.69 285.35 1.01
RTX 3090 llama 7B Q4_K_S 8 pp512 341.05 352.97 1.03
RTX 3090 llama 7B Q5_0 1 pp512 114.12 112.97 0.99
RTX 3090 llama 7B Q5_0 2 pp512 97.08 216.30 2.23
RTX 3090 llama 7B Q5_0 4 pp512 190.99 354.03 1.85
RTX 3090 llama 7B Q5_0 8 pp512 301.19 433.77 1.44
RTX 3090 llama 7B Q5_1 1 pp512 109.46 108.28 0.99
RTX 3090 llama 7B Q5_1 2 pp512 120.85 210.62 1.74
RTX 3090 llama 7B Q5_1 4 pp512 236.95 327.31 1.38
RTX 3090 llama 7B Q5_1 8 pp512 359.41 468.41 1.30
RTX 3090 llama 7B Q5_K_S 1 pp512 112.22 111.32 0.99
RTX 3090 llama 7B Q5_K_S 2 pp512 110.10 201.15 1.83
RTX 3090 llama 7B Q5_K_S 4 pp512 216.51 283.01 1.31
RTX 3090 llama 7B Q5_K_S 8 pp512 267.93 347.67 1.30
RTX 3090 llama 7B Q6_K 1 pp512 90.84 90.19 0.99
RTX 3090 llama 7B Q6_K 2 pp512 100.04 145.84 1.46
RTX 3090 llama 7B Q6_K 4 pp512 197.02 236.99 1.20
RTX 3090 llama 7B Q6_K 8 pp512 254.85 306.45 1.20
RTX 3090 llama 7B Q8_0 1 pp512 87.15 86.56 0.99
RTX 3090 llama 7B Q8_0 2 pp512 118.66 166.80 1.41
RTX 3090 llama 7B Q8_0 4 pp512 232.79 292.20 1.26
RTX 3090 llama 7B Q8_0 8 pp512 314.13 404.77 1.29
RX 6800 llama 7B Q2_K_M 1 pp512 59.11 59.45 1.01
RX 6800 llama 7B Q2_K_M 2 pp512 11.67 89.29 7.65
RX 6800 llama 7B Q2_K_M 4 pp512 23.27 108.07 4.64
RX 6800 llama 7B Q2_K_M 8 pp512 46.35 147.49 3.18
RX 6800 llama 7B Q3_K_S 1 pp512 56.74 56.73 1.00
RX 6800 llama 7B Q3_K_S 2 pp512 10.93 84.17 7.70
RX 6800 llama 7B Q3_K_S 4 pp512 21.80 102.01 4.68
RX 6800 llama 7B Q3_K_S 8 pp512 43.39 144.84 3.34
RX 6800 llama 7B Q4_0 1 pp512 65.28 64.94 0.99
RX 6800 llama 7B Q4_0 2 pp512 34.46 127.55 3.70
RX 6800 llama 7B Q4_0 4 pp512 68.61 212.68 3.10
RX 6800 llama 7B Q4_0 8 pp512 136.14 280.72 2.06
RX 6800 llama 7B Q4_1 1 pp512 61.42 61.22 1.00
RX 6800 llama 7B Q4_1 2 pp512 31.48 120.18 3.82
RX 6800 llama 7B Q4_1 4 pp512 62.75 215.77 3.44
RX 6800 llama 7B Q4_1 8 pp512 124.38 279.39 2.25
RX 6800 llama 7B Q4_K_S 1 pp512 56.59 56.15 0.99
RX 6800 llama 7B Q4_K_S 2 pp512 26.56 98.99 3.73
RX 6800 llama 7B Q4_K_S 4 pp512 52.98 146.75 2.77
RX 6800 llama 7B Q4_K_S 8 pp512 105.32 166.09 1.58
RX 6800 llama 7B Q5_0 1 pp512 58.98 58.68 0.99
RX 6800 llama 7B Q5_0 2 pp512 28.46 115.67 4.06
RX 6800 llama 7B Q5_0 4 pp512 56.72 209.11 3.69
RX 6800 llama 7B Q5_0 8 pp512 112.88 257.96 2.29
RX 6800 llama 7B Q5_1 1 pp512 54.62 54.44 1.00
RX 6800 llama 7B Q5_1 2 pp512 26.32 106.53 4.05
RX 6800 llama 7B Q5_1 4 pp512 52.37 198.94 3.80
RX 6800 llama 7B Q5_1 8 pp512 103.59 237.57 2.29
RX 6800 llama 7B Q5_K_S 1 pp512 53.91 53.48 0.99
RX 6800 llama 7B Q5_K_S 2 pp512 26.45 95.36 3.60
RX 6800 llama 7B Q5_K_S 4 pp512 52.78 132.93 2.52
RX 6800 llama 7B Q5_K_S 8 pp512 104.92 155.61 1.48
RX 6800 llama 7B Q6_K 1 pp512 52.17 51.77 0.99
RX 6800 llama 7B Q6_K 2 pp512 24.81 100.83 4.06
RX 6800 llama 7B Q6_K 4 pp512 49.43 139.27 2.82
RX 6800 llama 7B Q6_K 8 pp512 98.05 142.07 1.45
RX 6800 llama 7B Q8_0 1 pp512 44.83 44.63 1.00
RX 6800 llama 7B Q8_0 2 pp512 34.70 87.87 2.53
RX 6800 llama 7B Q8_0 4 pp512 69.00 172.24 2.50
RX 6800 llama 7B Q8_0 8 pp512 136.64 206.36 1.51
P40 llama 7B Q2_K_M 1 pp512 30.88 30.91 1.00
P40 llama 7B Q2_K_M 2 pp512 13.50 40.95 3.03
P40 llama 7B Q2_K_M 4 pp512 26.84 64.87 2.42
P40 llama 7B Q2_K_M 8 pp512 51.91 81.14 1.56
P40 llama 7B Q3_K_S 1 pp512 28.45 28.43 1.00
P40 llama 7B Q3_K_S 2 pp512 13.29 40.13 3.02
P40 llama 7B Q3_K_S 4 pp512 26.39 64.51 2.44
P40 llama 7B Q3_K_S 8 pp512 51.05 81.00 1.59
P40 llama 7B Q4_0 1 pp512 54.71 54.66 1.00
P40 llama 7B Q4_0 2 pp512 21.81 57.80 2.65
P40 llama 7B Q4_0 4 pp512 43.41 86.93 2.00
P40 llama 7B Q4_0 8 pp512 82.34 123.46 1.50
P40 llama 7B Q4_1 1 pp512 54.15 54.03 1.00
P40 llama 7B Q4_1 2 pp512 22.52 57.54 2.56
P40 llama 7B Q4_1 4 pp512 44.53 91.01 2.04
P40 llama 7B Q4_1 8 pp512 84.30 124.96 1.48
P40 llama 7B Q4_K_S 1 pp512 50.84 50.94 1.00
P40 llama 7B Q4_K_S 2 pp512 20.60 43.42 2.11
P40 llama 7B Q4_K_S 4 pp512 40.71 72.69 1.79
P40 llama 7B Q4_K_S 8 pp512 77.35 92.97 1.20
P40 llama 7B Q5_0 1 pp512 47.09 46.98 1.00
P40 llama 7B Q5_0 2 pp512 20.28 51.54 2.54
P40 llama 7B Q5_0 4 pp512 40.47 82.53 2.04
P40 llama 7B Q5_0 8 pp512 76.87 115.32 1.50
P40 llama 7B Q5_1 1 pp512 47.61 47.50 1.00
P40 llama 7B Q5_1 2 pp512 21.57 49.75 2.31
P40 llama 7B Q5_1 4 pp512 42.78 84.25 1.97
P40 llama 7B Q5_1 8 pp512 81.15 121.31 1.49
P40 llama 7B Q5_K_S 1 pp512 38.87 38.90 1.00
P40 llama 7B Q5_K_S 2 pp512 18.95 48.89 2.58
P40 llama 7B Q5_K_S 4 pp512 37.52 67.09 1.79
P40 llama 7B Q5_K_S 8 pp512 71.46 94.45 1.32
P40 llama 7B Q6_K 1 pp512 35.00 35.03 1.00
P40 llama 7B Q6_K 2 pp512 19.09 41.42 2.17
P40 llama 7B Q6_K 4 pp512 37.11 58.58 1.58
P40 llama 7B Q6_K 8 pp512 71.53 71.52 1.00
P40 llama 7B Q8_0 1 pp512 33.72 33.71 1.00
P40 llama 7B Q8_0 2 pp512 21.72 42.41 1.95
P40 llama 7B Q8_0 4 pp512 42.93 56.08 1.31
P40 llama 7B Q8_0 8 pp512 81.27 102.88 1.27

I did not test the new XS/XSS quants because I did not yet get around to setting them up.

@ggerganov
Copy link
Owner

This is looking very good! I'm doing some V100 tests and seeing similar gains

To keep things simple I am therefore only enabling the new implementation for batch sizes <= 8 (which should be enough for techniques like speculative decoding).

Yes, I think 8 should be fine for most purposes

I think the optimal solution would be to rewrite the mul_mat_vec_q kernel in a way that maximizes memory bandwidth

I was thinking of writing a ggml_cuda_op_dequantize_mul_mat_vec kernel that dequantizes into shared memory and loads 16x16 fragments (padding with 0 what is beyond batch size). This will have constant speed across all batches up to 16, but it remains to be seen if it can outperform mul_mat_vec_q for very small batches. I'm thinking even if there is a tiny regression, it might be worth it, so that we have a single set of mat-vec kernels and consistent performance for bs <= 16

@ggerganov ggerganov requested a review from slaren February 6, 2024 12:51
@JohannesGaessler
Copy link
Collaborator Author

I was thinking of writing a ggml_cuda_op_dequantize_mul_mat_vec kernel that dequantizes into shared memory and loads 16x16 fragments (padding with 0 what is beyond batch size).

I've tried some improved implementations for both dequantize_mul_mat_vec and mul_mat_vec_q where the data is loaded into shared memory in a coalesced way and then only unpacked in shared memory but I was not able to make these faster than the implementation on master. I think one issue is that on a hardware level shared memory is just a portion of the regular cache that is manually managed. And since the data access patterns are very predictable VRAM (cache) -> SRAM -> registers just adds extra copies compared to VRAM (cache) -> registers. It may be better to have a kernel that loads the data into registers and then distributes the scales as needed using __shfl_sync. Or if you load the data asynchronously from VRAM to SRAM that may also work to increase bandwidth (but as I said before, memcpy_async causes annoying build issues.).

I think one issue with FP16 arithmetic will be shared memory limits. In principle, if you can fit the entire hidden state into shared memory then you should be able to write a kernel that needs to load it only once per streaming multiprocessor. Even for the smallest hidden state of 4096 values, if you have a batch size of 16 you would need 16*4096*2 = 131072 bytes to store it as FP16. But you would only need 73728 bytes to store it as q8_1 and then it would fit into the 102400 bytes of shared memory per SM on Ampere (the A100 has 167936 bytes per SM).

There are also issues with tail effects. Ideally you would distribute the rows as evenly as possible between SMs to maximize GPU utilization but tensor cores restrict you to multiples of 8/16/32 (depending on whether there are at least 32/16/8 hidden state columns). If you just use __dp4a for integer dot products you can assign rows with a granularity of 1.

@JohannesGaessler JohannesGaessler merged commit 2c51661 into ggerganov:master Feb 6, 2024
53 checks passed
@ggerganov
Copy link
Owner

ggerganov commented Feb 6, 2024

Do you have some 13B models handy to see how are the results (I have just 7B on this V100 machine and it will take me some time to setup)?

I suspect that the results might not be so good for larger models.
For example, with Q8_0, the speed after BS=6 actually starts to degrade significantly (see the S_TG column):

LLAMA_CUBLAS=1 make -j batched-bench && ./batched-bench /mnt/llama.cpp/models/open-llama/7B-v2/ggml-model-q8_0.gguf 4800 0 99 0 50 100 1,2,3,4,5,6,7,8,16,32,64
|    PP |     TG |    B |   N_KV |   T_PP s | S_PP t/s |   T_TG s | S_TG t/s |      T s |    S t/s |
|-------|--------|------|--------|----------|----------|----------|----------|----------|----------|
|    50 |    100 |    1 |    150 |    0.079 |   635.44 |    1.311 |    76.26 |    1.390 |   107.91 |
|    50 |    100 |    2 |    300 |    0.076 |  1313.75 |    1.469 |   136.15 |    1.545 |   194.17 |
|    50 |    100 |    3 |    450 |    0.082 |  1823.60 |    1.688 |   177.76 |    1.770 |   254.24 |
|    50 |    100 |    4 |    600 |    0.095 |  2111.73 |    1.659 |   241.09 |    1.754 |   342.11 |
|    50 |    100 |    5 |    750 |    0.096 |  2592.66 |    2.010 |   248.78 |    2.106 |   356.08 |
|    50 |    100 |    6 |    900 |    0.116 |  2582.24 |    2.394 |   250.60 |    2.510 |   358.51 |
|    50 |    100 |    7 |   1050 |    0.127 |  2753.24 |    2.854 |   245.25 |    2.981 |   352.19 |
|    50 |    100 |    8 |   1200 |    0.139 |  2871.77 |    3.713 |   215.43 |    3.853 |   311.47 |
|    50 |    100 |   16 |   2400 |    0.282 |  2837.93 |    5.076 |   315.22 |    5.358 |   447.95 |
|    50 |    100 |   32 |   4800 |    0.582 |  2750.42 |    9.534 |   335.62 |   10.116 |   474.49 |

Before this PR I got:

|    PP |     TG |    B |   N_KV |   T_PP s | S_PP t/s |   T_TG s | S_TG t/s |      T s |    S t/s |
|-------|--------|------|--------|----------|----------|----------|----------|----------|----------|
|    50 |    100 |    1 |    150 |    0.079 |   634.24 |    1.324 |    75.55 |    1.403 |   106.95 |
|    50 |    100 |    2 |    300 |    0.076 |  1313.34 |    1.980 |   100.99 |    2.056 |   145.88 |
|    50 |    100 |    3 |    450 |    0.082 |  1820.65 |    2.016 |   148.77 |    2.099 |   214.40 |
|    50 |    100 |    4 |    600 |    0.095 |  2112.16 |    2.059 |   194.27 |    2.154 |   278.59 |
|    50 |    100 |    5 |    750 |    0.097 |  2580.89 |    2.532 |   197.50 |    2.629 |   285.33 |
|    50 |    100 |    6 |    900 |    0.116 |  2583.33 |    2.575 |   233.04 |    2.691 |   334.47 |
|    50 |    100 |    7 |   1050 |    0.128 |  2733.12 |    2.639 |   265.29 |    2.767 |   379.51 |
|    50 |    100 |    8 |   1200 |    0.140 |  2867.10 |    2.676 |   298.96 |    2.815 |   426.22 |
|    50 |    100 |   16 |   2400 |    0.283 |  2829.51 |    5.100 |   313.74 |    5.382 |   445.90 |
|    50 |    100 |   32 |   4800 |    0.582 |  2751.47 |    9.545 |   335.24 |   10.127 |   473.98 |

I'm thinking, the larger the model, the stronger this effect would be maybe

@slaren
Copy link
Collaborator

slaren commented Feb 6, 2024

llama-bench shows that performance increases consistently with the batch size (for bs 1-8), what is different with batched-bench? Is it possibly an outlier?

Device 0: NVIDIA GeForce RTX 3090 Ti, compute capability 8.6, VMM: yes

model size params backend ngl n_batch test t/s
llama 7B Q8_0 6.67 GiB 6.74 B CUDA 99 1 pp 512 85.21 ± 0.16
llama 7B Q8_0 6.67 GiB 6.74 B CUDA 99 2 pp 512 120.27 ± 0.27
llama 7B Q8_0 6.67 GiB 6.74 B CUDA 99 3 pp 512 177.33 ± 0.35
llama 7B Q8_0 6.67 GiB 6.74 B CUDA 99 4 pp 512 235.63 ± 1.31
llama 7B Q8_0 6.67 GiB 6.74 B CUDA 99 5 pp 512 220.49 ± 1.11
llama 7B Q8_0 6.67 GiB 6.74 B CUDA 99 6 pp 512 263.85 ± 1.38
llama 7B Q8_0 6.67 GiB 6.74 B CUDA 99 7 pp 512 301.79 ± 2.48
llama 7B Q8_0 6.67 GiB 6.74 B CUDA 99 8 pp 512 346.90 ± 1.63

build: 9392ebd (2061)

./llama-bench -m models/13B/ggml-model-Q8_0.gguf -p 1,2,3,4,5,6,7,8 -n 0 -r 100

Device 0: NVIDIA GeForce RTX 3090 Ti, compute capability 8.6, VMM: yes

model size params backend ngl test t/s
llama 13B Q8_0 12.88 GiB 13.02 B CUDA 99 pp 1 51.22 ± 3.21
llama 13B Q8_0 12.88 GiB 13.02 B CUDA 99 pp 2 75.90 ± 0.36
llama 13B Q8_0 12.88 GiB 13.02 B CUDA 99 pp 3 112.92 ± 0.60
llama 13B Q8_0 12.88 GiB 13.02 B CUDA 99 pp 4 150.07 ± 0.58
llama 13B Q8_0 12.88 GiB 13.02 B CUDA 99 pp 5 139.42 ± 0.57
llama 13B Q8_0 12.88 GiB 13.02 B CUDA 99 pp 6 166.50 ± 2.82
llama 13B Q8_0 12.88 GiB 13.02 B CUDA 99 pp 7 193.80 ± 1.45
llama 13B Q8_0 12.88 GiB 13.02 B CUDA 99 pp 8 220.15 ± 2.18

build: 9392ebd (2061)

@ggerganov
Copy link
Owner

ggerganov commented Feb 6, 2024

This is using llama-bench:

LLAMA_CUBLAS=1 make -j && ./llama-bench -ngl 99 -m models/openllama-7b-v2/ggml-model-q8_0.gguf -p 512 -b 1,2,3,4,5,6,7,8,16 -n 0

Device 0: Tesla V100-PCIE-16GB, compute capability 7.0, VMM: yes

model size params backend ngl n_batch test t/s
llama 7B Q8_0 6.67 GiB 6.74 B CUDA 99 1 pp 512 75.14 ± 0.11
llama 7B Q8_0 6.67 GiB 6.74 B CUDA 99 2 pp 512 135.18 ± 0.45
llama 7B Q8_0 6.67 GiB 6.74 B CUDA 99 3 pp 512 179.34 ± 0.44
llama 7B Q8_0 6.67 GiB 6.74 B CUDA 99 4 pp 512 246.62 ± 0.70
llama 7B Q8_0 6.67 GiB 6.74 B CUDA 99 5 pp 512 257.17 ± 0.21
llama 7B Q8_0 6.67 GiB 6.74 B CUDA 99 6 pp 512 259.58 ± 0.39
llama 7B Q8_0 6.67 GiB 6.74 B CUDA 99 7 pp 512 256.51 ± 0.20
llama 7B Q8_0 6.67 GiB 6.74 B CUDA 99 8 pp 512 225.86 ± 0.13
llama 7B Q8_0 6.67 GiB 6.74 B CUDA 99 16 pp 512 360.31 ± 0.76

build: dbb795b (2075)

Before this PR:

Device 0: Tesla V100-PCIE-16GB, compute capability 7.0, VMM: yes

model size params backend ngl n_batch test t/s
llama 7B Q8_0 6.67 GiB 6.74 B CUDA 99 1 pp 512 75.57 ± 0.04
llama 7B Q8_0 6.67 GiB 6.74 B CUDA 99 2 pp 512 101.47 ± 0.23
llama 7B Q8_0 6.67 GiB 6.74 B CUDA 99 3 pp 512 149.69 ± 0.09
llama 7B Q8_0 6.67 GiB 6.74 B CUDA 99 4 pp 512 198.93 ± 0.25
llama 7B Q8_0 6.67 GiB 6.74 B CUDA 99 5 pp 512 202.30 ± 0.26
llama 7B Q8_0 6.67 GiB 6.74 B CUDA 99 6 pp 512 241.35 ± 0.40
llama 7B Q8_0 6.67 GiB 6.74 B CUDA 99 7 pp 512 277.81 ± 0.66
llama 7B Q8_0 6.67 GiB 6.74 B CUDA 99 8 pp 512 318.37 ± 0.89
llama 7B Q8_0 6.67 GiB 6.74 B CUDA 99 16 pp 512 360.33 ± 0.51

build: 8a79c59 (2078)

Btw @slaren you might be running a wrong commit? 9392ebd is on master

@slaren
Copy link
Collaborator

slaren commented Feb 6, 2024

You are right, I forgot to pull before running the test. This is with this PR:

Device 0: NVIDIA GeForce RTX 3090 Ti, compute capability 8.6, VMM: yes

model size params backend ngl test t/s
llama 7B Q8_0 6.67 GiB 6.74 B CUDA 99 pp 1 87.15 ± 6.65
llama 7B Q8_0 6.67 GiB 6.74 B CUDA 99 pp 2 166.08 ± 1.41
llama 7B Q8_0 6.67 GiB 6.74 B CUDA 99 pp 3 195.42 ± 3.38
llama 7B Q8_0 6.67 GiB 6.74 B CUDA 99 pp 4 272.84 ± 1.27
llama 7B Q8_0 6.67 GiB 6.74 B CUDA 99 pp 5 242.03 ± 0.69
llama 7B Q8_0 6.67 GiB 6.74 B CUDA 99 pp 6 358.82 ± 1.31
llama 7B Q8_0 6.67 GiB 6.74 B CUDA 99 pp 7 346.97 ± 1.03
llama 7B Q8_0 6.67 GiB 6.74 B CUDA 99 pp 8 381.42 ± 1.01

build: 2e9c0bd (2080)

Device 0: NVIDIA GeForce RTX 3090 Ti, compute capability 8.6, VMM: yes

model size params backend ngl test t/s
llama 13B Q8_0 12.88 GiB 13.02 B CUDA 99 pp 1 51.63 ± 3.49
llama 13B Q8_0 12.88 GiB 13.02 B CUDA 99 pp 2 98.43 ± 0.52
llama 13B Q8_0 12.88 GiB 13.02 B CUDA 99 pp 3 116.87 ± 0.64
llama 13B Q8_0 12.88 GiB 13.02 B CUDA 99 pp 4 165.70 ± 0.90
llama 13B Q8_0 12.88 GiB 13.02 B CUDA 99 pp 5 133.40 ± 0.22
llama 13B Q8_0 12.88 GiB 13.02 B CUDA 99 pp 6 196.36 ± 0.52
llama 13B Q8_0 12.88 GiB 13.02 B CUDA 99 pp 7 173.17 ± 0.30
llama 13B Q8_0 12.88 GiB 13.02 B CUDA 99 pp 8 187.23 ± 0.35

build: 2e9c0bd (2080)

@JohannesGaessler
Copy link
Collaborator Author

I suspect that the results might not be so good for larger models.

I'm thinking, the larger the model, the stronger this effect would be maybe

The size of the model should not make a difference. That only increases the number of blocks that the GPU works on but not the relative speed of the kernels. I suspect it's an issue related to GPU architecture where on Volta the number of registers per thread ends up too high; I'll look into adding launch bounds.

@slaren
Copy link
Collaborator

slaren commented Feb 6, 2024

I also see a performance drop with batch sizes 6-8 with 13B Q8_0 compared to the previous master with a 3090 Ti.

@ggerganov
Copy link
Owner

The size of the model should not make a difference.

The model size does seem to matter somehow because on 3090 it's faster for 7B Q8_0, but it is slower for 13B Q8_0 at BS 6-8

@JohannesGaessler
Copy link
Collaborator Author

The regression I'm seeing in mine and slaren's data is for batch size 4->5 and 6->7. 4->6, 6->8, and 4->8 is consistently faster. I'm currently profiling a few runs to make sure there are no other weird things going on at those batch sizes.

@JohannesGaessler
Copy link
Collaborator Author

I ran:

make clean && make LLAMA_CUBLAS=1 llama-bench
export model_name=llama_2-7b && export quantization=q8_0
for i in 1 2 3 4 5 6 7 8; do nsys profile ./llama-bench --model models/nvme/${model_name}-${quantization}.gguf -n 0 -r 5 -b $i; done

Then I looked at the runtime for mul_mat_vec_q in NSight Systems:

GPU Batch size Time mul_mat_vec_q [s] Time * batch_size [s]
RTX 3090 1 22.599 22.599
RTX 3090 2 12.266 24.532
RTX 3090 3 11.377 34.131
RTX 3090 4 8.248 32.992
RTX 3090 5 10.133 50.665
RTX 3090 6 6.521 39.126
RTX 3090 7 7.051 49.357
RTX 3090 8 6.470 51.760

There is a performance regression for 4->5 and 6->7 that is not caused by any other kernels. This is maybe related to pointer arithmetic because multiplications and divisions by powers of 2 can be replaced with bit shifts which makes them much faster. In addition to that there may be GPU architecture related issues that cause problems on Volta (if I had to guess the compiler assigns different numbers of registers per thread so the impact of tail effects is different).

More generally I think the issue is that the current implementation in mul_mat_vec_q just does not scale well with batch size; I added a column for time*batch_size that should stay constant under the assumption of 100% efficiency but the efficiency for batch sizes > 2 is already getting significantly worse. I personally do not want to invest a lot of time into the current mul_mat_vec_q implementation because I think it needs to be overhauled anyways. Before I do that I would rather either revert this PR or limit the use to smaller batch sizes where there seem to be no issues.

@ggerganov
Copy link
Owner

Yup, let's limit the mul_mat_vec_q kernel to bs <= 4 for now

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants