File tree Expand file tree Collapse file tree 3 files changed +16
-9
lines changed
Expand file tree Collapse file tree 3 files changed +16
-9
lines changed Original file line number Diff line number Diff line change @@ -2989,7 +2989,7 @@ def evaluation(
29892989 if offline_evaluation_type is not NotProvided :
29902990 self .offline_evaluation_type = offline_evaluation_type
29912991 if offline_eval_runner_class is not NotProvided :
2992- self .offline_eval_runner_cls = offline_eval_runner_class
2992+ self .offline_eval_runner_class = offline_eval_runner_class
29932993 if offline_loss_for_module_fn is not NotProvided :
29942994 self .offline_loss_for_module_fn = offline_loss_for_module_fn
29952995 if offline_eval_batch_size_per_runner is not NotProvided :
@@ -5328,8 +5328,8 @@ def _validate_offline_settings(self):
53285328
53295329 from ray .rllib .offline .offline_evaluation_runner import OfflineEvaluationRunner
53305330
5331- if self .prelearner_class and not issubclass (
5332- self .prelearner_class , OfflineEvaluationRunner
5331+ if self .offline_eval_runner_class and not issubclass (
5332+ self .offline_eval_runner_class , OfflineEvaluationRunner
53335333 ):
53345334 self ._value_error (
53355335 "Unknown `offline_eval_runner_class`. OfflineEvaluationRunner class needs to inherit "
Original file line number Diff line number Diff line change @@ -302,15 +302,20 @@ def training_step(self) -> None:
302302
303303 # Sampling from offline data.
304304 with self .metrics .log_time ((TIMERS , OFFLINE_SAMPLING_TIMER )):
305+ # If we should use an iterator in the learner(s). Note, in case of
306+ # multiple learners we must always return a list of iterators.
307+ return_iterator = return_iterator = (
308+ self .config .num_learners > 0
309+ or self .config .dataset_num_iters_per_learner != 1
310+ )
311+
305312 # Return an iterator in case we are using remote learners.
306313 batch_or_iterator = self .offline_data .sample (
307314 num_samples = self .config .train_batch_size_per_learner ,
308315 num_shards = self .config .num_learners ,
309316 # Return an iterator, if a `Learner` should update
310317 # multiple times per RLlib iteration.
311- return_iterator = self .config .dataset_num_iters_per_learner > 1
312- if self .config .dataset_num_iters_per_learner
313- else True ,
318+ return_iterator = return_iterator ,
314319 )
315320
316321 # Updating the policy.
Original file line number Diff line number Diff line change @@ -457,11 +457,13 @@ class (multi-/single-learner setup) and evaluation on
457457 # the user that sth. is not right, although it is as
458458 # we do not step the env.
459459 with self .metrics .log_time ((TIMERS , OFFLINE_SAMPLING_TIMER )):
460+ # If we should use an iterator in the learner(s). Note, in case of
461+ # multiple learners we must always return a list of iterators.
460462 return_iterator = (
461- self .config .dataset_num_iters_per_learner > 1
462- if self .config .dataset_num_iters_per_learner
463- else True
463+ self .config .num_learners > 0
464+ or self .config .dataset_num_iters_per_learner != 1
464465 )
466+
465467 # Sampling from offline data.
466468 batch_or_iterator = self .offline_data .sample (
467469 num_samples = self .config .train_batch_size_per_learner ,
You can’t perform that action at this time.
0 commit comments