Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 80 additions & 2 deletions tests/e2e/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,11 @@
#

import contextlib
import functools
import gc
import json
import logging
import multiprocessing
import os
import shlex
import subprocess
Expand Down Expand Up @@ -78,6 +80,77 @@
_TEST_DIR = os.path.dirname(__file__)


def _check_npu_memory_worker(target_free_percentage: float, max_wait_seconds: float):
import torch_npu # type: ignore

# We can try to clean up memory in this subprocess, though it mostly affects this process.
# But if there are any lingering contexts in this process (unlikely for a fresh spawn), it helps.
gc.collect()
torch.npu.empty_cache()

_, total_npu_memory = torch.npu.mem_get_info()
start_time = time.time()

while True:
free_bytes, _ = torch.npu.mem_get_info()
if free_bytes / total_npu_memory >= target_free_percentage:
print(f'check_npu_memory_worker: npu free memory decreased target value.')
return # Success

elapsed = time.time() - start_time
if elapsed > max_wait_seconds:
# Print to stderr so it's visible in test logs even if captured
print(
f"Timeout: NPU memory free size did not reach "
f"{target_free_percentage} of total npu memory within {max_wait_seconds} seconds.",
file=sys.stderr
)
sys.exit(1) # Failure

print(
f"Waiting for NPU memory to be free: "
f"{free_bytes / 1024**3:.2f} GB available, "
f"Elapsed time: {elapsed:.2f} s."
)
# Try to clean up
gc.collect()
torch.npu.empty_cache()
time.sleep(1)


def wait_until_npu_memory_free(target_free_percentage: float = 0.5, max_wait_seconds: float = 50):
"""Decorator to wait until the NPU memory free size is above target_free_percentage.

Args:
target_free_percentage (float): Target free memory percentage of total.
max_wait_seconds (float): Maximum wait time in seconds.
"""
def decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
# Clean up non-NPU resources in the main process
cleanup_dist_env_and_memory()

# Use a spawned subprocess to check NPU memory to avoid initializing NPU in the main process
ctx = multiprocessing.get_context("spawn")
p = ctx.Process(
target=_check_npu_memory_worker,
args=(target_free_percentage, max_wait_seconds)
)
p.start()
p.join()

if p.exitcode != 0:
raise TimeoutError(
f"Timeout: NPU memory free size did not reach "
f"{target_free_percentage} of total npu memory within {max_wait_seconds} seconds."
)

return func(*args, **kwargs)
return wrapper
return decorator


def cleanup_dist_env_and_memory(shutdown_ray: bool = False):
destroy_model_parallel()
destroy_distributed_environment()
Expand All @@ -87,8 +160,13 @@ def cleanup_dist_env_and_memory(shutdown_ray: bool = False):
import ray # Lazy import Ray
ray.shutdown()
gc.collect()
torch.npu.empty_cache()
torch.npu.reset_peak_memory_stats()

# Only clean NPU cache if NPU is already initialized/available in this process.
# This prevents accidental initialization of NPU context in the main process,
# which would break subsequent forks.
if hasattr(torch, "npu") and torch.npu.is_initialized():
torch.npu.empty_cache()
torch.npu.reset_peak_memory_stats()


class RemoteOpenAIServer:
Expand Down
2 changes: 2 additions & 0 deletions tests/e2e/multicard/2-cards/test_aclgraph_capture_replay.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from vllm.utils.network_utils import get_open_port

from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type
from tests.e2e.conftest import wait_until_npu_memory_free

MODELS = [
# Offline data parallel mode will be not supported/useful for dense models
Expand Down Expand Up @@ -137,6 +138,7 @@ def _run_worker_process(
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("max_tokens", [4, 36])
@patch.dict(os.environ, {"ASCEND_RT_VISIBLE_DEVICES": "0,1"})
@wait_until_npu_memory_free(target_free_percentage=0.6)
def test_models_aclgraph_capture_replay_metrics_dp2(
model: str,
max_tokens: int,
Expand Down
3 changes: 2 additions & 1 deletion tests/e2e/multicard/4-cards/long_sequence/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

from vllm import SamplingParams

from tests.e2e.conftest import VllmRunner
from tests.e2e.conftest import VllmRunner, wait_until_npu_memory_free

os.environ["HCCL_BUFFSIZE"] = "768"

Expand Down Expand Up @@ -126,6 +126,7 @@ def test_models_pcp_dcp_piece_wise():
runner.model.generate(prompts, sampling_params)


@wait_until_npu_memory_free()
def test_pcp_basic():
prompts = [
"The capital of France is", "Hello, my name is Tom, I am",
Expand Down