Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 36 additions & 3 deletions python/tvm/driver/tvmc/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from tvm import autotvm, auto_scheduler
from tvm import relay
from tvm.driver.tvmc.registry import generate_registry_args, reconstruct_registry_entity
from tvm.ir.instrument import PassInstrument, PassTimingInstrument
from tvm.ir.instrument import PassInstrument, PassTimingInstrument, PassPrintingInstrument
from tvm.ir.memory_pools import WorkspaceMemoryPools
from tvm.target import Target
from tvm.relay.backend import Executor, Runtime
Expand Down Expand Up @@ -162,6 +162,18 @@ def add_compile_parser(subparsers, _, json_params):
action="store_true",
help="print compilation time per pass",
)
parser.add_argument(
"--print-ir-before",
help="print IR before each named pass of a comma-separated list of pass names."
"e.g. '--print-ir-before [tir.SplitHostDevice,tir.ConvertSSA]' ",
default="",
)
parser.add_argument(
"--print-ir-after",
help="print IR after each named pass of a comma-separated list of pass names."
"e.g. '--print-ir-after [tir.SplitHostDevice,tir.ConvertSSA]' ",
default="",
)
for one_entry in json_params:
parser.set_defaults(**one_entry)

Expand Down Expand Up @@ -220,6 +232,8 @@ def drive_compile(args):
workspace_pools_recombobulate(args, [workspace_pools_target], extra_targets)
),
print_pass_times=args.print_pass_times,
print_ir_before=args.print_ir_before,
print_ir_after=args.print_ir_after,
**transform_args,
)

Expand Down Expand Up @@ -247,6 +261,8 @@ def compile_model(
mod_name: Optional[str] = "default",
workspace_pools: Optional[WorkspaceMemoryPools] = None,
print_pass_times: bool = False,
print_ir_before: List[str] = "",
print_ir_after: List[str] = "",
instruments: Optional[Sequence[PassInstrument]] = None,
desired_layout: Optional[str] = None,
desired_layout_ops: Optional[List[str]] = None,
Expand Down Expand Up @@ -295,7 +311,7 @@ def compile_model(
needs to be generated.
disabled_pass: str, optional
Comma-separated list of passes which needs to be disabled
during compilation
during compilation.
pass_context_configs: list[str], optional
List of strings containing a set of configurations to be passed to the
PassContext.
Expand All @@ -310,6 +326,10 @@ def compile_model(
compilation.
print_pass_times: bool
To enable printing a breakdown of compilation times by pass. Disabled by default.
print_ir_before: list[str]
To print IR before each named pass of a comma-separated list of passes.
print_ir_after: list[str]
To print IR after each named pass of a comma-separated list of passes.
instruments: Optional[Sequence[PassInstrument]]
The list of pass instrument implementations.
desired_layout: str, optional
Expand Down Expand Up @@ -369,6 +389,20 @@ def compile_model(
timing_inst = PassTimingInstrument()
instruments = [timing_inst] if instruments is None else [timing_inst] + instruments

if print_ir_before:
print_ir_before_instr = PassPrintingInstrument("before", print_ir_before)
instruments = (
[print_ir_before_instr]
if instruments is None
else [print_ir_before_instr] + instruments
)

if print_ir_after:
print_ir_after_instr = PassPrintingInstrument("after", print_ir_after)
instruments = (
[print_ir_after_instr] if instruments is None else [print_ir_after_instr] + instruments
)

with tvm.transform.PassContext(
opt_level=opt_level,
config=config,
Expand Down Expand Up @@ -581,7 +615,6 @@ def dump_operation_offloads(mod: tvm.ir.IRModule, initial_mod: tvm.ir.IRModule,
save_to_file = all([dump_path != "-", dump_path != ""])

if print_to_console or save_to_file:

operations_distribution = analyze_operations_distribution(mod)

def annotate_f(x):
Expand Down
15 changes: 15 additions & 0 deletions python/tvm/ir/instrument.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,3 +255,18 @@ def render():
profiles = timing_inst.render()
"""
return _ffi_instrument_api.RenderTimePassProfiles()


@pass_instrument
class PassPrintingInstrument:
def __init__(self, before_after, print_pass_name):
self.before_after = before_after
self.print_pass_name = print_pass_name

def run_before_pass(self, mod, pass_info):
if self.before_after == "before" and pass_info.name in self.print_pass_name:
print(f"Print IR before:\n{pass_info.name}\n{mod}\n\n")

def run_after_pass(self, mod, pass_info):
if self.before_after == "after" and pass_info.name in self.print_pass_name:
print(f"Print IR after:\n{pass_info.name}\n{mod}\n\n")
41 changes: 40 additions & 1 deletion tests/python/driver/tvmc/test_command_line.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,6 @@ def paddle_model(paddle_resnet50):
@mock.patch.object(compiler, "compile_model")
# @mock.patch.object(compiler, "compile_model")
def test_tvmc_compile_input_model(mock_compile_model, tmpdir_factory, model):

output_dir = tmpdir_factory.mktemp("output")
output_file = output_dir / "model.tar"

Expand Down Expand Up @@ -289,3 +288,43 @@ def test_tvmc_print_pass_times(capsys, keras_simple, tmpdir_factory):
captured_out = capsys.readouterr().out
for exp_str in ("Compilation time breakdown by pass:", "sequential:", "us]"):
assert exp_str in captured_out


@pytest.mark.parametrize(
"print_cmd, out_str",
[
(
"--print-ir-after=[tir.SplitHostDevice]",
(
"Print IR after:\ntir.SplitHostDevice\n# from tvm.script import ir as I\n",
"@I.ir_module",
),
),
(
"--print-ir-before=[tir.SplitHostDevice]",
("Print IR before:\ntir.SplitHostDevice\n# from tvm.script import ir as I\n"),
),
(
"--print-ir-after=[tir.ThreadSync,tir.SplitHostDevice]",
("tir.ThreadSync\n,tir.SplitHostDevice\n"),
),
(
"--print-ir-before=[tir.SplitHostDevice] --print-ir-after=[tir.SplitHostDevice]",
("Print IR before:\ntir.SplitHostDevice\n", "Print IR after:\ntir.SplitHostDevice\n"),
),
],
)
def test_tvmc_print_ir_before_after(capsys, keras_simple, tmpdir_factory, print_cmd, out_str):
pytest.importorskip("tensorflow")
tmpdir = tmpdir_factory.mktemp("out")

# Compile model
module_file = os.path.join(tmpdir, "keras-tvm.tar")
compile_cmd = f"tvmc compile --target 'llvm' {keras_simple} --output {module_file} {print_cmd}"
compile_args = compile_cmd.split(" ")[1:]
_main(compile_args)

# Check for printing IR before or IR after
captured_out = capsys.readouterr().out
for exp_str in out_str:
assert exp_str in captured_out