Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adjoint method to find the gradient of the Laplace approximation/mode #343

Open
theorashid opened this issue May 15, 2024 · 3 comments
Open

Comments

@theorashid
Copy link
Contributor

theorashid commented May 15, 2024

This is part of INLA roadmap #340.

From the Stan paper:

One of the main bottlenecks is differentiating the estimated mode, $\theta^* $. In theory, it is straightforward to apply automatic differentiation, by bruteforce propagating derivatives through $\theta^* $, that is, sequentially differentiating the iterations of a numerical optimizer,
But this approach, termed the direct method, is prohibitively expensive. A much faster alternative is to use the implicit function theorem. Given any accurate numerical solver, we can always use the implicit function theorem to get derivatives. One side effect is that the numerical optimizer is treated as a black box. By contrast, Rasmussen and Williams [34] define a bespoke Newton method to compute $\theta^* $, meaning we can store relevant variables from the final Newton step when computing derivatives. In our experience, this leads to important computational savings. But overall this method is much less flexible, working well only when the number of hyperparameters is low dimensional and requiring the user to pass the tensor of derivatives.

I think the jax implementation uses the tensor of derivatives but not 100% sure.

@theorashid
Copy link
Contributor Author

The rewrites in optimistix might be helpful here to understand what is going on. Also see the docs.

@theorashid
Copy link
Contributor Author

Some notes on how they use this in Stan.

@theorashid
Copy link
Contributor Author

An example of this for the fixed point optimiser in jax.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant