Skip to content

Commit

Permalink
move to scipy 1.11
Browse files Browse the repository at this point in the history
  • Loading branch information
gboehl committed Jun 27, 2023
1 parent 43737da commit 6c94350
Showing 1 changed file with 11 additions and 4 deletions.
15 changes: 11 additions & 4 deletions grgrjax/newton.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,13 @@
import scipy.sparse as ssp
from .helpers import val_and_jacfwd, amax

try:
ssp_csr_array = ssp.csr_array
ssp_lil_array = ssp.lil_array
except AttributeError:
ssp_csr_array = ssp._arrays.csr_array
ssp_lil_array = ssp._arrays.lil_array


def _newton_cond_func(carry):
(xi, eps, cnt), (func, verbose, maxit, tol) = carry
Expand Down Expand Up @@ -148,7 +155,7 @@ def newton_jax(func, init, maxit=30, tol=1e-8, rtol=None, solver=None, verbose=T
fval, jacval, aux = fout if len(fout) == 3 else (*fout, None)
# check for convergence or errors
jac_is_nan = jnp.isnan(jacval.data).any() if isinstance(
jacval, ssp._arrays.csr_array) else jnp.isnan(jacval).any()
jacval, ssp_csr_array) else jnp.isnan(jacval).any()
eps = jnp.abs(fval).max()
if _perform_checks_newton(res, eps, cnt, jac_is_nan, tol, rtol, maxit):
break
Expand All @@ -159,22 +166,22 @@ def newton_jax(func, init, maxit=30, tol=1e-8, rtol=None, solver=None, verbose=T
info_str = f' Iteration {cnt:3d} | max. error {eps:.2e} | lapsed {ltime:3.4f}'
if verbose_jac:
jacval = jacval.toarray() if isinstance(
jacval, ssp._arrays.csr_array) else jacval
jacval, ssp_csr_array) else jacval
jacdet = jnp.linalg.det(jacval) if (
jacval.shape[0] == jacval.shape[1]) else 0
info_str += f' | det {jacdet:1.5g} | rank {jnp.linalg.matrix_rank(jacval)}/{jacval.shape[0]}'
print(info_str)

# assign suitable solver if not given
if solver is None:
if isinstance(jacval, ssp._arrays.csr_array):
if isinstance(jacval, ssp_csr_array):
solver = ssp.linalg.spsolve
else:
solver = jax.scipy.linalg.solve
xi -= solver(jacval, fval)

jacval = jacval.toarray() if isinstance(
jacval, (ssp._arrays.csr_array, ssp._arrays.lil_array)) else jacval
jacval, (ssp_csr_array, ssp_lil_array)) else jacval

res['x'], res['niter'] = xi, cnt
res['fun'], res['jac'] = fval, jacval
Expand Down

0 comments on commit 6c94350

Please sign in to comment.