4
4
from typing import ClassVar
5
5
6
6
import jax .numpy as jnp
7
+ from jax .nn import softplus
7
8
from jax .scipy .linalg import solve_triangular
8
9
from jaxtyping import Array , ArrayLike , Shaped
9
10
10
- from flowjax import wrappers
11
11
from flowjax .bijections .bijection import AbstractBijection
12
- from flowjax .bijections . softplus import SoftPlus
13
- from flowjax .utils import arraylike_to_array
12
+ from flowjax .utils import arraylike_to_array , inv_softplus
13
+ from flowjax .wrappers import AbstractUnwrappable , Parameterize , unwrap
14
14
15
15
16
16
class Affine (AbstractBijection ):
@@ -29,7 +29,7 @@ class Affine(AbstractBijection):
29
29
shape : tuple [int , ...]
30
30
cond_shape : ClassVar [None ] = None
31
31
loc : Array
32
- scale : Array | wrappers . AbstractUnwrappable [Array ]
32
+ scale : Array | AbstractUnwrappable [Array ]
33
33
34
34
def __init__ (
35
35
self ,
@@ -40,7 +40,7 @@ def __init__(
40
40
* (arraylike_to_array (a , dtype = float ) for a in (loc , scale )),
41
41
)
42
42
self .shape = scale .shape
43
- self .scale = wrappers . BijectionReparam ( scale , SoftPlus ( ))
43
+ self .scale = Parameterize ( softplus , inv_softplus ( scale ))
44
44
45
45
def transform (self , x , condition = None ):
46
46
return x * self .scale + self .loc
@@ -92,15 +92,15 @@ class Scale(AbstractBijection):
92
92
93
93
shape : tuple [int , ...]
94
94
cond_shape : ClassVar [None ] = None
95
- scale : Array | wrappers . AbstractUnwrappable [Array ]
95
+ scale : Array | AbstractUnwrappable [Array ]
96
96
97
97
def __init__ (
98
98
self ,
99
99
scale : ArrayLike ,
100
100
):
101
101
scale = arraylike_to_array (scale , "scale" , dtype = float )
102
- self .scale = wrappers . BijectionReparam ( scale , SoftPlus ( ))
103
- self .shape = jnp .shape (wrappers . unwrap (scale ))
102
+ self .scale = Parameterize ( softplus , inv_softplus ( scale ))
103
+ self .shape = jnp .shape (unwrap (scale ))
104
104
105
105
def transform (self , x , condition = None ):
106
106
return x * self .scale
@@ -120,7 +120,7 @@ class TriangularAffine(AbstractBijection):
120
120
121
121
Transformation has the form :math:`Ax + b`, where :math:`A` is a lower or upper
122
122
triangular matrix, and :math:`b` is the bias vector. We assume the diagonal
123
- entries are positive, and constrain the values using SoftPlus . Other
123
+ entries are positive, and constrain the values using softplus . Other
124
124
parameterizations can be achieved by e.g. replacing ``self.triangular``
125
125
after construction.
126
126
@@ -135,7 +135,7 @@ class TriangularAffine(AbstractBijection):
135
135
shape : tuple [int , ...]
136
136
cond_shape : ClassVar [None ] = None
137
137
loc : Array
138
- triangular : Array | wrappers . AbstractUnwrappable [Array ]
138
+ triangular : Array | AbstractUnwrappable [Array ]
139
139
lower : bool
140
140
141
141
def __init__ (
@@ -152,12 +152,12 @@ def __init__(
152
152
raise ValueError ("arr must be a square, 2-dimensional matrix." )
153
153
dim = arr .shape [0 ]
154
154
155
- def _to_triangular (diag , arr ):
156
- tri = jnp .tril (arr , k = - 1 ) if lower else jnp .triu (arr , k = 1 )
157
- return jnp .diag ( diag ) + tri
155
+ def _to_triangular (arr ):
156
+ tri = jnp .tril (arr ) if lower else jnp .triu (arr )
157
+ return jnp .fill_diagonal ( tri , softplus ( jnp . diag ( tri )), inplace = False )
158
158
159
- diag = wrappers . BijectionReparam ( jnp .diag (arr ), SoftPlus () )
160
- self .triangular = wrappers . Lambda (_to_triangular , diag = diag , arr = arr )
159
+ arr = jnp . fill_diagonal ( arr , inv_softplus ( jnp .diag (arr )), inplace = False )
160
+ self .triangular = Parameterize (_to_triangular , arr )
161
161
self .lower = lower
162
162
self .shape = (dim ,)
163
163
self .loc = jnp .broadcast_to (loc , (dim ,))
0 commit comments