Skip to content

Commit

Permalink
Adds multiple providers support
Browse files Browse the repository at this point in the history
  • Loading branch information
akash-plane committed Dec 12, 2024
1 parent 954fd2f commit 65acf0c
Show file tree
Hide file tree
Showing 2 changed files with 106 additions and 30 deletions.
9 changes: 8 additions & 1 deletion .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,16 @@ FILE_SIZE_LIMIT=5242880

# GPT settings
OPENAI_API_BASE="https://api.openai.com/v1" # deprecated
OPENAI_API_KEY="sk-" # deprecated
GPT_ENGINE="gpt-3.5-turbo" # deprecated

# AI Assistant Settings
LLM_PROVIDER=openai # Can be "openai", "anthropic", or "google"
LLM_MODEL=gpt-4o-mini # The specific model you want to use

OPENAI_API_KEY=your-openai-api-key
ANTHROPIC_API_KEY=your-anthropic-api-key
GEMINI_API_KEY=your-gemini-api-key

# Settings related to Docker
DOCKERIZED=1 # deprecated

Expand Down
127 changes: 98 additions & 29 deletions apiserver/plane/app/views/external/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Python import
import os
from typing import List, Dict, Tuple

# Third party import
import litellm
Expand All @@ -19,30 +20,105 @@
from ..base import BaseAPIView


def get_gpt_config():
"""Helper to get GPT configuration values"""
OPENAI_API_KEY, GPT_ENGINE = get_configuration_value([
class LLMProvider:
"""Base class for LLM provider configurations"""
name: str = ""
models: List[str] = []
api_key_env: str = ""
default_model: str = ""

@classmethod
def get_config(cls) -> Dict[str, str | List[str]]:
return {
"name": cls.name,
"models": cls.models,
"default_model": cls.default_model,
}

class OpenAIProvider(LLMProvider):
name = "OpenAI"
models = ["gpt-3.5-turbo", "gpt-4o-mini", "gpt-4o", "o1-mini", "o1-preview"]
api_key_env = "OPENAI_API_KEY"
default_model = "gpt-4o-mini"

class AnthropicProvider(LLMProvider):
name = "Anthropic"
models = [
"claude-3-5-sonnet-20240620",
"claude-3-haiku-20240307",
"claude-3-opus-20240229",
"claude-3-sonnet-20240229",
"claude-2.1",
"claude-2",
"claude-instant-1.2",
"claude-instant-1"
]
api_key_env = "ANTHROPIC_API_KEY"
default_model = "claude-3-sonnet-20240229"

class GeminiProvider(LLMProvider):
name = "Gemini"
models = ["gemini-pro", "gemini-1.5-pro-latest", "gemini-pro-vision"]
api_key_env = "GEMINI_API_KEY"
default_model = "gemini-pro"

SUPPORTED_PROVIDERS = {
"openai": OpenAIProvider,
"anthropic": AnthropicProvider,
"gemini": GeminiProvider,
}

def get_llm_config() -> Tuple[str | None, str | None, str | None]:
"""
Helper to get LLM configuration values, returns:
- api_key, model, provider
"""
provider_key, model = get_configuration_value([
{
"key": "OPENAI_API_KEY",
"default": os.environ.get("OPENAI_API_KEY", None),
"key": "LLM_PROVIDER",
"default": os.environ.get("LLM_PROVIDER", "openai"),
},
{
"key": "GPT_ENGINE",
"default": os.environ.get("GPT_ENGINE", "gpt-4o-mini"),
"key": "LLM_MODEL",
"default": None,
},
])

if not OPENAI_API_KEY or not GPT_ENGINE:
return None, None
return OPENAI_API_KEY, GPT_ENGINE

provider = SUPPORTED_PROVIDERS.get(provider_key.lower())
if not provider:
return None, None, None

def get_gpt_response(task, prompt, api_key, engine):
"""Helper to get GPT completion response"""
api_key, _ = get_configuration_value([
{
"key": provider.api_key_env,
"default": os.environ.get(provider.api_key_env, None),
}
])

if not api_key:
return None, None, None

# If no model specified, use provider's default
if not model:
model = provider.default_model

# Validate model is supported by provider
if model not in provider.models:
return None, None, None

return api_key, model, provider_key


def get_llm_response(task, prompt, api_key: str, model: str, provider: str) -> Tuple[str | None, str | None]:
"""Helper to get LLM completion response"""
final_text = task + "\n" + prompt
try:
# For Gemini, prepend provider name to model
if provider.lower() == "gemini":
model = f"gemini/{model}"

response = litellm.completion(
model=engine,
model=model,
messages=[{"role": "user", "content": final_text}],
api_key=api_key,
)
Expand All @@ -56,18 +132,11 @@ def get_gpt_response(task, prompt, api_key, engine):
class GPTIntegrationEndpoint(BaseAPIView):
@allow_permission([ROLE.ADMIN, ROLE.MEMBER])
def post(self, request, slug, project_id):
OPENAI_API_KEY, GPT_ENGINE = get_gpt_config()

supported_models = ["gpt-4o-mini", "gpt-4o"]
if GPT_ENGINE not in supported_models:
return Response(
{"error": f"Unsupported model. Please use one of: {', '.join(supported_models)}"},
status=status.HTTP_400_BAD_REQUEST,
)
api_key, model, provider = get_llm_config()

if not OPENAI_API_KEY or not GPT_ENGINE:
if not api_key or not model or not provider:
return Response(
{"error": "OpenAI API key and engine is required"},
{"error": "LLM provider API key and model are required"},
status=status.HTTP_400_BAD_REQUEST,
)

Expand All @@ -77,7 +146,7 @@ def post(self, request, slug, project_id):
{"error": "Task is required"}, status=status.HTTP_400_BAD_REQUEST
)

text, error = get_gpt_response(task, request.data.get("prompt", False), OPENAI_API_KEY, GPT_ENGINE)
text, error = get_llm_response(task, request.data.get("prompt", False), api_key, model, provider)
if not text and error:
return Response(
{"error": "An internal error has occurred."},
Expand All @@ -101,11 +170,11 @@ def post(self, request, slug, project_id):
class WorkspaceGPTIntegrationEndpoint(BaseAPIView):
@allow_permission(allowed_roles=[ROLE.ADMIN, ROLE.MEMBER], level="WORKSPACE")
def post(self, request, slug):
OPENAI_API_KEY, GPT_ENGINE = get_gpt_config()
api_key, model, provider = get_llm_config()

if not OPENAI_API_KEY or not GPT_ENGINE:
if not api_key or not model or not provider:
return Response(
{"error": "OpenAI API key and engine is required"},
{"error": "LLM provider API key and model are required"},
status=status.HTTP_400_BAD_REQUEST,
)

Expand All @@ -115,7 +184,7 @@ def post(self, request, slug):
{"error": "Task is required"}, status=status.HTTP_400_BAD_REQUEST
)

text, error = get_gpt_response(task, request.data.get("prompt", False), OPENAI_API_KEY, GPT_ENGINE)
text, error = get_llm_response(task, request.data.get("prompt", False), api_key, model, provider)
if not text and error:
return Response(
{"error": "An internal error has occurred."},
Expand Down

0 comments on commit 65acf0c

Please sign in to comment.