Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add LMStudioClient and update __init__.py #210

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions adalflow/adalflow/components/model_client/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,10 @@
"adalflow.components.model_client.openai_client.get_probabilities",
OptionalPackages.OPENAI,
)
LMStudioClient = LazyImport(
"adalflow.components.model_client.lm_studio_client.LMStudioClient",
OptionalPackages.LMSTUDIO,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The Optional package is not defined, please add it.

This dependency will also be added in the pyproejct.toml under /adalflow

)


__all__ = [
Expand All @@ -72,6 +76,7 @@
"GroqAPIClient",
"OpenAIClient",
"GoogleGenAIClient",
"LMStudioClient",
]

for name in __all__:
Expand Down
120 changes: 120 additions & 0 deletions adalflow/adalflow/components/model_client/lm_studio_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
import os
import re
from typing import Dict, Optional, Any, List, Union, Sequence
import logging
import backoff
import requests
from adalflow.core.types import ModelType, GeneratorOutput, EmbedderOutput, Embedding, Usage
from adalflow.core.model_client import ModelClient

log = logging.getLogger(__name__)

class LMStudioClient(ModelClient):
"""A component wrapper for the LM Studio API client."""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add relevant links and instructions on how to set up the client and additionally some example in this doc_string.


def __init__(self, host: Optional[str] = None, port: Optional[int] = None):
super().__init__()
self._host = host or os.getenv("LMSTUDIO_HOST", "http://localhost")
self._port = port or int(os.getenv("LMSTUDIO_PORT", "1234"))
self._base_url = f"{self._host}:{self._port}/v1"
self.init_sync_client()
self.async_client = None # To be added

def init_sync_client(self):
"""Create the synchronous client"""
self.sync_client = requests.Session()

def convert_inputs_to_api_kwargs(
self,
input: Optional[Any] = None,
model_kwargs: Dict = {},
model_type: ModelType = ModelType.UNDEFINED,
) -> Dict:
"""Convert the input and model_kwargs to api_kwargs for the LM Studio API."""
final_model_kwargs = model_kwargs.copy()
if model_type == ModelType.EMBEDDER:
if isinstance(input, str):
input = [input]
assert isinstance(input, Sequence), "input must be a sequence of text"
final_model_kwargs["input"] = input
elif model_type == ModelType.LLM:
messages = []

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should use type hints for messages

Suggested change
messages = []
messages: List[Dict[str, str]] = []

ref: model_client/openai_client.py line 234

if input is not None and input != "":
messages.append({"role": "system", "content": "You are a helpful assistant. Provide a direct and concise answer to the user's question. Do not include any URLs or references in your response."})
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As we use uni-prompt where both system and user prompt are in the same jinja2 syntax, we only need one message here, and in default, we use role system.

messages.append({"role":"system", "content": input})

Please modify it to this.

messages.append({"role": "user", "content": input})
assert isinstance(messages, Sequence), "input must be a sequence of messages"

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this assert statement is needed since messages is explicitly created as a list just a few lines above.

final_model_kwargs["messages"] = messages

# Set default values for controlling response length if not provided

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider using a for-loop instead

Suggested change
# Set default values for controlling response length if not provided
default_values = [("temperature", 0.1), ("frequency_penalty", 0.0), ("presence_penalty", 0.0), ("stop", ["\n", "###", "://"])]
for key, val in default_values:
final_model_kwargs.setdefault(key, val)

final_model_kwargs.setdefault("max_tokens", 50)
final_model_kwargs.setdefault("temperature", 0.1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use 0 as default temperature

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

or we should let the model provider decides the default behavior.

final_model_kwargs.setdefault("top_p", 0.9)
final_model_kwargs.setdefault("frequency_penalty", 0.0)
final_model_kwargs.setdefault("presence_penalty", 0.0)
final_model_kwargs.setdefault("stop", ["\n", "###", "://"])
else:
raise ValueError(f"model_type {model_type} is not supported")
return final_model_kwargs

@backoff.on_exception(backoff.expo, requests.RequestException, max_time=10)
def call(self, api_kwargs: Dict = {}, model_type: ModelType = ModelType.UNDEFINED):
if model_type == ModelType.EMBEDDER:
response = self.sync_client.post(f"{self._base_url}/embeddings", json=api_kwargs)
elif model_type == ModelType.LLM:
response = self.sync_client.post(f"{self._base_url}/chat/completions", json=api_kwargs)
else:
raise ValueError(f"model_type {model_type} is not supported")

response.raise_for_status()
return response.json()

def parse_chat_completion(self, completion: Dict) -> GeneratorOutput:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggest writting more precise error messages

Suggested change
def parse_chat_completion(self, completion: Dict) -> GeneratorOutput:
def parse_chat_completion(self, completion: Dict) -> GeneratorOutput:
"""Parse the completion to a GeneratorOutput."""
try:
if "choices" not in completion:
return GeneratorOutput(data=None, error="Error parsing the completion: 'choices' not in 'completion'.", raw_response=content)
elif not len(completion["choices"]) > 0:
return GeneratorOutput(data=None, error="Error parsing the completion: 'choices' length is 0.", raw_response=content)
else:
content = completion["choices"][0]["message"]["content"]
# Clean up the content
content = self._clean_response(content)
return GeneratorOutput(data=None, raw_response=content)
except Exception as e:
log.error(f"Error parsing the completion: {e}")
return GeneratorOutput(data=None, error=str(e), raw_response=completion)

"""Parse the completion to a GeneratorOutput."""
if "choices" in completion and len(completion["choices"]) > 0:
content = completion["choices"][0]["message"]["content"]

# Clean up the content
content = self._clean_response(content)

return GeneratorOutput(data=None, raw_response=content)
else:
log.error(f"Error parsing the completion: {completion}")
return GeneratorOutput(data=None, error="Error parsing the completion", raw_response=completion)

def _clean_response(self, content: str) -> str:
"""Clean up the response content."""
# Remove any URLs
content = re.sub(r'http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\\(\\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+', '', content)

# Remove any content after "###" or "://"
content = re.split(r'###|://', content)[0]

# Remove any remaining HTML-like tags
content = re.sub(r'<[^>]+>', '', content)

# Remove any repeated information
sentences = content.split('.')
unique_sentences = []
for sentence in sentences:
if sentence.strip() and sentence.strip() not in unique_sentences:
unique_sentences.append(sentence.strip())
content = '. '.join(unique_sentences)

return content.strip()

def parse_embedding_response(self, response: Dict) -> EmbedderOutput:
"""Parse the embedding response to an EmbedderOutput."""
try:
embeddings = [Embedding(embedding=data["embedding"], index=i) for i, data in enumerate(response["data"])]
usage = Usage(
prompt_tokens=response["usage"]["prompt_tokens"],
total_tokens=response["usage"]["total_tokens"]
)
return EmbedderOutput(data=embeddings, model=response["model"], usage=usage)
except Exception as e:
log.error(f"Error parsing the embedding response: {e}")
return EmbedderOutput(data=[], error=str(e), raw_response=response)

async def acall(self, api_kwargs: Dict = {}, model_type: ModelType = ModelType.UNDEFINED):
"""LM Studio doesn't support async calls natively, so we use the sync method."""
return self.call(api_kwargs, model_type)