Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions python/ray/experimental/internal_kv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import ray


def _internal_kv_initialized():
worker = ray.worker.get_global_worker()
return hasattr(worker, "mode") and worker.mode is not None


def _internal_kv_get(key):
"""Fetch the value of a binary key."""

worker = ray.worker.get_global_worker()
return worker.redis_client.hget(key, "value")


def _internal_kv_put(key, value):
"""Globally associates a value with a given binary key.

This only has an effect if the key does not already have a value.

Returns
already_exists (bool): whether the value already exists.
"""

worker = ray.worker.get_global_worker()
updated = worker.redis_client.hsetnx(key, "value", value)
return updated == 0 # already exists
13 changes: 6 additions & 7 deletions python/ray/experimental/named_actors.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import ray
import ray.cloudpickle as pickle
from ray.experimental.internal_kv import _internal_kv_get, _internal_kv_put


def _calculate_key(name):
Expand All @@ -29,9 +30,8 @@ def get_actor(name):
Returns:
The ActorHandle object corresponding to the name.
"""
worker = ray.worker.get_global_worker()
actor_hash = _calculate_key(name)
pickled_state = worker.redis_client.hget(actor_hash, name)
actor_name = _calculate_key(name)
pickled_state = _internal_kv_get(actor_name)
if pickled_state is None:
raise ValueError("The actor with name={} doesn't exist".format(name))
handle = pickle.loads(pickled_state)
Expand All @@ -45,17 +45,16 @@ def register_actor(name, actor_handle):
name: The name of the named actor.
actor_handle: The actor object to be associated with this name
"""
worker = ray.worker.get_global_worker()
if not isinstance(name, str):
raise TypeError("The name argument must be a string.")
if not isinstance(actor_handle, ray.actor.ActorHandle):
raise TypeError("The actor_handle argument must be an ActorHandle "
"object.")
actor_hash = _calculate_key(name)
actor_name = _calculate_key(name)
pickled_state = pickle.dumps(actor_handle)

# Add the actor to Redis if it does not already exist.
updated = worker.redis_client.hsetnx(actor_hash, name, pickled_state)
if updated == 0:
already_exists = _internal_kv_put(actor_name, pickled_state)
if already_exists:
raise ValueError(
"Error: the actor with name={} already exists".format(name))
4 changes: 2 additions & 2 deletions python/ray/rllib/a3c/a3c.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def session_creator():
batch_steps=self.config["batch_size"],
batch_mode="truncate_episodes",
tf_session_creator=session_creator,
registry=self.registry, env_config=self.config["env_config"],
env_config=self.config["env_config"],
model_config=self.config["model"], policy_config=self.config,
num_envs=self.config["num_envs"])
self.remote_evaluators = [
Expand All @@ -111,7 +111,7 @@ def session_creator():
batch_steps=self.config["batch_size"],
batch_mode="truncate_episodes", sample_async=True,
tf_session_creator=session_creator,
registry=self.registry, env_config=self.config["env_config"],
env_config=self.config["env_config"],
model_config=self.config["model"], policy_config=self.config,
num_envs=self.config["num_envs"])
for i in range(self.config["num_workers"])]
Expand Down
3 changes: 1 addition & 2 deletions python/ray/rllib/a3c/a3c_tf_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@
class A3CTFPolicyGraph(TFPolicyGraph):
"""The TF policy base class."""

def __init__(self, ob_space, action_space, registry, config):
self.registry = registry
def __init__(self, ob_space, action_space, config):
self.local_steps = 0
self.config = config
self.summarize = config.get("summarize")
Expand Down
6 changes: 2 additions & 4 deletions python/ray/rllib/a3c/a3c_torch_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@
class SharedTorchPolicy(PolicyGraph):
"""A simple, non-recurrent PyTorch policy example."""

def __init__(self, obs_space, action_space, registry, config):
self.registry = registry
def __init__(self, obs_space, action_space, config):
self.local_steps = 0
self.config = config
self.summarize = config.get("summarize")
Expand All @@ -29,8 +28,7 @@ def __init__(self, obs_space, action_space, registry, config):
def setup_graph(self, obs_space, action_space):
_, self.logit_dim = ModelCatalog.get_action_dist(action_space)
self._model = ModelCatalog.get_torch_model(
self.registry, obs_space.shape, self.logit_dim,
self.config["model"])
obs_space.shape, self.logit_dim, self.config["model"])
self.optimizer = torch.optim.Adam(
self._model.parameters(), lr=self.config["lr"])

Expand Down
6 changes: 3 additions & 3 deletions python/ray/rllib/a3c/shared_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,15 @@

class SharedModel(A3CTFPolicyGraph):

def __init__(self, ob_space, ac_space, registry, config, **kwargs):
def __init__(self, ob_space, ac_space, config, **kwargs):
super(SharedModel, self).__init__(
ob_space, ac_space, registry, config, **kwargs)
ob_space, ac_space, config, **kwargs)

def _setup_graph(self, ob_space, ac_space):
self.x = tf.placeholder(tf.float32, [None] + list(ob_space.shape))
dist_class, self.logit_dim = ModelCatalog.get_action_dist(ac_space)
self._model = ModelCatalog.get_model(
self.registry, self.x, self.logit_dim, self.config["model"])
self.x, self.logit_dim, self.config["model"])
self.logits = self._model.outputs
self.action_dist = dist_class(self.logits)
self.vf = tf.reshape(linear(self._model.last_layer, 1, "value",
Expand Down
4 changes: 2 additions & 2 deletions python/ray/rllib/a3c/shared_model_lstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@

class SharedModelLSTM(A3CTFPolicyGraph):

def __init__(self, ob_space, ac_space, registry, config, **kwargs):
def __init__(self, ob_space, ac_space, config, **kwargs):
super(SharedModelLSTM, self).__init__(
ob_space, ac_space, registry, config, **kwargs)
ob_space, ac_space, config, **kwargs)

def _setup_graph(self, ob_space, ac_space):
self.x = tf.placeholder(tf.float32, [None] + list(ob_space.shape))
Expand Down
15 changes: 5 additions & 10 deletions python/ray/rllib/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import pickle

import tensorflow as tf
from ray.tune.registry import ENV_CREATOR
from ray.tune.registry import ENV_CREATOR, _global_registry
from ray.tune.result import TrainingResult
from ray.tune.trainable import Trainable

Expand Down Expand Up @@ -56,8 +56,6 @@ class Agent(Trainable):
env_creator (func): Function that creates a new training env.
config (obj): Algorithm-specific configuration data.
logdir (str): Directory in which training outputs should be placed.
registry (obj): Tune object registry which holds user-registered
classes and objects by name.
"""

_allow_unknown_configs = False
Expand All @@ -72,16 +70,13 @@ def resource_help(cls, config):
"The config of this agent is: " + json.dumps(config))

def __init__(
self, config=None, env=None, registry=None,
logger_creator=None):
self, config=None, env=None, logger_creator=None):
"""Initialize an RLLib agent.

Args:
config (dict): Algorithm-specific configuration data.
env (str): Name of the environment to use. Note that this can also
be specified as the `env` key in config.
registry (obj): Object registry for user-defined envs, models, etc.
If unspecified, the default registry will be used.
logger_creator (func): Function that creates a ray.tune.Logger
object. If unspecified, a default logger is created.
"""
Expand All @@ -90,14 +85,14 @@ def __init__(

# Agents allow env ids to be passed directly to the constructor.
self._env_id = env or config.get("env")
Trainable.__init__(self, config, registry, logger_creator)
Trainable.__init__(self, config, logger_creator)

def _setup(self):
env = self._env_id
if env:
self.config["env"] = env
if self.registry and self.registry.contains(ENV_CREATOR, env):
self.env_creator = self.registry.get(ENV_CREATOR, env)
if _global_registry.contains(ENV_CREATOR, env):
self.env_creator = _global_registry.get(ENV_CREATOR, env)
else:
import gym # soft dependency
self.env_creator = lambda env_config: gym.make(env)
Expand Down
5 changes: 2 additions & 3 deletions python/ray/rllib/bc/bc.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,14 +63,13 @@ def default_resource_request(cls, config):

def _init(self):
self.local_evaluator = BCEvaluator(
self.registry, self.env_creator, self.config, self.logdir)
self.env_creator, self.config, self.logdir)
if self.config["use_gpu_for_workers"]:
remote_cls = GPURemoteBCEvaluator
else:
remote_cls = RemoteBCEvaluator
self.remote_evaluators = [
remote_cls.remote(
self.registry, self.env_creator, self.config, self.logdir)
remote_cls.remote(self.env_creator, self.config, self.logdir)
for _ in range(self.config["num_workers"])]
self.optimizer = AsyncOptimizer(
self.config["optimizer"], self.local_evaluator,
Expand Down
7 changes: 3 additions & 4 deletions python/ray/rllib/bc/bc_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,11 @@


class BCEvaluator(PolicyEvaluator):
def __init__(self, registry, env_creator, config, logdir):
env = ModelCatalog.get_preprocessor_as_wrapper(registry, env_creator(
def __init__(self, env_creator, config, logdir):
env = ModelCatalog.get_preprocessor_as_wrapper(env_creator(
config["env_config"]), config["model"])
self.dataset = ExperienceDataset(config["dataset_path"])
self.policy = BCPolicy(registry, env.observation_space,
env.action_space, config)
self.policy = BCPolicy(env.observation_space, env.action_space, config)
self.config = config
self.logdir = logdir
self.metrics_queue = queue.Queue()
Expand Down
5 changes: 2 additions & 3 deletions python/ray/rllib/bc/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@


class BCPolicy(object):
def __init__(self, registry, obs_space, action_space, config):
self.registry = registry
def __init__(self, obs_space, action_space, config):
self.local_steps = 0
self.config = config
self.summarize = config.get("summarize")
Expand All @@ -24,7 +23,7 @@ def _setup_graph(self, obs_space, ac_space):
self.x = tf.placeholder(tf.float32, [None] + list(obs_space.shape))
dist_class, self.logit_dim = ModelCatalog.get_action_dist(ac_space)
self._model = ModelCatalog.get_model(
self.registry, self.x, self.logit_dim, self.config["model"])
self.x, self.logit_dim, self.config["model"])
self.logits = self._model.outputs
self.curr_dist = dist_class(self.logits)
self.sample = self.curr_dist.sample()
Expand Down
26 changes: 11 additions & 15 deletions python/ray/rllib/ddpg/ddpg_policy_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,12 @@
Q_TARGET_SCOPE = "target_q_func"


def _build_p_network(registry, inputs, dim_actions, config):
def _build_p_network(inputs, dim_actions, config):
"""
map an observation (i.e., state) to an action where
each entry takes value from (0, 1) due to the sigmoid function
"""
frontend = ModelCatalog.get_model(registry, inputs, 1, config["model"])
frontend = ModelCatalog.get_model(inputs, 1, config["model"])

hiddens = config["actor_hiddens"]
action_out = frontend.last_layer
Expand Down Expand Up @@ -66,8 +66,8 @@ def _build_action_network(p_values, low_action, high_action, stochastic, eps,
lambda: deterministic_actions)


def _build_q_network(registry, inputs, action_inputs, config):
frontend = ModelCatalog.get_model(registry, inputs, 1, config["model"])
def _build_q_network(inputs, action_inputs, config):
frontend = ModelCatalog.get_model(inputs, 1, config["model"])

hiddens = config["critic_hiddens"]

Expand All @@ -81,7 +81,7 @@ def _build_q_network(registry, inputs, action_inputs, config):


class DDPGPolicyGraph(TFPolicyGraph):
def __init__(self, observation_space, action_space, registry, config):
def __init__(self, observation_space, action_space, config):
if not isinstance(action_space, Box):
raise UnsupportedSpaceException(
"Action space {} is not supported for DDPG.".format(
Expand All @@ -105,7 +105,7 @@ def __init__(self, observation_space, action_space, registry, config):

# Actor: P (policy) network
with tf.variable_scope(P_SCOPE) as scope:
p_values = _build_p_network(registry, self.cur_observations,
p_values = _build_p_network(self.cur_observations,
dim_actions, config)
self.p_func_vars = _scope_vars(scope.name)

Expand Down Expand Up @@ -136,13 +136,11 @@ def __init__(self, observation_space, action_space, registry, config):

# p network evaluation
with tf.variable_scope(P_SCOPE, reuse=True) as scope:
self.p_t = _build_p_network(
registry, self.obs_t, dim_actions, config)
self.p_t = _build_p_network(self.obs_t, dim_actions, config)

# target p network evaluation
with tf.variable_scope(P_TARGET_SCOPE) as scope:
p_tp1 = _build_p_network(
registry, self.obs_tp1, dim_actions, config)
p_tp1 = _build_p_network(self.obs_tp1, dim_actions, config)
target_p_func_vars = _scope_vars(scope.name)

# Action outputs
Expand All @@ -161,17 +159,15 @@ def __init__(self, observation_space, action_space, registry, config):

# q network evaluation
with tf.variable_scope(Q_SCOPE) as scope:
q_t = _build_q_network(
registry, self.obs_t, self.act_t, config)
q_t = _build_q_network(self.obs_t, self.act_t, config)
self.q_func_vars = _scope_vars(scope.name)
with tf.variable_scope(Q_SCOPE, reuse=True):
q_tp0 = _build_q_network(
registry, self.obs_t, output_actions, config)
q_tp0 = _build_q_network(self.obs_t, output_actions, config)

# target q network evalution
with tf.variable_scope(Q_TARGET_SCOPE) as scope:
q_tp1 = _build_q_network(
registry, self.obs_tp1, output_actions_estimated, config)
self.obs_tp1, output_actions_estimated, config)
target_q_func_vars = _scope_vars(scope.name)

q_t_selected = tf.squeeze(q_t, axis=len(q_t.shape) - 1)
Expand Down
4 changes: 2 additions & 2 deletions python/ray/rllib/dqn/common/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from ray.rllib.utils.atari_wrappers import wrap_deepmind


def wrap_dqn(registry, env, options, random_starts):
def wrap_dqn(env, options, random_starts):
"""Apply a common set of wrappers for DQN."""

is_atari = hasattr(env.unwrapped, "ale")
Expand All @@ -17,4 +17,4 @@ def wrap_dqn(registry, env, options, random_starts):
return wrap_deepmind(
env, random_starts=random_starts, dim=options.get("dim", 80))

return ModelCatalog.get_preprocessor_as_wrapper(registry, env, options)
return ModelCatalog.get_preprocessor_as_wrapper(env, options)
4 changes: 2 additions & 2 deletions python/ray/rllib/dqn/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def _init(self):
batch_steps=adjusted_batch_size,
batch_mode="truncate_episodes", preprocessor_pref="deepmind",
compress_observations=True,
registry=self.registry, env_config=self.config["env_config"],
env_config=self.config["env_config"],
model_config=self.config["model"], policy_config=self.config,
num_envs=self.config["num_envs"])
remote_cls = CommonPolicyEvaluator.as_remote(
Expand All @@ -141,7 +141,7 @@ def _init(self):
batch_steps=adjusted_batch_size,
batch_mode="truncate_episodes", preprocessor_pref="deepmind",
compress_observations=True,
registry=self.registry, env_config=self.config["env_config"],
env_config=self.config["env_config"],
model_config=self.config["model"], policy_config=self.config,
num_envs=self.config["num_envs"])
for _ in range(self.config["num_workers"])]
Expand Down
Loading