diff --git a/daft/ai/_expressions.py b/daft/ai/_expressions.py index 8cc0b133a9..53059c6dfa 100644 --- a/daft/ai/_expressions.py +++ b/daft/ai/_expressions.py @@ -4,7 +4,7 @@ if TYPE_CHECKING: from daft import Series - from daft.ai.protocols import TextEmbedder, TextEmbedderDescriptor + from daft.ai.protocols import ImageEmbedder, ImageEmbedderDescriptor, TextEmbedder, TextEmbedderDescriptor from daft.ai.typing import Embedding @@ -19,3 +19,16 @@ def __init__(self, text_embedder: TextEmbedderDescriptor): def __call__(self, text_series: Series) -> list[Embedding]: text = text_series.to_pylist() return self.text_embedder.embed_text(text) if text else [] + + +class _ImageEmbedderExpression: + """Function expression implementation for an ImageEmbedder protocol.""" + + image_embedder: ImageEmbedder + + def __init__(self, image_embedder: ImageEmbedderDescriptor): + self.image_embedder = image_embedder.instantiate() + + def __call__(self, image_series: Series) -> list[Embedding]: + image = image_series.to_pylist() + return self.image_embedder.embed_image(image) if image else [] diff --git a/daft/ai/openai/__init__.py b/daft/ai/openai/__init__.py index 7f400cd4f1..fcf423fc97 100644 --- a/daft/ai/openai/__init__.py +++ b/daft/ai/openai/__init__.py @@ -9,7 +9,7 @@ if TYPE_CHECKING: from daft.ai.openai.typing import OpenAIProviderOptions - from daft.ai.protocols import TextEmbedder, TextEmbedderDescriptor + from daft.ai.protocols import ImageEmbedderDescriptor, TextEmbedderDescriptor from daft.ai.typing import Options __all__ = [ @@ -36,3 +36,6 @@ def get_text_embedder(self, model: str | None = None, **options: Any) -> TextEmb model_name=(model or "text-embedding-3-small"), model_options=options, ) + + def get_image_embedder(self, model: str | None = None, **options: Any) -> ImageEmbedderDescriptor: + raise NotImplementedError("embed_image is not currently implemented for the OpenAI provider") diff --git a/daft/ai/protocols.py b/daft/ai/protocols.py index 71cdbaca7c..4b34836e7a 100644 --- a/daft/ai/protocols.py +++ b/daft/ai/protocols.py @@ -6,7 +6,7 @@ from daft.ai.typing import Descriptor if TYPE_CHECKING: - from daft.ai.typing import Embedding, EmbeddingDimensions + from daft.ai.typing import Embedding, EmbeddingDimensions, Image @runtime_checkable @@ -24,3 +24,20 @@ class TextEmbedderDescriptor(Descriptor[TextEmbedder]): @abstractmethod def get_dimensions(self) -> EmbeddingDimensions: """Returns the dimensions of the embeddings produced by the described TextEmbedder.""" + + +@runtime_checkable +class ImageEmbedder(Protocol): + """Protocol for image embedding implementations.""" + + def embed_image(self, images: list[Image]) -> list[Embedding]: + """Embeds a batch of images into an embedding vector.""" + ... + + +class ImageEmbedderDescriptor(Descriptor[ImageEmbedder]): + """Descriptor for a ImageEmbedder implementation.""" + + @abstractmethod + def get_dimensions(self) -> EmbeddingDimensions: + """Returns the dimensions of the embeddings produced by the described ImageEmbedder.""" diff --git a/daft/ai/provider.py b/daft/ai/provider.py index 35ea43328d..06ee6cfd1f 100644 --- a/daft/ai/provider.py +++ b/daft/ai/provider.py @@ -7,7 +7,7 @@ if TYPE_CHECKING: from daft.ai.openai.typing import OpenAIProviderOptions - from daft.ai.protocols import TextEmbedderDescriptor + from daft.ai.protocols import ImageEmbedderDescriptor, TextEmbedderDescriptor class ProviderImportError(ImportError): @@ -34,9 +34,19 @@ def load_sentence_transformers(name: str | None = None, **options: Any) -> Provi raise ProviderImportError(["sentence_transformers", "torch"]) from e +def load_transformers(name: str | None = None, **options: Any) -> Provider: + try: + from daft.ai.transformers import TransformersProvider + + return TransformersProvider(name, **options) + except ImportError as e: + raise ProviderImportError(["torch", "torchvision", "transformers", "Pillow"]) from e + + PROVIDERS: dict[str, Callable[..., Provider]] = { "openai": load_openai, "sentence_transformers": load_sentence_transformers, + "transformers": load_transformers, } @@ -65,3 +75,8 @@ def name(self) -> str: def get_text_embedder(self, model: str | None = None, **options: Any) -> TextEmbedderDescriptor: """Returns a TextEmbedderDescriptor for this provider.""" ... + + @abstractmethod + def get_image_embedder(self, model: str | None = None, **options: Any) -> ImageEmbedderDescriptor: + """Returns an ImageEmbedderDescriptor for this provider.""" + ... diff --git a/daft/ai/sentence_transformers/__init__.py b/daft/ai/sentence_transformers/__init__.py index fd675eab2c..9e8e6db3bf 100644 --- a/daft/ai/sentence_transformers/__init__.py +++ b/daft/ai/sentence_transformers/__init__.py @@ -6,7 +6,7 @@ from typing import TYPE_CHECKING, Any if TYPE_CHECKING: - from daft.ai.protocols import TextEmbedderDescriptor + from daft.ai.protocols import ImageEmbedderDescriptor, TextEmbedderDescriptor from daft.ai.typing import Options __all__ = [ @@ -28,3 +28,6 @@ def name(self) -> str: def get_text_embedder(self, model: str | None = None, **options: Any) -> TextEmbedderDescriptor: return SentenceTransformersTextEmbedderDescriptor(model or "sentence-transformers/all-MiniLM-L6-v2", options) + + def get_image_embedder(self, model: str | None = None, **options: Any) -> ImageEmbedderDescriptor: + raise NotImplementedError("embed_image is not currently implemented for the Sentence Transformers provider") diff --git a/daft/ai/sentence_transformers/text_embedder.py b/daft/ai/sentence_transformers/text_embedder.py index 328cfccfff..646c3517f5 100644 --- a/daft/ai/sentence_transformers/text_embedder.py +++ b/daft/ai/sentence_transformers/text_embedder.py @@ -30,7 +30,7 @@ def get_options(self) -> Options: return self.options def get_dimensions(self) -> EmbeddingDimensions: - dimensions = AutoConfig.from_pretrained(self.model).hidden_size + dimensions = AutoConfig.from_pretrained(self.model, trust_remote_code=True).hidden_size return EmbeddingDimensions(size=dimensions, dtype=DataType.float32()) def instantiate(self) -> TextEmbedder: diff --git a/daft/ai/transformers/__init__.py b/daft/ai/transformers/__init__.py new file mode 100644 index 0000000000..6bd7e2c24e --- /dev/null +++ b/daft/ai/transformers/__init__.py @@ -0,0 +1,33 @@ +from __future__ import annotations + +from daft.ai.provider import Provider + +from daft.ai.transformers.image_embedder import TransformersImageEmbedderDescriptor +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from daft.ai.protocols import ImageEmbedderDescriptor, TextEmbedderDescriptor + from daft.ai.typing import Options + +__all__ = [ + "TransformersProvider", +] + + +class TransformersProvider(Provider): + _name: str + _options: Options + + def __init__(self, name: str | None = None, **options: Any): + self._name = name if name else "transformers" + self._options = options + + @property + def name(self) -> str: + return self._name + + def get_image_embedder(self, model: str | None = None, **options: Any) -> ImageEmbedderDescriptor: + return TransformersImageEmbedderDescriptor(model or "openai/clip-vit-base-patch32", options) + + def get_text_embedder(self, model: str | None = None, **options: Any) -> TextEmbedderDescriptor: + raise NotImplementedError("embed_text is not currently implemented for the Transformers provider") diff --git a/daft/ai/transformers/image_embedder.py b/daft/ai/transformers/image_embedder.py new file mode 100644 index 0000000000..b5ac4d973e --- /dev/null +++ b/daft/ai/transformers/image_embedder.py @@ -0,0 +1,66 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any + +import torch +from transformers import AutoConfig, AutoModel, AutoProcessor + +from daft import DataType +from daft.ai.protocols import ImageEmbedder, ImageEmbedderDescriptor +from daft.ai.typing import EmbeddingDimensions, Options +from daft.ai.utils import get_device +from daft.dependencies import pil_image + +if TYPE_CHECKING: + from daft.ai.typing import Embedding, Image + + +@dataclass +class TransformersImageEmbedderDescriptor(ImageEmbedderDescriptor): + model: str + options: Options + + def get_provider(self) -> str: + return "transformers" + + def get_model(self) -> str: + return self.model + + def get_options(self) -> Options: + return self.options + + def get_dimensions(self) -> EmbeddingDimensions: + config = AutoConfig.from_pretrained(self.model, trust_remote_code=True) + # For CLIP models, the image embedding dimension is typically in projection_dim or hidden_size. + embedding_size = getattr(config, "projection_dim", getattr(config, "hidden_size", 512)) + return EmbeddingDimensions(size=embedding_size, dtype=DataType.float32()) + + def instantiate(self) -> ImageEmbedder: + return TransformersImageEmbedder(self.model, **self.options) + + +class TransformersImageEmbedder(ImageEmbedder): + model: Any + options: Options + + def __init__(self, model_name_or_path: str, **options: Any): + self.device = get_device() + self.model = AutoModel.from_pretrained( + model_name_or_path, + trust_remote_code=True, + use_safetensors=True, + ).to(self.device) + self.processor = AutoProcessor.from_pretrained(model_name_or_path, trust_remote_code=True, use_fast=True) + self.options = options + + def embed_image(self, images: list[Image]) -> list[Embedding]: + # TODO(desmond): There's potential for image decoding and processing on the GPU with greater + # performance. Methods differ a little between different models, so let's do it later. + pil_images = [pil_image.fromarray(image) for image in images] + processed = self.processor(images=pil_images, return_tensors="pt") + pixel_values = processed["pixel_values"].to(self.device) + + with torch.inference_mode(): + embeddings = self.model.get_image_features(pixel_values) + return embeddings.cpu().numpy().tolist() diff --git a/daft/ai/typing.py b/daft/ai/typing.py index edaf6316a3..d98072ab34 100644 --- a/daft/ai/typing.py +++ b/daft/ai/typing.py @@ -16,6 +16,7 @@ "Descriptor", "Embedding", "EmbeddingDimensions", + "Image", ] @@ -47,8 +48,10 @@ def instantiate(self) -> T: from daft.dependencies import np Embedding: TypeAlias = np.typing.NDArray[Any] + Image: TypeAlias = np.ndarray[Any, Any] else: Embedding: TypeAlias = Any + Image: TypeAlias = Any @dataclass diff --git a/daft/ai/utils.py b/daft/ai/utils.py new file mode 100644 index 0000000000..a1a58b81ab --- /dev/null +++ b/daft/ai/utils.py @@ -0,0 +1,26 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + import torch + + +def get_device() -> torch.device: + """Get the best available PyTorch device for computation. + + This function automatically selects the optimal device in order of preference: + 1. CUDA GPU (if available) - for NVIDIA GPUs with CUDA support + 2. MPS (Metal Performance Shaders) - for Apple Silicon Macs + 3. CPU - as fallback when no GPU acceleration is available + """ + import torch + + device = ( + torch.device("cuda") + if torch.cuda.is_available() + else torch.device("mps") + if hasattr(torch.backends, "mps") and torch.backends.mps.is_available() + else torch.device("cpu") + ) + return device diff --git a/daft/functions/ai/__init__.py b/daft/functions/ai/__init__.py index d5745114be..ba9838e83c 100644 --- a/daft/functions/ai/__init__.py +++ b/daft/functions/ai/__init__.py @@ -85,3 +85,35 @@ def embed_text( ) expr = expr.with_init_args(text_embedder) return expr(text) + + +def embed_image( + image: Expression, + *, + provider: str | Provider | None = None, + model: str | None = None, + **options: str, +) -> Expression: + """Returns an expression that embeds images using the specified image model and provider. + + Args: + image (Expression): The input image column expression. + provider (str | Provider | None): The provider to use for the image model. If None, the default provider is used. + model (str | None): The image model to use. Can be a model instance or a model name. If None, the default model is used. + **options: Any additional options to pass for the model. + + Note: + Make sure the required provider packages are installed (e.g. vllm, transformers, openai). + + Returns: + Expression: An expression representing the embedded image vectors. + """ + from daft.ai._expressions import _ImageEmbedderExpression + from daft.ai.protocols import ImageEmbedder + + image_embedder = _resolve_provider(provider, "transformers").get_image_embedder(model, **options) + expr = udf(return_dtype=image_embedder.get_dimensions().as_dtype(), concurrency=1, use_process=False)( + _ImageEmbedderExpression + ) + expr = expr.with_init_args(image_embedder) + return expr(image) diff --git a/docs/install.md b/docs/install.md index d5852d15f2..95ccf9d83b 100644 --- a/docs/install.md +++ b/docs/install.md @@ -41,6 +41,14 @@ Depending on your use case, you may need to install Daft with additional depende + +