|
2 | 2 | import logging |
3 | 3 | from typing import Any, Callable, Dict, Sequence, Tuple, List |
4 | 4 |
|
| 5 | +import tvm |
5 | 6 | from tvm import IRModule |
6 | 7 | from tvm import dlight as dl |
7 | 8 | from tvm import relax |
@@ -43,26 +44,35 @@ def _update_quantize_map( |
43 | 44 | for worker_id in range(tensor_parallel_shards): |
44 | 45 | named_params[_sharded_param_name(param_name, worker_id)] = param |
45 | 46 |
|
| 47 | +def create_tir_shard_func( |
| 48 | + param: nn.Parameter, |
| 49 | + tensor_parallel_shards: int, |
| 50 | +) -> Tuple[tvm.tir.PrimFunc, List[tvm.tir.PrimExpr], List[tvm.tir.PrimExpr]]: |
| 51 | + shard_strategy = param.attrs.get("shard_strategy", None) |
| 52 | + tir_func = shard_strategy.gen_tir(shards=tensor_parallel_shards, weight=param) |
| 53 | + tir_func = tir_func.without_attr("global_symbol") |
| 54 | + weight_shape = list(param.shape) |
| 55 | + weight_shape[shard_strategy.dim] = weight_shape[shard_strategy.dim] * tensor_parallel_shards |
| 56 | + sharded_weight_shape = [tensor_parallel_shards, *param.shape] |
| 57 | + |
| 58 | + return tir_func, weight_shape, sharded_weight_shape |
46 | 59 |
|
47 | 60 | def create_shard_func( |
48 | 61 | bb: relax.BlockBuilder, |
49 | 62 | param: nn.Parameter, |
50 | 63 | tensor_parallel_shards: int, |
51 | 64 | do_split: bool = True, |
52 | 65 | ): # pylint: disable=too-many-locals |
53 | | - shard_strategy = param.attrs.get("shard_strategy", None) |
| 66 | + |
54 | 67 | # generate tir shard function |
55 | | - tir_func = shard_strategy.gen_tir(shards=tensor_parallel_shards, weight=param) |
56 | | - tir_func = tir_func.with_attr("global_symbol", f"{shard_strategy.name}_tir") |
| 68 | + tir_func, weight_shape, sharded_weight_shape = create_tir_shard_func(param, tensor_parallel_shards) |
| 69 | + shard_strategy = param.attrs.get("shard_strategy", None) |
57 | 70 | # add tir shard function to the IRModule |
58 | 71 | tir_gvar = bb.add_func(tir_func, func_name=f"{shard_strategy.name}_tir") |
59 | 72 | # create relax function that |
60 | 73 | # 1. shard weight with tir shard function, result: [num_shards, *sharded_weight_shape] |
61 | 74 | # 2. split the sharded weight along dim 0, result: num_shards * [1, *sharded_weight_shape] |
62 | 75 | # 3. squeeze the 0th-dim of all shards, result: num_shards * [*sharded_weight_shape] |
63 | | - weight_shape = param.shape |
64 | | - weight_shape[shard_strategy.dim] = weight_shape[shard_strategy.dim] * tensor_parallel_shards |
65 | | - sharded_weight_shape = [tensor_parallel_shards, *param.shape] |
66 | 76 | weight_var = relax.Var("weight", relax.TensorStructInfo(weight_shape, param.dtype)) |
67 | 77 |
|
68 | 78 | with bb.function(name=shard_strategy.name, params=[weight_var]): |
|
0 commit comments