-
Notifications
You must be signed in to change notification settings - Fork 72
Can we make vectorized function compatible to JAX' numpy interface? #310
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
Today I experimented a little with JAX trying to port The # test_jax.py
import timeit
import jax.numpy as jnp
import numpy as np
from jax import jit
from pytransform3d.rotations._utils import norm_vector as norm_vector_original
# same function would work with `np.linalg` or `np.where`
def norm_vector_jax(v):
"""Normalize vector.
Parameters
----------
v : array-like, shape (n,)
nd vector
Returns
-------
u : array, shape (n,)
nd unit vector with norm 1 or the zero vector
"""
norm = jnp.linalg.norm(v, axis=0, ord=2)
result = jnp.where(norm == 0.0, v, jnp.asarray(v) / norm)
return result
if __name__ == "__main__":
N = 1000000 # a large array
v = np.random.randn(N)
# JIT compilation here
norm_vector_jax_jit = jit(norm_vector_jax)
print("Original:", norm_vector_original(v))
print("No-JIT:", norm_vector_jax(v))
print("JIT:", norm_vector_jax_jit(v))
def original_func(): return norm_vector_original(v)
def jax_func(): return norm_vector_jax(v)
def jax_jit_func(): return norm_vector_jax_jit(v)
num_executions = 100
t_original = timeit.timeit(
original_func, number=num_executions) / num_executions
t_jax = timeit.timeit(jax_func, number=num_executions) / num_executions
t_jax_jit = timeit.timeit(
jax_jit_func, number=num_executions) / num_executions
print("\nBenchmark:")
print("Original:", t_original)
print("No-JIT", t_jax)
print("JIT", t_jax_jit) JIT version brings 2x performance of numpy in my small benchmark.
|
You are right. And it will be hard to implement more complex functionality with JAX support though. In this case every operation involved would have to be JAX compatible (e.g., ScLERP, which uses a lot of basic vectorized functions). I also tried to convert some of the more complex functions. The main problem are operations like A[..., :3, :3] = ... Instead of doing this, I'd recommend to compute the parts and concatenate them later. Not sure how well that works with conditional code though. |
It seems I missed that sentence in our last discussion, @AlexanderFabisch:
So you suggest, as a rough guideline do the processing in JAX and assignment back in python, e.g. # current implementation, contains `A[..., :3, :3]` stuff
compute_something(A)
# instead we would do:
def compute_something(A: np.ndarray):
if len(A.shape) == 3:
A[..., :3, :3] = compute_something_jax_3D(A)
if len(A.shape) == 4:
A[..., :3, :3] = compute_something_jax_4D(A) Something like that? |
Direct assignment to parts of an array is not possible. You have to use a weird syntax like |
My idea was to still use numpy as an interface for the users. For compute heavy functions with a large amount of data, the library can use the JAX for some compute heavy code-sections. Before calling it however, we would turn the numpy tensors into As I understood the overall goal to use JAX is to make the library faster than numpy when more advanced compute (GPU, or exotic CPU instructions). So we could turn at least some code-sections (maybe not complete functions) JAX-compiled. Thinking out loud, I guess your initial goal was to expose everything as JAX, so that a user can decide how she compiles the logic instead of us optimizing small codesections within the library. |
Here is the prototype that I am working on: https://github.com/AlexanderFabisch/jaxtransform3d Exposing the jax interface allows us to jax.jit any functions, differentiate any function, and vectorize any function. Everything is still compatible to numpy, since we allow ArrayLike types. Take a look at some examples and tests to see what I mean.
This could also be an option for pytransform3d, but I fear that this might make the code way too complicated. |
Ok the auto-differentiation and vectorization argument convinced me, you are right and that makes much more sense! With JAX the compiler can look at the complete user logic + library logic and maybe even compile/summarize things more efficiently. You are right, thanks for taking your time and clarify the motivation! |
Sure, it's still good to see another perspective! I think I will leave the issue open for now. |
Or maybe it is rather a discussion than an issue? |
I would also see it as a discussion - since you started Usually in larger projects (like |
OK! |
This would make JIT compilation and GPU parallelization possible.
The text was updated successfully, but these errors were encountered: