diff --git a/.github/workflows/jaxtests.yml b/.github/workflows/jaxtests.yml index 0e57bba3c6..5e7fc3d212 100644 --- a/.github/workflows/jaxtests.yml +++ b/.github/workflows/jaxtests.yml @@ -58,7 +58,7 @@ jobs: - name: Install jax specific dependencies run: | conda activate pymc3-dev-py39 - pip install numpyro tensorflow_probability + pip install numpyro tensorflow_probability "jax<0.2.21" - name: Run tests run: | python -m pytest -vv --cov=pymc3 --cov-report=xml --cov-report term --durations=50 $TEST_SUBSET