diff --git a/evaluate.py b/evaluate.py index 731d18f215cf..d030545f2996 100644 --- a/evaluate.py +++ b/evaluate.py @@ -481,6 +481,11 @@ def get_args(): action="store_true", help="Use graph runtime debugger to output per layer perf. data and other statistics", ) + parser.add_argument( + "--VM", + action="store_true", + help="Use VM compiling and benchmarking", + ) args = parser.parse_args() if args.rpc_tracker_port != None: @@ -733,9 +738,9 @@ def __init__(self, shape_dict, layout="NCHW", preproc=None): self.inputs = {name : image} - def Validate(self, m, ref_outputs=[]): + def Validate(self, m, ref_outputs=[], data={}): if isinstance(m, tvm.runtime.vm.VirtualMachine) or isinstance(m, tvm.runtime.profiler_vm.VirtualMachineProfiler): - tvm_output = m.get_outputs()[0] + tvm_output = m.invoke("main", **data) else: tvm_output = m.get_output(0) #import ipdb; ipdb.set_trace() @@ -1545,12 +1550,8 @@ def _benchmark_vm( dtype="float32", validator=None ): - #if args.debug: - # from tvm.contrib.debugger import debug_runtime as graph_executor - #else: - # from tvm.contrib import graph_executor from tvm.runtime.vm import VirtualMachine - + from tvm.runtime import profiler_vm if self.use_tracker and self.remote == None: self._connect_tracker() @@ -1560,7 +1561,6 @@ def _benchmark_vm( mod = tvm.IRModule() mod["main"] = tvm_mod - #target = tvm.target.Target(args.target, host=args.target_host) with tvm.transform.PassContext(opt_level=3): vmc = relay.vm.compile(mod, target_host=target_host, target=target, params=params) @@ -1588,7 +1588,6 @@ def _benchmark_vm( vm = tvm.runtime.profiler_vm.VirtualMachineProfiler(vmc, ctx, "naive") else: vm = VirtualMachine(vmc, ctx, "naive") - inputs = [] if isinstance(validator, Validator): inputs = validator.GetInputDictionary() @@ -1604,11 +1603,14 @@ def _benchmark_vm( else: data = tvm.nd.array(np.random.normal(size=input_shape).astype("float32"), ctx) vm.set_input("main", data) - + print("Evaluating...", flush=True) + if args.debug: + res = vm.profile(**data, func_name="main") + print(res) + number = 1 repeat = 100 - repeat = 1 min_repeat_ms = 0 time_to_work_ms = 1000 cooldown_interval_ms=1000 @@ -1622,7 +1624,7 @@ def _benchmark_vm( if validator: if isinstance(validator, Validator): ref_outputs = validator.GetReference() - validator.Validate(vm, ref_outputs) + validator.Validate(vm, ref_outputs, data) else: ref_outputs = validator(inputs) for i, ref_output in enumerate(ref_outputs):