Specifying id_tap
dependency without data copy
#14775
-
Hi, I'm trying to implement a host callback that records the current time while running computation on GPUs. import time
import jax
import jax.numpy as jnp
from jax import random
from jax.experimental import host_callback
TIME_RECORDS = [] # <- Q1. Race?
def record_time(args, transforms):
"""Record the current time and return."""
TIME_RECORDS.append(time.time())
@jax.jit
def step(x):
"""Run computation on x and record time points in between."""
x = host_callback.id_tap(record_time, None, result=x)
y = x @ x @ x @ x @ x @ x @ x @ x @ x @ x @ x
y = host_callback.id_tap(record_time, y, result=y)
# ^
# |
# Q2. Extra GPU -> CPU copy? --------+
return y
if __name__ == "__main__":
step(random.normal(random.PRNGKey(0), (10240, 10240)))
print(TIME_RECORDS) Output:
I have two questions in this context: Q1) Is this potentially a race condition on the global variable Q2) I would like the second Thanks a lot. |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 3 replies
-
Could you try using In general, there is no way to enforce a callback runs after a certain computation without adding data dependency (and extra copying) as you've done. As a result, the callbacks aren't reliable for timing and measuring performance. I'd suggest using the profiler instead. |
Beta Was this translation helpful? Give feedback.
Could you try using
jax.debug.callback(..., ordered=True)
? This is our newer API though and by sayingordered=True
you are forcing callbacks to be run sequentially relative to each other.host_callback
behaves like ordered debug callback.In general, there is no way to enforce a callback runs after a certain computation without adding data dependency (and extra copying) as you've done. As a result, the callbacks aren't reliable for timing and measuring performance. I'd suggest using the profiler instead.