Skip to content

Commit

Permalink
Merge pull request #48 from InternLM/dev
Browse files Browse the repository at this point in the history
refactor the config.yaml to make the model setting looks more logical
  • Loading branch information
fly2tomato authored Jan 26, 2024
2 parents f98c44b + e1df753 commit 24c5b31
Show file tree
Hide file tree
Showing 34 changed files with 374 additions and 232 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ English | [简体中文](docs/README_zh-CN.md)


## Latest Progress 🎉

- \[January 2024\] refactor the config-template.yaml to control the backend and the frontend settings at the same time, [click](https://github.com/InternLM/OpenAOE/blob/main/docs/tech-report/config-template.md) to find more introduction about the `config-template.yaml`
- \[January 2024\] Add internlm2-chat-7b model
- \[January 2024\] Released version v0.0.1, officially open source!
______________________________________________________________________

Expand Down
Empty file.
6 changes: 4 additions & 2 deletions docs/todo/TODO.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
### TODO

- [x] set the workers to 3
- [ ] add Gemini model
- [x] refacotor the config.yaml to make the model setting looks more logical
- [x] add internlm2-chat-7b model as default
- [ ] add Gemini model as default
- [x] refactor the config.yaml to make the model setting looks more logical
- [ ] dynamic add new model by editing external python files and the config.yaml
- [ ] build frontend project when OpenAOE start up
95 changes: 70 additions & 25 deletions openaoe/backend/config/biz_config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import argparse
import os
import sys
from copy import deepcopy

import yaml

Expand All @@ -12,23 +13,53 @@

class BizConfig:
def __init__(self, **args):
# raw dict
self.__dict__.update(args)

def get(self, field):
if field in self.__dict__:
return self.__dict__[field]
return None


def init_config():
parser = argparse.ArgumentParser(description="Example app using a YAML config file.")
# provider -> model_name -> ModelConfig
@property
def models_map(self):
if "models" not in self.__dict__:
return {}

models_dict = self.__dict__["models"]
ret = {}
for model_name, model_config in models_dict.items():
if model_config["provider"] not in ret:
ret[model_config["provider"]] = {}
ret[model_config["provider"]][model_name] = ModelConfig(model_config["webui"], model_config["api"])
return ret

@models_map.getter
def __get_models_map(self):
return self.models_map

@property
def json(self):
ret = deepcopy(self.__dict__)
if "models" in ret:
models_config = ret["models"]
for model_name, config in models_config.items():
config.pop("api")
ret["models"][model_name] = config
return ret


class ModelConfig:
def __init__(self, webui_config, api_config):
self.webui_config = webui_config
self.api_config = api_config


def init_config() -> BizConfig:
parser = argparse.ArgumentParser(description="LLM group chat framework")
parser.add_argument('-f', '--file', type=str, required=True, help='Path to the YAML config file.')
config_path = parser.parse_args()
logger.info(f"your config file is: {config_path.file}")
load_config(config_path.file)
return load_config(config_path.file)


def load_config(config_path):
def load_config(config_path) -> BizConfig:
logger.info(f"start to init configuration from {config_path}.")
if not os.path.isfile(config_path):
logger.error(f"invalid path: {config_path}, not exist or not file")
Expand All @@ -43,32 +74,46 @@ def load_config(config_path):
global biz_config
biz_config = BizConfig(**m)
logger.info("init configuration successfully.")
return biz_config


def get_configuration(field):
return biz_config.get(field)


def get_model_configuration(vendor: str, field):
models = get_configuration("models")
def get_model_configuration(provider: str, field, model_name: str = None):
models = biz_config.models_map
if not models:
logger.error(f"invalid configuration file")
sys.exit(-1)

if models.get(vendor) and models.get(vendor).get(field):
conf = models.get(vendor).get(field)
return conf

logger.error(f"vendor: {vendor} has no field: {field} configuration")
provider_config = models.get(provider)
if provider_config:
if model_name:
try:
return provider_config.get(model_name).api_config.get(field)
except:
for config_model_name, config in provider_config.items():
if config_model_name.startswith(model_name):
return config.api_config.get(field)
logger.error(f"{provider} get field: {field} for model: {model_name} failed")
return ""
elif not model_name:
# default the first provider ModelConfig
provider_models_config_list = list(provider_config.values())
try:
logger.info(f"{provider} get field: {field} for anonymous model, use the first one as default.")
return provider_models_config_list[0].api_config.get(field)
except:
logger.error(f"{provider} get field: {field} for anonymous model failed")
return ""

logger.error(f"provider: {provider} has no field: {field} configuration for model: {model_name}")
return ""


def get_base_url(vendor: str) -> str:
return get_model_configuration(vendor, "api_base")
def get_base_url(provider: str, model_name: str = None) -> str:
return get_model_configuration(provider, "api_base", model_name)


def get_api_key(vendor: str) -> str:
return get_model_configuration(vendor, "api_key")
def get_api_key(provider: str, model_name: str = None) -> str:
return get_model_configuration(provider, "api_key", model_name)


def app_abs_path():
Expand Down
92 changes: 67 additions & 25 deletions openaoe/backend/config/config-template.yaml
Original file line number Diff line number Diff line change
@@ -1,29 +1,71 @@
---
models:
gpt:
api_base: https://api.openai.com/v1
api_key:

claude:
api_key:

bard:
api_base: https://bard.google.com
api_key:

minimax:
api_base: https://api.minimax.chat
group_id:
jwt:

internlm-chat-7b:
provider: internlm
webui:
avatar: 'https://oss.openmmlab.com/frontend/OpenAOE/internlm.svg'
background: 'linear-gradient(rgb(3 26 108 / 85%) 0%, rgb(29 60 161 / 85%) 100%)'
api:
api_base: http://localhost:23333
gpt-3.5-turbo:
provider: openai
webui:
avatar: 'https://oss.openmmlab.com/frontend/OpenAOE/openai.svg'
background: 'linear-gradient(180deg, rgba(156, 206, 116, 0.15) 0%, #1a8d15 100%)'
api:
api_base: https://api.openai.com/v1
api_key:
gpt-4:
provider: openai
webui:
avatar: 'https://oss.openmmlab.com/frontend/OpenAOE/openai.svg'
background: 'linear-gradient(180deg, rgba(156, 206, 116, 0.15) 0%, #08be00 100%)'
api:
api_base: https://api.openai.com/v1
api_key:
claude-1:
provider: claude
webui:
avatar: 'https://oss.openmmlab.com/frontend/OpenAOE/claude.svg'
background: 'linear-gradient(180deg, rgba(141, 90, 181, 0.15) 0%, rgba(106, 39, 123, 0.7) 53.12%, #663E9A 100%)'
api:
api_base: https://api.anthropic.com
api_key:
claude-1-100k:
provider: claude
webui:
avatar: 'https://oss.openmmlab.com/frontend/OpenAOE/claude.svg'
background: 'linear-gradient(180deg, rgba(141, 90, 181, 0.15) 0%, rgba(106, 39, 123, 0.7) 53.12%, #663E9A 100%)'
api:
api_base: https://api.anthropic.com
api_key:
chat-bison-001:
provider: google
webui:
avatar: 'https://oss.openmmlab.com/frontend/OpenAOE/google-palm.webp'
isStream: false
background: 'linear-gradient(180deg, rgba(181, 90, 90, 0.15) 0%, #fa5ab1 100%)'
api:
api_base: https://generativelanguage.googleapis.com
api_key:
abab5-chat:
provider: minimax
webui:
avatar: 'https://oss.openmmlab.com/frontend/OpenAOE/minimax.png'
background: 'linear-gradient(180deg, rgba(207, 72, 72, 0.15) 0%, rgba(151, 43, 43, 0.7) 53.12%, #742828 100%)'
api:
api_base: https://api.minimax.chat
group_id:
jwt:
spark:
api_base: wss://spark-api.xf-yun.com
app_id:
ak:
sk:

internlm:
api_base:
provider: spark
webui:
avatar: 'https://oss.openmmlab.com/frontend/OpenAOE/spark.svg'
isStream: false
background: 'linear-gradient(180deg, rgba(72, 72, 207, 0.15) 0%, #7498be 100%)'
api:
api_base: wss://spark-api.xf-yun.com/v2.1/chat
app_id:
ak:
sk:
...


12 changes: 6 additions & 6 deletions openaoe/backend/config/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@
DATE_PATTERN = "%Y%m%d"
TIMEOUT_SECONDS = 30

VENDOR_OPENAI = "gpt"
VENDOR_MINIMAX = "minimax"
VENDOR_GOOGLE = "bard"
VENDOR_XUNFEI = "spark"
VENDOR_CLAUDE = "claude"
VENDOR_INTERNLM = "internlm"
PROVIDER_OPENAI = "openai"
PROVIDER_MINIMAX = "minimax"
PROVIDER_GOOGLE = "google"
PROVIDER_XUNFEI = "spark"
PROVIDER_CLAUDE = "claude"
PROVIDER_INTERNLM = "internlm"

DEFAULT_TIMEOUT_SECONDS = 600

9 changes: 5 additions & 4 deletions openaoe/backend/service/service_claude.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
from anthropic import Anthropic, HUMAN_PROMPT, AI_PROMPT
from sse_starlette.sse import EventSourceResponse

from openaoe.backend.config.biz_config import get_api_key
from openaoe.backend.config.biz_config import get_api_key, get_base_url
from openaoe.backend.config.constant import TYPE_BOT, TYPE_USER, TYPE_SYSTEM
from openaoe.backend.config.constant import VENDOR_CLAUDE
from openaoe.backend.config.constant import PROVIDER_CLAUDE
from openaoe.backend.model.aoe_response import AOEResponse
from openaoe.backend.model.claude import ClaudeChatBody, ClaudeMessage

Expand All @@ -16,7 +16,8 @@ def claude_chat_stream_svc(request, body: ClaudeChatBody):
stream api logic for Claude model
use anthropic SDK: https://github.com/anthropics/anthropic-sdk-python
"""
api_key = get_api_key(VENDOR_CLAUDE)
api_key = get_api_key(PROVIDER_CLAUDE, body.model)
api_base = get_base_url(PROVIDER_CLAUDE, body.model)
prompt = _gen_prompt(body.messages)
if not prompt or len(prompt) == 0:
return AOEResponse(
Expand All @@ -25,7 +26,7 @@ def claude_chat_stream_svc(request, body: ClaudeChatBody):
data="prompt or messages must be set"
)

anthropic = Anthropic(api_key=api_key)
anthropic = Anthropic(api_key=api_key, base_url=api_base)

async def stream():
try:
Expand Down
4 changes: 2 additions & 2 deletions openaoe/backend/service/service_google.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ def palm_chat_svc(body: GooglePalmChatBody):
"""
chat logic for google PaLM model
"""
api_key = get_api_key(VENDOR_GOOGLE)
url = get_base_url(VENDOR_GOOGLE)
api_key = get_api_key(PROVIDER_GOOGLE, body.model)
url = get_base_url(PROVIDER_GOOGLE, body.model)
url = f"{url}/google/v1beta2/models/{body.model}:generateMessage?key={api_key}"
messages = [
{"content": msg.content, "author": msg.author}
Expand Down
4 changes: 2 additions & 2 deletions openaoe/backend/service/service_internlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def chat_completion_v1(request, body: InternlmChatCompletionBody):
}
msgs.append(msg_item)
# restful api
url = get_base_url(VENDOR_INTERNLM) + "/v1/chat/completions"
url = get_base_url(PROVIDER_INTERNLM, body.model) + "/v1/chat/completions"
headers = {
'accept': 'application/json',
'Content-Type': 'application/json'
Expand All @@ -42,7 +42,7 @@ def chat_completion_v1(request, body: InternlmChatCompletionBody):
"top_p": body.top_p,
"n": body.n,
"max_tokens": body.max_tokens,
"stop": False,
"stop": "false",
"stream": body.stream,
"presence_penalty": body.presence_penalty,
"frequency_penalty": body.frequency_penalty,
Expand Down
6 changes: 3 additions & 3 deletions openaoe/backend/service/service_minimax.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@


def _get_req_param(body):
group_id = get_model_configuration(VENDOR_MINIMAX, "group_id")
jwt = get_model_configuration(VENDOR_MINIMAX, "jwt")
api_base = get_base_url(VENDOR_MINIMAX)
group_id = get_model_configuration(PROVIDER_MINIMAX, "group_id", body.model)
jwt = get_model_configuration(PROVIDER_MINIMAX, "jwt", body.model)
api_base = get_base_url(PROVIDER_MINIMAX, body.model)
headers = {
"Authorization": jwt,
"Content-Type": "application/json"
Expand Down
8 changes: 4 additions & 4 deletions openaoe/backend/service/service_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,9 @@ def chat_completion_stream(request, body):
async def event_generator():
while True:
client = OpenAI(
api_key=get_api_key(VENDOR_OPENAI),
api_key=get_api_key(PROVIDER_OPENAI, body.model),
timeout=body.timeout,
base_url=get_base_url(VENDOR_OPENAI)
base_url=get_base_url(PROVIDER_OPENAI, body.model)
)

stop_flag = False
Expand Down Expand Up @@ -80,9 +80,9 @@ async def event_generator():
async def event_generator_json():
while True:
client = OpenAI(
api_key=get_api_key(VENDOR_OPENAI),
api_key=get_api_key(PROVIDER_OPENAI, body.model),
timeout=body.timeout,
base_url=get_base_url(VENDOR_OPENAI)
base_url=get_base_url(PROVIDER_OPENAI, body.model)
)
stop_flag = False
response = ""
Expand Down
10 changes: 5 additions & 5 deletions openaoe/backend/service/service_xunfei.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from websocket import create_connection

from openaoe.backend.config.biz_config import get_model_configuration, get_base_url
from openaoe.backend.config.constant import VENDOR_XUNFEI
from openaoe.backend.config.constant import PROVIDER_XUNFEI
from openaoe.backend.model.aoe_response import AOEResponse
from openaoe.backend.model.xunfei import XunfeiSparkChatBody
from openaoe.backend.util.log import log
Expand Down Expand Up @@ -59,10 +59,10 @@ def spark_chat_svc(body: XunfeiSparkChatBody):
"""
chat logic for spark model
"""
api_base = get_base_url(VENDOR_XUNFEI)
app_id = get_model_configuration(VENDOR_XUNFEI, "app_id")
ak = get_model_configuration(VENDOR_XUNFEI, "ak")
sk = get_model_configuration(VENDOR_XUNFEI, "sk")
api_base = get_base_url(PROVIDER_XUNFEI)
app_id = get_model_configuration(PROVIDER_XUNFEI, "app_id")
ak = get_model_configuration(PROVIDER_XUNFEI, "ak")
sk = get_model_configuration(PROVIDER_XUNFEI, "sk")

url_parse = urllib.parse.urlparse(api_base)
host = url_parse.hostname
Expand Down
1 change: 0 additions & 1 deletion openaoe/frontend/dist/assets/claude-bd5f04f1.svg

This file was deleted.

Loading

0 comments on commit 24c5b31

Please sign in to comment.