-
Notifications
You must be signed in to change notification settings - Fork 295
feat: Implement embed_image() #5101
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Summary
This PR implements comprehensive image embedding functionality for the Daft framework, extending the existing text-only AI capabilities to support multimodal operations. The implementation follows the established provider architecture pattern used by text embeddings, adding a new embed_image()
function that works with transformer-based models like CLIP.
The changes introduce several key components:
-
Core Image Embedding Infrastructure: A new
ImageEmbedder
protocol andImageEmbedderDescriptor
abstract class inprotocols.py
that mirrors the existing text embedding pattern, providing a standardized interface for image embedding implementations. -
Transformers Provider Support: A complete
TransformersImageEmbedder
implementation that leverages HuggingFace transformers models with automatic device selection (CUDA/MPS/CPU), proper PIL image conversion, and batch processing capabilities. The implementation defaults to the 'openai/clip-vit-base-patch32' model. -
Provider Architecture Extensions: Updates to the provider system including a new
load_transformers()
function, addition of transformers to the PROVIDERS registry, and implementation of the abstractget_image_embedder()
method across all providers. Non-supporting providers (OpenAI and SentenceTransformers) properly raise NotImplementedError with descriptive messages. -
Expression Layer: An
_ImageEmbedderExpression
class that integrates with Daft's UDF system, enabling the image embedding functionality to work seamlessly within dataframe operations. -
Public API: The main
embed_image()
function indaft.functions.ai
that provides users with a simple interface for image embedding operations, following the same pattern as the existingembed_text()
function.
The implementation maintains API consistency across the framework while extending capabilities to support computer vision tasks. All changes follow the established patterns for dependency management, error handling, and provider resolution that users are already familiar with from text embedding operations.
Confidence score: 4/5
- This PR introduces significant new functionality but follows well-established patterns from the existing text embedding implementation
- Score reflects the complexity of the multimodal AI integration and potential for device/dependency-related issues
- Pay close attention to the transformers image embedder implementation and provider registration changes
9 files reviewed, 5 comments
tests/ai/test_transformers.py
Outdated
with pytest.raises(ImportError, match="Pillow is required for image processing but not available"): | ||
test_image = np.random.randint(0, 255, (100, 100, 3), dtype=np.uint8) | ||
( | ||
daft.from_pydict({"image": [test_image]}) | ||
.select(daft.col("image").cast(daft.DataType.image())) | ||
.select(embed_image(daft.col("image"))) | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: Test creates dataframe but doesn't execute it - the lazy evaluation means PIL check only happens on execution
with pytest.raises(ImportError, match="Pillow is required for image processing but not available"): | |
test_image = np.random.randint(0, 255, (100, 100, 3), dtype=np.uint8) | |
( | |
daft.from_pydict({"image": [test_image]}) | |
.select(daft.col("image").cast(daft.DataType.image())) | |
.select(embed_image(daft.col("image"))) | |
) | |
with pytest.raises(ImportError, match="Pillow is required for image processing but not available"): | |
test_image = np.random.randint(0, 255, (100, 100, 3), dtype=np.uint8) | |
( | |
daft.from_pydict({"image": [test_image]}) | |
.select(daft.col("image").cast(daft.DataType.image())) | |
.select(embed_image(daft.col("image"))) | |
.collect() | |
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Incorrect. I specifically want to check that PIL image check happens without needing to execute the query.
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #5101 +/- ##
==========================================
+ Coverage 75.27% 76.54% +1.26%
==========================================
Files 949 952 +3
Lines 132520 130581 -1939
==========================================
+ Hits 99761 99948 +187
+ Misses 32759 30633 -2126
🚀 New features to boost your workflow:
|
daft/ai/transformers/__init__.py
Outdated
def get_image_embedder(self, model: str | None = None, **options: Any) -> ImageEmbedderDescriptor: | ||
# Raise an error early if PIL is not available. | ||
if not pil_image.module_available(): | ||
raise ImportError("Pillow is required for image processing but not available") | ||
|
||
return TransformersImageEmbedderDescriptor(model or "openai/clip-vit-base-patch32", options) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You should just depend on PIL directly in the module, and the ProviderImportError within load_transformers
will handle the appropriate error messaging. PIL will need to be added to that dep list — I will leave a comment there as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah perfect
Image: TypeAlias = np.ndarray[Any, Any] | ||
else: | ||
Embedding: TypeAlias = Any | ||
Image: TypeAlias = Any |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice, thanks. A little 'eh, why?' right now — but you can see how once the typing work catches up, how daft will have its own "Image" and "Embedding" type which is np.ndarray compatible via https://numpy.org/devdocs/reference/arrays.interface.html
Co-authored-by: R. C. Howell <[email protected]>
## Changes Made Adds support for `embed_image()`, e.g. ``` import daft from daft.functions.ai import embed_image import numpy as np provider = "transformers" model = "openai/clip-vit-base-patch32" test_image = np.random.randint(0, 255, (100, 100, 3), dtype=np.uint8) ( daft.from_pydict({"image": [test_image] * 16}) .select(daft.col("image").cast(daft.DataType.image())) .select(embed_image(daft.col("image"), provider=provider, model=model)) .show() ) ``` **!! Currently only supports OpenAI CLIP models: https://huggingface.co/docs/transformers/en/model_doc/clip** --------- Co-authored-by: R. C. Howell <[email protected]>
Changes Made
Adds support for
embed_image()
, e.g.!! Currently only supports OpenAI CLIP models: https://huggingface.co/docs/transformers/en/model_doc/clip