fix worker_init_fn signature handling#2769
Conversation
WalkthroughA new method, Changes
Sequence Diagram(s)sequenceDiagram
participant User
participant AxolotlGRPOTrainer
participant Dataset
participant DataLoader
participant Accelerator
User->>AxolotlGRPOTrainer: Call get_train_dataloader()
AxolotlGRPOTrainer->>AxolotlGRPOTrainer: Validate train_dataset
AxolotlGRPOTrainer->>Dataset: Remove unused columns (if applicable)
AxolotlGRPOTrainer->>AxolotlGRPOTrainer: Compute batch size (batch_size * steps_per_generation)
AxolotlGRPOTrainer->>DataLoader: Construct DataLoader with parameters
AxolotlGRPOTrainer->>Accelerator: Prepare DataLoader
Accelerator-->>AxolotlGRPOTrainer: Return prepared DataLoader
AxolotlGRPOTrainer-->>User: Return DataLoader
Poem
✨ Finishing Touches
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
SupportNeed help? Create a ticket on our support page for assistance with any issues or questions. Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
CodeRabbit Configuration File (
|
There was a problem hiding this comment.
Actionable comments posted: 1
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
src/axolotl/core/trainers/grpo/trainer.py(2 hunks)
⏰ Context from checks skipped due to timeout of 90000ms (8)
- GitHub Check: pre-commit
- GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
- GitHub Check: PyTest (3.11, 2.7.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.5.1)
- GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
- GitHub Check: PyTest (3.11, 2.6.0)
- GitHub Check: PyTest (3.11, 2.5.1)
- GitHub Check: pre-commit
🔇 Additional comments (3)
src/axolotl/core/trainers/grpo/trainer.py (3)
6-6: LGTM! Import added for worker initialization fix.The
partialimport is correctly added to support the worker_init_fn signature handling fix.
77-80: Batch size multiplication logic looks correct.The multiplication by
steps_per_generationaligns with GRPO's requirement to process multiple generations per prompt in a single batch.
89-93:Details
❓ Verification inconclusive
Verify the seed_worker function signature.
The
partialusage suggests a fix for incorrect parameter passing toseed_worker. However, I need to verify thatnum_workersandrankare the correct parameters thatseed_workerexpects.Let me check the
seed_workerfunction signature to ensure the parameters are correct:
🏁 Script executed:
#!/bin/bash # Description: Check the seed_worker function signature from transformers # Expected: Function signature showing the expected parameters ast-grep --pattern 'def seed_worker($$$)'Length of output: 43
#!/bin/bash
set -euo pipefailSearch for the seed_worker definition
rg -n "def seed_worker" -C3
Also search all references to seed_worker to see how it's used
rg -n "seed_worker" -C3
#!/bin/bash
set -euo pipefailLocate the seed_worker definition and its signature
rg -n "def seed_worker" -C3
Show all usages/imports to infer expected parameters
rg -n "seed_worker" -C3
#!/bin/bash
set -euo pipefailLocate the seed_worker definition (if present) and its signature
rg -n "^def seed_worker" -C5 .
Find all import sites and usages to infer expected parameters
rg -n "seed_worker" -C5 .
Confirm
seed_workersignature for DataLoaderworker_init_fnI couldn’t locate the
seed_workerimplementation in the repo to verify that it acceptsnum_workersandrankas keyword arguments. Please double-check its signature and adjust thepartial(...)call if the parameter names differ.• File:
src/axolotl/core/trainers/grpo/trainer.py
Lines: 89–93
| def get_train_dataloader(self): | ||
| if self.train_dataset is None: | ||
| raise ValueError("Trainer: training requires a train_dataset.") | ||
|
|
||
| train_dataset = self.train_dataset | ||
| data_collator = self.data_collator | ||
| if isinstance(train_dataset, datasets.Dataset): | ||
| train_dataset = self._remove_unused_columns( | ||
| train_dataset, description="training" | ||
| ) | ||
| else: | ||
| data_collator = self._get_collator_with_removed_columns( | ||
| data_collator, description="training" | ||
| ) | ||
|
|
||
| dataloader_params = { | ||
| "batch_size": self._train_batch_size | ||
| * self.args.steps_per_generation, # < this is the change | ||
| "collate_fn": data_collator, | ||
| "num_workers": self.args.dataloader_num_workers, | ||
| "pin_memory": self.args.dataloader_pin_memory, | ||
| "persistent_workers": self.args.dataloader_persistent_workers, | ||
| } | ||
|
|
||
| if not isinstance(train_dataset, torch.utils.data.IterableDataset): | ||
| dataloader_params["sampler"] = self._get_train_sampler() | ||
| dataloader_params["drop_last"] = self.args.dataloader_drop_last | ||
| dataloader_params["worker_init_fn"] = partial( | ||
| seed_worker, | ||
| num_workers=self.args.dataloader_num_workers, | ||
| rank=self.args.process_index, | ||
| ) | ||
| dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor | ||
|
|
||
| return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params)) | ||
|
|
There was a problem hiding this comment.
🛠️ Refactor suggestion
Inconsistency with AxolotlGRPOSequenceParallelTrainer implementation.
The AxolotlGRPOSequenceParallelTrainer class has a different get_train_dataloader method (line 276) that handles worker_init_fn differently - it uses seed_worker directly without partial. This inconsistency could lead to different behavior between the two trainer classes.
Consider aligning both implementations to use the same approach for worker_init_fn. If the partial approach is the correct fix, the AxolotlGRPOSequenceParallelTrainer should be updated accordingly:
# In AxolotlGRPOSequenceParallelTrainer.get_train_dataloader method around line 255
- dataloader_params["worker_init_fn"] = seed_worker
+ dataloader_params["worker_init_fn"] = partial(
+ seed_worker,
+ num_workers=self.args.dataloader_num_workers,
+ rank=self.args.process_index,
+ )📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| def get_train_dataloader(self): | |
| if self.train_dataset is None: | |
| raise ValueError("Trainer: training requires a train_dataset.") | |
| train_dataset = self.train_dataset | |
| data_collator = self.data_collator | |
| if isinstance(train_dataset, datasets.Dataset): | |
| train_dataset = self._remove_unused_columns( | |
| train_dataset, description="training" | |
| ) | |
| else: | |
| data_collator = self._get_collator_with_removed_columns( | |
| data_collator, description="training" | |
| ) | |
| dataloader_params = { | |
| "batch_size": self._train_batch_size | |
| * self.args.steps_per_generation, # < this is the change | |
| "collate_fn": data_collator, | |
| "num_workers": self.args.dataloader_num_workers, | |
| "pin_memory": self.args.dataloader_pin_memory, | |
| "persistent_workers": self.args.dataloader_persistent_workers, | |
| } | |
| if not isinstance(train_dataset, torch.utils.data.IterableDataset): | |
| dataloader_params["sampler"] = self._get_train_sampler() | |
| dataloader_params["drop_last"] = self.args.dataloader_drop_last | |
| dataloader_params["worker_init_fn"] = partial( | |
| seed_worker, | |
| num_workers=self.args.dataloader_num_workers, | |
| rank=self.args.process_index, | |
| ) | |
| dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor | |
| return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params)) | |
| def get_train_dataloader(self): | |
| # … earlier in the method … | |
| if not isinstance(train_dataset, torch.utils.data.IterableDataset): | |
| dataloader_params["sampler"] = self._get_train_sampler() | |
| dataloader_params["drop_last"] = self.args.dataloader_drop_last | |
| - dataloader_params["worker_init_fn"] = seed_worker | |
| + dataloader_params["worker_init_fn"] = partial( | |
| + seed_worker, | |
| + num_workers=self.args.dataloader_num_workers, | |
| + rank=self.args.process_index, | |
| + ) | |
| dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor | |
| return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params)) |
🤖 Prompt for AI Agents
In src/axolotl/core/trainers/grpo/trainer.py between lines 62 and 97, the
get_train_dataloader method uses partial to wrap seed_worker for the
worker_init_fn parameter, while the AxolotlGRPOSequenceParallelTrainer class
uses seed_worker directly. To fix this inconsistency, review both
implementations and decide on one approach for worker_init_fn; if partial is
preferred, update the AxolotlGRPOSequenceParallelTrainer's get_train_dataloader
method to use partial with the same arguments, ensuring consistent behavior
across both trainer classes.
Codecov ReportAttention: Patch coverage is
📢 Thoughts on this report? Let us know! |
Summary by CodeRabbit