diff --git a/.github/workflows/build_linux_wheels.yaml b/.github/workflows/build_linux_wheels.yaml index 7b30a64ad6..bcfe639531 100644 --- a/.github/workflows/build_linux_wheels.yaml +++ b/.github/workflows/build_linux_wheels.yaml @@ -36,6 +36,8 @@ jobs: with: repository: pytorch/torchtune ref: "" + test-infra-repository: pytorch/test-infra + test-infra-ref: main package-name: torchtune build-matrix: ${{ needs.generate-matrix.outputs.matrix }} pre-script: .github/scripts/pre_build_script.sh diff --git a/docs/source/index.rst b/docs/source/index.rst index 318c82b3e2..d62ad77b63 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -113,6 +113,7 @@ torchtune tutorials. recipes/recipes_overview recipes/lora_finetune_single_device recipes/qat_distributed + recipes/dpo .. toctree:: :glob: diff --git a/docs/source/recipes/dpo.rst b/docs/source/recipes/dpo.rst new file mode 100644 index 0000000000..5fdb455a35 --- /dev/null +++ b/docs/source/recipes/dpo.rst @@ -0,0 +1,75 @@ +.. _dpo_recipe_label: + +==================================== +Direct Preference Optimization +==================================== + +This recipe supports several `Direct Preference Optimization `_ (DPO)-style fine-tuning techniques. +These techniques aim to steer (or `align `_) a model towards some desirable behaviours. +For example, a common goal is to train language models to produce safe and honest outputs, +or to be `helpful and harmless `_. + +To see the best results when using this recipe, it may be helpful to first fine-tune your model with using supervised fine-tuning to ensure your model is +on-distribution for the domain you're interested in. To do this, check out our other fine-tuning recipes in the :ref:`recipe overview ` which +support a variety of SFT paradigms. + +After supervised fine-tuning, here is an example of DPO with Llama 3.1 8B: + +.. note:: + + You may need to be granted access to the Llama model you're interested in. See + :ref:`here ` for details on accessing gated repositories. + + +.. code-block:: bash + + tune download meta-llama/Meta-Llama-3.1-8B-Instruct \ + --ignore-patterns "original/consolidated.00.pth" + --HF_TOKEN + + # run on a single device + tune run lora_dpo_single_device --config llama3_1/8B_lora_dpo_single_device + + # run on two gpus + tune run --nproc_per_node 2 lora_dpo_distributed --config llama3_1/8B_lora_dpo + +It's easy to get started with this recipe with your dataset of choice, including custom local datasets, +and datasets from Hugging Face. Check out our primer on :ref:`preference datasets ` to +see how to do this. + +For this recipe we include different DPO-style losses: + +* :class:`Direct Preference Optimization ` (DPO) loss [#]_. The DPO loss function + increases the relative log-probabilities of preferred to un-preferred responses, whilst using log probabilities + from a reference model to prevent policy degradation during training. Alongside RLHF, this is the most commonly used + alignment technique and is used to train a growing number of state-of-the-art LLMs e.g. Llama3.1, Gemma 2, Qwen2, etc. + This is a good starting point for alignment fine-tuning. +* :class:`Statistical Rejection Sampling Optimization ` (RSO) or "hinge" loss [#]_. + RSO builds on concepts from support vector machines and DPO, applying a margin-based approach that penalizes + low-quality responses while ensuring a significant gap between chosen and un-chosen log probabilities. + +To use any of these, simply use the ``loss`` config entry or flag through the :ref:`cli_label`: + +.. code-block:: bash + + tune run lora_dpo_single_device --config llama2/7B_lora_dpo_single_device \ + loss=torchtune.modules.loss.RSOLoss \ + gamma=0.5 + +.. todo (@SalmanMohammadi) point to an example repo for SimPO + +For a deeper understanding of the different levers you can pull when using this recipe, +see our documentation for the different PEFT training paradigms we support: + +* :ref:`glossary_lora` +* :ref:`glossary_qlora` +* :ref:`glossary_dora` + +Many of our other memory optimization features can be used in this recipe. You can learn more about all of our memory optimization features in our :ref:`memory optimization overview`. + +.. rubric:: References: + +.. [#] Rafailov, R., Sharma, A., Mitchell, E., Manning, C.D., Ermon, S. and Finn, C., 2024. + Direct preference optimization: Your language model is secretly a reward model. Advances in Neural Information Processing Systems, 36. +.. [#] Liu, T., Zhao, Y., Joshi, R., Khalman, M., Saleh, M., Liu, P.J. and Liu, J., 2023. + Statistical rejection sampling improves preference optimization. arXiv preprint arXiv:2309.06657. diff --git a/docs/source/recipes/lora_finetune_single_device.rst b/docs/source/recipes/lora_finetune_single_device.rst index 4b4d476058..ffcca11d53 100644 --- a/docs/source/recipes/lora_finetune_single_device.rst +++ b/docs/source/recipes/lora_finetune_single_device.rst @@ -8,7 +8,7 @@ This recipe supports finetuning on next-token prediction tasks using parameter e such as :ref:`glossary_lora` and :ref:`glossary_qlora`. These techniques significantly reduce memory consumption during training whilst still maintaining competitive performance. -We provide configs which you can get up and running quickly. Here is an example with llama 3.1 8B: +We provide configs which you can get up and running quickly. Here is an example with Llama 3.1 8B: .. note:: diff --git a/docs/source/recipes/recipes_overview.rst b/docs/source/recipes/recipes_overview.rst index a1c4f39ef3..e6e8c9cd63 100644 --- a/docs/source/recipes/recipes_overview.rst +++ b/docs/source/recipes/recipes_overview.rst @@ -28,7 +28,7 @@ Our recipes include: * Single-device full fine-tuning * Distributed full fine-tuning * Distributed LoRA fine-tuning -* Direct Preference Optimization (DPO) +* :ref:`Direct Preference Optimization (DPO) ` * Proximal Policy Optimization (PPO) * :ref:`Distributed Quantization-Aware Training (QAT)`. diff --git a/recipes/configs/llama2/7B_lora_dpo.yaml b/recipes/configs/llama2/7B_lora_dpo.yaml index abf1b43138..0f21b03206 100644 --- a/recipes/configs/llama2/7B_lora_dpo.yaml +++ b/recipes/configs/llama2/7B_lora_dpo.yaml @@ -83,4 +83,7 @@ log_peak_memory_stats: True # Environment device: cuda dtype: bf16 + +# Memory management enable_activation_checkpointing: True # True reduces memory +enable_activation_offloading: False # True reduces memory diff --git a/recipes/configs/llama2/7B_lora_dpo_single_device.yaml b/recipes/configs/llama2/7B_lora_dpo_single_device.yaml index 7543cb5d6f..c6d8d4bbba 100644 --- a/recipes/configs/llama2/7B_lora_dpo_single_device.yaml +++ b/recipes/configs/llama2/7B_lora_dpo_single_device.yaml @@ -80,4 +80,7 @@ log_peak_memory_stats: True # Environment device: cuda dtype: bf16 + +# Memory management enable_activation_checkpointing: True # True reduces memory +enable_activation_offloading: False # True reduces memory diff --git a/recipes/configs/llama3/70B_lora.yaml b/recipes/configs/llama3/70B_lora.yaml index 2e4e718f62..2d300a4299 100644 --- a/recipes/configs/llama3/70B_lora.yaml +++ b/recipes/configs/llama3/70B_lora.yaml @@ -105,6 +105,7 @@ device: cuda dtype: bf16 enable_activation_checkpointing: True # True reduces memory enable_activation_offloading: False # True reduces memory +# custom_sharded_layers: ['tok_embeddings', 'output'] # Layers to shard separately (useful for large vocab size models). Lower Memory, but lower speed. # Profiler (disabled) profiler: diff --git a/recipes/configs/llama3_1/70B_lora.yaml b/recipes/configs/llama3_1/70B_lora.yaml index a89c01b4c1..45e65b3010 100644 --- a/recipes/configs/llama3_1/70B_lora.yaml +++ b/recipes/configs/llama3_1/70B_lora.yaml @@ -104,6 +104,7 @@ device: cuda dtype: bf16 enable_activation_checkpointing: True # True reduces memory enable_activation_offloading: False # True reduces memory +# custom_sharded_layers: ['tok_embeddings', 'output'] # Layers to shard separately (useful for large vocab size models). Lower Memory, but lower speed. # Profiler (disabled) profiler: diff --git a/recipes/configs/llama3_1/8B_lora_dpo.yaml b/recipes/configs/llama3_1/8B_lora_dpo.yaml new file mode 100644 index 0000000000..6f94b7d09d --- /dev/null +++ b/recipes/configs/llama3_1/8B_lora_dpo.yaml @@ -0,0 +1,92 @@ +# Config for multi-device LoRA DPO alignment in lora_dpo_distributed.py +# using a Llama2 7B model +# +# This config assumes that you've run the following command before launching +# this run: +# tune download meta-llama/Meta-Llama-3.1-8B-Instruct --output-dir /tmp/Meta-Llama-3.1-8B-Instruct --ignore-patterns "original/consolidated.00.pth" +# +# To launch on 2 devices, run the following command from root: +# tune run --nnodes 1 --nproc_per_node 2 lora_dpo_distributed --config llama3_1/8B_lora_dpo +# +# You can add specific overrides through the command line. For example +# to override the checkpointer directory while launching training +# you can run: +# tune run --nnodes 1 --nproc_per_node 2 lora_dpo_distributed --config llama3_1/8B_lora_dpo checkpointer.checkpoint_dir= +# +# This config works best when the model is being fine-tuned on 2+ GPUs. +# For single device LoRA DPO alignment please use llama3_1/8B_lora_dpo_single_device + +# Model Arguments +model: + _component_: torchtune.models.llama3_1.lora_llama3_1_8b + lora_attn_modules: ['q_proj', 'v_proj', 'output_proj'] + apply_lora_to_mlp: True + apply_lora_to_output: False + lora_rank: 8 # higher increases accuracy and memory + lora_alpha: 16 # usually alpha=2*rank + lora_dropout: 0.0 + +# Tokenizer +tokenizer: + _component_: torchtune.models.llama3.llama3_tokenizer + path: /tmp/Meta-Llama-3.1-8B-Instruct/original/tokenizer.model + max_seq_len: null + +checkpointer: + _component_: torchtune.training.FullModelHFCheckpointer + checkpoint_dir: /tmp/Meta-Llama-3.1-8B-Instruct/ + checkpoint_files: [ + model-00001-of-00004.safetensors, + model-00002-of-00004.safetensors, + model-00003-of-00004.safetensors, + model-00004-of-00004.safetensors + ] + recipe_checkpoint: null + output_dir: /tmp/Meta-Llama-3.1-8B-Instruct/ + model_type: LLAMA3 +resume_from_checkpoint: False +save_adapter_weights_only: False + +# Dataset and Sampler +dataset: + _component_: torchtune.datasets.stack_exchange_paired_dataset +seed: null +shuffle: True +batch_size: 4 + +# Optimizer and Scheduler +optimizer: + _component_: torch.optim.AdamW + fused: True + weight_decay: 0.05 + lr: 5e-4 +lr_scheduler: + _component_: torchtune.training.lr_schedulers.get_cosine_schedule_with_warmup + num_warmup_steps: 100 + +loss: + _component_: torchtune.rlhf.loss.DPOLoss + beta: 0.1 + label_smoothing: 0 + +# Training +epochs: 1 +max_steps_per_epoch: 1000 +gradient_accumulation_steps: 8 # Use to increase virtual batch size +compile: False # pytorch compile, set to true for better perf/memory + +# Logging +output_dir: /tmp/lora_dpo_output/ +metric_logger: + _component_: torchtune.training.metric_logging.DiskLogger + log_dir: ${output_dir} +log_every_n_steps: 1 +log_peak_memory_stats: True + +# Environment +device: cuda +dtype: bf16 + +# Memory management +enable_activation_checkpointing: True # True reduces memory +enable_activation_offloading: False # True reduces memory diff --git a/recipes/configs/llama3_1/8B_lora_dpo_single_device.yaml b/recipes/configs/llama3_1/8B_lora_dpo_single_device.yaml new file mode 100644 index 0000000000..638a4efe12 --- /dev/null +++ b/recipes/configs/llama3_1/8B_lora_dpo_single_device.yaml @@ -0,0 +1,89 @@ +# Config for single device LoRA DPO alignment in lora_dpo_single_device.py +# using a Llama2 7B model +# +# This config assumes that you've run the following command before launching +# this run: +# tune download meta-llama/Meta-Llama-3.1-8B-Instruct --output-dir /tmp/Meta-Llama-3.1-8B-Instruct --ignore-patterns "original/consolidated.00.pth" +# +# To launch on a single device, run the following command from root: +# tune run lora_dpo_single_device --config llama3_1/8B_lora_dpo_single_device +# +# You can add specific overrides through the command line. For example +# to override the checkpointer directory while launching training +# you can run: +# tune run lora_dpo_single_device --config llama3_1/8B_lora_dpo_single_device checkpointer.checkpoint_dir= +# +# This config works only for training on single device. + +# Model Arguments +model: + _component_: torchtune.models.llama3_1.lora_llama3_1_8b + lora_attn_modules: ['q_proj', 'v_proj', 'output_proj'] + apply_lora_to_mlp: True + apply_lora_to_output: False + lora_rank: 8 # higher increases accuracy and memory + lora_alpha: 16 # usually alpha=2*rank + lora_dropout: 0.0 + +# Tokenizer +tokenizer: + _component_: torchtune.models.llama3.llama3_tokenizer + path: /tmp/Meta-Llama-3.1-8B-Instruct/original/tokenizer.model + max_seq_len: null + +checkpointer: + _component_: torchtune.training.FullModelHFCheckpointer + checkpoint_dir: /tmp/Meta-Llama-3.1-8B-Instruct/ + checkpoint_files: [ + model-00001-of-00004.safetensors, + model-00002-of-00004.safetensors, + model-00003-of-00004.safetensors, + model-00004-of-00004.safetensors + ] + recipe_checkpoint: null + output_dir: /tmp/Meta-Llama-3.1-8B-Instruct/ + model_type: LLAMA3 +resume_from_checkpoint: False +save_adapter_weights_only: False + +# Dataset and Sampler +dataset: + _component_: torchtune.datasets.stack_exchange_paired_dataset +seed: null +shuffle: True +batch_size: 4 + +# Optimizer and Scheduler +optimizer: + _component_: torch.optim.AdamW + fused: True + weight_decay: 0.05 + lr: 5e-4 +lr_scheduler: + _component_: torchtune.training.lr_schedulers.get_cosine_schedule_with_warmup + num_warmup_steps: 100 + +loss: + _component_: torchtune.rlhf.loss.DPOLoss + +# Training +epochs: 1 +max_steps_per_epoch: 1000 +gradient_accumulation_steps: 8 # Use to increase virtual batch size +compile: False # pytorch compile, set to true for better perf/memory + +# Logging +output_dir: /tmp/lora_dpo_output/ +metric_logger: + _component_: torchtune.training.metric_logging.DiskLogger + log_dir: ${output_dir} +log_every_n_steps: 1 +log_peak_memory_stats: True + +# Environment +device: cuda +dtype: bf16 + +# Memory management +enable_activation_checkpointing: True # True reduces memory +enable_activation_offloading: False # True reduces memory diff --git a/recipes/configs/qwen2_5/32B_lora.yaml b/recipes/configs/qwen2_5/32B_lora.yaml index bed3868365..002a2b6a1a 100644 --- a/recipes/configs/qwen2_5/32B_lora.yaml +++ b/recipes/configs/qwen2_5/32B_lora.yaml @@ -97,6 +97,7 @@ device: cuda dtype: bf16 enable_activation_checkpointing: False # True reduces memory enable_activation_offloading: False # True reduces memory +# custom_sharded_layers: ['tok_embeddings'] # Layers to shard separately (useful for large vocab size models). Lower Memory, but lower speed. # Show case the usage of pytorch profiler # Set enabled to False as it's only needed for debugging training diff --git a/recipes/configs/qwen2_5/72B_lora.yaml b/recipes/configs/qwen2_5/72B_lora.yaml index fc7ad2dc7d..52ade6a49f 100644 --- a/recipes/configs/qwen2_5/72B_lora.yaml +++ b/recipes/configs/qwen2_5/72B_lora.yaml @@ -117,6 +117,7 @@ device: cuda dtype: bf16 enable_activation_checkpointing: True # True reduces memory enable_activation_offloading: False # True reduces memory +# custom_sharded_layers: ['tok_embeddings'] # Layers to shard separately (useful for large vocab size models). Lower Memory, but lower speed. # Show case the usage of pytorch profiler # Set enabled to False as it's only needed for debugging training diff --git a/recipes/full_finetune_distributed.py b/recipes/full_finetune_distributed.py index aa48920815..2e000cc67a 100644 --- a/recipes/full_finetune_distributed.py +++ b/recipes/full_finetune_distributed.py @@ -122,15 +122,6 @@ def __init__(self, cfg: DictConfig) -> None: "full fp16 training is not supported with this recipe. Please use bf16 or fp32 instead." ) - if ( - cfg.get("fsdp_cpu_offload", False) - and cfg.optimizer.get("fused", False) - and not utils.torch_version_ge("2.4.0") - ): - raise RuntimeError( - "Using fused optimizer on CPU is only supported in PyTorch nightly." - ) - # logging attributes self._output_dir = cfg.output_dir self._log_every_n_steps = cfg.get("log_every_n_steps", 1) diff --git a/recipes/lora_dpo_distributed.py b/recipes/lora_dpo_distributed.py index ab37623cc1..7f6b0a8394 100644 --- a/recipes/lora_dpo_distributed.py +++ b/recipes/lora_dpo_distributed.py @@ -33,7 +33,6 @@ validate_missing_and_unexpected_for_lora, ) from torchtune.recipe_interfaces import FTRecipeInterface -from torchtune.rlhf.loss import SimPOLoss from tqdm import tqdm log = utils.get_logger("DEBUG") @@ -59,6 +58,18 @@ class LoRADPORecipeDistributed(FTRecipeInterface): come at the cost of training performance. In most cases training can slow-down quite a bit as a result of this activation recomputation. + - Activation Offloading. This can be controlled using the ``enable_activation_offloading`` + flag. Activation offloading is a technique similar to activations checkpointing that helps + reduce the memory footprint to prevent OOMs on CUDA and enable bigger batches. Where activations + checkpointing drops the activation in the forward to recompute it later in the backward, + activations offloading will drop the activation in the forward to the CPU and bring it + back during the backward pass. As always, there is a tradeoff--these savings in memory can + come at the cost of training performance and CPU resources. To recover some runtime cost, + we've added an option to enable offloading on a different stream to permit overlapping with + the computation. This option is currently only available on PyTorch 2.5 or later and will + be enabled by default if an acceptable torch version is found. Activation offloading can be + used in conjunction with activation checkpointing. + - Precision. Full fp32 and bf16 training are supported. Precision is controlled using the ``dtype`` flag. When ``dtype=bf16``, all activations, gradients and optimizer states are in bfloat16. In most cases this should halve the memory footprint of full precision (fp32) training, without @@ -97,7 +108,6 @@ class LoRADPORecipeDistributed(FTRecipeInterface): The following losses are supported in this recipe: - :class:`~torchtune.rlhf.loss.DPOLoss`: Direct Preference Optimization (DPO). - :class:`~torchtune.rlhf.loss.RSOPLoss`: Rejection Sampling Optimization (RSO). - - :class:`~torchtune.rlhf.loss.SimPOLoss`: Simple Preference Optimization (SimPO). For a full list of example configs for this recipe, run ``tune ls`` on the command line. Each config has example commands for how to kick-off training. @@ -109,6 +119,8 @@ class LoRADPORecipeDistributed(FTRecipeInterface): ValueError: If ``dtype`` is set to fp16. ValueError: If world_size is 1 RuntimeError: If ``dtype`` is set to bf16 and the hardware does not support bf16. + RuntimeError: If ``enable_activation_offloading`` is True and device is not CUDA. + RuntimeError: If ``enable_activation_offloading`` is True and ``enable_activation_checkpointing`` is False. """ def __init__(self, cfg: DictConfig) -> None: @@ -135,8 +147,28 @@ def __init__(self, cfg: DictConfig) -> None: ) self._log_peak_memory_stats = False - # training attributes - self._enable_activation_checkpointing = cfg.enable_activation_checkpointing + # activation checkpointing/offloading + self._enable_activation_checkpointing = cfg.get( + "enable_activation_checkpointing", False + ) + self._enable_activation_offloading = cfg.get( + "enable_activation_offloading", False + ) + if self._enable_activation_offloading: + if self._device.type != "cuda": + raise RuntimeError( + "enable_activation_offloading should only be True when training on CUDA" + ) + if not self._enable_activation_checkpointing: + raise RuntimeError( + "enable_activation_offloading should only be True when enable_activation_checkpointing is True" + ) + elif self._enable_activation_checkpointing: + utils.log_rank_zero( + log, + "Hint: enable_activation_checkpointing is True, but enable_activation_offloading isn't. " + "Enabling activation offloading should reduce memory further.", + ) # These attributes constitute the recipe state and are updated by ``load_checkpoint`` # when ``resume_from_checkpoint`` is ``True`` @@ -232,6 +264,8 @@ def setup(self, cfg: DictConfig) -> None: self._model = self._setup_model( cfg_model=cfg.model, enable_activation_checkpointing=cfg.enable_activation_checkpointing, + enable_activation_offloading=self._enable_activation_offloading, + custom_sharded_layers=cfg.get("custom_sharded_layers", None), fsdp_cpu_offload=cfg.get("fsdp_cpu_offload", False), reshard_after_forward=cfg.get("fsdp_reshard_after_forward", True), base_model_state_dict=checkpoint_dict[training.MODEL_KEY], @@ -293,6 +327,7 @@ def _setup_model( self, cfg_model: DictConfig, enable_activation_checkpointing: bool, + enable_activation_offloading: bool, fsdp_cpu_offload: bool, reshard_after_forward: bool, base_model_state_dict: Dict[str, Any], @@ -396,6 +431,12 @@ def _setup_model( lora_unexpected=lora_unexpected, ) # Ensure no params and buffers are on meta device + + # activation offloading + self.activations_handling_ctx = training.get_act_offloading_ctx_manager( + model, enable_activation_offloading + ) + training.validate_no_params_on_meta_device(model) utils.log_rank_zero( log, @@ -581,14 +622,10 @@ def concatenated_forward( # formed by concatenating an equal number of "chosen" and "rejected". len_chosen = concatenated_input_ids.shape[0] // 2 - all_logits = model(concatenated_input_ids) + with self.activations_handling_ctx: + all_logits = model(concatenated_input_ids) - all_log_probs = rlhf.get_batch_log_probs( - all_logits, - concatenated_labels, - # see :class:`~torchtune.rlhf.loss.dpo.SimPOLoss` - return_average_logprobs=isinstance(self._loss_fn, SimPOLoss), - ) + all_log_probs = rlhf.get_batch_log_probs(all_logits, concatenated_labels) chosen_log_probs = all_log_probs[:len_chosen] rejected_log_probs = all_log_probs[len_chosen:] @@ -647,26 +684,19 @@ def train(self) -> None: # deleting logits here helps reduce (peak) memory usage - we only need them for metric logging del policy_chosen_logits, policy_rejected_logits - if isinstance(self._loss_fn, SimPOLoss): - loss, chosen_rewards, rejected_rewards = self._loss_fn( - policy_chosen_log_probs, policy_rejected_log_probs - ) - else: - # reference based losses (e.g. DPO) explicitly regularize the objective fn based on - # the reference model's output - reference-free losses (such as SimPO) don't require this. - with torch.no_grad(), disable_adapter(self._model): - ( - reference_chosen_log_probs, - reference_rejected_log_probs, - _, - _, - ) = self.concatenated_forward(self._model, batch) - loss, chosen_rewards, rejected_rewards = self._loss_fn( - policy_chosen_log_probs, - policy_rejected_log_probs, + with torch.no_grad(), disable_adapter(self._model): + ( reference_chosen_log_probs, reference_rejected_log_probs, - ) + _, + _, + ) = self.concatenated_forward(self._model, batch) + loss, chosen_rewards, rejected_rewards = self._loss_fn( + policy_chosen_log_probs, + policy_rejected_log_probs, + reference_chosen_log_probs, + reference_rejected_log_probs, + ) loss = loss.mean() reward_accuracies = (chosen_rewards > rejected_rewards).float() diff --git a/recipes/lora_dpo_single_device.py b/recipes/lora_dpo_single_device.py index 53b3b67be5..17f985e75f 100644 --- a/recipes/lora_dpo_single_device.py +++ b/recipes/lora_dpo_single_device.py @@ -30,7 +30,6 @@ ) from torchtune.recipe_interfaces import FTRecipeInterface -from torchtune.rlhf.loss import SimPOLoss from tqdm import tqdm log = utils.get_logger("DEBUG") @@ -44,9 +43,11 @@ class LoRADPORecipeSingleDevice(FTRecipeInterface): This recipe supports: - Activation checkpointing. This is enabled by default but is configurable. + - Activation offloading - this is enabled by default and should only be used alongside + activation checkpointing. - Full bf16 training for supported HW architectures. We currently check bf16 support via - the `torch.cuda.is_bf16_supported` API. This is disabled by default but can be enabled via - setting `dtype=bf16` in configuration. + the `torch.cuda.is_bf16_supported` API. This is disabled by default but can be enabled via + setting `dtype=bf16` in configuration. - Checkpointing: of LoRA adapter parameters and their optimizer states. When resuming from a checkpoint, the adapter parameters are loaded from the checkpoint along with the base model weights. Note that intra-epoch resumption is not supported. @@ -56,7 +57,6 @@ class LoRADPORecipeSingleDevice(FTRecipeInterface): The following losses are supported in this recipe: - :class:`~torchtune.rlhf.loss.DPOLoss`: Direct Preference Optimization (DPO). - :class:`~torchtune.rlhf.loss.RSOPLoss`: Rejection Sampling Optimization (RSO). - - :class:`~torchtune.rlhf.loss.SimPOLoss`: Simple Preference Optimization (SimPO). Assumptions: - Checkpoints are ONLY saved at epoch boundaries. In case of failure, work done @@ -74,6 +74,8 @@ class LoRADPORecipeSingleDevice(FTRecipeInterface): Raises: ValueError: If ``dtype`` is set to fp16. RuntimeError: If ``dtype`` is set to bf16 and the hardware does not support bf16. + RuntimeError: If ``enable_activation_offloading`` is True and device is not CUDA. + RuntimeError: If ``enable_activation_offloading`` is True and ``enable_activation_checkpointing`` is False. """ @@ -101,6 +103,29 @@ def __init__(self, cfg: DictConfig) -> None: ) self._log_peak_memory_stats = False + # activation checkpointing/offloading + self._enable_activation_checkpointing = cfg.get( + "enable_activation_checkpointing", False + ) + self._enable_activation_offloading = cfg.get( + "enable_activation_offloading", False + ) + if self._enable_activation_offloading: + if self._device.type != "cuda": + raise RuntimeError( + "enable_activation_offloading should only be True when training on CUDA" + ) + if not self._enable_activation_checkpointing: + raise RuntimeError( + "enable_activation_offloading should only be True when enable_activation_checkpointing is True" + ) + elif self._enable_activation_checkpointing: + utils.log_rank_zero( + log, + "Hint: enable_activation_checkpointing is True, but enable_activation_offloading isn't. " + "Enabling activation offloading should reduce memory further.", + ) + # These are public properties which are updated by the checkpoint loader # when ``resume_from_checkpoint`` is `True` or validated in tests self.seed = training.set_seed(seed=cfg.seed) @@ -190,6 +215,7 @@ def setup(self, cfg: DictConfig) -> None: self._model = self._setup_model( cfg_model=cfg.model, enable_activation_checkpointing=cfg.enable_activation_checkpointing, + enable_activation_offloading=self._enable_activation_offloading, compile_model=cfg.compile, base_model_state_dict=checkpoint_dict[training.MODEL_KEY], lora_weights_state_dict=( @@ -251,6 +277,7 @@ def _setup_model( self, cfg_model: DictConfig, enable_activation_checkpointing: bool, + enable_activation_offloading: bool, compile_model: bool, base_model_state_dict: Dict[str, Any], lora_weights_state_dict: Optional[Dict[str, Any]] = None, @@ -289,6 +316,11 @@ def _setup_model( lora_unexpected=lora_unexpected, ) + # activation offloading + self.activations_handling_ctx = training.get_act_offloading_ctx_manager( + model, enable_activation_offloading + ) + log.info(f"Model is initialized with precision {self._dtype}.") # Compile model, if enabled. @@ -443,14 +475,10 @@ def concatenated_forward( # formed by concatenating an equal number of "chosen" and "rejected". len_chosen = concatenated_input_ids.shape[0] // 2 - all_logits = model(concatenated_input_ids) + with self.activations_handling_ctx: + all_logits = model(concatenated_input_ids) - all_log_probs = rlhf.get_batch_log_probs( - all_logits, - concatenated_labels, - # see :class:`~torchtune.rlhf.loss.dpo.SimPOLoss` - return_average_logprobs=isinstance(self._loss_fn, SimPOLoss), - ) + all_log_probs = rlhf.get_batch_log_probs(all_logits, concatenated_labels) chosen_log_probs = all_log_probs[:len_chosen] rejected_log_probs = all_log_probs[len_chosen:] @@ -503,26 +531,19 @@ def train(self) -> None: # deleting logits here helps reduce (peak) memory usage - we only need them for metric logging del policy_chosen_logits, policy_rejected_logits - if isinstance(self._loss_fn, SimPOLoss): - loss, chosen_rewards, rejected_rewards = self._loss_fn( - policy_chosen_log_probs, policy_rejected_log_probs - ) - else: - # reference based losses (e.g. DPO) explicitly regularize the objective fn based on - # the reference model's output - reference-free losses (such as SimPO) don't require this. - with torch.no_grad(), disable_adapter(self._model): - ( - reference_chosen_log_probs, - reference_rejected_log_probs, - _, - _, - ) = self.concatenated_forward(self._model, batch) - loss, chosen_rewards, rejected_rewards = self._loss_fn( - policy_chosen_log_probs, - policy_rejected_log_probs, + with torch.no_grad(), disable_adapter(self._model): + ( reference_chosen_log_probs, reference_rejected_log_probs, - ) + _, + _, + ) = self.concatenated_forward(self._model, batch) + loss, chosen_rewards, rejected_rewards = self._loss_fn( + policy_chosen_log_probs, + policy_rejected_log_probs, + reference_chosen_log_probs, + reference_rejected_log_probs, + ) loss = loss.mean() reward_accuracies = (chosen_rewards > rejected_rewards).float() diff --git a/recipes/lora_finetune_distributed.py b/recipes/lora_finetune_distributed.py index 45209814a0..5aef0e2e97 100644 --- a/recipes/lora_finetune_distributed.py +++ b/recipes/lora_finetune_distributed.py @@ -273,6 +273,7 @@ def setup(self, cfg: DictConfig) -> None: cfg_model=cfg.model, enable_activation_checkpointing=self._enable_activation_checkpointing, enable_activation_offloading=self._enable_activation_offloading, + custom_sharded_layers=cfg.get("custom_sharded_layers", None), fsdp_cpu_offload=cfg.get("fsdp_cpu_offload", False), reshard_after_forward=cfg.get("fsdp_reshard_after_forward", True), base_model_state_dict=checkpoint_dict[training.MODEL_KEY], diff --git a/recipes/qat_distributed.py b/recipes/qat_distributed.py index ab9cea3eda..e7df34d97c 100644 --- a/recipes/qat_distributed.py +++ b/recipes/qat_distributed.py @@ -133,15 +133,6 @@ def __init__(self, cfg: DictConfig) -> None: "full fp16 training is not supported with this recipe. Please use bf16 or fp32 instead." ) - if ( - cfg.get("fsdp_cpu_offload", False) - and cfg.optimizer.get("fused", False) - and not utils.torch_version_ge("2.4.0") - ): - raise RuntimeError( - "Using fused optimizer on CPU is only supported in PyTorch nightly." - ) - # logging attributes self._output_dir = cfg.output_dir self._log_every_n_steps = cfg.get("log_every_n_steps", 1) diff --git a/tests/recipes/test_qat_lora_finetune_distributed.py b/tests/recipes/test_qat_lora_finetune_distributed.py index 5be3a2379a..4d7c4b6899 100644 --- a/tests/recipes/test_qat_lora_finetune_distributed.py +++ b/tests/recipes/test_qat_lora_finetune_distributed.py @@ -45,7 +45,7 @@ def _get_test_config_overrides(self): def _fetch_expected_loss_values(self, model_type): loss_values_map = { - "llama3": [11.9325, 11.9325, 11.9325, 11.9369], + "llama3": [11.9835, 11.9694, 11.9615, 11.9383], } return loss_values_map[model_type] @@ -66,6 +66,7 @@ def test_loss( ): ckpt = "llama3_tune" ckpt_path = Path(CKPT_MODEL_PATHS[ckpt]) + tokenizer_path = Path(TOKENIZER_PATHS["llama3"]) ckpt_dir = ckpt_path.parent log_file = gen_log_file_name(tmpdir) cmd = f""" @@ -80,11 +81,12 @@ def test_loss( checkpointer.output_dir={tmpdir} \ checkpointer.model_type=LLAMA3 \ metric_logger.filename={log_file} \ - tokenizer.path=/tmp/test-artifacts/tokenizer.model \ + tokenizer.path={tokenizer_path} \ tokenizer.prompt_template=null \ compile={should_compile} \ enable_activation_checkpointing=False \ enable_activation_offloading=False \ + quantizer.groupsize=32 \ """.split() model_config = MODEL_TEST_CONFIGS["llama3_lora"] @@ -154,6 +156,7 @@ def test_training_state_on_resume( save_adapter_weights_only={save_adapter_weights_only} \ enable_activation_checkpointing=True \ enable_activation_offloading=True \ + quantizer.groupsize=32 \ """.split() model_config = MODEL_TEST_CONFIGS[model_type + "_lora"] @@ -182,6 +185,7 @@ def test_training_state_on_resume( metric_logger.filename={log_file} \ enable_activation_checkpointing=True \ enable_activation_offloading=True \ + quantizer.groupsize=32 \ """.split() cmd_2 = cmd_2 + self._get_test_config_overrides() + model_config @@ -228,6 +232,7 @@ def test_save_and_load_merged_weights( tokenizer.prompt_template=null \ enable_activation_checkpointing=True \ enable_activation_offloading=True \ + quantizer.groupsize=32 \ """.split() model_config = MODEL_TEST_CONFIGS[model_type + "_lora"] diff --git a/tests/torchtune/datasets/multimodal/test_llava_instruct_dataset.py b/tests/torchtune/datasets/multimodal/test_llava_instruct_dataset.py index 5df17bb877..11e039e66d 100644 --- a/tests/torchtune/datasets/multimodal/test_llava_instruct_dataset.py +++ b/tests/torchtune/datasets/multimodal/test_llava_instruct_dataset.py @@ -86,3 +86,12 @@ def test_get_item(self, load_image, load_dataset, tokenizer, test_image_pil): assert Counter(input) == expected_count assert labels.count(CROSS_ENTROPY_IGNORE_IDX) == 11 assert images == [test_image_pil] + + def test_dataset_fails_with_packed(self, tokenizer): + with pytest.raises( + ValueError, match="Multimodal datasets don't support packing yet." + ): + llava_instruct_dataset( + model_transform=tokenizer, + packed=True, + ) diff --git a/tests/torchtune/datasets/multimodal/test_multimodal_chat_dataset.py b/tests/torchtune/datasets/multimodal/test_multimodal_chat_dataset.py new file mode 100644 index 0000000000..8b12d3a85e --- /dev/null +++ b/tests/torchtune/datasets/multimodal/test_multimodal_chat_dataset.py @@ -0,0 +1,24 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import pytest +from tests.test_utils import DummyTokenizer + +from torchtune.datasets.multimodal import multimodal_chat_dataset + + +class TestMultimodalChatDataset: + @pytest.fixture + def tokenizer(self): + return DummyTokenizer() + + def test_dataset_fails_with_packed(self, tokenizer): + with pytest.raises( + ValueError, match="Multimodal datasets don't support packing yet." + ): + multimodal_chat_dataset( + model_transform=tokenizer, source="json", packed=True + ) diff --git a/tests/torchtune/datasets/multimodal/test_the_cauldron_dataset.py b/tests/torchtune/datasets/multimodal/test_the_cauldron_dataset.py index ed8ed40ec7..ebc485d8dd 100644 --- a/tests/torchtune/datasets/multimodal/test_the_cauldron_dataset.py +++ b/tests/torchtune/datasets/multimodal/test_the_cauldron_dataset.py @@ -79,3 +79,13 @@ def test_get_item(self, load_dataset, tokenizer, test_image_pil): ] assert labels.count(CROSS_ENTROPY_IGNORE_IDX) == 24 assert images == [test_image_pil] + + def test_dataset_fails_with_packed(self, tokenizer): + with pytest.raises( + ValueError, match="Multimodal datasets don't support packing yet." + ): + the_cauldron_dataset( + model_transform=tokenizer, + subset="dummy", + packed=True, + ) diff --git a/tests/torchtune/datasets/multimodal/test_vqa_dataset.py b/tests/torchtune/datasets/multimodal/test_vqa_dataset.py index 6ca36d9615..d2b80fdea7 100644 --- a/tests/torchtune/datasets/multimodal/test_vqa_dataset.py +++ b/tests/torchtune/datasets/multimodal/test_vqa_dataset.py @@ -47,3 +47,13 @@ def test_get_item(self, tokenizer): assert prompt == expected_tokens[i] assert label == expected_labels[i] assert isinstance(image[0], PngImageFile) + + def test_dataset_fails_with_packed(self, tokenizer): + with pytest.raises( + ValueError, match="Multimodal datasets don't support packing yet." + ): + vqa_dataset( + model_transform=tokenizer, + source="json", + packed=True, + ) diff --git a/tests/torchtune/datasets/test_hh_rlhf_helpful_dataset.py b/tests/torchtune/datasets/test_hh_rlhf_helpful_dataset.py index 40a7d02c8c..834e8d78a9 100644 --- a/tests/torchtune/datasets/test_hh_rlhf_helpful_dataset.py +++ b/tests/torchtune/datasets/test_hh_rlhf_helpful_dataset.py @@ -107,3 +107,14 @@ def test_dataset_get_item(self, mock_load_dataset, train_on_input): else: # Check that the input is masked assert sample["rejected_labels"].count(CROSS_ENTROPY_IGNORE_IDX) == 16 + + def test_dataset_fails_with_packed(self): + with pytest.raises( + ValueError, + match="Packed is currently not supported for preference datasets", + ): + hh_rlhf_helpful_dataset( + tokenizer=DummyTokenizer(), + train_on_input=True, + packed=True, + ) diff --git a/tests/torchtune/datasets/test_preference_dataset.py b/tests/torchtune/datasets/test_preference_dataset.py index e6bbc264b3..4f5cba7a8d 100644 --- a/tests/torchtune/datasets/test_preference_dataset.py +++ b/tests/torchtune/datasets/test_preference_dataset.py @@ -155,3 +155,17 @@ def test_load_local_json(self): assert expected_chosen_labels[0] == ds[0]["chosen_labels"] assert expected_rejected_labels[0] == ds[0]["rejected_labels"] + + def test_dataset_fails_with_packed(self): + with pytest.raises( + ValueError, + match="Packed is currently not supported for preference datasets.", + ): + preference_dataset( + tokenizer=DummyTokenizer(), + source="json", + data_files=str(ASSETS / "hh_rlhf_tiny.json"), + train_on_input=False, + split="train", + packed=True, + ) diff --git a/tests/torchtune/datasets/test_stack_exchange_paired_dataset.py b/tests/torchtune/datasets/test_stack_exchange_paired_dataset.py index f9bcbebc08..6e8a9a4eb8 100644 --- a/tests/torchtune/datasets/test_stack_exchange_paired_dataset.py +++ b/tests/torchtune/datasets/test_stack_exchange_paired_dataset.py @@ -100,6 +100,16 @@ def test_dataset_get_item(self, mock_load_dataset, train_on_input): # Check that the input is masked assert sample["rejected_labels"].count(CROSS_ENTROPY_IGNORE_IDX) == 52 + def test_dataset_fails_with_packed(self): + with pytest.raises( + ValueError, + match="Packed is currently not supported for preference datasets", + ): + stack_exchange_paired_dataset( + tokenizer=DummyTokenizer(), + packed=True, + ) + class TestStackExchangePairedToMessages: @pytest.fixture diff --git a/torchtune/_recipe_registry.py b/torchtune/_recipe_registry.py index daf84be1b7..2a4bf25a8b 100644 --- a/torchtune/_recipe_registry.py +++ b/torchtune/_recipe_registry.py @@ -302,6 +302,10 @@ class Recipe: name="llama2/7B_lora_dpo_single_device", file_path="llama2/7B_lora_dpo_single_device.yaml", ), + Config( + name="llama3_1/8B_lora_dpo_single_device", + file_path="llama3_1/8B_lora_dpo_single_device.yaml", + ), ], supports_distributed=False, ), @@ -313,6 +317,10 @@ class Recipe: name="llama2/7B_lora_dpo", file_path="llama2/7B_lora_dpo.yaml", ), + Config( + name="llama3_1/8B_lora_dpo", + file_path="llama3_1/8B_lora_dpo.yaml", + ), ], supports_distributed=True, ), diff --git a/torchtune/datasets/_preference.py b/torchtune/datasets/_preference.py index 1cc53b3626..dea4eec852 100644 --- a/torchtune/datasets/_preference.py +++ b/torchtune/datasets/_preference.py @@ -89,9 +89,14 @@ class requires the dataset to have "chosen" and "rejected" model responses. Thes filter_fn (Optional[Callable]): callable used to filter the dataset prior to any pre-processing. See the Hugging Face `docs `_ for more details. + packed (bool): Whether or not to pack the dataset to ``max_seq_len`` prior to training. Default is False. Packed is + currently not supported for ``PreferenceDataset`` and a ``ValueError`` will be raised if this is set to True. **load_dataset_kwargs (Dict[str, Any]): additional keyword arguments to pass to ``load_dataset``. See Hugging Face's `API ref `_ for more details. + + Raises: + ValueError: If ``packed`` is True, this feature is not supported for ``PreferenceDataset``. """ def __init__( @@ -101,8 +106,14 @@ def __init__( message_transform: Transform, tokenizer: ModelTokenizer, filter_fn: Optional[Callable] = None, + packed: bool = False, **load_dataset_kwargs: Dict[str, Any], ) -> None: + if packed: + raise ValueError( + "Packed is currently not supported for preference datasets." + ) + self._tokenizer = tokenizer self._message_transform = message_transform self._data = load_dataset(source, **load_dataset_kwargs) diff --git a/torchtune/datasets/multimodal/_llava_instruct.py b/torchtune/datasets/multimodal/_llava_instruct.py index 17174bf1f1..2f218731ba 100644 --- a/torchtune/datasets/multimodal/_llava_instruct.py +++ b/torchtune/datasets/multimodal/_llava_instruct.py @@ -118,6 +118,8 @@ def __call__(self, sample: Mapping[str, Any]) -> Mapping[str, Any]: >>> print(f"Batch size: {len(batch)}") >>> Batch size: 8 """ + if packed: + raise ValueError("Multimodal datasets don't support packing yet.") message_transform = ShareGPTToMessages( train_on_input=False, @@ -136,6 +138,5 @@ def __call__(self, sample: Mapping[str, Any]) -> Mapping[str, Any]: data_files=data_files, **load_dataset_kwargs, ) - if packed: - raise ValueError("Multimodal datasets don't support packing yet.") + return ds diff --git a/torchtune/datasets/multimodal/_multimodal.py b/torchtune/datasets/multimodal/_multimodal.py index 294e71567a..83673a2e1a 100644 --- a/torchtune/datasets/multimodal/_multimodal.py +++ b/torchtune/datasets/multimodal/_multimodal.py @@ -18,6 +18,7 @@ def multimodal_chat_dataset( source: str, column_map: Optional[Dict[str, str]] = None, new_system_prompt: Optional[str] = None, + packed: bool = False, image_tag: Optional[str] = None, image_dir: Optional[str] = None, filter_fn: Optional[Callable] = None, @@ -79,6 +80,7 @@ def multimodal_chat_dataset( new_system_prompt (Optional[str]): if specified, prepend a system message. This can serve as instructions to guide the model response. Setting this will OVERRIDE any system messages already present in the dataset. Default is None. + packed (bool): Whether or not to pack the dataset to ``max_seq_len`` prior to training. Default is False. image_tag (Optional[str]): placeholder tags in the text content of each message to be replaced by dictionaries indicating to the tokenizer where to place image tokens. If images are present and this is None, then will prepend image tokens to the first user message in the sample by default. If text-only, leave @@ -169,7 +171,14 @@ def multimodal_chat_dataset( Returns: SFTDataset: the configured :class:`~torchtune.datasets.SFTDataset` + + Raises: + ValueError: If ``packed`` is True, they are not supported for multimodal datasets yet. + """ + if packed: + raise ValueError("Multimodal datasets don't support packing yet.") + message_transform = ShareGPTToMessages( train_on_input=False, column_map=column_map, diff --git a/torchtune/datasets/multimodal/_the_cauldron.py b/torchtune/datasets/multimodal/_the_cauldron.py index 899360a619..8887edf827 100644 --- a/torchtune/datasets/multimodal/_the_cauldron.py +++ b/torchtune/datasets/multimodal/_the_cauldron.py @@ -216,6 +216,8 @@ def __call__(self, sample: Mapping[str, Any]) -> Mapping[str, Any]: >>> print(f"Batch size: {len(batch)}") >>> Batch size: 8 """ + if packed: + raise ValueError("Multimodal datasets don't support packing yet.") message_transform = TheCauldronToMessages( column_map=column_map, @@ -231,6 +233,5 @@ def __call__(self, sample: Mapping[str, Any]) -> Mapping[str, Any]: split=split, **load_dataset_kwargs, ) - if packed: - raise ValueError("Multimodal datasets don't support packing yet.") + return ds diff --git a/torchtune/datasets/multimodal/_vqa.py b/torchtune/datasets/multimodal/_vqa.py index 27e8bbe1d4..ce991f07ec 100644 --- a/torchtune/datasets/multimodal/_vqa.py +++ b/torchtune/datasets/multimodal/_vqa.py @@ -18,6 +18,7 @@ def vqa_dataset( image_dir: str = None, column_map: Optional[Dict[str, str]] = None, new_system_prompt: Optional[str] = None, + packed: bool = False, filter_fn: Optional[Callable] = None, split: str = "train", **load_dataset_kwargs: Dict[str, Any], @@ -63,6 +64,7 @@ def vqa_dataset( new_system_prompt (Optional[str]): if specified, prepend a system message. This can serve as instructions to guide the model response. Setting this will OVERRIDE any system messages already present in the dataset. Default is None. + packed (bool): Whether or not to pack the dataset to ``max_seq_len`` prior to training. Default is False. filter_fn (Optional[Callable]): callable used to filter the dataset prior to any pre-processing. See the Hugging Face `docs `_ for more details. @@ -122,7 +124,14 @@ def vqa_dataset( Returns: SFTDataset: the configured :class:`~torchtune.datasets.SFTDataset` + + Raises: + ValueError: If ``packed`` is True, they are not supported for multimodal datasets yet. + """ + if packed: + raise ValueError("Multimodal datasets don't support packing yet.") + message_transform = InputOutputToMessages( column_map=column_map, new_system_prompt=new_system_prompt, image_dir=image_dir ) diff --git a/torchtune/rlhf/loss/dpo.py b/torchtune/rlhf/loss/dpo.py index 29f66a20c3..b19e0d93ca 100644 --- a/torchtune/rlhf/loss/dpo.py +++ b/torchtune/rlhf/loss/dpo.py @@ -9,6 +9,7 @@ import torch import torch.nn as nn import torch.nn.functional as F +from torchtune.utils._logging import deprecated class DPOLoss(nn.Module): @@ -160,6 +161,7 @@ def forward( return losses, chosen_rewards, rejected_rewards +@deprecated(msg="SimPOLoss will be deprecated in an upcoming release.") class SimPOLoss(nn.Module): """ SimPO: Simple Preference Optimization with a Reference-Free Reward: https://arxiv.org/abs/2405.14734.