Skip to content
Closed
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 11 additions & 10 deletions examples/src/main/python/ml/simple_params_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,10 @@
import pprint
import sys

from pyspark import SparkContext
from pyspark.ml.classification import LogisticRegression
from pyspark.mllib.linalg import DenseVector
from pyspark.mllib.regression import LabeledPoint
from pyspark.sql import SQLContext
from pyspark.sql import SparkSession

"""
A simple example demonstrating ways to specify parameters for Estimators and Transformers.
Expand All @@ -36,18 +35,20 @@
if len(sys.argv) > 1:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems that the argv are never used in this example. So what about just removing this if segment?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm.. Isn't it making sure of not taking arguments for this script?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This checking seems meaningless. And scala and java example dont have it.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see.. Hm.. but I don't think this is meaningless but means explicitly not taking arguments.

Actually, I think all the examples (not taking arguments) should check this for consistency because some of example scripts (taking arguments) are already checking this.

Strictly, running this example with arguments might not be a proper way to run this example.

Copy link
Member Author

@HyukjinKwon HyukjinKwon May 18, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yanboliang @MLnick Do you mind If I ask your thoughts as well? I don't mind if I should change this example not to check the sys.argv or make another PR to check sys.argv for all other examples in this way.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We're moving most examples towards being more simple (with a few exceptions, such as keeping the longer ML examples that show a bit how to build an app and use args parsing). As such I agree we should remove this.

Copy link
Member Author

@HyukjinKwon HyukjinKwon May 18, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you all!

print("Usage: simple_params_example", file=sys.stderr)
exit(1)
sc = SparkContext(appName="PythonSimpleParamsExample")
sqlContext = SQLContext(sc)
spark = SparkSession \
.builder \
.appName("SimpleTextClassificationPipeline") \
.getOrCreate()

# prepare training data.
# We create an RDD of LabeledPoints and convert them into a DataFrame.
# A LabeledPoint is an Object with two fields named label and features
# and Spark SQL identifies these fields and creates the schema appropriately.
training = sc.parallelize([
training = spark.createDataFrame([
LabeledPoint(1.0, DenseVector([0.0, 1.1, 0.1])),
LabeledPoint(0.0, DenseVector([2.0, 1.0, -1.0])),
LabeledPoint(0.0, DenseVector([2.0, 1.3, 1.0])),
LabeledPoint(1.0, DenseVector([0.0, 1.2, -0.5]))]).toDF()
LabeledPoint(1.0, DenseVector([0.0, 1.2, -0.5]))])

# Create a LogisticRegression instance with maxIter = 10.
# This instance is an Estimator.
Expand All @@ -70,18 +71,18 @@

# We may alternatively specify parameters using a parameter map.
# paramMap overrides all lr parameters set earlier.
paramMap = {lr.maxIter: 20, lr.thresholds: [0.45, 0.55], lr.probabilityCol: "myProbability"}
paramMap = {lr.maxIter: 20, lr.thresholds: [0.5, 0.5], lr.probabilityCol: "myProbability"}
Copy link
Contributor

@yanboliang yanboliang May 16, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, it throws exception when we make predictions because we want to find an authoritative threshold. This change is okey. Actually we use threshold more frequently than thresholds in LogisticRegression, because LR does not support multi classification currently. The community is try to find a way to harmonize the two param for LR, but did not find a final solution. You can refer SPARK-11834 and SPARK-11543 .


# Now learn a new model using the new parameters.
model2 = lr.fit(training, paramMap)
print("Model 2 was fit using parameters:\n")
pprint.pprint(model2.extractParamMap())

# prepare test data.
test = sc.parallelize([
test = spark.createDataFrame([
LabeledPoint(1.0, DenseVector([-1.0, 1.5, 1.3])),
LabeledPoint(0.0, DenseVector([3.0, 2.0, -0.1])),
LabeledPoint(0.0, DenseVector([0.0, 2.2, -1.5]))]).toDF()
LabeledPoint(0.0, DenseVector([0.0, 2.2, -1.5]))])

# Make predictions on test data using the Transformer.transform() method.
# LogisticRegressionModel.transform will only use the 'features' column.
Expand All @@ -95,4 +96,4 @@
print("features=%s,label=%s -> prob=%s, prediction=%s"
% (row.features, row.label, row.myProbability, row.prediction))

sc.stop()
spark.stop()