Skip to content

Commit ec28b67

Browse files
authored
[Apps] Remove mxnet dependency from /apps/android_camera/models (#17297)
* use torchvision's resnet18 instead of mxnet * cleanup import statements
1 parent 4eafd00 commit ec28b67

File tree

2 files changed

+18
-16
lines changed

2 files changed

+18
-16
lines changed

apps/android_camera/models/prepare_model.py

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -15,18 +15,16 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717

18-
import logging
19-
import pathlib
20-
from pathlib import Path
21-
from typing import Union
18+
import json
2219
import os
2320
from os import environ
24-
import json
21+
from pathlib import Path
22+
from typing import Union
2523

2624
import tvm
2725
import tvm.relay as relay
28-
from tvm.contrib import utils, ndk, graph_executor as runtime
29-
from tvm.contrib.download import download_testdata, download
26+
from tvm.contrib import ndk
27+
from tvm.contrib.download import download, download_testdata
3028

3129
target = "llvm -mtriple=arm64-linux-android"
3230
target_host = None
@@ -50,15 +48,18 @@ def del_dir(target: Union[Path, str], only_if_empty: bool = False):
5048

5149
def get_model(model_name, batch_size=1):
5250
if model_name == "resnet18_v1":
53-
import mxnet as mx
54-
from mxnet import gluon
55-
from mxnet.gluon.model_zoo import vision
51+
import torch
52+
import torchvision
5653

57-
gluon_model = vision.get_model(model_name, pretrained=True)
58-
img_size = 224
59-
data_shape = (batch_size, 3, img_size, img_size)
60-
net, params = relay.frontend.from_mxnet(gluon_model, {"data": data_shape})
61-
return (net, params)
54+
weights = torchvision.models.ResNet18_Weights.IMAGENET1K_V1
55+
torch_model = torchvision.models.resnet18(weights=weights).eval()
56+
input_shape = [1, 3, 224, 224]
57+
input_data = torch.randn(input_shape)
58+
scripted_model = torch.jit.trace(torch_model, input_data)
59+
60+
input_infos = [("data", input_data.shape)]
61+
mod, params = relay.frontend.from_pytorch(scripted_model, input_infos)
62+
return (mod, params)
6263
elif model_name == "mobilenet_v2":
6364
import keras
6465
from keras.applications.mobilenet_v2 import MobileNetV2
Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
keras==2.9
2-
mxnet
32
scipy
43
tensorflow==2.9.3
4+
torch
5+
torchvision

0 commit comments

Comments
 (0)