From 995664ce47a3675b105616d34cdb4a82cd7583f2 Mon Sep 17 00:00:00 2001 From: Akihiro Nitta Date: Sat, 8 Jan 2022 00:56:46 +0900 Subject: [PATCH] Mock _RICH_AVAILABLE --- tests/callbacks/test_rich_model_summary.py | 11 ++++++----- tests/callbacks/test_rich_progress_bar.py | 10 +++++----- 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/tests/callbacks/test_rich_model_summary.py b/tests/callbacks/test_rich_model_summary.py index 88c5f9ab531f0..c596557eed0dc 100644 --- a/tests/callbacks/test_rich_model_summary.py +++ b/tests/callbacks/test_rich_model_summary.py @@ -19,7 +19,6 @@ from pytorch_lightning import Trainer from pytorch_lightning.callbacks import RichModelSummary, RichProgressBar -from pytorch_lightning.utilities.imports import _RICH_AVAILABLE from pytorch_lightning.utilities.model_summary import summarize from tests.helpers import BoringModel from tests.helpers.runif import RunIf @@ -33,10 +32,12 @@ def test_rich_model_summary_callback(): assert isinstance(trainer.progress_bar_callback, RichProgressBar) -def test_rich_progress_bar_import_error(): - if not _RICH_AVAILABLE: - with pytest.raises(ModuleNotFoundError, match="`RichModelSummary` requires `rich` to be installed."): - Trainer(callbacks=RichModelSummary()) +def test_rich_progress_bar_import_error(monkeypatch): + import pytorch_lightning.callbacks.rich_model_summary as imports + + monkeypatch.setattr(imports, "_RICH_AVAILABLE", False) + with pytest.raises(ModuleNotFoundError, match="`RichModelSummary` requires `rich` to be installed."): + RichModelSummary() @RunIf(rich=True) diff --git a/tests/callbacks/test_rich_progress_bar.py b/tests/callbacks/test_rich_progress_bar.py index d9a9ad8c3c726..7c6d2b656d08e 100644 --- a/tests/callbacks/test_rich_progress_bar.py +++ b/tests/callbacks/test_rich_progress_bar.py @@ -20,7 +20,6 @@ from pytorch_lightning import Trainer from pytorch_lightning.callbacks import ProgressBarBase, RichProgressBar from pytorch_lightning.callbacks.progress.rich_progress import RichProgressBarTheme -from pytorch_lightning.utilities.imports import _RICH_AVAILABLE from tests.helpers.boring_model import BoringModel, RandomDataset, RandomIterableDataset from tests.helpers.runif import RunIf @@ -83,11 +82,12 @@ def predict_dataloader(self): assert progress_update.call_count == 8 -def test_rich_progress_bar_import_error(): - if not _RICH_AVAILABLE: - with pytest.raises(ModuleNotFoundError, match="`RichProgressBar` requires `rich` >= 10.2.2."): - Trainer(callbacks=RichProgressBar()) +def test_rich_progress_bar_import_error(monkeypatch): + import pytorch_lightning.callbacks.progress.rich_progress as imports + monkeypatch.setattr(imports, "_RICH_AVAILABLE", False) + with pytest.raises(ModuleNotFoundError, match="`RichProgressBar` requires `rich` >= 10.2.2."): + RichProgressBar() @RunIf(rich=True) def test_rich_progress_bar_custom_theme(tmpdir):