Skip to content

Commit 0ca864b

Browse files
committed
oneflow cause crash, do test with change
1 parent 23e3112 commit 0ca864b

File tree

1 file changed

+0
-126
lines changed

1 file changed

+0
-126
lines changed

gallery/how_to/compile_models/from_oneflow.py

Lines changed: 0 additions & 126 deletions
Original file line numberDiff line numberDiff line change
@@ -49,129 +49,3 @@
4949
from tvm import relay
5050
from tvm.contrib.download import download_testdata
5151

52-
######################################################################
53-
# Load a pretrained OneFlow model and save model
54-
# ----------------------------------------------
55-
model_name = "resnet18"
56-
model = getattr(flowvision.models, model_name)(pretrained=True)
57-
model = model.eval()
58-
59-
model_dir = "resnet18_model"
60-
if not os.path.exists(model_dir):
61-
flow.save(model.state_dict(), model_dir)
62-
63-
######################################################################
64-
# Load a test image
65-
# -----------------
66-
# Classic cat example!
67-
from PIL import Image
68-
69-
img_url = "https://github.com/dmlc/mxnet.js/blob/main/data/cat.png?raw=true"
70-
img_path = download_testdata(img_url, "cat.png", module="data")
71-
img = Image.open(img_path).resize((224, 224))
72-
73-
# Preprocess the image and convert to tensor
74-
from flowvision import transforms
75-
76-
my_preprocess = transforms.Compose(
77-
[
78-
transforms.Resize(256),
79-
transforms.CenterCrop(224),
80-
transforms.ToTensor(),
81-
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
82-
]
83-
)
84-
img = my_preprocess(img)
85-
img = np.expand_dims(img.numpy(), 0)
86-
87-
######################################################################
88-
# Import the graph to Relay
89-
# -------------------------
90-
# Convert OneFlow graph to Relay graph. The input name can be arbitrary.
91-
class Graph(flow.nn.Graph):
92-
def __init__(self, module):
93-
super().__init__()
94-
self.m = module
95-
96-
def build(self, x):
97-
out = self.m(x)
98-
return out
99-
100-
101-
graph = Graph(model)
102-
_ = graph._compile(flow.randn(1, 3, 224, 224))
103-
104-
mod, params = relay.frontend.from_oneflow(graph, model_dir)
105-
106-
######################################################################
107-
# Relay Build
108-
# -----------
109-
# Compile the graph to llvm target with given input specification.
110-
target = tvm.target.Target("llvm", host="llvm")
111-
dev = tvm.cpu(0)
112-
with tvm.transform.PassContext(opt_level=3):
113-
lib = relay.build(mod, target=target, params=params)
114-
115-
######################################################################
116-
# Execute the portable graph on TVM
117-
# ---------------------------------
118-
# Now we can try deploying the compiled model on target.
119-
target = "cuda"
120-
with tvm.transform.PassContext(opt_level=10):
121-
intrp = relay.build_module.create_executor("graph", mod, tvm.cuda(0), target)
122-
123-
print(type(img))
124-
print(img.shape)
125-
tvm_output = intrp.evaluate()(tvm.nd.array(img.astype("float32")), **params)
126-
127-
#####################################################################
128-
# Look up synset name
129-
# -------------------
130-
# Look up prediction top 1 index in 1000 class synset.
131-
synset_url = "".join(
132-
[
133-
"https://raw.githubusercontent.com/Cadene/",
134-
"pretrained-models.pytorch/master/data/",
135-
"imagenet_synsets.txt",
136-
]
137-
)
138-
synset_name = "imagenet_synsets.txt"
139-
synset_path = download_testdata(synset_url, synset_name, module="data")
140-
with open(synset_path) as f:
141-
synsets = f.readlines()
142-
143-
synsets = [x.strip() for x in synsets]
144-
splits = [line.split(" ") for line in synsets]
145-
key_to_classname = {spl[0]: " ".join(spl[1:]) for spl in splits}
146-
147-
class_url = "".join(
148-
[
149-
"https://raw.githubusercontent.com/Cadene/",
150-
"pretrained-models.pytorch/master/data/",
151-
"imagenet_classes.txt",
152-
]
153-
)
154-
class_name = "imagenet_classes.txt"
155-
class_path = download_testdata(class_url, class_name, module="data")
156-
with open(class_path) as f:
157-
class_id_to_key = f.readlines()
158-
159-
class_id_to_key = [x.strip() for x in class_id_to_key]
160-
161-
# Get top-1 result for TVM
162-
top1_tvm = np.argmax(tvm_output.numpy()[0])
163-
tvm_class_key = class_id_to_key[top1_tvm]
164-
165-
# Convert input to OneFlow variable and get OneFlow result for comparison
166-
with flow.no_grad():
167-
torch_img = flow.from_numpy(img)
168-
output = model(torch_img)
169-
170-
# Get top-1 result for OneFlow
171-
top_oneflow = np.argmax(output.numpy())
172-
oneflow_class_key = class_id_to_key[top_oneflow]
173-
174-
print("Relay top-1 id: {}, class name: {}".format(top1_tvm, key_to_classname[tvm_class_key]))
175-
print(
176-
"OneFlow top-1 id: {}, class name: {}".format(top_oneflow, key_to_classname[oneflow_class_key])
177-
)

0 commit comments

Comments
 (0)