diff --git a/docs/advanced_features/server_arguments.md b/docs/advanced_features/server_arguments.md index cce4be83afcf..ae17ada40107 100644 --- a/docs/advanced_features/server_arguments.md +++ b/docs/advanced_features/server_arguments.md @@ -394,6 +394,7 @@ Please consult the documentation below and [server_args.py](https://github.com/s | `--enable-return-hidden-states` | Enable returning hidden states with responses. | `False` | bool flag (set to enable) | | `--scheduler-recv-interval` | The interval to poll requests in scheduler. Can be set to >1 to reduce the overhead of this. | `1` | Type: int | | `--numa-node` | Sets the numa node for the subprocesses. i-th element corresponds to i-th subprocess. | `None` | List[int] | +| `--enable-attn-tp-input-scattered` | Allow input of attention to be scattered when only using tensor parallelism, to reduce the computational load of operations such as qkv latent. | `False` | bool flag (set to enable) | ## Debug tensor dumps | Argument | Description | Defaults | Options | diff --git a/python/sglang/srt/layers/communicator.py b/python/sglang/srt/layers/communicator.py index 4626afafc492..29d3ac8af6f5 100644 --- a/python/sglang/srt/layers/communicator.py +++ b/python/sglang/srt/layers/communicator.py @@ -11,15 +11,17 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== - +import logging +from contextlib import contextmanager from dataclasses import dataclass from enum import Enum, auto from functools import partial -from typing import Dict, List, Optional +from typing import Callable, Dict, List, Optional, Tuple import torch from sglang.srt.distributed import ( + get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, get_tp_group, tensor_model_parallel_all_reduce, @@ -59,9 +61,10 @@ prepare_weight_cache, ) +_is_cuda = is_cuda() _is_flashinfer_available = is_flashinfer_available() -_is_sm90_supported = is_cuda() and is_sm90_supported() -_is_sm100_supported = is_cuda() and is_sm100_supported() +_is_sm90_supported = _is_cuda and is_sm90_supported() +_is_sm100_supported = _is_cuda and is_sm100_supported() _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and is_hip() _is_gfx95_supported = is_gfx95_supported() @@ -92,6 +95,119 @@ def model_input_output(): return ScatterMode.TP_ATTN_FULL +class AttentionInputs: + + def __init__( + self, + hidden_states: torch.Tensor, + forward_batch: ForwardBatch, + qkv_latent_func: Callable, + ): + self.hidden_states_local = hidden_states + self.forward_batch = forward_batch + self.qkv_latent_func = qkv_latent_func + self.hidden_states_ = None + self.qkv_latent_ = None + + def tp_all_gather_hidden_states(self, hidden_states, forward_batch): + total_tokens = forward_batch.input_ids.shape[0] + output = hidden_states.new_empty((total_tokens, hidden_states.shape[-1])) + get_tp_group().all_gather_into_tensor(output, hidden_states) + return output + + def fetch_qkv_latent(self): + if self.qkv_latent_ is not None: + return self.qkv_latent_ + assert self.qkv_latent_func is not None + self.qkv_latent_ = self.qkv_latent_func( + self.hidden_states_local, self.forward_batch + ) + if get_attn_tp_context().input_scattered: + self.qkv_latent_ = self.tp_all_gather_hidden_states( + self.qkv_latent_, self.forward_batch + ) + return self.qkv_latent_ + + def fetch_hidden_states(self): + if self.hidden_states_ is not None: + return self.hidden_states_ + self.hidden_states_ = self.hidden_states_local + if get_attn_tp_context().input_scattered: + self.hidden_states_ = self.tp_all_gather_hidden_states( + self.hidden_states_, self.forward_batch + ) + return self.hidden_states_ + + +class AttnTpContext: + def __init__(self): + self.allow_input_scattered = False + self.input_scattered_ = False + self.attn_inputs_: Optional[AttentionInputs] = None + + def init_context(self, q_lora_rank, is_nsa): + self.allow_input_scattered = ( + get_global_server_args().enable_attn_tp_input_scattered + and _is_cuda + and q_lora_rank is not None + and not is_nsa + and get_tensor_model_parallel_world_size() > 1 + and not is_dp_attention_enabled() + and get_moe_a2a_backend().is_none() + and not enable_moe_dense_fully_dp() + and not get_global_server_args().enable_piecewise_cuda_graph + and get_global_server_args().speculative_algorithm != "EAGLE3" + ) + if get_global_server_args().enable_attn_tp_input_scattered: + if not self.allow_input_scattered: + logging.info( + "attn_tp_input_scattered is not enabled while other conditions are not met" + ) + else: + logging.info("attn_tp_input_scattered is enabled") + + def use_input_scattered(self, forward_batch: ForwardBatch): + return ( + self.allow_input_scattered + and forward_batch.forward_mode.is_extend() + and not forward_batch.forward_mode.is_target_verify() + and not forward_batch.forward_mode.is_draft_extend() + and forward_batch.input_ids is not None + and not forward_batch.can_run_tbo + ) + + @property + def input_scattered(self): + return self.input_scattered_ + + def set_attn_inputs(self, attn_inputs: AttentionInputs): + self.attn_inputs_ = attn_inputs + + def fetch_qkv_latent(self): + assert self.attn_inputs_ is not None + return self.attn_inputs_.fetch_qkv_latent() + + def fetch_hidden_states(self): + assert self.attn_inputs_ is not None + return self.attn_inputs_.fetch_hidden_states() + + @contextmanager + def maybe_input_scattered(self, forward_batch: ForwardBatch): + flag = self.use_input_scattered(forward_batch) + old_flag = self.input_scattered + self.input_scattered_ = flag + yield + self.input_scattered_ = old_flag + self.attn_inputs_ = None + + +ATTN_TP_CONTEXT = AttnTpContext() + + +def get_attn_tp_context(): + return ATTN_TP_CONTEXT + + @dataclass class _LayerModeComputationContext: num_layers: int @@ -188,12 +304,14 @@ def __init__( # Reduce scatter requires skipping all-reduce in model code after MoE/MLP, so only enable for models which have that implemented. Remove flag once done for all models that use LayerCommunicator. allow_reduce_scatter: bool = False, is_last_layer: bool = False, + qkv_latent_func: Optional[Callable] = None, ): self.layer_scatter_modes = layer_scatter_modes self.input_layernorm = input_layernorm self.post_attention_layernorm = post_attention_layernorm self.allow_reduce_scatter = allow_reduce_scatter self.is_last_layer = is_last_layer + self.qkv_latent_func = qkv_latent_func self._context = CommunicateContext.init_new() self._communicate_simple_fn = CommunicateSimpleFn.get_fn( @@ -252,6 +370,11 @@ def prepare_attn( forward_batch: ForwardBatch, quant_format: str = "", ): + if get_attn_tp_context().input_scattered: + hidden_states, residual = self._tp_reduce_scatter( + hidden_states, + residual, + ) if hidden_states.shape[0] == 0: residual = hidden_states else: @@ -335,9 +458,32 @@ def prepare_attn( forward_batch=forward_batch, context=self._context, ) - + if self.qkv_latent_func is not None: + attn_inputs = AttentionInputs( + hidden_states, forward_batch, self.qkv_latent_func + ) + get_attn_tp_context().set_attn_inputs(attn_inputs) return hidden_states, residual + def _tp_reduce_scatter( + self, + hidden_states: torch.Tensor, + residual: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + if hidden_states.shape[0] == 0: + return hidden_states, hidden_states + assert ( + hidden_states.shape[0] % self._context.tp_size == 0 + ), f"Expected total tokens {hidden_states.shape[0]} % tp_size {self._context.tp_size} to be 0" + local_tokens = hidden_states.shape[0] // self._context.tp_size + output = hidden_states.new_empty(local_tokens, *hidden_states.shape[1:]) + get_tp_group().reduce_scatter_tensor(output, hidden_states) + if residual is not None: + residual = residual.tensor_split(self._context.tp_size)[ + self._context.tp_rank + ] + return output, residual + def prepare_mlp( self, hidden_states: torch.Tensor, @@ -371,12 +517,17 @@ def postprocess_layer( ) def should_use_reduce_scatter(self, forward_batch: ForwardBatch): - return ( - self.allow_reduce_scatter - and self._communicate_summable_tensor_pair_fn + if not self.allow_reduce_scatter: + return False + if ( + self._communicate_summable_tensor_pair_fn is CommunicateSummableTensorPairFn._scatter_hidden_states and forward_batch.dp_padding_mode.is_max_len() - ) + ): + return True + if get_attn_tp_context().input_scattered and not self.is_last_layer: + return True + return False def should_fuse_mlp_allreduce_with_next_layer( self, forward_batch: ForwardBatch @@ -388,6 +539,9 @@ def should_fuse_mlp_allreduce_with_next_layer( ): return False + if get_attn_tp_context().input_scattered: + return False + batch_size = ( forward_batch.input_ids.shape[0] if hasattr(forward_batch, "input_ids") @@ -422,6 +576,7 @@ class CommunicateContext: attn_dp_size: int tp_size: int cache = None + tp_rank: int def is_same_group_size(self, a: ScatterMode, b: ScatterMode): return self.process_group_sizes[a] == self.process_group_sizes[b] @@ -432,6 +587,7 @@ def init_new(cls): attn_tp_size = get_attention_tp_size() attn_dp_size = get_attention_dp_size() tp_size = get_tensor_model_parallel_world_size() + tp_rank = get_tensor_model_parallel_rank() process_group_sizes = { ScatterMode.SCATTERED: 1, ScatterMode.TP_ATTN_FULL: attn_tp_size, @@ -444,6 +600,7 @@ def init_new(cls): attn_tp_size=attn_tp_size, attn_dp_size=attn_dp_size, tp_size=tp_size, + tp_rank=tp_rank, ) @@ -566,6 +723,14 @@ def _gather_hidden_states_and_residual( *, residual_input_mode, ): + if get_attn_tp_context().input_scattered: + return CommunicateWithAllReduceAndLayerNormFn._tp_all_reduce_with_scattered_residual( + hidden_states, + residual, + layernorm, + context, + ) + if residual_input_mode == ScatterMode.SCATTERED and context.attn_tp_size > 1: residual, local_residual = ( get_local_dp_buffer(), @@ -637,6 +802,22 @@ def _scatter_hidden_states_and_residual( hidden_states, residual = layernorm(hidden_states, residual) return hidden_states, residual + @staticmethod + def _tp_all_reduce_with_scattered_residual( + hidden_states: torch.Tensor, + residual: torch.Tensor, + layernorm: torch.nn.Module, + context: CommunicateContext, + ): + if hidden_states.shape[0] == 0: + return hidden_states, hidden_states + + scattered_states = hidden_states.tensor_split(context.tp_size)[context.tp_rank] + scattered_states += residual + residual = tensor_model_parallel_all_reduce(hidden_states) + hidden_states = layernorm(residual) + return hidden_states, residual + class CommunicateSummableTensorPairFn: """It is allowed to make (hidden_states, residual) := (hidden_states + residual, None) if needed.""" diff --git a/python/sglang/srt/layers/vocab_parallel_embedding.py b/python/sglang/srt/layers/vocab_parallel_embedding.py index 6c153b25051b..5912cb86d375 100644 --- a/python/sglang/srt/layers/vocab_parallel_embedding.py +++ b/python/sglang/srt/layers/vocab_parallel_embedding.py @@ -18,6 +18,7 @@ use_symmetric_memory, ) from sglang.srt.layers.amx_utils import PackWeightMethod +from sglang.srt.layers.communicator import get_attn_tp_context from sglang.srt.layers.dp_attention import get_attention_tp_rank, get_attention_tp_size from sglang.srt.layers.parameter import BasevLLMParameter from sglang.srt.layers.quantization.base_config import ( @@ -478,11 +479,10 @@ def forward(self, input_): # Mask the output embedding. if self.tp_size > 1: output_parallel.masked_fill_(input_mask.unsqueeze(-1), 0) - # Reduce across all the model parallel GPUs. - output = tensor_model_parallel_all_reduce(output_parallel) - else: - output = output_parallel - return output + if not get_attn_tp_context().input_scattered: + # Reduce across all the model parallel GPUs. + output_parallel = tensor_model_parallel_all_reduce(output_parallel) + return output_parallel def extra_repr(self) -> str: s = f"num_embeddings={self.num_embeddings_per_partition}" diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index a4813b988422..12f0af7428de 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -38,7 +38,10 @@ import triton import triton.language as tl -from sglang.srt.distributed.parallel_state import get_moe_expert_parallel_world_size +from sglang.srt.distributed.parallel_state import ( + get_moe_expert_parallel_world_size, + get_tensor_model_parallel_world_size, +) from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton from sglang.srt.layers.dp_attention import ( DpPaddingMode, @@ -766,6 +769,13 @@ def prepare_mlp_sync_batch(self, model_runner: ModelRunner): else: bs = self.batch_size = num_tokens + # padding + self._pad_inputs_to_size(model_runner, num_tokens, bs) + self.global_num_tokens_cpu = global_num_tokens + global_num_tokens_pinned = torch.tensor(global_num_tokens, pin_memory=True) + self.global_num_tokens_gpu.copy_(global_num_tokens_pinned, non_blocking=True) + + def _pad_inputs_to_size(self, model_runner: ModelRunner, num_tokens, bs): # padding self.input_ids = self._pad_tensor_to_size(self.input_ids, num_tokens) self.req_pool_indices = self._pad_tensor_to_size(self.req_pool_indices, bs) @@ -788,9 +798,6 @@ def prepare_mlp_sync_batch(self, model_runner: ModelRunner): if self.encoder_lens is not None: self.encoder_lens = self._pad_tensor_to_size(self.encoder_lens, bs) self.positions = self._pad_tensor_to_size(self.positions, num_tokens) - self.global_num_tokens_cpu = global_num_tokens - global_num_tokens_pinned = torch.tensor(global_num_tokens, pin_memory=True) - self.global_num_tokens_gpu.copy_(global_num_tokens_pinned, non_blocking=True) if self.mrope_positions is not None: self.mrope_positions = self._pad_tensor_to_size(self.mrope_positions, bs) @@ -818,6 +825,19 @@ def prepare_mlp_sync_batch(self, model_runner: ModelRunner): spec_info.hidden_states, num_tokens ) + def prepare_attn_tp_scatter_input(self, model_runner: ModelRunner): + from sglang.srt.layers.communicator import get_attn_tp_context + + attn_tp_context = get_attn_tp_context() + input_scattered = attn_tp_context.use_input_scattered(self) + if not input_scattered: + return + assert self.forward_mode.is_extend() + tokens = self.input_ids.shape[0] + rank_size = get_tensor_model_parallel_world_size() + tokens_padded = (tokens + rank_size - 1) // rank_size * rank_size + self._pad_inputs_to_size(model_runner, tokens_padded, self.batch_size) + def post_forward_mlp_sync_batch(self, logits_output: LogitsProcessorOutput): self.forward_mode = getattr(self, "_original_forward_mode", self.forward_mode) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 19e029d60e45..3d9bbb62cc11 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -2200,6 +2200,8 @@ def _forward_raw( # For MLP sync if forward_batch.global_num_tokens_cpu is not None: forward_batch.prepare_mlp_sync_batch(self) + else: + forward_batch.prepare_attn_tp_scatter_input(self) if forward_batch.forward_mode.is_decode(): ret = self.forward_decode( diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index b174407c780e..1b282b8ceb5e 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -59,6 +59,7 @@ LayerCommunicator, LayerScatterModes, enable_moe_dense_fully_dp, + get_attn_tp_context, ) from sglang.srt.layers.dp_attention import ( get_attention_tp_rank, @@ -1409,13 +1410,19 @@ def forward_prepare( # when hidden_states is a tuple of tensors, the tuple will include quantized weight and scale tensor if isinstance(hidden_states, tuple): - if hidden_states[0].shape[0] == 0: + if ( + not get_attn_tp_context().input_scattered + and hidden_states[0].shape[0] == 0 + ): assert ( not self.o_proj.reduce_results ), "short-circuiting allreduce will lead to hangs" return hidden_states[0] else: - if hidden_states.shape[0] == 0: + if ( + not get_attn_tp_context().input_scattered + and hidden_states.shape[0] == 0 + ): assert ( not self.o_proj.reduce_results ), "short-circuiting allreduce will lead to hangs" @@ -1498,6 +1505,23 @@ def forward_core(self, intermediate_state): else: raise NotImplementedError + def prepare_qkv_latent( + self, hidden_states: torch.Tensor, forward_batch: ForwardBatch + ): + assert self.q_lora_rank is not None + if ( + (not isinstance(hidden_states, tuple)) + and hidden_states.shape[0] >= 1 + and hidden_states.shape[0] <= 16 + and self.use_min_latency_fused_a_gemm + ): + qkv_latent = dsv3_fused_a_gemm( + hidden_states, self.fused_qkv_a_proj_with_mqa.weight.T + ) + else: + qkv_latent = self.fused_qkv_a_proj_with_mqa(hidden_states)[0] + return qkv_latent + def forward_normal_prepare( self, positions: torch.Tensor, @@ -1506,8 +1530,13 @@ def forward_normal_prepare( zero_allocator: BumpAllocator, ): if self.q_lora_rank is not None: - q, latent_cache = self.fused_qkv_a_proj_with_mqa(hidden_states)[0].split( - [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1 + q, latent_cache = ( + get_attn_tp_context() + .fetch_qkv_latent() + .split( + [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], + dim=-1, + ) ) # NSA Indexer: cache quantized keys, auto-skip topk for sequences <= nsa_index_topk @@ -1630,18 +1659,13 @@ def forward_absorb_prepare( q_lora = None if self.q_lora_rank is not None: - if ( - (not isinstance(hidden_states, tuple)) - and hidden_states.shape[0] <= 16 - and self.use_min_latency_fused_a_gemm - ): - fused_qkv_a_proj_out = dsv3_fused_a_gemm( - hidden_states, self.fused_qkv_a_proj_with_mqa.weight.T + q, latent_cache = ( + get_attn_tp_context() + .fetch_qkv_latent() + .split( + [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], + dim=-1, ) - else: - fused_qkv_a_proj_out = self.fused_qkv_a_proj_with_mqa(hidden_states)[0] - q, latent_cache = fused_qkv_a_proj_out.split( - [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1 ) k_nope = latent_cache[..., : self.kv_lora_rank] @@ -2738,6 +2762,7 @@ def __init__( is_last_layer=( is_nextn or (self.layer_id == self.config.num_hidden_layers - 1) ), + qkv_latent_func=self.self_attn.prepare_qkv_latent, ) def _is_layer_sparse(self, layer_id: int, is_nextn: bool) -> bool: @@ -3159,6 +3184,9 @@ def __init__( ) self.capture_aux_hidden_states = False + q_lora_rank = config.q_lora_rank if hasattr(config, "q_lora_rank") else None + get_attn_tp_context().init_context(q_lora_rank, is_deepseek_nsa(config)) + @property def routed_experts_weights_of_layer(self): return self._routed_experts_weights_of_layer.value @@ -3217,9 +3245,10 @@ def forward( input_embeds: torch.Tensor = None, pp_proxy_tensors: Optional[PPProxyTensors] = None, ) -> torch.Tensor: - hidden_states = self.model( - input_ids, positions, forward_batch, input_embeds, pp_proxy_tensors - ) + with get_attn_tp_context().maybe_input_scattered(forward_batch): + hidden_states = self.model( + input_ids, positions, forward_batch, input_embeds, pp_proxy_tensors + ) aux_hidden_states = None if self.capture_aux_hidden_states: hidden_states, aux_hidden_states = hidden_states diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index f5cfd59acd55..6ccbd76e3609 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -523,6 +523,7 @@ class ServerArgs: numa_node: Optional[List[int]] = None enable_deterministic_inference: bool = False rl_on_policy_target: Optional[str] = None + enable_attn_tp_input_scattered: bool = False # Dynamic batch tokenizer enable_dynamic_batch_tokenizer: bool = False @@ -3474,6 +3475,11 @@ def add_cli_args(parser: argparse.ArgumentParser): choices=RL_ON_POLICY_TARGET_CHOICES, help="The training system that SGLang needs to match for true on-policy.", ) + parser.add_argument( + "--enable-attn-tp-input-scattered", + action="store_true", + help="Allow input of attention to be scattered when only using tensor parallelism, to reduce the computational load of operations such as qkv latent.", + ) # Dynamic batch tokenizer parser.add_argument(