Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BugFix] Fix replay buffer extension with lists #1937

Merged
merged 2 commits into from
Feb 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 32 additions & 28 deletions test/test_rb.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,44 +156,20 @@ def _get_datum(self, datatype):

def _get_data(self, datatype, size):
if datatype is None:
data = torch.randint(
100,
(
size,
1,
),
)
data = torch.randint(100, (size, 1))
elif datatype == "tensor":
data = torch.randint(
100,
(
size,
1,
),
)
data = torch.randint(100, (size, 1))
elif datatype == "tensordict":
data = TensorDict(
{
"a": torch.randint(
100,
(
size,
1,
),
),
"a": torch.randint(100, (size, 1)),
"next": {"reward": torch.randn(size, 1)},
},
[size],
)
elif datatype == "pytree":
data = {
"a": torch.randint(
100,
(
size,
1,
),
),
"a": torch.randint(100, (size, 1)),
"b": {"c": [torch.zeros(size, 3), (torch.ones(size, 2),)]},
30: torch.zeros(size, 2),
}
Expand Down Expand Up @@ -838,6 +814,34 @@ def test_set_tensorclass(self, max_size, shape, storage):
tc_sample = mystorage.get(idx)
assert tc_sample.shape == torch.Size([tc.shape[0] - 2, *tc.shape[1:]])

def test_extend_list_pytree(self, max_size, shape, storage):
memory = ReplayBuffer(
storage=storage(max_size=max_size),
sampler=SamplerWithoutReplacement(),
)
data = [
(
torch.full(shape, i),
{"a": torch.full(shape, i), "b": (torch.full(shape, i))},
[torch.full(shape, i)],
)
for i in range(10)
]
memory.extend(data)
sample = memory.sample(10)
for leaf in torch.utils._pytree.tree_leaves(sample):
assert (leaf.unique(sorted=True) == torch.arange(10)).all()
memory = ReplayBuffer(
storage=storage(max_size=max_size),
sampler=SamplerWithoutReplacement(),
)
t1x4 = torch.Tensor([0.1, 0.2, 0.3, 0.4])
t1x1 = torch.Tensor([0.01])
with pytest.raises(
RuntimeError, match="Stacking the elements of the list resulted in an error"
):
memory.extend([t1x4, t1x1, t1x4 + 0.4, t1x1 + 0.01])


@pytest.mark.parametrize("priority_key", ["pk", "td_error"])
@pytest.mark.parametrize("contiguous", [True, False])
Expand Down
22 changes: 22 additions & 0 deletions torchrl/data/replay_buffers/storages.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,6 +542,21 @@ def set(
else:
self._len = max(self._len, max(cursor) + 1)

if isinstance(data, list):
# flip list
try:
data = _flip_list(data)
except Exception:
raise RuntimeError(
"Stacking the elements of the list resulted in "
"an error. "
f"Storages of type {type(self)} expect all elements of the list "
f"to have the same tree structure. If the list is compact (each "
f"leaf is itself a batch with the appropriate number of elements) "
f"consider using a tuple instead, as lists are used within `extend` "
f"for per-item addition."
)

if not self.initialized:
if not isinstance(cursor, INT_CLASSES):
if is_tensor_collection(data):
Expand Down Expand Up @@ -1319,3 +1334,10 @@ def save_tensor(tensor_path: str, tensor: torch.Tensor):
out.append(save_tensor(tensor_path, tensor))

return tree_unflatten(out, data_specs)


def _flip_list(data):
flat_data, flat_specs = zip(*[tree_flatten(item) for item in data])
flat_data = zip(*flat_data)
stacks = [torch.stack(item) for item in flat_data]
return tree_unflatten(stacks, flat_specs[0])
13 changes: 11 additions & 2 deletions torchrl/data/replay_buffers/writers.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,16 @@
from tensordict import is_tensor_collection, MemoryMappedTensor
from tensordict.utils import _STRDTYPE2DTYPE
from torch import multiprocessing as mp
from torch.utils._pytree import tree_flatten

try:
from torch.utils._pytree import tree_leaves
except ImportError:
from torch.utils._pytree import tree_flatten

def tree_leaves(data): # noqa: D103
tree_flat, _ = tree_flatten(data)
return tree_flat


from torchrl.data.replay_buffers.storages import Storage
from torchrl.data.replay_buffers.utils import _reduce
Expand Down Expand Up @@ -125,7 +134,7 @@ def extend(self, data: Sequence) -> torch.Tensor:
elif isinstance(data, list):
batch_size = len(data)
else:
batch_size = len(tree_flatten(data)[0][0])
batch_size = len(tree_leaves(data)[0])
if batch_size == 0:
raise RuntimeError("Expected at least one element in extend.")
device = data.device if hasattr(data, "device") else None
Expand Down
Loading