Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions rllib/algorithms/algorithm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2989,7 +2989,7 @@ def evaluation(
if offline_evaluation_type is not NotProvided:
self.offline_evaluation_type = offline_evaluation_type
if offline_eval_runner_class is not NotProvided:
self.offline_eval_runner_cls = offline_eval_runner_class
self.offline_eval_runner_class = offline_eval_runner_class
if offline_loss_for_module_fn is not NotProvided:
self.offline_loss_for_module_fn = offline_loss_for_module_fn
if offline_eval_batch_size_per_runner is not NotProvided:
Expand Down Expand Up @@ -5328,8 +5328,8 @@ def _validate_offline_settings(self):

from ray.rllib.offline.offline_evaluation_runner import OfflineEvaluationRunner

if self.prelearner_class and not issubclass(
self.prelearner_class, OfflineEvaluationRunner
if self.offline_eval_runner_class and not issubclass(
self.offline_eval_runner_class, OfflineEvaluationRunner
):
self._value_error(
"Unknown `offline_eval_runner_class`. OfflineEvaluationRunner class needs to inherit "
Expand Down
11 changes: 8 additions & 3 deletions rllib/algorithms/cql/cql.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,15 +302,20 @@ def training_step(self) -> None:

# Sampling from offline data.
with self.metrics.log_time((TIMERS, OFFLINE_SAMPLING_TIMER)):
# If we should use an iterator in the learner(s). Note, in case of
# multiple learners we must always return a list of iterators.
return_iterator = return_iterator = (
self.config.num_learners > 0
or self.config.dataset_num_iters_per_learner != 1
)

# Return an iterator in case we are using remote learners.
batch_or_iterator = self.offline_data.sample(
num_samples=self.config.train_batch_size_per_learner,
num_shards=self.config.num_learners,
# Return an iterator, if a `Learner` should update
# multiple times per RLlib iteration.
return_iterator=self.config.dataset_num_iters_per_learner > 1
if self.config.dataset_num_iters_per_learner
else True,
return_iterator=return_iterator,
)

# Updating the policy.
Expand Down
8 changes: 5 additions & 3 deletions rllib/algorithms/marwil/marwil.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,11 +457,13 @@ class (multi-/single-learner setup) and evaluation on
# the user that sth. is not right, although it is as
# we do not step the env.
with self.metrics.log_time((TIMERS, OFFLINE_SAMPLING_TIMER)):
# If we should use an iterator in the learner(s). Note, in case of
# multiple learners we must always return a list of iterators.
return_iterator = (
self.config.dataset_num_iters_per_learner > 1
if self.config.dataset_num_iters_per_learner
else True
self.config.num_learners > 0
or self.config.dataset_num_iters_per_learner != 1
)

# Sampling from offline data.
batch_or_iterator = self.offline_data.sample(
num_samples=self.config.train_batch_size_per_learner,
Expand Down