Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

SemanticSegmentationData: Paletted PIL Images Support #1400

Open
Nico995 opened this issue Jul 21, 2022 Discussed in #1397 · 2 comments
Open

SemanticSegmentationData: Paletted PIL Images Support #1400

Nico995 opened this issue Jul 21, 2022 Discussed in #1397 · 2 comments
Assignees
Labels
enhancement New feature or request help wanted Extra attention is needed

Comments

@Nico995
Copy link
Contributor

Nico995 commented Jul 21, 2022

Discussed in #1397

Originally posted by Nico995 July 17, 2022
Hello, I have been trying run a SemanticSegmentation task on a custom dataset with no luck. The problem I'm facing is that it looks like paletted png images are not supported as targets for the task.

TL;DR

ImageSegmentationData.from_files breaks when target files are paletted PNG images.

Reproducing a working example

To prove my point, I started from the segmentation example on flash website:

  1. I downloaded the dataset:
from flash.core.data.utils import download_data

download_data(
    "https://github.com/ongchinkiat/LyftPerceptionChallenge/releases/download/v0.1/carla-capture-20180513A.zip",
    "./data",
)
  1. Next, I took a second to examine the format of the segmentation mask. As it turns out, the images are in RGB format, where the R channel contains the class index for each pixel, and the GB channels are set to 0. You can verify that with the following code:
from PIL import Image
import numpy as np

image = Image.open('./data/CameraRGB/F61-1.png')
seg_mask = Image.open('./data/CameraSeg/F61-1.png')

print('Mask mode:', seg_mask.mode)
print('Mask shape:', np.array(seg_mask).shape)
print('Mask pixel:', np.array(seg_mask)[0,0])
print('R channel values:', np.unique(np.array(seg_mask)[:,:,0]))
print('G channel values:', np.unique(np.array(seg_mask)[:,:,1]))
print('B channel values:', np.unique(np.array(seg_mask)[:,:,2]))
display(image.resize((200, 200)))
display(seg_mask.resize((200, 200)))

Mask mode: RGB
Mask shape: (600, 800, 3)
Mask pixel: [9 0 0]
R channel values: [ 0 1 2 3 4 5 6 7 8 9 10 11 12]
G channel values: [0]
B channel values: [0]
image
mask

The mask looks completely black because PIL in RGB mode (x, 0, 0) colors with 0<x<12 are black/red-ish colors.

  1. I then created a datamodule and a trainer to check that everything is working as intended (and it does):
import torch
import flash
from flash.image import SemanticSegmentation, SemanticSegmentationData

datamodule = SemanticSegmentationData.from_files(
    train_files=['./data/CameraRGB/F61-1.png'],
    train_targets=['./data/CameraSeg/F61-1.png'],
    num_classes=21,
    batch_size=4,
)

model = SemanticSegmentation(
    backbone="mobilenetv3_large_100",
    head="fpn",
    num_classes=datamodule.num_classes,
)

trainer = flash.Trainer(fast_dev_run=1, gpus=torch.cuda.device_count())
trainer.fit(model, datamodule=datamodule)

Epoch 0: 100%
1/1 [00:00<00:00, 3.13it/s, loss=7.17, v_num=, train_jaccardindex_step=0.00086, train_cross_entropy_step=7.170, train_jaccardindex_epoch=0.00086, train_cross_entropy_epoch=7.170]

Testing Paletted images

The problem arises when I try to use the paletted version of the same mask, as follows:

  1. First, I convert the RGB image to a P image, assigning an arbitrary palette:
# build fake palette
palette = []
for channel in range(3):
    for value in np.linspace(64, 255, 4):
        palette.append(np.roll(np.array((int(value), 0, 0)), channel))

# convert RGB to paletted mode (labels are in channel 0)
seg_mask_p = Image.fromarray(np.array(seg_mask)[:,:,0]).convert('P')

# set fake palette
seg_mask_p.putpalette([c for rgb in palette for c in rgb])
seg_mask_p.resize((200, 200))

mask_p

  1. Next I check the format of the new mask, to compare it with the original mask:
print('Mask mode:', seg_mask_p.mode)
print('Mask shape:', np.array(seg_mask_p).shape)
print('Mask pixel:', np.array(seg_mask_p)[0,0])
print('P channel values:', np.unique(np.array(seg_mask)))

Mask mode: P
Mask shape: (600, 800)
Mask pixel: 9
P channel values: [ 0 1 2 3 4 5 6 7 8 9 10 11 12]

  1. Finally I run the trainer again and get an error:
import torch
from flash.image import SemanticSegmentation, SemanticSegmentationData

image.save('./tmp_image.png')
seg_mask_p.save('./tmp_mask_p.png')

datamodule = SemanticSegmentationData.from_files(
    train_files=['./tmp_image.png'],
    train_targets=['./tmp_mask_p.png'],
    num_classes=21,
    batch_size=4,
)

trainer.fit(model, datamodule=datamodule)

The whole traceback is very long, but the key line is the following

../aten/src/ATen/native/cuda/NLLLoss2d.cu:93: nll_loss2d_forward_kernel: block: [0,0,0], thread: [961,0,0] Assertion `t >= 0 && t < n_classes` failed.

Suspected problem

I believe I have seen this error many times already, and I think it is referring to a loss function call when the argument contains values larget than the number of classes declared. My first guess is that this is happening because flash default behavior is to convert the mask to RGB, and reading only the first channel. Below I try to justify my assumption.

I believe this is the line responsible for extracting the first channel:

in /flash/image/segmentation/input.py:110

def load_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]:
    filepath = sample[DataKeys.INPUT]
    sample[DataKeys.INPUT] = to_tensor(load_image(filepath))
    if DataKeys.TARGET in sample:
        sample[DataKeys.TARGET] = (to_tensor(load_image(sample[DataKeys.TARGET])) * 255).long()[0] <---
    sample = super().load_sample(sample)
    sample[DataKeys.METADATA]["filepath"] = filepath
    return sample

And this is the line responsible for the RGB conversion

in /flash/core/data/utils.py:169

def image_default_loader(file_path: str, drop_alpha: bool = True) -> Image:
    """Default loader for images.
    Args:
        file_path: The image file to load.
        drop_alpha: If ``True`` (default) then any alpha channels will be silently removed.
    """
    img = default_loader(file_path) <---
    if img.mode == "RGBA" and drop_alpha:
        img = img.convert("RGB")
    return 

Which in turn is a pytorch function that simply loads a RGB image

in torchvision/datasets/folder.py:245

def pil_loader(path: str) -> Image.Image:
    # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
    with open(path, "rb") as f:
        img = Image.open(f)
        return img.convert("RGB") <---

#(...)

def default_loader(path: str) -> Any:
    from torchvision import get_image_backend

    if get_image_backend() == "accimage":
        return accimage_loader(path)
    else:
        return pil_loader(path) <---

Possible solution

I think it would be nice to have a parameter in SemanticSegmentationFilesInput.from_folders inside /flash/image/segmentation/data.py that allows the user to specify the mode of the image.
The parameter would default to 'RGB' for backward compatibility, and could be set to 'P' if one wants to work with paletted images.


Let me know if this looks like a possible enhancement, or if I just simply missed a parameter somewhere that would solve the problem :)

@krshrimali krshrimali self-assigned this Aug 22, 2022
@krshrimali krshrimali added this to the 0.8.0 milestone Aug 22, 2022
@krshrimali krshrimali removed this from the 0.8.0 milestone Aug 26, 2022
@uakarsh
Copy link
Contributor

uakarsh commented Oct 1, 2022

Hi @Nico995, I have able to generate the error (however, I had to use two files, since, for one file, the shape comes out to be (256, 256)), and I got the error as IndexError: Target 64 is out of bounds.. I guess you are right as well, i.e allowing a bit of modifiable reading format would help in customizing the package on different sets of data and tasks. Are you still working on it, or have found any approach to solve the problem?

@Nico995
Copy link
Contributor Author

Nico995 commented Oct 13, 2022

Hi @uakarsh, sorry for the late reply.
I have not been checking this in a long time. It appears that either I got the code reference from the wrong branch, or the master branch changed. Anyways, now the official code looks like this:

in flash/image/segmentation/input.py:97

def load_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]:
    if DataKeys.TARGET in sample:
        sample[DataKeys.TARGET] = np.array(load_image(sample[DataKeys.TARGET])).transpose((2, 0, 1))[:, :, 0]
    return super().load_sample(sample)

which breaks my example at point n.3 with the following error:

Expected input batch_size (1) to match target batch_size (128).

I did not take the time to fully debug the code, but I suppose one could at least include support for P-formatted images since they fit the description of a segmentation target incredibly well.

Anyways, the quick&dirty solution I am using at the moment is removing the .transpose((2, 0, 1)) when working with paletted images.

@Borda Borda added enhancement New feature or request help wanted Extra attention is needed labels Dec 23, 2022
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
enhancement New feature or request help wanted Extra attention is needed
Projects
None yet
Development

No branches or pull requests

4 participants