Skip to content

Commit b74f5a5

Browse files
committed
Merge branch 'master' into SPARK-14994
2 parents 045865d + 4607f6e commit b74f5a5

File tree

38 files changed

+600
-384
lines changed

38 files changed

+600
-384
lines changed

core/src/main/scala/org/apache/spark/NewAccumulator.scala

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ import java.io.ObjectInputStream
2222
import java.util.concurrent.atomic.AtomicLong
2323
import javax.annotation.concurrent.GuardedBy
2424

25+
import scala.collection.JavaConverters._
26+
2527
import org.apache.spark.scheduler.AccumulableInfo
2628
import org.apache.spark.util.Utils
2729

@@ -57,7 +59,7 @@ abstract class NewAccumulator[IN, OUT] extends Serializable {
5759
* registered before ues, or it will throw exception.
5860
*/
5961
final def isRegistered: Boolean =
60-
metadata != null && AccumulatorContext.originals.containsKey(metadata.id)
62+
metadata != null && AccumulatorContext.get(metadata.id).isDefined
6163

6264
private def assertMetadataNotNull(): Unit = {
6365
if (metadata == null) {
@@ -197,7 +199,7 @@ private[spark] object AccumulatorContext {
197199
* TODO: Don't use a global map; these should be tied to a SparkContext (SPARK-13051).
198200
*/
199201
@GuardedBy("AccumulatorContext")
200-
val originals = new java.util.HashMap[Long, jl.ref.WeakReference[NewAccumulator[_, _]]]
202+
private val originals = new java.util.HashMap[Long, jl.ref.WeakReference[NewAccumulator[_, _]]]
201203

202204
private[this] val nextId = new AtomicLong(0L)
203205

@@ -207,6 +209,10 @@ private[spark] object AccumulatorContext {
207209
*/
208210
def newId(): Long = nextId.getAndIncrement
209211

212+
def numAccums: Int = synchronized(originals.size)
213+
214+
def accumIds: Set[Long] = synchronized(originals.keySet().asScala.toSet)
215+
210216
/**
211217
* Register an [[Accumulator]] created on the driver such that it can be used on the executors.
212218
*

core/src/test/scala/org/apache/spark/AccumulatorSuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ class AccumulatorSuite extends SparkFunSuite with Matchers with LocalSparkContex
191191
assert(ref.get.isEmpty)
192192

193193
AccumulatorContext.remove(accId)
194-
assert(!AccumulatorContext.originals.containsKey(accId))
194+
assert(!AccumulatorContext.get(accId).isDefined)
195195
}
196196

197197
test("get accum") {

core/src/test/scala/org/apache/spark/InternalAccumulatorSuite.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -183,18 +183,18 @@ class InternalAccumulatorSuite extends SparkFunSuite with LocalSparkContext {
183183
private val myCleaner = new SaveAccumContextCleaner(this)
184184
override def cleaner: Option[ContextCleaner] = Some(myCleaner)
185185
}
186-
assert(AccumulatorContext.originals.isEmpty)
186+
assert(AccumulatorContext.numAccums == 0)
187187
sc.parallelize(1 to 100).map { i => (i, i) }.reduceByKey { _ + _ }.count()
188188
val numInternalAccums = TaskMetrics.empty.internalAccums.length
189189
// We ran 2 stages, so we should have 2 sets of internal accumulators, 1 for each stage
190-
assert(AccumulatorContext.originals.size === numInternalAccums * 2)
190+
assert(AccumulatorContext.numAccums === numInternalAccums * 2)
191191
val accumsRegistered = sc.cleaner match {
192192
case Some(cleaner: SaveAccumContextCleaner) => cleaner.accumsRegisteredForCleanup
193193
case _ => Seq.empty[Long]
194194
}
195195
// Make sure the same set of accumulators is registered for cleanup
196196
assert(accumsRegistered.size === numInternalAccums * 2)
197-
assert(accumsRegistered.toSet === AccumulatorContext.originals.keySet().asScala)
197+
assert(accumsRegistered.toSet === AccumulatorContext.accumIds)
198198
}
199199

200200
/**

mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala

Lines changed: 67 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,17 @@
1717

1818
package org.apache.spark.ml.classification
1919

20+
import org.apache.spark.SparkException
2021
import org.apache.spark.annotation.DeveloperApi
2122
import org.apache.spark.ml.{PredictionModel, Predictor, PredictorParams}
2223
import org.apache.spark.ml.param.shared.HasRawPredictionCol
23-
import org.apache.spark.ml.util.SchemaUtils
24+
import org.apache.spark.ml.util.{MetadataUtils, SchemaUtils}
2425
import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
25-
import org.apache.spark.sql.{DataFrame, Dataset}
26+
import org.apache.spark.mllib.regression.LabeledPoint
27+
import org.apache.spark.rdd.RDD
28+
import org.apache.spark.sql.{DataFrame, Dataset, Row}
2629
import org.apache.spark.sql.functions._
27-
import org.apache.spark.sql.types.{DataType, StructType}
30+
import org.apache.spark.sql.types.{DataType, DoubleType, StructType}
2831

2932
/**
3033
* (private[spark]) Params for classification.
@@ -62,6 +65,67 @@ abstract class Classifier[
6265
def setRawPredictionCol(value: String): E = set(rawPredictionCol, value).asInstanceOf[E]
6366

6467
// TODO: defaultEvaluator (follow-up PR)
68+
69+
/**
70+
* Extract [[labelCol]] and [[featuresCol]] from the given dataset,
71+
* and put it in an RDD with strong types.
72+
*
73+
* @param dataset DataFrame with columns for labels ([[org.apache.spark.sql.types.NumericType]])
74+
* and features ([[Vector]]). Labels are cast to [[DoubleType]].
75+
* @param numClasses Number of classes label can take. Labels must be integers in the range
76+
* [0, numClasses).
77+
* @throws SparkException if any label is not an integer >= 0
78+
*/
79+
protected def extractLabeledPoints(dataset: Dataset[_], numClasses: Int): RDD[LabeledPoint] = {
80+
require(numClasses > 0, s"Classifier (in extractLabeledPoints) found numClasses =" +
81+
s" $numClasses, but requires numClasses > 0.")
82+
dataset.select(col($(labelCol)).cast(DoubleType), col($(featuresCol))).rdd.map {
83+
case Row(label: Double, features: Vector) =>
84+
require(label % 1 == 0 && label >= 0 && label < numClasses, s"Classifier was given" +
85+
s" dataset with invalid label $label. Labels must be integers in range" +
86+
s" [0, 1, ..., $numClasses), where numClasses=$numClasses.")
87+
LabeledPoint(label, features)
88+
}
89+
}
90+
91+
/**
92+
* Get the number of classes. This looks in column metadata first, and if that is missing,
93+
* then this assumes classes are indexed 0,1,...,numClasses-1 and computes numClasses
94+
* by finding the maximum label value.
95+
*
96+
* Label validation (ensuring all labels are integers >= 0) needs to be handled elsewhere,
97+
* such as in [[extractLabeledPoints()]].
98+
*
99+
* @param dataset Dataset which contains a column [[labelCol]]
100+
* @param maxNumClasses Maximum number of classes allowed when inferred from data. If numClasses
101+
* is specified in the metadata, then maxNumClasses is ignored.
102+
* @return number of classes
103+
* @throws IllegalArgumentException if metadata does not specify numClasses, and the
104+
* actual numClasses exceeds maxNumClasses
105+
*/
106+
protected def getNumClasses(dataset: Dataset[_], maxNumClasses: Int = 100): Int = {
107+
MetadataUtils.getNumClasses(dataset.schema($(labelCol))) match {
108+
case Some(n: Int) => n
109+
case None =>
110+
// Get number of classes from dataset itself.
111+
val maxLabelRow: Array[Row] = dataset.select(max($(labelCol))).take(1)
112+
if (maxLabelRow.isEmpty) {
113+
throw new SparkException("ML algorithm was given empty dataset.")
114+
}
115+
val maxDoubleLabel: Double = maxLabelRow.head.getDouble(0)
116+
require((maxDoubleLabel + 1).isValidInt, s"Classifier found max label value =" +
117+
s" $maxDoubleLabel but requires integers in range [0, ... ${Int.MaxValue})")
118+
val numClasses = maxDoubleLabel.toInt + 1
119+
require(numClasses <= maxNumClasses, s"Classifier inferred $numClasses from label values" +
120+
s" in column $labelCol, but this exceeded the max numClasses ($maxNumClasses) allowed" +
121+
s" to be inferred from values. To avoid this error for labels with > $maxNumClasses" +
122+
s" classes, specify numClasses explicitly in the metadata; this can be done by applying" +
123+
s" StringIndexer to the label column.")
124+
logInfo(this.getClass.getCanonicalName + s" inferred $numClasses classes for" +
125+
s" labelCol=$labelCol since numClasses was not specified in the column metadata.")
126+
numClasses
127+
}
128+
}
65129
}
66130

67131
/**

mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -85,14 +85,8 @@ class DecisionTreeClassifier @Since("1.4.0") (
8585
override protected def train(dataset: Dataset[_]): DecisionTreeClassificationModel = {
8686
val categoricalFeatures: Map[Int, Int] =
8787
MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
88-
val numClasses: Int = MetadataUtils.getNumClasses(dataset.schema($(labelCol))) match {
89-
case Some(n: Int) => n
90-
case None => throw new IllegalArgumentException("DecisionTreeClassifier was given input" +
91-
s" with invalid label column ${$(labelCol)}, without the number of classes" +
92-
" specified. See StringIndexer.")
93-
// TODO: Automatically index labels: SPARK-7126
94-
}
95-
val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset)
88+
val numClasses: Int = getNumClasses(dataset)
89+
val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset, numClasses)
9690
val strategy = getOldStrategy(categoricalFeatures, numClasses)
9791
val trees = RandomForest.run(oldDataset, strategy, numTrees = 1, featureSubsetStrategy = "all",
9892
seed = $(seed), parentUID = Some(uid))

mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,9 @@ import org.apache.spark.mllib.regression.LabeledPoint
3535
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
3636
import org.apache.spark.mllib.tree.model.{GradientBoostedTreesModel => OldGBTModel}
3737
import org.apache.spark.rdd.RDD
38-
import org.apache.spark.sql.{DataFrame, Dataset}
38+
import org.apache.spark.sql.{DataFrame, Dataset, Row}
3939
import org.apache.spark.sql.functions._
40+
import org.apache.spark.sql.types.DoubleType
4041

4142
/**
4243
* :: Experimental ::
@@ -126,16 +127,16 @@ class GBTClassifier @Since("1.4.0") (
126127
override protected def train(dataset: Dataset[_]): GBTClassificationModel = {
127128
val categoricalFeatures: Map[Int, Int] =
128129
MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
129-
val numClasses: Int = MetadataUtils.getNumClasses(dataset.schema($(labelCol))) match {
130-
case Some(n: Int) => n
131-
case None => throw new IllegalArgumentException("GBTClassifier was given input" +
132-
s" with invalid label column ${$(labelCol)}, without the number of classes" +
133-
" specified. See StringIndexer.")
134-
// TODO: Automatically index labels: SPARK-7126
135-
}
136-
require(numClasses == 2,
137-
s"GBTClassifier only supports binary classification but was given numClasses = $numClasses")
138-
val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset)
130+
// We copy and modify this from Classifier.extractLabeledPoints since GBT only supports
131+
// 2 classes now. This lets us provide a more precise error message.
132+
val oldDataset: RDD[LabeledPoint] =
133+
dataset.select(col($(labelCol)).cast(DoubleType), col($(featuresCol))).rdd.map {
134+
case Row(label: Double, features: Vector) =>
135+
require(label == 0 || label == 1, s"GBTClassifier was given" +
136+
s" dataset with invalid label $label. Labels must be in {0,1}; note that" +
137+
s" GBTClassifier currently only supports binary classification.")
138+
LabeledPoint(label, features)
139+
}
139140
val numFeatures = oldDataset.first().features.size
140141
val boostingStrategy = super.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Classification)
141142
val (baseLearners, learnerWeights) = GradientBoostedTrees.run(oldDataset, boostingStrategy,
@@ -165,6 +166,7 @@ object GBTClassifier extends DefaultParamsReadable[GBTClassifier] {
165166
* model for classification.
166167
* It supports binary labels, as well as both continuous and categorical features.
167168
* Note: Multiclass labels are not currently supported.
169+
*
168170
* @param _trees Decision trees in the ensemble.
169171
* @param _treeWeights Weights for the decision trees in the ensemble.
170172
*/
@@ -185,6 +187,7 @@ class GBTClassificationModel private[ml](
185187

186188
/**
187189
* Construct a GBTClassificationModel
190+
*
188191
* @param _trees Decision trees in the ensemble.
189192
* @param _treeWeights Weights for the decision trees in the ensemble.
190193
*/

mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -101,14 +101,8 @@ class RandomForestClassifier @Since("1.4.0") (
101101
override protected def train(dataset: Dataset[_]): RandomForestClassificationModel = {
102102
val categoricalFeatures: Map[Int, Int] =
103103
MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
104-
val numClasses: Int = MetadataUtils.getNumClasses(dataset.schema($(labelCol))) match {
105-
case Some(n: Int) => n
106-
case None => throw new IllegalArgumentException("RandomForestClassifier was given input" +
107-
s" with invalid label column ${$(labelCol)}, without the number of classes" +
108-
" specified. See StringIndexer.")
109-
// TODO: Automatically index labels: SPARK-7126
110-
}
111-
val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset)
104+
val numClasses: Int = getNumClasses(dataset)
105+
val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset, numClasses)
112106
val strategy =
113107
super.getOldStrategy(categoricalFeatures, numClasses, OldAlgo.Classification, getOldImpurity)
114108
val trees =

mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -327,5 +327,9 @@ object FPGrowth {
327327
def javaItems: java.util.List[Item] = {
328328
items.toList.asJava
329329
}
330+
331+
override def toString: String = {
332+
s"${items.mkString("{", ",", "}")}: $freq"
333+
}
330334
}
331335
}

mllib/src/test/scala/org/apache/spark/ml/classification/ClassifierSuite.scala

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,86 @@
1717

1818
package org.apache.spark.ml.classification
1919

20+
import org.apache.spark.{SparkException, SparkFunSuite}
21+
import org.apache.spark.ml.classification.ClassifierSuite.MockClassifier
22+
import org.apache.spark.ml.param.ParamMap
23+
import org.apache.spark.ml.util.Identifiable
24+
import org.apache.spark.mllib.linalg.{Vector, Vectors}
25+
import org.apache.spark.mllib.regression.LabeledPoint
26+
import org.apache.spark.mllib.util.MLlibTestSparkContext
27+
import org.apache.spark.rdd.RDD
28+
import org.apache.spark.sql.{DataFrame, Dataset}
29+
30+
class ClassifierSuite extends SparkFunSuite with MLlibTestSparkContext {
31+
32+
test("extractLabeledPoints") {
33+
def getTestData(labels: Seq[Double]): DataFrame = {
34+
val data = labels.map { label: Double => LabeledPoint(label, Vectors.dense(0.0)) }
35+
sqlContext.createDataFrame(data)
36+
}
37+
38+
val c = new MockClassifier
39+
// Valid dataset
40+
val df0 = getTestData(Seq(0.0, 2.0, 1.0, 5.0))
41+
c.extractLabeledPoints(df0, 6).count()
42+
// Invalid datasets
43+
val df1 = getTestData(Seq(0.0, -2.0, 1.0, 5.0))
44+
withClue("Classifier should fail if label is negative") {
45+
val e: SparkException = intercept[SparkException] {
46+
c.extractLabeledPoints(df1, 6).count()
47+
}
48+
assert(e.getMessage.contains("given dataset with invalid label"))
49+
}
50+
val df2 = getTestData(Seq(0.0, 2.1, 1.0, 5.0))
51+
withClue("Classifier should fail if label is not an integer") {
52+
val e: SparkException = intercept[SparkException] {
53+
c.extractLabeledPoints(df2, 6).count()
54+
}
55+
assert(e.getMessage.contains("given dataset with invalid label"))
56+
}
57+
// extractLabeledPoints with numClasses specified
58+
withClue("Classifier should fail if label is >= numClasses") {
59+
val e: SparkException = intercept[SparkException] {
60+
c.extractLabeledPoints(df0, numClasses = 5).count()
61+
}
62+
assert(e.getMessage.contains("given dataset with invalid label"))
63+
}
64+
withClue("Classifier.extractLabeledPoints should fail if numClasses <= 0") {
65+
val e: IllegalArgumentException = intercept[IllegalArgumentException] {
66+
c.extractLabeledPoints(df0, numClasses = 0).count()
67+
}
68+
assert(e.getMessage.contains("but requires numClasses > 0"))
69+
}
70+
}
71+
72+
test("getNumClasses") {
73+
def getTestData(labels: Seq[Double]): DataFrame = {
74+
val data = labels.map { label: Double => LabeledPoint(label, Vectors.dense(0.0)) }
75+
sqlContext.createDataFrame(data)
76+
}
77+
78+
val c = new MockClassifier
79+
// Valid dataset
80+
val df0 = getTestData(Seq(0.0, 2.0, 1.0, 5.0))
81+
assert(c.getNumClasses(df0) === 6)
82+
// Invalid datasets
83+
val df1 = getTestData(Seq(0.0, 2.0, 1.0, 5.1))
84+
withClue("getNumClasses should fail if label is max label not an integer") {
85+
val e: IllegalArgumentException = intercept[IllegalArgumentException] {
86+
c.getNumClasses(df1)
87+
}
88+
assert(e.getMessage.contains("requires integers in range"))
89+
}
90+
val df2 = getTestData(Seq(0.0, 2.0, 1.0, Int.MaxValue.toDouble))
91+
withClue("getNumClasses should fail if label is max label is >= Int.MaxValue") {
92+
val e: IllegalArgumentException = intercept[IllegalArgumentException] {
93+
c.getNumClasses(df2)
94+
}
95+
assert(e.getMessage.contains("requires integers in range"))
96+
}
97+
}
98+
}
99+
20100
object ClassifierSuite {
21101

22102
/**
@@ -29,4 +109,32 @@ object ClassifierSuite {
29109
"rawPredictionCol" -> "myRawPrediction"
30110
)
31111

112+
class MockClassifier(override val uid: String)
113+
extends Classifier[Vector, MockClassifier, MockClassificationModel] {
114+
115+
def this() = this(Identifiable.randomUID("mockclassifier"))
116+
117+
override def copy(extra: ParamMap): MockClassifier = throw new NotImplementedError()
118+
119+
override def train(dataset: Dataset[_]): MockClassificationModel =
120+
throw new NotImplementedError()
121+
122+
// Make methods public
123+
override def extractLabeledPoints(dataset: Dataset[_], numClasses: Int): RDD[LabeledPoint] =
124+
super.extractLabeledPoints(dataset, numClasses)
125+
def getNumClasses(dataset: Dataset[_]): Int = super.getNumClasses(dataset)
126+
}
127+
128+
class MockClassificationModel(override val uid: String)
129+
extends ClassificationModel[Vector, MockClassificationModel] {
130+
131+
def this() = this(Identifiable.randomUID("mockclassificationmodel"))
132+
133+
protected def predictRaw(features: Vector): Vector = throw new NotImplementedError()
134+
135+
override def copy(extra: ParamMap): MockClassificationModel = throw new NotImplementedError()
136+
137+
override def numClasses: Int = throw new NotImplementedError()
138+
}
139+
32140
}

mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -342,6 +342,12 @@ class DecisionTreeClassifierSuite
342342
}
343343
}
344344

345+
test("Fitting without numClasses in metadata") {
346+
val df: DataFrame = sqlContext.createDataFrame(TreeTests.featureImportanceData(sc))
347+
val dt = new DecisionTreeClassifier().setMaxDepth(1)
348+
dt.fit(df)
349+
}
350+
345351
/////////////////////////////////////////////////////////////////////////////
346352
// Tests of model save/load
347353
/////////////////////////////////////////////////////////////////////////////

0 commit comments

Comments
 (0)