-
Notifications
You must be signed in to change notification settings - Fork 483
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
How are PJRT asynchronous executions throttled by torch_xla? #8380
Comments
This is a proposed fixed for pytorch#8380. However, I'm not sure if the lack of throttling is by design. Therefore, whether to wait on executions os enabled with an environment variable.
We controlled it by |
We can do something similar for neuron. How is |
We defined TPU as a plugin and define the client create option in xla/torch_xla/_internal/tpu.py Lines 352 to 360 in 91f5c8a
|
So it sounds like for TPUs you want 32 inflight. Does that mean it's by design that you are breaking torch lazy_graph_executor's contract of "only one asynchronous operation can execute at the same time, on a given device"? |
Ever since we move to the PJRT runtime, the runtime itself supports async transfer and async execution, it is better to let runtime handle this kind of stuff. We want to make sure the program is not tracing bound so it is better to unblock from tracing as many graphs as possible. |
Should we move the max inflight logic to torch_xla's prjt_computation_client.cc? PJRT's Execute APIs are asynchronous, and asynchronous APIs should not block, ideally... |
hmm, torch_xla's execute and pjrt's execute are both async. |
🐛 Bug
Here at AWS we have a single PJRT device plugin for both PyTorch and JAX, and recently we've made implements to our device plugin to make it work better with JAX. I.e. now
PJRT_LoadedExecutable_Execute()
is fully asynchronous, we queue up an execution and return immediately, and expect the caller to wait on thereturned_future
, whereas before, execution was synchronous and is completed whenPJRT_LoadedExecutable_Execute()
returns.As soon as we switched to the new implementation, we noticed that now torch_xla queues up as many executions it can without any throttling in PJRT or torch_xla, which causes us to easily exhaust device memory. It appears that now that there are no internal throttling mechanisms, and only explicit ones which needs to be triggered by user code:
xm.wait_device_ops()
is called, which calls down toWaitDeviceOps()
WaitDeviceOps()
However,
WaitDeviceOps()
is a heavy hammer because it pauses the world until the entire pipeline is drained. Ideally we do not want to rely on this mechanism for throttling. Also we do not want the user to have to guess when to insert these calls to avoid running out of memory. Some sensible internal throttling mechanism is needed.The main issue here is that pjrt_computation_client.cc does not await on the
returned_future
from PJRT. It simply throws it away.However, according to torch's lazy_graph_executor, "only one asynchronous operation can execute at the same time, on a given device." This is controlled by a device lock, which is supposed to be held for the entire duration of the asynchronous execution. However, in torch_xla's xla_graph_executor.cpp, the device locks acquired by torch are released as soon as
ExecuteComputation()
returns, andExecuteComputaton()
does not actually wait for the actual computation to complete. Therefore, torch lazy_graph_executor's throttling mechanism is defeated here.The text was updated successfully, but these errors were encountered: