Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 13 additions & 12 deletions tests/train/test_trainer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -840,22 +840,23 @@ def test_validate_generator_output_element_length_mismatch():
validate_generator_output(len(input_batch["prompts"]), generator_output)


def test_build_dataloader_seeding(dummy_config):
"""Test that build_dataloader correctly seeds the dataloader for reproducible shuffling."""
# Create a dataset with multiple distinct items to test shuffling
class MultiItemDataset:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

To improve clarity and indicate that this class is intended for internal use within this test module, consider renaming it to _MultiItemDataset. This follows the Python convention (PEP 8) for internal-use names and prevents it from being accidentally used elsewhere.

Suggested change
class MultiItemDataset:
class _MultiItemDataset:
References
  1. PEP 8 suggests using a single leading underscore for internal-use functions, methods, and attributes to signal they are not part of the public API of the module.

def __init__(self, size=10):
self.data = [f"item_{i}" for i in range(size)]

def __len__(self):
return len(self.data)

# Create a dataset with multiple distinct items to test shuffling
class MultiItemDataset:
def __init__(self, size=10):
self.data = [f"item_{i}" for i in range(size)]
def __getitem__(self, idx):
return self.data[idx]

def __len__(self):
return len(self.data)
def collate_fn(self, batch):
return batch

def __getitem__(self, idx):
return self.data[idx]

def collate_fn(self, batch):
return batch
def test_build_dataloader_seeding(dummy_config):
"""Test that build_dataloader correctly seeds the dataloader for reproducible shuffling."""

dataset = MultiItemDataset(size=20)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Please update the instantiation to use the new class name _MultiItemDataset as suggested for the class definition.

Suggested change
dataset = MultiItemDataset(size=20)
dataset = _MultiItemDataset(size=20)


Expand Down
Loading