diff --git a/sql/codegen_analyze.go b/sql/codegen_analyze.go index dbe4c2976e..cf5adacc43 100644 --- a/sql/codegen_analyze.go +++ b/sql/codegen_analyze.go @@ -45,10 +45,10 @@ func newAnalyzeFiller(pr *extendedSelect, db *DB, fms []*FeatureMeta, label, mod }, nil } -func readAntXGBFeatures(pr *extendedSelect, db *DB) ([]*FeatureMeta, string, error) { +func readXGBFeatures(pr *extendedSelect, db *DB) ([]*FeatureMeta, string, error) { // TODO(weiguo): It's a quick way to read column and label names from // xgboost.*, but too heavy. - fr, err := newAntXGBoostFiller(pr, nil, db) + fr, err := newXGBFiller(pr, nil, db) if err != nil { return nil, "", err } @@ -68,7 +68,7 @@ func readAntXGBFeatures(pr *extendedSelect, db *DB) ([]*FeatureMeta, string, err IsSparse: fr.X[i].IsSparse, } } - return xs, fr.Label, nil + return xs, fr.Y.FeatureName, nil } func readPlotType(pr *extendedSelect) string { @@ -85,19 +85,17 @@ func genAnalyzer(pr *extendedSelect, db *DB, cwd, modelDir string) (*bytes.Buffe if err != nil { return nil, fmt.Errorf("loadModelMeta %v", err) } - if !strings.HasPrefix(strings.ToUpper(pr.estimator), `XGBOOST.`) { + if !strings.HasPrefix(strings.ToUpper(pr.estimator), `XGB.`) { return nil, fmt.Errorf("analyzer: model[%s] not supported", pr.estimator) } - // We untar the AntXGBoost.{pr.trainedModel}.tar.gz and get three files. - // Here, the sqlflow_booster is a raw xgboost binary file can be analyzed. - antXGBModelPath := fmt.Sprintf("%s/sqlflow_booster", pr.trainedModel) + // We untar the XGBoost.{pr.trainedModel}.tar.gz and get three files. plotType := readPlotType(pr) - xs, label, err := readAntXGBFeatures(pr, db) + xs, label, err := readXGBFeatures(pr, db) if err != nil { return nil, err } - fr, err := newAnalyzeFiller(pr, db, xs, label, antXGBModelPath, plotType) + fr, err := newAnalyzeFiller(pr, db, xs, label, pr.trainedModel, plotType) if err != nil { return nil, fmt.Errorf("create analyze filler failed: %v", err) } diff --git a/sql/codegen_xgboost_test.go b/sql/codegen_xgboost_test.go index 0c15f895a2..0b07cc8eca 100644 --- a/sql/codegen_xgboost_test.go +++ b/sql/codegen_xgboost_test.go @@ -32,6 +32,11 @@ COLUMN sepal_length, sepal_width, petal_length, petal_width LABEL class INTO sqlflow_models.my_xgboost_model; ` +const testAnalyzeTreeModelSelectIris = ` +SELECT * FROM iris.train +ANALYZE sqlflow_models.my_xgboost_model +USING TreeExplainer; + ` func TestXGBFiller(t *testing.T) { a := assert.New(t) diff --git a/sql/executor_test.go b/sql/executor_test.go index a6bce4348e..6268b462ec 100644 --- a/sql/executor_test.go +++ b/sql/executor_test.go @@ -104,12 +104,14 @@ func TestExecutorTrainAnalyzePredictAntXGBoost(t *testing.T) { } } -func TestExecutorTrainXGBoost(t *testing.T) { +func TestExecutorXGBoost(t *testing.T) { a := assert.New(t) modelDir := "" a.NotPanics(func() { stream := runExtendedSQL(testXGBoostTrainSelectIris, testDB, modelDir, nil) a.True(goodStream(stream.ReadAll())) + stream = runExtendedSQL(testAnalyzeTreeModelSelectIris, testDB, modelDir, nil) + a.True(goodStream(stream.ReadAll())) }) }