refine use cache for mpt model#1158
Conversation
|
@Jing1Ling can you check the other PR #1126, the code base has changed there |
@yafshar Thank you for the reminder. I believe this patch is a more concise approach. It only changes a few lines of code, achieves lower memory usage and similar throughput. |
|
Why this change causes such a large decrease in memory usage: Before the modification, 'key_states' and 'value_states' were saved in 'past_key_value' as views of 'mixed_qkv', so 'mixed_qkv' could not be freed. In other words, with the change in this PR, each transformer layer (block) can save the space of one 'query_states' during the decode first token process(create k, v but release mixed_qkv, thus saving the space of one q). I used torch.hpu.synchronize() and torch.hpu.memory_summary() to do memory statistics at each stage of decoding the first token. # exit after decoding the first token
python run_generation.py --model_name_or_path mosaicml/mpt-7b --use_hpu_graphs --use_kv_cache --limit_hpu_graph --batch_size 8 --max_input_tokens 128 --max_new_tokens 5 --trim_logits --attn_softmax_bf16 --warmup 3 --n_iterations 1 --bf16Here is some evidence: |
|
@Jing1Ling thanks. |
In the point of performance, YES. In the point of "--reuse_cache" implementation like other modes, no. |
@atakaha can you tell me what is missing for |
ssarkar2
left a comment
There was a problem hiding this comment.
Do you think any other model might benefit from this? As far as I can tell, mixed qkv is a feature of mpt/dbrx, and of the 2 we only have MPT here right?
Here is only MPT. I think the same can be useful for any similar model |
The |
@atakaha thanks. Then please close #1126, let us merge this PR first and if there is any missing functionality we can start a new PR on top of this. Thanks for your contribution |
|
| key_states = torch.cat([past_key_value[0], key_states], dim=2) | ||
| value_states = torch.cat([past_key_value[1], value_states], dim=2) | ||
| past_key_value = (key_states, value_states) | ||
| past_key_value = (key_states, value_states) |
There was a problem hiding this comment.
shouldn't it be actually:
else:
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
past_key_value = (key_states, value_states)
There was a problem hiding this comment.
shouldn't it be actually:
else: key_states = torch.cat([past_key_value[0], key_states], dim=2) value_states = torch.cat([past_key_value[1], value_states], dim=2) past_key_value = (key_states, value_states)
No this is wrong! It was like this before! please look at the code in the if condition we are updating the past_key_value and there is no need to destroy and create it again. Below is the old code:
if token_idx is not None:
past_key_value[0].index_copy_(2, token_idx - 1, key_states)
past_key_value[1].index_copy_(2, token_idx - 1, value_states)
key_states = past_key_value[0]
value_states = past_key_value[1]
else:
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
past_key_value = (key_states, value_states)There was a problem hiding this comment.
ok, thanks for clarification
ea36567 to
02c3f99
Compare
Co-authored-by: atakaha <akihiro.takahashi@intel.com>
| past_key_value = [torch.empty(key_states.shape, dtype=key_states.dtype,device=key_states.device), | ||
| torch.empty(key_states.shape, dtype=key_states.dtype,device=key_states.device)] | ||
| past_key_value[0][:] = key_states[:] | ||
| past_key_value[1][:] = value_states[:] |
There was a problem hiding this comment.
@Jing1Ling Could you please add your description in PR discussion here as comment for easier code understanding and future references? It's much easier to understand the code if there's comment rather than going thru commits and associated PRs.
There was a problem hiding this comment.
I changed the past_key_value in this PR from tuple type to list to make it compatible with this merged PR.
refine use cache for mpt model huggingface#1158
refine use cache for mpt model huggingface#1158
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |


What does this PR do?
Modified the kv_cache initialization method and optimized performance.
Co-author: @atakaha
Test command:
Result:
Before submitting