Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,19 @@ line-length = 120
# Folder to be modified
exclude = [
"tests/**",

# (8)
"vllm_ascend/ops/__init__.py",
"vllm_ascend/ops/activation.py",
"vllm_ascend/ops/flashcomm2_oshard_manager.py",
"vllm_ascend/ops/layernorm.py",
"vllm_ascend/ops/mla.py",
"vllm_ascend/ops/mm_encoder_attention.py",
"vllm_ascend/ops/register_custom_ops.py",
"vllm_ascend/ops/vocab_parallel_embedding.py",
"vllm_ascend/ops/weight_prefetch.py",
"vllm_ascend/spec_decode/**",

]

[tool.ruff.lint]
Expand Down
33 changes: 22 additions & 11 deletions vllm_ascend/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@

import vllm_ascend.ops.vocab_parallel_embedding # noqa
from vllm_ascend.ops.activation import AscendQuickGELU, AscendSiluAndMul
from vllm_ascend.ops.rotary_embedding import AscendDeepseekScalingRotaryEmbedding, AscendRotaryEmbedding
from vllm_ascend.ops.rotary_embedding import (
AscendDeepseekScalingRotaryEmbedding, AscendRotaryEmbedding)


class dummyFusionOp:
Expand All @@ -39,13 +40,23 @@ def __init__(self, name=""):

def register_dummy_fusion_op() -> None:
torch.ops._C_ascend.rms_norm = dummyFusionOp(name="rms_norm")
torch.ops._C_ascend.fused_add_rms_norm = dummyFusionOp(name="fused_add_rms_norm")
torch.ops._C_ascend.static_scaled_fp8_quant = dummyFusionOp(name="static_scaled_fp8_quant")
torch.ops._C_ascend.dynamic_scaled_fp8_quant = dummyFusionOp(name="dynamic_scaled_fp8_quant")
torch.ops._C_ascend.dynamic_per_token_scaled_fp8_quant = dummyFusionOp(name="dynamic_per_token_scaled_fp8_quant")
torch.ops._C_ascend.rms_norm_static_fp8_quant = dummyFusionOp(name="rms_norm_static_fp8_quant")
torch.ops._C_ascend.fused_add_rms_norm_static_fp8_quant = dummyFusionOp(name="fused_add_rms_norm_static_fp8_quant")
torch.ops._C_ascend.rms_norm_dynamic_per_token_quant = dummyFusionOp(name="rms_norm_dynamic_per_token_quant")


__all__ = ["AscendQuickGELU", "AscendSiluAndMul", "AscendRotaryEmbedding", "AscendDeepseekScalingRotaryEmbedding"]
torch.ops._C_ascend.fused_add_rms_norm = dummyFusionOp(
name="fused_add_rms_norm")
torch.ops._C_ascend.static_scaled_fp8_quant = dummyFusionOp(
name="static_scaled_fp8_quant")
torch.ops._C_ascend.dynamic_scaled_fp8_quant = dummyFusionOp(
name="dynamic_scaled_fp8_quant")
torch.ops._C_ascend.dynamic_per_token_scaled_fp8_quant = dummyFusionOp(
name="dynamic_per_token_scaled_fp8_quant")
torch.ops._C_ascend.rms_norm_static_fp8_quant = dummyFusionOp(
name="rms_norm_static_fp8_quant")
torch.ops._C_ascend.fused_add_rms_norm_static_fp8_quant = dummyFusionOp(
name="fused_add_rms_norm_static_fp8_quant")
torch.ops._C_ascend.rms_norm_dynamic_per_token_quant = dummyFusionOp(
name="rms_norm_dynamic_per_token_quant")


__all__ = [
"AscendQuickGELU", "AscendSiluAndMul", "AscendRotaryEmbedding",
"AscendDeepseekScalingRotaryEmbedding"
]
4 changes: 2 additions & 2 deletions vllm_ascend/ops/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,10 @@

import torch
from vllm.model_executor.layers.activation import QuickGELU, SiluAndMul

from vllm_ascend.utils import get_weight_prefetch_method


class AscendQuickGELU(QuickGELU):

def forward_oot(self, x: torch.tensor) -> torch.Tensor:
import torch_npu

Expand All @@ -30,6 +29,7 @@ def forward_oot(self, x: torch.tensor) -> torch.Tensor:


class AscendSiluAndMul(SiluAndMul):

def forward_oot(self, x: torch.Tensor) -> torch.Tensor:
import torch_npu

Expand Down
19 changes: 9 additions & 10 deletions vllm_ascend/ops/flashcomm2_oshard_manager.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,11 @@
from typing import Any
from typing import Any, Dict, Optional

from vllm.model_executor.models.utils import extract_layer_index

from vllm_ascend.distributed.parallel_state import get_shard_weight_group
from vllm_ascend.ops.layer_shard_linear import (
is_hidden_layer,
post_process_after_loading_for_shard_weight_series,
reach_layer_for_shard_weight_series,
register_layer_to_shard_weight_series,
)
is_hidden_layer, post_process_after_loading_for_shard_weight_series,
reach_layer_for_shard_weight_series, register_layer_to_shard_weight_series)
from vllm_ascend.utils import flashcomm2_enable, o_shard_enable


Expand All @@ -29,7 +26,7 @@ class Flashcomm2OShardManager:
"""

def __init__(self):
self._shard_layers: dict[int, Any] = {}
self._shard_layers: Dict[int, Any] = {}

def flashcomm2_oshard_enable(self):
return flashcomm2_enable() and o_shard_enable()
Expand All @@ -55,10 +52,12 @@ def register_layer(self, layer: Any, prefetch_step: int = 1):
self._shard_layers[layer_idx] = layer

register_layer_to_shard_weight_series(
series_name="o_proj", group=get_shard_weight_group(), layer=layer, prefetch_step=prefetch_step
)
series_name="o_proj",
group=get_shard_weight_group(),
layer=layer,
prefetch_step=prefetch_step)

def get_layer(self, layer_idx: int) -> Any | None:
def get_layer(self, layer_idx: int) -> Optional[Any]:
"""Safely retrieves a registered layer by its index.

Args:
Expand Down
75 changes: 46 additions & 29 deletions vllm_ascend/ops/layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,53 +15,56 @@
# This file is a part of the vllm-ascend project.
#

from typing import Optional, Tuple, Union

import torch
from torch import nn
from vllm.config import get_current_vllm_config
from vllm.model_executor.layers.layernorm import GemmaRMSNorm, RMSNorm, RMSNormGated

from vllm_ascend.ops.triton.layernorm_gated import layer_norm_fwd_npu
from vllm_ascend.utils import enable_custom_op, get_weight_prefetch_method

from vllm_ascend.utils import enable_custom_op
from vllm_ascend.utils import get_weight_prefetch_method

class AscendRMSNorm(RMSNorm):

def __init__(
self,
hidden_size: int,
eps: float = 1e-6,
var_hidden_size: int | None = None,
var_hidden_size: Optional[int] = None,
has_weight: bool = True,
dtype: torch.dtype | None = None,
dtype: Optional[torch.dtype] = None,
) -> None:
super().__init__(hidden_size, eps, var_hidden_size, has_weight, dtype)
vllm_config = get_current_vllm_config()
self.bias = None
# quantization with anti_method m4 will generate none-zero norm bias
if vllm_config.quant_config is not None and any(
"norm.bias" in name for name in vllm_config.quant_config.quant_description
):
self.bias = torch.nn.Parameter(torch.zeros(hidden_size), requires_grad=False)
if vllm_config.quant_config is not None and \
any("norm.bias" in name for name in vllm_config.quant_config.quant_description.keys()):
self.bias = torch.nn.Parameter(torch.zeros(hidden_size),
requires_grad=False)

def forward_oot(
self,
x: torch.Tensor,
residual: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
import torch_npu

if residual is not None:
if enable_custom_op():
x, _, residual = torch.ops._C_ascend.npu_add_rms_norm_bias(
x, residual, self.weight, self.bias, self.variance_epsilon
)
x, residual, self.weight, self.bias, self.variance_epsilon)
else:
x, _, residual = torch_npu.npu_add_rms_norm(x, residual, self.weight, self.variance_epsilon)
x, _, residual = torch_npu.npu_add_rms_norm(
x, residual, self.weight, self.variance_epsilon)
if self.bias is not None:
x.add_(self.bias)
return x, residual

x, residual = torch_npu.npu_rms_norm(x, self.weight, self.variance_epsilon)
x, residual = torch_npu.npu_rms_norm(x, self.weight,
self.variance_epsilon)
if self.bias is not None:
x.add_(self.bias)

Expand All @@ -72,30 +75,42 @@ def forward_oot(


class AscendGemmaRMSNorm(GemmaRMSNorm):

def forward_oot(
self,
x: torch.Tensor,
residual: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
import torch_npu

from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type
if residual is not None:
if enable_custom_op():
x, _, residual = torch.ops._C_ascend.npu_add_rms_norm_bias(
x, residual, 1.0 + self.weight, None, self.variance_epsilon
)
x, residual, 1.0 + self.weight, None,
self.variance_epsilon)
else:
x, _, residual = torch_npu.npu_add_rms_norm(x, residual, 1.0 + self.weight, self.variance_epsilon)
x, _, residual = torch_npu.npu_add_rms_norm(
x, residual, 1.0 + self.weight, self.variance_epsilon)
return x, residual

x, _ = torch_npu.npu_rms_norm(x, 1.0 + self.weight, self.variance_epsilon)
x, _ = torch_npu.npu_rms_norm(x, 1.0 + self.weight,
self.variance_epsilon)
return x


class LayerNormFn(torch.autograd.Function):
@staticmethod
def forward(ctx, x, weight, bias, z=None, eps=1e-6, group_size=None, norm_before_gate=True, is_rms_norm=False):
"""If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))"""
def forward(ctx,
x,
weight,
bias,
z=None,
eps=1e-6,
group_size=None,
norm_before_gate=True,
is_rms_norm=False):
"""If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))
"""

x_shape_og = x.shape
# reshape input data into 2D tensor
Expand Down Expand Up @@ -128,16 +143,16 @@ def forward(ctx, x, weight, bias, z=None, eps=1e-6, group_size=None, norm_before
ctx.is_rms_norm = is_rms_norm
return y.reshape(x_shape_og)


class AscendRMSNormGated(RMSNormGated):

def __init__(
self,
hidden_size,
eps: float = 1e-5,
group_size: int | None = None,
group_size: Optional[int] = None,
norm_before_gate: bool = False,
device: torch.device | None = None,
dtype: torch.dtype | None = None,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
"""If group_size is not None, we do GroupNorm with each group having group_size elements.
group_size=None is equivalent to group_size=hidden_size (i.e. there's only 1 group).
Expand All @@ -155,5 +170,7 @@ def reset_parameters(self):
torch.nn.init.ones_(self.weight)

def forward_oot(self, x, z=None):
"""If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))"""
return LayerNormFn.apply(x, self.weight, self.bias, z, self.eps, self.group_size, self.norm_before_gate, True)
"""If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))
"""
return LayerNormFn.apply(x, self.weight, self.bias, z, self.eps, self.group_size,
self.norm_before_gate, True)
Loading
Loading