From 04be7698ec36f9adddbd7bad4835e4ae7c3ca312 Mon Sep 17 00:00:00 2001 From: Haifeng Wu Date: Wed, 7 Sep 2022 11:24:33 +0800 Subject: [PATCH] add docstring --- hypergbm/experiment.py | 14 ++++++++++++++ hypergbm/hyper_gbm.py | 15 +++++++++++++++ 2 files changed, 29 insertions(+) diff --git a/hypergbm/experiment.py b/hypergbm/experiment.py index 0070481..90d389c 100644 --- a/hypergbm/experiment.py +++ b/hypergbm/experiment.py @@ -333,6 +333,20 @@ def __init__(self, pipeline_model, data, **kwargs): self._explainer = KernelExplainer(pred_f, data=data, keep_index=True, **kwargs) def __call__(self, X, **kwargs): + """Calc explanation of X using shap kernel method. + + Parameters + ---------- + X + kwargs + + Returns + ------- + For classification task, output type is List[Explanation], length is `n_classes` in the model, + and shape of each element is equal to X.shape. + For regression task, output type is Explanation, shape is equal to X.shape + """ + explainer = self._explainer shap_values_data = explainer.shap_values(X, **kwargs) from shap._explanation import Explanation diff --git a/hypergbm/hyper_gbm.py b/hypergbm/hyper_gbm.py index 94daa0f..337ee55 100644 --- a/hypergbm/hyper_gbm.py +++ b/hypergbm/hyper_gbm.py @@ -696,6 +696,21 @@ def expected_values(self): return self._explainers[0].expected_value def __call__(self, X, transform_kwargs=None, **kwargs): + """Calc explanation of X using shap tree method. + + Parameters + ---------- + X + transform_kwargs + kwargs + + Returns + ------- + For cv training, output type is List[Explanation], length is num folds of CV. + For train-test split training, output type is Explanation. if it's a LightGBM training one + classification task the output shape is (Xt_n_rows, Xt_n_cols, n_classes), for other algorithms + output shape is (Xt_n_rows, Xt_n_cols) + """ if transform_kwargs is None: transform_kwargs = {}