From 54439c868303f66919c9fb0b0d90f0b5b33fc829 Mon Sep 17 00:00:00 2001 From: arjun_kg Date: Mon, 3 Apr 2023 21:13:14 +0900 Subject: [PATCH 01/12] ddpg continuous + jax --- README.md | 10 +++++++++- benchmark/ddpg.sh | 0 cleanrl/ddpg_continuous_action.py | 23 ++++++++++++----------- cleanrl/ddpg_continuous_action_jax.py | 20 +++++++++----------- pyproject.toml | 2 +- 5 files changed, 31 insertions(+), 24 deletions(-) mode change 100644 => 100755 benchmark/ddpg.sh diff --git a/README.md b/README.md index 2a6ceb6f0..d1b827509 100644 --- a/README.md +++ b/README.md @@ -30,7 +30,15 @@ You can read more about CleanRL in our [JMLR paper](https://www.jmlr.org/papers/ CleanRL only contains implementations of **online** deep reinforcement learning algorithms. If you are looking for **offline** algorithms, please check out [tinkoff-ai/CORL](https://github.com/tinkoff-ai/CORL), which shares a similar design philosophy as CleanRL. -> ℹ️ **Support for Gymnasium**: [Farama-Foundation/Gymnasium](https://github.com/Farama-Foundation/Gymnasium) is the next generation of [`openai/gym`](https://github.com/openai/gym) that will continue to be maintained and introduce new features. Please see their [announcement](https://farama.org/Announcing-The-Farama-Foundation) for further detail. We are migrating to `gymnasium` and the progress can be tracked in [vwxyzjn/cleanrl#277](https://github.com/vwxyzjn/cleanrl/pull/277). +> ℹ️ **Support for Gymnasium**: [Farama-Foundation/Gymnasium](https://github.com/Farama-Foundation/Gymnasium) is the next generation of [`openai/gym`](https://github.com/openai/gym) that will continue to be maintained and introduce new features. Please see their [announcement](https://farama.org/Announcing-The-Farama-Foundation) for further detail. We are migrating to `gymnasium` and the progress can be tracked in [vwxyzjn/cleanrl#277](https://github.com/vwxyzjn/cleanrl/pull/277). + +Currently, `ppo_continuous_action_isaacgym.py`, `ddpg_continuous_action_jax.py`, `ddpg_continuous_action.py` have been ported to gymnasium. + +Please note that, `stable-baselines3` version `1.2` does not support `gymnasium`. To use these scripts, please install the `alpha1` version like, + +``` +poetry run pip install sb3==2.0.0a1 +``` diff --git a/benchmark/ddpg.sh b/benchmark/ddpg.sh old mode 100644 new mode 100755 diff --git a/cleanrl/ddpg_continuous_action.py b/cleanrl/ddpg_continuous_action.py index 00a821918..e2a92327b 100644 --- a/cleanrl/ddpg_continuous_action.py +++ b/cleanrl/ddpg_continuous_action.py @@ -5,9 +5,10 @@ import time from distutils.util import strtobool -import gym +import gymnasium as gym import numpy as np -import pybullet_envs # noqa + +# import pybullet_envs # noqa import torch import torch.nn as nn import torch.nn.functional as F @@ -71,7 +72,6 @@ def thunk(): if capture_video: if idx == 0: env = gym.wrappers.RecordVideo(env, f"videos/{run_name}") - env.seed(seed) env.action_space.seed(seed) env.observation_space.seed(seed) return env @@ -164,12 +164,12 @@ def forward(self, x): envs.single_observation_space, envs.single_action_space, device, - handle_timeout_termination=True, + handle_timeout_termination=False, ) start_time = time.time() # TRY NOT TO MODIFY: start the game - obs = envs.reset() + obs, _ = envs.reset(seed=args.seed) for global_step in range(args.total_timesteps): # ALGO LOGIC: put action logic here if global_step < args.learning_starts: @@ -181,11 +181,12 @@ def forward(self, x): actions = actions.cpu().numpy().clip(envs.single_action_space.low, envs.single_action_space.high) # TRY NOT TO MODIFY: execute the game and log data. - next_obs, rewards, dones, infos = envs.step(actions) + next_obs, rewards, terminateds, truncateds, infos = envs.step(actions) # TRY NOT TO MODIFY: record rewards for plotting purposes - for info in infos: - if "episode" in info.keys(): + + if "final_info" in infos: + for info in infos["final_info"]: print(f"global_step={global_step}, episodic_return={info['episode']['r']}") writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step) writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step) @@ -193,10 +194,10 @@ def forward(self, x): # TRY NOT TO MODIFY: save data to reply buffer; handle `terminal_observation` real_next_obs = next_obs.copy() - for idx, d in enumerate(dones): + for idx, d in enumerate(truncateds): if d: - real_next_obs[idx] = infos[idx]["terminal_observation"] - rb.add(obs, real_next_obs, actions, rewards, dones, infos) + real_next_obs[idx] = infos["final_observation"][idx] + rb.add(obs, real_next_obs, actions, rewards, terminateds, infos) # TRY NOT TO MODIFY: CRUCIAL step easy to overlook obs = next_obs diff --git a/cleanrl/ddpg_continuous_action_jax.py b/cleanrl/ddpg_continuous_action_jax.py index b6291e4dc..4769b3f84 100644 --- a/cleanrl/ddpg_continuous_action_jax.py +++ b/cleanrl/ddpg_continuous_action_jax.py @@ -8,12 +8,11 @@ import flax import flax.linen as nn -import gym +import gymnasium as gym import jax import jax.numpy as jnp import numpy as np import optax -import pybullet_envs # noqa from flax.training.train_state import TrainState from stable_baselines3.common.buffers import ReplayBuffer from torch.utils.tensorboard import SummaryWriter @@ -70,7 +69,6 @@ def thunk(): if capture_video: if idx == 0: env = gym.wrappers.RecordVideo(env, f"videos/{run_name}") - env.seed(seed) env.action_space.seed(seed) env.observation_space.seed(seed) return env @@ -150,11 +148,11 @@ class TrainState(TrainState): envs.single_observation_space, envs.single_action_space, device="cpu", - handle_timeout_termination=True, + handle_timeout_termination=False, ) # TRY NOT TO MODIFY: start the game - obs = envs.reset() + obs, _ = envs.reset() action_scale = np.array((envs.action_space.high - envs.action_space.low) / 2.0) action_bias = np.array((envs.action_space.high + envs.action_space.low) / 2.0) actor = Actor( @@ -235,11 +233,11 @@ def actor_loss(params): ) # TRY NOT TO MODIFY: execute the game and log data. - next_obs, rewards, dones, infos = envs.step(actions) + next_obs, rewards, terminateds, truncateds, infos = envs.step(actions) # TRY NOT TO MODIFY: record rewards for plotting purposes - for info in infos: - if "episode" in info.keys(): + if "final_info" in infos: + for info in infos["final_info"]: print(f"global_step={global_step}, episodic_return={info['episode']['r']}") writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step) writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step) @@ -247,10 +245,10 @@ def actor_loss(params): # TRY NOT TO MODIFY: save data to reply buffer; handle `terminal_observation` real_next_obs = next_obs.copy() - for idx, d in enumerate(dones): + for idx, d in enumerate(truncateds): if d: - real_next_obs[idx] = infos[idx]["terminal_observation"] - rb.add(obs, real_next_obs, actions, rewards, dones, infos) + real_next_obs[idx] = infos["final_observation"][idx] + rb.add(obs, real_next_obs, actions, rewards, terminateds, infos) # TRY NOT TO MODIFY: CRUCIAL step easy to overlook obs = next_obs diff --git a/pyproject.toml b/pyproject.toml index 599d78b86..68b374937 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,7 +18,7 @@ wandb = "^0.13.6" gym = "0.23.1" torch = ">=1.12.1" stable-baselines3 = "1.2.0" -gymnasium = "^0.26.3" +gymnasium = "^0.28.1" moviepy = "^1.0.3" pygame = "2.1.0" huggingface-hub = "^0.11.1" From 6f4f072a49d4c7d5bf19ef6c6abdbb5fd95c652c Mon Sep 17 00:00:00 2001 From: arjun_kg Date: Mon, 3 Apr 2023 22:01:50 +0900 Subject: [PATCH 02/12] fix video recording --- cleanrl/ddpg_continuous_action.py | 17 ++++++++++++++--- cleanrl/ddpg_continuous_action_jax.py | 14 ++++++++++++-- 2 files changed, 26 insertions(+), 5 deletions(-) diff --git a/cleanrl/ddpg_continuous_action.py b/cleanrl/ddpg_continuous_action.py index e2a92327b..9cb577957 100644 --- a/cleanrl/ddpg_continuous_action.py +++ b/cleanrl/ddpg_continuous_action.py @@ -38,7 +38,7 @@ def parse_args(): help="whether to capture videos of the agent performances (check out `videos` folder)") # Algorithm specific arguments - parser.add_argument("--env-id", type=str, default="HopperBulletEnv-v0", + parser.add_argument("--env-id", type=str, default="HalfCheetah-v2", help="the id of the environment") parser.add_argument("--total-timesteps", type=int, default=1000000, help="total timesteps of the experiments") @@ -67,7 +67,10 @@ def parse_args(): def make_env(env_id, seed, idx, capture_video, run_name): def thunk(): - env = gym.make(env_id) + if capture_video: + env = gym.make(env_id, render_mode="rgb_array") + else: + env = gym.make(env_id) env = gym.wrappers.RecordEpisodeStatistics(env) if capture_video: if idx == 0: @@ -128,7 +131,7 @@ def forward(self, x): sync_tensorboard=True, config=vars(args), name=run_name, - monitor_gym=True, + # monitor_gym=True, # no longer works for gymnasium save_code=True, ) writer = SummaryWriter(f"runs/{run_name}") @@ -170,6 +173,8 @@ def forward(self, x): # TRY NOT TO MODIFY: start the game obs, _ = envs.reset(seed=args.seed) + video_filenames = set() + for global_step in range(args.total_timesteps): # ALGO LOGIC: put action logic here if global_step < args.learning_starts: @@ -238,4 +243,10 @@ def forward(self, x): writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step) envs.close() + + if args.track and args.capture_video: + for filename in os.listdir(f"videos/{run_name}"): + if filename not in video_filenames and filename.endswith(".mp4"): + wandb.log({f"videos": wandb.Video(f"videos/{run_name}/{filename}")}) + video_filenames.add(filename) writer.close() diff --git a/cleanrl/ddpg_continuous_action_jax.py b/cleanrl/ddpg_continuous_action_jax.py index 4769b3f84..ff7cafed2 100644 --- a/cleanrl/ddpg_continuous_action_jax.py +++ b/cleanrl/ddpg_continuous_action_jax.py @@ -64,7 +64,10 @@ def parse_args(): def make_env(env_id, seed, idx, capture_video, run_name): def thunk(): - env = gym.make(env_id) + if capture_video: + env = gym.make(env_id, render_mode="rgb_array") + else: + env = gym.make(env_id) env = gym.wrappers.RecordEpisodeStatistics(env) if capture_video: if idx == 0: @@ -122,7 +125,7 @@ class TrainState(TrainState): sync_tensorboard=True, config=vars(args), name=run_name, - monitor_gym=True, + monitor_gym=True, # does not work on gymnasium save_code=True, ) writer = SummaryWriter(f"runs/{run_name}") @@ -153,6 +156,7 @@ class TrainState(TrainState): # TRY NOT TO MODIFY: start the game obs, _ = envs.reset() + video_filenames = set() action_scale = np.array((envs.action_space.high - envs.action_space.low) / 2.0) action_bias = np.array((envs.action_space.high + envs.action_space.low) / 2.0) actor = Actor( @@ -280,4 +284,10 @@ def actor_loss(params): writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step) envs.close() + if args.track and args.capture_video: + for filename in os.listdir(f"videos/{run_name}"): + if filename not in video_filenames and filename.endswith(".mp4"): + wandb.log({f"videos": wandb.Video(f"videos/{run_name}/{filename}")}) + video_filenames.add(filename) + writer.close() From 4171609e267e289517408e4350a15f9d9e1722cb Mon Sep 17 00:00:00 2001 From: arjun_kg Date: Sat, 8 Apr 2023 00:17:32 +0900 Subject: [PATCH 03/12] remove pybullet --- cleanrl/ddpg_continuous_action.py | 1 - 1 file changed, 1 deletion(-) diff --git a/cleanrl/ddpg_continuous_action.py b/cleanrl/ddpg_continuous_action.py index 9cb577957..024a5ceb9 100644 --- a/cleanrl/ddpg_continuous_action.py +++ b/cleanrl/ddpg_continuous_action.py @@ -8,7 +8,6 @@ import gymnasium as gym import numpy as np -# import pybullet_envs # noqa import torch import torch.nn as nn import torch.nn.functional as F From f2608a393c8aa197a9f7cebf6e188bb808b4ef27 Mon Sep 17 00:00:00 2001 From: arjun_kg Date: Sat, 8 Apr 2023 00:24:04 +0900 Subject: [PATCH 04/12] move to usage docs --- README.md | 9 --------- docs/get-started/basic-usage.md | 9 +++++++++ 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/README.md b/README.md index ae3559501..8e11f9e83 100644 --- a/README.md +++ b/README.md @@ -32,15 +32,6 @@ CleanRL only contains implementations of **online** deep reinforcement learning > ℹ️ **Support for Gymnasium**: [Farama-Foundation/Gymnasium](https://github.com/Farama-Foundation/Gymnasium) is the next generation of [`openai/gym`](https://github.com/openai/gym) that will continue to be maintained and introduce new features. Please see their [announcement](https://farama.org/Announcing-The-Farama-Foundation) for further detail. We are migrating to `gymnasium` and the progress can be tracked in [vwxyzjn/cleanrl#277](https://github.com/vwxyzjn/cleanrl/pull/277). -Currently, `ppo_continuous_action_isaacgym.py`, `ddpg_continuous_action_jax.py`, `ddpg_continuous_action.py` have been ported to gymnasium. - -Please note that, `stable-baselines3` version `1.2` does not support `gymnasium`. To use these scripts, please install the `alpha1` version like, - -``` -poetry run pip install sb3==2.0.0a1 -``` - - > ⚠️ **NOTE**: CleanRL is *not* a modular library and therefore it is not meant to be imported. At the cost of duplicate code, we make all implementation details of a DRL algorithm variant easy to understand, so CleanRL comes with its own pros and cons. You should consider using CleanRL if you want to 1) understand all implementation details of an algorithm's varaint or 2) prototype advanced features that other modular DRL libraries do not support (CleanRL has minimal lines of code so it gives you great debugging experience and you don't have do a lot of subclassing like sometimes in modular DRL libraries). diff --git a/docs/get-started/basic-usage.md b/docs/get-started/basic-usage.md index 0d2fe1ba7..5571c3e7e 100644 --- a/docs/get-started/basic-usage.md +++ b/docs/get-started/basic-usage.md @@ -44,6 +44,15 @@ the CleanRL script under the poetry virtual environments. **We will assume to run other commands (e.g. `tensorboard`) in the documentation within the poetry's shell.** +!!! note +Currently, `ddpg_continuous_action_jax.py`, `ddpg_continuous_action.py` have been ported to gymnasium. + +Please note that, `stable-baselines3` version `1.2` does not support `gymnasium`. To use these scripts, please install the `alpha1` version like, + +``` +poetry run pip install sb3==2.0.0a1 +``` + !!! warning If you are using NVIDIA ampere GPUs (e.g., 3060 TI), you might meet the following error From 6e6a5b57da573a0c5886eb169523f7713a84a209 Mon Sep 17 00:00:00 2001 From: arjun_kg Date: Sat, 8 Apr 2023 00:24:16 +0900 Subject: [PATCH 05/12] isort --- cleanrl/ddpg_continuous_action.py | 1 - 1 file changed, 1 deletion(-) diff --git a/cleanrl/ddpg_continuous_action.py b/cleanrl/ddpg_continuous_action.py index 024a5ceb9..32f69dde5 100644 --- a/cleanrl/ddpg_continuous_action.py +++ b/cleanrl/ddpg_continuous_action.py @@ -7,7 +7,6 @@ import gymnasium as gym import numpy as np - import torch import torch.nn as nn import torch.nn.functional as F From d8dd8012f4ac4db59317551fa767eeb372e3af53 Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Mon, 24 Apr 2023 15:45:55 -0400 Subject: [PATCH 06/12] update lock files --- poetry.lock | 80 ++++++++++++++++++++++++++++++++++------------------- 1 file changed, 51 insertions(+), 29 deletions(-) diff --git a/poetry.lock b/poetry.lock index ea2c39ce6..8411a66d8 100644 --- a/poetry.lock +++ b/poetry.lock @@ -845,6 +845,18 @@ typeguard = ">=2.6.1" [package.extras] test = ["mock (>=2.0.0)", "pytest (>=5.0)", "pytest-asyncio", "pytest-cov", "tensorflow"] +[[package]] +name = "farama-notifications" +version = "0.0.4" +description = "Notifications for all Farama Foundation maintained libraries." +category = "main" +optional = false +python-versions = "*" +files = [ + {file = "Farama-Notifications-0.0.4.tar.gz", hash = "sha256:13fceff2d14314cf80703c8266462ebf3733c7d165336eee998fc58e545efd18"}, + {file = "Farama_Notifications-0.0.4-py3-none-any.whl", hash = "sha256:14de931035a41961f7c056361dc7f980762a143d05791ef5794a751a2caf05ae"}, +] + [[package]] name = "fasteners" version = "0.15" @@ -1298,45 +1310,36 @@ test = ["gym (==0.17.2)", "gym-retro (==0.8.0)", "mpi4py (==3.0.3)", "pytest (== [[package]] name = "gymnasium" -version = "0.26.3" -description = "A standard API for reinforcement learning and a diverse set of reference environments (formerly Gym)" +version = "0.28.1" +description = "A standard API for reinforcement learning and a diverse set of reference environments (formerly Gym)." category = "main" optional = false -python-versions = ">=3.6" +python-versions = ">=3.7" files = [ - {file = "Gymnasium-0.26.3-py3-none-any.whl", hash = "sha256:4be0085252759c65b09c9fb83970ceedd02fab03b075024d8ba22eaa1a11eda1"}, - {file = "Gymnasium-0.26.3.tar.gz", hash = "sha256:2a918e321fc0bb48f4ebf2936ccd8f20a049658f1509dea9c6e768b8030392ed"}, + {file = "gymnasium-0.28.1-py3-none-any.whl", hash = "sha256:7bc9a5bce1022f997d1dbc152fc91d1ac977bad9cc7794cdc25437010867cabf"}, + {file = "gymnasium-0.28.1.tar.gz", hash = "sha256:4c2c745808792c8f45c6e88ad0a5504774394e0c126f6e3db555e720d3da6f24"}, ] [package.dependencies] cloudpickle = ">=1.2.0" -gymnasium-notices = ">=0.0.1" +farama-notifications = ">=0.0.1" importlib-metadata = {version = ">=4.8.0", markers = "python_version < \"3.10\""} -numpy = ">=1.18.0" +jax-jumpy = ">=1.0.0" +numpy = ">=1.21.0" +typing-extensions = ">=4.3.0" [package.extras] accept-rom-license = ["autorom[accept-rom-license] (>=0.4.2,<0.5.0)"] -all = ["ale-py (>=0.8.0,<0.9.0)", "box2d-py (==2.3.5)", "gym (==0.26.2)", "imageio (>=2.14.1)", "lz4 (>=3.1.0)", "matplotlib (>=3.0)", "moviepy (>=1.0.0)", "mujoco (==2.2)", "mujoco-py (>=2.1,<2.2)", "opencv-python (>=3.0)", "pygame (==2.1.0)", "pytest (==7.0.1)", "swig (>=4.0.0,<5.0.0)"] -atari = ["ale-py (>=0.8.0,<0.9.0)"] -box2d = ["box2d-py (==2.3.5)", "pygame (==2.1.0)", "swig (>=4.0.0,<5.0.0)"] -classic-control = ["pygame (==2.1.0)"] -mujoco = ["imageio (>=2.14.1)", "mujoco (==2.2)"] -mujoco-py = ["mujoco-py (>=2.1,<2.2)"] -other = ["lz4 (>=3.1.0)", "matplotlib (>=3.0)", "moviepy (>=1.0.0)", "opencv-python (>=3.0)"] -testing = ["box2d-py (==2.3.5)", "gym (==0.26.2)", "imageio (>=2.14.1)", "lz4 (>=3.1.0)", "matplotlib (>=3.0)", "moviepy (>=1.0.0)", "mujoco (==2.2)", "mujoco-py (>=2.1,<2.2)", "opencv-python (>=3.0)", "pygame (==2.1.0)", "pytest (==7.0.1)", "swig (>=4.0.0,<5.0.0)"] -toy-text = ["pygame (==2.1.0)"] - -[[package]] -name = "gymnasium-notices" -version = "0.0.1" -description = "Notices for gymnasium" -category = "main" -optional = false -python-versions = "*" -files = [ - {file = "gymnasium-notices-0.0.1.tar.gz", hash = "sha256:3e8c868046f56dea84c949cc7e97383cccfab27152fc3f4968754e4c9c087ab9"}, - {file = "gymnasium_notices-0.0.1-py3-none-any.whl", hash = "sha256:be68c8399e88b554b6db1eb3c484b00f229cbe5c930f64f6ae9cd1a6e93db1c5"}, -] +all = ["box2d-py (==2.3.5)", "imageio (>=2.14.1)", "jax (==0.3.24)", "jaxlib (==0.3.24)", "lz4 (>=3.1.0)", "matplotlib (>=3.0)", "moviepy (>=1.0.0)", "mujoco (>=2.3.2)", "mujoco-py (>=2.1,<2.2)", "opencv-python (>=3.0)", "pygame (==2.1.3)", "shimmy[atari] (>=0.1.0,<1.0)", "swig (>=4.0.0,<5.0.0)", "torch (>=1.0.0)"] +atari = ["shimmy[atari] (>=0.1.0,<1.0)"] +box2d = ["box2d-py (==2.3.5)", "pygame (==2.1.3)", "swig (>=4.0.0,<5.0.0)"] +classic-control = ["pygame (==2.1.3)", "pygame (==2.1.3)"] +jax = ["jax (==0.3.24)", "jaxlib (==0.3.24)"] +mujoco = ["imageio (>=2.14.1)", "mujoco (>=2.3.2)"] +mujoco-py = ["mujoco-py (>=2.1,<2.2)", "mujoco-py (>=2.1,<2.2)"] +other = ["lz4 (>=3.1.0)", "matplotlib (>=3.0)", "moviepy (>=1.0.0)", "opencv-python (>=3.0)", "torch (>=1.0.0)"] +testing = ["pytest (==7.1.3)", "scipy (==1.7.3)"] +toy-text = ["pygame (==2.1.3)", "pygame (==2.1.3)"] [[package]] name = "huggingface-hub" @@ -1594,6 +1597,25 @@ cuda11-cudnn82 = ["jaxlib (==0.3.15+cuda11.cudnn82)"] minimum-jaxlib = ["jaxlib (==0.3.14)"] tpu = ["jaxlib (==0.3.15)", "libtpu-nightly (==0.1.dev20220723)", "requests"] +[[package]] +name = "jax-jumpy" +version = "1.0.0" +description = "Common backend for Jax or Numpy." +category = "main" +optional = false +python-versions = ">=3.7" +files = [ + {file = "jax-jumpy-1.0.0.tar.gz", hash = "sha256:195fb955cc4c2b7f0b1453e3cb1fb1c414a51a407ffac7a51e69a73cb30d59ad"}, + {file = "jax_jumpy-1.0.0-py3-none-any.whl", hash = "sha256:ab7e01454bba462de3c4d098e3e585c302a8f06bc36d9182ab4e7e4aa7067c5e"}, +] + +[package.dependencies] +numpy = ">=1.18.0" + +[package.extras] +jax = ["jax (>=0.3.24)", "jaxlib (>=0.3.24)"] +testing = ["pytest (==7.1.3)"] + [[package]] name = "jaxlib" version = "0.3.15" @@ -4656,4 +4678,4 @@ pytest = ["pytest"] [metadata] lock-version = "2.0" python-versions = ">=3.7.1,<3.10" -content-hash = "921a8a7e4153e969e0cc03593f9739f8246c51d8f683f34eb4396db8863bc1b4" +content-hash = "03032e39ebcff13ae198823e2aab81af60c1d97a154d9c67ac674455e71a50b7" From d9825b453edd804f5375f1027ca619571ab62a1a Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Tue, 25 Apr 2023 09:39:23 -0400 Subject: [PATCH 07/12] try trigger CI --- cleanrl/ddpg_continuous_action.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/cleanrl/ddpg_continuous_action.py b/cleanrl/ddpg_continuous_action.py index 32f69dde5..9a0fbbfee 100644 --- a/cleanrl/ddpg_continuous_action.py +++ b/cleanrl/ddpg_continuous_action.py @@ -119,6 +119,10 @@ def forward(self, x): if __name__ == "__main__": args = parse_args() + import stable_baselines3 as sb3 + if sb3.__version__ < "2.0": + raise ValueError("Ongoing migration: run `poetry run pip install sb3==2.0.0a1`") + run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}" run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}" if args.track: import wandb From 03b3c7eb28d277bdb409a03be3102f66ab40dc28 Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Wed, 3 May 2023 11:19:24 -0400 Subject: [PATCH 08/12] update ddpg default v4 environments --- cleanrl/ddpg_continuous_action.py | 13 +++++++++---- cleanrl/ddpg_continuous_action_jax.py | 11 ++++++++++- 2 files changed, 19 insertions(+), 5 deletions(-) diff --git a/cleanrl/ddpg_continuous_action.py b/cleanrl/ddpg_continuous_action.py index 9a0fbbfee..14ccfd252 100644 --- a/cleanrl/ddpg_continuous_action.py +++ b/cleanrl/ddpg_continuous_action.py @@ -36,7 +36,7 @@ def parse_args(): help="whether to capture videos of the agent performances (check out `videos` folder)") # Algorithm specific arguments - parser.add_argument("--env-id", type=str, default="HalfCheetah-v2", + parser.add_argument("--env-id", type=str, default="HalfCheetah-v4", help="the id of the environment") parser.add_argument("--total-timesteps", type=int, default=1000000, help="total timesteps of the experiments") @@ -118,11 +118,16 @@ def forward(self, x): if __name__ == "__main__": - args = parse_args() import stable_baselines3 as sb3 + if sb3.__version__ < "2.0": - raise ValueError("Ongoing migration: run `poetry run pip install sb3==2.0.0a1`") - run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}" + raise ValueError( + """Ongoing migration: run the following command to install the new dependencies: + +poetry run pip install "stable_baselines3==2.0.0a1" +""" + ) + args = parse_args() run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}" if args.track: import wandb diff --git a/cleanrl/ddpg_continuous_action_jax.py b/cleanrl/ddpg_continuous_action_jax.py index ff7cafed2..6ddb87ad4 100644 --- a/cleanrl/ddpg_continuous_action_jax.py +++ b/cleanrl/ddpg_continuous_action_jax.py @@ -35,7 +35,7 @@ def parse_args(): help="whether to capture videos of the agent performances (check out `videos` folder)") # Algorithm specific arguments - parser.add_argument("--env-id", type=str, default="HalfCheetah-v2", + parser.add_argument("--env-id", type=str, default="HalfCheetah-v4", help="the id of the environment") parser.add_argument("--total-timesteps", type=int, default=1000000, help="total timesteps of the experiments") @@ -114,6 +114,15 @@ class TrainState(TrainState): if __name__ == "__main__": + import stable_baselines3 as sb3 + + if sb3.__version__ < "2.0": + raise ValueError( + """Ongoing migration: run the following command to install the new dependencies: + +poetry run pip install "stable_baselines3==2.0.0a1" +""" + ) args = parse_args() run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}" if args.track: From b6d8598be5609c63fadcdfe0c6cc58e08447516b Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Wed, 3 May 2023 13:34:23 -0400 Subject: [PATCH 09/12] trigger CI --- .github/workflows/tests.yaml | 94 +++++++++++++++++++++++++++++++ tests/test_mujoco.py | 10 ++++ tests/test_mujoco_gymnasium.py | 17 ++++++ tests/test_mujoco_py.py | 10 ---- tests/test_mujoco_py_gymnasium.py | 17 ++++++ tests/test_pybullet.py | 5 -- 6 files changed, 138 insertions(+), 15 deletions(-) create mode 100644 tests/test_mujoco_gymnasium.py create mode 100644 tests/test_mujoco_py_gymnasium.py diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index 7c4933195..2f53209f7 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -174,6 +174,38 @@ jobs: continue-on-error: true # MUJOCO_GL=osmesa results in `free(): invalid pointer` run: poetry run pytest tests/test_mujoco.py + test-mujoco-gymnasium-envs: + strategy: + fail-fast: false + matrix: + python-version: [3.8] + poetry-version: [1.3] + os: [ubuntu-22.04] + runs-on: ${{ matrix.os }} + steps: + - uses: actions/checkout@v2 + - uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + - name: Run image + uses: abatilo/actions-poetry@v2.0.0 + with: + poetry-version: ${{ matrix.poetry-version }} + + # mujoco tests + - name: Install dependencies + run: poetry install -E "pytest mujoco dm_control" + - name: Downgrade setuptools + run: poetry run pip install setuptools==59.5.0 + - name: Run gymnasium migration dependencies + run: poetry run pip install "stable_baselines3==2.0.0a1" + - name: install mujoco dependencies + run: | + sudo apt-get update && sudo apt-get -y install libgl1-mesa-glx libosmesa6 libglfw3 + - name: Run mujoco tests + continue-on-error: true # MUJOCO_GL=osmesa results in `free(): invalid pointer` + run: poetry run pytest tests/test_mujoco_gymnasium.py + test-mujoco-envs-windows-mac: strategy: fail-fast: false @@ -200,6 +232,33 @@ jobs: - name: Run mujoco tests run: poetry run pytest tests/test_mujoco.py + test-mujoco-gymnasium-windows-mac: + strategy: + fail-fast: false + matrix: + python-version: [3.8] + poetry-version: [1.3] + os: [macos-latest, windows-latest] + runs-on: ${{ matrix.os }} + steps: + - uses: actions/checkout@v2 + - uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + - name: Run image + uses: abatilo/actions-poetry@v2.0.0 + with: + poetry-version: ${{ matrix.poetry-version }} + + # mujoco tests + - name: Install dependencies + run: poetry install -E "pytest mujoco dm_control" + - name: Downgrade setuptools + run: poetry run pip install setuptools==59.5.0 + - name: Run gymnasium migration dependencies + run: poetry run pip install "stable_baselines3==2.0.0a1" + - name: Run mujoco tests + run: poetry run pytest tests/test_mujoco_gymnasium.py test-mujoco_py-envs: strategy: @@ -234,6 +293,41 @@ jobs: - name: Run mujoco_py tests run: poetry run pytest tests/test_mujoco_py.py + test-mujoco_py-envs-gymnasium: + strategy: + fail-fast: false + matrix: + python-version: [3.8] + poetry-version: [1.3] + os: [ubuntu-22.04] + runs-on: ${{ matrix.os }} + steps: + - uses: actions/checkout@v2 + - uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + - name: Run image + uses: abatilo/actions-poetry@v2.0.0 + with: + poetry-version: ${{ matrix.poetry-version }} + + # mujoco_py tests + - name: Install dependencies + run: poetry install -E "pytest pybullet mujoco_py mujoco jax" + - name: Run gymnasium migration dependencies + run: poetry run pip install "stable_baselines3==2.0.0a1" + - name: Downgrade setuptools + run: poetry run pip install setuptools==59.5.0 + - name: install mujoco_py dependencies + run: | + sudo apt-get update && sudo apt-get -y install wget unzip software-properties-common \ + libgl1-mesa-dev \ + libgl1-mesa-glx \ + libglew-dev \ + libosmesa6-dev patchelf + - name: Run mujoco_py tests + run: poetry run pytest tests/test_mujoco_py_gymnasium.py + test-envpool-envs: strategy: fail-fast: false diff --git a/tests/test_mujoco.py b/tests/test_mujoco.py index cc3acb9f6..dd7ead5f9 100644 --- a/tests/test_mujoco.py +++ b/tests/test_mujoco.py @@ -15,6 +15,16 @@ def test_mujoco(): shell=True, check=True, ) + subprocess.run( + "python cleanrl/td3_continuous_action_jax.py --env-id Hopper-v4 --learning-starts 100 --batch-size 32 --total-timesteps 105", + shell=True, + check=True, + ) + subprocess.run( + "python cleanrl/td3_continuous_action.py --env-id Hopper-v4 --learning-starts 100 --batch-size 32 --total-timesteps 105", + shell=True, + check=True, + ) subprocess.run( "python cleanrl/rpo_continuous_action.py --env-id Hopper-v4 --num-envs 1 --num-steps 64 --total-timesteps 128", shell=True, diff --git a/tests/test_mujoco_gymnasium.py b/tests/test_mujoco_gymnasium.py new file mode 100644 index 000000000..a887dc0ad --- /dev/null +++ b/tests/test_mujoco_gymnasium.py @@ -0,0 +1,17 @@ +import subprocess + + +def test_mujoco(): + """ + Test mujoco + """ + subprocess.run( + "python cleanrl/ddpg_continuous_action.py --env-id Hopper-v4 --learning-starts 100 --batch-size 32 --total-timesteps 105", + shell=True, + check=True, + ) + subprocess.run( + "python cleanrl/ddpg_continuous_action_jax.py --env-id Hopper-v4 --learning-starts 100 --batch-size 32 --total-timesteps 105", + shell=True, + check=True, + ) diff --git a/tests/test_mujoco_py.py b/tests/test_mujoco_py.py index c02389002..882e16808 100644 --- a/tests/test_mujoco_py.py +++ b/tests/test_mujoco_py.py @@ -10,16 +10,6 @@ def test_mujoco_py(): shell=True, check=True, ) - subprocess.run( - "python cleanrl/ddpg_continuous_action.py --env-id Hopper-v2 --learning-starts 100 --batch-size 32 --total-timesteps 105", - shell=True, - check=True, - ) - subprocess.run( - "python cleanrl/ddpg_continuous_action_jax.py --env-id Hopper-v2 --learning-starts 100 --batch-size 32 --total-timesteps 105", - shell=True, - check=True, - ) subprocess.run( "python cleanrl/td3_continuous_action_jax.py --env-id Hopper-v2 --learning-starts 100 --batch-size 32 --total-timesteps 105", shell=True, diff --git a/tests/test_mujoco_py_gymnasium.py b/tests/test_mujoco_py_gymnasium.py new file mode 100644 index 000000000..6474c0e5b --- /dev/null +++ b/tests/test_mujoco_py_gymnasium.py @@ -0,0 +1,17 @@ +import subprocess + + +def test_mujoco_py(): + """ + Test mujoco_py + """ + subprocess.run( + "python cleanrl/ddpg_continuous_action.py --env-id Hopper-v2 --learning-starts 100 --batch-size 32 --total-timesteps 105", + shell=True, + check=True, + ) + subprocess.run( + "python cleanrl/ddpg_continuous_action_jax.py --env-id Hopper-v2 --learning-starts 100 --batch-size 32 --total-timesteps 105", + shell=True, + check=True, + ) diff --git a/tests/test_pybullet.py b/tests/test_pybullet.py index c9fabf700..365f71ccb 100644 --- a/tests/test_pybullet.py +++ b/tests/test_pybullet.py @@ -2,11 +2,6 @@ def test_pybullet(): - subprocess.run( - "python cleanrl/ddpg_continuous_action.py --learning-starts 100 --batch-size 32 --total-timesteps 105", - shell=True, - check=True, - ) subprocess.run( "python cleanrl/td3_continuous_action.py --learning-starts 100 --batch-size 32 --total-timesteps 105", shell=True, From 8f0029dd95b6db1de8369a47b8c6b68929b70cbc Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Wed, 3 May 2023 14:02:13 -0400 Subject: [PATCH 10/12] install jax dependency --- .github/workflows/tests.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index 2f53209f7..43d56ba28 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -226,7 +226,7 @@ jobs: # mujoco tests - name: Install dependencies - run: poetry install -E "pytest mujoco dm_control" + run: poetry install -E "pytest mujoco dm_control jax" - name: Downgrade setuptools run: poetry run pip install setuptools==59.5.0 - name: Run mujoco tests @@ -252,7 +252,7 @@ jobs: # mujoco tests - name: Install dependencies - run: poetry install -E "pytest mujoco dm_control" + run: poetry install -E "pytest mujoco dm_control jax" - name: Downgrade setuptools run: poetry run pip install setuptools==59.5.0 - name: Run gymnasium migration dependencies From d6306034a06e7386530f3166169b8c39447bb781 Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Wed, 3 May 2023 14:10:40 -0400 Subject: [PATCH 11/12] fix CI --- tests/test_mujoco.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/tests/test_mujoco.py b/tests/test_mujoco.py index dd7ead5f9..cc3acb9f6 100644 --- a/tests/test_mujoco.py +++ b/tests/test_mujoco.py @@ -15,16 +15,6 @@ def test_mujoco(): shell=True, check=True, ) - subprocess.run( - "python cleanrl/td3_continuous_action_jax.py --env-id Hopper-v4 --learning-starts 100 --batch-size 32 --total-timesteps 105", - shell=True, - check=True, - ) - subprocess.run( - "python cleanrl/td3_continuous_action.py --env-id Hopper-v4 --learning-starts 100 --batch-size 32 --total-timesteps 105", - shell=True, - check=True, - ) subprocess.run( "python cleanrl/rpo_continuous_action.py --env-id Hopper-v4 --num-envs 1 --num-steps 64 --total-timesteps 128", shell=True, From 65000a8c5d7402aba2567c02a976c073bc1d381c Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Wed, 3 May 2023 14:37:35 -0400 Subject: [PATCH 12/12] remove windows CI --- .github/workflows/tests.yaml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index 43d56ba28..4558c02ae 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -206,13 +206,13 @@ jobs: continue-on-error: true # MUJOCO_GL=osmesa results in `free(): invalid pointer` run: poetry run pytest tests/test_mujoco_gymnasium.py - test-mujoco-envs-windows-mac: + test-mujoco-envs-mac: strategy: fail-fast: false matrix: python-version: [3.8] poetry-version: [1.3] - os: [macos-latest, windows-latest] + os: [macos-latest] runs-on: ${{ matrix.os }} steps: - uses: actions/checkout@v2 @@ -232,13 +232,13 @@ jobs: - name: Run mujoco tests run: poetry run pytest tests/test_mujoco.py - test-mujoco-gymnasium-windows-mac: + test-mujoco-gymnasium-mac: strategy: fail-fast: false matrix: python-version: [3.8] poetry-version: [1.3] - os: [macos-latest, windows-latest] + os: [macos-latest] runs-on: ${{ matrix.os }} steps: - uses: actions/checkout@v2