Skip to content

Commit 2cc93fd

Browse files
committed
hide APIs as much as I can
1 parent 34319ba commit 2cc93fd

File tree

11 files changed

+79
-60
lines changed

11 files changed

+79
-60
lines changed

examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleTextClassificationPipeline.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,16 +27,16 @@
2727
import org.apache.spark.ml.PipelineStage;
2828
import org.apache.spark.ml.classification.LogisticRegression;
2929
import org.apache.spark.ml.feature.HashingTF;
30+
import org.apache.spark.ml.feature.Tokenizer;
3031
import org.apache.spark.sql.api.java.JavaSQLContext;
3132
import org.apache.spark.sql.api.java.JavaSchemaRDD;
3233
import org.apache.spark.sql.api.java.Row;
3334
import org.apache.spark.SparkConf;
3435

3536
/**
3637
* A simple text classification pipeline that recognizes "spark" from input text. It uses the Java
37-
* bean classes {@link LabeledDocument} and {@link Document}, and the tokenizer {@link MyTokenizer}
38-
* defined in the Scala counterpart of this example {@link SimpleTextClassificationPipeline}.
39-
* Run with
38+
* bean classes {@link LabeledDocument} and {@link Document} defined in the Scala counterpart of
39+
* this example {@link SimpleTextClassificationPipeline}. Run with
4040
* <pre>
4141
* bin/run-example ml.JavaSimpleTextClassificationPipeline
4242
* </pre>
@@ -58,7 +58,7 @@ public static void main(String[] args) {
5858
jsql.applySchema(jsc.parallelize(localTraining), LabeledDocument.class);
5959

6060
// Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr.
61-
MyTokenizer tokenizer = new MyTokenizer()
61+
Tokenizer tokenizer = new Tokenizer()
6262
.setInputCol("text")
6363
.setOutputCol("words");
6464
HashingTF hashingTF = new HashingTF()

examples/src/main/scala/org/apache/spark/examples/ml/SimpleTextClassificationPipeline.scala

Lines changed: 5 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -20,35 +20,20 @@ package org.apache.spark.examples.ml
2020
import scala.beans.BeanInfo
2121

2222
import org.apache.spark.{SparkConf, SparkContext}
23-
import org.apache.spark.ml.{Pipeline, UnaryTransformer}
23+
import org.apache.spark.ml.Pipeline
2424
import org.apache.spark.ml.classification.LogisticRegression
25-
import org.apache.spark.ml.feature.HashingTF
26-
import org.apache.spark.ml.param.ParamMap
27-
import org.apache.spark.sql.{DataType, SQLContext, StringType}
25+
import org.apache.spark.ml.feature.{HashingTF, Tokenizer}
26+
import org.apache.spark.sql.SQLContext
2827

2928
@BeanInfo
3029
case class LabeledDocument(id: Long, text: String, label: Double)
3130

3231
@BeanInfo
3332
case class Document(id: Long, text: String)
3433

35-
/**
36-
* A tokenizer that converts the input string to lowercase and then splits it by white spaces.
37-
*/
38-
class MyTokenizer extends UnaryTransformer[String, Seq[String], MyTokenizer] {
39-
40-
override def createTransformFunc(paramMap: ParamMap): String => Seq[String] = {
41-
_.toLowerCase.split("\\s")
42-
}
43-
44-
override protected def validateInputType(inputType: DataType): Unit = {
45-
require(inputType == StringType, s"Input type must be string type but got $inputType.")
46-
}
47-
}
48-
4934
/**
5035
* A simple text classification pipeline that recognizes "spark" from input text. This is to show
51-
* how to define a simple tokenizer and then use it as part of a ML pipeline. Run with
36+
* how to create and configure an ML pipeline. Run with
5237
* {{{
5338
* bin/run-example ml.SimpleTextClassificationPipeline
5439
* }}}
@@ -69,7 +54,7 @@ object SimpleTextClassificationPipeline {
6954
LabeledDocument(3L, "hadoop mapreduce", 0.0)))
7055

7156
// Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr.
72-
val tokenizer = new MyTokenizer()
57+
val tokenizer = new Tokenizer()
7358
.setInputCol("text")
7459
.setOutputCol("words")
7560
val hashingTF = new HashingTF()

mllib/src/main/scala/org/apache/spark/ml/Identifiable.scala

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,18 +19,15 @@ package org.apache.spark.ml
1919

2020
import java.util.UUID
2121

22-
import org.apache.spark.annotation.AlphaComponent
23-
2422
/**
25-
* :: AlphaComponent ::
2623
* Object with a unique id.
2724
*/
28-
@AlphaComponent
29-
trait Identifiable extends Serializable {
25+
private[ml] trait Identifiable extends Serializable {
3026

3127
/**
3228
* A unique id for the object. The default implementation concatenates the class name, "-", and 8
3329
* random hex chars.
3430
*/
35-
val uid: String = this.getClass.getSimpleName + "-" + UUID.randomUUID().toString.take(8)
31+
private[ml] val uid: String =
32+
this.getClass.getSimpleName + "-" + UUID.randomUUID().toString.take(8)
3633
}

mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -21,20 +21,20 @@ import scala.collection.mutable.ListBuffer
2121

2222
import org.apache.spark.Logging
2323
import org.apache.spark.annotation.AlphaComponent
24-
import org.apache.spark.ml.param.{Param, ParamMap}
24+
import org.apache.spark.ml.param.{Params, Param, ParamMap}
2525
import org.apache.spark.sql.{SchemaRDD, StructType}
2626

2727
/**
2828
* :: AlphaComponent ::
29-
* A stage in a pipeline, either an Estimator or an Transformer.
29+
* A stage in a pipeline, either an [[Estimator]] or a [[Transformer]].
3030
*/
3131
@AlphaComponent
3232
abstract class PipelineStage extends Serializable with Logging {
3333

3434
/**
3535
* Derives the output schema from the input schema and parameters.
3636
*/
37-
def transformSchema(schema: StructType, paramMap: ParamMap): StructType
37+
private[ml] def transformSchema(schema: StructType, paramMap: ParamMap): StructType
3838

3939
/**
4040
* Derives the output schema from the input schema and parameters, optionally with logging.
@@ -123,7 +123,7 @@ class Pipeline extends Estimator[PipelineModel] {
123123
new PipelineModel(this, map, transformers.toArray)
124124
}
125125

126-
override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
126+
private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
127127
val map = this.paramMap ++ paramMap
128128
val theStages = map(stages)
129129
require(theStages.toSet.size == theStages.size,
@@ -137,23 +137,23 @@ class Pipeline extends Estimator[PipelineModel] {
137137
* Represents a compiled pipeline.
138138
*/
139139
@AlphaComponent
140-
class PipelineModel(
140+
class PipelineModel private[ml] (
141141
override val parent: Pipeline,
142142
override val fittingParamMap: ParamMap,
143-
val transformers: Array[Transformer])
143+
private[ml] val stages: Array[Transformer])
144144
extends Model[PipelineModel] with Logging {
145145

146146
/**
147147
* Gets the model produced by the input estimator. Throws an NoSuchElementException is the input
148148
* estimator does not exist in the pipeline.
149149
*/
150-
def getModel[M <: Model[M]](estimator: Estimator[M]): M = {
151-
val matched = transformers.filter {
152-
case m: Model[_] => m.parent.eq(estimator)
150+
def getModel[M <: Model[M]](stage: Estimator[M]): M = {
151+
val matched = stages.filter {
152+
case m: Model[_] => m.parent.eq(stage)
153153
case _ => false
154154
}
155155
if (matched.isEmpty) {
156-
throw new NoSuchElementException(s"Cannot find estimator $estimator from the pipeline.")
156+
throw new NoSuchElementException(s"Cannot find stage $stage from the pipeline.")
157157
} else if (matched.size > 1) {
158158
throw new IllegalStateException(s"Cannot have duplicate estimators in the sample pipeline.")
159159
} else {
@@ -163,10 +163,10 @@ class PipelineModel(
163163

164164
override def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD = {
165165
transformSchema(dataset.schema, paramMap, logging = true)
166-
transformers.foldLeft(dataset)((cur, transformer) => transformer.transform(cur, paramMap))
166+
stages.foldLeft(dataset)((cur, transformer) => transformer.transform(cur, paramMap))
167167
}
168168

169-
override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
170-
transformers.foldLeft(schema)((cur, transformer) => transformer.transformSchema(cur, paramMap))
169+
private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
170+
stages.foldLeft(schema)((cur, transformer) => transformer.transformSchema(cur, paramMap))
171171
}
172172
}

mllib/src/main/scala/org/apache/spark/ml/Transformer.scala

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ import org.apache.spark.sql.catalyst.dsl._
3131
import org.apache.spark.sql.catalyst.types._
3232

3333
/**
34-
* :: AlphaComponet ::
34+
* :: AlphaComponent ::
3535
* Abstract class for transformers that transform one dataset into another.
3636
*/
3737
@AlphaComponent
@@ -83,12 +83,10 @@ abstract class Transformer extends PipelineStage with Params {
8383
}
8484

8585
/**
86-
* :: AlphaComponent ::
8786
* Abstract class for transformers that take one input column, apply transformation, and output the
8887
* result as a new column.
8988
*/
90-
@AlphaComponent
91-
abstract class UnaryTransformer[IN, OUT: TypeTag, T <: UnaryTransformer[IN, OUT, T]]
89+
private[ml] abstract class UnaryTransformer[IN, OUT: TypeTag, T <: UnaryTransformer[IN, OUT, T]]
9290
extends Transformer with HasInputCol with HasOutputCol with Logging {
9391

9492
def setInputCol(value: String): T = set(inputCol, value).asInstanceOf[T]

mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ class LogisticRegression extends Estimator[LogisticRegressionModel] with Logisti
105105
lrm
106106
}
107107

108-
override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
108+
private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
109109
validateAndTransformSchema(schema, paramMap, fitting = true)
110110
}
111111
}
@@ -118,15 +118,15 @@ class LogisticRegression extends Estimator[LogisticRegressionModel] with Logisti
118118
class LogisticRegressionModel private[ml] (
119119
override val parent: LogisticRegression,
120120
override val fittingParamMap: ParamMap,
121-
val weights: Vector)
121+
weights: Vector)
122122
extends Model[LogisticRegressionModel] with LogisticRegressionParams {
123123

124124
def setThreshold(value: Double): this.type = set(threshold, value)
125125
def setFeaturesCol(value: String): this.type = set(featuresCol, value)
126126
def setScoreCol(value: String): this.type = set(scoreCol, value)
127127
def setPredictionCol(value: String): this.type = set(predictionCol, value)
128128

129-
override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
129+
private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
130130
validateAndTransformSchema(schema, paramMap, fitting = false)
131131
}
132132

mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ class StandardScaler extends Estimator[StandardScalerModel] with StandardScalerP
5656
model
5757
}
5858

59-
override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
59+
private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
6060
val map = this.paramMap ++ paramMap
6161
val inputType = schema(map(inputCol)).dataType
6262
require(inputType.isInstanceOf[VectorUDT],
@@ -92,7 +92,7 @@ class StandardScalerModel private[ml] (
9292
dataset.select(Star(None), scale.call(map(inputCol).attr) as map(outputCol))
9393
}
9494

95-
override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
95+
private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
9696
val map = this.paramMap ++ paramMap
9797
val inputType = schema(map(inputCol)).dataType
9898
require(inputType.isInstanceOf[VectorUDT],
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.ml.feature
19+
20+
import org.apache.spark.annotation.AlphaComponent
21+
import org.apache.spark.ml.UnaryTransformer
22+
import org.apache.spark.ml.param.ParamMap
23+
import org.apache.spark.sql.{DataType, StringType}
24+
25+
/**
26+
* :: AlphaComponent ::
27+
* A tokenizer that converts the input string to lowercase and then splits it by white spaces.
28+
*/
29+
@AlphaComponent
30+
class Tokenizer extends UnaryTransformer[String, Seq[String], Tokenizer] {
31+
32+
protected override def createTransformFunc(paramMap: ParamMap): String => Seq[String] = {
33+
_.toLowerCase.split("\\s")
34+
}
35+
36+
protected override def validateInputType(inputType: DataType): Unit = {
37+
require(inputType == StringType, s"Input type must be string type but got $inputType.")
38+
}
39+
}

mllib/src/main/scala/org/apache/spark/ml/param/params.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,7 @@ private[ml] object Params {
187187
* @param parent the parent estimator
188188
* @param child the child model
189189
*/
190-
private[ml] def inheritValues[E <: Params, M <: E](
190+
def inheritValues[E <: Params, M <: E](
191191
paramMap: ParamMap,
192192
parent: E,
193193
child: M): Unit = {

mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ class CrossValidator extends Estimator[CrossValidatorModel] with CrossValidatorP
9999
cvModel
100100
}
101101

102-
override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
102+
private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
103103
val map = this.paramMap ++ paramMap
104104
map(estimator).transformSchema(schema, paramMap)
105105
}
@@ -120,7 +120,7 @@ class CrossValidatorModel private[ml] (
120120
bestModel.transform(dataset, paramMap)
121121
}
122122

123-
override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
123+
private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
124124
bestModel.transformSchema(schema, paramMap)
125125
}
126126
}

0 commit comments

Comments
 (0)