Skip to content

Commit 1511779

Browse files
committed
* using packed functions
1 parent 2bc0763 commit 1511779

File tree

6 files changed

+30
-46
lines changed

6 files changed

+30
-46
lines changed

python/tvm/driver/tvmc/autotuner.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,6 @@ def add_tune_parser(subparsers, _, json_params):
226226
"--pre-build-hooks",
227227
action="append",
228228
help="specify any pre processing hooks before relay.build",
229-
choices=composite_target.get_pre_build_hooks(),
230229
default=[],
231230
)
232231
for one_entry in json_params:
@@ -410,7 +409,8 @@ def tune_model(
410409

411410
if pre_build_hooks:
412411
for hook in pre_build_hooks:
413-
mod = composite_target.PRE_BUILD_HOOKS[hook](mod, params)
412+
packed_function = tvm.get_global_func(hook)
413+
mod = packed_function(mod, params)
414414

415415
# min_repeat_ms should be:
416416
# a. the value provided by the user, if any, or

python/tvm/driver/tvmc/compiler.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -149,14 +149,12 @@ def add_compile_parser(subparsers, _, json_params):
149149
"--pre-build-hooks",
150150
action="append",
151151
help="specify any pre processing hooks before relay.build",
152-
choices=composite_target.get_pre_build_hooks(),
153152
default=[],
154153
)
155154
parser.add_argument(
156155
"--post-build-hooks",
157156
action="append",
158157
help="specify any post processing hooks after relay.build",
159-
choices=composite_target.get_post_build_hooks(),
160158
default=[],
161159
)
162160
for one_entry in json_params:
@@ -340,7 +338,8 @@ def compile_model(
340338

341339
if pre_build_hooks:
342340
for hook in pre_build_hooks:
343-
mod = composite_target.PRE_BUILD_HOOKS[hook](mod, params)
341+
packed_function = tvm.get_global_func(hook)
342+
mod = packed_function(mod, params)
344343

345344
if tuning_records and os.path.exists(tuning_records):
346345
logger.debug("tuning records file provided: %s", tuning_records)
@@ -393,7 +392,8 @@ def compile_model(
393392

394393
if post_build_hooks:
395394
for hook in post_build_hooks:
396-
graph_module = composite_target.POST_BUILD_HOOKS[hook](mod)
395+
packed_function = tvm.get_global_func(hook)
396+
packed_function(graph_module.get_lib())
397397

398398
# Generate output dump files with sources
399399
if dump_code is None:

python/tvm/driver/tvmc/composite_target.py

Lines changed: 0 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929
from tvm.relay.op.contrib.bnns import partition_for_bnns
3030
from tvm.relay.op.contrib.vitis_ai import partition_for_vitis_ai
3131
from tvm.relay.op.contrib.clml import partition_for_clml
32-
from tvm.relay.op.contrib import adreno
3332

3433

3534
from tvm.driver.tvmc import TVMCException
@@ -108,35 +107,3 @@ def get_codegen_by_target(name):
108107
return REGISTERED_CODEGEN[name]
109108
except KeyError:
110109
raise TVMCException("Composite target %s is not defined in TVMC." % name)
111-
112-
113-
# Global dictionary of pre build hook and hook function
114-
PRE_BUILD_HOOKS = {
115-
"adreno-precision-fp16": adreno.mixed_precision_hook_fp16,
116-
"adreno-precision-fp16_acc32": adreno.mixed_precision_hook_fp16_acc32,
117-
}
118-
119-
120-
def get_pre_build_hooks():
121-
"""Return a list of all registered pre build hooks.
122-
123-
Returns
124-
-------
125-
list of str
126-
all registered pre build hooks
127-
"""
128-
return list(PRE_BUILD_HOOKS.keys())
129-
130-
131-
POST_BUILD_HOOKS = {}
132-
133-
134-
def get_post_build_hooks():
135-
"""Return a list of all registered post build hooks.
136-
137-
Returns
138-
-------
139-
list of str
140-
all registered post build hooks
141-
"""
142-
return list(POST_BUILD_HOOKS.keys())

python/tvm/relay/op/contrib/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,3 +27,4 @@
2727
from .tensorrt import *
2828
from .cutlass import *
2929
from .clml import *
30+
from .adreno import *

python/tvm/relay/op/contrib/adreno.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,16 +82,17 @@ def convert_to_dtype(mod, dtype):
8282
mod = seq(mod)
8383
else:
8484
print("Warn: Invald dtype conversion to ", dtype)
85-
print("Mod:", mod)
8685
return mod
8786

8887

88+
@tvm.register_func("adreno.mixed_precision_fp16")
8989
def mixed_precision_hook_fp16(mod, params):
9090
"""TVMC hook api"""
9191

9292
return convert_to_dtype(mod["main"], "float16")
9393

9494

95+
@tvm.register_func("adreno.mixed_precision_fp16_acc32")
9596
def mixed_precision_hook_fp16_acc32(mod, params):
9697
"""TVMC hook api"""
9798

tests/python/driver/tvmc/test_compiler.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -714,17 +714,32 @@ def run_after_pass(self, mod, info):
714714
assert passes_counter.run_after_count == passes_counter.run_before_count
715715

716716

717-
def test_compile_preprocess_hooks(keras_resnet50):
718-
# some CI environments wont offer tensorflow/Keras, so skip in case it is not present
717+
def test_compile_hooks(keras_resnet50):
719718
pytest.importorskip("tensorflow")
719+
pre_hook_called = False
720+
post_hook_called = False
721+
722+
@tvm.register_func("tvmc.test_pre_hook")
723+
def test_pre_hook(mod, params):
724+
nonlocal pre_hook_called
725+
pre_hook_called = True
726+
return mod
727+
728+
@tvm.register_func("tvmc.test_post_hook")
729+
def test_post_hook(libm):
730+
nonlocal post_hook_called
731+
post_hook_called = True
732+
return
720733

721734
tvmc_model = tvmc.load(keras_resnet50)
722735
tvmc_package = tvmc.compile(
723-
tvmc_model, target="llvm", pre_build_hooks=["adreno-precision-fp16"]
724-
)
725-
tvmc_package = tvmc.compile(
726-
tvmc_model, target="llvm", pre_build_hooks=["adreno-precision-fp16_acc32"]
736+
tvmc_model,
737+
target="llvm",
738+
pre_build_hooks=["tvmc.test_pre_hook"],
739+
post_build_hooks=["tvmc.test_post_hook"],
727740
)
741+
assert pre_hook_called
742+
assert post_hook_called
728743

729744

730745
if __name__ == "__main__":

0 commit comments

Comments
 (0)