Skip to content

Commit dbd076a

Browse files
authored
[BYORTL][Verilator] update ops and add MobileNet (#7972)
* update * update vta submodule * cpp fmt * python fmt * skip if tflite is not available * fmt * change assertion * update comment
1 parent c510c2b commit dbd076a

File tree

8 files changed

+554
-103
lines changed

8 files changed

+554
-103
lines changed

3rdparty/vta-hw

src/runtime/contrib/verilator/verilator_kernel.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,12 @@ namespace tvm {
3333
namespace runtime {
3434
namespace contrib {
3535

36-
extern "C" TVM_DLL void verilator_add(VerilatorHandle handle, int* data, int* weight, int* out,
36+
extern "C" TVM_DLL void verilator_add(VerilatorHandle handle, int* left, int* right, int* out,
3737
int p_h_, int p_w_);
3838

39+
extern "C" TVM_DLL void verilator_bias_add(VerilatorHandle handle, int* data, int* bias, int* out,
40+
int p_n_, int p_c_, int p_h_, int p_w_);
41+
3942
} // namespace contrib
4043
} // namespace runtime
4144
} // namespace tvm

src/runtime/contrib/verilator/verilator_runtime.cc

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ VerilatorRuntime::~VerilatorRuntime() {
8080
auto dealloc = reinterpret_cast<VerilatorDeallocFunc>(lib_->GetSymbol("VerilatorDealloc"));
8181
ICHECK(dealloc != nullptr);
8282
dealloc(device_);
83-
delete lib_;
83+
lib_->~VerilatorLibrary();
8484
}
8585

8686
void VerilatorRuntime::SetLibrary(const std::string& lib_path) { lib_path_ = lib_path; }
@@ -100,15 +100,14 @@ void VerilatorRuntime::Init(const Array<NDArray>& consts) {
100100
ICHECK(reset != nullptr);
101101
read_ = reinterpret_cast<VerilatorReadFunc>(lib_->GetSymbol("VerilatorRead"));
102102
ICHECK(read_ != nullptr);
103-
add_op_ = reinterpret_cast<VerilatorAddFunc>(lib_->GetSymbol("verilator_add"));
104103

105104
// alloc verilator device
106105
device_ = alloc();
107106

108107
// enable profiler
109108
if (prof_enable_) prof_ = VerilatorProfiler::ThreadLocal();
110109

111-
// reset verilator device.
110+
// reset verilator device
112111
reset(device_, reset_cycles_);
113112

114113
CHECK_EQ(consts.size(), const_idx_.size())
@@ -136,11 +135,17 @@ void VerilatorRuntime::Run() {
136135
if (node.GetOpType() == "kernel") {
137136
CHECK_EQ(node.GetOpType(), "kernel");
138137
auto op_name = node.GetOpName();
138+
auto entry = node.GetInputs()[0];
139+
auto shape = node.GetOpShape()[entry.index_];
139140
if ("add" == op_name) {
140-
auto entry = node.GetInputs()[0];
141-
auto shape = nodes_[entry.id_].GetOpShape()[entry.index_];
142-
ICHECK(add_op_ != nullptr);
143-
add_op_(device_, in_ptr[0], in_ptr[1], out_ptr[0], shape[0], shape[1]);
141+
auto add = reinterpret_cast<VerilatorAddFunc>(lib_->GetSymbol("verilator_add"));
142+
ICHECK(add != nullptr);
143+
add(device_, in_ptr[0], in_ptr[1], out_ptr[0], shape[0], shape[1]);
144+
} else if ("nn.bias_add" == op_name) {
145+
auto bias_add =
146+
reinterpret_cast<VerilatorBiasAddFunc>(lib_->GetSymbol("verilator_bias_add"));
147+
ICHECK(bias_add != nullptr);
148+
bias_add(device_, in_ptr[0], in_ptr[1], out_ptr[0], shape[0], shape[3], shape[1], shape[2]);
144149
} else {
145150
LOG(FATAL) << "Unsupported op: " << op_name;
146151
}

src/runtime/contrib/verilator/verilator_runtime.h

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,9 @@ using namespace tvm::runtime::json;
5050
typedef VerilatorHandle (*VerilatorAllocFunc)();
5151
typedef void (*VerilatorDeallocFunc)(VerilatorHandle);
5252
typedef void (*VerilatorResetFunc)(VerilatorHandle, int);
53-
typedef void (*VerilatorAddFunc)(VerilatorHandle, int*, int*, int*, int, int);
5453
typedef int (*VerilatorReadFunc)(VerilatorHandle, int, int);
54+
typedef void (*VerilatorAddFunc)(VerilatorHandle, int*, int*, int*, int, int);
55+
typedef void (*VerilatorBiasAddFunc)(VerilatorHandle, int*, int*, int*, int, int, int, int);
5556

5657
class VerilatorLibrary : public Library {
5758
public:
@@ -122,8 +123,6 @@ class VerilatorRuntime : public JSONRuntimeBase {
122123
VerilatorProfiler* prof_{nullptr};
123124
/*! \brief the verilator read function */
124125
VerilatorReadFunc read_{nullptr};
125-
/*! \brief the verilator add op function */
126-
VerilatorAddFunc add_op_{nullptr};
127126
/*! \brief the verilator reset cycles */
128127
int reset_cycles_{1};
129128
/*! \brief the verilator profiler status */

tests/python/contrib/test_verilator/infrastructure.py

Lines changed: 104 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import os
2020
import sys
2121
import subprocess as sp
22+
import json
2223

2324
import tvm
2425
from tvm import relay
@@ -48,6 +49,10 @@ def _func_wrapper(expr):
4849
return _func_wrapper
4950

5051

52+
_register_verilator_op("add")
53+
_register_verilator_op("nn.bias_add")
54+
55+
5156
def skip_test():
5257
"""Skip test if it requires the Verilator codegen and it's not present."""
5358
if not tvm.get_global_func("relay.ext.verilator", True):
@@ -59,8 +64,33 @@ def skip_test():
5964
return False
6065

6166

67+
def clear_stats():
68+
"""Clear profiler statistics."""
69+
f = tvm.get_global_func("verilator.profiler_clear", True)
70+
if f:
71+
f()
72+
73+
74+
def stats():
75+
"""Get profiler statistics."""
76+
77+
x = tvm.get_global_func("verilator.profiler_status")()
78+
return json.loads(x)
79+
80+
6281
def offload(mod):
63-
"""Offload ops based on the registered ops"""
82+
"""Offload ops based on the registered ops
83+
84+
Paramters
85+
---------
86+
mod : Module
87+
The input module.
88+
89+
Returns
90+
-------
91+
mod : Module
92+
The output module with offloaded ops.
93+
"""
6494

6595
backend = "verilator"
6696
mod = transform.AnnotateTarget([backend])(mod)
@@ -69,7 +99,7 @@ def offload(mod):
6999

70100

71101
def verilator_app_path():
72-
"""Find verilator hardware app path"""
102+
"""Create verilator hardware app path."""
73103

74104
cur_dir = os.path.dirname(os.path.realpath(__file__))
75105
return os.path.join(
@@ -82,37 +112,87 @@ def verilator_app_path():
82112
"vta-hw",
83113
"apps",
84114
"verilator",
115+
"add",
85116
)
86117

87118

88-
def compile_hardware():
89-
"""Compile hardware into shared library"""
119+
def compile_hardware(lanes):
120+
"""Compile hardware into shared library
121+
122+
Paramters
123+
---------
124+
lanes : Int
125+
The number of vector lanes.
126+
127+
Returns
128+
-------
129+
path : Str
130+
The path of the shared library.
131+
"""
132+
lib_name = "libverilator_{}".format(lanes)
133+
lib_name_ext = "{}.so".format(lib_name)
134+
lib = os.path.join(verilator_app_path(), lib_name_ext)
135+
if not os.path.isfile(lib):
136+
opt_lib_name = "LIB_NAME={}".format(lib_name)
137+
opt_lanes = "LANES={}".format(lanes)
138+
cmd = []
139+
cmd.append("make")
140+
cmd.append("--directory")
141+
cmd.append(verilator_app_path())
142+
cmd.append(opt_lib_name)
143+
cmd.append(opt_lanes)
144+
sp.run(cmd, check=True, stdout=sp.DEVNULL)
145+
return lib
146+
90147

91-
cmd = []
92-
cmd.append("make")
93-
cmd.append("--directory")
94-
cmd.append(verilator_app_path())
95-
sp.run(cmd, check=True)
148+
def compiler_opts(lib):
149+
"""Create compiler options
96150
151+
Paramters
152+
---------
153+
lib : Str
154+
The path of the hardware shared library.
97155
98-
def compile_module(mod):
99-
"""Compile Relay module and hardware library"""
156+
Returns
157+
-------
158+
opts : Dict
159+
The compiler options.
160+
"""
161+
opts = {
162+
"lib_path": lib,
163+
"profiler_enable": True,
164+
"profiler_cycle_counter_id": 0,
165+
}
166+
return opts
100167

101-
lib = os.path.join(verilator_app_path(), "libverilator.so")
102-
if not os.path.isfile(lib):
103-
compile_hardware()
104168

105-
opts = {"lib_path": lib}
169+
def run_module(inp, mod, params=None, opts=None):
170+
"""Compile Relay module and hardware library
106171
107-
with tvm.transform.PassContext(opt_level=3, config={"relay.ext.verilator.options": opts}):
108-
exe = relay.vm.compile(mod, target="llvm", params=None)
109-
code, lib = exe.save()
110-
return runtime.vm.Executable.load_exec(code, lib)
172+
Paramters
173+
---------
174+
inp : Data
175+
The input data.
111176
177+
mod : Module
178+
The relay module.
112179
113-
def run_module(exe, inputs):
114-
"""Run Relay module"""
180+
params : Parameters
181+
The model Parameters.
115182
116-
dev = tvm.cpu()
117-
vm = runtime.vm.VirtualMachine(exe, dev)
118-
return vm.run(**inputs)
183+
opts : Dict
184+
The compiler
185+
186+
Returns
187+
-------
188+
out : Data
189+
The output data.
190+
"""
191+
192+
with tvm.transform.PassContext(opt_level=3, config={"relay.ext.verilator.options": opts}):
193+
lib = relay.vm.compile(mod, target="llvm", params=params)
194+
code, lib = lib.save()
195+
exe = runtime.vm.Executable.load_exec(code, lib)
196+
vm = runtime.vm.VirtualMachine(exe, tvm.cpu())
197+
out = vm.run(**inp)
198+
return out

0 commit comments

Comments
 (0)