diff --git a/grgrjax/newton.py b/grgrjax/newton.py index 01ab9a9..9919ddf 100644 --- a/grgrjax/newton.py +++ b/grgrjax/newton.py @@ -6,6 +6,13 @@ import scipy.sparse as ssp from .helpers import val_and_jacfwd, amax +try: + ssp_csr_array = ssp.csr_array + ssp_lil_array = ssp.lil_array +except AttributeError: + ssp_csr_array = ssp._arrays.csr_array + ssp_lil_array = ssp._arrays.lil_array + def _newton_cond_func(carry): (xi, eps, cnt), (func, verbose, maxit, tol) = carry @@ -148,7 +155,7 @@ def newton_jax(func, init, maxit=30, tol=1e-8, rtol=None, solver=None, verbose=T fval, jacval, aux = fout if len(fout) == 3 else (*fout, None) # check for convergence or errors jac_is_nan = jnp.isnan(jacval.data).any() if isinstance( - jacval, ssp._arrays.csr_array) else jnp.isnan(jacval).any() + jacval, ssp_csr_array) else jnp.isnan(jacval).any() eps = jnp.abs(fval).max() if _perform_checks_newton(res, eps, cnt, jac_is_nan, tol, rtol, maxit): break @@ -159,7 +166,7 @@ def newton_jax(func, init, maxit=30, tol=1e-8, rtol=None, solver=None, verbose=T info_str = f' Iteration {cnt:3d} | max. error {eps:.2e} | lapsed {ltime:3.4f}' if verbose_jac: jacval = jacval.toarray() if isinstance( - jacval, ssp._arrays.csr_array) else jacval + jacval, ssp_csr_array) else jacval jacdet = jnp.linalg.det(jacval) if ( jacval.shape[0] == jacval.shape[1]) else 0 info_str += f' | det {jacdet:1.5g} | rank {jnp.linalg.matrix_rank(jacval)}/{jacval.shape[0]}' @@ -167,14 +174,14 @@ def newton_jax(func, init, maxit=30, tol=1e-8, rtol=None, solver=None, verbose=T # assign suitable solver if not given if solver is None: - if isinstance(jacval, ssp._arrays.csr_array): + if isinstance(jacval, ssp_csr_array): solver = ssp.linalg.spsolve else: solver = jax.scipy.linalg.solve xi -= solver(jacval, fval) jacval = jacval.toarray() if isinstance( - jacval, (ssp._arrays.csr_array, ssp._arrays.lil_array)) else jacval + jacval, (ssp_csr_array, ssp_lil_array)) else jacval res['x'], res['niter'] = xi, cnt res['fun'], res['jac'] = fval, jacval