Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 37 additions & 1 deletion daft/ai/openai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from daft.ai.provider import Provider

from daft.ai.openai.text_embedder import OpenAITextEmbedderDescriptor
from daft.ai.openai.text_embedder import OpenAITextEmbedderDescriptor, LMStudioTextEmbedderDescriptor
from typing import TYPE_CHECKING, Any, TypedDict

from typing_extensions import Unpack
Expand All @@ -13,6 +13,7 @@
from daft.ai.typing import Options

__all__ = [
"LMStudioProvider",
"OpenAIProvider",
]

Expand All @@ -39,3 +40,38 @@ def get_text_embedder(self, model: str | None = None, **options: Any) -> TextEmb

def get_image_embedder(self, model: str | None = None, **options: Any) -> ImageEmbedderDescriptor:
raise NotImplementedError("embed_image is not currently implemented for the OpenAI provider")


class LMStudioProvider(OpenAIProvider):
"""LM Studio provider that extends OpenAI provider with local server configuration.

LM Studio runs a local server that's API-compatible with OpenAI, so we can reuse
all the OpenAI logic and just configure the base URL to point to the local instance.
"""

def __init__(
self,
name: str | None = None,
**options: Unpack[OpenAIProviderOptions],
):
if "api_key" not in options:
options["api_key"] = "not-needed-for-lm-studio"
if "base_url" not in options:
options["base_url"] = "http://localhost:1234/v1"
else:
# Ensure base_url ends with /v1 for LM Studio compatibility.
base_url = options["base_url"]
if base_url is not None and not base_url.endswith("/v1"):
options["base_url"] = base_url.rstrip("/") + "/v1"
super().__init__(name or "lm_studio", **options)

def get_text_embedder(self, model: str | None = None, **options: Any) -> TextEmbedderDescriptor:
return LMStudioTextEmbedderDescriptor(
provider_name=self._name,
provider_options=self._options,
model_name=(model or "text-embedding-3-small"),
model_options=options,
)

def get_image_embedder(self, model: str | None = None, **options: Any) -> ImageEmbedderDescriptor:
raise NotImplementedError("embed_image is not currently implemented for the LM Studio provider")
42 changes: 42 additions & 0 deletions daft/ai/openai/text_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,48 @@ def instantiate(self) -> TextEmbedder:
)


@dataclass
class LMStudioTextEmbedderDescriptor(TextEmbedderDescriptor):
"""LM Studio text embedder descriptor that dynamically discovers model dimensions.

Unlike OpenAI, LM Studio can load different models with varying embedding dimensions.
This descriptor queries the local server to get the actual model dimensions.
"""

provider_name: str
provider_options: OpenAIProviderOptions
model_name: str
model_options: Options

def get_provider(self) -> str:
return "lm_studio"

def get_model(self) -> str:
return self.model_name

def get_options(self) -> Options:
return self.model_options

def get_dimensions(self) -> EmbeddingDimensions:
try:
client = OpenAI(**self.provider_options)
response = client.embeddings.create(
input="dimension probe",
model=self.model_name,
encoding_format="float",
)
size = len(response.data[0].embedding)
return EmbeddingDimensions(size=size, dtype=DataType.float32())
except Exception as ex:
raise ValueError("Failed to determine embedding dimensions from LM Studio.") from ex
Comment on lines +110 to +121
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: The dimension probing creates a new OpenAI client and makes a network request during descriptor creation. This could be expensive if called repeatedly and may fail if the LM Studio server is temporarily unavailable. Consider caching the dimensions or moving this logic to instantiation time.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a local network request

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"I am not concerned"

Comment on lines +120 to +121
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: The broad exception catch could mask specific connection errors. Consider catching more specific exceptions like OpenAIError or connection-related exceptions to provide better error messages.


def instantiate(self) -> TextEmbedder:
return OpenAITextEmbedder(
client=OpenAI(**self.provider_options),
model=self.model_name,
)


class OpenAITextEmbedder(TextEmbedder):
"""The OpenAI TextEmbedder will batch across rows, and split a large row into a batch request when necessary.

Expand Down
10 changes: 10 additions & 0 deletions daft/ai/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,15 @@ def __init__(self, dependencies: list[str]):
super().__init__(f"Missing required dependencies: {deps}. " f"Please install {deps} to use this provider.")


def load_lm_studio(name: str | None = None, **options: Any) -> Provider:
try:
from daft.ai.openai import LMStudioProvider

return LMStudioProvider(name, **options)
except ImportError as e:
raise ProviderImportError(["openai"]) from e


def load_openai(name: str | None = None, **options: Unpack[OpenAIProviderOptions]) -> Provider:
try:
from daft.ai.openai import OpenAIProvider
Expand Down Expand Up @@ -44,6 +53,7 @@ def load_transformers(name: str | None = None, **options: Any) -> Provider:


PROVIDERS: dict[str, Callable[..., Provider]] = {
"lm_studio": load_lm_studio,
"openai": load_openai,
"sentence_transformers": load_sentence_transformers,
"transformers": load_transformers,
Expand Down
26 changes: 26 additions & 0 deletions docs/modalities/text.md
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,32 @@ model = "text-embedding-3-small"

In this case you could either use a different model with a larger maximum context length, or could chunk your text into smaller segments before generating embeddings. See our [text embeddings guide](../examples/text-embeddings.md) for examples of text chunking strategies, or refer to the section below on [text chunking](#chunk-text-into-smaller-pieces).

#### Using LM Studio

[LM Studio](https://lmstudio.ai/) is a local AI model platform that lets you run Large Language Models like Qwen, Mistral, Gemma, or gpt-oss on your own machine. If you're running an LM studio server, Daft can use it as a provider for computing embeddings.

First install the optional OpenAI dependency for Daft. This is needed because LM studio uses an OpenAI-compatible API.

```bash
pip install -U "daft[openai]"
```

LM Studio runs on `localhost` port `1234` by default, but you can customize the `base_url` as needed in Daft. In this example, we use the [`nomic-ai/nomic-embed-text-v1.5`](https://huggingface.co/nomic-ai/nomic-embed-text-v1.5) embedding model.

```python
import daft
from daft.ai.provider import load_provider
from daft.functions.ai import embed_text

provider = load_provider("lm_studio", base_url="http://127.0.0.1:1235") # This base_url parameter is optional if you're using the defaults for LM Studio. You can modify this as needed.
model = "text-embedding-nomic-embed-text-v1.5" # Select a text embedding model that you've loaded into LM Studio.

(
daft.read_huggingface("Open-Orca/OpenOrca")
.with_column("embedding", embed_text(daft.col("response"), provider=provider, model=model))
.show()
)
```

### How to work with embeddings

Expand Down
65 changes: 65 additions & 0 deletions tests/ai/test_lm_studio.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
from __future__ import annotations

import pytest

pytest.importorskip("openai")

from unittest.mock import patch

import numpy as np
from openai.types.create_embedding_response import CreateEmbeddingResponse
from openai.types.embedding import Embedding as OpenAIEmbedding

from daft.ai.openai import LMStudioProvider
from daft.ai.protocols import TextEmbedder, TextEmbedderDescriptor


@pytest.mark.parametrize(
"model, embedding_dim",
[
("text-embedding-qwen3-embedding-0.6b", 1024),
("text-embedding-nomic-embed-text-v1.5", 768),
],
)
def test_lm_studio_text_embedder(model, embedding_dim):
text_data = [
"Alice was beginning to get very tired of sitting by her sister on the bank.",
"So she was considering in her own mind (as well as she could, for the hot day made her feel very sleepy and stupid),",
"whether the pleasure of making a daisy-chain would be worth the trouble of getting up and picking the daisies,",
"when suddenly a White Rabbit with pink eyes ran close by her.",
"There was nothing so very remarkable in that;",
"nor did Alice think it so very much out of the way to hear the Rabbit say to itself, 'Oh dear! Oh dear! I shall be late!'",
]

def mock_embedding_response(input_data):
if isinstance(input_data, list):
num_texts = len(input_data)
else:
num_texts = 1

embeddings = []
for i in range(num_texts):
embedding_values = [0.1] * embedding_dim
embedding_obj = OpenAIEmbedding(embedding=embedding_values, index=i, object="embedding")
embeddings.append(embedding_obj)

response = CreateEmbeddingResponse(
data=embeddings, model=model, object="list", usage={"prompt_tokens": 0, "total_tokens": 0}
)
return response

with patch("openai.resources.embeddings.Embeddings.create") as mock_embed:
mock_embed.side_effect = lambda **kwargs: mock_embedding_response(kwargs.get("input"))

descriptor = LMStudioProvider().get_text_embedder(model=model)
assert isinstance(descriptor, TextEmbedderDescriptor)
assert descriptor.get_provider() == "lm_studio"
assert descriptor.get_model() == model
assert descriptor.get_dimensions().size == embedding_dim

embedder = descriptor.instantiate()
assert isinstance(embedder, TextEmbedder)
embeddings = embedder.embed_text(text_data)
assert len(embeddings) == len(text_data)
assert all(isinstance(embedding, np.ndarray) for embedding in embeddings)
assert all(len(embedding) == embedding_dim for embedding in embeddings)
Loading