Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Masking actions #1421

Merged
merged 32 commits into from
Sep 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
7d291a7
init
vmoens Jul 27, 2023
77860cf
Merge branch 'main' into masked_actions
vmoens Aug 30, 2023
9a4a31c
Merge remote-tracking branch 'origin/main' into masked_actions
vmoens Aug 30, 2023
4ce72e1
amend
vmoens Aug 30, 2023
1888fe7
amend
vmoens Aug 30, 2023
2e11876
amend
vmoens Aug 30, 2023
dfbdf88
Merge branch 'main' into masked_actions
vmoens Sep 1, 2023
24ff374
Merge remote-tracking branch 'origin/main' into masked_actions
vmoens Sep 1, 2023
285a990
amend
vmoens Sep 1, 2023
0d2f791
Merge branch 'main' into masked_actions
vmoens Sep 1, 2023
40321a3
amend
vmoens Sep 1, 2023
77c1963
fix
vmoens Sep 1, 2023
e26b3be
fix
vmoens Sep 1, 2023
4e49953
fix
vmoens Sep 2, 2023
76c3f4f
fix
vmoens Sep 2, 2023
51ccbb3
init
vmoens Sep 2, 2023
1332824
lint
vmoens Sep 2, 2023
b832087
Merge branch 'fix_cliffwalk' into masked_actions
vmoens Sep 2, 2023
9eb7f23
fix
vmoens Sep 2, 2023
7be1080
fix
vmoens Sep 2, 2023
8313c92
fix
vmoens Sep 2, 2023
e60a94c
fix
vmoens Sep 3, 2023
9f5ad64
Merge remote-tracking branch 'origin/fix_cliffwalk' into masked_actions
vmoens Sep 3, 2023
a9e3f9e
amend
vmoens Sep 3, 2023
9c52b69
Merge branch 'fix_cliffwalk' into masked_actions
vmoens Sep 3, 2023
6551f80
amend
vmoens Sep 3, 2023
c369854
Merge branch 'fix_cliffwalk' into masked_actions
vmoens Sep 3, 2023
9a00095
amend
vmoens Sep 3, 2023
e4e3b99
amend
vmoens Sep 3, 2023
0199af8
Merge remote-tracking branch 'origin/fix_cliffwalk' into masked_actions
vmoens Sep 3, 2023
468a92d
amend
vmoens Sep 3, 2023
5d47b1f
Merge remote-tracking branch 'origin/fix_cliffwalk' into masked_actions
vmoens Sep 3, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/reference/envs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -445,6 +445,7 @@ to be able to create this other composition:

Transform
TransformedEnv
ActionMask
BinarizeReward
CatFrames
CatTensors
Expand Down
5 changes: 3 additions & 2 deletions test/mocking_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1053,7 +1053,7 @@ def _step(
batch_size=self.batch_size,
device=self.device,
)
return tensordict.select().set("next", tensordict)
return tensordict


class NestedCountingEnv(CountingEnv):
Expand Down Expand Up @@ -1696,7 +1696,8 @@ def _step(
done = self.output_spec["full_done_spec"].zero()
td = self.observation_spec.zero()

one_hot_action = tensordict["action"].argmax(-1).unsqueeze(-1)
one_hot_action = tensordict["action"]
one_hot_action = one_hot_action.long().argmax(-1).unsqueeze(-1)
reward["reward"] += one_hot_action.to(torch.float)
self.count += one_hot_action.to(torch.int)
td["observation"] += expand_right(self.count, td["observation"].shape)
Expand Down
2 changes: 1 addition & 1 deletion test/test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -2013,7 +2013,7 @@ def check_rollout_consistency(td: TensorDict, max_steps: int):
== td["next", "observation"][index_batch_size][:-1][~next_is_done]
).all()
# Check observation and reward update with count action for root
action_is_count = td["action"].argmax(-1).to(torch.bool)
action_is_count = td["action"].long().argmax(-1).to(torch.bool)
assert (
td["next", "observation"][action_is_count]
== td["observation"][action_is_count] + 1
Expand Down
19 changes: 19 additions & 0 deletions test/test_libs.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,25 @@ def info_reader(info, tensordict):
env.rand_step()
env.rollout(3)

@implement_for("gymnasium", "0.27.0", None)
def test_one_hot_and_categorical(self):
# tests that one-hot and categorical work ok when an integer is expected as action
cliff_walking = GymEnv("CliffWalking-v0", categorical_action_encoding=True)
cliff_walking.rollout(10)
check_env_specs(cliff_walking)

cliff_walking = GymEnv("CliffWalking-v0", categorical_action_encoding=False)
cliff_walking.rollout(10)
check_env_specs(cliff_walking)

@implement_for("gym", None, "0.27.0")
def test_one_hot_and_categorical(self): # noqa: F811
# we do not skip (bc we may want to make sure nothing is skipped)
# but CliffWalking-v0 in earlier Gym versions uses np.bool, which
# was deprecated after np 1.20, and we don't want to install multiple np
# versions.
return


@implement_for("gym", None, "0.26")
def _make_gym_environment(env_name): # noqa: F811
Expand Down
122 changes: 121 additions & 1 deletion test/test_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1168,7 +1168,7 @@ def test_one_hot_discrete_action_spec_rand(self):

sample = action_spec.rand((100000,))

sample_list = sample.argmax(-1)
sample_list = sample.long().argmax(-1)
sample_list = [sum(sample_list == i).item() for i in range(10)]
assert chisquare(sample_list).pvalue > 0.1

Expand Down Expand Up @@ -3169,6 +3169,126 @@ def get_all_keys(spec: TensorSpec, include_exclusive: bool):
return keys


@pytest.mark.parametrize("shape", ((), (1,), (2, 3), (2, 3, 4)))
@pytest.mark.parametrize(
"spectype", ["one_hot", "categorical", "mult_one_hot", "mult_discrete"]
)
@pytest.mark.parametrize("device", get_default_devices())
@pytest.mark.parametrize("rand_shape", ((), (2,), (2, 3)))
class TestSpecMasking:
def _make_mask(self, shape):
torch.manual_seed(0)
mask = torch.zeros(shape, dtype=torch.bool).bernoulli_()
if len(shape) == 1:
while not mask.any() or mask.all():
mask = torch.zeros(shape, dtype=torch.bool).bernoulli_()
return mask
mask_view = mask.view(-1, shape[-1])
for i in range(mask_view.shape[0]):
t = mask_view[i]
while not t.any() or t.all():
t.copy_(torch.zeros_like(t).bernoulli_())
return mask

def _one_hot_spec(self, shape, device, n):
shape = torch.Size([*shape, n])
mask = self._make_mask(shape).to(device)
return OneHotDiscreteTensorSpec(n, shape, device, mask=mask)

def _mult_one_hot_spec(self, shape, device, n):
shape = torch.Size([*shape, n + n + 2])
mask = torch.cat(
[
self._make_mask(shape[:-1] + (n,)).to(device),
self._make_mask(shape[:-1] + (n + 2,)).to(device),
],
-1,
)
return MultiOneHotDiscreteTensorSpec([n, n + 2], shape, device, mask=mask)

def _discrete_spec(self, shape, device, n):
mask = self._make_mask(torch.Size([*shape, n])).to(device)
return DiscreteTensorSpec(n, shape, device, mask=mask)

def _mult_discrete_spec(self, shape, device, n):
shape = torch.Size([*shape, 2])
mask = torch.cat(
[
self._make_mask(shape[:-1] + (n,)).to(device),
self._make_mask(shape[:-1] + (n + 2,)).to(device),
],
-1,
)
return MultiDiscreteTensorSpec([n, n + 2], shape, device, mask=mask)

def test_equal(self, shape, device, spectype, rand_shape, n=5):
shape = torch.Size(shape)
spec = (
self._one_hot_spec(shape, device, n=n)
if spectype == "one_hot"
else self._discrete_spec(shape, device, n=n)
if spectype == "categorical"
else self._mult_one_hot_spec(shape, device, n=n)
if spectype == "mult_one_hot"
else self._mult_discrete_spec(shape, device, n=n)
if spectype == "mult_discrete"
else None
)
spec_clone = spec.clone()
assert spec == spec_clone
assert spec.unsqueeze(0).squeeze(0) == spec
spec.update_mask(~spec.mask)
assert (spec.mask != spec_clone.mask).any()
assert spec != spec_clone

def test_is_in(self, shape, device, spectype, rand_shape, n=5):
shape = torch.Size(shape)
rand_shape = torch.Size(rand_shape)
spec = (
self._one_hot_spec(shape, device, n=n)
if spectype == "one_hot"
else self._discrete_spec(shape, device, n=n)
if spectype == "categorical"
else self._mult_one_hot_spec(shape, device, n=n)
if spectype == "mult_one_hot"
else self._mult_discrete_spec(shape, device, n=n)
if spectype == "mult_discrete"
else None
)
s = spec.rand(rand_shape)
assert spec.is_in(s)
spec.update_mask(~spec.mask)
assert not spec.is_in(s)

def test_project(self, shape, device, spectype, rand_shape, n=5):
shape = torch.Size(shape)
rand_shape = torch.Size(rand_shape)
spec = (
self._one_hot_spec(shape, device, n=n)
if spectype == "one_hot"
else self._discrete_spec(shape, device, n=n)
if spectype == "categorical"
else self._mult_one_hot_spec(shape, device, n=n)
if spectype == "mult_one_hot"
else self._mult_discrete_spec(shape, device, n=n)
if spectype == "mult_discrete"
else None
)
s = spec.rand(rand_shape)
assert (spec.project(s) == s).all()
spec.update_mask(~spec.mask)
sp = spec.project(s)
assert sp.shape == s.shape
if spectype == "one_hot":
assert (sp != s).any(-1).all()
assert (sp.any(-1)).all()
elif spectype == "mult_one_hot":
assert (sp != s).any(-1).all()
assert (sp.sum(-1) == 2).all()
else:
assert (sp != s).all()


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
111 changes: 111 additions & 0 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
UnboundedContinuousTensorSpec,
)
from torchrl.envs import (
ActionMask,
BinarizeReward,
CatFrames,
CatTensors,
Expand Down Expand Up @@ -8180,6 +8181,116 @@ def test_kl_lstm(self):
klt(env.rollout(3, policy))


class TestActionMask(TransformBase):
@property
def _env_class(self):
from torchrl.data import BinaryDiscreteTensorSpec, DiscreteTensorSpec

class MaskedEnv(EnvBase):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.action_spec = DiscreteTensorSpec(4)
self.state_spec = CompositeSpec(
mask=BinaryDiscreteTensorSpec(4, dtype=torch.bool)
)
self.observation_spec = CompositeSpec(
obs=UnboundedContinuousTensorSpec(3),
mask=BinaryDiscreteTensorSpec(4, dtype=torch.bool),
)
self.reward_spec = UnboundedContinuousTensorSpec(1)

def _reset(self, tensordict):
td = self.observation_spec.rand()
td.update(torch.ones_like(self.state_spec.rand()))
return td

def _step(self, data):
td = self.observation_spec.rand()
mask = data.get("mask")
action = data.get("action")
mask = mask.scatter(-1, action.unsqueeze(-1), 0)

td.set("mask", mask)
td.set("reward", self.reward_spec.rand())
td.set("done", ~(mask.any().view(1)))
return td

def _set_seed(self, seed):
return seed

return MaskedEnv

def test_single_trans_env_check(self):
env = self._env_class()
env = TransformedEnv(env, ActionMask())
check_env_specs(env)

def test_serial_trans_env_check(self):
env = SerialEnv(2, lambda: TransformedEnv(self._env_class(), ActionMask()))
check_env_specs(env)

def test_parallel_trans_env_check(self):
env = ParallelEnv(2, lambda: TransformedEnv(self._env_class(), ActionMask()))
check_env_specs(env)

def test_trans_serial_env_check(self):
env = TransformedEnv(SerialEnv(2, self._env_class), ActionMask())
check_env_specs(env)

def test_trans_parallel_env_check(self):
env = TransformedEnv(ParallelEnv(2, self._env_class), ActionMask())
check_env_specs(env)

def test_transform_no_env(self):
t = ActionMask()
with pytest.raises(RuntimeError, match="parent cannot be None"):
t._call(TensorDict({}, []))

def test_transform_compose(self):
env = self._env_class()
env = TransformedEnv(env, Compose(ActionMask()))
check_env_specs(env)

def test_transform_env(self):
env = TransformedEnv(ContinuousActionVecMockEnv(), ActionMask())
with pytest.raises(ValueError, match="The action spec must be one of"):
env.rollout(2)
env = self._env_class()
env = TransformedEnv(env, ActionMask())
td = env.reset()
for _ in range(1000):
td = env.rand_action(td)
assert env.action_spec.is_in(td.get("action"))
td = env.step(td)
td = step_mdp(td)
if td.get("done"):
break
else:
raise RuntimeError
assert not td.get("mask").any()

def test_transform_model(self):
t = ActionMask()
with pytest.raises(
RuntimeError, match="ActionMask must be executed within an environment"
):
t(TensorDict({}, []))

def test_transform_rb(self):
t = ActionMask()
rb = ReplayBuffer(storage=LazyTensorStorage(100))
rb.append_transform(t)
rb.extend(TensorDict({"a": [1]}, [1]).expand(10))
with pytest.raises(
RuntimeError, match="ActionMask must be executed within an environment"
):
rb.sample(3)

def test_transform_inverse(self):
# no inverse transform
return


class TestDeviceCastTransform(TransformBase):
def test_single_trans_env_check(self):
env = ContinuousActionVecMockEnv(device="cpu:0")
Expand Down
Loading
Loading