|
27 | 27 |
|
28 | 28 |
|
29 | 29 | class ParamGridBuilder(object): |
30 | | - """ |
| 30 | + r""" |
31 | 31 | Builder for a param grid used in grid search-based model selection. |
32 | 32 |
|
33 | | - >>> from classification import LogisticRegression |
| 33 | + >>> from pyspark.ml.classification import LogisticRegression |
34 | 34 | >>> lr = LogisticRegression() |
35 | | - >>> output = ParamGridBuilder().baseOn({lr.labelCol: 'l'}) \ |
36 | | - .baseOn([lr.predictionCol, 'p']) \ |
37 | | - .addGrid(lr.regParam, [1.0, 2.0, 3.0]) \ |
38 | | - .addGrid(lr.maxIter, [1, 5]) \ |
39 | | - .addGrid(lr.featuresCol, ['f']) \ |
40 | | - .build() |
41 | | - >>> expected = [ \ |
42 | | -{lr.regParam: 1.0, lr.featuresCol: 'f', lr.maxIter: 1, lr.labelCol: 'l', lr.predictionCol: 'p'}, \ |
43 | | -{lr.regParam: 2.0, lr.featuresCol: 'f', lr.maxIter: 1, lr.labelCol: 'l', lr.predictionCol: 'p'}, \ |
44 | | -{lr.regParam: 3.0, lr.featuresCol: 'f', lr.maxIter: 1, lr.labelCol: 'l', lr.predictionCol: 'p'}, \ |
45 | | -{lr.regParam: 1.0, lr.featuresCol: 'f', lr.maxIter: 5, lr.labelCol: 'l', lr.predictionCol: 'p'}, \ |
46 | | -{lr.regParam: 2.0, lr.featuresCol: 'f', lr.maxIter: 5, lr.labelCol: 'l', lr.predictionCol: 'p'}, \ |
47 | | -{lr.regParam: 3.0, lr.featuresCol: 'f', lr.maxIter: 5, lr.labelCol: 'l', lr.predictionCol: 'p'}] |
| 35 | + >>> output = ParamGridBuilder() \ |
| 36 | + ... .baseOn({lr.labelCol: 'l'}) \ |
| 37 | + ... .baseOn([lr.predictionCol, 'p']) \ |
| 38 | + ... .addGrid(lr.regParam, [1.0, 2.0]) \ |
| 39 | + ... .addGrid(lr.maxIter, [1, 5]) \ |
| 40 | + ... .build() |
| 41 | + >>> expected = [ |
| 42 | + ... {lr.regParam: 1.0, lr.maxIter: 1, lr.labelCol: 'l', lr.predictionCol: 'p'}, |
| 43 | + ... {lr.regParam: 2.0, lr.maxIter: 1, lr.labelCol: 'l', lr.predictionCol: 'p'}, |
| 44 | + ... {lr.regParam: 1.0, lr.maxIter: 5, lr.labelCol: 'l', lr.predictionCol: 'p'}, |
| 45 | + ... {lr.regParam: 2.0, lr.maxIter: 5, lr.labelCol: 'l', lr.predictionCol: 'p'}] |
48 | 46 | >>> len(output) == len(expected) |
49 | 47 | True |
50 | 48 | >>> all([m in expected for m in output]) |
|
0 commit comments