diff --git a/.gitignore b/.gitignore index 2f4cdbfa03e9..118dd9ae462b 100644 --- a/.gitignore +++ b/.gitignore @@ -179,6 +179,9 @@ benchmark/llava_bench/mme_pack *.jsonl tmp*.txt +# Torch Compile logs +tl_out/ + # Plots *.png *.pdf diff --git a/python/sglang/srt/models/grok.py b/python/sglang/srt/models/grok.py index ac5bff3dd383..fd513060a911 100644 --- a/python/sglang/srt/models/grok.py +++ b/python/sglang/srt/models/grok.py @@ -12,17 +12,11 @@ # limitations under the License. # ============================================================================== -# Adapted from -# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/mixtral.py#L1 -"""Inference-only Grok1 model.""" import functools import logging import math -import os -import warnings from typing import Iterable, Optional, Tuple -import numpy as np import torch from torch import nn from transformers import PretrainedConfig @@ -30,7 +24,6 @@ from sglang.srt.distributed import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, - tensor_model_parallel_all_gather, tensor_model_parallel_all_reduce, ) from sglang.srt.layers.activation import GeluAndMul @@ -62,22 +55,15 @@ ParallelLMHead, VocabParallelEmbedding, ) +from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.loader import DefaultModelLoader from sglang.srt.model_loader.weight_utils import default_weight_loader -from sglang.srt.server_args import get_global_server_args -from sglang.srt.utils import add_prefix, dispose_tensor, dump_to_file +from sglang.srt.utils import add_prefix logger = logging.getLogger(__name__) -# Dump tensors for debugging -debug_tensor_dump_output_folder = None -debug_tensor_dump_inject = False -debug_tensor_dump_layers = None -debug_tensor_dump_test = False - - class Grok1MLP(nn.Module): def __init__( self, @@ -120,14 +106,6 @@ def forward(self, x): class Grok1MoE(nn.Module): - """A tensor-parallel MoE implementation for Grok1 that shards each expert - across all ranks. - - Each expert's weights are sharded across all ranks and a fused MoE - kernel is used for the forward pass, and finally we reduce the outputs - across ranks. - """ - def __init__( self, config: PretrainedConfig, @@ -148,7 +126,6 @@ def __init__( super().__init__() self.hidden_size = hidden_size - # Gate always runs at full precision for stability (see https://arxiv.org/pdf/2101.03961) self.gate = ReplicatedLinear( hidden_size, num_experts, @@ -157,9 +134,7 @@ def __init__( quant_config=None, ) - self.router_logit_softcapping = getattr( - config, "router_logit_softcapping", 30.0 - ) + self.router_logit_softcapping = 30.0 custom_routing_function = functools.partial( fused_moe_router_shim, self.router_logit_softcapping ) @@ -186,7 +161,6 @@ def __init__( ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - # need to assert self.gate.quant_method is unquantized topk_output = self.topk(hidden_states, self.gate.weight) return self.experts(hidden_states, topk_output) @@ -447,75 +421,12 @@ def forward( hidden_states: torch.Tensor, forward_batch: ForwardBatch, ) -> torch.Tensor: - if hidden_states.shape[0] == 0: - assert ( - not self.o_proj.reduce_results - ), "short-circuiting allreduce will lead to hangs" - return hidden_states - if debug_tensor_dump_output_folder: - dump_to_file( - debug_tensor_dump_output_folder, - f"attn_input_{self.layer_id}", - hidden_states, - ) - - if debug_tensor_dump_inject: - name = os.path.join( - debug_tensor_dump_output_folder, - f"jax_dump_attn_input_{self.layer_id}.npy", - ) - logger.info(f"Load {name} from jax.") - x = np.load(name) - hidden_states = torch.tensor(x[0, : hidden_states.shape[0]]).to( - hidden_states - ) - qkv, _ = self.qkv_proj(hidden_states) - dispose_tensor(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k = self.rotary_emb(positions, q, k) - if debug_tensor_dump_output_folder: - num_tokens = q.shape[0] - num_heads_q = self.num_heads - head_dim = self.head_dim - num_heads_kv = k.numel() // (num_tokens * head_dim) - - dump_to_file( - debug_tensor_dump_output_folder, - f"q_{self.layer_id}", - tensor_model_parallel_all_gather( - q.reshape(num_tokens, num_heads_q, head_dim).contiguous(), dim=1 - ).contiguous(), - ) - dump_to_file( - debug_tensor_dump_output_folder, - f"k_{self.layer_id}", - tensor_model_parallel_all_gather( - k.reshape(num_tokens, num_heads_kv, head_dim).contiguous(), dim=1 - ).contiguous(), - ) - dump_to_file( - debug_tensor_dump_output_folder, - f"v_{self.layer_id}", - tensor_model_parallel_all_gather( - v.reshape(num_tokens, num_heads_kv, head_dim).contiguous(), dim=1 - ).contiguous(), - ) - attn_output = self.attn(q, k, v, forward_batch) - del q, k, v, qkv - - if debug_tensor_dump_output_folder: - dump_to_file( - debug_tensor_dump_output_folder, - f"attn_output_{self.layer_id}", - tensor_model_parallel_all_gather( - attn_output.reshape(num_tokens, num_heads_q, head_dim).contiguous(), - dim=1, - ).contiguous(), - ) output, _ = self.o_proj(attn_output) return output @@ -648,14 +559,6 @@ def forward( hidden_states, ) - if residual_original is not None: - dispose_tensor(residual_original) - - dispose_flag = False - if residual is not hidden_states_original: - dispose_flag = True - dispose_tensor(hidden_states_original) - hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, @@ -673,21 +576,21 @@ def forward( self.post_attn_norm.variance_epsilon, ) - if not dispose_flag: - dispose_tensor(hidden_states_original) - # Fully Connected hidden_states = self.ffn(hidden_states) return hidden_states, residual, self.post_moe_norm # defer layernorm def moe_with_rmoe(self, x): - current_stream = torch.cuda.current_stream() - self.alt_stream.wait_stream(current_stream) - mlp_result = self.mlp(x) - with torch.cuda.stream(self.alt_stream): - # moe should not be inplace because of stream race condition + if self.alt_stream is not None and get_is_capture_mode(): + current_stream = torch.cuda.current_stream() + self.alt_stream.wait_stream(current_stream) + mlp_result = self.mlp(x) + with torch.cuda.stream(self.alt_stream): + moe_result = self.block_sparse_moe(x) + current_stream.wait_stream(self.alt_stream) + else: + mlp_result = self.mlp(x) moe_result = self.block_sparse_moe(x) - current_stream.wait_stream(self.alt_stream) return (mlp_result + moe_result) / 1.4142135623730951 @@ -752,41 +655,13 @@ def forward( positions, hidden_states, forward_batch, residual, deferred_norm ) - if debug_tensor_dump_output_folder: - hidden_states = ( - fused_rmsnorm( - hidden_states, - deferred_norm.weight, - deferred_norm.variance_epsilon, - ) - + residual - ) - - dump_to_file( - debug_tensor_dump_output_folder, - "last_hidden_before_norm", - hidden_states, - ) - - hidden_states = fused_rmsnorm( - hidden_states, - self.norm.weight, - self.norm.variance_epsilon, - ) - - dump_to_file( - debug_tensor_dump_output_folder, - "last_hidden_after_norm", - hidden_states, - ) - else: - hidden_states, _ = fused_dual_residual_rmsnorm( - hidden_states, - residual, - deferred_norm.weight, - self.norm.weight, - deferred_norm.variance_epsilon, - ) + hidden_states, _ = fused_dual_residual_rmsnorm( + hidden_states, + residual, + deferred_norm.weight, + self.norm.weight, + deferred_norm.variance_epsilon, + ) return hidden_states @@ -862,21 +737,9 @@ def __init__( ) self.logits_processor = LogitsProcessor(config) - # Dump tensors for debugging - global debug_tensor_dump_output_folder, debug_tensor_dump_inject - debug_tensor_dump_output_folder = ( - get_global_server_args().debug_tensor_dump_output_folder - ) - debug_tensor_dump_inject = get_global_server_args().debug_tensor_dump_inject - warnings.filterwarnings("ignore", category=FutureWarning) - - if get_tensor_model_parallel_rank() == 0: - logger.info( - f"#parameters (analytical): {self.get_num_params_analytical() / 1e9:.2f} B, " - f"#parameters (actual): {self.get_num_params_torch() / 1e9:.2f} B" - ) self.loaded_param_names = set() + @torch.no_grad() def forward( self, input_ids: torch.Tensor, @@ -884,9 +747,6 @@ def forward( forward_batch: ForwardBatch, input_embeds: torch.Tensor = None, ) -> torch.Tensor: - if debug_tensor_dump_output_folder: - dump_to_file(debug_tensor_dump_output_folder, "input_ids", input_ids) - hidden_states = self.model(input_ids, positions, forward_batch, input_embeds) return self.logits_processor( input_ids, hidden_states, self.lm_head, forward_batch diff --git a/python/sglang/srt/models/mixtral.py b/python/sglang/srt/models/mixtral.py index cb55848cfc7d..c4f3e4c446f7 100644 --- a/python/sglang/srt/models/mixtral.py +++ b/python/sglang/srt/models/mixtral.py @@ -353,6 +353,7 @@ def __init__( ) self.logits_processor = LogitsProcessor(config) + @torch.no_grad() def forward( self, input_ids: torch.Tensor,