Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 3 additions & 6 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.9", "3.10", "3.13"]
python-version: ["3.11", "3.13"]

steps:
- uses: actions/checkout@v2
Expand All @@ -36,15 +36,12 @@ jobs:
pip install -r docs/requirements.txt
pip freeze
- name: Lint with mypy and ruff
if: matrix.python-version != '3.9'
run: |
make lint
- name: Build documentation
if: matrix.python-version != '3.9'
run: |
make docs
- name: Test documentation
if: matrix.python-version != '3.9'
run: |
make doctest
python -m doctest -v README.md
Expand All @@ -56,7 +53,7 @@ jobs:
needs: lint
strategy:
matrix:
python-version: ["3.9", "3.13"]
python-version: ["3.11", "3.13"]

steps:
- uses: actions/checkout@v2
Expand Down Expand Up @@ -107,7 +104,7 @@ jobs:
needs: lint
strategy:
matrix:
python-version: ["3.9", "3.13"]
python-version: ["3.11", "3.13"]

steps:
- uses: actions/checkout@v2
Expand Down
15 changes: 9 additions & 6 deletions numpyro/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,17 @@

from collections import OrderedDict
from collections.abc import Callable
from typing import Any, Optional, Protocol, Union, runtime_checkable
from typing import (
Any,
Optional,
ParamSpec,
Protocol,
TypeAlias,
Union,
runtime_checkable,
)
import weakref

try:
from typing import ParamSpec, TypeAlias
except ImportError:
from typing_extensions import ParamSpec, TypeAlias

import numpy as np

import jax
Expand Down
18 changes: 5 additions & 13 deletions numpyro/distributions/gof.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,19 +209,11 @@ def volume_of_sphere(dim, radius):


def get_nearest_neighbor_distances(samples):
try:
# This version scales as O(N log(N)).
from scipy.spatial import cKDTree

distances, indices = cKDTree(samples).query(samples, k=2)
return distances[:, 1]
except ImportError:
# This version scales as O(N^2).
x = samples
x2 = (x * x).sum(-1)
d2 = x2[:, None] + x2 - 2 * x @ x.T
min_d2 = np.partition(d2, 1)[:, 1]
return np.sqrt(np.clip(min_d2, 0, None))
# This version scales as O(N log(N)).
from scipy.spatial import cKDTree

distances, indices = cKDTree(samples).query(samples, k=2)
return distances[:, 1]


def vector_density_goodness_of_fit(samples, probs, *, dim=None, plot=False):
Expand Down
7 changes: 1 addition & 6 deletions numpyro/infer/elbo.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,9 @@
from collections import OrderedDict, defaultdict
from collections.abc import Callable
from functools import partial
from typing import TYPE_CHECKING, Any, TypedDict, TypeVar
from typing import TYPE_CHECKING, Any, TypeAlias, TypedDict, TypeVar
import warnings

try:
from typing import TypeAlias
except ImportError:
from typing_extensions import TypeAlias

import jax
from jax import eval_shape, random, vmap
from jax.lax import stop_gradient
Expand Down
39 changes: 6 additions & 33 deletions numpyro/ops/provenance.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,39 +2,12 @@
# SPDX-License-Identifier: Apache-2.0

import jax
from jax.api_util import flatten_fun, shaped_abstractify

try:
from jax.experimental.pjit import pjit_p
except ImportError:
from jax.extend.core.primitives import jit_p as pjit_p
try:
import jax.extend.linear_util as lu
except ImportError:
import jax.linear_util as lu

try:
from jax.extend.core import Literal
except ImportError:
from jax.core import Literal

try:
from jax.extend.core.primitives import call_p, closed_call_p
except ImportError:
from jax.core import call_p, closed_call_p

try:
from jax.api_util import debug_info
except ImportError:
debug_info = None

from jax.api_util import debug_info, flatten_fun, shaped_abstractify
from jax.extend.core import Literal
from jax.extend.core.primitives import call_p, closed_call_p, jit_p, xla_pmap_p
import jax.extend.linear_util as lu
from jax.interpreters.partial_eval import trace_to_jaxpr_dynamic

try:
from jax.extend.core.primitives import xla_pmap_p
except ImportError:
from jax.interpreters.pxla import xla_pmap_p


# Adapted from definition in jax v0.5.0
def _safe_map(f, *args):
Expand Down Expand Up @@ -151,8 +124,8 @@ def track_deps_closed_call_rule(eqn, provenance_inputs):
track_deps_rules[closed_call_p] = track_deps_closed_call_rule


def track_deps_pjit_rule(eqn, provenance_inputs):
def track_deps_jit_rule(eqn, provenance_inputs):
return track_deps_jaxpr(eqn.params["jaxpr"].jaxpr, provenance_inputs)


track_deps_rules[pjit_p] = track_deps_pjit_rule
track_deps_rules[jit_p] = track_deps_jit_rule
11 changes: 4 additions & 7 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
from setuptools import find_packages, setup

PROJECT_PATH = os.path.dirname(os.path.abspath(__file__))
_jax_version_constraints = ">=0.4.25"
_jaxlib_version_constraints = ">=0.4.25"
_jax_version_constraints = ">=0.7.0"
_jaxlib_version_constraints = ">=0.7.0"

# Find version
for line in open(os.path.join(PROJECT_PATH, "numpyro", "version.py")):
Expand Down Expand Up @@ -38,7 +38,6 @@
"multipledispatch",
"numpy",
"tqdm",
"typing_extensions; python_version < '3.10'",
],
extras_require={
"doc": [
Expand All @@ -60,8 +59,7 @@
"ty>=0.0.4",
],
"dev": [
"dm-haiku>=0.0.14; python_version >= '3.10'",
"dm-haiku<0.0.14; python_version < '3.10'",
"dm-haiku>=0.0.14",
"equinox",
"flax",
"funsor>=0.4.1",
Expand Down Expand Up @@ -103,10 +101,9 @@
"License :: OSI Approved :: Apache Software License",
"Operating System :: POSIX :: Linux",
"Operating System :: MacOS :: MacOS X",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Programming Language :: Python :: 3.13",
"Programming Language :: Python :: 3.14",
],
)
12 changes: 2 additions & 10 deletions test/ops/test_provenance.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,10 @@

import jax
from jax.api_util import flatten_fun_nokwargs

try:
import jax.extend.linear_util as lu
except ImportError:
import jax.linear_util as lu
from jax.extend.core.primitives import call_p, closed_call_p
import jax.extend.linear_util as lu
import jax.numpy as jnp

try:
from jax.extend.core.primitives import call_p, closed_call_p
except ImportError:
from jax.core import call_p, closed_call_p

try:
from jax.api_util import debug_info
except ImportError:
Expand Down