-
Notifications
You must be signed in to change notification settings - Fork 120
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add initial support for PyTorch backend #764
Conversation
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey :-)
Great to have some code already!
I left a few comments while skimming through the code, just ignore if they are not helpful right now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
First approach looks okay although some things that were copied from JAX don't necessarily make sense for Pytorch. We should remove those.
Also may be good to start adding some tests
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is getting closer to ready. I suggest removing all dispatch that are not being tested as we can add the functionality in follow up PRs
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@HarshvirSandhu I'm sorry I didn't mean that we should parametrize all tests with device, only those that are explicitly about inputs/ outputs, which are the first 3 of test_basic
:
test_pytorch_FunctionGraph_once
test_shared
test_shared_updates
All others you don't have to worry about device
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great!
Some failing tests: AssertionError: https://github.com/pymc-devs/pytensor/actions/runs/9547073282/job/26311326610?pr=764#step:6:287
I actually don't know if github CI/CD gives us access to GPUs, so we may need to skip the GPU parametrization when the device is not available (so it can still be tested locally if someone has GPU support, but does not fail on Github) |
Yeah GPUs are only available on team/enterprise github: https://docs.github.com/en/enterprise-cloud@latest/actions/using-github-hosted-runners/about-larger-runners/about-larger-runners We need to skip the gpu parametrizations when cuda is not available. You should be able to use |
tests/link/pytorch/test_basic.py
Outdated
|
||
@pytest.mark.parametrize("device", ["cpu", "cuda"]) | ||
def test_pytorch_FunctionGraph_once(device): | ||
if torch.cuda.is_available() is False: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Only want to skip on the GPU parametrization
if torch.cuda.is_available() is False: | |
if device == "cuda" and not torch.cuda.is_available(): |
Same thing in the other tests
Nice work @HarshvirSandhu ! |
Description
Add PyTorch support for few basic Pytensor Ops using
torch.compile
method and following the existing JAX architecture.Also replicated this PyTorch sandbox example with this backend.
cc @ricardoV94 @aseyboldt
Type of change