Skip to content

Conversation

@younesbelkada
Copy link
Contributor

@younesbelkada younesbelkada commented Nov 25, 2022

What does this PR do?

This PR introduces the input casting mechanism for image processors. Since the introduction of accelerate supported models for Vision, I have been playing around with half-precision models. I found it a bit inintuitive to manually cast the pixel_values outside the ImageProcessor class. Therefore for some models, small hacks have been introduced to make the casting operation more user-friendly.
With this PR, it will be possible to cast the input tensors to any floating point precision, for any framework, at theImageProcessor level as follows:

from transformers import ViTFeatureExtractor
from PIL import Image
import requests

url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
image = Image.open(requests.get(url, stream=True).raw)
feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-large-patch32-384')
inputs = feature_extractor(images=image, return_tensors="np", float_precision="float16")
print(inputs.pixel_values.dtype)
>>> float16

The casting discards non-floating point tensors, therefore these tensors should not be affected by the casting mechanism (thinking for eg for ViLT that takes both text + image)

With this PR, the hacks introduced on ViT and OWLViT will be removed!

cc @amyeroberts @ydshieh

@younesbelkada younesbelkada changed the title add v1 float_precision [Vision] Support different floating precision inputs from the ImageProcessor Nov 25, 2022
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

@younesbelkada
Copy link
Contributor Author

I think this PR is ready, at least as a PoC!
To make the PR complete, for now the arg float_precision needs to be manually added for each image processor. Before moving forward and start doing it for all image processors and adding tests, I would love to hear from @sgugger, @amyeroberts & @ydshieh to see if this is the approach we would like to follow!
Thanks again!

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not 100% sure how the float precision is used in other frameworks. If we use it only for PyTorch, I would concentrate support in the to method, just making sure it also accepts a torch.device without adding a new argument for the float_precision (something that should be added to this PR in any case).

If it is used in other frameworks, your approach seems right, though for readability I would just do the loop in every branch of the test insterad of creating cast_fun and is_floating.

Copy link
Collaborator

@ydshieh ydshieh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am fine for supporting float precision for all frameworks.
Thank you for working on this.

for key, value in self.items():
# sanity check that we check for only tensors
if is_tensor(value):
if is_floating(value):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not very comfortable to call these is_tensor and if_floating without checking if value is from target_framework.

For example, how about tensor_type being pt but value is a tf tensor?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the feedback! I guess this would not happen since the test is already done on the convert_to_tensors function that is called right before

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with you @younesbelkada . However, this method is added like a public method, so the concern is there (despite I doubt any user will use it). If it is prefixed with _, I won't complain at all :-)

Let @sgugger review and give us his opinion if we should make any effort on such things.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ahh I see now! Yes then makes sense to have it prefixed with _ 💪

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My opinion has been stated above. I don't think any of this is useful as Flax and TensorFlow deal differently with different dtypes, and there should only be a slight adaptation of the to method.

Copy link
Contributor

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for implementing the idea in this PR. In terms of structure looks good - juts a few small comments.

I don't have a very strong opinion of how dtypes should be cast. I was originally for the float_precision flag. However, @sgugger has raised good points about using the .to API and I agree that should be focused on first.

@gante @Rocketknight1 - how useful would this be in TF land?

@younesbelkada
Copy link
Contributor Author

younesbelkada commented Dec 1, 2022

Thanks so much everyone for your comments!
After thinking a bit and trying to see if this could be useful for flax

import jax.numpy as jnp
from transformers import FlaxViTForImageClassification, ViTFeatureExtractor

from PIL import Image
import requests

url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
image = Image.open(requests.get(url, stream=True).raw)

model = FlaxViTForImageClassification.from_pretrained("google/vit-base-patch16-224", dtype=jnp.float16)
feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224')

inputs = feature_extractor(images=image, return_tensors="np")
outputs = model(**inputs)
print(outputs)

it seems that flax can deal properly with different dtype, without having to explicitly cast the input. I think that a good point has been raised by @sgugger, however it could be useful if it is needed on tf side. If not, happy to change the PR to something that modifies only the .to function as this will be intended only for PyTorch.

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
@ydshieh
Copy link
Collaborator

ydshieh commented Dec 1, 2022

I don't have strong opinion though. So you can follow what @sgugger suggests. If we find it's useful for other frameworks, we can add them back.

@younesbelkada
Copy link
Contributor Author

Thanks everyone!
Let's keep this PR open in case we figure out this is needed for tf. I have opened a PR in #20536 for supporting dtypes in .to

@gante
Copy link
Contributor

gante commented Dec 16, 2022

@gante @Rocketknight1 - how useful would this be in TF land?

I don't think our TF models are compatible with half-precision, right @Rocketknight1? At least I haven't used TF with half-precision :D

@huggingface huggingface deleted a comment from github-actions bot Jan 10, 2023
@Rocketknight1
Copy link
Member

Extremely late reply on the TF front, but yeah, we aren't really running TF models in half precision right now. We do support mixed precision (similar to Torch AMP), but we don't officially support splatting the whole model to (b)float16 yet.

@huggingface huggingface deleted a comment from github-actions bot Feb 13, 2023
@huggingface huggingface deleted a comment from github-actions bot Mar 9, 2023
@github-actions
Copy link
Contributor

github-actions bot commented Apr 3, 2023

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot closed this Apr 12, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants