Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How are PJRT asynchronous executions throttled by torch_xla? #8380

Open
mcuiaws opened this issue Nov 14, 2024 · 7 comments
Open

How are PJRT asynchronous executions throttled by torch_xla? #8380

mcuiaws opened this issue Nov 14, 2024 · 7 comments

Comments

@mcuiaws
Copy link
Contributor

mcuiaws commented Nov 14, 2024

🐛 Bug

Here at AWS we have a single PJRT device plugin for both PyTorch and JAX, and recently we've made implements to our device plugin to make it work better with JAX. I.e. now PJRT_LoadedExecutable_Execute() is fully asynchronous, we queue up an execution and return immediately, and expect the caller to wait on the returned_future, whereas before, execution was synchronous and is completed when PJRT_LoadedExecutable_Execute() returns.

As soon as we switched to the new implementation, we noticed that now torch_xla queues up as many executions it can without any throttling in PJRT or torch_xla, which causes us to easily exhaust device memory. It appears that now that there are no internal throttling mechanisms, and only explicit ones which needs to be triggered by user code:

  1. when xm.wait_device_ops() is called, which calls down to WaitDeviceOps()
  2. when tensor is read, which internally calls WaitDeviceOps()
    However, WaitDeviceOps() is a heavy hammer because it pauses the world until the entire pipeline is drained. Ideally we do not want to rely on this mechanism for throttling. Also we do not want the user to have to guess when to insert these calls to avoid running out of memory. Some sensible internal throttling mechanism is needed.

The main issue here is that pjrt_computation_client.cc does not await on the returned_future from PJRT. It simply throws it away.

However, according to torch's lazy_graph_executor, "only one asynchronous operation can execute at the same time, on a given device." This is controlled by a device lock, which is supposed to be held for the entire duration of the asynchronous execution. However, in torch_xla's xla_graph_executor.cpp, the device locks acquired by torch are released as soon as ExecuteComputation() returns, and ExecuteComputaton() does not actually wait for the actual computation to complete. Therefore, torch lazy_graph_executor's throttling mechanism is defeated here.

mcuiaws added a commit to mcuiaws/pytorch-xla that referenced this issue Nov 14, 2024
This is a proposed fixed for pytorch#8380. However, I'm not sure if
the lack of throttling is by design. Therefore, whether to
wait on executions os enabled with an environment variable.
@JackCaoG
Copy link
Collaborator

We controlled it by XLA_TPU_MAX_INFLIGHT_COMPUTATIONS but I guess that's a TPU specified flag.

@mcuiaws
Copy link
Contributor Author

mcuiaws commented Nov 18, 2024

We can do something similar for neuron. How is XLA_TPU_MAX_INFLIGHT_COMPUTATIONS implemented? Do you block inside PJRT_LoadedExecutable_Execute() if the client queues up too many executions?

@JackCaoG
Copy link
Collaborator

We defined TPU as a plugin and define the client create option in

def client_create_options(self):
return {
'max_inflight_computations':
xu.getenv_as('XLA_TPU_MAX_INFLIGHT_COMPUTATIONS', int, 32),
'ml_framework_name':
'PyTorch/XLA',
'ml_framework_version':
__version__
}
. My understanding is that as long as you specified that PJRT client will handle the rest.

@mcuiaws
Copy link
Contributor Author

mcuiaws commented Nov 18, 2024

So it sounds like for TPUs you want 32 inflight. Does that mean it's by design that you are breaking torch lazy_graph_executor's contract of "only one asynchronous operation can execute at the same time, on a given device"?

@JackCaoG
Copy link
Collaborator

only one asynchronous operation can execute at the same time was a old design choice for the XRT runtime. Back then the async execution was implemented in the torch_xla level. It is more of a design constrain not a design choice.

Ever since we move to the PJRT runtime, the runtime itself supports async transfer and async execution, it is better to let runtime handle this kind of stuff. We want to make sure the program is not tracing bound so it is better to unblock from tracing as many graphs as possible.

@mcuiaws
Copy link
Contributor Author

mcuiaws commented Nov 21, 2024

Should we move the max inflight logic to torch_xla's prjt_computation_client.cc? PJRT's Execute APIs are asynchronous, and asynchronous APIs should not block, ideally...

@JackCaoG
Copy link
Collaborator

hmm, torch_xla's execute and pjrt's execute are both async. max_inflight_computations option in PJRT is the standard way to control max async execution, XLA uses for TPU, GPU and CPU. In our case when async execution in PJRT blocks, it won't relase the device lock hence the main thread that does the tracing will also block, which is exactly the behavior you need I think.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants