Skip to content

Commit bce5758

Browse files
BBufhhhfccz
andauthored
Add oneflow fronted tutorials (#11036)
* add relay.f.frontend.fm_oneflow support cnns * support cuda * fix mobilenetv2 and reviews * fix: model without meta info * support eager and yolo, add test * fix: license * add: tutorials * fix: support new graph * fix some comments * refine * fix concat op convert bug * refine * refine * change cuda to cpu * fix bug * fix ci error in tvm * fix pylint check * delete useless file * add skimage package in docker * fix ci error * fix bug * add oneflow fronted test in ci * merge conflict * fix tutorial * try to find error in ci * revert * merge conflict * black oneflow * Delete from_oneflow.py * fix bug when upgrade oneflow to 0.7.0 * add tutorials * add tutorials * try to fix * fix bug * add test * fix bug * fix flowvision bug * Update test_forward.py * Update test_forward.py Co-authored-by: hhhfccz <[email protected]>
1 parent 60e43e1 commit bce5758

File tree

2 files changed

+188
-11
lines changed

2 files changed

+188
-11
lines changed
Lines changed: 177 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,177 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
"""
18+
Compile OneFlow Models
19+
======================
20+
**Author**: `Xiaoyu Zhang <https://github.com/BBuf/>`_
21+
22+
This article is an introductory tutorial to deploy OneFlow models with Relay.
23+
24+
For us to begin with, OneFlow package should be installed.
25+
26+
A quick solution is to install via pip
27+
28+
.. code-block:: bash
29+
30+
pip install flowvision==0.1.0
31+
python3 -m pip install -f https://release.oneflow.info oneflow==0.7.0+cpu
32+
33+
or please refer to official site:
34+
https://github.com/Oneflow-Inc/oneflow
35+
36+
Currently, TVM supports OneFlow 0.7.0. Other versions may be unstable.
37+
"""
38+
import os, math
39+
from matplotlib import pyplot as plt
40+
import numpy as np
41+
from PIL import Image
42+
43+
# oneflow imports
44+
import flowvision
45+
import oneflow as flow
46+
import oneflow.nn as nn
47+
48+
import tvm
49+
from tvm import relay
50+
from tvm.contrib.download import download_testdata
51+
52+
######################################################################
53+
# Load a pretrained OneFlow model and save model
54+
# ----------------------------------------------
55+
model_name = "resnet18"
56+
model = getattr(flowvision.models, model_name)(pretrained=True)
57+
model = model.eval()
58+
59+
model_dir = "resnet18_model"
60+
if not os.path.exists(model_dir):
61+
flow.save(model.state_dict(), model_dir)
62+
63+
######################################################################
64+
# Load a test image
65+
# -----------------
66+
# Classic cat example!
67+
from PIL import Image
68+
69+
img_url = "https://github.com/dmlc/mxnet.js/blob/main/data/cat.png?raw=true"
70+
img_path = download_testdata(img_url, "cat.png", module="data")
71+
img = Image.open(img_path).resize((224, 224))
72+
73+
# Preprocess the image and convert to tensor
74+
from flowvision import transforms
75+
76+
my_preprocess = transforms.Compose(
77+
[
78+
transforms.Resize(256),
79+
transforms.CenterCrop(224),
80+
transforms.ToTensor(),
81+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
82+
]
83+
)
84+
img = my_preprocess(img)
85+
img = np.expand_dims(img.numpy(), 0)
86+
87+
######################################################################
88+
# Import the graph to Relay
89+
# -------------------------
90+
# Convert OneFlow graph to Relay graph. The input name can be arbitrary.
91+
class Graph(flow.nn.Graph):
92+
def __init__(self, module):
93+
super().__init__()
94+
self.m = module
95+
96+
def build(self, x):
97+
out = self.m(x)
98+
return out
99+
100+
101+
graph = Graph(model)
102+
_ = graph._compile(flow.randn(1, 3, 224, 224))
103+
104+
mod, params = relay.frontend.from_oneflow(graph, model_dir)
105+
106+
######################################################################
107+
# Relay Build
108+
# -----------
109+
# Compile the graph to llvm target with given input specification.
110+
target = tvm.target.Target("llvm", host="llvm")
111+
dev = tvm.cpu(0)
112+
with tvm.transform.PassContext(opt_level=3):
113+
lib = relay.build(mod, target=target, params=params)
114+
115+
######################################################################
116+
# Execute the portable graph on TVM
117+
# ---------------------------------
118+
# Now we can try deploying the compiled model on target.
119+
target = "cuda"
120+
with tvm.transform.PassContext(opt_level=10):
121+
intrp = relay.build_module.create_executor("graph", mod, tvm.cuda(0), target)
122+
123+
print(type(img))
124+
print(img.shape)
125+
tvm_output = intrp.evaluate()(tvm.nd.array(img.astype("float32")), **params)
126+
127+
#####################################################################
128+
# Look up synset name
129+
# -------------------
130+
# Look up prediction top 1 index in 1000 class synset.
131+
synset_url = "".join(
132+
[
133+
"https://raw.githubusercontent.com/Cadene/",
134+
"pretrained-models.pytorch/master/data/",
135+
"imagenet_synsets.txt",
136+
]
137+
)
138+
synset_name = "imagenet_synsets.txt"
139+
synset_path = download_testdata(synset_url, synset_name, module="data")
140+
with open(synset_path) as f:
141+
synsets = f.readlines()
142+
143+
synsets = [x.strip() for x in synsets]
144+
splits = [line.split(" ") for line in synsets]
145+
key_to_classname = {spl[0]: " ".join(spl[1:]) for spl in splits}
146+
147+
class_url = "".join(
148+
[
149+
"https://raw.githubusercontent.com/Cadene/",
150+
"pretrained-models.pytorch/master/data/",
151+
"imagenet_classes.txt",
152+
]
153+
)
154+
class_name = "imagenet_classes.txt"
155+
class_path = download_testdata(class_url, class_name, module="data")
156+
with open(class_path) as f:
157+
class_id_to_key = f.readlines()
158+
159+
class_id_to_key = [x.strip() for x in class_id_to_key]
160+
161+
# Get top-1 result for TVM
162+
top1_tvm = np.argmax(tvm_output.numpy()[0])
163+
tvm_class_key = class_id_to_key[top1_tvm]
164+
165+
# Convert input to OneFlow variable and get OneFlow result for comparison
166+
with flow.no_grad():
167+
torch_img = flow.from_numpy(img)
168+
output = model(torch_img)
169+
170+
# Get top-1 result for OneFlow
171+
top_oneflow = np.argmax(output.numpy())
172+
oneflow_class_key = class_id_to_key[top_oneflow]
173+
174+
print("Relay top-1 id: {}, class name: {}".format(top1_tvm, key_to_classname[tvm_class_key]))
175+
print(
176+
"OneFlow top-1 id: {}, class name: {}".format(top_oneflow, key_to_classname[oneflow_class_key])
177+
)

tests/python/frontend/oneflow/test_forward.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -710,14 +710,14 @@ def forward(self, x1, x2, x3):
710710
verify_concat(model, device=device)
711711

712712

713-
# if __name__ == "__main__":
714-
# test_conv2d()
715-
# test_pool2d()
716-
# test_normalization()
717-
# test_upsample()
718-
# test_convtran()
719-
# test_activation()
720-
# test_math()
721-
# test_slice()
722-
# test_concat()
723-
# rmdir("log")
713+
if __name__ == "__main__":
714+
test_conv2d()
715+
test_pool2d()
716+
test_normalization()
717+
test_upsample()
718+
test_convtran()
719+
test_activation()
720+
test_math()
721+
test_slice()
722+
test_concat()
723+
rmdir("log")

0 commit comments

Comments
 (0)