Skip to content

Commit

Permalink
Merge branch 'main' into update_preset
Browse files Browse the repository at this point in the history
  • Loading branch information
lisadunlap authored Aug 27, 2024
2 parents 7e6ab6b + 05b9305 commit 2f20e3b
Show file tree
Hide file tree
Showing 12 changed files with 244 additions and 28 deletions.
12 changes: 6 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
# FastChat
| [**Demo**](https://chat.lmsys.org/) | [**Discord**](https://discord.gg/HSWAKCrnFx) | [**X**](https://x.com/lmsysorg) |
| [**Demo**](https://lmarena.ai/) | [**Discord**](https://discord.gg/HSWAKCrnFx) | [**X**](https://x.com/lmsysorg) |

FastChat is an open platform for training, serving, and evaluating large language model based chatbots.
- FastChat powers Chatbot Arena (https://chat.lmsys.org/), serving over 10 million chat requests for 70+ LLMs.
- Chatbot Arena has collected over 500K human votes from side-by-side LLM battles to compile an online [LLM Elo leaderboard](https://leaderboard.lmsys.org).
- FastChat powers Chatbot Arena ([lmarena.ai](https://lmarena.ai)), serving over 10 million chat requests for 70+ LLMs.
- Chatbot Arena has collected over 1.5M human votes from side-by-side LLM battles to compile an online [LLM Elo leaderboard](https://lmarena.ai/?leaderboard).

FastChat's core features include:
- The training and evaluation code for state-of-the-art models (e.g., Vicuna, MT-Bench).
Expand All @@ -26,7 +26,7 @@ FastChat's core features include:

</details>

<a href="https://chat.lmsys.org"><img src="assets/demo_narrow.gif" width="70%"></a>
<a href="https://lmarena.ai"><img src="assets/demo_narrow.gif" width="70%"></a>

## Contents
- [Install](#install)
Expand Down Expand Up @@ -97,7 +97,7 @@ You can use the commands below to chat with them. They will automatically downlo

## Inference with Command Line Interface

<a href="https://chat.lmsys.org"><img src="assets/screenshot_cli.png" width="70%"></a>
<a href="https://lmarena.ai"><img src="assets/screenshot_cli.png" width="70%"></a>

(Experimental Feature: You can specify `--style rich` to enable rich text output and better text streaming quality for some non-ASCII content. This may not work properly on certain terminals.)

Expand Down Expand Up @@ -202,7 +202,7 @@ export FASTCHAT_USE_MODELSCOPE=True

## Serving with Web GUI

<a href="https://chat.lmsys.org"><img src="assets/screenshot_gui.png" width="70%"></a>
<a href="https://lmarena.ai"><img src="assets/screenshot_gui.png" width="70%"></a>

To serve using the web UI, you need three main components: web servers that interface with users, model workers that host one or more models, and a controller to coordinate the webserver and model workers. You can learn more about the architecture [here](docs/server_arch.md).

Expand Down
2 changes: 1 addition & 1 deletion docs/arena.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Chatbot Arena
Chatbot Arena is an LLM benchmark platform featuring anonymous, randomized battles, available at https://chat.lmsys.org.
Chatbot Arena is an LLM benchmark platform featuring anonymous, randomized battles, available at https://lmarena.ai.
We invite the entire community to join this benchmarking effort by contributing your votes and models.

## How to add a new model
Expand Down
1 change: 1 addition & 0 deletions docs/dataset_release.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,6 @@
We release the following datasets based on our projects and websites.

- [LMSYS-Chat-1M: A Large-Scale Real-World LLM Conversation Dataset](https://huggingface.co/datasets/lmsys/lmsys-chat-1m)
- [LMSYS-Human-Preference-55k](lmsys/lmsys-arena-human-preference-55k)
- [Chatbot Arena Conversation Dataset](https://huggingface.co/datasets/lmsys/chatbot_arena_conversations)
- [MT-bench Human Annotation Dataset](https://huggingface.co/datasets/lmsys/mt_bench_human_judgments)
2 changes: 1 addition & 1 deletion fastchat/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
CONVERSATION_LIMIT_MSG = "YOU HAVE REACHED THE CONVERSATION LENGTH LIMIT. PLEASE CLEAR HISTORY AND START A NEW CONVERSATION."
INACTIVE_MSG = "THIS SESSION HAS BEEN INACTIVE FOR TOO LONG. PLEASE REFRESH THIS PAGE."
SLOW_MODEL_MSG = "⚠️ Both models will show the responses all at once. Please stay patient as it may take over 30 seconds."
RATE_LIMIT_MSG = "**RATE LIMIT OF THIS MODEL IS REACHED. PLEASE COME BACK LATER OR USE <span style='color: red; font-weight: bold;'>[BATTLE MODE](https://chat.lmsys.org)</span> (the 1st tab).**"
RATE_LIMIT_MSG = "**RATE LIMIT OF THIS MODEL IS REACHED. PLEASE COME BACK LATER OR USE <span style='color: red; font-weight: bold;'>[BATTLE MODE](https://lmarena.ai)</span> (the 1st tab).**"
# Maximum input length
INPUT_CHAR_LEN_LIMIT = int(os.getenv("FASTCHAT_INPUT_CHAR_LEN_LIMIT", 12000))
BLIND_MODE_INPUT_CHAR_LEN_LIMIT = int(
Expand Down
2 changes: 1 addition & 1 deletion fastchat/model/model_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2423,7 +2423,7 @@ def get_default_conv_template(self, model_path: str) -> Conversation:


class DBRXAdapter(BaseModelAdapter):
"""The model adapter for Cohere"""
"""The model adapter for Databricks"""

def match(self, model_path: str):
return model_path in ["dbrx-instruct"]
Expand Down
12 changes: 6 additions & 6 deletions fastchat/model/model_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,17 +195,17 @@ def get_model_info(name: str) -> ModelInfo:
)

register_model_info(
["command-r-plus"],
"Command-R-Plus",
["command-r-plus", "command-r-plus-04-2024"],
"Command R+",
"https://txt.cohere.com/command-r-plus-microsoft-azure/",
"Command-R Plus by Cohere",
"Command R+ by Cohere",
)

register_model_info(
["command-r"],
"Command-R",
["command-r", "command-r-03-2024", "command-r-08-2024"],
"Command R",
"https://txt.cohere.com/command-r/",
"Command-R by Cohere",
"Command R by Cohere",
)

register_model_info(
Expand Down
6 changes: 3 additions & 3 deletions fastchat/serve/gradio_block_arena_anony.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,15 +420,15 @@ def build_side_by_side_ui_anony(models):
{SURVEY_LINK}
## 📣 News
- Chatbot Arena now supports images in beta. Check it out [here](https://chat.lmsys.org/?vision).
- Chatbot Arena now supports images in beta. Check it out [here](https://lmarena.ai/?vision).
## 📜 Rules
- Ask any question to two anonymous models (e.g., ChatGPT, Gemini, Claude, Llama) and vote for the better one!
- You can chat for multiple turns until you identify a winner.
- Votes won't be counted if model identities are revealed during the conversation.
## 🏆 Chatbot Arena [Leaderboard](https://leaderboard.lmsys.org)
- We've collected **1,000,000+** human votes to compute an LLM leaderboard for 100+ models. Find out who is the 🥇LLM Champion [here](https://leaderboard.lmsys.org)!
## 🏆 Chatbot Arena [Leaderboard](https://lmarena.ai/?leaderboard)
- We've collected **1,000,000+** human votes to compute an LLM leaderboard for 100+ models. Find out who is the 🥇LLM Champion [here](https://lmarena.ai/?leaderboard)!
## 👇 Chat now!
"""
Expand Down
4 changes: 2 additions & 2 deletions fastchat/serve/gradio_block_arena_vision_anony.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,8 +386,8 @@ def build_side_by_side_vision_ui_anony(text_models, vl_models, random_questions=
- Vote won't be counted if model identity is revealed during conversation.
- **NEW** Image Support: <span style='color: #DE3163; font-weight: bold'>Upload an image</span> on your first turn to unlock the multimodal arena! Images should be less than 15MB.
## 🏆 Chatbot Arena [Leaderboard](https://leaderboard.lmsys.org)
- We've collected **1,000,000+** human votes to compute an LLM Elo leaderboard for 100+ models. Find out who is the 🥇LLM Champion [here](https://leaderboard.lmsys.org)!
## 🏆 Chatbot Arena [Leaderboard](https://lmarena.ai/?leaderboard)
- We've collected **1,000,000+** human votes to compute an LLM Elo leaderboard for 100+ models. Find out who is the 🥇LLM Champion [here](https://lmarena.ai/?leaderboard)!
## 👇 Chat now!
"""
Expand Down
84 changes: 84 additions & 0 deletions fastchat/serve/monitor/add_markdown_info.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import pandas as pd
import re
import argparse

from tqdm import tqdm

tqdm.pandas()


def count_markdown_elements(markdown_text, suffix):
counters = {
f"header_count{suffix}": {
"h1": len(re.findall(r"^#{1}\s", markdown_text, re.MULTILINE)),
"h2": len(re.findall(r"^#{2}\s", markdown_text, re.MULTILINE)),
"h3": len(re.findall(r"^#{3}\s", markdown_text, re.MULTILINE)),
"h4": len(re.findall(r"^#{4}\s", markdown_text, re.MULTILINE)),
"h5": len(re.findall(r"^#{5}\s", markdown_text, re.MULTILINE)),
"h6": len(re.findall(r"^#{6}\s", markdown_text, re.MULTILINE)),
},
f"list_count{suffix}": {
"ordered": len(re.findall(r"^\s*\d+\.\s", markdown_text, re.MULTILINE)),
"unordered": len(re.findall(r"^\s*[-*+]\s", markdown_text, re.MULTILINE)),
},
f"bold_count{suffix}": {
"**": len(re.findall(r"\*\*[^*\n]+\*\*", markdown_text)),
"__": len(re.findall(r"__[^_\n]+__", markdown_text)),
},
}
return counters


def remove_pattern(answer, pattern):
blocks = pattern.findall(answer)
for block in blocks:
answer = answer.replace(block, "")
return answer


def get_element_counts(df, column):
pattern = re.compile("```([^`]*)```")
answers = df[column].map(
lambda convo: "\n".join(
[turn["content"] for turn in convo if turn["role"] == "assistant"]
)
)
results = answers.progress_map(
lambda answer: count_markdown_elements(
remove_pattern(answer, pattern),
suffix=column[-2:], # Remove code block first
)
)

return results.tolist()


def add_markdown_meta(row):
conv_meta = {k: v for k, v in row["conv_metadata"].items()}
return conv_meta | row["markdown_meta_a"] | row["markdown_meta_b"]


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--input-file", type=str, required=True)
parser.add_argument("--output-file", type=str, required=True)
args = parser.parse_args()

print("loading file...")
data = pd.read_json(args.input_file)

assert "conv_metadata" in data.columns

temp = data[["question_id", "conv_metadata"]].copy()

print("Processing conversation_a")
temp["markdown_meta_a"] = get_element_counts(data, column="conversation_a")

print("Processing conversation_b")
temp["markdown_meta_b"] = get_element_counts(data, column="conversation_b")

print("Post-processing...")
data["conv_metadata"] = temp.apply(add_markdown_meta, axis=1)

print("Saving to file...")
data.to_json(args.output_file, orient="records", indent=4, force_ascii=False)
139 changes: 135 additions & 4 deletions fastchat/serve/monitor/elo_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,18 @@
pd.options.display.float_format = "{:.2f}".format


STYLE_CONTROL_ELEMENTS_V1 = [
"sum_assistant_a_tokens",
"header_count_a",
"list_count_a",
"bold_count_a",
"sum_assistant_b_tokens",
"header_count_b",
"list_count_b",
"bold_count_b",
]


def compute_elo(battles, K=4, SCALE=400, BASE=10, INIT_RATING=1000):
rating = defaultdict(lambda: INIT_RATING)

Expand Down Expand Up @@ -399,6 +411,109 @@ def outlier_detect(
return battles


def fit_mle_elo(X, Y, models, indices=None, SCALE=400, INIT_RATING=1000):
from sklearn.linear_model import LogisticRegression

p = len(models.index)

lr = LogisticRegression(fit_intercept=False)
if indices:
lr.fit(X[indices], Y[indices])
else:
lr.fit(X, Y)

elo_scores = SCALE * lr.coef_[0] + INIT_RATING
# calibrate llama-13b to 800 if applicable
if "mixtral-8x7b-instruct-v0.1" in models.index:
elo_scores += 1114 - elo_scores[models["mixtral-8x7b-instruct-v0.1"]]
return (
pd.Series(elo_scores[:p], index=models.index).sort_values(ascending=False),
lr.coef_[0][p:],
)


def construct_style_matrices(
df,
BASE=10,
apply_ratio=[1, 1, 1, 1],
style_elements=STYLE_CONTROL_ELEMENTS_V1,
add_one=True,
):
models = pd.concat([battles["model_a"], battles["model_b"]]).unique()
models = pd.Series(np.arange(len(models)), index=models)

# duplicate battles
df = pd.concat([df, df], ignore_index=True)
p = len(models.index)
n = df.shape[0]
assert len(style_elements) % 2 == 0
k = int(len(style_elements) / 2)

X = np.zeros([n, p + k])
X[np.arange(n), models[df["model_a"]]] = +math.log(BASE)
X[np.arange(n), models[df["model_b"]]] = -math.log(BASE)

# creates turn each of the specified column in "conv_metadata" into a vector
style_vector = np.array(
[
df.conv_metadata.map(
lambda x: x[element]
if type(x[element]) is int
else sum(x[element].values())
).tolist()
for element in style_elements
]
)

style_diff = (style_vector[:k] - style_vector[k:]).astype(float)
style_sum = (style_vector[:k] + style_vector[k:]).astype(float)

if add_one:
style_sum = style_sum + np.ones(style_diff.shape)

apply_ratio = np.flatnonzero(apply_ratio)

style_diff[apply_ratio] /= style_sum[
apply_ratio
] # Apply ratio where necessary (length, etc)

style_mean = np.mean(style_diff, axis=1)
style_std = np.std(style_diff, axis=1)

X[:, -k:] = ((style_diff - style_mean[:, np.newaxis]) / style_std[:, np.newaxis]).T

# one A win => two A win
Y = np.zeros(n)
Y[df["winner"] == "model_a"] = 1.0

# one tie => one A win + one B win
# find tie + tie (both bad) index
tie_idx = (df["winner"] == "tie") | (df["winner"] == "tie (bothbad)")
tie_idx[len(tie_idx) // 2 :] = False
Y[tie_idx] = 1.0

return X, Y, models


def get_bootstrap_result_style_control(X, Y, models, func_compute_elo, num_round=1000):
elos = []
coefs = []
for _ in tqdm(range(num_round), desc="bootstrap"):
indices = np.random.choice(
list(range(len(battles))), size=(len(battles)), replace=True
)
_X = X[indices]
_Y = Y[indices]
states = ~_X[:, : len(models)].any(axis=0)

elo, coef = func_compute_elo(_X, _Y, models=models[~states])
elos.append(elo)
coefs.append(coef)

df = pd.DataFrame(elos)
return df[df.median().sort_values(ascending=False).index], coefs


def filter_long_conv(row):
threshold = 768
for conversation_type in ["conversation_a", "conversation_b"]:
Expand All @@ -421,6 +536,7 @@ def report_elo_analysis_results(
run_outlier_detect=False,
scale=1,
filter_func=lambda x: True,
style_control=False,
):
battles = pd.DataFrame(battles_json)

Expand Down Expand Up @@ -461,10 +577,17 @@ def report_elo_analysis_results(
elo_rating_online = compute_elo(battles)

if rating_system == "bt":
bootstrap_df = get_bootstrap_result(
battles, compute_elo_mle_with_tie, num_round=num_bootstrap
)
elo_rating_final = compute_elo_mle_with_tie(battles)
if style_control:
X, Y, models = construct_style_matrices(battles)
bootstrap_df, boostrap_coef = get_bootstrap_result_style_control(
X, Y, models, fit_mle_elo, num_round=num_bootstrap
)
elo_rating_final, coef_final = fit_mle_elo(X, Y, models)
else:
bootstrap_df = get_bootstrap_result(
battles, compute_elo_mle_with_tie, num_round=num_bootstrap
)
elo_rating_final = compute_elo_mle_with_tie(battles)
elif rating_system == "elo":
bootstrap_df = get_bootstrap_result(
battles, compute_elo, num_round=num_bootstrap
Expand Down Expand Up @@ -538,6 +661,12 @@ def report_elo_analysis_results(
"last_updated_tstamp": last_updated_tstamp,
"bootstrap_df": bootstrap_df,
"leaderboard_table_df": leaderboard_table_df,
"style_coefficients": {
"bootstrap": np.vstack(boostrap_coef),
"final": coef_final,
}
if rating_system == "bt" and style_control
else {},
}


Expand Down Expand Up @@ -565,6 +694,7 @@ def pretty_print_elo_rating(rating):
parser.add_argument("--run-outlier-detect", action="store_true", default=False)
parser.add_argument("--category", nargs="+", default=["full"])
parser.add_argument("--scale", type=float, default=1)
parser.add_argument("--style-control", action="store_true")
args = parser.parse_args()

np.random.seed(42)
Expand Down Expand Up @@ -602,6 +732,7 @@ def pretty_print_elo_rating(rating):
run_outlier_detect=args.run_outlier_detect,
scale=args.scale,
filter_func=filter_func,
style_control=args.style_control,
)

for cat in args.category:
Expand Down
Loading

0 comments on commit 2f20e3b

Please sign in to comment.