Skip to content

Commit 1e5094f

Browse files
[RLlib - Offline RL] Fix bug in return_iterator in multi-learner settings. (#55693)
1 parent b830b8d commit 1e5094f

File tree

3 files changed

+16
-9
lines changed

3 files changed

+16
-9
lines changed

rllib/algorithms/algorithm_config.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2989,7 +2989,7 @@ def evaluation(
29892989
if offline_evaluation_type is not NotProvided:
29902990
self.offline_evaluation_type = offline_evaluation_type
29912991
if offline_eval_runner_class is not NotProvided:
2992-
self.offline_eval_runner_cls = offline_eval_runner_class
2992+
self.offline_eval_runner_class = offline_eval_runner_class
29932993
if offline_loss_for_module_fn is not NotProvided:
29942994
self.offline_loss_for_module_fn = offline_loss_for_module_fn
29952995
if offline_eval_batch_size_per_runner is not NotProvided:
@@ -5328,8 +5328,8 @@ def _validate_offline_settings(self):
53285328

53295329
from ray.rllib.offline.offline_evaluation_runner import OfflineEvaluationRunner
53305330

5331-
if self.prelearner_class and not issubclass(
5332-
self.prelearner_class, OfflineEvaluationRunner
5331+
if self.offline_eval_runner_class and not issubclass(
5332+
self.offline_eval_runner_class, OfflineEvaluationRunner
53335333
):
53345334
self._value_error(
53355335
"Unknown `offline_eval_runner_class`. OfflineEvaluationRunner class needs to inherit "

rllib/algorithms/cql/cql.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -302,15 +302,20 @@ def training_step(self) -> None:
302302

303303
# Sampling from offline data.
304304
with self.metrics.log_time((TIMERS, OFFLINE_SAMPLING_TIMER)):
305+
# If we should use an iterator in the learner(s). Note, in case of
306+
# multiple learners we must always return a list of iterators.
307+
return_iterator = return_iterator = (
308+
self.config.num_learners > 0
309+
or self.config.dataset_num_iters_per_learner != 1
310+
)
311+
305312
# Return an iterator in case we are using remote learners.
306313
batch_or_iterator = self.offline_data.sample(
307314
num_samples=self.config.train_batch_size_per_learner,
308315
num_shards=self.config.num_learners,
309316
# Return an iterator, if a `Learner` should update
310317
# multiple times per RLlib iteration.
311-
return_iterator=self.config.dataset_num_iters_per_learner > 1
312-
if self.config.dataset_num_iters_per_learner
313-
else True,
318+
return_iterator=return_iterator,
314319
)
315320

316321
# Updating the policy.

rllib/algorithms/marwil/marwil.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -457,11 +457,13 @@ class (multi-/single-learner setup) and evaluation on
457457
# the user that sth. is not right, although it is as
458458
# we do not step the env.
459459
with self.metrics.log_time((TIMERS, OFFLINE_SAMPLING_TIMER)):
460+
# If we should use an iterator in the learner(s). Note, in case of
461+
# multiple learners we must always return a list of iterators.
460462
return_iterator = (
461-
self.config.dataset_num_iters_per_learner > 1
462-
if self.config.dataset_num_iters_per_learner
463-
else True
463+
self.config.num_learners > 0
464+
or self.config.dataset_num_iters_per_learner != 1
464465
)
466+
465467
# Sampling from offline data.
466468
batch_or_iterator = self.offline_data.sample(
467469
num_samples=self.config.train_batch_size_per_learner,

0 commit comments

Comments
 (0)