Skip to content

Commit da771ed

Browse files
Yejing-Lailoadamstjruwase
authored
Add MLP/lm_head tp grain size setting. (#6828)
This PR aims to add MLP/lm_head tp size granularity setting to deepspeed.init_inference() API. It will be more flexible to set the MLP/lm_head sharding grain size. DNN library favors tensor size in granularity of power of 2, we pick 64 as a default size. We aim to be able to set the MLP/lm_head tp grain size flexibly. This is a preliminary solution. If there is a better solution, we can discuss it together. Thanks~ --------- Co-authored-by: Logan Adams <[email protected]> Co-authored-by: Olatunji Ruwase <[email protected]>
1 parent 87c6506 commit da771ed

File tree

3 files changed

+15
-4
lines changed

3 files changed

+15
-4
lines changed

deepspeed/inference/config.py

+3
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,9 @@ class DeepSpeedTPConfig(DeepSpeedConfigModel):
4040
tp_size: int = 1
4141
""" Number of devices to split the model across using tensor parallelism. """
4242

43+
tp_grain_size: int = 64
44+
"Desired MLP/lm_head tp size granularity. DNN library favors tensor size in granularity of power of 2, we pick 64 as a default size."
45+
4346
mpu: object = None
4447
"""
4548
A model parallelism unit object that implements

deepspeed/module_inject/replace_module.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from .layers import TensorParallelOcShardConv2d, TensorParallelIcShardConv2d
1818

1919
from deepspeed import comm as dist
20-
from deepspeed.module_inject.tp_shard import set_num_kv_heads, set_n_embd, set_num_attention_heads
20+
from deepspeed.module_inject.tp_shard import set_num_kv_heads, set_n_embd, set_num_attention_heads, set_tp_grain_size
2121

2222
from .load_checkpoint import load_model_with_checkpoint
2323
import time
@@ -303,6 +303,9 @@ def replace_wo_policy(module, all_reduce_linears, prefix="", state_dict=None):
303303
if hasattr(model_config, 'num_attention_heads'):
304304
set_num_attention_heads(getattr(model_config, 'num_attention_heads'))
305305

306+
# 4.4 set tp_grain_size
307+
set_tp_grain_size(config.tensor_parallel.tp_grain_size)
308+
306309
# 5. Set linear policies
307310
_autotp.update_linear_policies()
308311

deepspeed/module_inject/tp_shard.py

+8-3
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,11 @@ def set_n_embd(num):
2222
n_embd = num
2323

2424

25+
def set_tp_grain_size(num):
26+
global tp_grain_size
27+
tp_grain_size = num
28+
29+
2530
def get_num_kv_heads():
2631
global num_kv_heads
2732
if 'num_kv_heads' in globals():
@@ -45,9 +50,9 @@ def get_shard_size(total_size, mp_size, name=None, rank=None):
4550
my_slices = (num_kv_heads // mp_size) + (1 if rank < (num_kv_heads % mp_size) else 0)
4651
return total_size * my_slices // num_kv_heads
4752
else:
48-
if total_size >= 64:
49-
grain_size = total_size // 64
50-
return (grain_size // mp_size + (1 if rank < (grain_size % mp_size) else 0)) * 64
53+
if total_size >= tp_grain_size:
54+
grain_size = total_size // tp_grain_size
55+
return (grain_size // mp_size + (1 if rank < (grain_size % mp_size) else 0)) * tp_grain_size
5156
else:
5257
return total_size // mp_size + (1 if rank < (total_size % mp_size) else 0)
5358

0 commit comments

Comments
 (0)