Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/references/environment_variables.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ SGLang supports various environment variables that can be used to configure its
| `SGLANG_NCCL_ALL_GATHER_IN_OVERLAP_SCHEDULER_SYNC_BATCH` | Enable NCCL for gathering when preparing mlp sync batch under overlap scheduler (without this flag gloo is used for gathering) | `false` |
| `SGLANG_SYMM_MEM_PREALLOC_GB_SIZE` | Size of preallocated GPU buffer (in GB) for NCCL symmetric memory pool to limit memory fragmentation. Only have an effect when server arg `--enable-symm-mem` is set. | `-1` |
| `SGLANG_CUSTOM_ALLREDUCE_ALGO` | The algorithm of custom all-reduce. Set to `oneshot` or `1stage` to force use one-shot. Set to `twoshot` or `2stage` to force use two-shot. | `` |
| `SGLANG_SKIP_SOFTMAX_PREFILL_THRESHOLD_SCALE_FACTOR` | Skip-softmax threshold scale factor for TRT-LLM prefill attention in flashinfer. `None` means standard attention. See https://arxiv.org/abs/2512.12087 | `None` |
| `SGLANG_SKIP_SOFTMAX_DECODE_THRESHOLD_SCALE_FACTOR` | Skip-softmax threshold scale factor for TRT-LLM decode attention in flashinfer. `None` means standard attention. See https://arxiv.org/abs/2512.12087 | `None` |


## DeepGEMM Configuration (Advanced Optimization)
Expand Down
1 change: 1 addition & 0 deletions python/sglang/bench_serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -1937,6 +1937,7 @@ def __call__(self, parser, namespace, values, option_string=None):
"mmmu",
"image",
"mooncake",
"longbench_v2",
],
help="Name of the dataset to benchmark on.",
)
Expand Down
2 changes: 2 additions & 0 deletions python/sglang/benchmark/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
GeneratedSharedPrefixDataset,
)
from sglang.benchmark.datasets.image import ImageDataset
from sglang.benchmark.datasets.longbench_v2 import LongBenchV2Dataset
from sglang.benchmark.datasets.mmmu import MMMUDataset
from sglang.benchmark.datasets.mooncake import MooncakeDataset
from sglang.benchmark.datasets.openai_dataset import OpenAIDataset
Expand All @@ -24,6 +25,7 @@
"mmmu": MMMUDataset,
"image": ImageDataset,
"mooncake": MooncakeDataset,
"longbench_v2": LongBenchV2Dataset,
}


Expand Down
104 changes: 104 additions & 0 deletions python/sglang/benchmark/datasets/longbench_v2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
import random
from argparse import Namespace
from dataclasses import dataclass
from typing import List, Optional

from transformers import PreTrainedTokenizerBase

from sglang.benchmark.datasets.common import BaseDataset, DatasetRow

LONGBENCH_V2_REPO_ID = "THUDM/LongBench-v2"
LONGBENCH_V2_DEFAULT_OUTPUT_LEN = 10 # answer letter + short explanation


def _format_prompt(example: dict) -> str:
return (
f"{example['context']}\n\n"
f"Question: {example['question']}\n"
f"A. {example['choice_A']}\n"
f"B. {example['choice_B']}\n"
f"C. {example['choice_C']}\n"
f"D. {example['choice_D']}\n"
f"Answer:"
)


@dataclass
class LongBenchV2Dataset(BaseDataset):
dataset_path: str
num_requests: int
fixed_output_len: Optional[int]
context_len: Optional[int]

@classmethod
def from_args(cls, args: Namespace) -> "LongBenchV2Dataset":
return cls(
dataset_path=args.dataset_path,
num_requests=args.num_prompts,
fixed_output_len=args.sharegpt_output_len,
context_len=args.sharegpt_context_len,
)

def load(
self, tokenizer: PreTrainedTokenizerBase, model_id=None
) -> List[DatasetRow]:
return sample_longbench_v2_requests(
dataset_path=self.dataset_path,
num_requests=self.num_requests,
tokenizer=tokenizer,
fixed_output_len=self.fixed_output_len,
context_len=self.context_len,
)


def sample_longbench_v2_requests(
dataset_path: str,
num_requests: int,
tokenizer: PreTrainedTokenizerBase,
fixed_output_len: Optional[int] = None,
context_len: Optional[int] = None,
) -> List[DatasetRow]:
output_len = (
fixed_output_len
if fixed_output_len is not None
else LONGBENCH_V2_DEFAULT_OUTPUT_LEN
)

# Load dataset
if dataset_path:
# Local file (parquet or JSON lines)
import pandas as pd

if dataset_path.endswith(".parquet"):
df = pd.read_parquet(dataset_path)
examples = df.to_dict(orient="records")
else:
import json

with open(dataset_path) as f:
examples = [json.loads(line) for line in f if line.strip()]
else:
from datasets import load_dataset

ds = load_dataset(LONGBENCH_V2_REPO_ID, split="train")
examples = list(ds)

random.shuffle(examples)

rows: List[DatasetRow] = []
for example in examples:
if len(rows) >= num_requests:
break

prompt = _format_prompt(example)
prompt_ids = tokenizer(prompt).input_ids
prompt_len = len(prompt_ids)

if context_len is not None and prompt_len + output_len > context_len:
continue

rows.append(
DatasetRow(prompt=prompt, prompt_len=prompt_len, output_len=output_len)
)

return rows
4 changes: 4 additions & 0 deletions python/sglang/srt/environ.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,10 @@ class Envs:
# Default to the pick from flashinfer
SGLANG_FLASHINFER_FP4_GEMM_BACKEND = EnvStr("")
SGLANG_FLASHINFER_WORKSPACE_SIZE = EnvInt(384 * 1024 * 1024)
# Skip-softmax threshold scale factor for TRT-LLM attention (prefill and decode separately).
# None = standard attention. See https://arxiv.org/abs/2512.12087
SGLANG_SKIP_SOFTMAX_PREFILL_THRESHOLD_SCALE_FACTOR = EnvFloat(None)
SGLANG_SKIP_SOFTMAX_DECODE_THRESHOLD_SCALE_FACTOR = EnvFloat(None)
# TODO(mmangkad): Remove this once the FlashInfer unified allreduce-fusion
# transport issue on GB200/GB300 platforms is fixed and verified resolved.
SGLANG_FLASHINFER_FORCE_POSIX_FD_TRANSPORT = EnvBool(None)
Expand Down
2 changes: 2 additions & 0 deletions python/sglang/srt/layers/attention/nsa_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -1790,6 +1790,7 @@ def _forward_standard_mha(
enable_pdl=False,
is_causal=causal,
return_lse=False,
skip_softmax_threshold_scale_factor=envs.SGLANG_SKIP_SOFTMAX_PREFILL_THRESHOLD_SCALE_FACTOR.get(),
)

# Use FA3 for SM90 (Hopper/H200)
Expand Down Expand Up @@ -2025,6 +2026,7 @@ def _forward_trtllm(
sparse_mla_top_k=self.nsa_index_topk,
bmm1_scale=bmm1_scale,
backend="trtllm-gen",
skip_softmax_threshold_scale_factor=envs.SGLANG_SKIP_SOFTMAX_DECODE_THRESHOLD_SCALE_FACTOR.get(),
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This unconditionally uses the decode-specific skip-softmax factor. However, _forward_trtllm can be called for prefill operations via forward_extend when nsa_prefill_impl is set to "trtllm". This could lead to using the wrong threshold for prefill.

You should select the appropriate factor based on the forward_mode of the forward_batch. For example:

is_decode_like = (
    forward_batch.forward_mode.is_decode_or_idle()
    or forward_batch.forward_mode.is_target_verify()
    or forward_batch.forward_mode.is_draft_extend(include_v2=True)
)
skip_softmax_factor = (
    envs.SGLANG_SKIP_SOFTMAX_DECODE_THRESHOLD_SCALE_FACTOR.get()
    if is_decode_like
    else envs.SGLANG_SKIP_SOFTMAX_PREFILL_THRESHOLD_SCALE_FACTOR.get()
)

Then use skip_softmax_factor in the function call.

Additionally, please note that flashinfer.decode.trtllm_batch_decode_with_kv_cache_mla is a decode kernel. Using it for prefill might be a separate issue to investigate.

)
# Output: [batch, q_len=1, heads, v_dim] -> [batch, heads, v_dim]
return out.squeeze(1)
Expand Down
3 changes: 3 additions & 0 deletions python/sglang/srt/layers/attention/trtllm_mha_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -773,6 +773,7 @@ def forward_decode(
bmm2_scale=bmm2_scale,
window_left=layer.sliding_window_size,
sinks=attention_sink,
skip_softmax_threshold_scale_factor=envs.SGLANG_SKIP_SOFTMAX_DECODE_THRESHOLD_SCALE_FACTOR.get(),
out_dtype=self.q_data_type, # model_runner.dtype
)

Expand Down Expand Up @@ -855,6 +856,7 @@ def forward_extend(
bmm2_scale=bmm2_scale,
window_left=layer.sliding_window_size,
sinks=attention_sink,
skip_softmax_threshold_scale_factor=envs.SGLANG_SKIP_SOFTMAX_DECODE_THRESHOLD_SCALE_FACTOR.get(),
out_dtype=self.q_data_type, # model_runner.dtype
q_len_per_req=self.forward_metadata.max_seq_len_q,
)
Expand All @@ -874,6 +876,7 @@ def forward_extend(
cum_seq_lens_kv=self.forward_metadata.cu_seqlens_k,
window_left=layer.sliding_window_size,
sinks=attention_sink,
skip_softmax_threshold_scale_factor=envs.SGLANG_SKIP_SOFTMAX_PREFILL_THRESHOLD_SCALE_FACTOR.get(),
out_dtype=self.q_data_type, # model_runner.dtype
)

Expand Down
4 changes: 4 additions & 0 deletions python/sglang/srt/layers/attention/trtllm_mla_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import triton.language as tl

from sglang.srt.compilation.piecewise_context_manager import is_in_piecewise_cuda_graph
from sglang.srt.environ import envs
from sglang.srt.layers.attention.flashinfer_mla_backend import (
FlashInferMLAAttnBackend,
FlashInferMLAMultiStepDraftBackend,
Expand Down Expand Up @@ -875,6 +876,7 @@ def forward_decode(
seq_lens=forward_batch.seq_lens.to(torch.int32),
max_seq_len=metadata.max_seq_len_k,
bmm1_scale=bmm1_scale,
skip_softmax_threshold_scale_factor=envs.SGLANG_SKIP_SOFTMAX_DECODE_THRESHOLD_SCALE_FACTOR.get(),
)

# Reshape output directly without slicing
Expand Down Expand Up @@ -1062,6 +1064,7 @@ def forward_extend(
seq_lens=metadata.seq_lens_k,
max_seq_len=max_seq_len,
bmm1_scale=bmm1_scale,
skip_softmax_threshold_scale_factor=envs.SGLANG_SKIP_SOFTMAX_DECODE_THRESHOLD_SCALE_FACTOR.get(),
)

if needs_unpad:
Expand Down Expand Up @@ -1099,6 +1102,7 @@ def forward_extend(
"bmm1_scale": q_scale * k_scale * layer.scaling,
"bmm2_scale": v_scale,
"cum_seq_lens_q": self.forward_prefill_metadata.cum_seq_lens,
"skip_softmax_threshold_scale_factor": envs.SGLANG_SKIP_SOFTMAX_PREFILL_THRESHOLD_SCALE_FACTOR.get(),
}

# When chunked prefix cache is enabled, dispatch to different path for ragged attention.
Expand Down
Loading