Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/CN/source/tutorial/api_server_args_zh.rst
Original file line number Diff line number Diff line change
Expand Up @@ -333,9 +333,9 @@ attention类型选择参数

推理后端将为解码使用 flashinfer 的注意力 kernel

.. option:: --enable_fa3
.. option:: --disable_fa3

推理后端将为预填充和解码使用 fa3 注意力 kernel
推理后端将不为预填充和解码使用 fa3 注意力 kernel(FA3 默认启用)

.. option:: --disable_cudagraph

Expand Down
14 changes: 1 addition & 13 deletions docs/CN/source/tutorial/deepseek_deployment.rst
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,11 @@ LightLLM 支持以下几种部署模式:
# H200 单机 DeepSeek-R1 TP 模式
LOADWORKER=18 python -m lightllm.server.api_server --port 8088 \
--model_dir /path/DeepSeek-R1 \
--tp 8 \
--enable_fa3
--tp 8

**参数说明:**
- `LOADWORKER=18`: 模型加载线程数,提高加载速度
- `--tp 8`: 张量并行度,使用8个GPU
- `--enable_fa3`: 启用 Flash Attention 3.0
- `--port 8088`: 服务端口

1.2 单机 DP + EP 模式 (Data Parallel + Expert Parallel)
Expand All @@ -55,13 +53,11 @@ LightLLM 支持以下几种部署模式:
--model_dir /path/DeepSeek-R1 \
--tp 8 \
--dp 8 \
--enable_fa3

**参数说明:**
- `MOE_MODE=EP`: 设置专家并行模式
- `--tp 8`: 张量并行度
- `--dp 8`: 数据并行度,通常设置为与 tp 相同的值
- `--enable_fa3`: 启用 Flash Attention 3.0

**可选优化参数:**
- `--enable_prefill_microbatch_overlap`: 启用预填充微批次重叠
Expand All @@ -85,7 +81,6 @@ LightLLM 支持以下几种部署模式:
LOADWORKER=18 python -m lightllm.server.api_server --port 8088 \
--model_dir /path/DeepSeek-R1 \
--tp 16 \
--enable_fa3 \
--nnodes 2 \
--node_rank 0 \
--nccl_host $nccl_host \
Expand All @@ -101,7 +96,6 @@ LightLLM 支持以下几种部署模式:
LOADWORKER=18 python -m lightllm.server.api_server --port 8088 \
--model_dir /path/DeepSeek-R1 \
--tp 16 \
--enable_fa3 \
--nnodes 2 \
--node_rank 1 \
--nccl_host $nccl_host \
Expand Down Expand Up @@ -129,7 +123,6 @@ LightLLM 支持以下几种部署模式:
--model_dir /path/DeepSeek-R1 \
--tp 16 \
--dp 16 \
--enable_fa3 \
--nnodes 2 \
--node_rank 0 \
--nccl_host $nccl_host \
Expand All @@ -146,7 +139,6 @@ LightLLM 支持以下几种部署模式:
--model_dir /path/DeepSeek-R1 \
--tp 16 \
--dp 16 \
--enable_fa3 \
--nnodes 2 \
--node_rank 1 \
--nccl_host $nccl_host \
Expand Down Expand Up @@ -195,7 +187,6 @@ PD (Prefill-Decode) 分离模式将预填充和解码阶段分离部署,可以
--host $host \
--port 8019 \
--nccl_port 2732 \
--enable_fa3 \
--disable_cudagraph \
--pd_master_ip $pd_master_ip \
--pd_master_port 60011
Expand All @@ -219,7 +210,6 @@ PD (Prefill-Decode) 分离模式将预填充和解码阶段分离部署,可以
--host $host \
--port 8121 \
--nccl_port 12322 \
--enable_fa3 \
--disable_cudagraph \
--pd_master_ip $pd_master_ip \
--pd_master_port 60011
Expand Down Expand Up @@ -287,7 +277,6 @@ PD (Prefill-Decode) 分离模式将预填充和解码阶段分离部署,可以
--tp 8 \
--dp 8 \
--nccl_port 2732 \
--enable_fa3 \
--disable_cudagraph \
--config_server_host $config_server_host \
--config_server_port 60088
Expand All @@ -306,7 +295,6 @@ PD (Prefill-Decode) 分离模式将预填充和解码阶段分离部署,可以
--nccl_port 12322 \
--tp 8 \
--dp 8 \
--enable_fa3 \
--config_server_host $config_server_host \
--config_server_port 60088
# 如果需要启用微批次重叠,可以取消注释以下行
Expand Down
4 changes: 2 additions & 2 deletions docs/EN/source/tutorial/api_server_args_zh.rst
Original file line number Diff line number Diff line change
Expand Up @@ -332,9 +332,9 @@ Performance Optimization Parameters

The inference backend will use flashinfer's attention kernel for decoding

.. option:: --enable_fa3
.. option:: --disable_fa3

The inference backend will use fa3 attention kernel for prefill and decoding
The inference backend will not use fa3 attention kernel for prefill and decoding (FA3 is enabled by default)

.. option:: --disable_cudagraph

Expand Down
14 changes: 1 addition & 13 deletions docs/EN/source/tutorial/deepseek_deployment.rst
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,11 @@ Suitable for deploying DeepSeek-R1 model on a single H200 node.
# H200 Single node DeepSeek-R1 TP Mode
LOADWORKER=18 python -m lightllm.server.api_server --port 8088 \
--model_dir /path/DeepSeek-R1 \
--tp 8 \
--enable_fa3
--tp 8

**Parameter Description:**
- `LOADWORKER=18`: Model loading thread count, improves loading speed
- `--tp 8`: Tensor parallelism, using 8 GPUs
- `--enable_fa3`: Enable Flash Attention 3.0
- `--port 8088`: Service port

1.2 Single node DP + EP Mode (Data Parallel + Expert Parallel)
Expand All @@ -55,13 +53,11 @@ Suitable for expert parallelism deployment of MoE models like DeepSeek-V2/V3.
--model_dir /path/DeepSeek-R1 \
--tp 8 \
--dp 8 \
--enable_fa3

**Parameter Description:**
- `MOE_MODE=EP`: Set expert parallelism mode
- `--tp 8`: Tensor parallelism
- `--dp 8`: Data parallelism, usually set to the same value as tp
- `--enable_fa3`: Enable Flash Attention 3.0

**Optional Optimization Parameters:**
- `--enable_prefill_microbatch_overlap`: Enable prefill microbatch overlap
Expand All @@ -85,7 +81,6 @@ Suitable for deployment across multiple H200/H100 nodes.
LOADWORKER=18 python -m lightllm.server.api_server --port 8088 \
--model_dir /path/DeepSeek-R1 \
--tp 16 \
--enable_fa3 \
--nnodes 2 \
--node_rank 0 \
--nccl_host $nccl_host \
Expand All @@ -101,7 +96,6 @@ Suitable for deployment across multiple H200/H100 nodes.
LOADWORKER=18 python -m lightllm.server.api_server --port 8088 \
--model_dir /path/DeepSeek-R1 \
--tp 16 \
--enable_fa3 \
--nnodes 2 \
--node_rank 1 \
--nccl_host $nccl_host \
Expand Down Expand Up @@ -129,7 +123,6 @@ Suitable for deploying MoE models across multiple nodes.
--model_dir /path/DeepSeek-R1 \
--tp 16 \
--dp 16 \
--enable_fa3 \
--nnodes 2 \
--node_rank 0 \
--nccl_host $nccl_host \
Expand All @@ -146,7 +139,6 @@ Suitable for deploying MoE models across multiple nodes.
--model_dir /path/DeepSeek-R1 \
--tp 16 \
--dp 16 \
--enable_fa3 \
--nnodes 2 \
--node_rank 1 \
--nccl_host $nccl_host \
Expand Down Expand Up @@ -195,7 +187,6 @@ PD (Prefill-Decode) disaggregation mode separates prefill and decode stages for
--host $host \
--port 8019 \
--nccl_port 2732 \
--enable_fa3 \
--disable_cudagraph \
--pd_master_ip $pd_master_ip

Expand All @@ -216,7 +207,6 @@ PD (Prefill-Decode) disaggregation mode separates prefill and decode stages for
--host $host \
--port 8121 \
--nccl_port 12322 \
--enable_fa3 \
--disable_cudagraph \
--pd_master_ip $pd_master_ip \
--pd_master_port 60011
Expand Down Expand Up @@ -284,7 +274,6 @@ Supports multiple PD Master nodes, providing better load balancing and high avai
--tp 8 \
--dp 8 \
--nccl_port 2732 \
--enable_fa3 \
--disable_cudagraph \
--config_server_host $config_server_host \
--config_server_port 60088
Expand All @@ -303,7 +292,6 @@ Supports multiple PD Master nodes, providing better load balancing and high avai
--nccl_port 12322 \
--tp 8 \
--dp 8 \
--enable_fa3 \
--config_server_host $config_server_host \
--config_server_port 60088
# if you want to enable microbatch overlap, you can uncomment the following lines
Expand Down
14 changes: 7 additions & 7 deletions lightllm/common/offline_fp8_quant_mem_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def __init__(
self.abs_max = None

if is_export_mode:
scales_shape = [layer_num, 2 * head_num] if get_env_start_args().enable_fa3 else [layer_num, 2]
scales_shape = [layer_num, 2 * head_num] if not get_env_start_args().disable_fa3 else [layer_num, 2]
self.abs_max = torch.zeros(scales_shape, dtype=torch.float32, device="cuda")
elif get_env_start_args().kv_quant_calibration_config_path is not None:
logger.info(
Expand All @@ -43,15 +43,15 @@ def __init__(

self.scales_list = cfg["scales"]
self.scales = torch.tensor(self.scales_list, dtype=torch.float32, device="cuda").view(cfg["scales_shape"])
if not get_env_start_args().enable_fa3:
if get_env_start_args().disable_fa3:
self.scales = torch.repeat_interleave(self.scales, head_num, dim=-1)
elif cfg["num_head"] > self.total_head_num:
factor = cfg["num_head"] // self.total_head_num
self.scales = self.scales[..., ::factor].contiguous()
elif cfg["num_head"] < self.total_head_num:
factor = self.total_head_num // cfg["num_head"]
self.scales = torch.repeat_interleave(self.scales, factor, dim=-1).contiguous()
if get_env_start_args().enable_fa3 and dist.is_initialized() and dist.get_world_size() > 1:
if not get_env_start_args().disable_fa3 and dist.is_initialized() and dist.get_world_size() > 1:
half_head = self.total_head_num // 2
start_head = dist.get_rank() * head_num
end_head = start_head + head_num
Expand Down Expand Up @@ -86,7 +86,7 @@ def _load_and_check_config(self):
raise ValueError(
f"num_head {cfg['num_head']} in config " f"not match current model head num {self.total_head_num}"
)
if get_env_start_args().enable_fa3:
if not get_env_start_args().disable_fa3:
if cfg["quant_type"] != "per_head":
raise ValueError(f"quant type {cfg['num_head']} in config not match fa3 backend")
else:
Expand All @@ -109,7 +109,7 @@ def update_calibration_data(self, kv_buffer: torch.Tensor, layer_index: int):
logger.info("kv cache calibration mode will collect kv cache data for quantization calibration")

if self.abs_max is not None and self.count >= warmup_counts:
if get_env_start_args().enable_fa3:
if not get_env_start_args().disable_fa3:
kv_max = kv_buffer.abs().amax(dim=(0, 2)).to(torch.float32)
else:
k_max = kv_buffer[:, : self.head_num, :].abs().amax(dim=()).to(torch.float32)
Expand All @@ -119,7 +119,7 @@ def update_calibration_data(self, kv_buffer: torch.Tensor, layer_index: int):
if self.count == warmup_counts + inference_counts - 1 and layer_index == self.layer_num - 1:
final_abs_max = self.abs_max
if dist.is_initialized() and dist.get_world_size() > 1:
if get_env_start_args().enable_fa3:
if not get_env_start_args().disable_fa3:
k_max, v_max = torch.chunk(self.abs_max, 2, dim=-1)
k_max = k_max.contiguous()
v_max = v_max.contiguous()
Expand Down Expand Up @@ -148,7 +148,7 @@ def _export_calibration_data(self):
cfg = {
"version": "1.0",
"architectures": model_arch,
"quant_type": "per_head" if get_env_start_args().enable_fa3 else "per_tensor",
"quant_type": "per_head" if not get_env_start_args().disable_fa3 else "per_tensor",
"qmin": self.qmin,
"qmax": self.qmax,
"num_layers": self.layer_num,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def _bind_attention(self):
)
else:
self._copy_kv_to_mem_cache = partial(Deepseek2TransformerLayerInfer._copy_kv_to_mem_cache_normal, self)
if get_env_start_args().enable_fa3:
if not get_env_start_args().disable_fa3:
self._token_attention_kernel = partial(
Deepseek2TransformerLayerInfer._token_gqa_decode_attention_flashattention, self
)
Expand All @@ -118,7 +118,7 @@ def _bind_attention(self):
Deepseek2TransformerLayerInfer._context_attention_kernel_with_CC_fp8, self
)
else:
if get_env_start_args().enable_fa3:
if not get_env_start_args().disable_fa3:
self._context_attention_kernel = partial(
Deepseek2TransformerLayerInfer._context_attention_flashattention_kernel_with_CC, self
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ def _parse_config(self):
self.num_attention_heads = self.network_config_["num_attention_heads"]
self.kv_lora_rank = self.network_config_["kv_lora_rank"]
self.num_fused_shared_experts = 0
if get_env_start_args().enable_fused_shared_experts and self.is_moe:
# MOE_MODE 处于 TP 模式下才能使能 enable_fused_shared_experts
if not get_env_start_args().disable_fused_shared_experts and self.is_moe:
# MOE_MODE 处于 TP 模式下才能使能 fused_shared_experts
moe_mode = os.getenv("MOE_MODE", "TP")
assert moe_mode == "TP"
self.num_fused_shared_experts = self.network_config_.get("n_shared_experts", 0)
Expand Down
2 changes: 1 addition & 1 deletion lightllm/models/deepseek2/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def __init__(self, kvargs):
return

def _init_inferstate_cls(self):
if get_env_start_args().enable_fa3:
if not get_env_start_args().disable_fa3:
self.infer_state_class = Deepseek2FlashAttentionStateInfo
elif self.enable_flashinfer:
self.infer_state_class = Deepseek2FlashInferStateInfo
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def _bind_norm(self):
return

def _bind_attention(self):
if get_env_start_args().enable_fa3:
if not get_env_start_args().disable_fa3:
if "offline_calibration_fp8kv" in self.mode:
self._context_attention_kernel = partial(
LlamaTransformerLayerInfer._context_attention_flashattention_fp8, self
Expand Down
2 changes: 1 addition & 1 deletion lightllm/models/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def _init_mem_manager(self):
return

def _init_inferstate_cls(self):
if get_env_start_args().enable_fa3:
if not get_env_start_args().disable_fa3:
self.infer_state_class = FlashAttentionStateInfo
elif self.enable_flashinfer:
self.infer_state_class = LlamaFlashInferStateInfo
Expand Down
2 changes: 1 addition & 1 deletion lightllm/models/qwen2_vl/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def __init__(self, kvargs):
return

def _init_inferstate_cls(self):
if get_env_start_args().enable_fa3:
if not get_env_start_args().disable_fa3:
self.infer_state_class = Qwen2VLFlashAttentionStateInfo

def _init_config(self):
Expand Down
10 changes: 5 additions & 5 deletions lightllm/server/api_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def make_argument_parser() -> argparse.ArgumentParser:
)
parser.add_argument("--host", type=str, default="127.0.0.1")
parser.add_argument("--port", type=int, default=8000)
parser.add_argument("--httpserver_workers", type=int, default=1)
parser.add_argument("--httpserver_workers", type=int, default=4)
parser.add_argument(
"--zmq_mode",
type=str,
Expand Down Expand Up @@ -303,9 +303,9 @@ def make_argument_parser() -> argparse.ArgumentParser:
only deepseekv3 model supported now.""",
)
parser.add_argument(
"--enable_fa3",
"--disable_fa3",
action="store_true",
help="""inference backend will use the fa3 attention kernel for prefill and decode""",
help="""inference backend will not use the fa3 attention kernel for prefill and decode""",
)
parser.add_argument(
"--cache_capacity", type=int, default=200, help="cache server capacity for multimodal resources"
Expand Down Expand Up @@ -466,9 +466,9 @@ def make_argument_parser() -> argparse.ArgumentParser:
help="""Whether to update the redundant expert for deepseekv3 model by online expert used counter.""",
)
parser.add_argument(
"--enable_fused_shared_experts",
"--disable_fused_shared_experts",
action="store_true",
help="""Whether to enable fused shared experts for deepseekv3 model. only work when MOE_MODE=TP """,
help="""Whether to disable fused shared experts for deepseekv3 model. only work when MOE_MODE=TP """,
)
parser.add_argument(
"--mtp_mode",
Expand Down
Loading