diff --git a/docs/source/en/llm_optims.md b/docs/source/en/llm_optims.md index 16be638498df..0a6a7e15bea0 100644 --- a/docs/source/en/llm_optims.md +++ b/docs/source/en/llm_optims.md @@ -348,6 +348,99 @@ model = AutoModelForCausalLM.from_pretrained( ) ``` +### Fine-Tuning with torch.compile and Padding-Free Data Collation + +In addition to optimizing inference, you can also enhance the training efficiency of large language models by leveraging torch.compile during fine-tuning and using a padding-free data collator. This approach can significantly speed up training and reduce computational overhead. + +Here's how you can fine-tune a Llama model using SFTTrainer from the TRL library, with torch_compile enabled and a padding-free data collator: + +``` +#################### IMPORTS ################### + +import math +import datasets +import dataclasses +from transformers import ( + AutoModelForCausalLM, + AutoTokenizer, + TrainingArguments +) +from trl import SFTConfig, SFTTrainer, DataCollatorForCompletionOnlyLM + +#################### MODEL LOADING WITH FLASH ATTENTION ################### + +model_name = "meta-llama/Llama-3.2-1B" +model = AutoModelForCausalLM.from_pretrained( + model_name, + attn_implementation="flash_attention_2" # Enables FlashAttention-2 +) +tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True) + +#################### DATA PREPROCESSING (PADDING-FREE) ################### + +response_template = "\n### Label:" +response_template_ids = tokenizer.encode( + response_template, add_special_tokens=False +)[2:] # Exclude special tokens + +data_collator = DataCollatorForCompletionOnlyLM( + response_template_ids=response_template_ids, + tokenizer=tokenizer, + ignore_index=-100, + padding_free=True # Enables padding-free collation +) + +def format_dataset(example): + return { + "output": example["output"] + tokenizer.eos_token + } + +data_files = {"train": "path/to/dataset"} # Replace with your dataset path +json_dataset = datasets.load_dataset("json", data_files=data_files) +formatted_train_dataset = json_dataset["train"].map(format_dataset) + +################# TRAINING CONFIGURATION ############################ + +train_args = TrainingArguments( + num_train_epochs=5, + per_device_train_batch_size=4, + per_device_eval_batch_size=4, + gradient_accumulation_steps=4, + learning_rate=1e-5, + weight_decay=0.0, + warmup_ratio=0.03, + lr_scheduler_type="cosine", + logging_steps=1, + include_tokens_per_second=True, + save_strategy="epoch", + output_dir="output", + torch_compile=True, # Enables torch.compile + torch_compile_backend="inductor", + torch_compile_mode="default" +) + +# Convert TrainingArguments to SFTConfig +transformer_train_arg_fields = [x.name for x in dataclasses.fields(SFTConfig)] +transformer_kwargs = { + k: v + for k, v in train_args.to_dict().items() + if k in transformer_train_arg_fields +} +training_args = SFTConfig(**transformer_kwargs) + +####################### FINE-TUNING ##################### + +trainer = SFTTrainer( + model=model, + tokenizer=tokenizer, + train_dataset=formatted_train_dataset, + data_collator=data_collator, + dataset_text_field="output", + args=training_args, +) +trainer.train() +``` + ### PyTorch scaled dot product attention Scaled dot product attention (SDPA) is automatically enabled in PyTorch 2.0 and it supports FlashAttention, xFormers, and PyTorch's C++ implementation. SDPA chooses the most performant attention algorithm if you're using a CUDA backend. For other backends, SDPA defaults to the PyTorch C++ implementation. diff --git a/src/transformers/modeling_flash_attention_utils.py b/src/transformers/modeling_flash_attention_utils.py index da961c6060e4..045d2f6d6460 100644 --- a/src/transformers/modeling_flash_attention_utils.py +++ b/src/transformers/modeling_flash_attention_utils.py @@ -15,7 +15,7 @@ import inspect import os -from typing import Optional, Tuple +from typing import Optional, Tuple, TypedDict import torch import torch.nn.functional as F @@ -180,6 +180,10 @@ def prepare_fa2_from_position_ids(query, key, value, position_ids): return (query, key, value, indices_q, (cu_seq_lens, cu_seq_lens), (max_length, max_length)) +flash_241 = is_flash_attn_greater_or_equal("2.4.1") +deterministic_g = os.environ.get("FLASH_ATTENTION_DETERMINISTIC", "0") == "1" + + def _flash_attention_forward( query_states: torch.Tensor, key_states: torch.Tensor, @@ -194,6 +198,10 @@ def _flash_attention_forward( use_top_left_mask: bool = False, softcap: Optional[float] = None, deterministic: bool = None, + cu_seq_lens_q: Optional[torch.LongTensor] = None, + cu_seq_lens_k: Optional[torch.LongTensor] = None, + max_length_q: Optional[int] = None, + max_length_k: Optional[int] = None, ): """ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token @@ -232,9 +240,9 @@ def _flash_attention_forward( ) flash_kwargs = {"window_size": (sliding_window, sliding_window)} if use_sliding_windows else {} - if is_flash_attn_greater_or_equal("2.4.1"): + if flash_241: if deterministic is None: - deterministic = os.environ.get("FLASH_ATTENTION_DETERMINISTIC", "0") == "1" + deterministic = deterministic_g flash_kwargs["deterministic"] = deterministic if softcap is not None: @@ -267,24 +275,32 @@ def _flash_attention_forward( # If position_ids is provided and check all examples do not contain only 1 sequence, If tensor in increasing # then we probably have one sequence, otherwise it is packed. Additionally check we are in pre-fill/training stage. # Use `flash_attn_varlen_func` to prevent cross-example attention and also allow padding free approach - # Note: the `torch.diff(...)` condition is last to use short-circuit and avoid the cuda synchronization it incurs during inference (query_length == 1 always) - elif position_ids is not None and query_length != 1 and not (torch.diff(position_ids, dim=-1) >= 0).all(): + elif position_ids is not None and ( + max_length_q is not None or (query_length != 1 and not (torch.diff(position_ids, dim=-1) >= 0).all()) + ): batch_size = query_states.size(0) - query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = prepare_fa2_from_position_ids( - query_states, key_states, value_states, position_ids - ) - cu_seqlens_q, cu_seqlens_k = cu_seq_lens - max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + if cu_seq_lens_q is None or cu_seq_lens_k is None: + query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = ( + prepare_fa2_from_position_ids(query_states, key_states, value_states, position_ids) + ) + + cu_seq_lens_q, cu_seq_lens_k = cu_seq_lens + max_length_q, max_length_k = max_seq_lens + + else: + query_states = query_states.reshape(-1, query_states.size(-2), query_states.size(-1)) + key_states = key_states.reshape(-1, key_states.size(-2), key_states.size(-1)) + value_states = value_states.reshape(-1, value_states.size(-2), value_states.size(-1)) attn_output = flash_attn_varlen_func( query_states, key_states, value_states, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_k=cu_seqlens_k, - max_seqlen_q=max_seqlen_in_batch_q, - max_seqlen_k=max_seqlen_in_batch_k, + cu_seqlens_q=cu_seq_lens_q, + cu_seqlens_k=cu_seq_lens_k, + max_seqlen_q=max_length_q, + max_seqlen_k=max_length_k, dropout_p=dropout, softmax_scale=softmax_scale, causal=causal, @@ -299,3 +315,24 @@ def _flash_attention_forward( ) return attn_output + + +class FlashAttentionKwargs(TypedDict, total=False): + """ + Keyword arguments for Flash Attention with Compile. + + Attributes: + cu_seq_lens_q (`torch.LongTensor`, *optional*) + Gets cumlative sequence length for query state. + cu_seq_lens_k (`torch.LongTensor`, *optional*) + Gets cumlative sequence length for key state. + max_length_q (`int`, *optional*): + Maximum sequence length for query state. + max_length_k (`int`, *optional*): + Maximum sequence length for key state. + """ + + cu_seq_lens_q: Optional[torch.LongTensor] + cu_seq_lens_k: Optional[torch.LongTensor] + max_length_q: Optional[int] + max_length_k: Optional[int] diff --git a/src/transformers/models/cohere/modeling_cohere.py b/src/transformers/models/cohere/modeling_cohere.py index 9aa588be4310..b215fb6561bf 100644 --- a/src/transformers/models/cohere/modeling_cohere.py +++ b/src/transformers/models/cohere/modeling_cohere.py @@ -33,12 +33,14 @@ from ...cache_utils import Cache, DynamicCache, StaticCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, ) from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS from ...modeling_utils import PreTrainedModel +from ...processing_utils import Unpack from ...pytorch_utils import ALL_LAYERNORM_LAYERS from ...utils import ( add_start_docstrings, @@ -832,6 +834,7 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, + **flash_attn_kwargs: Unpack[FlashAttentionKwargs], ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -913,6 +916,7 @@ def forward( use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, + **flash_attn_kwargs, ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/glm/modeling_glm.py b/src/transformers/models/glm/modeling_glm.py index aad4da282b78..6354e20e33fe 100644 --- a/src/transformers/models/glm/modeling_glm.py +++ b/src/transformers/models/glm/modeling_glm.py @@ -38,6 +38,7 @@ ) from ...modeling_utils import PreTrainedModel from ...utils import ( + add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_2_available, @@ -51,7 +52,11 @@ if is_flash_attn_2_available(): from ...modeling_flash_attention_utils import _flash_attention_forward -from ...modeling_flash_attention_utils import _flash_attention_forward +from ...modeling_flash_attention_utils import FlashAttentionKwargs, _flash_attention_forward +from ...processing_utils import Unpack + + +_CHECKPOINT_FOR_DOC = "dummy" class GlmRMSNorm(nn.Module): @@ -736,6 +741,7 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, + **flash_attn_kwargs: Unpack[FlashAttentionKwargs], ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -817,6 +823,7 @@ def forward( use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, + **flash_attn_kwargs, ) hidden_states = layer_outputs[0] @@ -1222,6 +1229,11 @@ def set_input_embeddings(self, value): self.model.embed_tokens = value @add_start_docstrings_to_model_forward(GLM_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TokenClassifierOutput, + config_class=_CONFIG_FOR_DOC, + ) def forward( self, input_ids: Optional[torch.LongTensor] = None, diff --git a/src/transformers/models/glm/modular_glm.py b/src/transformers/models/glm/modular_glm.py index 55bf89d1c56b..c26477fdc173 100644 --- a/src/transformers/models/glm/modular_glm.py +++ b/src/transformers/models/glm/modular_glm.py @@ -46,6 +46,8 @@ logger = logging.get_logger(__name__) +_CHECKPOINT_FOR_DOC = "dummy" + class GlmRMSNorm(Phi3RMSNorm): pass diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 617ef38e4ae3..4d95f01849d6 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -29,7 +29,7 @@ from ...cache_utils import Cache, DynamicCache, StaticCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter -from ...modeling_flash_attention_utils import _flash_attention_forward +from ...modeling_flash_attention_utils import FlashAttentionKwargs, _flash_attention_forward from ...modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, @@ -39,8 +39,10 @@ ) from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS from ...modeling_utils import PreTrainedModel +from ...processing_utils import Unpack from ...pytorch_utils import ALL_LAYERNORM_LAYERS from ...utils import ( + LossKwargs, add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, @@ -422,6 +424,7 @@ def forward( use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if isinstance(past_key_value, StaticCache): raise ValueError( @@ -506,6 +509,7 @@ def forward( sliding_window=getattr(self, "sliding_window", None), use_top_left_mask=self._flash_attn_uses_top_left_mask, is_causal=self.is_causal, + **kwargs, ) attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() @@ -870,6 +874,7 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, + **flash_attn_kwargs: Unpack[FlashAttentionKwargs], ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -951,6 +956,7 @@ def forward( use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, + **flash_attn_kwargs, ) hidden_states = layer_outputs[0] @@ -1102,6 +1108,9 @@ def _prepare_4d_causal_attention_mask_with_cache_position( return causal_mask +class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... + + class LlamaForCausalLM(LlamaPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] @@ -1148,7 +1157,7 @@ def forward( return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, num_logits_to_keep: int = 0, - **loss_kwargs, + **kwargs: Unpack[KwargsForCausalLM], ) -> Union[Tuple, CausalLMOutputWithPast]: r""" Args: @@ -1198,6 +1207,7 @@ def forward( output_hidden_states=output_hidden_states, return_dict=return_dict, cache_position=cache_position, + **kwargs, ) hidden_states = outputs[0] @@ -1211,7 +1221,7 @@ def forward( loss = None if labels is not None: - loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **loss_kwargs) + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) if not return_dict: output = (logits,) + outputs[1:] diff --git a/src/transformers/tokenization_utils_base.py b/src/transformers/tokenization_utils_base.py index 16c05a14028e..4f3187d510fa 100644 --- a/src/transformers/tokenization_utils_base.py +++ b/src/transformers/tokenization_utils_base.py @@ -815,7 +815,7 @@ def to(self, device: Union[str, "torch.device"]) -> "BatchEncoding": # Otherwise it passes the casts down and casts the LongTensor containing the token idxs # into a HalfTensor if isinstance(device, str) or is_torch_device(device) or isinstance(device, int): - self.data = {k: v.to(device=device) for k, v in self.data.items() if isinstance(v, torch.Tensor)} + self.data = {k: v.to(device=device) if isinstance(v, torch.Tensor) else v for k, v in self.data.items()} else: logger.warning(f"Attempting to cast a BatchEncoding to type {str(device)}. This is not supported.") return self diff --git a/src/transformers/utils/__init__.py b/src/transformers/utils/__init__.py index a781389c2fbd..2a10bcaa3c94 100755 --- a/src/transformers/utils/__init__.py +++ b/src/transformers/utils/__init__.py @@ -37,6 +37,7 @@ from .generic import ( ContextManagers, ExplicitEnum, + LossKwargs, ModelOutput, PaddingStrategy, TensorType, diff --git a/src/transformers/utils/generic.py b/src/transformers/utils/generic.py index a5f01fa2e0df..26ec82b20fd4 100644 --- a/src/transformers/utils/generic.py +++ b/src/transformers/utils/generic.py @@ -24,7 +24,7 @@ from dataclasses import fields, is_dataclass from enum import Enum from functools import partial, wraps -from typing import Any, ContextManager, Iterable, List, Optional, Tuple +from typing import Any, ContextManager, Iterable, List, Optional, Tuple, TypedDict import numpy as np from packaging import version @@ -854,3 +854,16 @@ def wrapper(*args, **kwargs): return wrapper return decorator + + +class LossKwargs(TypedDict, total=False): + """ + Keyword arguments to be passed to the loss function + + Attributes: + num_items_in_batch (`int`, *optional*): + Number of items in the batch. It is recommended to pass it when + you are doing gradient accumulation. + """ + + num_items_in_batch: Optional[int]