Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 16 additions & 6 deletions python/mlc_llm/support/preshard.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import logging
from typing import Any, Callable, Dict, Sequence, Tuple, List

import tvm
from tvm import IRModule
from tvm import dlight as dl
from tvm import relax
Expand Down Expand Up @@ -43,26 +44,35 @@ def _update_quantize_map(
for worker_id in range(tensor_parallel_shards):
named_params[_sharded_param_name(param_name, worker_id)] = param

def create_tir_shard_func(
param: nn.Parameter,
tensor_parallel_shards: int,
) -> Tuple[tvm.tir.PrimFunc, List[tvm.tir.PrimExpr], List[tvm.tir.PrimExpr]]:
shard_strategy = param.attrs.get("shard_strategy", None)
tir_func = shard_strategy.gen_tir(shards=tensor_parallel_shards, weight=param)
tir_func = tir_func.without_attr("global_symbol")
weight_shape = list(param.shape)
weight_shape[shard_strategy.dim] = weight_shape[shard_strategy.dim] * tensor_parallel_shards
sharded_weight_shape = [tensor_parallel_shards, *param.shape]

return tir_func, weight_shape, sharded_weight_shape

def create_shard_func(
bb: relax.BlockBuilder,
param: nn.Parameter,
tensor_parallel_shards: int,
do_split: bool = True,
): # pylint: disable=too-many-locals
shard_strategy = param.attrs.get("shard_strategy", None)

# generate tir shard function
tir_func = shard_strategy.gen_tir(shards=tensor_parallel_shards, weight=param)
tir_func = tir_func.with_attr("global_symbol", f"{shard_strategy.name}_tir")
tir_func, weight_shape, sharded_weight_shape = create_tir_shard_func(param, tensor_parallel_shards)
shard_strategy = param.attrs.get("shard_strategy", None)
# add tir shard function to the IRModule
tir_gvar = bb.add_func(tir_func, func_name=f"{shard_strategy.name}_tir")
# create relax function that
# 1. shard weight with tir shard function, result: [num_shards, *sharded_weight_shape]
# 2. split the sharded weight along dim 0, result: num_shards * [1, *sharded_weight_shape]
# 3. squeeze the 0th-dim of all shards, result: num_shards * [*sharded_weight_shape]
weight_shape = param.shape
weight_shape[shard_strategy.dim] = weight_shape[shard_strategy.dim] * tensor_parallel_shards
sharded_weight_shape = [tensor_parallel_shards, *param.shape]
weight_var = relax.Var("weight", relax.TensorStructInfo(weight_shape, param.dtype))

with bb.function(name=shard_strategy.name, params=[weight_var]):
Expand Down