diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index a4e6e22f5..9be6a46c2 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -17,7 +17,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ["3.9", "3.10", "3.13"] + python-version: ["3.11", "3.13"] steps: - uses: actions/checkout@v2 @@ -36,15 +36,12 @@ jobs: pip install -r docs/requirements.txt pip freeze - name: Lint with mypy and ruff - if: matrix.python-version != '3.9' run: | make lint - name: Build documentation - if: matrix.python-version != '3.9' run: | make docs - name: Test documentation - if: matrix.python-version != '3.9' run: | make doctest python -m doctest -v README.md @@ -56,7 +53,7 @@ jobs: needs: lint strategy: matrix: - python-version: ["3.9", "3.13"] + python-version: ["3.11", "3.13"] steps: - uses: actions/checkout@v2 @@ -107,7 +104,7 @@ jobs: needs: lint strategy: matrix: - python-version: ["3.9", "3.13"] + python-version: ["3.11", "3.13"] steps: - uses: actions/checkout@v2 diff --git a/numpyro/_typing.py b/numpyro/_typing.py index 10cc1c9e8..791fa998a 100644 --- a/numpyro/_typing.py +++ b/numpyro/_typing.py @@ -4,14 +4,17 @@ from collections import OrderedDict from collections.abc import Callable -from typing import Any, Optional, Protocol, Union, runtime_checkable +from typing import ( + Any, + Optional, + ParamSpec, + Protocol, + TypeAlias, + Union, + runtime_checkable, +) import weakref -try: - from typing import ParamSpec, TypeAlias -except ImportError: - from typing_extensions import ParamSpec, TypeAlias - import numpy as np import jax diff --git a/numpyro/distributions/gof.py b/numpyro/distributions/gof.py index b1f917303..3b6236829 100644 --- a/numpyro/distributions/gof.py +++ b/numpyro/distributions/gof.py @@ -209,19 +209,11 @@ def volume_of_sphere(dim, radius): def get_nearest_neighbor_distances(samples): - try: - # This version scales as O(N log(N)). - from scipy.spatial import cKDTree - - distances, indices = cKDTree(samples).query(samples, k=2) - return distances[:, 1] - except ImportError: - # This version scales as O(N^2). - x = samples - x2 = (x * x).sum(-1) - d2 = x2[:, None] + x2 - 2 * x @ x.T - min_d2 = np.partition(d2, 1)[:, 1] - return np.sqrt(np.clip(min_d2, 0, None)) + # This version scales as O(N log(N)). + from scipy.spatial import cKDTree + + distances, indices = cKDTree(samples).query(samples, k=2) + return distances[:, 1] def vector_density_goodness_of_fit(samples, probs, *, dim=None, plot=False): diff --git a/numpyro/infer/elbo.py b/numpyro/infer/elbo.py index ab70c6d4e..99147b283 100644 --- a/numpyro/infer/elbo.py +++ b/numpyro/infer/elbo.py @@ -6,14 +6,9 @@ from collections import OrderedDict, defaultdict from collections.abc import Callable from functools import partial -from typing import TYPE_CHECKING, Any, TypedDict, TypeVar +from typing import TYPE_CHECKING, Any, TypeAlias, TypedDict, TypeVar import warnings -try: - from typing import TypeAlias -except ImportError: - from typing_extensions import TypeAlias - import jax from jax import eval_shape, random, vmap from jax.lax import stop_gradient diff --git a/numpyro/ops/provenance.py b/numpyro/ops/provenance.py index 21b0fe7ea..46ea75d56 100644 --- a/numpyro/ops/provenance.py +++ b/numpyro/ops/provenance.py @@ -2,39 +2,12 @@ # SPDX-License-Identifier: Apache-2.0 import jax -from jax.api_util import flatten_fun, shaped_abstractify - -try: - from jax.experimental.pjit import pjit_p -except ImportError: - from jax.extend.core.primitives import jit_p as pjit_p -try: - import jax.extend.linear_util as lu -except ImportError: - import jax.linear_util as lu - -try: - from jax.extend.core import Literal -except ImportError: - from jax.core import Literal - -try: - from jax.extend.core.primitives import call_p, closed_call_p -except ImportError: - from jax.core import call_p, closed_call_p - -try: - from jax.api_util import debug_info -except ImportError: - debug_info = None - +from jax.api_util import debug_info, flatten_fun, shaped_abstractify +from jax.extend.core import Literal +from jax.extend.core.primitives import call_p, closed_call_p, jit_p, xla_pmap_p +import jax.extend.linear_util as lu from jax.interpreters.partial_eval import trace_to_jaxpr_dynamic -try: - from jax.extend.core.primitives import xla_pmap_p -except ImportError: - from jax.interpreters.pxla import xla_pmap_p - # Adapted from definition in jax v0.5.0 def _safe_map(f, *args): @@ -151,8 +124,8 @@ def track_deps_closed_call_rule(eqn, provenance_inputs): track_deps_rules[closed_call_p] = track_deps_closed_call_rule -def track_deps_pjit_rule(eqn, provenance_inputs): +def track_deps_jit_rule(eqn, provenance_inputs): return track_deps_jaxpr(eqn.params["jaxpr"].jaxpr, provenance_inputs) -track_deps_rules[pjit_p] = track_deps_pjit_rule +track_deps_rules[jit_p] = track_deps_jit_rule diff --git a/setup.py b/setup.py index 33b3eea55..f0bef9f5a 100644 --- a/setup.py +++ b/setup.py @@ -9,8 +9,8 @@ from setuptools import find_packages, setup PROJECT_PATH = os.path.dirname(os.path.abspath(__file__)) -_jax_version_constraints = ">=0.4.25" -_jaxlib_version_constraints = ">=0.4.25" +_jax_version_constraints = ">=0.7.0" +_jaxlib_version_constraints = ">=0.7.0" # Find version for line in open(os.path.join(PROJECT_PATH, "numpyro", "version.py")): @@ -38,7 +38,6 @@ "multipledispatch", "numpy", "tqdm", - "typing_extensions; python_version < '3.10'", ], extras_require={ "doc": [ @@ -60,8 +59,7 @@ "ty>=0.0.4", ], "dev": [ - "dm-haiku>=0.0.14; python_version >= '3.10'", - "dm-haiku<0.0.14; python_version < '3.10'", + "dm-haiku>=0.0.14", "equinox", "flax", "funsor>=0.4.1", @@ -103,10 +101,9 @@ "License :: OSI Approved :: Apache Software License", "Operating System :: POSIX :: Linux", "Operating System :: MacOS :: MacOS X", - "Programming Language :: Python :: 3.9", - "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", "Programming Language :: Python :: 3.13", + "Programming Language :: Python :: 3.14", ], ) diff --git a/test/ops/test_provenance.py b/test/ops/test_provenance.py index caae230a0..bc0b2a5ab 100644 --- a/test/ops/test_provenance.py +++ b/test/ops/test_provenance.py @@ -7,18 +7,10 @@ import jax from jax.api_util import flatten_fun_nokwargs - -try: - import jax.extend.linear_util as lu -except ImportError: - import jax.linear_util as lu +from jax.extend.core.primitives import call_p, closed_call_p +import jax.extend.linear_util as lu import jax.numpy as jnp -try: - from jax.extend.core.primitives import call_p, closed_call_p -except ImportError: - from jax.core import call_p, closed_call_p - try: from jax.api_util import debug_info except ImportError: