Skip to content

Commit df06c58

Browse files
authored
[Bugfux] wasm32-standalone app repaired (#8563)
1 parent 00ad44e commit df06c58

File tree

10 files changed

+157
-86
lines changed

10 files changed

+157
-86
lines changed

apps/wasm-standalone/README.md

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -116,16 +116,10 @@ This project should be considered **experimental** at the very early stage, all
116116
117117
- Build DL library in the WebAssembly format.
118118
119-
- Download model
119+
- Compile the model
120120
121121
```
122-
cd wasm-graph/tools && wget https://s3.amazonaws.com/onnx-model-zoo/resnet/resnet50v1/resnet50v1.onnx
123-
```
124-
125-
- Compile
126-
127-
```
128-
LLVM_AR=llvm-ar-10 python ./build_graph_lib.py -O3 ./resnet50v1.onnx
122+
cd wasm-graph/tools && LLVM_AR=llvm-ar-10 python ./build_graph_lib.py -O3
129123
```
130124
131125
### Build wasm-graph package
@@ -170,9 +164,14 @@ $ wget -O synset.csv https://raw.githubusercontent.com/kazum/tvm-wasm/master/syn
170164
$ ./target/debug/test_graph_resnet50 -g ./wasm_graph_resnet50.wasm -i ./cat.png -l ./synset.csv
171165
original image dimensions: (256, 256)
172166
resized image dimensions: (224, 224)
173-
input image belongs to the class `tabby, tabby cat`
167+
input image belongs to the class `tiger cat`
174168
```
175169

170+
Note: this example also works without WASI support. Please modify `wasm-graph/.cargo/config` to change the target to
171+
`wasm32-unknown-unknown` and uncomment the raw wasm engine in `wasm-runtime/src/graph.rs` to run in pure wasm32. SIMD
172+
may not be supported without WASI support. You may also need to delete ` -mattr=+simd128` in the
173+
[build script](apps/wasm-standalone/wasm-graph/tools/build_graph_lib.py).
174+
176175
## Future Work
177176

178177
### More networks support

apps/wasm-standalone/wasm-graph/src/lib.rs

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ lazy_static! {
4848
"/lib/graph.json"
4949
)))
5050
.unwrap();
51+
5152
let params_bytes =
5253
include_bytes!(concat!(env!("CARGO_MANIFEST_DIR"), "/lib/graph.params"));
5354
let params = tvm_graph_rt::load_param_dict(params_bytes)
@@ -57,6 +58,7 @@ lazy_static! {
5758
.collect::<HashMap<String, TVMTensor<'static>>>();
5859

5960
let mut exec = GraphExecutor::new(graph, &*SYSLIB).unwrap();
61+
6062
exec.load_params(params);
6163

6264
Mutex::new(exec)
@@ -68,14 +70,14 @@ pub extern "C" fn run(wasm_addr: i32, in_size: i32) -> i32 {
6870
let in_tensor = unsafe { utils::load_input(wasm_addr, in_size as usize) };
6971
let input: TVMTensor = in_tensor.as_dltensor().into();
7072

71-
GRAPH_EXECUTOR.lock().unwrap().set_input("data", input);
72-
GRAPH_EXECUTOR.lock().unwrap().run();
73-
let output = GRAPH_EXECUTOR
74-
.lock()
75-
.unwrap()
76-
.get_output(0)
77-
.unwrap()
78-
.as_dltensor(false);
73+
// since this executor is not multi-threaded, we can acquire lock once
74+
let mut executor = GRAPH_EXECUTOR.lock().unwrap();
75+
76+
executor.set_input("data", input);
77+
78+
executor.run();
79+
80+
let output = executor.get_output(0).unwrap().as_dltensor(false);
7981

8082
let out_tensor: Tensor = output.into();
8183
let out_size = unsafe { utils::store_output(wasm_addr, out_tensor) };

apps/wasm-standalone/wasm-graph/src/types.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ use std::{
2424
};
2525
pub use tvm_sys::ffi::DLTensor;
2626
use tvm_sys::ffi::{
27-
DLDevice, DLDataType, DLDataTypeCode_kDLFloat, DLDataTypeCode_kDLInt, DLDeviceType_kDLCPU,
27+
DLDataType, DLDataTypeCode_kDLFloat, DLDataTypeCode_kDLInt, DLDevice, DLDeviceType_kDLCPU,
2828
};
2929

3030
#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]

apps/wasm-standalone/wasm-graph/src/utils.rs

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,20 @@ use std::ptr;
2424
pub unsafe fn load_input(in_addr: i32, in_size: usize) -> Tensor {
2525
let in_addr = in_addr as *mut u8;
2626

27-
let mut data_vec = Vec::new();
28-
for i in 0..in_size {
29-
data_vec.push(ptr::read(in_addr.offset(i as isize)));
30-
}
31-
let input: Tensor = serde_json::from_slice(&data_vec).unwrap();
27+
println!("DEBUG: in_addr {:?}, in_size {:?}", in_addr, in_size);
28+
29+
let data_vec = unsafe { std::slice::from_raw_parts(in_addr, in_size) };
3230

33-
input
31+
let input = serde_json::from_slice(&data_vec);
32+
match input {
33+
Ok(result) => {
34+
println!("DEBUG: SER SUCCEED!!! and Ok");
35+
result
36+
}
37+
Err(e) => {
38+
panic!("DEBUG: SER SUCCEED!!! but Err, {:?}", &e);
39+
}
40+
}
3441
}
3542

3643
pub unsafe fn store_output(out_addr: i32, output: Tensor) -> usize {

apps/wasm-standalone/wasm-graph/tools/build_graph_lib.py

100644100755
Lines changed: 49 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
# specific language governing permissions and limitations
1717
# under the License.
1818

19-
"""Builds a simple graph for testing."""
19+
"""Builds a simple resnet50 graph for testing."""
2020
import argparse
2121
import os
2222
import subprocess
@@ -25,47 +25,78 @@
2525
import onnx
2626
import tvm
2727
from tvm import relay, runtime
28+
from tvm.contrib.download import download_testdata
29+
from tvm.contrib import graph_executor
2830

31+
from PIL import Image
32+
import numpy as np
33+
import tvm.relay as relay
2934

30-
def _get_mod_and_params(model_file):
31-
onnx_model = onnx.load(model_file)
32-
shape_dict = {}
33-
for input in onnx_model.graph.input:
34-
shape_dict[input.name] = [dim.dim_value for dim in input.type.tensor_type.shape.dim]
35+
# This example uses resnet50-v2-7 model
36+
model_url = "".join(
37+
[
38+
"https://github.com/onnx/models/raw/",
39+
"master/vision/classification/resnet/model/",
40+
"resnet50-v2-7.onnx",
41+
]
42+
)
3543

36-
return relay.frontend.from_onnx(onnx_model, shape_dict)
3744

38-
39-
def build_graph_lib(model_file, opt_level):
45+
def build_graph_lib(opt_level):
4046
"""Compiles the pre-trained model with TVM"""
4147
out_dir = os.path.join(sys.path[0], "../lib")
4248
if not os.path.exists(out_dir):
4349
os.makedirs(out_dir)
4450

45-
# Compile the relay mod
46-
mod, params = _get_mod_and_params(model_file)
51+
# Follow the tutorial to download and compile the model
52+
model_path = download_testdata(model_url, "resnet50-v2-7.onnx", module="onnx")
53+
onnx_model = onnx.load(model_path)
54+
55+
img_url = "https://s3.amazonaws.com/model-server/inputs/kitten.jpg"
56+
img_path = download_testdata(img_url, "imagenet_cat.png", module="data")
57+
58+
# Resize it to 224x224
59+
resized_image = Image.open(img_path).resize((224, 224))
60+
img_data = np.asarray(resized_image).astype("float32")
61+
62+
# Our input image is in HWC layout while ONNX expects CHW input, so convert the array
63+
img_data = np.transpose(img_data, (2, 0, 1))
64+
65+
# Normalize according to the ImageNet input specification
66+
imagenet_mean = np.array([0.485, 0.456, 0.406]).reshape((3, 1, 1))
67+
imagenet_stddev = np.array([0.229, 0.224, 0.225]).reshape((3, 1, 1))
68+
norm_img_data = (img_data / 255 - imagenet_mean) / imagenet_stddev
69+
70+
# Add the batch dimension, as we are expecting 4-dimensional input: NCHW.
71+
img_data = np.expand_dims(norm_img_data, axis=0)
72+
73+
input_name = "data"
74+
shape_dict = {input_name: img_data.shape}
75+
76+
mod, params = relay.frontend.from_onnx(onnx_model, shape_dict)
4777
target = "llvm -mtriple=wasm32-unknown-unknown -mattr=+simd128 --system-lib"
78+
4879
with tvm.transform.PassContext(opt_level=opt_level):
49-
graph_json, lib, params = relay.build(mod, target=target, params=params)
80+
factory = relay.build(mod, target=target, params=params)
5081

5182
# Save the model artifacts to obj_file
5283
obj_file = os.path.join(out_dir, "graph.o")
53-
lib.save(obj_file)
84+
factory.get_lib().save(obj_file)
85+
5486
# Run llvm-ar to archive obj_file into lib_file
5587
lib_file = os.path.join(out_dir, "libgraph_wasm32.a")
5688
cmds = [os.environ.get("LLVM_AR", "llvm-ar-10"), "rcs", lib_file, obj_file]
5789
subprocess.run(cmds)
5890

91+
# Save the json and params
5992
with open(os.path.join(out_dir, "graph.json"), "w") as f_graph:
60-
f_graph.write(graph_json)
61-
93+
f_graph.write(factory.get_graph_json())
6294
with open(os.path.join(out_dir, "graph.params"), "wb") as f_params:
63-
f_params.write(runtime.save_param_dict(params))
95+
f_params.write(runtime.save_param_dict(factory.get_params()))
6496

6597

6698
if __name__ == "__main__":
6799
parser = argparse.ArgumentParser(description="ONNX model build example")
68-
parser.add_argument("model_file", type=str, help="the path of onnx model file")
69100
parser.add_argument(
70101
"-O",
71102
"--opt-level",
@@ -75,4 +106,4 @@ def build_graph_lib(model_file, opt_level):
75106
)
76107
args = parser.parse_args()
77108

78-
build_graph_lib(args.model_file, args.opt_level)
109+
build_graph_lib(args.opt_level)

apps/wasm-standalone/wasm-runtime/Cargo.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@ license = "Apache-2.0"
2626
keywords = ["wasm", "machine learning", "wasmtime"]
2727

2828
[dependencies]
29-
wasmtime = "0.16.0"
30-
wasmtime-wasi = "0.16.0"
29+
wasmtime = "0.28.0"
30+
wasmtime-wasi = "0.28.0"
3131
anyhow = "1.0.31"
3232
serde = "1.0.53"
3333
serde_json = "1.0.53"

apps/wasm-standalone/wasm-runtime/src/graph.rs

Lines changed: 52 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,17 @@
1919

2020
use anyhow::Result;
2121
use wasmtime::*;
22-
use wasmtime_wasi::{Wasi, WasiCtx};
22+
use wasmtime_wasi::{WasiCtx, WasiCtxBuilder};
2323

2424
use super::Tensor;
2525

2626
pub struct GraphExecutor {
2727
pub(crate) wasm_addr: i32,
2828
pub(crate) input_size: i32,
2929
pub(crate) output_size: i32,
30+
pub(crate) store: Option<Store<WasiCtx>>,
31+
// None-WASI version:
32+
// pub(crate) store: Option<Store<()>>,
3033
pub(crate) instance: Option<Instance>,
3134
}
3235

@@ -37,25 +40,44 @@ impl GraphExecutor {
3740
wasm_addr: 0,
3841
input_size: 0,
3942
output_size: 0,
43+
store: None,
4044
instance: None,
4145
}
4246
}
4347

4448
pub fn instantiate(&mut self, wasm_graph_file: String) -> Result<()> {
45-
let engine = Engine::new(Config::new().wasm_simd(true));
46-
let store = Store::new(&engine);
49+
// It seems WASI in this example is not necessary
4750

51+
// None WASI version: works with no SIMD
52+
// let engine = Engine::new(Config::new().wasm_simd(true)).unwrap();
53+
// let mut store = Store::new(&engine, ());
54+
// let module = Module::from_file(store.engine(), &wasm_graph_file)?;
55+
56+
// let instance = Instance::new(&mut store, &module, &[])?;
57+
58+
// self.instance = Some(instance);
59+
// self.store = Some(store);
60+
61+
// Ok(())
62+
63+
// WASI version:
64+
let engine = Engine::new(Config::new().wasm_simd(true)).unwrap();
4865
// First set up our linker which is going to be linking modules together. We
4966
// want our linker to have wasi available, so we set that up here as well.
50-
let mut linker = Linker::new(&store);
67+
let mut linker = Linker::new(&engine);
68+
wasmtime_wasi::add_to_linker(&mut linker, |s| s)?;
5169
// Create an instance of `Wasi` which contains a `WasiCtx`. Note that
5270
// `WasiCtx` provides a number of ways to configure what the target program
5371
// will have access to.
54-
let wasi = Wasi::new(&store, WasiCtx::new(std::env::args())?);
55-
wasi.add_to_linker(&mut linker)?;
72+
let wasi = WasiCtxBuilder::new()
73+
.inherit_stdio()
74+
.inherit_args()?
75+
.build();
76+
let mut store = Store::new(&engine, wasi);
5677

57-
let module = Module::from_file(&store, &wasm_graph_file)?;
58-
self.instance = Some(linker.instantiate(&module)?);
78+
let module = Module::from_file(&engine, &wasm_graph_file)?;
79+
self.instance = Some(linker.instantiate(&mut store, &module)?);
80+
self.store = Some(store);
5981

6082
Ok(())
6183
}
@@ -65,26 +87,24 @@ impl GraphExecutor {
6587
.instance
6688
.as_ref()
6789
.unwrap()
68-
.get_memory("memory")
90+
.get_memory(self.store.as_mut().unwrap(), "memory")
6991
.ok_or_else(|| anyhow::format_err!("failed to find `memory` export"))?;
7092

7193
// Specify the wasm address to access the wasm memory.
72-
let wasm_addr = memory.data_size();
94+
let wasm_addr = memory.data_size(self.store.as_mut().unwrap());
95+
7396
// Serialize the data into a JSON string.
7497
let in_data = serde_json::to_vec(&input_data)?;
7598
let in_size = in_data.len();
99+
76100
// Grow up memory size according to in_size to avoid memory leak.
77-
memory.grow((in_size >> 16) as u32 + 1)?;
101+
memory.grow(self.store.as_mut().unwrap(), (in_size >> 16) as u32 + 1)?;
78102

79-
// Insert the input data into wasm memory.
80-
for i in 0..in_size {
81-
unsafe {
82-
memory.data_unchecked_mut()[wasm_addr + i] = *in_data.get(i).unwrap();
83-
}
84-
}
103+
memory.write(self.store.as_mut().unwrap(), wasm_addr, &in_data)?;
85104

86105
self.wasm_addr = wasm_addr as i32;
87106
self.input_size = in_size as i32;
107+
88108
Ok(())
89109
}
90110

@@ -94,11 +114,12 @@ impl GraphExecutor {
94114
.instance
95115
.as_ref()
96116
.unwrap()
97-
.get_func("run")
98-
.ok_or_else(|| anyhow::format_err!("failed to find `run` function export!"))?
99-
.get2::<i32, i32, i32>()?;
117+
.get_func(self.store.as_mut().unwrap(), "run")
118+
.ok_or_else(|| anyhow::format_err!("failed to find `run` function export!"))?;
100119

101-
let out_size = run(self.wasm_addr, self.input_size)?;
120+
let params = [Val::I32(self.wasm_addr), Val::I32(self.input_size)];
121+
let out_size = run.call(self.store.as_mut().unwrap(), &params[..])?;
122+
let out_size = (*out_size)[0].unwrap_i32();
102123
if out_size == 0 {
103124
panic!("graph run failed!");
104125
}
@@ -107,18 +128,22 @@ impl GraphExecutor {
107128
Ok(())
108129
}
109130

110-
pub fn get_output(&self) -> Result<Tensor> {
131+
pub fn get_output(&mut self) -> Result<Tensor> {
111132
let memory = self
112133
.instance
113134
.as_ref()
114135
.unwrap()
115-
.get_memory("memory")
136+
.get_memory(self.store.as_mut().unwrap(), "memory")
116137
.ok_or_else(|| anyhow::format_err!("failed to find `memory` export"))?;
117138

118-
let out_data = unsafe {
119-
&memory.data_unchecked()[self.wasm_addr as usize..][..self.output_size as usize]
120-
};
121-
let out_vec: Tensor = serde_json::from_slice(out_data).unwrap();
139+
let mut out_data = vec![0 as u8; self.output_size as _];
140+
memory.read(
141+
self.store.as_mut().unwrap(),
142+
self.wasm_addr as _,
143+
&mut out_data,
144+
)?;
145+
146+
let out_vec: Tensor = serde_json::from_slice(&out_data).unwrap();
122147
Ok(out_vec)
123148
}
124149
}

rust/tvm-graph-rt/src/module/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ fn wrap_backend_packed_func(func_name: String, func: BackendPackedCFunc) -> Box<
5252
values.len() as i32,
5353
&mut ret_val,
5454
&mut ret_type_code,
55+
std::ptr::null_mut(),
5556
);
5657
if exit_code == 0 {
5758
Ok(RetValue::from_tvm_value(ret_val, ret_type_code))

0 commit comments

Comments
 (0)