-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Mechanism to skip certain hooks #5586
Comments
This issue has been automatically marked as stale because it hasn't had any recent activity. This issue will be closed in 7 days if no further activity occurs. Thank you for your contributions, Pytorch Lightning Team! |
Hello! I think this would be a really useful feature. It would improve the flexibility of |
I had the same issue as reported in the PL forums but used a slightly different work-around which I'll provide here as an alternative solution. The issue revolved around PL trying to load the weights to the non-existing dm = MyDataModule()
model = MyNet()
trainer = pl.Trainer(callbacks=[ModuleInitializer()])
trainer.fit(model, dm)
trained_model = MyNet.load_from_checkpoint(trainer.checkpoint_callback.best_model_path) where class ModuleInitializer(Callback):
"""Callback for dynamically initializing the model according to the learned data transformation
from pl.LightningDataModule.setup. Called right after pl.LightningDataModule.setup"""
def setup(
self, trainer: pl.Trainer, pl_module: pl.LightningModule, stage: Optional[str] = None
):
if stage == "fit":
dim = trainer.datamodule.get_dim()
pl_module.initialize(dim)
pl_module.hparams["dim"] = dim and class MyNet(pl.LightningModule):
def __init__(
self,
lr: float = 1e-3,
dim = None
):
super().__init__()
self.lr = lr
self.save_hyperparameters()
if dim:
self.initialize(dim)
def initialize(self, dim):
"""Dynamically initialize the model according to the learned data transformation"""
self.my_pytorch_module = MyPytorchModule(dim)
... When the model is initialized for training it'll get the dimensionality from the |
This issue has been automatically marked as stale because it hasn't had any recent activity. This issue will be closed in 7 days if no further activity occurs. Thank you for your contributions, Pytorch Lightning Team! |
@awaelchli #6420 is a related view of this |
@awaelchli I would not pursue this as the described alternative is the natural solution for all Python users. Doing this limits the flexibility in how they are skipped and requires knowing that One more argument against is that "loop customization" is the natural vehicle for this. Albeit more flexible and more complex |
🚀 Feature
Do we need a way to prevent certain hooks to be executed? I'm not entirely sure how solid this idea is, so I'm hoping for some discussion :)
Motivation
A user encountered a use case in which they wanted to build the model in the setup hook. However, because the setup hook is exectued everytime regardless whether the model was already trained or not, this would then overwrite the weights of the model, making continued training impossible. #5410
Pitch
Something abstract like
model.skip_hooks = ["setup", ...]
could be considered.
Alternatives
The user can handle it on their side with some conditional code in their hook,
or forcefully remove the method from the model object,
delattr(model, "setup")
, for example.Additional context
Originates from here
cc: @sjgosai
cc @Borda @tchaton @justusschock @awaelchli @carmocca @ninginthecloud @daniellepintz @rohitgr7
The text was updated successfully, but these errors were encountered: