diff --git a/tests/train/test_trainer_utils.py b/tests/train/test_trainer_utils.py index 0177ffbd5f..f18d0d402c 100644 --- a/tests/train/test_trainer_utils.py +++ b/tests/train/test_trainer_utils.py @@ -840,22 +840,23 @@ def test_validate_generator_output_element_length_mismatch(): validate_generator_output(len(input_batch["prompts"]), generator_output) -def test_build_dataloader_seeding(dummy_config): - """Test that build_dataloader correctly seeds the dataloader for reproducible shuffling.""" +# Create a dataset with multiple distinct items to test shuffling +class MultiItemDataset: + def __init__(self, size=10): + self.data = [f"item_{i}" for i in range(size)] + + def __len__(self): + return len(self.data) - # Create a dataset with multiple distinct items to test shuffling - class MultiItemDataset: - def __init__(self, size=10): - self.data = [f"item_{i}" for i in range(size)] + def __getitem__(self, idx): + return self.data[idx] - def __len__(self): - return len(self.data) + def collate_fn(self, batch): + return batch - def __getitem__(self, idx): - return self.data[idx] - def collate_fn(self, batch): - return batch +def test_build_dataloader_seeding(dummy_config): + """Test that build_dataloader correctly seeds the dataloader for reproducible shuffling.""" dataset = MultiItemDataset(size=20)