Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tests/kernels/moe/test_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,7 @@ def test_mixtral_moe(dtype: torch.dtype, padding: bool, use_rocm_aiter: bool,
atol=mixtral_moe_tol[dtype])


@pytest.mark.flaky(reruns=2)
@pytest.mark.parametrize("m", [1, 123, 666])
@pytest.mark.parametrize("n", [128, 1024])
@pytest.mark.parametrize("k", [256, 2048])
Expand Down
33 changes: 0 additions & 33 deletions tests/tensorizer_loader/conftest.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,5 @@
# SPDX-License-Identifier: Apache-2.0

import functools
import gc
from typing import Callable, TypeVar

import pytest
import torch
from typing_extensions import ParamSpec

from vllm.distributed import cleanup_dist_env_and_memory
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
Expand All @@ -25,32 +18,6 @@ def cleanup():
cleanup_dist_env_and_memory(shutdown_ray=True)


_P = ParamSpec("_P")
_R = TypeVar("_R")


def retry_until_skip(n: int):

def decorator_retry(func: Callable[_P, _R]) -> Callable[_P, _R]:

@functools.wraps(func)
def wrapper_retry(*args: _P.args, **kwargs: _P.kwargs) -> _R:
for i in range(n):
try:
return func(*args, **kwargs)
except AssertionError:
gc.collect()
torch.cuda.empty_cache()
if i == n - 1:
pytest.skip(f"Skipping test after {n} attempts.")

raise AssertionError("Code should not be reached")

return wrapper_retry

return decorator_retry


@pytest.fixture(autouse=True)
def tensorizer_config():
config = TensorizerConfig(tensorizer_uri="vllm")
Expand Down
3 changes: 1 addition & 2 deletions tests/tensorizer_loader/test_tensorizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
from vllm.utils import PlaceholderModule, import_from_path

from ..utils import VLLM_PATH, RemoteOpenAIServer
from .conftest import retry_until_skip

try:
from tensorizer import EncryptionParams
Expand Down Expand Up @@ -325,7 +324,7 @@ def test_deserialized_encrypted_vllm_model_with_tp_has_same_outputs(
assert outputs == deserialized_outputs


@retry_until_skip(3)
@pytest.mark.flaky(reruns=3)
def test_vllm_tensorized_model_has_same_outputs(vllm_runner, tmp_path):
gc.collect()
torch.cuda.empty_cache()
Expand Down