From 846ae0c44384c5915f86c1a0d6b0005d3dbd4696 Mon Sep 17 00:00:00 2001 From: sufubao Date: Tue, 14 Oct 2025 13:54:07 +0800 Subject: [PATCH 1/2] fix: MTP in chunked prefill mode --- lightllm/common/basemodel/batch_objs.py | 1 + .../server/router/model_infer/infer_batch.py | 10 +++++-- .../mode_backend/chunked_prefill/impl.py | 8 +++++- .../mode_backend/dp_backend/impl.py | 28 ++++++++++++++++--- .../generic_padded_pre_process.py | 6 +++- .../mode_backend/generic_pre_process.py | 5 +++- 6 files changed, 49 insertions(+), 9 deletions(-) diff --git a/lightllm/common/basemodel/batch_objs.py b/lightllm/common/basemodel/batch_objs.py index ae70f4940..8a8bac1a3 100644 --- a/lightllm/common/basemodel/batch_objs.py +++ b/lightllm/common/basemodel/batch_objs.py @@ -31,6 +31,7 @@ class ModelInput: # prefill 阶段使用的参数,但是不是推理过程使用的参数,是推理外部进行资源管理 # 的一些变量 b_prefill_has_output_cpu: List[bool] = None # 标记进行prefill的请求是否具有输出 + b_chunked_prefill_next_token_ids_cpu: List[int] = None # for chunked prefill mtp # 专有变量,用于一些特殊的模型,特殊的模式下, 传递一些特殊 # 的输入变量。只在特殊的模型模式下才会具体使用和生效。 diff --git a/lightllm/server/router/model_infer/infer_batch.py b/lightllm/server/router/model_infer/infer_batch.py index 643c317bd..88418b6f6 100644 --- a/lightllm/server/router/model_infer/infer_batch.py +++ b/lightllm/server/router/model_infer/infer_batch.py @@ -393,7 +393,13 @@ def get_input_token_ids(self): def get_chuncked_input_token_ids(self): chunked_start = self.cur_kv_len chunked_end = min(self.get_cur_total_len(), chunked_start + self.shm_req.chunked_prefill_size) - return self.shm_req.shm_prompt_ids.arr[0:chunked_end] + + if chunked_end < self.get_cur_total_len(): + next_token_id = self.shm_req.shm_prompt_ids.arr[chunked_end] + else: + next_token_id = -1 # last chunk + + return self.shm_req.shm_prompt_ids.arr[0:chunked_end], next_token_id def get_chuncked_input_token_len(self): chunked_start = self.cur_kv_len @@ -438,7 +444,7 @@ def _stop_sequences_matched(self, output_len: int): def prefill_need_token_num(self, is_chuncked_prefill: bool): if is_chuncked_prefill: - input_token_ids = self.get_chuncked_input_token_ids() + input_token_ids, _ = self.get_chuncked_input_token_ids() else: input_token_ids = self.get_input_token_ids() diff --git a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py index f2b0fab49..17977b4b7 100644 --- a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py @@ -192,8 +192,14 @@ def prefill_mtp( mask_func=self.prefill_mask_func, ) # mtp kv fill + b_has_out = torch.tensor(model_input.b_prefill_has_output_cpu, dtype=torch.bool, device="cuda") + b_chunked_next_token_ids = torch.tensor( + model_input.b_chunked_prefill_next_token_ids_cpu, dtype=torch.int64, device="cuda" + ) + mtp_next_token_ids = torch.where(b_has_out, next_token_ids, b_chunked_next_token_ids) + self._draft_prefill_forward( - model_input=model_input, model_output=model_output, next_token_ids=next_token_ids + model_input=model_input, model_output=model_output, next_token_ids=mtp_next_token_ids ) sync_event = torch.cuda.Event() sync_event.record() diff --git a/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py b/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py index 8d47d1057..22be58dba 100644 --- a/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py @@ -354,7 +354,13 @@ def prefill_mtp(self, event_pack: OverlapEventPack, prefill_reqs: List[InferReq] # mtp kv fill draft_next_token_ids_gpu = torch.zeros((model_input.batch_size), dtype=torch.int64, device="cuda") if req_num > 0: - draft_next_token_ids_gpu[0:req_num].copy_(next_token_ids) + b_has_out = torch.tensor(b_has_out_cpu, dtype=torch.bool, device="cuda") + b_chunked_next_token_ids = torch.tensor( + model_input.b_chunked_prefill_next_token_ids_cpu[0:req_num], dtype=torch.int64, device="cuda" + ) + mtp_next_token_ids = torch.where(b_has_out, next_token_ids, b_chunked_next_token_ids) + draft_next_token_ids_gpu[0:req_num].copy_(mtp_next_token_ids) + self._draft_prefill_forward( model_input=model_input, model_output=model_output, @@ -633,13 +639,27 @@ def prefill_overlap_mtp(self, event_pack: OverlapEventPack, prefill_reqs: List[I draft_model_input0, draft_model_input1 = model_input0, model_input1 draft_next_token_ids_gpu0 = torch.zeros((model_input0.batch_size), dtype=torch.int64, device="cuda") if req_num0 > 0: - draft_next_token_ids_gpu0[0:req_num0].copy_(next_token_ids[0:req_num0], non_blocking=True) + b_has_out0 = torch.tensor( + model_input0.b_prefill_has_output_cpu[0:req_num0], dtype=torch.bool, device="cuda" + ) + b_chunked_next_token_ids0 = torch.tensor( + model_input0.b_chunked_prefill_next_token_ids_cpu[0:req_num0], dtype=torch.int64, device="cuda" + ) + mtp_next_token_ids0 = torch.where(b_has_out0, next_token_ids[0:req_num0], b_chunked_next_token_ids0) + draft_next_token_ids_gpu0[0:req_num0].copy_(mtp_next_token_ids0, non_blocking=True) draft_next_token_ids_gpu1 = torch.zeros((model_input1.batch_size), dtype=torch.int64, device="cuda") if req_num1 > 0: - draft_next_token_ids_gpu1[0:req_num1].copy_( - next_token_ids[req_num0 : (req_num0 + req_num1)], non_blocking=True + b_has_out1 = torch.tensor( + model_input1.b_prefill_has_output_cpu[0:req_num1], dtype=torch.bool, device="cuda" + ) + b_chunked_next_token_ids1 = torch.tensor( + model_input1.b_chunked_prefill_next_token_ids_cpu[0:req_num1], dtype=torch.int64, device="cuda" + ) + mtp_next_token_ids1 = torch.where( + b_has_out1, next_token_ids[req_num0 : (req_num0 + req_num1)], b_chunked_next_token_ids1 ) + draft_next_token_ids_gpu1[0:req_num1].copy_(mtp_next_token_ids1, non_blocking=True) draft_model_output0, draft_model_output1 = model_output0, model_output1 diff --git a/lightllm/server/router/model_infer/mode_backend/generic_padded_pre_process.py b/lightllm/server/router/model_infer/mode_backend/generic_padded_pre_process.py index 62e3628ce..0c815c446 100644 --- a/lightllm/server/router/model_infer/mode_backend/generic_padded_pre_process.py +++ b/lightllm/server/router/model_infer/mode_backend/generic_padded_pre_process.py @@ -36,6 +36,7 @@ def padded_prepare_prefill_inputs( b_ready_cache_len = [] b_mtp_index = [] b_prefill_has_output = [] + b_chunked_prefill_next_token_ids = [] for req in req_objs: @@ -43,7 +44,8 @@ def padded_prepare_prefill_inputs( batch_multimodal_params.append(req.multimodal_params) b_req_idx.append(req.req_idx) - input_token_ids = req.get_chuncked_input_token_ids() + input_token_ids, next_token_id = req.get_chuncked_input_token_ids() + b_chunked_prefill_next_token_ids.append(next_token_id) b_prefill_has_output.append(False if len(input_token_ids) < req.get_cur_total_len() else True) seq_len = len(input_token_ids) input_token_len = seq_len - req.cur_kv_len @@ -65,6 +67,7 @@ def padded_prepare_prefill_inputs( b_q_seq_len.append(1) b_mtp_index.append(0) b_prefill_has_output.append(False) + b_chunked_prefill_next_token_ids.append(-1) b_ready_cache_len.append(0) total_token_num += 1 prefix_total_token_num += 0 @@ -112,6 +115,7 @@ def padded_prepare_prefill_inputs( b_ready_cache_len=b_ready_cache_len, is_prefill=True, b_prefill_has_output_cpu=b_prefill_has_output, + b_chunked_prefill_next_token_ids_cpu=b_chunked_prefill_next_token_ids, ) if is_multimodal: model_input.multimodal_params = batch_multimodal_params diff --git a/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py b/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py index 7939a98ab..66c171506 100644 --- a/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py +++ b/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py @@ -20,6 +20,7 @@ def prepare_prefill_inputs( b_ready_cache_len = [] b_mtp_index = [] b_prefill_has_output = [] + b_chunked_prefill_next_token_ids = [] for req in req_objs: run_reqs.append(req) @@ -27,7 +28,8 @@ def prepare_prefill_inputs( b_req_idx.append(req.req_idx) if is_chuncked_mode: - input_token_ids = req.get_chuncked_input_token_ids() + input_token_ids, next_token_id = req.get_chuncked_input_token_ids() + b_chunked_prefill_next_token_ids.append(next_token_id) else: input_token_ids = req.get_input_token_ids() @@ -80,6 +82,7 @@ def prepare_prefill_inputs( b_ready_cache_len=b_ready_cache_len, is_prefill=True, b_prefill_has_output_cpu=b_prefill_has_output, + b_chunked_prefill_next_token_ids_cpu=b_chunked_prefill_next_token_ids, prefix_total_token_num=prefix_total_token_num, ) if is_multimodal: From 64fda2b73cca8cf6891e729bc076e10b36ac5efb Mon Sep 17 00:00:00 2001 From: sufubao Date: Tue, 14 Oct 2025 15:04:39 +0800 Subject: [PATCH 2/2] clean code --- lightllm/common/basemodel/batch_objs.py | 2 +- .../server/router/model_infer/infer_batch.py | 3 +- .../model_infer/mode_backend/base_backend.py | 6 ++++ .../mode_backend/chunked_prefill/impl.py | 10 ++---- .../mode_backend/dp_backend/impl.py | 33 +++++-------------- .../generic_padded_pre_process.py | 8 ++--- .../mode_backend/generic_pre_process.py | 7 ++-- 7 files changed, 28 insertions(+), 41 deletions(-) diff --git a/lightllm/common/basemodel/batch_objs.py b/lightllm/common/basemodel/batch_objs.py index 8a8bac1a3..ca924e367 100644 --- a/lightllm/common/basemodel/batch_objs.py +++ b/lightllm/common/basemodel/batch_objs.py @@ -31,7 +31,7 @@ class ModelInput: # prefill 阶段使用的参数,但是不是推理过程使用的参数,是推理外部进行资源管理 # 的一些变量 b_prefill_has_output_cpu: List[bool] = None # 标记进行prefill的请求是否具有输出 - b_chunked_prefill_next_token_ids_cpu: List[int] = None # for chunked prefill mtp + b_next_chunck_first_token_ids_cpu: List[int] = None # for chuncked prefill mtp # 专有变量,用于一些特殊的模型,特殊的模式下, 传递一些特殊 # 的输入变量。只在特殊的模型模式下才会具体使用和生效。 diff --git a/lightllm/server/router/model_infer/infer_batch.py b/lightllm/server/router/model_infer/infer_batch.py index 88418b6f6..e96c9abec 100644 --- a/lightllm/server/router/model_infer/infer_batch.py +++ b/lightllm/server/router/model_infer/infer_batch.py @@ -397,7 +397,8 @@ def get_chuncked_input_token_ids(self): if chunked_end < self.get_cur_total_len(): next_token_id = self.shm_req.shm_prompt_ids.arr[chunked_end] else: - next_token_id = -1 # last chunk + # padding id for last chunck, will be discarded. + next_token_id = self.shm_req.shm_prompt_ids.arr[0] return self.shm_req.shm_prompt_ids.arr[0:chunked_end], next_token_id diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index 07dfc19fa..8127848c7 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -658,6 +658,7 @@ def _sample_and_scatter_token( is_prefill: bool, b_prefill_has_output_cpu: torch.Tensor = None, mask_func: Optional[Callable] = None, + b_next_chunck_first_token_ids_cpu: torch.Tensor = None, ): if mask_func is not None: @@ -670,6 +671,11 @@ def _sample_and_scatter_token( b_has_out = g_pin_mem_manager.gen_from_list( key="b_has_out", data=b_prefill_has_output_cpu, dtype=torch.bool ).cuda(non_blocking=True) + if b_next_chunck_first_token_ids_cpu is not None: + b_next_chunck_first_token_ids = g_pin_mem_manager.gen_from_list( + key="b_next_chunck_first_token_ids", data=b_next_chunck_first_token_ids_cpu, dtype=torch.int64 + ).cuda(non_blocking=True) + next_token_ids = torch.where(b_has_out, next_token_ids, b_next_chunck_first_token_ids) scatter_token( next_token_ids=next_token_ids, diff --git a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py index 17977b4b7..10a31fe4c 100644 --- a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py @@ -190,16 +190,10 @@ def prefill_mtp( is_prefill=True, b_prefill_has_output_cpu=model_input.b_prefill_has_output_cpu, mask_func=self.prefill_mask_func, + b_next_chunck_first_token_ids_cpu=model_input.b_next_chunck_first_token_ids_cpu, ) - # mtp kv fill - b_has_out = torch.tensor(model_input.b_prefill_has_output_cpu, dtype=torch.bool, device="cuda") - b_chunked_next_token_ids = torch.tensor( - model_input.b_chunked_prefill_next_token_ids_cpu, dtype=torch.int64, device="cuda" - ) - mtp_next_token_ids = torch.where(b_has_out, next_token_ids, b_chunked_next_token_ids) - self._draft_prefill_forward( - model_input=model_input, model_output=model_output, next_token_ids=mtp_next_token_ids + model_input=model_input, model_output=model_output, next_token_ids=next_token_ids ) sync_event = torch.cuda.Event() sync_event.record() diff --git a/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py b/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py index 22be58dba..3f1ba321b 100644 --- a/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py @@ -354,13 +354,7 @@ def prefill_mtp(self, event_pack: OverlapEventPack, prefill_reqs: List[InferReq] # mtp kv fill draft_next_token_ids_gpu = torch.zeros((model_input.batch_size), dtype=torch.int64, device="cuda") if req_num > 0: - b_has_out = torch.tensor(b_has_out_cpu, dtype=torch.bool, device="cuda") - b_chunked_next_token_ids = torch.tensor( - model_input.b_chunked_prefill_next_token_ids_cpu[0:req_num], dtype=torch.int64, device="cuda" - ) - mtp_next_token_ids = torch.where(b_has_out, next_token_ids, b_chunked_next_token_ids) - draft_next_token_ids_gpu[0:req_num].copy_(mtp_next_token_ids) - + draft_next_token_ids_gpu[0:req_num].copy_(next_token_ids) self._draft_prefill_forward( model_input=model_input, model_output=model_output, @@ -622,6 +616,10 @@ def prefill_overlap_mtp(self, event_pack: OverlapEventPack, prefill_reqs: List[I b_has_out_cpu = ( model_input0.b_prefill_has_output_cpu[0:req_num0] + model_input1.b_prefill_has_output_cpu[0:req_num1] ) + b_next_chunck_first_token_ids_cpu = ( + model_input0.b_next_chunck_first_token_ids_cpu[0:req_num0] + + model_input1.b_next_chunck_first_token_ids_cpu[0:req_num1] + ) b_mtp_index = torch.cat((model_input0.b_mtp_index[0:req_num0], model_input1.b_mtp_index[0:req_num1]), dim=0) b_req_idx = torch.cat((model_input0.b_req_idx[0:req_num0], model_input1.b_req_idx[0:req_num1]), dim=0) @@ -633,33 +631,20 @@ def prefill_overlap_mtp(self, event_pack: OverlapEventPack, prefill_reqs: List[I b_mtp_index=b_mtp_index, is_prefill=True, b_prefill_has_output_cpu=b_has_out_cpu, + b_next_chunck_first_token_ids_cpu=b_next_chunck_first_token_ids_cpu, ) # spec prefill: MTP draft_model_input0, draft_model_input1 = model_input0, model_input1 draft_next_token_ids_gpu0 = torch.zeros((model_input0.batch_size), dtype=torch.int64, device="cuda") if req_num0 > 0: - b_has_out0 = torch.tensor( - model_input0.b_prefill_has_output_cpu[0:req_num0], dtype=torch.bool, device="cuda" - ) - b_chunked_next_token_ids0 = torch.tensor( - model_input0.b_chunked_prefill_next_token_ids_cpu[0:req_num0], dtype=torch.int64, device="cuda" - ) - mtp_next_token_ids0 = torch.where(b_has_out0, next_token_ids[0:req_num0], b_chunked_next_token_ids0) - draft_next_token_ids_gpu0[0:req_num0].copy_(mtp_next_token_ids0, non_blocking=True) + draft_next_token_ids_gpu0[0:req_num0].copy_(next_token_ids[0:req_num0], non_blocking=True) draft_next_token_ids_gpu1 = torch.zeros((model_input1.batch_size), dtype=torch.int64, device="cuda") if req_num1 > 0: - b_has_out1 = torch.tensor( - model_input1.b_prefill_has_output_cpu[0:req_num1], dtype=torch.bool, device="cuda" - ) - b_chunked_next_token_ids1 = torch.tensor( - model_input1.b_chunked_prefill_next_token_ids_cpu[0:req_num1], dtype=torch.int64, device="cuda" - ) - mtp_next_token_ids1 = torch.where( - b_has_out1, next_token_ids[req_num0 : (req_num0 + req_num1)], b_chunked_next_token_ids1 + draft_next_token_ids_gpu1[0:req_num1].copy_( + next_token_ids[req_num0 : (req_num0 + req_num1)], non_blocking=True ) - draft_next_token_ids_gpu1[0:req_num1].copy_(mtp_next_token_ids1, non_blocking=True) draft_model_output0, draft_model_output1 = model_output0, model_output1 diff --git a/lightllm/server/router/model_infer/mode_backend/generic_padded_pre_process.py b/lightllm/server/router/model_infer/mode_backend/generic_padded_pre_process.py index 0c815c446..a2c74da79 100644 --- a/lightllm/server/router/model_infer/mode_backend/generic_padded_pre_process.py +++ b/lightllm/server/router/model_infer/mode_backend/generic_padded_pre_process.py @@ -36,7 +36,7 @@ def padded_prepare_prefill_inputs( b_ready_cache_len = [] b_mtp_index = [] b_prefill_has_output = [] - b_chunked_prefill_next_token_ids = [] + b_next_chunck_first_token_ids = [] for req in req_objs: @@ -45,7 +45,7 @@ def padded_prepare_prefill_inputs( b_req_idx.append(req.req_idx) input_token_ids, next_token_id = req.get_chuncked_input_token_ids() - b_chunked_prefill_next_token_ids.append(next_token_id) + b_next_chunck_first_token_ids.append(next_token_id) b_prefill_has_output.append(False if len(input_token_ids) < req.get_cur_total_len() else True) seq_len = len(input_token_ids) input_token_len = seq_len - req.cur_kv_len @@ -67,7 +67,7 @@ def padded_prepare_prefill_inputs( b_q_seq_len.append(1) b_mtp_index.append(0) b_prefill_has_output.append(False) - b_chunked_prefill_next_token_ids.append(-1) + b_next_chunck_first_token_ids.append(0) b_ready_cache_len.append(0) total_token_num += 1 prefix_total_token_num += 0 @@ -115,7 +115,7 @@ def padded_prepare_prefill_inputs( b_ready_cache_len=b_ready_cache_len, is_prefill=True, b_prefill_has_output_cpu=b_prefill_has_output, - b_chunked_prefill_next_token_ids_cpu=b_chunked_prefill_next_token_ids, + b_next_chunck_first_token_ids_cpu=b_next_chunck_first_token_ids, ) if is_multimodal: model_input.multimodal_params = batch_multimodal_params diff --git a/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py b/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py index 66c171506..d0a40e0c1 100644 --- a/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py +++ b/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py @@ -20,7 +20,7 @@ def prepare_prefill_inputs( b_ready_cache_len = [] b_mtp_index = [] b_prefill_has_output = [] - b_chunked_prefill_next_token_ids = [] + b_next_chunck_first_token_ids = [] for req in req_objs: run_reqs.append(req) @@ -29,7 +29,7 @@ def prepare_prefill_inputs( if is_chuncked_mode: input_token_ids, next_token_id = req.get_chuncked_input_token_ids() - b_chunked_prefill_next_token_ids.append(next_token_id) + b_next_chunck_first_token_ids.append(next_token_id) else: input_token_ids = req.get_input_token_ids() @@ -59,6 +59,7 @@ def prepare_prefill_inputs( b_seq_len = torch.tensor(b_seq_len, dtype=torch.int32, device="cpu") b_mtp_index = torch.tensor(b_mtp_index, dtype=torch.int32, device="cpu") b_ready_cache_len = torch.tensor(b_ready_cache_len, dtype=torch.int32, device="cpu") + b_next_chunck_first_token_ids = torch.tensor(b_next_chunck_first_token_ids, dtype=torch.int64, device="cpu") # dynamic prompt cache 准备 token g_infer_state_lock.acquire() @@ -82,7 +83,7 @@ def prepare_prefill_inputs( b_ready_cache_len=b_ready_cache_len, is_prefill=True, b_prefill_has_output_cpu=b_prefill_has_output, - b_chunked_prefill_next_token_ids_cpu=b_chunked_prefill_next_token_ids, + b_next_chunck_first_token_ids_cpu=b_next_chunck_first_token_ids, prefix_total_token_num=prefix_total_token_num, ) if is_multimodal: