diff --git a/examples/text-generation/quantization_config/maxabs_quant_gemma.json b/examples/text-generation/quantization_config/maxabs_quant_gemma.json index e7c6b6ddd2..ada2fa0c85 100644 --- a/examples/text-generation/quantization_config/maxabs_quant_gemma.json +++ b/examples/text-generation/quantization_config/maxabs_quant_gemma.json @@ -4,8 +4,6 @@ "observer": "maxabs", "scale_method": "maxabs_hw", "blocklist": {"types": [], "names": [ - "matmul_qk", - "matmul_av", "lm_head" ]}, "dump_stats_path": "./hqt_output/measure" diff --git a/optimum/habana/transformers/models/gemma/modeling_gemma.py b/optimum/habana/transformers/models/gemma/modeling_gemma.py index 40b9b429fe..b420af44f0 100755 --- a/optimum/habana/transformers/models/gemma/modeling_gemma.py +++ b/optimum/habana/transformers/models/gemma/modeling_gemma.py @@ -83,6 +83,36 @@ def gaudi_gemma_repeat_kv( return query_states, key_states, value_states, attention_mask +class ModuleFusedSDPA(torch.nn.Module): + def __init__(self, fusedSDPA): + super().__init__() + self._hpu_kernel_fsdpa = fusedSDPA + + def forward( + self, + query, + key, + value, + attn_mask, + dropout_p, + is_casual, + scale, + enable_recompute, + ): + import habana_frameworks.torch.hpu as ht + + with ht.sdp_kernel(enable_recompute=enable_recompute): + return self._hpu_kernel_fsdpa.apply( + query, + key, + value, + attn_mask, + dropout_p, + is_casual, + scale, + ) + + class Matmul(torch.nn.Module): def __init__(self): super().__init__() @@ -145,6 +175,8 @@ def __init__(self, config: GemmaConfig, layer_idx: Optional[int] = None): self.block_size = 4096 self.rotary_emb = GaudiRotaryEmbedding(config=self.config) + self.fused_scaled_dot_product_attention = ModuleFusedSDPA(FusedSDPA) if FusedSDPA else None + def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len): cache_shape = (batch_size, self.num_key_value_heads, max_seq_len, self.head_dim) device = self.k_proj.weight.device @@ -174,7 +206,9 @@ def reorder_kv_cache(self, beam_idx: torch.LongTensor): self.reorder(self.v_cache.cache, beam_idx, seq_length, head_dim) return (self.k_cache.cache.shape, self.v_cache.cache.shape) - def gaudi_flash_attn_v1(self, query_layer, key_layer, value_layer, attention_mask, dropout_rate, q_block_size): + def gaudi_flash_attn_v1( + self, query_layer, key_layer, value_layer, attention_mask, dropout_rate, q_block_size, enable_recompute + ): """ Gaudi version of Flash Attention V1 to support long sequence at prompt phase Causal mask is not supported in this optimization @@ -191,7 +225,9 @@ def gaudi_flash_attn_v1(self, query_layer, key_layer, value_layer, attention_mas s, e = i * q_block_size, (i + 1) * q_block_size row_q = query_layer[:, :, s:e, :] row_mask = attention_mask[:, :, s:e, :] - attn_output_partial = FusedSDPA.apply(row_q, key_layer, value_layer, row_mask, dropout_rate, False, None) + attn_output_partial = self.fused_scaled_dot_product_attention( + row_q, key_layer, value_layer, row_mask, dropout_rate, False, None, enable_recompute + ) row_o_list.append(attn_output_partial) attn_output = torch.cat(row_o_list, dim=-2) @@ -286,32 +322,42 @@ def pre_attn_forward( past_key_value = None if use_flash_attention and FusedSDPA: - import habana_frameworks.torch.hpu as ht - if q_len == 1: # next token use_recompute = True if os.getenv("QUANT_CONFIG", "") else False - with ht.sdp_kernel(enable_recompute=use_recompute): - attn_output = FusedSDPA.apply( - query_states, key_states, value_states, attention_mask, 0.0, False, None - ) + attn_output = self.fused_scaled_dot_product_attention( + query_states, key_states, value_states, attention_mask, 0.0, False, None, use_recompute + ) else: # first token if flash_attention_causal_mask: # causal masking on first token requires inputs to be of the same length - with ht.sdp_kernel(enable_recompute=flash_attention_recompute): - attn_output = FusedSDPA.apply(query_states, key_states, value_states, None, 0.0, True, None) + attn_output = self.fused_scaled_dot_product_attention( + query_states, key_states, value_states, None, 0.0, True, None, flash_attention_recompute + ) else: - with ht.sdp_kernel(enable_recompute=flash_attention_recompute): - if q_len > 16384: - attn_output = self.gaudi_flash_attn_v1( - query_states, key_states, value_states, attention_mask, 0.0, self.block_size - ) - htcore.mark_step() - else: - attn_output = FusedSDPA.apply( - query_states, key_states, value_states, attention_mask, 0.0, False, None - ) + if q_len > 16384: + attn_output = self.gaudi_flash_attn_v1( + query_states, + key_states, + value_states, + attention_mask, + 0.0, + self.block_size, + flash_attention_recompute, + ) + htcore.mark_step() + else: + attn_output = self.fused_scaled_dot_product_attention( + query_states, + key_states, + value_states, + attention_mask, + 0.0, + False, + None, + flash_attention_recompute, + ) else: query_states, key_states, value_states, attention_mask = gaudi_gemma_repeat_kv(