Skip to content

Support bucket_internal for MPT#1137

Merged
regisss merged 1 commit into
huggingface:mainfrom
pk1d3v:mpt7-bucket-internal
Jul 29, 2024
Merged

Support bucket_internal for MPT#1137
regisss merged 1 commit into
huggingface:mainfrom
pk1d3v:mpt7-bucket-internal

Conversation

@pk1d3v
Copy link
Copy Markdown
Contributor

@pk1d3v pk1d3v commented Jul 16, 2024

What does this PR do?

Adds support for --bucket_internal for MPT model.

Measurements:

Param In/Out tokens Bath size Throughput t/s Memory allocated
--bucket_internal 128/128 128 6019.4 37.04 GB
NO --bucket_internal 128/128 128 2919.8 37.28 GB
--bucket_internal 128/256 64 4268.5 31.28 GB
NO --bucket_internal 128/256 64 3265.1 37.03 GB
--bucket_internal 128/512 32 2529.7 28.03 GB
NO --bucket_internal 128/512 32 2543.2 43.16 GB

Command line used: python run_generation.py --model_name_or_path mosaicml/mpt-7b --use_hpu_graphs --use_kv_cache --max_input_tokens 128 --max_new_tokens <num> --trim_logits --bf16 --warmup 2 --n_iterations 2 --limit_hpu_graphs --batch_size=<num> --bucket_size 128 --bucket_internal

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you make sure to update the documentation with your changes?
  • Did you write any new necessary tests?

@pk1d3v pk1d3v requested a review from mandy-li as a code owner July 16, 2024 12:45
@pk1d3v pk1d3v closed this Jul 16, 2024
@pk1d3v pk1d3v reopened this Jul 16, 2024
Copy link
Copy Markdown
Contributor

@imangohari1 imangohari1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good so far.
I think there should be a change.
I tested this with mpt-30b as well below and seems to be running fine.

python run_generation.py --model_name_or_path mosaicml/mpt-30b --use_hpu_graphs --use_kv_cache --max_input_tokens 128 --max_new_tokens 128 --trim_logits --bf16 --warmup 2 --n_iterations 2 --limit_hpu_graphs --batch_size 32 --bucket_size 128 --bucket_internal

Result: 715.6375115391292 tokens/second

python run_generation.py --model_name_or_path mosaicml/mpt-30b --use_hpu_graphs --use_kv_cache --max_input_tokens 128 --max_new_tokens 128 --trim_logits --bf16 --warmup 2 --n_iterations 2 --limit_hpu_graphs --batch_size 32 --bucket_size 128

Result: 471.0239768753105 tokens/second

input_ids = torch.index_select(input_ids, 1, token_idx - 1)
# Converting back to tuples as it should be, so there's no type mismatch when calling graph
past_key_values = tuple([tuple(kv) for kv in past_key_values])
elif bucket_internal and token_idx is not None:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've compared this to the llama/qwen2 (below) and I think this line should be

Suggested change
elif bucket_internal and token_idx is not None:
elif (reuse_cache or bucket_internal) and token_idx is not None:

https://github.com/huggingface/optimum-habana/blob/main/optimum/habana/transformers/models/qwen2/modeling_qwen2.py#L890C14-L890C72
https://github.com/huggingface/optimum-habana/blob/main/optimum/habana/transformers/models/llama/modeling_llama.py#L1117C9-L1117C73

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@imangohari1, thanks for review!

There's no reuse_cache support for MPT yet. That's why I didn't add it to the condition.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

historical context: reuse_cache came first, and then this change: #1028

PR 1028 removes the need for reuse_cache. So for new model optimizations I think it is fine to only make changes in line with PR1028, and leave out reuse_cache related changes.

Only in older models, where we already had reuse_cache code, we accomodate both

Copy link
Copy Markdown
Contributor

@imangohari1 imangohari1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@regisss
Could you take a final look here?

@libinta libinta added run-test Run CI for PRs from external contributors synapse1.17 PR that should be available along with Synapse 1.17 but have no dependency on Synapse 1.17 content. and removed review wip labels Jul 24, 2024
@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Copy Markdown
Collaborator

@regisss regisss left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@regisss regisss merged commit ac79d23 into huggingface:main Jul 29, 2024
@Jing1Ling Jing1Ling mentioned this pull request Aug 1, 2024
3 tasks
@mgonchar mgonchar mentioned this pull request Sep 24, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

run-test Run CI for PRs from external contributors synapse1.17 PR that should be available along with Synapse 1.17 but have no dependency on Synapse 1.17 content.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants