Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/workflows/scripts/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,8 @@ e2e-multicard-2-cards:
estimated_time: 215
- name: tests/e2e/multicard/2-cards/test_disaggregated_encoder.py
estimated_time: 90
- name: tests/e2e/multicard/2-cards/test_sp_pass.py
estimated_time: 300

e2e-multicard-4-cards:
# TODO: recover skipped tests
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ The following table lists additional configuration options available in vLLM Asc
| `pa_shape_list` | list | `[]` | The custom shape list of page attention ops. |
| `enable_kv_nz` | bool | `False` | Whether to enable KV cache NZ layout. This option only takes effects on models using MLA (e.g., DeepSeek). |
| `layer_sharding` | dict | `{}` | Configuration options for Layer Sharding Linear |
| `sp_threshold` | int | `1000` | For dense models, only num_tokens > threshold will enable sequence parallelism. |

The details of each configuration option are as follows:

Expand Down
1 change: 1 addition & 0 deletions docs/source/user_guide/feature_guide/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,5 @@ speculative_decoding
context_parallel
npugraph_ex
weight_prefetch
sequence_parallelism
:::
57 changes: 57 additions & 0 deletions docs/source/user_guide/feature_guide/sequence_parallelism.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# Sequence Parallelism

## What is Sequence Parallelism

Sequence Parallelism (SP) was first introduced in [Megatron](https://arxiv.org/pdf/2205.05198), with the original intention of reducing training activation memory. The core modification was changing `Allreduce->LayerNorm` to `ReduceScatter->LayerNorm->Allgather`. This technique was later applied to inference by vllm. It should be noted that splitting Allreduce into ReduceScatter and Allgather does not inherently bring performance benefits; it reduces the computation load of LayerNorm, but this gain is minimal. The real benefits of SP come from:

1. LLM inference deployment often uses quantization. Taking INT8 quantization commonly used on NPUs as an example, after LayerNorm, a Quant operator quantizes the hidden states from BF16 to INT8. The communication volume of Allgather is halved, and the time consumption is almost halved.
2. ReduceScatter and Allgather can be fused with the preceding and following Matmul operations respectively into communication-computation parallel operators, reducing latency.

## How to Use

Currently, vllm-ascend has implemented Sequence Parallelism for VL-class models based on the Inductor pass. It can be enabled in the following way:

```bash
vllm serve Qwen/Qwen3-VL-2B-Instruct \
--tensor-parallel-size 2 \
--compilation-config '{"pass_config": {"enable_sp": true}}' \
--additional_config={"sp_threshold": 1000}
```

- `"pass_config": {"enable_sp": true}`: This is the switch for SP. Since SP relies on graph mode, it must be enabled and is not supported in eager mode.
- `--additional_config={"sp_threshold": 1000}`: Based on our experiments, when the number of tokens is small (empirical value is less than 1000), SP can actually bring negative benefits. This is because when the communication volume is small, the fixed overhead of the communication operator becomes the dominant factor. Therefore, when one communication operator (Allreduce) is split into two communication operators (ReduceScatter+Allgather), the end-to-end latency often becomes longer. Thus, we have reserved the `sp_threshold`parameter; SP will only take effect when `num_tokens >= sp_threshold`. **The default value is 1000, which generally does not need to be modified.** `sp_threshold` will be appended into `compile_ranges_split_points`, which is a parameter provided by vllm that splits the graph compilation range `[1, max_num_batched_tokens]` into `{[1, split_points[0]], [split_points[0] + 1, split_points[1]], ..., [split_points[-1] + 1, max_num_batched_tokens]}`, and sequentially checks whether the `is_applicable_for_range` of the pass returns `True`.

Without modifying `sp_threshold`, the simplest way and recommended way to enable SP is:

```bash
vllm serve Qwen/Qwen3-VL-2B-Instruct \
--tensor-parallel-size 2 \
--compilation-config '{"pass_config": {"enable_sp": true}}'
```

## Difference Between SP and Flash Comm V1

[Flash Comm V1 (FC1)](https://gitcode.com/ascend-tribe/ascend-inference-cluster/blob/main/FlashComm/ascend-inference-cluster-flashcomm.md) is an enhanced version of Sequence Parallelism developed based on NPU. The enhancements include:

1. For models using the MLA structure, Allgather is postponed until after QKV projection, further reducing communication volume.
2. For MoE models, Allgather is postponed until after Gating+DynamicQuant, also aiming to reduce communication volume.

FC1 is a unique optimization in vllm-ascend, currently implemented based on Custom OP, but it is difficult to support VL-class models (reasons detailed in [[RFC]: support sequence parallelism by pass](https://github.com/vllm-project/vllm-ascend/issues/5712) ). Therefore, currently FC1 and SP are complementary.

## Support Matrix

### Without Quantization

| | VL + Dense | VL + MoE | non-VL + Dense | non-VL + MoE |
| -------------------- | ---------- | -------- | -------------- | ------------ |
| Sequence Parallelism | graph | x | x | x |
| Flash Comm V1 | x | x | eager/graph | eager/graph |

### With Quantization

SP currently does not support quantization and is under adaptation.

| | VL + Dense | VL + MoE | non-VL + Dense | non-VL + MoE |
| -------------------- | ---------- | -------- | -------------- | ------------ |
| Sequence Parallelism | x | x | x | x |
| Flash Comm V1 | x | x | eager/graph | eager/graph |
64 changes: 64 additions & 0 deletions tests/e2e/multicard/2-cards/test_sp_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import os

import pytest
from vllm import SamplingParams

from tests.e2e.conftest import VllmRunner
from tests.e2e.model_utils import check_outputs_equal

MODELS = [
"Qwen/Qwen3-VL-2B-Instruct",
]


@pytest.mark.parametrize("model", MODELS)
def test_qwen3_vl_sp_tp2(model: str) -> None:
prompts = [
"Hello, my name is", "The capital of the United States is",
"The capital of France is", "The future of AI is"
]
sampling_params = SamplingParams(max_tokens=10, temperature=0.0)

with VllmRunner(
model,
max_model_len=1024,
tensor_parallel_size=2,
compilation_config={
"cudagraph_capture_sizes": [2, 4],
"cudagraph_mode": "FULL_DECODE_ONLY",
"pass_config": {"enable_sp": False}
},
additional_config={"npugraph_ex_config": {"enable": False}}
) as runner:
no_sp_outputs = runner.model.generate(prompts, sampling_params)

with VllmRunner(
model,
max_model_len=1024,
tensor_parallel_size=2,
compilation_config={
"cudagraph_capture_sizes": [2, 4],
"cudagraph_mode": "FULL_DECODE_ONLY",
"pass_config": {"enable_sp": True}
},
additional_config={"sp_threshold": 10, "npugraph_ex_config": {"enable": False}}
) as runner:
sp_outputs = runner.model.generate(
prompts, sampling_params)

no_sp_outputs_list = []
for output in no_sp_outputs:
no_sp_outputs_list.append(
(output.outputs[0].index, output.outputs[0].text))

sp_outputs_list = []
for output in sp_outputs:
sp_outputs_list.append(
(output.outputs[0].index, output.outputs[0].text))

check_outputs_equal(
outputs_0_lst=no_sp_outputs_list,
outputs_1_lst=sp_outputs_list,
name_0="no_sp_outputs",
name_1="sp_outputs",
)
2 changes: 1 addition & 1 deletion tests/ut/ops/test_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def test_forward(self, mock_get_forward_context, mock_tp_size,
hidden_states = torch.randn(3, self.hidden_size)

mock_forward_context = MagicMock(spec=ForwardContext)
mock_forward_context.sp_enabled = False
mock_forward_context.flash_comm_v1_enabled = False
mock_get_forward_context.return_value = mock_forward_context

mock_mla_forward.return_value = (3, self.hidden_size)
Expand Down
8 changes: 4 additions & 4 deletions tests/ut/spec_decode/test_eagle_proposer.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,7 @@ def tearDown(self):

# cpu does not support parallel-group, let alone `sp`
@patch("vllm_ascend.spec_decode.eagle_proposer.get_forward_context",
**{"return_value.sp_enabled": False})
**{"return_value.flash_comm_v1_enabled": False})
@patch("vllm_ascend.spec_decode.eagle_proposer.set_ascend_forward_context")
def test_dummy_run_basic(self, mock_context, mock_get_context):
num_tokens = 32
Expand All @@ -406,7 +406,7 @@ def test_dummy_run_basic(self, mock_context, mock_get_context):

# cpu does not support parallel-group, let alone `sp`
@patch("vllm_ascend.spec_decode.eagle_proposer.get_forward_context",
**{"return_value.sp_enabled": False})
**{"return_value.flash_comm_v1_enabled": False})
@patch("vllm_ascend.spec_decode.eagle_proposer.set_ascend_forward_context")
def test_dummy_run_with_prefill(self, mock_context, mock_get_context):
mock_context.return_value.__enter__.return_value = None
Expand All @@ -426,7 +426,7 @@ def test_dummy_run_in_graph_capture(self, mock_context, mock_get_context,
mock_return_context.cudagraph_runtime_mode = CUDAGraphMode.FULL
mock_return_context.capturing = True
# cpu does not support parallel-group, let alone `sp`
mock_return_context.sp_enabled = False
mock_return_context.flash_comm_v1_enabled = False
mock_get_context.return_value = mock_return_context
self.proposer.use_cuda_graph = True
# cpu does not support `torch.ops.vllm.maybe_pad_and_reduce`
Expand All @@ -449,7 +449,7 @@ def test_dummy_run_in_graph_run(self, mock_context, mock_get_context,
mock_return_context.cudagraph_runtime_mode = CUDAGraphMode.FULL
mock_return_context.capturing = False
# cpu does not support parallel-group, let alone `sp`
mock_return_context.sp_enabled = False
mock_return_context.flash_comm_v1_enabled = False
mock_get_context.return_value = mock_return_context
self.proposer.use_cuda_graph = True
# cpu does not support `torch.ops.vllm.maybe_pad_and_reduce`
Expand Down
42 changes: 42 additions & 0 deletions vllm_ascend/ascend_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class AscendConfig:
"""

def __init__(self, vllm_config: "VllmConfig"):
self.vllm_config = vllm_config
additional_config = vllm_config.additional_config if vllm_config.additional_config is not None else {}

xlite_graph_config = additional_config.get("xlite_graph_config", {})
Expand Down Expand Up @@ -160,6 +161,47 @@ def _construct_weight_prefetch_config(self, additional_config):
stacklevel=2,
)

def update_compile_ranges_split_points(self):
vllm_config = self.vllm_config
if self.npugraph_ex_config.enable:
if self.npugraph_ex_config.fuse_allreduce_rms:
from vllm_ascend.compilation.passes.allreduce_rmsnorm_fusion_pass import ALLREDUCE_NORM_FUSE_THRESHOLD

new_compile_ranges_split_points = vllm_config.compilation_config.compile_ranges_split_points
new_compile_ranges_split_points.append(ALLREDUCE_NORM_FUSE_THRESHOLD)
new_compile_ranges_split_points = sorted(new_compile_ranges_split_points)
vllm_config.compilation_config.compile_ranges_split_points = new_compile_ranges_split_points
logger.debug(
"set compile_ranges_split_points to "
"{new_compile_ranges_split_points} for matmul and allreduce fusion"
)

else:
new_compile_ranges_split_points = vllm_config.compilation_config.compile_ranges_split_points
if vllm_config.additional_config.get("ascend_compilation_config", {}).get("fuse_allreduce_rms", True):
from vllm_ascend.compilation.passes.allreduce_rmsnorm_fusion_pass import ALLREDUCE_NORM_FUSE_THRESHOLD

new_compile_ranges_split_points = vllm_config.compilation_config.compile_ranges_split_points
new_compile_ranges_split_points.append(ALLREDUCE_NORM_FUSE_THRESHOLD)
new_compile_ranges_split_points = sorted(new_compile_ranges_split_points)
vllm_config.compilation_config.compile_ranges_split_points = new_compile_ranges_split_points
logger.debug(
"set compile_ranges_split_points to "
"{new_compile_ranges_split_points} for matmul and allreduce fusion"
)

from vllm_ascend.utils import is_moe_model

if vllm_config.compilation_config.pass_config.enable_sp and not is_moe_model(vllm_config):
from vllm_ascend.compilation.passes.sequence_parallelism import get_sp_threshold

sp_threshold = get_sp_threshold(vllm_config)
new_compile_ranges_split_points.append(sp_threshold)
logger.debug(f"add {sp_threshold} to compile_ranges_split_points for sequence parallelism")
if len(new_compile_ranges_split_points) > len(vllm_config.compilation_config.compile_ranges_split_points):
new_compile_ranges_split_points = sorted(new_compile_ranges_split_points)
vllm_config.compilation_config.compile_ranges_split_points = new_compile_ranges_split_points


class FinegrainedTPConfig:
"""
Expand Down
16 changes: 8 additions & 8 deletions vllm_ascend/ascend_forward_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.utils import (
AscendDeviceType,
enable_sp,
enable_flash_comm_v1,
flashcomm2_enable,
get_ascend_device_type,
has_layer_idx,
Expand Down Expand Up @@ -92,22 +92,22 @@ def set_ascend_forward_context(
# main model and drafter model may have different architecture
is_context_moe_model = is_drafter_moe_model(vllm_config) if is_draft_model else is_moe_model(vllm_config)
if is_context_moe_model:
sp_enabled = enable_sp(vllm_config) and num_tokens is not None
flash_comm_v1_enabled = enable_flash_comm_v1() and num_tokens is not None
mmrs_fusion = False
elif is_draft_model:
# TODO: for dense drafter, `sp` is redundant and is not compatible with `dp` and `graph`.
# Disable it to avoid more problems.
sp_enabled = False
flash_comm_v1_enabled = False
else:
sp_enabled = enable_sp(vllm_config) and num_tokens is not None and num_tokens > 1000

flash_comm_v1_enabled = enable_flash_comm_v1() and num_tokens is not None and num_tokens > 1000
forward_context.mmrs_fusion = mmrs_fusion
forward_context.num_tokens = num_tokens
forward_context.sp_enabled = sp_enabled
forward_context.flash_comm_v1_enabled = flash_comm_v1_enabled
# TODO(Levi-JQ): another PR to normalize the enabling logic for sp/fc2
forward_context.flashcomm_v2_enabled = flashcomm2_enable() and tp_world_size > 1 and num_tokens is not None

if forward_context.sp_enabled or forward_context.flashcomm_v2_enabled:
forward_context.pad_size = 0
if forward_context.flash_comm_v1_enabled or forward_context.flashcomm_v2_enabled:
pad_size = (tp_world_size - (num_tokens % tp_world_size)) % tp_world_size
forward_context.pad_size = pad_size

Expand All @@ -131,7 +131,7 @@ def set_ascend_forward_context(
dp_world_size = get_dp_group().world_size
if dp_world_size > 1 and forward_context.dp_metadata is not None:
max_tokens_across_dp = forward_context.dp_metadata.max_tokens_across_dp_cpu.item()
if forward_context.sp_enabled or forward_context.flashcomm_v2_enabled:
if forward_context.flash_comm_v1_enabled or forward_context.flashcomm_v2_enabled:
padded_length = (max_tokens_across_dp + tp_world_size - 1) // tp_world_size * tp_world_size
pad_size = padded_length - num_tokens
forward_context.padded_length = padded_length
Expand Down
5 changes: 5 additions & 0 deletions vllm_ascend/compilation/compiler_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import copy
import functools
from collections.abc import Callable
from typing import Any
Expand Down Expand Up @@ -129,6 +130,10 @@ def compile(
compile_range: Range,
key: str | None = None,
) -> tuple[Callable | None, Any | None]:
# inductor can inplace modify the graph, so we need to copy it
# see https://github.com/pytorch/pytorch/issues/138980
graph = copy.deepcopy(graph)

npugraph_ex_config = get_ascend_config().npugraph_ex_config
if npugraph_ex_config.enable:
assert hasattr(self, "vllm_config")
Expand Down
5 changes: 5 additions & 0 deletions vllm_ascend/compilation/graph_fusion_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,3 +70,8 @@ def configure(self, config: VllmConfig):
from .passes.allreduce_rmsnorm_fusion_pass import MatmulAllReduceAddRMSNormPass

self.passes.append(MatmulAllReduceAddRMSNormPass(config))

if config.compilation_config.pass_config.enable_sp:
from .passes.sequence_parallelism import AscendSequenceParallelismPass

self.passes.append(AscendSequenceParallelismPass(config))
Loading