Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
import pytest
import torch

import vllm.model_executor.model_loader.weight_utils as weight_utils
from vllm.config.load import LoadConfig
from vllm.model_executor.model_loader.default_loader import DefaultModelLoader
from vllm.model_executor.model_loader.weight_utils import (
download_weights_from_hf,
fastsafetensors_weights_iterator,
Expand All @@ -16,6 +19,77 @@
from vllm.platforms import current_platform


def test_default_loader_filters_fastsafetensors_before_materializing(monkeypatch):
class FakeProcessGroup:
def size(self):
return 1

class FakeFileBuffer:
def __init__(self):
self.key_to_rank_lidx = {
"model.layers.0.self_attn.q_proj.weight": (0, 0),
"model.layers.0.mlp.experts.0.gate_proj.weight": (0, 1),
"model.layers.0.mlp.experts.1.gate_proj.weight": (0, 2),
"model.mtp.0.weight": (0, 3),
}
self.loaded_keys: list[str] = []
self.closed = False

def get_tensor(self, key: str):
self.loaded_keys.append(key)
return torch.tensor([len(self.loaded_keys)])

def close(self):
self.closed = True

class FakeLoader:
def __init__(self, file_buffer):
self.file_buffer = file_buffer
self.closed = False

def copy_files_to_device(self):
return self.file_buffer

def close(self):
self.closed = True

file_buffer = FakeFileBuffer()
loader = FakeLoader(file_buffer)

model_loader = DefaultModelLoader(LoadConfig(load_format="fastsafetensors"))
model_loader.local_expert_ids = {0}
monkeypatch.setattr(
model_loader,
"_prepare_weights",
lambda *_args: ("/weights", ["model.safetensors"], True),
)
monkeypatch.setattr(torch.distributed, "is_initialized", lambda: False)
monkeypatch.setattr(weight_utils, "SingleGroup", FakeProcessGroup)
monkeypatch.setattr(
weight_utils,
"_init_fastsafetensors_loader",
lambda *_args, **_kwargs: loader,
)

loaded = dict(
model_loader._get_weights_iterator(
DefaultModelLoader.Source("model", revision=None),
weight_name_filter=lambda name: "model.mtp." in name,
)
)

assert set(loaded) == {
"model.layers.0.self_attn.q_proj.weight",
"model.layers.0.mlp.experts.0.gate_proj.weight",
}
assert file_buffer.loaded_keys == [
"model.layers.0.self_attn.q_proj.weight",
"model.layers.0.mlp.experts.0.gate_proj.weight",
]
assert file_buffer.closed
assert loader.closed


@pytest.mark.skipif(
not current_platform.is_cuda_alike(),
reason="fastsafetensors requires NVIDIA/AMD GPUs",
Expand Down
14 changes: 14 additions & 0 deletions tests/model_executor/model_loader/test_ep_weight_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,20 @@ def test_ep2_rank0_gets_half_experts(self, synthetic_moe_files):
assert "model.layers.0.input_layernorm.weight" in loaded
assert "model.layers.0.mlp.shared_experts.gate_proj.weight" in loaded

def test_weight_name_filter_skips_dense_weights(self, synthetic_moe_files):
files, _ = synthetic_moe_files
loaded = dict(
safetensors_weights_iterator(
files,
False,
weight_name_filter=lambda name: "self_attn.q_proj" in name,
)
)

assert "model.layers.0.self_attn.q_proj.weight" not in loaded
assert "model.embed_tokens.weight" in loaded
assert "model.layers.0.mlp.shared_experts.gate_proj.weight" in loaded

def test_ep2_rank1_gets_other_half(self, synthetic_moe_files):
files, expected = synthetic_moe_files
local_ids = compute_local_expert_ids(8, ep_size=2, ep_rank=1)
Expand Down
288 changes: 288 additions & 0 deletions tests/v1/attention/test_sm120_deepgemm_fallbacks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,288 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import pytest
import torch

import vllm.utils.deep_gemm as deep_gemm_utils
from vllm.model_executor.layers.sparse_attn_indexer import (
_decode_logits_width,
_decode_topk_logits_width,
_sparse_indexer_requires_deep_gemm,
)
from vllm.platforms import current_platform
from vllm.utils.math_utils import cdiv
from vllm.v1.attention.backends.mla import indexer as mla_indexer
from vllm.v1.attention.ops.deepseek_v4_ops import sm12x_deep_gemm_fallbacks


def _make_indexer_kv_cache(
kv_fp8: torch.Tensor,
kv_scale: torch.Tensor,
) -> torch.Tensor:
num_blocks, block_size, num_kv_heads, head_dim = kv_fp8.shape
assert num_kv_heads == 1
fused_kv = torch.empty(
num_blocks,
block_size,
1,
head_dim + torch.float32.itemsize,
device=kv_fp8.device,
dtype=torch.uint8,
)
block_stride = fused_kv.stride(0)
kv_values = torch.as_strided(
fused_kv,
size=kv_fp8.shape,
stride=(block_stride, head_dim, head_dim, 1),
)
kv_scales = torch.as_strided(
fused_kv,
size=(num_blocks, block_size, 1, torch.float32.itemsize),
stride=(block_stride, torch.float32.itemsize, torch.float32.itemsize, 1),
storage_offset=block_size * head_dim,
)
kv_values.copy_(kv_fp8.view(torch.uint8))
kv_scales.copy_(kv_scale.contiguous().view(torch.uint8))
return fused_kv


def _reference_paged_mqa_logits(
q_fp8: torch.Tensor,
kv_fp8: torch.Tensor,
kv_scale: torch.Tensor,
weights: torch.Tensor,
context_lens: torch.Tensor,
block_tables: torch.Tensor,
max_model_len: int,
) -> torch.Tensor:
batch_size, next_n, _, _ = q_fp8.shape
_, block_size, _, _ = kv_fp8.shape
logits = torch.full(
(batch_size * next_n, max_model_len),
float("-inf"),
device=q_fp8.device,
dtype=torch.float32,
)
q = q_fp8.float()
kv = kv_fp8.float() * kv_scale.float()
for batch_idx in range(batch_size):
for next_idx in range(next_n):
row = batch_idx * next_n + next_idx
context_len = min(
int(context_lens[batch_idx, next_idx].item()),
max_model_len,
)
for token_idx in range(context_len):
block_idx = block_tables[batch_idx, token_idx // block_size]
block_offset = token_idx % block_size
k = kv[block_idx, block_offset, 0]
scores = (q[batch_idx, next_idx] * k).sum(dim=-1).relu()
logits[row, token_idx] = (scores * weights[row]).sum()
return logits


def test_decode_logits_width_uses_active_context_bound():
assert _decode_logits_width(262144, 1024) == 1024
assert _decode_logits_width(4096, 8192) == 4096
assert _decode_logits_width(4096, 0) == 4096
assert _decode_logits_width(0, 1024) == 0


def test_decode_topk_logits_width_keeps_topk_kernel_width():
assert _decode_topk_logits_width(262144, 1024, 512) == 1024
assert _decode_topk_logits_width(262144, 128, 512) == 512
assert _decode_topk_logits_width(300, 128, 512) == 300
assert _decode_topk_logits_width(0, 128, 512) == 0


def test_sm120_sparse_indexer_does_not_require_deep_gemm(monkeypatch):
monkeypatch.setattr(current_platform, "is_cuda", lambda: True)
monkeypatch.setattr(
current_platform,
"is_device_capability_family",
lambda capability: capability == 120,
)

assert _sparse_indexer_requires_deep_gemm() is False


def test_non_sm120_cuda_sparse_indexer_still_requires_deep_gemm(monkeypatch):
monkeypatch.setattr(current_platform, "is_cuda", lambda: True)
monkeypatch.setattr(
current_platform,
"is_device_capability_family",
lambda capability: False,
)

assert _sparse_indexer_requires_deep_gemm() is True


def test_sm120_mla_indexer_skips_deep_gemm_scheduler_metadata(monkeypatch):
monkeypatch.setattr(current_platform, "is_cuda", lambda: True)
monkeypatch.setattr(
current_platform,
"is_device_capability_family",
lambda capability: capability == 120,
)
monkeypatch.setattr(mla_indexer, "has_deep_gemm", lambda: True)

assert not mla_indexer._uses_deep_gemm_scheduler_metadata()


def test_cuda_mla_indexer_uses_deep_gemm_scheduler_metadata_off_sm12x(monkeypatch):
monkeypatch.setattr(current_platform, "is_cuda", lambda: True)
monkeypatch.setattr(
current_platform,
"is_device_capability_family",
lambda capability: False,
)
monkeypatch.setattr(mla_indexer, "has_deep_gemm", lambda: True)

assert mla_indexer._uses_deep_gemm_scheduler_metadata()


def test_sm120_fp8_mqa_fallbacks_do_not_initialize_deep_gemm(monkeypatch):
monkeypatch.setattr(
current_platform,
"is_device_capability_family",
lambda capability: capability == 120,
)

def fail_lazy_init():
raise AssertionError("SM120 FP8 MQA should not initialize DeepGEMM")

monkeypatch.setattr(deep_gemm_utils, "_lazy_init", fail_lazy_init)

mqa_result = torch.empty(1)
paged_result = torch.empty(1)
calls = []

def fake_mqa_fallback(*args, **kwargs):
calls.append("mqa")
return mqa_result

def fake_paged_fallback(*args, **kwargs):
calls.append("paged")
return paged_result

monkeypatch.setattr(deep_gemm_utils, "_fp8_mqa_logits_sm12x", fake_mqa_fallback)
monkeypatch.setattr(
deep_gemm_utils, "_fp8_paged_mqa_logits_sm12x", fake_paged_fallback
)

assert (
deep_gemm_utils.fp8_fp4_mqa_logits(
(torch.empty(1, 1, 1), None),
(torch.empty(1, 1), torch.empty(1)),
torch.empty(1, 1),
torch.empty(1, dtype=torch.int32),
torch.empty(1, dtype=torch.int32),
clean_logits=False,
)
is mqa_result
)
assert (
deep_gemm_utils.fp8_fp4_paged_mqa_logits(
(torch.empty(1, 1, 1, 1), None),
torch.empty(1, 1, 1, 5, dtype=torch.uint8),
torch.empty(1, 1),
torch.empty(1, 1, dtype=torch.int32),
torch.empty(1, 1, dtype=torch.int32),
torch.empty(1, dtype=torch.int32),
max_model_len=1,
clean_logits=False,
)
is paged_result
)
assert calls == ["mqa", "paged"]


@pytest.mark.skipif(
not current_platform.is_device_capability_family(120), reason="SM120 only"
)
def test_sm120_paged_mqa_direct_topk_matches_truncated_decode_width(
monkeypatch: pytest.MonkeyPatch,
):
torch.manual_seed(7)
batch_size, next_n, num_heads, head_dim = 2, 2, 8, 32
block_size, max_model_len, num_blocks = 4, 64, 16
active_max_len = 13
topk_tokens = 6
monkeypatch.setattr(deep_gemm_utils, "_lazy_init", lambda: None)
monkeypatch.setattr(
sm12x_deep_gemm_fallbacks,
"_SM120_PAGED_MQA_TOPK_CHUNK_SIZE",
7,
)

q = torch.randn(
batch_size,
next_n,
num_heads,
head_dim,
device="cuda",
dtype=torch.bfloat16,
)
q_fp8 = q.to(torch.float8_e4m3fn).contiguous()
kv = torch.randn(
num_blocks, block_size, 1, head_dim, device="cuda", dtype=torch.bfloat16
)
kv_scale = kv.abs().float().amax(dim=-1, keepdim=True).clamp(1e-4) / 448.0
kv_fp8 = (kv * kv_scale.reciprocal()).to(torch.float8_e4m3fn)
fused_kv = _make_indexer_kv_cache(kv_fp8, kv_scale)

weights = torch.randn(
batch_size * next_n, num_heads, device="cuda", dtype=torch.float32
)
context_lens = torch.tensor(
[[7, active_max_len], [9, 12]], device="cuda", dtype=torch.int32
)
block_tables = (
torch.arange(
batch_size * cdiv(max_model_len, block_size),
device="cuda",
dtype=torch.int32,
).reshape(batch_size, -1)
% num_blocks
)

full_width_topk = torch.empty(
batch_size * next_n, topk_tokens, device="cuda", dtype=torch.int32
)
truncated_width_topk = torch.empty_like(full_width_topk)

assert deep_gemm_utils.fp8_fp4_paged_mqa_topk_indices(
(q_fp8, None),
fused_kv,
weights,
context_lens,
block_tables,
max_model_len,
full_width_topk,
)
assert deep_gemm_utils.fp8_fp4_paged_mqa_topk_indices(
(q_fp8, None),
fused_kv,
weights,
context_lens,
block_tables,
active_max_len,
truncated_width_topk,
)

reference_logits = _reference_paged_mqa_logits(
q_fp8,
kv_fp8,
kv_scale,
weights,
context_lens,
block_tables,
active_max_len,
)
expected_topk = torch.topk(reference_logits, topk_tokens, dim=1).indices.to(
torch.int32
)

torch.testing.assert_close(truncated_width_topk, full_width_topk, rtol=0, atol=0)
torch.testing.assert_close(truncated_width_topk, expected_topk, rtol=0, atol=0)
Loading
Loading