Skip to content
2 changes: 1 addition & 1 deletion doc/source/rllib-concepts.rst
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ This is how the example in the previous section looks when written using a polic
Trainers
--------

Trainers are the boilerplate classes that put the above components together. Trainer make algorithms accessible via Python API and the command line. They manage algorithm configuration, setup of the policy evaluators and optimizer, and collection of training metrics. Trainers also implement the `Trainable API <https://ray.readthedocs.io/en/latest/tune-usage.html#training-api>`__ for easy experiment management.
Trainers are the boilerplate classes that put the above components together, making algorithms accessible via Python API and the command line. They manage algorithm configuration, setup of the policy evaluators and optimizer, and collection of training metrics. Trainers also implement the `Trainable API <https://ray.readthedocs.io/en/latest/tune-usage.html#training-api>`__ for easy experiment management.

Example of two equivalent ways of interacting with the PPO trainer:

Expand Down
3 changes: 3 additions & 0 deletions python/ray/internal/internal_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ def free(object_ids, local_only=False, delete_creating_tasks=False):
"""
worker = ray.worker.get_global_worker()

if ray.worker._mode() == ray.worker.LOCAL_MODE:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh this is nice

return

if isinstance(object_ids, ray.ObjectID):
object_ids = [object_ids]

Expand Down
3 changes: 2 additions & 1 deletion python/ray/rllib/agents/ars/ars.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from ray.rllib.agents.ars import utils
from ray.rllib.evaluation.sample_batch import DEFAULT_POLICY_ID
from ray.rllib.utils.annotations import override
from ray.rllib.utils.memory import ray_get_and_free
from ray.rllib.utils import FilterManager

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -312,7 +313,7 @@ def _collect_results(self, theta_id, min_episodes):
worker.do_rollouts.remote(theta_id) for worker in self.workers
]
# Get the results of the rollouts.
for result in ray.get(rollout_ids):
for result in ray_get_and_free(rollout_ids):
results.append(result)
# Update the number of episodes and the number of timesteps
# keeping in mind that result.noisy_lengths is a list of lists,
Expand Down
3 changes: 2 additions & 1 deletion python/ray/rllib/agents/es/es.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from ray.rllib.agents.es import utils
from ray.rllib.evaluation.sample_batch import DEFAULT_POLICY_ID
from ray.rllib.utils.annotations import override
from ray.rllib.utils.memory import ray_get_and_free
from ray.rllib.utils import FilterManager

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -309,7 +310,7 @@ def _collect_results(self, theta_id, min_episodes, min_timesteps):
worker.do_rollouts.remote(theta_id) for worker in self.workers
]
# Get the results of the rollouts.
for result in ray.get(rollout_ids):
for result in ray_get_and_free(rollout_ids):
results.append(result)
# Update the number of episodes and the number of timesteps
# keeping in mind that result.noisy_lengths is a list of lists,
Expand Down
3 changes: 2 additions & 1 deletion python/ray/rllib/agents/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer
from ray.rllib.utils.annotations import override, PublicAPI, DeveloperAPI
from ray.rllib.utils import FilterManager, deep_update, merge_dicts
from ray.rllib.utils.memory import ray_get_and_free
from ray.tune.registry import ENV_CREATOR, register_env, _global_registry
from ray.tune.trainable import Trainable
from ray.tune.trial import Resources, ExportFormat
Expand Down Expand Up @@ -668,7 +669,7 @@ def _try_recover(self):
for i, obj_id in enumerate(checks):
ev = self.optimizer.remote_evaluators[i]
try:
ray.get(obj_id)
ray_get_and_free(obj_id)
healthy_evaluators.append(ev)
logger.info("Worker {} looks healthy".format(i + 1))
except RayError:
Expand Down
3 changes: 2 additions & 1 deletion python/ray/rllib/env/remote_vector_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import ray
from ray.rllib.env.base_env import BaseEnv, _DUMMY_AGENT_ID, ASYNC_RESET_RETURN
from ray.rllib.utils.memory import ray_get_and_free

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -60,7 +61,7 @@ def make_remote_env(i):
actor = self.pending.pop(obj_id)
env_id = self.actors.index(actor)
env_ids.add(env_id)
ob, rew, done, info = ray.get(obj_id)
ob, rew, done, info = ray_get_and_free(obj_id)
obs[env_id] = ob
rewards[env_id] = rew
dones[env_id] = done
Expand Down
3 changes: 2 additions & 1 deletion python/ray/rllib/evaluation/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from ray.rllib.evaluation.sample_batch import DEFAULT_POLICY_ID
from ray.rllib.offline.off_policy_estimator import OffPolicyEstimate
from ray.rllib.utils.annotations import DeveloperAPI
from ray.rllib.utils.memory import ray_get_and_free

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -69,7 +70,7 @@ def collect_episodes(local_evaluator=None,
"Timed out waiting for metrics from workers. You can configure "
"this timeout with `collect_metrics_timeout`.")

metric_lists = ray.get(collected)
metric_lists = ray_get_and_free(collected)
if local_evaluator:
metric_lists.append(local_evaluator.get_metrics())
episodes = []
Expand Down
3 changes: 2 additions & 1 deletion python/ray/rllib/optimizers/aso_aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import ray
from ray.rllib.utils.actors import TaskPool
from ray.rllib.utils.annotations import override
from ray.rllib.utils.memory import ray_get_and_free


class Aggregator(object):
Expand Down Expand Up @@ -143,7 +144,7 @@ def can_replay():
return len(self.replay_batches) > num_needed

for ev, sample_batch in sample_futures:
sample_batch = ray.get(sample_batch)
sample_batch = ray_get_and_free(sample_batch)
yield ev, sample_batch

if can_replay():
Expand Down
3 changes: 2 additions & 1 deletion python/ray/rllib/optimizers/aso_tree_aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from ray.rllib.utils.annotations import override
from ray.rllib.optimizers.aso_aggregator import Aggregator, \
AggregationWorkerBase
from ray.rllib.utils.memory import ray_get_and_free

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -86,7 +87,7 @@ def init(self, aggregators):
def iter_train_batches(self):
assert self.initialized, "Must call init() before using this class."
for agg, batches in self.agg_tasks.completed_prefetch():
for b in ray.get(batches):
for b in ray_get_and_free(batches):
self.num_sent_since_broadcast += 1
yield b
agg.set_weights.remote(self.broadcasted_weights)
Expand Down
3 changes: 2 additions & 1 deletion python/ray/rllib/optimizers/async_gradients_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer
from ray.rllib.utils.annotations import override
from ray.rllib.utils.timer import TimerStat
from ray.rllib.utils.memory import ray_get_and_free


class AsyncGradientsOptimizer(PolicyOptimizer):
Expand Down Expand Up @@ -49,7 +50,7 @@ def step(self):
ready_list = wait_results[0]
future = ready_list[0]

gradient, info = ray.get(future)
gradient, info = ray_get_and_free(future)
e = pending_gradients.pop(future)
self.learner_stats = get_learner_stats(info)

Expand Down
8 changes: 5 additions & 3 deletions python/ray/rllib/optimizers/async_replay_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from ray.rllib.optimizers.replay_buffer import PrioritizedReplayBuffer
from ray.rllib.utils.annotations import override
from ray.rllib.utils.actors import TaskPool, create_colocated
from ray.rllib.utils.memory import ray_get_and_free
from ray.rllib.utils.timer import TimerStat
from ray.rllib.utils.window_stat import WindowStat

Expand Down Expand Up @@ -143,7 +144,8 @@ def reset(self, remote_evaluators):

@override(PolicyOptimizer)
def stats(self):
replay_stats = ray.get(self.replay_actors[0].stats.remote(self.debug))
replay_stats = ray_get_and_free(self.replay_actors[0].stats.remote(
self.debug))
timing = {
"{}_time_ms".format(k): round(1000 * self.timers[k].mean, 3)
for k in self.timers
Expand Down Expand Up @@ -188,7 +190,7 @@ def _step(self):

with self.timers["sample_processing"]:
completed = list(self.sample_tasks.completed())
counts = ray.get([c[1][1] for c in completed])
counts = ray_get_and_free([c[1][1] for c in completed])
for i, (ev, (sample_batch, count)) in enumerate(completed):
sample_timesteps += counts[i]

Expand Down Expand Up @@ -220,7 +222,7 @@ def _step(self):
self.num_samples_dropped += 1
else:
with self.timers["get_samples"]:
samples = ray.get(replay)
samples = ray_get_and_free(replay)
# Defensive copy against plasma crashes, see #2610 #3452
self.learner.inqueue.put((ra, samples and samples.copy()))

Expand Down
6 changes: 3 additions & 3 deletions python/ray/rllib/optimizers/policy_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@

import logging

import ray
from ray.rllib.utils.annotations import DeveloperAPI
from ray.rllib.evaluation.metrics import collect_episodes, summarize_episodes
from ray.rllib.utils.memory import ray_get_and_free

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -140,7 +140,7 @@ def foreach_evaluator(self, func):
"""Apply the given function to each evaluator instance."""

local_result = [func(self.local_evaluator)]
remote_results = ray.get(
remote_results = ray_get_and_free(
[ev.apply.remote(func) for ev in self.remote_evaluators])
return local_result + remote_results

Expand All @@ -152,7 +152,7 @@ def foreach_evaluator_with_index(self, func):
"""

local_result = [func(self.local_evaluator, 0)]
remote_results = ray.get([
remote_results = ray_get_and_free([
ev.apply.remote(func, i + 1)
for i, ev in enumerate(self.remote_evaluators)
])
Expand Down
5 changes: 3 additions & 2 deletions python/ray/rllib/optimizers/rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import ray
from ray.rllib.evaluation.sample_batch import SampleBatch
from ray.rllib.utils.memory import ray_get_and_free

logger = logging.getLogger(__name__)

Expand All @@ -25,7 +26,7 @@ def collect_samples(agents, sample_batch_size, num_envs_per_worker,
while agent_dict:
[fut_sample], _ = ray.wait(list(agent_dict))
agent = agent_dict.pop(fut_sample)
next_sample = ray.get(fut_sample)
next_sample = ray_get_and_free(fut_sample)
assert next_sample.count >= sample_batch_size * num_envs_per_worker
num_timesteps_so_far += next_sample.count
trajectories.append(next_sample)
Expand Down Expand Up @@ -63,7 +64,7 @@ def collect_samples_straggler_mitigation(agents, train_batch_size):
fut_sample2 = agent.sample.remote()
agent_dict[fut_sample2] = agent

next_sample = ray.get(fut_sample)
next_sample = ray_get_and_free(fut_sample)
num_timesteps_so_far += next_sample.count
trajectories.append(next_sample)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
MultiAgentBatch
from ray.rllib.utils.annotations import override
from ray.rllib.utils.timer import TimerStat
from ray.rllib.utils.memory import ray_get_and_free


class SyncBatchReplayOptimizer(PolicyOptimizer):
Expand Down Expand Up @@ -51,7 +52,7 @@ def step(self):

with self.sample_timer:
if self.remote_evaluators:
batches = ray.get(
batches = ray_get_and_free(
[e.sample.remote() for e in self.remote_evaluators])
else:
batches = [self.local_evaluator.sample()]
Expand Down
3 changes: 2 additions & 1 deletion python/ray/rllib/optimizers/sync_replay_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from ray.rllib.utils.compression import pack_if_needed
from ray.rllib.utils.timer import TimerStat
from ray.rllib.utils.schedules import LinearSchedule
from ray.rllib.utils.memory import ray_get_and_free

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -89,7 +90,7 @@ def step(self):
with self.sample_timer:
if self.remote_evaluators:
batch = SampleBatch.concat_samples(
ray.get(
ray_get_and_free(
[e.sample.remote() for e in self.remote_evaluators]))
else:
batch = self.local_evaluator.sample()
Expand Down
3 changes: 2 additions & 1 deletion python/ray/rllib/optimizers/sync_samples_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from ray.rllib.utils.annotations import override
from ray.rllib.utils.filter import RunningStat
from ray.rllib.utils.timer import TimerStat
from ray.rllib.utils.memory import ray_get_and_free

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -50,7 +51,7 @@ def step(self):
while sum(s.count for s in samples) < self.train_batch_size:
if self.remote_evaluators:
samples.extend(
ray.get([
ray_get_and_free([
e.sample.remote() for e in self.remote_evaluators
]))
else:
Expand Down
2 changes: 2 additions & 0 deletions python/ray/rllib/setup-rllib-dev.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ def do_link(package, force=False):
do_link("tune", force=args.yes)
do_link("autoscaler", force=args.yes)
do_link("scripts", force=args.yes)
do_link("internal", force=args.yes)
do_link("experimental", force=args.yes)
print("Created links.\n\nIf you run into issues initializing Ray, please "
"ensure that your local repo and the installed Ray are in sync "
"(pip install -U the latest wheels at "
Expand Down
3 changes: 2 additions & 1 deletion python/ray/rllib/utils/filter_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import ray
from ray.rllib.utils.annotations import DeveloperAPI
from ray.rllib.utils.memory import ray_get_and_free


@DeveloperAPI
Expand All @@ -24,7 +25,7 @@ def synchronize(local_filters, remotes, update_remote=True):
remotes (list): Remote evaluators with filters.
update_remote (bool): Whether to push updates to remote filters.
"""
remote_filters = ray.get(
remote_filters = ray_get_and_free(
[r.get_filters.remote(flush_after=True) for r in remotes])
for rf in remote_filters:
for k in local_filters:
Expand Down
41 changes: 41 additions & 0 deletions python/ray/rllib/utils/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,47 @@
from __future__ import print_function

import numpy as np
import time

import ray

FREE_DELAY_S = 10.0
MAX_FREE_QUEUE_SIZE = 100
_last_free_time = 0.0
_to_free = []


def ray_get_and_free(object_ids):
"""Call ray.get and then queue the object ids for deletion.

This function should be used whenever possible in RLlib, to optimize
memory usage. The only exception is when an object_id is shared among
multiple readers.

Args:
object_ids (ObjectID|List[ObjectID]): Object ids to fetch and free.

Returns:
The result of ray.get(object_ids).
"""

global _last_free_time
global _to_free

result = ray.get(object_ids)
if type(object_ids) is not list:
object_ids = [object_ids]
_to_free.extend(object_ids)

# batch calls to free to reduce overheads
now = time.time()
if (len(_to_free) > MAX_FREE_QUEUE_SIZE
or now - _last_free_time > FREE_DELAY_S):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

batch the free calls to avoid too much overheads

ray.internal.free(_to_free)
_to_free = []
_last_free_time = now

return result


def aligned_array(size, dtype, align=64):
Expand Down