Skip to content

Commit

Permalink
code update (#2997)
Browse files Browse the repository at this point in the history
  • Loading branch information
infwinston authored Feb 5, 2024
1 parent 81785d7 commit 2264204
Show file tree
Hide file tree
Showing 11 changed files with 757 additions and 221 deletions.
19 changes: 19 additions & 0 deletions fastchat/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -824,6 +824,15 @@ def get_conv_template(name: str) -> Conversation:
)
)

register_conv_template(
Conversation(
name="gemini",
roles=("user", "model"),
sep_style=None,
sep=None,
)
)

# BiLLa default template
register_conv_template(
Conversation(
Expand Down Expand Up @@ -1474,6 +1483,16 @@ def get_conv_template(name: str) -> Conversation:
)
)

# nvidia/Llama2-70B-SteerLM-Chat
register_conv_template(
Conversation(
name="steerlm",
system_message="",
roles=("user", "assistant"),
sep_style=None,
sep=None,
)
)

# yuan 2.0 template
# reference:https://github.com/IEIT-Yuan/Yuan-2.0
Expand Down
16 changes: 14 additions & 2 deletions fastchat/model/model_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@
"gpt-4-0314",
"gpt-4-0613",
"gpt-4-turbo",
"gpt-4-0125-preview",
)


Expand Down Expand Up @@ -1164,13 +1165,13 @@ class GeminiAdapter(BaseModelAdapter):
"""The model adapter for Gemini"""

def match(self, model_path: str):
return model_path in ["gemini-pro"]
return "gemini" in model_path.lower() or "bard" in model_path.lower()

def load_model(self, model_path: str, from_pretrained_kwargs: dict):
raise NotImplementedError()

def get_default_conv_template(self, model_path: str) -> Conversation:
return get_conv_template("bard")
return get_conv_template("gemini")


class BiLLaAdapter(BaseModelAdapter):
Expand Down Expand Up @@ -2193,6 +2194,16 @@ def get_default_conv_template(self, model_path: str) -> Conversation:
return get_conv_template("solar")


class SteerLMAdapter(BaseModelAdapter):
"""The model adapter for nvidia/Llama2-70B-SteerLM-Chat"""

def match(self, model_path: str):
return "steerlm-chat" in model_path.lower()

def get_default_conv_template(self, model_path: str) -> Conversation:
return get_conv_template("steerlm")


class LlavaAdapter(BaseModelAdapter):
"""The model adapter for liuhaotian/llava-v1.5 series of models"""

Expand Down Expand Up @@ -2327,6 +2338,7 @@ def get_default_conv_template(self, model_path: str) -> Conversation:
register_model_adapter(MetaMathAdapter)
register_model_adapter(BagelAdapter)
register_model_adapter(SolarAdapter)
register_model_adapter(SteerLMAdapter)
register_model_adapter(LlavaAdapter)
register_model_adapter(YuanAdapter)

Expand Down
34 changes: 31 additions & 3 deletions fastchat/model/model_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,19 +41,40 @@ def get_model_info(name: str) -> ModelInfo:
)

register_model_info(
["mixtral-8x7b-instruct-v0.1", "mistral-7b-instruct"],
["mixtral-8x7b-instruct-v0.1", "mistral-medium", "mistral-7b-instruct"],
"Mixtral of experts",
"https://mistral.ai/news/mixtral-of-experts/",
"A Mixture-of-Experts model by Mistral AI",
)

register_model_info(
["gemini-pro"],
["bard-feb-2024", "bard-jan-24-gemini-pro"],
"Bard",
"https://bard.google.com/",
"Bard by Google",
)

register_model_info(
["gemini-pro", "gemini-pro-dev-api"],
"Gemini",
"https://blog.google/technology/ai/google-gemini-pro-imagen-duet-ai-update/",
"Gemini by Google",
)

register_model_info(
["deepseek-llm-67b-chat"],
"DeepSeek LLM",
"https://huggingface.co/deepseek-ai/deepseek-llm-67b-chat",
"An advanced language model by DeepSeek",
)

register_model_info(
["stripedhyena-nous-7b"],
"StripedHyena-Nous",
"https://huggingface.co/togethercomputer/StripedHyena-Nous-7B",
"A chat model developed by Together Research and Nous Research.",
)

register_model_info(
["solar-10.7b-instruct-v1.0"],
"SOLAR-10.7B-Instruct",
Expand All @@ -62,7 +83,7 @@ def get_model_info(name: str) -> ModelInfo:
)

register_model_info(
["gpt-4-turbo"],
["gpt-4-turbo", "gpt-4-0125-preview"],
"GPT-4-Turbo",
"https://platform.openai.com/docs/models/gpt-4-and-gpt-4-turbo",
"GPT-4-Turbo by OpenAI",
Expand Down Expand Up @@ -103,6 +124,13 @@ def get_model_info(name: str) -> ModelInfo:
"Claude Instant by Anthropic",
)

register_model_info(
["llama2-70b-steerlm-chat"],
"Llama2-70B-SteerLM-Chat",
"https://huggingface.co/nvidia/Llama2-70B-SteerLM-Chat",
"A Llama fine-tuned with SteerLM method by NVIDIA",
)

register_model_info(
["pplx-70b-online", "pplx-7b-online"],
"pplx-online-llms",
Expand Down
Loading

0 comments on commit 2264204

Please sign in to comment.