diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index f712fe0164e4..baeba5298182 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -710,7 +710,6 @@ def __init__( self.dimensions = dimensions # For diffusion LLM - self.dllm_ids = [] self.dllm_block_offset = 0 self.dllm_config = dllm_config @@ -786,22 +785,19 @@ def finished(self) -> bool: def is_dllm(self): return self.dllm_config is not None + def _init_fill_ids_for_dllm(self): + if not self.fill_ids: + self.fill_ids = ( + self.origin_input_ids + + [self.dllm_config.mask_id] * self.dllm_config.block_size + ) + else: + self.dllm_block_offset += self.dllm_config.block_size + self.fill_ids += [self.dllm_config.mask_id] * self.dllm_config.block_size + def init_next_round_input(self, tree_cache: Optional[BasePrefixCache] = None): if self.is_dllm(): - if not self.fill_ids: - self.dllm_ids = ( - self.origin_input_ids - + [ - self.dllm_config.mask_id, - ] - * self.dllm_config.block_size - ) - else: - self.dllm_block_offset += self.dllm_config.block_size - self.dllm_ids += [ - self.dllm_config.mask_id - ] * self.dllm_config.block_size - self.fill_ids = self.dllm_ids + self._init_fill_ids_for_dllm() else: self.fill_ids = self.origin_input_ids + self.output_ids @@ -1322,9 +1318,11 @@ def prepare_encoder_info_extend(self, input_ids: List[int], seq_lens: List[int]) ), f"Expected {len(self.out_cache_loc)}, got {self.extend_num_tokens}" def prepare_for_extend(self): - self.forward_mode = ( - ForwardMode.DLLM_EXTEND if self.is_dllm() else ForwardMode.EXTEND - ) + self.forward_mode = ForwardMode.EXTEND + + if self.is_dllm(): + # For DLLM, we use a separate forward mode + self.forward_mode = ForwardMode.DLLM_EXTEND # Init tensors reqs = self.reqs diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index f37138a72749..758f0ffc9571 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -239,8 +239,11 @@ def __init__( is_draft_model=is_draft_worker, ) + # Init DLLM algorithm if server_args.dllm_algorithm is not None: self.dllm_algorithm = DllmAlgorithm.from_server_args(server_args) + else: + self.dllm_algorithm = None self._model_runner = ModelRunner( model_config=self.model_config, @@ -349,7 +352,19 @@ def get_worker_info(self): ) def is_dllm(self): - return hasattr(self, "dllm_algorithm") + return self.dllm_algorithm is not None + + def _forward_batch_generation_dllm( + self, forward_batch: ForwardBatch + ) -> GenerationBatchResult: + logits_output, next_token_ids, can_run_cuda_graph = self.dllm_algorithm.run( + self.model_runner, forward_batch + ) + return GenerationBatchResult( + logits_output=logits_output, + next_token_ids=next_token_ids, + can_run_cuda_graph=can_run_cuda_graph, + ) def forward_batch_generation( self, @@ -380,14 +395,7 @@ def forward_batch_generation( if self.pp_group.is_last_rank: if self.is_dllm(): - logits_output, next_token_ids, can_run_cuda_graph = ( - self.dllm_algorithm.run(self.model_runner, forward_batch) - ) - return GenerationBatchResult( - logits_output=logits_output, - next_token_ids=next_token_ids, - can_run_cuda_graph=can_run_cuda_graph, - ) + return self._forward_batch_generation_dllm(forward_batch) logits_output, can_run_cuda_graph = self.model_runner.forward( forward_batch, diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index 1fd483da3752..06726733aaeb 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -832,16 +832,18 @@ def replay( graph_key = self.bs self.graphs[graph_key].replay() output = self.output_buffers[graph_key] + if isinstance(output, LogitsProcessorOutput): + if self.is_dllm: + next_token_logits = None + full_logits = output.full_logits[: self.raw_num_token] + else: + full_logits = None + next_token_logits = output.next_token_logits[: self.raw_num_token] + return LogitsProcessorOutput( - next_token_logits=( - output.next_token_logits[: self.raw_num_token] - if not self.is_dllm - else None - ), - full_logits=( - output.full_logits[: self.raw_num_token] if self.is_dllm else None - ), + next_token_logits=next_token_logits, + full_logits=full_logits, hidden_states=( output.hidden_states[: self.raw_num_token] if output.hidden_states is not None