diff --git a/python/sglang/srt/layers/attention/flashinfer_backend.py b/python/sglang/srt/layers/attention/flashinfer_backend.py index 27705a4b8793..3dbe847e324a 100644 --- a/python/sglang/srt/layers/attention/flashinfer_backend.py +++ b/python/sglang/srt/layers/attention/flashinfer_backend.py @@ -26,7 +26,7 @@ from sglang.srt.layers.radix_attention import AttentionType from sglang.srt.mem_cache.swa_memory_pool import SWATokenToKVPoolAllocator from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode -from sglang.srt.speculative.spec_info import SpecInput +from sglang.srt.speculative.spec_info import SpecInput, SpecInputType from sglang.srt.utils import ( get_int_env_var, is_flashinfer_available, @@ -606,9 +606,7 @@ def init_forward_metadata_capture_cuda_graph( elif forward_mode.is_target_verify(): # FlashInfer's prefill wrapper decides mask mode based on whether # `custom_mask_buf` is initialized (not whether a custom mask is provided). - # For cases like DFLASH draft (ENCODER_ONLY / non-causal) we do NOT use a - # custom mask, so we must avoid initializing `custom_mask_buf`, otherwise - # FlashInfer will treat the (zero) buffer as a real mask and block attention. + # DFlash relies on layer causal/window metadata instead of a custom mask. use_custom_mask = ( spec_info is not None and getattr(spec_info, "custom_mask", None) is not None @@ -1241,7 +1239,7 @@ def update( seq_lens: torch.Tensor, seq_lens_cpu: Optional[torch.Tensor], seq_lens_sum: int, - prefix_lens: torch.Tensor, + prefix_lens: Optional[torch.Tensor], prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper], use_ragged: bool, encoder_lens: Optional[torch.Tensor], @@ -1259,7 +1257,7 @@ def update_single_wrapper( seq_lens: torch.Tensor, seq_lens_cpu: Optional[torch.Tensor], seq_lens_sum: int, - prefix_lens: torch.Tensor, + prefix_lens: Optional[torch.Tensor], prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper], use_ragged: bool, encoder_lens: Optional[torch.Tensor], @@ -1269,6 +1267,7 @@ def update_single_wrapper( cross_attention_custom_mask: Optional[torch.Tensor] = None, ): if use_ragged: + assert prefix_lens is not None # TODO: remove this device sync, we can use forward_batch.extend_prefix_lens_cpu # and forward_batch.extend_seq_lens_cpu paged_kernel_lens = prefix_lens @@ -1300,7 +1299,7 @@ def update_sliding_window( seq_lens: torch.Tensor, seq_lens_cpu: Optional[torch.Tensor], seq_lens_sum: int, - prefix_lens: torch.Tensor, + prefix_lens: Optional[torch.Tensor], prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper], use_ragged: bool, encoder_lens: Optional[torch.Tensor], @@ -1309,12 +1308,23 @@ def update_sliding_window( multi_item_params: Optional[MultiItemScoringParams] = None, cross_attention_custom_mask: Optional[torch.Tensor] = None, ): + if prefix_lens is None: + accept_length = getattr(spec_info, "accept_length", None) + prefix_lens = ( + seq_lens + if accept_length is None + else seq_lens + - accept_length[: seq_lens.shape[0]].to( + device=seq_lens.device, dtype=seq_lens.dtype + ) + ) + window_size = seq_lens.new_tensor(self.sliding_window_size) for wrapper_id in range(2): if wrapper_id == 0: # window attention use paged only paged_kernel_lens = torch.minimum( seq_lens, - torch.tensor(self.sliding_window_size) + seq_lens - prefix_lens, + window_size + seq_lens - prefix_lens, ) paged_kernel_lens_sum = paged_kernel_lens.sum().item() else: @@ -1350,7 +1360,7 @@ def update_cross_attention( seq_lens: torch.Tensor, seq_lens_cpu: Optional[torch.Tensor], seq_lens_sum: int, - prefix_lens: torch.Tensor, + prefix_lens: Optional[torch.Tensor], prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper], use_ragged: bool, encoder_lens: Optional[torch.Tensor], @@ -1398,7 +1408,7 @@ def call_begin_forward( paged_kernel_lens: torch.Tensor, paged_kernel_lens_sum: int, seq_lens: torch.Tensor, - prefix_lens: torch.Tensor, + prefix_lens: Optional[torch.Tensor], kv_start_idx: torch.Tensor, kv_indptr: torch.Tensor, qo_indptr: torch.Tensor, @@ -1411,6 +1421,7 @@ def call_begin_forward( ): bs = len(seq_lens) if spec_info is None: + assert prefix_lens is not None assert len(seq_lens) == len(req_pool_indices) # Normal extend kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0) @@ -1435,14 +1446,25 @@ def call_begin_forward( custom_mask = cross_attention_custom_mask else: assert isinstance(spec_info, SpecInput) - kv_indices, kv_indptr, qo_indptr, custom_mask = ( - spec_info.generate_attn_arg_prefill( - req_pool_indices, - paged_kernel_lens, - paged_kernel_lens_sum, - self.req_to_token, + if spec_info.spec_input_type == SpecInputType.DFLASH_VERIFY: + kv_indices, kv_indptr, qo_indptr, custom_mask = ( + spec_info.generate_attn_arg_prefill( + req_pool_indices, + paged_kernel_lens, + paged_kernel_lens_sum, + self.req_to_token, + kv_start_idx=kv_start_idx, + ) + ) + else: + kv_indices, kv_indptr, qo_indptr, custom_mask = ( + spec_info.generate_attn_arg_prefill( + req_pool_indices, + paged_kernel_lens, + paged_kernel_lens_sum, + self.req_to_token, + ) ) - ) # extend part if use_ragged: diff --git a/python/sglang/srt/managers/overlap_utils.py b/python/sglang/srt/managers/overlap_utils.py index a2bc66eaf967..973953c1bc5d 100644 --- a/python/sglang/srt/managers/overlap_utils.py +++ b/python/sglang/srt/managers/overlap_utils.py @@ -83,41 +83,54 @@ def __init__( def _lazy_init_buf(self, draft_input: EagleDraftInput): self.buf_initialized = True - # Get a reference for each tensor - topk_p0 = draft_input.topk_p[0] - topk_index0 = draft_input.topk_index[0] - bonus_token0 = draft_input.bonus_tokens[0] + self.need_verified_id = getattr(draft_input, "verified_id", None) is not None + self.need_bonus_tokens = getattr(draft_input, "bonus_tokens", None) is not None + + if self.need_verified_id: + verified_id0 = draft_input.verified_id[0] + self.verified_id_buf = torch.empty( + (self.future_buffer_len, *verified_id0.shape), + dtype=verified_id0.dtype, + device=self.device, + ) + if self.need_bonus_tokens: + bonus_token0 = draft_input.bonus_tokens[0] + self.bonus_tokens_buf = torch.empty( + (self.future_buffer_len, *bonus_token0.shape), + dtype=bonus_token0.dtype, + device=self.device, + ) new_seq_lens0 = draft_input.new_seq_lens[0] - self.topk_p_buf = torch.empty( - (self.future_buffer_len, *topk_p0.shape), - dtype=topk_p0.dtype, - device=self.device, - ) - self.topk_index_buf = torch.empty( - (self.future_buffer_len, *topk_index0.shape), - dtype=topk_index0.dtype, - device=self.device, - ) - self.bonus_tokens_buf = torch.empty( - (self.future_buffer_len, *bonus_token0.shape), - dtype=bonus_token0.dtype, - device=self.device, - ) self.new_seq_lens_buf = torch.empty( (self.future_buffer_len, *new_seq_lens0.shape), dtype=new_seq_lens0.dtype, device=self.device, ) - if spec_need_hidden_states(): - hidden_states0 = draft_input.hidden_states[0] - self.hidden_states_buf = torch.empty( - (self.future_buffer_len, *hidden_states0.shape), - dtype=hidden_states0.dtype, + if self.spec_algo.need_topk(): + # Get a reference for each tensor + topk_p0 = draft_input.topk_p[0] + topk_index0 = draft_input.topk_index[0] + self.topk_p_buf = torch.empty( + (self.future_buffer_len, *topk_p0.shape), + dtype=topk_p0.dtype, + device=self.device, + ) + self.topk_index_buf = torch.empty( + (self.future_buffer_len, *topk_index0.shape), + dtype=topk_index0.dtype, device=self.device, ) + if spec_need_hidden_states(): + hidden_states0 = draft_input.hidden_states[0] + self.hidden_states_buf = torch.empty( + (self.future_buffer_len, *hidden_states0.shape), + dtype=hidden_states0.dtype, + device=self.device, + ) + def alloc_future_indices(self, bs: int) -> FutureIndices: """Update the circular buffer pointer and allocate future indices.""" cur_future_ct = self.future_ct @@ -144,12 +157,16 @@ def resolve_future(self, model_worker_batch: ModelWorkerBatch): # caching allocator (torch GC) could reclaim the memory before # the GPU finishes reading it. indices.record_stream(torch.get_device_module(self.device).current_stream()) - draft_input.topk_p = self.topk_p_buf[indices] - draft_input.topk_index = self.topk_index_buf[indices] - draft_input.bonus_tokens = self.bonus_tokens_buf[indices] + if self.need_verified_id: + draft_input.verified_id = self.verified_id_buf[indices] + if self.need_bonus_tokens: + draft_input.bonus_tokens = self.bonus_tokens_buf[indices] draft_input.new_seq_lens = self.new_seq_lens_buf[indices] - if spec_need_hidden_states(): - draft_input.hidden_states = self.hidden_states_buf[indices] + if self.spec_algo.need_topk(): + draft_input.topk_p = self.topk_p_buf[indices] + draft_input.topk_index = self.topk_index_buf[indices] + if spec_need_hidden_states(): + draft_input.hidden_states = self.hidden_states_buf[indices] def is_empty_slice(self, s: slice) -> bool: start, stop, step = s.indices(self.future_buffer_len) @@ -179,9 +196,13 @@ def store_to_map_for_new_batch( if not self.buf_initialized: self._lazy_init_buf(draft_input) - self.topk_p_buf[intv] = draft_input.topk_p - self.topk_index_buf[intv] = draft_input.topk_index - self.bonus_tokens_buf[intv] = draft_input.bonus_tokens + if self.need_verified_id: + self.verified_id_buf[intv] = draft_input.verified_id + if self.need_bonus_tokens: + self.bonus_tokens_buf[intv] = draft_input.bonus_tokens self.new_seq_lens_buf[intv] = draft_input.new_seq_lens - if spec_need_hidden_states(): - self.hidden_states_buf[intv] = draft_input.hidden_states + if self.spec_algo.need_topk(): + self.topk_p_buf[intv] = draft_input.topk_p + self.topk_index_buf[intv] = draft_input.topk_index + if spec_need_hidden_states(): + self.hidden_states_buf[intv] = draft_input.hidden_states diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 99f2744ee763..9d2498362d9b 100755 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -2388,10 +2388,20 @@ def prepare_for_decode(self): ) def maybe_wait_verify_done(self): - if self.is_spec_v2: - draft_input: EagleDraftInput = self.spec_info - if draft_input.verify_done is not None: - draft_input.verify_done.synchronize() + if not self.is_spec_v2: + return + + draft_input: EagleDraftInput = self.spec_info + verify_done = getattr(draft_input, "verify_done", None) + if verify_done is None: + return + + if envs.SGLANG_ENABLE_OVERLAP_PLAN_STREAM.get(): + torch.get_device_module(self.device).current_stream().wait_event( + verify_done + ) + else: + verify_done.synchronize() def filter_batch( self, diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 8f4dfe996afb..afbef56befdb 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -208,6 +208,10 @@ from sglang.srt.server_args import PortArgs, ServerArgs, get_global_server_args from sglang.srt.session.session_controller import SessionController from sglang.srt.session.streaming_session import StreamingSession +from sglang.srt.speculative.dflash_utils import ( + resolve_dflash_prefill_refill_target, + should_delay_dflash_prefill_for_batching, +) from sglang.srt.speculative.spec_info import SpeculativeAlgorithm from sglang.srt.utils import ( DynamicGradMode, @@ -303,10 +307,12 @@ def copy_to_cpu(self): self.copy_done.record() -def validate_dflash_request(req: Req) -> Optional[str]: +def validate_dflash_request(req: Req, enable_overlap: bool) -> Optional[str]: if req.return_logprob: return "DFLASH speculative decoding does not support return_logprob yet." + if enable_overlap and req.return_hidden_states: + return "DFLASH speculative decoding does not support return_hidden_states yet." if ( req.sampling_params.json_schema is not None or req.sampling_params.regex is not None @@ -726,6 +732,11 @@ def init_model_worker(self): _, _, ) = self.tp_worker.get_worker_info() + self.dflash_prefill_refill_target = ( + resolve_dflash_prefill_refill_target(self.max_running_requests) + if self.spec_algorithm.is_dflash() + else 1 + ) if not get_global_server_args().pp_max_micro_batch_size: get_global_server_args().pp_max_micro_batch_size = max( self.max_running_requests // self.pp_size, 1 @@ -2085,13 +2096,12 @@ def handle_generate_request( return if self.spec_algorithm.is_dflash(): - error_msg = validate_dflash_request(req) + error_msg = validate_dflash_request(req, self.enable_overlap) if error_msg is not None: req.set_finish_with_abort(error_msg) self.init_req_max_new_tokens(req) self._add_request_to_queue(req) return - # Handle multimodal inputs if recv_req.mm_inputs is not None: image_inputs = self._get_multimodal_inputs(recv_req.mm_inputs) @@ -2580,6 +2590,19 @@ def get_num_allocatable_reqs(self, running_bs): res = min(res, self.req_to_token_pool.available_size()) return res + def _should_delay_dflash_prefill_for_batching(self, running_bs: int) -> bool: + if not self.spec_algorithm.is_dflash(): + return False + if running_bs <= 0 or self.chunked_req is not None: + return False + + return should_delay_dflash_prefill_for_batching( + running_bs=running_bs, + num_allocatable_reqs=self.get_num_allocatable_reqs(running_bs), + max_running_requests=self.max_running_requests, + prefill_refill_target=self.dflash_prefill_refill_target, + ) + def get_new_batch_prefill(self) -> Optional[ScheduleBatch]: prefill_delayer_single_pass = None if self.prefill_delayer: @@ -2620,6 +2643,8 @@ def _get_new_batch_prefill_raw( return None running_bs = len(self.running_batch.reqs) + if self._should_delay_dflash_prefill_for_batching(running_bs): + return None # Ignore the check if self.chunked_req is not None. # In the non-PP case, when self.chunked_req is not None, num_allocatable_reqs should always be greater than 0, diff --git a/python/sglang/srt/managers/scheduler_output_processor_mixin.py b/python/sglang/srt/managers/scheduler_output_processor_mixin.py index 997073ab8ca4..919c72fef6ca 100644 --- a/python/sglang/srt/managers/scheduler_output_processor_mixin.py +++ b/python/sglang/srt/managers/scheduler_output_processor_mixin.py @@ -439,12 +439,17 @@ def _resolve_spec_overlap_token_ids( continue if req.finished(): - # -1 because prepare_for_decode pre-claimed the bonus slot. - req.kv_committed_len -= 1 + if not batch.spec_algorithm.is_dflash(): + # EAGLE prepare_for_decode pre-claimed the bonus slot. + req.kv_committed_len -= 1 continue - # -1 because prepare_for_decode pre-claimed the bonus slot. - req.kv_committed_len += accept_lens[i] - 1 + if batch.spec_algorithm.is_dflash(): + # DFLASH materialized accepted draft tokens plus the bonus token. + req.kv_committed_len += accept_lens[i] + else: + # EAGLE prepare_for_decode pre-claimed the bonus slot. + req.kv_committed_len += accept_lens[i] - 1 req.spec_verify_ct += 1 accepted_draft_tokens = result.num_accepted_drafts_per_req_cpu[i] @@ -487,6 +492,20 @@ def process_batch_result_decode( if batch.spec_algorithm.is_none() or batch.is_spec_v2: if batch.is_spec_v2: + prepared_kv_lens_cpu = getattr( + result, "prepared_kv_allocated_lens_cpu", None + ) + if prepared_kv_lens_cpu is not None: + for i, req in enumerate(batch.reqs): + if self.enable_overlap and (req.finished() or req.is_retracted): + continue + # In overlap mode, a newer batch may already have reserved + # further KV slots before this older result is processed. + # Do not move the request-side allocation watermark + # backwards or release_kv_cache can miss those pages. + req.kv_allocated_len = max( + req.kv_allocated_len, int(prepared_kv_lens_cpu[i]) + ) next_token_ids = self._resolve_spec_overlap_token_ids(result, batch) elif isinstance(next_token_ids, list): pass # MLX path: already a list[int], skip torch round-trip diff --git a/python/sglang/srt/managers/utils.py b/python/sglang/srt/managers/utils.py index b404f15aebf5..738ab126bf02 100644 --- a/python/sglang/srt/managers/utils.py +++ b/python/sglang/srt/managers/utils.py @@ -44,6 +44,7 @@ class GenerationBatchResult: # FIXME(lsyin): maybe move to a better place? # sync path: forward stream -> output processor accept_lens: Optional[torch.Tensor] = None + prepared_kv_allocated_lens_cpu: Optional[torch.Tensor] = None # relay path: forward stream -> next step forward next_draft_input: Optional[EagleDraftInput] = None diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py index 6eff09e176d3..fb57da798c59 100644 --- a/python/sglang/srt/mem_cache/memory_pool.py +++ b/python/sglang/srt/mem_cache/memory_pool.py @@ -124,6 +124,67 @@ def _set_kv_buffer_impl( v_cache[indices] = v +def _set_kv_buffer_prefix_valid_impl( + k: torch.Tensor, + v: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + loc_2d: torch.Tensor, + commit_lens: torch.Tensor, + row_dim: int, + store_dtype: torch.dtype, +) -> None: + if k.numel() == 0 or loc_2d.numel() == 0 or commit_lens.numel() == 0: + return + + if not k.is_contiguous(): + k = k.contiguous() + if not v.is_contiguous(): + v = v.contiguous() + if not loc_2d.is_contiguous(): + loc_2d = loc_2d.contiguous() + if not commit_lens.is_contiguous(): + commit_lens = commit_lens.contiguous() + + row_bytes = row_dim * store_dtype.itemsize + if row_bytes <= 0: + return + + if row_bytes >= 8192: + bytes_per_tile = 512 + num_warps = 8 + elif row_bytes >= 4096: + bytes_per_tile = 256 + num_warps = 4 + else: + bytes_per_tile = 128 + num_warps = 4 + + grid = ( + int(loc_2d.shape[0]), + int(loc_2d.shape[1]), + triton.cdiv(row_bytes, bytes_per_tile), + ) + + set_kv_buffer_prefix_valid_tiled[grid]( + k, + v, + k_cache, + v_cache, + loc_2d, + commit_lens, + int(k.stride(0) * k.element_size()), + int(v.stride(0) * v.element_size()), + int(k_cache.stride(0) * k_cache.element_size()), + int(v_cache.stride(0) * v_cache.element_size()), + int(loc_2d.shape[1]), + ROW_BYTES=row_bytes, + BYTES_PER_TILE=bytes_per_tile, + num_warps=num_warps, + num_stages=2, + ) + + class ReqToTokenPool: """A memory pool that maps a request to its token locations.""" @@ -1073,6 +1134,91 @@ def set_kv_buffer( same_kv_dim=self.same_kv_dim, ) + def set_kv_buffer_prefix_valid( + self, + layer: RadixAttention, + loc_2d: torch.Tensor, + commit_lens: torch.Tensor, + cache_k: torch.Tensor, + cache_v: torch.Tensor, + k_scale: Optional[float] = None, + v_scale: Optional[float] = None, + layer_id_override: Optional[int] = None, + ): + if layer_id_override is not None: + layer_id = layer_id_override + else: + layer_id = layer.layer_id + + if loc_2d.ndim != 2: + raise ValueError(f"loc_2d must be rank-2, got shape={tuple(loc_2d.shape)}.") + if commit_lens.ndim != 1 or commit_lens.shape[0] != loc_2d.shape[0]: + raise ValueError( + "commit_lens must match loc_2d batch size: " + f"{tuple(commit_lens.shape)=} {tuple(loc_2d.shape)=}." + ) + + num_rows = int(loc_2d.numel()) + if cache_k.shape[0] != num_rows or cache_v.shape[0] != num_rows: + raise ValueError( + "dense KV rows must match loc_2d size: " + f"{tuple(cache_k.shape)=} {tuple(cache_v.shape)=} {tuple(loc_2d.shape)=}." + ) + + if cache_k.dtype != self.dtype: + if k_scale is not None: + cache_k.div_(k_scale) + if v_scale is not None: + cache_v.div_(v_scale) + cache_k = cache_k.to(self.dtype) + cache_v = cache_v.to(self.dtype) + + if self.store_dtype != self.dtype: + cache_k = cache_k.contiguous().view(self.store_dtype) + cache_v = cache_v.contiguous().view(self.store_dtype) + else: + cache_k = cache_k.contiguous() + cache_v = cache_v.contiguous() + + if loc_2d.device != self.k_buffer[0].device: + loc_2d = loc_2d.to(device=self.k_buffer[0].device, non_blocking=True) + if commit_lens.device != self.k_buffer[0].device: + commit_lens = commit_lens.to( + device=self.k_buffer[0].device, non_blocking=True + ) + if loc_2d.dtype != torch.int64: + loc_2d = loc_2d.to(torch.int64) + if commit_lens.dtype != torch.int32: + commit_lens = commit_lens.to(torch.int32) + + if not (_is_cuda or _is_hip): + row_offsets = torch.arange(loc_2d.shape[1], device=loc_2d.device) + valid_mask = row_offsets[None, :] < commit_lens.to(torch.int64)[:, None] + valid_idx = torch.nonzero(valid_mask.reshape(-1), as_tuple=False).flatten() + if valid_idx.numel() == 0: + return + self.set_kv_buffer( + layer, + loc_2d.reshape(-1).index_select(0, valid_idx), + cache_k.index_select(0, valid_idx), + cache_v.index_select(0, valid_idx), + k_scale, + v_scale, + layer_id_override=layer_id, + ) + return + + _set_kv_buffer_prefix_valid_impl( + cache_k, + cache_v, + self.k_buffer[layer_id - self.start_layer], + self.v_buffer[layer_id - self.start_layer], + loc_2d, + commit_lens, + row_dim=self.row_dim, + store_dtype=self.store_dtype, + ) + def move_kv_cache(self, tgt_loc: torch.Tensor, src_loc: torch.Tensor): if envs.SGLANG_NATIVE_MOVE_KV_CACHE.get(): move_kv_cache_native(self.k_buffer, self.v_buffer, tgt_loc, src_loc) @@ -2204,6 +2350,53 @@ def move_kv_cache_native( v_cache[tgt_loc_flat] = v_cache[src_loc_flat] +@triton.jit +def set_kv_buffer_prefix_valid_tiled( + src_k_ptr, + src_v_ptr, + dst_k_ptr, + dst_v_ptr, + loc_2d_ptr, + commit_len_ptr, + src_k_row_stride, + src_v_row_stride, + dst_k_row_stride, + dst_v_row_stride, + block_size, + ROW_BYTES: tl.constexpr, + BYTES_PER_TILE: tl.constexpr, +): + bid = tl.program_id(0) + row = tl.program_id(1) + tid = tl.program_id(2) + + commit_len = tl.load(commit_len_ptr + bid) + if row >= commit_len: + return + + byte_off = tid * BYTES_PER_TILE + tl.arange(0, BYTES_PER_TILE) + mask_byte = byte_off < ROW_BYTES + tl.multiple_of(byte_off, 16) + + loc = tl.load(loc_2d_ptr + bid * block_size + row) + src_row = bid * block_size + row + + src_k_ptr = tl.cast(src_k_ptr, tl.pointer_type(tl.uint8)) + src_v_ptr = tl.cast(src_v_ptr, tl.pointer_type(tl.uint8)) + dst_k_ptr = tl.cast(dst_k_ptr, tl.pointer_type(tl.uint8)) + dst_v_ptr = tl.cast(dst_v_ptr, tl.pointer_type(tl.uint8)) + + src_k_row_ptr = src_k_ptr + src_row * src_k_row_stride + byte_off + src_v_row_ptr = src_v_ptr + src_row * src_v_row_stride + byte_off + dst_k_row_ptr = dst_k_ptr + loc * dst_k_row_stride + byte_off + dst_v_row_ptr = dst_v_ptr + loc * dst_v_row_stride + byte_off + + k_val = tl.load(src_k_row_ptr, mask=mask_byte, other=0) + v_val = tl.load(src_v_row_ptr, mask=mask_byte, other=0) + tl.store(dst_k_row_ptr, k_val, mask=mask_byte) + tl.store(dst_v_row_ptr, v_val, mask=mask_byte) + + @triton.jit def copy_all_layer_kv_cache_tiled( data_ptrs, diff --git a/python/sglang/srt/model_executor/piecewise_cuda_graph_runner.py b/python/sglang/srt/model_executor/piecewise_cuda_graph_runner.py index da547df15e5e..f6624834ea95 100644 --- a/python/sglang/srt/model_executor/piecewise_cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/piecewise_cuda_graph_runner.py @@ -216,8 +216,12 @@ def __init__(self, model_runner: ModelRunner): self.capture_forward_mode = ForwardMode.EXTEND self.capture_hidden_mode = CaptureHiddenMode.NULL - # If returning hidden states is enabled, set initial capture hidden mode to full to avoid double-capture on startup - if model_runner.server_args.enable_return_hidden_states: + # If returning hidden states is enabled, or if speculative prefill needs + # aux hidden states (DFLASH), capture the FULL variant up front. + if ( + model_runner.server_args.enable_return_hidden_states + or model_runner.spec_algorithm.is_dflash() + ): self.capture_hidden_mode = CaptureHiddenMode.FULL self.max_num_tokens = ( @@ -415,7 +419,7 @@ def warmup_compile(self, num_tokens: int): mrope_positions=mrope_positions, spec_algorithm=None, spec_info=None, - capture_hidden_mode=CaptureHiddenMode.NULL, + capture_hidden_mode=self.capture_hidden_mode, num_token_non_padded=None, num_token_non_padded_cpu=num_tokens, global_forward_mode=ForwardMode.EXTEND, @@ -587,7 +591,7 @@ def capture_one_batch_size(self, num_tokens: int): mrope_positions=mrope_positions, spec_algorithm=None, spec_info=None, - capture_hidden_mode=CaptureHiddenMode.NULL, + capture_hidden_mode=self.capture_hidden_mode, num_token_non_padded=None, num_token_non_padded_cpu=num_tokens, global_forward_mode=ForwardMode.EXTEND, diff --git a/python/sglang/srt/models/dflash.py b/python/sglang/srt/models/dflash.py index c7df14c08762..8d3f5fc57857 100644 --- a/python/sglang/srt/models/dflash.py +++ b/python/sglang/srt/models/dflash.py @@ -28,12 +28,40 @@ from sglang.srt.models.utils import apply_qk_norm from sglang.srt.speculative.dflash_utils import ( can_dflash_slice_qkv_weight, + get_dflash_attention_sliding_window_size, + get_dflash_layer_types, parse_dflash_draft_config, ) +from sglang.srt.utils.hf_transformers_utils import get_rope_config logger = logging.getLogger(__name__) +def _get_dflash_layer_attention_params( + config, layer_id: int +) -> Tuple[int, AttentionType]: + layer_types = get_dflash_layer_types(config) + if layer_types is None: + return -1, AttentionType.ENCODER_ONLY + if layer_id >= len(layer_types): + raise ValueError( + "DFLASH config.layer_types must contain one entry per draft layer. " + f"Got {len(layer_types)} entries, layer_id={layer_id}." + ) + + layer_type = layer_types[layer_id] + if layer_type == "full_attention": + return -1, AttentionType.ENCODER_ONLY + if layer_type == "sliding_attention": + sliding_window_size = get_dflash_attention_sliding_window_size(config) + assert sliding_window_size is not None + return sliding_window_size, AttentionType.DECODER + raise ValueError( + "Unsupported DFLASH draft layer type. " + f"layer_types[{layer_id}]={layer_type!r}." + ) + + class DFlashAttention(nn.Module): def __init__(self, config, layer_id: int) -> None: super().__init__() @@ -90,8 +118,7 @@ def __init__(self, config, layer_id: int) -> None: self.q_norm = RMSNorm(head_dim, eps=rms_norm_eps) self.k_norm = RMSNorm(head_dim, eps=rms_norm_eps) - rope_theta = float(getattr(config, "rope_theta", 1000000)) - rope_scaling = getattr(config, "rope_scaling", None) + rope_theta, rope_scaling = get_rope_config(config) rope_is_neox_style = bool( getattr( config, "rope_is_neox_style", getattr(config, "is_neox_style", True) @@ -108,14 +135,17 @@ def __init__(self, config, layer_id: int) -> None: ) self.scaling = head_dim**-0.5 - # DFlash uses non-causal attention over the draft block. + self.sliding_window_size, self.attn_type = _get_dflash_layer_attention_params( + config, layer_id + ) self.attn = RadixAttention( num_heads=self.num_heads, head_dim=head_dim, scaling=self.scaling, num_kv_heads=self.num_kv_heads, layer_id=layer_id, - attn_type=AttentionType.ENCODER_ONLY, + sliding_window_size=self.sliding_window_size, + attn_type=self.attn_type, ) def forward( @@ -292,6 +322,9 @@ def __init__(self, config, quant_config=None, prefix: str = "") -> None: self.block_size = draft_config.resolve_block_size(default=16) + def get_attention_sliding_window_size(self) -> Optional[int]: + return get_dflash_attention_sliding_window_size(self.config) + def project_target_hidden(self, target_hidden: torch.Tensor) -> torch.Tensor: """Project concatenated target-layer hidden states into draft hidden_size.""" expected = int(self.fc.in_features) diff --git a/python/sglang/srt/models/gemma4_causal.py b/python/sglang/srt/models/gemma4_causal.py index 9c04746d2873..9d9dccac20e8 100644 --- a/python/sglang/srt/models/gemma4_causal.py +++ b/python/sglang/srt/models/gemma4_causal.py @@ -949,6 +949,14 @@ def get_attention_sliding_window_size(self): def dtype(self) -> torch.dtype: return next(self.parameters()).dtype + def set_dflash_layers_to_capture(self, layer_ids: list[int]): + if layer_ids is None: + raise ValueError( + "DFLASH requires explicit layer_ids for aux hidden capture." + ) + self.capture_aux_hidden_states = True + self.model.layers_to_capture = [val + 1 for val in layer_ids] + @torch.no_grad() def forward( self, diff --git a/python/sglang/srt/models/gemma4_mm.py b/python/sglang/srt/models/gemma4_mm.py index 4ce8a5909cc9..e3783b1d4d69 100644 --- a/python/sglang/srt/models/gemma4_mm.py +++ b/python/sglang/srt/models/gemma4_mm.py @@ -218,6 +218,7 @@ def __init__( quant_config, prefix=add_prefix("language_model", prefix), ) + self.lm_head = self.language_model.embed_tokens # Create logits processor for the multimodal model self.logits_processor = LogitsProcessor(config.text_config) @@ -265,6 +266,14 @@ def get_embed_and_head(self) -> Tuple[torch.Tensor, torch.Tensor]: def get_attention_sliding_window_size(self): return getattr(self.config.text_config, "sliding_window", -1) - 1 + def set_dflash_layers_to_capture(self, layer_ids: List[int]): + if layer_ids is None: + raise ValueError( + "DFLASH requires explicit layer_ids for aux hidden capture." + ) + self.capture_aux_hidden_states = True + self.language_model.layers_to_capture = [val + 1 for val in layer_ids] + def prepare_attn_masks( self, forward_batch: ForwardBatch, diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index ec7791337320..2fa37c505b7a 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -1553,7 +1553,7 @@ def _handle_gpu_memory_settings(self, gpu_mem): if self.speculative_algorithm == "STANDALONE": # standalonedraft model and cuda graphs reserved_mem += 6 * 1024 - elif self.speculative_algorithm != "NGRAM": + elif self.speculative_algorithm not in {"NGRAM", "DFLASH"}: # eagle draft models and cuda graphs reserved_mem += 4 * 1024 @@ -3609,10 +3609,17 @@ def _handle_speculative_decoding(self): "Max running requests is reset to 48 for speculative decoding. You can override this by explicitly setting --max-running-requests." ) - self.disable_overlap_schedule = True - logger.warning( - "Overlap scheduler is disabled when using DFLASH speculative decoding (spec v2 is not supported yet)." - ) + if not envs.SGLANG_ENABLE_SPEC_V2.get(): + self.disable_overlap_schedule = True + + if self.disable_overlap_schedule: + logger.warning( + "Spec v1 is used for DFLASH speculative decoding because overlap schedule is disabled." + ) + else: + logger.warning( + "Spec v2 is enabled by default for DFLASH speculative decoding." + ) if self.enable_mixed_chunk: self.enable_mixed_chunk = False diff --git a/python/sglang/srt/speculative/dflash_info.py b/python/sglang/srt/speculative/dflash_info.py index 9cbba1faa61d..f85844bbffd9 100644 --- a/python/sglang/srt/speculative/dflash_info.py +++ b/python/sglang/srt/speculative/dflash_info.py @@ -1,21 +1,25 @@ from __future__ import annotations from dataclasses import dataclass -from typing import List, Tuple +from typing import TYPE_CHECKING, List, Optional, Tuple import torch from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton from sglang.srt.layers.logits_processor import LogitsProcessorOutput -from sglang.srt.layers.sampler import apply_custom_logit_processor -from sglang.srt.managers.schedule_batch import ScheduleBatch +from sglang.srt.managers.schedule_batch import ModelWorkerBatch, ScheduleBatch from sglang.srt.mem_cache.common import ( alloc_paged_token_slots_extend, alloc_token_slots, get_last_loc, ) -from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode +from sglang.srt.model_executor.forward_batch_info import ( + CaptureHiddenMode, + ForwardBatch, + ForwardMode, +) from sglang.srt.speculative.dflash_utils import ( + apply_dflash_verify_logits_adjustments, compute_dflash_accept_len_and_bonus, compute_dflash_sampling_accept_len_and_bonus, is_dflash_sampling_verify_available, @@ -23,6 +27,9 @@ from sglang.srt.speculative.spec_info import SpecInput, SpecInputType from sglang.srt.speculative.spec_utils import assign_req_to_token_pool_func +if TYPE_CHECKING: + from sglang.srt.managers.tp_worker import TpModelWorker + def _compute_paged_keep_slots( *, @@ -161,7 +168,7 @@ class DFlashVerifyInput(SpecInput): # Kept for compatibility with attention backends that gate tree metadata by `topk > 1`. # DFLASH verify is linear (non-tree), so this is always 1. topk: int = 1 - # Custom attention "allow mask" for TARGET_VERIFY in backends that require it (e.g. triton). + # Custom attention "allow mask" for TARGET_VERIFY in backends that require it. # Semantics follow SGLang speculative conventions: True means the (q, k) pair is allowed. custom_mask: torch.Tensor | None = None capture_hidden_mode: CaptureHiddenMode = CaptureHiddenMode.FULL @@ -251,12 +258,49 @@ def prepare_for_verify( else torch.empty((0,), dtype=torch.bool, device=batch.device) ) + def prepare_for_v2_verify( + self, + batch: ModelWorkerBatch, + target_worker: "TpModelWorker", + ) -> tuple[ForwardBatch, bool]: + """Prepare a DFLASH verify forward batch for overlap scheduling. + + Unlike spec-v1, the overlap path already computes and stores + `batch.out_cache_loc` before this method is called. This helper only + packages the verify forward and pre-initializes either CUDA-graph replay + metadata or eager attention metadata so the actual forward can run with + `skip_attn_backend_init=True`. + """ + batch.input_ids = self.draft_token + batch.spec_info = self + batch.forward_mode = ( + ForwardMode.IDLE + if batch.forward_mode.is_idle() + else ForwardMode.TARGET_VERIFY + ) + batch.capture_hidden_mode = self.capture_hidden_mode + verify_forward_batch = ForwardBatch.init_new(batch, target_worker.model_runner) + + can_run_cuda_graph = bool( + target_worker.model_runner.graph_runner + and target_worker.model_runner.graph_runner.can_run(verify_forward_batch) + ) + if can_run_cuda_graph: + target_worker.model_runner.graph_runner.replay_prepare(verify_forward_batch) + elif not batch.forward_mode.is_idle(): + target_worker.model_runner.attn_backend.init_forward_metadata( + verify_forward_batch + ) + + return verify_forward_batch, can_run_cuda_graph + def generate_attn_arg_prefill( self, req_pool_indices: torch.Tensor, paged_kernel_lens: torch.Tensor, paged_kernel_lens_sum: int, req_to_token: torch.Tensor, + kv_start_idx: Optional[torch.Tensor] = None, ): device = req_pool_indices.device bs = len(req_pool_indices) @@ -283,7 +327,7 @@ def generate_attn_arg_prefill( req_pool_indices, paged_kernel_lens, cum_kv_seq_len, - None, + kv_start_idx, kv_indices, req_to_token.size(1), ) @@ -339,28 +383,11 @@ def verify( "DFLASH verify sampling_info size mismatch: " f"len(sampling_info)={len(sampling_info)}, bs={bs}." ) - - # Keep speculative verify semantics consistent with normal sampling path. - if sampling_info.has_custom_logit_processor: - apply_custom_logit_processor( - logits_output.next_token_logits, - sampling_info, - num_tokens_in_batch=self.draft_token_num, - ) - - if ( - sampling_info.penalizer_orchestrator.is_required - or sampling_info.logit_bias is not None - ): - linear_penalty = torch.zeros( - (bs, logits_output.next_token_logits.shape[1]), - dtype=torch.float32, - device=device, - ) - sampling_info.apply_logits_bias(linear_penalty) - logits_output.next_token_logits.add_( - torch.repeat_interleave(linear_penalty, self.draft_token_num, dim=0) - ) + apply_dflash_verify_logits_adjustments( + next_token_logits=logits_output.next_token_logits, + sampling_info=sampling_info, + draft_token_num=self.draft_token_num, + ) candidates = self.draft_token.view(bs, self.draft_token_num) if ( @@ -368,10 +395,17 @@ def verify( and not sampling_info.is_all_greedy and is_dflash_sampling_verify_available() ): + top_ks = [int(req.sampling_params.top_k) for req in batch.reqs] accept_len, bonus = compute_dflash_sampling_accept_len_and_bonus( candidates=candidates, next_token_logits=logits_output.next_token_logits, sampling_info=sampling_info, + max_top_k=max(max(top_ks), 1) if top_ks else 1, + uniform_top_k_value=( + top_ks[0] + if top_ks and all(top_k == top_ks[0] for top_k in top_ks) + else None + ), ) else: target_predict = torch.argmax(logits_output.next_token_logits, dim=-1).view( diff --git a/python/sglang/srt/speculative/dflash_info_v2.py b/python/sglang/srt/speculative/dflash_info_v2.py new file mode 100644 index 000000000000..66d3bafb5e77 --- /dev/null +++ b/python/sglang/srt/speculative/dflash_info_v2.py @@ -0,0 +1,348 @@ +"""DFLASH spec-v2 overlap scheduling data structures.""" + +import contextlib +from dataclasses import dataclass +from typing import Optional, Tuple + +import torch + +from sglang.srt.environ import envs +from sglang.srt.managers.overlap_utils import FutureIndices +from sglang.srt.managers.schedule_batch import ScheduleBatch +from sglang.srt.mem_cache.common import ( + alloc_paged_token_slots_extend, + alloc_token_slots, + get_last_loc, +) +from sglang.srt.server_args import get_global_server_args +from sglang.srt.speculative.spec_info import SpecInput, SpecInputType +from sglang.srt.speculative.spec_utils import assign_req_to_token_pool_func +from sglang.srt.utils.common import is_pin_memory_available + +_OVERLAP_PLAN_STREAMS: dict[str, torch.cuda.Stream] = {} + + +def _get_overlap_plan_stream( + device: torch.device | str, +) -> tuple[Optional[torch.cuda.Stream], contextlib.AbstractContextManager]: + """Return an optional plan stream/context for overlap scheduling prep kernels.""" + if not envs.SGLANG_ENABLE_OVERLAP_PLAN_STREAM.get(): + return None, contextlib.nullcontext() + + device_str = str(device) + stream = _OVERLAP_PLAN_STREAMS.get(device_str) + if stream is None: + stream = torch.get_device_module(device_str).Stream() + _OVERLAP_PLAN_STREAMS[device_str] = stream + return stream, torch.get_device_module(device_str).stream(stream) + + +@dataclass +class DFlashDraftInputV2(SpecInput): + """Draft-side state carried across overlap iterations (spec-v2).""" + + # Legacy Eagle-shaped fields kept only for dataclass compatibility. DFLASH + # overlap only relays verified_id/new_seq_lens through FutureMap. + topk_p: torch.Tensor + topk_index: torch.Tensor + verified_id: torch.Tensor + new_seq_lens: torch.Tensor + hidden_states: torch.Tensor + verify_done: Optional[torch.cuda.Event] = None + max_top_k: int = 1 + uniform_top_k_value: Optional[int] = None + cur_allocated_seq_lens_cpu: Optional[torch.Tensor] = None + planning_seq_lens_cpu: Optional[torch.Tensor] = None + planning_seq_lens_sum: Optional[int] = None + reserved_seq_lens_cpu: Optional[torch.Tensor] = None + reserved_seq_lens_sum: Optional[int] = None + _prepare_committed_kv_lens_cpu_buf: Optional[torch.Tensor] = None + _prepare_planning_kv_lens_cpu_buf: Optional[torch.Tensor] = None + _prepare_batch_seq_lens_cpu_buf: Optional[torch.Tensor] = None + _prepare_cur_kv_lens_cpu_buf: Optional[torch.Tensor] = None + _prepare_nxt_kv_lens_cpu_buf: Optional[torch.Tensor] = None + _prepare_cur_kv_lens_gpu_buf: Optional[torch.Tensor] = None + _prepare_nxt_kv_lens_gpu_buf: Optional[torch.Tensor] = None + + # Filled by scheduler after dispatch. + future_indices: Optional[FutureIndices] = None + + def __post_init__(self): + super().__init__(spec_input_type=SpecInputType.DFLASH_DRAFT) + + def get_spec_adjust_token_coefficient(self) -> Tuple[int, int]: + # Spec v2 draft state itself does not change token accounting. + return (1, 1) + + def _ensure_prepare_length_buffers( + self, bs: int, device: torch.device | str + ) -> None: + pin_memory = is_pin_memory_available(device) + + def needs_cpu_alloc(buf: Optional[torch.Tensor]) -> bool: + return buf is None or buf.numel() < bs or buf.is_pinned() != pin_memory + + def needs_gpu_alloc(buf: Optional[torch.Tensor]) -> bool: + return buf is None or buf.numel() < bs or str(buf.device) != str(device) + + def grown_capacity(buf: Optional[torch.Tensor]) -> int: + current = 0 if buf is None else int(buf.numel()) + return max(bs, 32, current * 2 if current > 0 else 0) + + if needs_cpu_alloc(self._prepare_committed_kv_lens_cpu_buf): + capacity = grown_capacity(self._prepare_committed_kv_lens_cpu_buf) + self._prepare_committed_kv_lens_cpu_buf = torch.empty( + (capacity,), dtype=torch.int32, device="cpu", pin_memory=pin_memory + ) + self._prepare_planning_kv_lens_cpu_buf = torch.empty( + (capacity,), dtype=torch.int32, device="cpu", pin_memory=pin_memory + ) + self._prepare_batch_seq_lens_cpu_buf = torch.empty( + (capacity,), dtype=torch.int64, device="cpu" + ) + self._prepare_cur_kv_lens_cpu_buf = torch.empty( + (capacity,), dtype=torch.int32, device="cpu", pin_memory=pin_memory + ) + self._prepare_nxt_kv_lens_cpu_buf = torch.empty( + (capacity,), dtype=torch.int32, device="cpu", pin_memory=pin_memory + ) + + if needs_gpu_alloc(self._prepare_cur_kv_lens_gpu_buf): + capacity = grown_capacity(self._prepare_cur_kv_lens_gpu_buf) + self._prepare_cur_kv_lens_gpu_buf = torch.empty( + (capacity,), dtype=torch.int32, device=device + ) + self._prepare_nxt_kv_lens_gpu_buf = torch.empty( + (capacity,), dtype=torch.int32, device=device + ) + + @classmethod + def create_idle_input(cls, device: torch.device) -> "DFlashDraftInputV2": + return cls( + topk_p=torch.empty((0, 0), device=device, dtype=torch.float32), + topk_index=torch.empty((0, 0), device=device, dtype=torch.int64), + verified_id=torch.empty((0,), device=device, dtype=torch.int32), + new_seq_lens=torch.empty((0,), device=device, dtype=torch.int32), + hidden_states=torch.empty((0, 0), device=device, dtype=torch.float16), + verify_done=None, + ) + + def prepare_for_decode(self, batch: ScheduleBatch): + """Allocate headroom in the shared req_to_token pool for the next DFLASH step. + + DFLASH spec-v2 uses overlap scheduling's "over-allocation" approach: we reserve + future KV slots ahead of time so the worker can gather `out_cache_loc` directly + from `req_to_token` without allocator backup/restore. CPU metadata intentionally + lags by one iteration; keep it separate from the reserved upper bound that backs + the overallocated mapping. + """ + plan_stream, plan_stream_ctx = _get_overlap_plan_stream(batch.device) + if plan_stream is None: + # Ensure previous forward is completed before mutating shared buffers. + batch.maybe_wait_verify_done() + + bs = batch.batch_size() + if bs == 0: + return + self._ensure_prepare_length_buffers(bs, batch.device) + assert self._prepare_committed_kv_lens_cpu_buf is not None + assert self._prepare_planning_kv_lens_cpu_buf is not None + assert self._prepare_batch_seq_lens_cpu_buf is not None + assert self._prepare_cur_kv_lens_cpu_buf is not None + assert self._prepare_nxt_kv_lens_cpu_buf is not None + assert self._prepare_cur_kv_lens_gpu_buf is not None + assert self._prepare_nxt_kv_lens_gpu_buf is not None + committed_kv_lens_cpu_t = self._prepare_committed_kv_lens_cpu_buf[:bs] + planning_kv_lens_cpu_t = self._prepare_planning_kv_lens_cpu_buf[:bs] + batch_seq_lens_cpu_t = self._prepare_batch_seq_lens_cpu_buf[:bs] + cur_kv_lens_cpu_t = self._prepare_cur_kv_lens_cpu_buf[:bs] + cur_allocated_seq_lens_cpu = self.cur_allocated_seq_lens_cpu + + # For DFLASH, each decode step needs a fixed-size verify block. + block_size = int(get_global_server_args().speculative_num_draft_tokens) + if block_size <= 0: + raise ValueError( + f"DFLASH invalid speculative_num_draft_tokens={block_size}." + ) + + page_size = batch.token_to_kv_pool_allocator.page_size + nxt_kv_lens_cpu_t = self._prepare_nxt_kv_lens_cpu_buf[:bs] + committed_seq_lens_sum = 0 + planning_seq_lens_sum = 0 + reserved_seq_lens_sum = 0 + num_needed_tokens = 0 + max_top_k = 1 + uniform_top_k_value = None + uniform_top_k = True + for i, req in enumerate(batch.reqs): + committed_len = int(req.kv_committed_len) + if cur_allocated_seq_lens_cpu is not None and i < len( + cur_allocated_seq_lens_cpu + ): + cur_alloc_len = int(cur_allocated_seq_lens_cpu[i]) + else: + cur_alloc_len = int(req.kv_allocated_len) + planning_len = committed_len + block_size + reserved_len = max(cur_alloc_len, committed_len + 2 * block_size) + top_k = int(req.sampling_params.top_k) + + committed_kv_lens_cpu_t[i] = committed_len + batch_seq_lens_cpu_t[i] = committed_len + cur_kv_lens_cpu_t[i] = cur_alloc_len + planning_kv_lens_cpu_t[i] = planning_len + nxt_kv_lens_cpu_t[i] = reserved_len + + committed_seq_lens_sum += committed_len + planning_seq_lens_sum += planning_len + reserved_seq_lens_sum += reserved_len + num_needed_tokens += reserved_len - cur_alloc_len + + if top_k > max_top_k: + max_top_k = top_k + if i == 0: + uniform_top_k_value = top_k + elif uniform_top_k and top_k != uniform_top_k_value: + uniform_top_k = False + + self.max_top_k = max(max_top_k, 1) + self.uniform_top_k_value = uniform_top_k_value if uniform_top_k else None + + caller_stream = None + if plan_stream is not None: + caller_stream = torch.get_device_module(batch.device).current_stream() + + with plan_stream_ctx: + if plan_stream is not None and caller_stream is not None: + # `batch.seq_lens`, `batch.req_pool_indices`, and related tensors may + # have just been rebuilt on the scheduler stream by filter/merge ops. + # The plan stream must wait for those writes before reading them. + plan_stream.wait_stream(caller_stream) + + if plan_stream is not None and self.verify_done is not None: + plan_stream.wait_event(self.verify_done) + + cur_kv_lens = self._prepare_cur_kv_lens_gpu_buf[:bs] + nxt_kv_lens = self._prepare_nxt_kv_lens_gpu_buf[:bs] + cur_kv_lens.copy_(cur_kv_lens_cpu_t, non_blocking=True) + nxt_kv_lens.copy_(nxt_kv_lens_cpu_t, non_blocking=True) + + if num_needed_tokens > 0: + if page_size == 1: + out_cache_loc = alloc_token_slots( + batch.tree_cache, num_needed_tokens + ) + else: + last_loc = get_last_loc( + batch.req_to_token_pool.req_to_token, + batch.req_pool_indices, + cur_kv_lens, + ) + out_cache_loc = alloc_paged_token_slots_extend( + batch.tree_cache, + cur_kv_lens, + cur_kv_lens_cpu_t, + nxt_kv_lens, + nxt_kv_lens_cpu_t, + last_loc, + num_needed_tokens, + ) + + # Updating req_to_token is a write to a shared tensor: it must not overlap + # with the previous batch's forward, which also reads req_to_token. + assign_req_to_token_pool_func( + batch.req_pool_indices, + batch.req_to_token_pool.req_to_token, + cur_kv_lens, + nxt_kv_lens, + out_cache_loc, + bs, + ) + if caller_stream is not None: + # Enqueue the dependency on the caller's stream, not inside the + # plan-stream context, so forward work cannot observe partially + # prepared req_to_token / KV allocation state. + caller_stream.wait_stream(plan_stream) + + for i, req in enumerate(batch.reqs): + req.kv_allocated_len = int(nxt_kv_lens_cpu_t[i]) + + # Preserve the lagging committed CPU view on the batch and carry the + # tighter host-side planning bound separately from the full reserved + # allocator upper bound. Overlap scheduling only drifts by at most one + # DFlash block on the committed prefix lengths. + batch.seq_lens_cpu = batch_seq_lens_cpu_t + batch.seq_lens_sum = committed_seq_lens_sum + self.planning_seq_lens_cpu = planning_kv_lens_cpu_t + self.planning_seq_lens_sum = planning_seq_lens_sum + self.reserved_seq_lens_cpu = nxt_kv_lens_cpu_t + self.reserved_seq_lens_sum = reserved_seq_lens_sum + + def filter_batch(self, new_indices: torch.Tensor, has_been_filtered: bool = True): + if self.cur_allocated_seq_lens_cpu is not None: + self.cur_allocated_seq_lens_cpu = self.cur_allocated_seq_lens_cpu[ + new_indices.cpu() + ] + if self.planning_seq_lens_cpu is not None: + self.planning_seq_lens_cpu = self.planning_seq_lens_cpu[new_indices.cpu()] + self.planning_seq_lens_sum = int(self.planning_seq_lens_cpu.sum().item()) + if self.reserved_seq_lens_cpu is not None: + self.reserved_seq_lens_cpu = self.reserved_seq_lens_cpu[new_indices.cpu()] + self.reserved_seq_lens_sum = int(self.reserved_seq_lens_cpu.sum().item()) + + if self.future_indices is not None: + self.future_indices.indices = self.future_indices.indices[new_indices] + return + + self.topk_p = self.topk_p[new_indices] + self.topk_index = self.topk_index[new_indices] + self.verified_id = self.verified_id[new_indices] + self.new_seq_lens = self.new_seq_lens[new_indices] + self.hidden_states = self.hidden_states[new_indices] + + def merge_batch(self, spec_info: "DFlashDraftInputV2"): + if self.cur_allocated_seq_lens_cpu is not None: + assert spec_info.cur_allocated_seq_lens_cpu is not None + self.cur_allocated_seq_lens_cpu = torch.cat( + [self.cur_allocated_seq_lens_cpu, spec_info.cur_allocated_seq_lens_cpu] + ) + elif spec_info.cur_allocated_seq_lens_cpu is not None: + self.cur_allocated_seq_lens_cpu = spec_info.cur_allocated_seq_lens_cpu + + if self.planning_seq_lens_cpu is not None: + assert spec_info.planning_seq_lens_cpu is not None + self.planning_seq_lens_cpu = torch.cat( + [self.planning_seq_lens_cpu, spec_info.planning_seq_lens_cpu] + ) + self.planning_seq_lens_sum = int(self.planning_seq_lens_cpu.sum().item()) + elif spec_info.planning_seq_lens_cpu is not None: + self.planning_seq_lens_cpu = spec_info.planning_seq_lens_cpu + self.planning_seq_lens_sum = spec_info.planning_seq_lens_sum + + if self.reserved_seq_lens_cpu is not None: + assert spec_info.reserved_seq_lens_cpu is not None + self.reserved_seq_lens_cpu = torch.cat( + [self.reserved_seq_lens_cpu, spec_info.reserved_seq_lens_cpu] + ) + self.reserved_seq_lens_sum = int(self.reserved_seq_lens_cpu.sum().item()) + elif spec_info.reserved_seq_lens_cpu is not None: + self.reserved_seq_lens_cpu = spec_info.reserved_seq_lens_cpu + self.reserved_seq_lens_sum = spec_info.reserved_seq_lens_sum + + if self.future_indices is not None: + assert spec_info.future_indices is not None + self.future_indices = FutureIndices( + indices=torch.cat( + [self.future_indices.indices, spec_info.future_indices.indices] + ) + ) + return + + self.topk_p = torch.cat([self.topk_p, spec_info.topk_p], dim=0) + self.topk_index = torch.cat([self.topk_index, spec_info.topk_index], dim=0) + self.verified_id = torch.cat([self.verified_id, spec_info.verified_id], dim=0) + self.new_seq_lens = torch.cat( + [self.new_seq_lens, spec_info.new_seq_lens], dim=0 + ) + self.hidden_states = torch.cat( + [self.hidden_states, spec_info.hidden_states], dim=0 + ) diff --git a/python/sglang/srt/speculative/dflash_utils.py b/python/sglang/srt/speculative/dflash_utils.py index 2d7963532654..a4eaa1696e27 100644 --- a/python/sglang/srt/speculative/dflash_utils.py +++ b/python/sglang/srt/speculative/dflash_utils.py @@ -1,5 +1,8 @@ from __future__ import annotations +import logging +import os +from collections.abc import Sequence from dataclasses import dataclass from numbers import Integral from typing import Any, List, Optional, Tuple @@ -8,9 +11,13 @@ import torch.nn.functional as F from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod +from sglang.srt.layers.sampler import apply_custom_logit_processor from sglang.srt.utils import is_cuda, is_musa DEFAULT_DFLASH_MASK_TOKEN = "<|MASK|>" +DFLASH_PREFILL_REFILL_TARGET_ENV = "SGLANG_DFLASH_PREFILL_REFILL_TARGET" + +logger = logging.getLogger(__name__) _DFLASH_SAMPLING_VERIFY_AVAILABLE = False _DFLASH_CHAIN_VERIFY_BUFFERS: dict[tuple[Optional[int], int], dict[str, Any]] = {} @@ -19,6 +26,7 @@ "FlashInferAttnBackend", "FlashInferMLAAttnBackend", "FlashAttentionBackend", + "TritonAttnBackend", "TRTLLMHAAttnBackend", "TRTLLMMLABackend", } @@ -48,6 +56,54 @@ def is_dflash_sampling_verify_available() -> bool: return _DFLASH_SAMPLING_VERIFY_AVAILABLE +def get_default_dflash_prefill_refill_target(max_running_requests: int) -> int: + """Choose how many free running-request slots DFlash waits for before refill.""" + max_running_requests = max(0, int(max_running_requests)) + if max_running_requests < 8: + return 1 + return min(4, max(2, (max_running_requests + 5) // 6)) + + +def get_dflash_prefill_refill_target_override() -> Optional[int]: + env_value = os.getenv(DFLASH_PREFILL_REFILL_TARGET_ENV) + if env_value is None or not env_value.strip(): + return None + try: + return int(env_value) + except ValueError: + logger.warning( + "Ignoring invalid %s=%r; using DFlash prefill refill heuristic.", + DFLASH_PREFILL_REFILL_TARGET_ENV, + env_value, + ) + return None + + +def resolve_dflash_prefill_refill_target(max_running_requests: int) -> int: + override = get_dflash_prefill_refill_target_override() + if override is not None: + return override + return get_default_dflash_prefill_refill_target(max_running_requests) + + +def should_delay_dflash_prefill_for_batching( + *, + running_bs: int, + num_allocatable_reqs: int, + max_running_requests: int, + prefill_refill_target: int, +) -> bool: + if running_bs <= 0: + return False + + target_prefill_bs = int(prefill_refill_target) + if target_prefill_bs <= 1: + return False + + target_prefill_bs = min(target_prefill_bs, int(max_running_requests)) + return int(num_allocatable_reqs) < target_prefill_bs + + def scale_kv_cell_size_per_token_for_dflash( *, target_cell_size_per_token: int, @@ -100,6 +156,95 @@ def resolve_dflash_verify_mask_policy(attn_backend: Any) -> tuple[str, bool]: return backend_name, (backend_name not in _DFLASH_VERIFY_SKIP_CUSTOM_MASK_BACKENDS) +def apply_dflash_verify_logits_adjustments( + *, + next_token_logits: torch.Tensor, + sampling_info: Any, + draft_token_num: int, +) -> None: + """Apply sampling-time logit adjustments for DFlash verify in place. + + This keeps v1 and v2 verify semantics aligned while letting overlap scheduling + use the cheaper precomputed `acc_linear_penalties` path instead of allocating a + repeated `[bs * draft_token_num, vocab]` penalty tensor every step. + """ + if sampling_info is None: + return + if next_token_logits.ndim != 2: + raise ValueError( + "next_token_logits must be 2D, " + f"got shape={tuple(next_token_logits.shape)}." + ) + if draft_token_num <= 0: + raise ValueError(f"draft_token_num must be positive, got {draft_token_num}.") + + bs = len(sampling_info) + if next_token_logits.shape[0] != bs * draft_token_num: + raise ValueError( + "next_token_logits row count mismatch for DFlash verify adjustments. " + f"Expected {bs * draft_token_num}, got {next_token_logits.shape[0]}." + ) + + if sampling_info.has_custom_logit_processor: + apply_custom_logit_processor( + next_token_logits, + sampling_info, + num_tokens_in_batch=draft_token_num, + ) + + acc_linear_penalties = getattr(sampling_info, "acc_linear_penalties", None) + penalizer = getattr(sampling_info, "penalizer_orchestrator", None) + vocab_mask = getattr(sampling_info, "vocab_mask", None) + logit_bias = getattr(sampling_info, "logit_bias", None) + + logits_3d: Optional[torch.Tensor] = None + + def get_logits_3d() -> torch.Tensor: + nonlocal logits_3d + if logits_3d is None: + logits_3d = next_token_logits.reshape(bs, draft_token_num, -1) + return logits_3d + + # Dense fallback only when we need live penalizer application or a vocab mask. + # In overlap scheduling the common path is `acc_linear_penalties`, which can be + # broadcast over the verify block without materializing a repeated buffer. + if ( + penalizer is not None and penalizer.is_required and acc_linear_penalties is None + ) or vocab_mask is not None: + linear_penalty = torch.zeros( + (bs, next_token_logits.shape[1]), + dtype=torch.float32, + device=next_token_logits.device, + ) + sampling_info.apply_logits_bias(linear_penalty) + get_logits_3d().add_( + linear_penalty[:, None, :].to(dtype=next_token_logits.dtype) + ) + return + + if acc_linear_penalties is not None: + if ( + acc_linear_penalties.device != next_token_logits.device + or acc_linear_penalties.dtype != next_token_logits.dtype + ): + acc_linear_penalties = acc_linear_penalties.to( + device=next_token_logits.device, + dtype=next_token_logits.dtype, + ) + get_logits_3d().add_(acc_linear_penalties[:, None, :]) + + if logit_bias is not None: + if ( + logit_bias.device != next_token_logits.device + or logit_bias.dtype != next_token_logits.dtype + ): + logit_bias = logit_bias.to( + device=next_token_logits.device, + dtype=next_token_logits.dtype, + ) + get_logits_3d().add_(logit_bias[:, None, :]) + + def _get_or_create_chain_verify_buffers( *, bs: int, @@ -199,6 +344,36 @@ def build_target_layer_ids(num_target_layers: int, num_draft_layers: int) -> Lis ] +def get_dflash_layer_types(config: Any) -> Optional[Sequence[str]]: + text_config = _get_text_config(config) + layer_types = _cfg_get(text_config, "layer_types", _cfg_get(config, "layer_types")) + if layer_types is None: + return None + if isinstance(layer_types, str) or not isinstance(layer_types, Sequence): + raise ValueError( + "DFLASH config.layer_types must be a sequence of attention type strings." + ) + return layer_types + + +def get_dflash_attention_sliding_window_size(config: Any) -> Optional[int]: + layer_types = get_dflash_layer_types(config) + if layer_types is None or "sliding_attention" not in layer_types: + return None + + text_config = _get_text_config(config) + sliding_window = _cfg_get( + text_config, "sliding_window", _cfg_get(config, "sliding_window") + ) + if sliding_window is None: + raise ValueError( + "DFLASH sliding_attention layers require config.sliding_window." + ) + + # HF sliding windows include the current token; SGLang stores window_left. + return int(sliding_window) - 1 + + def _cfg_get(config: Any, key: str, default: Any = None) -> Any: if isinstance(config, dict): return config.get(key, default) @@ -464,6 +639,8 @@ def compute_dflash_sampling_accept_len_and_bonus( candidates: torch.Tensor, next_token_logits: torch.Tensor, sampling_info: Any, + max_top_k: Optional[int] = None, + uniform_top_k_value: Optional[int] = None, threshold_single: Optional[float] = None, threshold_acc: Optional[float] = None, uniform_samples: Optional[torch.Tensor] = None, @@ -560,12 +737,19 @@ def compute_dflash_sampling_accept_len_and_bonus( ).to(dtype=torch.int64) vocab_size = int(scaled_logits.shape[-1]) repeated_top_ks.clamp_(min=1, max=vocab_size) - max_top_k = int(repeated_top_ks.max().item()) + if max_top_k is None: + max_top_k = int(repeated_top_ks.max().item()) + else: + max_top_k = int(max_top_k) + if max_top_k < 1: + max_top_k = 1 + elif max_top_k > vocab_size: + max_top_k = vocab_size # Sparse exact path for top-k/top-p (top-k-first semantics), then scatter to dense. if 0 < max_top_k < vocab_size: topk_logits, topk_indices = torch.topk(scaled_logits, k=max_top_k, dim=-1) - if not torch.all(repeated_top_ks == max_top_k): + if uniform_top_k_value is None or int(uniform_top_k_value) != max_top_k: ranks = torch.arange(max_top_k, device=device, dtype=torch.int64)[ None, : ] diff --git a/python/sglang/srt/speculative/dflash_worker.py b/python/sglang/srt/speculative/dflash_worker.py index 8d34db1748a5..4cf8dd06b7de 100644 --- a/python/sglang/srt/speculative/dflash_worker.py +++ b/python/sglang/srt/speculative/dflash_worker.py @@ -113,7 +113,7 @@ def __init__( _fb = "triton" if _torch.version.hip else "flashinfer" logger.warning( "DFLASH draft worker does not support 'trtllm_mha' because the " - "draft path requires non-causal attention. Falling back to " + "draft path requires per-layer DFlash attention. Falling back to " "'%s'.", _fb, ) @@ -157,6 +157,8 @@ def __init__( ) set_global_server_args_for_scheduler(saved_server_args) self.draft_model_runner = self.draft_worker.model_runner + # Keep the same alias that other spec-v2 workers expose. + self.draft_worker.draft_runner = self.draft_model_runner self.draft_model = self.draft_model_runner.model draft_config = parse_dflash_draft_config( draft_hf_config=self.draft_model_runner.model_config.hf_config @@ -177,6 +179,7 @@ def __init__( self.block_size, model_block_size, ) + self.speculative_num_draft_tokens = int(self.block_size) self._mask_token = draft_config.mask_token self._mask_token_id_override = draft_config.mask_token_id @@ -210,6 +213,9 @@ def __init__( self._draft_block_tokens_buf: Optional[torch.Tensor] = ( None # [cap_bs, block_size] ) + self._draft_verify_out_cache_loc_buf: Optional[torch.Tensor] = ( + None # [cap_bs, block_size] + ) self._draft_block_end_buf: Optional[torch.Tensor] = None # [cap_bs] self._draft_seq_lens_cpu_buf: Optional[torch.Tensor] = None # [cap_bs] on CPU self._draft_block_spec_info = DFlashVerifyInput( @@ -222,11 +228,13 @@ def __init__( self._draft_greedy_gathered_max_buf: Optional[torch.Tensor] = None self._draft_greedy_gathered_ids_buf: Optional[torch.Tensor] = None self._draft_greedy_gather_cap: int = 0 + self._draft_greedy_local_max_buf: Optional[torch.Tensor] = None + self._draft_greedy_local_arg_buf: Optional[torch.Tensor] = None + self._draft_greedy_local_cap: int = 0 self._draft_greedy_best_rank_buf: Optional[torch.Tensor] = None self._draft_greedy_rank_index_buf: Optional[torch.Tensor] = None self._draft_greedy_selected_ids_buf: Optional[torch.Tensor] = None self._draft_greedy_index_cap: int = 0 - self._use_fused_kv_materialize = is_cuda() self._fused_kv_helper: Optional[object] = None if self._use_fused_kv_materialize: @@ -294,6 +302,8 @@ def _init_fused_kv_helper(self) -> None: num_kv_heads=first_attn.num_kv_heads, head_dim=first_attn.head_dim, device=self.device, + max_position_hint=self.target_worker.model_runner.model_config.context_len + + int(self.block_size), ) if self.tp_rank == 0: logger.info( @@ -332,6 +342,9 @@ def _ensure_draft_block_buffers(self, bs: int) -> None: self._draft_block_tokens_buf = torch.empty( (new_cap, block_size), dtype=torch.long, device=device ) + self._draft_verify_out_cache_loc_buf = torch.empty( + (new_cap, block_size), dtype=torch.int64, device=device + ) self._draft_block_end_buf = torch.empty( (new_cap,), dtype=torch.int32, device=device ) @@ -343,6 +356,9 @@ def __getattr__(self, name): # Delegate anything not implemented yet to the target worker. return getattr(self.target_worker, name) + def on_verify_complete_cpu(self, num_accepted_drafts_per_req: list[int]) -> None: + pass + def clear_cache_pool(self): # The target worker owns the shared KV allocator/cache. For the compact # sliding-window path, the draft req->token view is rebuilt from committed @@ -550,17 +566,12 @@ def _prepare_for_speculative_decoding( target_model = self.target_worker.model_runner.model embed_module = target_model.get_input_embeddings() lm_head = getattr(target_model, "lm_head", None) - if ( - lm_head is None - or not hasattr(lm_head, "weight") - or not hasattr(lm_head, "shard_indices") - ): + if lm_head is None or not hasattr(lm_head, "weight"): raise RuntimeError( - "DFLASH requires the target model to expose a vocab-parallel `lm_head` with `weight` and " - "`shard_indices` attributes." + "DFLASH requires the target model to expose `lm_head` with `weight`." ) - # --- 2) Draft a non-causal block with the draft model. + # --- 2) Draft a fixed block with the draft model. self._ensure_draft_block_buffers(bs) assert self._draft_block_ids_buf is not None assert self._draft_block_positions_buf is not None @@ -568,38 +579,40 @@ def _prepare_for_speculative_decoding( assert self._draft_block_end_buf is not None assert self._draft_seq_lens_cpu_buf is not None - block_ids = self._draft_block_ids_buf[:bs] - block_ids.fill_(int(self._mask_token_id)) - block_ids[:, 0].copy_(draft_input.bonus_tokens.to(torch.long)) - - noise_embedding = embed_module(block_ids) - input_embeds = noise_embedding.view(-1, noise_embedding.shape[-1]) - - # For spec-v1, the draft KV cache is always materialized before drafting the - # next block. `target_prefix_lens` stay absolute for RoPE; `draft_prefix_lens` - # are the logical resident lengths in the draft-local cache. - target_prefix_lens = batch.seq_lens # int32, device - draft_prefix_lens = draft_input.draft_seq_lens - if draft_prefix_lens.dtype != torch.int32: - draft_prefix_lens = draft_prefix_lens.to(torch.int32) - if draft_prefix_lens.device != self.device: - draft_prefix_lens = draft_prefix_lens.to(self.device, non_blocking=True) - - positions_2d = self._draft_block_positions_buf[:bs] - torch.add( - target_prefix_lens.unsqueeze(1), self._block_pos_offsets, out=positions_2d - ) - positions = positions_2d.reshape(-1) - - block_start = draft_prefix_lens - block_end = self._draft_block_end_buf[:bs] - torch.add(block_start, int(self.block_size), out=block_end) - - seq_lens_cpu = self._draft_seq_lens_cpu_buf[:bs] - seq_lens_cpu.copy_(draft_prefix_lens.to(device="cpu", dtype=torch.int32)) allocator = self.draft_model_runner.token_to_kv_pool_allocator token_to_kv_pool_state_backup = allocator.backup_state() try: + block_ids = self._draft_block_ids_buf[:bs] + block_ids.fill_(int(self._mask_token_id)) + block_ids[:, 0].copy_(draft_input.bonus_tokens.to(torch.long)) + + noise_embedding = embed_module(block_ids) + input_embeds = noise_embedding.view(-1, noise_embedding.shape[-1]) + + # For spec-v1, the draft KV cache is always materialized before drafting the + # next block. `target_prefix_lens` stay absolute for RoPE; `draft_prefix_lens` + # are the logical resident lengths in the draft-local cache. + target_prefix_lens = batch.seq_lens # int32, device + draft_prefix_lens = draft_input.draft_seq_lens + if draft_prefix_lens.dtype != torch.int32: + draft_prefix_lens = draft_prefix_lens.to(torch.int32) + if draft_prefix_lens.device != self.device: + draft_prefix_lens = draft_prefix_lens.to(self.device, non_blocking=True) + + positions_2d = self._draft_block_positions_buf[:bs] + torch.add( + target_prefix_lens.unsqueeze(1), + self._block_pos_offsets, + out=positions_2d, + ) + positions = positions_2d.reshape(-1) + + block_start = draft_prefix_lens + block_end = self._draft_block_end_buf[:bs] + torch.add(block_start, int(self.block_size), out=block_end) + + seq_lens_cpu = self._draft_seq_lens_cpu_buf[:bs] + seq_lens_cpu.copy_(draft_prefix_lens.to(device="cpu", dtype=torch.int32)) if self.page_size == 1: block_cache_loc = allocator.alloc(bs * self.block_size) else: @@ -716,17 +729,27 @@ def _greedy_sample_from_vocab_parallel_head( if hidden_states.numel() == 0: return torch.empty((0,), dtype=torch.long, device=hidden_states.device) - tp_group = get_tp_group() - tp_size = int(tp_group.world_size) + weight = lm_head.weight # [local_vocab_padded, hidden] + weight_dtype = weight.dtype + num_tokens = int(hidden_states.shape[0]) + out_token_ids = torch.empty( + (num_tokens,), dtype=torch.long, device=hidden_states.device + ) - if not hasattr(lm_head, "weight") or not hasattr(lm_head, "shard_indices"): - raise RuntimeError( - "DFLASH greedy sampling requires a vocab-parallel head with `weight` and `shard_indices`." - ) + def _cast_hs(x: torch.Tensor) -> torch.Tensor: + return x if x.dtype == weight_dtype else x.to(weight_dtype) + + if not hasattr(lm_head, "shard_indices"): + for start in range(0, num_tokens, int(chunk_size)): + end = min(num_tokens, start + int(chunk_size)) + hs = _cast_hs(hidden_states[start:end]) + logits = torch.matmul(hs, weight.T) + out_token_ids[start:end] = torch.argmax(logits, dim=-1).to(torch.long) + return out_token_ids shard = lm_head.shard_indices - weight = lm_head.weight # [local_vocab_padded, hidden] - weight_dtype = weight.dtype + tp_group = get_tp_group() + tp_size = int(tp_group.world_size) # Valid ranges in the local shard (excluding padding): # base vocab: [0, num_org) @@ -737,26 +760,51 @@ def _greedy_sample_from_vocab_parallel_head( org_vocab_start = int(shard.org_vocab_start_index) added_vocab_start = int(shard.added_vocab_start_index) - num_tokens = int(hidden_states.shape[0]) - out_token_ids = torch.empty( - (num_tokens,), dtype=torch.long, device=hidden_states.device - ) - - def _cast_hs(x: torch.Tensor) -> torch.Tensor: - return x if x.dtype == weight_dtype else x.to(weight_dtype) + def _ensure_local_reduce_buffers( + chunk_len: int, + value_dtype: torch.dtype, + device: torch.device, + ) -> tuple[torch.Tensor, torch.Tensor]: + if ( + self._draft_greedy_local_cap < chunk_len + or self._draft_greedy_local_max_buf is None + or self._draft_greedy_local_arg_buf is None + or self._draft_greedy_local_max_buf.dtype != value_dtype + or self._draft_greedy_local_max_buf.device != device + or self._draft_greedy_local_arg_buf.device != device + ): + cap = max(int(chunk_size), chunk_len) + self._draft_greedy_local_max_buf = torch.empty( + (cap,), dtype=value_dtype, device=device + ) + self._draft_greedy_local_arg_buf = torch.empty( + (cap,), dtype=torch.int64, device=device + ) + self._draft_greedy_local_cap = cap + return ( + self._draft_greedy_local_max_buf[:chunk_len], + self._draft_greedy_local_arg_buf[:chunk_len], + ) # Fast path (common): single-rank greedy sampling over the base vocab shard. # Avoids extra max/id bookkeeping that is only needed for TP sync or added vocab. + # + # DFLASH draft sampling only materializes a small fixed block of hidden states + # each step. On tp=1, splitting those states into many 256-token chunks adds + # extra matmul/argmax launches without reducing peak memory meaningfully. if tp_size == 1 and num_added == 0: - for start in range(0, num_tokens, int(chunk_size)): - end = min(num_tokens, start + int(chunk_size)) + fast_chunk_size = max(int(chunk_size), 1024) + for start in range(0, num_tokens, fast_chunk_size): + end = min(num_tokens, start + fast_chunk_size) hs = _cast_hs(hidden_states[start:end]) if num_org > 0: base_logits = torch.matmul(hs, weight[:num_org].T) - out_token_ids[start:end] = ( - torch.argmax(base_logits, dim=-1).to(torch.long) - + org_vocab_start + local_max, local_arg = _ensure_local_reduce_buffers( + end - start, base_logits.dtype, hs.device ) + torch.max(base_logits, dim=-1, out=(local_max, local_arg)) + out_token_ids[start:end].copy_(local_arg) + out_token_ids[start:end].add_(org_vocab_start) else: out_token_ids[start:end] = 0 return out_token_ids @@ -769,7 +817,10 @@ def _cast_hs(x: torch.Tensor) -> torch.Tensor: # Base vocab logits. if num_org > 0: base_logits = torch.matmul(hs, weight[:num_org].T) - local_max, local_arg = torch.max(base_logits, dim=-1) + local_max, local_arg = _ensure_local_reduce_buffers( + chunk_len, base_logits.dtype, hs.device + ) + torch.max(base_logits, dim=-1, out=(local_max, local_arg)) else: local_max = torch.full( (chunk_len,), @@ -967,11 +1018,13 @@ def _append_target_hidden_to_draft_kv( f"DFLASH ctx_hidden/cache_loc mismatch: {ctx_hidden.shape[0]} vs {ctx_cache_loc.numel()}." ) + wrote_with_fused_kv = False if self._use_fused_kv_materialize and self._fused_kv_helper is not None: try: self._append_target_hidden_fused( ctx_hidden, ctx_positions, ctx_cache_loc ) + wrote_with_fused_kv = True except Exception as e: logger.warning( "DFLASH fused KV append failed; falling back to sequential path: %s", @@ -979,10 +1032,7 @@ def _append_target_hidden_to_draft_kv( ) self._use_fused_kv_materialize = False self._fused_kv_helper = None - self._append_target_hidden_sequential( - ctx_hidden, ctx_positions, ctx_cache_loc - ) - else: + if not wrote_with_fused_kv: self._append_target_hidden_sequential( ctx_hidden, ctx_positions, ctx_cache_loc ) @@ -1012,6 +1062,160 @@ def _append_target_hidden_to_draft_kv( draft_input.ctx_lens = torch.zeros_like(ctx_lens) draft_input.target_hidden = draft_input.target_hidden[:0] + def _append_target_hidden_to_draft_kv_by_loc( + self, + *, + target_hidden: torch.Tensor, + cache_loc: torch.Tensor, + positions: torch.Tensor, + cache_loc_2d: Optional[torch.Tensor] = None, + commit_lens: Optional[torch.Tensor] = None, + ) -> None: + """Materialize target context features into the draft KV cache at explicit slots. + + For the spec-v2 overlap path, callers can pass dense `[bs, block_size]` + `cache_loc_2d` plus `commit_lens`; the prefix-valid writer then commits + only the live prefix rows without constructing masked/packed index tensors. + """ + if target_hidden is None: + raise RuntimeError("DFLASH missing target hidden context features.") + if target_hidden.numel() == 0: + return + if target_hidden.ndim != 2: + raise ValueError( + "DFLASH target_hidden must be 2D, " + f"got shape={tuple(target_hidden.shape)}." + ) + + if cache_loc.ndim != 1: + raise ValueError( + f"DFLASH cache_loc must be 1D, got shape={tuple(cache_loc.shape)}." + ) + if positions.ndim != 1: + raise ValueError( + f"DFLASH positions must be 1D, got shape={tuple(positions.shape)}." + ) + num_tokens = int(target_hidden.shape[0]) + if int(cache_loc.numel()) != num_tokens: + raise ValueError( + "DFLASH cache_loc length mismatch: " + f"cache_loc={int(cache_loc.numel())}, target_hidden={num_tokens}." + ) + if int(positions.numel()) != num_tokens: + raise ValueError( + "DFLASH positions length mismatch: " + f"positions={int(positions.numel())}, target_hidden={num_tokens}." + ) + if cache_loc_2d is not None: + if cache_loc_2d.ndim != 2: + raise ValueError( + "DFLASH cache_loc_2d must be 2D, " + f"got shape={tuple(cache_loc_2d.shape)}." + ) + if int(cache_loc_2d.numel()) != num_tokens: + raise ValueError( + "DFLASH cache_loc_2d size mismatch: " + f"cache_loc_2d={int(cache_loc_2d.numel())}, target_hidden={num_tokens}." + ) + if commit_lens is None: + raise ValueError( + "DFLASH cache_loc_2d requires commit_lens for prefix-valid writes." + ) + + device = self.model_runner.device + if cache_loc.device != device: + cache_loc = cache_loc.to(device, non_blocking=True) + if positions.device != device: + positions = positions.to(device, non_blocking=True) + if target_hidden.device != device: + target_hidden = target_hidden.to(device, non_blocking=True) + + if cache_loc.dtype != torch.int64: + cache_loc = cache_loc.to(torch.int64) + if positions.dtype != torch.int64: + positions = positions.to(torch.int64) + if cache_loc_2d is not None: + if cache_loc_2d.device != device: + cache_loc_2d = cache_loc_2d.to(device, non_blocking=True) + if cache_loc_2d.dtype != torch.int64: + cache_loc_2d = cache_loc_2d.to(torch.int64) + if commit_lens is not None: + if commit_lens.device != device: + commit_lens = commit_lens.to(device, non_blocking=True) + if commit_lens.dtype != torch.int32: + commit_lens = commit_lens.to(torch.int32) + + with torch.inference_mode(): + ctx_hidden = self.draft_model.project_target_hidden(target_hidden) + + if cache_loc_2d is not None: + bs = int(commit_lens.shape[0]) + if int(cache_loc_2d.shape[0]) != bs: + raise ValueError( + "DFLASH cache_loc_2d batch size mismatch: " + f"cache_loc_2d={tuple(cache_loc_2d.shape)}, commit_lens={tuple(commit_lens.shape)}." + ) + if bs == 0: + return + if self._use_fused_kv_materialize and self._fused_kv_helper is not None: + try: + self._append_target_hidden_fused( + ctx_hidden=ctx_hidden, + ctx_positions=positions, + ctx_cache_loc=cache_loc, + ctx_cache_loc_2d=cache_loc_2d, + commit_lens=commit_lens, + ) + return + except Exception as e: + logger.warning( + "DFLASH fused prefix-direct KV append failed; falling back to the per-layer prefix-direct path: %s", + e, + ) + self._use_fused_kv_materialize = False + self._fused_kv_helper = None + + for layer in self.draft_model.layers: + attn = layer.self_attn + k, v = attn.kv_proj_only(ctx_hidden) + k = attn.apply_k_norm(k) + k = attn.apply_k_rope(positions, k) + k = k.view(-1, attn.num_kv_heads, attn.head_dim) + v = v.view(-1, attn.num_kv_heads, attn.head_dim) + + self.draft_model_runner.token_to_kv_pool.set_kv_buffer_prefix_valid( + attn.attn, + cache_loc_2d, + commit_lens, + k, + v, + attn.attn.k_scale, + attn.attn.v_scale, + ) + return + + if self._use_fused_kv_materialize and self._fused_kv_helper is not None: + try: + self._append_target_hidden_fused( + ctx_hidden=ctx_hidden, + ctx_positions=positions, + ctx_cache_loc=cache_loc, + ) + return + except Exception as e: + logger.warning( + "DFLASH fused KV append-by-loc failed; falling back to sequential path: %s", + e, + ) + self._use_fused_kv_materialize = False + self._fused_kv_helper = None + + self._append_target_hidden_sequential( + ctx_hidden=ctx_hidden, + ctx_positions=positions, + ctx_cache_loc=cache_loc, + ) + def _append_target_hidden_sequential( self, ctx_hidden: torch.Tensor, @@ -1039,23 +1243,39 @@ def _append_target_hidden_fused( ctx_hidden: torch.Tensor, ctx_positions: torch.Tensor, ctx_cache_loc: torch.Tensor, + ctx_cache_loc_2d: Optional[torch.Tensor] = None, + commit_lens: Optional[torch.Tensor] = None, ) -> None: """Fused KV materialization using batched projection + Triton kernel.""" token_to_kv_pool = self.draft_model_runner.token_to_kv_pool - layers = self.draft_model.layers + if self._fused_kv_helper is None: + raise RuntimeError("DFLASH fused KV helper is not initialized.") def _write_layer_kv( - layer_idx: int, cache_k: torch.Tensor, cache_v: torch.Tensor + layer_idx: int, + cache_k: torch.Tensor, + cache_v: torch.Tensor, ) -> None: - attn = layers[layer_idx].self_attn.attn - token_to_kv_pool.set_kv_buffer( - attn, - ctx_cache_loc, - cache_k, - cache_v, - attn.k_scale, - attn.v_scale, - ) + attn = self.draft_model.layers[layer_idx].self_attn.attn + if ctx_cache_loc_2d is not None and commit_lens is not None: + token_to_kv_pool.set_kv_buffer_prefix_valid( + attn, + ctx_cache_loc_2d, + commit_lens, + cache_k, + cache_v, + attn.k_scale, + attn.v_scale, + ) + else: + token_to_kv_pool.set_kv_buffer( + attn, + ctx_cache_loc, + cache_k, + cache_v, + attn.k_scale, + attn.v_scale, + ) self._fused_kv_helper.materialize( ctx_hidden=ctx_hidden, diff --git a/python/sglang/srt/speculative/dflash_worker_v2.py b/python/sglang/srt/speculative/dflash_worker_v2.py new file mode 100644 index 000000000000..90b977a052a8 --- /dev/null +++ b/python/sglang/srt/speculative/dflash_worker_v2.py @@ -0,0 +1,608 @@ +import logging +from typing import Optional + +import torch + +from sglang.srt.managers.schedule_batch import ModelWorkerBatch +from sglang.srt.managers.scheduler import GenerationBatchResult +from sglang.srt.managers.tp_worker import TpModelWorker +from sglang.srt.model_executor.forward_batch_info import ( + CaptureHiddenMode, + ForwardBatch, + ForwardMode, + compute_position, +) +from sglang.srt.server_args import ServerArgs +from sglang.srt.speculative.dflash_info import DFlashVerifyInput +from sglang.srt.speculative.dflash_info_v2 import DFlashDraftInputV2 +from sglang.srt.speculative.dflash_utils import ( + apply_dflash_verify_logits_adjustments, + compute_dflash_accept_len_and_bonus, + compute_dflash_sampling_accept_len_and_bonus, + is_dflash_sampling_verify_available, +) +from sglang.srt.speculative.dflash_worker import DFlashWorker +from sglang.srt.speculative.eagle_info_v2 import assign_extend_cache_locs_func +from sglang.srt.speculative.spec_info import SpeculativeAlgorithm +from sglang.srt.speculative.spec_utils import assign_req_to_token_pool_func +from sglang.srt.speculative.triton_ops.dflash_accept_bonus import ( + _compute_dflash_accept_bonus_triton_unchecked, +) +from sglang.srt.speculative.triton_ops.dflash_prepare_block import ( + _prepare_dflash_draft_block_unchecked, +) +from sglang.srt.utils import is_cuda, is_hip + +logger = logging.getLogger(__name__) + + +class DFlashWorkerV2(DFlashWorker): + """DFLASH speculative decoding worker (spec-v2 overlap scheduling). + + This is intentionally implemented as a *separate* worker from the existing + spec-v1 `DFlashWorker` (non-overlap), to keep the v1 path stable and to + minimize risk while bringing up overlap scheduling. + """ + + def __init__( + self, + server_args: ServerArgs, + gpu_id: int, + tp_rank: int, + dp_rank: Optional[int], + moe_ep_rank: int, + attn_cp_rank: int, + moe_dp_rank: int, + nccl_port: int, + target_worker: TpModelWorker, + ): + super().__init__( + server_args=server_args, + gpu_id=gpu_id, + tp_rank=tp_rank, + dp_rank=dp_rank, + moe_ep_rank=moe_ep_rank, + attn_cp_rank=attn_cp_rank, + moe_dp_rank=moe_dp_rank, + nccl_port=nccl_port, + target_worker=target_worker, + ) + supports_gpu_triton = is_cuda() or is_hip() + self._use_triton_prepare_block = supports_gpu_triton + self._use_triton_accept_bonus = supports_gpu_triton + + def _validate_phase1_sampling_support( + self, model_worker_batch: ModelWorkerBatch + ) -> None: + sampling_info = model_worker_batch.sampling_info + if sampling_info is None or sampling_info.is_all_greedy: + return + + if ( + not is_dflash_sampling_verify_available() + and not self._warned_sampling_fallback + and self.tp_rank == 0 + ): + logger.warning( + "DFLASH non-greedy verification is unavailable on this build/device; " + "falling back to greedy argmax verification." + ) + self._warned_sampling_fallback = True + + def _make_next_draft_input_prefill( + self, + *, + verified_id: torch.Tensor, + seq_lens: torch.Tensor, + verify_done: Optional[torch.cuda.Event] = None, + cur_allocated_seq_lens_cpu: Optional[torch.Tensor] = None, + ) -> DFlashDraftInputV2: + bs = int(seq_lens.numel()) + device = verified_id.device + return DFlashDraftInputV2( + topk_p=torch.empty((bs, 0), device=device, dtype=torch.float32), + topk_index=torch.empty((bs, 0), device=device, dtype=torch.int64), + verified_id=verified_id.to(dtype=torch.int32), + new_seq_lens=seq_lens.to(dtype=torch.int32), + hidden_states=torch.empty((bs, 0), device=device, dtype=torch.float16), + verify_done=verify_done, + cur_allocated_seq_lens_cpu=cur_allocated_seq_lens_cpu, + ) + + def _make_next_draft_input_decode( + self, + *, + verified_id: torch.Tensor, + new_seq_lens: torch.Tensor, + verify_done: Optional[torch.cuda.Event] = None, + cur_allocated_seq_lens_cpu: Optional[torch.Tensor] = None, + ) -> DFlashDraftInputV2: + bs = int(new_seq_lens.numel()) + device = verified_id.device + return DFlashDraftInputV2( + topk_p=torch.empty((bs, 0), device=device, dtype=torch.float32), + topk_index=torch.empty((bs, 0), device=device, dtype=torch.int64), + verified_id=verified_id.to(dtype=torch.int32), + new_seq_lens=new_seq_lens.to(dtype=torch.int32), + hidden_states=torch.empty((bs, 0), device=device, dtype=torch.float16), + verify_done=verify_done, + cur_allocated_seq_lens_cpu=cur_allocated_seq_lens_cpu, + ) + + def forward_batch_generation( + self, + model_worker_batch: ModelWorkerBatch, + **kwargs, + ) -> GenerationBatchResult: + if getattr(model_worker_batch, "return_logprob", False): + raise ValueError( + "DFLASH speculative decoding does not support return_logprob yet." + ) + self._validate_phase1_sampling_support(model_worker_batch) + + if ( + model_worker_batch.forward_mode.is_extend() + or model_worker_batch.is_extend_in_batch + ): + # Target prefill: capture DFlash aux hidden states for prompt tokens. + model_worker_batch.capture_hidden_mode = CaptureHiddenMode.FULL + batch_output = self.target_worker.forward_batch_generation( + model_worker_batch, **kwargs + ) + + logits_output, next_token_ids = ( + batch_output.logits_output, + batch_output.next_token_ids, + ) + + if logits_output.hidden_states is None: + raise RuntimeError( + "DFLASH requires target aux hidden capture for prefill, but got None. " + "Make sure the target model has DFlash layers-to-capture configured." + ) + + if ( + model_worker_batch.extend_seq_lens is None + or model_worker_batch.extend_prefix_lens is None + ): + raise RuntimeError( + "DFLASH expected extend_seq_lens / extend_prefix_lens to be populated in extend mode, " + "but got None." + ) + + # Materialize prompt tokens into the draft KV cache immediately. This is required + # for radix cache safety (the scheduler may update radix after prefill returns). + device = next_token_ids.device + ctx_lens = torch.tensor( + model_worker_batch.extend_seq_lens, dtype=torch.int32, device=device + ) + draft_seq_lens = torch.tensor( + model_worker_batch.extend_prefix_lens, dtype=torch.int32, device=device + ) + + if model_worker_batch.out_cache_loc is None: + raise RuntimeError( + "DFLASH prefill expected out_cache_loc, but got None." + ) + positions, _ = compute_position( + self.model_runner.server_args.attention_backend, + draft_seq_lens, + ctx_lens, + int(sum(model_worker_batch.extend_seq_lens)), + ) + self._append_target_hidden_to_draft_kv_by_loc( + target_hidden=logits_output.hidden_states, + cache_loc=model_worker_batch.out_cache_loc, + positions=positions, + ) + + # Avoid copying large hidden-state buffers to CPU in overlap scheduling. + logits_output.hidden_states = None + + batch_output.next_draft_input = self._make_next_draft_input_prefill( + verified_id=next_token_ids, + seq_lens=model_worker_batch.seq_lens, + cur_allocated_seq_lens_cpu=model_worker_batch.seq_lens_cpu, + ) + verify_done = torch.get_device_module(device).Event() + verify_done.record() + batch_output.next_draft_input.verify_done = verify_done + return batch_output + + # Decode / target-verify stage. + if model_worker_batch.spec_info is None: + model_worker_batch.spec_info = DFlashDraftInputV2.create_idle_input( + device=self.device + ) + + draft_input = model_worker_batch.spec_info + if not isinstance(draft_input, DFlashDraftInputV2): + raise RuntimeError( + "DFLASH spec-v2 expected DFlashDraftInputV2 state on the running batch." + ) + + if model_worker_batch.forward_mode.is_idle(): + empty_ids = torch.empty((0,), dtype=torch.int64, device=self.device) + empty_lens = torch.empty((0,), dtype=torch.int32, device=self.device) + next_draft_input = self._make_next_draft_input_decode( + verified_id=torch.empty((0,), device=self.device, dtype=torch.int32), + new_seq_lens=torch.empty((0,), device=self.device, dtype=torch.int32), + ) + verify_done = torch.get_device_module(self.device).Event() + verify_done.record() + next_draft_input.verify_done = verify_done + return GenerationBatchResult( + logits_output=None, + next_token_ids=empty_ids, + accept_lens=empty_lens, + next_draft_input=next_draft_input, + can_run_cuda_graph=False, + speculative_num_draft_tokens=int(self.block_size), + ) + + # `seq_lens` is carried over from the previous overlap iteration and may have been + # produced on another stream. + model_worker_batch.seq_lens.record_stream( + torch.get_device_module(self.device).current_stream() + ) + + bs = len(model_worker_batch.seq_lens) + device = self.device + + # --- 1) Draft a fixed block with the draft model. + target_model = self.target_worker.model_runner.model + embed_module = target_model.get_input_embeddings() + lm_head = getattr(target_model, "lm_head", None) + if lm_head is None or not hasattr(lm_head, "weight"): + raise RuntimeError( + "DFLASH requires the target model to expose `lm_head` with `weight`." + ) + + block_size = int(self.block_size) + self._ensure_draft_block_buffers(bs) + assert self._draft_block_ids_buf is not None + assert self._draft_block_positions_buf is not None + assert self._draft_block_tokens_buf is not None + assert self._draft_verify_out_cache_loc_buf is not None + assert self._draft_block_end_buf is not None + assert self._draft_seq_lens_cpu_buf is not None + + block_ids = self._draft_block_ids_buf[:bs] + prefix_lens = model_worker_batch.seq_lens + positions_2d = self._draft_block_positions_buf[:bs] + verify_out_cache_loc_2d = self._draft_verify_out_cache_loc_buf[:bs] + if self._use_triton_prepare_block: + try: + _prepare_dflash_draft_block_unchecked( + verified_id=draft_input.verified_id.view(-1), + prefix_lens=prefix_lens.view(-1), + req_pool_indices=model_worker_batch.req_pool_indices.view(-1), + req_to_token=self.model_runner.req_to_token_pool.req_to_token, + block_ids_out=block_ids, + positions_out=positions_2d, + cache_loc_out=verify_out_cache_loc_2d, + mask_token_id=int(self._mask_token_id), + ) + except Exception as e: + self._use_triton_prepare_block = False + logger.warning( + "DFLASH Triton prepare_block failed; falling back to eager path: %s", + e, + ) + block_ids.fill_(int(self._mask_token_id)) + block_ids[:, 0].copy_(draft_input.verified_id) + torch.add( + prefix_lens.unsqueeze(1), + self._block_pos_offsets, + out=positions_2d, + ) + end_offset = prefix_lens + block_size + verify_out_cache_loc = assign_extend_cache_locs_func( + req_pool_indices=model_worker_batch.req_pool_indices, + req_to_token=self.model_runner.req_to_token_pool.req_to_token, + start_offset=prefix_lens, + end_offset=end_offset, + batch_size=bs, + draft_token_num=block_size, + device=device, + ) + verify_out_cache_loc_2d.copy_(verify_out_cache_loc.view(bs, block_size)) + else: + block_ids.fill_(int(self._mask_token_id)) + block_ids[:, 0].copy_(draft_input.verified_id) + torch.add( + prefix_lens.unsqueeze(1), + self._block_pos_offsets, + out=positions_2d, + ) + end_offset = prefix_lens + block_size + verify_out_cache_loc = assign_extend_cache_locs_func( + req_pool_indices=model_worker_batch.req_pool_indices, + req_to_token=self.model_runner.req_to_token_pool.req_to_token, + start_offset=prefix_lens, + end_offset=end_offset, + batch_size=bs, + draft_token_num=block_size, + device=device, + ) + verify_out_cache_loc_2d.copy_(verify_out_cache_loc.view(bs, block_size)) + + noise_embedding = embed_module(block_ids) + input_embeds = noise_embedding.view(-1, noise_embedding.shape[-1]) + + positions = positions_2d.reshape(-1) + verify_out_cache_loc = verify_out_cache_loc_2d.reshape(-1) + + seq_lens_cpu = self._draft_seq_lens_cpu_buf[:bs] + if self.use_compact_draft_cache: + # Rebuild the draft-local sliding-window view from committed target state. + draft_prefix_lens = self._compute_compact_draft_seq_lens(prefix_lens) + seq_lens_cpu.copy_(draft_prefix_lens.to(device="cpu", dtype=torch.int32)) + + suffix_start = prefix_lens.to(torch.int64) - draft_prefix_lens.to( + torch.int64 + ) + suffix_cache_loc = self._gather_req_to_token_segments( + req_to_token=self.model_runner.req_to_token_pool.req_to_token, + req_pool_indices=model_worker_batch.req_pool_indices, + start=suffix_start, + lengths=draft_prefix_lens, + ) + assign_req_to_token_pool_func( + model_worker_batch.req_pool_indices, + self.draft_model_runner.req_to_token_pool.req_to_token, + torch.zeros_like(draft_prefix_lens), + draft_prefix_lens, + suffix_cache_loc, + bs, + ) + + block_end = self._draft_block_end_buf[:bs] + torch.add(draft_prefix_lens, block_size, out=block_end) + assign_req_to_token_pool_func( + model_worker_batch.req_pool_indices, + self.draft_model_runner.req_to_token_pool.req_to_token, + draft_prefix_lens, + block_end, + verify_out_cache_loc, + bs, + ) + draft_seq_lens = draft_prefix_lens + draft_seq_lens_sum = int(seq_lens_cpu.sum().item()) + else: + # Non-windowed path uses the shared overallocated mapping directly. + # Backend planning only needs a safe upper bound for the committed + # prefix lengths, not the full allocator reservation length. + draft_seq_lens = prefix_lens + if draft_input.planning_seq_lens_cpu is not None: + seq_lens_cpu.copy_(draft_input.planning_seq_lens_cpu) + draft_seq_lens_sum = int(draft_input.planning_seq_lens_sum) + elif draft_input.reserved_seq_lens_cpu is not None: + seq_lens_cpu.copy_(draft_input.reserved_seq_lens_cpu) + draft_seq_lens_sum = int(draft_input.reserved_seq_lens_sum) + elif model_worker_batch.seq_lens_cpu is not None: + seq_lens_cpu.copy_(model_worker_batch.seq_lens_cpu) + draft_seq_lens_sum = int(model_worker_batch.seq_lens_sum) + else: + seq_lens_cpu.copy_(prefix_lens.to("cpu", dtype=torch.int32)) + draft_seq_lens_sum = int(prefix_lens.sum().item()) + + forward_batch = ForwardBatch( + forward_mode=ForwardMode.TARGET_VERIFY, + batch_size=bs, + input_ids=block_ids.flatten(), + req_pool_indices=model_worker_batch.req_pool_indices, + seq_lens=draft_seq_lens, + out_cache_loc=verify_out_cache_loc, + seq_lens_sum=draft_seq_lens_sum, + seq_lens_cpu=seq_lens_cpu, + positions=positions, + req_to_token_pool=self.draft_model_runner.req_to_token_pool, + token_to_kv_pool=self.draft_model_runner.token_to_kv_pool, + attn_backend=self.draft_model_runner.attn_backend, + input_embeds=input_embeds, + spec_algorithm=SpeculativeAlgorithm.DFLASH, + spec_info=self._draft_block_spec_info, + capture_hidden_mode=CaptureHiddenMode.NULL, + ) + + with torch.inference_mode(): + draft_logits_output = self.draft_model_runner.forward( + forward_batch + ).logits_output + + draft_hidden = draft_logits_output.hidden_states + if draft_hidden is None: + raise RuntimeError("DFLASH draft model returned no hidden states.") + draft_hidden = draft_hidden.view(bs, int(self.block_size), -1) + draft_next = self._greedy_sample_from_vocab_parallel_head( + hidden_states=draft_hidden[:, 1:, :].reshape(-1, draft_hidden.shape[-1]), + lm_head=lm_head, + ).view(bs, int(self.block_size) - 1) + + draft_tokens = self._draft_block_tokens_buf[:bs] + draft_tokens[:, 0].copy_(block_ids[:, 0]) + draft_tokens[:, 1:].copy_(draft_next) + + # --- 2) Target verify. + # TARGET_VERIFY uses standard causal masking; custom masks are unnecessary here. + custom_mask = None + + verify_input_ids = draft_tokens.reshape(-1) + verify_input = DFlashVerifyInput( + draft_token=verify_input_ids, + positions=positions, + draft_token_num=int(self.block_size), + custom_mask=custom_mask, + capture_hidden_mode=CaptureHiddenMode.FULL, + ) + + model_worker_batch.out_cache_loc = verify_out_cache_loc + sampling_info = model_worker_batch.sampling_info + + need_mamba_verify_commit = hasattr( + self.target_worker.model_runner.attn_backend, + "update_mamba_state_after_mtp_verify", + ) + seq_lens_pre_verify = ( + model_worker_batch.seq_lens.clone() if need_mamba_verify_commit else None + ) + seq_lens_cpu_backup = model_worker_batch.seq_lens_cpu + seq_lens_sum_backup = model_worker_batch.seq_lens_sum + if draft_input.planning_seq_lens_cpu is not None: + model_worker_batch.seq_lens_cpu = draft_input.planning_seq_lens_cpu + model_worker_batch.seq_lens_sum = int(draft_input.planning_seq_lens_sum) + elif draft_input.reserved_seq_lens_cpu is not None: + model_worker_batch.seq_lens_cpu = draft_input.reserved_seq_lens_cpu + model_worker_batch.seq_lens_sum = int(draft_input.reserved_seq_lens_sum) + + verify_forward_batch, _ = verify_input.prepare_for_v2_verify( + model_worker_batch, self.target_worker + ) + model_worker_batch.seq_lens_cpu = seq_lens_cpu_backup + model_worker_batch.seq_lens_sum = seq_lens_sum_backup + + target_out = self.target_worker.forward_batch_generation( + model_worker_batch=None, + forward_batch=verify_forward_batch, + is_verify=True, + skip_attn_backend_init=True, + **kwargs, + ) + logits_output = target_out.logits_output + can_run_cuda_graph = target_out.can_run_cuda_graph + + if sampling_info is not None: + apply_dflash_verify_logits_adjustments( + next_token_logits=logits_output.next_token_logits, + sampling_info=sampling_info, + draft_token_num=int(self.block_size), + ) + + candidates = draft_tokens + if ( + sampling_info is not None + and not sampling_info.is_all_greedy + and is_dflash_sampling_verify_available() + ): + accept_len, bonus = compute_dflash_sampling_accept_len_and_bonus( + candidates=candidates, + next_token_logits=logits_output.next_token_logits, + sampling_info=sampling_info, + max_top_k=draft_input.max_top_k, + uniform_top_k_value=draft_input.uniform_top_k_value, + ) + commit_lens = accept_len.to(torch.int32) + 1 # [bs] + out_tokens = torch.empty( + (bs, int(self.block_size)), dtype=torch.int64, device=device + ) + if int(self.block_size) > 1: + out_tokens[:, : int(self.block_size) - 1].copy_(candidates[:, 1:]) + out_tokens[:, int(self.block_size) - 1].fill_(0) + out_tokens.scatter_(1, accept_len.to(torch.int64)[:, None], bonus[:, None]) + else: + target_predict = torch.argmax(logits_output.next_token_logits, dim=-1).view( + bs, int(self.block_size) + ) + if self._use_triton_accept_bonus: + try: + accept_len = torch.empty((bs,), dtype=torch.int32, device=device) + commit_lens = torch.empty((bs,), dtype=torch.int32, device=device) + bonus = torch.empty((bs,), dtype=candidates.dtype, device=device) + out_tokens = torch.empty( + (bs, int(self.block_size)), + dtype=candidates.dtype, + device=device, + ) + _compute_dflash_accept_bonus_triton_unchecked( + candidates=candidates, + target_top1=target_predict, + accept_lens_out=accept_len, + commit_lens_out=commit_lens, + bonus_ids_out=bonus, + out_tokens_out=out_tokens, + ) + except Exception as e: + self._use_triton_accept_bonus = False + logger.warning( + "DFLASH Triton accept/bonus failed; falling back to eager path: %s", + e, + ) + accept_len, bonus = compute_dflash_accept_len_and_bonus( + candidates=candidates, + target_predict=target_predict, + ) + commit_lens = accept_len.to(torch.int32) + 1 # [bs] + out_tokens = torch.empty( + (bs, int(self.block_size)), dtype=torch.int64, device=device + ) + if int(self.block_size) > 1: + out_tokens[:, : int(self.block_size) - 1].copy_( + candidates[:, 1:] + ) + out_tokens[:, int(self.block_size) - 1].fill_(0) + out_tokens.scatter_( + 1, accept_len.to(torch.int64)[:, None], bonus[:, None] + ) + else: + accept_len, bonus = compute_dflash_accept_len_and_bonus( + candidates=candidates, + target_predict=target_predict, + ) + commit_lens = accept_len.to(torch.int32) + 1 # [bs] + out_tokens = torch.empty( + (bs, int(self.block_size)), dtype=torch.int64, device=device + ) + if int(self.block_size) > 1: + out_tokens[:, : int(self.block_size) - 1].copy_(candidates[:, 1:]) + out_tokens[:, int(self.block_size) - 1].fill_(0) + out_tokens.scatter_( + 1, accept_len.to(torch.int64)[:, None], bonus[:, None] + ) + + if need_mamba_verify_commit: + assert seq_lens_pre_verify is not None + self._update_target_mamba_state_after_verify( + batch=model_worker_batch, + seq_lens_pre_verify=seq_lens_pre_verify, + commit_lens=commit_lens, + ) + + # --- 3) Materialize committed verify-input tokens into draft KV cache. + hidden = logits_output.hidden_states + if hidden is None: + raise RuntimeError( + "DFLASH verify requires target hidden states, but got None." + ) + hidden = hidden.view(bs, int(self.block_size), -1) + + self._append_target_hidden_to_draft_kv_by_loc( + target_hidden=hidden.reshape(-1, hidden.shape[-1]), + cache_loc=verify_out_cache_loc, + cache_loc_2d=verify_out_cache_loc_2d, + positions=positions, + commit_lens=commit_lens, + ) + + # Avoid copying large hidden-state buffers to CPU in overlap scheduling. + logits_output.hidden_states = None + + new_seq_lens = prefix_lens + commit_lens.to(prefix_lens.dtype) + next_draft_input = self._make_next_draft_input_decode( + verified_id=bonus, + new_seq_lens=new_seq_lens, + cur_allocated_seq_lens_cpu=draft_input.reserved_seq_lens_cpu, + ) + verify_done = torch.get_device_module(device).Event() + verify_done.record() + next_draft_input.verify_done = verify_done + + return GenerationBatchResult( + logits_output=logits_output, + next_token_ids=out_tokens.reshape(-1), + accept_lens=commit_lens, + can_run_cuda_graph=can_run_cuda_graph, + next_draft_input=next_draft_input, + prepared_kv_allocated_lens_cpu=draft_input.reserved_seq_lens_cpu, + speculative_num_draft_tokens=int(self.block_size), + ) diff --git a/python/sglang/srt/speculative/spec_info.py b/python/sglang/srt/speculative/spec_info.py index 8a0565e32f96..9c8dd4638ddb 100644 --- a/python/sglang/srt/speculative/spec_info.py +++ b/python/sglang/srt/speculative/spec_info.py @@ -110,7 +110,14 @@ def is_ngram(self) -> bool: return self == SpeculativeAlgorithm.NGRAM def supports_spec_v2(self) -> bool: - return (self.is_eagle() and not self.is_frozen_kv_mtp()) or self.is_standalone() + return ( + (self.is_eagle() and not self.is_frozen_kv_mtp()) + or self.is_standalone() + or self.is_dflash() + ) + + def need_topk(self) -> bool: + return self.is_eagle() or self.is_standalone() def create_worker( self, server_args: ServerArgs @@ -123,9 +130,9 @@ def create_worker( if self.is_dflash(): if enable_overlap: - raise ValueError( - "DFLASH does not support overlap scheduling (spec v2)." - ) + from sglang.srt.speculative.dflash_worker_v2 import DFlashWorkerV2 + + return DFlashWorkerV2 from sglang.srt.speculative.dflash_worker import DFlashWorker return DFlashWorker diff --git a/python/sglang/srt/speculative/triton_ops/dflash_accept_bonus.py b/python/sglang/srt/speculative/triton_ops/dflash_accept_bonus.py new file mode 100644 index 000000000000..c337aac2175a --- /dev/null +++ b/python/sglang/srt/speculative/triton_ops/dflash_accept_bonus.py @@ -0,0 +1,122 @@ +import torch +import triton +import triton.language as tl + + +@triton.jit +def _dflash_accept_bonus_contig_kernel( + candidates_ptr, + target_top1_ptr, + accept_lens_out_ptr, + commit_lens_out_ptr, + bonus_ids_out_ptr, + out_tokens_ptr, + candidates_row_stride, + target_row_stride, + accept_stride, + commit_stride, + bonus_stride, + out_tokens_row_stride, + block_size, + BLOCK_SIZE: tl.constexpr, +): + row = tl.program_id(0) + cols = tl.arange(0, BLOCK_SIZE) + row_mask = cols < block_size + draft_mask = cols < (block_size - 1) + + candidate_row_ptr = candidates_ptr + row * candidates_row_stride + target_row_ptr = target_top1_ptr + row * target_row_stride + candidate_tail = tl.load(candidate_row_ptr + cols + 1, mask=draft_mask, other=0) + + accept_len = tl.full((), 0, tl.int32) + prefix_live = tl.full((), 1, tl.int32) + for col in range(BLOCK_SIZE - 1): + in_range = col < (block_size - 1) + candidate_id = tl.load(candidate_row_ptr + (col + 1), mask=in_range, other=0) + target_id = tl.load(target_row_ptr + col, mask=in_range, other=0) + match_i32 = (candidate_id == target_id).to(tl.int32) + keep = in_range & (prefix_live != 0) & (match_i32 != 0) + accept_len += keep.to(tl.int32) + prefix_live = tl.where(in_range, prefix_live & match_i32, prefix_live) + + commit_len = accept_len + 1 + bonus_id = tl.load(target_row_ptr + accept_len.to(tl.int64)) + + tl.store(accept_lens_out_ptr + row * accept_stride, accept_len) + tl.store(commit_lens_out_ptr + row * commit_stride, commit_len) + tl.store(bonus_ids_out_ptr + row * bonus_stride, bonus_id) + + out_val = tl.where(draft_mask, candidate_tail, 0) + out_val = tl.where(cols == accept_len, bonus_id, out_val) + tl.store( + out_tokens_ptr + row * out_tokens_row_stride + cols, out_val, mask=row_mask + ) + + +def _pick_num_warps(block_size: int) -> int: + if block_size <= 16: + return 1 + if block_size <= 32: + return 2 + if block_size <= 64: + return 4 + return 8 + + +def _is_row_major_contiguous_2d(x: torch.Tensor) -> bool: + return x.ndim == 2 and x.is_contiguous() + + +def _compute_dflash_accept_bonus_triton_unchecked( + candidates: torch.Tensor, + target_top1: torch.Tensor, + accept_lens_out: torch.Tensor, + commit_lens_out: torch.Tensor, + bonus_ids_out: torch.Tensor, + out_tokens_out: torch.Tensor, +) -> None: + batch_size, block_size = candidates.shape + if batch_size == 0: + return + + if not _is_row_major_contiguous_2d(candidates): + raise ValueError("DFLASH Triton accept_bonus requires contiguous candidates.") + if not _is_row_major_contiguous_2d(target_top1): + raise ValueError("DFLASH Triton accept_bonus requires contiguous target_top1.") + if not _is_row_major_contiguous_2d(out_tokens_out): + raise ValueError( + "DFLASH Triton accept_bonus requires contiguous out_tokens_out." + ) + if not accept_lens_out.is_contiguous(): + raise ValueError( + "DFLASH Triton accept_bonus requires contiguous accept_lens_out." + ) + if not commit_lens_out.is_contiguous(): + raise ValueError( + "DFLASH Triton accept_bonus requires contiguous commit_lens_out." + ) + if not bonus_ids_out.is_contiguous(): + raise ValueError( + "DFLASH Triton accept_bonus requires contiguous bonus_ids_out." + ) + + block = triton.next_power_of_2(block_size) + num_warps = _pick_num_warps(block) + _dflash_accept_bonus_contig_kernel[(batch_size,)]( + candidates, + target_top1, + accept_lens_out, + commit_lens_out, + bonus_ids_out, + out_tokens_out, + candidates.stride(0), + target_top1.stride(0), + accept_lens_out.stride(0), + commit_lens_out.stride(0), + bonus_ids_out.stride(0), + out_tokens_out.stride(0), + block_size, + BLOCK_SIZE=block, + num_warps=num_warps, + ) diff --git a/python/sglang/srt/speculative/triton_ops/dflash_prepare_block.py b/python/sglang/srt/speculative/triton_ops/dflash_prepare_block.py new file mode 100644 index 000000000000..4cfc8fe06ad0 --- /dev/null +++ b/python/sglang/srt/speculative/triton_ops/dflash_prepare_block.py @@ -0,0 +1,123 @@ +import torch +import triton +import triton.language as tl + + +@triton.jit +def _prepare_dflash_draft_block_contig_kernel( + verified_id_ptr, + prefix_lens_ptr, + req_pool_indices_ptr, + req_to_token_ptr, + block_ids_out_ptr, + positions_out_ptr, + cache_loc_out_ptr, + verified_id_stride, + prefix_lens_stride, + req_pool_indices_stride, + req_to_token_row_stride, + block_ids_row_stride, + positions_row_stride, + cache_loc_row_stride, + req_to_token_width, + block_size, + mask_token_id, + BLOCK_SIZE: tl.constexpr, +): + row = tl.program_id(0) + cols = tl.arange(0, BLOCK_SIZE) + row_mask = cols < block_size + + prefix_len = tl.load(prefix_lens_ptr + row * prefix_lens_stride) + req_idx = tl.load(req_pool_indices_ptr + row * req_pool_indices_stride) + verified_id = tl.load(verified_id_ptr + row * verified_id_stride) + + logical_pos = prefix_len.to(tl.int64) + cols + valid = row_mask & (logical_pos < req_to_token_width) + req_row_ptr = req_to_token_ptr + req_idx * req_to_token_row_stride + slot_ids = tl.load(req_row_ptr + logical_pos, mask=valid, other=0) + + block_ids = tl.full((BLOCK_SIZE,), mask_token_id, tl.int64) + block_ids = tl.where(cols == 0, verified_id.to(tl.int64), block_ids) + tl.store( + block_ids_out_ptr + row * block_ids_row_stride + cols, block_ids, mask=row_mask + ) + tl.store( + positions_out_ptr + row * positions_row_stride + cols, + logical_pos, + mask=row_mask, + ) + tl.store( + cache_loc_out_ptr + row * cache_loc_row_stride + cols, + slot_ids.to(tl.int64), + mask=row_mask, + ) + + +def _pick_num_warps(block_size: int) -> int: + if block_size <= 16: + return 1 + if block_size <= 32: + return 2 + if block_size <= 64: + return 4 + return 8 + + +def _is_row_major_contiguous_2d(x: torch.Tensor) -> bool: + return x.ndim == 2 and x.is_contiguous() + + +def _prepare_dflash_draft_block_unchecked( + verified_id: torch.Tensor, + prefix_lens: torch.Tensor, + req_pool_indices: torch.Tensor, + req_to_token: torch.Tensor, + block_ids_out: torch.Tensor, + positions_out: torch.Tensor, + cache_loc_out: torch.Tensor, + mask_token_id: int, +) -> None: + batch_size = int(verified_id.numel()) + if batch_size == 0: + return + + if req_to_token.ndim != 2 or req_to_token.stride(1) != 1: + raise ValueError("DFLASH Triton prepare_block requires row-major req_to_token.") + if not _is_row_major_contiguous_2d(block_ids_out): + raise ValueError( + "DFLASH Triton prepare_block requires contiguous block_ids_out." + ) + if not _is_row_major_contiguous_2d(positions_out): + raise ValueError( + "DFLASH Triton prepare_block requires contiguous positions_out." + ) + if not _is_row_major_contiguous_2d(cache_loc_out): + raise ValueError( + "DFLASH Triton prepare_block requires contiguous cache_loc_out." + ) + + block_size = int(block_ids_out.shape[1]) + block = triton.next_power_of_2(block_size) + num_warps = _pick_num_warps(block) + _prepare_dflash_draft_block_contig_kernel[(batch_size,)]( + verified_id, + prefix_lens, + req_pool_indices, + req_to_token, + block_ids_out, + positions_out, + cache_loc_out, + verified_id.stride(0), + prefix_lens.stride(0), + req_pool_indices.stride(0), + req_to_token.stride(0), + block_ids_out.stride(0), + positions_out.stride(0), + cache_loc_out.stride(0), + int(req_to_token.shape[1]), + block_size, + int(mask_token_id), + BLOCK_SIZE=block, + num_warps=num_warps, + ) diff --git a/python/sglang/srt/speculative/triton_ops/fused_kv_materialize.py b/python/sglang/srt/speculative/triton_ops/fused_kv_materialize.py index e7dc4c05ddfc..0b8f5284f3ec 100644 --- a/python/sglang/srt/speculative/triton_ops/fused_kv_materialize.py +++ b/python/sglang/srt/speculative/triton_ops/fused_kv_materialize.py @@ -13,10 +13,10 @@ # ============================================================================== """Fused Triton kernel for DFlash KV materialization. -Combines: KV projection (cuBLAS) + RMSNorm + RoPE (Triton), then pool-managed KV writes. +Combines: KV projection + RMSNorm + RoPE, then pool-managed KV writes. """ -from typing import Callable, List +from typing import Callable, List, Optional import torch import triton @@ -24,45 +24,58 @@ @triton.jit -def _fused_norm_rope_kernel( - kv_ptr, # [total_ctx, kv_size * 2] - k_norm_weight_ptr, # [head_dim] +def _fused_norm_rope_kernel_stacked( + kv_ptr, # [total_ctx, n_layers, kv_size * 2] + k_norm_weight_ptr, # [n_layers, head_dim] + eps_ptr, # [n_layers] cos_sin_cache_ptr, # [max_pos, rotary_dim] positions_ptr, # [total_ctx] - k_out_ptr, # [total_ctx, num_kv_heads, head_dim] - v_out_ptr, # [total_ctx, num_kv_heads, head_dim] + k_out_ptr, # [n_layers, total_ctx, num_kv_heads, head_dim] + v_out_ptr, # [n_layers, total_ctx, num_kv_heads, head_dim] kv_stride_ctx, + kv_stride_layer, + k_norm_weight_stride_layer, cos_sin_stride_pos, + k_out_stride_layer, k_out_stride_ctx, k_out_stride_head, + v_out_stride_layer, v_out_stride_ctx, v_out_stride_head, total_ctx, + n_layers: tl.constexpr, num_kv_heads: tl.constexpr, head_dim: tl.constexpr, kv_size: tl.constexpr, rotary_dim: tl.constexpr, half_rotary_dim: tl.constexpr, - eps: tl.constexpr, BLOCK_HD: tl.constexpr, ): - """Fused RMSNorm(K) + RoPE(K) materialization. Grid: (total_ctx, num_kv_heads).""" + """Fused RMSNorm(K) + RoPE(K) materialization. Grid: (total_ctx, num_kv_heads, n_layers).""" ctx_id = tl.program_id(0) head_id = tl.program_id(1) - if ctx_id >= total_ctx: + layer_id = tl.program_id(2) + if ctx_id >= total_ctx or layer_id >= n_layers: return - # Load metadata position = tl.load(positions_ptr + ctx_id) - - # Compute base pointers - kv_base = kv_ptr + ctx_id * kv_stride_ctx + eps = tl.load(eps_ptr + layer_id).to(tl.float32) + kv_base = kv_ptr + ctx_id * kv_stride_ctx + layer_id * kv_stride_layer k_base = kv_base + head_id * head_dim v_base = kv_base + kv_size + head_id * head_dim - k_write = k_out_ptr + ctx_id * k_out_stride_ctx + head_id * k_out_stride_head - v_write = v_out_ptr + ctx_id * v_out_stride_ctx + head_id * v_out_stride_head + k_write = ( + k_out_ptr + + layer_id * k_out_stride_layer + + ctx_id * k_out_stride_ctx + + head_id * k_out_stride_head + ) + v_write = ( + v_out_ptr + + layer_id * v_out_stride_layer + + ctx_id * v_out_stride_ctx + + head_id * v_out_stride_head + ) - # Load K and V offs = tl.arange(0, BLOCK_HD) mask_hd = offs < head_dim mask_half = offs < half_rotary_dim @@ -70,36 +83,38 @@ def _fused_norm_rope_kernel( k_raw = tl.load(k_base + offs, mask=mask_hd, other=0.0).to(tl.float32) v_raw = tl.load(v_base + offs, mask=mask_hd, other=0.0) - # RMSNorm on K inv_rms = tl.rsqrt(tl.sum(k_raw * k_raw) / head_dim + eps) - norm_w = tl.load(k_norm_weight_ptr + offs, mask=mask_hd, other=1.0).to(tl.float32) + norm_w = tl.load( + k_norm_weight_ptr + layer_id * k_norm_weight_stride_layer + offs, + mask=mask_hd, + other=1.0, + ).to(tl.float32) k_normed = k_raw * inv_rms * norm_w - # RoPE (neox style): k_first, k_second -> rotated cos_sin_base = cos_sin_cache_ptr + position * cos_sin_stride_pos cos_v = tl.load(cos_sin_base + offs, mask=mask_half, other=1.0).to(tl.float32) sin_v = tl.load( cos_sin_base + half_rotary_dim + offs, mask=mask_half, other=0.0 ).to(tl.float32) - # Extract first/second halves of K for rotation k_first = tl.where(mask_half, k_normed, 0.0) k_second_raw = tl.load( k_base + half_rotary_dim + offs, mask=mask_half, other=0.0 ).to(tl.float32) norm_w_second = tl.load( - k_norm_weight_ptr + half_rotary_dim + offs, mask=mask_half, other=1.0 + k_norm_weight_ptr + + layer_id * k_norm_weight_stride_layer + + half_rotary_dim + + offs, + mask=mask_half, + other=1.0, ).to(tl.float32) k_second = k_second_raw * inv_rms * norm_w_second - # Apply rotation k_rot_first = k_first * cos_v - k_second * sin_v k_rot_second = k_second * cos_v + k_first * sin_v - # Store V (no transform) tl.store(v_write + offs, v_raw, mask=mask_hd) - - # Store K: rotated halves + pass-through tl.store(k_write + offs, k_rot_first.to(v_raw.dtype), mask=mask_half) tl.store( k_write + half_rotary_dim + offs, k_rot_second.to(v_raw.dtype), mask=mask_half @@ -108,70 +123,117 @@ def _fused_norm_rope_kernel( tl.store(k_write + offs, k_normed.to(v_raw.dtype), mask=mask_pass) -def _fused_norm_rope( - kv: torch.Tensor, # [total_ctx, kv_size*2] - k_norm_weight: torch.Tensor, # [head_dim] +def _fused_norm_rope_stacked( + kv: torch.Tensor, # [total_ctx, n_layers, kv_size*2] + k_norm_weight: torch.Tensor, # [n_layers, head_dim] + eps: torch.Tensor, # [n_layers] cos_sin_cache: torch.Tensor, # [max_pos, rotary_dim] positions: torch.Tensor, # [total_ctx] num_kv_heads: int, head_dim: int, rotary_dim: int, - eps: float = 1e-6, + k_out: Optional[torch.Tensor] = None, + v_out: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor, torch.Tensor]: - """Fused RMSNorm + RoPE materialization for a single layer.""" - total_ctx = kv.shape[0] + """Fused RMSNorm + RoPE materialization for all layers.""" + if kv.ndim != 3: + raise ValueError( + "Invalid stacked fused KV projection shape: " + f"got {tuple(kv.shape)}, expected 3D [total_ctx, n_layers, kv_size*2]." + ) + + total_ctx, n_layers, kv_dim = kv.shape if total_ctx == 0: empty = torch.empty( - (0, num_kv_heads, head_dim), dtype=kv.dtype, device=kv.device + (n_layers, 0, num_kv_heads, head_dim), dtype=kv.dtype, device=kv.device ) return empty, empty kv_size = num_kv_heads * head_dim - if kv.shape[1] != kv_size * 2: + if kv_dim != kv_size * 2: raise ValueError( "Invalid fused KV projection shape: " - f"got {tuple(kv.shape)}, expected second dim {kv_size * 2}." + f"got {tuple(kv.shape)}, expected trailing dim {kv_size * 2}." ) if rotary_dim <= 0 or rotary_dim > head_dim or rotary_dim % 2 != 0: raise ValueError( "Invalid fused KV rotary/head dim pair: " f"rotary_dim={rotary_dim}, head_dim={head_dim}." ) + if k_norm_weight.shape != (n_layers, head_dim): + raise ValueError( + "Invalid stacked k_norm_weight shape for fused KV materialization: " + f"got {tuple(k_norm_weight.shape)}, expected {(n_layers, head_dim)}." + ) + if eps.shape != (n_layers,): + raise ValueError( + "Invalid stacked eps shape for fused KV materialization: " + f"got {tuple(eps.shape)}, expected {(n_layers,)}." + ) half_rotary_dim = rotary_dim // 2 BLOCK_HD = triton.next_power_of_2(head_dim) - # Ensure int64 for indexing if positions.device != kv.device: positions = positions.to(device=kv.device, dtype=torch.int64) elif positions.dtype != torch.int64: positions = positions.to(torch.int64) - k_out = torch.empty( - (total_ctx, num_kv_heads, head_dim), dtype=kv.dtype, device=kv.device - ) - v_out = torch.empty_like(k_out) + expected_shape = (n_layers, total_ctx, num_kv_heads, head_dim) + if k_out is None: + k_out = torch.empty(expected_shape, dtype=kv.dtype, device=kv.device) + else: + if k_out.shape != expected_shape: + raise ValueError( + "Invalid k_out shape for fused KV materialization: " + f"got {tuple(k_out.shape)}, expected {expected_shape}." + ) + if k_out.device != kv.device or k_out.dtype != kv.dtype: + raise ValueError( + "Invalid k_out device/dtype for fused KV materialization: " + f"got device={k_out.device}, dtype={k_out.dtype}, " + f"expected device={kv.device}, dtype={kv.dtype}." + ) + if v_out is None: + v_out = torch.empty_like(k_out) + else: + if v_out.shape != expected_shape: + raise ValueError( + "Invalid v_out shape for fused KV materialization: " + f"got {tuple(v_out.shape)}, expected {expected_shape}." + ) + if v_out.device != kv.device or v_out.dtype != kv.dtype: + raise ValueError( + "Invalid v_out device/dtype for fused KV materialization: " + f"got device={v_out.device}, dtype={v_out.dtype}, " + f"expected device={kv.device}, dtype={kv.dtype}." + ) - _fused_norm_rope_kernel[(total_ctx, num_kv_heads)]( + _fused_norm_rope_kernel_stacked[(total_ctx, num_kv_heads, n_layers)]( kv, k_norm_weight, + eps, cos_sin_cache, positions, k_out, v_out, kv.stride(0), + kv.stride(1), + k_norm_weight.stride(0), cos_sin_cache.stride(0), k_out.stride(0), k_out.stride(1), + k_out.stride(2), v_out.stride(0), v_out.stride(1), + v_out.stride(2), total_ctx, + n_layers, num_kv_heads, head_dim, kv_size, rotary_dim, half_rotary_dim, - eps, BLOCK_HD, ) return k_out, v_out @@ -180,8 +242,8 @@ def _fused_norm_rope( class FusedKVMaterializeHelper: """Fused KV materialization helper using batched projection. - Uses torch.einsum for batched KV projection across all layers, - then a Triton kernel for fused RMSNorm + RoPE materialization per layer. + Uses a single large GEMM across all layers, then a Triton kernel for fused + RMSNorm + RoPE materialization across all layers. """ def __init__( @@ -191,12 +253,15 @@ def __init__( num_kv_heads: int, head_dim: int, device: torch.device, + max_position_hint: Optional[int] = None, ): self.num_kv_heads = num_kv_heads self.head_dim = head_dim self.rotary_emb = rotary_emb self.n_layers = len(layers) self.device = device + self.kv_size = self.num_kv_heads * self.head_dim + self.layer_out_dim = 2 * self.kv_size self.rotary_dim = int(getattr(rotary_emb, "rotary_dim", head_dim)) self.is_neox_style = bool(getattr(rotary_emb, "is_neox_style", True)) @@ -209,10 +274,24 @@ def __init__( f"rotary_dim={self.rotary_dim}, head_dim={self.head_dim}." ) - # Pre-extract and stack weights for batched projection. + self.max_position_hint = ( + max(int(max_position_hint) - 1, 0) + if max_position_hint is not None + else None + ) + self._reserved_rope_cache_len = int( + getattr(self.rotary_emb, "cos_sin_cache", torch.empty((0,))).shape[0] + ) + self._mm_out_supported = True + self._workspace_capacity = 0 + self._workspace_dtype: Optional[torch.dtype] = None + self._proj_workspace: Optional[torch.Tensor] = None + self._k_workspace: Optional[torch.Tensor] = None + self._v_workspace: Optional[torch.Tensor] = None + kv_weights = [] - self.k_norm_weights = [] - self.eps_values = [] + k_norm_weights = [] + eps_values = [] for layer_id, layer in enumerate(layers): attn = layer.self_attn @@ -240,15 +319,72 @@ def __init__( f"got (rotary_dim={layer_rotary_dim}, neox={layer_is_neox}) at layer {layer_id}." ) - # Extract KV portion of QKV weight qkv_w = attn.qkv_proj.weight kv_weight = qkv_w[attn.q_size : attn.q_size + 2 * attn.kv_size] kv_weights.append(kv_weight) - self.k_norm_weights.append(attn.k_norm.weight) - self.eps_values.append(attn.k_norm.variance_epsilon) + k_norm_weights.append(attn.k_norm.weight) + eps_values.append(float(attn.k_norm.variance_epsilon)) - # Stack for batched einsum: [n_layers, kv_size*2, hidden_size] - self.batched_kv_weight = torch.stack(kv_weights) + flat_kv_weight = torch.stack(kv_weights).reshape( + self.n_layers * self.layer_out_dim, -1 + ) + self.flat_kv_weight_t = flat_kv_weight.transpose(0, 1).contiguous() + self.k_norm_weights = torch.stack(k_norm_weights).contiguous() + self.eps_values = torch.tensor( + eps_values, dtype=torch.float32, device=self.device + ) + + if self.max_position_hint is not None: + self._ensure_rope_cache(self.max_position_hint) + + def _ensure_rope_cache(self, max_position: int) -> torch.Tensor: + if max_position + 1 > self._reserved_rope_cache_len: + ensure_cos_sin_cache_length = getattr( + self.rotary_emb, "_ensure_cos_sin_cache_length", None + ) + if callable(ensure_cos_sin_cache_length): + ensure_cos_sin_cache_length(max_position) + self._reserved_rope_cache_len = int( + self.rotary_emb.cos_sin_cache.shape[0] + ) + + cos_sin_cache = self.rotary_emb.cos_sin_cache + if max_position >= int(cos_sin_cache.shape[0]): + raise RuntimeError( + "RoPE cos/sin cache is too short for fused KV materialization: " + f"max_position={max_position}, cache_len={int(cos_sin_cache.shape[0])}." + ) + if cos_sin_cache.device != self.device: + cos_sin_cache = cos_sin_cache.to(self.device) + return cos_sin_cache + + def _ensure_workspace(self, total_ctx: int, dtype: torch.dtype) -> None: + if ( + self._workspace_capacity >= total_ctx + and self._workspace_dtype == dtype + and self._proj_workspace is not None + and self._k_workspace is not None + and self._v_workspace is not None + ): + return + + new_capacity = max(1, total_ctx) + if self._workspace_capacity > 0: + new_capacity = max(new_capacity, self._workspace_capacity * 2) + + self._proj_workspace = torch.empty( + (new_capacity, self.n_layers * self.layer_out_dim), + dtype=dtype, + device=self.device, + ) + self._k_workspace = torch.empty( + (self.n_layers, new_capacity, self.num_kv_heads, self.head_dim), + dtype=dtype, + device=self.device, + ) + self._v_workspace = torch.empty_like(self._k_workspace) + self._workspace_capacity = new_capacity + self._workspace_dtype = dtype def materialize( self, @@ -269,35 +405,53 @@ def materialize( f"positions={positions.numel()}, total_ctx={total_ctx}." ) - max_position = int(positions.max().item()) - ensure_cos_sin_cache_length = getattr( - self.rotary_emb, "_ensure_cos_sin_cache_length", None - ) - if callable(ensure_cos_sin_cache_length): - ensure_cos_sin_cache_length(max_position) - - cos_sin_cache = self.rotary_emb.cos_sin_cache - if max_position >= int(cos_sin_cache.shape[0]): - raise RuntimeError( - "RoPE cos/sin cache is too short for fused KV materialization: " - f"max_position={max_position}, cache_len={int(cos_sin_cache.shape[0])}." - ) - if cos_sin_cache.device != ctx_hidden.device: - cos_sin_cache = cos_sin_cache.to(ctx_hidden.device) - - # Batched KV projection: [n_layers, total_ctx, kv_size*2] - kv_all = torch.einsum("th,loh->lto", ctx_hidden, self.batched_kv_weight) - - # Per-layer fused norm/RoPE/materialize, then delegate writes to the KV pool. - for layer_id in range(self.n_layers): - cache_k, cache_v = _fused_norm_rope( - kv_all[layer_id], - self.k_norm_weights[layer_id], - cos_sin_cache, - positions, - self.num_kv_heads, - self.head_dim, - self.rotary_dim, - self.eps_values[layer_id], + if ctx_hidden.device != self.device: + ctx_hidden = ctx_hidden.to(self.device, non_blocking=True) + if ctx_hidden.dtype != self.flat_kv_weight_t.dtype: + ctx_hidden = ctx_hidden.to(self.flat_kv_weight_t.dtype) + if positions.device != self.device: + positions = positions.to( + device=self.device, dtype=torch.int64, non_blocking=True ) - write_layer_kv(layer_id, cache_k, cache_v) + elif positions.dtype != torch.int64: + positions = positions.to(torch.int64) + + max_position = ( + self.max_position_hint + if self.max_position_hint is not None + else int(positions.max().item()) + ) + cos_sin_cache = self._ensure_rope_cache(max_position) + + self._ensure_workspace(total_ctx, ctx_hidden.dtype) + assert self._proj_workspace is not None + assert self._k_workspace is not None + assert self._v_workspace is not None + + proj_out_2d = self._proj_workspace[:total_ctx] + if self._mm_out_supported: + try: + torch.mm(ctx_hidden, self.flat_kv_weight_t, out=proj_out_2d) + except Exception: + self._mm_out_supported = False + proj_out_2d = torch.mm(ctx_hidden, self.flat_kv_weight_t) + else: + proj_out_2d = torch.mm(ctx_hidden, self.flat_kv_weight_t) + + proj_out = proj_out_2d.view(total_ctx, self.n_layers, self.layer_out_dim) + tmp_k = self._k_workspace[:, :total_ctx] + tmp_v = self._v_workspace[:, :total_ctx] + cache_k, cache_v = _fused_norm_rope_stacked( + proj_out, + self.k_norm_weights, + self.eps_values, + cos_sin_cache, + positions, + self.num_kv_heads, + self.head_dim, + self.rotary_dim, + k_out=tmp_k, + v_out=tmp_v, + ) + for layer_idx in range(self.n_layers): + write_layer_kv(layer_idx, cache_k[layer_idx], cache_v[layer_idx]) diff --git a/test/registered/spec/dflash/test_dflash.py b/test/registered/spec/dflash/test_dflash.py index d393e4019cdc..d06317a9f03e 100644 --- a/test/registered/spec/dflash/test_dflash.py +++ b/test/registered/spec/dflash/test_dflash.py @@ -1,4 +1,3 @@ -import os import unittest import openai @@ -8,7 +7,10 @@ from sglang.test.ci.ci_register import register_cuda_ci from sglang.test.kits.eval_accuracy_kit import GSM8KMixin from sglang.test.kits.matched_stop_kit import MatchedStopMixin -from sglang.test.kits.radix_cache_server_kit import gen_radix_tree +from sglang.test.kits.radix_cache_server_kit import ( + gen_radix_tree, + run_radix_attention_test, +) from sglang.test.test_utils import ( DEFAULT_DRAFT_MODEL_DFLASH, DEFAULT_TARGET_MODEL_DFLASH, @@ -26,6 +28,8 @@ class TestDFlashServerBase(CustomTestCase, MatchedStopMixin, GSM8KMixin): attention_backend = "flashinfer" page_size = 1 other_launch_args = [] + spec_v2 = False + overlap_plan_stream = False model = DEFAULT_TARGET_MODEL_DFLASH draft_model = DEFAULT_DRAFT_MODEL_DFLASH gsm8k_accuracy_thres = 0.75 @@ -50,31 +54,30 @@ def setUpClass(cls): *[str(i) for i in range(1, cls.max_running_requests + 1)], ] launch_args.extend(cls.other_launch_args) - old_value = os.environ.get("SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN") - os.environ["SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN"] = "1" - try: - with envs.SGLANG_ENABLE_STRICT_MEM_CHECK_DURING_BUSY.override( - 1 - ), envs.SGLANG_SPEC_NAN_DETECTION.override( - True - ), envs.SGLANG_SPEC_OOB_DETECTION.override( - True - ): - cls.process = popen_launch_server( - cls.model, - cls.base_url, - timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - other_args=launch_args, - ) - finally: - if old_value is None: - del os.environ["SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN"] - else: - os.environ["SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN"] = old_value + with envs.SGLANG_ENABLE_SPEC_V2.override( + cls.spec_v2 + ), envs.SGLANG_ENABLE_OVERLAP_PLAN_STREAM.override( + cls.overlap_plan_stream + ), envs.SGLANG_ENABLE_STRICT_MEM_CHECK_DURING_BUSY.override( + 1 + ), envs.SGLANG_SPEC_NAN_DETECTION.override( + True + ), envs.SGLANG_SPEC_OOB_DETECTION.override( + True + ), envs.SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN.override( + True + ): + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=launch_args, + ) @classmethod def tearDownClass(cls): - kill_process_tree(cls.process.pid) + if hasattr(cls, "process") and cls.process: + kill_process_tree(cls.process.pid) def test_early_stop(self): client = openai.Client(base_url=self.base_url + "/v1", api_key="EMPTY") @@ -148,5 +151,17 @@ class TestDFlashServerNoCudaGraph(TestDFlashServerBase): other_launch_args = ["--disable-cuda-graph"] +class TestDFlashServerSpecV2(TestDFlashServerBase): + spec_v2 = True + + def test_radix_attention(self): + run_radix_attention_test(self.base_url) + assert self.process.poll() is None + + +class TestDFlashServerSpecV2PlanStream(TestDFlashServerSpecV2): + overlap_plan_stream = True + + if __name__ == "__main__": unittest.main()