Skip to content

enable internal kv bucket in llama#658

Closed
x574chen wants to merge 0 commit into
huggingface:mainfrom
x574chen:llama_internal_bucket
Closed

enable internal kv bucket in llama#658
x574chen wants to merge 0 commit into
huggingface:mainfrom
x574chen:llama_internal_bucket

Conversation

@x574chen
Copy link
Copy Markdown

What does this PR do?

To enhance throughput in scenarios with long new tokens, break down the KV cache into multiples of the bucket width. Use this to compute attention rather than using the entire KV cache. Below are some results from LLaMA2 7B/70B on Gaudi2:

  TP Input Length Output Length BS Base Throughput Throughput w/ internal kv bucket
LLaMA v2-7B 1 128 2048 76 1402 2217 (bucket=128)
LLaMA v2-7B 2 2048 2048 64 1417 1672 (bucket=128)
LLaMA  v2-70B 4 128 2048 240 2834 3638 (bucket=256)

Add --bucket_size=128 --bucket_internal to the commands to enable the feature.

@x574chen x574chen requested a review from a user January 23, 2024 14:51
@x574chen x574chen requested a review from regisss as a code owner January 23, 2024 14:51
if cache_idx is not None and q_len == 1:
key_states = key_states[:, :, :cache_idx, :]
value_states = value_states[:, :, :cache_idx, :]
attention_mask = attention_mask[:, :, :, :cache_idx]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a check whether attention_mask is not None

@puneeshkhanna
Copy link
Copy Markdown
Contributor

@ssarkar2 - Maybe we should remove the original bucketing logic in separate PR later for simplicity of the overall code once we are convinced that this PR bucketing logic is best for all cases.

Btw everyone - I will add an option of clear cache too in utils.py (just an API call to release HPU graph memory) in a separate PR to address some corner cases where memory may increase with the bucketing changes of this PR.

@ssarkar2
Copy link
Copy Markdown
Contributor

@puneeshkhanna , the original external bucketing is general for any model and does not need model file change. It might be useful for unknown/new/unoptimized models.
However for max perf, we have to modify model files for internal cache (like this one).

Let me know if its worth it to keep the general external one that might work for any model

@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

if cache_len and bucket_size > 0:
idx = torch.div(token_idx - 1, bucket_size, rounding_mode="floor")
if idx < (cache_len // bucket_size):
cache_idx = (idx.item() + 1) * bucket_size
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@x574chen - Just one query here that do we need to do .item() here because this will cause a sync back to CPU and graph ? Can it work without .item() ?

Copy link
Copy Markdown
Contributor

@puneeshkhanna puneeshkhanna Feb 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@x574chen - Also one more query here that can we move lines 823 to 829 to utils.py only and pass cache_idx in kwargs and just have one line here that cache_idx = kwargs.get("cache_idx"). It will make the bucketing changes easier for other models. Sorry for all these late review comments. allocate_kv_cache() can maybe return kv len to utils.py.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rest all changes look good to me ; basically what I m thinking is that we just pass cache_idx in modeling_llama.py and just have the additional change of slicing the KV cache in the attention block code.

Copy link
Copy Markdown
Contributor

@xt574chen xt574chen Feb 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The .item() is used to ensure that cache_idx (one of model inputs) is an integer, not a tensor. This prevents HPUGraph from calling replay when the value of cache_idx changes. Please correct me if my understanding of the use of HPU graph is incorrect.

Also, I have tried to not use .item() to calculate cache_idx every step, but the performance doesn't appear to be significantly impacted. Therefore, I have not made further changes in the repo to avoid .item usage here.

Comment thread examples/text-generation/utils.py Outdated
assert generation_config.bucket_size > 0
generation_config.kv_cache_fp8 = args.kv_cache_fp8
generation_config.use_flash_attention = args.use_flash_attention
generation_config.bucket_internal = args.bucket_internal
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also we need initialize this in optimum-habana/optimum/habana/transformers/generation/configuration_utils.py. I think CI will fail without that change

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants