-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
23 changed files
with
1,660 additions
and
43 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,254 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"id": "9df307b4-54c4-4c0e-a846-cbbe2499ca3d", | ||
"metadata": {}, | ||
"source": [ | ||
"# Comparing two regression models using `stambo`\n", | ||
"\n", | ||
"[![Binder](https://mybinder.org/badge_logo.svg)](https://mybinder.org/v2/gh/Oulu-IMEDS/stambo/main?labpath=notebooks%2FRegression.ipynb)\n", | ||
"\n", | ||
"V1: © Aleksei Tiulpin, PhD, 2024\n", | ||
"\n", | ||
"This notebook shows an end-to-end example on how one can take a dataset, train two machine learning models, and conduct a statistical test to assess whether the two models are different. " | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "baebf8b7-3cea-4739-8f4d-0546b72329f5", | ||
"metadata": {}, | ||
"source": [ | ||
"## Import of necessary libraries" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"id": "091baaa3-7f21-46fa-9e2a-0277fc53a0b2", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import numpy as np\n", | ||
"import stambo\n", | ||
"\n", | ||
"from sklearn.neighbors import KNeighborsRegressor\n", | ||
"from sklearn.linear_model import LinearRegression\n", | ||
"from sklearn.datasets import load_diabetes\n", | ||
"from sklearn.model_selection import train_test_split\n", | ||
"from sklearn.preprocessing import StandardScaler\n", | ||
"from sklearn.metrics import mean_absolute_error, mean_squared_error\n", | ||
"\n", | ||
"SEED = 2024" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "7e066e9d-6f01-4a57-b555-cf24b1022c80", | ||
"metadata": {}, | ||
"source": [ | ||
"## Loading the diabetes dataset and creating train-test split" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 14, | ||
"id": "ba8578f0-0b80-4c21-ab63-c0d4ca81050a", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"X, y = load_diabetes(return_X_y=True)\n", | ||
"Xtr, Xte, ytr, yte = train_test_split(X, y, test_size=0.5, random_state=SEED)\n", | ||
"\n", | ||
"scaler = StandardScaler()\n", | ||
"scaler.fit(Xtr)\n", | ||
"\n", | ||
"Xtr = scaler.transform(Xtr)\n", | ||
"Xte = scaler.transform(Xte)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "3c2cb8d9-343f-42d1-89b9-8fafe0d2c772", | ||
"metadata": {}, | ||
"source": [ | ||
"## Training the models\n", | ||
"\n", | ||
"We train a kNN and a logistic regression. Here, we can see that the logistic regression outperformes the kNN. " | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 16, | ||
"id": "4e2bad13-2678-4df4-bb24-59e0ed90de54", | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"kNN MAE: 51.2489 / LR MAE: 44.3217\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"model = KNeighborsRegressor(n_neighbors=3)\n", | ||
"model.fit(Xtr, ytr)\n", | ||
"preds_knn = model.predict(Xte)\n", | ||
"\n", | ||
"model = LinearRegression()\n", | ||
"model.fit(Xtr, ytr)\n", | ||
"preds_lr = model.predict(Xte)\n", | ||
"\n", | ||
"\n", | ||
"mae_knn, mae_lr = mean_absolute_error(yte, preds_knn), mean_absolute_error(yte, preds_lr)\n", | ||
"print(f\"kNN MAE: {mae_knn:.4f} / LR MAE: {mae_lr:.4f}\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "68886aa2-6944-4af3-afe9-e2a9d070d41f", | ||
"metadata": {}, | ||
"source": [ | ||
"## Statistical testing\n", | ||
"\n", | ||
"As stated in the documentation, the testing routine returns the `dict` of `tuple`. The keys in the dict are the metric tags, and the values are tuples that store the data in the following format:\n", | ||
"\n", | ||
"* p-value ($H_0: model_1 = model_2$)\n", | ||
"* Empirical value (model 1)\n", | ||
"* CI low (model 1)\n", | ||
"* CI high (model 1)\n", | ||
"* Empirical value (model 2)\n", | ||
"* CI low (model 2)\n", | ||
"* CI high (model 2)\n", | ||
"\n", | ||
"If you launch the code in Binder, decrease the number of bootstrap iterations (`10000` by default)." | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "682cd7cd-a685-41fe-8d8b-e76e0259978f", | ||
"metadata": {}, | ||
"source": [ | ||
"**Important to note:** Regression metrics are *errors*, which means that the lower value is better (contrary to classification metrics). Therefore, we actually ask a question whether the kNN has a larger MAE than the linear regression. So, model 1 is here is actually an *improved* model." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 17, | ||
"id": "0ffe5c23-3169-4fe6-9512-4cc91c2bac0c", | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stderr", | ||
"output_type": "stream", | ||
"text": [ | ||
"Bootstrapping: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████| 10000/10000 [00:03<00:00, 2571.56it/s]\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"testing_result = stambo.compare_models(yte, preds_lr, preds_knn, metrics=(\"MAE\", \"MSE\"), seed=SEED)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "8512e7f5-ae48-4615-93f0-e5cee58c9067", | ||
"metadata": {}, | ||
"source": [ | ||
"If we want to visualize the testing results, they are available in a dict in the format we have described above:" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 18, | ||
"id": "bae40894-63b3-4836-a574-da7ea51ad9e6", | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"{'MAE': (0.0007999200079992001,\n", | ||
" 44.321658137655405,\n", | ||
" 40.17458217663324,\n", | ||
" 48.62883263591205,\n", | ||
" 51.248868778280546,\n", | ||
" 46.417496229260934,\n", | ||
" 56.1538838612368),\n", | ||
" 'MSE': (0.0008999100089991,\n", | ||
" 3020.4335055268534,\n", | ||
" 2508.771399983571,\n", | ||
" 3583.767801911658,\n", | ||
" 3978.893916540975,\n", | ||
" 3293.007629462041,\n", | ||
" 4723.732478632479)}" | ||
] | ||
}, | ||
"execution_count": 18, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"testing_result" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "b06b89f2-60bd-4b66-92e3-5add6759236f", | ||
"metadata": {}, | ||
"source": [ | ||
"Most commonly, we though want to visualize them in a report, paper, or a presentation. For that, we can use a function `to_latex`, and get a cut-and-paste `tabular`. To use it in a LaTeX document, one needs to not forget to import booktabs" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 19, | ||
"id": "7c2494d9-6801-4028-9423-ea0bdff22f89", | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"% \\usepackage{booktabs} <-- do not for get to have this imported. \n", | ||
"\\begin{tabular}{lll} \\\\ \n", | ||
"\\toprule \n", | ||
"\\textbf{Model} & \\textbf{MAE} & \\textbf{MSE} \\\\ \n", | ||
"\\midrule \n", | ||
"LR & $44.32$ [$40.17$-$48.63$] & $3020.43$ [$2508.77$-$3583.77$] \\\\ \n", | ||
"kNN & $51.25$ [$46.42$-$56.15$] & $3978.89$ [$3293.01$-$4723.73$] \\\\ \n", | ||
"\\midrule\n", | ||
"$p$-value & $0.00$ & $0.00$ \\\\ \n", | ||
"\\bottomrule\n", | ||
"\\end{tabular}\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"print(stambo.to_latex(testing_result, m2_name=\"kNN\", m1_name=\"LR\"))" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3 (ipykernel)", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.9.18" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 5 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.