diff --git a/python/tvm/testing/plugin.py b/python/tvm/testing/plugin.py index e90bd5e6dbf5..1f4f983b7210 100644 --- a/python/tvm/testing/plugin.py +++ b/python/tvm/testing/plugin.py @@ -56,8 +56,8 @@ def pytest_configure(config): """Runs at pytest configure time, defines marks to be used later.""" - for markername, desc in MARKERS.items(): - config.addinivalue_line("markers", "{}: {}".format(markername, desc)) + for feature in utils.Feature._all_features.values(): + feature._register_marker(config) print("enabled targets:", "; ".join(map(lambda x: x[0], utils.enabled_targets()))) print("pytest marker:", config.option.markexpr) @@ -269,25 +269,26 @@ def _target_to_requirement(target): # mapping from target to decorator if target.kind.name == "cuda" and "cudnn" in target.attrs.get("libs", []): - return utils.requires_cudnn() + return utils.requires_cudnn.marks() if target.kind.name == "cuda" and "cublas" in target.attrs.get("libs", []): - return utils.requires_cublas() + return utils.requires_cublas.marks() if target.kind.name == "cuda": - return utils.requires_cuda() + return utils.requires_cuda.marks() if target.kind.name == "rocm": - return utils.requires_rocm() + return utils.requires_rocm.marks() if target.kind.name == "vulkan": - return utils.requires_vulkan() + return utils.requires_vulkan.marks() if target.kind.name == "nvptx": - return utils.requires_nvptx() + return utils.requires_nvptx.marks() if target.kind.name == "metal": - return utils.requires_metal() + return utils.requires_metal.marks() if target.kind.name == "opencl": - return utils.requires_opencl() + return utils.requires_opencl.marks() if target.kind.name == "llvm": - return utils.requires_llvm() + return utils.requires_llvm.marks() if target.kind.name == "hexagon": - return utils.requires_hexagon() + return utils.requires_hexagon.marks() + return [] diff --git a/python/tvm/testing/utils.py b/python/tvm/testing/utils.py index 0e2d7be4a14e..939786c9294f 100644 --- a/python/tvm/testing/utils.py +++ b/python/tvm/testing/utils.py @@ -67,15 +67,20 @@ def test_something(): import copyreg import ctypes import functools +import itertools import logging import os +import pickle import platform import shutil import sys import time -import pickle + +from typing import Optional, Callable, Union, List + import pytest import numpy as np + import tvm import tvm.arith import tvm.tir @@ -84,9 +89,6 @@ def test_something(): from tvm.contrib import nvcc, cudnn from tvm.error import TVMError -from tvm.relay.op.contrib.ethosn import ethosn_available -from tvm.relay.op.contrib import cmsisnn -from tvm.relay.op.contrib import vitis_ai SKIP_SLOW_TESTS = os.getenv("SKIP_SLOW_TESTS", "").lower() in {"true", "1", "yes"} @@ -388,12 +390,9 @@ def _check_forward(constraints1, constraints2, varmap, backvarmap): ) -def _get_targets(target_str=None): - if target_str is None: - target_str = os.environ.get("TVM_TEST_TARGETS", "") - # Use dict instead of set for de-duplication so that the - # targets stay in the order specified. - target_names = list({t.strip(): None for t in target_str.split(";") if t.strip()}) +def _get_targets(target_names=None): + if target_names is None: + target_names = _tvm_test_targets() if not target_names: target_names = DEFAULT_TEST_TARGETS @@ -429,7 +428,7 @@ def _get_targets(target_str=None): " Try setting TVM_TEST_TARGETS to a supported target. Defaulting to llvm.", target_str, ) - return _get_targets("llvm") + return _get_targets(["llvm"]) raise TVMError( "None of the following targets are supported by this build of TVM: %s." @@ -515,458 +514,544 @@ def enabled_targets(): return [(t["target"], tvm.device(t["target"])) for t in _get_targets() if t["is_runnable"]] -def _compose(args, decs): - """Helper to apply multiple markers""" - if len(args) > 0: - f = args[0] - for d in reversed(decs): - f = d(f) - return f - return decs +class Feature: + """A feature that may be required to run a test. -def slow(fn): - @functools.wraps(fn) - def wrapper(*args, **kwargs): - if SKIP_SLOW_TESTS: - pytest.skip("Skipping slow test since RUN_SLOW_TESTS environment variables is 'true'") - else: - fn(*args, **kwargs) + Parameters + ---------- + name: str - return wrapper + The short name of the feature. Should match the name in the + requires_* decorator. This is applied as a mark to all tests + using this feature, and can be used in pytests ``-m`` + argument. + long_name: Optional[str] -def uses_gpu(*args): - """Mark to differentiate tests that use the GPU in some capacity. + The long name of the feature, to be used in error messages. - These tests will be run on CPU-only test nodes and on test nodes with GPUs. - To mark a test that must have a GPU present to run, use - :py:func:`tvm.testing.requires_gpu`. + If None, defaults to the short name. - Parameters - ---------- - f : function - Function to mark - """ - _uses_gpu = [pytest.mark.gpu] - return _compose(args, _uses_gpu) + cmake_flag: Optional[str] + The flag that must be enabled in the config.cmake in order to + use this feature. -def requires_x86(*args): - """Mark a test as requiring the x86 Architecture to run. + If None, no flag is required to use this feature. - Tests with this mark will not be run unless on an x86 platform. + target_kind_enabled: Optional[str] - Parameters - ---------- - f : function - Function to mark - """ - _requires_x86 = [ - pytest.mark.skipif(platform.machine() != "x86_64", reason="x86 Architecture Required"), - ] - return _compose(args, _requires_x86) + The target kind that must be enabled to run tests using this + feature. If present, the target_kind must appear in the + TVM_TEST_TARGETS environment variable, or in + tvm.testing.DEFAULT_TEST_TARGETS if TVM_TEST_TARGETS is + undefined. + If None, this feature does not require a specific target to be + enabled. -def requires_gpu(*args): - """Mark a test as requiring a GPU to run. + compile_time_check: Optional[Callable[[], Union[bool,str]]] - Tests with this mark will not be run unless a gpu is present. + A check that returns True if the feature can be used at + compile-time. (e.g. Validating the version number of the nvcc + compiler.) If the feature does not have support to perform + compile-time tests, the check should returns False to display + a generic error message, or a string to display a more + specific error message. - Parameters - ---------- - f : function - Function to mark - """ - _requires_gpu = [ - pytest.mark.skipif( - not tvm.cuda().exist - and not tvm.rocm().exist - and not tvm.opencl().exist - and not tvm.metal().exist - and not tvm.vulkan().exist, - reason="No GPU present", - ), - *uses_gpu(), - ] - return _compose(args, _requires_gpu) + If None, no additional check is performed. + target_kind_hardware: Optional[str] -def requires_cuda(*args): - """Mark a test as requiring the CUDA runtime. + The target kind that must have available hardware in order to + run tests using this feature. This is checked using + tvm.device(target_kind_hardware).exist. If a feature requires + a different check, this should be implemented using + run_time_check. - This also marks the test as requiring a cuda gpu. + If None, this feature does not require a specific + tvm.device to exist. - Parameters - ---------- - f : function - Function to mark - """ - _requires_cuda = [ - pytest.mark.cuda, - pytest.mark.skipif(not device_enabled("cuda"), reason="CUDA support not enabled"), - *requires_gpu(), - ] - return _compose(args, _requires_cuda) + run_time_check: Optional[Callable[[], Union[bool,str]]] + A check that returns True if the feature can be used at + run-time. (e.g. Validating the compute version supported by a + GPU.) If the feature does not have support to perform + run-time tests, the check should returns False to display a + generic error message, or a string to display a more specific + error message. -def requires_cudnn(*args): - """Mark a test as requiring the cuDNN library. + If None, no additional check is performed. - This also marks the test as requiring a cuda gpu. + parent_features: Optional[Union[str,List[str]]] - Parameters - ---------- - f : function - Function to mark - """ + The short name of a feature or features that are required in + order to use this feature. (e.g. Using cuDNN requires using + CUDA) This feature should inherit all checks of the parent + feature, with the exception of the `target_kind_enabled` + checks. - requirements = [ - pytest.mark.skipif( - not cudnn.exists(), reason="cuDNN library not enabled, or not installed" - ), - *requires_cuda(), - ] - return _compose(args, requirements) + If None, this feature does not require any other parent + features. + """ -def requires_cublas(*args): - """Mark a test as requiring the cuBLAS library. + _all_features = {} + + def __init__( + self, + name: str, + long_name: Optional[str] = None, + cmake_flag: Optional[str] = None, + target_kind_enabled: Optional[str] = None, + compile_time_check: Optional[Callable[[], Union[bool, str]]] = None, + target_kind_hardware: Optional[str] = None, + run_time_check: Optional[Callable[[], Union[bool, str]]] = None, + parent_features: Optional[Union[str, List[str]]] = None, + ): + self.name = name + self.long_name = long_name or name + self.cmake_flag = cmake_flag + self.target_kind_enabled = target_kind_enabled + self.compile_time_check = compile_time_check + self.target_kind_hardware = target_kind_hardware + self.run_time_check = run_time_check + + if parent_features is None: + self.parent_features = [] + elif isinstance(parent_features, str): + self.parent_features = [parent_features] + else: + self.parent_features = parent_features - This also marks the test as requiring a cuda gpu. + self._all_features[self.name] = self - Parameters - ---------- - f : function - Function to mark - """ + def _register_marker(self, config): + config.addinivalue_line("markers", f"{self.name}: Mark a test as using {self.long_name}") - requirements = [ - pytest.mark.skipif( - tvm.get_global_func("tvm.contrib.cublas.matmul", True), - reason="cuDNN library not enabled", - ), - *requires_cuda(), - ] - return _compose(args, requirements) + def _uses_marks(self): + for parent in self.parent_features: + yield from self._all_features[parent]._uses_marks() + yield getattr(pytest.mark, self.name) -def requires_nvptx(*args): - """Mark a test as requiring the NVPTX compilation on the CUDA runtime + def _compile_only_marks(self): + for parent in self.parent_features: + yield from self._all_features[parent]._compile_only_marks() - This also marks the test as requiring a cuda gpu, and requiring - LLVM support. + if self.compile_time_check is not None: + res = self.compile_time_check() + if isinstance(res, str): + yield pytest.mark.skipif(True, reason=res) + else: + yield pytest.mark.skipif( + not res, reason=f"Compile-time support for {self.long_name} not present" + ) - Parameters - ---------- - f : function - Function to mark + if self.target_kind_enabled is not None: + target_kind = self.target_kind_enabled.split()[0] + yield pytest.mark.skipif( + all(enabled.split()[0] != target_kind for enabled in _tvm_test_targets()), + reason=( + f"{self.target_kind_enabled} tests disabled " + f"by TVM_TEST_TARGETS environment variable" + ), + ) - """ - _requires_nvptx = [ - pytest.mark.skipif(not device_enabled("nvptx"), reason="NVPTX support not enabled"), - *requires_llvm(), - *requires_gpu(), - ] - return _compose(args, _requires_nvptx) + if self.cmake_flag is not None: + yield pytest.mark.skipif( + not _cmake_flag_enabled(self.cmake_flag), + reason=( + f"{self.long_name} support not enabled. " + f"Set {self.cmake_flag} in config.cmake to enable." + ), + ) + def _run_only_marks(self): + for parent in self.parent_features: + yield from self._all_features[parent]._run_only_marks() + + if self.run_time_check is not None: + res = self.run_time_check() + if isinstance(res, str): + yield pytest.mark.skipif(True, reason=res) + else: + yield pytest.mark.skipif( + not res, reason=f"Run-time support for {self.long_name} not present" + ) -def requires_nvcc_version(major_version, minor_version=0, release_version=0): - """Mark a test as requiring at least a specific version of nvcc. + if self.target_kind_hardware is not None: + yield pytest.mark.skipif( + not tvm.device(self.target_kind_hardware).exist, + reason=f"No device exists for target {self.target_kind_hardware}", + ) - Unit test marked with this decorator will run only if the - installed version of NVCC is at least `(major_version, - minor_version, release_version)`. + def marks(self, support_required="compile-and-run"): + """Return a list of marks to be used - This also marks the test as requiring a cuda support. + Parameters + ---------- - Parameters - ---------- - major_version: int + support_required: str - The major version of the (major,minor,release) version tuple. + Allowed values: "compile-and-run" (default), + "compile-only", or "optional". - minor_version: int + See Feature.__call__ for details. + """ + if support_required not in ["compile-and-run", "compile-only", "optional"]: + raise ValueError(f"Unknown feature support type: {support_required}") - The minor version of the (major,minor,release) version tuple. + if support_required == "compile-and-run": + marks = itertools.chain( + self._run_only_marks(), self._compile_only_marks(), self._uses_marks() + ) + elif support_required == "compile-only": + marks = itertools.chain(self._compile_only_marks(), self._uses_marks()) + elif support_required == "optional": + marks = self._uses_marks() + else: + raise ValueError(f"Unknown feature support type: {support_required}") - release_version: int + return list(marks) - The release version of the (major,minor,release) version tuple. + def __call__(self, func=None, *, support_required="compile-and-run"): + """Mark a pytest function as requiring this feature - """ + Can be used either as a bare decorator, or as a decorator with + arguments. - try: - nvcc_version = nvcc.get_cuda_version() - except RuntimeError: - nvcc_version = (0, 0, 0) + Parameters + ---------- - min_version = (major_version, minor_version, release_version) - version_str = ".".join(str(v) for v in min_version) - requires = [ - pytest.mark.skipif(nvcc_version < min_version, reason=f"Requires NVCC >= {version_str}"), - *requires_cuda(), - ] + func: Callable - def inner(func): - return _compose([func], requires) + The pytest test function to be marked - return inner + support_required: str + Allowed values: "compile-and-run" (default), + "compile-only", or "optional". -def skip_if_32bit(reason): - def decorator(*args): - if "32bit" in platform.architecture()[0]: - return _compose(args, [pytest.mark.skip(reason=reason)]) + If "compile-and-run", the test case is marked as using the + feature, and is skipped if the environment lacks either + compile-time or run-time support for the feature. - return _compose(args, []) + If "compile-only", the test case is marked as using the + feature, and is skipped if the environment lacks + compile-time support. - return decorator + If "optional", the test case is marked as using the + feature, but isn't skipped. This is kept for backwards + compatibility for tests that use `enabled_targets()`, and + should be avoided in new test code. Instead, prefer + parametrizing over the target using the `target` fixture. + Examples + -------- -def requires_cudagraph(*args): - """Mark a test as requiring the CUDA Graph Feature + .. code-block:: python - This also marks the test as requiring cuda + @feature + def test_compile_and_run(): + ... - Parameters - ---------- - f : function - Function to mark - """ - _requires_cudagraph = [ - pytest.mark.skipif( - not nvcc.have_cudagraph(), reason="CUDA Graph is not supported in this environment" - ), - *requires_cuda(), - ] - return _compose(args, _requires_cudagraph) + @feature(compile_only=True) + def test_compile_only(): + ... + """ -def requires_opencl(*args): - """Mark a test as requiring the OpenCL runtime. + if support_required not in ["compile-and-run", "compile-only", "optional"]: + raise ValueError(f"Unknown feature support type: {support_required}") - This also marks the test as requiring a gpu. + def wrapper(func): + for mark in self.marks(support_required=support_required): + func = mark(func) + return func - Parameters - ---------- - f : function - Function to mark - """ - _requires_opencl = [ - pytest.mark.opencl, - pytest.mark.skipif(not device_enabled("opencl"), reason="OpenCL support not enabled"), - *requires_gpu(), - ] - return _compose(args, _requires_opencl) + if func is None: + return wrapper + return wrapper(func) -def requires_corstone300(*args): - """Mark a test as requiring the corstone300 FVP + @classmethod + def require(cls, name, support_required="compile-and-run"): + """Returns a decorator that marks a test as requiring a feature - Parameters - ---------- - f : function - Function to mark - """ - _requires_corstone300 = [ - pytest.mark.corstone300, - pytest.mark.skipif( - shutil.which("arm-none-eabi-gcc") is None, reason="ARM embedded toolchain unavailable" - ), - ] - return _compose(args, _requires_corstone300) + Parameters + ---------- + name: str -def requires_rocm(*args): - """Mark a test as requiring the rocm runtime. + The name of the feature that is used by the test - This also marks the test as requiring a gpu. + support_required: str - Parameters - ---------- - f : function - Function to mark - """ - _requires_rocm = [ - pytest.mark.rocm, - pytest.mark.skipif(not device_enabled("rocm"), reason="rocm support not enabled"), - *requires_gpu(), - ] - return _compose(args, _requires_rocm) + Allowed values: "compile-and-run" (default), + "compile-only", or "optional". + See Feature.__call__ for details. -def requires_metal(*args): - """Mark a test as requiring the metal runtime. + Examples + -------- - This also marks the test as requiring a gpu. + .. code-block:: python - Parameters - ---------- - f : function - Function to mark - """ - _requires_metal = [ - pytest.mark.metal, - pytest.mark.skipif(not device_enabled("metal"), reason="metal support not enabled"), - *requires_gpu(), - ] - return _compose(args, _requires_metal) + @Feature.require("cuda") + def test_compile_and_run(): + ... + @Feature.require("cuda", compile_only=True) + def test_compile_only(): + ... + """ + return cls._all_features[name](support_required=support_required) -def requires_vulkan(*args): - """Mark a test as requiring the vulkan runtime. - This also marks the test as requiring a gpu. +def _any_gpu_exists(): + return ( + tvm.cuda().exist + or tvm.rocm().exist + or tvm.opencl().exist + or tvm.metal().exist + or tvm.vulkan().exist + ) - Parameters - ---------- - f : function - Function to mark - """ - _requires_vulkan = [ - pytest.mark.vulkan, - pytest.mark.skipif(not device_enabled("vulkan"), reason="vulkan support not enabled"), - *requires_gpu(), - ] - return _compose(args, _requires_vulkan) +# Mark a test as requiring llvm to run +requires_llvm = Feature( + "llvm", "LLVM", cmake_flag="USE_LLVM", target_kind_enabled="llvm", target_kind_hardware="llvm" +) -def requires_tensorcore(*args): - """Mark a test as requiring a tensorcore to run. +# Mark a test as requiring a GPU to run. +requires_gpu = Feature("gpu", run_time_check=_any_gpu_exists) - Tests with this mark will not be run unless a tensorcore is present. +# Mark to differentiate tests that use the GPU in some capacity. +# +# These tests will be run on CPU-only test nodes and on test nodes with GPUs. +# To mark a test that must have a GPU present to run, use +# :py:func:`tvm.testing.requires_gpu`. +uses_gpu = requires_gpu(support_required="optional") + +# Mark a test as requiring the x86 Architecture to run. +requires_x86 = Feature( + "x86", "x86 Architecture", run_time_check=lambda: platform.machine() == "x86_64" +) + +# Mark a test as requiring the CUDA runtime. +requires_cuda = Feature( + "cuda", + "CUDA", + cmake_flag="USE_CUDA", + target_kind_enabled="cuda", + target_kind_hardware="cuda", + parent_features="gpu", +) + +# Mark a test as requiring a tensorcore to run +requires_tensorcore = Feature( + "tensorcore", + "NVIDIA Tensor Core", + run_time_check=lambda: tvm.cuda().exist and nvcc.have_tensorcore(tvm.cuda().compute_version), + parent_features="cuda", +) + +# Mark a test as requiring the cuDNN library. +requires_cudnn = Feature("cudnn", "cuDNN", cmake_flag="USE_CUDNN", parent_features="cuda") + +# Mark a test as requiring the cuBLAS library. +requires_cublas = Feature("cublas", "cuBLAS", cmake_flag="USE_CUBLAS", parent_features="cuda") + +# Mark a test as requiring the NVPTX compilation on the CUDA runtime +requires_nvptx = Feature( + "nvptx", + "NVPTX", + target_kind_enabled="nvptx", + target_kind_hardware="nvptx", + parent_features=["llvm", "cuda"], +) + +# Mark a test as requiring the CUDA Graph Feature +requires_cudagraph = Feature( + "cudagraph", + "CUDA Graph", + target_kind_enabled="cuda", + compile_time_check=nvcc.have_cudagraph, + parent_features="cuda", +) + +# Mark a test as requiring the OpenCL runtime +requires_opencl = Feature( + "opencl", + "OpenCL", + cmake_flag="USE_OPENCL", + target_kind_enabled="opencl", + target_kind_hardware="opencl", + parent_features="gpu", +) + +# Mark a test as requiring the rocm runtime +requires_rocm = Feature( + "rocm", + "ROCm", + cmake_flag="USE_ROCM", + target_kind_enabled="rocm", + target_kind_hardware="rocm", + parent_features="gpu", +) + +# Mark a test as requiring the metal runtime +requires_metal = Feature( + "metal", + "Metal", + cmake_flag="USE_METAL", + target_kind_enabled="metal", + target_kind_hardware="metal", + parent_features="gpu", +) + +# Mark a test as requiring the vulkan runtime +requires_vulkan = Feature( + "vulkan", + "Vulkan", + cmake_flag="USE_VULKAN", + target_kind_enabled="vulkan", + target_kind_hardware="vulkan", + parent_features="gpu", +) + +# Mark a test as requiring microTVM to run +requires_micro = Feature("micro", "MicroTVM", cmake_flag="USE_MICRO") + +# Mark a test as requiring rpc to run +requires_rpc = Feature("rpc", "RPC", cmake_flag="USE_RPC") + +# Mark a test as requiring Arm(R) Ethos(TM)-N to run +requires_ethosn = Feature("ethosn", "Arm(R) Ethos(TM)-N", cmake_flag="USE_ETHOSN") + +# Mark a test as requiring Hexagon to run +requires_hexagon = Feature( + "hexagon", + "Hexagon", + cmake_flag="USE_HEXAGON", + target_kind_enabled="hexagon", + compile_time_check=lambda: ( + (_cmake_flag_enabled("USE_LLVM") and tvm.target.codegen.llvm_version_major() >= 7) + or "Hexagon requires LLVM 7 or later" + ), + target_kind_hardware="hexagon", + parent_features="llvm", +) + +# Mark a test as requiring the CMSIS NN library +requires_cmsisnn = Feature("cmsisnn", "CMSIS NN", cmake_flag="USE_CMSISNN") + +# Mark a test as requiring the corstone300 FVP +requires_corstone300 = Feature( + "corstone300", + "Corstone-300", + compile_time_check=lambda: ( + (shutil.which("arm-none-eabi-gcc") is None) or "ARM embedded toolchain unavailable" + ), + parent_features="cmsisnn", +) + +# Mark a test as requiring Vitis AI to run +requires_vitis_ai = Feature("vitis_ai", "Vitis AI", cmake_flag="USE_VITIS_AI") + + +def _cmake_flag_enabled(flag): + flag = tvm.support.libinfo()[flag] + + # Because many of the flags can be library flags, we check if the + # flag is not disabled, rather than checking if it is enabled. + return flag.lower() not in ["off", "false", "0"] + + +def _tvm_test_targets(): + target_str = os.environ.get("TVM_TEST_TARGETS", "").strip() + if target_str: + # Use dict instead of set for de-duplication so that the + # targets stay in the order specified. + return list({t.strip(): None for t in target_str.split(";") if t.strip()}) - Parameters - ---------- - f : function - Function to mark - """ - _requires_tensorcore = [ - pytest.mark.tensorcore, - pytest.mark.skipif( - not tvm.cuda().exist or not nvcc.have_tensorcore(tvm.cuda(0).compute_version), - reason="No tensorcore present", - ), - *requires_gpu(), - ] - return _compose(args, _requires_tensorcore) + return DEFAULT_TEST_TARGETS -def requires_llvm(*args): - """Mark a test as requiring llvm to run. +def _compose(args, decs): + """Helper to apply multiple markers""" + if len(args) > 0: + f = args[0] + for d in reversed(decs): + f = d(f) + return f + return decs - Parameters - ---------- - f : function - Function to mark - """ - _requires_llvm = [ - pytest.mark.llvm, - pytest.mark.skipif(not device_enabled("llvm"), reason="LLVM support not enabled"), - ] - return _compose(args, _requires_llvm) +def slow(fn): + @functools.wraps(fn) + def wrapper(*args, **kwargs): + if SKIP_SLOW_TESTS: + pytest.skip("Skipping slow test since RUN_SLOW_TESTS environment variables is 'true'") + else: + fn(*args, **kwargs) -def requires_micro(*args): - """Mark a test as requiring microTVM to run. + return wrapper - Parameters - ---------- - f : function - Function to mark - """ - _requires_micro = [ - pytest.mark.skipif( - tvm.support.libinfo().get("USE_MICRO", "OFF") != "ON", - reason="MicroTVM support not enabled. Set USE_MICRO=ON in config.cmake to enable.", - ) - ] - return _compose(args, _requires_micro) +def requires_nvcc_version(major_version, minor_version=0, release_version=0): + """Mark a test as requiring at least a specific version of nvcc. -def requires_rpc(*args): - """Mark a test as requiring rpc to run. + Unit test marked with this decorator will run only if the + installed version of NVCC is at least `(major_version, + minor_version, release_version)`. + + This also marks the test as requiring a cuda support. Parameters ---------- - f : function - Function to mark - """ - _requires_rpc = [ - pytest.mark.skipif( - tvm.support.libinfo().get("USE_RPC", "OFF") != "ON", - reason="RPC support not enabled. Set USE_RPC=ON in config.cmake to enable.", - ) - ] - return _compose(args, _requires_rpc) + major_version: int + The major version of the (major,minor,release) version tuple. -def requires_ethosn(*args): - """Mark a test as requiring Arm(R) Ethos(TM)-N to run. + minor_version: int - Parameters - ---------- - f : function - Function to mark - """ - marks = [ - pytest.mark.ethosn, - pytest.mark.skipif( - not ethosn_available(), - reason=( - "Arm(R) Ethos(TM)-N support not enabled. " - "Set USE_ETHOSN=ON in config.cmake to enable, " - "and ensure that hardware support is present." - ), - ), - ] - return _compose(args, marks) + The minor version of the (major,minor,release) version tuple. + release_version: int -def requires_hexagon(*args): - """Mark a test as requiring Hexagon to run. + The release version of the (major,minor,release) version tuple. - Parameters - ---------- - f : function - Function to mark """ - _requires_hexagon = [ - pytest.mark.hexagon, - pytest.mark.skipif(not device_enabled("hexagon"), reason="Hexagon support not enabled"), - *requires_llvm(), - pytest.mark.skipif( - tvm.target.codegen.llvm_version_major() < 7, reason="Hexagon requires LLVM 7 or later" - ), - ] - return _compose(args, _requires_hexagon) + try: + nvcc_version = nvcc.get_cuda_version() + except RuntimeError: + nvcc_version = (0, 0, 0) -def requires_cmsisnn(*args): - """Mark a test as requiring the CMSIS NN library. + min_version = (major_version, minor_version, release_version) + version_str = ".".join(str(v) for v in min_version) + requires = [ + pytest.mark.skipif(nvcc_version < min_version, reason=f"Requires NVCC >= {version_str}"), + *requires_cuda.marks(), + ] - Parameters - ---------- - f : function - Function to mark - """ + def inner(func): + return _compose([func], requires) - requirements = [pytest.mark.skipif(not cmsisnn.enabled(), reason="CMSIS NN not enabled")] - return _compose(args, requirements) + return inner -def requires_vitis_ai(*args): - """Mark a test as requiring Vitis AI to run. +def skip_if_32bit(reason): + def decorator(*args): + if "32bit" in platform.architecture()[0]: + return _compose(args, [pytest.mark.skip(reason=reason)]) - Parameters - ---------- - f : function - Function to mark - """ + return _compose(args, []) - requirements = [pytest.mark.skipif(not vitis_ai.enabled(), reason="Vitis AI not enabled")] - return _compose(args, requirements) + return decorator def requires_package(*packages): diff --git a/tests/python/contrib/test_dnnl.py b/tests/python/contrib/test_dnnl.py index fecd776d7065..bb1231e39438 100755 --- a/tests/python/contrib/test_dnnl.py +++ b/tests/python/contrib/test_dnnl.py @@ -32,8 +32,8 @@ ) run_module = tvm.testing.parameter( - pytest.param(False, marks=[has_dnnl_codegen, *tvm.testing.requires_llvm()]), - pytest.param(True, marks=[has_dnnl_codegen, *tvm.testing.requires_llvm()]), + pytest.param(False, marks=[has_dnnl_codegen, *tvm.testing.requires_llvm.marks()]), + pytest.param(True, marks=[has_dnnl_codegen, *tvm.testing.requires_llvm.marks()]), ids=["compile", "run"], ) diff --git a/tests/python/contrib/test_tensorrt.py b/tests/python/contrib/test_tensorrt.py index 982ec976d54e..cecb64785a49 100644 --- a/tests/python/contrib/test_tensorrt.py +++ b/tests/python/contrib/test_tensorrt.py @@ -44,9 +44,9 @@ ) run_module = tvm.testing.parameter( - pytest.param(False, marks=[has_tensorrt_codegen, *tvm.testing.requires_cuda()]), + pytest.param(False, marks=[has_tensorrt_codegen, *tvm.testing.requires_cuda.marks()]), pytest.param( - True, marks=[has_tensorrt_runtime, has_tensorrt_codegen, *tvm.testing.requires_cuda()] + True, marks=[has_tensorrt_runtime, has_tensorrt_codegen, *tvm.testing.requires_cuda.marks()] ), ids=["compile", "run"], ) diff --git a/tests/python/driver/tvmc/test_compiler.py b/tests/python/driver/tvmc/test_compiler.py index d6ae27957de2..e8e93a6c7514 100644 --- a/tests/python/driver/tvmc/test_compiler.py +++ b/tests/python/driver/tvmc/test_compiler.py @@ -25,7 +25,7 @@ import tvm import tvm.testing -from tvm.testing.utils import ethosn_available +from tvm.relay.op.contrib.ethosn import ethosn_available from tvm.relay.backend import Runtime, Executor from tvm.contrib.target.vitis_ai import vitis_ai_available @@ -412,10 +412,7 @@ def test_compile_tflite_module_with_external_codegen_cmsisnn( assert len(c_source_files) == 4 -@pytest.mark.skipif( - not ethosn_available(), - reason="--target=Ethos(TM)-N78 is not available. TVM built with 'USE_ETHOSN OFF'", -) +@tvm.testing.requires_ethosn def test_compile_tflite_module_with_external_codegen_ethos_n78(tflite_mobilenet_v1_1_quant): pytest.importorskip("tflite") tvmc_model = tvmc.load(tflite_mobilenet_v1_1_quant) @@ -430,10 +427,7 @@ def test_compile_tflite_module_with_external_codegen_ethos_n78(tflite_mobilenet_ assert os.path.exists(dumps_path) -@pytest.mark.skipif( - not vitis_ai_available(), - reason="--target=vitis-ai is not available. TVM built with 'USE_VITIS_AI OFF'", -) +@tvm.testing.requires_vitis_ai def test_compile_tflite_module_with_external_codegen_vitis_ai(tflite_mobilenet_v1_1_quant): pytest.importorskip("tflite") diff --git a/tests/python/integration/test_reduce.py b/tests/python/integration/test_reduce.py index a40164ded941..f3886374ccb6 100644 --- a/tests/python/integration/test_reduce.py +++ b/tests/python/integration/test_reduce.py @@ -528,7 +528,7 @@ def check_target(device): check_target("rocm") -@tvm.testing.requires_gpu +@tvm.testing.requires_cuda def test_reduce_storage_reuse(): target = tvm.target.Target("cuda")