Skip to content

Commit dadf5d3

Browse files
[Pass] Attach memory-planning attributes for dynamic func output (#1604)
This PR adds a pass into the model compilation pipeline, which attach an attribute `"relax.memory_plan_dynamic_func_output"` for each Relax function in the IRModule. This attribute suggests that the Relax functions' output tensors, though having dynamic shapes, are statically plannable. This enhancement makes sure that in serving scenarios, our memory allcoation is completely static after stablized. So we will not be worried about continuing memory usage growth, and can allocate more memory for KV cache. This PR can be early merged, but it will not take effects until apache/tvm#16111 is merged.
1 parent 3e9d185 commit dadf5d3

File tree

2 files changed

+18
-1
lines changed

2 files changed

+18
-1
lines changed

python/mlc_chat/compiler_pass/attach_to_ir_module.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,3 +32,15 @@ def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IR
3232
for func_name, func in self.functions.items():
3333
mod[func_name] = func.with_attr("global_symbol", func_name)
3434
return mod
35+
36+
37+
@tvm.transform.module_pass(opt_level=0, name="AttachMemoryPlanAttr")
38+
class AttachMemoryPlanAttr: # pylint: disable=too-few-public-methods
39+
"""Attach memory planning attribute for dynamic function output planning to Relax functions."""
40+
41+
def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IRModule:
42+
"""Entrypoint"""
43+
for g_var, func in mod.functions_items():
44+
if isinstance(func, relax.Function):
45+
mod[g_var] = func.with_attr("relax.memory_plan_dynamic_func_output", True)
46+
return mod

python/mlc_chat/compiler_pass/pipeline.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,11 @@
1010

1111
from mlc_chat.support import logging
1212

13-
from .attach_to_ir_module import AttachAdditionalPrimFuncs, AttachVariableBounds
13+
from .attach_to_ir_module import (
14+
AttachAdditionalPrimFuncs,
15+
AttachMemoryPlanAttr,
16+
AttachVariableBounds,
17+
)
1418
from .clean_up_tir_attrs import CleanUpTIRAttrs
1519
from .cublas_dispatch import CublasDispatch
1620
from .estimate_memory_usage import AttachMetadataWithMemoryUsage
@@ -81,6 +85,7 @@ def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.I
8185
PruneRelaxFunc(flashinfer=flashinfer),
8286
AttachVariableBounds(variable_bounds),
8387
AttachAdditionalPrimFuncs(additional_tirs),
88+
AttachMemoryPlanAttr(),
8489
_DebugDump("debug-phase0.py", debug_dump, show_meta=False),
8590
# Phase 1. Passes on high-level operator graph
8691
_LogProgress("Running TVM Relax graph-level optimizations"),

0 commit comments

Comments
 (0)