Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Device transform #1472

Merged
merged 10 commits into from
Aug 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/reference/envs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,7 @@ to be able to create this other composition:
CatTensors
CenterCrop
Compose
DeviceCastTransform
DiscreteActionProjection
DoubleToFloat
DTypeCastTransform
Expand Down
2 changes: 1 addition & 1 deletion test/mocking_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def __init__(
**kwargs,
):
super().__init__(
device="cpu",
device=kwargs.pop("device", "cpu"),
dtype=torch.get_default_dtype(),
)
self.set_seed(seed)
Expand Down
102 changes: 102 additions & 0 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,11 @@
from torchrl.data import (
BoundedTensorSpec,
CompositeSpec,
LazyMemmapStorage,
LazyTensorStorage,
ReplayBuffer,
TensorDictReplayBuffer,
TensorStorage,
UnboundedContinuousTensorSpec,
)
from torchrl.envs import (
Expand All @@ -49,6 +51,7 @@
CatTensors,
CenterCrop,
Compose,
DeviceCastTransform,
DiscreteActionProjection,
DoubleToFloat,
EnvBase,
Expand Down Expand Up @@ -8133,6 +8136,105 @@ def test_kl_lstm(self):
klt(env.rollout(3, policy))


class TestDeviceCastTransform(TransformBase):
def test_single_trans_env_check(self):
env = ContinuousActionVecMockEnv(device="cpu:0")
env = TransformedEnv(env, DeviceCastTransform("cpu:1"))
assert env.device == torch.device("cpu:1")
check_env_specs(env)

def test_serial_trans_env_check(self):
def make_env():
return TransformedEnv(
ContinuousActionVecMockEnv(device="cpu:0"), DeviceCastTransform("cpu:1")
)

env = SerialEnv(2, make_env)
assert env.device == torch.device("cpu:1")
check_env_specs(env)

def test_parallel_trans_env_check(self):
def make_env():
return TransformedEnv(
ContinuousActionVecMockEnv(device="cpu:0"), DeviceCastTransform("cpu:1")
)

env = ParallelEnv(2, make_env)
assert env.device == torch.device("cpu:1")
check_env_specs(env)

def test_trans_serial_env_check(self):
def make_env():
return ContinuousActionVecMockEnv(device="cpu:0")

env = TransformedEnv(SerialEnv(2, make_env), DeviceCastTransform("cpu:1"))
assert env.device == torch.device("cpu:1")
check_env_specs(env)

def test_trans_parallel_env_check(self):
def make_env():
return ContinuousActionVecMockEnv(device="cpu:0")

env = TransformedEnv(ParallelEnv(2, make_env), DeviceCastTransform("cpu:1"))
assert env.device == torch.device("cpu:1")
check_env_specs(env)

def test_transform_no_env(self):
t = DeviceCastTransform("cpu:1", "cpu:0")
assert t._call(TensorDict({}, [], device="cpu:0")).device == torch.device(
"cpu:1"
)

def test_transform_compose(self):
t = Compose(DeviceCastTransform("cpu:1", "cpu:0"))
assert t._call(TensorDict({}, [], device="cpu:0")).device == torch.device(
"cpu:1"
)
assert t._inv_call(TensorDict({}, [], device="cpu:1")).device == torch.device(
"cpu:0"
)

def test_transform_env(self):
env = ContinuousActionVecMockEnv(device="cpu:0")
assert env.device == torch.device("cpu:0")
env = TransformedEnv(env, DeviceCastTransform("cpu:1"))
assert env.device == torch.device("cpu:1")
assert env.transform.device == torch.device("cpu:1")
assert env.transform.orig_device == torch.device("cpu:0")

def test_transform_model(self):
t = Compose(DeviceCastTransform("cpu:1", "cpu:0"))
m = nn.Sequential(t)
assert t(TensorDict({}, [], device="cpu:0")).device == torch.device("cpu:1")

@pytest.mark.parametrize("rbclass", [ReplayBuffer, TensorDictReplayBuffer])
@pytest.mark.parametrize(
"storage", [TensorStorage, LazyTensorStorage, LazyMemmapStorage]
)
def test_transform_rb(self, rbclass, storage):
t = Compose(DeviceCastTransform("cpu:1", "cpu:0"))
storage_kwargs = (
{
"storage": TensorDict(
{"a": torch.zeros(20, 1, device="cpu:0")}, [20], device="cpu:0"
)
}
if storage is TensorStorage
else {}
)
rb = rbclass(storage=storage(max_size=20, device="auto", **storage_kwargs))
rb.append_transform(t)
rb.add(TensorDict({"a": [1]}, [], device="cpu:1"))
assert rb._storage._storage.device == torch.device("cpu:0")
assert rb.sample(4).device == torch.device("cpu:1")

def test_transform_inverse(self):
t = DeviceCastTransform("cpu:1", "cpu:0")
assert t._inv_call(TensorDict({}, [], device="cpu:1")).device == torch.device(
"cpu:0"
)


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
33 changes: 22 additions & 11 deletions torchrl/data/replay_buffers/replay_buffers.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,13 @@ def add(self, data: Any) -> int:
Returns:
index where the data lives in the replay buffer.
"""
if self._transform is not None and (
is_tensor_collection(data) or len(self._transform)
):
data = self._transform.inv(data)
return self._add(data)

def _add(self, data):
with self._replay_lock:
index = self._writer.add(data)
self._sampler.add(index)
Expand All @@ -271,9 +278,9 @@ def extend(self, data: Sequence) -> torch.Tensor:
Returns:
Indices of the data added to the replay buffer.
"""
if self._transform is not None and is_tensor_collection(data):
data = self._transform.inv(data)
elif self._transform is not None and len(self._transform):
if self._transform is not None and (
is_tensor_collection(data) or len(self._transform)
):
data = self._transform.inv(data)
return self._extend(data)

Expand Down Expand Up @@ -675,19 +682,24 @@ def _get_priority(self, tensordict: TensorDictBase) -> Optional[torch.Tensor]:
return priority

def add(self, data: TensorDictBase) -> int:
if self._transform is not None:
data = self._transform.inv(data)

if is_tensor_collection(data):
data_add = TensorDict(
{
"_data": data,
},
batch_size=[],
device=data.device,
)
if data.batch_size:
data_add["_rb_batch_size"] = torch.tensor(data.batch_size)

else:
data_add = data
index = super().add(data_add)

index = super()._add(data_add)
if is_tensor_collection(data_add):
data_add.set("index", index)

Expand All @@ -699,7 +711,8 @@ def add(self, data: TensorDictBase) -> int:
def extend(self, tensordicts: Union[List, TensorDictBase]) -> torch.Tensor:
if is_tensor_collection(tensordicts):
tensordicts = TensorDict(
{"_data": tensordicts}, batch_size=tensordicts.batch_size[:1]
{"_data": tensordicts},
batch_size=tensordicts.batch_size[:1],
)
if tensordicts.batch_dims > 1:
# we want the tensordict to have one dimension only. The batch size
Expand Down Expand Up @@ -730,14 +743,12 @@ def extend(self, tensordicts: Union[List, TensorDictBase]) -> torch.Tensor:
stacked_td = tensordicts

if self._transform is not None:
stacked_td.set("_data", self._transform.inv(stacked_td.get("_data")))
tensordicts = self._transform.inv(stacked_td.get("_data"))
stacked_td.set("_data", tensordicts)
if tensordicts.device is not None:
stacked_td = stacked_td.to(tensordicts.device)

index = super()._extend(stacked_td)
# stacked_td.set(
# "index",
# torch.tensor(index, dtype=torch.int, device=stacked_td.device),
# inplace=True,
# )
self.update_tensordict_priority(stacked_td)
return index

Expand Down
35 changes: 27 additions & 8 deletions torchrl/data/replay_buffers/storages.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,11 +171,14 @@ class TensorStorage(Storage):
"""A storage for tensors and tensordicts.

Args:
data (tensor or TensorDict): the data buffer to be used.
storage (tensor or TensorDict): the data buffer to be used.
max_size (int): size of the storage, i.e. maximum number of elements stored
in the buffer.
device (torch.device, optional): device where the sampled tensors will be
stored and sent. Default is :obj:`torch.device("cpu")`.
If "auto" is passed, the device is automatically gathered from the
first batch of data passed. This is not enabled by default to avoid
data placed on GPU by mistake, causing OOM issues.

Examples:
>>> data = TensorDict({
Expand Down Expand Up @@ -230,7 +233,7 @@ def __new__(cls, *args, **kwargs):
cls._storage = None
return super().__new__(cls)

def __init__(self, storage, max_size=None, device=None):
def __init__(self, storage, max_size=None, device="cpu"):
if not ((storage is None) ^ (max_size is None)):
if storage is None:
raise ValueError("Expected storage to be non-null.")
Expand All @@ -247,7 +250,13 @@ def __init__(self, storage, max_size=None, device=None):
self._len = max_size
else:
self._len = 0
self.device = device if device else torch.device("cpu")
self.device = (
torch.device(device)
if device != "auto"
else storage.device
if storage is not None
else "auto"
)
self._storage = storage

def state_dict(self) -> Dict[str, Any]:
Expand Down Expand Up @@ -345,6 +354,9 @@ class LazyTensorStorage(TensorStorage):
in the buffer.
device (torch.device, optional): device where the sampled tensors will be
stored and sent. Default is :obj:`torch.device("cpu")`.
If "auto" is passed, the device is automatically gathered from the
first batch of data passed. This is not enabled by default to avoid
data placed on GPU by mistake, causing OOM issues.

Examples:
>>> data = TensorDict({
Expand Down Expand Up @@ -396,12 +408,14 @@ class LazyTensorStorage(TensorStorage):

"""

def __init__(self, max_size, device=None):
super().__init__(None, max_size, device=device)
def __init__(self, max_size, device="cpu"):
super().__init__(storage=None, max_size=max_size, device=device)

def _init(self, data: Union[TensorDictBase, torch.Tensor]) -> None:
if VERBOSE:
print("Creating a TensorStorage...")
if self.device == "auto":
self.device = data.device
if isinstance(data, torch.Tensor):
# if Tensor, we just create a MemmapTensor of the desired shape, device and dtype
out = torch.empty(
Expand Down Expand Up @@ -436,6 +450,9 @@ class LazyMemmapStorage(LazyTensorStorage):
scratch_dir (str or path): directory where memmap-tensors will be written.
device (torch.device, optional): device where the sampled tensors will be
stored and sent. Default is :obj:`torch.device("cpu")`.
If ``None`` is provided, the device is automatically gathered from the
first batch of data passed. This is not enabled by default to avoid
data placed on GPU by mistake, causing OOM issues.

Examples:
>>> data = TensorDict({
Expand Down Expand Up @@ -486,15 +503,15 @@ class LazyMemmapStorage(LazyTensorStorage):

"""

def __init__(self, max_size, scratch_dir=None, device=None):
def __init__(self, max_size, scratch_dir=None, device="cpu"):
super().__init__(max_size)
self.initialized = False
self.scratch_dir = None
if scratch_dir is not None:
self.scratch_dir = str(scratch_dir)
if self.scratch_dir[-1] != "/":
self.scratch_dir += "/"
self.device = device if device else torch.device("cpu")
self.device = torch.device(device) if device != "auto" else device
self._len = 0

def state_dict(self) -> Dict[str, Any]:
Expand Down Expand Up @@ -552,6 +569,8 @@ def load_state_dict(self, state_dict):
def _init(self, data: Union[TensorDictBase, torch.Tensor]) -> None:
if VERBOSE:
print("Creating a MemmapStorage...")
if self.device == "auto":
self.device = data.device
if isinstance(data, torch.Tensor):
# if Tensor, we just create a MemmapTensor of the desired shape, device and dtype
out = MemmapTensor(
Expand Down Expand Up @@ -682,7 +701,7 @@ def _get_default_collate(storage, _is_tensordict=False):
return torch.utils.data._utils.collate.default_collate
elif isinstance(storage, LazyMemmapStorage):
return _collate_as_tensor
elif isinstance(storage, (LazyTensorStorage,)):
elif isinstance(storage, (TensorStorage,)):
return _collate_contiguous
else:
raise NotImplementedError(
Expand Down
1 change: 1 addition & 0 deletions torchrl/envs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
CatTensors,
CenterCrop,
Compose,
DeviceCastTransform,
DiscreteActionProjection,
DoubleToFloat,
DTypeCastTransform,
Expand Down
1 change: 1 addition & 0 deletions torchrl/envs/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
CatTensors,
CenterCrop,
Compose,
DeviceCastTransform,
DiscreteActionProjection,
DoubleToFloat,
DTypeCastTransform,
Expand Down
Loading
Loading