Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support custom LLM implementation #174

Merged
merged 11 commits into from
Nov 5, 2024
19 changes: 14 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ To this 👇
<a href="https://www.loom.com/share/4c55f395dbd64ef3b69670eccf961124">
<img style="max-width:300px;" src="https://cdn.loom.com/sessions/thumbnails/4c55f395dbd64ef3b69670eccf961124-db2004995e8d621c-full-play.gif">
</a>

## Ways to Use HolmesGPT

<details>
Expand Down Expand Up @@ -137,10 +137,7 @@ plugins:
scopes:
- all
command: bash
background: false
confirm: false
args:
- -c

- |
INSTRUCTIONS="# Edit the line below. Lines starting with '#' will be ignored."
DEFAULT_ASK_COMMAND="why is $NAME of $RESOURCE_NAME in -n $NAMESPACE not working as expected"
Expand All @@ -167,6 +164,18 @@ plugins:
```
</details>


### Bring your own LLM
<details>
<summary>Bring your own LLM</summary>

You can use Holmes as a library and pass in your own LLM implementation. This is particularly useful if LiteLLM or the default Holmes implementation does not suit you.

See an example implementation [here](examples/custom_llm.py).


</details>

Like what you see? Checkout [other use cases](#other-use-cases) or get started by [installing HolmesGPT](#installation).

## Installation
Expand Down
60 changes: 60 additions & 0 deletions examples/custom_llm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@

from typing import Any, Dict, List, Optional, Type, Union
from holmes.config import Config
from holmes.core.llm import LLM
from litellm.types.utils import ModelResponse
from holmes.core.tool_calling_llm import ToolCallingLLM
from holmes.core.tools import Tool, ToolExecutor
from holmes.plugins.toolsets import load_builtin_toolsets
from rich.console import Console
from pydantic import BaseModel
from holmes.plugins.prompts import load_and_render_prompt
import sys
class MyCustomLLM(LLM):

def get_context_window_size(self) -> int:
return 128000

def get_maximum_output_token(self) -> int:
return 4096

def count_tokens_for_message(self, messages: list[dict]) -> int:
return 1

def completion(self, messages: List[Dict[str, Any]], tools: Optional[List[Tool]] = [], tool_choice: Optional[Union[str, dict]] = None, response_format: Optional[Union[dict, Type[BaseModel]]] = None, temperature:Optional[float] = None, drop_params: Optional[bool] = None) -> ModelResponse:
return ModelResponse(choices=[{
"finish_reason": "stop",
"index": 0,
"message": {
"role": "assistant",
"content": "There are no issues with your cluster"
}
}],
usage={
"prompt_tokens": 0, # Integer
"completion_tokens": 0,
"total_tokens": 0
}
)


def ask_holmes():
console = Console()

prompt = "what issues do I have in my cluster"

system_prompt = load_and_render_prompt("builtin://generic_ask.jinja2")

tool_executor = ToolExecutor(load_builtin_toolsets())
ai = ToolCallingLLM(
tool_executor,
max_steps=10,
llm=MyCustomLLM()
)

response = ai.call(system_prompt, prompt)

print(response.model_dump())


ask_holmes()
15 changes: 10 additions & 5 deletions holmes/config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
import os
import os.path
from holmes.core.llm import LLM, DefaultLLM
from strenum import StrEnum
from typing import List, Optional

Expand Down Expand Up @@ -145,14 +146,15 @@ def create_toolcalling_llm(
) -> ToolCallingLLM:
tool_executor = self._create_tool_executor(console, allowed_toolsets)
return ToolCallingLLM(
self.model,
self.api_key.get_secret_value() if self.api_key else None,
tool_executor,
self.max_steps,
self._get_llm()
)

def create_issue_investigator(
self, console: Console, allowed_toolsets: ToolsetPattern
self,
console: Console,
allowed_toolsets: ToolsetPattern
) -> IssueInvestigator:
all_runbooks = load_builtin_runbooks()
for runbook_path in self.custom_runbooks:
Expand All @@ -161,11 +163,10 @@ def create_issue_investigator(
runbook_manager = RunbookManager(all_runbooks)
tool_executor = self._create_tool_executor(console, allowed_toolsets)
return IssueInvestigator(
self.model,
self.api_key.get_secret_value() if self.api_key else None,
tool_executor,
runbook_manager,
self.max_steps,
self._get_llm()
)

def create_jira_source(self) -> JiraSource:
Expand Down Expand Up @@ -266,3 +267,7 @@ def load_from_file(cls, config_file: Optional[str], **kwargs) -> "Config":
merged_config = config_from_file.dict()
merged_config.update(cli_options)
return cls(**merged_config)

def _get_llm(self) -> LLM:
api_key = self.api_key.get_secret_value() if self.api_key else None
return DefaultLLM(self.model, api_key)
130 changes: 130 additions & 0 deletions holmes/core/llm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@

import logging
from abc import abstractmethod
from typing import Any, Dict, List, Literal, Optional, Type, Union

from litellm.types.utils import ModelResponse
from pydantic.types import SecretStr

from holmes.core.tools import Tool
from pydantic import BaseModel
import litellm
import os

from holmes.common.env_vars import ROBUSTA_AI, ROBUSTA_API_ENDPOINT

class LLM:

@abstractmethod
def get_context_window_size(self) -> int:
pass

@abstractmethod
def get_maximum_output_token(self) -> int:
pass

@abstractmethod
def count_tokens_for_message(self, messages: list[dict]) -> int:
pass

@abstractmethod
def completion(self, messages: List[Dict[str, Any]], tools: Optional[List[Tool]] = [], tool_choice: Optional[Union[str, dict]] = None, response_format: Optional[Union[dict, Type[BaseModel]]] = None, temperature:Optional[float] = None, drop_params: Optional[bool] = None) -> ModelResponse:
pass


class DefaultLLM(LLM):

model: str
api_key: str
base_url: Optional[str]

def __init__(
self,
model: str,
api_key: str
):
self.model = model
self.api_key = api_key
self.base_url = None

if ROBUSTA_AI:
self.base_url = ROBUSTA_API_ENDPOINT

self.check_llm(self.model, self.api_key)

def check_llm(self, model, api_key):
logging.debug(f"Checking LiteLLM model {model}")
# TODO: this WAS a hack to get around the fact that we can't pass in an api key to litellm.validate_environment
# so without this hack it always complains that the environment variable for the api key is missing
# to fix that, we always set an api key in the standard format that litellm expects (which is ${PROVIDER}_API_KEY)
# TODO: we can now handle this better - see https://github.com/BerriAI/litellm/issues/4375#issuecomment-2223684750
lookup = litellm.get_llm_provider(self.model)
if not lookup:
raise Exception(f"Unknown provider for model {model}")
provider = lookup[1]
api_key_env_var = f"{provider.upper()}_API_KEY"
if api_key:
os.environ[api_key_env_var] = api_key
model_requirements = litellm.validate_environment(model=model)
if not model_requirements["keys_in_environment"]:
raise Exception(f"model {model} requires the following environment variables: {model_requirements['missing_keys']}")

def _strip_model_prefix(self) -> str:
"""
Helper function to strip 'openai/' prefix from model name if it exists.
model cost is taken from here which does not have the openai prefix
https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json
"""
model_name = self.model
if model_name.startswith('openai/'):
model_name = model_name[len('openai/'):] # Strip the 'openai/' prefix
return model_name


# this unfortunately does not seem to work for azure if the deployment name is not a well-known model name
#if not litellm.supports_function_calling(model=model):
# raise Exception(f"model {model} does not support function calling. You must use HolmesGPT with a model that supports function calling.")
def get_context_window_size(self) -> int:
model_name = self._strip_model_prefix()
try:
return litellm.model_cost[model_name]['max_input_tokens']
except Exception as e:
logging.warning(f"Couldn't find model's name {model_name} in litellm's model list, fallback to 128k tokens for max_input_tokens")
return 128000

def count_tokens_for_message(self, messages: list[dict]) -> int:
return litellm.token_counter(model=self.model,
messages=messages)

def completion(self, messages: List[Dict[str, Any]], tools: Optional[List[Tool]] = [], tool_choice: Optional[Union[str, dict]] = None, response_format: Optional[Union[dict, Type[BaseModel]]] = None, temperature:Optional[float] = None, drop_params: Optional[bool] = None) -> ModelResponse:
result = litellm.completion(
model=self.model,
api_key=self.api_key,
messages=messages,
tools=tools,
tool_choice=tool_choice,
base_url=self.base_url,
temperature=temperature,
response_format=response_format,
drop_params=drop_params
)



if isinstance(result, ModelResponse):
response = result.choices[0]
response_message = response.message
# when asked to run tools, we expect no response other than the request to run tools unless bedrock
if response_message.content and ('bedrock' not in self.model and logging.DEBUG != logging.root.level):
logging.warning(f"got unexpected response when tools were given: {response_message.content}")
return result
else:
raise Exception(f"Unexpected type returned by the LLM {type(result)}")

def get_maximum_output_token(self) -> int:
model_name = self._strip_model_prefix()
try:
return litellm.model_cost[model_name]['max_output_tokens']
except Exception as e:
logging.warning(f"Couldn't find model's name {model_name} in litellm's model list, fallback to 4096 tokens for max_output_tokens")
return 4096
Loading
Loading