Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

metal : GPU "idle-throttling" analysis #10119

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open

metal : GPU "idle-throttling" analysis #10119

wants to merge 1 commit into from

Conversation

ggerganov
Copy link
Owner

The Apple Silicon GPUs seem to have some sort of (power-saving?) mechanism which affects the performance in a significant way. Here are some analysis and demonstration. I'm hoping that someone in the community knows a way to workaround this.

The llama-idle tool in this PR performs the following computation:

sleep(t)
decode(1 token)
sleep(t)
decode(1 token)
... repeat and measure the average decode time as a function of t ...

The sleep emulates an idle time where the GPU would do nothing in an application after which it wakes up to compute some tokens. In this case, for simplicity - a single token. In the ideal case, without any throttling mechanisms, the decode time should be constant and not change whatever the value of t is.

Here are the results on M2 Ultra from this test for 3 different models, increasing the time t from 0 up to 2200 ms:

./llama-idle -m models/llama-3.1-8b-instruct/ggml-model-f16.gguf

iters:   10, pause:     0 ms, avg decode time:    26.34 +/- 0.19 ms
iters:   10, pause:   200 ms, avg decode time:    29.22 +/- 0.60 ms
iters:   10, pause:   400 ms, avg decode time:    29.40 +/- 0.63 ms
iters:   10, pause:   600 ms, avg decode time:    29.36 +/- 0.73 ms
iters:   10, pause:   800 ms, avg decode time:    29.21 +/- 0.70 ms
iters:   10, pause:  1000 ms, avg decode time:    29.53 +/- 0.52 ms
iters:   10, pause:  1200 ms, avg decode time:   226.81 +/- 223.52 ms
iters:   10, pause:  1400 ms, avg decode time:   429.08 +/- 35.20 ms
iters:   10, pause:  1600 ms, avg decode time:   413.67 +/- 54.53 ms
iters:   10, pause:  1800 ms, avg decode time:   463.81 +/- 50.51 ms
iters:   10, pause:  2000 ms, avg decode time:   419.40 +/- 38.01 ms
iters:   10, pause:  2200 ms, avg decode time:   472.38 +/- 105.87 ms

./llama-idle -m models/llama-3.1-8b-instruct/ggml-model-q8_0.gguf

iters:   10, pause:     0 ms, avg decode time:    16.13 +/- 0.19 ms
iters:   10, pause:   200 ms, avg decode time:    19.02 +/- 0.66 ms
iters:   10, pause:   400 ms, avg decode time:    18.92 +/- 0.48 ms
iters:   10, pause:   600 ms, avg decode time:    19.13 +/- 0.87 ms
iters:   10, pause:   800 ms, avg decode time:    19.27 +/- 0.66 ms
iters:   10, pause:  1000 ms, avg decode time:    19.19 +/- 0.52 ms
iters:   10, pause:  1200 ms, avg decode time:   106.04 +/- 92.54 ms
iters:   10, pause:  1400 ms, avg decode time:   225.60 +/- 29.71 ms
iters:   10, pause:  1600 ms, avg decode time:   299.07 +/- 73.57 ms
iters:   10, pause:  1800 ms, avg decode time:   270.54 +/- 23.73 ms
iters:   10, pause:  2000 ms, avg decode time:   278.60 +/- 27.24 ms
iters:   10, pause:  2200 ms, avg decode time:   268.24 +/- 65.90 ms

./llama-idle -m models/llama-3.1-8b-instruct/ggml-model-q4_0.gguf

iters:   10, pause:     0 ms, avg decode time:    11.04 +/- 0.19 ms
iters:   10, pause:   200 ms, avg decode time:    14.45 +/- 0.58 ms
iters:   10, pause:   400 ms, avg decode time:    14.20 +/- 0.65 ms
iters:   10, pause:   600 ms, avg decode time:    14.30 +/- 0.70 ms
iters:   10, pause:   800 ms, avg decode time:    14.49 +/- 0.57 ms
iters:   10, pause:  1000 ms, avg decode time:    23.82 +/- 29.35 ms
iters:   10, pause:  1200 ms, avg decode time:    86.02 +/- 69.61 ms
iters:   10, pause:  1400 ms, avg decode time:   171.72 +/- 32.36 ms
iters:   10, pause:  1600 ms, avg decode time:   170.60 +/- 36.20 ms
iters:   10, pause:  1800 ms, avg decode time:   157.94 +/- 12.47 ms
iters:   10, pause:  2000 ms, avg decode time:   205.16 +/- 55.31 ms
iters:   10, pause:  2200 ms, avg decode time:   215.77 +/- 53.40 ms

It can be seen that after ~1s of being idle, the next llama_decode would take extra time to compute. Some additional observations:

  • Same behavior occurs on M1 Pro and M2 Ultra
  • Occurs only with GPU
  • Does not occur with CUDA
  • The extra time is proportional to the model size
  • This overhead is not from any recent changes, as I remember observing it a year ago, but now took the time to look into it in more details

My best guess is that this is something done by the OS in order to reduce the power consumption. But the problem is that it is quite aggressive and it ruins the performance for local applications that run inference every few seconds.

Hopefully there is a way to disable this somehow - any insights would be highly appreciated. For now, the only solution I can come up with is to introduce an optional "heartbeat" callback that would call a dummy Metal kernel every 0.9s, just to keep the GPU under slight pressure. This is obviously not a great solution, but it would be much better than the current situation.

@ggerganov ggerganov added the demo Demonstrate some concept or idea, not intended to be merged label Nov 1, 2024
@ngxson
Copy link
Collaborator

ngxson commented Nov 1, 2024

Found this threads from pytorch, maybe related to this: pytorch/pytorch#124056

@CyborgArmy83
Copy link

Not a lot of people are sifting through these requests. Maybe consider adding it to the discussions or even somewhere in the Readme? Sounds like a (potentially) pretty significant thing if we can figure this out.

@AndreasKunar
Copy link
Contributor

Just to confirm, that it also seems still to be present on the M4.

Machine: MacBook Pro, M4 Pro, 20 GPU, 48GB RAM
macOS Sequoia 15.1.1
llama.cpp version: 4288 (43ed389)
built with Apple clang version 16.0.0 (clang-1600.0.26.4) for arm64-apple-darwin24.1.0

./build/bin/llama-idle -m ../models.llama.cpp/Llama/llama-2-7b-chat.F16.gguf 
...
iters:   10, pause:     0 ms, avg decode time:    56.37 +/- 0.43 ms
iters:   10, pause:   200 ms, avg decode time:    62.96 +/- 1.26 ms
iters:   10, pause:   400 ms, avg decode time:    63.28 +/- 1.40 ms
iters:   10, pause:   600 ms, avg decode time:    63.37 +/- 1.30 ms
iters:   10, pause:   800 ms, avg decode time:    63.27 +/- 1.58 ms
iters:   10, pause:  1000 ms, avg decode time:    63.74 +/- 1.59 ms
iters:   10, pause:  1200 ms, avg decode time:   317.82 +/- 104.86 ms
iters:   10, pause:  1400 ms, avg decode time:   337.73 +/- 48.03 ms
iters:   10, pause:  1600 ms, avg decode time:   320.89 +/- 20.60 ms
iters:   10, pause:  1800 ms, avg decode time:   326.02 +/- 14.65 ms
iters:   10, pause:  2000 ms, avg decode time:   315.00 +/- 17.74 ms

@japindersi
Copy link

There is an option in MacOS settings on Mac Mini, under 'Energy', to change to 'High Power' mode. Does it help in any way?

@ggerganov
Copy link
Owner Author

This option is not present on the Mac Studio:

image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
demo Demonstrate some concept or idea, not intended to be merged examples
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants