Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 22 additions & 52 deletions tests/models/language/generation/test_hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,17 +34,6 @@
"LiquidAI/LFM2-1.2B",
]

HF_UNSUPPORTED_MODELS = [
# The HF transformers implementation of
# Mamba2 is buggy for Codestral as it doesn't handle n_groups, so the test
# doesn't compare vLLM output with HF output.
# See https://github.com/huggingface/transformers/pull/35943
"yujiepan/mamba2-codestral-v0.1-tiny-random",
# transformers 4.55 is still producing garbage for this model
# TODO(tdoublep): follow-up on transformers side
"ibm-granite/granite-4.0-tiny-preview"
]

V1_SUPPORTED_MODELS = [
"state-spaces/mamba-130m-hf",
"ai21labs/Jamba-tiny-dev",
Expand Down Expand Up @@ -85,20 +74,13 @@ def test_models(
try:
model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
model_info.check_available_online(on_fail="skip")
hf_version_check = model_info.check_transformers_version(
on_fail="return")
model_info.check_transformers_version(on_fail="skip")
except ValueError:
hf_version_check = None

if hf_version_check is not None:
print(f"Skipping transformers comparison because: {hf_version_check}")
pass

with hf_runner(model) as hf_model:
if model not in HF_UNSUPPORTED_MODELS and hf_version_check is None:
hf_outputs = hf_model.generate_greedy_logprobs_limit(
example_prompts, max_tokens, num_logprobs)
else:
hf_outputs = None
hf_outputs = hf_model.generate_greedy_logprobs_limit(
example_prompts, max_tokens, num_logprobs)

with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "0")
Expand All @@ -116,7 +98,7 @@ def test_models(
else:
vllm_v1_outputs = None

if hf_outputs is not None and vllm_v0_outputs is not None:
if vllm_v0_outputs is not None:
check_logprobs_close(
outputs_0_lst=hf_outputs,
outputs_1_lst=vllm_v0_outputs,
Expand All @@ -125,12 +107,10 @@ def test_models(
)

if model in V1_SUPPORTED_MODELS:
ref_outputs = hf_outputs if hf_outputs is not None else vllm_v0_outputs
assert ref_outputs is not None
check_logprobs_close(
outputs_0_lst=ref_outputs,
outputs_0_lst=hf_outputs,
outputs_1_lst=vllm_v1_outputs,
name_0="hf" if hf_outputs is not None else "vllm-v0",
name_0="hf",
name_1="vllm-v1",
)

Expand Down Expand Up @@ -397,11 +377,8 @@ def test_full_cuda_graph(
pass

with hf_runner(model) as hf_model:
if model not in HF_UNSUPPORTED_MODELS:
hf_outputs = hf_model.generate_greedy_logprobs_limit(
example_prompts, max_tokens, num_logprobs)
else:
hf_outputs = None
hf_outputs = hf_model.generate_greedy_logprobs_limit(
example_prompts, max_tokens, num_logprobs)

with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "0")
Expand All @@ -416,20 +393,18 @@ def test_full_cuda_graph(
vllm_v1_outputs = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs)

if hf_outputs is not None and vllm_v0_outputs is not None:
if vllm_v0_outputs is not None:
check_logprobs_close(
outputs_0_lst=hf_outputs,
outputs_1_lst=vllm_v0_outputs,
name_0="hf",
name_1="vllm-v0",
)

ref_outputs = hf_outputs if hf_outputs is not None else vllm_v0_outputs
assert ref_outputs is not None
check_logprobs_close(
outputs_0_lst=ref_outputs,
outputs_0_lst=hf_outputs,
outputs_1_lst=vllm_v1_outputs,
name_0="hf" if hf_outputs is not None else "vllm-v0",
name_0="hf",
name_1="vllm-v1",
)

Expand All @@ -455,11 +430,8 @@ def test_fp32_state(
pass

with hf_runner(model) as hf_model:
if model not in HF_UNSUPPORTED_MODELS:
hf_outputs = hf_model.generate_greedy_logprobs_limit(
example_prompts, max_tokens, num_logprobs)
else:
hf_outputs = None
hf_outputs = hf_model.generate_greedy_logprobs_limit(
example_prompts, max_tokens, num_logprobs)

with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "0")
Expand All @@ -475,18 +447,16 @@ def test_fp32_state(
vllm_v1_outputs = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs)

if hf_outputs is not None:
check_logprobs_close(
outputs_0_lst=hf_outputs,
outputs_1_lst=vllm_v0_outputs,
name_0="hf",
name_1="vllm-v0",
)
check_logprobs_close(
outputs_0_lst=hf_outputs,
outputs_1_lst=vllm_v0_outputs,
name_0="hf",
name_1="vllm-v0",
)

ref_outputs = hf_outputs if hf_outputs is not None else vllm_v0_outputs
check_logprobs_close(
outputs_0_lst=ref_outputs,
outputs_0_lst=hf_outputs,
outputs_1_lst=vllm_v1_outputs,
name_0="hf" if hf_outputs is not None else "vllm-v0",
name_0="hf",
name_1="vllm-v1",
)
13 changes: 9 additions & 4 deletions tests/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def check_available_online(
"BailingMoeForCausalLM": _HfExamplesInfo("inclusionAI/Ling-lite-1.5",
trust_remote_code=True),
"BambaForCausalLM": _HfExamplesInfo("ibm-ai-platform/Bamba-9B-v1",
min_transformers_version="4.56.0",
min_transformers_version="4.55.3",
extras={"tiny": "hmellor/tiny-random-BambaForCausalLM"}), # noqa: E501
"BloomForCausalLM": _HfExamplesInfo("bigscience/bloom-560m",
{"1b": "bigscience/bloomz-1b1"}),
Expand Down Expand Up @@ -208,7 +208,8 @@ def check_available_online(
"GptOssForCausalLM": _HfExamplesInfo("lmsys/gpt-oss-20b-bf16"),
"GraniteForCausalLM": _HfExamplesInfo("ibm/PowerLM-3b"),
"GraniteMoeForCausalLM": _HfExamplesInfo("ibm/PowerMoE-3b"),
"GraniteMoeHybridForCausalLM": _HfExamplesInfo("ibm-granite/granite-4.0-tiny-preview"), # noqa: E501
"GraniteMoeHybridForCausalLM": _HfExamplesInfo("ibm-granite/granite-4.0-tiny-preview", # noqa: E501
min_transformers_version="4.55.3"),
"GraniteMoeSharedForCausalLM": _HfExamplesInfo("ibm-research/moe-7b-1b-active-shared-experts"), # noqa: E501
"Grok1ModelForCausalLM": _HfExamplesInfo("hpcai-tech/grok-1",
trust_remote_code=True),
Expand All @@ -228,7 +229,7 @@ def check_available_online(
trust_remote_code=True),
"JAISLMHeadModel": _HfExamplesInfo("inceptionai/jais-13b-chat"),
"JambaForCausalLM": _HfExamplesInfo("ai21labs/AI21-Jamba-1.5-Mini",
min_transformers_version="4.56.0",
min_transformers_version="4.55.3",
extras={
"tiny": "ai21labs/Jamba-tiny-dev",
"random": "ai21labs/Jamba-tiny-random", # noqa: E501
Expand All @@ -244,7 +245,11 @@ def check_available_online(
"Llama4ForCausalLM": _HfExamplesInfo("meta-llama/Llama-4-Scout-17B-16E-Instruct", # noqa: E501
is_available_online=False),
"MambaForCausalLM": _HfExamplesInfo("state-spaces/mamba-130m-hf"),
"Mamba2ForCausalLM": _HfExamplesInfo("mistralai/Mamba-Codestral-7B-v0.1"),
"Mamba2ForCausalLM": _HfExamplesInfo("mistralai/Mamba-Codestral-7B-v0.1",
min_transformers_version="4.55.3",
extras={
"random": "yujiepan/mamba2-codestral-v0.1-tiny-random", # noqa: E501
}),
"FalconMambaForCausalLM": _HfExamplesInfo("tiiuae/falcon-mamba-7b-instruct"), # noqa: E501
"MiniCPMForCausalLM": _HfExamplesInfo("openbmb/MiniCPM-2B-sft-bf16",
trust_remote_code=True),
Expand Down