Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 14 additions & 12 deletions evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,6 +481,11 @@ def get_args():
action="store_true",
help="Use graph runtime debugger to output per layer perf. data and other statistics",
)
parser.add_argument(
"--VM",
action="store_true",
help="Use VM compiling and benchmarking",
)

args = parser.parse_args()
if args.rpc_tracker_port != None:
Expand Down Expand Up @@ -733,9 +738,9 @@ def __init__(self, shape_dict, layout="NCHW", preproc=None):

self.inputs = {name : image}

def Validate(self, m, ref_outputs=[]):
def Validate(self, m, ref_outputs=[], data={}):
if isinstance(m, tvm.runtime.vm.VirtualMachine) or isinstance(m, tvm.runtime.profiler_vm.VirtualMachineProfiler):
tvm_output = m.get_outputs()[0]
tvm_output = m.invoke("main", **data)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm confused about that. Why we should infer the network one more time for getting outputs?

Copy link
Author

@dsbarinov1 dsbarinov1 Jul 21, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe that is because we run "invoke_stateful" before, which "Invoke a function and ignore the returned result." (CC from tvm/python/tvm/runtime/vm.py#invoke_stateful). Besides that, m.get_outputs() should do the work, but for some reason it is not returning correct outputs.

else:
tvm_output = m.get_output(0)
#import ipdb; ipdb.set_trace()
Expand Down Expand Up @@ -1545,12 +1550,8 @@ def _benchmark_vm(
dtype="float32",
validator=None
):
#if args.debug:
# from tvm.contrib.debugger import debug_runtime as graph_executor
#else:
# from tvm.contrib import graph_executor
from tvm.runtime.vm import VirtualMachine

from tvm.runtime import profiler_vm
if self.use_tracker and self.remote == None:
self._connect_tracker()

Expand All @@ -1560,7 +1561,6 @@ def _benchmark_vm(
mod = tvm.IRModule()
mod["main"] = tvm_mod

#target = tvm.target.Target(args.target, host=args.target_host)
with tvm.transform.PassContext(opt_level=3):
vmc = relay.vm.compile(mod, target_host=target_host, target=target, params=params)

Expand Down Expand Up @@ -1588,7 +1588,6 @@ def _benchmark_vm(
vm = tvm.runtime.profiler_vm.VirtualMachineProfiler(vmc, ctx, "naive")
else:
vm = VirtualMachine(vmc, ctx, "naive")

inputs = []
if isinstance(validator, Validator):
inputs = validator.GetInputDictionary()
Expand All @@ -1604,11 +1603,14 @@ def _benchmark_vm(
else:
data = tvm.nd.array(np.random.normal(size=input_shape).astype("float32"), ctx)
vm.set_input("main", data)

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Redundant change

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove these spaces

print("Evaluating...", flush=True)
if args.debug:
res = vm.profile(**data, func_name="main")
print(res)

number = 1
repeat = 100
repeat = 1
min_repeat_ms = 0
time_to_work_ms = 1000
cooldown_interval_ms=1000
Expand All @@ -1622,7 +1624,7 @@ def _benchmark_vm(
if validator:
if isinstance(validator, Validator):
ref_outputs = validator.GetReference()
validator.Validate(vm, ref_outputs)
validator.Validate(vm, ref_outputs, data)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How does it work? It looks like method Validate takes only 2 arguments but you pass 3...

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For some reason the corresponding change in Validate was not in commit. Should be fine now.

else:
ref_outputs = validator(inputs)
for i, ref_output in enumerate(ref_outputs):
Expand Down