Skip to content

Sasarkar/qwen optimization#1067

Closed
ssarkar2 wants to merge 11 commits into
mainfrom
sasarkar/qwen_opt
Closed

Sasarkar/qwen optimization#1067
ssarkar2 wants to merge 11 commits into
mainfrom
sasarkar/qwen_opt

Conversation

@ssarkar2
Copy link
Copy Markdown
Contributor

@ssarkar2 ssarkar2 commented Jun 11, 2024

What does this PR do?

Description

move fusedsdpa.apply into a separate module, so it can be quantized as documented here
similar to falcon opt here

Also add "no resue cache" change. similar to falcon optimization here. Originally from this

Tests

test 1 (for fp8)

QUANT_CONFIG=./quantization_config/maxabs_measure_include_outputs.json python run_generation.py --model_name_or_path /software/data/pytorch/huggingface/Qwen2-7B --use_kv_cache --max_new_tokens 128 --max_input_tokens=128 --bf16 --batch_size 1 --use_hpu_graph --trim_logits --bucket_size 256 --bucket_internal --reuse_cache --use_flash_attention --flash_attention_recompute --flash_attention_causal_mask
QUANT_CONFIG=./quantization_config/maxabs_quant.json python run_generation.py --model_name_or_path /software/data/pytorch/huggingface/Qwen2-7B --use_kv_cache --max_new_tokens 4096 --max_input_tokens=2048 --batch_size 12 --use_hpu_graph --trim_logits --bucket_size 128 --bucket_internal --reuse_cache --use_flash_attention --bf16 --flash_attention_recompute --flash_attention_causal_mask

TPS increases from 1278.0 to 1981.4. without flash flags, on main the tps is around 1833

If flash flags is not used, both main and this branch has TPS=1833.1

bs with flash flag branch tps memory
12 no main 1833.1 ?
12 yes main 1278.0 ?
12 no this branch 1833.1 ?
12 yes this branch 1981.4 ?

test 2 (no reuse cache)

HABANA_VISIBLE_MODULES=4,5,6,7 python ../gaudi_spawn.py --master 29502 --use_deepspeed --world_size 4 run_generation.py --model_name_or_path /mnt/weka/data/Qwen/Qwen2-72B --use_hpu_graphs --use_kv_cache --max_input_tokens 2048 --max_new_tokens 2048 --batch_size 64 --attn_softmax_bf16 --trim_logits --bf16 --warmup 2 --n_iterations 3 --limit_hpu_graphs --bucket_internal --bucket_size 128 --reuse_cache

bs reusecache branch tps memory
64 yes main 1310 89.9 gb, 73.48 gb
50 yes this branch 1117.99 81.13gb, 66.68 gb
64 yes this branch 1304.56 90.0 gb, 73.49 gb
64 no this branch 1306.57 94.52 gb, 50.95 gb
100 yes this branch OOM -
100 no this branch 1610 55.79gb, 94.3 gb

test 3 (fp8)
same cmd as test 1, but with larger batches

bs with flash flag branch tps memory
32 yes main 1822.3 22.32 gb
128 yes main 2259.48 66.6 gb
192 yes main 2341.06 3 94.62 gb
128 no main 5855.92 79.96 gb
150 no main 5736.10 92.1 gb, 90.69 gb
180, 192 no main OOM -
150 no current branch 5729.6 92.1gb
192 no current branch OOM -
192 yes current branch 8250.7 91.84

test 4
fp8 and no reuse_cache
first few rows of data are from test 3

bs with flash flag branch reusecache tps memory
150 no main yes 5736.10 92.1 gb, 90.69 gb
192 yes current branch yes 8250.7 91.84
180 yes current branch no wip wip
192 yes current branch no OOM -

TODO .. should bs=192 fp8, without reusecache have OOM'ed?

same cmd as test 1, but with reuse_cache removed, and larger batchsize

QUANT_CONFIG=./quantization_config/maxabs_quant.json python run_generation.py --model_name_or_path /software/data/pytorch/huggingface/Qwen2-7B --use_kv_cache --max_new_tokens 4096 --max_input_tokens=2048 --batch_size 192 --use_hpu_graph --trim_logits --bucket_size 128 --bucket_internal --use_flash_attention --bf16 --flash_attention_recompute --flash_attention_causal_mask

Fixes # (issue)

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you make sure to update the documentation with your changes?
  • Did you write any new necessary tests?

@ssarkar2 ssarkar2 requested a review from regisss as a code owner June 11, 2024 22:06
@ssarkar2 ssarkar2 changed the title Sasarkar/qwen fp8 optimization Sasarkar/qwen optimization Jun 11, 2024
@ssarkar2 ssarkar2 marked this pull request as draft June 11, 2024 22:09
@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Comment thread optimum/habana/transformers/models/qwen2/modeling_qwen2.py Outdated
q_len = query_layer.size(-2)
q_tiles = (q_len // q_block_size) if (q_len % q_block_size == 0) else math.ceil(q_len / q_block_size)
q_padding = q_tiles * q_block_size - q_len
q_tiles = (q_len // self.block_size) if (q_len % self.block_size == 0) else math.ceil(q_len / self.block_size)
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note to self: changing signature of gaudi_flash_attn_v1 to match inputs of self.fused_scaled_dot_product_attention. that way the calling site is simplified

@ssarkar2 ssarkar2 marked this pull request as ready for review June 13, 2024 07:03
Comment thread optimum/habana/transformers/models/qwen2/modeling_qwen2.py Outdated
else:
past_key_value = None

flash_attention_fast_softmax = True # TODO pass this along
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you change as pass?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there are other models with the same todo. i'll create a PR thsat fixes this todo for all of them in one go

@Morxi
Copy link
Copy Markdown

Morxi commented Jun 17, 2024

Hello, I am very interested in knowing from which generation of Gaudi this test data originates, Gaudi 1/2 or 3?

@ssarkar2
Copy link
Copy Markdown
Contributor Author

Hello, I am very interested in knowing from which generation of Gaudi this test data originates, Gaudi 1/2 or 3?

Gaudi 2

@libinta libinta closed this Jul 2, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants