Skip to content

Commit b8eb028

Browse files
authored
Merge pull request #172 from danielward27/update_wrappers
Update wrappers
2 parents 157f867 + e53c865 commit b8eb028

12 files changed

+104
-177
lines changed

flowjax/bijections/affine.py

+15-15
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,13 @@
44
from typing import ClassVar
55

66
import jax.numpy as jnp
7+
from jax.nn import softplus
78
from jax.scipy.linalg import solve_triangular
89
from jaxtyping import Array, ArrayLike, Shaped
910

10-
from flowjax import wrappers
1111
from flowjax.bijections.bijection import AbstractBijection
12-
from flowjax.bijections.softplus import SoftPlus
13-
from flowjax.utils import arraylike_to_array
12+
from flowjax.utils import arraylike_to_array, inv_softplus
13+
from flowjax.wrappers import AbstractUnwrappable, Parameterize, unwrap
1414

1515

1616
class Affine(AbstractBijection):
@@ -29,7 +29,7 @@ class Affine(AbstractBijection):
2929
shape: tuple[int, ...]
3030
cond_shape: ClassVar[None] = None
3131
loc: Array
32-
scale: Array | wrappers.AbstractUnwrappable[Array]
32+
scale: Array | AbstractUnwrappable[Array]
3333

3434
def __init__(
3535
self,
@@ -40,7 +40,7 @@ def __init__(
4040
*(arraylike_to_array(a, dtype=float) for a in (loc, scale)),
4141
)
4242
self.shape = scale.shape
43-
self.scale = wrappers.BijectionReparam(scale, SoftPlus())
43+
self.scale = Parameterize(softplus, inv_softplus(scale))
4444

4545
def transform(self, x, condition=None):
4646
return x * self.scale + self.loc
@@ -92,15 +92,15 @@ class Scale(AbstractBijection):
9292

9393
shape: tuple[int, ...]
9494
cond_shape: ClassVar[None] = None
95-
scale: Array | wrappers.AbstractUnwrappable[Array]
95+
scale: Array | AbstractUnwrappable[Array]
9696

9797
def __init__(
9898
self,
9999
scale: ArrayLike,
100100
):
101101
scale = arraylike_to_array(scale, "scale", dtype=float)
102-
self.scale = wrappers.BijectionReparam(scale, SoftPlus())
103-
self.shape = jnp.shape(wrappers.unwrap(scale))
102+
self.scale = Parameterize(softplus, inv_softplus(scale))
103+
self.shape = jnp.shape(unwrap(scale))
104104

105105
def transform(self, x, condition=None):
106106
return x * self.scale
@@ -120,7 +120,7 @@ class TriangularAffine(AbstractBijection):
120120
121121
Transformation has the form :math:`Ax + b`, where :math:`A` is a lower or upper
122122
triangular matrix, and :math:`b` is the bias vector. We assume the diagonal
123-
entries are positive, and constrain the values using SoftPlus. Other
123+
entries are positive, and constrain the values using softplus. Other
124124
parameterizations can be achieved by e.g. replacing ``self.triangular``
125125
after construction.
126126
@@ -135,7 +135,7 @@ class TriangularAffine(AbstractBijection):
135135
shape: tuple[int, ...]
136136
cond_shape: ClassVar[None] = None
137137
loc: Array
138-
triangular: Array | wrappers.AbstractUnwrappable[Array]
138+
triangular: Array | AbstractUnwrappable[Array]
139139
lower: bool
140140

141141
def __init__(
@@ -152,12 +152,12 @@ def __init__(
152152
raise ValueError("arr must be a square, 2-dimensional matrix.")
153153
dim = arr.shape[0]
154154

155-
def _to_triangular(diag, arr):
156-
tri = jnp.tril(arr, k=-1) if lower else jnp.triu(arr, k=1)
157-
return jnp.diag(diag) + tri
155+
def _to_triangular(arr):
156+
tri = jnp.tril(arr) if lower else jnp.triu(arr)
157+
return jnp.fill_diagonal(tri, softplus(jnp.diag(tri)), inplace=False)
158158

159-
diag = wrappers.BijectionReparam(jnp.diag(arr), SoftPlus())
160-
self.triangular = wrappers.Lambda(_to_triangular, diag=diag, arr=arr)
159+
arr = jnp.fill_diagonal(arr, inv_softplus(jnp.diag(arr)), inplace=False)
160+
self.triangular = Parameterize(_to_triangular, arr)
161161
self.lower = lower
162162
self.shape = (dim,)
163163
self.loc = jnp.broadcast_to(loc, (dim,))

flowjax/bijections/block_autoregressive_network.py

+7-9
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,14 @@
99
import jax.numpy as jnp
1010
import jax.random as jr
1111
from jax import random
12+
from jax.nn import softplus
1213
from jaxtyping import PRNGKeyArray
1314

1415
from flowjax import masks
1516
from flowjax.bijections.bijection import AbstractBijection
16-
from flowjax.bijections.softplus import SoftPlus
1717
from flowjax.bijections.tanh import LeakyTanh
1818
from flowjax.bisection_search import AutoregressiveBisectionInverter
19-
from flowjax.wrappers import BijectionReparam, WeightNormalization, Where
19+
from flowjax.wrappers import Parameterize, WeightNormalization
2020

2121

2222
class _CallableToBijection(AbstractBijection):
@@ -219,13 +219,11 @@ def block_autoregressive_linear(
219219
block_diag_mask = masks.block_diag_mask(block_shape, n_blocks)
220220
block_tril_mask = masks.block_tril_mask(block_shape, n_blocks)
221221

222-
weight = Where(block_tril_mask, linear.weight, 0)
223-
weight = Where(
224-
block_diag_mask,
225-
BijectionReparam(weight, SoftPlus(), invert_on_init=False),
226-
weight,
227-
)
228-
weight = WeightNormalization(weight)
222+
def apply_mask(weight):
223+
weight = jnp.where(block_tril_mask, weight, 0)
224+
return jnp.where(block_diag_mask, softplus(weight), weight)
225+
226+
weight = WeightNormalization(Parameterize(apply_mask, linear.weight))
229227
linear = eqx.tree_at(lambda linear: linear.weight, linear, replace=weight)
230228

231229
def linear_to_log_block_diagonal(linear: eqx.nn.Linear):

flowjax/bijections/masked_autoregressive.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from flowjax.bijections.jax_transforms import Vmap
1414
from flowjax.masks import rank_based_mask
1515
from flowjax.utils import get_ravelled_pytree_constructor
16-
from flowjax.wrappers import Where
16+
from flowjax.wrappers import Parameterize
1717

1818

1919
class MaskedAutoregressive(AbstractBijection):
@@ -135,7 +135,7 @@ def masked_autoregressive_mlp(
135135
) -> eqx.nn.MLP:
136136
"""Returns an equinox multilayer perceptron, with autoregressive masks.
137137
138-
The weight matrices are wrapped using :class:`~flowjax.wrappers.Where`, which
138+
The weight matrices are wrapped using :class:`~flowjax.wrappers.Parameterize`, which
139139
will apply the masking when :class:`~flowjax.wrappers.unwrap` is called on the MLP.
140140
For details of how the masks are formed, see https://arxiv.org/pdf/1502.03509.pdf.
141141
@@ -160,7 +160,9 @@ def masked_autoregressive_mlp(
160160
for i, linear in enumerate(mlp.layers):
161161
mask = rank_based_mask(ranks[i], ranks[i + 1], eq=i != len(mlp.layers) - 1)
162162
masked_linear = eqx.tree_at(
163-
lambda linear: linear.weight, linear, Where(mask, linear.weight, 0)
163+
lambda linear: linear.weight,
164+
linear,
165+
Parameterize(jnp.where, mask, linear.weight, 0),
164166
)
165167
masked_layers.append(masked_linear)
166168

flowjax/bijections/rational_quadratic_spline.py

+9-8
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,9 @@
77
import jax.numpy as jnp
88
from jaxtyping import Array, Float
99

10-
from flowjax import wrappers
1110
from flowjax.bijections.bijection import AbstractBijection
11+
from flowjax.utils import inv_softplus
12+
from flowjax.wrappers import AbstractUnwrappable, Parameterize
1213

1314

1415
def _real_to_increasing_on_interval(
@@ -62,9 +63,9 @@ class RationalQuadraticSpline(AbstractBijection):
6263
interval: tuple[int | float, int | float]
6364
softmax_adjust: float | int
6465
min_derivative: float
65-
x_pos: Array | wrappers.AbstractUnwrappable[Array]
66-
y_pos: Array | wrappers.AbstractUnwrappable[Array]
67-
derivatives: Array | wrappers.AbstractUnwrappable[Array]
66+
x_pos: Array | AbstractUnwrappable[Array]
67+
y_pos: Array | AbstractUnwrappable[Array]
68+
derivatives: Array | AbstractUnwrappable[Array]
6869
shape: ClassVar[tuple] = ()
6970
cond_shape: ClassVar[None] = None
7071

@@ -89,11 +90,11 @@ def __init__(
8990
softmax_adjust=softmax_adjust,
9091
)
9192

92-
self.x_pos = wrappers.Lambda(pos_parameterization, jnp.zeros(knots))
93-
self.y_pos = wrappers.Lambda(pos_parameterization, jnp.zeros(knots))
94-
self.derivatives = wrappers.Lambda(
93+
self.x_pos = Parameterize(pos_parameterization, jnp.zeros(knots))
94+
self.y_pos = Parameterize(pos_parameterization, jnp.zeros(knots))
95+
self.derivatives = Parameterize(
9596
lambda arr: jax.nn.softplus(arr) + self.min_derivative,
96-
jnp.full(knots + 2, jnp.log(jnp.exp(1 - min_derivative) - 1)),
97+
jnp.full(knots + 2, inv_softplus(1 - min_derivative)),
9798
)
9899

99100
def transform(self, x, condition=None):

flowjax/distributions.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import jax.numpy as jnp
1212
import jax.random as jr
1313
from equinox import AbstractVar
14-
from jax.nn import log_softmax
14+
from jax.nn import log_softmax, softplus
1515
from jax.numpy import linalg
1616
from jax.scipy import stats as jstats
1717
from jax.scipy.special import logsumexp
@@ -24,15 +24,15 @@
2424
Chain,
2525
Exp,
2626
Scale,
27-
SoftPlus,
2827
TriangularAffine,
2928
)
3029
from flowjax.utils import (
3130
_get_ufunc_signature,
3231
arraylike_to_array,
32+
inv_softplus,
3333
merge_cond_shapes,
3434
)
35-
from flowjax.wrappers import AbstractUnwrappable, BijectionReparam, Lambda, unwrap
35+
from flowjax.wrappers import AbstractUnwrappable, Parameterize, unwrap
3636

3737

3838
class AbstractDistribution(eqx.Module):
@@ -609,7 +609,7 @@ def __init__(self, df: ArrayLike):
609609
df = arraylike_to_array(df, dtype=float)
610610
df = eqx.error_if(df, df <= 0, "Degrees of freedom values must be positive.")
611611
self.shape = jnp.shape(df)
612-
self.df = BijectionReparam(df, SoftPlus())
612+
self.df = Parameterize(softplus, inv_softplus(df))
613613

614614
def _log_prob(self, x, condition=None):
615615
return jstats.t.logpdf(x, df=self.df).sum()
@@ -761,7 +761,7 @@ def __init__(
761761
):
762762
weights = eqx.error_if(weights, weights <= 0, "Weights must be positive.")
763763
self.dist = dist
764-
self.log_normalized_weights = Lambda(lambda w: log_softmax(w), jnp.log(weights))
764+
self.log_normalized_weights = Parameterize(log_softmax, jnp.log(weights))
765765
self.shape = dist.shape
766766
self.cond_shape = dist.cond_shape
767767

flowjax/flows.py

+5-9
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import jax.numpy as jnp
1515
import jax.random as jr
1616
from equinox.nn import Linear
17+
from jax.nn import softplus
1718
from jax.nn.initializers import glorot_uniform
1819
from jaxtyping import PRNGKeyArray
1920

@@ -27,27 +28,22 @@
2728
Flip,
2829
Invert,
2930
LeakyTanh,
30-
Loc,
3131
MaskedAutoregressive,
3232
Permute,
3333
Planar,
3434
RationalQuadraticSpline,
3535
Scan,
36-
SoftPlus,
3736
TriangularAffine,
3837
Vmap,
3938
)
4039
from flowjax.distributions import AbstractDistribution, Transformed
41-
from flowjax.wrappers import BijectionReparam, WeightNormalization, non_trainable
40+
from flowjax.utils import inv_softplus
41+
from flowjax.wrappers import Parameterize, WeightNormalization
4242

4343

4444
def _affine_with_min_scale(min_scale: float = 1e-2) -> Affine:
45-
scale_reparam = Chain([SoftPlus(), non_trainable(Loc(min_scale))])
46-
return eqx.tree_at(
47-
where=lambda aff: aff.scale,
48-
pytree=Affine(),
49-
replace=BijectionReparam(jnp.array(1), scale_reparam),
50-
)
45+
scale = Parameterize(lambda x: softplus(x) + min_scale, inv_softplus(1 - min_scale))
46+
return eqx.tree_at(where=lambda aff: aff.scale, pytree=Affine(), replace=scale)
5147

5248

5349
def coupling_flow(

flowjax/train/variational_fit.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def fit_to_variational_target(
6161
params, opt_state, loss = step(
6262
params,
6363
static,
64-
key,
64+
key=key,
6565
optimizer=optimizer,
6666
opt_state=opt_state,
6767
loss_fn=loss_fn,

flowjax/utils.py

+12
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,18 @@
1010
import flowjax
1111

1212

13+
def inv_softplus(x: ArrayLike) -> Array:
14+
"""The inverse of the softplus function, checking for positive inputs."""
15+
x = eqx.error_if(
16+
x,
17+
x < 0,
18+
"Expected positive inputs to inv_softplus. If you are trying to use a negative "
19+
"scale parameter, consider constructing with positive scales and modifying the "
20+
"scale attribute post-construction, e.g., using eqx.tree_at.",
21+
)
22+
return jnp.log(-jnp.expm1(-x)) + x
23+
24+
1325
def merge_cond_shapes(shapes: Sequence[tuple[int, ...] | None]):
1426
"""Merges shapes (tuples of ints or None) used in bijections and distributions.
1527

0 commit comments

Comments
 (0)