-
Hi Jax team, I am recently building a distributed reinforcement learning platform. The rough idea is to have an actor worker perform rollouts ( sample actions, accumulate training data, then transfer the training data to the learner). Then the learner performs gradient updates on the training data. The actor resides in GPU0 and the learner GPU1. I did some profiling and found this -- the learner blocks the actor thread when the learner attempts to fetch data from the actor's device to the learner's device. Here is the reproduction code, and below is the shortened code. def prepare_data(obs, actions, logprobs):
obs = jnp.asarray(obs)
actions = jnp.asarray(actions)
logprobs = jnp.asarray(logprobs)
# dummy operation
a = jnp.ones((1000, 1000))
b = a @ a
b_obs = obs.reshape((-1,) + obs.shape[2:])
b_actions = actions.reshape(-1)
b_logprobs = logprobs.reshape(-1)
return b_obs, b_actions, b_logprobs
def rollout(params, rollout_queue, key):
num_envs = 20
cpu_next_obs = np.zeros((num_envs, 4, 84, 84))
for update in range(20):
if update == 4:
jax.profiler.start_trace('./profile')
obs = []
actions = []
logprobs = []
for _ in range(384):
next_obs, action, logprob, key = sample(params, cpu_next_obs, key)
cpu_action = jax.device_get(action)
# env.send(cpu_action)
obs.append(next_obs)
actions.append(action)
logprobs.append(logprob)
rollout_queue.put((obs, actions, logprobs))
jax.profiler.stop_trace()
if __name__ == "__main__":
devices = jax.devices()
assert len(devices) >= 2
key = jax.random.PRNGKey(0)
network = Network()
params = network.init(key, np.zeros((1, 4, 84, 84)))
rollout_queue = queue.Queue(maxsize=1)
threading.Thread(
target=rollout,
args=(
params,
rollout_queue,
key,
)
).start()
prepare_data = jax.jit(prepare_data, device=devices[1])
for update in range(20):
obs, actions, logprobs = rollout_queue.get()
b_obs, b_actions, b_logprobs = prepare_data(obs, actions, logprobs)
print(update) The profiling shows that when Is there any way to not have the device transfers block the actor thread's computation? Thanks a lot! ---- quick update: I realized when switching to a different machine, the transfers appear become P2P transfers. However, the transfers still block the actor threads.
Test with both |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
Without digging deeply into the reproduction, my guess is that the issue is the GPU memory allocator JAX uses. Fundamentally memory allocation is synchronized to the compute stream, so it's common for transfers to block waiting for compute or vice versa if an allocated block is not known to be free. I'll need to dig more to be sure that's it, though. |
Beta Was this translation helpful? Give feedback.
Without digging deeply into the reproduction, my guess is that the issue is the GPU memory allocator JAX uses. Fundamentally memory allocation is synchronized to the compute stream, so it's common for transfers to block waiting for compute or vice versa if an allocated block is not known to be free.
I'll need to dig more to be sure that's it, though.