Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 10 additions & 17 deletions tests/kernels/test_attention_selector.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,22 @@
import os
from unittest.mock import patch

import pytest
import torch

from tests.kernels.utils import (STR_FLASH_ATTN_VAL, STR_INVALID_VAL,
override_backend_env_variable)
from vllm.attention.selector import which_attn_to_use


@pytest.mark.parametrize(
"name", ["TORCH_SDPA", "ROCM_FLASH", "XFORMERS", "FLASHINFER"])
@pytest.mark.parametrize("device", ["cpu", "hip", "cuda"])
def test_env(name: str, device: str):
def test_env(name: str, device: str, monkeypatch):
"""Test that the attention selector can be set via environment variable.
Note that we do not test FlashAttn because it is the default backend.
"""
name_backup = os.environ.get("VLLM_ATTENTION_BACKEND", None)
os.environ["VLLM_ATTENTION_BACKEND"] = name

override_backend_env_variable(monkeypatch, name)

if device == "cpu":
with patch("vllm.attention.selector.is_cpu", return_value=True):
Expand All @@ -32,14 +33,11 @@ def test_env(name: str, device: str):
torch.float16, 16)
assert backend.name == name

if name_backup is not None:
os.environ["VLLM_ATTENTION_BACKEND"] = name_backup


def test_flash_attn():
def test_flash_attn(monkeypatch):
"""Test FlashAttn validation."""
name_backup = os.environ.get("VLLM_ATTENTION_BACKEND", None)
os.environ["VLLM_ATTENTION_BACKEND"] = "FLASH_ATTN"

override_backend_env_variable(monkeypatch, STR_FLASH_ATTN_VAL)

# Unsupported CUDA arch
with patch("torch.cuda.get_device_capability", return_value=[7, 5]):
Expand Down Expand Up @@ -71,14 +69,9 @@ def test_flash_attn():
backend = which_attn_to_use(8, 17, 8, None, torch.float16, None, 16)
assert backend.name != "FLASH_ATTN"

if name_backup is not None:
os.environ["VLLM_ATTENTION_BACKEND"] = name_backup


def test_invalid_env():
def test_invalid_env(monkeypatch):
"""Throw an exception if the backend name is invalid."""
name_backup = os.environ.get("VLLM_ATTENTION_BACKEND", None)
os.environ["VLLM_ATTENTION_BACKEND"] = "INVALID"
override_backend_env_variable(monkeypatch, STR_INVALID_VAL)
with pytest.raises(ValueError):
which_attn_to_use(8, 16, 8, None, torch.float16, None, 16)
os.environ["VLLM_ATTENTION_BACKEND"] = name_backup
22 changes: 22 additions & 0 deletions tests/kernels/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
"""Kernel test utils"""

import pytest

STR_BACKEND_ENV_VAR: str = "VLLM_ATTENTION_BACKEND"
STR_FLASH_ATTN_VAL: str = "FLASH_ATTN"
STR_INVALID_VAL: str = "INVALID"


def override_backend_env_variable(mpatch: pytest.MonkeyPatch,
backend_name: str) -> None:
'''
Override the environment variable indicating the vLLM backend temporarily,
using pytest monkeypatch to ensure that the env vars get
reset once the test context exits.

Arguments:

* mpatch: pytest monkeypatch instance
* backend_name: attention backend name to force
'''
mpatch.setenv(STR_BACKEND_ENV_VAR, backend_name)