diff --git a/cmake/modules/contrib/Mrvl.cmake b/cmake/modules/contrib/Mrvl.cmake index 03296336196b..8bf48e02ca21 100644 --- a/cmake/modules/contrib/Mrvl.cmake +++ b/cmake/modules/contrib/Mrvl.cmake @@ -20,6 +20,7 @@ if(USE_MRVL) message(STATUS "Build with Mrvl support") file(GLOB RUNTIME_MRVL_SRCS src/runtime/contrib/mrvl/mrvl_runtime.cc + src/runtime/contrib/mrvl/mrvl_sw_runtime_lib.cc ) list(APPEND RUNTIME_SRCS ${RUNTIME_MRVL_SRCS}) file(GLOB COMPILER_MRVL_SRCS diff --git a/docker/Dockerfile.demo_mrvl b/docker/Dockerfile.demo_mrvl index a99345d07ffd..b50944d2c20e 100644 --- a/docker/Dockerfile.demo_mrvl +++ b/docker/Dockerfile.demo_mrvl @@ -17,3 +17,31 @@ # prebuild ci-cpu image FROM tlcpack/ci-cpu:20230604-060130-0af9ff90e + +# Cloning TVM's main repo +RUN echo "Cloning TVM source & submodules" +ENV TVM_PAR_DIR="/usr" +RUN mkdir -p TVM_PAR_DIR && \ + cd ${TVM_PAR_DIR} && \ + git clone --depth=1 https://github.com/apache/tvm tvm --recursive + +# Building TVM +RUN echo "Building TVM" +ENV TVM_HOME="/usr/tvm" +ENV TVM_BUILD_DIR="${TVM_HOME}/build" +RUN mkdir -p ${TVM_BUILD_DIR} && \ + cd ${TVM_HOME} && \ + ./tests/scripts/task_config_build_mrvl.sh build && \ + cd ${TVM_BUILD_DIR} && \ + cmake .. && \ + make -j$(nproc) + +RUN echo "Building Python package" +ENV PYTHONPATH=${TVM_HOME}/python:${PYTHONPATH} +RUN cd ${TVM_HOME}/python && python3 setup.py install --user + +# Fetching Marvell binaries +RUN cd /opt && \ + git clone https://github.com/MarvellEmbeddedProcessors/MarvellMLTools.git + +ENV PATH="/opt/MarvellMLTools/bin:$PATH" diff --git a/docs/how_to/deploy/mrvl.rst b/docs/how_to/deploy/mrvl.rst index 0b0b81ed3494..7b41e2ee3a74 100644 --- a/docs/how_to/deploy/mrvl.rst +++ b/docs/how_to/deploy/mrvl.rst @@ -32,7 +32,7 @@ compiles supported operations for accelerated execution on MLIP, or LLVM for general compute. For runtime, the library supports native execution on MLIP hardware -as well as Marvell's ML simulator (mlModel). +as well as Marvell's ML simulator (mrvl-mlsim). The library supports Marvell's Octeon family of processors with ML accelarators. @@ -54,21 +54,10 @@ https://tvm.apache.org/docs/install/from_source.html .. code:: bash - ./docker/build.sh demo_mrvl bash # Build the docker container - ./docker/bash.sh tvm.demo_mrvl --env PYTHONPATH=$PWD/python # Load the docker image + ./docker/build.sh demo_mrvl bash # Build the docker container + ./docker/bash.sh tvm.demo_mrvl # Load the docker image - -3. Build TVM inside the docker container with mrvl (inside tvm directory) -------------------------------------------------------------------------- - -.. code:: bash - - ./tests/scripts/task_config_build_mrvl.sh build - cd build - cmake .. - make -j$(nproc) # nproc = 4/8/.. (Number of Parallel jobs) - -4. Compiling a model using TVMC command line +3. Compiling a model using TVMC command line -------------------------------------------- Models can be compiled and run for mrvl target using TVMC which is optimized for performance. @@ -79,14 +68,14 @@ https://tvm.apache.org/docs/tutorial/tvmc_command_line_driver.html Additional mrvl-specific options may be added as attributes if necessary. The advanced usage is described in this document below. -4.1 TVMC Compilation Flow for a model +3.1 TVMC Compilation Flow for a model ------------------------------------- Refer to the following TVM documentation, for compilation flow https://tvm.apache.org/docs/arch/index.html#example-compilation-flow -4.2. TVMC - Command line option(s): Syntax for mrvl target +3.2. TVMC - Command line option(s): Syntax for mrvl target ---------------------------------------------------------- Compiling an ONNX model using the tvmc for mrvl target. @@ -115,8 +104,9 @@ integrated MLIP cn10ka processor, using only 4 tiles in the block. --output model.tar \ mnist-12.onnx +The runtime support for hardware acceleration is a WIP, it will be added in future PR. -4.3. TVMC Compiler: mrvl specific Command Line Options +3.3. TVMC Compiler: mrvl specific Command Line Options ------------------------------------------------------ .. code:: python @@ -151,30 +141,35 @@ integrated MLIP cn10ka processor, using only 4 tiles in the block. Optimize runtime by preloading a model's weights and bias into the on chip memory. Possible values = {0, 1}. Default is 0 (no preload) -5. Compilation - Generating model partitions --------------------------------------------- +4. Compile ONNX model for Simulator + LLVM / x86_64 target +---------------------------------------------------------- In the TVMC mrvl flow, the model is partitioned into Marvell and LLVM regions. Building each partitioned Marvell subgraph generates serialized nodes.json and const.json. Partitioned nodes.json is the representation of the model graph which is -suitable for the Marvell mmlc compiler. It is distributed separately via CDK +suitable for the Marvell compiler (mrvl-tmlc). The compiler compiles the model graph to +generate the model binary with MLIP instructions. -**Model Partition** +**Model Compilation for Simulator + LLVM / x86_64 target** -.. code:: bash +.. code:: python + + python3 -m tvm.driver.tvmc compile --target="mrvl, llvm" \ + --target-mrvl-num_tiles=4 --output model.tar model.onnx + +**Run TVM models on x86_64 host using MLIP Simulator** + +Generated model binary is simulated using Marvell's MLIP Simulator(mrvl-mlsim). - python3 -m tvm.driver.tvmc compile --target="mrvl, llvm \ - -mtriple=aarch64-linux-gnu -mcpu=neoverse-n2" \ - --cross-compiler aarch64-linux-gnu-gcc \ - --target-mrvl-num_tiles=4 --output model.tar model.onnx +.. code:: python + python3 -m tvm.driver.tvmc run --inputs infer.npz --outputs predict.npz model.tar --number=0 -6. Compiling a model using Python APIs +5. Compiling a model using Python APIs -------------------------------------- In addition to using TVMC, models can also be compiled and run using -TVM Python API. Below is an example to compile the MNIST model. Support -to run the model will be part of next PR by mrvl +TVM Python API. Below is an example to compile and run the MNIST model. **Download MNIST model from the web** @@ -187,9 +182,10 @@ to run the model will be part of next PR by mrvl .. code:: python - import tvm, onnx, os + import tvm, onnx import numpy as np import tvm.relay as relay + from tvm.contrib import graph_executor from tvm.relay.op.contrib.mrvl import partition_for_mrvl from tvm.relay.build_module import build from keras.datasets import mnist @@ -224,12 +220,33 @@ operations will go through the regular LLVM compilation and code generation for **Build the Relay Graph** Build the Relay graph, using the new module returned by partition_for_mrvl. -The target must always be a LLVM (ARM) target. ``partition_for_mrvl`` will -pass the options from dictionary into the config parameters needed by the -compiler backend, so there is no need to modify it - just pass it along -to the PassContext so the values can be read during compilation. .. code:: python with tvm.transform.PassContext(opt_level=3, config={"relay.ext.mrvl.options" : option_dict}): - model_lib = relay.build(mod, tvm_target, params=params) + model_lib = relay.build(mod, tvm_target, params=params) + +**Generate runtime graph of the model library** + +.. code:: python + + dev = tvm.cpu() + model_rt_graph = graph_executor.GraphModule(model_lib["default"](dev)) + +**Get test data and initialize model input** + +.. code:: python + + (train_X, train_y), (test_X, test_y) = mnist.load_data() + image = tvm.nd.array(test_X[0].reshape(1, 1, 28, 28).astype("float32") / 255) + inputs_dict = {} + inputs_dict["Input3"] = image + model_rt_graph.set_input(**inputs_dict) + +**Run Inference and print the output** + +.. code:: python + + model_rt_graph.run() + output_tensor = model_rt_graph.get_output(0).numpy() + print (output_tensor) diff --git a/python/tvm/contrib/mrvl.py b/python/tvm/contrib/mrvl.py index cd0dab05efe7..7004bb5b9db6 100644 --- a/python/tvm/contrib/mrvl.py +++ b/python/tvm/contrib/mrvl.py @@ -19,10 +19,41 @@ import os import json +import shutil +import tempfile +import base64 +import numpy as np import tvm import tvm._ffi +@tvm._ffi.register_func("tvm.mrvl.find_value_in_KV_pair") +def find_value_in_KV_pair(json_input: str, key_to_find: str) -> str: + """This function takes the graph_json string and key to be searched in + the json string, using json parser routine it loads the json string + and access the value using the given key. It raises exception if the + key is not found in the input json string. + + Parameters + ---------- + graph_json: String + This is the graph_json string + + Returns + ------- + value_string: string + This returns the value string for the given key string + """ + value = "" + try: + json_dict = json.loads(json_input) + value = json_dict[key_to_find] + except KeyError: + assert False, "Marvell-Compiler-ERROR-Internal:: Could not find matching key in json" + + return value + + @tvm._ffi.register_func("tvm.mrvl.GetNodesJSONString") def get_nodes_json_string(graph_json): """This takes the graph_json string from MrvlJSONSerializer and adds / modifies @@ -152,6 +183,7 @@ def get_nodes_json_string(graph_json): "kernel_const", "bias_const", "gamma_const", + "input_const", ]: iterator["attrs"][it2] = iterator["attrs"][it2][0] @@ -274,6 +306,18 @@ def modify_const_names(nodes_json_str, consts_json_str): var_map["dtype"] = const[new_name]["dtype"] var_map["name"] = new_name iterator["attrs"]["var_const"] = var_map + if attrs == "input_const_name": + new_name = iterator["name"] + "_const_0" + const[new_name] = const.pop(iterator["attrs"][attrs][0]) + const[new_name]["shape"] = list(map(int, iterator["attrs"]["input_const_shape"])) + iterator["attrs"][attrs][0] = new_name + map_const = {} + map_const["shape"] = const[new_name]["shape"] + map_const["dtype"] = const[new_name]["dtype"] + map_const["min"] = const[new_name]["min"] + map_const["max"] = const[new_name]["max"] + map_const["name"] = new_name + iterator["attrs"]["input_const"] = map_const nodes_mod_str = json.dumps(nodes, indent=2) const_mod_str = json.dumps(const, indent=2) @@ -283,3 +327,131 @@ def modify_const_names(nodes_json_str, consts_json_str): def get_working_dir(): """Obtain the current working directory from where tvm is invoked""" return os.getcwd() + + +@tvm._ffi.register_func("tvm.mrvl.WriteJsonFile") +def write_json_file(json_string, json_filename): + """Generate json file under working directory""" + working_dir = get_working_dir() + json_file = os.path.join(working_dir, json_filename) + with open(json_file, "w") as out_file: + out_file.write(json_string) + return json_file + + +def delete_temp_files(symbol_name): + """Delete temporary files generated by the Marvell compiler""" + working_dir = get_working_dir() + nodes_json_file = os.path.join(working_dir, f"{symbol_name}-nodes.json") + consts_json_file = os.path.join(working_dir, f"{symbol_name}-consts.json") + os.remove(nodes_json_file) + os.remove(consts_json_file) + bin_folder = os.path.join(working_dir, "bin_" + symbol_name) + if "MRVL_SAVE_MODEL_BIN" not in os.environ: + shutil.rmtree(bin_folder) + + +@tvm._ffi.register_func("tvm.mrvl.CompileModel") +def compile_model( + symbol_name, + nodes_json_string, + consts_json_string, + compiler_opts, +): + """Compile the model using Marvell Backend compiler and return the generated binary""" + # generate pair of json files + nodes_json_file = write_json_file(nodes_json_string, f"{symbol_name}-nodes.json") + consts_json_file = write_json_file(consts_json_string, f"{symbol_name}-consts.json") + mrvl_exec = "mrvl-tmlc" + exec_on_path = shutil.which(mrvl_exec) + if exec_on_path is None: + error_msg = ( + "Marvell Compiler not found! Please specify the path to Marvell tools " + "by adding it to $PATH." + ) + raise RuntimeError(error_msg) + + # Parse the nodes_json string for the batch size + dictionary = json.loads(nodes_json_string) + batch_size = dictionary["batch_size"] + + # Check for supported batch size + if int(batch_size) > 8: + error_msg = "Compilation ERROR: mrvl-tmlc supports batch_size <= 8" + raise RuntimeError(error_msg) + + # Invoke Marvell Backend with appropriate options + compile_cmd = ( + mrvl_exec + + " -mn " + + symbol_name + + " -f1 " + + nodes_json_file + + " -f2 " + + consts_json_file + + " " + + compiler_opts + + " -b " + + batch_size + ) + + ret_val = os.system(compile_cmd) + if ret_val == 0: + # Read generated binary and encode in base64 format + working_dir = get_working_dir() + bin_file = os.path.join(working_dir, "bin_" + symbol_name, symbol_name + ".bin") + + with open(bin_file, "rb") as f: + data = bytearray(f.read()) + base64_bytes = base64.b64encode(data) + if not data: + raise RuntimeError("Compilation ERROR: Marvell binary could not be generated") + # Cleanup Temporary Files + delete_temp_files(symbol_name) + return base64_bytes + else: + error_msg = "Compilation ERROR: Error compiling Marvell region!" + raise RuntimeError(error_msg) + + +@tvm._ffi.register_func("tvm.mrvl.CleanUpSim") +def clean_up_sim(bin_file, input_json, input_bin, out_bin_prefix, num_outputs): + os.remove(bin_file) + os.remove(input_json) + os.remove(input_bin) + for i in range(num_outputs): + out_bin = out_bin_prefix + "-" + str(i) + ".bin" + os.remove(out_bin) + + +@tvm._ffi.register_func("tvm.mrvl.SearchPath") +def search_path(file_name): + path = shutil.which(file_name) + if path is None: + return "" + return os.path.dirname(path) + + +@tvm._ffi.register_func("tvm.mrvl.JsonToBin") +def convert_json_to_bin(json_file, input_bin_file): + with open(json_file) as input_json: + data = json.load(input_json) + data_float = np.array(data["inputs"], dtype=np.float32) + data_b = data_float.tobytes() + with open(input_bin_file, "wb") as f: + f.write(data_b) + + +@tvm._ffi.register_func("tvm.mrvl.RunSim") +def run_simulation(run_command, sim_directory): + cwd_path = get_working_dir() + os.mkdir(sim_directory) + os.chdir(sim_directory) + os.system(run_command) + os.chdir(cwd_path) + shutil.rmtree(sim_directory) + + +@tvm._ffi.register_func("tvm.mrvl.TempDir") +def get_temp_dir(): + return tempfile.gettempdir() diff --git a/python/tvm/relay/op/contrib/mrvl.py b/python/tvm/relay/op/contrib/mrvl.py index 016e7ea7f6b1..75041fbc8c44 100644 --- a/python/tvm/relay/op/contrib/mrvl.py +++ b/python/tvm/relay/op/contrib/mrvl.py @@ -432,14 +432,14 @@ def conv2d_batchnorm(pattern): return pad | no_pad - def sum2d_pattern(): - """Create a sum2d pattern. + def sum_pattern(): + """Create a sum pattern. review tvm/tests/python/relay/test_dataflow_pattern.py for examples Returns ------- pattern : dataflow_pattern.AltPattern - Denotes the sum2d pattern. + Denotes the sum pattern. """ pattern = is_op("add")(wildcard(), wildcard()) pattern = is_activation(pattern) @@ -466,13 +466,28 @@ def fc_pattern(): pattern : dataflow_pattern.AltPattern Denotes the fc pattern. """ - pattern = is_op("nn.dense")(wildcard(), is_constant()) - pattern = pattern.optional( - lambda x: (is_op("nn.bias_add")(x, is_constant()) | is_op("add")(x, is_constant())) + + def fc_base_pattern(pattern): + pattern = is_op("nn.dense")(pattern, is_constant()) + pattern = pattern.optional( + lambda x: (is_op("nn.bias_add")(x, is_constant()) | is_op("add")(x, is_constant())) + ) + pattern = is_activation(pattern) + + return pattern + + transform1 = is_op("layout_transform")(wildcard()).has_attr( + {"src_layout": "NHWC", "dst_layout": "NCHW"} ) - pattern = is_activation(pattern) + reshape = is_op("reshape")(transform1) + flatten = is_op("nn.batch_flatten")(transform1) + flatten = reshape | flatten + flatten = fc_base_pattern(flatten) - return pattern + no_flatten = wildcard() + no_flatten = fc_base_pattern(no_flatten) + + return flatten | no_flatten def maxpool2d_pattern(): """Create a maxpool2d pattern. @@ -543,16 +558,6 @@ def layout_transform_nchw2nhwc_pattern(): ) return pattern - def layout_transform_nhwc2nchw_to_2D_pattern(): - # Layout_Transform + Reshape/BatchFlatten - transform1 = is_op("layout_transform")(wildcard()).has_attr( - {"src_layout": "NHWC", "dst_layout": "NCHW"} - ) - pattern1 = is_op("reshape")(transform1) - pattern2 = is_op("nn.batch_flatten")(transform1) - - return pattern1 | pattern2 - def check_conv2d(extract): """Check conv pattern is supported by Mrvl.""" call = extract @@ -609,21 +614,12 @@ def check_layout_transform_nchw2nhwc(extract): call = call.args[0] return layout_transform_nchw2nhwc(call) - def check_layout_transform_nhwc2nchw_2D(extract): - call = extract - if call.op.name == "reshape" or call.op.name == "nn.batch_flatten": - call = call.args[0] - if call.op.name == "layout_transform": - if call.attrs.src_layout == "NHWC" and call.attrs.dst_layout == "NCHW": - return True - return False - - def check_sum2d(extract): + def check_sum(extract): """Check sum2d pattern is supported by Mrvl.""" call = extract while call.op.name != "add": call = call.args[0] - return sum2d(call) + return summation(call) def check_concat(extract): """Check concat pattern is supported by Mrvl.""" @@ -638,13 +634,8 @@ def check_concat(extract): ("mrvl.maxpool2d_nhwc2nhwc", maxpool2d_pattern(), check_maxpool2d), ("mrvl.avgpool2d_nhwc2nhwc", avgpool2d_pattern(), check_avgpool2d), ("mrvl.globalavgpool2d_nhwc2nhwc", globalavgpool2d_pattern(), check_globalavgpool2d), - ("mrvl.sum2d", sum2d_pattern(), check_sum2d), + ("mrvl.sum", sum_pattern(), check_sum), ("mrvl.concat", concat_pattern(), check_concat), - ( - "mrvl.layout_transform_nhwc2nchw_reshape", - layout_transform_nhwc2nchw_to_2D_pattern(), - check_layout_transform_nhwc2nchw_2D, - ), ( "mrvl.layout_transform_nchw2nhwc", layout_transform_nchw2nhwc_pattern(), @@ -692,8 +683,8 @@ def conv2d_nhwc2nhwc(expr): # register a helper function to indicate that the given operator can be supported by Mrvl. @tvm.ir.register_op_attr("add", "target.mrvl") -def sum2d(expr): - """Check if the external Mrvl codegen for sum2d should be used.""" +def summation(expr): + """Check if the external Mrvl codegen for sum should be used.""" arg0 = expr.args[0] # - need to further checking if the call_func of arg0 is not nn.conv2d nor nn.dense @@ -707,7 +698,7 @@ def sum2d(expr): # - need to further checking if dimension of input or output tensor is 4 data_type = arg0.checked_type if ( - (len(data_type.shape) != 4) + (len(data_type.shape) != 4 and len(data_type.shape) != 3) or not is_valid_batch_size(data_type.shape[0]) or (data_type.dtype not in ["float32"]) ): @@ -827,14 +818,13 @@ def reshape_mrvl(expr): """Check if the external Mrvl codegen for reshape should be used.""" if expr.op.name != "reshape": return False - else: - data_type = expr.checked_type - if not (len(data_type.shape) == 4 or len(data_type.shape) == 2): - return False + data_type = expr.checked_type + if not (len(data_type.shape) == 4 or len(data_type.shape) == 2): + return False - args = expr.args - data_type = args[0].checked_type - return True + args = expr.args + data_type = args[0].checked_type + return True @tvm.ir.register_op_attr("nn.batch_flatten", "target.mrvl") diff --git a/src/relay/backend/contrib/mrvl/codegen.cc b/src/relay/backend/contrib/mrvl/codegen.cc index d395de6694ff..6d7e593b9b04 100644 --- a/src/relay/backend/contrib/mrvl/codegen.cc +++ b/src/relay/backend/contrib/mrvl/codegen.cc @@ -187,9 +187,9 @@ class MrvlJSONSerializer : public backend::contrib::JSONSerializer { /*! * \brief A series of operators that form a composite - * sum2d. + * sum. */ - struct CompositeSum2DNode { + struct CompositeSumNode { const CallNode* add = nullptr; const CallNode* activation = nullptr; }; @@ -218,14 +218,6 @@ class MrvlJSONSerializer : public backend::contrib::JSONSerializer { const CallNode* reshape = nullptr; }; - /*! - * \brief A series of operators that form a transform reshape node. - */ - struct CompositeLayoutTransformReshapeNode { - const CallNode* transform = nullptr; - const CallNode* reshape = nullptr; - }; - /*! * \brief A series of operators that form a batch flatten node. */ @@ -238,6 +230,8 @@ class MrvlJSONSerializer : public backend::contrib::JSONSerializer { * fc layer. Supports both nn.fc_ni2no and qnn.fc_ni2no. */ struct CompositeFcNode { + const CallNode* transform = nullptr; + const CallNode* flatten = nullptr; const CallNode* fc = nullptr; const CallNode* add = nullptr; const CallNode* activation = nullptr; @@ -284,12 +278,10 @@ class MrvlJSONSerializer : public backend::contrib::JSONSerializer { json_kernel_node = CreateCompositeMrvlAvgpool2DLayer(cn); } else if (name == "mrvl.globalavgpool2d_nhwc2nhwc") { json_kernel_node = CreateCompositeMrvlGlobalAvgpool2DLayer(cn); - } else if (name == "mrvl.sum2d") { - json_kernel_node = CreateCompositeMrvlSum2DLayer(cn); + } else if (name == "mrvl.sum") { + json_kernel_node = CreateCompositeMrvlSumLayer(cn); } else if (name == "mrvl.concat") { json_kernel_node = CreateMrvlConcatLayer(cn); - } else if (name == "mrvl.layout_transform_nhwc2nchw_reshape") { - json_kernel_node = CreateMrvlLayoutTransposeReshapeLayer(cn); } else if (name == "mrvl.reshape") { json_kernel_node = CreateMrvlReshapeLayer(cn); } else if (name == "mrvl.batch_flatten") { @@ -308,6 +300,83 @@ class MrvlJSONSerializer : public backend::contrib::JSONSerializer { int node_idx_{0}; int const_suffix_{0}; + void resizeInputOutputLayoutTo4dim(std::shared_ptr json_node, const CallNode* cn, + std::string node_name) { + const uint64_t new_layout_size = 4; + std::string data_layout = "NHWC"; + std::string out_layout = "NHWC"; + + auto num_inputs = GetInputNum(cn); + auto num_outputs = GetOutputNum(cn); + uint64_t max_old_input_layout_size = 0; + // Inputs + if (num_inputs > 1) { + for (uint64_t in_idx = 0; in_idx < num_inputs; in_idx++) { + std::vector layout; + GetInputTensorShapeViaArgN(cn, &layout, in_idx); + uint64_t old_layout_size = layout.size(); + max_old_input_layout_size = std::max(old_layout_size, max_old_input_layout_size); + ICHECK(old_layout_size <= 4) << "Marvell-Compiler-ERROR-Internal::" << node_name + << " with input tensor shape > 4 is not supported yet."; + layout.resize(new_layout_size, 1); + + if (!cn->args[in_idx].as()) { + JsonNodeSetVecAttr(json_node, "data_layout_shape_" + std::to_string(in_idx), layout); + if (in_idx == 0) { + JsonNodeSetVecAttr(json_node, "data_layout_shape", layout); + } + } + } + for (uint64_t in_idx = 0; in_idx < num_inputs; in_idx++) { + std::vector layout; + GetInputTensorShapeViaArgN(cn, &layout, in_idx); + uint64_t old_layout_size = layout.size(); + ICHECK(old_layout_size <= 4) << "Marvell-Compiler-ERROR-Internal::" << node_name + << " with input tensor shape > 4 is not supported yet."; + layout.resize(max_old_input_layout_size, 1); + std::rotate(layout.begin(), layout.end() - (max_old_input_layout_size - old_layout_size), + layout.end()); + layout.resize(new_layout_size, 1); + if (cn->args[in_idx].as()) { + std::vector const_name = {layer_name_ + "_const_" + + std::to_string(const_suffix_++)}; + JsonNodeSetAttr(json_node, "input_const_name", const_name); + JsonNodeSetVecAttr(json_node, "input_const_shape", layout); + } + } + } else { + std::vector layout; + GetInputTensorShapeViaArgN(cn, &layout, 0); + layout.resize(new_layout_size, 1); + JsonNodeSetVecAttr(json_node, "data_layout_shape", layout); + } + // Outputs + if (num_outputs > 1) { + std::vector> layout; + GetOutputTensorShapes(cn, &layout); + for (size_t out_idx = 0; out_idx < num_outputs; out_idx++) { + ICHECK(layout.at(out_idx).size() <= 4) + << "Marvell-Compiler-ERROR-Internal::" << node_name + << " with output tensor shape > 4 is not supported yet."; + layout.at(out_idx).resize(new_layout_size, 1); + JsonNodeSetVecAttr(json_node, "out_layout_shape_" + std::to_string(out_idx), + layout.at(out_idx)); + if (out_idx == 0) { + JsonNodeSetVecAttr(json_node, "out_layout_shape", layout.at(out_idx)); + } + } + } else { + std::vector layout; + GetOutputTensorShape(cn, &layout); + layout.resize(new_layout_size, 1); + JsonNodeSetVecAttr(json_node, "out_layout_shape", layout); + } + + std::vector layout_format_vec = {data_layout}; + JsonNodeSetAttr(json_node, "data_layout", layout_format_vec); + JsonNodeSetAttr(json_node, "out_layout", layout_format_vec); + } + /*! * \brief Extract convolution nodes from a composite function. * @@ -366,13 +435,13 @@ class MrvlJSONSerializer : public backend::contrib::JSONSerializer { } /*! - * \brief Extract sum2d nodes from a composite function. + * \brief Extract sum nodes from a composite function. * * \param call The call node of the composite function. - * \return Extracted composite sum2d nodes. + * \return Extracted composite sum nodes. */ - CompositeSum2DNode UnpackCompositeSum2D(const CallNode* call) { - CompositeSum2DNode nodes{}; + CompositeSumNode UnpackCompositeSum(const CallNode* call) { + CompositeSumNode nodes{}; const auto* fn = call->op.as(); ICHECK(fn) << "Marvell-Compiler-ERROR-Internal::Downcast to FunctionNode failed."; @@ -408,30 +477,6 @@ class MrvlJSONSerializer : public backend::contrib::JSONSerializer { return nodes; } - /*! - * \brief Extract LayoutTransposeReshape nodes from a composite function. - * - * \param call The call node of the composite function. - * \return Extracted composite layouttranspose reshape nodes. - */ - CompositeLayoutTransformReshapeNode UnpackCompositeLayoutTransposeReshape(const CallNode* call) { - CompositeLayoutTransformReshapeNode nodes{}; - const auto* fn = call->op.as(); - ICHECK(fn) << "Marvell-Compiler-ERROR-Internal::Downcast to FunctionNode failed."; - - const CallNode* current_call = fn->body.as(); - ICHECK(backend::IsOp(current_call, "reshape") || - backend::IsOp(current_call, "nn.batch_flatten")) - << "Marvell-Compiler-ERROR-Internal::Reshape/Batch_flatten Op missing."; - nodes.reshape = current_call; - current_call = current_call->args[0].as(); - - ICHECK(backend::IsOp(current_call, "layout_transform")) - << "Marvell-Compiler-ERROR-Internal::Layout_Transform Op missing."; - nodes.transform = current_call; - return nodes; - } - /*! * \brief Extract Reshape nodes from a composite function. * @@ -530,6 +575,18 @@ class MrvlJSONSerializer : public backend::contrib::JSONSerializer { ICHECK(backend::IsOp(current_call, "nn.dense")) << "Marvell-Compiler-ERROR-Internal::nn.dense Op missing."; nodes.fc = current_call; + current_call = current_call->args[0].as(); + if (current_call) { + if (backend::IsOp(current_call, "reshape") | + backend::IsOp(current_call, "nn.batch_flatten")) { + nodes.flatten = current_call; + current_call = current_call->args[0].as(); + ICHECK(backend::IsOp(current_call, "layout_transform")) + << "Marvell-Compiler-ERROR-Internal::layout_transform Op missing."; + nodes.transform = current_call; + } + } + return nodes; } @@ -627,7 +684,7 @@ class MrvlJSONSerializer : public backend::contrib::JSONSerializer { if (num_inputs > 1) { for (size_t in_idx = 0; in_idx < num_inputs; in_idx++) { std::vector data_layout_vec_n; - GetInputTensorShapeViaArg(cn, &data_layout_vec_n, &tuple_idx, in_idx); + tuple_idx = GetInputTensorShapeViaArgN(cn, &data_layout_vec_n, in_idx); std::string attr_name = "data_layout_shape_" + std::to_string(in_idx); JsonNodeSetVecAttr(json_node, attr_name, data_layout_vec_n); tuple_idx_vec.push_back(tuple_idx); @@ -636,7 +693,7 @@ class MrvlJSONSerializer : public backend::contrib::JSONSerializer { } } } else { - GetInputTensorShapeViaArg(cn, &data_layout_vec, &tuple_idx, 0); + tuple_idx = GetInputTensorShapeViaArgN(cn, &data_layout_vec, 0); JsonNodeSetVecAttr(json_node, "data_layout_shape", data_layout_vec); tuple_idx_vec.push_back(tuple_idx); } @@ -784,6 +841,17 @@ class MrvlJSONSerializer : public backend::contrib::JSONSerializer { if (tuple_type) { tensor_type = tuple_type->fields[n].as(); } + } else if (call_node_ptr->args[n].as()) { + const auto* arg_n = call_node_ptr->args[n].as(); + ICHECK((arg_n != nullptr) && arg_n->IsInstance()) + << "Marvell-Compiler-ERROR-Internal::Downcast to ConstantNode failed."; + tensor_type = arg_n->checked_type().as(); + if (tensor_type == nullptr) { + const TupleTypeNode* tuple_type = arg_n->checked_type().as(); + if (tuple_type) { + tensor_type = tuple_type->fields[n].as(); + } + } } } else { LOG(INFO) << "TVM Mrvl runtime does not support calls to " @@ -798,10 +866,11 @@ class MrvlJSONSerializer : public backend::contrib::JSONSerializer { } } - void GetInputTensorShapeViaArg0(const CallNode* call_node_ptr, - std::vector* tensor_shape) { + int GetInputTensorShapeViaArgN(const CallNode* call_node_ptr, std::vector* tensor_shape, + int64_t n = 0) { int tuple_idx = -1; - GetInputTensorShapeViaArg(call_node_ptr, tensor_shape, &tuple_idx, 0); + GetInputTensorShapeViaArg(call_node_ptr, tensor_shape, &tuple_idx, n); + return tuple_idx; } void GetTensorShape(const VarNode* var_node_ptr, std::vector* tensor_shape) { @@ -937,32 +1006,25 @@ class MrvlJSONSerializer : public backend::contrib::JSONSerializer { } /*! - * \brief Create a JSON representation of a composite sum2d. + * \brief Create a JSON representation of a composite sum. * * \param cn The call to be represented. * \return A JSON representation of a specific operator. */ - std::shared_ptr CreateCompositeMrvlSum2DLayer(const CallNode* cn) { - CompositeSum2DNode nodes = UnpackCompositeSum2D(cn); + std::shared_ptr CreateCompositeMrvlSumLayer(const CallNode* cn) { + CompositeSumNode nodes = UnpackCompositeSum(cn); ICHECK(nodes.add != nullptr) << "Marvell-Compiler-ERROR-Internal::attribute add can't be nullptr"; std::string mrvlLayerName = "Sum2D"; - std::string name = "sum2d"; + std::string name = "sum"; std::string data_layout; std::string out_layout; std::vector layout_vec; std::vector inputs; - inputs.push_back(VisitExpr(cn->args[0])[0]); - inputs.push_back(VisitExpr(cn->args[1])[0]); - GetInputTensorShapeViaArg0(cn, &layout_vec); - if (layout_vec.size() == 4) { - data_layout = "NHWC"; - out_layout = "NHWC"; - } else if (layout_vec.size() == 2) { - data_layout = "NC"; - out_layout = "NC"; + for (auto arg : cn->args) { + inputs.push_back(VisitExpr(arg)[0]); } // add json node attributes @@ -970,6 +1032,7 @@ class MrvlJSONSerializer : public backend::contrib::JSONSerializer { SetCallNodeAttribute(json_node, nodes.add); if (nodes.activation) JsonNodeSetAttr(json_node, "activation_type", {"relu"}); SetMrvlLayerCommonAttrs(json_node, cn, layer_name_, mrvlLayerName, data_layout, "", out_layout); + resizeInputOutputLayoutTo4dim(json_node, cn, "Sum"); return json_node; } @@ -989,7 +1052,7 @@ class MrvlJSONSerializer : public backend::contrib::JSONSerializer { std::vector inputs; inputs.push_back(VisitExpr(cn->args[0])[0]); - GetInputTensorShapeViaArg0(nodes.reshape, &layout_vec); + GetInputTensorShapeViaArgN(nodes.reshape, &layout_vec); ICHECK(layout_vec.size() == 2 || layout_vec.size() == 4) << "Marvell-Compiler-ERROR-Internal::" << "Reshape with input tensor dim != 2 or != 4 is not supported yet."; @@ -1031,7 +1094,7 @@ class MrvlJSONSerializer : public backend::contrib::JSONSerializer { std::vector inputs; inputs.push_back(VisitExpr(cn->args[0])[0]); - GetInputTensorShapeViaArg0(nodes.batch_flatten, &layout_vec); + GetInputTensorShapeViaArgN(nodes.batch_flatten, &layout_vec); ICHECK(layout_vec.size() == 2 || layout_vec.size() == 4) << "Marvell-Compiler-ERROR-Internal::" << "nn.batch_flatten with input tensor dim != 2 or != 4 is not supported yet."; @@ -1074,7 +1137,7 @@ class MrvlJSONSerializer : public backend::contrib::JSONSerializer { } std::vector layout_vec; - GetInputTensorShapeViaArg0(cn, &layout_vec); + GetInputTensorShapeViaArgN(cn, &layout_vec); if (layout_vec.size() == 4) { data_layout = "NHWC"; out_layout = "NHWC"; @@ -1090,33 +1153,6 @@ class MrvlJSONSerializer : public backend::contrib::JSONSerializer { return json_node; } - /*! - * \brief Create a JSON representation of a composite LayoutTransform Reshape. - * - * \param cn The call to be represented. - * \return A JSON representation of a specific operator. - */ - std::shared_ptr CreateMrvlLayoutTransposeReshapeLayer(const CallNode* cn) { - CompositeLayoutTransformReshapeNode nodes = UnpackCompositeLayoutTransposeReshape(cn); - ICHECK(nodes.transform != nullptr) - << "Marvell-Compiler-ERROR-Internal::attribute transform can't be nullptr"; - - std::string mrvlLayerName = "TransformReshape"; - std::string name = "transformreshape"; - std::string data_layout; - std::string out_layout = "NC"; - std::vector inputs; - - inputs.push_back(VisitExpr(cn->args[0])[0]); - auto layout_transform_attr = nodes.transform->attrs.as(); - data_layout = layout_transform_attr->src_layout; - - auto json_node = std::make_shared(name, "kernel", inputs, 1); - SetMrvlLayerCommonAttrs(json_node, cn, layer_name_, name, data_layout, - "" /* no kernel_layout */, out_layout); - return json_node; - } - /*! * \brief Create a JSON representation of a composite fc (fully-connected) operator. * @@ -1153,7 +1189,10 @@ class MrvlJSONSerializer : public backend::contrib::JSONSerializer { JsonNodeSetAttr(json_node, "bias_layout", {bias_layout}); } if (nodes.activation) JsonNodeSetAttr(json_node, "activation_type", {"relu"}); - + if (nodes.transform && nodes.flatten) { + JsonNodeSetAttr(json_node, "weights_need_transform", {"yes"}); + data_layout = "NHWC"; + } SetMrvlLayerCommonAttrs(json_node, cn, layer_name_, mrvlLayerName, data_layout, kernel_layout, out_layout); return json_node; @@ -1251,7 +1290,7 @@ class MrvlJSONSerializer : public backend::contrib::JSONSerializer { inputs.push_back(VisitExpr(cn->args[0])[0]); std::vector kernel_layout_vec; std::vector data_layout_vec; - GetInputTensorShapeViaArg0(cn, &data_layout_vec); + GetInputTensorShapeViaArgN(cn, &data_layout_vec); ICHECK(data_layout_vec.size() == 4); kernel_layout_vec.push_back(data_layout_vec[1]); kernel_layout_vec.push_back(data_layout_vec[2]); @@ -1311,7 +1350,7 @@ std::vector split(const std::string& s, char delim) { } /*! - * \brief Generate JSON meta files and then return a runtime module for Mrvl. + * \brief Generate compiled model binary and then return a runtime module for Mrvl. * * \note This consists of a series of IR functions, which each represents * a full Mrvl subgraph/region (in tvmc mode) or one fused Mrvl backend layer @@ -1344,9 +1383,13 @@ runtime::Module MrvlCompiler(const ObjectRef& ref) { std::string modified_json = (*modifyConsts)(nodes_json_string, consts_json_string); auto json_vec = split(modified_json, '|'); + // Invoke Marvell Backend compiler to generate binary for sub graph + const auto* compile = runtime::Registry::Get("tvm.mrvl.CompileModel"); + std::string bin = (*compile)(func_name, json_vec[0], json_vec[1], compiler_opt); + const auto* pf = runtime::Registry::Get("runtime.mrvl_runtime_create"); ICHECK(pf != nullptr) << "Cannot find software simulator runtime module to create"; - runtime_lib = (*pf)(func_name, json_vec[0]); + runtime_lib = (*pf)(func_name, json_vec[0], bin); return runtime_lib; } diff --git a/src/relay/backend/contrib/mrvl/compiler_attr.cc b/src/relay/backend/contrib/mrvl/compiler_attr.cc index 4309212e3350..86cb04ab3936 100644 --- a/src/relay/backend/contrib/mrvl/compiler_attr.cc +++ b/src/relay/backend/contrib/mrvl/compiler_attr.cc @@ -36,7 +36,6 @@ struct MrvlCompilerConfigNode : public tvm::AttrsNode { String mcpu; IntImm num_tiles; String mattr; - String working_dir; TVM_DECLARE_ATTRS(MrvlCompilerConfigNode, "ext.attrs.MrvlCompilerConfigNode") { TVM_ATTR_FIELD(mcpu) diff --git a/src/runtime/contrib/mrvl/mrvl_base64.h b/src/runtime/contrib/mrvl/mrvl_base64.h new file mode 100644 index 000000000000..67452597fd48 --- /dev/null +++ b/src/runtime/contrib/mrvl/mrvl_base64.h @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file mrvl_base64.h + * \brief Util functions for converting plain bytes back to plain bytes + */ + +#ifndef TVM_RUNTIME_CONTRIB_MRVL_MRVL_BASE64_H_ +#define TVM_RUNTIME_CONTRIB_MRVL_MRVL_BASE64_H_ + +#include + +#include +#include + +#include "../../../../src/support/base64.h" + +namespace tvm { +namespace runtime { +namespace contrib { +namespace mrvl { + +inline size_t b64strlen(const std::string& b64str) { + ICHECK(b64str.size() % 4 == 0) << "invalid base64 encoding"; + size_t length = b64str.size() / 4 * 3; + if (b64str[b64str.size() - 2] == '=') { + length -= 2; + } else if (b64str[b64str.size() - 1] == '=') { + length -= 1; + } + return length; +} + +inline void b64decode(const std::string& b64str, uint8_t* ret) { + size_t index = 0; + const auto length = b64str.size(); + for (size_t i = 0; i < length; i += 4) { + int8_t ch0 = tvm::support::base64::DecodeTable[(int32_t)b64str[i]]; + int8_t ch1 = tvm::support::base64::DecodeTable[(int32_t)b64str[i + 1]]; + int8_t ch2 = tvm::support::base64::DecodeTable[(int32_t)b64str[i + 2]]; + int8_t ch3 = tvm::support::base64::DecodeTable[(int32_t)b64str[i + 3]]; + uint8_t st1 = (ch0 << 2) + (ch1 >> 4); + ret[index++] = st1; + if (b64str[i + 2] != '=') { + uint8_t st2 = ((ch1 & 0b1111) << 4) + (ch2 >> 2); + ret[index++] = st2; + if (b64str[i + 3] != '=') { + uint8_t st3 = ((ch2 & 0b11) << 6) + ch3; + ret[index++] = st3; + } + } + } + ICHECK(b64strlen(b64str) == index) << "base64 decoding fails"; +} + +} // namespace mrvl +} // namespace contrib +} // namespace runtime +} // namespace tvm + +#endif // TVM_RUNTIME_CONTRIB_MRVL_MRVL_BASE64_H_ diff --git a/src/runtime/contrib/mrvl/mrvl_runtime.cc b/src/runtime/contrib/mrvl/mrvl_runtime.cc index 89e8ff108e59..337d81c8a0be 100644 --- a/src/runtime/contrib/mrvl/mrvl_runtime.cc +++ b/src/runtime/contrib/mrvl/mrvl_runtime.cc @@ -34,6 +34,7 @@ #include #include "../json/json_node.h" +#include "mrvl_sw_runtime_lib.h" namespace tvm { namespace runtime { @@ -44,12 +45,16 @@ namespace contrib { hardware and then runs the generated binary using the Marvell software simulator (MlModel). * \param symbol_name The name of the subgraph / relay function * \param nodes_json The serialized JSON representation of relay function + * \param bin_code The binary code generated by the Marvell compiler for the subgraph */ class MarvellSimulatorModuleNode : public ModuleNode { public: - MarvellSimulatorModuleNode(const std::string& symbol_name, const std::string& nodes_json) - : symbol_name_(symbol_name), nodes_json_(nodes_json) {} + MarvellSimulatorModuleNode(const std::string& symbol_name, const std::string& nodes_json, + const std::string& bin_code) + : symbol_name_(symbol_name), nodes_json_(nodes_json), bin_code_(bin_code) { + set_num_inputs_outputs(); + } const char* type_key() const { return "mrvl_sim"; } @@ -85,18 +90,21 @@ class MarvellSimulatorModuleNode : public ModuleNode { // binary format. stream->Write(symbol_name_); stream->Write(nodes_json_); + stream->Write(bin_code_); } static Module LoadFromBinary(void* strm) { dmlc::Stream* stream = static_cast(strm); std::string symbol_name; std::string nodes_json; + std::string bin_code; // Load the symbol_name and other data to construct the module ICHECK(stream->Read(&symbol_name)) << "Marvell-Compiler-ERROR-Internal::Loading symbol name failed"; ICHECK(stream->Read(&nodes_json)) << "Marvell-Compiler-ERROR-Internal::Loading nodes json failed"; - auto n = make_object(symbol_name, nodes_json); + ICHECK(stream->Read(&bin_code)) << "Marvell-Compiler-ERROR-Internal::Loading bin code failed"; + auto n = make_object(symbol_name, nodes_json, bin_code); return Module(n); } @@ -111,15 +119,33 @@ class MarvellSimulatorModuleNode : public ModuleNode { protected: std::string symbol_name_; std::string nodes_json_; + std::string bin_code_; + size_t num_inputs_; + size_t num_outputs_; void Run(TVMArgs args) { - ICHECK(false) << "Marvell-Compiler-ERROR-Internal::Run not supported for Marvell Runtime yet!"; + ICHECK_EQ(args.size(), num_inputs_ + num_outputs_) + << "Marvell-Compiler-ERROR-Internal::Mismatch in number of input & number of output args " + "to subgraph"; + tvm::runtime::contrib::mrvl::RunMarvellSimulator(args, symbol_name_, bin_code_, num_inputs_, + num_outputs_); + } + + void set_num_inputs_outputs() { + const auto* get_value_from_key = runtime::Registry::Get("tvm.mrvl.find_value_in_KV_pair"); + + std::string value_for_inputs = (*get_value_from_key)(nodes_json_, "num_subgraph_inputs"); + num_inputs_ = std::stoi(value_for_inputs); + + std::string value_for_outputs = (*get_value_from_key)(nodes_json_, "num_subgraph_outputs"); + num_outputs_ = std::stoi(value_for_outputs); } }; runtime::Module MarvellSimulatorModuleRuntimeCreate(const String& symbol_name, - const String& nodes_json) { - auto n = make_object(symbol_name, nodes_json); + const String& nodes_json, + const String& bin_code) { + auto n = make_object(symbol_name, nodes_json, bin_code); return runtime::Module(n); } diff --git a/src/runtime/contrib/mrvl/mrvl_sw_runtime_lib.cc b/src/runtime/contrib/mrvl/mrvl_sw_runtime_lib.cc new file mode 100644 index 000000000000..f5e222255ce6 --- /dev/null +++ b/src/runtime/contrib/mrvl/mrvl_sw_runtime_lib.cc @@ -0,0 +1,175 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/runtime/contrib/mrvl/mrvl_sw_runtime_lib.cc + * \brief Runtime library for Marvell Software Simulator. + */ + +#include "mrvl_sw_runtime_lib.h" + +#include +#include +#include + +#include +#include + +#include "mrvl_base64.h" + +using namespace tvm::runtime; + +template +static void NDArrayToFile(const tvm::runtime::NDArray& arr, std::ostream& os) { + int ndim = arr->ndim; + int tot_dim = 1; + for (int i = 0; i < ndim; i++) { + tot_dim *= arr->shape[i]; + } + T* data_ptr = reinterpret_cast(arr->data); + os << "\t\t["; + os << std::endl; + for (int i = 0; i < tot_dim; i++) { + os << "\t\t\t" << std::setprecision(10) << data_ptr[i] << (i != tot_dim - 1 ? "," : ""); + os << std::endl; + } + os << "\t\t]"; +} + +static void WriteBinToDisk(const std::string& bin_file, const std::string& bin_code) { + auto length = tvm::runtime::contrib::mrvl::b64strlen(bin_code); + std::vector byte_array(length); + tvm::runtime::contrib::mrvl::b64decode(bin_code, byte_array.data()); + std::ofstream file_out; + file_out.open(bin_file, std::ios_base::out | std::ios_base::trunc | std::ios_base::binary); + for (auto byte : byte_array) file_out << byte; +} + +static void ReadInputsAndGenerateInputBin(TVMArgs args, const std::string& input_json, + const std::string& input_bin, + const std::string& bin_directory, size_t num_inputs) { + std::ofstream file_out; + file_out.open(input_json, std::ios_base::out | std::ios_base::trunc); + file_out << "{" << std::endl; + file_out << R"( "inputs": [)" << std::endl; + for (size_t i = 0; i < num_inputs; ++i) { + const DLTensor* tensor; + if (args[i].IsObjectRef()) { + NDArray arr = args[i]; + tensor = arr.operator->(); + } else { + tensor = args[i].operator DLTensor*(); + } + std::vector shape; + for (int64_t i = 0; i < tensor->ndim; i++) { + shape.push_back(tensor->shape[i]); + } + NDArray arr = NDArray::Empty(shape, tensor->dtype, tensor->device); + arr.CopyFrom(tensor); + NDArrayToFile(arr, file_out); + if (i != num_inputs - 1) { + file_out << std::endl << "\t," << std::endl; + } + } + file_out << std::endl << "\t]" << std::endl; + file_out << "}" << std::endl; + + const auto* json_to_bin = tvm::runtime::Registry::Get("tvm.mrvl.JsonToBin"); + (*json_to_bin)(input_json, input_bin); +} + +static void RunInferenceOnMlModel(const std::string& symbol_name, const std::string& bin_directory, + const std::string& bin_file, const std::string& input_bin, + const std::string& out_bin_prefix) { + auto command = bin_directory + "/mrvl-mlsim " + "-m " + bin_file + " -d " + input_bin + " -o " + + out_bin_prefix; + std::string sim_directory = "mrvl_sw_sim_" + symbol_name; + const auto* run_sim = tvm::runtime::Registry::Get("tvm.mrvl.RunSim"); + (*run_sim)(command, sim_directory); +} + +static void ReadOutputsAndUpdateRuntime(TVMArgs args, size_t num_inputs, + const std::string& out_bin_prefix) { + for (int out = num_inputs; out < args.size(); out++) { + const DLTensor* outTensor; + if (args[out].IsObjectRef()) { + NDArray arr = args[out]; + outTensor = arr.operator->(); + } else { + outTensor = args[out].operator DLTensor*(); + } + std::vector shape; + for (int64_t i = 0; i < outTensor->ndim; i++) { + shape.push_back(outTensor->shape[i]); + } + NDArray arr = NDArray::Empty(shape, outTensor->dtype, outTensor->device); + int ndim = arr->ndim; + int tot_dim = 1; + for (int i = 0; i < ndim; i++) { + tot_dim *= arr->shape[i]; + } + float f; + float* data = new float[tot_dim](); + String outbin = out_bin_prefix + "-" + std::to_string(out - num_inputs) + ".bin"; + std::ifstream fin(outbin, std::ios::binary); + ICHECK(fin.is_open()) << "Cannot open file: " << outbin; + int i = 0; + while (fin.read(reinterpret_cast(&f), sizeof(float))) { + data[i] = f; + ICHECK(i < tot_dim) << "Output data size mismatch"; + i++; + } + arr.CopyFromBytes(data, tot_dim * sizeof(float)); + arr.CopyTo(const_cast(outTensor)); + delete[] data; + } +} + +static void CleanUp(TVMArgs args, const std::string& bin_file, const std::string& input_json, + const std::string& input_bin, const std::string& out_bin_prefix, + size_t num_outputs) { + const auto* clean_up = tvm::runtime::Registry::Get("tvm.mrvl.CleanUpSim"); + (*clean_up)(bin_file, input_json, input_bin, out_bin_prefix, num_outputs); +} + +void tvm::runtime::contrib::mrvl::RunMarvellSimulator(TVMArgs args, const std::string& symbol_name, + const std::string& bin_code, + size_t num_inputs, size_t num_outputs) { + // check $PATH for the presence of MRVL dependent tools/scripts + std::string file_name("mrvl-mlsim"); + const auto* search_path = tvm::runtime::Registry::Get("tvm.mrvl.SearchPath"); + std::string tools_directory = (*search_path)(file_name); + if (tools_directory.empty()) { + ICHECK(false) << "mrvl-mlsim simulator not found! Please specify the path to Marvell " + "tools by adding it to $PATH."; + } + + const auto* temp_dir = tvm::runtime::Registry::Get("tvm.mrvl.TempDir"); + std::string working_directory = (*temp_dir)(); + auto bin_file = working_directory + "/" + symbol_name + ".bin"; + auto input_json = working_directory + "/indata.json"; + auto input_bin = working_directory + "/input.bin"; + auto out_bin_prefix = working_directory + "/mrvl_sim_out"; + + WriteBinToDisk(bin_file, bin_code); + ReadInputsAndGenerateInputBin(args, input_json, input_bin, tools_directory, num_inputs); + RunInferenceOnMlModel(symbol_name, tools_directory, bin_file, input_bin, out_bin_prefix); + ReadOutputsAndUpdateRuntime(args, num_inputs, out_bin_prefix); + CleanUp(args, bin_file, input_json, input_bin, out_bin_prefix, num_outputs); +} diff --git a/src/runtime/contrib/mrvl/mrvl_sw_runtime_lib.h b/src/runtime/contrib/mrvl/mrvl_sw_runtime_lib.h new file mode 100644 index 000000000000..4670487ed1c4 --- /dev/null +++ b/src/runtime/contrib/mrvl/mrvl_sw_runtime_lib.h @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/runtime/contrib/mrvl/mrvl_sw_runtime_lib.h + * \brief Runtime library for Marvell Software Simulator + */ + +#ifndef TVM_RUNTIME_CONTRIB_MRVL_MRVL_SW_RUNTIME_LIB_H_ +#define TVM_RUNTIME_CONTRIB_MRVL_MRVL_SW_RUNTIME_LIB_H_ + +#include + +#include +#include + +namespace tvm { +namespace runtime { +namespace contrib { +namespace mrvl { + +void RunMarvellSimulator(tvm::runtime::TVMArgs args, const std::string& symbol_name, + const std::string& bin_code, size_t num_inputs, size_t num_outputs); +} +} // namespace contrib +} // namespace runtime +} // namespace tvm + +#endif // TVM_RUNTIME_CONTRIB_MRVL_MRVL_SW_RUNTIME_LIB_H_ diff --git a/tests/python/contrib/test_mrvl/infrastructure.py b/tests/python/contrib/test_mrvl/infrastructure.py index c46753d4e799..c4c56edfead5 100644 --- a/tests/python/contrib/test_mrvl/infrastructure.py +++ b/tests/python/contrib/test_mrvl/infrastructure.py @@ -18,11 +18,14 @@ """Infrastructure to Test Marvell Code Generation""" import json -import os import tvm from tvm import relay from tvm.relay.op.contrib import mrvl +import numpy as np +from tvm.contrib import graph_executor +from tvm.relay.build_module import build +from tvm.relay.op.contrib.mrvl import partition_for_mrvl def get_cpu_op_count(mod): @@ -103,3 +106,48 @@ def verify_codegen( if contains is not None: actual_str = json.dumps(json.loads(mrvl_modules[0].get_source())) assert actual_str.find(contains) + + +def run_and_verify_func(config, data_type="float32"): + + np.random.seed(0) + tvm_target = "llvm" + + func, input_shapes, is_param, option_dict = config + params = { + x: np.random.uniform(-1, 1, input_shapes[x]).astype(dtype=data_type) for x in is_param + } + inputs_dict = { + k: np.random.uniform(-1, 1, v).astype(dtype=data_type) + for k, v in input_shapes.items() + if k not in is_param + } + + dev = tvm.cpu() + for use_mrvl in [True, False]: + mod = tvm.IRModule() + mod["main"] = func + if use_mrvl: + mod = partition_for_mrvl(mod, params, **option_dict) + with tvm.transform.PassContext( + opt_level=3, config={"relay.ext.mrvl.options": option_dict} + ): + model_lib = relay.build(mod, tvm_target, params=params) + + model_rt_graph = graph_executor.GraphModule(model_lib["default"](dev)) + model_rt_graph.set_input(**inputs_dict) + model_rt_graph.run() + output_tensor1 = model_rt_graph.get_output(0).numpy() + + else: + with tvm.transform.PassContext( + opt_level=3, config={"relay.ext.mrvl.options": option_dict} + ): + model_lib = relay.build(mod, tvm_target, params=params) + + model_rt_graph = graph_executor.GraphModule(model_lib["default"](dev)) + model_rt_graph.set_input(**inputs_dict) + model_rt_graph.run() + output_tensor2 = model_rt_graph.get_output(0).numpy() + + tvm.testing.assert_allclose(output_tensor1, output_tensor2, rtol=1e-2, atol=1e-2) diff --git a/tests/python/contrib/test_mrvl/test_mrvl.py b/tests/python/contrib/test_mrvl/test_mrvl.py index 03fdcedc93e5..26956c97c5c1 100644 --- a/tests/python/contrib/test_mrvl/test_mrvl.py +++ b/tests/python/contrib/test_mrvl/test_mrvl.py @@ -26,6 +26,7 @@ from tvm.testing.utils import requires_mrvl from tvm.relay.op.contrib.mrvl import partition_for_mrvl from .infrastructure import verify_codegen +from .infrastructure import run_and_verify_func from tvm.testing import requires_mrvl @@ -142,30 +143,42 @@ def test_partition_mobilenet(num_expected_partition): def test_conv2d(): """Test conv2d operator for "mrvl" targets""" - x = relay.var("x", shape=(1, 3, 224, 224)) - w = relay.const(np.zeros((16, 3, 3, 3), dtype="float32")) - y = relay.nn.conv2d(x, w, strides=[2, 2], padding=[1, 1, 1, 1], kernel_size=[3, 3]) - func = relay.Function([x], y) - params = {} - params["w"] = np.random.rand(16, 3, 3, 3).astype("float32") - mod = tvm.IRModule() - mod["main"] = func - verify_codegen(mod, params=params, tvm_ops=1, contains="mrvl.conv2d_nhwc2nhwc") + def get_graph(): + x = relay.var("x", shape=(1, 3, 224, 224)) + arr = np.random.rand(16, 3, 3, 3).astype("float32") + w = relay.const(arr) + y = relay.nn.conv2d(x, w, strides=[2, 2], padding=[1, 1, 1, 1], kernel_size=[3, 3]) + func = relay.Function([x], y) + params = {} + params["w"] = arr + mod = tvm.IRModule() + mod["main"] = func + option_dict = {"num_tiles": 1} + verify_codegen(mod, params=params, tvm_ops=1, contains="mrvl.conv2d_nhwc2nhwc") + return func, {"x": (1, 3, 224, 224), "w": (16, 3, 3, 3)}, ["w"], option_dict + + run_and_verify_func(get_graph()) @requires_mrvl def test_dense(): """Test dense operator for "mrvl" targets""" - x = relay.var("x", shape=(1, 16)) - w = relay.const(np.zeros((32, 16), dtype="float32")) - y = relay.nn.dense(x, w) - func = relay.Function([x], y) - params = {} - params["w"] = np.random.rand(16, 3, 3, 3).astype("float32") - mod = tvm.IRModule() - mod["main"] = func - verify_codegen(mod, params=params, tvm_ops=0, contains="mrvl.fc_ni2no") + def get_graph(): + x = relay.var("x", shape=(1, 16)) + arr = np.random.rand(16, 16).astype("float32") + w = relay.const(arr) + y = relay.nn.dense(x, w) + func = relay.Function([x], y) + params = {} + params["w"] = arr + mod = tvm.IRModule() + mod["main"] = func + option_dict = {"num_tiles": 1} + verify_codegen(mod, params=params, tvm_ops=0, contains="mrvl.fc_ni2no") + return func, {"x": (1, 16), "w": (16, 16)}, ["w"], option_dict + + run_and_verify_func(get_graph()) if __name__ == "__main__":