-
Notifications
You must be signed in to change notification settings - Fork 5k
[Piecewise CUDA Graph] Fix recompile issue for Mixtral and Grok2 #13667
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
7 commits
Select commit
Hold shift + click to select a range
6756662
fix router
hebiao064 e9c5f91
grok clean up
hebiao064 2f4084d
add no_grad()
hebiao064 f57df04
revert router change and fix mixtral no_grad
hebiao064 97a4273
Merge branch 'main' into bhe/grok2-piecewise
hebiao064 a817fd1
Merge branch 'main' into bhe/grok2-piecewise
hebiao064 a7f111b
fix alt stream
hebiao064 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -179,6 +179,9 @@ benchmark/llava_bench/mme_pack | |
| *.jsonl | ||
| tmp*.txt | ||
|
|
||
| # Torch Compile logs | ||
| tl_out/ | ||
|
|
||
| # Plots | ||
| *.png | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -12,25 +12,18 @@ | |
| # limitations under the License. | ||
| # ============================================================================== | ||
|
|
||
| # Adapted from | ||
| # https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/mixtral.py#L1 | ||
| """Inference-only Grok1 model.""" | ||
| import functools | ||
| import logging | ||
| import math | ||
| import os | ||
| import warnings | ||
| from typing import Iterable, Optional, Tuple | ||
|
|
||
| import numpy as np | ||
| import torch | ||
| from torch import nn | ||
| from transformers import PretrainedConfig | ||
|
|
||
| from sglang.srt.distributed import ( | ||
| get_tensor_model_parallel_rank, | ||
| get_tensor_model_parallel_world_size, | ||
| tensor_model_parallel_all_gather, | ||
| tensor_model_parallel_all_reduce, | ||
| ) | ||
| from sglang.srt.layers.activation import GeluAndMul | ||
|
|
@@ -62,22 +55,15 @@ | |
| ParallelLMHead, | ||
| VocabParallelEmbedding, | ||
| ) | ||
| from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode | ||
| from sglang.srt.model_executor.forward_batch_info import ForwardBatch | ||
| from sglang.srt.model_loader.loader import DefaultModelLoader | ||
| from sglang.srt.model_loader.weight_utils import default_weight_loader | ||
| from sglang.srt.server_args import get_global_server_args | ||
| from sglang.srt.utils import add_prefix, dispose_tensor, dump_to_file | ||
| from sglang.srt.utils import add_prefix | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| # Dump tensors for debugging | ||
| debug_tensor_dump_output_folder = None | ||
| debug_tensor_dump_inject = False | ||
| debug_tensor_dump_layers = None | ||
| debug_tensor_dump_test = False | ||
|
|
||
|
|
||
| class Grok1MLP(nn.Module): | ||
| def __init__( | ||
| self, | ||
|
|
@@ -120,14 +106,6 @@ def forward(self, x): | |
|
|
||
|
|
||
| class Grok1MoE(nn.Module): | ||
| """A tensor-parallel MoE implementation for Grok1 that shards each expert | ||
| across all ranks. | ||
|
|
||
| Each expert's weights are sharded across all ranks and a fused MoE | ||
| kernel is used for the forward pass, and finally we reduce the outputs | ||
| across ranks. | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| config: PretrainedConfig, | ||
|
|
@@ -148,7 +126,6 @@ def __init__( | |
| super().__init__() | ||
| self.hidden_size = hidden_size | ||
|
|
||
| # Gate always runs at full precision for stability (see https://arxiv.org/pdf/2101.03961) | ||
| self.gate = ReplicatedLinear( | ||
| hidden_size, | ||
| num_experts, | ||
|
|
@@ -157,9 +134,7 @@ def __init__( | |
| quant_config=None, | ||
| ) | ||
|
|
||
| self.router_logit_softcapping = getattr( | ||
| config, "router_logit_softcapping", 30.0 | ||
| ) | ||
| self.router_logit_softcapping = 30.0 | ||
| custom_routing_function = functools.partial( | ||
| fused_moe_router_shim, self.router_logit_softcapping | ||
| ) | ||
|
|
@@ -186,7 +161,6 @@ def __init__( | |
| ) | ||
|
|
||
| def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: | ||
| # need to assert self.gate.quant_method is unquantized | ||
| topk_output = self.topk(hidden_states, self.gate.weight) | ||
| return self.experts(hidden_states, topk_output) | ||
|
|
||
|
|
@@ -447,75 +421,12 @@ def forward( | |
| hidden_states: torch.Tensor, | ||
| forward_batch: ForwardBatch, | ||
| ) -> torch.Tensor: | ||
| if hidden_states.shape[0] == 0: | ||
| assert ( | ||
| not self.o_proj.reduce_results | ||
| ), "short-circuiting allreduce will lead to hangs" | ||
| return hidden_states | ||
| if debug_tensor_dump_output_folder: | ||
| dump_to_file( | ||
| debug_tensor_dump_output_folder, | ||
| f"attn_input_{self.layer_id}", | ||
| hidden_states, | ||
| ) | ||
|
|
||
| if debug_tensor_dump_inject: | ||
| name = os.path.join( | ||
| debug_tensor_dump_output_folder, | ||
| f"jax_dump_attn_input_{self.layer_id}.npy", | ||
| ) | ||
| logger.info(f"Load {name} from jax.") | ||
| x = np.load(name) | ||
| hidden_states = torch.tensor(x[0, : hidden_states.shape[0]]).to( | ||
| hidden_states | ||
| ) | ||
|
|
||
| qkv, _ = self.qkv_proj(hidden_states) | ||
| dispose_tensor(hidden_states) | ||
|
|
||
| q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) | ||
| q, k = self.rotary_emb(positions, q, k) | ||
|
|
||
| if debug_tensor_dump_output_folder: | ||
| num_tokens = q.shape[0] | ||
| num_heads_q = self.num_heads | ||
| head_dim = self.head_dim | ||
| num_heads_kv = k.numel() // (num_tokens * head_dim) | ||
|
|
||
| dump_to_file( | ||
| debug_tensor_dump_output_folder, | ||
| f"q_{self.layer_id}", | ||
| tensor_model_parallel_all_gather( | ||
| q.reshape(num_tokens, num_heads_q, head_dim).contiguous(), dim=1 | ||
| ).contiguous(), | ||
| ) | ||
| dump_to_file( | ||
| debug_tensor_dump_output_folder, | ||
| f"k_{self.layer_id}", | ||
| tensor_model_parallel_all_gather( | ||
| k.reshape(num_tokens, num_heads_kv, head_dim).contiguous(), dim=1 | ||
| ).contiguous(), | ||
| ) | ||
| dump_to_file( | ||
| debug_tensor_dump_output_folder, | ||
| f"v_{self.layer_id}", | ||
| tensor_model_parallel_all_gather( | ||
| v.reshape(num_tokens, num_heads_kv, head_dim).contiguous(), dim=1 | ||
| ).contiguous(), | ||
| ) | ||
|
|
||
| attn_output = self.attn(q, k, v, forward_batch) | ||
| del q, k, v, qkv | ||
|
|
||
| if debug_tensor_dump_output_folder: | ||
| dump_to_file( | ||
| debug_tensor_dump_output_folder, | ||
| f"attn_output_{self.layer_id}", | ||
| tensor_model_parallel_all_gather( | ||
| attn_output.reshape(num_tokens, num_heads_q, head_dim).contiguous(), | ||
| dim=1, | ||
| ).contiguous(), | ||
| ) | ||
|
|
||
| output, _ = self.o_proj(attn_output) | ||
| return output | ||
|
|
@@ -648,14 +559,6 @@ def forward( | |
| hidden_states, | ||
| ) | ||
|
|
||
| if residual_original is not None: | ||
| dispose_tensor(residual_original) | ||
|
|
||
| dispose_flag = False | ||
| if residual is not hidden_states_original: | ||
| dispose_flag = True | ||
| dispose_tensor(hidden_states_original) | ||
|
|
||
| hidden_states = self.self_attn( | ||
| positions=positions, | ||
| hidden_states=hidden_states, | ||
|
|
@@ -673,21 +576,21 @@ def forward( | |
| self.post_attn_norm.variance_epsilon, | ||
| ) | ||
|
|
||
| if not dispose_flag: | ||
| dispose_tensor(hidden_states_original) | ||
|
|
||
| # Fully Connected | ||
| hidden_states = self.ffn(hidden_states) | ||
| return hidden_states, residual, self.post_moe_norm # defer layernorm | ||
|
|
||
| def moe_with_rmoe(self, x): | ||
| current_stream = torch.cuda.current_stream() | ||
| self.alt_stream.wait_stream(current_stream) | ||
| mlp_result = self.mlp(x) | ||
| with torch.cuda.stream(self.alt_stream): | ||
| # moe should not be inplace because of stream race condition | ||
| if self.alt_stream is not None and get_is_capture_mode(): | ||
| current_stream = torch.cuda.current_stream() | ||
| self.alt_stream.wait_stream(current_stream) | ||
|
Comment on lines
+585
to
+586
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| mlp_result = self.mlp(x) | ||
| with torch.cuda.stream(self.alt_stream): | ||
| moe_result = self.block_sparse_moe(x) | ||
| current_stream.wait_stream(self.alt_stream) | ||
| else: | ||
| mlp_result = self.mlp(x) | ||
| moe_result = self.block_sparse_moe(x) | ||
| current_stream.wait_stream(self.alt_stream) | ||
| return (mlp_result + moe_result) / 1.4142135623730951 | ||
|
|
||
|
|
||
|
|
@@ -752,41 +655,13 @@ def forward( | |
| positions, hidden_states, forward_batch, residual, deferred_norm | ||
| ) | ||
|
|
||
| if debug_tensor_dump_output_folder: | ||
| hidden_states = ( | ||
| fused_rmsnorm( | ||
| hidden_states, | ||
| deferred_norm.weight, | ||
| deferred_norm.variance_epsilon, | ||
| ) | ||
| + residual | ||
| ) | ||
|
|
||
| dump_to_file( | ||
| debug_tensor_dump_output_folder, | ||
| "last_hidden_before_norm", | ||
| hidden_states, | ||
| ) | ||
|
|
||
| hidden_states = fused_rmsnorm( | ||
| hidden_states, | ||
| self.norm.weight, | ||
| self.norm.variance_epsilon, | ||
| ) | ||
|
|
||
| dump_to_file( | ||
| debug_tensor_dump_output_folder, | ||
| "last_hidden_after_norm", | ||
| hidden_states, | ||
| ) | ||
| else: | ||
| hidden_states, _ = fused_dual_residual_rmsnorm( | ||
| hidden_states, | ||
| residual, | ||
| deferred_norm.weight, | ||
| self.norm.weight, | ||
| deferred_norm.variance_epsilon, | ||
| ) | ||
| hidden_states, _ = fused_dual_residual_rmsnorm( | ||
| hidden_states, | ||
| residual, | ||
| deferred_norm.weight, | ||
| self.norm.weight, | ||
| deferred_norm.variance_epsilon, | ||
| ) | ||
|
|
||
| return hidden_states | ||
|
|
||
|
|
@@ -862,31 +737,16 @@ def __init__( | |
| ) | ||
| self.logits_processor = LogitsProcessor(config) | ||
|
|
||
| # Dump tensors for debugging | ||
| global debug_tensor_dump_output_folder, debug_tensor_dump_inject | ||
| debug_tensor_dump_output_folder = ( | ||
| get_global_server_args().debug_tensor_dump_output_folder | ||
| ) | ||
| debug_tensor_dump_inject = get_global_server_args().debug_tensor_dump_inject | ||
| warnings.filterwarnings("ignore", category=FutureWarning) | ||
|
|
||
| if get_tensor_model_parallel_rank() == 0: | ||
| logger.info( | ||
| f"#parameters (analytical): {self.get_num_params_analytical() / 1e9:.2f} B, " | ||
| f"#parameters (actual): {self.get_num_params_torch() / 1e9:.2f} B" | ||
| ) | ||
| self.loaded_param_names = set() | ||
|
|
||
| @torch.no_grad() | ||
| def forward( | ||
| self, | ||
| input_ids: torch.Tensor, | ||
| positions: torch.Tensor, | ||
| forward_batch: ForwardBatch, | ||
| input_embeds: torch.Tensor = None, | ||
| ) -> torch.Tensor: | ||
| if debug_tensor_dump_output_folder: | ||
| dump_to_file(debug_tensor_dump_output_folder, "input_ids", input_ids) | ||
|
|
||
| hidden_states = self.model(input_ids, positions, forward_batch, input_embeds) | ||
| return self.logits_processor( | ||
| input_ids, hidden_states, self.lm_head, forward_batch | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.