diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py index fae9f9dc5..5a9a9874e 100755 --- a/lightllm/common/basemodel/basemodel.py +++ b/lightllm/common/basemodel/basemodel.py @@ -81,7 +81,7 @@ def __init__(self, kvargs): self.tp_world_size_ = get_dp_world_size() self.enable_tpsp_mix_mode = get_env_start_args().enable_tpsp_mix_mode - self.is_deepseekv3_mtp_mode = self.args.mtp_mode == "deepseekv3" + self.is_deepseekv3_mtp_mode = self.args.mtp_mode in ["deepseekv3_vanilla", "deepseekv3_eagle"] self._init_datatype() self._init_config() @@ -262,10 +262,8 @@ def _create_inferstate(self, model_input: ModelInput, microbatch_index: int = 0) infer_state.b_req_idx = model_input.b_req_idx infer_state.b_seq_len = model_input.b_seq_len if model_input.is_prefill: - if model_input.b_ready_cache_len is not None: - infer_state.b_ready_cache_len = model_input.b_ready_cache_len - else: - infer_state.b_ready_cache_len = torch.zeros_like(input=infer_state.b_seq_len) + assert model_input.b_ready_cache_len is not None + infer_state.b_ready_cache_len = model_input.b_ready_cache_len infer_state.multimodal_params = model_input.multimodal_params @@ -337,14 +335,14 @@ def _prefill( infer_state = self._create_inferstate(model_input) init_req_to_token_indexes( self.req_manager.req_to_token_indexs, - model_input.b_req_idx, - model_input.b_seq_len, - infer_state.b_ready_cache_len, + model_input.b_req_idx_cpu, + model_input.b_seq_len_cpu, + model_input.b_ready_cache_len_cpu, model_input.max_len_in_batch, infer_state.mem_index, ) - infer_state.init_some_extra_state(self, model_input.input_ids) + infer_state.init_some_extra_state(self, model_input) return self._context_forward(model_input.input_ids, infer_state) def _decode( @@ -369,7 +367,7 @@ def _decode( infer_state.b_seq_len, infer_state.mem_index, ) - infer_state.init_some_extra_state(self, padded_model_input.input_ids) + infer_state.init_some_extra_state(self, padded_model_input) if self.graph.need_capture(find_graph_batch_size): infer_state.is_cuda_graph = True @@ -390,7 +388,7 @@ def _decode( infer_state.b_seq_len, infer_state.mem_index, ) - infer_state.init_some_extra_state(self, model_input.input_ids) + infer_state.init_some_extra_state(self, model_input) model_output = self._token_forward(model_input.input_ids, infer_state) return model_output @@ -482,7 +480,7 @@ def microbatch_overlap_prefill(self, model_input0: ModelInput, model_input1: Mod model_input0.max_len_in_batch, infer_state0.mem_index, ) - infer_state0.init_some_extra_state(self, input_ids0) + infer_state0.init_some_extra_state(self, model_input0) infer_state1 = self._create_inferstate(model_input1, 1) init_req_to_token_indexes( @@ -493,7 +491,7 @@ def microbatch_overlap_prefill(self, model_input0: ModelInput, model_input1: Mod model_input1.max_len_in_batch, infer_state1.mem_index, ) - infer_state1.init_some_extra_state(self, input_ids1) + infer_state1.init_some_extra_state(self, model_input1) model_output0, model_output1 = self._overlap_tpsp_context_forward( input_ids0, infer_state0, input_ids1=input_ids1, infer_state1=infer_state1 @@ -540,7 +538,7 @@ def microbatch_overlap_decode(self, model_input0: ModelInput, model_input1: Mode infer_state0.b_seq_len, infer_state0.mem_index, ) - infer_state0.init_some_extra_state(self, padded_model_input0.input_ids) + infer_state0.init_some_extra_state(self, padded_model_input0) infer_state1 = self._create_inferstate(padded_model_input1, 1) copy_kv_index_to_req( self.req_manager.req_to_token_indexs, @@ -548,7 +546,7 @@ def microbatch_overlap_decode(self, model_input0: ModelInput, model_input1: Mode infer_state1.b_seq_len, infer_state1.mem_index, ) - infer_state1.init_some_extra_state(self, padded_model_input1.input_ids) + infer_state1.init_some_extra_state(self, padded_model_input1) if self.graph.need_capture(find_graph_batch_size): infer_state0.is_cuda_graph = True @@ -578,7 +576,7 @@ def microbatch_overlap_decode(self, model_input0: ModelInput, model_input1: Mode infer_state0.b_seq_len, infer_state0.mem_index, ) - infer_state0.init_some_extra_state(self, model_input0.input_ids) + infer_state0.init_some_extra_state(self, model_input0) infer_state1 = self._create_inferstate(model_input1, 1) copy_kv_index_to_req( self.req_manager.req_to_token_indexs, @@ -586,7 +584,7 @@ def microbatch_overlap_decode(self, model_input0: ModelInput, model_input1: Mode infer_state1.b_seq_len, infer_state1.mem_index, ) - infer_state1.init_some_extra_state(self, model_input1.input_ids) + infer_state1.init_some_extra_state(self, model_input1) model_output0, model_output1 = self._overlap_tpsp_token_forward( model_input0.input_ids, infer_state0, input_ids1=model_input1.input_ids, infer_state1=infer_state1 @@ -684,25 +682,25 @@ def _check_max_len_infer(self): # 模拟最大长度进行 prefill,观察是否出现 OOM try: logger.info("begin check max_len infer") - dummy_input_ids = torch.ones(self.batch_max_tokens, dtype=torch.int32, device="cuda") - b_req_idx = torch.tensor([self.req_manager.alloc()], dtype=torch.int32, device="cuda") - mem_indexes = self.mem_manager.alloc(len(dummy_input_ids)).cuda() - b_seq_len = torch.ones(1, dtype=torch.int32, device="cuda") + dummy_input_ids = torch.ones(self.batch_max_tokens, dtype=torch.int32, device="cpu") + b_req_idx = torch.tensor([self.req_manager.alloc()], dtype=torch.int32, device="cpu") + mem_indexes = self.mem_manager.alloc(len(dummy_input_ids)) + b_seq_len = torch.ones(1, dtype=torch.int32, device="cpu") b_seq_len[:] = self.batch_max_tokens - b_ready_cache_len = torch.zeros(1, dtype=torch.int32, device="cuda") + b_ready_cache_len = torch.zeros(1, dtype=torch.int32, device="cpu") total_token_num = self.batch_max_tokens - b_mtp_index = torch.zeros(1, dtype=torch.int32, device="cuda") + b_mtp_index = torch.zeros(1, dtype=torch.int32, device="cpu") model_input = ModelInput( batch_size=1, total_token_num=total_token_num, max_len_in_batch=self.batch_max_tokens, - input_ids=dummy_input_ids, - mem_indexes=mem_indexes, - b_req_idx=b_req_idx, - b_seq_len=b_seq_len, - b_mtp_index=b_mtp_index, + input_ids_cpu=dummy_input_ids, + mem_indexes_cpu=mem_indexes, + b_req_idx_cpu=b_req_idx, + b_seq_len_cpu=b_seq_len, + b_mtp_index_cpu=b_mtp_index, is_prefill=True, - b_ready_cache_len=b_ready_cache_len, + b_ready_cache_len_cpu=b_ready_cache_len, ) model_output = self.forward( model_input, @@ -750,29 +748,29 @@ def _autotune_warmup(self): self.layers_num = self.autotune_layers() for input_len in tqdm(warmup_lengths, desc="warming up"): try: - rand_gen = torch.Generator(device="cuda") + rand_gen = torch.Generator(device="cpu") rand_gen.manual_seed(input_len) dummy_input_ids = torch.randint( - 0, 10000, (input_len,), dtype=torch.int32, device="cuda", generator=rand_gen + 0, 10000, (input_len,), dtype=torch.int32, device="cpu", generator=rand_gen ) - b_req_idx = torch.tensor([self.req_manager.alloc()], dtype=torch.int32, device="cuda") - mem_indexes = self.mem_manager.alloc(len(dummy_input_ids)).cuda() - b_seq_len = torch.ones(1, dtype=torch.int32, device="cuda") + b_req_idx = torch.tensor([self.req_manager.alloc()], dtype=torch.int32, device="cpu") + mem_indexes = self.mem_manager.alloc(len(dummy_input_ids)) + b_seq_len = torch.ones(1, dtype=torch.int32, device="cpu") b_seq_len[:] = input_len - b_ready_cache_len = torch.zeros(1, dtype=torch.int32, device="cuda") + b_ready_cache_len = torch.zeros(1, dtype=torch.int32, device="cpu") total_token_num = input_len - b_mtp_index = torch.zeros(1, dtype=torch.int32, device="cuda") + b_mtp_index = torch.zeros(1, dtype=torch.int32, device="cpu") model_input = ModelInput( batch_size=1, total_token_num=total_token_num, max_len_in_batch=input_len, - input_ids=dummy_input_ids, - mem_indexes=mem_indexes, - b_req_idx=b_req_idx, - b_seq_len=b_seq_len, - b_mtp_index=b_mtp_index, + input_ids_cpu=dummy_input_ids, + mem_indexes_cpu=mem_indexes, + b_req_idx_cpu=b_req_idx, + b_seq_len_cpu=b_seq_len, + b_mtp_index_cpu=b_mtp_index, is_prefill=True, - b_ready_cache_len=b_ready_cache_len, + b_ready_cache_len_cpu=b_ready_cache_len, multimodal_params=[], **self._gen_special_model_input(total_token_num), ) @@ -807,27 +805,27 @@ def _init_padded_req(self): # prefill init padding req. prefill_input_len = 1 batch_size = 1 - dummy_input_ids = torch.ones((batch_size,), dtype=torch.int32, device="cuda") + dummy_input_ids = torch.ones((batch_size,), dtype=torch.int32, device="cpu") b_req_idx = torch.tensor( - [self.req_manager.HOLD_REQUEST_ID for _ in range(batch_size)], dtype=torch.int32, device="cuda" + [self.req_manager.HOLD_REQUEST_ID for _ in range(batch_size)], dtype=torch.int32, device="cpu" ) mem_indexes = torch.tensor( - [self.mem_manager.HOLD_TOKEN_MEMINDEX for _ in range(batch_size)], dtype=torch.int32, device="cuda" + [self.mem_manager.HOLD_TOKEN_MEMINDEX for _ in range(batch_size)], dtype=torch.int32, device="cpu" ) - b_seq_len = torch.ones(batch_size, dtype=torch.int32, device="cuda") - b_ready_cache_len = torch.zeros(batch_size, dtype=torch.int32, device="cuda") + b_seq_len = torch.ones(batch_size, dtype=torch.int32, device="cpu") + b_ready_cache_len = torch.zeros(batch_size, dtype=torch.int32, device="cpu") total_token_num = prefill_input_len * batch_size - b_mtp_index = torch.zeros(batch_size, dtype=torch.int32, device="cuda") + b_mtp_index = torch.zeros(batch_size, dtype=torch.int32, device="cpu") model_input = ModelInput( batch_size=batch_size, total_token_num=total_token_num, max_len_in_batch=prefill_input_len, - input_ids=dummy_input_ids, - mem_indexes=mem_indexes, - b_req_idx=b_req_idx, - b_mtp_index=b_mtp_index, - b_seq_len=b_seq_len, - b_ready_cache_len=b_ready_cache_len, + input_ids_cpu=dummy_input_ids, + mem_indexes_cpu=mem_indexes, + b_req_idx_cpu=b_req_idx, + b_mtp_index_cpu=b_mtp_index, + b_seq_len_cpu=b_seq_len, + b_ready_cache_len_cpu=b_ready_cache_len, is_prefill=True, multimodal_params=[], **self._gen_special_model_input(total_token_num), diff --git a/lightllm/common/basemodel/batch_objs.py b/lightllm/common/basemodel/batch_objs.py index 9b317b423..2878e1bfb 100644 --- a/lightllm/common/basemodel/batch_objs.py +++ b/lightllm/common/basemodel/batch_objs.py @@ -10,17 +10,22 @@ class ModelInput: batch_size: int total_token_num: int max_len_in_batch: int - input_ids: torch.Tensor - b_req_idx: torch.Tensor - b_mtp_index: torch.Tensor - b_seq_len: torch.Tensor + input_ids: torch.Tensor = None + b_req_idx: torch.Tensor = None + b_mtp_index: torch.Tensor = None + b_seq_len: torch.Tensor = None mem_indexes: torch.Tensor = None is_prefill: bool = False b_ready_cache_len: torch.Tensor = None multimodal_params: list = field(default_factory=list) # cpu 变量 + input_ids_cpu: torch.Tensor = None + b_req_idx_cpu: torch.Tensor = None + b_mtp_index_cpu: torch.Tensor = None mem_indexes_cpu: torch.Tensor = None + b_seq_len_cpu: torch.Tensor = None + b_ready_cache_len_cpu: torch.Tensor = None # prefill 阶段使用的参数,但是不是推理过程使用的参数,是推理外部进行资源管理 # 的一些变量 b_prefill_has_output_cpu: List[bool] = None # 标记进行prefill的请求是否具有输出 @@ -33,15 +38,20 @@ class ModelInput: deepseekv3_mtp_draft_input_hiddens: Optional[torch.Tensor] = None def to_cuda(self): - if self.input_ids is not None: - self.input_ids = self.input_ids.cuda(non_blocking=True) + # input_ids 可能不存在,通过req_to_token_indexs来获取 + if self.input_ids is None and self.input_ids_cpu is not None: + self.input_ids = self.input_ids_cpu.cuda(non_blocking=True) if self.mem_indexes is None: self.mem_indexes = self.mem_indexes_cpu.cuda(non_blocking=True) - self.b_req_idx = self.b_req_idx.cuda(non_blocking=True) - self.b_seq_len = self.b_seq_len.cuda(non_blocking=True) - self.b_mtp_index = self.b_mtp_index.cuda(non_blocking=True) - if self.b_ready_cache_len is not None: - self.b_ready_cache_len = self.b_ready_cache_len.cuda(non_blocking=True) + if self.b_req_idx is None: + self.b_req_idx = self.b_req_idx_cpu.cuda(non_blocking=True) + if self.b_seq_len is None: + self.b_seq_len = self.b_seq_len_cpu.cuda(non_blocking=True) + # b_ready_cache_len 只在 prefill 阶段生效 + if self.b_ready_cache_len_cpu is not None: + self.b_ready_cache_len = self.b_ready_cache_len_cpu.cuda(non_blocking=True) + if self.b_mtp_index is None: + self.b_mtp_index = self.b_mtp_index_cpu.cuda(non_blocking=True) @dataclass diff --git a/lightllm/common/basemodel/cuda_graph.py b/lightllm/common/basemodel/cuda_graph.py index 07792865e..0f2644120 100644 --- a/lightllm/common/basemodel/cuda_graph.py +++ b/lightllm/common/basemodel/cuda_graph.py @@ -195,25 +195,27 @@ def warmup(self, model): seq_len = 2 total_token_num = batch_size * seq_len max_len_in_batch = self.graph_max_len_in_batch - input_ids = torch.tensor([1 for _ in range(batch_size)], dtype=torch.int32, device="cuda") + input_ids = torch.tensor([1 for _ in range(batch_size)], dtype=torch.int32, device="cpu") mem_indexes = model.mem_manager.alloc(len(input_ids)).cuda() b_req_idx = torch.tensor( - [model.req_manager.HOLD_REQUEST_ID for _ in range(batch_size)], dtype=torch.int32, device="cuda" + [model.req_manager.HOLD_REQUEST_ID for _ in range(batch_size)], dtype=torch.int32, device="cpu" ) - b_seq_len = torch.empty(batch_size, dtype=torch.int32, device="cuda") + b_seq_len = torch.empty(batch_size, dtype=torch.int32, device="cpu") b_seq_len.fill_(seq_len) - b_mtp_index = torch.zeros(batch_size, dtype=torch.int32, device="cuda") + b_ready_cache_len = torch.zeros(batch_size, dtype=torch.int32, device="cpu") + b_mtp_index = torch.zeros(batch_size, dtype=torch.int32, device="cpu") model_input = ModelInput( batch_size=batch_size, total_token_num=total_token_num, max_len_in_batch=max_len_in_batch, - input_ids=input_ids, - mem_indexes=mem_indexes, - b_req_idx=b_req_idx, - b_seq_len=b_seq_len, - b_mtp_index=b_mtp_index, + input_ids_cpu=input_ids, + mem_indexes_cpu=mem_indexes, + b_req_idx_cpu=b_req_idx, + b_seq_len_cpu=b_seq_len, + b_mtp_index_cpu=b_mtp_index, is_prefill=False, + b_ready_cache_len_cpu=b_ready_cache_len, **model._gen_special_model_input(batch_size), ) model_output: ModelOutput = model.forward(model_input) @@ -251,25 +253,25 @@ def warmup_overlap(self, model): seq_len = 2 total_token_num = batch_size * seq_len max_len_in_batch = self.graph_max_len_in_batch - input_ids = torch.tensor([1 for _ in range(batch_size)], dtype=torch.int32, device="cuda") + input_ids = torch.tensor([1 for _ in range(batch_size)], dtype=torch.int32, device="cpu") mem_indexes = model.mem_manager.alloc(len(input_ids)).cuda() b_req_idx = torch.tensor( - [model.req_manager.HOLD_REQUEST_ID for _ in range(batch_size)], dtype=torch.int32, device="cuda" + [model.req_manager.HOLD_REQUEST_ID for _ in range(batch_size)], dtype=torch.int32, device="cpu" ) - b_seq_len = torch.empty(batch_size, dtype=torch.int32, device="cuda") + b_seq_len = torch.empty(batch_size, dtype=torch.int32, device="cpu") b_seq_len.fill_(seq_len) - b_mtp_index = torch.zeros(batch_size, dtype=torch.int32, device="cuda") + b_mtp_index = torch.zeros(batch_size, dtype=torch.int32, device="cpu") micro_batch = ModelInput( is_prefill=False, batch_size=batch_size, total_token_num=total_token_num, max_len_in_batch=max_len_in_batch, - input_ids=input_ids, + input_ids_cpu=input_ids, b_mtp_index=b_mtp_index, - mem_indexes=mem_indexes, - b_req_idx=b_req_idx, - b_seq_len=b_seq_len, + mem_indexes_cpu=mem_indexes, + b_req_idx_cpu=b_req_idx, + b_seq_len_cpu=b_seq_len, **model._gen_special_model_input(batch_size), ) decode_batches.append(micro_batch) diff --git a/lightllm/common/basemodel/infer_struct.py b/lightllm/common/basemodel/infer_struct.py index 79131677e..f99129e0d 100755 --- a/lightllm/common/basemodel/infer_struct.py +++ b/lightllm/common/basemodel/infer_struct.py @@ -64,7 +64,7 @@ def __init__(self): # 的输入会用到,其他模型和场景都不会用到 self.deepseekv3_mtp_draft_input_hiddens: Optional[torch.Tensor] = None - def init_some_extra_state(self, model, input_ids: torch.Tensor): + def init_some_extra_state(self, model, model_input: ModelInput): if self.is_prefill: ( self.b_q_seq_len, @@ -75,9 +75,7 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor): self.max_q_seq_len, self.max_kv_seq_len, ) = gen_prefill_params( - input_token_num=input_ids.shape[0], - b_ready_cache_len=self.b_ready_cache_len, - b_seq_len=self.b_seq_len, + model_input, ) self.b_start_loc = self.b1_cu_q_seq_len[0:-1] else: diff --git a/lightllm/common/basemodel/triton_kernel/gen_prefill_params.py b/lightllm/common/basemodel/triton_kernel/gen_prefill_params.py index 9b4d1d814..a7fed6325 100644 --- a/lightllm/common/basemodel/triton_kernel/gen_prefill_params.py +++ b/lightllm/common/basemodel/triton_kernel/gen_prefill_params.py @@ -3,6 +3,8 @@ import triton import triton.language as tl +from lightllm.common.basemodel.batch_objs import ModelInput + @triton.jit def _gen_cumsum_pad0_kernel( @@ -80,7 +82,14 @@ def _gen_prefill_position( @torch.no_grad() -def gen_prefill_params(input_token_num: int, b_ready_cache_len: torch.Tensor, b_seq_len: torch.Tensor): +def gen_prefill_params(model_input: ModelInput): + # input_token_num: int, b_ready_cache_len: torch.Tensor, b_seq_len: torch.Tensor): + input_token_num = model_input.input_ids.shape[0] + b_seq_len = model_input.b_seq_len + b_ready_cache_len = model_input.b_ready_cache_len + b_seq_len_cpu = model_input.b_seq_len_cpu + b_ready_cache_len_cpu = model_input.b_ready_cache_len_cpu + batch_size = b_ready_cache_len.shape[0] position_ids = torch.empty((input_token_num,), dtype=torch.int32, device="cuda") assert b_ready_cache_len.shape[0] == b_seq_len.shape[0] @@ -99,6 +108,6 @@ def gen_prefill_params(input_token_num: int, b_ready_cache_len: torch.Tensor, b_ num_stages=1, ) b_kv_seq_len = b_seq_len - max_q_seq_len = b_q_seq_len.max().item() - max_kv_seq_len = b_kv_seq_len.max().item() + max_q_seq_len = (b_seq_len_cpu - b_ready_cache_len_cpu).max() + max_kv_seq_len = b_seq_len_cpu.max() return b_q_seq_len, b1_cu_q_seq_len, b_kv_seq_len, b1_cu_kv_seq_len, position_ids, max_q_seq_len, max_kv_seq_len diff --git a/lightllm/common/infer_utils.py b/lightllm/common/infer_utils.py index da2f35e08..35b3dc838 100644 --- a/lightllm/common/infer_utils.py +++ b/lightllm/common/infer_utils.py @@ -1,14 +1,14 @@ def init_req_to_token_indexes( - req_to_token_indexs, b_req_idx, b_seq_len, b_ready_cache_len, max_len_in_batch, alloc_mem_index + req_to_token_indexs, b_req_idx_cpu, b_seq_len_cpu, b_ready_cache_len_cpu, max_len_in_batch, alloc_mem_index ): start_index = 0 - b_seq_len_numpy = b_seq_len.cpu().numpy() - b_ready_cache_len_numpy = b_ready_cache_len.cpu().numpy() - b_req_idx_numpy = b_req_idx.cpu().numpy() - for i in range(len(b_seq_len)): - cur_seq_len = b_seq_len_numpy[i] - cur_ready_cache_len = b_ready_cache_len_numpy[i] - req_to_token_indexs[b_req_idx_numpy[i], cur_ready_cache_len:cur_seq_len] = alloc_mem_index[ + # b_seq_len_numpy = b_seq_len.cpu().numpy() + # b_ready_cache_len_numpy = b_ready_cache_len.cpu().numpy() + # b_req_idx_numpy = b_req_idx.cpu().numpy() + for i in range(b_seq_len_cpu.shape[0]): + cur_seq_len = b_seq_len_cpu[i] + cur_ready_cache_len = b_ready_cache_len_cpu[i] + req_to_token_indexs[b_req_idx_cpu[i], cur_ready_cache_len:cur_seq_len] = alloc_mem_index[ start_index : start_index + cur_seq_len - cur_ready_cache_len ] start_index += cur_seq_len - cur_ready_cache_len diff --git a/lightllm/models/deepseek2/flashattention_infer_struct.py b/lightllm/models/deepseek2/flashattention_infer_struct.py index d2ae055ce..9cd26679b 100644 --- a/lightllm/models/deepseek2/flashattention_infer_struct.py +++ b/lightllm/models/deepseek2/flashattention_infer_struct.py @@ -4,6 +4,7 @@ import torch.distributed as dist from lightllm.models.deepseek2.infer_struct import Deepseek2InferStateInfo from lightllm.utils.dist_utils import get_current_device_id +from lightllm.common.basemodel.batch_objs import ModelInput class Deepseek2FlashAttentionStateInfo(Deepseek2InferStateInfo): @@ -21,8 +22,9 @@ def get_page_table_buffer(cls, graph_max_batch_size: int, max_seq_len: int): ] return cls._shared_page_table_buffer - def init_some_extra_state(self, model, input_ids: torch.Tensor): - super().init_some_extra_state(model, input_ids) + def init_some_extra_state(self, model, model_input: ModelInput): + device = model_input.input_ids.device + super().init_some_extra_state(model, model_input) if self.is_prefill: self.cu_seqlens_q = self.b1_cu_q_seq_len self.cu_seqlens_k = self.b1_cu_kv_seq_len @@ -46,12 +48,9 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor): : self.batch_size * model.graph_max_len_in_batch ].reshape(self.batch_size, model.graph_max_len_in_batch) else: - self.page_table = torch.empty((self.batch_size, self.max_len_in_batch), dtype=torch.int32).to( - input_ids.device - ) + self.page_table = torch.empty((self.batch_size, self.max_len_in_batch), dtype=torch.int32).to(device) self.page_table[:, :max_seq_len_k].copy_( model.req_manager.req_to_token_indexs[self.b_req_idx, :max_seq_len_k] ) - self.page_table[:, max_seq_len_k:].fill_(0) return diff --git a/lightllm/models/deepseek2/flashinfer_struct.py b/lightllm/models/deepseek2/flashinfer_struct.py index a00c45601..c4542ba8e 100644 --- a/lightllm/models/deepseek2/flashinfer_struct.py +++ b/lightllm/models/deepseek2/flashinfer_struct.py @@ -5,6 +5,7 @@ from lightllm.models.deepseek2.infer_struct import Deepseek2InferStateInfo from lightllm.utils.envs_utils import get_env_start_args from lightllm.models.deepseek2.triton_kernel.repack_kv_index import repack_kv_index +from lightllm.common.basemodel.batch_objs import ModelInput class Deepseek2FlashInferStateInfo(Deepseek2InferStateInfo): @@ -14,15 +15,16 @@ def __init__(self): self.decode_wrapper = None self.flashinfer_extra_state = None - def init_some_extra_state(self, model, input_ids: torch.Tensor): - super().init_some_extra_state(model, input_ids) + def init_some_extra_state(self, model, model_input: ModelInput): + device = model_input.input_ids.device + super().init_some_extra_state(model, model_input) self.flashinfer_extra_state = model.flashinfer_extra_state import flashinfer if not self.is_prefill: if get_env_start_args().enable_flashinfer_decode: - self.q_indptr = torch.arange(self.batch_size + 1, dtype=torch.int32).to(input_ids.device) + self.q_indptr = torch.arange(self.batch_size + 1, dtype=torch.int32).to(device) if self.batch_size <= model.graph_max_batch_size: self.kv_indices = self.flashinfer_extra_state.kv_indices_buffer[self.microbatch_index][ : self.batch_size * self.flashinfer_extra_state.max_seq_length @@ -31,7 +33,7 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor): self.kv_indices = torch.empty( self.batch_size * self.flashinfer_extra_state.max_seq_length, dtype=torch.int32, - device=input_ids.device, + device=device, ) repack_kv_index( self.req_manager.req_to_token_indexs, diff --git a/lightllm/models/deepseek2/infer_struct.py b/lightllm/models/deepseek2/infer_struct.py index 021f7a123..d5d061617 100644 --- a/lightllm/models/deepseek2/infer_struct.py +++ b/lightllm/models/deepseek2/infer_struct.py @@ -3,6 +3,7 @@ import numpy as np import torch.distributed as dist from lightllm.models.llama.infer_struct import LlamaInferStateInfo +from lightllm.common.basemodel.batch_objs import ModelInput class Deepseek2InferStateInfo(LlamaInferStateInfo): @@ -10,12 +11,12 @@ def __init__(self): super().__init__() self.kv_starts = None - def init_some_extra_state(self, model, input_ids: torch.Tensor): - super().init_some_extra_state(model, input_ids) + def init_some_extra_state(self, model, model_input: ModelInput): + super().init_some_extra_state(model, model_input) if not self.is_prefill: self.kv_starts = self.b1_cu_kv_seq_len if self.is_prefill: self.b1_kv_start_loc = self.b1_cu_kv_seq_len - self.max_value_in_b_seq_len = self.b_seq_len.max().item() + self.max_value_in_b_seq_len = self.max_kv_seq_len return diff --git a/lightllm/models/gemma3/infer_struct.py b/lightllm/models/gemma3/infer_struct.py index 4145124af..c29d63352 100644 --- a/lightllm/models/gemma3/infer_struct.py +++ b/lightllm/models/gemma3/infer_struct.py @@ -2,6 +2,7 @@ import numpy as np from lightllm.common.basemodel import InferStateInfo from lightllm.common.req_manager import ReqManager +from lightllm.common.basemodel.batch_objs import ModelInput class Gemma3InferStateInfo(InferStateInfo): @@ -12,8 +13,8 @@ def __init__(self): self.position_sin_local = None self.position_cos_local = None - def init_some_extra_state(self, model, input_ids: torch.Tensor): - super().init_some_extra_state(model, input_ids) + def init_some_extra_state(self, model, model_input: ModelInput): + super().init_some_extra_state(model, model_input) if self.is_prefill: self.max_seq_len = self.max_kv_seq_len position_ids = self.position_ids diff --git a/lightllm/models/llama/flashattention_infer_struct.py b/lightllm/models/llama/flashattention_infer_struct.py index 98f628f07..1bfbeffe7 100644 --- a/lightllm/models/llama/flashattention_infer_struct.py +++ b/lightllm/models/llama/flashattention_infer_struct.py @@ -24,7 +24,8 @@ def get_page_table_buffer(cls, graph_max_batch_size: int, max_seq_len: int): ] return cls._shared_page_table_buffer - def _init_flash_attention_state(self, model, input_ids: torch.Tensor): + def _init_flash_attention_state(self, model, model_input: ModelInput): + input_ids = model_input.input_ids if self.is_prefill: self.cu_seqlens_q = self.b1_cu_q_seq_len.int() self.cu_seqlens_k = self.b1_cu_kv_seq_len.int() @@ -93,7 +94,7 @@ def _init_flash_attention_state(self, model, input_ids: torch.Tensor): ) return - def init_some_extra_state(self, model, input_ids: torch.Tensor): - super().init_some_extra_state(model, input_ids) - self._init_flash_attention_state(model, input_ids) + def init_some_extra_state(self, model, model_input: ModelInput): + super().init_some_extra_state(model, model_input) + self._init_flash_attention_state(model, model_input) return diff --git a/lightllm/models/llama/flashinfer_struct.py b/lightllm/models/llama/flashinfer_struct.py index a0c40b57a..0a4c108ff 100644 --- a/lightllm/models/llama/flashinfer_struct.py +++ b/lightllm/models/llama/flashinfer_struct.py @@ -5,6 +5,7 @@ from lightllm.models.llama.infer_struct import LlamaInferStateInfo from lightllm.utils.envs_utils import get_env_start_args from lightllm.models.deepseek2.triton_kernel.repack_kv_index import repack_kv_index +from lightllm.common.basemodel.batch_objs import ModelInput class LlamaFlashInferStateInfo(LlamaInferStateInfo): @@ -14,17 +15,16 @@ def __init__(self): self.decode_wrapper = None self.flashinfer_extra_state = None - def init_some_extra_state(self, model, input_ids: torch.Tensor): - super().init_some_extra_state(model, input_ids) + def init_some_extra_state(self, model, model_input: ModelInput): + super().init_some_extra_state(model, model_input) self.flashinfer_extra_state = model.flashinfer_extra_state + device = model_input.input_ids.device import flashinfer if not self.is_prefill: if get_env_start_args().enable_flashinfer_decode: - self.kv_last_page_len_buffer = torch.full( - (self.batch_size,), 1, dtype=torch.int32, device=input_ids.device - ) + self.kv_last_page_len_buffer = torch.full((self.batch_size,), 1, dtype=torch.int32, device=device) if self.batch_size <= model.graph_max_batch_size: self.kv_indices = self.flashinfer_extra_state.kv_indices_buffer[self.microbatch_index][ : self.batch_size * self.flashinfer_extra_state.max_seq_length @@ -33,7 +33,7 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor): self.kv_indices = torch.empty( self.batch_size * self.flashinfer_extra_state.max_seq_length, dtype=torch.int32, - device=input_ids.device, + device=device, ) repack_kv_index( @@ -71,11 +71,11 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor): if get_env_start_args().enable_flashinfer_prefill: q_starts = self.b1_cu_q_seq_len.int() kv_starts = self.b1_cu_kv_seq_len.int() - kv_last_page_len = torch.full((self.batch_size,), 1, dtype=torch.int32, device=input_ids.device) + kv_last_page_len = torch.full((self.batch_size,), 1, dtype=torch.int32, device=device) kv_indices = torch.empty( self.batch_size * self.flashinfer_extra_state.max_seq_length, dtype=torch.int32, - device=input_ids.device, + device=device, ) repack_kv_index( self.req_manager.req_to_token_indexs, diff --git a/lightllm/models/llama/infer_struct.py b/lightllm/models/llama/infer_struct.py index 064c5770b..44b70e3c4 100644 --- a/lightllm/models/llama/infer_struct.py +++ b/lightllm/models/llama/infer_struct.py @@ -2,6 +2,7 @@ import numpy as np from lightllm.common.basemodel import InferStateInfo from lightllm.common.req_manager import ReqManager +from lightllm.common.basemodel.batch_objs import ModelInput class LlamaInferStateInfo(InferStateInfo): @@ -10,10 +11,10 @@ def __init__(self): self.position_cos = None self.position_sin = None - def init_some_extra_state(self, model, input_ids: torch.Tensor): - super().init_some_extra_state(model, input_ids) + def init_some_extra_state(self, model, model_input: ModelInput): + super().init_some_extra_state(model, model_input) if self.is_prefill: - b_ready_cache_len_numpy = self.b_ready_cache_len.cpu().numpy() + b_ready_cache_len_numpy = model_input.b_ready_cache_len_cpu.numpy() self.b_ready_cache_len_numpy = b_ready_cache_len_numpy self.max_seq_len = self.max_kv_seq_len diff --git a/lightllm/models/qwen/infer_struct.py b/lightllm/models/qwen/infer_struct.py index deff17ce2..3c131ce52 100644 --- a/lightllm/models/qwen/infer_struct.py +++ b/lightllm/models/qwen/infer_struct.py @@ -2,6 +2,7 @@ import numpy as np from lightllm.models.llama.infer_struct import LlamaInferStateInfo from lightllm.common.basemodel import InferStateInfo +from lightllm.common.basemodel.batch_objs import ModelInput class QwenInferStateInfo(LlamaInferStateInfo): @@ -11,13 +12,13 @@ def __init__(self): self.position_sin = None self.logn_values = None - def init_some_extra_state(self, model, input_ids: torch.Tensor): + def init_some_extra_state(self, model, model_input: ModelInput): use_dynamic_ntk = model.config.get("use_dynamic_ntk", False) if not use_dynamic_ntk: - super().init_some_extra_state(model, input_ids) + super().init_some_extra_state(model, model_input) return - InferStateInfo.init_some_extra_state(self, model, input_ids) + InferStateInfo.init_some_extra_state(self, model, model_input) if self.is_prefill: position_ids = self.position_ids self.position_sin = [] diff --git a/lightllm/models/qwen2_vl/flashattention_infer_struct.py b/lightllm/models/qwen2_vl/flashattention_infer_struct.py index 7d96d7370..52e30a4fc 100644 --- a/lightllm/models/qwen2_vl/flashattention_infer_struct.py +++ b/lightllm/models/qwen2_vl/flashattention_infer_struct.py @@ -11,8 +11,8 @@ class Qwen2VLFlashAttentionStateInfo(FlashAttentionStateInfo): - def init_some_extra_state(self, model, input_ids: torch.Tensor): - InferStateInfo.init_some_extra_state(self, model, input_ids) + def init_some_extra_state(self, model, model_input: ModelInput): + InferStateInfo.init_some_extra_state(self, model, model_input) if self.is_prefill: self.max_seq_len = self.max_kv_seq_len self.q_max_seq_len = self.max_q_seq_len @@ -26,5 +26,5 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor): self.position_cos = model._cos_cached[:, position_ids, :].unsqueeze(1) # init flash attention state - self._init_flash_attention_state(model, input_ids) + self._init_flash_attention_state(model, model_input) return diff --git a/lightllm/models/qwen2_vl/infer_struct.py b/lightllm/models/qwen2_vl/infer_struct.py index f57445454..86a46b6e1 100644 --- a/lightllm/models/qwen2_vl/infer_struct.py +++ b/lightllm/models/qwen2_vl/infer_struct.py @@ -2,6 +2,7 @@ import numpy as np from lightllm.models.llama.infer_struct import LlamaInferStateInfo from lightllm.common.basemodel.infer_struct import InferStateInfo +from lightllm.common.basemodel.batch_objs import ModelInput class Qwen2VLInferStateInfo(LlamaInferStateInfo): @@ -10,8 +11,8 @@ def __init__(self): self.position_cos = None self.position_sin = None - def init_some_extra_state(self, model, input_ids: torch.Tensor): - InferStateInfo.init_some_extra_state(self, model, input_ids) + def init_some_extra_state(self, model, model_input: ModelInput): + InferStateInfo.init_some_extra_state(self, model, model_input) if self.is_prefill: position_ids = self.position_ids self.position_sin = model._sin_cached[:, position_ids, :].unsqueeze(1) diff --git a/lightllm/models/vit/infer_struct.py b/lightllm/models/vit/infer_struct.py index 4a22f7f12..d231f3c88 100644 --- a/lightllm/models/vit/infer_struct.py +++ b/lightllm/models/vit/infer_struct.py @@ -2,6 +2,7 @@ import numpy as np from lightllm.common.basemodel import InferStateInfo from lightllm.common.req_manager import ReqManager +from lightllm.common.basemodel.batch_objs import ModelInput class LlamaInferStateInfo(InferStateInfo): @@ -10,7 +11,7 @@ def __init__(self): self.position_cos = None self.position_sin = None - def init_some_extra_state(self, model, input_ids: torch.Tensor): + def init_some_extra_state(self, model, model_input: ModelInput): if self.is_prefill: b_seq_len_numpy = self.b_seq_len.cpu().numpy() b_ready_cache_len_numpy = self.b_ready_cache_len.cpu().numpy() diff --git a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py index 39d345ff5..02bc29ce1 100644 --- a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py @@ -265,7 +265,7 @@ def decode_mtp( decode_reqs: List[InferReq], ): model_input, run_reqs = prepare_decode_inputs(decode_reqs) - b_mtp_index_cpu = model_input.b_mtp_index + b_mtp_index_cpu = model_input.b_mtp_index_cpu with torch.cuda.stream(g_infer_context.get_overlap_stream()): model_output = self.model.forward(model_input) all_next_token_ids = [] diff --git a/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py b/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py index a90a946fd..fe044cb3c 100644 --- a/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py @@ -449,7 +449,7 @@ def prefill_mtp(self, event_pack: OverlapEventPack, prefill_reqs: List[InferReq] def decode_mtp(self, event_pack: OverlapEventPack, decode_reqs: List[InferReq]): model_input, run_reqs, padded_req_num = padded_prepare_decode_inputs(decode_reqs) - b_mtp_index_cpu = model_input.b_mtp_index + b_mtp_index_cpu = model_input.b_mtp_index_cpu req_num = len(run_reqs) with torch.cuda.stream(g_infer_context.get_overlap_stream()): @@ -680,8 +680,8 @@ def decode_overlap_mtp(self, event_pack: OverlapEventPack, decode_reqs: List[Inf ) = padded_overlap_prepare_decode_inputs(decode_reqs) req_num0, req_num1 = len(run_reqs0), len(run_reqs1) all_next_token_ids = [] - b_mtp_index_cpu0 = micro_input0.b_mtp_index - b_mtp_index_cpu1 = micro_input1.b_mtp_index + b_mtp_index_cpu0 = micro_input0.b_mtp_index_cpu + b_mtp_index_cpu1 = micro_input1.b_mtp_index_cpu with torch.cuda.stream(g_infer_context.get_overlap_stream()): model_output0, model_output1 = self.model.microbatch_overlap_decode(micro_input0, micro_input1) diff --git a/lightllm/server/router/model_infer/mode_backend/generic_padded_pre_process.py b/lightllm/server/router/model_infer/mode_backend/generic_padded_pre_process.py index 10090a576..ddee92564 100644 --- a/lightllm/server/router/model_infer/mode_backend/generic_padded_pre_process.py +++ b/lightllm/server/router/model_infer/mode_backend/generic_padded_pre_process.py @@ -93,12 +93,12 @@ def padded_prepare_prefill_inputs( batch_size=b_seq_len.shape[0], total_token_num=total_token_num, max_len_in_batch=max_len_in_batch, - input_ids=input_ids, + input_ids_cpu=input_ids, mem_indexes_cpu=mem_indexes, - b_req_idx=b_req_idx, - b_mtp_index=b_mtp_index, - b_seq_len=b_seq_len, - b_ready_cache_len=b_ready_cache_len, + b_req_idx_cpu=b_req_idx, + b_mtp_index_cpu=b_mtp_index, + b_seq_len_cpu=b_seq_len, + b_ready_cache_len_cpu=b_ready_cache_len, is_prefill=True, b_prefill_has_output_cpu=b_prefill_has_output, ) @@ -180,9 +180,9 @@ def padded_prepare_decode_inputs( max_len_in_batch=max_len_in_batch, input_ids=None, mem_indexes_cpu=mem_indexes, - b_req_idx=b_req_idx, - b_mtp_index=b_mtp_index, - b_seq_len=b_seq_len, + b_req_idx_cpu=b_req_idx, + b_mtp_index_cpu=b_mtp_index, + b_seq_len_cpu=b_seq_len, is_prefill=False, ) return model_input, run_reqs, padded_req_num diff --git a/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py b/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py index d5bba1ae5..a240eb023 100644 --- a/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py +++ b/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py @@ -63,12 +63,12 @@ def prepare_prefill_inputs( batch_size=b_seq_len.shape[0], total_token_num=total_token_num, max_len_in_batch=max_len_in_batch, - input_ids=input_ids, + input_ids_cpu=input_ids, mem_indexes_cpu=mem_indexes, - b_req_idx=b_req_idx, - b_mtp_index=b_mtp_index, - b_seq_len=b_seq_len, - b_ready_cache_len=b_ready_cache_len, + b_req_idx_cpu=b_req_idx, + b_mtp_index_cpu=b_mtp_index, + b_seq_len_cpu=b_seq_len, + b_ready_cache_len_cpu=b_ready_cache_len, is_prefill=True, b_prefill_has_output_cpu=b_prefill_has_output, ) @@ -121,9 +121,9 @@ def prepare_decode_inputs(req_objs: List[InferReq]) -> Tuple[ModelInput, List[In max_len_in_batch=max_len_in_batch, input_ids=None, mem_indexes_cpu=mem_indexes, - b_req_idx=b_req_idx, - b_mtp_index=b_mtp_index, - b_seq_len=b_seq_len, + b_req_idx_cpu=b_req_idx, + b_mtp_index_cpu=b_mtp_index, + b_seq_len_cpu=b_seq_len, is_prefill=False, ) return model_input, run_reqs diff --git a/test/benchmark/static_inference/model_infer.py b/test/benchmark/static_inference/model_infer.py index b7c07d17a..1143dc1eb 100644 --- a/test/benchmark/static_inference/model_infer.py +++ b/test/benchmark/static_inference/model_infer.py @@ -189,13 +189,13 @@ def prefill( batch_size=batch_size, total_token_num=total_token_num, max_len_in_batch=max_len_in_batch, - input_ids=input_ids, - b_req_idx=b_req_idx, - b_seq_len=b_seq_len, - b_mtp_index=b_mtp_index, + input_ids_cpu=input_ids, + b_req_idx_cpu=b_req_idx, + b_seq_len_cpu=b_seq_len, + b_mtp_index_cpu=b_mtp_index, mem_indexes_cpu=mem_indexes, is_prefill=True, - b_ready_cache_len=b_ready_cache_len, # b_ready_cache_len + b_ready_cache_len_cpu=b_ready_cache_len, # b_ready_cache_len ) model_output = model_part.forward(model_input) @@ -209,10 +209,10 @@ def decode( batch_size=batch_size, total_token_num=total_token_num, max_len_in_batch=max_len_in_batch, - input_ids=input_ids, - b_req_idx=b_req_idx, - b_seq_len=b_seq_len, - b_mtp_index=b_mtp_index, + input_ids_cpu=input_ids, + b_req_idx_cpu=b_req_idx, + b_seq_len_cpu=b_seq_len, + b_mtp_index_cpu=b_mtp_index, mem_indexes_cpu=mem_indexes, is_prefill=False, ) @@ -230,13 +230,14 @@ def torch_profile(fn, log_dir=None): ) as prof: fn() if get_current_rank_in_dp() == 0: - print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=20)) + print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=1000)) def run_forward_once( model_kvargs, input_len, output_len, batch_size, model_part, enable_overlap, enable_torch_profile=False ): - test_data = np.vstack([np.random.randint(0, 50256, input_len) for _ in range(batch_size)]) + # test_data = np.vstack([np.random.randint(0, 50256, input_len) for _ in range(batch_size)]) + test_data = np.load("test.npy") test_data = test_data.reshape(-1) test_data = torch.from_numpy(test_data) import torch.distributed as dist @@ -251,15 +252,15 @@ def run_forward_once( b_req_idx = torch.tensor( [model_part.req_manager.alloc() for _ in range(batch_size)], dtype=torch.int32, device="cpu" - ) - b_seq_len = torch.zeros(batch_size, dtype=torch.int32, device="cpu") - b_ready_cache_len = torch.zeros(batch_size, dtype=torch.int32, device="cpu") + ).pin_memory() + b_seq_len = torch.zeros(batch_size, dtype=torch.int32, device="cpu").pin_memory() + b_ready_cache_len = torch.zeros(batch_size, dtype=torch.int32, device="cpu").pin_memory() for i in range(batch_size): b_seq_len[i] = input_len total_token_num = batch_size * input_len mem_indexes = model_part.req_manager.mem_manager.alloc(test_data.shape[0]) - b_mtp_index = torch.zeros(batch_size, dtype=torch.int32, device="cpu") + b_mtp_index = torch.zeros(batch_size, dtype=torch.int32, device="cpu").pin_memory() rank_id = model_kvargs["rank_id"] if enable_overlap: @@ -354,9 +355,6 @@ def run_forward_once( print(str(e)) raise - prob_out = torch.softmax(logits, dim=-1) - predict_ids = torch.argmax(prob_out, dim=1, keepdim=True) - _ = predict_ids.detach().cpu().numpy() torch.cuda.synchronize() if i % 100 == 0 or i == output_len - 1: if rank_id == 0: @@ -364,6 +362,9 @@ def run_forward_once( f"i: {i}, step cost time: {(time.time() - step_start) * 1000} ms, " f"throughput: {dp_size * batch_size / (time.time() - step_start)} tokens/s" ) + prob_out = torch.softmax(logits, dim=-1) + predict_ids = torch.argmax(prob_out, dim=1, keepdim=True) + _ = predict_ids.detach().cpu().numpy() model_part.mem_manager.free_all() model_part.req_manager.free_all()