Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multinode #752

Merged
merged 69 commits into from
Mar 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
69 commits
Select commit Hold shift + click to select a range
65428a1
supporting multinode
jayfeather9 Feb 16, 2025
a26d48d
fix format
jayfeather9 Feb 16, 2025
64ccb4f
add cuda() with device id
jayfeather9 Feb 16, 2025
b3de424
Merge branch 'main' into multinode
Feb 18, 2025
3fd6e48
fix multinode abort
Feb 20, 2025
c0b1146
support chunked prefill
Feb 20, 2025
855e037
Merge branch 'main' into multinode
Feb 20, 2025
8ad585a
modify dist_utils & remove child_ips
Feb 21, 2025
e0844d3
Merge branch 'multinode' of https://github.com/ModelTC/lightllm into …
Feb 21, 2025
fa3f826
fix chunked_prefill for multinode
Feb 23, 2025
7e92ff6
Merge branch 'multinode' of https://github.com/ModelTC/lightllm into …
Feb 23, 2025
4a83a6a
merge main
Feb 23, 2025
461ec65
fix health
Feb 23, 2025
b15a487
update port
Feb 24, 2025
4e419df
修改 rank 配置。
hiworldwzj Feb 26, 2025
8237f19
refactor multinode
Feb 27, 2025
b55edd7
Merge branch 'main' into multinode
Feb 27, 2025
958b83d
fix get_dp_size
Feb 27, 2025
fffb99e
remove tp_rank of get_available_gpu_memory
Feb 27, 2025
ecd495c
fix chunked prefill
Feb 27, 2025
82a756f
fix dist_utils
Feb 27, 2025
814f095
multinode utils
Feb 27, 2025
b33b3b4
update router mulitnode mananger
Feb 27, 2025
296a579
fix chunked prefill
Feb 27, 2025
2afd14b
refomat
Feb 27, 2025
7061bfb
fix
hiworldwzj Feb 27, 2025
0bde847
refactor order
Feb 28, 2025
39b90bf
fix
Feb 28, 2025
429f9c3
update httpserver sync
Feb 28, 2025
4377c20
update
Feb 28, 2025
4abb4a1
remove cudagraph_step_length
Feb 28, 2025
7646d6e
modify the default value of current_waiting_num
Feb 28, 2025
8446544
update
Feb 28, 2025
877b98f
fix
Feb 28, 2025
8ff7ed6
fix visualserver
Feb 28, 2025
249dea7
fix
Feb 28, 2025
64bdc11
update mem_manager
Feb 28, 2025
d68e0d7
fix start rank params.
hiworldwzj Feb 28, 2025
2461069
fix
Feb 28, 2025
b3dbecd
fix
hiworldwzj Mar 1, 2025
ea4dc98
fix
hiworldwzj Mar 1, 2025
1cffab4
fix
hiworldwzj Mar 1, 2025
4f68164
update docs
Mar 1, 2025
2b3f07a
fix
hiworldwzj Mar 1, 2025
068663a
fix
hiworldwzj Mar 1, 2025
d14fcea
fix
hiworldwzj Mar 1, 2025
46dbbb2
fix
Mar 1, 2025
72f4eb3
fix
hiworldwzj Mar 1, 2025
a640f72
fix
hiworldwzj Mar 1, 2025
5cd5dbf
fix
hiworldwzj Mar 1, 2025
5d13dc8
fix
hiworldwzj Mar 1, 2025
ce4d3eb
reformat
Mar 1, 2025
fd49ab3
完善rank信息的管理。
hiworldwzj Mar 3, 2025
b6a8168
fix
hiworldwzj Mar 3, 2025
4486665
fix radix cache rank.
hiworldwzj Mar 3, 2025
7bb9967
fix req queue init error.
hiworldwzj Mar 3, 2025
3db8420
Merge remote-tracking branch 'origin/main' into multinode
hiworldwzj Mar 3, 2025
1659c0c
python version >= 3.9.16
hiworldwzj Mar 3, 2025
46d0b13
fix is_master_in_dp.
hiworldwzj Mar 4, 2025
fb1c40b
修复decode节点init req 代码重复。
hiworldwzj Mar 4, 2025
27a7045
Merge remote-tracking branch 'origin/main' into multinode
hiworldwzj Mar 4, 2025
16e67e3
修复新的rank管理机制下的pd 分离实现。
hiworldwzj Mar 5, 2025
6418a1e
fix bug for pd reformater.
hiworldwzj Mar 5, 2025
fbf58d3
add test.sh
hiworldwzj Mar 5, 2025
6f340e9
fix reformatted
hiworldwzj Mar 5, 2025
5c4a94b
add details for rank.
hiworldwzj Mar 5, 2025
b2f5405
fix
hiworldwzj Mar 5, 2025
c48e657
fix
hiworldwzj Mar 5, 2025
ccbeea9
fix
hiworldwzj Mar 6, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 18 additions & 15 deletions lightllm/common/basemodel/infer_lock.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,28 +17,32 @@


class InferStateLock:
def __init__(self, name):
def __init__(self, name, rank_in_dp: int, dp_rank_in_node: int, dp_world_size: int):
self.infer_lock = threading.Lock()
self.dp_rank_in_node = dp_rank_in_node
# sync_world_size 应该是 min(dp_world_size, node_world_size)
self.dp_world_size = dp_world_size
self.rank_in_dp = rank_in_dp
# 默认开 128 tp 的空间, 现在应该没什么卡能开这么大的tp 吧
self.lock_tp_infos = SharedArray(f"{name}_lock_tp_infos", shape=(129,), dtype=np.int64)
self.lock_tp_infos = SharedArray(
f"{name}_dp_rank_{str(self.dp_rank_in_node)}_lock_tp_infos", shape=(self.dp_world_size + 1,), dtype=np.int64
)
self.lock_tp_infos.arr[:] = 0
self.rank_id = dist.get_rank()
self.world_size = dist.get_world_size()

def add_cur_mark(self):
self.lock_tp_infos.arr[self.rank_id] += 1
self.lock_tp_infos.arr[self.rank_in_dp] += 1

def get_cur_mark(self):
return self.lock_tp_infos.arr[self.rank_id]
return self.lock_tp_infos.arr[self.rank_in_dp]

def get_max_mark_in_group(self):
return np.max(self.lock_tp_infos.arr[0 : self.world_size])
return np.max(self.lock_tp_infos.arr[0 : self.dp_world_size])

def judge_cur_mark_equal_max_mark_in_group(self):
return self.get_cur_mark() == self.get_max_mark_in_group()

def judge_mark_in_group_all_same(self):
marks = self.lock_tp_infos.arr[0 : self.world_size]
marks = self.lock_tp_infos.arr[0 : self.dp_world_size]
return bool(np.all(marks == marks[0]))

def acquire_lock_and_update_cur_mark(self):
Expand All @@ -49,11 +53,11 @@ def release_lock(self):
self.infer_lock.release()

def set_group_wait_mark(self):
if self.rank_id == 0:
if self.rank_in_dp == 0:
self.lock_tp_infos.arr[-1] = 1

def unset_group_wait_mark(self):
if self.rank_id == 0:
if self.rank_in_dp == 0:
self.lock_tp_infos.arr[-1] = 0

def get_group_wait_mark(self):
Expand All @@ -63,7 +67,7 @@ def get_group_wait_mark(self):
@dataclass
class G_Infer_Lock:
obj: InferStateLock = None
dp_size: int = None
dp_world_size: int = None

def acquire(self):
if self.obj is not None:
Expand All @@ -86,9 +90,8 @@ def release(self):

# 下面两个函数需要配对使用
def acquire_lock_until_ready(nccl_group):
# 在 deepseekv2 的tp dp 混合运行模式下, 不需要多个推理进程间做协调同步
# 所以直接加锁,解锁即可
if g_infer_state_lock.dp_size != 1:
# 单卡一tp不用过度加锁
if g_infer_state_lock.dp_world_size == 1:
g_infer_state_lock.obj.infer_lock.acquire()
return

Expand Down Expand Up @@ -118,7 +121,7 @@ def release_acquired_lock():
@dataclass
class G_Router_Lock:
"""
保护pd分离模式下, 一些数据的操作
保护pd分离模式下, 一些调度相关信息数据的操作
"""

obj = None # 进程锁对象
Expand Down
18 changes: 10 additions & 8 deletions lightllm/common/deepseek2_mem_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,9 @@ def alloc_kv_move_buffer(self, max_req_total_len):
return

def send_to_decode_node(
self, move_tasks: List[KVMoveTask], mem_managers: List["Deepseek2MemoryManager"], dp_size: int
self, move_tasks: List[KVMoveTask], mem_managers: List["Deepseek2MemoryManager"], dp_size_in_node: int
):
assert dp_size == 1
assert dp_size_in_node == 1

# 先将数据发送到指定的一张卡上的buffer,再发送。
move_token_indexes = []
Expand All @@ -66,9 +66,9 @@ def _get_kv_move_data(self, token_indexes: List[int], layer_index: int):
return move_buffer

def receive_from_prefill_node(
self, move_tasks: List[KVMoveTask], mem_managers: List["MemoryManager"], dp_size: int
self, move_tasks: List[KVMoveTask], mem_managers: List["MemoryManager"], dp_size_in_node: int
):
assert dp_size == 1
assert dp_size_in_node == 1

# 先将数据接受到指定的一张卡上的buffer,再复制到其他的卡上。
move_token_indexes = []
Expand Down Expand Up @@ -97,11 +97,13 @@ def _write_kv_move_data(self, token_indexes: torch.Tensor, buffer_tensor: torch.
self.kv_buffer[layer_index : layer_index + 1, token_indexes, :, :] = buffer_tensor
return

def send_to_decode_node_p2p(self, move_tasks: List[KVMoveTask], mem_managers: List["MemoryManager"], dp_size: int):
def send_to_decode_node_p2p(
self, move_tasks: List[KVMoveTask], mem_managers: List["MemoryManager"], dp_size_in_node: int
):
"""
使用 p2p triton kernel 进行数据复制和传输的实现方式。
"""
assert dp_size == 1
assert dp_size_in_node == 1

move_token_indexes = []
for task in move_tasks:
Expand All @@ -124,9 +126,9 @@ def _get_kv_move_data_p2p(self, token_indexes: torch.Tensor, layer_index: int, k
return move_buffer

def receive_from_prefill_node_p2p(
self, move_tasks: List[KVMoveTask], mem_managers: List["MemoryManager"], dp_size: int
self, move_tasks: List[KVMoveTask], mem_managers: List["MemoryManager"], dp_size_in_node: int
):
assert dp_size == 1
assert dp_size_in_node == 1

move_token_indexes = []
for task in move_tasks:
Expand Down
68 changes: 30 additions & 38 deletions lightllm/common/mem_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from lightllm.server.router.dynamic_prompt.shared_arr import SharedInt
from lightllm.utils.profile_max_tokens import get_available_gpu_memory, get_total_gpu_memory
from lightllm.common.kv_trans_kernel.kv_trans import kv_trans
from lightllm.utils.dist_utils import get_global_rank
from lightllm.utils.dist_utils import get_current_rank_in_node
from lightllm.utils.envs_utils import get_unique_server_name, get_env_start_args


Expand Down Expand Up @@ -37,8 +37,10 @@ def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False
# 用共享内存进行共享,router 模块读取进行精确的调度估计, nccl port 作为一个单机中单实列的标记。防止冲突。
from lightllm.utils.envs_utils import get_unique_server_name

rank_id = get_global_rank()
self.shared_can_use_token_num = SharedInt(f"{get_unique_server_name()}_mem_manger_can_use_token_num_{rank_id}")
rank_in_node = get_current_rank_in_node()
self.shared_can_use_token_num = SharedInt(
f"{get_unique_server_name()}_mem_manger_can_use_token_num_{rank_in_node}"
)

self.shared_can_use_token_num.set_value(self.can_use_mem_size)
self._init_buffers(
Expand Down Expand Up @@ -83,13 +85,10 @@ def alloc_kv_move_buffer(self, max_req_total_len):
self.kv_move_buf_indexes = torch.arange(0, max_req_total_len + 8, dtype=torch.int64, device="cuda")
return

def send_to_decode_node(self, move_tasks: List[KVMoveTask], mem_managers: List["MemoryManager"], dp_size: int):
"""
dp_size 是为 deepseekv2 类型,可以 dp 和 tp 混合模式运行的模型定制的参数,
普通tp模式下, dp_size 一定等于 1, dp_index 一定等于 0, 同时普通模式下, 这两个参数并不会
被真正使用
"""
assert dp_size == 1
def send_to_decode_node(
self, move_tasks: List[KVMoveTask], mem_managers: List["MemoryManager"], dp_size_in_node: int
):
assert dp_size_in_node == 1

# 先将数据发送到指定的一张卡上的buffer,再发送。

Expand Down Expand Up @@ -123,14 +122,9 @@ def _get_kv_move_data(self, token_indexes: List[int], layer_index: int):
return move_buffer

def receive_from_prefill_node(
self, move_tasks: List[KVMoveTask], mem_managers: List["MemoryManager"], dp_size: int
self, move_tasks: List[KVMoveTask], mem_managers: List["MemoryManager"], dp_size_in_node: int
):
"""
dp_size 是为 deepseekv2 类型,可以 dp 和 tp 混合模式运行的模型定制的参数,
普通tp模式下, dp_size 一定等于 1, 同时普通模式下, 这两个参数并不会
被真正使用
"""
assert dp_size == 1
assert dp_size_in_node == 1

# 先将数据接受到指定的一张卡上的buffer,再复制到其他的卡上。

Expand Down Expand Up @@ -160,11 +154,13 @@ def _write_kv_move_data(self, token_indexes: torch.Tensor, buffer_tensor: torch.
self.kv_buffer[layer_index : layer_index + 1, token_indexes, :, :] = buffer_tensor
return

def send_to_decode_node_p2p(self, move_tasks: List[KVMoveTask], mem_managers: List["MemoryManager"], dp_size: int):
def send_to_decode_node_p2p(
self, move_tasks: List[KVMoveTask], mem_managers: List["MemoryManager"], dp_size_in_node: int
):
"""
使用 p2p triton kernel 进行数据复制和传输的实现方式。
"""
assert dp_size == 1
assert dp_size_in_node == 1

# 先将数据发送到指定的一张卡上的buffer,再发送。

Expand All @@ -190,9 +186,9 @@ def _get_kv_move_data_p2p(self, token_indexes: torch.Tensor, layer_index: int, k
return move_buffer

def receive_from_prefill_node_p2p(
self, move_tasks: List[KVMoveTask], mem_managers: List["MemoryManager"], dp_size: int
self, move_tasks: List[KVMoveTask], mem_managers: List["MemoryManager"], dp_size_in_node: int
):
assert dp_size == 1
assert dp_size_in_node == 1

# 先将数据接受到指定的一张卡上的buffer,再复制到其他的卡上。

Expand Down Expand Up @@ -303,20 +299,16 @@ class ReadOnlyStaticsMemoryManager:
def __init__(self) -> None:
args = get_env_start_args()
self.global_world_size = args.tp
node_world_size = args.tp // args.nnodes
rank_start = args.node_rank * node_world_size
rank_end = (args.node_rank + 1) * node_world_size
self.shared_tp_infos = {
rank: SharedInt(f"{get_unique_server_name()}_mem_manger_can_use_token_num_{rank}")
for rank in range(rank_start, rank_end)
}

def get_unrefed_token_num(self, dp_rank: int):
args = get_env_start_args()
if args.dp == 1 and args.nnodes > 1:
# 兼容多机 dp size=1 的情况
rank_id = args.tp // args.nnodes * args.node_rank
return self.shared_tp_infos[rank_id].get_value()
dp_size = args.dp
dp_world_size = self.global_world_size // dp_size
return self.shared_tp_infos[dp_rank * dp_world_size].get_value()
self.node_world_size = args.tp // args.nnodes
self.dp_world_size = self.global_world_size // args.dp
# 兼容多机 dp size=1 纯 tp 模式的情况
self.is_multinode_tp = args.dp == 1 and args.nnodes > 1
self.shared_tp_infos = [
SharedInt(f"{get_unique_server_name()}_mem_manger_can_use_token_num_{rank_in_node}")
for rank_in_node in range(0, self.node_world_size, self.dp_world_size)
]

def get_unrefed_token_num(self, dp_rank_in_node: int):
if self.is_multinode_tp:
return self.shared_tp_infos[0].get_value()
return self.shared_tp_infos[dp_rank_in_node].get_value()
3 changes: 2 additions & 1 deletion lightllm/server/api_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,8 @@ def set_args(self, args):
enable_multimodal=args.enable_multimodal,
metric_port=args.metric_port,
)
self.shared_token_load = TokenLoad(f"{get_unique_server_name()}_shared_token_load", args.dp)
dp_size_in_node = max(1, args.dp // args.nnodes) # 兼容多机纯tp的运行模式,这时候 1 // 2 == 0, 需要兼容
self.shared_token_load = TokenLoad(f"{get_unique_server_name()}_shared_token_load", dp_size_in_node)


g_objs = G_Objs()
Expand Down
10 changes: 5 additions & 5 deletions lightllm/server/api_start.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,13 +165,13 @@ def normal_or_p_d_start(args):
ports_locker = PortLocker(already_uesd_ports)
ports_locker.lock_port()

node_world_size = args.tp // args.nnodes
can_use_ports = alloc_can_use_network_port(
num=6 + args.tp + args.tp + args.visual_dp * args.visual_tp, used_nccl_ports=already_uesd_ports
num=6 + node_world_size + args.visual_dp * args.visual_tp, used_nccl_ports=already_uesd_ports
)
logger.info(f"alloced ports: {can_use_ports}")
router_port, detokenization_port, detokenization_pub_port, visual_port, cache_port, metric_port = can_use_ports[0:6]
model_rpc_ports = can_use_ports[6 : 6 + args.tp]
can_use_ports = can_use_ports[6 + args.tp :]
can_use_ports = can_use_ports[6:]

visual_model_tp_ports = []
for _ in range(args.visual_dp):
Expand All @@ -188,7 +188,7 @@ def normal_or_p_d_start(args):
args.metric_port = metric_port

# 申请在 p d 分离模式下,会用的端口
args.pd_tp_infer_rpyc_ports = can_use_ports[0 : args.tp]
args.pd_node_infer_rpyc_ports = can_use_ports[0:node_world_size]
# p d 分离模式下用于标识节点的id
args.pd_node_id = uuid.uuid4().int
# p 节点用来建立torch kv 传输分布组的可用端口范围
Expand Down Expand Up @@ -231,7 +231,7 @@ def normal_or_p_d_start(args):
process_manager.start_submodule_processes(
start_funcs=[start_router_process, start_detokenization_process],
start_args=[
(args, router_port, detokenization_port, model_rpc_ports, metric_port),
(args, router_port, detokenization_port, metric_port),
(args, detokenization_port, detokenization_pub_port),
],
)
Expand Down
6 changes: 3 additions & 3 deletions lightllm/server/router/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@


class Batch:
def __init__(self, batch_id, reqs: List[Req], dp_size: int):
def __init__(self, batch_id, reqs: List[Req], dp_size_in_node: int):
self.batch_id = batch_id
self.reqs = reqs
self.id_to_reqs = {req.request_id: req for req in reqs}
self.dp_size = dp_size
self.dp_size_in_node = dp_size_in_node
return

def input_tokens(self):
Expand All @@ -22,7 +22,7 @@ def input_tokens(self):
return batch_input_tokens

def get_batch_decode_need_tokens(self):
new_batch_decode_need_tokens = [0 for _ in range(self.dp_size)] # for chunked prefill
new_batch_decode_need_tokens = [0 for _ in range(self.dp_size_in_node)] # for chunked prefill

for req in self.reqs:
req_dp_index = req.sample_params.suggested_dp_index
Expand Down
Loading