diff --git a/Dockerfile.tmpl b/Dockerfile.tmpl index e4338b61..369bd2ee 100644 --- a/Dockerfile.tmpl +++ b/Dockerfile.tmpl @@ -157,11 +157,12 @@ RUN pip install lightgbm==$LIGHTGBM_VERSION && \ {{ end }} # Install JAX +# b/316967430 Remove pin once new version of tensorflowjs is released (> 4.15.0) {{ if eq .Accelerator "gpu" }} -RUN pip install "jax[cuda11_local]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html && \ +RUN pip install "jax[cuda11_local]==0.4.21" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html && \ /tmp/clean-layer.sh {{ else }} -RUN pip install jax[cpu] && \ +RUN pip install jax[cpu]==0.4.21 && \ /tmp/clean-layer.sh {{ end }}