Repository for doing model based RL
-
Requirements:
- Python >=3.11
- CUDA >= 12.1
- cudnn >= 8.9
-
Install JAX either on CPU or GPU:
pip install -U "jax[cpu]" pip install -U "jax[cuda12]"
-
Install with a conda environment:
conda create -n mbrl python=3.11 -y conda activate mbrl git clone https://github.com/lasgroup/model-based-rl.git pip install .
-
set up wandb
-
add mbrl to your python path:
PYTHONPATH=$PYTHONPATH:/path/to/model-based-rl
. You can also add this to your .bashrc. -
Launch experiments with the launcher:
python path/to/model-based-rl/experiments/experiment_name/launcher.py
Remote Deployment on euler.ethz.ch
-
Set up remote development from your computer to Euler in either PyCharm
or VSCode. -
Set up git protocols on Euler: Connecting to GitHub with SSH
-
Set up a .mbrl_setup file on your login node:
export XLA_PYTHON_CLIENT_MEM_FRACTION=.7 export TF_FORCE_GPU_ALLOW_GROWTH=true export TF_DETERMINISTIC_OPS=0 module load stack/2024-06 module load gcc/12.2.0 module load eth_proxy module load python/3.11.6 PYTHONPATH=$PYTHONPATH:/cluster/home/kiten/copax/model-based-rl export PYTHONPATH
Source it with
source .mbrl_setup
. -
Create a miniconda environment or a python virtual environment.
-
activate virtual environment:
source path/on/euler/to/venv/bin/activate
-
Install Jax for GPU (see the JAX documentation)
pip install "jax[cuda12]"
-
git clone and pip install the mbrl library:
git clone https://github.com/lasgroup/model-based-rl.git pip install .
-
set up wandb on Euler
-
add mbrl to your python path:
PYTHONPATH=$PYTHONPATH:/path/on/euler/to/model-based-rl
. You can also add this to your .bashrc or .mbrl_setup file. -
Launch experiments with the launcher:
python path/on/euler/to/model-based-rl/experiments/experiment_name/launcher.py
ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
This should generally not be a problem. Be sure to use the most up-to-date versions of jax, chex, flax.
Here's a combination that works on CentOS:
brax==0.10.0
chex>=0.1.82
flax==0.7.2
jax==0.4.14
jaxlib==0.4.14
jaxopt==0.8
jaxtyping==0.2.21
jaxutils==0.0.8
ml-collections==0.1.0
numba
numpy>=1.25.2
scipy==1.11.2
setuptools>=68.1.2
tensorflow>=2.13.0
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
about to launch 5 jobs with 1 cores each. proceed? [yes/no]
This is a normal warning when launching from the login node on Euler, since GPUs are only available on the computing node. It should use the GPU once the job is submitted.
This is an error on the cluster. On Euler (Ubuntu), make sure you load the correct modules (python
instead of python_cuda
).