-
Notifications
You must be signed in to change notification settings - Fork 1
Update evaluate.py to be able to profile and correctly validate the outputs on VM #1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: echuraev/virtual_device
Are you sure you want to change the base?
Changes from all commits
b12454c
75807cf
e9d2468
97ed284
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -481,6 +481,11 @@ def get_args(): | |
| action="store_true", | ||
| help="Use graph runtime debugger to output per layer perf. data and other statistics", | ||
| ) | ||
| parser.add_argument( | ||
| "--VM", | ||
| action="store_true", | ||
| help="Use VM compiling and benchmarking", | ||
| ) | ||
|
|
||
| args = parser.parse_args() | ||
| if args.rpc_tracker_port != None: | ||
|
|
@@ -733,9 +738,9 @@ def __init__(self, shape_dict, layout="NCHW", preproc=None): | |
|
|
||
| self.inputs = {name : image} | ||
|
|
||
| def Validate(self, m, ref_outputs=[]): | ||
| def Validate(self, m, ref_outputs=[], data={}): | ||
| if isinstance(m, tvm.runtime.vm.VirtualMachine) or isinstance(m, tvm.runtime.profiler_vm.VirtualMachineProfiler): | ||
| tvm_output = m.get_outputs()[0] | ||
| tvm_output = m.invoke("main", **data) | ||
| else: | ||
| tvm_output = m.get_output(0) | ||
| #import ipdb; ipdb.set_trace() | ||
|
|
@@ -1545,12 +1550,8 @@ def _benchmark_vm( | |
| dtype="float32", | ||
| validator=None | ||
| ): | ||
| #if args.debug: | ||
| # from tvm.contrib.debugger import debug_runtime as graph_executor | ||
| #else: | ||
| # from tvm.contrib import graph_executor | ||
| from tvm.runtime.vm import VirtualMachine | ||
|
|
||
| from tvm.runtime import profiler_vm | ||
| if self.use_tracker and self.remote == None: | ||
| self._connect_tracker() | ||
|
|
||
|
|
@@ -1560,7 +1561,6 @@ def _benchmark_vm( | |
| mod = tvm.IRModule() | ||
| mod["main"] = tvm_mod | ||
|
|
||
| #target = tvm.target.Target(args.target, host=args.target_host) | ||
| with tvm.transform.PassContext(opt_level=3): | ||
| vmc = relay.vm.compile(mod, target_host=target_host, target=target, params=params) | ||
|
|
||
|
|
@@ -1588,7 +1588,6 @@ def _benchmark_vm( | |
| vm = tvm.runtime.profiler_vm.VirtualMachineProfiler(vmc, ctx, "naive") | ||
| else: | ||
| vm = VirtualMachine(vmc, ctx, "naive") | ||
|
|
||
| inputs = [] | ||
| if isinstance(validator, Validator): | ||
| inputs = validator.GetInputDictionary() | ||
|
|
@@ -1604,11 +1603,14 @@ def _benchmark_vm( | |
| else: | ||
| data = tvm.nd.array(np.random.normal(size=input_shape).astype("float32"), ctx) | ||
| vm.set_input("main", data) | ||
|
|
||
|
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Redundant change
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Remove these spaces |
||
| print("Evaluating...", flush=True) | ||
| if args.debug: | ||
| res = vm.profile(**data, func_name="main") | ||
| print(res) | ||
|
|
||
| number = 1 | ||
| repeat = 100 | ||
| repeat = 1 | ||
| min_repeat_ms = 0 | ||
| time_to_work_ms = 1000 | ||
| cooldown_interval_ms=1000 | ||
|
|
@@ -1622,7 +1624,7 @@ def _benchmark_vm( | |
| if validator: | ||
| if isinstance(validator, Validator): | ||
| ref_outputs = validator.GetReference() | ||
| validator.Validate(vm, ref_outputs) | ||
| validator.Validate(vm, ref_outputs, data) | ||
|
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How does it work? It looks like method
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For some reason the corresponding change in |
||
| else: | ||
| ref_outputs = validator(inputs) | ||
| for i, ref_output in enumerate(ref_outputs): | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm confused about that. Why we should infer the network one more time for getting outputs?
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe that is because we run "invoke_stateful" before, which "Invoke a function and ignore the returned result." (CC from tvm/python/tvm/runtime/vm.py#invoke_stateful). Besides that,
m.get_outputs()should do the work, but for some reason it is not returning correct outputs.