Skip to content

Commit

Permalink
[Feature] BatchSizeTransform (#2030)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Mar 26, 2024
1 parent 2b95b41 commit a7bf5a4
Show file tree
Hide file tree
Showing 6 changed files with 679 additions and 2 deletions.
1 change: 1 addition & 0 deletions docs/source/reference/envs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -580,6 +580,7 @@ to be able to create this other composition:
Transform
TransformedEnv
ActionMask
BatchSizeTransform
BinarizeReward
BurnInTransform
CatFrames
Expand Down
280 changes: 279 additions & 1 deletion test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,11 @@
from torchrl.envs.transforms import VecNorm
from torchrl.envs.transforms.r3m import _R3MNet
from torchrl.envs.transforms.rlhf import KLRewardTransform
from torchrl.envs.transforms.transforms import _has_tv, FORWARD_NOT_IMPLEMENTED
from torchrl.envs.transforms.transforms import (
_has_tv,
BatchSizeTransform,
FORWARD_NOT_IMPLEMENTED,
)
from torchrl.envs.transforms.vc1 import _has_vc
from torchrl.envs.transforms.vip import _VIPNet, VIPRewardTransform
from torchrl.envs.utils import check_env_specs, step_mdp
Expand Down Expand Up @@ -10301,6 +10305,280 @@ def test_multistep_transform_changes(self):
assert t._buffer["steps"][-1] == data["steps"][-1]


class TestBatchSizeTransform(TransformBase):
class MyEnv(EnvBase):
batch_locked = False

def __init__(self):
super().__init__()
self.observation_spec = CompositeSpec(
observation=UnboundedContinuousTensorSpec(3)
)
self.reward_spec = UnboundedContinuousTensorSpec(1)
self.action_spec = UnboundedContinuousTensorSpec(1)

def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
tensordict_batch_size = (
tensordict.batch_size if tensordict is not None else torch.Size([])
)
result = self.observation_spec.rand(tensordict_batch_size)
result.update(self.full_done_spec.zero(tensordict_batch_size))
return result

def _step(
self,
tensordict: TensorDictBase,
) -> TensorDictBase:
result = self.observation_spec.rand(tensordict.batch_size)
result.update(self.full_done_spec.zero(tensordict.batch_size))
result.update(self.full_reward_spec.zero(tensordict.batch_size))
return result

def _set_seed(self, seed: int):
pass

@classmethod
def reset_func(tensordict, tensordict_reset, env):
result = env.observation_spec.rand()
result.update(env.full_done_spec.zero())
assert result.batch_size != torch.Size([])
return result

@pytest.mark.parametrize(
"stateless,reshape_fn",
[
[False, "reshape"],
[False, "unsqueeze"],
[False, "unflatten"],
[False, "squeeze"],
[False, "flatten"],
[True, None],
],
)
def test_single_trans_env_check(self, stateless, reshape_fn):
if stateless:
base_env = self.MyEnv()
transform = BatchSizeTransform(batch_size=[10])
expected_batch_size = torch.Size([10])
assert transform.reshape_fn is None
else:
if reshape_fn == "reshape":
base_env = CountingEnv(max_steps=3)
reshape_fn = lambda x: x.reshape(1, 1)
expected_batch_size = torch.Size([1, 1])
elif reshape_fn == "unsqueeze":
base_env = CountingEnv(max_steps=3)
reshape_fn = lambda x: x.unsqueeze(0)
expected_batch_size = torch.Size([1])
elif reshape_fn == "unflatten":
base_env = SerialEnv(1, lambda: CountingEnv(max_steps=3))
reshape_fn = lambda x: x.unflatten(0, (1, 1))
expected_batch_size = torch.Size([1, 1])
elif reshape_fn == "squeeze":
base_env = SerialEnv(1, lambda: CountingEnv(max_steps=3))
reshape_fn = lambda x: x.squeeze(0)
expected_batch_size = torch.Size([])
elif reshape_fn == "flatten":
base_env = SerialEnv(1, lambda: CountingEnv(max_steps=3))
reshape_fn = lambda x: x.unflatten(0, (1, 1)).flatten(0, 1)
expected_batch_size = torch.Size([1])
else:
raise NotImplementedError(reshape_fn)

transform = BatchSizeTransform(reshape_fn=reshape_fn)
assert transform.batch_size is None

env = TransformedEnv(base_env, transform)
assert env.batch_size == expected_batch_size
check_env_specs(env)

@pytest.mark.parametrize(
"stateless,reshape_fn",
[
[False, "reshape"],
[True, None],
],
)
def test_serial_trans_env_check(self, stateless, reshape_fn):
def make_env(stateless=stateless, reshape_fn=reshape_fn):
if stateless:
base_env = self.MyEnv()
transform = BatchSizeTransform(batch_size=[10])
expected_batch_size = torch.Size([10])
assert transform.reshape_fn is None
else:
if reshape_fn == "reshape":
base_env = CountingEnv(max_steps=3)
reshape_fn = lambda x: x.reshape(1, 1)
expected_batch_size = torch.Size([1, 1])
else:
raise NotImplementedError(reshape_fn)

transform = BatchSizeTransform(reshape_fn=reshape_fn)
assert transform.batch_size is None

env = TransformedEnv(base_env, transform)
assert env.batch_size == expected_batch_size
return env

env = SerialEnv(2, make_env)
assert env.batch_size == (2, *make_env().batch_size)
check_env_specs(env)

@pytest.mark.parametrize(
"stateless,reshape_fn",
[
[False, "reshape"],
[True, None],
],
)
def test_parallel_trans_env_check(self, stateless, reshape_fn):
def make_env(stateless=stateless, reshape_fn=reshape_fn):
if stateless:
base_env = self.MyEnv()
transform = BatchSizeTransform(batch_size=[10])
expected_batch_size = torch.Size([10])
assert transform.reshape_fn is None
else:
if reshape_fn == "reshape":
base_env = CountingEnv(max_steps=3)
reshape_fn = lambda x: x.reshape(1, 1)
expected_batch_size = torch.Size([1, 1])
else:
raise NotImplementedError(reshape_fn)

transform = BatchSizeTransform(reshape_fn=reshape_fn)
assert transform.batch_size is None

env = TransformedEnv(base_env, transform)
assert env.batch_size == expected_batch_size
return env

env = ParallelEnv(2, make_env, mp_start_method="fork")
assert env.batch_size == (2, *make_env().batch_size)
check_env_specs(env)

@pytest.mark.parametrize(
"stateless,reshape_fn",
[
[False, "reshape"],
],
)
def test_trans_serial_env_check(self, stateless, reshape_fn):
def make_env(stateless=stateless, reshape_fn=reshape_fn):
if reshape_fn == "reshape":
base_env = CountingEnv(max_steps=3)
else:
raise NotImplementedError(reshape_fn)
return base_env

if reshape_fn == "reshape":
reshape_fn = lambda x: x.reshape(1, 2)
expected_batch_size = torch.Size([1, 2])
else:
raise NotImplementedError(reshape_fn)

transform = BatchSizeTransform(reshape_fn=reshape_fn)
assert transform.batch_size is None

env = TransformedEnv(SerialEnv(2, make_env), transform)
assert env.batch_size == expected_batch_size
check_env_specs(env)

@pytest.mark.parametrize(
"stateless,reshape_fn",
[
[False, "reshape"],
],
)
def test_trans_parallel_env_check(self, stateless, reshape_fn):
def make_env(stateless=stateless, reshape_fn=reshape_fn):
if reshape_fn == "reshape":
base_env = CountingEnv(max_steps=3)
else:
raise NotImplementedError(reshape_fn)
return base_env

if reshape_fn == "reshape":
reshape_fn = lambda x: x.reshape(1, 2)
expected_batch_size = torch.Size([1, 2])
else:
raise NotImplementedError(reshape_fn)

transform = BatchSizeTransform(reshape_fn=reshape_fn)
assert transform.batch_size is None

env = TransformedEnv(
ParallelEnv(2, make_env, mp_start_method="fork"), transform
)
assert env.batch_size == expected_batch_size
check_env_specs(env)

@pytest.mark.parametrize("stateless,reshape_fn", [[False, "reshape"]])
def test_transform_no_env(self, stateless, reshape_fn):
if reshape_fn == "reshape":
reshape_fn = lambda x: x.reshape(1)
expected_batch_size = torch.Size([1])
else:
raise NotImplementedError(reshape_fn)
transform = BatchSizeTransform(reshape_fn=reshape_fn)
base_env = CountingEnv(max_steps=3)
assert transform._call(base_env.reset()).batch_size == expected_batch_size

@pytest.mark.parametrize("stateless,reshape_fn", [[False, "reshape"]])
def test_transform_compose(self, stateless, reshape_fn):
if reshape_fn == "reshape":
reshape_fn = lambda x: x.reshape(1)
expected_batch_size = torch.Size([1])
else:
raise NotImplementedError(reshape_fn)
transform = Compose(BatchSizeTransform(reshape_fn=reshape_fn))
base_env = CountingEnv(max_steps=3)
assert transform(base_env.reset()).batch_size == expected_batch_size

@pytest.mark.parametrize("stateless,reshape_fn", [[False, "reshape"]])
def test_transform_env(self, stateless, reshape_fn):
# tested in single_env
return

@pytest.mark.parametrize("stateless,reshape_fn", [[False, "reshape"]])
def test_transform_model(self, stateless, reshape_fn):
if reshape_fn == "reshape":
reshape_fn = lambda x: x.reshape(1)
expected_batch_size = torch.Size([1])
else:
raise NotImplementedError(reshape_fn)
transform = nn.Sequential(Compose(BatchSizeTransform(reshape_fn=reshape_fn)))
base_env = CountingEnv(max_steps=3)
assert transform(base_env.reset()).batch_size == expected_batch_size

@pytest.mark.parametrize("stateless,reshape_fn", [[False, "reshape"]])
@pytest.mark.parametrize("rbclass", [ReplayBuffer, TensorDictReplayBuffer])
def test_transform_rb(self, rbclass, stateless, reshape_fn):
if reshape_fn == "reshape":
reshape_fn = lambda x: x.reshape(1, -1)
expected_batch_size = torch.Size([1, 12])
else:
raise NotImplementedError(reshape_fn)
rb = rbclass(storage=LazyTensorStorage(20))
transform = Compose(BatchSizeTransform(reshape_fn=reshape_fn))
rb.append_transform(transform)

batch = (20, 3)
td = TensorDict({"a": {"b": {"c": {}}}}, batch)

rb.extend(td)
if rbclass is TensorDictReplayBuffer:
with pytest.raises(RuntimeError, match="Failed to set the metadata"):
assert rb.sample(4).shape == expected_batch_size
else:
assert rb.sample(4).shape == expected_batch_size

def test_transform_inverse(self):
# Tested in single_env
return


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
Loading

0 comments on commit a7bf5a4

Please sign in to comment.