From d5419971c23983f3aa4a1a659220769dcb77e12d Mon Sep 17 00:00:00 2001 From: Abhishek Date: Thu, 3 Oct 2024 16:47:05 -0400 Subject: [PATCH 01/26] fix: fixes for graph breaks Signed-off-by: Abhishek --- .../modeling_flash_attention_utils.py | 21 +++++----- .../models/llama/modeling_llama.py | 38 +++++++++++++++++++ src/transformers/tokenization_utils_base.py | 2 +- 3 files changed, 49 insertions(+), 12 deletions(-) diff --git a/src/transformers/modeling_flash_attention_utils.py b/src/transformers/modeling_flash_attention_utils.py index 44e61825dd9c..f2431892385d 100644 --- a/src/transformers/modeling_flash_attention_utils.py +++ b/src/transformers/modeling_flash_attention_utils.py @@ -179,6 +179,8 @@ def prepare_fa2_from_position_ids(query, key, value, position_ids): return (query, key, value, indices_q, (cu_seq_lens, cu_seq_lens), (max_length, max_length)) +flash_241 = is_flash_attn_greater_or_equal("2.4.1") +deterministic_g = os.environ.get("FLASH_ATTENTION_DETERMINISTIC", "0") == "1" def _flash_attention_forward( query_states: torch.Tensor, @@ -194,6 +196,11 @@ def _flash_attention_forward( use_top_left_mask: bool = False, softcap: Optional[float] = None, deterministic: bool = None, + cu_seqlens_q: Optional[torch.LongTensor] = None, + cu_seqlens_k: Optional[torch.LongTensor] = None, + max_seqlen_in_batch_q: int = 0, + max_seqlen_in_batch_k: int = 0, + batch_size: int = 2, ): """ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token @@ -232,9 +239,9 @@ def _flash_attention_forward( ) flash_kwargs = {"window_size": (sliding_window, sliding_window)} if use_sliding_windows else {} - if is_flash_attn_greater_or_equal("2.4.1"): + if flash_241: if deterministic is None: - deterministic = os.environ.get("FLASH_ATTENTION_DETERMINISTIC", "0") == "1" + deterministic = deterministic_g flash_kwargs["deterministic"] = deterministic if softcap is not None: @@ -267,15 +274,7 @@ def _flash_attention_forward( # If position_ids is provided and check all examples do not contain only 1 sequence, If tensor in increasing # then we probably have one sequence, otherwise it is packed. Additionally check we are in pre-fill/training stage. # Use `flash_attn_varlen_func` to prevent cross-example attention and also allow padding free approach - elif position_ids is not None and not (torch.diff(position_ids, dim=-1) >= 0).all() and query_length != 1: - batch_size = query_states.size(0) - query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = prepare_fa2_from_position_ids( - query_states, key_states, value_states, position_ids - ) - - cu_seqlens_q, cu_seqlens_k = cu_seq_lens - max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens - + elif position_ids is not None and max_seqlen_in_batch_q is not None: attn_output = flash_attn_varlen_func( query_states, key_states, diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 99edee6a92a8..3982c905042a 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -422,6 +422,10 @@ def forward( use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + cu_seq_lens_q: Optional[torch.LongTensor] = None, + cu_seq_lens_k: Optional[torch.LongTensor] = None, + max_length_q: int = 0, + max_length_k: int = 0, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if isinstance(past_key_value, StaticCache): raise ValueError( @@ -495,6 +499,11 @@ def forward( key_states = key_states.to(target_dtype) value_states = value_states.to(target_dtype) + batch_size=query_states.size(0) + query_states = query_states.reshape(-1, query_states.size(-2), query_states.size(-1)) + key_states = key_states.reshape(-1, key_states.size(-2), key_states.size(-1)) + value_states = value_states.reshape(-1, value_states.size(-2), value_states.size(-1)) + attn_output = _flash_attention_forward( query_states, key_states, @@ -506,6 +515,11 @@ def forward( sliding_window=getattr(self, "sliding_window", None), use_top_left_mask=self._flash_attn_uses_top_left_mask, is_causal=self.is_causal, + cu_seqlens_q=cu_seq_lens_q, + cu_seqlens_k=cu_seq_lens_k, + max_seqlen_in_batch_q=max_length_q if isinstance(max_length_q, int) else max_length_q.item(), + max_seqlen_in_batch_k=max_length_k if isinstance(max_length_k, int) else max_length_k.item(), + batch_size=batch_size ) attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() @@ -644,6 +658,10 @@ def forward( use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + cu_seq_lens_q: Optional[torch.LongTensor] = None, + cu_seq_lens_k: Optional[torch.LongTensor] = None, + max_length_q: int = 0, + max_length_k: int = 0, **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ @@ -682,6 +700,10 @@ def forward( use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, + cu_seq_lens_q=cu_seq_lens_q, + cu_seq_lens_k=cu_seq_lens_k, + max_length_q=max_length_q, + max_length_k=max_length_k, **kwargs, ) hidden_states = residual + hidden_states @@ -870,6 +892,10 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, + cu_seq_lens_q: Optional[torch.LongTensor] = None, + cu_seq_lens_k: Optional[torch.LongTensor] = None, + max_length_q: int = 0, + max_length_k: int = 0, ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -953,6 +979,10 @@ def forward( use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, + cu_seq_lens_q=cu_seq_lens_q, + cu_seq_lens_k=cu_seq_lens_k, + max_length_q=max_length_q, + max_length_k=max_length_k, ) hidden_states = layer_outputs[0] @@ -1148,6 +1178,10 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, + cu_seq_lens_q: Optional[torch.LongTensor] = None, + cu_seq_lens_k: Optional[torch.LongTensor] = None, + max_length_q: int = 0, + max_length_k: int = 0, num_logits_to_keep: int = 0, ) -> Union[Tuple, CausalLMOutputWithPast]: r""" @@ -1198,6 +1232,10 @@ def forward( output_hidden_states=output_hidden_states, return_dict=return_dict, cache_position=cache_position, + cu_seq_lens_q=cu_seq_lens_q, + cu_seq_lens_k=cu_seq_lens_k, + max_length_q=max_length_q, + max_length_k=max_length_k, ) hidden_states = outputs[0] diff --git a/src/transformers/tokenization_utils_base.py b/src/transformers/tokenization_utils_base.py index b5bd4fa1a391..086d6177ceb5 100644 --- a/src/transformers/tokenization_utils_base.py +++ b/src/transformers/tokenization_utils_base.py @@ -813,7 +813,7 @@ def to(self, device: Union[str, "torch.device"]) -> "BatchEncoding": # Otherwise it passes the casts down and casts the LongTensor containing the token idxs # into a HalfTensor if isinstance(device, str) or is_torch_device(device) or isinstance(device, int): - self.data = {k: v.to(device=device) for k, v in self.data.items() if v is not None} + self.data = {k: v.to(device=device) for k, v in self.data.items() if v is not None and isinstance(v, torch.Tensor)} else: logger.warning(f"Attempting to cast a BatchEncoding to type {str(device)}. This is not supported.") return self From 35b2aa621dd65d9bac62dc208f74a269dafa8666 Mon Sep 17 00:00:00 2001 From: Abhishek Date: Thu, 3 Oct 2024 17:25:12 -0400 Subject: [PATCH 02/26] fix: formatting Signed-off-by: Abhishek --- src/transformers/modeling_flash_attention_utils.py | 2 ++ src/transformers/models/llama/modeling_llama.py | 4 ++-- src/transformers/tokenization_utils_base.py | 4 +++- 3 files changed, 7 insertions(+), 3 deletions(-) diff --git a/src/transformers/modeling_flash_attention_utils.py b/src/transformers/modeling_flash_attention_utils.py index f2431892385d..d3d052a005e3 100644 --- a/src/transformers/modeling_flash_attention_utils.py +++ b/src/transformers/modeling_flash_attention_utils.py @@ -179,9 +179,11 @@ def prepare_fa2_from_position_ids(query, key, value, position_ids): return (query, key, value, indices_q, (cu_seq_lens, cu_seq_lens), (max_length, max_length)) + flash_241 = is_flash_attn_greater_or_equal("2.4.1") deterministic_g = os.environ.get("FLASH_ATTENTION_DETERMINISTIC", "0") == "1" + def _flash_attention_forward( query_states: torch.Tensor, key_states: torch.Tensor, diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 3982c905042a..df5995bc3dbf 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -499,7 +499,7 @@ def forward( key_states = key_states.to(target_dtype) value_states = value_states.to(target_dtype) - batch_size=query_states.size(0) + batch_size = query_states.size(0) query_states = query_states.reshape(-1, query_states.size(-2), query_states.size(-1)) key_states = key_states.reshape(-1, key_states.size(-2), key_states.size(-1)) value_states = value_states.reshape(-1, value_states.size(-2), value_states.size(-1)) @@ -519,7 +519,7 @@ def forward( cu_seqlens_k=cu_seq_lens_k, max_seqlen_in_batch_q=max_length_q if isinstance(max_length_q, int) else max_length_q.item(), max_seqlen_in_batch_k=max_length_k if isinstance(max_length_k, int) else max_length_k.item(), - batch_size=batch_size + batch_size=batch_size, ) attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() diff --git a/src/transformers/tokenization_utils_base.py b/src/transformers/tokenization_utils_base.py index 086d6177ceb5..9eedb9c5c237 100644 --- a/src/transformers/tokenization_utils_base.py +++ b/src/transformers/tokenization_utils_base.py @@ -813,7 +813,9 @@ def to(self, device: Union[str, "torch.device"]) -> "BatchEncoding": # Otherwise it passes the casts down and casts the LongTensor containing the token idxs # into a HalfTensor if isinstance(device, str) or is_torch_device(device) or isinstance(device, int): - self.data = {k: v.to(device=device) for k, v in self.data.items() if v is not None and isinstance(v, torch.Tensor)} + self.data = { + k: v.to(device=device) for k, v in self.data.items() if v is not None and isinstance(v, torch.Tensor) + } else: logger.warning(f"Attempting to cast a BatchEncoding to type {str(device)}. This is not supported.") return self From 5cefb84ef2f86bd49a7aae725557bdfa31e60364 Mon Sep 17 00:00:00 2001 From: Abhishek Date: Thu, 3 Oct 2024 18:43:55 -0400 Subject: [PATCH 03/26] fix: import error Signed-off-by: Abhishek --- src/transformers/tokenization_utils_base.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/transformers/tokenization_utils_base.py b/src/transformers/tokenization_utils_base.py index 9eedb9c5c237..aaa327f82f01 100644 --- a/src/transformers/tokenization_utils_base.py +++ b/src/transformers/tokenization_utils_base.py @@ -31,6 +31,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union import numpy as np +import torch from packaging import version from . import __version__ From aa7b01491a0ae664d9c17d7e82e1c8454e43c950 Mon Sep 17 00:00:00 2001 From: Abhishek Date: Mon, 7 Oct 2024 13:05:49 -0400 Subject: [PATCH 04/26] fix: Add Fa2Kwargs Signed-off-by: Abhishek --- .../modeling_flash_attention_utils.py | 16 +++---- .../models/llama/modeling_llama.py | 42 ++++--------------- src/transformers/processing_utils.py | 21 ++++++++++ 3 files changed, 38 insertions(+), 41 deletions(-) diff --git a/src/transformers/modeling_flash_attention_utils.py b/src/transformers/modeling_flash_attention_utils.py index d3d052a005e3..4281f4a305cc 100644 --- a/src/transformers/modeling_flash_attention_utils.py +++ b/src/transformers/modeling_flash_attention_utils.py @@ -198,10 +198,10 @@ def _flash_attention_forward( use_top_left_mask: bool = False, softcap: Optional[float] = None, deterministic: bool = None, - cu_seqlens_q: Optional[torch.LongTensor] = None, - cu_seqlens_k: Optional[torch.LongTensor] = None, - max_seqlen_in_batch_q: int = 0, - max_seqlen_in_batch_k: int = 0, + cu_seq_lens_q: Optional[torch.LongTensor] = None, + cu_seq_lens_k: Optional[torch.LongTensor] = None, + max_length_q: int = 0, + max_length_k: int = 0, batch_size: int = 2, ): """ @@ -281,10 +281,10 @@ def _flash_attention_forward( query_states, key_states, value_states, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_k=cu_seqlens_k, - max_seqlen_q=max_seqlen_in_batch_q, - max_seqlen_k=max_seqlen_in_batch_k, + cu_seqlens_q=cu_seq_lens_q, + cu_seqlens_k=cu_seq_lens_k, + max_seqlen_q=max_length_q, + max_seqlen_k=max_length_k, dropout_p=dropout, softmax_scale=softmax_scale, causal=causal, diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index df5995bc3dbf..0fbdb630c52a 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -49,6 +49,9 @@ logging, replace_return_docstrings, ) +from ...processing_utils import ( + Fa2Kwargs, +) from .configuration_llama import LlamaConfig @@ -422,10 +425,7 @@ def forward( use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 - cu_seq_lens_q: Optional[torch.LongTensor] = None, - cu_seq_lens_k: Optional[torch.LongTensor] = None, - max_length_q: int = 0, - max_length_k: int = 0, + **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if isinstance(past_key_value, StaticCache): raise ValueError( @@ -515,11 +515,8 @@ def forward( sliding_window=getattr(self, "sliding_window", None), use_top_left_mask=self._flash_attn_uses_top_left_mask, is_causal=self.is_causal, - cu_seqlens_q=cu_seq_lens_q, - cu_seqlens_k=cu_seq_lens_k, - max_seqlen_in_batch_q=max_length_q if isinstance(max_length_q, int) else max_length_q.item(), - max_seqlen_in_batch_k=max_length_k if isinstance(max_length_k, int) else max_length_k.item(), batch_size=batch_size, + **kwargs ) attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() @@ -658,10 +655,6 @@ def forward( use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 - cu_seq_lens_q: Optional[torch.LongTensor] = None, - cu_seq_lens_k: Optional[torch.LongTensor] = None, - max_length_q: int = 0, - max_length_k: int = 0, **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ @@ -700,10 +693,6 @@ def forward( use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, - cu_seq_lens_q=cu_seq_lens_q, - cu_seq_lens_k=cu_seq_lens_k, - max_length_q=max_length_q, - max_length_k=max_length_k, **kwargs, ) hidden_states = residual + hidden_states @@ -891,11 +880,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, - cu_seq_lens_q: Optional[torch.LongTensor] = None, - cu_seq_lens_k: Optional[torch.LongTensor] = None, - max_length_q: int = 0, - max_length_k: int = 0, + **fa2_kwargs: Fa2Kwargs, ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -979,10 +964,7 @@ def forward( use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, - cu_seq_lens_q=cu_seq_lens_q, - cu_seq_lens_k=cu_seq_lens_k, - max_length_q=max_length_q, - max_length_k=max_length_k, + **fa2_kwargs ) hidden_states = layer_outputs[0] @@ -1178,11 +1160,8 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - cu_seq_lens_q: Optional[torch.LongTensor] = None, - cu_seq_lens_k: Optional[torch.LongTensor] = None, - max_length_q: int = 0, - max_length_k: int = 0, num_logits_to_keep: int = 0, + **fa2_kwargs: Fa2Kwargs, ) -> Union[Tuple, CausalLMOutputWithPast]: r""" Args: @@ -1232,10 +1211,7 @@ def forward( output_hidden_states=output_hidden_states, return_dict=return_dict, cache_position=cache_position, - cu_seq_lens_q=cu_seq_lens_q, - cu_seq_lens_k=cu_seq_lens_k, - max_length_q=max_length_q, - max_length_k=max_length_k, + **fa2_kwargs, ) hidden_states = outputs[0] diff --git a/src/transformers/processing_utils.py b/src/transformers/processing_utils.py index 062dfe311c1d..d241894f88fa 100644 --- a/src/transformers/processing_utils.py +++ b/src/transformers/processing_utils.py @@ -23,6 +23,7 @@ import sys import typing import warnings +import torch from pathlib import Path from typing import Any, Dict, List, Optional, Tuple, TypedDict, Union @@ -77,6 +78,26 @@ else: Unpack = typing_extensions.Unpack +class Fa2Kwargs(TypedDict, total=False): + """ + Keyword arguments for Flash Attention with Compile. + + Attributes: + cu_seq_lens_q (`torch.LongTensor`, *optional*) + Gets cumlative sequence length for query state. + cu_seq_lens_k (`torch.LongTensor`, *optional*) + Gets cumlative sequence length for key state. + max_length_q (`int`, *optional*): + Maximum sequence length for query state. + max_length_k (`int`, *optional*): + Maximum sequence length for key state. + """ + + cu_seq_lens_q: Optional[torch.LongTensor] + cu_seq_lens_k: Optional[torch.LongTensor] + max_length_q: Optional[int] + max_length_k: Optional[int] + class TextKwargs(TypedDict, total=False): """ From 926481b74d2f990bde9f340f400a05e9db9ad743 Mon Sep 17 00:00:00 2001 From: Abhishek Date: Wed, 9 Oct 2024 12:36:57 -0400 Subject: [PATCH 05/26] fix: PR Changes Signed-off-by: Abhishek --- src/transformers/models/llama/modeling_llama.py | 14 +++++++------- src/transformers/processing_utils.py | 2 +- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 0fbdb630c52a..b78573777d9a 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -18,7 +18,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import math -from typing import List, Optional, Tuple, Union +from typing import List, Optional, Tuple, Union, Unpack import torch import torch.nn.functional as F @@ -50,7 +50,7 @@ replace_return_docstrings, ) from ...processing_utils import ( - Fa2Kwargs, + FlashAttentionKwargs, ) from .configuration_llama import LlamaConfig @@ -425,7 +425,7 @@ def forward( use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 - **kwargs, + **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if isinstance(past_key_value, StaticCache): raise ValueError( @@ -880,7 +880,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, - **fa2_kwargs: Fa2Kwargs, + **flash_attn_kwargs: Unpack[FlashAttentionKwargs], ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -964,7 +964,7 @@ def forward( use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, - **fa2_kwargs + **flash_attn_kwargs ) hidden_states = layer_outputs[0] @@ -1161,7 +1161,7 @@ def forward( return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, num_logits_to_keep: int = 0, - **fa2_kwargs: Fa2Kwargs, + **flash_attn_kwargs: Unpack[FlashAttentionKwargs], ) -> Union[Tuple, CausalLMOutputWithPast]: r""" Args: @@ -1211,7 +1211,7 @@ def forward( output_hidden_states=output_hidden_states, return_dict=return_dict, cache_position=cache_position, - **fa2_kwargs, + **flash_attn_kwargs, ) hidden_states = outputs[0] diff --git a/src/transformers/processing_utils.py b/src/transformers/processing_utils.py index d241894f88fa..3ea3df2cefa6 100644 --- a/src/transformers/processing_utils.py +++ b/src/transformers/processing_utils.py @@ -78,7 +78,7 @@ else: Unpack = typing_extensions.Unpack -class Fa2Kwargs(TypedDict, total=False): +class FlashAttentionKwargs(TypedDict, total=False): """ Keyword arguments for Flash Attention with Compile. From 20a4dd616d3f617e82f0dd18dd8f59aa2f41eb77 Mon Sep 17 00:00:00 2001 From: Abhishek Date: Thu, 10 Oct 2024 19:46:15 -0400 Subject: [PATCH 06/26] PR changes Signed-off-by: Abhishek --- src/transformers/models/llama/modeling_llama.py | 10 ++++------ src/transformers/processing_utils.py | 3 ++- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index d17b69ae0d7d..871e982b7e06 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -18,7 +18,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import math -from typing import List, Optional, Tuple, Union, Unpack +from typing import List, Optional, Tuple, Union import torch import torch.nn.functional as F @@ -40,6 +40,7 @@ ) from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS from ...modeling_utils import PreTrainedModel +from ...processing_utils import FlashAttentionKwargs, Unpack from ...pytorch_utils import ALL_LAYERNORM_LAYERS from ...utils import ( add_start_docstrings, @@ -48,9 +49,6 @@ logging, replace_return_docstrings, ) -from ...processing_utils import ( - FlashAttentionKwargs, -) from .configuration_llama import LlamaConfig @@ -515,7 +513,7 @@ def forward( use_top_left_mask=self._flash_attn_uses_top_left_mask, is_causal=self.is_causal, batch_size=batch_size, - **kwargs + **kwargs, ) attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() @@ -961,7 +959,7 @@ def forward( use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, - **flash_attn_kwargs + **flash_attn_kwargs, ) hidden_states = layer_outputs[0] diff --git a/src/transformers/processing_utils.py b/src/transformers/processing_utils.py index b1eebc2f6618..9a45d534888a 100644 --- a/src/transformers/processing_utils.py +++ b/src/transformers/processing_utils.py @@ -23,11 +23,11 @@ import sys import typing import warnings -import torch from pathlib import Path from typing import Any, Dict, List, Optional, Tuple, TypedDict, Union import numpy as np +import torch import typing_extensions from .dynamic_module_utils import custom_object_save @@ -78,6 +78,7 @@ else: Unpack = typing_extensions.Unpack + class FlashAttentionKwargs(TypedDict, total=False): """ Keyword arguments for Flash Attention with Compile. From 045ef161e7e57c330fef4416f6819c70a444eae6 Mon Sep 17 00:00:00 2001 From: Abhishek Date: Thu, 10 Oct 2024 19:52:58 -0400 Subject: [PATCH 07/26] PR changes Signed-off-by: Abhishek --- src/transformers/tokenization_utils_base.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/transformers/tokenization_utils_base.py b/src/transformers/tokenization_utils_base.py index 994dbb7b3036..f87208f44ac6 100644 --- a/src/transformers/tokenization_utils_base.py +++ b/src/transformers/tokenization_utils_base.py @@ -31,7 +31,6 @@ from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union import numpy as np -import torch from packaging import version from . import __version__ From d2796f6f12702ae9688fc4771040aed7a3eb973f Mon Sep 17 00:00:00 2001 From: Abhishek Date: Thu, 10 Oct 2024 20:03:14 -0400 Subject: [PATCH 08/26] PR changes Signed-off-by: Abhishek --- src/transformers/models/llama/modeling_llama.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 871e982b7e06..92317c59dd7c 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -877,6 +877,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, **flash_attn_kwargs: Unpack[FlashAttentionKwargs], ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions From 39d2868e5c93cc5f3f3c7c6ff981b66614c0e0e4 Mon Sep 17 00:00:00 2001 From: Abhishek Date: Thu, 10 Oct 2024 20:15:34 -0400 Subject: [PATCH 09/26] PR changes Signed-off-by: Abhishek --- src/transformers/processing_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/processing_utils.py b/src/transformers/processing_utils.py index 9a45d534888a..05a9487480bc 100644 --- a/src/transformers/processing_utils.py +++ b/src/transformers/processing_utils.py @@ -27,7 +27,6 @@ from typing import Any, Dict, List, Optional, Tuple, TypedDict, Union import numpy as np -import torch import typing_extensions from .dynamic_module_utils import custom_object_save @@ -93,6 +92,7 @@ class FlashAttentionKwargs(TypedDict, total=False): max_length_k (`int`, *optional*): Maximum sequence length for key state. """ + import torch cu_seq_lens_q: Optional[torch.LongTensor] cu_seq_lens_k: Optional[torch.LongTensor] From 83747b53f08e4418681372fdf3491d0fa5d913af Mon Sep 17 00:00:00 2001 From: Abhishek Date: Thu, 10 Oct 2024 20:30:54 -0400 Subject: [PATCH 10/26] Revert "PR changes" This reverts commit 39d2868e5c93cc5f3f3c7c6ff981b66614c0e0e4. --- src/transformers/processing_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/processing_utils.py b/src/transformers/processing_utils.py index 05a9487480bc..9a45d534888a 100644 --- a/src/transformers/processing_utils.py +++ b/src/transformers/processing_utils.py @@ -27,6 +27,7 @@ from typing import Any, Dict, List, Optional, Tuple, TypedDict, Union import numpy as np +import torch import typing_extensions from .dynamic_module_utils import custom_object_save @@ -92,7 +93,6 @@ class FlashAttentionKwargs(TypedDict, total=False): max_length_k (`int`, *optional*): Maximum sequence length for key state. """ - import torch cu_seq_lens_q: Optional[torch.LongTensor] cu_seq_lens_k: Optional[torch.LongTensor] From b642d45b7eab39d2daa70566e0fab2bbf061b86b Mon Sep 17 00:00:00 2001 From: Abhishek Date: Thu, 10 Oct 2024 20:51:43 -0400 Subject: [PATCH 11/26] PR changes Signed-off-by: Abhishek --- src/transformers/models/cohere/modeling_cohere.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/transformers/models/cohere/modeling_cohere.py b/src/transformers/models/cohere/modeling_cohere.py index a5d3721f5bdb..aa3e6dd89ba3 100644 --- a/src/transformers/models/cohere/modeling_cohere.py +++ b/src/transformers/models/cohere/modeling_cohere.py @@ -40,6 +40,7 @@ ) from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS from ...modeling_utils import PreTrainedModel +from ...processing_utils import FlashAttentionKwargs, Unpack from ...pytorch_utils import ALL_LAYERNORM_LAYERS from ...utils import ( add_start_docstrings, @@ -833,6 +834,7 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, + **flash_attn_kwargs: Unpack[FlashAttentionKwargs], ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -914,6 +916,7 @@ def forward( use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, + **flash_attn_kwargs, ) hidden_states = layer_outputs[0] From d03e673d8132f5342f13904918073147350f7d2d Mon Sep 17 00:00:00 2001 From: Abhishek Date: Mon, 14 Oct 2024 14:23:56 -0400 Subject: [PATCH 12/26] fix: FlashAttentionKwarg Signed-off-by: Abhishek --- .../modeling_flash_attention_utils.py | 23 ++++++++++++++++++- .../models/cohere/modeling_cohere.py | 4 ++-- .../models/llama/modeling_llama.py | 4 ++-- src/transformers/processing_utils.py | 22 ------------------ 4 files changed, 26 insertions(+), 27 deletions(-) diff --git a/src/transformers/modeling_flash_attention_utils.py b/src/transformers/modeling_flash_attention_utils.py index 2dd69298b1ea..a6607e21a55d 100644 --- a/src/transformers/modeling_flash_attention_utils.py +++ b/src/transformers/modeling_flash_attention_utils.py @@ -15,7 +15,7 @@ import inspect import os -from typing import Optional, Tuple +from typing import Optional, Tuple, TypedDict import torch import torch.nn.functional as F @@ -299,3 +299,24 @@ def _flash_attention_forward( ) return attn_output + + +class FlashAttentionKwargs(TypedDict, total=False): + """ + Keyword arguments for Flash Attention with Compile. + + Attributes: + cu_seq_lens_q (`torch.LongTensor`, *optional*) + Gets cumlative sequence length for query state. + cu_seq_lens_k (`torch.LongTensor`, *optional*) + Gets cumlative sequence length for key state. + max_length_q (`int`, *optional*): + Maximum sequence length for query state. + max_length_k (`int`, *optional*): + Maximum sequence length for key state. + """ + + cu_seq_lens_q: Optional[torch.LongTensor] + cu_seq_lens_k: Optional[torch.LongTensor] + max_length_q: Optional[int] + max_length_k: Optional[int] diff --git a/src/transformers/models/cohere/modeling_cohere.py b/src/transformers/models/cohere/modeling_cohere.py index aa3e6dd89ba3..307ab3602695 100644 --- a/src/transformers/models/cohere/modeling_cohere.py +++ b/src/transformers/models/cohere/modeling_cohere.py @@ -40,7 +40,7 @@ ) from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS from ...modeling_utils import PreTrainedModel -from ...processing_utils import FlashAttentionKwargs, Unpack +from ...processing_utils import Unpack from ...pytorch_utils import ALL_LAYERNORM_LAYERS from ...utils import ( add_start_docstrings, @@ -54,7 +54,7 @@ if is_flash_attn_2_available(): - from ...modeling_flash_attention_utils import _flash_attention_forward + from ...modeling_flash_attention_utils import FlashAttentionKwargs, _flash_attention_forward logger = logging.get_logger(__name__) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 92317c59dd7c..5d71d2b34e2a 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -30,7 +30,7 @@ from ...cache_utils import Cache, DynamicCache, StaticCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter -from ...modeling_flash_attention_utils import _flash_attention_forward +from ...modeling_flash_attention_utils import FlashAttentionKwargs, _flash_attention_forward from ...modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, @@ -40,7 +40,7 @@ ) from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS from ...modeling_utils import PreTrainedModel -from ...processing_utils import FlashAttentionKwargs, Unpack +from ...processing_utils import Unpack from ...pytorch_utils import ALL_LAYERNORM_LAYERS from ...utils import ( add_start_docstrings, diff --git a/src/transformers/processing_utils.py b/src/transformers/processing_utils.py index 9a45d534888a..cb2327e5c46b 100644 --- a/src/transformers/processing_utils.py +++ b/src/transformers/processing_utils.py @@ -27,7 +27,6 @@ from typing import Any, Dict, List, Optional, Tuple, TypedDict, Union import numpy as np -import torch import typing_extensions from .dynamic_module_utils import custom_object_save @@ -79,27 +78,6 @@ Unpack = typing_extensions.Unpack -class FlashAttentionKwargs(TypedDict, total=False): - """ - Keyword arguments for Flash Attention with Compile. - - Attributes: - cu_seq_lens_q (`torch.LongTensor`, *optional*) - Gets cumlative sequence length for query state. - cu_seq_lens_k (`torch.LongTensor`, *optional*) - Gets cumlative sequence length for key state. - max_length_q (`int`, *optional*): - Maximum sequence length for query state. - max_length_k (`int`, *optional*): - Maximum sequence length for key state. - """ - - cu_seq_lens_q: Optional[torch.LongTensor] - cu_seq_lens_k: Optional[torch.LongTensor] - max_length_q: Optional[int] - max_length_k: Optional[int] - - class TextKwargs(TypedDict, total=False): """ Keyword arguments for text processing. For extended documentation, check out tokenization_utils_base methods and From 80e0d5fd27083eeb5d1f052ee290d3b8def19f90 Mon Sep 17 00:00:00 2001 From: Abhishek Date: Mon, 14 Oct 2024 14:34:23 -0400 Subject: [PATCH 13/26] fix: FlashAttentionKwarg Signed-off-by: Abhishek --- src/transformers/models/cohere/modeling_cohere.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/cohere/modeling_cohere.py b/src/transformers/models/cohere/modeling_cohere.py index 307ab3602695..46c6793aadd4 100644 --- a/src/transformers/models/cohere/modeling_cohere.py +++ b/src/transformers/models/cohere/modeling_cohere.py @@ -34,6 +34,7 @@ from ...cache_utils import Cache, DynamicCache, StaticCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, @@ -54,7 +55,7 @@ if is_flash_attn_2_available(): - from ...modeling_flash_attention_utils import FlashAttentionKwargs, _flash_attention_forward + from ...modeling_flash_attention_utils import _flash_attention_forward logger = logging.get_logger(__name__) From ca42b8b03ea2ed2438badce0fbc1b518cf2c85cd Mon Sep 17 00:00:00 2001 From: Abhishek Date: Tue, 15 Oct 2024 10:23:32 -0400 Subject: [PATCH 14/26] PR Changes Signed-off-by: Abhishek --- .../modeling_flash_attention_utils.py | 16 +++++++++++++++- src/transformers/models/llama/modeling_llama.py | 6 ------ 2 files changed, 15 insertions(+), 7 deletions(-) diff --git a/src/transformers/modeling_flash_attention_utils.py b/src/transformers/modeling_flash_attention_utils.py index a6607e21a55d..92d47fe9672e 100644 --- a/src/transformers/modeling_flash_attention_utils.py +++ b/src/transformers/modeling_flash_attention_utils.py @@ -202,7 +202,6 @@ def _flash_attention_forward( cu_seq_lens_k: Optional[torch.LongTensor] = None, max_length_q: int = 0, max_length_k: int = 0, - batch_size: int = 2, ): """ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token @@ -277,6 +276,21 @@ def _flash_attention_forward( # then we probably have one sequence, otherwise it is packed. Additionally check we are in pre-fill/training stage. # Use `flash_attn_varlen_func` to prevent cross-example attention and also allow padding free approach elif position_ids is not None and max_length_q is not None: + batch_size = query_states.size(0) + + if cu_seq_lens_q is None or cu_seq_lens_k is None: + query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = prepare_fa2_from_position_ids( + query_states, key_states, value_states, position_ids + ) + + cu_seq_lens_q, cu_seq_lens_q = cu_seq_lens + max_length_q, max_length_k = max_seq_lens + + else: + query_states = query_states.reshape(-1, query_states.size(-2), query_states.size(-1)) + key_states = key_states.reshape(-1, key_states.size(-2), key_states.size(-1)) + value_states = value_states.reshape(-1, value_states.size(-2), value_states.size(-1)) + attn_output = flash_attn_varlen_func( query_states, key_states, diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 5d71d2b34e2a..0d7a8ec03405 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -496,11 +496,6 @@ def forward( key_states = key_states.to(target_dtype) value_states = value_states.to(target_dtype) - batch_size = query_states.size(0) - query_states = query_states.reshape(-1, query_states.size(-2), query_states.size(-1)) - key_states = key_states.reshape(-1, key_states.size(-2), key_states.size(-1)) - value_states = value_states.reshape(-1, value_states.size(-2), value_states.size(-1)) - attn_output = _flash_attention_forward( query_states, key_states, @@ -512,7 +507,6 @@ def forward( sliding_window=getattr(self, "sliding_window", None), use_top_left_mask=self._flash_attn_uses_top_left_mask, is_causal=self.is_causal, - batch_size=batch_size, **kwargs, ) From b8d2568f506c4c71d7870bc19752b9ff0e920ff1 Mon Sep 17 00:00:00 2001 From: Abhishek Date: Tue, 15 Oct 2024 10:30:09 -0400 Subject: [PATCH 15/26] PR Changes Signed-off-by: Abhishek --- src/transformers/modeling_flash_attention_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/modeling_flash_attention_utils.py b/src/transformers/modeling_flash_attention_utils.py index 92d47fe9672e..f0126bdcdee8 100644 --- a/src/transformers/modeling_flash_attention_utils.py +++ b/src/transformers/modeling_flash_attention_utils.py @@ -200,8 +200,8 @@ def _flash_attention_forward( deterministic: bool = None, cu_seq_lens_q: Optional[torch.LongTensor] = None, cu_seq_lens_k: Optional[torch.LongTensor] = None, - max_length_q: int = 0, - max_length_k: int = 0, + max_length_q: Optional[int] = None, + max_length_k: Optional[int] = None, ): """ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token From ae11c96d66b810df0bef252cb186fd88f5ac5437 Mon Sep 17 00:00:00 2001 From: Abhishek Date: Tue, 15 Oct 2024 10:40:56 -0400 Subject: [PATCH 16/26] PR Changes Signed-off-by: Abhishek --- src/transformers/modeling_flash_attention_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/modeling_flash_attention_utils.py b/src/transformers/modeling_flash_attention_utils.py index f0126bdcdee8..bc4b4dd22c77 100644 --- a/src/transformers/modeling_flash_attention_utils.py +++ b/src/transformers/modeling_flash_attention_utils.py @@ -279,8 +279,8 @@ def _flash_attention_forward( batch_size = query_states.size(0) if cu_seq_lens_q is None or cu_seq_lens_k is None: - query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = prepare_fa2_from_position_ids( - query_states, key_states, value_states, position_ids + query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = ( + prepare_fa2_from_position_ids(query_states, key_states, value_states, position_ids) ) cu_seq_lens_q, cu_seq_lens_q = cu_seq_lens From 76c51cad32927d118b84262a348bfc2c74c80761 Mon Sep 17 00:00:00 2001 From: Abhishek Date: Tue, 15 Oct 2024 11:25:41 -0400 Subject: [PATCH 17/26] PR Changes Signed-off-by: Abhishek --- src/transformers/modeling_flash_attention_utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/transformers/modeling_flash_attention_utils.py b/src/transformers/modeling_flash_attention_utils.py index bc4b4dd22c77..f354a3f8a382 100644 --- a/src/transformers/modeling_flash_attention_utils.py +++ b/src/transformers/modeling_flash_attention_utils.py @@ -275,7 +275,9 @@ def _flash_attention_forward( # If position_ids is provided and check all examples do not contain only 1 sequence, If tensor in increasing # then we probably have one sequence, otherwise it is packed. Additionally check we are in pre-fill/training stage. # Use `flash_attn_varlen_func` to prevent cross-example attention and also allow padding free approach - elif position_ids is not None and max_length_q is not None: + elif position_ids is not None and ( + max_length_q is not None or (query_length != 1 and not (torch.diff(position_ids, dim=-1) >= 0).all()) + ): batch_size = query_states.size(0) if cu_seq_lens_q is None or cu_seq_lens_k is None: From 77c7a3db2744dd2a69a167847e8fd707874728e7 Mon Sep 17 00:00:00 2001 From: Abhishek Date: Wed, 16 Oct 2024 17:00:42 -0400 Subject: [PATCH 18/26] PR Changes Signed-off-by: Abhishek --- src/transformers/tokenization_utils_base.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/transformers/tokenization_utils_base.py b/src/transformers/tokenization_utils_base.py index f87208f44ac6..e5a0195aa1fc 100644 --- a/src/transformers/tokenization_utils_base.py +++ b/src/transformers/tokenization_utils_base.py @@ -815,9 +815,7 @@ def to(self, device: Union[str, "torch.device"]) -> "BatchEncoding": # Otherwise it passes the casts down and casts the LongTensor containing the token idxs # into a HalfTensor if isinstance(device, str) or is_torch_device(device) or isinstance(device, int): - self.data = { - k: v.to(device=device) for k, v in self.data.items() if v is not None and isinstance(v, torch.Tensor) - } + self.data = {k: v.to(device=device) if isinstance(v, torch.Tensor) else v for k, v in self.data.items()} else: logger.warning(f"Attempting to cast a BatchEncoding to type {str(device)}. This is not supported.") return self From 391715aa5e7a2790d0569f5e645b78fdcb40f048 Mon Sep 17 00:00:00 2001 From: Abhishek Date: Fri, 18 Oct 2024 19:30:00 -0400 Subject: [PATCH 19/26] addition of documentation Signed-off-by: Abhishek --- docs/source/en/llm_optims.md | 93 ++++++++++++++++++++++++++++++++++++ 1 file changed, 93 insertions(+) diff --git a/docs/source/en/llm_optims.md b/docs/source/en/llm_optims.md index 16be638498df..0a6a7e15bea0 100644 --- a/docs/source/en/llm_optims.md +++ b/docs/source/en/llm_optims.md @@ -348,6 +348,99 @@ model = AutoModelForCausalLM.from_pretrained( ) ``` +### Fine-Tuning with torch.compile and Padding-Free Data Collation + +In addition to optimizing inference, you can also enhance the training efficiency of large language models by leveraging torch.compile during fine-tuning and using a padding-free data collator. This approach can significantly speed up training and reduce computational overhead. + +Here's how you can fine-tune a Llama model using SFTTrainer from the TRL library, with torch_compile enabled and a padding-free data collator: + +``` +#################### IMPORTS ################### + +import math +import datasets +import dataclasses +from transformers import ( + AutoModelForCausalLM, + AutoTokenizer, + TrainingArguments +) +from trl import SFTConfig, SFTTrainer, DataCollatorForCompletionOnlyLM + +#################### MODEL LOADING WITH FLASH ATTENTION ################### + +model_name = "meta-llama/Llama-3.2-1B" +model = AutoModelForCausalLM.from_pretrained( + model_name, + attn_implementation="flash_attention_2" # Enables FlashAttention-2 +) +tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True) + +#################### DATA PREPROCESSING (PADDING-FREE) ################### + +response_template = "\n### Label:" +response_template_ids = tokenizer.encode( + response_template, add_special_tokens=False +)[2:] # Exclude special tokens + +data_collator = DataCollatorForCompletionOnlyLM( + response_template_ids=response_template_ids, + tokenizer=tokenizer, + ignore_index=-100, + padding_free=True # Enables padding-free collation +) + +def format_dataset(example): + return { + "output": example["output"] + tokenizer.eos_token + } + +data_files = {"train": "path/to/dataset"} # Replace with your dataset path +json_dataset = datasets.load_dataset("json", data_files=data_files) +formatted_train_dataset = json_dataset["train"].map(format_dataset) + +################# TRAINING CONFIGURATION ############################ + +train_args = TrainingArguments( + num_train_epochs=5, + per_device_train_batch_size=4, + per_device_eval_batch_size=4, + gradient_accumulation_steps=4, + learning_rate=1e-5, + weight_decay=0.0, + warmup_ratio=0.03, + lr_scheduler_type="cosine", + logging_steps=1, + include_tokens_per_second=True, + save_strategy="epoch", + output_dir="output", + torch_compile=True, # Enables torch.compile + torch_compile_backend="inductor", + torch_compile_mode="default" +) + +# Convert TrainingArguments to SFTConfig +transformer_train_arg_fields = [x.name for x in dataclasses.fields(SFTConfig)] +transformer_kwargs = { + k: v + for k, v in train_args.to_dict().items() + if k in transformer_train_arg_fields +} +training_args = SFTConfig(**transformer_kwargs) + +####################### FINE-TUNING ##################### + +trainer = SFTTrainer( + model=model, + tokenizer=tokenizer, + train_dataset=formatted_train_dataset, + data_collator=data_collator, + dataset_text_field="output", + args=training_args, +) +trainer.train() +``` + ### PyTorch scaled dot product attention Scaled dot product attention (SDPA) is automatically enabled in PyTorch 2.0 and it supports FlashAttention, xFormers, and PyTorch's C++ implementation. SDPA chooses the most performant attention algorithm if you're using a CUDA backend. For other backends, SDPA defaults to the PyTorch C++ implementation. From f23c9555cccd2979ec6da764b1609dcfc907d722 Mon Sep 17 00:00:00 2001 From: Abhishek Date: Mon, 21 Oct 2024 19:09:44 -0400 Subject: [PATCH 20/26] change in _flash_attention_forward Signed-off-by: Abhishek --- src/transformers/modeling_flash_attention_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/modeling_flash_attention_utils.py b/src/transformers/modeling_flash_attention_utils.py index f354a3f8a382..045d2f6d6460 100644 --- a/src/transformers/modeling_flash_attention_utils.py +++ b/src/transformers/modeling_flash_attention_utils.py @@ -285,7 +285,7 @@ def _flash_attention_forward( prepare_fa2_from_position_ids(query_states, key_states, value_states, position_ids) ) - cu_seq_lens_q, cu_seq_lens_q = cu_seq_lens + cu_seq_lens_q, cu_seq_lens_k = cu_seq_lens max_length_q, max_length_k = max_seq_lens else: From 67c78283da539f24c049ce26708c18d376bea163 Mon Sep 17 00:00:00 2001 From: Abhishek Date: Tue, 22 Oct 2024 11:59:35 -0400 Subject: [PATCH 21/26] make fix-copies Signed-off-by: Abhishek --- src/transformers/models/glm/modeling_glm.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/glm/modeling_glm.py b/src/transformers/models/glm/modeling_glm.py index a458c02a6fed..afbc9256eeb8 100644 --- a/src/transformers/models/glm/modeling_glm.py +++ b/src/transformers/models/glm/modeling_glm.py @@ -50,8 +50,8 @@ if is_flash_attn_2_available(): from ...modeling_flash_attention_utils import _flash_attention_forward - -from ...modeling_flash_attention_utils import _flash_attention_forward +from ...modeling_flash_attention_utils import FlashAttentionKwargs, _flash_attention_forward +from ...processing_utils import Unpack class GlmRMSNorm(nn.Module): @@ -736,6 +736,7 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, + **flash_attn_kwargs: Unpack[FlashAttentionKwargs], ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -817,6 +818,7 @@ def forward( use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, + **flash_attn_kwargs, ) hidden_states = layer_outputs[0] From 8d2ec29995cabedcb9e889606e017a3e7e8eb1d5 Mon Sep 17 00:00:00 2001 From: Abhishek Date: Tue, 22 Oct 2024 12:24:59 -0400 Subject: [PATCH 22/26] revert make fix-copies Signed-off-by: Abhishek --- src/transformers/models/glm/modeling_glm.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/transformers/models/glm/modeling_glm.py b/src/transformers/models/glm/modeling_glm.py index afbc9256eeb8..a458c02a6fed 100644 --- a/src/transformers/models/glm/modeling_glm.py +++ b/src/transformers/models/glm/modeling_glm.py @@ -50,8 +50,8 @@ if is_flash_attn_2_available(): from ...modeling_flash_attention_utils import _flash_attention_forward -from ...modeling_flash_attention_utils import FlashAttentionKwargs, _flash_attention_forward -from ...processing_utils import Unpack + +from ...modeling_flash_attention_utils import _flash_attention_forward class GlmRMSNorm(nn.Module): @@ -736,7 +736,6 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - **flash_attn_kwargs: Unpack[FlashAttentionKwargs], ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -818,7 +817,6 @@ def forward( use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, - **flash_attn_kwargs, ) hidden_states = layer_outputs[0] From 5a903da8d453e4d88407bd943be25b5027ed7095 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Wed, 23 Oct 2024 16:25:40 +0200 Subject: [PATCH 23/26] fix copies --- src/transformers/models/glm/modeling_glm.py | 14 +++++++++++++- src/transformers/models/glm/modular_glm.py | 1 + 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/glm/modeling_glm.py b/src/transformers/models/glm/modeling_glm.py index a458c02a6fed..89cbc6da5493 100644 --- a/src/transformers/models/glm/modeling_glm.py +++ b/src/transformers/models/glm/modeling_glm.py @@ -38,6 +38,7 @@ ) from ...modeling_utils import PreTrainedModel from ...utils import ( + add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_2_available, @@ -51,7 +52,11 @@ if is_flash_attn_2_available(): from ...modeling_flash_attention_utils import _flash_attention_forward -from ...modeling_flash_attention_utils import _flash_attention_forward +from ...modeling_flash_attention_utils import FlashAttentionKwargs, _flash_attention_forward +from ...processing_utils import Unpack + + +_CHECKPOINT_FOR_DOC = "dummy" class GlmRMSNorm(nn.Module): @@ -736,6 +741,7 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, + **flash_attn_kwargs: Unpack[FlashAttentionKwargs], ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -817,6 +823,7 @@ def forward( use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, + **flash_attn_kwargs, ) hidden_states = layer_outputs[0] @@ -1221,6 +1228,11 @@ def set_input_embeddings(self, value): self.model.embed_tokens = value @add_start_docstrings_to_model_forward(GLM_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TokenClassifierOutput, + config_class=_CONFIG_FOR_DOC, + ) def forward( self, input_ids: Optional[torch.LongTensor] = None, diff --git a/src/transformers/models/glm/modular_glm.py b/src/transformers/models/glm/modular_glm.py index 55bf89d1c56b..2adba8b36f48 100644 --- a/src/transformers/models/glm/modular_glm.py +++ b/src/transformers/models/glm/modular_glm.py @@ -46,6 +46,7 @@ logger = logging.get_logger(__name__) +_CHECKPOINT_FOR_DOC = "dummy" class GlmRMSNorm(Phi3RMSNorm): pass From 05f9a80dddbe9b6f4258e46481884719d9039125 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Wed, 23 Oct 2024 16:49:25 +0200 Subject: [PATCH 24/26] style --- src/transformers/models/glm/modular_glm.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/transformers/models/glm/modular_glm.py b/src/transformers/models/glm/modular_glm.py index 2adba8b36f48..c26477fdc173 100644 --- a/src/transformers/models/glm/modular_glm.py +++ b/src/transformers/models/glm/modular_glm.py @@ -48,6 +48,7 @@ _CHECKPOINT_FOR_DOC = "dummy" + class GlmRMSNorm(Phi3RMSNorm): pass From a6e2601845b70cddb1fe6e51f9da962d0c8a4bb1 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Wed, 23 Oct 2024 17:55:09 +0200 Subject: [PATCH 25/26] loss kwargs typing --- .../models/llama/modeling_llama.py | 18 ++++++------------ src/transformers/utils/__init__.py | 1 + src/transformers/utils/generic.py | 15 ++++++++++++++- 3 files changed, 21 insertions(+), 13 deletions(-) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 7559442b3bcf..d8f5fb93ba20 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -42,6 +42,7 @@ from ...processing_utils import Unpack from ...pytorch_utils import ALL_LAYERNORM_LAYERS from ...utils import ( + LossKwargs, add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, @@ -1107,6 +1108,8 @@ def _prepare_4d_causal_attention_mask_with_cache_position( return causal_mask +class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs):... + class LlamaForCausalLM(LlamaPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] @@ -1153,7 +1156,7 @@ def forward( return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, num_logits_to_keep: int = 0, - **kwargs, + **kwargs: Unpack[KwargsForCausalLM], ) -> Union[Tuple, CausalLMOutputWithPast]: r""" Args: @@ -1191,15 +1194,6 @@ def forward( ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict - flash_attn_keys = [ - "cu_seq_lens_q", - "cu_seq_lens_k", - "max_length_q", - "max_length_k", - ] - flash_attn_kwargs = {k: kwargs.pop(k) for k in flash_attn_keys if k in kwargs} - loss_kwargs = kwargs - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = self.model( input_ids=input_ids, @@ -1212,7 +1206,7 @@ def forward( output_hidden_states=output_hidden_states, return_dict=return_dict, cache_position=cache_position, - **flash_attn_kwargs, + **kwargs, ) hidden_states = outputs[0] @@ -1226,7 +1220,7 @@ def forward( loss = None if labels is not None: - loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **loss_kwargs) + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) if not return_dict: output = (logits,) + outputs[1:] diff --git a/src/transformers/utils/__init__.py b/src/transformers/utils/__init__.py index a781389c2fbd..2a10bcaa3c94 100755 --- a/src/transformers/utils/__init__.py +++ b/src/transformers/utils/__init__.py @@ -37,6 +37,7 @@ from .generic import ( ContextManagers, ExplicitEnum, + LossKwargs, ModelOutput, PaddingStrategy, TensorType, diff --git a/src/transformers/utils/generic.py b/src/transformers/utils/generic.py index a5f01fa2e0df..26ec82b20fd4 100644 --- a/src/transformers/utils/generic.py +++ b/src/transformers/utils/generic.py @@ -24,7 +24,7 @@ from dataclasses import fields, is_dataclass from enum import Enum from functools import partial, wraps -from typing import Any, ContextManager, Iterable, List, Optional, Tuple +from typing import Any, ContextManager, Iterable, List, Optional, Tuple, TypedDict import numpy as np from packaging import version @@ -854,3 +854,16 @@ def wrapper(*args, **kwargs): return wrapper return decorator + + +class LossKwargs(TypedDict, total=False): + """ + Keyword arguments to be passed to the loss function + + Attributes: + num_items_in_batch (`int`, *optional*): + Number of items in the batch. It is recommended to pass it when + you are doing gradient accumulation. + """ + + num_items_in_batch: Optional[int] From cb08b6371101dae92867d52b1a3b24405779d46f Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Thu, 24 Oct 2024 09:27:28 +0200 Subject: [PATCH 26/26] style and pull latest changes --- src/transformers/models/llama/modeling_llama.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index d8f5fb93ba20..4d95f01849d6 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -1108,7 +1108,8 @@ def _prepare_4d_causal_attention_mask_with_cache_position( return causal_mask -class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs):... +class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... + class LlamaForCausalLM(LlamaPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"]