Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add onnx export #2596

Merged
merged 25 commits into from
Jul 31, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added SSIM metrics ([#2671](https://github.com/PyTorchLightning/pytorch-lightning/pull/2671))
- Added BLEU metrics ([#2535](https://github.com/PyTorchLightning/pytorch-lightning/pull/2535))

- Added support to export a model to ONNX format ([#2596](https://github.com/PyTorchLightning/pytorch-lightning/pull/2596))

- Added support for `Trainer(num_sanity_val_steps=-1)` to check all validation data before training ([#2246](https://github.com/PyTorchLightning/pytorch-lightning/pull/2246))

### Changed
Expand Down
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ PyTorch Lightning Documentation
transfer_learning
tpu
test_set
production_inference

.. toctree::
:maxdepth: 1
Expand Down
28 changes: 28 additions & 0 deletions docs/source/production_inference.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
Inference in Production
=======================
PyTorch Lightning eases the process of deploying models into production.


Exporting to ONNX
-----------------
PyTorch Lightning provides a handy function to quickly export your model to ONNX format, which allows the model to be independent of PyTorch and run on an ONNX Runtime.

To export your model to ONNX format call the `to_onnx` function on your Lightning Module with the filepath and input_sample.

.. code-block:: python

filepath = 'model.onnx'
model = SimpleModel()
input_sample = torch.randn((1, 64))
model.to_onnx(filepath, input_sample, export_params=True)

You can also skip passing the input sample if the `example_input_array` property is specified in your LightningModule.

Once you have the exported model, you can run it on your ONNX runtime in the following way:

.. code-block:: python

ort_session = onnxruntime.InferenceSession(filepath)
input_name = ort_session.get_inputs()[0].name
ort_inputs = {input_name: np.random.randn(1, 64).astype(np.float32)}
ort_outs = ort_session.run(None, ort_inputs)
1 change: 1 addition & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,4 @@ dependencies:
- wandb>=0.8.21
- neptune-client>=0.4.109
- horovod>=0.19.1
- onnxruntime>=1.3.0
39 changes: 39 additions & 0 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import inspect
import os
import re
import tempfile
from abc import ABC, abstractmethod
from argparse import Namespace
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
Expand Down Expand Up @@ -1723,6 +1724,44 @@ def _set_hparams(self, hp: Union[dict, Namespace, str]) -> None:
else:
self._hparams = hp

def to_onnx(self, file_path: str, input_sample: Optional[Tensor] = None, **kwargs):
"""Saves the model in ONNX format

Args:
file_path: The path of the file the model should be saved to.
input_sample: A sample of an input tensor for tracing.
**kwargs: Will be passed to torch.onnx.export function.

Example:
>>> class SimpleModel(LightningModule):
... def __init__(self):
... super().__init__()
... self.l1 = torch.nn.Linear(in_features=64, out_features=4)
...
... def forward(self, x):
... return torch.relu(self.l1(x.view(x.size(0), -1)))

>>> with tempfile.NamedTemporaryFile(suffix='.onnx', delete=False) as tmpfile:
... model = SimpleModel()
... input_sample = torch.randn((1, 64))
... model.to_onnx(tmpfile.name, input_sample, export_params=True)
... os.path.isfile(tmpfile.name)
True
"""
Borda marked this conversation as resolved.
Show resolved Hide resolved

if isinstance(input_sample, Tensor):
input_data = input_sample
elif self.example_input_array is not None:
input_data = self.example_input_array
else:
raise ValueError(f'input_sample and example_input_array tensors are both missing.')

if 'example_outputs' not in kwargs:
self.eval()
kwargs['example_outputs'] = self(input_data)

torch.onnx.export(self, input_data, file_path, **kwargs)
lezwon marked this conversation as resolved.
Show resolved Hide resolved

@property
def hparams(self) -> Union[AttributeDict, str]:
if not hasattr(self, '_hparams'):
Expand Down
2 changes: 2 additions & 0 deletions requirements/extra.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,5 @@ omegaconf>=2.0.0
# scipy>=0.13.3
scikit-learn>=0.20.0
torchtext>=0.3.1, <0.7 # TODO: temporary fix fix for compatibility
onnx>=1.7.0
onnxruntime>=1.3.0
Borda marked this conversation as resolved.
Show resolved Hide resolved
4 changes: 1 addition & 3 deletions tests/base/model_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,7 @@ def __init__(
self.test_step_end_called = False
self.test_epoch_end_called = False

# if you specify an example input, the summary will show input/output for each layer
# TODO: to be fixed in #1773
# self.example_input_array = torch.rand(5, 28 * 28)
self.example_input_array = torch.rand(5, 28 * 28)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought we were talking about having it as property, right? @awaelchli

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you mean the user overrides the property function?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, at least we talked about it but maybe there was some picking issue?


# build model
self.__build_model()
Expand Down
114 changes: 114 additions & 0 deletions tests/models/test_onnx_save.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
import os

import onnxruntime
import pytest
import torch
import numpy as np
import tests.base.develop_pipelines as tpipes
import tests.base.develop_utils as tutils
from pytorch_lightning import Trainer
from tests.base import EvalModelTemplate


def test_model_saves_with_input_sample(tmpdir):
"""Test that ONNX model saves with input sample and size is greater than 3 MB"""
model = EvalModelTemplate()
trainer = Trainer(max_epochs=1)
trainer.fit(model)

file_path = os.path.join(tmpdir, "model.onxx")
input_sample = torch.randn((1, 28 * 28))
model.to_onnx(file_path, input_sample)
assert os.path.isfile(file_path)
assert os.path.getsize(file_path) > 3e+06


def test_model_saves_with_example_output(tmpdir):
"""Test that ONNX model saves when provided with example output"""
model = EvalModelTemplate()
trainer = Trainer(max_epochs=1)
trainer.fit(model)

file_path = os.path.join(tmpdir, "model.onxx")
input_sample = torch.randn((1, 28 * 28))
model.eval()
example_outputs = model.forward(input_sample)
model.to_onnx(file_path, input_sample, example_outputs=example_outputs)
assert os.path.exists(file_path) is True


def test_model_saves_with_example_input_array(tmpdir):
"""Test that ONNX model saves with_example_input_array and size is greater than 3 MB"""
model = EvalModelTemplate()
file_path = os.path.join(tmpdir, "model.onxx")
model.to_onnx(file_path)
assert os.path.exists(file_path) is True
assert os.path.getsize(file_path) > 3e+06


@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
def test_model_saves_on_multi_gpu(tmpdir):
"""Test that ONNX model saves on a distributed backend"""
tutils.set_random_master_port()

trainer_options = dict(
default_root_dir=tmpdir,
max_epochs=1,
limit_train_batches=10,
limit_val_batches=10,
gpus=[0, 1],
distributed_backend='ddp_spawn',
progress_bar_refresh_rate=0
)

model = EvalModelTemplate()

tpipes.run_model_test(trainer_options, model)

file_path = os.path.join(tmpdir, "model.onxx")
model.to_onnx(file_path)
assert os.path.exists(file_path) is True


def test_verbose_param(tmpdir, capsys):
"""Test that output is present when verbose parameter is set"""
model = EvalModelTemplate()
file_path = os.path.join(tmpdir, "model.onxx")
model.to_onnx(file_path, verbose=True)
captured = capsys.readouterr()
assert "graph(%" in captured.out


def test_error_if_no_input(tmpdir):
"""Test that an exception is thrown when there is no input tensor"""
model = EvalModelTemplate()
model.example_input_array = None
file_path = os.path.join(tmpdir, "model.onxx")
with pytest.raises(ValueError, match=r'input_sample and example_input_array tensors are both missing'):
model.to_onnx(file_path)


def test_if_inference_output_is_valid(tmpdir):
"""Test that the output inferred from ONNX model is same as from PyTorch"""
model = EvalModelTemplate()
trainer = Trainer(max_epochs=5)
trainer.fit(model)

model.eval()
with torch.no_grad():
torch_out = model(model.example_input_array)

file_path = os.path.join(tmpdir, "model.onxx")
model.to_onnx(file_path, model.example_input_array, export_params=True)

ort_session = onnxruntime.InferenceSession(file_path)

def to_numpy(tensor):
return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()

# compute ONNX Runtime output prediction
ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(model.example_input_array)}
ort_outs = ort_session.run(None, ort_inputs)

# compare ONNX Runtime and PyTorch results
assert np.allclose(to_numpy(torch_out), ort_outs[0], rtol=1e-03, atol=1e-05)