Skip to content

Commit

Permalink
add docstring
Browse files Browse the repository at this point in the history
  • Loading branch information
oaksharks committed Sep 7, 2022
1 parent efba2ca commit 04be769
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 0 deletions.
14 changes: 14 additions & 0 deletions hypergbm/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,20 @@ def __init__(self, pipeline_model, data, **kwargs):
self._explainer = KernelExplainer(pred_f, data=data, keep_index=True, **kwargs)

def __call__(self, X, **kwargs):
"""Calc explanation of X using shap kernel method.
Parameters
----------
X
kwargs
Returns
-------
For classification task, output type is List[Explanation], length is `n_classes` in the model,
and shape of each element is equal to X.shape.
For regression task, output type is Explanation, shape is equal to X.shape
"""

explainer = self._explainer
shap_values_data = explainer.shap_values(X, **kwargs)
from shap._explanation import Explanation
Expand Down
15 changes: 15 additions & 0 deletions hypergbm/hyper_gbm.py
Original file line number Diff line number Diff line change
Expand Up @@ -696,6 +696,21 @@ def expected_values(self):
return self._explainers[0].expected_value

def __call__(self, X, transform_kwargs=None, **kwargs):
"""Calc explanation of X using shap tree method.
Parameters
----------
X
transform_kwargs
kwargs
Returns
-------
For cv training, output type is List[Explanation], length is num folds of CV.
For train-test split training, output type is Explanation. if it's a LightGBM training one
classification task the output shape is (Xt_n_rows, Xt_n_cols, n_classes), for other algorithms
output shape is (Xt_n_rows, Xt_n_cols)
"""

if transform_kwargs is None:
transform_kwargs = {}
Expand Down

0 comments on commit 04be769

Please sign in to comment.