-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtest_jax.py
201 lines (153 loc) · 5.66 KB
/
test_jax.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
"""jax examples (https://github.com/google/jax).
Define custom derivatives via JVPs (forward mode).
The code examples are not useful production code since most derivatives that we
implement are already the correct default anyway (e.g. jax.grad(jax.numpy.sin)
-> jax.numpy.cos).
When we call
>>> grad(func)(x)
all x and v in each custom_jvp function have x's shape, which makes sense and
is not always the case in autograd.
>>> # scalar
>>> x=jnp.array(1.234)
>>> # array
>>> x=jnp.array(rand(3))
"""
from jax import grad, vmap, jacobian, random
import jax
import jax.numpy as jnp
def elementwise_grad(func):
"""Emulate elementwise_grad() from autograd."""
return vmap(grad(func))
@jax.custom_jvp
def pow2(x):
return jnp.power(x, 2.0)
# pow2(), mysin(): We call defjvp() below in tests to check several JVP
# implementations. That's why we skip the decorators here.
##@pow2.defjvp
def pow2_jvp(primals, tangents):
(x,) = primals
(v,) = tangents
return pow2(x), 2 * x * v
##@pow2.defjvp
def pow2_jvp_with_jac(primals, tangents):
"""JVP where we really build the intermediate Jacobian. Not
necessary in practice, only for demonstration.
"""
(x,) = primals
(v,) = tangents
# jacobian() works for scalar and 1d array input, diag() doesn't
if x.shape == ():
return pow2(x), 2 * x * v
else:
##jac = jacobian(lambda x: jnp.power(x,2))(x)
jac = jnp.diag(2 * x)
return pow2(x), jnp.dot(jac, v)
@jax.custom_jvp
def mysin(x):
"""Fake use case for custom sin(): We pretend that we have a super fast
approximation of sin(x): Some terms of Taylor around x=0. Actually, this is
much slower than np.sin() :-D
"""
##return jnp.sin(x)
return x - x**3 / 6 + x**5 / 120 - x**7 / 5040 + x**9 / 362880
def mycos(x):
"""Here is a real use case for implementing a custom_jvp, in this case for
mysin():
"Approximate the derivative, not differentiate the approximation."
The analytic (and thus the AD) deriv of mysin() is
1 - x**2/2 + x**4/24 - x**6/720 + x**8/40320
But that's not accurate enough! By ADing the approximate sine, we get an
approximate cosine which is worse. We need the x**10 term as well. With that
|mycos(x) - cos(x)| is slightly better than |mysin(x) - sin(x)|, w/o it
slightly worse. Both grow beyond 1e-8 outside of ~ [-1,1].
"""
return (
1 - x**2 / 2 + x**4 / 24 - x**6 / 720 + x**8 / 40320 - x**10 / 3628800
)
##@mysin.defjvp
def mysin_jvp(primals, tangents):
(x,) = primals
(v,) = tangents
##return jnp.sin(x), jnp.cos(x) * v
return mysin(x), mycos(x) * v
##@mysin.defjvp
def mysin_jvp_with_jac(primals, tangents):
(x,) = primals
(v,) = tangents
# jacobian() works for scalar and 1d array input, diag() doesn't
if x.shape == ():
return mysin(x), mycos(x) * v
else:
# The same, using exact results:
# jac = jacobian(jnp.sin)(x)
# jac = jnp.diag(jnp.cos(x))
# but:
# jac = jacobian(mysin)(x)
# doesn't work b/c we can't use a function to calculate its own deriv
# (jacobian() would call the JVP which are about to define right here).
jac = jnp.diag(mycos(x))
return mysin(x), jnp.dot(jac, v)
@jax.custom_jvp
def mysum(x):
return jnp.sum(x)
@mysum.defjvp
def mysum_jvp(primals, tangents):
"""
jac = jacobian(jnp.sum)(x) == jnp.ones_like(x), i.e. 1st row of J b/c sum:
R^n -> R, so dot(jac, v) == sum(v). However, note that when v is scalar,
e.g. jnp.array(1.234), dot() does NOT perform a sum, but only multiplies
(scalar-vector product). Oddly enough, in this case returning either a
scalar, e.g. one of
sum(v)
v
or a vector, one of
dot(jnp.ones_like(x), v)
jnp.ones_like(x) * v
works.
"""
(x,) = primals
(v,) = tangents
# v scalar or array
return jnp.sum(x), jnp.sum(v)
##return jnp.sum(x), jnp.dot(jnp.ones_like(x), v)
# v scalar only
##return jnp.sum(x), jnp.ones_like(x) * v
##return jnp.sum(x), v
def func(x):
"""Composite function we wish to differentiate. Implemented using jax
primitives."""
return jnp.sum(jnp.power(jnp.sin(x), 2))
def func_with_jvp(x):
"""Composite function we wish to differentiate. Implemented using custom
primitives for which we also defined custom JVPs."""
return mysum(pow2(mysin(x)))
def test():
assert jnp.allclose(grad(jnp.sin)(1.234), jnp.cos(1.234))
# Keep slightly tighter than -pi/2 .. pi/2 to keep mysin() and mycos()
# errors below 1e-8, else tune allclose() default thresholds.
x = random.uniform(key=random.PRNGKey(123), shape=(10,)) * 2 - 1
assert jnp.allclose(jacobian(jnp.sin)(x), jnp.diag(jnp.cos(x)))
assert jnp.allclose(jacobian(jnp.sin)(x).sum(axis=0), jnp.cos(x))
assert jnp.allclose(elementwise_grad(jnp.sin)(x), jnp.cos(x))
assert (jacobian(jnp.sum)(x) == jnp.ones_like(x)).all()
for p2_jvp, s_jvp in [
(pow2_jvp, mysin_jvp),
(pow2_jvp_with_jac, mysin_jvp_with_jac),
]:
pow2.defjvp(p2_jvp)
mysin.defjvp(s_jvp)
assert jnp.allclose(
jnp.array([func(xi) for xi in x]),
jnp.array([func_with_jvp(xi) for xi in x]),
)
assert jnp.allclose(func(x), func_with_jvp(x))
assert jnp.allclose(
jnp.array([grad(func)(xi) for xi in x]),
jnp.array([grad(func_with_jvp)(xi) for xi in x]),
)
assert jnp.allclose(
elementwise_grad(func)(x), elementwise_grad(func_with_jvp)(x)
)
assert jnp.allclose(grad(func)(x), grad(func_with_jvp)(x))
if __name__ == "__main__":
test()