Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
90 commits
Select commit Hold shift + click to select a range
222bf13
keep_dims -> keepdims
alok May 2, 2018
a2ab3f7
WIP
alok May 2, 2018
7fc25c6
Get single sample
alok May 2, 2018
470d56a
add test scripts
alok May 2, 2018
19b6510
WIP
alok May 3, 2018
4a24483
Silence PyTorch warnings
alok May 3, 2018
4c52ede
Test A3C with only 1 worker
alok May 3, 2018
18b7692
Use PyTorch's new scalar support
alok May 3, 2018
6745a4b
Use F.mse_loss instead of rolling our own
alok May 3, 2018
aaabe16
Use correct samplebatch key
alok May 3, 2018
5e7fe40
Write magic methods for SampleBatch/PartialRollout
alok May 4, 2018
3d53186
WIP
alok May 4, 2018
19490f4
Fix IndentationError
alok May 6, 2018
0f0a17b
rm Variable for torch 0.4.0
alok May 8, 2018
1dade7c
misc
alok May 8, 2018
353cff9
Fix some shape errors in TRPO
alok May 8, 2018
0fd8f4a
Merge branch 'master' into trpo
alok May 8, 2018
06611c2
Use kl_divergence provided by PyTorch
alok May 8, 2018
06192a5
Use detach() over .data
alok May 8, 2018
bc3dca6
Rename variables
alok May 8, 2018
779fa5c
rm unnecessary `probs` attribute
alok May 8, 2018
2c46098
Fix rewards shape
alok May 9, 2018
0125eef
size -> out_size in SlimFC
alok May 9, 2018
a3ec08d
.size() -> .shape
alok May 9, 2018
6d45db2
Use chain to adjust only action head
alok May 9, 2018
eda75ae
WIP
alok May 9, 2018
8ba4c17
rm trailing comma
alok May 9, 2018
2d43fd0
Update test scripts
alok May 9, 2018
fec90b7
leave note to debug remote_evaluators
alok May 9, 2018
c2592fd
Use .item() to extract number from torch scalar
alok May 9, 2018
3fae65f
temporary fix to zero gradient
alok May 9, 2018
9601b69
fmt
alok May 9, 2018
dfd38c0
rm test scripts
alok May 9, 2018
280be28
fmt
alok May 11, 2018
2fdac3b
Undo lint changes
alok May 11, 2018
be5a5e5
Undo fmt
alok May 11, 2018
75e2d1e
Undo magic methods
alok May 11, 2018
e0da980
Merge branch 'master' into trpo
alok May 11, 2018
23535a8
Calculate remote evaluator only once
alok May 11, 2018
1ea501f
rm needless copy
alok May 11, 2018
82a6698
Merge branch 'master' into trpo
alok May 11, 2018
0151a0a
rm unnecessary copy
alok May 11, 2018
51901ad
Use F.softmax instead of a pointless network layer
alok May 11, 2018
5d7fc19
Use correct pytorch functions
alok May 11, 2018
8583616
Rename argument name to out_size
alok May 11, 2018
18c4a4c
Fix shapes of tensors
alok May 11, 2018
64ae2ab
Fmt
alok May 11, 2018
5b623c0
Register TRPO with other agents
alok May 12, 2018
3073e09
Drop use of numpy in TRPO _backward
alok May 12, 2018
85b5fbf
Re-add deepcopy as stopgap measure
alok May 12, 2018
8accdae
replace deprecated function
alok May 12, 2018
8645cd7
rm unnecessary Variable wrapper
alok May 14, 2018
12bd5d6
Clarify variable name
alok May 14, 2018
47e8ebd
rm all use of torch Variables
alok May 14, 2018
f9e4797
Merge branch 'master' into fix-a3c-torch
alok May 14, 2018
884a6a8
Ensure that values are flat list
alok May 14, 2018
7d1b205
Fix shape error in conv nets
alok May 14, 2018
8542341
Merge branch 'fix-a3c-torch' into trpo
alok May 15, 2018
72f2afc
rm unused import
alok May 15, 2018
2850279
rm unused functions
alok May 15, 2018
8999714
Handle partial rollouts
alok May 15, 2018
aeab1f3
Merge branch 'master' into fix-a3c-torch
alok May 17, 2018
662eaa5
Merge branch 'master' into fix-a3c-2
alok May 21, 2018
f9561d3
fmt
alok May 21, 2018
7f06a1f
Fix shape errors
alok May 24, 2018
0438707
Add TODO
alok May 24, 2018
da8d9e6
Use correct filter size
alok May 24, 2018
db9804d
Add missing channel major
alok May 24, 2018
9aac5bd
Merge branch 'fix-a3c-torch' into trpo
alok May 24, 2018
9246b65
Merge branch 'master' into trpo
alok May 24, 2018
e865a09
Merge branch 'master' into fix-a3c-torch
alok May 25, 2018
a62fa6e
Merge branch 'fix-a3c-torch' into trpo
alok May 25, 2018
27cd897
Revert reshape of action
alok May 25, 2018
75ea9a7
Squeeze action
alok May 29, 2018
87ab87e
Squeeze actions along first dimension
alok May 29, 2018
9acd029
try adding pytorch tests
richardliaw May 29, 2018
c4b8ca7
typo
richardliaw May 29, 2018
6a79793
fixup docker messages
richardliaw May 29, 2018
7cdedf3
Fix A3C for some envs
alok May 30, 2018
da414fc
fmt
alok May 30, 2018
3b9234f
nit flake
richardliaw May 30, 2018
9ddab77
small lint
richardliaw May 30, 2018
ca9b33c
Merge branch 'fix-a3c-torch' into trpo
alok May 30, 2018
51dc392
Use A3C's save/restore/optimizer
alok Jun 1, 2018
f3d401f
ent_coeff -> entropy_coeff
alok Jun 1, 2018
c57989d
Clean up config dicts
alok Jun 1, 2018
6c15780
fmt
alok Jun 1, 2018
cbbaf32
Use async optimizer for TRPO
alok Jun 3, 2018
78ab9d4
Use single quotes
alok Jun 3, 2018
826a2e1
Merge branch 'master' into trpo
alok Jun 3, 2018
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/ray/rllib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

def _register_all():
for key in ["PPO", "ES", "DQN", "APEX", "A3C", "BC", "PG", "DDPG",
"DDPG2", "APEX_DDPG", "__fake", "__sigmoid_fake_data",
"DDPG2", "APEX_DDPG", "TRPO", "__fake", "__sigmoid_fake_data",
"__parameter_tuning"]:
from ray.rllib.agent import get_agent_class
register_trainable(key, get_agent_class(key))
Expand Down
39 changes: 23 additions & 16 deletions python/ray/rllib/a3c/a3c.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,29 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import absolute_import, division, print_function

import numpy as np
import pickle
import os
import pickle

import numpy as np

import ray
from ray.rllib.a3c.a3c_evaluator import (A3CEvaluator, GPURemoteA3CEvaluator,
RemoteA3CEvaluator,)
from ray.rllib.agent import Agent
from ray.rllib.optimizers import AsyncOptimizer
from ray.rllib.utils import FilterManager
from ray.rllib.a3c.a3c_evaluator import A3CEvaluator, RemoteA3CEvaluator, \
GPURemoteA3CEvaluator
from ray.tune.result import TrainingResult
from ray.tune.trial import Resources

DEFAULT_CONFIG = {
# Number of workers (excluding master)
"num_workers": 4,
# Size of rollout batch
# Size of rollout
"batch_size": 10,
# Use LSTM model - only applicable for image states
# Only applicable for image states
"use_lstm": False,
# Use PyTorch as backend - no LSTM support
# No LSTM support if PyTorch used
"use_pytorch": False,
# Which observation filter to apply to the observation
"observation_filter": "NoFilter",
# Which reward filter to apply to the reward
"reward_filter": "NoFilter",
# Discount factor of MDP
"gamma": 0.99,
Expand All @@ -36,9 +33,7 @@
"grad_clip": 40.0,
# Learning rate
"lr": 0.0001,
# Value Function Loss coefficient
"vf_loss_coeff": 0.5,
# Entropy coefficient
"entropy_coeff": -0.01,
# Whether to place workers on GPUs
"use_gpu_for_workers": False,
Expand Down Expand Up @@ -84,15 +79,18 @@ def _init(self):
self.config,
self.logdir,
start_sampler=False)

if self.config["use_gpu_for_workers"]:
remote_cls = GPURemoteA3CEvaluator
else:
remote_cls = RemoteA3CEvaluator

self.remote_evaluators = [
remote_cls.remote(self.registry, self.env_creator, self.config,
self.logdir)
for i in range(self.config["num_workers"])
]

self.optimizer = AsyncOptimizer(self.config["optimizer"],
self.local_evaluator,
self.remote_evaluators)
Expand All @@ -101,20 +99,23 @@ def _train(self):
self.optimizer.step()
FilterManager.synchronize(self.local_evaluator.filters,
self.remote_evaluators)
res = self._fetch_metrics_from_remote_evaluators()
return res
result = self._fetch_metrics_from_remote_evaluators()
return result

def _fetch_metrics_from_remote_evaluators(self):
episode_rewards = []
episode_lengths = []

metric_lists = [
a.get_completed_rollout_metrics.remote()
for a in self.remote_evaluators
]

for metrics in metric_lists:
for episode in ray.get(metrics):
episode_lengths.append(episode.episode_length)
episode_rewards.append(episode.episode_reward)

avg_reward = (np.mean(episode_rewards)
if episode_rewards else float('nan'))
avg_length = (np.mean(episode_lengths)
Expand All @@ -137,21 +138,27 @@ def _stop(self):
def _save(self, checkpoint_dir):
checkpoint_path = os.path.join(checkpoint_dir,
"checkpoint-{}".format(self.iteration))

agent_state = ray.get(
[a.save.remote() for a in self.remote_evaluators])

extra_data = {
"remote_state": agent_state,
"local_state": self.local_evaluator.save()
}

pickle.dump(extra_data, open(checkpoint_path + ".extra_data", "wb"))

return checkpoint_path

def _restore(self, checkpoint_path):
extra_data = pickle.load(open(checkpoint_path + ".extra_data", "rb"))

ray.get([
a.restore.remote(o)
for a, o in zip(self.remote_evaluators, extra_data["remote_state"])
])

self.local_evaluator.restore(extra_data["local_state"])

def compute_action(self, observation):
Expand Down
35 changes: 22 additions & 13 deletions python/ray/rllib/a3c/a3c_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,23 +26,31 @@ class A3CEvaluator(PolicyEvaluator):
rollouts.
logdir: Directory for logging.
"""
def __init__(
self, registry, env_creator, config, logdir, start_sampler=True):

def __init__(self,
registry,
env_creator,
config,
logdir,
start_sampler=True):
env = ModelCatalog.get_preprocessor_as_wrapper(
registry, env_creator(config["env_config"]), config["model"])
self.env = env
policy_cls = get_policy_cls(config)
# TODO(rliaw): should change this to be just env.observation_space
self.policy = policy_cls(
registry, env.observation_space.shape, env.action_space, config)
self.policy = policy_cls(registry, env.observation_space.shape,
env.action_space, config)
self.config = config

# Technically not needed when not remote
self.obs_filter = get_filter(
config["observation_filter"], env.observation_space.shape)
self.obs_filter = get_filter(config["observation_filter"],
env.observation_space.shape)
self.rew_filter = get_filter(config["reward_filter"], ())
self.filters = {"obs_filter": self.obs_filter,
"rew_filter": self.rew_filter}
self.filters = {
"obs_filter": self.obs_filter,
"rew_filter": self.rew_filter
}

self.sampler = AsyncSampler(env, self.policy, self.obs_filter,
config["batch_size"])
if start_sampler and self.sampler._async:
Expand All @@ -52,8 +60,11 @@ def __init__(
def sample(self):
rollout = self.sampler.get_data()
samples = process_rollout(
rollout, self.rew_filter, gamma=self.config["gamma"],
lambda_=self.config["lambda"], use_gae=True)
rollout,
self.rew_filter,
gamma=self.config["gamma"],
lambda_=self.config["lambda"],
use_gae=True)
return samples

def get_completed_rollout_metrics(self):
Expand All @@ -79,9 +90,7 @@ def set_weights(self, params):
def save(self):
filters = self.get_filters(flush_after=True)
weights = self.get_weights()
return pickle.dumps({
"filters": filters,
"weights": weights})
return pickle.dumps({"filters": filters, "weights": weights})

def restore(self, objs):
objs = pickle.loads(objs)
Expand Down
12 changes: 7 additions & 5 deletions python/ray/rllib/a3c/shared_torch_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
import torch.nn.functional as F

from ray.rllib.a3c.torchpolicy import TorchPolicy
from ray.rllib.models.pytorch.misc import var_to_np, convert_batch
from ray.rllib.models.catalog import ModelCatalog
from ray.rllib.models.pytorch.misc import convert_batch, var_to_np


class SharedTorchPolicy(TorchPolicy):
Expand All @@ -28,7 +28,7 @@ def _setup_graph(self, ob_space, ac_space):
self._model.parameters(), lr=self.config["lr"])

def compute(self, ob, *args):
"""Should take in a SINGLE ob"""
"""Should take in a SINGLE ob."""
with self.lock:
ob = torch.from_numpy(ob).float().unsqueeze(0)
logits, values = self._model(ob)
Expand Down Expand Up @@ -64,16 +64,18 @@ def _evaluate(self, obs, actions):
return values, action_log_probs, entropy

def _backward(self, batch):
"""Loss is encoded in here. Defining a new loss function
would start by rewriting this function"""
"""Loss is encoded in here.

Defining a new loss function would start by rewriting this
function
"""

states, actions, advs, rs, _ = convert_batch(batch)
values, action_log_probs, entropy = self._evaluate(states, actions)
pi_err = -advs.dot(action_log_probs.reshape(-1))
value_err = F.mse_loss(values.reshape(-1), rs)

self.optimizer.zero_grad()

overall_err = sum([
pi_err,
self.config["vf_loss_coeff"] * value_err,
Expand Down
4 changes: 4 additions & 0 deletions python/ray/rllib/a3c/torchpolicy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from __future__ import division
from __future__ import print_function

from copy import deepcopy
import torch

from ray.rllib.a3c.policy import Policy
Expand Down Expand Up @@ -30,6 +31,9 @@ def __init__(self,
self.lock = Lock()

def apply_gradients(self, grads):
# TODO(alok): see how A3C fills gradient buffers so that they don't get
# cleared by zero_grad
grads = deepcopy(grads) # TODO rm
self.optimizer.zero_grad()
for g, p in zip(grads, self._model.parameters()):
p.grad = torch.from_numpy(g)
Expand Down
3 changes: 3 additions & 0 deletions python/ray/rllib/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,9 @@ def get_agent_class(alg):
elif alg == "PG":
from ray.rllib import pg
return pg.PGAgent
elif alg == "TRPO":
from ray.rllib import trpo
return trpo.TRPOAgent
elif alg == "script":
from ray.tune import script_runner
return script_runner.ScriptRunner
Expand Down
6 changes: 6 additions & 0 deletions python/ray/rllib/trpo/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from ray.rllib.trpo.trpo import DEFAULT_CONFIG, TRPOAgent

__all__ = [
'TRPOAgent',
'DEFAULT_CONFIG',
]
Loading