Skip to content

Commit

Permalink
Added save_config_filename init argument to LightningCLI
Browse files Browse the repository at this point in the history
  • Loading branch information
mauvilsa committed May 27, 2021
1 parent 9304c0d commit 1be09d3
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 2 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `should_rank_save_checkpoint` property to Training Plugins ([#7684](https://github.com/PyTorchLightning/pytorch-lightning/pull/7684))


- Added `save_config_filename` init argument to `LightningCLI` to ease resolving name conflicts ([#7741](https://github.com/PyTorchLightning/pytorch-lightning/pull/7741))


### Changed

- Changed calling of `untoggle_optimizer(opt_idx)` out of the closure function ([#7563](https://github.com/PyTorchLightning/pytorch-lightning/pull/7563)
Expand Down
7 changes: 5 additions & 2 deletions pytorch_lightning/utilities/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def __init__(
self,
parser: LightningArgumentParser,
config: Union[Namespace, Dict[str, Any]],
config_filename: str = 'config.yaml'
config_filename: str,
) -> None:
self.parser = parser
self.config = config
Expand All @@ -96,6 +96,7 @@ def __init__(
model_class: Type[LightningModule],
datamodule_class: Type[LightningDataModule] = None,
save_config_callback: Type[SaveConfigCallback] = SaveConfigCallback,
save_config_filename: str = 'config.yaml',
trainer_class: Type[Trainer] = Trainer,
trainer_defaults: Dict[str, Any] = None,
seed_everything_default: int = None,
Expand Down Expand Up @@ -154,6 +155,7 @@ def __init__(
self.model_class = model_class
self.datamodule_class = datamodule_class
self.save_config_callback = save_config_callback
self.save_config_filename = save_config_filename
self.trainer_class = trainer_class
self.trainer_defaults = {} if trainer_defaults is None else trainer_defaults
self.seed_everything_default = seed_everything_default
Expand Down Expand Up @@ -241,7 +243,8 @@ def instantiate_trainer(self) -> None:
else:
self.config_init['trainer']['callbacks'].append(self.trainer_defaults['callbacks'])
if self.save_config_callback is not None:
self.config_init['trainer']['callbacks'].append(self.save_config_callback(self.parser, self.config))
config_callback = self.save_config_callback(self.parser, self.config, self.save_config_filename)
self.config_init['trainer']['callbacks'].append(config_callback)
self.trainer = self.trainer_class(**self.config_init['trainer'])

def prepare_fit_kwargs(self) -> None:
Expand Down

0 comments on commit 1be09d3

Please sign in to comment.