diff --git a/notebooks/brax_training.ipynb b/notebooks/brax_training.ipynb index 45b97336..b58adc38 100644 --- a/notebooks/brax_training.ipynb +++ b/notebooks/brax_training.ipynb @@ -18,7 +18,7 @@ }, "outputs": [], "source": [ - "!pip install git+https://github.com/Denys88/rl_games" + "!pip install git+https://github.com/Denys88/rl_games ray" ] }, { @@ -33,15 +33,9 @@ "#@markdown ## ⚠️ PLEASE NOTE:\n", "#@markdown This colab runs using a GPU runtime. From the Colab menu, choose Runtime > Change Runtime Type, then select **'GPU'** in the dropdown.\n", "\n", - "from datetime import datetime\n", - "import functools\n", - "import os\n", - "\n", - "from IPython.display import HTML, clear_output\n", + "from IPython.display import display, clear_output\n", "\n", "import jax\n", - "import jax.numpy as jnp\n", - "import matplotlib.pyplot as plt\n", "\n", "try:\n", " import brax\n", @@ -50,10 +44,10 @@ " clear_output()\n", " import brax\n", "\n", - "from brax import envs\n", - "from brax import jumpy as jp\n", "from brax.io import html\n", - "from brax.io import model" + "\n", + "from rl_games.torch_runner import Runner\n", + "from rl_games.envs.brax import BraxEnv" ] }, { @@ -213,9 +207,6 @@ }, "outputs": [], "source": [ - "import yaml\n", - "from rl_games.torch_runner import Runner\n", - "\n", "env_name = 'ant' # @param ['ant', 'humanoid']\n", "configs = {\n", " 'ant' : ant_config,\n", @@ -238,20 +229,6 @@ "})" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "3s-95B-KlqE1" - }, - "outputs": [], - "source": [ - "from rl_games.envs.brax import BraxEnv\n", - "\n", - "from IPython.display import HTML, IFrame, display, clear_output\n", - "import os" - ] - }, { "cell_type": "code", "execution_count": null, diff --git a/pyproject.toml b/pyproject.toml index e73c4c42..c56096b3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,6 +18,7 @@ psutil = "^5.9.0" setproctitle = "^1.2.2" opencv-python = "^4.5.5" wandb = "^0.12.11" +gymnasium = "^1.0" ale-py = {version = "^0.7", optional = true} AutoROM = {version = "^0.4.2", optional = true, extras = ["accept-rom-license"]} @@ -25,6 +26,7 @@ brax = {version = "^0.0.13", optional = true} jax = {version = "^0.3.13", optional = true} mujoco-py = {version = "^2.1.2", optional = true} envpool = {version = "^0.6.1", optional = true} +ray = {version = "^2.2.0", optional = true} [build-system] requires = ["poetry-core>=1.0.0"] diff --git a/rl_games/envs/brax.py b/rl_games/envs/brax.py index 31c7629f..22915bdf 100644 --- a/rl_games/envs/brax.py +++ b/rl_games/envs/brax.py @@ -18,8 +18,7 @@ def torch_to_jax(tensor): class BraxEnv(IVecEnv): def __init__(self, config_name, num_actors, **kwargs): - from brax import envs - import jax.numpy as jnp + from brax.v1 import envs self.batch_size = num_actors env_name=kwargs.pop('env_name', 'ant') diff --git a/setup.py b/setup.py index d3c36193..2b89a87a 100644 --- a/setup.py +++ b/setup.py @@ -37,6 +37,7 @@ # this setup is only for pytorch # 'gym>=0.17.2', + 'gymnasium' 'torch>=1.7.0', 'numpy>=1.16.0', 'tensorboard>=1.14.0',