Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions src/axolotl/core/trainers/grpo/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# pylint: disable=too-many-lines,duplicate-code,protected-access,no-member

import warnings
from functools import partial
from typing import Any

import datasets
Expand Down Expand Up @@ -58,6 +59,42 @@ class AxolotlGRPOTrainer(

_tag_names = ["trl", "grpo", "axolotl"]

def get_train_dataloader(self):
if self.train_dataset is None:
raise ValueError("Trainer: training requires a train_dataset.")

train_dataset = self.train_dataset
data_collator = self.data_collator
if isinstance(train_dataset, datasets.Dataset):
train_dataset = self._remove_unused_columns(
train_dataset, description="training"
)
else:
data_collator = self._get_collator_with_removed_columns(
data_collator, description="training"
)

dataloader_params = {
"batch_size": self._train_batch_size
* self.args.steps_per_generation, # < this is the change
"collate_fn": data_collator,
"num_workers": self.args.dataloader_num_workers,
"pin_memory": self.args.dataloader_pin_memory,
"persistent_workers": self.args.dataloader_persistent_workers,
}

if not isinstance(train_dataset, torch.utils.data.IterableDataset):
dataloader_params["sampler"] = self._get_train_sampler()
dataloader_params["drop_last"] = self.args.dataloader_drop_last
dataloader_params["worker_init_fn"] = partial(
seed_worker,
num_workers=self.args.dataloader_num_workers,
rank=self.args.process_index,
)
dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor

return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params))

Comment on lines +62 to +97

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Inconsistency with AxolotlGRPOSequenceParallelTrainer implementation.

The AxolotlGRPOSequenceParallelTrainer class has a different get_train_dataloader method (line 276) that handles worker_init_fn differently - it uses seed_worker directly without partial. This inconsistency could lead to different behavior between the two trainer classes.

Consider aligning both implementations to use the same approach for worker_init_fn. If the partial approach is the correct fix, the AxolotlGRPOSequenceParallelTrainer should be updated accordingly:

# In AxolotlGRPOSequenceParallelTrainer.get_train_dataloader method around line 255
- dataloader_params["worker_init_fn"] = seed_worker
+ dataloader_params["worker_init_fn"] = partial(
+     seed_worker,
+     num_workers=self.args.dataloader_num_workers,
+     rank=self.args.process_index,
+ )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def get_train_dataloader(self):
if self.train_dataset is None:
raise ValueError("Trainer: training requires a train_dataset.")
train_dataset = self.train_dataset
data_collator = self.data_collator
if isinstance(train_dataset, datasets.Dataset):
train_dataset = self._remove_unused_columns(
train_dataset, description="training"
)
else:
data_collator = self._get_collator_with_removed_columns(
data_collator, description="training"
)
dataloader_params = {
"batch_size": self._train_batch_size
* self.args.steps_per_generation, # < this is the change
"collate_fn": data_collator,
"num_workers": self.args.dataloader_num_workers,
"pin_memory": self.args.dataloader_pin_memory,
"persistent_workers": self.args.dataloader_persistent_workers,
}
if not isinstance(train_dataset, torch.utils.data.IterableDataset):
dataloader_params["sampler"] = self._get_train_sampler()
dataloader_params["drop_last"] = self.args.dataloader_drop_last
dataloader_params["worker_init_fn"] = partial(
seed_worker,
num_workers=self.args.dataloader_num_workers,
rank=self.args.process_index,
)
dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor
return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params))
def get_train_dataloader(self):
# … earlier in the method …
if not isinstance(train_dataset, torch.utils.data.IterableDataset):
dataloader_params["sampler"] = self._get_train_sampler()
dataloader_params["drop_last"] = self.args.dataloader_drop_last
- dataloader_params["worker_init_fn"] = seed_worker
+ dataloader_params["worker_init_fn"] = partial(
+ seed_worker,
+ num_workers=self.args.dataloader_num_workers,
+ rank=self.args.process_index,
+ )
dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor
return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params))
🤖 Prompt for AI Agents
In src/axolotl/core/trainers/grpo/trainer.py between lines 62 and 97, the
get_train_dataloader method uses partial to wrap seed_worker for the
worker_init_fn parameter, while the AxolotlGRPOSequenceParallelTrainer class
uses seed_worker directly. To fix this inconsistency, review both
implementations and decide on one approach for worker_init_fn; if partial is
preferred, update the AxolotlGRPOSequenceParallelTrainer's get_train_dataloader
method to use partial with the same arguments, ensuring consistent behavior
across both trainer classes.


class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
"""Extend the base GRPOTrainer for sequence parallelism handling"""
Expand Down