From 198a779a72535ead9ed655503602e91af5d14ad6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 9 Sep 2021 13:41:12 +0200 Subject: [PATCH] remove repro --- pl_examples/repro.py | 84 -------------------------------------------- 1 file changed, 84 deletions(-) delete mode 100644 pl_examples/repro.py diff --git a/pl_examples/repro.py b/pl_examples/repro.py deleted file mode 100644 index 22f1650852a06..0000000000000 --- a/pl_examples/repro.py +++ /dev/null @@ -1,84 +0,0 @@ -import os -import unittest.mock - -import torch -from torch.utils.data import DataLoader, Dataset - -from pytorch_lightning import LightningModule, Trainer - - -class RandomDataset(Dataset): - def __init__(self, size, length): - self.len = length - self.data = torch.randn(length, size) - - def __getitem__(self, index): - return self.data[index] - - def __len__(self): - return self.len - - -class BoringModel(LightningModule): - def __init__(self): - super().__init__() - self.layer = torch.nn.Linear(32, 2) - - def forward(self, x): - return self.layer(x) - - def training_step(self, batch, batch_idx): - print(batch.sum()) - loss = self(batch).sum() - self.log("train_loss", loss) - return {"loss": loss} - - def validation_step(self, batch, batch_idx): - loss = self(batch).sum() - self.log("valid_loss", loss) - - def test_step(self, batch, batch_idx): - loss = self(batch).sum() - self.log("test_loss", loss) - - def on_train_epoch_end(self): - print("epoch ended") - - def configure_optimizers(self): - return torch.optim.SGD(self.layer.parameters(), lr=0.1) - - def on_load_checkpoint(self, checkpoint): - pass - - -@unittest.mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"}) -def run(): - train_data = DataLoader(RandomDataset(32, 64), batch_size=2) - - model = BoringModel() - trainer = Trainer( - default_root_dir=os.getcwd(), - limit_train_batches=3, - limit_val_batches=0, - num_sanity_val_steps=0, - max_epochs=1, - weights_summary=None, - ) - trainer.fit(model, train_dataloader=train_data) - - trainer.save_checkpoint("lightning_logs/auto.pt") - - trainer = Trainer( - default_root_dir=os.getcwd(), - limit_train_batches=3, - limit_val_batches=0, - num_sanity_val_steps=0, - max_epochs=3, - weights_summary=None, - resume_from_checkpoint="lightning_logs/auto.pt", - ) - trainer.fit(model, train_dataloader=train_data) - - -if __name__ == "__main__": - run()