diff --git a/sample_factory/algo/learning/learner.py b/sample_factory/algo/learning/learner.py index b40f1cb3..ed19b5ff 100644 --- a/sample_factory/algo/learning/learner.py +++ b/sample_factory/algo/learning/learner.py @@ -843,6 +843,7 @@ def _record_summaries(self, train_loop_vars) -> AttrDict: self.last_summary_time = time.time() stats = AttrDict() + stats.env_steps = self.env_steps stats.lr = self.curr_lr stats.actual_lr = train_loop_vars.actual_lr # potentially scaled because of masked data diff --git a/sample_factory/utils/wandb_utils.py b/sample_factory/utils/wandb_utils.py index f0208c8b..ec64d3b1 100644 --- a/sample_factory/utils/wandb_utils.py +++ b/sample_factory/utils/wandb_utils.py @@ -56,6 +56,14 @@ def init_wandb_func(): wandb.config.update(cfg, allow_val_change=True) + wandb.define_metric("train/env_steps") + wandb.define_metric("train/*", step_metric="train/env_steps") + wandb.define_metric("perf/*", step_metric="train/env_steps") + wandb.define_metric("len/*", step_metric="train/env_steps") + wandb.define_metric("policy_stats/*", step_metric="train/env_steps") + wandb.define_metric("reward/*", step_metric="train/env_steps") + wandb.define_metric("stats/*", step_metric="train/env_steps") + def finish_wandb(cfg): if cfg.with_wandb: