Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: groq support via official tool-calling API #1257

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 86 additions & 1 deletion memgpt/cli/cli_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from memgpt.constants import MEMGPT_DIR
from memgpt.credentials import MemGPTCredentials, SUPPORTED_AUTH_TYPES
from memgpt.data_types import User, LLMConfig, EmbeddingConfig
from memgpt.llm_api.openai import openai_get_model_list
from memgpt.llm_api.openai import openai_get_model_list, openai_get_model_context_window
from memgpt.llm_api.azure_openai import azure_openai_get_model_list
from memgpt.llm_api.google_ai import google_ai_get_model_list, google_ai_get_model_context_window
from memgpt.llm_api.anthropic import anthropic_get_model_list, antropic_get_model_context_window
Expand Down Expand Up @@ -122,6 +122,43 @@ def configure_llm_endpoint(config: MemGPTConfig, credentials: MemGPTCredentials)
raise KeyboardInterrupt
provider = "openai"

elif provider == "groq":
# NOTE: basically same as OpenAI
# check for key
if credentials.groq_key is None:
# allow key to get pulled from env vars
groq_api_key = os.getenv("GROQ_API_KEY", None)
# if we still can't find it, ask for it as input
if groq_api_key is None:
while groq_api_key is None or len(groq_api_key) == 0:
# Ask for API key as input
groq_api_key = questionary.password("Enter your Groq API key (see https://console.groq.com/keys):").ask()
if groq_api_key is None:
raise KeyboardInterrupt
credentials.groq_key = groq_api_key
credentials.save()
else:
# Give the user an opportunity to overwrite the key
groq_api_key = None
default_input = shorten_key_middle(credentials.groq_key) if credentials.groq_key.startswith("sk-") else credentials.groq_key
groq_api_key = questionary.password(
"Enter your Groq API key (see https://console.groq.com/keys):",
default=default_input,
).ask()
if groq_api_key is None:
raise KeyboardInterrupt
# If the user modified it, use the new one
if groq_api_key != default_input:
credentials.groq_key = groq_api_key
credentials.save()

model_endpoint_type = "groq"
model_endpoint = "https://api.groq.com/openai/v1"
model_endpoint = questionary.text("Override default endpoint:", default=model_endpoint).ask()
if model_endpoint is None:
raise KeyboardInterrupt
provider = model_endpoint_type

elif provider == "azure":
# check for necessary vars
azure_creds = get_azure_credentials()
Expand Down Expand Up @@ -344,6 +381,13 @@ def get_model_options(
else:
model_options = [obj["id"] for obj in fetched_model_options_response["data"]]

elif model_endpoint_type == "groq":
if credentials.groq_key is None:
raise ValueError("Missing Groq API key")
fetched_model_options_response = openai_get_model_list(url=model_endpoint, api_key=credentials.groq_key)

model_options = [obj["id"] for obj in fetched_model_options_response["data"]]

elif model_endpoint_type == "azure":
if credentials.azure_key is None:
raise ValueError("Missing Azure key")
Expand Down Expand Up @@ -457,6 +501,26 @@ def configure_model(config: MemGPTConfig, credentials: MemGPTCredentials, model_
if model is None:
raise KeyboardInterrupt

elif model_endpoint_type == "groq":
try:
fetched_model_options = get_model_options(
credentials=credentials, model_endpoint_type=model_endpoint_type, model_endpoint=model_endpoint
)
except Exception as e:
# NOTE: if this fails, it means the user's key is probably bad
typer.secho(
f"Failed to get model list from {model_endpoint} - make sure your API key and endpoints are correct!", fg=typer.colors.RED
)
raise e

model = questionary.select(
"Select default model:",
choices=fetched_model_options,
default=fetched_model_options[0],
).ask()
if model is None:
raise KeyboardInterrupt

elif model_endpoint_type == "google_ai":
try:
fetched_model_options = get_model_options(
Expand Down Expand Up @@ -740,6 +804,27 @@ def configure_model(config: MemGPTConfig, credentials: MemGPTCredentials, model_
if context_window_input is None:
raise KeyboardInterrupt

elif model_endpoint_type == "groq":
try:
fetched_context_window = str(
openai_get_model_context_window(url=model_endpoint, api_key=credentials.groq_key, model=model)
)
print(f"Got context window {fetched_context_window} for model {model}")
context_length_options = [
fetched_context_window,
"custom",
]
except Exception as e:
print(f"Failed to get model details for model '{model}' ({str(e)})")

context_window_input = questionary.select(
"Select your model's context window (see https://console.groq.com/docs/models):",
choices=context_length_options,
default=context_length_options[0],
).ask()
if context_window_input is None:
raise KeyboardInterrupt

else:

# Ask the user to specify the context length
Expand Down
8 changes: 8 additions & 0 deletions memgpt/credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ class MemGPTCredentials:
# cohere config
cohere_key: Optional[str] = None

# groq config
groq_key: Optional[str] = None

# azure config
azure_auth_type: str = "api_key"
azure_key: Optional[str] = None
Expand Down Expand Up @@ -87,6 +90,8 @@ def load(cls) -> "MemGPTCredentials":
"anthropic_key": get_field(config, "anthropic", "key"),
# cohere
"cohere_key": get_field(config, "cohere", "key"),
# groq
"groq_key": get_field(config, "groq", "key"),
# open llm
"openllm_auth_type": get_field(config, "openllm", "auth_type"),
"openllm_key": get_field(config, "openllm", "key"),
Expand Down Expand Up @@ -129,6 +134,9 @@ def save(self):
# cohere
set_field(config, "cohere", "key", self.cohere_key)

# groq
set_field(config, "groq", "key", self.groq_key)

# openllm config
set_field(config, "openllm", "auth_type", self.openllm_auth_type)
set_field(config, "openllm", "key", self.openllm_key)
Expand Down
30 changes: 28 additions & 2 deletions memgpt/data_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from memgpt.models import chat_completion_response
from memgpt.utils import get_human_text, get_persona_text, printd, is_utc_datetime
from memgpt.local_llm.constants import INNER_THOUGHTS_KWARG, INNER_THOUGHTS_KWARG_DESCRIPTION
from memgpt.constants import JSON_ENSURE_ASCII


class Record:
Expand Down Expand Up @@ -150,11 +151,17 @@ def dict_to_message(
model: Optional[str] = None, # model used to make function call
allow_functions_style: bool = False, # allow deprecated functions style?
created_at: Optional[datetime] = None,
allow_null_content: Optional[bool] = False,
inner_thoughts_in_kwargs: Optional[bool] = False,
):
"""Convert a ChatCompletion message object into a Message object (synced to DB)"""

assert "role" in openai_message_dict, openai_message_dict
assert "content" in openai_message_dict, openai_message_dict
if not (allow_null_content or inner_thoughts_in_kwargs):
assert "content" in openai_message_dict, openai_message_dict
else:
if "content" not in openai_message_dict:
openai_message_dict["content"] = None

# If we're going from deprecated function form
if openai_message_dict["role"] == "function":
Expand Down Expand Up @@ -227,6 +234,16 @@ def dict_to_message(
else:
tool_calls = None

# Optionally inner thoughts may be inside the tool sections
if inner_thoughts_in_kwargs and openai_message_dict["role"] == "assistant" and len(tool_calls) > 0:
from memgpt.local_llm.constants import INNER_THOUGHTS_KWARG, INNER_THOUGHTS_KWARG_DESCRIPTION

assert openai_message_dict["content"] is None, openai_message_dict
assert len(tool_calls) == 1, tool_calls
assert INNER_THOUGHTS_KWARG in json.loads(tool_calls[0].function["arguments"]), tool_calls
inner_thoughts = json.loads(tool_calls[0].function["arguments"]).pop(INNER_THOUGHTS_KWARG)
openai_message_dict["content"] = inner_thoughts

# If we're going from tool-call style
return Message(
created_at=created_at,
Expand All @@ -241,7 +258,7 @@ def dict_to_message(
tool_call_id=openai_message_dict["tool_call_id"] if "tool_call_id" in openai_message_dict else None,
)

def to_openai_dict(self, max_tool_id_length=TOOL_CALL_ID_MAX_LEN) -> dict:
def to_openai_dict(self, max_tool_id_length=TOOL_CALL_ID_MAX_LEN, put_inner_thoughts_in_kwargs: bool = True) -> dict:
"""Go from Message class to ChatCompletion message object"""

# TODO change to pydantic casting, eg `return SystemMessageModel(self)`
Expand Down Expand Up @@ -281,6 +298,15 @@ def to_openai_dict(self, max_tool_id_length=TOOL_CALL_ID_MAX_LEN) -> dict:
for tool_call_dict in openai_message["tool_calls"]:
tool_call_dict["id"] = tool_call_dict["id"][:max_tool_id_length]

if put_inner_thoughts_in_kwargs:
openai_message["content"] = None
assert openai_message["tool_calls"] is not None and len(openai_message["tool_calls"]) == 1
for tc in openai_message["tool_calls"]:
existing_args = json.loads(tc["function"]["arguments"])
# TODO throw error for null case here?
existing_args[INNER_THOUGHTS_KWARG] = self.text if self.text is not None else ""
tc["function"]["arguments"] = json.dumps(existing_args, ensure_ascii=JSON_ENSURE_ASCII)

elif self.role == "tool":
assert all([v is not None for v in [self.role, self.tool_call_id]]), vars(self)
openai_message = {
Expand Down
40 changes: 37 additions & 3 deletions memgpt/llm_api/llm_api_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from memgpt.llm_api.cohere import cohere_chat_completions_request


LLM_API_PROVIDER_OPTIONS = ["openai", "azure", "anthropic", "google_ai", "cohere", "local"]
LLM_API_PROVIDER_OPTIONS = ["openai", "azure", "anthropic", "google_ai", "cohere", "groq", "local"]


def is_context_overflow_error(exception: requests.exceptions.RequestException) -> bool:
Expand Down Expand Up @@ -152,7 +152,7 @@ def create(
# TODO do the same for Azure?
if credentials.openai_key is None and agent_state.llm_config.model_endpoint == "https://api.openai.com/v1":
# only is a problem if we are *not* using an openai proxy
raise ValueError(f"OpenAI key is missing from MemGPT config file")
raise ValueError(f"OpenAI key is missing from MemGPT credentials file")
if use_tool_naming:
data = ChatCompletionRequest(
model=agent_state.llm_config.model,
Expand All @@ -172,7 +172,41 @@ def create(
return openai_chat_completions_request(
url=agent_state.llm_config.model_endpoint, # https://api.openai.com/v1 -> https://api.openai.com/v1/chat/completions
api_key=credentials.openai_key,
data=data,
chat_completion_request=data,
)

# NOTE: basically the same as OpenAI
elif agent_state.llm_config.model_endpoint_type == "groq":

if credentials.groq_key is None:
raise ValueError(f"Groq API key is missing from MemGPT credentials file")
if use_tool_naming:
data = ChatCompletionRequest(
model=agent_state.llm_config.model,
messages=[cast_message_to_subtype(m.to_openai_dict()) for m in messages],
tools=[{"type": "function", "function": f} for f in functions] if functions else None,
# tool_choice=function_call,
# tool_choice="auto",
tool_choice={"type": "function", "function": {"name": "send_message"}},
user=str(agent_state.user_id),
)
else:
data = ChatCompletionRequest(
model=agent_state.llm_config.model,
messages=[cast_message_to_subtype(m.to_openai_dict()) for m in messages],
functions=functions,
# function_call=function_call,
# function_call="auto",
tool_choice={"type": "function", "function": {"name": "send_message"}},
user=str(agent_state.user_id),
)
# NOTE: using openai function since it's the same req/resp
return openai_chat_completions_request(
url=agent_state.llm_config.model_endpoint, # https://api.openai.com/v1 -> https://api.openai.com/v1/chat/completions
api_key=credentials.groq_key,
chat_completion_request=data,
# NOTE: Groq in function calling mode doesn't seem to return non-null content, so we need to put CoT in the kwargs
inner_thoughts_in_tools=True,
)

# azure
Expand Down
75 changes: 71 additions & 4 deletions memgpt/llm_api/openai.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,25 @@
import requests
import uuid
import time
from typing import Union, Optional
from typing import Union, Optional, List

from memgpt.data_types import Message
from memgpt.models.chat_completion_response import ChatCompletionResponse
from memgpt.models.chat_completion_request import ChatCompletionRequest
from memgpt.models.chat_completion_request import ChatCompletionRequest, Tool
from memgpt.models.embedding_response import EmbeddingResponse
from memgpt.utils import smart_urljoin


def openai_get_model_context_window(url: str, api_key: Union[str, None], model: str, fix_url: Optional[bool] = False) -> str:
# NOTE: this actually doesn't work for OpenAI atm, just some OpenAI-compatible APIs like Groq
model_list = openai_get_model_list(url=url, api_key=api_key, fix_url=fix_url)

for model_dict in model_list["data"]:
if model_dict["id"] == model and "context_window" in model_dict:
return int(model_dict["context_window"])
raise ValueError(f"Can't find model '{model}' in model list")


def openai_get_model_list(url: str, api_key: Union[str, None], fix_url: Optional[bool] = False) -> dict:
"""https://platform.openai.com/docs/api-reference/models/list"""
from memgpt.utils import printd
Expand Down Expand Up @@ -58,13 +70,43 @@ def openai_get_model_list(url: str, api_key: Union[str, None], fix_url: Optional
raise e


def openai_chat_completions_request(url: str, api_key: str, data: ChatCompletionRequest) -> ChatCompletionResponse:
def add_inner_thoughts_to_tool_params(tools: List[Tool], inner_thoughts_required: Optional[bool] = True) -> List[Tool]:
from memgpt.local_llm.constants import INNER_THOUGHTS_KWARG, INNER_THOUGHTS_KWARG_DESCRIPTION

tools_with_inner_thoughts = []
for tool in tools:
assert INNER_THOUGHTS_KWARG not in tool.function.parameters["properties"], tool

tool.function.parameters["properties"][INNER_THOUGHTS_KWARG] = {
"type": "string",
"description": INNER_THOUGHTS_KWARG_DESCRIPTION,
}

if inner_thoughts_required:
assert INNER_THOUGHTS_KWARG not in tool.function.parameters["required"], tool
tool.function.parameters["required"].append(INNER_THOUGHTS_KWARG)

tools_with_inner_thoughts.append(tool)

return tools_with_inner_thoughts


def openai_chat_completions_request(
url: str, api_key: str, chat_completion_request: ChatCompletionRequest, inner_thoughts_in_tools: Optional[bool] = False
) -> ChatCompletionResponse:
"""https://platform.openai.com/docs/guides/text-generation?lang=curl"""
from memgpt.utils import printd

url = smart_urljoin(url, "chat/completions")
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}
data = data.model_dump(exclude_none=True)

# Certain inference backends or models may not support non-null content + function call
# E.g., Groq's API is technically OpenAI compliant, but does not tend to (always?) returns non-null contents when function calling
# In this case, we may want to move inner thoughts into the function parameters to ensure that we still get CoT
if inner_thoughts_in_tools:
chat_completion_request.tools = add_inner_thoughts_to_tool_params(chat_completion_request.tools)

data = chat_completion_request.model_dump(exclude_none=True)

# If functions == None, strip from the payload
if "functions" in data and data["functions"] is None:
Expand All @@ -75,6 +117,26 @@ def openai_chat_completions_request(url: str, api_key: str, data: ChatCompletion
data.pop("tools")
data.pop("tool_choice", None) # extra safe, should exist always (default="auto")

print("aaa")
for m in data["messages"]:
print(m)

if inner_thoughts_in_tools:
# move inner thoughts to func calls in the chat history by recasting
msg_objs = [
Message.dict_to_message(user_id=uuid.uuid4(), agent_id=uuid.uuid4(), openai_message_dict=m, inner_thoughts_in_kwargs=True)
for m in data["messages"]
]
data["messages"] = [m.to_openai_dict(put_inner_thoughts_in_kwargs=True) for m in msg_objs]

print("zzz")
for m in data["messages"]:
print(m)

print("xxx")
for t in data["tools"]:
print(t)

printd(f"Sending request to {url}")
try:
# Example code to trigger a rate limit response:
Expand All @@ -93,6 +155,11 @@ def openai_chat_completions_request(url: str, api_key: str, data: ChatCompletion
response = response.json() # convert to dict from string
printd(f"response.json = {response}")
response = ChatCompletionResponse(**response) # convert to 'dot-dict' style which is the openai python client default

if inner_thoughts_in_tools:
# We need to strip the inner thought out of the parameters and put it back inside the content
raise NotImplementedError

return response
except requests.exceptions.HTTPError as http_err:
# Handle HTTP errors (e.g., response 4XX, 5XX)
Expand Down
Loading
Loading