-
Hi all! I'm defining functions which accept I expose these functions to users -- and in the semantics of my functions, taking the gradient with Is there a way to expose a wrapped function whose gradient with respect to full Another solution would suffice: Is it possible to take a function which accepts the full Potential other other solution: If I allow The issue is that I'm doing gradient descent with |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 3 replies
-
Ah, I found this: #12765 Hmm -- so, I have control over the functions I pass over to the users to take What if I define a Has anyone tried this before? |
Beta Was this translation helpful? Give feedback.
-
Thanks for the question – unfortunately there's not a great solution for this at the moment. We encountered a similar problem in We've been talking about ways to improve on this, for example by allowing pytrees to register their preferred autodiff behavior in the same way that Thanks for the feedback - I hope we can provide a lighter-weight solution to this kind of situation in the near future. |
Beta Was this translation helpful? Give feedback.
Thanks for the question – unfortunately there's not a great solution for this at the moment. We encountered a similar problem in
jax.experimental.sparse
, in which we want autodiff to operate with respect to the data buffer, but not the indices buffer. Our solution there was to provide sparse-specific autodiff wrappers that know how to ignore the integer components in the sparse pytree: http://go/jax-github/blob/main/jax/experimental/sparse/ad.pyWe've been talking about ways to improve on this, for example by allowing pytrees to register their preferred autodiff behavior in the same way that
register_vmappable
allows user-definedvmap
behavior (see how it's used in e.g. http://go/jax-gith…