enable internal kv bucket in llama#24
Conversation
| parser.add_argument( | ||
| "--bucket_internal", | ||
| action="store_true", | ||
| help="Split kv sequence into buckets in decode phase. It is useful for long new tokens.", |
There was a problem hiding this comment.
It improves throughput when max_new_tokens is large
| if idx < (model_kwargs["kv_cache_len"] // bucket_size): | ||
| cache_idx = (idx.item() + 1) * bucket_size | ||
| model_kwargs["cache_idx"] = cache_idx | ||
|
|
There was a problem hiding this comment.
@xt574chen - this logic will work only when your total generated length is multiple of bucket size. For example consider an example of total length as 2060. So for tokens getting generated between 2048 and 2060, KV cache will be sliced till seq len 2048 and KV values between 2048 and 2060 won't be considered.
Please find updated logic below (spent a lot of time reviewing all the changes today):
if model_kwargs.get("token_idx") <= (model_kwargs["kv_cache_len"] // bucket_size) * bucket_size:
idx = torch.div(model_kwargs.get("token_idx") - 1, bucket_size, rounding_mode="floor")
cache_idx = (idx.item() + 1) * bucket_size
model_kwargs["cache_idx"] = cache_idx
else:
model_kwargs["cache_idx"] = model_kwargs["kv_cache_len"]
We can also further enhance a bit more for avoiding .item() call when the idx tensor is not changing. But lets avoid that minor enhancement for now. We can push separate PR later.
More importantly above logic needs to go in first so that the model logic works fine.
There was a problem hiding this comment.
@xt574chen Further enhanced code can be as below. I will let you decide the best course of action.
#Declare prev_idx = None outside while loop.
if model_kwargs.get("token_idx") <= (model_kwargs["kv_cache_len"] // bucket_size) * bucket_size:
idx = torch.div(model_kwargs.get("token_idx") - 1, bucket_size, rounding_mode="floor")
if idx != prev_idx:
cache_idx = (idx.item() + 1) * bucket_size
model_kwargs["cache_idx"] = cache_idx
prev_idx = idx
else:
model_kwargs["cache_idx"] = model_kwargs["kv_cache_len"]
There was a problem hiding this comment.
I tested the recommended logic (without the enhancement) quite a bit and seems to be working fine. @xt574chen - please test from your side and feel free to update anything. But the original code has the issue as I highlighted in earlier comments.
|
@dvarshney-habana - check comment. Once addressed then we can merge. |
|
@puneeshkhanna updated, thank you! |
|
@xt574chen - Thank you. Hope you also verified the changes and we are not missing any corner cases. @dvarshney-habana - lets merge it so that we can start testing in nightly jobs too and we can see an impact of improved performances with bucketing. |
|
@xt574chen - As an example, lets take below 2 configs
Throughput (including tokenization) = 2924.5992187903807 tokens/second
|
* enable internal kv bucket in llama * initialize bucket_internal for CI * make bucket_internal more clear * further perf optim while max length is not multiple of bucket size
* enable internal kv bucket in llama * initialize bucket_internal for CI * make bucket_internal more clear * further perf optim while max length is not multiple of bucket size
* enable internal kv bucket in llama * initialize bucket_internal for CI * make bucket_internal more clear * further perf optim while max length is not multiple of bucket size
* enable internal kv bucket in llama * initialize bucket_internal for CI * make bucket_internal more clear * further perf optim while max length is not multiple of bucket size
* enable internal kv bucket in llama * initialize bucket_internal for CI * make bucket_internal more clear * further perf optim while max length is not multiple of bucket size
What does this PR do?
To enhance throughput in scenarios with long new tokens, break down the KV cache into multiples of the bucket width. Use this to compute attention rather than using the entire KV cache.

Add
--bucket_size=128 --bucket_internalto the commands to enable the feature.