Skip to content

Commit

Permalink
Refactor code and apply formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
ikostrikov2 committed Mar 14, 2019
1 parent 88080da commit f60ac80
Show file tree
Hide file tree
Showing 12 changed files with 426 additions and 285 deletions.
3 changes: 2 additions & 1 deletion a2c_ppo_acktr/algo/a2c_acktr.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ def update(self, rollouts):

values, action_log_probs, dist_entropy, _ = self.actor_critic.evaluate_actions(
rollouts.obs[:-1].view(-1, *obs_shape),
rollouts.recurrent_hidden_states[0].view(-1, self.actor_critic.recurrent_hidden_state_size),
rollouts.recurrent_hidden_states[0].view(
-1, self.actor_critic.recurrent_hidden_state_size),
rollouts.masks[:-1].view(-1, 1),
rollouts.actions.view(-1, action_shape))

Expand Down
15 changes: 9 additions & 6 deletions a2c_ppo_acktr/algo/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,21 +55,24 @@ def update(self, rollouts):

# Reshape to do in a single forward pass for all steps
values, action_log_probs, dist_entropy, _ = self.actor_critic.evaluate_actions(
obs_batch, recurrent_hidden_states_batch,
masks_batch, actions_batch)
obs_batch, recurrent_hidden_states_batch, masks_batch,
actions_batch)

ratio = torch.exp(action_log_probs - old_action_log_probs_batch)
ratio = torch.exp(action_log_probs -
old_action_log_probs_batch)
surr1 = ratio * adv_targ
surr2 = torch.clamp(ratio, 1.0 - self.clip_param,
1.0 + self.clip_param) * adv_targ
1.0 + self.clip_param) * adv_targ
action_loss = -torch.min(surr1, surr2).mean()

if self.use_clipped_value_loss:
value_pred_clipped = value_preds_batch + \
(values - value_preds_batch).clamp(-self.clip_param, self.clip_param)
value_losses = (values - return_batch).pow(2)
value_losses_clipped = (value_pred_clipped - return_batch).pow(2)
value_loss = 0.5 * torch.max(value_losses, value_losses_clipped).mean()
value_losses_clipped = (
value_pred_clipped - return_batch).pow(2)
value_loss = 0.5 * torch.max(value_losses,
value_losses_clipped).mean()
else:
value_loss = 0.5 * (return_batch - values).pow(2).mean()

Expand Down
191 changes: 133 additions & 58 deletions a2c_ppo_acktr/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,64 +5,139 @@

def get_args():
parser = argparse.ArgumentParser(description='RL')
parser.add_argument('--algo', default='a2c',
help='algorithm to use: a2c | ppo | acktr')
parser.add_argument('--lr', type=float, default=7e-4,
help='learning rate (default: 7e-4)')
parser.add_argument('--eps', type=float, default=1e-5,
help='RMSprop optimizer epsilon (default: 1e-5)')
parser.add_argument('--alpha', type=float, default=0.99,
help='RMSprop optimizer apha (default: 0.99)')
parser.add_argument('--gamma', type=float, default=0.99,
help='discount factor for rewards (default: 0.99)')
parser.add_argument('--use-gae', action='store_true', default=False,
help='use generalized advantage estimation')
parser.add_argument('--tau', type=float, default=0.95,
help='gae parameter (default: 0.95)')
parser.add_argument('--entropy-coef', type=float, default=0.01,
help='entropy term coefficient (default: 0.01)')
parser.add_argument('--value-loss-coef', type=float, default=0.5,
help='value loss coefficient (default: 0.5)')
parser.add_argument('--max-grad-norm', type=float, default=0.5,
help='max norm of gradients (default: 0.5)')
parser.add_argument('--seed', type=int, default=1,
help='random seed (default: 1)')
parser.add_argument('--cuda-deterministic', action='store_true', default=False,
help="sets flags for determinism when using CUDA (potentially slow!)")
parser.add_argument('--num-processes', type=int, default=16,
help='how many training CPU processes to use (default: 16)')
parser.add_argument('--num-steps', type=int, default=5,
help='number of forward steps in A2C (default: 5)')
parser.add_argument('--ppo-epoch', type=int, default=4,
help='number of ppo epochs (default: 4)')
parser.add_argument('--num-mini-batch', type=int, default=32,
help='number of batches for ppo (default: 32)')
parser.add_argument('--clip-param', type=float, default=0.2,
help='ppo clip parameter (default: 0.2)')
parser.add_argument('--log-interval', type=int, default=10,
help='log interval, one log per n updates (default: 10)')
parser.add_argument('--save-interval', type=int, default=100,
help='save interval, one save per n updates (default: 100)')
parser.add_argument('--eval-interval', type=int, default=None,
help='eval interval, one eval per n updates (default: None)')
parser.add_argument('--num-env-steps', type=int, default=10e6,
help='number of environment steps to train (default: 10e6)')
parser.add_argument('--env-name', default='PongNoFrameskip-v4',
help='environment to train on (default: PongNoFrameskip-v4)')
parser.add_argument('--log-dir', default='/tmp/gym/',
help='directory to save agent logs (default: /tmp/gym)')
parser.add_argument('--save-dir', default='./trained_models/',
help='directory to save agent logs (default: ./trained_models/)')
parser.add_argument('--no-cuda', action='store_true', default=False,
help='disables CUDA training')
parser.add_argument('--use-proper-time-limits', action='store_true', default=False,
help='compute returns taking into account time limits')
parser.add_argument('--recurrent-policy', action='store_true', default=False,
help='use a recurrent policy')
parser.add_argument('--use-linear-lr-decay', action='store_true', default=False,
help='use a linear schedule on the learning rate')
parser.add_argument('--use-linear-clip-decay', action='store_true', default=False,
help='use a linear schedule on the ppo clipping parameter')
parser.add_argument(
'--algo', default='a2c', help='algorithm to use: a2c | ppo | acktr')
parser.add_argument(
'--lr', type=float, default=7e-4, help='learning rate (default: 7e-4)')
parser.add_argument(
'--eps',
type=float,
default=1e-5,
help='RMSprop optimizer epsilon (default: 1e-5)')
parser.add_argument(
'--alpha',
type=float,
default=0.99,
help='RMSprop optimizer apha (default: 0.99)')
parser.add_argument(
'--gamma',
type=float,
default=0.99,
help='discount factor for rewards (default: 0.99)')
parser.add_argument(
'--use-gae',
action='store_true',
default=False,
help='use generalized advantage estimation')
parser.add_argument(
'--tau',
type=float,
default=0.95,
help='gae parameter (default: 0.95)')
parser.add_argument(
'--entropy-coef',
type=float,
default=0.01,
help='entropy term coefficient (default: 0.01)')
parser.add_argument(
'--value-loss-coef',
type=float,
default=0.5,
help='value loss coefficient (default: 0.5)')
parser.add_argument(
'--max-grad-norm',
type=float,
default=0.5,
help='max norm of gradients (default: 0.5)')
parser.add_argument(
'--seed', type=int, default=1, help='random seed (default: 1)')
parser.add_argument(
'--cuda-deterministic',
action='store_true',
default=False,
help="sets flags for determinism when using CUDA (potentially slow!)")
parser.add_argument(
'--num-processes',
type=int,
default=16,
help='how many training CPU processes to use (default: 16)')
parser.add_argument(
'--num-steps',
type=int,
default=5,
help='number of forward steps in A2C (default: 5)')
parser.add_argument(
'--ppo-epoch',
type=int,
default=4,
help='number of ppo epochs (default: 4)')
parser.add_argument(
'--num-mini-batch',
type=int,
default=32,
help='number of batches for ppo (default: 32)')
parser.add_argument(
'--clip-param',
type=float,
default=0.2,
help='ppo clip parameter (default: 0.2)')
parser.add_argument(
'--log-interval',
type=int,
default=10,
help='log interval, one log per n updates (default: 10)')
parser.add_argument(
'--save-interval',
type=int,
default=100,
help='save interval, one save per n updates (default: 100)')
parser.add_argument(
'--eval-interval',
type=int,
default=None,
help='eval interval, one eval per n updates (default: None)')
parser.add_argument(
'--num-env-steps',
type=int,
default=10e6,
help='number of environment steps to train (default: 10e6)')
parser.add_argument(
'--env-name',
default='PongNoFrameskip-v4',
help='environment to train on (default: PongNoFrameskip-v4)')
parser.add_argument(
'--log-dir',
default='/tmp/gym/',
help='directory to save agent logs (default: /tmp/gym)')
parser.add_argument(
'--save-dir',
default='./trained_models/',
help='directory to save agent logs (default: ./trained_models/)')
parser.add_argument(
'--no-cuda',
action='store_true',
default=False,
help='disables CUDA training')
parser.add_argument(
'--use-proper-time-limits',
action='store_true',
default=False,
help='compute returns taking into account time limits')
parser.add_argument(
'--recurrent-policy',
action='store_true',
default=False,
help='use a recurrent policy')
parser.add_argument(
'--use-linear-lr-decay',
action='store_true',
default=False,
help='use a linear schedule on the learning rate')
parser.add_argument(
'--use-linear-clip-decay',
action='store_true',
default=False,
help='use a linear schedule on the ppo clipping parameter')
args = parser.parse_args()

args.cuda = not args.no_cuda and torch.cuda.is_available()
Expand Down
26 changes: 13 additions & 13 deletions a2c_ppo_acktr/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import torch.nn.functional as F

from a2c_ppo_acktr.utils import AddBias, init

"""
Modify standard PyTorch distributions so they are compatible with this code.
"""
Expand All @@ -21,28 +20,30 @@
FixedCategorical.sample = lambda self: old_sample(self).unsqueeze(-1)

log_prob_cat = FixedCategorical.log_prob
FixedCategorical.log_probs = lambda self, actions: log_prob_cat(self, actions.squeeze(-1)).view(actions.size(0), -1).sum(-1).unsqueeze(-1)
FixedCategorical.log_probs = lambda self, actions: log_prob_cat(
self, actions.squeeze(-1)).view(actions.size(0), -1).sum(-1).unsqueeze(-1)

FixedCategorical.mode = lambda self: self.probs.argmax(dim=-1, keepdim=True)


# Normal
FixedNormal = torch.distributions.Normal

log_prob_normal = FixedNormal.log_prob
FixedNormal.log_probs = lambda self, actions: log_prob_normal(self, actions).sum(-1, keepdim=True)
FixedNormal.log_probs = lambda self, actions: log_prob_normal(
self, actions).sum(
-1, keepdim=True)

normal_entropy = FixedNormal.entropy
FixedNormal.entropy = lambda self: normal_entropy(self).sum(-1)

FixedNormal.mode = lambda self: self.mean


# Bernoulli
FixedBernoulli = torch.distributions.Bernoulli

log_prob_bernoulli = FixedBernoulli.log_prob
FixedBernoulli.log_probs = lambda self, actions: log_prob_bernoulli(self, actions).view(actions.size(0), -1).sum(-1).unsqueeze(-1)
FixedBernoulli.log_probs = lambda self, actions: log_prob_bernoulli(
self, actions).view(actions.size(0), -1).sum(-1).unsqueeze(-1)

bernoulli_entropy = FixedBernoulli.entropy
FixedBernoulli.entropy = lambda self: bernoulli_entropy(self).sum(-1)
Expand All @@ -53,7 +54,8 @@ class Categorical(nn.Module):
def __init__(self, num_inputs, num_outputs):
super(Categorical, self).__init__()

init_ = lambda m: init(m,
init_ = lambda m: init(
m,
nn.init.orthogonal_,
lambda x: nn.init.constant_(x, 0),
gain=0.01)
Expand All @@ -69,9 +71,8 @@ class DiagGaussian(nn.Module):
def __init__(self, num_inputs, num_outputs):
super(DiagGaussian, self).__init__()

init_ = lambda m: init(m,
nn.init.orthogonal_,
lambda x: nn.init.constant_(x, 0))
init_ = lambda m: init(m, nn.init.orthogonal_, lambda x: nn.init.
constant_(x, 0))

self.fc_mean = init_(nn.Linear(num_inputs, num_outputs))
self.logstd = AddBias(torch.zeros(num_outputs))
Expand All @@ -92,9 +93,8 @@ class Bernoulli(nn.Module):
def __init__(self, num_inputs, num_outputs):
super(Bernoulli, self).__init__()

init_ = lambda m: init(m,
nn.init.orthogonal_,
lambda x: nn.init.constant_(x, 0))
init_ = lambda m: init(m, nn.init.orthogonal_, lambda x: nn.init.
constant_(x, 0))

self.linear = init_(nn.Linear(num_inputs, num_outputs))

Expand Down
Loading

0 comments on commit f60ac80

Please sign in to comment.