Skip to content

More flexible Hutchinson's implementation in Jax #14261

Answered by mattjj
yaroslavvb asked this question in Q&A
Discussion options

You must be logged in to vote

Cool!

One quick thought is to replace some of those loops with calls to jax.vmap. For example, keeping only the loop over random sampling (because it's currently implemented with numpy):

def hutchinsons(f, xs, d, k):
    """Estimates average Hessian trace using improved Hutchinson's estimator.

    f: R^d->R function to differentiate
    xs: batch of examples
    d: number of dimensions
    k: number of orthogonal vectors per pass, k=d gives exact result
    """

    s = 1 if k==d else 2   # number of stochastic samples to use per example
    trace = 0.

    tr = lambda vs: jax.vmap(lambda x: jax.vmap(lambda v: v @ hvp(f, [x], [v]))(vs).sum())(xs).sum()
    for sample in range(s):
      t…

Replies: 2 comments 2 replies

Comment options

You must be logged in to vote
2 replies
@yaroslavvb
Comment options

@mattjj
Comment options

Answer selected by yaroslavvb
Comment options

You must be logged in to vote
0 replies
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants