-
Notifications
You must be signed in to change notification settings - Fork 271
Prefill kvcache upstream #942
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
puneeshkhanna
wants to merge
6
commits into
huggingface:main
from
puneeshkhanna:prefill_kvcache_upstream
Closed
Changes from all commits
Commits
Show all changes
6 commits
Select commit
Hold shift + click to select a range
be99027
Use KV cache till input seq len for prefill phase (#154)
967fa47
Sampling search UseKV cache till input seq len for prefill phase (#161)
3fdd1f6
Fix merge conflict
3770d1c
Fix merge conflict
4a04bff
Update modeling_llama.py
7af3dce
fix review comment
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -384,7 +384,8 @@ def pre_attn_forward( | |
| past_value = torch.zeros( | ||
| key_states.shape, dtype=self.k_proj.weight.dtype, device=key_states.device | ||
| ) | ||
| past_key_value = (past_key, past_value) | ||
| # Return list instead of tuple | ||
| past_key_value = [past_key, past_value] | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this could impact tgi as tgi goes through this route |
||
| key_states = self.k_cache.update(past_key_value[0], key_states, 2, token_idx, self.inp_seq_len) | ||
| value_states = self.v_cache.update(past_key_value[1], value_states, 2, token_idx, self.inp_seq_len) | ||
| if token_idx is None: | ||
|
|
@@ -459,6 +460,11 @@ def pre_attn_forward( | |
| if not output_attentions: | ||
| attn_weights = None | ||
|
|
||
| if not reuse_cache and token_idx is not None and cache_idx is not None and q_len == 1: | ||
| # Return only past key value shapes and not the tensors during decode phase (q len is 1) | ||
| # to avoid making past key values as persistent output tensors of HPU graphs. | ||
| past_key_value = (past_key_value[0].shape, past_key_value[1].shape) | ||
|
|
||
| return attn_output, attn_weights, past_key_value | ||
|
|
||
| def attention_all_reduce(self, attn_output): | ||
|
|
@@ -800,6 +806,7 @@ def forward( | |
| use_flash_attention, | ||
| flash_attention_recompute, | ||
| flash_attention_causal_mask, | ||
| None, | ||
| ) | ||
| else: | ||
| layer_outputs = decoder_layer( | ||
|
|
@@ -968,6 +975,7 @@ def prepare_inputs_for_generation( | |
| past_length = 0 | ||
|
|
||
| reuse_cache = kwargs.get("reuse_cache") | ||
| bucket_internal = kwargs.get("bucket_internal") | ||
| if past_key_values is not None: | ||
| if token_idx is not None: | ||
| input_ids = torch.index_select(input_ids, 1, token_idx - 1) | ||
|
|
@@ -999,8 +1007,9 @@ def prepare_inputs_for_generation( | |
| and cache_length + input_ids.shape[1] > max_cache_length | ||
| ): | ||
| attention_mask = attention_mask[:, -max_cache_length:] | ||
| elif reuse_cache and token_idx is not None: | ||
| # With reuse_cache, KV cache is pre allocated hence for the 1st token we can slice the inputs till token idx for the fwd pass | ||
| elif (reuse_cache or bucket_internal) and token_idx is not None: | ||
| # KV cache is pre allocated with reuse cache or will be padded with bucket internal | ||
| # hence for the 1st token we can slice the inputs till token idx for the fwd pass. | ||
| input_ids = input_ids[:, :token_idx] | ||
| attention_mask = attention_mask[:, :token_idx] | ||
|
|
||
|
|
||
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.