|  | 
|  | 1 | +{ | 
|  | 2 | + "cells": [ | 
|  | 3 | +  { | 
|  | 4 | +   "cell_type": "markdown", | 
|  | 5 | +   "metadata": {}, | 
|  | 6 | +   "source": [ | 
|  | 7 | +    "# XGBoost on SQLFlow Tutorial\n", | 
|  | 8 | +    "\n", | 
|  | 9 | +    "This is a tutorial on train/predict XGBoost model in SQLFLow, you can find more SQLFlow usage from the [User Guide](https://github.com/sql-machine-learning/sqlflow/blob/develop/doc/user_guide.md), in this tutorial you will learn how to:\n", | 
|  | 10 | +    "- Train a XGBoost model to fit the boston housing dataset; and\n", | 
|  | 11 | +    "- Predict the housing price using the trained model;\n", | 
|  | 12 | +    "\n", | 
|  | 13 | +    "\n", | 
|  | 14 | +    "## The Dataset\n", | 
|  | 15 | +    "\n", | 
|  | 16 | +    "This tutorial would use the [Boston Housing](https://www.kaggle.com/c/boston-housing) as the demonstration dataset.\n", | 
|  | 17 | +    "The database contains 506 lines and 14 columns, the meaning of each column is as follows:\n", | 
|  | 18 | +    "\n", | 
|  | 19 | +    "Column | Explain \n", | 
|  | 20 | +    "-- | -- \n", | 
|  | 21 | +    "crim|per capita crime rate by town.\n", | 
|  | 22 | +    "zn|proportion of residential land zoned for lots over 25,000 sq.ft.\n", | 
|  | 23 | +    "indus|proportion of non-retail business acres per town.\n", | 
|  | 24 | +    "chas|Charles River dummy variable (= 1 if tract bounds river; 0 otherwise).\n", | 
|  | 25 | +    "nox|nitrogen oxides concentration (parts per 10 million).\n", | 
|  | 26 | +    "rm|average number of rooms per dwelling.\n", | 
|  | 27 | +    "age|proportion of owner-occupied units built prior to 1940.\n", | 
|  | 28 | +    "dis|weighted mean of distances to five Boston employment centres.\n", | 
|  | 29 | +    "rad|index of accessibility to radial highways.\n", | 
|  | 30 | +    "tax|full-value property-tax rate per \\$10,000.\n", | 
|  | 31 | +    "ptratio|pupil-teacher ratio by town.\n", | 
|  | 32 | +    "black|1000(Bk - 0.63)^2 where Bk is the proportion of blacks by town.\n", | 
|  | 33 | +    "lstat|lower status of the population (percent).\n", | 
|  | 34 | +    "medv|median value of owner-occupied homes in $1000s.\n", | 
|  | 35 | +    "\n", | 
|  | 36 | +    "We separated the dataset into train/test dataset, which is used to train/predict our model. SQLFlow would automatically split the training dataset into train/validation dataset while training progress." | 
|  | 37 | +   ] | 
|  | 38 | +  }, | 
|  | 39 | +  { | 
|  | 40 | +   "cell_type": "code", | 
|  | 41 | +   "execution_count": 1, | 
|  | 42 | +   "metadata": {}, | 
|  | 43 | +   "outputs": [ | 
|  | 44 | +    { | 
|  | 45 | +     "data": { | 
|  | 46 | +      "text/plain": [ | 
|  | 47 | +       "+---------+---------+------+-----+---------+-------+\n", | 
|  | 48 | +       "|  Field  |   Type  | Null | Key | Default | Extra |\n", | 
|  | 49 | +       "+---------+---------+------+-----+---------+-------+\n", | 
|  | 50 | +       "|   crim  |  float  | YES  |     |   None  |       |\n", | 
|  | 51 | +       "|    zn   |  float  | YES  |     |   None  |       |\n", | 
|  | 52 | +       "|  indus  |  float  | YES  |     |   None  |       |\n", | 
|  | 53 | +       "|   chas  | int(11) | YES  |     |   None  |       |\n", | 
|  | 54 | +       "|   nox   |  float  | YES  |     |   None  |       |\n", | 
|  | 55 | +       "|    rm   |  float  | YES  |     |   None  |       |\n", | 
|  | 56 | +       "|   age   |  float  | YES  |     |   None  |       |\n", | 
|  | 57 | +       "|   dis   |  float  | YES  |     |   None  |       |\n", | 
|  | 58 | +       "|   rad   | int(11) | YES  |     |   None  |       |\n", | 
|  | 59 | +       "|   tax   | int(11) | YES  |     |   None  |       |\n", | 
|  | 60 | +       "| ptratio |  float  | YES  |     |   None  |       |\n", | 
|  | 61 | +       "|    b    |  float  | YES  |     |   None  |       |\n", | 
|  | 62 | +       "|  lstat  |  float  | YES  |     |   None  |       |\n", | 
|  | 63 | +       "|   medv  |  float  | YES  |     |   None  |       |\n", | 
|  | 64 | +       "+---------+---------+------+-----+---------+-------+" | 
|  | 65 | +      ] | 
|  | 66 | +     }, | 
|  | 67 | +     "execution_count": 1, | 
|  | 68 | +     "metadata": {}, | 
|  | 69 | +     "output_type": "execute_result" | 
|  | 70 | +    } | 
|  | 71 | +   ], | 
|  | 72 | +   "source": [ | 
|  | 73 | +    "%%sqlflow\n", | 
|  | 74 | +    "describe boston.train;" | 
|  | 75 | +   ] | 
|  | 76 | +  }, | 
|  | 77 | +  { | 
|  | 78 | +   "cell_type": "code", | 
|  | 79 | +   "execution_count": 2, | 
|  | 80 | +   "metadata": {}, | 
|  | 81 | +   "outputs": [ | 
|  | 82 | +    { | 
|  | 83 | +     "data": { | 
|  | 84 | +      "text/plain": [ | 
|  | 85 | +       "+---------+---------+------+-----+---------+-------+\n", | 
|  | 86 | +       "|  Field  |   Type  | Null | Key | Default | Extra |\n", | 
|  | 87 | +       "+---------+---------+------+-----+---------+-------+\n", | 
|  | 88 | +       "|   crim  |  float  | YES  |     |   None  |       |\n", | 
|  | 89 | +       "|    zn   |  float  | YES  |     |   None  |       |\n", | 
|  | 90 | +       "|  indus  |  float  | YES  |     |   None  |       |\n", | 
|  | 91 | +       "|   chas  | int(11) | YES  |     |   None  |       |\n", | 
|  | 92 | +       "|   nox   |  float  | YES  |     |   None  |       |\n", | 
|  | 93 | +       "|    rm   |  float  | YES  |     |   None  |       |\n", | 
|  | 94 | +       "|   age   |  float  | YES  |     |   None  |       |\n", | 
|  | 95 | +       "|   dis   |  float  | YES  |     |   None  |       |\n", | 
|  | 96 | +       "|   rad   | int(11) | YES  |     |   None  |       |\n", | 
|  | 97 | +       "|   tax   | int(11) | YES  |     |   None  |       |\n", | 
|  | 98 | +       "| ptratio |  float  | YES  |     |   None  |       |\n", | 
|  | 99 | +       "|    b    |  float  | YES  |     |   None  |       |\n", | 
|  | 100 | +       "|  lstat  |  float  | YES  |     |   None  |       |\n", | 
|  | 101 | +       "|   medv  |  float  | YES  |     |   None  |       |\n", | 
|  | 102 | +       "+---------+---------+------+-----+---------+-------+" | 
|  | 103 | +      ] | 
|  | 104 | +     }, | 
|  | 105 | +     "execution_count": 2, | 
|  | 106 | +     "metadata": {}, | 
|  | 107 | +     "output_type": "execute_result" | 
|  | 108 | +    } | 
|  | 109 | +   ], | 
|  | 110 | +   "source": [ | 
|  | 111 | +    "%%sqlflow\n", | 
|  | 112 | +    "describe boston.test;" | 
|  | 113 | +   ] | 
|  | 114 | +  }, | 
|  | 115 | +  { | 
|  | 116 | +   "cell_type": "markdown", | 
|  | 117 | +   "metadata": {}, | 
|  | 118 | +   "source": [ | 
|  | 119 | +    "## Fit Boston Housing Dataset\n", | 
|  | 120 | +    "\n", | 
|  | 121 | +    "First, let's train an XGBoost regression model to fit the boston housing dataset, we prefer to train the model for `30 rounds`,\n", | 
|  | 122 | +    "and using `squarederror` loss function that the SQLFLow extended SQL can be like:\n", | 
|  | 123 | +    "\n", | 
|  | 124 | +    "``` sql\n", | 
|  | 125 | +    "TRAIN xgboost.gbtree\n", | 
|  | 126 | +    "WITH\n", | 
|  | 127 | +    "    train.num_boost_round=30,\n", | 
|  | 128 | +    "    objective=\"reg:squarederror\"\n", | 
|  | 129 | +    "```\n", | 
|  | 130 | +    "\n", | 
|  | 131 | +    "`xgboost.gbtree` is the estimator name, `gbtree` is one of the XGBoost booster, you can find more information from [here](https://xgboost.readthedocs.io/en/latest/parameter.html#general-parameters).\n", | 
|  | 132 | +    "\n", | 
|  | 133 | +    "We can specify the training data columns in `COLUMN clause`, and the label by `LABEL` keyword:\n", | 
|  | 134 | +    "\n", | 
|  | 135 | +    "``` sql\n", | 
|  | 136 | +    "COLUMN crim, zn, indus, chas, nox, rm, age, dis, rad, tax, ptratio, b, lstat\n", | 
|  | 137 | +    "LABEL medv\n", | 
|  | 138 | +    "```\n", | 
|  | 139 | +    "\n", | 
|  | 140 | +    "To save the trained model, we can use `INTO clause` to specify a model name:\n", | 
|  | 141 | +    "\n", | 
|  | 142 | +    "``` sql\n", | 
|  | 143 | +    "INTO sqlflow_models.my_xgb_regression_model\n", | 
|  | 144 | +    "```\n", | 
|  | 145 | +    "\n", | 
|  | 146 | +    "Second, let's use a standar SQL to fetch the traning data from table `boston.train`:\n", | 
|  | 147 | +    "\n", | 
|  | 148 | +    "``` sql\n", | 
|  | 149 | +    "SELECT * FROM boston.train\n", | 
|  | 150 | +    "```\n", | 
|  | 151 | +    "\n", | 
|  | 152 | +    "Finally, the following is the SQLFlow Train statment of this regression task, you can run it in the cell:" | 
|  | 153 | +   ] | 
|  | 154 | +  }, | 
|  | 155 | +  { | 
|  | 156 | +   "cell_type": "code", | 
|  | 157 | +   "execution_count": 5, | 
|  | 158 | +   "metadata": {}, | 
|  | 159 | +   "outputs": [ | 
|  | 160 | +    { | 
|  | 161 | +     "name": "stdout", | 
|  | 162 | +     "output_type": "stream", | 
|  | 163 | +     "text": [ | 
|  | 164 | +      "[03:44:56] 387x13 matrix with 5031 entries loaded from train.txt\n", | 
|  | 165 | +      "\n", | 
|  | 166 | +      "[03:44:56] 109x13 matrix with 1417 entries loaded from test.txt\n", | 
|  | 167 | +      "\n", | 
|  | 168 | +      "[0]\ttrain-rmse:17.0286\tvalidation-rmse:17.8089\n", | 
|  | 169 | +      "\n", | 
|  | 170 | +      "[1]\ttrain-rmse:12.285\tvalidation-rmse:13.2787\n", | 
|  | 171 | +      "\n", | 
|  | 172 | +      "[2]\ttrain-rmse:8.93071\tvalidation-rmse:9.87677\n", | 
|  | 173 | +      "\n", | 
|  | 174 | +      "[3]\ttrain-rmse:6.60757\tvalidation-rmse:7.64013\n", | 
|  | 175 | +      "\n", | 
|  | 176 | +      "[4]\ttrain-rmse:4.96022\tvalidation-rmse:6.0181\n", | 
|  | 177 | +      "\n", | 
|  | 178 | +      "[5]\ttrain-rmse:3.80725\tvalidation-rmse:4.95013\n", | 
|  | 179 | +      "\n", | 
|  | 180 | +      "[6]\ttrain-rmse:2.94382\tvalidation-rmse:4.2357\n", | 
|  | 181 | +      "\n", | 
|  | 182 | +      "[7]\ttrain-rmse:2.36361\tvalidation-rmse:3.74683\n", | 
|  | 183 | +      "\n", | 
|  | 184 | +      "[8]\ttrain-rmse:1.95236\tvalidation-rmse:3.43284\n", | 
|  | 185 | +      "\n", | 
|  | 186 | +      "[9]\ttrain-rmse:1.66604\tvalidation-rmse:3.20455\n", | 
|  | 187 | +      "\n", | 
|  | 188 | +      "[10]\ttrain-rmse:1.4738\tvalidation-rmse:3.08947\n", | 
|  | 189 | +      "\n", | 
|  | 190 | +      "[11]\ttrain-rmse:1.35336\tvalidation-rmse:3.0492\n", | 
|  | 191 | +      "\n", | 
|  | 192 | +      "[12]\ttrain-rmse:1.22835\tvalidation-rmse:2.99508\n", | 
|  | 193 | +      "\n", | 
|  | 194 | +      "[13]\ttrain-rmse:1.15615\tvalidation-rmse:2.98604\n", | 
|  | 195 | +      "\n", | 
|  | 196 | +      "[14]\ttrain-rmse:1.11082\tvalidation-rmse:2.96433\n", | 
|  | 197 | +      "\n", | 
|  | 198 | +      "[15]\ttrain-rmse:1.01666\tvalidation-rmse:2.96584\n", | 
|  | 199 | +      "\n", | 
|  | 200 | +      "[16]\ttrain-rmse:0.953761\tvalidation-rmse:2.94013\n", | 
|  | 201 | +      "\n", | 
|  | 202 | +      "[17]\ttrain-rmse:0.905753\tvalidation-rmse:2.91569\n", | 
|  | 203 | +      "\n", | 
|  | 204 | +      "[18]\ttrain-rmse:0.870137\tvalidation-rmse:2.89735\n", | 
|  | 205 | +      "\n", | 
|  | 206 | +      "[19]\ttrain-rmse:0.800778\tvalidation-rmse:2.87206\n", | 
|  | 207 | +      "\n", | 
|  | 208 | +      "[20]\ttrain-rmse:0.757704\tvalidation-rmse:2.86564\n", | 
|  | 209 | +      "\n", | 
|  | 210 | +      "[21]\ttrain-rmse:0.74058\tvalidation-rmse:2.86587\n", | 
|  | 211 | +      "\n", | 
|  | 212 | +      "[22]\ttrain-rmse:0.66901\tvalidation-rmse:2.86224\n", | 
|  | 213 | +      "\n", | 
|  | 214 | +      "[23]\ttrain-rmse:0.647195\tvalidation-rmse:2.87395\n", | 
|  | 215 | +      "\n", | 
|  | 216 | +      "[24]\ttrain-rmse:0.609025\tvalidation-rmse:2.86069\n", | 
|  | 217 | +      "\n", | 
|  | 218 | +      "[25]\ttrain-rmse:0.562925\tvalidation-rmse:2.87205\n", | 
|  | 219 | +      "\n", | 
|  | 220 | +      "[26]\ttrain-rmse:0.541676\tvalidation-rmse:2.86275\n", | 
|  | 221 | +      "\n", | 
|  | 222 | +      "[27]\ttrain-rmse:0.524815\tvalidation-rmse:2.87106\n", | 
|  | 223 | +      "\n", | 
|  | 224 | +      "[28]\ttrain-rmse:0.483566\tvalidation-rmse:2.86129\n", | 
|  | 225 | +      "\n", | 
|  | 226 | +      "[29]\ttrain-rmse:0.460363\tvalidation-rmse:2.85877\n", | 
|  | 227 | +      "\n" | 
|  | 228 | +     ] | 
|  | 229 | +    } | 
|  | 230 | +   ], | 
|  | 231 | +   "source": [ | 
|  | 232 | +    "%%sqlflow\n", | 
|  | 233 | +    "SELECT * FROM boston.train\n", | 
|  | 234 | +    "TRAIN xgboost.gbtree\n", | 
|  | 235 | +    "WITH\n", | 
|  | 236 | +    "    objective=\"reg:squarederror\",\n", | 
|  | 237 | +    "    train.num_boost_round = 30\n", | 
|  | 238 | +    "COLUMN crim, zn, indus, chas, nox, rm, age, dis, rad, tax, ptratio, b, lstat\n", | 
|  | 239 | +    "LABEL medv\n", | 
|  | 240 | +    "INTO sqlflow_models.my_xgb_regression_model;" | 
|  | 241 | +   ] | 
|  | 242 | +  }, | 
|  | 243 | +  { | 
|  | 244 | +   "cell_type": "markdown", | 
|  | 245 | +   "metadata": {}, | 
|  | 246 | +   "source": [ | 
|  | 247 | +    "### Predict the housing price\n", | 
|  | 248 | +    "After training the regression model, let's predict the house price using the trained model.\n", | 
|  | 249 | +    "\n", | 
|  | 250 | +    "First, we can specify the trained model by `USING clause`: \n", | 
|  | 251 | +    "\n", | 
|  | 252 | +    "```sql\n", | 
|  | 253 | +    "USING sqlflow_models.my_xgb_regression_model\n", | 
|  | 254 | +    "```\n", | 
|  | 255 | +    "\n", | 
|  | 256 | +    "Than, we can specify the prediction result table by `PREDICT clause`:\n", | 
|  | 257 | +    "\n", | 
|  | 258 | +    "``` sql\n", | 
|  | 259 | +    "PREDICT boston.predict.medv\n", | 
|  | 260 | +    "```\n", | 
|  | 261 | +    "\n", | 
|  | 262 | +    "And using a standar SQL to fetch the prediction data:\n", | 
|  | 263 | +    "\n", | 
|  | 264 | +    "``` sql\n", | 
|  | 265 | +    "SELECT * FROM boston.test\n", | 
|  | 266 | +    "```\n", | 
|  | 267 | +    "\n", | 
|  | 268 | +    "Finally, the following is the SQLFLow Prediction statment:" | 
|  | 269 | +   ] | 
|  | 270 | +  }, | 
|  | 271 | +  { | 
|  | 272 | +   "cell_type": "code", | 
|  | 273 | +   "execution_count": 8, | 
|  | 274 | +   "metadata": {}, | 
|  | 275 | +   "outputs": [ | 
|  | 276 | +    { | 
|  | 277 | +     "name": "stdout", | 
|  | 278 | +     "output_type": "stream", | 
|  | 279 | +     "text": [ | 
|  | 280 | +      "[03:45:18] 10x13 matrix with 130 entries loaded from predict.txt\n", | 
|  | 281 | +      "\n", | 
|  | 282 | +      "Done predicting. Predict table : boston.predict\n", | 
|  | 283 | +      "\n" | 
|  | 284 | +     ] | 
|  | 285 | +    } | 
|  | 286 | +   ], | 
|  | 287 | +   "source": [ | 
|  | 288 | +    "%%sqlflow\n", | 
|  | 289 | +    "SELECT * FROM boston.test\n", | 
|  | 290 | +    "PREDICT boston.predict.medv\n", | 
|  | 291 | +    "USING sqlflow_models.my_xgb_regression_model;" | 
|  | 292 | +   ] | 
|  | 293 | +  }, | 
|  | 294 | +  { | 
|  | 295 | +   "cell_type": "markdown", | 
|  | 296 | +   "metadata": {}, | 
|  | 297 | +   "source": [ | 
|  | 298 | +    "Let's have a glance at prediction results." | 
|  | 299 | +   ] | 
|  | 300 | +  }, | 
|  | 301 | +  { | 
|  | 302 | +   "cell_type": "code", | 
|  | 303 | +   "execution_count": 10, | 
|  | 304 | +   "metadata": {}, | 
|  | 305 | +   "outputs": [ | 
|  | 306 | +    { | 
|  | 307 | +     "data": { | 
|  | 308 | +      "text/plain": [ | 
|  | 309 | +       "+---------+-----+-------+------+-------+-------+------+--------+-----+-----+---------+--------+-------+---------+\n", | 
|  | 310 | +       "|   crim  |  zn | indus | chas |  nox  |   rm  | age  |  dis   | rad | tax | ptratio |   b    | lstat |   medv  |\n", | 
|  | 311 | +       "+---------+-----+-------+------+-------+-------+------+--------+-----+-----+---------+--------+-------+---------+\n", | 
|  | 312 | +       "|  0.2896 | 0.0 |  9.69 |  0   | 0.585 |  5.39 | 72.9 | 2.7986 |  6  | 391 |   19.2  | 396.9  | 21.14 | 21.9436 |\n", | 
|  | 313 | +       "| 0.26838 | 0.0 |  9.69 |  0   | 0.585 | 5.794 | 70.6 | 2.8927 |  6  | 391 |   19.2  | 396.9  |  14.1 | 21.9667 |\n", | 
|  | 314 | +       "| 0.23912 | 0.0 |  9.69 |  0   | 0.585 | 6.019 | 65.3 | 2.4091 |  6  | 391 |   19.2  | 396.9  | 12.92 | 22.9708 |\n", | 
|  | 315 | +       "| 0.17783 | 0.0 |  9.69 |  0   | 0.585 | 5.569 | 73.5 | 2.3999 |  6  | 391 |   19.2  | 395.77 |  15.1 | 22.6373 |\n", | 
|  | 316 | +       "| 0.22438 | 0.0 |  9.69 |  0   | 0.585 | 6.027 | 79.7 | 2.4982 |  6  | 391 |   19.2  | 396.9  | 14.33 | 21.9439 |\n", | 
|  | 317 | +       "| 0.06263 | 0.0 | 11.93 |  0   | 0.573 | 6.593 | 69.1 | 2.4786 |  1  | 273 |   21.0  | 391.99 |  9.67 | 24.0095 |\n", | 
|  | 318 | +       "| 0.04527 | 0.0 | 11.93 |  0   | 0.573 |  6.12 | 76.7 | 2.2875 |  1  | 273 |   21.0  | 396.9  |  9.08 |   25.0  |\n", | 
|  | 319 | +       "| 0.06076 | 0.0 | 11.93 |  0   | 0.573 | 6.976 | 91.0 | 2.1675 |  1  | 273 |   21.0  | 396.9  |  5.64 | 31.6326 |\n", | 
|  | 320 | +       "| 0.10959 | 0.0 | 11.93 |  0   | 0.573 | 6.794 | 89.3 | 2.3889 |  1  | 273 |   21.0  | 393.45 |  6.48 | 26.8375 |\n", | 
|  | 321 | +       "| 0.04741 | 0.0 | 11.93 |  0   | 0.573 |  6.03 | 80.8 | 2.505  |  1  | 273 |   21.0  | 396.9  |  7.88 | 22.5877 |\n", | 
|  | 322 | +       "+---------+-----+-------+------+-------+-------+------+--------+-----+-----+---------+--------+-------+---------+" | 
|  | 323 | +      ] | 
|  | 324 | +     }, | 
|  | 325 | +     "execution_count": 10, | 
|  | 326 | +     "metadata": {}, | 
|  | 327 | +     "output_type": "execute_result" | 
|  | 328 | +    } | 
|  | 329 | +   ], | 
|  | 330 | +   "source": [ | 
|  | 331 | +    "%%sqlflow\n", | 
|  | 332 | +    "SELECT * FROM boston.predict;" | 
|  | 333 | +   ] | 
|  | 334 | +  } | 
|  | 335 | + ], | 
|  | 336 | + "metadata": { | 
|  | 337 | +  "kernelspec": { | 
|  | 338 | +   "display_name": "Python 3", | 
|  | 339 | +   "language": "python", | 
|  | 340 | +   "name": "python3" | 
|  | 341 | +  }, | 
|  | 342 | +  "language_info": { | 
|  | 343 | +   "codemirror_mode": { | 
|  | 344 | +    "name": "ipython", | 
|  | 345 | +    "version": 3 | 
|  | 346 | +   }, | 
|  | 347 | +   "file_extension": ".py", | 
|  | 348 | +   "mimetype": "text/x-python", | 
|  | 349 | +   "name": "python", | 
|  | 350 | +   "nbconvert_exporter": "python", | 
|  | 351 | +   "pygments_lexer": "ipython3", | 
|  | 352 | +   "version": "3.6.9" | 
|  | 353 | +  } | 
|  | 354 | + }, | 
|  | 355 | + "nbformat": 4, | 
|  | 356 | + "nbformat_minor": 2 | 
|  | 357 | +} | 
0 commit comments