From 19f9a449d8e0e89b3c7a815435ec197efe9c37dd Mon Sep 17 00:00:00 2001 From: lezwon Date: Mon, 13 Jul 2020 16:11:08 +0530 Subject: [PATCH] use example_input_array add to changelog --- CHANGELOG.md | 2 +- pytorch_lightning/core/lightning.py | 19 +++++++--------- tests/base/model_template.py | 3 +-- tests/base/model_train_steps.py | 1 + tests/models/test_onxx_save.py | 34 +++++++++++------------------ 5 files changed, 24 insertions(+), 35 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index cf200ea15f007d..1263e4b1b2dbac 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,7 +8,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ## [unreleased] - YYYY-MM-DD ### Added - +- Added exporting model to ONNX format. ### Changed diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index a0292fea4d0dc1..1c52e30b9b82d1 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1718,26 +1718,23 @@ def _set_hparams(self, hp: Union[dict, Namespace, str]) -> None: else: self._hparams = hp - def to_onnx(self, file_path: str, input: Optional[Union[DataLoader, Tensor]] = None, verbose: Optional[bool] = False): + def to_onnx(self, file_path: str, input_sample: Optional[Tensor] = None, **kwargs): """Saves the model in ONNX format Args: file_path: The path of the file the model should be saved to. - input: Either a PyTorch DataLoader with training samples or an input tensor for tracing. + input_sample: A sample of an input tensor for tracing. verbose: Boolean value to indicate if the ONNX output should be printed """ - if isinstance(input, DataLoader): - batch = next(iter(input)) - input_data = batch[0] - elif isinstance(input, Tensor): - input_data = input + if isinstance(input_sample, Tensor): + input_data = input_sample + elif self.example_input_array is not None: + input_data = self.example_input_array else: - self.prepare_data() - batch = next(iter(self.train_dataloader())) - input_data = batch[0] + raise ValueError(f'input_sample and example_input_array tensors are both missing.') - torch.onnx.export(self, input_data, file_path, verbose=verbose) + torch.onnx.export(self, input_data, file_path, **kwargs) @property def hparams(self) -> Union[AttributeDict, str]: diff --git a/tests/base/model_template.py b/tests/base/model_template.py index cf5f0e55f6094b..d3fa349b9640b0 100644 --- a/tests/base/model_template.py +++ b/tests/base/model_template.py @@ -66,7 +66,7 @@ def __init__( # if you specify an example input, the summary will show input/output for each layer # TODO: to be fixed in #1773 - # self.example_input_array = torch.rand(5, 28 * 28) + self.example_input_array = torch.rand(5, 28 * 28) # build model self.__build_model() @@ -89,7 +89,6 @@ def __build_model(self): ) def forward(self, x): - x = x.view(x.size(0), -1) x = self.c_d1(x) x = torch.tanh(x) x = self.c_d1_bn(x) diff --git a/tests/base/model_train_steps.py b/tests/base/model_train_steps.py index 6576eb42c0f2d8..fcd020d852126f 100644 --- a/tests/base/model_train_steps.py +++ b/tests/base/model_train_steps.py @@ -15,6 +15,7 @@ def training_step(self, batch, batch_idx, optimizer_idx=None): """Lightning calls this inside the training loop""" # forward pass x, y = batch + x = x.view(x.size(0), -1) y_hat = self(x) diff --git a/tests/models/test_onxx_save.py b/tests/models/test_onxx_save.py index be483f5e849341..a518129a9b52e2 100644 --- a/tests/models/test_onxx_save.py +++ b/tests/models/test_onxx_save.py @@ -8,12 +8,22 @@ from tests.base import EvalModelTemplate -def test_model_saves_on_cpu(tmpdir): - """Test that ONNX model saves on CPU and size is greater than 3 MB""" +def test_model_saves_with_input_sample(tmpdir): + """Test that ONNX model saves with input sample and size is greater than 3 MB""" model = EvalModelTemplate() trainer = Trainer(max_epochs=1) trainer.fit(model) + file_path = os.path.join(tmpdir, "model.onxx") + input_sample = torch.randn((1, 28 * 28)) + model.to_onnx(file_path, input_sample) + assert os.path.exists(file_path) is True + assert os.path.getsize(file_path) > 3e+06 + + +def test_model_saves_with_example_input_array(tmpdir): + """Test that ONNX model saves with_example_input_array and size is greater than 3 MB""" + model = EvalModelTemplate() file_path = os.path.join(tmpdir, "model.onxx") model.to_onnx(file_path) assert os.path.exists(file_path) is True @@ -50,22 +60,4 @@ def test_verbose_param(tmpdir, capsys): file_path = os.path.join(tmpdir, "model.onxx") model.to_onnx(file_path, verbose=True) captured = capsys.readouterr() - assert "graph(%0" in captured.out - - -def test_input_param_with_dataloader(tmpdir): - """Test that ONXX model is saved when a dataloader is passed in as input""" - model = EvalModelTemplate() - dataloader = model.dataloader(train=True) - file_path = os.path.join(tmpdir, "model.onxx") - model.to_onnx(file_path, input=dataloader) - assert os.path.exists(file_path) is True - - -def test_input_param_with_tensor(tmpdir): - """Test that ONXX model is saved when a a tensor is passed in as input""" - model = EvalModelTemplate() - tensor = torch.randn((1, 28, 28)) - file_path = os.path.join(tmpdir, "model.onxx") - model.to_onnx(file_path, input=tensor) - assert os.path.exists(file_path) is True + assert "graph(%" in captured.out