Skip to content
89 changes: 69 additions & 20 deletions tests/kernels/attention/test_deepgemm_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,12 @@
_ceil_to_ue8m0,
calc_diff,
fp8_fp4_mqa_logits,
fp8_fp4_paged_mqa_logits,
get_num_sms,
get_paged_mqa_logits_metadata,
)
from vllm.utils.deep_gemm import (
fp8_fp4_paged_mqa_logits as fp8_paged_mqa_logits,
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we want to change this? I do prefer keep the original name.

)
from vllm.utils.import_utils import has_deep_gemm
from vllm.utils.math_utils import cdiv

Expand Down Expand Up @@ -90,10 +92,64 @@ def _ref_fp8_mqa_logits(
return logits


def _supports_deepgemm_optimized_mqa_logits() -> bool:
return current_platform.is_cuda() and (
current_platform.is_device_capability(90)
or current_platform.is_device_capability_family(100)
)


@pytest.mark.skipif(not current_platform.is_cuda(), reason="CUDA only")
@pytest.mark.skipif(
not current_platform.is_device_capability_family(120), reason="SM120 only"
)
def test_sm120_fp8_mqa_logits_torch_path():
torch.manual_seed(0)

seq_len, seq_len_kv, num_heads, head_dim = 9, 17, 32, 32
q = torch.randn(
seq_len, num_heads, head_dim, device="cuda", dtype=torch.bfloat16
)
kv = torch.randn(seq_len_kv, head_dim, device="cuda", dtype=torch.bfloat16)
weights = torch.randn(seq_len, num_heads, device="cuda", dtype=torch.float32)
cu_seqlen_ks = (torch.arange(seq_len, device="cuda", dtype=torch.int32) % 3)
cu_seqlen_ke = torch.minimum(
torch.arange(seq_len, device="cuda", dtype=torch.int32) + 4,
torch.full((seq_len,), seq_len_kv, device="cuda", dtype=torch.int32),
)

q_fp8 = q.to(torch.float8_e4m3fn)
kv_amax = kv.abs().float().amax(dim=1, keepdim=True).clamp(1e-4)
kv_scale = (kv_amax / 448.0).squeeze(1).contiguous()
kv_fp8 = (kv * (1.0 / kv_scale[:, None])).to(torch.float8_e4m3fn)

logits = fp8_fp4_mqa_logits(
(q_fp8, None),
(kv_fp8, kv_scale),
weights,
cu_seqlen_ks,
cu_seqlen_ke,
clean_logits=True,
)

kv_dequant = kv_fp8.float() * kv_scale[:, None]
score = torch.einsum("mhd,nd->hmn", q_fp8.float(), kv_dequant)
ref_logits = (score.relu() * weights.transpose(0, 1).unsqueeze(-1)).sum(dim=0)
offsets = torch.arange(seq_len_kv, device="cuda")
valid = (offsets[None, :] >= cu_seqlen_ks[:, None]) & (
offsets[None, :] < cu_seqlen_ke[:, None]
)
ref_logits = ref_logits.masked_fill(~valid, float("-inf"))

assert torch.equal(torch.isneginf(logits), torch.isneginf(ref_logits))
finite = torch.isfinite(ref_logits)
assert (logits[finite] - ref_logits[finite]).abs().max() < 1e-4


@pytest.mark.skipif(not current_platform.is_cuda(), reason="CUDA only")
@pytest.mark.skipif(not has_deep_gemm(), reason="DeepGEMM not available")
@pytest.mark.skipif(
not current_platform.has_device_capability(90), reason="SM90 and SM100 only"
not _supports_deepgemm_optimized_mqa_logits(), reason="SM90 and SM100 only"
)
@pytest.mark.parametrize("clean_logits", [True, False])
def test_deepgemm_fp8_mqa_logits(clean_logits: bool):
Expand Down Expand Up @@ -150,7 +206,7 @@ def test_deepgemm_fp8_mqa_logits(clean_logits: bool):
assert diff < 1e-3, f"{diff=}"


def _ref_fp8_fp4_paged_mqa_logits(
def _ref_fp8_paged_mqa_logits(
q: torch.Tensor,
kv_cache: torch.Tensor,
weights: torch.Tensor,
Expand Down Expand Up @@ -203,12 +259,10 @@ def _ref_fp8_fp4_paged_mqa_logits(
@pytest.mark.skipif(not current_platform.is_cuda(), reason="CUDA only")
@pytest.mark.skipif(not has_deep_gemm(), reason="DeepGEMM not available")
@pytest.mark.skipif(
not current_platform.has_device_capability(90), reason="SM90 and SM100 only"
not _supports_deepgemm_optimized_mqa_logits(), reason="SM90 and SM100 only"
)
def test_deepgemm_fp8_fp4_paged_mqa_logits():
# NOTE: clean_logits=True is incompatible with the 2D context_lens
# required by csrc/apis/attention.hpp; only the False path is exercised.
clean_logits = False
@pytest.mark.parametrize("clean_logits", [True, False])
def test_deepgemm_fp8_paged_mqa_logits(clean_logits: bool):
torch.manual_seed(0)
random.seed(0)

Expand Down Expand Up @@ -260,29 +314,24 @@ def test_deepgemm_fp8_fp4_paged_mqa_logits():
q_fp8 = q.to(torch.float8_e4m3fn)
kv_cache_fp8 = kv_cache_cast_to_fp8(kv_cache)

# deep_gemm paged MQA logits requires 2D context_lens of
# shape (B, next_n) (csrc/apis/attention.hpp:332-335);
# see indexer.py:607-608. For each batch/next_n token, the
# effective context length is context_lens[b] - next_n + j + 1.
next_n_arange = torch.arange(next_n, device="cuda", dtype=torch.int32)
context_lens_2d = (
context_lens.unsqueeze(-1) - next_n + 1 + next_n_arange
).contiguous()
deepgemm_context_lens = (
context_lens[:, None].expand(-1, next_n).contiguous()
)
schedule_metadata = get_paged_mqa_logits_metadata(
context_lens_2d, blocksize, get_num_sms()
deepgemm_context_lens, blocksize, get_num_sms()
)
logits = fp8_fp4_paged_mqa_logits(
logits = fp8_paged_mqa_logits(
(q_fp8, None),
kv_cache_fp8,
weights,
context_lens_2d,
deepgemm_context_lens,
block_tables,
schedule_metadata,
max_model_len,
clean_logits=clean_logits,
)

ref_logits = _ref_fp8_fp4_paged_mqa_logits(
ref_logits = _ref_fp8_paged_mqa_logits(
q,
kv_cache,
weights,
Expand Down
60 changes: 60 additions & 0 deletions tests/kernels/moe/test_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,15 @@
)
from vllm.model_executor.layers.fused_moe.config import (
FUSED_MOE_UNQUANTIZED_CONFIG,
FusedMoEConfig,
FusedMoEParallelConfig,
RoutingMethodType,
int4_w4a16_moe_quant_config,
int8_w8a16_moe_quant_config,
mxfp4_w4a16_moe_quant_config,
)
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
MarlinExperts,
batched_fused_marlin_moe,
fused_marlin_moe,
)
Expand Down Expand Up @@ -1007,6 +1012,61 @@ def test_fused_marlin_moe(
torch.testing.assert_close(marlin_output, torch_output, atol=4e-2, rtol=0)


def test_marlin_experts_apply_forwards_gemm1_clamp_limit(monkeypatch):
captured: dict[str, float | None] = {}

def fake_fused_marlin_moe(**kwargs):
captured["clamp_limit"] = kwargs.get("clamp_limit")
kwargs["output"].zero_()

monkeypatch.setattr(
"vllm.model_executor.layers.fused_moe.fused_marlin_moe."
"fused_marlin_moe",
fake_fused_marlin_moe,
)

clamp_limit = 10.0
moe_config = FusedMoEConfig(
num_experts=2,
experts_per_token=1,
hidden_dim=16,
intermediate_size_per_partition=8,
num_local_experts=2,
num_logical_experts=2,
activation=MoEActivation.SILU,
device="cpu",
routing_method=RoutingMethodType.Default,
moe_parallel_config=FusedMoEParallelConfig.make_no_parallel(),
in_dtype=torch.bfloat16,
)
quant_config = mxfp4_w4a16_moe_quant_config(
w1_scale=torch.empty(0),
w2_scale=torch.empty(0),
gemm1_clamp_limit=clamp_limit,
)
experts = MarlinExperts(moe_config=moe_config, quant_config=quant_config)

experts.apply(
output=torch.empty((1, 16), dtype=torch.bfloat16),
hidden_states=torch.empty((1, 16), dtype=torch.bfloat16),
w1=torch.empty((2, 1, 1), dtype=torch.int32),
w2=torch.empty((2, 1, 1), dtype=torch.int32),
topk_weights=torch.ones((1, 1), dtype=torch.float32),
topk_ids=torch.zeros((1, 1), dtype=torch.int32),
activation=MoEActivation.SILU,
global_num_experts=2,
expert_map=None,
a1q_scale=None,
a2_scale=None,
workspace13=torch.empty(0),
workspace2=torch.empty(0),
expert_tokens_meta=None,
apply_router_weight_on_input=False,
)

assert captured["clamp_limit"] == clamp_limit


@pytest.mark.flaky(reruns=2)
@pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm")
@pytest.mark.parametrize("m", [1, 256])
Expand Down
55 changes: 53 additions & 2 deletions tests/models/test_deepseek_v4_mega_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from vllm.model_executor.models.deepseek_v4 import (
DeepseekV4MegaMoEExperts,
_stage_deepseek_v4_mega_moe_inputs,
_use_deepseek_v4_mega_moe,
make_deepseek_v4_expert_params_mapping,
)
from vllm.platforms import current_platform
Expand All @@ -19,6 +20,52 @@
)


def _make_mega_moe_config(
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we solely testing backend selection then we don't need to include this into the tests

*,
enable_expert_parallel: bool = True,
moe_backend: str = "auto",
):
return SimpleNamespace(
parallel_config=SimpleNamespace(
enable_expert_parallel=enable_expert_parallel
),
kernel_config=SimpleNamespace(moe_backend=moe_backend),
)


def test_deepseek_v4_mega_moe_selection_preserves_kernel_config(monkeypatch):
from vllm import envs

monkeypatch.delenv("VLLM_DEEPSEEK_V4_USE_MEGA_MOE", raising=False)
envs.disable_envs_cache()

assert _use_deepseek_v4_mega_moe(
_make_mega_moe_config(moe_backend="deep_gemm_mega_moe")
)
assert not _use_deepseek_v4_mega_moe(_make_mega_moe_config())
with pytest.raises(NotImplementedError, match="requires expert parallel"):
_use_deepseek_v4_mega_moe(
_make_mega_moe_config(
enable_expert_parallel=False,
moe_backend="deep_gemm_mega_moe",
)
)


def test_deepseek_v4_mega_moe_selection_env_override(monkeypatch):
from vllm import envs

monkeypatch.setenv("VLLM_DEEPSEEK_V4_USE_MEGA_MOE", "1")
envs.disable_envs_cache()
assert _use_deepseek_v4_mega_moe(_make_mega_moe_config())

monkeypatch.setenv("VLLM_DEEPSEEK_V4_USE_MEGA_MOE", "0")
envs.disable_envs_cache()
assert not _use_deepseek_v4_mega_moe(
_make_mega_moe_config(moe_backend="deep_gemm_mega_moe")
)


def test_deepseek_v4_mega_moe_expert_mapping():
mapping = make_deepseek_v4_expert_params_mapping(2)

Expand Down Expand Up @@ -46,7 +93,8 @@ def test_deepseek_v4_mega_moe_ue8m0_uint8_to_float():

def test_deepseek_v4_mega_moe_weight_loader_uses_ep_expert_ownership():
vllm_config = SimpleNamespace(
scheduler_config=SimpleNamespace(max_num_batched_tokens=4)
scheduler_config=SimpleNamespace(max_num_batched_tokens=4),
compilation_config=SimpleNamespace(static_forward_context={}),
)
experts = DeepseekV4MegaMoEExperts(
vllm_config,
Expand Down Expand Up @@ -111,7 +159,10 @@ def test_deepseek_v4_mega_moe_weight_loader_uses_ep_expert_ownership():
reason="DeepSeek V4 MegaMoE fused input staging requires CUDA.",
)
def test_deepseek_v4_mega_moe_fused_input_staging_is_bitwise_exact():
from vllm.third_party.deep_gemm.utils import per_token_cast_to_fp8
per_token_cast_to_fp8 = pytest.importorskip(
"deep_gemm.utils",
reason="DeepGEMM helper package is required for FP8 staging parity.",
).per_token_cast_to_fp8

device = torch.device("cuda")
num_tokens = 7
Expand Down
9 changes: 9 additions & 0 deletions tests/models/test_deepseek_v4_pp.py
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't think this is needed as well.

Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from vllm.model_executor.models.deepseek_v4 import DeepseekV4ForCausalLM
from vllm.model_executor.models.interfaces import supports_pp


def test_deepseek_v4_declares_pipeline_parallel_support():
assert supports_pp(DeepseekV4ForCausalLM)
33 changes: 33 additions & 0 deletions tests/quantization/test_fp8_scale_parameter.py
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also no need to include this

Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import pytest
import torch

import vllm.model_executor.parameter as parameter
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
create_fp8_scale_parameter,
)
from vllm.model_executor.parameter import BlockQuantScaleParameter


@pytest.mark.skipif(
not hasattr(torch, "float8_e8m0fnu"),
reason="torch does not expose float8_e8m0fnu",
)
def test_create_fp8_scale_parameter_initializes_e8m0(monkeypatch):
monkeypatch.setattr(parameter, "get_tensor_model_parallel_rank", lambda: 0)
monkeypatch.setattr(parameter, "get_tensor_model_parallel_world_size", lambda: 1)

scale = create_fp8_scale_parameter(
BlockQuantScaleParameter,
output_partition_sizes=[128],
input_size_per_partition=128,
block_size=[128, 128],
weight_loader=None,
scale_dtype=torch.float8_e8m0fnu,
)

assert scale.dtype == torch.float8_e8m0fnu
raw_scale = scale.data.view(torch.uint8)
assert torch.equal(raw_scale, torch.zeros_like(raw_scale))
38 changes: 38 additions & 0 deletions tests/quantization/test_mxfp4.py
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also no need to include.

Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

def test_mxfp4_e8m0_scale_loading_preserves_raw_bytes():
from types import SimpleNamespace

import pytest
import torch

from vllm.model_executor.layers.fused_moe.layer import FusedMoE

e8m0_dtype = getattr(torch, "float8_e8m0fnu", None)
if e8m0_dtype is None:
pytest.skip("torch does not expose float8_e8m0fnu")

layer = object.__new__(FusedMoE)
layer.moe_config = SimpleNamespace(is_act_and_mul=True)

expert_data = torch.zeros((4, 2), dtype=torch.uint8)
loaded_scale = torch.tensor(
[[0.0078125, 0.015625], [0.5, 1.0]],
dtype=e8m0_dtype,
)

layer._load_w13(
expert_data=expert_data,
shard_dim=0,
shard_id="w1",
loaded_weight=loaded_scale,
tp_rank=0,
)

torch.testing.assert_close(
expert_data[:2],
loaded_scale.view(torch.uint8),
rtol=0,
atol=0,
)
Loading
Loading