From 9c597c4b26cd5ecc0d38798818eaab99752ab03b Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 3 Jun 2024 01:29:22 -0400 Subject: [PATCH 1/9] FIX: test_attention_selector.py was leaking VLLM_ATTENTION_BACKEND values; fixed with backend context manager --- tests/kernels/test_attention_selector.py | 91 +++++++++++------------- tests/kernels/utils.py | 23 ++++++ tests/utils.py | 27 +++++++ 3 files changed, 93 insertions(+), 48 deletions(-) create mode 100644 tests/kernels/utils.py diff --git a/tests/kernels/test_attention_selector.py b/tests/kernels/test_attention_selector.py index f439afa9b7d2..1726f58cee08 100644 --- a/tests/kernels/test_attention_selector.py +++ b/tests/kernels/test_attention_selector.py @@ -1,10 +1,10 @@ -import os from unittest.mock import patch import pytest import torch from vllm.attention.selector import which_attn_to_use +from tests.kernels.utils import backend_override_fixture @pytest.mark.parametrize( @@ -14,71 +14,66 @@ def test_env(name: str, device: str): """Test that the attention selector can be set via environment variable. Note that we do not test FlashAttn because it is the default backend. """ - name_backup = os.environ.get("VLLM_ATTENTION_BACKEND", None) - os.environ["VLLM_ATTENTION_BACKEND"] = name - if device == "cpu": - with patch("vllm.attention.selector.is_cpu", return_value=True): + with backend_override_fixture(name): + + if device == "cpu": + with patch("vllm.attention.selector.is_cpu", return_value=True): + backend = which_attn_to_use(8, 16, 8, None, torch.float16, + torch.float16, 16) + assert backend.name == "TORCH_SDPA" + elif device == "hip": + with patch("vllm.attention.selector.is_hip", return_value=True): + backend = which_attn_to_use(8, 16, 8, None, torch.float16, + torch.float16, 16) + assert backend.name == "ROCM_FLASH" + else: backend = which_attn_to_use(8, 16, 8, None, torch.float16, torch.float16, 16) - assert backend.name == "TORCH_SDPA" - elif device == "hip": - with patch("vllm.attention.selector.is_hip", return_value=True): - backend = which_attn_to_use(8, 16, 8, None, torch.float16, - torch.float16, 16) - assert backend.name == "ROCM_FLASH" - else: - backend = which_attn_to_use(8, 16, 8, None, torch.float16, - torch.float16, 16) - assert backend.name == name - - if name_backup is not None: - os.environ["VLLM_ATTENTION_BACKEND"] = name_backup + assert backend.name == name def test_flash_attn(): """Test FlashAttn validation.""" - name_backup = os.environ.get("VLLM_ATTENTION_BACKEND", None) - os.environ["VLLM_ATTENTION_BACKEND"] = "FLASH_ATTN" - # Unsupported CUDA arch - with patch("torch.cuda.get_device_capability", return_value=[7, 5]): - backend = which_attn_to_use(8, 16, 8, None, torch.float16, None, 16) - assert backend.name != "FLASH_ATTN" + with backend_override_fixture("FLASH_ATTN"): - # Unsupported data type - backend = which_attn_to_use(8, 16, 8, None, torch.float8_e4m3fn, None, 16) - assert backend.name != "FLASH_ATTN" + # Unsupported CUDA arch + with patch("torch.cuda.get_device_capability", return_value=[7, 5]): + backend = which_attn_to_use(8, 16, 8, None, torch.float16, None, + 16) + assert backend.name != "FLASH_ATTN" - # Unsupported kv cache data type - backend = which_attn_to_use(8, 16, 8, None, torch.float16, "fp8", 16) - assert backend.name != "FLASH_ATTN" + # Unsupported data type + backend = which_attn_to_use(8, 16, 8, None, torch.float8_e4m3fn, None, + 16) + assert backend.name != "FLASH_ATTN" - # Unsupported block size - backend = which_attn_to_use(8, 16, 8, None, torch.float16, None, 8) - assert backend.name != "FLASH_ATTN" + # Unsupported kv cache data type + backend = which_attn_to_use(8, 16, 8, None, torch.float16, "fp8", 16) + assert backend.name != "FLASH_ATTN" - # Unsupported sliding window - backend = which_attn_to_use(8, 16, 8, 1, torch.float16, None, 16) - assert backend.name != "FLASH_ATTN" + # Unsupported block size + backend = which_attn_to_use(8, 16, 8, None, torch.float16, None, 8) + assert backend.name != "FLASH_ATTN" - # flash-attn is not installed - with patch.dict('sys.modules', {'vllm_flash_attn': None}): - backend = which_attn_to_use(8, 16, 8, None, torch.float16, None, 16) + # Unsupported sliding window + backend = which_attn_to_use(8, 16, 8, 1, torch.float16, None, 16) assert backend.name != "FLASH_ATTN" - # Unsupported head size - backend = which_attn_to_use(8, 17, 8, None, torch.float16, None, 16) - assert backend.name != "FLASH_ATTN" + # flash-attn is not installed + with patch.dict('sys.modules', {'vllm_flash_attn': None}): + backend = which_attn_to_use(8, 16, 8, None, torch.float16, None, + 16) + assert backend.name != "FLASH_ATTN" - if name_backup is not None: - os.environ["VLLM_ATTENTION_BACKEND"] = name_backup + # Unsupported head size + backend = which_attn_to_use(8, 17, 8, None, torch.float16, None, 16) + assert backend.name != "FLASH_ATTN" def test_invalid_env(): """Throw an exception if the backend name is invalid.""" - name_backup = os.environ.get("VLLM_ATTENTION_BACKEND", None) - os.environ["VLLM_ATTENTION_BACKEND"] = "INVALID" - with pytest.raises(ValueError): + + with backend_override_fixture("INVALID"), pytest.raises(ValueError): which_attn_to_use(8, 16, 8, None, torch.float16, None, 16) - os.environ["VLLM_ATTENTION_BACKEND"] = name_backup diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py new file mode 100644 index 000000000000..955a96bae2a8 --- /dev/null +++ b/tests/kernels/utils.py @@ -0,0 +1,23 @@ +"""Kernel test utils""" + +from tests.utils import env_var_fixture +from contextlib import contextmanager +from typing import Iterator + +@contextmanager +def backend_override_fixture(backend_name: str) -> Iterator[None]: + ''' + Text fixture, temporarily configures the vLLM backend by setting + VLLM_ATTENTION_BACKEND, then resets the environment outside of + the fixture. + + Usage: + + with backend_override_fixture("backend_name"): + # code that depends on vLLM backend + + # VLLM_ATTENTION_BACKEND is returned to original value + # or unset + ''' + with env_var_fixture('VLLM_ATTENTION_BACKEND', backend_name): + yield diff --git a/tests/utils.py b/tests/utils.py index 329842911e15..48666ca652dd 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -5,6 +5,8 @@ import warnings from contextlib import contextmanager +from typing import Iterator + import ray import requests @@ -101,3 +103,28 @@ def error_on_warning(): warnings.simplefilter("error") yield + + +@contextmanager +def env_var_fixture(var_name: str, value: str) -> Iterator[None]: + ''' + Text fixture, temporarily assigns value var_name environment variable, + then resets environment variable outside of test fixture. + + Usage: + + with env_var_fixture("my_var","my_val"): + # code that depends on my_val == "my_val" + + # my_var is returned to original value or unset + ''' + original_value = os.environ.get(var_name) # Store the original value + os.environ[var_name] = value # Set the new value + try: + yield + finally: + # Restore the original value + if original_value is None: + del os.environ[var_name] + else: + os.environ[var_name] = original_value From 9831ce63077cd4b531979dde442b4185a18b16b8 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 3 Jun 2024 03:06:50 -0400 Subject: [PATCH 2/9] formatting --- tests/kernels/test_attention_selector.py | 2 +- tests/kernels/utils.py | 4 +++- tests/utils.py | 1 - 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/kernels/test_attention_selector.py b/tests/kernels/test_attention_selector.py index 1726f58cee08..b0b383974904 100644 --- a/tests/kernels/test_attention_selector.py +++ b/tests/kernels/test_attention_selector.py @@ -3,8 +3,8 @@ import pytest import torch -from vllm.attention.selector import which_attn_to_use from tests.kernels.utils import backend_override_fixture +from vllm.attention.selector import which_attn_to_use @pytest.mark.parametrize( diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index 955a96bae2a8..8ebc2fc5905a 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -1,9 +1,11 @@ """Kernel test utils""" -from tests.utils import env_var_fixture from contextlib import contextmanager from typing import Iterator +from tests.utils import env_var_fixture + + @contextmanager def backend_override_fixture(backend_name: str) -> Iterator[None]: ''' diff --git a/tests/utils.py b/tests/utils.py index 48666ca652dd..adbff8e8dc1c 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -4,7 +4,6 @@ import time import warnings from contextlib import contextmanager - from typing import Iterator import ray From e738fb4ee2c98f6af912e61c9a4a99acfe887dcf Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 3 Jun 2024 11:47:40 -0400 Subject: [PATCH 3/9] reverted my custom env var patch impl --- tests/kernels/utils.py | 25 ------------------------- tests/utils.py | 26 -------------------------- 2 files changed, 51 deletions(-) delete mode 100644 tests/kernels/utils.py diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py deleted file mode 100644 index 8ebc2fc5905a..000000000000 --- a/tests/kernels/utils.py +++ /dev/null @@ -1,25 +0,0 @@ -"""Kernel test utils""" - -from contextlib import contextmanager -from typing import Iterator - -from tests.utils import env_var_fixture - - -@contextmanager -def backend_override_fixture(backend_name: str) -> Iterator[None]: - ''' - Text fixture, temporarily configures the vLLM backend by setting - VLLM_ATTENTION_BACKEND, then resets the environment outside of - the fixture. - - Usage: - - with backend_override_fixture("backend_name"): - # code that depends on vLLM backend - - # VLLM_ATTENTION_BACKEND is returned to original value - # or unset - ''' - with env_var_fixture('VLLM_ATTENTION_BACKEND', backend_name): - yield diff --git a/tests/utils.py b/tests/utils.py index adbff8e8dc1c..329842911e15 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -4,7 +4,6 @@ import time import warnings from contextlib import contextmanager -from typing import Iterator import ray import requests @@ -102,28 +101,3 @@ def error_on_warning(): warnings.simplefilter("error") yield - - -@contextmanager -def env_var_fixture(var_name: str, value: str) -> Iterator[None]: - ''' - Text fixture, temporarily assigns value var_name environment variable, - then resets environment variable outside of test fixture. - - Usage: - - with env_var_fixture("my_var","my_val"): - # code that depends on my_val == "my_val" - - # my_var is returned to original value or unset - ''' - original_value = os.environ.get(var_name) # Store the original value - os.environ[var_name] = value # Set the new value - try: - yield - finally: - # Restore the original value - if original_value is None: - del os.environ[var_name] - else: - os.environ[var_name] = original_value From dfe9c10389beccfc43b2bddf687075e07e7283b9 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 3 Jun 2024 12:01:52 -0400 Subject: [PATCH 4/9] monkeypatch works --- tests/kernels/test_attention_selector.py | 93 ++++++++++++------------ 1 file changed, 47 insertions(+), 46 deletions(-) diff --git a/tests/kernels/test_attention_selector.py b/tests/kernels/test_attention_selector.py index b0b383974904..ebd2d460dc45 100644 --- a/tests/kernels/test_attention_selector.py +++ b/tests/kernels/test_attention_selector.py @@ -1,79 +1,80 @@ +import os from unittest.mock import patch import pytest import torch -from tests.kernels.utils import backend_override_fixture from vllm.attention.selector import which_attn_to_use +_backend_env_var = "VLLM_ATTENTION_BACKEND" +_flash_attn_val = "FLASH_ATTN" +_invalid_val = "INVALID" + @pytest.mark.parametrize( "name", ["TORCH_SDPA", "ROCM_FLASH", "XFORMERS", "FLASHINFER"]) @pytest.mark.parametrize("device", ["cpu", "hip", "cuda"]) -def test_env(name: str, device: str): +def test_env(name: str, device: str, monkeypatch): """Test that the attention selector can be set via environment variable. Note that we do not test FlashAttn because it is the default backend. """ - with backend_override_fixture(name): - - if device == "cpu": - with patch("vllm.attention.selector.is_cpu", return_value=True): - backend = which_attn_to_use(8, 16, 8, None, torch.float16, - torch.float16, 16) - assert backend.name == "TORCH_SDPA" - elif device == "hip": - with patch("vllm.attention.selector.is_hip", return_value=True): - backend = which_attn_to_use(8, 16, 8, None, torch.float16, - torch.float16, 16) - assert backend.name == "ROCM_FLASH" - else: + monkeypatch.setenv(_backend_env_var,name) + + if device == "cpu": + with patch("vllm.attention.selector.is_cpu", return_value=True): backend = which_attn_to_use(8, 16, 8, None, torch.float16, torch.float16, 16) - assert backend.name == name + assert backend.name == "TORCH_SDPA" + elif device == "hip": + with patch("vllm.attention.selector.is_hip", return_value=True): + backend = which_attn_to_use(8, 16, 8, None, torch.float16, + torch.float16, 16) + assert backend.name == "ROCM_FLASH" + else: + backend = which_attn_to_use(8, 16, 8, None, torch.float16, + torch.float16, 16) + assert backend.name == name -def test_flash_attn(): +def test_flash_attn(monkeypatch): """Test FlashAttn validation.""" - with backend_override_fixture("FLASH_ATTN"): - - # Unsupported CUDA arch - with patch("torch.cuda.get_device_capability", return_value=[7, 5]): - backend = which_attn_to_use(8, 16, 8, None, torch.float16, None, - 16) - assert backend.name != "FLASH_ATTN" + monkeypatch.setenv(_backend_env_var,_flash_attn_val) - # Unsupported data type - backend = which_attn_to_use(8, 16, 8, None, torch.float8_e4m3fn, None, - 16) + # Unsupported CUDA arch + with patch("torch.cuda.get_device_capability", return_value=[7, 5]): + backend = which_attn_to_use(8, 16, 8, None, torch.float16, None, 16) assert backend.name != "FLASH_ATTN" - # Unsupported kv cache data type - backend = which_attn_to_use(8, 16, 8, None, torch.float16, "fp8", 16) - assert backend.name != "FLASH_ATTN" + # Unsupported data type + backend = which_attn_to_use(8, 16, 8, None, torch.float8_e4m3fn, None, 16) + assert backend.name != "FLASH_ATTN" - # Unsupported block size - backend = which_attn_to_use(8, 16, 8, None, torch.float16, None, 8) - assert backend.name != "FLASH_ATTN" + # Unsupported kv cache data type + backend = which_attn_to_use(8, 16, 8, None, torch.float16, "fp8", 16) + assert backend.name != "FLASH_ATTN" - # Unsupported sliding window - backend = which_attn_to_use(8, 16, 8, 1, torch.float16, None, 16) - assert backend.name != "FLASH_ATTN" + # Unsupported block size + backend = which_attn_to_use(8, 16, 8, None, torch.float16, None, 8) + assert backend.name != "FLASH_ATTN" - # flash-attn is not installed - with patch.dict('sys.modules', {'vllm_flash_attn': None}): - backend = which_attn_to_use(8, 16, 8, None, torch.float16, None, - 16) - assert backend.name != "FLASH_ATTN" + # Unsupported sliding window + backend = which_attn_to_use(8, 16, 8, 1, torch.float16, None, 16) + assert backend.name != "FLASH_ATTN" - # Unsupported head size - backend = which_attn_to_use(8, 17, 8, None, torch.float16, None, 16) + # flash-attn is not installed + with patch.dict('sys.modules', {'vllm_flash_attn': None}): + backend = which_attn_to_use(8, 16, 8, None, torch.float16, None, 16) assert backend.name != "FLASH_ATTN" + # Unsupported head size + backend = which_attn_to_use(8, 17, 8, None, torch.float16, None, 16) + assert backend.name != "FLASH_ATTN" -def test_invalid_env(): - """Throw an exception if the backend name is invalid.""" - with backend_override_fixture("INVALID"), pytest.raises(ValueError): +def test_invalid_env(monkeypatch): + """Throw an exception if the backend name is invalid.""" + monkeypatch.setenv(_backend_env_var,_invalid_val) + with pytest.raises(ValueError): which_attn_to_use(8, 16, 8, None, torch.float16, None, 16) From 822175834eb2846c826a63357f731966ed83abce Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 3 Jun 2024 12:02:31 -0400 Subject: [PATCH 5/9] formatting --- tests/kernels/test_attention_selector.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/tests/kernels/test_attention_selector.py b/tests/kernels/test_attention_selector.py index ebd2d460dc45..0b4bb7c353cc 100644 --- a/tests/kernels/test_attention_selector.py +++ b/tests/kernels/test_attention_selector.py @@ -1,4 +1,3 @@ -import os from unittest.mock import patch import pytest @@ -19,7 +18,7 @@ def test_env(name: str, device: str, monkeypatch): Note that we do not test FlashAttn because it is the default backend. """ - monkeypatch.setenv(_backend_env_var,name) + monkeypatch.setenv(_backend_env_var, name) if device == "cpu": with patch("vllm.attention.selector.is_cpu", return_value=True): @@ -40,7 +39,7 @@ def test_env(name: str, device: str, monkeypatch): def test_flash_attn(monkeypatch): """Test FlashAttn validation.""" - monkeypatch.setenv(_backend_env_var,_flash_attn_val) + monkeypatch.setenv(_backend_env_var, _flash_attn_val) # Unsupported CUDA arch with patch("torch.cuda.get_device_capability", return_value=[7, 5]): @@ -75,6 +74,6 @@ def test_flash_attn(monkeypatch): def test_invalid_env(monkeypatch): """Throw an exception if the backend name is invalid.""" - monkeypatch.setenv(_backend_env_var,_invalid_val) + monkeypatch.setenv(_backend_env_var, _invalid_val) with pytest.raises(ValueError): which_attn_to_use(8, 16, 8, None, torch.float16, None, 16) From cbb89b1dd912cb816d3c7d721bb129baa08c83b3 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 3 Jun 2024 12:08:56 -0400 Subject: [PATCH 6/9] refactored constants into tests/kernels/utils.py --- tests/kernels/test_attention_selector.py | 12 +++++------- tests/kernels/utils.py | 5 +++++ 2 files changed, 10 insertions(+), 7 deletions(-) create mode 100644 tests/kernels/utils.py diff --git a/tests/kernels/test_attention_selector.py b/tests/kernels/test_attention_selector.py index 0b4bb7c353cc..7bc0439f3ee8 100644 --- a/tests/kernels/test_attention_selector.py +++ b/tests/kernels/test_attention_selector.py @@ -3,12 +3,10 @@ import pytest import torch +from tests.kernels.utils import (STR_BACKEND_ENV_VAR, STR_FLASH_ATTN_VAL, + STR_INVALID_VAL) from vllm.attention.selector import which_attn_to_use -_backend_env_var = "VLLM_ATTENTION_BACKEND" -_flash_attn_val = "FLASH_ATTN" -_invalid_val = "INVALID" - @pytest.mark.parametrize( "name", ["TORCH_SDPA", "ROCM_FLASH", "XFORMERS", "FLASHINFER"]) @@ -18,7 +16,7 @@ def test_env(name: str, device: str, monkeypatch): Note that we do not test FlashAttn because it is the default backend. """ - monkeypatch.setenv(_backend_env_var, name) + monkeypatch.setenv(STR_BACKEND_ENV_VAR, name) if device == "cpu": with patch("vllm.attention.selector.is_cpu", return_value=True): @@ -39,7 +37,7 @@ def test_env(name: str, device: str, monkeypatch): def test_flash_attn(monkeypatch): """Test FlashAttn validation.""" - monkeypatch.setenv(_backend_env_var, _flash_attn_val) + monkeypatch.setenv(STR_BACKEND_ENV_VAR, STR_FLASH_ATTN_VAL) # Unsupported CUDA arch with patch("torch.cuda.get_device_capability", return_value=[7, 5]): @@ -74,6 +72,6 @@ def test_flash_attn(monkeypatch): def test_invalid_env(monkeypatch): """Throw an exception if the backend name is invalid.""" - monkeypatch.setenv(_backend_env_var, _invalid_val) + monkeypatch.setenv(STR_BACKEND_ENV_VAR, STR_INVALID_VAL) with pytest.raises(ValueError): which_attn_to_use(8, 16, 8, None, torch.float16, None, 16) diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py new file mode 100644 index 000000000000..74ad9d8256e3 --- /dev/null +++ b/tests/kernels/utils.py @@ -0,0 +1,5 @@ +"""Kernel test utils""" + +STR_BACKEND_ENV_VAR = "VLLM_ATTENTION_BACKEND" +STR_FLASH_ATTN_VAL = "FLASH_ATTN" +STR_INVALID_VAL = "INVALID" \ No newline at end of file From ca570e7a078540d8788dc6b7cf961039f699a761 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 3 Jun 2024 12:17:57 -0400 Subject: [PATCH 7/9] a refactoring backend override functionality into tests/kernels/utils.py --- tests/kernels/test_attention_selector.py | 10 +++++----- tests/kernels/utils.py | 6 +++++- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/tests/kernels/test_attention_selector.py b/tests/kernels/test_attention_selector.py index 7bc0439f3ee8..ea3ccb026ea2 100644 --- a/tests/kernels/test_attention_selector.py +++ b/tests/kernels/test_attention_selector.py @@ -3,8 +3,8 @@ import pytest import torch -from tests.kernels.utils import (STR_BACKEND_ENV_VAR, STR_FLASH_ATTN_VAL, - STR_INVALID_VAL) +from tests.kernels.utils import (STR_FLASH_ATTN_VAL, STR_INVALID_VAL, + override_backend) from vllm.attention.selector import which_attn_to_use @@ -16,7 +16,7 @@ def test_env(name: str, device: str, monkeypatch): Note that we do not test FlashAttn because it is the default backend. """ - monkeypatch.setenv(STR_BACKEND_ENV_VAR, name) + override_backend(monkeypatch, name) if device == "cpu": with patch("vllm.attention.selector.is_cpu", return_value=True): @@ -37,7 +37,7 @@ def test_env(name: str, device: str, monkeypatch): def test_flash_attn(monkeypatch): """Test FlashAttn validation.""" - monkeypatch.setenv(STR_BACKEND_ENV_VAR, STR_FLASH_ATTN_VAL) + override_backend(monkeypatch, STR_FLASH_ATTN_VAL) # Unsupported CUDA arch with patch("torch.cuda.get_device_capability", return_value=[7, 5]): @@ -72,6 +72,6 @@ def test_flash_attn(monkeypatch): def test_invalid_env(monkeypatch): """Throw an exception if the backend name is invalid.""" - monkeypatch.setenv(STR_BACKEND_ENV_VAR, STR_INVALID_VAL) + override_backend(monkeypatch, STR_INVALID_VAL) with pytest.raises(ValueError): which_attn_to_use(8, 16, 8, None, torch.float16, None, 16) diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index 74ad9d8256e3..3874fad57ae4 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -2,4 +2,8 @@ STR_BACKEND_ENV_VAR = "VLLM_ATTENTION_BACKEND" STR_FLASH_ATTN_VAL = "FLASH_ATTN" -STR_INVALID_VAL = "INVALID" \ No newline at end of file +STR_INVALID_VAL = "INVALID" + + +def override_backend(mpatch, backend_name): + mpatch.setenv(STR_BACKEND_ENV_VAR, backend_name) From ed8f8b3aa6eba958e0e527510f50aa3cc94f66c4 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 3 Jun 2024 12:40:06 -0400 Subject: [PATCH 8/9] Comments & type hints --- tests/kernels/utils.py | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index 3874fad57ae4..fb28924c5f9c 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -1,9 +1,21 @@ """Kernel test utils""" -STR_BACKEND_ENV_VAR = "VLLM_ATTENTION_BACKEND" -STR_FLASH_ATTN_VAL = "FLASH_ATTN" -STR_INVALID_VAL = "INVALID" +import pytest +STR_BACKEND_ENV_VAR: str = "VLLM_ATTENTION_BACKEND" +STR_FLASH_ATTN_VAL: str = "FLASH_ATTN" +STR_INVALID_VAL: str = "INVALID" -def override_backend(mpatch, backend_name): + +def override_backend(mpatch: pytest.MonkeyPatch, backend_name: str) -> None: + ''' + Override vLLM attention backend temporarily, + using pytest monkeypatch to ensure that the env vars get + reset once the test context exits. + + Arguments: + + * mpatch: pytest monkeypatch instance + * backend_name: attention backend name to force + ''' mpatch.setenv(STR_BACKEND_ENV_VAR, backend_name) From 8abe51c8b6091de871600ea83d6fc1837eb4db79 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 3 Jun 2024 13:15:28 -0400 Subject: [PATCH 9/9] small refactors per @sroy745 suggestions --- tests/kernels/test_attention_selector.py | 8 ++++---- tests/kernels/utils.py | 5 +++-- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/tests/kernels/test_attention_selector.py b/tests/kernels/test_attention_selector.py index ea3ccb026ea2..79e03c7478de 100644 --- a/tests/kernels/test_attention_selector.py +++ b/tests/kernels/test_attention_selector.py @@ -4,7 +4,7 @@ import torch from tests.kernels.utils import (STR_FLASH_ATTN_VAL, STR_INVALID_VAL, - override_backend) + override_backend_env_variable) from vllm.attention.selector import which_attn_to_use @@ -16,7 +16,7 @@ def test_env(name: str, device: str, monkeypatch): Note that we do not test FlashAttn because it is the default backend. """ - override_backend(monkeypatch, name) + override_backend_env_variable(monkeypatch, name) if device == "cpu": with patch("vllm.attention.selector.is_cpu", return_value=True): @@ -37,7 +37,7 @@ def test_env(name: str, device: str, monkeypatch): def test_flash_attn(monkeypatch): """Test FlashAttn validation.""" - override_backend(monkeypatch, STR_FLASH_ATTN_VAL) + override_backend_env_variable(monkeypatch, STR_FLASH_ATTN_VAL) # Unsupported CUDA arch with patch("torch.cuda.get_device_capability", return_value=[7, 5]): @@ -72,6 +72,6 @@ def test_flash_attn(monkeypatch): def test_invalid_env(monkeypatch): """Throw an exception if the backend name is invalid.""" - override_backend(monkeypatch, STR_INVALID_VAL) + override_backend_env_variable(monkeypatch, STR_INVALID_VAL) with pytest.raises(ValueError): which_attn_to_use(8, 16, 8, None, torch.float16, None, 16) diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index fb28924c5f9c..b401eb87d3ec 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -7,9 +7,10 @@ STR_INVALID_VAL: str = "INVALID" -def override_backend(mpatch: pytest.MonkeyPatch, backend_name: str) -> None: +def override_backend_env_variable(mpatch: pytest.MonkeyPatch, + backend_name: str) -> None: ''' - Override vLLM attention backend temporarily, + Override the environment variable indicating the vLLM backend temporarily, using pytest monkeypatch to ensure that the env vars get reset once the test context exits.