From a1503982e1fd69227222b950e888dafbdb45fc78 Mon Sep 17 00:00:00 2001 From: Vincent Roseberry Date: Tue, 9 Jan 2024 14:11:34 -0800 Subject: [PATCH] Pin jax to fix tensorflowjs issue (#1348) http://b/316967430 --- Dockerfile.tmpl | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/Dockerfile.tmpl b/Dockerfile.tmpl index e4338b61..369bd2ee 100644 --- a/Dockerfile.tmpl +++ b/Dockerfile.tmpl @@ -157,11 +157,12 @@ RUN pip install lightgbm==$LIGHTGBM_VERSION && \ {{ end }} # Install JAX +# b/316967430 Remove pin once new version of tensorflowjs is released (> 4.15.0) {{ if eq .Accelerator "gpu" }} -RUN pip install "jax[cuda11_local]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html && \ +RUN pip install "jax[cuda11_local]==0.4.21" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html && \ /tmp/clean-layer.sh {{ else }} -RUN pip install jax[cpu] && \ +RUN pip install jax[cpu]==0.4.21 && \ /tmp/clean-layer.sh {{ end }}