1
1
import torch
2
2
import os
3
-
3
+ import torch . distributed as dist
4
4
from lightllm .server .pd_io_struct import KVMoveTask
5
5
from .mem_manager import MemoryManager
6
6
from typing import List
7
7
from lightllm .utils .log_utils import init_logger
8
+ from lightllm .common .kv_trans_kernel .kv_trans import kv_trans
8
9
9
10
logger = init_logger (__name__ )
10
11
@@ -33,6 +34,7 @@ def alloc_kv_move_buffer(self, max_req_total_len):
33
34
self .kv_move_buffer = torch .empty (
34
35
(1 , max_req_total_len + 8 , self .head_num , self .head_dim ), dtype = self .dtype , device = "cuda"
35
36
)
37
+ self .kv_move_buf_indexes = torch .arange (0 , max_req_total_len + 8 , dtype = torch .int64 , device = "cuda" )
36
38
return
37
39
38
40
def send_to_decode_node (
@@ -41,8 +43,6 @@ def send_to_decode_node(
41
43
assert dp_size == 1
42
44
43
45
# 先将数据发送到指定的一张卡上的buffer,再发送。
44
- import torch .distributed as dist
45
-
46
46
move_token_indexes = []
47
47
for task in move_tasks :
48
48
if task .move_kv_len != 0 :
@@ -69,8 +69,6 @@ def receive_from_prefill_node(
69
69
assert dp_size == 1
70
70
71
71
# 先将数据接受到指定的一张卡上的buffer,再复制到其他的卡上。
72
- import torch .distributed as dist
73
-
74
72
move_token_indexes = []
75
73
for task in move_tasks :
76
74
if task .move_kv_len != 0 :
@@ -97,6 +95,58 @@ def _write_kv_move_data(self, token_indexes: torch.Tensor, buffer_tensor: torch.
97
95
self .kv_buffer [layer_index : layer_index + 1 , token_indexes , :, :] = buffer_tensor
98
96
return
99
97
98
+ def send_to_decode_node_p2p (self , move_tasks : List [KVMoveTask ], mem_managers : List ["MemoryManager" ], dp_size : int ):
99
+ """
100
+ 使用 p2p triton kernel 进行数据复制和传输的实现方式。
101
+ """
102
+ assert dp_size == 1
103
+
104
+ move_token_indexes = []
105
+ for task in move_tasks :
106
+ if task .move_kv_len != 0 :
107
+ move_token_indexes .extend (task .prefill_token_indexes [- task .move_kv_len :])
108
+
109
+ move_token_indexes = torch .tensor (move_token_indexes , dtype = torch .int64 , device = "cuda" )
110
+ for layer_index in range (self .layer_num ):
111
+ move_buffer = self ._get_kv_move_data_p2p (move_token_indexes , layer_index , self .kv_move_buffer )
112
+ dist .send (move_buffer , dst = 1 )
113
+ return
114
+
115
+ def _get_kv_move_data_p2p (self , token_indexes : torch .Tensor , layer_index : int , kv_move_buffer : torch .Tensor ):
116
+ move_token_num = len (token_indexes )
117
+ move_size = self .kv_buffer .numel () // self .layer_num // self .size * move_token_num
118
+ move_buffer = kv_move_buffer .view (- 1 )[0 :move_size ].view (move_token_num , self .head_num , self .head_dim )
119
+ kv_trans (
120
+ self .kv_buffer [layer_index , :, :, :], token_indexes , move_buffer , self .kv_move_buf_indexes [0 :move_token_num ]
121
+ )
122
+ return move_buffer
123
+
124
+ def receive_from_prefill_node_p2p (
125
+ self , move_tasks : List [KVMoveTask ], mem_managers : List ["MemoryManager" ], dp_size : int
126
+ ):
127
+ assert dp_size == 1
128
+
129
+ move_token_indexes = []
130
+ for task in move_tasks :
131
+ if task .move_kv_len != 0 :
132
+ move_token_indexes .extend (task .decode_token_indexes [- task .move_kv_len :])
133
+
134
+ move_token_indexes = torch .tensor (move_token_indexes , dtype = torch .int64 , device = "cuda" )
135
+
136
+ token_num = len (move_token_indexes )
137
+ move_size = self .kv_buffer .numel () // self .layer_num // self .size * token_num
138
+ recive_buffer = self .kv_move_buffer .view (- 1 )[0 :move_size ].view (token_num , self .head_num , self .head_dim )
139
+ for layer_index in range (self .layer_num ):
140
+ dist .recv (recive_buffer , src = 0 )
141
+ for i , mem in enumerate (mem_managers ):
142
+ mem ._write_kv_move_data_p2p (move_token_indexes , recive_buffer , layer_index )
143
+ return
144
+
145
+ def _write_kv_move_data_p2p (self , token_indexes : torch .Tensor , buffer_tensor : torch .Tensor , layer_index ):
146
+ move_token_num = len (token_indexes )
147
+ kv_trans (buffer_tensor , self .kv_move_buf_indexes [0 :move_token_num ], self .kv_buffer [layer_index ], token_indexes )
148
+ return
149
+
100
150
@torch .no_grad ()
101
151
def free_all (self ):
102
152
self .can_use_mem_size = len (self .mem_state ) - self .holding_size
0 commit comments