Can an API for dealing with composite operations be defined? #781
Replies: 3 comments
-
I'd love to know why they didn't work. |
Beta Was this translation helpful? Give feedback.
-
@leofang for For
|
Beta Was this translation helpful? Give feedback.
-
import numpy as np
import numba
import jax
import jax.numpy as jnp
# Here `xp` and `compiler` can be injected
compiler = jax # numba
xp = jnp # np
@compiler.jit
def abs2(x):
"""Squared absolute value"""
return xp.real(x) * xp.real(x) + xp.imag(x) * xp.imag(x)
x = xp.arange(3000, dtype=xp.complex64)
# %timeit abs2(x)
# %timeit xp.abs(x) For NumPy + Numba: >>> %timeit abs2(x)
2.33 µs ± 10.5 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
>>> %timeit xp.abs(x)**2
3.82 µs ± 56.2 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each) For JAX: >>> %timeit abs2(x)
2.58 µs ± 17.6 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
>>> %timeit xp.abs(x)**2
86.1 µs ± 383 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each) For PyTorch, I tried with
There's other PyTorch compilers though, and this really should work given the right one (e.g., TorchDynamo: https://github.com/pytorch/torchdynamo#usage-example). There's also Transonic (see https://transonic.readthedocs.io/en/latest/backends/pythran.html), which has CuPy will work with Numba too: https://docs.cupy.dev/en/stable/user_guide/interoperability.html#numba. Triton has a I think the point of this issue is not that we should add an API for this or that someone should write a separate package for these (that is possible, but not sure it's high-value). More that this is the right way of doing it in principle, so every function that can easily be written this way probably should be - rather than expanding the API surface of the standard. |
Beta Was this translation helpful? Give feedback.
-
We've had discussions several times about whether some function that is a composite of other functions already present in the standard can be added. The most recent example is gh-460, which proposed adding
abs2
(=abs(x) ** 2
). Typically this isn't worth it, unless there are very compelling reasons. Fusing function calls is something that multiple libraries have compilers for, so there may not be a gain for those libraries to add a function likeabs2
, just extra API surface. And performance-wise, the gain must be quite large for it to be justified to add a function.The discussion then turned to "is it possible to write for example a standardized way for array API consumers to element-wise apply arbitrary functions to arrays"? Something along the lines of
np.vectorize
.Another direction could be to try and write a portable JIT-able set of functions. Something like:
There's multiple options for compilers here, and they work with different array libraries. Perhaps a more future-proof direction (
np.vectorize
wasn't quite a success, andnumexpr
type string expressions are a bit of a hack as well ...).This is not a worked out proposal, just opening it here as a follow-up to the discussion in gh-460 and as a tracker issue to collect more ideas and serve as a reference for when other composite functions are proposed to be added to the standard.
Beta Was this translation helpful? Give feedback.
All reactions