diff --git a/mteb/leaderboard/app.py b/mteb/leaderboard/app.py
index a6dd1c7325..6838891e02 100644
--- a/mteb/leaderboard/app.py
+++ b/mteb/leaderboard/app.py
@@ -218,12 +218,12 @@ def update_task_info(task_names: str) -> gr.DataFrame:
)
citation = gr.Markdown(update_citation, inputs=[benchmark_select])
with gr.Column():
- with gr.Tab("Performance-Size Plot"):
+ with gr.Tab("Performance per Model Size"):
plot = gr.Plot(performance_size_plot, inputs=[summary_table])
gr.Markdown(
"*We only display models that have been run on all tasks in the benchmark*"
)
- with gr.Tab("Top 5 Radar Chart"):
+ with gr.Tab("Performance per Task Type (Radar Chart)"):
radar_plot = gr.Plot(radar_chart, inputs=[summary_table])
gr.Markdown(
"*We only display models that have been run on all task types in the benchmark*"
diff --git a/mteb/leaderboard/figures.py b/mteb/leaderboard/figures.py
index 9f3e73f7a4..35f91dd363 100644
--- a/mteb/leaderboard/figures.py
+++ b/mteb/leaderboard/figures.py
@@ -6,6 +6,28 @@
import plotly.graph_objects as go
+def text_plot(text: str):
+ """Returns empty scatter plot with text added, this can be great for error messages."""
+ return px.scatter(template="plotly_white").add_annotation(
+ text=text, showarrow=False, font=dict(size=20)
+ )
+
+
+def failsafe_plot(fun):
+ """Decorator that turns the function producing a figure failsafe.
+ This is necessary, because once a Callback encounters an exception it
+ becomes useless in Gradio.
+ """
+
+ def wrapper(*args, **kwargs):
+ try:
+ return fun(*args, **kwargs)
+ except Exception:
+ return text_plot("Couldn't produce plot.")
+
+ return wrapper
+
+
def parse_n_params(text: str) -> int:
if text.endswith("M"):
return float(text[:-1]) * 1e6
@@ -37,6 +59,48 @@ def parse_float(value) -> float:
]
+def add_size_guide(fig: go.Figure):
+ xpos = [5 * 1e9] * 4
+ ypos = [7.8, 8.5, 9, 10]
+ sizes = [256, 1024, 2048, 4096]
+ fig.add_trace(
+ go.Scatter(
+ showlegend=False,
+ opacity=0.3,
+ mode="markers",
+ marker=dict(
+ size=np.sqrt(sizes),
+ color="rgba(0,0,0,0)",
+ line=dict(color="black", width=2),
+ ),
+ x=xpos,
+ y=ypos,
+ )
+ )
+ fig.add_annotation(
+ text="Embedding Size:",
+ font=dict(size=16),
+ x=np.log10(1.5e9),
+ y=10,
+ showarrow=False,
+ opacity=0.3,
+ )
+ for x, y, size in zip(xpos, np.linspace(7.5, 14, 4), sizes):
+ fig.add_annotation(
+ text=f"{size}",
+ font=dict(size=12),
+ x=np.log10(x),
+ y=y,
+ showarrow=True,
+ ay=0,
+ ax=50,
+ opacity=0.3,
+ arrowwidth=2,
+ )
+ return fig
+
+
+@failsafe_plot
def performance_size_plot(df: pd.DataFrame) -> go.Figure:
df = df.copy()
df["Number of Parameters"] = df["Number of Parameters"].map(parse_n_params)
@@ -50,6 +114,7 @@ def performance_size_plot(df: pd.DataFrame) -> go.Figure:
if not len(df.index):
return go.Figure()
min_score, max_score = df["Mean (Task)"].min(), df["Mean (Task)"].max()
+ df["sqrt(dim)"] = np.sqrt(df["Embedding Dimensions"])
fig = px.scatter(
df,
x="Number of Parameters",
@@ -57,7 +122,7 @@ def performance_size_plot(df: pd.DataFrame) -> go.Figure:
log_x=True,
template="plotly_white",
text="model_text",
- size="Embedding Dimensions",
+ size="sqrt(dim)",
color="Log(Tokens)",
range_color=[2, 5],
range_x=[8 * 1e6, 11 * 1e9],
@@ -69,10 +134,21 @@ def performance_size_plot(df: pd.DataFrame) -> go.Figure:
"Mean (Task)": True,
"Rank (Borda)": True,
"Log(Tokens)": False,
+ "sqrt(dim)": False,
"model_text": False,
},
hover_name="Model",
)
+ # Note: it's important that this comes before setting the size mode
+ fig = add_size_guide(fig)
+ fig.update_traces(
+ marker=dict(
+ sizemode="diameter",
+ sizeref=1.5,
+ sizemin=0,
+ )
+ )
+ fig.add_annotation(x=1e9, y=10, text="Model size:")
fig.update_layout(
coloraxis_colorbar=dict( # noqa
title="Max Tokens",
@@ -124,14 +200,15 @@ def performance_size_plot(df: pd.DataFrame) -> go.Figure:
"#3CBBB1",
]
fill_colors = [
- "rgba(238,66,102,0.2)",
- "rgba(0,166,237,0.2)",
- "rgba(236,167,44,0.2)",
- "rgba(180,35,24,0.2)",
- "rgba(60,187,177,0.2)",
+ "rgba(238,66,102,0.05)",
+ "rgba(0,166,237,0.05)",
+ "rgba(236,167,44,0.05)",
+ "rgba(180,35,24,0.05)",
+ "rgba(60,187,177,0.05)",
]
+@failsafe_plot
def radar_chart(df: pd.DataFrame) -> go.Figure:
df = df.copy()
df["Model"] = df["Model"].map(parse_model_name)
@@ -139,6 +216,10 @@ def radar_chart(df: pd.DataFrame) -> go.Figure:
task_type_columns = [
column for column in df.columns if "".join(column.split()) in task_types
]
+ if len(task_type_columns) <= 1:
+ raise ValueError(
+ "Couldn't produce radar chart, the benchmark only contains one task category."
+ )
df = df[["Model", *task_type_columns]].set_index("Model")
df = df.replace("", np.nan)
df = df.dropna()
@@ -156,7 +237,7 @@ def radar_chart(df: pd.DataFrame) -> go.Figure:
mode="lines",
line=dict(width=2, color=line_colors[i]),
fill="toself",
- fillcolor=fill_colors[i],
+ fillcolor="rgba(0,0,0,0)",
)
)
fig.update_layout(