From 264b22bd70d7ff362c3482172f096a601113f8bd Mon Sep 17 00:00:00 2001 From: Libin Tang Date: Wed, 29 May 2024 07:56:23 -0700 Subject: [PATCH 1/2] Falcon optimization: add use_flash_attentiong, flash_attention_recompute, flash_attention_causal_mask add mark step per decoder add fusedsdpa fp8 fix memory issue --- examples/text-generation/README.md | 6 +- .../models/falcon/modeling_falcon.py | 222 ++++++++++++++---- 2 files changed, 183 insertions(+), 45 deletions(-) diff --git a/examples/text-generation/README.md b/examples/text-generation/README.md index d486ccf873..95b981ff4c 100644 --- a/examples/text-generation/README.md +++ b/examples/text-generation/README.md @@ -142,7 +142,7 @@ python run_generation.py \ --use_kv_cache \ --batch_size 1 \ --max_new_tokens 128 \ - --do_sample + --do_sample ``` To run Falcon-40B inference on 8 Gaudi2 cards, use the following command: @@ -154,7 +154,9 @@ python ../gaudi_spawn.py --use_deepspeed --world_size 8 run_generation.py \ --use_hpu_graphs \ --use_kv_cache \ --batch_size 1 \ ---do_sample +--do_sample \ +--use_flash_attention \ +--flash_attention_causal_mask ``` > To be able to run gated models like [StarCoder](https://huggingface.co/bigcode/starcoder), you should: diff --git a/optimum/habana/transformers/models/falcon/modeling_falcon.py b/optimum/habana/transformers/models/falcon/modeling_falcon.py index da8929d64f..068f1facad 100644 --- a/optimum/habana/transformers/models/falcon/modeling_falcon.py +++ b/optimum/habana/transformers/models/falcon/modeling_falcon.py @@ -83,8 +83,15 @@ def apply_customized_rope(q, k, cos, sin, position_ids): return apply_rotary_pos_emb(q, k, cos, sin, position_ids) +def gaudi_falcon_linear_forward(self, input: torch.Tensor) -> torch.Tensor: + hidden_states = F.linear(input, self.weight, bias=self.bias) + return hidden_states + + def gaudi_falcon_attention_split_heads( - self, fused_qkv: torch.Tensor + self, + fused_qkv: torch.Tensor, + broadcast: Optional[bool] = True, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Copied from FalconAttention._split_heads https://github.com/huggingface/transformers/blob/v4.33.2/src/transformers/models/falcon/modeling_falcon.py @@ -109,8 +116,9 @@ def gaudi_falcon_attention_split_heads( key = torch.index_select(qkv, 3, index=torch.tensor([d3], device=qkv.device)) value = torch.index_select(qkv, 3, index=torch.tensor([d3 + 1], device=qkv.device)) - key = torch.broadcast_to(key, query.shape) - value = torch.broadcast_to(value, query.shape) + if broadcast: + key = torch.broadcast_to(key, query.shape) + value = torch.broadcast_to(value, query.shape) query, key, value = [x.flatten(2, 3) for x in (query, key, value)] return query, key, value @@ -130,6 +138,16 @@ def gaudi_falcon_attention_split_heads( return query, key, value +# FusedScaledDotProductAttention +class ModuleFusedSDPA(torch.nn.Module): + def __init__(self, fusedSDPA): + super().__init__() + self._hpu_kernel_fsdpa = fusedSDPA + + def forward(self, query, key, value, attn_mask, dropout_p, is_casual, scale, softmax_mode): + return self._hpu_kernel_fsdpa.apply(query, key, value, attn_mask, dropout_p, is_casual, scale, softmax_mode) + + class Softmax(nn.Module): def __init__(self): super().__init__() @@ -154,6 +172,41 @@ def __init__(self, config: FalconConfig): self.bmm1 = Matmul() self.bmm2 = Matmul() self.softmax = Softmax() + self.num_key_value_groups = config.num_attention_heads // config.num_kv_heads + + def repeat_kv( + self, + query_states: torch.Tensor, + key_states: torch.Tensor, + value_states: torch.Tensor, + attention_mask: torch.Tensor, + n_rep: int, + ): + """ + Copied from repeat_kv: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py + The only differences are: + - Append num_key_value_heads == 1 check as kv states can be broadcasted during matmuls so need to expand and reshape them. + - Add new args query_states, key_states, value_states and attention_mask and update the logic for expansion. + The query states go from (batch, num_heads, seqlen, head_dim) to (batch, num_key_value_heads, n_rep, seqlen, head_dim) + The key/value states go from (batch, num_key_value_heads, seqlen, head_dim) to (batch, num_key_value_heads, 1, seqlen, head_dim) + """ + batch, num_key_value_heads, kv_len, head_dim = key_states.shape + if n_rep == 1 or num_key_value_heads == 1: + return query_states, key_states, value_states, attention_mask + + new_kv_shape = (batch, num_key_value_heads, 1, kv_len, head_dim) + key_states = key_states.reshape(new_kv_shape) + value_states = value_states.reshape(new_kv_shape) + + batch, _, q_len, head_dim = query_states.shape + new_q_shape = (batch, num_key_value_heads, n_rep, q_len, head_dim) + query_states = query_states.reshape(new_q_shape) + + if attention_mask is not None: + # Add groups dim and set to 1 + attention_mask = attention_mask.unsqueeze(1) + + return query_states, key_states, value_states, attention_mask def forward(self, query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None) -> torch.Tensor: L, S = query.size(-2), key.size(-2) @@ -171,12 +224,14 @@ def forward(self, query, key, value, attn_mask=None, dropout_p=0.0, is_causal=Fa if attn_mask.dtype == torch.bool: attn_mask.masked_fill_(attn_mask.logical_not(), float("-inf")) - attn_weight = self.bmm1(query, key.transpose(-2, -1)) + query, key, value, attn_mask = self.repeat_kv(query, key, value, attn_mask, self.num_key_value_groups) + attn_weight = self.bmm1(query, key.transpose(-2, -1)) attn_weight += attn_mask attn_weight = self.softmax(attn_weight, dim=-1, invAttnHead=invAttnHead) attn_weight = torch.dropout(attn_weight, dropout_p, train=True) - return self.bmm2(attn_weight, value) + attn_output = self.bmm2(attn_weight, value) + return attn_output def update(prev, cur, dim, idx, inp_seq_len): @@ -236,13 +291,30 @@ class GaudiFalconAttention(FalconAttention): - replace F.scaled_dot_product_attention with Habana torch's version for BF16 - use ScaledDotProductAttention for FP8 quantization - add new arg reuse_cache + - add new args use_flash_attention + - add new arg flash_attention_recompute + - add new arg flash_attention_causal_mask """ def __init__(self, config: FalconConfig): super().__init__(config) - if os.getenv("QUANT_CONFIG", ""): - self.sdpa = ScaledDotProductAttention(config) + """ + Choice of SDPA: + There are these variables: use_flash_attention and datatype (bf16/fp8) + datatype is determined by presence of QUANT_CONFIG env var, presence of which indicates fp8 + 1. use_flash_attention, fp8: use ModuleFusedSDPA. most optimal + 2. use_flash_attention, bf16: use FusedSDPA + 3. not use_flash_attention, fp8: Use ScaledDotProductAttention, along with QUANT_CONFIG. This is the case before this PR + 4. not use_flash_attention, bf16: F.scaled_dot_product_attention. Slowest option + """ + self.is_fp8 = os.getenv("QUANT_CONFIG", "") != "" + + # In the constructor we do not know which one we will need later in the forward, so creating both + # TODO, Does this affect memory usage? + if self.is_fp8: + self.fused_scaled_dot_product_attention = ModuleFusedSDPA(FusedSDPA) + self.unfused_scaled_dot_product_attention = ScaledDotProductAttention(config) self.k_cache = KVCache() self.v_cache = KVCache() @@ -251,7 +323,7 @@ def __init__(self, config: FalconConfig): def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len): if self.config.new_decoder_architecture: - cache_shape = (batch_size, self.num_heads, max_seq_len, self.head_dim) + cache_shape = (batch_size, self.num_kv_heads, max_seq_len, self.head_dim) else: cache_shape = (batch_size, 1, max_seq_len, self.head_dim) device = self.query_key_value.weight.device @@ -282,16 +354,22 @@ def pre_attn_forward( token_idx: Optional[torch.Tensor] = None, reuse_cache: Optional[bool] = False, cache_idx: int = None, + use_flash_attention: Optional[bool] = False, + flash_attention_recompute: Optional[bool] = False, + flash_attention_causal_mask: Optional[bool] = False, **kwargs, ): if "padding_mask" in kwargs: warnings.warn( "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" ) - fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size] # 3 x [batch_size, seq_length, num_heads, head_dim] - (query_layer, key_layer, value_layer) = self._split_heads(fused_qkv) + + train_with_flash_attention = self.training and self._use_sdpa and not output_attentions and head_mask is None + (query_layer, key_layer, value_layer) = self._split_heads( + fused_qkv, not use_flash_attention and not self.is_fp8 and not train_with_flash_attention + ) batch_size, query_length, _, _ = query_layer.shape @@ -333,7 +411,7 @@ def pre_attn_forward( dtype=self.query_key_value.weight.dtype, device=self.query_key_value.weight.device, ) - layer_past = [past_key, past_value] + layer_past = (past_key, past_value) key_layer = self.k_cache.update( layer_past[0], key_layer, -2, token_idx, self.inp_seq_len ) # k_layer bs*1, q_len, head_dim @@ -354,12 +432,7 @@ def pre_attn_forward( else: kv_length = present[0][-2] if reuse_cache else present[0].shape[-2] - if (not reuse_cache) and (token_idx is not None) and (cache_idx is not None) and (query_length == 1): - # Return only past key value shapes and not the tensors during decode phase (q len is 1) - # to avoid making past key values as persistent output tensors of HPU graphs. - present = (present[0].shape, present[1].shape) - - if alibi is None: + if alibi is None: # both train/inference if output_attentions: attention_scores = query_layer @ key_layer.transpose(-1, -2) attention_scores /= math.sqrt(self.head_dim) @@ -368,13 +441,22 @@ def pre_attn_forward( # It is unclear why neither dropout nor head_mask is applied here (while it is with alibi). attn_output = attention_scores @ value_layer else: - if FusedSDPA: - if os.getenv("QUANT_CONFIG", ""): - attn_output = self.sdpa( - query_layer, key_layer, value_layer, attention_mask, 0.0, is_causal=False - ) + if use_flash_attention or train_with_flash_attention: + is_causal = self.is_causal and query_length > 1 and flash_attention_causal_mask + if self.is_fp8: + attn_mask = None if is_causal else attention_mask + flash_attention_fast_softmax = True # TODO pass this along + softmax_mode = "fast" if flash_attention_fast_softmax else "None" + enable_recompute = self.is_fp8 if query_length == 1 else flash_attention_recompute + with sdp_kernel(enable_recompute=enable_recompute): + attn_output = self.fused_scaled_dot_product_attention( + query_layer, key_layer, value_layer, attn_mask, 0.0, is_causal, None, softmax_mode + ) else: - with sdp_kernel(enable_recompute=False) if SDPContext else contextlib.nullcontext(): + # TODO very similar to the fp8 case above, could be merged. + with sdp_kernel( + enable_recompute=flash_attention_recompute + ) if SDPContext else contextlib.nullcontext(): attn_output = FusedSDPA.apply( query_layer, key_layer, @@ -382,22 +464,28 @@ def pre_attn_forward( attention_mask, 0.0, # The query_length > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case query_length == 1. - self.is_causal and attention_mask is None and query_length > 1, + is_causal and attention_mask is None, ) else: - # Workaround util scaled_dot_product_attention support broadcast. - if self.training is True and query_layer.shape != key_layer.shape: - key_layer = torch.broadcast_to(key_layer, query_layer.shape) - value_layer = torch.broadcast_to(value_layer, query_layer.shape) - attn_output = F.scaled_dot_product_attention( - query_layer, - key_layer, - value_layer, - attention_mask, - 0.0, - # The query_length > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case query_length == 1. - is_causal=self.is_causal and attention_mask is None and query_length > 1, - ) + if self.is_fp8: + attn_output = self.unfused_scaled_dot_product_attention( + query_layer, key_layer, value_layer, attention_mask, 0.0, is_causal=False + ) + else: + # Workaround util scaled_dot_product_attention support broadcast. + if self.training is True and query_layer.shape != key_layer.shape: + key_layer = torch.broadcast_to(key_layer, query_layer.shape) + value_layer = torch.broadcast_to(value_layer, query_layer.shape) + attn_output = F.scaled_dot_product_attention( + query_layer, + key_layer, + value_layer, + attention_mask, + 0.0, + # The query_length > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case query_length == 1. + is_causal=self.is_causal and attention_mask is None and query_length > 1, + ) + # Performance improvement for HPU if self.training is True and htcore: htcore.mark_step() @@ -415,8 +503,9 @@ def pre_attn_forward( return attn_output, present, _ else: - if self._use_sdpa and not output_attentions and head_mask is None: + if train_with_flash_attention: if FusedSDPA: + # TODO needs to be turned into a module for quantization with sdp_kernel(enable_recompute=False) if SDPContext else contextlib.nullcontext(): attn_output = FusedSDPA.apply( query_layer, @@ -513,6 +602,9 @@ class GaudiFalconDecoderLayer(FalconDecoderLayer): - add new args token_idx and position_ids - add token_idx and position_ids into attention inputs - add new args reuse_cache + - add new args use_flash_attention + - add new arg flash_attention_recompute + - add new arg flash_attention_causal_mask """ def __init__(self, config: FalconConfig): @@ -538,6 +630,9 @@ def forward( token_idx: Optional[torch.Tensor] = None, reuse_cache: Optional[bool] = False, cache_idx: int = None, + use_flash_attention: Optional[bool] = False, + flash_attention_recompute: Optional[bool] = False, + flash_attention_causal_mask: Optional[bool] = False, **kwargs, ): if "padding_mask" in kwargs: @@ -563,6 +658,9 @@ def forward( token_idx=token_idx, reuse_cache=reuse_cache, cache_idx=cache_idx, + use_flash_attention=use_flash_attention, + flash_attention_recompute=flash_attention_recompute, + flash_attention_causal_mask=flash_attention_causal_mask, **kwargs, ) @@ -611,6 +709,9 @@ def pre_attn( token_idx: Optional[torch.Tensor] = None, reuse_cache: Optional[bool] = False, cache_idx: int = None, + use_flash_attention: Optional[bool] = False, + flash_attention_recompute: Optional[bool] = False, + flash_attention_causal_mask: Optional[bool] = False, ): if self.config.new_decoder_architecture: attention_layernorm_out = self.ln_attn(hidden_states) @@ -632,6 +733,9 @@ def pre_attn( token_idx=token_idx, reuse_cache=reuse_cache, cache_idx=cache_idx, + use_flash_attention=use_flash_attention, + flash_attention_recompute=flash_attention_recompute, + flash_attention_causal_mask=flash_attention_causal_mask, ) return attn_outputs, present, attn_scores, attention_layernorm_out, mlp_layernorm_out @@ -644,6 +748,9 @@ class GaudiFalconModel(FalconModel): - add new args token_idx and position_ids - add token_idx and position_ids into decoder inputs - add new arg reuse_cache + - add new args use_flash_attention + - add new arg flash_attention_recompute + - add new arg flash_attention_causal_mask """ def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len): @@ -669,6 +776,9 @@ def forward( token_idx: Optional[torch.Tensor] = None, reuse_cache: Optional[bool] = False, cache_idx: int = None, + use_flash_attention: Optional[bool] = False, + flash_attention_recompute: Optional[bool] = False, + flash_attention_causal_mask: Optional[bool] = False, ) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -786,8 +896,12 @@ def forward( # head_mask has shape n_layer x batch x num_heads x N x N head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) - htcore.mark_step() + # htcore.mark_step() for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): + # if not self.training and ( + # torch.distributed.is_initialized() is False or torch.distributed.get_world_size() == 1 + # ): + # htcore.mark_step() if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) @@ -803,6 +917,9 @@ def forward( use_cache, output_attentions, None, + use_flash_attention, + flash_attention_recompute, + flash_attention_causal_mask, ) else: outputs = block( @@ -817,6 +934,9 @@ def forward( token_idx=token_idx, reuse_cache=reuse_cache, cache_idx=cache_idx, + use_flash_attention=use_flash_attention, + flash_attention_recompute=flash_attention_recompute, + flash_attention_causal_mask=flash_attention_causal_mask, ) hidden_states = outputs[0] @@ -852,6 +972,9 @@ class GaudiFalconForCausalLM(FalconForCausalLM): - from step2 when enable KV cache, slice next_input_ids from input_ids base on the token_idx - from step2 when enable KV cache, slice next_position_ids from position_ids base on the token_idx - add new args reuse_cache + - add use_flash_attention + - add flash_attention_recompute + - add flash_attention_causal_mask """ def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len): @@ -871,7 +994,6 @@ def prepare_inputs_for_generation( **kwargs, ) -> dict: reuse_cache = kwargs.get("reuse_cache") - bucket_internal = kwargs.get("bucket_internal") if past_key_values is not None: if token_idx is not None: input_ids = torch.index_select(input_ids, 1, token_idx - 1) @@ -886,9 +1008,8 @@ def prepare_inputs_for_generation( remove_prefix_length = input_ids.shape[1] - 1 input_ids = input_ids[:, remove_prefix_length:] - elif (reuse_cache or bucket_internal) and token_idx is not None: - # KV cache is pre allocated with reuse cache or will be padded with bucket internal - # hence for the 1st token we can slice the inputs till token idx for the fwd pass. + elif reuse_cache and token_idx is not None: + # With reuse_cache, KV cache is pre allocated hence for the 1st token we can slice the inputs till token idx for the fwd pass input_ids = input_ids[:, :token_idx] attention_mask = attention_mask[:, :token_idx] @@ -917,6 +1038,9 @@ def prepare_inputs_for_generation( "token_idx": token_idx, "reuse_cache": reuse_cache, "cache_idx": kwargs.get("cache_idx"), + "use_flash_attention": kwargs.get("use_flash_attention"), + "flash_attention_recompute": kwargs.get("flash_attention_recompute"), + "flash_attention_causal_mask": kwargs.get("flash_attention_causal_mask"), } def forward( @@ -936,6 +1060,9 @@ def forward( reuse_cache: Optional[bool] = False, trim_logits: Optional[bool] = False, cache_idx: int = None, + use_flash_attention: Optional[bool] = False, + flash_attention_recompute: Optional[bool] = False, + flash_attention_causal_mask: Optional[bool] = False, ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): @@ -944,6 +1071,12 @@ def forward( are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if use_flash_attention: + assert FusedSDPA, "Set use_flash_attention True, but cannot find FusedSDPA. Please import it as from habana_frameworks.torch.hpex.kernels import FusedSDPA or set use_flash_attention to False (at the expense of a possible performance degradation)" + if flash_attention_recompute: + assert use_flash_attention, "flash_attention_recompute is set, but use_flash_attention is not" + if flash_attention_causal_mask: + assert use_flash_attention, "flash_attention_causal_mask is set, but use_flash_attention is not" transformer_outputs = self.transformer( input_ids, @@ -959,6 +1092,9 @@ def forward( token_idx=token_idx, reuse_cache=reuse_cache, cache_idx=cache_idx, + use_flash_attention=use_flash_attention, + flash_attention_recompute=flash_attention_recompute, + flash_attention_causal_mask=flash_attention_causal_mask, ) hidden_states = transformer_outputs[0] From fbe5dd1afd1f49759713ebd73f3fd6d18ab240a3 Mon Sep 17 00:00:00 2001 From: Libin Tang Date: Wed, 29 May 2024 08:07:21 -0700 Subject: [PATCH 2/2] Update readme. --- examples/text-generation/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/text-generation/README.md b/examples/text-generation/README.md index 95b981ff4c..519f00e57d 100644 --- a/examples/text-generation/README.md +++ b/examples/text-generation/README.md @@ -142,7 +142,7 @@ python run_generation.py \ --use_kv_cache \ --batch_size 1 \ --max_new_tokens 128 \ - --do_sample + --do_sample ``` To run Falcon-40B inference on 8 Gaudi2 cards, use the following command: