Skip to content
10 changes: 10 additions & 0 deletions rllib/agents/ddpg/apex.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,15 @@
"buffer_size": 2000000,
# TODO(jungong) : update once Apex supports replay_buffer_config.
"replay_buffer_config": None,
# Whether all shards of the replay buffer must be co-located
# with the learner process (running the execution plan).
# This is preferred b/c the learner process should have quick
# access to the data from the buffer shards, avoiding network
# traffic each time samples from the buffer(s) are drawn.
# Set this to False for relaxing this constraint and allowing
# replay shards to be created on node(s) other than the one
# on which the learner is located.
"replay_buffer_shards_colocated_with_driver": True,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually I wonder, why do they need to be on the same node?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For APEX, local node is the learner, so data (one in the buffer shards) never has to travel again. I think that's the sole intention here.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see I see. to be honest, this doesn't feel like a requirement to me, more like an optimization.
since we don't have viability guarantee from Ray core, If it's up to me, I would choose to do this as a best-effort thing.
like trying to colocate everything, and if that fails, schedule the other rb shards anywhere.
then we don't have the while loop, and this scheduling can finish in at most 2 steps.

it is obviously too big of a change. maybe just add a note/todo somewhere???

as written, I am a little worried a stack may fail with mysterious error message like "fail to schedule RB actors" while there are enough CPUs, just a small head node.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a note: This is nothing new that I introduced here for APEX. We have always forced all replay shards to be located on the driver. This change actually allows users (via setting this new flag to False) to relax this constraint.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can add a comment to explain this more. ...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

appreciate!

"learning_starts": 50000,
"train_batch_size": 512,
"rollout_fragment_length": 50,
Expand All @@ -31,6 +40,7 @@
"worker_side_prioritization": True,
"min_iter_time_s": 30,
},
_allow_unknown_configs=True,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why need this?

Copy link
Contributor Author

@sven1977 sven1977 Jan 12, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are using Trainer's merge utility. It requires that if the second config (APEX-DDPG's) contains new keys that you set this to True.
Otherwise, it would complain about the new key (e.g. ) not being found in the first config (DDPG's).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👌 👌

)


Expand Down
36 changes: 33 additions & 3 deletions rllib/agents/dqn/apex.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import collections
import copy
import platform
from typing import Tuple

import ray
Expand All @@ -32,7 +33,7 @@
from ray.rllib.execution.rollout_ops import ParallelRollouts
from ray.rllib.execution.train_ops import UpdateTargetNetwork
from ray.rllib.utils import merge_dicts
from ray.rllib.utils.actors import create_colocated
from ray.rllib.utils.actors import create_colocated_actors
from ray.rllib.utils.annotations import override
from ray.rllib.utils.metrics.learner_info import LEARNER_INFO
from ray.rllib.utils.typing import SampleBatchType, TrainerConfigDict
Expand All @@ -55,10 +56,21 @@
"n_step": 3,
"num_gpus": 1,
"num_workers": 32,

"buffer_size": 2000000,
# TODO(jungong) : add proper replay_buffer_config after
# DistributedReplayBuffer type is supported.
"replay_buffer_config": None,
# Whether all shards of the replay buffer must be co-located
# with the learner process (running the execution plan).
# This is preferred b/c the learner process should have quick
# access to the data from the buffer shards, avoiding network
# traffic each time samples from the buffer(s) are drawn.
# Set this to False for relaxing this constraint and allowing
# replay shards to be created on node(s) other than the one
# on which the learner is located.
"replay_buffer_shards_colocated_with_driver": True,

"learning_starts": 50000,
"train_batch_size": 512,
"rollout_fragment_length": 50,
Expand Down Expand Up @@ -129,7 +141,8 @@ def execution_plan(workers: WorkerSet, config: dict,
# Create a number of replay buffer actors.
num_replay_buffer_shards = config["optimizer"][
"num_replay_buffer_shards"]
replay_actors = create_colocated(ReplayActor, [

replay_actor_args = [
num_replay_buffer_shards,
config["learning_starts"],
config["buffer_size"],
Expand All @@ -139,7 +152,24 @@ def execution_plan(workers: WorkerSet, config: dict,
config["prioritized_replay_eps"],
config["multiagent"]["replay_mode"],
config.get("replay_sequence_length", 1),
], num_replay_buffer_shards)
]
# Place all replay buffer shards on the same node as the learner
# (driver process that runs this execution plan).
if config["replay_buffer_shards_colocated_with_driver"]:
replay_actors = create_colocated_actors(
actor_specs=[
# (class, args, kwargs={}, count)
(ReplayActor, replay_actor_args, {},
num_replay_buffer_shards)
],
node=platform.node(), # localhost
)[0] # [0]=only one item in `actor_specs`.
# Place replay buffer shards on any node(s).
else:
replay_actors = [
ReplayActor(*replay_actor_args)
for _ in range(num_replay_buffer_shards)
]

# Start the learner thread.
learner_thread = LearnerThread(workers.local_worker())
Expand Down
31 changes: 19 additions & 12 deletions rllib/agents/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,11 @@
import pickle
import tempfile
import time
from typing import Callable, Dict, List, Optional, Tuple, Type, Union
from typing import Callable, DefaultDict, Dict, List, Optional, Set, Tuple, \
Type, Union

import ray
from ray.actor import ActorHandle
from ray.exceptions import RayError
from ray.rllib.agents.callbacks import DefaultCallbacks
from ray.rllib.env.env_context import EnvContext
Expand Down Expand Up @@ -722,8 +724,9 @@ def default_logger_creator(config):
self._episode_history = []
self._episodes_to_be_collected = []

# Evaluation WorkerSet and metrics last returned by `self.evaluate()`.
self.evaluation_workers = None
# Evaluation WorkerSet.
self.evaluation_workers: Optional[WorkerSet] = None
# Metrics most recently returned by `self.evaluate()`.
self.evaluation_metrics = {}

super().__init__(config, logger_creator, remote_checkpoint_dir,
Expand Down Expand Up @@ -798,12 +801,19 @@ def env_creator_from_classpath(env_context):
self.local_replay_buffer = (
self._create_local_replay_buffer_if_necessary(self.config))

# Create a dict, mapping ActorHandles to sets of open remote
# requests (object refs). This way, we keep track, of which actors
# inside this Trainer (e.g. a remote RolloutWorker) have
# already been sent how many (e.g. `sample()`) requests.
self.remote_requests_in_flight: \
DefaultDict[ActorHandle, Set[ray.ObjectRef]] = defaultdict(set)

# Deprecated way of implementing Trainer sub-classes (or "templates"
# via the soon-to-be deprecated `build_trainer` utility function).
# Instead, sub-classes should override the Trainable's `setup()`
# method and call super().setup() from within that override at some
# point.
self.workers = None
self.workers: Optional[WorkerSet] = None
self.train_exec_impl = None

# Old design: Override `Trainer._init` (or use `build_trainer()`, which
Expand Down Expand Up @@ -845,13 +855,10 @@ def env_creator_from_classpath(env_context):
self.workers, self.config,
**self._kwargs_for_execution_plan())

# TODO: Now that workers have been created, update our policy
# specs in the config[multiagent] dict with the correct spaces.
# However, this leads to a problem with the evaluation
# workers' observation one-hot preprocessor in
# `examples/documentation/rllib_in_6sec.py` script.
# self.config["multiagent"]["policies"] = \
# self.workers.local_worker().policy_map.policy_specs
# Now that workers have been created, update our policy
# specs in the config[multiagent] dict with the correct spaces.
self.config["multiagent"]["policies"] = \
self.workers.local_worker().policy_dict

# Evaluation WorkerSet setup.
# User would like to setup a separate evaluation worker set.
Expand Down Expand Up @@ -912,7 +919,7 @@ def env_creator_from_classpath(env_context):
# If evaluation_num_workers=0, use the evaluation set's local
# worker for evaluation, otherwise, use its remote workers
# (parallelized evaluation).
self.evaluation_workers = self._make_workers(
self.evaluation_workers: WorkerSet = self._make_workers(
env_creator=self.env_creator,
validate_env=None,
policy_class=self.get_default_policy_class(self.config),
Expand Down
22 changes: 16 additions & 6 deletions rllib/evaluation/rollout_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
TYPE_CHECKING, Union

import ray
from ray import ObjectRef
from ray import cloudpickle as pickle
from ray.rllib.env.base_env import BaseEnv, convert_to_base_env
from ray.rllib.env.env_context import EnvContext
Expand Down Expand Up @@ -537,16 +538,16 @@ def make_sub_env(vector_index):
self.make_sub_env_fn = make_sub_env
self.spaces = spaces

policy_dict = _determine_spaces_for_multi_agent_dict(
self.policy_dict = _determine_spaces_for_multi_agent_dict(
policy_spec,
self.env,
spaces=self.spaces,
policy_config=policy_config)

# List of IDs of those policies, which should be trained.
# By default, these are all policies found in the policy_dict.
# By default, these are all policies found in `self.policy_dict`.
self.policies_to_train: List[PolicyID] = policies_to_train or list(
policy_dict.keys())
self.policy_dict.keys())
self.set_policies_to_train(self.policies_to_train)

self.policy_map: PolicyMap = None
Expand Down Expand Up @@ -583,7 +584,7 @@ def make_sub_env(vector_index):
f"is ignored.")

self._build_policy_map(
policy_dict,
self.policy_dict,
policy_config,
session_creator=tf_session_creator,
seed=seed)
Expand Down Expand Up @@ -1111,7 +1112,7 @@ def add_policy(
"""
if policy_id in self.policy_map:
raise ValueError(f"Policy ID '{policy_id}' already in policy map!")
policy_dict = _determine_spaces_for_multi_agent_dict(
policy_dict_to_add = _determine_spaces_for_multi_agent_dict(
{
policy_id: PolicySpec(policy_cls, observation_space,
action_space, config or {})
Expand All @@ -1120,8 +1121,9 @@ def add_policy(
spaces=self.spaces,
policy_config=self.policy_config,
)
self.policy_dict.update(policy_dict_to_add)
self._build_policy_map(
policy_dict,
policy_dict_to_add,
self.policy_config,
seed=self.policy_config.get("seed"))
new_policy = self.policy_map[policy_id]
Expand Down Expand Up @@ -1386,6 +1388,14 @@ def set_weights(self,
>>> # Set `global_vars` (timestep) as well.
>>> worker.set_weights(weights, {"timestep": 42})
"""
# If per-policy weights are object refs, `ray.get()` them first.
if weights and isinstance(next(iter(weights.values())), ObjectRef):
actual_weights = ray.get(list(weights.values()))
weights = {
pid: actual_weights[i]
for i, pid in enumerate(weights.keys())
}

for pid, w in weights.items():
self.policy_map[pid].set_weights(w)
if global_vars:
Expand Down
Loading