Skip to content

Commit fd751fc

Browse files
committed
add java-friendly versions of fit and tranform
1 parent 3f810cd commit fd751fc

File tree

3 files changed

+41
-12
lines changed

3 files changed

+41
-12
lines changed

mllib/src/main/scala/org/apache/spark/ml/Estimator.scala

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,11 @@
1818
package org.apache.spark.ml
1919

2020
import scala.annotation.varargs
21+
import scala.collection.JavaConverters._
2122

2223
import org.apache.spark.ml.param.{ParamMap, ParamPair, Params}
2324
import org.apache.spark.sql.SchemaRDD
25+
import org.apache.spark.sql.api.java.JavaSchemaRDD
2426

2527
/**
2628
* Abstract class for estimators that fits models to data.
@@ -63,6 +65,21 @@ abstract class Estimator[M <: Model] extends PipelineStage with Params {
6365
paramMaps.map(fit(dataset, _))
6466
}
6567

68+
// Java-friendly versions of fit.
69+
70+
@varargs
71+
def fit(dataset: JavaSchemaRDD, paramPairs: ParamPair[_]*): M = {
72+
fit(dataset.schemaRDD, paramPairs: _*)
73+
}
74+
75+
def fit(dataset: JavaSchemaRDD, paramMap: ParamMap): M = {
76+
fit(dataset.schemaRDD, paramMap)
77+
}
78+
79+
def fit(dataset: JavaSchemaRDD, paramMaps: Array[ParamMap]): java.util.List[M] = {
80+
fit(dataset.schemaRDD, paramMaps).asJava
81+
}
82+
6683
/**
6784
* Parameters for the output model.
6885
*/

mllib/src/main/scala/org/apache/spark/ml/Transformer.scala

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import scala.annotation.varargs
2121

2222
import org.apache.spark.ml.param.{ParamMap, ParamPair, Params}
2323
import org.apache.spark.sql.SchemaRDD
24+
import org.apache.spark.sql.api.java.JavaSchemaRDD
2425

2526
/**
2627
* Abstract class for transformers that transform one dataset into another.
@@ -47,4 +48,15 @@ abstract class Transformer extends PipelineStage with Params {
4748
* @return transformed dataset
4849
*/
4950
def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD
51+
52+
// Java-friendly versions of transform.
53+
54+
@varargs
55+
def transform(dataset: JavaSchemaRDD, paramPairs: ParamPair[_]*): JavaSchemaRDD = {
56+
transform(dataset.schemaRDD, paramPairs: _*).toJavaSchemaRDD
57+
}
58+
59+
def transform(dataset: JavaSchemaRDD, paramMap: ParamMap): JavaSchemaRDD = {
60+
transform(dataset.schemaRDD, paramMap).toJavaSchemaRDD
61+
}
5062
}

mllib/src/test/java/org/apache/spark/ml/example/JavaLogisticRegressionSuite.java

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,10 @@
1919

2020
import java.io.Serializable;
2121

22+
import org.junit.After;
23+
import org.junit.Before;
24+
import org.junit.Test;
25+
2226
import org.apache.spark.api.java.JavaRDD;
2327
import org.apache.spark.api.java.JavaSparkContext;
2428
import org.apache.spark.ml.Pipeline;
@@ -32,10 +36,6 @@
3236
import org.apache.spark.sql.api.java.JavaSchemaRDD;
3337
import org.apache.spark.sql.api.java.Row;
3438

35-
import org.junit.After;
36-
import org.junit.Before;
37-
import org.junit.Test;
38-
3939
public class JavaLogisticRegressionSuite implements Serializable {
4040

4141
private transient JavaSparkContext jsc;
@@ -61,8 +61,8 @@ public void tearDown() {
6161
@Test
6262
public void logisticRegression() {
6363
LogisticRegression lr = new LogisticRegression();
64-
LogisticRegressionModel model = lr.fit(dataset.schemaRDD());
65-
model.transform(dataset.schemaRDD()).registerTempTable("prediction");
64+
LogisticRegressionModel model = lr.fit(dataset);
65+
model.transform(dataset).registerTempTable("prediction");
6666
JavaSchemaRDD predictions = jsql.sql("SELECT label, score, prediction FROM prediction");
6767
for (Row r: predictions.collect()) {
6868
System.out.println(r);
@@ -74,8 +74,8 @@ public void logisticRegressionWithSetters() {
7474
LogisticRegression lr = new LogisticRegression()
7575
.setMaxIter(10)
7676
.setRegParam(1.0);
77-
LogisticRegressionModel model = lr.fit(dataset.schemaRDD());
78-
model.transform(dataset.schemaRDD(), model.threshold().w(0.8)) // overwrite threshold
77+
LogisticRegressionModel model = lr.fit(dataset);
78+
model.transform(dataset, model.threshold().w(0.8)) // overwrite threshold
7979
.registerTempTable("prediction");
8080
JavaSchemaRDD predictions = jsql.sql("SELECT label, score, prediction FROM prediction");
8181
for (Row r: predictions.collect()) {
@@ -95,7 +95,7 @@ public void chainModelParams() {
9595
@Test
9696
public void logisticRegressionFitWithVarargs() {
9797
LogisticRegression lr = new LogisticRegression();
98-
lr.fit(dataset.schemaRDD(), lr.maxIter().w(10), lr.regParam().w(1.0));
98+
lr.fit(dataset, lr.maxIter().w(10), lr.regParam().w(1.0));
9999
}
100100

101101
@Test
@@ -111,7 +111,7 @@ public void logisticRegressionWithCrossValidation() {
111111
.setEstimatorParamMaps(lrParamMaps)
112112
.setEvaluator(eval)
113113
.setNumFolds(3);
114-
CrossValidatorModel bestModel = cv.fit(dataset.baseSchemaRDD());
114+
CrossValidatorModel bestModel = cv.fit(dataset);
115115
}
116116

117117
@Test
@@ -123,8 +123,8 @@ public void logisticRegressionWithPipeline() {
123123
.setFeaturesCol("scaledFeatures");
124124
Pipeline pipeline = new Pipeline()
125125
.setStages(new PipelineStage[] {scaler, lr});
126-
PipelineModel model = pipeline.fit(dataset.baseSchemaRDD());
127-
model.transform(dataset.baseSchemaRDD()).registerTempTable("prediction");
126+
PipelineModel model = pipeline.fit(dataset);
127+
model.transform(dataset).registerTempTable("prediction");
128128
JavaSchemaRDD predictions = jsql.sql("SELECT label, score, prediction FROM prediction");
129129
for (Row r: predictions.collect()) {
130130
System.out.println(r);

0 commit comments

Comments
 (0)