Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tests: switch imports for pytorch #16595

Merged
merged 13 commits into from
Feb 2, 2023
6 changes: 3 additions & 3 deletions tests/tests_pytorch/accelerators/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@

import torch

from pytorch_lightning import Trainer
from pytorch_lightning.accelerators import Accelerator
from pytorch_lightning.strategies import DDPStrategy
from lightning.pytorch import Trainer
from lightning.pytorch.accelerators import Accelerator
from lightning.pytorch.strategies import DDPStrategy


def test_pluggable_accelerator():
Expand Down
14 changes: 7 additions & 7 deletions tests/tests_pytorch/accelerators/test_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@
import pytest
import torch

import pytorch_lightning as pl
from lightning_fabric.plugins import TorchCheckpointIO
from pytorch_lightning import Trainer
from pytorch_lightning.accelerators import CPUAccelerator
from pytorch_lightning.demos.boring_classes import BoringModel
from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin
from pytorch_lightning.strategies import SingleDeviceStrategy
import lightning.pytorch as pl
from lightning.fabric.plugins import TorchCheckpointIO
from lightning.pytorch import Trainer
from lightning.pytorch.accelerators import CPUAccelerator
from lightning.pytorch.demos.boring_classes import BoringModel
from lightning.pytorch.plugins.precision.precision_plugin import PrecisionPlugin
from lightning.pytorch.strategies import SingleDeviceStrategy
from tests_pytorch.helpers.runif import RunIf


Expand Down
8 changes: 4 additions & 4 deletions tests/tests_pytorch/accelerators/test_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@
import pytest
import torch

from pytorch_lightning import Trainer
from pytorch_lightning.accelerators import CUDAAccelerator
from pytorch_lightning.accelerators.cuda import get_nvidia_gpu_stats
from pytorch_lightning.demos.boring_classes import BoringModel
from lightning.pytorch import Trainer
from lightning.pytorch.accelerators import CUDAAccelerator
from lightning.pytorch.accelerators.cuda import get_nvidia_gpu_stats
from lightning.pytorch.demos.boring_classes import BoringModel
from tests_pytorch.helpers.runif import RunIf


Expand Down
14 changes: 7 additions & 7 deletions tests/tests_pytorch/accelerators/test_hpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,13 @@
import pytest
import torch

from pytorch_lightning import Callback, seed_everything, Trainer
from pytorch_lightning.accelerators import HPUAccelerator
from pytorch_lightning.demos.boring_classes import BoringModel
from pytorch_lightning.strategies.hpu_parallel import HPUParallelStrategy
from pytorch_lightning.strategies.single_hpu import SingleHPUStrategy
from pytorch_lightning.utilities import _HPU_AVAILABLE
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from lightning.pytorch import Callback, seed_everything, Trainer
from lightning.pytorch.accelerators import HPUAccelerator
from lightning.pytorch.demos.boring_classes import BoringModel
from lightning.pytorch.strategies.hpu_parallel import HPUParallelStrategy
from lightning.pytorch.strategies.single_hpu import SingleHPUStrategy
from lightning.pytorch.utilities import _HPU_AVAILABLE
from lightning.pytorch.utilities.exceptions import MisconfigurationException
from tests_pytorch.helpers.datamodules import ClassifDataModule
from tests_pytorch.helpers.runif import RunIf
from tests_pytorch.helpers.simple_models import ClassificationModel
Expand Down
22 changes: 11 additions & 11 deletions tests/tests_pytorch/accelerators/test_ipu.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,16 @@
import torch.nn.functional as F
from torch.utils.data import DistributedSampler

from pytorch_lightning import Callback, seed_everything, Trainer
from pytorch_lightning.accelerators import IPUAccelerator
from pytorch_lightning.core.module import LightningModule
from pytorch_lightning.demos.boring_classes import BoringModel
from pytorch_lightning.plugins import IPUPrecisionPlugin
from pytorch_lightning.strategies.ipu import IPUStrategy
from pytorch_lightning.trainer.states import RunningStage, TrainerFn
from pytorch_lightning.trainer.supporters import CombinedLoader
from pytorch_lightning.utilities import _IPU_AVAILABLE
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from lightning.pytorch import Callback, seed_everything, Trainer
from lightning.pytorch.accelerators import IPUAccelerator
from lightning.pytorch.core.module import LightningModule
from lightning.pytorch.demos.boring_classes import BoringModel
from lightning.pytorch.plugins import IPUPrecisionPlugin
from lightning.pytorch.strategies.ipu import IPUStrategy
from lightning.pytorch.trainer.states import RunningStage, TrainerFn
from lightning.pytorch.trainer.supporters import CombinedLoader
from lightning.pytorch.utilities import _IPU_AVAILABLE
from lightning.pytorch.utilities.exceptions import MisconfigurationException
from tests_pytorch.helpers.datamodules import ClassifDataModule
from tests_pytorch.helpers.runif import RunIf
from tests_pytorch.helpers.simple_models import ClassificationModel
Expand Down Expand Up @@ -95,7 +95,7 @@ def test_auto_device_count():


@pytest.mark.skipif(_IPU_AVAILABLE, reason="test requires non-IPU machine")
@mock.patch("pytorch_lightning.accelerators.ipu.IPUAccelerator.is_available", return_value=True)
@mock.patch("lightning.pytorch.accelerators.ipu.IPUAccelerator.is_available", return_value=True)
def test_fail_if_no_ipus(_, tmpdir):
with pytest.raises(MisconfigurationException, match="IPU Accelerator requires IPU devices to run"):
Trainer(default_root_dir=tmpdir, accelerator="ipu", devices=1)
Expand Down
6 changes: 3 additions & 3 deletions tests/tests_pytorch/accelerators/test_mps.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@
import torch

import tests_pytorch.helpers.pipelines as tpipes
from pytorch_lightning import Trainer
from pytorch_lightning.accelerators import MPSAccelerator
from pytorch_lightning.demos.boring_classes import BoringModel
from lightning.pytorch import Trainer
from lightning.pytorch.accelerators import MPSAccelerator
from lightning.pytorch.demos.boring_classes import BoringModel
from tests_pytorch.helpers.runif import RunIf


Expand Down
2 changes: 1 addition & 1 deletion tests/tests_pytorch/accelerators/test_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pytorch_lightning.accelerators import AcceleratorRegistry
from lightning.pytorch.accelerators import AcceleratorRegistry


def test_available_accelerators_in_registry():
Expand Down
16 changes: 8 additions & 8 deletions tests/tests_pytorch/accelerators/test_tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,13 @@
from torch import nn
from torch.utils.data import DataLoader

from pytorch_lightning import Trainer
from pytorch_lightning.accelerators.cpu import CPUAccelerator
from pytorch_lightning.accelerators.tpu import TPUAccelerator
from pytorch_lightning.demos.boring_classes import BoringModel, RandomDataset
from pytorch_lightning.plugins import PrecisionPlugin, TPUPrecisionPlugin, XLACheckpointIO
from pytorch_lightning.strategies import DDPStrategy, TPUSpawnStrategy
from pytorch_lightning.utilities import find_shared_parameters
from lightning.pytorch import Trainer
from lightning.pytorch.accelerators.cpu import CPUAccelerator
from lightning.pytorch.accelerators.tpu import TPUAccelerator
from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset
from lightning.pytorch.plugins import PrecisionPlugin, TPUPrecisionPlugin, XLACheckpointIO
from lightning.pytorch.strategies import DDPStrategy, TPUSpawnStrategy
from lightning.pytorch.utilities import find_shared_parameters
from tests_pytorch.helpers.runif import RunIf
from tests_pytorch.trainer.optimization.test_manual_optimization import assert_emtpy_grad

Expand Down Expand Up @@ -269,7 +269,7 @@ def test_xla_checkpoint_plugin_being_default(tpu_available):

@RunIf(tpu=True)
@patch("torch_xla.distributed.parallel_loader.MpDeviceLoader")
@patch("pytorch_lightning.strategies.tpu_spawn.TPUSpawnStrategy.root_device")
@patch("lightning.pytorch.strategies.tpu_spawn.TPUSpawnStrategy.root_device")
def test_mp_device_dataloader_attribute(root_device_mock, mp_loader_mock):
dataset = RandomDataset(32, 64)
dataloader = DataLoader(dataset)
Expand Down
2 changes: 1 addition & 1 deletion tests/tests_pytorch/benchmarks/test_basic_parity.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import torch
from tqdm import tqdm

from pytorch_lightning import LightningModule, seed_everything, Trainer
from lightning.pytorch import LightningModule, seed_everything, Trainer
from tests_pytorch.helpers.advanced_models import ParityModuleCIFAR, ParityModuleMNIST, ParityModuleRNN

_EXTEND_BENCHMARKS = os.getenv("PL_RUNNING_BENCHMARKS", "0") == "1"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import torch.nn as nn
from torch.utils.data import DataLoader, DistributedSampler

from pytorch_lightning import LightningModule, seed_everything, Trainer
from lightning.pytorch import LightningModule, seed_everything, Trainer
from tests_pytorch.helpers.runif import RunIf


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pytorch_lightning.demos.boring_classes import BoringModel
from pytorch_lightning.trainer.trainer import Trainer
from lightning.pytorch.demos.boring_classes import BoringModel
from lightning.pytorch.trainer.trainer import Trainer


def test_main_progress_bar_with_val_check_interval_int():
Expand Down
32 changes: 16 additions & 16 deletions tests/tests_pytorch/callbacks/progress/test_rich_progress_bar.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@
import pytest
from torch.utils.data import DataLoader

from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ProgressBarBase, RichProgressBar
from pytorch_lightning.callbacks.progress.rich_progress import RichProgressBarTheme
from pytorch_lightning.demos.boring_classes import BoringModel, RandomDataset, RandomIterableDataset
from pytorch_lightning.loggers import CSVLogger
from lightning.pytorch import Trainer
from lightning.pytorch.callbacks import ProgressBarBase, RichProgressBar
from lightning.pytorch.callbacks.progress.rich_progress import RichProgressBarTheme
from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset, RandomIterableDataset
from lightning.pytorch.loggers import CSVLogger
from tests_pytorch.helpers.runif import RunIf


Expand Down Expand Up @@ -75,26 +75,26 @@ def predict_dataloader(self):
)
model = TestModel()

with mock.patch("pytorch_lightning.callbacks.progress.rich_progress.Progress.update") as mocked:
with mock.patch("lightning.pytorch.callbacks.progress.rich_progress.Progress.update") as mocked:
trainer.fit(model)
# 3 for main progress bar and 1 for val progress bar
assert mocked.call_count == 4

with mock.patch("pytorch_lightning.callbacks.progress.rich_progress.Progress.update") as mocked:
with mock.patch("lightning.pytorch.callbacks.progress.rich_progress.Progress.update") as mocked:
trainer.validate(model)
assert mocked.call_count == 1

with mock.patch("pytorch_lightning.callbacks.progress.rich_progress.Progress.update") as mocked:
with mock.patch("lightning.pytorch.callbacks.progress.rich_progress.Progress.update") as mocked:
trainer.test(model)
assert mocked.call_count == 1

with mock.patch("pytorch_lightning.callbacks.progress.rich_progress.Progress.update") as mocked:
with mock.patch("lightning.pytorch.callbacks.progress.rich_progress.Progress.update") as mocked:
trainer.predict(model)
assert mocked.call_count == 1


def test_rich_progress_bar_import_error(monkeypatch):
import pytorch_lightning.callbacks.progress.rich_progress as imports
import lightning.pytorch.callbacks.progress.rich_progress as imports

monkeypatch.setattr(imports, "_RICH_AVAILABLE", False)
with pytest.raises(ModuleNotFoundError, match="`RichProgressBar` requires `rich` >= 10.2.2."):
Expand All @@ -105,7 +105,7 @@ def test_rich_progress_bar_import_error(monkeypatch):
def test_rich_progress_bar_custom_theme(tmpdir):
"""Test to ensure that custom theme styles are used."""
with mock.patch.multiple(
"pytorch_lightning.callbacks.progress.rich_progress",
"lightning.pytorch.callbacks.progress.rich_progress",
CustomBarColumn=DEFAULT,
BatchesProcessedColumn=DEFAULT,
CustomTimeColumn=DEFAULT,
Expand Down Expand Up @@ -142,7 +142,7 @@ def on_train_start(self) -> None:
model = TestModel()

with mock.patch(
"pytorch_lightning.callbacks.progress.rich_progress.Progress.stop", autospec=True
"lightning.pytorch.callbacks.progress.rich_progress.Progress.stop", autospec=True
) as mock_progress_stop:
progress_bar = RichProgressBar()
trainer = Trainer(
Expand Down Expand Up @@ -180,7 +180,7 @@ def test_rich_progress_bar_leave(tmpdir, leave, reset_call_count):
model = BoringModel()

with mock.patch(
"pytorch_lightning.callbacks.progress.rich_progress.Progress.reset", autospec=True
"lightning.pytorch.callbacks.progress.rich_progress.Progress.reset", autospec=True
) as mock_progress_reset:
progress_bar = RichProgressBar(leave=leave)
trainer = Trainer(
Expand All @@ -199,7 +199,7 @@ def test_rich_progress_bar_leave(tmpdir, leave, reset_call_count):


@RunIf(rich=True)
@mock.patch("pytorch_lightning.callbacks.progress.rich_progress.Progress.update")
@mock.patch("lightning.pytorch.callbacks.progress.rich_progress.Progress.update")
def test_rich_progress_bar_refresh_rate_disabled(progress_update, tmpdir):
trainer = Trainer(
default_root_dir=tmpdir,
Expand Down Expand Up @@ -308,7 +308,7 @@ def test_rich_progress_bar_counter_with_val_check_interval(tmpdir):


@RunIf(rich=True)
@mock.patch("pytorch_lightning.callbacks.progress.rich_progress._detect_light_colab_theme", return_value=True)
@mock.patch("lightning.pytorch.callbacks.progress.rich_progress._detect_light_colab_theme", return_value=True)
def test_rich_progress_bar_colab_light_theme_update(*_):
theme = RichProgressBar().theme
assert theme.description == "black"
Expand Down Expand Up @@ -499,7 +499,7 @@ def test_rich_progress_bar_disabled(tmpdir):
callbacks=[bar],
)

with mock.patch("pytorch_lightning.callbacks.progress.rich_progress.CustomProgress") as mocked:
with mock.patch("lightning.pytorch.callbacks.progress.rich_progress.CustomProgress") as mocked:
trainer.fit(model)
trainer.validate(model)
trainer.test(model)
Expand Down
Loading