-
Notifications
You must be signed in to change notification settings - Fork 296
feat: Implement embed_image() #5101
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
11 commits
Select commit
Hold shift + click to select a range
cfdd91b
impl embed_image
desmondcheongzx d858ca0
add tests
desmondcheongzx 7b2a06f
address
desmondcheongzx 7bae303
fix mypy
desmondcheongzx a896aee
ease up on ci
desmondcheongzx e50ecd3
fix ci
desmondcheongzx 9a6368c
use safetensors
desmondcheongzx 546126a
address comments
desmondcheongzx 4ff1e24
oops
desmondcheongzx 0320345
Update daft/ai/utils.py
desmondcheongzx fb7a5ea
style
desmondcheongzx File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
from __future__ import annotations | ||
|
||
from daft.ai.provider import Provider | ||
|
||
from daft.ai.transformers.image_embedder import TransformersImageEmbedderDescriptor | ||
from typing import TYPE_CHECKING, Any | ||
|
||
if TYPE_CHECKING: | ||
from daft.ai.protocols import ImageEmbedderDescriptor, TextEmbedderDescriptor | ||
from daft.ai.typing import Options | ||
|
||
__all__ = [ | ||
"TransformersProvider", | ||
] | ||
|
||
|
||
class TransformersProvider(Provider): | ||
_name: str | ||
_options: Options | ||
|
||
def __init__(self, name: str | None = None, **options: Any): | ||
self._name = name if name else "transformers" | ||
self._options = options | ||
|
||
@property | ||
def name(self) -> str: | ||
return self._name | ||
|
||
def get_image_embedder(self, model: str | None = None, **options: Any) -> ImageEmbedderDescriptor: | ||
return TransformersImageEmbedderDescriptor(model or "openai/clip-vit-base-patch32", options) | ||
|
||
def get_text_embedder(self, model: str | None = None, **options: Any) -> TextEmbedderDescriptor: | ||
raise NotImplementedError("embed_text is not currently implemented for the Transformers provider") |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
from __future__ import annotations | ||
|
||
from dataclasses import dataclass | ||
from typing import TYPE_CHECKING, Any | ||
|
||
import torch | ||
from transformers import AutoConfig, AutoModel, AutoProcessor | ||
|
||
from daft import DataType | ||
from daft.ai.protocols import ImageEmbedder, ImageEmbedderDescriptor | ||
from daft.ai.typing import EmbeddingDimensions, Options | ||
from daft.ai.utils import get_device | ||
from daft.dependencies import pil_image | ||
|
||
if TYPE_CHECKING: | ||
from daft.ai.typing import Embedding, Image | ||
|
||
|
||
@dataclass | ||
class TransformersImageEmbedderDescriptor(ImageEmbedderDescriptor): | ||
model: str | ||
options: Options | ||
|
||
def get_provider(self) -> str: | ||
return "transformers" | ||
|
||
def get_model(self) -> str: | ||
return self.model | ||
|
||
def get_options(self) -> Options: | ||
return self.options | ||
|
||
def get_dimensions(self) -> EmbeddingDimensions: | ||
config = AutoConfig.from_pretrained(self.model, trust_remote_code=True) | ||
# For CLIP models, the image embedding dimension is typically in projection_dim or hidden_size. | ||
embedding_size = getattr(config, "projection_dim", getattr(config, "hidden_size", 512)) | ||
return EmbeddingDimensions(size=embedding_size, dtype=DataType.float32()) | ||
|
||
def instantiate(self) -> ImageEmbedder: | ||
return TransformersImageEmbedder(self.model, **self.options) | ||
|
||
|
||
class TransformersImageEmbedder(ImageEmbedder): | ||
model: Any | ||
options: Options | ||
|
||
def __init__(self, model_name_or_path: str, **options: Any): | ||
self.device = get_device() | ||
self.model = AutoModel.from_pretrained( | ||
model_name_or_path, | ||
trust_remote_code=True, | ||
use_safetensors=True, | ||
).to(self.device) | ||
self.processor = AutoProcessor.from_pretrained(model_name_or_path, trust_remote_code=True, use_fast=True) | ||
self.options = options | ||
|
||
def embed_image(self, images: list[Image]) -> list[Embedding]: | ||
# TODO(desmond): There's potential for image decoding and processing on the GPU with greater | ||
# performance. Methods differ a little between different models, so let's do it later. | ||
pil_images = [pil_image.fromarray(image) for image in images] | ||
processed = self.processor(images=pil_images, return_tensors="pt") | ||
pixel_values = processed["pixel_values"].to(self.device) | ||
|
||
with torch.inference_mode(): | ||
embeddings = self.model.get_image_features(pixel_values) | ||
return embeddings.cpu().numpy().tolist() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
from __future__ import annotations | ||
|
||
from typing import TYPE_CHECKING | ||
|
||
if TYPE_CHECKING: | ||
import torch | ||
|
||
|
||
def get_device() -> torch.device: | ||
"""Get the best available PyTorch device for computation. | ||
|
||
This function automatically selects the optimal device in order of preference: | ||
1. CUDA GPU (if available) - for NVIDIA GPUs with CUDA support | ||
2. MPS (Metal Performance Shaders) - for Apple Silicon Macs | ||
3. CPU - as fallback when no GPU acceleration is available | ||
""" | ||
import torch | ||
|
||
device = ( | ||
torch.device("cuda") | ||
if torch.cuda.is_available() | ||
else torch.device("mps") | ||
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available() | ||
else torch.device("cpu") | ||
) | ||
return device |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice, thanks. A little 'eh, why?' right now — but you can see how once the typing work catches up, how daft will have its own "Image" and "Embedding" type which is np.ndarray compatible via https://numpy.org/devdocs/reference/arrays.interface.html