Skip to content

Commit

Permalink
Switch to nvidia's JAX container
Browse files Browse the repository at this point in the history
  • Loading branch information
rainx0r committed Aug 20, 2024
1 parent 4bc2dd6 commit 8a55033
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 14 deletions.
6 changes: 3 additions & 3 deletions metaworld-jax/Dockerfile
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
FROM python:3.12.5-slim
FROM ghcr.io/nvidia/jax:jax-2024-08-19
LABEL maintainer="[email protected]"

RUN apt-get update && apt install -y --no-install-recommends git python3-pip libglfw3 libglfw3-dev
RUN apt-get update && apt install -y --no-install-recommends git libglfw3 libglfw3-dev

WORKDIR /usr/src/app
COPY requirements.txt ./
RUN pip install --no-cache-dir --upgrade pip && \
pip install --no-cache-dir -r requirements.txt
pip install --no-cache-dir --upgrade -r requirements.txt

ENTRYPOINT ["python"]
12 changes: 1 addition & 11 deletions metaworld-jax/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,14 +1,4 @@
# Cuda
nvidia-cublas-cu12~=12.2.0
nvidia-cuda-cupti-cu12~=12.2.0
nvidia-cuda-nvcc-cu12~=12.2.0
nvidia-cuda-runtime-cu12~=12.2.0
nvidia-cusparse-cu12~=12.2.0
nvidia-nvjitlink-cu12~=12.2.0

# Jax
jax[cuda12]==0.4.31
flax==0.8.5
distrax==0.1.5

# Metaworld
Expand All @@ -21,5 +11,5 @@ torch==2.4.0
# Logging
wandb==0.17.6
tensorboard==2.17.1
orbax-checkpoint @ git+https://github.com/google/orbax/@2ce2fb27f9786442b08ba14c8767c460dd6e8a0a#subdirectory=checkpoint
orbax-checkpoint==0.6.0

0 comments on commit 8a55033

Please sign in to comment.