Skip to content

Commit

Permalink
Pass cfg everywhere in order to get of singleton
Browse files Browse the repository at this point in the history
Signed-off-by: Merwane Hamadi <[email protected]>
  • Loading branch information
waynehamadi committed Jun 19, 2023
1 parent 096d27f commit ab0075b
Show file tree
Hide file tree
Showing 44 changed files with 323 additions and 300 deletions.
10 changes: 6 additions & 4 deletions autogpt/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def signal_handler(signum, frame):

try:
assistant_reply_json = extract_json_from_response(assistant_reply)
validate_json(assistant_reply_json)
validate_json(assistant_reply_json, self.config)
except json.JSONDecodeError as e:
logger.error(f"Exception while validating assistant reply JSON: {e}")
assistant_reply_json = {}
Expand All @@ -158,7 +158,7 @@ def signal_handler(signum, frame):
# Get command name and arguments
try:
print_assistant_thoughts(
self.ai_name, assistant_reply_json, self.config.speak_mode
self.ai_name, assistant_reply_json, self.config
)
command_name, arguments = get_command(assistant_reply_json)
if self.config.speak_mode:
Expand Down Expand Up @@ -197,10 +197,12 @@ def signal_handler(signum, frame):
)
while True:
if self.config.chat_messages_enabled:
console_input = clean_input("Waiting for your response...")
console_input = clean_input(
self.config, "Waiting for your response..."
)
else:
console_input = clean_input(
Fore.MAGENTA + "Input:" + Style.RESET_ALL
self.config, Fore.MAGENTA + "Input:" + Style.RESET_ALL
)
if console_input.lower().strip() == self.config.authorise_key:
user_input = "GENERATE NEXT COMMAND JSON"
Expand Down
20 changes: 10 additions & 10 deletions autogpt/agent/agent_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@
class AgentManager(metaclass=Singleton):
"""Agent manager for managing GPT agents"""

def __init__(self):
def __init__(self, config: Config):
self.next_key = 0
self.agents: dict[
int, tuple[str, list[Message], str]
] = {} # key, (task, full_message_history, model)
self.cfg = Config()
self.config = config

# Create new GPT agent
# TODO: Centralise use of create_chat_completion() to globally enforce token limit
Expand All @@ -35,18 +35,18 @@ def create_agent(
"""
messages = ChatSequence.for_model(model, [Message("user", creation_prompt)])

for plugin in self.cfg.plugins:
for plugin in self.config.plugins:
if not plugin.can_handle_pre_instruction():
continue
if plugin_messages := plugin.pre_instruction(messages.raw()):
messages.extend([Message(**raw_msg) for raw_msg in plugin_messages])
# Start GPT instance
agent_reply = create_chat_completion(prompt=messages)
agent_reply = create_chat_completion(prompt=messages, config=self.config)

messages.add("assistant", agent_reply)

plugins_reply = ""
for i, plugin in enumerate(self.cfg.plugins):
for i, plugin in enumerate(self.config.plugins):
if not plugin.can_handle_on_instruction():
continue
if plugin_result := plugin.on_instruction([m.raw() for m in messages]):
Expand All @@ -62,7 +62,7 @@ def create_agent(

self.agents[key] = (task, list(messages), model)

for plugin in self.cfg.plugins:
for plugin in self.config.plugins:
if not plugin.can_handle_post_instruction():
continue
agent_reply = plugin.post_instruction(agent_reply)
Expand All @@ -85,19 +85,19 @@ def message_agent(self, key: str | int, message: str) -> str:
messages = ChatSequence.for_model(model, messages)
messages.add("user", message)

for plugin in self.cfg.plugins:
for plugin in self.config.plugins:
if not plugin.can_handle_pre_instruction():
continue
if plugin_messages := plugin.pre_instruction([m.raw() for m in messages]):
messages.extend([Message(**raw_msg) for raw_msg in plugin_messages])

# Start GPT instance
agent_reply = create_chat_completion(prompt=messages)
agent_reply = create_chat_completion(prompt=messages, config=self.config)

messages.add("assistant", agent_reply)

plugins_reply = agent_reply
for i, plugin in enumerate(self.cfg.plugins):
for i, plugin in enumerate(self.config.plugins):
if not plugin.can_handle_on_instruction():
continue
if plugin_result := plugin.on_instruction([m.raw() for m in messages]):
Expand All @@ -107,7 +107,7 @@ def message_agent(self, key: str | int, message: str) -> str:
if plugins_reply and plugins_reply != "":
messages.add("assistant", plugins_reply)

for plugin in self.cfg.plugins:
for plugin in self.config.plugins:
if not plugin.can_handle_post_instruction():
continue
agent_reply = plugin.post_instruction(agent_reply)
Expand Down
4 changes: 2 additions & 2 deletions autogpt/command_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@ def command(
"""The command decorator is used to create Command objects from ordinary functions."""

# TODO: Remove this in favor of better command management
CFG = Config()
config = Config()

if callable(enabled):
enabled = enabled(CFG)
enabled = enabled(config)
if not enabled:
if disabled_reason is not None:
logger.debug(f"Command '{name}' is disabled: {disabled_reason}")
Expand Down
8 changes: 4 additions & 4 deletions autogpt/commands/execute_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from autogpt.command_decorator import command
from autogpt.config import Config
from autogpt.logs import logger
from autogpt.setup import CFG
from autogpt.workspace.workspace import Workspace

ALLOWLIST_CONTROL = "allowlist"
Expand Down Expand Up @@ -83,7 +82,7 @@ def execute_python_file(filename: str, agent: Agent) -> str:
str: The output of the file
"""
logger.info(
f"Executing python file '{filename}' in working directory '{CFG.workspace_path}'"
f"Executing python file '{filename}' in working directory '{agent.config.workspace_path}'"
)

if not filename.endswith(".py"):
Expand All @@ -105,7 +104,7 @@ def execute_python_file(filename: str, agent: Agent) -> str:
["python", str(path)],
capture_output=True,
encoding="utf8",
cwd=CFG.workspace_path,
cwd=agent.config.workspace_path,
)
if result.returncode == 0:
return result.stdout
Expand Down Expand Up @@ -174,6 +173,7 @@ def validate_command(command: str, config: Config) -> bool:
Args:
command (str): The command to validate
config (Config): The config to use to validate the command
Returns:
bool: True if the command is allowed, False otherwise
Expand All @@ -199,7 +199,7 @@ def validate_command(command: str, config: Config) -> bool:
"required": True,
}
},
enabled=lambda cfg: cfg.execute_local_commands,
enabled=lambda config: config.execute_local_commands,
disabled_reason="You are not allowed to run local shell commands. To execute"
" shell commands, EXECUTE_LOCAL_COMMANDS must be set to 'True' "
"in your config file: .env - do not attempt to bypass the restriction.",
Expand Down
3 changes: 2 additions & 1 deletion autogpt/commands/file_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ def is_duplicate_operation(
Args:
operation: The operation to check for
filename: The name of the file to check for
config: The agent config
checksum: The checksum of the contents to be written
Returns:
Expand Down Expand Up @@ -137,7 +138,7 @@ def read_file(filename: str, agent: Agent) -> str:
content = read_textual_file(filename, logger)

# TODO: invalidate/update memory when file is edited
file_memory = MemoryItem.from_text_file(content, filename)
file_memory = MemoryItem.from_text_file(content, filename, agent.config)
if len(file_memory.chunks) > 1:
return file_memory.summary

Expand Down
2 changes: 1 addition & 1 deletion autogpt/commands/image_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def generate_image_with_sd_webui(
"negative_prompt": negative_prompt,
"sampler_index": "DDIM",
"steps": 20,
"cfg_scale": 7.0,
"config_scale": 7.0,
"width": size,
"height": size,
"n_iter": 1,
Expand Down
2 changes: 1 addition & 1 deletion autogpt/commands/web_selenium.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,6 @@ def summarize_memorize_webpage(

memory = get_memory(agent.config)

new_memory = MemoryItem.from_webpage(text, url, question=question)
new_memory = MemoryItem.from_webpage(text, url, agent.config, question=question)
memory.add(new_memory)
return new_memory.summary
22 changes: 10 additions & 12 deletions autogpt/config/ai_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,22 +59,22 @@ def __init__(
self.command_registry: CommandRegistry | None = None

@staticmethod
def load(config_file: str = SAVE_FILE) -> "AIConfig":
def load(ai_settings_file: str = SAVE_FILE) -> "AIConfig":
"""
Returns class object with parameters (ai_name, ai_role, ai_goals, api_budget) loaded from
yaml file if yaml file exists,
else returns class with no parameters.
Parameters:
config_file (int): The path to the config yaml file.
ai_settings_file (int): The path to the config yaml file.
DEFAULT: "../ai_settings.yaml"
Returns:
cls (object): An instance of given cls object
"""

try:
with open(config_file, encoding="utf-8") as file:
with open(ai_settings_file, encoding="utf-8") as file:
config_params = yaml.load(file, Loader=yaml.FullLoader) or {}
except FileNotFoundError:
config_params = {}
Expand All @@ -91,12 +91,12 @@ def load(config_file: str = SAVE_FILE) -> "AIConfig":
# type: Type[AIConfig]
return AIConfig(ai_name, ai_role, ai_goals, api_budget)

def save(self, config_file: str = SAVE_FILE) -> None:
def save(self, ai_settings_file: str = SAVE_FILE) -> None:
"""
Saves the class parameters to the specified file yaml file path as a yaml file.
Parameters:
config_file(str): The path to the config yaml file.
ai_settings_file(str): The path to the config yaml file.
DEFAULT: "../ai_settings.yaml"
Returns:
Expand All @@ -109,11 +109,11 @@ def save(self, config_file: str = SAVE_FILE) -> None:
"ai_goals": self.ai_goals,
"api_budget": self.api_budget,
}
with open(config_file, "w", encoding="utf-8") as file:
with open(ai_settings_file, "w", encoding="utf-8") as file:
yaml.dump(config, file, allow_unicode=True)

def construct_full_prompt(
self, prompt_generator: Optional[PromptGenerator] = None
self, config, prompt_generator: Optional[PromptGenerator] = None
) -> str:
"""
Returns a prompt to the user with the class information in an organized fashion.
Expand All @@ -133,22 +133,20 @@ def construct_full_prompt(
""
)

from autogpt.config import Config
from autogpt.prompts.prompt import build_default_prompt_generator

cfg = Config()
if prompt_generator is None:
prompt_generator = build_default_prompt_generator()
prompt_generator = build_default_prompt_generator(config)
prompt_generator.goals = self.ai_goals
prompt_generator.name = self.ai_name
prompt_generator.role = self.ai_role
prompt_generator.command_registry = self.command_registry
for plugin in cfg.plugins:
for plugin in config.plugins:
if not plugin.can_handle_post_prompt():
continue
prompt_generator = plugin.post_prompt(prompt_generator)

if cfg.execute_local_commands:
if config.execute_local_commands:
# add OS info to prompt
os_name = platform.system()
os_info = (
Expand Down
5 changes: 2 additions & 3 deletions autogpt/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,10 +300,9 @@ def set_memory_backend(self, name: str) -> None:
self.memory_backend = name


def check_openai_api_key() -> None:
def check_openai_api_key(config: Config) -> None:
"""Check if the OpenAI API key is set in config.py or as an environment variable."""
cfg = Config()
if not cfg.openai_api_key:
if not config.openai_api_key:
print(
Fore.RED
+ "Please set your OpenAI API key in .env or as an environment variable."
Expand Down
12 changes: 3 additions & 9 deletions autogpt/config/prompt_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,8 @@
from colorama import Fore

from autogpt import utils
from autogpt.config.config import Config
from autogpt.logs import logger

CFG = Config()


class PromptConfig:
"""
Expand All @@ -22,10 +19,7 @@ class PromptConfig:
performance_evaluations (list): Performance evaluation list for the prompt generator.
"""

def __init__(
self,
config_file: str = CFG.prompt_settings_file,
) -> None:
def __init__(self, prompt_settings_file: str) -> None:
"""
Initialize a class instance with parameters (constraints, resources, performance_evaluations) loaded from
yaml file if yaml file exists,
Expand All @@ -39,13 +33,13 @@ def __init__(
None
"""
# Validate file
(validated, message) = utils.validate_yaml_file(config_file)
(validated, message) = utils.validate_yaml_file(prompt_settings_file)
if not validated:
logger.typewriter_log("FAILED FILE VALIDATION", Fore.RED, message)
logger.double_check()
exit(1)

with open(config_file, encoding="utf-8") as file:
with open(prompt_settings_file, encoding="utf-8") as file:
config_params = yaml.load(file, Loader=yaml.FullLoader)

self.constraints = config_params.get("constraints", [])
Expand Down
5 changes: 2 additions & 3 deletions autogpt/json_utils/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from autogpt.config import Config
from autogpt.logs import logger

CFG = Config()
LLM_DEFAULT_RESPONSE_FORMAT = "llm_response_format_1"


Expand Down Expand Up @@ -37,7 +36,7 @@ def llm_response_schema(


def validate_json(
json_object: object, schema_name: str = LLM_DEFAULT_RESPONSE_FORMAT
json_object: object, config: Config, schema_name: str = LLM_DEFAULT_RESPONSE_FORMAT
) -> bool:
"""
:type schema_name: object
Expand All @@ -54,7 +53,7 @@ def validate_json(
for error in errors:
logger.error(f"JSON Validation Error: {error}")

if CFG.debug_mode:
if config.debug_mode:
logger.error(
json.dumps(json_object, indent=4)
) # Replace 'json_object' with the variable containing the JSON data
Expand Down
7 changes: 4 additions & 3 deletions autogpt/llm/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def chat_with_ai(
current_tokens_used += 500 # Reserve space for new_summary_message

# Add Messages until the token limit is reached or there are no more messages to add.
for cycle in reversed(list(agent.history.per_cycle())):
for cycle in reversed(list(agent.history.per_cycle(agent.config))):
messages_to_add = [msg for msg in cycle if msg is not None]
tokens_to_add = count_message_tokens(messages_to_add, model)
if current_tokens_used + tokens_to_add > send_token_limit:
Expand All @@ -110,14 +110,14 @@ def chat_with_ai(
# Update & add summary of trimmed messages
if len(agent.history) > 0:
new_summary_message, trimmed_messages = agent.history.trim_messages(
current_message_chain=list(message_sequence),
current_message_chain=list(message_sequence), config=agent.config
)
tokens_to_add = count_message_tokens([new_summary_message], model)
message_sequence.insert(insertion_index, new_summary_message)
current_tokens_used += tokens_to_add - 500

# FIXME: uncomment when memory is back in use
# memory_store = get_memory(cfg)
# memory_store = get_memory(config)
# for _, ai_msg, result_msg in agent.history.per_cycle(trimmed_messages):
# memory_to_add = MemoryItem.from_ai_action(ai_msg, result_msg)
# logger.debug(f"Storing the following memory:\n{memory_to_add.dump()}")
Expand Down Expand Up @@ -192,6 +192,7 @@ def chat_with_ai(
# temperature and other settings we care about
assistant_reply = create_chat_completion(
prompt=message_sequence,
config=agent.config,
max_tokens=tokens_remaining,
)

Expand Down
Loading

0 comments on commit ab0075b

Please sign in to comment.