More flexible Hutchinson's implementation in Jax #14261
-
There's a flexible version of Hutchison's which recovers exact trace computation as a special case. This was motivated by discussion with @dpfau , where k=1 Hutchinson seemed not accurate enough k=1 case is should have the same complexity as 10 backward passes (1 hvp=5 backprops, use 2 samples)
|
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 2 replies
-
Cool! One quick thought is to replace some of those loops with calls to def hutchinsons(f, xs, d, k):
"""Estimates average Hessian trace using improved Hutchinson's estimator.
f: R^d->R function to differentiate
xs: batch of examples
d: number of dimensions
k: number of orthogonal vectors per pass, k=d gives exact result
"""
s = 1 if k==d else 2 # number of stochastic samples to use per example
trace = 0.
tr = lambda vs: jax.vmap(lambda x: jax.vmap(lambda v: v @ hvp(f, [x], [v]))(vs).sum())(xs).sum()
for sample in range(s):
trace += tr(random_ortho(d)[:k])
trace /= s * len(xs) # average over examples and samples
trace *= d / k # bias correction
return trace Also, we should apply If we didn't use At some point we'll be limited by memory and so we cant keep vmapping. It would then make sense to use Is this the direction you had in mind? |
Beta Was this translation helpful? Give feedback.
-
Thanks for the tips @mattjj . This approach seems faster for exact Hessian trace than jax.hessian, some timings in colab
|
Beta Was this translation helpful? Give feedback.
Cool!
One quick thought is to replace some of those loops with calls to
jax.vmap
. For example, keeping only the loop over random sampling (because it's currently implemented with numpy):