Skip to content

Commit 77df6e8

Browse files
[CMSIS-NN] Add a runtime error message (#13643)
[CMSIS-NN] Add a runtime error message APIs TVMAPISetLastError and TVMGetLastError are used to propagate CMSIS-NN errors caught in the backend. AOT test runner was improved to observe the contents of this global variable. A test was added to check for the last set error as part of this commit.
1 parent 0e046da commit 77df6e8

File tree

8 files changed

+287
-34
lines changed

8 files changed

+287
-34
lines changed

python/tvm/testing/aot.py

Lines changed: 53 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
import subprocess
2424
import tarfile
2525
import logging
26-
from typing import Any, NamedTuple, Union, Tuple, Optional, List, Dict
26+
from typing import Any, NamedTuple, Union, Tuple, Optional, List, Dict, Callable
2727
import numpy as np
2828

2929
import tvm
@@ -200,6 +200,7 @@ def _emit_main_prologue(
200200
compiled_models,
201201
interface_api,
202202
use_stack_allocator=True,
203+
debug_last_error=False,
203204
):
204205
if use_stack_allocator:
205206
workspace_define = f"#define WORKSPACE_SIZE ({workspace_bytes}"
@@ -243,11 +244,28 @@ def _emit_main_prologue(
243244
va_start(args, msg);
244245
vfprintf(stdout, msg, args);
245246
va_end(args);
246-
}\n
247-
TVM_DLL int TVMFuncRegisterGlobal(const char* name, TVMFunctionHandle f, int override) {}
248-
int main(){\n
247+
}
249248
"""
250249
)
250+
if debug_last_error:
251+
main_file.write(
252+
"""\n
253+
tvm_crt_error_t TVMPlatformTimerStart() {
254+
return kTvmErrorFunctionCallNotImplemented;
255+
}
256+
tvm_crt_error_t TVMPlatformTimerStop(double* elapsed_time_seconds) {
257+
return kTvmErrorFunctionCallNotImplemented;
258+
}
259+
const TVMModule* TVMSystemLibEntryPoint(void) { return NULL; }
260+
"""
261+
)
262+
else:
263+
main_file.write(
264+
"""\n
265+
TVM_DLL int TVMFuncRegisterGlobal(const char* name, TVMFunctionHandle f, int override) {}
266+
"""
267+
)
268+
main_file.write("\nint main(){\n")
251269
main_file.write(custom_prologue)
252270

253271

@@ -332,10 +350,10 @@ def _emit_main_data_setup(main_file, input_map, output_map, mod_name):
332350

333351

334352
def _emit_main_c_interface_call(
335-
main_file, devices, workspace_pool_names, mod_name, use_workspace_io
353+
main_file, devices, workspace_pool_names, mod_name, use_workspace_io, debug_last_error
336354
):
337355
sub_strings = list()
338-
sub_strings.append(f'{_mangle_name(mod_name,"run")}(')
356+
sub_strings.append(f'if ({_mangle_name(mod_name,"run")}(')
339357
if not use_workspace_io:
340358
sub_strings.append(f'&{_mangle_name(mod_name,"inputs")}, ')
341359
sub_strings.append(f'&{_mangle_name(mod_name,"outputs")}, ')
@@ -346,10 +364,14 @@ def _emit_main_c_interface_call(
346364
# Removing the last two characters that is a comma and a space
347365
sub_strings[-1] = sub_strings[-1][:-2]
348366
# Adding brackets and newline instead
349-
sub_strings[-1] = sub_strings[-1] + ");\n"
350-
367+
sub_strings[-1] = sub_strings[-1] + ") == -1) {\n"
351368
main_file_string = "".join(sub_strings)
352369
main_file.write(main_file_string)
370+
if debug_last_error:
371+
main_file.write(f'\tprintf("ERROR: %s\\n", TVMGetLastError());\n')
372+
main_file.write(f'\tprintf("{AOT_FAILURE_TOKEN}\\n");\n')
373+
main_file.write(f"\treturn -1;\n")
374+
main_file.write("}\n")
353375

354376

355377
def _emit_main_fake_packed_values(main_file):
@@ -447,13 +469,15 @@ def _emit_main_epilogue(main_file, custom_epilogue):
447469
main_file.write("}\n")
448470

449471

450-
def _emit_main_common_includes(main_file, custom_includes):
472+
def _emit_main_common_includes(main_file, custom_includes, debug_last_error):
451473
main_file.write("#include <stdio.h>\n")
452474
main_file.write("#include <stdarg.h>\n")
453475
main_file.write("#include <stdlib.h>\n")
454476
main_file.write("#include <math.h>\n")
455477
main_file.write('#include "tvm/runtime/c_runtime_api.h"\n')
456478
main_file.write('#include "tvm/runtime/crt/stack_allocator.h"\n')
479+
if debug_last_error:
480+
main_file.write('#include "tvm/runtime/crt/module.h"\n')
457481
for include in custom_includes:
458482
main_file.write(f'#include "{include}"\n')
459483

@@ -474,12 +498,13 @@ def _create_main(
474498
workspace_bytes,
475499
use_stack_allocator=True,
476500
use_workspace_io=False,
501+
debug_last_error=False,
477502
):
478503
file_path = pathlib.Path(f"{output_path}/" + test_name).resolve()
479504
# create header file
480505
raw_path = file_path.with_suffix(".c").resolve()
481506
with open(raw_path, "w") as main_file:
482-
_emit_main_common_includes(main_file, custom_includes)
507+
_emit_main_common_includes(main_file, custom_includes, debug_last_error)
483508

484509
if interface_api == "c":
485510
for compiled_model in compiled_models:
@@ -497,6 +522,7 @@ def _create_main(
497522
compiled_models,
498523
interface_api,
499524
use_stack_allocator,
525+
debug_last_error,
500526
)
501527
if use_stack_allocator:
502528
_emit_main_init_memory_manager(main_file)
@@ -529,6 +555,7 @@ def _create_main(
529555
list(workspace_pool_names.keys()),
530556
model.name,
531557
use_workspace_io,
558+
debug_last_error,
532559
)
533560
else:
534561
_emit_main_fake_packed_values(main_file)
@@ -701,6 +728,8 @@ def run_and_check(
701728
test_dir: str = None,
702729
verbose: bool = False,
703730
use_workspace_io: bool = False,
731+
debug_last_error: bool = False,
732+
checker: Optional[Callable[[str], bool]] = None,
704733
):
705734
"""
706735
This method uses the original test data and compiled runtime.Modules
@@ -780,8 +809,12 @@ def run_and_check_body(base_path):
780809
workspace_bytes,
781810
use_stack_allocator,
782811
use_workspace_io,
812+
debug_last_error,
783813
)
784814

815+
if checker and (not checker(base_path)):
816+
return False
817+
785818
# Verify that compiles fine
786819
file_dir = os.path.dirname(os.path.abspath(__file__))
787820
makefile_dir = os.path.join(file_dir, "../../../tests/python/relay/aot")
@@ -829,11 +862,13 @@ def run_and_check_body(base_path):
829862
with open(run_log_path) as run_log:
830863
assert AOT_SUCCESS_TOKEN in run_log.read()
831864

865+
return True
866+
832867
if test_dir is None:
833868
tmpdir = utils.tempdir()
834-
run_and_check_body(os.path.join(tmpdir.path, "test"))
869+
return run_and_check_body(os.path.join(tmpdir.path, "test"))
835870
else:
836-
run_and_check_body(test_dir)
871+
return run_and_check_body(test_dir)
837872

838873

839874
def compile_and_run(
@@ -852,7 +887,9 @@ def compile_and_run(
852887
test_dir: str = None,
853888
verbose: bool = False,
854889
schedule_name: str = None,
855-
):
890+
debug_last_error: bool = False,
891+
checker: Optional[Callable[[str], bool]] = None,
892+
) -> bool:
856893
"""This is a wrapper API to compile and run models as test for AoT
857894
858895
Parameters
@@ -883,7 +920,7 @@ def compile_and_run(
883920
schedule_name=schedule_name,
884921
)
885922

886-
run_and_check(
923+
return run_and_check(
887924
models=compiled_test_mods,
888925
runner=runner,
889926
interface_api=interface_api,
@@ -893,6 +930,8 @@ def compile_and_run(
893930
data_linkage=data_linkage,
894931
test_dir=test_dir,
895932
verbose=verbose,
933+
debug_last_error=debug_last_error,
934+
checker=checker,
896935
)
897936

898937

src/relay/backend/contrib/cmsisnn/compiler_attrs.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,11 +40,13 @@ Target CreateTarget(const tvm::transform::PassContext& ctx) {
4040

4141
String mcpu = cfg.value()->mcpu;
4242
Array<String> mattr = {cfg.value()->mattr};
43+
Bool debug_last_error = cfg.value()->debug_last_error;
4344

4445
Target cmsis_nn_target(TargetJSON{
4546
{"kind", String("cmsis-nn")},
4647
{"mcpu", mcpu},
4748
{"mattr", mattr},
49+
{"debug_last_error", debug_last_error},
4850
});
4951

5052
return cmsis_nn_target;

src/relay/backend/contrib/cmsisnn/compiler_attrs.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ namespace cmsisnn {
3737
struct CMSISNNCompilerConfigNode : public tvm::AttrsNode<CMSISNNCompilerConfigNode> {
3838
String mcpu;
3939
String mattr;
40+
Bool debug_last_error = Bool(false);
4041

4142
TVM_DECLARE_ATTRS(CMSISNNCompilerConfigNode, "ext.attrs.CMSISNNCompilerConfigNode") {
4243
TVM_ATTR_FIELD(mcpu)
@@ -47,6 +48,9 @@ struct CMSISNNCompilerConfigNode : public tvm::AttrsNode<CMSISNNCompilerConfigNo
4748
TVM_ATTR_FIELD(mattr)
4849
.describe("The attributes to configure CMSIS-NN (i.e. +nodsp, +nomve)")
4950
.set_default("");
51+
TVM_ATTR_FIELD(debug_last_error)
52+
.describe("Whether to enable storing the last error")
53+
.set_default(Bool(false));
5054
}
5155
};
5256

src/relay/backend/contrib/cmsisnn/target.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ runtime::Module TIRToRuntime(IRModule mod, Target target);
3636
TVM_REGISTER_TARGET_KIND("cmsis-nn", kDLCPU)
3737
.add_attr_option<Array<String>>("mattr")
3838
.add_attr_option<String>("mcpu")
39+
.add_attr_option<Bool>("debug_last_error")
3940
.set_attr<FTVMRelayToTIR>(tvm::attr::kRelayToTIR, RelayToTIR())
4041
.set_attr<FTVMTIRToRuntime>("TIRToRuntime", TIRToRuntime)
4142
.set_target_parser(tvm::target::parsers::cpu::ParseTarget);

src/relay/backend/contrib/cmsisnn/tir_to_runtime.cc

Lines changed: 41 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
* specific language governing permissions and limitations
1717
* under the License.
1818
*/
19+
#include <tvm/ir/transform.h>
20+
1921
#include <cmath>
2022
#include <fstream>
2123
#include <map>
@@ -26,6 +28,7 @@
2628
#include "../../../../runtime/file_utils.h"
2729
#include "../../../../target/source/codegen_c.h"
2830
#include "../../../../target/source/codegen_c_host.h"
31+
#include "compiler_attrs.h"
2932

3033
namespace tvm {
3134
using namespace tir;
@@ -35,7 +38,9 @@ namespace cmsisnn {
3538

3639
class CodeGenCMSISNN : public codegen::CodeGenCHost {
3740
public:
38-
void Init(bool output_ssa, bool emit_asserts, bool emit_fwd_func_decl, std::string target_str) {
41+
void Init(bool output_ssa, bool emit_asserts, bool emit_fwd_func_decl, std::string target_str,
42+
bool debug_last_error) {
43+
this->debug_last_error = debug_last_error;
3944
std::unordered_set<std::string> devices;
4045
devices.insert("cmsis-nn");
4146
CodeGenCHost::Init(output_ssa, emit_asserts, emit_fwd_func_decl, target_str, devices);
@@ -49,6 +54,9 @@ class CodeGenCMSISNN : public codegen::CodeGenCHost {
4954
void AddFunction(const PrimFunc& prim_func) { CodeGenC::AddFunction(prim_func); }
5055

5156
private:
57+
/*! * \brief Enable storing the last error */
58+
bool debug_last_error;
59+
5260
/*! * \brief CMSIS-NN context buffer info */
5361
struct CMSISNNContextBuffer {
5462
std::string name;
@@ -357,13 +365,7 @@ class CodeGenCMSISNN : public codegen::CodeGenCHost {
357365
stream << "&" << filter_dim << ", " << filter_data << ", ";
358366
stream << "&" << bias_dim << ", " << bias_data << ", ";
359367
stream << "&" << output_dim << ", " << output_data << ");\n";
360-
PrintIndent();
361-
stream << "if (status != ARM_CMSIS_NN_SUCCESS) {\n";
362-
PrintIndent();
363-
PrintIndent();
364-
stream << "return -1;\n";
365-
PrintIndent();
366-
stream << "}\n";
368+
EmitErrorCheck();
367369
}
368370

369371
/*! * \brief Emits CMSIS-NN APIs for every call_extern comprising fully connected */
@@ -426,13 +428,7 @@ class CodeGenCMSISNN : public codegen::CodeGenCHost {
426428
stream << "&" << filter_dim << ", " << filter_data << ", ";
427429
stream << "&" << bias_dim << ", " << bias_data << ", ";
428430
stream << "&" << output_dim << ", " << output_data << ");\n";
429-
PrintIndent();
430-
stream << "if (status != ARM_CMSIS_NN_SUCCESS) {\n";
431-
PrintIndent();
432-
PrintIndent();
433-
stream << "return -1;\n";
434-
PrintIndent();
435-
stream << "}\n";
431+
EmitErrorCheck();
436432
}
437433

438434
/*! * \brief Emits CMSIS-NN APIs for every call_extern comprising pooling ops */
@@ -480,24 +476,51 @@ class CodeGenCMSISNN : public codegen::CodeGenCHost {
480476
stream << "&" << input_dim << ", " << input_data << ", ";
481477
stream << "&" << filter_dim << ", ";
482478
stream << "&" << output_dim << ", " << output_data << ");\n";
479+
EmitErrorCheck();
480+
}
481+
482+
void EmitErrorCheck() {
483+
auto emit_error = [&](std::string error) {
484+
if (this->debug_last_error) {
485+
stream << "TVMAPISetLastError(\"" << error << "\"); ";
486+
}
487+
};
488+
483489
PrintIndent();
484-
stream << "if (status != ARM_CMSIS_NN_SUCCESS) {\n";
490+
stream << "switch (!status) {\n";
485491
PrintIndent();
492+
stream << "case ARM_CMSIS_NN_SUCCESS: break;\n";
486493
PrintIndent();
494+
stream << "case ARM_CMSIS_NN_ARG_ERROR: ";
495+
emit_error("ARM_CMSIS_NN_ARG_ERROR");
496+
stream << "return -1;\n";
497+
PrintIndent();
498+
stream << "case ARM_CMSIS_NN_NO_IMPL_ERROR: ";
499+
emit_error("ARM_CMSIS_NN_NO_IMPL_ERROR");
487500
stream << "return -1;\n";
488501
PrintIndent();
489502
stream << "}\n";
490503
}
491504
};
492505

506+
static CMSISNNCompilerConfig GetCompilerAttrs() {
507+
auto ctx = tvm::tir::transform::PassContext::Current();
508+
Optional<CMSISNNCompilerConfig> cfg =
509+
ctx->GetConfig<CMSISNNCompilerConfig>("relay.ext.cmsisnn.options");
510+
if (!cfg.defined()) {
511+
return AttrsWithDefaultValues<CMSISNNCompilerConfig>();
512+
}
513+
return cfg.value();
514+
}
515+
493516
runtime::Module TIRToRuntime(IRModule mod, Target target) {
494517
bool output_ssa = false;
495518
bool emit_asserts = false;
496519
bool emit_fwd_func_decl = false;
520+
bool debug_last_error = GetCompilerAttrs()->debug_last_error;
497521
CodeGenCMSISNN codegen;
498522
Array<String> function_names;
499-
codegen.Init(output_ssa, emit_asserts, emit_fwd_func_decl, target->str());
500-
523+
codegen.Init(output_ssa, emit_asserts, emit_fwd_func_decl, target->str(), debug_last_error);
501524
std::vector<std::pair<tvm::GlobalVar, tvm::BaseFunc>> funcs;
502525
for (auto kv : mod->functions) {
503526
funcs.push_back(kv);

0 commit comments

Comments
 (0)