Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add torch_xla.experimental.compile for eager mode #7246

Merged
merged 3 commits into from
Jun 12, 2024

Conversation

JackCaoG
Copy link
Collaborator

This should only be used for the eager mode. The compile pretty much enable the LTC before entering the function and disable it again.

TODO

  1. add more unit test

@JackCaoG JackCaoG added the usability Bugs/features related to improving the usability of PyTorch/XLA label Jun 11, 2024
@JackCaoG JackCaoG requested review from lsy323 and qihqi June 11, 2024 23:05
@JackCaoG JackCaoG marked this pull request as ready for review June 12, 2024 00:18
@JackCaoG
Copy link
Collaborator Author

ok test added, should be ready for review.

Copy link
Collaborator

@qihqi qihqi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm, few qqs:

result = func(*args, **kwargs)
except Exception as e:
# Handle exceptions (if needed)
print(f"Error in target function: {e}")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here the exception is tracing exception right?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea, execution is async so we won't be able to catch it here.

print(f"Error in target function: {e}")
raise # Re-raise the exception
# Sync the graph generated by the target function.
torch_xla.sync()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there actaully runs the graph and you might get exceptions here too.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

right, the way that LTC works is that async execution happens in a separate thread and the runtime error will be set in the unlocker. Next time when we try to get the device lock it will find that exception and throw
https://github.com/pytorch/xla/blob/master/torch_xla/csrc/xla_graph_executor.cpp#L861-L864

That being said I agree with you that there is no harm to put sync in the try region. Let me update that in the following pr.

"""

@functools.wraps(func) # Keep function's name, docstring, etc.
def wrapper(*args, **kwargs):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so is the mechanism for caching the graph is already there right so no need to do anything extra?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we still need to trace the whole model(run all python code), we just skip XLA compilation and lowering to HLO part.

This compile does not modify the function bytecode or transform the function in anyway besides enabling the tracing mode. It is actually more accurate to call it trace but compile is more align with pytorch API.

@JackCaoG JackCaoG merged commit 90168e8 into master Jun 12, 2024
22 checks passed
@JackCaoG JackCaoG added the eager label Jun 20, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
eager usability Bugs/features related to improving the usability of PyTorch/XLA
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants