diff --git a/tests/ci/slow_tests_diffusers.sh b/tests/ci/slow_tests_diffusers.sh index 6bcc6e4d5e..6e4293764c 100644 --- a/tests/ci/slow_tests_diffusers.sh +++ b/tests/ci/slow_tests_diffusers.sh @@ -2,4 +2,5 @@ python -m pip install --upgrade pip export RUN_SLOW=true +CUSTOM_BF16_OPS=1 python -m pytest tests/test_diffusers.py -v -s -k "test_no_throughput_regression_autocast" make slow_tests_diffusers diff --git a/tests/test_diffusers.py b/tests/test_diffusers.py index 57e270d9c6..9545e021fa 100644 --- a/tests/test_diffusers.py +++ b/tests/test_diffusers.py @@ -21,7 +21,7 @@ import tempfile from io import BytesIO from pathlib import Path -from unittest import TestCase +from unittest import TestCase, skipUnless import numpy as np import requests @@ -31,7 +31,7 @@ from parameterized import parameterized from PIL import Image from transformers import CLIPTextConfig, CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer -from transformers.testing_utils import slow +from transformers.testing_utils import parse_flag_from_env, slow from optimum.habana import GaudiConfig from optimum.habana.diffusers import ( @@ -58,6 +58,20 @@ TEXTUAL_INVERSION_RUNTIME = 206.32180358597543 +_run_custom_bf16_ops_test_ = parse_flag_from_env("CUSTOM_BF16_OPS", default=False) + + +def custom_bf16_ops(test_case): + """ + Decorator marking a test as needing custom bf16 ops. + Custom bf16 ops must be declared before `habana_frameworks.torch.core` is imported, which is not possible if some other tests are executed before. + + Such tests are skipped by default. Set the CUSTOM_BF16_OPS environment variable to a truthy value to run them. + + """ + return skipUnless(_run_custom_bf16_ops_test_, "test requires custom bf16 ops")(test_case) + + class GaudiPipelineUtilsTester(TestCase): """ Tests the features added on top of diffusers/pipeline_utils.py. @@ -550,6 +564,7 @@ def test_no_throughput_regression_bf16(self): self.assertEqual(len(outputs.images), num_images_per_prompt * len(prompts)) self.assertGreaterEqual(outputs.throughput, 0.95 * THROUGHPUT_BASELINE_BF16) + @custom_bf16_ops @slow def test_no_throughput_regression_autocast(self): prompts = [