Skip to content

How to measure computational cost in JAX? #14828

Answered by jakevdp
nalzok asked this question in Q&A
Discussion options

You must be logged in to vote

You can get a sense of this with the Ahead of Time Compilation APIs. For example:

import jax

def f(M, x):
  for i in range(10):
    x = M @ x
  return x

M = jax.numpy.ones((60, 60))
x = jax.numpy.ones(60)

compiled = jax.jit(f).lower(M, x).compile()
print(compiled.cost_analysis())
[{'bytes accessed': 148800.0,
  'bytes accessed operand 1 {}': 2400.0,
  'utilization operand 0 {}': 10.0,
  'utilization operand 1 {}': 10.0,
  'bytes accessed output {}': 2400.0,
  'bytes accessed operand 0 {}': 144000.0,
  'flops': 72000.0}]

Replies: 1 comment

Comment options

You must be logged in to vote
0 replies
Answer selected by nalzok
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants