Skip to content

Commit

Permalink
Restore test
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca committed Feb 18, 2022
1 parent 6e14209 commit 18067c5
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 15 deletions.
6 changes: 5 additions & 1 deletion pytorch_lightning/utilities/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,7 +447,11 @@ def setup(self, trainer: Trainer, pl_module: LightningModule, stage: Optional[st

def __reduce__(self) -> Tuple[Type["SaveConfigCallback"], Tuple, Dict]:
# `ArgumentParser` is un-pickleable. Drop it
return self.__class__, (None, self.config, self.config_filename), {}
return (
self.__class__,
(None, self.config, self.config_filename),
{"overwrite": self.overwrite, "multifile": self.multifile},
)


class LightningCLI:
Expand Down
19 changes: 5 additions & 14 deletions tests/utilities/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _TORCHVISION_AVAILABLE
from tests.helpers import BoringDataModule, BoringModel
from tests.helpers.runif import RunIf
from tests.helpers.utils import no_warning_call

torchvision_version = version.parse("0")
Expand Down Expand Up @@ -577,18 +576,8 @@ def on_fit_start(self):


@pytest.mark.parametrize("logger", (False, True))
@pytest.mark.parametrize(
"trainer_kwargs",
(
# dict(strategy="ddp_spawn")
# dict(strategy="ddp")
# the previous accl_conn will choose singleDeviceStrategy for both strategy=ddp/ddp_spawn
# TODO revisit this test as it never worked with DDP or DDPSpawn
dict(strategy="single_device"),
pytest.param({"tpu_cores": 1}, marks=RunIf(tpu=True)),
),
)
def test_cli_distributed_save_config_callback(tmpdir, logger, trainer_kwargs):
@pytest.mark.parametrize("strategy", ("ddp_spawn", "ddp"))
def test_cli_distributed_save_config_callback(tmpdir, logger, strategy):
with mock.patch("sys.argv", ["any.py", "fit"]), pytest.raises(
MisconfigurationException, match=r"Error on fit start"
):
Expand All @@ -599,7 +588,9 @@ def test_cli_distributed_save_config_callback(tmpdir, logger, trainer_kwargs):
"logger": logger,
"max_steps": 1,
"max_epochs": 1,
**trainer_kwargs,
"strategy": strategy,
"accelerator": "cpu",
"devices": 1,
},
)
if logger:
Expand Down

0 comments on commit 18067c5

Please sign in to comment.