Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions tests/ut/attention/test_mla_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
import torch
from vllm.config import CacheConfig, ModelConfig, SchedulerConfig, VllmConfig
from vllm.distributed.parallel_state import GroupCoordinator
from vllm.model_executor.layers.linear import LinearBase
from vllm.model_executor.layers.linear import (LinearBase,
UnquantizedLinearMethod)

from tests.ut.base import TestBase
from vllm_ascend.ascend_config import init_ascend_config
Expand Down Expand Up @@ -972,16 +973,13 @@ def test_q_proj_and_k_up_proj(self):
def test_process_weights_after_loading(self, mock_format_cast):
layer = MagicMock(spec=LinearBase)
layer.input_size_per_partition = 10
quant_method = MagicMock()
apply = MagicMock()
quant_method.apply = apply
quant_method = MagicMock(spec=UnquantizedLinearMethod)
layer.quant_method = quant_method
shape_0 = self.impl.num_heads * (self.impl.qk_nope_head_dim +
self.impl.v_head_dim)
shape_1 = self.impl.kv_lora_rank
layer.weight = torch.randn(shape_0, shape_1)
self.impl.kv_b_proj = layer
apply.return_value = layer.weight.T
mock_format_cast.return_value = layer.weight
self.impl.process_weights_after_loading(torch.bfloat16)

Expand Down
21 changes: 12 additions & 9 deletions tests/ut/ops/test_linear.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import unittest
from unittest import mock
from unittest.mock import MagicMock, patch
Expand Down Expand Up @@ -61,22 +62,24 @@ def setUp(self):
mock_dtype = mock.PropertyMock(return_value=torch.float16)
type(self.layer.weight.data).dtype = mock_dtype

@mock.patch("vllm_ascend.ops.linear.is_enable_nz")
@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_NZ": "0"})
@mock.patch("torch_npu.npu_format_cast")
def test_process_weights_after_loading_enable_nz(self, mock_format_cast,
mock_is_nz):
mock_is_nz.return_value = 1
def test_process_weights_after_loading_with_nz0(self, mock_format_cast):
self.method.process_weights_after_loading(self.layer)
mock_format_cast.assert_called_once()
mock_format_cast.assert_not_called()

@mock.patch("vllm_ascend.ops.linear.is_enable_nz")
@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_NZ": "1"})
@mock.patch("torch_npu.npu_format_cast")
def test_process_weights_after_loading_disable_nz(self, mock_format_cast,
mock_is_nz):
mock_is_nz.return_value = 0
def test_process_weights_after_loading_with_nz1(self, mock_format_cast):
self.method.process_weights_after_loading(self.layer)
mock_format_cast.assert_not_called()

@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_NZ": "2"})
@mock.patch("torch_npu.npu_format_cast")
def test_process_weights_after_loading_with_nz2(self, mock_format_cast):
self.method.process_weights_after_loading(self.layer)
mock_format_cast.assert_called_once()


class TestAscendRowParallelLinear(BaseLinearTest):

Expand Down
30 changes: 0 additions & 30 deletions tests/ut/quantization/test_w4a4_flatquant_dynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,6 @@ def test_process_weights_after_loading(self, mock_pack_weights):
(self.output_size, self.input_size // 8),
dtype=torch.int32)
mock_pack_weights.return_value = mock_packed
self.method.transpose_weight = False
self.method.process_weights_after_loading(layer)
mock_pack_weights.assert_called_once()
self.assertFalse(hasattr(layer, 'weight'))
Expand All @@ -212,35 +211,6 @@ def test_process_weights_after_loading(self, mock_pack_weights):
self.assertEqual(layer.left_trans.shape, (24, 24))
self.assertTrue(layer.left_trans.is_contiguous())

@patch('vllm_ascend.quantization.w4a4_flatquant_dynamic.pack_int4_weights')
def test_process_weights_after_loading_with_transpose(
self, mock_pack_weights):
"""Tests weight processing after loading, with transpose."""
layer = nn.Module()
layer.weight = torch.randint(-8,
7, (self.output_size, self.input_size),
dtype=torch.int8)
layer.weight_scale = torch.randn(self.output_size,
1,
dtype=torch.bfloat16)
layer.weight_offset = torch.randn(self.output_size,
1,
dtype=torch.bfloat16)
layer.left_trans = torch.randn(24, 24)
layer.right_trans = torch.randn(32, 32)
layer.clip_ratio = torch.tensor([0.9])
mock_packed = torch.randint(0,
100,
(self.output_size, self.input_size // 8),
dtype=torch.int32)
mock_pack_weights.return_value = mock_packed
self.method.transpose_weight = True
self.method.process_weights_after_loading(layer)
self.assertTrue(hasattr(layer, 'weight_packed'))
self.assertEqual(layer.weight_packed.shape,
(self.input_size // 8, self.output_size))
self.assertTrue(layer.weight_packed.is_contiguous())


if __name__ == '__main__':
unittest.main(argv=['first-arg-is-ignored'], exit=False)
7 changes: 6 additions & 1 deletion tests/ut/quantization/test_w4a8_dynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@ def test_get_pergroup_param(self):

@patch('torch_npu.npu_convert_weight_to_int4pack')
@patch('torch.Tensor.npu')
def test_process_weights_after_loading(self, mock_npu,
@patch("torch_npu.npu_format_cast")
def test_process_weights_after_loading(self, mock_format_cast, mock_npu,
mock_npu_convert_weight):
mock_npu.side_effect = lambda: torch.zeros(
(1, 32), dtype=torch.float32)
Expand All @@ -85,6 +86,8 @@ def test_process_weights_after_loading(self, mock_npu,
layer.weight_offset_second = torch.nn.Parameter(torch.empty_like(
layer.weight_scale_second.data),
requires_grad=False)
mock_format_cast.return_value = layer.weight.data.transpose(
0, 1).contiguous()
self.method.process_weights_after_loading(layer)
self.assertTrue(hasattr(layer, "weight_scale_bias"))
self.assertEqual(layer.weight_scale_bias.data.shape, (32, ))
Expand All @@ -110,6 +113,8 @@ def test_process_weights_after_loading(self, mock_npu,
new_layer.scale_bias = torch.nn.Parameter(torch.zeros(
(32, 1), dtype=torch.float32),
requires_grad=False)
mock_format_cast.return_value = new_layer.weight.data.transpose(
0, 1).contiguous()
self.method.process_weights_after_loading(new_layer)
self.assertEqual(new_layer.scale_bias.data.shape, (32, ))
self.assertTrue(hasattr(new_layer, "weight_scale_second"))
Expand Down
52 changes: 42 additions & 10 deletions tests/ut/quantization/test_w8a8.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
from unittest.mock import MagicMock, patch

import torch
Expand Down Expand Up @@ -132,20 +133,21 @@ def test_apply_with_x_is_310p(self, mock_npu_quant_matmul,
expected_y_output += bias
self.assertTrue(torch.equal(output, expected_y_output))

@patch("vllm_ascend.quantization.w8a8.is_enable_nz")
@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_NZ": "0"})
@patch('torch_npu.npu_format_cast')
def test_process_weights_after_loading_not_nz(self, mock_npu_format_cast,
mock_is_nz):
def test_process_weights_after_loading_with_nz0(self,
mock_npu_format_cast):
layer = MagicMock()

layer.weight.data = torch.randn(128, 256)
layer.weight.data = torch.randint(-127,
128, (128, 256),
dtype=torch.int8)
layer.input_scale.data = torch.tensor([0.1])
layer.input_offset.data = torch.tensor([0])
layer.deq_scale = torch.tensor([0.5])
layer.weight_scale.data = torch.randn(128, 1)
layer.weight_offset.data = torch.randn(128, 1)

mock_is_nz.return_value = 0
mock_npu_format_cast.return_value = MagicMock
self.method.process_weights_after_loading(layer)

Expand All @@ -160,20 +162,50 @@ def test_process_weights_after_loading_not_nz(self, mock_npu_format_cast,
self.assertEqual(layer.weight_offset.data.shape, (128, ))
mock_npu_format_cast.assert_not_called()

@patch("vllm_ascend.quantization.w8a8.is_enable_nz")
@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_NZ": "1"})
@patch('torch_npu.npu_format_cast')
def test_process_weights_after_loading_nz(self, mock_npu_format_cast,
mock_is_nz):
def test_process_weights_after_loading_with_nz1(self,
mock_npu_format_cast):
layer = MagicMock()

layer.weight.data = torch.randn(128, 256)
layer.weight.data = torch.randint(-127,
128, (128, 256),
dtype=torch.int8)
layer.input_scale.data = torch.tensor([0.1])
layer.input_offset.data = torch.tensor([0])
layer.deq_scale = torch.tensor([0.5])
layer.weight_scale.data = torch.randn(128, 1)
layer.weight_offset.data = torch.randn(128, 1)

mock_npu_format_cast.return_value = MagicMock
self.method.process_weights_after_loading(layer)

expected_offset = torch.tensor([0]).repeat(256).to(torch.int8)
self.assertTrue(
torch.equal(layer.aclnn_input_offset.data, expected_offset))
self.assertFalse(layer.aclnn_input_offset.requires_grad)

self.assertFalse(layer.deq_scale.requires_grad)

self.assertEqual(layer.weight_scale.data.shape, (128, ))
self.assertEqual(layer.weight_offset.data.shape, (128, ))
mock_npu_format_cast.assert_called_once()

@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_NZ": "2"})
@patch('torch_npu.npu_format_cast')
def test_process_weights_after_loading_with_nz2(self,
mock_npu_format_cast):
layer = MagicMock()

layer.weight.data = torch.randint(-127,
128, (128, 256),
dtype=torch.int8)
layer.input_scale.data = torch.tensor([0.1])
layer.input_offset.data = torch.tensor([0])
layer.deq_scale = torch.tensor([0.5])
layer.weight_scale.data = torch.randn(128, 1)
layer.weight_offset.data = torch.randn(128, 1)

mock_is_nz.return_value = 1
mock_npu_format_cast.return_value = MagicMock
self.method.process_weights_after_loading(layer)

Expand Down
8 changes: 0 additions & 8 deletions tests/ut/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,6 @@ def setUp(self):
from vllm_ascend import platform
importlib.reload(platform)

def test_is_enable_nz(self):
with mock.patch("vllm_ascend.utils.envs_ascend.VLLM_ASCEND_ENABLE_NZ",
1):
self.assertTrue(utils.is_enable_nz())
with mock.patch("vllm_ascend.utils.envs_ascend.VLLM_ASCEND_ENABLE_NZ",
0):
self.assertFalse(utils.is_enable_nz())

def test_nd_to_nz_2d(self):
# can be divided by 16
input_tensor = torch.randn(32, 64)
Expand Down
60 changes: 13 additions & 47 deletions vllm_ascend/attention/mla_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@
get_pcp_group)
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.logger import logger
from vllm.model_executor.layers.linear import (LinearBase,
UnquantizedLinearMethod)
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
from vllm.utils.math_utils import cdiv, round_down
from vllm.v1.attention.backends.utils import AttentionCGSupport
from vllm.v1.kv_cache_interface import MLAAttentionSpec
Expand All @@ -38,8 +37,8 @@
register_layer_to_shared_weight_series)
from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch
from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
flashcomm2_o_shared_enabled, is_enable_nz,
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND,
flashcomm2_o_shared_enabled, maybe_trans_nz,
weak_ref_tensors)
from vllm_ascend.worker.npu_input_batch import NPUInputBatch

Expand Down Expand Up @@ -796,40 +795,11 @@ def _q_proj_and_k_up_proj(self, x):
return ql_nope.transpose(0, 1), q_pe

def process_weights_after_loading(self, act_dtype: torch.dtype):

def get_layer_weight(layer):
WEIGHT_NAMES = ("weight", "qweight", "weight_packed")
for attr in WEIGHT_NAMES:
try:
return getattr(layer, attr)
except AttributeError:
pass
raise AttributeError(
f"Layer '{layer}' has no recognized weight attribute:"
f" {WEIGHT_NAMES}.")

def get_and_maybe_dequant_weights(layer: LinearBase):
if not isinstance(layer.quant_method, UnquantizedLinearMethod):
# NOTE: This should only be used offline, since it's O(N^3)
eye = torch.eye(layer.input_size_per_partition,
dtype=act_dtype,
device=get_layer_weight(layer).device)
dequant_weights = layer.quant_method.apply(layer,
eye,
bias=None)
del eye
# standardize to (output, input)
return dequant_weights.T
# Weight will be reshaped next. To be on the safe side, the format
# of the weight should be reverted to FRACTAL_AND.
layer.weight.data = torch_npu.npu_format_cast(
layer.weight.data, ACL_FORMAT_FRACTAL_ND)
return layer.weight

# we currently do not have quantized bmm's which are needed for
# `W_UV` and `W_UK_T`, we we just store fp16/bf16 copies and perform
# the bmm's in 16-bit, the extra memory overhead of this is fairly low
kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T
# NOTE: We currently do not support quant kv_b_proj.
assert isinstance(self.kv_b_proj.quant_method, UnquantizedLinearMethod)
# NOTE: Weight will be reshaped next, we need to revert and transpose it.
kv_b_proj_weight = torch_npu.npu_format_cast(
self.kv_b_proj.weight.data, ACL_FORMAT_FRACTAL_ND).T
assert kv_b_proj_weight.shape == (
self.kv_lora_rank,
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim)), (
Expand All @@ -852,15 +822,8 @@ def get_and_maybe_dequant_weights(layer: LinearBase):
# Convert from (L, N, P) to (N, P, L)
self.W_UK_T = W_UK.permute(1, 2, 0).contiguous()

# Function `get_and_maybe_dequant_weights` will cast the weights to
# FRACTAL_AND. So we need to cast to FRACTAL_NZ again.
if is_enable_nz():
self.kv_b_proj.weight.data = torch_npu.npu_format_cast(
self.kv_b_proj.weight.data, ACL_FORMAT_FRACTAL_NZ)

# Waiting for BMM NZ support
# self.W_UV.data = torch_npu.npu_format_cast(self.W_UV.data, 29)
# self.W_UK_T.data = torch_npu.npu_format_cast(self.W_UK_T.data, 29)
# TODO(zzzzwwjj): Currently, torch.ops._C_ascend.batch_matmul_transpose cannot support weight nz
# self.W_UV = maybe_trans_nz(self.W_UV)

if self.enable_mlapo:
# Currently mlapo only supports W8A8 quantization in MLA scenario
Expand All @@ -875,6 +838,9 @@ def get_and_maybe_dequant_weights(layer: LinearBase):
"thus mlapo is disabled for these layers.")
if self.enable_mlapo:
self._process_weights_for_fused_mlapo(act_dtype)
else:
# if mlapo, W_UK_T can't trans nz
self.W_UK_T = maybe_trans_nz(self.W_UK_T)

if self.fc2_o_shared_enable and is_hidden_layer(
self.vllm_config, self.o_proj):
Expand Down
Loading
Loading