Skip to content

Commit 4b22ffe

Browse files
committed
[Feature] RB compability with compile
ghstack-source-id: f3bc11cd46440629f8e0fb51799c683a2679c0b9 Pull Request resolved: #2426
1 parent 82284a4 commit 4b22ffe

File tree

3 files changed

+90
-36
lines changed

3 files changed

+90
-36
lines changed

torchrl/data/replay_buffers/replay_buffers.py

+18-3
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,9 @@ class ReplayBuffer:
131131
.. warning:: As of now, the generator has no effect on the transforms.
132132
shared (bool, optional): whether the buffer will be shared using multiprocessing or not.
133133
Defaults to ``False``.
134+
compilable (bool, optional): whether the writer is compilable.
135+
If ``True``, the writer cannot be shared between multiple processes.
136+
Defaults to ``False``.
134137
135138
Examples:
136139
>>> import torch
@@ -216,11 +219,20 @@ def __init__(
216219
checkpointer: "StorageCheckpointerBase" | None = None, # noqa: F821
217220
generator: torch.Generator | None = None,
218221
shared: bool = False,
222+
compilable: bool = None,
219223
) -> None:
220224
self._storage = storage if storage is not None else ListStorage(max_size=1_000)
221225
self._storage.attach(self)
226+
if compilable is not None:
227+
self._storage._compilable = compilable
228+
self._storage._len = self._storage._len
229+
222230
self._sampler = sampler if sampler is not None else RandomSampler()
223-
self._writer = writer if writer is not None else RoundRobinWriter()
231+
self._writer = (
232+
writer
233+
if writer is not None
234+
else RoundRobinWriter(compilable=bool(compilable))
235+
)
224236
self._writer.register_storage(self._storage)
225237

226238
self._get_collate_fn(collate_fn)
@@ -601,7 +613,9 @@ def _add(self, data):
601613
return index
602614

603615
def _extend(self, data: Sequence) -> torch.Tensor:
604-
with self._replay_lock, self._write_lock:
616+
is_compiling = torch.compiler.is_dynamo_compiling()
617+
nc = contextlib.nullcontext()
618+
with self._replay_lock if not is_compiling else nc, self._write_lock if not is_compiling else nc:
605619
if self.dim_extend > 0:
606620
data = self._transpose(data)
607621
index = self._writer.extend(data)
@@ -654,7 +668,7 @@ def update_priority(
654668

655669
@pin_memory_output
656670
def _sample(self, batch_size: int) -> Tuple[Any, dict]:
657-
with self._replay_lock:
671+
with self._replay_lock if not torch.compiler.is_dynamo_compiling() else contextlib.nullcontext():
658672
index, info = self._sampler.sample(self._storage, batch_size)
659673
info["index"] = index
660674
data = self._storage.get(index)
@@ -1753,6 +1767,7 @@ def __init__(
17531767
num_buffer_sampled: int | None = None,
17541768
generator: torch.Generator | None = None,
17551769
shared: bool = False,
1770+
compilable: bool = False,
17561771
**kwargs,
17571772
):
17581773

torchrl/data/replay_buffers/storages.py

+32-15
Original file line numberDiff line numberDiff line change
@@ -57,10 +57,15 @@ class Storage:
5757
_rng: torch.Generator | None = None
5858

5959
def __init__(
60-
self, max_size: int, checkpointer: StorageCheckpointerBase | None = None
60+
self,
61+
max_size: int,
62+
checkpointer: StorageCheckpointerBase | None = None,
63+
compilable: bool = False,
6164
) -> None:
6265
self.max_size = int(max_size)
6366
self.checkpointer = checkpointer
67+
self._compilable = compilable
68+
self._attached_entities_set = set()
6469

6570
@property
6671
def checkpointer(self):
@@ -80,11 +85,11 @@ def _is_full(self):
8085
def _attached_entities(self):
8186
# RBs that use a given instance of Storage should add
8287
# themselves to this set.
83-
_attached_entities = self.__dict__.get("_attached_entities_set", None)
84-
if _attached_entities is None:
85-
_attached_entities = set()
86-
self.__dict__["_attached_entities_set"] = _attached_entities
87-
return _attached_entities
88+
return getattr(self, "_attached_entities_set", None)
89+
90+
@torch._dynamo.assume_constant_result
91+
def _attached_entities_iter(self):
92+
return list(self._attached_entities)
8893

8994
@abc.abstractmethod
9095
def set(self, cursor: int, data: Any, *, set_cursor: bool = True):
@@ -140,6 +145,7 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
140145
def _empty(self):
141146
...
142147

148+
@torch._dynamo.disable()
143149
def _rand_given_ndim(self, batch_size):
144150
# a method to return random indices given the storage ndim
145151
if self.ndim == 1:
@@ -330,6 +336,9 @@ class TensorStorage(Storage):
330336
measuring the storage size. For instance, a storage of shape ``[3, 4]``
331337
has capacity ``3`` if ``ndim=1`` and ``12`` if ``ndim=2``.
332338
Defaults to ``1``.
339+
compilable (bool, optional): whether the storage is compilable.
340+
If ``True``, the writer cannot be shared between multiple processes.
341+
Defaults to ``False``.
333342
334343
Examples:
335344
>>> data = TensorDict({
@@ -389,6 +398,7 @@ def __init__(
389398
*,
390399
device: torch.device = "cpu",
391400
ndim: int = 1,
401+
compilable: bool = False,
392402
):
393403
if not ((storage is None) ^ (max_size is None)):
394404
if storage is None:
@@ -404,7 +414,7 @@ def __init__(
404414
else:
405415
max_size = tree_flatten(storage)[0][0].shape[0]
406416
self.ndim = ndim
407-
super().__init__(max_size)
417+
super().__init__(max_size, compilable=compilable)
408418
self.initialized = storage is not None
409419
if self.initialized:
410420
self._len = max_size
@@ -423,16 +433,23 @@ def __init__(
423433
@property
424434
def _len(self):
425435
_len_value = self.__dict__.get("_len_value", None)
436+
if not self._compilable or not isinstance(self._len_value, int):
437+
if _len_value is None:
438+
_len_value = self._len_value = mp.Value("i", 0)
439+
return _len_value.value
426440
if _len_value is None:
427-
_len_value = self._len_value = mp.Value("i", 0)
428-
return _len_value.value
441+
_len_value = self._len_value = 0
442+
return _len_value
429443

430444
@_len.setter
431445
def _len(self, value):
432446
_len_value = self.__dict__.get("_len_value", None)
433-
if _len_value is None:
434-
_len_value = self._len_value = mp.Value("i", 0)
435-
_len_value.value = value
447+
if not self._compilable:
448+
if _len_value is None:
449+
_len_value = self._len_value = mp.Value("i", 0)
450+
_len_value.value = value
451+
else:
452+
self._len_value = value
436453

437454
@property
438455
def _total_shape(self):
@@ -1184,9 +1201,9 @@ def _rng(self, value):
11841201
for storage in self._storages:
11851202
storage._rng = value
11861203

1187-
@property
1188-
def _attached_entities(self):
1189-
return set()
1204+
# @property
1205+
# def _attached_entities(self):
1206+
# return set()
11901207

11911208
def extend(self, value):
11921209
raise RuntimeError

torchrl/data/replay_buffers/writers.py

+40-18
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,9 @@ class Writer(ABC):
4040
_storage: Storage
4141
_rng: torch.Generator | None = None
4242

43-
def __init__(self) -> None:
43+
def __init__(self, compilable: bool = False) -> None:
4444
self._storage = None
45+
self._compilable = compilable
4546

4647
def register_storage(self, storage: Storage) -> None:
4748
self._storage = storage
@@ -138,10 +139,17 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
138139

139140

140141
class RoundRobinWriter(Writer):
141-
"""A RoundRobin Writer class for composable replay buffers."""
142+
"""A RoundRobin Writer class for composable replay buffers.
142143
143-
def __init__(self, **kw) -> None:
144-
super().__init__(**kw)
144+
Args:
145+
compilable (bool, optional): whether the writer is compilable.
146+
If ``True``, the writer cannot be shared between multiple processes.
147+
Defaults to ``False``.
148+
149+
"""
150+
151+
def __init__(self, compilable: bool = False) -> None:
152+
super().__init__(compilable=compilable)
145153
self._cursor = 0
146154

147155
def dumps(self, path):
@@ -197,7 +205,7 @@ def extend(self, data: Sequence) -> torch.Tensor:
197205
# Other than that, a "flat" (1d) index is ok to write the data
198206
self._storage.set(index, data)
199207
index = self._replicate_index(index)
200-
for ent in self._storage._attached_entities:
208+
for ent in self._storage._attached_entities_iter():
201209
ent.mark_update(index)
202210
return index
203211

@@ -213,30 +221,44 @@ def _empty(self):
213221
@property
214222
def _cursor(self):
215223
_cursor_value = self.__dict__.get("_cursor_value", None)
224+
if not self._compilable or not isinstance(_cursor_value, int):
225+
if _cursor_value is None:
226+
_cursor_value = self._cursor_value = mp.Value("i", 0)
227+
return _cursor_value.value
216228
if _cursor_value is None:
217-
_cursor_value = self._cursor_value = mp.Value("i", 0)
218-
return _cursor_value.value
229+
_cursor_value = self._cursor_value = 0
230+
return _cursor_value
219231

220232
@_cursor.setter
221233
def _cursor(self, value):
222-
_cursor_value = self.__dict__.get("_cursor_value", None)
223-
if _cursor_value is None:
224-
_cursor_value = self._cursor_value = mp.Value("i", 0)
225-
_cursor_value.value = value
234+
if not self._compilable:
235+
_cursor_value = self.__dict__.get("_cursor_value", None)
236+
if _cursor_value is None:
237+
_cursor_value = self._cursor_value = mp.Value("i", 0)
238+
_cursor_value.value = value
239+
else:
240+
self._cursor_value = value
226241

227242
@property
228243
def _write_count(self):
229244
_write_count = self.__dict__.get("_write_count_value", None)
245+
if not self._compilable or not isinstance(_write_count, int):
246+
if _write_count is None:
247+
_write_count = self._write_count_value = mp.Value("i", 0)
248+
return _write_count.value
230249
if _write_count is None:
231-
_write_count = self._write_count_value = mp.Value("i", 0)
232-
return _write_count.value
250+
_write_count = self._write_count_value = 0
251+
return _write_count
233252

234253
@_write_count.setter
235254
def _write_count(self, value):
236-
_write_count = self.__dict__.get("_write_count_value", None)
237-
if _write_count is None:
238-
_write_count = self._write_count_value = mp.Value("i", 0)
239-
_write_count.value = value
255+
if not self._compilable:
256+
_write_count = self.__dict__.get("_write_count_value", None)
257+
if _write_count is None:
258+
_write_count = self._write_count_value = mp.Value("i", 0)
259+
_write_count.value = value
260+
else:
261+
self._write_count_value = value
240262

241263
def __getstate__(self):
242264
state = super().__getstate__()
@@ -248,7 +270,7 @@ def __getstate__(self):
248270

249271
def __setstate__(self, state):
250272
cursor = state.pop("cursor__context", None)
251-
if cursor is not None:
273+
if not state["_compilable"] and cursor is not None:
252274
_cursor_value = mp.Value("i", cursor)
253275
state["_cursor_value"] = _cursor_value
254276
self.__dict__.update(state)

0 commit comments

Comments
 (0)