|
17 | 17 |
|
18 | 18 | __all__ = ['ParamGridBuilder'] |
19 | 19 |
|
| 20 | + |
20 | 21 | class ParamGridBuilder(object): |
21 | 22 | """ |
22 | 23 | Builder for a param grid used in grid search-based model selection. |
@@ -66,18 +67,17 @@ def build(self): |
66 | 67 | lr = LogisticRegression() |
67 | 68 | grid_test.addGrid(lr.regParam, [1.0, 2.0, 3.0]) |
68 | 69 | grid_test.addGrid(lr.maxIter, [1, 5]) |
69 | | - grid_test.addGrid(lr.featuresCol, ['f']) |
| 70 | + grid_test.addGrid(lr.inputCol, ['f']) |
70 | 71 | grid_test.baseOn({lr.labelCol: 'l'}) |
71 | | - grid_test.baseOn([lr.predictionCol, 'p']) |
| 72 | + grid_test.baseOn([lr.outputCol, 'p']) |
72 | 73 | grid = grid_test.build() |
73 | 74 | expected = [ |
74 | | - {lr.regParam: 1.0, lr.featuresCol: 'f', lr.maxIter: 1, lr.labelCol: 'l', lr.predictionCol: 'p'}, |
75 | | - {lr.regParam: 2.0, lr.featuresCol: 'f', lr.maxIter: 1, lr.labelCol: 'l', lr.predictionCol: 'p'}, |
76 | | - {lr.regParam: 3.0, lr.featuresCol: 'f', lr.maxIter: 1, lr.labelCol: 'l', lr.predictionCol: 'p'}, |
77 | | - {lr.regParam: 1.0, lr.featuresCol: 'f', lr.maxIter: 5, lr.labelCol: 'l', lr.predictionCol: 'p'}, |
78 | | - {lr.regParam: 2.0, lr.featuresCol: 'f', lr.maxIter: 5, lr.labelCol: 'l', lr.predictionCol: 'p'}, |
79 | | - {lr.regParam: 3.0, lr.featuresCol: 'f', lr.maxIter: 5, lr.labelCol: 'l', lr.predictionCol: 'p'} |
80 | | - ] |
| 75 | + {lr.regParam: 1.0, lr.inputCol: 'f', lr.maxIter: 1, lr.labelCol: 'l', lr.outputCol: 'p'}, |
| 76 | + {lr.regParam: 2.0, lr.inputCol: 'f', lr.maxIter: 1, lr.labelCol: 'l', lr.outputCol: 'p'}, |
| 77 | + {lr.regParam: 3.0, lr.inputCol: 'f', lr.maxIter: 1, lr.labelCol: 'l', lr.outputCol: 'p'}, |
| 78 | + {lr.regParam: 1.0, lr.inputCol: 'f', lr.maxIter: 5, lr.labelCol: 'l', lr.outputCol: 'p'}, |
| 79 | + {lr.regParam: 2.0, lr.inputCol: 'f', lr.maxIter: 5, lr.labelCol: 'l', lr.outputCol: 'p'}, |
| 80 | + {lr.regParam: 3.0, lr.inputCol: 'f', lr.maxIter: 5, lr.labelCol: 'l', lr.outputCol: 'p'}] |
81 | 81 |
|
82 | 82 | for a, b in zip(grid, expected): |
83 | 83 | if a != b: |
|
0 commit comments