Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature Request bessel function in scipy ( J0 J1 J2 Y0 Y1 Y2) which have applied in autograd. #12402

Open
xichaoqiang opened this issue Sep 17, 2022 · 30 comments
Assignees
Labels
enhancement New feature or request

Comments

@xichaoqiang
Copy link

Feature Request bessel function in scipy ( J0 J1 J2 Y0 Y1 Y2) which have applied in autograd.

image

@xichaoqiang xichaoqiang added the enhancement New feature or request label Sep 17, 2022
@jakevdp
Copy link
Collaborator

jakevdp commented Sep 19, 2022

Thanks for the request. Related feature requests are at #9956 and #11002.

@benjaminpope
Copy link

Seconded - hoping to use J1, J0 in simple interferometry applications, without necessarily needing the higher order. Is there a reason these are hard to do in Jax?

@KostadinovShalon
Copy link

Hi, I'd love to see this implementation. I had to move to another framework because I needed J0 and J1 for a nerf related project.

@LouisDesdoigts
Copy link

Seconding the request for J0, J1 functions! Would implementing these in Jax be similar to other packages such as scipy?

@benjaminpope
Copy link

benjaminpope commented Jan 17, 2023

I actually see Bessel Jv has now been implemented! But it seems to be numerically unstable:

def j1(x):
    # jax version
    return jax.scipy.special.bessel_jn(x,v=1)[1]

from scipy.special import j1 as j1s # scipy version

z = np.linspace(0,100,1000)

plt.plot(z,vmap(j1)(z))
plt.plot(z,j1s(z))

This is improved by moving to higher n_iter.

image

@jecampagne
Copy link
Contributor

Hi Notice also instabilities at high x which incline to use high n_iter but this is in conjunction withnan at low x:

N=300
x = jnp.linspace(0,N,50*N)
jscs.bessel_jn(x, v=0, n_iter=200)

leads to

DeviceArray([[       nan,        nan,        nan, ..., 0.88876336,
              0.89054682, 0.8923063 ]], dtype=float64)

while

plot(x,jscs.bessel_jn(x, v=0, n_iter=200).squeeze())
plot(x,scs.jn(0,x),ls="--") # scipy.special.jn

gives
image

@jakevdp
Copy link
Collaborator

jakevdp commented Jan 24, 2023

cc/ @tlu7

@jecampagne
Copy link
Contributor

Here may be a solution for J0(x) using discussion #14132 and Boost code;

def _evaluate_rational(z, num, denom):

    count = len(num)

    def true_fn_update(z):

      def body_true(val1):
        # decode val1
        s1, s2, i = val1 
        s1 *= z; s2 *= z; s1 += num[i]; s2 += denom[i];
        return s1, s2, i-1

      def cond_true(val1):
        s1, s2, i = val1 
        return i>=0

      val1_init = (num[-1], denom[-1], count-2)
      s1, s2, _ = jax.lax.while_loop(cond_true, body_true, val1_init)

      return s1/s2

    def false_fn_update(z):
      def body_false(val1):
        # decode val1
        s1, s2, i = val1 
        s1 *= z; s2 *= z; s1 += num[i]; s2 += denom[i];
        return s1, s2, i+1

      def cond_false(val1):
        s1, s2, i = val1 
        return i<count

      val1_init = (num[0], denom[0],1)
      s1, s2, _ = jax.lax.while_loop(cond_false, body_false, val1_init)

      return s1/s2


    return jnp.where(z<=1, true_fn_update(z), false_fn_update(1/z))  

v_ratio = jax.vmap(_evaluate_rational, in_axes=(0,None, None))

def J0(x):
  P1 = jnp.array([-4.1298668500990866786e+11,
                  2.7282507878605942706e+10,
                  -6.2140700423540120665e+08,
                  6.6302997904833794242e+06,
                  -3.6629814655107086448e+04,
                  1.0344222815443188943e+02,
                  -1.2117036164593528341e-01]) 
  Q1 = jnp.array([2.3883787996332290397e+12,
                  2.6328198300859648632e+10,
                  1.3985097372263433271e+08,
                  4.5612696224219938200e+05,
                  9.3614022392337710626e+02,
                  1.0,
                  0.0])
  assert len(P1) == len(Q1)

  P2 = jnp.array([-1.8319397969392084011e+03,
                  -1.2254078161378989535e+04,
                  -7.2879702464464618998e+03,
                  1.0341910641583726701e+04,
                  1.1725046279757103576e+04,
                  4.4176707025325087628e+03,
                  7.4321196680624245801e+02,
                  4.8591703355916499363e+01])
  Q2 = jnp.array([-3.5783478026152301072e+05,
                  2.4599102262586308984e+05,
                  -8.4055062591169562211e+04,
                  1.8680990008359188352e+04,
                  -2.9458766545509337327e+03,
                  3.3307310774649071172e+02,
                  -2.5258076240801555057e+01,
                  1.0])
  assert len(P2) == len(Q2)


  PC = jnp.array([2.2779090197304684302e+04,
                  4.1345386639580765797e+04,
                  2.1170523380864944322e+04,
                  3.4806486443249270347e+03,
                  1.5376201909008354296e+02,
                  8.8961548424210455236e-01])
  QC = jnp.array([2.2779090197304684318e+04,
                  4.1370412495510416640e+04,
                  2.1215350561880115730e+04,
                  3.5028735138235608207e+03,
                  1.5711159858080893649e+02,
                  1.0])
  
  assert len(PC) == len(QC)


  PS = jnp.array([-8.9226600200800094098e+01,
                  -1.8591953644342993800e+02,
                  -1.1183429920482737611e+02,
                  -2.2300261666214198472e+01,
                  -1.2441026745835638459e+00,
                  -8.8033303048680751817e-03])
  QS = jnp.array([5.7105024128512061905e+03,
                  1.1951131543434613647e+04,
                  7.2642780169211018836e+03,
                  1.4887231232283756582e+03,
                  9.0593769594993125859e+01,
                  1.0])
  assert len(PS) == len(QS)


  x1 = 2.4048255576957727686e+00
  x2 = 5.5200781102863106496e+00
  x11 = 6.160e+02
  x12 = -1.42444230422723137837e-03
  x21 = 1.4130e+03
  x22 = 5.46860286310649596604e-04
  one_div_root_pi =  5.641895835477562869480794515607725858e-01



  def t1(x):  # x<=4
    y = x * x
    r = v_ratio(y, P1, Q1)
    factor = (x + x1) * ((x - x11/256) - x12);
    return factor * r
  
  def t2(x): # x<=8
    y = 1 - (x * x)/64
    r = v_ratio(y, P2, Q2)
    factor = (x + x2) * ((x - x21/256) - x22)
    return factor * r

  def t3(x): #x>8
      y = 8 / x
      y2 = y * y
      rc = v_ratio(y2, PC, QC)
      rs = v_ratio(y2, PS, QS)
      factor = one_div_root_pi / jnp.sqrt(x)
      sx = jnp.sin(x)
      cx = jnp.cos(x)
      return factor * (rc * (cx + sx) - y * rs * (sx - cx))



  x = jnp.abs(x)
  return jnp.select(
      [x == 0, x <= 4, x <= 8, x>8],
      [1, t1(x), t2(x), t3(x)],
      default = x)

Test:

plt.plot(x,J0(x))
plt.plot(x, sc.special.j0(x), ls="--")

image

plt.plot(x, J0(x)-sc.special.j0(x))

image

@jecampagne
Copy link
Contributor

jecampagne commented Feb 2, 2023

If you can simplify the code and improve it , I would be glad... Or may be it is already a better implementation. Moreover, Scipy may be not the most accurate to compare with.

@shashankdholakia
Copy link

Just bumping this thread--the current implementation of J0 and J1 yield nans and have numerically unstable gradients:

Here is a MWE showing the issue:

import jax.numpy as jnp
import jax
import matplotlib.pyplot as plt
from jax import grad, jit, vmap
from jax import jacfwd, jacrev, grad


from scipy.special import j0, j1, jn

def j1_jax(x):
    return jax.scipy.special.bessel_jn(x,v=1,n_iter=50)[1]

def j0_jax(x):
    return jax.scipy.special.bessel_jn(x,v=0,n_iter=50)[0]

grad_jax = vmap(grad(j1_jax))

def grad_scipy(x):
    return (j0(x) - jn(2,x))/(2)

x = jnp.linspace(-10*jnp.pi,10*jnp.pi,1000)

fig, (ax1,ax2) = plt.subplots(2,1, figsize=(10,10))
ax1.set_title("J1 comparison")
ax1.plot(x,j1(x), label='scipy J1')
ax1.plot(x,j1_jax(x), label='jax J1', zorder=2)
ax1.legend()

ax2.set_title("J1 Gradient Comparison")
ax2.plot(x, grad_scipy(x), label='analytic grad')
ax2.plot(x,grad_jax(x), label="jax grad", zorder=2)
ax2.legend()

image

@benjaminpope
Copy link

benjaminpope commented Apr 3, 2023

Just to bump this - for our use case we seem ok with a bit of a hack, patching together the core Jax bessel function for small x with a different series expansion for large x. We currently mainly only need J1 out to ~ moderate values, but a similar approach can be used to generate up to arbitrary Jn and perhaps should be implemented.

I should add that the core Jax implementation is very numerically unstable in 32 bit but much better in 64 bit.

Quick demo:

https://github.com/benjaminpope/sibylla/blob/main/notebooks/bessel_test.ipynb

@benjaminpope
Copy link

benjaminpope commented Apr 4, 2023

Just to update - I went the full way and just implemented the CEPHES version of J1 in pure-Jax, and it is now much faster and more numerically stable. Can do this for J0 too - @jecampagne has taken a similar approach there.

Thoughts?

@benjaminpope
Copy link

I've now done a full pure-Jax implementation of the Scipy Bessel function routine for J0 and J1, and a non-iterative version of Jv which uses the first recurrence relation to construct v > 1 from J0 and J1 via J_v+1 = 2v/x Jv(x) - J_v-1(x). It does have poorer precision close to x=0 this way but is at machine precision wrt scipy for larger values and has good speed.

https://github.com/benjaminpope/sibylla/blob/main/notebooks/bessel_test.ipynb

@jakevdp - what do you think? I would do a pull request - wanted to check this is the sort of thing you might want.

@femtomc
Copy link

femtomc commented Apr 4, 2023

Just seconding to say I would love this functionality - I'm trying to use J1 for some heat equation simulations. Thanks for the links already @benjaminpope!

@axch
Copy link
Contributor

axch commented Jun 2, 2023

Assigning to @jakevdp for further triage.

@benjaminpope
Copy link

Anyway yes happy to do the PR

@TobiBu
Copy link

TobiBu commented Mar 6, 2025

hi,
I just came along this in the search for the bessel function of second kind in jax. it seem these still don't exist, right?
Is there any progress on this somewhere else?

@jakevdp
Copy link
Collaborator

jakevdp commented Mar 6, 2025

No, they don't exist yet. The main blocker is the lack of an efficient way to compute them given available XLA operations.

@TobiBu
Copy link

TobiBu commented Mar 7, 2025

thanks for the very quick answer.
If there is any plan how to proceed or how to achieve an implementation i am happy to give it a try...

@jakevdp
Copy link
Collaborator

jakevdp commented Mar 10, 2025

I suspect the best way forward would be to implement series expansions for these functions directly in JAX. That said, we've found this to be quite brittle in practice, which is why I haven't pursued this.

@benjaminpope
Copy link

@jakevdp
Copy link
Collaborator

jakevdp commented Mar 11, 2025

It works fine doing this directly in Jax: https://github.com/benjaminpope/sibylla/blob/main/notebooks/bessel_test.ipynb

Looks great! Once this kind of implementation is in JAX, however, people tend to find dark corners where things fail to converge or otherwise behave strangely (e.g. #16996, #20769).

That doesn't prevent us from adding these kinds of implementations, but this kind of experience makes me more hesitant to do so because of the support burden down the road.

@NeilGirdhar
Copy link
Contributor

If these are in scipy, then isn't the scipy team looking to implement everything in the Array API eventually? If all you want are primals, you could do the implementation in SciPy. That would at least give you the functions you want for primals, but you may not have forward or reverse mode differentiation.

Depending on how you implement them, you may also get differentiation?

@jakevdp
Copy link
Collaborator

jakevdp commented Mar 11, 2025

If these are in scipy, then isn't the scipy team looking to implement everything in the Array API eventually?

For functions like these, Scipy plans to support the array API via host-side callbacks to the scipy implementation, which is more-or-less useless to general JAX users (it won't execute on accelerators, won't work with autodiff, won't work with vmap, will cause significant slowdowns in JIT, etc.) so unfortunately I don't think we can hope for the Array API to save the day here.

@benjaminpope
Copy link

I think we basically need to have a standalone package for jax special functions that Jake and co aren't on the hook for maintaining. In the critical path for @shashankdholakia's PhD - probably some others? Who's going to lead it? Happy to contribute and/or lead...

@dfm
Copy link
Collaborator

dfm commented Mar 12, 2025

Drive-by comment: I know some folks started hacking on something with this goal here https://github.com/JAXtronomy/spexial, but it looks like it's not very active. Perhaps @cgiovanetti, @nstarman, or @jnibauer want to chime in.

@pearu
Copy link
Collaborator

pearu commented Mar 12, 2025

FYI, this and similar issues are on the top of my todo list to address the incorrectness/missing functionality/accuracy issues with JAX special functions. See also #27088 .

@TobiBu
Copy link

TobiBu commented Mar 12, 2025

Awesome!

Just dropping this very recent reference for a GPU optimized algorithm for bessel functions here. Maybe useful for some of you…

@pearu
Copy link
Collaborator

pearu commented Mar 17, 2025

As a step forward here, I have validated scipy.special j0 and j1 against mpmath.besselj with a conclusion that the accuracy of scipy.special.jv is relatively good, although, it could be better: there exists inputs where scipy.special.j0 and j1 are moderately inaccurate or/and incorrect (x < sqrt(eps) or x > 1/sqrt(eps)).

As reported in scipy/scipy#22705 and scipy/scipy#22706, the accuracy/correctness of scipy.special.j0 and j1 for small and large x can be easily improved.

The higher order Bessel functions j2, j3, ... in scipy.special have the same accuracy issues as j0 and j1 (see pearu/functional_algorithms#102 (comment)) that most likely will be fixed after fixing the corresponding issues for j0 and j1.

@lucascolley
Copy link

Scipy plans to support the array API via host-side callbacks to the scipy implementation

In the short term, yes, but that is only part of the story. The long term plan is a special extension to the standard, similar to xp.fft and xp.linalg - see data-apis/array-api#725. So @benjaminpope's suggestion of a standalone package for JAX is along the right lines. The proposal there has focused on things that are already implemented in JAX, so there is this to say about the bessel functions

One point I'll add is that, of course, there are many other special functions that are widely used. Indeed, many of the ones important to me (like the Bessel functions) are not covered by the above list. Why? We started with a minimal set of functions that are easily implementable everywhere. It's not super helpful to propose a standard for a special function that other array libraries will not implement because it's too much effort.

See scipy/xsf#1 for context on interoperability work here. @steppi said in a footnote there

It will be difficult, but developing a GPU capabable implementation of Miller's recurrence algorithm to perform argument reduction with a recurrence in cases where the recurrence is only stable in one direction (and not the one we want), is needed for us to implement accurate Bessel and hypergeometric functions which work on the GPU. This is something I've been hoping to work on for a while now, but I don't think it should be priortized over the four items below

Not sure how related that is to the recent paper you linked @TobiBu .

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests