diff --git a/rllib/execution/train_ops.py b/rllib/execution/train_ops.py index ef4d465bb9b9..96625df39dd8 100644 --- a/rllib/execution/train_ops.py +++ b/rllib/execution/train_ops.py @@ -192,11 +192,13 @@ def __call__(self, # (2) Execute minibatch SGD on loaded data. fetches = {} for policy_id, tuples_per_device in num_loaded_tuples.items(): + policy = self.workers.local_worker().get_policy(policy_id) optimizer = self.optimizers[policy_id] num_batches = max( 1, int(tuples_per_device) // int(self.per_device_batch_size)) logger.debug("== sgd epochs for {} ==".format(policy_id)) + sgd_error = False for i in range(self.num_sgd_iter): iter_extra_fetches = defaultdict(list) permutation = np.random.permutation(num_batches) @@ -204,12 +206,22 @@ def __call__(self, batch_fetches = optimizer.optimize( self.sess, permutation[batch_index] * self.per_device_batch_size) + sgd_error = policy.check_sgd_iter_errors(batch_fetches) for k, v in batch_fetches[LEARNER_STATS_KEY].items(): iter_extra_fetches[k].append(v) + if sgd_error: + break if logger.getEffectiveLevel() <= logging.DEBUG: avg = averaged(iter_extra_fetches) logger.debug("{} {}".format(i, avg)) - fetches[policy_id] = averaged(iter_extra_fetches, axis=0) + # Stop policy updates on any runtime errors + if sgd_error: + break + # Note: This captures only last SGD iteration. + fetches[policy_id] = averaged( + iter_extra_fetches, + axis=0, + dict_averaging_func=policy.aggregate_dict_metric) load_timer.push_units_processed(samples.count) learn_timer.push_units_processed(samples.count) diff --git a/rllib/policy/tf_policy.py b/rllib/policy/tf_policy.py index d3b00a2d89d1..7c58eeb0bce8 100644 --- a/rllib/policy/tf_policy.py +++ b/rllib/policy/tf_policy.py @@ -1,5 +1,7 @@ import errno import logging +from typing import Dict, List, Any + import numpy as np import os @@ -475,6 +477,20 @@ def extra_compute_grad_fetches(self): """Extra values to fetch and return from compute_gradients().""" return {LEARNER_STATS_KEY: {}} # e.g, stats, td error, etc. + @DeveloperAPI + def aggregate_dict_metric(self, name: str, dict_list: List[dict]) -> Any: + """ Aggregate dictonary metrics created by extra grad fetches. By default return the first + element only.""" + return dict_list[0] + + @DeveloperAPI + def check_sgd_iter_errors(self, minibatch_fetches: Any) -> bool: + """ + Check (and possibly log) sgd iteration errors based on evaluated metrics. + :return: True if there are errors that must break optimization loop. + """ + return False + @DeveloperAPI def optimizer(self): """TF optimizer to use for policy optimization.""" diff --git a/rllib/utils/sgd.py b/rllib/utils/sgd.py index bc2e920b0713..1b7b84c2ae5d 100644 --- a/rllib/utils/sgd.py +++ b/rllib/utils/sgd.py @@ -1,4 +1,5 @@ """Utils for minibatch SGD across multiple RLlib policies.""" +from typing import Optional, Callable, List, Any import numpy as np import logging @@ -13,13 +14,14 @@ logger = logging.getLogger(__name__) -def averaged(kv, axis=None): +def averaged(kv, axis=None, dict_averaging_func: Optional[Callable[[str, List[dict]], Any]] = None): """Average the value lists of a dictionary. For non-scalar values, we simply pick the first value. Arguments: kv (dict): dictionary with values that are lists of floats. + dict_averaging_func: optional function averaging non-numeric (dictionary) arguments Returns: dictionary with single averaged float as values. @@ -29,7 +31,10 @@ def averaged(kv, axis=None): if v[0] is not None and not isinstance(v[0], dict): out[k] = np.mean(v, axis=axis) else: - out[k] = v[0] + if dict_averaging_func is not None: + out[k] = dict_averaging_func(k, v) + else: + out[k] = v[0] return out