-
Notifications
You must be signed in to change notification settings - Fork 112
/
learner.py
1061 lines (844 loc) · 47.2 KB
/
learner.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
from __future__ import annotations
import glob
import os
import time
from abc import ABC, abstractmethod
from os.path import join
from typing import Callable, Dict, Optional, Tuple
import numpy as np
import torch
from torch import Tensor
from torch.nn import Module
from sample_factory.algo.learning.rnn_utils import build_core_out_from_seq, build_rnn_inputs
from sample_factory.algo.utils.action_distributions import get_action_distribution, is_continuous_action_space
from sample_factory.algo.utils.env_info import EnvInfo
from sample_factory.algo.utils.misc import LEARNER_ENV_STEPS, POLICY_ID_KEY, STATS_KEY, TRAIN_STATS, memory_stats
from sample_factory.algo.utils.model_sharing import ParameterServer
from sample_factory.algo.utils.optimizers import Lamb
from sample_factory.algo.utils.rl_utils import gae_advantages, prepare_and_normalize_obs
from sample_factory.algo.utils.shared_buffers import policy_device
from sample_factory.algo.utils.tensor_dict import TensorDict, shallow_recursive_copy
from sample_factory.algo.utils.torch_utils import masked_select, synchronize, to_scalar
from sample_factory.cfg.configurable import Configurable
from sample_factory.model.actor_critic import ActorCritic, create_actor_critic
from sample_factory.utils.attr_dict import AttrDict
from sample_factory.utils.decay import LinearDecay
from sample_factory.utils.dicts import iterate_recursively
from sample_factory.utils.timing import Timing
from sample_factory.utils.typing import ActionDistribution, Config, InitModelData, PolicyID
from sample_factory.utils.utils import ensure_dir_exists, experiment_dir, log
class LearningRateScheduler:
def update(self, current_lr, recent_kls):
return current_lr
def invoke_after_each_minibatch(self):
return False
def invoke_after_each_epoch(self):
return False
class KlAdaptiveScheduler(LearningRateScheduler, ABC):
def __init__(self, cfg: Config):
self.lr_schedule_kl_threshold = cfg.lr_schedule_kl_threshold
self.min_lr = cfg.lr_adaptive_min
self.max_lr = cfg.lr_adaptive_max
@abstractmethod
def num_recent_kls_to_use(self) -> int:
pass
def update(self, current_lr, recent_kls):
num_kls_to_use = self.num_recent_kls_to_use()
kls = recent_kls[-num_kls_to_use:]
mean_kl = np.mean(kls)
lr = current_lr
if mean_kl > 2.0 * self.lr_schedule_kl_threshold:
lr = max(current_lr / 1.5, self.min_lr)
if mean_kl < (0.5 * self.lr_schedule_kl_threshold):
lr = min(current_lr * 1.5, self.max_lr)
return lr
class KlAdaptiveSchedulerPerMinibatch(KlAdaptiveScheduler):
def num_recent_kls_to_use(self) -> int:
return 1
def invoke_after_each_minibatch(self):
return True
class KlAdaptiveSchedulerPerEpoch(KlAdaptiveScheduler):
def __init__(self, cfg):
super().__init__(cfg)
self.num_minibatches_per_epoch = cfg.num_batches_per_epoch
def num_recent_kls_to_use(self) -> int:
return self.num_minibatches_per_epoch
def invoke_after_each_epoch(self):
return True
class LinearDecayScheduler(LearningRateScheduler):
def __init__(self, cfg):
num_updates = cfg.train_for_env_steps // cfg.batch_size * cfg.num_epochs
self.linear_decay = LinearDecay([(0, cfg.learning_rate), (num_updates, 0)])
self.step = 0
def invoke_after_each_minibatch(self):
return True
def update(self, current_lr, recent_kls):
self.step += 1
lr = self.linear_decay.at(self.step)
return lr
def get_lr_scheduler(cfg) -> LearningRateScheduler:
if cfg.lr_schedule == "constant":
return LearningRateScheduler()
elif cfg.lr_schedule == "kl_adaptive_minibatch":
return KlAdaptiveSchedulerPerMinibatch(cfg)
elif cfg.lr_schedule == "kl_adaptive_epoch":
return KlAdaptiveSchedulerPerEpoch(cfg)
elif cfg.lr_schedule == "linear_decay":
return LinearDecayScheduler(cfg)
else:
raise RuntimeError(f"Unknown scheduler {cfg.lr_schedule}")
def model_initialization_data(
cfg: Config, policy_id: PolicyID, actor_critic: Module, policy_version: int, device: torch.device
) -> InitModelData:
# in serial mode we will just use the same actor_critic directly
state_dict = None if cfg.serial_mode else actor_critic.state_dict()
model_state = (policy_id, state_dict, device, policy_version)
return model_state
class Learner(Configurable):
def __init__(
self,
cfg: Config,
env_info: EnvInfo,
policy_versions_tensor: Tensor,
policy_id: PolicyID,
param_server: ParameterServer,
):
Configurable.__init__(self, cfg)
self.timing = Timing(name=f"Learner {policy_id} profile")
self.policy_id = policy_id
self.env_info = env_info
self.device = None
self.actor_critic: Optional[ActorCritic] = None
self.optimizer = None
self.curr_lr: Optional[float] = None
self.lr_scheduler: Optional[LearningRateScheduler] = None
self.train_step: int = 0 # total number of SGD steps
self.env_steps: int = 0 # total number of environment steps consumed by the learner
self.best_performance = -1e9
# for configuration updates, i.e. from PBT
self.new_cfg: Optional[Dict] = None
# for multi-policy learning (i.e. with PBT) when we need to load weights of another policy
self.policy_to_load: Optional[PolicyID] = None
# decay rate at which summaries are collected
# save summaries every 5 seconds in the beginning, but decay to every 4 minutes in the limit, because we
# do not need frequent summaries for longer experiments
self.summary_rate_decay_seconds = LinearDecay([(0, 2), (100000, 60), (1000000, 120)])
self.last_summary_time = 0
self.last_milestone_time = 0
# shared tensor used to share the latest policy version between processes
self.policy_versions_tensor: Tensor = policy_versions_tensor
self.param_server: ParameterServer = param_server
self.exploration_loss_func: Optional[Callable] = None
self.kl_loss_func: Optional[Callable] = None
self.is_initialized = False
def init(self) -> InitModelData:
if self.cfg.exploration_loss_coeff == 0.0:
self.exploration_loss_func = lambda action_distr, valids, num_invalids: 0.0
elif self.cfg.exploration_loss == "entropy":
self.exploration_loss_func = self._entropy_exploration_loss
elif self.cfg.exploration_loss == "symmetric_kl":
self.exploration_loss_func = self._symmetric_kl_exploration_loss
else:
raise NotImplementedError(f"{self.cfg.exploration_loss} not supported!")
if self.cfg.kl_loss_coeff == 0.0:
if is_continuous_action_space(self.env_info.action_space):
log.warning(
"WARNING! It is generally recommended to enable Fixed KL loss (https://arxiv.org/pdf/1707.06347.pdf) for continuous action tasks to avoid potential numerical issues. "
"I.e. set --kl_loss_coeff=0.1"
)
self.kl_loss_func = lambda action_space, action_logits, distribution, valids, num_invalids: (None, 0.0)
else:
self.kl_loss_func = self._kl_loss
# initialize the Torch modules
if self.cfg.seed is None:
log.info("Starting seed is not provided")
else:
log.info("Setting fixed seed %d", self.cfg.seed)
torch.manual_seed(self.cfg.seed)
np.random.seed(self.cfg.seed)
# initialize device
self.device = policy_device(self.cfg, self.policy_id)
log.debug("Initializing actor-critic model on device %s", self.device)
# trainable torch module
self.actor_critic = create_actor_critic(self.cfg, self.env_info.obs_space, self.env_info.action_space)
log.debug("Created Actor Critic model with architecture:")
log.debug(self.actor_critic)
self.actor_critic.model_to_device(self.device)
def share_mem(t):
if t is not None and not t.is_cuda:
return t.share_memory_()
return t
# noinspection PyProtectedMember
self.actor_critic._apply(share_mem)
self.actor_critic.train()
params = list(self.actor_critic.parameters())
optimizer_cls = dict(adam=torch.optim.Adam, lamb=Lamb)
if self.cfg.optimizer not in optimizer_cls:
raise RuntimeError(f"Unknown optimizer {self.cfg.optimizer}")
optimizer_cls = optimizer_cls[self.cfg.optimizer]
log.debug(f"Using optimizer {optimizer_cls}")
optimizer_kwargs = dict(
lr=self.cfg.learning_rate, # use default lr only in ctor, then we use the one loaded from the checkpoint
betas=(self.cfg.adam_beta1, self.cfg.adam_beta2),
)
if self.cfg.optimizer in ["adam", "lamb"]:
optimizer_kwargs["eps"] = self.cfg.adam_eps
self.optimizer = optimizer_cls(params, **optimizer_kwargs)
self.load_from_checkpoint(self.policy_id)
self.param_server.init(self.actor_critic, self.train_step, self.device)
self.policy_versions_tensor[self.policy_id] = self.train_step
self.lr_scheduler = get_lr_scheduler(self.cfg)
self.curr_lr = self.cfg.learning_rate if self.curr_lr is None else self.curr_lr
self._apply_lr(self.curr_lr)
self.is_initialized = True
return model_initialization_data(self.cfg, self.policy_id, self.actor_critic, self.train_step, self.device)
@staticmethod
def checkpoint_dir(cfg, policy_id):
checkpoint_dir = join(experiment_dir(cfg=cfg), f"checkpoint_p{policy_id}")
return ensure_dir_exists(checkpoint_dir)
@staticmethod
def get_checkpoints(checkpoints_dir, pattern="checkpoint_*"):
checkpoints = glob.glob(join(checkpoints_dir, pattern))
return sorted(checkpoints)
@staticmethod
def load_checkpoint(checkpoints, device):
if len(checkpoints) <= 0:
log.warning("No checkpoints found")
return None
else:
latest_checkpoint = checkpoints[-1]
# extra safety mechanism to recover from spurious filesystem errors
num_attempts = 3
for attempt in range(num_attempts):
# noinspection PyBroadException
try:
log.warning("Loading state from checkpoint %s...", latest_checkpoint)
checkpoint_dict = torch.load(latest_checkpoint, map_location=device)
return checkpoint_dict
except Exception:
log.exception(f"Could not load from checkpoint, attempt {attempt}")
def _load_state(self, checkpoint_dict, load_progress=True):
if load_progress:
self.train_step = checkpoint_dict["train_step"]
self.env_steps = checkpoint_dict["env_steps"]
self.best_performance = checkpoint_dict.get("best_performance", self.best_performance)
self.actor_critic.load_state_dict(checkpoint_dict["model"])
self.optimizer.load_state_dict(checkpoint_dict["optimizer"])
self.curr_lr = checkpoint_dict.get("curr_lr", self.cfg.learning_rate)
log.info(f"Loaded experiment state at {self.train_step=}, {self.env_steps=}")
def load_from_checkpoint(self, policy_id: PolicyID, load_progress: bool = True) -> None:
name_prefix = dict(latest="checkpoint", best="best")[self.cfg.load_checkpoint_kind]
checkpoints = self.get_checkpoints(self.checkpoint_dir(self.cfg, policy_id), pattern=f"{name_prefix}_*")
checkpoint_dict = self.load_checkpoint(checkpoints, self.device)
if checkpoint_dict is None:
log.debug("Did not load from checkpoint, starting from scratch!")
else:
log.debug("Loading model from checkpoint")
# if we're replacing our policy with another policy (under PBT), let's not reload the env_steps
self._load_state(checkpoint_dict, load_progress=load_progress)
def _should_save_summaries(self):
summaries_every_seconds = self.summary_rate_decay_seconds.at(self.train_step)
if time.time() - self.last_summary_time < summaries_every_seconds:
return False
return True
def _after_optimizer_step(self):
"""A hook to be called after each optimizer step."""
self.train_step += 1
def _get_checkpoint_dict(self):
checkpoint = {
"train_step": self.train_step,
"env_steps": self.env_steps,
"best_performance": self.best_performance,
"model": self.actor_critic.state_dict(),
"optimizer": self.optimizer.state_dict(),
"curr_lr": self.curr_lr,
}
return checkpoint
def _save_impl(self, name_prefix, name_suffix, keep_checkpoints, verbose=True) -> bool:
if not self.is_initialized:
return False
checkpoint = self._get_checkpoint_dict()
assert checkpoint is not None
checkpoint_dir = self.checkpoint_dir(self.cfg, self.policy_id)
tmp_filepath = join(checkpoint_dir, f"{name_prefix}_temp")
checkpoint_name = f"{name_prefix}_{self.train_step:09d}_{self.env_steps}{name_suffix}.pth"
filepath = join(checkpoint_dir, checkpoint_name)
if verbose:
log.info("Saving %s...", filepath)
# This should protect us from a rare case where something goes wrong mid-save and we end up with a corrupted
# checkpoint file. It better be a corrupted temp file.
torch.save(checkpoint, tmp_filepath)
os.rename(tmp_filepath, filepath)
while len(checkpoints := self.get_checkpoints(checkpoint_dir, f"{name_prefix}_*")) > keep_checkpoints:
oldest_checkpoint = checkpoints[0]
if os.path.isfile(oldest_checkpoint):
if verbose:
log.debug("Removing %s", oldest_checkpoint)
os.remove(oldest_checkpoint)
return True
def save(self) -> bool:
return self._save_impl("checkpoint", "", self.cfg.keep_checkpoints)
def save_milestone(self):
checkpoint = self._get_checkpoint_dict()
assert checkpoint is not None
checkpoint_dir = self.checkpoint_dir(self.cfg, self.policy_id)
checkpoint_name = f"checkpoint_{self.train_step:09d}_{self.env_steps}.pth"
milestones_dir = ensure_dir_exists(join(checkpoint_dir, "milestones"))
milestone_path = join(milestones_dir, f"{checkpoint_name}")
log.info("Saving a milestone %s", milestone_path)
torch.save(checkpoint, milestone_path)
def save_best(self, policy_id, metric, metric_value) -> bool:
if policy_id != self.policy_id:
return False
p = 3 # precision, number of significant digits
if metric_value - self.best_performance > 1 / 10**p:
log.info(f"Saving new best policy, {metric}={metric_value:.{p}f}!")
self.best_performance = metric_value
name_suffix = f"_{metric}_{metric_value:.{p}f}"
return self._save_impl("best", name_suffix, 1, verbose=False)
return False
def set_new_cfg(self, new_cfg: Dict) -> None:
self.new_cfg = new_cfg
def set_policy_to_load(self, policy_to_load: PolicyID) -> None:
self.policy_to_load = policy_to_load
def _maybe_update_cfg(self) -> None:
if self.new_cfg is not None:
for key, value in self.new_cfg.items():
if self.cfg[key] != value:
log.debug("Learner %d replacing cfg parameter %r with new value %r", self.policy_id, key, value)
self.cfg[key] = value
if self.cfg.lr_schedule == "constant" and self.curr_lr != self.cfg.learning_rate:
# PBT-optimized learning rate, only makes sense if we use constant LR
# in case of more advanced LR scheduling we should update the parameters of the scheduler, not the
# learning rate directly
log.debug(f"Updating learning rate from {self.curr_lr} to {self.cfg.learning_rate}")
self.curr_lr = self.cfg.learning_rate
self._apply_lr(self.curr_lr)
for param_group in self.optimizer.param_groups:
param_group["betas"] = (self.cfg.adam_beta1, self.cfg.adam_beta2)
log.debug("Optimizer lr value %.7f, betas: %r", param_group["lr"], param_group["betas"])
self.new_cfg = None
def _maybe_load_policy(self) -> None:
if self.policy_to_load is not None:
with self.param_server.policy_lock:
# don't re-load progress if we are loading from another policy checkpoint
self.load_from_checkpoint(self.policy_to_load, load_progress=False)
# make sure everything (such as policy weights) is committed to shared device memory
synchronize(self.cfg, self.device)
# this will force policy update on the inference worker (policy worker)
# we add max_policy_lag steps so that all experience currently in batches is invalidated
self.train_step += self.cfg.max_policy_lag + 1
self.policy_versions_tensor[self.policy_id] = self.train_step
self.policy_to_load = None
@staticmethod
def _policy_loss(ratio, adv, clip_ratio_low, clip_ratio_high, valids, num_invalids: int):
clipped_ratio = torch.clamp(ratio, clip_ratio_low, clip_ratio_high)
loss_unclipped = ratio * adv
loss_clipped = clipped_ratio * adv
loss = torch.min(loss_unclipped, loss_clipped)
loss = masked_select(loss, valids, num_invalids)
loss = -loss.mean()
return loss
def _value_loss(
self,
new_values: Tensor,
old_values: Tensor,
target: Tensor,
clip_value: float,
valids: Tensor,
num_invalids: int,
) -> Tensor:
value_clipped = old_values + torch.clamp(new_values - old_values, -clip_value, clip_value)
value_original_loss = (new_values - target).pow(2)
value_clipped_loss = (value_clipped - target).pow(2)
value_loss = torch.max(value_original_loss, value_clipped_loss)
value_loss = masked_select(value_loss, valids, num_invalids)
value_loss = value_loss.mean()
value_loss *= self.cfg.value_loss_coeff
return value_loss
def _kl_loss(
self, action_space, action_logits, action_distribution, valids, num_invalids: int
) -> Tuple[Tensor, Tensor]:
old_action_distribution = get_action_distribution(action_space, action_logits)
kl_old = action_distribution.kl_divergence(old_action_distribution)
kl_old = masked_select(kl_old, valids, num_invalids)
kl_loss = kl_old.mean()
kl_loss *= self.cfg.kl_loss_coeff
return kl_old, kl_loss
def _entropy_exploration_loss(self, action_distribution, valids, num_invalids: int) -> Tensor:
entropy = action_distribution.entropy()
entropy = masked_select(entropy, valids, num_invalids)
entropy_loss = -self.cfg.exploration_loss_coeff * entropy.mean()
return entropy_loss
def _symmetric_kl_exploration_loss(self, action_distribution, valids, num_invalids: int) -> Tensor:
kl_prior = action_distribution.symmetric_kl_with_uniform_prior()
kl_prior = masked_select(kl_prior, valids, num_invalids).mean()
if not torch.isfinite(kl_prior):
kl_prior = torch.zeros(kl_prior.shape)
kl_prior = torch.clamp(kl_prior, max=30)
kl_prior_loss = self.cfg.exploration_loss_coeff * kl_prior
return kl_prior_loss
def _optimizer_lr(self):
for param_group in self.optimizer.param_groups:
return param_group["lr"]
def _apply_lr(self, lr: float) -> None:
"""Change learning rate in the optimizer."""
if lr != self._optimizer_lr():
for param_group in self.optimizer.param_groups:
param_group["lr"] = lr
def _get_minibatches(self, batch_size, experience_size):
"""Generating minibatches for training."""
assert self.cfg.rollout % self.cfg.recurrence == 0
assert experience_size % batch_size == 0, f"experience size: {experience_size}, batch size: {batch_size}"
minibatches_per_epoch = self.cfg.num_batches_per_epoch
if minibatches_per_epoch == 1:
return [None] # single minibatch is actually the entire buffer, we don't need indices
if self.cfg.shuffle_minibatches:
# indices that will start the mini-trajectories from the same episode (for bptt)
indices = np.arange(0, experience_size, self.cfg.recurrence)
indices = np.random.permutation(indices)
# complete indices of mini trajectories, e.g. with recurrence==4: [4, 16] -> [4, 5, 6, 7, 16, 17, 18, 19]
indices = [np.arange(i, i + self.cfg.recurrence) for i in indices]
indices = np.concatenate(indices)
assert len(indices) == experience_size
num_minibatches = experience_size // batch_size
minibatches = np.split(indices, num_minibatches)
else:
minibatches = list(slice(i * batch_size, (i + 1) * batch_size) for i in range(0, minibatches_per_epoch))
# this makes sense but I'd like to do some testing before enabling it
# random.shuffle(minibatches) # same minibatches between epochs, but in random order
return minibatches
@staticmethod
def _get_minibatch(buffer, indices):
if indices is None:
# handle the case of a single batch, where the entire buffer is a minibatch
return buffer
mb = buffer[indices]
return mb
def _calculate_losses(
self, mb: AttrDict, num_invalids: int
) -> Tuple[ActionDistribution, Tensor, Tensor | float, Optional[Tensor], Tensor | float, Tensor, Dict]:
with torch.no_grad(), self.timing.add_time("losses_init"):
recurrence: int = self.cfg.recurrence
# PPO clipping
clip_ratio_high = 1.0 + self.cfg.ppo_clip_ratio # e.g. 1.1
# this still works with e.g. clip_ratio = 2, while PPO's 1-r would give negative ratio
clip_ratio_low = 1.0 / clip_ratio_high
clip_value = self.cfg.ppo_clip_value
valids = mb.valids
# calculate policy head outside of recurrent loop
with self.timing.add_time("forward_head"):
head_outputs = self.actor_critic.forward_head(mb.normalized_obs)
minibatch_size: int = head_outputs.size(0)
# initial rnn states
with self.timing.add_time("bptt_initial"):
if self.cfg.use_rnn:
# this is the only way to stop RNNs from backpropagating through invalid timesteps
# (i.e. experience collected by another policy)
done_or_invalid = torch.logical_or(mb.dones_cpu, ~valids.cpu()).float()
head_output_seq, rnn_states, inverted_select_inds = build_rnn_inputs(
head_outputs,
done_or_invalid,
mb.rnn_states,
recurrence,
)
else:
rnn_states = mb.rnn_states[::recurrence]
# calculate RNN outputs for each timestep in a loop
with self.timing.add_time("bptt"):
if self.cfg.use_rnn:
with self.timing.add_time("bptt_forward_core"):
core_output_seq, _ = self.actor_critic.forward_core(head_output_seq, rnn_states)
core_outputs = build_core_out_from_seq(core_output_seq, inverted_select_inds)
del core_output_seq
else:
core_outputs, _ = self.actor_critic.forward_core(head_outputs, rnn_states)
del head_outputs
num_trajectories = minibatch_size // recurrence
assert core_outputs.shape[0] == minibatch_size
with self.timing.add_time("tail"):
# calculate policy tail outside of recurrent loop
result = self.actor_critic.forward_tail(core_outputs, values_only=False, sample_actions=False)
action_distribution = self.actor_critic.action_distribution()
log_prob_actions = action_distribution.log_prob(mb.actions)
ratio = torch.exp(log_prob_actions - mb.log_prob_actions) # pi / pi_old
# super large/small values can cause numerical problems and are probably noise anyway
ratio = torch.clamp(ratio, 0.05, 20.0)
values = result["values"].squeeze()
del core_outputs
# these computations are not the part of the computation graph
with torch.no_grad(), self.timing.add_time("advantages_returns"):
if self.cfg.with_vtrace:
# V-trace parameters
rho_hat = torch.Tensor([self.cfg.vtrace_rho])
c_hat = torch.Tensor([self.cfg.vtrace_c])
ratios_cpu = ratio.cpu()
values_cpu = values.cpu()
rewards_cpu = mb.rewards_cpu
dones_cpu = mb.dones_cpu
vtrace_rho = torch.min(rho_hat, ratios_cpu)
vtrace_c = torch.min(c_hat, ratios_cpu)
vs = torch.zeros((num_trajectories * recurrence))
adv = torch.zeros((num_trajectories * recurrence))
next_values = values_cpu[recurrence - 1 :: recurrence] - rewards_cpu[recurrence - 1 :: recurrence]
next_values /= self.cfg.gamma
next_vs = next_values
for i in reversed(range(self.cfg.recurrence)):
rewards = rewards_cpu[i::recurrence]
dones = dones_cpu[i::recurrence]
not_done = 1.0 - dones
not_done_gamma = not_done * self.cfg.gamma
curr_values = values_cpu[i::recurrence]
curr_vtrace_rho = vtrace_rho[i::recurrence]
curr_vtrace_c = vtrace_c[i::recurrence]
delta_s = curr_vtrace_rho * (rewards + not_done_gamma * next_values - curr_values)
adv[i::recurrence] = curr_vtrace_rho * (rewards + not_done_gamma * next_vs - curr_values)
next_vs = curr_values + delta_s + not_done_gamma * curr_vtrace_c * (next_vs - next_values)
vs[i::recurrence] = next_vs
next_values = curr_values
targets = vs.to(self.device)
adv = adv.to(self.device)
else:
# using regular GAE
adv = mb.advantages
targets = mb.returns
adv_std, adv_mean = torch.std_mean(masked_select(adv, valids, num_invalids))
adv = (adv - adv_mean) / torch.clamp_min(adv_std, 1e-7) # normalize advantage
with self.timing.add_time("losses"):
# noinspection PyTypeChecker
policy_loss = self._policy_loss(ratio, adv, clip_ratio_low, clip_ratio_high, valids, num_invalids)
exploration_loss = self.exploration_loss_func(action_distribution, valids, num_invalids)
kl_old, kl_loss = self.kl_loss_func(
self.actor_critic.action_space, mb.action_logits, action_distribution, valids, num_invalids
)
old_values = mb["values"]
value_loss = self._value_loss(values, old_values, targets, clip_value, valids, num_invalids)
loss_summaries = dict(
ratio=ratio,
clip_ratio_low=clip_ratio_low,
clip_ratio_high=clip_ratio_high,
values=result["values"],
adv=adv,
adv_std=adv_std,
adv_mean=adv_mean,
)
return action_distribution, policy_loss, exploration_loss, kl_old, kl_loss, value_loss, loss_summaries
def _train(
self, gpu_buffer: TensorDict, batch_size: int, experience_size: int, num_invalids: int
) -> Optional[AttrDict]:
timing = self.timing
with torch.no_grad():
early_stopping_tolerance = 1e-6
early_stop = False
prev_epoch_actor_loss = 1e9
epoch_actor_losses = [0] * self.cfg.num_batches_per_epoch
# recent mean KL-divergences per minibatch, this used by LR schedulers
recent_kls = []
if self.cfg.with_vtrace:
assert (
self.cfg.recurrence == self.cfg.rollout and self.cfg.recurrence > 1
), "V-trace requires to recurrence and rollout to be equal"
num_sgd_steps = 0
stats_and_summaries: Optional[AttrDict] = None
# When it is time to record train summaries, we randomly sample epoch/batch for which the summaries are
# collected to get equal representation from different stages of training.
# Half the time, we record summaries from the very large step of training. There we will have the highest
# KL-divergence and ratio of PPO-clipped samples, which makes this data even more useful for analysis.
# Something to consider: maybe we should have these last-batch metrics in a separate summaries category?
with_summaries = self._should_save_summaries()
if np.random.rand() < 0.5:
summaries_epoch = np.random.randint(0, self.cfg.num_epochs)
summaries_batch = np.random.randint(0, self.cfg.num_batches_per_epoch)
else:
summaries_epoch = self.cfg.num_epochs - 1
summaries_batch = self.cfg.num_batches_per_epoch - 1
assert self.actor_critic.training
for epoch in range(self.cfg.num_epochs):
with timing.add_time("epoch_init"):
if early_stop:
break
force_summaries = False
minibatches = self._get_minibatches(batch_size, experience_size)
for batch_num in range(len(minibatches)):
with torch.no_grad(), timing.add_time("minibatch_init"):
indices = minibatches[batch_num]
# current minibatch consisting of short trajectory segments with length == recurrence
mb = self._get_minibatch(gpu_buffer, indices)
# enable syntactic sugar that allows us to access dict's keys as object attributes
mb = AttrDict(mb)
with timing.add_time("calculate_losses"):
(
action_distribution,
policy_loss,
exploration_loss,
kl_old,
kl_loss,
value_loss,
loss_summaries,
) = self._calculate_losses(mb, num_invalids)
with timing.add_time("losses_postprocess"):
# noinspection PyTypeChecker
actor_loss: Tensor = policy_loss + exploration_loss + kl_loss
critic_loss = value_loss
loss: Tensor = actor_loss + critic_loss
epoch_actor_losses[batch_num] = float(actor_loss)
high_loss = 30.0
if torch.abs(loss) > high_loss:
log.warning(
"High loss value: l:%.4f pl:%.4f vl:%.4f exp_l:%.4f kl_l:%.4f (recommended to adjust the --reward_scale parameter)",
to_scalar(loss),
to_scalar(policy_loss),
to_scalar(value_loss),
to_scalar(exploration_loss),
to_scalar(kl_loss),
)
# perhaps something weird is happening, we definitely want summaries from this step
force_summaries = True
with torch.no_grad(), timing.add_time("kl_divergence"):
# if kl_old is not None it is already calculated above
if kl_old is None:
# calculate KL-divergence with the behaviour policy action distribution
old_action_distribution = get_action_distribution(
self.actor_critic.action_space,
mb.action_logits,
)
kl_old = action_distribution.kl_divergence(old_action_distribution)
kl_old = masked_select(kl_old, mb.valids, num_invalids)
kl_old_mean = float(kl_old.mean().item())
recent_kls.append(kl_old_mean)
if kl_old.numel() > 0 and kl_old.max().item() > 100:
log.warning(f"KL-divergence is very high: {kl_old.max().item():.4f}")
# update the weights
with timing.add_time("update"):
# following advice from https://youtu.be/9mS1fIYj1So set grad to None instead of optimizer.zero_grad()
for p in self.actor_critic.parameters():
p.grad = None
loss.backward()
if self.cfg.max_grad_norm > 0.0:
with timing.add_time("clip"):
torch.nn.utils.clip_grad_norm_(self.actor_critic.parameters(), self.cfg.max_grad_norm)
curr_policy_version = self.train_step # policy version before the weight update
actual_lr = self.curr_lr
if num_invalids > 0:
# if we have masked (invalid) data we should reduce the learning rate accordingly
# this prevents a situation where most of the data in the minibatch is invalid
# and we end up doing SGD with super noisy gradients
actual_lr = self.curr_lr * (experience_size - num_invalids) / experience_size
self._apply_lr(actual_lr)
with self.param_server.policy_lock:
self.optimizer.step()
num_sgd_steps += 1
with torch.no_grad(), timing.add_time("after_optimizer"):
self._after_optimizer_step()
if self.lr_scheduler.invoke_after_each_minibatch():
self.curr_lr = self.lr_scheduler.update(self.curr_lr, recent_kls)
# collect and report summaries
should_record_summaries = with_summaries
should_record_summaries &= epoch == summaries_epoch and batch_num == summaries_batch
should_record_summaries |= force_summaries
if should_record_summaries:
# hacky way to collect all of the intermediate variables for summaries
summary_vars = {**locals(), **loss_summaries}
stats_and_summaries = self._record_summaries(AttrDict(summary_vars))
del summary_vars
force_summaries = False
# make sure everything (such as policy weights) is committed to shared device memory
synchronize(self.cfg, self.device)
# this will force policy update on the inference worker (policy worker)
self.policy_versions_tensor[self.policy_id] = self.train_step
# end of an epoch
if self.lr_scheduler.invoke_after_each_epoch():
self.curr_lr = self.lr_scheduler.update(self.curr_lr, recent_kls)
new_epoch_actor_loss = float(np.mean(epoch_actor_losses))
loss_delta_abs = abs(prev_epoch_actor_loss - new_epoch_actor_loss)
if loss_delta_abs < early_stopping_tolerance:
early_stop = True
log.debug(
"Early stopping after %d epochs (%d sgd steps), loss delta %.7f",
epoch + 1,
num_sgd_steps,
loss_delta_abs,
)
break
prev_epoch_actor_loss = new_epoch_actor_loss
return stats_and_summaries
def _record_summaries(self, train_loop_vars) -> AttrDict:
var = train_loop_vars
self.last_summary_time = time.time()
stats = AttrDict()
stats.lr = self.curr_lr
stats.actual_lr = train_loop_vars.actual_lr # potentially scaled because of masked data
stats.update(self.actor_critic.summaries())
stats.valids_fraction = var.mb.valids.float().mean()
stats.same_policy_fraction = (var.mb.policy_id == self.policy_id).float().mean()
grad_norm = (
sum(p.grad.data.norm(2).item() ** 2 for p in self.actor_critic.parameters() if p.grad is not None) ** 0.5
)
stats.grad_norm = grad_norm
stats.loss = var.loss
stats.value = var.values.mean()
stats.entropy = var.action_distribution.entropy().mean()
stats.policy_loss = var.policy_loss
stats.kl_loss = var.kl_loss
stats.value_loss = var.value_loss
stats.exploration_loss = var.exploration_loss
stats.act_min = var.mb.actions.min()
stats.act_max = var.mb.actions.max()
stats.adv_min = var.mb.advantages.min()
stats.adv_max = var.mb.advantages.max()
stats.adv_std = var.adv_std
stats.adv_mean = var.adv_mean
stats.max_abs_logprob = torch.abs(var.mb.action_logits).max()
if hasattr(var.action_distribution, "summaries"):
stats.update(var.action_distribution.summaries())
if var.epoch == self.cfg.num_epochs - 1 and var.batch_num == len(var.minibatches) - 1:
# we collect these stats only for the last PPO batch, or every time if we're only doing one batch, IMPALA-style
valid_ratios = masked_select(var.ratio, var.mb.valids, var.num_invalids)
ratio_mean = torch.abs(1.0 - valid_ratios).mean().detach()
ratio_min = valid_ratios.min().detach()
ratio_max = valid_ratios.max().detach()
# log.debug('Learner %d ratio mean min max %.4f %.4f %.4f', self.policy_id, ratio_mean.cpu().item(), ratio_min.cpu().item(), ratio_max.cpu().item())
value_delta = torch.abs(var.values - var.mb.values)
value_delta_avg, value_delta_max = value_delta.mean(), value_delta.max()
stats.kl_divergence = var.kl_old_mean
stats.kl_divergence_max = var.kl_old.max()
stats.value_delta = value_delta_avg
stats.value_delta_max = value_delta_max
# noinspection PyUnresolvedReferences
stats.fraction_clipped = (
(valid_ratios < var.clip_ratio_low).float() + (valid_ratios > var.clip_ratio_high).float()
).mean()
stats.ratio_mean = ratio_mean
stats.ratio_min = ratio_min
stats.ratio_max = ratio_max
stats.num_sgd_steps = var.num_sgd_steps
# this caused numerical issues on some versions of PyTorch with second moment reaching infinity
adam_max_second_moment = 0.0
for key, tensor_state in self.optimizer.state.items():
if "exp_avg_sq" in tensor_state:
adam_max_second_moment = max(tensor_state["exp_avg_sq"].max().item(), adam_max_second_moment)
stats.adam_max_second_moment = adam_max_second_moment
version_diff = (var.curr_policy_version - var.mb.policy_version)[var.mb.policy_id == self.policy_id]
stats.version_diff_avg = version_diff.mean()
stats.version_diff_min = version_diff.min()
stats.version_diff_max = version_diff.max()
for key, value in stats.items():
stats[key] = to_scalar(value)
return stats
def _prepare_and_normalize_obs(self, obs: TensorDict) -> TensorDict:
og_shape = dict()
# assuming obs is a flat dict, collapse time and envs dimensions into a single batch dimension
for key, x in obs.items():
og_shape[key] = x.shape
obs[key] = x.view((x.shape[0] * x.shape[1],) + x.shape[2:])
# hold the lock while we alter the state of the normalizer since they can be used in other processes too
with self.param_server.policy_lock:
normalized_obs = prepare_and_normalize_obs(self.actor_critic, obs)
# restore original shape
for key, x in normalized_obs.items():
normalized_obs[key] = x.view(og_shape[key])
return normalized_obs
def _prepare_batch(self, batch: TensorDict) -> Tuple[TensorDict, int, int]:
with torch.no_grad():
# create a shallow copy so we can modify the dictionary
# we still reference the same buffers though
buff = shallow_recursive_copy(batch)
# ignore experience from other agents (i.e. on episode boundary) and from inactive agents
valids: Tensor = buff["policy_id"] == self.policy_id
# ignore experience that was older than the threshold even before training started
curr_policy_version: int = self.train_step
buff["valids"][:, :-1] = valids & (curr_policy_version - buff["policy_version"] < self.cfg.max_policy_lag)
# for last T+1 step, we want to use the validity of the previous step
buff["valids"][:, -1] = buff["valids"][:, -2]
# ensure we're in train mode so that normalization statistics are updated
if not self.actor_critic.training:
self.actor_critic.train()
buff["normalized_obs"] = self._prepare_and_normalize_obs(buff["obs"])
del buff["obs"] # don't need non-normalized obs anymore
# calculate estimated value for the next step (T+1)
normalized_last_obs = buff["normalized_obs"][:, -1]
next_values = self.actor_critic(normalized_last_obs, buff["rnn_states"][:, -1], values_only=True)["values"]
buff["values"][:, -1] = next_values
if self.cfg.normalize_returns:
# Since our value targets are normalized, the values will also have normalized statistics.
# We need to denormalize them before using them for GAE caculation and value bootstrapping.
# rl_games PPO uses a similar approach, see:
# https://github.com/Denys88/rl_games/blob/7b5f9500ee65ae0832a7d8613b019c333ecd932c/rl_games/algos_torch/models.py#L51
denormalized_values = buff["values"].clone() # need to clone since normalizer is in-place
self.actor_critic.returns_normalizer(denormalized_values, denormalize=True)
else:
# values are not normalized in this case, so we can use them as is
denormalized_values = buff["values"]
if self.cfg.value_bootstrap:
# Value bootstrapping is a technique that reduces the surprise for the critic in case
# we're ending the episode by timeout. Intuitively, in this case the cumulative return for the last step
# should not be zero, but rather what the critic expects. This improves learning in many envs
# because otherwise the critic cannot predict the abrupt change in rewards in a timed-out episode.
# What we really want here is v(t+1) which we don't have because we don't have obs(t+1) (since
# the episode ended). Using v(t) is an approximation that requires that rew(t) can be generally ignored.
# Multiply by both time_out and done flags to make sure we count only timeouts in terminal states.
# There was a bug in older versions of isaacgym where timeouts were reported for non-terminal states.
buff["rewards"].add_(self.cfg.gamma * denormalized_values[:, :-1] * buff["time_outs"] * buff["dones"])
if not self.cfg.with_vtrace:
# calculate advantage estimate (in case of V-trace it is done separately for each minibatch)
buff["advantages"] = gae_advantages(
buff["rewards"],
buff["dones"],
denormalized_values,
buff["valids"],
self.cfg.gamma,
self.cfg.gae_lambda,
)
# here returns are not normalized yet, so we should use denormalized values
buff["returns"] = buff["advantages"] + buff["valids"][:, :-1] * denormalized_values[:, :-1]
# remove next step obs, rnn_states, and values from the batch, we don't need them anymore
for key in ["normalized_obs", "rnn_states", "values", "valids"]: