Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Semantic Segmentation task documentation for from_datasets #938

Closed
klwetstone opened this issue Nov 5, 2021 · 10 comments
Closed

Semantic Segmentation task documentation for from_datasets #938

klwetstone opened this issue Nov 5, 2021 · 10 comments
Labels
bug / fix Something isn't working documentation Improvements or additions to documentation won't fix This will not be worked on
Milestone

Comments

@klwetstone
Copy link

Currently, the docs for Semantic Segmentation tasks say that to use the from_datasets method:

The __getitem__ of your datasets should return a dictionary with "input" and "target" keys which map to the input and target images as tensors.

When the __getitem__ method is defined this way, items are not correctly interpreted and the full dictionary is treated as the input. It actually has to be a tuple of the input tensor and the target tensor.

Issue using current docs instructions

The code below is a simplified version of what was run:

Create dataset class

import rasterio
import torch

class NewDataset(torch.utils.data.Dataset):
    def __init__(self, input_data_dir, target_data_dir):
        self.input_data_dir = input_data_dir
        self.target_data_dir = target_data_dir

    def __getitem__(self, image_id):
        # Load input image and target based on image id
        input_img_path = self.input_data_dir / f"{image_id}.tif"
        with rasterio.open(input_img_path) as im:
            input_arr = im.read(1).astype("float32")

        target_img_path = self.target_data_dir / f"{image_id}.tif"
        with rasterio.open(target_img_path) as im:
            target_arr = im.read(1).astype("float32")

        return {
            "input": torch.from_numpy(input_arr),
            "target": torch.from_numpy(target_arr),
        }

Run model

from flash.image import SemanticSegmentation, SemanticSegmentationData
import flash

datamodule = SemanticSegmentationData.from_datasets(
    train_dataset=NewDataset(
        input_data_dir=INPUT_DATA_DIR, target_data_dir=TARGET_DATA_DIR
    ),
    num_classes=2,
)

model = SemanticSegmentation(
    backbone="resnet18", head="fpn", num_classes=datamodule.num_classes,
)

trainer = flash.Trainer(max_epochs=3, gpus=torch.cuda.device_count())
trainer.finetune(model, datamodule=datamodule, strategy="freeze_unfreeze")
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
/var/folders/hh/_nbd8pkn08x5ty_mx954bjvm0000gn/T/ipykernel_23248/1059754187.py in <module>
      1 # using float32
      2 trainer = flash.Trainer(max_epochs=3, gpus=torch.cuda.device_count())
----> 3 trainer.finetune(model, datamodule=datamodule, strategy="freeze_unfreeze")

...

/opt/anaconda3/envs/clouds-planetary-computer/lib/python3.9/site-packages/kornia/geometry/transform/affwarp.py in resize(input, size, interpolation, align_corners, side, antialias)
    541     """
    542     if not isinstance(input, torch.Tensor):
--> 543         raise TypeError("Input tensor type is not a torch.Tensor. Got {}"
    544                         .format(type(input)))
    545 

TypeError: Input tensor type is not a torch.Tensor. Got <class 'dict'>

Correct usage

The modeling code above works correctly when __getitem__ returns a tuple of the input image and the target image, each as a tensor:

class NewDataset(torch.utils.data.Dataset):
    ...

    def __getitem__(self, image_id):
        # Load input image and target based on image id
        input_img_path = self.input_data_dir / f"{image_id}.tif"
        with rasterio.open(input_img_path) as im:
            input_arr = im.read(1).astype("float32")

        target_img_path = self.target_data_dir / f"{image_id}.tif"
        with rasterio.open(target_img_path) as im:
            target_arr = im.read(1).astype("float32")

  -->   return torch.from_numpy(input_arr), torch.from_numpy(target_arr)

Suggested docs fix

In the from_datasets section of the docs, change:

The __getitem__ of your datasets should return a dictionary with "input" and "target" keys which map to the input and target images as tensors.

To:

The __getitem__ of your datasets should return a tuple where the first item is the input image as a tensor and the second item is the target image as a tensor. Eg. return torch.from_numpy(input_image_array), torch.from_numpy(target_image_array)

@klwetstone klwetstone added the documentation Improvements or additions to documentation label Nov 5, 2021
@ethanwharris
Copy link
Collaborator

Hi @klwetstone thanks for reporting this! I think this is also bug as I was expecting both of the variations you have there to work 😃

@ethanwharris ethanwharris added the bug / fix Something isn't working label Nov 5, 2021
@klwetstone
Copy link
Author

@ethanwharris glad it's helpful! Happy to send more of the details I found from poking around in the debugger too 👍

@ethanwharris
Copy link
Collaborator

Definitely helpful as I think this invalidates our docs / recommendations for every task! AFAICT the error is here: https://github.com/PyTorchLightning/lightning-flash/blob/d1be93cd2d5b59af8bc40db2e6a606688b9d071c/flash/core/data/data_source.py#L371
Changing that line to return sample should solve it, but we should add some proper tests to make sure we don't regress in future.
@klwetstone Would you be interested in contributing the fix?

@klwetstone
Copy link
Author

@ethanwharris sure!

I agree - I think that's where the error is. When I was in debugging mode I basically got:

DefaultDataKeys.INPUT: {"input": # input image tensor,
                        "target": # target image tensor
         }

So if __getitem__ does return a dictionary, maybe:

return {DefaultDataKeys.INPUT: sample["input"]}

@ethanwharris
Copy link
Collaborator

Yeah, I think a special case for when __getitem__ returned a dictionary is the best idea (in case anything else we have depends on the current behaviour). Worth noting that DefaultDataKeys.INPUT will behave the same as "input" so I think if it's a dictionary you can just return the sample 😃

@klwetstone
Copy link
Author

Ah I see. That sounds like a good fix, thanks!

If you want a tester once it's implemented let me know, happy to try it out

@ethanwharris
Copy link
Collaborator

@klwetstone Would you like to try the fix? Should be possible to just update the load_sample method I linked above and then I guess we should have some tests too. If so, I can assign the issue to you 😃 These methods are moving / being renamed in #929 so if you wait to get started until that is merged (should be done today) then that should avoid any conflicts.

@stale
Copy link

stale bot commented Jan 9, 2022

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.

@stale stale bot added the won't fix This will not be worked on label Jan 9, 2022
@ethanwharris ethanwharris removed the won't fix This will not be worked on label Jan 17, 2022
@stale
Copy link

stale bot commented Mar 20, 2022

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.

@stale stale bot added the won't fix This will not be worked on label Mar 20, 2022
@ethanwharris ethanwharris added this to the 0.7.x milestone Mar 31, 2022
@stale stale bot removed the won't fix This will not be worked on label Mar 31, 2022
@stale
Copy link

stale bot commented Jun 5, 2022

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.

@stale stale bot added the won't fix This will not be worked on label Jun 5, 2022
@stale stale bot closed this as completed Jun 15, 2022
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
bug / fix Something isn't working documentation Improvements or additions to documentation won't fix This will not be worked on
Projects
None yet
Development

No branches or pull requests

2 participants