Skip to content

Update Mixtral-8x7B Optimization#836

Closed
jychen21 wants to merge 23 commits into
huggingface:mainfrom
jychen21:update-mixtral-optimizations
Closed

Update Mixtral-8x7B Optimization#836
jychen21 wants to merge 23 commits into
huggingface:mainfrom
jychen21:update-mixtral-optimizations

Conversation

@jychen21
Copy link
Copy Markdown

@jychen21 jychen21 commented Mar 26, 2024

What does this PR do?

  • Update Mixtral-8x7B Optimization:
    reuse_cache / enable FP8 KV Cache / FP8 Attn / bucket_internal ...

  • Support long sequence prompt

QUANT_CONFIG=./quantization_config/maxabs_measure.json python run_generation.py \
--model_name_or_path mistralai/Mixtral-8x7B-v0.1  \
--use_hpu_graphs \
--use_kv_cache \
--limit_hpu_graphs \
--reuse_cache \
--bucket_size 128 \
--bucket_internal \
--max_new_tokens 100 \
--bf16 \
--batch_size 1

QUANT_CONFIG=./quantization_config/maxabs_quant_mixtral.json python run_generation.py \
--model_name_or_path mistralai/Mixtral-8x7B-v0.1  \
--use_hpu_graphs \
--use_kv_cache \
--limit_hpu_graphs \
--reuse_cache \
--bucket_internal \
--bucket_size 128 \
--max_new_tokens 100 \
--bf16 \
--fp8 \
--batch_size 2 \
--warmup 1 \
--n_iterations 1 \
--max_input_tokens 32000

image

@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Copy Markdown
Collaborator

@mandy-li mandy-li left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jychen-habana , as we sync off-line:

  1. kv_cache_fp8 is the previous way to support fp8 inference which will be removed soon. All the models fp8 inference should use HQT.
  2. Your current code in this PR causes regression for HQT measurement.

@mandy-li
Copy link
Copy Markdown
Collaborator

@schoi-habana , please provide details of how you optimized Falcon-180b fp8 for Jinyan to follow to add to this model. thanks

Comment thread optimum/habana/transformers/models/mixtral/modeling_mixtral.py Outdated
Comment thread optimum/habana/transformers/models/mixtral/modeling_mixtral.py Outdated
Comment thread optimum/habana/transformers/models/mixtral/modeling_mixtral.py Outdated
Comment thread optimum/habana/transformers/models/mixtral/modeling_mixtral.py
@schoi-habana
Copy link
Copy Markdown
Collaborator

I tested this PR with run_generation.py in 1.16.0 docker. It could fit 30k input tokens but the generated output was empty. Did you check the output?

input 1: ('DeepSpeed is a machine learning framework',)
output 1: ('DeepSpeed is a machine learning framework',)

@schoi-habana
Copy link
Copy Markdown
Collaborator

schoi-habana commented Apr 8, 2024

@jychen-habana after you implement ScopedLinearAllreduce, please see if in-place addition in this PR HabanaAI#65 helps this model

@jychen21
Copy link
Copy Markdown
Author

jychen21 commented Apr 9, 2024

I tested this PR with run_generation.py in 1.16.0 docker. It could fit 30k input tokens but the generated output was empty. Did you check the output?

input 1: ('DeepSpeed is a machine learning framework',) output 1: ('DeepSpeed is a machine learning framework',)

In 1.15 steup env, I didn't get this issue.

@jychen21
Copy link
Copy Markdown
Author

jychen21 commented Apr 9, 2024

@jychen-habana , as we sync off-line:

  1. kv_cache_fp8 is the previous way to support fp8 inference which will be removed soon. All the models fp8 inference should use HQT.
  2. Your current code in this PR causes regression for HQT measurement.

fixed.

@jychen21
Copy link
Copy Markdown
Author

jychen21 commented Apr 9, 2024

@jychen-habana after you implement ScopedLinearAllreduce, please see if in-place addition in this PR HabanaAI#65 helps this model

Sure.

Comment thread optimum/habana/transformers/models/mixtral/modeling_mixtral.py Outdated
Comment thread optimum/habana/transformers/models/mixtral/modeling_mixtral.py Outdated
@ZhaiFeiyue ZhaiFeiyue added the run-test Run CI for PRs from external contributors label Apr 15, 2024
Comment thread optimum/habana/transformers/models/mixtral/modeling_mixtral.py
Comment thread optimum/habana/transformers/models/mixtral/modeling_mixtral.py
Comment thread optimum/habana/transformers/models/mixtral/modeling_mixtral.py
@mandy-li
Copy link
Copy Markdown
Collaborator

@jychen-habana , please post the performance measurements with/without this PR here.

@mandy-li
Copy link
Copy Markdown
Collaborator

@jychen-habana , please rebase to latest code in OH main branch

@mandy-li
Copy link
Copy Markdown
Collaborator

@jychen-habana , this PR doesn't work with Synapse 1.15 release docker when measurement.

QUANT_CONFIG=./quantization_config/maxabs_measure.json python run_generation.py --model_name_or_path /mnt/weka/data/mixtral/models--mistralai--Mixtral-8x7B-Instruct-v0.1/snapshots/1e637f2d7cb0a9d6fb1922f305cb784995190a83/ --use_hpu_graphs --use_kv_cache --limit_hpu_graphs --bucket_size 128 --max_new_tokens 128 --batch_size 1 --bf16

Error:

File "/home/jwang/test/optimum-habana-jychen/optimum/habana/transformers/models/mixtral/modeling_mixtral.py", line 787, in forward
outputs = self.model(
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1514, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1564, in _call_impl
result = forward_call(*args, **kwargs)
File "/home/jwang/test/optimum-habana-jychen/optimum/habana/transformers/models/mixtral/modeling_mixtral.py", line 692, in forward
layer_outputs = decoder_layer(
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1514, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1564, in _call_impl
result = forward_call(*args, **kwargs)
File "/home/jwang/test/optimum-habana-jychen/optimum/habana/transformers/models/mixtral/modeling_mixtral.py", line 518, in forward
hidden_states, self_attn_weights, present_key_value = self.self_attn(
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1514, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1564, in _call_impl
result = forward_call(*args, **kwargs)
File "/home/jwang/test/optimum-habana-jychen/optimum/habana/transformers/models/mixtral/modeling_mixtral.py", line 356, in forward
key_states = self.k_cache.update(past_key_value[0], key_states, 2, token_idx, self.inp_seq_len)
File "/usr/local/lib/python3.10/dist-packages/habana_quantization_toolkit/_quant_common/helper_modules.py", line 264, in update
qinput = self.quant_input_0(cur)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1691, in getattr
raise AttributeError(f"'{type(self).name}' object has no attribute '{name}'")
AttributeError: 'PatchedKVCache' object has no attribute 'quant_input_0'

@jychen21
Copy link
Copy Markdown
Author

jychen21 commented Apr 18, 2024

Do not merge! Will break this PR into small pieces: #898 #901 #903

@jychen21
Copy link
Copy Markdown
Author

@jychen-habana , this PR doesn't work with Synapse 1.15 release docker when measurement.

QUANT_CONFIG=./quantization_config/maxabs_measure.json python run_generation.py --model_name_or_path /mnt/weka/data/mixtral/models--mistralai--Mixtral-8x7B-Instruct-v0.1/snapshots/1e637f2d7cb0a9d6fb1922f305cb784995190a83/ --use_hpu_graphs --use_kv_cache --limit_hpu_graphs --bucket_size 128 --max_new_tokens 128 --batch_size 1 --bf16

Error:

File "/home/jwang/test/optimum-habana-jychen/optimum/habana/transformers/models/mixtral/modeling_mixtral.py", line 787, in forward outputs = self.model( File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1514, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1564, in _call_impl result = forward_call(*args, **kwargs) File "/home/jwang/test/optimum-habana-jychen/optimum/habana/transformers/models/mixtral/modeling_mixtral.py", line 692, in forward layer_outputs = decoder_layer( File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1514, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1564, in _call_impl result = forward_call(*args, **kwargs) File "/home/jwang/test/optimum-habana-jychen/optimum/habana/transformers/models/mixtral/modeling_mixtral.py", line 518, in forward hidden_states, self_attn_weights, present_key_value = self.self_attn( File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1514, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1564, in _call_impl result = forward_call(*args, **kwargs) File "/home/jwang/test/optimum-habana-jychen/optimum/habana/transformers/models/mixtral/modeling_mixtral.py", line 356, in forward key_states = self.k_cache.update(past_key_value[0], key_states, 2, token_idx, self.inp_seq_len) File "/usr/local/lib/python3.10/dist-packages/habana_quantization_toolkit/_quant_common/helper_modules.py", line 264, in update qinput = self.quant_input_0(cur) File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1691, in getattr raise AttributeError(f"'{type(self).name}' object has no attribute '{name}'") AttributeError: 'PatchedKVCache' object has no attribute 'quant_input_0'

Please add --reuse_kvcache when measure with bf16, from my understanding, because kvcache need to be an 'nn.Module', then it could be measured.

For quantization mode, it's fine to just remove --reuse_cache.

Or if there is any solution, please let me know

@libinta libinta removed the run-test Run CI for PRs from external contributors label Apr 23, 2024
@jychen21 jychen21 closed this May 7, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants