Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

code update #2997

Merged
merged 16 commits into from
Feb 5, 2024
19 changes: 19 additions & 0 deletions fastchat/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -824,6 +824,15 @@ def get_conv_template(name: str) -> Conversation:
)
)

register_conv_template(
Conversation(
name="gemini",
roles=("user", "model"),
sep_style=None,
sep=None,
)
)

# BiLLa default template
register_conv_template(
Conversation(
Expand Down Expand Up @@ -1474,6 +1483,16 @@ def get_conv_template(name: str) -> Conversation:
)
)

# nvidia/Llama2-70B-SteerLM-Chat
register_conv_template(
Conversation(
name="steerlm",
system_message="",
roles=("user", "assistant"),
sep_style=None,
sep=None,
)
)

# yuan 2.0 template
# reference:https://github.com/IEIT-Yuan/Yuan-2.0
Expand Down
16 changes: 14 additions & 2 deletions fastchat/model/model_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@
"gpt-4-0314",
"gpt-4-0613",
"gpt-4-turbo",
"gpt-4-0125-preview",
)


Expand Down Expand Up @@ -1164,13 +1165,13 @@ class GeminiAdapter(BaseModelAdapter):
"""The model adapter for Gemini"""

def match(self, model_path: str):
return model_path in ["gemini-pro"]
return "gemini" in model_path.lower() or "bard" in model_path.lower()

def load_model(self, model_path: str, from_pretrained_kwargs: dict):
raise NotImplementedError()

def get_default_conv_template(self, model_path: str) -> Conversation:
return get_conv_template("bard")
return get_conv_template("gemini")


class BiLLaAdapter(BaseModelAdapter):
Expand Down Expand Up @@ -2193,6 +2194,16 @@ def get_default_conv_template(self, model_path: str) -> Conversation:
return get_conv_template("solar")


class SteerLMAdapter(BaseModelAdapter):
"""The model adapter for nvidia/Llama2-70B-SteerLM-Chat"""

def match(self, model_path: str):
return "steerlm-chat" in model_path.lower()

def get_default_conv_template(self, model_path: str) -> Conversation:
return get_conv_template("steerlm")


class LlavaAdapter(BaseModelAdapter):
"""The model adapter for liuhaotian/llava-v1.5 series of models"""

Expand Down Expand Up @@ -2327,6 +2338,7 @@ def get_default_conv_template(self, model_path: str) -> Conversation:
register_model_adapter(MetaMathAdapter)
register_model_adapter(BagelAdapter)
register_model_adapter(SolarAdapter)
register_model_adapter(SteerLMAdapter)
register_model_adapter(LlavaAdapter)
register_model_adapter(YuanAdapter)

Expand Down
34 changes: 31 additions & 3 deletions fastchat/model/model_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,19 +41,40 @@ def get_model_info(name: str) -> ModelInfo:
)

register_model_info(
["mixtral-8x7b-instruct-v0.1", "mistral-7b-instruct"],
["mixtral-8x7b-instruct-v0.1", "mistral-medium", "mistral-7b-instruct"],
"Mixtral of experts",
"https://mistral.ai/news/mixtral-of-experts/",
"A Mixture-of-Experts model by Mistral AI",
)

register_model_info(
["gemini-pro"],
["bard-feb-2024", "bard-jan-24-gemini-pro"],
"Bard",
"https://bard.google.com/",
"Bard by Google",
)

register_model_info(
["gemini-pro", "gemini-pro-dev-api"],
"Gemini",
"https://blog.google/technology/ai/google-gemini-pro-imagen-duet-ai-update/",
"Gemini by Google",
)

register_model_info(
["deepseek-llm-67b-chat"],
"DeepSeek LLM",
"https://huggingface.co/deepseek-ai/deepseek-llm-67b-chat",
"An advanced language model by DeepSeek",
)

register_model_info(
["stripedhyena-nous-7b"],
"StripedHyena-Nous",
"https://huggingface.co/togethercomputer/StripedHyena-Nous-7B",
"A chat model developed by Together Research and Nous Research.",
)

register_model_info(
["solar-10.7b-instruct-v1.0"],
"SOLAR-10.7B-Instruct",
Expand All @@ -62,7 +83,7 @@ def get_model_info(name: str) -> ModelInfo:
)

register_model_info(
["gpt-4-turbo"],
["gpt-4-turbo", "gpt-4-0125-preview"],
"GPT-4-Turbo",
"https://platform.openai.com/docs/models/gpt-4-and-gpt-4-turbo",
"GPT-4-Turbo by OpenAI",
Expand Down Expand Up @@ -103,6 +124,13 @@ def get_model_info(name: str) -> ModelInfo:
"Claude Instant by Anthropic",
)

register_model_info(
["llama2-70b-steerlm-chat"],
"Llama2-70B-SteerLM-Chat",
"https://huggingface.co/nvidia/Llama2-70B-SteerLM-Chat",
"A Llama fine-tuned with SteerLM method by NVIDIA",
)

register_model_info(
["pplx-70b-online", "pplx-7b-online"],
"pplx-online-llms",
Expand Down
Loading
Loading