diff --git a/python/ray/rllib/__init__.py b/python/ray/rllib/__init__.py index a2441f0b5bf6..c8f9049989ed 100644 --- a/python/ray/rllib/__init__.py +++ b/python/ray/rllib/__init__.py @@ -9,7 +9,7 @@ def _register_all(): for key in ["PPO", "ES", "DQN", "APEX", "A3C", "BC", "PG", "DDPG", - "DDPG2", "APEX_DDPG", "__fake", "__sigmoid_fake_data", + "DDPG2", "APEX_DDPG", "TRPO", "__fake", "__sigmoid_fake_data", "__parameter_tuning"]: from ray.rllib.agent import get_agent_class register_trainable(key, get_agent_class(key)) diff --git a/python/ray/rllib/a3c/a3c.py b/python/ray/rllib/a3c/a3c.py index 569b50c44420..3a45ffae7931 100644 --- a/python/ray/rllib/a3c/a3c.py +++ b/python/ray/rllib/a3c/a3c.py @@ -1,32 +1,29 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function +from __future__ import absolute_import, division, print_function -import numpy as np -import pickle import os +import pickle + +import numpy as np import ray +from ray.rllib.a3c.a3c_evaluator import (A3CEvaluator, GPURemoteA3CEvaluator, + RemoteA3CEvaluator,) from ray.rllib.agent import Agent from ray.rllib.optimizers import AsyncOptimizer from ray.rllib.utils import FilterManager -from ray.rllib.a3c.a3c_evaluator import A3CEvaluator, RemoteA3CEvaluator, \ - GPURemoteA3CEvaluator from ray.tune.result import TrainingResult from ray.tune.trial import Resources DEFAULT_CONFIG = { # Number of workers (excluding master) "num_workers": 4, - # Size of rollout batch + # Size of rollout "batch_size": 10, - # Use LSTM model - only applicable for image states + # Only applicable for image states "use_lstm": False, - # Use PyTorch as backend - no LSTM support + # No LSTM support if PyTorch used "use_pytorch": False, - # Which observation filter to apply to the observation "observation_filter": "NoFilter", - # Which reward filter to apply to the reward "reward_filter": "NoFilter", # Discount factor of MDP "gamma": 0.99, @@ -36,9 +33,7 @@ "grad_clip": 40.0, # Learning rate "lr": 0.0001, - # Value Function Loss coefficient "vf_loss_coeff": 0.5, - # Entropy coefficient "entropy_coeff": -0.01, # Whether to place workers on GPUs "use_gpu_for_workers": False, @@ -84,15 +79,18 @@ def _init(self): self.config, self.logdir, start_sampler=False) + if self.config["use_gpu_for_workers"]: remote_cls = GPURemoteA3CEvaluator else: remote_cls = RemoteA3CEvaluator + self.remote_evaluators = [ remote_cls.remote(self.registry, self.env_creator, self.config, self.logdir) for i in range(self.config["num_workers"]) ] + self.optimizer = AsyncOptimizer(self.config["optimizer"], self.local_evaluator, self.remote_evaluators) @@ -101,20 +99,23 @@ def _train(self): self.optimizer.step() FilterManager.synchronize(self.local_evaluator.filters, self.remote_evaluators) - res = self._fetch_metrics_from_remote_evaluators() - return res + result = self._fetch_metrics_from_remote_evaluators() + return result def _fetch_metrics_from_remote_evaluators(self): episode_rewards = [] episode_lengths = [] + metric_lists = [ a.get_completed_rollout_metrics.remote() for a in self.remote_evaluators ] + for metrics in metric_lists: for episode in ray.get(metrics): episode_lengths.append(episode.episode_length) episode_rewards.append(episode.episode_reward) + avg_reward = (np.mean(episode_rewards) if episode_rewards else float('nan')) avg_length = (np.mean(episode_lengths) @@ -137,21 +138,27 @@ def _stop(self): def _save(self, checkpoint_dir): checkpoint_path = os.path.join(checkpoint_dir, "checkpoint-{}".format(self.iteration)) + agent_state = ray.get( [a.save.remote() for a in self.remote_evaluators]) + extra_data = { "remote_state": agent_state, "local_state": self.local_evaluator.save() } + pickle.dump(extra_data, open(checkpoint_path + ".extra_data", "wb")) + return checkpoint_path def _restore(self, checkpoint_path): extra_data = pickle.load(open(checkpoint_path + ".extra_data", "rb")) + ray.get([ a.restore.remote(o) for a, o in zip(self.remote_evaluators, extra_data["remote_state"]) ]) + self.local_evaluator.restore(extra_data["local_state"]) def compute_action(self, observation): diff --git a/python/ray/rllib/a3c/a3c_evaluator.py b/python/ray/rllib/a3c/a3c_evaluator.py index 74d201016adf..6f5217fa8e89 100644 --- a/python/ray/rllib/a3c/a3c_evaluator.py +++ b/python/ray/rllib/a3c/a3c_evaluator.py @@ -26,23 +26,31 @@ class A3CEvaluator(PolicyEvaluator): rollouts. logdir: Directory for logging. """ - def __init__( - self, registry, env_creator, config, logdir, start_sampler=True): + + def __init__(self, + registry, + env_creator, + config, + logdir, + start_sampler=True): env = ModelCatalog.get_preprocessor_as_wrapper( registry, env_creator(config["env_config"]), config["model"]) self.env = env policy_cls = get_policy_cls(config) # TODO(rliaw): should change this to be just env.observation_space - self.policy = policy_cls( - registry, env.observation_space.shape, env.action_space, config) + self.policy = policy_cls(registry, env.observation_space.shape, + env.action_space, config) self.config = config # Technically not needed when not remote - self.obs_filter = get_filter( - config["observation_filter"], env.observation_space.shape) + self.obs_filter = get_filter(config["observation_filter"], + env.observation_space.shape) self.rew_filter = get_filter(config["reward_filter"], ()) - self.filters = {"obs_filter": self.obs_filter, - "rew_filter": self.rew_filter} + self.filters = { + "obs_filter": self.obs_filter, + "rew_filter": self.rew_filter + } + self.sampler = AsyncSampler(env, self.policy, self.obs_filter, config["batch_size"]) if start_sampler and self.sampler._async: @@ -52,8 +60,11 @@ def __init__( def sample(self): rollout = self.sampler.get_data() samples = process_rollout( - rollout, self.rew_filter, gamma=self.config["gamma"], - lambda_=self.config["lambda"], use_gae=True) + rollout, + self.rew_filter, + gamma=self.config["gamma"], + lambda_=self.config["lambda"], + use_gae=True) return samples def get_completed_rollout_metrics(self): @@ -79,9 +90,7 @@ def set_weights(self, params): def save(self): filters = self.get_filters(flush_after=True) weights = self.get_weights() - return pickle.dumps({ - "filters": filters, - "weights": weights}) + return pickle.dumps({"filters": filters, "weights": weights}) def restore(self, objs): objs = pickle.loads(objs) diff --git a/python/ray/rllib/a3c/shared_torch_policy.py b/python/ray/rllib/a3c/shared_torch_policy.py index d98a2f6dc436..150e1a890fd6 100644 --- a/python/ray/rllib/a3c/shared_torch_policy.py +++ b/python/ray/rllib/a3c/shared_torch_policy.py @@ -6,8 +6,8 @@ import torch.nn.functional as F from ray.rllib.a3c.torchpolicy import TorchPolicy -from ray.rllib.models.pytorch.misc import var_to_np, convert_batch from ray.rllib.models.catalog import ModelCatalog +from ray.rllib.models.pytorch.misc import convert_batch, var_to_np class SharedTorchPolicy(TorchPolicy): @@ -28,7 +28,7 @@ def _setup_graph(self, ob_space, ac_space): self._model.parameters(), lr=self.config["lr"]) def compute(self, ob, *args): - """Should take in a SINGLE ob""" + """Should take in a SINGLE ob.""" with self.lock: ob = torch.from_numpy(ob).float().unsqueeze(0) logits, values = self._model(ob) @@ -64,8 +64,11 @@ def _evaluate(self, obs, actions): return values, action_log_probs, entropy def _backward(self, batch): - """Loss is encoded in here. Defining a new loss function - would start by rewriting this function""" + """Loss is encoded in here. + + Defining a new loss function would start by rewriting this + function + """ states, actions, advs, rs, _ = convert_batch(batch) values, action_log_probs, entropy = self._evaluate(states, actions) @@ -73,7 +76,6 @@ def _backward(self, batch): value_err = F.mse_loss(values.reshape(-1), rs) self.optimizer.zero_grad() - overall_err = sum([ pi_err, self.config["vf_loss_coeff"] * value_err, diff --git a/python/ray/rllib/a3c/torchpolicy.py b/python/ray/rllib/a3c/torchpolicy.py index 8c6a282568c0..768362a60c87 100644 --- a/python/ray/rllib/a3c/torchpolicy.py +++ b/python/ray/rllib/a3c/torchpolicy.py @@ -2,6 +2,7 @@ from __future__ import division from __future__ import print_function +from copy import deepcopy import torch from ray.rllib.a3c.policy import Policy @@ -30,6 +31,9 @@ def __init__(self, self.lock = Lock() def apply_gradients(self, grads): + # TODO(alok): see how A3C fills gradient buffers so that they don't get + # cleared by zero_grad + grads = deepcopy(grads) # TODO rm self.optimizer.zero_grad() for g, p in zip(grads, self._model.parameters()): p.grad = torch.from_numpy(g) diff --git a/python/ray/rllib/agent.py b/python/ray/rllib/agent.py index 5699022b2a8e..08a09555fc6b 100644 --- a/python/ray/rllib/agent.py +++ b/python/ray/rllib/agent.py @@ -261,6 +261,9 @@ def get_agent_class(alg): elif alg == "PG": from ray.rllib import pg return pg.PGAgent + elif alg == "TRPO": + from ray.rllib import trpo + return trpo.TRPOAgent elif alg == "script": from ray.tune import script_runner return script_runner.ScriptRunner diff --git a/python/ray/rllib/trpo/__init__.py b/python/ray/rllib/trpo/__init__.py new file mode 100644 index 000000000000..928264c07505 --- /dev/null +++ b/python/ray/rllib/trpo/__init__.py @@ -0,0 +1,6 @@ +from ray.rllib.trpo.trpo import DEFAULT_CONFIG, TRPOAgent + +__all__ = [ + 'TRPOAgent', + 'DEFAULT_CONFIG', +] diff --git a/python/ray/rllib/trpo/policy.py b/python/ray/rllib/trpo/policy.py new file mode 100644 index 000000000000..1fa793168b3f --- /dev/null +++ b/python/ray/rllib/trpo/policy.py @@ -0,0 +1,240 @@ +"""Code adapted from https://github.com/mjacar/pytorch-trpo.""" +from __future__ import absolute_import, division, print_function + +from copy import deepcopy +from itertools import chain + +import numpy as np +import torch +import torch.nn.functional as F +from torch import distributions +from torch.distributions import kl_divergence +from torch.distributions.categorical import Categorical +from torch.nn.utils.convert_parameters import (_check_param_device, + parameters_to_vector, + vector_to_parameters,) + +import ray +from ray.rllib.a3c.shared_torch_policy import SharedTorchPolicy +from ray.rllib.models.pytorch.misc import convert_batch +from ray.rllib.utils.process_rollout import discount + + + + +def vector_to_gradient(v, parameters): + # TODO(alok) may have to rm the .data from v + r"""Convert one vector representing the + gradient to the .grad of the parameters. + + Arguments: + v (Tensor): a single vector represents the parameters of a model. + parameters (Iterable[Tensor]): an iterator of Tensors that are the + parameters of a model. + """ + # Ensure v of type Tensor + if not isinstance(v, torch.Tensor): + raise TypeError('expected torch.Tensor, but got: {}'.format( + torch.typename(v))) + # Flag for the device where the parameter is located + param_device = None + + # Pointer for slicing the vector for each parameter + pointer = 0 + for param in parameters: + # Ensure the parameters are located in the same device + param_device = _check_param_device(param, param_device) + + # The length of the parameter + num_param = np.prod(param.grad.shape) + # Slice the vector, reshape it, and replace the old data of the parameter + param.grad.data = v[pointer:pointer + num_param].view( + param.shape).detach() + + # Increment the pointer + pointer += num_param + + +class TRPOPolicy(SharedTorchPolicy): + def __init__(self, registry, ob_space, ac_space, config, **kwargs): + super().__init__(registry, ob_space, ac_space, config, **kwargs) + + def _evaluate_action_dists(self, obs, *args): + logits, _ = self._model(obs) + # TODO(alok): Handle continuous case since this assumes a + # Categorical distribution. + action_dists = F.softmax(logits, dim=1) + return action_dists + + def mean_kl(self): + """Returns an estimate of the average KL divergence between a given + policy and self._model.""" + new_prob = F.softmax( + self._model(self._states)[0], dim=1).detach() + 1e-8 + old_prob = F.softmax(self._model(self._states)[0], dim=1) + + # TODO(alok): Handle continuous case since this assumes a + # Categorical distribution. + new_prob, old_prob = Categorical(new_prob), Categorical(old_prob) + + return kl_divergence(new_prob, old_prob).mean() + + def HVP(self, v): + """Returns the product of the Hessian of the KL divergence and the + given vector.""" + + self._model.zero_grad() + + g_kl = torch.autograd.grad( + outputs=self._kl, + inputs=chain(self._model.hidden_layers.parameters(), + self._model.logits.parameters()), + create_graph=True, + ) + flat_g = torch.cat([grad.reshape(-1) for grad in g_kl]) + gvp = flat_g.dot(v) + + H = torch.autograd.grad( + outputs=gvp, + inputs=chain(self._model.hidden_layers.parameters(), + self._model.logits.parameters()), + create_graph=True, + ) + fisher_vector_product = torch.cat( + [grad.reshape(-1) for grad in H]).detach() + + return fisher_vector_product + (self.config['cg_damping'] * v.detach()) + + def conjugate_gradient(self, b, cg_iters=10): + """Returns F^(-1)b where F is the Hessian of the KL divergence.""" + p, r = b.clone().detach(), b.clone().detach() + x = torch.zeros_like(b) + + # all the float casts are to avoid mixing torch scalars and numpy + # arrays + rdotr = r.dot(r) + + for _ in range(cg_iters): + z = self.HVP(p) + v = rdotr / p.dot(z) + x += v * p + r -= v * z + new_rdotr = r.dot(r) + mu = new_rdotr / rdotr + p = r + mu * p + rdotr = new_rdotr + + if rdotr < self.config['residual_tol']: + break + return x + + def surrogate_loss(self, params): + """Returns the surrogate loss wrt the given parameter vector params.""" + + new_policy = deepcopy(self._model) + + # only adjust parameters of action head and hidden layers + vector_to_parameters( + params, + chain(self._model.hidden_layers.parameters(), + self._model.logits.parameters())) + + EPSILON = 1e-8 + + prob_new = new_policy(self._states)[0].gather( + 1, self._actions.view(-1, 1)).detach() + prob_old = self._model(self._states)[0].gather( + 1, self._actions.view(-1, 1)).detach() + EPSILON + + return -torch.mean(self._adv * (prob_new / prob_old)) + + def linesearch(self, x, fullstep, expected_improve_rate): + """Returns the scaled gradient that would improve the loss. + + Found via backtracking linesearch. + """ + + accept_ratio = 0.1 + max_backtracks = 10 + + loss = self.surrogate_loss + + for stepfrac in .5**np.arange(max_backtracks): + + g = stepfrac * fullstep + + actual_improve = loss(x) - loss(x.detach() + g) + expected_improve = expected_improve_rate * stepfrac + + if actual_improve / expected_improve > accept_ratio and actual_improve > 0: + return g + + # If no improvement could be obtained, return 0 gradient + else: + return torch.zeros_like(fullstep) + + def _backward(self, batch): + """Fills gradient buffers up.""" + + states, actions, advs, rewards, _ = convert_batch(batch) + values, _, entropy = self._evaluate(states, actions) + action_dists = self._evaluate_action_dists(states) + + # TODO find way to copy generator + # self._action_params = [ self._model.hidden_layers.parameters(), self._model.logits.parameters(), ] + + self._states = states + self._actions = actions + self._action_dists = action_dists + self._adv = advs + + self._kl = self.mean_kl() + + # Calculate the surrogate loss as the element-wise product of the + # advantage and the probability ratio of actions taken. + # TODO(alok): Do we need log probs or the actual probabilities here? + new_prob = self._action_dists.gather(1, self._actions.view(-1, 1)) + old_prob = new_prob.detach() + 1e-8 + prob_ratio = new_prob / old_prob + + surrogate_loss = -torch.mean(prob_ratio * self._adv) - ( + self.config['entropy_coeff'] * entropy) + + # Gradient wrt policy + self._model.zero_grad() + + # TODO just turn this into a flat list and work with it directly + surrogate_loss.backward(retain_graph=True) + + g = parameters_to_vector([ + p.grad for p in chain(self._model.hidden_layers.parameters(), + self._model.logits.parameters()) + ]) + + # check that gradient is not 0 + if any(g): + step_dir = self.conjugate_gradient(-g) + + # Do line search to determine the stepsize of params in the direction of step_dir + shs = step_dir.dot(self.HVP(step_dir)) / 2 + lagrange_mult = torch.sqrt(shs / self.config['max_kl']) + fullstep = step_dir / lagrange_mult + g_step_dir = -g.dot(step_dir).detach().item() + grad = self.linesearch( + x=parameters_to_vector( + chain(self._model.hidden_layers.parameters(), + self._model.logits.parameters())), + fullstep=fullstep, + expected_improve_rate=g_step_dir / lagrange_mult, + ) + + # Here we fill the gradient buffers + if not any(torch.isnan(grad)): + vector_to_gradient( + grad, + chain(self._model.hidden_layers.parameters(), + self._model.logits.parameters())) + + # Also get gradient wrt value function + value_err = F.mse_loss(values, rewards) + value_err.backward() diff --git a/python/ray/rllib/trpo/trpo.py b/python/ray/rllib/trpo/trpo.py new file mode 100644 index 000000000000..a8970558abb2 --- /dev/null +++ b/python/ray/rllib/trpo/trpo.py @@ -0,0 +1,172 @@ +from __future__ import absolute_import, division, print_function + +import os +import pickle + +import numpy as np + +import ray +from ray.rllib.agent import Agent +from ray.rllib.optimizers import AsyncOptimizer, LocalSyncOptimizer +from ray.rllib.trpo.trpo_evaluator import TRPOEvaluator +from ray.rllib.utils import FilterManager +from ray.tune.result import TrainingResult +from ray.tune.trial import Resources + +DEFAULT_CONFIG = { + # Number of workers (excluding master) + 'num_workers': 4, + 'use_gpu_for_workers': False, + 'vf_loss_coeff': 0.5, + 'use_lstm': False, + # Size of rollout + 'batch_size': 512, + # GAE(gamma) parameter + 'lambda': 1.0, + # Max global norm for each gradient calculated by worker + 'grad_clip': 40.0, + # Discount factor of MDP + 'gamma': 0.99, + 'use_pytorch': True, + 'observation_filter': 'NoFilter', + 'reward_filter': 'NoFilter', + 'entropy_coeff': 0.0, + 'max_kl': 0.001, + 'cg_damping': 0.001, + 'residual_tol': 1e-10, + # Number of steps after which the rollout gets cut + 'horizon': 500, + # Learning rate + 'lr': 0.0001, + # Arguments to pass to the RLlib optimizer + 'optimizer': { + # Number of gradients applied for each `train` step + 'grads_per_step': 100, + }, + # Model and preprocessor options + 'model': { + # (Image statespace) - Converts image to Channels = 1 + 'grayscale': True, + # (Image statespace) - Each pixel + 'zero_mean': False, + # (Image statespace) - Converts image to (dim, dim, C) + 'dim': 80, + # (Image statespace) - Converts image shape to (C, dim, dim) + # + # XXX set to true by default here since there's currently only a + # PyTorch implementation. + "channel_major": True, + }, + # Arguments to pass to the env creator + 'env_config': {}, +} + + +class TRPOAgent(Agent): + _agent_name = 'TRPO' + _default_config = DEFAULT_CONFIG + _allow_unknown_subkeys = ['model', 'optimizer', 'env_config'] + + @classmethod + def default_resource_request(cls, config): + cf = dict(cls._default_config, **config) + return Resources( + cpu=1, + gpu=0, + extra_cpu=cf['num_workers'], + extra_gpu=cf['use_gpu_for_workers'] and cf['num_workers'] or 0) + + def _init(self): + self.local_evaluator = TRPOEvaluator( + self.registry, + self.env_creator, + self.config, + self.logdir, + start_sampler=False, + ) + + RemoteTRPOEvaluator = ray.remote(TRPOEvaluator) + + self.remote_evaluators = [ + RemoteTRPOEvaluator.remote( + self.registry, + self.env_creator, + self.config, + self.logdir, + start_sampler=True, + ) for _ in range(self.config['num_workers']) + ] + + self.optimizer = AsyncOptimizer(self.config['optimizer'], + self.local_evaluator, + self.remote_evaluators) + + def _train(self): + self.optimizer.step() + FilterManager.synchronize(self.local_evaluator.filters, + self.remote_evaluators) + result = self._fetch_metrics_from_remote_evaluators() + return result + + def _fetch_metrics_from_remote_evaluators(self): + episode_rewards = [] + episode_lengths = [] + + metric_lists = [ + a.get_completed_rollout_metrics.remote() + for a in self.remote_evaluators + ] + + for metrics in metric_lists: + for episode in ray.get(metrics): + episode_lengths.append(episode.episode_length) + episode_rewards.append(episode.episode_reward) + + avg_reward = (np.mean(episode_rewards) + if episode_rewards else float('nan')) + avg_length = (np.mean(episode_lengths) + if episode_lengths else float('nan')) + timesteps = np.sum(episode_lengths) if episode_lengths else 0 + + result = TrainingResult( + episode_reward_mean=avg_reward, + episode_len_mean=avg_length, + timesteps_this_iter=timesteps, + info={}) + + return result + + def _stop(self): + # workaround for https://github.com/ray-project/ray/issues/1516 + for ev in self.remote_evaluators: + ev.__ray_terminate__.remote() + + def _save(self, checkpoint_dir): + checkpoint_path = os.path.join(checkpoint_dir, + 'checkpoint-{}'.format(self.iteration)) + agent_state = ray.get( + [a.save.remote() for a in self.remote_evaluators]) + + extra_data = { + 'remote_state': agent_state, + 'local_state': self.local_evaluator.save() + } + + pickle.dump(extra_data, open(checkpoint_path + '.extra_data', 'wb')) + + return checkpoint_path + + def _restore(self, checkpoint_path): + extra_data = pickle.load(open(checkpoint_path + '.extra_data', 'rb')) + + ray.get([ + a.restore.remote(o) + for a, o in zip(self.remote_evaluators, extra_data['remote_state']) + ]) + + self.local_evaluator.restore(extra_data['local_state']) + + def compute_action(self, observation): + obs = self.local_evaluator.obs_filter(observation, update=False) + action, _ = self.local_evaluator.policy.compute(obs) + return action diff --git a/python/ray/rllib/trpo/trpo_evaluator.py b/python/ray/rllib/trpo/trpo_evaluator.py new file mode 100644 index 000000000000..006713f69b2d --- /dev/null +++ b/python/ray/rllib/trpo/trpo_evaluator.py @@ -0,0 +1,149 @@ +from __future__ import absolute_import, division, print_function + +import pickle + +from ray.rllib.models.catalog import ModelCatalog +from ray.rllib.optimizers import PolicyEvaluator +from ray.rllib.trpo.policy import TRPOPolicy +from ray.rllib.utils.filter import NoFilter, get_filter +from ray.rllib.utils.process_rollout import process_rollout +from ray.rllib.utils.sampler import AsyncSampler + + +class TRPOEvaluator(PolicyEvaluator): + """Actor object to start running simulation on workers. + + The gradient computation is also executed from this object. + + Attributes: + policy: Copy of graph used for policy. Used by sampler and gradients. + obs_filter: Observation filter used in environment sampling + rew_filter: Reward filter used in rollout post-processing. + sampler: Component for interacting with environment and generating + rollouts. + """ + + def __init__( + self, + registry, + env_creator, + config, + logdir, + start_sampler=True, + ): + self.config = config + + self.env = ModelCatalog.get_preprocessor_as_wrapper( + registry, + env=env_creator(self.config['env_config']), + options=self.config['model'], + ) + + # TODO(alok): use ob_space directly rather than shape + self.policy = TRPOPolicy( + registry, + self.env.observation_space.shape, + self.env.action_space, + self.config, + ) + + self.obs_filter = get_filter( + self.config['observation_filter'], + self.env.observation_space.shape, + ) + self.rew_filter = get_filter(self.config['reward_filter'], ()) + self.filters = { + 'obs_filter': self.obs_filter, + 'rew_filter': self.rew_filter, + } + + self.sampler = AsyncSampler( + self.env, + self.policy, + obs_filter=NoFilter(), + num_local_steps=config['batch_size'], + horizon=config['horizon'], + ) + + if start_sampler and self.sampler._async: + self.sampler.start() + self.logdir = logdir + + def sample(self): + rollout = self.sampler.get_data() + + samples = process_rollout( + rollout, + self.rew_filter, + gamma=self.config['gamma'], + lambda_=self.config['lambda'], + use_gae=False, + ) + + return samples + + def get_completed_rollout_metrics(self): + """Returns metrics on previously completed rollouts. + + Calling this clears the queue of completed rollout metrics. + """ + return self.sampler.get_metrics() + + def compute_gradients(self, samples): + """Returns gradient w.r.t. + + samples. + """ + gradient, _ = self.policy.compute_gradients(samples) + return gradient, {} + + def apply_gradients(self, grads): + """Applies gradients to evaluator weights.""" + + self.policy.apply_gradients(grads) + + def get_weights(self): + """Returns model weights.""" + + return self.policy.get_weights() + + def set_weights(self, weights): + """Sets model weights.""" + + return self.policy.set_weights(weights) + + def save(self): + filters = self.get_filters(flush_after=True) + weights = self.get_weights() + return pickle.dumps({'filters': filters, 'weights': weights}) + + def restore(self, objs): + objs = pickle.loads(objs) + self.sync_filters(objs['filters']) + self.set_weights(objs['weights']) + + def sync_filters(self, new_filters): + """Changes self's filter to given and rebases any accumulated delta. + + Args: + new_filters (dict): Filters with new state to update local copy. + """ + assert all(k in new_filters for k in self.filters) + for k in self.filters: + self.filters[k].sync(new_filters[k]) + + def get_filters(self, flush_after=False): + """Returns a snapshot of filters. + + Args: + flush_after (bool): Clears the filter buffer state. + + Returns: + return_filters (dict): Dict for serializable filters + """ + return_filters = {} + for k, f in self.filters.items(): + return_filters[k] = f.as_serializable() + if flush_after: + f.clear_buffer() + return return_filters