Skip to content

Commit

Permalink
[Feature] replay_buffer_chunk
Browse files Browse the repository at this point in the history
ghstack-source-id: 4abe903dc1d3643d793f54f93cb4fe147cce8d06
Pull Request resolved: #2388
  • Loading branch information
vmoens committed Aug 10, 2024
1 parent 228f68b commit e7d7d43
Show file tree
Hide file tree
Showing 5 changed files with 112 additions and 21 deletions.
77 changes: 59 additions & 18 deletions test/test_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import gc

import sys
import time

import numpy as np
import pytest
Expand Down Expand Up @@ -2798,44 +2797,76 @@ def test_collector_rb_sync(self):
del collector, env
assert assert_allclose_td(rbdata0, rbdata1)

def test_collector_rb_multisync(self):
env = GymEnv(CARTPOLE_VERSIONED())
env.set_seed(0)
@pytest.mark.parametrize("replay_buffer_chunk", [False, True])
@pytest.mark.parametrize("env_creator", [False, True])
def test_collector_rb_multisync(self, replay_buffer_chunk, env_creator):
if not env_creator:
env = GymEnv(CARTPOLE_VERSIONED()).append_transform(StepCounter())
env.set_seed(0)
action_spec = env.action_spec
env = lambda env=env: env
else:
env = EnvCreator(
lambda cp=CARTPOLE_VERSIONED(): GymEnv(cp).append_transform(
StepCounter()
)
)
action_spec = env.meta_data.specs["input_spec", "full_action_spec"]

rb = ReplayBuffer(storage=LazyTensorStorage(256), batch_size=5)
rb.add(env.rand_step(env.reset()))
rb.empty()

collector = MultiSyncDataCollector(
[lambda: env, lambda: env],
RandomPolicy(env.action_spec),
[env, env],
RandomPolicy(action_spec),
replay_buffer=rb,
total_frames=256,
frames_per_batch=16,
frames_per_batch=32,
replay_buffer_chunk=replay_buffer_chunk,
)
torch.manual_seed(0)
pred_len = 0
for c in collector:
pred_len += 16
pred_len += 32
assert c is None
assert len(rb) == pred_len
collector.shutdown()
assert len(rb) == 256

def test_collector_rb_multiasync(self):
env = GymEnv(CARTPOLE_VERSIONED())
env.set_seed(0)
if not replay_buffer_chunk:
steps_counts = rb["step_count"].squeeze().split(16)
collector_ids = rb["collector", "traj_ids"].squeeze().split(16)
for step_count, ids in zip(steps_counts, collector_ids):
step_countdiff = step_count.diff()
idsdiff = ids.diff()
assert (
(step_countdiff == 1) | (step_countdiff < 0)
).all(), steps_counts
assert (idsdiff >= 0).all()

@pytest.mark.parametrize("replay_buffer_chunk", [False, True])
@pytest.mark.parametrize("env_creator", [False, True])
def test_collector_rb_multiasync(self, replay_buffer_chunk, env_creator):
if not env_creator:
env = GymEnv(CARTPOLE_VERSIONED()).append_transform(StepCounter())
env.set_seed(0)
action_spec = env.action_spec
env = lambda env=env: env
else:
env = EnvCreator(
lambda cp=CARTPOLE_VERSIONED(): GymEnv(cp).append_transform(
StepCounter()
)
)
action_spec = env.meta_data.specs["input_spec", "full_action_spec"]

rb = ReplayBuffer(storage=LazyTensorStorage(256), batch_size=5)
rb.add(env.rand_step(env.reset()))
rb.empty()

collector = MultiaSyncDataCollector(
[lambda: env, lambda: env],
RandomPolicy(env.action_spec),
[env, env],
RandomPolicy(action_spec),
replay_buffer=rb,
total_frames=256,
frames_per_batch=16,
replay_buffer_chunk=replay_buffer_chunk,
)
torch.manual_seed(0)
pred_len = 0
Expand All @@ -2845,6 +2876,16 @@ def test_collector_rb_multiasync(self):
assert len(rb) >= pred_len
collector.shutdown()
assert len(rb) == 256
if not replay_buffer_chunk:
steps_counts = rb["step_count"].squeeze().split(16)
collector_ids = rb["collector", "traj_ids"].squeeze().split(16)
for step_count, ids in zip(steps_counts, collector_ids):
step_countdiff = step_count.diff()
idsdiff = ids.diff()
assert (
(step_countdiff == 1) | (step_countdiff < 0)
).all(), steps_counts
assert (idsdiff >= 0).all()


if __name__ == "__main__":
Expand Down
29 changes: 27 additions & 2 deletions torchrl/collectors/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
from torchrl.data.tensor_specs import TensorSpec
from torchrl.data.utils import CloudpickleWrapper, DEVICE_TYPING
from torchrl.envs.common import _do_nothing, EnvBase
from torchrl.envs.env_creator import EnvCreator
from torchrl.envs.transforms import StepCounter, TransformedEnv
from torchrl.envs.utils import (
_aggregate_end_of_traj,
Expand Down Expand Up @@ -1466,6 +1467,7 @@ def __init__(
set_truncated: bool = False,
use_buffers: bool | None = None,
replay_buffer: ReplayBuffer | None = None,
replay_buffer_chunk: bool = True,
):
exploration_type = _convert_exploration_type(
exploration_mode=exploration_mode, exploration_type=exploration_type
Expand Down Expand Up @@ -1510,6 +1512,8 @@ def __init__(

self._use_buffers = use_buffers
self.replay_buffer = replay_buffer
self._check_replay_buffer_init()
self.replay_buffer_chunk = replay_buffer_chunk
if (
replay_buffer is not None
and hasattr(replay_buffer, "shared")
Expand Down Expand Up @@ -1656,6 +1660,21 @@ def _get_weight_fn(weights=policy_weights):
)
self.cat_results = cat_results

def _check_replay_buffer_init(self):
try:
if not self.replay_buffer._storage.initialized:
if isinstance(self.create_env_fn, EnvCreator):
fake_td = self.create_env_fn.tensordict
else:
fake_td = self.create_env_fn[0](
**self.create_env_kwargs[0]
).fake_tensordict()
fake_td["collector", "traj_ids"] = torch.zeros((), dtype=torch.long)

self.replay_buffer._storage._init(fake_td)
except AttributeError:
pass

@classmethod
def _total_workers_from_env(cls, env_creators):
if isinstance(env_creators, (tuple, list)):
Expand Down Expand Up @@ -1790,6 +1809,7 @@ def _run_processes(self) -> None:
"set_truncated": self.set_truncated,
"use_buffers": self._use_buffers,
"replay_buffer": self.replay_buffer,
"replay_buffer_chunk": self.replay_buffer_chunk,
"traj_pool": traj_pool,
}
proc = _ProcessNoWarn(
Expand Down Expand Up @@ -2799,6 +2819,7 @@ def _main_async_collector(
set_truncated: bool = False,
use_buffers: bool | None = None,
replay_buffer: ReplayBuffer | None = None,
replay_buffer_chunk: bool = True,
traj_pool: _TrajectoryPool = None,
) -> None:
pipe_parent.close()
Expand All @@ -2824,7 +2845,7 @@ def _main_async_collector(
interruptor=interruptor,
set_truncated=set_truncated,
use_buffers=use_buffers,
replay_buffer=replay_buffer,
replay_buffer=replay_buffer if replay_buffer_chunk else None,
traj_pool=traj_pool,
)
use_buffers = inner_collector._use_buffers
Expand Down Expand Up @@ -2890,6 +2911,10 @@ def _main_async_collector(
continue

if replay_buffer is not None:
if not replay_buffer_chunk:
next_data.names = None
replay_buffer.extend(next_data)

try:
queue_out.put((idx, j), timeout=_TIMEOUT)
if verbose:
Expand Down Expand Up @@ -3026,7 +3051,7 @@ def __init__(self, ctx=None, lock: bool = False):
def get_traj_and_increment(self, n=1, device=None):
traj_id = []
with self.lock:
for i in range(n):
for _ in range(n):
traj_id.append(int(self._traj_id.value))
self._traj_id.value += 1
return torch.as_tensor(traj_id, device=device)
5 changes: 5 additions & 0 deletions torchrl/data/replay_buffers/replay_buffers.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,11 @@ def __len__(self) -> int:
with self._replay_lock:
return len(self._storage)

@property
def write_count(self):
"""The total number of items written so far in the buffer through add and extend."""
return self._writer._write_count

def __repr__(self) -> str:
from torchrl.envs.transforms import Compose

Expand Down
20 changes: 20 additions & 0 deletions torchrl/data/replay_buffers/writers.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ def add(self, data: Any) -> int | torch.Tensor:
self._cursor = (self._cursor + 1) % self._storage._max_size_along_dim0(
single_data=data
)
self._write_count += 1
# Replicate index requires the shape of the storage to be known
# Other than that, a "flat" (1d) index is ok to write the data
self._storage.set(_cursor, data)
Expand Down Expand Up @@ -191,6 +192,7 @@ def extend(self, data: Sequence) -> torch.Tensor:
)
# we need to update the cursor first to avoid race conditions between workers
self._cursor = (batch_size + cur_size) % max_size_along0
self._write_count += batch_size
# Replicate index requires the shape of the storage to be known
# Other than that, a "flat" (1d) index is ok to write the data
self._storage.set(index, data)
Expand Down Expand Up @@ -222,6 +224,20 @@ def _cursor(self, value):
_cursor_value = self._cursor_value = mp.Value("i", 0)
_cursor_value.value = value

@property
def _write_count(self):
_write_count = self.__dict__.get("_write_count_value", None)
if _write_count is None:
_write_count = self._write_count_value = mp.Value("i", 0)
return _write_count.value

@_write_count.setter
def _write_count(self, value):
_write_count = self.__dict__.get("_write_count_value", None)
if _write_count is None:
_write_count = self._write_count_value = mp.Value("i", 0)
_write_count.value = value

def __getstate__(self):
state = super().__getstate__()
if get_spawning_popen() is None:
Expand Down Expand Up @@ -249,6 +265,7 @@ def add(self, data: Any) -> int | torch.Tensor:
# we need to update the cursor first to avoid race conditions between workers
max_size_along_dim0 = self._storage._max_size_along_dim0(single_data=data)
self._cursor = (index + 1) % max_size_along_dim0
self._write_count += 1
if not is_tensorclass(data):
data.set(
"index",
Expand All @@ -275,6 +292,7 @@ def extend(self, data: Sequence) -> torch.Tensor:
)
# we need to update the cursor first to avoid race conditions between workers
self._cursor = (batch_size + cur_size) % max_size_along_dim0
self._write_count += batch_size
# storage must convert the data to the appropriate format if needed
if not is_tensorclass(data):
data.set(
Expand Down Expand Up @@ -469,6 +487,7 @@ def add(self, data: Any) -> int | torch.Tensor:
index = self.get_insert_index(data)
if index is not None:
data.set("index", index)
self._write_count += 1
# Replicate index requires the shape of the storage to be known
# Other than that, a "flat" (1d) index is ok to write the data
self._storage.set(index, data)
Expand All @@ -488,6 +507,7 @@ def extend(self, data: TensorDictBase) -> None:
for data_idx, sample in enumerate(data):
storage_idx = self.get_insert_index(sample)
if storage_idx is not None:
self._write_count += 1
data_to_replace[storage_idx] = data_idx

# -1 will be interpreted as invalid by prioritized buffers
Expand Down
2 changes: 1 addition & 1 deletion torchrl/envs/env_creator.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def share_memory(self, state_dict: OrderedDict) -> None:
del state_dict[key]

@property
def meta_data(self):
def meta_data(self) -> EnvMetaData:
if self._meta_data is None:
raise RuntimeError(
"meta_data is None in EnvCreator. " "Make sure init_() has been called."
Expand Down

0 comments on commit e7d7d43

Please sign in to comment.