Skip to content

Commit

Permalink
tests: group skip wrappers & make them optional (#2372)
Browse files Browse the repository at this point in the history
(cherry picked from commit ee1a529)
  • Loading branch information
Borda committed Mar 16, 2024
1 parent 7f16839 commit 2df5098
Show file tree
Hide file tree
Showing 13 changed files with 77 additions and 54 deletions.
5 changes: 5 additions & 0 deletions .azure/gpu-unittests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,11 @@ jobs:
echo "##vso[task.setvariable variable=CUDA_VERSION_MM]$CUDA_version_mm"
echo "##vso[task.setvariable variable=TORCH_URL]https://download.pytorch.org/whl/cu${CUDA_version_mm}/torch_stable.html"
displayName: "set Env. vars"
- bash: |
echo "##vso[task.setvariable variable=ALLOW_SKIP_IF_OUT_OF_MEMORY]1"
echo "##vso[task.setvariable variable=ALLOW_SKIP_IF_BAD_CONNECTION]1"
condition: eq(variables['Build.Reason'], 'PullRequest')
displayName: "set Env. vars for PRs"
- bash: |
printf "PR: $PR_NUMBER \n"
Expand Down
6 changes: 6 additions & 0 deletions .github/workflows/ci-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,12 @@ jobs:
--find-links $PYTORCH_URL -f $PYPI_CACHE
pip list
- name: set special vars for PR
if: ${{ github.event_name == 'pull_request' }}
run: |
echo 'ALLOW_SKIP_IF_OUT_OF_MEMORY=1' >> $GITHUB_ENV
echo 'ALLOW_SKIP_IF_BAD_CONNECTION=1' >> $GITHUB_ENV
- name: Sanity check
id: info
run: |
Expand Down
2 changes: 0 additions & 2 deletions tests/unittests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
NUM_PROCESSES,
THRESHOLD,
setup_ddp,
skip_on_running_out_of_memory,
)

# adding compatibility for numpy >= 1.24
Expand Down Expand Up @@ -50,5 +49,4 @@ class _GroupInput(NamedTuple):
"NUM_PROCESSES",
"THRESHOLD",
"setup_ddp",
"skip_on_running_out_of_memory",
]
20 changes: 0 additions & 20 deletions tests/unittests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@
import contextlib
import os
import sys
from functools import wraps
from typing import Any, Callable, Optional

import pytest
import torch
Expand Down Expand Up @@ -84,21 +82,3 @@ def pytest_sessionfinish():
return
pytest.pool.close()
pytest.pool.join()


def skip_on_running_out_of_memory(reason: str = "Skipping test as it ran out of memory."):
"""Handle tests that sometimes runs out of memory, by simply skipping them."""

def test_decorator(function: Callable, *args: Any, **kwargs: Any) -> Optional[Callable]:
@wraps(function)
def run_test(*args: Any, **kwargs: Any) -> Optional[Any]:
try:
return function(*args, **kwargs)
except RuntimeError as ex:
if "DefaultCPUAllocator: not enough memory:" not in str(ex):
raise ex
pytest.skip(reason)

return run_test

return test_decorator
5 changes: 5 additions & 0 deletions tests/unittests/helpers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,15 @@
import numpy
import torch

from unittests.helpers.wrappers import skip_on_connection_issues, skip_on_running_out_of_memory


def seed_all(seed):
"""Set the seed of all computational frameworks."""
random.seed(seed)
numpy.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)


__all__ = ["seed_all", "skip_on_connection_issues", "skip_on_running_out_of_memory"]
51 changes: 51 additions & 0 deletions tests/unittests/helpers/wrappers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import os
from functools import wraps
from typing import Any, Callable, Optional

import pytest

ALLOW_SKIP_IF_OUT_OF_MEMORY = os.getenv("ALLOW_SKIP_IF_OUT_OF_MEMORY", "0") == "1"
ALLOW_SKIP_IF_BAD_CONNECTION = os.getenv("ALLOW_SKIP_IF_BAD_CONNECTION", "0") == "1"


def skip_on_running_out_of_memory(reason: str = "Skipping test as it ran out of memory."):
"""Handle tests that sometimes runs out of memory, by simply skipping them."""

def test_decorator(function: Callable, *args: Any, **kwargs: Any) -> Optional[Callable]:
@wraps(function)
def run_test(*args: Any, **kwargs: Any) -> Optional[Any]:
try:
return function(*args, **kwargs)
except RuntimeError as ex:
if "DefaultCPUAllocator: not enough memory:" not in str(ex):
raise ex
if ALLOW_SKIP_IF_OUT_OF_MEMORY:
pytest.skip(reason)

return run_test

return test_decorator


def skip_on_connection_issues(reason: str = "Unable to load checkpoints from HuggingFace `transformers`."):
"""Handle download related tests if they fail due to connection issues.
The tests run normally if no connection issue arises, and they're marked as skipped otherwise.
"""
_error_msg_starts = ["We couldn't connect to", "Connection error", "Can't load", "`nltk` resource `punkt` is"]

def test_decorator(function: Callable, *args: Any, **kwargs: Any) -> Optional[Callable]:
@wraps(function)
def run_test(*args: Any, **kwargs: Any) -> Optional[Any]:
try:
return function(*args, **kwargs)
except (OSError, ValueError) as ex:
if all(msg_start not in str(ex) for msg_start in _error_msg_starts):
raise ex
if ALLOW_SKIP_IF_BAD_CONNECTION:
pytest.skip(reason)

return run_test

return test_decorator
3 changes: 1 addition & 2 deletions tests/unittests/image/test_perceptual_path_length.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,7 @@
from torchmetrics.image.perceptual_path_length import PerceptualPathLength
from torchmetrics.utilities.imports import _TORCH_FIDELITY_AVAILABLE

from unittests import skip_on_running_out_of_memory
from unittests.helpers import seed_all
from unittests.helpers import seed_all, skip_on_running_out_of_memory

seed_all(42)

Expand Down
2 changes: 1 addition & 1 deletion tests/unittests/multimodal/test_clip_iqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@
from torchmetrics.utilities.imports import _PIQ_GREATER_EQUAL_0_8, _TRANSFORMERS_GREATER_EQUAL_4_10
from torchvision.transforms import PILToTensor

from unittests.helpers import skip_on_connection_issues
from unittests.helpers.testers import MetricTester
from unittests.image import _SAMPLE_IMAGE
from unittests.text.helpers import skip_on_connection_issues


@pytest.mark.parametrize(
Expand Down
3 changes: 1 addition & 2 deletions tests/unittests/multimodal/test_clip_score.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,8 @@
from transformers import CLIPModel as _CLIPModel
from transformers import CLIPProcessor as _CLIPProcessor

from unittests.helpers import seed_all
from unittests.helpers import seed_all, skip_on_connection_issues
from unittests.helpers.testers import MetricTester
from unittests.text.helpers import skip_on_connection_issues

seed_all(42)

Expand Down
25 changes: 1 addition & 24 deletions tests/unittests/text/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
import pickle
import sys
from functools import partial, wraps
from functools import partial
from typing import Any, Callable, Dict, Optional, Sequence, Union

import numpy as np
Expand Down Expand Up @@ -477,26 +477,3 @@ def run_differentiability_test(
if metric.is_differentiable:
# check for numerical correctness
assert torch.autograd.gradcheck(partial(metric_functional, **metric_args), (preds[0], targets[0]))


def skip_on_connection_issues(reason: str = "Unable to load checkpoints from HuggingFace `transformers`."):
"""Handle download related tests if they fail due to connection issues.
The tests run normally if no connection issue arises, and they're marked as skipped otherwise.
"""
_error_msg_starts = ["We couldn't connect to", "Connection error", "Can't load", "`nltk` resource `punkt` is"]

def test_decorator(function: Callable, *args: Any, **kwargs: Any) -> Optional[Callable]:
@wraps(function)
def run_test(*args: Any, **kwargs: Any) -> Optional[Any]:
try:
return function(*args, **kwargs)
except (OSError, ValueError) as ex:
if all(msg_start not in str(ex) for msg_start in _error_msg_starts):
raise ex
pytest.skip(reason)

return run_test

return test_decorator
3 changes: 2 additions & 1 deletion tests/unittests/text/test_bertscore.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@
from torchmetrics.utilities.imports import _BERTSCORE_AVAILABLE, _TRANSFORMERS_GREATER_EQUAL_4_4
from typing_extensions import Literal

from unittests.text.helpers import TextTester, skip_on_connection_issues
from unittests.helpers import skip_on_connection_issues
from unittests.text.helpers import TextTester
from unittests.text.inputs import _inputs_single_reference

if _BERTSCORE_AVAILABLE:
Expand Down
3 changes: 2 additions & 1 deletion tests/unittests/text/test_infolm.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
from torchmetrics.text.infolm import InfoLM
from torchmetrics.utilities.imports import _TRANSFORMERS_GREATER_EQUAL_4_4

from unittests.text.helpers import TextTester, skip_on_connection_issues
from unittests.helpers import skip_on_connection_issues
from unittests.text.helpers import TextTester
from unittests.text.inputs import HYPOTHESIS_A, HYPOTHESIS_C, _inputs_single_reference

# Small bert model with 2 layers, 2 attention heads and hidden dim of 128
Expand Down
3 changes: 2 additions & 1 deletion tests/unittests/text/test_rouge.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@
from torchmetrics.utilities.imports import _NLTK_AVAILABLE, _ROUGE_SCORE_AVAILABLE
from typing_extensions import Literal

from unittests.text.helpers import TextTester, skip_on_connection_issues
from unittests.helpers import skip_on_connection_issues
from unittests.text.helpers import TextTester
from unittests.text.inputs import _Input, _inputs_multiple_references, _inputs_single_sentence_single_reference

if _ROUGE_SCORE_AVAILABLE:
Expand Down

0 comments on commit 2df5098

Please sign in to comment.