diff --git a/python/ray/experimental/internal_kv.py b/python/ray/experimental/internal_kv.py new file mode 100644 index 000000000000..99b2d73b520e --- /dev/null +++ b/python/ray/experimental/internal_kv.py @@ -0,0 +1,31 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import ray + + +def _internal_kv_initialized(): + worker = ray.worker.get_global_worker() + return hasattr(worker, "mode") and worker.mode is not None + + +def _internal_kv_get(key): + """Fetch the value of a binary key.""" + + worker = ray.worker.get_global_worker() + return worker.redis_client.hget(key, "value") + + +def _internal_kv_put(key, value): + """Globally associates a value with a given binary key. + + This only has an effect if the key does not already have a value. + + Returns + already_exists (bool): whether the value already exists. + """ + + worker = ray.worker.get_global_worker() + updated = worker.redis_client.hsetnx(key, "value", value) + return updated == 0 # already exists diff --git a/python/ray/experimental/named_actors.py b/python/ray/experimental/named_actors.py index 9ae7972fc37c..54deddd1f7d0 100644 --- a/python/ray/experimental/named_actors.py +++ b/python/ray/experimental/named_actors.py @@ -4,6 +4,7 @@ import ray import ray.cloudpickle as pickle +from ray.experimental.internal_kv import _internal_kv_get, _internal_kv_put def _calculate_key(name): @@ -29,9 +30,8 @@ def get_actor(name): Returns: The ActorHandle object corresponding to the name. """ - worker = ray.worker.get_global_worker() - actor_hash = _calculate_key(name) - pickled_state = worker.redis_client.hget(actor_hash, name) + actor_name = _calculate_key(name) + pickled_state = _internal_kv_get(actor_name) if pickled_state is None: raise ValueError("The actor with name={} doesn't exist".format(name)) handle = pickle.loads(pickled_state) @@ -45,17 +45,16 @@ def register_actor(name, actor_handle): name: The name of the named actor. actor_handle: The actor object to be associated with this name """ - worker = ray.worker.get_global_worker() if not isinstance(name, str): raise TypeError("The name argument must be a string.") if not isinstance(actor_handle, ray.actor.ActorHandle): raise TypeError("The actor_handle argument must be an ActorHandle " "object.") - actor_hash = _calculate_key(name) + actor_name = _calculate_key(name) pickled_state = pickle.dumps(actor_handle) # Add the actor to Redis if it does not already exist. - updated = worker.redis_client.hsetnx(actor_hash, name, pickled_state) - if updated == 0: + already_exists = _internal_kv_put(actor_name, pickled_state) + if already_exists: raise ValueError( "Error: the actor with name={} already exists".format(name)) diff --git a/python/ray/rllib/a3c/a3c.py b/python/ray/rllib/a3c/a3c.py index dcacc5fdc46c..6375bc90d439 100644 --- a/python/ray/rllib/a3c/a3c.py +++ b/python/ray/rllib/a3c/a3c.py @@ -102,7 +102,7 @@ def session_creator(): batch_steps=self.config["batch_size"], batch_mode="truncate_episodes", tf_session_creator=session_creator, - registry=self.registry, env_config=self.config["env_config"], + env_config=self.config["env_config"], model_config=self.config["model"], policy_config=self.config, num_envs=self.config["num_envs"]) self.remote_evaluators = [ @@ -111,7 +111,7 @@ def session_creator(): batch_steps=self.config["batch_size"], batch_mode="truncate_episodes", sample_async=True, tf_session_creator=session_creator, - registry=self.registry, env_config=self.config["env_config"], + env_config=self.config["env_config"], model_config=self.config["model"], policy_config=self.config, num_envs=self.config["num_envs"]) for i in range(self.config["num_workers"])] diff --git a/python/ray/rllib/a3c/a3c_tf_policy.py b/python/ray/rllib/a3c/a3c_tf_policy.py index e2a8da233880..8532734c2561 100644 --- a/python/ray/rllib/a3c/a3c_tf_policy.py +++ b/python/ray/rllib/a3c/a3c_tf_policy.py @@ -13,8 +13,7 @@ class A3CTFPolicyGraph(TFPolicyGraph): """The TF policy base class.""" - def __init__(self, ob_space, action_space, registry, config): - self.registry = registry + def __init__(self, ob_space, action_space, config): self.local_steps = 0 self.config = config self.summarize = config.get("summarize") diff --git a/python/ray/rllib/a3c/a3c_torch_policy.py b/python/ray/rllib/a3c/a3c_torch_policy.py index 5a654fa5732c..773240f5d382 100644 --- a/python/ray/rllib/a3c/a3c_torch_policy.py +++ b/python/ray/rllib/a3c/a3c_torch_policy.py @@ -17,8 +17,7 @@ class SharedTorchPolicy(PolicyGraph): """A simple, non-recurrent PyTorch policy example.""" - def __init__(self, obs_space, action_space, registry, config): - self.registry = registry + def __init__(self, obs_space, action_space, config): self.local_steps = 0 self.config = config self.summarize = config.get("summarize") @@ -29,8 +28,7 @@ def __init__(self, obs_space, action_space, registry, config): def setup_graph(self, obs_space, action_space): _, self.logit_dim = ModelCatalog.get_action_dist(action_space) self._model = ModelCatalog.get_torch_model( - self.registry, obs_space.shape, self.logit_dim, - self.config["model"]) + obs_space.shape, self.logit_dim, self.config["model"]) self.optimizer = torch.optim.Adam( self._model.parameters(), lr=self.config["lr"]) diff --git a/python/ray/rllib/a3c/shared_model.py b/python/ray/rllib/a3c/shared_model.py index 3a093fa906f8..f77534c0eee2 100644 --- a/python/ray/rllib/a3c/shared_model.py +++ b/python/ray/rllib/a3c/shared_model.py @@ -10,15 +10,15 @@ class SharedModel(A3CTFPolicyGraph): - def __init__(self, ob_space, ac_space, registry, config, **kwargs): + def __init__(self, ob_space, ac_space, config, **kwargs): super(SharedModel, self).__init__( - ob_space, ac_space, registry, config, **kwargs) + ob_space, ac_space, config, **kwargs) def _setup_graph(self, ob_space, ac_space): self.x = tf.placeholder(tf.float32, [None] + list(ob_space.shape)) dist_class, self.logit_dim = ModelCatalog.get_action_dist(ac_space) self._model = ModelCatalog.get_model( - self.registry, self.x, self.logit_dim, self.config["model"]) + self.x, self.logit_dim, self.config["model"]) self.logits = self._model.outputs self.action_dist = dist_class(self.logits) self.vf = tf.reshape(linear(self._model.last_layer, 1, "value", diff --git a/python/ray/rllib/a3c/shared_model_lstm.py b/python/ray/rllib/a3c/shared_model_lstm.py index 7cb64e684aa6..d2bc1adde2ab 100644 --- a/python/ray/rllib/a3c/shared_model_lstm.py +++ b/python/ray/rllib/a3c/shared_model_lstm.py @@ -11,9 +11,9 @@ class SharedModelLSTM(A3CTFPolicyGraph): - def __init__(self, ob_space, ac_space, registry, config, **kwargs): + def __init__(self, ob_space, ac_space, config, **kwargs): super(SharedModelLSTM, self).__init__( - ob_space, ac_space, registry, config, **kwargs) + ob_space, ac_space, config, **kwargs) def _setup_graph(self, ob_space, ac_space): self.x = tf.placeholder(tf.float32, [None] + list(ob_space.shape)) diff --git a/python/ray/rllib/agent.py b/python/ray/rllib/agent.py index bbcc07fccadb..5e1db81c06a9 100644 --- a/python/ray/rllib/agent.py +++ b/python/ray/rllib/agent.py @@ -9,7 +9,7 @@ import pickle import tensorflow as tf -from ray.tune.registry import ENV_CREATOR +from ray.tune.registry import ENV_CREATOR, _global_registry from ray.tune.result import TrainingResult from ray.tune.trainable import Trainable @@ -56,8 +56,6 @@ class Agent(Trainable): env_creator (func): Function that creates a new training env. config (obj): Algorithm-specific configuration data. logdir (str): Directory in which training outputs should be placed. - registry (obj): Tune object registry which holds user-registered - classes and objects by name. """ _allow_unknown_configs = False @@ -72,16 +70,13 @@ def resource_help(cls, config): "The config of this agent is: " + json.dumps(config)) def __init__( - self, config=None, env=None, registry=None, - logger_creator=None): + self, config=None, env=None, logger_creator=None): """Initialize an RLLib agent. Args: config (dict): Algorithm-specific configuration data. env (str): Name of the environment to use. Note that this can also be specified as the `env` key in config. - registry (obj): Object registry for user-defined envs, models, etc. - If unspecified, the default registry will be used. logger_creator (func): Function that creates a ray.tune.Logger object. If unspecified, a default logger is created. """ @@ -90,14 +85,14 @@ def __init__( # Agents allow env ids to be passed directly to the constructor. self._env_id = env or config.get("env") - Trainable.__init__(self, config, registry, logger_creator) + Trainable.__init__(self, config, logger_creator) def _setup(self): env = self._env_id if env: self.config["env"] = env - if self.registry and self.registry.contains(ENV_CREATOR, env): - self.env_creator = self.registry.get(ENV_CREATOR, env) + if _global_registry.contains(ENV_CREATOR, env): + self.env_creator = _global_registry.get(ENV_CREATOR, env) else: import gym # soft dependency self.env_creator = lambda env_config: gym.make(env) diff --git a/python/ray/rllib/bc/bc.py b/python/ray/rllib/bc/bc.py index cdfc7ab98878..501f535215f5 100644 --- a/python/ray/rllib/bc/bc.py +++ b/python/ray/rllib/bc/bc.py @@ -63,14 +63,13 @@ def default_resource_request(cls, config): def _init(self): self.local_evaluator = BCEvaluator( - self.registry, self.env_creator, self.config, self.logdir) + self.env_creator, self.config, self.logdir) if self.config["use_gpu_for_workers"]: remote_cls = GPURemoteBCEvaluator else: remote_cls = RemoteBCEvaluator self.remote_evaluators = [ - remote_cls.remote( - self.registry, self.env_creator, self.config, self.logdir) + remote_cls.remote(self.env_creator, self.config, self.logdir) for _ in range(self.config["num_workers"])] self.optimizer = AsyncOptimizer( self.config["optimizer"], self.local_evaluator, diff --git a/python/ray/rllib/bc/bc_evaluator.py b/python/ray/rllib/bc/bc_evaluator.py index 22739d5928fb..87a7d497656e 100644 --- a/python/ray/rllib/bc/bc_evaluator.py +++ b/python/ray/rllib/bc/bc_evaluator.py @@ -13,12 +13,11 @@ class BCEvaluator(PolicyEvaluator): - def __init__(self, registry, env_creator, config, logdir): - env = ModelCatalog.get_preprocessor_as_wrapper(registry, env_creator( + def __init__(self, env_creator, config, logdir): + env = ModelCatalog.get_preprocessor_as_wrapper(env_creator( config["env_config"]), config["model"]) self.dataset = ExperienceDataset(config["dataset_path"]) - self.policy = BCPolicy(registry, env.observation_space, - env.action_space, config) + self.policy = BCPolicy(env.observation_space, env.action_space, config) self.config = config self.logdir = logdir self.metrics_queue = queue.Queue() diff --git a/python/ray/rllib/bc/policy.py b/python/ray/rllib/bc/policy.py index 2c4210a57cf5..998d39249004 100644 --- a/python/ray/rllib/bc/policy.py +++ b/python/ray/rllib/bc/policy.py @@ -10,8 +10,7 @@ class BCPolicy(object): - def __init__(self, registry, obs_space, action_space, config): - self.registry = registry + def __init__(self, obs_space, action_space, config): self.local_steps = 0 self.config = config self.summarize = config.get("summarize") @@ -24,7 +23,7 @@ def _setup_graph(self, obs_space, ac_space): self.x = tf.placeholder(tf.float32, [None] + list(obs_space.shape)) dist_class, self.logit_dim = ModelCatalog.get_action_dist(ac_space) self._model = ModelCatalog.get_model( - self.registry, self.x, self.logit_dim, self.config["model"]) + self.x, self.logit_dim, self.config["model"]) self.logits = self._model.outputs self.curr_dist = dist_class(self.logits) self.sample = self.curr_dist.sample() diff --git a/python/ray/rllib/ddpg/ddpg_policy_graph.py b/python/ray/rllib/ddpg/ddpg_policy_graph.py index 51572659b4e9..da1b64a3026b 100644 --- a/python/ray/rllib/ddpg/ddpg_policy_graph.py +++ b/python/ray/rllib/ddpg/ddpg_policy_graph.py @@ -22,12 +22,12 @@ Q_TARGET_SCOPE = "target_q_func" -def _build_p_network(registry, inputs, dim_actions, config): +def _build_p_network(inputs, dim_actions, config): """ map an observation (i.e., state) to an action where each entry takes value from (0, 1) due to the sigmoid function """ - frontend = ModelCatalog.get_model(registry, inputs, 1, config["model"]) + frontend = ModelCatalog.get_model(inputs, 1, config["model"]) hiddens = config["actor_hiddens"] action_out = frontend.last_layer @@ -66,8 +66,8 @@ def _build_action_network(p_values, low_action, high_action, stochastic, eps, lambda: deterministic_actions) -def _build_q_network(registry, inputs, action_inputs, config): - frontend = ModelCatalog.get_model(registry, inputs, 1, config["model"]) +def _build_q_network(inputs, action_inputs, config): + frontend = ModelCatalog.get_model(inputs, 1, config["model"]) hiddens = config["critic_hiddens"] @@ -81,7 +81,7 @@ def _build_q_network(registry, inputs, action_inputs, config): class DDPGPolicyGraph(TFPolicyGraph): - def __init__(self, observation_space, action_space, registry, config): + def __init__(self, observation_space, action_space, config): if not isinstance(action_space, Box): raise UnsupportedSpaceException( "Action space {} is not supported for DDPG.".format( @@ -105,7 +105,7 @@ def __init__(self, observation_space, action_space, registry, config): # Actor: P (policy) network with tf.variable_scope(P_SCOPE) as scope: - p_values = _build_p_network(registry, self.cur_observations, + p_values = _build_p_network(self.cur_observations, dim_actions, config) self.p_func_vars = _scope_vars(scope.name) @@ -136,13 +136,11 @@ def __init__(self, observation_space, action_space, registry, config): # p network evaluation with tf.variable_scope(P_SCOPE, reuse=True) as scope: - self.p_t = _build_p_network( - registry, self.obs_t, dim_actions, config) + self.p_t = _build_p_network(self.obs_t, dim_actions, config) # target p network evaluation with tf.variable_scope(P_TARGET_SCOPE) as scope: - p_tp1 = _build_p_network( - registry, self.obs_tp1, dim_actions, config) + p_tp1 = _build_p_network(self.obs_tp1, dim_actions, config) target_p_func_vars = _scope_vars(scope.name) # Action outputs @@ -161,17 +159,15 @@ def __init__(self, observation_space, action_space, registry, config): # q network evaluation with tf.variable_scope(Q_SCOPE) as scope: - q_t = _build_q_network( - registry, self.obs_t, self.act_t, config) + q_t = _build_q_network(self.obs_t, self.act_t, config) self.q_func_vars = _scope_vars(scope.name) with tf.variable_scope(Q_SCOPE, reuse=True): - q_tp0 = _build_q_network( - registry, self.obs_t, output_actions, config) + q_tp0 = _build_q_network(self.obs_t, output_actions, config) # target q network evalution with tf.variable_scope(Q_TARGET_SCOPE) as scope: q_tp1 = _build_q_network( - registry, self.obs_tp1, output_actions_estimated, config) + self.obs_tp1, output_actions_estimated, config) target_q_func_vars = _scope_vars(scope.name) q_t_selected = tf.squeeze(q_t, axis=len(q_t.shape) - 1) diff --git a/python/ray/rllib/dqn/common/wrappers.py b/python/ray/rllib/dqn/common/wrappers.py index a968888aab71..cfcd90ddd34e 100644 --- a/python/ray/rllib/dqn/common/wrappers.py +++ b/python/ray/rllib/dqn/common/wrappers.py @@ -6,7 +6,7 @@ from ray.rllib.utils.atari_wrappers import wrap_deepmind -def wrap_dqn(registry, env, options, random_starts): +def wrap_dqn(env, options, random_starts): """Apply a common set of wrappers for DQN.""" is_atari = hasattr(env.unwrapped, "ale") @@ -17,4 +17,4 @@ def wrap_dqn(registry, env, options, random_starts): return wrap_deepmind( env, random_starts=random_starts, dim=options.get("dim", 80)) - return ModelCatalog.get_preprocessor_as_wrapper(registry, env, options) + return ModelCatalog.get_preprocessor_as_wrapper(env, options) diff --git a/python/ray/rllib/dqn/dqn.py b/python/ray/rllib/dqn/dqn.py index 2824f36d7359..83dc1078ebc4 100644 --- a/python/ray/rllib/dqn/dqn.py +++ b/python/ray/rllib/dqn/dqn.py @@ -129,7 +129,7 @@ def _init(self): batch_steps=adjusted_batch_size, batch_mode="truncate_episodes", preprocessor_pref="deepmind", compress_observations=True, - registry=self.registry, env_config=self.config["env_config"], + env_config=self.config["env_config"], model_config=self.config["model"], policy_config=self.config, num_envs=self.config["num_envs"]) remote_cls = CommonPolicyEvaluator.as_remote( @@ -141,7 +141,7 @@ def _init(self): batch_steps=adjusted_batch_size, batch_mode="truncate_episodes", preprocessor_pref="deepmind", compress_observations=True, - registry=self.registry, env_config=self.config["env_config"], + env_config=self.config["env_config"], model_config=self.config["model"], policy_config=self.config, num_envs=self.config["num_envs"]) for _ in range(self.config["num_workers"])] diff --git a/python/ray/rllib/dqn/dqn_policy_graph.py b/python/ray/rllib/dqn/dqn_policy_graph.py index ffafc5be5231..9c7ceedc4e19 100644 --- a/python/ray/rllib/dqn/dqn_policy_graph.py +++ b/python/ray/rllib/dqn/dqn_policy_graph.py @@ -46,7 +46,7 @@ def adjust_nstep(n_step, gamma, obs, actions, rewards, new_obs, dones): class DQNPolicyGraph(TFPolicyGraph): - def __init__(self, observation_space, action_space, registry, config): + def __init__(self, observation_space, action_space, config): if not isinstance(action_space, Discrete): raise UnsupportedSpaceException( "Action space {} is not supported for DQN.".format( @@ -65,7 +65,7 @@ def __init__(self, observation_space, action_space, registry, config): # Action Q network with tf.variable_scope(Q_SCOPE) as scope: q_values = _build_q_network( - registry, self.cur_observations, num_actions, config) + self.cur_observations, num_actions, config) self.q_func_vars = _scope_vars(scope.name) # Action outputs @@ -89,13 +89,11 @@ def __init__(self, observation_space, action_space, registry, config): # q network evaluation with tf.variable_scope(Q_SCOPE, reuse=True): - q_t = _build_q_network( - registry, self.obs_t, num_actions, config) + q_t = _build_q_network(self.obs_t, num_actions, config) # target q network evalution with tf.variable_scope(Q_TARGET_SCOPE) as scope: - q_tp1 = _build_q_network( - registry, self.obs_tp1, num_actions, config) + q_tp1 = _build_q_network(self.obs_tp1, num_actions, config) self.target_q_func_vars = _scope_vars(scope.name) # q scores for actions which we know were selected in the given state. @@ -106,7 +104,7 @@ def __init__(self, observation_space, action_space, registry, config): if config["double_q"]: with tf.variable_scope(Q_SCOPE, reuse=True): q_tp1_using_online_net = _build_q_network( - registry, self.obs_tp1, num_actions, config) + self.obs_tp1, num_actions, config) q_tp1_best_using_online_net = tf.argmax(q_tp1_using_online_net, 1) q_tp1_best = tf.reduce_sum( q_tp1 * tf.one_hot( @@ -236,10 +234,10 @@ def _postprocess_dqn(policy_graph, sample_batch): return batch -def _build_q_network(registry, inputs, num_actions, config): +def _build_q_network(inputs, num_actions, config): dueling = config["dueling"] hiddens = config["hiddens"] - frontend = ModelCatalog.get_model(registry, inputs, 1, config["model"]) + frontend = ModelCatalog.get_model(inputs, 1, config["model"]) frontend_out = frontend.last_layer with tf.variable_scope("action_value"): diff --git a/python/ray/rllib/es/es.py b/python/ray/rllib/es/es.py index 23b7268efafb..8e5dbe06417b 100644 --- a/python/ray/rllib/es/es.py +++ b/python/ray/rllib/es/es.py @@ -64,7 +64,7 @@ def sample_index(self, dim): @ray.remote class Worker(object): - def __init__(self, registry, config, policy_params, env_creator, noise, + def __init__(self, config, policy_params, env_creator, noise, min_task_runtime=0.2): self.min_task_runtime = min_task_runtime self.config = config @@ -73,12 +73,11 @@ def __init__(self, registry, config, policy_params, env_creator, noise, self.env = env_creator(config["env_config"]) from ray.rllib import models - self.preprocessor = models.ModelCatalog.get_preprocessor( - registry, self.env) + self.preprocessor = models.ModelCatalog.get_preprocessor(self.env) self.sess = utils.make_session(single_threaded=True) self.policy = policies.GenericPolicy( - registry, self.sess, self.env.action_space, self.preprocessor, + self.sess, self.env.action_space, self.preprocessor, config["observation_filter"], **policy_params) def rollout(self, timestep_limit, add_noise=True): @@ -152,12 +151,11 @@ def _init(self): env = self.env_creator(self.config["env_config"]) from ray.rllib import models - preprocessor = models.ModelCatalog.get_preprocessor( - self.registry, env) + preprocessor = models.ModelCatalog.get_preprocessor(env) self.sess = utils.make_session(single_threaded=False) self.policy = policies.GenericPolicy( - self.registry, self.sess, env.action_space, preprocessor, + self.sess, env.action_space, preprocessor, self.config["observation_filter"], **policy_params) self.optimizer = optimizers.Adam(self.policy, self.config["stepsize"]) @@ -170,8 +168,7 @@ def _init(self): print("Creating actors.") self.workers = [ Worker.remote( - self.registry, self.config, policy_params, self.env_creator, - noise_id) + self.config, policy_params, self.env_creator, noise_id) for _ in range(self.config["num_workers"])] self.episodes_so_far = 0 diff --git a/python/ray/rllib/es/policies.py b/python/ray/rllib/es/policies.py index a1746673bd1c..eb492373ff83 100644 --- a/python/ray/rllib/es/policies.py +++ b/python/ray/rllib/es/policies.py @@ -38,7 +38,7 @@ def rollout(policy, env, timestep_limit=None, add_noise=False): class GenericPolicy(object): - def __init__(self, registry, sess, action_space, preprocessor, + def __init__(self, sess, action_space, preprocessor, observation_filter, action_noise_std): self.sess = sess self.action_space = action_space @@ -52,7 +52,7 @@ def __init__(self, registry, sess, action_space, preprocessor, # Policy network. dist_class, dist_dim = ModelCatalog.get_action_dist( self.action_space, dist_type="deterministic") - model = ModelCatalog.get_model(registry, self.inputs, dist_dim) + model = ModelCatalog.get_model(self.inputs, dist_dim) dist = dist_class(model.outputs) self.sampler = dist.sample() diff --git a/python/ray/rllib/examples/multiagent_mountaincar.py b/python/ray/rllib/examples/multiagent_mountaincar.py index 74f818d7e552..29a7590b3407 100644 --- a/python/ray/rllib/examples/multiagent_mountaincar.py +++ b/python/ray/rllib/examples/multiagent_mountaincar.py @@ -8,7 +8,7 @@ import ray import ray.rllib.ppo as ppo -from ray.tune.registry import get_registry, register_env +from ray.tune.registry import register_env env_name = "MultiAgentMountainCarEnv" @@ -51,6 +51,6 @@ def create_env(env_config): "multiagent_shared_model": False, "multiagent_fcnet_hiddens": [[32, 32]] * 2} config["model"].update({"custom_options": options}) - alg = ppo.PPOAgent(env=env_name, registry=get_registry(), config=config) + alg = ppo.PPOAgent(env=env_name, config=config) for i in range(1): alg.train() diff --git a/python/ray/rllib/examples/multiagent_pendulum.py b/python/ray/rllib/examples/multiagent_pendulum.py index 20cd5d7ace77..9754f681e18d 100644 --- a/python/ray/rllib/examples/multiagent_pendulum.py +++ b/python/ray/rllib/examples/multiagent_pendulum.py @@ -8,7 +8,7 @@ import ray import ray.rllib.ppo as ppo -from ray.tune.registry import get_registry, register_env +from ray.tune.registry import register_env env_name = "MultiAgentPendulumEnv" @@ -51,6 +51,6 @@ def create_env(env_config): "multiagent_shared_model": True, "multiagent_fcnet_hiddens": [[32, 32]] * 2} config["model"].update({"custom_options": options}) - alg = ppo.PPOAgent(env=env_name, registry=get_registry(), config=config) + alg = ppo.PPOAgent(env=env_name, config=config) for i in range(1): alg.train() diff --git a/python/ray/rllib/models/catalog.py b/python/ray/rllib/models/catalog.py index 603073dbfebc..ada537340345 100644 --- a/python/ray/rllib/models/catalog.py +++ b/python/ray/rllib/models/catalog.py @@ -8,7 +8,7 @@ from functools import partial from ray.tune.registry import RLLIB_MODEL, RLLIB_PREPROCESSOR, \ - _default_registry + _global_registry from ray.rllib.models.action_dist import ( Categorical, Deterministic, DiagGaussian, MultiActionDistribution) @@ -45,7 +45,7 @@ class ModelCatalog(object): >>> observation = prep.transform(raw_observation) >>> dist_cls, dist_dim = ModelCatalog.get_action_dist(env.action_space) - >>> model = ModelCatalog.get_model(registry, inputs, dist_dim) + >>> model = ModelCatalog.get_model(inputs, dist_dim) >>> dist = dist_cls(model.outputs) >>> action = dist.sample() """ @@ -123,11 +123,10 @@ def get_action_placeholder(action_space): " not supported".format(action_space)) @staticmethod - def get_model(registry, inputs, num_outputs, options={}): + def get_model(inputs, num_outputs, options={}): """Returns a suitable model conforming to given input and output specs. Args: - registry (obj): Registry of named objects (ray.tune.registry). inputs (Tensor): The input tensor to the model. num_outputs (int): The size of the output vector of the model. options (dict): Optional args to pass to the model constructor. @@ -139,7 +138,7 @@ def get_model(registry, inputs, num_outputs, options={}): if "custom_model" in options: model = options["custom_model"] print("Using custom model {}".format(model)) - return registry.get(RLLIB_MODEL, model)( + return _global_registry.get(RLLIB_MODEL, model)( inputs, num_outputs, options) obs_rank = len(inputs.shape) - 1 @@ -156,12 +155,11 @@ def get_model(registry, inputs, num_outputs, options={}): return FullyConnectedNetwork(inputs, num_outputs, options) @staticmethod - def get_torch_model(registry, input_shape, num_outputs, options={}): + def get_torch_model(input_shape, num_outputs, options={}): """Returns a PyTorch suitable model. This is currently only supported in A3C. Args: - registry (obj): Registry of named objects (ray.tune.registry). input_shape (tuple): The input shape to the model. num_outputs (int): The size of the output vector of the model. options (dict): Optional args to pass to the model constructor. @@ -177,7 +175,7 @@ def get_torch_model(registry, input_shape, num_outputs, options={}): if "custom_model" in options: model = options["custom_model"] print("Using custom torch model {}".format(model)) - return registry.get(RLLIB_MODEL, model)( + return _global_registry.get(RLLIB_MODEL, model)( input_shape, num_outputs, options) # TODO(alok): fix to handle Discrete(n) state spaces @@ -191,11 +189,10 @@ def get_torch_model(registry, input_shape, num_outputs, options={}): return PyTorchFCNet(input_shape[0], num_outputs, options) @staticmethod - def get_preprocessor(registry, env, options={}): + def get_preprocessor(env, options={}): """Returns a suitable processor for the given environment. Args: - registry (obj): Registry of named objects (ray.tune.registry). env (gym.Env): The gym environment to preprocess. options (dict): Options to pass to the preprocessor. @@ -211,18 +208,17 @@ def get_preprocessor(registry, env, options={}): if "custom_preprocessor" in options: preprocessor = options["custom_preprocessor"] print("Using custom preprocessor {}".format(preprocessor)) - return registry.get(RLLIB_PREPROCESSOR, preprocessor)( + return _global_registry.get(RLLIB_PREPROCESSOR, preprocessor)( env.observation_space, options) preprocessor = get_preprocessor(env.observation_space) return preprocessor(env.observation_space, options) @staticmethod - def get_preprocessor_as_wrapper(registry, env, options={}): + def get_preprocessor_as_wrapper(env, options={}): """Returns a preprocessor as a gym observation wrapper. Args: - registry (obj): Registry of named objects (ray.tune.registry). env (gym.Env): The gym environment to wrap. options (dict): Options to pass to the preprocessor. @@ -230,7 +226,7 @@ def get_preprocessor_as_wrapper(registry, env, options={}): wrapper (gym.ObservationWrapper): Preprocessor in wrapper form. """ - preprocessor = ModelCatalog.get_preprocessor(registry, env, options) + preprocessor = ModelCatalog.get_preprocessor(env, options) return _RLlibPreprocessorWrapper(env, preprocessor) @staticmethod @@ -244,7 +240,7 @@ def register_custom_preprocessor(preprocessor_name, preprocessor_class): preprocessor_name (str): Name to register the preprocessor under. preprocessor_class (type): Python class of the preprocessor. """ - _default_registry.register( + _global_registry.register( RLLIB_PREPROCESSOR, preprocessor_name, preprocessor_class) @staticmethod @@ -258,7 +254,7 @@ def register_custom_model(model_name, model_class): model_name (str): Name to register the model under. model_class (type): Python class of the model. """ - _default_registry.register(RLLIB_MODEL, model_name, model_class) + _global_registry.register(RLLIB_MODEL, model_name, model_class) class _RLlibPreprocessorWrapper(gym.ObservationWrapper): diff --git a/python/ray/rllib/pg/pg.py b/python/ray/rllib/pg/pg.py index 79dcdd3a6f3d..1ca4eb49334b 100644 --- a/python/ray/rllib/pg/pg.py +++ b/python/ray/rllib/pg/pg.py @@ -55,7 +55,6 @@ def _init(self): "policy_graph": PGPolicyGraph, "batch_steps": self.config["batch_size"], "batch_mode": "truncate_episodes", - "registry": self.registry, "model_config": self.config["model"], "env_config": self.config["env_config"], "policy_config": self.config, diff --git a/python/ray/rllib/pg/pg_policy_graph.py b/python/ray/rllib/pg/pg_policy_graph.py index b605a513f39c..210af707d8ad 100644 --- a/python/ray/rllib/pg/pg_policy_graph.py +++ b/python/ray/rllib/pg/pg_policy_graph.py @@ -11,14 +11,14 @@ class PGPolicyGraph(TFPolicyGraph): - def __init__(self, obs_space, action_space, registry, config): + def __init__(self, obs_space, action_space, config): self.config = config # setup policy self.x = tf.placeholder(tf.float32, shape=[None]+list(obs_space.shape)) dist_class, self.logit_dim = ModelCatalog.get_action_dist(action_space) self.model = ModelCatalog.get_model( - registry, self.x, self.logit_dim, options=self.config["model"]) + self.x, self.logit_dim, options=self.config["model"]) self.dist = dist_class(self.model.outputs) # logit for each action # setup policy loss diff --git a/python/ray/rllib/ppo/loss.py b/python/ray/rllib/ppo/loss.py index 7f61efaf690c..dd0e03f47f97 100644 --- a/python/ray/rllib/ppo/loss.py +++ b/python/ray/rllib/ppo/loss.py @@ -16,14 +16,14 @@ def __init__( self, observation_space, action_space, observations, value_targets, advantages, actions, prev_logits, prev_vf_preds, logit_dim, - kl_coeff, distribution_class, config, sess, registry): + kl_coeff, distribution_class, config, sess): self.prev_dist = distribution_class(prev_logits) # Saved so that we can compute actions given different observations self.observations = observations self.curr_logits = ModelCatalog.get_model( - registry, observations, logit_dim, config["model"]).outputs + observations, logit_dim, config["model"]).outputs self.curr_dist = distribution_class(self.curr_logits) self.sampler = self.curr_dist.sample() @@ -35,7 +35,7 @@ def __init__( vf_config["free_log_std"] = False with tf.variable_scope("value_function"): self.value_function = ModelCatalog.get_model( - registry, observations, 1, vf_config).outputs + observations, 1, vf_config).outputs self.value_function = tf.reshape(self.value_function, [-1]) # Make loss functions. diff --git a/python/ray/rllib/ppo/ppo.py b/python/ray/rllib/ppo/ppo.py index 7fb15c9d3103..144241c44031 100644 --- a/python/ray/rllib/ppo/ppo.py +++ b/python/ray/rllib/ppo/ppo.py @@ -103,14 +103,13 @@ def default_resource_request(cls, config): def _init(self): self.global_step = 0 self.local_evaluator = PPOEvaluator( - self.registry, self.env_creator, self.config, self.logdir, False) + self.env_creator, self.config, self.logdir, False) RemotePPOEvaluator = ray.remote( num_cpus=self.config["num_cpus_per_worker"], num_gpus=self.config["num_gpus_per_worker"])(PPOEvaluator) self.remote_evaluators = [ RemotePPOEvaluator.remote( - self.registry, self.env_creator, self.config, self.logdir, - True) + self.env_creator, self.config, self.logdir, True) for _ in range(self.config["num_workers"])] self.optimizer = LocalMultiGPUOptimizer( diff --git a/python/ray/rllib/ppo/ppo_evaluator.py b/python/ray/rllib/ppo/ppo_evaluator.py index 68f4437f3c8b..da9ae91a41b9 100644 --- a/python/ray/rllib/ppo/ppo_evaluator.py +++ b/python/ray/rllib/ppo/ppo_evaluator.py @@ -24,12 +24,11 @@ class PPOEvaluator(TFMultiGPUSupport): network weights. When run as a remote agent, only this graph is used. """ - def __init__(self, registry, env_creator, config, logdir, is_remote): - self.registry = registry + def __init__(self, env_creator, config, logdir, is_remote): self.config = config self.logdir = logdir self.env = ModelCatalog.get_preprocessor_as_wrapper( - registry, env_creator(config["env_config"]), config["model"]) + env_creator(config["env_config"]), config["model"]) if is_remote: config_proto = tf.ConfigProto() else: @@ -92,7 +91,7 @@ def build_tf_loss(self, input_placeholders): self.env.observation_space, self.env.action_space, obs, vtargets, advs, acts, plog, pvf_preds, self.logit_dim, self.kl_coeff, self.distribution_class, self.config, - self.sess, self.registry) + self.sess) def init_extra_ops(self, device_losses): self.extra_ops = OrderedDict() diff --git a/python/ray/rllib/rollout.py b/python/ray/rllib/rollout.py index 64174866aa19..09fd52a314ef 100755 --- a/python/ray/rllib/rollout.py +++ b/python/ray/rllib/rollout.py @@ -14,7 +14,6 @@ from ray.rllib.agent import get_agent_class from ray.rllib.dqn.common.wrappers import wrap_dqn from ray.rllib.models import ModelCatalog -from ray.tune.registry import get_registry EXAMPLE_USAGE = """ example usage: @@ -74,10 +73,9 @@ if args.run == "DQN": env = gym.make(args.env) - env = wrap_dqn(get_registry(), env, args.config.get("model", {})) + env = wrap_dqn(env, args.config.get("model", {})) else: - env = ModelCatalog.get_preprocessor_as_wrapper(get_registry(), - gym.make(args.env)) + env = ModelCatalog.get_preprocessor_as_wrapper(gym.make(args.env)) if args.out is not None: rollouts = [] steps = 0 diff --git a/python/ray/rllib/test/test_catalog.py b/python/ray/rllib/test/test_catalog.py index c5e503b717ee..e975b1b4cca5 100644 --- a/python/ray/rllib/test/test_catalog.py +++ b/python/ray/rllib/test/test_catalog.py @@ -5,7 +5,6 @@ from gym.spaces import Box, Discrete, Tuple import ray -from ray.tune.registry import get_registry from ray.rllib.models import ModelCatalog from ray.rllib.models.model import Model @@ -33,12 +32,10 @@ def tearDown(self): ray.worker.cleanup() def testGymPreprocessors(self): - p1 = ModelCatalog.get_preprocessor( - get_registry(), gym.make("CartPole-v0")) + p1 = ModelCatalog.get_preprocessor(gym.make("CartPole-v0")) self.assertEqual(type(p1), NoPreprocessor) - p2 = ModelCatalog.get_preprocessor( - get_registry(), gym.make("FrozenLake-v0")) + p2 = ModelCatalog.get_preprocessor(gym.make("FrozenLake-v0")) self.assertEqual(type(p2), OneHotPreprocessor) def testTuplePreprocessor(self): @@ -48,8 +45,7 @@ class TupleEnv(object): def __init__(self): self.observation_space = Tuple( [Discrete(5), Box(0, 1, shape=(3,), dtype=np.float32)]) - p1 = ModelCatalog.get_preprocessor( - get_registry(), TupleEnv()) + p1 = ModelCatalog.get_preprocessor(TupleEnv()) self.assertEqual(p1.shape, (8,)) self.assertEqual( list(p1.transform((0, [1, 2, 3]))), @@ -60,33 +56,29 @@ def testCustomPreprocessor(self): ModelCatalog.register_custom_preprocessor("foo", CustomPreprocessor) ModelCatalog.register_custom_preprocessor("bar", CustomPreprocessor2) env = gym.make("CartPole-v0") - p1 = ModelCatalog.get_preprocessor( - get_registry(), env, {"custom_preprocessor": "foo"}) + p1 = ModelCatalog.get_preprocessor(env, {"custom_preprocessor": "foo"}) self.assertEqual(str(type(p1)), str(CustomPreprocessor)) - p2 = ModelCatalog.get_preprocessor( - get_registry(), env, {"custom_preprocessor": "bar"}) + p2 = ModelCatalog.get_preprocessor(env, {"custom_preprocessor": "bar"}) self.assertEqual(str(type(p2)), str(CustomPreprocessor2)) - p3 = ModelCatalog.get_preprocessor(get_registry(), env) + p3 = ModelCatalog.get_preprocessor(env) self.assertEqual(type(p3), NoPreprocessor) def testDefaultModels(self): ray.init() with tf.variable_scope("test1"): - p1 = ModelCatalog.get_model( - get_registry(), np.zeros((10, 3), dtype=np.float32), 5) + p1 = ModelCatalog.get_model(np.zeros((10, 3), dtype=np.float32), 5) self.assertEqual(type(p1), FullyConnectedNetwork) with tf.variable_scope("test2"): p2 = ModelCatalog.get_model( - get_registry(), np.zeros((10, 80, 80, 3), dtype=np.float32), 5) + np.zeros((10, 80, 80, 3), dtype=np.float32), 5) self.assertEqual(type(p2), VisionNetwork) def testCustomModel(self): ray.init() ModelCatalog.register_custom_model("foo", CustomModel) - p1 = ModelCatalog.get_model( - get_registry(), 1, 5, {"custom_model": "foo"}) + p1 = ModelCatalog.get_model(1, 5, {"custom_model": "foo"}) self.assertEqual(str(type(p1)), str(CustomModel)) diff --git a/python/ray/rllib/test/test_supported_spaces.py b/python/ray/rllib/test/test_supported_spaces.py index bf3124002a9f..5c6e8c362e2a 100644 --- a/python/ray/rllib/test/test_supported_spaces.py +++ b/python/ray/rllib/test/test_supported_spaces.py @@ -59,8 +59,7 @@ def check_support(alg, config, stats): for o_name, obs_space in OBSERVATION_SPACES_TO_TEST.items(): print("=== Testing", alg, action_space, obs_space, "===") stub_env = make_stub_env(action_space, obs_space) - register_env( - "stub_env", lambda c: stub_env()) + register_env("stub_env", lambda c: stub_env()) stat = "ok" a = None try: diff --git a/python/ray/rllib/utils/common_policy_evaluator.py b/python/ray/rllib/utils/common_policy_evaluator.py index 6366375d36f9..95be18fcefbe 100644 --- a/python/ray/rllib/utils/common_policy_evaluator.py +++ b/python/ray/rllib/utils/common_policy_evaluator.py @@ -18,7 +18,6 @@ from ray.rllib.utils.serving_env import ServingEnv, _ServingEnvToAsync from ray.rllib.utils.tf_policy_graph import TFPolicyGraph from ray.rllib.utils.vector_env import VectorEnv -from ray.tune.registry import get_registry from ray.tune.result import TrainingResult @@ -97,7 +96,6 @@ def __init__( compress_observations=False, num_envs=1, observation_filter="NoFilter", - registry=None, env_config=None, model_config=None, policy_config=None): @@ -137,15 +135,11 @@ def __init__( and vectorize the computation of actions. This has no effect if if the env already implements VectorEnv. observation_filter (str): Name of observation filter to use. - registry (tune.Registry): User-registered objects. Pass in the - value from tune.registry.get_registry() if you're having - trouble resolving things like custom envs. env_config (dict): Config to pass to the env creator. model_config (dict): Config to use when creating the policy model. policy_config (dict): Config to pass to the policy. """ - registry = registry or get_registry() env_config = env_config or {} policy_config = policy_config or {} model_config = model_config or {} @@ -169,7 +163,7 @@ def wrap(env): else: def wrap(env): return ModelCatalog.get_preprocessor_as_wrapper( - registry, env, model_config) + env, model_config) self.env = wrap(self.env) def make_env(): @@ -187,11 +181,11 @@ def make_env(): with self.sess.as_default(): policy = policy_graph( self.env.observation_space, self.env.action_space, - registry, policy_config) + policy_config) else: policy = policy_graph( self.env.observation_space, self.env.action_space, - registry, policy_config) + policy_config) self.policy_map = { "default": policy } diff --git a/python/ray/rllib/utils/policy_graph.py b/python/ray/rllib/utils/policy_graph.py index fdd22ede6bae..45f48684caf0 100644 --- a/python/ray/rllib/utils/policy_graph.py +++ b/python/ray/rllib/utils/policy_graph.py @@ -17,11 +17,10 @@ class PolicyGraph(object): graphs and multi-GPU support. """ - def __init__(self, registry, observation_space, action_space, config): + def __init__(self, observation_space, action_space, config): """Initialize the graph. Args: - registry (obj): Object registry for user-defined envs, models, etc. observation_space (gym.Space): Observation space of the env. action_space (gym.Space): Action space of the env. config (dict): Policy-specific configuration data. diff --git a/python/ray/tune/registry.py b/python/ray/tune/registry.py index cb9505771e63..3441fc793bed 100644 --- a/python/ray/tune/registry.py +++ b/python/ray/tune/registry.py @@ -4,10 +4,10 @@ from types import FunctionType -import numpy as np - import ray -from ray.local_scheduler import ObjectID +import ray.cloudpickle as pickle +from ray.experimental.internal_kv import _internal_kv_initialized, \ + _internal_kv_get, _internal_kv_put TRAINABLE_CLASS = "trainable_class" ENV_CREATOR = "env_creator" @@ -35,7 +35,7 @@ def register_trainable(name, trainable): if not issubclass(trainable, Trainable): raise TypeError("Second argument must be convertable to Trainable", trainable) - _default_registry.register(TRAINABLE_CLASS, name, trainable) + _global_registry.register(TRAINABLE_CLASS, name, trainable) def register_env(name, env_creator): @@ -48,62 +48,59 @@ def register_env(name, env_creator): if not isinstance(env_creator, FunctionType): raise TypeError("Second argument must be a function.", env_creator) - _default_registry.register(ENV_CREATOR, name, env_creator) - - -def get_registry(): - """Use this to access the registry. This requires ray to be initialized.""" + _global_registry.register(ENV_CREATOR, name, env_creator) - _default_registry.flush_values_to_object_store() - # returns a registry copy that doesn't include the hard refs - return _Registry(_default_registry._all_objects) +def _make_key(category, key): + """Generate a binary key for the given category and key. + Args: + category (str): The category of the item + key (str): The unique identifier for the item -def _to_pinnable(obj): - """Converts obj to a form that can be pinned in object store memory. - - Currently only numpy arrays are pinned in memory, if you have a strong - reference to the array value. + Returns: + The key to use for storing a the value. """ - - return (obj, np.zeros(1)) - - -def _from_pinnable(obj): - """Retrieve from _to_pinnable format.""" - - return obj[0] + return (b"TuneRegistry:" + category.encode("ascii") + b"/" + + key.encode("ascii")) class _Registry(object): - def __init__(self, objs=None): - self._all_objects = {} if objs is None else objs.copy() - self._refs = [] # hard refs that prevent eviction of objects + def __init__(self): + self._to_flush = {} def register(self, category, key, value): if category not in KNOWN_CATEGORIES: from ray.tune import TuneError raise TuneError("Unknown category {} not among {}".format( category, KNOWN_CATEGORIES)) - self._all_objects[(category, key)] = value + self._to_flush[(category, key)] = pickle.dumps(value) + if _internal_kv_initialized(): + self.flush_values() def contains(self, category, key): - return (category, key) in self._all_objects + if _internal_kv_initialized(): + value = _internal_kv_get(_make_key(category, key)) + return value is not None + else: + return (category, key) in self._to_flush def get(self, category, key): - value = self._all_objects[(category, key)] - if type(value) == ObjectID: - return _from_pinnable(ray.get(value)) + if _internal_kv_initialized(): + value = _internal_kv_get(_make_key(category, key)) + if value is None: + raise ValueError( + "Registry value for {}/{} doesn't exist.".format( + category, key)) + return pickle.loads(value) else: - return value + return pickle.loads(self._to_flush[(category, key)]) - def flush_values_to_object_store(self): - for k, v in self._all_objects.items(): - if type(v) != ObjectID: - obj = ray.put(_to_pinnable(v)) - self._all_objects[k] = obj - self._refs.append(ray.get(obj)) + def flush_values(self): + for (category, key), value in self._to_flush.items(): + _internal_kv_put(_make_key(category, key), value) + self._to_flush.clear() -_default_registry = _Registry() +_global_registry = _Registry() +ray.worker._post_init_hooks.append(_global_registry.flush_values) diff --git a/python/ray/tune/test/trial_runner_test.py b/python/ray/tune/test/trial_runner_test.py index f762b4848728..832bcde284bd 100644 --- a/python/ray/tune/test/trial_runner_test.py +++ b/python/ray/tune/test/trial_runner_test.py @@ -11,7 +11,7 @@ from ray.tune import Trainable, TuneError from ray.tune import register_env, register_trainable, run_experiments -from ray.tune.registry import _default_registry, TRAINABLE_CLASS +from ray.tune.registry import _global_registry, TRAINABLE_CLASS from ray.tune.result import DEFAULT_RESULTS_DIR, TrainingResult from ray.tune.util import pin_in_object_store, get_pinned_object from ray.tune.experiment import Experiment @@ -595,7 +595,7 @@ def train(config, reporter): def testTrialErrorOnStart(self): ray.init() - _default_registry.register(TRAINABLE_CLASS, "asdf", None) + _global_registry.register(TRAINABLE_CLASS, "asdf", None) trial = Trial("asdf", resources=Resources(1, 0)) try: trial.start() @@ -690,7 +690,7 @@ def testErrorHandling(self): }, "resources": Resources(cpu=1, gpu=1), } - _default_registry.register(TRAINABLE_CLASS, "asdf", None) + _global_registry.register(TRAINABLE_CLASS, "asdf", None) trials = [Trial("asdf", **kwargs), Trial("__fake", **kwargs)] for t in trials: runner.add_trial(t) diff --git a/python/ray/tune/trainable.py b/python/ray/tune/trainable.py index f0351e128834..c0e4838cb3f8 100644 --- a/python/ray/tune/trainable.py +++ b/python/ray/tune/trainable.py @@ -46,11 +46,9 @@ class Trainable(object): Attributes: config (obj): The hyperparam configuration for this trial. logdir (str): Directory in which training outputs should be placed. - registry (obj): Tune object registry which holds user-registered - classes and objects by name. """ - def __init__(self, config=None, registry=None, logger_creator=None): + def __init__(self, config=None, logger_creator=None): """Initialize an Trainable. Subclasses should prefer defining ``_setup()`` instead of overriding @@ -58,20 +56,13 @@ def __init__(self, config=None, registry=None, logger_creator=None): Args: config (dict): Trainable-specific configuration data. - registry (obj): Object registry for user-defined envs, models, etc. - If unspecified, the default registry will be used. logger_creator (func): Function that creates a ray.tune.Logger object. If unspecified, a default logger is created. """ - if registry is None: - from ray.tune.registry import get_registry - registry = get_registry() - self._initialize_ok = False self._experiment_id = uuid.uuid4().hex self.config = config or {} - self.registry = registry if logger_creator: self._result_logger = logger_creator(self.config) diff --git a/python/ray/tune/trial.py b/python/ray/tune/trial.py index 7d7442572e6a..e8f42ab1cec8 100644 --- a/python/ray/tune/trial.py +++ b/python/ray/tune/trial.py @@ -57,7 +57,7 @@ def gpu_total(self): def has_trainable(trainable_name): - return ray.tune.registry._default_registry.contains( + return ray.tune.registry._global_registry.contains( ray.tune.registry.TRAINABLE_CLASS, trainable_name) @@ -377,12 +377,10 @@ def logger_creator(config): # Logging for trials is handled centrally by TrialRunner, so # configure the remote runner to use a noop-logger. self.runner = cls.remote( - config=self.config, - registry=ray.tune.registry.get_registry(), - logger_creator=logger_creator) + config=self.config, logger_creator=logger_creator) def _get_trainable_cls(self): - return ray.tune.registry.get_registry().get( + return ray.tune.registry._global_registry.get( ray.tune.registry.TRAINABLE_CLASS, self.trainable_name) def set_verbose(self, verbose): diff --git a/python/ray/tune/util.py b/python/ray/tune/util.py index 27bef5cc1557..f6ae3ab28ec7 100644 --- a/python/ray/tune/util.py +++ b/python/ray/tune/util.py @@ -2,12 +2,12 @@ from __future__ import division from __future__ import print_function -import base64 from six.moves import queue +import base64 +import numpy as np import threading import ray -from ray.tune.registry import _to_pinnable, _from_pinnable _pinned_objects = [] _fetch_requests = queue.Queue() @@ -63,6 +63,22 @@ def _serve_get_pin_requests(): pass +def _to_pinnable(obj): + """Converts obj to a form that can be pinned in object store memory. + + Currently only numpy arrays are pinned in memory, if you have a strong + reference to the array value. + """ + + return (obj, np.zeros(1)) + + +def _from_pinnable(obj): + """Retrieve from _to_pinnable format.""" + + return obj[0] + + if __name__ == '__main__': ray.init() X = pin_in_object_store("hello") diff --git a/python/ray/worker.py b/python/ray/worker.py index e2077aa24070..6567d327575e 100644 --- a/python/ray/worker.py +++ b/python/ray/worker.py @@ -1741,7 +1741,7 @@ def init(redis_address=None, redis_address = services.address_to_ip(redis_address) info = {"node_ip_address": node_ip_address, "redis_address": redis_address} - return _init( + ret = _init( address_info=info, start_ray_local=(redis_address is None), num_workers=num_workers, @@ -1758,6 +1758,13 @@ def init(redis_address=None, include_webui=include_webui, object_store_memory=object_store_memory, use_raylet=use_raylet) + for hook in _post_init_hooks: + hook() + return ret + + +# Functions to run as callback after a successful ray init +_post_init_hooks = [] def cleanup(worker=global_worker):