Skip to content
Merged
39 changes: 29 additions & 10 deletions python/ray/rllib/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,10 @@
"compress_observations": False,
# Drop metric batches from unresponsive workers after this many seconds
"collect_metrics_timeout": 180,
# If using num_envs_per_worker > 1, whether to create those new envs in
# remote processes instead of in the same worker. This adds overheads, but
# can make sense if your envs are very CPU intensive (e.g., for StarCraft).
"remote_worker_envs": False,

# === Offline Datasets ===
# __sphinx_doc_input_begin__
Expand Down Expand Up @@ -463,7 +467,9 @@ def make_local_evaluator(self,
"tf_session_args": self.
config["local_evaluator_tf_session_args"]
}),
extra_config or {}))
extra_config or {}),
remote_worker_envs=False,
)

@DeveloperAPI
def make_remote_evaluators(self, env_creator, policy_graph, count):
Expand All @@ -476,9 +482,16 @@ def make_remote_evaluators(self, env_creator, policy_graph, count):
}

cls = PolicyEvaluator.as_remote(**remote_args).remote

return [
self._make_evaluator(cls, env_creator, policy_graph, i + 1,
self.config) for i in range(count)
self._make_evaluator(
cls,
env_creator,
policy_graph,
i + 1,
self.config,
remote_worker_envs=self.config["remote_worker_envs"])
for i in range(count)
]

@DeveloperAPI
Expand Down Expand Up @@ -544,8 +557,13 @@ def _validate_config(config):
raise ValueError(
"`input_evaluation` should not be set when input=sampler")

def _make_evaluator(self, cls, env_creator, policy_graph, worker_index,
config):
def _make_evaluator(self,
cls,
env_creator,
policy_graph,
worker_index,
config,
remote_worker_envs=False):
def session_creator():
logger.debug("Creating TF session {}".format(
config["tf_session_args"]))
Expand Down Expand Up @@ -573,10 +591,10 @@ def session_creator():
compress_columns=config["output_compress_columns"]))
else:
output_creator = (lambda ioctx: JsonWriter(
config["output"],
ioctx,
max_file_size=config["output_max_file_size"],
compress_columns=config["output_compress_columns"]))
config["output"],
ioctx,
max_file_size=config["output_max_file_size"],
compress_columns=config["output_compress_columns"]))

return cls(
env_creator,
Expand Down Expand Up @@ -605,7 +623,8 @@ def session_creator():
callbacks=config["callbacks"],
input_creator=input_creator,
input_evaluation_method=config["input_evaluation"],
output_creator=output_creator)
output_creator=output_creator,
remote_worker_envs=remote_worker_envs)

@override(Trainable)
def _export_model(self, export_formats, export_dir):
Expand Down
17 changes: 15 additions & 2 deletions python/ray/rllib/env/base_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,18 @@ class BaseEnv(object):
"""

@staticmethod
def to_base_env(env, make_env=None, num_envs=1):
def to_base_env(env, make_env=None, num_envs=1, remote_envs=False):
"""Wraps any env type as needed to expose the async interface."""
if remote_envs and num_envs == 1:
raise ValueError(
"Remote envs only make sense to use if num_envs > 1 "
"(i.e. vectorization is enabled).")
if not isinstance(env, BaseEnv):
if isinstance(env, MultiAgentEnv):
if remote_envs:
raise NotImplementedError(
"Remote multiagent environments are not implemented")

env = _MultiAgentEnvToBaseEnv(
make_env=make_env, existing_envs=[env], num_envs=num_envs)
elif isinstance(env, ExternalEnv):
Expand All @@ -81,7 +89,12 @@ def to_base_env(env, make_env=None, num_envs=1):
env = _VectorEnvToBaseEnv(env)
else:
env = VectorEnv.wrap(
make_env=make_env, existing_envs=[env], num_envs=num_envs)
make_env=make_env,
existing_envs=[env],
num_envs=num_envs,
remote_envs=remote_envs,
action_space=env.action_space,
observation_space=env.observation_space)
env = _VectorEnvToBaseEnv(env)
assert isinstance(env, BaseEnv)
return env
Expand Down
16 changes: 13 additions & 3 deletions python/ray/rllib/env/env_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,23 @@ class EnvContext(dict):
uniquely identifies the worker the env is created in.
vector_index (int): When there are multiple envs per worker, this
uniquely identifies the env index within the worker.
remote (bool): Whether environment should be remote or not.
"""

def __init__(self, env_config, worker_index, vector_index=0):
def __init__(self, env_config, worker_index, vector_index=0, remote=False):
dict.__init__(self, env_config)
self.worker_index = worker_index
self.vector_index = vector_index
self.remote = remote

def with_vector_index(self, vector_index):
def copy_with_overrides(self,
env_config=None,
worker_index=None,
vector_index=None,
remote=None):
return EnvContext(
self, worker_index=self.worker_index, vector_index=vector_index)
env_config if env_config is not None else self,
worker_index if worker_index is not None else self.worker_index,
vector_index if vector_index is not None else self.vector_index,
remote if remote is not None else self.remote,
)
98 changes: 93 additions & 5 deletions python/ray/rllib/env/vector_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,13 @@
from __future__ import division
from __future__ import print_function

import logging

import ray
from ray.rllib.utils.annotations import override, PublicAPI

logger = logging.getLogger(__name__)


@PublicAPI
class VectorEnv(object):
Expand All @@ -18,8 +23,17 @@ class VectorEnv(object):
"""

@staticmethod
def wrap(make_env=None, existing_envs=None, num_envs=1):
return _VectorizedGymEnv(make_env, existing_envs or [], num_envs)
def wrap(make_env=None,
existing_envs=None,
num_envs=1,
remote_envs=False,
action_space=None,
observation_space=None):
if remote_envs:
return _RemoteVectorizedGymEnv(make_env, num_envs, action_space,
observation_space)
return _VectorizedGymEnv(make_env, existing_envs or [], num_envs,
action_space, observation_space)

@PublicAPI
def vector_reset(self):
Expand Down Expand Up @@ -70,14 +84,20 @@ class _VectorizedGymEnv(VectorEnv):
num_envs (int): Desired num gym envs to keep total.
"""

def __init__(self, make_env, existing_envs, num_envs):
def __init__(self,
make_env,
existing_envs,
num_envs,
action_space=None,
observation_space=None):
self.make_env = make_env
self.envs = existing_envs
self.num_envs = num_envs
while len(self.envs) < self.num_envs:
self.envs.append(self.make_env(len(self.envs)))
self.action_space = self.envs[0].action_space
self.observation_space = self.envs[0].observation_space
self.action_space = action_space or self.envs[0].action_space
self.observation_space = observation_space or \
self.envs[0].observation_space

@override(VectorEnv)
def vector_reset(self):
Expand All @@ -101,3 +121,71 @@ def vector_step(self, actions):
@override(VectorEnv)
def get_unwrapped(self):
return self.envs


@ray.remote(num_cpus=0)
class _RemoteEnv(object):
"""Wrapper class for making a gym env a remote actor."""

def __init__(self, make_env, i):
self.env = make_env(i)

def reset(self):
return self.env.reset()

def step(self, action):
return self.env.step(action)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to call close() for remote env? SC2 environments are starting SC2 server which is a separate process, and and I guess the correct way to stop it in these situations would be calling the close method (though I see them dying after keyboard interrupt).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's python atexit which I think should work. If not, we can add close() hooks (but I don't know if this is as reliable in case of errors).



class _RemoteVectorizedGymEnv(_VectorizedGymEnv):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider extending VectorEnv directly, since you don't seem to use much of the functionality of VectorizedGymEnv

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, I do reuse constructor and get_unwrapped, no need to copy those. I would leave it like this.

"""Internal wrapper for gym envs to implement VectorEnv as remote workers.
"""

def __init__(self,
make_env,
num_envs,
action_space=None,
observation_space=None):
self.make_local_env = make_env
self.num_envs = num_envs
self.initialized = False
self.action_space = action_space
self.observation_space = observation_space

def _initialize_if_needed(self):
if self.initialized:
return

self.initialized = True

def make_remote_env(i):
logger.info("Launching env {} in remote actor".format(i))
return _RemoteEnv.remote(self.make_local_env, i)

_VectorizedGymEnv.__init__(self, make_remote_env, [], self.num_envs,
self.action_space, self.observation_space)

for env in self.envs:
assert isinstance(env, ray.actor.ActorHandle), env

@override(_VectorizedGymEnv)
def vector_reset(self):
self._initialize_if_needed()
return ray.get([env.reset.remote() for env in self.envs])

@override(_VectorizedGymEnv)
def reset_at(self, index):
return ray.get(self.envs[index].reset.remote())

@override(_VectorizedGymEnv)
def vector_step(self, actions):
step_outs = ray.get(
[env.step.remote(act) for env, act in zip(self.envs, actions)])

obs_batch, rew_batch, done_batch, info_batch = [], [], [], []
for obs, rew, done, info in step_outs:
obs_batch.append(obs)
rew_batch.append(rew)
done_batch.append(done)
info_batch.append(info)
return obs_batch, rew_batch, done_batch, info_batch
16 changes: 13 additions & 3 deletions python/ray/rllib/evaluation/policy_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,8 @@ def __init__(self,
callbacks=None,
input_creator=lambda ioctx: ioctx.default_sampler_input(),
input_evaluation_method=None,
output_creator=lambda ioctx: NoopOutput()):
output_creator=lambda ioctx: NoopOutput(),
remote_worker_envs=False):
"""Initialize a policy evaluator.

Arguments:
Expand Down Expand Up @@ -192,6 +193,10 @@ def __init__(self,
use this data for evaluation only and never for learning.
output_creator (func): Function that returns an OutputWriter object
for saving generated experiences.
remote_worker_envs (bool): If using num_envs > 1, whether to create
those new envs in remote processes instead of in the current
process. This adds overheads, but can make sense if your envs
are very CPU intensive (e.g., for StarCraft).
"""

if log_level:
Expand Down Expand Up @@ -250,7 +255,9 @@ def wrap(env):

def make_env(vector_index):
return wrap(
env_creator(env_context.with_vector_index(vector_index)))
env_creator(
env_context.copy_with_overrides(
vector_index=vector_index, remote=remote_worker_envs)))

self.tf_sess = None
policy_dict = _validate_and_canonicalize(policy_graph, self.env)
Expand Down Expand Up @@ -293,7 +300,10 @@ def make_env(vector_index):

# Always use vector env for consistency even if num_envs = 1
self.async_env = BaseEnv.to_base_env(
self.env, make_env=make_env, num_envs=num_envs)
self.env,
make_env=make_env,
num_envs=num_envs,
remote_envs=remote_worker_envs)
self.num_envs = num_envs

if self.batch_mode == "truncate_episodes":
Expand Down
7 changes: 7 additions & 0 deletions test/jenkins_tests/run_multi_node_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,13 @@ docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
--stop '{"training_iteration": 2}' \
--config '{"kl_coeff": 1.0, "num_sgd_iter": 10, "lr": 1e-4, "sgd_minibatch_size": 64, "train_batch_size": 2000, "num_workers": 1, "use_gae": false, "batch_mode": "complete_episodes"}'

docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
python /ray/python/ray/rllib/train.py \
--env CartPole-v1 \
--run PPO \
--stop '{"training_iteration": 2}' \
--config '{"remote_worker_envs": true, "num_envs_per_worker": 2, "num_workers": 1, "train_batch_size": 100, "sgd_minibatch_size": 50}'

docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
python /ray/python/ray/rllib/train.py \
--env Pendulum-v0 \
Expand Down