Skip to content

Commit

Permalink
[Feature] End-of-life transform (#1605)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Oct 5, 2023
1 parent 244f93a commit 37c01cc
Show file tree
Hide file tree
Showing 10 changed files with 325 additions and 79 deletions.
1 change: 1 addition & 0 deletions docs/source/reference/envs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -476,6 +476,7 @@ to be able to create this other composition:
DiscreteActionProjection
DoubleToFloat
DTypeCastTransform
EndOfLifeTransform
ExcludeTransform
FiniteTensorDictCheck
FlattenObservation
Expand Down
2 changes: 1 addition & 1 deletion examples/a2c/a2c_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def main(cfg: "DictConfig"): # noqa: F821
)

# use end-of-life as done key
loss_module.set_keys(done="eol", terminated="eol")
loss_module.set_keys(done="end-of-life", terminated="end-of-life")

# Create optimizer
optim = torch.optim.Adam(
Expand Down
36 changes: 2 additions & 34 deletions examples/a2c/utils_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,12 @@
import torch.nn
import torch.optim
from tensordict.nn import TensorDictModule
from torchrl.data import CompositeSpec, UnboundedDiscreteTensorSpec
from torchrl.data import CompositeSpec
from torchrl.data.tensor_specs import DiscreteBox
from torchrl.envs import (
CatFrames,
DoubleToFloat,
EndOfLifeTransform,
EnvCreator,
ExplorationType,
GrayScale,
Expand All @@ -23,7 +24,6 @@
RewardSum,
StepCounter,
ToTensorImage,
Transform,
TransformedEnv,
VecNorm,
)
Expand All @@ -42,38 +42,6 @@
# --------------------------------------------------------------------


class EndOfLifeTransform(Transform):
"""Registers the end-of-life signal from a Gym env with a `lives` method.
Done by DeepMind for the DQN and co. It helps value estimation.
"""

def _step(self, tensordict, next_tensordict):
lives = self.parent.base_env._env.unwrapped.ale.lives()
end_of_life = torch.tensor(
[tensordict["lives"] < lives], device=self.parent.device
)
end_of_life = end_of_life | next_tensordict.get("done")
next_tensordict.set("eol", end_of_life)
next_tensordict.set("lives", lives)
return next_tensordict

def reset(self, tensordict):
lives = self.parent.base_env._env.unwrapped.ale.lives()
end_of_life = False
tensordict.set("eol", [end_of_life])
tensordict.set("lives", lives)
return tensordict

def transform_observation_spec(self, observation_spec):
full_done_spec = self.parent.output_spec["full_done_spec"]
observation_spec["eol"] = full_done_spec["done"].clone()
observation_spec["lives"] = UnboundedDiscreteTensorSpec(
self.parent.batch_size, device=self.parent.device
)
return observation_spec


def make_base_env(
env_name="BreakoutNoFrameskip-v4", frame_skip=4, device="cpu", is_test=False
):
Expand Down
2 changes: 1 addition & 1 deletion examples/ppo/ppo_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def main(cfg: "DictConfig"): # noqa: F821
)

# use end-of-life as done key
loss_module.set_keys(done="eol", terminated="eol")
loss_module.set_keys(done="end-of-life", terminated="end-of-life")

# Create optimizer
optim = torch.optim.Adam(
Expand Down
36 changes: 2 additions & 34 deletions examples/ppo/utils_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@
import torch.optim
from tensordict.nn import TensorDictModule
from torchrl.data import CompositeSpec
from torchrl.data.tensor_specs import DiscreteBox, UnboundedDiscreteTensorSpec
from torchrl.data.tensor_specs import DiscreteBox
from torchrl.envs import (
CatFrames,
DoubleToFloat,
EndOfLifeTransform,
EnvCreator,
ExplorationType,
GrayScale,
Expand All @@ -22,7 +23,6 @@
RewardSum,
StepCounter,
ToTensorImage,
Transform,
TransformedEnv,
VecNorm,
)
Expand All @@ -41,38 +41,6 @@
# --------------------------------------------------------------------


class EndOfLifeTransform(Transform):
"""Registers the end-of-life signal from a Gym env with a `lives` method.
Done by DeepMind for the DQN and co. It helps value estimation.
"""

def _step(self, tensordict, next_tensordict):
lives = self.parent.base_env._env.unwrapped.ale.lives()
end_of_life = torch.tensor(
[tensordict["lives"] < lives], device=self.parent.device
)
end_of_life = end_of_life | next_tensordict.get("done")
next_tensordict.set("eol", end_of_life)
next_tensordict.set("lives", lives)
return next_tensordict

def reset(self, tensordict):
lives = self.parent.base_env._env.unwrapped.ale.lives()
end_of_life = False
tensordict.set("eol", [end_of_life])
tensordict.set("lives", lives)
return tensordict

def transform_observation_spec(self, observation_spec):
full_done_spec = self.parent.output_spec["full_done_spec"]
observation_spec["eol"] = full_done_spec["done"].clone()
observation_spec["lives"] = UnboundedDiscreteTensorSpec(
self.parent.batch_size, device=self.parent.device
)
return observation_spec


def make_base_env(
env_name="BreakoutNoFrameskip-v4", frame_skip=4, device="cpu", is_test=False
):
Expand Down
120 changes: 112 additions & 8 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
DiscreteActionProjection,
DMControlEnv,
DoubleToFloat,
EndOfLifeTransform,
EnvBase,
EnvCreator,
ExcludeTransform,
Expand Down Expand Up @@ -101,11 +102,11 @@
VIPTransform,
)
from torchrl.envs.libs.dm_control import _has_dm_control
from torchrl.envs.libs.gym import _has_gym, GymEnv
from torchrl.envs.libs.gym import _has_gym, GymEnv, set_gym_backend
from torchrl.envs.transforms import VecNorm
from torchrl.envs.transforms.r3m import _R3MNet
from torchrl.envs.transforms.rlhf import KLRewardTransform
from torchrl.envs.transforms.transforms import _has_tv
from torchrl.envs.transforms.transforms import _has_tv, FORWARD_NOT_IMPLEMENTED
from torchrl.envs.transforms.vc1 import _has_vc
from torchrl.envs.transforms.vip import _VIPNet, VIPRewardTransform
from torchrl.envs.utils import _replace_last, check_env_specs, step_mdp
Expand Down Expand Up @@ -8710,19 +8711,15 @@ def test_transform_env(self):

def test_transform_model(self):
t = ActionMask()
with pytest.raises(
RuntimeError, match="ActionMask must be executed within an environment"
):
with pytest.raises(RuntimeError, match=FORWARD_NOT_IMPLEMENTED.format(type(t))):
t(TensorDict({}, []))

def test_transform_rb(self):
t = ActionMask()
rb = ReplayBuffer(storage=LazyTensorStorage(100))
rb.append_transform(t)
rb.extend(TensorDict({"a": [1]}, [1]).expand(10))
with pytest.raises(
RuntimeError, match="ActionMask must be executed within an environment"
):
with pytest.raises(RuntimeError, match=FORWARD_NOT_IMPLEMENTED.format(type(t))):
rb.sample(3)

def test_transform_inverse(self):
Expand Down Expand Up @@ -8964,6 +8961,113 @@ def test_transform_no_env(self, batch):
assert td["pixels"].shape == torch.Size((*batch, C, D, H, W))


@pytest.mark.skipif(
not _has_gym, reason="EndOfLifeTransform can only be tested when Gym is present."
)
class TestEndOfLife(TransformBase):
def test_trans_parallel_env_check(self):
def make():
with set_gym_backend("gymnasium"):
return GymEnv("ALE/Breakout-v5")

with pytest.warns(UserWarning, match="The base_env is not a gym env"):
with pytest.raises(AttributeError):
env = TransformedEnv(
ParallelEnv(2, make), transform=EndOfLifeTransform()
)
check_env_specs(env)

def test_trans_serial_env_check(self):
def make():
with set_gym_backend("gymnasium"):
return GymEnv("ALE/Breakout-v5")

with pytest.warns(UserWarning, match="The base_env is not a gym env"):
env = TransformedEnv(SerialEnv(2, make), transform=EndOfLifeTransform())
check_env_specs(env)

@pytest.mark.parametrize("eol_key", ["eol_key", ("nested", "eol")])
@pytest.mark.parametrize("lives_key", ["lives_key", ("nested", "lives")])
def test_single_trans_env_check(self, eol_key, lives_key):
with set_gym_backend("gymnasium"):
env = TransformedEnv(
GymEnv("ALE/Breakout-v5"),
transform=EndOfLifeTransform(eol_key=eol_key, lives_key=lives_key),
)
check_env_specs(env)

@pytest.mark.parametrize("eol_key", ["eol_key", ("nested", "eol")])
@pytest.mark.parametrize("lives_key", ["lives_key", ("nested", "lives")])
def test_serial_trans_env_check(self, eol_key, lives_key):
def make():
with set_gym_backend("gymnasium"):
return TransformedEnv(
GymEnv("ALE/Breakout-v5"),
transform=EndOfLifeTransform(eol_key=eol_key, lives_key=lives_key),
)

env = SerialEnv(2, make)
check_env_specs(env)

@pytest.mark.parametrize("eol_key", ["eol_key", ("nested", "eol")])
@pytest.mark.parametrize("lives_key", ["lives_key", ("nested", "lives")])
def test_parallel_trans_env_check(self, eol_key, lives_key):
def make():
with set_gym_backend("gymnasium"):
return TransformedEnv(
GymEnv("ALE/Breakout-v5"),
transform=EndOfLifeTransform(eol_key=eol_key, lives_key=lives_key),
)

env = ParallelEnv(2, make)
check_env_specs(env)

def test_transform_no_env(self):
t = EndOfLifeTransform()
with pytest.raises(RuntimeError, match=t.NO_PARENT_ERR.format(type(t))):
t._step(TensorDict({}, []), TensorDict({}, []))

def test_transform_compose(self):
t = EndOfLifeTransform()
with pytest.raises(RuntimeError, match=t.NO_PARENT_ERR.format(type(t))):
Compose(t)._step(TensorDict({}, []), TensorDict({}, []))

@pytest.mark.parametrize("eol_key", ["eol_key", ("nested", "eol")])
@pytest.mark.parametrize("lives_key", ["lives_key", ("nested", "lives")])
def test_transform_env(self, eol_key, lives_key):
from tensordict.nn import TensorDictModule
from torchrl.objectives import DQNLoss
from torchrl.objectives.value import GAE

with set_gym_backend("gymnasium"):
env = TransformedEnv(
GymEnv("ALE/Breakout-v5"),
transform=EndOfLifeTransform(eol_key=eol_key, lives_key=lives_key),
)
check_env_specs(env)
loss = DQNLoss(nn.Identity(), action_space="categorical")
env.transform.register_keys(loss)
assert ("next", eol_key) in loss.in_keys
gae = GAE(
gamma=0.9,
lmbda=0.9,
value_network=TensorDictModule(nn.Identity(), ["x"], ["y"]),
)
env.transform.register_keys(gae)
assert ("next", eol_key) in gae.in_keys

def test_transform_model(self):
t = EndOfLifeTransform()
with pytest.raises(RuntimeError, match=FORWARD_NOT_IMPLEMENTED.format(type(t))):
nn.Sequential(t)(TensorDict({}, []))

def test_transform_rb(self):
pass

def test_transform_inverse(self):
pass


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
1 change: 1 addition & 0 deletions torchrl/envs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
DiscreteActionProjection,
DoubleToFloat,
DTypeCastTransform,
EndOfLifeTransform,
ExcludeTransform,
FiniteTensorDictCheck,
FlattenObservation,
Expand Down
1 change: 1 addition & 0 deletions torchrl/envs/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from .gym_transforms import EndOfLifeTransform
from .r3m import R3MTransform
from .rlhf import KLRewardTransform
from .transforms import (
Expand Down
Loading

0 comments on commit 37c01cc

Please sign in to comment.