|
17 | 17 |
|
18 | 18 | package org.apache.spark.ml |
19 | 19 |
|
20 | | -import org.apache.spark.ml.param.{ParamMap, Param} |
21 | | -import org.apache.spark.sql.SchemaRDD |
22 | | - |
23 | 20 | import scala.collection.mutable.ListBuffer |
24 | 21 |
|
| 22 | +import org.apache.spark.ml.param.{Param, ParamMap} |
| 23 | +import org.apache.spark.sql.SchemaRDD |
| 24 | + |
| 25 | +/** |
| 26 | + * A stage in a pipeline, either an Estimator or an Transformer. |
| 27 | + */ |
25 | 28 | trait PipelineStage extends Identifiable |
26 | 29 |
|
27 | 30 | /** |
28 | 31 | * A simple pipeline, which acts as an estimator. |
29 | 32 | */ |
30 | 33 | class Pipeline extends Estimator[PipelineModel] { |
31 | 34 |
|
32 | | - val stages: Param[Array[PipelineStage]] = |
33 | | - new Param(this, "stages", "stages of the pipeline") |
34 | | - |
35 | | - def setStages(stages: Array[PipelineStage]): this.type = { |
36 | | - set(this.stages, stages) |
37 | | - this |
38 | | - } |
39 | | - |
40 | | - def getStages: Array[PipelineStage] = { |
41 | | - get(stages) |
42 | | - } |
| 35 | + val stages: Param[Array[PipelineStage]] = new Param(this, "stages", "stages of the pipeline") |
| 36 | + def setStages(value: Array[PipelineStage]): this.type = { set(stages, value); this } |
| 37 | + def getStages: Array[PipelineStage] = get(stages) |
43 | 38 |
|
44 | 39 | override def fit(dataset: SchemaRDD, paramMap: ParamMap): PipelineModel = { |
45 | 40 | val map = this.paramMap ++ paramMap |
46 | 41 | val theStages = map(stages) |
47 | | - // Search for last estimator. |
| 42 | + // Search for the last estimator. |
48 | 43 | var lastIndexOfEstimator = -1 |
49 | 44 | theStages.view.zipWithIndex.foreach { case (stage, index) => |
50 | 45 | stage match { |
@@ -75,10 +70,11 @@ class Pipeline extends Estimator[PipelineModel] { |
75 | 70 |
|
76 | 71 | new PipelineModel(transformers.toArray) |
77 | 72 | } |
78 | | - |
79 | | - override def params: Array[Param[_]] = Array.empty |
80 | 73 | } |
81 | 74 |
|
| 75 | +/** |
| 76 | + * Represents a compiled pipeline. |
| 77 | + */ |
82 | 78 | class PipelineModel(val transformers: Array[Transformer]) extends Model { |
83 | 79 |
|
84 | 80 | override def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD = { |
|
0 commit comments