diff --git a/optimum/habana/transformers/models/qwen2/modeling_qwen2.py b/optimum/habana/transformers/models/qwen2/modeling_qwen2.py index f192cf4898..27d2c24716 100644 --- a/optimum/habana/transformers/models/qwen2/modeling_qwen2.py +++ b/optimum/habana/transformers/models/qwen2/modeling_qwen2.py @@ -82,6 +82,16 @@ def gaudi_qwen2_rmsnorm_forward(self, hidden_states): return self.weight * hidden_states.to(input_dtype) +# FusedScaledDotProductAttention +class ModuleFusedSDPA(torch.nn.Module): + def __init__(self, fusedSDPA): + super().__init__() + self._hpu_kernel_fsdpa = fusedSDPA + + def forward(self, query, key, value, attn_mask, dropout_p, is_casual, scale, softmax_mode): + return self._hpu_kernel_fsdpa.apply(query, key, value, attn_mask, dropout_p, is_casual, scale, softmax_mode) + + class GaudiQwen2MLP(Qwen2MLP): def pre_mlp_forward(self, x): inputs = self.act_fn(self.gate_proj(x)) * self.up_proj(x) @@ -177,6 +187,8 @@ class GaudiQwen2Attention(Qwen2Attention): def __init__(self, config: Qwen2Config, layer_idx: Optional[int] = None): super().__init__(config, layer_idx) + self.fused_scaled_dot_product_attention = ModuleFusedSDPA(FusedSDPA) + self.matmul_qk = Matmul() self.matmul_av = Matmul() self.k_cache = KVCache() @@ -214,24 +226,30 @@ def reorder_kv_cache(self, beam_idx: torch.LongTensor): self.reorder(self.v_cache.cache, beam_idx, seq_length, head_dim) return (self.k_cache.cache.shape, self.v_cache.cache.shape) - def gaudi_flash_attn_v1(self, query_layer, key_layer, value_layer, attention_mask, dropout_rate, q_block_size): + # TODO test with this function as well + def gaudi_flash_attn_v1( + self, query_layer, key_layer, value_layer, attention_mask, dropout_rate, is_casual, scale, softmax_mode + ): """ Gaudi version of Flash Attention V1 to support long sequence at prompt phase Causal mask is not supported in this optimization """ + assert not is_casual q_len = query_layer.size(-2) - q_tiles = (q_len // q_block_size) if (q_len % q_block_size == 0) else math.ceil(q_len / q_block_size) - q_padding = q_tiles * q_block_size - q_len + q_tiles = (q_len // self.block_size) if (q_len % self.block_size == 0) else math.ceil(q_len / self.block_size) + q_padding = q_tiles * self.block_size - q_len query_layer = F.pad(query_layer, (0, 0, 0, q_padding), "constant", 0) if attention_mask is not None: attention_mask = F.pad(attention_mask, (0, 0, 0, q_padding), "constant", -10000.0) row_o_list = [] for i in range(q_tiles): - s, e = i * q_block_size, (i + 1) * q_block_size + s, e = i * self.block_size, (i + 1) * self.block_size row_q = query_layer[:, :, s:e, :] row_mask = attention_mask[:, :, s:e, :] - attn_output_partial = FusedSDPA.apply(row_q, key_layer, value_layer, row_mask, dropout_rate, False, None) + attn_output_partial = self.fused_scaled_dot_product_attention( + row_q, key_layer, value_layer, row_mask, dropout_rate, False, None, softmax_mode + ) row_o_list.append(attn_output_partial) attn_output = torch.cat(row_o_list, dim=-2) @@ -310,7 +328,8 @@ def pre_attn_forward( past_value = torch.zeros( key_states.shape, dtype=self.k_proj.weight.dtype, device=key_states.device ) - past_key_value = (past_key, past_value) + # Return list instead of tuple + past_key_value = [past_key, past_value] key_states = self.k_cache.update(past_key_value[0], key_states, 2, token_idx, self.inp_seq_len) value_states = self.v_cache.update(past_key_value[1], value_states, 2, token_idx, self.inp_seq_len) if token_idx is None: @@ -325,33 +344,22 @@ def pre_attn_forward( else: past_key_value = None + flash_attention_fast_softmax = True # TODO pass this along + softmax_mode = "fast" if flash_attention_fast_softmax else "None" if use_flash_attention and FusedSDPA: import habana_frameworks.torch.hpu as ht - if q_len == 1: - # next token - with ht.sdp_kernel(enable_recompute=False): - attn_output = FusedSDPA.apply( - query_states, key_states, value_states, attention_mask, 0.0, False, None - ) - else: - # first token - if flash_attention_causal_mask: - # causal masking on first token requires inputs to be of the same length - with ht.sdp_kernel(enable_recompute=flash_attention_recompute): - attn_output = FusedSDPA.apply(query_states, key_states, value_states, None, 0.0, True, None) - else: - with ht.sdp_kernel(enable_recompute=flash_attention_recompute): - if q_len > 8192: - attn_output = self.gaudi_flash_attn_v1( - query_states, key_states, value_states, attention_mask, 0.0, self.block_size - ) - htcore.mark_step() - else: - attn_output = FusedSDPA.apply( - query_states, key_states, value_states, attention_mask, 0.0, False, None - ) + enable_recompute = q_len == 1 or flash_attention_recompute + is_causal = q_len > 1 and flash_attention_causal_mask + attn_mask = None if flash_attention_causal_mask else attention_mask + args = (query_states, key_states, value_states, attn_mask, 0.0, is_causal, None, softmax_mode) + with ht.sdp_kernel(enable_recompute=enable_recompute): + if flash_attention_causal_mask or (not flash_attention_causal_mask and q_len <= 8192): + attn_output = self.fused_scaled_dot_product_attention(*args) + else: + attn_output = self.gaudi_flash_attn_v1(*args) + htcore.mark_step() else: query_states, key_states, value_states, attention_mask = gaudi_qwen2_repeat_kv( query_states, key_states, value_states, attention_mask, self.num_key_value_groups @@ -391,6 +399,11 @@ def pre_attn_forward( if not output_attentions: attn_weights = None + if not reuse_cache and token_idx is not None and cache_idx is not None and q_len == 1: + # Return only past key value shapes and not the tensors during decode phase (q len is 1) + # to avoid making past key values as persistent output tensors of HPU graphs. + past_key_value = (past_key_value[0].shape, past_key_value[1].shape) + return attn_output, attn_weights, past_key_value def attention_all_reduce(self, attn_output): @@ -821,6 +834,7 @@ def prepare_inputs_for_generation( past_length = 0 reuse_cache = kwargs.get("reuse_cache") + bucket_internal = kwargs.get("bucket_internal") if past_key_values is not None: if token_idx is not None: input_ids = torch.index_select(input_ids, 1, token_idx - 1) @@ -852,7 +866,7 @@ def prepare_inputs_for_generation( and cache_length + input_ids.shape[1] > max_cache_length ): attention_mask = attention_mask[:, -max_cache_length:] - elif reuse_cache and token_idx is not None: + elif (reuse_cache or bucket_internal) and token_idx is not None: # With reuse_cache, KV cache is pre allocated hence for the 1st token we can slice the inputs till token idx for the fwd pass input_ids = input_ids[:, :token_idx] attention_mask = attention_mask[:, :token_idx]