Skip to content

Commit

Permalink
Add OpenAI function call support (#4683)
Browse files Browse the repository at this point in the history
Co-authored-by: merwanehamadi <[email protected]>
Co-authored-by: Reinier van der Leer <[email protected]>
  • Loading branch information
3 people committed Jun 22, 2023
1 parent 32038c9 commit 857d26d
Show file tree
Hide file tree
Showing 23 changed files with 425 additions and 189 deletions.
6 changes: 5 additions & 1 deletion .env.template
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,14 @@ OPENAI_API_KEY=your-openai-api-key
## PROMPT_SETTINGS_FILE - Specifies which Prompt Settings file to use (defaults to prompt_settings.yaml)
# PROMPT_SETTINGS_FILE=prompt_settings.yaml

## OPENAI_API_BASE_URL - Custom url for the OpenAI API, useful for connecting to custom backends. No effect if USE_AZURE is true, leave blank to keep the default url
## OPENAI_API_BASE_URL - Custom url for the OpenAI API, useful for connecting to custom backends. No effect if USE_AZURE is true, leave blank to keep the default url
# the following is an example:
# OPENAI_API_BASE_URL=http://localhost:443/v1

## OPENAI_FUNCTIONS - Enables OpenAI functions: https://platform.openai.com/docs/guides/gpt/function-calling
## WARNING: this feature is only supported by OpenAI's newest models. Until these models become the default on 27 June, add a '-0613' suffix to the model of your choosing.
# OPENAI_FUNCTIONS=False

## AUTHORISE COMMAND KEY - Key to authorise commands
# AUTHORISE_COMMAND_KEY=y

Expand Down
8 changes: 6 additions & 2 deletions autogpt/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,9 @@ def signal_handler(signum, frame):
)

try:
assistant_reply_json = extract_json_from_response(assistant_reply)
assistant_reply_json = extract_json_from_response(
assistant_reply.content
)
validate_json(assistant_reply_json, self.config)
except json.JSONDecodeError as e:
logger.error(f"Exception while validating assistant reply JSON: {e}")
Expand All @@ -160,7 +162,9 @@ def signal_handler(signum, frame):
print_assistant_thoughts(
self.ai_name, assistant_reply_json, self.config
)
command_name, arguments = get_command(assistant_reply_json)
command_name, arguments = get_command(
assistant_reply_json, assistant_reply, self.config
)
if self.config.speak_mode:
say_text(f"I want to execute {command_name}")

Expand Down
8 changes: 6 additions & 2 deletions autogpt/agent/agent_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,9 @@ def create_agent(
if plugin_messages := plugin.pre_instruction(messages.raw()):
messages.extend([Message(**raw_msg) for raw_msg in plugin_messages])
# Start GPT instance
agent_reply = create_chat_completion(prompt=messages, config=self.config)
agent_reply = create_chat_completion(
prompt=messages, config=self.config
).content

messages.add("assistant", agent_reply)

Expand Down Expand Up @@ -92,7 +94,9 @@ def message_agent(self, key: str | int, message: str) -> str:
messages.extend([Message(**raw_msg) for raw_msg in plugin_messages])

# Start GPT instance
agent_reply = create_chat_completion(prompt=messages, config=self.config)
agent_reply = create_chat_completion(
prompt=messages, config=self.config
).content

messages.add("assistant", agent_reply)

Expand Down
28 changes: 22 additions & 6 deletions autogpt/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from typing import Dict

from autogpt.agent.agent import Agent
from autogpt.config import Config
from autogpt.llm import ChatModelResponse


def is_valid_int(value: str) -> bool:
Expand All @@ -21,11 +23,15 @@ def is_valid_int(value: str) -> bool:
return False


def get_command(response_json: Dict):
def get_command(
assistant_reply_json: Dict, assistant_reply: ChatModelResponse, config: Config
):
"""Parse the response and return the command name and arguments
Args:
response_json (json): The response from the AI
assistant_reply_json (dict): The response object from the AI
assistant_reply (ChatModelResponse): The model response from the AI
config (Config): The config object
Returns:
tuple: The command name and arguments
Expand All @@ -35,14 +41,24 @@ def get_command(response_json: Dict):
Exception: If any other error occurs
"""
if config.openai_functions:
if assistant_reply.function_call is None:
return "Error:", "No 'function_call' in assistant reply"
assistant_reply_json["command"] = {
"name": assistant_reply.function_call.name,
"args": json.loads(assistant_reply.function_call.arguments),
}
try:
if "command" not in response_json:
if "command" not in assistant_reply_json:
return "Error:", "Missing 'command' object in JSON"

if not isinstance(response_json, dict):
return "Error:", f"'response_json' object is not dictionary {response_json}"
if not isinstance(assistant_reply_json, dict):
return (
"Error:",
f"The previous message sent was not a dictionary {assistant_reply_json}",
)

command = response_json["command"]
command = assistant_reply_json["command"]
if not isinstance(command, dict):
return "Error:", "'command' object is not a dictionary"

Expand Down
23 changes: 19 additions & 4 deletions autogpt/command_decorator.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,43 @@
import functools
from typing import Any, Callable, Dict, Optional
from typing import Any, Callable, Optional, TypedDict

from autogpt.config import Config
from autogpt.models.command import Command
from autogpt.models.command import Command, CommandParameter

# Unique identifier for auto-gpt commands
AUTO_GPT_COMMAND_IDENTIFIER = "auto_gpt_command"


class CommandParameterSpec(TypedDict):
type: str
description: str
required: bool


def command(
name: str,
description: str,
arguments: Dict[str, Dict[str, Any]],
parameters: dict[str, CommandParameterSpec],
enabled: bool | Callable[[Config], bool] = True,
disabled_reason: Optional[str] = None,
) -> Callable[..., Any]:
"""The command decorator is used to create Command objects from ordinary functions."""

def decorator(func: Callable[..., Any]) -> Command:
typed_parameters = [
CommandParameter(
name=param_name,
description=parameter.get("description"),
type=parameter.get("type", "string"),
required=parameter.get("required", False),
)
for param_name, parameter in parameters.items()
]
cmd = Command(
name=name,
description=description,
method=func,
signature=arguments,
parameters=typed_parameters,
enabled=enabled,
disabled_reason=disabled_reason,
)
Expand Down
2 changes: 1 addition & 1 deletion autogpt/config/ai_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,5 +164,5 @@ def construct_full_prompt(
if self.api_budget > 0.0:
full_prompt += f"\nIt takes money to let you run. Your API budget is ${self.api_budget:.3f}"
self.prompt_generator = prompt_generator
full_prompt += f"\n\n{prompt_generator.generate_prompt_string()}"
full_prompt += f"\n\n{prompt_generator.generate_prompt_string(config)}"
return full_prompt
2 changes: 2 additions & 0 deletions autogpt/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,8 @@ def __init__(self) -> None:
if self.openai_organization is not None:
openai.organization = self.openai_organization

self.openai_functions = os.getenv("OPENAI_FUNCTIONS", "False") == "True"

self.elevenlabs_api_key = os.getenv("ELEVENLABS_API_KEY")
# ELEVENLABS_VOICE_1_ID is deprecated and included for backwards-compatibility
self.elevenlabs_voice_id = os.getenv(
Expand Down
10 changes: 7 additions & 3 deletions autogpt/json_utils/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,15 @@ def extract_json_from_response(response_content: str) -> dict:


def llm_response_schema(
schema_name: str = LLM_DEFAULT_RESPONSE_FORMAT,
config: Config, schema_name: str = LLM_DEFAULT_RESPONSE_FORMAT
) -> dict[str, Any]:
filename = os.path.join(os.path.dirname(__file__), f"{schema_name}.json")
with open(filename, "r") as f:
return json.load(f)
json_schema = json.load(f)
if config.openai_functions:
del json_schema["properties"]["command"]
json_schema["required"].remove("command")
return json_schema


def validate_json(
Expand All @@ -47,7 +51,7 @@ def validate_json(
Returns:
bool: Whether the json_object is valid or not
"""
schema = llm_response_schema(schema_name)
schema = llm_response_schema(config, schema_name)
validator = Draft7Validator(schema)

if errors := sorted(validator.iter_errors(json_object), key=lambda e: e.path):
Expand Down
8 changes: 6 additions & 2 deletions autogpt/llm/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@

from dataclasses import dataclass, field
from math import ceil, floor
from typing import List, Literal, TypedDict
from typing import TYPE_CHECKING, List, Literal, Optional, TypedDict

if TYPE_CHECKING:
from autogpt.llm.providers.openai import OpenAIFunctionCall

MessageRole = Literal["system", "user", "assistant"]
MessageType = Literal["ai_response", "action_result"]
Expand Down Expand Up @@ -156,4 +159,5 @@ def __post_init__(self):
class ChatModelResponse(LLMResponse):
"""Standard response struct for a response from an LLM model."""

content: str = None
content: Optional[str] = None
function_call: Optional[OpenAIFunctionCall] = None
6 changes: 5 additions & 1 deletion autogpt/llm/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import time
from typing import TYPE_CHECKING

from autogpt.llm.providers.openai import get_openai_command_specs

if TYPE_CHECKING:
from autogpt.agent.agent import Agent

Expand Down Expand Up @@ -94,6 +96,7 @@ def chat_with_ai(
current_tokens_used += count_message_tokens([user_input_msg], model)

current_tokens_used += 500 # Reserve space for new_summary_message
current_tokens_used += 500 # Reserve space for the openai functions TODO improve

# Add Messages until the token limit is reached or there are no more messages to add.
for cycle in reversed(list(agent.history.per_cycle(agent.config))):
Expand Down Expand Up @@ -193,11 +196,12 @@ def chat_with_ai(
assistant_reply = create_chat_completion(
prompt=message_sequence,
config=agent.config,
functions=get_openai_command_specs(agent),
max_tokens=tokens_remaining,
)

# Update full message history
agent.history.append(user_input_msg)
agent.history.add("assistant", assistant_reply, "ai_response")
agent.history.add("assistant", assistant_reply.content, "ai_response")

return assistant_reply
83 changes: 82 additions & 1 deletion autogpt/llm/providers/openai.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from __future__ import annotations

import functools
import time
from typing import List
from dataclasses import dataclass
from typing import TYPE_CHECKING, List, Optional
from unittest.mock import patch

import openai
Expand All @@ -9,6 +12,9 @@
from openai.error import APIError, RateLimitError, Timeout
from openai.openai_object import OpenAIObject

if TYPE_CHECKING:
from autogpt.agent.agent import Agent

from autogpt.llm.base import (
ChatModelInfo,
EmbeddingModelInfo,
Expand Down Expand Up @@ -267,3 +273,78 @@ def create_embedding(
input=input,
**kwargs,
)


@dataclass
class OpenAIFunctionCall:
"""Represents a function call as generated by an OpenAI model
Attributes:
name: the name of the function that the LLM wants to call
arguments: a stringified JSON object (unverified) containing `arg: value` pairs
"""

name: str
arguments: str


@dataclass
class OpenAIFunctionSpec:
"""Represents a "function" in OpenAI, which is mapped to a Command in Auto-GPT"""

name: str
description: str
parameters: dict[str, ParameterSpec]

@dataclass
class ParameterSpec:
name: str
type: str
description: Optional[str]
required: bool = False

@property
def __dict__(self):
"""Output an OpenAI-consumable function specification"""
return {
"name": self.name,
"description": self.description,
"parameters": {
"type": "object",
"properties": {
param.name: {
"type": param.type,
"description": param.description,
}
for param in self.parameters.values()
},
"required": [
param.name for param in self.parameters.values() if param.required
],
},
}


def get_openai_command_specs(agent: Agent) -> list[OpenAIFunctionSpec]:
"""Get OpenAI-consumable function specs for the agent's available commands.
see https://platform.openai.com/docs/guides/gpt/function-calling
"""
if not agent.config.openai_functions:
return []

return [
OpenAIFunctionSpec(
name=command.name,
description=command.description,
parameters={
param.name: OpenAIFunctionSpec.ParameterSpec(
name=param.name,
type=param.type,
required=param.required,
description=param.description,
)
for param in command.parameters
},
)
for command in agent.command_registry.commands.values()
]
Loading

0 comments on commit 857d26d

Please sign in to comment.