Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mechanism to skip certain hooks #5586

Closed
awaelchli opened this issue Jan 20, 2021 · 6 comments
Closed

Mechanism to skip certain hooks #5586

awaelchli opened this issue Jan 20, 2021 · 6 comments
Labels
design Includes a design discussion discussion In a discussion stage feature Is an improvement or enhancement hooks Related to the hooks API

Comments

@awaelchli
Copy link
Contributor

awaelchli commented Jan 20, 2021

🚀 Feature

Do we need a way to prevent certain hooks to be executed? I'm not entirely sure how solid this idea is, so I'm hoping for some discussion :)

Motivation

A user encountered a use case in which they wanted to build the model in the setup hook. However, because the setup hook is exectued everytime regardless whether the model was already trained or not, this would then overwrite the weights of the model, making continued training impossible. #5410

Pitch

Something abstract like
model.skip_hooks = ["setup", ...]
could be considered.

Alternatives

The user can handle it on their side with some conditional code in their hook,

def setup(self, stage):
    if self.has_setup:
        return
    # do your thing
    ...
    self.has_setup = True

or forcefully remove the method from the model object, delattr(model, "setup"), for example.

Additional context

Originates from here

cc: @sjgosai

cc @Borda @tchaton @justusschock @awaelchli @carmocca @ninginthecloud @daniellepintz @rohitgr7

@awaelchli awaelchli added feature Is an improvement or enhancement help wanted Open to be worked on discussion In a discussion stage labels Jan 20, 2021
@stale
Copy link

stale bot commented Feb 19, 2021

This issue has been automatically marked as stale because it hasn't had any recent activity. This issue will be closed in 7 days if no further activity occurs. Thank you for your contributions, Pytorch Lightning Team!

@stale stale bot added the won't fix This will not be worked on label Feb 19, 2021
@awaelchli awaelchli removed the won't fix This will not be worked on label Feb 19, 2021
@sjgosai
Copy link

sjgosai commented Feb 22, 2021

Hello! I think this would be a really useful feature. It would improve the flexibility of setup in combination with the Trainer, as there is no other obvious solution for occasionally dynamic graph building. I haven't explored all of the hooks the trainer calls, but I'm certain there will be other use cases for skipping hooks in the trainer.

@vamaral1
Copy link

vamaral1 commented Mar 9, 2021

I had the same issue as reported in the PL forums but used a slightly different work-around which I'll provide here as an alternative solution. The issue revolved around PL trying to load the weights to the non-existing MyPytorchModule. Keep in mind it doesn't conform to the suggested use of callbacks

dm = MyDataModule()
model = MyNet()
trainer = pl.Trainer(callbacks=[ModuleInitializer()])
trainer.fit(model, dm)
trained_model = MyNet.load_from_checkpoint(trainer.checkpoint_callback.best_model_path)

where

class ModuleInitializer(Callback):
    """Callback for dynamically initializing the model according to the learned data transformation
    from pl.LightningDataModule.setup. Called right after pl.LightningDataModule.setup"""

    def setup(
        self, trainer: pl.Trainer, pl_module: pl.LightningModule, stage: Optional[str] = None
    ):
        if stage == "fit":
            dim = trainer.datamodule.get_dim()
            pl_module.initialize(dim)
            pl_module.hparams["dim"] = dim

and

class MyNet(pl.LightningModule):

    def __init__(
        self,
        lr: float = 1e-3,
        dim = None
    ):
        super().__init__()
        self.lr = lr
        self.save_hyperparameters()
        if dim:
            self.initialize(dim)

    def initialize(self, dim):
        """Dynamically initialize the model according to the learned data transformation"""
        self.my_pytorch_module = MyPytorchModule(dim)
    ...

When the model is initialized for training it'll get the dimensionality from the ModuleInitializer callback and save those as hyperparameters. When the model is being loaded, PL passes in the saved hyperparemeters to the __init__ of the PL module which will allow the MyPytorchModule to be initialized and the saved weights loaded.

@stale
Copy link

stale bot commented Apr 10, 2021

This issue has been automatically marked as stale because it hasn't had any recent activity. This issue will be closed in 7 days if no further activity occurs. Thank you for your contributions, Pytorch Lightning Team!

@stale stale bot added the won't fix This will not be worked on label Apr 10, 2021
@stale stale bot closed this as completed Apr 18, 2021
@ananthsub
Copy link
Contributor

@awaelchli #6420 is a related view of this

@awaelchli awaelchli reopened this Apr 19, 2021
@stale stale bot removed the won't fix This will not be worked on label Apr 19, 2021
@edenlightning edenlightning added the design Includes a design discussion label May 9, 2021
@edenlightning edenlightning added this to the v1.4 milestone May 9, 2021
@edenlightning edenlightning modified the milestones: v1.4, v1.5 Jul 1, 2021
@awaelchli awaelchli modified the milestones: v1.5, v1.6 Nov 4, 2021
@carmocca
Copy link
Contributor

carmocca commented Feb 1, 2022

@awaelchli I would not pursue this as the described alternative is the natural solution for all Python users.

Doing this limits the flexibility in how they are skipped and requires knowing that .skip_hooks exists.

One more argument against is that "loop customization" is the natural vehicle for this. Albeit more flexible and more complex

@carmocca carmocca added hooks Related to the hooks API and removed help wanted Open to be worked on labels Feb 1, 2022
@carmocca carmocca removed this from the 1.6 milestone Feb 1, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
design Includes a design discussion discussion In a discussion stage feature Is an improvement or enhancement hooks Related to the hooks API
Projects
None yet
Development

No branches or pull requests

6 participants