Skip to content

Commit c7896c9

Browse files
hiworldwzjwangzaijun
and
wangzaijun
authored
pd mode use p2p triton kernel to manage kv trans && refactor deepseekv2 code (#691)
Co-authored-by: wangzaijun <[email protected]>
1 parent 5aa8d9b commit c7896c9

File tree

15 files changed

+585
-436
lines changed

15 files changed

+585
-436
lines changed

lightllm/common/deepseek2_mem_manager.py

+55-5
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
import torch
22
import os
3-
3+
import torch.distributed as dist
44
from lightllm.server.pd_io_struct import KVMoveTask
55
from .mem_manager import MemoryManager
66
from typing import List
77
from lightllm.utils.log_utils import init_logger
8+
from lightllm.common.kv_trans_kernel.kv_trans import kv_trans
89

910
logger = init_logger(__name__)
1011

@@ -33,6 +34,7 @@ def alloc_kv_move_buffer(self, max_req_total_len):
3334
self.kv_move_buffer = torch.empty(
3435
(1, max_req_total_len + 8, self.head_num, self.head_dim), dtype=self.dtype, device="cuda"
3536
)
37+
self.kv_move_buf_indexes = torch.arange(0, max_req_total_len + 8, dtype=torch.int64, device="cuda")
3638
return
3739

3840
def send_to_decode_node(
@@ -41,8 +43,6 @@ def send_to_decode_node(
4143
assert dp_size == 1
4244

4345
# 先将数据发送到指定的一张卡上的buffer,再发送。
44-
import torch.distributed as dist
45-
4646
move_token_indexes = []
4747
for task in move_tasks:
4848
if task.move_kv_len != 0:
@@ -69,8 +69,6 @@ def receive_from_prefill_node(
6969
assert dp_size == 1
7070

7171
# 先将数据接受到指定的一张卡上的buffer,再复制到其他的卡上。
72-
import torch.distributed as dist
73-
7472
move_token_indexes = []
7573
for task in move_tasks:
7674
if task.move_kv_len != 0:
@@ -97,6 +95,58 @@ def _write_kv_move_data(self, token_indexes: torch.Tensor, buffer_tensor: torch.
9795
self.kv_buffer[layer_index : layer_index + 1, token_indexes, :, :] = buffer_tensor
9896
return
9997

98+
def send_to_decode_node_p2p(self, move_tasks: List[KVMoveTask], mem_managers: List["MemoryManager"], dp_size: int):
99+
"""
100+
使用 p2p triton kernel 进行数据复制和传输的实现方式。
101+
"""
102+
assert dp_size == 1
103+
104+
move_token_indexes = []
105+
for task in move_tasks:
106+
if task.move_kv_len != 0:
107+
move_token_indexes.extend(task.prefill_token_indexes[-task.move_kv_len :])
108+
109+
move_token_indexes = torch.tensor(move_token_indexes, dtype=torch.int64, device="cuda")
110+
for layer_index in range(self.layer_num):
111+
move_buffer = self._get_kv_move_data_p2p(move_token_indexes, layer_index, self.kv_move_buffer)
112+
dist.send(move_buffer, dst=1)
113+
return
114+
115+
def _get_kv_move_data_p2p(self, token_indexes: torch.Tensor, layer_index: int, kv_move_buffer: torch.Tensor):
116+
move_token_num = len(token_indexes)
117+
move_size = self.kv_buffer.numel() // self.layer_num // self.size * move_token_num
118+
move_buffer = kv_move_buffer.view(-1)[0:move_size].view(move_token_num, self.head_num, self.head_dim)
119+
kv_trans(
120+
self.kv_buffer[layer_index, :, :, :], token_indexes, move_buffer, self.kv_move_buf_indexes[0:move_token_num]
121+
)
122+
return move_buffer
123+
124+
def receive_from_prefill_node_p2p(
125+
self, move_tasks: List[KVMoveTask], mem_managers: List["MemoryManager"], dp_size: int
126+
):
127+
assert dp_size == 1
128+
129+
move_token_indexes = []
130+
for task in move_tasks:
131+
if task.move_kv_len != 0:
132+
move_token_indexes.extend(task.decode_token_indexes[-task.move_kv_len :])
133+
134+
move_token_indexes = torch.tensor(move_token_indexes, dtype=torch.int64, device="cuda")
135+
136+
token_num = len(move_token_indexes)
137+
move_size = self.kv_buffer.numel() // self.layer_num // self.size * token_num
138+
recive_buffer = self.kv_move_buffer.view(-1)[0:move_size].view(token_num, self.head_num, self.head_dim)
139+
for layer_index in range(self.layer_num):
140+
dist.recv(recive_buffer, src=0)
141+
for i, mem in enumerate(mem_managers):
142+
mem._write_kv_move_data_p2p(move_token_indexes, recive_buffer, layer_index)
143+
return
144+
145+
def _write_kv_move_data_p2p(self, token_indexes: torch.Tensor, buffer_tensor: torch.Tensor, layer_index):
146+
move_token_num = len(token_indexes)
147+
kv_trans(buffer_tensor, self.kv_move_buf_indexes[0:move_token_num], self.kv_buffer[layer_index], token_indexes)
148+
return
149+
100150
@torch.no_grad()
101151
def free_all(self):
102152
self.can_use_mem_size = len(self.mem_state) - self.holding_size

lightllm/common/kv_trans_kernel/__init__.py

Whitespace-only changes.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
import torch
2+
3+
import triton
4+
import triton.language as tl
5+
6+
7+
@triton.jit
8+
def _kv_trans_kernel(
9+
input_ptr,
10+
input_stride_0,
11+
input_stride_1,
12+
input_stride_2,
13+
input_token_idx_ptr,
14+
output_ptr,
15+
output_stride_0,
16+
output_stride_1,
17+
output_stride_2,
18+
output_token_idx_ptr,
19+
token_num: int,
20+
head_num: int,
21+
head_dim: int,
22+
grid_count: int,
23+
BLOCK_SIZE: tl.constexpr,
24+
NUM_STAGES: tl.constexpr,
25+
):
26+
input_stride_0 = tl.cast(input_stride_0, dtype=tl.int64)
27+
input_stride_1 = tl.cast(input_stride_1, dtype=tl.int64)
28+
output_stride_0 = tl.cast(output_stride_0, dtype=tl.int64)
29+
output_stride_1 = tl.cast(output_stride_1, dtype=tl.int64)
30+
31+
head_num_dim = head_num * head_dim
32+
tid = tl.program_id(0)
33+
34+
offs = tl.arange(0, BLOCK_SIZE)
35+
while tid < token_num:
36+
input_token_idx = tl.load(input_token_idx_ptr + tid)
37+
output_token_idx = tl.load(output_token_idx_ptr + tid)
38+
for block_idx in tl.range(0, tl.cdiv(head_num_dim, BLOCK_SIZE), 1, num_stages=NUM_STAGES):
39+
cur_offs = block_idx * BLOCK_SIZE + offs
40+
in_datas = tl.load(input_ptr + input_stride_0 * input_token_idx + cur_offs, mask=cur_offs < head_num_dim)
41+
tl.store(output_ptr + output_stride_0 * output_token_idx + cur_offs, in_datas, mask=cur_offs < head_num_dim)
42+
43+
tid += grid_count
44+
45+
return
46+
47+
48+
def kv_trans(input: torch.Tensor, input_idx: torch.Tensor, output: torch.Tensor, output_idx: torch.Tensor):
49+
assert input.is_contiguous()
50+
assert output.is_contiguous()
51+
assert len(input.shape) == 3
52+
assert len(output.shape) == 3
53+
assert len(input_idx) == len(output_idx)
54+
55+
_, head_num, head_dim = input.shape
56+
token_num = len(input_idx)
57+
# 用较少的资源来做数据传输,防止占用过多的 sm 计算单元
58+
grid_count = 20
59+
BLOCK_SIZE = 256
60+
NUM_STAGES = 3
61+
grid = (grid_count,)
62+
63+
_kv_trans_kernel[grid](
64+
input,
65+
*input.stride(),
66+
input_idx,
67+
output,
68+
*output.stride(),
69+
output_idx,
70+
token_num=token_num,
71+
head_num=head_num,
72+
head_dim=head_dim,
73+
grid_count=grid_count,
74+
BLOCK_SIZE=BLOCK_SIZE,
75+
NUM_STAGES=NUM_STAGES,
76+
num_warps=1,
77+
)
78+
return

lightllm/common/mem_manager.py

+61
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from lightllm.utils.log_utils import init_logger
88
from lightllm.server.router.dynamic_prompt.shared_arr import SharedInt
99
from lightllm.utils.profile_max_tokens import get_available_gpu_memory, get_total_gpu_memory
10+
from lightllm.common.kv_trans_kernel.kv_trans import kv_trans
1011

1112
logger = init_logger(__name__)
1213

@@ -78,6 +79,7 @@ def alloc_kv_move_buffer(self, max_req_total_len):
7879
self.kv_move_buffer = torch.empty(
7980
(1, max_req_total_len + 8, 2 * self.head_num, self.head_dim), dtype=self.dtype, device="cuda"
8081
)
82+
self.kv_move_buf_indexes = torch.arange(0, max_req_total_len + 8, dtype=torch.int64, device="cuda")
8183
return
8284

8385
def send_to_decode_node(self, move_tasks: List[KVMoveTask], mem_managers: List["MemoryManager"], dp_size: int):
@@ -159,6 +161,65 @@ def _write_kv_move_data(self, token_indexes: torch.Tensor, buffer_tensor: torch.
159161
self.kv_buffer[layer_index : layer_index + 1, token_indexes, :, :] = buffer_tensor
160162
return
161163

164+
def send_to_decode_node_p2p(self, move_tasks: List[KVMoveTask], mem_managers: List["MemoryManager"], dp_size: int):
165+
"""
166+
使用 p2p triton kernel 进行数据复制和传输的实现方式。
167+
"""
168+
assert dp_size == 1
169+
170+
# 先将数据发送到指定的一张卡上的buffer,再发送。
171+
import torch.distributed as dist
172+
173+
move_token_indexes = []
174+
for task in move_tasks:
175+
if task.move_kv_len != 0:
176+
move_token_indexes.extend(task.prefill_token_indexes[-task.move_kv_len :])
177+
178+
move_token_indexes = torch.tensor(move_token_indexes, dtype=torch.int64, device="cuda")
179+
for i, mem in enumerate(mem_managers):
180+
for layer_index in range(mem.layer_num):
181+
move_buffer = mem._get_kv_move_data_p2p(move_token_indexes, layer_index, self.kv_move_buffer)
182+
dist.send(move_buffer, dst=1)
183+
return
184+
185+
def _get_kv_move_data_p2p(self, token_indexes: torch.Tensor, layer_index: int, kv_move_buffer: torch.Tensor):
186+
move_token_num = len(token_indexes)
187+
move_size = self.kv_buffer.numel() // self.layer_num // self.size * move_token_num
188+
move_buffer = kv_move_buffer.view(-1)[0:move_size].view(move_token_num, 2 * self.head_num, self.head_dim)
189+
kv_trans(
190+
self.kv_buffer[layer_index, :, :, :], token_indexes, move_buffer, self.kv_move_buf_indexes[0:move_token_num]
191+
)
192+
return move_buffer
193+
194+
def receive_from_prefill_node_p2p(
195+
self, move_tasks: List[KVMoveTask], mem_managers: List["MemoryManager"], dp_size: int
196+
):
197+
assert dp_size == 1
198+
199+
# 先将数据接受到指定的一张卡上的buffer,再复制到其他的卡上。
200+
import torch.distributed as dist
201+
202+
move_token_indexes = []
203+
for task in move_tasks:
204+
if task.move_kv_len != 0:
205+
move_token_indexes.extend(task.decode_token_indexes[-task.move_kv_len :])
206+
207+
move_token_indexes = torch.tensor(move_token_indexes, dtype=torch.int64, device="cuda")
208+
209+
token_num = len(move_token_indexes)
210+
move_size = self.kv_buffer.numel() // self.layer_num // self.size * token_num
211+
recive_buffer = self.kv_move_buffer.view(-1)[0:move_size].view(token_num, 2 * self.head_num, self.head_dim)
212+
for i, mem in enumerate(mem_managers):
213+
for layer_index in range(mem.layer_num):
214+
dist.recv(recive_buffer, src=0)
215+
mem._write_kv_move_data_p2p(move_token_indexes, recive_buffer, layer_index)
216+
return
217+
218+
def _write_kv_move_data_p2p(self, token_indexes: torch.Tensor, buffer_tensor: torch.Tensor, layer_index):
219+
move_token_num = len(token_indexes)
220+
kv_trans(buffer_tensor, self.kv_move_buf_indexes[0:move_token_num], self.kv_buffer[layer_index], token_indexes)
221+
return
222+
162223
def _free_buffers(self):
163224
self.kv_buffer = None
164225

0 commit comments

Comments
 (0)