-
Notifications
You must be signed in to change notification settings - Fork 297
feat(embed_text): Support LM Studio as a provider #5103
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
+180
−1
Merged
Changes from all commits
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -85,6 +85,48 @@ def instantiate(self) -> TextEmbedder: | |
) | ||
|
||
|
||
@dataclass | ||
class LMStudioTextEmbedderDescriptor(TextEmbedderDescriptor): | ||
"""LM Studio text embedder descriptor that dynamically discovers model dimensions. | ||
|
||
Unlike OpenAI, LM Studio can load different models with varying embedding dimensions. | ||
This descriptor queries the local server to get the actual model dimensions. | ||
""" | ||
|
||
provider_name: str | ||
provider_options: OpenAIProviderOptions | ||
model_name: str | ||
model_options: Options | ||
|
||
def get_provider(self) -> str: | ||
return "lm_studio" | ||
|
||
def get_model(self) -> str: | ||
return self.model_name | ||
|
||
def get_options(self) -> Options: | ||
return self.model_options | ||
|
||
def get_dimensions(self) -> EmbeddingDimensions: | ||
try: | ||
client = OpenAI(**self.provider_options) | ||
response = client.embeddings.create( | ||
input="dimension probe", | ||
model=self.model_name, | ||
encoding_format="float", | ||
) | ||
size = len(response.data[0].embedding) | ||
return EmbeddingDimensions(size=size, dtype=DataType.float32()) | ||
except Exception as ex: | ||
raise ValueError("Failed to determine embedding dimensions from LM Studio.") from ex | ||
Comment on lines
+120
to
+121
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. style: The broad exception catch could mask specific connection errors. Consider catching more specific exceptions like |
||
|
||
def instantiate(self) -> TextEmbedder: | ||
return OpenAITextEmbedder( | ||
client=OpenAI(**self.provider_options), | ||
model=self.model_name, | ||
) | ||
|
||
|
||
class OpenAITextEmbedder(TextEmbedder): | ||
"""The OpenAI TextEmbedder will batch across rows, and split a large row into a batch request when necessary. | ||
|
||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
from __future__ import annotations | ||
|
||
import pytest | ||
|
||
pytest.importorskip("openai") | ||
|
||
from unittest.mock import patch | ||
|
||
import numpy as np | ||
from openai.types.create_embedding_response import CreateEmbeddingResponse | ||
from openai.types.embedding import Embedding as OpenAIEmbedding | ||
|
||
from daft.ai.openai import LMStudioProvider | ||
from daft.ai.protocols import TextEmbedder, TextEmbedderDescriptor | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"model, embedding_dim", | ||
[ | ||
("text-embedding-qwen3-embedding-0.6b", 1024), | ||
("text-embedding-nomic-embed-text-v1.5", 768), | ||
], | ||
) | ||
def test_lm_studio_text_embedder(model, embedding_dim): | ||
text_data = [ | ||
"Alice was beginning to get very tired of sitting by her sister on the bank.", | ||
"So she was considering in her own mind (as well as she could, for the hot day made her feel very sleepy and stupid),", | ||
"whether the pleasure of making a daisy-chain would be worth the trouble of getting up and picking the daisies,", | ||
"when suddenly a White Rabbit with pink eyes ran close by her.", | ||
"There was nothing so very remarkable in that;", | ||
"nor did Alice think it so very much out of the way to hear the Rabbit say to itself, 'Oh dear! Oh dear! I shall be late!'", | ||
] | ||
|
||
def mock_embedding_response(input_data): | ||
if isinstance(input_data, list): | ||
num_texts = len(input_data) | ||
else: | ||
num_texts = 1 | ||
|
||
embeddings = [] | ||
for i in range(num_texts): | ||
embedding_values = [0.1] * embedding_dim | ||
embedding_obj = OpenAIEmbedding(embedding=embedding_values, index=i, object="embedding") | ||
embeddings.append(embedding_obj) | ||
|
||
response = CreateEmbeddingResponse( | ||
data=embeddings, model=model, object="list", usage={"prompt_tokens": 0, "total_tokens": 0} | ||
) | ||
return response | ||
|
||
with patch("openai.resources.embeddings.Embeddings.create") as mock_embed: | ||
mock_embed.side_effect = lambda **kwargs: mock_embedding_response(kwargs.get("input")) | ||
|
||
descriptor = LMStudioProvider().get_text_embedder(model=model) | ||
assert isinstance(descriptor, TextEmbedderDescriptor) | ||
assert descriptor.get_provider() == "lm_studio" | ||
assert descriptor.get_model() == model | ||
assert descriptor.get_dimensions().size == embedding_dim | ||
|
||
embedder = descriptor.instantiate() | ||
assert isinstance(embedder, TextEmbedder) | ||
embeddings = embedder.embed_text(text_data) | ||
assert len(embeddings) == len(text_data) | ||
assert all(isinstance(embedding, np.ndarray) for embedding in embeddings) | ||
assert all(len(embedding) == embedding_dim for embedding in embeddings) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
style: The dimension probing creates a new OpenAI client and makes a network request during descriptor creation. This could be expensive if called repeatedly and may fail if the LM Studio server is temporarily unavailable. Consider caching the dimensions or moving this logic to instantiation time.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a local network request
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"I am not concerned"