-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feature Request bessel function in scipy ( J0 J1 J2 Y0 Y1 Y2) which have applied in autograd. #12402
Comments
Seconded - hoping to use J1, J0 in simple interferometry applications, without necessarily needing the higher order. Is there a reason these are hard to do in Jax? |
Hi, I'd love to see this implementation. I had to move to another framework because I needed J0 and J1 for a nerf related project. |
Seconding the request for J0, J1 functions! Would implementing these in Jax be similar to other packages such as scipy? |
I actually see Bessel Jv has now been implemented! But it seems to be numerically unstable: def j1(x):
# jax version
return jax.scipy.special.bessel_jn(x,v=1)[1]
from scipy.special import j1 as j1s # scipy version
z = np.linspace(0,100,1000)
plt.plot(z,vmap(j1)(z))
plt.plot(z,j1s(z)) This is improved by moving to higher |
Hi Notice also instabilities at high N=300
x = jnp.linspace(0,N,50*N)
jscs.bessel_jn(x, v=0, n_iter=200) leads to
while plot(x,jscs.bessel_jn(x, v=0, n_iter=200).squeeze())
plot(x,scs.jn(0,x),ls="--") # scipy.special.jn |
cc/ @tlu7 |
Here may be a solution for J0(x) using discussion #14132 and Boost code; def _evaluate_rational(z, num, denom):
count = len(num)
def true_fn_update(z):
def body_true(val1):
# decode val1
s1, s2, i = val1
s1 *= z; s2 *= z; s1 += num[i]; s2 += denom[i];
return s1, s2, i-1
def cond_true(val1):
s1, s2, i = val1
return i>=0
val1_init = (num[-1], denom[-1], count-2)
s1, s2, _ = jax.lax.while_loop(cond_true, body_true, val1_init)
return s1/s2
def false_fn_update(z):
def body_false(val1):
# decode val1
s1, s2, i = val1
s1 *= z; s2 *= z; s1 += num[i]; s2 += denom[i];
return s1, s2, i+1
def cond_false(val1):
s1, s2, i = val1
return i<count
val1_init = (num[0], denom[0],1)
s1, s2, _ = jax.lax.while_loop(cond_false, body_false, val1_init)
return s1/s2
return jnp.where(z<=1, true_fn_update(z), false_fn_update(1/z))
v_ratio = jax.vmap(_evaluate_rational, in_axes=(0,None, None))
def J0(x):
P1 = jnp.array([-4.1298668500990866786e+11,
2.7282507878605942706e+10,
-6.2140700423540120665e+08,
6.6302997904833794242e+06,
-3.6629814655107086448e+04,
1.0344222815443188943e+02,
-1.2117036164593528341e-01])
Q1 = jnp.array([2.3883787996332290397e+12,
2.6328198300859648632e+10,
1.3985097372263433271e+08,
4.5612696224219938200e+05,
9.3614022392337710626e+02,
1.0,
0.0])
assert len(P1) == len(Q1)
P2 = jnp.array([-1.8319397969392084011e+03,
-1.2254078161378989535e+04,
-7.2879702464464618998e+03,
1.0341910641583726701e+04,
1.1725046279757103576e+04,
4.4176707025325087628e+03,
7.4321196680624245801e+02,
4.8591703355916499363e+01])
Q2 = jnp.array([-3.5783478026152301072e+05,
2.4599102262586308984e+05,
-8.4055062591169562211e+04,
1.8680990008359188352e+04,
-2.9458766545509337327e+03,
3.3307310774649071172e+02,
-2.5258076240801555057e+01,
1.0])
assert len(P2) == len(Q2)
PC = jnp.array([2.2779090197304684302e+04,
4.1345386639580765797e+04,
2.1170523380864944322e+04,
3.4806486443249270347e+03,
1.5376201909008354296e+02,
8.8961548424210455236e-01])
QC = jnp.array([2.2779090197304684318e+04,
4.1370412495510416640e+04,
2.1215350561880115730e+04,
3.5028735138235608207e+03,
1.5711159858080893649e+02,
1.0])
assert len(PC) == len(QC)
PS = jnp.array([-8.9226600200800094098e+01,
-1.8591953644342993800e+02,
-1.1183429920482737611e+02,
-2.2300261666214198472e+01,
-1.2441026745835638459e+00,
-8.8033303048680751817e-03])
QS = jnp.array([5.7105024128512061905e+03,
1.1951131543434613647e+04,
7.2642780169211018836e+03,
1.4887231232283756582e+03,
9.0593769594993125859e+01,
1.0])
assert len(PS) == len(QS)
x1 = 2.4048255576957727686e+00
x2 = 5.5200781102863106496e+00
x11 = 6.160e+02
x12 = -1.42444230422723137837e-03
x21 = 1.4130e+03
x22 = 5.46860286310649596604e-04
one_div_root_pi = 5.641895835477562869480794515607725858e-01
def t1(x): # x<=4
y = x * x
r = v_ratio(y, P1, Q1)
factor = (x + x1) * ((x - x11/256) - x12);
return factor * r
def t2(x): # x<=8
y = 1 - (x * x)/64
r = v_ratio(y, P2, Q2)
factor = (x + x2) * ((x - x21/256) - x22)
return factor * r
def t3(x): #x>8
y = 8 / x
y2 = y * y
rc = v_ratio(y2, PC, QC)
rs = v_ratio(y2, PS, QS)
factor = one_div_root_pi / jnp.sqrt(x)
sx = jnp.sin(x)
cx = jnp.cos(x)
return factor * (rc * (cx + sx) - y * rs * (sx - cx))
x = jnp.abs(x)
return jnp.select(
[x == 0, x <= 4, x <= 8, x>8],
[1, t1(x), t2(x), t3(x)],
default = x) Test: plt.plot(x,J0(x))
plt.plot(x, sc.special.j0(x), ls="--") plt.plot(x, J0(x)-sc.special.j0(x)) |
If you can simplify the code and improve it , I would be glad... Or may be it is already a better implementation. Moreover, Scipy may be not the most accurate to compare with. |
Just bumping this thread--the current implementation of J0 and J1 yield nans and have numerically unstable gradients: Here is a MWE showing the issue: import jax.numpy as jnp
import jax
import matplotlib.pyplot as plt
from jax import grad, jit, vmap
from jax import jacfwd, jacrev, grad
from scipy.special import j0, j1, jn
def j1_jax(x):
return jax.scipy.special.bessel_jn(x,v=1,n_iter=50)[1]
def j0_jax(x):
return jax.scipy.special.bessel_jn(x,v=0,n_iter=50)[0]
grad_jax = vmap(grad(j1_jax))
def grad_scipy(x):
return (j0(x) - jn(2,x))/(2)
x = jnp.linspace(-10*jnp.pi,10*jnp.pi,1000)
fig, (ax1,ax2) = plt.subplots(2,1, figsize=(10,10))
ax1.set_title("J1 comparison")
ax1.plot(x,j1(x), label='scipy J1')
ax1.plot(x,j1_jax(x), label='jax J1', zorder=2)
ax1.legend()
ax2.set_title("J1 Gradient Comparison")
ax2.plot(x, grad_scipy(x), label='analytic grad')
ax2.plot(x,grad_jax(x), label="jax grad", zorder=2)
ax2.legend() |
Just to bump this - for our use case we seem ok with a bit of a hack, patching together the core Jax bessel function for small x with a different series expansion for large x. We currently mainly only need J1 out to ~ moderate values, but a similar approach can be used to generate up to arbitrary Jn and perhaps should be implemented. I should add that the core Jax implementation is very numerically unstable in 32 bit but much better in 64 bit. Quick demo: https://github.com/benjaminpope/sibylla/blob/main/notebooks/bessel_test.ipynb |
Just to update - I went the full way and just implemented the CEPHES version of J1 in pure-Jax, and it is now much faster and more numerically stable. Can do this for J0 too - @jecampagne has taken a similar approach there. Thoughts? |
I've now done a full pure-Jax implementation of the Scipy Bessel function routine for J0 and J1, and a non-iterative version of Jv which uses the first recurrence relation to construct v > 1 from J0 and J1 via J_v+1 = 2v/x Jv(x) - J_v-1(x). It does have poorer precision close to x=0 this way but is at machine precision wrt scipy for larger values and has good speed. https://github.com/benjaminpope/sibylla/blob/main/notebooks/bessel_test.ipynb @jakevdp - what do you think? I would do a pull request - wanted to check this is the sort of thing you might want. |
Just seconding to say I would love this functionality - I'm trying to use J1 for some heat equation simulations. Thanks for the links already @benjaminpope! |
Assigning to @jakevdp for further triage. |
Anyway yes happy to do the PR |
hi, |
No, they don't exist yet. The main blocker is the lack of an efficient way to compute them given available XLA operations. |
thanks for the very quick answer. |
I suspect the best way forward would be to implement series expansions for these functions directly in JAX. That said, we've found this to be quite brittle in practice, which is why I haven't pursued this. |
It works fine doing this directly in Jax: https://github.com/benjaminpope/sibylla/blob/main/notebooks/bessel_test.ipynb |
Looks great! Once this kind of implementation is in JAX, however, people tend to find dark corners where things fail to converge or otherwise behave strangely (e.g. #16996, #20769). That doesn't prevent us from adding these kinds of implementations, but this kind of experience makes me more hesitant to do so because of the support burden down the road. |
If these are in scipy, then isn't the scipy team looking to implement everything in the Array API eventually? If all you want are primals, you could do the implementation in SciPy. That would at least give you the functions you want for primals, but you may not have forward or reverse mode differentiation. Depending on how you implement them, you may also get differentiation? |
For functions like these, Scipy plans to support the array API via host-side callbacks to the scipy implementation, which is more-or-less useless to general JAX users (it won't execute on accelerators, won't work with autodiff, won't work with vmap, will cause significant slowdowns in JIT, etc.) so unfortunately I don't think we can hope for the Array API to save the day here. |
I think we basically need to have a standalone package for jax special functions that Jake and co aren't on the hook for maintaining. In the critical path for @shashankdholakia's PhD - probably some others? Who's going to lead it? Happy to contribute and/or lead... |
Drive-by comment: I know some folks started hacking on something with this goal here https://github.com/JAXtronomy/spexial, but it looks like it's not very active. Perhaps @cgiovanetti, @nstarman, or @jnibauer want to chime in. |
FYI, this and similar issues are on the top of my todo list to address the incorrectness/missing functionality/accuracy issues with JAX special functions. See also #27088 . |
Awesome! Just dropping this very recent reference for a GPU optimized algorithm for bessel functions here. Maybe useful for some of you… |
As a step forward here, I have validated scipy.special j0 and j1 against mpmath.besselj with a conclusion that the accuracy of scipy.special.jv is relatively good, although, it could be better: there exists inputs where scipy.special.j0 and j1 are moderately inaccurate or/and incorrect ( As reported in scipy/scipy#22705 and scipy/scipy#22706, the accuracy/correctness of scipy.special.j0 and j1 for small and large x can be easily improved. The higher order Bessel functions j2, j3, ... in scipy.special have the same accuracy issues as j0 and j1 (see pearu/functional_algorithms#102 (comment)) that most likely will be fixed after fixing the corresponding issues for j0 and j1. |
In the short term, yes, but that is only part of the story. The long term plan is a
See scipy/xsf#1 for context on interoperability work here. @steppi said in a footnote there
Not sure how related that is to the recent paper you linked @TobiBu . |
Feature Request bessel function in scipy ( J0 J1 J2 Y0 Y1 Y2) which have applied in autograd.
The text was updated successfully, but these errors were encountered: