Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enums parsing in hparams.yaml generated #9170

Merged
merged 8 commits into from
Oct 25, 2021

Conversation

grajat90
Copy link
Contributor

@grajat90 grajat90 commented Aug 27, 2021

What does this PR do?

Fixes #8912

Does your PR introduce any breaking changes? If yes, please list them.

NONE

Before submitting

  • Was this discussed/approved via a GitHub issue? (not for typos and docs)
  • Did you read the contributor guideline, Pull Request section?
  • Did you make sure your PR does only one thing, instead of bundling different changes together?
  • Did you make sure to update the documentation with your changes? (if necessary)
  • Did you write any new necessary tests? (not for typos and docs)
  • Did you verify new and existing tests pass locally with your changes?
  • Did you list all the breaking changes introduced by this pull request?
  • Did you update the CHANGELOG? (not for typos, docs, test updates, or internal minor changes/refactorings)

PR review

Anyone in the community is welcome to review the PR.
Before you start reviewing make sure you have read Review guidelines. In short, see the following bullet-list:

  • Is this pull request ready for review? (if not, please submit in draft mode)
  • Check that all items from Before submitting are resolved
  • Make sure the title is self-explanatory and the description concisely explains the PR
  • Add labels and milestones (and optionally projects) to the PR so it can be classified

Did you have fun?

Yesss

Make sure you had fun coding 🙃

@grajat90
Copy link
Contributor Author

Passed functionality testing as well as unit testing when checked locally.
Does not break anything else when tested locally. Tested using following piece of code to run a test classifier on MNIST and have tensorboard log the hyperparams in the hparams.yaml file. This test fails on master and passes on this branch:

import pytorch_lightning as pl
from enum import Enum
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torch.utils.data import random_split
from torchvision.datasets import MNIST
from torchvision import transforms
from pytorch_lightning.loggers import TensorBoardLogger

logger = TensorBoardLogger(save_dir="tb_logs", name="my_model")

class Options(str, Enum):
    option1 = "option1"
    option2 = "option2"
    option3 = "option3"

class LitAutoEncoder(pl.LightningModule):
    def __init__(self,learning_rate: float = 0.0001,
                 switch: Options = Options.option3, # argument of interest
                 batch_size: int = 32):
        super().__init__()
        self.save_hyperparameters()
        self.encoder = nn.Sequential(
            nn.Linear(28 * 28, 64),
            nn.ReLU(),
            nn.Linear(64, 3))
        self.decoder = nn.Sequential(
            nn.Linear(3, 64),
            nn.ReLU(),
            nn.Linear(64, 28 * 28))
      

    def forward(self, x):
        embedding = self.encoder(x)
        return embedding

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

    def training_step(self, train_batch, batch_idx):
        x, y = train_batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)    
        x_hat = self.decoder(z)
        loss = F.mse_loss(x_hat, x)
        self.log('train_loss', loss)
        return loss

    def validation_step(self, val_batch, batch_idx):
        x, y = val_batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        loss = F.mse_loss(x_hat, x)
        self.log('val_loss', loss)

# data
dataset = MNIST('', train=True, download=True, transform=transforms.ToTensor())
mnist_train, mnist_val = random_split(dataset, [55000, 5000])

train_loader = DataLoader(mnist_train, batch_size=32)
val_loader = DataLoader(mnist_val, batch_size=32)

# model
model = LitAutoEncoder()

# training
trainer = pl.Trainer(limit_train_batches=0.2, logger=logger, max_epochs=1)
trainer.fit(model, train_loader, val_loader)

@grajat90 grajat90 changed the title Hyperparams save #8912 Hyperparams save #8912 Ready for review Aug 27, 2021
@grajat90 grajat90 changed the title Hyperparams save #8912 Ready for review Enums parsing in hparams.yaml generated #8912 Ready for review Aug 27, 2021
pytorch_lightning/core/saving.py Outdated Show resolved Hide resolved
tests/models/test_hparams.py Show resolved Hide resolved
@Borda Borda changed the title Enums parsing in hparams.yaml generated #8912 Ready for review Enums parsing in hparams.yaml generated Aug 27, 2021
@dasayan05
Copy link

Hi, I opened the issue #8912 and I would like to point out few things @tchaton

  1. Please note that OmegaConf.save(..) writes the .name of an Enum variable, whereas this PR is trying to write the .value. This will cause inconsistency in the written yaml file.
  2. Now if you follow OmegaConf's way of writing .name, then the tests here will not pass because the == operator tries to match .value of an enum.
  3. I just noticed that the save_hparams_to_yaml(..) do not have a use_omegaconf=.. guard. So everytime the tests here are writing the yaml using (unless it is not installed at all) OmegaConf and reading without it.

awaelchli
awaelchli previously approved these changes Aug 28, 2021
@awaelchli awaelchli added the bug Something isn't working label Aug 28, 2021
@awaelchli awaelchli added this to the v1.4.x milestone Aug 28, 2021
@awaelchli
Copy link
Contributor

@grajat90 let's add a changelog entry too :)

@grajat90
Copy link
Contributor Author

@grajat90 let's add a changelog entry too :)

Hey sure

Also wanted some clarification:
The way omegaconf saves enum type hyperparams is it converts the chosen out of the enums to a string and just saves it as string. I’ve replicated the functionality. But I think it would be better if it were converted to a dict as you pointed out. What I’ve done locally to try out is created custom parsing for enums. I save the enum definition and recreate the enum when we load it back from the file, select the option from the enum and return it to the hparams object. Is this the approach I should pursue or just try to replicate how omegaconf does it?

@grajat90 grajat90 force-pushed the HYPERPARAMS-SAVE-#8912 branch from 5875d80 to 8aec4da Compare August 29, 2021 06:34
@grajat90
Copy link
Contributor Author

As part of this latest commit, I have created custom parsing for enum type objects to store them in the hparams.yaml file as a dict. This is read back, the entire enum is recreated according to the definition, and is returned in the new object, thus retaining all enum based functionality that might be occurring in the model (using dot operator or any other enum special functions). This appears to be the best way to tackle it, please advice.

@grajat90 grajat90 requested a review from tchaton August 29, 2021 10:26
@codecov
Copy link

codecov bot commented Aug 29, 2021

Codecov Report

Merging #9170 (c83b9e2) into master (d9dfb2e) will decrease coverage by 4%.
The diff coverage is 100%.

@@           Coverage Diff           @@
##           master   #9170    +/-   ##
=======================================
- Coverage      93%     89%    -4%     
=======================================
  Files         182     182            
  Lines       16140   16142     +2     
=======================================
- Hits        14930   14316   -614     
- Misses       1210    1826   +616     

@awaelchli awaelchli self-requested a review August 29, 2021 12:15
@awaelchli awaelchli dismissed their stale review August 29, 2021 19:11

there were substantial changes since

@grajat90
Copy link
Contributor Author

Hey if anyone can let me know if this is a better approach or the earlier approach, I can just drop the commit and we can continue with the previous method.

@awaelchli
Copy link
Contributor

@grajat90 I do not know which one is better, I'm not familiar with this enum logic, but the current approach does not look right to me due to access to private member attributes.

Copy link
Contributor

@tchaton tchaton left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM !

@mergify mergify bot added the ready PRs ready to be merged label Oct 12, 2021
Copy link
Contributor

@rohitgr7 rohitgr7 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LG! just some minor doc updates.

pytorch_lightning/core/saving.py Outdated Show resolved Hide resolved
pytorch_lightning/core/saving.py Outdated Show resolved Hide resolved
pytorch_lightning/core/saving.py Outdated Show resolved Hide resolved
CHANGELOG.md Outdated Show resolved Hide resolved
tests/models/test_hparams.py Outdated Show resolved Hide resolved
@rohitgr7 rohitgr7 force-pushed the HYPERPARAMS-SAVE-#8912 branch from 0325aa7 to db52733 Compare October 21, 2021 10:55
@mergify mergify bot removed the has conflicts label Oct 21, 2021
@rohitgr7 rohitgr7 requested a review from Borda October 21, 2021 10:56
@rohitgr7 rohitgr7 enabled auto-merge (squash) October 21, 2021 10:56
@rohitgr7 rohitgr7 disabled auto-merge October 21, 2021 21:23
@rohitgr7 rohitgr7 enabled auto-merge (squash) October 21, 2021 21:24
@rohitgr7 rohitgr7 force-pushed the HYPERPARAMS-SAVE-#8912 branch from 7ff7c33 to 13bbcaa Compare October 21, 2021 21:29
@rohitgr7 rohitgr7 force-pushed the HYPERPARAMS-SAVE-#8912 branch from 13bbcaa to aff6b1b Compare October 25, 2021 11:35
@mergify mergify bot removed the has conflicts label Oct 25, 2021
@rohitgr7 rohitgr7 merged commit 47e7a28 into Lightning-AI:master Oct 25, 2021
ninginthecloud pushed a commit to ninginthecloud/pytorch-lightning that referenced this pull request Oct 27, 2021
Co-authored-by: tchaton <[email protected]>
Co-authored-by: Adrian Wälchli <[email protected]>
Co-authored-by: Rohit Gupta <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working ready PRs ready to be merged
Projects
None yet
Development

Successfully merging this pull request may close these issues.

hparams.yaml do not write Enum-style arguments properly
7 participants