Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 44 additions & 30 deletions optimum/habana/transformers/models/qwen2/modeling_qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,16 @@ def gaudi_qwen2_rmsnorm_forward(self, hidden_states):
return self.weight * hidden_states.to(input_dtype)


# FusedScaledDotProductAttention
class ModuleFusedSDPA(torch.nn.Module):
def __init__(self, fusedSDPA):
super().__init__()
self._hpu_kernel_fsdpa = fusedSDPA

def forward(self, query, key, value, attn_mask, dropout_p, is_casual, scale, softmax_mode):
return self._hpu_kernel_fsdpa.apply(query, key, value, attn_mask, dropout_p, is_casual, scale, softmax_mode)


class GaudiQwen2MLP(Qwen2MLP):
def pre_mlp_forward(self, x):
inputs = self.act_fn(self.gate_proj(x)) * self.up_proj(x)
Expand Down Expand Up @@ -177,6 +187,8 @@ class GaudiQwen2Attention(Qwen2Attention):
def __init__(self, config: Qwen2Config, layer_idx: Optional[int] = None):
super().__init__(config, layer_idx)

self.fused_scaled_dot_product_attention = ModuleFusedSDPA(FusedSDPA)

self.matmul_qk = Matmul()
self.matmul_av = Matmul()
self.k_cache = KVCache()
Expand Down Expand Up @@ -214,24 +226,30 @@ def reorder_kv_cache(self, beam_idx: torch.LongTensor):
self.reorder(self.v_cache.cache, beam_idx, seq_length, head_dim)
return (self.k_cache.cache.shape, self.v_cache.cache.shape)

def gaudi_flash_attn_v1(self, query_layer, key_layer, value_layer, attention_mask, dropout_rate, q_block_size):
# TODO test with this function as well
def gaudi_flash_attn_v1(
self, query_layer, key_layer, value_layer, attention_mask, dropout_rate, is_casual, scale, softmax_mode
):
"""
Gaudi version of Flash Attention V1 to support long sequence at prompt phase
Causal mask is not supported in this optimization
"""
assert not is_casual
q_len = query_layer.size(-2)
q_tiles = (q_len // q_block_size) if (q_len % q_block_size == 0) else math.ceil(q_len / q_block_size)
q_padding = q_tiles * q_block_size - q_len
q_tiles = (q_len // self.block_size) if (q_len % self.block_size == 0) else math.ceil(q_len / self.block_size)
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note to self: changing signature of gaudi_flash_attn_v1 to match inputs of self.fused_scaled_dot_product_attention. that way the calling site is simplified

q_padding = q_tiles * self.block_size - q_len
query_layer = F.pad(query_layer, (0, 0, 0, q_padding), "constant", 0)
if attention_mask is not None:
attention_mask = F.pad(attention_mask, (0, 0, 0, q_padding), "constant", -10000.0)

row_o_list = []
for i in range(q_tiles):
s, e = i * q_block_size, (i + 1) * q_block_size
s, e = i * self.block_size, (i + 1) * self.block_size
row_q = query_layer[:, :, s:e, :]
row_mask = attention_mask[:, :, s:e, :]
attn_output_partial = FusedSDPA.apply(row_q, key_layer, value_layer, row_mask, dropout_rate, False, None)
attn_output_partial = self.fused_scaled_dot_product_attention(
row_q, key_layer, value_layer, row_mask, dropout_rate, False, None, softmax_mode
)
row_o_list.append(attn_output_partial)
attn_output = torch.cat(row_o_list, dim=-2)

Expand Down Expand Up @@ -310,7 +328,8 @@ def pre_attn_forward(
past_value = torch.zeros(
key_states.shape, dtype=self.k_proj.weight.dtype, device=key_states.device
)
past_key_value = (past_key, past_value)
# Return list instead of tuple
past_key_value = [past_key, past_value]
key_states = self.k_cache.update(past_key_value[0], key_states, 2, token_idx, self.inp_seq_len)
value_states = self.v_cache.update(past_key_value[1], value_states, 2, token_idx, self.inp_seq_len)
if token_idx is None:
Expand All @@ -325,33 +344,22 @@ def pre_attn_forward(
else:
past_key_value = None

flash_attention_fast_softmax = True # TODO pass this along
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you change as pass?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there are other models with the same todo. i'll create a PR thsat fixes this todo for all of them in one go

softmax_mode = "fast" if flash_attention_fast_softmax else "None"
if use_flash_attention and FusedSDPA:
import habana_frameworks.torch.hpu as ht

if q_len == 1:
# next token
with ht.sdp_kernel(enable_recompute=False):
attn_output = FusedSDPA.apply(
query_states, key_states, value_states, attention_mask, 0.0, False, None
)
else:
# first token
if flash_attention_causal_mask:
# causal masking on first token requires inputs to be of the same length
with ht.sdp_kernel(enable_recompute=flash_attention_recompute):
attn_output = FusedSDPA.apply(query_states, key_states, value_states, None, 0.0, True, None)
else:
with ht.sdp_kernel(enable_recompute=flash_attention_recompute):
if q_len > 8192:
attn_output = self.gaudi_flash_attn_v1(
query_states, key_states, value_states, attention_mask, 0.0, self.block_size
)
htcore.mark_step()
else:
attn_output = FusedSDPA.apply(
query_states, key_states, value_states, attention_mask, 0.0, False, None
)
enable_recompute = q_len == 1 or flash_attention_recompute
is_causal = q_len > 1 and flash_attention_causal_mask

attn_mask = None if flash_attention_causal_mask else attention_mask
args = (query_states, key_states, value_states, attn_mask, 0.0, is_causal, None, softmax_mode)
with ht.sdp_kernel(enable_recompute=enable_recompute):
if flash_attention_causal_mask or (not flash_attention_causal_mask and q_len <= 8192):
attn_output = self.fused_scaled_dot_product_attention(*args)
else:
attn_output = self.gaudi_flash_attn_v1(*args)
htcore.mark_step()
else:
query_states, key_states, value_states, attention_mask = gaudi_qwen2_repeat_kv(
query_states, key_states, value_states, attention_mask, self.num_key_value_groups
Expand Down Expand Up @@ -391,6 +399,11 @@ def pre_attn_forward(
if not output_attentions:
attn_weights = None

if not reuse_cache and token_idx is not None and cache_idx is not None and q_len == 1:
# Return only past key value shapes and not the tensors during decode phase (q len is 1)
# to avoid making past key values as persistent output tensors of HPU graphs.
past_key_value = (past_key_value[0].shape, past_key_value[1].shape)

return attn_output, attn_weights, past_key_value

def attention_all_reduce(self, attn_output):
Expand Down Expand Up @@ -821,6 +834,7 @@ def prepare_inputs_for_generation(
past_length = 0

reuse_cache = kwargs.get("reuse_cache")
bucket_internal = kwargs.get("bucket_internal")
if past_key_values is not None:
if token_idx is not None:
input_ids = torch.index_select(input_ids, 1, token_idx - 1)
Expand Down Expand Up @@ -852,7 +866,7 @@ def prepare_inputs_for_generation(
and cache_length + input_ids.shape[1] > max_cache_length
):
attention_mask = attention_mask[:, -max_cache_length:]
elif reuse_cache and token_idx is not None:
elif (reuse_cache or bucket_internal) and token_idx is not None:
# With reuse_cache, KV cache is pre allocated hence for the 1st token we can slice the inputs till token idx for the fwd pass
input_ids = input_ids[:, :token_idx]
attention_mask = attention_mask[:, :token_idx]
Expand Down