Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Multicollector interruptor #963

Merged
merged 19 commits into from
Mar 17, 2023
62 changes: 62 additions & 0 deletions test/test_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from torchrl._utils import seed_generator
from torchrl.collectors import aSyncDataCollector, SyncDataCollector
from torchrl.collectors.collectors import (
_Interruptor,
MultiaSyncDataCollector,
MultiSyncDataCollector,
RandomPolicy,
Expand Down Expand Up @@ -1232,6 +1233,67 @@ def weight_reset(m):
m.reset_parameters()


class TestPreemptiveThreshold:
@pytest.mark.parametrize("env_name", ["conv", "vec"])
def test_sync_collector_interruptor_mechanism(self, env_name, seed=100):
def env_fn(seed):
env = make_make_env(env_name)()
env.set_seed(seed)
return env

policy = make_policy(env_name)
interruptor = _Interruptor()
interruptor.start_collection()

collector = SyncDataCollector(
create_env_fn=env_fn,
create_env_kwargs={"seed": seed},
policy=policy,
frames_per_batch=50,
total_frames=200,
device="cpu",
interruptor=interruptor,
split_trajs=False,
)

interruptor.stop_collection()
for batch in collector:
assert batch["collector"]["traj_ids"][0] != -1
assert batch["collector"]["traj_ids"][1] == -1

@pytest.mark.parametrize("env_name", ["conv", "vec"])
def test_multisync_collector_interruptor_mechanism(self, env_name, seed=100):

frames_per_batch = 800

def env_fn(seed):
env = make_make_env(env_name)()
env.set_seed(seed)
return env

policy = make_policy(env_name)

collector = MultiSyncDataCollector(
create_env_fn=[env_fn] * 4,
create_env_kwargs=[{"seed": seed}] * 4,
policy=policy,
total_frames=800,
max_frames_per_traj=50,
frames_per_batch=frames_per_batch,
init_random_frames=-1,
reset_at_each_iter=False,
devices="cpu",
storing_devices="cpu",
split_trajs=False,
preemptive_threshold=0.0, # stop after one iteration
)

for batch in collector:
trajectory_ids = batch["collector"]["traj_ids"]
trajectory_ids_mask = trajectory_ids != -1 # valid frames mask
assert trajectory_ids[trajectory_ids_mask].numel() < frames_per_batch


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
82 changes: 80 additions & 2 deletions torchrl/collectors/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from collections import OrderedDict
from copy import deepcopy
from multiprocessing import connection, queues
from multiprocessing.managers import SyncManager
from textwrap import indent
from typing import Any, Callable, Dict, Iterator, Optional, Sequence, Tuple, Union

Expand Down Expand Up @@ -72,6 +73,44 @@ def __call__(self, td: TensorDictBase) -> TensorDictBase:
return td.set("action", self.action_spec.rand())


class _Interruptor:
"""A class for managing the collection state of a process.

This class provides methods to start and stop collection, and to check
whether collection has been stopped. The collection state is protected
by a lock to ensure thread-safety.
"""

def __init__(self):
self._collect = True
self._lock = mp.Lock()

def start_collection(self):
with self._lock:
self._collect = True

def stop_collection(self):
with self._lock:
self._collect = False

def collection_stopped(self):
with self._lock:
return self._collect is False


class _InterruptorManager(SyncManager):
"""A custom SyncManager for managing the collection state of a process.

This class extends the SyncManager class and allows to share an Interruptor object
between processes.
"""

pass


_InterruptorManager.register("_Interruptor", _Interruptor)


def recursive_map_to_cpu(dictionary: OrderedDict) -> OrderedDict:
"""Maps the tensors to CPU through a nested dictionary."""
return OrderedDict(
Expand Down Expand Up @@ -341,6 +380,11 @@ class SyncDataCollector(_DataCollector):
updated. This feature should be used cautiously: if the same
tensordict is added to a replay buffer for instance,
the whole content of the buffer will be identical.
Default is False.
interruptor : (_Interruptor, optional)
An _Interruptor object that can be used from outside the class to control rollout collection.
The _Interruptor class has methods ´start_collection´ and ´stop_collection´, which allow to implement
strategies such as preeptively stopping rollout collection.
Default is ``False``.
reset_when_done (bool, optional): if ``True`` (default), an environment
that return a ``True`` value in its ``"done"`` or ``"truncated"``
Expand Down Expand Up @@ -419,6 +463,7 @@ def __init__(
exploration_mode: str = DEFAULT_EXPLORATION_MODE,
return_same_td: bool = False,
reset_when_done: bool = True,
interruptor=None,
):
self.closed = True

Expand Down Expand Up @@ -547,6 +592,7 @@ def __init__(
self.split_trajs = split_trajs
self._has_been_done = None
self._exclude_private_keys = True
self.interruptor = interruptor

# for RPC
def next(self):
Expand Down Expand Up @@ -679,6 +725,9 @@ def rollout(self) -> TensorDictBase:
if self.reset_at_each_iter:
self._tensordict.update(self.env.reset(), inplace=True)

# self._tensordict.fill_(("collector", "step_count"), 0)
self._tensordict_out.fill_(("collector", "traj_ids"), -1)

with set_exploration_mode(self.exploration_mode):
for j in range(self.frames_per_batch):
if self._frames < self.init_random_frames:
Expand All @@ -701,6 +750,8 @@ def rollout(self) -> TensorDictBase:
self._tensordict_out.lock()

self._step_and_maybe_reset()
if self.interruptor and self.interruptor.collection_stopped():
break

return self._tensordict_out

Expand Down Expand Up @@ -882,7 +933,8 @@ class _MultiDataCollector(_DataCollector):
update_at_each_batch (boolm optional): if ``True``, :meth:`~.update_policy_weight_()`
will be called before (sync) or after (async) each data collection.
Defaults to ``False``.

preemptive_threshold (float, optional): a value between 0.0 and 1.0 that specifies the ratio of workers
that will be allowed to finished collecting their rollout before the rest are forced to end early.
"""

def __init__(
Expand All @@ -907,6 +959,7 @@ def __init__(
split_trajs: Optional[bool] = None,
exploration_mode: str = DEFAULT_EXPLORATION_MODE,
reset_when_done: bool = True,
preemptive_threshold: float = None,
update_at_each_batch: bool = False,
devices=None,
storing_devices=None,
Expand Down Expand Up @@ -1047,6 +1100,15 @@ def device_err_msg(device_name, devices_list):
self.init_random_frames = init_random_frames
self.update_at_each_batch = update_at_each_batch
self.exploration_mode = exploration_mode
self.frames_per_worker = np.inf
if preemptive_threshold is not None:
self.preemptive_threshold = np.clip(preemptive_threshold, 0.0, 1.0)
manager = _InterruptorManager()
manager.start()
self.interruptor = manager._Interruptor()
else:
self.preemptive_threshold = 1.0
self.interruptor = None
self._run_processes()
self._exclude_private_keys = True

Expand Down Expand Up @@ -1099,6 +1161,7 @@ def _run_processes(self) -> None:
"exploration_mode": self.exploration_mode,
"reset_when_done": self.reset_when_done,
"idx": i,
"interruptor": self.interruptor,
}
proc = mp.Process(target=_main_async_collector, kwargs=kwargs)
# proc.daemon can't be set as daemonic processes may be launched by the process itself
Expand Down Expand Up @@ -1399,6 +1462,18 @@ def iterator(self) -> Iterator[TensorDictBase]:

i += 1
max_traj_idx = None

if self.interruptor:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for clarity can we have if self.interruptor is not None?

self.interruptor.start_collection()
while self.queue_out.qsize() < int(
self.num_workers * self.preemptive_threshold
):
continue
self.interruptor.stop_collection()
# Now wait for stragglers to return
while self.queue_out.qsize() < int(self.num_workers):
continue

for _ in range(self.num_workers):
new_data, j = self.queue_out.get()
if j == 0:
Expand All @@ -1416,7 +1491,7 @@ def iterator(self) -> Iterator[TensorDictBase]:
for idx in range(self.num_workers):
traj_ids = out_tensordicts_shared[idx].get(("collector", "traj_ids"))
if max_traj_idx is not None:
traj_ids += max_traj_idx
traj_ids[traj_ids != -1] += max_traj_idx
# out_tensordicts_shared[idx].set("traj_ids", traj_ids)
max_traj_idx = traj_ids.max().item() + 1
# out = out_tensordicts_shared[idx]
Expand All @@ -1428,6 +1503,7 @@ def iterator(self) -> Iterator[TensorDictBase]:
prev_device = item.device
else:
same_device = same_device and (item.device == prev_device)

if same_device:
out_buffer = torch.cat(
list(out_tensordicts_shared.values()), 0, out=out_buffer
Expand Down Expand Up @@ -1803,6 +1879,7 @@ def _main_async_collector(
exploration_mode: str = DEFAULT_EXPLORATION_MODE,
reset_when_done: bool = True,
verbose: bool = VERBOSE,
interruptor=None,
) -> None:
pipe_parent.close()
#  init variables that will be cleared when closing
Expand All @@ -1823,6 +1900,7 @@ def _main_async_collector(
exploration_mode=exploration_mode,
reset_when_done=reset_when_done,
return_same_td=True,
interruptor=interruptor,
)
if verbose:
print("Sync data collector created")
Expand Down