Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/tvm/tir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,4 +109,4 @@
from . import analysis
from . import stmt_functor
from .build import build
from .pipeline import get_pipeline
from .pipeline import get_tir_pipeline, get_default_tir_pipeline
46 changes: 28 additions & 18 deletions python/tvm/tir/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,41 +131,51 @@ def build(
assert isinstance(mod, tvm.IRModule)

# Step 0: Determine the target in environment
# It's used to bind the PrimFunc without target attr to serve as a default target
target_to_bind = Target.current() if target is None else target
if target_to_bind is None:
target_to_bind = "llvm"
assert target_to_bind is not None
target_to_bind = Target.canon_target(target_to_bind)

# Step 1: Determine the target to search for tir pipeline
target = Target.current() if target is None else target
if target is None:
target = "llvm"
assert target is not None
target = Target.canon_target(target)
for func in mod.functions.values():
f_target = func.attrs.get("target", None)
if f_target is not None:
target = f_target
break
if target is not None:
target = Target.canon_target(target)

# Step 1: Determine the host
# Step 2: Determine the host target
target_host = "llvm" if tvm.runtime.enabled("llvm") else "stackvm"
if target is not None:
if target.host is not None:
target_host = target.host
elif ndarray.device(target.kind.name, 0).device_type == ndarray.cpu(0).device_type:
target_host = target
else:
for func in mod.functions.values():
f_target = func.attrs.get("target", None)
if f_target is not None and f_target.host is not None:
target_host = f_target.host
assert target_host is not None
target_host = Target.canon_target(target_host)
target = target.with_host(target_host)
target_to_bind = target_to_bind.with_host(target_host)

# Step 2: Bind the target to the input module
mod = tvm.tir.transform.BindTarget(target)(mod)
# Step 3: Bind the target to the input module
mod = tvm.tir.transform.BindTarget(target_to_bind)(mod)

# Step 3: Apply the pipeline
# Step 4: Apply the tir pipeline
if pipeline is not None:
# custom pipeline
if isinstance(pipeline, str):
pipeline = tvm.tir.get_pipeline(pipeline)
mod = pipeline(mod)
pipeline = tvm.tir.get_tir_pipeline(pipeline)
else:
# default pipeline depends on the target
pipeline = tvm.tir.get_default_tir_pipeline(target)
mod = pipeline(mod)

# Step 4: Get host and device modules
# Step 5: Get host and device modules
host_mod, device_mod_dict = split_host_device_mods(mod)

# Step 5: Apply finalization passes
# Step 6: Apply finalization passes
host_mod = tvm.tir.pipeline.finalize_host_passes()(host_mod)
device_mod_dict = {
target: tvm.tir.pipeline.finalize_device_passes()(device_mod)
Expand Down
9 changes: 8 additions & 1 deletion python/tvm/tir/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def finalize_device_passes(): # pylint: disable=unused-argument
}


def get_pipeline(name: str = "default_tir", **kwargs) -> tvm.transform.Pass:
def get_tir_pipeline(name: str = "default_tir", **kwargs) -> tvm.transform.Pass:
"""Get pre-build pipeline by name

Parameters
Expand All @@ -173,3 +173,10 @@ def get_pipeline(name: str = "default_tir", **kwargs) -> tvm.transform.Pass:
f"Unknown pre-built pipeline {name}," f"candidates are {list(PIPELINE_MAP.keys())}"
)
return PIPELINE_MAP[name](**kwargs)


def get_default_tir_pipeline(
target: tvm.target.Target, # pylint: disable=unused-argument
) -> tvm.transform.Pass:
"""Get the default TIR pipeline for the given target."""
return default_tir_pipeline()
Loading