diff --git a/docs/CN/source/tutorial/api_server_args_zh.rst b/docs/CN/source/tutorial/api_server_args_zh.rst index ec4d67c53..aea023b6b 100755 --- a/docs/CN/source/tutorial/api_server_args_zh.rst +++ b/docs/CN/source/tutorial/api_server_args_zh.rst @@ -333,9 +333,9 @@ attention类型选择参数 推理后端将为解码使用 flashinfer 的注意力 kernel -.. option:: --enable_fa3 +.. option:: --disable_fa3 - 推理后端将为预填充和解码使用 fa3 注意力 kernel + 推理后端将不为预填充和解码使用 fa3 注意力 kernel(FA3 默认启用) .. option:: --disable_cudagraph diff --git a/docs/CN/source/tutorial/deepseek_deployment.rst b/docs/CN/source/tutorial/deepseek_deployment.rst index f59549fb3..1478034bb 100644 --- a/docs/CN/source/tutorial/deepseek_deployment.rst +++ b/docs/CN/source/tutorial/deepseek_deployment.rst @@ -32,13 +32,11 @@ LightLLM 支持以下几种部署模式: # H200 单机 DeepSeek-R1 TP 模式 LOADWORKER=18 python -m lightllm.server.api_server --port 8088 \ --model_dir /path/DeepSeek-R1 \ - --tp 8 \ - --enable_fa3 + --tp 8 **参数说明:** - `LOADWORKER=18`: 模型加载线程数,提高加载速度 - `--tp 8`: 张量并行度,使用8个GPU -- `--enable_fa3`: 启用 Flash Attention 3.0 - `--port 8088`: 服务端口 1.2 单机 DP + EP 模式 (Data Parallel + Expert Parallel) @@ -55,13 +53,11 @@ LightLLM 支持以下几种部署模式: --model_dir /path/DeepSeek-R1 \ --tp 8 \ --dp 8 \ - --enable_fa3 **参数说明:** - `MOE_MODE=EP`: 设置专家并行模式 - `--tp 8`: 张量并行度 - `--dp 8`: 数据并行度,通常设置为与 tp 相同的值 -- `--enable_fa3`: 启用 Flash Attention 3.0 **可选优化参数:** - `--enable_prefill_microbatch_overlap`: 启用预填充微批次重叠 @@ -85,7 +81,6 @@ LightLLM 支持以下几种部署模式: LOADWORKER=18 python -m lightllm.server.api_server --port 8088 \ --model_dir /path/DeepSeek-R1 \ --tp 16 \ - --enable_fa3 \ --nnodes 2 \ --node_rank 0 \ --nccl_host $nccl_host \ @@ -101,7 +96,6 @@ LightLLM 支持以下几种部署模式: LOADWORKER=18 python -m lightllm.server.api_server --port 8088 \ --model_dir /path/DeepSeek-R1 \ --tp 16 \ - --enable_fa3 \ --nnodes 2 \ --node_rank 1 \ --nccl_host $nccl_host \ @@ -129,7 +123,6 @@ LightLLM 支持以下几种部署模式: --model_dir /path/DeepSeek-R1 \ --tp 16 \ --dp 16 \ - --enable_fa3 \ --nnodes 2 \ --node_rank 0 \ --nccl_host $nccl_host \ @@ -146,7 +139,6 @@ LightLLM 支持以下几种部署模式: --model_dir /path/DeepSeek-R1 \ --tp 16 \ --dp 16 \ - --enable_fa3 \ --nnodes 2 \ --node_rank 1 \ --nccl_host $nccl_host \ @@ -195,7 +187,6 @@ PD (Prefill-Decode) 分离模式将预填充和解码阶段分离部署,可以 --host $host \ --port 8019 \ --nccl_port 2732 \ - --enable_fa3 \ --disable_cudagraph \ --pd_master_ip $pd_master_ip \ --pd_master_port 60011 @@ -219,7 +210,6 @@ PD (Prefill-Decode) 分离模式将预填充和解码阶段分离部署,可以 --host $host \ --port 8121 \ --nccl_port 12322 \ - --enable_fa3 \ --disable_cudagraph \ --pd_master_ip $pd_master_ip \ --pd_master_port 60011 @@ -287,7 +277,6 @@ PD (Prefill-Decode) 分离模式将预填充和解码阶段分离部署,可以 --tp 8 \ --dp 8 \ --nccl_port 2732 \ - --enable_fa3 \ --disable_cudagraph \ --config_server_host $config_server_host \ --config_server_port 60088 @@ -306,7 +295,6 @@ PD (Prefill-Decode) 分离模式将预填充和解码阶段分离部署,可以 --nccl_port 12322 \ --tp 8 \ --dp 8 \ - --enable_fa3 \ --config_server_host $config_server_host \ --config_server_port 60088 # 如果需要启用微批次重叠,可以取消注释以下行 diff --git a/docs/EN/source/tutorial/api_server_args_zh.rst b/docs/EN/source/tutorial/api_server_args_zh.rst index 0769cef55..723b4cf86 100755 --- a/docs/EN/source/tutorial/api_server_args_zh.rst +++ b/docs/EN/source/tutorial/api_server_args_zh.rst @@ -332,9 +332,9 @@ Performance Optimization Parameters The inference backend will use flashinfer's attention kernel for decoding -.. option:: --enable_fa3 +.. option:: --disable_fa3 - The inference backend will use fa3 attention kernel for prefill and decoding + The inference backend will not use fa3 attention kernel for prefill and decoding (FA3 is enabled by default) .. option:: --disable_cudagraph diff --git a/docs/EN/source/tutorial/deepseek_deployment.rst b/docs/EN/source/tutorial/deepseek_deployment.rst index 9e2624bb8..cc19a961d 100755 --- a/docs/EN/source/tutorial/deepseek_deployment.rst +++ b/docs/EN/source/tutorial/deepseek_deployment.rst @@ -32,13 +32,11 @@ Suitable for deploying DeepSeek-R1 model on a single H200 node. # H200 Single node DeepSeek-R1 TP Mode LOADWORKER=18 python -m lightllm.server.api_server --port 8088 \ --model_dir /path/DeepSeek-R1 \ - --tp 8 \ - --enable_fa3 + --tp 8 **Parameter Description:** - `LOADWORKER=18`: Model loading thread count, improves loading speed - `--tp 8`: Tensor parallelism, using 8 GPUs -- `--enable_fa3`: Enable Flash Attention 3.0 - `--port 8088`: Service port 1.2 Single node DP + EP Mode (Data Parallel + Expert Parallel) @@ -55,13 +53,11 @@ Suitable for expert parallelism deployment of MoE models like DeepSeek-V2/V3. --model_dir /path/DeepSeek-R1 \ --tp 8 \ --dp 8 \ - --enable_fa3 **Parameter Description:** - `MOE_MODE=EP`: Set expert parallelism mode - `--tp 8`: Tensor parallelism - `--dp 8`: Data parallelism, usually set to the same value as tp -- `--enable_fa3`: Enable Flash Attention 3.0 **Optional Optimization Parameters:** - `--enable_prefill_microbatch_overlap`: Enable prefill microbatch overlap @@ -85,7 +81,6 @@ Suitable for deployment across multiple H200/H100 nodes. LOADWORKER=18 python -m lightllm.server.api_server --port 8088 \ --model_dir /path/DeepSeek-R1 \ --tp 16 \ - --enable_fa3 \ --nnodes 2 \ --node_rank 0 \ --nccl_host $nccl_host \ @@ -101,7 +96,6 @@ Suitable for deployment across multiple H200/H100 nodes. LOADWORKER=18 python -m lightllm.server.api_server --port 8088 \ --model_dir /path/DeepSeek-R1 \ --tp 16 \ - --enable_fa3 \ --nnodes 2 \ --node_rank 1 \ --nccl_host $nccl_host \ @@ -129,7 +123,6 @@ Suitable for deploying MoE models across multiple nodes. --model_dir /path/DeepSeek-R1 \ --tp 16 \ --dp 16 \ - --enable_fa3 \ --nnodes 2 \ --node_rank 0 \ --nccl_host $nccl_host \ @@ -146,7 +139,6 @@ Suitable for deploying MoE models across multiple nodes. --model_dir /path/DeepSeek-R1 \ --tp 16 \ --dp 16 \ - --enable_fa3 \ --nnodes 2 \ --node_rank 1 \ --nccl_host $nccl_host \ @@ -195,7 +187,6 @@ PD (Prefill-Decode) disaggregation mode separates prefill and decode stages for --host $host \ --port 8019 \ --nccl_port 2732 \ - --enable_fa3 \ --disable_cudagraph \ --pd_master_ip $pd_master_ip @@ -216,7 +207,6 @@ PD (Prefill-Decode) disaggregation mode separates prefill and decode stages for --host $host \ --port 8121 \ --nccl_port 12322 \ - --enable_fa3 \ --disable_cudagraph \ --pd_master_ip $pd_master_ip \ --pd_master_port 60011 @@ -284,7 +274,6 @@ Supports multiple PD Master nodes, providing better load balancing and high avai --tp 8 \ --dp 8 \ --nccl_port 2732 \ - --enable_fa3 \ --disable_cudagraph \ --config_server_host $config_server_host \ --config_server_port 60088 @@ -303,7 +292,6 @@ Supports multiple PD Master nodes, providing better load balancing and high avai --nccl_port 12322 \ --tp 8 \ --dp 8 \ - --enable_fa3 \ --config_server_host $config_server_host \ --config_server_port 60088 # if you want to enable microbatch overlap, you can uncomment the following lines diff --git a/lightllm/common/offline_fp8_quant_mem_manager.py b/lightllm/common/offline_fp8_quant_mem_manager.py index 5cc0b12d0..3d0b9bad4 100755 --- a/lightllm/common/offline_fp8_quant_mem_manager.py +++ b/lightllm/common/offline_fp8_quant_mem_manager.py @@ -32,7 +32,7 @@ def __init__( self.abs_max = None if is_export_mode: - scales_shape = [layer_num, 2 * head_num] if get_env_start_args().enable_fa3 else [layer_num, 2] + scales_shape = [layer_num, 2 * head_num] if not get_env_start_args().disable_fa3 else [layer_num, 2] self.abs_max = torch.zeros(scales_shape, dtype=torch.float32, device="cuda") elif get_env_start_args().kv_quant_calibration_config_path is not None: logger.info( @@ -43,7 +43,7 @@ def __init__( self.scales_list = cfg["scales"] self.scales = torch.tensor(self.scales_list, dtype=torch.float32, device="cuda").view(cfg["scales_shape"]) - if not get_env_start_args().enable_fa3: + if get_env_start_args().disable_fa3: self.scales = torch.repeat_interleave(self.scales, head_num, dim=-1) elif cfg["num_head"] > self.total_head_num: factor = cfg["num_head"] // self.total_head_num @@ -51,7 +51,7 @@ def __init__( elif cfg["num_head"] < self.total_head_num: factor = self.total_head_num // cfg["num_head"] self.scales = torch.repeat_interleave(self.scales, factor, dim=-1).contiguous() - if get_env_start_args().enable_fa3 and dist.is_initialized() and dist.get_world_size() > 1: + if not get_env_start_args().disable_fa3 and dist.is_initialized() and dist.get_world_size() > 1: half_head = self.total_head_num // 2 start_head = dist.get_rank() * head_num end_head = start_head + head_num @@ -86,7 +86,7 @@ def _load_and_check_config(self): raise ValueError( f"num_head {cfg['num_head']} in config " f"not match current model head num {self.total_head_num}" ) - if get_env_start_args().enable_fa3: + if not get_env_start_args().disable_fa3: if cfg["quant_type"] != "per_head": raise ValueError(f"quant type {cfg['num_head']} in config not match fa3 backend") else: @@ -109,7 +109,7 @@ def update_calibration_data(self, kv_buffer: torch.Tensor, layer_index: int): logger.info("kv cache calibration mode will collect kv cache data for quantization calibration") if self.abs_max is not None and self.count >= warmup_counts: - if get_env_start_args().enable_fa3: + if not get_env_start_args().disable_fa3: kv_max = kv_buffer.abs().amax(dim=(0, 2)).to(torch.float32) else: k_max = kv_buffer[:, : self.head_num, :].abs().amax(dim=()).to(torch.float32) @@ -119,7 +119,7 @@ def update_calibration_data(self, kv_buffer: torch.Tensor, layer_index: int): if self.count == warmup_counts + inference_counts - 1 and layer_index == self.layer_num - 1: final_abs_max = self.abs_max if dist.is_initialized() and dist.get_world_size() > 1: - if get_env_start_args().enable_fa3: + if not get_env_start_args().disable_fa3: k_max, v_max = torch.chunk(self.abs_max, 2, dim=-1) k_max = k_max.contiguous() v_max = v_max.contiguous() @@ -148,7 +148,7 @@ def _export_calibration_data(self): cfg = { "version": "1.0", "architectures": model_arch, - "quant_type": "per_head" if get_env_start_args().enable_fa3 else "per_tensor", + "quant_type": "per_head" if not get_env_start_args().disable_fa3 else "per_tensor", "qmin": self.qmin, "qmax": self.qmax, "num_layers": self.layer_num, diff --git a/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py b/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py index ace54bba4..5992c87e3 100644 --- a/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py @@ -95,7 +95,7 @@ def _bind_attention(self): ) else: self._copy_kv_to_mem_cache = partial(Deepseek2TransformerLayerInfer._copy_kv_to_mem_cache_normal, self) - if get_env_start_args().enable_fa3: + if not get_env_start_args().disable_fa3: self._token_attention_kernel = partial( Deepseek2TransformerLayerInfer._token_gqa_decode_attention_flashattention, self ) @@ -118,7 +118,7 @@ def _bind_attention(self): Deepseek2TransformerLayerInfer._context_attention_kernel_with_CC_fp8, self ) else: - if get_env_start_args().enable_fa3: + if not get_env_start_args().disable_fa3: self._context_attention_kernel = partial( Deepseek2TransformerLayerInfer._context_attention_flashattention_kernel_with_CC, self ) diff --git a/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py b/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py index f78f7e849..5ec89ba3c 100644 --- a/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py @@ -40,8 +40,8 @@ def _parse_config(self): self.num_attention_heads = self.network_config_["num_attention_heads"] self.kv_lora_rank = self.network_config_["kv_lora_rank"] self.num_fused_shared_experts = 0 - if get_env_start_args().enable_fused_shared_experts and self.is_moe: - # MOE_MODE 处于 TP 模式下才能使能 enable_fused_shared_experts + if not get_env_start_args().disable_fused_shared_experts and self.is_moe: + # MOE_MODE 处于 TP 模式下才能使能 fused_shared_experts moe_mode = os.getenv("MOE_MODE", "TP") assert moe_mode == "TP" self.num_fused_shared_experts = self.network_config_.get("n_shared_experts", 0) diff --git a/lightllm/models/deepseek2/model.py b/lightllm/models/deepseek2/model.py index 9101cb963..04a63c2f9 100644 --- a/lightllm/models/deepseek2/model.py +++ b/lightllm/models/deepseek2/model.py @@ -69,7 +69,7 @@ def __init__(self, kvargs): return def _init_inferstate_cls(self): - if get_env_start_args().enable_fa3: + if not get_env_start_args().disable_fa3: self.infer_state_class = Deepseek2FlashAttentionStateInfo elif self.enable_flashinfer: self.infer_state_class = Deepseek2FlashInferStateInfo diff --git a/lightllm/models/llama/layer_infer/transformer_layer_infer.py b/lightllm/models/llama/layer_infer/transformer_layer_infer.py index b00215cff..646f7d1b9 100755 --- a/lightllm/models/llama/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/llama/layer_infer/transformer_layer_infer.py @@ -68,7 +68,7 @@ def _bind_norm(self): return def _bind_attention(self): - if get_env_start_args().enable_fa3: + if not get_env_start_args().disable_fa3: if "offline_calibration_fp8kv" in self.mode: self._context_attention_kernel = partial( LlamaTransformerLayerInfer._context_attention_flashattention_fp8, self diff --git a/lightllm/models/llama/model.py b/lightllm/models/llama/model.py index abc258e8b..947fb9d29 100644 --- a/lightllm/models/llama/model.py +++ b/lightllm/models/llama/model.py @@ -97,7 +97,7 @@ def _init_mem_manager(self): return def _init_inferstate_cls(self): - if get_env_start_args().enable_fa3: + if not get_env_start_args().disable_fa3: self.infer_state_class = FlashAttentionStateInfo elif self.enable_flashinfer: self.infer_state_class = LlamaFlashInferStateInfo diff --git a/lightllm/models/qwen2_vl/model.py b/lightllm/models/qwen2_vl/model.py index 6d179e6f9..b09acc78d 100644 --- a/lightllm/models/qwen2_vl/model.py +++ b/lightllm/models/qwen2_vl/model.py @@ -106,7 +106,7 @@ def __init__(self, kvargs): return def _init_inferstate_cls(self): - if get_env_start_args().enable_fa3: + if not get_env_start_args().disable_fa3: self.infer_state_class = Qwen2VLFlashAttentionStateInfo def _init_config(self): diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index ae9f7541d..f23c79b3a 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -16,7 +16,7 @@ def make_argument_parser() -> argparse.ArgumentParser: ) parser.add_argument("--host", type=str, default="127.0.0.1") parser.add_argument("--port", type=int, default=8000) - parser.add_argument("--httpserver_workers", type=int, default=1) + parser.add_argument("--httpserver_workers", type=int, default=4) parser.add_argument( "--zmq_mode", type=str, @@ -303,9 +303,9 @@ def make_argument_parser() -> argparse.ArgumentParser: only deepseekv3 model supported now.""", ) parser.add_argument( - "--enable_fa3", + "--disable_fa3", action="store_true", - help="""inference backend will use the fa3 attention kernel for prefill and decode""", + help="""inference backend will not use the fa3 attention kernel for prefill and decode""", ) parser.add_argument( "--cache_capacity", type=int, default=200, help="cache server capacity for multimodal resources" @@ -466,9 +466,9 @@ def make_argument_parser() -> argparse.ArgumentParser: help="""Whether to update the redundant expert for deepseekv3 model by online expert used counter.""", ) parser.add_argument( - "--enable_fused_shared_experts", + "--disable_fused_shared_experts", action="store_true", - help="""Whether to enable fused shared experts for deepseekv3 model. only work when MOE_MODE=TP """, + help="""Whether to disable fused shared experts for deepseekv3 model. only work when MOE_MODE=TP """, ) parser.add_argument( "--mtp_mode", diff --git a/lightllm/server/api_start.py b/lightllm/server/api_start.py index 03c519d7b..c584e99af 100644 --- a/lightllm/server/api_start.py +++ b/lightllm/server/api_start.py @@ -112,21 +112,33 @@ def normal_or_p_d_start(args): assert args.disable_dynamic_prompt_cache is True, "need add --disable_dynamic_prompt_cache" assert args.disable_chunked_prefill is True, "need add --disable_chunked_prefill" if "offline_calibration_fp8kv" in args.mode: - assert args.enable_fa3 is True or ( + assert args.disable_fa3 is False or ( args.enable_flashinfer_prefill is True and args.enable_flashinfer_decode is True ), ( - "offline_calibration_fp8kv mode need enable fa3 or flashinfer, add --enable_fa3 or " + "offline_calibration_fp8kv mode need enable fa3 or flashinfer, add --disable_fa3 False or " "--enable_flashinfer_prefill and --enable_flashinfer_decode" ) if "export_fp8kv_calibration" in args.mode: - assert args.enable_fa3 is True or ( + assert args.disable_fa3 is False or ( args.enable_flashinfer_prefill is True and args.enable_flashinfer_decode is True ), ( - "export_fp8kv_calibration mode need enable fa3 or flashinfer, add --enable_fa3 or " + "export_fp8kv_calibration mode need enable fa3 or flashinfer, add --disable_fa3 False or " "--enable_flashinfer_prefill and --enable_flashinfer_decode" ) assert args.disable_cudagraph is True, "export_fp8kv_calibration mode need disable cudagraph" + # Validate FA3 support when enabled (when disable_fa3 is False) + if not args.disable_fa3: + from lightllm.utils.device_utils import is_fa3_supported + + if not is_fa3_supported(): + logger.warning( + "FA3 is enabled but not supported on this hardware/software environment. " + "FA3 requires Hopper architecture (H100, H200, H800) or newer, and sgl_kernel package. " + "Disabling FA3 and falling back to other attention kernels." + ) + args.disable_fa3 = True + # 部分模式还不能支持与高级动态调度算法协同,to do. if args.diverse_mode: assert args.router_token_ratio == 0.0 diff --git a/lightllm/utils/device_utils.py b/lightllm/utils/device_utils.py index 66a25e93b..8b93f49c0 100644 --- a/lightllm/utils/device_utils.py +++ b/lightllm/utils/device_utils.py @@ -24,6 +24,31 @@ def is_hopper(): ) +@lru_cache(maxsize=None) +def is_fa3_supported(): + """ + Check if the current hardware and software environment supports FA3. + FA3 requires: + 1. Hopper architecture (H100, H200, H800) or newer + 2. sgl_kernel package installed + """ + # Check hardware support (Hopper or newer) + if not is_hopper(): + # Check if it's Ada Lovelace (40 series) which might also support FA3 + device_name = torch.cuda.get_device_name(0) + if not ("RTX_40" in device_name.replace(" ", "_") or "L40" in device_name or "Ada" in device_name): + return False + + # Check software support (sgl_kernel installed) + try: + import sgl_kernel + from sgl_kernel.flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache + + return flash_attn_varlen_func is not None and flash_attn_with_kvcache is not None + except ImportError: + return False + + @lru_cache(maxsize=None) def get_device_sm_count(): import triton diff --git a/test/start_scripts/README.md b/test/start_scripts/README.md index aff1d973f..792b8a5e6 100644 --- a/test/start_scripts/README.md +++ b/test/start_scripts/README.md @@ -108,7 +108,7 @@ sh multi_pd_master/pd_decode.sh - `--model_dir`: Model file path - `--tp`: Tensor parallelism degree - `--dp`: Data parallelism degree -- `--enable_fa3`: Enable Flash Attention 3.0 +- `--disable_fa3`: Disable Flash Attention 3.0 (FA3 is enabled by default, use this flag to disable it) - `--nnodes`: Total number of nodes - `--node_rank`: Current node rank - `--nccl_host`: NCCL communication host address diff --git a/test/start_scripts/multi_node_ep_node0.sh b/test/start_scripts/multi_node_ep_node0.sh index 3a139968a..cd72e6cfc 100644 --- a/test/start_scripts/multi_node_ep_node0.sh +++ b/test/start_scripts/multi_node_ep_node0.sh @@ -6,7 +6,6 @@ MOE_MODE=EP LOADWORKER=18 python -m lightllm.server.api_server --port 8088 \ --model_dir /path/DeepSeek-R1 \ --tp 16 \ --dp 16 \ ---enable_fa3 \ --nnodes 2 \ --node_rank 0 \ --nccl_host $nccl_host \ diff --git a/test/start_scripts/multi_node_ep_node1.sh b/test/start_scripts/multi_node_ep_node1.sh index b24a59868..17b878a1b 100644 --- a/test/start_scripts/multi_node_ep_node1.sh +++ b/test/start_scripts/multi_node_ep_node1.sh @@ -6,7 +6,6 @@ MOE_MODE=EP LOADWORKER=18 python -m lightllm.server.api_server --port 8088 \ --model_dir /path/DeepSeek-R1 \ --tp 16 \ --dp 16 \ ---enable_fa3 \ --nnodes 2 \ --node_rank 1 \ --nccl_host $nccl_host \ diff --git a/test/start_scripts/multi_node_tp_node0.sh b/test/start_scripts/multi_node_tp_node0.sh index b86bdeb35..01a903b3e 100644 --- a/test/start_scripts/multi_node_tp_node0.sh +++ b/test/start_scripts/multi_node_tp_node0.sh @@ -5,7 +5,6 @@ export nccl_host=$1 LOADWORKER=18 python -m lightllm.server.api_server --port 8088 \ --model_dir /path/DeepSeek-R1 \ --tp 16 \ ---enable_fa3 \ --nnodes 2 \ --node_rank 0 \ --nccl_host $nccl_host \ diff --git a/test/start_scripts/multi_node_tp_node1.sh b/test/start_scripts/multi_node_tp_node1.sh index 378977ab2..99a3ac130 100644 --- a/test/start_scripts/multi_node_tp_node1.sh +++ b/test/start_scripts/multi_node_tp_node1.sh @@ -5,7 +5,6 @@ export nccl_host=$1 LOADWORKER=18 python -m lightllm.server.api_server --port 8088 \ --model_dir /path/DeepSeek-R1 \ --tp 16 \ ---enable_fa3 \ --nnodes 2 \ --node_rank 1 \ --nccl_host $nccl_host \ diff --git a/test/start_scripts/multi_pd_master/pd_decode.sh b/test/start_scripts/multi_pd_master/pd_decode.sh index 4cefef6fb..cb55ec338 100644 --- a/test/start_scripts/multi_pd_master/pd_decode.sh +++ b/test/start_scripts/multi_pd_master/pd_decode.sh @@ -13,7 +13,6 @@ MOE_MODE=EP LOADWORKER=18 python -m lightllm.server.api_server \ --nccl_port 12322 \ --tp 8 \ --dp 8 \ ---enable_fa3 \ --config_server_host $config_server_host \ --config_server_port 60088 # if you want to enable microbatch overlap, you can uncomment the following lines diff --git a/test/start_scripts/multi_pd_master/pd_prefill.sh b/test/start_scripts/multi_pd_master/pd_prefill.sh index b845da435..41ad52551 100644 --- a/test/start_scripts/multi_pd_master/pd_prefill.sh +++ b/test/start_scripts/multi_pd_master/pd_prefill.sh @@ -13,7 +13,6 @@ MOE_MODE=EP LOADWORKER=18 python -m lightllm.server.api_server \ --tp 8 \ --dp 8 \ --nccl_port 2732 \ ---enable_fa3 \ --disable_cudagraph \ --config_server_host $config_server_host \ --config_server_port 60088 diff --git a/test/start_scripts/single_node_ep.sh b/test/start_scripts/single_node_ep.sh index cad172d51..50a2adb5c 100644 --- a/test/start_scripts/single_node_ep.sh +++ b/test/start_scripts/single_node_ep.sh @@ -3,7 +3,6 @@ MOE_MODE=EP LOADWORKER=18 python -m lightllm.server.api_server --port 8088 \ --model_dir /path/DeepSeek-R1 \ --tp 8 \ --dp 8 \ ---enable_fa3 # if you want to enable microbatch overlap, you can uncomment the following lines #--enable_prefill_microbatch_overlap \ #--enable_decode_microbatch_overlap \ diff --git a/test/start_scripts/single_node_tp.sh b/test/start_scripts/single_node_tp.sh index 1fb461bb1..237cec8ed 100644 --- a/test/start_scripts/single_node_tp.sh +++ b/test/start_scripts/single_node_tp.sh @@ -2,7 +2,6 @@ LOADWORKER=18 python -m lightllm.server.api_server --port 8088 \ --model_dir /path/DeepSeek-R1 \ --tp 8 \ ---enable_fa3 # if you want to enable microbatch overlap, you can uncomment the following lines #--enable_prefill_microbatch_overlap \ #--enable_decode_microbatch_overlap \ diff --git a/test/start_scripts/single_pd_master/pd_decode.sh b/test/start_scripts/single_pd_master/pd_decode.sh index 1bf465746..2a898645f 100644 --- a/test/start_scripts/single_pd_master/pd_decode.sh +++ b/test/start_scripts/single_pd_master/pd_decode.sh @@ -13,7 +13,6 @@ MOE_MODE=EP KV_TRANS_USE_P2P=1 LOADWORKER=18 python -m lightllm.server.api_serve --host $host \ --port 8121 \ --nccl_port 12322 \ ---enable_fa3 \ --pd_master_ip $pd_master_ip \ --pd_master_port 60011 # if you want to enable microbatch overlap, you can uncomment the following lines diff --git a/test/start_scripts/single_pd_master/pd_prefill.sh b/test/start_scripts/single_pd_master/pd_prefill.sh index b15e4ef70..a049ffeda 100644 --- a/test/start_scripts/single_pd_master/pd_prefill.sh +++ b/test/start_scripts/single_pd_master/pd_prefill.sh @@ -13,7 +13,6 @@ MOE_MODE=EP KV_TRANS_USE_P2P=1 LOADWORKER=18 python -m lightllm.server.api_serve --host $host \ --port 8019 \ --nccl_port 2732 \ ---enable_fa3 \ --disable_cudagraph \ --pd_master_ip $pd_master_ip \ --pd_master_port 60011