Skip to content

Commit a1c10ff

Browse files
committed
revert buffer update
1 parent a6952e0 commit a1c10ff

File tree

1 file changed

+7
-10
lines changed

1 file changed

+7
-10
lines changed

benchmarl/experiment/experiment.py

+7-10
Original file line numberDiff line numberDiff line change
@@ -362,7 +362,6 @@ def _setup(self):
362362
self._setup_task()
363363
self._setup_algorithm()
364364
self._setup_collector()
365-
self._setup_buffers()
366365
self._setup_name()
367366
self._setup_logger()
368367
self._on_setup()
@@ -480,6 +479,13 @@ def _setup_algorithm(self):
480479
self.test_env = self.algorithm.process_env_fun(lambda: self.test_env)()
481480
self.env_func = self.algorithm.process_env_fun(self.env_func)
482481

482+
self.replay_buffers = {
483+
group: self.algorithm.get_replay_buffer(
484+
group=group,
485+
transforms=self.task.get_replay_buffer_transforms(self.test_env, group),
486+
)
487+
for group in self.group_map.keys()
488+
}
483489
self.losses = {
484490
group: self.algorithm.get_loss_and_updater(group)[0]
485491
for group in self.group_map.keys()
@@ -528,15 +534,6 @@ def _setup_collector(self):
528534
)
529535
self.rollout_env = self.env_func().to(self.config.sampling_device)
530536

531-
def _setup_buffers(self):
532-
self.replay_buffers = {
533-
group: self.algorithm.get_replay_buffer(
534-
group=group,
535-
transforms=self.task.get_replay_buffer_transforms(self.test_env, group),
536-
)
537-
for group in self.group_map.keys()
538-
}
539-
540537
def _setup_name(self):
541538
self.algorithm_name = self.algorithm_config.associated_class().__name__.lower()
542539
self.model_name = self.model_config.associated_class().__name__.lower()

0 commit comments

Comments
 (0)