Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 49 additions & 36 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import time
import warnings
from collections.abc import Callable, Iterable, Sequence
from contextlib import ExitStack, contextmanager, suppress
from contextlib import ExitStack, contextmanager
from multiprocessing import Process
from pathlib import Path
from typing import Any, Literal
Expand Down Expand Up @@ -1414,52 +1414,65 @@ def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> None:


Comment thread
dzhengAP marked this conversation as resolved.
def spawn_new_process_for_each_test(f: Callable[_P, None]) -> Callable[_P, None]:
"""Decorator to spawn a new process for each test function."""
"""Decorator to spawn a new process for each test function.

Uses subprocess with cloudpickle to serialize the test function and
propagates exceptions back to the parent, so test failures are never
silently swallowed (fixes https://github.com/vllm-project/vllm/issues/41415).
"""

@functools.wraps(f)
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> None:
# Check if we're already in a subprocess
if os.environ.get("RUNNING_IN_SUBPROCESS") == "1":
# If we are, just run the function directly
return f(*args, **kwargs)

import torch.multiprocessing as mp

with suppress(RuntimeError):
mp.set_start_method("spawn")
Comment thread
dzhengAP marked this conversation as resolved.

# Get the module
module_name = f.__module__

# Create a process with environment variable set
env = os.environ.copy()
env["RUNNING_IN_SUBPROCESS"] = "1"

with tempfile.TemporaryDirectory() as tempdir:
output_filepath = os.path.join(tempdir, "new_process.tmp")
with tempfile.NamedTemporaryFile(delete=False, suffix=".tb", mode="wb") as tmp:
tb_file = tmp.name

# `cloudpickle` allows pickling complex functions directly
input_bytes = cloudpickle.dumps((f, output_filepath))
try:
# Serialize the function + args with cloudpickle so closures work
payload = cloudpickle.dumps((f, args, kwargs, tb_file))

child_script = (
"import sys, cloudpickle, traceback\n"
"try:\n"
" from _pytest.outcomes import Skipped\n"
"except ImportError:\n"
" class Skipped(BaseException): pass\n"
"f, args, kwargs, tb_file = "
"cloudpickle.loads(sys.stdin.buffer.read())\n"
"try:\n"
" f(*args, **kwargs)\n"
"except Skipped:\n"
" sys.exit(0)\n"
"except BaseException:\n"
" open(tb_file, 'w').write(traceback.format_exc())\n"
" sys.exit(1)\n"
)
Comment thread
dzhengAP marked this conversation as resolved.

repo_root = str(VLLM_PATH.resolve())

env = dict(env or os.environ)
env = os.environ.copy()
env["PYTHONPATH"] = repo_root + os.pathsep + env.get("PYTHONPATH", "")

cmd = [sys.executable, "-m", f"{module_name}"]

returned = subprocess.run(
cmd, input=input_bytes, capture_output=True, env=env
result = subprocess.run(
[sys.executable, "-c", child_script],
input=payload,
capture_output=True,
env=env,
)

# check if the subprocess is successful
try:
returned.check_returncode()
except Exception as e:
# wrap raised exception to provide more information
if result.returncode != 0:
# Read traceback written by child, fall back to stderr
tb = ""
if os.path.exists(tb_file) and os.path.getsize(tb_file) > 0:
with open(tb_file) as fp:
tb = fp.read()
else:
tb = result.stderr.decode()
raise RuntimeError(
f"Error raised in subprocess:\n{returned.stderr.decode()}"
) from e
f"Test subprocess '{f.__name__}' failed "
f"(exit code {result.returncode}):\n{tb}"
)
finally:
with contextlib.suppress(OSError):
os.remove(tb_file)

return wrapper

Expand Down
33 changes: 33 additions & 0 deletions tests/utils_/test_spawn_decorator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Tests for spawn_new_process_for_each_test decorator."""

import pytest

from tests.utils import spawn_new_process_for_each_test


@spawn_new_process_for_each_test
Comment thread
dzhengAP marked this conversation as resolved.
def test_spawn_decorator_passing():
"""Passing function should complete normally."""
assert 1 + 1 == 2


@pytest.mark.xfail(raises=RuntimeError, strict=True)
@spawn_new_process_for_each_test
Comment thread
dzhengAP marked this conversation as resolved.
def test_spawn_decorator_failure_is_caught():
"""Failing function should raise RuntimeError, never silently pass."""
raise ValueError("intentional failure")


@spawn_new_process_for_each_test
def test_spawn_decorator_skip():
"""pytest.skip inside subprocess should propagate correctly."""
pytest.skip("intentional skip")


@spawn_new_process_for_each_test
@pytest.mark.parametrize("x,y,expected", [(1, 2, 3), (0, 0, 0)])
def test_spawn_decorator_parametrized(x, y, expected):
"""Args and kwargs must be forwarded correctly to subprocess."""
assert x + y == expected
Loading
Loading