Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions mteb/leaderboard/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,12 +218,12 @@ def update_task_info(task_names: str) -> gr.DataFrame:
)
citation = gr.Markdown(update_citation, inputs=[benchmark_select])
with gr.Column():
with gr.Tab("Performance-Size Plot"):
with gr.Tab("Performance per Model Size"):
plot = gr.Plot(performance_size_plot, inputs=[summary_table])
gr.Markdown(
"*We only display models that have been run on all tasks in the benchmark*"
)
with gr.Tab("Top 5 Radar Chart"):
with gr.Tab("Performance per Task Type (Radar Chart)"):
radar_plot = gr.Plot(radar_chart, inputs=[summary_table])
gr.Markdown(
"*We only display models that have been run on all task types in the benchmark*"
Expand Down
95 changes: 88 additions & 7 deletions mteb/leaderboard/figures.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,28 @@
import plotly.graph_objects as go


def text_plot(text: str):
"""Returns empty scatter plot with text added, this can be great for error messages."""
return px.scatter(template="plotly_white").add_annotation(
text=text, showarrow=False, font=dict(size=20)
)


def failsafe_plot(fun):
"""Decorator that turns the function producing a figure failsafe.
This is necessary, because once a Callback encounters an exception it
becomes useless in Gradio.
"""

def wrapper(*args, **kwargs):
try:
return fun(*args, **kwargs)
except Exception:
return text_plot("Couldn't produce plot.")

return wrapper


def parse_n_params(text: str) -> int:
if text.endswith("M"):
return float(text[:-1]) * 1e6
Expand Down Expand Up @@ -37,6 +59,48 @@ def parse_float(value) -> float:
]


def add_size_guide(fig: go.Figure):
xpos = [5 * 1e9] * 4
ypos = [7.8, 8.5, 9, 10]
sizes = [256, 1024, 2048, 4096]
fig.add_trace(
go.Scatter(
showlegend=False,
opacity=0.3,
mode="markers",
marker=dict(
size=np.sqrt(sizes),
color="rgba(0,0,0,0)",
line=dict(color="black", width=2),
),
x=xpos,
y=ypos,
)
)
fig.add_annotation(
text="<b>Embedding Size:</b>",
font=dict(size=16),
x=np.log10(1.5e9),
y=10,
showarrow=False,
opacity=0.3,
)
for x, y, size in zip(xpos, np.linspace(7.5, 14, 4), sizes):
fig.add_annotation(
text=f"<b>{size}</b>",
font=dict(size=12),
x=np.log10(x),
y=y,
showarrow=True,
ay=0,
ax=50,
opacity=0.3,
arrowwidth=2,
)
return fig


@failsafe_plot
def performance_size_plot(df: pd.DataFrame) -> go.Figure:
df = df.copy()
df["Number of Parameters"] = df["Number of Parameters"].map(parse_n_params)
Expand All @@ -50,14 +114,15 @@ def performance_size_plot(df: pd.DataFrame) -> go.Figure:
if not len(df.index):
return go.Figure()
min_score, max_score = df["Mean (Task)"].min(), df["Mean (Task)"].max()
df["sqrt(dim)"] = np.sqrt(df["Embedding Dimensions"])
fig = px.scatter(
df,
x="Number of Parameters",
y="Mean (Task)",
log_x=True,
template="plotly_white",
text="model_text",
size="Embedding Dimensions",
size="sqrt(dim)",
color="Log(Tokens)",
range_color=[2, 5],
range_x=[8 * 1e6, 11 * 1e9],
Expand All @@ -69,10 +134,21 @@ def performance_size_plot(df: pd.DataFrame) -> go.Figure:
"Mean (Task)": True,
"Rank (Borda)": True,
"Log(Tokens)": False,
"sqrt(dim)": False,
"model_text": False,
},
hover_name="Model",
)
# Note: it's important that this comes before setting the size mode
fig = add_size_guide(fig)
fig.update_traces(
marker=dict(
sizemode="diameter",
sizeref=1.5,
sizemin=0,
)
)
fig.add_annotation(x=1e9, y=10, text="Model size:")
fig.update_layout(
coloraxis_colorbar=dict( # noqa
title="Max Tokens",
Expand Down Expand Up @@ -124,21 +200,26 @@ def performance_size_plot(df: pd.DataFrame) -> go.Figure:
"#3CBBB1",
]
fill_colors = [
"rgba(238,66,102,0.2)",
"rgba(0,166,237,0.2)",
"rgba(236,167,44,0.2)",
"rgba(180,35,24,0.2)",
"rgba(60,187,177,0.2)",
"rgba(238,66,102,0.05)",
"rgba(0,166,237,0.05)",
"rgba(236,167,44,0.05)",
"rgba(180,35,24,0.05)",
"rgba(60,187,177,0.05)",
]


@failsafe_plot
def radar_chart(df: pd.DataFrame) -> go.Figure:
df = df.copy()
df["Model"] = df["Model"].map(parse_model_name)
# Remove whitespace
task_type_columns = [
column for column in df.columns if "".join(column.split()) in task_types
]
if len(task_type_columns) <= 1:
raise ValueError(
"Couldn't produce radar chart, the benchmark only contains one task category."
)
df = df[["Model", *task_type_columns]].set_index("Model")
df = df.replace("", np.nan)
df = df.dropna()
Expand All @@ -156,7 +237,7 @@ def radar_chart(df: pd.DataFrame) -> go.Figure:
mode="lines",
line=dict(width=2, color=line_colors[i]),
fill="toself",
fillcolor=fill_colors[i],
fillcolor="rgba(0,0,0,0)",
)
)
fig.update_layout(
Expand Down