diff --git a/numpyro/_typing.py b/numpyro/_typing.py index c9ad8819a..3be15dcb8 100644 --- a/numpyro/_typing.py +++ b/numpyro/_typing.py @@ -4,13 +4,15 @@ from collections import OrderedDict from collections.abc import Callable -from typing import Any, Protocol, runtime_checkable +from typing import Any, Optional, Protocol, Union, runtime_checkable try: from typing import ParamSpec, TypeAlias except ImportError: from typing_extensions import ParamSpec, TypeAlias +import numpy as np + import jax from jax.typing import ArrayLike @@ -21,6 +23,18 @@ TraceT: TypeAlias = OrderedDict[str, Message] +NonScalarArray = Union[np.ndarray, jax.Array] +"""An alias for array-like types excluding scalars.""" + + +NumLike = Union[NonScalarArray, np.number, int, float, complex] +"""An alias for array-like types excluding `np.bool_` and `bool`.""" + + +PyTree: TypeAlias = Any +"""A generic type for a pytree, i.e. a nested structure of lists, tuples, dicts, and arrays.""" + + @runtime_checkable class ConstraintT(Protocol): is_discrete: bool = ... @@ -87,20 +101,25 @@ def is_discrete(self) -> bool: ... @runtime_checkable class TransformT(Protocol): - domain = ConstraintT - codomain = ConstraintT - _inv: "TransformT" = None + domain: ConstraintT = ... + codomain: ConstraintT = ... + _inv: Optional["TransformT"] = ... - def __call__(self, x: ArrayLike) -> ArrayLike: ... - def _inverse(self, y: ArrayLike) -> ArrayLike: ... + def __call__(self, x: NumLike) -> NumLike: ... + def _inverse(self, y: NumLike) -> NumLike: ... def log_abs_det_jacobian( - self, x: ArrayLike, y: ArrayLike, intermediates=None - ) -> ArrayLike: ... - def call_with_intermediates(self, x: ArrayLike) -> tuple[ArrayLike, None]: ... + self, + x: NumLike, + y: NumLike, + intermediates: Optional[PyTree] = None, + ) -> NumLike: ... + def call_with_intermediates( + self, x: NumLike + ) -> tuple[NumLike, Optional[PyTree]]: ... def forward_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: ... def inverse_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: ... @property def inv(self) -> "TransformT": ... @property - def sign(self) -> ArrayLike: ... + def sign(self) -> NumLike: ... diff --git a/numpyro/distributions/constraints.py b/numpyro/distributions/constraints.py index 7671b0685..317b56b2e 100644 --- a/numpyro/distributions/constraints.py +++ b/numpyro/distributions/constraints.py @@ -801,18 +801,18 @@ def tree_flatten(self): corr_cholesky: ConstraintT = _CorrCholesky() corr_matrix: ConstraintT = _CorrMatrix() dependent: ConstraintT = _Dependent() -greater_than: ConstraintT = _GreaterThan -greater_than_eq: ConstraintT = _GreaterThanEq -less_than: ConstraintT = _LessThan -less_than_eq: ConstraintT = _LessThanEq -independent: ConstraintT = _IndependentConstraint -integer_interval: ConstraintT = _IntegerInterval -integer_greater_than: ConstraintT = _IntegerGreaterThan -interval: ConstraintT = _Interval +greater_than = _GreaterThan +greater_than_eq = _GreaterThanEq +less_than = _LessThan +less_than_eq = _LessThanEq +independent = _IndependentConstraint +integer_interval = _IntegerInterval +integer_greater_than = _IntegerGreaterThan +interval = _Interval l1_ball: ConstraintT = _L1Ball() lower_cholesky: ConstraintT = _LowerCholesky() scaled_unit_lower_cholesky: ConstraintT = _ScaledUnitLowerCholesky() -multinomial: ConstraintT = _Multinomial +multinomial = _Multinomial nonnegative: ConstraintT = _Nonnegative() nonnegative_integer: ConstraintT = _IntegerNonnegative() ordered_vector: ConstraintT = _OrderedVector() @@ -830,5 +830,5 @@ def tree_flatten(self): softplus_positive: ConstraintT = _SoftplusPositive() sphere: ConstraintT = _Sphere() unit_interval: ConstraintT = _UnitInterval() -open_interval: ConstraintT = _OpenInterval -zero_sum: ConstraintT = _ZeroSum +open_interval = _OpenInterval +zero_sum = _ZeroSum diff --git a/numpyro/distributions/transforms.py b/numpyro/distributions/transforms.py index 082b2df3a..6f187efd1 100644 --- a/numpyro/distributions/transforms.py +++ b/numpyro/distributions/transforms.py @@ -1,8 +1,9 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 + import math -from typing import Optional, Sequence, Tuple +from typing import Any, Optional, Sequence, Tuple, Union, cast import warnings import weakref @@ -17,7 +18,7 @@ from jax.tree_util import register_pytree_node from jax.typing import ArrayLike -from numpyro._typing import TransformT +from numpyro._typing import ConstraintT, NonScalarArray, NumLike, PyTree, TransformT from numpyro.distributions import constraints from numpyro.distributions.util import ( add_diag, @@ -59,7 +60,7 @@ ] -def _clipped_expit(x: ArrayLike) -> ArrayLike: +def _clipped_expit(x: NumLike) -> NumLike: finfo = jnp.finfo(jnp.result_type(x)) return jnp.clip(expit(x), finfo.tiny, 1.0 - finfo.eps) @@ -67,36 +68,39 @@ def _clipped_expit(x: ArrayLike) -> ArrayLike: class Transform(object): domain = constraints.real codomain = constraints.real - _inv = None + _inv: Optional[Union[TransformT, weakref.ReferenceType]] = None def __init_subclass__(cls, **kwargs): super().__init_subclass__(**kwargs) register_pytree_node(cls, cls.tree_flatten, cls.tree_unflatten) @property - def inv(self) -> TransformT: + def inv(self: TransformT) -> TransformT: inv = None - if self._inv is not None: + if (self._inv is not None) and isinstance(self._inv, weakref.ReferenceType): inv = self._inv() if inv is None: - inv = _InverseTransform(self) - self._inv = weakref.ref(inv) + inv = cast(TransformT, _InverseTransform(self)) + self._inv = cast(TransformT, weakref.ref(inv)) return inv - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: Union[NonScalarArray, Any]) -> Union[NonScalarArray, Any]: raise NotImplementedError - def _inverse(self, y: ArrayLike) -> ArrayLike: + def _inverse(self, y: Union[NonScalarArray, Any]) -> Union[NonScalarArray, Any]: raise NotImplementedError def log_abs_det_jacobian( - self, x: ArrayLike, y: ArrayLike, intermediates=None - ) -> ArrayLike: + self, + x: Union[NonScalarArray, Any], + y: Union[NonScalarArray, Any], + intermediates: Optional[PyTree] = None, + ) -> Union[NonScalarArray, Any]: raise NotImplementedError def call_with_intermediates( - self, x: ArrayLike - ) -> Tuple[ArrayLike, Optional[ArrayLike]]: + self, x: Union[NonScalarArray, Any] + ) -> Tuple[Union[NonScalarArray, Any], Optional[PyTree]]: return self(x), None def forward_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: @@ -114,7 +118,7 @@ def inverse_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: return shape @property - def sign(self) -> ArrayLike: + def sign(self) -> NumLike: """ Sign of the derivative of the transform if it is bijective. """ @@ -148,37 +152,40 @@ class ParameterFreeTransform(Transform): def tree_flatten(self): return (), ((), dict()) - def __eq__(self, other: TransformT) -> bool: + def __eq__(self, other: object) -> bool: return isinstance(other, type(self)) class _InverseTransform(Transform): def __init__(self, transform: TransformT) -> None: super().__init__() - self._inv = transform + self._inv: TransformT = transform @property - def domain(self) -> constraints.Constraint: + def domain(self) -> ConstraintT: # type: ignore[override] return self._inv.codomain @property - def codomain(self) -> constraints.Constraint: + def codomain(self) -> ConstraintT: # type: ignore[override] return self._inv.domain @property - def sign(self) -> ArrayLike: + def sign(self) -> NumLike: return self._inv.sign @property def inv(self) -> TransformT: return self._inv - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: NumLike) -> NumLike: return self._inv._inverse(x) def log_abs_det_jacobian( - self, x: ArrayLike, y: ArrayLike, intermediates=None - ) -> ArrayLike: + self, + x: NumLike, + y: NumLike, + intermediates: Optional[PyTree] = None, + ) -> NumLike: # NB: we don't use intermediates for inverse transform return -self._inv.log_abs_det_jacobian(y, x, None) @@ -191,7 +198,7 @@ def inverse_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: def tree_flatten(self): return (self._inv,), (("_inv",), dict()) - def __eq__(self, other: TransformT) -> bool: + def __eq__(self, other: object) -> bool: if not isinstance(other, _InverseTransform): return False return self._inv == other._inv @@ -201,13 +208,13 @@ class AbsTransform(ParameterFreeTransform): domain = constraints.real codomain = constraints.positive - def __eq__(self, other: TransformT) -> bool: + def __eq__(self, other: object) -> bool: return isinstance(other, AbsTransform) - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: NonScalarArray) -> NonScalarArray: return jnp.abs(x) - def _inverse(self, y: ArrayLike) -> ArrayLike: + def _inverse(self, y: NonScalarArray) -> NonScalarArray: warnings.warn( "AbsTransform is not a bijective transform." " The inverse of `y` will be `y`.", @@ -226,14 +233,14 @@ def __init__( self, loc: ArrayLike, scale: ArrayLike, - domain: constraints.Constraint = constraints.real, + domain: ConstraintT = constraints.real, ): self.loc = loc self.scale = scale self.domain = domain @property - def codomain(self) -> constraints.Constraint: + def codomain(self) -> ConstraintT: # type: ignore[override] if self.domain is constraints.real: return constraints.real elif isinstance(self.domain, constraints.greater_than): @@ -244,35 +251,40 @@ def codomain(self) -> constraints.Constraint: return constraints.greater_than(self(self.domain.lower_bound)) elif isinstance(self.domain, constraints.less_than): if not_jax_tracer(self.scale) and np.all(np.less(self.scale, 0)): - return constraints.greater_than(self(self.domain.upper_bound)) + return constraints.greater_than(self(self.domain.upper_bound)) # type: ignore[arg-type] # we suppose scale > 0 for any tracer else: - return constraints.less_than(self(self.domain.upper_bound)) + return constraints.less_than(self(self.domain.upper_bound)) # type: ignore[arg-type] elif isinstance(self.domain, constraints.interval): if not_jax_tracer(self.scale) and np.all(np.less(self.scale, 0)): - return constraints.interval( - self(self.domain.upper_bound), self(self.domain.lower_bound) + return constraints.interval( # type: ignore[arg-type] + self(self.domain.upper_bound), # type: ignore[arg-type] + self(self.domain.lower_bound), # type: ignore[arg-type] ) else: - return constraints.interval( - self(self.domain.lower_bound), self(self.domain.upper_bound) + return constraints.interval( # type: ignore[arg-type] + self(self.domain.lower_bound), # type: ignore[arg-type] + self(self.domain.upper_bound), # type: ignore[arg-type] ) else: raise NotImplementedError @property - def sign(self) -> ArrayLike: + def sign(self) -> NumLike: return jnp.sign(self.scale) - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: NumLike) -> NumLike: return self.loc + self.scale * x - def _inverse(self, y: ArrayLike) -> ArrayLike: + def _inverse(self, y: NumLike) -> NumLike: return (y - self.loc) / self.scale def log_abs_det_jacobian( - self, x: ArrayLike, y: ArrayLike, intermediates=None - ) -> ArrayLike: + self, + x: NumLike, + y: NumLike, + intermediates: Optional[PyTree] = None, + ) -> NumLike: return jnp.broadcast_to(jnp.log(jnp.abs(self.scale)), jnp.shape(x)) def forward_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: @@ -288,13 +300,13 @@ def inverse_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: def tree_flatten(self): return (self.loc, self.scale, self.domain), (("loc", "scale", "domain"), dict()) - def __eq__(self, other: TransformT) -> bool: + def __eq__(self, other: object) -> bool: if not isinstance(other, AffineTransform): return False return ( - jnp.array_equal(self.loc, other.loc) - & jnp.array_equal(self.scale, other.scale) - & (self.domain == other.domain) + jnp.array_equal(self.loc, other.loc) # type: ignore[return-value] + & jnp.array_equal(self.scale, other.scale) # type: ignore[return-value] + & (self.domain == other.domain) # type: ignore[return-value] ) @@ -321,19 +333,19 @@ def __init__(self, parts: Sequence[TransformT]) -> None: self.parts = parts @property - def domain(self) -> constraints.Constraint: + def domain(self) -> ConstraintT: # type: ignore[override] input_event_dim = _get_compose_transform_input_event_dim(self.parts) first_input_event_dim = self.parts[0].domain.event_dim assert input_event_dim >= first_input_event_dim if input_event_dim == first_input_event_dim: return self.parts[0].domain else: - return constraints.independent( + return constraints.independent( # type: ignore[return-value] self.parts[0].domain, input_event_dim - first_input_event_dim ) @property - def codomain(self) -> constraints.Constraint: + def codomain(self) -> ConstraintT: # type: ignore[override] output_event_dim = _get_compose_transform_output_event_dim(self.parts) last_output_event_dim = self.parts[-1].codomain.event_dim assert output_event_dim >= last_output_event_dim @@ -342,28 +354,31 @@ def codomain(self) -> constraints.Constraint: else: return constraints.independent( self.parts[-1].codomain, output_event_dim - last_output_event_dim - ) + ) # type: ignore[return-value] @property - def sign(self) -> ArrayLike: - sign = 1 + def sign(self) -> NumLike: + sign: NumLike = 1 for transform in self.parts: sign *= transform.sign return sign - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: NumLike) -> NumLike: for part in self.parts: x = part(x) return x - def _inverse(self, y: ArrayLike) -> ArrayLike: + def _inverse(self, y: NumLike) -> NumLike: for part in self.parts[::-1]: y = part.inv(y) return y def log_abs_det_jacobian( - self, x: ArrayLike, y: ArrayLike, intermediates=None - ) -> ArrayLike: + self, + x: NumLike, + y: NumLike, + intermediates: Optional[PyTree] = None, + ) -> NumLike: if intermediates is not None: if len(intermediates) != len(self.parts): raise ValueError( @@ -389,10 +404,8 @@ def log_abs_det_jacobian( result = result + sum_rightmost(logdet, input_event_dim - part.domain.event_dim) return result - def call_with_intermediates( - self, x: ArrayLike - ) -> Tuple[ArrayLike, Optional[ArrayLike]]: - intermediates = [] + def call_with_intermediates(self, x: NumLike) -> Tuple[NumLike, Optional[PyTree]]: + intermediates: list[Optional[PyTree]] = [] for part in self.parts[:-1]: x, inter = part.call_with_intermediates(x) intermediates.append([x, inter]) @@ -414,10 +427,10 @@ def inverse_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: def tree_flatten(self): return (self.parts,), (("parts",), {}) - def __eq__(self, other: TransformT) -> bool: + def __eq__(self, other: object) -> bool: if not isinstance(other, ComposeTransform): return False - return jnp.logical_and(*(p1 == p2 for p1, p2 in zip(self.parts, other.parts))) + return jnp.logical_and(*(p1 == p2 for p1, p2 in zip(self.parts, other.parts))) # type: ignore[return-value] def _matrix_forward_shape(shape: tuple[int, ...], offset: int = 0) -> tuple[int, ...]: @@ -452,15 +465,18 @@ class CholeskyTransform(ParameterFreeTransform): domain = constraints.positive_definite codomain = constraints.lower_cholesky - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: NonScalarArray) -> NonScalarArray: return jnp.linalg.cholesky(x) - def _inverse(self, y: ArrayLike) -> ArrayLike: + def _inverse(self, y: NonScalarArray) -> NonScalarArray: return jnp.matmul(y, jnp.swapaxes(y, -2, -1)) def log_abs_det_jacobian( - self, x: ArrayLike, y: ArrayLike, intermediates=None - ) -> ArrayLike: + self, + x: NonScalarArray, + y: NonScalarArray, + intermediates: Optional[PyTree] = None, + ) -> NumLike: # Ref: http://web.mit.edu/18.325/www/handouts/handout2.pdf page 13 n = jnp.shape(x)[-1] order = -jnp.arange(n, 0, -1) @@ -499,12 +515,12 @@ class :class:`StickBreakingTransform` to transform :math:`X_i` into a domain = constraints.real_vector codomain = constraints.corr_cholesky - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: NonScalarArray) -> NonScalarArray: # we interchange step 1 and step 2.a for a better performance t = jnp.tanh(x) return signed_stick_breaking_tril(t) - def _inverse(self, y: ArrayLike) -> ArrayLike: + def _inverse(self, y: NonScalarArray) -> NonScalarArray: # inverse stick-breaking z1m_cumprod = 1 - jnp.cumsum(y * y, axis=-1) pad_width = [(0, 0)] * y.ndim @@ -519,8 +535,11 @@ def _inverse(self, y: ArrayLike) -> ArrayLike: return jnp.arctanh(t) def log_abs_det_jacobian( - self, x: ArrayLike, y: ArrayLike, intermediates=None - ) -> ArrayLike: + self, + x: NonScalarArray, + y: NonScalarArray, + intermediates: Optional[PyTree] = None, + ) -> NumLike: # NB: because domain and codomain are two spaces with different dimensions, determinant of # Jacobian is not well-defined. Here we return `log_abs_det_jacobian` of `x` and the # flatten lower triangular part of `y`. @@ -552,8 +571,11 @@ class CorrMatrixCholeskyTransform(CholeskyTransform): codomain = constraints.corr_cholesky def log_abs_det_jacobian( - self, x: ArrayLike, y: ArrayLike, intermediates=None - ) -> ArrayLike: + self, + x: NonScalarArray, + y: NonScalarArray, + intermediates: Optional[PyTree] = None, + ) -> NumLike: # NB: see derivation in LKJCholesky implementation n = jnp.shape(x)[-1] order = -jnp.arange(n - 1, -1, -1) @@ -565,11 +587,11 @@ class ExpTransform(Transform): # TODO: refine domain/codomain logic through setters, especially when # transforms for inverses are supported - def __init__(self, domain=constraints.real): + def __init__(self, domain: ConstraintT = constraints.real): self.domain = domain @property - def codomain(self) -> constraints.Constraint: + def codomain(self) -> ConstraintT: # type: ignore[override] if self.domain is constraints.ordered_vector: return constraints.positive_ordered_vector elif self.domain is constraints.real: @@ -577,29 +599,32 @@ def codomain(self) -> constraints.Constraint: elif isinstance(self.domain, constraints.greater_than): return constraints.greater_than(self.__call__(self.domain.lower_bound)) elif isinstance(self.domain, constraints.interval): - return constraints.interval( - self.__call__(self.domain.lower_bound), - self.__call__(self.domain.upper_bound), + return constraints.interval( # type: ignore[arg-type] + self.__call__(self.domain.lower_bound), # type: ignore[arg-type] + self.__call__(self.domain.upper_bound), # type: ignore[arg-type] ) else: raise NotImplementedError - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: NumLike) -> NumLike: # XXX consider to clamp from below for stability if necessary return jnp.exp(x) - def _inverse(self, y: ArrayLike) -> ArrayLike: + def _inverse(self, y: NumLike) -> NumLike: return jnp.log(y) def log_abs_det_jacobian( - self, x: ArrayLike, y: ArrayLike, intermediates=None - ) -> ArrayLike: + self, + x: NumLike, + y: NumLike, + intermediates: Optional[PyTree] = None, + ) -> NumLike: return x def tree_flatten(self): return (self.domain,), (("domain",), dict()) - def __eq__(self, other: TransformT) -> bool: + def __eq__(self, other: object) -> bool: if not isinstance(other, ExpTransform): return False return self.domain == other.domain @@ -608,15 +633,18 @@ def __eq__(self, other: TransformT) -> bool: class IdentityTransform(ParameterFreeTransform): sign = 1 - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: NonScalarArray) -> NonScalarArray: return x - def _inverse(self, y: ArrayLike) -> ArrayLike: + def _inverse(self, y: NonScalarArray) -> NonScalarArray: return y def log_abs_det_jacobian( - self, x: ArrayLike, y: ArrayLike, intermediates=None - ) -> ArrayLike: + self, + x: NonScalarArray, + y: NonScalarArray, + intermediates: Optional[PyTree] = None, + ) -> NumLike: return jnp.zeros_like(x) @@ -628,7 +656,7 @@ class IndependentTransform(Transform): """ def __init__( - self, base_transform: TransformT, reinterpreted_batch_ndims: int + self, base_transform: Transform, reinterpreted_batch_ndims: int ) -> None: assert isinstance(base_transform, Transform) assert isinstance(reinterpreted_batch_ndims, int) @@ -638,26 +666,29 @@ def __init__( super().__init__() @property - def domain(self) -> constraints.Constraint: + def domain(self) -> ConstraintT: # type: ignore[override] return constraints.independent( self.base_transform.domain, self.reinterpreted_batch_ndims - ) + ) # type: ignore[return-value] @property - def codomain(self) -> constraints.Constraint: + def codomain(self) -> ConstraintT: # type: ignore[override] return constraints.independent( self.base_transform.codomain, self.reinterpreted_batch_ndims - ) + ) # type: ignore[return-value] - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: NonScalarArray) -> NonScalarArray: return self.base_transform(x) - def _inverse(self, y: ArrayLike) -> ArrayLike: + def _inverse(self, y: NonScalarArray) -> NonScalarArray: return self.base_transform._inverse(y) def log_abs_det_jacobian( - self, x: ArrayLike, y: ArrayLike, intermediates=None - ) -> ArrayLike: + self, + x: NonScalarArray, + y: NonScalarArray, + intermediates: Optional[PyTree] = None, + ) -> NumLike: result = self.base_transform.log_abs_det_jacobian( x, y, intermediates=intermediates ) @@ -666,9 +697,7 @@ def log_abs_det_jacobian( raise ValueError(f"Expected x.dim() >= {expected} but got {jnp.ndim(x)}") return sum_rightmost(result, self.reinterpreted_batch_ndims) - def call_with_intermediates( - self, x: ArrayLike - ) -> Tuple[ArrayLike, Optional[ArrayLike]]: + def call_with_intermediates(self, x: NumLike) -> Tuple[NumLike, Optional[PyTree]]: return self.base_transform.call_with_intermediates(x) def forward_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: @@ -683,12 +712,12 @@ def tree_flatten(self): dict(), ) - def __eq__(self, other: TransformT) -> bool: + def __eq__(self, other: object) -> bool: if not isinstance(other, IndependentTransform): return False return (self.base_transform == other.base_transform) & ( self.reinterpreted_batch_ndims == other.reinterpreted_batch_ndims - ) + ) # type: ignore[return-value] class L1BallTransform(ParameterFreeTransform): @@ -699,7 +728,7 @@ class L1BallTransform(ParameterFreeTransform): domain = constraints.real_vector codomain = constraints.l1_ball - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: NonScalarArray) -> NonScalarArray: # transform to (-1, 1) interval t = jnp.tanh(x) @@ -709,7 +738,7 @@ def __call__(self, x: ArrayLike) -> ArrayLike: remainder = jnp.pad(remainder, pad_width, mode="constant", constant_values=1.0) return t * remainder - def _inverse(self, y: ArrayLike) -> ArrayLike: + def _inverse(self, y: NonScalarArray) -> NonScalarArray: # inverse stick-breaking remainder = 1 - jnp.cumsum(jnp.abs(y[..., :-1]), axis=-1) pad_width = [(0, 0)] * (y.ndim - 1) + [(1, 0)] @@ -723,8 +752,11 @@ def _inverse(self, y: ArrayLike) -> ArrayLike: return jnp.arctanh(t) def log_abs_det_jacobian( - self, x: ArrayLike, y: ArrayLike, intermediates=None - ) -> ArrayLike: + self, + x: NonScalarArray, + y: NonScalarArray, + intermediates: Optional[PyTree] = None, + ) -> NumLike: # compute stick-breaking logdet # t1 -> t1 # t2 -> t2 * (1 - abs(t1)) @@ -765,7 +797,7 @@ class LowerCholeskyAffine(Transform): domain = constraints.real_vector codomain = constraints.real_vector - def __init__(self, loc: ArrayLike, scale_tril: Array): + def __init__(self, loc: NonScalarArray, scale_tril: NonScalarArray): if jnp.ndim(scale_tril) != 2: raise ValueError( "Only support 2-dimensional scale_tril matrix. " @@ -775,12 +807,12 @@ def __init__(self, loc: ArrayLike, scale_tril: Array): self.loc = loc self.scale_tril = scale_tril - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: NonScalarArray) -> NonScalarArray: return self.loc + jnp.squeeze( jnp.matmul(self.scale_tril, x[..., jnp.newaxis]), axis=-1 ) - def _inverse(self, y: ArrayLike) -> ArrayLike: + def _inverse(self, y: NonScalarArray) -> NonScalarArray: y = y - self.loc original_shape = jnp.shape(y) yt = jnp.reshape(y, (-1, original_shape[-1])).T @@ -788,8 +820,11 @@ def _inverse(self, y: ArrayLike) -> ArrayLike: return jnp.reshape(xt.T, original_shape) def log_abs_det_jacobian( - self, x: ArrayLike, y: ArrayLike, intermediates=None - ) -> ArrayLike: + self, + x: NonScalarArray, + y: NonScalarArray, + intermediates: Optional[PyTree] = None, + ) -> NumLike: return jnp.broadcast_to( jnp.log(jnp.diagonal(self.scale_tril, axis1=-2, axis2=-1)).sum(-1), jnp.shape(x)[:-1], @@ -808,12 +843,12 @@ def inverse_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: def tree_flatten(self): return (self.loc, self.scale_tril), (("loc", "scale_tril"), dict()) - def __eq__(self, other: TransformT) -> bool: + def __eq__(self, other: object) -> bool: if not isinstance(other, LowerCholeskyAffine): return False return jnp.array_equal(self.loc, other.loc) & jnp.array_equal( self.scale_tril, other.scale_tril - ) + ) # type: ignore[return-value] class LowerCholeskyTransform(ParameterFreeTransform): @@ -827,21 +862,24 @@ class LowerCholeskyTransform(ParameterFreeTransform): domain = constraints.real_vector codomain = constraints.lower_cholesky - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: NonScalarArray) -> NonScalarArray: n = round((math.sqrt(1 + 8 * x.shape[-1]) - 1) / 2) z = vec_to_tril_matrix(x[..., :-n], diagonal=-1) diag = jnp.exp(x[..., -n:]) return add_diag(z, diag) - def _inverse(self, y: ArrayLike) -> ArrayLike: + def _inverse(self, y: NonScalarArray) -> NonScalarArray: z = matrix_to_tril_vec(y, diagonal=-1) return jnp.concatenate( [z, jnp.log(jnp.diagonal(y, axis1=-2, axis2=-1))], axis=-1 ) def log_abs_det_jacobian( - self, x: ArrayLike, y: ArrayLike, intermediates=None - ) -> ArrayLike: + self, + x: NonScalarArray, + y: NonScalarArray, + intermediates: Optional[PyTree] = None, + ) -> NumLike: # the jacobian is diagonal, so logdet is the sum of diagonal `exp` transform n = round((math.sqrt(1 + 8 * x.shape[-1]) - 1) / 2) return x[..., -n:].sum(-1) @@ -869,20 +907,23 @@ class ScaledUnitLowerCholeskyTransform(LowerCholeskyTransform): domain = constraints.real_vector codomain = constraints.scaled_unit_lower_cholesky - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: NonScalarArray) -> NonScalarArray: n = round((math.sqrt(1 + 8 * x.shape[-1]) - 1) / 2) z = vec_to_tril_matrix(x[..., :-n], diagonal=-1) diag = softplus(x[..., -n:]) - return add_diag(z, 1) * diag[..., None] + return add_diag(z, 1) * diag[..., None] # type: ignore[arg-type] - def _inverse(self, y: ArrayLike) -> ArrayLike: + def _inverse(self, y: NonScalarArray) -> NonScalarArray: diag = jnp.diagonal(y, axis1=-2, axis2=-1) z = matrix_to_tril_vec(y / diag[..., None], diagonal=-1) return jnp.concatenate([z, _softplus_inv(diag)], axis=-1) def log_abs_det_jacobian( - self, x: ArrayLike, y: ArrayLike, intermediates=None - ) -> ArrayLike: + self, + x: NonScalarArray, + y: NonScalarArray, + intermediates: Optional[PyTree] = None, + ) -> NumLike: n = round((math.sqrt(1 + 8 * x.shape[-1]) - 1) / 2) diag = x[..., -n:] diag_softplus = jnp.diagonal(y, axis1=-2, axis2=-1) @@ -913,17 +954,20 @@ class OrderedTransform(ParameterFreeTransform): domain = constraints.real_vector codomain = constraints.ordered_vector - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: NonScalarArray) -> NonScalarArray: z = jnp.concatenate([x[..., :1], jnp.exp(x[..., 1:])], axis=-1) return jnp.cumsum(z, axis=-1) - def _inverse(self, y: ArrayLike) -> ArrayLike: + def _inverse(self, y: NonScalarArray) -> NonScalarArray: x = jnp.log(y[..., 1:] - y[..., :-1]) return jnp.concatenate([y[..., :1], x], axis=-1) def log_abs_det_jacobian( - self, x: ArrayLike, y: ArrayLike, intermediates=None - ) -> ArrayLike: + self, + x: NonScalarArray, + y: NonScalarArray, + intermediates: Optional[PyTree] = None, + ) -> NumLike: return jnp.sum(x[..., 1:], -1) @@ -934,12 +978,12 @@ class PermuteTransform(Transform): def __init__(self, permutation: Array) -> None: self.permutation = permutation - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: NonScalarArray) -> NonScalarArray: return x[..., self.permutation] - def _inverse(self, y: ArrayLike) -> ArrayLike: + def _inverse(self, y: NonScalarArray) -> NonScalarArray: size = self.permutation.size - permutation_inv = ( + permutation_inv: NonScalarArray = ( jnp.zeros(size, dtype=jnp.result_type(int)) .at[self.permutation] .set(jnp.arange(size)) @@ -947,17 +991,20 @@ def _inverse(self, y: ArrayLike) -> ArrayLike: return y[..., permutation_inv] def log_abs_det_jacobian( - self, x: ArrayLike, y: ArrayLike, intermediates=None - ) -> ArrayLike: + self, + x: NonScalarArray, + y: NonScalarArray, + intermediates: Optional[PyTree] = None, + ) -> NumLike: return jnp.full(jnp.shape(x)[:-1], 0.0) def tree_flatten(self): return (self.permutation,), (("permutation",), dict()) - def __eq__(self, other: TransformT) -> bool: + def __eq__(self, other: object) -> bool: if not isinstance(other, PermuteTransform): return False - return jnp.array_equal(self.permutation, other.permutation) + return jnp.array_equal(self.permutation, other.permutation) # type: ignore[return-value] class PowerTransform(Transform): @@ -967,15 +1014,18 @@ class PowerTransform(Transform): def __init__(self, exponent: ArrayLike) -> None: self.exponent = exponent - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: NumLike) -> NumLike: return jnp.power(x, self.exponent) - def _inverse(self, y: ArrayLike) -> ArrayLike: + def _inverse(self, y: NumLike) -> NumLike: return jnp.power(y, 1 / self.exponent) def log_abs_det_jacobian( - self, x: ArrayLike, y: ArrayLike, intermediates=None - ) -> ArrayLike: + self, + x: NumLike, + y: NumLike, + intermediates: Optional[PyTree] = None, + ) -> NumLike: return jnp.log(jnp.abs(self.exponent * y / x)) def forward_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: @@ -987,13 +1037,13 @@ def inverse_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: def tree_flatten(self): return (self.exponent,), (("exponent",), dict()) - def __eq__(self, other: TransformT) -> bool: + def __eq__(self, other: object) -> bool: if not isinstance(other, PowerTransform): return False - return jnp.array_equal(self.exponent, other.exponent) + return jnp.array_equal(self.exponent, other.exponent) # type: ignore[return-value] @property - def sign(self) -> ArrayLike: + def sign(self) -> NumLike: return jnp.sign(self.exponent) @@ -1001,16 +1051,19 @@ class SigmoidTransform(ParameterFreeTransform): codomain = constraints.unit_interval sign = 1 - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: NumLike) -> NumLike: return _clipped_expit(x) - def _inverse(self, y: ArrayLike) -> ArrayLike: + def _inverse(self, y: NumLike) -> NumLike: return logit(y) def log_abs_det_jacobian( - self, x: ArrayLike, y: ArrayLike, intermediates=None - ) -> ArrayLike: - return -softplus(x) - softplus(-x) + self, + x: NumLike, + y: NumLike, + intermediates: Optional[PyTree] = None, + ) -> NumLike: + return -softplus(x) - softplus(-x) # type: ignore[operator] class SimplexToOrderedTransform(Transform): @@ -1045,12 +1098,12 @@ class SimplexToOrderedTransform(Transform): def __init__(self, anchor_point: ArrayLike = 0.0) -> None: self.anchor_point = anchor_point - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: NonScalarArray) -> NonScalarArray: s = jnp.cumsum(x[..., :-1], axis=-1) y = logit(s) + jnp.expand_dims(self.anchor_point, -1) return y - def _inverse(self, y: ArrayLike) -> ArrayLike: + def _inverse(self, y: NonScalarArray) -> NonScalarArray: y = y - jnp.expand_dims(self.anchor_point, -1) s = expit(y) # x0 = s0, x1 = s1 - s0, x2 = s2 - s1,..., xn = 1 - s[n-1] @@ -1061,8 +1114,11 @@ def _inverse(self, y: ArrayLike) -> ArrayLike: return x def log_abs_det_jacobian( - self, x: ArrayLike, y: ArrayLike, intermediates=None - ) -> ArrayLike: + self, + x: NonScalarArray, + y: NonScalarArray, + intermediates: Optional[PyTree] = None, + ) -> NumLike: # |dp/dc| = |dx/dy| = prod(ds/dy) = prod(expit'(y)) # we know log derivative of expit(y) is `-softplus(y) - softplus(-y)` J_logdet = (softplus(y) + softplus(-y)).sum(-1) @@ -1071,10 +1127,10 @@ def log_abs_det_jacobian( def tree_flatten(self): return (self.anchor_point,), (("anchor_point",), dict()) - def __eq__(self, other: TransformT) -> bool: + def __eq__(self, other: object) -> bool: if not isinstance(other, SimplexToOrderedTransform): return False - return jnp.array_equal(self.anchor_point, other.anchor_point) + return jnp.array_equal(self.anchor_point, other.anchor_point) # type: ignore[return-value] def forward_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: return shape[:-1] + (shape[-1] - 1,) @@ -1083,8 +1139,8 @@ def inverse_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: return shape[:-1] + (shape[-1] + 1,) -def _softplus_inv(y: ArrayLike) -> ArrayLike: - return jnp.log(-jnp.expm1(-y)) + y +def _softplus_inv(y: ArrayLike) -> NumLike: + return jnp.log(-jnp.expm1(-y)) + y # type: ignore[operator] class SoftplusTransform(ParameterFreeTransform): @@ -1097,16 +1153,19 @@ class SoftplusTransform(ParameterFreeTransform): codomain = constraints.softplus_positive sign = 1 - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: NumLike) -> NumLike: return softplus(x) - def _inverse(self, y: ArrayLike) -> ArrayLike: + def _inverse(self, y: NumLike) -> NumLike: return _softplus_inv(y) def log_abs_det_jacobian( - self, x: ArrayLike, y: ArrayLike, intermediates=None - ) -> ArrayLike: - return -softplus(-x) + self, + x: NumLike, + y: NumLike, + intermediates: Optional[PyTree] = None, + ) -> NumLike: + return -softplus(-x) # type: ignore[operator] class SoftplusLowerCholeskyTransform(ParameterFreeTransform): @@ -1119,20 +1178,23 @@ class SoftplusLowerCholeskyTransform(ParameterFreeTransform): domain = constraints.real_vector codomain = constraints.softplus_lower_cholesky - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: NonScalarArray) -> NonScalarArray: n = round((math.sqrt(1 + 8 * x.shape[-1]) - 1) / 2) z = vec_to_tril_matrix(x[..., :-n], diagonal=-1) diag = softplus(x[..., -n:]) return z + jnp.expand_dims(diag, axis=-1) * jnp.identity(n) - def _inverse(self, y: ArrayLike) -> ArrayLike: + def _inverse(self, y: NonScalarArray) -> NonScalarArray: z = matrix_to_tril_vec(y, diagonal=-1) diag = _softplus_inv(jnp.diagonal(y, axis1=-2, axis2=-1)) return jnp.concatenate([z, diag], axis=-1) def log_abs_det_jacobian( - self, x: ArrayLike, y: ArrayLike, intermediates=None - ) -> ArrayLike: + self, + x: NonScalarArray, + y: NonScalarArray, + intermediates: Optional[PyTree] = None, + ) -> NumLike: # the jacobian is diagonal, so logdet is the sum of diagonal # `softplus` transform n = round((math.sqrt(1 + 8 * x.shape[-1]) - 1) / 2) @@ -1149,7 +1211,7 @@ class StickBreakingTransform(ParameterFreeTransform): domain = constraints.real_vector codomain = constraints.simplex - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: NonScalarArray) -> NonScalarArray: # we shift x to obtain a balanced mapping (0, 0, ..., 0) -> (1/K, 1/K, ..., 1/K) x = x - jnp.log(x.shape[-1] - jnp.arange(x.shape[-1])) # convert to probabilities (relative to the remaining) of each fraction of the stick @@ -1165,7 +1227,7 @@ def __call__(self, x: ArrayLike) -> ArrayLike: ) return z_padded * z1m_cumprod_shifted - def _inverse(self, y: ArrayLike) -> ArrayLike: + def _inverse(self, y: NonScalarArray) -> NonScalarArray: y_crop = y[..., :-1] z1m_cumprod = jnp.clip(1 - jnp.cumsum(y_crop, axis=-1), jnp.finfo(y.dtype).tiny) # hence x = logit(z) = log(z / (1 - z)) = y[::-1] / z1m_cumprod @@ -1173,8 +1235,11 @@ def _inverse(self, y: ArrayLike) -> ArrayLike: return x + jnp.log(x.shape[-1] - jnp.arange(x.shape[-1])) def log_abs_det_jacobian( - self, x: ArrayLike, y: ArrayLike, intermediates=None - ) -> ArrayLike: + self, + x: NonScalarArray, + y: NonScalarArray, + intermediates: Optional[PyTree] = None, + ) -> NumLike: # Ref: https://mc-stan.org/docs/2_19/reference-manual/simplex-transform-section.html # |det|(J) = Product(y * (1 - sigmoid(x))) # = Product(y * sigmoid(x) * exp(-x)) @@ -1207,7 +1272,7 @@ def __init__(self, unpack_fn, pack_fn=None): self.unpack_fn = unpack_fn self.pack_fn = pack_fn - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: NonScalarArray) -> NonScalarArray: batch_shape = x.shape[:-1] if batch_shape: unpacked = vmap(self.unpack_fn)(x.reshape((-1,) + x.shape[-1:])) @@ -1217,7 +1282,7 @@ def __call__(self, x: ArrayLike) -> ArrayLike: else: return self.unpack_fn(x) - def _inverse(self, y: ArrayLike) -> ArrayLike: + def _inverse(self, y: NonScalarArray) -> NonScalarArray: if self.pack_fn is None: raise NotImplementedError( "pack_fn needs to be provided to perform UnpackTransform.inv." @@ -1238,8 +1303,11 @@ def _inverse(self, y: ArrayLike) -> ArrayLike: return self.pack_fn(y) def log_abs_det_jacobian( - self, x: ArrayLike, y: ArrayLike, intermediates=None - ) -> ArrayLike: + self, + x: NonScalarArray, + y: NonScalarArray, + intermediates: Optional[PyTree] = None, + ) -> NumLike: return jnp.zeros(jnp.shape(x)[:-1]) def forward_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: @@ -1252,7 +1320,7 @@ def tree_flatten(self): # XXX: what if unpack_fn is a parametrized callable pytree? return (), ((), {"unpack_fn": self.unpack_fn, "pack_fn": self.pack_fn}) - def __eq__(self, other: TransformT) -> bool: + def __eq__(self, other: object) -> bool: return ( isinstance(other, UnpackTransform) and (self.unpack_fn is other.unpack_fn) @@ -1293,12 +1361,12 @@ def __init__( self._inverse_shape = inverse_shape @property - def domain(self) -> constraints.Constraint: - return constraints.independent(constraints.real, len(self._inverse_shape)) + def domain(self) -> ConstraintT: # type: ignore[override] + return constraints.independent(constraints.real, len(self._inverse_shape)) # type: ignore[return-value] @property - def codomain(self) -> constraints.Constraint: - return constraints.independent(constraints.real, len(self._forward_shape)) + def codomain(self) -> ConstraintT: # type: ignore[override] + return constraints.independent(constraints.real, len(self._forward_shape)) # type: ignore[return-value] def forward_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: return _get_target_shape(shape, self._forward_shape, self._inverse_shape) @@ -1306,15 +1374,18 @@ def forward_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: def inverse_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: return _get_target_shape(shape, self._inverse_shape, self._forward_shape) - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: NonScalarArray) -> NonScalarArray: return jnp.reshape(x, self.forward_shape(jnp.shape(x))) - def _inverse(self, y: ArrayLike) -> ArrayLike: + def _inverse(self, y: NonScalarArray) -> NonScalarArray: return jnp.reshape(y, self.inverse_shape(jnp.shape(y))) def log_abs_det_jacobian( - self, x: ArrayLike, y: ArrayLike, intermediates=None - ) -> ArrayLike: + self, + x: NonScalarArray, + y: NonScalarArray, + intermediates: Optional[PyTree] = None, + ) -> NumLike: return jnp.zeros_like(x, shape=x.shape[: x.ndim - len(self._inverse_shape)]) def tree_flatten(self): @@ -1324,7 +1395,7 @@ def tree_flatten(self): } return (), ((), aux_data) - def __eq__(self, other: TransformT) -> bool: + def __eq__(self, other: object) -> bool: return ( isinstance(other, ReshapeTransform) and self._forward_shape == other._forward_shape @@ -1334,7 +1405,7 @@ def __eq__(self, other: TransformT) -> bool: def _normalize_rfft_shape( input_shape: tuple[int, ...], - shape: tuple[int, ...], + shape: Optional[tuple[int, ...]], ) -> tuple[int, ...]: if shape is None: return input_shape @@ -1365,11 +1436,11 @@ def __init__( self.transform_shape = transform_shape self.transform_ndims = transform_ndims - def __call__(self, x: jnp.ndarray) -> jnp.ndarray: + def __call__(self, x: NonScalarArray) -> NonScalarArray: axes = tuple(range(-self.transform_ndims, 0)) return jnp.fft.rfftn(x, self.transform_shape, axes) - def _inverse(self, y: jnp.ndarray) -> jnp.ndarray: + def _inverse(self, y: NonScalarArray) -> NonScalarArray: axes = tuple(range(-self.transform_ndims, 0)) return jnp.fft.irfftn(y, self.transform_shape, axes) @@ -1385,8 +1456,11 @@ def inverse_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: return shape[:-1] + (size,) def log_abs_det_jacobian( - self, x: Array, y: Array, intermediates: None = None - ) -> jnp.ndarray: + self, + x: NonScalarArray, + y: NonScalarArray, + intermediates: Optional[PyTree] = None, + ) -> NumLike: batch_shape = jnp.broadcast_shapes( x.shape[: -self.transform_ndims], y.shape[: -self.transform_ndims] ) @@ -1405,14 +1479,14 @@ def tree_flatten(self): return (), ((), aux_data) @property - def domain(self) -> constraints.Constraint: - return constraints.independent(constraints.real, self.transform_ndims) + def domain(self) -> ConstraintT: # type: ignore[override] + return constraints.independent(constraints.real, self.transform_ndims) # type: ignore[return-value] @property - def codomain(self) -> constraints.Constraint: - return constraints.independent(constraints.complex, self.transform_ndims) + def codomain(self) -> ConstraintT: # type: ignore[override] + return constraints.independent(constraints.complex, self.transform_ndims) # type: ignore[return-value] - def __eq__(self, other: TransformT) -> bool: + def __eq__(self, other: object) -> bool: return ( isinstance(other, RealFastFourierTransform) and self.transform_ndims == other.transform_ndims @@ -1428,13 +1502,13 @@ class PackRealFastFourierCoefficientsTransform(Transform): """ domain = constraints.real_vector - codomain = constraints.independent(constraints.complex, 1) + codomain = constraints.independent(constraints.complex, 1) # type: ignore[assignment] def __init__(self, transform_shape: Optional[tuple[int, ...]] = None) -> None: assert transform_shape is None or len(transform_shape) == 1, ( "Packing Fourier coefficients is only implemented for vectors." ) - self.shape = transform_shape + self.shape: tuple[int, ...] = transform_shape # type: ignore[assignment] def tree_flatten(self): return (), ((), {"shape": self.shape}) @@ -1457,25 +1531,28 @@ def inverse_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: return (*batch_shape, n) def log_abs_det_jacobian( - self, x: Array, y: Array, intermediates: None = None + self, + x: NonScalarArray, + y: NonScalarArray, + intermediates: Optional[PyTree] = None, ) -> Array: shape = jnp.broadcast_shapes(x.shape[:-1], y.shape[:-1]) return jnp.zeros_like(x, shape=shape) - def __call__(self, x: Array) -> Array: + def __call__(self, x: NonScalarArray) -> NonScalarArray: assert self.shape is None or self.shape == x.shape[-1:] n = x.shape[-1] n_real = n // 2 + 1 n_imag = n - n_real complex_dtype = jnp.result_type(x.dtype, jnp.complex64) return ( - x[..., :n_real] + jnp.asarray(x)[..., :n_real] .astype(complex_dtype) .at[..., 1 : 1 + n_imag] .add(1j * x[..., n_real:]) ) - def _inverse(self, y: Array) -> Array: + def _inverse(self, y: NonScalarArray) -> NonScalarArray: (n,) = self.shape n_real = n // 2 + 1 n_imag = n - n_real @@ -1540,7 +1617,11 @@ class RecursiveLinearTransform(Transform): domain = constraints.real_matrix codomain = constraints.real_matrix - def __init__(self, transition_matrix: Array, initial_value: Array = None) -> None: + def __init__( + self, + transition_matrix: NonScalarArray, + initial_value: Optional[NonScalarArray] = None, + ) -> None: event_shape = transition_matrix.shape[-1:] if initial_value is None: @@ -1567,7 +1648,7 @@ def _get_initial_value(self, sample_shape) -> Array: return jnp.broadcast_to(self.initial_value, batch_shape + event_shape) - def __call__(self, x: Array) -> Array: + def __call__(self, x: NonScalarArray) -> NonScalarArray: # Move the time axis to the first position so we can scan over it. sample_shape = x.shape[:-2] x = jnp.moveaxis(x, -2, 0) @@ -1581,7 +1662,7 @@ def f(y, x): _, y = lax.scan(f, initial_value, x) return jnp.moveaxis(y, 0, -2) - def _inverse(self, y: Array) -> Array: + def _inverse(self, y: NonScalarArray) -> NonScalarArray: # Move the time axis to the first position so we can scan over it in reverse. sample_shape = y.shape[:-2] y = jnp.moveaxis(y, -2, 0) @@ -1597,7 +1678,12 @@ def f(y, prev): ) return jnp.moveaxis(x, 0, -2) - def log_abs_det_jacobian(self, x: Array, y: Array, intermediates=None): + def log_abs_det_jacobian( + self, + x: NonScalarArray, + y: NonScalarArray, + intermediates: Optional[PyTree] = None, + ) -> NumLike: return jnp.zeros_like(x, shape=x.shape[:-2]) def tree_flatten(self): @@ -1606,10 +1692,10 @@ def tree_flatten(self): {}, ) - def __eq__(self, other: TransformT) -> bool: + def __eq__(self, other: object) -> bool: if not isinstance(other, RecursiveLinearTransform): return False - return jnp.array_equal(self.transition_matrix, other.transition_matrix) + return jnp.array_equal(self.transition_matrix, other.transition_matrix) # type: ignore[return-value] class ZeroSumTransform(Transform): @@ -1627,26 +1713,26 @@ def __init__(self, transform_ndims: int = 1) -> None: self.transform_ndims = transform_ndims @property - def domain(self) -> constraints.Constraint: - return constraints.independent(constraints.real, self.transform_ndims) + def domain(self) -> ConstraintT: # type: ignore[override] + return constraints.independent(constraints.real, self.transform_ndims) # type: ignore[return-value] @property - def codomain(self) -> constraints.Constraint: - return constraints.zero_sum(self.transform_ndims) + def codomain(self) -> ConstraintT: # type: ignore[override] + return constraints.zero_sum(self.transform_ndims) # type: ignore[return-value] - def __call__(self, x: Array) -> Array: + def __call__(self, x: NonScalarArray) -> NonScalarArray: zero_sum_axes = tuple(range(-self.transform_ndims, 0)) for axis in zero_sum_axes: x = self.extend_axis(x, axis=axis) return x - def _inverse(self, y: Array) -> Array: + def _inverse(self, y: NonScalarArray) -> NonScalarArray: zero_sum_axes = tuple(range(-self.transform_ndims, 0)) for axis in zero_sum_axes: y = self.extend_axis_rev(y, axis=axis) return y - def extend_axis_rev(self, array: Array, axis: int) -> Array: + def extend_axis_rev(self, array: NonScalarArray, axis: int) -> NonScalarArray: normalized_axis = axis if axis >= 0 else jnp.ndim(array) + axis n = array.shape[normalized_axis] @@ -1657,7 +1743,7 @@ def extend_axis_rev(self, array: Array, axis: int) -> Array: slice_before = (slice(None, None),) * normalized_axis return array[(*slice_before, slice(None, -1))] + norm - def extend_axis(self, array: Array, axis: int) -> Array: + def extend_axis(self, array: NonScalarArray, axis: int) -> NonScalarArray: n = array.shape[axis] + 1 sum_vals = array.sum(axis, keepdims=True) @@ -1668,7 +1754,10 @@ def extend_axis(self, array: Array, axis: int) -> Array: return out - norm def log_abs_det_jacobian( - self, x: Array, y: Array, intermediates: None = None + self, + x: NonScalarArray, + y: NonScalarArray, + intermediates: Optional[PyTree] = None, ) -> jnp.ndarray: shape = jnp.broadcast_shapes( x.shape[: -self.transform_ndims], y.shape[: -self.transform_ndims] @@ -1689,7 +1778,7 @@ def tree_flatten(self): aux_data = {"transform_ndims": self.transform_ndims} return (), ((), aux_data) - def __eq__(self, other: TransformT) -> bool: + def __eq__(self, other: object) -> bool: return ( isinstance(other, ZeroSumTransform) and self.transform_ndims == other.transform_ndims @@ -1704,14 +1793,19 @@ class ComplexTransform(ParameterFreeTransform): domain = constraints.real_vector codomain = constraints.complex - def __call__(self, x: Array) -> Array: + def __call__(self, x: NonScalarArray) -> NonScalarArray: assert x.shape[-1] == 2, "Input must have a trailing dimension of size 2." return lax.complex(x[..., 0], x[..., 1]) - def _inverse(self, y: Array) -> Array: + def _inverse(self, y: ArrayLike) -> Array: return jnp.stack([y.real, y.imag], axis=-1) - def log_abs_det_jacobian(self, x: Array, y: Array, intermediates=None) -> Array: + def log_abs_det_jacobian( + self, + x: NumLike, + y: NumLike, + intermediates: Optional[PyTree] = None, + ) -> Array: return jnp.zeros_like(y) def forward_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: diff --git a/pyproject.toml b/pyproject.toml index d79a3e532..47925bf7a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,9 +38,9 @@ extend-include = ["*.ipynb"] [tool.ruff.lint] select = ["ANN", "E", "F", "I", "W"] ignore = [ - "ANN002", # missing args type annotation - "ANN003", # missing kwargs type annotation - "ANN204", # missing type annotation for __call__ + "ANN002", # missing args type annotation + "ANN003", # missing kwargs type annotation + "ANN204", # missing type annotation for __call__ "E203", ] @@ -68,7 +68,9 @@ skip-magic-trailing-comma = false line-ending = "auto" [tool.ruff.lint.per-file-ignores] -"!numpyro/{diagnostics.py,handlers.py,optim.py,patch.py,primitives.py,infer/elbo.py}" = ["ANN"] # require type annotations in typed modules +"!numpyro/{diagnostics.py,handlers.py,optim.py,patch.py,primitives.py,infer/elbo.py}" = [ + "ANN", +] # require type annotations in typed modules [tool.ruff.lint.extend-per-file-ignores] "numpyro/contrib/tfp/distributions.py" = ["F811"] @@ -114,7 +116,6 @@ doctest_optionflags = [ [tool.mypy] ignore_errors = true ignore_missing_imports = true -plugins = ["numpy.typing.mypy_plugin"] [[tool.mypy.overrides]] module = [ @@ -129,5 +130,6 @@ module = [ "numpyro.primitives.*", "numpyro.patch.*", "numpyro.util.*", + "numpyro.distributions.transforms", ] ignore_errors = false