|
95 | 95 | # ----------------------------- |
96 | 96 | # Back to the host machine, which should have a full TVM installed (with LLVM). |
97 | 97 | # |
98 | | -# We will use pre-trained model from |
99 | | -# `MXNet Gluon model zoo <https://mxnet.apache.org/api/python/gluon/model_zoo.html>`_. |
100 | | -# You can found more details about this part at tutorial :ref:`tutorial-from-mxnet`. |
| 98 | +# We will use pre-trained model from torchvision |
101 | 99 |
|
102 | | -import sys |
103 | | - |
104 | | -from mxnet.gluon.model_zoo.vision import get_model |
| 100 | +import torch |
| 101 | +import torchvision |
105 | 102 | from PIL import Image |
106 | 103 | import numpy as np |
107 | 104 |
|
108 | 105 | # one line to get the model |
109 | | -try: |
110 | | - block = get_model("resnet18_v1", pretrained=True) |
111 | | -except RuntimeError: |
112 | | - print("Downloads from mxnet no longer supported", file=sys.stderr) |
113 | | - sys.exit(0) |
| 106 | +model_name = "resnet18" |
| 107 | +model = getattr(torchvision.models, model_name)(pretrained=True) |
| 108 | +model = model.eval() |
| 109 | + |
| 110 | +# We grab the TorchScripted model via tracing |
| 111 | +input_shape = [1, 3, 224, 224] |
| 112 | +input_data = torch.randn(input_shape) |
| 113 | +scripted_model = torch.jit.trace(model, input_data).eval() |
114 | 114 |
|
115 | 115 | ###################################################################### |
116 | 116 | # In order to test our model, here we download an image of cat and |
@@ -148,12 +148,12 @@ def transform_image(image): |
148 | 148 | synset = eval(f.read()) |
149 | 149 |
|
150 | 150 | ###################################################################### |
151 | | -# Now we would like to port the Gluon model to a portable computational graph. |
| 151 | +# Now we would like to port the PyTorch model to a portable computational graph. |
152 | 152 | # It's as easy as several lines. |
153 | 153 |
|
154 | | -# We support MXNet static graph(symbol) and HybridBlock in mxnet.gluon |
155 | | -shape_dict = {"data": x.shape} |
156 | | -mod, params = relay.frontend.from_mxnet(block, shape_dict) |
| 154 | +input_name = "input0" |
| 155 | +shape_list = [(input_name, x.shape)] |
| 156 | +mod, params = relay.frontend.from_pytorch(scripted_model, shape_list) |
157 | 157 | # we want a probability so add a softmax operator |
158 | 158 | func = mod["main"] |
159 | 159 | func = relay.Function(func.params, relay.nn.softmax(func.body), None, func.type_params, func.attrs) |
@@ -226,7 +226,7 @@ def transform_image(image): |
226 | 226 | dev = remote.cpu(0) |
227 | 227 | module = runtime.GraphModule(rlib["default"](dev)) |
228 | 228 | # set input data |
229 | | -module.set_input("data", tvm.nd.array(x.astype("float32"))) |
| 229 | +module.set_input(input_name, tvm.nd.array(x.astype("float32"))) |
230 | 230 | # run |
231 | 231 | module.run() |
232 | 232 | # get output |
|
0 commit comments