Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 0 additions & 11 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -51,17 +51,6 @@ line-length = 120
# Folder to be modified
exclude = [
"tests/**",

# (10)
"vllm_ascend/ops/*linear*.py",
"vllm_ascend/worker/worker.py",
"vllm_ascend/distributed/parallel_state.py",
"vllm_ascend/distributed/utils.py",
"vllm_ascend/xlite/*.py",
"vllm_ascend/patch/worker/patch_*.py",
"vllm_ascend/worker/v2/**",
"vllm_ascend/worker/npu_input_batch.py",
"vllm_ascend/ops/rotary_embedding.py",
]

[tool.ruff.lint]
Expand Down
177 changes: 63 additions & 114 deletions vllm_ascend/distributed/parallel_state.py

Large diffs are not rendered by default.

30 changes: 11 additions & 19 deletions vllm_ascend/distributed/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from typing import Optional

import torch
import torch.distributed as dist
from vllm.distributed.parallel_state import GroupCoordinator, get_dp_group
Expand All @@ -8,7 +6,9 @@
from vllm_ascend.distributed.parallel_state import get_fc3_quant_x_group


def fc3_all_gather_and_maybe_unpad_impl(x: torch.Tensor, ) -> torch.Tensor:
def fc3_all_gather_and_maybe_unpad_impl(
x: torch.Tensor,
) -> torch.Tensor:
try:
forward_context = get_forward_context()
except AssertionError:
Expand All @@ -22,34 +22,26 @@ def fc3_all_gather_and_maybe_unpad_impl(x: torch.Tensor, ) -> torch.Tensor:
else:
# unpad
num_tokens_across_dp_cpu = dp_metadata.num_tokens_across_dp_cpu
result = torch.empty((num_tokens_across_dp_cpu.sum(), *x.shape[1:]),
device=x.device,
dtype=x.dtype)
result = torch.empty((num_tokens_across_dp_cpu.sum(), *x.shape[1:]), device=x.device, dtype=x.dtype)
dp_size = get_dp_group().world_size
x = x.view(dp_size, forward_context.padded_length, *x.shape[1:])
offset = 0
for idx in range(dp_size):
num_tokens_dp = num_tokens_across_dp_cpu[idx]
result[offset:offset + num_tokens_dp] = x[idx, :num_tokens_dp]
result[offset : offset + num_tokens_dp] = x[idx, :num_tokens_dp]
offset += num_tokens_dp
x = result

return x


def all_gather_async(input: torch.Tensor,
group: GroupCoordinator,
output: Optional[torch.Tensor] = None,
async_op: bool = True):
def all_gather_async(
input: torch.Tensor, group: GroupCoordinator, output: torch.Tensor | None = None, async_op: bool = True
):
if group.world_size == 1:
return input, None
if output is None:
input_size = input.size()
output_size = (input_size[0] * group.world_size, ) + input_size[1:]
output = torch.empty(output_size,
dtype=input.dtype,
device=input.device)
return output, dist.all_gather_into_tensor(output,
input,
group=group.device_group,
async_op=async_op)
output_size = (input_size[0] * group.world_size,) + input_size[1:]
output = torch.empty(output_size, dtype=input.dtype, device=input.device)
return output, dist.all_gather_into_tensor(output, input, group=group.device_group, async_op=async_op)
104 changes: 57 additions & 47 deletions vllm_ascend/ops/layer_shard_linear.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from collections.abc import Callable
from dataclasses import dataclass
from functools import lru_cache
from typing import Callable, List, Optional

import torch
import torch.distributed as dist
Expand All @@ -17,39 +17,38 @@ def dispose_tensor(x: torch.Tensor):

@dataclass
class LayerMetadata:
"""Metadata for a layer.
"""
"""Metadata for a layer."""

layer_idx: int # The index of the layer.
layer: LinearBase # The layer object.
post_method: Callable[[
torch.nn.Module
], None] # The `process_weights_after_loading` method from the quant method.
post_method: Callable[[torch.nn.Module], None] # The `process_weights_after_loading` method from the quant method.
weight: torch.Tensor # The weight tensor.
window_idx: int # The index of the window.


@dataclass
class ShardWindowMetadata:
"""Metadata for a shard window.
"""
"""Metadata for a shard window."""

weight: torch.Tensor # The weight tensor to be shard by layers.
data_layer_idx: int # The index of the layer this window's weight is equal to.
work: Optional[torch.distributed.Work] # The asynchronous broadcast work.
work: torch.distributed.Work | None # The asynchronous broadcast work.


@dataclass
class SeriesMetadata:
"""Metadata for a weight shard series.
"""
"""Metadata for a weight shard series."""

group: GroupCoordinator
start_layer: int
end_layer: int
num_layers: int
prefetch_step: int
dummy_weight: torch.Tensor # Dummy weight to replace the loaded weight matrix. All the layers in the series share the same dummy weight tensor.
dummy_weight: torch.Tensor # Dummy weight to replace the loaded weight matrix.
# All the layers in the series share the same dummy weight tensor.
layers: list[LayerMetadata]
shard_windows: list[
ShardWindowMetadata] # Shard windows for prefetching. The window size is (`prefetch_step` + 1), as only the weights for the next (`prefetch_step` + 1) layers need to be stored.
shard_windows: list[ShardWindowMetadata] # Shard windows for prefetching. The window size is (`prefetch_step` + 1),
# as only the weights for the next (`prefetch_step` + 1) layers need to be stored.
window_offset: int # The index of the window for the next coming layer.

def is_source(self, layer_idx) -> bool:
Expand All @@ -63,35 +62,37 @@ def post_process_after_loading(self):
self.layers.sort(key=lambda x: x.layer_idx)
self.num_layers = len(self.layers)
assert self.num_layers > 0, "No layers in the series"
assert self.prefetch_step >= 0 and self.prefetch_step <= max(
0, self.num_layers -
2), "prefetch_step must be in [0, num_layers - 2]"
assert self.prefetch_step >= 0 and self.prefetch_step <= max(0, self.num_layers - 2), (
"prefetch_step must be in [0, num_layers - 2]"
)
self.start_layer = self.layers[0].layer_idx
self.end_layer = self.layers[-1].layer_idx + 1

for layer_idx in range(self.start_layer, self.end_layer):
layer = self.layers[layer_idx - self.start_layer]
assert layer.layer_idx == layer_idx, "layer_idx must be consecutive"
is_source = self.is_source(layer_idx)
# If the weight uses dummy weight, make a copy temporary such that the post method call won't affect other layers which also uses dummy weight.
# If the weight uses dummy weight, make a copy temporary such that the post method call
# won't affect other layers which also uses dummy weight.
if not is_source:
layer.weight.set_(torch.empty_like(self.dummy_weight))
# Broadcast to get the true weight.
dist.broadcast(layer.weight,
src=self.group.ranks[layer_idx %
self.group.world_size],
group=self.group.device_group)
dist.broadcast(
layer.weight, src=self.group.ranks[layer_idx % self.group.world_size], group=self.group.device_group
)
# Call `process_weights_after_loading` from the quant method.
layer.post_method(layer.layer)
step = layer_idx - self.start_layer
if step < self.prefetch_step:
# Build the windows for the first `prefetch_step` layers. The weights can be used for the first `prefetch_step` layers in `forward()`, so also clone the weights.
# Build the windows for the first `prefetch_step` layers. The weights can be used
# for the first `prefetch_step` layers in `forward()`, so also clone the weights.
self.shard_windows.append(
ShardWindowMetadata(
weight=layer.weight.clone().detach(),
data_layer_idx=layer_idx,
work=None,
))
)
)
layer.window_idx = step
# When the layer not intended to be stored in this device, link to the corresponding window's tensor.
if not is_source:
Expand All @@ -104,7 +105,8 @@ def post_process_after_loading(self):
weight=torch.empty_like(layer.weight),
data_layer_idx=-1,
work=None,
))
)
)
# When the layer not intended to be stored in this device, dispose the tensor.
if not is_source:
dispose_tensor(layer.weight)
Expand All @@ -113,8 +115,7 @@ def post_process_after_loading(self):

def reach_layer(self, layer_idx: int):
# The index of the layer to be prefetched.
next_layer_idx = (layer_idx + self.prefetch_step
) % self.num_layers + self.start_layer
next_layer_idx = (layer_idx + self.prefetch_step) % self.num_layers + self.start_layer
next_layer = self.layers[next_layer_idx - self.start_layer]
# The index of the window to store the weight for the coming layer.
next_layer.window_idx = self.window_offset
Expand All @@ -123,22 +124,21 @@ def reach_layer(self, layer_idx: int):
if not self.is_source(next_layer_idx):
next_layer.weight.set_(window.weight)
# Update `window_offset` by rolling one step.
self.window_offset = (self.window_offset + 1) % (self.prefetch_step +
1)
self.window_offset = (self.window_offset + 1) % (self.prefetch_step + 1)
assert window.data_layer_idx != next_layer_idx
window.data_layer_idx = next_layer_idx
# Start asynchronous broadcast work.
window.work = dist.broadcast(
next_layer.weight,
src=self.group.ranks[next_layer_idx % self.group.world_size],
group=self.group.device_group,
async_op=True)
async_op=True,
)

def wait_weight(self, layer_idx: int):
# Find the asynchronous broadcast work and wait for it.
assert self.shard_windows
window = self.shard_windows[self.layers[layer_idx -
self.start_layer].window_idx]
window = self.shard_windows[self.layers[layer_idx - self.start_layer].window_idx]
# Make sure the data in the corresponding shard window is for the current layer.
assert window.data_layer_idx == layer_idx
if window.work is not None:
Expand All @@ -148,8 +148,8 @@ def wait_weight(self, layer_idx: int):

@dataclass
class LayerExternalMetadata:
"""External metadata for a layer.
"""
"""External metadata for a layer."""

series: SeriesMetadata
layer_idx: int

Expand All @@ -159,9 +159,7 @@ class LayerExternalMetadata:
_layer_external_dict: dict[int, LayerExternalMetadata] = {}


def _create_forward_wrapper(forward: Callable, series: SeriesMetadata,
layer_idx: int) -> Callable:

def _create_forward_wrapper(forward: Callable, series: SeriesMetadata, layer_idx: int) -> Callable:
def wrapped_forward(*args, **kwargs):
# Wait for the weight.
series.wait_weight(layer_idx)
Expand All @@ -173,23 +171,32 @@ def wrapped_forward(*args, **kwargs):
"""
Register linear layers into a shard storage series.

In a parallel group, each device stores a distinct, non-overlapping subset of layers from the series. All layers in a series must have the same structure (are isomorphic). The weight matrix for the i-th layer is stored on device (i % n), where n is the number of devices.
In a parallel group, each device stores a distinct, non-overlapping subset of layers from the series.
All layers in a series must have the same structure (are isomorphic). The weight matrix for the i-th layer
is stored on device (i % n), where n is the number of devices.

After loading the model, you must call `post_process_after_loading_for_shard_weight_series(layer)` on any layer of this series to complete the initialization.
After loading the model, you must call `post_process_after_loading_for_shard_weight_series(layer)`
on any layer of this series to complete the initialization.

During execution, each time a new layer is reached, you must call `reach_layer_for_shard_weight_series(layer)` for that layer to prefetch the weights. The argument `prefetch_step` is a non-negative integer k that manages asynchronous weight prefetching. Each call to `reach_layer_for_shard_weight_series(current_layer)` method will trigger an asynchronous prefetch for the weights of the k-th subsequent layer after `current_layer` within the series.
During execution, each time a new layer is reached, you must call `reach_layer_for_shard_weight_series(layer)`
for that layer to prefetch the weights. The argument `prefetch_step` is a non-negative integer k that manages
asynchronous weight prefetching. Each call to `reach_layer_for_shard_weight_series(current_layer)` method will
trigger an asynchronous prefetch for the weights of the k-th subsequent layer after `current_layer` within the series.

Note: The layers are managed as a circular buffer. The index of the layer to prefetch is determined by the formula:
- start_layer is the index of the first layer in the series (inclusive).
- end_layer is the index of the last layer in the series (exclusive). Thus, the series includes all layers with indices in the range [start_layer, end_layer).
- end_layer is the index of the last layer in the series (exclusive). Thus, the series includes all layers with
indices in the range [start_layer, end_layer).
- total_layers = end_layer - start_layer
- prefetch_layer_idx = (layer_idx + prefetch_step) % total_layers + start_layer

To hold the weights for the current layer and the k prefetched layers, a pool of (k + 1) shard tensor buffers will be created for this series.
To hold the weights for the current layer and the k prefetched layers, a pool of (k + 1) shard tensor buffers
will be created for this series.

Arguments:
series_name: This name identifies which series this layer belongs to.
group: The group coordinator for handling asynchronous communications. It is recommended to create a new group coordinator for each new series.
group: The group coordinator for handling asynchronous communications. It is recommended to create a new group
coordinator for each new series.
layer: The linear layer object to register.
prefetch_step: An integer that manages asynchronous weight prefetching. Setting it to 0 or 1 can cover most cases.
"""
Expand Down Expand Up @@ -224,7 +231,8 @@ def register_layer_to_shard_weight_series(
post_method=layer.quant_method.process_weights_after_loading,
weight=layer.weight,
window_idx=-1,
))
)
)
# Discard the original `process_weights_after_loading` method such that it won't be called by others.
layer.quant_method.process_weights_after_loading = lambda layer: None
# When the layer not intended to be stored in this device, dispose the tensor and skip weight loading.
Expand Down Expand Up @@ -257,6 +265,7 @@ def wait_layer_for_shard_weight_series(layer: LinearBase):
@lru_cache(maxsize=1)
def get_current_model_num_hidden_layers() -> int:
from vllm.config import get_current_vllm_config

vllm_config = get_current_vllm_config()
return vllm_config.model_config.get_total_num_hidden_layers()

Expand All @@ -268,10 +277,11 @@ def is_hidden_layer(layer: LinearBase) -> bool:


def register_all_layers_to_shard_weight_series(
layer_sharding: List[LinearBase], ):
for curr_layer in (layer_sharding or []):
layer_sharding: list[LinearBase],
):
for curr_layer in layer_sharding or []:
if is_hidden_layer(curr_layer):
layer_name = curr_layer.prefix.split('.')[-1]
layer_name = curr_layer.prefix.split(".")[-1]
register_layer_to_shard_weight_series(
series_name=layer_name,
group=get_shard_weight_group(),
Expand Down
Loading
Loading