diff --git a/docs/source/reference/modules.rst b/docs/source/reference/modules.rst index 704c8e6276a..0281061b007 100644 --- a/docs/source/reference/modules.rst +++ b/docs/source/reference/modules.rst @@ -68,7 +68,7 @@ other cases, the action written in the tensordict is simply the network output. :template: rl_template_noinherit.rst AdditiveGaussianWrapper - EGreedyWrapper + EGreedyModule OrnsteinUhlenbeckProcessWrapper Probabilistic actors diff --git a/examples/multiagent/iql.py b/examples/multiagent/iql.py index 00c7bf5fc87..4d36614f199 100644 --- a/examples/multiagent/iql.py +++ b/examples/multiagent/iql.py @@ -101,7 +101,7 @@ def train(cfg: "DictConfig"): # noqa: F821 eps_end=0, annealing_num_steps=int(cfg.collector.total_frames * (1 / 2)), action_key=env.action_key, - spec=env.unbatched_action_spec[env.action_key], + spec=env.unbatched_action_spec, ) collector = SyncDataCollector( diff --git a/examples/multiagent/qmix_vdn.py b/examples/multiagent/qmix_vdn.py index 55c5ef012ba..222e0434db2 100644 --- a/examples/multiagent/qmix_vdn.py +++ b/examples/multiagent/qmix_vdn.py @@ -102,7 +102,7 @@ def train(cfg: "DictConfig"): # noqa: F821 eps_end=0, annealing_num_steps=int(cfg.collector.total_frames * (1 / 2)), action_key=env.action_key, - spec=env.unbatched_action_spec[env.action_key], + spec=env.unbatched_action_spec, ) if cfg.loss.mixer_type == "qmix": diff --git a/test/test_actors.py b/test/test_actors.py index d16c95731d5..06d59de0a48 100644 --- a/test/test_actors.py +++ b/test/test_actors.py @@ -613,6 +613,39 @@ def test_qvalue_hook_categorical_1_dim_batch(self, action_space, expected_action assert values.shape == in_values.shape assert (values == in_values).all() + @pytest.mark.parametrize("action_space", ["categorical", "one-hot"]) + @pytest.mark.parametrize("action_n", [2, 3, 4, 5]) + def test_qvalue_mask(self, action_space, action_n): + torch.manual_seed(0) + shape = (3, 4, 3, action_n) + action_values = torch.randn(size=shape) + td = TensorDict({"action_value": action_values}, [3]) + module = QValueModule( + action_space=action_space, + action_value_key="action_value", + action_mask_key="action_mask", + ) + with pytest.raises(KeyError, match="Action mask key "): + module(td) + + action_mask = torch.randint(high=2, size=shape).to(torch.bool) + while not action_mask.any(dim=-1).all() or action_mask.all(): + action_mask = torch.randint(high=2, size=shape).to(torch.bool) + + td.set("action_mask", action_mask) + module(td) + new_action_values = td.get("action_value") + + assert (new_action_values[~action_mask] != action_values[~action_mask]).all() + assert (new_action_values[action_mask] == action_values[action_mask]).all() + assert (td.get("chosen_action_value") > torch.finfo(torch.float).min).all() + + if action_space == "one-hot": + assert (td.get("action")[action_mask]).any() + assert not (td.get("action")[~action_mask]).any() + else: + assert action_mask.gather(-1, td.get("action").unsqueeze(-1)).all() + @pytest.mark.parametrize("device", get_default_devices()) def test_value_based_policy(device): diff --git a/test/test_exploration.py b/test/test_exploration.py index c823dbaf4f4..c4cd44f0692 100644 --- a/test/test_exploration.py +++ b/test/test_exploration.py @@ -14,12 +14,18 @@ NestedCountingEnv, ) from scipy.stats import ttest_1samp -from tensordict.nn import InteractionType, TensorDictModule + +from tensordict.nn import InteractionType, TensorDictModule, TensorDictSequential from tensordict.tensordict import TensorDict from torch import nn from torchrl.collectors import SyncDataCollector -from torchrl.data import BoundedTensorSpec, CompositeSpec +from torchrl.data import ( + BoundedTensorSpec, + CompositeSpec, + DiscreteTensorSpec, + OneHotDiscreteTensorSpec, +) from torchrl.envs import SerialEnv from torchrl.envs.transforms.transforms import gSDENoise, InitTracker, TransformedEnv from torchrl.envs.utils import set_exploration_type @@ -30,23 +36,37 @@ NormalParamWrapper, ) from torchrl.modules.models.exploration import LazygSDEModule -from torchrl.modules.tensordict_module.actors import Actor, ProbabilisticActor +from torchrl.modules.tensordict_module.actors import ( + Actor, + ProbabilisticActor, + QValueActor, +) from torchrl.modules.tensordict_module.exploration import ( _OrnsteinUhlenbeckProcess, AdditiveGaussianWrapper, + EGreedyModule, EGreedyWrapper, OrnsteinUhlenbeckProcessWrapper, ) -@pytest.mark.parametrize("eps_init", [0.0, 0.5, 1.0]) class TestEGreedy: - def test_egreedy(self, eps_init): + @pytest.mark.parametrize("eps_init", [0.0, 0.5, 1.0]) + @pytest.mark.parametrize("module", [True, False]) + def test_egreedy(self, eps_init, module): torch.manual_seed(0) spec = BoundedTensorSpec(1, 1, torch.Size([4])) module = torch.nn.Linear(4, 4, bias=False) + policy = Actor(spec=spec, module=module) - explorative_policy = EGreedyWrapper(policy, eps_init=eps_init, eps_end=eps_init) + if module: + explorative_policy = TensorDictSequential( + policy, EGreedyModule(eps_init=eps_init, eps_end=eps_init, spec=spec) + ) + else: + explorative_policy = EGreedyWrapper( + policy, eps_init=eps_init, eps_end=eps_init + ) td = TensorDict({"observation": torch.zeros(10, 4)}, batch_size=[10]) action = explorative_policy(td).get("action") if eps_init == 0: @@ -58,6 +78,135 @@ def test_egreedy(self, eps_init): assert (action == 0).any() assert ((action == 1) | (action == 0)).all() + @pytest.mark.parametrize("eps_init", [0.0, 0.5, 1.0]) + @pytest.mark.parametrize("module", [True, False]) + @pytest.mark.parametrize("spec_class", ["discrete", "one_hot"]) + def test_egreedy_masked(self, module, eps_init, spec_class): + torch.manual_seed(0) + action_size = 4 + batch_size = (3, 4, 2) + module = torch.nn.Linear(action_size, action_size, bias=False) + if spec_class == "discrete": + spec = DiscreteTensorSpec(action_size) + else: + spec = OneHotDiscreteTensorSpec( + action_size, + shape=(action_size,), + ) + policy = QValueActor(spec=spec, module=module, action_mask_key="action_mask") + if module: + explorative_policy = TensorDictSequential( + policy, + EGreedyModule( + eps_init=eps_init, + eps_end=eps_init, + spec=spec, + action_mask_key="action_mask", + ), + ) + else: + explorative_policy = EGreedyWrapper( + policy, + eps_init=eps_init, + eps_end=eps_init, + action_mask_key="action_mask", + ) + + td = TensorDict( + {"observation": torch.zeros(*batch_size, action_size)}, + batch_size=batch_size, + ) + with pytest.raises(KeyError, match="Action mask key action_mask not found in"): + explorative_policy(td) + + torch.manual_seed(0) + action_mask = torch.ones(*batch_size, action_size).to(torch.bool) + td = TensorDict( + { + "observation": torch.zeros(*batch_size, action_size), + "action_mask": action_mask, + }, + batch_size=batch_size, + ) + action = explorative_policy(td).get("action") + + torch.manual_seed(0) + action_mask = torch.randint(high=2, size=(*batch_size, action_size)).to( + torch.bool + ) + while not action_mask.any(dim=-1).all() or action_mask.all(): + action_mask = torch.randint(high=2, size=(*batch_size, action_size)).to( + torch.bool + ) + + td = TensorDict( + { + "observation": torch.zeros(*batch_size, action_size), + "action_mask": action_mask, + }, + batch_size=batch_size, + ) + masked_action = explorative_policy(td).get("action") + + if spec_class == "discrete": + action = spec.to_one_hot(action) + masked_action = spec.to_one_hot(masked_action) + + assert not (action[~action_mask] == 0).all() + assert (masked_action[~action_mask] == 0).all() + + def test_egreedy_wrapper_deprecation(self): + torch.manual_seed(0) + spec = BoundedTensorSpec(1, 1, torch.Size([4])) + module = torch.nn.Linear(4, 4, bias=False) + policy = Actor(spec=spec, module=module) + with pytest.deprecated_call(): + EGreedyWrapper(policy) + + def test_no_spec_error( + self, + ): + torch.manual_seed(0) + action_size = 4 + batch_size = (3, 4, 2) + module = torch.nn.Linear(action_size, action_size, bias=False) + spec = OneHotDiscreteTensorSpec(action_size, shape=(action_size,)) + policy = QValueActor(spec=spec, module=module) + explorative_policy = TensorDictSequential( + policy, + EGreedyModule(spec=None), + ) + td = TensorDict( + { + "observation": torch.zeros(*batch_size, action_size), + }, + batch_size=batch_size, + ) + + with pytest.raises( + RuntimeError, match="spec must be provided to the exploration wrapper." + ): + explorative_policy(td) + + @pytest.mark.parametrize("module", [True, False]) + def test_wrong_action_shape(self, module): + torch.manual_seed(0) + spec = BoundedTensorSpec(1, 1, torch.Size([4])) + module = torch.nn.Linear(4, 5, bias=False) + + policy = Actor(spec=spec, module=module) + if module: + explorative_policy = TensorDictSequential(policy, EGreedyModule(spec=spec)) + else: + explorative_policy = EGreedyWrapper( + policy, + ) + td = TensorDict({"observation": torch.zeros(10, 4)}, batch_size=[10]) + with pytest.raises( + ValueError, match="Action spec shape does not match the action shape" + ): + explorative_policy(td) + @pytest.mark.parametrize("device", get_default_devices()) class TestOrnsteinUhlenbeckProcessWrapper: diff --git a/torchrl/modules/__init__.py b/torchrl/modules/__init__.py index ad654dbc7c9..604bb3bdca7 100644 --- a/torchrl/modules/__init__.py +++ b/torchrl/modules/__init__.py @@ -53,6 +53,7 @@ DistributionalQValueActor, DistributionalQValueHook, DistributionalQValueModule, + EGreedyModule, EGreedyWrapper, LMHeadActorValueOperator, LSTMModule, diff --git a/torchrl/modules/tensordict_module/__init__.py b/torchrl/modules/tensordict_module/__init__.py index 645c7b6f122..d1930855ab2 100644 --- a/torchrl/modules/tensordict_module/__init__.py +++ b/torchrl/modules/tensordict_module/__init__.py @@ -23,6 +23,7 @@ from .common import SafeModule, VmapModule from .exploration import ( AdditiveGaussianWrapper, + EGreedyModule, EGreedyWrapper, OrnsteinUhlenbeckProcessWrapper, ) diff --git a/torchrl/modules/tensordict_module/actors.py b/torchrl/modules/tensordict_module/actors.py index da719102179..7606836caa0 100644 --- a/torchrl/modules/tensordict_module/actors.py +++ b/torchrl/modules/tensordict_module/actors.py @@ -327,6 +327,8 @@ class QValueModule(TensorDictModuleBase): conditions the action_space. action_value_key (str or tuple of str, optional): The input key representing the action value. Defaults to ``"action_value"``. + action_mask_key (str or tuple of str, optional): The input key + representing the action mask. Defaults to ``"None"`` (equivalent to no masking). out_keys (list of str or tuple of str, optional): The output keys representing the actions, action values and chosen action value. Defaults to ``["action", "action_value", "chosen_action_value"]``. @@ -378,6 +380,7 @@ def __init__( self, action_space: Optional[str], action_value_key: Optional[NestedKey] = None, + action_mask_key: Optional[NestedKey] = None, out_keys: Optional[Sequence[NestedKey]] = None, var_nums: Optional[int] = None, spec: Optional[TensorSpec] = None, @@ -407,7 +410,11 @@ def __init__( ) if action_value_key is None: action_value_key = "action_value" - self.in_keys = [action_value_key] + self.action_mask_key = action_mask_key + in_keys = [action_value_key] + if self.action_mask_key is not None: + in_keys.append(self.action_mask_key) + self.in_keys = in_keys if out_keys is None: out_keys = ["action", action_value_key, "chosen_action_value"] elif action_value_key not in out_keys: @@ -446,6 +453,15 @@ def forward(self, tensordict: torch.Tensor) -> TensorDictBase: raise KeyError( f"Action value key {self.action_value_key} not found in {tensordict}." ) + if self.action_mask_key is not None: + action_mask = tensordict.get(self.action_mask_key, None) + if action_mask is None: + raise KeyError( + f"Action mask key {self.action_mask_key} not found in {tensordict}." + ) + action_values = torch.where( + action_mask, action_values, torch.finfo(action_values.dtype).min + ) action = self.action_func_mapping[self.action_space](action_values) @@ -528,6 +544,8 @@ class DistributionalQValueModule(QValueModule): support (torch.Tensor): support of the action values. action_value_key (str or tuple of str, optional): The input key representing the action value. Defaults to ``"action_value"``. + action_mask_key (str or tuple of str, optional): The input key + representing the action mask. Defaults to ``"None"`` (equivalent to no masking). out_keys (list of str or tuple of str, optional): The output keys representing the actions and action values. Defaults to ``["action", "action_value"]``. @@ -583,6 +601,7 @@ def __init__( action_space: Optional[str], support: torch.Tensor, action_value_key: Optional[NestedKey] = None, + action_mask_key: Optional[NestedKey] = None, out_keys: Optional[Sequence[NestedKey]] = None, var_nums: Optional[int] = None, spec: TensorSpec = None, @@ -595,6 +614,7 @@ def __init__( super().__init__( action_space=action_space, action_value_key=action_value_key, + action_mask_key=action_mask_key, out_keys=out_keys, var_nums=var_nums, spec=spec, @@ -609,6 +629,15 @@ def forward(self, tensordict: torch.Tensor) -> TensorDictBase: raise KeyError( f"Action value key {self.action_value_key} not found in {tensordict}." ) + if self.action_mask_key is not None: + action_mask = tensordict.get(self.action_mask_key, None) + if action_mask is None: + raise KeyError( + f"Action mask key {self.action_mask_key} not found in {tensordict}." + ) + action_values = torch.where( + action_mask, action_values, torch.finfo(action_values.dtype).min + ) action = self.action_func_mapping[self.action_space](action_values) @@ -698,6 +727,8 @@ class QValueHook: action_value_key (str or tuple of str, optional): to be used when hooked on a TensorDictModule. The input key representing the action value. Defaults to ``"action_value"``. + action_mask_key (str or tuple of str, optional): The input key + representing the action mask. Defaults to ``"None"`` (equivalent to no masking). out_keys (list of str or tuple of str, optional): to be used when hooked on a TensorDictModule. The output keys representing the actions, action values and chosen action value. Defaults to ``["action", "action_value", "chosen_action_value"]``. @@ -733,6 +764,7 @@ def __init__( action_space: str, var_nums: Optional[int] = None, action_value_key: Optional[NestedKey] = None, + action_mask_key: Optional[NestedKey] = None, out_keys: Optional[Sequence[NestedKey]] = None, ): if isinstance(action_space, TensorSpec): @@ -747,6 +779,7 @@ def __init__( action_space=action_space, var_nums=var_nums, action_value_key=action_value_key, + action_mask_key=action_mask_key, out_keys=out_keys, ) action_value_key = self.qvalue_model.in_keys[0] @@ -776,6 +809,11 @@ class DistributionalQValueHook(QValueHook): Args: action_space (str): Action space. Must be one of ``"one-hot"``, ``"mult-one-hot"``, ``"binary"`` or ``"categorical"``. + action_value_key (str or tuple of str, optional): to be used when hooked on + a TensorDictModule. The input key representing the action value. Defaults + to ``"action_value"``. + action_mask_key (str or tuple of str, optional): The input key + representing the action mask. Defaults to ``"None"`` (equivalent to no masking). support (torch.Tensor): support of the action values. var_nums (int, optional): if ``action_space = "mult-one-hot"``, this value represents the cardinality of each @@ -823,6 +861,7 @@ def __init__( support: torch.Tensor, var_nums: Optional[int] = None, action_value_key: Optional[NestedKey] = None, + action_mask_key: Optional[NestedKey] = None, out_keys: Optional[Sequence[NestedKey]] = None, ): if isinstance(action_space, TensorSpec): @@ -837,6 +876,7 @@ def __init__( var_nums=var_nums, support=support, action_value_key=action_value_key, + action_mask_key=action_mask_key, out_keys=out_keys, ) action_value_key = self.qvalue_model.in_keys[0] @@ -884,6 +924,8 @@ class QValueActor(SafeSequential): is a :class:`tensordict.nn.TensorDictModuleBase` instance, it must match one of its output keys. Otherwise, this string represents the name of the action-value entry in the output tensordict. + action_mask_key (str or tuple of str, optional): The input key + representing the action mask. Defaults to ``"None"`` (equivalent to no masking). .. note:: ``out_keys`` cannot be passed. If the module is a :class:`tensordict.nn.TensorDictModule` @@ -942,6 +984,7 @@ def __init__( safe=False, action_space: Optional[str] = None, action_value_key=None, + action_mask_key: Optional[NestedKey] = None, ): if isinstance(action_space, TensorSpec): warnings.warn( @@ -987,6 +1030,7 @@ def __init__( spec=spec, safe=safe, action_space=action_space, + action_mask_key=action_mask_key, ) super().__init__(module, qvalue) @@ -1035,6 +1079,12 @@ class DistributionalQValueActor(QValueActor): make_log_softmax (bool, optional): if ``True`` and if the module is not of type :class:`torchrl.modules.DistributionalDQNnet`, a log-softmax operation will be applied along dimension -2 of the action value tensor. + action_value_key (str or tuple of str, optional): if the input module + is a :class:`tensordict.nn.TensorDictModuleBase` instance, it must + match one of its output keys. Otherwise, this string represents + the name of the action-value entry in the output tensordict. + action_mask_key (str or tuple of str, optional): The input key + representing the action mask. Defaults to ``"None"`` (equivalent to no masking). Examples: >>> import torch @@ -1079,6 +1129,7 @@ def __init__( var_nums: Optional[int] = None, action_space: Optional[str] = None, action_value_key: str = "action_value", + action_mask_key: Optional[NestedKey] = None, make_log_softmax: bool = True, ): if isinstance(action_space, TensorSpec): @@ -1121,6 +1172,7 @@ def __init__( spec=spec, safe=safe, action_space=action_space, + action_mask_key=action_mask_key, support=support, var_nums=var_nums, ) diff --git a/torchrl/modules/tensordict_module/exploration.py b/torchrl/modules/tensordict_module/exploration.py index 20d26b7aabd..d2e8ed8e3a1 100644 --- a/torchrl/modules/tensordict_module/exploration.py +++ b/torchrl/modules/tensordict_module/exploration.py @@ -7,7 +7,12 @@ import numpy as np import torch -from tensordict.nn import TensorDictModule, TensorDictModuleWrapper + +from tensordict.nn import ( + TensorDictModule, + TensorDictModuleBase, + TensorDictModuleWrapper, +) from tensordict.tensordict import TensorDictBase from tensordict.utils import expand_as_right, expand_right, NestedKey @@ -17,13 +22,168 @@ __all__ = [ "EGreedyWrapper", + "EGreedyModule", "AdditiveGaussianWrapper", "OrnsteinUhlenbeckProcessWrapper", ] +class EGreedyModule(TensorDictModuleBase): + """Epsilon-Greedy exploration module. + + This module randomly updates the action(s) in a tensordict given an epsilon greedy exploration strategy. + At each call, random draws (one per action) are executed given a certain probability threshold. If successful, + the corresponding actions are being replaced by random samples drawn from the action spec provided. + Others are left unchanged. + + Args: + spec (TensorSpec): the spec used for sampling actions. + eps_init (scalar, optional): initial epsilon value. + default: 1.0 + eps_end (scalar, optional): final epsilon value. + default: 0.1 + annealing_num_steps (int, optional): number of steps it will take for epsilon to reach + the ``eps_end`` value. Defaults to `1000`. + + Keyword Args: + action_key (NestedKey, optional): the key where the action can be found in the input tensordict. + Default is ``"action"``. + action_mask_key (NestedKey, optional): the key where the action mask can be found in the input tensordict. + Default is ``None`` (corresponding to no mask). + + .. note:: + It is crucial to incorporate a call to :meth:`~.step` in the training loop + to update the exploration factor. + Since it is not easy to capture this omission no warning or exception + will be raised if this is ommitted! + + Examples: + >>> import torch + >>> from tensordict import TensorDict + >>> from tensordict.nn import TensorDictSequential + >>> from torchrl.modules import EGreedyModule, Actor + >>> from torchrl.data import BoundedTensorSpec + >>> torch.manual_seed(0) + >>> spec = BoundedTensorSpec(-1, 1, torch.Size([4])) + >>> module = torch.nn.Linear(4, 4, bias=False) + >>> policy = Actor(spec=spec, module=module) + >>> explorative_policy = TensorDictSequential(policy, EGreedyModule(eps_init=0.2)) + >>> td = TensorDict({"observation": torch.zeros(10, 4)}, batch_size=[10]) + >>> print(explorative_policy(td).get("action")) + tensor([[ 0.0000, 0.0000, 0.0000, 0.0000], + [ 0.0000, 0.0000, 0.0000, 0.0000], + [ 0.9055, -0.9277, -0.6295, -0.2532], + [ 0.0000, 0.0000, 0.0000, 0.0000], + [ 0.0000, 0.0000, 0.0000, 0.0000], + [ 0.0000, 0.0000, 0.0000, 0.0000], + [ 0.0000, 0.0000, 0.0000, 0.0000], + [ 0.0000, 0.0000, 0.0000, 0.0000], + [ 0.0000, 0.0000, 0.0000, 0.0000], + [ 0.0000, 0.0000, 0.0000, 0.0000]], grad_fn=) + + """ + + def __init__( + self, + spec: TensorSpec, + eps_init: float = 1.0, + eps_end: float = 0.1, + annealing_num_steps: int = 1000, + *, + action_key: Optional[NestedKey] = "action", + action_mask_key: Optional[NestedKey] = None, + ): + self.action_key = action_key + self.action_mask_key = action_mask_key + in_keys = [self.action_key] + if self.action_mask_key is not None: + in_keys.append(self.action_mask_key) + self.in_keys = in_keys + self.out_keys = [self.action_key] + + super().__init__() + + self.register_buffer("eps_init", torch.tensor([eps_init])) + self.register_buffer("eps_end", torch.tensor([eps_end])) + if self.eps_end > self.eps_init: + raise RuntimeError("eps should decrease over time or be constant") + self.annealing_num_steps = annealing_num_steps + self.register_buffer("eps", torch.tensor([eps_init])) + + if spec is not None: + if not isinstance(spec, CompositeSpec) and len(self.out_keys) >= 1: + spec = CompositeSpec({action_key: spec}, shape=spec.shape[:-1]) + self._spec = spec + + @property + def spec(self): + return self._spec + + def step(self, frames: int = 1) -> None: + """A step of epsilon decay. + + After `self.annealing_num_steps` calls to this method, calls result in no-op. + + Args: + frames (int, optional): number of frames since last step. Defaults to ``1``. + + """ + for _ in range(frames): + self.eps.data[0] = max( + self.eps_end.item(), + ( + self.eps - (self.eps_init - self.eps_end) / self.annealing_num_steps + ).item(), + ) + + def forward(self, tensordict: TensorDictBase) -> TensorDictBase: + if exploration_type() == ExplorationType.RANDOM or exploration_type() is None: + if isinstance(self.action_key, tuple) and len(self.action_key) > 1: + action_tensordict = tensordict.get(self.action_key[:-1]) + action_key = self.action_key[-1] + else: + action_tensordict = tensordict + action_key = self.action_key + + out = action_tensordict.get(action_key) + eps = self.eps.item() + cond = ( + torch.rand(action_tensordict.shape, device=action_tensordict.device) + < eps + ).to(out.dtype) + cond = expand_as_right(cond, out) + spec = self.spec + if spec is not None: + if isinstance(spec, CompositeSpec): + spec = spec[self.action_key] + if spec.shape != out.shape: + # In batched envs if the spec is passed unbatched, the rand() will not + # cover all batched dims + if ( + not len(spec.shape) + or out.shape[-len(spec.shape) :] == spec.shape + ): + spec = spec.expand(out.shape) + else: + raise ValueError( + "Action spec shape does not match the action shape" + ) + if self.action_mask_key is not None: + action_mask = tensordict.get(self.action_mask_key, None) + if action_mask is None: + raise KeyError( + f"Action mask key {self.action_mask_key} not found in {tensordict}." + ) + spec.update_mask(action_mask) + out = cond * spec.rand().to(out.device) + (1 - cond) * out + else: + raise RuntimeError("spec must be provided to the exploration wrapper.") + action_tensordict.set(action_key, out) + return tensordict + + class EGreedyWrapper(TensorDictModuleWrapper): - """Epsilon-Greedy PO wrapper. + """[Deprecated] Epsilon-Greedy PO wrapper. Args: policy (TensorDictModule): a deterministic policy. @@ -34,16 +194,16 @@ class EGreedyWrapper(TensorDictModuleWrapper): eps_end (scalar, optional): final epsilon value. default: 0.1 annealing_num_steps (int, optional): number of steps it will take for epsilon to reach the eps_end value - action_key (NestedKey, optional): if the policy module has more than one output key, - its output spec will be of type CompositeSpec. One needs to know where to - find the action spec. - Default is "action". + action_key (NestedKey, optional): the key where the action can be found in the input tensordict. + Default is ``"action"``. + action_mask_key (NestedKey, optional): the key where the action mask can be found in the input tensordict. + Default is ``None`` (corresponding to no mask). spec (TensorSpec, optional): if provided, the sampled action will be - projected onto the valid action space once explored. If not provided, + taken from this action space. If not provided, the exploration wrapper will attempt to recover it from the policy. .. note:: - Once an environment has been wrapped in :class:`EGreedyWrapper`, it is + Once a module has been wrapped in :class:`EGreedyWrapper`, it is crucial to incorporate a call to :meth:`~.step` in the training loop to update the exploration factor. Since it is not easy to capture this omission no warning or exception @@ -82,8 +242,15 @@ def __init__( eps_end: float = 0.1, annealing_num_steps: int = 1000, action_key: Optional[NestedKey] = "action", + action_mask_key: Optional[NestedKey] = None, spec: Optional[TensorSpec] = None, ): + warnings.warn( + "EGreedyWrapper is deprecated and it will be removed in v0.3. " + "Please use torchrl.modules.EGreedyModule instead.", + category=DeprecationWarning, + ) + super().__init__(policy) self.register_buffer("eps_init", torch.tensor([eps_init])) self.register_buffer("eps_end", torch.tensor([eps_end])) @@ -92,6 +259,7 @@ def __init__( self.annealing_num_steps = annealing_num_steps self.register_buffer("eps", torch.tensor([eps_init])) self.action_key = action_key + self.action_mask_key = action_mask_key if spec is not None: if not isinstance(spec, CompositeSpec) and len(self.out_keys) >= 1: spec = CompositeSpec({action_key: spec}, shape=spec.shape[:-1]) @@ -105,7 +273,7 @@ def __init__( if action_key not in self._spec.keys(): self._spec[action_key] = None else: - self._spec = CompositeSpec({key: None for key in policy.out_keys}) + self._spec = spec @property def spec(self): @@ -149,6 +317,25 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: if spec is not None: if isinstance(spec, CompositeSpec): spec = spec[self.action_key] + if spec.shape != out.shape: + # In batched envs if the spec is passed unbatched, the rand() will not + # cover all batched dims + if ( + not len(spec.shape) + or out.shape[-len(spec.shape) :] == spec.shape + ): + spec = spec.expand(out.shape) + else: + raise ValueError( + "Action spec shape does not match the action shape" + ) + if self.action_mask_key is not None: + action_mask = tensordict.get(self.action_mask_key, None) + if action_mask is None: + raise KeyError( + f"Action mask key {self.action_mask_key} not found in {tensordict}." + ) + spec.update_mask(action_mask) out = cond * spec.rand().to(out.device) + (1 - cond) * out else: raise RuntimeError(