@@ -194,24 +194,36 @@ def __call__(self,
194194 # (2) Execute minibatch SGD on loaded data.
195195 fetches = {}
196196 for policy_id , tuples_per_device in num_loaded_tuples .items ():
197+ policy = self .workers .local_worker ().get_policy (policy_id )
197198 optimizer = self .optimizers [policy_id ]
198199 num_batches = max (
199200 1 ,
200201 int (tuples_per_device ) // int (self .per_device_batch_size ))
201202 logger .debug ("== sgd epochs for {} ==" .format (policy_id ))
203+ sgd_error = False
202204 for i in range (self .num_sgd_iter ):
203205 iter_extra_fetches = defaultdict (list )
204206 permutation = np .random .permutation (num_batches )
205207 for batch_index in range (num_batches ):
206208 batch_fetches = optimizer .optimize (
207209 self .sess , permutation [batch_index ] *
208210 self .per_device_batch_size )
211+ sgd_error = policy .check_sgd_iter_errors (batch_fetches )
209212 for k , v in batch_fetches [LEARNER_STATS_KEY ].items ():
210213 iter_extra_fetches [k ].append (v )
214+ if sgd_error :
215+ break
211216 if logger .getEffectiveLevel () <= logging .DEBUG :
212217 avg = averaged (iter_extra_fetches )
213218 logger .debug ("{} {}" .format (i , avg ))
214- fetches [policy_id ] = averaged (iter_extra_fetches , axis = 0 )
219+ # Stop policy updates on any runtime errors
220+ if sgd_error :
221+ break
222+ # Note: This captures only last SGD iteration.
223+ fetches [policy_id ] = averaged (
224+ iter_extra_fetches ,
225+ axis = 0 ,
226+ dict_averaging_func = policy .aggregate_dict_metric )
215227
216228 load_timer .push_units_processed (samples .count )
217229 learn_timer .push_units_processed (samples .count )
0 commit comments