diff --git a/apps/wasm-standalone/README.md b/apps/wasm-standalone/README.md index e40d218634aa..b8a977f6ae50 100644 --- a/apps/wasm-standalone/README.md +++ b/apps/wasm-standalone/README.md @@ -116,16 +116,10 @@ This project should be considered **experimental** at the very early stage, all - Build DL library in the WebAssembly format. - - Download model + - Compile the model ``` - cd wasm-graph/tools && wget https://s3.amazonaws.com/onnx-model-zoo/resnet/resnet50v1/resnet50v1.onnx - ``` - - - Compile - - ``` - LLVM_AR=llvm-ar-10 python ./build_graph_lib.py -O3 ./resnet50v1.onnx + cd wasm-graph/tools && LLVM_AR=llvm-ar-10 python ./build_graph_lib.py -O3 ``` ### Build wasm-graph package @@ -170,9 +164,14 @@ $ wget -O synset.csv https://raw.githubusercontent.com/kazum/tvm-wasm/master/syn $ ./target/debug/test_graph_resnet50 -g ./wasm_graph_resnet50.wasm -i ./cat.png -l ./synset.csv original image dimensions: (256, 256) resized image dimensions: (224, 224) -input image belongs to the class `tabby, tabby cat` +input image belongs to the class `tiger cat` ``` +Note: this example also works without WASI support. Please modify `wasm-graph/.cargo/config` to change the target to +`wasm32-unknown-unknown` and uncomment the raw wasm engine in `wasm-runtime/src/graph.rs` to run in pure wasm32. SIMD +may not be supported without WASI support. You may also need to delete ` -mattr=+simd128` in the +[build script](apps/wasm-standalone/wasm-graph/tools/build_graph_lib.py). + ## Future Work ### More networks support diff --git a/apps/wasm-standalone/wasm-graph/src/lib.rs b/apps/wasm-standalone/wasm-graph/src/lib.rs index 2b4187849edc..92a3d5c2f3b0 100644 --- a/apps/wasm-standalone/wasm-graph/src/lib.rs +++ b/apps/wasm-standalone/wasm-graph/src/lib.rs @@ -48,6 +48,7 @@ lazy_static! { "/lib/graph.json" ))) .unwrap(); + let params_bytes = include_bytes!(concat!(env!("CARGO_MANIFEST_DIR"), "/lib/graph.params")); let params = tvm_graph_rt::load_param_dict(params_bytes) @@ -57,6 +58,7 @@ lazy_static! { .collect::>>(); let mut exec = GraphExecutor::new(graph, &*SYSLIB).unwrap(); + exec.load_params(params); Mutex::new(exec) @@ -68,14 +70,14 @@ pub extern "C" fn run(wasm_addr: i32, in_size: i32) -> i32 { let in_tensor = unsafe { utils::load_input(wasm_addr, in_size as usize) }; let input: TVMTensor = in_tensor.as_dltensor().into(); - GRAPH_EXECUTOR.lock().unwrap().set_input("data", input); - GRAPH_EXECUTOR.lock().unwrap().run(); - let output = GRAPH_EXECUTOR - .lock() - .unwrap() - .get_output(0) - .unwrap() - .as_dltensor(false); + // since this executor is not multi-threaded, we can acquire lock once + let mut executor = GRAPH_EXECUTOR.lock().unwrap(); + + executor.set_input("data", input); + + executor.run(); + + let output = executor.get_output(0).unwrap().as_dltensor(false); let out_tensor: Tensor = output.into(); let out_size = unsafe { utils::store_output(wasm_addr, out_tensor) }; diff --git a/apps/wasm-standalone/wasm-graph/src/types.rs b/apps/wasm-standalone/wasm-graph/src/types.rs index a3761a758cff..f08b7be84990 100644 --- a/apps/wasm-standalone/wasm-graph/src/types.rs +++ b/apps/wasm-standalone/wasm-graph/src/types.rs @@ -24,7 +24,7 @@ use std::{ }; pub use tvm_sys::ffi::DLTensor; use tvm_sys::ffi::{ - DLDevice, DLDataType, DLDataTypeCode_kDLFloat, DLDataTypeCode_kDLInt, DLDeviceType_kDLCPU, + DLDataType, DLDataTypeCode_kDLFloat, DLDataTypeCode_kDLInt, DLDevice, DLDeviceType_kDLCPU, }; #[derive(Debug, PartialEq, Clone, Serialize, Deserialize)] diff --git a/apps/wasm-standalone/wasm-graph/src/utils.rs b/apps/wasm-standalone/wasm-graph/src/utils.rs index fd4a71745f4f..92d386e3062a 100644 --- a/apps/wasm-standalone/wasm-graph/src/utils.rs +++ b/apps/wasm-standalone/wasm-graph/src/utils.rs @@ -24,13 +24,20 @@ use std::ptr; pub unsafe fn load_input(in_addr: i32, in_size: usize) -> Tensor { let in_addr = in_addr as *mut u8; - let mut data_vec = Vec::new(); - for i in 0..in_size { - data_vec.push(ptr::read(in_addr.offset(i as isize))); - } - let input: Tensor = serde_json::from_slice(&data_vec).unwrap(); + println!("DEBUG: in_addr {:?}, in_size {:?}", in_addr, in_size); + + let data_vec = unsafe { std::slice::from_raw_parts(in_addr, in_size) }; - input + let input = serde_json::from_slice(&data_vec); + match input { + Ok(result) => { + println!("DEBUG: SER SUCCEED!!! and Ok"); + result + } + Err(e) => { + panic!("DEBUG: SER SUCCEED!!! but Err, {:?}", &e); + } + } } pub unsafe fn store_output(out_addr: i32, output: Tensor) -> usize { diff --git a/apps/wasm-standalone/wasm-graph/tools/build_graph_lib.py b/apps/wasm-standalone/wasm-graph/tools/build_graph_lib.py old mode 100644 new mode 100755 index 3d8a349b8744..b1cdb199a871 --- a/apps/wasm-standalone/wasm-graph/tools/build_graph_lib.py +++ b/apps/wasm-standalone/wasm-graph/tools/build_graph_lib.py @@ -16,7 +16,7 @@ # specific language governing permissions and limitations # under the License. -"""Builds a simple graph for testing.""" +"""Builds a simple resnet50 graph for testing.""" import argparse import os import subprocess @@ -25,47 +25,78 @@ import onnx import tvm from tvm import relay, runtime +from tvm.contrib.download import download_testdata +from tvm.contrib import graph_executor +from PIL import Image +import numpy as np +import tvm.relay as relay -def _get_mod_and_params(model_file): - onnx_model = onnx.load(model_file) - shape_dict = {} - for input in onnx_model.graph.input: - shape_dict[input.name] = [dim.dim_value for dim in input.type.tensor_type.shape.dim] +# This example uses resnet50-v2-7 model +model_url = "".join( + [ + "https://github.com/onnx/models/raw/", + "master/vision/classification/resnet/model/", + "resnet50-v2-7.onnx", + ] +) - return relay.frontend.from_onnx(onnx_model, shape_dict) - -def build_graph_lib(model_file, opt_level): +def build_graph_lib(opt_level): """Compiles the pre-trained model with TVM""" out_dir = os.path.join(sys.path[0], "../lib") if not os.path.exists(out_dir): os.makedirs(out_dir) - # Compile the relay mod - mod, params = _get_mod_and_params(model_file) + # Follow the tutorial to download and compile the model + model_path = download_testdata(model_url, "resnet50-v2-7.onnx", module="onnx") + onnx_model = onnx.load(model_path) + + img_url = "https://s3.amazonaws.com/model-server/inputs/kitten.jpg" + img_path = download_testdata(img_url, "imagenet_cat.png", module="data") + + # Resize it to 224x224 + resized_image = Image.open(img_path).resize((224, 224)) + img_data = np.asarray(resized_image).astype("float32") + + # Our input image is in HWC layout while ONNX expects CHW input, so convert the array + img_data = np.transpose(img_data, (2, 0, 1)) + + # Normalize according to the ImageNet input specification + imagenet_mean = np.array([0.485, 0.456, 0.406]).reshape((3, 1, 1)) + imagenet_stddev = np.array([0.229, 0.224, 0.225]).reshape((3, 1, 1)) + norm_img_data = (img_data / 255 - imagenet_mean) / imagenet_stddev + + # Add the batch dimension, as we are expecting 4-dimensional input: NCHW. + img_data = np.expand_dims(norm_img_data, axis=0) + + input_name = "data" + shape_dict = {input_name: img_data.shape} + + mod, params = relay.frontend.from_onnx(onnx_model, shape_dict) target = "llvm -mtriple=wasm32-unknown-unknown -mattr=+simd128 --system-lib" + with tvm.transform.PassContext(opt_level=opt_level): - graph_json, lib, params = relay.build(mod, target=target, params=params) + factory = relay.build(mod, target=target, params=params) # Save the model artifacts to obj_file obj_file = os.path.join(out_dir, "graph.o") - lib.save(obj_file) + factory.get_lib().save(obj_file) + # Run llvm-ar to archive obj_file into lib_file lib_file = os.path.join(out_dir, "libgraph_wasm32.a") cmds = [os.environ.get("LLVM_AR", "llvm-ar-10"), "rcs", lib_file, obj_file] subprocess.run(cmds) + # Save the json and params with open(os.path.join(out_dir, "graph.json"), "w") as f_graph: - f_graph.write(graph_json) - + f_graph.write(factory.get_graph_json()) with open(os.path.join(out_dir, "graph.params"), "wb") as f_params: - f_params.write(runtime.save_param_dict(params)) + f_params.write(runtime.save_param_dict(factory.get_params())) if __name__ == "__main__": parser = argparse.ArgumentParser(description="ONNX model build example") - parser.add_argument("model_file", type=str, help="the path of onnx model file") parser.add_argument( "-O", "--opt-level", @@ -75,4 +106,4 @@ def build_graph_lib(model_file, opt_level): ) args = parser.parse_args() - build_graph_lib(args.model_file, args.opt_level) + build_graph_lib(args.opt_level) diff --git a/apps/wasm-standalone/wasm-runtime/Cargo.toml b/apps/wasm-standalone/wasm-runtime/Cargo.toml index 99f6db54431f..d3f860170d4e 100644 --- a/apps/wasm-standalone/wasm-runtime/Cargo.toml +++ b/apps/wasm-standalone/wasm-runtime/Cargo.toml @@ -26,8 +26,8 @@ license = "Apache-2.0" keywords = ["wasm", "machine learning", "wasmtime"] [dependencies] -wasmtime = "0.16.0" -wasmtime-wasi = "0.16.0" +wasmtime = "0.28.0" +wasmtime-wasi = "0.28.0" anyhow = "1.0.31" serde = "1.0.53" serde_json = "1.0.53" diff --git a/apps/wasm-standalone/wasm-runtime/src/graph.rs b/apps/wasm-standalone/wasm-runtime/src/graph.rs index e7c39cbb0687..bfa1c2f19c56 100644 --- a/apps/wasm-standalone/wasm-runtime/src/graph.rs +++ b/apps/wasm-standalone/wasm-runtime/src/graph.rs @@ -19,7 +19,7 @@ use anyhow::Result; use wasmtime::*; -use wasmtime_wasi::{Wasi, WasiCtx}; +use wasmtime_wasi::{WasiCtx, WasiCtxBuilder}; use super::Tensor; @@ -27,6 +27,9 @@ pub struct GraphExecutor { pub(crate) wasm_addr: i32, pub(crate) input_size: i32, pub(crate) output_size: i32, + pub(crate) store: Option>, + // None-WASI version: + // pub(crate) store: Option>, pub(crate) instance: Option, } @@ -37,25 +40,44 @@ impl GraphExecutor { wasm_addr: 0, input_size: 0, output_size: 0, + store: None, instance: None, } } pub fn instantiate(&mut self, wasm_graph_file: String) -> Result<()> { - let engine = Engine::new(Config::new().wasm_simd(true)); - let store = Store::new(&engine); + // It seems WASI in this example is not necessary + // None WASI version: works with no SIMD + // let engine = Engine::new(Config::new().wasm_simd(true)).unwrap(); + // let mut store = Store::new(&engine, ()); + // let module = Module::from_file(store.engine(), &wasm_graph_file)?; + + // let instance = Instance::new(&mut store, &module, &[])?; + + // self.instance = Some(instance); + // self.store = Some(store); + + // Ok(()) + + // WASI version: + let engine = Engine::new(Config::new().wasm_simd(true)).unwrap(); // First set up our linker which is going to be linking modules together. We // want our linker to have wasi available, so we set that up here as well. - let mut linker = Linker::new(&store); + let mut linker = Linker::new(&engine); + wasmtime_wasi::add_to_linker(&mut linker, |s| s)?; // Create an instance of `Wasi` which contains a `WasiCtx`. Note that // `WasiCtx` provides a number of ways to configure what the target program // will have access to. - let wasi = Wasi::new(&store, WasiCtx::new(std::env::args())?); - wasi.add_to_linker(&mut linker)?; + let wasi = WasiCtxBuilder::new() + .inherit_stdio() + .inherit_args()? + .build(); + let mut store = Store::new(&engine, wasi); - let module = Module::from_file(&store, &wasm_graph_file)?; - self.instance = Some(linker.instantiate(&module)?); + let module = Module::from_file(&engine, &wasm_graph_file)?; + self.instance = Some(linker.instantiate(&mut store, &module)?); + self.store = Some(store); Ok(()) } @@ -65,26 +87,24 @@ impl GraphExecutor { .instance .as_ref() .unwrap() - .get_memory("memory") + .get_memory(self.store.as_mut().unwrap(), "memory") .ok_or_else(|| anyhow::format_err!("failed to find `memory` export"))?; // Specify the wasm address to access the wasm memory. - let wasm_addr = memory.data_size(); + let wasm_addr = memory.data_size(self.store.as_mut().unwrap()); + // Serialize the data into a JSON string. let in_data = serde_json::to_vec(&input_data)?; let in_size = in_data.len(); + // Grow up memory size according to in_size to avoid memory leak. - memory.grow((in_size >> 16) as u32 + 1)?; + memory.grow(self.store.as_mut().unwrap(), (in_size >> 16) as u32 + 1)?; - // Insert the input data into wasm memory. - for i in 0..in_size { - unsafe { - memory.data_unchecked_mut()[wasm_addr + i] = *in_data.get(i).unwrap(); - } - } + memory.write(self.store.as_mut().unwrap(), wasm_addr, &in_data)?; self.wasm_addr = wasm_addr as i32; self.input_size = in_size as i32; + Ok(()) } @@ -94,11 +114,12 @@ impl GraphExecutor { .instance .as_ref() .unwrap() - .get_func("run") - .ok_or_else(|| anyhow::format_err!("failed to find `run` function export!"))? - .get2::()?; + .get_func(self.store.as_mut().unwrap(), "run") + .ok_or_else(|| anyhow::format_err!("failed to find `run` function export!"))?; - let out_size = run(self.wasm_addr, self.input_size)?; + let params = [Val::I32(self.wasm_addr), Val::I32(self.input_size)]; + let out_size = run.call(self.store.as_mut().unwrap(), ¶ms[..])?; + let out_size = (*out_size)[0].unwrap_i32(); if out_size == 0 { panic!("graph run failed!"); } @@ -107,18 +128,22 @@ impl GraphExecutor { Ok(()) } - pub fn get_output(&self) -> Result { + pub fn get_output(&mut self) -> Result { let memory = self .instance .as_ref() .unwrap() - .get_memory("memory") + .get_memory(self.store.as_mut().unwrap(), "memory") .ok_or_else(|| anyhow::format_err!("failed to find `memory` export"))?; - let out_data = unsafe { - &memory.data_unchecked()[self.wasm_addr as usize..][..self.output_size as usize] - }; - let out_vec: Tensor = serde_json::from_slice(out_data).unwrap(); + let mut out_data = vec![0 as u8; self.output_size as _]; + memory.read( + self.store.as_mut().unwrap(), + self.wasm_addr as _, + &mut out_data, + )?; + + let out_vec: Tensor = serde_json::from_slice(&out_data).unwrap(); Ok(out_vec) } } diff --git a/rust/tvm-graph-rt/src/module/mod.rs b/rust/tvm-graph-rt/src/module/mod.rs index 511ba4b37132..a345758deca1 100644 --- a/rust/tvm-graph-rt/src/module/mod.rs +++ b/rust/tvm-graph-rt/src/module/mod.rs @@ -52,6 +52,7 @@ fn wrap_backend_packed_func(func_name: String, func: BackendPackedCFunc) -> Box< values.len() as i32, &mut ret_val, &mut ret_type_code, + std::ptr::null_mut(), ); if exit_code == 0 { Ok(RetValue::from_tvm_value(ret_val, ret_type_code)) diff --git a/rust/tvm-sys/build.rs b/rust/tvm-sys/build.rs index d80bd9598246..170ccce0a9f1 100644 --- a/rust/tvm-sys/build.rs +++ b/rust/tvm-sys/build.rs @@ -84,21 +84,26 @@ fn main() -> Result<()> { println!("cargo:rerun-if-changed={}", build_path.display()); println!("cargo:rerun-if-changed={}/include", source_path.display()); - if cfg!(feature = "static-linking") { - println!("cargo:rustc-link-lib=static=tvm"); - // TODO(@jroesch): move this to tvm-build as library_path? - println!( - "cargo:rustc-link-search=native={}/build", - build_path.display() - ); - } - - if cfg!(feature = "dynamic-linking") { - println!("cargo:rustc-link-lib=dylib=tvm"); - println!( - "cargo:rustc-link-search=native={}/build", - build_path.display() - ); + match &std::env::var("CARGO_CFG_TARGET_ARCH").unwrap()[..] { + "wasm32" => {} + _ => { + if cfg!(feature = "static-linking") { + println!("cargo:rustc-link-lib=static=tvm"); + // TODO(@jroesch): move this to tvm-build as library_path? + println!( + "cargo:rustc-link-search=native={}/build", + build_path.display() + ); + } + + if cfg!(feature = "dynamic-linking") { + println!("cargo:rustc-link-lib=dylib=tvm"); + println!( + "cargo:rustc-link-search=native={}/build", + build_path.display() + ); + } + } } let runtime_api = source_path.join("include/tvm/runtime/c_runtime_api.h"); diff --git a/rust/tvm-sys/src/lib.rs b/rust/tvm-sys/src/lib.rs index f874e672bb66..f9ac3b461c69 100644 --- a/rust/tvm-sys/src/lib.rs +++ b/rust/tvm-sys/src/lib.rs @@ -40,6 +40,7 @@ pub mod ffi { num_args: c_int, out_ret_value: *mut TVMValue, out_ret_tcode: *mut u32, + resource_handle: *mut c_void, ) -> c_int; }