-
Notifications
You must be signed in to change notification settings - Fork 471
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge remote-tracking branch 'origin/main' into flux_ae
- Loading branch information
Showing
33 changed files
with
521 additions
and
84 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
.. _dpo_recipe_label: | ||
|
||
==================================== | ||
Direct Preference Optimization | ||
==================================== | ||
|
||
This recipe supports several `Direct Preference Optimization <https://arxiv.org/abs/2305.18290>`_ (DPO)-style fine-tuning techniques. | ||
These techniques aim to steer (or `align <https://en.wikipedia.org/wiki/AI_alignment>`_) a model towards some desirable behaviours. | ||
For example, a common goal is to train language models to produce safe and honest outputs, | ||
or to be `helpful and harmless <https://arxiv.org/abs/2204.05862>`_. | ||
|
||
To see the best results when using this recipe, it may be helpful to first fine-tune your model with using supervised fine-tuning to ensure your model is | ||
on-distribution for the domain you're interested in. To do this, check out our other fine-tuning recipes in the :ref:`recipe overview <recipes_overview_label>` which | ||
support a variety of SFT paradigms. | ||
|
||
After supervised fine-tuning, here is an example of DPO with Llama 3.1 8B: | ||
|
||
.. note:: | ||
|
||
You may need to be granted access to the Llama model you're interested in. See | ||
:ref:`here <download_llama_label>` for details on accessing gated repositories. | ||
|
||
|
||
.. code-block:: bash | ||
tune download meta-llama/Meta-Llama-3.1-8B-Instruct \ | ||
--ignore-patterns "original/consolidated.00.pth" | ||
--HF_TOKEN <HF_TOKEN> | ||
# run on a single device | ||
tune run lora_dpo_single_device --config llama3_1/8B_lora_dpo_single_device | ||
# run on two gpus | ||
tune run --nproc_per_node 2 lora_dpo_distributed --config llama3_1/8B_lora_dpo | ||
It's easy to get started with this recipe with your dataset of choice, including custom local datasets, | ||
and datasets from Hugging Face. Check out our primer on :ref:`preference datasets <preference_dataset_usage_label>` to | ||
see how to do this. | ||
|
||
For this recipe we include different DPO-style losses: | ||
|
||
* :class:`Direct Preference Optimization <torchtune.rlhf.loss.DPOLoss>` (DPO) loss [#]_. The DPO loss function | ||
increases the relative log-probabilities of preferred to un-preferred responses, whilst using log probabilities | ||
from a reference model to prevent policy degradation during training. Alongside RLHF, this is the most commonly used | ||
alignment technique and is used to train a growing number of state-of-the-art LLMs e.g. Llama3.1, Gemma 2, Qwen2, etc. | ||
This is a good starting point for alignment fine-tuning. | ||
* :class:`Statistical Rejection Sampling Optimization <torchtune.rlhf.loss.RSOLoss>` (RSO) or "hinge" loss [#]_. | ||
RSO builds on concepts from support vector machines and DPO, applying a margin-based approach that penalizes | ||
low-quality responses while ensuring a significant gap between chosen and un-chosen log probabilities. | ||
|
||
To use any of these, simply use the ``loss`` config entry or flag through the :ref:`cli_label`: | ||
|
||
.. code-block:: bash | ||
tune run lora_dpo_single_device --config llama2/7B_lora_dpo_single_device \ | ||
loss=torchtune.modules.loss.RSOLoss \ | ||
gamma=0.5 | ||
.. todo (@SalmanMohammadi) point to an example repo for SimPO | ||
For a deeper understanding of the different levers you can pull when using this recipe, | ||
see our documentation for the different PEFT training paradigms we support: | ||
|
||
* :ref:`glossary_lora` | ||
* :ref:`glossary_qlora` | ||
* :ref:`glossary_dora` | ||
|
||
Many of our other memory optimization features can be used in this recipe. You can learn more about all of our memory optimization features in our :ref:`memory optimization overview<memory_optimization_overview_label>`. | ||
|
||
.. rubric:: References: | ||
|
||
.. [#] Rafailov, R., Sharma, A., Mitchell, E., Manning, C.D., Ermon, S. and Finn, C., 2024. | ||
Direct preference optimization: Your language model is secretly a reward model. Advances in Neural Information Processing Systems, 36. | ||
.. [#] Liu, T., Zhao, Y., Joshi, R., Khalman, M., Saleh, M., Liu, P.J. and Liu, J., 2023. | ||
Statistical rejection sampling improves preference optimization. arXiv preprint arXiv:2309.06657. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
# Config for multi-device LoRA DPO alignment in lora_dpo_distributed.py | ||
# using a Llama2 7B model | ||
# | ||
# This config assumes that you've run the following command before launching | ||
# this run: | ||
# tune download meta-llama/Meta-Llama-3.1-8B-Instruct --output-dir /tmp/Meta-Llama-3.1-8B-Instruct --ignore-patterns "original/consolidated.00.pth" | ||
# | ||
# To launch on 2 devices, run the following command from root: | ||
# tune run --nnodes 1 --nproc_per_node 2 lora_dpo_distributed --config llama3_1/8B_lora_dpo | ||
# | ||
# You can add specific overrides through the command line. For example | ||
# to override the checkpointer directory while launching training | ||
# you can run: | ||
# tune run --nnodes 1 --nproc_per_node 2 lora_dpo_distributed --config llama3_1/8B_lora_dpo checkpointer.checkpoint_dir=<YOUR_CHECKPOINT_DIR> | ||
# | ||
# This config works best when the model is being fine-tuned on 2+ GPUs. | ||
# For single device LoRA DPO alignment please use llama3_1/8B_lora_dpo_single_device | ||
|
||
# Model Arguments | ||
model: | ||
_component_: torchtune.models.llama3_1.lora_llama3_1_8b | ||
lora_attn_modules: ['q_proj', 'v_proj', 'output_proj'] | ||
apply_lora_to_mlp: True | ||
apply_lora_to_output: False | ||
lora_rank: 8 # higher increases accuracy and memory | ||
lora_alpha: 16 # usually alpha=2*rank | ||
lora_dropout: 0.0 | ||
|
||
# Tokenizer | ||
tokenizer: | ||
_component_: torchtune.models.llama3.llama3_tokenizer | ||
path: /tmp/Meta-Llama-3.1-8B-Instruct/original/tokenizer.model | ||
max_seq_len: null | ||
|
||
checkpointer: | ||
_component_: torchtune.training.FullModelHFCheckpointer | ||
checkpoint_dir: /tmp/Meta-Llama-3.1-8B-Instruct/ | ||
checkpoint_files: [ | ||
model-00001-of-00004.safetensors, | ||
model-00002-of-00004.safetensors, | ||
model-00003-of-00004.safetensors, | ||
model-00004-of-00004.safetensors | ||
] | ||
recipe_checkpoint: null | ||
output_dir: /tmp/Meta-Llama-3.1-8B-Instruct/ | ||
model_type: LLAMA3 | ||
resume_from_checkpoint: False | ||
save_adapter_weights_only: False | ||
|
||
# Dataset and Sampler | ||
dataset: | ||
_component_: torchtune.datasets.stack_exchange_paired_dataset | ||
seed: null | ||
shuffle: True | ||
batch_size: 4 | ||
|
||
# Optimizer and Scheduler | ||
optimizer: | ||
_component_: torch.optim.AdamW | ||
fused: True | ||
weight_decay: 0.05 | ||
lr: 5e-4 | ||
lr_scheduler: | ||
_component_: torchtune.training.lr_schedulers.get_cosine_schedule_with_warmup | ||
num_warmup_steps: 100 | ||
|
||
loss: | ||
_component_: torchtune.rlhf.loss.DPOLoss | ||
beta: 0.1 | ||
label_smoothing: 0 | ||
|
||
# Training | ||
epochs: 1 | ||
max_steps_per_epoch: 1000 | ||
gradient_accumulation_steps: 8 # Use to increase virtual batch size | ||
compile: False # pytorch compile, set to true for better perf/memory | ||
|
||
# Logging | ||
output_dir: /tmp/lora_dpo_output/ | ||
metric_logger: | ||
_component_: torchtune.training.metric_logging.DiskLogger | ||
log_dir: ${output_dir} | ||
log_every_n_steps: 1 | ||
log_peak_memory_stats: True | ||
|
||
# Environment | ||
device: cuda | ||
dtype: bf16 | ||
|
||
# Memory management | ||
enable_activation_checkpointing: True # True reduces memory | ||
enable_activation_offloading: False # True reduces memory |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
# Config for single device LoRA DPO alignment in lora_dpo_single_device.py | ||
# using a Llama2 7B model | ||
# | ||
# This config assumes that you've run the following command before launching | ||
# this run: | ||
# tune download meta-llama/Meta-Llama-3.1-8B-Instruct --output-dir /tmp/Meta-Llama-3.1-8B-Instruct --ignore-patterns "original/consolidated.00.pth" | ||
# | ||
# To launch on a single device, run the following command from root: | ||
# tune run lora_dpo_single_device --config llama3_1/8B_lora_dpo_single_device | ||
# | ||
# You can add specific overrides through the command line. For example | ||
# to override the checkpointer directory while launching training | ||
# you can run: | ||
# tune run lora_dpo_single_device --config llama3_1/8B_lora_dpo_single_device checkpointer.checkpoint_dir=<YOUR_CHECKPOINT_DIR> | ||
# | ||
# This config works only for training on single device. | ||
|
||
# Model Arguments | ||
model: | ||
_component_: torchtune.models.llama3_1.lora_llama3_1_8b | ||
lora_attn_modules: ['q_proj', 'v_proj', 'output_proj'] | ||
apply_lora_to_mlp: True | ||
apply_lora_to_output: False | ||
lora_rank: 8 # higher increases accuracy and memory | ||
lora_alpha: 16 # usually alpha=2*rank | ||
lora_dropout: 0.0 | ||
|
||
# Tokenizer | ||
tokenizer: | ||
_component_: torchtune.models.llama3.llama3_tokenizer | ||
path: /tmp/Meta-Llama-3.1-8B-Instruct/original/tokenizer.model | ||
max_seq_len: null | ||
|
||
checkpointer: | ||
_component_: torchtune.training.FullModelHFCheckpointer | ||
checkpoint_dir: /tmp/Meta-Llama-3.1-8B-Instruct/ | ||
checkpoint_files: [ | ||
model-00001-of-00004.safetensors, | ||
model-00002-of-00004.safetensors, | ||
model-00003-of-00004.safetensors, | ||
model-00004-of-00004.safetensors | ||
] | ||
recipe_checkpoint: null | ||
output_dir: /tmp/Meta-Llama-3.1-8B-Instruct/ | ||
model_type: LLAMA3 | ||
resume_from_checkpoint: False | ||
save_adapter_weights_only: False | ||
|
||
# Dataset and Sampler | ||
dataset: | ||
_component_: torchtune.datasets.stack_exchange_paired_dataset | ||
seed: null | ||
shuffle: True | ||
batch_size: 4 | ||
|
||
# Optimizer and Scheduler | ||
optimizer: | ||
_component_: torch.optim.AdamW | ||
fused: True | ||
weight_decay: 0.05 | ||
lr: 5e-4 | ||
lr_scheduler: | ||
_component_: torchtune.training.lr_schedulers.get_cosine_schedule_with_warmup | ||
num_warmup_steps: 100 | ||
|
||
loss: | ||
_component_: torchtune.rlhf.loss.DPOLoss | ||
|
||
# Training | ||
epochs: 1 | ||
max_steps_per_epoch: 1000 | ||
gradient_accumulation_steps: 8 # Use to increase virtual batch size | ||
compile: False # pytorch compile, set to true for better perf/memory | ||
|
||
# Logging | ||
output_dir: /tmp/lora_dpo_output/ | ||
metric_logger: | ||
_component_: torchtune.training.metric_logging.DiskLogger | ||
log_dir: ${output_dir} | ||
log_every_n_steps: 1 | ||
log_peak_memory_stats: True | ||
|
||
# Environment | ||
device: cuda | ||
dtype: bf16 | ||
|
||
# Memory management | ||
enable_activation_checkpointing: True # True reduces memory | ||
enable_activation_offloading: False # True reduces memory |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.