Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Matrix support #6

Open
shoyer opened this issue Dec 29, 2021 · 6 comments
Open

Matrix support #6

shoyer opened this issue Dec 29, 2021 · 6 comments

Comments

@shoyer
Copy link
Member

shoyer commented Dec 29, 2021

I'd like to add a Matrix class to complement Vector.

A key design question is what this needs to support. In particular: do we need to support multiple axes that correspond to flattened pytrees, or is only a single axis enough?

If we only need to support a single "tree axis", then most Matrix operations can be implemented essentially by calling vmap on a Vector, and the implementation only needs to keep track of whether the "tree axis" on the underlying pytree is at the start or the end. This would suffice for use-cases like implementing L-BFGS or GMRES, which keep track of some fixed number of state vectors in the form of a matrix.

In contrast, multiple "tree axes" would be required to fully support use cases where both the inputs and outputs of a linear map correspond to (possible different) pytrees. For example, consider the outputs of jax.jacobian on a pytree -> pytree function. Here the implemention would need to be more complex to keep track of the separate tree definitions for inputs/outputs, similar to my first attempt at implementing a tree vectorizing transformation: google/jax#3263.

My inclination is to only implement the "single tree-axis" version of matrix, which the reasoning being that it suffices to implement most "efficient" numerical algorithms on large-scale inputs, which cannot afford to use O(n^2) memory. On the other hand, it does preclude the interesting use-case of using tree-math to implement jax.jacobian (and variations).

@shoyer
Copy link
Member Author

shoyer commented Dec 30, 2021

@Sohl-Dickstein seems to be in the camp of wanting support for multiple tree axes, to compute things like Jacobians, Hessians and covariances/correlations between pytrees.

@shoyer shoyer mentioned this issue Dec 30, 2021
@patrick-kidger
Copy link

Weighing in here at @shoyer's request --

I'd put myself in the multiple-tree-axes camp as well, I think. (I am saying that primarily as an end user rather than a developer for the feature, though... !)

This starts to open up notions of "labelled tree axes". For example, performing the pytree-axes equivalent of jax.vmap(jax.vmap(operator.mul, in_axes=(0, None)), in_axes=(None, 0)) (an outer product), in which two different BatchTraces interact.

For a thorny reference problem in which this kind of stuff might get quite useful (tree-matrix vs tree-ndarray or otherwise), I'd suggest the Ito version of Milstein's method available here:

https://github.com/patrick-kidger/diffrax/blob/10b652e1d91518ac182e8d832ff309f7c199a9a0/diffrax/solver/milstein.py#L104

This is a pretty tricky implementation! It's very heavily annotated with comments descrbing the various tree-axes, normal-axes, and the way in which they interact.

@geoff-davis
Copy link

+1 to the @Sohl-Dickstein use case. Some more detail of where this would be handy: I recently needed to invert a Hessian of a function that took a pytree as its argument. The headache I ran into was that when I used jax.jacfwd(jax.jacrev(f))(x) to compute the Hessian, I got it as a pytree of pytrees, which turned out to be pretty complicated to flatten. It would be nice to be able to either transform a pytree of pytrees to and from a matrix of floats or to be able to perform matrix operations directly on the pytree of pytrees.

@njwfish
Copy link

njwfish commented May 18, 2023

Just wanted to chime in and say that I'd love this feature, and for my use cases (which are primarily about numerical solvers for non-convex problems) a single axis is all I'd need, though I'm sure I'd find uses in multi-axis implementation if that does get developed.

@patrick-kidger
Copy link

So it's not a documented feature, but Equinox actually has a tree-math like sublibrary built-in, which can be used to do this kind of multi-axis stuff.

To set the scene, here is how it is used just to broadcast vector operations together:

from equinox.internal import ω

vector1 = [0, 1, (2, 3)]
vector2 = [4, 5, (6, 7)]
summed = (ω(vector1) + ω(vector2)).ω
# Alternate notation; I prefer this when doing pure arithmetic:
summed = (vector1**ω + vector2**ω).ω
print(summed)  # [4, 6, (8, 10)]

But with a bit of thinking you can nest these to accomplish higher-order operations:

# matrix has shape (2, 3)
matrix = ω([ω([0, 1, 2]), ω([3, 4, 5])])
# vector has shape (3,)
vector = ω([6, 7, 8])
# product (2, 3) @ (3,) -> (2,)     ("call" applies the specified function to every leaf of its pytree)
matvec = matrix.call(lambda row: sum((row * vector).ω))
# unwrap
matvec = matvec.ω
print(matvec)  # [23, 86]

The reason this works is that ω is not a PyTree. This means that matrix = ω([ω([0, 1, 2]), ω([3, 4, 5])]) doesn't have the outer ω looking inside the inner ωs. (I believe tree-math's Vector is a PyTree and that the same trick wouldn't work in this library, though.)

Conversely, this does mean that you musn't pass ω objects across JIT/grad/etc. API boundaries. (Whilst you can with tree-math.) ω is only meant to be used as a convenient syntax with the bounds of a single function.

@shoyer
Copy link
Member Author

shoyer commented May 19, 2023

I do still think matrix support would be awesome to have, and I actually had a use-cases for this just last week.

That said, at this point I'm relatively unlikely to work on it. It somebody else wants to give this a try that would be very welcome!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants