-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enums parsing in hparams.yaml generated #9170
Enums parsing in hparams.yaml generated #9170
Conversation
Passed functionality testing as well as unit testing when checked locally. import pytorch_lightning as pl
from enum import Enum
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torch.utils.data import random_split
from torchvision.datasets import MNIST
from torchvision import transforms
from pytorch_lightning.loggers import TensorBoardLogger
logger = TensorBoardLogger(save_dir="tb_logs", name="my_model")
class Options(str, Enum):
option1 = "option1"
option2 = "option2"
option3 = "option3"
class LitAutoEncoder(pl.LightningModule):
def __init__(self,learning_rate: float = 0.0001,
switch: Options = Options.option3, # argument of interest
batch_size: int = 32):
super().__init__()
self.save_hyperparameters()
self.encoder = nn.Sequential(
nn.Linear(28 * 28, 64),
nn.ReLU(),
nn.Linear(64, 3))
self.decoder = nn.Sequential(
nn.Linear(3, 64),
nn.ReLU(),
nn.Linear(64, 28 * 28))
def forward(self, x):
embedding = self.encoder(x)
return embedding
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
return optimizer
def training_step(self, train_batch, batch_idx):
x, y = train_batch
x = x.view(x.size(0), -1)
z = self.encoder(x)
x_hat = self.decoder(z)
loss = F.mse_loss(x_hat, x)
self.log('train_loss', loss)
return loss
def validation_step(self, val_batch, batch_idx):
x, y = val_batch
x = x.view(x.size(0), -1)
z = self.encoder(x)
x_hat = self.decoder(z)
loss = F.mse_loss(x_hat, x)
self.log('val_loss', loss)
# data
dataset = MNIST('', train=True, download=True, transform=transforms.ToTensor())
mnist_train, mnist_val = random_split(dataset, [55000, 5000])
train_loader = DataLoader(mnist_train, batch_size=32)
val_loader = DataLoader(mnist_val, batch_size=32)
# model
model = LitAutoEncoder()
# training
trainer = pl.Trainer(limit_train_batches=0.2, logger=logger, max_epochs=1)
trainer.fit(model, train_loader, val_loader) |
Ready for review
Ready for review
Ready for review
Ready for review
Hi, I opened the issue #8912 and I would like to point out few things @tchaton
|
@grajat90 let's add a changelog entry too :) |
Hey sure Also wanted some clarification: |
5875d80
to
8aec4da
Compare
As part of this latest commit, I have created custom parsing for enum type objects to store them in the hparams.yaml file as a dict. This is read back, the entire enum is recreated according to the definition, and is returned in the new object, thus retaining all enum based functionality that might be occurring in the model (using dot operator or any other enum special functions). This appears to be the best way to tackle it, please advice. |
Codecov Report
@@ Coverage Diff @@
## master #9170 +/- ##
=======================================
- Coverage 93% 89% -4%
=======================================
Files 182 182
Lines 16140 16142 +2
=======================================
- Hits 14930 14316 -614
- Misses 1210 1826 +616 |
Hey if anyone can let me know if this is a better approach or the earlier approach, I can just drop the commit and we can continue with the previous method. |
@grajat90 I do not know which one is better, I'm not familiar with this enum logic, but the current approach does not look right to me due to access to private member attributes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM !
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LG! just some minor doc updates.
…Comments implemented
for more information, see https://pre-commit.ci
Co-authored-by: Rohit Gupta <[email protected]>
0325aa7
to
db52733
Compare
7ff7c33
to
13bbcaa
Compare
13bbcaa
to
aff6b1b
Compare
Co-authored-by: tchaton <[email protected]> Co-authored-by: Adrian Wälchli <[email protected]> Co-authored-by: Rohit Gupta <[email protected]>
What does this PR do?
Fixes #8912
Does your PR introduce any breaking changes? If yes, please list them.
NONE
Before submitting
PR review
Anyone in the community is welcome to review the PR.
Before you start reviewing make sure you have read Review guidelines. In short, see the following bullet-list:
Did you have fun?
Yesss
Make sure you had fun coding 🙃