1515# specific language governing permissions and limitations
1616# under the License.
1717
18- import logging
19- import pathlib
20- from pathlib import Path
21- from typing import Union
18+ import json
2219import os
2320from os import environ
24- import json
21+ from pathlib import Path
22+ from typing import Union
2523
2624import tvm
2725import tvm .relay as relay
28- from tvm .contrib import utils , ndk , graph_executor as runtime
29- from tvm .contrib .download import download_testdata , download
26+ from tvm .contrib import ndk
27+ from tvm .contrib .download import download , download_testdata
3028
3129target = "llvm -mtriple=arm64-linux-android"
3230target_host = None
@@ -50,15 +48,18 @@ def del_dir(target: Union[Path, str], only_if_empty: bool = False):
5048
5149def get_model (model_name , batch_size = 1 ):
5250 if model_name == "resnet18_v1" :
53- import mxnet as mx
54- from mxnet import gluon
55- from mxnet .gluon .model_zoo import vision
51+ import torch
52+ import torchvision
5653
57- gluon_model = vision .get_model (model_name , pretrained = True )
58- img_size = 224
59- data_shape = (batch_size , 3 , img_size , img_size )
60- net , params = relay .frontend .from_mxnet (gluon_model , {"data" : data_shape })
61- return (net , params )
54+ weights = torchvision .models .ResNet18_Weights .IMAGENET1K_V1
55+ torch_model = torchvision .models .resnet18 (weights = weights ).eval ()
56+ input_shape = [1 , 3 , 224 , 224 ]
57+ input_data = torch .randn (input_shape )
58+ scripted_model = torch .jit .trace (torch_model , input_data )
59+
60+ input_infos = [("data" , input_data .shape )]
61+ mod , params = relay .frontend .from_pytorch (scripted_model , input_infos )
62+ return (mod , params )
6263 elif model_name == "mobilenet_v2" :
6364 import keras
6465 from keras .applications .mobilenet_v2 import MobileNetV2
0 commit comments