diff --git a/mteb/leaderboard/app.py b/mteb/leaderboard/app.py index a6dd1c7325..6838891e02 100644 --- a/mteb/leaderboard/app.py +++ b/mteb/leaderboard/app.py @@ -218,12 +218,12 @@ def update_task_info(task_names: str) -> gr.DataFrame: ) citation = gr.Markdown(update_citation, inputs=[benchmark_select]) with gr.Column(): - with gr.Tab("Performance-Size Plot"): + with gr.Tab("Performance per Model Size"): plot = gr.Plot(performance_size_plot, inputs=[summary_table]) gr.Markdown( "*We only display models that have been run on all tasks in the benchmark*" ) - with gr.Tab("Top 5 Radar Chart"): + with gr.Tab("Performance per Task Type (Radar Chart)"): radar_plot = gr.Plot(radar_chart, inputs=[summary_table]) gr.Markdown( "*We only display models that have been run on all task types in the benchmark*" diff --git a/mteb/leaderboard/figures.py b/mteb/leaderboard/figures.py index 9f3e73f7a4..35f91dd363 100644 --- a/mteb/leaderboard/figures.py +++ b/mteb/leaderboard/figures.py @@ -6,6 +6,28 @@ import plotly.graph_objects as go +def text_plot(text: str): + """Returns empty scatter plot with text added, this can be great for error messages.""" + return px.scatter(template="plotly_white").add_annotation( + text=text, showarrow=False, font=dict(size=20) + ) + + +def failsafe_plot(fun): + """Decorator that turns the function producing a figure failsafe. + This is necessary, because once a Callback encounters an exception it + becomes useless in Gradio. + """ + + def wrapper(*args, **kwargs): + try: + return fun(*args, **kwargs) + except Exception: + return text_plot("Couldn't produce plot.") + + return wrapper + + def parse_n_params(text: str) -> int: if text.endswith("M"): return float(text[:-1]) * 1e6 @@ -37,6 +59,48 @@ def parse_float(value) -> float: ] +def add_size_guide(fig: go.Figure): + xpos = [5 * 1e9] * 4 + ypos = [7.8, 8.5, 9, 10] + sizes = [256, 1024, 2048, 4096] + fig.add_trace( + go.Scatter( + showlegend=False, + opacity=0.3, + mode="markers", + marker=dict( + size=np.sqrt(sizes), + color="rgba(0,0,0,0)", + line=dict(color="black", width=2), + ), + x=xpos, + y=ypos, + ) + ) + fig.add_annotation( + text="Embedding Size:", + font=dict(size=16), + x=np.log10(1.5e9), + y=10, + showarrow=False, + opacity=0.3, + ) + for x, y, size in zip(xpos, np.linspace(7.5, 14, 4), sizes): + fig.add_annotation( + text=f"{size}", + font=dict(size=12), + x=np.log10(x), + y=y, + showarrow=True, + ay=0, + ax=50, + opacity=0.3, + arrowwidth=2, + ) + return fig + + +@failsafe_plot def performance_size_plot(df: pd.DataFrame) -> go.Figure: df = df.copy() df["Number of Parameters"] = df["Number of Parameters"].map(parse_n_params) @@ -50,6 +114,7 @@ def performance_size_plot(df: pd.DataFrame) -> go.Figure: if not len(df.index): return go.Figure() min_score, max_score = df["Mean (Task)"].min(), df["Mean (Task)"].max() + df["sqrt(dim)"] = np.sqrt(df["Embedding Dimensions"]) fig = px.scatter( df, x="Number of Parameters", @@ -57,7 +122,7 @@ def performance_size_plot(df: pd.DataFrame) -> go.Figure: log_x=True, template="plotly_white", text="model_text", - size="Embedding Dimensions", + size="sqrt(dim)", color="Log(Tokens)", range_color=[2, 5], range_x=[8 * 1e6, 11 * 1e9], @@ -69,10 +134,21 @@ def performance_size_plot(df: pd.DataFrame) -> go.Figure: "Mean (Task)": True, "Rank (Borda)": True, "Log(Tokens)": False, + "sqrt(dim)": False, "model_text": False, }, hover_name="Model", ) + # Note: it's important that this comes before setting the size mode + fig = add_size_guide(fig) + fig.update_traces( + marker=dict( + sizemode="diameter", + sizeref=1.5, + sizemin=0, + ) + ) + fig.add_annotation(x=1e9, y=10, text="Model size:") fig.update_layout( coloraxis_colorbar=dict( # noqa title="Max Tokens", @@ -124,14 +200,15 @@ def performance_size_plot(df: pd.DataFrame) -> go.Figure: "#3CBBB1", ] fill_colors = [ - "rgba(238,66,102,0.2)", - "rgba(0,166,237,0.2)", - "rgba(236,167,44,0.2)", - "rgba(180,35,24,0.2)", - "rgba(60,187,177,0.2)", + "rgba(238,66,102,0.05)", + "rgba(0,166,237,0.05)", + "rgba(236,167,44,0.05)", + "rgba(180,35,24,0.05)", + "rgba(60,187,177,0.05)", ] +@failsafe_plot def radar_chart(df: pd.DataFrame) -> go.Figure: df = df.copy() df["Model"] = df["Model"].map(parse_model_name) @@ -139,6 +216,10 @@ def radar_chart(df: pd.DataFrame) -> go.Figure: task_type_columns = [ column for column in df.columns if "".join(column.split()) in task_types ] + if len(task_type_columns) <= 1: + raise ValueError( + "Couldn't produce radar chart, the benchmark only contains one task category." + ) df = df[["Model", *task_type_columns]].set_index("Model") df = df.replace("", np.nan) df = df.dropna() @@ -156,7 +237,7 @@ def radar_chart(df: pd.DataFrame) -> go.Figure: mode="lines", line=dict(width=2, color=line_colors[i]), fill="toself", - fillcolor=fill_colors[i], + fillcolor="rgba(0,0,0,0)", ) ) fig.update_layout(