@@ -50,6 +50,8 @@ def compile_metal(src, target):
5050
5151
5252def prepare_input ():
53+ from torchvision import transforms
54+
5355 img_url = "https://github.com/dmlc/mxnet.js/blob/main/data/cat.png?raw=true"
5456 img_name = "cat.png"
5557 synset_url = "" .join (
@@ -65,13 +67,17 @@ def prepare_input():
6567 synset_path = download_testdata (synset_url , synset_name , module = "data" )
6668 with open (synset_path ) as f :
6769 synset = eval (f .read ())
68- image = Image .open (img_path ).resize ((224 , 224 ))
69-
70- image = np .array (image ) - np .array ([123.0 , 117.0 , 104.0 ])
71- image /= np .array ([58.395 , 57.12 , 57.375 ])
72- image = image .transpose ((2 , 0 , 1 ))
73- image = image [np .newaxis , :]
74- return image .astype ("float32" ), synset
70+ input_image = Image .open (img_path )
71+
72+ preprocess = transforms .Compose ([
73+ transforms .Resize (256 ),
74+ transforms .CenterCrop (224 ),
75+ transforms .ToTensor (),
76+ transforms .Normalize (mean = [0.485 , 0.456 , 0.406 ], std = [0.229 , 0.224 , 0.225 ]),
77+ ])
78+ input_tensor = preprocess (input_image )
79+ input_batch = input_tensor .unsqueeze (0 )
80+ return input_batch .detach ().cpu ().numpy (), synset
7581
7682
7783def get_model (model_name , data_shape ):
0 commit comments