Further fixes for performance with internal bucketing.#781
Conversation
Calculate kv cache sliding idx for the decode phase only. Signed-off-by: Puneesh Khanna <pkhanna@habana.ai>
|
make style also passes. @regisss - please merge this new PR. |
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
Thanks @puneeshkhanna! I'll test it in a couple of hours when my Gaudi2 node is free and then I'll merge 👍 |
|
Some reference commands to verify performance: python ../gaudi_spawn.py --use_deepspeed --world_size 1 run_generation.py --model_name_or_path Llama-2-7b-hf/ --use_hpu_graphs --use_kv_cache --max_input_tokens 128 --max_new_tokens 2048 --batch_size 60 --attn_softmax_bf16 --trim_logits --bf16 --reuse_cache --warmup 2 --n_iterations 2 --limit_hpu_graphs --bucket_internal --bucket_size 128 python ../gaudi_spawn.py --use_deepspeed --world_size 8 run_generation.py --model_name_or_path Llama-2-70b-hf/ --use_hpu_graphs --use_kv_cache --max_input_tokens 128 --max_new_tokens 2048 --batch_size 60 --attn_softmax_bf16 --trim_logits --bf16 --reuse_cache --warmup 2 --n_iterations 2 --limit_hpu_graphs --bucket_internal --bucket_size 128 |
Calculate kv cache sliding idx for the decode phase only.
This PR has additional enhancements over #720.
token_idx_cpu is introduced which is an integer rather than a tensor to keep track of buckets. And the switch of buckets happens after the prefill phase.
Assume input tokens as 128 and new tokens as 512.
Without bucketing and slicing changes in these 2 PRs, in the decode phase we used to calculate the attention scores by multiplying with full KV cache of size (512+128).
Now with the changes of these 2 PRs, we l consider KV caches as below:
Decode phases of tokens 128-256 -> sliced KV cache till seq len 256.
Decode phases of tokens 256-384 -> sliced KV cache till seq len 384.
Decode phases of tokens 384-512 -> sliced KV cache till seq len 512.
And so on.
The bucketing changes along with reuse cache gives enhanced performance. See improved performances in below table.
What does this PR do?
Fixes # (issue)
Before submitting