Skip to content

Can we make vectorized function compatible to JAX' numpy interface? #310

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
AlexanderFabisch opened this issue Nov 27, 2024 · 11 comments
Closed

Comments

@AlexanderFabisch
Copy link
Member

This would make JIT compilation and GPU parallelization possible.

@kopytjuk
Copy link
Contributor

kopytjuk commented Dec 6, 2024

Today I experimented a little with JAX trying to port pytransform3d.rotations._utils.norm_vector and it looks like that we have to adjust the existing code (see below). I achieved a 2x speedup. There is a complete guide with potential pitfalls when porting functions to JAX, so at least it is well documented what to do.

The norm_vector was only an easy function, for me it feels like, that adapting functions like dual_quaternions_from_screw_parameters is a lot of work. I think we first need to identify what functions are the slowest in the library, and JAX-rewrite/compile them. Rewriting everything to JAX is a very large and little-use effort imho.

# test_jax.py
import timeit

import jax.numpy as jnp
import numpy as np
from jax import jit

from pytransform3d.rotations._utils import norm_vector as norm_vector_original


# same function would work with `np.linalg` or `np.where`
def norm_vector_jax(v):
    """Normalize vector.

    Parameters
    ----------
    v : array-like, shape (n,)
        nd vector

    Returns
    -------
    u : array, shape (n,)
        nd unit vector with norm 1 or the zero vector
    """

    norm = jnp.linalg.norm(v, axis=0, ord=2)
    result = jnp.where(norm == 0.0, v, jnp.asarray(v) / norm)
    return result


if __name__ == "__main__":
    N = 1000000  # a large array
    v = np.random.randn(N)

    # JIT compilation here
    norm_vector_jax_jit = jit(norm_vector_jax)

    print("Original:", norm_vector_original(v))
    print("No-JIT:", norm_vector_jax(v))
    print("JIT:", norm_vector_jax_jit(v))

    def original_func(): return norm_vector_original(v)
    def jax_func(): return norm_vector_jax(v)
    def jax_jit_func(): return norm_vector_jax_jit(v)

    num_executions = 100
    t_original = timeit.timeit(
        original_func, number=num_executions) / num_executions
    t_jax = timeit.timeit(jax_func, number=num_executions) / num_executions
    t_jax_jit = timeit.timeit(
        jax_jit_func, number=num_executions) / num_executions

    print("\nBenchmark:")
    print("Original:", t_original)
    print("No-JIT", t_jax)
    print("JIT", t_jax_jit)

JIT version brings 2x performance of numpy in my small benchmark.

Original: [-0.0013723   0.001101   -0.00120956 ...  0.00043465  0.00133065
 -0.00048264]
No-JIT: [-0.0013723   0.001101   -0.00120956 ...  0.00043465  0.00133065
 -0.00048264]
JIT: [-0.0013723   0.001101   -0.00120956 ...  0.00043465  0.00133065
 -0.00048264]

Benchmark:
Original: 0.0011862749996362253
No-JIT 0.0016969608300132677
JIT 0.0006298654095735401

@AlexanderFabisch
Copy link
Member Author

Rewriting everything to JAX is a very large and little-use effort imho.

You are right.

And it will be hard to implement more complex functionality with JAX support though. In this case every operation involved would have to be JAX compatible (e.g., ScLERP, which uses a lot of basic vectorized functions).

I also tried to convert some of the more complex functions. The main problem are operations like

A[..., :3, :3] = ...

Instead of doing this, I'd recommend to compute the parts and concatenate them later. Not sure how well that works with conditional code though.

@kopytjuk
Copy link
Contributor

kopytjuk commented Mar 6, 2025

It seems I missed that sentence in our last discussion, @AlexanderFabisch:

I'd recommend to compute the parts and concatenate them later.

So you suggest, as a rough guideline do the processing in JAX and assignment back in python, e.g.

# current implementation, contains `A[..., :3, :3]` stuff
compute_something(A)

# instead we would do:

def compute_something(A: np.ndarray):
    if len(A.shape) == 3:
        A[..., :3, :3] = compute_something_jax_3D(A)
    if len(A.shape) == 4:
        A[..., :3, :3] = compute_something_jax_4D(A)

Something like that?

@AlexanderFabisch
Copy link
Member Author

Direct assignment to parts of an array is not possible. You have to use a weird syntax like A.at[..., :3, :3].set(value) with jax, which produces a new array A of the original shape of A with the updated part [..., :3, :3]. This is so much different from standard numpy that I decided that it is too hard to use that in pytransform3d. I am working on a similar library that uses jax exclusively at the moment. It will have a smaller set of features and rely on pytransform3d for testing and visualization.

@kopytjuk
Copy link
Contributor

kopytjuk commented Mar 6, 2025

My idea was to still use numpy as an interface for the users.

For compute heavy functions with a large amount of data, the library can use the JAX for some compute heavy code-sections. Before calling it however, we would turn the numpy tensors into jax.Arrays and convert back, after the compiled section is complete.

As I understood the overall goal to use JAX is to make the library faster than numpy when more advanced compute (GPU, or exotic CPU instructions). So we could turn at least some code-sections (maybe not complete functions) JAX-compiled.

Thinking out loud, I guess your initial goal was to expose everything as JAX, so that a user can decide how she compiles the logic instead of us optimizing small codesections within the library.

@AlexanderFabisch
Copy link
Member Author

Here is the prototype that I am working on: https://github.com/AlexanderFabisch/jaxtransform3d

Exposing the jax interface allows us to jax.jit any functions, differentiate any function, and vectorize any function. Everything is still compatible to numpy, since we allow ArrayLike types. Take a look at some examples and tests to see what I mean.

For compute heavy functions with a large amount of data, the library can use the JAX for some compute heavy code-sections. Before calling it however, we would turn the numpy tensors into jax.Arrays and convert back, after the compiled section is complete.

This could also be an option for pytransform3d, but I fear that this might make the code way too complicated.

@kopytjuk
Copy link
Contributor

kopytjuk commented Mar 6, 2025

differentiate any function, and vectorize any function

Ok the auto-differentiation and vectorization argument convinced me, you are right and that makes much more sense! With JAX the compiler can look at the complete user logic + library logic and maybe even compile/summarize things more efficiently. You are right, thanks for taking your time and clarify the motivation!

@AlexanderFabisch
Copy link
Member Author

Sure, it's still good to see another perspective! I think I will leave the issue open for now.

@AlexanderFabisch
Copy link
Member Author

Or maybe it is rather a discussion than an issue?

@kopytjuk
Copy link
Contributor

kopytjuk commented Mar 6, 2025

Or maybe it is rather a discussion than an issue?

I would also see it as a discussion - since you started jaxtransform3d already and we both agree on complexities of rewriting the current code I am in favor for closing. I mean if someone brings it up again, we can still reopen it :)

Usually in larger projects (like pandas) one can always look through closed issues to get some information on a particular topic.

@AlexanderFabisch
Copy link
Member Author

OK!

@AlexanderFabisch AlexanderFabisch moved this from Backlog to Done in pytransform3d Mar 6, 2025
@AlexanderFabisch AlexanderFabisch self-assigned this Mar 6, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
Archived in project
Development

No branches or pull requests

2 participants