Skip to content

Commit 4810da7

Browse files
Remove deprecated jax apis (#113)
`jax.experimental.host_callback` was recently removed in [jax 0.4.35](https://jax.readthedocs.io/en/latest/changelog.html#jax-0-4-35-oct-22-2024). > - `jax.experimental.host_callback` has been deprecated since March 2024, with JAX version 0.4.26. Now we removed it. See jax-ml/jax#20385 for a discussion of alternatives. This PR replaces `jax.experimental.host_callback` with `jax.debug.callback` and `jax.experimental.io_callback`. In addition, it implements the batching change in `jax.numpy.linalg.solve()` since [jax 0.5.0](https://jax.readthedocs.io/en/latest/changelog.html#jax-0-5-0-jan-17-2025). > - [`jax.numpy.linalg.solve()`](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.linalg.solve.html#jax.numpy.linalg.solve) no longer supports batched 1D arguments on the right hand side. To recover the previous behavior in these cases, use `solve(a, b[..., None]).squeeze(-1)`. Lastly, jax's prng changed in 0.5.0 so the generated random numbers will be different. > Enable jax_threefry_partitionable by default (see [the update note](jax-ml/jax#18480)). This is unlucky on the power method test, so I've bumped the relative tolerance slightly. (Very untested, but at least it passes the tests. It seems tqdm has a bit of a footgun where if a jax array is passed to info, then it'll freeze. Perhaps this should be documented in the `kwargs` argument.) --------- Co-authored-by: Andres Potapczynski <[email protected]>
1 parent 994f801 commit 4810da7

File tree

6 files changed

+38
-36
lines changed

6 files changed

+38
-36
lines changed

README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
[![Downloads](https://static.pepy.tech/badge/cola-ml)](https://pepy.tech/project/cola-ml)
2323
<!-- [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/wilson-labs/cola/blob/master/docs/notebooks/colabs/all.ipynb) -->
2424

25-
CoLA is a framework for scalable linear algebra in machine learning and beyond, providing:
25+
CoLA is a framework for scalable linear algebra in machine learning and beyond, providing:
2626

2727
(1) Fast hardware-sensitive (GPU accelerated) iterative algorithms for general matrix operations; <br>
2828
(2) Algorithms that can exploit matrix structure for efficiency; <br>

cola/linalg/inverse/gmres.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ def gmres_fwd(A, rhs, x0, max_iters, tol, P, use_householder, use_triangular, pb
115115
zero_thresh = 10 * tol * overall_max[:, None]
116116
padding = xnp.where(largest_vals < zero_thresh, xnp.ones_like(largest_vals), xnp.zeros_like(largest_vals))
117117
added_diag = xnp.vmap(xnp.diag)(padding)
118-
y = xnp.solve(HT @ H + added_diag, HT[..., 0]) * beta[:, None]
118+
y = xnp.solve(HT @ H + added_diag, HT[..., 0, None]).squeeze(-1) * beta[:, None]
119119
zeros = xnp.zeros_like(y)
120120
y = xnp.where(largest_vals < zero_thresh, zeros, y)
121121
pred = xnp.permute(Q @ y[..., None], axes=[1, 0, 2])[:, :, 0]

cola/linalg/tbd/nullspace.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ def cond_fn(state):
106106
if final_error > 5 * tol:
107107
logging.warning(f"Normalized basis has too high error {final_error:.2e} for tol {tol:.2e}")
108108
scutoff = (S[rank] if r > rank else 0)
109-
text = f"Singular value gap too small: {S[rank-1]:.2e}"
109+
text = f"Singular value gap too small: {S[rank - 1]:.2e}"
110110
text += "above cutoff {scutoff:.2e} below cutoff. Final L, earlier {S[rank-5:rank]}"
111111
assert rank == 0 or scutoff < S[rank - 1] / 100, text
112112

cola/linalg/unary/unary.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def _matmat(self, V): # (n,bs)
8181
norms = self.xnp.norm(V, axis=0)
8282

8383
e0 = self.xnp.canonical(0, (P.shape[1], V.shape[-1]), dtype=P.dtype, device=self.device)
84-
Pinv0 = self.xnp.solve(P, e0.T) # (bs, m, m) vs (bs, m)
84+
Pinv0 = self.xnp.solve(P, e0.T[..., None]).squeeze(-1) # (bs, m, m) vs (bs, m)
8585
out = Pinv0 * norms[:, None] # (bs, m)
8686
Q = self.xnp.cast(Q, dtype=P.dtype) # (bs, n, m)
8787
# (bs,n,m) @ (bs,m,m) @ (bs, m) -> (bs, n)

cola/utils/jax_tqdm.py

+33-31
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66

77
import jax
88
import numpy as np
9-
from jax.experimental import host_callback
9+
from jax.debug import callback
10+
from jax.experimental import io_callback
1011
from tqdm.auto import tqdm
1112

1213

@@ -94,45 +95,46 @@ def build_tqdm(n: int, message: typing.Optional[str] = None) -> typing.Tuple[typ
9495
print_rate = 1
9596
remainder = n % print_rate
9697

97-
def _define_tqdm(arg, transform):
98+
def _define_tqdm(arg):
9899
tqdm_bars[0] = tqdm(range(n))
99100
tqdm_bars[0].set_description(message, refresh=False)
100101

101-
def _update_tqdm(arg, transform):
102-
tqdm_bars[0].update(arg)
102+
def _update_tqdm(arg):
103+
tqdm_bars[0].update(float(arg))
103104

104105
def _update_progress_bar(iter_num):
105106
"Updates tqdm from a JAX scan or loop"
106107
_ = jax.jax.lax.cond(
107108
iter_num == 0,
108-
lambda _: host_callback.id_tap(_define_tqdm, None, result=iter_num),
109-
lambda _: iter_num,
109+
lambda _: callback(_define_tqdm, None),
110+
lambda _: None,
110111
operand=None,
111112
)
112113

113114
_ = jax.lax.cond(
114115
# update tqdm every multiple of `print_rate` except at the end
115116
(iter_num % print_rate == 0) & (iter_num != n - remainder),
116-
lambda _: host_callback.id_tap(_update_tqdm, print_rate, result=iter_num),
117-
lambda _: iter_num,
117+
lambda _: callback(_update_tqdm, print_rate),
118+
lambda _: None,
118119
operand=None,
119120
)
120121

121122
_ = jax.lax.cond(
122123
# update tqdm by `remainder`
123124
iter_num == n - remainder,
124-
lambda _: host_callback.id_tap(_update_tqdm, remainder, result=iter_num),
125-
lambda _: iter_num,
125+
lambda _: callback(_update_tqdm, remainder),
126+
lambda _: None,
126127
operand=None,
127128
)
128129

129-
def _close_tqdm(arg, transform):
130+
def _close_tqdm(arg):
130131
tqdm_bars[0].close()
132+
return arg
131133

132134
def close_tqdm(result, iter_num):
133135
return jax.lax.cond(
134136
iter_num == n - 1,
135-
lambda _: host_callback.id_tap(_close_tqdm, None, result=result),
137+
lambda _: io_callback(_close_tqdm, result),
136138
lambda _: result,
137139
operand=None,
138140
)
@@ -151,41 +153,41 @@ def new_while(cond_fun, body_fun, init_val):
151153
info = {'progval': 0, 'pbar': None}
152154
default_desc = f"Running {body_fun.__name__}"
153155

154-
def construct_tqdm(arg, transform):
156+
def construct_tqdm(arg):
155157
_bar_format = "{l_bar}{bar}| {n:.3g}/{total_fmt} [{elapsed}<{remaining}, "
156158
_bar_format += "{rate_fmt}{postfix}]"
157159
info['pbar'] = tqdm(total=100, desc=f'{desc or default_desc}', bar_format=_bar_format)
158160

159-
def update_tqdm(arg, transform):
161+
def update_tqdm(arg):
160162
error = errorfn(arg)
161163
errstart = info.setdefault('errstart', error)
162164
progress = max(100 * np.log(error / errstart) / np.log(tol / errstart) - info['progval'], 0)
163165
progress = min(100 - info['progval'], progress)
164166
if progress > 0:
165167
info['progval'] += progress
166-
info['pbar'].update(progress)
168+
info['pbar'].update(float(progress))
167169

168-
def close_tqdm(arg, transform):
169-
update_tqdm(arg, transform)
170+
def close_tqdm(arg):
171+
update_tqdm(arg)
170172
info['pbar'].close()
173+
return False
171174

172175
def newbody(ival):
173176
i, val = ival
174177
jax.lax.cond(
175178
i % every == 0,
176-
lambda _: host_callback.id_tap(update_tqdm, val, result=i),
177-
lambda _: i,
179+
lambda _: callback(update_tqdm, val),
180+
lambda _: None,
178181
operand=None,
179182
)
180183
return (i + 1, body_fun(val))
181184

182185
def newcond(ival):
183186
i, val = ival
184-
out = jax.lax.cond(cond_fun(val), lambda _: True,
185-
lambda _: host_callback.id_tap(close_tqdm, val, result=False), operand=None)
187+
out = jax.lax.cond(cond_fun(val), lambda _: True, lambda _: io_callback(close_tqdm, val), operand=None)
186188
return out
187189

188-
host_callback.id_tap(construct_tqdm, None)
190+
callback(construct_tqdm)
189191
_, val = jax.lax.while_loop(newcond, newbody, (0, init_val))
190192
return val
191193

@@ -224,23 +226,23 @@ def construct_info(*_):
224226
bar_format += "{rate_fmt}{postfix}]"
225227
info['pbar'] = tqdm(total=100, desc=f'{desc or default_desc}', bar_format=bar_format)
226228

227-
def update_info(ival, _):
229+
def update_info(ival):
228230
i, arg = ival
229-
error = errorfn(arg)
231+
error = float(errorfn(arg))
230232
info['errors'].append(error)
231233
if pbar:
232234
errstart = info.setdefault('errstart', error)
233235
howclose = np.log(error / errstart) / np.log(tol / errstart)
234236
if max_iters is not None:
235237
howclose = max((i + 1) / max_iters, howclose)
236-
progress = min(100 - info['progval'], max(100 * howclose - info['progval'], 0))
238+
progress = float(min(100 - info['progval'], max(100 * howclose - info['progval'], 0)))
237239
if progress > 0:
238240
info['progval'] += progress
239241
info['pbar'].update(progress)
240242

241-
def close_info(arg, transform):
243+
def close_info(arg):
242244
i, val = arg
243-
update_info(arg, transform)
245+
update_info(arg)
244246
info['iteration_time'] = (time.time() - info['iteration_time']) / (i + 1)
245247
if pbar:
246248
info['pbar'].close()
@@ -254,8 +256,8 @@ def newbody(ival):
254256
i, val = ival
255257
jax.lax.cond(
256258
i % every == 0,
257-
lambda _: host_callback.id_tap(update_info, ival, result=i),
258-
lambda _: i,
259+
lambda _: callback(update_info, ival),
260+
lambda _: None,
259261
operand=None,
260262
)
261263
return (i + 1, body_fun(val))
@@ -269,9 +271,9 @@ def newcond(ival):
269271
operand=None)
270272
return out
271273

272-
host_callback.id_tap(construct_info, None)
274+
callback(construct_info, None)
273275
k, val = jax.lax.while_loop(newcond, newbody, (0, init_val))
274-
host_callback.id_tap(close_info, (k, val))
276+
callback(close_info, (k, val))
275277
return val
276278

277279
return new_while, info

tests/test_linalg.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def test_power_iteration(backend):
3232
A = xnp.diag(xnp.array([10., 9.75, 3., 0.1], dtype=dtype, device=None))
3333
B = lazify(A)
3434
soln = xnp.array(10., dtype=dtype, device=None)
35-
tol, max_iter = 1e-5, 500
35+
tol, max_iter = 1e-6, 500
3636
_, approx, _ = power_iteration(B, tol=tol, max_iter=max_iter, momentum=0.)
3737
rel_error = relative_error(soln, approx)
3838
assert rel_error < tol * 100

0 commit comments

Comments
 (0)