-Spark SQL can convert an RDD of Row objects to a DataFrame, inferring the datatypes. Rows are constructed by passing a list of
+Spark SQL can convert an RDD of Row objects to a DataFrame, inferring the datatypes. Rows are constructed by passing a list of
key/value pairs as kwargs to the Row class. The keys of this list define the column names of the table,
-and the types are inferred by looking at the first row. Since we currently only look at the first
+and the types are inferred by looking at the first row. Since we currently only look at the first
row, it is important that there is no missing data in the first row of the RDD. In future versions we
plan to more completely infer the schema by looking at more data, similar to the inference that is
performed on JSON files.
@@ -780,7 +852,7 @@ for name in names.collect():
Spark SQL supports operating on a variety of data sources through the `DataFrame` interface.
A DataFrame can be operated on as normal RDDs and can also be registered as a temporary table.
-Registering a DataFrame as a table allows you to run SQL queries over its data. This section
+Registering a DataFrame as a table allows you to run SQL queries over its data. This section
describes the general methods for loading and saving data using the Spark Data Sources and then
goes into specific options that are available for the built-in data sources.
@@ -834,9 +906,9 @@ saveDF(select(df, "name", "age"), "namesAndAges.parquet")
### Manually Specifying Options
You can also manually specify the data source that will be used along with any extra options
-that you would like to pass to the data source. Data sources are specified by their fully qualified
+that you would like to pass to the data source. Data sources are specified by their fully qualified
name (i.e., `org.apache.spark.sql.parquet`), but for built-in sources you can also use their short
-names (`json`, `parquet`, `jdbc`). DataFrames of any type can be converted into other types
+names (`json`, `parquet`, `jdbc`). DataFrames of any type can be converted into other types
using this syntax.
@@ -923,8 +995,8 @@ df <- sql(sqlContext, "SELECT * FROM parquet.`examples/src/main/resources/users.
### Save Modes
Save operations can optionally take a `SaveMode`, that specifies how to handle existing data if
-present. It is important to realize that these save modes do not utilize any locking and are not
-atomic. Additionally, when performing a `Overwrite`, the data will be deleted before writing out the
+present. It is important to realize that these save modes do not utilize any locking and are not
+atomic. Additionally, when performing a `Overwrite`, the data will be deleted before writing out the
new data.
@@ -960,7 +1032,7 @@ new data.
Ignore mode means that when saving a DataFrame to a data source, if data already exists,
the save operation is expected to not save the contents of the DataFrame and to not
- change the existing data. This is similar to a CREATE TABLE IF NOT EXISTS in SQL.
+ change the existing data. This is similar to a CREATE TABLE IF NOT EXISTS in SQL.
|
@@ -968,14 +1040,14 @@ new data.
### Saving to Persistent Tables
When working with a `HiveContext`, `DataFrames` can also be saved as persistent tables using the
-`saveAsTable` command. Unlike the `registerTempTable` command, `saveAsTable` will materialize the
-contents of the dataframe and create a pointer to the data in the HiveMetastore. Persistent tables
+`saveAsTable` command. Unlike the `registerTempTable` command, `saveAsTable` will materialize the
+contents of the dataframe and create a pointer to the data in the HiveMetastore. Persistent tables
will still exist even after your Spark program has restarted, as long as you maintain your connection
-to the same metastore. A DataFrame for a persistent table can be created by calling the `table`
+to the same metastore. A DataFrame for a persistent table can be created by calling the `table`
method on a `SQLContext` with the name of the table.
By default `saveAsTable` will create a "managed table", meaning that the location of the data will
-be controlled by the metastore. Managed tables will also have their data deleted automatically
+be controlled by the metastore. Managed tables will also have their data deleted automatically
when a table is dropped.
## Parquet Files
@@ -1003,7 +1075,7 @@ val people: RDD[Person] = ... // An RDD of case class objects, from the previous
// The RDD is implicitly converted to a DataFrame by implicits, allowing it to be stored using Parquet.
people.write.parquet("people.parquet")
-// Read in the parquet file created above. Parquet files are self-describing so the schema is preserved.
+// Read in the parquet file created above. Parquet files are self-describing so the schema is preserved.
// The result of loading a Parquet file is also a DataFrame.
val parquetFile = sqlContext.read.parquet("people.parquet")
@@ -1025,7 +1097,7 @@ DataFrame schemaPeople = ... // The DataFrame from the previous example.
// DataFrames can be saved as Parquet files, maintaining the schema information.
schemaPeople.write().parquet("people.parquet");
-// Read in the Parquet file created above. Parquet files are self-describing so the schema is preserved.
+// Read in the Parquet file created above. Parquet files are self-describing so the schema is preserved.
// The result of loading a parquet file is also a DataFrame.
DataFrame parquetFile = sqlContext.read().parquet("people.parquet");
@@ -1051,7 +1123,7 @@ schemaPeople # The DataFrame from the previous example.
# DataFrames can be saved as Parquet files, maintaining the schema information.
schemaPeople.write.parquet("people.parquet")
-# Read in the Parquet file created above. Parquet files are self-describing so the schema is preserved.
+# Read in the Parquet file created above. Parquet files are self-describing so the schema is preserved.
# The result of loading a parquet file is also a DataFrame.
parquetFile = sqlContext.read.parquet("people.parquet")
@@ -1075,7 +1147,7 @@ schemaPeople # The DataFrame from the previous example.
# DataFrames can be saved as Parquet files, maintaining the schema information.
saveAsParquetFile(schemaPeople, "people.parquet")
-# Read in the Parquet file created above. Parquet files are self-describing so the schema is preserved.
+# Read in the Parquet file created above. Parquet files are self-describing so the schema is preserved.
# The result of loading a parquet file is also a DataFrame.
parquetFile <- parquetFile(sqlContext, "people.parquet")
@@ -1110,10 +1182,10 @@ SELECT * FROM parquetTable
### Partition Discovery
-Table partitioning is a common optimization approach used in systems like Hive. In a partitioned
+Table partitioning is a common optimization approach used in systems like Hive. In a partitioned
table, data are usually stored in different directories, with partitioning column values encoded in
-the path of each partition directory. The Parquet data source is now able to discover and infer
-partitioning information automatically. For example, we can store all our previously used
+the path of each partition directory. The Parquet data source is now able to discover and infer
+partitioning information automatically. For example, we can store all our previously used
population data into a partitioned table using the following directory structure, with two extra
columns, `gender` and `country` as partitioning columns:
@@ -1155,7 +1227,7 @@ root
{% endhighlight %}
-Notice that the data types of the partitioning columns are automatically inferred. Currently,
+Notice that the data types of the partitioning columns are automatically inferred. Currently,
numeric data types and string type are supported. Sometimes users may not want to automatically
infer the data types of the partitioning columns. For these use cases, the automatic type inference
can be configured by `spark.sql.sources.partitionColumnTypeInference.enabled`, which is default to
@@ -1164,13 +1236,13 @@ can be configured by `spark.sql.sources.partitionColumnTypeInference.enabled`, w
### Schema Merging
-Like ProtocolBuffer, Avro, and Thrift, Parquet also supports schema evolution. Users can start with
-a simple schema, and gradually add more columns to the schema as needed. In this way, users may end
-up with multiple Parquet files with different but mutually compatible schemas. The Parquet data
+Like ProtocolBuffer, Avro, and Thrift, Parquet also supports schema evolution. Users can start with
+a simple schema, and gradually add more columns to the schema as needed. In this way, users may end
+up with multiple Parquet files with different but mutually compatible schemas. The Parquet data
source is now able to automatically detect this case and merge schemas of all these files.
Since schema merging is a relatively expensive operation, and is not a necessity in most cases, we
-turned it off by default starting from 1.5.0. You may enable it by
+turned it off by default starting from 1.5.0. You may enable it by
1. setting data source option `mergeSchema` to `true` when reading Parquet files (as shown in the
examples below), or
@@ -1284,10 +1356,10 @@ processing.
1. Hive considers all columns nullable, while nullability in Parquet is significant
Due to this reason, we must reconcile Hive metastore schema with Parquet schema when converting a
-Hive metastore Parquet table to a Spark SQL Parquet table. The reconciliation rules are:
+Hive metastore Parquet table to a Spark SQL Parquet table. The reconciliation rules are:
1. Fields that have the same name in both schema must have the same data type regardless of
- nullability. The reconciled field should have the data type of the Parquet side, so that
+ nullability. The reconciled field should have the data type of the Parquet side, so that
nullability is respected.
1. The reconciled schema contains exactly those fields defined in Hive metastore schema.
@@ -1298,8 +1370,8 @@ Hive metastore Parquet table to a Spark SQL Parquet table. The reconciliation r
#### Metadata Refreshing
-Spark SQL caches Parquet metadata for better performance. When Hive metastore Parquet table
-conversion is enabled, metadata of those converted tables are also cached. If these tables are
+Spark SQL caches Parquet metadata for better performance. When Hive metastore Parquet table
+conversion is enabled, metadata of those converted tables are also cached. If these tables are
updated by Hive or other external tools, you need to refresh them manually to ensure consistent
metadata.
@@ -1362,7 +1434,7 @@ Configuration of Parquet can be done using the `setConf` method on `SQLContext`
spark.sql.parquet.int96AsTimestamp |
true |
- Some Parquet-producing systems, in particular Impala and Hive, store Timestamp into INT96. This
+ Some Parquet-producing systems, in particular Impala and Hive, store Timestamp into INT96. This
flag tells Spark SQL to interpret INT96 data as a timestamp to provide compatibility with these systems.
|
@@ -1400,7 +1472,7 @@ Configuration of Parquet can be done using the `setConf` method on `SQLContext`
The output committer class used by Parquet. The specified class needs to be a subclass of
- org.apache.hadoop. mapreduce.OutputCommitter. Typically, it's also a
+ org.apache.hadoop. mapreduce.OutputCommitter. Typically, it's also a
subclass of org.apache.parquet.hadoop.ParquetOutputCommitter.
@@ -1628,7 +1700,7 @@ YARN cluster. The convenient way to do this is adding them through the `--jars`
When working with Hive one must construct a `HiveContext`, which inherits from `SQLContext`, and
adds support for finding tables in the MetaStore and writing queries using HiveQL. Users who do
-not have an existing Hive deployment can still create a `HiveContext`. When not configured by the
+not have an existing Hive deployment can still create a `HiveContext`. When not configured by the
hive-site.xml, the context automatically creates `metastore_db` in the current directory and
creates `warehouse` directory indicated by HiveConf, which defaults to `/user/hive/warehouse`.
Note that you may need to grant write privilege on `/user/hive/warehouse` to the user who starts
@@ -1738,10 +1810,10 @@ The following options can be used to configure the version of Hive that is used
enabled. When this option is chosen, spark.sql.hive.metastore.version must be
either 1.2.1 or not defined.
maven
- Use Hive jars of specified version downloaded from Maven repositories. This configuration
+ Use Hive jars of specified version downloaded from Maven repositories. This configuration
is not generally recommended for production deployments.
- A classpath in the standard format for the JVM. This classpath must include all of Hive
- and its dependencies, including the correct version of Hadoop. These jars only need to be
+ A classpath in the standard format for the JVM. This classpath must include all of Hive
+ and its dependencies, including the correct version of Hadoop. These jars only need to be
present on the driver, but if you are running in yarn cluster mode then you must ensure
they are packaged with you application.
@@ -1776,7 +1848,7 @@ The following options can be used to configure the version of Hive that is used
## JDBC To Other Databases
-Spark SQL also includes a data source that can read data from other databases using JDBC. This
+Spark SQL also includes a data source that can read data from other databases using JDBC. This
functionality should be preferred over using [JdbcRDD](api/scala/index.html#org.apache.spark.rdd.JdbcRDD).
This is because the results are returned
as a DataFrame and they can easily be processed in Spark SQL or joined with other data sources.
@@ -1786,7 +1858,7 @@ provide a ClassTag.
run queries using Spark SQL).
To get started you will need to include the JDBC driver for you particular database on the
-spark classpath. For example, to connect to postgres from the Spark Shell you would run the
+spark classpath. For example, to connect to postgres from the Spark Shell you would run the
following command:
{% highlight bash %}
@@ -1794,7 +1866,7 @@ SPARK_CLASSPATH=postgresql-9.3-1102-jdbc41.jar bin/spark-shell
{% endhighlight %}
Tables from the remote database can be loaded as a DataFrame or Spark SQL Temporary table using
-the Data Sources API. The following options are supported:
+the Data Sources API. The following options are supported:
| Property Name | Meaning |
@@ -1807,8 +1879,8 @@ the Data Sources API. The following options are supported:
dbtable |
- The JDBC table that should be read. Note that anything that is valid in a FROM clause of
- a SQL query can be used. For example, instead of a full table you could also use a
+ The JDBC table that should be read. Note that anything that is valid in a FROM clause of
+ a SQL query can be used. For example, instead of a full table you could also use a
subquery in parentheses.
|
@@ -1816,7 +1888,7 @@ the Data Sources API. The following options are supported:
driver |
- The class name of the JDBC driver needed to connect to this URL. This class will be loaded
+ The class name of the JDBC driver needed to connect to this URL. This class will be loaded
on the master and workers before running an JDBC commands to allow the driver to
register itself with the JDBC subsystem.
|
@@ -1825,7 +1897,7 @@ the Data Sources API. The following options are supported:
partitionColumn, lowerBound, upperBound, numPartitions |
- These options must all be specified if any of them is specified. They describe how to
+ These options must all be specified if any of them is specified. They describe how to
partition the table when reading in parallel from multiple workers.
partitionColumn must be a numeric column from the table in question. Notice
that lowerBound and upperBound are just used to decide the
@@ -1938,7 +2010,7 @@ Configuration of in-memory caching can be done using the `setConf` method on `SQ
| spark.sql.inMemoryColumnarStorage.batchSize |
10000 |
- Controls the size of batches for columnar caching. Larger batch sizes can improve memory utilization
+ Controls the size of batches for columnar caching. Larger batch sizes can improve memory utilization
and compression, but risk OOMs when caching data.
|
@@ -1947,7 +2019,7 @@ Configuration of in-memory caching can be done using the `setConf` method on `SQ
## Other Configuration Options
-The following options can also be used to tune the performance of query execution. It is possible
+The following options can also be used to tune the performance of query execution. It is possible
that these options will be deprecated in future release as more optimizations are performed automatically.
@@ -1957,7 +2029,7 @@ that these options will be deprecated in future release as more optimizations ar
| 10485760 (10 MB) |
Configures the maximum size in bytes for a table that will be broadcast to all worker nodes when
- performing a join. By setting this value to -1 broadcasting can be disabled. Note that currently
+ performing a join. By setting this value to -1 broadcasting can be disabled. Note that currently
statistics are only supported for Hive Metastore tables where the command
ANALYZE TABLE <tableName> COMPUTE STATISTICS noscan has been run.
|
@@ -1995,8 +2067,8 @@ To start the JDBC/ODBC server, run the following in the Spark directory:
./sbin/start-thriftserver.sh
This script accepts all `bin/spark-submit` command line options, plus a `--hiveconf` option to
-specify Hive properties. You may run `./sbin/start-thriftserver.sh --help` for a complete list of
-all available options. By default, the server listens on localhost:10000. You may override this
+specify Hive properties. You may run `./sbin/start-thriftserver.sh --help` for a complete list of
+all available options. By default, the server listens on localhost:10000. You may override this
behaviour via either environment variables, i.e.:
{% highlight bash %}
@@ -2062,10 +2134,10 @@ options.
## Upgrading From Spark SQL 1.5 to 1.6
- - From Spark 1.6, by default the Thrift server runs in multi-session mode. Which means each JDBC/ODBC
- connection owns a copy of their own SQL configuration and temporary function registry. Cached
- tables are still shared though. If you prefer to run the Thrift server in the old single-session
- mode, please set option `spark.sql.hive.thriftServer.singleSession` to `true`. You may either add
+ - From Spark 1.6, by default the Thrift server runs in multi-session mode. Which means each JDBC/ODBC
+ connection owns a copy of their own SQL configuration and temporary function registry. Cached
+ tables are still shared though. If you prefer to run the Thrift server in the old single-session
+ mode, please set option `spark.sql.hive.thriftServer.singleSession` to `true`. You may either add
this option to `spark-defaults.conf`, or pass it to `start-thriftserver.sh` via `--conf`:
{% highlight bash %}
@@ -2077,20 +2149,20 @@ options.
## Upgrading From Spark SQL 1.4 to 1.5
- Optimized execution using manually managed memory (Tungsten) is now enabled by default, along with
- code generation for expression evaluation. These features can both be disabled by setting
+ code generation for expression evaluation. These features can both be disabled by setting
`spark.sql.tungsten.enabled` to `false`.
- - Parquet schema merging is no longer enabled by default. It can be re-enabled by setting
+ - Parquet schema merging is no longer enabled by default. It can be re-enabled by setting
`spark.sql.parquet.mergeSchema` to `true`.
- Resolution of strings to columns in python now supports using dots (`.`) to qualify the column or
- access nested values. For example `df['table.column.nestedField']`. However, this means that if
- your column name contains any dots you must now escape them using backticks (e.g., ``table.`column.with.dots`.nested``).
+ access nested values. For example `df['table.column.nestedField']`. However, this means that if
+ your column name contains any dots you must now escape them using backticks (e.g., ``table.`column.with.dots`.nested``).
- In-memory columnar storage partition pruning is on by default. It can be disabled by setting
`spark.sql.inMemoryColumnarStorage.partitionPruning` to `false`.
- Unlimited precision decimal columns are no longer supported, instead Spark SQL enforces a maximum
- precision of 38. When inferring schema from `BigDecimal` objects, a precision of (38, 18) is now
+ precision of 38. When inferring schema from `BigDecimal` objects, a precision of (38, 18) is now
used. When no precision is specified in DDL then the default remains `Decimal(10, 0)`.
- Timestamps are now stored at a precision of 1us, rather than 1ns
- - In the `sql` dialect, floating point numbers are now parsed as decimal. HiveQL parsing remains
+ - In the `sql` dialect, floating point numbers are now parsed as decimal. HiveQL parsing remains
unchanged.
- The canonical name of SQL/DataFrame functions are now lower case (e.g. sum vs SUM).
- It has been determined that using the DirectOutputCommitter when speculation is enabled is unsafe
@@ -2183,38 +2255,38 @@ sqlContext.setConf("spark.sql.retainGroupColumns", "false")
## Upgrading from Spark SQL 1.0-1.2 to 1.3
In Spark 1.3 we removed the "Alpha" label from Spark SQL and as part of this did a cleanup of the
-available APIs. From Spark 1.3 onwards, Spark SQL will provide binary compatibility with other
-releases in the 1.X series. This compatibility guarantee excludes APIs that are explicitly marked
+available APIs. From Spark 1.3 onwards, Spark SQL will provide binary compatibility with other
+releases in the 1.X series. This compatibility guarantee excludes APIs that are explicitly marked
as unstable (i.e., DeveloperAPI or Experimental).
#### Rename of SchemaRDD to DataFrame
The largest change that users will notice when upgrading to Spark SQL 1.3 is that `SchemaRDD` has
-been renamed to `DataFrame`. This is primarily because DataFrames no longer inherit from RDD
+been renamed to `DataFrame`. This is primarily because DataFrames no longer inherit from RDD
directly, but instead provide most of the functionality that RDDs provide though their own
-implementation. DataFrames can still be converted to RDDs by calling the `.rdd` method.
+implementation. DataFrames can still be converted to RDDs by calling the `.rdd` method.
In Scala there is a type alias from `SchemaRDD` to `DataFrame` to provide source compatibility for
-some use cases. It is still recommended that users update their code to use `DataFrame` instead.
+some use cases. It is still recommended that users update their code to use `DataFrame` instead.
Java and Python users will need to update their code.
#### Unification of the Java and Scala APIs
Prior to Spark 1.3 there were separate Java compatible classes (`JavaSQLContext` and `JavaSchemaRDD`)
-that mirrored the Scala API. In Spark 1.3 the Java API and Scala API have been unified. Users
-of either language should use `SQLContext` and `DataFrame`. In general theses classes try to
+that mirrored the Scala API. In Spark 1.3 the Java API and Scala API have been unified. Users
+of either language should use `SQLContext` and `DataFrame`. In general theses classes try to
use types that are usable from both languages (i.e. `Array` instead of language specific collections).
In some cases where no common type exists (e.g., for passing in closures or Maps) function overloading
is used instead.
-Additionally the Java specific types API has been removed. Users of both Scala and Java should
+Additionally the Java specific types API has been removed. Users of both Scala and Java should
use the classes present in `org.apache.spark.sql.types` to describe schema programmatically.
#### Isolation of Implicit Conversions and Removal of dsl Package (Scala-only)
Many of the code examples prior to Spark 1.3 started with `import sqlContext._`, which brought
-all of the functions from sqlContext into scope. In Spark 1.3 we have isolated the implicit
+all of the functions from sqlContext into scope. In Spark 1.3 we have isolated the implicit
conversions for converting `RDD`s into `DataFrame`s into an object inside of the `SQLContext`.
Users should now write `import sqlContext.implicits._`.
@@ -2222,7 +2294,7 @@ Additionally, the implicit conversions now only augment RDDs that are composed o
case classes or tuples) with a method `toDF`, instead of applying automatically.
When using function inside of the DSL (now replaced with the `DataFrame` API) users used to import
-`org.apache.spark.sql.catalyst.dsl`. Instead the public dataframe functions API should be used:
+`org.apache.spark.sql.catalyst.dsl`. Instead the public dataframe functions API should be used:
`import org.apache.spark.sql.functions._`.
#### Removal of the type aliases in org.apache.spark.sql for DataType (Scala-only)
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaIndexToStringExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaIndexToStringExample.java
new file mode 100644
index 0000000000000..3ccd6993261e2
--- /dev/null
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaIndexToStringExample.java
@@ -0,0 +1,75 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.examples.ml;
+
+import org.apache.spark.SparkConf;
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.sql.SQLContext;
+
+// $example on$
+import java.util.Arrays;
+
+import org.apache.spark.ml.feature.IndexToString;
+import org.apache.spark.ml.feature.StringIndexer;
+import org.apache.spark.ml.feature.StringIndexerModel;
+import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Row;
+import org.apache.spark.sql.RowFactory;
+import org.apache.spark.sql.types.DataTypes;
+import org.apache.spark.sql.types.Metadata;
+import org.apache.spark.sql.types.StructField;
+import org.apache.spark.sql.types.StructType;
+// $example off$
+
+public class JavaIndexToStringExample {
+ public static void main(String[] args) {
+ SparkConf conf = new SparkConf().setAppName("JavaIndexToStringExample");
+ JavaSparkContext jsc = new JavaSparkContext(conf);
+ SQLContext sqlContext = new SQLContext(jsc);
+
+ // $example on$
+ JavaRDD jrdd = jsc.parallelize(Arrays.asList(
+ RowFactory.create(0, "a"),
+ RowFactory.create(1, "b"),
+ RowFactory.create(2, "c"),
+ RowFactory.create(3, "a"),
+ RowFactory.create(4, "a"),
+ RowFactory.create(5, "c")
+ ));
+ StructType schema = new StructType(new StructField[]{
+ new StructField("id", DataTypes.IntegerType, false, Metadata.empty()),
+ new StructField("category", DataTypes.StringType, false, Metadata.empty())
+ });
+ DataFrame df = sqlContext.createDataFrame(jrdd, schema);
+
+ StringIndexerModel indexer = new StringIndexer()
+ .setInputCol("category")
+ .setOutputCol("categoryIndex")
+ .fit(df);
+ DataFrame indexed = indexer.transform(df);
+
+ IndexToString converter = new IndexToString()
+ .setInputCol("categoryIndex")
+ .setOutputCol("originalCategory");
+ DataFrame converted = converter.transform(indexed);
+ converted.select("id", "originalCategory").show();
+ // $example off$
+ jsc.stop();
+ }
+}
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaQuantileDiscretizerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaQuantileDiscretizerExample.java
new file mode 100644
index 0000000000000..251ae79d9a108
--- /dev/null
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaQuantileDiscretizerExample.java
@@ -0,0 +1,71 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.examples.ml;
+
+import org.apache.spark.SparkConf;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.sql.SQLContext;
+// $example on$
+import java.util.Arrays;
+
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.ml.feature.QuantileDiscretizer;
+import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Row;
+import org.apache.spark.sql.RowFactory;
+import org.apache.spark.sql.types.DataTypes;
+import org.apache.spark.sql.types.Metadata;
+import org.apache.spark.sql.types.StructField;
+import org.apache.spark.sql.types.StructType;
+// $example off$
+
+public class JavaQuantileDiscretizerExample {
+ public static void main(String[] args) {
+ SparkConf conf = new SparkConf().setAppName("JavaQuantileDiscretizerExample");
+ JavaSparkContext jsc = new JavaSparkContext(conf);
+ SQLContext sqlContext = new SQLContext(jsc);
+
+ // $example on$
+ JavaRDD jrdd = jsc.parallelize(
+ Arrays.asList(
+ RowFactory.create(0, 18.0),
+ RowFactory.create(1, 19.0),
+ RowFactory.create(2, 8.0),
+ RowFactory.create(3, 5.0),
+ RowFactory.create(4, 2.2)
+ )
+ );
+
+ StructType schema = new StructType(new StructField[]{
+ new StructField("id", DataTypes.IntegerType, false, Metadata.empty()),
+ new StructField("hour", DataTypes.DoubleType, false, Metadata.empty())
+ });
+
+ DataFrame df = sqlContext.createDataFrame(jrdd, schema);
+
+ QuantileDiscretizer discretizer = new QuantileDiscretizer()
+ .setInputCol("hour")
+ .setOutputCol("result")
+ .setNumBuckets(3);
+
+ DataFrame result = discretizer.fit(df).transform(df);
+ result.show();
+ // $example off$
+ jsc.stop();
+ }
+}
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaSQLTransformerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaSQLTransformerExample.java
new file mode 100644
index 0000000000000..d55c70796a967
--- /dev/null
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaSQLTransformerExample.java
@@ -0,0 +1,59 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.examples.ml;
+
+// $example on$
+import java.util.Arrays;
+
+import org.apache.spark.SparkConf;
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.ml.feature.SQLTransformer;
+import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Row;
+import org.apache.spark.sql.RowFactory;
+import org.apache.spark.sql.SQLContext;
+import org.apache.spark.sql.types.*;
+// $example off$
+
+public class JavaSQLTransformerExample {
+ public static void main(String[] args) {
+
+ SparkConf conf = new SparkConf().setAppName("JavaSQLTransformerExample");
+ JavaSparkContext jsc = new JavaSparkContext(conf);
+ SQLContext sqlContext = new SQLContext(jsc);
+
+ // $example on$
+ JavaRDD jrdd = jsc.parallelize(Arrays.asList(
+ RowFactory.create(0, 1.0, 3.0),
+ RowFactory.create(2, 2.0, 5.0)
+ ));
+ StructType schema = new StructType(new StructField [] {
+ new StructField("id", DataTypes.IntegerType, false, Metadata.empty()),
+ new StructField("v1", DataTypes.DoubleType, false, Metadata.empty()),
+ new StructField("v2", DataTypes.DoubleType, false, Metadata.empty())
+ });
+ DataFrame df = sqlContext.createDataFrame(jrdd, schema);
+
+ SQLTransformer sqlTrans = new SQLTransformer().setStatement(
+ "SELECT *, (v1 + v2) AS v3, (v1 * v2) AS v4 FROM __THIS__");
+
+ sqlTrans.transform(df).show();
+ // $example off$
+ }
+}
diff --git a/examples/src/main/python/ml/index_to_string_example.py b/examples/src/main/python/ml/index_to_string_example.py
new file mode 100644
index 0000000000000..fb0ba2950bbd6
--- /dev/null
+++ b/examples/src/main/python/ml/index_to_string_example.py
@@ -0,0 +1,45 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from __future__ import print_function
+
+from pyspark import SparkContext
+# $example on$
+from pyspark.ml.feature import IndexToString, StringIndexer
+# $example off$
+from pyspark.sql import SQLContext
+
+if __name__ == "__main__":
+ sc = SparkContext(appName="IndexToStringExample")
+ sqlContext = SQLContext(sc)
+
+ # $example on$
+ df = sqlContext.createDataFrame(
+ [(0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")],
+ ["id", "category"])
+
+ stringIndexer = StringIndexer(inputCol="category", outputCol="categoryIndex")
+ model = stringIndexer.fit(df)
+ indexed = model.transform(df)
+
+ converter = IndexToString(inputCol="categoryIndex", outputCol="originalCategory")
+ converted = converter.transform(indexed)
+
+ converted.select("id", "originalCategory").show()
+ # $example off$
+
+ sc.stop()
diff --git a/examples/src/main/python/ml/sql_transformer.py b/examples/src/main/python/ml/sql_transformer.py
new file mode 100644
index 0000000000000..9575d728d8159
--- /dev/null
+++ b/examples/src/main/python/ml/sql_transformer.py
@@ -0,0 +1,40 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from __future__ import print_function
+
+from pyspark import SparkContext
+# $example on$
+from pyspark.ml.feature import SQLTransformer
+# $example off$
+from pyspark.sql import SQLContext
+
+if __name__ == "__main__":
+ sc = SparkContext(appName="SQLTransformerExample")
+ sqlContext = SQLContext(sc)
+
+ # $example on$
+ df = sqlContext.createDataFrame([
+ (0, 1.0, 3.0),
+ (2, 2.0, 5.0)
+ ], ["id", "v1", "v2"])
+ sqlTrans = SQLTransformer(
+ statement="SELECT *, (v1 + v2) AS v3, (v1 * v2) AS v4 FROM __THIS__")
+ sqlTrans.transform(df).show()
+ # $example off$
+
+ sc.stop()
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/IndexToStringExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/IndexToStringExample.scala
new file mode 100644
index 0000000000000..52537e5bb568d
--- /dev/null
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/IndexToStringExample.scala
@@ -0,0 +1,60 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// scalastyle:off println
+package org.apache.spark.examples.ml
+
+import org.apache.spark.sql.SQLContext
+import org.apache.spark.{SparkConf, SparkContext}
+// $example on$
+import org.apache.spark.ml.feature.{StringIndexer, IndexToString}
+// $example off$
+
+object IndexToStringExample {
+ def main(args: Array[String]) {
+ val conf = new SparkConf().setAppName("IndexToStringExample")
+ val sc = new SparkContext(conf)
+
+ val sqlContext = SQLContext.getOrCreate(sc)
+
+ // $example on$
+ val df = sqlContext.createDataFrame(Seq(
+ (0, "a"),
+ (1, "b"),
+ (2, "c"),
+ (3, "a"),
+ (4, "a"),
+ (5, "c")
+ )).toDF("id", "category")
+
+ val indexer = new StringIndexer()
+ .setInputCol("category")
+ .setOutputCol("categoryIndex")
+ .fit(df)
+ val indexed = indexer.transform(df)
+
+ val converter = new IndexToString()
+ .setInputCol("categoryIndex")
+ .setOutputCol("originalCategory")
+
+ val converted = converter.transform(indexed)
+ converted.select("id", "originalCategory").show()
+ // $example off$
+ sc.stop()
+ }
+}
+// scalastyle:on println
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/QuantileDiscretizerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/QuantileDiscretizerExample.scala
new file mode 100644
index 0000000000000..8f29b7eaa6d26
--- /dev/null
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/QuantileDiscretizerExample.scala
@@ -0,0 +1,49 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// scalastyle:off println
+package org.apache.spark.examples.ml
+
+// $example on$
+import org.apache.spark.ml.feature.QuantileDiscretizer
+// $example off$
+import org.apache.spark.sql.SQLContext
+import org.apache.spark.{SparkConf, SparkContext}
+
+object QuantileDiscretizerExample {
+ def main(args: Array[String]) {
+ val conf = new SparkConf().setAppName("QuantileDiscretizerExample")
+ val sc = new SparkContext(conf)
+ val sqlContext = new SQLContext(sc)
+ import sqlContext.implicits._
+
+ // $example on$
+ val data = Array((0, 18.0), (1, 19.0), (2, 8.0), (3, 5.0), (4, 2.2))
+ val df = sc.parallelize(data).toDF("id", "hour")
+
+ val discretizer = new QuantileDiscretizer()
+ .setInputCol("hour")
+ .setOutputCol("result")
+ .setNumBuckets(3)
+
+ val result = discretizer.fit(df).transform(df)
+ result.show()
+ // $example off$
+ sc.stop()
+ }
+}
+// scalastyle:on println
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/SQLTransformerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/SQLTransformerExample.scala
new file mode 100644
index 0000000000000..014abd1fdbc63
--- /dev/null
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/SQLTransformerExample.scala
@@ -0,0 +1,45 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// scalastyle:off println
+package org.apache.spark.examples.ml
+
+// $example on$
+import org.apache.spark.ml.feature.SQLTransformer
+// $example off$
+import org.apache.spark.sql.SQLContext
+import org.apache.spark.{SparkConf, SparkContext}
+
+
+object SQLTransformerExample {
+ def main(args: Array[String]) {
+ val conf = new SparkConf().setAppName("SQLTransformerExample")
+ val sc = new SparkContext(conf)
+ val sqlContext = new SQLContext(sc)
+
+ // $example on$
+ val df = sqlContext.createDataFrame(
+ Seq((0, 1.0, 3.0), (2, 2.0, 5.0))).toDF("id", "v1", "v2")
+
+ val sqlTrans = new SQLTransformer().setStatement(
+ "SELECT *, (v1 + v2) AS v3, (v1 * v2) AS v4 FROM __THIS__")
+
+ sqlTrans.transform(df).show()
+ // $example off$
+ }
+}
+// scalastyle:on println
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala
index 75b0f69cf91aa..70010b05e4345 100644
--- a/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala
@@ -18,19 +18,16 @@
// scalastyle:off println
package org.apache.spark.examples.mllib
-import java.text.BreakIterator
-
-import scala.collection.mutable
-
import scopt.OptionParser
import org.apache.log4j.{Level, Logger}
-
-import org.apache.spark.{SparkContext, SparkConf}
-import org.apache.spark.mllib.clustering.{EMLDAOptimizer, OnlineLDAOptimizer, DistributedLDAModel, LDA}
-import org.apache.spark.mllib.linalg.{Vector, Vectors}
+import org.apache.spark.ml.Pipeline
+import org.apache.spark.ml.feature.{CountVectorizer, CountVectorizerModel, RegexTokenizer, StopWordsRemover}
+import org.apache.spark.mllib.clustering.{DistributedLDAModel, EMLDAOptimizer, LDA, OnlineLDAOptimizer}
+import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.rdd.RDD
-
+import org.apache.spark.sql.{Row, SQLContext}
+import org.apache.spark.{SparkConf, SparkContext}
/**
* An example Latent Dirichlet Allocation (LDA) app. Run with
@@ -192,115 +189,45 @@ object LDAExample {
vocabSize: Int,
stopwordFile: String): (RDD[(Long, Vector)], Array[String], Long) = {
+ val sqlContext = SQLContext.getOrCreate(sc)
+ import sqlContext.implicits._
+
// Get dataset of document texts
// One document per line in each text file. If the input consists of many small files,
// this can result in a large number of small partitions, which can degrade performance.
// In this case, consider using coalesce() to create fewer, larger partitions.
- val textRDD: RDD[String] = sc.textFile(paths.mkString(","))
-
- // Split text into words
- val tokenizer = new SimpleTokenizer(sc, stopwordFile)
- val tokenized: RDD[(Long, IndexedSeq[String])] = textRDD.zipWithIndex().map { case (text, id) =>
- id -> tokenizer.getWords(text)
- }
- tokenized.cache()
-
- // Counts words: RDD[(word, wordCount)]
- val wordCounts: RDD[(String, Long)] = tokenized
- .flatMap { case (_, tokens) => tokens.map(_ -> 1L) }
- .reduceByKey(_ + _)
- wordCounts.cache()
- val fullVocabSize = wordCounts.count()
- // Select vocab
- // (vocab: Map[word -> id], total tokens after selecting vocab)
- val (vocab: Map[String, Int], selectedTokenCount: Long) = {
- val tmpSortedWC: Array[(String, Long)] = if (vocabSize == -1 || fullVocabSize <= vocabSize) {
- // Use all terms
- wordCounts.collect().sortBy(-_._2)
- } else {
- // Sort terms to select vocab
- wordCounts.sortBy(_._2, ascending = false).take(vocabSize)
- }
- (tmpSortedWC.map(_._1).zipWithIndex.toMap, tmpSortedWC.map(_._2).sum)
- }
-
- val documents = tokenized.map { case (id, tokens) =>
- // Filter tokens by vocabulary, and create word count vector representation of document.
- val wc = new mutable.HashMap[Int, Int]()
- tokens.foreach { term =>
- if (vocab.contains(term)) {
- val termIndex = vocab(term)
- wc(termIndex) = wc.getOrElse(termIndex, 0) + 1
- }
- }
- val indices = wc.keys.toArray.sorted
- val values = indices.map(i => wc(i).toDouble)
-
- val sb = Vectors.sparse(vocab.size, indices, values)
- (id, sb)
- }
-
- val vocabArray = new Array[String](vocab.size)
- vocab.foreach { case (term, i) => vocabArray(i) = term }
-
- (documents, vocabArray, selectedTokenCount)
- }
-}
-
-/**
- * Simple Tokenizer.
- *
- * TODO: Formalize the interface, and make this a public class in mllib.feature
- */
-private class SimpleTokenizer(sc: SparkContext, stopwordFile: String) extends Serializable {
-
- private val stopwords: Set[String] = if (stopwordFile.isEmpty) {
- Set.empty[String]
- } else {
- val stopwordText = sc.textFile(stopwordFile).collect()
- stopwordText.flatMap(_.stripMargin.split("\\s+")).toSet
- }
-
- // Matches sequences of Unicode letters
- private val allWordRegex = "^(\\p{L}*)$".r
-
- // Ignore words shorter than this length.
- private val minWordLength = 3
-
- def getWords(text: String): IndexedSeq[String] = {
-
- val words = new mutable.ArrayBuffer[String]()
-
- // Use Java BreakIterator to tokenize text into words.
- val wb = BreakIterator.getWordInstance
- wb.setText(text)
-
- // current,end index start,end of each word
- var current = wb.first()
- var end = wb.next()
- while (end != BreakIterator.DONE) {
- // Convert to lowercase
- val word: String = text.substring(current, end).toLowerCase
- // Remove short words and strings that aren't only letters
- word match {
- case allWordRegex(w) if w.length >= minWordLength && !stopwords.contains(w) =>
- words += w
- case _ =>
- }
-
- current = end
- try {
- end = wb.next()
- } catch {
- case e: Exception =>
- // Ignore remaining text in line.
- // This is a known bug in BreakIterator (for some Java versions),
- // which fails when it sees certain characters.
- end = BreakIterator.DONE
- }
+ val df = sc.textFile(paths.mkString(",")).toDF("docs")
+ val customizedStopWords: Array[String] = if (stopwordFile.isEmpty) {
+ Array.empty[String]
+ } else {
+ val stopWordText = sc.textFile(stopwordFile).collect()
+ stopWordText.flatMap(_.stripMargin.split("\\s+"))
}
- words
+ val tokenizer = new RegexTokenizer()
+ .setInputCol("docs")
+ .setOutputCol("rawTokens")
+ val stopWordsRemover = new StopWordsRemover()
+ .setInputCol("rawTokens")
+ .setOutputCol("tokens")
+ stopWordsRemover.setStopWords(stopWordsRemover.getStopWords ++ customizedStopWords)
+ val countVectorizer = new CountVectorizer()
+ .setVocabSize(vocabSize)
+ .setInputCol("tokens")
+ .setOutputCol("features")
+
+ val pipeline = new Pipeline()
+ .setStages(Array(tokenizer, stopWordsRemover, countVectorizer))
+
+ val model = pipeline.fit(df)
+ val documents = model.transform(df)
+ .select("features")
+ .map { case Row(features: Vector) => features }
+ .zipWithIndex()
+ .map(_.swap)
+
+ (documents,
+ model.stages(2).asInstanceOf[CountVectorizerModel].vocabulary, // vocabulary
+ documents.map(_._2.numActives).sum().toLong) // total token count
}
-
}
// scalastyle:on println
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingTestExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingTestExample.scala
index b6677c6476639..49f5df39443e9 100644
--- a/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingTestExample.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingTestExample.scala
@@ -18,7 +18,7 @@
package org.apache.spark.examples.mllib
import org.apache.spark.SparkConf
-import org.apache.spark.mllib.stat.test.StreamingTest
+import org.apache.spark.mllib.stat.test.{BinarySample, StreamingTest}
import org.apache.spark.streaming.{Seconds, StreamingContext}
import org.apache.spark.util.Utils
@@ -66,7 +66,7 @@ object StreamingTestExample {
// $example on$
val data = ssc.textFileStream(dataDir).map(line => line.split(",") match {
- case Array(label, value) => (label.toBoolean, value.toDouble)
+ case Array(label, value) => BinarySample(label.toBoolean, value.toDouble)
})
val streamingTest = new StreamingTest()
diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala
index ad2fb8aa5f24c..fe572220528d5 100644
--- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala
+++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala
@@ -51,6 +51,7 @@ object KafkaUtils {
* in its own thread
* @param storageLevel Storage level to use for storing the received objects
* (default: StorageLevel.MEMORY_AND_DISK_SER_2)
+ * @return DStream of (Kafka message key, Kafka message value)
*/
def createStream(
ssc: StreamingContext,
@@ -74,6 +75,11 @@ object KafkaUtils {
* @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed
* in its own thread.
* @param storageLevel Storage level to use for storing the received objects
+ * @tparam K type of Kafka message key
+ * @tparam V type of Kafka message value
+ * @tparam U type of Kafka message key decoder
+ * @tparam T type of Kafka message value decoder
+ * @return DStream of (Kafka message key, Kafka message value)
*/
def createStream[K: ClassTag, V: ClassTag, U <: Decoder[_]: ClassTag, T <: Decoder[_]: ClassTag](
ssc: StreamingContext,
@@ -93,6 +99,7 @@ object KafkaUtils {
* @param groupId The group id for this consumer
* @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed
* in its own thread
+ * @return DStream of (Kafka message key, Kafka message value)
*/
def createStream(
jssc: JavaStreamingContext,
@@ -111,6 +118,7 @@ object KafkaUtils {
* @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed
* in its own thread.
* @param storageLevel RDD storage level.
+ * @return DStream of (Kafka message key, Kafka message value)
*/
def createStream(
jssc: JavaStreamingContext,
@@ -135,6 +143,11 @@ object KafkaUtils {
* @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed
* in its own thread
* @param storageLevel RDD storage level.
+ * @tparam K type of Kafka message key
+ * @tparam V type of Kafka message value
+ * @tparam U type of Kafka message key decoder
+ * @tparam T type of Kafka message value decoder
+ * @return DStream of (Kafka message key, Kafka message value)
*/
def createStream[K, V, U <: Decoder[_], T <: Decoder[_]](
jssc: JavaStreamingContext,
@@ -219,6 +232,11 @@ object KafkaUtils {
* host1:port1,host2:port2 form.
* @param offsetRanges Each OffsetRange in the batch corresponds to a
* range of offsets for a given Kafka topic/partition
+ * @tparam K type of Kafka message key
+ * @tparam V type of Kafka message value
+ * @tparam KD type of Kafka message key decoder
+ * @tparam VD type of Kafka message value decoder
+ * @return RDD of (Kafka message key, Kafka message value)
*/
def createRDD[
K: ClassTag,
@@ -251,6 +269,12 @@ object KafkaUtils {
* @param leaders Kafka brokers for each TopicAndPartition in offsetRanges. May be an empty map,
* in which case leaders will be looked up on the driver.
* @param messageHandler Function for translating each message and metadata into the desired type
+ * @tparam K type of Kafka message key
+ * @tparam V type of Kafka message value
+ * @tparam KD type of Kafka message key decoder
+ * @tparam VD type of Kafka message value decoder
+ * @tparam R type returned by messageHandler
+ * @return RDD of R
*/
def createRDD[
K: ClassTag,
@@ -288,6 +312,15 @@ object KafkaUtils {
* host1:port1,host2:port2 form.
* @param offsetRanges Each OffsetRange in the batch corresponds to a
* range of offsets for a given Kafka topic/partition
+ * @param keyClass type of Kafka message key
+ * @param valueClass type of Kafka message value
+ * @param keyDecoderClass type of Kafka message key decoder
+ * @param valueDecoderClass type of Kafka message value decoder
+ * @tparam K type of Kafka message key
+ * @tparam V type of Kafka message value
+ * @tparam KD type of Kafka message key decoder
+ * @tparam VD type of Kafka message value decoder
+ * @return RDD of (Kafka message key, Kafka message value)
*/
def createRDD[K, V, KD <: Decoder[K], VD <: Decoder[V]](
jsc: JavaSparkContext,
@@ -321,6 +354,12 @@ object KafkaUtils {
* @param leaders Kafka brokers for each TopicAndPartition in offsetRanges. May be an empty map,
* in which case leaders will be looked up on the driver.
* @param messageHandler Function for translating each message and metadata into the desired type
+ * @tparam K type of Kafka message key
+ * @tparam V type of Kafka message value
+ * @tparam KD type of Kafka message key decoder
+ * @tparam VD type of Kafka message value decoder
+ * @tparam R type returned by messageHandler
+ * @return RDD of R
*/
def createRDD[K, V, KD <: Decoder[K], VD <: Decoder[V], R](
jsc: JavaSparkContext,
@@ -373,6 +412,12 @@ object KafkaUtils {
* @param fromOffsets Per-topic/partition Kafka offsets defining the (inclusive)
* starting point of the stream
* @param messageHandler Function for translating each message and metadata into the desired type
+ * @tparam K type of Kafka message key
+ * @tparam V type of Kafka message value
+ * @tparam KD type of Kafka message key decoder
+ * @tparam VD type of Kafka message value decoder
+ * @tparam R type returned by messageHandler
+ * @return DStream of R
*/
def createDirectStream[
K: ClassTag,
@@ -419,6 +464,11 @@ object KafkaUtils {
* If not starting from a checkpoint, "auto.offset.reset" may be set to "largest" or "smallest"
* to determine where the stream starts (defaults to "largest")
* @param topics Names of the topics to consume
+ * @tparam K type of Kafka message key
+ * @tparam V type of Kafka message value
+ * @tparam KD type of Kafka message key decoder
+ * @tparam VD type of Kafka message value decoder
+ * @return DStream of (Kafka message key, Kafka message value)
*/
def createDirectStream[
K: ClassTag,
@@ -470,6 +520,12 @@ object KafkaUtils {
* @param fromOffsets Per-topic/partition Kafka offsets defining the (inclusive)
* starting point of the stream
* @param messageHandler Function for translating each message and metadata into the desired type
+ * @tparam K type of Kafka message key
+ * @tparam V type of Kafka message value
+ * @tparam KD type of Kafka message key decoder
+ * @tparam VD type of Kafka message value decoder
+ * @tparam R type returned by messageHandler
+ * @return DStream of R
*/
def createDirectStream[K, V, KD <: Decoder[K], VD <: Decoder[V], R](
jssc: JavaStreamingContext,
@@ -529,6 +585,11 @@ object KafkaUtils {
* If not starting from a checkpoint, "auto.offset.reset" may be set to "largest" or "smallest"
* to determine where the stream starts (defaults to "largest")
* @param topics Names of the topics to consume
+ * @tparam K type of Kafka message key
+ * @tparam V type of Kafka message value
+ * @tparam KD type of Kafka message key decoder
+ * @tparam VD type of Kafka message value decoder
+ * @return DStream of (Kafka message key, Kafka message value)
*/
def createDirectStream[K, V, KD <: Decoder[K], VD <: Decoder[V]](
jssc: JavaStreamingContext,
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
index c478aea44ace8..8c4cec1326653 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
@@ -17,7 +17,7 @@
package org.apache.spark.ml.classification
-import org.apache.spark.annotation.Experimental
+import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.tree.{DecisionTreeModel, DecisionTreeParams, Node, TreeClassifierParams}
import org.apache.spark.ml.tree.impl.RandomForest
@@ -36,32 +36,44 @@ import org.apache.spark.sql.DataFrame
* It supports both binary and multiclass labels, as well as both continuous and categorical
* features.
*/
+@Since("1.4.0")
@Experimental
-final class DecisionTreeClassifier(override val uid: String)
+final class DecisionTreeClassifier @Since("1.4.0") (
+ @Since("1.4.0") override val uid: String)
extends ProbabilisticClassifier[Vector, DecisionTreeClassifier, DecisionTreeClassificationModel]
with DecisionTreeParams with TreeClassifierParams {
+ @Since("1.4.0")
def this() = this(Identifiable.randomUID("dtc"))
// Override parameter setters from parent trait for Java API compatibility.
+ @Since("1.4.0")
override def setMaxDepth(value: Int): this.type = super.setMaxDepth(value)
+ @Since("1.4.0")
override def setMaxBins(value: Int): this.type = super.setMaxBins(value)
+ @Since("1.4.0")
override def setMinInstancesPerNode(value: Int): this.type =
super.setMinInstancesPerNode(value)
+ @Since("1.4.0")
override def setMinInfoGain(value: Double): this.type = super.setMinInfoGain(value)
+ @Since("1.4.0")
override def setMaxMemoryInMB(value: Int): this.type = super.setMaxMemoryInMB(value)
+ @Since("1.4.0")
override def setCacheNodeIds(value: Boolean): this.type = super.setCacheNodeIds(value)
+ @Since("1.4.0")
override def setCheckpointInterval(value: Int): this.type = super.setCheckpointInterval(value)
+ @Since("1.4.0")
override def setImpurity(value: String): this.type = super.setImpurity(value)
+ @Since("1.6.0")
override def setSeed(value: Long): this.type = super.setSeed(value)
override protected def train(dataset: DataFrame): DecisionTreeClassificationModel = {
@@ -89,12 +101,15 @@ final class DecisionTreeClassifier(override val uid: String)
subsamplingRate = 1.0)
}
+ @Since("1.4.1")
override def copy(extra: ParamMap): DecisionTreeClassifier = defaultCopy(extra)
}
+@Since("1.4.0")
@Experimental
object DecisionTreeClassifier {
/** Accessor for supported impurities: entropy, gini */
+ @Since("1.4.0")
final val supportedImpurities: Array[String] = TreeClassifierParams.supportedImpurities
}
@@ -104,12 +119,13 @@ object DecisionTreeClassifier {
* It supports both binary and multiclass labels, as well as both continuous and categorical
* features.
*/
+@Since("1.4.0")
@Experimental
final class DecisionTreeClassificationModel private[ml] (
- override val uid: String,
- override val rootNode: Node,
- override val numFeatures: Int,
- override val numClasses: Int)
+ @Since("1.4.0")override val uid: String,
+ @Since("1.4.0")override val rootNode: Node,
+ @Since("1.6.0")override val numFeatures: Int,
+ @Since("1.5.0")override val numClasses: Int)
extends ProbabilisticClassificationModel[Vector, DecisionTreeClassificationModel]
with DecisionTreeModel with Serializable {
@@ -142,11 +158,13 @@ final class DecisionTreeClassificationModel private[ml] (
}
}
+ @Since("1.4.0")
override def copy(extra: ParamMap): DecisionTreeClassificationModel = {
copyValues(new DecisionTreeClassificationModel(uid, rootNode, numFeatures, numClasses), extra)
.setParent(parent)
}
+ @Since("1.4.0")
override def toString: String = {
s"DecisionTreeClassificationModel (uid=$uid) of depth $depth with $numNodes nodes"
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
index 74aef94bf7675..cda2bca58c50d 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
@@ -20,7 +20,7 @@ package org.apache.spark.ml.classification
import com.github.fommil.netlib.BLAS.{getInstance => blas}
import org.apache.spark.Logging
-import org.apache.spark.annotation.Experimental
+import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml.{PredictionModel, Predictor}
import org.apache.spark.ml.param.{Param, ParamMap}
import org.apache.spark.ml.regression.DecisionTreeRegressionModel
@@ -44,36 +44,47 @@ import org.apache.spark.sql.types.DoubleType
* It supports binary labels, as well as both continuous and categorical features.
* Note: Multiclass labels are not currently supported.
*/
+@Since("1.4.0")
@Experimental
-final class GBTClassifier(override val uid: String)
+final class GBTClassifier @Since("1.4.0") (
+ @Since("1.4.0") override val uid: String)
extends Predictor[Vector, GBTClassifier, GBTClassificationModel]
with GBTParams with TreeClassifierParams with Logging {
+ @Since("1.4.0")
def this() = this(Identifiable.randomUID("gbtc"))
// Override parameter setters from parent trait for Java API compatibility.
// Parameters from TreeClassifierParams:
+ @Since("1.4.0")
override def setMaxDepth(value: Int): this.type = super.setMaxDepth(value)
+ @Since("1.4.0")
override def setMaxBins(value: Int): this.type = super.setMaxBins(value)
+ @Since("1.4.0")
override def setMinInstancesPerNode(value: Int): this.type =
super.setMinInstancesPerNode(value)
+ @Since("1.4.0")
override def setMinInfoGain(value: Double): this.type = super.setMinInfoGain(value)
+ @Since("1.4.0")
override def setMaxMemoryInMB(value: Int): this.type = super.setMaxMemoryInMB(value)
+ @Since("1.4.0")
override def setCacheNodeIds(value: Boolean): this.type = super.setCacheNodeIds(value)
+ @Since("1.4.0")
override def setCheckpointInterval(value: Int): this.type = super.setCheckpointInterval(value)
/**
* The impurity setting is ignored for GBT models.
* Individual trees are built using impurity "Variance."
*/
+ @Since("1.4.0")
override def setImpurity(value: String): this.type = {
logWarning("GBTClassifier.setImpurity should NOT be used")
this
@@ -81,8 +92,10 @@ final class GBTClassifier(override val uid: String)
// Parameters from TreeEnsembleParams:
+ @Since("1.4.0")
override def setSubsamplingRate(value: Double): this.type = super.setSubsamplingRate(value)
+ @Since("1.4.0")
override def setSeed(value: Long): this.type = {
logWarning("The 'seed' parameter is currently ignored by Gradient Boosting.")
super.setSeed(value)
@@ -90,8 +103,10 @@ final class GBTClassifier(override val uid: String)
// Parameters from GBTParams:
+ @Since("1.4.0")
override def setMaxIter(value: Int): this.type = super.setMaxIter(value)
+ @Since("1.4.0")
override def setStepSize(value: Double): this.type = super.setStepSize(value)
// Parameters for GBTClassifier:
@@ -102,6 +117,7 @@ final class GBTClassifier(override val uid: String)
* (default = logistic)
* @group param
*/
+ @Since("1.4.0")
val lossType: Param[String] = new Param[String](this, "lossType", "Loss function which GBT" +
" tries to minimize (case-insensitive). Supported options:" +
s" ${GBTClassifier.supportedLossTypes.mkString(", ")}",
@@ -110,9 +126,11 @@ final class GBTClassifier(override val uid: String)
setDefault(lossType -> "logistic")
/** @group setParam */
+ @Since("1.4.0")
def setLossType(value: String): this.type = set(lossType, value)
/** @group getParam */
+ @Since("1.4.0")
def getLossType: String = $(lossType).toLowerCase
/** (private[ml]) Convert new loss to old loss. */
@@ -145,13 +163,16 @@ final class GBTClassifier(override val uid: String)
GBTClassificationModel.fromOld(oldModel, this, categoricalFeatures, numFeatures)
}
+ @Since("1.4.1")
override def copy(extra: ParamMap): GBTClassifier = defaultCopy(extra)
}
+@Since("1.4.0")
@Experimental
object GBTClassifier {
// The losses below should be lowercase.
/** Accessor for supported loss settings: logistic */
+ @Since("1.4.0")
final val supportedLossTypes: Array[String] = Array("logistic").map(_.toLowerCase)
}
@@ -164,12 +185,13 @@ object GBTClassifier {
* @param _trees Decision trees in the ensemble.
* @param _treeWeights Weights for the decision trees in the ensemble.
*/
+@Since("1.6.0")
@Experimental
final class GBTClassificationModel private[ml](
- override val uid: String,
+ @Since("1.6.0") override val uid: String,
private val _trees: Array[DecisionTreeRegressionModel],
private val _treeWeights: Array[Double],
- override val numFeatures: Int)
+ @Since("1.6.0") override val numFeatures: Int)
extends PredictionModel[Vector, GBTClassificationModel]
with TreeEnsembleModel with Serializable {
@@ -182,11 +204,14 @@ final class GBTClassificationModel private[ml](
* @param _trees Decision trees in the ensemble.
* @param _treeWeights Weights for the decision trees in the ensemble.
*/
+ @Since("1.6.0")
def this(uid: String, _trees: Array[DecisionTreeRegressionModel], _treeWeights: Array[Double]) =
this(uid, _trees, _treeWeights, -1)
+ @Since("1.4.0")
override def trees: Array[DecisionTreeModel] = _trees.asInstanceOf[Array[DecisionTreeModel]]
+ @Since("1.4.0")
override def treeWeights: Array[Double] = _treeWeights
override protected def transformImpl(dataset: DataFrame): DataFrame = {
@@ -205,11 +230,13 @@ final class GBTClassificationModel private[ml](
if (prediction > 0.0) 1.0 else 0.0
}
+ @Since("1.4.0")
override def copy(extra: ParamMap): GBTClassificationModel = {
copyValues(new GBTClassificationModel(uid, _trees, _treeWeights, numFeatures),
extra).setParent(parent)
}
+ @Since("1.4.0")
override def toString: String = {
s"GBTClassificationModel (uid=$uid) with $numTrees trees"
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
index d320d64dd90d0..19cc323d5073f 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
@@ -24,7 +24,7 @@ import breeze.optimize.{CachedDiffFunction, DiffFunction, LBFGS => BreezeLBFGS,
import org.apache.hadoop.fs.Path
import org.apache.spark.{Logging, SparkException}
-import org.apache.spark.annotation.{Since, Experimental}
+import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml.feature.Instance
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
@@ -154,11 +154,14 @@ private[classification] trait LogisticRegressionParams extends ProbabilisticClas
* Currently, this class only supports binary classification. It will support multiclass
* in the future.
*/
+@Since("1.2.0")
@Experimental
-class LogisticRegression(override val uid: String)
+class LogisticRegression @Since("1.2.0") (
+ @Since("1.4.0") override val uid: String)
extends ProbabilisticClassifier[Vector, LogisticRegression, LogisticRegressionModel]
with LogisticRegressionParams with DefaultParamsWritable with Logging {
+ @Since("1.4.0")
def this() = this(Identifiable.randomUID("logreg"))
/**
@@ -166,6 +169,7 @@ class LogisticRegression(override val uid: String)
* Default is 0.0.
* @group setParam
*/
+ @Since("1.2.0")
def setRegParam(value: Double): this.type = set(regParam, value)
setDefault(regParam -> 0.0)
@@ -176,6 +180,7 @@ class LogisticRegression(override val uid: String)
* Default is 0.0 which is an L2 penalty.
* @group setParam
*/
+ @Since("1.4.0")
def setElasticNetParam(value: Double): this.type = set(elasticNetParam, value)
setDefault(elasticNetParam -> 0.0)
@@ -184,6 +189,7 @@ class LogisticRegression(override val uid: String)
* Default is 100.
* @group setParam
*/
+ @Since("1.2.0")
def setMaxIter(value: Int): this.type = set(maxIter, value)
setDefault(maxIter -> 100)
@@ -193,6 +199,7 @@ class LogisticRegression(override val uid: String)
* Default is 1E-6.
* @group setParam
*/
+ @Since("1.4.0")
def setTol(value: Double): this.type = set(tol, value)
setDefault(tol -> 1E-6)
@@ -201,6 +208,7 @@ class LogisticRegression(override val uid: String)
* Default is true.
* @group setParam
*/
+ @Since("1.4.0")
def setFitIntercept(value: Boolean): this.type = set(fitIntercept, value)
setDefault(fitIntercept -> true)
@@ -213,11 +221,14 @@ class LogisticRegression(override val uid: String)
* Default is true.
* @group setParam
*/
+ @Since("1.5.0")
def setStandardization(value: Boolean): this.type = set(standardization, value)
setDefault(standardization -> true)
+ @Since("1.5.0")
override def setThreshold(value: Double): this.type = super.setThreshold(value)
+ @Since("1.5.0")
override def getThreshold: Double = super.getThreshold
/**
@@ -226,11 +237,14 @@ class LogisticRegression(override val uid: String)
* Default is empty, so all instances have weight one.
* @group setParam
*/
+ @Since("1.6.0")
def setWeightCol(value: String): this.type = set(weightCol, value)
setDefault(weightCol -> "")
+ @Since("1.5.0")
override def setThresholds(value: Array[Double]): this.type = super.setThresholds(value)
+ @Since("1.5.0")
override def getThresholds: Array[Double] = super.getThresholds
override protected def train(dataset: DataFrame): LogisticRegressionModel = {
@@ -384,11 +398,14 @@ class LogisticRegression(override val uid: String)
model.setSummary(logRegSummary)
}
+ @Since("1.4.0")
override def copy(extra: ParamMap): LogisticRegression = defaultCopy(extra)
}
+@Since("1.6.0")
object LogisticRegression extends DefaultParamsReadable[LogisticRegression] {
+ @Since("1.6.0")
override def load(path: String): LogisticRegression = super.load(path)
}
@@ -396,23 +413,28 @@ object LogisticRegression extends DefaultParamsReadable[LogisticRegression] {
* :: Experimental ::
* Model produced by [[LogisticRegression]].
*/
+@Since("1.4.0")
@Experimental
class LogisticRegressionModel private[ml] (
- override val uid: String,
- val coefficients: Vector,
- val intercept: Double)
+ @Since("1.4.0") override val uid: String,
+ @Since("1.6.0") val coefficients: Vector,
+ @Since("1.3.0") val intercept: Double)
extends ProbabilisticClassificationModel[Vector, LogisticRegressionModel]
with LogisticRegressionParams with MLWritable {
@deprecated("Use coefficients instead.", "1.6.0")
def weights: Vector = coefficients
+ @Since("1.5.0")
override def setThreshold(value: Double): this.type = super.setThreshold(value)
+ @Since("1.5.0")
override def getThreshold: Double = super.getThreshold
+ @Since("1.5.0")
override def setThresholds(value: Array[Double]): this.type = super.setThresholds(value)
+ @Since("1.5.0")
override def getThresholds: Array[Double] = super.getThresholds
/** Margin (rawPrediction) for class label 1. For binary classification only. */
@@ -426,8 +448,10 @@ class LogisticRegressionModel private[ml] (
1.0 / (1.0 + math.exp(-m))
}
+ @Since("1.6.0")
override val numFeatures: Int = coefficients.size
+ @Since("1.3.0")
override val numClasses: Int = 2
private var trainingSummary: Option[LogisticRegressionTrainingSummary] = None
@@ -436,6 +460,7 @@ class LogisticRegressionModel private[ml] (
* Gets summary of model on training set. An exception is
* thrown if `trainingSummary == None`.
*/
+ @Since("1.5.0")
def summary: LogisticRegressionTrainingSummary = trainingSummary match {
case Some(summ) => summ
case None =>
@@ -451,6 +476,7 @@ class LogisticRegressionModel private[ml] (
}
/** Indicates whether a training summary exists for this model instance. */
+ @Since("1.5.0")
def hasSummary: Boolean = trainingSummary.isDefined
/**
@@ -493,6 +519,7 @@ class LogisticRegressionModel private[ml] (
Vectors.dense(-m, m)
}
+ @Since("1.4.0")
override def copy(extra: ParamMap): LogisticRegressionModel = {
val newModel = copyValues(new LogisticRegressionModel(uid, coefficients, intercept), extra)
if (trainingSummary.isDefined) newModel.setSummary(trainingSummary.get)
@@ -710,12 +737,13 @@ sealed trait LogisticRegressionSummary extends Serializable {
* @param objectiveHistory objective function (scaled loss + regularization) at each iteration.
*/
@Experimental
+@Since("1.5.0")
class BinaryLogisticRegressionTrainingSummary private[classification] (
- predictions: DataFrame,
- probabilityCol: String,
- labelCol: String,
- featuresCol: String,
- val objectiveHistory: Array[Double])
+ @Since("1.5.0") predictions: DataFrame,
+ @Since("1.5.0") probabilityCol: String,
+ @Since("1.5.0") labelCol: String,
+ @Since("1.6.0") featuresCol: String,
+ @Since("1.5.0") val objectiveHistory: Array[Double])
extends BinaryLogisticRegressionSummary(predictions, probabilityCol, labelCol, featuresCol)
with LogisticRegressionTrainingSummary {
@@ -731,11 +759,13 @@ class BinaryLogisticRegressionTrainingSummary private[classification] (
* @param featuresCol field in "predictions" which gives the features of each instance as a vector.
*/
@Experimental
+@Since("1.5.0")
class BinaryLogisticRegressionSummary private[classification] (
- @transient override val predictions: DataFrame,
- override val probabilityCol: String,
- override val labelCol: String,
- override val featuresCol: String) extends LogisticRegressionSummary {
+ @Since("1.5.0") @transient override val predictions: DataFrame,
+ @Since("1.5.0") override val probabilityCol: String,
+ @Since("1.5.0") override val labelCol: String,
+ @Since("1.6.0") override val featuresCol: String) extends LogisticRegressionSummary {
+
private val sqlContext = predictions.sqlContext
import sqlContext.implicits._
@@ -760,6 +790,7 @@ class BinaryLogisticRegressionSummary private[classification] (
* This will change in later Spark versions.
* @see http://en.wikipedia.org/wiki/Receiver_operating_characteristic
*/
+ @Since("1.5.0")
@transient lazy val roc: DataFrame = binaryMetrics.roc().toDF("FPR", "TPR")
/**
@@ -768,6 +799,7 @@ class BinaryLogisticRegressionSummary private[classification] (
* Note: This ignores instance weights (setting all to 1.0) from [[LogisticRegression.weightCol]].
* This will change in later Spark versions.
*/
+ @Since("1.5.0")
lazy val areaUnderROC: Double = binaryMetrics.areaUnderROC()
/**
@@ -777,6 +809,7 @@ class BinaryLogisticRegressionSummary private[classification] (
* Note: This ignores instance weights (setting all to 1.0) from [[LogisticRegression.weightCol]].
* This will change in later Spark versions.
*/
+ @Since("1.5.0")
@transient lazy val pr: DataFrame = binaryMetrics.pr().toDF("recall", "precision")
/**
@@ -785,6 +818,7 @@ class BinaryLogisticRegressionSummary private[classification] (
* Note: This ignores instance weights (setting all to 1.0) from [[LogisticRegression.weightCol]].
* This will change in later Spark versions.
*/
+ @Since("1.5.0")
@transient lazy val fMeasureByThreshold: DataFrame = {
binaryMetrics.fMeasureByThreshold().toDF("threshold", "F-Measure")
}
@@ -797,6 +831,7 @@ class BinaryLogisticRegressionSummary private[classification] (
* Note: This ignores instance weights (setting all to 1.0) from [[LogisticRegression.weightCol]].
* This will change in later Spark versions.
*/
+ @Since("1.5.0")
@transient lazy val precisionByThreshold: DataFrame = {
binaryMetrics.precisionByThreshold().toDF("threshold", "precision")
}
@@ -809,6 +844,7 @@ class BinaryLogisticRegressionSummary private[classification] (
* Note: This ignores instance weights (setting all to 1.0) from [[LogisticRegression.weightCol]].
* This will change in later Spark versions.
*/
+ @Since("1.5.0")
@transient lazy val recallByThreshold: DataFrame = {
binaryMetrics.recallByThreshold().toDF("threshold", "recall")
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala
index cd7462596dd9e..a691aa005ef54 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala
@@ -19,7 +19,7 @@ package org.apache.spark.ml.classification
import scala.collection.JavaConverters._
-import org.apache.spark.annotation.Experimental
+import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml.param.shared.{HasTol, HasMaxIter, HasSeed}
import org.apache.spark.ml.{PredictorParams, PredictionModel, Predictor}
import org.apache.spark.ml.param.{IntParam, ParamValidators, IntArrayParam, ParamMap}
@@ -104,19 +104,23 @@ private object LabelConverter {
* Each layer has sigmoid activation function, output layer has softmax.
* Number of inputs has to be equal to the size of feature vectors.
* Number of outputs has to be equal to the total number of labels.
- *
*/
+@Since("1.5.0")
@Experimental
-class MultilayerPerceptronClassifier(override val uid: String)
+class MultilayerPerceptronClassifier @Since("1.5.0") (
+ @Since("1.5.0") override val uid: String)
extends Predictor[Vector, MultilayerPerceptronClassifier, MultilayerPerceptronClassificationModel]
with MultilayerPerceptronParams {
+ @Since("1.5.0")
def this() = this(Identifiable.randomUID("mlpc"))
/** @group setParam */
+ @Since("1.5.0")
def setLayers(value: Array[Int]): this.type = set(layers, value)
/** @group setParam */
+ @Since("1.5.0")
def setBlockSize(value: Int): this.type = set(blockSize, value)
/**
@@ -124,6 +128,7 @@ class MultilayerPerceptronClassifier(override val uid: String)
* Default is 100.
* @group setParam
*/
+ @Since("1.5.0")
def setMaxIter(value: Int): this.type = set(maxIter, value)
/**
@@ -132,14 +137,17 @@ class MultilayerPerceptronClassifier(override val uid: String)
* Default is 1E-4.
* @group setParam
*/
+ @Since("1.5.0")
def setTol(value: Double): this.type = set(tol, value)
/**
* Set the seed for weights initialization.
* @group setParam
*/
+ @Since("1.5.0")
def setSeed(value: Long): this.type = set(seed, value)
+ @Since("1.5.0")
override def copy(extra: ParamMap): MultilayerPerceptronClassifier = defaultCopy(extra)
/**
@@ -173,14 +181,16 @@ class MultilayerPerceptronClassifier(override val uid: String)
* @param weights vector of initial weights for the model that consists of the weights of layers
* @return prediction model
*/
+@Since("1.5.0")
@Experimental
class MultilayerPerceptronClassificationModel private[ml] (
- override val uid: String,
- val layers: Array[Int],
- val weights: Vector)
+ @Since("1.5.0") override val uid: String,
+ @Since("1.5.0") val layers: Array[Int],
+ @Since("1.5.0") val weights: Vector)
extends PredictionModel[Vector, MultilayerPerceptronClassificationModel]
with Serializable {
+ @Since("1.6.0")
override val numFeatures: Int = layers.head
private val mlpModel = FeedForwardTopology.multiLayerPerceptron(layers, true).getInstance(weights)
@@ -200,6 +210,7 @@ class MultilayerPerceptronClassificationModel private[ml] (
LabelConverter.decodeLabel(mlpModel.predict(features))
}
+ @Since("1.5.0")
override def copy(extra: ParamMap): MultilayerPerceptronClassificationModel = {
copyValues(new MultilayerPerceptronClassificationModel(uid, layers, weights), extra)
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
index c512a2cb8bf3d..718f49d3aedcd 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
@@ -72,11 +72,14 @@ private[ml] trait NaiveBayesParams extends PredictorParams {
* ([[http://nlp.stanford.edu/IR-book/html/htmledition/the-bernoulli-model-1.html]]).
* The input feature values must be nonnegative.
*/
+@Since("1.5.0")
@Experimental
-class NaiveBayes(override val uid: String)
+class NaiveBayes @Since("1.5.0") (
+ @Since("1.5.0") override val uid: String)
extends ProbabilisticClassifier[Vector, NaiveBayes, NaiveBayesModel]
with NaiveBayesParams with DefaultParamsWritable {
+ @Since("1.5.0")
def this() = this(Identifiable.randomUID("nb"))
/**
@@ -84,6 +87,7 @@ class NaiveBayes(override val uid: String)
* Default is 1.0.
* @group setParam
*/
+ @Since("1.5.0")
def setSmoothing(value: Double): this.type = set(smoothing, value)
setDefault(smoothing -> 1.0)
@@ -93,6 +97,7 @@ class NaiveBayes(override val uid: String)
* Default is "multinomial"
* @group setParam
*/
+ @Since("1.5.0")
def setModelType(value: String): this.type = set(modelType, value)
setDefault(modelType -> OldNaiveBayes.Multinomial)
@@ -102,6 +107,7 @@ class NaiveBayes(override val uid: String)
NaiveBayesModel.fromOld(oldModel, this)
}
+ @Since("1.5.0")
override def copy(extra: ParamMap): NaiveBayes = defaultCopy(extra)
}
@@ -119,11 +125,12 @@ object NaiveBayes extends DefaultParamsReadable[NaiveBayes] {
* @param theta log of class conditional probabilities, whose dimension is C (number of classes)
* by D (number of features)
*/
+@Since("1.5.0")
@Experimental
class NaiveBayesModel private[ml] (
- override val uid: String,
- val pi: Vector,
- val theta: Matrix)
+ @Since("1.5.0") override val uid: String,
+ @Since("1.5.0") val pi: Vector,
+ @Since("1.5.0") val theta: Matrix)
extends ProbabilisticClassificationModel[Vector, NaiveBayesModel]
with NaiveBayesParams with MLWritable {
@@ -148,8 +155,10 @@ class NaiveBayesModel private[ml] (
throw new UnknownError(s"Invalid modelType: ${$(modelType)}.")
}
+ @Since("1.6.0")
override val numFeatures: Int = theta.numCols
+ @Since("1.5.0")
override val numClasses: Int = pi.size
private def multinomialCalculation(features: Vector) = {
@@ -206,10 +215,12 @@ class NaiveBayesModel private[ml] (
}
}
+ @Since("1.5.0")
override def copy(extra: ParamMap): NaiveBayesModel = {
copyValues(new NaiveBayesModel(uid, pi, theta).setParent(this.parent), extra)
}
+ @Since("1.5.0")
override def toString: String = {
s"NaiveBayesModel (uid=$uid) with ${pi.size} classes"
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
index debc164bf2432..08a51109d6c62 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
@@ -21,7 +21,7 @@ import java.util.UUID
import scala.language.existentials
-import org.apache.spark.annotation.Experimental
+import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml._
import org.apache.spark.ml.attribute._
import org.apache.spark.ml.param.{Param, ParamMap}
@@ -70,17 +70,20 @@ private[ml] trait OneVsRestParams extends PredictorParams {
* The i-th model is produced by testing the i-th class (taking label 1) vs the rest
* (taking label 0).
*/
+@Since("1.4.0")
@Experimental
final class OneVsRestModel private[ml] (
- override val uid: String,
- labelMetadata: Metadata,
- val models: Array[_ <: ClassificationModel[_, _]])
+ @Since("1.4.0") override val uid: String,
+ @Since("1.4.0") labelMetadata: Metadata,
+ @Since("1.4.0") val models: Array[_ <: ClassificationModel[_, _]])
extends Model[OneVsRestModel] with OneVsRestParams {
+ @Since("1.4.0")
override def transformSchema(schema: StructType): StructType = {
validateAndTransformSchema(schema, fitting = false, getClassifier.featuresDataType)
}
+ @Since("1.4.0")
override def transform(dataset: DataFrame): DataFrame = {
// Check schema
transformSchema(dataset.schema, logging = true)
@@ -134,6 +137,7 @@ final class OneVsRestModel private[ml] (
.drop(accColName)
}
+ @Since("1.4.1")
override def copy(extra: ParamMap): OneVsRestModel = {
val copied = new OneVsRestModel(
uid, labelMetadata, models.map(_.copy(extra).asInstanceOf[ClassificationModel[_, _]]))
@@ -150,30 +154,39 @@ final class OneVsRestModel private[ml] (
* Each example is scored against all k models and the model with highest score
* is picked to label the example.
*/
+@Since("1.4.0")
@Experimental
-final class OneVsRest(override val uid: String)
+final class OneVsRest @Since("1.4.0") (
+ @Since("1.4.0") override val uid: String)
extends Estimator[OneVsRestModel] with OneVsRestParams {
+ @Since("1.4.0")
def this() = this(Identifiable.randomUID("oneVsRest"))
/** @group setParam */
+ @Since("1.4.0")
def setClassifier(value: Classifier[_, _, _]): this.type = {
set(classifier, value.asInstanceOf[ClassifierType])
}
/** @group setParam */
+ @Since("1.5.0")
def setLabelCol(value: String): this.type = set(labelCol, value)
/** @group setParam */
+ @Since("1.5.0")
def setFeaturesCol(value: String): this.type = set(featuresCol, value)
/** @group setParam */
+ @Since("1.5.0")
def setPredictionCol(value: String): this.type = set(predictionCol, value)
+ @Since("1.4.0")
override def transformSchema(schema: StructType): StructType = {
validateAndTransformSchema(schema, fitting = true, getClassifier.featuresDataType)
}
+ @Since("1.4.0")
override def fit(dataset: DataFrame): OneVsRestModel = {
// determine number of classes either from metadata if provided, or via computation.
val labelSchema = dataset.schema($(labelCol))
@@ -222,6 +235,7 @@ final class OneVsRest(override val uid: String)
copyValues(model)
}
+ @Since("1.4.1")
override def copy(extra: ParamMap): OneVsRest = {
val copied = defaultCopy(extra).asInstanceOf[OneVsRest]
if (isDefined(classifier)) {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
index bae329692a68d..d6d85ad2533a2 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
@@ -17,7 +17,7 @@
package org.apache.spark.ml.classification
-import org.apache.spark.annotation.Experimental
+import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml.tree.impl.RandomForest
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.tree.{DecisionTreeModel, RandomForestParams, TreeClassifierParams, TreeEnsembleModel}
@@ -38,44 +38,59 @@ import org.apache.spark.sql.functions._
* It supports both binary and multiclass labels, as well as both continuous and categorical
* features.
*/
+@Since("1.4.0")
@Experimental
-final class RandomForestClassifier(override val uid: String)
+final class RandomForestClassifier @Since("1.4.0") (
+ @Since("1.4.0") override val uid: String)
extends ProbabilisticClassifier[Vector, RandomForestClassifier, RandomForestClassificationModel]
with RandomForestParams with TreeClassifierParams {
+ @Since("1.4.0")
def this() = this(Identifiable.randomUID("rfc"))
// Override parameter setters from parent trait for Java API compatibility.
// Parameters from TreeClassifierParams:
+ @Since("1.4.0")
override def setMaxDepth(value: Int): this.type = super.setMaxDepth(value)
+ @Since("1.4.0")
override def setMaxBins(value: Int): this.type = super.setMaxBins(value)
+ @Since("1.4.0")
override def setMinInstancesPerNode(value: Int): this.type =
super.setMinInstancesPerNode(value)
+ @Since("1.4.0")
override def setMinInfoGain(value: Double): this.type = super.setMinInfoGain(value)
+ @Since("1.4.0")
override def setMaxMemoryInMB(value: Int): this.type = super.setMaxMemoryInMB(value)
+ @Since("1.4.0")
override def setCacheNodeIds(value: Boolean): this.type = super.setCacheNodeIds(value)
+ @Since("1.4.0")
override def setCheckpointInterval(value: Int): this.type = super.setCheckpointInterval(value)
+ @Since("1.4.0")
override def setImpurity(value: String): this.type = super.setImpurity(value)
// Parameters from TreeEnsembleParams:
+ @Since("1.4.0")
override def setSubsamplingRate(value: Double): this.type = super.setSubsamplingRate(value)
+ @Since("1.4.0")
override def setSeed(value: Long): this.type = super.setSeed(value)
// Parameters from RandomForestParams:
+ @Since("1.4.0")
override def setNumTrees(value: Int): this.type = super.setNumTrees(value)
+ @Since("1.4.0")
override def setFeatureSubsetStrategy(value: String): this.type =
super.setFeatureSubsetStrategy(value)
@@ -99,15 +114,19 @@ final class RandomForestClassifier(override val uid: String)
new RandomForestClassificationModel(trees, numFeatures, numClasses)
}
+ @Since("1.4.1")
override def copy(extra: ParamMap): RandomForestClassifier = defaultCopy(extra)
}
+@Since("1.4.0")
@Experimental
object RandomForestClassifier {
/** Accessor for supported impurity settings: entropy, gini */
+ @Since("1.4.0")
final val supportedImpurities: Array[String] = TreeClassifierParams.supportedImpurities
/** Accessor for supported featureSubsetStrategy settings: auto, all, onethird, sqrt, log2 */
+ @Since("1.4.0")
final val supportedFeatureSubsetStrategies: Array[String] =
RandomForestParams.supportedFeatureSubsetStrategies
}
@@ -120,12 +139,13 @@ object RandomForestClassifier {
* @param _trees Decision trees in the ensemble.
* Warning: These have null parents.
*/
+@Since("1.4.0")
@Experimental
final class RandomForestClassificationModel private[ml] (
- override val uid: String,
+ @Since("1.5.0") override val uid: String,
private val _trees: Array[DecisionTreeClassificationModel],
- override val numFeatures: Int,
- override val numClasses: Int)
+ @Since("1.6.0") override val numFeatures: Int,
+ @Since("1.5.0") override val numClasses: Int)
extends ProbabilisticClassificationModel[Vector, RandomForestClassificationModel]
with TreeEnsembleModel with Serializable {
@@ -141,11 +161,13 @@ final class RandomForestClassificationModel private[ml] (
numClasses: Int) =
this(Identifiable.randomUID("rfc"), trees, numFeatures, numClasses)
+ @Since("1.4.0")
override def trees: Array[DecisionTreeModel] = _trees.asInstanceOf[Array[DecisionTreeModel]]
// Note: We may add support for weights (based on tree performance) later on.
private lazy val _treeWeights: Array[Double] = Array.fill[Double](numTrees)(1.0)
+ @Since("1.4.0")
override def treeWeights: Array[Double] = _treeWeights
override protected def transformImpl(dataset: DataFrame): DataFrame = {
@@ -186,11 +208,13 @@ final class RandomForestClassificationModel private[ml] (
}
}
+ @Since("1.4.0")
override def copy(extra: ParamMap): RandomForestClassificationModel = {
copyValues(new RandomForestClassificationModel(uid, _trees, numFeatures, numClasses), extra)
.setParent(parent)
}
+ @Since("1.4.0")
override def toString: String = {
s"RandomForestClassificationModel (uid=$uid) with $numTrees trees"
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala
index daaa174a086e0..b6b25ecd01b3d 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala
@@ -73,10 +73,15 @@ final class RegressionEvaluator @Since("1.4.0") (@Since("1.4.0") override val ui
@Since("1.4.0")
override def evaluate(dataset: DataFrame): Double = {
val schema = dataset.schema
+ val predictionColName = $(predictionCol)
val predictionType = schema($(predictionCol)).dataType
- require(predictionType == FloatType || predictionType == DoubleType)
+ require(predictionType == FloatType || predictionType == DoubleType,
+ s"Prediction column $predictionColName must be of type float or double, " +
+ s" but not $predictionType")
+ val labelColName = $(labelCol)
val labelType = schema($(labelCol)).dataType
- require(labelType == FloatType || labelType == DoubleType)
+ require(labelType == FloatType || labelType == DoubleType,
+ s"Label column $labelColName must be of type float or double, but not $labelType")
val predictionAndLabels = dataset
.select(col($(predictionCol)).cast(DoubleType), col($(labelCol)).cast(DoubleType))
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala
index 3a735017ba836..c09f4d076c964 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala
@@ -27,9 +27,16 @@ import org.apache.spark.sql.types.StructType
/**
* :: Experimental ::
- * Implements the transforms which are defined by SQL statement.
- * Currently we only support SQL syntax like 'SELECT ... FROM __THIS__'
+ * Implements the transformations which are defined by SQL statement.
+ * Currently we only support SQL syntax like 'SELECT ... FROM __THIS__ ...'
* where '__THIS__' represents the underlying table of the input dataset.
+ * The select clause specifies the fields, constants, and expressions to display in
+ * the output, it can be any select clause that Spark SQL supports. Users can also
+ * use Spark SQL built-in function and UDFs to operate on these selected columns.
+ * For example, [[SQLTransformer]] supports statements like:
+ * - SELECT a, a + b AS a_b FROM __THIS__
+ * - SELECT a, SQRT(b) AS b_sqrt FROM __THIS__ where a > 5
+ * - SELECT a, b, SUM(c) AS c_sum FROM __THIS__ GROUP BY a, b
*/
@Experimental
@Since("1.6.0")
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
index a8d61b6dea00b..f105a983a34f6 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
@@ -49,6 +49,17 @@ private[feature] trait Word2VecBase extends Params
/** @group getParam */
def getVectorSize: Int = $(vectorSize)
+ /**
+ * The window size (context words from [-window, window]) default 5.
+ * @group expertParam
+ */
+ final val windowSize = new IntParam(
+ this, "windowSize", "the window size (context words from [-window, window])")
+ setDefault(windowSize -> 5)
+
+ /** @group expertGetParam */
+ def getWindowSize: Int = $(windowSize)
+
/**
* Number of partitions for sentences of words.
* Default: 1
@@ -106,6 +117,9 @@ final class Word2Vec(override val uid: String) extends Estimator[Word2VecModel]
/** @group setParam */
def setVectorSize(value: Int): this.type = set(vectorSize, value)
+ /** @group expertSetParam */
+ def setWindowSize(value: Int): this.type = set(windowSize, value)
+
/** @group setParam */
def setStepSize(value: Double): this.type = set(stepSize, value)
@@ -131,6 +145,7 @@ final class Word2Vec(override val uid: String) extends Estimator[Word2VecModel]
.setNumPartitions($(numPartitions))
.setSeed($(seed))
.setVectorSize($(vectorSize))
+ .setWindowSize($(windowSize))
.fit(input)
copyValues(new Word2VecModel(uid, wordVectors).setParent(this))
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
index 54b03a9f90283..2aa6aec0b4347 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
@@ -1191,7 +1191,7 @@ private[python] class PythonMLLibAPI extends Serializable {
def getIndexedRows(indexedRowMatrix: IndexedRowMatrix): DataFrame = {
// We use DataFrames for serialization of IndexedRows to Python,
// so return a DataFrame.
- val sqlContext = new SQLContext(indexedRowMatrix.rows.sparkContext)
+ val sqlContext = SQLContext.getOrCreate(indexedRowMatrix.rows.sparkContext)
sqlContext.createDataFrame(indexedRowMatrix.rows)
}
@@ -1201,7 +1201,7 @@ private[python] class PythonMLLibAPI extends Serializable {
def getMatrixEntries(coordinateMatrix: CoordinateMatrix): DataFrame = {
// We use DataFrames for serialization of MatrixEntry entries to
// Python, so return a DataFrame.
- val sqlContext = new SQLContext(coordinateMatrix.entries.sparkContext)
+ val sqlContext = SQLContext.getOrCreate(coordinateMatrix.entries.sparkContext)
sqlContext.createDataFrame(coordinateMatrix.entries)
}
@@ -1211,7 +1211,7 @@ private[python] class PythonMLLibAPI extends Serializable {
def getMatrixBlocks(blockMatrix: BlockMatrix): DataFrame = {
// We use DataFrames for serialization of sub-matrix blocks to
// Python, so return a DataFrame.
- val sqlContext = new SQLContext(blockMatrix.blocks.sparkContext)
+ val sqlContext = SQLContext.getOrCreate(blockMatrix.blocks.sparkContext)
sqlContext.createDataFrame(blockMatrix.blocks)
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
index a956084ae06e8..aef9ef2cb052d 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
@@ -192,7 +192,7 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] {
modelType: String)
def save(sc: SparkContext, path: String, data: Data): Unit = {
- val sqlContext = new SQLContext(sc)
+ val sqlContext = SQLContext.getOrCreate(sc)
import sqlContext.implicits._
// Create JSON metadata.
@@ -208,7 +208,7 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] {
@Since("1.3.0")
def load(sc: SparkContext, path: String): NaiveBayesModel = {
- val sqlContext = new SQLContext(sc)
+ val sqlContext = SQLContext.getOrCreate(sc)
// Load Parquet data.
val dataRDD = sqlContext.read.parquet(dataPath(path))
// Check schema explicitly since erasure makes it hard to use match-case for checking.
@@ -239,7 +239,7 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] {
theta: Array[Array[Double]])
def save(sc: SparkContext, path: String, data: Data): Unit = {
- val sqlContext = new SQLContext(sc)
+ val sqlContext = SQLContext.getOrCreate(sc)
import sqlContext.implicits._
// Create JSON metadata.
@@ -254,7 +254,7 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] {
}
def load(sc: SparkContext, path: String): NaiveBayesModel = {
- val sqlContext = new SQLContext(sc)
+ val sqlContext = SQLContext.getOrCreate(sc)
// Load Parquet data.
val dataRDD = sqlContext.read.parquet(dataPath(path))
// Check schema explicitly since erasure makes it hard to use match-case for checking.
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/impl/GLMClassificationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/impl/GLMClassificationModel.scala
index fe09f6b75d28b..2910c027ae06d 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/classification/impl/GLMClassificationModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/impl/GLMClassificationModel.scala
@@ -51,7 +51,7 @@ private[classification] object GLMClassificationModel {
weights: Vector,
intercept: Double,
threshold: Option[Double]): Unit = {
- val sqlContext = new SQLContext(sc)
+ val sqlContext = SQLContext.getOrCreate(sc)
import sqlContext.implicits._
// Create JSON metadata.
@@ -74,7 +74,7 @@ private[classification] object GLMClassificationModel {
*/
def loadData(sc: SparkContext, path: String, modelClass: String): Data = {
val datapath = Loader.dataPath(path)
- val sqlContext = new SQLContext(sc)
+ val sqlContext = SQLContext.getOrCreate(sc)
val dataRDD = sqlContext.read.parquet(datapath)
val dataArray = dataRDD.select("weights", "intercept", "threshold").take(1)
assert(dataArray.size == 1, s"Unable to load $modelClass data from: $datapath")
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala
index 2115f7d99c182..74d13e4f77945 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala
@@ -145,7 +145,7 @@ object GaussianMixtureModel extends Loader[GaussianMixtureModel] {
weights: Array[Double],
gaussians: Array[MultivariateGaussian]): Unit = {
- val sqlContext = new SQLContext(sc)
+ val sqlContext = SQLContext.getOrCreate(sc)
import sqlContext.implicits._
// Create JSON metadata.
@@ -162,7 +162,7 @@ object GaussianMixtureModel extends Loader[GaussianMixtureModel] {
def load(sc: SparkContext, path: String): GaussianMixtureModel = {
val dataPath = Loader.dataPath(path)
- val sqlContext = new SQLContext(sc)
+ val sqlContext = SQLContext.getOrCreate(sc)
val dataFrame = sqlContext.read.parquet(dataPath)
// Check schema explicitly since erasure makes it hard to use match-case for checking.
Loader.checkSchema[Data](dataFrame.schema)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala
index a741584982725..91fa9b0d3590d 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala
@@ -124,7 +124,7 @@ object KMeansModel extends Loader[KMeansModel] {
val thisClassName = "org.apache.spark.mllib.clustering.KMeansModel"
def save(sc: SparkContext, model: KMeansModel, path: String): Unit = {
- val sqlContext = new SQLContext(sc)
+ val sqlContext = SQLContext.getOrCreate(sc)
import sqlContext.implicits._
val metadata = compact(render(
("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ ("k" -> model.k)))
@@ -137,7 +137,7 @@ object KMeansModel extends Loader[KMeansModel] {
def load(sc: SparkContext, path: String): KMeansModel = {
implicit val formats = DefaultFormats
- val sqlContext = new SQLContext(sc)
+ val sqlContext = SQLContext.getOrCreate(sc)
val (className, formatVersion, metadata) = Loader.loadMetadata(sc, path)
assert(className == thisClassName)
assert(formatVersion == thisFormatVersion)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala
index 7cd9b08fa8e0e..bb1804505948b 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala
@@ -70,7 +70,7 @@ object PowerIterationClusteringModel extends Loader[PowerIterationClusteringMode
@Since("1.4.0")
def save(sc: SparkContext, model: PowerIterationClusteringModel, path: String): Unit = {
- val sqlContext = new SQLContext(sc)
+ val sqlContext = SQLContext.getOrCreate(sc)
import sqlContext.implicits._
val metadata = compact(render(
@@ -84,7 +84,7 @@ object PowerIterationClusteringModel extends Loader[PowerIterationClusteringMode
@Since("1.4.0")
def load(sc: SparkContext, path: String): PowerIterationClusteringModel = {
implicit val formats = DefaultFormats
- val sqlContext = new SQLContext(sc)
+ val sqlContext = SQLContext.getOrCreate(sc)
val (className, formatVersion, metadata) = Loader.loadMetadata(sc, path)
assert(className == thisClassName)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala
index d4d022afde051..eaa99cfe82e27 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala
@@ -134,7 +134,7 @@ object ChiSqSelectorModel extends Loader[ChiSqSelectorModel] {
val thisClassName = "org.apache.spark.mllib.feature.ChiSqSelectorModel"
def save(sc: SparkContext, model: ChiSqSelectorModel, path: String): Unit = {
- val sqlContext = new SQLContext(sc)
+ val sqlContext = SQLContext.getOrCreate(sc)
import sqlContext.implicits._
val metadata = compact(render(
("class" -> thisClassName) ~ ("version" -> thisFormatVersion)))
@@ -150,7 +150,7 @@ object ChiSqSelectorModel extends Loader[ChiSqSelectorModel] {
def load(sc: SparkContext, path: String): ChiSqSelectorModel = {
implicit val formats = DefaultFormats
- val sqlContext = new SQLContext(sc)
+ val sqlContext = SQLContext.getOrCreate(sc)
val (className, formatVersion, metadata) = Loader.loadMetadata(sc, path)
assert(className == thisClassName)
assert(formatVersion == thisFormatVersion)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
index a47f27b0afb14..a01077524f5cf 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
@@ -125,6 +125,15 @@ class Word2Vec extends Serializable with Logging {
this
}
+ /**
+ * Sets the window of words (default: 5)
+ */
+ @Since("1.6.0")
+ def setWindowSize(window: Int): this.type = {
+ this.window = window
+ this
+ }
+
/**
* Sets minCount, the minimum number of times a token must appear to be included in the word2vec
* model's vocabulary (default: 5).
@@ -141,7 +150,7 @@ class Word2Vec extends Serializable with Logging {
private val MAX_SENTENCE_LENGTH = 1000
/** context words from [-window, window] */
- private val window = 5
+ private var window = 5
private var trainWordsCount = 0
private var vocabSize = 0
@@ -582,7 +591,7 @@ object Word2VecModel extends Loader[Word2VecModel] {
def load(sc: SparkContext, path: String): Word2VecModel = {
val dataPath = Loader.dataPath(path)
- val sqlContext = new SQLContext(sc)
+ val sqlContext = SQLContext.getOrCreate(sc)
val dataFrame = sqlContext.read.parquet(dataPath)
// Check schema explicitly since erasure makes it hard to use match-case for checking.
Loader.checkSchema[Data](dataFrame.schema)
@@ -594,7 +603,7 @@ object Word2VecModel extends Loader[Word2VecModel] {
def save(sc: SparkContext, path: String, model: Map[String, Array[Float]]): Unit = {
- val sqlContext = new SQLContext(sc)
+ val sqlContext = SQLContext.getOrCreate(sc)
import sqlContext.implicits._
val vectorSize = model.values.head.size
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala
index 46562eb2ad0f7..0dc40483dd0ff 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala
@@ -353,7 +353,7 @@ object MatrixFactorizationModel extends Loader[MatrixFactorizationModel] {
*/
def save(model: MatrixFactorizationModel, path: String): Unit = {
val sc = model.userFeatures.sparkContext
- val sqlContext = new SQLContext(sc)
+ val sqlContext = SQLContext.getOrCreate(sc)
import sqlContext.implicits._
val metadata = compact(render(
("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ ("rank" -> model.rank)))
@@ -364,7 +364,7 @@ object MatrixFactorizationModel extends Loader[MatrixFactorizationModel] {
def load(sc: SparkContext, path: String): MatrixFactorizationModel = {
implicit val formats = DefaultFormats
- val sqlContext = new SQLContext(sc)
+ val sqlContext = SQLContext.getOrCreate(sc)
val (className, formatVersion, metadata) = loadMetadata(sc, path)
assert(className == thisClassName)
assert(formatVersion == thisFormatVersion)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala
index ec78ea24539b5..f235089873ab8 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala
@@ -185,7 +185,7 @@ object IsotonicRegressionModel extends Loader[IsotonicRegressionModel] {
boundaries: Array[Double],
predictions: Array[Double],
isotonic: Boolean): Unit = {
- val sqlContext = new SQLContext(sc)
+ val sqlContext = SQLContext.getOrCreate(sc)
val metadata = compact(render(
("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~
@@ -198,7 +198,7 @@ object IsotonicRegressionModel extends Loader[IsotonicRegressionModel] {
}
def load(sc: SparkContext, path: String): (Array[Double], Array[Double]) = {
- val sqlContext = new SQLContext(sc)
+ val sqlContext = SQLContext.getOrCreate(sc)
val dataRDD = sqlContext.read.parquet(dataPath(path))
checkSchema[Data](dataRDD.schema)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/impl/GLMRegressionModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/impl/GLMRegressionModel.scala
index 317d3a5702636..02af281fb726b 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/regression/impl/GLMRegressionModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/impl/GLMRegressionModel.scala
@@ -47,7 +47,7 @@ private[regression] object GLMRegressionModel {
modelClass: String,
weights: Vector,
intercept: Double): Unit = {
- val sqlContext = new SQLContext(sc)
+ val sqlContext = SQLContext.getOrCreate(sc)
import sqlContext.implicits._
// Create JSON metadata.
@@ -71,7 +71,7 @@ private[regression] object GLMRegressionModel {
*/
def loadData(sc: SparkContext, path: String, modelClass: String, numFeatures: Int): Data = {
val datapath = Loader.dataPath(path)
- val sqlContext = new SQLContext(sc)
+ val sqlContext = SQLContext.getOrCreate(sc)
val dataRDD = sqlContext.read.parquet(datapath)
val dataArray = dataRDD.select("weights", "intercept").take(1)
assert(dataArray.size == 1, s"Unable to load $modelClass data from: $datapath")
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/StreamingTest.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/StreamingTest.scala
index 75c6a51d09571..e990fe0768bc9 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/StreamingTest.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/StreamingTest.scala
@@ -17,12 +17,30 @@
package org.apache.spark.mllib.stat.test
+import scala.beans.BeanInfo
+
import org.apache.spark.Logging
import org.apache.spark.annotation.{Experimental, Since}
-import org.apache.spark.rdd.RDD
+import org.apache.spark.streaming.api.java.JavaDStream
import org.apache.spark.streaming.dstream.DStream
import org.apache.spark.util.StatCounter
+/**
+ * Class that represents the group and value of a sample.
+ *
+ * @param isExperiment if the sample is of the experiment group.
+ * @param value numeric value of the observation.
+ */
+@Since("1.6.0")
+@BeanInfo
+case class BinarySample @Since("1.6.0") (
+ @Since("1.6.0") isExperiment: Boolean,
+ @Since("1.6.0") value: Double) {
+ override def toString: String = {
+ s"($isExperiment, $value)"
+ }
+}
+
/**
* :: Experimental ::
* Performs online 2-sample significance testing for a stream of (Boolean, Double) pairs. The
@@ -83,13 +101,13 @@ class StreamingTest @Since("1.6.0") () extends Logging with Serializable {
/**
* Register a [[DStream]] of values for significance testing.
*
- * @param data stream of (key,value) pairs where the key denotes group membership (true =
- * experiment, false = control) and the value is the numerical metric to test for
- * significance
+ * @param data stream of BinarySample(key,value) pairs where the key denotes group membership
+ * (true = experiment, false = control) and the value is the numerical metric to
+ * test for significance
* @return stream of significance testing results
*/
@Since("1.6.0")
- def registerStream(data: DStream[(Boolean, Double)]): DStream[StreamingTestResult] = {
+ def registerStream(data: DStream[BinarySample]): DStream[StreamingTestResult] = {
val dataAfterPeacePeriod = dropPeacePeriod(data)
val summarizedData = summarizeByKeyAndWindow(dataAfterPeacePeriod)
val pairedSummaries = pairSummaries(summarizedData)
@@ -97,9 +115,22 @@ class StreamingTest @Since("1.6.0") () extends Logging with Serializable {
testMethod.doTest(pairedSummaries)
}
+ /**
+ * Register a [[JavaDStream]] of values for significance testing.
+ *
+ * @param data stream of BinarySample(isExperiment,value) pairs where the isExperiment denotes
+ * group (true = experiment, false = control) and the value is the numerical metric
+ * to test for significance
+ * @return stream of significance testing results
+ */
+ @Since("1.6.0")
+ def registerStream(data: JavaDStream[BinarySample]): JavaDStream[StreamingTestResult] = {
+ JavaDStream.fromDStream(registerStream(data.dstream))
+ }
+
/** Drop all batches inside the peace period. */
private[stat] def dropPeacePeriod(
- data: DStream[(Boolean, Double)]): DStream[(Boolean, Double)] = {
+ data: DStream[BinarySample]): DStream[BinarySample] = {
data.transform { (rdd, time) =>
if (time.milliseconds > data.slideDuration.milliseconds * peacePeriod) {
rdd
@@ -111,9 +142,10 @@ class StreamingTest @Since("1.6.0") () extends Logging with Serializable {
/** Compute summary statistics over each key and the specified test window size. */
private[stat] def summarizeByKeyAndWindow(
- data: DStream[(Boolean, Double)]): DStream[(Boolean, StatCounter)] = {
+ data: DStream[BinarySample]): DStream[(Boolean, StatCounter)] = {
+ val categoryValuePair = data.map(sample => (sample.isExperiment, sample.value))
if (this.windowSize == 0) {
- data.updateStateByKey[StatCounter](
+ categoryValuePair.updateStateByKey[StatCounter](
(newValues: Seq[Double], oldSummary: Option[StatCounter]) => {
val newSummary = oldSummary.getOrElse(new StatCounter())
newSummary.merge(newValues)
@@ -121,7 +153,7 @@ class StreamingTest @Since("1.6.0") () extends Logging with Serializable {
})
} else {
val windowDuration = data.slideDuration * this.windowSize
- data
+ categoryValuePair
.groupByKeyAndWindow(windowDuration)
.mapValues { values =>
val summary = new StatCounter()
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
index 54c136aecf660..89c470d573431 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
@@ -201,7 +201,7 @@ object DecisionTreeModel extends Loader[DecisionTreeModel] with Logging {
}
def save(sc: SparkContext, path: String, model: DecisionTreeModel): Unit = {
- val sqlContext = new SQLContext(sc)
+ val sqlContext = SQLContext.getOrCreate(sc)
import sqlContext.implicits._
// SPARK-6120: We do a hacky check here so users understand why save() is failing
@@ -242,7 +242,7 @@ object DecisionTreeModel extends Loader[DecisionTreeModel] with Logging {
def load(sc: SparkContext, path: String, algo: String, numNodes: Int): DecisionTreeModel = {
val datapath = Loader.dataPath(path)
- val sqlContext = new SQLContext(sc)
+ val sqlContext = SQLContext.getOrCreate(sc)
// Load Parquet data.
val dataRDD = sqlContext.read.parquet(datapath)
// Check schema explicitly since erasure makes it hard to use match-case for checking.
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala
index 90e032e3d9842..feabcee24fa2c 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala
@@ -25,7 +25,7 @@ import org.json4s.JsonDSL._
import org.json4s.jackson.JsonMethods._
import org.apache.spark.{Logging, SparkContext}
-import org.apache.spark.annotation.Since
+import org.apache.spark.annotation.{DeveloperApi, Since}
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.regression.LabeledPoint
@@ -186,6 +186,7 @@ class GradientBoostedTreesModel @Since("1.2.0") (
object GradientBoostedTreesModel extends Loader[GradientBoostedTreesModel] {
/**
+ * :: DeveloperApi ::
* Compute the initial predictions and errors for a dataset for the first
* iteration of gradient boosting.
* @param data: training data.
@@ -196,6 +197,7 @@ object GradientBoostedTreesModel extends Loader[GradientBoostedTreesModel] {
* corresponding to every sample.
*/
@Since("1.4.0")
+ @DeveloperApi
def computeInitialPredictionAndError(
data: RDD[LabeledPoint],
initTreeWeight: Double,
@@ -209,6 +211,7 @@ object GradientBoostedTreesModel extends Loader[GradientBoostedTreesModel] {
}
/**
+ * :: DeveloperApi ::
* Update a zipped predictionError RDD
* (as obtained with computeInitialPredictionAndError)
* @param data: training data.
@@ -220,6 +223,7 @@ object GradientBoostedTreesModel extends Loader[GradientBoostedTreesModel] {
* corresponding to each sample.
*/
@Since("1.4.0")
+ @DeveloperApi
def updatePredictionError(
data: RDD[LabeledPoint],
predictionAndError: RDD[(Double, Double)],
@@ -408,7 +412,7 @@ private[tree] object TreeEnsembleModel extends Logging {
case class EnsembleNodeData(treeId: Int, node: NodeData)
def save(sc: SparkContext, path: String, model: TreeEnsembleModel, className: String): Unit = {
- val sqlContext = new SQLContext(sc)
+ val sqlContext = SQLContext.getOrCreate(sc)
import sqlContext.implicits._
// SPARK-6120: We do a hacky check here so users understand why save() is failing
@@ -468,7 +472,7 @@ private[tree] object TreeEnsembleModel extends Logging {
path: String,
treeAlgo: String): Array[DecisionTreeModel] = {
val datapath = Loader.dataPath(path)
- val sqlContext = new SQLContext(sc)
+ val sqlContext = SQLContext.getOrCreate(sc)
val nodes = sqlContext.read.parquet(datapath).map(NodeData.apply)
val trees = constructTrees(nodes)
trees.map(new DecisionTreeModel(_, Algo.fromString(treeAlgo)))
diff --git a/mllib/src/test/java/org/apache/spark/mllib/stat/JavaStatisticsSuite.java b/mllib/src/test/java/org/apache/spark/mllib/stat/JavaStatisticsSuite.java
index 4795809e47a46..66b2ceacb05f2 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/stat/JavaStatisticsSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/stat/JavaStatisticsSuite.java
@@ -18,34 +18,49 @@
package org.apache.spark.mllib.stat;
import java.io.Serializable;
-
import java.util.Arrays;
+import java.util.List;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
+import static org.apache.spark.streaming.JavaTestUtils.*;
import static org.junit.Assert.assertEquals;
+import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaDoubleRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.mllib.regression.LabeledPoint;
+import org.apache.spark.mllib.stat.test.BinarySample;
import org.apache.spark.mllib.stat.test.ChiSqTestResult;
import org.apache.spark.mllib.stat.test.KolmogorovSmirnovTestResult;
+import org.apache.spark.mllib.stat.test.StreamingTest;
+import org.apache.spark.streaming.Duration;
+import org.apache.spark.streaming.api.java.JavaDStream;
+import org.apache.spark.streaming.api.java.JavaStreamingContext;
public class JavaStatisticsSuite implements Serializable {
private transient JavaSparkContext sc;
+ private transient JavaStreamingContext ssc;
@Before
public void setUp() {
- sc = new JavaSparkContext("local", "JavaStatistics");
+ SparkConf conf = new SparkConf()
+ .setMaster("local[2]")
+ .setAppName("JavaStatistics")
+ .set("spark.streaming.clock", "org.apache.spark.util.ManualClock");
+ sc = new JavaSparkContext(conf);
+ ssc = new JavaStreamingContext(sc, new Duration(1000));
+ ssc.checkpoint("checkpoint");
}
@After
public void tearDown() {
- sc.stop();
+ ssc.stop();
+ ssc = null;
sc = null;
}
@@ -76,4 +91,21 @@ public void chiSqTest() {
new LabeledPoint(0.0, Vectors.dense(2.4, 8.1))));
ChiSqTestResult[] testResults = Statistics.chiSqTest(data);
}
+
+ @Test
+ public void streamingTest() {
+ List trainingBatch = Arrays.asList(
+ new BinarySample(true, 1.0),
+ new BinarySample(false, 2.0));
+ JavaDStream training =
+ attachTestInputStream(ssc, Arrays.asList(trainingBatch, trainingBatch), 2);
+ int numBatches = 2;
+ StreamingTest model = new StreamingTest()
+ .setWindowSize(0)
+ .setPeacePeriod(0)
+ .setTestMethod("welch");
+ model.registerStream(training);
+ attachTestOutputStream(training);
+ runStreams(ssc, numBatches, numBatches);
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala
index a773244cd735e..d561bbbb25529 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala
@@ -35,7 +35,8 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
}
test("Word2Vec") {
- val sqlContext = new SQLContext(sc)
+
+ val sqlContext = this.sqlContext
import sqlContext.implicits._
val sentence = "a b " * 100 + "a c " * 10
@@ -77,7 +78,7 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
test("getVectors") {
- val sqlContext = new SQLContext(sc)
+ val sqlContext = this.sqlContext
import sqlContext.implicits._
val sentence = "a b " * 100 + "a c " * 10
@@ -118,7 +119,7 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
test("findSynonyms") {
- val sqlContext = new SQLContext(sc)
+ val sqlContext = this.sqlContext
import sqlContext.implicits._
val sentence = "a b " * 100 + "a c " * 10
@@ -141,7 +142,43 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
expectedSimilarity.zip(similarity).map {
case (expected, actual) => assert(math.abs((expected - actual) / expected) < 1E-5)
}
+ }
+
+ test("window size") {
+
+ val sqlContext = this.sqlContext
+ import sqlContext.implicits._
+
+ val sentence = "a q s t q s t b b b s t m s t m q " * 100 + "a c " * 10
+ val doc = sc.parallelize(Seq(sentence, sentence)).map(line => line.split(" "))
+ val docDF = doc.zip(doc).toDF("text", "alsotext")
+
+ val model = new Word2Vec()
+ .setVectorSize(3)
+ .setWindowSize(2)
+ .setInputCol("text")
+ .setOutputCol("result")
+ .setSeed(42L)
+ .fit(docDF)
+ val (synonyms, similarity) = model.findSynonyms("a", 6).map {
+ case Row(w: String, sim: Double) => (w, sim)
+ }.collect().unzip
+
+ // Increase the window size
+ val biggerModel = new Word2Vec()
+ .setVectorSize(3)
+ .setInputCol("text")
+ .setOutputCol("result")
+ .setSeed(42L)
+ .setWindowSize(10)
+ .fit(docDF)
+
+ val (synonymsLarger, similarityLarger) = model.findSynonyms("a", 6).map {
+ case Row(w: String, sim: Double) => (w, sim)
+ }.collect().unzip
+ // The similarity score should be very different with the larger window
+ assert(math.abs(similarity(5) - similarityLarger(5) / similarity(5)) > 1E-5)
}
test("Word2Vec read/write") {
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/StreamingTestSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/StreamingTestSuite.scala
index d3e9ef4ff079c..3c657c8cfe743 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/stat/StreamingTestSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/StreamingTestSuite.scala
@@ -18,7 +18,8 @@
package org.apache.spark.mllib.stat
import org.apache.spark.SparkFunSuite
-import org.apache.spark.mllib.stat.test.{StreamingTest, StreamingTestResult, StudentTTest, WelchTTest}
+import org.apache.spark.mllib.stat.test.{StreamingTest, StreamingTestResult, StudentTTest,
+ WelchTTest, BinarySample}
import org.apache.spark.streaming.TestSuiteBase
import org.apache.spark.streaming.dstream.DStream
import org.apache.spark.util.StatCounter
@@ -48,7 +49,7 @@ class StreamingTestSuite extends SparkFunSuite with TestSuiteBase {
// setup and run the model
val ssc = setupStreams(
- input, (inputDStream: DStream[(Boolean, Double)]) => model.registerStream(inputDStream))
+ input, (inputDStream: DStream[BinarySample]) => model.registerStream(inputDStream))
val outputBatches = runStreams[StreamingTestResult](ssc, numBatches, numBatches)
assert(outputBatches.flatten.forall(res =>
@@ -75,7 +76,7 @@ class StreamingTestSuite extends SparkFunSuite with TestSuiteBase {
// setup and run the model
val ssc = setupStreams(
- input, (inputDStream: DStream[(Boolean, Double)]) => model.registerStream(inputDStream))
+ input, (inputDStream: DStream[BinarySample]) => model.registerStream(inputDStream))
val outputBatches = runStreams[StreamingTestResult](ssc, numBatches, numBatches)
assert(outputBatches.flatten.forall(res =>
@@ -102,7 +103,7 @@ class StreamingTestSuite extends SparkFunSuite with TestSuiteBase {
// setup and run the model
val ssc = setupStreams(
- input, (inputDStream: DStream[(Boolean, Double)]) => model.registerStream(inputDStream))
+ input, (inputDStream: DStream[BinarySample]) => model.registerStream(inputDStream))
val outputBatches = runStreams[StreamingTestResult](ssc, numBatches, numBatches)
@@ -130,7 +131,7 @@ class StreamingTestSuite extends SparkFunSuite with TestSuiteBase {
// setup and run the model
val ssc = setupStreams(
- input, (inputDStream: DStream[(Boolean, Double)]) => model.registerStream(inputDStream))
+ input, (inputDStream: DStream[BinarySample]) => model.registerStream(inputDStream))
val outputBatches = runStreams[StreamingTestResult](ssc, numBatches, numBatches)
assert(outputBatches.flatten.forall(res =>
@@ -157,7 +158,7 @@ class StreamingTestSuite extends SparkFunSuite with TestSuiteBase {
// setup and run the model
val ssc = setupStreams(
input,
- (inputDStream: DStream[(Boolean, Double)]) => model.summarizeByKeyAndWindow(inputDStream))
+ (inputDStream: DStream[BinarySample]) => model.summarizeByKeyAndWindow(inputDStream))
val outputBatches = runStreams[(Boolean, StatCounter)](ssc, numBatches, numBatches)
val outputCounts = outputBatches.flatten.map(_._2.count)
@@ -190,7 +191,7 @@ class StreamingTestSuite extends SparkFunSuite with TestSuiteBase {
// setup and run the model
val ssc = setupStreams(
- input, (inputDStream: DStream[(Boolean, Double)]) => model.dropPeacePeriod(inputDStream))
+ input, (inputDStream: DStream[BinarySample]) => model.dropPeacePeriod(inputDStream))
val outputBatches = runStreams[(Boolean, Double)](ssc, numBatches, numBatches)
assert(outputBatches.flatten.length == (numBatches - peacePeriod) * pointsPerBatch)
@@ -210,11 +211,11 @@ class StreamingTestSuite extends SparkFunSuite with TestSuiteBase {
.setPeacePeriod(0)
val input = generateTestData(numBatches, pointsPerBatch, meanA, stdevA, meanB, stdevB, 42)
- .map(batch => batch.filter(_._1)) // only keep one test group
+ .map(batch => batch.filter(_.isExperiment)) // only keep one test group
// setup and run the model
val ssc = setupStreams(
- input, (inputDStream: DStream[(Boolean, Double)]) => model.registerStream(inputDStream))
+ input, (inputDStream: DStream[BinarySample]) => model.registerStream(inputDStream))
val outputBatches = runStreams[StreamingTestResult](ssc, numBatches, numBatches)
assert(outputBatches.flatten.forall(result => (result.pValue - 1.0).abs < 0.001))
@@ -228,13 +229,13 @@ class StreamingTestSuite extends SparkFunSuite with TestSuiteBase {
stdevA: Double,
meanB: Double,
stdevB: Double,
- seed: Int): (IndexedSeq[IndexedSeq[(Boolean, Double)]]) = {
+ seed: Int): (IndexedSeq[IndexedSeq[BinarySample]]) = {
val rand = new XORShiftRandom(seed)
val numTrues = pointsPerBatch / 2
val data = (0 until numBatches).map { i =>
- (0 until numTrues).map { idx => (true, meanA + stdevA * rand.nextGaussian())} ++
+ (0 until numTrues).map { idx => BinarySample(true, meanA + stdevA * rand.nextGaussian())} ++
(pointsPerBatch / 2 until pointsPerBatch).map { idx =>
- (false, meanB + stdevB * rand.nextGaussian())
+ BinarySample(false, meanB + stdevB * rand.nextGaussian())
}
}
diff --git a/pom.xml b/pom.xml
index 4d8b0832ae3c5..da50fcbb57eb8 100644
--- a/pom.xml
+++ b/pom.xml
@@ -492,7 +492,7 @@
${commons.math3.version}
- org.apache.commons
+ commons-collections
commons-collections
${commons.collections.version}
diff --git a/python/pyspark/context.py b/python/pyspark/context.py
index 77710a13394c6..529d16b480399 100644
--- a/python/pyspark/context.py
+++ b/python/pyspark/context.py
@@ -222,6 +222,7 @@ def _do_init(self, master, appName, sparkHome, pyFiles, environment, batchSize,
# create a signal handler which would be invoked on receiving SIGINT
def signal_handler(signal, frame):
self.cancelAllJobs()
+ raise KeyboardInterrupt()
# see http://stackoverflow.com/questions/23206787/
if isinstance(threading.current_thread(), threading._MainThread):
diff --git a/python/pyspark/sql/group.py b/python/pyspark/sql/group.py
index 1911588309aff..9ca303a974cd4 100644
--- a/python/pyspark/sql/group.py
+++ b/python/pyspark/sql/group.py
@@ -169,16 +169,20 @@ def sum(self, *cols):
@since(1.6)
def pivot(self, pivot_col, values=None):
- """Pivots a column of the current DataFrame and perform the specified aggregation.
+ """
+ Pivots a column of the current [[DataFrame]] and perform the specified aggregation.
+ There are two versions of pivot function: one that requires the caller to specify the list
+ of distinct values to pivot on, and one that does not. The latter is more concise but less
+ efficient, because Spark needs to first compute the list of distinct values internally.
- :param pivot_col: Column to pivot
- :param values: Optional list of values of pivot column that will be translated to columns in
- the output DataFrame. If values are not provided the method will do an immediate call
- to .distinct() on the pivot column.
+ :param pivot_col: Name of the column to pivot.
+ :param values: List of values that will be translated to columns in the output DataFrame.
+ // Compute the sum of earnings for each year by course with each course as a separate column
>>> df4.groupBy("year").pivot("course", ["dotNET", "Java"]).sum("earnings").collect()
[Row(year=2012, dotNET=15000, Java=20000), Row(year=2013, dotNET=48000, Java=30000)]
+ // Or without specifying column values (less efficient)
>>> df4.groupBy("year").pivot("course").sum("earnings").collect()
[Row(year=2012, Java=20000, dotNET=15000), Row(year=2013, Java=30000, dotNET=48000)]
"""
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala
index c40061ae0aafd..bb0fdc4c3d83b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala
@@ -19,20 +19,60 @@ package org.apache.spark.sql
import java.lang.reflect.Modifier
+import scala.annotation.implicitNotFound
import scala.reflect.{ClassTag, classTag}
+import org.apache.spark.annotation.Experimental
import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, encoderFor}
import org.apache.spark.sql.catalyst.expressions.{DecodeUsingSerializer, BoundReference, EncodeUsingSerializer}
import org.apache.spark.sql.types._
/**
+ * :: Experimental ::
* Used to convert a JVM object of type `T` to and from the internal Spark SQL representation.
*
- * Encoders are not intended to be thread-safe and thus they are allow to avoid internal locking
- * and reuse internal buffers to improve performance.
+ * == Scala ==
+ * Encoders are generally created automatically through implicits from a `SQLContext`.
+ *
+ * {{{
+ * import sqlContext.implicits._
+ *
+ * val ds = Seq(1, 2, 3).toDS() // implicitly provided (sqlContext.implicits.newIntEncoder)
+ * }}}
+ *
+ * == Java ==
+ * Encoders are specified by calling static methods on [[Encoders]].
+ *
+ * {{{
+ * List data = Arrays.asList("abc", "abc", "xyz");
+ * Dataset ds = context.createDataset(data, Encoders.STRING());
+ * }}}
+ *
+ * Encoders can be composed into tuples:
+ *
+ * {{{
+ * Encoder> encoder2 = Encoders.tuple(Encoders.INT(), Encoders.STRING());
+ * List> data2 = Arrays.asList(new scala.Tuple2(1, "a");
+ * Dataset> ds2 = context.createDataset(data2, encoder2);
+ * }}}
+ *
+ * Or constructed from Java Beans:
+ *
+ * {{{
+ * Encoders.bean(MyClass.class);
+ * }}}
+ *
+ * == Implementation ==
+ * - Encoders are not required to be thread-safe and thus they do not need to use locks to guard
+ * against concurrent access if they reuse internal buffers to improve performance.
*
* @since 1.6.0
*/
+@Experimental
+@implicitNotFound("Unable to find encoder for type stored in a Dataset. Primitive types " +
+ "(Int, String, etc) and Product types (case classes) are supported by importing " +
+ "sqlContext.implicits._ Support for serializing other types will be added in future " +
+ "releases.")
trait Encoder[T] extends Serializable {
/** Returns the schema of encoding this type of object as a Row. */
@@ -43,10 +83,12 @@ trait Encoder[T] extends Serializable {
}
/**
- * Methods for creating encoders.
+ * :: Experimental ::
+ * Methods for creating an [[Encoder]].
*
* @since 1.6.0
*/
+@Experimental
object Encoders {
/**
@@ -97,6 +139,24 @@ object Encoders {
*/
def STRING: Encoder[java.lang.String] = ExpressionEncoder()
+ /**
+ * An encoder for nullable decimal type.
+ * @since 1.6.0
+ */
+ def DECIMAL: Encoder[java.math.BigDecimal] = ExpressionEncoder()
+
+ /**
+ * An encoder for nullable date type.
+ * @since 1.6.0
+ */
+ def DATE: Encoder[java.sql.Date] = ExpressionEncoder()
+
+ /**
+ * An encoder for nullable timestamp type.
+ * @since 1.6.0
+ */
+ def TIMESTAMP: Encoder[java.sql.Timestamp] = ExpressionEncoder()
+
/**
* Creates an encoder for Java Bean of type T.
*
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index f117d3c4a1dd7..6f1f088ba1e95 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -259,7 +259,7 @@ class Analyzer(
object ResolvePivot extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
- case p: Pivot if !p.childrenResolved => p
+ case p: Pivot if !p.childrenResolved | !p.aggregates.forall(_.resolved) => p
case Pivot(groupByExprs, pivotColumn, pivotValues, aggregates, child) =>
val singleAgg = aggregates.size == 1
val pivotAggregates: Seq[NamedExpression] = pivotValues.flatMap { value =>
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
index 29502a59915f0..dbcbd6854b474 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
@@ -594,6 +594,20 @@ object HiveTypeCoercion {
case None => c
}
+ case g @ Greatest(children) if children.map(_.dataType).distinct.size > 1 =>
+ val types = children.map(_.dataType)
+ findTightestCommonType(types) match {
+ case Some(finalDataType) => Greatest(children.map(Cast(_, finalDataType)))
+ case None => g
+ }
+
+ case l @ Least(children) if children.map(_.dataType).distinct.size > 1 =>
+ val types = children.map(_.dataType)
+ findTightestCommonType(types) match {
+ case Some(finalDataType) => Least(children.map(Cast(_, finalDataType)))
+ case None => l
+ }
+
case NaNvl(l, r) if l.dataType == DoubleType && r.dataType == FloatType =>
NaNvl(l, Cast(r, DoubleType))
case NaNvl(l, r) if l.dataType == FloatType && r.dataType == DoubleType =>
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala
index ba1866efc84e1..915c585ec91fb 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala
@@ -32,6 +32,7 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
'intField.int,
'stringField.string,
'booleanField.boolean,
+ 'decimalField.decimal(8, 0),
'arrayField.array(StringType),
'mapField.map(StringType, LongType))
@@ -189,4 +190,13 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
assertError(Round('intField, 'mapField), "requires int type")
assertError(Round('booleanField, 'intField), "requires numeric type")
}
+
+ test("check types for Greatest/Least") {
+ for (operator <- Seq[(Seq[Expression] => Expression)](Greatest, Least)) {
+ assertError(operator(Seq('booleanField)), "requires at least 2 arguments")
+ assertError(operator(Seq('intField, 'stringField)), "should all have the same type")
+ assertError(operator(Seq('intField, 'decimalField)), "should all have the same type")
+ assertError(operator(Seq('mapField, 'mapField)), "does not support ordering")
+ }
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala
index d3fafaae89938..142915056f451 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala
@@ -251,6 +251,29 @@ class HiveTypeCoercionSuite extends PlanTest {
:: Nil))
}
+ test("greatest/least cast") {
+ for (operator <- Seq[(Seq[Expression] => Expression)](Greatest, Least)) {
+ ruleTest(HiveTypeCoercion.FunctionArgumentConversion,
+ operator(Literal(1.0)
+ :: Literal(1)
+ :: Literal.create(1.0, FloatType)
+ :: Nil),
+ operator(Cast(Literal(1.0), DoubleType)
+ :: Cast(Literal(1), DoubleType)
+ :: Cast(Literal.create(1.0, FloatType), DoubleType)
+ :: Nil))
+ ruleTest(HiveTypeCoercion.FunctionArgumentConversion,
+ operator(Literal(1L)
+ :: Literal(1)
+ :: Literal(new java.math.BigDecimal("1000000000000000000000"))
+ :: Nil),
+ operator(Cast(Literal(1L), DecimalType(22, 0))
+ :: Cast(Literal(1), DecimalType(22, 0))
+ :: Cast(Literal(new java.math.BigDecimal("1000000000000000000000")), DecimalType(22, 0))
+ :: Nil))
+ }
+ }
+
test("nanvl casts") {
ruleTest(HiveTypeCoercion.FunctionArgumentConversion,
NaNvl(Literal.create(1.0, FloatType), Literal.create(1.0, DoubleType)),
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
index ad6af481fadc4..d641fcac1c8ac 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
@@ -73,7 +73,26 @@ class TypedColumn[-T, U](
/**
* :: Experimental ::
- * A column in a [[DataFrame]].
+ * A column that will be computed based on the data in a [[DataFrame]].
+ *
+ * A new column is constructed based on the input columns present in a dataframe:
+ *
+ * {{{
+ * df("columnName") // On a specific DataFrame.
+ * col("columnName") // A generic column no yet associcated with a DataFrame.
+ * col("columnName.field") // Extracting a struct field
+ * col("`a.column.with.dots`") // Escape `.` in column names.
+ * $"columnName" // Scala short hand for a named column.
+ * expr("a + 1") // A column that is constructed from a parsed SQL Expression.
+ * lit("1") // A column that produces a literal (constant) value.
+ * }}}
+ *
+ * [[Column]] objects can be composed to form complex expressions:
+ *
+ * {{{
+ * $"a" + 1
+ * $"a" === $"b"
+ * }}}
*
* @groupname java_expr_ops Java-specific expression operators
* @groupname expr_ops Expression operators
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index d6bb1d2ad8e50..3bd18a14f9e8f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -67,15 +67,21 @@ class Dataset[T] private[sql](
tEncoder: Encoder[T]) extends Queryable with Serializable {
/**
- * An unresolved version of the internal encoder for the type of this dataset. This one is marked
- * implicit so that we can use it when constructing new [[Dataset]] objects that have the same
- * object type (that will be possibly resolved to a different schema).
+ * An unresolved version of the internal encoder for the type of this [[Dataset]]. This one is
+ * marked implicit so that we can use it when constructing new [[Dataset]] objects that have the
+ * same object type (that will be possibly resolved to a different schema).
*/
private[sql] implicit val unresolvedTEncoder: ExpressionEncoder[T] = encoderFor(tEncoder)
/** The encoder for this [[Dataset]] that has been resolved to its output schema. */
private[sql] val resolvedTEncoder: ExpressionEncoder[T] =
- unresolvedTEncoder.resolve(queryExecution.analyzed.output, OuterScopes.outerScopes)
+ unresolvedTEncoder.resolve(logicalPlan.output, OuterScopes.outerScopes)
+
+ /**
+ * The encoder where the expressions used to construct an object from an input row have been
+ * bound to the ordinals of the given schema.
+ */
+ private[sql] val boundTEncoder = resolvedTEncoder.bind(logicalPlan.output)
private implicit def classTag = resolvedTEncoder.clsTag
@@ -89,7 +95,7 @@ class Dataset[T] private[sql](
override def schema: StructType = resolvedTEncoder.schema
/**
- * Prints the schema of the underlying [[DataFrame]] to the console in a nice tree format.
+ * Prints the schema of the underlying [[Dataset]] to the console in a nice tree format.
* @since 1.6.0
*/
override def printSchema(): Unit = toDF().printSchema()
@@ -111,7 +117,7 @@ class Dataset[T] private[sql](
* ************* */
/**
- * Returns a new `Dataset` where each record has been mapped on to the specified type. The
+ * Returns a new [[Dataset]] where each record has been mapped on to the specified type. The
* method used to map columns depend on the type of `U`:
* - When `U` is a class, fields for the class will be mapped to columns of the same name
* (case sensitivity is determined by `spark.sql.caseSensitive`)
@@ -145,7 +151,7 @@ class Dataset[T] private[sql](
def toDF(): DataFrame = DataFrame(sqlContext, logicalPlan)
/**
- * Returns this Dataset.
+ * Returns this [[Dataset]].
* @since 1.6.0
*/
// This is declared with parentheses to prevent the Scala compiler from treating
@@ -153,15 +159,12 @@ class Dataset[T] private[sql](
def toDS(): Dataset[T] = this
/**
- * Converts this Dataset to an RDD.
+ * Converts this [[Dataset]] to an [[RDD]].
* @since 1.6.0
*/
def rdd: RDD[T] = {
- val tEnc = resolvedTEncoder
- val input = queryExecution.analyzed.output
queryExecution.toRdd.mapPartitions { iter =>
- val bound = tEnc.bind(input)
- iter.map(bound.fromRow)
+ iter.map(boundTEncoder.fromRow)
}
}
@@ -189,7 +192,7 @@ class Dataset[T] private[sql](
def show(numRows: Int): Unit = show(numRows, truncate = true)
/**
- * Displays the top 20 rows of [[DataFrame]] in a tabular form. Strings more than 20 characters
+ * Displays the top 20 rows of [[Dataset]] in a tabular form. Strings more than 20 characters
* will be truncated, and all cells will be aligned right.
*
* @since 1.6.0
@@ -197,7 +200,7 @@ class Dataset[T] private[sql](
def show(): Unit = show(20)
/**
- * Displays the top 20 rows of [[DataFrame]] in a tabular form.
+ * Displays the top 20 rows of [[Dataset]] in a tabular form.
*
* @param truncate Whether truncate long strings. If true, strings more than 20 characters will
* be truncated and all cells will be aligned right
@@ -207,7 +210,7 @@ class Dataset[T] private[sql](
def show(truncate: Boolean): Unit = show(20, truncate)
/**
- * Displays the [[DataFrame]] in a tabular form. For example:
+ * Displays the [[Dataset]] in a tabular form. For example:
* {{{
* year month AVG('Adj Close) MAX('Adj Close)
* 1980 12 0.503218 0.595103
@@ -291,7 +294,7 @@ class Dataset[T] private[sql](
/**
* (Scala-specific)
- * Returns a new [[Dataset]] that contains the result of applying `func` to each element.
+ * Returns a new [[Dataset]] that contains the result of applying `func` to each partition.
* @since 1.6.0
*/
def mapPartitions[U : Encoder](func: Iterator[T] => Iterator[U]): Dataset[U] = {
@@ -307,7 +310,7 @@ class Dataset[T] private[sql](
/**
* (Java-specific)
- * Returns a new [[Dataset]] that contains the result of applying `func` to each element.
+ * Returns a new [[Dataset]] that contains the result of applying `func` to each partition.
* @since 1.6.0
*/
def mapPartitions[U](f: MapPartitionsFunction[T, U], encoder: Encoder[U]): Dataset[U] = {
@@ -341,28 +344,28 @@ class Dataset[T] private[sql](
/**
* (Scala-specific)
- * Runs `func` on each element of this Dataset.
+ * Runs `func` on each element of this [[Dataset]].
* @since 1.6.0
*/
def foreach(func: T => Unit): Unit = rdd.foreach(func)
/**
* (Java-specific)
- * Runs `func` on each element of this Dataset.
+ * Runs `func` on each element of this [[Dataset]].
* @since 1.6.0
*/
def foreach(func: ForeachFunction[T]): Unit = foreach(func.call(_))
/**
* (Scala-specific)
- * Runs `func` on each partition of this Dataset.
+ * Runs `func` on each partition of this [[Dataset]].
* @since 1.6.0
*/
def foreachPartition(func: Iterator[T] => Unit): Unit = rdd.foreachPartition(func)
/**
* (Java-specific)
- * Runs `func` on each partition of this Dataset.
+ * Runs `func` on each partition of this [[Dataset]].
* @since 1.6.0
*/
def foreachPartition(func: ForeachPartitionFunction[T]): Unit =
@@ -374,7 +377,7 @@ class Dataset[T] private[sql](
/**
* (Scala-specific)
- * Reduces the elements of this Dataset using the specified binary function. The given function
+ * Reduces the elements of this [[Dataset]] using the specified binary function. The given `func`
* must be commutative and associative or the result may be non-deterministic.
* @since 1.6.0
*/
@@ -382,7 +385,7 @@ class Dataset[T] private[sql](
/**
* (Java-specific)
- * Reduces the elements of this Dataset using the specified binary function. The given function
+ * Reduces the elements of this Dataset using the specified binary function. The given `func`
* must be commutative and associative or the result may be non-deterministic.
* @since 1.6.0
*/
@@ -390,11 +393,11 @@ class Dataset[T] private[sql](
/**
* (Scala-specific)
- * Returns a [[GroupedDataset]] where the data is grouped by the given key function.
+ * Returns a [[GroupedDataset]] where the data is grouped by the given key `func`.
* @since 1.6.0
*/
def groupBy[K : Encoder](func: T => K): GroupedDataset[K, T] = {
- val inputPlan = queryExecution.analyzed
+ val inputPlan = logicalPlan
val withGroupingKey = AppendColumns(func, resolvedTEncoder, inputPlan)
val executed = sqlContext.executePlan(withGroupingKey)
@@ -429,18 +432,18 @@ class Dataset[T] private[sql](
/**
* (Java-specific)
- * Returns a [[GroupedDataset]] where the data is grouped by the given key function.
+ * Returns a [[GroupedDataset]] where the data is grouped by the given key `func`.
* @since 1.6.0
*/
- def groupBy[K](f: MapFunction[T, K], encoder: Encoder[K]): GroupedDataset[K, T] =
- groupBy(f.call(_))(encoder)
+ def groupBy[K](func: MapFunction[T, K], encoder: Encoder[K]): GroupedDataset[K, T] =
+ groupBy(func.call(_))(encoder)
/* ****************** *
* Typed Relational *
* ****************** */
/**
- * Selects a set of column based expressions.
+ * Returns a new [[DataFrame]] by selecting a set of column based expressions.
* {{{
* df.select($"colA", $"colB" + 1)
* }}}
@@ -464,8 +467,8 @@ class Dataset[T] private[sql](
sqlContext,
Project(
c1.withInputType(
- resolvedTEncoder.bind(queryExecution.analyzed.output),
- queryExecution.analyzed.output).named :: Nil,
+ boundTEncoder,
+ logicalPlan.output).named :: Nil,
logicalPlan))
}
@@ -477,7 +480,7 @@ class Dataset[T] private[sql](
protected def selectUntyped(columns: TypedColumn[_, _]*): Dataset[_] = {
val encoders = columns.map(_.encoder)
val namedColumns =
- columns.map(_.withInputType(resolvedTEncoder, queryExecution.analyzed.output).named)
+ columns.map(_.withInputType(resolvedTEncoder, logicalPlan.output).named)
val execution = new QueryExecution(sqlContext, Project(namedColumns, logicalPlan))
new Dataset(sqlContext, execution, ExpressionEncoder.tuple(encoders))
@@ -654,7 +657,7 @@ class Dataset[T] private[sql](
* Returns an array that contains all the elements in this [[Dataset]].
*
* Running collect requires moving all the data into the application's driver process, and
- * doing so on a very large dataset can crash the driver process with OutOfMemoryError.
+ * doing so on a very large [[Dataset]] can crash the driver process with OutOfMemoryError.
*
* For Java API, use [[collectAsList]].
* @since 1.6.0
@@ -662,17 +665,14 @@ class Dataset[T] private[sql](
def collect(): Array[T] = {
// This is different from Dataset.rdd in that it collects Rows, and then runs the encoders
// to convert the rows into objects of type T.
- val tEnc = resolvedTEncoder
- val input = queryExecution.analyzed.output
- val bound = tEnc.bind(input)
- queryExecution.toRdd.map(_.copy()).collect().map(bound.fromRow)
+ queryExecution.toRdd.map(_.copy()).collect().map(boundTEncoder.fromRow)
}
/**
* Returns an array that contains all the elements in this [[Dataset]].
*
* Running collect requires moving all the data into the application's driver process, and
- * doing so on a very large dataset can crash the driver process with OutOfMemoryError.
+ * doing so on a very large [[Dataset]] can crash the driver process with OutOfMemoryError.
*
* For Java API, use [[collectAsList]].
* @since 1.6.0
@@ -683,7 +683,7 @@ class Dataset[T] private[sql](
* Returns the first `num` elements of this [[Dataset]] as an array.
*
* Running take requires moving data into the application's driver process, and doing so with
- * a very large `n` can crash the driver process with OutOfMemoryError.
+ * a very large `num` can crash the driver process with OutOfMemoryError.
* @since 1.6.0
*/
def take(num: Int): Array[T] = withPlan(Limit(Literal(num), _)).collect()
@@ -692,7 +692,7 @@ class Dataset[T] private[sql](
* Returns the first `num` elements of this [[Dataset]] as an array.
*
* Running take requires moving data into the application's driver process, and doing so with
- * a very large `n` can crash the driver process with OutOfMemoryError.
+ * a very large `num` can crash the driver process with OutOfMemoryError.
* @since 1.6.0
*/
def takeAsList(num: Int): java.util.List[T] = java.util.Arrays.asList(take(num) : _*)
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
index ae47f4fe0e231..383a2d0badb53 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
@@ -18,6 +18,9 @@
package test.org.apache.spark.sql;
import java.io.Serializable;
+import java.math.BigDecimal;
+import java.sql.Date;
+import java.sql.Timestamp;
import java.util.*;
import scala.Tuple2;
@@ -385,6 +388,20 @@ public void testNestedTupleEncoder() {
Assert.assertEquals(data3, ds3.collectAsList());
}
+ @Test
+ public void testPrimitiveEncoder() {
+ Encoder> encoder =
+ Encoders.tuple(Encoders.DOUBLE(), Encoders.DECIMAL(), Encoders.DATE(), Encoders.TIMESTAMP(),
+ Encoders.FLOAT());
+ List> data =
+ Arrays.asList(new Tuple5(
+ 1.7976931348623157E308, new BigDecimal("0.922337203685477589"),
+ Date.valueOf("1970-01-01"), new Timestamp(System.currentTimeMillis()), Float.MAX_VALUE));
+ Dataset> ds =
+ context.createDataset(data, encoder);
+ Assert.assertEquals(data, ds.collectAsList());
+ }
+
@Test
public void testTypedAggregation() {
Encoder> encoder = Encoders.tuple(Encoders.STRING(), Encoders.INT());
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala
index fc53aba68ebb7..bc1a336ea4fd0 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala
@@ -85,4 +85,12 @@ class DataFramePivotSuite extends QueryTest with SharedSQLContext{
sqlContext.conf.setConf(SQLConf.DATAFRAME_PIVOT_MAX_VALUES,
SQLConf.DATAFRAME_PIVOT_MAX_VALUES.defaultValue.get)
}
+
+ test("pivot with UnresolvedFunction") {
+ checkAnswer(
+ courseSales.groupBy("year").pivot("course", Seq("dotNET", "Java"))
+ .agg("earnings" -> "sum"),
+ Row(2012, 15000.0, 20000.0) :: Row(2013, 48000.0, 30000.0) :: Nil
+ )
+ }
}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala
index fd0e8d5d690b6..d0046afdeb447 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala
@@ -277,7 +277,7 @@ class CheckpointWriter(
val bytes = Checkpoint.serialize(checkpoint, conf)
executor.execute(new CheckpointWriteHandler(
checkpoint.checkpointTime, bytes, clearCheckpointDataLater))
- logDebug("Submitted checkpoint of time " + checkpoint.checkpointTime + " writer queue")
+ logInfo("Submitted checkpoint of time " + checkpoint.checkpointTime + " writer queue")
} catch {
case rej: RejectedExecutionException =>
logError("Could not submit checkpoint task to the thread pool executor", rej)
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/TrackStateDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/TrackStateDStream.scala
index 0ada1111ce30a..ea6213420e7ab 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/TrackStateDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/TrackStateDStream.scala
@@ -132,22 +132,37 @@ class InternalTrackStateDStream[K: ClassTag, V: ClassTag, S: ClassTag, E: ClassT
/** Method that generates a RDD for the given time */
override def compute(validTime: Time): Option[RDD[TrackStateRDDRecord[K, S, E]]] = {
// Get the previous state or create a new empty state RDD
- val prevStateRDD = getOrCompute(validTime - slideDuration).getOrElse {
- TrackStateRDD.createFromPairRDD[K, V, S, E](
- spec.getInitialStateRDD().getOrElse(new EmptyRDD[(K, S)](ssc.sparkContext)),
- partitioner, validTime
- )
+ val prevStateRDD = getOrCompute(validTime - slideDuration) match {
+ case Some(rdd) =>
+ if (rdd.partitioner != Some(partitioner)) {
+ // If the RDD is not partitioned the right way, let us repartition it using the
+ // partition index as the key. This is to ensure that state RDD is always partitioned
+ // before creating another state RDD using it
+ TrackStateRDD.createFromRDD[K, V, S, E](
+ rdd.flatMap { _.stateMap.getAll() }, partitioner, validTime)
+ } else {
+ rdd
+ }
+ case None =>
+ TrackStateRDD.createFromPairRDD[K, V, S, E](
+ spec.getInitialStateRDD().getOrElse(new EmptyRDD[(K, S)](ssc.sparkContext)),
+ partitioner,
+ validTime
+ )
}
+
// Compute the new state RDD with previous state RDD and partitioned data RDD
- parent.getOrCompute(validTime).map { dataRDD =>
- val partitionedDataRDD = dataRDD.partitionBy(partitioner)
- val timeoutThresholdTime = spec.getTimeoutInterval().map { interval =>
- (validTime - interval).milliseconds
- }
- new TrackStateRDD(
- prevStateRDD, partitionedDataRDD, trackingFunction, validTime, timeoutThresholdTime)
+ // Even if there is no data RDD, use an empty one to create a new state RDD
+ val dataRDD = parent.getOrCompute(validTime).getOrElse {
+ context.sparkContext.emptyRDD[(K, V)]
+ }
+ val partitionedDataRDD = dataRDD.partitionBy(partitioner)
+ val timeoutThresholdTime = spec.getTimeoutInterval().map { interval =>
+ (validTime - interval).milliseconds
}
+ Some(new TrackStateRDD(
+ prevStateRDD, partitionedDataRDD, trackingFunction, validTime, timeoutThresholdTime))
}
}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/rdd/TrackStateRDD.scala b/streaming/src/main/scala/org/apache/spark/streaming/rdd/TrackStateRDD.scala
index 7050378d0feb0..30aafcf1460e3 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/rdd/TrackStateRDD.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/rdd/TrackStateRDD.scala
@@ -179,22 +179,43 @@ private[streaming] class TrackStateRDD[K: ClassTag, V: ClassTag, S: ClassTag, E:
private[streaming] object TrackStateRDD {
- def createFromPairRDD[K: ClassTag, V: ClassTag, S: ClassTag, T: ClassTag](
+ def createFromPairRDD[K: ClassTag, V: ClassTag, S: ClassTag, E: ClassTag](
pairRDD: RDD[(K, S)],
partitioner: Partitioner,
- updateTime: Time): TrackStateRDD[K, V, S, T] = {
+ updateTime: Time): TrackStateRDD[K, V, S, E] = {
val rddOfTrackStateRecords = pairRDD.partitionBy(partitioner).mapPartitions ({ iterator =>
val stateMap = StateMap.create[K, S](SparkEnv.get.conf)
iterator.foreach { case (key, state) => stateMap.put(key, state, updateTime.milliseconds) }
- Iterator(TrackStateRDDRecord(stateMap, Seq.empty[T]))
+ Iterator(TrackStateRDDRecord(stateMap, Seq.empty[E]))
}, preservesPartitioning = true)
val emptyDataRDD = pairRDD.sparkContext.emptyRDD[(K, V)].partitionBy(partitioner)
val noOpFunc = (time: Time, key: K, value: Option[V], state: State[S]) => None
- new TrackStateRDD[K, V, S, T](rddOfTrackStateRecords, emptyDataRDD, noOpFunc, updateTime, None)
+ new TrackStateRDD[K, V, S, E](rddOfTrackStateRecords, emptyDataRDD, noOpFunc, updateTime, None)
+ }
+
+ def createFromRDD[K: ClassTag, V: ClassTag, S: ClassTag, E: ClassTag](
+ rdd: RDD[(K, S, Long)],
+ partitioner: Partitioner,
+ updateTime: Time): TrackStateRDD[K, V, S, E] = {
+
+ val pairRDD = rdd.map { x => (x._1, (x._2, x._3)) }
+ val rddOfTrackStateRecords = pairRDD.partitionBy(partitioner).mapPartitions({ iterator =>
+ val stateMap = StateMap.create[K, S](SparkEnv.get.conf)
+ iterator.foreach { case (key, (state, updateTime)) =>
+ stateMap.put(key, state, updateTime)
+ }
+ Iterator(TrackStateRDDRecord(stateMap, Seq.empty[E]))
+ }, preservesPartitioning = true)
+
+ val emptyDataRDD = pairRDD.sparkContext.emptyRDD[(K, V)].partitionBy(partitioner)
+
+ val noOpFunc = (time: Time, key: K, value: Option[V], state: State[S]) => None
+
+ new TrackStateRDD[K, V, S, E](rddOfTrackStateRecords, emptyDataRDD, noOpFunc, updateTime, None)
}
}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/BatchedWriteAheadLog.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/BatchedWriteAheadLog.scala
index 6e6ed8d819721..862272bb4498f 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/util/BatchedWriteAheadLog.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/util/BatchedWriteAheadLog.scala
@@ -165,10 +165,12 @@ private[util] class BatchedWriteAheadLog(val wrappedLog: WriteAheadLog, conf: Sp
var segment: WriteAheadLogRecordHandle = null
if (buffer.length > 0) {
logDebug(s"Batched ${buffer.length} records for Write Ahead Log write")
+ // threads may not be able to add items in order by time
+ val sortedByTime = buffer.sortBy(_.time)
// We take the latest record for the timestamp. Please refer to the class Javadoc for
// detailed explanation
- val time = buffer.last.time
- segment = wrappedLog.write(aggregate(buffer), time)
+ val time = sortedByTime.last.time
+ segment = wrappedLog.write(aggregate(sortedByTime), time)
}
buffer.foreach(_.promise.success(segment))
} catch {
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala
index b1cbc7163bee3..cd28d3cf408d5 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala
@@ -33,17 +33,149 @@ import org.mockito.Mockito.mock
import org.scalatest.concurrent.Eventually._
import org.scalatest.time.SpanSugar._
-import org.apache.spark.TestUtils
+import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite, TestUtils}
import org.apache.spark.streaming.dstream.{DStream, FileInputDStream}
import org.apache.spark.streaming.scheduler._
import org.apache.spark.util.{MutableURLClassLoader, Clock, ManualClock, Utils}
+/**
+ * A trait of that can be mixed in to get methods for testing DStream operations under
+ * DStream checkpointing. Note that the implementations of this trait has to implement
+ * the `setupCheckpointOperation`
+ */
+trait DStreamCheckpointTester { self: SparkFunSuite =>
+
+ /**
+ * Tests a streaming operation under checkpointing, by restarting the operation
+ * from checkpoint file and verifying whether the final output is correct.
+ * The output is assumed to have come from a reliable queue which an replay
+ * data as required.
+ *
+ * NOTE: This takes into consideration that the last batch processed before
+ * master failure will be re-processed after restart/recovery.
+ */
+ protected def testCheckpointedOperation[U: ClassTag, V: ClassTag](
+ input: Seq[Seq[U]],
+ operation: DStream[U] => DStream[V],
+ expectedOutput: Seq[Seq[V]],
+ numBatchesBeforeRestart: Int,
+ batchDuration: Duration = Milliseconds(500),
+ stopSparkContextAfterTest: Boolean = true
+ ) {
+ require(numBatchesBeforeRestart < expectedOutput.size,
+ "Number of batches before context restart less than number of expected output " +
+ "(i.e. number of total batches to run)")
+ require(StreamingContext.getActive().isEmpty,
+ "Cannot run test with already active streaming context")
+
+ // Current code assumes that number of batches to be run = number of inputs
+ val totalNumBatches = input.size
+ val batchDurationMillis = batchDuration.milliseconds
+
+ // Setup the stream computation
+ val checkpointDir = Utils.createTempDir(this.getClass.getSimpleName()).toString
+ logDebug(s"Using checkpoint directory $checkpointDir")
+ val ssc = createContextForCheckpointOperation(batchDuration)
+ require(ssc.conf.get("spark.streaming.clock") === classOf[ManualClock].getName,
+ "Cannot run test without manual clock in the conf")
+
+ val inputStream = new TestInputStream(ssc, input, numPartitions = 2)
+ val operatedStream = operation(inputStream)
+ operatedStream.print()
+ val outputStream = new TestOutputStreamWithPartitions(operatedStream,
+ new ArrayBuffer[Seq[Seq[V]]] with SynchronizedBuffer[Seq[Seq[V]]])
+ outputStream.register()
+ ssc.checkpoint(checkpointDir)
+
+ // Do the computation for initial number of batches, create checkpoint file and quit
+ val beforeRestartOutput = generateOutput[V](ssc,
+ Time(batchDurationMillis * numBatchesBeforeRestart), checkpointDir, stopSparkContextAfterTest)
+ assertOutput(beforeRestartOutput, expectedOutput, beforeRestart = true)
+ // Restart and complete the computation from checkpoint file
+ logInfo(
+ "\n-------------------------------------------\n" +
+ " Restarting stream computation " +
+ "\n-------------------------------------------\n"
+ )
+
+ val restartedSsc = new StreamingContext(checkpointDir)
+ val afterRestartOutput = generateOutput[V](restartedSsc,
+ Time(batchDurationMillis * totalNumBatches), checkpointDir, stopSparkContextAfterTest)
+ assertOutput(afterRestartOutput, expectedOutput, beforeRestart = false)
+ }
+
+ protected def createContextForCheckpointOperation(batchDuration: Duration): StreamingContext = {
+ val conf = new SparkConf().setMaster("local").setAppName(this.getClass.getSimpleName)
+ conf.set("spark.streaming.clock", classOf[ManualClock].getName())
+ new StreamingContext(SparkContext.getOrCreate(conf), batchDuration)
+ }
+
+ private def generateOutput[V: ClassTag](
+ ssc: StreamingContext,
+ targetBatchTime: Time,
+ checkpointDir: String,
+ stopSparkContext: Boolean
+ ): Seq[Seq[V]] = {
+ try {
+ val batchDuration = ssc.graph.batchDuration
+ val batchCounter = new BatchCounter(ssc)
+ ssc.start()
+ val clock = ssc.scheduler.clock.asInstanceOf[ManualClock]
+ val currentTime = clock.getTimeMillis()
+
+ logInfo("Manual clock before advancing = " + clock.getTimeMillis())
+ clock.setTime(targetBatchTime.milliseconds)
+ logInfo("Manual clock after advancing = " + clock.getTimeMillis())
+
+ val outputStream = ssc.graph.getOutputStreams().filter { dstream =>
+ dstream.isInstanceOf[TestOutputStreamWithPartitions[V]]
+ }.head.asInstanceOf[TestOutputStreamWithPartitions[V]]
+
+ eventually(timeout(10 seconds)) {
+ ssc.awaitTerminationOrTimeout(10)
+ assert(batchCounter.getLastCompletedBatchTime === targetBatchTime)
+ }
+
+ eventually(timeout(10 seconds)) {
+ val checkpointFilesOfLatestTime = Checkpoint.getCheckpointFiles(checkpointDir).filter {
+ _.toString.contains(clock.getTimeMillis.toString)
+ }
+ // Checkpoint files are written twice for every batch interval. So assert that both
+ // are written to make sure that both of them have been written.
+ assert(checkpointFilesOfLatestTime.size === 2)
+ }
+ outputStream.output.map(_.flatten)
+
+ } finally {
+ ssc.stop(stopSparkContext = stopSparkContext)
+ }
+ }
+
+ private def assertOutput[V: ClassTag](
+ output: Seq[Seq[V]],
+ expectedOutput: Seq[Seq[V]],
+ beforeRestart: Boolean): Unit = {
+ val expectedPartialOutput = if (beforeRestart) {
+ expectedOutput.take(output.size)
+ } else {
+ expectedOutput.takeRight(output.size)
+ }
+ val setComparison = output.zip(expectedPartialOutput).forall {
+ case (o, e) => o.toSet === e.toSet
+ }
+ assert(setComparison, s"set comparison failed\n" +
+ s"Expected output items:\n${expectedPartialOutput.mkString("\n")}\n" +
+ s"Generated output items: ${output.mkString("\n")}"
+ )
+ }
+}
+
/**
* This test suites tests the checkpointing functionality of DStreams -
* the checkpointing of a DStream's RDDs as well as the checkpointing of
* the whole DStream graph.
*/
-class CheckpointSuite extends TestSuiteBase {
+class CheckpointSuite extends TestSuiteBase with DStreamCheckpointTester {
var ssc: StreamingContext = null
@@ -56,7 +188,7 @@ class CheckpointSuite extends TestSuiteBase {
override def afterFunction() {
super.afterFunction()
- if (ssc != null) ssc.stop()
+ if (ssc != null) { ssc.stop() }
Utils.deleteRecursively(new File(checkpointDir))
}
@@ -251,7 +383,9 @@ class CheckpointSuite extends TestSuiteBase {
Seq(("", 2)),
Seq(),
Seq(("a", 2), ("b", 1)),
- Seq(("", 2)), Seq() ),
+ Seq(("", 2)),
+ Seq()
+ ),
3
)
}
@@ -634,53 +768,6 @@ class CheckpointSuite extends TestSuiteBase {
checkpointWriter.stop()
}
- /**
- * Tests a streaming operation under checkpointing, by restarting the operation
- * from checkpoint file and verifying whether the final output is correct.
- * The output is assumed to have come from a reliable queue which an replay
- * data as required.
- *
- * NOTE: This takes into consideration that the last batch processed before
- * master failure will be re-processed after restart/recovery.
- */
- def testCheckpointedOperation[U: ClassTag, V: ClassTag](
- input: Seq[Seq[U]],
- operation: DStream[U] => DStream[V],
- expectedOutput: Seq[Seq[V]],
- initialNumBatches: Int
- ) {
-
- // Current code assumes that:
- // number of inputs = number of outputs = number of batches to be run
- val totalNumBatches = input.size
- val nextNumBatches = totalNumBatches - initialNumBatches
- val initialNumExpectedOutputs = initialNumBatches
- val nextNumExpectedOutputs = expectedOutput.size - initialNumExpectedOutputs + 1
- // because the last batch will be processed again
-
- // Do the computation for initial number of batches, create checkpoint file and quit
- ssc = setupStreams[U, V](input, operation)
- ssc.start()
- val output = advanceTimeWithRealDelay[V](ssc, initialNumBatches)
- ssc.stop()
- verifyOutput[V](output, expectedOutput.take(initialNumBatches), true)
- Thread.sleep(1000)
-
- // Restart and complete the computation from checkpoint file
- logInfo(
- "\n-------------------------------------------\n" +
- " Restarting stream computation " +
- "\n-------------------------------------------\n"
- )
- ssc = new StreamingContext(checkpointDir)
- ssc.start()
- val outputNew = advanceTimeWithRealDelay[V](ssc, nextNumBatches)
- // the first element will be re-processed data of the last batch before restart
- verifyOutput[V](outputNew, expectedOutput.takeRight(nextNumExpectedOutputs), true)
- ssc.stop()
- ssc = null
- }
-
/**
* Advances the manual clock on the streaming scheduler by given number of batches.
* It also waits for the expected amount of time for each batch.
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala
index a45c92d9c7bc8..be0f4636a6cb8 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala
@@ -142,6 +142,7 @@ class BatchCounter(ssc: StreamingContext) {
// All access to this state should be guarded by `BatchCounter.this.synchronized`
private var numCompletedBatches = 0
private var numStartedBatches = 0
+ private var lastCompletedBatchTime: Time = null
private val listener = new StreamingListener {
override def onBatchStarted(batchStarted: StreamingListenerBatchStarted): Unit =
@@ -152,6 +153,7 @@ class BatchCounter(ssc: StreamingContext) {
override def onBatchCompleted(batchCompleted: StreamingListenerBatchCompleted): Unit =
BatchCounter.this.synchronized {
numCompletedBatches += 1
+ lastCompletedBatchTime = batchCompleted.batchInfo.batchTime
BatchCounter.this.notifyAll()
}
}
@@ -165,6 +167,10 @@ class BatchCounter(ssc: StreamingContext) {
numStartedBatches
}
+ def getLastCompletedBatchTime: Time = this.synchronized {
+ lastCompletedBatchTime
+ }
+
/**
* Wait until `expectedNumCompletedBatches` batches are completed, or timeout. Return true if
* `expectedNumCompletedBatches` batches are completed. Otherwise, return false to indicate it's
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/TrackStateByKeySuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/TrackStateByKeySuite.scala
index 58aef74c0040f..1fc320d31b18b 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/TrackStateByKeySuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/TrackStateByKeySuite.scala
@@ -25,31 +25,27 @@ import scala.reflect.ClassTag
import org.scalatest.PrivateMethodTester._
import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll}
-import org.apache.spark.streaming.dstream.{InternalTrackStateDStream, TrackStateDStream, TrackStateDStreamImpl}
+import org.apache.spark.streaming.dstream.{DStream, InternalTrackStateDStream, TrackStateDStream, TrackStateDStreamImpl}
import org.apache.spark.util.{ManualClock, Utils}
import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite}
-class TrackStateByKeySuite extends SparkFunSuite with BeforeAndAfterAll with BeforeAndAfter {
+class TrackStateByKeySuite extends SparkFunSuite
+ with DStreamCheckpointTester with BeforeAndAfterAll with BeforeAndAfter {
private var sc: SparkContext = null
- private var ssc: StreamingContext = null
- private var checkpointDir: File = null
- private val batchDuration = Seconds(1)
+ protected var checkpointDir: File = null
+ protected val batchDuration = Seconds(1)
before {
- StreamingContext.getActive().foreach {
- _.stop(stopSparkContext = false)
- }
+ StreamingContext.getActive().foreach { _.stop(stopSparkContext = false) }
checkpointDir = Utils.createTempDir("checkpoint")
-
- ssc = new StreamingContext(sc, batchDuration)
- ssc.checkpoint(checkpointDir.toString)
}
after {
- StreamingContext.getActive().foreach {
- _.stop(stopSparkContext = false)
+ if (checkpointDir != null) {
+ Utils.deleteRecursively(checkpointDir)
}
+ StreamingContext.getActive().foreach { _.stop(stopSparkContext = false) }
}
override def beforeAll(): Unit = {
@@ -242,7 +238,7 @@ class TrackStateByKeySuite extends SparkFunSuite with BeforeAndAfterAll with Bef
assert(dstreamImpl.stateClass === classOf[Double])
assert(dstreamImpl.emittedClass === classOf[Long])
}
-
+ val ssc = new StreamingContext(sc, batchDuration)
val inputStream = new TestInputStream[(String, Int)](ssc, Seq.empty, numPartitions = 2)
// Defining StateSpec inline with trackStateByKey and simple function implicitly gets the types
@@ -451,8 +447,9 @@ class TrackStateByKeySuite extends SparkFunSuite with BeforeAndAfterAll with Bef
expectedCheckpointDuration: Duration,
explicitCheckpointDuration: Option[Duration] = None
): Unit = {
+ val ssc = new StreamingContext(sc, batchDuration)
+
try {
- ssc = new StreamingContext(sc, batchDuration)
val inputStream = new TestInputStream(ssc, Seq.empty[Seq[Int]], 2).map(_ -> 1)
val dummyFunc = (value: Option[Int], state: State[Int]) => 0
val trackStateStream = inputStream.trackStateByKey(StateSpec.function(dummyFunc))
@@ -462,11 +459,12 @@ class TrackStateByKeySuite extends SparkFunSuite with BeforeAndAfterAll with Bef
trackStateStream.checkpoint(d)
}
trackStateStream.register()
+ ssc.checkpoint(checkpointDir.toString)
ssc.start() // should initialize all the checkpoint durations
assert(trackStateStream.checkpointDuration === null)
assert(internalTrackStateStream.checkpointDuration === expectedCheckpointDuration)
} finally {
- StreamingContext.getActive().foreach { _.stop(stopSparkContext = false) }
+ ssc.stop(stopSparkContext = false)
}
}
@@ -479,6 +477,50 @@ class TrackStateByKeySuite extends SparkFunSuite with BeforeAndAfterAll with Bef
testCheckpointDuration(Seconds(10), Seconds(20), Some(Seconds(20)))
}
+
+ test("trackStateByKey - driver failure recovery") {
+ val inputData =
+ Seq(
+ Seq(),
+ Seq("a"),
+ Seq("a", "b"),
+ Seq("a", "b", "c"),
+ Seq("a", "b"),
+ Seq("a"),
+ Seq()
+ )
+
+ val stateData =
+ Seq(
+ Seq(),
+ Seq(("a", 1)),
+ Seq(("a", 2), ("b", 1)),
+ Seq(("a", 3), ("b", 2), ("c", 1)),
+ Seq(("a", 4), ("b", 3), ("c", 1)),
+ Seq(("a", 5), ("b", 3), ("c", 1)),
+ Seq(("a", 5), ("b", 3), ("c", 1))
+ )
+
+ def operation(dstream: DStream[String]): DStream[(String, Int)] = {
+
+ val checkpointDuration = batchDuration * (stateData.size / 2)
+
+ val runningCount = (value: Option[Int], state: State[Int]) => {
+ state.update(state.getOption().getOrElse(0) + value.getOrElse(0))
+ state.get()
+ }
+
+ val trackStateStream = dstream.map { _ -> 1 }.trackStateByKey(
+ StateSpec.function(runningCount))
+ // Set internval make sure there is one RDD checkpointing
+ trackStateStream.checkpoint(checkpointDuration)
+ trackStateStream.stateSnapshots()
+ }
+
+ testCheckpointedOperation(inputData, operation, stateData, inputData.size / 2,
+ batchDuration = batchDuration, stopSparkContextAfterTest = false)
+ }
+
private def testOperation[K: ClassTag, S: ClassTag, T: ClassTag](
input: Seq[Seq[K]],
trackStateSpec: StateSpec[K, Int, S, T],
@@ -500,6 +542,7 @@ class TrackStateByKeySuite extends SparkFunSuite with BeforeAndAfterAll with Bef
): (Seq[Seq[T]], Seq[Seq[(K, S)]]) = {
// Setup the stream computation
+ val ssc = new StreamingContext(sc, Seconds(1))
val inputStream = new TestInputStream(ssc, input, numPartitions = 2)
val trackeStateStream = inputStream.map(x => (x, 1)).trackStateByKey(trackStateSpec)
val collectedOutputs = new ArrayBuffer[Seq[T]] with SynchronizedBuffer[Seq[T]]
@@ -511,12 +554,14 @@ class TrackStateByKeySuite extends SparkFunSuite with BeforeAndAfterAll with Bef
stateSnapshotStream.register()
val batchCounter = new BatchCounter(ssc)
+ ssc.checkpoint(checkpointDir.toString)
ssc.start()
val clock = ssc.scheduler.clock.asInstanceOf[ManualClock]
clock.advance(batchDuration.milliseconds * numBatches)
batchCounter.waitUntilBatchesCompleted(numBatches, 10000)
+ ssc.stop(stopSparkContext = false)
(collectedOutputs, collectedStateSnapshots)
}
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala
index eaa88ea3cd380..ef1e89df31305 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala
@@ -480,7 +480,7 @@ class BatchedWriteAheadLogSuite extends CommonWriteAheadLogTests(
p
}
- test("BatchedWriteAheadLog - name log with aggregated entries with the timestamp of last entry") {
+ test("BatchedWriteAheadLog - name log with the highest timestamp of aggregated entries") {
val blockingWal = new BlockingWriteAheadLog(wal, walHandle)
val batchedWal = new BatchedWriteAheadLog(blockingWal, sparkConf)
@@ -500,8 +500,14 @@ class BatchedWriteAheadLogSuite extends CommonWriteAheadLogTests(
// rest of the records will be batched while it takes time for 3 to get written
writeAsync(batchedWal, event2, 5L)
writeAsync(batchedWal, event3, 8L)
- writeAsync(batchedWal, event4, 12L)
- writeAsync(batchedWal, event5, 10L)
+ // we would like event 5 to be written before event 4 in order to test that they get
+ // sorted before being aggregated
+ writeAsync(batchedWal, event5, 12L)
+ eventually(timeout(1 second)) {
+ assert(blockingWal.isBlocked)
+ assert(batchedWal.invokePrivate(queueLength()) === 3)
+ }
+ writeAsync(batchedWal, event4, 10L)
eventually(timeout(1 second)) {
assert(walBatchingThreadPool.getActiveCount === 5)
assert(batchedWal.invokePrivate(queueLength()) === 4)
@@ -517,7 +523,7 @@ class BatchedWriteAheadLogSuite extends CommonWriteAheadLogTests(
// the file name should be the timestamp of the last record, as events should be naturally
// in order of timestamp, and we need the last element.
val bufferCaptor = ArgumentCaptor.forClass(classOf[ByteBuffer])
- verify(wal, times(1)).write(bufferCaptor.capture(), meq(10L))
+ verify(wal, times(1)).write(bufferCaptor.capture(), meq(12L))
val records = BatchedWriteAheadLog.deaggregate(bufferCaptor.getValue).map(byteBufferToString)
assert(records.toSet === queuedEvents)
}
|