From e167c1aa4ef7fb45bb8eaae2e7da68e8a4231182 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9B=84=E6=9C=A8?= Date: Fri, 6 Sep 2019 14:02:37 +0800 Subject: [PATCH 1/5] xgboost predict --- sql/codegen_xgboost.go | 46 ++++++++++++------ sql/codegen_xgboost_test.go | 23 +++++++-- sql/executor.go | 4 ++ sql/template_xgboost.go | 97 +++++++++++++++++++++++++++++++------ 4 files changed, 137 insertions(+), 33 deletions(-) diff --git a/sql/codegen_xgboost.go b/sql/codegen_xgboost.go index 9d6f6b95f1..1bbf1de833 100644 --- a/sql/codegen_xgboost.go +++ b/sql/codegen_xgboost.go @@ -79,7 +79,7 @@ func resolveParamsCfg(attrs map[string]*attribute) (map[string]interface{}, erro func resolveObjective(pr *extendedSelect) (string, error) { estimatorParts := strings.Split(pr.estimator, ".") if len(estimatorParts) != 3 { - return "", fmt.Errorf("XGBoost Estimator should be xgboost.first_part.second_part") + return "", fmt.Errorf("XGBoost Estimator should be xgboost.first_part.second_part, current: %s", pr.estimator) } return strings.Join(estimatorParts[1:], ":"), nil } @@ -90,6 +90,7 @@ func newXGBFiller(pr *extendedSelect, ds *trainAndValDataset, db *DB) (*xgbFille return nil, err } training, validation := trainingAndValidationDataset(pr, ds) + isTrain := pr.train r := &xgbFiller{ Estimator: Estimator{ IsTrain: pr.train, @@ -100,24 +101,33 @@ func newXGBFiller(pr *extendedSelect, ds *trainAndValDataset, db *DB) (*xgbFille Save: pr.save, } - // resolve the attribute keys without any prefix as the XGBoost Paremeters - params, err := resolveParamsCfg(attrs) - if err != nil { - return nil, err + if !isTrain { + r.PredictionDatasetSQL = pr.standardSelect.String() + if r.TableName, _, err = parseTableColumn(pr.into); err != nil { + return nil, err + } } - // fill learning target - objective, err := resolveObjective(pr) - if err != nil { - return nil, err - } - params["objective"] = objective + if isTrain { + // resolve the attribute keys without any prefix as the XGBoost Paremeters + params, err := resolveParamsCfg(attrs) + if err != nil { + return nil, err + } - paramsJSON, err := json.Marshal(params) - if err != nil { - return nil, err + // fill learning target + objective, err := resolveObjective(pr) + if err != nil { + return nil, err + } + params["objective"] = objective + + paramsJSON, err := json.Marshal(params) + if err != nil { + return nil, err + } + r.ParamsCfgJSON = string(paramsJSON) } - r.ParamsCfgJSON = string(paramsJSON) if r.connectionConfig, err = newConnectionConfig(db); err != nil { return nil, err @@ -161,7 +171,11 @@ func genXGBoost(w io.Writer, pr *extendedSelect, ds *trainAndValDataset, fts fie if pr.train { return xgbTrainTemplate.Execute(w, r) } - return fmt.Errorf("xgboost prediction codegen has not been implemented") + if e := createPredictionTable(pr, db); e != nil { + return fmt.Errorf("failed to create prediction table: %v", e) + } + return xgbPredictTemplate.Execute(w, r) } var xgbTrainTemplate = template.Must(template.New("codegenXGBTrain").Parse(xgbTrainTemplateText)) +var xgbPredictTemplate = template.Must(template.New("codegenXGBPredict").Parse(xgbPredictTemplateText)) diff --git a/sql/codegen_xgboost_test.go b/sql/codegen_xgboost_test.go index 0c15f895a2..799fcab95e 100644 --- a/sql/codegen_xgboost_test.go +++ b/sql/codegen_xgboost_test.go @@ -25,14 +25,21 @@ SELECT * FROM iris.train TRAIN xgb.multi.softprob WITH - train.num_boost_round = 30, - eta = 3.1, - num_class = 3 + train.num_boost_round = 30, + eta = 3.1, + num_class = 3 COLUMN sepal_length, sepal_width, petal_length, petal_width LABEL class INTO sqlflow_models.my_xgboost_model; ` +const testXGBoostPredictIris = ` +SELECT * +FROM iris.test +PREDICT iris.predict.class +USING sqlflow_models.my_xgboost_model; +` + func TestXGBFiller(t *testing.T) { a := assert.New(t) parser := newParser() @@ -51,3 +58,13 @@ func TestXGBFiller(t *testing.T) { a.NoError(err) a.Equal(filler.ParamsCfgJSON, string(paramsJSON)) } + +func TestXGBFillerPredict(t *testing.T) { + a := assert.New(t) + parser := newParser() + r, e := parser.Parse(testXGBoostPredictIris) + a.NoError(e) + filler, e := newXGBFiller(r, nil, testDB) + a.NoError(e) + a.False(filler.IsTrain) +} diff --git a/sql/executor.go b/sql/executor.go index 08f71462e0..95f344243c 100644 --- a/sql/executor.go +++ b/sql/executor.go @@ -463,6 +463,10 @@ func pred(wr *PipeWriter, pr *extendedSelect, db *DB, cwd string, modelDir strin if e := genAntXGBoost(&buf, pr, nil, fts, db); e != nil { return fmt.Errorf("genAntXGBoost %v", e) } + } else if strings.HasPrefix(strings.ToUpper(pr.estimator), `XGB.`) { + if e := genXGBoost(&buf, pr, nil, fts, db); e != nil { + return fmt.Errorf("genXGBoost %v", e) + } } else { if e := genTF(&buf, pr, nil, fts, db); e != nil { return fmt.Errorf("genTF %v", e) diff --git a/sql/template_xgboost.go b/sql/template_xgboost.go index 5f81b40992..4946480e9b 100644 --- a/sql/template_xgboost.go +++ b/sql/template_xgboost.go @@ -34,7 +34,7 @@ num_boost_round = {{.NumBoostRound}} maximize = True if "{{.Maximize}}" == "true" else False early_stopping_rounds = {{.EarlyStoppingRounds}} if early_stopping_rounds == -1: - early_stopping_rounds = None + early_stopping_rounds = None {{if ne .ParamsCfgJSON ""}} params = {{.ParamsCfgJSON}} @@ -58,22 +58,20 @@ feature_specs["{{$value.FeatureName}}"] = { } {{end}} - - conn = connect(driver, database, user="{{.User}}", password="{{.Password}}", host="{{.Host}}", port={{.Port}}, auth="{{.Auth}}") def xgb_dataset(fn, dataset_sql): - gen = db_generator(driver, conn, session_cfg, dataset_sql, feature_column_names, "{{.Y.FeatureName}}", feature_specs) - with open(fn, 'w') as f: - for item in gen(): - features, label = item - row_data = [str(label[0])] + ["%d:%f" % (i, v) for i, v in enumerate(features)] - f.write("\t".join(row_data) + "\n") - # TODO(yancey1989): genearte group and weight text file if necessary - return xgb.DMatrix(fn) - -dtrain = xgb_dataset('train.txt', "{{.TrainingDatasetSQL}}") -dtest = xgb_dataset('test.txt', "{{.ValidationDatasetSQL}}") + gen = db_generator(driver, conn, session_cfg, dataset_sql, feature_column_names, "{{.Y.FeatureName}}", feature_specs) + with open(fn, 'w') as f: + for item in gen(): + features, label = item + row_data = [str(label[0])] + ["%d:%f" % (i, v) for i, v in enumerate(features)] + f.write("\t".join(row_data) + "\n") + # TODO(yancey1989): genearte group and weight text file if necessary + return xgb.DMatrix(fn) + +dtrain = xgb_dataset('train.txt', """{{.TrainingDatasetSQL}}""") +dtest = xgb_dataset('test.txt', """{{.ValidationDatasetSQL}}""") train_args = {} train_args["num_boost_round"] = num_boost_round @@ -84,3 +82,74 @@ train_args["evals"] = [(dtrain, "train"), (dtest, "validation")] bst = xgb.train(params, dtrain, **train_args) bst.save_model("{{.Save}}") ` + +const xgbPredictTemplateText = ` +import xgboost as xgb +import numpy as np +from sqlflow_submitter.db import connect, db_generator, buffered_db_writer + +driver="{{.Driver}}" + +{{if ne .Database ""}} +database="{{.Database}}" +{{else}} +database="" +{{end}} + +session_cfg = {} +{{ range $k, $v := .Session }} +session_cfg["{{$k}}"] = "{{$v}}" +{{end}} + +feature_column_names = [{{range .X}} +"{{.FeatureName}}", +{{end}}] + +{{/* Convert go side featureSpec to python dict for input_fn */}} +feature_specs = dict() +{{ range $value := .X }} +feature_specs["{{$value.FeatureName}}"] = { + "feature_name": "{{$value.FeatureName}}", + "dtype": "{{$value.Dtype}}", + "delimiter": "{{$value.Delimiter}}", + "shape": {{$value.InputShape}}, + "is_sparse": "{{$value.IsSparse}}" == "true" +} +{{end}} + +conn = connect(driver, database, user="{{.User}}", password="{{.Password}}", host="{{.Host}}", port={{.Port}}, auth="{{.Auth}}") + +def xgb_dataset(fn, dataset_sql): + gen = db_generator(driver, conn, session_cfg, dataset_sql, feature_column_names, "", feature_specs) + with open(fn, 'w') as f: + for item in gen(): + features, label = item + row_data = [str(label[0])] + ["%d:%f" % (i, v) for i, v in enumerate(features)] + f.write("\t".join(row_data) + "\n") + # TODO(yancey1989): genearte group and weight text file if necessary + return xgb.DMatrix(fn) + +dpred = xgb_dataset('predict.txt', """{{.PredictionDatasetSQL}}""") + +bst = xgb.Booster({'nthread': 4}) # init model +bst.load_model("{{.Save}}") # load data +preds = bst.predict(dpred) +# TODO(typhoonzero): regression models may have different behavior +pred_classes = np.argmax(np.array(preds), axis=1) + +feature_file_read = open("predict.txt", "r") + +result_column_names = feature_column_names +result_column_names.append("{{.Y.FeatureName}}") + +line_no = 0 +with buffered_db_writer(driver, conn, "{{.TableName}}", result_column_names, 100) as w: + while True: + line = feature_file_read.readline() + if not line: + break + row = [float(i.split(":")[1]) for i in line.replace("\n", "").split("\t")[1:]] + row.append(pred_classes[line_no]) + w.write(row) + line_no += 1 +` From 3e04ae46d1156e8329b2a9774d560c7055e81339 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9B=84=E6=9C=A8?= Date: Fri, 6 Sep 2019 14:03:59 +0800 Subject: [PATCH 2/5] refine --- sql/template_xgboost.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/template_xgboost.go b/sql/template_xgboost.go index 4946480e9b..e6aa7d185a 100644 --- a/sql/template_xgboost.go +++ b/sql/template_xgboost.go @@ -148,7 +148,7 @@ with buffered_db_writer(driver, conn, "{{.TableName}}", result_column_names, 100 line = feature_file_read.readline() if not line: break - row = [float(i.split(":")[1]) for i in line.replace("\n", "").split("\t")[1:]] + row = [i.split(":")[1] for i in line.replace("\n", "").split("\t")[1:]] row.append(pred_classes[line_no]) w.write(row) line_no += 1 From b0e78d9d29a2c3c9db3741f4782f0f612ae88949 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9B=84=E6=9C=A8?= Date: Fri, 6 Sep 2019 14:24:30 +0800 Subject: [PATCH 3/5] update --- sql/codegen_xgboost.go | 1 + sql/codegen_xgboost_test.go | 4 ++++ 2 files changed, 5 insertions(+) diff --git a/sql/codegen_xgboost.go b/sql/codegen_xgboost.go index 1bbf1de833..e9e964bbb4 100644 --- a/sql/codegen_xgboost.go +++ b/sql/codegen_xgboost.go @@ -106,6 +106,7 @@ func newXGBFiller(pr *extendedSelect, ds *trainAndValDataset, db *DB) (*xgbFille if r.TableName, _, err = parseTableColumn(pr.into); err != nil { return nil, err } + r.Save = pr.model } if isTrain { diff --git a/sql/codegen_xgboost_test.go b/sql/codegen_xgboost_test.go index 799fcab95e..2cd5961c2f 100644 --- a/sql/codegen_xgboost_test.go +++ b/sql/codegen_xgboost_test.go @@ -67,4 +67,8 @@ func TestXGBFillerPredict(t *testing.T) { filler, e := newXGBFiller(r, nil, testDB) a.NoError(e) a.False(filler.IsTrain) + a.Equal(filler.TableName, "iris.predict") + a.Equal(filler.Save, "sqlflow_models.my_xgboost_model") + a.Equal(filler.PredictionDatasetSQL, `SELECT * +FROM iris.test`) } From 76f8dd790c0cc217fb4122eb6931c9c91aee2701 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9B=84=E6=9C=A8?= Date: Fri, 6 Sep 2019 17:27:48 +0800 Subject: [PATCH 4/5] add executor test --- sql/executor_test.go | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/sql/executor_test.go b/sql/executor_test.go index a6bce4348e..8a18a65d14 100644 --- a/sql/executor_test.go +++ b/sql/executor_test.go @@ -104,13 +104,18 @@ func TestExecutorTrainAnalyzePredictAntXGBoost(t *testing.T) { } } -func TestExecutorTrainXGBoost(t *testing.T) { +func TestExecutorTrainAndPredictXGBoost(t *testing.T) { a := assert.New(t) modelDir := "" a.NotPanics(func() { stream := runExtendedSQL(testXGBoostTrainSelectIris, testDB, modelDir, nil) a.True(goodStream(stream.ReadAll())) }) + + a.NotPanics(func() { + stream := runExtendedSQL(testXGBoostPredictIris, testDB, modelDir, nil) + a.True(goodStream(stream.ReadAll())) + }) } func TestExecutorTrainAndPredictDNN(t *testing.T) { From cb357d7975fd42d1e8e1e90588b01f7da27a87f7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9B=84=E6=9C=A8?= Date: Fri, 6 Sep 2019 18:45:13 +0800 Subject: [PATCH 5/5] fix test by merge --- sql/codegen_xgboost.go | 3 +-- sql/executor_test.go | 5 +---- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/sql/codegen_xgboost.go b/sql/codegen_xgboost.go index e9e964bbb4..ec3305d1a0 100644 --- a/sql/codegen_xgboost.go +++ b/sql/codegen_xgboost.go @@ -100,8 +100,7 @@ func newXGBFiller(pr *extendedSelect, ds *trainAndValDataset, db *DB) (*xgbFille xgbTrainConfig: *resolveTrainCfg(attrs), Save: pr.save, } - - if !isTrain { + if !isTrain && !pr.analyze { r.PredictionDatasetSQL = pr.standardSelect.String() if r.TableName, _, err = parseTableColumn(pr.into); err != nil { return nil, err diff --git a/sql/executor_test.go b/sql/executor_test.go index 74604beb99..07123d2046 100644 --- a/sql/executor_test.go +++ b/sql/executor_test.go @@ -112,10 +112,7 @@ func TestExecutorXGBoost(t *testing.T) { a.True(goodStream(stream.ReadAll())) stream = runExtendedSQL(testAnalyzeTreeModelSelectIris, testDB, modelDir, nil) a.True(goodStream(stream.ReadAll())) - }) - - a.NotPanics(func() { - stream := runExtendedSQL(testXGBoostPredictIris, testDB, modelDir, nil) + stream = runExtendedSQL(testXGBoostPredictIris, testDB, modelDir, nil) a.True(goodStream(stream.ReadAll())) }) }