Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion python/sglang/srt/layers/attention/base_attn_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,9 @@ def forward(
**kwargs,
):
"""Run forward on an attention layer."""
if forward_batch.forward_mode.is_decode():
if forward_batch.forward_mode.is_idle():
return q.new_empty(q.shape[0], layer.tp_q_head_num * layer.v_head_dim)
elif forward_batch.forward_mode.is_decode():
return self.forward_decode(
q,
k,
Expand Down
18 changes: 2 additions & 16 deletions python/sglang/srt/layers/dp_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import functools
import logging
from contextlib import contextmanager
from typing import TYPE_CHECKING, List
from typing import TYPE_CHECKING, List, Tuple

import torch
import triton
Expand Down Expand Up @@ -162,7 +162,7 @@ def disable_dp_size():
_ATTN_DP_SIZE = old_dp_size


def get_dp_local_info(forward_batch: ForwardBatch):
def get_dp_local_info(forward_batch: ForwardBatch) -> Tuple[torch.Tensor, torch.Tensor]:
# `get_dp_local_info` is only called in global DP gather and scatter. We use global DP rank here.
dp_rank = get_attention_dp_rank()

Expand Down Expand Up @@ -238,13 +238,6 @@ def _dp_gather(
local_tokens.untyped_storage() is not global_tokens.untyped_storage()
), "aliasing between global_tokens and local_tokens not allowed"

# NOTE: During draft extend, the gathered_buffer is padded to num_tokens * (speculative_num_steps + 1).
# But the size of local_tokens is total accepted tokens. We need to reduce the local_num_tokens to the
# actual size of the accepted tokens.
if forward_batch.forward_mode.is_draft_extend():
shape_tensor = local_num_tokens.new_full((), local_tokens.shape[0])
local_num_tokens = torch.minimum(local_num_tokens, shape_tensor)

memcpy_triton(
global_tokens, local_tokens, 0, local_start_pos, local_num_tokens, False
)
Expand Down Expand Up @@ -296,13 +289,6 @@ def dp_scatter(
local_tokens.untyped_storage() is not global_tokens.untyped_storage()
), "aliasing between local_tokens and global_tokens not allowed"

# NOTE: During draft extend, the gathered_buffer is padded to num_tokens * (speculative_num_steps + 1).
# But the size of local_tokens is total accepted tokens. We need to reduce the local_num_tokens to the
# actual size of the accepted tokens.
if forward_batch.forward_mode.is_draft_extend():
shape_tensor = local_num_tokens.new_full((), local_tokens.shape[0])
local_num_tokens = torch.minimum(local_num_tokens, shape_tensor)

memcpy_triton(
local_tokens, global_tokens, 0, local_start_pos, local_num_tokens, True
)
Expand Down
8 changes: 5 additions & 3 deletions python/sglang/srt/layers/radix_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,16 @@
# limitations under the License.
# ==============================================================================
"""Radix attention."""
from __future__ import annotations

from enum import Enum
from typing import Optional
from typing import TYPE_CHECKING, Optional

from torch import nn

from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
if TYPE_CHECKING:
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.model_executor.forward_batch_info import ForwardBatch


class AttentionType(Enum):
Expand Down
5 changes: 2 additions & 3 deletions python/sglang/srt/managers/schedule_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@
import triton.language as tl

from sglang.global_config import global_config
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject
from sglang.srt.disaggregation.base import BaseKVSender
from sglang.srt.disaggregation.decode_schedule_batch_mixin import (
Expand All @@ -68,6 +67,7 @@
from sglang.srt.utils import flatten_nested_list, support_triton

if TYPE_CHECKING:
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm

Expand Down Expand Up @@ -1879,7 +1879,7 @@ class ModelWorkerBatch:
sampling_info: SamplingBatchInfo

# The input Embeds
input_embeds: Optional[torch.tensor] = None
input_embeds: Optional[torch.Tensor] = None

# For corss-encoder model
token_type_ids: Optional[torch.Tensor] = None
Expand All @@ -1889,7 +1889,6 @@ class ModelWorkerBatch:
spec_info: Optional[Union[EagleVerifyInput, EagleDraftInput]] = None
# If set, the output of the batch contains the hidden states of the run.
capture_hidden_mode: CaptureHiddenMode = None
spec_num_draft_tokens: Optional[int] = None
hicache_consumer_index: int = 0

# Overlap event
Expand Down
80 changes: 60 additions & 20 deletions python/sglang/srt/model_executor/forward_batch_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
import triton
import triton.language as tl

from sglang.srt.layers.dp_attention import get_attention_dp_rank
from sglang.srt.layers.rotary_embedding import MRotaryEmbedding
from sglang.srt.utils import (
flatten_nested_list,
Expand Down Expand Up @@ -242,7 +243,7 @@ class ForwardBatch:
lora_paths: Optional[List[str]] = None

# For input embeddings
input_embeds: Optional[torch.tensor] = None
input_embeds: Optional[torch.Tensor] = None

# For cross-encoder model
token_type_ids: Optional[torch.Tensor] = None
Expand Down Expand Up @@ -340,20 +341,38 @@ def init_new(
len(batch.input_ids), dtype=torch.int32
).to(device, non_blocking=True)

# For DP attention
# For MLP sync
if batch.global_num_tokens is not None:

spec_num_draft_tokens = (
batch.spec_num_draft_tokens
if batch.spec_num_draft_tokens is not None
else 1
from sglang.srt.speculative.eagle_utils import (
EagleDraftInput,
EagleVerifyInput,
)
global_num_tokens = [
x * spec_num_draft_tokens for x in batch.global_num_tokens
]
global_num_tokens_for_logprob = [
x * spec_num_draft_tokens for x in batch.global_num_tokens_for_logprob
]

assert batch.global_num_tokens_for_logprob is not None
# process global_num_tokens and global_num_tokens_for_logprob
if batch.spec_info is not None:
if isinstance(batch.spec_info, EagleDraftInput):
global_num_tokens = [
x * batch.spec_info.num_tokens_per_batch
for x in batch.global_num_tokens
]
global_num_tokens_for_logprob = [
x * batch.spec_info.num_tokens_for_logprob_per_batch
for x in batch.global_num_tokens_for_logprob
]
else:
assert isinstance(batch.spec_info, EagleVerifyInput)
global_num_tokens = [
x * batch.spec_info.draft_token_num
for x in batch.global_num_tokens
]
global_num_tokens_for_logprob = [
x * batch.spec_info.draft_token_num
for x in batch.global_num_tokens_for_logprob
]
else:
global_num_tokens = batch.global_num_tokens
global_num_tokens_for_logprob = batch.global_num_tokens_for_logprob

ret.global_num_tokens_cpu = global_num_tokens
ret.global_num_tokens_gpu = torch.tensor(
Expand All @@ -365,13 +384,6 @@ def init_new(
global_num_tokens_for_logprob, dtype=torch.int64
).to(device, non_blocking=True)

sum_len = sum(global_num_tokens)
ret.gathered_buffer = torch.zeros(
(sum_len, model_runner.model_config.hidden_size),
dtype=model_runner.dtype,
device=device,
)

if ret.forward_mode.is_idle():
ret.positions = torch.empty((0,), device=device)
TboForwardBatchPreparer.prepare(
Expand Down Expand Up @@ -573,6 +585,34 @@ def prepare_chunked_kv_indices(self, device: torch.device):
)
self.prefix_chunk_kv_indices.append(chunk_kv_indices)

def _pad_tensor_to_size(self, tensor: torch.Tensor, size: int):
return torch.cat(
[tensor, tensor.new_zeros(size - tensor.shape[0], *tensor.shape[1:])], dim=0
)

def prepare_mlp_sync_batch(self, model_runner: ModelRunner):
assert self.global_num_tokens_cpu is not None
global_num_tokens = self.global_num_tokens_cpu
sum_len = sum(global_num_tokens)
self.gathered_buffer = torch.zeros(
(sum_len, model_runner.model_config.hidden_size),
dtype=model_runner.dtype,
device=model_runner.device,
)
if self.forward_mode.is_draft_extend():
if len(global_num_tokens) > 1:
num_tokens = global_num_tokens[get_attention_dp_rank()]
else:
num_tokens = global_num_tokens[0]
self.input_ids = self._pad_tensor_to_size(self.input_ids, num_tokens)
self.out_cache_loc = self._pad_tensor_to_size(
self.out_cache_loc, num_tokens
)
self.positions = self._pad_tensor_to_size(self.positions, num_tokens)
self.spec_info.hidden_states = self._pad_tensor_to_size(
self.spec_info.hidden_states, num_tokens
)

# Here we suppose the length of each chunk is equal
# For example, if we have 4 sequences with prefix length [256, 512, 768, 1024], prefix_chunk_len = 256
# num_prefix_chunks = cdiv(1024, 256) = 4
Expand Down
22 changes: 18 additions & 4 deletions python/sglang/srt/model_executor/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1462,9 +1462,13 @@ def apply_torch_tp(self):
tensor_parallel(self.model, device_mesh)

def forward_decode(
self, forward_batch: ForwardBatch, pp_proxy_tensors=None
self,
forward_batch: ForwardBatch,
skip_attn_backend_init: bool = False,
pp_proxy_tensors=None,
) -> LogitsProcessorOutput:
self.attn_backend.init_forward_metadata(forward_batch)
if not skip_attn_backend_init:
self.attn_backend.init_forward_metadata(forward_batch)
# FIXME: add pp_proxy_tensors arg to all models
kwargs = {}
if self.support_pp:
Expand Down Expand Up @@ -1576,8 +1580,18 @@ def _forward_raw(
skip_attn_backend_init=skip_attn_backend_init,
pp_proxy_tensors=pp_proxy_tensors,
)
elif forward_batch.forward_mode.is_decode():
ret = self.forward_decode(forward_batch, pp_proxy_tensors=pp_proxy_tensors)
return ret, can_run_cuda_graph

# For MLP sync
if forward_batch.global_num_tokens_cpu is not None:
forward_batch.prepare_mlp_sync_batch(self)

if forward_batch.forward_mode.is_decode():
ret = self.forward_decode(
forward_batch,
skip_attn_backend_init=skip_attn_backend_init,
pp_proxy_tensors=pp_proxy_tensors,
)
elif forward_batch.forward_mode.is_extend():
ret = self.forward_extend(
forward_batch,
Expand Down
4 changes: 4 additions & 0 deletions python/sglang/srt/speculative/eagle_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,10 @@ class EagleDraftInput:
kv_indptr: torch.Tensor = None
kv_indices: torch.Tensor = None

# Shape info for padding
num_tokens_per_batch: int = -1
num_tokens_for_logprob_per_batch: int = -1

# Inputs for draft extend
# shape: (b,)
seq_lens_for_draft_extend: torch.Tensor = None
Expand Down
26 changes: 17 additions & 9 deletions python/sglang/srt/speculative/eagle_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,6 @@ def forward_target_extend(
# We need the full hidden states to prefill the KV cache of the draft model.
model_worker_batch = batch.get_model_worker_batch()
model_worker_batch.capture_hidden_mode = CaptureHiddenMode.FULL
model_worker_batch.spec_num_draft_tokens = 1
logits_output, next_token_ids, _ = self.target_worker.forward_batch_generation(
model_worker_batch
)
Expand Down Expand Up @@ -510,13 +509,15 @@ def draft(self, batch: ScheduleBatch):
self._draft_preprocess_decode(batch)

spec_info = batch.spec_info
assert isinstance(spec_info, EagleDraftInput)

spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
spec_info.num_tokens_per_batch = self.topk
spec_info.num_tokens_for_logprob_per_batch = self.topk
batch.return_hidden_states = False

# Get forward batch
model_worker_batch = batch.get_model_worker_batch()
model_worker_batch.spec_num_draft_tokens = self.topk
assert model_worker_batch.capture_hidden_mode == CaptureHiddenMode.LAST
forward_batch = ForwardBatch.init_new(
model_worker_batch, self.draft_model_runner
Expand All @@ -529,6 +530,7 @@ def draft(self, batch: ScheduleBatch):
forward_batch
)
else:
forward_batch.can_run_dp_cuda_graph = False
if not forward_batch.forward_mode.is_idle():
# Initialize attention backend
self.draft_attn_backend.init_forward_metadata(forward_batch)
Expand Down Expand Up @@ -580,6 +582,7 @@ def draft(self, batch: ScheduleBatch):
def draft_forward(self, forward_batch: ForwardBatch):
# Parse args
spec_info = forward_batch.spec_info
assert isinstance(spec_info, EagleDraftInput)
out_cache_loc = forward_batch.out_cache_loc
topk_p, topk_index, hidden_states = (
spec_info.topk_p,
Expand Down Expand Up @@ -623,8 +626,8 @@ def draft_forward(self, forward_batch: ForwardBatch):
spec_info.hidden_states = hidden_states

# Run forward
logits_output = self.draft_model_runner.model.forward(
forward_batch.input_ids, forward_batch.positions, forward_batch
logits_output, _ = self.draft_model_runner.forward(
forward_batch, skip_attn_backend_init=True
)
self._detect_nan_if_needed(logits_output)
probs = torch.softmax(logits_output.next_token_logits, dim=-1)
Expand All @@ -644,10 +647,10 @@ def verify(self, batch: ScheduleBatch, spec_info: EagleVerifyInput):
else ForwardMode.IDLE
)
batch.spec_info = spec_info

model_worker_batch = batch.get_model_worker_batch(
seq_lens_cpu_cache=spec_info.seq_lens_cpu
)
model_worker_batch.spec_num_draft_tokens = self.speculative_num_draft_tokens
assert model_worker_batch.capture_hidden_mode == spec_info.capture_hidden_mode

if batch.has_grammar:
Expand Down Expand Up @@ -797,14 +800,15 @@ def forward_draft_extend(
batch.spec_info = EagleDraftInput(
hidden_states=hidden_states,
verified_id=next_token_ids,
num_tokens_per_batch=1,
num_tokens_for_logprob_per_batch=1,
)
batch.return_hidden_states = False
batch.spec_info.prepare_for_extend(batch)
batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
model_worker_batch = batch.get_model_worker_batch(
seq_lens_cpu_cache=seq_lens_cpu
)
model_worker_batch.spec_num_draft_tokens = 1
forward_batch = ForwardBatch.init_new(
model_worker_batch, self.draft_model_runner
)
Expand All @@ -816,6 +820,7 @@ def forward_draft_extend(
self.capture_for_decode(logits_output, forward_batch.spec_info)

def forward_draft_extend_after_decode(self, batch: ScheduleBatch):
assert isinstance(batch.spec_info, EagleDraftInput)
# Backup fields that will be modified in-place
seq_lens_backup = batch.seq_lens.clone()
req_pool_indices_backup = batch.req_pool_indices
Expand All @@ -839,6 +844,9 @@ def forward_draft_extend_after_decode(self, batch: ScheduleBatch):
topk=self.topk,
capture_hidden_mode=CaptureHiddenMode.LAST,
)

batch.spec_info.num_tokens_per_batch = self.speculative_num_steps + 1
batch.spec_info.num_tokens_for_logprob_per_batch = 1
batch.spec_info.prepare_extend_after_decode(
batch,
self.speculative_num_steps,
Expand All @@ -851,7 +859,6 @@ def forward_draft_extend_after_decode(self, batch: ScheduleBatch):

batch.return_hidden_states = False
model_worker_batch = batch.get_model_worker_batch()
model_worker_batch.spec_num_draft_tokens = self.speculative_num_steps + 1
assert model_worker_batch.capture_hidden_mode == CaptureHiddenMode.LAST
forward_batch = ForwardBatch.init_new(
model_worker_batch, self.draft_model_runner
Expand All @@ -876,12 +883,13 @@ def forward_draft_extend_after_decode(self, batch: ScheduleBatch):
)
forward_batch.spec_info.hidden_states = logits_output.hidden_states
else:
forward_batch.can_run_dp_cuda_graph = False
if not forward_batch.forward_mode.is_idle():
self.draft_model_runner.attn_backend.init_forward_metadata(
forward_batch
)
logits_output = self.draft_model_runner.model.forward(
forward_batch.input_ids, forward_batch.positions, forward_batch
logits_output, _ = self.draft_model_runner.forward(
forward_batch, skip_attn_backend_init=True
)
self.capture_for_decode(logits_output, forward_batch.spec_info)

Expand Down
Loading