-
Notifications
You must be signed in to change notification settings - Fork 42
/
experiment.py
926 lines (801 loc) · 33.9 KB
/
experiment.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
from __future__ import annotations
import copy
import importlib
import os
import time
from collections import deque, OrderedDict
from dataclasses import dataclass, MISSING
from pathlib import Path
from typing import Any, Dict, List, Optional
import torch
from tensordict import TensorDictBase
from tensordict.nn import TensorDictSequential
from torchrl.collectors import SyncDataCollector
from torchrl.envs import SerialEnv, TransformedEnv
from torchrl.envs.transforms import Compose
from torchrl.envs.utils import ExplorationType, set_exploration_type, step_mdp
from torchrl.record.loggers import generate_exp_name
from tqdm import tqdm
from benchmarl.algorithms import IppoConfig, MappoConfig
from benchmarl.algorithms.common import AlgorithmConfig
from benchmarl.environments import Task
from benchmarl.experiment.callback import Callback, CallbackNotifier
from benchmarl.experiment.logger import Logger
from benchmarl.models import GnnConfig, SequenceModelConfig
from benchmarl.models.common import ModelConfig
from benchmarl.utils import _read_yaml_config, seed_everything
_has_hydra = importlib.util.find_spec("hydra") is not None
if _has_hydra:
from hydra.core.hydra_config import HydraConfig
@dataclass
class ExperimentConfig:
"""
Configuration class for experiments.
This class acts as a schema for loading and validating yaml configurations.
Parameters in this class aim to be agnostic of the algorithm, task or model used.
To know their meaning, please check out the descriptions in ``benchmarl/conf/experiment/base_experiment.yaml``
"""
sampling_device: str = MISSING
train_device: str = MISSING
buffer_device: str = MISSING
share_policy_params: bool = MISSING
prefer_continuous_actions: bool = MISSING
collect_with_grad: bool = MISSING
gamma: float = MISSING
lr: float = MISSING
adam_eps: float = MISSING
clip_grad_norm: bool = MISSING
clip_grad_val: Optional[float] = MISSING
soft_target_update: bool = MISSING
polyak_tau: float = MISSING
hard_target_update_frequency: int = MISSING
exploration_eps_init: float = MISSING
exploration_eps_end: float = MISSING
exploration_anneal_frames: Optional[int] = MISSING
max_n_iters: Optional[int] = MISSING
max_n_frames: Optional[int] = MISSING
on_policy_collected_frames_per_batch: int = MISSING
on_policy_n_envs_per_worker: int = MISSING
on_policy_n_minibatch_iters: int = MISSING
on_policy_minibatch_size: int = MISSING
off_policy_collected_frames_per_batch: int = MISSING
off_policy_n_envs_per_worker: int = MISSING
off_policy_n_optimizer_steps: int = MISSING
off_policy_train_batch_size: int = MISSING
off_policy_memory_size: int = MISSING
off_policy_init_random_frames: int = MISSING
evaluation: bool = MISSING
render: bool = MISSING
evaluation_interval: int = MISSING
evaluation_episodes: int = MISSING
evaluation_deterministic_actions: bool = MISSING
loggers: List[str] = MISSING
project_name: str = MISSING
create_json: bool = MISSING
save_folder: Optional[str] = MISSING
restore_file: Optional[str] = MISSING
restore_map_location: Optional[Any] = MISSING
checkpoint_interval: int = MISSING
checkpoint_at_end: bool = MISSING
keep_checkpoints_num: Optional[int] = MISSING
def train_batch_size(self, on_policy: bool) -> int:
"""
The batch size of tensors used for training
Args:
on_policy (bool): is the algorithms on_policy
"""
return (
self.collected_frames_per_batch(on_policy)
if on_policy
else self.off_policy_train_batch_size
)
def train_minibatch_size(self, on_policy: bool) -> int:
"""
The minibatch size of tensors used for training.
On-policy algorithms are trained by splitting the train_batch_size (equal to the collected frames) into minibatches.
Off-policy algorithms do not go through this process and thus have the ``train_minibatch_size==train_batch_size``
Args:
on_policy (bool): is the algorithms on_policy
"""
return (
self.on_policy_minibatch_size
if on_policy
else self.train_batch_size(on_policy)
)
def n_optimizer_steps(self, on_policy: bool) -> int:
"""
Number of times to loop over the training step per collection iteration.
Args:
on_policy (bool): is the algorithms on_policy
"""
return (
self.on_policy_n_minibatch_iters
if on_policy
else self.off_policy_n_optimizer_steps
)
def replay_buffer_memory_size(self, on_policy: bool) -> int:
"""
Size of the replay buffer memory in terms of frames
Args:
on_policy (bool): is the algorithms on_policy
"""
return (
self.collected_frames_per_batch(on_policy)
if on_policy
else self.off_policy_memory_size
)
def collected_frames_per_batch(self, on_policy: bool) -> int:
"""
Number of collected frames per collection iteration.
Args:
on_policy (bool): is the algorithms on_policy
"""
return (
self.on_policy_collected_frames_per_batch
if on_policy
else self.off_policy_collected_frames_per_batch
)
def n_envs_per_worker(self, on_policy: bool) -> int:
"""
Number of environments used for collection
- In vectorized environments, this will be the vectorized batch_size.
- In other environments, this will be emulated by running them sequentially.
Args:
on_policy (bool): is the algorithms on_policy
"""
return (
self.on_policy_n_envs_per_worker
if on_policy
else self.off_policy_n_envs_per_worker
)
def get_max_n_frames(self, on_policy: bool) -> int:
"""
Get the maximum number of frames collected before the experiment ends.
Args:
on_policy (bool): is the algorithms on_policy
"""
if self.max_n_frames is not None and self.max_n_iters is not None:
return min(
self.max_n_frames,
self.max_n_iters * self.collected_frames_per_batch(on_policy),
)
elif self.max_n_frames is not None:
return self.max_n_frames
elif self.max_n_iters is not None:
return self.max_n_iters * self.collected_frames_per_batch(on_policy)
def get_max_n_iters(self, on_policy: bool) -> int:
"""
Get the maximum number of experiment iterations before the experiment ends.
Args:
on_policy (bool): is the algorithms on_policy
"""
return -(
-self.get_max_n_frames(on_policy)
// self.collected_frames_per_batch(on_policy)
)
def get_exploration_anneal_frames(self, on_policy: bool):
"""
Get the number of frames for exploration annealing.
If self.exploration_anneal_frames is None this will be a third of the total frames to collect.
Args:
on_policy (bool): is the algorithms on_policy
"""
return (
(self.get_max_n_frames(on_policy) // 3)
if self.exploration_anneal_frames is None
else self.exploration_anneal_frames
)
@staticmethod
def get_from_yaml(path: Optional[str] = None):
"""
Load the experiment configuration from yaml
Args:
path (str, optional): The full path of the yaml file to load from.
If None, it will default to
``benchmarl/conf/experiment/base_experiment.yaml``
Returns:
the loaded :class:`~benchmarl.experiment.ExperimentConfig`
"""
if path is None:
yaml_path = (
Path(__file__).parent.parent
/ "conf"
/ "experiment"
/ "base_experiment.yaml"
)
return ExperimentConfig(**_read_yaml_config(str(yaml_path.resolve())))
else:
return ExperimentConfig(**_read_yaml_config(path))
def validate(self, on_policy: bool):
"""
Validates config.
Args:
on_policy (bool): is the algorithms on_policy
"""
if (
self.evaluation
and self.evaluation_interval % self.collected_frames_per_batch(on_policy)
!= 0
):
raise ValueError(
f"evaluation_interval ({self.evaluation_interval}) "
f"is not a multiple of the collected_frames_per_batch ({self.collected_frames_per_batch(on_policy)})"
)
if (
self.checkpoint_interval != 0
and self.checkpoint_interval % self.collected_frames_per_batch(on_policy)
!= 0
):
raise ValueError(
f"checkpoint_interval ({self.checkpoint_interval}) "
f"is not a multiple of the collected_frames_per_batch ({self.collected_frames_per_batch(on_policy)})"
)
if self.keep_checkpoints_num is not None and self.keep_checkpoints_num <= 0:
raise ValueError("keep_checkpoints_num must be greater than zero or null")
if self.max_n_frames is None and self.max_n_iters is None:
raise ValueError("n_iters and total_frames are both not set")
class Experiment(CallbackNotifier):
"""
Main experiment class in BenchMARL.
Args:
task (Task): the task configuration
algorithm_config (AlgorithmConfig): the algorithm configuration
model_config (ModelConfig): the policy model configuration
seed (int): the seed for the experiment
config (ExperimentConfig): the experiment config
critic_model_config (ModelConfig, optional): the policy model configuration.
If None, it defaults to model_config
callbacks (list of Callback, optional): callbacks for this experiment
"""
def __init__(
self,
task: Task,
algorithm_config: AlgorithmConfig,
model_config: ModelConfig,
seed: int,
config: ExperimentConfig,
critic_model_config: Optional[ModelConfig] = None,
callbacks: Optional[List[Callback]] = None,
):
super().__init__(
experiment=self, callbacks=callbacks if callbacks is not None else []
)
self.config = config
self.task = task
self.model_config = model_config
self.critic_model_config = (
critic_model_config
if critic_model_config is not None
else copy.deepcopy(model_config)
)
self.critic_model_config.is_critic = True
self.algorithm_config = algorithm_config
self.seed = seed
self._setup()
self.total_time = 0
self.total_frames = 0
self.n_iters_performed = 0
self.mean_return = 0
if self.config.restore_file is not None:
self._load_experiment()
@property
def on_policy(self) -> bool:
"""Whether the algorithm has to be run on policy."""
return self.algorithm_config.on_policy()
def _setup(self):
self.config.validate(self.on_policy)
seed_everything(self.seed)
self._perfrom_checks()
self._set_action_type()
self._setup_task()
self._setup_algorithm()
self._setup_collector()
self._setup_name()
self._setup_logger()
self._on_setup()
def _perfrom_checks(self):
for config in (self.model_config, self.critic_model_config):
if isinstance(config, SequenceModelConfig):
for layer_config in config.model_configs[1:]:
if isinstance(layer_config, GnnConfig) and (
layer_config.position_key is not None
or layer_config.velocity_key is not None
):
raise ValueError(
"GNNs reading position or velocity keys are currently only usable in first"
" layer of sequence models"
)
if self.algorithm_config in (MappoConfig, IppoConfig):
critic_model_config = self.critic_model_config
if isinstance(critic_model_config, SequenceModelConfig):
critic_model_config = self.critic_model_config.model_configs[0]
if (
isinstance(critic_model_config, GnnConfig)
and critic_model_config.topology == "from_pos"
):
raise ValueError(
"GNNs in PPO critics with topology 'from_pos' are currently not available, "
"see https://github.com/pytorch/rl/issues/2537"
)
def _set_action_type(self):
if (
self.task.supports_continuous_actions()
and self.algorithm_config.supports_continuous_actions()
and self.config.prefer_continuous_actions
):
self.continuous_actions = True
elif (
self.task.supports_discrete_actions()
and self.algorithm_config.supports_discrete_actions()
):
self.continuous_actions = False
elif (
self.task.supports_continuous_actions()
and self.algorithm_config.supports_continuous_actions()
):
self.continuous_actions = True
else:
raise ValueError(
f"Algorithm {self.algorithm_config} is not compatible"
f" with the action space of task {self.task} "
)
def _setup_task(self):
test_env = self.task.get_env_fun(
num_envs=self.config.evaluation_episodes,
continuous_actions=self.continuous_actions,
seed=self.seed,
device=self.config.sampling_device,
)()
env_func = self.task.get_env_fun(
num_envs=self.config.n_envs_per_worker(self.on_policy),
continuous_actions=self.continuous_actions,
seed=self.seed,
device=self.config.sampling_device,
)
transforms_env = self.task.get_env_transforms(test_env)
transforms_training = transforms_env + [
self.task.get_reward_sum_transform(test_env)
]
transforms_env = Compose(*transforms_env)
transforms_training = Compose(*transforms_training)
if test_env.batch_size == ():
self.env_func = lambda: TransformedEnv(
SerialEnv(self.config.n_envs_per_worker(self.on_policy), env_func),
transforms_training.clone(),
)
else:
self.env_func = lambda: TransformedEnv(
env_func(), transforms_training.clone()
)
self.test_env = TransformedEnv(test_env, transforms_env.clone()).to(
self.config.sampling_device
)
self.observation_spec = self.task.observation_spec(self.test_env)
self.info_spec = self.task.info_spec(self.test_env)
self.state_spec = self.task.state_spec(self.test_env)
self.action_mask_spec = self.task.action_mask_spec(self.test_env)
self.action_spec = self.task.action_spec(self.test_env)
self.group_map = self.task.group_map(self.test_env)
self.train_group_map = copy.deepcopy(self.group_map)
self.max_steps = self.task.max_steps(self.test_env)
def _setup_algorithm(self):
self.algorithm = self.algorithm_config.get_algorithm(experiment=self)
self.test_env = self.algorithm.process_env_fun(lambda: self.test_env)()
self.env_func = self.algorithm.process_env_fun(self.env_func)
self.replay_buffers = {
group: self.algorithm.get_replay_buffer(
group=group,
transforms=self.task.get_replay_buffer_transforms(self.test_env, group),
)
for group in self.group_map.keys()
}
self.losses = {
group: self.algorithm.get_loss_and_updater(group)[0]
for group in self.group_map.keys()
}
self.target_updaters = {
group: self.algorithm.get_loss_and_updater(group)[1]
for group in self.group_map.keys()
}
self.optimizers = {
group: {
loss_name: torch.optim.Adam(
params, lr=self.config.lr, eps=self.config.adam_eps
)
for loss_name, params in self.algorithm.get_parameters(group).items()
}
for group in self.group_map.keys()
}
def _setup_collector(self):
self.policy = self.algorithm.get_policy_for_collection()
self.group_policies = {}
for group in self.group_map.keys():
group_policy = self.policy.select_subsequence(out_keys=[(group, "action")])
assert len(group_policy) == 1
self.group_policies.update({group: group_policy[0]})
if not self.config.collect_with_grad:
self.collector = SyncDataCollector(
self.env_func,
self.policy,
device=self.config.sampling_device,
storing_device=self.config.train_device,
frames_per_batch=self.config.collected_frames_per_batch(self.on_policy),
total_frames=self.config.get_max_n_frames(self.on_policy),
init_random_frames=(
self.config.off_policy_init_random_frames
if not self.on_policy
else 0
),
)
else:
if self.config.off_policy_init_random_frames and not self.on_policy:
raise TypeError(
"Collection via rollouts does not support initial random frames as of now."
)
self.rollout_env = self.env_func().to(self.config.sampling_device)
def _setup_name(self):
self.algorithm_name = self.algorithm_config.associated_class().__name__.lower()
self.model_name = self.model_config.associated_class().__name__.lower()
self.environment_name = self.task.env_name().lower()
self.task_name = self.task.name.lower()
self._checkpointed_files = deque([])
if self.config.save_folder is not None:
# If the user specified a folder for the experiment we use that
save_folder = Path(self.config.save_folder)
else:
# Otherwise, if the user is restoring from a folder, we will save in the folder they are restoring from
if self.config.restore_file is not None:
save_folder = Path(
self.config.restore_file
).parent.parent.parent.resolve()
# Otherwise, the user is not restoring and did not specify a save_folder so we save in the hydra directory
# of the experiment or in the directory where the experiment was run (if hydra is not used)
else:
if _has_hydra and HydraConfig.initialized():
save_folder = Path(HydraConfig.get().runtime.output_dir)
else:
save_folder = Path(os.getcwd())
if self.config.restore_file is None:
self.name = generate_exp_name(
f"{self.algorithm_name}_{self.task_name}_{self.model_name}", ""
)
self.folder_name = save_folder / self.name
else:
# If restoring, we use the name of the previous experiment
self.name = Path(self.config.restore_file).parent.parent.resolve().name
self.folder_name = save_folder / self.name
if (
len(self.config.loggers)
or self.config.checkpoint_interval > 0
or self.config.create_json
):
self.folder_name.mkdir(parents=False, exist_ok=True)
def _setup_logger(self):
self.logger = Logger(
project_name=self.config.project_name,
experiment_name=self.name,
folder_name=str(self.folder_name),
experiment_config=self.config,
algorithm_name=self.algorithm_name,
model_name=self.model_name,
environment_name=self.environment_name,
task_name=self.task_name,
group_map=self.group_map,
seed=self.seed,
)
self.logger.log_hparams(
experiment_config=self.config.__dict__,
algorithm_config=self.algorithm_config.__dict__,
model_config=self.model_config.__dict__,
task_config=self.task.config,
continuous_actions=self.continuous_actions,
on_policy=self.on_policy,
)
def run(self):
"""Run the experiment until completion."""
try:
torch.cuda.empty_cache()
self._collection_loop()
except KeyboardInterrupt as interrupt:
print("\n\nExperiment was closed gracefully\n\n")
self.close()
raise interrupt
except Exception as err:
print("\n\nExperiment failed and is closing gracefully\n\n")
self.close()
raise err
def evaluate(self):
"""Run just the evaluation loop once."""
self._evaluation_loop()
self.logger.commit()
print(
f"Evaluation results logged to loggers={self.config.loggers}"
f"{' and to a json file in the experiment folder.' if self.config.create_json else ''}"
)
def _collection_loop(self):
pbar = tqdm(
initial=self.n_iters_performed,
total=self.config.get_max_n_iters(self.on_policy),
)
if not self.config.collect_with_grad:
iterator = iter(self.collector)
else:
reset_batch = self.rollout_env.reset()
# Training/collection iterations
for _ in range(
self.n_iters_performed, self.config.get_max_n_iters(self.on_policy)
):
iteration_start = time.time()
if not self.config.collect_with_grad:
batch = next(iterator)
else:
with set_exploration_type(ExplorationType.RANDOM):
batch = self.rollout_env.rollout(
max_steps=-(
-self.config.collected_frames_per_batch(self.on_policy)
// self.rollout_env.batch_size.numel()
),
policy=self.policy,
break_when_any_done=False,
auto_reset=False,
tensordict=reset_batch,
)
reset_batch = step_mdp(
batch[..., -1],
reward_keys=self.rollout_env.reward_keys,
action_keys=self.rollout_env.action_keys,
done_keys=self.rollout_env.done_keys,
)
# Logging collection
collection_time = time.time() - iteration_start
current_frames = batch.numel()
self.total_frames += current_frames
self.mean_return = self.logger.log_collection(
batch,
total_frames=self.total_frames,
task=self.task,
step=self.n_iters_performed,
)
pbar.set_description(f"mean return = {self.mean_return}", refresh=False)
# Callback
self._on_batch_collected(batch)
batch = batch.detach()
# Loop over groups
training_start = time.time()
for group in self.train_group_map.keys():
group_batch = batch.exclude(*self._get_excluded_keys(group))
group_batch = self.algorithm.process_batch(group, group_batch)
if not self.algorithm.has_rnn:
group_batch = group_batch.reshape(-1)
self.replay_buffers[group].extend(group_batch)
training_tds = []
for _ in range(self.config.n_optimizer_steps(self.on_policy)):
for _ in range(
-(
-self.config.train_batch_size(self.on_policy)
// self.config.train_minibatch_size(self.on_policy)
)
):
training_tds.append(self._optimizer_loop(group))
training_td = torch.stack(training_tds)
self.logger.log_training(
group, training_td, step=self.n_iters_performed
)
# Callback
self._on_train_end(training_td, group)
# Exploration update
if isinstance(self.group_policies[group], TensorDictSequential):
explore_layer = self.group_policies[group][-1]
else:
explore_layer = self.group_policies[group]
if hasattr(explore_layer, "step"): # Step exploration annealing
explore_layer.step(current_frames)
# Update policy in collector
if not self.config.collect_with_grad:
self.collector.update_policy_weights_()
# Training timer
training_time = time.time() - training_start
# Evaluation
if (
self.config.evaluation
and (
self.total_frames % self.config.evaluation_interval == 0
or self.n_iters_performed == 0
)
and (len(self.config.loggers) or self.config.create_json)
):
self._evaluation_loop()
# End of step
iteration_time = time.time() - iteration_start
self.total_time += iteration_time
self.logger.log(
{
"timers/collection_time": collection_time,
"timers/training_time": training_time,
"timers/iteration_time": iteration_time,
"timers/total_time": self.total_time,
"counters/current_frames": current_frames,
"counters/total_frames": self.total_frames,
"counters/iter": self.n_iters_performed,
},
step=self.n_iters_performed,
)
self.n_iters_performed += 1
self.logger.commit()
if (
self.config.checkpoint_interval > 0
and self.total_frames % self.config.checkpoint_interval == 0
):
self._save_experiment()
pbar.update()
if self.config.checkpoint_at_end:
self._save_experiment()
self.close()
def close(self):
"""Close the experiment."""
if not self.config.collect_with_grad:
self.collector.shutdown()
else:
self.rollout_env.close()
self.test_env.close()
self.logger.finish()
def _get_excluded_keys(self, group: str):
excluded_keys = []
for other_group in self.group_map.keys():
if other_group != group:
excluded_keys += [other_group, ("next", other_group)]
excluded_keys += ["info", (group, "info"), ("next", group, "info")]
return excluded_keys
def _optimizer_loop(self, group: str) -> TensorDictBase:
subdata = self.replay_buffers[group].sample().to(self.config.train_device)
loss_vals = self.losses[group](subdata)
training_td = loss_vals.detach()
loss_vals = self.algorithm.process_loss_vals(group, loss_vals)
for loss_name, loss_value in loss_vals.items():
if loss_name in self.optimizers[group].keys():
optimizer = self.optimizers[group][loss_name]
loss_value.backward()
grad_norm = self._grad_clip(optimizer)
training_td.set(
f"grad_norm_{loss_name}",
torch.tensor(grad_norm, device=self.config.train_device),
)
optimizer.step()
optimizer.zero_grad()
self.replay_buffers[group].update_tensordict_priority(subdata)
if self.target_updaters[group] is not None:
self.target_updaters[group].step()
callback_loss = self._on_train_step(subdata, group)
if callback_loss is not None:
training_td.update(callback_loss)
return training_td
def _grad_clip(self, optimizer: torch.optim.Optimizer) -> float:
params = []
for param_group in optimizer.param_groups:
params += param_group["params"]
if self.config.clip_grad_norm and self.config.clip_grad_val is not None:
total_norm = torch.nn.utils.clip_grad_norm_(
params, self.config.clip_grad_val
)
else:
norm_type = 2.0
norms = [
torch.linalg.vector_norm(p.grad, norm_type)
for p in params
if p.grad is not None
]
total_norm = torch.linalg.vector_norm(torch.stack(norms), norm_type)
if self.config.clip_grad_val is not None:
torch.nn.utils.clip_grad_value_(params, self.config.clip_grad_val)
return float(total_norm)
@torch.no_grad()
def _evaluation_loop(self):
evaluation_start = time.time()
with set_exploration_type(
ExplorationType.DETERMINISTIC
if self.config.evaluation_deterministic_actions
else ExplorationType.RANDOM
):
if self.task.has_render(self.test_env) and self.config.render:
video_frames = []
def callback(env, td):
video_frames.append(
self.task.__class__.render_callback(self, env, td)
)
else:
video_frames = None
callback = None
if self.test_env.batch_size == ():
rollouts = []
for eval_episode in range(self.config.evaluation_episodes):
rollouts.append(
self.test_env.rollout(
max_steps=self.max_steps,
policy=self.policy,
callback=callback if eval_episode == 0 else None,
auto_cast_to_device=True,
break_when_any_done=True,
)
)
else:
rollouts = self.test_env.rollout(
max_steps=self.max_steps,
policy=self.policy,
callback=callback,
auto_cast_to_device=True,
break_when_any_done=False,
# We are running vectorized evaluation we do not want it to stop when just one env is done
)
rollouts = list(rollouts.unbind(0))
evaluation_time = time.time() - evaluation_start
self.logger.log(
{"timers/evaluation_time": evaluation_time}, step=self.n_iters_performed
)
self.logger.log_evaluation(
rollouts,
video_frames=video_frames,
step=self.n_iters_performed,
total_frames=self.total_frames,
)
# Callback
self._on_evaluation_end(rollouts)
# Saving experiment state
def state_dict(self) -> OrderedDict:
"""Get the state_dict for the experiment."""
state = OrderedDict(
total_time=self.total_time,
total_frames=self.total_frames,
n_iters_performed=self.n_iters_performed,
mean_return=self.mean_return,
)
state_dict = OrderedDict(
state=state,
**{f"loss_{k}": item.state_dict() for k, item in self.losses.items()},
**{
f"buffer_{k}": item.state_dict() if len(item) else None
for k, item in self.replay_buffers.items()
},
)
if not self.config.collect_with_grad:
state_dict.update({"collector": self.collector.state_dict()})
return state_dict
def load_state_dict(self, state_dict: Dict) -> None:
"""Load the state_dict for the experiment.
Args:
state_dict (dict): the state dict
"""
for group in self.group_map.keys():
self.losses[group].load_state_dict(state_dict[f"loss_{group}"])
if state_dict[f"buffer_{group}"] is not None:
self.replay_buffers[group].load_state_dict(
state_dict[f"buffer_{group}"]
)
if not self.config.collect_with_grad:
self.collector.load_state_dict(state_dict["collector"])
self.total_time = state_dict["state"]["total_time"]
self.total_frames = state_dict["state"]["total_frames"]
self.n_iters_performed = state_dict["state"]["n_iters_performed"]
self.mean_return = state_dict["state"]["mean_return"]
def _save_experiment(self) -> None:
"""Checkpoint trainer"""
if self.config.keep_checkpoints_num is not None:
while len(self._checkpointed_files) >= self.config.keep_checkpoints_num:
file_to_delete = self._checkpointed_files.popleft()
file_to_delete.unlink(missing_ok=False)
checkpoint_folder = self.folder_name / "checkpoints"
checkpoint_folder.mkdir(parents=False, exist_ok=True)
checkpoint_file = checkpoint_folder / f"checkpoint_{self.total_frames}.pt"
torch.save(self.state_dict(), checkpoint_file)
self._checkpointed_files.append(checkpoint_file)
def _load_experiment(self) -> Experiment:
"""Load trainer from checkpoint"""
loaded_dict: OrderedDict = torch.load(
self.config.restore_file, map_location=self.config.restore_map_location
)
self.load_state_dict(loaded_dict)
return self