Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Making action masks compatible with q value modules and e-greedy #1499

Merged
merged 15 commits into from
Sep 7, 2023
1 change: 1 addition & 0 deletions docs/source/reference/modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ other cases, the action written in the tensordict is simply the network output.

AdditiveGaussianWrapper
EGreedyWrapper
EGreedyModule
OrnsteinUhlenbeckProcessWrapper

Probabilistic actors
Expand Down
2 changes: 1 addition & 1 deletion examples/multiagent/iql.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def train(cfg: "DictConfig"): # noqa: F821
eps_end=0,
annealing_num_steps=int(cfg.collector.total_frames * (1 / 2)),
action_key=env.action_key,
spec=env.unbatched_action_spec[env.action_key],
spec=env.unbatched_action_spec,
)

collector = SyncDataCollector(
Expand Down
2 changes: 1 addition & 1 deletion examples/multiagent/qmix_vdn.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def train(cfg: "DictConfig"): # noqa: F821
eps_end=0,
annealing_num_steps=int(cfg.collector.total_frames * (1 / 2)),
action_key=env.action_key,
spec=env.unbatched_action_spec[env.action_key],
spec=env.unbatched_action_spec,
)

if cfg.loss.mixer_type == "qmix":
Expand Down
33 changes: 33 additions & 0 deletions test/test_actors.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,6 +613,39 @@ def test_qvalue_hook_categorical_1_dim_batch(self, action_space, expected_action
assert values.shape == in_values.shape
assert (values == in_values).all()

@pytest.mark.parametrize("action_space", ["categorical", "one-hot"])
@pytest.mark.parametrize("action_n", [2, 3, 4, 5])
def test_qvalue_mask(self, action_space, action_n):
torch.manual_seed(0)
shape = (3, 4, 3, action_n)
action_values = torch.randn(size=shape)
td = TensorDict({"action_value": action_values.clone()}, [3])
module = QValueModule(
action_space=action_space,
action_value_key="action_value",
action_mask_key="action_mask",
)
with pytest.raises(KeyError, match="Action mask key "):
module(td)

action_mask = torch.randint(high=2, size=shape).to(torch.bool)
while not action_mask.any(dim=-1).all() or action_mask.all():
action_mask = torch.randint(high=2, size=shape).to(torch.bool)

td.set("action_mask", action_mask)
module(td)
new_action_values = td.get("action_value")

assert (new_action_values[~action_mask] != action_values[~action_mask]).all()
assert (new_action_values[action_mask] == action_values[action_mask]).all()
assert (td.get("chosen_action_value") > torch.finfo(torch.float).min).all()

if action_space == "one-hot":
assert (td.get("action")[action_mask]).any()
assert not (td.get("action")[~action_mask]).any()
else:
assert action_mask.gather(-1, td.get("action").unsqueeze(-1)).all()


@pytest.mark.parametrize("device", get_default_devices())
def test_value_based_policy(device):
Expand Down
97 changes: 92 additions & 5 deletions test/test_exploration.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,18 @@
NestedCountingEnv,
)
from scipy.stats import ttest_1samp
from tensordict.nn import InteractionType, TensorDictModule

from tensordict.nn import InteractionType, TensorDictModule, TensorDictSequential
from tensordict.tensordict import TensorDict
from torch import nn

from torchrl.collectors import SyncDataCollector
from torchrl.data import BoundedTensorSpec, CompositeSpec
from torchrl.data import (
BoundedTensorSpec,
CompositeSpec,
DiscreteTensorSpec,
OneHotDiscreteTensorSpec,
)
from torchrl.envs import SerialEnv
from torchrl.envs.transforms.transforms import gSDENoise, InitTracker, TransformedEnv
from torchrl.envs.utils import set_exploration_type
Expand All @@ -30,23 +36,37 @@
NormalParamWrapper,
)
from torchrl.modules.models.exploration import LazygSDEModule
from torchrl.modules.tensordict_module.actors import Actor, ProbabilisticActor
from torchrl.modules.tensordict_module.actors import (
Actor,
ProbabilisticActor,
QValueActor,
)
from torchrl.modules.tensordict_module.exploration import (
_OrnsteinUhlenbeckProcess,
AdditiveGaussianWrapper,
EGreedyModule,
EGreedyWrapper,
OrnsteinUhlenbeckProcessWrapper,
)


@pytest.mark.parametrize("eps_init", [0.0, 0.5, 1.0])
class TestEGreedy:
def test_egreedy(self, eps_init):
@pytest.mark.parametrize("module", [True, False])
def test_egreedy(self, eps_init, module):
torch.manual_seed(0)
spec = BoundedTensorSpec(1, 1, torch.Size([4]))
module = torch.nn.Linear(4, 4, bias=False)

policy = Actor(spec=spec, module=module)
explorative_policy = EGreedyWrapper(policy, eps_init=eps_init, eps_end=eps_init)
if module:
explorative_policy = TensorDictSequential(
policy, EGreedyModule(eps_init=eps_init, eps_end=eps_init, spec=spec)
)
else:
explorative_policy = EGreedyWrapper(
policy, eps_init=eps_init, eps_end=eps_init
)
td = TensorDict({"observation": torch.zeros(10, 4)}, batch_size=[10])
action = explorative_policy(td).get("action")
if eps_init == 0:
Expand All @@ -58,6 +78,73 @@ def test_egreedy(self, eps_init):
assert (action == 0).any()
assert ((action == 1) | (action == 0)).all()

@pytest.mark.parametrize("module", [True])
@pytest.mark.parametrize("spec_class", ["discrete", "one_hot"])
def test_egreedy_masked(self, module, eps_init, spec_class):
torch.manual_seed(0)
action_size = 4
batch_size = (3, 4, 2)
module = torch.nn.Linear(action_size, action_size, bias=False)
if spec_class == "discrete":
spec = DiscreteTensorSpec(action_size, shape=batch_size)
else:
spec = OneHotDiscreteTensorSpec(
action_size, shape=batch_size + (action_size,)
)
policy = QValueActor(spec=spec, module=module, action_mask_key="action_mask")
if module:
explorative_policy = TensorDictSequential(
policy,
EGreedyModule(
eps_init=eps_init,
eps_end=eps_init,
spec=spec,
action_mask_key="action_mask",
),
)
else:
explorative_policy = EGreedyWrapper(
matteobettini marked this conversation as resolved.
Show resolved Hide resolved
policy,
eps_init=eps_init,
eps_end=eps_init,
action_mask_key="action_mask",
)
torch.manual_seed(0)
action_mask = torch.ones(*batch_size, action_size).to(torch.bool)
td = TensorDict(
{
"observation": torch.zeros(*batch_size, action_size),
"action_mask": action_mask,
},
batch_size=batch_size,
)
action = explorative_policy(td).get("action")

torch.manual_seed(0)
action_mask = torch.randint(high=2, size=(*batch_size, action_size)).to(
torch.bool
)
while not action_mask.any(dim=-1).all() or action_mask.all():
action_mask = torch.randint(high=2, size=(*batch_size, action_size)).to(
torch.bool
)

td = TensorDict(
{
"observation": torch.zeros(*batch_size, action_size),
"action_mask": action_mask,
},
batch_size=batch_size,
)
masked_action = explorative_policy(td).get("action")

if spec_class == "discrete":
action = spec.to_one_hot(action)
masked_action = spec.to_one_hot(masked_action)

assert not (action[~action_mask] == 0).all()
assert (masked_action[~action_mask] == 0).all()


@pytest.mark.parametrize("device", get_default_devices())
class TestOrnsteinUhlenbeckProcessWrapper:
Expand Down
1 change: 1 addition & 0 deletions torchrl/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
DistributionalQValueActor,
DistributionalQValueHook,
DistributionalQValueModule,
EGreedyModule,
EGreedyWrapper,
LMHeadActorValueOperator,
LSTMModule,
Expand Down
1 change: 1 addition & 0 deletions torchrl/modules/tensordict_module/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from .common import SafeModule, VmapModule
from .exploration import (
AdditiveGaussianWrapper,
EGreedyModule,
EGreedyWrapper,
OrnsteinUhlenbeckProcessWrapper,
)
Expand Down
Loading
Loading