Skip to content

Commit 04eef9a

Browse files
yelitesunggg
authored andcommitted
Cutlass offload (mlc-ai#2)
* Add cutlass offload * Add comments
1 parent 5056620 commit 04eef9a

File tree

6 files changed

+68
-23
lines changed

6 files changed

+68
-23
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ __pycache__/
1313
*.so
1414

1515
build*
16+
!build.py
1617

1718
*.ll
1819
.npm

build.py

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,12 @@
88
import tvm
99
from tvm import meta_schedule as ms
1010
from tvm import relax
11+
from tvm.relax.backend.pattern_registry import get_pattern
1112

1213
import mlc_llm
1314
from mlc_llm import utils
1415
from mlc_llm.relax_model import gpt_neox, llama, moss
16+
from mlc_llm.transform import rewrite_attention
1517

1618

1719
def _parse_args():
@@ -36,6 +38,7 @@ def _parse_args():
3638
choices=[*utils.quantization_dict.keys()],
3739
default=list(utils.quantization_dict.keys())[0],
3840
)
41+
args.add_argument("--cutlass-offload", action="store_true", default=False)
3942
args.add_argument("--max-seq-len", type=int, default=-1)
4043
args.add_argument("--target", type=str, default="auto")
4144
args.add_argument(
@@ -185,6 +188,8 @@ def debug_dump_script(mod, name, args):
185188
return
186189
dump_path = os.path.join(args.artifact_path, "debug", name)
187190
with open(dump_path, "w", encoding="utf-8") as outfile:
191+
# Remove runtime modules from external codegen so that the IR module can be printed.
192+
mod = mod.without_attr("external_mods").without_attr("const_name_to_constant")
188193
outfile.write(mod.script(show_meta=True))
189194
print(f"Dump mod to {dump_path}")
190195

@@ -240,11 +245,23 @@ def mod_transform_before_build(
240245
storage_nbit=args.quantization.storage_nbit,
241246
dtype=args.quantization.model_dtype,
242247
)(mod)
243-
mod = mlc_llm.transform.FuseTransposeMatmul()(mod) # pylint: disable=not-callable
244-
mod = relax.pipeline.get_pipeline()(mod) # pylint: disable=no-value-for-parameter
245-
mod = mlc_llm.transform.FuseDecodeMatmulEwise( # pylint: disable=not-callable
246-
args.quantization.model_dtype, args.target_kind
247-
)(mod)
248+
if args.target_kind == "cuda" and args.cutlass_offload:
249+
from tvm.relax.backend.contrib.cutlass import partition_for_cutlass
250+
251+
debug_dump_script(mod, "mod_before_cutlass.py", args)
252+
mod = partition_for_cutlass(mod)
253+
debug_dump_script(mod, "mod_after_cutlass_partition.py", args)
254+
codegen_pass = relax.transform.RunCodegen(
255+
{"cutlass": {"sm": 80, "find_first_valid": False}},
256+
entry_functions=model_names,
257+
)
258+
mod = codegen_pass(mod)
259+
debug_dump_script(mod, "mod_after_cutlass_codegen.py", args)
260+
261+
mod = mlc_llm.transform.FuseTransposeMatmul()(mod)
262+
263+
mod = relax.pipeline.get_pipeline()(mod)
264+
mod = mlc_llm.transform.FuseDecodeMatmulEwise(args.dtype)(mod)
248265
mod = relax.transform.DeadCodeElimination(model_names)(mod)
249266
mod = relax.transform.LiftTransformParams()(mod)
250267
mod_transform, mod_deploy = utils.split_transform_deploy_mod(mod, model_names)
@@ -317,10 +334,10 @@ def build(mod_deploy: tvm.IRModule, args: argparse.Namespace) -> None:
317334
ex = relax.build(mod_deploy, args.target, system_lib=args.system_lib)
318335

319336
output_filename = (
320-
f"{args.model}-{args.quantization.name}-{target_kind}.{args.lib_format}"
337+
f"{args.model}-{args.quantization.name}-{target_kind}_{args.dtype}.{args.lib_format}"
321338
)
322339

323-
debug_dump_shader(ex, f"{args.model}_{args.quantization.name}_{target_kind}", args)
340+
debug_dump_shader(ex, f"{args.model}_{args.quantization.name}_{target_kind}_{args.dtype}", args)
324341
lib_path = os.path.join(args.artifact_path, output_filename)
325342
ex.export_library(lib_path, **args.export_kwargs)
326343
print(f"Finish exporting to {lib_path}")

mlc_llm/relax_model/llama.py

Lines changed: 5 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -303,6 +303,8 @@ def forward(
303303
attention_mask.struct_info.shape.values,
304304
(bsz, tvm.tir.IntImm("int64", 1), q_len, kv_seq_len),
305305
)
306+
307+
attn_weights = nn.emit(relax.op.add(attn_weights, attention_mask))
306308

307309
attn_weights = nn.emit(
308310
maximum(
@@ -315,12 +317,7 @@ def forward(
315317
)
316318
attn_weights = nn.emit(relax.op.minimum(attn_weights, attention_mask))
317319

318-
# upcast attention to fp32
319-
if attn_weights.struct_info.dtype != "float32":
320-
attn_weights = astype(attn_weights, "float32")
321320
attn_weights = nn.emit(softmax(attn_weights, axis=-1))
322-
if attn_weights.struct_info.dtype != query_states.struct_info.dtype:
323-
attn_weights = astype(attn_weights, query_states.struct_info.dtype)
324321
attn_output = nn.emit(matmul(attn_weights, value_states))
325322

326323
tvm.ir.assert_structural_equal(
@@ -402,7 +399,7 @@ def min_max_triu_te():
402399
return te.compute(
403400
(tgt_len, tgt_len),
404401
lambda i, j: tvm.tir.Select(
405-
j > i, tvm.tir.min_value(dtype), tvm.tir.max_value(dtype)
402+
j > i, tvm.tir.min_value(dtype), tvm.tir.FloatImm(dtype, 0)
406403
),
407404
name="make_diag_mask_te",
408405
)
@@ -416,9 +413,7 @@ def extend_te(x, tgt_len, src_len):
416413
return te.compute(
417414
(bsz, 1, tgt_len, src_len),
418415
lambda b, _, i, j: te.if_then_else(
419-
j < src_len - tgt_len,
420-
tvm.tir.max_value(dtype),
421-
x[b, _, i, j - (src_len - tgt_len)],
416+
j < src_len - tgt_len, tvm.tir.FloatImm(dtype, 0), x[b, _, i, j - (src_len - tgt_len)]
422417
),
423418
name="concat_te",
424419
)
@@ -451,13 +446,7 @@ def _prepare_decoder_attention_mask(self, input_shape, src_len, dtype):
451446
# Get src_len from input parameters
452447
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
453448
bsz, tgt_len = input_shape
454-
combined_attention_mask = nn.emit(
455-
relax.op.full(
456-
(bsz, 1, tgt_len, src_len),
457-
relax.const(tvm.tir.max_value(dtype).value, dtype),
458-
dtype,
459-
)
460-
)
449+
combined_attention_mask = nn.emit(relax.op.full((bsz, 1, tgt_len, src_len), relax.const(0, dtype), dtype))
461450
return combined_attention_mask
462451

463452
def forward(

mlc_llm/transform/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,5 @@
22
from .lift_tir_global_buffer_alloc import LiftTIRGlobalBufferAlloc
33
from .quantization import GroupQuantize
44
from .transpose_matmul import FuseTransposeMatmul
5+
from .decode_matmul_ewise import FuseDecodeMatmulEwise
6+
from .rewrite_attention import rewrite_attention
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
from tvm.relax.dpl import PatternContext, is_const, is_op, rewrite_call, wildcard
2+
from tvm.script import relax as R
3+
4+
5+
def rewrite_attention(f):
6+
Q = wildcard()
7+
K = wildcard()
8+
V = wildcard()
9+
bias = wildcard()
10+
11+
Q_BNSH = is_op("relax.permute_dims")(Q)
12+
K_BNSH = is_op("relax.permute_dims")(K)
13+
V_BNSH = is_op("relax.permute_dims")(V)
14+
15+
K_BNSH_T = is_op("relax.permute_dims")(K_BNSH)
16+
17+
matmul1 = is_op("relax.matmul")(Q_BNSH, K_BNSH_T)
18+
divide = is_op("relax.divide")(matmul1, is_const())
19+
with_bias = is_op("relax.add")(divide, bias)
20+
softmax = is_op("relax.nn.softmax")(with_bias)
21+
matmul2 = is_op("relax.matmul")(softmax, V_BNSH)
22+
23+
pattern = is_op("relax.permute_dims")(matmul2)
24+
25+
def callback(_, matchings):
26+
return R.nn.attention(matchings[Q], matchings[K], matchings[V], matchings[bias])
27+
28+
return rewrite_call(pattern, callback, f)

mlc_llm/utils.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,14 @@ def split_transform_deploy_mod(
104104
)
105105
mod_deploy = relax.transform.DeadCodeElimination(model_names)(mod_deploy)
106106

107+
# Copy the runtime module from external codegen
108+
mod_deploy = mod_deploy.with_attrs(
109+
{
110+
"external_mods": mod.get_attr("external_mods"),
111+
"const_name_to_constant": mod.get_attr("const_name_to_constant"),
112+
}
113+
)
114+
107115
return mod_transform, mod_deploy
108116

109117

0 commit comments

Comments
 (0)