1919
2020import java .io .Serializable ;
2121
22+ import org .junit .After ;
23+ import org .junit .Before ;
24+ import org .junit .Test ;
25+
2226import org .apache .spark .api .java .JavaRDD ;
2327import org .apache .spark .api .java .JavaSparkContext ;
2428import org .apache .spark .ml .Pipeline ;
3236import org .apache .spark .sql .api .java .JavaSchemaRDD ;
3337import org .apache .spark .sql .api .java .Row ;
3438
35- import org .junit .After ;
36- import org .junit .Before ;
37- import org .junit .Test ;
38-
3939public class JavaLogisticRegressionSuite implements Serializable {
4040
4141 private transient JavaSparkContext jsc ;
@@ -61,8 +61,8 @@ public void tearDown() {
6161 @ Test
6262 public void logisticRegression () {
6363 LogisticRegression lr = new LogisticRegression ();
64- LogisticRegressionModel model = lr .fit (dataset . schemaRDD () );
65- model .transform (dataset . schemaRDD () ).registerTempTable ("prediction" );
64+ LogisticRegressionModel model = lr .fit (dataset );
65+ model .transform (dataset ).registerTempTable ("prediction" );
6666 JavaSchemaRDD predictions = jsql .sql ("SELECT label, score, prediction FROM prediction" );
6767 for (Row r : predictions .collect ()) {
6868 System .out .println (r );
@@ -74,8 +74,8 @@ public void logisticRegressionWithSetters() {
7474 LogisticRegression lr = new LogisticRegression ()
7575 .setMaxIter (10 )
7676 .setRegParam (1.0 );
77- LogisticRegressionModel model = lr .fit (dataset . schemaRDD () );
78- model .transform (dataset . schemaRDD () , model .threshold ().w (0.8 )) // overwrite threshold
77+ LogisticRegressionModel model = lr .fit (dataset );
78+ model .transform (dataset , model .threshold ().w (0.8 )) // overwrite threshold
7979 .registerTempTable ("prediction" );
8080 JavaSchemaRDD predictions = jsql .sql ("SELECT label, score, prediction FROM prediction" );
8181 for (Row r : predictions .collect ()) {
@@ -95,7 +95,7 @@ public void chainModelParams() {
9595 @ Test
9696 public void logisticRegressionFitWithVarargs () {
9797 LogisticRegression lr = new LogisticRegression ();
98- lr .fit (dataset . schemaRDD () , lr .maxIter ().w (10 ), lr .regParam ().w (1.0 ));
98+ lr .fit (dataset , lr .maxIter ().w (10 ), lr .regParam ().w (1.0 ));
9999 }
100100
101101 @ Test
@@ -111,7 +111,7 @@ public void logisticRegressionWithCrossValidation() {
111111 .setEstimatorParamMaps (lrParamMaps )
112112 .setEvaluator (eval )
113113 .setNumFolds (3 );
114- CrossValidatorModel bestModel = cv .fit (dataset . baseSchemaRDD () );
114+ CrossValidatorModel bestModel = cv .fit (dataset );
115115 }
116116
117117 @ Test
@@ -123,8 +123,8 @@ public void logisticRegressionWithPipeline() {
123123 .setFeaturesCol ("scaledFeatures" );
124124 Pipeline pipeline = new Pipeline ()
125125 .setStages (new PipelineStage [] {scaler , lr });
126- PipelineModel model = pipeline .fit (dataset . baseSchemaRDD () );
127- model .transform (dataset . baseSchemaRDD () ).registerTempTable ("prediction" );
126+ PipelineModel model = pipeline .fit (dataset );
127+ model .transform (dataset ).registerTempTable ("prediction" );
128128 JavaSchemaRDD predictions = jsql .sql ("SELECT label, score, prediction FROM prediction" );
129129 for (Row r : predictions .collect ()) {
130130 System .out .println (r );
0 commit comments