Skip to content

refine use cache for mpt model#1158

Merged
regisss merged 2 commits into
huggingface:mainfrom
Jing1Ling:mpt_refine_use_cache
Aug 2, 2024
Merged

refine use cache for mpt model#1158
regisss merged 2 commits into
huggingface:mainfrom
Jing1Ling:mpt_refine_use_cache

Conversation

@Jing1Ling
Copy link
Copy Markdown
Contributor

@Jing1Ling Jing1Ling commented Jul 25, 2024

What does this PR do?

Modified the kv_cache initialization method and optimized performance.
Co-author: @atakaha
Test command:

python run_generation.py --model_name_or_path mosaicml/mpt-7b --use_hpu_graphs --use_kv_cache --limit_hpu_graph --batch_size 128  --max_input_tokens 128 --max_new_tokens 128 --trim_logits --attn_softmax_bf16 --warmup 3 --n_iterations 1 --bf16

Result:

Version batchsize max input tokens max new tokens Throughput (including tokenization)(tokens/s) Memory allocated(GB) Max memory allocated(GB)
before 128 128 128 2900 37.28 64.39
after 128 128 128 4803 29.28 48.39
before 16 128 1024 624 26.55 41.79
after 16 128 1024 1215 23.04 33.77
before 2 1024 1024 197 16.27 19.66
after 2 1024 1024 255 15.27 17.66
before 16 1024 1024 384 40.78 67.86
after 16 1024 1024 846 32.78 51.86

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you make sure to update the documentation with your changes?
  • Did you write any new necessary tests?

@Jing1Ling Jing1Ling requested a review from mandy-li as a code owner July 25, 2024 07:51
@Jing1Ling Jing1Ling changed the title refine use cache refine use cache for mpt model Jul 25, 2024
@yafshar
Copy link
Copy Markdown
Contributor

yafshar commented Jul 25, 2024

@Jing1Ling can you check the other PR #1126, the code base has changed there

@Jing1Ling Jing1Ling mentioned this pull request Jul 26, 2024
3 tasks
@Jing1Ling
Copy link
Copy Markdown
Contributor Author

@Jing1Ling can you check the other PR #1126, the code base has changed there

@yafshar Thank you for the reminder. I believe this patch is a more concise approach. It only changes a few lines of code, achieves lower memory usage and similar throughput.

@Jing1Ling
Copy link
Copy Markdown
Contributor Author

Why this change causes such a large decrease in memory usage:
The conclusion is that 'query_states', 'key_states', and 'value_states' are all views of 'mixed_qkv'. When I create a new tensor tuple and copy 'key_states' and 'value_states' into it, 'mixed_qkv' will be released after the attention function returns.

Before the modification, 'key_states' and 'value_states' were saved in 'past_key_value' as views of 'mixed_qkv', so 'mixed_qkv' could not be freed. In other words, with the change in this PR, each transformer layer (block) can save the space of one 'query_states' during the decode first token process(create k, v but release mixed_qkv, thus saving the space of one q).

I used torch.hpu.synchronize() and torch.hpu.memory_summary() to do memory statistics at each stage of decoding the first token.

# exit after decoding the first token
python run_generation.py --model_name_or_path mosaicml/mpt-7b --use_hpu_graphs --use_kv_cache --limit_hpu_graph --batch_size 8 --max_input_tokens 128 --max_new_tokens 5 --trim_logits --attn_softmax_bf16 --warmup 3 --n_iterations 1 --bf16

Here is some evidence:

image
image

@yafshar
Copy link
Copy Markdown
Contributor

yafshar commented Jul 26, 2024

@Jing1Ling thanks.
@atakaha is this PR covers everything you added in #1126

@atakaha
Copy link
Copy Markdown
Contributor

atakaha commented Jul 26, 2024

@Jing1Ling thanks. @atakaha is this PR covers everything you added in #1126

In the point of performance, YES. In the point of "--reuse_cache" implementation like other modes, no.

@libinta libinta added the synapse1.17 PR that should be available along with Synapse 1.17 but have no dependency on Synapse 1.17 content. label Jul 26, 2024
@yafshar
Copy link
Copy Markdown
Contributor

yafshar commented Jul 26, 2024

@Jing1Ling thanks. @atakaha is this PR covers everything you added in #1126

In the point of performance, YES. In the point of "--reuse_cache" implementation like other modes, no.

@atakaha can you tell me what is missing for --reuse_cache implementation? With the current implementation the memory usage seems the same and I think that is the main point of reuse_cache. So if you add that extra code what is the benefit? Am I missing anything?

Copy link
Copy Markdown
Contributor

@ssarkar2 ssarkar2 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you think any other model might benefit from this? As far as I can tell, mixed qkv is a feature of mpt/dbrx, and of the 2 we only have MPT here right?

@yafshar
Copy link
Copy Markdown
Contributor

yafshar commented Jul 26, 2024

Do you think any other model might benefit from this? As far as I can tell, mixed qkv is a feature of mpt/dbrx, and of the 2 we only have MPT here right?

Here is only MPT. I think the same can be useful for any similar model

Copy link
Copy Markdown
Contributor

@yafshar yafshar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@regisss please take a look at this PR. Minimal code change!

@atakaha
Copy link
Copy Markdown
Contributor

atakaha commented Jul 29, 2024

@Jing1Ling thanks. @atakaha is this PR covers everything you added in #1126

In the point of performance, YES. In the point of "--reuse_cache" implementation like other modes, no.

@atakaha can you tell me what is missing for --reuse_cache implementation? With the current implementation the memory usage seems the same and I think that is the main point of reuse_cache. So if you add that extra code what is the benefit? Am I missing anything?

The --reuse_cache implementation is followed other models, such llama, mistral approach. This PR introduce same performance of improvement of --reuse_cache implementation.

@yafshar
Copy link
Copy Markdown
Contributor

yafshar commented Jul 29, 2024

@Jing1Ling thanks. @atakaha is this PR covers everything you added in #1126

In the point of performance, YES. In the point of "--reuse_cache" implementation like other modes, no.

@atakaha can you tell me what is missing for --reuse_cache implementation? With the current implementation the memory usage seems the same and I think that is the main point of reuse_cache. So if you add that extra code what is the benefit? Am I missing anything?

The --reuse_cache implementation is followed other models, such llama, mistral approach. This PR introduce same performance of improvement of --reuse_cache implementation.

@atakaha thanks. Then please close #1126, let us merge this PR first and if there is any missing functionality we can start a new PR on top of this. Thanks for your contribution

@atakaha
Copy link
Copy Markdown
Contributor

atakaha commented Jul 29, 2024

@Jing1Ling thanks. @atakaha is this PR covers everything you added in #1126

In the point of performance, YES. In the point of "--reuse_cache" implementation like other modes, no.

@atakaha can you tell me what is missing for --reuse_cache implementation? With the current implementation the memory usage seems the same and I think that is the main point of reuse_cache. So if you add that extra code what is the benefit? Am I missing anything?

The --reuse_cache implementation is followed other models, such llama, mistral approach. This PR introduce same performance of improvement of --reuse_cache implementation.

@atakaha thanks. Then please close #1126, let us merge this PR first and if there is any missing functionality we can start a new PR on top of this. Thanks for your contribution

@yafshar, #1126 is closed.

key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
past_key_value = (key_states, value_states)
past_key_value = (key_states, value_states)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't it be actually:

else:
    key_states = torch.cat([past_key_value[0], key_states], dim=2)
    value_states = torch.cat([past_key_value[1], value_states], dim=2)
past_key_value = (key_states, value_states)

Copy link
Copy Markdown
Contributor

@yafshar yafshar Jul 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't it be actually:

else:
    key_states = torch.cat([past_key_value[0], key_states], dim=2)
    value_states = torch.cat([past_key_value[1], value_states], dim=2)
past_key_value = (key_states, value_states)

No this is wrong! It was like this before! please look at the code in the if condition we are updating the past_key_value and there is no need to destroy and create it again. Below is the old code:

            if token_idx is not None:
                past_key_value[0].index_copy_(2, token_idx - 1, key_states)
                past_key_value[1].index_copy_(2, token_idx - 1, value_states)
                key_states = past_key_value[0]
                value_states = past_key_value[1]
            else:
                key_states = torch.cat([past_key_value[0], key_states], dim=2)
                value_states = torch.cat([past_key_value[1], value_states], dim=2)
        past_key_value = (key_states, value_states)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, thanks for clarification

@Jing1Ling Jing1Ling force-pushed the mpt_refine_use_cache branch from ea36567 to 02c3f99 Compare August 1, 2024 14:58
Co-authored-by: atakaha <akihiro.takahashi@intel.com>
Comment on lines +72 to +75
past_key_value = [torch.empty(key_states.shape, dtype=key_states.dtype,device=key_states.device),
torch.empty(key_states.shape, dtype=key_states.dtype,device=key_states.device)]
past_key_value[0][:] = key_states[:]
past_key_value[1][:] = value_states[:]
Copy link
Copy Markdown
Contributor

@pk1d3v pk1d3v Aug 1, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Jing1Ling Could you please add your description in PR discussion here as comment for easier code understanding and future references? It's much easier to understand the code if there's comment rather than going thru commits and associated PRs.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed the past_key_value in this PR from tuple type to list to make it compatible with this merged PR.

@libinta libinta added the run-test Run CI for PRs from external contributors label Aug 1, 2024
vidyasiv pushed a commit to emascarenhas/optimum-habana that referenced this pull request Aug 1, 2024
vidyasiv added a commit to emascarenhas/optimum-habana that referenced this pull request Aug 2, 2024
Copy link
Copy Markdown
Collaborator

@regisss regisss left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@regisss regisss merged commit 9a5fe2c into huggingface:main Aug 2, 2024
@Jing1Ling Jing1Ling deleted the mpt_refine_use_cache branch October 6, 2024 14:27
@Jing1Ling Jing1Ling restored the mpt_refine_use_cache branch October 6, 2024 14:28
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

run-test Run CI for PRs from external contributors synapse1.17 PR that should be available along with Synapse 1.17 but have no dependency on Synapse 1.17 content.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

9 participants