From d8eb1d89fe8767671017da11ade1bb19f8534f4a Mon Sep 17 00:00:00 2001 From: Yancey1989 Date: Fri, 6 Sep 2019 16:04:26 +0800 Subject: [PATCH] shap with xgboost --- sql/codegen_analyze.go | 15 +++++++-------- sql/codegen_xgboost_test.go | 5 +++++ sql/executor_test.go | 4 +++- 3 files changed, 15 insertions(+), 9 deletions(-) diff --git a/sql/codegen_analyze.go b/sql/codegen_analyze.go index 588eab24cd..dc6b7940b0 100644 --- a/sql/codegen_analyze.go +++ b/sql/codegen_analyze.go @@ -43,10 +43,10 @@ func newAnalyzeFiller(pr *extendedSelect, db *DB, fms []*FeatureMeta, label, mod }, nil } -func readAntXGBFeatures(pr *extendedSelect, db *DB) ([]*FeatureMeta, string, error) { +func readXGBFeatures(pr *extendedSelect, db *DB) ([]*FeatureMeta, string, error) { // TODO(weiguo): It's a quick way to read column and label names from // xgboost.*, but too heavy. - fr, err := newAntXGBoostFiller(pr, nil, db) + fr, err := newXGBFiller(pr, nil, db) if err != nil { return nil, "", err } @@ -66,7 +66,7 @@ func readAntXGBFeatures(pr *extendedSelect, db *DB) ([]*FeatureMeta, string, err IsSparse: fr.X[i].IsSparse, } } - return xs, fr.Label, nil + return xs, fr.Y.FeatureName, nil } func genAnalyzer(pr *extendedSelect, db *DB, cwd, modelDir string) (*bytes.Buffer, error) { @@ -74,18 +74,17 @@ func genAnalyzer(pr *extendedSelect, db *DB, cwd, modelDir string) (*bytes.Buffe if err != nil { return nil, fmt.Errorf("loadModelMeta %v", err) } - if !strings.HasPrefix(strings.ToUpper(pr.estimator), `XGBOOST.`) { + if !strings.HasPrefix(strings.ToUpper(pr.estimator), `XGB.`) { return nil, fmt.Errorf("analyzer: model[%s] not supported", pr.estimator) } - // We untar the AntXGBoost.{pr.trainedModel}.tar.gz and get three files. + // We untar the XGBoost.{pr.trainedModel}.tar.gz and get three files. // Here, the sqlflow_booster is a raw xgboost binary file can be analyzed. - antXGBModelPath := fmt.Sprintf("%s/sqlflow_booster", pr.trainedModel) - xs, label, err := readAntXGBFeatures(pr, db) + xs, label, err := readXGBFeatures(pr, db) if err != nil { return nil, err } - fr, err := newAnalyzeFiller(pr, db, xs, label, antXGBModelPath) + fr, err := newAnalyzeFiller(pr, db, xs, label, pr.trainedModel) if err != nil { return nil, fmt.Errorf("create analyze filler failed: %v", err) } diff --git a/sql/codegen_xgboost_test.go b/sql/codegen_xgboost_test.go index 0c15f895a2..0b07cc8eca 100644 --- a/sql/codegen_xgboost_test.go +++ b/sql/codegen_xgboost_test.go @@ -32,6 +32,11 @@ COLUMN sepal_length, sepal_width, petal_length, petal_width LABEL class INTO sqlflow_models.my_xgboost_model; ` +const testAnalyzeTreeModelSelectIris = ` +SELECT * FROM iris.train +ANALYZE sqlflow_models.my_xgboost_model +USING TreeExplainer; + ` func TestXGBFiller(t *testing.T) { a := assert.New(t) diff --git a/sql/executor_test.go b/sql/executor_test.go index a6bce4348e..6268b462ec 100644 --- a/sql/executor_test.go +++ b/sql/executor_test.go @@ -104,12 +104,14 @@ func TestExecutorTrainAnalyzePredictAntXGBoost(t *testing.T) { } } -func TestExecutorTrainXGBoost(t *testing.T) { +func TestExecutorXGBoost(t *testing.T) { a := assert.New(t) modelDir := "" a.NotPanics(func() { stream := runExtendedSQL(testXGBoostTrainSelectIris, testDB, modelDir, nil) a.True(goodStream(stream.ReadAll())) + stream = runExtendedSQL(testAnalyzeTreeModelSelectIris, testDB, modelDir, nil) + a.True(goodStream(stream.ReadAll())) }) }