Skip to content

Commit 17e1ac7

Browse files
committed
Merge branch 'unity' into unity_reorder_take_after_matmul_pr_16315
2 parents cb1af2c + 4c7c010 commit 17e1ac7

File tree

32 files changed

+1142
-537
lines changed

32 files changed

+1142
-537
lines changed

CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
cmake_minimum_required(VERSION 3.18)
1+
cmake_minimum_required(VERSION 3.24)
22
project(tvm C CXX)
33

44
# Utility functions

cmake/modules/CUDA.cmake

Lines changed: 4 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -38,27 +38,10 @@ if(USE_CUDA)
3838
list(APPEND TVM_RUNTIME_LINKER_LIBS ${CUDA_CUDA_LIBRARY})
3939
list(APPEND TVM_RUNTIME_LINKER_LIBS ${CUDA_NVRTC_LIBRARY})
4040

41-
# Compatibility with cmake 3.18+
42-
#
43-
# The updates to the cutlass kernels made in TVM PR#16244 require
44-
# symbols provided in cuda 7.5+. While the cuda architecture is
45-
# specified by setting `NVCC_FLAGS` in the `CMakeLists.txt` for each
46-
# kernel, cmake 3.18+ also sets it based on the
47-
# `CMAKE_CUDA_ARCHITECTURES` value. If not set, cmake will explicitly
48-
# pass the compute capability as nvidia's default of 5.2, *EVEN IF* it
49-
# has already been specified in `NVCC_FLAGS`. Because the kernels
50-
# cannot compile with compute capability of 5.2, this causes
51-
# compilation errors.
52-
#
53-
# By setting `CMAKE_CUDA_ARCHITECTURES` to `OFF`, cmake does not add
54-
# 5.2 as a target architecture.
55-
#
56-
# See https://cmake.org/cmake/help/latest/policy/CMP0104.html for
57-
# details on CMake's policy for CUDA architecture flags.
58-
#
59-
# See https://cmake.org/cmake/help/latest/policy/CMP0104.html for the
60-
# default CUDA architecture for each version of CUDA.
61-
set(CMAKE_CUDA_ARCHITECTURES OFF)
41+
if(NOT DEFINED CMAKE_CUDA_ARCHITECTURES)
42+
message(STATUS "CMAKE_CUDA_ARCHITECTURES not set, using native")
43+
set(CMAKE_CUDA_ARCHITECTURES native)
44+
endif()
6245

6346
if(USE_CUDNN)
6447
message(STATUS "Build with cuDNN support")

python/tvm/contrib/msc/core/codegen/codegen.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,4 +172,18 @@ def relay_to_relax(
172172
def _bind_weights(mod: tvm.IRModule, folder: msc_utils.MSCDirectory) -> tvm.IRModule:
173173
return BindParams("main", weights)(mod)
174174

175-
return codegen.load(inputs, post_load=_bind_weights)
175+
mod = codegen.load(inputs, post_load=_bind_weights)
176+
177+
mod = tvm.ir.transform.Sequential(
178+
[
179+
# The canonicalization of relax variable bindings is not required
180+
# for correctness. It does, however, remove trivial `x = y`
181+
# bindings, preventing test cases from depending on their
182+
# presence.
183+
tvm.relax.transform.CanonicalizeBindings(),
184+
tvm.relax.transform.ConvertToDataflow(min_size=1),
185+
],
186+
name="tvm.contrib.msc.core.codegen.relay_to_relax_postproc",
187+
)(mod)
188+
189+
return mod

python/tvm/contrib/msc/framework/tvm/codegen/codegen.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,4 +71,18 @@ def _bind_weights(mod: tvm.IRModule, folder: msc_utils.MSCDirectory) -> tvm.IRMo
7171
return mod
7272

7373
codegen = CodeGen(graph, _ffi_api.GetRelaxSources, codegen_config, print_config, build_folder)
74-
return codegen.load(inputs, pre_load=_save_weights, post_load=_bind_weights)
74+
mod = codegen.load(inputs, pre_load=_save_weights, post_load=_bind_weights)
75+
76+
mod = tvm.ir.transform.Sequential(
77+
[
78+
# The canonicalization of relax variable bindings is not required
79+
# for correctness. It does, however, remove trivial `x = y`
80+
# bindings, preventing test cases from depending on their
81+
# presence.
82+
tvm.relax.transform.CanonicalizeBindings(),
83+
tvm.relax.transform.ConvertToDataflow(min_size=1),
84+
],
85+
name="tvm.contrib.msc.framework.tvm.codegen.to_relax_postproc",
86+
)(mod)
87+
88+
return mod

python/tvm/driver/tvmc/compiler.py

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
from tvm import autotvm, auto_scheduler
3232
from tvm import relay
3333
from tvm.driver.tvmc.registry import generate_registry_args, reconstruct_registry_entity
34-
from tvm.ir.instrument import PassInstrument, PassTimingInstrument
34+
from tvm.ir.instrument import PassInstrument, PassTimingInstrument, PassPrintingInstrument
3535
from tvm.ir.memory_pools import WorkspaceMemoryPools
3636
from tvm.target import Target
3737
from tvm.relay.backend import Executor, Runtime
@@ -162,6 +162,18 @@ def add_compile_parser(subparsers, _, json_params):
162162
action="store_true",
163163
help="print compilation time per pass",
164164
)
165+
parser.add_argument(
166+
"--print-ir-before",
167+
help="print IR before each named pass of a comma-separated list of pass names."
168+
"e.g. '--print-ir-before [tir.SplitHostDevice,tir.ConvertSSA]' ",
169+
default="",
170+
)
171+
parser.add_argument(
172+
"--print-ir-after",
173+
help="print IR after each named pass of a comma-separated list of pass names."
174+
"e.g. '--print-ir-after [tir.SplitHostDevice,tir.ConvertSSA]' ",
175+
default="",
176+
)
165177
for one_entry in json_params:
166178
parser.set_defaults(**one_entry)
167179

@@ -220,6 +232,8 @@ def drive_compile(args):
220232
workspace_pools_recombobulate(args, [workspace_pools_target], extra_targets)
221233
),
222234
print_pass_times=args.print_pass_times,
235+
print_ir_before=args.print_ir_before,
236+
print_ir_after=args.print_ir_after,
223237
**transform_args,
224238
)
225239

@@ -247,6 +261,8 @@ def compile_model(
247261
mod_name: Optional[str] = "default",
248262
workspace_pools: Optional[WorkspaceMemoryPools] = None,
249263
print_pass_times: bool = False,
264+
print_ir_before: Optional[List[str]] = None,
265+
print_ir_after: Optional[List[str]] = None,
250266
instruments: Optional[Sequence[PassInstrument]] = None,
251267
desired_layout: Optional[str] = None,
252268
desired_layout_ops: Optional[List[str]] = None,
@@ -295,7 +311,7 @@ def compile_model(
295311
needs to be generated.
296312
disabled_pass: str, optional
297313
Comma-separated list of passes which needs to be disabled
298-
during compilation
314+
during compilation.
299315
pass_context_configs: list[str], optional
300316
List of strings containing a set of configurations to be passed to the
301317
PassContext.
@@ -310,6 +326,10 @@ def compile_model(
310326
compilation.
311327
print_pass_times: bool
312328
To enable printing a breakdown of compilation times by pass. Disabled by default.
329+
print_ir_before: list[str], optional
330+
To print IR before each named pass of a comma-separated list of passes.
331+
print_ir_after: list[str], optional
332+
To print IR after each named pass of a comma-separated list of passes.
313333
instruments: Optional[Sequence[PassInstrument]]
314334
The list of pass instrument implementations.
315335
desired_layout: str, optional
@@ -369,6 +389,12 @@ def compile_model(
369389
timing_inst = PassTimingInstrument()
370390
instruments = [timing_inst] if instruments is None else [timing_inst] + instruments
371391

392+
if print_ir_before or print_ir_after:
393+
print_ir_instr = PassPrintingInstrument(
394+
print_before_pass_names=print_ir_before, print_after_pass_names=print_ir_after
395+
)
396+
instruments = [print_ir_instr] if instruments is None else [print_ir_instr] + instruments
397+
372398
with tvm.transform.PassContext(
373399
opt_level=opt_level,
374400
config=config,
@@ -581,7 +607,6 @@ def dump_operation_offloads(mod: tvm.ir.IRModule, initial_mod: tvm.ir.IRModule,
581607
save_to_file = all([dump_path != "-", dump_path != ""])
582608

583609
if print_to_console or save_to_file:
584-
585610
operations_distribution = analyze_operations_distribution(mod)
586611

587612
def annotate_f(x):

python/tvm/ir/instrument.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -255,3 +255,21 @@ def render():
255255
profiles = timing_inst.render()
256256
"""
257257
return _ffi_instrument_api.RenderTimePassProfiles()
258+
259+
260+
@pass_instrument
261+
class PassPrintingInstrument:
262+
"""A pass instrument to print if before or
263+
print ir after each element of a named pass."""
264+
265+
def __init__(self, print_before_pass_names, print_after_pass_names):
266+
self.print_before_pass_names = print_before_pass_names
267+
self.print_after_pass_names = print_after_pass_names
268+
269+
def run_before_pass(self, mod, pass_info):
270+
if pass_info.name in self.print_before_pass_names:
271+
print(f"Print IR before: {pass_info.name}\n{mod}\n\n")
272+
273+
def run_after_pass(self, mod, pass_info):
274+
if pass_info.name in self.print_after_pass_names:
275+
print(f"Print IR after: {pass_info.name}\n{mod}\n\n")

python/tvm/relax/backend/dispatch_sort_scan.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@ def visit_call_(self, call: relax.Call) -> relax.Expr:
116116
tir_call = self.builder_.call_te(
117117
te_func,
118118
call.args[0],
119+
k=call.attrs.k,
119120
axis=call.attrs.axis,
120121
ret_type=call.attrs.ret_type,
121122
is_ascend=not call.attrs.largest,

python/tvm/relax/transform/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
"""Relax transformations. """
1818

1919
from .transform import (
20+
AdjustMatmulOrder,
2021
AllocateWorkspace,
2122
AlterOpImpl,
2223
AnnotateTIROpPattern,

python/tvm/relax/transform/transform.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1249,6 +1249,23 @@ def UpdateParamStructInfo(sinfo_func: Callable[[Var], Optional[StructInfo]]):
12491249
return _ffi_api.UpdateParamStructInfo(sinfo_func) # type: ignore
12501250

12511251

1252+
def AdjustMatmulOrder():
1253+
"""Reorder `x*(A*B)` to `(x*A)*B`
1254+
1255+
Useful for optimizing LoRA computations, where `matmul(x,
1256+
LoraA*LoraB)` may be computed as `matmul(matmul(x, LoraA),
1257+
LoraB)`, reducing the total memory usage.
1258+
1259+
1260+
Returns
1261+
-------
1262+
ret : tvm.transform.Pass
1263+
The corresponding pass.
1264+
"""
1265+
1266+
return _ffi_api.AdjustMatmulOrder() # type: ignore
1267+
1268+
12521269
def ReorderTakeAfterMatmul():
12531270
"""Reorder `matmul(x, take(weights, indices))` to `take(matmul(x,weights),indices)`
12541271

0 commit comments

Comments
 (0)