diff --git a/README.md b/README.md index d8ef60d483..4f5f926f8a 100644 --- a/README.md +++ b/README.md @@ -36,6 +36,15 @@ HPUs offer fast model training and inference as well as a great price-performanc Check out [this blog post about BLOOM inference](https://huggingface.co/blog/habana-gaudi-2-bloom) and [this post benchmarking Intel Gaudi 2 and NVIDIA A100 GPUs for BridgeTower training](https://huggingface.co/blog/bridgetower) for concrete examples. +## Gaudi Setup + +Please refer to the Intel Gaudi AI Accelerator official [installation guide](https://docs.habana.ai/en/latest/Installation_Guide/index.html). + +> Tests should be run in a Docker container based on Intel Gaudi Docker images. +> +> The current version has been validated for SynapseAI 1.15. + + ## Install the library and get example scripts ### Option 1: Use the latest stable release @@ -237,15 +246,6 @@ If you find any issues while using those, please open an issue or a pull request After training your model, feel free to submit it to the Intel [leaderboard](https://huggingface.co/spaces/Intel/powered_by_intel_llm_leaderboard) which is designed to evaluate, score, and rank open-source LLMs that have been pre-trained or fine-tuned on Intel Hardwares. Models submitted to the leaderboard will be evaluated on the Intel Developer Cloud. The evaluation platform consists of Gaudi Accelerators and Xeon CPUs running benchmarks from the Eleuther AI Language Model Evaluation Harness. -## Gaudi Setup - -Please refer to the Intel Gaudi AI Accelerator official [installation guide](https://docs.habana.ai/en/latest/Installation_Guide/index.html). - -> Tests should be run in a Docker container based on Intel Gaudi Docker images. -> -> The current version has been validated for SynapseAI 1.15. - - ## Development Check the [contributor guide](https://github.com/huggingface/optimum/blob/main/CONTRIBUTING.md) for instructions. \ No newline at end of file diff --git a/examples/text-generation/README.md b/examples/text-generation/README.md index b669d3e67d..27556a5023 100644 --- a/examples/text-generation/README.md +++ b/examples/text-generation/README.md @@ -170,7 +170,9 @@ python ../gaudi_spawn.py --use_deepspeed --world_size 8 run_generation.py \ --use_hpu_graphs \ --use_kv_cache \ --batch_size 1 \ ---do_sample +--do_sample \ +--use_flash_attention \ +--flash_attention_causal_mask ``` > To be able to run gated models like [StarCoder](https://huggingface.co/bigcode/starcoder), you should: diff --git a/optimum/habana/transformers/models/falcon/modeling_falcon.py b/optimum/habana/transformers/models/falcon/modeling_falcon.py index eff21a8090..6cd3f4f2ad 100644 --- a/optimum/habana/transformers/models/falcon/modeling_falcon.py +++ b/optimum/habana/transformers/models/falcon/modeling_falcon.py @@ -88,6 +88,16 @@ def gaudi_falcon_linear_forward(self, input: torch.Tensor) -> torch.Tensor: return hidden_states +# FusedScaledDotProductAttention +class ModuleFusedSDPA(torch.nn.Module): + def __init__(self, fusedSDPA): + super().__init__() + self._hpu_kernel_fsdpa = fusedSDPA + + def forward(self, query, key, value, attn_mask, dropout_p, is_casual, scale, softmax_mode): + return self._hpu_kernel_fsdpa.apply(query, key, value, attn_mask, dropout_p, is_casual, scale, softmax_mode) + + class Softmax(nn.Module): def __init__(self): super().__init__() @@ -112,6 +122,41 @@ def __init__(self, config: FalconConfig): self.bmm1 = Matmul() self.bmm2 = Matmul() self.softmax = Softmax() + self.num_key_value_groups = config.num_attention_heads // config.num_kv_heads + + def repeat_kv( + self, + query_states: torch.Tensor, + key_states: torch.Tensor, + value_states: torch.Tensor, + attention_mask: torch.Tensor, + n_rep: int, + ): + """ + Copied from repeat_kv: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py + The only differences are: + - Append num_key_value_heads == 1 check as kv states can be broadcasted during matmuls so need to expand and reshape them. + - Add new args query_states, key_states, value_states and attention_mask and update the logic for expansion. + The query states go from (batch, num_heads, seqlen, head_dim) to (batch, num_key_value_heads, n_rep, seqlen, head_dim) + The key/value states go from (batch, num_key_value_heads, seqlen, head_dim) to (batch, num_key_value_heads, 1, seqlen, head_dim) + """ + batch, num_key_value_heads, kv_len, head_dim = key_states.shape + if n_rep == 1 or num_key_value_heads == 1: + return query_states, key_states, value_states, attention_mask + + new_kv_shape = (batch, num_key_value_heads, 1, kv_len, head_dim) + key_states = key_states.reshape(new_kv_shape) + value_states = value_states.reshape(new_kv_shape) + + batch, _, q_len, head_dim = query_states.shape + new_q_shape = (batch, num_key_value_heads, n_rep, q_len, head_dim) + query_states = query_states.reshape(new_q_shape) + + if attention_mask is not None: + # Add groups dim and set to 1 + attention_mask = attention_mask.unsqueeze(1) + + return query_states, key_states, value_states, attention_mask def forward(self, query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None) -> torch.Tensor: L, S = query.size(-2), key.size(-2) @@ -129,12 +174,14 @@ def forward(self, query, key, value, attn_mask=None, dropout_p=0.0, is_causal=Fa if attn_mask.dtype == torch.bool: attn_mask.masked_fill_(attn_mask.logical_not(), float("-inf")) - attn_weight = self.bmm1(query, key.transpose(-2, -1)) + query, key, value, attn_mask = self.repeat_kv(query, key, value, attn_mask, self.num_key_value_groups) + attn_weight = self.bmm1(query, key.transpose(-2, -1)) attn_weight += attn_mask attn_weight = self.softmax(attn_weight, dim=-1, invAttnHead=invAttnHead) attn_weight = torch.dropout(attn_weight, dropout_p, train=True) - return self.bmm2(attn_weight, value) + attn_output = self.bmm2(attn_weight, value) + return attn_output def update(prev, cur, dim, idx, inp_seq_len): @@ -194,20 +241,37 @@ class GaudiFalconAttention(FalconAttention): - replace F.scaled_dot_product_attention with Habana torch's version for BF16 - use ScaledDotProductAttention for FP8 quantization - add new arg reuse_cache + - add new args use_flash_attention + - add new arg flash_attention_recompute + - add new arg flash_attention_causal_mask + Choice of SDPA: + There are these variables: use_flash_attention and datatype (bf16/fp8) + datatype is determined by presence of QUANT_CONFIG env var, presence of which indicates fp8 + 1. use_flash_attention, fp8: use ModuleFusedSDPA. most optimal + 2. use_flash_attention, bf16: use FusedSDPA + 3. not use_flash_attention, fp8: Use ScaledDotProductAttention, along with QUANT_CONFIG. This is the case before this PR + 4. not use_flash_attention, bf16: F.scaled_dot_product_attention. Slowest option """ def __init__(self, config: FalconConfig): super().__init__(config) - if os.getenv("QUANT_CONFIG", ""): - self.sdpa = ScaledDotProductAttention(config) + self.is_fp8 = os.getenv("QUANT_CONFIG", "") != "" + + # In the constructor we do not know which one we will need later in the forward, so creating both + # TODO, Does this affect memory usage? + if self.is_fp8: + self.fused_scaled_dot_product_attention = ModuleFusedSDPA(FusedSDPA) + self.unfused_scaled_dot_product_attention = ScaledDotProductAttention(config) self.k_cache = KVCache() self.v_cache = KVCache() self.inp_seq_len = -1 self.max_position_embeddings = config.max_position_embeddings - def _split_heads(self, fused_qkv: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + def _split_heads( + self, fused_qkv: torch.Tensor, broadcast: Optional[bool] = True + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: if self.new_decoder_architecture: batch, seq_len, _ = fused_qkv.shape @@ -226,9 +290,9 @@ def _split_heads(self, fused_qkv: torch.Tensor) -> Tuple[torch.Tensor, torch.Ten query = torch.index_select(qkv, 3, index=torch.arange(d3, device=qkv.device)) key = torch.index_select(qkv, 3, index=torch.tensor([d3], device=qkv.device)) value = torch.index_select(qkv, 3, index=torch.tensor([d3 + 1], device=qkv.device)) - - key = torch.broadcast_to(key, query.shape) - value = torch.broadcast_to(value, query.shape) + if broadcast: + key = torch.broadcast_to(key, query.shape) + value = torch.broadcast_to(value, query.shape) query, key, value = [x.flatten(2, 3) for x in (query, key, value)] return query, key, value @@ -249,7 +313,7 @@ def _split_heads(self, fused_qkv: torch.Tensor) -> Tuple[torch.Tensor, torch.Ten def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len): if self.config.new_decoder_architecture: - cache_shape = (batch_size, self.num_heads, max_seq_len, self.head_dim) + cache_shape = (batch_size, self.num_kv_heads, max_seq_len, self.head_dim) else: cache_shape = (batch_size, 1, max_seq_len, self.head_dim) device = self.query_key_value.weight.device @@ -280,16 +344,22 @@ def pre_attn_forward( token_idx: Optional[torch.Tensor] = None, reuse_cache: Optional[bool] = False, cache_idx: int = None, + use_flash_attention: Optional[bool] = False, + flash_attention_recompute: Optional[bool] = False, + flash_attention_causal_mask: Optional[bool] = False, **kwargs, ): if "padding_mask" in kwargs: warnings.warn( "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" ) - fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size] # 3 x [batch_size, seq_length, num_heads, head_dim] - (query_layer, key_layer, value_layer) = self._split_heads(fused_qkv) + + train_with_flash_attention = self.training and self._use_sdpa and not output_attentions and head_mask is None + (query_layer, key_layer, value_layer) = self._split_heads( + fused_qkv, not use_flash_attention and not self.is_fp8 and not train_with_flash_attention + ) batch_size, query_length, _, _ = query_layer.shape @@ -357,7 +427,7 @@ def pre_attn_forward( # to avoid making past key values as persistent output tensors of HPU graphs. present = (present[0].shape, present[1].shape) - if alibi is None: + if alibi is None: # both train/inference if output_attentions: attention_scores = query_layer @ key_layer.transpose(-1, -2) attention_scores /= math.sqrt(self.head_dim) @@ -366,13 +436,22 @@ def pre_attn_forward( # It is unclear why neither dropout nor head_mask is applied here (while it is with alibi). attn_output = attention_scores @ value_layer else: - if FusedSDPA: - if os.getenv("QUANT_CONFIG", ""): - attn_output = self.sdpa( - query_layer, key_layer, value_layer, attention_mask, 0.0, is_causal=False - ) + if use_flash_attention or train_with_flash_attention: + is_causal = self.is_causal and query_length > 1 and flash_attention_causal_mask + if self.is_fp8: + attn_mask = None if is_causal else attention_mask + flash_attention_fast_softmax = True # TODO pass this along + softmax_mode = "fast" if flash_attention_fast_softmax else "None" + enable_recompute = self.is_fp8 if query_length == 1 else flash_attention_recompute + with sdp_kernel(enable_recompute=enable_recompute): + attn_output = self.fused_scaled_dot_product_attention( + query_layer, key_layer, value_layer, attn_mask, 0.0, is_causal, None, softmax_mode + ) else: - with sdp_kernel(enable_recompute=False) if SDPContext else contextlib.nullcontext(): + # TODO very similar to the fp8 case above, could be merged. + with sdp_kernel( + enable_recompute=flash_attention_recompute + ) if SDPContext else contextlib.nullcontext(): attn_output = FusedSDPA.apply( query_layer, key_layer, @@ -380,22 +459,28 @@ def pre_attn_forward( attention_mask, 0.0, # The query_length > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case query_length == 1. - self.is_causal and attention_mask is None and query_length > 1, + is_causal and attention_mask is None, ) else: - # Workaround util scaled_dot_product_attention support broadcast. - if self.training is True and query_layer.shape != key_layer.shape: - key_layer = torch.broadcast_to(key_layer, query_layer.shape) - value_layer = torch.broadcast_to(value_layer, query_layer.shape) - attn_output = F.scaled_dot_product_attention( - query_layer, - key_layer, - value_layer, - attention_mask, - 0.0, - # The query_length > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case query_length == 1. - is_causal=self.is_causal and attention_mask is None and query_length > 1, - ) + if self.is_fp8: + attn_output = self.unfused_scaled_dot_product_attention( + query_layer, key_layer, value_layer, attention_mask, 0.0, is_causal=False + ) + else: + # Workaround util scaled_dot_product_attention support broadcast. + if self.training is True and query_layer.shape != key_layer.shape: + key_layer = torch.broadcast_to(key_layer, query_layer.shape) + value_layer = torch.broadcast_to(value_layer, query_layer.shape) + attn_output = F.scaled_dot_product_attention( + query_layer, + key_layer, + value_layer, + attention_mask, + 0.0, + # The query_length > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case query_length == 1. + is_causal=self.is_causal and attention_mask is None and query_length > 1, + ) + # Performance improvement for HPU if self.training is True and htcore: htcore.mark_step() @@ -413,8 +498,9 @@ def pre_attn_forward( return attn_output, present, _ else: - if self._use_sdpa and not output_attentions and head_mask is None: + if train_with_flash_attention: if FusedSDPA: + # TODO needs to be turned into a module for quantization with sdp_kernel(enable_recompute=False) if SDPContext else contextlib.nullcontext(): attn_output = FusedSDPA.apply( query_layer, @@ -511,6 +597,9 @@ class GaudiFalconDecoderLayer(FalconDecoderLayer): - add new args token_idx and position_ids - add token_idx and position_ids into attention inputs - add new args reuse_cache + - add new args use_flash_attention + - add new arg flash_attention_recompute + - add new arg flash_attention_causal_mask """ def __init__(self, config: FalconConfig): @@ -536,6 +625,9 @@ def forward( token_idx: Optional[torch.Tensor] = None, reuse_cache: Optional[bool] = False, cache_idx: int = None, + use_flash_attention: Optional[bool] = False, + flash_attention_recompute: Optional[bool] = False, + flash_attention_causal_mask: Optional[bool] = False, **kwargs, ): if "padding_mask" in kwargs: @@ -561,6 +653,9 @@ def forward( token_idx=token_idx, reuse_cache=reuse_cache, cache_idx=cache_idx, + use_flash_attention=use_flash_attention, + flash_attention_recompute=flash_attention_recompute, + flash_attention_causal_mask=flash_attention_causal_mask, **kwargs, ) @@ -609,6 +704,9 @@ def pre_attn( token_idx: Optional[torch.Tensor] = None, reuse_cache: Optional[bool] = False, cache_idx: int = None, + use_flash_attention: Optional[bool] = False, + flash_attention_recompute: Optional[bool] = False, + flash_attention_causal_mask: Optional[bool] = False, ): if self.config.new_decoder_architecture: attention_layernorm_out = self.ln_attn(hidden_states) @@ -630,6 +728,9 @@ def pre_attn( token_idx=token_idx, reuse_cache=reuse_cache, cache_idx=cache_idx, + use_flash_attention=use_flash_attention, + flash_attention_recompute=flash_attention_recompute, + flash_attention_causal_mask=flash_attention_causal_mask, ) return attn_outputs, present, attn_scores, attention_layernorm_out, mlp_layernorm_out @@ -642,6 +743,9 @@ class GaudiFalconModel(FalconModel): - add new args token_idx and position_ids - add token_idx and position_ids into decoder inputs - add new arg reuse_cache + - add new args use_flash_attention + - add new arg flash_attention_recompute + - add new arg flash_attention_causal_mask """ def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len): @@ -667,6 +771,9 @@ def forward( token_idx: Optional[torch.Tensor] = None, reuse_cache: Optional[bool] = False, cache_idx: int = None, + use_flash_attention: Optional[bool] = False, + flash_attention_recompute: Optional[bool] = False, + flash_attention_causal_mask: Optional[bool] = False, ) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -781,7 +888,6 @@ def forward( # head_mask has shape n_layer x batch x num_heads x N x N head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) - htcore.mark_step() for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) @@ -798,6 +904,9 @@ def forward( use_cache, output_attentions, None, + use_flash_attention, + flash_attention_recompute, + flash_attention_causal_mask, ) else: outputs = block( @@ -812,6 +921,9 @@ def forward( token_idx=token_idx, reuse_cache=reuse_cache, cache_idx=cache_idx, + use_flash_attention=use_flash_attention, + flash_attention_recompute=flash_attention_recompute, + flash_attention_causal_mask=flash_attention_causal_mask, ) hidden_states = outputs[0] @@ -847,6 +959,9 @@ class GaudiFalconForCausalLM(FalconForCausalLM): - from step2 when enable KV cache, slice next_input_ids from input_ids base on the token_idx - from step2 when enable KV cache, slice next_position_ids from position_ids base on the token_idx - add new args reuse_cache + - add use_flash_attention + - add flash_attention_recompute + - add flash_attention_causal_mask """ def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len): @@ -912,6 +1027,9 @@ def prepare_inputs_for_generation( "token_idx": token_idx, "reuse_cache": reuse_cache, "cache_idx": kwargs.get("cache_idx"), + "use_flash_attention": kwargs.get("use_flash_attention"), + "flash_attention_recompute": kwargs.get("flash_attention_recompute"), + "flash_attention_causal_mask": kwargs.get("flash_attention_causal_mask"), } def forward( @@ -931,6 +1049,9 @@ def forward( reuse_cache: Optional[bool] = False, trim_logits: Optional[bool] = False, cache_idx: int = None, + use_flash_attention: Optional[bool] = False, + flash_attention_recompute: Optional[bool] = False, + flash_attention_causal_mask: Optional[bool] = False, ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): @@ -939,6 +1060,12 @@ def forward( are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if use_flash_attention: + assert FusedSDPA, "`use_flash_attention` is True, but cannot find FusedSDPA. Please import it as `from habana_frameworks.torch.hpex.kernels import FusedSDPA` or set use_flash_attention to False (at the expense of a possible performance degradation)." + if flash_attention_recompute: + assert use_flash_attention, "flash_attention_recompute is set, but use_flash_attention is not" + if flash_attention_causal_mask: + assert use_flash_attention, "flash_attention_causal_mask is set, but use_flash_attention is not" transformer_outputs = self.transformer( input_ids, @@ -954,6 +1081,9 @@ def forward( token_idx=token_idx, reuse_cache=reuse_cache, cache_idx=cache_idx, + use_flash_attention=use_flash_attention, + flash_attention_recompute=flash_attention_recompute, + flash_attention_causal_mask=flash_attention_causal_mask, ) hidden_states = transformer_outputs[0]