diff --git a/metaworld-jax/Dockerfile b/metaworld-jax/Dockerfile index d0507e0..9c95316 100644 --- a/metaworld-jax/Dockerfile +++ b/metaworld-jax/Dockerfile @@ -1,7 +1,7 @@ -FROM ghcr.io/nvidia/jax:jax-2024-08-19 +FROM python:3.12.5-slim LABEL maintainer="me@evangelos.ai" -RUN apt-get update && apt install -y --no-install-recommends git libglfw3 libglfw3-dev +RUN apt-get update && apt install -y --no-install-recommends git python3-pip libglfw3 libglfw3-dev WORKDIR /usr/src/app COPY requirements.txt ./ diff --git a/metaworld-jax/requirements.txt b/metaworld-jax/requirements.txt index f84e4c3..c319a77 100644 --- a/metaworld-jax/requirements.txt +++ b/metaworld-jax/requirements.txt @@ -1,11 +1,13 @@ # Jax +jax[cuda12]==0.4.31 +flax==0.8.5 distrax==0.1.5 # Metaworld metaworld @ git+https://github.com/rainx0r/Metaworld.git@f131964 # Pytorch CPU for logging ---extra-index-url https://download.pytorch.org/whl/cpu +# --extra-index-url https://download.pytorch.org/whl/cpu torch==2.4.0 # Logging