-
Notifications
You must be signed in to change notification settings - Fork 7
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Matrix support #6
Comments
@Sohl-Dickstein seems to be in the camp of wanting support for multiple tree axes, to compute things like Jacobians, Hessians and covariances/correlations between pytrees. |
Weighing in here at @shoyer's request -- I'd put myself in the multiple-tree-axes camp as well, I think. (I am saying that primarily as an end user rather than a developer for the feature, though... !) This starts to open up notions of "labelled tree axes". For example, performing the pytree-axes equivalent of For a thorny reference problem in which this kind of stuff might get quite useful (tree-matrix vs tree-ndarray or otherwise), I'd suggest the Ito version of Milstein's method available here: This is a pretty tricky implementation! It's very heavily annotated with comments descrbing the various tree-axes, normal-axes, and the way in which they interact. |
+1 to the @Sohl-Dickstein use case. Some more detail of where this would be handy: I recently needed to invert a Hessian of a function that took a pytree as its argument. The headache I ran into was that when I used jax.jacfwd(jax.jacrev(f))(x) to compute the Hessian, I got it as a pytree of pytrees, which turned out to be pretty complicated to flatten. It would be nice to be able to either transform a pytree of pytrees to and from a matrix of floats or to be able to perform matrix operations directly on the pytree of pytrees. |
Just wanted to chime in and say that I'd love this feature, and for my use cases (which are primarily about numerical solvers for non-convex problems) a single axis is all I'd need, though I'm sure I'd find uses in multi-axis implementation if that does get developed. |
So it's not a documented feature, but Equinox actually has a tree-math like sublibrary built-in, which can be used to do this kind of multi-axis stuff. To set the scene, here is how it is used just to broadcast vector operations together: from equinox.internal import ω
vector1 = [0, 1, (2, 3)]
vector2 = [4, 5, (6, 7)]
summed = (ω(vector1) + ω(vector2)).ω
# Alternate notation; I prefer this when doing pure arithmetic:
summed = (vector1**ω + vector2**ω).ω
print(summed) # [4, 6, (8, 10)] But with a bit of thinking you can nest these to accomplish higher-order operations: # matrix has shape (2, 3)
matrix = ω([ω([0, 1, 2]), ω([3, 4, 5])])
# vector has shape (3,)
vector = ω([6, 7, 8])
# product (2, 3) @ (3,) -> (2,) ("call" applies the specified function to every leaf of its pytree)
matvec = matrix.call(lambda row: sum((row * vector).ω))
# unwrap
matvec = matvec.ω
print(matvec) # [23, 86] The reason this works is that ω is not a PyTree. This means that Conversely, this does mean that you musn't pass ω objects across JIT/grad/etc. API boundaries. (Whilst you can with tree-math.) ω is only meant to be used as a convenient syntax with the bounds of a single function. |
I do still think matrix support would be awesome to have, and I actually had a use-cases for this just last week. That said, at this point I'm relatively unlikely to work on it. It somebody else wants to give this a try that would be very welcome! |
I'd like to add a
Matrix
class to complementVector
.A key design question is what this needs to support. In particular: do we need to support multiple axes that correspond to flattened pytrees, or is only a single axis enough?
If we only need to support a single "tree axis", then most
Matrix
operations can be implemented essentially by callingvmap
on aVector
, and the implementation only needs to keep track of whether the "tree axis" on the underlying pytree is at the start or the end. This would suffice for use-cases like implementing L-BFGS or GMRES, which keep track of some fixed number of state vectors in the form of a matrix.In contrast, multiple "tree axes" would be required to fully support use cases where both the inputs and outputs of a linear map correspond to (possible different) pytrees. For example, consider the outputs of
jax.jacobian
on a pytree -> pytree function. Here the implemention would need to be more complex to keep track of the separate tree definitions for inputs/outputs, similar to my first attempt at implementing a tree vectorizing transformation: google/jax#3263.My inclination is to only implement the "single tree-axis" version of
matrix
, which the reasoning being that it suffices to implement most "efficient" numerical algorithms on large-scale inputs, which cannot afford to use O(n^2) memory. On the other hand, it does preclude the interesting use-case of using tree-math to implementjax.jacobian
(and variations).The text was updated successfully, but these errors were encountered: