|
63 | 63 | else:
|
64 | 64 | from argparse import Namespace
|
65 | 65 |
|
| 66 | + def lazy_instance(*args, **kwargs): |
| 67 | + return None |
| 68 | + |
66 | 69 |
|
67 | 70 | @contextmanager
|
68 | 71 | def mock_subclasses(baseclass, *subclasses):
|
@@ -176,7 +179,9 @@ def on_fit_start(self):
|
176 | 179 | self.trainer.ran_asserts = True
|
177 | 180 |
|
178 | 181 | with mock.patch("sys.argv", ["any.py", "fit", f"--trainer.callbacks={json.dumps(callbacks)}"]):
|
179 |
| - cli = LightningCLI(TestModel, trainer_defaults={"fast_dev_run": True, "logger": CSVLogger(".")}) |
| 182 | + cli = LightningCLI( |
| 183 | + TestModel, trainer_defaults={"fast_dev_run": True, "logger": lazy_instance(CSVLogger, save_dir=".")} |
| 184 | + ) |
180 | 185 |
|
181 | 186 | assert cli.trainer.ran_asserts
|
182 | 187 |
|
@@ -592,7 +597,7 @@ def on_fit_start(self):
|
592 | 597 |
|
593 | 598 | # mps not yet supported by distributed
|
594 | 599 | @RunIf(skip_windows=True, mps=False)
|
595 |
| -@pytest.mark.parametrize("logger", [False, TensorBoardLogger(".")]) |
| 600 | +@pytest.mark.parametrize("logger", [False, lazy_instance(TensorBoardLogger, save_dir=".")]) |
596 | 601 | @pytest.mark.parametrize("strategy", ["ddp_spawn", "ddp"])
|
597 | 602 | def test_cli_distributed_save_config_callback(cleandir, logger, strategy):
|
598 | 603 | from torch.multiprocessing import ProcessRaisedException
|
@@ -1478,9 +1483,7 @@ def test_tensorboard_logger_init_args():
|
1478 | 1483 | "TensorBoardLogger",
|
1479 | 1484 | {
|
1480 | 1485 | "save_dir": "tb", # Resolve from TensorBoardLogger.__init__
|
1481 |
| - }, |
1482 |
| - { |
1483 |
| - "comment": "tb", # Unsupported resolving from local imports |
| 1486 | + "comment": "tb", # Resolve from FabricTensorBoardLogger.experiment SummaryWriter local import |
1484 | 1487 | },
|
1485 | 1488 | )
|
1486 | 1489 |
|
|
0 commit comments