From ae360c9b53465bc959108fb636c4718942b7e848 Mon Sep 17 00:00:00 2001 From: Vincent TEMPLIER Date: Thu, 6 Jun 2024 13:13:04 +0000 Subject: [PATCH] Fix bad number of model outputs in python api --- python/examples/export/export_cpp_resnet18.py | 2 +- python/n2d2/deepnet.py | 7 ++++--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/python/examples/export/export_cpp_resnet18.py b/python/examples/export/export_cpp_resnet18.py index 46553c18..a60609a3 100644 --- a/python/examples/export/export_cpp_resnet18.py +++ b/python/examples/export/export_cpp_resnet18.py @@ -1,5 +1,5 @@ """ -Example of exporting a ResNet18 model to TDA4VM C7x +Example of exporting a ResNet18 model to any board which can support C++ Download the onnx from https://github.com/onnx/models/raw/main/vision/classification/resnet/model/resnet18-v1-7.onnx diff --git a/python/n2d2/deepnet.py b/python/n2d2/deepnet.py index ce45509b..773e0274 100755 --- a/python/n2d2/deepnet.py +++ b/python/n2d2/deepnet.py @@ -108,11 +108,12 @@ def get_input_cells(self): def get_output_cells(self): """ - Return the last N2D2 cell in the deepNet + Return the last N2D2 cells in the deepNet """ output = [] - for cell in self.N2D2().getLayers()[-1]: - output.append(self._cells[cell]) + for cell in self.N2D2().getCells(): + if len(self._cells[cell].N2D2().getChildrenCells()) == 0: + output.append(self._cells[cell]) return output def draw(self, filename):