Skip to content

Commit

Permalink
[Feature] Making action masks compatible with q value modules and e-g…
Browse files Browse the repository at this point in the history
…reedy (pytorch#1499)

Signed-off-by: Matteo Bettini <[email protected]>
Co-authored-by: Vincent Moens <[email protected]>
  • Loading branch information
matteobettini and vmoens committed Oct 10, 2023
1 parent d8a0bc8 commit 7ee8f13
Show file tree
Hide file tree
Showing 9 changed files with 442 additions and 19 deletions.
2 changes: 1 addition & 1 deletion docs/source/reference/modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ other cases, the action written in the tensordict is simply the network output.
:template: rl_template_noinherit.rst

AdditiveGaussianWrapper
EGreedyWrapper
EGreedyModule
OrnsteinUhlenbeckProcessWrapper

Probabilistic actors
Expand Down
2 changes: 1 addition & 1 deletion examples/multiagent/iql.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def train(cfg: "DictConfig"): # noqa: F821
eps_end=0,
annealing_num_steps=int(cfg.collector.total_frames * (1 / 2)),
action_key=env.action_key,
spec=env.unbatched_action_spec[env.action_key],
spec=env.unbatched_action_spec,
)

collector = SyncDataCollector(
Expand Down
2 changes: 1 addition & 1 deletion examples/multiagent/qmix_vdn.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def train(cfg: "DictConfig"): # noqa: F821
eps_end=0,
annealing_num_steps=int(cfg.collector.total_frames * (1 / 2)),
action_key=env.action_key,
spec=env.unbatched_action_spec[env.action_key],
spec=env.unbatched_action_spec,
)

if cfg.loss.mixer_type == "qmix":
Expand Down
33 changes: 33 additions & 0 deletions test/test_actors.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,6 +613,39 @@ def test_qvalue_hook_categorical_1_dim_batch(self, action_space, expected_action
assert values.shape == in_values.shape
assert (values == in_values).all()

@pytest.mark.parametrize("action_space", ["categorical", "one-hot"])
@pytest.mark.parametrize("action_n", [2, 3, 4, 5])
def test_qvalue_mask(self, action_space, action_n):
torch.manual_seed(0)
shape = (3, 4, 3, action_n)
action_values = torch.randn(size=shape)
td = TensorDict({"action_value": action_values}, [3])
module = QValueModule(
action_space=action_space,
action_value_key="action_value",
action_mask_key="action_mask",
)
with pytest.raises(KeyError, match="Action mask key "):
module(td)

action_mask = torch.randint(high=2, size=shape).to(torch.bool)
while not action_mask.any(dim=-1).all() or action_mask.all():
action_mask = torch.randint(high=2, size=shape).to(torch.bool)

td.set("action_mask", action_mask)
module(td)
new_action_values = td.get("action_value")

assert (new_action_values[~action_mask] != action_values[~action_mask]).all()
assert (new_action_values[action_mask] == action_values[action_mask]).all()
assert (td.get("chosen_action_value") > torch.finfo(torch.float).min).all()

if action_space == "one-hot":
assert (td.get("action")[action_mask]).any()
assert not (td.get("action")[~action_mask]).any()
else:
assert action_mask.gather(-1, td.get("action").unsqueeze(-1)).all()


@pytest.mark.parametrize("device", get_default_devices())
def test_value_based_policy(device):
Expand Down
161 changes: 155 additions & 6 deletions test/test_exploration.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,18 @@
NestedCountingEnv,
)
from scipy.stats import ttest_1samp
from tensordict.nn import InteractionType, TensorDictModule

from tensordict.nn import InteractionType, TensorDictModule, TensorDictSequential
from tensordict.tensordict import TensorDict
from torch import nn

from torchrl.collectors import SyncDataCollector
from torchrl.data import BoundedTensorSpec, CompositeSpec
from torchrl.data import (
BoundedTensorSpec,
CompositeSpec,
DiscreteTensorSpec,
OneHotDiscreteTensorSpec,
)
from torchrl.envs import SerialEnv
from torchrl.envs.transforms.transforms import gSDENoise, InitTracker, TransformedEnv
from torchrl.envs.utils import set_exploration_type
Expand All @@ -30,23 +36,37 @@
NormalParamWrapper,
)
from torchrl.modules.models.exploration import LazygSDEModule
from torchrl.modules.tensordict_module.actors import Actor, ProbabilisticActor
from torchrl.modules.tensordict_module.actors import (
Actor,
ProbabilisticActor,
QValueActor,
)
from torchrl.modules.tensordict_module.exploration import (
_OrnsteinUhlenbeckProcess,
AdditiveGaussianWrapper,
EGreedyModule,
EGreedyWrapper,
OrnsteinUhlenbeckProcessWrapper,
)


@pytest.mark.parametrize("eps_init", [0.0, 0.5, 1.0])
class TestEGreedy:
def test_egreedy(self, eps_init):
@pytest.mark.parametrize("eps_init", [0.0, 0.5, 1.0])
@pytest.mark.parametrize("module", [True, False])
def test_egreedy(self, eps_init, module):
torch.manual_seed(0)
spec = BoundedTensorSpec(1, 1, torch.Size([4]))
module = torch.nn.Linear(4, 4, bias=False)

policy = Actor(spec=spec, module=module)
explorative_policy = EGreedyWrapper(policy, eps_init=eps_init, eps_end=eps_init)
if module:
explorative_policy = TensorDictSequential(
policy, EGreedyModule(eps_init=eps_init, eps_end=eps_init, spec=spec)
)
else:
explorative_policy = EGreedyWrapper(
policy, eps_init=eps_init, eps_end=eps_init
)
td = TensorDict({"observation": torch.zeros(10, 4)}, batch_size=[10])
action = explorative_policy(td).get("action")
if eps_init == 0:
Expand All @@ -58,6 +78,135 @@ def test_egreedy(self, eps_init):
assert (action == 0).any()
assert ((action == 1) | (action == 0)).all()

@pytest.mark.parametrize("eps_init", [0.0, 0.5, 1.0])
@pytest.mark.parametrize("module", [True, False])
@pytest.mark.parametrize("spec_class", ["discrete", "one_hot"])
def test_egreedy_masked(self, module, eps_init, spec_class):
torch.manual_seed(0)
action_size = 4
batch_size = (3, 4, 2)
module = torch.nn.Linear(action_size, action_size, bias=False)
if spec_class == "discrete":
spec = DiscreteTensorSpec(action_size)
else:
spec = OneHotDiscreteTensorSpec(
action_size,
shape=(action_size,),
)
policy = QValueActor(spec=spec, module=module, action_mask_key="action_mask")
if module:
explorative_policy = TensorDictSequential(
policy,
EGreedyModule(
eps_init=eps_init,
eps_end=eps_init,
spec=spec,
action_mask_key="action_mask",
),
)
else:
explorative_policy = EGreedyWrapper(
policy,
eps_init=eps_init,
eps_end=eps_init,
action_mask_key="action_mask",
)

td = TensorDict(
{"observation": torch.zeros(*batch_size, action_size)},
batch_size=batch_size,
)
with pytest.raises(KeyError, match="Action mask key action_mask not found in"):
explorative_policy(td)

torch.manual_seed(0)
action_mask = torch.ones(*batch_size, action_size).to(torch.bool)
td = TensorDict(
{
"observation": torch.zeros(*batch_size, action_size),
"action_mask": action_mask,
},
batch_size=batch_size,
)
action = explorative_policy(td).get("action")

torch.manual_seed(0)
action_mask = torch.randint(high=2, size=(*batch_size, action_size)).to(
torch.bool
)
while not action_mask.any(dim=-1).all() or action_mask.all():
action_mask = torch.randint(high=2, size=(*batch_size, action_size)).to(
torch.bool
)

td = TensorDict(
{
"observation": torch.zeros(*batch_size, action_size),
"action_mask": action_mask,
},
batch_size=batch_size,
)
masked_action = explorative_policy(td).get("action")

if spec_class == "discrete":
action = spec.to_one_hot(action)
masked_action = spec.to_one_hot(masked_action)

assert not (action[~action_mask] == 0).all()
assert (masked_action[~action_mask] == 0).all()

def test_egreedy_wrapper_deprecation(self):
torch.manual_seed(0)
spec = BoundedTensorSpec(1, 1, torch.Size([4]))
module = torch.nn.Linear(4, 4, bias=False)
policy = Actor(spec=spec, module=module)
with pytest.deprecated_call():
EGreedyWrapper(policy)

def test_no_spec_error(
self,
):
torch.manual_seed(0)
action_size = 4
batch_size = (3, 4, 2)
module = torch.nn.Linear(action_size, action_size, bias=False)
spec = OneHotDiscreteTensorSpec(action_size, shape=(action_size,))
policy = QValueActor(spec=spec, module=module)
explorative_policy = TensorDictSequential(
policy,
EGreedyModule(spec=None),
)
td = TensorDict(
{
"observation": torch.zeros(*batch_size, action_size),
},
batch_size=batch_size,
)

with pytest.raises(
RuntimeError, match="spec must be provided to the exploration wrapper."
):
explorative_policy(td)

@pytest.mark.parametrize("module", [True, False])
def test_wrong_action_shape(self, module):
torch.manual_seed(0)
spec = BoundedTensorSpec(1, 1, torch.Size([4]))
module = torch.nn.Linear(4, 5, bias=False)

policy = Actor(spec=spec, module=module)
if module:
explorative_policy = TensorDictSequential(policy, EGreedyModule(spec=spec))
else:
explorative_policy = EGreedyWrapper(
policy,
)
td = TensorDict({"observation": torch.zeros(10, 4)}, batch_size=[10])
with pytest.raises(
ValueError, match="Action spec shape does not match the action shape"
):
explorative_policy(td)


@pytest.mark.parametrize("device", get_default_devices())
class TestOrnsteinUhlenbeckProcessWrapper:
Expand Down
1 change: 1 addition & 0 deletions torchrl/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
DistributionalQValueActor,
DistributionalQValueHook,
DistributionalQValueModule,
EGreedyModule,
EGreedyWrapper,
LMHeadActorValueOperator,
LSTMModule,
Expand Down
1 change: 1 addition & 0 deletions torchrl/modules/tensordict_module/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from .common import SafeModule, VmapModule
from .exploration import (
AdditiveGaussianWrapper,
EGreedyModule,
EGreedyWrapper,
OrnsteinUhlenbeckProcessWrapper,
)
Expand Down
Loading

0 comments on commit 7ee8f13

Please sign in to comment.