diff --git a/metaworld-jax/Dockerfile b/metaworld-jax/Dockerfile new file mode 100644 index 0000000..89dfe0f --- /dev/null +++ b/metaworld-jax/Dockerfile @@ -0,0 +1,11 @@ +FROM python:3.10-slim-buster +LABEL maintainer="me@evangelos.ai" + +RUN apt-get update && apt install -y --no-install-recommends git python3-pip libglfw3 libglfw3-dev + +WORKDIR /usr/src/app +COPY requirements.txt ./ +RUN pip install --no-cache-dir --upgrade pip && \ + pip install --no-cache-dir -r requirements.txt + +ENTRYPOINT ["python"] diff --git a/metaworld-jax/README.md b/metaworld-jax/README.md new file mode 100644 index 0000000..6f4483b --- /dev/null +++ b/metaworld-jax/README.md @@ -0,0 +1,11 @@ +# metaworld-jax + +A base Dockerfile containing all the required dependencies for running my Metaworld RL experiments with JAX. + +Includes: + +- Python@3.10 +- CUDA & cuDNN for GPU acceleration. +- (TEMPORARY) Custom versions of Metaworld and Multi-task Cleanrl. +- JAX, Flax and Orbax. +- Pytorch/tensorboard/wandb for logging. diff --git a/metaworld-jax/requirements.txt b/metaworld-jax/requirements.txt new file mode 100644 index 0000000..d802543 --- /dev/null +++ b/metaworld-jax/requirements.txt @@ -0,0 +1,16 @@ +# Jax +jax[cuda12_pip]==0.4.31 +flax==0.8.5 +distrax==0.1.5 + +# Metaworld +metaworld @ git+https://github.com/rainx0r/Metaworld.git@f131964 + +# Pytorch CPU for logging +torch==2.4.0 + +# Logging +wandb==0.17.6 +tensorboard==2.17.1 +orbax-checkpoint @ git+https://github.com/google/orbax/@2ce2fb27f9786442b08ba14c8767c460dd6e8a0a#subdirectory=checkpoint +