-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Transfer learning example #1564
Transfer learning example #1564
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pls add argparse
to be able to run with diff params
pls use Napoleon docs style https://www.sphinx-doc.org/en/master/usage/extensions/napoleon.html
Hi @jbschiratti Thanks for such a nice example. I'm rather new in the field so hopefully my question will not be too off base. I was just going over the code and I noticed that in your example the BatchNorm layers will always remain in training mode (as train_bn is always set to self.hparams.train_bn when calling the freeze function). That is even when performing validation or evaluation. I understand the code potentially allows for the BN layers to be set to eval (if the train_bn=False) but I am just wondering if there is a specific reason why do you always leave BN in train mode? Why not have them train in the training stage and eval in the validation/testing? Just to clerify I'm not arguing it should be different, I'm just asking for the reasoning behind it. |
Thank you for your interest and help with this addition, may you please use review tab of this PR to write your comments directly to the sections you are talking about... it will make the discussion clearer and a bit more concrete :] |
@hcjghr Thank you for spotting this. It's was a bug! The way I see it, in the evaluation loop when @Borda I fixed the docstrings and added |
pls add note to changelog 🐰 |
This pull request is now in conflict... :( |
Codecov Report
@@ Coverage Diff @@
## master #1564 +/- ##
======================================
Coverage 88% 88%
======================================
Files 69 69
Lines 4133 4133
======================================
Hits 3656 3656
Misses 477 477 |
Thanks @awaelchli for the review and the comments! |
I noticed that the downloaded dataset is not ignored in version control. Could we maybe redirect it to a subfolder datasets and add a .gitignore in domain templates folder? |
This is strange because the context manager with TemporaryDirectory(dir=hparams.root_data_path) as tmp_dir:
... should delete the temporary folder in which the data is downloaded. |
ah ok, so it is also supposed to do that on keyboard interrupt? Maybe it's because I'm on Windows currently. |
I tried to stop the script with CTRL+C during the 1st epoch and the temporary folder was deleted (on Linux). But I cannot guarantee this always works. |
Can now also confirm it works fine on Linux, so it's just a Windows thing, so I guess we can keep it like that. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice, minimal and clean. I like it very much.
def loss(self, labels, logits): | ||
return self.loss_func(input=logits, target=labels) | ||
|
||
def train(self, mode=True): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what is the mode
for? could it be more descriptive?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See https://github.com/pytorch/pytorch/blob/d37a4861b8a5eed3d9a1340484d1efb0f48aa59e/torch/nn/modules/module.py#L1067. This line overrides the train
method of the Pytorch module. I will add a docstring specifying what mode
does.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you are right, we may rename it...
doe you have suggestion about a better name? @PyTorchLightning/core-contributors
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not sure we can rename it. In the evaluation loop (L330), model.train()
is called and here, model
refers (if I am not mistaken) to the LightningModule. We want to override this train
method to ensure that, at the end of this evaluation loop when model.train()
is called, some parameters (in specific layers) remain frozen (that is, with requires_grad=False
) if needed.
@staticmethod | ||
def add_model_specific_args(parent_parser): | ||
parser = argparse.ArgumentParser(parents=[parent_parser]) | ||
parser.add_argument('--backbone', |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use add_argparse_args
so we limit duplication and add just the new/needed for a model
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
by "limit code duplication", you want me to remove this line?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I mean remove lines which are generated from Trainer arguments... does it make sense?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
to a temporary directory. | ||
""" | ||
|
||
with TemporaryDirectory(dir=hparams.root_data_path) as tmp_dir: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess we want to keep the output folder
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jbschiratti ^^
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The folder in which the data was downloaded is deleted after the experiment. If you think we should leave the data untouched after the example has run, I can make another PR to fix this :-)
@jbschiratti this is super cool. |
I would keep it here in examples.... |
This pull request is now in conflict... :( |
Thank you @williamFalcon 👍 |
What does this PR do?
Addresses issue #514. Following up on this discussion, this PR proposes to add a self-contained example which shows how a pretrained network (such as ResNet50) can be fine tuned within a LightningModule.
PR review
Anyone in the community is free to review the PR 🙂