Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion rllib/execution/train_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,24 +192,36 @@ def __call__(self,
# (2) Execute minibatch SGD on loaded data.
fetches = {}
for policy_id, tuples_per_device in num_loaded_tuples.items():
policy = self.workers.local_worker().get_policy(policy_id)
optimizer = self.optimizers[policy_id]
num_batches = max(
1,
int(tuples_per_device) // int(self.per_device_batch_size))
logger.debug("== sgd epochs for {} ==".format(policy_id))
sgd_error = False
for i in range(self.num_sgd_iter):
iter_extra_fetches = defaultdict(list)
permutation = np.random.permutation(num_batches)
for batch_index in range(num_batches):
batch_fetches = optimizer.optimize(
self.sess, permutation[batch_index] *
self.per_device_batch_size)
sgd_error = policy.check_sgd_iter_errors(batch_fetches)
for k, v in batch_fetches[LEARNER_STATS_KEY].items():
iter_extra_fetches[k].append(v)
if sgd_error:
break
if logger.getEffectiveLevel() <= logging.DEBUG:
avg = averaged(iter_extra_fetches)
logger.debug("{} {}".format(i, avg))
fetches[policy_id] = averaged(iter_extra_fetches, axis=0)
# Stop policy updates on any runtime errors
if sgd_error:
break
# Note: This captures only last SGD iteration.
fetches[policy_id] = averaged(
iter_extra_fetches,
axis=0,
dict_averaging_func=policy.aggregate_dict_metric)

load_timer.push_units_processed(samples.count)
learn_timer.push_units_processed(samples.count)
Expand Down
16 changes: 16 additions & 0 deletions rllib/policy/tf_policy.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import errno
import logging
from typing import Dict, List, Any

import numpy as np
import os

Expand Down Expand Up @@ -475,6 +477,20 @@ def extra_compute_grad_fetches(self):
"""Extra values to fetch and return from compute_gradients()."""
return {LEARNER_STATS_KEY: {}} # e.g, stats, td error, etc.

@DeveloperAPI
def aggregate_dict_metric(self, name: str, dict_list: List[dict]) -> Any:
""" Aggregate dictonary metrics created by extra grad fetches. By default return the first
element only."""
return dict_list[0]

@DeveloperAPI
def check_sgd_iter_errors(self, minibatch_fetches: Any) -> bool:
"""
Check (and possibly log) sgd iteration errors based on evaluated metrics.
:return: True if there are errors that must break optimization loop.
"""
return False

@DeveloperAPI
def optimizer(self):
"""TF optimizer to use for policy optimization."""
Expand Down
9 changes: 7 additions & 2 deletions rllib/utils/sgd.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Utils for minibatch SGD across multiple RLlib policies."""
from typing import Optional, Callable, List, Any

import numpy as np
import logging
Expand All @@ -13,13 +14,14 @@
logger = logging.getLogger(__name__)


def averaged(kv, axis=None):
def averaged(kv, axis=None, dict_averaging_func: Optional[Callable[[str, List[dict]], Any]] = None):
"""Average the value lists of a dictionary.

For non-scalar values, we simply pick the first value.

Arguments:
kv (dict): dictionary with values that are lists of floats.
dict_averaging_func: optional function averaging non-numeric (dictionary) arguments

Returns:
dictionary with single averaged float as values.
Expand All @@ -29,7 +31,10 @@ def averaged(kv, axis=None):
if v[0] is not None and not isinstance(v[0], dict):
out[k] = np.mean(v, axis=axis)
else:
out[k] = v[0]
if dict_averaging_func is not None:
out[k] = dict_averaging_func(k, v)
else:
out[k] = v[0]
return out


Expand Down