Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,9 @@ benchmark/llava_bench/mme_pack
*.jsonl
tmp*.txt

# Torch Compile logs
tl_out/

# Plots
*.png
*.pdf
Expand Down
180 changes: 20 additions & 160 deletions python/sglang/srt/models/grok.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,25 +12,18 @@
# limitations under the License.
# ==============================================================================

# Adapted from
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/mixtral.py#L1
"""Inference-only Grok1 model."""
import functools
import logging
import math
import os
import warnings
from typing import Iterable, Optional, Tuple

import numpy as np
import torch
from torch import nn
from transformers import PretrainedConfig

from sglang.srt.distributed import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_gather,
tensor_model_parallel_all_reduce,
)
from sglang.srt.layers.activation import GeluAndMul
Expand Down Expand Up @@ -62,22 +55,15 @@
ParallelLMHead,
VocabParallelEmbedding,
)
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.loader import DefaultModelLoader
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.server_args import get_global_server_args
from sglang.srt.utils import add_prefix, dispose_tensor, dump_to_file
from sglang.srt.utils import add_prefix

logger = logging.getLogger(__name__)


# Dump tensors for debugging
debug_tensor_dump_output_folder = None
debug_tensor_dump_inject = False
debug_tensor_dump_layers = None
debug_tensor_dump_test = False


class Grok1MLP(nn.Module):
def __init__(
self,
Expand Down Expand Up @@ -120,14 +106,6 @@ def forward(self, x):


class Grok1MoE(nn.Module):
"""A tensor-parallel MoE implementation for Grok1 that shards each expert
across all ranks.

Each expert's weights are sharded across all ranks and a fused MoE
kernel is used for the forward pass, and finally we reduce the outputs
across ranks.
"""

def __init__(
self,
config: PretrainedConfig,
Expand All @@ -148,7 +126,6 @@ def __init__(
super().__init__()
self.hidden_size = hidden_size

# Gate always runs at full precision for stability (see https://arxiv.org/pdf/2101.03961)
self.gate = ReplicatedLinear(
hidden_size,
num_experts,
Expand All @@ -157,9 +134,7 @@ def __init__(
quant_config=None,
)

self.router_logit_softcapping = getattr(
config, "router_logit_softcapping", 30.0
)
self.router_logit_softcapping = 30.0
custom_routing_function = functools.partial(
fused_moe_router_shim, self.router_logit_softcapping
)
Expand All @@ -186,7 +161,6 @@ def __init__(
)

def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# need to assert self.gate.quant_method is unquantized
topk_output = self.topk(hidden_states, self.gate.weight)
return self.experts(hidden_states, topk_output)

Expand Down Expand Up @@ -447,75 +421,12 @@ def forward(
hidden_states: torch.Tensor,
forward_batch: ForwardBatch,
) -> torch.Tensor:
if hidden_states.shape[0] == 0:
assert (
not self.o_proj.reduce_results
), "short-circuiting allreduce will lead to hangs"
return hidden_states
if debug_tensor_dump_output_folder:
dump_to_file(
debug_tensor_dump_output_folder,
f"attn_input_{self.layer_id}",
hidden_states,
)

if debug_tensor_dump_inject:
name = os.path.join(
debug_tensor_dump_output_folder,
f"jax_dump_attn_input_{self.layer_id}.npy",
)
logger.info(f"Load {name} from jax.")
x = np.load(name)
hidden_states = torch.tensor(x[0, : hidden_states.shape[0]]).to(
hidden_states
)

qkv, _ = self.qkv_proj(hidden_states)
dispose_tensor(hidden_states)

q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)

if debug_tensor_dump_output_folder:
num_tokens = q.shape[0]
num_heads_q = self.num_heads
head_dim = self.head_dim
num_heads_kv = k.numel() // (num_tokens * head_dim)

dump_to_file(
debug_tensor_dump_output_folder,
f"q_{self.layer_id}",
tensor_model_parallel_all_gather(
q.reshape(num_tokens, num_heads_q, head_dim).contiguous(), dim=1
).contiguous(),
)
dump_to_file(
debug_tensor_dump_output_folder,
f"k_{self.layer_id}",
tensor_model_parallel_all_gather(
k.reshape(num_tokens, num_heads_kv, head_dim).contiguous(), dim=1
).contiguous(),
)
dump_to_file(
debug_tensor_dump_output_folder,
f"v_{self.layer_id}",
tensor_model_parallel_all_gather(
v.reshape(num_tokens, num_heads_kv, head_dim).contiguous(), dim=1
).contiguous(),
)

attn_output = self.attn(q, k, v, forward_batch)
del q, k, v, qkv

if debug_tensor_dump_output_folder:
dump_to_file(
debug_tensor_dump_output_folder,
f"attn_output_{self.layer_id}",
tensor_model_parallel_all_gather(
attn_output.reshape(num_tokens, num_heads_q, head_dim).contiguous(),
dim=1,
).contiguous(),
)

output, _ = self.o_proj(attn_output)
return output
Expand Down Expand Up @@ -648,14 +559,6 @@ def forward(
hidden_states,
)

if residual_original is not None:
dispose_tensor(residual_original)

dispose_flag = False
if residual is not hidden_states_original:
dispose_flag = True
dispose_tensor(hidden_states_original)

hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
Expand All @@ -673,21 +576,21 @@ def forward(
self.post_attn_norm.variance_epsilon,
)

if not dispose_flag:
dispose_tensor(hidden_states_original)

# Fully Connected
hidden_states = self.ffn(hidden_states)
return hidden_states, residual, self.post_moe_norm # defer layernorm

def moe_with_rmoe(self, x):
current_stream = torch.cuda.current_stream()
self.alt_stream.wait_stream(current_stream)
mlp_result = self.mlp(x)
with torch.cuda.stream(self.alt_stream):
# moe should not be inplace because of stream race condition
if self.alt_stream is not None and get_is_capture_mode():
current_stream = torch.cuda.current_stream()
self.alt_stream.wait_stream(current_stream)
Comment on lines +585 to +586
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

These two lines are redundant. current_stream is already defined on line 584, and self.alt_stream.wait_stream(current_stream) is called on line 585. These lines can be safely removed to improve code clarity.

mlp_result = self.mlp(x)
with torch.cuda.stream(self.alt_stream):
moe_result = self.block_sparse_moe(x)
current_stream.wait_stream(self.alt_stream)
else:
mlp_result = self.mlp(x)
moe_result = self.block_sparse_moe(x)
current_stream.wait_stream(self.alt_stream)
return (mlp_result + moe_result) / 1.4142135623730951


Expand Down Expand Up @@ -752,41 +655,13 @@ def forward(
positions, hidden_states, forward_batch, residual, deferred_norm
)

if debug_tensor_dump_output_folder:
hidden_states = (
fused_rmsnorm(
hidden_states,
deferred_norm.weight,
deferred_norm.variance_epsilon,
)
+ residual
)

dump_to_file(
debug_tensor_dump_output_folder,
"last_hidden_before_norm",
hidden_states,
)

hidden_states = fused_rmsnorm(
hidden_states,
self.norm.weight,
self.norm.variance_epsilon,
)

dump_to_file(
debug_tensor_dump_output_folder,
"last_hidden_after_norm",
hidden_states,
)
else:
hidden_states, _ = fused_dual_residual_rmsnorm(
hidden_states,
residual,
deferred_norm.weight,
self.norm.weight,
deferred_norm.variance_epsilon,
)
hidden_states, _ = fused_dual_residual_rmsnorm(
hidden_states,
residual,
deferred_norm.weight,
self.norm.weight,
deferred_norm.variance_epsilon,
)

return hidden_states

Expand Down Expand Up @@ -862,31 +737,16 @@ def __init__(
)
self.logits_processor = LogitsProcessor(config)

# Dump tensors for debugging
global debug_tensor_dump_output_folder, debug_tensor_dump_inject
debug_tensor_dump_output_folder = (
get_global_server_args().debug_tensor_dump_output_folder
)
debug_tensor_dump_inject = get_global_server_args().debug_tensor_dump_inject
warnings.filterwarnings("ignore", category=FutureWarning)

if get_tensor_model_parallel_rank() == 0:
logger.info(
f"#parameters (analytical): {self.get_num_params_analytical() / 1e9:.2f} B, "
f"#parameters (actual): {self.get_num_params_torch() / 1e9:.2f} B"
)
self.loaded_param_names = set()

@torch.no_grad()
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
) -> torch.Tensor:
if debug_tensor_dump_output_folder:
dump_to_file(debug_tensor_dump_output_folder, "input_ids", input_ids)

hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
return self.logits_processor(
input_ids, hidden_states, self.lm_head, forward_batch
Expand Down
1 change: 1 addition & 0 deletions python/sglang/srt/models/mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,7 @@ def __init__(
)
self.logits_processor = LogitsProcessor(config)

@torch.no_grad()
def forward(
self,
input_ids: torch.Tensor,
Expand Down
Loading