From 01b88cde1bcdd8e97b1a459e5a1345c08ce95483 Mon Sep 17 00:00:00 2001 From: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com> Date: Thu, 12 Mar 2026 11:09:36 +0100 Subject: [PATCH] Simplify get_train_dataloader in GRPO and RLOO --- trl/trainer/grpo_trainer.py | 45 +++++++++---------------------------- trl/trainer/rloo_trainer.py | 45 +++++++++---------------------------- 2 files changed, 20 insertions(+), 70 deletions(-) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 3ebd3c649ae..e5b3d21e96f 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -25,11 +25,9 @@ from collections import defaultdict, deque from collections.abc import Callable from contextlib import nullcontext -from functools import partial from pathlib import Path from typing import Any, Protocol -import datasets import numpy as np import pandas as pd import torch @@ -42,7 +40,7 @@ from packaging.version import Version from torch import nn from torch.distributed.fsdp import FullyShardedDataParallel as FSDP -from torch.utils.data import DataLoader, Sampler +from torch.utils.data import Sampler from transformers import ( AutoModelForSequenceClassification, AutoProcessor, @@ -55,8 +53,7 @@ is_trackio_available, is_wandb_available, ) -from transformers.trainer_utils import seed_worker -from transformers.utils import is_datasets_available, is_peft_available, is_rich_available +from transformers.utils import is_peft_available, is_rich_available from ..chat_template_utils import add_response_schema, get_training_chat_template, parse_response from ..data_utils import ( @@ -849,37 +846,15 @@ def _set_signature_columns_if_needed(self): # `steps_per_generation`. Thus, `_prepare_inputs` is called with this *generation* batch, and it handles the # splitting internally. # Maintenance note: This method is a copy-paste of the original `Trainer.get_train_dataloader` with only one line - # modification. As a result, some parts of the method aren't relevant to GRPO, but we keep them to stay one line - # apart from the super method, ensuring easier maintenance in the future. + # modification. def get_train_dataloader(self): - if self.train_dataset is None: - raise ValueError("Trainer: training requires a train_dataset.") - - train_dataset = self.train_dataset - data_collator = self.data_collator - if is_datasets_available() and isinstance(train_dataset, datasets.Dataset): - train_dataset = self._remove_unused_columns(train_dataset, description="training") - else: - data_collator = self._get_collator_with_removed_columns(data_collator, description="training") - - dataloader_params = { - "batch_size": self._train_batch_size * self.args.steps_per_generation, # < this is the change - "collate_fn": data_collator, - "num_workers": self.args.dataloader_num_workers, - "pin_memory": self.args.dataloader_pin_memory, - "persistent_workers": self.args.dataloader_persistent_workers, - } - - if not isinstance(train_dataset, torch.utils.data.IterableDataset): - dataloader_params["sampler"] = self._get_train_sampler() - dataloader_params["drop_last"] = self.args.dataloader_drop_last - dataloader_params["worker_init_fn"] = partial( - seed_worker, num_workers=self.args.dataloader_num_workers, rank=self.args.process_index - ) - - dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor - - return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params)) + return self._get_dataloader( + dataset=self.train_dataset, + description="Training", + batch_size=self._train_batch_size * self.args.steps_per_generation, # < this is the change + sampler_fn=self._get_train_sampler, + is_training=True, + ) def _get_train_sampler(self, dataset: Dataset | None = None) -> Sampler: # Returns a sampler that diff --git a/trl/trainer/rloo_trainer.py b/trl/trainer/rloo_trainer.py index e149518cca5..e47c0d8ae2a 100644 --- a/trl/trainer/rloo_trainer.py +++ b/trl/trainer/rloo_trainer.py @@ -21,11 +21,9 @@ from collections import defaultdict, deque from collections.abc import Callable from contextlib import nullcontext -from functools import partial from pathlib import Path from typing import Any -import datasets import numpy as np import pandas as pd import torch @@ -37,7 +35,7 @@ from packaging.version import Version from torch import nn from torch.distributed.fsdp import FullyShardedDataParallel as FSDP -from torch.utils.data import DataLoader, Sampler +from torch.utils.data import Sampler from transformers import ( AutoModelForSequenceClassification, AutoProcessor, @@ -50,8 +48,7 @@ is_trackio_available, is_wandb_available, ) -from transformers.trainer_utils import seed_worker -from transformers.utils import is_datasets_available, is_peft_available, is_rich_available +from transformers.utils import is_peft_available, is_rich_available from ..data_utils import apply_chat_template, is_conversational, prepare_multimodal_messages from ..extras.profiling import profiling_context, profiling_decorator @@ -608,37 +605,15 @@ def _set_signature_columns_if_needed(self): # `steps_per_generation`. Thus, `_prepare_inputs` is called with this *generation* batch, and it handles the # splitting internally. # Maintenance note: This method is a copy-paste of the original `Trainer.get_train_dataloader` with only one line - # modification. As a result, some parts of the method aren't relevant to RLOO, but we keep them to stay one line - # apart from the super method, ensuring easier maintenance in the future. + # modification. def get_train_dataloader(self): - if self.train_dataset is None: - raise ValueError("Trainer: training requires a train_dataset.") - - train_dataset = self.train_dataset - data_collator = self.data_collator - if is_datasets_available() and isinstance(train_dataset, datasets.Dataset): - train_dataset = self._remove_unused_columns(train_dataset, description="training") - else: - data_collator = self._get_collator_with_removed_columns(data_collator, description="training") - - dataloader_params = { - "batch_size": self._train_batch_size * self.args.steps_per_generation, # < this is the change - "collate_fn": data_collator, - "num_workers": self.args.dataloader_num_workers, - "pin_memory": self.args.dataloader_pin_memory, - "persistent_workers": self.args.dataloader_persistent_workers, - } - - if not isinstance(train_dataset, torch.utils.data.IterableDataset): - dataloader_params["sampler"] = self._get_train_sampler() - dataloader_params["drop_last"] = self.args.dataloader_drop_last - dataloader_params["worker_init_fn"] = partial( - seed_worker, num_workers=self.args.dataloader_num_workers, rank=self.args.process_index - ) - - dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor - - return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params)) + return self._get_dataloader( + dataset=self.train_dataset, + description="Training", + batch_size=self._train_batch_size * self.args.steps_per_generation, # < this is the change + sampler_fn=self._get_train_sampler, + is_training=True, + ) def _get_train_sampler(self, dataset: Dataset | None = None) -> Sampler: # Returns a sampler that