Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP pytorch/onnx backend #144

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft

WIP pytorch/onnx backend #144

wants to merge 5 commits into from

Conversation

aseyboldt
Copy link
Member

This exports a torch module (or function) using dynamo_export to onnx, and then uses the onnxruntime to sample from the model.

Still WIP, not tested much and the interface is not cleaned up yet. (this should just turn into a backend="torch-onnx" kwarg in compile_pymc_model` or so).

For experiments so far, it can be used like this:

# Set the ORT_DYLIB to point to the onnxruntime shared lib.
# This should be somewhere in the target dir after running `maturin develop --release`

%env ORT_DYLIB_PATH=/home/adr/build/cargo/release/libonnxruntime.so

import pymc as pm
import pytensor
import pytensor.tensor as pt
import numpy as np
import torch
import io
import nutpie

N = 1000

# Create a model that represents the model and compute the logp and gradient
class LogpModel(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        #self.matrix = torch.nn.Parameter(torch.zeros(1_000, N), requires_grad=False)

    def forward(self, x, bias=None):
        #return -(x * x).sum() * 0.5 + (self.matrix @ x).sum(), -x
        return -(x * x).sum() * 0.5, -x

model = LogpModel()

# Export the model to onnx
x0 = torch.ones(N)
onnx_program = torch.onnx.dynamo_export(model, x0)
onnx_code = io.BytesIO()
onnx_program.save(onnx_code)
onnx_code = onnx_code.getvalue()

# Set up onnx options
providers = nutpie._lib.OnnxProviders()
providers.add_cuda()
providers.add_cpu()

# Create nutpie model and sample
model = nutpie._lib.OnnxModel(N, onnx_code, providers)
settings = nutpie._lib.PyNutsSettings.Diag()
progress = nutpie._lib.ProgressType.none()
sampler = nutpie._lib.PySampler.from_onnx(settings, 6, model, progress)
tr = sampler.wait()

Current timing:

%%time
progress = nutpie._lib.ProgressType.none()
sampler = nutpie._lib.PySampler.from_onnx(settings, 6, model, progress)
tr = sampler.wait()
CPU times: user 11.8 s, sys: 1.8 s, total: 13.6 s
Wall time: 2.6 s

While numpyro takes ~18s:

with pm.Model() as model:
    pm.Normal("x", shape=N)

compiled = nutpie.compile_pymc_model(model, backend="numba")

%%time
with model:
    tr = pm.sample(nuts_sampler="numpyro", nuts_sampler_kwargs=dict(chain_method="vectorized"), progressbar=False)
CPU times: user 17 s, sys: 415 ms, total: 17.4 s
Wall time: 17.9 s

@ricardoV94
Copy link
Member

What's needed to use the Pytorch backend of PyTensor?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants