| 
 | 1 | +{  | 
 | 2 | + "cells": [  | 
 | 3 | +  {  | 
 | 4 | +   "cell_type": "code",  | 
 | 5 | +   "execution_count": null,  | 
 | 6 | +   "id": "702bb87d",  | 
 | 7 | +   "metadata": {},  | 
 | 8 | +   "outputs": [],  | 
 | 9 | +   "source": [  | 
 | 10 | +    "import matplotlib.pyplot as plt\n",  | 
 | 11 | +    "import numpy as np\n",  | 
 | 12 | +    "import pandas as pd\n",  | 
 | 13 | +    "\n",  | 
 | 14 | +    "df = pd.read_csv(\"https://github.com/user-attachments/files/21469860/benchmark_results.csv\")"  | 
 | 15 | +   ]  | 
 | 16 | +  },  | 
 | 17 | +  {  | 
 | 18 | +   "cell_type": "code",  | 
 | 19 | +   "execution_count": null,  | 
 | 20 | +   "id": "c0dd4297",  | 
 | 21 | +   "metadata": {},  | 
 | 22 | +   "outputs": [],  | 
 | 23 | +   "source": [  | 
 | 24 | +    "N_BOOTSTRAPS=100\n",  | 
 | 25 | +    "\n",  | 
 | 26 | +    "def generate_plots(df, metric=\"r2_train\", exclude = [\"SupportVectorMachine\", \"LightGBM\"], fontsize=\"small\"):\n",  | 
 | 27 | +    "    simulator_list = sorted(df[\"simulator\"].unique().tolist())\n",  | 
 | 28 | +    "    n_iter_list = sorted(df[\"n_iter\"].unique().tolist())\n",  | 
 | 29 | +    "    n_splits_list = sorted(df[\"n_splits\"].unique().tolist())\n",  | 
 | 30 | +    "    color = {name:f\"C{idx}\" for idx, name in enumerate(sorted(df[\"model_name\"].unique().tolist()))}\n",  | 
 | 31 | +    "    for plot_idx, simulator in enumerate(simulator_list):\n",  | 
 | 32 | +    "        fig, axs = plt.subplots(len(n_splits_list), len(n_iter_list), figsize=(12, 6), squeeze=False)\n",  | 
 | 33 | +    "        handles = []\n",  | 
 | 34 | +    "        labels = []\n",  | 
 | 35 | +    "        for row_idx, n_splits in enumerate(n_splits_list):\n",  | 
 | 36 | +    "            for col_idx, n_iter in enumerate(n_iter_list):\n",  | 
 | 37 | +    "                subset = df[df[\"simulator\"].eq(simulator) & df[\"n_splits\"].eq(n_splits) & df[\"n_iter\"].eq(n_iter)]\n",  | 
 | 38 | +    "                ax = axs[row_idx][col_idx]\n",  | 
 | 39 | +    "                for idx, ((name,), group) in enumerate(subset.groupby([\"model_name\"], sort=True)): \n",  | 
 | 40 | +    "                    if name in exclude:\n",  | 
 | 41 | +    "                        continue\n",  | 
 | 42 | +    "                    group_sorted = group.sort_values(\"n_samples\")\n",  | 
 | 43 | +    "                    line = ax.plot(group_sorted[\"n_samples\"], group_sorted[metric], label=name, c=color[name])\n",  | 
 | 44 | +    "\n",  | 
 | 45 | +    "                    if row_idx == 0 and col_idx == 0:\n",  | 
 | 46 | +    "                        handles.append(line[0])\n",  | 
 | 47 | +    "                        labels.append(name)\n",  | 
 | 48 | +    "                    \n",  | 
 | 49 | +    "                    mean = group_sorted[metric]\n",  | 
 | 50 | +    "                    ste = group_sorted[f\"{metric}_std\"] / np.sqrt(N_BOOTSTRAPS)\n",  | 
 | 51 | +    "                    ax.fill_between(group_sorted[\"n_samples\"], mean - ste, mean + ste, alpha=0.2, lw=0, color=color[name])\n",  | 
 | 52 | +    "                ax.set_ylim(-0.1, 1.05)\n",  | 
 | 53 | +    "                # ax.set_xlim(df[\"n_samples\"].min(), df[\"n_samples\"].max())\n",  | 
 | 54 | +    "                ax.set_xlim(10, df[\"n_samples\"].max())\n",  | 
 | 55 | +    "                ax.axhline(0., lw=0.5, ls=\"--\", c=\"grey\", alpha=0.5, zorder=-1)\n",  | 
 | 56 | +    "                \n",  | 
 | 57 | +    "                ax.set_xscale(\"log\")\n",  | 
 | 58 | +    "                # ax.set_yscale(\"log\")\n",  | 
 | 59 | +    "                if col_idx == 0:\n",  | 
 | 60 | +    "                    ax.set_ylabel(metric, size=fontsize)\n",  | 
 | 61 | +    "                if row_idx == len(n_splits_list)-1:\n",  | 
 | 62 | +    "                    ax.set_xlabel(\"n_samples\", size=fontsize)\n",  | 
 | 63 | +    "                ax.tick_params(labelsize=fontsize)\n",  | 
 | 64 | +    "                ax.set_title(f\"{simulator} (n_iter={n_iter}, n_splits={n_splits})\", size=fontsize)\n",  | 
 | 65 | +    "                ax.grid(True, which='both', linestyle=':', linewidth=0.5, alpha=0.7)\n",  | 
 | 66 | +    "        fig.legend(handles, labels, loc='upper center', bbox_to_anchor=(0.5, 0.98), ncol=df[\"model_name\"].nunique()-len(exclude), fontsize=fontsize)\n",  | 
 | 67 | +    "        \n",  | 
 | 68 | +    "        # Adjust layout to make room for legend\n",  | 
 | 69 | +    "        plt.tight_layout()\n",  | 
 | 70 | +    "        plt.subplots_adjust(top=0.88)\n",  | 
 | 71 | +    "        \n",  | 
 | 72 | +    "        plt.show()\n"  | 
 | 73 | +   ]  | 
 | 74 | +  },  | 
 | 75 | +  {  | 
 | 76 | +   "cell_type": "code",  | 
 | 77 | +   "execution_count": null,  | 
 | 78 | +   "id": "be99a004",  | 
 | 79 | +   "metadata": {},  | 
 | 80 | +   "outputs": [],  | 
 | 81 | +   "source": [  | 
 | 82 | +    "# All models\n",  | 
 | 83 | +    "generate_plots(df, metric=\"r2_test\", exclude=[])\n"  | 
 | 84 | +   ]  | 
 | 85 | +  },  | 
 | 86 | +  {  | 
 | 87 | +   "cell_type": "code",  | 
 | 88 | +   "execution_count": null,  | 
 | 89 | +   "id": "ffa939d7",  | 
 | 90 | +   "metadata": {},  | 
 | 91 | +   "outputs": [],  | 
 | 92 | +   "source": [  | 
 | 93 | +    "# GPs, ensembles and MLPs only\n",  | 
 | 94 | +    "generate_plots(df, metric=\"r2_test\", exclude=[\"RandomForest\", \"LightGBM\", \"SupportVectorMachine\", \"RadialBasisFunctions\"])"  | 
 | 95 | +   ]  | 
 | 96 | +  },  | 
 | 97 | +  {  | 
 | 98 | +   "cell_type": "code",  | 
 | 99 | +   "execution_count": null,  | 
 | 100 | +   "id": "a313ab5c",  | 
 | 101 | +   "metadata": {},  | 
 | 102 | +   "outputs": [],  | 
 | 103 | +   "source": []  | 
 | 104 | +  }  | 
 | 105 | + ],  | 
 | 106 | + "metadata": {  | 
 | 107 | +  "kernelspec": {  | 
 | 108 | +   "display_name": ".venv",  | 
 | 109 | +   "language": "python",  | 
 | 110 | +   "name": "python3"  | 
 | 111 | +  },  | 
 | 112 | +  "language_info": {  | 
 | 113 | +   "codemirror_mode": {  | 
 | 114 | +    "name": "ipython",  | 
 | 115 | +    "version": 3  | 
 | 116 | +   },  | 
 | 117 | +   "file_extension": ".py",  | 
 | 118 | +   "mimetype": "text/x-python",  | 
 | 119 | +   "name": "python",  | 
 | 120 | +   "nbconvert_exporter": "python",  | 
 | 121 | +   "pygments_lexer": "ipython3",  | 
 | 122 | +   "version": "3.12.11"  | 
 | 123 | +  }  | 
 | 124 | + },  | 
 | 125 | + "nbformat": 4,  | 
 | 126 | + "nbformat_minor": 5  | 
 | 127 | +}  | 
0 commit comments