Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

multi-label inconsistencies in the API #208

Closed
edgarriba opened this issue Apr 9, 2021 · 0 comments · Fixed by #230
Closed

multi-label inconsistencies in the API #208

edgarriba opened this issue Apr 9, 2021 · 0 comments · Fixed by #230
Labels
bug / fix Something isn't working help wanted Extra attention is needed

Comments

@edgarriba
Copy link
Contributor

edgarriba commented Apr 9, 2021

🐛 Bug

When I try to use the ImageClassifationData to solve a multi-label problem there's no clear way how to pass the labels data.

Following the multi-class example in the documentation, I expect the mult-label to work as follows:

train_images: List[str] = ['img1.jpg', 'img2.jpg', ...]
train_labels: List[List[int]] = [[0, 0, 1], [1, 1, 0], ...]

num_classes: int = len(train_labels[0])

datamodule = ImageClassificationData.from_filepaths(
    train_filepaths=train_images,
    train_labels=train_labels,
    batch_size=32,
    num_workers=mproc.cpu_count(),
    train_transform=TRAIN_TRANSFORM,
    val_transform=VALID_TRANSFORM,
    val_split=0.3
)

model = ImageClassifier(
    num_classes=num_classes,
    backbone="resnet50",
    pretrained=True,
    optimizer = torch.optim.Adam,
    loss_fn=torch.nn.functional.binary_cross_entropy_with_logits,
)

It seems that the DataModule don't cast to torch tensors and don't prepare the batch.

As a workaround, and NOT specified in the docs - you can pass the labels as List[torch.Tensor].
However, you have create and tweak a custom loss function as follows:

def my_loss(x, y):
    # To make this work you have to pass the labels as List[torch.Tensor] from `from_filepaths`.
    # This not specified anywhere, since the examples show that you must pass List[int].
    assert isinstance(x, torch.Tensor), type(x)
    assert isinstance(y, torch.Tensor), type(y)
    assert x.shape == y.shape, (x.shape, y.shape)
    # NOTE: even if you pass the tensor as torch.float it's automatically casted to torch.int64
    y = y.to(x)  # labels have to be in floating point precision to compute BCE
    return torch.nn.functional.binary_cross_entropy_with_logits(x, y)

IMO, there should a unique way to pass labels either List[List[int]] or List[torch.Tensor].

@edgarriba edgarriba added bug / fix Something isn't working help wanted Extra attention is needed labels Apr 9, 2021
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
bug / fix Something isn't working help wanted Extra attention is needed
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant