You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Interfax is a very nice package. However, if it is included in the code wrapped in jax.grad, it is about 10x slower than my simple cubic spline interpolation. Interpolation itself is not significantly different.
Code was written with the help of chatGPT.
It is not a high-priority issue.
#tridiagonal_solve supporting autodiff
@jax.jit
def tridiagonal_solve(a, b, c, d):
"""
Solves a tridiagonal system Ax = d for x, where A is a tridiagonal matrix with
diagonals a, b, and c.
a - subdiagonal (length n-1)
b - main diagonal (length n)
c - superdiagonal (length n-1)
d - right-hand side vector (length n)
"""
n = len(d)
ac, bc, cc, dc = map(jnp.array, (a, b, c, d)) # Ensure all inputs are arrays
# Forward sweep using lax.fori_loop
def forward_sweep(i, val):
bc, dc = val
w = ac[i-1] / bc[i-1]
bc = bc.at[i].set(bc[i] - w * cc[i-1])
dc = dc.at[i].set(dc[i] - w * dc[i-1])
return bc, dc
bc, dc = jax.lax.fori_loop(1, n, forward_sweep, (bc, dc))
# Back substitution using lax.fori_loop
def back_substitution(i, xc):
idx = n - 2 - i
xc = xc.at[idx].set((dc[idx] - cc[idx] * xc[idx+1]) / bc[idx])
return xc
xc = jnp.zeros_like(dc)
xc = xc.at[-1].set(dc[-1] / bc[-1])
xc = jax.lax.fori_loop(0, n-1, back_substitution, xc)
return xc
#jax implementation of cubic spline
@jax.jit
def compute_cubic_spline_coeffs(x, y):
n = len(x)
h = x[1:] - x[:-1]
# Create the tridiagonal matrix A components
lower = jnp.zeros(n)
upper = jnp.zeros(n)
diag = jnp.ones(n)
lower=lower.at[:-2].set(h[:-1])
upper = upper.at[2:].set(h[1:])
diag = diag.at[1:-1].set((h[:-1] + h[1:]) * 2)
# Right-hand side vector B
B = jnp.zeros(n)
B = B.at[1:-1].set(3 * ((y[2:] - y[1:-1]) / h[1:] - (y[1:-1] - y[:-2]) / h[:-1]))
c = tridiagonal_solve(lower[:-1], diag, upper[1:], B)
# Compute b and d
b = (y[1:] - y[:-1]) / h - h * (2 * c[:-1] + c[1:]) / 3
d = (c[1:] - c[:-1]) / (3 * h)
return x, y[:-1], b, c[:-1], d
@jax.jit
def evaluate_cubic_spline( coeffs, x_new):
x, a, b, c, d = coeffs
idx = jnp.searchsorted(x, x_new) - 1
idx = jnp.clip(idx, 0, len(a) - 1)
dx = x_new - x[idx]
return a[idx] + dx * (b[idx] + dx * (c[idx] + d[idx] * dx))
@jax.jit
def cubic_spline(xi, x, y):
c = compute_cubic_spline_coeffs(x, y)
return evaluate_cubic_spline( c, xi)
@jax.jit
def integrate_cubic_spline(coeffs, A, B):
#assume A <= B
#NOTE Extrapolates spline by zeros on both ends
x, a, b, c, d = coeffs
# Find the indices of intervals containing a and b
idx = jnp.searchsorted(x, jnp.array([A,B])) - 1
# Clip indices to ensure they are within bounds
idx = jnp.clip(idx, 0, len(x) - 2)
# Calculate the integration limits for each interval
dx_a = A-x[idx[0]]
x_clipped = jnp.clip(x, A, B)
dx_b = x_clipped[1:]-x_clipped[:-1]
dx_b = dx_b.at[idx[0]].add(dx_a)
## Calculate the integral contribution for each interval
int_b = jnp.zeros_like(dx_b)
int_a = 0
for i, coeff in enumerate([d,c,b,a]):
int_b = (int_b + coeff/(4-i)) * dx_b
int_a = (int_a + coeff[idx[0]]/(4-i)) * dx_a
## Sum up the integral contributions from all intervals
total_integral = jnp.sum(int_b)-int_a
return total_integral
The text was updated successfully, but these errors were encountered:
Do you have an example of the code above being used compared to interpax? Also, are you running on CPU or GPU? On CPU the tridiagonal solve is likely faster, but on GPU the loops will cause a lot of overhead compared to just calling out to cusolver on a full matrix, though if that's the case I would expect to see a performance difference in the forward pass not just in the gradient.
Interfax is a very nice package. However, if it is included in the code wrapped in jax.grad, it is about 10x slower than my simple cubic spline interpolation. Interpolation itself is not significantly different.
Code was written with the help of chatGPT.
It is not a high-priority issue.
The text was updated successfully, but these errors were encountered: