From 8cd24d565704f40440bb33aae0aeb028a8598b42 Mon Sep 17 00:00:00 2001 From: "Lavender =^..^=" Date: Tue, 18 Nov 2025 15:15:01 -0800 Subject: [PATCH 1/3] StatefulDistributedSampler add epoch to state_dict --- torchdata/stateful_dataloader/sampler.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/torchdata/stateful_dataloader/sampler.py b/torchdata/stateful_dataloader/sampler.py index cacb1d12c..358fbb62e 100644 --- a/torchdata/stateful_dataloader/sampler.py +++ b/torchdata/stateful_dataloader/sampler.py @@ -181,6 +181,7 @@ def __iter__(self): class StatefulDistributedSampler(torch.utils.data.distributed.DistributedSampler): _YIELDED = "yielded" + _EPOCH = "epoch" def __init__( self, @@ -196,6 +197,7 @@ def __init__( self.next_yielded = None def __iter__(self): + print(f"Calling __iter__... {self.yielded=} {self.next_yielded=}") self.yielded = 0 if self.next_yielded is not None: self.yielded = self.next_yielded @@ -206,9 +208,12 @@ def __iter__(self): yield idx def state_dict(self) -> Dict[str, Any]: - return {self._YIELDED: self.yielded} + return {self._YIELDED: self.yielded, self._EPOCH: self.epoch, "NEXT_YIELDED": self.next_yielded} def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + if self._EPOCH not in state_dict: + raise ValueError("Invalid state_dict") + self.set_epoch(state_dict[self._EPOCH]) if self._YIELDED not in state_dict: raise ValueError("Invalid state_dict") if state_dict[self._YIELDED] < 0: From 23e0ddd7aa6ad03ea872b6a80ce38ac401a2056b Mon Sep 17 00:00:00 2001 From: Lavender Date: Tue, 18 Nov 2025 16:26:31 -0800 Subject: [PATCH 2/3] update tests with new state dict format --- test/stateful_dataloader/test_sampler.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/test/stateful_dataloader/test_sampler.py b/test/stateful_dataloader/test_sampler.py index 7b172c8c0..9fade8502 100644 --- a/test/stateful_dataloader/test_sampler.py +++ b/test/stateful_dataloader/test_sampler.py @@ -96,10 +96,10 @@ def test_sampler_state_dict(self): def test_sampler_load_state_dict(self): sampler = StatefulDistributedSampler(self.dataset, num_replicas=10, rank=0) - sampler.load_state_dict({"yielded": 3}) + sampler.load_state_dict({"epoch": 0, "yielded": 3}) self.assertEqual(sampler.next_yielded, 3) with self.assertRaises(ValueError): - sampler.load_state_dict({"yielded": -1}) + sampler.load_state_dict({"epoch": 0, "yielded": -1}) def test_sampler_next_yielded(self): @@ -108,7 +108,12 @@ def test_sampler_next_yielded(self): next(iterator) # advance the iterator self.assertEqual(sampler.yielded, 1) self.assertIsNone(sampler.next_yielded) - sampler.load_state_dict({StatefulDistributedSampler._YIELDED: 5}) + sampler.load_state_dict( + { + StatefulDistributedSampler._EPOCH: 0, + StatefulDistributedSampler._YIELDED: 5, + } + ) self.assertEqual(sampler.next_yielded, 5) iterator = iter(sampler) next(iterator) # advance the iterator again From 998171600b4da1a300f1e161dea195380a810e26 Mon Sep 17 00:00:00 2001 From: Lavender Date: Mon, 1 Dec 2025 23:13:13 -0800 Subject: [PATCH 3/3] remove testing print --- torchdata/stateful_dataloader/sampler.py | 1 - 1 file changed, 1 deletion(-) diff --git a/torchdata/stateful_dataloader/sampler.py b/torchdata/stateful_dataloader/sampler.py index 358fbb62e..4ae37eac9 100644 --- a/torchdata/stateful_dataloader/sampler.py +++ b/torchdata/stateful_dataloader/sampler.py @@ -197,7 +197,6 @@ def __init__( self.next_yielded = None def __iter__(self): - print(f"Calling __iter__... {self.yielded=} {self.next_yielded=}") self.yielded = 0 if self.next_yielded is not None: self.yielded = self.next_yielded