diff --git a/rllib/algorithms/algorithm_config.py b/rllib/algorithms/algorithm_config.py index 2c6120b86452..19e8801319af 100644 --- a/rllib/algorithms/algorithm_config.py +++ b/rllib/algorithms/algorithm_config.py @@ -2989,7 +2989,7 @@ def evaluation( if offline_evaluation_type is not NotProvided: self.offline_evaluation_type = offline_evaluation_type if offline_eval_runner_class is not NotProvided: - self.offline_eval_runner_cls = offline_eval_runner_class + self.offline_eval_runner_class = offline_eval_runner_class if offline_loss_for_module_fn is not NotProvided: self.offline_loss_for_module_fn = offline_loss_for_module_fn if offline_eval_batch_size_per_runner is not NotProvided: @@ -5328,8 +5328,8 @@ def _validate_offline_settings(self): from ray.rllib.offline.offline_evaluation_runner import OfflineEvaluationRunner - if self.prelearner_class and not issubclass( - self.prelearner_class, OfflineEvaluationRunner + if self.offline_eval_runner_class and not issubclass( + self.offline_eval_runner_class, OfflineEvaluationRunner ): self._value_error( "Unknown `offline_eval_runner_class`. OfflineEvaluationRunner class needs to inherit " diff --git a/rllib/algorithms/cql/cql.py b/rllib/algorithms/cql/cql.py index de972a119a90..e2f3ac2eff44 100644 --- a/rllib/algorithms/cql/cql.py +++ b/rllib/algorithms/cql/cql.py @@ -302,15 +302,20 @@ def training_step(self) -> None: # Sampling from offline data. with self.metrics.log_time((TIMERS, OFFLINE_SAMPLING_TIMER)): + # If we should use an iterator in the learner(s). Note, in case of + # multiple learners we must always return a list of iterators. + return_iterator = return_iterator = ( + self.config.num_learners > 0 + or self.config.dataset_num_iters_per_learner != 1 + ) + # Return an iterator in case we are using remote learners. batch_or_iterator = self.offline_data.sample( num_samples=self.config.train_batch_size_per_learner, num_shards=self.config.num_learners, # Return an iterator, if a `Learner` should update # multiple times per RLlib iteration. - return_iterator=self.config.dataset_num_iters_per_learner > 1 - if self.config.dataset_num_iters_per_learner - else True, + return_iterator=return_iterator, ) # Updating the policy. diff --git a/rllib/algorithms/marwil/marwil.py b/rllib/algorithms/marwil/marwil.py index b0a06ae6d2d8..4ebf1d9333a4 100644 --- a/rllib/algorithms/marwil/marwil.py +++ b/rllib/algorithms/marwil/marwil.py @@ -457,11 +457,13 @@ class (multi-/single-learner setup) and evaluation on # the user that sth. is not right, although it is as # we do not step the env. with self.metrics.log_time((TIMERS, OFFLINE_SAMPLING_TIMER)): + # If we should use an iterator in the learner(s). Note, in case of + # multiple learners we must always return a list of iterators. return_iterator = ( - self.config.dataset_num_iters_per_learner > 1 - if self.config.dataset_num_iters_per_learner - else True + self.config.num_learners > 0 + or self.config.dataset_num_iters_per_learner != 1 ) + # Sampling from offline data. batch_or_iterator = self.offline_data.sample( num_samples=self.config.train_batch_size_per_learner,