Skip to content

Commit

Permalink
Merge pull request DLR-RM#51 from Antonin-Raffin/fix/entropy-squashed
Browse files Browse the repository at this point in the history
Fix entropy loss for squashed Gaussian and VecEnv seeding
  • Loading branch information
araffin authored and GitHub Enterprise committed Feb 11, 2020
2 parents 02a080f + 240833f commit cbb0843
Show file tree
Hide file tree
Showing 9 changed files with 78 additions and 30 deletions.
2 changes: 2 additions & 0 deletions docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ Bug Fixes:
^^^^^^^^^^
- Fix loading model on CPU that were trained on GPU
- Fix `reset_num_timesteps` that was not used
- Fix entropy computation for squashed Gaussian (approximate it now)
- Fix seeding when using multiple environments (different seed per env)

Deprecations:
^^^^^^^^^^^^^
Expand Down
9 changes: 9 additions & 0 deletions tests/test_distributions.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pytest
import torch as th

from torchy_baselines import A2C, PPO
from torchy_baselines.common.distributions import DiagGaussianDistribution, TanhBijector, \
StateDependentNoiseDistribution
from torchy_baselines.common.utils import set_random_seed
Expand All @@ -21,6 +22,14 @@ def test_bijector():
# Check the inverse method
assert th.isclose(TanhBijector.inverse(squashed_actions), actions).all()

@pytest.mark.parametrize("model_class", [A2C, PPO])
def test_squashed_gaussian(model_class):
"""
Test run with squashed Gaussian (notably entropy computation)
"""
model = model_class('MlpPolicy', 'Pendulum-v0', use_sde=True, n_steps=100, policy_kwargs=dict(squash_output=True))
model.learn(500)


def test_sde_distribution():
n_samples = int(5e6)
Expand Down
10 changes: 7 additions & 3 deletions torchy_baselines/a2c/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def _setup_model(self):
lr=self.learning_rate(1), alpha=0.99,
eps=self.rms_prop_eps, weight_decay=0)

def train(self, gradient_steps, batch_size=None):
def train(self, gradient_steps: int, batch_size=None):
# Update optimizer learning rate
self._update_learning_rate(self.policy.optimizer)
# A2C with gradient_steps > 1 does not make sense
Expand Down Expand Up @@ -107,7 +107,11 @@ def train(self, gradient_steps, batch_size=None):
value_loss = F.mse_loss(return_batch, values)

# Entropy loss favor exploration
entropy_loss = -th.mean(entropy)
if entropy is None:
# Approximate entropy when no analytical form
entropy_loss = -log_prob.mean()
else:
entropy_loss = -th.mean(entropy)

loss = policy_loss + self.ent_coef * entropy_loss + self.vf_coef * value_loss

Expand All @@ -123,7 +127,7 @@ def train(self, gradient_steps, batch_size=None):
self.rollout_buffer.values.flatten())

logger.logkv("explained_variance", explained_var)
logger.logkv("entropy", entropy.mean().item())
logger.logkv("entropy_loss", entropy_loss.item())
logger.logkv("policy_loss", policy_loss.item())
logger.logkv("value_loss", value_loss.item())
if hasattr(self.policy, 'log_std'):
Expand Down
7 changes: 4 additions & 3 deletions torchy_baselines/common/buffers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import torch as th

from torchy_baselines.common.vec_env import VecNormalize
from torchy_baselines.common.type_aliases import RolloutBufferSamples, ReplayBufferSamples


class BaseBuffer(object):
Expand Down Expand Up @@ -177,7 +178,7 @@ def add(self,
def _get_samples(self,
batch_inds: np.ndarray,
env: Optional[VecNormalize] = None
) -> Tuple[th.Tensor, ...]:
) -> ReplayBufferSamples:
data = (self._normalize_obs(self.observations[batch_inds, 0, :], env),
self.actions[batch_inds, 0, :],
self._normalize_obs(self.next_observations[batch_inds, 0, :], env),
Expand Down Expand Up @@ -305,7 +306,7 @@ def add(self,
if self.pos == self.buffer_size:
self.full = True

def get(self, batch_size: Optional[int] = None) -> Generator[Tuple[th.Tensor, ...], None, None]:
def get(self, batch_size: Optional[int] = None) -> Generator[RolloutBufferSamples, None, None]:
assert self.full, ''
indices = np.random.permutation(self.buffer_size * self.n_envs)
# Prepare the data
Expand All @@ -325,7 +326,7 @@ def get(self, batch_size: Optional[int] = None) -> Generator[Tuple[th.Tensor, ..
start_idx += batch_size

def _get_samples(self, batch_inds: np.ndarray,
env: Optional[VecNormalize] = None) -> Tuple[th.Tensor, ...]:
env: Optional[VecNormalize] = None) -> RolloutBufferSamples:
data = (self.observations[batch_inds],
self.actions[batch_inds],
self.values[batch_inds].flatten(),
Expand Down
23 changes: 17 additions & 6 deletions torchy_baselines/common/distributions.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Optional

import torch as th
import torch.nn as nn
from torch.distributions import Normal, Categorical
Expand All @@ -8,24 +10,25 @@ class Distribution(object):
def __init__(self):
super(Distribution, self).__init__()

def log_prob(self, x):
def log_prob(self, x: th.Tensor) -> th.Tensor:
"""
returns the log likelihood
:param x: (object) the taken action
:param x: (th.Tensor) the taken action
:return: (th.Tensor) The log likelihood of the distribution
"""
raise NotImplementedError

def entropy(self):
def entropy(self) -> Optional[th.Tensor]:
"""
Returns shannon's entropy of the probability
:return: (th.Tensor) the entropy
:return: (Optional[th.Tensor]) the entropy,
return None if no analytical form is known
"""
raise NotImplementedError

def sample(self):
def sample(self) -> th.Tensor:
"""
returns a sample from the probabilty distribution
Expand Down Expand Up @@ -145,6 +148,11 @@ def mode(self):
# Squash the output
return th.tanh(self.gaussian_action)

def entropy(self):
# No analytical form,
# entropy needs to be estimated using -log_prob.mean()
return None

def sample(self):
self.gaussian_action = self.distribution.rsample()
return th.tanh(self.gaussian_action)
Expand Down Expand Up @@ -371,7 +379,10 @@ def sample(self, latent_sde):
return action

def entropy(self):
# TODO: account for the squashing?
# No analytical form,
# entropy needs to be estimated using -log_prob.mean()
if self.bijector is not None:
return None
return self.distribution.entropy()

def log_prob_from_params(self, mean_actions, log_std, latent_sde):
Expand Down
8 changes: 6 additions & 2 deletions torchy_baselines/common/type_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,16 @@
"""
from typing import Union, Type, Optional, Dict, Any, List, Tuple

import torch
import torch as th
import gym

from torchy_baselines.common.vec_env import VecEnv


GymEnv = Union[gym.Env, VecEnv]
TensorDict = Dict[str, torch.Tensor]
TensorDict = Dict[str, th.Tensor]
OptimizerStateDict = Dict[str, Any]
# obs, action, old_values, old_log_prob, advantage, return_batch
RolloutBufferSamples = Tuple[th.Tensor, th.Tensor, th.Tensor, th.Tensor, th.Tensor, th.Tensor]
# obs, action, next_obs, done, reward
ReplayBufferSamples = Tuple[th.Tensor, th.Tensor, th.Tensor, th.Tensor, th.Tensor]
3 changes: 2 additions & 1 deletion torchy_baselines/common/vec_env/base_vec_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,8 +152,9 @@ def seed(self, seed, indices=None):
:param indices: ([int])
"""
indices = self._get_indices(indices)
# Different seed per environment
if not hasattr(seed, 'len'):
seed = [seed] * len(indices)
seed = [seed + i for i in range(len(indices))]
assert len(seed) == len(indices)
return [self.env_method('seed', seed[i], indices=i) for i in indices]

Expand Down
22 changes: 14 additions & 8 deletions torchy_baselines/ppo/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,9 +116,7 @@ def _setup_model(self):
# Action is a scalar
action_dim = 1

# TODO: different seed for each env when n_envs > 1
if self.n_envs == 1:
self.set_random_seed(self.seed)
self.set_random_seed(self.seed)

self.rollout_buffer = RolloutBuffer(self.n_steps, state_dim, action_dim, self.device,
gamma=self.gamma, gae_lambda=self.gae_lambda, n_envs=self.n_envs)
Expand Down Expand Up @@ -208,15 +206,14 @@ def collect_rollouts(self,

return obs, continue_training

def train(self, gradient_steps, batch_size=64):
def train(self, gradient_steps: int, batch_size: int = 64) -> None:
# Update optimizer learning rate
self._update_learning_rate(self.policy.optimizer)
# Compute current clip range
clip_range = self.clip_range(self._current_progress)
logger.logkv("clip_range", clip_range)
# Optional: clip range for the value function
if self.clip_range_vf is not None:
clip_range_vf = self.clip_range_vf(self._current_progress)
logger.logkv("clip_range_vf", clip_range_vf)

for gradient_step in range(gradient_steps):
approx_kl_divs = []
Expand Down Expand Up @@ -258,7 +255,11 @@ def train(self, gradient_steps, batch_size=64):
value_loss = F.mse_loss(return_batch, values_pred)

# Entropy loss favor exploration
entropy_loss = -th.mean(entropy)
if entropy is None:
# Approximate entropy when no analytical form
entropy_loss = -log_prob.mean()
else:
entropy_loss = -th.mean(entropy)

loss = policy_loss + self.ent_coef * entropy_loss + self.vf_coef * value_loss

Expand All @@ -278,9 +279,14 @@ def train(self, gradient_steps, batch_size=64):
explained_var = explained_variance(self.rollout_buffer.returns.flatten(),
self.rollout_buffer.values.flatten())

logger.logkv("clip_range", clip_range)
if self.clip_range_vf is not None:
logger.logkv("clip_range_vf", clip_range_vf)


logger.logkv("explained_variance", explained_var)
# TODO: gather stats for the entropy and other losses?
logger.logkv("entropy", entropy.mean().item())
logger.logkv("entropy_loss", entropy_loss.item())
logger.logkv("policy_loss", policy_loss.item())
logger.logkv("value_loss", value_loss.item())
if hasattr(self.policy, 'log_std'):
Expand Down
24 changes: 17 additions & 7 deletions torchy_baselines/td3/td3.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from typing import List, Tuple
from typing import List, Tuple, Optional

import torch as th
import torch.nn.functional as F
import numpy as np

from torchy_baselines.common.base_class import OffPolicyRLModel
from torchy_baselines.common.buffers import ReplayBuffer
from torchy_baselines.common.type_aliases import ReplayBufferSamples
from torchy_baselines.td3.policies import TD3Policy


Expand Down Expand Up @@ -132,7 +133,10 @@ def predict(self, observation, state=None, mask=None, deterministic=True):
"""
return self.unscale_action(self.select_action(observation, deterministic=deterministic))

def train_critic(self, gradient_steps=1, batch_size=100, replay_data=None, tau=0.0):
def train_critic(self, gradient_steps: int = 1,
batch_size: int = 100,
replay_data: Optional[ReplayBufferSamples] = None,
tau: float = 0.0):
# Update optimizer learning rate
self._update_learning_rate(self.critic.optimizer)

Expand Down Expand Up @@ -171,9 +175,11 @@ def train_critic(self, gradient_steps=1, batch_size=100, replay_data=None, tau=0
for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):
target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data)

def train_actor(self, gradient_steps=1, batch_size=100, tau_actor=0.005,
tau_critic=0.005,
replay_data=None):
def train_actor(self, gradient_steps: int = 1,
batch_size: int = 100,
tau_actor: float = 0.005,
tau_critic: float = 0.005,
replay_data: Optional[ReplayBufferSamples] = None):
# Update optimizer learning rate
self._update_learning_rate(self.actor.optimizer)

Expand All @@ -200,7 +206,7 @@ def train_actor(self, gradient_steps=1, batch_size=100, tau_actor=0.005,
for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()):
target_param.data.copy_(tau_actor * param.data + (1 - tau_actor) * target_param.data)

def train(self, gradient_steps, batch_size=100, policy_delay=2):
def train(self, gradient_steps: int, batch_size: int = 100, policy_delay: int = 2):

for gradient_step in range(gradient_steps):

Expand Down Expand Up @@ -234,7 +240,11 @@ def train_sde(self):
policy_loss = -(advantage * log_prob).mean()

# Entropy loss favor exploration
entropy_loss = -th.mean(entropy)
if entropy is None:
# Approximate entropy when no analytical form
entropy_loss = -log_prob.mean()
else:
entropy_loss = -th.mean(entropy)

vf_coef = 0.5
loss = policy_loss + self.sde_ent_coef * entropy_loss + vf_coef * value_loss
Expand Down

0 comments on commit cbb0843

Please sign in to comment.