Skip to content

Commit

Permalink
Work on Actor-Critic algorithms
Browse files Browse the repository at this point in the history
- Minor refactoring in Agent
- Created new class OnPolicyDeepAC
- Fixed some issues of PPO_BPTT and policy class
- PPO_BPTT is now using our latest interface
- Implemented agent normalization in PPO, TRPO, PPO_BPTT
  • Loading branch information
boris-il-forte committed Jan 22, 2024
1 parent 77affee commit a9f3f1a
Show file tree
Hide file tree
Showing 8 changed files with 135 additions and 116 deletions.
16 changes: 8 additions & 8 deletions examples/gym_recurrent_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ def experiment(
# setup critic
input_shape_critic = (mdp.info.observation_space.shape[0]+2*n_hidden_features,)
critic_params = dict(network=PPOCriticBPTTNetwork,
optimizer={'class': optim.Adam,
optimizer={'class': optim.Adam,
'params': {'lr': lr_critic,
'weight_decay': 0.0}},
loss=torch.nn.MSELoss(),
Expand All @@ -240,7 +240,7 @@ def experiment(
)

alg_params = dict(actor_optimizer={'class': optim.Adam,
'params': {'lr': lr_actor,
'params': {'lr': lr_actor,
'weight_decay': 0.0}},
n_epochs_policy=n_epochs_policy,
batch_size=batch_size_actor,
Expand All @@ -258,9 +258,9 @@ def experiment(

# Evaluation
dataset = core.evaluate(n_episodes=5)
J = np.mean(dataset.discounted_return)
R = np.mean(dataset.undiscounted_return)
L = np.mean(dataset.episodes_length)
J = dataset.discounted_return.mean()
R = dataset.undiscounted_return.mean()
L = dataset.episodes_length.mean()
logger.log_numpy(R=R, J=J, L=L)
logger.epoch_info(0, R=R, J=J, L=L)

Expand All @@ -269,9 +269,9 @@ def experiment(

# Evaluation
dataset = core.evaluate(n_episodes=n_episode_eval)
J = np.mean(dataset.discounted_return)
R = np.mean(dataset.undiscounted_return)
L = np.mean(dataset.episodes_length)
J = dataset.discounted_return.mean()
R = dataset.undiscounted_return.mean()
L = dataset.episodes_length.mean()
logger.log_numpy(R=R, J=J, L=L)
logger.epoch_info(i, R=R, J=J, L=L)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
from .deep_actor_critic import DeepAC
from .deep_actor_critic import OnPolicyDeepAC, DeepAC
from .a2c import A2C
from .ddpg import DDPG
from .td3 import TD3
from .sac import SAC
from .trpo import TRPO
from .ppo import PPO
from .ppo_bptt import PPO_BPTT

__all__ = ['DeepAC', 'A2C', 'DDPG', 'TD3', 'SAC', 'TRPO', 'PPO', 'PPO_BPTT']
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,27 @@
from mushroom_rl.utils.torch import TorchUtils


class OnPolicyDeepAC(Agent):
def _preprocess_state(self, state, next_state, output_old=True):
state_old = None

if output_old:
state_old = self._agent_preprocess(state)

self._update_agent_preprocessor(state)
state = self._agent_preprocess(state)
next_state = self._agent_preprocess(next_state)

if output_old:
return state, next_state, state_old
else:
return state, next_state


class DeepAC(Agent):
"""
Base class for algorithms that uses the reparametrization trick, such as
SAC, DDPG and TD3.
Base class for off policy deep actor-critic algorithms.
These algorithms use the reparametrization trick, such as SAC, DDPG and TD3.
"""

Expand Down
7 changes: 4 additions & 3 deletions mushroom_rl/algorithms/actor_critic/deep_actor_critic/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch
import torch.nn.functional as F

from mushroom_rl.core import Agent
from mushroom_rl.algorithms.actor_critic.deep_actor_critic import OnPolicyDeepAC
from mushroom_rl.approximators import Regressor
from mushroom_rl.approximators.parametric import TorchApproximator
from mushroom_rl.utils.torch import TorchUtils
Expand All @@ -12,7 +12,7 @@
from mushroom_rl.rl_utils.parameters import to_parameter


class PPO(Agent):
class PPO(OnPolicyDeepAC):
"""
Proximal Policy Optimization algorithm.
"Proximal Policy Optimization Algorithms".
Expand Down Expand Up @@ -72,6 +72,7 @@ def __init__(self, mdp_info, policy, actor_optimizer, critic_params,

def fit(self, dataset):
state, action, reward, next_state, absorbing, last = dataset.parse(to='torch')
state, next_state, state_old = self._preprocess_state(state, next_state)

v_target, adv = compute_gae(self._V, state, next_state, reward, absorbing, last, self.mdp_info.gamma,
self._lambda())
Expand All @@ -80,7 +81,7 @@ def fit(self, dataset):
adv = adv.detach()
v_target = v_target.detach()

old_pol_dist = self.policy.distribution_t(state)
old_pol_dist = self.policy.distribution_t(state_old)
old_log_p = old_pol_dist.log_prob(action)[:, None].detach()

self._V.fit(state, v_target, **self._critic_fit_params)
Expand Down
136 changes: 69 additions & 67 deletions mushroom_rl/algorithms/actor_critic/deep_actor_critic/ppo_bptt.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
import torch

from mushroom_rl.core import Agent
from mushroom_rl.algorithms.actor_critic.deep_actor_critic import OnPolicyDeepAC
from mushroom_rl.approximators import Regressor
from mushroom_rl.approximators.parametric import TorchApproximator
from mushroom_rl.utils.torch import TorchUtils
from mushroom_rl.utils.minibatches import minibatch_generator
from mushroom_rl.rl_utils.parameters import to_parameter
from mushroom_rl.rl_utils.preprocessors import StandardizationPreprocessor


class PPO_BPTT(Agent):
class PPO_BPTT(OnPolicyDeepAC):
"""
Proximal Policy Optimization algorithm.
"Proximal Policy Optimization Algorithms".
Expand Down Expand Up @@ -71,81 +70,84 @@ def __init__(self, mdp_info, policy, actor_optimizer, critic_params,
_dim_env_state='primitive'
)

# add the standardization preprocessor
self._core_preprocessors.append(StandardizationPreprocessor(mdp_info))

def divide_state_to_env_hidden_batch(self, states):
assert len(states.shape) > 1, "This function only divides batches of states."
return states[:, 0:self._dim_env_state], states[:, self._dim_env_state:]

def fit(self, dataset):
obs, act, r, obs_next, absorbing, last = dataset.parse(to='torch')
state, action, reward, next_state, absorbing, last = dataset.parse(to='torch')
state, next_state, state_old = self._preprocess_state(state, next_state)

policy_state, policy_next_state = dataset.parse_policy_state(to='torch')
obs_seq, policy_state_seq, act_seq, obs_next_seq, policy_next_state_seq, lengths = \
self.transform_to_sequences(obs, policy_state, act, obs_next, policy_next_state, last, absorbing)
state_old_seq, state_seq, policy_state_seq, act_seq, state_next_seq, policy_next_state_seq, lengths = \
self._transform_to_sequences(state_old, state, policy_state, action, next_state, policy_next_state,
last, absorbing)

v_target, adv = self.compute_gae(self._V, obs_seq, policy_state_seq, obs_next_seq, policy_next_state_seq,
lengths, r, absorbing, last, self.mdp_info.gamma, self._lambda())
v_target, adv = self.compute_gae(self._V, state_seq, policy_state_seq, state_next_seq, policy_next_state_seq,
lengths, reward, absorbing, last, self.mdp_info.gamma, self._lambda())
adv = (adv - torch.mean(adv)) / (torch.std(adv) + 1e-8)

old_pol_dist = self.policy.distribution_t(obs_seq, policy_state_seq, lengths)
old_log_p = old_pol_dist.log_prob(act)[:, None].detach()
old_pol_dist = self.policy.distribution_t(state_old_seq, policy_state_seq, lengths)
old_log_p = old_pol_dist.log_prob(action)[:, None].detach()

self._V.fit(obs_seq, policy_state_seq, lengths, v_target, **self._critic_fit_params)
self._V.fit(state_seq, policy_state_seq, lengths, v_target, **self._critic_fit_params)

self._update_policy(obs_seq, policy_state_seq, act, lengths, adv, old_log_p)
self._update_policy(state_seq, policy_state_seq, action, lengths, adv, old_log_p)

# Print fit information
self._log_info(dataset, obs_seq, policy_state_seq, lengths, v_target, old_pol_dist)
self._log_info(dataset, state_seq, policy_state_seq, lengths, v_target, old_pol_dist)
self._iter += 1

def transform_to_sequences(self, states, policy_states, actions, next_states, policy_next_states, last, absorbing):

s = torch.empty(len(states), self._truncation_length, states.shape[-1])
ps = torch.empty(len(states), policy_states.shape[-1])
a = torch.empty(len(actions), self._truncation_length, actions.shape[-1])
ss = torch.empty(len(states), self._truncation_length, states.shape[-1])
pss = torch.empty(len(states), policy_states.shape[-1])
lengths = torch.empty(len(states), dtype=torch.long)

for i in range(len(states)):
# determine the begin of a sequence
begin_seq = max(i - self._truncation_length + 1, 0)
end_seq = i + 1

# maybe the sequence contains more than one trajectory, so we need to cut it so that it contains only one
lasts_absorbing = last[begin_seq - 1: i].int() + absorbing[begin_seq - 1: i].int()
begin_traj = torch.where(lasts_absorbing > 0)
sequence_is_shorter_than_requested = len(*begin_traj) > 0
if sequence_is_shorter_than_requested:
begin_seq = begin_seq + begin_traj[0][-1]

# get the sequences
states_seq = states[begin_seq:end_seq]
actions_seq = actions[begin_seq:end_seq]
next_states_seq = next_states[begin_seq:end_seq]

# apply padding
length_seq = len(states_seq)
padded_states = torch.concatenate([states_seq,
torch.zeros((self._truncation_length - states_seq.shape[0],
states_seq.shape[1]))])
padded_next_states = torch.concatenate([next_states_seq,
torch.zeros((self._truncation_length - next_states_seq.shape[0],
next_states_seq.shape[1]))])
padded_action_seq = torch.concatenate([actions_seq,
torch.zeros((self._truncation_length - actions_seq.shape[0],
actions_seq.shape[1]))])

s[i] = padded_states
ps[i] = policy_states[begin_seq]
a[i] = padded_action_seq
ss[i] = padded_next_states
pss[i] = policy_next_states[begin_seq]

lengths[i] = length_seq

return s.detach(), ps.detach(), a.detach(), ss.detach(), pss.detach(), lengths.detach()
def _transform_to_sequences(self, states_old, states, policy_states, actions, next_states, policy_next_states,
last, absorbing):
with torch.no_grad():
s_old = torch.empty(len(states), self._truncation_length, states.shape[-1])
s = torch.empty(len(states), self._truncation_length, states.shape[-1])
ps = torch.empty(len(states), policy_states.shape[-1])
a = torch.empty(len(actions), self._truncation_length, actions.shape[-1])
ss = torch.empty(len(states), self._truncation_length, states.shape[-1])
pss = torch.empty(len(states), policy_states.shape[-1])
lengths = torch.empty(len(states), dtype=torch.long)

for i in range(len(states)):
# determine the begin of a sequence
begin_seq = max(i - self._truncation_length + 1, 0)
end_seq = i + 1

# the sequence may contain more than one trajectory, we need to cut it so that it contains only one
lasts_absorbing = last[begin_seq - 1: i].int() + absorbing[begin_seq - 1: i].int()
begin_traj = torch.where(lasts_absorbing > 0)
sequence_is_shorter_than_requested = len(*begin_traj) > 0
if sequence_is_shorter_than_requested:
begin_seq = begin_seq + begin_traj[0][-1]

# get the sequences
states_old_seq = states_old[begin_seq:end_seq]
states_seq = states[begin_seq:end_seq]
actions_seq = actions[begin_seq:end_seq]
next_states_seq = next_states[begin_seq:end_seq]

# apply padding
length_seq = len(states_seq)
padded_states_old = torch.concatenate([states_old_seq,
torch.zeros((self._truncation_length - states_seq.shape[0],
states_seq.shape[1]))])
padded_states = torch.concatenate([states_seq,
torch.zeros((self._truncation_length - states_seq.shape[0],
states_seq.shape[1]))])
padded_next_states = torch.concatenate([next_states_seq,
torch.zeros((self._truncation_length - next_states_seq.shape[0],
next_states_seq.shape[1]))])
padded_action_seq = torch.concatenate([actions_seq,
torch.zeros((self._truncation_length - actions_seq.shape[0],
actions_seq.shape[1]))])

s_old[i] = padded_states_old
s[i] = padded_states
ps[i] = policy_states[begin_seq]
a[i] = padded_action_seq
ss[i] = padded_next_states
pss[i] = policy_next_states[begin_seq]

lengths[i] = length_seq

return s_old, s, ps, a, ss, pss, lengths

def _update_policy(self, obs, pi_h, act, lengths, adv, old_log_p):
for epoch in range(self._n_epochs_policy()):
Expand Down
9 changes: 5 additions & 4 deletions mushroom_rl/algorithms/actor_critic/deep_actor_critic/trpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,15 @@
import torch
import torch.nn.functional as F

from mushroom_rl.core import Agent
from mushroom_rl.algorithms.actor_critic.deep_actor_critic import OnPolicyDeepAC
from mushroom_rl.approximators import Regressor
from mushroom_rl.approximators.parametric import TorchApproximator
from mushroom_rl.utils.torch import TorchUtils
from mushroom_rl.rl_utils.value_functions import compute_gae
from mushroom_rl.rl_utils.parameters import to_parameter


class TRPO(Agent):
class TRPO(OnPolicyDeepAC):
"""
Trust Region Policy optimization algorithm.
"Trust Region Policy Optimization".
Expand Down Expand Up @@ -83,6 +83,7 @@ def __init__(self, mdp_info, policy, critic_params, ent_coeff=0., max_kl=.001, l

def fit(self, dataset):
state, action, reward, next_state, absorbing, last = dataset.parse(to='torch')
state, next_state, state_old = self._preprocess_state(state, next_state)

v_target, adv = compute_gae(self._V, state, next_state, reward, absorbing, last,
self.mdp_info.gamma, self._lambda())
Expand All @@ -93,8 +94,8 @@ def fit(self, dataset):

# Policy update
self._old_policy = deepcopy(self.policy)
old_pol_dist = self._old_policy.distribution_t(state)
old_log_prob = self._old_policy.log_prob_t(state, action).detach()
old_pol_dist = self._old_policy.distribution_t(state_old)
old_log_prob = self._old_policy.log_prob_t(state_old, action).detach()

TorchUtils.zero_grad(self.policy.parameters())
loss = self._compute_loss(state, action, adv, old_log_prob)
Expand Down
56 changes: 28 additions & 28 deletions mushroom_rl/core/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,34 +103,6 @@ def draw_action(self, state, policy_state=None):

return self._convert_to_env_backend(action), self._convert_to_env_backend(next_policy_state)

def _agent_preprocess(self, state):
"""
Applies all the agent's preprocessors to the state.
Args:
state (Array): the state where the agent is;
Returns:
The preprocessed state.
"""
for p in self._agent_preprocessors:
state = p(state)
return state

def _update_agent_preprocessor(self, state):
"""
Updates the stats of all the agent's preprocessors given the state.
Args:
state (Array): the state where the agent is;
"""
for i, p in enumerate(self._agent_preprocessors, 1):
p.update(state)
if i < len(self._agent_preprocessors):
state = p(state)

def episode_start(self, initial_state, episode_info):
"""
Called by the Core when a new episode starts.
Expand Down Expand Up @@ -214,6 +186,34 @@ def _convert_to_env_backend(self, array):
def _convert_to_agent_backend(self, array):
return self._agent_backend.convert_to_backend(self._env_backend, array)

def _agent_preprocess(self, state):
"""
Applies all the agent's preprocessors to the state.
Args:
state (Array): the state where the agent is;
Returns:
The preprocessed state.
"""
for p in self._agent_preprocessors:
state = p(state)
return state

def _update_agent_preprocessor(self, state):
"""
Updates the stats of all the agent's preprocessors given the state.
Args:
state (Array): the state where the agent is;
"""
for i, p in enumerate(self._agent_preprocessors, 1):
p.update(state)
if i < len(self._agent_preprocessors):
state = p(state)

@property
def info(self):
return self._info
Expand Down
Loading

0 comments on commit a9f3f1a

Please sign in to comment.