From 4b2092f55a504c1c70e88ff48897c665a8b0f51e Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Mon, 23 Mar 2020 15:31:14 +0100 Subject: [PATCH] Remove base network --- tests/test_distributions.py | 7 +++- torchy_baselines/common/policies.py | 26 ------------- torchy_baselines/sac/policies.py | 38 ++++++++++--------- torchy_baselines/td3/policies.py | 57 +++++++++++++++++------------ 4 files changed, 60 insertions(+), 68 deletions(-) diff --git a/tests/test_distributions.py b/tests/test_distributions.py index c63900dc0..eb9fbaf07 100644 --- a/tests/test_distributions.py +++ b/tests/test_distributions.py @@ -4,7 +4,7 @@ from torchy_baselines import A2C, PPO from torchy_baselines.common.distributions import (DiagGaussianDistribution, TanhBijector, StateDependentNoiseDistribution, - CategoricalDistribution) + CategoricalDistribution, SquashedDiagGaussianDistribution) from torchy_baselines.common.utils import set_random_seed @@ -35,6 +35,11 @@ def test_squashed_gaussian(model_class): model = model_class('MlpPolicy', 'Pendulum-v0', use_sde=True, n_steps=100, policy_kwargs=dict(squash_output=True)) model.learn(500) + gaussian_mean = th.rand(N_SAMPLES, N_ACTIONS) + dist = SquashedDiagGaussianDistribution(N_ACTIONS) + _, log_std = dist.proba_distribution_net(N_FEATURES) + actions, _ = dist.proba_distribution(gaussian_mean, log_std) + assert th.max(th.abs(actions)) <= 1.0 def test_sde_distribution(): n_actions = 1 diff --git a/torchy_baselines/common/policies.py b/torchy_baselines/common/policies.py index a5e08b5c6..751f49a2c 100644 --- a/torchy_baselines/common/policies.py +++ b/torchy_baselines/common/policies.py @@ -139,32 +139,6 @@ def create_sde_feature_extractor(features_dim: int, return sde_feature_extractor, latent_sde_dim -class BaseNetwork(nn.Module): - """ - Abstract class for the different networks (actor/critic) - that implements two helpers for using CEM with their weights. - """ - def __init__(self): - super(BaseNetwork, self).__init__() - - def load_from_vector(self, vector: np.ndarray): - """ - Load parameters from a 1D vector. - - :param vector: (np.ndarray) - """ - device = next(self.parameters()).device - th.nn.utils.vector_to_parameters(th.FloatTensor(vector).to(device), self.parameters()) - - def parameters_to_vector(self) -> np.ndarray: - """ - Convert the parameters to a 1D vector. - - :return: (np.ndarray) - """ - return th.nn.utils.parameters_to_vector(self.parameters()).detach().cpu().numpy() - - _policy_registry = dict() # type: Dict[Type[BasePolicy], Dict[str, Type[BasePolicy]]] diff --git a/torchy_baselines/sac/policies.py b/torchy_baselines/sac/policies.py index 2b7367a50..3467d23a5 100644 --- a/torchy_baselines/sac/policies.py +++ b/torchy_baselines/sac/policies.py @@ -5,7 +5,7 @@ import torch.nn as nn from torchy_baselines.common.preprocessing import get_action_dim, get_obs_dim -from torchy_baselines.common.policies import (BasePolicy, register_policy, create_mlp, BaseNetwork, +from torchy_baselines.common.policies import (BasePolicy, register_policy, create_mlp, create_sde_feature_extractor) from torchy_baselines.common.distributions import SquashedDiagGaussianDistribution, StateDependentNoiseDistribution @@ -14,12 +14,12 @@ LOG_STD_MIN = -20 -class Actor(BaseNetwork): +class Actor(BasePolicy): """ Actor network (policy) for SAC. - :param obs_dim: (int) Dimension of the observation - :param action_dim: (int) Dimension of the action space + :param observation_space: (gym.spaces.Space) Obervation space + :param action_space: (gym.spaces.Space) Action space :param net_arch: ([int]) Network architecture :param activation_fn: (nn.Module) Activation function :param use_sde: (bool) Whether to use State Dependent Exploration or not @@ -34,8 +34,8 @@ class Actor(BaseNetwork): above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough. :param clip_mean: (float) Clip the mean output when using SDE to avoid numerical instability. """ - def __init__(self, obs_dim: int, - action_dim: int, + def __init__(self, observation_space: gym.spaces.Space, + action_space: gym.spaces.Space, net_arch: List[int], activation_fn: nn.Module = nn.ReLU, use_sde: bool = False, @@ -44,7 +44,10 @@ def __init__(self, obs_dim: int, sde_net_arch: Optional[List[int]] = None, use_expln: bool = False, clip_mean: float = 2.0): - super(Actor, self).__init__() + super(Actor, self).__init__(observation_space, action_space) + + obs_dim = get_obs_dim(self.observation_space) + action_dim = get_action_dim(self.action_space) latent_pi_net = create_mlp(obs_dim, -1, net_arch, activation_fn) self.latent_pi = nn.Sequential(*latent_pi_net) @@ -128,20 +131,23 @@ def action_log_prob(self, obs: th.Tensor) -> Tuple[th.Tensor, th.Tensor]: return self.action_dist.log_prob_from_params(mean_actions, log_std, **kwargs) -class Critic(BaseNetwork): +class Critic(BasePolicy): """ Critic network (q-value function) for SAC. - :param obs_dim: (int) Dimension of the observation - :param action_dim: (int) Dimension of the action space + :param observation_space: (gym.spaces.Space) Obervation space + :param action_space: (gym.spaces.Space) Action space :param net_arch: ([int]) Network architecture :param activation_fn: (nn.Module) Activation function """ - def __init__(self, obs_dim: int, - action_dim: int, + def __init__(self, observation_space: gym.spaces.Space, + action_space: gym.spaces.Space, net_arch: List[int], activation_fn: nn.Module = nn.ReLU): - super(Critic, self).__init__() + super(Critic, self).__init__(observation_space, action_space) + + obs_dim = get_obs_dim(self.observation_space) + action_dim = get_action_dim(self.action_space) q1_net = create_mlp(obs_dim + action_dim, 1, net_arch, activation_fn) self.q1_net = nn.Sequential(*q1_net) @@ -192,13 +198,11 @@ def __init__(self, observation_space: gym.spaces.Space, if net_arch is None: net_arch = [256, 256] - self.obs_dim = get_obs_dim(self.observation_space) - self.action_dim = get_action_dim(self.action_space) self.net_arch = net_arch self.activation_fn = activation_fn self.net_args = { - 'obs_dim': self.obs_dim, - 'action_dim': self.action_dim, + 'observation_space': self.observation_space, + 'action_space': self.action_space, 'net_arch': self.net_arch, 'activation_fn': self.activation_fn } diff --git a/torchy_baselines/td3/policies.py b/torchy_baselines/td3/policies.py index 9b6f5d70a..a556f2b43 100644 --- a/torchy_baselines/td3/policies.py +++ b/torchy_baselines/td3/policies.py @@ -5,17 +5,17 @@ import torch.nn as nn from torchy_baselines.common.preprocessing import get_action_dim, get_obs_dim -from torchy_baselines.common.policies import (BasePolicy, register_policy, create_mlp, BaseNetwork, +from torchy_baselines.common.policies import (BasePolicy, register_policy, create_mlp, create_sde_feature_extractor) from torchy_baselines.common.distributions import StateDependentNoiseDistribution, Distribution -class Actor(BaseNetwork): +class Actor(BasePolicy): """ Actor network (policy) for TD3. - :param obs_dim: (int) Dimension of the observation - :param action_dim: (int) Dimension of the action space + :param observation_space: (gym.spaces.Space) Obervation space + :param action_space: (gym.spaces.Space) Action space :param net_arch: ([int]) Network architecture :param activation_fn: (nn.Module) Activation function :param use_sde: (bool) Whether to use State Dependent Exploration or not @@ -32,8 +32,8 @@ class Actor(BaseNetwork): above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough. """ def __init__(self, - obs_dim: int, - action_dim: int, + observation_space: gym.spaces.Space, + action_space: gym.spaces.Space, net_arch: List[int], activation_fn: nn.Module = nn.ReLU, use_sde: bool = False, @@ -43,15 +43,17 @@ def __init__(self, full_std: bool = False, sde_net_arch: Optional[List[int]] = None, use_expln: bool = False): - super(Actor, self).__init__() + super(Actor, self).__init__(observation_space, action_space) self.latent_pi, self.log_std = None, None self.weights_dist, self.exploration_mat = None, None self.use_sde, self.sde_optimizer = use_sde, None - self.action_dim = action_dim self.full_std = full_std self.sde_feature_extractor = None + obs_dim = get_obs_dim(self.observation_space) + action_dim = get_action_dim(self.action_space) + if use_sde: latent_pi_net = create_mlp(obs_dim, -1, net_arch, activation_fn, squash_output=False) self.latent_pi = nn.Sequential(*latent_pi_net) @@ -95,7 +97,7 @@ def _get_action_dist_from_latent(self, latent_pi: th.Tensor, mean_actions = self.mu(latent_pi) return self.action_dist.proba_distribution(mean_actions, self.log_std, latent_sde) - def _get_latent(self, obs) -> Tuple[th.Tensor, th.Tensor]: + def _get_latent(self, obs: th.Tensor) -> Tuple[th.Tensor, th.Tensor]: latent_pi = self.latent_pi(obs) if self.sde_feature_extractor is not None: @@ -145,19 +147,24 @@ def forward(self, obs: th.Tensor, deterministic: bool = True) -> th.Tensor: return self.mu(obs) -class Critic(BaseNetwork): +class Critic(BasePolicy): """ Critic network for TD3, in fact it represents the action-state value function (Q-value function) - :param obs_dim: (int) Dimension of the observation - :param action_dim: (int) Dimension of the action space + :param observation_space: (gym.spaces.Space) Obervation space + :param action_space: (gym.spaces.Space) Action space :param net_arch: ([int]) Network architecture :param activation_fn: (nn.Module) Activation function """ - def __init__(self, obs_dim: int, action_dim: int, - net_arch: List[int], activation_fn: nn.Module = nn.ReLU): - super(Critic, self).__init__() + def __init__(self, observation_space: gym.spaces.Space, + action_space: gym.spaces.Space, + net_arch: List[int], + activation_fn: nn.Module = nn.ReLU): + super(Critic, self).__init__(observation_space, action_space) + + obs_dim = get_obs_dim(self.observation_space) + action_dim = get_action_dim(self.action_space) q1_net = create_mlp(obs_dim + action_dim, 1, net_arch, activation_fn) self.q1_net = nn.Sequential(*q1_net) @@ -173,18 +180,22 @@ def q1_forward(self, obs: th.Tensor, action: th.Tensor) -> th.Tensor: return self.q1_net(th.cat([obs, action], dim=1)) -class ValueFunction(BaseNetwork): +class ValueFunction(BasePolicy): """ Value function for TD3 when doing on-policy exploration with SDE. - :param obs_dim: (int) Dimension of the observation + :param observation_space: (gym.spaces.Space) Obervation space + :param action_space: (gym.spaces.Space) Action space :param net_arch: (Optional[List[int]]) Network architecture :param activation_fn: (nn.Module) Activation function """ - def __init__(self, obs_dim: int, net_arch: Optional[List[int]] = None, + def __init__(self, observation_space: gym.spaces.Space, + action_space: gym.spaces.Space, + net_arch: Optional[List[int]] = None, activation_fn: nn.Module = nn.Tanh): - super(ValueFunction, self).__init__() + super(ValueFunction, self).__init__(observation_space, action_space) + obs_dim = get_obs_dim(self.observation_space) if net_arch is None: net_arch = [64, 64] @@ -232,13 +243,11 @@ def __init__(self, observation_space: gym.spaces.Space, if net_arch is None: net_arch = [400, 300] - self.obs_dim = get_obs_dim(self.observation_space) - self.action_dim = get_action_dim(self.action_space) self.net_arch = net_arch self.activation_fn = activation_fn self.net_args = { - 'obs_dim': self.obs_dim, - 'action_dim': self.action_dim, + 'observation_space': self.observation_space, + 'action_space': self.action_space, 'net_arch': self.net_arch, 'activation_fn': self.activation_fn } @@ -273,7 +282,7 @@ def _build(self, lr_schedule: Callable) -> None: self.critic.optimizer = th.optim.Adam(self.critic.parameters(), lr=lr_schedule(1)) if self.use_sde: - self.vf_net = ValueFunction(self.obs_dim) + self.vf_net = ValueFunction(self.observation_space, self.action_space) self.actor.sde_optimizer.add_param_group({'params': self.vf_net.parameters()}) # pytype: disable=attribute-error def reset_noise(self) -> None: