Skip to content

Prefill kvcache upstream#942

Closed
puneeshkhanna wants to merge 6 commits into
huggingface:mainfrom
puneeshkhanna:prefill_kvcache_upstream
Closed

Prefill kvcache upstream#942
puneeshkhanna wants to merge 6 commits into
huggingface:mainfrom
puneeshkhanna:prefill_kvcache_upstream

Conversation

@puneeshkhanna
Copy link
Copy Markdown
Contributor

What does this PR do?

Fixes # (issue)

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you make sure to update the documentation with your changes?
  • Did you write any new necessary tests?

Puneesh Khanna added 2 commits May 2, 2024 08:05
* Use KV cache till input seq len for prefill phase.

Pad KV cache to full input + new tokens len for decode phase.
Delete the KV cache used as inputs by HPU graphs after full prompt generation.
Ensure KV cache is not returned as output tensor during decode phase.
Deletion of KV cache input tensor used by HPU graphs needs to be protected by
PT_HPUGRAPH_DISABLE_TENSOR_CACHE env variable.
All the changes are protected by bucket internal flag.

Signed-off-by: Puneesh Khanna <pkhanna@habana.ai>

* Revert initialization of KV cache

* Set PT_HPUGRAPH_DISABLE_TENSOR_CACHE flag

* remove os import

* remove commented print

---------

Signed-off-by: Puneesh Khanna <pkhanna@habana.ai>
…anaAI#161)

* Sampling search UseKV cache till input seq len for prefill phase

* Remove redundant line
@puneeshkhanna puneeshkhanna requested a review from a user May 2, 2024 05:09
@puneeshkhanna puneeshkhanna requested a review from regisss as a code owner May 2, 2024 05:09
@puneeshkhanna
Copy link
Copy Markdown
Contributor Author

@regisss, @libinta, @dvarshney-habana - Please add 1.16 synpase release label to this.

@puneeshkhanna
Copy link
Copy Markdown
Contributor Author

Description of the changes in this PR -

Pad KV cache to full input + new tokens len for decode phase. Delete the KV cache used as inputs by HPU graphs after full prompt generation. Ensure KV cache is not returned as output tensor during decode phase. Deletion of KV cache input tensor used by HPU graphs needs to be protected by PT_HPUGRAPH_DISABLE_TENSOR_CACHE env variable. All the changes are protected by bucket internal flag right now.

Updated command (remove --reuse_cache from all existing commands , setting PT_HPUGRAPH_DISABLE_TENSOR_CACHE=1 automatically taken care)
python ../gaudi_spawn.py --use_deepspeed --world_size 8 run_generation.py --model_name_or_path /mnt/weka/data/pytorch/llama2/Llama-2-70b-hf/ --use_hpu_graphs --use_kv_cache --max_input_tokens 2048 --max_new_tokens 2048 --batch_size 200 --attn_softmax_bf16 --trim_logits --bf16 --warmup 2 --n_iterations 2 --limit_hpu_graphs --bucket_internal --bucket_size 128

With the changes in this PR, performance in any existing configs remains same but we can scale batch sizes to much much higher numbers since we save a lot of memory during the prefill phase.
As an example with 2K input tokens + 2K new tokens, llama 70B on 8x with flash attention - maximum batch size without PR changes that we can go is around 270 and with the changes in this PR, we can go up to batch size 370.

@libinta libinta added the synapse 1.16_dependency synapse 1.16 dependency label May 2, 2024
)
past_key_value = (past_key, past_value)
# Return list instead of tuple
past_key_value = [past_key, past_value]
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this could impact tgi as tgi goes through this route

Comment thread optimum/habana/transformers/models/llama/modeling_llama.py
@ssarkar2 ssarkar2 removed the synapse 1.16_dependency synapse 1.16 dependency label May 31, 2024
@libinta libinta closed this Jun 8, 2024
@ssarkar2
Copy link
Copy Markdown
Contributor

merged thru: #1028

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants