Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Can't save and load models using the SemanticSegmentation Task #388

Closed
ido-greenfeld opened this issue Jun 9, 2021 · 3 comments
Closed
Labels
bug / fix Something isn't working Epic help wanted Extra attention is needed Priority won't fix This will not be worked on

Comments

@ido-greenfeld
Copy link

🐛 Bug

Unable to properly load a saved SemanticSegmentation model using Flash.

Tried 3 methods -
1-2. Saving with jit using model.to_torchscript(methos='trace',...)
3. Saving weight using torch.save(model.backbone.state_dict(),...)

Code sample

Code mostly follows the SemanticSegmentation example
Changed backbone to Unet and added saving and loading phase

import flash
import torch

from flash.core.data.utils import download_data
from flash.image import SemanticSegmentation, SemanticSegmentationData
from flash.image.segmentation.serialization import SegmentationLabels

download_data(
    "https://github.com/ongchinkiat/LyftPerceptionChallenge/releases/download/v0.1/carla-capture-20180513A.zip", "data/"
)

datamodule = SemanticSegmentationData.from_folders(
    train_folder="data/CameraRGB",
    train_target_folder="data/CameraSeg",
    batch_size=2,
    val_split=0.3,
    image_size=(200, 200),  # (600, 800)
    num_classes=21,
)

model = SemanticSegmentation(
    backbone="unet", num_classes=datamodule.num_classes, serializer=SegmentationLabels(visualize=True)
)

trainer = flash.Trainer(
    max_epochs=1,
    gpus=1,
)

trainer.finetune(model, datamodule=datamodule, strategy='freeze')
# ~~~Saving~~~
# Method 1-2: Saving model using torchscript
script = model.to_torchscript(method="trace", example_inputs=torch.rand(1, 3, 200, 200))
torch.jit.save(script, "./save_load_torchscript_test.pt")

# Method 3: Saving weights using state_dict
torch.save(model.backbone.state_dict(), "./save_load_weights_test.ckpt")
# ~~~Loading~~~
# Method 1: This results in KeyError(f"Key: {key} is not in {repr(self)}")
model_from_torchscript = SemanticSegmentation(backbone=torch.jit.load('./save_load_torchscript_test.pt'), num_classes = 21)
---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
<ipython-input-10-1bb5b1a4118b> in <module>
      1 # Method 1: This results in KeyError(f"Key: {key} is not in {repr(self)}")
----> 2 model_from_torchscript = SemanticSegmentation(backbone=torch.jit.load('./save_load_torchscript_test.pt'), num_classes = 21)

~/repositories/lightning-flash/flash/image/segmentation/model.py in __init__(self, num_classes, backbone, backbone_kwargs, pretrained, loss_fn, optimizer, metrics, learning_rate, multi_label, serializer, postprocess)
    115 
    116         # TODO: pretrained to True causes some issues
--> 117         self.backbone = self.backbones.get(backbone)(num_classes, pretrained=pretrained, **backbone_kwargs)
    118 
    119     def training_step(self, batch: Any, batch_idx: int) -> Any:

~/repositories/lightning-flash/flash/core/registry.py in get(self, key, with_metadata, strict, **metadata)
     60         matches = [e for e in self.functions if key == e["name"]]
     61         if not matches:
---> 62             raise KeyError(f"Key: {key} is not in {repr(self)}")
     63 
     64         if metadata:

KeyError: "Key: RecursiveScriptModule(\n  original_name=SemanticSegmentation\n  (metrics): RecursiveScriptModule(\n    original_name=ModuleDict\n    (iou): RecursiveScriptModule(original_name=IoU)\n  )\n  (_postprocess): RecursiveScriptModule(original_name=SemanticSegmentationPostprocess)\n  (backbone): RecursiveScriptModule(\n    original_name=UNet\n    (layers): RecursiveScriptModule(\n      original_name=ModuleList\n      (0): RecursiveScriptModule(\n        original_name=DoubleConv\n        (net): RecursiveScriptModule(\n          original_name=Sequential\n          (0): RecursiveScriptModule(original_name=Conv2d)\n          (1): RecursiveScriptModule(original_name=BatchNorm2d)\n          (2): RecursiveScriptModule(original_name=ReLU)\n          (3): RecursiveScriptModule(original_name=Conv2d)\n          (4): RecursiveScriptModule(original_name=BatchNorm2d)\n          (5): RecursiveScriptModule(original_name=ReLU)\n        )\n      )\n      (1): RecursiveScriptModule(\n        original_name=Down\n        (net): RecursiveScriptModule(\n          original_name=Sequential\n          (0): RecursiveScriptModule(original_name=MaxPool2d)\n          (1): RecursiveScriptModule(\n            original_name=DoubleConv\n            (net): RecursiveScriptModule(\n              original_name=Sequential\n              (0): RecursiveScriptModule(original_name=Conv2d)\n              (1): RecursiveScriptModule(original_name=BatchNorm2d)\n              (2): RecursiveScriptModule(original_name=ReLU)\n              (3): RecursiveScriptModule(original_name=Conv2d)\n              (4): RecursiveScriptModule(original_name=BatchNorm2d)\n              (5): RecursiveScriptModule(original_name=ReLU)\n            )\n          )\n        )\n      )\n      (2): RecursiveScriptModule(\n        original_name=Down\n        (net): RecursiveScriptModule(\n          original_name=Sequential\n          (0): RecursiveScriptModule(original_name=MaxPool2d)\n          (1): RecursiveScriptModule(\n            original_name=DoubleConv\n            (net): RecursiveScriptModule(\n              original_name=Sequential\n              (0): RecursiveScriptModule(original_name=Conv2d)\n              (1): RecursiveScriptModule(original_name=BatchNorm2d)\n              (2): RecursiveScriptModule(original_name=ReLU)\n              (3): RecursiveScriptModule(original_name=Conv2d)\n              (4): RecursiveScriptModule(original_name=BatchNorm2d)\n              (5): RecursiveScriptModule(original_name=ReLU)\n            )\n          )\n        )\n      )\n      (3): RecursiveScriptModule(\n        original_name=Down\n        (net): RecursiveScriptModule(\n          original_name=Sequential\n          (0): RecursiveScriptModule(original_name=MaxPool2d)\n          (1): RecursiveScriptModule(\n            original_name=DoubleConv\n            (net): RecursiveScriptModule(\n              original_name=Sequential\n              (0): RecursiveScriptModule(original_name=Conv2d)\n              (1): RecursiveScriptModule(original_name=BatchNorm2d)\n              (2): RecursiveScriptModule(original_name=ReLU)\n              (3): RecursiveScriptModule(original_name=Conv2d)\n              (4): RecursiveScriptModule(original_name=BatchNorm2d)\n              (5): RecursiveScriptModule(original_name=ReLU)\n            )\n          )\n        )\n      )\n      (4): RecursiveScriptModule(\n        original_name=Down\n        (net): RecursiveScriptModule(\n          original_name=Sequential\n          (0): RecursiveScriptModule(original_name=MaxPool2d)\n          (1): RecursiveScriptModule(\n            original_name=DoubleConv\n            (net): RecursiveScriptModule(\n              original_name=Sequential\n              (0): RecursiveScriptModule(original_name=Conv2d)\n              (1): RecursiveScriptModule(original_name=BatchNorm2d)\n              (2): RecursiveScriptModule(original_name=ReLU)\n              (3): RecursiveScriptModule(original_name=Conv2d)\n              (4): RecursiveScriptModule(original_name=BatchNorm2d)\n              (5): RecursiveScriptModule(original_name=ReLU)\n            )\n          )\n        )\n      )\n      (5): RecursiveScriptModule(\n        original_name=Up\n        (upsample): RecursiveScriptModule(original_name=ConvTranspose2d)\n        (conv): RecursiveScriptModule(\n          original_name=DoubleConv\n          (net): RecursiveScriptModule(\n            original_name=Sequential\n            (0): RecursiveScriptModule(original_name=Conv2d)\n            (1): RecursiveScriptModule(original_name=BatchNorm2d)\n            (2): RecursiveScriptModule(original_name=ReLU)\n            (3): RecursiveScriptModule(original_name=Conv2d)\n            (4): RecursiveScriptModule(original_name=BatchNorm2d)\n            (5): RecursiveScriptModule(original_name=ReLU)\n          )\n        )\n      )\n      (6): RecursiveScriptModule(\n        original_name=Up\n        (upsample): RecursiveScriptModule(original_name=ConvTranspose2d)\n        (conv): RecursiveScriptModule(\n          original_name=DoubleConv\n          (net): RecursiveScriptModule(\n            original_name=Sequential\n            (0): RecursiveScriptModule(original_name=Conv2d)\n            (1): RecursiveScriptModule(original_name=BatchNorm2d)\n            (2): RecursiveScriptModule(original_name=ReLU)\n            (3): RecursiveScriptModule(original_name=Conv2d)\n            (4): RecursiveScriptModule(original_name=BatchNorm2d)\n            (5): RecursiveScriptModule(original_name=ReLU)\n          )\n        )\n      )\n      (7): RecursiveScriptModule(\n        original_name=Up\n        (upsample): RecursiveScriptModule(original_name=ConvTranspose2d)\n        (conv): RecursiveScriptModule(\n          original_name=DoubleConv\n          (net): RecursiveScriptModule(\n            original_name=Sequential\n            (0): RecursiveScriptModule(original_name=Conv2d)\n            (1): RecursiveScriptModule(original_name=BatchNorm2d)\n            (2): RecursiveScriptModule(original_name=ReLU)\n            (3): RecursiveScriptModule(original_name=Conv2d)\n            (4): RecursiveScriptModule(original_name=BatchNorm2d)\n            (5): RecursiveScriptModule(original_name=ReLU)\n          )\n        )\n      )\n      (8): RecursiveScriptModule(\n        original_name=Up\n        (upsample): RecursiveScriptModule(original_name=ConvTranspose2d)\n        (conv): RecursiveScriptModule(\n          original_name=DoubleConv\n          (net): RecursiveScriptModule(\n            original_name=Sequential\n            (0): RecursiveScriptModule(original_name=Conv2d)\n            (1): RecursiveScriptModule(original_name=BatchNorm2d)\n            (2): RecursiveScriptModule(original_name=ReLU)\n            (3): RecursiveScriptModule(original_name=Conv2d)\n            (4): RecursiveScriptModule(original_name=BatchNorm2d)\n            (5): RecursiveScriptModule(original_name=ReLU)\n          )\n        )\n      )\n      (9): RecursiveScriptModule(original_name=Conv2d)\n    )\n  )\n  (_preprocess): RecursiveScriptModule(\n    original_name=SemanticSegmentationPreprocess\n    (_train_transform): RecursiveScriptModule(\n      original_name=ModuleDict\n      (post_tensor_transform): RecursiveScriptModule(\n        original_name=Sequential\n        (0): RecursiveScriptModule(\n          original_name=Sequential\n          (0): RecursiveScriptModule(\n            original_name=ApplyToKeys\n            (0): RecursiveScriptModule(\n              original_name=KorniaParallelTransforms\n              (0): RecursiveScriptModule(original_name=Resize)\n            )\n          )\n        )\n        (1): RecursiveScriptModule(\n          original_name=Sequential\n          (0): RecursiveScriptModule(\n            original_name=ApplyToKeys\n            (0): RecursiveScriptModule(\n              original_name=KorniaParallelTransforms\n              (0): RecursiveScriptModule(original_name=RandomHorizontalFlip)\n            )\n          )\n        )\n      )\n      (collate): RecursiveScriptModule(original_name=FuncModule)\n    )\n    (_val_transform): RecursiveScriptModule(\n      original_name=ModuleDict\n      (post_tensor_transform): RecursiveScriptModule(\n        original_name=Sequential\n        (0): RecursiveScriptModule(\n          original_name=ApplyToKeys\n          (0): RecursiveScriptModule(\n            original_name=KorniaParallelTransforms\n            (0): RecursiveScriptModule(original_name=Resize)\n          )\n        )\n      )\n      (collate): RecursiveScriptModule(original_name=FuncModule)\n    )\n    (_test_transform): RecursiveScriptModule(\n      original_name=ModuleDict\n      (post_tensor_transform): RecursiveScriptModule(\n        original_name=Sequential\n        (0): RecursiveScriptModule(\n          original_name=ApplyToKeys\n          (0): RecursiveScriptModule(\n            original_name=KorniaParallelTransforms\n            (0): RecursiveScriptModule(original_name=Resize)\n          )\n        )\n      )\n      (collate): RecursiveScriptModule(original_name=FuncModule)\n    )\n    (_predict_transform): RecursiveScriptModule(\n      original_name=ModuleDict\n      (post_tensor_transform): RecursiveScriptModule(\n        original_name=Sequential\n        (0): RecursiveScriptModule(\n          original_name=ApplyToKeys\n          (0): RecursiveScriptModule(\n            original_name=KorniaParallelTransforms\n            (0): RecursiveScriptModule(original_name=Resize)\n          )\n        )\n      )\n      (collate): RecursiveScriptModule(original_name=FuncModule)\n    )\n  )\n) is not in FlashRegistry(name=backbones, functions=[{'fn': <function catch_url_error.<locals>.wrapper at 0x7fb28363ab00>, 'name': 'fcn_resnet50', 'metadata': {'namespace': 'image/segmentation', 'package': 'torchvision', 'type': 'fcn'}}, {'fn': <function catch_url_error.<locals>.wrapper at 0x7fb28363ac20>, 'name': 'fcn_resnet101', 'metadata': {'namespace': 'image/segmentation', 'package': 'torchvision', 'type': 'fcn'}}, {'fn': <function catch_url_error.<locals>.wrapper at 0x7fb28363acb0>, 'name': 'deeplabv3_resnet50', 'metadata': {'namespace': 'image/segmentation', 'package': 'torchvision', 'type': 'deeplabv3'}}, {'fn': <function catch_url_error.<locals>.wrapper at 0x7fb28363ad40>, 'name': 'deeplabv3_resnet101', 'metadata': {'namespace': 'image/segmentation', 'package': 'torchvision', 'type': 'deeplabv3'}}, {'fn': <function catch_url_error.<locals>.wrapper at 0x7fb28363add0>, 'name': 'deeplabv3_mobilenet_v3_large', 'metadata': {'namespace': 'image/segmentation', 'package': 'torchvision', 'type': 'deeplabv3'}}, {'fn': <function catch_url_error.<locals>.wrapper at 0x7fb28363af80>, 'name': 'torchvision/fcn_resnet50', 'metadata': {}}, {'fn': <function catch_url_error.<locals>.wrapper at 0x7fb2836420e0>, 'name': 'torchvision/fcn_resnet101', 'metadata': {}}, {'fn': <function catch_url_error.<locals>.wrapper at 0x7fb283642170>, 'name': 'lraspp_mobilenet_v3_large', 'metadata': {'namespace': 'image/segmentation', 'package': 'torchvision', 'type': 'lraspp'}}, {'fn': <function load_bolts_unet at 0x7fb283642200>, 'name': 'unet', 'metadata': {'namespace': 'image/segmentation', 'package': 'bolts', 'type': 'unet'}}])"
# Method 2: This results in NotImplementedError Error
model_from_torchscript = SemanticSegmentation.load_from_checkpoint('./save_load_torchscript_test.pt')
/opt/conda/lib/python3.7/site-packages/torch/serialization.py:589: UserWarning: 'torch.load' received a zip file that looks like a TorchScript archive dispatching to 'torch.jit.load' (call 'torch.jit.load' directly to silence this warning)
  " silence this warning)", UserWarning)
---------------------------------------------------------------------------
NotImplementedError                       Traceback (most recent call last)
<ipython-input-11-68abae8bd5fd> in <module>
      1 # Method 1: This results in NotImplementedError Error
----> 2 model_from_torchscript = SemanticSegmentation.load_from_checkpoint('./save_load_torchscript_test.pt')

/opt/conda/lib/python3.7/site-packages/pytorch_lightning/core/saving.py in load_from_checkpoint(cls, checkpoint_path, map_location, hparams_file, strict, **kwargs)
    150 
    151         # for past checkpoint need to add the new key
--> 152         if cls.CHECKPOINT_HYPER_PARAMS_KEY not in checkpoint:
    153             checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY] = {}
    154         # override the hparams with values that were passed in

/opt/conda/lib/python3.7/site-packages/torch/jit/_script.py in __contains__(self, key)
    625 
    626         def __contains__(self, key):
--> 627             return self.forward_magic_method("__contains__", key)
    628 
    629         # dir is defined by the base nn.Module, so instead of throwing if

/opt/conda/lib/python3.7/site-packages/torch/jit/_script.py in forward_magic_method(self, method_name, *args, **kwargs)
    612                 RecursiveScriptModule, method_name
    613             ):
--> 614                 raise NotImplementedError()
    615             return self_method(*args, **kwargs)
    616 

NotImplementedError: 
# Method 3: Trying to load the weights and predict
model_from_weights = SemanticSegmentation(
    backbone="unet", num_classes=datamodule.num_classes, serializer=SegmentationLabels(visualize=True)
)
model_from_weights.backbone.load_state_dict(torch.load('./save_load_weights_test.ckpt'))

# This results in TypeError: list indices must be integers or slices, not DefaultDataKeys
predictions = model_from_weights.predict([
    "data/CameraRGB/F61-1.png",
    "data/CameraRGB/F62-1.png",
    "data/CameraRGB/F63-1.png",
])
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-13-7e6aa120c044> in <module>
      3     "data/CameraRGB/F61-1.png",
      4     "data/CameraRGB/F62-1.png",
----> 5     "data/CameraRGB/F63-1.png",
      6 ])

~/repositories/lightning-flash/flash/core/model.py in wrapper(self, *args, **kwargs)
     66         torch.set_grad_enabled(False)
     67 
---> 68         result = func(self, *args, **kwargs)
     69 
     70         if is_training:

~/repositories/lightning-flash/flash/core/model.py in predict(self, x, data_source, data_pipeline)
    200         x = self.transfer_batch_to_device(x, next(self.parameters()).device)
    201         x = data_pipeline.device_preprocessor(running_stage)(x)
--> 202         predictions = self.predict_step(x, 0)  # batch_idx is always 0 when running with `model.predict`
    203         predictions = data_pipeline.postprocessor(running_stage)(predictions)
    204         return predictions

~/repositories/lightning-flash/flash/image/segmentation/model.py in predict_step(self, batch, batch_idx, dataloader_idx)
    130 
    131     def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any:
--> 132         batch_input = (batch[DefaultDataKeys.INPUT])
    133         preds = super().predict_step(batch_input, batch_idx, dataloader_idx=dataloader_idx)
    134         batch[DefaultDataKeys.PREDS] = preds

TypeError: list indices must be integers or slices, not DefaultDataKeys
  • PyTorch Version (e.g., 1.0): 1.8.0
  • OS (e.g., Linux): Debian 10 on GCP cloud
  • Python version: 3.7.10
  • CUDA/cuDNN version: 11.0
  • GPU models and configuration: 1X Nvidia k80
@ido-greenfeld ido-greenfeld added bug / fix Something isn't working help wanted Extra attention is needed labels Jun 9, 2021
@ethanwharris
Copy link
Collaborator

Hi @ido-greenfeld thanks for the issue 😃

Some comments / observations:

  • turns out support for jit script was easier than I thought so we'll be adding it in Jit support #389
  • with jit, only the forward pass of the model will be usable. This doesn't include the data preprocessing, so you would need to give a tensor to the model (in the case of segmentation) rather than a filename.
  • when we save checkpoints with lightning, the data preprocessing classes are included in the checkpoint. This doesn't happen when you just call state_dict and so you get an error.

Proposal:

I've broken this down in to two issues. One for giving a better error message when this happens (rather than the current DefaultDataKeys message) in #391. One for making checkpoints work with state_dict rather than just through lightning in #392. Once these two issues are fixed, I'll consider this issue resolved.

Hope that helps 😃 let me know if you have any questions!

@ido-greenfeld
Copy link
Author

Thanks @ethanwharris !

@stale
Copy link

stale bot commented Aug 10, 2021

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.

@stale stale bot added the won't fix This will not be worked on label Aug 10, 2021
@stale stale bot closed this as completed Aug 20, 2021
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
bug / fix Something isn't working Epic help wanted Extra attention is needed Priority won't fix This will not be worked on
Projects
None yet
Development

No branches or pull requests

2 participants