Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 2 additions & 6 deletions tests/ut/attention/test_attention_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,7 @@ def test_get_builder_cls(self):
self.assertEqual(AscendAttentionBackend.get_builder_cls(),
AscendAttentionMetadataBuilder)

@patch('vllm_ascend.utils.get_ascend_device_type',
return_value=AscendDeviceType.A3)
def test_get_kv_cache_shape_not_310p(self, mock_soc_version):
def test_get_kv_cache_shape_not(self):
result = AscendAttentionBackend.get_kv_cache_shape(10, 20, 30, 40)
self.assertEqual(result, (2, 10, 20, 30, 40))

Expand Down Expand Up @@ -92,9 +90,7 @@ def test_reorder_batch(self):
self.assertFalse(result)

@patch('vllm_ascend.attention.attention_v1.AscendMetadata')
@patch('vllm_ascend.utils.get_ascend_device_type',
return_value=AscendDeviceType.A3)
def test_build_non_310p(self, mock_soc_version, mock_ascend_metadata):
def test_build(self, mock_ascend_metadata):
common_attn_metadata = AscendCommonAttentionMetadata(
query_start_loc=torch.tensor([0, 2, 5, 9]),
query_start_loc_cpu=torch.tensor([0, 2, 5, 9]),
Expand Down
69 changes: 44 additions & 25 deletions tests/ut/ops/test_activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,45 +52,64 @@ def test_QuickGELU_forward(mock_gelu, dummy_tensor, default_vllm_config):
mock_gelu.assert_called_once()


@pytest.mark.skipif(is_310p_hw(), reason="310P operator classes have already been refactored.")
@pytest.mark.parametrize("is_310p", [True, False])
@pytest.mark.skipif(is_310p_hw(), reason="non_310P device unittest case.")
@patch("torch_npu.npu_swiglu", side_effect=lambda x: x + 1)
@patch("torch.ops.vllm.maybe_wait_prefetch_done", side_effect=lambda x: None)
@patch("torch.ops.vllm.maybe_prefetch_mlp_down_proj", side_effect=lambda x: None)
def test_SiluAndMul_forward(
Comment thread
pu-zhe marked this conversation as resolved.
mock_maybe_prefetch_mlp_down_proj,
mock_maybe_wait_prefetch_done,
mock_swiglu,
is_310p,
dummy_tensor,
default_vllm_config,
):
if is_310p and (not is_310p_hw()):
pytest.skip("Pseudo-310P param case is not valid on non-310P CI after refactor.")
layer = SiluAndMul()
out = layer.forward(dummy_tensor)
expected_arg = dummy_tensor

# assert mock_maybe_prefetch_mlp_down_proj.call_count == 1
mock_maybe_prefetch_mlp_down_proj.assert_called_once()

# assert mock_swiglu.call_count == 1
mock_swiglu.assert_called_once()

# assert mock_maybe_wait_prefetch_done.call_count == 1
mock_maybe_wait_prefetch_done.assert_called_once()

actual_arg = mock_swiglu.call_args[0][0]
assert torch.allclose(actual_arg, expected_arg), "npu_swiglu called with unexpected input"

expected_out = dummy_tensor + 1
assert torch.allclose(out, expected_out)

with patch(
"vllm_ascend.utils.get_ascend_device_type",
return_value=AscendDeviceType._310P if is_310p else AscendDeviceType.A3,
):
layer = SiluAndMul()
out = layer.forward(dummy_tensor)

if is_310p:
expected_arg = dummy_tensor.to(torch.float32)
else:
expected_arg = dummy_tensor
@pytest.mark.skipif(not is_310p_hw(), reason="310P device unittest case.")
@patch("torch.nn.functional.silu", side_effect=lambda x: x + 1)
@patch("torch.ops.vllm.maybe_wait_prefetch_done", side_effect=lambda x: None)
@patch("torch.ops.vllm.maybe_prefetch_mlp_down_proj", side_effect=lambda x: None)
def test_SiluAndMul_forward_310p(
mock_maybe_prefetch_mlp_down_proj,
mock_maybe_wait_prefetch_done,
mock_silu,
dummy_tensor,
default_vllm_config,
):
layer = SiluAndMul()
out = layer.forward(dummy_tensor)
h = dummy_tensor.shape[-1] // 2
expected_arg = dummy_tensor[..., :h]

# assert mock_maybe_prefetch_mlp_down_proj.call_count == 1
mock_maybe_prefetch_mlp_down_proj.assert_called_once()
# assert mock_maybe_prefetch_mlp_down_proj.call_count == 1
mock_maybe_prefetch_mlp_down_proj.assert_called_once()

# assert mock_swiglu.call_count == 1
mock_swiglu.assert_called_once()
# assert mock_silu.call_count == 1
mock_silu.assert_called_once()

# assert mock_maybe_wait_prefetch_done.call_count == 1
mock_maybe_wait_prefetch_done.assert_called_once()
# assert mock_maybe_wait_prefetch_done.call_count == 1
mock_maybe_wait_prefetch_done.assert_called_once()

actual_arg = mock_swiglu.call_args[0][0]
assert torch.allclose(actual_arg, expected_arg), "npu_swiglu called with unexpected input"
actual_arg = mock_silu.call_args[0][0]
assert torch.allclose(actual_arg, expected_arg), "swiglu called with unexpected input"

expected_out = dummy_tensor + 1
assert torch.allclose(out, expected_out)
expected_out = (dummy_tensor[..., :h] + 1) * dummy_tensor[..., h:]
assert torch.allclose(out, expected_out)
71 changes: 37 additions & 34 deletions tests/ut/ops/test_layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,43 +40,46 @@ def default_vllm_config():
yield mock_config


@pytest.mark.skipif(is_310p_hw(), reason="310P operator classes have already been refactored.")
@pytest.mark.parametrize("is_310p", [True, False])
@pytest.mark.skipif(is_310p_hw(), reason="non_310P device unittest case.")
@pytest.mark.parametrize("residual", [None, torch.randn(4, 8, dtype=torch.float32)])
@patch("torch_npu.npu_rms_norm", side_effect=mock_rms_norm)
@patch("torch_npu.npu_add_rms_norm", side_effect=mock_add_rms_norm)
@patch("torch.ops._C_ascend.npu_add_rms_norm_bias", side_effect=mock_add_rms_norm_bias)
def test_RMSNorm_forward(
Comment thread
pu-zhe marked this conversation as resolved.
mock_add_rms_norm_bias, mock_add_rmsnorm, mock_rmsnorm, is_310p, residual, dummy_tensor, default_vllm_config
mock_add_rms_norm_bias, mock_add_rmsnorm, mock_rmsnorm, residual, dummy_tensor, default_vllm_config
):
if is_310p and (not is_310p_hw()):
pytest.skip("Pseudo-310P branch is invalid on non-310P CI after refactor.")

with patch(
"vllm_ascend.utils.get_ascend_device_type",
return_value=AscendDeviceType._310P if is_310p else AscendDeviceType.A3,
):
layer = RMSNorm(hidden_size=8, eps=1e-05)
if residual is not None:
out_x, out_residual = layer.forward_oot(dummy_tensor, residual)

if is_310p:
expected_arg_x = dummy_tensor + residual.to(dummy_tensor.dtype)
expected_out_x = expected_arg_x + 1
expected_out_residual = expected_arg_x.to(residual.dtype)

mock_rmsnorm.assert_called_once()
assert torch.allclose(out_x, expected_out_x)
assert torch.allclose(out_residual, expected_out_residual)
else:
expected_out_x = 2 * dummy_tensor
expected_out_residual = 2 * residual
mock_add_rms_norm_bias.assert_called_once()
assert torch.allclose(out_x, expected_out_x)
assert torch.allclose(out_residual, expected_out_residual)
else:
out_x = layer.forward_oot(dummy_tensor, residual)
expected_out_x = dummy_tensor + 1

mock_rmsnorm.assert_called_once()
assert torch.allclose(out_x, expected_out_x)
layer = RMSNorm(hidden_size=8, eps=1e-05)
if residual is not None:
out_x, out_residual = layer.forward_oot(dummy_tensor, residual)
expected_out_x = 2 * dummy_tensor
expected_out_residual = 2 * residual
mock_add_rms_norm_bias.assert_called_once()
assert torch.allclose(out_x, expected_out_x)
assert torch.allclose(out_residual, expected_out_residual)
else:
out_x = layer.forward_oot(dummy_tensor, residual)
expected_out_x = dummy_tensor + 1

mock_rmsnorm.assert_called_once()
assert torch.allclose(out_x, expected_out_x)


@pytest.mark.skipif(not is_310p_hw(), reason="310P device unittest case.")
@pytest.mark.parametrize("residual", [None, torch.randn(4, 8, dtype=torch.float16)])
@patch("torch_npu.npu_rms_norm", side_effect=mock_rms_norm)
def test_RMSNorm_forward_310p(
mock_rmsnorm, residual, dummy_tensor, default_vllm_config
):
layer = RMSNorm(hidden_size=8, eps=1e-05)
if residual is not None:
out_x, out_residual = layer.forward_oot(dummy_tensor, residual)
expected_out_residual = dummy_tensor + residual
expected_out_x = expected_out_residual + 1
mock_rmsnorm.assert_called_once()
assert torch.allclose(out_x, expected_out_x)
assert torch.allclose(out_residual, expected_out_residual)
else:
out_x = layer.forward_oot(dummy_tensor, residual)
expected_out_x = dummy_tensor + 1
mock_rmsnorm.assert_called_once()
assert torch.allclose(out_x, expected_out_x)