diff --git a/pyproject.toml b/pyproject.toml index a7e172002214..d66b89769c2c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,4 +32,5 @@ doctest_optionflags="NUMBER NORMALIZE_WHITESPACE ELLIPSIS" doctest_glob="**/*.md" markers = [ "flash_attn_test: marks tests related to flash attention (deselect with '-m \"not flash_attn_test\"')", + "bitsandbytes: select (or deselect with `not`) bitsandbytes integration tests", ] \ No newline at end of file diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index 0ff7e718af20..50e178fbea3f 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -966,9 +966,17 @@ def require_aqlm(test_case): def require_bitsandbytes(test_case): """ - Decorator for bits and bytes (bnb) dependency + Decorator marking a test that requires the bitsandbytes library. Will be skipped when the library or its hard dependency torch is not installed. """ - return unittest.skipUnless(is_bitsandbytes_available(), "test requires bnb")(test_case) + if is_bitsandbytes_available() and is_torch_available(): + try: + import pytest + + return pytest.mark.bitsandbytes(test_case) + except ImportError: + return test_case + else: + return unittest.skip("test requires bitsandbytes and torch")(test_case) def require_optimum(test_case):