-
Notifications
You must be signed in to change notification settings - Fork 271
Sasarkar/qwen optimization #1067
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Closed
Changes from all commits
Commits
Show all changes
11 commits
Select commit
Hold shift + click to select a range
cee7dd4
initial commit
ssarkar2 4ed76f3
Similify fusedsdpa section
ssarkar2 65c01ca
kv cache opt
ssarkar2 be2828d
Merge branch 'main' into sasarkar/qwen_opt
ssarkar2 e9fcaa8
Merge branch 'main' into sasarkar/qwen_opt
ssarkar2 f6e1f16
Merge branch 'main' into sasarkar/qwen_opt
ssarkar2 99ffca7
clean up
ssarkar2 eca543f
style
ssarkar2 7ba78d4
remove unused var and add style
ssarkar2 a5b9afc
remove unused var and add style
ssarkar2 4281d6b
Update modeling_qwen2.py
ssarkar2 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -82,6 +82,16 @@ def gaudi_qwen2_rmsnorm_forward(self, hidden_states): | |
| return self.weight * hidden_states.to(input_dtype) | ||
|
|
||
|
|
||
| # FusedScaledDotProductAttention | ||
| class ModuleFusedSDPA(torch.nn.Module): | ||
| def __init__(self, fusedSDPA): | ||
| super().__init__() | ||
| self._hpu_kernel_fsdpa = fusedSDPA | ||
|
|
||
| def forward(self, query, key, value, attn_mask, dropout_p, is_casual, scale, softmax_mode): | ||
| return self._hpu_kernel_fsdpa.apply(query, key, value, attn_mask, dropout_p, is_casual, scale, softmax_mode) | ||
|
|
||
|
|
||
| class GaudiQwen2MLP(Qwen2MLP): | ||
| def pre_mlp_forward(self, x): | ||
| inputs = self.act_fn(self.gate_proj(x)) * self.up_proj(x) | ||
|
|
@@ -177,6 +187,8 @@ class GaudiQwen2Attention(Qwen2Attention): | |
| def __init__(self, config: Qwen2Config, layer_idx: Optional[int] = None): | ||
| super().__init__(config, layer_idx) | ||
|
|
||
| self.fused_scaled_dot_product_attention = ModuleFusedSDPA(FusedSDPA) | ||
|
|
||
| self.matmul_qk = Matmul() | ||
| self.matmul_av = Matmul() | ||
| self.k_cache = KVCache() | ||
|
|
@@ -214,24 +226,30 @@ def reorder_kv_cache(self, beam_idx: torch.LongTensor): | |
| self.reorder(self.v_cache.cache, beam_idx, seq_length, head_dim) | ||
| return (self.k_cache.cache.shape, self.v_cache.cache.shape) | ||
|
|
||
| def gaudi_flash_attn_v1(self, query_layer, key_layer, value_layer, attention_mask, dropout_rate, q_block_size): | ||
| # TODO test with this function as well | ||
| def gaudi_flash_attn_v1( | ||
| self, query_layer, key_layer, value_layer, attention_mask, dropout_rate, is_casual, scale, softmax_mode | ||
| ): | ||
| """ | ||
| Gaudi version of Flash Attention V1 to support long sequence at prompt phase | ||
| Causal mask is not supported in this optimization | ||
| """ | ||
| assert not is_casual | ||
| q_len = query_layer.size(-2) | ||
| q_tiles = (q_len // q_block_size) if (q_len % q_block_size == 0) else math.ceil(q_len / q_block_size) | ||
| q_padding = q_tiles * q_block_size - q_len | ||
| q_tiles = (q_len // self.block_size) if (q_len % self.block_size == 0) else math.ceil(q_len / self.block_size) | ||
| q_padding = q_tiles * self.block_size - q_len | ||
| query_layer = F.pad(query_layer, (0, 0, 0, q_padding), "constant", 0) | ||
| if attention_mask is not None: | ||
| attention_mask = F.pad(attention_mask, (0, 0, 0, q_padding), "constant", -10000.0) | ||
|
|
||
| row_o_list = [] | ||
| for i in range(q_tiles): | ||
| s, e = i * q_block_size, (i + 1) * q_block_size | ||
| s, e = i * self.block_size, (i + 1) * self.block_size | ||
| row_q = query_layer[:, :, s:e, :] | ||
| row_mask = attention_mask[:, :, s:e, :] | ||
| attn_output_partial = FusedSDPA.apply(row_q, key_layer, value_layer, row_mask, dropout_rate, False, None) | ||
| attn_output_partial = self.fused_scaled_dot_product_attention( | ||
| row_q, key_layer, value_layer, row_mask, dropout_rate, False, None, softmax_mode | ||
| ) | ||
| row_o_list.append(attn_output_partial) | ||
| attn_output = torch.cat(row_o_list, dim=-2) | ||
|
|
||
|
|
@@ -310,7 +328,8 @@ def pre_attn_forward( | |
| past_value = torch.zeros( | ||
| key_states.shape, dtype=self.k_proj.weight.dtype, device=key_states.device | ||
| ) | ||
| past_key_value = (past_key, past_value) | ||
| # Return list instead of tuple | ||
| past_key_value = [past_key, past_value] | ||
| key_states = self.k_cache.update(past_key_value[0], key_states, 2, token_idx, self.inp_seq_len) | ||
| value_states = self.v_cache.update(past_key_value[1], value_states, 2, token_idx, self.inp_seq_len) | ||
| if token_idx is None: | ||
|
|
@@ -325,33 +344,22 @@ def pre_attn_forward( | |
| else: | ||
| past_key_value = None | ||
|
|
||
| flash_attention_fast_softmax = True # TODO pass this along | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can you change as pass?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. there are other models with the same todo. i'll create a PR thsat fixes this todo for all of them in one go |
||
| softmax_mode = "fast" if flash_attention_fast_softmax else "None" | ||
| if use_flash_attention and FusedSDPA: | ||
| import habana_frameworks.torch.hpu as ht | ||
|
|
||
| if q_len == 1: | ||
| # next token | ||
| with ht.sdp_kernel(enable_recompute=False): | ||
| attn_output = FusedSDPA.apply( | ||
| query_states, key_states, value_states, attention_mask, 0.0, False, None | ||
| ) | ||
| else: | ||
| # first token | ||
| if flash_attention_causal_mask: | ||
| # causal masking on first token requires inputs to be of the same length | ||
| with ht.sdp_kernel(enable_recompute=flash_attention_recompute): | ||
| attn_output = FusedSDPA.apply(query_states, key_states, value_states, None, 0.0, True, None) | ||
| else: | ||
| with ht.sdp_kernel(enable_recompute=flash_attention_recompute): | ||
| if q_len > 8192: | ||
| attn_output = self.gaudi_flash_attn_v1( | ||
| query_states, key_states, value_states, attention_mask, 0.0, self.block_size | ||
| ) | ||
| htcore.mark_step() | ||
| else: | ||
| attn_output = FusedSDPA.apply( | ||
| query_states, key_states, value_states, attention_mask, 0.0, False, None | ||
| ) | ||
| enable_recompute = q_len == 1 or flash_attention_recompute | ||
| is_causal = q_len > 1 and flash_attention_causal_mask | ||
|
|
||
| attn_mask = None if flash_attention_causal_mask else attention_mask | ||
| args = (query_states, key_states, value_states, attn_mask, 0.0, is_causal, None, softmax_mode) | ||
| with ht.sdp_kernel(enable_recompute=enable_recompute): | ||
| if flash_attention_causal_mask or (not flash_attention_causal_mask and q_len <= 8192): | ||
| attn_output = self.fused_scaled_dot_product_attention(*args) | ||
| else: | ||
| attn_output = self.gaudi_flash_attn_v1(*args) | ||
| htcore.mark_step() | ||
| else: | ||
| query_states, key_states, value_states, attention_mask = gaudi_qwen2_repeat_kv( | ||
| query_states, key_states, value_states, attention_mask, self.num_key_value_groups | ||
|
|
@@ -391,6 +399,11 @@ def pre_attn_forward( | |
| if not output_attentions: | ||
| attn_weights = None | ||
|
|
||
| if not reuse_cache and token_idx is not None and cache_idx is not None and q_len == 1: | ||
| # Return only past key value shapes and not the tensors during decode phase (q len is 1) | ||
| # to avoid making past key values as persistent output tensors of HPU graphs. | ||
| past_key_value = (past_key_value[0].shape, past_key_value[1].shape) | ||
|
|
||
| return attn_output, attn_weights, past_key_value | ||
|
|
||
| def attention_all_reduce(self, attn_output): | ||
|
|
@@ -821,6 +834,7 @@ def prepare_inputs_for_generation( | |
| past_length = 0 | ||
|
|
||
| reuse_cache = kwargs.get("reuse_cache") | ||
| bucket_internal = kwargs.get("bucket_internal") | ||
| if past_key_values is not None: | ||
| if token_idx is not None: | ||
| input_ids = torch.index_select(input_ids, 1, token_idx - 1) | ||
|
|
@@ -852,7 +866,7 @@ def prepare_inputs_for_generation( | |
| and cache_length + input_ids.shape[1] > max_cache_length | ||
| ): | ||
| attention_mask = attention_mask[:, -max_cache_length:] | ||
| elif reuse_cache and token_idx is not None: | ||
| elif (reuse_cache or bucket_internal) and token_idx is not None: | ||
| # With reuse_cache, KV cache is pre allocated hence for the 1st token we can slice the inputs till token idx for the fwd pass | ||
| input_ids = input_ids[:, :token_idx] | ||
| attention_mask = attention_mask[:, :token_idx] | ||
|
|
||
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
note to self: changing signature of gaudi_flash_attn_v1 to match inputs of self.fused_scaled_dot_product_attention. that way the calling site is simplified