Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion docs/references/environment_variables.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,9 @@ SGLang supports various environment variables that can be used to configure its
| Environment Variable | Description | Default Value |
| --- | --- | --- |
| `SGLANG_DEEPEP_BF16_DISPATCH` | Use Bfloat16 for dispatch | `"false"` |
| `SGLANG_MOE_NVFP4_DISPATCH` | Use nvfp4 for moe dispatch | `"false"` |
| `SGLANG_MOE_NVFP4_DISPATCH` | Use nvfp4 for moe dispatch (on flashinfer_cutlass or flashinfer_cutedsl moe runner backend) | `"false"` |
| `SGLANG_DEEPEP_NUM_MAX_DISPATCH_TOKENS_PER_RANK` | The maximum number of dispatched tokens on each GPU | `"128"` |
| `SGLANG_DEEPEP_LL_COMBINE_SEND_NUM_SMS` | Number of SMs used for DeepEP combine when single batch overlap is enabled | `"32"` |

## Memory Management

Expand All @@ -65,6 +67,8 @@ SGLang supports various environment variables that can be used to configure its
| `SGLANG_CLIP_MAX_NEW_TOKENS_ESTIMATION` | Clip max new tokens estimation for memory planning | `4096` |
| `SGLANG_DETOKENIZER_MAX_STATES` | Maximum states for detokenizer | Default value based on system |
| `SGL_DISABLE_TP_MEMORY_INBALANCE_CHECK` | Disable checks for memory imbalance across Tensor Parallel ranks | Not set (defaults to enabled check) |
| `SGLANG_NVFP4_CKPT_FP8_GEMM_IN_ATTN` | Quantize q_b_proj from BF16 to FP8 when launching DeepSeek NVFP4 checkpoint | `false` |
| `SGLANG_PER_TOKEN_GROUP_QUANT_8BIT_V2` | Apply per token group quantization kernel with fused silu and mul and masked m | `false` |

## Model-Specific Options

Expand Down
11 changes: 7 additions & 4 deletions python/sglang/srt/batch_overlap/single_batch_overlap.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,10 @@

import torch

from sglang.srt.environ import envs
from sglang.srt.layers.moe import get_moe_runner_backend
from sglang.srt.layers.moe.utils import is_sbo_enabled
from sglang.srt.utils import get_int_env_var, is_blackwell
from sglang.srt.utils import is_blackwell


class SboFlags:
Expand Down Expand Up @@ -87,9 +88,11 @@ def compute_overlap_args(dispatch_output, alt_stream):
total_num_sms = torch.cuda.get_device_properties(
device="cuda"
).multi_processor_count
communicate_num_sms = get_int_env_var(
"SGLANG_DEEPEP_LL_COMBINE_SEND_NUM_SMS", 32 if is_blackwell() else 3
)

if envs.SGLANG_DEEPEP_LL_COMBINE_SEND_NUM_SMS.is_set():
communicate_num_sms = envs.SGLANG_DEEPEP_LL_COMBINE_SEND_NUM_SMS.get()
else:
communicate_num_sms = 32 if is_blackwell() else 3
compute_num_sms = total_num_sms - communicate_num_sms

assert alt_stream is not None
Expand Down
10 changes: 10 additions & 0 deletions python/sglang/srt/environ.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,8 @@ class Envs:
SGLANG_USE_DYNAMIC_MXFP4_LINEAR = EnvBool(False)
SGLANG_FORCE_FP8_MARLIN = EnvBool(False)
SGLANG_MOE_NVFP4_DISPATCH = EnvBool(False)
SGLANG_NVFP4_CKPT_FP8_GEMM_IN_ATTN = EnvBool(False)
SGLANG_PER_TOKEN_GROUP_QUANT_8BIT_V2 = EnvBool(False)

# Flashinfer
SGLANG_IS_FLASHINFER_AVAILABLE = EnvBool(True)
Expand Down Expand Up @@ -276,6 +278,11 @@ class Envs:
SGLANG_DG_USE_NVRTC = EnvBool(False)
SGLANG_USE_DEEPGEMM_BMM = EnvBool(False)

# DeepEP
SGLANG_DEEPEP_BF16_DISPATCH = EnvBool(False)
SGLANG_DEEPEP_NUM_MAX_DISPATCH_TOKENS_PER_RANK = EnvInt(128)
SGLANG_DEEPEP_LL_COMBINE_SEND_NUM_SMS = EnvInt(32)

# sgl-kernel
SGLANG_SKIP_SGL_KERNEL_VERSION_CHECK = EnvBool(False)

Expand Down Expand Up @@ -365,6 +372,9 @@ def _convert_SGL_to_SGLANG():
_print_deprecated_env(
"SGLANG_ENABLE_FLASHINFER_FP8_GEMM", "SGLANG_ENABLE_FLASHINFER_GEMM"
)
_print_deprecated_env(
"SGLANG_MOE_NVFP4_DISPATCH", "SGLANG_CUTEDSL_MOE_NVFP4_DISPATCH"
)

for key, value in os.environ.items():
if key.startswith("SGL_"):
Expand Down
5 changes: 3 additions & 2 deletions python/sglang/srt/layers/moe/ep_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import torch

from sglang.srt.environ import envs
from sglang.srt.layers import deep_gemm_wrapper
from sglang.srt.layers.moe import (
get_deepep_mode,
Expand Down Expand Up @@ -303,8 +304,8 @@ def forward_cutlass_w4afp8_masked(
):
assert self.moe_runner_config.activation == "silu"
assert isinstance(self.quant_method, W4AFp8MoEMethod)
assert get_bool_env_var(
"SGLANG_DEEPEP_BF16_DISPATCH"
assert (
envs.SGLANG_DEEPEP_BF16_DISPATCH.get()
), "W4AFP8 does not support FP8 dispatch; please set SGLANG_DEEPEP_BF16_DISPATCH=1."
return self.quant_method.apply_deepep_ll(
layer=self,
Expand Down
10 changes: 5 additions & 5 deletions python/sglang/srt/layers/moe/token_dispatcher/deepep.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from dataclasses import dataclass
from typing import TYPE_CHECKING, List, NamedTuple, Optional, Tuple, Union

from sglang.srt.environ import envs
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
from sglang.srt.layers import deep_gemm_wrapper
from sglang.srt.layers.dp_attention import get_is_extend_in_batch
Expand All @@ -26,7 +27,6 @@
)
from sglang.srt.utils import (
get_bool_env_var,
get_int_env_var,
is_blackwell,
is_hip,
is_npu,
Expand Down Expand Up @@ -317,8 +317,8 @@ def __init__(

self.params_bytes = 2
# A large value will lead to large memory occupation, thus users should change it accordingly
self.num_max_dispatch_tokens_per_rank = get_int_env_var(
"SGLANG_DEEPEP_NUM_MAX_DISPATCH_TOKENS_PER_RANK", 128
self.num_max_dispatch_tokens_per_rank = (
envs.SGLANG_DEEPEP_NUM_MAX_DISPATCH_TOKENS_PER_RANK.get()
)
# DeepEP internode_ll dispatch uses FINISHED_SUM_TAG=1024
# and the logic requires num-tokens-sent-from-one-rank-to-another-rank less than it
Expand Down Expand Up @@ -387,7 +387,7 @@ def dispatch_a(
if (
deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
and not get_moe_runner_backend().is_cutlass()
and not get_bool_env_var("SGLANG_DEEPEP_BF16_DISPATCH")
and not envs.SGLANG_DEEPEP_BF16_DISPATCH.get()
):
# TODO hard code 128 block quant,use fp8 communication
hidden_states = sglang_per_token_group_quant_fp8(
Expand Down Expand Up @@ -609,7 +609,7 @@ def _dispatch_core(
input_global_scale = self.quant_config.get("input_global_scale", None)
if input_global_scale is not None:
use_nvfp4 = True
elif not get_bool_env_var("SGLANG_DEEPEP_BF16_DISPATCH"):
elif not envs.SGLANG_DEEPEP_BF16_DISPATCH.get():
use_fp8 = True

buffer = self._get_buffer()
Expand Down
6 changes: 3 additions & 3 deletions python/sglang/srt/layers/moe/token_dispatcher/fuseep.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import torch

from sglang.srt.environ import envs
from sglang.srt.layers.moe.token_dispatcher.base import (
BaseDispatcher,
CombineInput,
Expand All @@ -15,7 +16,6 @@
from sglang.srt.layers.moe.token_dispatcher.deepep import DeepEPBuffer
from sglang.srt.layers.moe.topk import TopKOutput
from sglang.srt.layers.moe.utils import DeepEPMode
from sglang.srt.utils import get_int_env_var

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -62,8 +62,8 @@ def __init__(
self.deepep_mode = deepep_mode

self.params_bytes = 2
self.num_max_dispatch_tokens_per_rank = get_int_env_var(
"SGLANG_DEEPEP_NUM_MAX_DISPATCH_TOKENS_PER_RANK", 128
self.num_max_dispatch_tokens_per_rank = (
envs.SGLANG_DEEPEP_NUM_MAX_DISPATCH_TOKENS_PER_RANK.get()
)

def dispatch(
Expand Down
5 changes: 3 additions & 2 deletions python/sglang/srt/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce,
)
from sglang.srt.environ import envs
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation
from sglang.srt.eplb.expert_location_dispatch import ExpertLocationDispatchInfo
Expand Down Expand Up @@ -2655,7 +2656,7 @@ def _concat_and_cast_mha_k(self, k_nope, k_pe, forward_batch):

@staticmethod
def _get_q_b_proj_quant_config(quant_config):
if get_bool_env_var("SGLANG_NVFP4_CKPT_FP8_GEMM_IN_ATTN"):
if envs.SGLANG_NVFP4_CKPT_FP8_GEMM_IN_ATTN.get():
# refer to real DeepSeek V3 quant config
return Fp8Config(
is_checkpoint_fp8_serialized=True,
Expand Down Expand Up @@ -3591,7 +3592,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], is_nextn=Fal
else:
raise ValueError("num_nextn_predict_layers is not in the config")

if get_bool_env_var("SGLANG_NVFP4_CKPT_FP8_GEMM_IN_ATTN"):
if envs.SGLANG_NVFP4_CKPT_FP8_GEMM_IN_ATTN.get():
weights = self._quant_attn_to_fp8_ue8m0(weights, is_nextn=is_nextn)
if is_nextn and enable_nextn_moe_bf16_cast_to_fp8(self.quant_config):
weights = self._quant_nextn_moe_to_fp8_ue8m0(
Expand Down
Loading