diff --git a/tests/test_trainer.py b/tests/test_trainer.py index 1963f805c7..2cb3147523 100644 --- a/tests/test_trainer.py +++ b/tests/test_trainer.py @@ -49,6 +49,7 @@ get_gpu_count, get_tests_dir, is_staging_test, + parse_flag_from_env, require_optuna, require_safetensors, require_sentencepiece, @@ -90,6 +91,20 @@ PATH_SAMPLE_TEXT = f"{get_tests_dir()}/fixtures/sample_text.txt" +_run_safe_loading_tests_ = parse_flag_from_env("SAFE_LOADING_TESTS", default=False) + + +def safe_loading_test(test_case): + """ + Decorator marking a test as needing custom bf16 ops. + Custom bf16 ops must be declared before `habana_frameworks.torch.core` is imported, which is not possible if some other tests are executed before. + + Such tests are skipped by default. Set the CUSTOM_BF16_OPS environment variable to a truthy value to run them. + + """ + return unittest.skipUnless(_run_safe_loading_tests_, "test requires SAFE_LOADING_TESTS")(test_case) + + class RegressionDataset: def __init__(self, a=2, b=3, length=64, seed=42, label_names=None): np.random.seed(seed) @@ -1465,6 +1480,7 @@ def test_training_with_resume_from_checkpoint_false(self): trainer.train(resume_from_checkpoint=False) + @safe_loading_test @require_safetensors def test_resume_training_with_safe_checkpoint(self): # This test will fail for more than 2 GPUs since the batch size will get bigger and with the number of @@ -1658,6 +1674,7 @@ def test_load_best_model_at_end(self): self.check_saved_checkpoints(tmpdir, 5, total, is_pretrained=False) self.check_best_model_has_been_loaded(tmpdir, 5, total, trainer, "eval_loss", is_pretrained=False) + @safe_loading_test @require_safetensors def test_load_best_model_from_safetensors(self): total = int(self.n_epochs * 64 / self.batch_size)