Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion docs/backend/server_arguments.md
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,5 @@ Please consult the documentation below to learn more about the parameters you ma
* `cuda_graph_bs`: The batch sizes to capture by `CudaGraphRunner`. By default this is done for you.
* `torchao_config`: Experimental feature that optimizes the model with [torchao](https://github.com/pytorch/ao). Possible choices are: int8dq, int8wo, int4wo-<group_size>, fp8wo, fp8dq-per_tensor, fp8dq-per_row.
* `triton_attention_num_kv_splits`: Use to adjust the number of KV splits in triton kernels. Default is 8.
* `enable_flashinfer_mla`: Use the attention backend with FlashInfer MLA wrapper for DeepSeek models. **This argument will be deprecated in the next release. Please use `--attention_backend flashinfer` instead to enable FlashfIner MLA.**
* `flashinfer_mla_disable_ragged`: Disable the use of the ragged prefill wrapper for the FlashInfer MLA attention backend. Only use it when FlashInfer is being used as the MLA backend.
* `disable_chunked_prefix_cache`: Disable the use of chunked prefix cache for DeepSeek models. Only use it when FA3 is attention backend.
2 changes: 1 addition & 1 deletion docs/references/deepseek.md
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ Please refer to [the example](https://github.com/sgl-project/sglang/tree/main/be

- **Weight Absorption**: By applying the associative law of matrix multiplication to reorder computation steps, this method balances computation and memory access and improves efficiency in the decoding phase.

- **MLA Attention Backends**: Currently SGLang supports different optimized MLA attention backends, including FlashAttention3, [Flashinfer](https://docs.flashinfer.ai/api/mla.html) and Triton backends. It can be set with `--attention-backend` argument.
- **MLA Attention Backends**: Currently SGLang supports different optimized MLA attention backends, including [FlashAttention3](https://github.com/Dao-AILab/flash-attention), [Flashinfer](https://docs.flashinfer.ai/api/mla.html), and [Triton](https://github.com/triton-lang/triton) backends. It can be set with `--attention-backend` argument.

- **FP8 Quantization**: W8A8 FP8 and KV Cache FP8 quantization enables efficient FP8 inference. Additionally, we have implemented Batched Matrix Multiplication (BMM) operator to facilitate FP8 inference in MLA with weight absorption.

Expand Down
3 changes: 1 addition & 2 deletions python/sglang/srt/managers/schedule_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,6 @@
"device": ServerArgs.device,
"speculative_accept_threshold_single": ServerArgs.speculative_accept_threshold_single,
"speculative_accept_threshold_acc": ServerArgs.speculative_accept_threshold_acc,
"enable_flashmla": ServerArgs.enable_flashmla,
"disable_radix_cache": ServerArgs.disable_radix_cache,
"flashinfer_mla_disable_ragged": ServerArgs.flashinfer_mla_disable_ragged,
"chunked_prefill_size": ServerArgs.chunked_prefill_size,
Expand Down Expand Up @@ -1481,7 +1480,7 @@ def get_model_worker_batch(self) -> ModelWorkerBatch:
global_server_args_dict["use_mla_backend"]
and global_server_args_dict["attention_backend"] == "flashinfer"
)
or global_server_args_dict["enable_flashmla"]
or global_server_args_dict["attention_backend"] == "flashmla"
or global_server_args_dict["attention_backend"] == "fa3"
):
seq_lens_cpu = self.seq_lens.cpu()
Expand Down
18 changes: 7 additions & 11 deletions python/sglang/srt/model_executor/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,6 @@ def __init__(
"device": server_args.device,
"speculative_accept_threshold_single": server_args.speculative_accept_threshold_single,
"speculative_accept_threshold_acc": server_args.speculative_accept_threshold_acc,
"enable_flashmla": server_args.enable_flashmla,
"disable_radix_cache": server_args.disable_radix_cache,
"flashinfer_mla_disable_ragged": server_args.flashinfer_mla_disable_ragged,
"debug_tensor_dump_output_folder": server_args.debug_tensor_dump_output_folder,
Expand Down Expand Up @@ -229,15 +228,7 @@ def initialize(self, min_per_gpu_memory: float):
def model_specific_adjustment(self):
server_args = self.server_args

if server_args.enable_flashinfer_mla:
# TODO: remove this branch after enable_flashinfer_mla is deprecated
logger.info("MLA optimization is turned on. Use flashinfer backend.")
server_args.attention_backend = "flashinfer"
elif server_args.enable_flashmla:
# TODO: remove this branch after enable_flashmla is deprecated
logger.info("MLA optimization is turned on. Use flashmla decode.")
server_args.attention_backend = "flashmla"
elif server_args.attention_backend is None:
if server_args.attention_backend is None:
# By default, use flashinfer for non-mla attention and triton for mla attention
if not self.use_mla_backend:
if (
Expand All @@ -263,7 +254,12 @@ def model_specific_adjustment(self):
elif self.use_mla_backend:
# TODO: add MLA optimization on CPU
if server_args.device != "cpu":
if server_args.attention_backend in ["flashinfer", "fa3", "triton"]:
if server_args.attention_backend in [
"flashinfer",
"fa3",
"triton",
"flashmla",
]:
logger.info(
f"MLA optimization is turned on. Use {server_args.attention_backend} backend."
)
Expand Down
14 changes: 6 additions & 8 deletions python/sglang/srt/server_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,8 +180,6 @@ class ServerArgs:
tool_call_parser: Optional[str] = None
enable_hierarchical_cache: bool = False
hicache_ratio: float = 2.0
enable_flashinfer_mla: bool = False # TODO: remove this argument
enable_flashmla: bool = False
flashinfer_mla_disable_ragged: bool = False
warmups: Optional[str] = None
n_share_experts_fusion: int = 0
Expand Down Expand Up @@ -255,7 +253,7 @@ def __post_init__(self):

assert self.chunked_prefill_size % self.page_size == 0

if self.enable_flashmla is True:
if self.attention_backend == "flashmla":
logger.warning(
"FlashMLA only supports a page_size of 64, change page_size to 64."
)
Expand Down Expand Up @@ -824,7 +822,7 @@ def add_cli_args(parser: argparse.ArgumentParser):
parser.add_argument(
"--attention-backend",
type=str,
choices=["flashinfer", "triton", "torch_native", "fa3"],
choices=["flashinfer", "triton", "torch_native", "fa3", "flashmla"],
default=ServerArgs.attention_backend,
help="Choose the kernels for attention layers.",
)
Expand All @@ -844,13 +842,13 @@ def add_cli_args(parser: argparse.ArgumentParser):
)
parser.add_argument(
"--enable-flashinfer-mla",
action="store_true",
help="Enable FlashInfer MLA optimization. This argument will be deprecated soon! Please use '--attention-backend flashinfer' instead for switching on flashfiner mla!",
action=DeprecatedAction,
help="--enable-flashinfer-mla is deprecated. Please use '--attention-backend flashinfer' instead.",
)
parser.add_argument(
"--enable-flashmla",
action="store_true",
help="Enable FlashMLA decode optimization",
action=DeprecatedAction,
help="--enable-flashmla is deprecated. Please use '--attention-backend flashmla' instead.",
)
parser.add_argument(
"--flashinfer-mla-disable-ragged",
Expand Down
11 changes: 3 additions & 8 deletions scripts/playground/bench_speculative.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,16 +176,11 @@ def main(args, server_args):
]
)

if server_args.enable_flashinfer_mla:
if server_args.attention_backend:
other_args.extend(
[
"--enable-flashinfer-mla",
]
)
if server_args.enable_flashmla:
other_args.extend(
[
"--enable-flashmla",
"--attention-backend",
server_args.attention_backend,
]
)

Expand Down
Loading