Skip to content

Use KV cache till input seq len for prefill phase#154

Merged
5 commits merged into
HabanaAI:habana-mainfrom
puneeshkhanna:prefill_kvcache
Apr 11, 2024
Merged

Use KV cache till input seq len for prefill phase#154
5 commits merged into
HabanaAI:habana-mainfrom
puneeshkhanna:prefill_kvcache

Conversation

@puneeshkhanna
Copy link
Copy Markdown

Pad KV cache to full input + new tokens len for decode phase. Delete the KV cache used as inputs by HPU graphs after full prompt generation. Ensure KV cache is not returned as output tensor during decode phase. Deletion of KV cache input tensor used by HPU graphs needs to be protected by PT_HPUGRAPH_DISABLE_TENSOR_CACHE env variable.
All the changes are protected by bucket internal flag.

What does this PR do?

Fixes # (issue)

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you make sure to update the documentation with your changes?
  • Did you write any new necessary tests?

Pad KV cache to full input + new tokens len for decode phase.
Delete the KV cache used as inputs by HPU graphs after full prompt generation.
Ensure KV cache is not returned as output tensor during decode phase.
Deletion of KV cache input tensor used by HPU graphs needs to be protected by
PT_HPUGRAPH_DISABLE_TENSOR_CACHE env variable.
All the changes are protected by bucket internal flag.

Signed-off-by: Puneesh Khanna <pkhanna@habana.ai>
@puneeshkhanna
Copy link
Copy Markdown
Author

puneeshkhanna commented Apr 10, 2024

Updated command (remove --reuse_cache , setting PT_HPUGRAPH_DISABLE_TENSOR_CACHE=1 automatically taken care)

python ../gaudi_spawn.py --use_deepspeed --world_size 8 run_generation.py --model_name_or_path /mnt/weka/data/pytorch/llama2/Llama-2-70b-hf/ --use_hpu_graphs --use_kv_cache --max_input_tokens 2048 --max_new_tokens 2048 --batch_size 200 --attn_softmax_bf16 --trim_logits --bf16 --warmup 2 --n_iterations 2 --limit_hpu_graphs --bucket_internal --bucket_size 128

Also requires pytorch-integration patch - https://gerrit.habana-labs.com/#/c/408363/

Comment thread optimum/habana/transformers/models/llama/modeling_llama.py
@ghost ghost merged commit 60b5d9b into HabanaAI:habana-main Apr 11, 2024
sushildubey171 pushed a commit that referenced this pull request Apr 12, 2024
* Use KV cache till input seq len for prefill phase.

Pad KV cache to full input + new tokens len for decode phase.
Delete the KV cache used as inputs by HPU graphs after full prompt generation.
Ensure KV cache is not returned as output tensor during decode phase.
Deletion of KV cache input tensor used by HPU graphs needs to be protected by
PT_HPUGRAPH_DISABLE_TENSOR_CACHE env variable.
All the changes are protected by bucket internal flag.

Signed-off-by: Puneesh Khanna <pkhanna@habana.ai>

* Revert initialization of KV cache

* Set PT_HPUGRAPH_DISABLE_TENSOR_CACHE flag

* remove os import

* remove commented print

---------

Signed-off-by: Puneesh Khanna <pkhanna@habana.ai>
astachowiczhabana pushed a commit that referenced this pull request Apr 19, 2024
* Use KV cache till input seq len for prefill phase.

Pad KV cache to full input + new tokens len for decode phase.
Delete the KV cache used as inputs by HPU graphs after full prompt generation.
Ensure KV cache is not returned as output tensor during decode phase.
Deletion of KV cache input tensor used by HPU graphs needs to be protected by
PT_HPUGRAPH_DISABLE_TENSOR_CACHE env variable.
All the changes are protected by bucket internal flag.

Signed-off-by: Puneesh Khanna <pkhanna@habana.ai>

* Revert initialization of KV cache

* Set PT_HPUGRAPH_DISABLE_TENSOR_CACHE flag

* remove os import

* remove commented print

---------

Signed-off-by: Puneesh Khanna <pkhanna@habana.ai>
astachowiczhabana pushed a commit that referenced this pull request Apr 22, 2024
* Use KV cache till input seq len for prefill phase.

Pad KV cache to full input + new tokens len for decode phase.
Delete the KV cache used as inputs by HPU graphs after full prompt generation.
Ensure KV cache is not returned as output tensor during decode phase.
Deletion of KV cache input tensor used by HPU graphs needs to be protected by
PT_HPUGRAPH_DISABLE_TENSOR_CACHE env variable.
All the changes are protected by bucket internal flag.

Signed-off-by: Puneesh Khanna <pkhanna@habana.ai>

* Revert initialization of KV cache

* Set PT_HPUGRAPH_DISABLE_TENSOR_CACHE flag

* remove os import

* remove commented print

---------

Signed-off-by: Puneesh Khanna <pkhanna@habana.ai>
astachowiczhabana pushed a commit that referenced this pull request Apr 24, 2024
* Use KV cache till input seq len for prefill phase.

Pad KV cache to full input + new tokens len for decode phase.
Delete the KV cache used as inputs by HPU graphs after full prompt generation.
Ensure KV cache is not returned as output tensor during decode phase.
Deletion of KV cache input tensor used by HPU graphs needs to be protected by
PT_HPUGRAPH_DISABLE_TENSOR_CACHE env variable.
All the changes are protected by bucket internal flag.

Signed-off-by: Puneesh Khanna <pkhanna@habana.ai>

* Revert initialization of KV cache

* Set PT_HPUGRAPH_DISABLE_TENSOR_CACHE flag

* remove os import

* remove commented print

---------

Signed-off-by: Puneesh Khanna <pkhanna@habana.ai>
astachowiczhabana pushed a commit that referenced this pull request Apr 24, 2024
* Use KV cache till input seq len for prefill phase.

Pad KV cache to full input + new tokens len for decode phase.
Delete the KV cache used as inputs by HPU graphs after full prompt generation.
Ensure KV cache is not returned as output tensor during decode phase.
Deletion of KV cache input tensor used by HPU graphs needs to be protected by
PT_HPUGRAPH_DISABLE_TENSOR_CACHE env variable.
All the changes are protected by bucket internal flag.

Signed-off-by: Puneesh Khanna <pkhanna@habana.ai>

* Revert initialization of KV cache

* Set PT_HPUGRAPH_DISABLE_TENSOR_CACHE flag

* remove os import

* remove commented print

---------

Signed-off-by: Puneesh Khanna <pkhanna@habana.ai>
puneeshkhanna pushed a commit to puneeshkhanna/optimum-habana-fork that referenced this pull request May 2, 2024
* Use KV cache till input seq len for prefill phase.

Pad KV cache to full input + new tokens len for decode phase.
Delete the KV cache used as inputs by HPU graphs after full prompt generation.
Ensure KV cache is not returned as output tensor during decode phase.
Deletion of KV cache input tensor used by HPU graphs needs to be protected by
PT_HPUGRAPH_DISABLE_TENSOR_CACHE env variable.
All the changes are protected by bucket internal flag.

Signed-off-by: Puneesh Khanna <pkhanna@habana.ai>

* Revert initialization of KV cache

* Set PT_HPUGRAPH_DISABLE_TENSOR_CACHE flag

* remove os import

* remove commented print

---------

Signed-off-by: Puneesh Khanna <pkhanna@habana.ai>
@astachowiczhabana
Copy link
Copy Markdown

huggingface#1028

This pull request was closed.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants