Skip to content

Commit 2995e1c

Browse files
authored
[MLC-LLM] Separate function for generating sharding PrimFunc (#250)
[MLCChat] Separate function for generating sharding PrimFunc The `create_shard_func` produces both the `tir.PrimFunc` that performs the sharding and a `relax.Function` wrapper that calls into it. This separates out the two functions, such that the `tir.PrimFunc` can be generated without also generating the `relax.Function` wrapper.
1 parent 7343d87 commit 2995e1c

File tree

1 file changed

+16
-6
lines changed

1 file changed

+16
-6
lines changed

python/mlc_llm/support/preshard.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import logging
33
from typing import Any, Callable, Dict, Sequence, Tuple, List
44

5+
import tvm
56
from tvm import IRModule
67
from tvm import dlight as dl
78
from tvm import relax
@@ -43,26 +44,35 @@ def _update_quantize_map(
4344
for worker_id in range(tensor_parallel_shards):
4445
named_params[_sharded_param_name(param_name, worker_id)] = param
4546

47+
def create_tir_shard_func(
48+
param: nn.Parameter,
49+
tensor_parallel_shards: int,
50+
) -> Tuple[tvm.tir.PrimFunc, List[tvm.tir.PrimExpr], List[tvm.tir.PrimExpr]]:
51+
shard_strategy = param.attrs.get("shard_strategy", None)
52+
tir_func = shard_strategy.gen_tir(shards=tensor_parallel_shards, weight=param)
53+
tir_func = tir_func.without_attr("global_symbol")
54+
weight_shape = list(param.shape)
55+
weight_shape[shard_strategy.dim] = weight_shape[shard_strategy.dim] * tensor_parallel_shards
56+
sharded_weight_shape = [tensor_parallel_shards, *param.shape]
57+
58+
return tir_func, weight_shape, sharded_weight_shape
4659

4760
def create_shard_func(
4861
bb: relax.BlockBuilder,
4962
param: nn.Parameter,
5063
tensor_parallel_shards: int,
5164
do_split: bool = True,
5265
): # pylint: disable=too-many-locals
53-
shard_strategy = param.attrs.get("shard_strategy", None)
66+
5467
# generate tir shard function
55-
tir_func = shard_strategy.gen_tir(shards=tensor_parallel_shards, weight=param)
56-
tir_func = tir_func.with_attr("global_symbol", f"{shard_strategy.name}_tir")
68+
tir_func, weight_shape, sharded_weight_shape = create_tir_shard_func(param, tensor_parallel_shards)
69+
shard_strategy = param.attrs.get("shard_strategy", None)
5770
# add tir shard function to the IRModule
5871
tir_gvar = bb.add_func(tir_func, func_name=f"{shard_strategy.name}_tir")
5972
# create relax function that
6073
# 1. shard weight with tir shard function, result: [num_shards, *sharded_weight_shape]
6174
# 2. split the sharded weight along dim 0, result: num_shards * [1, *sharded_weight_shape]
6275
# 3. squeeze the 0th-dim of all shards, result: num_shards * [*sharded_weight_shape]
63-
weight_shape = param.shape
64-
weight_shape[shard_strategy.dim] = weight_shape[shard_strategy.dim] * tensor_parallel_shards
65-
sharded_weight_shape = [tensor_parallel_shards, *param.shape]
6676
weight_var = relax.Var("weight", relax.TensorStructInfo(weight_shape, param.dtype))
6777

6878
with bb.function(name=shard_strategy.name, params=[weight_var]):

0 commit comments

Comments
 (0)