Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 13 additions & 5 deletions tests/kernels/helion/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,24 +6,26 @@
from collections.abc import Callable
from contextlib import contextmanager
from pathlib import Path
from typing import Any
from unittest.mock import patch

import helion

from vllm.kernels.helion.case_key import CaseKey
from vllm.kernels.helion.config_manager import ConfigManager
from vllm.kernels.helion.register import register_kernel
from vllm.kernels.helion.utils import get_canonical_gpu_name

GPU_PLATFORM = get_canonical_gpu_name()

DEFAULT_CONFIGS: dict[str, helion.Config] = {
"default": helion.Config(block_sizes=[32]),
DEFAULT_CONFIGS: dict[CaseKey, helion.Config] = {
CaseKey.default(): helion.Config(block_sizes=[32]),
}


@contextmanager
def dummy_kernel_registry(
configs: dict[str, helion.Config] | None = None,
configs: dict[CaseKey, helion.Config] | None = None,
):
"""Context manager providing a register function with automatic config setup.

Expand All @@ -34,7 +36,13 @@ def dummy_kernel_registry(
"""
if configs is None:
configs = DEFAULT_CONFIGS
config_data = {k: v.__dict__["config"] for k, v in configs.items()}

def _to_config_entries(cfgs: dict) -> list[dict[str, Any]]:
pairs: list[dict[str, Any]] = []
for k, v in cfgs.items():
config_data = v.__dict__["config"]
pairs.append({"key": dict(k), "config": config_data})
return pairs

with tempfile.TemporaryDirectory() as tmpdir:
config_dir = Path(tmpdir)
Expand All @@ -55,7 +63,7 @@ def decorator(fn: Callable) -> Callable:
kernel_dir = config_dir / name
kernel_dir.mkdir(parents=True, exist_ok=True)
(kernel_dir / f"{GPU_PLATFORM}.json").write_text(
json.dumps(config_data)
json.dumps(_to_config_entries(configs))
)
return register_kernel(op_name, **kwargs)(fn)

Expand Down
2 changes: 1 addition & 1 deletion tests/kernels/helion/test_autotune.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def test_autotune_disabled_kernel_produces_valid_config(self):
with dummy_kernel_registry(configs={}) as register:
wrapper = register(
"autotune_test_kernel",
config_picker=lambda args, keys: "default",
config_picker=lambda args, keys: None,
fake_impl=lambda *a, **kw: None,
input_generator=lambda: {
"small": (
Expand Down
68 changes: 68 additions & 0 deletions tests/kernels/helion/test_case_key.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import pytest

from vllm.utils.import_utils import has_helion

if not has_helion():
pytest.skip(
"Helion is not installed. Install with: pip install vllm[helion]",
allow_module_level=True,
)

from vllm.kernels.helion.case_key import CaseKey


class TestCaseKey:
"""Test suite for CaseKey class."""

def test_construction_with_dict(self):
key = CaseKey({"intermediate": 2048, "numtokens": 256})
assert key["intermediate"] == 2048
assert key["numtokens"] == 256

def test_empty_construction_raises(self):
with pytest.raises(TypeError, match="at least one key-value pair"):
CaseKey()
with pytest.raises(TypeError, match="at least one key-value pair"):
CaseKey({})

def test_default_construction(self):
key = CaseKey.default()
assert len(key) == 0
assert key.is_default()

def test_non_default_is_not_default(self):
key = CaseKey({"intermediate": 2048})
assert not key.is_default()

def test_hashable_and_equality(self):
a = CaseKey({"intermediate": 2048, "numtokens": 256})
b = CaseKey({"numtokens": 256, "intermediate": 2048})
assert a == b
assert hash(a) == hash(b)
assert a != CaseKey({"intermediate": 4096})
assert CaseKey.default() == CaseKey.default()

configs = {
CaseKey.default(): "default_config",
a: "a_config",
}
assert configs[b] == "a_config"
assert configs[CaseKey.default()] == "default_config"

def test_str_is_sorted_json(self):
assert str(CaseKey({"z": 1, "a": 2})) == '{"a":2,"z":1}'
assert str(CaseKey.default()) == "{}"

def test_immutable(self):
key = CaseKey({"intermediate": 2048})
with pytest.raises(TypeError, match="immutable"):
key["intermediate"] = 4096
with pytest.raises(TypeError, match="immutable"):
del key["intermediate"]
with pytest.raises(TypeError, match="immutable"):
key.update({"numtokens": 256})
with pytest.raises(TypeError, match="immutable"):
key.clear()
127 changes: 85 additions & 42 deletions tests/kernels/helion/test_config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

import helion

from vllm.kernels.helion.case_key import CaseKey
from vllm.kernels.helion.config_manager import (
ConfigManager,
ConfigSet,
Expand All @@ -49,22 +50,25 @@ def test_config_set_creation(self):

def test_config_set_from_dict(self):
"""Test creating ConfigSet from dictionary data."""
# Use realistic config data that helion.Config can handle
config_data = {
"block_sizes": [32, 16],
"num_warps": 4,
"num_stages": 3,
"pid_type": "persistent_interleaved",
}
data = {"h100": {"batch_32_hidden_4096": config_data}}
data = {
"h100": [
{"key": {"batch": 32, "hidden": 4096}, "config": config_data},
]
}

config_set = ConfigSet.from_dict("test_kernel", data)

assert config_set.kernel_name == "test_kernel"
assert config_set.get_platforms() == ["h100"]

# Verify the config was created correctly
config = config_set.get_config("h100", "batch_32_hidden_4096")
internal_key = CaseKey({"batch": 32, "hidden": 4096})
config = config_set.get_config("h100", internal_key)
assert isinstance(config, helion.Config)
assert config.block_sizes == [32, 16]
assert config.num_warps == 4
Expand All @@ -76,17 +80,19 @@ def test_config_set_get_config_keyerror(self):
config_set = ConfigSet("test_kernel")

with pytest.raises(KeyError, match="platform 'h100' not found"):
config_set.get_config("h100", "batch_32_hidden_4096")
config_set.get_config("h100", "nonexistent")

# Use realistic config data
config_data = {"num_warps": 8, "num_stages": 4}
data = {"h100": {"batch_64_hidden_2048": config_data}}
data = {
"h100": [
{"key": {"batch": 64, "hidden": 2048}, "config": config_data},
]
}
config_set = ConfigSet.from_dict("test_kernel", data)

with pytest.raises(
KeyError, match="config_key 'batch_32_hidden_4096' not found"
):
config_set.get_config("h100", "batch_32_hidden_4096")
nonexistent_key = CaseKey({"batch": 32, "hidden": 4096})
with pytest.raises(KeyError, match="config_key .* not found"):
config_set.get_config("h100", nonexistent_key)

def test_config_set_get_platforms(self):
"""Test get_platforms method."""
Expand All @@ -95,8 +101,12 @@ def test_config_set_get_platforms(self):
config2 = {"num_warps": 8, "num_stages": 5}

data = {
"h100": {"batch_32_hidden_4096": config1},
"a100": {"batch_16_hidden_2048": config2},
"h100": [
{"key": {"batch": 32, "hidden": 4096}, "config": config1},
],
"a100": [
{"key": {"batch": 16, "hidden": 2048}, "config": config2},
],
}
config_set = ConfigSet.from_dict("test_kernel", data)

Expand All @@ -105,39 +115,49 @@ def test_config_set_get_platforms(self):

def test_config_set_get_config_keys(self):
"""Test get_config_keys method."""
# Use realistic config data
config1 = {"num_warps": 4, "num_stages": 3}
config2 = {"num_warps": 8, "num_stages": 5}

data = {
"h100": {
"batch_32_hidden_4096": config1,
"batch_64_hidden_2048": config2,
}
"h100": [
{"key": {"batch": 32, "hidden": 4096}, "config": config1},
{"key": {"batch": 64, "hidden": 2048}, "config": config2},
]
}
config_set = ConfigSet.from_dict("test_kernel", data)

config_keys = config_set.get_config_keys("h100")
assert config_keys == ["batch_32_hidden_4096", "batch_64_hidden_2048"]
expected_keys = sorted(
[
CaseKey({"batch": 32, "hidden": 4096}),
CaseKey({"batch": 64, "hidden": 2048}),
],
key=lambda k: str(k) if k is not None else "",
)
assert config_keys == expected_keys

assert config_set.get_config_keys("v100") == []

def test_config_set_to_dict(self):
"""Test converting ConfigSet to dictionary."""
# Use realistic config data
original_config = {
"block_sizes": [64, 32],
"num_warps": 16,
"num_stages": 4,
"pid_type": "persistent_blocked",
}
original_data = {"h100": {"batch_32_hidden_4096": original_config}}
original_data = {
"h100": [
{"key": {"batch": 32, "hidden": 4096}, "config": original_config},
]
}

config_set = ConfigSet.from_dict("test_kernel", original_data)
result_data = config_set.to_dict()

# The result should match the original (Config roundtrip should work)
assert result_data == original_data
internal_key = CaseKey({"batch": 32, "hidden": 4096})
assert internal_key in result_data["h100"]
assert result_data["h100"][internal_key] == original_config


class TestConfigManager:
Expand Down Expand Up @@ -202,7 +222,10 @@ def test_load_config_set_valid_file(self):
kernel_dir.mkdir()
platform_file = kernel_dir / "h100.json"
with open(platform_file, "w") as f:
json.dump({"batch_32_hidden_4096": kernel_config}, f)
json.dump(
[{"key": {"batch": 32, "hidden": 4096}, "config": kernel_config}],
f,
)

manager = ConfigManager(base_dir=temp_dir)
config_set = manager.load_config_set("test_kernel")
Expand All @@ -211,7 +234,8 @@ def test_load_config_set_valid_file(self):
assert config_set.kernel_name == "test_kernel"
assert config_set.get_platforms() == ["h100"]

config = config_set.get_config("h100", "batch_32_hidden_4096")
internal_key = CaseKey({"batch": 32, "hidden": 4096})
config = config_set.get_config("h100", internal_key)
assert isinstance(config, helion.Config)
assert config.block_sizes == [128, 64]
assert config.num_warps == 8
Expand Down Expand Up @@ -241,7 +265,11 @@ def test_save_config_set(self):
"num_stages": 8,
"pid_type": "persistent_blocked",
}
data = {"h100": {"batch_32_hidden_4096": kernel_config}}
data = {
"h100": [
{"key": {"batch": 32, "hidden": 4096}, "config": kernel_config},
]
}
config_set = ConfigSet.from_dict("test_kernel", data)

manager = ConfigManager(base_dir=temp_dir)
Expand All @@ -255,13 +283,21 @@ def test_save_config_set(self):
assert platform_file.exists()
with open(platform_file) as f:
loaded_data = json.load(f)
assert loaded_data == data["h100"]
assert isinstance(loaded_data, list)
assert len(loaded_data) == 1
entry = loaded_data[0]
assert entry["key"] == {"batch": 32, "hidden": 4096}
assert entry["config"] == kernel_config

def test_save_config_set_creates_directory(self):
"""Test that save_config_set creates parent directories if needed."""
with tempfile.TemporaryDirectory() as temp_dir:
nested_dir = Path(temp_dir) / "nested" / "configs"
data = {"h100": {"default": {"num_warps": 4}}}
data = {
"h100": [
{"key": {}, "config": {"num_warps": 4}},
]
}
config_set = ConfigSet.from_dict("test_kernel", data)

manager = ConfigManager(base_dir=nested_dir)
Expand All @@ -288,34 +324,41 @@ def test_get_platform_configs(self):
kernel_dir.mkdir()
with open(kernel_dir / "h100.json", "w") as f:
json.dump(
{
"batch_32_hidden_4096": config_1,
"batch_64_hidden_2048": config_2,
"default": default_config,
},
[
{"key": {"batch": 32, "hidden": 4096}, "config": config_1},
{"key": {"batch": 64, "hidden": 2048}, "config": config_2},
{"key": {}, "config": default_config},
],
f,
)
with open(kernel_dir / "a100.json", "w") as f:
json.dump({"batch_16_hidden_1024": config_3}, f)
json.dump(
[{"key": {"batch": 16, "hidden": 1024}, "config": config_3}],
f,
)

manager = ConfigManager(base_dir=temp_dir)

key_b32_h4096 = CaseKey({"batch": 32, "hidden": 4096})
key_b64_h2048 = CaseKey({"batch": 64, "hidden": 2048})
key_b16_h1024 = CaseKey({"batch": 16, "hidden": 1024})

h100_configs = manager.get_platform_configs("test_kernel", "h100")
assert len(h100_configs) == 3
assert "batch_32_hidden_4096" in h100_configs
assert "batch_64_hidden_2048" in h100_configs
assert "default" in h100_configs
assert key_b32_h4096 in h100_configs
assert key_b64_h2048 in h100_configs
assert CaseKey.default() in h100_configs
for config in h100_configs.values():
assert isinstance(config, helion.Config)

assert h100_configs["batch_32_hidden_4096"].num_warps == 4
assert h100_configs["default"].num_stages == 7
assert h100_configs[key_b32_h4096].num_warps == 4
assert h100_configs[CaseKey.default()].num_stages == 7

a100_configs = manager.get_platform_configs("test_kernel", "a100")
assert len(a100_configs) == 1
assert "batch_16_hidden_1024" in a100_configs
assert isinstance(a100_configs["batch_16_hidden_1024"], helion.Config)
assert a100_configs["batch_16_hidden_1024"].num_warps == 2
assert key_b16_h1024 in a100_configs
assert isinstance(a100_configs[key_b16_h1024], helion.Config)
assert a100_configs[key_b16_h1024].num_warps == 2

nonexistent_configs = manager.get_platform_configs("test_kernel", "v100")
assert len(nonexistent_configs) == 0
Expand Down
Loading
Loading