Skip to content

Commit c7031b4

Browse files
dmlyubimEdilmo
authored andcommitted
Adding runtime error related apis, metric aggregation, early bailout in TFGraphGPUMulti (#55)
1 parent b5e7e9a commit c7031b4

File tree

3 files changed

+36
-3
lines changed

3 files changed

+36
-3
lines changed

rllib/execution/train_ops.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -194,24 +194,36 @@ def __call__(self,
194194
# (2) Execute minibatch SGD on loaded data.
195195
fetches = {}
196196
for policy_id, tuples_per_device in num_loaded_tuples.items():
197+
policy = self.workers.local_worker().get_policy(policy_id)
197198
optimizer = self.optimizers[policy_id]
198199
num_batches = max(
199200
1,
200201
int(tuples_per_device) // int(self.per_device_batch_size))
201202
logger.debug("== sgd epochs for {} ==".format(policy_id))
203+
sgd_error = False
202204
for i in range(self.num_sgd_iter):
203205
iter_extra_fetches = defaultdict(list)
204206
permutation = np.random.permutation(num_batches)
205207
for batch_index in range(num_batches):
206208
batch_fetches = optimizer.optimize(
207209
self.sess, permutation[batch_index] *
208210
self.per_device_batch_size)
211+
sgd_error = policy.check_sgd_iter_errors(batch_fetches)
209212
for k, v in batch_fetches[LEARNER_STATS_KEY].items():
210213
iter_extra_fetches[k].append(v)
214+
if sgd_error:
215+
break
211216
if logger.getEffectiveLevel() <= logging.DEBUG:
212217
avg = averaged(iter_extra_fetches)
213218
logger.debug("{} {}".format(i, avg))
214-
fetches[policy_id] = averaged(iter_extra_fetches, axis=0)
219+
# Stop policy updates on any runtime errors
220+
if sgd_error:
221+
break
222+
# Note: This captures only last SGD iteration.
223+
fetches[policy_id] = averaged(
224+
iter_extra_fetches,
225+
axis=0,
226+
dict_averaging_func=policy.aggregate_dict_metric)
215227

216228
load_timer.push_units_processed(samples.count)
217229
learn_timer.push_units_processed(samples.count)

rllib/policy/tf_policy.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import errno
22
import gym
33
import logging
4+
from typing import Dict, List, Any
5+
46
import numpy as np
57
import os
68
from typing import Dict, List, Optional, Tuple, Union
@@ -602,6 +604,20 @@ def extra_compute_grad_fetches(self) -> Dict[str, any]:
602604
"""
603605
return {LEARNER_STATS_KEY: {}} # e.g, stats, td error, etc.
604606

607+
@DeveloperAPI
608+
def aggregate_dict_metric(self, name: str, dict_list: List[dict]) -> Any:
609+
""" Aggregate dictonary metrics created by extra grad fetches. By default return the first
610+
element only."""
611+
return dict_list[0]
612+
613+
@DeveloperAPI
614+
def check_sgd_iter_errors(self, minibatch_fetches: Any) -> bool:
615+
"""
616+
Check (and possibly log) sgd iteration errors based on evaluated metrics.
617+
:return: True if there are errors that must break optimization loop.
618+
"""
619+
return False
620+
605621
@DeveloperAPI
606622
def optimizer(self) -> "tf.keras.optimizers.Optimizer":
607623
"""TF optimizer to use for policy optimization.

rllib/utils/sgd.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
"""Utils for minibatch SGD across multiple RLlib policies."""
2+
from typing import Optional, Callable, List, Any
23

34
import numpy as np
45
import logging
@@ -13,13 +14,14 @@
1314
logger = logging.getLogger(__name__)
1415

1516

16-
def averaged(kv, axis=None):
17+
def averaged(kv, axis=None, dict_averaging_func: Optional[Callable[[str, List[dict]], Any]] = None):
1718
"""Average the value lists of a dictionary.
1819
1920
For non-scalar values, we simply pick the first value.
2021
2122
Args:
2223
kv (dict): dictionary with values that are lists of floats.
24+
dict_averaging_func: optional function averaging non-numeric (dictionary) arguments
2325
2426
Returns:
2527
dictionary with single averaged float as values.
@@ -29,7 +31,10 @@ def averaged(kv, axis=None):
2931
if v[0] is not None and not isinstance(v[0], dict):
3032
out[k] = np.mean(v, axis=axis)
3133
else:
32-
out[k] = v[0]
34+
if dict_averaging_func is not None:
35+
out[k] = dict_averaging_func(k, v)
36+
else:
37+
out[k] = v[0]
3338
return out
3439

3540

0 commit comments

Comments
 (0)