Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 16 additions & 18 deletions python/sglang/srt/managers/schedule_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -710,7 +710,6 @@ def __init__(
self.dimensions = dimensions

# For diffusion LLM
self.dllm_ids = []
self.dllm_block_offset = 0
self.dllm_config = dllm_config

Expand Down Expand Up @@ -786,22 +785,19 @@ def finished(self) -> bool:
def is_dllm(self):
return self.dllm_config is not None

def _init_fill_ids_for_dllm(self):
if not self.fill_ids:
self.fill_ids = (
self.origin_input_ids
+ [self.dllm_config.mask_id] * self.dllm_config.block_size
)
else:
self.dllm_block_offset += self.dllm_config.block_size
self.fill_ids += [self.dllm_config.mask_id] * self.dllm_config.block_size

def init_next_round_input(self, tree_cache: Optional[BasePrefixCache] = None):
if self.is_dllm():
if not self.fill_ids:
self.dllm_ids = (
self.origin_input_ids
+ [
self.dllm_config.mask_id,
]
* self.dllm_config.block_size
)
else:
self.dllm_block_offset += self.dllm_config.block_size
self.dllm_ids += [
self.dllm_config.mask_id
] * self.dllm_config.block_size
self.fill_ids = self.dllm_ids
self._init_fill_ids_for_dllm()
else:
self.fill_ids = self.origin_input_ids + self.output_ids

Expand Down Expand Up @@ -1322,9 +1318,11 @@ def prepare_encoder_info_extend(self, input_ids: List[int], seq_lens: List[int])
), f"Expected {len(self.out_cache_loc)}, got {self.extend_num_tokens}"

def prepare_for_extend(self):
self.forward_mode = (
ForwardMode.DLLM_EXTEND if self.is_dllm() else ForwardMode.EXTEND
)
self.forward_mode = ForwardMode.EXTEND

if self.is_dllm():
# For DLLM, we use a separate forward mode
self.forward_mode = ForwardMode.DLLM_EXTEND

# Init tensors
reqs = self.reqs
Expand Down
26 changes: 17 additions & 9 deletions python/sglang/srt/managers/tp_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,8 +239,11 @@ def __init__(
is_draft_model=is_draft_worker,
)

# Init DLLM algorithm
if server_args.dllm_algorithm is not None:
self.dllm_algorithm = DllmAlgorithm.from_server_args(server_args)
else:
self.dllm_algorithm = None

self._model_runner = ModelRunner(
model_config=self.model_config,
Expand Down Expand Up @@ -349,7 +352,19 @@ def get_worker_info(self):
)

def is_dllm(self):
return hasattr(self, "dllm_algorithm")
return self.dllm_algorithm is not None

def _forward_batch_generation_dllm(
self, forward_batch: ForwardBatch
) -> GenerationBatchResult:
logits_output, next_token_ids, can_run_cuda_graph = self.dllm_algorithm.run(
self.model_runner, forward_batch
)
return GenerationBatchResult(
logits_output=logits_output,
next_token_ids=next_token_ids,
can_run_cuda_graph=can_run_cuda_graph,
)

def forward_batch_generation(
self,
Expand Down Expand Up @@ -380,14 +395,7 @@ def forward_batch_generation(

if self.pp_group.is_last_rank:
if self.is_dllm():
logits_output, next_token_ids, can_run_cuda_graph = (
self.dllm_algorithm.run(self.model_runner, forward_batch)
)
return GenerationBatchResult(
logits_output=logits_output,
next_token_ids=next_token_ids,
can_run_cuda_graph=can_run_cuda_graph,
)
return self._forward_batch_generation_dllm(forward_batch)

logits_output, can_run_cuda_graph = self.model_runner.forward(
forward_batch,
Expand Down
18 changes: 10 additions & 8 deletions python/sglang/srt/model_executor/cuda_graph_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -832,16 +832,18 @@ def replay(
graph_key = self.bs
self.graphs[graph_key].replay()
output = self.output_buffers[graph_key]

if isinstance(output, LogitsProcessorOutput):
if self.is_dllm:
next_token_logits = None
full_logits = output.full_logits[: self.raw_num_token]
else:
full_logits = None
next_token_logits = output.next_token_logits[: self.raw_num_token]

return LogitsProcessorOutput(
next_token_logits=(
output.next_token_logits[: self.raw_num_token]
if not self.is_dllm
else None
),
full_logits=(
output.full_logits[: self.raw_num_token] if self.is_dllm else None
),
next_token_logits=next_token_logits,
full_logits=full_logits,
hidden_states=(
output.hidden_states[: self.raw_num_token]
if output.hidden_states is not None
Expand Down
Loading