Skip to content

Commit

Permalink
use example_input_array
Browse files Browse the repository at this point in the history
add to changelog
  • Loading branch information
lezwon committed Jul 13, 2020
1 parent b3d3e7a commit 6c6e233
Show file tree
Hide file tree
Showing 5 changed files with 24 additions and 35 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
## [unreleased] - YYYY-MM-DD

### Added

- Added exporting model to ONNX format.

### Changed

Expand Down
19 changes: 8 additions & 11 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -1718,26 +1718,23 @@ def _set_hparams(self, hp: Union[dict, Namespace, str]) -> None:
else:
self._hparams = hp

def to_onnx(self, file_path: str, input: Optional[Union[DataLoader, Tensor]] = None, verbose: Optional[bool] = False):
def to_onnx(self, file_path: str, input_sample: Optional[Tensor] = None, **kwargs):
"""Saves the model in ONNX format
Args:
file_path: The path of the file the model should be saved to.
input: Either a PyTorch DataLoader with training samples or an input tensor for tracing.
input_sample: A sample of an input tensor for tracing.
verbose: Boolean value to indicate if the ONNX output should be printed
"""

if isinstance(input, DataLoader):
batch = next(iter(input))
input_data = batch[0]
elif isinstance(input, Tensor):
input_data = input
if isinstance(input_sample, Tensor):
input_data = input_sample
elif self.example_input_array is not None:
input_data = self.example_input_array
else:
self.prepare_data()
batch = next(iter(self.train_dataloader()))
input_data = batch[0]
raise ValueError(f'input_sample and example_input_array tensors are both missing.')

torch.onnx.export(self, input_data, file_path, verbose=verbose)
torch.onnx.export(self, input_data, file_path, **kwargs)

@property
def hparams(self) -> Union[AttributeDict, str]:
Expand Down
3 changes: 1 addition & 2 deletions tests/base/model_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def __init__(

# if you specify an example input, the summary will show input/output for each layer
# TODO: to be fixed in #1773
# self.example_input_array = torch.rand(5, 28 * 28)
self.example_input_array = torch.rand(5, 28 * 28)

# build model
self.__build_model()
Expand All @@ -89,7 +89,6 @@ def __build_model(self):
)

def forward(self, x):
x = x.view(x.size(0), -1)
x = self.c_d1(x)
x = torch.tanh(x)
x = self.c_d1_bn(x)
Expand Down
1 change: 1 addition & 0 deletions tests/base/model_train_steps.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ def training_step(self, batch, batch_idx, optimizer_idx=None):
"""Lightning calls this inside the training loop"""
# forward pass
x, y = batch
x = x.view(x.size(0), -1)

y_hat = self(x)

Expand Down
34 changes: 13 additions & 21 deletions tests/models/test_onxx_save.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,22 @@
from tests.base import EvalModelTemplate


def test_model_saves_on_cpu(tmpdir):
"""Test that ONNX model saves on CPU and size is greater than 3 MB"""
def test_model_saves_with_input_sample(tmpdir):
"""Test that ONNX model saves with input sample and size is greater than 3 MB"""
model = EvalModelTemplate()
trainer = Trainer(max_epochs=1)
trainer.fit(model)

file_path = os.path.join(tmpdir, "model.onxx")
input_sample = torch.randn((1, 28 * 28))
model.to_onnx(file_path, input_sample)
assert os.path.exists(file_path) is True
assert os.path.getsize(file_path) > 3e+06


def test_model_saves_with_example_input_array(tmpdir):
"""Test that ONNX model saves with_example_input_array and size is greater than 3 MB"""
model = EvalModelTemplate()
file_path = os.path.join(tmpdir, "model.onxx")
model.to_onnx(file_path)
assert os.path.exists(file_path) is True
Expand Down Expand Up @@ -50,22 +60,4 @@ def test_verbose_param(tmpdir, capsys):
file_path = os.path.join(tmpdir, "model.onxx")
model.to_onnx(file_path, verbose=True)
captured = capsys.readouterr()
assert "graph(%0" in captured.out


def test_input_param_with_dataloader(tmpdir):
"""Test that ONXX model is saved when a dataloader is passed in as input"""
model = EvalModelTemplate()
dataloader = model.dataloader(train=True)
file_path = os.path.join(tmpdir, "model.onxx")
model.to_onnx(file_path, input=dataloader)
assert os.path.exists(file_path) is True


def test_input_param_with_tensor(tmpdir):
"""Test that ONXX model is saved when a a tensor is passed in as input"""
model = EvalModelTemplate()
tensor = torch.randn((1, 28, 28))
file_path = os.path.join(tmpdir, "model.onxx")
model.to_onnx(file_path, input=tensor)
assert os.path.exists(file_path) is True
assert "graph(%" in captured.out

0 comments on commit 6c6e233

Please sign in to comment.