File tree 1 file changed +7
-10
lines changed
1 file changed +7
-10
lines changed Original file line number Diff line number Diff line change @@ -362,7 +362,6 @@ def _setup(self):
362
362
self ._setup_task ()
363
363
self ._setup_algorithm ()
364
364
self ._setup_collector ()
365
- self ._setup_buffers ()
366
365
self ._setup_name ()
367
366
self ._setup_logger ()
368
367
self ._on_setup ()
@@ -480,6 +479,13 @@ def _setup_algorithm(self):
480
479
self .test_env = self .algorithm .process_env_fun (lambda : self .test_env )()
481
480
self .env_func = self .algorithm .process_env_fun (self .env_func )
482
481
482
+ self .replay_buffers = {
483
+ group : self .algorithm .get_replay_buffer (
484
+ group = group ,
485
+ transforms = self .task .get_replay_buffer_transforms (self .test_env , group ),
486
+ )
487
+ for group in self .group_map .keys ()
488
+ }
483
489
self .losses = {
484
490
group : self .algorithm .get_loss_and_updater (group )[0 ]
485
491
for group in self .group_map .keys ()
@@ -528,15 +534,6 @@ def _setup_collector(self):
528
534
)
529
535
self .rollout_env = self .env_func ().to (self .config .sampling_device )
530
536
531
- def _setup_buffers (self ):
532
- self .replay_buffers = {
533
- group : self .algorithm .get_replay_buffer (
534
- group = group ,
535
- transforms = self .task .get_replay_buffer_transforms (self .test_env , group ),
536
- )
537
- for group in self .group_map .keys ()
538
- }
539
-
540
537
def _setup_name (self ):
541
538
self .algorithm_name = self .algorithm_config .associated_class ().__name__ .lower ()
542
539
self .model_name = self .model_config .associated_class ().__name__ .lower ()
You can’t perform that action at this time.
0 commit comments