diff --git a/test/test_collector.py b/test/test_collector.py index 372c54a04e8..93ee0a53f1f 100644 --- a/test/test_collector.py +++ b/test/test_collector.py @@ -26,6 +26,7 @@ from torchrl._utils import seed_generator from torchrl.collectors import aSyncDataCollector, SyncDataCollector from torchrl.collectors.collectors import ( + _Interruptor, MultiaSyncDataCollector, MultiSyncDataCollector, RandomPolicy, @@ -1232,6 +1233,67 @@ def weight_reset(m): m.reset_parameters() +class TestPreemptiveThreshold: + @pytest.mark.parametrize("env_name", ["conv", "vec"]) + def test_sync_collector_interruptor_mechanism(self, env_name, seed=100): + def env_fn(seed): + env = make_make_env(env_name)() + env.set_seed(seed) + return env + + policy = make_policy(env_name) + interruptor = _Interruptor() + interruptor.start_collection() + + collector = SyncDataCollector( + create_env_fn=env_fn, + create_env_kwargs={"seed": seed}, + policy=policy, + frames_per_batch=50, + total_frames=200, + device="cpu", + interruptor=interruptor, + split_trajs=False, + ) + + interruptor.stop_collection() + for batch in collector: + assert batch["collector"]["traj_ids"][0] != -1 + assert batch["collector"]["traj_ids"][1] == -1 + + @pytest.mark.parametrize("env_name", ["conv", "vec"]) + def test_multisync_collector_interruptor_mechanism(self, env_name, seed=100): + + frames_per_batch = 800 + + def env_fn(seed): + env = make_make_env(env_name)() + env.set_seed(seed) + return env + + policy = make_policy(env_name) + + collector = MultiSyncDataCollector( + create_env_fn=[env_fn] * 4, + create_env_kwargs=[{"seed": seed}] * 4, + policy=policy, + total_frames=800, + max_frames_per_traj=50, + frames_per_batch=frames_per_batch, + init_random_frames=-1, + reset_at_each_iter=False, + devices="cpu", + storing_devices="cpu", + split_trajs=False, + preemptive_threshold=0.0, # stop after one iteration + ) + + for batch in collector: + trajectory_ids = batch["collector"]["traj_ids"] + trajectory_ids_mask = trajectory_ids != -1 # valid frames mask + assert trajectory_ids[trajectory_ids_mask].numel() < frames_per_batch + + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index c3ba6152f5b..9e2640522ca 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -12,6 +12,7 @@ from collections import OrderedDict from copy import deepcopy from multiprocessing import connection, queues +from multiprocessing.managers import SyncManager from textwrap import indent from typing import Any, Callable, Dict, Iterator, Optional, Sequence, Tuple, Union @@ -72,6 +73,44 @@ def __call__(self, td: TensorDictBase) -> TensorDictBase: return td.set("action", self.action_spec.rand()) +class _Interruptor: + """A class for managing the collection state of a process. + + This class provides methods to start and stop collection, and to check + whether collection has been stopped. The collection state is protected + by a lock to ensure thread-safety. + """ + + def __init__(self): + self._collect = True + self._lock = mp.Lock() + + def start_collection(self): + with self._lock: + self._collect = True + + def stop_collection(self): + with self._lock: + self._collect = False + + def collection_stopped(self): + with self._lock: + return self._collect is False + + +class _InterruptorManager(SyncManager): + """A custom SyncManager for managing the collection state of a process. + + This class extends the SyncManager class and allows to share an Interruptor object + between processes. + """ + + pass + + +_InterruptorManager.register("_Interruptor", _Interruptor) + + def recursive_map_to_cpu(dictionary: OrderedDict) -> OrderedDict: """Maps the tensors to CPU through a nested dictionary.""" return OrderedDict( @@ -341,6 +380,11 @@ class SyncDataCollector(_DataCollector): updated. This feature should be used cautiously: if the same tensordict is added to a replay buffer for instance, the whole content of the buffer will be identical. + Default is False. + interruptor : (_Interruptor, optional) + An _Interruptor object that can be used from outside the class to control rollout collection. + The _Interruptor class has methods ´start_collection´ and ´stop_collection´, which allow to implement + strategies such as preeptively stopping rollout collection. Default is ``False``. reset_when_done (bool, optional): if ``True`` (default), an environment that return a ``True`` value in its ``"done"`` or ``"truncated"`` @@ -419,6 +463,7 @@ def __init__( exploration_mode: str = DEFAULT_EXPLORATION_MODE, return_same_td: bool = False, reset_when_done: bool = True, + interruptor=None, ): self.closed = True @@ -547,6 +592,7 @@ def __init__( self.split_trajs = split_trajs self._has_been_done = None self._exclude_private_keys = True + self.interruptor = interruptor # for RPC def next(self): @@ -679,6 +725,9 @@ def rollout(self) -> TensorDictBase: if self.reset_at_each_iter: self._tensordict.update(self.env.reset(), inplace=True) + # self._tensordict.fill_(("collector", "step_count"), 0) + self._tensordict_out.fill_(("collector", "traj_ids"), -1) + with set_exploration_mode(self.exploration_mode): for j in range(self.frames_per_batch): if self._frames < self.init_random_frames: @@ -701,6 +750,8 @@ def rollout(self) -> TensorDictBase: self._tensordict_out.lock() self._step_and_maybe_reset() + if self.interruptor is not None and self.interruptor.collection_stopped(): + break return self._tensordict_out @@ -882,7 +933,8 @@ class _MultiDataCollector(_DataCollector): update_at_each_batch (boolm optional): if ``True``, :meth:`~.update_policy_weight_()` will be called before (sync) or after (async) each data collection. Defaults to ``False``. - + preemptive_threshold (float, optional): a value between 0.0 and 1.0 that specifies the ratio of workers + that will be allowed to finished collecting their rollout before the rest are forced to end early. """ def __init__( @@ -907,6 +959,7 @@ def __init__( split_trajs: Optional[bool] = None, exploration_mode: str = DEFAULT_EXPLORATION_MODE, reset_when_done: bool = True, + preemptive_threshold: float = None, update_at_each_batch: bool = False, devices=None, storing_devices=None, @@ -1047,6 +1100,15 @@ def device_err_msg(device_name, devices_list): self.init_random_frames = init_random_frames self.update_at_each_batch = update_at_each_batch self.exploration_mode = exploration_mode + self.frames_per_worker = np.inf + if preemptive_threshold is not None: + self.preemptive_threshold = np.clip(preemptive_threshold, 0.0, 1.0) + manager = _InterruptorManager() + manager.start() + self.interruptor = manager._Interruptor() + else: + self.preemptive_threshold = 1.0 + self.interruptor = None self._run_processes() self._exclude_private_keys = True @@ -1099,6 +1161,7 @@ def _run_processes(self) -> None: "exploration_mode": self.exploration_mode, "reset_when_done": self.reset_when_done, "idx": i, + "interruptor": self.interruptor, } proc = mp.Process(target=_main_async_collector, kwargs=kwargs) # proc.daemon can't be set as daemonic processes may be launched by the process itself @@ -1399,6 +1462,18 @@ def iterator(self) -> Iterator[TensorDictBase]: i += 1 max_traj_idx = None + + if self.interruptor is not None: + self.interruptor.start_collection() + while self.queue_out.qsize() < int( + self.num_workers * self.preemptive_threshold + ): + continue + self.interruptor.stop_collection() + # Now wait for stragglers to return + while self.queue_out.qsize() < int(self.num_workers): + continue + for _ in range(self.num_workers): new_data, j = self.queue_out.get() if j == 0: @@ -1416,7 +1491,7 @@ def iterator(self) -> Iterator[TensorDictBase]: for idx in range(self.num_workers): traj_ids = out_tensordicts_shared[idx].get(("collector", "traj_ids")) if max_traj_idx is not None: - traj_ids += max_traj_idx + traj_ids[traj_ids != -1] += max_traj_idx # out_tensordicts_shared[idx].set("traj_ids", traj_ids) max_traj_idx = traj_ids.max().item() + 1 # out = out_tensordicts_shared[idx] @@ -1428,6 +1503,7 @@ def iterator(self) -> Iterator[TensorDictBase]: prev_device = item.device else: same_device = same_device and (item.device == prev_device) + if same_device: out_buffer = torch.cat( list(out_tensordicts_shared.values()), 0, out=out_buffer @@ -1803,6 +1879,7 @@ def _main_async_collector( exploration_mode: str = DEFAULT_EXPLORATION_MODE, reset_when_done: bool = True, verbose: bool = VERBOSE, + interruptor=None, ) -> None: pipe_parent.close() #  init variables that will be cleared when closing @@ -1823,6 +1900,7 @@ def _main_async_collector( exploration_mode=exploration_mode, reset_when_done=reset_when_done, return_same_td=True, + interruptor=interruptor, ) if verbose: print("Sync data collector created")