Skip to content

Commit 3a320e7

Browse files
committed
use torchvision's mobilenet_v2 instead of mxnet
1 parent 73b138b commit 3a320e7

File tree

1 file changed

+26
-11
lines changed

1 file changed

+26
-11
lines changed

apps/ios_rpc/tests/ios_rpc_mobilenet.py

Lines changed: 26 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
import coremltools
2424
import numpy as np
2525
import tvm
26-
from mxnet import gluon
2726
from PIL import Image
2827
from tvm import relay, rpc
2928
from tvm.contrib import coreml_runtime, graph_executor, utils, xcode
@@ -51,6 +50,8 @@ def compile_metal(src, target):
5150

5251

5352
def prepare_input():
53+
from torchvision import transforms
54+
5455
img_url = "https://github.com/dmlc/mxnet.js/blob/main/data/cat.png?raw=true"
5556
img_name = "cat.png"
5657
synset_url = "".join(
@@ -62,22 +63,36 @@ def prepare_input():
6263
]
6364
)
6465
synset_name = "imagenet1000_clsid_to_human.txt"
65-
img_path = download_testdata(img_url, "cat.png", module="data")
66+
img_path = download_testdata(img_url, img_name, module="data")
6667
synset_path = download_testdata(synset_url, synset_name, module="data")
6768
with open(synset_path) as f:
6869
synset = eval(f.read())
69-
image = Image.open(img_path).resize((224, 224))
70+
input_image = Image.open(img_path)
7071

71-
image = np.array(image) - np.array([123.0, 117.0, 104.0])
72-
image /= np.array([58.395, 57.12, 57.375])
73-
image = image.transpose((2, 0, 1))
74-
image = image[np.newaxis, :]
75-
return image.astype("float32"), synset
72+
preprocess = transforms.Compose(
73+
[
74+
transforms.Resize(256),
75+
transforms.CenterCrop(224),
76+
transforms.ToTensor(),
77+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
78+
]
79+
)
80+
input_tensor = preprocess(input_image)
81+
input_batch = input_tensor.unsqueeze(0)
82+
return input_batch.detach().cpu().numpy(), synset
7683

7784

7885
def get_model(model_name, data_shape):
79-
gluon_model = gluon.model_zoo.vision.get_model(model_name, pretrained=True)
80-
mod, params = relay.frontend.from_mxnet(gluon_model, {"data": data_shape})
86+
import torch
87+
import torchvision
88+
89+
torch_model = getattr(torchvision.models, model_name)(weights="IMAGENET1K_V1").eval()
90+
input_data = torch.randn(data_shape)
91+
scripted_model = torch.jit.trace(torch_model, input_data)
92+
93+
input_infos = [("data", input_data.shape)]
94+
mod, params = relay.frontend.from_pytorch(scripted_model, input_infos)
95+
8196
# we want a probability so add a softmax operator
8297
func = mod["main"]
8398
func = relay.Function(
@@ -90,7 +105,7 @@ def get_model(model_name, data_shape):
90105
def test_mobilenet(host, port, key, mode):
91106
temp = utils.tempdir()
92107
image, synset = prepare_input()
93-
model, params = get_model("mobilenetv2_1.0", image.shape)
108+
model, params = get_model("mobilenet_v2", image.shape)
94109

95110
def run(mod, target):
96111
with relay.build_config(opt_level=3):

0 commit comments

Comments
 (0)