Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 39 additions & 16 deletions numpyro/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,16 @@

from collections import OrderedDict
from collections.abc import Callable
from typing import Any, Protocol, runtime_checkable
from typing import Any, Optional, Protocol, Union, runtime_checkable
import weakref

try:
from typing import ParamSpec, TypeAlias
except ImportError:
from typing_extensions import ParamSpec, TypeAlias

import numpy as np

import jax
from jax.typing import ArrayLike

Expand All @@ -21,10 +24,26 @@
TraceT: TypeAlias = OrderedDict[str, Message]


NonScalarArray = Union[np.ndarray, jax.Array]
"""An alias for array-like types excluding scalars."""


NumLike = Union[NonScalarArray, np.number, int, float, complex]
"""An alias for array-like types excluding `np.bool_` and `bool`."""


PyTree: TypeAlias = Any
"""A generic type for a pytree, i.e. a nested structure of lists, tuples, dicts, and arrays."""


@runtime_checkable
class ConstraintT(Protocol):
is_discrete: bool = ...
event_dim: int = ...
"""A protocol for typing constraints."""

@property
def is_discrete(self) -> bool: ...
@property
def event_dim(self) -> int: ...

def __call__(self, x: ArrayLike) -> ArrayLike: ...
def __repr__(self) -> str: ...
Expand Down Expand Up @@ -87,20 +106,24 @@ def is_discrete(self) -> bool: ...

@runtime_checkable
class TransformT(Protocol):
domain = ConstraintT
codomain = ConstraintT
_inv: "TransformT" = None

def __call__(self, x: ArrayLike) -> ArrayLike: ...
def _inverse(self, y: ArrayLike) -> ArrayLike: ...
def log_abs_det_jacobian(
self, x: ArrayLike, y: ArrayLike, intermediates=None
) -> ArrayLike: ...
def call_with_intermediates(self, x: ArrayLike) -> tuple[ArrayLike, None]: ...
def forward_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: ...
def inverse_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: ...
_inv: Optional[Union["TransformT", weakref.ref]] = ...

@property
def domain(self) -> ConstraintT: ...
@property
def codomain(self) -> ConstraintT: ...
@property
def inv(self) -> "TransformT": ...
@property
def sign(self) -> ArrayLike: ...
def sign(self) -> NumLike: ...

def __call__(self, x: NumLike) -> NumLike: ...
def _inverse(self, y: NumLike) -> NumLike: ...
def log_abs_det_jacobian(
self, x: NumLike, y: NumLike, intermediates: Optional[PyTree] = None
) -> NumLike: ...
def call_with_intermediates(
self, x: NumLike
) -> tuple[NumLike, Optional[PyTree]]: ...
def forward_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: ...
def inverse_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: ...
2 changes: 1 addition & 1 deletion numpyro/distributions/batch_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def _vmap_over_affine_transform(
dist_axes = copy.copy(dist)
dist_axes.loc = loc
dist_axes.scale = scale
dist_axes.domain = domain
dist_axes._domain = domain
return dist_axes


Expand Down
Loading