diff --git a/python/paddle/fluid/dataloader/dataloader_iter.py b/python/paddle/fluid/dataloader/dataloader_iter.py index af60776a3f1c5..92fe3fb91549b 100644 --- a/python/paddle/fluid/dataloader/dataloader_iter.py +++ b/python/paddle/fluid/dataloader/dataloader_iter.py @@ -374,6 +374,8 @@ def __init__(self, loader): # see _try_put_indices self._thread_lock = threading.Lock() + self._base_seed = np.random.randint(low=0, high=sys.maxsize) + # init workers and indices queues and put 2 indices in each indices queue self._init_workers() for _ in range(self._outstanding_capacity): @@ -406,7 +408,8 @@ def _init_workers(self): self._data_queue, self._workers_done_event, self._auto_collate_batch, self._collate_fn, self._drop_last, self._worker_init_fn, i, - self._num_workers, self._use_shared_memory)) + self._num_workers, self._use_shared_memory, + self._base_seed)) worker.daemon = True worker.start() self._workers.append(worker) diff --git a/python/paddle/fluid/dataloader/worker.py b/python/paddle/fluid/dataloader/worker.py index 0c3ec898aadfd..06ea7ef9d72a3 100644 --- a/python/paddle/fluid/dataloader/worker.py +++ b/python/paddle/fluid/dataloader/worker.py @@ -257,7 +257,7 @@ def mix(x, y): def _worker_loop(dataset, dataset_kind, indices_queue, out_queue, done_event, auto_collate_batch, collate_fn, drop_last, init_fn, worker_id, - num_workers, use_shared_memory): + num_workers, use_shared_memory, base_seed): try: # NOTE: [ mmap files clear ] When the child process exits unexpectedly, # some shared memory objects may have been applied for but have not yet @@ -272,15 +272,20 @@ def _worker_loop(dataset, dataset_kind, indices_queue, out_queue, done_event, try: import numpy as np import time + import random except ImportError: pass else: - np.random.seed(_generate_states(int(time.time()), worker_id)) + seed = base_seed + worker_id + random.seed(seed) + paddle.seed(seed) + np.random.seed(_generate_states(base_seed, worker_id)) global _worker_info _worker_info = WorkerInfo(id=worker_id, num_workers=num_workers, - dataset=dataset) + dataset=dataset, + seed=base_seed) init_exception = None try: diff --git a/python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_exception.py b/python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_exception.py index 2d6cdac4854f7..e2ed2d8003a46 100644 --- a/python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_exception.py +++ b/python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_exception.py @@ -181,10 +181,11 @@ def _collate_fn(sample_list): for i in range(10): indices_queue.put([i, i + 10]) indices_queue.put(None) + base_seed = 1234 _worker_loop(loader._dataset, 0, indices_queue, loader._data_queue, loader._workers_done_event, True, _collate_fn, True, _init_fn, 0, 1, - loader._use_shared_memory) + loader._use_shared_memory, base_seed) self.assertTrue(False) except AssertionError: pass @@ -223,10 +224,11 @@ def _collate_fn(sample_list): indices_queue.put([i, i + 10]) indices_queue.put(None) loader._workers_done_event.set() + base_seed = 1234 _worker_loop(loader._dataset, 0, indices_queue, loader._data_queue, loader._workers_done_event, True, _collate_fn, True, _init_fn, 0, 1, - loader._use_shared_memory) + loader._use_shared_memory, base_seed) self.assertTrue(True) except AssertionError: pass