From e9589e4218e838d8a1377380c27ad45feb44f4ba Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bartek=20Cupia=C5=82?= <92169405+BartekCupial@users.noreply.github.com> Date: Tue, 19 Nov 2024 09:28:00 +0100 Subject: [PATCH] env steps (#307) --- sample_factory/algo/learning/learner.py | 1 + sample_factory/utils/wandb_utils.py | 8 ++++++++ 2 files changed, 9 insertions(+) diff --git a/sample_factory/algo/learning/learner.py b/sample_factory/algo/learning/learner.py index b40f1cb38..ed19b5ffc 100644 --- a/sample_factory/algo/learning/learner.py +++ b/sample_factory/algo/learning/learner.py @@ -843,6 +843,7 @@ def _record_summaries(self, train_loop_vars) -> AttrDict: self.last_summary_time = time.time() stats = AttrDict() + stats.env_steps = self.env_steps stats.lr = self.curr_lr stats.actual_lr = train_loop_vars.actual_lr # potentially scaled because of masked data diff --git a/sample_factory/utils/wandb_utils.py b/sample_factory/utils/wandb_utils.py index f0208c8b4..ec64d3b1a 100644 --- a/sample_factory/utils/wandb_utils.py +++ b/sample_factory/utils/wandb_utils.py @@ -56,6 +56,14 @@ def init_wandb_func(): wandb.config.update(cfg, allow_val_change=True) + wandb.define_metric("train/env_steps") + wandb.define_metric("train/*", step_metric="train/env_steps") + wandb.define_metric("perf/*", step_metric="train/env_steps") + wandb.define_metric("len/*", step_metric="train/env_steps") + wandb.define_metric("policy_stats/*", step_metric="train/env_steps") + wandb.define_metric("reward/*", step_metric="train/env_steps") + wandb.define_metric("stats/*", step_metric="train/env_steps") + def finish_wandb(cfg): if cfg.with_wandb: