diff --git a/doc/source/rllib-package-ref.rst b/doc/source/rllib-package-ref.rst index db4b2dbfe0eb..6a4e6aed43f8 100644 --- a/doc/source/rllib-package-ref.rst +++ b/doc/source/rllib-package-ref.rst @@ -1,25 +1,11 @@ RLlib Package Reference ======================= -ray.rllib.agents +ray.rllib.policy ---------------- -.. automodule:: ray.rllib.agents +.. automodule:: ray.rllib.policy :members: - -.. autoclass:: ray.rllib.agents.a3c.A2CTrainer -.. autoclass:: ray.rllib.agents.a3c.A3CTrainer -.. autoclass:: ray.rllib.agents.ddpg.ApexDDPGTrainer -.. autoclass:: ray.rllib.agents.ddpg.DDPGTrainer -.. autoclass:: ray.rllib.agents.dqn.ApexTrainer -.. autoclass:: ray.rllib.agents.dqn.DQNTrainer -.. autoclass:: ray.rllib.agents.es.ESTrainer -.. autoclass:: ray.rllib.agents.pg.PGTrainer -.. autoclass:: ray.rllib.agents.impala.ImpalaTrainer -.. autoclass:: ray.rllib.agents.ppo.APPOTrainer -.. autoclass:: ray.rllib.agents.ppo.PPOTrainer -.. autoclass:: ray.rllib.agents.marwil.MARWILTrainer - ray.rllib.env ------------- diff --git a/python/ray/rllib/agents/ddpg/apex.py b/python/ray/rllib/agents/ddpg/apex.py index 5ea732f17508..e0731e87a809 100644 --- a/python/ray/rllib/agents/ddpg/apex.py +++ b/python/ray/rllib/agents/ddpg/apex.py @@ -2,15 +2,14 @@ from __future__ import division from __future__ import print_function +from ray.rllib.agents.dqn.apex import APEX_TRAINER_PROPERTIES from ray.rllib.agents.ddpg.ddpg import DDPGTrainer, \ DEFAULT_CONFIG as DDPG_CONFIG -from ray.rllib.utils.annotations import override from ray.rllib.utils import merge_dicts APEX_DDPG_DEFAULT_CONFIG = merge_dicts( DDPG_CONFIG, # see also the options in ddpg.py, which are also supported { - "optimizer_class": "AsyncReplayOptimizer", "optimizer": merge_dicts( DDPG_CONFIG["optimizer"], { "max_weight_sync_delay": 400, @@ -32,23 +31,7 @@ }, ) - -class ApexDDPGTrainer(DDPGTrainer): - """DDPG variant that uses the Ape-X distributed policy optimizer. - - By default, this is configured for a large single node (32 cores). For - running in a large cluster, increase the `num_workers` config var. - """ - - _name = "APEX_DDPG" - _default_config = APEX_DDPG_DEFAULT_CONFIG - - @override(DDPGTrainer) - def update_target_if_needed(self): - # Ape-X updates based on num steps trained, not sampled - if self.optimizer.num_steps_trained - self.last_target_update_ts > \ - self.config["target_network_update_freq"]: - self.workers.local_worker().foreach_trainable_policy( - lambda p, _: p.update_target()) - self.last_target_update_ts = self.optimizer.num_steps_trained - self.num_target_updates += 1 +ApexDDPGTrainer = DDPGTrainer.with_updates( + name="APEX_DDPG", + default_config=APEX_DDPG_DEFAULT_CONFIG, + **APEX_TRAINER_PROPERTIES) diff --git a/python/ray/rllib/agents/ddpg/ddpg.py b/python/ray/rllib/agents/ddpg/ddpg.py index a9676335eb3f..a6b42f1ca927 100644 --- a/python/ray/rllib/agents/ddpg/ddpg.py +++ b/python/ray/rllib/agents/ddpg/ddpg.py @@ -3,9 +3,9 @@ from __future__ import print_function from ray.rllib.agents.trainer import with_common_config -from ray.rllib.agents.dqn.dqn import DQNTrainer +from ray.rllib.agents.dqn.dqn import GenericOffPolicyTrainer, \ + update_worker_explorations from ray.rllib.agents.ddpg.ddpg_policy import DDPGTFPolicy -from ray.rllib.utils.annotations import override from ray.rllib.utils.schedules import ConstantSchedule, LinearSchedule # yapf: disable @@ -97,6 +97,11 @@ # optimization on initial policy parameters. Note that this will be # disabled when the action noise scale is set to 0 (e.g during evaluation). "pure_exploration_steps": 1000, + # Extra configuration that disables exploration. + "evaluation_config": { + "exploration_fraction": 0, + "exploration_final_eps": 0, + }, # === Replay buffer === # Size of the replay buffer. Note that if async_updates is set, then @@ -108,6 +113,11 @@ "prioritized_replay_alpha": 0.6, # Beta parameter for sampling from prioritized replay buffer. "prioritized_replay_beta": 0.4, + # Fraction of entire training period over which the beta parameter is + # annealed + "beta_annealing_fraction": 0.2, + # Final value of beta + "final_prioritized_replay_beta": 0.4, # Epsilon to add to the TD errors when updating priorities. "prioritized_replay_eps": 1e-6, # Whether to LZ4 compress observations @@ -146,8 +156,6 @@ # to increase if your environment is particularly slow to sample, or if # you're using the Async or Ape-X optimizers. "num_workers": 0, - # Optimizer class to use. - "optimizer_class": "SyncReplayOptimizer", # Whether to use a distribution of epsilons across workers for exploration. "per_worker_exploration": False, # Whether to compute priorities on workers. @@ -159,47 +167,56 @@ # yapf: enable -class DDPGTrainer(DQNTrainer): - """DDPG implementation in TensorFlow.""" - _name = "DDPG" - _default_config = DEFAULT_CONFIG - _policy = DDPGTFPolicy +def make_exploration_schedule(config, worker_index): + # Modification of DQN's schedule to take into account + # `exploration_ou_noise_scale` + if config["per_worker_exploration"]: + assert config["num_workers"] > 1, "This requires multiple workers" + if worker_index >= 0: + # FIXME: what do magic constants mean? (0.4, 7) + max_index = float(config["num_workers"] - 1) + exponent = 1 + worker_index / max_index * 7 + return ConstantSchedule(0.4**exponent) + else: + # local ev should have zero exploration so that eval rollouts + # run properly + return ConstantSchedule(0.0) + elif config["exploration_should_anneal"]: + return LinearSchedule( + schedule_timesteps=int(config["exploration_fraction"] * + config["schedule_max_timesteps"]), + initial_p=1.0, + final_p=config["exploration_final_scale"]) + else: + # *always* add exploration noise + return ConstantSchedule(1.0) + + +def setup_ddpg_exploration(trainer): + trainer.exploration0 = make_exploration_schedule(trainer.config, -1) + trainer.explorations = [ + make_exploration_schedule(trainer.config, i) + for i in range(trainer.config["num_workers"]) + ] - @override(DQNTrainer) - def _train(self): - pure_expl_steps = self.config["pure_exploration_steps"] - if pure_expl_steps: - # tell workers whether they should do pure exploration - only_explore = self.global_timestep < pure_expl_steps - self.workers.local_worker().foreach_trainable_policy( + +def add_pure_exploration_phase(trainer): + global_timestep = trainer.optimizer.num_steps_sampled + pure_expl_steps = trainer.config["pure_exploration_steps"] + if pure_expl_steps: + # tell workers whether they should do pure exploration + only_explore = global_timestep < pure_expl_steps + trainer.workers.local_worker().foreach_trainable_policy( + lambda p, _: p.set_pure_exploration_phase(only_explore)) + for e in trainer.workers.remote_workers(): + e.foreach_trainable_policy.remote( lambda p, _: p.set_pure_exploration_phase(only_explore)) - for e in self.workers.remote_workers(): - e.foreach_trainable_policy.remote( - lambda p, _: p.set_pure_exploration_phase(only_explore)) - return super(DDPGTrainer, self)._train() - - @override(DQNTrainer) - def _make_exploration_schedule(self, worker_index): - # Override DQN's schedule to take into account - # `exploration_ou_noise_scale` - if self.config["per_worker_exploration"]: - assert self.config["num_workers"] > 1, \ - "This requires multiple workers" - if worker_index >= 0: - # FIXME: what do magic constants mean? (0.4, 7) - max_index = float(self.config["num_workers"] - 1) - exponent = 1 + worker_index / max_index * 7 - return ConstantSchedule(0.4**exponent) - else: - # local ev should have zero exploration so that eval rollouts - # run properly - return ConstantSchedule(0.0) - elif self.config["exploration_should_anneal"]: - return LinearSchedule( - schedule_timesteps=int(self.config["exploration_fraction"] * - self.config["schedule_max_timesteps"]), - initial_p=1.0, - final_p=self.config["exploration_final_scale"]) - else: - # *always* add exploration noise - return ConstantSchedule(1.0) + update_worker_explorations(trainer) + + +DDPGTrainer = GenericOffPolicyTrainer.with_updates( + name="DDPG", + default_config=DEFAULT_CONFIG, + default_policy=DDPGTFPolicy, + before_init=setup_ddpg_exploration, + before_train_step=add_pure_exploration_phase) diff --git a/python/ray/rllib/agents/ddpg/td3.py b/python/ray/rllib/agents/ddpg/td3.py index 714c39c6b2f8..ad3675294ce5 100644 --- a/python/ray/rllib/agents/ddpg/td3.py +++ b/python/ray/rllib/agents/ddpg/td3.py @@ -1,3 +1,9 @@ +"""A more stable successor to TD3. + +By default, this uses a near-identical configuration to that reported in the +TD3 paper. +""" + from __future__ import absolute_import from __future__ import division from __future__ import print_function @@ -36,7 +42,6 @@ "train_batch_size": 100, "use_huber": False, "target_network_update_freq": 0, - "optimizer_class": "SyncReplayOptimizer", "num_workers": 0, "num_gpus_per_worker": 0, "per_worker_exploration": False, @@ -48,10 +53,5 @@ }, ) - -class TD3Trainer(DDPGTrainer): - """A more stable successor to TD3. By default, this uses a near-identical - configuration to that reported in the TD3 paper.""" - - _name = "TD3" - _default_config = TD3_DEFAULT_CONFIG +TD3Trainer = DDPGTrainer.with_updates( + name="TD3", default_config=TD3_DEFAULT_CONFIG) diff --git a/python/ray/rllib/agents/dqn/apex.py b/python/ray/rllib/agents/dqn/apex.py index 129839a27119..ab89256a6b95 100644 --- a/python/ray/rllib/agents/dqn/apex.py +++ b/python/ray/rllib/agents/dqn/apex.py @@ -3,15 +3,14 @@ from __future__ import print_function from ray.rllib.agents.dqn.dqn import DQNTrainer, DEFAULT_CONFIG as DQN_CONFIG +from ray.rllib.optimizers import AsyncReplayOptimizer from ray.rllib.utils import merge_dicts -from ray.rllib.utils.annotations import override # yapf: disable # __sphinx_doc_begin__ APEX_DEFAULT_CONFIG = merge_dicts( DQN_CONFIG, # see also the options in dqn.py, which are also supported { - "optimizer_class": "AsyncReplayOptimizer", "optimizer": merge_dicts( DQN_CONFIG["optimizer"], { "max_weight_sync_delay": 400, @@ -36,22 +35,50 @@ # yapf: enable -class ApexTrainer(DQNTrainer): - """DQN variant that uses the Ape-X distributed policy optimizer. +def defer_make_workers(trainer, env_creator, policy, config): + # Hack to workaround https://github.com/ray-project/ray/issues/2541 + # The workers will be creatd later, after the optimizer is created + return trainer._make_workers(env_creator, policy, config, 0) - By default, this is configured for a large single node (32 cores). For - running in a large cluster, increase the `num_workers` config var. - """ - _name = "APEX" - _default_config = APEX_DEFAULT_CONFIG +def make_async_optimizer(workers, config): + assert len(workers.remote_workers()) == 0 + extra_config = config["optimizer"].copy() + for key in [ + "prioritized_replay", "prioritized_replay_alpha", + "prioritized_replay_beta", "prioritized_replay_eps" + ]: + if key in config: + extra_config[key] = config[key] + opt = AsyncReplayOptimizer( + workers, + learning_starts=config["learning_starts"], + buffer_size=config["buffer_size"], + train_batch_size=config["train_batch_size"], + sample_batch_size=config["sample_batch_size"], + **extra_config) + workers.add_workers(config["num_workers"]) + opt._set_workers(workers.remote_workers()) + return opt - @override(DQNTrainer) - def update_target_if_needed(self): - # Ape-X updates based on num steps trained, not sampled - if self.optimizer.num_steps_trained - self.last_target_update_ts > \ - self.config["target_network_update_freq"]: - self.workers.local_worker().foreach_trainable_policy( - lambda p, _: p.update_target()) - self.last_target_update_ts = self.optimizer.num_steps_trained - self.num_target_updates += 1 + +def update_target_based_on_num_steps_trained(trainer, fetches): + # Ape-X updates based on num steps trained, not sampled + if (trainer.optimizer.num_steps_trained - + trainer.state["last_target_update_ts"] > + trainer.config["target_network_update_freq"]): + trainer.workers.local_worker().foreach_trainable_policy( + lambda p, _: p.update_target()) + trainer.state["last_target_update_ts"] = ( + trainer.optimizer.num_steps_trained) + trainer.state["num_target_updates"] += 1 + + +APEX_TRAINER_PROPERTIES = { + "make_workers": defer_make_workers, + "make_policy_optimizer": make_async_optimizer, + "after_optimizer_step": update_target_based_on_num_steps_trained, +} + +ApexTrainer = DQNTrainer.with_updates( + name="APEX", default_config=APEX_DEFAULT_CONFIG, **APEX_TRAINER_PROPERTIES) diff --git a/python/ray/rllib/agents/dqn/dqn.py b/python/ray/rllib/agents/dqn/dqn.py index 15379e3fb394..cc418907a0b9 100644 --- a/python/ray/rllib/agents/dqn/dqn.py +++ b/python/ray/rllib/agents/dqn/dqn.py @@ -3,27 +3,17 @@ from __future__ import print_function import logging -import time from ray import tune -from ray.rllib import optimizers -from ray.rllib.agents.trainer import Trainer, with_common_config +from ray.rllib.agents.trainer import with_common_config +from ray.rllib.agents.trainer_template import build_trainer from ray.rllib.agents.dqn.dqn_policy import DQNTFPolicy -from ray.rllib.evaluation.metrics import collect_metrics +from ray.rllib.optimizers import SyncReplayOptimizer from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID -from ray.rllib.utils.annotations import override from ray.rllib.utils.schedules import ConstantSchedule, LinearSchedule logger = logging.getLogger(__name__) -OPTIMIZER_SHARED_CONFIGS = [ - "buffer_size", "prioritized_replay", "prioritized_replay_alpha", - "prioritized_replay_beta", "schedule_max_timesteps", - "beta_annealing_fraction", "final_prioritized_replay_beta", - "prioritized_replay_eps", "sample_batch_size", "train_batch_size", - "learning_starts" -] - # yapf: disable # __sphinx_doc_begin__ DEFAULT_CONFIG = with_common_config({ @@ -53,7 +43,8 @@ # 1.0 to exploration_fraction over this number of timesteps scaled by # exploration_fraction "schedule_max_timesteps": 100000, - # Number of env steps to optimize for before returning + # Minimum env steps to optimize for per train call. This value does + # not affect learning, only the length of iterations. "timesteps_per_iteration": 1000, # Fraction of entire training period over which the exploration rate is # annealed @@ -70,6 +61,11 @@ # If True parameter space noise will be used for exploration # See https://blog.openai.com/better-exploration-with-parameter-noise/ "parameter_noise": False, + # Extra configuration that disables exploration. + "evaluation_config": { + "exploration_fraction": 0, + "exploration_final_eps": 0, + }, # === Replay buffer === # Size of the replay buffer. Note that if async_updates is set, then @@ -115,8 +111,6 @@ # to increase if your environment is particularly slow to sample, or if # you"re using the Async or Ape-X optimizers. "num_workers": 0, - # Optimizer class to use. - "optimizer_class": "SyncReplayOptimizer", # Whether to use a distribution of epsilons across workers for exploration. "per_worker_exploration": False, # Whether to compute priorities on workers. @@ -128,202 +122,175 @@ # yapf: enable -class DQNTrainer(Trainer): - """DQN implementation in TensorFlow.""" - - _name = "DQN" - _default_config = DEFAULT_CONFIG - _policy = DQNTFPolicy - _optimizer_shared_configs = OPTIMIZER_SHARED_CONFIGS - - @override(Trainer) - def _init(self, config, env_creator): - self._validate_config() - - # Update effective batch size to include n-step - adjusted_batch_size = max(config["sample_batch_size"], - config.get("n_step", 1)) - config["sample_batch_size"] = adjusted_batch_size - - self.exploration0 = self._make_exploration_schedule(-1) - self.explorations = [ - self._make_exploration_schedule(i) - for i in range(config["num_workers"]) - ] - - for k in self._optimizer_shared_configs: - if self._name != "DQN" and k in [ - "schedule_max_timesteps", "beta_annealing_fraction", - "final_prioritized_replay_beta" - ]: - # only Rainbow needs annealing prioritized_replay_beta - continue - if k not in config["optimizer"]: - config["optimizer"][k] = config[k] - - if config.get("parameter_noise", False): - if config["callbacks"]["on_episode_start"]: - start_callback = config["callbacks"]["on_episode_start"] - else: - start_callback = None - - def on_episode_start(info): - # as a callback function to sample and pose parameter space - # noise on the parameters of network - policies = info["policy"] - for pi in policies.values(): - pi.add_parameter_noise() - if start_callback: - start_callback(info) - - config["callbacks"]["on_episode_start"] = tune.function( - on_episode_start) - if config["callbacks"]["on_episode_end"]: - end_callback = config["callbacks"]["on_episode_end"] - else: - end_callback = None - - def on_episode_end(info): - # as a callback function to monitor the distance - # between noisy policy and original policy - policies = info["policy"] - episode = info["episode"] - episode.custom_metrics["policy_distance"] = policies[ - DEFAULT_POLICY_ID].pi_distance - if end_callback: - end_callback(info) - - config["callbacks"]["on_episode_end"] = tune.function( - on_episode_end) - - if config["optimizer_class"] != "AsyncReplayOptimizer": - self.workers = self._make_workers( - env_creator, - self._policy, - config, - num_workers=self.config["num_workers"]) - workers_needed = 0 +def make_optimizer(workers, config): + return SyncReplayOptimizer( + workers, + learning_starts=config["learning_starts"], + buffer_size=config["buffer_size"], + prioritized_replay=config["prioritized_replay"], + prioritized_replay_alpha=config["prioritized_replay_alpha"], + prioritized_replay_beta=config["prioritized_replay_beta"], + schedule_max_timesteps=config["schedule_max_timesteps"], + beta_annealing_fraction=config["beta_annealing_fraction"], + final_prioritized_replay_beta=config["final_prioritized_replay_beta"], + prioritized_replay_eps=config["prioritized_replay_eps"], + train_batch_size=config["train_batch_size"], + sample_batch_size=config["sample_batch_size"], + **config["optimizer"]) + + +def check_config_and_setup_param_noise(config): + """Update the config based on settings. + + Rewrites sample_batch_size to take into account n_step truncation, and also + adds the necessary callbacks to support parameter space noise exploration. + """ + + # Update effective batch size to include n-step + adjusted_batch_size = max(config["sample_batch_size"], + config.get("n_step", 1)) + config["sample_batch_size"] = adjusted_batch_size + + if config.get("parameter_noise", False): + if config["batch_mode"] != "complete_episodes": + raise ValueError("Exploration with parameter space noise requires " + "batch_mode to be complete_episodes.") + if config.get("noisy", False): + raise ValueError( + "Exploration with parameter space noise and noisy network " + "cannot be used at the same time.") + if config["callbacks"]["on_episode_start"]: + start_callback = config["callbacks"]["on_episode_start"] + else: + start_callback = None + + def on_episode_start(info): + # as a callback function to sample and pose parameter space + # noise on the parameters of network + policies = info["policy"] + for pi in policies.values(): + pi.add_parameter_noise() + if start_callback: + start_callback(info) + + config["callbacks"]["on_episode_start"] = tune.function( + on_episode_start) + if config["callbacks"]["on_episode_end"]: + end_callback = config["callbacks"]["on_episode_end"] else: - # Hack to workaround https://github.com/ray-project/ray/issues/2541 - self.workers = self._make_workers( - env_creator, self._policy, config, num_workers=0) - workers_needed = self.config["num_workers"] - - self.optimizer = getattr(optimizers, config["optimizer_class"])( - self.workers, **config["optimizer"]) - - # Create the remote workers *after* the replay actors - if workers_needed > 0: - self.workers.add_workers(workers_needed) - self.optimizer._set_workers(self.workers.remote_workers()) - - self.last_target_update_ts = 0 - self.num_target_updates = 0 - - @override(Trainer) - def _train(self): - start_timestep = self.global_timestep - - # Update worker explorations - exp_vals = [self.exploration0.value(self.global_timestep)] - self.workers.local_worker().foreach_trainable_policy( - lambda p, _: p.set_epsilon(exp_vals[0])) - for i, e in enumerate(self.workers.remote_workers()): - exp_val = self.explorations[i].value(self.global_timestep) - e.foreach_trainable_policy.remote( - lambda p, _: p.set_epsilon(exp_val)) - exp_vals.append(exp_val) - - # Do optimization steps - start = time.time() - while (self.global_timestep - start_timestep < - self.config["timesteps_per_iteration"] - ) or time.time() - start < self.config["min_iter_time_s"]: - self.optimizer.step() - self.update_target_if_needed() - - if self.config["per_worker_exploration"]: - # Only collect metrics from the third of workers with lowest eps - result = self.collect_metrics( - selected_workers=self.workers.remote_workers()[ - -len(self.workers.remote_workers()) // 3:]) + end_callback = None + + def on_episode_end(info): + # as a callback function to monitor the distance + # between noisy policy and original policy + policies = info["policy"] + episode = info["episode"] + episode.custom_metrics["policy_distance"] = policies[ + DEFAULT_POLICY_ID].pi_distance + if end_callback: + end_callback(info) + + config["callbacks"]["on_episode_end"] = tune.function(on_episode_end) + + +def get_initial_state(config): + return { + "last_target_update_ts": 0, + "num_target_updates": 0, + } + + +def make_exploration_schedule(config, worker_index): + # Use either a different `eps` per worker, or a linear schedule. + if config["per_worker_exploration"]: + assert config["num_workers"] > 1, \ + "This requires multiple workers" + if worker_index >= 0: + exponent = ( + 1 + worker_index / float(config["num_workers"] - 1) * 7) + return ConstantSchedule(0.4**exponent) else: - result = self.collect_metrics() - - result.update( - timesteps_this_iter=self.global_timestep - start_timestep, - info=dict({ - "min_exploration": min(exp_vals), - "max_exploration": max(exp_vals), - "num_target_updates": self.num_target_updates, - }, **self.optimizer.stats())) - - return result - - def update_target_if_needed(self): - if self.global_timestep - self.last_target_update_ts > \ - self.config["target_network_update_freq"]: - self.workers.local_worker().foreach_trainable_policy( - lambda p, _: p.update_target()) - self.last_target_update_ts = self.global_timestep - self.num_target_updates += 1 - - @property - def global_timestep(self): - return self.optimizer.num_steps_sampled - - def _evaluate(self): - logger.info("Evaluating current policy for {} episodes".format( - self.config["evaluation_num_episodes"])) - self.evaluation_workers.local_worker().restore( - self.workers.local_worker().save()) - self.evaluation_workers.local_worker().foreach_policy( - lambda p, _: p.set_epsilon(0)) - for _ in range(self.config["evaluation_num_episodes"]): - self.evaluation_workers.local_worker().sample() - metrics = collect_metrics(self.evaluation_workers.local_worker()) - return {"evaluation": metrics} - - def _make_exploration_schedule(self, worker_index): - # Use either a different `eps` per worker, or a linear schedule. - if self.config["per_worker_exploration"]: - assert self.config["num_workers"] > 1, \ - "This requires multiple workers" - if worker_index >= 0: - exponent = ( - 1 + - worker_index / float(self.config["num_workers"] - 1) * 7) - return ConstantSchedule(0.4**exponent) - else: - # local ev should have zero exploration so that eval rollouts - # run properly - return ConstantSchedule(0.0) - return LinearSchedule( - schedule_timesteps=int(self.config["exploration_fraction"] * - self.config["schedule_max_timesteps"]), - initial_p=1.0, - final_p=self.config["exploration_final_eps"]) - - def __getstate__(self): - state = Trainer.__getstate__(self) - state.update({ - "num_target_updates": self.num_target_updates, - "last_target_update_ts": self.last_target_update_ts, - }) - return state - - def __setstate__(self, state): - Trainer.__setstate__(self, state) - self.num_target_updates = state["num_target_updates"] - self.last_target_update_ts = state["last_target_update_ts"] - - def _validate_config(self): - if self.config.get("parameter_noise", False): - if self.config["batch_mode"] != "complete_episodes": - raise ValueError( - "Exploration with parameter space noise requires " - "batch_mode to be complete_episodes.") - if self.config.get("noisy", False): - raise ValueError( - "Exploration with parameter space noise and noisy network " - "cannot be used at the same time.") + # local ev should have zero exploration so that eval rollouts + # run properly + return ConstantSchedule(0.0) + return LinearSchedule( + schedule_timesteps=int( + config["exploration_fraction"] * config["schedule_max_timesteps"]), + initial_p=1.0, + final_p=config["exploration_final_eps"]) + + +def setup_exploration(trainer): + trainer.exploration0 = make_exploration_schedule(trainer.config, -1) + trainer.explorations = [ + make_exploration_schedule(trainer.config, i) + for i in range(trainer.config["num_workers"]) + ] + + +def update_worker_explorations(trainer): + global_timestep = trainer.optimizer.num_steps_sampled + exp_vals = [trainer.exploration0.value(global_timestep)] + trainer.workers.local_worker().foreach_trainable_policy( + lambda p, _: p.set_epsilon(exp_vals[0])) + for i, e in enumerate(trainer.workers.remote_workers()): + exp_val = trainer.explorations[i].value(global_timestep) + e.foreach_trainable_policy.remote(lambda p, _: p.set_epsilon(exp_val)) + exp_vals.append(exp_val) + trainer.train_start_timestep = global_timestep + trainer.cur_exp_vals = exp_vals + + +def add_trainer_metrics(trainer, result): + global_timestep = trainer.optimizer.num_steps_sampled + result.update( + timesteps_this_iter=global_timestep - trainer.train_start_timestep, + info=dict({ + "min_exploration": min(trainer.cur_exp_vals), + "max_exploration": max(trainer.cur_exp_vals), + "num_target_updates": trainer.state["num_target_updates"], + }, **trainer.optimizer.stats())) + + +def update_target_if_needed(trainer, fetches): + global_timestep = trainer.optimizer.num_steps_sampled + if global_timestep - trainer.state["last_target_update_ts"] > \ + trainer.config["target_network_update_freq"]: + trainer.workers.local_worker().foreach_trainable_policy( + lambda p, _: p.update_target()) + trainer.state["last_target_update_ts"] = global_timestep + trainer.state["num_target_updates"] += 1 + + +def collect_metrics(trainer): + if trainer.config["per_worker_exploration"]: + # Only collect metrics from the third of workers with lowest eps + result = trainer.collect_metrics( + selected_workers=trainer.workers.remote_workers()[ + -len(trainer.workers.remote_workers()) // 3:]) + else: + result = trainer.collect_metrics() + return result + + +def disable_exploration(trainer): + trainer.evaluation_workers.local_worker().foreach_policy( + lambda p, _: p.set_epsilon(0)) + + +GenericOffPolicyTrainer = build_trainer( + name="GenericOffPolicyAlgorithm", + default_policy=None, + default_config=DEFAULT_CONFIG, + validate_config=check_config_and_setup_param_noise, + get_initial_state=get_initial_state, + make_policy_optimizer=make_optimizer, + before_init=setup_exploration, + before_train_step=update_worker_explorations, + after_optimizer_step=update_target_if_needed, + after_train_result=add_trainer_metrics, + collect_metrics_fn=collect_metrics, + before_evaluate_fn=disable_exploration) + +DQNTrainer = GenericOffPolicyTrainer.with_updates( + name="DQN", default_policy=DQNTFPolicy, default_config=DEFAULT_CONFIG) diff --git a/python/ray/rllib/agents/impala/impala.py b/python/ray/rllib/agents/impala/impala.py index e025a4817f8f..b9699888bfaf 100644 --- a/python/ray/rllib/agents/impala/impala.py +++ b/python/ray/rllib/agents/impala/impala.py @@ -2,33 +2,16 @@ from __future__ import division from __future__ import print_function -import time - from ray.rllib.agents.a3c.a3c_tf_policy import A3CTFPolicy from ray.rllib.agents.impala.vtrace_policy import VTraceTFPolicy from ray.rllib.agents.trainer import Trainer, with_common_config +from ray.rllib.agents.trainer_template import build_trainer from ray.rllib.optimizers import AsyncSamplesOptimizer from ray.rllib.optimizers.aso_tree_aggregator import TreeAggregator from ray.rllib.utils.annotations import override from ray.tune.trainable import Trainable from ray.tune.trial import Resources -OPTIMIZER_SHARED_CONFIGS = [ - "lr", - "num_envs_per_worker", - "num_gpus", - "sample_batch_size", - "train_batch_size", - "replay_buffer_num_slots", - "replay_proportion", - "num_data_loader_buffers", - "max_sample_requests_in_flight_per_worker", - "broadcast_interval", - "num_sgd_iter", - "minibatch_buffer_size", - "num_aggregation_workers", -] - # yapf: disable # __sphinx_doc_begin__ DEFAULT_CONFIG = with_common_config({ @@ -100,37 +83,57 @@ # yapf: enable -class ImpalaTrainer(Trainer): - """IMPALA implementation using DeepMind's V-trace.""" - - _name = "IMPALA" - _default_config = DEFAULT_CONFIG - _policy = VTraceTFPolicy - - @override(Trainer) - def _init(self, config, env_creator): - for k in OPTIMIZER_SHARED_CONFIGS: - if k not in config["optimizer"]: - config["optimizer"][k] = config[k] - policy_cls = self._get_policy() - self.workers = self._make_workers( - self.env_creator, policy_cls, self.config, num_workers=0) - - if self.config["num_aggregation_workers"] > 0: - # Create co-located aggregator actors first for placement pref - aggregators = TreeAggregator.precreate_aggregators( - self.config["num_aggregation_workers"]) - - self.workers.add_workers(config["num_workers"]) - self.optimizer = AsyncSamplesOptimizer(self.workers, - **config["optimizer"]) - if config["entropy_coeff"] < 0: - raise DeprecationWarning("entropy_coeff must be >= 0") - - if self.config["num_aggregation_workers"] > 0: - # Assign the pre-created aggregators to the optimizer - self.optimizer.aggregator.init(aggregators) - +def choose_policy(config): + if config["vtrace"]: + return VTraceTFPolicy + else: + return A3CTFPolicy + + +def validate_config(config): + if config["entropy_coeff"] < 0: + raise DeprecationWarning("entropy_coeff must be >= 0") + + +def defer_make_workers(trainer, env_creator, policy, config): + # Defer worker creation to after the optimizer has been created. + return trainer._make_workers(env_creator, policy, config, 0) + + +def make_aggregators_and_optimizer(workers, config): + if config["num_aggregation_workers"] > 0: + # Create co-located aggregator actors first for placement pref + aggregators = TreeAggregator.precreate_aggregators( + config["num_aggregation_workers"]) + else: + aggregators = None + workers.add_workers(config["num_workers"]) + + optimizer = AsyncSamplesOptimizer( + workers, + lr=config["lr"], + num_envs_per_worker=config["num_envs_per_worker"], + num_gpus=config["num_gpus"], + sample_batch_size=config["sample_batch_size"], + train_batch_size=config["train_batch_size"], + replay_buffer_num_slots=config["replay_buffer_num_slots"], + replay_proportion=config["replay_proportion"], + num_data_loader_buffers=config["num_data_loader_buffers"], + max_sample_requests_in_flight_per_worker=config[ + "max_sample_requests_in_flight_per_worker"], + broadcast_interval=config["broadcast_interval"], + num_sgd_iter=config["num_sgd_iter"], + minibatch_buffer_size=config["minibatch_buffer_size"], + num_aggregation_workers=config["num_aggregation_workers"], + **config["optimizer"]) + + if aggregators: + # Assign the pre-created aggregators to the optimizer + optimizer.aggregator.init(aggregators) + return optimizer + + +class OverrideDefaultResourceRequest(object): @classmethod @override(Trainable) def default_resource_request(cls, config): @@ -143,22 +146,13 @@ def default_resource_request(cls, config): cf["num_aggregation_workers"], extra_gpu=cf["num_gpus_per_worker"] * cf["num_workers"]) - @override(Trainer) - def _train(self): - prev_steps = self.optimizer.num_steps_sampled - start = time.time() - self.optimizer.step() - while (time.time() - start < self.config["min_iter_time_s"] - or self.optimizer.num_steps_sampled == prev_steps): - self.optimizer.step() - result = self.collect_metrics() - result.update(timesteps_this_iter=self.optimizer.num_steps_sampled - - prev_steps) - return result - - def _get_policy(self): - if self.config["vtrace"]: - policy_cls = self._policy - else: - policy_cls = A3CTFPolicy - return policy_cls + +ImpalaTrainer = build_trainer( + name="IMPALA", + default_config=DEFAULT_CONFIG, + default_policy=VTraceTFPolicy, + validate_config=validate_config, + get_policy_class=choose_policy, + make_workers=defer_make_workers, + make_policy_optimizer=make_aggregators_and_optimizer, + mixins=[OverrideDefaultResourceRequest]) diff --git a/python/ray/rllib/agents/marwil/marwil.py b/python/ray/rllib/agents/marwil/marwil.py index b8e01806ca29..29be38a84c32 100644 --- a/python/ray/rllib/agents/marwil/marwil.py +++ b/python/ray/rllib/agents/marwil/marwil.py @@ -2,10 +2,10 @@ from __future__ import division from __future__ import print_function -from ray.rllib.agents.trainer import Trainer, with_common_config +from ray.rllib.agents.trainer import with_common_config +from ray.rllib.agents.trainer_template import build_trainer from ray.rllib.agents.marwil.marwil_policy import MARWILPolicy from ray.rllib.optimizers import SyncBatchReplayOptimizer -from ray.rllib.utils.annotations import override # yapf: disable # __sphinx_doc_begin__ @@ -39,30 +39,17 @@ # yapf: enable -class MARWILTrainer(Trainer): - """MARWIL implementation in TensorFlow.""" +def make_optimizer(workers, config): + return SyncBatchReplayOptimizer( + workers, + learning_starts=config["learning_starts"], + buffer_size=config["replay_buffer_size"], + train_batch_size=config["train_batch_size"], + ) - _name = "MARWIL" - _default_config = DEFAULT_CONFIG - _policy = MARWILPolicy - @override(Trainer) - def _init(self, config, env_creator): - self.workers = self._make_workers(env_creator, self._policy, config, - config["num_workers"]) - self.optimizer = SyncBatchReplayOptimizer( - self.workers, - learning_starts=config["learning_starts"], - buffer_size=config["replay_buffer_size"], - train_batch_size=config["train_batch_size"], - ) - - @override(Trainer) - def _train(self): - prev_steps = self.optimizer.num_steps_sampled - fetches = self.optimizer.step() - res = self.collect_metrics() - res.update( - timesteps_this_iter=self.optimizer.num_steps_sampled - prev_steps, - info=dict(fetches, **res.get("info", {}))) - return res +MARWILTrainer = build_trainer( + name="MARWIL", + default_config=DEFAULT_CONFIG, + default_policy=MARWILPolicy, + make_policy_optimizer=make_optimizer) diff --git a/python/ray/rllib/agents/ppo/appo.py b/python/ray/rllib/agents/ppo/appo.py index 0438b2714221..4b0d9945dec3 100644 --- a/python/ray/rllib/agents/ppo/appo.py +++ b/python/ray/rllib/agents/ppo/appo.py @@ -5,7 +5,6 @@ from ray.rllib.agents.ppo.appo_policy import AsyncPPOTFPolicy from ray.rllib.agents.trainer import with_base_config from ray.rllib.agents import impala -from ray.rllib.utils.annotations import override # yapf: disable # __sphinx_doc_begin__ @@ -51,14 +50,8 @@ # __sphinx_doc_end__ # yapf: enable - -class APPOTrainer(impala.ImpalaTrainer): - """PPO surrogate loss with IMPALA-architecture.""" - - _name = "APPO" - _default_config = DEFAULT_CONFIG - _policy = AsyncPPOTFPolicy - - @override(impala.ImpalaTrainer) - def _get_policy(self): - return AsyncPPOTFPolicy +APPOTrainer = impala.ImpalaTrainer.with_updates( + name="APPO", + default_config=DEFAULT_CONFIG, + default_policy=AsyncPPOTFPolicy, + get_policy_class=lambda _: AsyncPPOTFPolicy) diff --git a/python/ray/rllib/agents/qmix/apex.py b/python/ray/rllib/agents/qmix/apex.py index 65c91d655af2..aac5d83f726a 100644 --- a/python/ray/rllib/agents/qmix/apex.py +++ b/python/ray/rllib/agents/qmix/apex.py @@ -4,15 +4,14 @@ from __future__ import division from __future__ import print_function +from ray.rllib.agents.dqn.apex import APEX_TRAINER_PROPERTIES from ray.rllib.agents.qmix.qmix import QMixTrainer, \ DEFAULT_CONFIG as QMIX_CONFIG -from ray.rllib.utils.annotations import override from ray.rllib.utils import merge_dicts APEX_QMIX_DEFAULT_CONFIG = merge_dicts( QMIX_CONFIG, # see also the options in qmix.py, which are also supported { - "optimizer_class": "AsyncReplayOptimizer", "optimizer": merge_dicts( QMIX_CONFIG["optimizer"], { @@ -34,23 +33,7 @@ }, ) - -class ApexQMixTrainer(QMixTrainer): - """QMIX variant that uses the Ape-X distributed policy optimizer. - - By default, this is configured for a large single node (32 cores). For - running in a large cluster, increase the `num_workers` config var. - """ - - _name = "APEX_QMIX" - _default_config = APEX_QMIX_DEFAULT_CONFIG - - @override(QMixTrainer) - def update_target_if_needed(self): - # Ape-X updates based on num steps trained, not sampled - if self.optimizer.num_steps_trained - self.last_target_update_ts > \ - self.config["target_network_update_freq"]: - self.workers.local_worker().foreach_trainable_policy( - lambda p, _: p.update_target()) - self.last_target_update_ts = self.optimizer.num_steps_trained - self.num_target_updates += 1 +ApexQMixTrainer = QMixTrainer.with_updates( + name="APEX_QMIX", + default_config=APEX_QMIX_DEFAULT_CONFIG, + **APEX_TRAINER_PROPERTIES) diff --git a/python/ray/rllib/agents/qmix/qmix.py b/python/ray/rllib/agents/qmix/qmix.py index 2ad6a3e56f95..6a5bff9d63e8 100644 --- a/python/ray/rllib/agents/qmix/qmix.py +++ b/python/ray/rllib/agents/qmix/qmix.py @@ -3,8 +3,9 @@ from __future__ import print_function from ray.rllib.agents.trainer import with_common_config -from ray.rllib.agents.dqn.dqn import DQNTrainer +from ray.rllib.agents.dqn.dqn import GenericOffPolicyTrainer from ray.rllib.agents.qmix.qmix_policy import QMixTorchPolicy +from ray.rllib.optimizers import SyncBatchReplayOptimizer # yapf: disable # __sphinx_doc_begin__ @@ -71,8 +72,6 @@ # to increase if your environment is particularly slow to sample, or if # you"re using the Async or Ape-X optimizers. "num_workers": 0, - # Optimizer class to use. - "optimizer_class": "SyncBatchReplayOptimizer", # Whether to use a distribution of epsilons across workers for exploration. "per_worker_exploration": False, # Whether to compute priorities on workers. @@ -90,12 +89,16 @@ # yapf: enable -class QMixTrainer(DQNTrainer): - """QMix implementation in PyTorch.""" +def make_sync_batch_optimizer(workers, config): + return SyncBatchReplayOptimizer( + workers, + learning_starts=config["learning_starts"], + buffer_size=config["buffer_size"], + train_batch_size=config["train_batch_size"]) - _name = "QMIX" - _default_config = DEFAULT_CONFIG - _policy = QMixTorchPolicy - _optimizer_shared_configs = [ - "learning_starts", "buffer_size", "train_batch_size" - ] + +QMixTrainer = GenericOffPolicyTrainer.with_updates( + name="QMIX", + default_config=DEFAULT_CONFIG, + default_policy=QMixTorchPolicy, + make_policy_optimizer=make_sync_batch_optimizer) diff --git a/python/ray/rllib/agents/trainer.py b/python/ray/rllib/agents/trainer.py index f08b23e93fd7..78d8ecf14366 100644 --- a/python/ray/rllib/agents/trainer.py +++ b/python/ray/rllib/agents/trainer.py @@ -186,6 +186,9 @@ "remote_env_batch_wait_ms": 0, # Minimum time per iteration "min_iter_time_s": 0, + # Minimum env steps to optimize for per train call. This value does + # not affect learning, only the length of iterations. + "timesteps_per_iteration": 0, # === Offline Datasets === # Specify how to generate experiences: @@ -499,6 +502,7 @@ def _evaluate(self): logger.info("Evaluating current policy for {} episodes".format( self.config["evaluation_num_episodes"])) + self._before_evaluate() self.evaluation_workers.local_worker().restore( self.workers.local_worker().save()) for _ in range(self.config["evaluation_num_episodes"]): @@ -507,6 +511,11 @@ def _evaluate(self): metrics = collect_metrics(self.evaluation_workers.local_worker()) return {"evaluation": metrics} + @DeveloperAPI + def _before_evaluate(self): + """Pre-evaluation callback.""" + pass + @PublicAPI def compute_action(self, observation, diff --git a/python/ray/rllib/agents/trainer_template.py b/python/ray/rllib/agents/trainer_template.py index 6af9e1c781e0..ee0b4181c337 100644 --- a/python/ray/rllib/agents/trainer_template.py +++ b/python/ray/rllib/agents/trainer_template.py @@ -6,6 +6,7 @@ from ray.rllib.agents.trainer import Trainer, COMMON_CONFIG from ray.rllib.optimizers import SyncSamplesOptimizer +from ray.rllib.utils import add_mixins from ray.rllib.utils.annotations import override, DeveloperAPI @@ -13,25 +14,47 @@ def build_trainer(name, default_policy, default_config=None, - make_policy_optimizer=None, validate_config=None, + get_initial_state=None, get_policy_class=None, + before_init=None, + make_workers=None, + make_policy_optimizer=None, + after_init=None, before_train_step=None, after_optimizer_step=None, - after_train_result=None): + after_train_result=None, + collect_metrics_fn=None, + before_evaluate_fn=None, + mixins=None): """Helper function for defining a custom trainer. + Functions will be run in this order to initialize the trainer: + 1. Config setup: validate_config, get_initial_state, get_policy + 2. Worker setup: before_init, make_workers, make_policy_optimizer + 3. Post setup: after_init + Arguments: name (str): name of the trainer (e.g., "PPO") default_policy (cls): the default Policy class to use default_config (dict): the default config dict of the algorithm, otherwises uses the Trainer default config - make_policy_optimizer (func): optional function that returns a - PolicyOptimizer instance given (WorkerSet, config) validate_config (func): optional callback that checks a given config for correctness. It may mutate the config as needed. + get_initial_state (func): optional function that returns the initial + state dict given the trainer instance as an argument. The state + dict must be serializable so that it can be checkpointed, and will + be available as the `trainer.state` variable. get_policy_class (func): optional callback that takes a config and returns the policy class to override the default with + before_init (func): optional function to run at the start of trainer + init that takes the trainer instance as argument + make_workers (func): override the method that creates rollout workers. + This takes in (trainer, env_creator, policy, config) as args. + make_policy_optimizer (func): optional function that returns a + PolicyOptimizer instance given (WorkerSet, config) + after_init (func): optional function to run at the end of trainer init + that takes the trainer instance as argument before_train_step (func): optional callback to run before each train() call. It takes the trainer instance as an argument. after_optimizer_step (func): optional callback to run after each @@ -40,27 +63,47 @@ def build_trainer(name, after_train_result (func): optional callback to run at the end of each train() call. It takes the trainer instance and result dict as arguments, and may mutate the result dict as needed. + collect_metrics_fn (func): override the method used to collect metrics. + It takes the trainer instance as argumnt. + before_evaluate_fn (func): callback to run before evaluation. This + takes the trainer instance as argument. + mixins (list): list of any class mixins for the returned trainer class. + These mixins will be applied in order and will have higher + precedence than the Trainer class Returns: a Trainer instance that uses the specified args. """ original_kwargs = locals().copy() + base = add_mixins(Trainer, mixins) - class trainer_cls(Trainer): + class trainer_cls(base): _name = name _default_config = default_config or COMMON_CONFIG _policy = default_policy + def __init__(self, config=None, env=None, logger_creator=None): + Trainer.__init__(self, config, env, logger_creator) + def _init(self, config, env_creator): if validate_config: validate_config(config) + if get_initial_state: + self.state = get_initial_state(self) + else: + self.state = {} if get_policy_class is None: policy = default_policy else: policy = get_policy_class(config) - self.workers = self._make_workers(env_creator, policy, config, - self.config["num_workers"]) + if before_init: + before_init(self) + if make_workers: + self.workers = make_workers(self, env_creator, policy, config) + else: + self.workers = self._make_workers(env_creator, policy, config, + self.config["num_workers"]) if make_policy_optimizer: self.optimizer = make_policy_optimizer(self.workers, config) else: @@ -69,6 +112,8 @@ def _init(self, config, env_creator): **{"train_batch_size": config["train_batch_size"]}) self.optimizer = SyncSamplesOptimizer(self.workers, **optimizer_config) + if after_init: + after_init(self) @override(Trainer) def _train(self): @@ -81,20 +126,46 @@ def _train(self): fetches = self.optimizer.step() if after_optimizer_step: after_optimizer_step(self, fetches) - if time.time() - start > self.config["min_iter_time_s"]: + if (time.time() - start >= self.config["min_iter_time_s"] + and self.optimizer.num_steps_sampled - prev_steps >= + self.config["timesteps_per_iteration"]): break - res = self.collect_metrics() + if collect_metrics_fn: + res = collect_metrics_fn(self) + else: + res = self.collect_metrics() res.update( timesteps_this_iter=self.optimizer.num_steps_sampled - prev_steps, info=res.get("info", {})) + if after_train_result: after_train_result(self, res) return res + @override(Trainer) + def _before_evaluate(self): + if before_evaluate_fn: + before_evaluate_fn(self) + + def __getstate__(self): + state = Trainer.__getstate__(self) + state.update(self.state) + return state + + def __setstate__(self, state): + Trainer.__setstate__(self, state) + self.state = state + @staticmethod def with_updates(**overrides): + """Build a copy of this trainer with the specified overrides. + + Arguments: + overrides (dict): use this to override any of the arguments + originally passed to build_trainer() for this policy. + """ return build_trainer(**dict(original_kwargs, **overrides)) trainer_cls.with_updates = with_updates diff --git a/python/ray/rllib/policy/tf_policy_template.py b/python/ray/rllib/policy/tf_policy_template.py index b7f33fcb0887..37828bfe18b0 100644 --- a/python/ray/rllib/policy/tf_policy_template.py +++ b/python/ray/rllib/policy/tf_policy_template.py @@ -5,6 +5,7 @@ from ray.rllib.policy.dynamic_tf_policy import DynamicTFPolicy from ray.rllib.policy.policy import Policy, LEARNER_STATS_KEY from ray.rllib.policy.tf_policy import TFPolicy +from ray.rllib.utils import add_mixins from ray.rllib.utils.annotations import override, DeveloperAPI @@ -89,13 +90,7 @@ def build_tf_policy(name, """ original_kwargs = locals().copy() - base = DynamicTFPolicy - while mixins: - - class new_base(mixins.pop(), base): - pass - - base = new_base + base = add_mixins(DynamicTFPolicy, mixins) class policy_cls(base): def __init__(self, diff --git a/python/ray/rllib/policy/torch_policy_template.py b/python/ray/rllib/policy/torch_policy_template.py index 1f4185f9c12e..f1b0c0c682d6 100644 --- a/python/ray/rllib/policy/torch_policy_template.py +++ b/python/ray/rllib/policy/torch_policy_template.py @@ -5,6 +5,7 @@ from ray.rllib.policy.policy import Policy from ray.rllib.policy.torch_policy import TorchPolicy from ray.rllib.models.catalog import ModelCatalog +from ray.rllib.utils import add_mixins from ray.rllib.utils.annotations import override, DeveloperAPI @@ -56,13 +57,7 @@ def build_torch_policy(name, """ original_kwargs = locals().copy() - base = TorchPolicy - while mixins: - - class new_base(mixins.pop(), base): - pass - - base = new_base + base = add_mixins(TorchPolicy, mixins) class policy_cls(base): def __init__(self, obs_space, action_space, config): diff --git a/python/ray/rllib/utils/__init__.py b/python/ray/rllib/utils/__init__.py index aad5590fd097..bde901e22a9c 100644 --- a/python/ray/rllib/utils/__init__.py +++ b/python/ray/rllib/utils/__init__.py @@ -27,6 +27,21 @@ def __init__(self, *args, **kw): return DeprecationWrapper +def add_mixins(base, mixins): + """Returns a new class with mixins applied in priority order.""" + + mixins = list(mixins or []) + + while mixins: + + class new_base(mixins.pop(), base): + pass + + base = new_base + + return base + + def renamed_agent(cls): """Helper class for renaming Agent => Trainer with a warning."""