enable internal kv bucket in llama#658
Conversation
| if cache_idx is not None and q_len == 1: | ||
| key_states = key_states[:, :, :cache_idx, :] | ||
| value_states = value_states[:, :, :cache_idx, :] | ||
| attention_mask = attention_mask[:, :, :, :cache_idx] |
There was a problem hiding this comment.
Add a check whether attention_mask is not None
|
@ssarkar2 - Maybe we should remove the original bucketing logic in separate PR later for simplicity of the overall code once we are convinced that this PR bucketing logic is best for all cases. Btw everyone - I will add an option of clear cache too in utils.py (just an API call to release HPU graph memory) in a separate PR to address some corner cases where memory may increase with the bucketing changes of this PR. |
|
@puneeshkhanna , the original external bucketing is general for any model and does not need model file change. It might be useful for unknown/new/unoptimized models. Let me know if its worth it to keep the general external one that might work for any model |
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
| if cache_len and bucket_size > 0: | ||
| idx = torch.div(token_idx - 1, bucket_size, rounding_mode="floor") | ||
| if idx < (cache_len // bucket_size): | ||
| cache_idx = (idx.item() + 1) * bucket_size |
There was a problem hiding this comment.
@x574chen - Just one query here that do we need to do .item() here because this will cause a sync back to CPU and graph ? Can it work without .item() ?
There was a problem hiding this comment.
@x574chen - Also one more query here that can we move lines 823 to 829 to utils.py only and pass cache_idx in kwargs and just have one line here that cache_idx = kwargs.get("cache_idx"). It will make the bucketing changes easier for other models. Sorry for all these late review comments. allocate_kv_cache() can maybe return kv len to utils.py.
There was a problem hiding this comment.
Rest all changes look good to me ; basically what I m thinking is that we just pass cache_idx in modeling_llama.py and just have the additional change of slicing the KV cache in the attention block code.
There was a problem hiding this comment.
The .item() is used to ensure that cache_idx (one of model inputs) is an integer, not a tensor. This prevents HPUGraph from calling replay when the value of cache_idx changes. Please correct me if my understanding of the use of HPU graph is incorrect.
Also, I have tried to not use .item() to calculate cache_idx every step, but the performance doesn't appear to be significantly impacted. Therefore, I have not made further changes in the repo to avoid .item usage here.
| assert generation_config.bucket_size > 0 | ||
| generation_config.kv_cache_fp8 = args.kv_cache_fp8 | ||
| generation_config.use_flash_attention = args.use_flash_attention | ||
| generation_config.bucket_internal = args.bucket_internal |
There was a problem hiding this comment.
Also we need initialize this in optimum-habana/optimum/habana/transformers/generation/configuration_utils.py. I think CI will fail without that change
f4dd4ba to
f096980
Compare
What does this PR do?
To enhance throughput in scenarios with long new tokens, break down the KV cache into multiples of the bucket width. Use this to compute attention rather than using the entire KV cache. Below are some results from LLaMA2 7B/70B on Gaudi2:
Add
--bucket_size=128 --bucket_internalto the commands to enable the feature.