2323import coremltools
2424import numpy as np
2525import tvm
26- from mxnet import gluon
2726from PIL import Image
2827from tvm import relay , rpc
2928from tvm .contrib import coreml_runtime , graph_executor , utils , xcode
@@ -51,6 +50,8 @@ def compile_metal(src, target):
5150
5251
5352def prepare_input ():
53+ from torchvision import transforms
54+
5455 img_url = "https://github.com/dmlc/mxnet.js/blob/main/data/cat.png?raw=true"
5556 img_name = "cat.png"
5657 synset_url = "" .join (
@@ -62,22 +63,36 @@ def prepare_input():
6263 ]
6364 )
6465 synset_name = "imagenet1000_clsid_to_human.txt"
65- img_path = download_testdata (img_url , "cat.png" , module = "data" )
66+ img_path = download_testdata (img_url , img_name , module = "data" )
6667 synset_path = download_testdata (synset_url , synset_name , module = "data" )
6768 with open (synset_path ) as f :
6869 synset = eval (f .read ())
69- image = Image .open (img_path ). resize (( 224 , 224 ) )
70+ input_image = Image .open (img_path )
7071
71- image = np .array (image ) - np .array ([123.0 , 117.0 , 104.0 ])
72- image /= np .array ([58.395 , 57.12 , 57.375 ])
73- image = image .transpose ((2 , 0 , 1 ))
74- image = image [np .newaxis , :]
75- return image .astype ("float32" ), synset
72+ preprocess = transforms .Compose (
73+ [
74+ transforms .Resize (256 ),
75+ transforms .CenterCrop (224 ),
76+ transforms .ToTensor (),
77+ transforms .Normalize (mean = [0.485 , 0.456 , 0.406 ], std = [0.229 , 0.224 , 0.225 ]),
78+ ]
79+ )
80+ input_tensor = preprocess (input_image )
81+ input_batch = input_tensor .unsqueeze (0 )
82+ return input_batch .detach ().cpu ().numpy (), synset
7683
7784
7885def get_model (model_name , data_shape ):
79- gluon_model = gluon .model_zoo .vision .get_model (model_name , pretrained = True )
80- mod , params = relay .frontend .from_mxnet (gluon_model , {"data" : data_shape })
86+ import torch
87+ import torchvision
88+
89+ torch_model = getattr (torchvision .models , model_name )(weights = "IMAGENET1K_V1" ).eval ()
90+ input_data = torch .randn (data_shape )
91+ scripted_model = torch .jit .trace (torch_model , input_data )
92+
93+ input_infos = [("data" , input_data .shape )]
94+ mod , params = relay .frontend .from_pytorch (scripted_model , input_infos )
95+
8196 # we want a probability so add a softmax operator
8297 func = mod ["main" ]
8398 func = relay .Function (
@@ -90,7 +105,7 @@ def get_model(model_name, data_shape):
90105def test_mobilenet (host , port , key , mode ):
91106 temp = utils .tempdir ()
92107 image , synset = prepare_input ()
93- model , params = get_model ("mobilenetv2_1.0 " , image .shape )
108+ model , params = get_model ("mobilenet_v2 " , image .shape )
94109
95110 def run (mod , target ):
96111 with relay .build_config (opt_level = 3 ):
0 commit comments