Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Slow jax.grad of CubicSpline #41

Open
odstrcilt opened this issue Sep 5, 2024 · 1 comment
Open

Slow jax.grad of CubicSpline #41

odstrcilt opened this issue Sep 5, 2024 · 1 comment

Comments

@odstrcilt
Copy link

Interfax is a very nice package. However, if it is included in the code wrapped in jax.grad, it is about 10x slower than my simple cubic spline interpolation. Interpolation itself is not significantly different.
Code was written with the help of chatGPT.
It is not a high-priority issue.


#tridiagonal_solve supporting autodiff
@jax.jit
def tridiagonal_solve(a, b, c, d):
    """
    Solves a tridiagonal system Ax = d for x, where A is a tridiagonal matrix with
    diagonals a, b, and c.
    a - subdiagonal (length n-1)
    b - main diagonal (length n)
    c - superdiagonal (length n-1)
    d - right-hand side vector (length n)
    """
    n = len(d)
    ac, bc, cc, dc = map(jnp.array, (a, b, c, d))  # Ensure all inputs are arrays

    # Forward sweep using lax.fori_loop
    def forward_sweep(i, val):
        bc, dc = val
        w = ac[i-1] / bc[i-1]
        bc = bc.at[i].set(bc[i] - w * cc[i-1])
        dc = dc.at[i].set(dc[i] - w * dc[i-1])
        return bc, dc

    bc, dc = jax.lax.fori_loop(1, n, forward_sweep, (bc, dc))

    # Back substitution using lax.fori_loop
    def back_substitution(i, xc):
        idx = n - 2 - i
        xc = xc.at[idx].set((dc[idx] - cc[idx] * xc[idx+1]) / bc[idx])
        return xc

    xc = jnp.zeros_like(dc)
    xc = xc.at[-1].set(dc[-1] / bc[-1])
    xc = jax.lax.fori_loop(0, n-1, back_substitution, xc)

    return xc



#jax implementation of cubic spline
@jax.jit
def compute_cubic_spline_coeffs(x, y):
    n = len(x)
    h = x[1:] - x[:-1]
    
    # Create the tridiagonal matrix A components
    lower = jnp.zeros(n)
    upper = jnp.zeros(n)
    diag = jnp.ones(n)

    lower=lower.at[:-2].set(h[:-1])
    upper = upper.at[2:].set(h[1:])
    diag = diag.at[1:-1].set((h[:-1] + h[1:]) * 2)

  
    # Right-hand side vector B
    B = jnp.zeros(n)
    B = B.at[1:-1].set(3 * ((y[2:] - y[1:-1]) / h[1:] - (y[1:-1] - y[:-2]) / h[:-1]))
    
    c = tridiagonal_solve(lower[:-1], diag, upper[1:], B)
 
    # Compute b and d
    b = (y[1:] - y[:-1]) / h - h * (2 * c[:-1] + c[1:]) / 3
    d = (c[1:] - c[:-1]) / (3 * h)
    
    return x, y[:-1], b, c[:-1], d

@jax.jit
def evaluate_cubic_spline( coeffs, x_new):
    x, a, b, c, d = coeffs
    idx = jnp.searchsorted(x, x_new) - 1
    idx = jnp.clip(idx, 0, len(a) - 1)
    dx = x_new - x[idx]

    return a[idx] + dx * (b[idx] + dx * (c[idx] + d[idx] * dx))
        
@jax.jit
def cubic_spline(xi, x, y):
    c = compute_cubic_spline_coeffs(x, y)
    return evaluate_cubic_spline( c, xi)



 

@jax.jit
def integrate_cubic_spline(coeffs, A, B):
    #assume A <= B
    #NOTE Extrapolates spline by zeros on both ends
  
    x, a, b, c, d  = coeffs

    # Find the indices of intervals containing a and b
    idx = jnp.searchsorted(x, jnp.array([A,B])) - 1
 
    # Clip indices to ensure they are within bounds
    idx = jnp.clip(idx, 0, len(x) - 2)

 
    # Calculate the integration limits for each interval
    dx_a = A-x[idx[0]]
    
    x_clipped = jnp.clip(x, A, B)
    dx_b = x_clipped[1:]-x_clipped[:-1]
    dx_b = dx_b.at[idx[0]].add(dx_a)
    

     
  
    ## Calculate the integral contribution for each interval
    int_b = jnp.zeros_like(dx_b)
    int_a = 0
    for i, coeff in enumerate([d,c,b,a]):
        int_b = (int_b + coeff/(4-i)) * dx_b
        int_a = (int_a + coeff[idx[0]]/(4-i)) * dx_a
                
 
    ## Sum up the integral contributions from all intervals
    total_integral = jnp.sum(int_b)-int_a


    return total_integral


@f0uriest
Copy link
Owner

f0uriest commented Sep 6, 2024

Do you have an example of the code above being used compared to interpax? Also, are you running on CPU or GPU? On CPU the tridiagonal solve is likely faster, but on GPU the loops will cause a lot of overhead compared to just calling out to cusolver on a full matrix, though if that's the case I would expect to see a performance difference in the forward pass not just in the gradient.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants