Skip to content

Commit 9bcd74f

Browse files
committed
fix preprocessing
1 parent 12334f9 commit 9bcd74f

File tree

1 file changed

+13
-7
lines changed

1 file changed

+13
-7
lines changed

apps/ios_rpc/tests/ios_rpc_mobilenet.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@ def compile_metal(src, target):
5050

5151

5252
def prepare_input():
53+
from torchvision import transforms
54+
5355
img_url = "https://github.com/dmlc/mxnet.js/blob/main/data/cat.png?raw=true"
5456
img_name = "cat.png"
5557
synset_url = "".join(
@@ -65,13 +67,17 @@ def prepare_input():
6567
synset_path = download_testdata(synset_url, synset_name, module="data")
6668
with open(synset_path) as f:
6769
synset = eval(f.read())
68-
image = Image.open(img_path).resize((224, 224))
69-
70-
image = np.array(image) - np.array([123.0, 117.0, 104.0])
71-
image /= np.array([58.395, 57.12, 57.375])
72-
image = image.transpose((2, 0, 1))
73-
image = image[np.newaxis, :]
74-
return image.astype("float32"), synset
70+
input_image = Image.open(img_path)
71+
72+
preprocess = transforms.Compose([
73+
transforms.Resize(256),
74+
transforms.CenterCrop(224),
75+
transforms.ToTensor(),
76+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
77+
])
78+
input_tensor = preprocess(input_image)
79+
input_batch = input_tensor.unsqueeze(0)
80+
return input_batch.detach().cpu().numpy(), synset
7581

7682

7783
def get_model(model_name, data_shape):

0 commit comments

Comments
 (0)