Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 8 additions & 9 deletions apps/wasm-standalone/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -116,16 +116,10 @@ This project should be considered **experimental** at the very early stage, all

- Build DL library in the WebAssembly format.

- Download model
- Compile the model

```
cd wasm-graph/tools && wget https://s3.amazonaws.com/onnx-model-zoo/resnet/resnet50v1/resnet50v1.onnx
```

- Compile

```
LLVM_AR=llvm-ar-10 python ./build_graph_lib.py -O3 ./resnet50v1.onnx
cd wasm-graph/tools && LLVM_AR=llvm-ar-10 python ./build_graph_lib.py -O3
```

### Build wasm-graph package
Expand Down Expand Up @@ -170,9 +164,14 @@ $ wget -O synset.csv https://raw.githubusercontent.com/kazum/tvm-wasm/master/syn
$ ./target/debug/test_graph_resnet50 -g ./wasm_graph_resnet50.wasm -i ./cat.png -l ./synset.csv
original image dimensions: (256, 256)
resized image dimensions: (224, 224)
input image belongs to the class `tabby, tabby cat`
input image belongs to the class `tiger cat`
```

Note: this example also works without WASI support. Please modify `wasm-graph/.cargo/config` to change the target to
`wasm32-unknown-unknown` and uncomment the raw wasm engine in `wasm-runtime/src/graph.rs` to run in pure wasm32. SIMD
may not be supported without WASI support. You may also need to delete ` -mattr=+simd128` in the
[build script](apps/wasm-standalone/wasm-graph/tools/build_graph_lib.py).

## Future Work

### More networks support
Expand Down
18 changes: 10 additions & 8 deletions apps/wasm-standalone/wasm-graph/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ lazy_static! {
"/lib/graph.json"
)))
.unwrap();

let params_bytes =
include_bytes!(concat!(env!("CARGO_MANIFEST_DIR"), "/lib/graph.params"));
let params = tvm_graph_rt::load_param_dict(params_bytes)
Expand All @@ -57,6 +58,7 @@ lazy_static! {
.collect::<HashMap<String, TVMTensor<'static>>>();

let mut exec = GraphExecutor::new(graph, &*SYSLIB).unwrap();

exec.load_params(params);

Mutex::new(exec)
Expand All @@ -68,14 +70,14 @@ pub extern "C" fn run(wasm_addr: i32, in_size: i32) -> i32 {
let in_tensor = unsafe { utils::load_input(wasm_addr, in_size as usize) };
let input: TVMTensor = in_tensor.as_dltensor().into();

GRAPH_EXECUTOR.lock().unwrap().set_input("data", input);
GRAPH_EXECUTOR.lock().unwrap().run();
let output = GRAPH_EXECUTOR
.lock()
.unwrap()
.get_output(0)
.unwrap()
.as_dltensor(false);
// since this executor is not multi-threaded, we can acquire lock once
let mut executor = GRAPH_EXECUTOR.lock().unwrap();

executor.set_input("data", input);

executor.run();

let output = executor.get_output(0).unwrap().as_dltensor(false);

let out_tensor: Tensor = output.into();
let out_size = unsafe { utils::store_output(wasm_addr, out_tensor) };
Expand Down
2 changes: 1 addition & 1 deletion apps/wasm-standalone/wasm-graph/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ use std::{
};
pub use tvm_sys::ffi::DLTensor;
use tvm_sys::ffi::{
DLDevice, DLDataType, DLDataTypeCode_kDLFloat, DLDataTypeCode_kDLInt, DLDeviceType_kDLCPU,
DLDataType, DLDataTypeCode_kDLFloat, DLDataTypeCode_kDLInt, DLDevice, DLDeviceType_kDLCPU,
};

#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
Expand Down
19 changes: 13 additions & 6 deletions apps/wasm-standalone/wasm-graph/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,20 @@ use std::ptr;
pub unsafe fn load_input(in_addr: i32, in_size: usize) -> Tensor {
let in_addr = in_addr as *mut u8;

let mut data_vec = Vec::new();
for i in 0..in_size {
data_vec.push(ptr::read(in_addr.offset(i as isize)));
}
let input: Tensor = serde_json::from_slice(&data_vec).unwrap();
println!("DEBUG: in_addr {:?}, in_size {:?}", in_addr, in_size);

let data_vec = unsafe { std::slice::from_raw_parts(in_addr, in_size) };

input
let input = serde_json::from_slice(&data_vec);
match input {
Ok(result) => {
println!("DEBUG: SER SUCCEED!!! and Ok");
result
}
Err(e) => {
panic!("DEBUG: SER SUCCEED!!! but Err, {:?}", &e);
}
}
}

pub unsafe fn store_output(out_addr: i32, output: Tensor) -> usize {
Expand Down
68 changes: 50 additions & 18 deletions apps/wasm-standalone/wasm-graph/tools/build_graph_lib.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# specific language governing permissions and limitations
# under the License.

"""Builds a simple graph for testing."""
"""Builds a simple resnet50 graph for testing."""
import argparse
import os
import subprocess
Expand All @@ -25,47 +25,79 @@
import onnx
import tvm
from tvm import relay, runtime
from tvm.contrib.download import download_testdata
from tvm.contrib import graph_executor

from PIL import Image
import numpy as np
import tvm.relay as relay

def _get_mod_and_params(model_file):
onnx_model = onnx.load(model_file)
shape_dict = {}
for input in onnx_model.graph.input:
shape_dict[input.name] = [dim.dim_value for dim in input.type.tensor_type.shape.dim]
# This example uses resnet50-v2-7 model
model_url = "".join(
[
"https://github.com/onnx/models/raw/",
"master/vision/classification/resnet/model/",
"resnet50-v2-7.onnx",
]
)

return relay.frontend.from_onnx(onnx_model, shape_dict)


def build_graph_lib(model_file, opt_level):
def build_graph_lib(opt_level):
"""Compiles the pre-trained model with TVM"""
out_dir = os.path.join(sys.path[0], "../lib")
if not os.path.exists(out_dir):
os.makedirs(out_dir)

# Compile the relay mod
mod, params = _get_mod_and_params(model_file)
# Follow the tutorial to download and compile the model
model_path = download_testdata(
model_url, "resnet50-v2-7.onnx", module="onnx")
onnx_model = onnx.load(model_path)

img_url = "https://s3.amazonaws.com/model-server/inputs/kitten.jpg"
img_path = download_testdata(img_url, "imagenet_cat.png", module="data")

# Resize it to 224x224
resized_image = Image.open(img_path).resize((224, 224))
img_data = np.asarray(resized_image).astype("float32")

# Our input image is in HWC layout while ONNX expects CHW input, so convert the array
img_data = np.transpose(img_data, (2, 0, 1))

# Normalize according to the ImageNet input specification
imagenet_mean = np.array([0.485, 0.456, 0.406]).reshape((3, 1, 1))
imagenet_stddev = np.array([0.229, 0.224, 0.225]).reshape((3, 1, 1))
norm_img_data = (img_data / 255 - imagenet_mean) / imagenet_stddev

# Add the batch dimension, as we are expecting 4-dimensional input: NCHW.
img_data = np.expand_dims(norm_img_data, axis=0)

input_name = "data"
shape_dict = {input_name: img_data.shape}

mod, params = relay.frontend.from_onnx(onnx_model, shape_dict)
target = "llvm -mtriple=wasm32-unknown-unknown -mattr=+simd128 --system-lib"

with tvm.transform.PassContext(opt_level=opt_level):
graph_json, lib, params = relay.build(mod, target=target, params=params)
factory = relay.build(mod, target=target, params=params)

# Save the model artifacts to obj_file
obj_file = os.path.join(out_dir, "graph.o")
lib.save(obj_file)
factory.get_lib().save(obj_file)

# Run llvm-ar to archive obj_file into lib_file
lib_file = os.path.join(out_dir, "libgraph_wasm32.a")
cmds = [os.environ.get("LLVM_AR", "llvm-ar-10"), "rcs", lib_file, obj_file]
subprocess.run(cmds)

# Save the json and params
with open(os.path.join(out_dir, "graph.json"), "w") as f_graph:
f_graph.write(graph_json)

f_graph.write(factory.get_graph_json())
with open(os.path.join(out_dir, "graph.params"), "wb") as f_params:
f_params.write(runtime.save_param_dict(params))
f_params.write(runtime.save_param_dict(factory.get_params()))


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="ONNX model build example")
parser.add_argument("model_file", type=str, help="the path of onnx model file")
parser.add_argument(
"-O",
"--opt-level",
Expand All @@ -75,4 +107,4 @@ def build_graph_lib(model_file, opt_level):
)
args = parser.parse_args()

build_graph_lib(args.model_file, args.opt_level)
build_graph_lib(args.opt_level)
4 changes: 2 additions & 2 deletions apps/wasm-standalone/wasm-runtime/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ license = "Apache-2.0"
keywords = ["wasm", "machine learning", "wasmtime"]

[dependencies]
wasmtime = "0.16.0"
wasmtime-wasi = "0.16.0"
wasmtime = "0.28.0"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can u check it if this version of wasmtime supports loading wasm graph with -mattr=+simd128 flag? Thanks!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can u check it if this version of wasmtime supports loading wasm graph with -mattr=+simd128 flag? Thanks!

Yes. I've already tested it. wasmtime 0.28 supports it.

wasmtime-wasi = "0.28.0"
anyhow = "1.0.31"
serde = "1.0.53"
serde_json = "1.0.53"
Expand Down
79 changes: 52 additions & 27 deletions apps/wasm-standalone/wasm-runtime/src/graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,17 @@

use anyhow::Result;
use wasmtime::*;
use wasmtime_wasi::{Wasi, WasiCtx};
use wasmtime_wasi::{WasiCtx, WasiCtxBuilder};

use super::Tensor;

pub struct GraphExecutor {
pub(crate) wasm_addr: i32,
pub(crate) input_size: i32,
pub(crate) output_size: i32,
pub(crate) store: Option<Store<WasiCtx>>,
// None-WASI version:
// pub(crate) store: Option<Store<()>>,
pub(crate) instance: Option<Instance>,
}

Expand All @@ -37,25 +40,44 @@ impl GraphExecutor {
wasm_addr: 0,
input_size: 0,
output_size: 0,
store: None,
instance: None,
}
}

pub fn instantiate(&mut self, wasm_graph_file: String) -> Result<()> {
let engine = Engine::new(Config::new().wasm_simd(true));
let store = Store::new(&engine);
// It seems WASI in this example is not necessary

// None WASI version: works with no SIMD
// let engine = Engine::new(Config::new().wasm_simd(true)).unwrap();
// let mut store = Store::new(&engine, ());
// let module = Module::from_file(store.engine(), &wasm_graph_file)?;

// let instance = Instance::new(&mut store, &module, &[])?;

// self.instance = Some(instance);
// self.store = Some(store);

// Ok(())

// WASI version:
let engine = Engine::new(Config::new().wasm_simd(true)).unwrap();
// First set up our linker which is going to be linking modules together. We
// want our linker to have wasi available, so we set that up here as well.
let mut linker = Linker::new(&store);
let mut linker = Linker::new(&engine);
wasmtime_wasi::add_to_linker(&mut linker, |s| s)?;
// Create an instance of `Wasi` which contains a `WasiCtx`. Note that
// `WasiCtx` provides a number of ways to configure what the target program
// will have access to.
let wasi = Wasi::new(&store, WasiCtx::new(std::env::args())?);
wasi.add_to_linker(&mut linker)?;
let wasi = WasiCtxBuilder::new()
.inherit_stdio()
.inherit_args()?
.build();
let mut store = Store::new(&engine, wasi);

let module = Module::from_file(&store, &wasm_graph_file)?;
self.instance = Some(linker.instantiate(&module)?);
let module = Module::from_file(&engine, &wasm_graph_file)?;
self.instance = Some(linker.instantiate(&mut store, &module)?);
self.store = Some(store);

Ok(())
}
Expand All @@ -65,26 +87,24 @@ impl GraphExecutor {
.instance
.as_ref()
.unwrap()
.get_memory("memory")
.get_memory(self.store.as_mut().unwrap(), "memory")
.ok_or_else(|| anyhow::format_err!("failed to find `memory` export"))?;

// Specify the wasm address to access the wasm memory.
let wasm_addr = memory.data_size();
let wasm_addr = memory.data_size(self.store.as_mut().unwrap());

// Serialize the data into a JSON string.
let in_data = serde_json::to_vec(&input_data)?;
let in_size = in_data.len();

// Grow up memory size according to in_size to avoid memory leak.
memory.grow((in_size >> 16) as u32 + 1)?;
memory.grow(self.store.as_mut().unwrap(), (in_size >> 16) as u32 + 1)?;

// Insert the input data into wasm memory.
for i in 0..in_size {
unsafe {
memory.data_unchecked_mut()[wasm_addr + i] = *in_data.get(i).unwrap();
}
}
memory.write(self.store.as_mut().unwrap(), wasm_addr, &in_data)?;

self.wasm_addr = wasm_addr as i32;
self.input_size = in_size as i32;

Ok(())
}

Expand All @@ -94,11 +114,12 @@ impl GraphExecutor {
.instance
.as_ref()
.unwrap()
.get_func("run")
.ok_or_else(|| anyhow::format_err!("failed to find `run` function export!"))?
.get2::<i32, i32, i32>()?;
.get_func(self.store.as_mut().unwrap(), "run")
.ok_or_else(|| anyhow::format_err!("failed to find `run` function export!"))?;

let out_size = run(self.wasm_addr, self.input_size)?;
let params = [Val::I32(self.wasm_addr), Val::I32(self.input_size)];
let out_size = run.call(self.store.as_mut().unwrap(), &params[..])?;
let out_size = (*out_size)[0].unwrap_i32();
if out_size == 0 {
panic!("graph run failed!");
}
Expand All @@ -107,18 +128,22 @@ impl GraphExecutor {
Ok(())
}

pub fn get_output(&self) -> Result<Tensor> {
pub fn get_output(&mut self) -> Result<Tensor> {
let memory = self
.instance
.as_ref()
.unwrap()
.get_memory("memory")
.get_memory(self.store.as_mut().unwrap(), "memory")
.ok_or_else(|| anyhow::format_err!("failed to find `memory` export"))?;

let out_data = unsafe {
&memory.data_unchecked()[self.wasm_addr as usize..][..self.output_size as usize]
};
let out_vec: Tensor = serde_json::from_slice(out_data).unwrap();
let mut out_data = vec![0 as u8; self.output_size as _];
memory.read(
self.store.as_mut().unwrap(),
self.wasm_addr as _,
&mut out_data,
)?;

let out_vec: Tensor = serde_json::from_slice(&out_data).unwrap();
Ok(out_vec)
}
}
Expand Down
1 change: 1 addition & 0 deletions rust/tvm-graph-rt/src/module/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ fn wrap_backend_packed_func(func_name: String, func: BackendPackedCFunc) -> Box<
values.len() as i32,
&mut ret_val,
&mut ret_type_code,
std::ptr::null_mut(),
);
if exit_code == 0 {
Ok(RetValue::from_tvm_value(ret_val, ret_type_code))
Expand Down
Loading