Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update kernelshap_tabular_land_atmosphere.ipynb: increase timeout #871

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
introducing key locally_run to disable time consuming cell for github…
… action
  • Loading branch information
SarahAlidoost committed Dec 19, 2024
commit d92c03b2fc8ae2bd69ec674343c01467f1ab52e8
Original file line number Diff line number Diff line change
@@ -720,7 +720,7 @@
},
{
"cell_type": "markdown",
"id": "80411a9d-881e-4196-8559-17aaadd15841",
"id": "ddb1e4f0-2674-4242-bcd8-abf66f97c611",
"metadata": {},
"source": [
"#### 5 - Run the explainer at one location, several data instances (here as an example one month time series)\n",
@@ -805,6 +805,24 @@
"background_data = x_train.drop(columns=['station', 'date_UTC']).fillna(0).to_numpy()"
]
},
{
"cell_type": "markdown",
"id": "8b612e55-e1ec-40dc-b189-65d90ffb2b1c",
"metadata": {},
"source": [
"This step takes a few minutes, so not suitable for github actions. If you want to run this step locally, set `locally_run = True`."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "59a54eaa-f6f2-42b5-8849-aceb37b06156",
"metadata": {},
"outputs": [],
"source": [
"locally_run = False"
]
},
{
"cell_type": "code",
"execution_count": 14,
@@ -821,11 +839,12 @@
],
"source": [
"# run explainer over time series, this might take a few minutes\n",
"explanations[key] = dianna.explain_tabular(runner, input_tabular=features.values, method='kernelshap',\n",
" mode ='regression', training_data=background_data, training_data_kmeans=5,\n",
" feature_names=features.columns, silent=True)\n",
"\n",
"print(\"Dianna is done!\") "
"if locally_run:\n",
" explanations[key] = dianna.explain_tabular(runner, input_tabular=features.values, method='kernelshap',\n",
" mode ='regression', training_data=background_data, training_data_kmeans=5,\n",
" feature_names=features.columns, silent=True)\n",
" \n",
" print(\"Dianna is done!\") "
]
},
{
@@ -846,30 +865,31 @@
}
],
"source": [
"# create shap_values object\n",
"shap_values = Explanation(explanations[key])\n",
"shap_values.feature_names = features.columns\n",
"\n",
"# create comparison plot: predictions vs test data \n",
"y_predict_time = runner(features.to_numpy())\n",
"y_test_time = y_test[(y_test[\"station\"] == location) & (y_test[\"date_UTC\"].dt.month == month)].drop(columns=['station', 'date_UTC']).fillna(0).to_numpy()\n",
"comparison_plot(y_test_time, y_predict_time, show=False) \n",
"comparison_img = plt.gcf()\n",
"plt.close()\n",
"\n",
"# create summary plot\n",
"shap.summary_plot(shap_values, features.values, feature_names=features.columns, cmap=\"PRGn\", show=False, max_display=15)\n",
"summary_img = plt.gcf()\n",
"plt.close()\n",
"\n",
"# create heatmap plot\n",
"shap.plots.heatmap(shap_values, cmap=\"bwr\", show=False, max_display=15)\n",
"heatmap_img = plt.gcf()\n",
"plt.close()\n",
"\n",
"# plot all three figures in one cell\n",
"figures = [comparison_img, heatmap_img, summary_img]\n",
"display_figures(figures, captions, 1, 3)"
"if locally_run:\n",
" # create shap_values object\n",
" shap_values = Explanation(explanations[key])\n",
" shap_values.feature_names = features.columns\n",
" \n",
" # create comparison plot: predictions vs test data \n",
" y_predict_time = runner(features.to_numpy())\n",
" y_test_time = y_test[(y_test[\"station\"] == location) & (y_test[\"date_UTC\"].dt.month == month)].drop(columns=['station', 'date_UTC']).fillna(0).to_numpy()\n",
" comparison_plot(y_test_time, y_predict_time, show=False) \n",
" comparison_img = plt.gcf()\n",
" plt.close()\n",
" \n",
" # create summary plot\n",
" shap.summary_plot(shap_values, features.values, feature_names=features.columns, cmap=\"PRGn\", show=False, max_display=15)\n",
" summary_img = plt.gcf()\n",
" plt.close()\n",
" \n",
" # create heatmap plot\n",
" shap.plots.heatmap(shap_values, cmap=\"bwr\", show=False, max_display=15)\n",
" heatmap_img = plt.gcf()\n",
" plt.close()\n",
" \n",
" # plot all three figures in one cell\n",
" figures = [comparison_img, heatmap_img, summary_img]\n",
" display_figures(figures, captions, 1, 3)"
]
},
{
@@ -887,9 +907,10 @@
}
],
"source": [
"relative_mae = np.mean(np.abs(y_predict_time - y_test_time))/ np.mean(y_test_time)\n",
"cor = np.corrcoef(y_predict_time.T, y_test_time.T)[0,1]\n",
"print(f\"Relative MAE is {relative_mae} and correlation is {cor}\")"
"if locally_run:\n",
" relative_mae = np.mean(np.abs(y_predict_time - y_test_time))/ np.mean(y_test_time)\n",
" cor = np.corrcoef(y_predict_time.T, y_test_time.T)[0,1]\n",
" print(f\"Relative MAE is {relative_mae} and correlation is {cor}\")"
]
},
{
@@ -947,12 +968,13 @@
}
],
"source": [
"# run explainer over time series, this might take a few minutes\n",
"explanations[key] = dianna.explain_tabular(runner, input_tabular=features.values, method='kernelshap',\n",
" mode ='regression', training_data=background_data, training_data_kmeans=5,\n",
" feature_names=features.columns, silent=True)\n",
"\n",
"print(\"Dianna is done!\") "
"if locally_run:\n",
" # run explainer over time series, this might take a few minutes\n",
" explanations[key] = dianna.explain_tabular(runner, input_tabular=features.values, method='kernelshap',\n",
" mode ='regression', training_data=background_data, training_data_kmeans=5,\n",
" feature_names=features.columns, silent=True)\n",
" \n",
" print(\"Dianna is done!\") "
]
},
{
@@ -973,30 +995,31 @@
}
],
"source": [
"# create shap_values object\n",
"shap_values = Explanation(explanations[key])\n",
"shap_values.feature_names = features.columns\n",
"\n",
"# create comparison plot: predictions vs test data \n",
"y_predict_time = runner(features.to_numpy())\n",
"y_test_time = y_test[(y_test[\"station\"] == location) & (y_test[\"date_UTC\"].dt.month == month)].drop(columns=['station', 'date_UTC']).fillna(0).to_numpy()\n",
"comparison_plot(y_test_time, y_predict_time, show=False) \n",
"comparison_img = plt.gcf()\n",
"plt.close()\n",
"\n",
"# create summary plot\n",
"shap.summary_plot(shap_values, features.values, feature_names=features.columns, cmap=\"PRGn\", show=False, max_display=15)\n",
"summary_img = plt.gcf()\n",
"plt.close()\n",
"\n",
"# create heatmap plot\n",
"shap.plots.heatmap(shap_values, cmap=\"bwr\", show=False, max_display=15)\n",
"heatmap_img = plt.gcf()\n",
"plt.close()\n",
"\n",
"# plot all three figures in one cell\n",
"figures = [comparison_img, heatmap_img, summary_img]\n",
"display_figures(figures, captions, 1, 3)"
"if locally_run:\n",
" # create shap_values object\n",
" shap_values = Explanation(explanations[key])\n",
" shap_values.feature_names = features.columns\n",
" \n",
" # create comparison plot: predictions vs test data \n",
" y_predict_time = runner(features.to_numpy())\n",
" y_test_time = y_test[(y_test[\"station\"] == location) & (y_test[\"date_UTC\"].dt.month == month)].drop(columns=['station', 'date_UTC']).fillna(0).to_numpy()\n",
" comparison_plot(y_test_time, y_predict_time, show=False) \n",
" comparison_img = plt.gcf()\n",
" plt.close()\n",
" \n",
" # create summary plot\n",
" shap.summary_plot(shap_values, features.values, feature_names=features.columns, cmap=\"PRGn\", show=False, max_display=15)\n",
" summary_img = plt.gcf()\n",
" plt.close()\n",
" \n",
" # create heatmap plot\n",
" shap.plots.heatmap(shap_values, cmap=\"bwr\", show=False, max_display=15)\n",
" heatmap_img = plt.gcf()\n",
" plt.close()\n",
" \n",
" # plot all three figures in one cell\n",
" figures = [comparison_img, heatmap_img, summary_img]\n",
" display_figures(figures, captions, 1, 3)"
]
},
{
@@ -1014,9 +1037,10 @@
}
],
"source": [
"relative_mae = np.mean(np.abs(y_predict_time - y_test_time))/ np.mean(y_test_time)\n",
"cor = np.corrcoef(y_predict_time.T, y_test_time.T)[0,1]\n",
"print(f\"Relative MAE is {relative_mae} and correlation is {cor}\")"
"if locally_run:\n",
" relative_mae = np.mean(np.abs(y_predict_time - y_test_time))/ np.mean(y_test_time)\n",
" cor = np.corrcoef(y_predict_time.T, y_test_time.T)[0,1]\n",
" print(f\"Relative MAE is {relative_mae} and correlation is {cor}\")"
]
},
{
@@ -1166,6 +1190,9 @@
}
],
"metadata": {
"execution": {
"timeout": 1800
},
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
@@ -1182,9 +1209,6 @@
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.9"
},
"execution": {
"timeout": 1800
}
},
"nbformat": 4,