From 90b209c9673837290b89f822606c37e0f68b18c5 Mon Sep 17 00:00:00 2001 From: albert bou Date: Mon, 13 Nov 2023 09:44:04 +0100 Subject: [PATCH] logging fix --- examples/a2c/a2c_atari.py | 4 ++-- examples/a2c/a2c_mujoco.py | 4 ++-- examples/ppo/ppo_atari.py | 4 ++-- examples/ppo/ppo_mujoco.py | 4 ++-- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/examples/a2c/a2c_atari.py b/examples/a2c/a2c_atari.py index 44a37cb3ce6..4598c11844b 100644 --- a/examples/a2c/a2c_atari.py +++ b/examples/a2c/a2c_atari.py @@ -117,9 +117,9 @@ def main(cfg: "DictConfig"): # noqa: F821 pbar.update(data.numel()) # Get training rewards and lengths - episode_rewards = data["next", "episode_reward"][data["next", "done"]] + episode_rewards = data["next", "episode_reward"][data["next", "terminated"]] if len(episode_rewards) > 0: - episode_length = data["next", "step_count"][data["next", "done"]] + episode_length = data["next", "step_count"][data["next", "terminated"]] log_info.update( { "train/reward": episode_rewards.mean().item(), diff --git a/examples/a2c/a2c_mujoco.py b/examples/a2c/a2c_mujoco.py index 7f9e588bbf6..48844dee6b6 100644 --- a/examples/a2c/a2c_mujoco.py +++ b/examples/a2c/a2c_mujoco.py @@ -101,9 +101,9 @@ def main(cfg: "DictConfig"): # noqa: F821 pbar.update(data.numel()) # Get training rewards and lengths - episode_rewards = data["next", "episode_reward"][data["next", "done"]] + episode_rewards = data["next", "episode_reward"][data["next", "terminated"]] if len(episode_rewards) > 0: - episode_length = data["next", "step_count"][data["next", "done"]] + episode_length = data["next", "step_count"][data["next", "terminated"]] log_info.update( { "train/reward": episode_rewards.mean().item(), diff --git a/examples/ppo/ppo_atari.py b/examples/ppo/ppo_atari.py index eb2ce15ec5a..1bfbccdeba4 100644 --- a/examples/ppo/ppo_atari.py +++ b/examples/ppo/ppo_atari.py @@ -134,9 +134,9 @@ def main(cfg: "DictConfig"): # noqa: F821 pbar.update(data.numel()) # Get training rewards and episode lengths - episode_rewards = data["next", "episode_reward"][data["next", "done"]] + episode_rewards = data["next", "episode_reward"][data["next", "terminated"]] if len(episode_rewards) > 0: - episode_length = data["next", "step_count"][data["next", "stop"]] + episode_length = data["next", "step_count"][data["next", "terminated"]] log_info.update( { "train/reward": episode_rewards.mean().item(), diff --git a/examples/ppo/ppo_mujoco.py b/examples/ppo/ppo_mujoco.py index ff6aeda51d2..988bc5300bf 100644 --- a/examples/ppo/ppo_mujoco.py +++ b/examples/ppo/ppo_mujoco.py @@ -120,9 +120,9 @@ def main(cfg: "DictConfig"): # noqa: F821 pbar.update(data.numel()) # Get training rewards and episode lengths - episode_rewards = data["next", "episode_reward"][data["next", "done"]] + episode_rewards = data["next", "episode_reward"][data["next", "terminated"]] if len(episode_rewards) > 0: - episode_length = data["next", "step_count"][data["next", "done"]] + episode_length = data["next", "step_count"][data["next", "terminated"]] log_info.update( { "train/reward": episode_rewards.mean().item(),