diff --git a/docs/backend/server_arguments.md b/docs/backend/server_arguments.md index 3c96a681695..70a180a7779 100644 --- a/docs/backend/server_arguments.md +++ b/docs/backend/server_arguments.md @@ -138,7 +138,7 @@ Please consult the documentation below to learn more about the parameters you ma ## Kernel backend -* `attention_backend`: The backend for attention computation and KV cache management. +* `attention_backend`: The backend for attention computation and KV cache management, and can be one of `fa3`, `flashinfer`, `triton` or `torch_native`. When deploying deepseek models, this argument is for specifying the MLA backend it uses. * `sampling_backend`: The backend for sampling. ## Constrained Decoding @@ -192,5 +192,5 @@ Please consult the documentation below to learn more about the parameters you ma * `cuda_graph_bs`: The batch sizes to capture by `CudaGraphRunner`. By default this is done for you. * `torchao_config`: Experimental feature that optimizes the model with [torchao](https://github.com/pytorch/ao). Possible choices are: int8dq, int8wo, int4wo-, fp8wo, fp8dq-per_tensor, fp8dq-per_row. * `triton_attention_num_kv_splits`: Use to adjust the number of KV splits in triton kernels. Default is 8. -* `enable_flashinfer_mla`: Use the attention backend with flashinfer MLA wrapper for deepseek models. When providing this argument, `attention_backend` argument is overridden. -* `flashinfer_mla_disable_ragged`: Disable usage of ragged prefill wrapper for flashinfer mla attention backend. Should be used when `enable_flashinfer_mla` is turned on. +* `enable_flashinfer_mla`: Use the attention backend with flashinfer MLA wrapper for deepseek models. **This argument will be deprecated soon! Please use `--attention_backend flashinfer` instead for switching on flashfiner mla!** +* `flashinfer_mla_disable_ragged`: Disable usage of ragged prefill wrapper for flashinfer mla attention backend. Should be used when flashinfer is used as mla backend turned on. diff --git a/docs/references/deepseek.md b/docs/references/deepseek.md index 9a079b9e77c..0e4cea70e56 100644 --- a/docs/references/deepseek.md +++ b/docs/references/deepseek.md @@ -86,7 +86,7 @@ Please refer to [the example](https://github.com/sgl-project/sglang/tree/main/be - **Weight Absorption**: By applying the associative law of matrix multiplication to reorder computation steps, this method balances computation and memory access and improves efficiency in the decoding phase. -- **Flashinfer MLA Wrapper**: By providing `--enable-flashinfer-mla` argument, the server will use MLA kernels customized by Flashinfer. More details can be referred to [this document](https://docs.flashinfer.ai/api/mla.html). Under long input scenarios, flashinfer mla can improve performance significantly. Optimized triton kernels will be used when flashinfer mla is turned off. +- **MLA Attention Backends**: Currently SGLang supports different optimized MLA attention backends, including FlashAttention3, [Flashinfer](https://docs.flashinfer.ai/api/mla.html) and Triton backends. It can be set with `--attention-backend` argument. - **FP8 Quantization**: W8A8 FP8 and KV Cache FP8 quantization enables efficient FP8 inference. Additionally, we have implemented Batched Matrix Multiplication (BMM) operator to facilitate FP8 inference in MLA with weight absorption. @@ -149,7 +149,7 @@ python3 -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3-0324 --spec ``` - The draft model are available at huggingface: [lmsys/DeepSeek-V3-0324-NextN](https://huggingface.co/lmsys/DeepSeek-V3-0324-NextN), [lmsys/DeepSeek-R1-NextN](https://huggingface.co/lmsys/DeepSeek-R1-NextN). It can also be exported from original DeepSeek-V3/R1 model with [export_deepseek_nextn.py](https://github.com/sgl-project/sglang/blob/main/scripts/export_deepseek_nextn.py) script. - The best configuratin for `--speculative-num-steps`, `--speculative-eagle-topk` and `--speculative-num-draft-tokens` can be searched with [bench_speculative.py](https://github.com/sgl-project/sglang/blob/main/scripts/playground/bench_speculative.py) script for given batch size. The minimum configuration is `--speculative-num-steps 1 --speculative-eagle-topk 1 --speculative-num-draft-tokens 2`, which can achieve speedup for larger batch sizes. -- Currently when using flashinfer mla wrapper (`--enable-flashinfer-mla`) and speculative decoding together, the `--speculative-eagle-topk` parameter should be set to `1`. +- Currently when using flashinfer mla wrapper (`--attention-backend flashinfer`) and speculative decoding together, the `--speculative-eagle-topk` parameter should be set to `1`. The MTP feature on FlashAttention 3 backend is still under beta. - To enable DeepSeek MTP for large batch sizes (>32), there are some parameters should be changed (Reference [this discussion](https://github.com/sgl-project/sglang/issues/4543#issuecomment-2737413756)): - Adjust `--max-running-requests` to a larger number. The default value is `32` for MTP. For larger batch sizes, you should increase this value beyond the default value. - Set `--cuda-graph-bs`. It's a list of batch sizes for cuda graph capture. The default captured batch sizes for speculative decoding is set [here](https://github.com/sgl-project/sglang/blob/49420741746c8f3e80e0eb17e7d012bfaf25793a/python/sglang/srt/model_executor/cuda_graph_runner.py#L126). You can include more batch sizes into it. diff --git a/python/sglang/srt/layers/attention/flashinfer_mla_backend.py b/python/sglang/srt/layers/attention/flashinfer_mla_backend.py index 65bcdf51391..81afcb9dac5 100644 --- a/python/sglang/srt/layers/attention/flashinfer_mla_backend.py +++ b/python/sglang/srt/layers/attention/flashinfer_mla_backend.py @@ -71,8 +71,6 @@ def __init__( self.device = model_runner.device self.skip_prefill = skip_prefill - global_config.enable_flashinfer_mla = True - # Allocate buffers global global_workspace_buffer if global_workspace_buffer is None: diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index a8796cb421d..107765eded1 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -76,7 +76,6 @@ "device": ServerArgs.device, "speculative_accept_threshold_single": ServerArgs.speculative_accept_threshold_single, "speculative_accept_threshold_acc": ServerArgs.speculative_accept_threshold_acc, - "enable_flashinfer_mla": ServerArgs.enable_flashinfer_mla, "enable_flashmla": ServerArgs.enable_flashmla, "disable_radix_cache": ServerArgs.disable_radix_cache, "flashinfer_mla_disable_ragged": ServerArgs.flashinfer_mla_disable_ragged, @@ -1435,7 +1434,7 @@ def get_model_worker_batch(self) -> ModelWorkerBatch: # Create seq_lens_cpu when needed if ( - global_server_args_dict["enable_flashinfer_mla"] + global_server_args_dict["attention_backend"] == "flashinfer_mla" or global_server_args_dict["enable_flashmla"] or global_server_args_dict["attention_backend"] == "fa3" ): diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index f42ea02d529..3b8b769a671 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -151,7 +151,6 @@ def __init__( "device": server_args.device, "speculative_accept_threshold_single": server_args.speculative_accept_threshold_single, "speculative_accept_threshold_acc": server_args.speculative_accept_threshold_acc, - "enable_flashinfer_mla": server_args.enable_flashinfer_mla, "enable_flashmla": server_args.enable_flashmla, "disable_radix_cache": server_args.disable_radix_cache, "flashinfer_mla_disable_ragged": server_args.flashinfer_mla_disable_ragged, @@ -223,10 +222,14 @@ def model_specific_adjustment(self): ): # TODO: add MLA optimization on CPU if server_args.device != "cpu": - if server_args.enable_flashinfer_mla: + if ( + server_args.attention_backend == "flashinfer" + or server_args.enable_flashinfer_mla + ): logger.info( - "MLA optimization is turned on. Use flashinfer mla backend." + "MLA optimization is turned on. Use flashinfer backend." ) + # Here we use a special flashinfer_mla tag to differentiate it from normal flashinfer backend server_args.attention_backend = "flashinfer_mla" elif server_args.enable_flashmla: logger.info("MLA optimization is turned on. Use flashmla decode.") diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 6aaa3744a86..775b7413c10 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -684,7 +684,6 @@ def __init__( self.w_vc = None self.w_scale = None - self.enable_flashinfer_mla = global_server_args_dict["enable_flashinfer_mla"] self.flashinfer_mla_disable_ragged = global_server_args_dict[ "flashinfer_mla_disable_ragged" ] @@ -692,7 +691,7 @@ def __init__( self.rocm_fused_decode_mla = os.getenv("SGLANG_ROCM_FUSED_DECODE_MLA") == "1" def no_absorb(self, forward_batch: ForwardBatch) -> bool: - if self.enable_flashinfer_mla: + if self.attention_backend == "flashinfer_mla": # Flashinfer MLA: Do not absorb when enabling ragged prefill return ( not self.flashinfer_mla_disable_ragged diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 1a19bbea225..1cd44862fc5 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -179,7 +179,7 @@ class ServerArgs: tool_call_parser: Optional[str] = None enable_hierarchical_cache: bool = False hicache_ratio: float = 2.0 - enable_flashinfer_mla: bool = False + enable_flashinfer_mla: bool = False # TODO: remove this argument enable_flashmla: bool = False flashinfer_mla_disable_ragged: bool = False warmups: Optional[str] = None @@ -836,7 +836,7 @@ def add_cli_args(parser: argparse.ArgumentParser): parser.add_argument( "--enable-flashinfer-mla", action="store_true", - help="Enable FlashInfer MLA optimization", + help="Enable FlashInfer MLA optimization. This argument will be deprecated soon! Please use '--attention-backend flashinfer' instead for switching on flashfiner mla!", ) parser.add_argument( "--enable-flashmla", diff --git a/test/srt/test_mla_flashinfer.py b/test/srt/test_mla_flashinfer.py index 81de94d3afd..4f0953e6a3e 100644 --- a/test/srt/test_mla_flashinfer.py +++ b/test/srt/test_mla_flashinfer.py @@ -26,7 +26,8 @@ def setUpClass(cls): "--enable-torch-compile", "--cuda-graph-max-bs", "2", - "--enable-flashinfer-mla", + "--attention-backend", + "flashinfer", ] ) cls.process = popen_launch_server( @@ -69,8 +70,8 @@ def setUpClass(cls): "--disable-cuda-graph", "--cuda-graph-max-bs", "4", - "--enable-flashinfer-mla", - "--flashinfer-mla-disable-ragged", + "--attention-backend", + "flashinfer", ] ) cls.process = popen_launch_server( @@ -125,7 +126,8 @@ def setUpClass(cls): "1", "--speculative-num-draft-tokens", "4", - "--enable-flashinfer-mla", + "--attention-backend", + "flashinfer", ] ) cls.process = popen_launch_server(