Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add initial support for PyTorch backend #764

Merged
merged 41 commits into from
Jun 20, 2024

Conversation

HarshvirSandhu
Copy link
Contributor

Description

Add PyTorch support for few basic Pytensor Ops using torch.compile method and following the existing JAX architecture.
Also replicated this PyTorch sandbox example with this backend.

cc @ricardoV94 @aseyboldt

Type of change

  • New feature / enhancement

Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@ricardoV94 ricardoV94 added the torch PyTorch backend label May 13, 2024
Copy link
Member

@aseyboldt aseyboldt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey :-)
Great to have some code already!
I left a few comments while skimming through the code, just ignore if they are not helpful right now.

pytensor/link/pytorch/dispatch/basic.py Outdated Show resolved Hide resolved
pytensor/link/pytorch/dispatch/basic.py Outdated Show resolved Hide resolved
pytensor/link/pytorch/dispatch/basic.py Outdated Show resolved Hide resolved
pytensor/link/pytorch/dispatch/elemwise.py Outdated Show resolved Hide resolved
pytensor/link/pytorch/dispatch/elemwise.py Outdated Show resolved Hide resolved
pytensor/link/pytorch/dispatch/elemwise.py Outdated Show resolved Hide resolved
pytensor/link/pytorch/dispatch/scalar.py Outdated Show resolved Hide resolved
pytensor/link/pytorch/dispatch/scalar.py Outdated Show resolved Hide resolved
pytensor/link/pytorch/dispatch/scalar.py Outdated Show resolved Hide resolved
Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

First approach looks okay although some things that were copied from JAX don't necessarily make sense for Pytorch. We should remove those.

Also may be good to start adding some tests

pytensor/compile/mode.py Outdated Show resolved Hide resolved
pytensor/link/pytorch/dispatch/basic.py Outdated Show resolved Hide resolved
pytensor/link/pytorch/linker.py Outdated Show resolved Hide resolved
@ricardoV94 ricardoV94 changed the title Pytorch support to pytensor Ops Add PyTorch backend to PyTensor May 24, 2024
@ricardoV94 ricardoV94 changed the title Add PyTorch backend to PyTensor Add initial PyTorch backend May 24, 2024
pytensor/tensor/basic.py Outdated Show resolved Hide resolved
pytensor/tensor/sharedvar.py Outdated Show resolved Hide resolved
pytensor/tensor/sharedvar.py Outdated Show resolved Hide resolved
pytensor/tensor/sharedvar.py Outdated Show resolved Hide resolved
tests/link/pytorch/test_basic.py Outdated Show resolved Hide resolved
@ricardoV94 ricardoV94 changed the title Add initial PyTorch backend Add initial support for PyTorch backend May 24, 2024
Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is getting closer to ready. I suggest removing all dispatch that are not being tested as we can add the functionality in follow up PRs

.github/workflows/test.yml Outdated Show resolved Hide resolved
pytensor/tensor/basic.py Outdated Show resolved Hide resolved
pytensor/tensor/sharedvar.py Outdated Show resolved Hide resolved
pytensor/link/pytorch/dispatch/elemwise.py Outdated Show resolved Hide resolved
pytensor/link/pytorch/dispatch/scalar.py Outdated Show resolved Hide resolved
pytensor/link/pytorch/dispatch/scalar.py Outdated Show resolved Hide resolved
pytensor/link/pytorch/dispatch/scalar.py Outdated Show resolved Hide resolved
Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@HarshvirSandhu I'm sorry I didn't mean that we should parametrize all tests with device, only those that are explicitly about inputs/ outputs, which are the first 3 of test_basic:
test_pytorch_FunctionGraph_once
test_shared
test_shared_updates

All others you don't have to worry about device

tests/link/pytorch/test_basic.py Show resolved Hide resolved
tests/link/pytorch/test_basic.py Show resolved Hide resolved
Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great!

@ricardoV94 ricardoV94 marked this pull request as ready for review June 17, 2024 13:49
@ricardoV94
Copy link
Member

ricardoV94 commented Jun 17, 2024

Some failing tests: AssertionError: https://github.com/pymc-devs/pytensor/actions/runs/9547073282/job/26311326610?pr=764#step:6:287

Torch not compiled with CUDA enabled

I actually don't know if github CI/CD gives us access to GPUs, so we may need to skip the GPU parametrization when the device is not available (so it can still be tested locally if someone has GPU support, but does not fail on Github)

@ricardoV94
Copy link
Member

Yeah GPUs are only available on team/enterprise github: https://docs.github.com/en/enterprise-cloud@latest/actions/using-github-hosted-runners/about-larger-runners/about-larger-runners

We need to skip the gpu parametrizations when cuda is not available. You should be able to use pytest.skip inside the test. This is also useful so people can test locally even if they don't have CUDA


@pytest.mark.parametrize("device", ["cpu", "cuda"])
def test_pytorch_FunctionGraph_once(device):
if torch.cuda.is_available() is False:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only want to skip on the GPU parametrization

Suggested change
if torch.cuda.is_available() is False:
if device == "cuda" and not torch.cuda.is_available():

Same thing in the other tests

@ricardoV94 ricardoV94 merged commit 320bac4 into pymc-devs:main Jun 20, 2024
56 of 57 checks passed
@ricardoV94
Copy link
Member

Nice work @HarshvirSandhu !

@HAKSOAT HAKSOAT mentioned this pull request Jun 23, 2024
11 tasks
@ricardoV94 ricardoV94 added enhancement New feature or request major labels Jul 3, 2024
@Ch0ronomato Ch0ronomato mentioned this pull request Jul 15, 2024
11 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request major torch PyTorch backend
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants