-
Say I want to compare the efficiency of SGD and another quasi-newton optimizer and see which one gives faster convergence. I cannot simply count the number of iterations, because a quasi-newton step is generally more expansive than a gradient update, which makes that an unfair comparison. On the other hand, I also want to avoid measuring the wall clock time, since that makes the whole experiment unreproducible. Is there a way to measure the computational cost of a JAX function, probably by counting the number of instructions executed on the XLA device? |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
You can get a sense of this with the Ahead of Time Compilation APIs. For example: import jax
def f(M, x):
for i in range(10):
x = M @ x
return x
M = jax.numpy.ones((60, 60))
x = jax.numpy.ones(60)
compiled = jax.jit(f).lower(M, x).compile()
print(compiled.cost_analysis())
|
Beta Was this translation helpful? Give feedback.
You can get a sense of this with the Ahead of Time Compilation APIs. For example: