Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from peft import LoraConfig
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser, TrainingArguments, set_seed

from trl import DPOTrainer
from trl import DPOConfig, DPOTrainer


# Define and parse arguments.
Expand Down Expand Up @@ -167,7 +167,7 @@ def return_prompt_and_responses(samples) -> Dict[str, str]:
)

# 4. initialize training arguments:
training_args = TrainingArguments(
training_args = DPOConfig(
per_device_train_batch_size=script_args.per_device_train_batch_size,
per_device_eval_batch_size=script_args.per_device_eval_batch_size,
max_steps=script_args.max_steps,
Expand Down
4 changes: 2 additions & 2 deletions examples/scripts/dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@
from datasets import Dataset, load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser, TrainingArguments

from trl import DPOTrainer, ModelConfig, get_kbit_device_map, get_peft_config, get_quantization_config
from trl import DPOConfig, DPOTrainer, ModelConfig, get_kbit_device_map, get_peft_config, get_quantization_config


@dataclass
Expand Down Expand Up @@ -116,7 +116,7 @@ def split_prompt_and_responses(sample) -> Dict[str, str]:


if __name__ == "__main__":
parser = HfArgumentParser((ScriptArguments, TrainingArguments, ModelConfig))
parser = HfArgumentParser((ScriptArguments, DPOConfig, ModelConfig))
args, training_args, model_config = parser.parse_args_into_dataclasses()

################
Expand Down
10 changes: 5 additions & 5 deletions tests/slow/test_dpo_slow.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@
from accelerate.utils.memory import release_memory
from datasets import load_dataset
from parameterized import parameterized
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TrainingArguments
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig

from trl import DPOTrainer, is_peft_available
from trl import DPOConfig, DPOTrainer, is_peft_available

from ..testing_utils import require_bitsandbytes, require_peft, require_torch_gpu
from .testing_constants import DPO_LOSS_TYPES, DPO_PRECOMPUTE_LOGITS, GRADIENT_CHECKPOINTING_KWARGS, MODELS_TO_TEST
Expand Down Expand Up @@ -60,7 +60,7 @@ def test_dpo_bare_model(self, model_id, loss_type, pre_compute_logits):
tokenizer = AutoTokenizer.from_pretrained(model_id)

with tempfile.TemporaryDirectory() as tmp_dir:
training_args = TrainingArguments(
training_args = DPOConfig(
output_dir=tmp_dir,
per_device_train_batch_size=2,
max_steps=2,
Expand Down Expand Up @@ -114,7 +114,7 @@ def test_dpo_peft_model(self, model_id, loss_type, pre_compute_logits, gradient_
tokenizer = AutoTokenizer.from_pretrained(model_id)

with tempfile.TemporaryDirectory() as tmp_dir:
training_args = TrainingArguments(
training_args = DPOConfig(
output_dir=tmp_dir,
per_device_train_batch_size=2,
max_steps=2,
Expand Down Expand Up @@ -178,7 +178,7 @@ def test_dpo_peft_model_qlora(self, model_id, loss_type, pre_compute_logits, gra
tokenizer = AutoTokenizer.from_pretrained(model_id)

with tempfile.TemporaryDirectory() as tmp_dir:
training_args = TrainingArguments(
training_args = DPOConfig(
output_dir=tmp_dir,
per_device_train_batch_size=2,
max_steps=2,
Expand Down
116 changes: 103 additions & 13 deletions tests/test_dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@
from datasets import Dataset
from parameterized import parameterized
from pytest import mark
from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer, TrainingArguments
from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer

from trl import DPOTrainer
from trl import DPOConfig, DPOTrainer, FDivergenceType

from .testing_utils import require_bitsandbytes, require_no_wandb, require_peft

Expand Down Expand Up @@ -92,7 +92,7 @@ def _init_dummy_dataset(self):
)
def test_dpo_trainer(self, name, loss_type, pre_compute):
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = TrainingArguments(
training_args = DPOConfig(
output_dir=tmp_dir,
per_device_train_batch_size=2,
max_steps=3,
Expand Down Expand Up @@ -140,7 +140,7 @@ def test_dpo_trainer(self, name, loss_type, pre_compute):

def test_dpo_trainer_without_providing_ref_model(self):
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = TrainingArguments(
training_args = DPOConfig(
output_dir=tmp_dir,
per_device_train_batch_size=2,
max_steps=3,
Expand Down Expand Up @@ -190,7 +190,7 @@ def test_dpo_trainer_without_providing_ref_model_with_lora(self):
)

with tempfile.TemporaryDirectory() as tmp_dir:
training_args = TrainingArguments(
training_args = DPOConfig(
output_dir=tmp_dir,
per_device_train_batch_size=2,
max_steps=3,
Expand Down Expand Up @@ -230,7 +230,7 @@ def test_dpo_trainer_without_providing_ref_model_with_lora(self):

def test_dpo_trainer_padding_token_is_none(self):
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = TrainingArguments(
training_args = DPOConfig(
output_dir=tmp_dir,
per_device_train_batch_size=2,
max_steps=3,
Expand Down Expand Up @@ -265,7 +265,7 @@ def test_dpo_trainer_padding_token_is_none(self):

def test_dpo_trainer_w_dataset_num_proc(self):
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = TrainingArguments(
training_args = DPOConfig(
output_dir=tmp_dir,
per_device_train_batch_size=2,
max_steps=3,
Expand Down Expand Up @@ -302,7 +302,7 @@ def test_dpo_trainer_w_dataset_num_proc(self):
@require_no_wandb
def test_dpo_trainer_generate_during_eval_no_wandb(self):
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = TrainingArguments(
training_args = DPOConfig(
output_dir=tmp_dir,
per_device_train_batch_size=2,
max_steps=3,
Expand Down Expand Up @@ -348,7 +348,7 @@ def test_dpo_lora_save(self):
model_peft = get_peft_model(model, lora_config)

with tempfile.TemporaryDirectory() as tmp_dir:
training_args = TrainingArguments(
training_args = DPOConfig(
output_dir=tmp_dir,
per_device_train_batch_size=2,
max_steps=3,
Expand Down Expand Up @@ -407,7 +407,7 @@ def test_dpo_lora_bf16_autocast_llama(self):
model = AutoModelForCausalLM.from_pretrained(model_id, load_in_4bit=True)

with tempfile.TemporaryDirectory() as tmp_dir:
training_args = TrainingArguments(
training_args = DPOConfig(
output_dir=tmp_dir,
per_device_train_batch_size=2,
max_steps=3,
Expand Down Expand Up @@ -475,7 +475,7 @@ def test_dpo_lora_bf16_autocast(self, name, loss_type, pre_compute, gen_during_e
model = AutoModelForCausalLM.from_pretrained(self.model_id, load_in_4bit=True)

with tempfile.TemporaryDirectory() as tmp_dir:
training_args = TrainingArguments(
training_args = DPOConfig(
output_dir=tmp_dir,
per_device_train_batch_size=2,
max_steps=3,
Expand Down Expand Up @@ -528,7 +528,7 @@ def test_dpo_lora_tags(self):
model = AutoModelForCausalLM.from_pretrained(model_id)

with tempfile.TemporaryDirectory() as tmp_dir:
training_args = TrainingArguments(
training_args = DPOConfig(
output_dir=tmp_dir,
per_device_train_batch_size=2,
max_steps=3,
Expand Down Expand Up @@ -563,7 +563,7 @@ def test_dpo_tags(self):
model = AutoModelForCausalLM.from_pretrained(model_id)

with tempfile.TemporaryDirectory() as tmp_dir:
training_args = TrainingArguments(
training_args = DPOConfig(
output_dir=tmp_dir,
per_device_train_batch_size=2,
max_steps=3,
Expand All @@ -587,3 +587,93 @@ def test_dpo_tags(self):
)

assert trainer.model.model_tags == trainer._tag_names

def test_dpo_loss_alpha_div_f(self):
model_id = "HuggingFaceM4/tiny-random-LlamaForCausalLM"
tokenizer = AutoTokenizer.from_pretrained(model_id)

# lora model
model = AutoModelForCausalLM.from_pretrained(model_id)

with tempfile.TemporaryDirectory() as tmp_dir:
training_args = DPOConfig(
output_dir=tmp_dir,
per_device_train_batch_size=2,
max_steps=3,
remove_unused_columns=False,
gradient_accumulation_steps=4,
learning_rate=9e-1,
evaluation_strategy="steps",
f_divergence_type=FDivergenceType.ALPHA_DIVERGENCE.value,
f_alpha_divergence_coef=0.5,
)

dummy_dataset = self._init_dummy_dataset()

# dpo train lora model with a lora config
trainer = DPOTrainer(
model=model,
ref_model=None,
beta=0.1,
args=training_args,
tokenizer=tokenizer,
train_dataset=dummy_dataset,
eval_dataset=dummy_dataset,
)

# Fake chosen and rejected log probs
policy_chosen_logps = torch.FloatTensor([410.0, 0.1])
policy_rejected_logps = torch.FloatTensor([810.5, 0.2])
reference_chosen_logps = torch.FloatTensor([-610.0, -0.1])
reference_rejected_logps = torch.FloatTensor([110.6, 0.5])
losses, _, _ = trainer.dpo_loss(policy_chosen_logps,
policy_rejected_logps,
reference_chosen_logps,
reference_rejected_logps)

assert torch.isfinite(losses).cpu().numpy().all()

def test_dpo_loss_js_div_f(self):
model_id = "HuggingFaceM4/tiny-random-LlamaForCausalLM"
tokenizer = AutoTokenizer.from_pretrained(model_id)

# lora model
model = AutoModelForCausalLM.from_pretrained(model_id)

with tempfile.TemporaryDirectory() as tmp_dir:
training_args = DPOConfig(
output_dir=tmp_dir,
per_device_train_batch_size=2,
max_steps=3,
remove_unused_columns=False,
gradient_accumulation_steps=4,
learning_rate=9e-1,
evaluation_strategy="steps",
f_divergence_type=FDivergenceType.JS_DIVERGENCE.value,
f_alpha_divergence_coef=0.5,
)

dummy_dataset = self._init_dummy_dataset()

# dpo train lora model with a lora config
trainer = DPOTrainer(
model=model,
ref_model=None,
beta=0.1,
args=training_args,
tokenizer=tokenizer,
train_dataset=dummy_dataset,
eval_dataset=dummy_dataset,
)

# Fake chosen and rejected log probs
policy_chosen_logps = torch.FloatTensor([410.0, 0.1])
policy_rejected_logps = torch.FloatTensor([95.5, 0.2])
reference_chosen_logps = torch.FloatTensor([-610.0, -0.1])
reference_rejected_logps = torch.FloatTensor([5.5, 0.5])
losses, _, _ = trainer.dpo_loss(policy_chosen_logps,
policy_rejected_logps,
reference_chosen_logps,
reference_rejected_logps)

assert torch.isfinite(losses).cpu().numpy().all()
3 changes: 3 additions & 0 deletions trl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,10 @@
)
from .trainer import (
DataCollatorForCompletionOnlyLM,
DPOConfig,
DPOTrainer,
FDivergenceConstants,
FDivergenceType,
IterativeSFTTrainer,
KTOConfig,
KTOTrainer,
Expand Down
2 changes: 1 addition & 1 deletion trl/trainer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
if is_diffusers_available():
from .ddpo_trainer import DDPOTrainer

from .dpo_trainer import DPOTrainer
from .dpo_trainer import DPOConfig, DPOTrainer, FDivergenceConstants, FDivergenceType
from .iterative_sft_trainer import IterativeSFTTrainer
from .kto_config import KTOConfig
from .kto_trainer import KTOTrainer
Expand Down
46 changes: 46 additions & 0 deletions trl/trainer/dpo_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
from enum import Enum
from typing import Dict, Optional

from transformers import TrainingArguments


class FDivergenceType(Enum):
REVERSE_KL = "reverse_kl"
JS_DIVERGENCE = "js_divergence"
ALPHA_DIVERGENCE = "alpha_divergence"


class FDivergenceConstants:
ALPHA_DIVERGENCE_COEF_KEY = "alpha_divergence_coef"
ALPHA_DIVERGENCE_COEF_DEFAULT = 1.0


@dataclass
class DPOConfig(TrainingArguments):
"""
DPOConfig collects all training arguments related to the [`DPOTrainer`] class.
Using [`HfArgumentParser`] we can turn this class into
[argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
command line.
Parameters:
f_divergence_type (`FDivergenceType`, *optional*, defaults to `FDivergenceType.REVERSE_KL`):
The type of f-divergence regularization function to compute divergence between policy and reference model. This argument is optional, defaults to `FDivergenceType.REVERSE_KL`.
f_divergence_params (`Dict`, *optional*, defaults to `None`):
The parameters of f-divergence regularization function, eg: the alpha parameter in alpha-divergence. This argument is optional, defaults to 'None'.
"""

f_divergence_type: Optional[FDivergenceType] = FDivergenceType.REVERSE_KL
"""The type of f-divergence regularization function to compute divergence between policy and reference model, This argument is optional, defaults to `FDivergenceType.REVERSE_KL`."""
f_alpha_divergence_coef: float = field(default=1.0, metadata={"help": "the alpha coef in alpha-divergence(u^-alpha) regularization function for DPO loss"})
"""The alpha coef in alpha-divergence(u^-alpha) regularization function for DPO loss."""
Loading