Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions trl/experimental/online_dpo/online_dpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ class may differ from those in [`~transformers.TrainingArguments`].
> - `gradient_checkpointing`: Defaults to `True` instead of `False`.
> - `bf16`: Defaults to `True` if `fp16` is not set, instead of `False`.
> - `learning_rate`: Defaults to `5e-7` instead of `5e-5`.
> - `remove_unused_columns`: Defaults to `False` instead of `True`.
"""

_VALID_DICT_FIELDS = TrainingArguments._VALID_DICT_FIELDS + ["model_init_kwargs"]
Expand All @@ -168,6 +169,10 @@ class may differ from those in [`~transformers.TrainingArguments`].
default=5e-7,
metadata={"help": "The initial learning rate for AdamW."},
)
remove_unused_columns: bool = field(
default=False,
metadata={"help": "Whether or not to automatically remove the columns unused by the model forward method."},
)

reward_model_path: str | None = field(
default=None,
Expand Down
79 changes: 2 additions & 77 deletions trl/experimental/online_dpo/online_dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import textwrap
from collections.abc import Callable
from contextlib import nullcontext
from functools import wraps
from pathlib import Path
from typing import Any

Expand All @@ -32,7 +31,7 @@
from datasets import Dataset
from packaging.version import Version
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.utils.data import DataLoader, IterableDataset
from torch.utils.data import IterableDataset
from transformers import (
AutoModelForCausalLM,
AutoModelForSequenceClassification,
Expand All @@ -42,12 +41,11 @@
PreTrainedModel,
PreTrainedTokenizerBase,
ProcessorMixin,
Trainer,
TrainerCallback,
is_bitsandbytes_available,
)
from transformers.models.auto.modeling_auto import MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES
from transformers.trainer_utils import EvalPrediction, seed_worker
from transformers.trainer_utils import EvalPrediction
from transformers.training_args import OptimizerNames
from transformers.utils import is_flash_attn_2_available, is_peft_available, is_sagemaker_mp_enabled

Expand Down Expand Up @@ -609,79 +607,6 @@ def tokenize_row(feature, is_encoder_decoder: bool, tokenizer: PreTrainedTokeniz
batch = {f"prompt_{key}": value for key, value in batch.items()}
return batch

# Same as Trainer.get_train_dataloader but skip the "remove_unused_columns".
@wraps(Trainer.get_train_dataloader)
def get_train_dataloader(self) -> DataLoader:
if self.train_dataset is None:
raise ValueError("Trainer: training requires a train_dataset.")

train_dataset = self.train_dataset
data_collator = self.data_collator
dataloader_params = {
"batch_size": self._train_batch_size,
"collate_fn": data_collator,
"num_workers": self.args.dataloader_num_workers,
"pin_memory": self.args.dataloader_pin_memory,
"persistent_workers": self.args.dataloader_persistent_workers,
}

if not isinstance(train_dataset, torch.utils.data.IterableDataset):
dataloader_params["sampler"] = self._get_train_sampler()
dataloader_params["drop_last"] = self.args.dataloader_drop_last
dataloader_params["worker_init_fn"] = seed_worker
dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor

return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params))

# Same as Trainer.get_eval_dataloader but skip the "remove_unused_columns".
@wraps(Trainer.get_eval_dataloader)
def get_eval_dataloader(self, eval_dataset: str | Dataset | None = None) -> DataLoader:
if eval_dataset is None and self.eval_dataset is None:
raise ValueError("Trainer: evaluation requires an eval_dataset.")

# If we have persistent workers, don't do a fork bomb especially as eval datasets
# don't change during training
dataloader_key = eval_dataset if isinstance(eval_dataset, str) else "eval"
if (
hasattr(self, "_eval_dataloaders")
and dataloader_key in self._eval_dataloaders
and self.args.dataloader_persistent_workers
):
return self.accelerator.prepare(self._eval_dataloaders[dataloader_key])

eval_dataset = (
self.eval_dataset[eval_dataset]
if isinstance(eval_dataset, str)
else eval_dataset
if eval_dataset is not None
else self.eval_dataset
)
data_collator = self.data_collator

dataloader_params = {
"batch_size": self.args.eval_batch_size,
"collate_fn": data_collator,
"num_workers": self.args.dataloader_num_workers,
"pin_memory": self.args.dataloader_pin_memory,
"persistent_workers": self.args.dataloader_persistent_workers,
}

if not isinstance(eval_dataset, torch.utils.data.IterableDataset):
dataloader_params["sampler"] = self._get_eval_sampler(eval_dataset)
dataloader_params["drop_last"] = self.args.dataloader_drop_last
dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor

# accelerator.free_memory() will destroy the references, so
# we need to store the non-prepared version
eval_dataloader = DataLoader(eval_dataset, **dataloader_params)
if self.args.dataloader_persistent_workers:
if hasattr(self, "_eval_dataloaders"):
self._eval_dataloaders[dataloader_key] = eval_dataloader
else:
self._eval_dataloaders = {dataloader_key: eval_dataloader}

return self.accelerator.prepare(eval_dataloader)

def _enable_gradient_checkpointing(self, model: PreTrainedModel, args: OnlineDPOConfig) -> PreTrainedModel:
"""Enables gradient checkpointing for the model."""
# Ensure use_cache is disabled
Expand Down
Loading