Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 10 additions & 35 deletions trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,9 @@
from collections import defaultdict, deque
from collections.abc import Callable
from contextlib import nullcontext
from functools import partial
from pathlib import Path
from typing import Any, Protocol

import datasets
import numpy as np
import pandas as pd
import torch
Expand All @@ -42,7 +40,7 @@
from packaging.version import Version
from torch import nn
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.utils.data import DataLoader, Sampler
from torch.utils.data import Sampler
from transformers import (
AutoModelForSequenceClassification,
AutoProcessor,
Expand All @@ -55,8 +53,7 @@
is_trackio_available,
is_wandb_available,
)
from transformers.trainer_utils import seed_worker
from transformers.utils import is_datasets_available, is_peft_available, is_rich_available
from transformers.utils import is_peft_available, is_rich_available

from ..chat_template_utils import add_response_schema, get_training_chat_template, parse_response
from ..data_utils import (
Expand Down Expand Up @@ -849,37 +846,15 @@ def _set_signature_columns_if_needed(self):
# `steps_per_generation`. Thus, `_prepare_inputs` is called with this *generation* batch, and it handles the
# splitting internally.
# Maintenance note: This method is a copy-paste of the original `Trainer.get_train_dataloader` with only one line
# modification. As a result, some parts of the method aren't relevant to GRPO, but we keep them to stay one line
# apart from the super method, ensuring easier maintenance in the future.
# modification.
def get_train_dataloader(self):
if self.train_dataset is None:
raise ValueError("Trainer: training requires a train_dataset.")

train_dataset = self.train_dataset
data_collator = self.data_collator
if is_datasets_available() and isinstance(train_dataset, datasets.Dataset):
train_dataset = self._remove_unused_columns(train_dataset, description="training")
else:
data_collator = self._get_collator_with_removed_columns(data_collator, description="training")

dataloader_params = {
"batch_size": self._train_batch_size * self.args.steps_per_generation, # < this is the change
"collate_fn": data_collator,
"num_workers": self.args.dataloader_num_workers,
"pin_memory": self.args.dataloader_pin_memory,
"persistent_workers": self.args.dataloader_persistent_workers,
}

if not isinstance(train_dataset, torch.utils.data.IterableDataset):
dataloader_params["sampler"] = self._get_train_sampler()
dataloader_params["drop_last"] = self.args.dataloader_drop_last
dataloader_params["worker_init_fn"] = partial(
seed_worker, num_workers=self.args.dataloader_num_workers, rank=self.args.process_index
)

dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor

return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params))
return self._get_dataloader(
dataset=self.train_dataset,
description="Training",
batch_size=self._train_batch_size * self.args.steps_per_generation, # < this is the change
sampler_fn=self._get_train_sampler,
is_training=True,
)

def _get_train_sampler(self, dataset: Dataset | None = None) -> Sampler:
# Returns a sampler that
Expand Down
45 changes: 10 additions & 35 deletions trl/trainer/rloo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,9 @@
from collections import defaultdict, deque
from collections.abc import Callable
from contextlib import nullcontext
from functools import partial
from pathlib import Path
from typing import Any

import datasets
import numpy as np
import pandas as pd
import torch
Expand All @@ -37,7 +35,7 @@
from packaging.version import Version
from torch import nn
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.utils.data import DataLoader, Sampler
from torch.utils.data import Sampler
from transformers import (
AutoModelForSequenceClassification,
AutoProcessor,
Expand All @@ -50,8 +48,7 @@
is_trackio_available,
is_wandb_available,
)
from transformers.trainer_utils import seed_worker
from transformers.utils import is_datasets_available, is_peft_available, is_rich_available
from transformers.utils import is_peft_available, is_rich_available

from ..data_utils import apply_chat_template, is_conversational, prepare_multimodal_messages
from ..extras.profiling import profiling_context, profiling_decorator
Expand Down Expand Up @@ -608,37 +605,15 @@ def _set_signature_columns_if_needed(self):
# `steps_per_generation`. Thus, `_prepare_inputs` is called with this *generation* batch, and it handles the
# splitting internally.
# Maintenance note: This method is a copy-paste of the original `Trainer.get_train_dataloader` with only one line
# modification. As a result, some parts of the method aren't relevant to RLOO, but we keep them to stay one line
# apart from the super method, ensuring easier maintenance in the future.
# modification.
def get_train_dataloader(self):
if self.train_dataset is None:
raise ValueError("Trainer: training requires a train_dataset.")

train_dataset = self.train_dataset
data_collator = self.data_collator
if is_datasets_available() and isinstance(train_dataset, datasets.Dataset):
train_dataset = self._remove_unused_columns(train_dataset, description="training")
else:
data_collator = self._get_collator_with_removed_columns(data_collator, description="training")

dataloader_params = {
"batch_size": self._train_batch_size * self.args.steps_per_generation, # < this is the change
"collate_fn": data_collator,
"num_workers": self.args.dataloader_num_workers,
"pin_memory": self.args.dataloader_pin_memory,
"persistent_workers": self.args.dataloader_persistent_workers,
}

if not isinstance(train_dataset, torch.utils.data.IterableDataset):
dataloader_params["sampler"] = self._get_train_sampler()
dataloader_params["drop_last"] = self.args.dataloader_drop_last
dataloader_params["worker_init_fn"] = partial(
seed_worker, num_workers=self.args.dataloader_num_workers, rank=self.args.process_index
)

dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor

return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params))
return self._get_dataloader(
dataset=self.train_dataset,
description="Training",
batch_size=self._train_batch_size * self.args.steps_per_generation, # < this is the change
sampler_fn=self._get_train_sampler,
is_training=True,
)

def _get_train_sampler(self, dataset: Dataset | None = None) -> Sampler:
# Returns a sampler that
Expand Down
Loading