Skip to content

Commit

Permalink
Remove base network
Browse files Browse the repository at this point in the history
  • Loading branch information
araffin committed Mar 23, 2020
1 parent dcb54b5 commit 4b2092f
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 68 deletions.
7 changes: 6 additions & 1 deletion tests/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from torchy_baselines import A2C, PPO
from torchy_baselines.common.distributions import (DiagGaussianDistribution, TanhBijector,
StateDependentNoiseDistribution,
CategoricalDistribution)
CategoricalDistribution, SquashedDiagGaussianDistribution)
from torchy_baselines.common.utils import set_random_seed


Expand Down Expand Up @@ -35,6 +35,11 @@ def test_squashed_gaussian(model_class):
model = model_class('MlpPolicy', 'Pendulum-v0', use_sde=True, n_steps=100, policy_kwargs=dict(squash_output=True))
model.learn(500)

gaussian_mean = th.rand(N_SAMPLES, N_ACTIONS)
dist = SquashedDiagGaussianDistribution(N_ACTIONS)
_, log_std = dist.proba_distribution_net(N_FEATURES)
actions, _ = dist.proba_distribution(gaussian_mean, log_std)
assert th.max(th.abs(actions)) <= 1.0

def test_sde_distribution():
n_actions = 1
Expand Down
26 changes: 0 additions & 26 deletions torchy_baselines/common/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,32 +139,6 @@ def create_sde_feature_extractor(features_dim: int,
return sde_feature_extractor, latent_sde_dim


class BaseNetwork(nn.Module):
"""
Abstract class for the different networks (actor/critic)
that implements two helpers for using CEM with their weights.
"""
def __init__(self):
super(BaseNetwork, self).__init__()

def load_from_vector(self, vector: np.ndarray):
"""
Load parameters from a 1D vector.
:param vector: (np.ndarray)
"""
device = next(self.parameters()).device
th.nn.utils.vector_to_parameters(th.FloatTensor(vector).to(device), self.parameters())

def parameters_to_vector(self) -> np.ndarray:
"""
Convert the parameters to a 1D vector.
:return: (np.ndarray)
"""
return th.nn.utils.parameters_to_vector(self.parameters()).detach().cpu().numpy()


_policy_registry = dict() # type: Dict[Type[BasePolicy], Dict[str, Type[BasePolicy]]]


Expand Down
38 changes: 21 additions & 17 deletions torchy_baselines/sac/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch.nn as nn

from torchy_baselines.common.preprocessing import get_action_dim, get_obs_dim
from torchy_baselines.common.policies import (BasePolicy, register_policy, create_mlp, BaseNetwork,
from torchy_baselines.common.policies import (BasePolicy, register_policy, create_mlp,
create_sde_feature_extractor)
from torchy_baselines.common.distributions import SquashedDiagGaussianDistribution, StateDependentNoiseDistribution

Expand All @@ -14,12 +14,12 @@
LOG_STD_MIN = -20


class Actor(BaseNetwork):
class Actor(BasePolicy):
"""
Actor network (policy) for SAC.
:param obs_dim: (int) Dimension of the observation
:param action_dim: (int) Dimension of the action space
:param observation_space: (gym.spaces.Space) Obervation space
:param action_space: (gym.spaces.Space) Action space
:param net_arch: ([int]) Network architecture
:param activation_fn: (nn.Module) Activation function
:param use_sde: (bool) Whether to use State Dependent Exploration or not
Expand All @@ -34,8 +34,8 @@ class Actor(BaseNetwork):
above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough.
:param clip_mean: (float) Clip the mean output when using SDE to avoid numerical instability.
"""
def __init__(self, obs_dim: int,
action_dim: int,
def __init__(self, observation_space: gym.spaces.Space,
action_space: gym.spaces.Space,
net_arch: List[int],
activation_fn: nn.Module = nn.ReLU,
use_sde: bool = False,
Expand All @@ -44,7 +44,10 @@ def __init__(self, obs_dim: int,
sde_net_arch: Optional[List[int]] = None,
use_expln: bool = False,
clip_mean: float = 2.0):
super(Actor, self).__init__()
super(Actor, self).__init__(observation_space, action_space)

obs_dim = get_obs_dim(self.observation_space)
action_dim = get_action_dim(self.action_space)

latent_pi_net = create_mlp(obs_dim, -1, net_arch, activation_fn)
self.latent_pi = nn.Sequential(*latent_pi_net)
Expand Down Expand Up @@ -128,20 +131,23 @@ def action_log_prob(self, obs: th.Tensor) -> Tuple[th.Tensor, th.Tensor]:
return self.action_dist.log_prob_from_params(mean_actions, log_std, **kwargs)


class Critic(BaseNetwork):
class Critic(BasePolicy):
"""
Critic network (q-value function) for SAC.
:param obs_dim: (int) Dimension of the observation
:param action_dim: (int) Dimension of the action space
:param observation_space: (gym.spaces.Space) Obervation space
:param action_space: (gym.spaces.Space) Action space
:param net_arch: ([int]) Network architecture
:param activation_fn: (nn.Module) Activation function
"""
def __init__(self, obs_dim: int,
action_dim: int,
def __init__(self, observation_space: gym.spaces.Space,
action_space: gym.spaces.Space,
net_arch: List[int],
activation_fn: nn.Module = nn.ReLU):
super(Critic, self).__init__()
super(Critic, self).__init__(observation_space, action_space)

obs_dim = get_obs_dim(self.observation_space)
action_dim = get_action_dim(self.action_space)

q1_net = create_mlp(obs_dim + action_dim, 1, net_arch, activation_fn)
self.q1_net = nn.Sequential(*q1_net)
Expand Down Expand Up @@ -192,13 +198,11 @@ def __init__(self, observation_space: gym.spaces.Space,
if net_arch is None:
net_arch = [256, 256]

self.obs_dim = get_obs_dim(self.observation_space)
self.action_dim = get_action_dim(self.action_space)
self.net_arch = net_arch
self.activation_fn = activation_fn
self.net_args = {
'obs_dim': self.obs_dim,
'action_dim': self.action_dim,
'observation_space': self.observation_space,
'action_space': self.action_space,
'net_arch': self.net_arch,
'activation_fn': self.activation_fn
}
Expand Down
57 changes: 33 additions & 24 deletions torchy_baselines/td3/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,17 @@
import torch.nn as nn

from torchy_baselines.common.preprocessing import get_action_dim, get_obs_dim
from torchy_baselines.common.policies import (BasePolicy, register_policy, create_mlp, BaseNetwork,
from torchy_baselines.common.policies import (BasePolicy, register_policy, create_mlp,
create_sde_feature_extractor)
from torchy_baselines.common.distributions import StateDependentNoiseDistribution, Distribution


class Actor(BaseNetwork):
class Actor(BasePolicy):
"""
Actor network (policy) for TD3.
:param obs_dim: (int) Dimension of the observation
:param action_dim: (int) Dimension of the action space
:param observation_space: (gym.spaces.Space) Obervation space
:param action_space: (gym.spaces.Space) Action space
:param net_arch: ([int]) Network architecture
:param activation_fn: (nn.Module) Activation function
:param use_sde: (bool) Whether to use State Dependent Exploration or not
Expand All @@ -32,8 +32,8 @@ class Actor(BaseNetwork):
above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough.
"""
def __init__(self,
obs_dim: int,
action_dim: int,
observation_space: gym.spaces.Space,
action_space: gym.spaces.Space,
net_arch: List[int],
activation_fn: nn.Module = nn.ReLU,
use_sde: bool = False,
Expand All @@ -43,15 +43,17 @@ def __init__(self,
full_std: bool = False,
sde_net_arch: Optional[List[int]] = None,
use_expln: bool = False):
super(Actor, self).__init__()
super(Actor, self).__init__(observation_space, action_space)

self.latent_pi, self.log_std = None, None
self.weights_dist, self.exploration_mat = None, None
self.use_sde, self.sde_optimizer = use_sde, None
self.action_dim = action_dim
self.full_std = full_std
self.sde_feature_extractor = None

obs_dim = get_obs_dim(self.observation_space)
action_dim = get_action_dim(self.action_space)

if use_sde:
latent_pi_net = create_mlp(obs_dim, -1, net_arch, activation_fn, squash_output=False)
self.latent_pi = nn.Sequential(*latent_pi_net)
Expand Down Expand Up @@ -95,7 +97,7 @@ def _get_action_dist_from_latent(self, latent_pi: th.Tensor,
mean_actions = self.mu(latent_pi)
return self.action_dist.proba_distribution(mean_actions, self.log_std, latent_sde)

def _get_latent(self, obs) -> Tuple[th.Tensor, th.Tensor]:
def _get_latent(self, obs: th.Tensor) -> Tuple[th.Tensor, th.Tensor]:
latent_pi = self.latent_pi(obs)

if self.sde_feature_extractor is not None:
Expand Down Expand Up @@ -145,19 +147,24 @@ def forward(self, obs: th.Tensor, deterministic: bool = True) -> th.Tensor:
return self.mu(obs)


class Critic(BaseNetwork):
class Critic(BasePolicy):
"""
Critic network for TD3,
in fact it represents the action-state value function (Q-value function)
:param obs_dim: (int) Dimension of the observation
:param action_dim: (int) Dimension of the action space
:param observation_space: (gym.spaces.Space) Obervation space
:param action_space: (gym.spaces.Space) Action space
:param net_arch: ([int]) Network architecture
:param activation_fn: (nn.Module) Activation function
"""
def __init__(self, obs_dim: int, action_dim: int,
net_arch: List[int], activation_fn: nn.Module = nn.ReLU):
super(Critic, self).__init__()
def __init__(self, observation_space: gym.spaces.Space,
action_space: gym.spaces.Space,
net_arch: List[int],
activation_fn: nn.Module = nn.ReLU):
super(Critic, self).__init__(observation_space, action_space)

obs_dim = get_obs_dim(self.observation_space)
action_dim = get_action_dim(self.action_space)

q1_net = create_mlp(obs_dim + action_dim, 1, net_arch, activation_fn)
self.q1_net = nn.Sequential(*q1_net)
Expand All @@ -173,18 +180,22 @@ def q1_forward(self, obs: th.Tensor, action: th.Tensor) -> th.Tensor:
return self.q1_net(th.cat([obs, action], dim=1))


class ValueFunction(BaseNetwork):
class ValueFunction(BasePolicy):
"""
Value function for TD3 when doing on-policy exploration with SDE.
:param obs_dim: (int) Dimension of the observation
:param observation_space: (gym.spaces.Space) Obervation space
:param action_space: (gym.spaces.Space) Action space
:param net_arch: (Optional[List[int]]) Network architecture
:param activation_fn: (nn.Module) Activation function
"""
def __init__(self, obs_dim: int, net_arch: Optional[List[int]] = None,
def __init__(self, observation_space: gym.spaces.Space,
action_space: gym.spaces.Space,
net_arch: Optional[List[int]] = None,
activation_fn: nn.Module = nn.Tanh):
super(ValueFunction, self).__init__()
super(ValueFunction, self).__init__(observation_space, action_space)

obs_dim = get_obs_dim(self.observation_space)
if net_arch is None:
net_arch = [64, 64]

Expand Down Expand Up @@ -232,13 +243,11 @@ def __init__(self, observation_space: gym.spaces.Space,
if net_arch is None:
net_arch = [400, 300]

self.obs_dim = get_obs_dim(self.observation_space)
self.action_dim = get_action_dim(self.action_space)
self.net_arch = net_arch
self.activation_fn = activation_fn
self.net_args = {
'obs_dim': self.obs_dim,
'action_dim': self.action_dim,
'observation_space': self.observation_space,
'action_space': self.action_space,
'net_arch': self.net_arch,
'activation_fn': self.activation_fn
}
Expand Down Expand Up @@ -273,7 +282,7 @@ def _build(self, lr_schedule: Callable) -> None:
self.critic.optimizer = th.optim.Adam(self.critic.parameters(), lr=lr_schedule(1))

if self.use_sde:
self.vf_net = ValueFunction(self.obs_dim)
self.vf_net = ValueFunction(self.observation_space, self.action_space)
self.actor.sde_optimizer.add_param_group({'params': self.vf_net.parameters()}) # pytype: disable=attribute-error

def reset_noise(self) -> None:
Expand Down

0 comments on commit 4b2092f

Please sign in to comment.