-
Notifications
You must be signed in to change notification settings - Fork 5.3k
Support skip-softmax attention #19089
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
bfc6c16
11558e3
592a82b
fd4cf27
da87def
ff5fea2
6f33368
c20fa0d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,104 @@ | ||
| import random | ||
| from argparse import Namespace | ||
| from dataclasses import dataclass | ||
| from typing import List, Optional | ||
|
|
||
| from transformers import PreTrainedTokenizerBase | ||
|
|
||
| from sglang.benchmark.datasets.common import BaseDataset, DatasetRow | ||
|
|
||
| LONGBENCH_V2_REPO_ID = "THUDM/LongBench-v2" | ||
| LONGBENCH_V2_DEFAULT_OUTPUT_LEN = 10 # answer letter + short explanation | ||
|
|
||
|
|
||
| def _format_prompt(example: dict) -> str: | ||
| return ( | ||
| f"{example['context']}\n\n" | ||
| f"Question: {example['question']}\n" | ||
| f"A. {example['choice_A']}\n" | ||
| f"B. {example['choice_B']}\n" | ||
| f"C. {example['choice_C']}\n" | ||
| f"D. {example['choice_D']}\n" | ||
| f"Answer:" | ||
| ) | ||
|
|
||
|
|
||
| @dataclass | ||
| class LongBenchV2Dataset(BaseDataset): | ||
| dataset_path: str | ||
| num_requests: int | ||
| fixed_output_len: Optional[int] | ||
| context_len: Optional[int] | ||
|
|
||
| @classmethod | ||
| def from_args(cls, args: Namespace) -> "LongBenchV2Dataset": | ||
| return cls( | ||
| dataset_path=args.dataset_path, | ||
| num_requests=args.num_prompts, | ||
| fixed_output_len=args.sharegpt_output_len, | ||
| context_len=args.sharegpt_context_len, | ||
| ) | ||
|
|
||
| def load( | ||
| self, tokenizer: PreTrainedTokenizerBase, model_id=None | ||
| ) -> List[DatasetRow]: | ||
| return sample_longbench_v2_requests( | ||
| dataset_path=self.dataset_path, | ||
| num_requests=self.num_requests, | ||
| tokenizer=tokenizer, | ||
| fixed_output_len=self.fixed_output_len, | ||
| context_len=self.context_len, | ||
| ) | ||
|
|
||
|
|
||
| def sample_longbench_v2_requests( | ||
| dataset_path: str, | ||
| num_requests: int, | ||
| tokenizer: PreTrainedTokenizerBase, | ||
| fixed_output_len: Optional[int] = None, | ||
| context_len: Optional[int] = None, | ||
| ) -> List[DatasetRow]: | ||
| output_len = ( | ||
| fixed_output_len | ||
| if fixed_output_len is not None | ||
| else LONGBENCH_V2_DEFAULT_OUTPUT_LEN | ||
| ) | ||
|
|
||
| # Load dataset | ||
| if dataset_path: | ||
| # Local file (parquet or JSON lines) | ||
| import pandas as pd | ||
|
|
||
| if dataset_path.endswith(".parquet"): | ||
| df = pd.read_parquet(dataset_path) | ||
| examples = df.to_dict(orient="records") | ||
| else: | ||
| import json | ||
|
|
||
| with open(dataset_path) as f: | ||
| examples = [json.loads(line) for line in f if line.strip()] | ||
| else: | ||
| from datasets import load_dataset | ||
|
|
||
| ds = load_dataset(LONGBENCH_V2_REPO_ID, split="train") | ||
| examples = list(ds) | ||
|
|
||
| random.shuffle(examples) | ||
|
|
||
| rows: List[DatasetRow] = [] | ||
| for example in examples: | ||
| if len(rows) >= num_requests: | ||
| break | ||
|
|
||
| prompt = _format_prompt(example) | ||
| prompt_ids = tokenizer(prompt).input_ids | ||
| prompt_len = len(prompt_ids) | ||
|
|
||
| if context_len is not None and prompt_len + output_len > context_len: | ||
| continue | ||
|
|
||
| rows.append( | ||
| DatasetRow(prompt=prompt, prompt_len=prompt_len, output_len=output_len) | ||
| ) | ||
|
|
||
| return rows | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1790,6 +1790,7 @@ def _forward_standard_mha( | |
| enable_pdl=False, | ||
| is_causal=causal, | ||
| return_lse=False, | ||
| skip_softmax_threshold_scale_factor=envs.SGLANG_SKIP_SOFTMAX_PREFILL_THRESHOLD_SCALE_FACTOR.get(), | ||
| ) | ||
|
|
||
| # Use FA3 for SM90 (Hopper/H200) | ||
|
|
@@ -2025,6 +2026,7 @@ def _forward_trtllm( | |
| sparse_mla_top_k=self.nsa_index_topk, | ||
| bmm1_scale=bmm1_scale, | ||
| backend="trtllm-gen", | ||
| skip_softmax_threshold_scale_factor=envs.SGLANG_SKIP_SOFTMAX_DECODE_THRESHOLD_SCALE_FACTOR.get(), | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This unconditionally uses the decode-specific skip-softmax factor. However, You should select the appropriate factor based on the is_decode_like = (
forward_batch.forward_mode.is_decode_or_idle()
or forward_batch.forward_mode.is_target_verify()
or forward_batch.forward_mode.is_draft_extend(include_v2=True)
)
skip_softmax_factor = (
envs.SGLANG_SKIP_SOFTMAX_DECODE_THRESHOLD_SCALE_FACTOR.get()
if is_decode_like
else envs.SGLANG_SKIP_SOFTMAX_PREFILL_THRESHOLD_SCALE_FACTOR.get()
)Then use Additionally, please note that |
||
| ) | ||
| # Output: [batch, q_len=1, heads, v_dim] -> [batch, heads, v_dim] | ||
| return out.squeeze(1) | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.