Skip to content
This repository was archived by the owner on Oct 9, 2023. It is now read-only.

ImageEmbedder default behavior is not a flattened output #656

Closed
tszumowski opened this issue Aug 13, 2021 · 3 comments · Fixed by #665 or #666
Closed

ImageEmbedder default behavior is not a flattened output #656

tszumowski opened this issue Aug 13, 2021 · 3 comments · Fixed by #665 or #666
Labels
bug / fix Something isn't working help wanted Extra attention is needed

Comments

@tszumowski
Copy link
Contributor

🐛 Bug

I discovered this issue while testing PR #655. If you run the Image Embedding README example code, it returns a 3D tensor.
My understanding from the use of embeddings in general, and how they are used in Fifty One is they expect the embeddings to be 1D (for each embedding).

The reason it returns a 3D tensor is because it depends on the backbone used. The default there is resnet101, which returns a 2048x7x7 shape tensor. Others like inception return a flat 1D tensor, i.e. length-X.

To Reproduce

Steps to reproduce the behavior:

Run the README example, but remove the embedding_dim parameter. See below for example.

Note: as-is, this will error on print(embeddings.shape), regardless of configuration, since that is a list. But the question here is around the logic for the ImageEmbedder.

Code sample

from flash.core.data.utils import download_data
from flash.image import ImageEmbedder

# 1. Download the data
download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", "data/")

# 2. Create an ImageEmbedder with resnet50 trained on imagenet.
embedder = ImageEmbedder(backbone="resnet50")

# 3. Generate an embedding from an image path.
embeddings = embedder.predict("data/hymenoptera_data/predict/153783656_85f9c3ac70.jpg")

# 4. Print embeddings shape
print(embeddings.shape)

Expected behavior

Expect to see a 100352x1 shape tensor as the output, instead of 2048x7x7.

Environment

  • PyTorch Version (e.g., 1.0): 1.9
  • OS (e.g., Linux): Linux
  • How you installed PyTorch (conda, pip, source): pip
  • Build command you used (if compiling from source): N/A
  • Python version: 3.8.6
  • CUDA/cuDNN version: N/A
  • GPU models and configuration: N/A
  • Any other relevant information: N/A

Additional context

I believe the question is around what the logic should be here:
https://github.com/PyTorchLightning/lightning-flash/blob/075de3a46d74d9fc0e769401063fede1f12d0518/flash/image/embedding/model.py#L85-L92

If embedding_dim is None, then the head is nn.Identity(). If we desire a flat 1D embedding, then the question is: should nn.Identity() change to nn.Flatten()?

It could be argued that the user should be left to flatten after on their own, but per the contributing guidelines, I thought this would align with "Force User Decisions To Best Practices"

Let me know your thoughts. If that makes sense, then I can update the code, run some tests, and update docs in a PR.

@tszumowski tszumowski added bug / fix Something isn't working help wanted Extra attention is needed labels Aug 13, 2021
@ethanwharris
Copy link
Collaborator

@tszumowski great catch! I think this was supposed to be handled by the apply_pool, but that only gets called in embedding_dim was set haha:
https://github.com/PyTorchLightning/lightning-flash/blob/075de3a46d74d9fc0e769401063fede1f12d0518/flash/image/embedding/model.py#L110

I guess that's meant to be if not embedding_dim?

@tszumowski
Copy link
Contributor Author

Ok, to start, I am no longer able to replicate the error I have above. It prints out:

torch.Size([2048])

So no issues there.

Note: there is a typo where:

print(embeddings.shape)

should be

print(embeddings[0].shape)

in the examples and docs. That can be a small separate PR

I believe @ethanwharris's comment is valid, where you want to apply that pool only if the dimension is too high AND you don't have embedding dim assigned.

I'll cut two PRs to address the issues here.

@ethanwharris
Copy link
Collaborator

@tszumowski Awesome 😃 thanks for your work on this!

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
bug / fix Something isn't working help wanted Extra attention is needed
Projects
None yet
2 participants