Skip to content

Commit 11be383

Browse files
committed
fix unit tests
1 parent 3df7952 commit 11be383

File tree

4 files changed

+25
-23
lines changed

4 files changed

+25
-23
lines changed

mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -18,26 +18,17 @@
1818
package org.apache.spark.ml.classification;
1919

2020
import java.io.Serializable;
21+
import java.util.List;
2122

2223
import org.junit.After;
2324
import org.junit.Before;
2425
import org.junit.Test;
2526

2627
import org.apache.spark.api.java.JavaRDD;
2728
import org.apache.spark.api.java.JavaSparkContext;
28-
import org.apache.spark.ml.Pipeline;
29-
import org.apache.spark.ml.PipelineModel;
30-
import org.apache.spark.ml.PipelineStage;
31-
import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator;
32-
import org.apache.spark.ml.feature.StandardScaler;
33-
import org.apache.spark.ml.param.ParamMap;
34-
import org.apache.spark.ml.tuning.CrossValidator;
35-
import org.apache.spark.ml.tuning.CrossValidatorModel;
36-
import org.apache.spark.ml.tuning.ParamGridBuilder;
3729
import org.apache.spark.mllib.regression.LabeledPoint;
3830
import org.apache.spark.sql.api.java.JavaSQLContext;
3931
import org.apache.spark.sql.api.java.JavaSchemaRDD;
40-
import org.apache.spark.sql.api.java.Row;
4132
import static org.apache.spark.mllib.classification.LogisticRegressionSuite
4233
.generateLogisticInputAsList;
4334

@@ -51,9 +42,8 @@ public class JavaLogisticRegressionSuite implements Serializable {
5142
public void setUp() {
5243
jsc = new JavaSparkContext("local", "JavaLogisticRegressionSuite");
5344
jsql = new JavaSQLContext(jsc);
54-
JavaRDD<LabeledPoint> points =
55-
jsc.parallelize(generateLogisticInputAsList(1.0, 1.0, 100, 42), 2);
56-
dataset = jsql.applySchema(points, LabeledPoint.class);
45+
List<LabeledPoint> points = generateLogisticInputAsList(1.0, 1.0, 100, 42);
46+
dataset = jsql.applySchema(jsc.parallelize(points, 2), LabeledPoint.class);
5747
}
5848

5949
@After

mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,14 @@
1717

1818
package org.apache.spark.ml.tuning;
1919

20+
import java.io.Serializable;
21+
import java.util.List;
22+
2023
import org.junit.After;
2124
import org.junit.Assert;
2225
import org.junit.Before;
2326
import org.junit.Test;
2427

25-
import org.apache.spark.api.java.JavaRDD;
2628
import org.apache.spark.api.java.JavaSparkContext;
2729
import org.apache.spark.ml.classification.LogisticRegression;
2830
import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator;
@@ -33,7 +35,7 @@
3335
import static org.apache.spark.mllib.classification.LogisticRegressionSuite
3436
.generateLogisticInputAsList;
3537

36-
public class JavaCrossValidatorSuite {
38+
public class JavaCrossValidatorSuite implements Serializable {
3739

3840
private transient JavaSparkContext jsc;
3941
private transient JavaSQLContext jsql;
@@ -43,9 +45,8 @@ public class JavaCrossValidatorSuite {
4345
public void setUp() {
4446
jsc = new JavaSparkContext("local", "JavaLogisticRegressionSuite");
4547
jsql = new JavaSQLContext(jsc);
46-
JavaRDD<LabeledPoint> points =
47-
jsc.parallelize(generateLogisticInputAsList(1.0, 1.0, 100, 42), 2);
48-
dataset = jsql.applySchema(points, LabeledPoint.class);
48+
List<LabeledPoint> points = generateLogisticInputAsList(1.0, 1.0, 100, 42);
49+
dataset = jsql.applySchema(jsc.parallelize(points, 2), LabeledPoint.class);
4950
}
5051

5152
@After

mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,15 @@ import org.apache.spark.sql.SchemaRDD
2323
import org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInput
2424
import org.apache.spark.sql.test.TestSQLContext._
2525

26-
class LogisticRegressionSuite extends FunSuite with BeforeAndAfterAll {
26+
class LogisticRegressionSuite extends FunSuite with BeforeAndAfterAll with Serializable {
2727

28-
var dataset: SchemaRDD = sparkContext.parallelize(generateLogisticInput(1.0, 1.0, 1000, 42), 2)
28+
var dataset: SchemaRDD = _
29+
30+
override def beforeAll(): Unit = {
31+
super.beforeAll()
32+
val points = generateLogisticInput(1.0, 1.0, 100, 42)
33+
dataset = sparkContext.parallelize(points, 2)
34+
}
2935

3036
test("logistic regression") {
3137
val lr = new LogisticRegression

mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,17 +17,22 @@
1717

1818
package org.apache.spark.ml.tuning
1919

20-
import org.scalatest.FunSuite
20+
import org.scalatest.{BeforeAndAfterAll, FunSuite}
2121

2222
import org.apache.spark.ml.classification.LogisticRegression
2323
import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator
2424
import org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInput
2525
import org.apache.spark.sql.SchemaRDD
2626
import org.apache.spark.sql.test.TestSQLContext._
2727

28-
class CrossValidatorSuite extends FunSuite {
28+
class CrossValidatorSuite extends FunSuite with BeforeAndAfterAll with Serializable {
2929

30-
val dataset: SchemaRDD = sparkContext.makeRDD(generateLogisticInput(1.0, 1.0, 100, 42), 2)
30+
var dataset: SchemaRDD = _
31+
32+
override def beforeAll(): Unit = {
33+
val points = generateLogisticInput(1.0, 1.0, 100, 42)
34+
dataset = sparkContext.parallelize(points, 2)
35+
}
3136

3237
test("cross validation with logistic regression") {
3338
val lr = new LogisticRegression

0 commit comments

Comments
 (0)