Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion daft/ai/_expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

if TYPE_CHECKING:
from daft import Series
from daft.ai.protocols import TextEmbedder, TextEmbedderDescriptor
from daft.ai.protocols import ImageEmbedder, ImageEmbedderDescriptor, TextEmbedder, TextEmbedderDescriptor
from daft.ai.typing import Embedding


Expand All @@ -19,3 +19,16 @@ def __init__(self, text_embedder: TextEmbedderDescriptor):
def __call__(self, text_series: Series) -> list[Embedding]:
text = text_series.to_pylist()
return self.text_embedder.embed_text(text) if text else []


class _ImageEmbedderExpression:
"""Function expression implementation for an ImageEmbedder protocol."""

image_embedder: ImageEmbedder

def __init__(self, image_embedder: ImageEmbedderDescriptor):
self.image_embedder = image_embedder.instantiate()

def __call__(self, image_series: Series) -> list[Embedding]:
image = image_series.to_pylist()
return self.image_embedder.embed_image(image) if image else []
5 changes: 4 additions & 1 deletion daft/ai/openai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

if TYPE_CHECKING:
from daft.ai.openai.typing import OpenAIProviderOptions
from daft.ai.protocols import TextEmbedder, TextEmbedderDescriptor
from daft.ai.protocols import ImageEmbedderDescriptor, TextEmbedderDescriptor
from daft.ai.typing import Options

__all__ = [
Expand All @@ -36,3 +36,6 @@ def get_text_embedder(self, model: str | None = None, **options: Any) -> TextEmb
model_name=(model or "text-embedding-3-small"),
model_options=options,
)

def get_image_embedder(self, model: str | None = None, **options: Any) -> ImageEmbedderDescriptor:
raise NotImplementedError("embed_image is not currently implemented for the OpenAI provider")
19 changes: 18 additions & 1 deletion daft/ai/protocols.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from daft.ai.typing import Descriptor

if TYPE_CHECKING:
from daft.ai.typing import Embedding, EmbeddingDimensions
from daft.ai.typing import Embedding, EmbeddingDimensions, Image


@runtime_checkable
Expand All @@ -24,3 +24,20 @@ class TextEmbedderDescriptor(Descriptor[TextEmbedder]):
@abstractmethod
def get_dimensions(self) -> EmbeddingDimensions:
"""Returns the dimensions of the embeddings produced by the described TextEmbedder."""


@runtime_checkable
class ImageEmbedder(Protocol):
"""Protocol for image embedding implementations."""

def embed_image(self, images: list[Image]) -> list[Embedding]:
"""Embeds a batch of images into an embedding vector."""
...


class ImageEmbedderDescriptor(Descriptor[ImageEmbedder]):
"""Descriptor for a ImageEmbedder implementation."""

@abstractmethod
def get_dimensions(self) -> EmbeddingDimensions:
"""Returns the dimensions of the embeddings produced by the described ImageEmbedder."""
17 changes: 16 additions & 1 deletion daft/ai/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

if TYPE_CHECKING:
from daft.ai.openai.typing import OpenAIProviderOptions
from daft.ai.protocols import TextEmbedderDescriptor
from daft.ai.protocols import ImageEmbedderDescriptor, TextEmbedderDescriptor


class ProviderImportError(ImportError):
Expand All @@ -34,9 +34,19 @@ def load_sentence_transformers(name: str | None = None, **options: Any) -> Provi
raise ProviderImportError(["sentence_transformers", "torch"]) from e


def load_transformers(name: str | None = None, **options: Any) -> Provider:
try:
from daft.ai.transformers import TransformersProvider

return TransformersProvider(name, **options)
except ImportError as e:
raise ProviderImportError(["torch", "torchvision", "transformers", "Pillow"]) from e


PROVIDERS: dict[str, Callable[..., Provider]] = {
"openai": load_openai,
"sentence_transformers": load_sentence_transformers,
"transformers": load_transformers,
}


Expand Down Expand Up @@ -65,3 +75,8 @@ def name(self) -> str:
def get_text_embedder(self, model: str | None = None, **options: Any) -> TextEmbedderDescriptor:
"""Returns a TextEmbedderDescriptor for this provider."""
...

@abstractmethod
def get_image_embedder(self, model: str | None = None, **options: Any) -> ImageEmbedderDescriptor:
"""Returns an ImageEmbedderDescriptor for this provider."""
...
5 changes: 4 additions & 1 deletion daft/ai/sentence_transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from typing import TYPE_CHECKING, Any

if TYPE_CHECKING:
from daft.ai.protocols import TextEmbedderDescriptor
from daft.ai.protocols import ImageEmbedderDescriptor, TextEmbedderDescriptor
from daft.ai.typing import Options

__all__ = [
Expand All @@ -28,3 +28,6 @@ def name(self) -> str:

def get_text_embedder(self, model: str | None = None, **options: Any) -> TextEmbedderDescriptor:
return SentenceTransformersTextEmbedderDescriptor(model or "sentence-transformers/all-MiniLM-L6-v2", options)

def get_image_embedder(self, model: str | None = None, **options: Any) -> ImageEmbedderDescriptor:
raise NotImplementedError("embed_image is not currently implemented for the Sentence Transformers provider")
2 changes: 1 addition & 1 deletion daft/ai/sentence_transformers/text_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def get_options(self) -> Options:
return self.options

def get_dimensions(self) -> EmbeddingDimensions:
dimensions = AutoConfig.from_pretrained(self.model).hidden_size
dimensions = AutoConfig.from_pretrained(self.model, trust_remote_code=True).hidden_size
return EmbeddingDimensions(size=dimensions, dtype=DataType.float32())

def instantiate(self) -> TextEmbedder:
Expand Down
33 changes: 33 additions & 0 deletions daft/ai/transformers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from __future__ import annotations

from daft.ai.provider import Provider

from daft.ai.transformers.image_embedder import TransformersImageEmbedderDescriptor
from typing import TYPE_CHECKING, Any

if TYPE_CHECKING:
from daft.ai.protocols import ImageEmbedderDescriptor, TextEmbedderDescriptor
from daft.ai.typing import Options

__all__ = [
"TransformersProvider",
]


class TransformersProvider(Provider):
_name: str
_options: Options

def __init__(self, name: str | None = None, **options: Any):
self._name = name if name else "transformers"
self._options = options

@property
def name(self) -> str:
return self._name

def get_image_embedder(self, model: str | None = None, **options: Any) -> ImageEmbedderDescriptor:
return TransformersImageEmbedderDescriptor(model or "openai/clip-vit-base-patch32", options)

def get_text_embedder(self, model: str | None = None, **options: Any) -> TextEmbedderDescriptor:
raise NotImplementedError("embed_text is not currently implemented for the Transformers provider")
66 changes: 66 additions & 0 deletions daft/ai/transformers/image_embedder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import TYPE_CHECKING, Any

import torch
from transformers import AutoConfig, AutoModel, AutoProcessor

from daft import DataType
from daft.ai.protocols import ImageEmbedder, ImageEmbedderDescriptor
from daft.ai.typing import EmbeddingDimensions, Options
from daft.ai.utils import get_device
from daft.dependencies import pil_image

if TYPE_CHECKING:
from daft.ai.typing import Embedding, Image


@dataclass
class TransformersImageEmbedderDescriptor(ImageEmbedderDescriptor):
model: str
options: Options

def get_provider(self) -> str:
return "transformers"

def get_model(self) -> str:
return self.model

def get_options(self) -> Options:
return self.options

def get_dimensions(self) -> EmbeddingDimensions:
config = AutoConfig.from_pretrained(self.model, trust_remote_code=True)
# For CLIP models, the image embedding dimension is typically in projection_dim or hidden_size.
embedding_size = getattr(config, "projection_dim", getattr(config, "hidden_size", 512))
return EmbeddingDimensions(size=embedding_size, dtype=DataType.float32())

def instantiate(self) -> ImageEmbedder:
return TransformersImageEmbedder(self.model, **self.options)


class TransformersImageEmbedder(ImageEmbedder):
model: Any
options: Options

def __init__(self, model_name_or_path: str, **options: Any):
self.device = get_device()
self.model = AutoModel.from_pretrained(
model_name_or_path,
trust_remote_code=True,
use_safetensors=True,
).to(self.device)
self.processor = AutoProcessor.from_pretrained(model_name_or_path, trust_remote_code=True, use_fast=True)
self.options = options

def embed_image(self, images: list[Image]) -> list[Embedding]:
# TODO(desmond): There's potential for image decoding and processing on the GPU with greater
# performance. Methods differ a little between different models, so let's do it later.
pil_images = [pil_image.fromarray(image) for image in images]
processed = self.processor(images=pil_images, return_tensors="pt")
pixel_values = processed["pixel_values"].to(self.device)

with torch.inference_mode():
embeddings = self.model.get_image_features(pixel_values)
return embeddings.cpu().numpy().tolist()
3 changes: 3 additions & 0 deletions daft/ai/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
"Descriptor",
"Embedding",
"EmbeddingDimensions",
"Image",
]


Expand Down Expand Up @@ -47,8 +48,10 @@ def instantiate(self) -> T:
from daft.dependencies import np

Embedding: TypeAlias = np.typing.NDArray[Any]
Image: TypeAlias = np.ndarray[Any, Any]
else:
Embedding: TypeAlias = Any
Image: TypeAlias = Any
Comment on lines +51 to +54
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, thanks. A little 'eh, why?' right now — but you can see how once the typing work catches up, how daft will have its own "Image" and "Embedding" type which is np.ndarray compatible via https://numpy.org/devdocs/reference/arrays.interface.html



@dataclass
Expand Down
26 changes: 26 additions & 0 deletions daft/ai/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from __future__ import annotations

from typing import TYPE_CHECKING

if TYPE_CHECKING:
import torch


def get_device() -> torch.device:
"""Get the best available PyTorch device for computation.

This function automatically selects the optimal device in order of preference:
1. CUDA GPU (if available) - for NVIDIA GPUs with CUDA support
2. MPS (Metal Performance Shaders) - for Apple Silicon Macs
3. CPU - as fallback when no GPU acceleration is available
"""
import torch

device = (
torch.device("cuda")
if torch.cuda.is_available()
else torch.device("mps")
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
else torch.device("cpu")
)
return device
32 changes: 32 additions & 0 deletions daft/functions/ai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,3 +85,35 @@ def embed_text(
)
expr = expr.with_init_args(text_embedder)
return expr(text)


def embed_image(
image: Expression,
*,
provider: str | Provider | None = None,
model: str | None = None,
**options: str,
) -> Expression:
"""Returns an expression that embeds images using the specified image model and provider.

Args:
image (Expression): The input image column expression.
provider (str | Provider | None): The provider to use for the image model. If None, the default provider is used.
model (str | None): The image model to use. Can be a model instance or a model name. If None, the default model is used.
**options: Any additional options to pass for the model.

Note:
Make sure the required provider packages are installed (e.g. vllm, transformers, openai).

Returns:
Expression: An expression representing the embedded image vectors.
"""
from daft.ai._expressions import _ImageEmbedderExpression
from daft.ai.protocols import ImageEmbedder

image_embedder = _resolve_provider(provider, "transformers").get_image_embedder(model, **options)
expr = udf(return_dtype=image_embedder.get_dimensions().as_dtype(), concurrency=1, use_process=False)(
_ImageEmbedderExpression
)
expr = expr.with_init_args(image_embedder)
return expr(image)
14 changes: 11 additions & 3 deletions docs/install.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,14 @@ Depending on your use case, you may need to install Daft with additional depende
</div>
</label>

<label class="checkbox-item">
<input type="checkbox" id="transformers" data-extra="transformers">
<span class="checkmark"></span>
<div class="checkbox-content">
<strong>Transformers</strong> <code>transformers</code>
</div>
</label>

<label class="checkbox-item">
<input type="checkbox" id="ray" data-extra="ray">
<span class="checkmark"></span>
Expand Down Expand Up @@ -171,7 +179,7 @@ pip install -U daft-lts
.checkbox-group {
display: grid;
grid-template-columns: repeat(auto-fit, minmax(250px, 1fr));
gap: 12px;
gap: 8px;
margin-bottom: 24px;
align-items: center;
}
Expand All @@ -180,7 +188,7 @@ pip install -U daft-lts
display: flex;
align-items: center;
cursor: pointer;
padding: 8px 8px 8px 16px;
padding: 4px 4px 4px 16px;
border: 1px solid var(--md-default-fg-color--lightest);
border-radius: 6px;
transition: all 0.2s ease;
Expand Down Expand Up @@ -228,7 +236,7 @@ pip install -U daft-lts

.checkbox-content strong {
display: block;
margin-bottom: -2px;
margin-bottom: -4px;
color: var(--md-default-fg-color);
font-size: 14px;
font-weight: normal;
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ homepage = "https://www.daft.ai"
repository = "https://github.com/Eventual-Inc/Daft"

[project.optional-dependencies]
all = ["daft[aws, azure, clickhouse, deltalake, gcp, hudi, huggingface, iceberg, lance, numpy, openai, pandas, ray, sentence-transformers, spark, sql, turbopuffer, unity]"]
all = ["daft[aws, azure, clickhouse, deltalake, gcp, hudi, huggingface, iceberg, lance, numpy, openai, pandas, ray, sentence-transformers, spark, sql, transformers, turbopuffer, unity]"]
aws = ["boto3"]
azure = []
clickhouse = ["clickhouse_connect"]
Expand All @@ -54,6 +54,7 @@ ray = [
"packaging"
]
sentence-transformers = ["sentence-transformers"]
transformers = ["transformers", "torch", "torchvision"]
spark = ["googleapis-common-protos >= 1.56.4", "grpcio >= 1.48", "grpcio-status >= 1.48", "numpy >= 1.15", "pandas >= 1.0.5", "py4j >= 0.10.9.7", "pyspark == 3.5.5"]
sql = ["connectorx", "sqlalchemy", "sqlglot"]
turbopuffer = ["turbopuffer"]
Expand Down
Loading
Loading