What is the best practice to track TPU memory usage? #9756
-
For GPU, it is common to use commands like Using the profiling tool seems to be the only way to achieve this. However, it seems that JAX support for the profiling tool is relatively poorer than Tensorflow. For example, there is no data for profiles captured by JAX Therefore, I would like to know what is the best practice to track TPU memory usage? |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 3 replies
-
See https://github.com/ayaka14732/jax-smi |
Beta Was this translation helpful? Give feedback.
-
There's also a new CLI for checking TPU utilization metrics, which isn't specific to JAX, called
|
Beta Was this translation helpful? Give feedback.
See https://github.com/ayaka14732/jax-smi