From 9e5b4ce727cf262a14a411efded85ee1e50a88ed Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Fri, 3 Mar 2017 18:44:31 -0800 Subject: [PATCH 01/78] [SPARK-19084][SQL] Ensure context class loader is set when initializing Hive. A change in Hive 2.2 (most probably HIVE-13149) causes this code path to fail, since the call to "state.getConf.setClassLoader" does not actually change the context's class loader. Spark doesn't yet officially support Hive 2.2, but some distribution-specific metastore client libraries may have that change (as certain versions of CDH already do), and this also makes it easier to support 2.2 when it comes out. Tested with existing unit tests; we've also used this patch extensively with Hive metastore client jars containing the offending patch. Author: Marcelo Vanzin Closes #17154 from vanzin/SPARK-19804. --- .../apache/spark/sql/hive/client/HiveClientImpl.scala | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala index 8f98c8f447037..7acaa9a7ab417 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala @@ -269,16 +269,21 @@ private[hive] class HiveClientImpl( */ def withHiveState[A](f: => A): A = retryLocked { val original = Thread.currentThread().getContextClassLoader - // Set the thread local metastore client to the client associated with this HiveClientImpl. - Hive.set(client) + val originalConfLoader = state.getConf.getClassLoader // The classloader in clientLoader could be changed after addJar, always use the latest - // classloader + // classloader. We explicitly set the context class loader since "conf.setClassLoader" does + // not do that, and the Hive client libraries may need to load classes defined by the client's + // class loader. + Thread.currentThread().setContextClassLoader(clientLoader.classLoader) state.getConf.setClassLoader(clientLoader.classLoader) + // Set the thread local metastore client to the client associated with this HiveClientImpl. + Hive.set(client) // setCurrentSessionState will use the classLoader associated // with the HiveConf in `state` to override the context class loader of the current // thread. shim.setCurrentSessionState(state) val ret = try f finally { + state.getConf.setClassLoader(originalConfLoader) Thread.currentThread().setContextClassLoader(original) HiveCatalogMetrics.incrementHiveClientCalls(1) } From fbc4058037cf5b0be9f14a7dd28105f7f8151bed Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Fri, 3 Mar 2017 19:00:35 -0800 Subject: [PATCH 02/78] [SPARK-19816][SQL][TESTS] Fix an issue that DataFrameCallbackSuite doesn't recover the log level ## What changes were proposed in this pull request? "DataFrameCallbackSuite.execute callback functions when a DataFrame action failed" sets the log level to "fatal" but doesn't recover it. Hence, tests running after it won't output any logs except fatal logs. This PR uses `testQuietly` instead to avoid changing the log level. ## How was this patch tested? Jenkins Author: Shixiong Zhu Closes #17156 from zsxwing/SPARK-19816. --- .../org/apache/spark/sql/util/DataFrameCallbackSuite.scala | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala index 9f27d06dcb366..7c9ea7d393630 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala @@ -60,7 +60,7 @@ class DataFrameCallbackSuite extends QueryTest with SharedSQLContext { spark.listenerManager.unregister(listener) } - test("execute callback functions when a DataFrame action failed") { + testQuietly("execute callback functions when a DataFrame action failed") { val metrics = ArrayBuffer.empty[(String, QueryExecution, Exception)] val listener = new QueryExecutionListener { override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = { @@ -75,8 +75,6 @@ class DataFrameCallbackSuite extends QueryTest with SharedSQLContext { val errorUdf = udf[Int, Int] { _ => throw new RuntimeException("udf error") } val df = sparkContext.makeRDD(Seq(1 -> "a")).toDF("i", "j") - // Ignore the log when we are expecting an exception. - sparkContext.setLogLevel("FATAL") val e = intercept[SparkException](df.select(errorUdf($"i")).collect()) assert(metrics.length == 1) From 6b0cfd9fa51aca4536d7c3f2a4bbceae11a50339 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Sat, 4 Mar 2017 16:43:31 +0000 Subject: [PATCH 03/78] [SPARK-19550][SPARKR][DOCS] Update R document to use JDK8 ## What changes were proposed in this pull request? Update R document to use JDK8. ## How was this patch tested? manual tests Author: Yuming Wang Closes #17162 from wangyum/SPARK-19550. --- R/WINDOWS.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R/WINDOWS.md b/R/WINDOWS.md index cb2eebb9ffe6e..9ca7e58e20cd2 100644 --- a/R/WINDOWS.md +++ b/R/WINDOWS.md @@ -6,7 +6,7 @@ To build SparkR on Windows, the following steps are required include Rtools and R in `PATH`. 2. Install -[JDK7](http://www.oracle.com/technetwork/java/javase/downloads/jdk7-downloads-1880260.html) and set +[JDK8](http://www.oracle.com/technetwork/java/javase/downloads/jdk8-downloads-2133151.html) and set `JAVA_HOME` in the system environment variables. 3. Download and install [Maven](http://maven.apache.org/download.html). Also include the `bin` From 42c4cd9e2a44eaa6a16e3b490eb82b6292d9b2ea Mon Sep 17 00:00:00 2001 From: liuxian Date: Sun, 5 Mar 2017 10:23:50 +0000 Subject: [PATCH 04/78] =?UTF-8?q?[SPARK-19792][WEBUI]=20In=20the=20Master?= =?UTF-8?q?=20Page,the=20column=20named=20=E2=80=9CMemory=20per=20Node?= =?UTF-8?q?=E2=80=9D=20,I=20think=20it=20is=20not=20all=20right?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: liuxian ## What changes were proposed in this pull request? Open the spark web page,in the Master Page ,have two tables:Running Applications table and Completed Applications table, to the column named “Memory per Node” ,I think it is not all right ,because a node may be not have only one executor.So I think that should be named as “Memory per Executor”.Otherwise easy to let the user misunderstanding ## How was this patch tested? N/A Author: liuxian Closes #17132 from 10110346/wid-lx-0302. --- .../scala/org/apache/spark/deploy/master/ui/MasterPage.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala index 7dbe32975435d..e722a24d4a89e 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala @@ -76,7 +76,7 @@ private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") { val aliveWorkers = state.workers.filter(_.state == WorkerState.ALIVE) val workerTable = UIUtils.listingTable(workerHeaders, workerRow, workers) - val appHeaders = Seq("Application ID", "Name", "Cores", "Memory per Node", "Submitted Time", + val appHeaders = Seq("Application ID", "Name", "Cores", "Memory per Executor", "Submitted Time", "User", "State", "Duration") val activeApps = state.activeApps.sortBy(_.startTime).reverse val activeAppsTable = UIUtils.listingTable(appHeaders, appRow, activeApps) From f48461ab2bdb91cd00efa5a5ec4b0b2bc361e7a2 Mon Sep 17 00:00:00 2001 From: uncleGen Date: Sun, 5 Mar 2017 03:35:42 -0800 Subject: [PATCH 05/78] [SPARK-19805][TEST] Log the row type when query result dose not match ## What changes were proposed in this pull request? improve the log message when query result does not match. before pr: ``` == Results == !== Correct Answer - 3 == == Spark Answer - 3 == [1] [1] [2] [2] [3] [3] ``` after pr: ~~== Results == !== Correct Answer - 3 == == Spark Answer - 3 == !RowType[string] RowType[integer] [1] [1] [2] [2] [3] [3]~~ ``` == Results == !== Correct Answer - 3 == == Spark Answer - 3 == !struct struct [1] [1] [2] [2] [3] [3] ``` ## How was this patch tested? Jenkins Author: uncleGen Closes #17145 from uncleGen/improve-test-result. --- .../test/scala/org/apache/spark/sql/QueryTest.scala | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index 34fa626e00e31..f9808834df4a5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -312,13 +312,23 @@ object QueryTest { sparkAnswer: Seq[Row], isSorted: Boolean = false): Option[String] = { if (prepareAnswer(expectedAnswer, isSorted) != prepareAnswer(sparkAnswer, isSorted)) { + val getRowType: Option[Row] => String = row => + row.map(row => + if (row.schema == null) { + "struct<>" + } else { + s"${row.schema.catalogString}" + }).getOrElse("struct<>") + val errorMessage = s""" |== Results == |${sideBySide( s"== Correct Answer - ${expectedAnswer.size} ==" +: + getRowType(expectedAnswer.headOption) +: prepareAnswer(expectedAnswer, isSorted).map(_.toString()), s"== Spark Answer - ${sparkAnswer.size} ==" +: + getRowType(sparkAnswer.headOption) +: prepareAnswer(sparkAnswer, isSorted).map(_.toString())).mkString("\n")} """.stripMargin return Some(errorMessage) From 14bb398fae974137c3e38162cefc088e12838258 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Sun, 5 Mar 2017 03:53:19 -0800 Subject: [PATCH 06/78] [SPARK-19254][SQL] Support Seq, Map, and Struct in functions.lit ## What changes were proposed in this pull request? This pr is to support Seq, Map, and Struct in functions.lit; it adds a new IF named `lit2` with `TypeTag` for avoiding type erasure. ## How was this patch tested? Added tests in `LiteralExpressionSuite` Author: Takeshi Yamamuro Author: Takeshi YAMAMURO Closes #16610 from maropu/SPARK-19254. --- .../sql/catalyst/expressions/literals.scala | 12 ++- .../expressions/LiteralExpressionSuite.scala | 90 ++++++++++++++++--- .../org/apache/spark/sql/functions.scala | 25 ++++-- .../spark/sql/ColumnExpressionSuite.scala | 14 +++ 4 files changed, 121 insertions(+), 20 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index e66fb893394eb..eaeaf08c37b4e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -32,11 +32,13 @@ import java.util.Objects import javax.xml.bind.DatatypeConverter import scala.math.{BigDecimal, BigInt} +import scala.reflect.runtime.universe.TypeTag +import scala.util.Try import org.json4s.JsonAST._ import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, ScalaReflection} import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ @@ -153,6 +155,14 @@ object Literal { Literal(CatalystTypeConverters.convertToCatalyst(v), dataType) } + def create[T : TypeTag](v: T): Literal = Try { + val ScalaReflection.Schema(dataType, _) = ScalaReflection.schemaFor[T] + val convert = CatalystTypeConverters.createToCatalystConverter(dataType) + Literal(convert(v), dataType) + }.getOrElse { + Literal(v) + } + /** * Create a literal with default value for given DataType */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala index 15e8e6c057baf..a9e0eb0e377a6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala @@ -19,9 +19,11 @@ package org.apache.spark.sql.catalyst.expressions import java.nio.charset.StandardCharsets +import scala.reflect.runtime.universe.{typeTag, TypeTag} + import org.apache.spark.SparkFunSuite import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.CatalystTypeConverters +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, ScalaReflection} import org.apache.spark.sql.catalyst.encoders.ExamplePointUDT import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ @@ -75,6 +77,9 @@ class LiteralExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { test("boolean literals") { checkEvaluation(Literal(true), true) checkEvaluation(Literal(false), false) + + checkEvaluation(Literal.create(true), true) + checkEvaluation(Literal.create(false), false) } test("int literals") { @@ -83,36 +88,60 @@ class LiteralExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Literal(d.toLong), d.toLong) checkEvaluation(Literal(d.toShort), d.toShort) checkEvaluation(Literal(d.toByte), d.toByte) + + checkEvaluation(Literal.create(d), d) + checkEvaluation(Literal.create(d.toLong), d.toLong) + checkEvaluation(Literal.create(d.toShort), d.toShort) + checkEvaluation(Literal.create(d.toByte), d.toByte) } checkEvaluation(Literal(Long.MinValue), Long.MinValue) checkEvaluation(Literal(Long.MaxValue), Long.MaxValue) + + checkEvaluation(Literal.create(Long.MinValue), Long.MinValue) + checkEvaluation(Literal.create(Long.MaxValue), Long.MaxValue) } test("double literals") { List(0.0, -0.0, Double.NegativeInfinity, Double.PositiveInfinity).foreach { d => checkEvaluation(Literal(d), d) checkEvaluation(Literal(d.toFloat), d.toFloat) + + checkEvaluation(Literal.create(d), d) + checkEvaluation(Literal.create(d.toFloat), d.toFloat) } checkEvaluation(Literal(Double.MinValue), Double.MinValue) checkEvaluation(Literal(Double.MaxValue), Double.MaxValue) checkEvaluation(Literal(Float.MinValue), Float.MinValue) checkEvaluation(Literal(Float.MaxValue), Float.MaxValue) + checkEvaluation(Literal.create(Double.MinValue), Double.MinValue) + checkEvaluation(Literal.create(Double.MaxValue), Double.MaxValue) + checkEvaluation(Literal.create(Float.MinValue), Float.MinValue) + checkEvaluation(Literal.create(Float.MaxValue), Float.MaxValue) + } test("string literals") { checkEvaluation(Literal(""), "") checkEvaluation(Literal("test"), "test") checkEvaluation(Literal("\u0000"), "\u0000") + + checkEvaluation(Literal.create(""), "") + checkEvaluation(Literal.create("test"), "test") + checkEvaluation(Literal.create("\u0000"), "\u0000") } test("sum two literals") { checkEvaluation(Add(Literal(1), Literal(1)), 2) + checkEvaluation(Add(Literal.create(1), Literal.create(1)), 2) } test("binary literals") { checkEvaluation(Literal.create(new Array[Byte](0), BinaryType), new Array[Byte](0)) checkEvaluation(Literal.create(new Array[Byte](2), BinaryType), new Array[Byte](2)) + + checkEvaluation(Literal.create(new Array[Byte](0)), new Array[Byte](0)) + checkEvaluation(Literal.create(new Array[Byte](2)), new Array[Byte](2)) } test("decimal") { @@ -124,24 +153,63 @@ class LiteralExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { Decimal((d * 1000L).toLong, 10, 3)) checkEvaluation(Literal(BigDecimal(d.toString)), Decimal(d)) checkEvaluation(Literal(new java.math.BigDecimal(d.toString)), Decimal(d)) + + checkEvaluation(Literal.create(Decimal(d)), Decimal(d)) + checkEvaluation(Literal.create(Decimal(d.toInt)), Decimal(d.toInt)) + checkEvaluation(Literal.create(Decimal(d.toLong)), Decimal(d.toLong)) + checkEvaluation(Literal.create(Decimal((d * 1000L).toLong, 10, 3)), + Decimal((d * 1000L).toLong, 10, 3)) + checkEvaluation(Literal.create(BigDecimal(d.toString)), Decimal(d)) + checkEvaluation(Literal.create(new java.math.BigDecimal(d.toString)), Decimal(d)) + } } + private def toCatalyst[T: TypeTag](value: T): Any = { + val ScalaReflection.Schema(dataType, _) = ScalaReflection.schemaFor[T] + CatalystTypeConverters.createToCatalystConverter(dataType)(value) + } + test("array") { - def checkArrayLiteral(a: Array[_], elementType: DataType): Unit = { - val toCatalyst = (a: Array[_], elementType: DataType) => { - CatalystTypeConverters.createToCatalystConverter(ArrayType(elementType))(a) - } - checkEvaluation(Literal(a), toCatalyst(a, elementType)) + def checkArrayLiteral[T: TypeTag](a: Array[T]): Unit = { + checkEvaluation(Literal(a), toCatalyst(a)) + checkEvaluation(Literal.create(a), toCatalyst(a)) + } + checkArrayLiteral(Array(1, 2, 3)) + checkArrayLiteral(Array("a", "b", "c")) + checkArrayLiteral(Array(1.0, 4.0)) + checkArrayLiteral(Array(CalendarInterval.MICROS_PER_DAY, CalendarInterval.MICROS_PER_HOUR)) + } + + test("seq") { + def checkSeqLiteral[T: TypeTag](a: Seq[T], elementType: DataType): Unit = { + checkEvaluation(Literal.create(a), toCatalyst(a)) } - checkArrayLiteral(Array(1, 2, 3), IntegerType) - checkArrayLiteral(Array("a", "b", "c"), StringType) - checkArrayLiteral(Array(1.0, 4.0), DoubleType) - checkArrayLiteral(Array(CalendarInterval.MICROS_PER_DAY, CalendarInterval.MICROS_PER_HOUR), + checkSeqLiteral(Seq(1, 2, 3), IntegerType) + checkSeqLiteral(Seq("a", "b", "c"), StringType) + checkSeqLiteral(Seq(1.0, 4.0), DoubleType) + checkSeqLiteral(Seq(CalendarInterval.MICROS_PER_DAY, CalendarInterval.MICROS_PER_HOUR), CalendarIntervalType) } - test("unsupported types (map and struct) in literals") { + test("map") { + def checkMapLiteral[T: TypeTag](m: T): Unit = { + checkEvaluation(Literal.create(m), toCatalyst(m)) + } + checkMapLiteral(Map("a" -> 1, "b" -> 2, "c" -> 3)) + checkMapLiteral(Map("1" -> 1.0, "2" -> 2.0, "3" -> 3.0)) + } + + test("struct") { + def checkStructLiteral[T: TypeTag](s: T): Unit = { + checkEvaluation(Literal.create(s), toCatalyst(s)) + } + checkStructLiteral((1, 3.0, "abcde")) + checkStructLiteral(("de", 1, 2.0f)) + checkStructLiteral((1, ("fgh", 3.0))) + } + + test("unsupported types (map and struct) in Literal.apply") { def checkUnsupportedTypeInLiteral(v: Any): Unit = { val errMsgMap = intercept[RuntimeException] { Literal(v) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 24ed906d33683..2247010ac3f3f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -91,15 +91,24 @@ object functions { * @group normal_funcs * @since 1.3.0 */ - def lit(literal: Any): Column = { - literal match { - case c: Column => return c - case s: Symbol => return new ColumnName(literal.asInstanceOf[Symbol].name) - case _ => // continue - } + def lit(literal: Any): Column = typedLit(literal) - val literalExpr = Literal(literal) - Column(literalExpr) + /** + * Creates a [[Column]] of literal value. + * + * The passed in object is returned directly if it is already a [[Column]]. + * If the object is a Scala Symbol, it is converted into a [[Column]] also. + * Otherwise, a new [[Column]] is created to represent the literal value. + * The difference between this function and [[lit]] is that this function + * can handle parameterized scala types e.g.: List, Seq and Map. + * + * @group normal_funcs + * @since 2.2.0 + */ + def typedLit[T : TypeTag](literal: T): Column = literal match { + case c: Column => c + case s: Symbol => new ColumnName(s.name) + case _ => Column(Literal.create(literal)) } ////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index ee280a313cc04..b0f398dab7455 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -712,4 +712,18 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext { testData2.select($"a".bitwiseXOR($"b").bitwiseXOR(39)), testData2.collect().toSeq.map(r => Row(r.getInt(0) ^ r.getInt(1) ^ 39))) } + + test("typedLit") { + val df = Seq(Tuple1(0)).toDF("a") + // Only check the types `lit` cannot handle + checkAnswer( + df.select(typedLit(Seq(1, 2, 3))), + Row(Seq(1, 2, 3)) :: Nil) + checkAnswer( + df.select(typedLit(Map("a" -> 1, "b" -> 2))), + Row(Map("a" -> 1, "b" -> 2)) :: Nil) + checkAnswer( + df.select(typedLit(("a", 2, 1.0))), + Row(Row("a", 2, 1.0)) :: Nil) + } } From 80d5338b32e856870cf187ce17bc87335d690761 Mon Sep 17 00:00:00 2001 From: Felix Cheung Date: Sun, 5 Mar 2017 12:37:02 -0800 Subject: [PATCH 07/78] [SPARK-19795][SPARKR] add column functions to_json, from_json ## What changes were proposed in this pull request? Add column functions: to_json, from_json, and tests covering error cases. ## How was this patch tested? unit tests, manual Author: Felix Cheung Closes #17134 from felixcheung/rtojson. --- R/pkg/NAMESPACE | 2 + R/pkg/R/functions.R | 57 +++++++++++++++++++++++ R/pkg/R/generics.R | 8 ++++ R/pkg/inst/tests/testthat/test_sparkSQL.R | 43 ++++++++++++++--- 4 files changed, 103 insertions(+), 7 deletions(-) diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 81e19364ae7ea..871f8e41a0f23 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -229,6 +229,7 @@ exportMethods("%in%", "floor", "format_number", "format_string", + "from_json", "from_unixtime", "from_utc_timestamp", "getField", @@ -327,6 +328,7 @@ exportMethods("%in%", "toDegrees", "toRadians", "to_date", + "to_json", "to_timestamp", "to_utc_timestamp", "translate", diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index 9e5084481fcde..edf2bcf8fdb3c 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -1793,6 +1793,33 @@ setMethod("to_date", column(jc) }) +#' to_json +#' +#' Converts a column containing a \code{structType} into a Column of JSON string. +#' Resolving the Column can fail if an unsupported type is encountered. +#' +#' @param x Column containing the struct +#' @param ... additional named properties to control how it is converted, accepts the same options +#' as the JSON data source. +#' +#' @family normal_funcs +#' @rdname to_json +#' @name to_json +#' @aliases to_json,Column-method +#' @export +#' @examples +#' \dontrun{ +#' to_json(df$t, dateFormat = 'dd/MM/yyyy') +#' select(df, to_json(df$t)) +#'} +#' @note to_json since 2.2.0 +setMethod("to_json", signature(x = "Column"), + function(x, ...) { + options <- varargsToStrEnv(...) + jc <- callJStatic("org.apache.spark.sql.functions", "to_json", x@jc, options) + column(jc) + }) + #' to_timestamp #' #' Converts the column into a TimestampType. You may optionally specify a format @@ -2403,6 +2430,36 @@ setMethod("date_format", signature(y = "Column", x = "character"), column(jc) }) +#' from_json +#' +#' Parses a column containing a JSON string into a Column of \code{structType} with the specified +#' \code{schema}. If the string is unparseable, the Column will contains the value NA. +#' +#' @param x Column containing the JSON string. +#' @param schema a structType object to use as the schema to use when parsing the JSON string. +#' @param ... additional named properties to control how the json is parsed, accepts the same +#' options as the JSON data source. +#' +#' @family normal_funcs +#' @rdname from_json +#' @name from_json +#' @aliases from_json,Column,structType-method +#' @export +#' @examples +#' \dontrun{ +#' schema <- structType(structField("name", "string"), +#' select(df, from_json(df$value, schema, dateFormat = "dd/MM/yyyy")) +#'} +#' @note from_json since 2.2.0 +setMethod("from_json", signature(x = "Column", schema = "structType"), + function(x, schema, ...) { + options <- varargsToStrEnv(...) + jc <- callJStatic("org.apache.spark.sql.functions", + "from_json", + x@jc, schema$jobj, options) + column(jc) + }) + #' from_utc_timestamp #' #' Given a timestamp, which corresponds to a certain time of day in UTC, returns another timestamp diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 647cbbdd825e3..45bc12746511c 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -991,6 +991,10 @@ setGeneric("format_number", function(y, x) { standardGeneric("format_number") }) #' @export setGeneric("format_string", function(format, x, ...) { standardGeneric("format_string") }) +#' @rdname from_json +#' @export +setGeneric("from_json", function(x, schema, ...) { standardGeneric("from_json") }) + #' @rdname from_unixtime #' @export setGeneric("from_unixtime", function(x, ...) { standardGeneric("from_unixtime") }) @@ -1265,6 +1269,10 @@ setGeneric("toRadians", function(x) { standardGeneric("toRadians") }) #' @export setGeneric("to_date", function(x, format) { standardGeneric("to_date") }) +#' @rdname to_json +#' @export +setGeneric("to_json", function(x, ...) { standardGeneric("to_json") }) + #' @rdname to_timestamp #' @export setGeneric("to_timestamp", function(x, format) { standardGeneric("to_timestamp") }) diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index 1dd8c5ce6cb32..7c096597fea66 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -88,6 +88,13 @@ mockLinesComplexType <- complexTypeJsonPath <- tempfile(pattern = "sparkr-test", fileext = ".tmp") writeLines(mockLinesComplexType, complexTypeJsonPath) +# For test map type and struct type in DataFrame +mockLinesMapType <- c("{\"name\":\"Bob\",\"info\":{\"age\":16,\"height\":176.5}}", + "{\"name\":\"Alice\",\"info\":{\"age\":20,\"height\":164.3}}", + "{\"name\":\"David\",\"info\":{\"age\":60,\"height\":180}}") +mapTypeJsonPath <- tempfile(pattern = "sparkr-test", fileext = ".tmp") +writeLines(mockLinesMapType, mapTypeJsonPath) + test_that("calling sparkRSQL.init returns existing SQL context", { sqlContext <- suppressWarnings(sparkRSQL.init(sc)) expect_equal(suppressWarnings(sparkRSQL.init(sc)), sqlContext) @@ -466,13 +473,6 @@ test_that("create DataFrame from a data.frame with complex types", { expect_equal(ldf$an_envir, collected$an_envir) }) -# For test map type and struct type in DataFrame -mockLinesMapType <- c("{\"name\":\"Bob\",\"info\":{\"age\":16,\"height\":176.5}}", - "{\"name\":\"Alice\",\"info\":{\"age\":20,\"height\":164.3}}", - "{\"name\":\"David\",\"info\":{\"age\":60,\"height\":180}}") -mapTypeJsonPath <- tempfile(pattern = "sparkr-test", fileext = ".tmp") -writeLines(mockLinesMapType, mapTypeJsonPath) - test_that("Collect DataFrame with complex types", { # ArrayType df <- read.json(complexTypeJsonPath) @@ -1337,6 +1337,33 @@ test_that("column functions", { df <- createDataFrame(data.frame(x = c(2.5, 3.5))) expect_equal(collect(select(df, bround(df$x, 0)))[[1]][1], 2) expect_equal(collect(select(df, bround(df$x, 0)))[[1]][2], 4) + + # Test to_json(), from_json() + df <- read.json(mapTypeJsonPath) + j <- collect(select(df, alias(to_json(df$info), "json"))) + expect_equal(j[order(j$json), ][1], "{\"age\":16,\"height\":176.5}") + df <- as.DataFrame(j) + schema <- structType(structField("age", "integer"), + structField("height", "double")) + s <- collect(select(df, alias(from_json(df$json, schema), "structcol"))) + expect_equal(ncol(s), 1) + expect_equal(nrow(s), 3) + expect_is(s[[1]][[1]], "struct") + expect_true(any(apply(s, 1, function(x) { x[[1]]$age == 16 } ))) + + # passing option + df <- as.DataFrame(list(list("col" = "{\"date\":\"21/10/2014\"}"))) + schema2 <- structType(structField("date", "date")) + expect_error(tryCatch(collect(select(df, from_json(df$col, schema2))), + error = function(e) { stop(e) }), + paste0(".*(java.lang.NumberFormatException: For input string:).*")) + s <- collect(select(df, from_json(df$col, schema2, dateFormat = "dd/MM/yyyy"))) + expect_is(s[[1]][[1]]$date, "Date") + expect_equal(as.character(s[[1]][[1]]$date), "2014-10-21") + + # check for unparseable + df <- as.DataFrame(list(list("a" = ""))) + expect_equal(collect(select(df, from_json(df$a, schema)))[[1]][[1]], NA) }) test_that("column binary mathfunctions", { @@ -2867,5 +2894,7 @@ unlink(parquetPath) unlink(orcPath) unlink(jsonPath) unlink(jsonPathNa) +unlink(complexTypeJsonPath) +unlink(mapTypeJsonPath) sparkR.session.stop() From 369a148e591bb16ec7da54867610b207602cd698 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Sun, 5 Mar 2017 14:35:06 -0800 Subject: [PATCH 08/78] [SPARK-19595][SQL] Support json array in from_json ## What changes were proposed in this pull request? This PR proposes to both, **Do not allow json arrays with multiple elements and return null in `from_json` with `StructType` as the schema.** Currently, it only reads the single row when the input is a json array. So, the codes below: ```scala import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ val schema = StructType(StructField("a", IntegerType) :: Nil) Seq(("""[{"a": 1}, {"a": 2}]""")).toDF("struct").select(from_json(col("struct"), schema)).show() ``` prints ``` +--------------------+ |jsontostruct(struct)| +--------------------+ | [1]| +--------------------+ ``` This PR simply suggests to print this as `null` if the schema is `StructType` and input is json array.with multiple elements ``` +--------------------+ |jsontostruct(struct)| +--------------------+ | null| +--------------------+ ``` **Support json arrays in `from_json` with `ArrayType` as the schema.** ```scala import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ val schema = ArrayType(StructType(StructField("a", IntegerType) :: Nil)) Seq(("""[{"a": 1}, {"a": 2}]""")).toDF("array").select(from_json(col("array"), schema)).show() ``` prints ``` +-------------------+ |jsontostruct(array)| +-------------------+ | [[1], [2]]| +-------------------+ ``` ## How was this patch tested? Unit test in `JsonExpressionsSuite`, `JsonFunctionsSuite`, Python doctests and manual test. Author: hyukjinkwon Closes #16929 from HyukjinKwon/disallow-array. --- python/pyspark/sql/functions.py | 11 +++- .../expressions/jsonExpressions.scala | 57 +++++++++++++++--- .../expressions/JsonExpressionsSuite.scala | 58 ++++++++++++++++++- .../org/apache/spark/sql/functions.scala | 52 +++++++++++++++-- .../apache/spark/sql/JsonFunctionsSuite.scala | 25 +++++++- 5 files changed, 186 insertions(+), 17 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 426a4a8c93a67..376b86ea69bd4 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1773,11 +1773,11 @@ def json_tuple(col, *fields): @since(2.1) def from_json(col, schema, options={}): """ - Parses a column containing a JSON string into a [[StructType]] with the - specified schema. Returns `null`, in the case of an unparseable string. + Parses a column containing a JSON string into a [[StructType]] or [[ArrayType]] + with the specified schema. Returns `null`, in the case of an unparseable string. :param col: string column in json format - :param schema: a StructType to use when parsing the json column + :param schema: a StructType or ArrayType to use when parsing the json column :param options: options to control parsing. accepts the same options as the json datasource >>> from pyspark.sql.types import * @@ -1786,6 +1786,11 @@ def from_json(col, schema, options={}): >>> df = spark.createDataFrame(data, ("key", "value")) >>> df.select(from_json(df.value, schema).alias("json")).collect() [Row(json=Row(a=1))] + >>> data = [(1, '''[{"a": 1}]''')] + >>> schema = ArrayType(StructType([StructField("a", IntegerType())])) + >>> df = spark.createDataFrame(data, ("key", "value")) + >>> df.select(from_json(df.value, schema).alias("json")).collect() + [Row(json=[Row(a=1)])] """ sc = SparkContext._active_spark_context diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala index 1e690a446951e..dbff62efdddb6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.json._ -import org.apache.spark.sql.catalyst.util.ParseModes +import org.apache.spark.sql.catalyst.util.{GenericArrayData, ParseModes} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils @@ -480,23 +480,45 @@ case class JsonTuple(children: Seq[Expression]) } /** - * Converts an json input string to a [[StructType]] with the specified schema. + * Converts an json input string to a [[StructType]] or [[ArrayType]] with the specified schema. */ case class JsonToStruct( - schema: StructType, + schema: DataType, options: Map[String, String], child: Expression, timeZoneId: Option[String] = None) extends UnaryExpression with TimeZoneAwareExpression with CodegenFallback with ExpectsInputTypes { override def nullable: Boolean = true - def this(schema: StructType, options: Map[String, String], child: Expression) = + def this(schema: DataType, options: Map[String, String], child: Expression) = this(schema, options, child, None) + override def checkInputDataTypes(): TypeCheckResult = schema match { + case _: StructType | ArrayType(_: StructType, _) => + super.checkInputDataTypes() + case _ => TypeCheckResult.TypeCheckFailure( + s"Input schema ${schema.simpleString} must be a struct or an array of structs.") + } + + @transient + lazy val rowSchema = schema match { + case st: StructType => st + case ArrayType(st: StructType, _) => st + } + + // This converts parsed rows to the desired output by the given schema. + @transient + lazy val converter = schema match { + case _: StructType => + (rows: Seq[InternalRow]) => if (rows.length == 1) rows.head else null + case ArrayType(_: StructType, _) => + (rows: Seq[InternalRow]) => new GenericArrayData(rows) + } + @transient lazy val parser = new JacksonParser( - schema, + rowSchema, new JSONOptions(options + ("mode" -> ParseModes.FAIL_FAST_MODE), timeZoneId.get)) override def dataType: DataType = schema @@ -505,11 +527,32 @@ case class JsonToStruct( copy(timeZoneId = Option(timeZoneId)) override def nullSafeEval(json: Any): Any = { + // When input is, + // - `null`: `null`. + // - invalid json: `null`. + // - empty string: `null`. + // + // When the schema is array, + // - json array: `Array(Row(...), ...)` + // - json object: `Array(Row(...))` + // - empty json array: `Array()`. + // - empty json object: `Array(Row(null))`. + // + // When the schema is a struct, + // - json object/array with single element: `Row(...)` + // - json array with multiple elements: `null` + // - empty json array: `null`. + // - empty json object: `Row(null)`. + + // We need `null` if the input string is an empty string. `JacksonParser` can + // deal with this but produces `Nil`. + if (json.toString.trim.isEmpty) return null + try { - parser.parse( + converter(parser.parse( json.asInstanceOf[UTF8String], CreateJacksonParser.utf8String, - identity[UTF8String]).headOption.orNull + identity[UTF8String])) } catch { case _: SparkSQLJsonProcessingException => null } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala index 0c46819cdb9cd..e3584909ddc4a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala @@ -22,7 +22,7 @@ import java.util.Calendar import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util.{DateTimeTestUtils, DateTimeUtils, ParseModes} -import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType, TimestampType} +import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { @@ -372,6 +372,62 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { ) } + test("from_json - input=array, schema=array, output=array") { + val input = """[{"a": 1}, {"a": 2}]""" + val schema = ArrayType(StructType(StructField("a", IntegerType) :: Nil)) + val output = InternalRow(1) :: InternalRow(2) :: Nil + checkEvaluation(JsonToStruct(schema, Map.empty, Literal(input), gmtId), output) + } + + test("from_json - input=object, schema=array, output=array of single row") { + val input = """{"a": 1}""" + val schema = ArrayType(StructType(StructField("a", IntegerType) :: Nil)) + val output = InternalRow(1) :: Nil + checkEvaluation(JsonToStruct(schema, Map.empty, Literal(input), gmtId), output) + } + + test("from_json - input=empty array, schema=array, output=empty array") { + val input = "[ ]" + val schema = ArrayType(StructType(StructField("a", IntegerType) :: Nil)) + val output = Nil + checkEvaluation(JsonToStruct(schema, Map.empty, Literal(input), gmtId), output) + } + + test("from_json - input=empty object, schema=array, output=array of single row with null") { + val input = "{ }" + val schema = ArrayType(StructType(StructField("a", IntegerType) :: Nil)) + val output = InternalRow(null) :: Nil + checkEvaluation(JsonToStruct(schema, Map.empty, Literal(input), gmtId), output) + } + + test("from_json - input=array of single object, schema=struct, output=single row") { + val input = """[{"a": 1}]""" + val schema = StructType(StructField("a", IntegerType) :: Nil) + val output = InternalRow(1) + checkEvaluation(JsonToStruct(schema, Map.empty, Literal(input), gmtId), output) + } + + test("from_json - input=array, schema=struct, output=null") { + val input = """[{"a": 1}, {"a": 2}]""" + val schema = StructType(StructField("a", IntegerType) :: Nil) + val output = null + checkEvaluation(JsonToStruct(schema, Map.empty, Literal(input), gmtId), output) + } + + test("from_json - input=empty array, schema=struct, output=null") { + val input = """[]""" + val schema = StructType(StructField("a", IntegerType) :: Nil) + val output = null + checkEvaluation(JsonToStruct(schema, Map.empty, Literal(input), gmtId), output) + } + + test("from_json - input=empty object, schema=struct, output=single row with null") { + val input = """{ }""" + val schema = StructType(StructField("a", IntegerType) :: Nil) + val output = InternalRow(null) + checkEvaluation(JsonToStruct(schema, Map.empty, Literal(input), gmtId), output) + } + test("from_json null input column") { val schema = StructType(StructField("a", IntegerType) :: Nil) checkEvaluation( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 2247010ac3f3f..201f726db3fad 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -2973,7 +2973,22 @@ object functions { * @group collection_funcs * @since 2.1.0 */ - def from_json(e: Column, schema: StructType, options: Map[String, String]): Column = withExpr { + def from_json(e: Column, schema: StructType, options: Map[String, String]): Column = + from_json(e, schema.asInstanceOf[DataType], options) + + /** + * (Scala-specific) Parses a column containing a JSON string into a `StructType` or `ArrayType` + * with the specified schema. Returns `null`, in the case of an unparseable string. + * + * @param e a string column containing JSON data. + * @param schema the schema to use when parsing the json string + * @param options options to control how the json is parsed. accepts the same options and the + * json data source. + * + * @group collection_funcs + * @since 2.2.0 + */ + def from_json(e: Column, schema: DataType, options: Map[String, String]): Column = withExpr { JsonToStruct(schema, options, e.expr) } @@ -2992,6 +3007,21 @@ object functions { def from_json(e: Column, schema: StructType, options: java.util.Map[String, String]): Column = from_json(e, schema, options.asScala.toMap) + /** + * (Java-specific) Parses a column containing a JSON string into a `StructType` or `ArrayType` + * with the specified schema. Returns `null`, in the case of an unparseable string. + * + * @param e a string column containing JSON data. + * @param schema the schema to use when parsing the json string + * @param options options to control how the json is parsed. accepts the same options and the + * json data source. + * + * @group collection_funcs + * @since 2.2.0 + */ + def from_json(e: Column, schema: DataType, options: java.util.Map[String, String]): Column = + from_json(e, schema, options.asScala.toMap) + /** * Parses a column containing a JSON string into a `StructType` with the specified schema. * Returns `null`, in the case of an unparseable string. @@ -3006,8 +3036,21 @@ object functions { from_json(e, schema, Map.empty[String, String]) /** - * Parses a column containing a JSON string into a `StructType` with the specified schema. - * Returns `null`, in the case of an unparseable string. + * Parses a column containing a JSON string into a `StructType` or `ArrayType` + * with the specified schema. Returns `null`, in the case of an unparseable string. + * + * @param e a string column containing JSON data. + * @param schema the schema to use when parsing the json string + * + * @group collection_funcs + * @since 2.2.0 + */ + def from_json(e: Column, schema: DataType): Column = + from_json(e, schema, Map.empty[String, String]) + + /** + * Parses a column containing a JSON string into a `StructType` or `ArrayType` + * with the specified schema. Returns `null`, in the case of an unparseable string. * * @param e a string column containing JSON data. * @param schema the schema to use when parsing the json string as a json string @@ -3016,8 +3059,7 @@ object functions { * @since 2.1.0 */ def from_json(e: Column, schema: String, options: java.util.Map[String, String]): Column = - from_json(e, DataType.fromJson(schema).asInstanceOf[StructType], options) - + from_json(e, DataType.fromJson(schema), options) /** * (Scala-specific) Converts a column containing a `StructType` into a JSON string with the diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala index 9c39b3c7f09bf..953d161ec2a1d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql import org.apache.spark.sql.functions.{from_json, struct, to_json} import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.sql.types.{CalendarIntervalType, IntegerType, StructType, TimestampType} +import org.apache.spark.sql.types._ class JsonFunctionsSuite extends QueryTest with SharedSQLContext { import testImplicits._ @@ -133,6 +133,29 @@ class JsonFunctionsSuite extends QueryTest with SharedSQLContext { Row(null) :: Nil) } + test("from_json invalid schema") { + val df = Seq("""{"a" 1}""").toDS() + val schema = ArrayType(StringType) + val message = intercept[AnalysisException] { + df.select(from_json($"value", schema)) + }.getMessage + + assert(message.contains( + "Input schema array must be a struct or an array of structs.")) + } + + test("from_json array support") { + val df = Seq("""[{"a": 1, "b": "a"}, {"a": 2}, { }]""").toDS() + val schema = ArrayType( + StructType( + StructField("a", IntegerType) :: + StructField("b", StringType) :: Nil)) + + checkAnswer( + df.select(from_json($"value", schema)), + Row(Seq(Row(1, "a"), Row(2, null), Row(null, null)))) + } + test("to_json") { val df = Seq(Tuple1(Tuple1(1))).toDF("a") From 70f9d7f71c63d2b1fdfed75cb7a59285c272a62b Mon Sep 17 00:00:00 2001 From: Sue Ann Hong Date: Sun, 5 Mar 2017 16:49:31 -0800 Subject: [PATCH 09/78] [SPARK-19535][ML] RecommendForAllUsers RecommendForAllItems for ALS on Dataframe ## What changes were proposed in this pull request? This is a simple implementation of RecommendForAllUsers & RecommendForAllItems for the Dataframe version of ALS. It uses Dataframe operations (not a wrapper on the RDD implementation). Haven't benchmarked against a wrapper, but unit test examples do work. ## How was this patch tested? Unit tests ``` $ build/sbt > mllib/testOnly *ALSSuite -- -z "recommendFor" > mllib/testOnly ``` Author: Your Name Author: sueann Closes #17090 from sueann/SPARK-19535. --- .../apache/spark/ml/recommendation/ALS.scala | 79 ++++++++++++++-- .../recommendation/TopByKeyAggregator.scala | 60 ++++++++++++ .../spark/ml/recommendation/ALSSuite.scala | 94 +++++++++++++++++++ .../TopByKeyAggregatorSuite.scala | 73 ++++++++++++++ 4 files changed, 297 insertions(+), 9 deletions(-) create mode 100644 mllib/src/main/scala/org/apache/spark/ml/recommendation/TopByKeyAggregator.scala create mode 100644 mllib/src/test/scala/org/apache/spark/ml/recommendation/TopByKeyAggregatorSuite.scala diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index 799e881fad74a..60dd7367053e2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -40,7 +40,8 @@ import org.apache.spark.ml.util._ import org.apache.spark.mllib.linalg.CholeskyDecomposition import org.apache.spark.mllib.optimization.NNLS import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{DataFrame, Dataset} +import org.apache.spark.sql.{DataFrame, Dataset, Row} +import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ import org.apache.spark.storage.StorageLevel @@ -284,18 +285,20 @@ class ALSModel private[ml] ( @Since("2.2.0") def setColdStartStrategy(value: String): this.type = set(coldStartStrategy, value) + private val predict = udf { (featuresA: Seq[Float], featuresB: Seq[Float]) => + if (featuresA != null && featuresB != null) { + // TODO(SPARK-19759): try dot-producting on Seqs or another non-converted type for + // potential optimization. + blas.sdot(rank, featuresA.toArray, 1, featuresB.toArray, 1) + } else { + Float.NaN + } + } + @Since("2.0.0") override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema) - // Register a UDF for DataFrame, and then // create a new column named map(predictionCol) by running the predict UDF. - val predict = udf { (userFeatures: Seq[Float], itemFeatures: Seq[Float]) => - if (userFeatures != null && itemFeatures != null) { - blas.sdot(rank, userFeatures.toArray, 1, itemFeatures.toArray, 1) - } else { - Float.NaN - } - } val predictions = dataset .join(userFactors, checkedCast(dataset($(userCol))) === userFactors("id"), "left") @@ -327,6 +330,64 @@ class ALSModel private[ml] ( @Since("1.6.0") override def write: MLWriter = new ALSModel.ALSModelWriter(this) + + /** + * Returns top `numItems` items recommended for each user, for all users. + * @param numItems max number of recommendations for each user + * @return a DataFrame of (userCol: Int, recommendations), where recommendations are + * stored as an array of (itemCol: Int, rating: Float) Rows. + */ + @Since("2.2.0") + def recommendForAllUsers(numItems: Int): DataFrame = { + recommendForAll(userFactors, itemFactors, $(userCol), $(itemCol), numItems) + } + + /** + * Returns top `numUsers` users recommended for each item, for all items. + * @param numUsers max number of recommendations for each item + * @return a DataFrame of (itemCol: Int, recommendations), where recommendations are + * stored as an array of (userCol: Int, rating: Float) Rows. + */ + @Since("2.2.0") + def recommendForAllItems(numUsers: Int): DataFrame = { + recommendForAll(itemFactors, userFactors, $(itemCol), $(userCol), numUsers) + } + + /** + * Makes recommendations for all users (or items). + * @param srcFactors src factors for which to generate recommendations + * @param dstFactors dst factors used to make recommendations + * @param srcOutputColumn name of the column for the source ID in the output DataFrame + * @param dstOutputColumn name of the column for the destination ID in the output DataFrame + * @param num max number of recommendations for each record + * @return a DataFrame of (srcOutputColumn: Int, recommendations), where recommendations are + * stored as an array of (dstOutputColumn: Int, rating: Float) Rows. + */ + private def recommendForAll( + srcFactors: DataFrame, + dstFactors: DataFrame, + srcOutputColumn: String, + dstOutputColumn: String, + num: Int): DataFrame = { + import srcFactors.sparkSession.implicits._ + + val ratings = srcFactors.crossJoin(dstFactors) + .select( + srcFactors("id"), + dstFactors("id"), + predict(srcFactors("features"), dstFactors("features"))) + // We'll force the IDs to be Int. Unfortunately this converts IDs to Int in the output. + val topKAggregator = new TopByKeyAggregator[Int, Int, Float](num, Ordering.by(_._2)) + val recs = ratings.as[(Int, Int, Float)].groupByKey(_._1).agg(topKAggregator.toColumn) + .toDF("id", "recommendations") + + val arrayType = ArrayType( + new StructType() + .add(dstOutputColumn, IntegerType) + .add("rating", FloatType) + ) + recs.select($"id" as srcOutputColumn, $"recommendations" cast arrayType) + } } @Since("1.6.0") diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/TopByKeyAggregator.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/TopByKeyAggregator.scala new file mode 100644 index 0000000000000..517179c0eb9ae --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/TopByKeyAggregator.scala @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.recommendation + +import scala.language.implicitConversions +import scala.reflect.runtime.universe.TypeTag + +import org.apache.spark.sql.{Encoder, Encoders} +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.expressions.Aggregator +import org.apache.spark.util.BoundedPriorityQueue + + +/** + * Works on rows of the form (K1, K2, V) where K1 & K2 are IDs and V is the score value. Finds + * the top `num` K2 items based on the given Ordering. + */ +private[recommendation] class TopByKeyAggregator[K1: TypeTag, K2: TypeTag, V: TypeTag] + (num: Int, ord: Ordering[(K2, V)]) + extends Aggregator[(K1, K2, V), BoundedPriorityQueue[(K2, V)], Array[(K2, V)]] { + + override def zero: BoundedPriorityQueue[(K2, V)] = new BoundedPriorityQueue[(K2, V)](num)(ord) + + override def reduce( + q: BoundedPriorityQueue[(K2, V)], + a: (K1, K2, V)): BoundedPriorityQueue[(K2, V)] = { + q += {(a._2, a._3)} + } + + override def merge( + q1: BoundedPriorityQueue[(K2, V)], + q2: BoundedPriorityQueue[(K2, V)]): BoundedPriorityQueue[(K2, V)] = { + q1 ++= q2 + } + + override def finish(r: BoundedPriorityQueue[(K2, V)]): Array[(K2, V)] = { + r.toArray.sorted(ord.reverse) + } + + override def bufferEncoder: Encoder[BoundedPriorityQueue[(K2, V)]] = { + Encoders.kryo[BoundedPriorityQueue[(K2, V)]] + } + + override def outputEncoder: Encoder[Array[(K2, V)]] = ExpressionEncoder[Array[(K2, V)]]() +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala index c8228dd004374..e494ea89e63bd 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala @@ -22,6 +22,7 @@ import java.util.Random import scala.collection.mutable import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.WrappedArray import scala.collection.JavaConverters._ import scala.language.existentials @@ -660,6 +661,99 @@ class ALSSuite model.setColdStartStrategy(s).transform(data) } } + + private def getALSModel = { + val spark = this.spark + import spark.implicits._ + + val userFactors = Seq( + (0, Array(6.0f, 4.0f)), + (1, Array(3.0f, 4.0f)), + (2, Array(3.0f, 6.0f)) + ).toDF("id", "features") + val itemFactors = Seq( + (3, Array(5.0f, 6.0f)), + (4, Array(6.0f, 2.0f)), + (5, Array(3.0f, 6.0f)), + (6, Array(4.0f, 1.0f)) + ).toDF("id", "features") + val als = new ALS().setRank(2) + new ALSModel(als.uid, als.getRank, userFactors, itemFactors) + .setUserCol("user") + .setItemCol("item") + } + + test("recommendForAllUsers with k < num_items") { + val topItems = getALSModel.recommendForAllUsers(2) + assert(topItems.count() == 3) + assert(topItems.columns.contains("user")) + + val expected = Map( + 0 -> Array((3, 54f), (4, 44f)), + 1 -> Array((3, 39f), (5, 33f)), + 2 -> Array((3, 51f), (5, 45f)) + ) + checkRecommendations(topItems, expected, "item") + } + + test("recommendForAllUsers with k = num_items") { + val topItems = getALSModel.recommendForAllUsers(4) + assert(topItems.count() == 3) + assert(topItems.columns.contains("user")) + + val expected = Map( + 0 -> Array((3, 54f), (4, 44f), (5, 42f), (6, 28f)), + 1 -> Array((3, 39f), (5, 33f), (4, 26f), (6, 16f)), + 2 -> Array((3, 51f), (5, 45f), (4, 30f), (6, 18f)) + ) + checkRecommendations(topItems, expected, "item") + } + + test("recommendForAllItems with k < num_users") { + val topUsers = getALSModel.recommendForAllItems(2) + assert(topUsers.count() == 4) + assert(topUsers.columns.contains("item")) + + val expected = Map( + 3 -> Array((0, 54f), (2, 51f)), + 4 -> Array((0, 44f), (2, 30f)), + 5 -> Array((2, 45f), (0, 42f)), + 6 -> Array((0, 28f), (2, 18f)) + ) + checkRecommendations(topUsers, expected, "user") + } + + test("recommendForAllItems with k = num_users") { + val topUsers = getALSModel.recommendForAllItems(3) + assert(topUsers.count() == 4) + assert(topUsers.columns.contains("item")) + + val expected = Map( + 3 -> Array((0, 54f), (2, 51f), (1, 39f)), + 4 -> Array((0, 44f), (2, 30f), (1, 26f)), + 5 -> Array((2, 45f), (0, 42f), (1, 33f)), + 6 -> Array((0, 28f), (2, 18f), (1, 16f)) + ) + checkRecommendations(topUsers, expected, "user") + } + + private def checkRecommendations( + topK: DataFrame, + expected: Map[Int, Array[(Int, Float)]], + dstColName: String): Unit = { + val spark = this.spark + import spark.implicits._ + + assert(topK.columns.contains("recommendations")) + topK.as[(Int, Seq[(Int, Float)])].collect().foreach { case (id: Int, recs: Seq[(Int, Float)]) => + assert(recs === expected(id)) + } + topK.collect().foreach { row => + val recs = row.getAs[WrappedArray[Row]]("recommendations") + assert(recs(0).fieldIndex(dstColName) == 0) + assert(recs(0).fieldIndex("rating") == 1) + } + } } class ALSCleanerSuite extends SparkFunSuite { diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/TopByKeyAggregatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/TopByKeyAggregatorSuite.scala new file mode 100644 index 0000000000000..5e763a8e908b8 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/TopByKeyAggregatorSuite.scala @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.recommendation + +import org.apache.spark.SparkFunSuite +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.Dataset + + +class TopByKeyAggregatorSuite extends SparkFunSuite with MLlibTestSparkContext { + + private def getTopK(k: Int): Dataset[(Int, Array[(Int, Float)])] = { + val sqlContext = spark.sqlContext + import sqlContext.implicits._ + + val topKAggregator = new TopByKeyAggregator[Int, Int, Float](k, Ordering.by(_._2)) + Seq( + (0, 3, 54f), + (0, 4, 44f), + (0, 5, 42f), + (0, 6, 28f), + (1, 3, 39f), + (2, 3, 51f), + (2, 5, 45f), + (2, 6, 18f) + ).toDS().groupByKey(_._1).agg(topKAggregator.toColumn) + } + + test("topByKey with k < #items") { + val topK = getTopK(2) + assert(topK.count() === 3) + + val expected = Map( + 0 -> Array((3, 54f), (4, 44f)), + 1 -> Array((3, 39f)), + 2 -> Array((3, 51f), (5, 45f)) + ) + checkTopK(topK, expected) + } + + test("topByKey with k > #items") { + val topK = getTopK(5) + assert(topK.count() === 3) + + val expected = Map( + 0 -> Array((3, 54f), (4, 44f), (5, 42f), (6, 28f)), + 1 -> Array((3, 39f)), + 2 -> Array((3, 51f), (5, 45f), (6, 18f)) + ) + checkTopK(topK, expected) + } + + private def checkTopK( + topK: Dataset[(Int, Array[(Int, Float)])], + expected: Map[Int, Array[(Int, Float)]]): Unit = { + topK.collect().foreach { case (id, recs) => assert(recs === expected(id)) } + } +} From 224e0e785b4b449ea638c2629263c798116a3011 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Sun, 5 Mar 2017 18:04:52 -0800 Subject: [PATCH 10/78] [SPARK-19701][SQL][PYTHON] Throws a correct exception for 'in' operator against column ## What changes were proposed in this pull request? This PR proposes to remove incorrect implementation that has been not executed so far (at least from Spark 1.5.2) for `in` operator and throw a correct exception rather than saying it is a bool. I tested the codes above in 1.5.2, 1.6.3, 2.1.0 and in the master branch as below: **1.5.2** ```python >>> df = sqlContext.createDataFrame([[1]]) >>> 1 in df._1 Traceback (most recent call last): File "", line 1, in File ".../spark-1.5.2-bin-hadoop2.6/python/pyspark/sql/column.py", line 418, in __nonzero__ raise ValueError("Cannot convert column into bool: please use '&' for 'and', '|' for 'or', " ValueError: Cannot convert column into bool: please use '&' for 'and', '|' for 'or', '~' for 'not' when building DataFrame boolean expressions. ``` **1.6.3** ```python >>> 1 in sqlContext.range(1).id Traceback (most recent call last): File "", line 1, in File ".../spark-1.6.3-bin-hadoop2.6/python/pyspark/sql/column.py", line 447, in __nonzero__ raise ValueError("Cannot convert column into bool: please use '&' for 'and', '|' for 'or', " ValueError: Cannot convert column into bool: please use '&' for 'and', '|' for 'or', '~' for 'not' when building DataFrame boolean expressions. ``` **2.1.0** ```python >>> 1 in spark.range(1).id Traceback (most recent call last): File "", line 1, in File ".../spark-2.1.0-bin-hadoop2.7/python/pyspark/sql/column.py", line 426, in __nonzero__ raise ValueError("Cannot convert column into bool: please use '&' for 'and', '|' for 'or', " ValueError: Cannot convert column into bool: please use '&' for 'and', '|' for 'or', '~' for 'not' when building DataFrame boolean expressions. ``` **Current Master** ```python >>> 1 in spark.range(1).id Traceback (most recent call last): File "", line 1, in File ".../spark/python/pyspark/sql/column.py", line 452, in __nonzero__ raise ValueError("Cannot convert column into bool: please use '&' for 'and', '|' for 'or', " ValueError: Cannot convert column into bool: please use '&' for 'and', '|' for 'or', '~' for 'not' when building DataFrame boolean expressions. ``` **After** ```python >>> 1 in spark.range(1).id Traceback (most recent call last): File "", line 1, in File ".../spark/python/pyspark/sql/column.py", line 184, in __contains__ raise ValueError("Cannot apply 'in' operator against a column: please use 'contains' " ValueError: Cannot apply 'in' operator against a column: please use 'contains' in a string column or 'array_contains' function for an array column. ``` In more details, It seems the implementation intended to support this ```python 1 in df.column ``` However, currently, it throws an exception as below: ```python Traceback (most recent call last): File "", line 1, in File ".../spark/python/pyspark/sql/column.py", line 426, in __nonzero__ raise ValueError("Cannot convert column into bool: please use '&' for 'and', '|' for 'or', " ValueError: Cannot convert column into bool: please use '&' for 'and', '|' for 'or', '~' for 'not' when building DataFrame boolean expressions. ``` What happens here is as below: ```python class Column(object): def __contains__(self, item): print "I am contains" return Column() def __nonzero__(self): raise Exception("I am nonzero.") >>> 1 in Column() I am contains Traceback (most recent call last): File "", line 1, in File "", line 6, in __nonzero__ Exception: I am nonzero. ``` It seems it calls `__contains__` first and then `__nonzero__` or `__bool__` is being called against `Column()` to make this a bool (or int to be specific). It seems `__nonzero__` (for Python 2), `__bool__` (for Python 3) and `__contains__` forcing the the return into a bool unlike other operators. There are few references about this as below: https://bugs.python.org/issue16011 http://stackoverflow.com/questions/12244074/python-source-code-for-built-in-in-operator/12244378#12244378 http://stackoverflow.com/questions/38542543/functionality-of-python-in-vs-contains/38542777 It seems we can't overwrite `__nonzero__` or `__bool__` as a workaround to make this working because these force the return type as a bool as below: ```python class Column(object): def __contains__(self, item): print "I am contains" return Column() def __nonzero__(self): return "a" >>> 1 in Column() I am contains Traceback (most recent call last): File "", line 1, in TypeError: __nonzero__ should return bool or int, returned str ``` ## How was this patch tested? Added unit tests in `tests.py`. Author: hyukjinkwon Closes #17160 from HyukjinKwon/SPARK-19701. --- python/pyspark/sql/column.py | 4 +++- python/pyspark/sql/tests.py | 3 +++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/python/pyspark/sql/column.py b/python/pyspark/sql/column.py index c10ab9638a21f..ec05c18d4f062 100644 --- a/python/pyspark/sql/column.py +++ b/python/pyspark/sql/column.py @@ -180,7 +180,9 @@ def __init__(self, jc): __ror__ = _bin_op("or") # container operators - __contains__ = _bin_op("contains") + def __contains__(self, item): + raise ValueError("Cannot apply 'in' operator against a column: please use 'contains' " + "in a string column or 'array_contains' function for an array column.") # bitwise operators bitwiseOR = _bin_op("bitwiseOR") diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index e943f8da3db14..81f3d1d36a342 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -967,6 +967,9 @@ def test_column_operators(self): cs.startswith('a'), cs.endswith('a') self.assertTrue(all(isinstance(c, Column) for c in css)) self.assertTrue(isinstance(ci.cast(LongType()), Column)) + self.assertRaisesRegexp(ValueError, + "Cannot apply 'in' operator against a column", + lambda: 1 in cs) def test_column_getitem(self): from pyspark.sql.functions import col From 207067ead6db6dc87b0d144a658e2564e3280a89 Mon Sep 17 00:00:00 2001 From: uncleGen Date: Sun, 5 Mar 2017 18:17:30 -0800 Subject: [PATCH 11/78] [SPARK-19822][TEST] CheckpointSuite.testCheckpointedOperation: should not filter checkpointFilesOfLatestTime with the PATH string. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? https://amplab.cs.berkeley.edu/jenkins/job/SparkPullRequestBuilder/73800/testReport/ ``` sbt.ForkMain$ForkError: org.scalatest.exceptions.TestFailedDueToTimeoutException: The code passed to eventually never returned normally. Attempted 617 times over 10.003740484 seconds. Last failure message: 8 did not equal 2. at org.scalatest.concurrent.Eventually$class.tryTryAgain$1(Eventually.scala:420) at org.scalatest.concurrent.Eventually$class.eventually(Eventually.scala:438) at org.scalatest.concurrent.Eventually$.eventually(Eventually.scala:478) at org.scalatest.concurrent.Eventually$class.eventually(Eventually.scala:336) at org.scalatest.concurrent.Eventually$.eventually(Eventually.scala:478) at org.apache.spark.streaming.DStreamCheckpointTester$class.generateOutput(CheckpointSuite .scala:172) at org.apache.spark.streaming.CheckpointSuite.generateOutput(CheckpointSuite.scala:211) ``` the check condition is: ``` val checkpointFilesOfLatestTime = Checkpoint.getCheckpointFiles(checkpointDir).filter { _.toString.contains(clock.getTimeMillis.toString) } // Checkpoint files are written twice for every batch interval. So assert that both // are written to make sure that both of them have been written. assert(checkpointFilesOfLatestTime.size === 2) ``` the path string may contain the `clock.getTimeMillis.toString`, like `3500` : ``` file:/root/dev/spark/assembly/CheckpointSuite/spark-20035007-9891-4fb6-91c1-cc15b7ccaf15/checkpoint-500 file:/root/dev/spark/assembly/CheckpointSuite/spark-20035007-9891-4fb6-91c1-cc15b7ccaf15/checkpoint-1000 file:/root/dev/spark/assembly/CheckpointSuite/spark-20035007-9891-4fb6-91c1-cc15b7ccaf15/checkpoint-1500 file:/root/dev/spark/assembly/CheckpointSuite/spark-20035007-9891-4fb6-91c1-cc15b7ccaf15/checkpoint-2000 file:/root/dev/spark/assembly/CheckpointSuite/spark-20035007-9891-4fb6-91c1-cc15b7ccaf15/checkpoint-2500 file:/root/dev/spark/assembly/CheckpointSuite/spark-20035007-9891-4fb6-91c1-cc15b7ccaf15/checkpoint-3000 file:/root/dev/spark/assembly/CheckpointSuite/spark-20035007-9891-4fb6-91c1-cc15b7ccaf15/checkpoint-3500.bk file:/root/dev/spark/assembly/CheckpointSuite/spark-20035007-9891-4fb6-91c1-cc15b7ccaf15/checkpoint-3500 ▲▲▲▲ ``` so we should only check the filename, but not the whole path. ## How was this patch tested? Jenkins. Author: uncleGen Closes #17167 from uncleGen/flaky-CheckpointSuite. --- .../scala/org/apache/spark/streaming/CheckpointSuite.scala | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala index 7fcf45e7dedc9..ee2fd45a7e851 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala @@ -152,11 +152,9 @@ trait DStreamCheckpointTester { self: SparkFunSuite => stopSparkContext: Boolean ): Seq[Seq[V]] = { try { - val batchDuration = ssc.graph.batchDuration val batchCounter = new BatchCounter(ssc) ssc.start() val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] - val currentTime = clock.getTimeMillis() logInfo("Manual clock before advancing = " + clock.getTimeMillis()) clock.setTime(targetBatchTime.milliseconds) @@ -171,7 +169,7 @@ trait DStreamCheckpointTester { self: SparkFunSuite => eventually(timeout(10 seconds)) { val checkpointFilesOfLatestTime = Checkpoint.getCheckpointFiles(checkpointDir).filter { - _.toString.contains(clock.getTimeMillis.toString) + _.getName.contains(clock.getTimeMillis.toString) } // Checkpoint files are written twice for every batch interval. So assert that both // are written to make sure that both of them have been written. From 2a0bc867a4a1dad4ecac47701199e540d345ff4f Mon Sep 17 00:00:00 2001 From: Tejas Patil Date: Mon, 6 Mar 2017 10:16:20 -0800 Subject: [PATCH 12/78] [SPARK-17495][SQL] Support Decimal type in Hive-hash ## What changes were proposed in this pull request? Hive hash to support Decimal datatype. [Hive internally normalises decimals](https://github.com/apache/hive/blob/4ba713ccd85c3706d195aeef9476e6e6363f1c21/storage-api/src/java/org/apache/hadoop/hive/common/type/HiveDecimalV1.java#L307) and I have ported that logic as-is to HiveHash. ## How was this patch tested? Added unit tests Author: Tejas Patil Closes #17056 from tejasapatil/SPARK-17495_decimal. --- .../spark/sql/catalyst/expressions/hash.scala | 56 ++++++++++++++++++- .../expressions/HashExpressionsSuite.scala | 46 ++++++++++++++- 2 files changed, 99 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala index 2d9c2e42064b3..03101b4bfc5f3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.expressions +import java.math.{BigDecimal, RoundingMode} import java.security.{MessageDigest, NoSuchAlgorithmException} import java.util.zip.CRC32 @@ -580,7 +581,7 @@ object XxHash64Function extends InterpretedHashFunction { * We should use this hash function for both shuffle and bucket of Hive tables, so that * we can guarantee shuffle and bucketing have same data distribution * - * TODO: Support Decimal and date related types + * TODO: Support date related types */ @ExpressionDescription( usage = "_FUNC_(expr1, expr2, ...) - Returns a hash value of the arguments.") @@ -635,6 +636,16 @@ case class HiveHash(children: Seq[Expression]) extends HashExpression[Int] { override protected def genHashBytes(b: String, result: String): String = s"$result = $hasherClassName.hashUnsafeBytes($b, Platform.BYTE_ARRAY_OFFSET, $b.length);" + override protected def genHashDecimal( + ctx: CodegenContext, + d: DecimalType, + input: String, + result: String): String = { + s""" + $result = ${HiveHashFunction.getClass.getName.stripSuffix("$")}.normalizeDecimal( + $input.toJavaBigDecimal()).hashCode();""" + } + override protected def genHashCalendarInterval(input: String, result: String): String = { s""" $result = (31 * $hasherClassName.hashInt($input.months)) + @@ -732,6 +743,44 @@ object HiveHashFunction extends InterpretedHashFunction { HiveHasher.hashUnsafeBytes(base, offset, len) } + private val HIVE_DECIMAL_MAX_PRECISION = 38 + private val HIVE_DECIMAL_MAX_SCALE = 38 + + // Mimics normalization done for decimals in Hive at HiveDecimalV1.normalize() + def normalizeDecimal(input: BigDecimal): BigDecimal = { + if (input == null) return null + + def trimDecimal(input: BigDecimal) = { + var result = input + if (result.compareTo(BigDecimal.ZERO) == 0) { + // Special case for 0, because java doesn't strip zeros correctly on that number. + result = BigDecimal.ZERO + } else { + result = result.stripTrailingZeros + if (result.scale < 0) { + // no negative scale decimals + result = result.setScale(0) + } + } + result + } + + var result = trimDecimal(input) + val intDigits = result.precision - result.scale + if (intDigits > HIVE_DECIMAL_MAX_PRECISION) { + return null + } + + val maxScale = Math.min(HIVE_DECIMAL_MAX_SCALE, + Math.min(HIVE_DECIMAL_MAX_PRECISION - intDigits, result.scale)) + if (result.scale > maxScale) { + result = result.setScale(maxScale, RoundingMode.HALF_UP) + // Trimming is again necessary, because rounding may introduce new trailing 0's. + result = trimDecimal(result) + } + result + } + override def hash(value: Any, dataType: DataType, seed: Long): Long = { value match { case null => 0 @@ -785,7 +834,10 @@ object HiveHashFunction extends InterpretedHashFunction { } result - case _ => super.hash(value, dataType, 0) + case d: Decimal => + normalizeDecimal(d.toJavaBigDecimal).hashCode() + + case _ => super.hash(value, dataType, seed) } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala index 0cb3a79eee67d..0c77dc2709dad 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala @@ -75,7 +75,6 @@ class HashExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkConsistencyBetweenInterpretedAndCodegen(Crc32, BinaryType) } - def checkHiveHash(input: Any, dataType: DataType, expected: Long): Unit = { // Note : All expected hashes need to be computed using Hive 1.2.1 val actual = HiveHashFunction.hash(input, dataType, seed = 0) @@ -371,6 +370,51 @@ class HashExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { new StructType().add("array", arrayOfString).add("map", mapOfString)) .add("structOfUDT", structOfUDT)) + test("hive-hash for decimal") { + def checkHiveHashForDecimal( + input: String, + precision: Int, + scale: Int, + expected: Long): Unit = { + val decimalType = DataTypes.createDecimalType(precision, scale) + val decimal = { + val value = Decimal.apply(new java.math.BigDecimal(input)) + if (value.changePrecision(precision, scale)) value else null + } + + checkHiveHash(decimal, decimalType, expected) + } + + checkHiveHashForDecimal("18", 38, 0, 558) + checkHiveHashForDecimal("-18", 38, 0, -558) + checkHiveHashForDecimal("-18", 38, 12, -558) + checkHiveHashForDecimal("18446744073709001000", 38, 19, 0) + checkHiveHashForDecimal("-18446744073709001000", 38, 22, 0) + checkHiveHashForDecimal("-18446744073709001000", 38, 3, 17070057) + checkHiveHashForDecimal("18446744073709001000", 38, 4, -17070057) + checkHiveHashForDecimal("9223372036854775807", 38, 4, 2147482656) + checkHiveHashForDecimal("-9223372036854775807", 38, 5, -2147482656) + checkHiveHashForDecimal("00000.00000000000", 38, 34, 0) + checkHiveHashForDecimal("-00000.00000000000", 38, 11, 0) + checkHiveHashForDecimal("123456.1234567890", 38, 2, 382713974) + checkHiveHashForDecimal("123456.1234567890", 38, 20, 1871500252) + checkHiveHashForDecimal("123456.1234567890", 38, 10, 1871500252) + checkHiveHashForDecimal("-123456.1234567890", 38, 10, -1871500234) + checkHiveHashForDecimal("123456.1234567890", 38, 0, 3827136) + checkHiveHashForDecimal("-123456.1234567890", 38, 0, -3827136) + checkHiveHashForDecimal("123456.1234567890", 38, 20, 1871500252) + checkHiveHashForDecimal("-123456.1234567890", 38, 20, -1871500234) + checkHiveHashForDecimal("123456.123456789012345678901234567890", 38, 0, 3827136) + checkHiveHashForDecimal("-123456.123456789012345678901234567890", 38, 0, -3827136) + checkHiveHashForDecimal("123456.123456789012345678901234567890", 38, 10, 1871500252) + checkHiveHashForDecimal("-123456.123456789012345678901234567890", 38, 10, -1871500234) + checkHiveHashForDecimal("123456.123456789012345678901234567890", 38, 20, 236317582) + checkHiveHashForDecimal("-123456.123456789012345678901234567890", 38, 20, -236317544) + checkHiveHashForDecimal("123456.123456789012345678901234567890", 38, 30, 1728235666) + checkHiveHashForDecimal("-123456.123456789012345678901234567890", 38, 30, -1728235608) + checkHiveHashForDecimal("123456.123456789012345678901234567890", 38, 31, 1728235666) + } + test("SPARK-18207: Compute hash for a lot of expressions") { val N = 1000 val wideRow = new GenericInternalRow( From 339b53a1311e08521d84a83c94201fcf3c766fb2 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Mon, 6 Mar 2017 10:36:50 -0800 Subject: [PATCH 13/78] [SPARK-19737][SQL] New analysis rule for reporting unregistered functions without relying on relation resolution ## What changes were proposed in this pull request? This PR adds a new `Once` analysis rule batch consists of a single analysis rule `LookupFunctions` that performs simple existence check over `UnresolvedFunctions` without actually resolving them. The benefit of this rule is that it doesn't require function arguments to be resolved first and therefore doesn't rely on relation resolution, which may incur potentially expensive partition/schema discovery cost. Please refer to [SPARK-19737][1] for more details about the motivation. ## How was this patch tested? New test case added in `AnalysisErrorSuite`. [1]: https://issues.apache.org/jira/browse/SPARK-19737 Author: Cheng Lian Closes #17168 from liancheng/spark-19737-lookup-functions. --- .../sql/catalyst/analysis/Analyzer.scala | 21 +++++++++++++++++ .../catalog/SessionCatalogSuite.scala | 23 ++++++++++++++++++- .../spark/sql/hive/HiveSessionCatalog.scala | 5 ++++ 3 files changed, 48 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 6d569b612de7d..2f8489de6b000 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -117,6 +117,8 @@ class Analyzer( Batch("Hints", fixedPoint, new ResolveHints.ResolveBroadcastHints(conf), ResolveHints.RemoveAllHints), + Batch("Simple Sanity Check", Once, + LookupFunctions), Batch("Substitution", fixedPoint, CTESubstitution, WindowsSubstitution, @@ -1038,6 +1040,25 @@ class Analyzer( } } + /** + * Checks whether a function identifier referenced by an [[UnresolvedFunction]] is defined in the + * function registry. Note that this rule doesn't try to resolve the [[UnresolvedFunction]]. It + * only performs simple existence check according to the function identifier to quickly identify + * undefined functions without triggering relation resolution, which may incur potentially + * expensive partition/schema discovery process in some cases. + * + * @see [[ResolveFunctions]] + * @see https://issues.apache.org/jira/browse/SPARK-19737 + */ + object LookupFunctions extends Rule[LogicalPlan] { + override def apply(plan: LogicalPlan): LogicalPlan = plan.transformAllExpressions { + case f: UnresolvedFunction if !catalog.functionExists(f.name) => + withPosition(f) { + throw new NoSuchFunctionException(f.name.database.getOrElse("default"), f.name.funcName) + } + } + } + /** * Replaces [[UnresolvedFunction]]s with concrete [[Expression]]s. */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala index a755231962be2..ffc272c6c0c39 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.catalog import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} +import org.apache.spark.sql.catalyst.{FunctionIdentifier, SimpleCatalystConf, TableIdentifier} import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.parser.CatalystSqlParser @@ -1196,4 +1196,25 @@ class SessionCatalogSuite extends PlanTest { catalog.listFunctions("unknown_db", "func*") } } + + test("SPARK-19737: detect undefined functions without triggering relation resolution") { + import org.apache.spark.sql.catalyst.dsl.plans._ + + Seq(true, false) foreach { caseSensitive => + val conf = SimpleCatalystConf(caseSensitive) + val catalog = new SessionCatalog(newBasicCatalog(), new SimpleFunctionRegistry, conf) + val analyzer = new Analyzer(catalog, conf) + + // The analyzer should report the undefined function rather than the undefined table first. + val cause = intercept[AnalysisException] { + analyzer.execute( + UnresolvedRelation(TableIdentifier("undefined_table")).select( + UnresolvedFunction("undefined_fn", Nil, isDistinct = false) + ) + ) + } + + assert(cause.getMessage.contains("Undefined function: 'undefined_fn'")) + } + } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala index c9be1b9d100b0..f1ea86890c210 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala @@ -199,6 +199,11 @@ private[sql] class HiveSessionCatalog( } } + // TODO Removes this method after implementing Spark native "histogram_numeric". + override def functionExists(name: FunctionIdentifier): Boolean = { + super.functionExists(name) || hiveFunctions.contains(name.funcName) + } + /** List of functions we pass over to Hive. Note that over time this list should go to 0. */ // We have a list of Hive built-in functions that we do not support. So, we will check // Hive's function registry and lazily load needed functions into our own function registry. From 46a64d1e0ae12c31e848f377a84fb28e3efb3699 Mon Sep 17 00:00:00 2001 From: Gaurav Date: Mon, 6 Mar 2017 10:41:49 -0800 Subject: [PATCH 14/78] [SPARK-19304][STREAMING][KINESIS] fix kinesis slow checkpoint recovery ## What changes were proposed in this pull request? added a limit to getRecords api call call in KinesisBackedBlockRdd. This helps reduce the amount of data returned by kinesis api call making the recovery considerably faster As we are storing the `fromSeqNum` & `toSeqNum` in checkpoint metadata, we can also store the number of records. Which can later be used for api call. ## How was this patch tested? The patch was manually tested Apologies for any silly mistakes, opening first pull request Author: Gaurav Closes #16842 from Gauravshah/kinesis_checkpoint_recovery_fix_2_1_0. --- .../kinesis/KinesisBackedBlockRDD.scala | 25 ++++++++++++++----- .../streaming/kinesis/KinesisReceiver.scala | 3 ++- .../kinesis/KinesisBackedBlockRDDSuite.scala | 4 +-- .../kinesis/KinesisStreamSuite.scala | 4 +-- 4 files changed, 25 insertions(+), 11 deletions(-) diff --git a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDD.scala b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDD.scala index 23c4d99e50f51..0f1790bddcc3d 100644 --- a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDD.scala +++ b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDD.scala @@ -36,7 +36,11 @@ import org.apache.spark.util.NextIterator /** Class representing a range of Kinesis sequence numbers. Both sequence numbers are inclusive. */ private[kinesis] case class SequenceNumberRange( - streamName: String, shardId: String, fromSeqNumber: String, toSeqNumber: String) + streamName: String, + shardId: String, + fromSeqNumber: String, + toSeqNumber: String, + recordCount: Int) /** Class representing an array of Kinesis sequence number ranges */ private[kinesis] @@ -136,6 +140,8 @@ class KinesisSequenceRangeIterator( private val client = new AmazonKinesisClient(credentials) private val streamName = range.streamName private val shardId = range.shardId + // AWS limits to maximum of 10k records per get call + private val maxGetRecordsLimit = 10000 private var toSeqNumberReceived = false private var lastSeqNumber: String = null @@ -153,12 +159,14 @@ class KinesisSequenceRangeIterator( // If the internal iterator has not been initialized, // then fetch records from starting sequence number - internalIterator = getRecords(ShardIteratorType.AT_SEQUENCE_NUMBER, range.fromSeqNumber) + internalIterator = getRecords(ShardIteratorType.AT_SEQUENCE_NUMBER, range.fromSeqNumber, + range.recordCount) } else if (!internalIterator.hasNext) { // If the internal iterator does not have any more records, // then fetch more records after the last consumed sequence number - internalIterator = getRecords(ShardIteratorType.AFTER_SEQUENCE_NUMBER, lastSeqNumber) + internalIterator = getRecords(ShardIteratorType.AFTER_SEQUENCE_NUMBER, lastSeqNumber, + range.recordCount) } if (!internalIterator.hasNext) { @@ -191,9 +199,12 @@ class KinesisSequenceRangeIterator( /** * Get records starting from or after the given sequence number. */ - private def getRecords(iteratorType: ShardIteratorType, seqNum: String): Iterator[Record] = { + private def getRecords( + iteratorType: ShardIteratorType, + seqNum: String, + recordCount: Int): Iterator[Record] = { val shardIterator = getKinesisIterator(iteratorType, seqNum) - val result = getRecordsAndNextKinesisIterator(shardIterator) + val result = getRecordsAndNextKinesisIterator(shardIterator, recordCount) result._1 } @@ -202,10 +213,12 @@ class KinesisSequenceRangeIterator( * to get records from Kinesis), and get the next shard iterator for next consumption. */ private def getRecordsAndNextKinesisIterator( - shardIterator: String): (Iterator[Record], String) = { + shardIterator: String, + recordCount: Int): (Iterator[Record], String) = { val getRecordsRequest = new GetRecordsRequest getRecordsRequest.setRequestCredentials(credentials) getRecordsRequest.setShardIterator(shardIterator) + getRecordsRequest.setLimit(Math.min(recordCount, this.maxGetRecordsLimit)) val getRecordsResult = retryOrTimeout[GetRecordsResult]( s"getting records using shard iterator") { client.getRecords(getRecordsRequest) diff --git a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala index 13fc54e531dda..320728f4bb221 100644 --- a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala +++ b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala @@ -210,7 +210,8 @@ private[kinesis] class KinesisReceiver[T]( if (records.size > 0) { val dataIterator = records.iterator().asScala.map(messageHandler) val metadata = SequenceNumberRange(streamName, shardId, - records.get(0).getSequenceNumber(), records.get(records.size() - 1).getSequenceNumber()) + records.get(0).getSequenceNumber(), records.get(records.size() - 1).getSequenceNumber(), + records.size()) blockGenerator.addMultipleDataWithCallback(dataIterator, metadata) } } diff --git a/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala index 18a5a1509a33a..2c7b9c58e6fa6 100644 --- a/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala +++ b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala @@ -51,7 +51,7 @@ abstract class KinesisBackedBlockRDDTests(aggregateTestData: Boolean) shardIdToSeqNumbers = shardIdToDataAndSeqNumbers.mapValues { _.map { _._2 }} shardIdToRange = shardIdToSeqNumbers.map { case (shardId, seqNumbers) => val seqNumRange = SequenceNumberRange( - testUtils.streamName, shardId, seqNumbers.head, seqNumbers.last) + testUtils.streamName, shardId, seqNumbers.head, seqNumbers.last, seqNumbers.size) (shardId, seqNumRange) } allRanges = shardIdToRange.values.toSeq @@ -181,7 +181,7 @@ abstract class KinesisBackedBlockRDDTests(aggregateTestData: Boolean) // Create the necessary ranges to use in the RDD val fakeRanges = Array.fill(numPartitions - numPartitionsInKinesis)( - SequenceNumberRanges(SequenceNumberRange("fakeStream", "fakeShardId", "xxx", "yyy"))) + SequenceNumberRanges(SequenceNumberRange("fakeStream", "fakeShardId", "xxx", "yyy", 1))) val realRanges = Array.tabulate(numPartitionsInKinesis) { i => val range = shardIdToRange(shardIds(i + (numPartitions - numPartitionsInKinesis))) SequenceNumberRanges(Array(range)) diff --git a/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala index 387a96f26b305..afb55c84f81fe 100644 --- a/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala +++ b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala @@ -119,13 +119,13 @@ abstract class KinesisStreamTests(aggregateTestData: Boolean) extends KinesisFun // Generate block info data for testing val seqNumRanges1 = SequenceNumberRanges( - SequenceNumberRange("fakeStream", "fakeShardId", "xxx", "yyy")) + SequenceNumberRange("fakeStream", "fakeShardId", "xxx", "yyy", 67)) val blockId1 = StreamBlockId(kinesisStream.id, 123) val blockInfo1 = ReceivedBlockInfo( 0, None, Some(seqNumRanges1), new BlockManagerBasedStoreResult(blockId1, None)) val seqNumRanges2 = SequenceNumberRanges( - SequenceNumberRange("fakeStream", "fakeShardId", "aaa", "bbb")) + SequenceNumberRange("fakeStream", "fakeShardId", "aaa", "bbb", 89)) val blockId2 = StreamBlockId(kinesisStream.id, 345) val blockInfo2 = ReceivedBlockInfo( 0, None, Some(seqNumRanges2), new BlockManagerBasedStoreResult(blockId2, None)) From 096df6d933c5326e5782aa8c5de842a0800eb369 Mon Sep 17 00:00:00 2001 From: windpiger Date: Mon, 6 Mar 2017 10:44:26 -0800 Subject: [PATCH 15/78] [SPARK-19257][SQL] location for table/partition/database should be java.net.URI ## What changes were proposed in this pull request? Currently we treat the location of table/partition/database as URI string. It will be safer if we can make the type of location as java.net.URI. In this PR, there are following classes changes: **1. CatalogDatabase** ``` case class CatalogDatabase( name: String, description: String, locationUri: String, properties: Map[String, String]) ---> case class CatalogDatabase( name: String, description: String, locationUri: URI, properties: Map[String, String]) ``` **2. CatalogStorageFormat** ``` case class CatalogStorageFormat( locationUri: Option[String], inputFormat: Option[String], outputFormat: Option[String], serde: Option[String], compressed: Boolean, properties: Map[String, String]) ----> case class CatalogStorageFormat( locationUri: Option[URI], inputFormat: Option[String], outputFormat: Option[String], serde: Option[String], compressed: Boolean, properties: Map[String, String]) ``` Before and After this PR, it is transparent for user, there is no change that the user should concern. The `String` to `URI` just happened in SparkSQL internally. Here list some operation related location: **1. whitespace in the location** e.g. `/a/b c/d` For both table location and partition location, After `CREATE TABLE t... (PARTITIONED BY ...) LOCATION '/a/b c/d'` , then `DESC EXTENDED t ` show the location is `/a/b c/d`, and the real path in the FileSystem also show `/a/b c/d` **2. colon(:) in the location** e.g. `/a/b:c/d` For both table location and partition location, when `CREATE TABLE t... (PARTITIONED BY ...) LOCATION '/a/b:c/d'` , **In linux file system** `DESC EXTENDED t ` show the location is `/a/b:c/d`, and the real path in the FileSystem also show `/a/b:c/d` **in HDFS** throw exception: `java.lang.IllegalArgumentException: Pathname /a/b:c/d from hdfs://iZbp1151s8hbnnwriekxdeZ:9000/a/b:c/d is not a valid DFS filename.` **while** After `INSERT INTO TABLE t PARTITION(a="a:b") SELECT 1` then `DESC EXTENDED t ` show the location is `/xxx/a=a%3Ab`, and the real path in the FileSystem also show `/xxx/a=a%3Ab` **3. percent sign(%) in the location** e.g. `/a/b%c/d` For both table location and partition location, After `CREATE TABLE t... (PARTITIONED BY ...) LOCATION '/a/b%c/d'` , then `DESC EXTENDED t ` show the location is `/a/b%c/d`, and the real path in the FileSystem also show `/a/b%c/d` **4. encoded(%25) in the location** e.g. `/a/b%25c/d` For both table location and partition location, After `CREATE TABLE t... (PARTITIONED BY ...) LOCATION '/a/b%25c/d'` , then `DESC EXTENDED t ` show the location is `/a/b%25c/d`, and the real path in the FileSystem also show `/a/b%25c/d` **while** After `INSERT INTO TABLE t PARTITION(a="%25") SELECT 1` then `DESC EXTENDED t ` show the location is `/xxx/a=%2525`, and the real path in the FileSystem also show `/xxx/a=%2525` **Additionally**, except the location, there are two other factors will affect the location of the table/partition. one is the table name which does not allowed to have special characters, and the other is `partition name` which have the same actions with `partition value`, and `partition name` with special character situation has add some testcase and resolve a bug in [PR](https://github.com/apache/spark/pull/17173) ### Summary: After `CREATE TABLE t... (PARTITIONED BY ...) LOCATION path`, the path which we get from `DESC TABLE` and `real path in FileSystem` are all the same with the `CREATE TABLE` command(different filesystem has different action that allow what kind of special character to create the path, e.g. HDFS does not allow colon, but linux filesystem allow it ). `DataBase` also have the same logic with `CREATE TABLE` while if the `partition value` has some special character like `%` `:` `#` etc, then we will get the path with encoded `partition value` like `/xxx/a=A%25B` from `DESC TABLE` and `real path in FileSystem` In this PR, the core change code is using `new Path(str).toUri` and `new Path(uri).toString` which transfrom `str to uri `or `uri to str`. for example: ``` val str = '/a/b c/d' val uri = new Path(str).toUri --> '/a/b%20c/d' val strFromUri = new Path(uri).toString -> '/a/b c/d' ``` when we restore table/partition from metastore, or get the location from `CREATE TABLE` command, we can use it as above to change string to uri `new Path(str).toUri ` ## How was this patch tested? unit test added. The `current master branch` also `passed all the test cases` added in this PR by a litter change. https://github.com/apache/spark/pull/17149/files#diff-b7094baa12601424a5d19cb930e3402fR1764 here `toURI` -> `toString` when test in master branch. This can show that this PR is transparent for user. Author: windpiger Closes #17149 from windpiger/changeStringToURI. --- .../catalog/ExternalCatalogUtils.scala | 26 +++ .../catalyst/catalog/InMemoryCatalog.scala | 12 +- .../sql/catalyst/catalog/SessionCatalog.scala | 15 +- .../sql/catalyst/catalog/interface.scala | 14 +- .../catalog/ExternalCatalogSuite.scala | 18 +- .../spark/sql/execution/SparkSqlParser.scala | 8 +- .../command/createDataSourceTables.scala | 10 +- .../spark/sql/execution/command/ddl.scala | 14 +- .../spark/sql/execution/command/tables.scala | 11 +- .../datasources/CatalogFileIndex.scala | 4 +- .../execution/datasources/DataSource.scala | 5 +- .../datasources/DataSourceStrategy.scala | 6 +- .../spark/sql/internal/CatalogImpl.scala | 4 +- .../spark/sql/internal/SharedState.scala | 6 +- .../execution/command/DDLCommandSuite.scala | 8 +- .../sql/execution/command/DDLSuite.scala | 136 +++++++++++--- .../spark/sql/internal/CatalogSuite.scala | 4 +- .../sql/sources/BucketedWriteSuite.scala | 2 +- .../spark/sql/sources/PathOptionSuite.scala | 12 +- .../spark/sql/hive/HiveExternalCatalog.scala | 21 ++- .../spark/sql/hive/HiveMetastoreCatalog.scala | 4 +- .../spark/sql/hive/HiveStrategies.scala | 2 +- .../sql/hive/client/HiveClientImpl.scala | 15 +- .../spark/sql/hive/client/HiveShim.scala | 9 +- .../spark/sql/hive/HiveDDLCommandSuite.scala | 12 +- ...nalCatalogBackwardCompatibilitySuite.scala | 23 ++- .../sql/hive/HiveMetastoreCatalogSuite.scala | 4 +- .../spark/sql/hive/HiveSparkSubmitSuite.scala | 12 +- .../sql/hive/MetastoreDataSourcesSuite.scala | 2 +- .../spark/sql/hive/MultiDatabaseSuite.scala | 8 +- .../spark/sql/hive/client/VersionsSuite.scala | 13 +- .../sql/hive/execution/HiveDDLSuite.scala | 171 +++++++++++++++++- .../sql/hive/execution/SQLQuerySuite.scala | 4 +- 33 files changed, 460 insertions(+), 155 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogUtils.scala index 58ced549bafe9..a418edc302d9c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogUtils.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.catalog +import java.net.URI + import org.apache.hadoop.fs.Path import org.apache.hadoop.util.Shell @@ -162,6 +164,30 @@ object CatalogUtils { BucketSpec(numBuckets, normalizedBucketCols, normalizedSortCols) } + /** + * Convert URI to String. + * Since URI.toString does not decode the uri, e.g. change '%25' to '%'. + * Here we create a hadoop Path with the given URI, and rely on Path.toString + * to decode the uri + * @param uri the URI of the path + * @return the String of the path + */ + def URIToString(uri: URI): String = { + new Path(uri).toString + } + + /** + * Convert String to URI. + * Since new URI(string) does not encode string, e.g. change '%' to '%25'. + * Here we create a hadoop Path with the given String, and rely on Path.toUri + * to encode the string + * @param str the String of the path + * @return the URI of the path + */ + def stringToURI(str: String): URI = { + new Path(str).toUri + } + private def normalizeColumnName( tableName: String, tableCols: Seq[String], diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala index 340e8451f14ee..80aba4af9436c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala @@ -202,7 +202,7 @@ class InMemoryCatalog( tableDefinition.storage.locationUri.isEmpty val tableWithLocation = if (needDefaultTableLocation) { - val defaultTableLocation = new Path(catalog(db).db.locationUri, table) + val defaultTableLocation = new Path(new Path(catalog(db).db.locationUri), table) try { val fs = defaultTableLocation.getFileSystem(hadoopConfig) fs.mkdirs(defaultTableLocation) @@ -211,7 +211,7 @@ class InMemoryCatalog( throw new SparkException(s"Unable to create table $table as failed " + s"to create its directory $defaultTableLocation", e) } - tableDefinition.withNewStorage(locationUri = Some(defaultTableLocation.toUri.toString)) + tableDefinition.withNewStorage(locationUri = Some(defaultTableLocation.toUri)) } else { tableDefinition } @@ -274,7 +274,7 @@ class InMemoryCatalog( "Managed table should always have table location, as we will assign a default location " + "to it if it doesn't have one.") val oldDir = new Path(oldDesc.table.location) - val newDir = new Path(catalog(db).db.locationUri, newName) + val newDir = new Path(new Path(catalog(db).db.locationUri), newName) try { val fs = oldDir.getFileSystem(hadoopConfig) fs.rename(oldDir, newDir) @@ -283,7 +283,7 @@ class InMemoryCatalog( throw new SparkException(s"Unable to rename table $oldName to $newName as failed " + s"to rename its directory $oldDir", e) } - oldDesc.table = oldDesc.table.withNewStorage(locationUri = Some(newDir.toUri.toString)) + oldDesc.table = oldDesc.table.withNewStorage(locationUri = Some(newDir.toUri)) } catalog(db).tables.put(newName, oldDesc) @@ -389,7 +389,7 @@ class InMemoryCatalog( existingParts.put( p.spec, - p.copy(storage = p.storage.copy(locationUri = Some(partitionPath.toString)))) + p.copy(storage = p.storage.copy(locationUri = Some(partitionPath.toUri)))) } } @@ -462,7 +462,7 @@ class InMemoryCatalog( } oldPartition.copy( spec = newSpec, - storage = oldPartition.storage.copy(locationUri = Some(newPartPath.toString))) + storage = oldPartition.storage.copy(locationUri = Some(newPartPath.toUri))) } else { oldPartition.copy(spec = newSpec) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index f6412e42c13d5..498bfbde9d7a1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.catalog +import java.net.URI import javax.annotation.concurrent.GuardedBy import scala.collection.mutable @@ -131,10 +132,10 @@ class SessionCatalog( * does not contain a scheme, this path will not be changed after the default * FileSystem is changed. */ - private def makeQualifiedPath(path: String): Path = { + private def makeQualifiedPath(path: URI): URI = { val hadoopPath = new Path(path) val fs = hadoopPath.getFileSystem(hadoopConf) - fs.makeQualified(hadoopPath) + fs.makeQualified(hadoopPath).toUri } private def requireDbExists(db: String): Unit = { @@ -170,7 +171,7 @@ class SessionCatalog( "you cannot create a database with this name.") } validateName(dbName) - val qualifiedPath = makeQualifiedPath(dbDefinition.locationUri).toString + val qualifiedPath = makeQualifiedPath(dbDefinition.locationUri) externalCatalog.createDatabase( dbDefinition.copy(name = dbName, locationUri = qualifiedPath), ignoreIfExists) @@ -228,9 +229,9 @@ class SessionCatalog( * Get the path for creating a non-default database when database location is not provided * by users. */ - def getDefaultDBPath(db: String): String = { + def getDefaultDBPath(db: String): URI = { val database = formatDatabaseName(db) - new Path(new Path(conf.warehousePath), database + ".db").toString + new Path(new Path(conf.warehousePath), database + ".db").toUri } // ---------------------------------------------------------------------------- @@ -351,11 +352,11 @@ class SessionCatalog( db, table, loadPath, spec, isOverwrite, inheritTableSpecs, isSrcLocal) } - def defaultTablePath(tableIdent: TableIdentifier): String = { + def defaultTablePath(tableIdent: TableIdentifier): URI = { val dbName = formatDatabaseName(tableIdent.database.getOrElse(getCurrentDatabase)) val dbLocation = getDatabaseMetadata(dbName).locationUri - new Path(new Path(dbLocation), formatTableName(tableIdent.table)).toString + new Path(new Path(dbLocation), formatTableName(tableIdent.table)).toUri } // ---------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala index 887caf07d1481..4452c479875fa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.catalog +import java.net.URI import java.util.Date import com.google.common.base.Objects @@ -48,10 +49,7 @@ case class CatalogFunction( * Storage format, used to describe how a partition or a table is stored. */ case class CatalogStorageFormat( - // TODO(ekl) consider storing this field as java.net.URI for type safety. Note that this must - // be converted to/from a hadoop Path object using new Path(new URI(locationUri)) and - // path.toUri respectively before use as a filesystem path due to URI char escaping. - locationUri: Option[String], + locationUri: Option[URI], inputFormat: Option[String], outputFormat: Option[String], serde: Option[String], @@ -105,7 +103,7 @@ case class CatalogTablePartition( } /** Return the partition location, assuming it is specified. */ - def location: String = storage.locationUri.getOrElse { + def location: URI = storage.locationUri.getOrElse { val specString = spec.map { case (k, v) => s"$k=$v" }.mkString(", ") throw new AnalysisException(s"Partition [$specString] did not specify locationUri") } @@ -210,7 +208,7 @@ case class CatalogTable( } /** Return the table location, assuming it is specified. */ - def location: String = storage.locationUri.getOrElse { + def location: URI = storage.locationUri.getOrElse { throw new AnalysisException(s"table $identifier did not specify locationUri") } @@ -241,7 +239,7 @@ case class CatalogTable( /** Syntactic sugar to update a field in `storage`. */ def withNewStorage( - locationUri: Option[String] = storage.locationUri, + locationUri: Option[URI] = storage.locationUri, inputFormat: Option[String] = storage.inputFormat, outputFormat: Option[String] = storage.outputFormat, compressed: Boolean = false, @@ -337,7 +335,7 @@ object CatalogTableType { case class CatalogDatabase( name: String, description: String, - locationUri: String, + locationUri: URI, properties: Map[String, String]) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogSuite.scala index a5d399a065589..07ccd68698e94 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.catalog +import java.net.URI + import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.scalatest.BeforeAndAfterEach @@ -340,7 +342,7 @@ abstract class ExternalCatalogSuite extends SparkFunSuite with BeforeAndAfterEac "db1", "tbl", Map("partCol1" -> "1", "partCol2" -> "2")).location - val tableLocation = catalog.getTable("db1", "tbl").location + val tableLocation = new Path(catalog.getTable("db1", "tbl").location) val defaultPartitionLocation = new Path(new Path(tableLocation, "partCol1=1"), "partCol2=2") assert(new Path(partitionLocation) == defaultPartitionLocation) } @@ -508,7 +510,7 @@ abstract class ExternalCatalogSuite extends SparkFunSuite with BeforeAndAfterEac partitionColumnNames = Seq("partCol1", "partCol2")) catalog.createTable(table, ignoreIfExists = false) - val tableLocation = catalog.getTable("db1", "tbl").location + val tableLocation = new Path(catalog.getTable("db1", "tbl").location) val mixedCasePart1 = CatalogTablePartition( Map("partCol1" -> "1", "partCol2" -> "2"), storageFormat) @@ -699,7 +701,7 @@ abstract class ExternalCatalogSuite extends SparkFunSuite with BeforeAndAfterEac // File System operations // -------------------------------------------------------------------------- - private def exists(uri: String, children: String*): Boolean = { + private def exists(uri: URI, children: String*): Boolean = { val base = new Path(uri) val finalPath = children.foldLeft(base) { case (parent, child) => new Path(parent, child) @@ -742,7 +744,7 @@ abstract class ExternalCatalogSuite extends SparkFunSuite with BeforeAndAfterEac identifier = TableIdentifier("external_table", Some("db1")), tableType = CatalogTableType.EXTERNAL, storage = CatalogStorageFormat( - Some(Utils.createTempDir().getAbsolutePath), + Some(Utils.createTempDir().toURI), None, None, None, false, Map.empty), schema = new StructType().add("a", "int").add("b", "string"), provider = Some(defaultProvider) @@ -790,7 +792,7 @@ abstract class ExternalCatalogSuite extends SparkFunSuite with BeforeAndAfterEac val partWithExistingDir = CatalogTablePartition( Map("partCol1" -> "7", "partCol2" -> "8"), CatalogStorageFormat( - Some(tempPath.toURI.toString), + Some(tempPath.toURI), None, None, None, false, Map.empty)) catalog.createPartitions("db1", "tbl", Seq(partWithExistingDir), ignoreIfExists = false) @@ -799,7 +801,7 @@ abstract class ExternalCatalogSuite extends SparkFunSuite with BeforeAndAfterEac val partWithNonExistingDir = CatalogTablePartition( Map("partCol1" -> "9", "partCol2" -> "10"), CatalogStorageFormat( - Some(tempPath.toURI.toString), + Some(tempPath.toURI), None, None, None, false, Map.empty)) catalog.createPartitions("db1", "tbl", Seq(partWithNonExistingDir), ignoreIfExists = false) assert(tempPath.exists()) @@ -883,7 +885,7 @@ abstract class CatalogTestUtils { def newFunc(): CatalogFunction = newFunc("funcName") - def newUriForDatabase(): String = Utils.createTempDir().toURI.toString.stripSuffix("/") + def newUriForDatabase(): URI = new URI(Utils.createTempDir().toURI.toString.stripSuffix("/")) def newDb(name: String): CatalogDatabase = { CatalogDatabase(name, name + " description", newUriForDatabase(), Map.empty) @@ -895,7 +897,7 @@ abstract class CatalogTestUtils { CatalogTable( identifier = TableIdentifier(name, database), tableType = CatalogTableType.EXTERNAL, - storage = storageFormat.copy(locationUri = Some(Utils.createTempDir().getAbsolutePath)), + storage = storageFormat.copy(locationUri = Some(Utils.createTempDir().toURI)), schema = new StructType() .add("col1", "int") .add("col2", "string") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index 65df688689397..c106163741278 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -386,7 +386,7 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { "LOCATION and 'path' in OPTIONS are both used to indicate the custom table path, " + "you can only specify one of them.", ctx) } - val customLocation = storage.locationUri.orElse(location) + val customLocation = storage.locationUri.orElse(location.map(CatalogUtils.stringToURI(_))) val tableType = if (customLocation.isDefined) { CatalogTableType.EXTERNAL @@ -1080,8 +1080,10 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { if (external && location.isEmpty) { operationNotAllowed("CREATE EXTERNAL TABLE must be accompanied by LOCATION", ctx) } + + val locUri = location.map(CatalogUtils.stringToURI(_)) val storage = CatalogStorageFormat( - locationUri = location, + locationUri = locUri, inputFormat = fileStorage.inputFormat.orElse(defaultStorage.inputFormat), outputFormat = fileStorage.outputFormat.orElse(defaultStorage.outputFormat), serde = rowStorage.serde.orElse(fileStorage.serde).orElse(defaultStorage.serde), @@ -1132,7 +1134,7 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { // At here, both rowStorage.serdeProperties and fileStorage.serdeProperties // are empty Maps. val newTableDesc = tableDesc.copy( - storage = CatalogStorageFormat.empty.copy(locationUri = location), + storage = CatalogStorageFormat.empty.copy(locationUri = locUri), provider = Some(conf.defaultDataSourceName)) CreateTable(newTableDesc, mode, Some(q)) } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala index d835b521166a8..3da66afceda9c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala @@ -17,6 +17,10 @@ package org.apache.spark.sql.execution.command +import java.net.URI + +import org.apache.hadoop.fs.Path + import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan @@ -54,7 +58,7 @@ case class CreateDataSourceTableCommand(table: CatalogTable, ignoreIfExists: Boo // Create the relation to validate the arguments before writing the metadata to the metastore, // and infer the table schema and partition if users didn't specify schema in CREATE TABLE. - val pathOption = table.storage.locationUri.map("path" -> _) + val pathOption = table.storage.locationUri.map("path" -> CatalogUtils.URIToString(_)) // Fill in some default table options from the session conf val tableWithDefaultOptions = table.copy( identifier = table.identifier.copy( @@ -175,12 +179,12 @@ case class CreateDataSourceTableAsSelectCommand( private def saveDataIntoTable( session: SparkSession, table: CatalogTable, - tableLocation: Option[String], + tableLocation: Option[URI], data: LogicalPlan, mode: SaveMode, tableExists: Boolean): BaseRelation = { // Create the relation based on the input logical plan: `data`. - val pathOption = tableLocation.map("path" -> _) + val pathOption = tableLocation.map("path" -> CatalogUtils.URIToString(_)) val dataSource = DataSource( session, className = table.provider.get, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala index 82cbb4aa47445..b5c60423514cb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala @@ -66,7 +66,7 @@ case class CreateDatabaseCommand( CatalogDatabase( databaseName, comment.getOrElse(""), - path.getOrElse(catalog.getDefaultDBPath(databaseName)), + path.map(CatalogUtils.stringToURI(_)).getOrElse(catalog.getDefaultDBPath(databaseName)), props), ifNotExists) Seq.empty[Row] @@ -146,7 +146,7 @@ case class DescribeDatabaseCommand( val result = Row("Database Name", dbMetadata.name) :: Row("Description", dbMetadata.description) :: - Row("Location", dbMetadata.locationUri) :: Nil + Row("Location", CatalogUtils.URIToString(dbMetadata.locationUri)) :: Nil if (extended) { val properties = @@ -426,7 +426,8 @@ case class AlterTableAddPartitionCommand( table.identifier.quotedString, sparkSession.sessionState.conf.resolver) // inherit table storage format (possibly except for location) - CatalogTablePartition(normalizedSpec, table.storage.copy(locationUri = location)) + CatalogTablePartition(normalizedSpec, table.storage.copy( + locationUri = location.map(CatalogUtils.stringToURI(_)))) } catalog.createPartitions(table.identifier, parts, ignoreIfExists = ifNotExists) Seq.empty[Row] @@ -710,7 +711,7 @@ case class AlterTableRecoverPartitionsCommand( // inherit table storage format (possibly except for location) CatalogTablePartition( spec, - table.storage.copy(locationUri = Some(location.toUri.toString)), + table.storage.copy(locationUri = Some(location.toUri)), params) } spark.sessionState.catalog.createPartitions(tableName, parts, ignoreIfExists = true) @@ -741,6 +742,7 @@ case class AlterTableSetLocationCommand( override def run(sparkSession: SparkSession): Seq[Row] = { val catalog = sparkSession.sessionState.catalog val table = catalog.getTableMetadata(tableName) + val locUri = CatalogUtils.stringToURI(location) DDLUtils.verifyAlterTableType(catalog, table, isView = false) partitionSpec match { case Some(spec) => @@ -748,11 +750,11 @@ case class AlterTableSetLocationCommand( sparkSession, table, "ALTER TABLE ... SET LOCATION") // Partition spec is specified, so we set the location only for this partition val part = catalog.getPartition(table.identifier, spec) - val newPart = part.copy(storage = part.storage.copy(locationUri = Some(location))) + val newPart = part.copy(storage = part.storage.copy(locationUri = Some(locUri))) catalog.alterPartitions(table.identifier, Seq(newPart)) case None => // No partition spec is specified, so we set the location for the table itself - catalog.alterTable(table.withNewStorage(locationUri = Some(location))) + catalog.alterTable(table.withNewStorage(locationUri = Some(locUri))) } Seq.empty[Row] } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala index 3e80916104bd9..86394ff23e379 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala @@ -79,7 +79,8 @@ case class CreateTableLikeCommand( CatalogTable( identifier = targetTable, tableType = tblType, - storage = sourceTableDesc.storage.copy(locationUri = location), + storage = sourceTableDesc.storage.copy( + locationUri = location.map(CatalogUtils.stringToURI(_))), schema = sourceTableDesc.schema, provider = newProvider, partitionColumnNames = sourceTableDesc.partitionColumnNames, @@ -495,7 +496,8 @@ case class DescribeTableCommand( append(buffer, "Owner:", table.owner, "") append(buffer, "Create Time:", new Date(table.createTime).toString, "") append(buffer, "Last Access Time:", new Date(table.lastAccessTime).toString, "") - append(buffer, "Location:", table.storage.locationUri.getOrElse(""), "") + append(buffer, "Location:", table.storage.locationUri.map(CatalogUtils.URIToString(_)) + .getOrElse(""), "") append(buffer, "Table Type:", table.tableType.name, "") table.stats.foreach(s => append(buffer, "Statistics:", s.simpleString, "")) @@ -587,7 +589,8 @@ case class DescribeTableCommand( append(buffer, "Partition Value:", s"[${partition.spec.values.mkString(", ")}]", "") append(buffer, "Database:", table.database, "") append(buffer, "Table:", tableIdentifier.table, "") - append(buffer, "Location:", partition.storage.locationUri.getOrElse(""), "") + append(buffer, "Location:", partition.storage.locationUri.map(CatalogUtils.URIToString(_)) + .getOrElse(""), "") append(buffer, "Partition Parameters:", "", "") partition.parameters.foreach { case (key, value) => append(buffer, s" $key", value, "") @@ -953,7 +956,7 @@ case class ShowCreateTableCommand(table: TableIdentifier) extends RunnableComman // when the table creation DDL contains the PATH option. None } else { - Some(s"path '${escapeSingleQuotedString(location)}'") + Some(s"path '${escapeSingleQuotedString(CatalogUtils.URIToString(location))}'") } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CatalogFileIndex.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CatalogFileIndex.scala index 2068811661fec..d6c4b97ebd080 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CatalogFileIndex.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CatalogFileIndex.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.datasources +import java.net.URI + import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path @@ -46,7 +48,7 @@ class CatalogFileIndex( assert(table.identifier.database.isDefined, "The table identifier must be qualified in CatalogFileIndex") - private val baseLocation: Option[String] = table.storage.locationUri + private val baseLocation: Option[URI] = table.storage.locationUri override def partitionSchema: StructType = table.partitionSchema diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index 4947dfda6fc7e..c9384e44255b8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -29,7 +29,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStorageFormat, CatalogTable} +import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStorageFormat, CatalogTable, CatalogUtils} import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap import org.apache.spark.sql.execution.datasources.csv.CSVFileFormat @@ -597,6 +597,7 @@ object DataSource { def buildStorageFormatFromOptions(options: Map[String, String]): CatalogStorageFormat = { val path = CaseInsensitiveMap(options).get("path") val optionsWithoutPath = options.filterKeys(_.toLowerCase != "path") - CatalogStorageFormat.empty.copy(locationUri = path, properties = optionsWithoutPath) + CatalogStorageFormat.empty.copy( + locationUri = path.map(CatalogUtils.stringToURI(_)), properties = optionsWithoutPath) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index f694a0d6d724b..bddf5af23e060 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -21,13 +21,15 @@ import java.util.concurrent.Callable import scala.collection.mutable.ArrayBuffer +import org.apache.hadoop.fs.Path + import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.{CatalystConf, CatalystTypeConverters, InternalRow, QualifiedTableName, TableIdentifier} import org.apache.spark.sql.catalyst.CatalystTypeConverters.convertToScala import org.apache.spark.sql.catalyst.analysis._ -import org.apache.spark.sql.catalyst.catalog.CatalogRelation +import org.apache.spark.sql.catalyst.catalog.{CatalogRelation, CatalogUtils} import org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning.PhysicalOperation @@ -220,7 +222,7 @@ class FindDataSourceTable(sparkSession: SparkSession) extends Rule[LogicalPlan] val plan = cache.get(qualifiedTableName, new Callable[LogicalPlan]() { override def call(): LogicalPlan = { - val pathOption = table.storage.locationUri.map("path" -> _) + val pathOption = table.storage.locationUri.map("path" -> CatalogUtils.URIToString(_)) val dataSource = DataSource( sparkSession, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala index 3d9f41832bc73..ed07ff3ff0599 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.internal import scala.reflect.runtime.universe.TypeTag +import org.apache.hadoop.fs.Path + import org.apache.spark.annotation.Experimental import org.apache.spark.sql._ import org.apache.spark.sql.catalog.{Catalog, Column, Database, Function, Table} @@ -77,7 +79,7 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { new Database( name = metadata.name, description = metadata.description, - locationUri = metadata.locationUri) + locationUri = CatalogUtils.URIToString(metadata.locationUri)) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala index bce84de45c3d7..86129fa87feaa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala @@ -21,6 +21,7 @@ import scala.reflect.ClassTag import scala.util.control.NonFatal import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path import org.apache.spark.{SparkConf, SparkContext, SparkException} import org.apache.spark.internal.Logging @@ -95,7 +96,10 @@ private[sql] class SharedState(val sparkContext: SparkContext) extends Logging { // Create the default database if it doesn't exist. { val defaultDbDefinition = CatalogDatabase( - SessionCatalog.DEFAULT_DATABASE, "default database", warehousePath, Map()) + SessionCatalog.DEFAULT_DATABASE, + "default database", + CatalogUtils.stringToURI(warehousePath), + Map()) // Initialize default database if it doesn't exist if (!externalCatalog.databaseExists(SessionCatalog.DEFAULT_DATABASE)) { // There may be another Spark application creating default database at the same time, here we diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala index 76bb9e5929a71..4b73b078da38e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.command +import java.net.URI + import scala.reflect.{classTag, ClassTag} import org.apache.spark.sql.catalyst.TableIdentifier @@ -317,7 +319,7 @@ class DDLCommandSuite extends PlanTest { val query = "CREATE EXTERNAL TABLE my_tab LOCATION '/something/anything'" val ct = parseAs[CreateTable](query) assert(ct.tableDesc.tableType == CatalogTableType.EXTERNAL) - assert(ct.tableDesc.storage.locationUri == Some("/something/anything")) + assert(ct.tableDesc.storage.locationUri == Some(new URI("/something/anything"))) } test("create hive table - property values must be set") { @@ -334,7 +336,7 @@ class DDLCommandSuite extends PlanTest { val query = "CREATE TABLE my_tab LOCATION '/something/anything'" val ct = parseAs[CreateTable](query) assert(ct.tableDesc.tableType == CatalogTableType.EXTERNAL) - assert(ct.tableDesc.storage.locationUri == Some("/something/anything")) + assert(ct.tableDesc.storage.locationUri == Some(new URI("/something/anything"))) } test("create table - with partitioned by") { @@ -409,7 +411,7 @@ class DDLCommandSuite extends PlanTest { val expectedTableDesc = CatalogTable( identifier = TableIdentifier("my_tab"), tableType = CatalogTableType.EXTERNAL, - storage = CatalogStorageFormat.empty.copy(locationUri = Some("/tmp/file")), + storage = CatalogStorageFormat.empty.copy(locationUri = Some(new URI("/tmp/file"))), schema = new StructType().add("a", IntegerType).add("b", StringType), provider = Some("parquet")) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index 8b8cd0fdf4db2..6ffa58bcd9af1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -26,9 +26,7 @@ import org.scalatest.BeforeAndAfterEach import org.apache.spark.sql.{AnalysisException, QueryTest, Row, SaveMode} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.{DatabaseAlreadyExistsException, FunctionRegistry, NoSuchPartitionException, NoSuchTableException, TempTableAlreadyExistsException} -import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogDatabase, CatalogStorageFormat} -import org.apache.spark.sql.catalyst.catalog.{CatalogTable, CatalogTableType} -import org.apache.spark.sql.catalyst.catalog.{CatalogTablePartition, SessionCatalog} +import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION @@ -72,7 +70,8 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { private def createDatabase(catalog: SessionCatalog, name: String): Unit = { catalog.createDatabase( - CatalogDatabase(name, "", spark.sessionState.conf.warehousePath, Map()), + CatalogDatabase( + name, "", CatalogUtils.stringToURI(spark.sessionState.conf.warehousePath), Map()), ignoreIfExists = false) } @@ -133,11 +132,11 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { } } - private def makeQualifiedPath(path: String): String = { + private def makeQualifiedPath(path: String): URI = { // copy-paste from SessionCatalog val hadoopPath = new Path(path) val fs = hadoopPath.getFileSystem(sparkContext.hadoopConfiguration) - fs.makeQualified(hadoopPath).toString + fs.makeQualified(hadoopPath).toUri } test("Create Database using Default Warehouse Path") { @@ -449,7 +448,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { sql(s"DESCRIBE DATABASE EXTENDED $dbName"), Row("Database Name", dbNameWithoutBackTicks) :: Row("Description", "") :: - Row("Location", location) :: + Row("Location", CatalogUtils.URIToString(location)) :: Row("Properties", "") :: Nil) sql(s"ALTER DATABASE $dbName SET DBPROPERTIES ('a'='a', 'b'='b', 'c'='c')") @@ -458,7 +457,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { sql(s"DESCRIBE DATABASE EXTENDED $dbName"), Row("Database Name", dbNameWithoutBackTicks) :: Row("Description", "") :: - Row("Location", location) :: + Row("Location", CatalogUtils.URIToString(location)) :: Row("Properties", "((a,a), (b,b), (c,c))") :: Nil) sql(s"ALTER DATABASE $dbName SET DBPROPERTIES ('d'='d')") @@ -467,7 +466,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { sql(s"DESCRIBE DATABASE EXTENDED $dbName"), Row("Database Name", dbNameWithoutBackTicks) :: Row("Description", "") :: - Row("Location", location) :: + Row("Location", CatalogUtils.URIToString(location)) :: Row("Properties", "((a,a), (b,b), (c,c), (d,d))") :: Nil) } finally { catalog.reset() @@ -1094,7 +1093,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { assert(catalog.getPartition(tableIdent, partSpec).storage.locationUri.isDefined) assert(catalog.getPartition(tableIdent, partSpec).storage.properties.isEmpty) // Verify that the location is set to the expected string - def verifyLocation(expected: String, spec: Option[TablePartitionSpec] = None): Unit = { + def verifyLocation(expected: URI, spec: Option[TablePartitionSpec] = None): Unit = { val storageFormat = spec .map { s => catalog.getPartition(tableIdent, s).storage } .getOrElse { catalog.getTableMetadata(tableIdent).storage } @@ -1111,17 +1110,17 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { } // set table location sql("ALTER TABLE dbx.tab1 SET LOCATION '/path/to/your/lovely/heart'") - verifyLocation("/path/to/your/lovely/heart") + verifyLocation(new URI("/path/to/your/lovely/heart")) // set table partition location sql("ALTER TABLE dbx.tab1 PARTITION (a='1', b='2') SET LOCATION '/path/to/part/ways'") - verifyLocation("/path/to/part/ways", Some(partSpec)) + verifyLocation(new URI("/path/to/part/ways"), Some(partSpec)) // set table location without explicitly specifying database catalog.setCurrentDatabase("dbx") sql("ALTER TABLE tab1 SET LOCATION '/swanky/steak/place'") - verifyLocation("/swanky/steak/place") + verifyLocation(new URI("/swanky/steak/place")) // set table partition location without explicitly specifying database sql("ALTER TABLE tab1 PARTITION (a='1', b='2') SET LOCATION 'vienna'") - verifyLocation("vienna", Some(partSpec)) + verifyLocation(new URI("vienna"), Some(partSpec)) // table to alter does not exist intercept[AnalysisException] { sql("ALTER TABLE dbx.does_not_exist SET LOCATION '/mister/spark'") @@ -1255,7 +1254,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { "PARTITION (a='2', b='6') LOCATION 'paris' PARTITION (a='3', b='7')") assert(catalog.listPartitions(tableIdent).map(_.spec).toSet == Set(part1, part2, part3)) assert(catalog.getPartition(tableIdent, part1).storage.locationUri.isDefined) - assert(catalog.getPartition(tableIdent, part2).storage.locationUri == Option("paris")) + assert(catalog.getPartition(tableIdent, part2).storage.locationUri == Option(new URI("paris"))) assert(catalog.getPartition(tableIdent, part3).storage.locationUri.isDefined) // add partitions without explicitly specifying database @@ -1819,7 +1818,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { // SET LOCATION won't move data from previous table path to new table path. assert(spark.table("tbl").count() == 0) // the previous table path should be still there. - assert(new File(new URI(defaultTablePath)).exists()) + assert(new File(defaultTablePath).exists()) sql("INSERT INTO tbl SELECT 2") checkAnswer(spark.table("tbl"), Row(2)) @@ -1843,28 +1842,27 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { |OPTIONS(path "$dir") """.stripMargin) val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t")) - assert(table.location == dir.getAbsolutePath) + assert(table.location == new URI(dir.getAbsolutePath)) dir.delete - val tableLocFile = new File(table.location) - assert(!tableLocFile.exists) + assert(!dir.exists) spark.sql("INSERT INTO TABLE t SELECT 'c', 1") - assert(tableLocFile.exists) + assert(dir.exists) checkAnswer(spark.table("t"), Row("c", 1) :: Nil) Utils.deleteRecursively(dir) - assert(!tableLocFile.exists) + assert(!dir.exists) spark.sql("INSERT OVERWRITE TABLE t SELECT 'c', 1") - assert(tableLocFile.exists) + assert(dir.exists) checkAnswer(spark.table("t"), Row("c", 1) :: Nil) val newDirFile = new File(dir, "x") - val newDir = newDirFile.toURI.toString + val newDir = newDirFile.getAbsolutePath spark.sql(s"ALTER TABLE t SET LOCATION '$newDir'") spark.sessionState.catalog.refreshTable(TableIdentifier("t")) val table1 = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t")) - assert(table1.location == newDir) + assert(table1.location == new URI(newDir)) assert(!newDirFile.exists) spark.sql("INSERT INTO TABLE t SELECT 'c', 1") @@ -1885,7 +1883,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { |LOCATION "$dir" """.stripMargin) val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t")) - assert(table.location == dir.getAbsolutePath) + assert(table.location == new URI(dir.getAbsolutePath)) spark.sql("INSERT INTO TABLE t PARTITION(a=1, b=2) SELECT 3, 4") checkAnswer(spark.table("t"), Row(3, 4, 1, 2) :: Nil) @@ -1911,13 +1909,13 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { |OPTIONS(path "$dir") """.stripMargin) val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t")) - assert(table.location == dir.getAbsolutePath) + assert(table.location == new URI(dir.getAbsolutePath)) dir.delete() checkAnswer(spark.table("t"), Nil) val newDirFile = new File(dir, "x") - val newDir = newDirFile.toURI.toString + val newDir = newDirFile.toURI spark.sql(s"ALTER TABLE t SET LOCATION '$newDir'") val table1 = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t")) @@ -1967,7 +1965,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { |AS SELECT 3 as a, 4 as b, 1 as c, 2 as d """.stripMargin) val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t")) - assert(table.location == dir.getAbsolutePath) + assert(table.location == new URI(dir.getAbsolutePath)) checkAnswer(spark.table("t"), Row(3, 4, 1, 2)) } @@ -1986,7 +1984,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { |AS SELECT 3 as a, 4 as b, 1 as c, 2 as d """.stripMargin) val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t1")) - assert(table.location == dir.getAbsolutePath) + assert(table.location == new URI(dir.getAbsolutePath)) val partDir = new File(dir, "a=3") assert(partDir.exists()) @@ -1996,4 +1994,84 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { } } } + + Seq("a b", "a:b", "a%b").foreach { specialChars => + test(s"location uri contains $specialChars for datasource table") { + withTable("t", "t1") { + withTempDir { dir => + val loc = new File(dir, specialChars) + loc.mkdir() + spark.sql( + s""" + |CREATE TABLE t(a string) + |USING parquet + |LOCATION '$loc' + """.stripMargin) + + val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t")) + assert(table.location == new Path(loc.getAbsolutePath).toUri) + assert(new Path(table.location).toString.contains(specialChars)) + + assert(loc.listFiles().isEmpty) + spark.sql("INSERT INTO TABLE t SELECT 1") + assert(loc.listFiles().length >= 1) + checkAnswer(spark.table("t"), Row("1") :: Nil) + } + + withTempDir { dir => + val loc = new File(dir, specialChars) + loc.mkdir() + spark.sql( + s""" + |CREATE TABLE t1(a string, b string) + |USING parquet + |PARTITIONED BY(b) + |LOCATION '$loc' + """.stripMargin) + + val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t1")) + assert(table.location == new Path(loc.getAbsolutePath).toUri) + assert(new Path(table.location).toString.contains(specialChars)) + + assert(loc.listFiles().isEmpty) + spark.sql("INSERT INTO TABLE t1 PARTITION(b=2) SELECT 1") + val partFile = new File(loc, "b=2") + assert(partFile.listFiles().length >= 1) + checkAnswer(spark.table("t1"), Row("1", "2") :: Nil) + + spark.sql("INSERT INTO TABLE t1 PARTITION(b='2017-03-03 12:13%3A14') SELECT 1") + val partFile1 = new File(loc, "b=2017-03-03 12:13%3A14") + assert(!partFile1.exists()) + val partFile2 = new File(loc, "b=2017-03-03 12%3A13%253A14") + assert(partFile2.listFiles().length >= 1) + checkAnswer(spark.table("t1"), Row("1", "2") :: Row("1", "2017-03-03 12:13%3A14") :: Nil) + } + } + } + } + + Seq("a b", "a:b", "a%b").foreach { specialChars => + test(s"location uri contains $specialChars for database") { + try { + withTable("t") { + withTempDir { dir => + val loc = new File(dir, specialChars) + spark.sql(s"CREATE DATABASE tmpdb LOCATION '$loc'") + spark.sql("USE tmpdb") + + import testImplicits._ + Seq(1).toDF("a").write.saveAsTable("t") + val tblloc = new File(loc, "t") + val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t")) + val tblPath = new Path(tblloc.getAbsolutePath) + val fs = tblPath.getFileSystem(spark.sessionState.newHadoopConf()) + assert(table.location == fs.makeQualified(tblPath).toUri) + assert(tblloc.listFiles().nonEmpty) + } + } + } finally { + spark.sql("DROP DATABASE IF EXISTS tmpdb") + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala index 75723d0abcfcc..989a7f2698171 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala @@ -459,7 +459,7 @@ class CatalogSuite options = Map("path" -> dir.getAbsolutePath)) val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t")) assert(table.tableType == CatalogTableType.EXTERNAL) - assert(table.storage.locationUri.get == dir.getAbsolutePath) + assert(table.storage.locationUri.get == new URI(dir.getAbsolutePath)) Seq((1)).toDF("i").write.insertInto("t") assert(dir.exists() && dir.listFiles().nonEmpty) @@ -481,7 +481,7 @@ class CatalogSuite options = Map.empty[String, String]) val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t")) assert(table.tableType == CatalogTableType.MANAGED) - val tablePath = new File(new URI(table.storage.locationUri.get)) + val tablePath = new File(table.storage.locationUri.get) assert(tablePath.exists() && tablePath.listFiles().isEmpty) Seq((1)).toDF("i").write.insertInto("t") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala index 9082261af7b00..93f3efe2ccc4a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala @@ -92,7 +92,7 @@ abstract class BucketedWriteSuite extends QueryTest with SQLTestUtils { def tableDir: File = { val identifier = spark.sessionState.sqlParser.parseTableIdentifier("bucketed_table") - new File(URI.create(spark.sessionState.catalog.defaultTablePath(identifier))) + new File(spark.sessionState.catalog.defaultTablePath(identifier)) } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/PathOptionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/PathOptionSuite.scala index faf9afc49a2d3..7ab339e005295 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/PathOptionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/PathOptionSuite.scala @@ -17,10 +17,13 @@ package org.apache.spark.sql.sources +import java.net.URI + import org.apache.hadoop.fs.Path import org.apache.spark.sql.{DataFrame, SaveMode, SparkSession, SQLContext} import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.catalog.CatalogUtils import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{IntegerType, Metadata, MetadataBuilder, StructType} @@ -78,7 +81,7 @@ class PathOptionSuite extends DataSourceTest with SharedSQLContext { // should exist even path option is not specified when creating table withTable("src") { sql(s"CREATE TABLE src(i int) USING ${classOf[TestOptionsSource].getCanonicalName}") - assert(getPathOption("src") == Some(defaultTablePath("src"))) + assert(getPathOption("src") == Some(CatalogUtils.URIToString(defaultTablePath("src")))) } } @@ -105,7 +108,8 @@ class PathOptionSuite extends DataSourceTest with SharedSQLContext { |USING ${classOf[TestOptionsSource].getCanonicalName} |AS SELECT 1 """.stripMargin) - assert(spark.table("src").schema.head.metadata.getString("path") == defaultTablePath("src")) + assert(spark.table("src").schema.head.metadata.getString("path") == + CatalogUtils.URIToString(defaultTablePath("src"))) } } @@ -123,7 +127,7 @@ class PathOptionSuite extends DataSourceTest with SharedSQLContext { withTable("src", "src2") { sql(s"CREATE TABLE src(i int) USING ${classOf[TestOptionsSource].getCanonicalName}") sql("ALTER TABLE src RENAME TO src2") - assert(getPathOption("src2") == Some(defaultTablePath("src2"))) + assert(getPathOption("src2") == Some(CatalogUtils.URIToString(defaultTablePath("src2")))) } } @@ -133,7 +137,7 @@ class PathOptionSuite extends DataSourceTest with SharedSQLContext { }.head } - private def defaultTablePath(tableName: String): String = { + private def defaultTablePath(tableName: String): URI = { spark.sessionState.catalog.defaultTablePath(TableIdentifier(tableName)) } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala index 43d9c2bec6823..9ab4624594924 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala @@ -210,7 +210,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat tableDefinition.storage.locationUri.isEmpty val tableLocation = if (needDefaultTableLocation) { - Some(defaultTablePath(tableDefinition.identifier)) + Some(CatalogUtils.stringToURI(defaultTablePath(tableDefinition.identifier))) } else { tableDefinition.storage.locationUri } @@ -260,7 +260,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat // However, in older version of Spark we already store table location in storage properties // with key "path". Here we keep this behaviour for backward compatibility. val storagePropsWithLocation = table.storage.properties ++ - table.storage.locationUri.map("path" -> _) + table.storage.locationUri.map("path" -> CatalogUtils.URIToString(_)) // converts the table metadata to Spark SQL specific format, i.e. set data schema, names and // bucket specification to empty. Note that partition columns are retained, so that we can @@ -285,7 +285,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat // compatible format, which means the data source is file-based and must have a `path`. require(table.storage.locationUri.isDefined, "External file-based data source table must have a `path` entry in storage properties.") - Some(new Path(table.location).toUri.toString) + Some(table.location) } else { None } @@ -432,13 +432,13 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat // // Please refer to https://issues.apache.org/jira/browse/SPARK-15269 for more details. val tempPath = { - val dbLocation = getDatabase(tableDefinition.database).locationUri + val dbLocation = new Path(getDatabase(tableDefinition.database).locationUri) new Path(dbLocation, tableDefinition.identifier.table + "-__PLACEHOLDER__") } try { client.createTable( - tableDefinition.withNewStorage(locationUri = Some(tempPath.toString)), + tableDefinition.withNewStorage(locationUri = Some(tempPath.toUri)), ignoreIfExists) } finally { FileSystem.get(tempPath.toUri, hadoopConf).delete(tempPath, true) @@ -563,7 +563,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat // want to alter the table location to a file path, we will fail. This should be fixed // in the future. - val newLocation = tableDefinition.storage.locationUri + val newLocation = tableDefinition.storage.locationUri.map(CatalogUtils.URIToString(_)) val storageWithPathOption = tableDefinition.storage.copy( properties = tableDefinition.storage.properties ++ newLocation.map("path" -> _)) @@ -704,7 +704,8 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat val storageWithLocation = { val tableLocation = getLocationFromStorageProps(table) // We pass None as `newPath` here, to remove the path option in storage properties. - updateLocationInStorageProps(table, newPath = None).copy(locationUri = tableLocation) + updateLocationInStorageProps(table, newPath = None).copy( + locationUri = tableLocation.map(CatalogUtils.stringToURI(_))) } val partitionProvider = table.properties.get(TABLE_PARTITION_PROVIDER) @@ -848,10 +849,10 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat // However, Hive metastore is not case preserving and will generate wrong partition location // with lower cased partition column names. Here we set the default partition location // manually to avoid this problem. - val partitionPath = p.storage.locationUri.map(uri => new Path(new URI(uri))).getOrElse { + val partitionPath = p.storage.locationUri.map(uri => new Path(uri)).getOrElse { ExternalCatalogUtils.generatePartitionPath(p.spec, partitionColumnNames, tablePath) } - p.copy(storage = p.storage.copy(locationUri = Some(partitionPath.toUri.toString))) + p.copy(storage = p.storage.copy(locationUri = Some(partitionPath.toUri))) } val lowerCasedParts = partsWithLocation.map(p => p.copy(spec = lowerCasePartitionSpec(p.spec))) client.createPartitions(db, table, lowerCasedParts, ignoreIfExists) @@ -890,7 +891,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat val newParts = newSpecs.map { spec => val rightPath = renamePartitionDirectory(fs, tablePath, partitionColumnNames, spec) val partition = client.getPartition(db, table, lowerCasePartitionSpec(spec)) - partition.copy(storage = partition.storage.copy(locationUri = Some(rightPath.toString))) + partition.copy(storage = partition.storage.copy(locationUri = Some(rightPath.toUri))) } alterPartitions(db, table, newParts) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 151a69aebf1d8..4d3b6c3cec1c6 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -128,7 +128,7 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log QualifiedTableName(relation.tableMeta.database, relation.tableMeta.identifier.table) val lazyPruningEnabled = sparkSession.sqlContext.conf.manageFilesourcePartitions - val tablePath = new Path(new URI(relation.tableMeta.location)) + val tablePath = new Path(relation.tableMeta.location) val result = if (relation.isPartitioned) { val partitionSchema = relation.tableMeta.partitionSchema val rootPaths: Seq[Path] = if (lazyPruningEnabled) { @@ -141,7 +141,7 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log // locations,_omitting_ the table's base path. val paths = sparkSession.sharedState.externalCatalog .listPartitions(tableIdentifier.database, tableIdentifier.name) - .map(p => new Path(new URI(p.storage.locationUri.get))) + .map(p => new Path(p.storage.locationUri.get)) if (paths.isEmpty) { Seq(tablePath) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index 624cfa206eeb2..b5ce027d51e73 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -133,7 +133,7 @@ class DetermineTableStats(session: SparkSession) extends Rule[LogicalPlan] { } else if (session.sessionState.conf.fallBackToHdfsForStatsEnabled) { try { val hadoopConf = session.sessionState.newHadoopConf() - val tablePath = new Path(new URI(table.location)) + val tablePath = new Path(table.location) val fs: FileSystem = tablePath.getFileSystem(hadoopConf) fs.getContentSummary(tablePath).getLength } catch { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala index 7acaa9a7ab417..469c9d84de054 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala @@ -317,7 +317,7 @@ private[hive] class HiveClientImpl( new HiveDatabase( database.name, database.description, - database.locationUri, + CatalogUtils.URIToString(database.locationUri), Option(database.properties).map(_.asJava).orNull), ignoreIfExists) } @@ -335,7 +335,7 @@ private[hive] class HiveClientImpl( new HiveDatabase( database.name, database.description, - database.locationUri, + CatalogUtils.URIToString(database.locationUri), Option(database.properties).map(_.asJava).orNull)) } @@ -344,7 +344,7 @@ private[hive] class HiveClientImpl( CatalogDatabase( name = d.getName, description = d.getDescription, - locationUri = d.getLocationUri, + locationUri = CatalogUtils.stringToURI(d.getLocationUri), properties = Option(d.getParameters).map(_.asScala.toMap).orNull) }.getOrElse(throw new NoSuchDatabaseException(dbName)) } @@ -410,7 +410,7 @@ private[hive] class HiveClientImpl( createTime = h.getTTable.getCreateTime.toLong * 1000, lastAccessTime = h.getLastAccessTime.toLong * 1000, storage = CatalogStorageFormat( - locationUri = shim.getDataLocation(h), + locationUri = shim.getDataLocation(h).map(CatalogUtils.stringToURI(_)), // To avoid ClassNotFound exception, we try our best to not get the format class, but get // the class name directly. However, for non-native tables, there is no interface to get // the format class name, so we may still throw ClassNotFound in this case. @@ -851,7 +851,8 @@ private[hive] object HiveClientImpl { conf.foreach(c => hiveTable.setOwner(c.getUser)) hiveTable.setCreateTime((table.createTime / 1000).toInt) hiveTable.setLastAccessTime((table.lastAccessTime / 1000).toInt) - table.storage.locationUri.foreach { loc => hiveTable.getTTable.getSd.setLocation(loc)} + table.storage.locationUri.map(CatalogUtils.URIToString(_)).foreach { loc => + hiveTable.getTTable.getSd.setLocation(loc)} table.storage.inputFormat.map(toInputFormat).foreach(hiveTable.setInputFormatClass) table.storage.outputFormat.map(toOutputFormat).foreach(hiveTable.setOutputFormatClass) hiveTable.setSerializationLib( @@ -885,7 +886,7 @@ private[hive] object HiveClientImpl { } val storageDesc = new StorageDescriptor val serdeInfo = new SerDeInfo - p.storage.locationUri.foreach(storageDesc.setLocation) + p.storage.locationUri.map(CatalogUtils.URIToString(_)).foreach(storageDesc.setLocation) p.storage.inputFormat.foreach(storageDesc.setInputFormat) p.storage.outputFormat.foreach(storageDesc.setOutputFormat) p.storage.serde.foreach(serdeInfo.setSerializationLib) @@ -906,7 +907,7 @@ private[hive] object HiveClientImpl { CatalogTablePartition( spec = Option(hp.getSpec).map(_.asScala.toMap).getOrElse(Map.empty), storage = CatalogStorageFormat( - locationUri = Option(apiPartition.getSd.getLocation), + locationUri = Option(CatalogUtils.stringToURI(apiPartition.getSd.getLocation)), inputFormat = Option(apiPartition.getSd.getInputFormat), outputFormat = Option(apiPartition.getSd.getOutputFormat), serde = Option(apiPartition.getSd.getSerdeInfo.getSerializationLib), diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala index 7280748361d60..c6188fc683e77 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala @@ -24,10 +24,9 @@ import java.util.{ArrayList => JArrayList, List => JList, Map => JMap, Set => JS import java.util.concurrent.TimeUnit import scala.collection.JavaConverters._ -import scala.util.Try import scala.util.control.NonFatal -import org.apache.hadoop.fs.{FileSystem, Path} +import org.apache.hadoop.fs.Path import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.metastore.api.{Function => HiveFunction, FunctionType, MetaException, PrincipalType, ResourceType, ResourceUri} import org.apache.hadoop.hive.ql.Driver @@ -41,7 +40,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.FunctionIdentifier import org.apache.spark.sql.catalyst.analysis.NoSuchPermanentFunctionException -import org.apache.spark.sql.catalyst.catalog.{CatalogFunction, CatalogTablePartition, FunctionResource, FunctionResourceType} +import org.apache.spark.sql.catalyst.catalog.{CatalogFunction, CatalogTablePartition, CatalogUtils, FunctionResource, FunctionResourceType} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{IntegralType, StringType} @@ -268,7 +267,7 @@ private[client] class Shim_v0_12 extends Shim with Logging { val table = hive.getTable(database, tableName) parts.foreach { s => val location = s.storage.locationUri.map( - uri => new Path(table.getPath, new Path(new URI(uri)))).orNull + uri => new Path(table.getPath, new Path(uri))).orNull val params = if (s.parameters.nonEmpty) s.parameters.asJava else null val spec = s.spec.asJava if (hive.getPartition(table, spec, false) != null && ignoreIfExists) { @@ -463,7 +462,7 @@ private[client] class Shim_v0_13 extends Shim_v0_12 { val addPartitionDesc = new AddPartitionDesc(db, table, ignoreIfExists) parts.zipWithIndex.foreach { case (s, i) => addPartitionDesc.addPartition( - s.spec.asJava, s.storage.locationUri.map(u => new Path(new URI(u)).toString).orNull) + s.spec.asJava, s.storage.locationUri.map(CatalogUtils.URIToString(_)).orNull) if (s.parameters.nonEmpty) { addPartitionDesc.getPartition(i).setPartParams(s.parameters.asJava) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala index 6d7a1c3937a96..490e02d0bd541 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.hive +import java.net.URI + import org.apache.spark.sql.{AnalysisException, SaveMode} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute @@ -70,7 +72,7 @@ class HiveDDLCommandSuite extends PlanTest with SQLTestUtils with TestHiveSingle assert(desc.identifier.database == Some("mydb")) assert(desc.identifier.table == "page_view") assert(desc.tableType == CatalogTableType.EXTERNAL) - assert(desc.storage.locationUri == Some("/user/external/page_view")) + assert(desc.storage.locationUri == Some(new URI("/user/external/page_view"))) assert(desc.schema.isEmpty) // will be populated later when the table is actually created assert(desc.comment == Some("This is the staging page view table")) // TODO will be SQLText @@ -102,7 +104,7 @@ class HiveDDLCommandSuite extends PlanTest with SQLTestUtils with TestHiveSingle assert(desc.identifier.database == Some("mydb")) assert(desc.identifier.table == "page_view") assert(desc.tableType == CatalogTableType.EXTERNAL) - assert(desc.storage.locationUri == Some("/user/external/page_view")) + assert(desc.storage.locationUri == Some(new URI("/user/external/page_view"))) assert(desc.schema.isEmpty) // will be populated later when the table is actually created // TODO will be SQLText assert(desc.comment == Some("This is the staging page view table")) @@ -338,7 +340,7 @@ class HiveDDLCommandSuite extends PlanTest with SQLTestUtils with TestHiveSingle val query = "CREATE EXTERNAL TABLE tab1 (id int, name string) LOCATION '/path/to/nowhere'" val (desc, _) = extractTableDesc(query) assert(desc.tableType == CatalogTableType.EXTERNAL) - assert(desc.storage.locationUri == Some("/path/to/nowhere")) + assert(desc.storage.locationUri == Some(new URI("/path/to/nowhere"))) } test("create table - if not exists") { @@ -469,7 +471,7 @@ class HiveDDLCommandSuite extends PlanTest with SQLTestUtils with TestHiveSingle assert(desc.viewText.isEmpty) assert(desc.viewDefaultDatabase.isEmpty) assert(desc.viewQueryColumnNames.isEmpty) - assert(desc.storage.locationUri == Some("/path/to/mercury")) + assert(desc.storage.locationUri == Some(new URI("/path/to/mercury"))) assert(desc.storage.inputFormat == Some("winput")) assert(desc.storage.outputFormat == Some("wowput")) assert(desc.storage.serde == Some("org.apache.poof.serde.Baff")) @@ -644,7 +646,7 @@ class HiveDDLCommandSuite extends PlanTest with SQLTestUtils with TestHiveSingle .add("id", "int") .add("name", "string", nullable = true, comment = "blabla")) assert(table.provider == Some(DDLUtils.HIVE_PROVIDER)) - assert(table.storage.locationUri == Some("/tmp/file")) + assert(table.storage.locationUri == Some(new URI("/tmp/file"))) assert(table.storage.properties == Map("my_prop" -> "1")) assert(table.comment == Some("BLABLA")) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogBackwardCompatibilitySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogBackwardCompatibilitySuite.scala index ee632d24b717e..705d43f1f3aba 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogBackwardCompatibilitySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogBackwardCompatibilitySuite.scala @@ -40,7 +40,8 @@ class HiveExternalCatalogBackwardCompatibilitySuite extends QueryTest spark.sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog].client val tempDir = Utils.createTempDir().getCanonicalFile - val tempDirUri = tempDir.toURI.toString.stripSuffix("/") + val tempDirUri = tempDir.toURI + val tempDirStr = tempDir.getAbsolutePath override def beforeEach(): Unit = { sql("CREATE DATABASE test_db") @@ -59,9 +60,7 @@ class HiveExternalCatalogBackwardCompatibilitySuite extends QueryTest } private def defaultTableURI(tableName: String): URI = { - val defaultPath = - spark.sessionState.catalog.defaultTablePath(TableIdentifier(tableName, Some("test_db"))) - new Path(defaultPath).toUri + spark.sessionState.catalog.defaultTablePath(TableIdentifier(tableName, Some("test_db"))) } // Raw table metadata that are dumped from tables created by Spark 2.0. Note that, all spark @@ -170,8 +169,8 @@ class HiveExternalCatalogBackwardCompatibilitySuite extends QueryTest identifier = TableIdentifier("tbl7", Some("test_db")), tableType = CatalogTableType.EXTERNAL, storage = CatalogStorageFormat.empty.copy( - locationUri = Some(defaultTableURI("tbl7").toString + "-__PLACEHOLDER__"), - properties = Map("path" -> tempDirUri)), + locationUri = Some(new URI(defaultTableURI("tbl7") + "-__PLACEHOLDER__")), + properties = Map("path" -> tempDirStr)), schema = new StructType(), provider = Some("json"), properties = Map( @@ -184,7 +183,7 @@ class HiveExternalCatalogBackwardCompatibilitySuite extends QueryTest tableType = CatalogTableType.EXTERNAL, storage = CatalogStorageFormat.empty.copy( locationUri = Some(tempDirUri), - properties = Map("path" -> tempDirUri)), + properties = Map("path" -> tempDirStr)), schema = simpleSchema, properties = Map( "spark.sql.sources.provider" -> "parquet", @@ -195,8 +194,8 @@ class HiveExternalCatalogBackwardCompatibilitySuite extends QueryTest identifier = TableIdentifier("tbl9", Some("test_db")), tableType = CatalogTableType.EXTERNAL, storage = CatalogStorageFormat.empty.copy( - locationUri = Some(defaultTableURI("tbl9").toString + "-__PLACEHOLDER__"), - properties = Map("path" -> tempDirUri)), + locationUri = Some(new URI(defaultTableURI("tbl9") + "-__PLACEHOLDER__")), + properties = Map("path" -> tempDirStr)), schema = new StructType(), provider = Some("json"), properties = Map("spark.sql.sources.provider" -> "json")) @@ -220,7 +219,7 @@ class HiveExternalCatalogBackwardCompatibilitySuite extends QueryTest if (tbl.tableType == CatalogTableType.EXTERNAL) { // trim the URI prefix - val tableLocation = new URI(readBack.storage.locationUri.get).getPath + val tableLocation = readBack.storage.locationUri.get.getPath val expectedLocation = tempDir.toURI.getPath.stripSuffix("/") assert(tableLocation == expectedLocation) } @@ -236,7 +235,7 @@ class HiveExternalCatalogBackwardCompatibilitySuite extends QueryTest val readBack = getTableMetadata(tbl.identifier.table) // trim the URI prefix - val actualTableLocation = new URI(readBack.storage.locationUri.get).getPath + val actualTableLocation = readBack.storage.locationUri.get.getPath val expected = dir.toURI.getPath.stripSuffix("/") assert(actualTableLocation == expected) } @@ -252,7 +251,7 @@ class HiveExternalCatalogBackwardCompatibilitySuite extends QueryTest assert(readBack.schema.sameType(expectedSchema)) // trim the URI prefix - val actualTableLocation = new URI(readBack.storage.locationUri.get).getPath + val actualTableLocation = readBack.storage.locationUri.get.getPath val expectedLocation = if (tbl.tableType == CatalogTableType.EXTERNAL) { tempDir.toURI.getPath.stripSuffix("/") } else { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala index 16cf4d7ec67f6..892a22ddfafc8 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.hive +import java.net.URI + import org.apache.spark.sql.{QueryTest, Row, SaveMode} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.catalog.CatalogTableType @@ -140,7 +142,7 @@ class DataSourceWithHiveMetastoreCatalogSuite assert(hiveTable.storage.serde === Some(serde)) assert(hiveTable.tableType === CatalogTableType.EXTERNAL) - assert(hiveTable.storage.locationUri === Some(path.toString)) + assert(hiveTable.storage.locationUri === Some(new URI(path.getAbsolutePath))) val columns = hiveTable.schema assert(columns.map(_.name) === Seq("d1", "d2")) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala index 8f0d5d886c9d5..5f15a705a2e99 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala @@ -485,7 +485,7 @@ object SetWarehouseLocationTest extends Logging { val tableMetadata = catalog.getTableMetadata(TableIdentifier("testLocation", Some("default"))) val expectedLocation = - "file:" + expectedWarehouseLocation.toString + "/testlocation" + CatalogUtils.stringToURI(s"file:${expectedWarehouseLocation.toString}/testlocation") val actualLocation = tableMetadata.location if (actualLocation != expectedLocation) { throw new Exception( @@ -500,8 +500,8 @@ object SetWarehouseLocationTest extends Logging { sparkSession.sql("create table testLocation (a int)") val tableMetadata = catalog.getTableMetadata(TableIdentifier("testLocation", Some("testLocationDB"))) - val expectedLocation = - "file:" + expectedWarehouseLocation.toString + "/testlocationdb.db/testlocation" + val expectedLocation = CatalogUtils.stringToURI( + s"file:${expectedWarehouseLocation.toString}/testlocationdb.db/testlocation") val actualLocation = tableMetadata.location if (actualLocation != expectedLocation) { throw new Exception( @@ -868,14 +868,16 @@ object SPARK_18360 { val rawTable = hiveClient.getTable("default", "test_tbl") // Hive will use the value of `hive.metastore.warehouse.dir` to generate default table // location for tables in default database. - assert(rawTable.storage.locationUri.get.contains(newWarehousePath)) + assert(rawTable.storage.locationUri.map( + CatalogUtils.URIToString(_)).get.contains(newWarehousePath)) hiveClient.dropTable("default", "test_tbl", ignoreIfNotExists = false, purge = false) spark.sharedState.externalCatalog.createTable(tableMeta, ignoreIfExists = false) val readBack = spark.sharedState.externalCatalog.getTable("default", "test_tbl") // Spark SQL will use the location of default database to generate default table // location for tables in default database. - assert(readBack.storage.locationUri.get.contains(defaultDbLocation)) + assert(readBack.storage.locationUri.map(CatalogUtils.URIToString(_)) + .get.contains(defaultDbLocation)) } finally { hiveClient.dropTable("default", "test_tbl", ignoreIfNotExists = true, purge = false) hiveClient.runSqlHive(s"SET hive.metastore.warehouse.dir=$defaultDbLocation") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala index 03ea0c8c77682..f02b7218d6eee 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala @@ -1011,7 +1011,7 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv identifier = TableIdentifier("not_skip_hive_metadata"), tableType = CatalogTableType.EXTERNAL, storage = CatalogStorageFormat.empty.copy( - locationUri = Some(tempPath.getCanonicalPath), + locationUri = Some(tempPath.toURI), properties = Map("skipHiveMetadata" -> "false") ), schema = schema, diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala index 47ee4dd4d952c..4aea6d14efb0e 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala @@ -17,6 +17,10 @@ package org.apache.spark.sql.hive +import java.net.URI + +import org.apache.hadoop.fs.Path + import org.apache.spark.sql.{AnalysisException, QueryTest, SaveMode} import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.test.SQLTestUtils @@ -26,8 +30,8 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils with TestHiveSingle private def checkTablePath(dbName: String, tableName: String): Unit = { val metastoreTable = spark.sharedState.externalCatalog.getTable(dbName, tableName) - val expectedPath = - spark.sharedState.externalCatalog.getDatabase(dbName).locationUri + "/" + tableName + val expectedPath = new Path(new Path( + spark.sharedState.externalCatalog.getDatabase(dbName).locationUri), tableName).toUri assert(metastoreTable.location === expectedPath) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala index d61d10bf869e2..dd624eca6b7b0 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.hive.client import java.io.{ByteArrayOutputStream, File, PrintStream} +import java.net.URI import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path @@ -54,7 +55,7 @@ class VersionsSuite extends QueryTest with SQLTestUtils with TestHiveSingleton w test("success sanity check") { val badClient = buildClient(HiveUtils.hiveExecutionVersion, new Configuration()) - val db = new CatalogDatabase("default", "desc", "loc", Map()) + val db = new CatalogDatabase("default", "desc", new URI("loc"), Map()) badClient.createDatabase(db, ignoreIfExists = true) } @@ -125,10 +126,10 @@ class VersionsSuite extends QueryTest with SQLTestUtils with TestHiveSingleton w // Database related API /////////////////////////////////////////////////////////////////////////// - val tempDatabasePath = Utils.createTempDir().getCanonicalPath + val tempDatabasePath = Utils.createTempDir().toURI test(s"$version: createDatabase") { - val defaultDB = CatalogDatabase("default", "desc", "loc", Map()) + val defaultDB = CatalogDatabase("default", "desc", new URI("loc"), Map()) client.createDatabase(defaultDB, ignoreIfExists = true) val tempDB = CatalogDatabase( "temporary", description = "test create", tempDatabasePath, Map()) @@ -346,7 +347,7 @@ class VersionsSuite extends QueryTest with SQLTestUtils with TestHiveSingleton w test(s"$version: alterPartitions") { val spec = Map("key1" -> "1", "key2" -> "2") - val newLocation = Utils.createTempDir().getPath() + val newLocation = new URI(Utils.createTempDir().toURI.toString.stripSuffix("/")) val storage = storageFormat.copy( locationUri = Some(newLocation), // needed for 0.12 alter partitions @@ -660,7 +661,7 @@ class VersionsSuite extends QueryTest with SQLTestUtils with TestHiveSingleton w val expectedPath = s"file:${tPath.toUri.getPath.stripSuffix("/")}" val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t")) - assert(table.location.stripSuffix("/") == expectedPath) + assert(table.location == CatalogUtils.stringToURI(expectedPath)) assert(tPath.getFileSystem(spark.sessionState.newHadoopConf()).exists(tPath)) checkAnswer(spark.table("t"), Row("1") :: Nil) @@ -669,7 +670,7 @@ class VersionsSuite extends QueryTest with SQLTestUtils with TestHiveSingleton w val table1 = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t1")) val expectedPath1 = s"file:${t1Path.toUri.getPath.stripSuffix("/")}" - assert(table1.location.stripSuffix("/") == expectedPath1) + assert(table1.location == CatalogUtils.stringToURI(expectedPath1)) assert(t1Path.getFileSystem(spark.sessionState.newHadoopConf()).exists(t1Path)) checkAnswer(spark.table("t1"), Row(2) :: Nil) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala index 81ae5b7bdb672..e956c9abae514 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala @@ -18,6 +18,8 @@ package org.apache.spark.sql.hive.execution import java.io.File +import java.lang.reflect.InvocationTargetException +import java.net.URI import org.apache.hadoop.fs.Path import org.scalatest.BeforeAndAfterEach @@ -25,7 +27,7 @@ import org.scalatest.BeforeAndAfterEach import org.apache.spark.SparkException import org.apache.spark.sql.{AnalysisException, QueryTest, Row, SaveMode} import org.apache.spark.sql.catalyst.analysis.{NoSuchPartitionException, TableAlreadyExistsException} -import org.apache.spark.sql.catalyst.catalog.{CatalogDatabase, CatalogTable, CatalogTableType} +import org.apache.spark.sql.catalyst.catalog.{CatalogDatabase, CatalogTable, CatalogTableType, CatalogUtils} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.hive.HiveExternalCatalog @@ -710,7 +712,7 @@ class HiveDDLSuite } sql(s"CREATE DATABASE $dbName Location '${tmpDir.toURI.getPath.stripSuffix("/")}'") val db1 = catalog.getDatabaseMetadata(dbName) - val dbPath = tmpDir.toURI.toString.stripSuffix("/") + val dbPath = new URI(tmpDir.toURI.toString.stripSuffix("/")) assert(db1 == CatalogDatabase(dbName, "", dbPath, Map.empty)) sql("USE db1") @@ -747,11 +749,12 @@ class HiveDDLSuite sql(s"CREATE DATABASE $dbName") val catalog = spark.sessionState.catalog val expectedDBLocation = s"file:${dbPath.toUri.getPath.stripSuffix("/")}/$dbName.db" + val expectedDBUri = CatalogUtils.stringToURI(expectedDBLocation) val db1 = catalog.getDatabaseMetadata(dbName) assert(db1 == CatalogDatabase( dbName, "", - expectedDBLocation, + expectedDBUri, Map.empty)) // the database directory was created assert(fs.exists(dbPath) && fs.isDirectory(dbPath)) @@ -1606,7 +1609,7 @@ class HiveDDLSuite """.stripMargin) val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t")) - assert(table.location == dir.getAbsolutePath) + assert(table.location == new URI(dir.getAbsolutePath)) checkAnswer(spark.table("t"), Row(3, 4, 1, 2)) } @@ -1626,7 +1629,7 @@ class HiveDDLSuite """.stripMargin) val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t1")) - assert(table.location == dir.getAbsolutePath) + assert(table.location == new URI(dir.getAbsolutePath)) val partDir = new File(dir, "a=3") assert(partDir.exists()) @@ -1686,4 +1689,162 @@ class HiveDDLSuite } } } + + Seq("a b", "a:b", "a%b").foreach { specialChars => + test(s"datasource table: location uri contains $specialChars") { + withTable("t", "t1") { + withTempDir { dir => + val loc = new File(dir, specialChars) + loc.mkdir() + spark.sql( + s""" + |CREATE TABLE t(a string) + |USING parquet + |LOCATION '$loc' + """.stripMargin) + + val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t")) + assert(table.location == new Path(loc.getAbsolutePath).toUri) + assert(new Path(table.location).toString.contains(specialChars)) + + assert(loc.listFiles().isEmpty) + spark.sql("INSERT INTO TABLE t SELECT 1") + assert(loc.listFiles().length >= 1) + checkAnswer(spark.table("t"), Row("1") :: Nil) + } + + withTempDir { dir => + val loc = new File(dir, specialChars) + loc.mkdir() + spark.sql( + s""" + |CREATE TABLE t1(a string, b string) + |USING parquet + |PARTITIONED BY(b) + |LOCATION '$loc' + """.stripMargin) + + val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t1")) + assert(table.location == new Path(loc.getAbsolutePath).toUri) + assert(new Path(table.location).toString.contains(specialChars)) + + assert(loc.listFiles().isEmpty) + spark.sql("INSERT INTO TABLE t1 PARTITION(b=2) SELECT 1") + val partFile = new File(loc, "b=2") + assert(partFile.listFiles().length >= 1) + checkAnswer(spark.table("t1"), Row("1", "2") :: Nil) + + spark.sql("INSERT INTO TABLE t1 PARTITION(b='2017-03-03 12:13%3A14') SELECT 1") + val partFile1 = new File(loc, "b=2017-03-03 12:13%3A14") + assert(!partFile1.exists()) + val partFile2 = new File(loc, "b=2017-03-03 12%3A13%253A14") + assert(partFile2.listFiles().length >= 1) + checkAnswer(spark.table("t1"), Row("1", "2") :: Row("1", "2017-03-03 12:13%3A14") :: Nil) + } + } + } + } + + Seq("a b", "a:b", "a%b").foreach { specialChars => + test(s"hive table: location uri contains $specialChars") { + withTable("t") { + withTempDir { dir => + val loc = new File(dir, specialChars) + loc.mkdir() + spark.sql( + s""" + |CREATE TABLE t(a string) + |USING hive + |LOCATION '$loc' + """.stripMargin) + + val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t")) + val path = new Path(loc.getAbsolutePath) + val fs = path.getFileSystem(spark.sessionState.newHadoopConf()) + assert(table.location == fs.makeQualified(path).toUri) + assert(new Path(table.location).toString.contains(specialChars)) + + assert(loc.listFiles().isEmpty) + if (specialChars != "a:b") { + spark.sql("INSERT INTO TABLE t SELECT 1") + assert(loc.listFiles().length >= 1) + checkAnswer(spark.table("t"), Row("1") :: Nil) + } else { + val e = intercept[InvocationTargetException] { + spark.sql("INSERT INTO TABLE t SELECT 1") + }.getTargetException.getMessage + assert(e.contains("java.net.URISyntaxException: Relative path in absolute URI: a:b")) + } + } + + withTempDir { dir => + val loc = new File(dir, specialChars) + loc.mkdir() + spark.sql( + s""" + |CREATE TABLE t1(a string, b string) + |USING hive + |PARTITIONED BY(b) + |LOCATION '$loc' + """.stripMargin) + + val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t1")) + val path = new Path(loc.getAbsolutePath) + val fs = path.getFileSystem(spark.sessionState.newHadoopConf()) + assert(table.location == fs.makeQualified(path).toUri) + assert(new Path(table.location).toString.contains(specialChars)) + + assert(loc.listFiles().isEmpty) + if (specialChars != "a:b") { + spark.sql("INSERT INTO TABLE t1 PARTITION(b=2) SELECT 1") + val partFile = new File(loc, "b=2") + assert(partFile.listFiles().length >= 1) + checkAnswer(spark.table("t1"), Row("1", "2") :: Nil) + + spark.sql("INSERT INTO TABLE t1 PARTITION(b='2017-03-03 12:13%3A14') SELECT 1") + val partFile1 = new File(loc, "b=2017-03-03 12:13%3A14") + assert(!partFile1.exists()) + val partFile2 = new File(loc, "b=2017-03-03 12%3A13%253A14") + assert(partFile2.listFiles().length >= 1) + checkAnswer(spark.table("t1"), + Row("1", "2") :: Row("1", "2017-03-03 12:13%3A14") :: Nil) + } else { + val e = intercept[InvocationTargetException] { + spark.sql("INSERT INTO TABLE t1 PARTITION(b=2) SELECT 1") + }.getTargetException.getMessage + assert(e.contains("java.net.URISyntaxException: Relative path in absolute URI: a:b")) + + val e1 = intercept[InvocationTargetException] { + spark.sql("INSERT INTO TABLE t1 PARTITION(b='2017-03-03 12:13%3A14') SELECT 1") + }.getTargetException.getMessage + assert(e1.contains("java.net.URISyntaxException: Relative path in absolute URI: a:b")) + } + } + } + } + } + + Seq("a b", "a:b", "a%b").foreach { specialChars => + test(s"location uri contains $specialChars for database") { + try { + withTable("t") { + withTempDir { dir => + val loc = new File(dir, specialChars) + spark.sql(s"CREATE DATABASE tmpdb LOCATION '$loc'") + spark.sql("USE tmpdb") + + Seq(1).toDF("a").write.saveAsTable("t") + val tblloc = new File(loc, "t") + val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t")) + val tblPath = new Path(tblloc.getAbsolutePath) + val fs = tblPath.getFileSystem(spark.sessionState.newHadoopConf()) + assert(table.location == fs.makeQualified(tblPath).toUri) + assert(tblloc.listFiles().nonEmpty) + } + } + } finally { + spark.sql("DROP DATABASE IF EXISTS tmpdb") + } + } + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index ef2d451e6b2d6..be9a5fd71bd25 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -28,7 +28,7 @@ import org.apache.spark.TestUtils import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.{EliminateSubqueryAliases, FunctionRegistry, NoSuchPartitionException} -import org.apache.spark.sql.catalyst.catalog.{CatalogRelation, CatalogTableType} +import org.apache.spark.sql.catalyst.catalog.{CatalogRelation, CatalogTableType, CatalogUtils} import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, SubqueryAlias} import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation} @@ -544,7 +544,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { } userSpecifiedLocation match { case Some(location) => - assert(r.tableMeta.location === location) + assert(r.tableMeta.location === CatalogUtils.stringToURI(location)) case None => // OK. } // Also make sure that the format and serde are as desired. From 12bf832407eaaed90d7c599522457cb36b303b6c Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Mon, 6 Mar 2017 14:06:11 -0600 Subject: [PATCH 16/78] [SPARK-19796][CORE] Fix serialization of long property values in TaskDescription ## What changes were proposed in this pull request? The properties that are serialized with a TaskDescription can have very long values (eg. "spark.job.description" which is set to the full sql statement with the thrift-server). DataOutputStream.writeUTF() does not work well for long strings, so this changes the way those values are serialized to handle longer strings. ## How was this patch tested? Updated existing unit test to reproduce the issue. All unit tests via jenkins. Author: Imran Rashid Closes #17140 from squito/SPARK-19796. --- .../apache/spark/scheduler/TaskDescription.scala | 12 ++++++++++-- .../spark/scheduler/TaskDescriptionSuite.scala | 16 ++++++++++++++++ 2 files changed, 26 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskDescription.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskDescription.scala index 78aa5c40010cc..c98b87148e404 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskDescription.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskDescription.scala @@ -19,6 +19,7 @@ package org.apache.spark.scheduler import java.io.{DataInputStream, DataOutputStream} import java.nio.ByteBuffer +import java.nio.charset.StandardCharsets import java.util.Properties import scala.collection.JavaConverters._ @@ -86,7 +87,10 @@ private[spark] object TaskDescription { dataOut.writeInt(taskDescription.properties.size()) taskDescription.properties.asScala.foreach { case (key, value) => dataOut.writeUTF(key) - dataOut.writeUTF(value) + // SPARK-19796 -- writeUTF doesn't work for long strings, which can happen for property values + val bytes = value.getBytes(StandardCharsets.UTF_8) + dataOut.writeInt(bytes.length) + dataOut.write(bytes) } // Write the task. The task is already serialized, so write it directly to the byte buffer. @@ -124,7 +128,11 @@ private[spark] object TaskDescription { val properties = new Properties() val numProperties = dataIn.readInt() for (i <- 0 until numProperties) { - properties.setProperty(dataIn.readUTF(), dataIn.readUTF()) + val key = dataIn.readUTF() + val valueLength = dataIn.readInt() + val valueBytes = new Array[Byte](valueLength) + dataIn.readFully(valueBytes) + properties.setProperty(key, new String(valueBytes, StandardCharsets.UTF_8)) } // Create a sub-buffer for the serialized task into its own buffer (to be deserialized later). diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskDescriptionSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskDescriptionSuite.scala index 9f1fe0515732e..97487ce1d2ca8 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskDescriptionSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskDescriptionSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.scheduler +import java.io.{ByteArrayOutputStream, DataOutputStream, UTFDataFormatException} import java.nio.ByteBuffer import java.util.Properties @@ -36,6 +37,21 @@ class TaskDescriptionSuite extends SparkFunSuite { val originalProperties = new Properties() originalProperties.put("property1", "18") originalProperties.put("property2", "test value") + // SPARK-19796 -- large property values (like a large job description for a long sql query) + // can cause problems for DataOutputStream, make sure we handle correctly + val sb = new StringBuilder() + (0 to 10000).foreach(_ => sb.append("1234567890")) + val largeString = sb.toString() + originalProperties.put("property3", largeString) + // make sure we've got a good test case + intercept[UTFDataFormatException] { + val out = new DataOutputStream(new ByteArrayOutputStream()) + try { + out.writeUTF(largeString) + } finally { + out.close() + } + } // Create a dummy byte buffer for the task. val taskBuffer = ByteBuffer.wrap(Array[Byte](1, 2, 3, 4)) From 9991c2dad6d09d77d5a61e4c4dcd1770e5d984d4 Mon Sep 17 00:00:00 2001 From: jiangxingbo Date: Mon, 6 Mar 2017 12:35:03 -0800 Subject: [PATCH 17/78] [SPARK-19211][SQL] Explicitly prevent Insert into View or Create View As Insert ## What changes were proposed in this pull request? Currently we don't explicitly forbid the following behaviors: 1. The statement CREATE VIEW AS INSERT INTO throws the following exception: ``` scala> spark.sql("CREATE VIEW testView AS INSERT INTO tab VALUES (1, \"a\")") org.apache.spark.sql.AnalysisException: org.apache.hadoop.hive.ql.metadata.HiveException: org.apache.hadoop.hive.ql.metadata.HiveException: at least one column must be specified for the table; scala> spark.sql("CREATE VIEW testView(a, b) AS INSERT INTO tab VALUES (1, \"a\")") org.apache.spark.sql.AnalysisException: The number of columns produced by the SELECT clause (num: `0`) does not match the number of column names specified by CREATE VIEW (num: `2`).; ``` 2. The statement INSERT INTO view VALUES throws the following exception from checkAnalysis: ``` scala> spark.sql("INSERT INTO testView VALUES (1, \"a\")") org.apache.spark.sql.AnalysisException: Inserting into an RDD-based table is not allowed.;; 'InsertIntoTable View (`default`.`testView`, [a#16,b#17]), false, false +- LocalRelation [col1#14, col2#15] ``` After this PR, the behavior changes to: ``` scala> spark.sql("CREATE VIEW testView AS INSERT INTO tab VALUES (1, \"a\")") org.apache.spark.sql.catalyst.parser.ParseException: Operation not allowed: CREATE VIEW ... AS INSERT INTO; scala> spark.sql("CREATE VIEW testView(a, b) AS INSERT INTO tab VALUES (1, \"a\")") org.apache.spark.sql.catalyst.parser.ParseException: Operation not allowed: CREATE VIEW ... AS INSERT INTO; scala> spark.sql("INSERT INTO testView VALUES (1, \"a\")") org.apache.spark.sql.AnalysisException: `default`.`testView` is a view, inserting into a view is not allowed; ``` ## How was this patch tested? Add a new test case in `SparkSqlParserSuite`; Update the corresponding test case in `SQLViewSuite`. Author: jiangxingbo Closes #17125 from jiangxb1987/insert-with-view. --- .../apache/spark/sql/catalyst/analysis/Analyzer.scala | 6 +++++- .../apache/spark/sql/execution/SparkSqlParser.scala | 9 +++++++++ .../org/apache/spark/sql/execution/SQLViewSuite.scala | 2 +- .../spark/sql/execution/SparkSqlParserSuite.scala | 11 +++++++++++ 4 files changed, 26 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 2f8489de6b000..ffa5aed30e19f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -606,7 +606,11 @@ class Analyzer( def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case i @ InsertIntoTable(u: UnresolvedRelation, parts, child, _, _) if child.resolved => - i.copy(table = EliminateSubqueryAliases(lookupTableFromCatalog(u))) + lookupTableFromCatalog(u).canonicalized match { + case v: View => + u.failAnalysis(s"Inserting into a view is not allowed. View: ${v.desc.identifier}.") + case other => i.copy(table = other) + } case u: UnresolvedRelation => resolveRelation(u) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index c106163741278..00d1d6d2701f2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -1331,6 +1331,15 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { if (ctx.identifierList != null) { operationNotAllowed("CREATE VIEW ... PARTITIONED ON", ctx) } else { + // CREATE VIEW ... AS INSERT INTO is not allowed. + ctx.query.queryNoWith match { + case s: SingleInsertQueryContext if s.insertInto != null => + operationNotAllowed("CREATE VIEW ... AS INSERT INTO", ctx) + case _: MultiInsertQueryContext => + operationNotAllowed("CREATE VIEW ... AS FROM ... [INSERT INTO ...]+", ctx) + case _ => // OK + } + val userSpecifiedColumns = Option(ctx.identifierCommentList).toSeq.flatMap { icl => icl.identifierComment.asScala.map { ic => ic.identifier.getText -> Option(ic.STRING).map(string) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala index 2d95cb6d64a87..0e5a1dc6ab629 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala @@ -172,7 +172,7 @@ abstract class SQLViewSuite extends QueryTest with SQLTestUtils { var e = intercept[AnalysisException] { sql(s"INSERT INTO TABLE $viewName SELECT 1") }.getMessage - assert(e.contains("Inserting into an RDD-based table is not allowed")) + assert(e.contains("Inserting into a view is not allowed. View: `default`.`testview`")) val dataFilePath = Thread.currentThread().getContextClassLoader.getResource("data/files/employee.dat") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala index bb6c486e880a0..d44a6e41cb347 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala @@ -210,6 +210,17 @@ class SparkSqlParserSuite extends PlanTest { "no viable alternative at input") } + test("create view as insert into table") { + // Single insert query + intercept("CREATE VIEW testView AS INSERT INTO jt VALUES(1, 1)", + "Operation not allowed: CREATE VIEW ... AS INSERT INTO") + + // Multi insert query + intercept("CREATE VIEW testView AS FROM jt INSERT INTO tbl1 SELECT * WHERE jt.id < 5 " + + "INSERT INTO tbl2 SELECT * WHERE jt.id > 4", + "Operation not allowed: CREATE VIEW ... AS FROM ... [INSERT INTO ...]+") + } + test("SPARK-17328 Fix NPE with EXPLAIN DESCRIBE TABLE") { assertEqual("describe table t", DescribeTableCommand( From 926543664f9d785e70f8314ed6ecc6ecda96d0f4 Mon Sep 17 00:00:00 2001 From: "wm624@hotmail.com" Date: Mon, 6 Mar 2017 13:08:59 -0800 Subject: [PATCH 18/78] [SPARK-19382][ML] Test sparse vectors in LinearSVCSuite ## What changes were proposed in this pull request? Add unit tests for testing SparseVector. We can't add mixed DenseVector and SparseVector test case, as discussed in JIRA 19382. def merge(other: MultivariateOnlineSummarizer): this.type = { if (this.totalWeightSum != 0.0 && other.totalWeightSum != 0.0) { require(n == other.n, s"Dimensions mismatch when merging with another summarizer. " + s"Expecting $n but got $ {other.n} .") ## How was this patch tested? Unit tests Author: wm624@hotmail.com Author: Miao Wang Closes #16784 from wangmiao1981/bk. --- .../ml/classification/LinearSVCSuite.scala | 24 +++++++++++++++++-- 1 file changed, 22 insertions(+), 2 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala index a165d8a9345cf..fe47176a4aaa6 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala @@ -24,12 +24,13 @@ import breeze.linalg.{DenseVector => BDV} import org.apache.spark.SparkFunSuite import org.apache.spark.ml.classification.LinearSVCSuite._ import org.apache.spark.ml.feature.{Instance, LabeledPoint} -import org.apache.spark.ml.linalg.{Vector, Vectors} +import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{Dataset, Row} +import org.apache.spark.sql.functions.udf class LinearSVCSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { @@ -41,6 +42,9 @@ class LinearSVCSuite extends SparkFunSuite with MLlibTestSparkContext with Defau @transient var smallValidationDataset: Dataset[_] = _ @transient var binaryDataset: Dataset[_] = _ + @transient var smallSparseBinaryDataset: Dataset[_] = _ + @transient var smallSparseValidationDataset: Dataset[_] = _ + override def beforeAll(): Unit = { super.beforeAll() @@ -51,6 +55,13 @@ class LinearSVCSuite extends SparkFunSuite with MLlibTestSparkContext with Defau smallBinaryDataset = generateSVMInput(A, Array[Double](B, C), nPoints, 42).toDF() smallValidationDataset = generateSVMInput(A, Array[Double](B, C), nPoints, 17).toDF() binaryDataset = generateSVMInput(1.0, Array[Double](1.0, 2.0, 3.0, 4.0), 10000, 42).toDF() + + // Dataset for testing SparseVector + val toSparse: Vector => SparseVector = _.asInstanceOf[DenseVector].toSparse + val sparse = udf(toSparse) + smallSparseBinaryDataset = smallBinaryDataset.withColumn("features", sparse('features)) + smallSparseValidationDataset = smallValidationDataset.withColumn("features", sparse('features)) + } /** @@ -68,6 +79,8 @@ class LinearSVCSuite extends SparkFunSuite with MLlibTestSparkContext with Defau val model = svm.fit(smallBinaryDataset) assert(model.transform(smallValidationDataset) .where("prediction=label").count() > nPoints * 0.8) + val sparseModel = svm.fit(smallSparseBinaryDataset) + checkModels(model, sparseModel) } test("Linear SVC binary classification with regularization") { @@ -75,6 +88,8 @@ class LinearSVCSuite extends SparkFunSuite with MLlibTestSparkContext with Defau val model = svm.setRegParam(0.1).fit(smallBinaryDataset) assert(model.transform(smallValidationDataset) .where("prediction=label").count() > nPoints * 0.8) + val sparseModel = svm.fit(smallSparseBinaryDataset) + checkModels(model, sparseModel) } test("params") { @@ -235,7 +250,7 @@ object LinearSVCSuite { "aggregationDepth" -> 3 ) - // Generate noisy input of the form Y = signum(x.dot(weights) + intercept + noise) + // Generate noisy input of the form Y = signum(x.dot(weights) + intercept + noise) def generateSVMInput( intercept: Double, weights: Array[Double], @@ -252,5 +267,10 @@ object LinearSVCSuite { y.zip(x).map(p => LabeledPoint(p._1, Vectors.dense(p._2))) } + def checkModels(model1: LinearSVCModel, model2: LinearSVCModel): Unit = { + assert(model1.intercept == model2.intercept) + assert(model1.coefficients.equals(model2.coefficients)) + } + } From f6471dc0d5db2d98e48f9f1ae1dba0f174ed9648 Mon Sep 17 00:00:00 2001 From: Wojtek Szymanski Date: Mon, 6 Mar 2017 13:19:36 -0800 Subject: [PATCH 19/78] [SPARK-19709][SQL] Read empty file with CSV data source ## What changes were proposed in this pull request? Bugfix for reading empty file with CSV data source. Instead of throwing `NoSuchElementException`, an empty data frame is returned. ## How was this patch tested? Added new unit test in `org.apache.spark.sql.execution.datasources.csv.CSVSuite` Author: Wojtek Szymanski Closes #17068 from wojtek-szymanski/SPARK-19709. --- .../datasources/csv/CSVDataSource.scala | 68 ++++++++++--------- .../execution/datasources/csv/CSVSuite.scala | 10 ++- 2 files changed, 40 insertions(+), 38 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala index 73e6abc6dad37..47567032b0195 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala @@ -133,20 +133,24 @@ object TextInputCSVDataSource extends CSVDataSource { sparkSession: SparkSession, inputPaths: Seq[FileStatus], parsedOptions: CSVOptions): Option[StructType] = { - val csv: Dataset[String] = createBaseDataset(sparkSession, inputPaths, parsedOptions) - val firstLine: String = CSVUtils.filterCommentAndEmpty(csv, parsedOptions).first() - val firstRow = new CsvParser(parsedOptions.asParserSettings).parseLine(firstLine) - val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis - val header = makeSafeHeader(firstRow, caseSensitive, parsedOptions) - val tokenRDD = csv.rdd.mapPartitions { iter => - val filteredLines = CSVUtils.filterCommentAndEmpty(iter, parsedOptions) - val linesWithoutHeader = - CSVUtils.filterHeaderLine(filteredLines, firstLine, parsedOptions) - val parser = new CsvParser(parsedOptions.asParserSettings) - linesWithoutHeader.map(parser.parseLine) + val csv = createBaseDataset(sparkSession, inputPaths, parsedOptions) + CSVUtils.filterCommentAndEmpty(csv, parsedOptions).take(1).headOption match { + case Some(firstLine) => + val firstRow = new CsvParser(parsedOptions.asParserSettings).parseLine(firstLine) + val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis + val header = makeSafeHeader(firstRow, caseSensitive, parsedOptions) + val tokenRDD = csv.rdd.mapPartitions { iter => + val filteredLines = CSVUtils.filterCommentAndEmpty(iter, parsedOptions) + val linesWithoutHeader = + CSVUtils.filterHeaderLine(filteredLines, firstLine, parsedOptions) + val parser = new CsvParser(parsedOptions.asParserSettings) + linesWithoutHeader.map(parser.parseLine) + } + Some(CSVInferSchema.infer(tokenRDD, header, parsedOptions)) + case None => + // If the first line could not be read, just return the empty schema. + Some(StructType(Nil)) } - - Some(CSVInferSchema.infer(tokenRDD, header, parsedOptions)) } private def createBaseDataset( @@ -190,28 +194,28 @@ object WholeFileCSVDataSource extends CSVDataSource { sparkSession: SparkSession, inputPaths: Seq[FileStatus], parsedOptions: CSVOptions): Option[StructType] = { - val csv: RDD[PortableDataStream] = createBaseRdd(sparkSession, inputPaths, parsedOptions) - val maybeFirstRow: Option[Array[String]] = csv.flatMap { lines => + val csv = createBaseRdd(sparkSession, inputPaths, parsedOptions) + csv.flatMap { lines => UnivocityParser.tokenizeStream( CodecStreams.createInputStreamWithCloseResource(lines.getConfiguration, lines.getPath()), - false, + shouldDropHeader = false, new CsvParser(parsedOptions.asParserSettings)) - }.take(1).headOption - - if (maybeFirstRow.isDefined) { - val firstRow = maybeFirstRow.get - val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis - val header = makeSafeHeader(firstRow, caseSensitive, parsedOptions) - val tokenRDD = csv.flatMap { lines => - UnivocityParser.tokenizeStream( - CodecStreams.createInputStreamWithCloseResource(lines.getConfiguration, lines.getPath()), - parsedOptions.headerFlag, - new CsvParser(parsedOptions.asParserSettings)) - } - Some(CSVInferSchema.infer(tokenRDD, header, parsedOptions)) - } else { - // If the first row could not be read, just return the empty schema. - Some(StructType(Nil)) + }.take(1).headOption match { + case Some(firstRow) => + val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis + val header = makeSafeHeader(firstRow, caseSensitive, parsedOptions) + val tokenRDD = csv.flatMap { lines => + UnivocityParser.tokenizeStream( + CodecStreams.createInputStreamWithCloseResource( + lines.getConfiguration, + lines.getPath()), + parsedOptions.headerFlag, + new CsvParser(parsedOptions.asParserSettings)) + } + Some(CSVInferSchema.infer(tokenRDD, header, parsedOptions)) + case None => + // If the first row could not be read, just return the empty schema. + Some(StructType(Nil)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index 56071803f685f..eaedede349134 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -1077,14 +1077,12 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } } - test("Empty file produces empty dataframe with empty schema - wholeFile option") { - withTempPath { path => - path.createNewFile() - + test("Empty file produces empty dataframe with empty schema") { + Seq(false, true).foreach { wholeFile => val df = spark.read.format("csv") .option("header", true) - .option("wholeFile", true) - .load(path.getAbsolutePath) + .option("wholeFile", wholeFile) + .load(testFile(emptyFile)) assert(df.schema === spark.emptyDataFrame.schema) checkAnswer(df, spark.emptyDataFrame) From b0a5cd89097c563e9949d8cfcf84d18b03b8d24c Mon Sep 17 00:00:00 2001 From: Tyson Condie Date: Mon, 6 Mar 2017 16:39:05 -0800 Subject: [PATCH 20/78] [SPARK-19719][SS] Kafka writer for both structured streaming and batch queires ## What changes were proposed in this pull request? Add a new Kafka Sink and Kafka Relation for writing streaming and batch queries, respectively, to Apache Kafka. ### Streaming Kafka Sink - When addBatch is called -- If batchId is great than the last written batch --- Write batch to Kafka ---- Topic will be taken from the record, if present, or from a topic option, which overrides topic in record. -- Else ignore ### Batch Kafka Sink - KafkaSourceProvider will implement CreatableRelationProvider - CreatableRelationProvider#createRelation will write the passed in Dataframe to a Kafka - Topic will be taken from the record, if present, or from topic option, which overrides topic in record. - Save modes Append and ErrorIfExist supported under identical semantics. Other save modes result in an AnalysisException tdas zsxwing ## How was this patch tested? ### The following unit tests will be included - write to stream with topic field: valid stream write with data that includes an existing topic in the schema - write structured streaming aggregation w/o topic field, with default topic: valid stream write with data that does not include a topic field, but the configuration includes a default topic - write data with bad schema: various cases of writing data that does not conform to a proper schema e.g., 1. no topic field or default topic, and 2. no value field - write data with valid schema but wrong types: data with a complete schema but wrong types e.g., key and value types are integers. - write to non-existing topic: write a stream to a topic that does not exist in Kafka, which has been configured to not auto-create topics. - write batch to kafka: simple write batch to Kafka, which goes through the same code path as streaming scenario, so validity checks will not be redone here. ### Examples ```scala // Structured Streaming val writer = inputStringStream.map(s => s.get(0).toString.getBytes()).toDF("value") .selectExpr("value as key", "value as value") .writeStream .format("kafka") .option("checkpointLocation", checkpointDir) .outputMode(OutputMode.Append) .option("kafka.bootstrap.servers", brokerAddress) .option("topic", topic) .queryName("kafkaStream") .start() // Batch val df = spark .sparkContext .parallelize(Seq("1", "2", "3", "4", "5")) .map(v => (topic, v)) .toDF("topic", "value") df.write .format("kafka") .option("kafka.bootstrap.servers",brokerAddress) .option("topic", topic) .save() ``` Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Tyson Condie Closes #17043 from tcondie/kafka-writer. --- .../apache/spark/sql/kafka010/KafkaSink.scala | 43 ++ .../sql/kafka010/KafkaSourceProvider.scala | 83 +++- .../spark/sql/kafka010/KafkaWriteTask.scala | 123 ++++++ .../spark/sql/kafka010/KafkaWriter.scala | 97 +++++ .../spark/sql/kafka010/KafkaSinkSuite.scala | 412 ++++++++++++++++++ 5 files changed, 753 insertions(+), 5 deletions(-) create mode 100644 external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSink.scala create mode 100644 external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala create mode 100644 external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala create mode 100644 external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSink.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSink.scala new file mode 100644 index 0000000000000..08914d82fffdd --- /dev/null +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSink.scala @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.kafka010 + +import java.{util => ju} + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.{DataFrame, SQLContext} +import org.apache.spark.sql.execution.streaming.Sink + +private[kafka010] class KafkaSink( + sqlContext: SQLContext, + executorKafkaParams: ju.Map[String, Object], + topic: Option[String]) extends Sink with Logging { + @volatile private var latestBatchId = -1L + + override def toString(): String = "KafkaSink" + + override def addBatch(batchId: Long, data: DataFrame): Unit = { + if (batchId <= latestBatchId) { + logInfo(s"Skipping already committed batch $batchId") + } else { + KafkaWriter.write(sqlContext.sparkSession, + data.queryExecution, executorKafkaParams, topic) + latestBatchId = batchId + } + } +} diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala index 6a7456719875f..febe3c217122a 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala @@ -23,12 +23,14 @@ import java.util.UUID import scala.collection.JavaConverters._ import org.apache.kafka.clients.consumer.ConsumerConfig -import org.apache.kafka.common.serialization.ByteArrayDeserializer +import org.apache.kafka.clients.producer.ProducerConfig +import org.apache.kafka.common.serialization.{ByteArrayDeserializer, ByteArraySerializer} import org.apache.spark.internal.Logging -import org.apache.spark.sql.SQLContext -import org.apache.spark.sql.execution.streaming.Source +import org.apache.spark.sql.{AnalysisException, DataFrame, SaveMode, SQLContext} +import org.apache.spark.sql.execution.streaming.{Sink, Source} import org.apache.spark.sql.sources._ +import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType /** @@ -36,8 +38,12 @@ import org.apache.spark.sql.types.StructType * IllegalArgumentException when the Kafka Dataset is created, so that it can catch * missing options even before the query is started. */ -private[kafka010] class KafkaSourceProvider extends DataSourceRegister with StreamSourceProvider - with RelationProvider with Logging { +private[kafka010] class KafkaSourceProvider extends DataSourceRegister + with StreamSourceProvider + with StreamSinkProvider + with RelationProvider + with CreatableRelationProvider + with Logging { import KafkaSourceProvider._ override def shortName(): String = "kafka" @@ -152,6 +158,72 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister with Stre endingRelationOffsets) } + override def createSink( + sqlContext: SQLContext, + parameters: Map[String, String], + partitionColumns: Seq[String], + outputMode: OutputMode): Sink = { + val defaultTopic = parameters.get(TOPIC_OPTION_KEY).map(_.trim) + val specifiedKafkaParams = kafkaParamsForProducer(parameters) + new KafkaSink(sqlContext, + new ju.HashMap[String, Object](specifiedKafkaParams.asJava), defaultTopic) + } + + override def createRelation( + outerSQLContext: SQLContext, + mode: SaveMode, + parameters: Map[String, String], + data: DataFrame): BaseRelation = { + mode match { + case SaveMode.Overwrite | SaveMode.Ignore => + throw new AnalysisException(s"Save mode $mode not allowed for Kafka. " + + s"Allowed save modes are ${SaveMode.Append} and " + + s"${SaveMode.ErrorIfExists} (default).") + case _ => // good + } + val topic = parameters.get(TOPIC_OPTION_KEY).map(_.trim) + val specifiedKafkaParams = kafkaParamsForProducer(parameters) + KafkaWriter.write(outerSQLContext.sparkSession, data.queryExecution, + new ju.HashMap[String, Object](specifiedKafkaParams.asJava), topic) + + /* This method is suppose to return a relation that reads the data that was written. + * We cannot support this for Kafka. Therefore, in order to make things consistent, + * we return an empty base relation. + */ + new BaseRelation { + override def sqlContext: SQLContext = unsupportedException + override def schema: StructType = unsupportedException + override def needConversion: Boolean = unsupportedException + override def sizeInBytes: Long = unsupportedException + override def unhandledFilters(filters: Array[Filter]): Array[Filter] = unsupportedException + private def unsupportedException = + throw new UnsupportedOperationException("BaseRelation from Kafka write " + + "operation is not usable.") + } + } + + private def kafkaParamsForProducer(parameters: Map[String, String]): Map[String, String] = { + val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase, v) } + if (caseInsensitiveParams.contains(s"kafka.${ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG}")) { + throw new IllegalArgumentException( + s"Kafka option '${ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG}' is not supported as keys " + + "are serialized with ByteArraySerializer.") + } + + if (caseInsensitiveParams.contains(s"kafka.${ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG}")) + { + throw new IllegalArgumentException( + s"Kafka option '${ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG}' is not supported as " + + "value are serialized with ByteArraySerializer.") + } + parameters + .keySet + .filter(_.toLowerCase.startsWith("kafka.")) + .map { k => k.drop(6).toString -> parameters(k) } + .toMap + (ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG -> classOf[ByteArraySerializer].getName, + ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG -> classOf[ByteArraySerializer].getName) + } + private def kafkaParamsForDriver(specifiedKafkaParams: Map[String, String]) = ConfigUpdater("source", specifiedKafkaParams) .set(ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG, deserClassName) @@ -381,6 +453,7 @@ private[kafka010] object KafkaSourceProvider { private val STARTING_OFFSETS_OPTION_KEY = "startingoffsets" private val ENDING_OFFSETS_OPTION_KEY = "endingoffsets" private val FAIL_ON_DATA_LOSS_OPTION_KEY = "failondataloss" + val TOPIC_OPTION_KEY = "topic" private val deserClassName = classOf[ByteArrayDeserializer].getName } diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala new file mode 100644 index 0000000000000..6e160cbe2db52 --- /dev/null +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala @@ -0,0 +1,123 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.kafka010 + +import java.{util => ju} + +import org.apache.kafka.clients.producer.{KafkaProducer, _} +import org.apache.kafka.common.serialization.ByteArraySerializer + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Literal, UnsafeProjection} +import org.apache.spark.sql.types.{BinaryType, StringType} + +/** + * A simple trait for writing out data in a single Spark task, without any concerns about how + * to commit or abort tasks. Exceptions thrown by the implementation of this class will + * automatically trigger task aborts. + */ +private[kafka010] class KafkaWriteTask( + producerConfiguration: ju.Map[String, Object], + inputSchema: Seq[Attribute], + topic: Option[String]) { + // used to synchronize with Kafka callbacks + @volatile private var failedWrite: Exception = null + private val projection = createProjection + private var producer: KafkaProducer[Array[Byte], Array[Byte]] = _ + + /** + * Writes key value data out to topics. + */ + def execute(iterator: Iterator[InternalRow]): Unit = { + producer = new KafkaProducer[Array[Byte], Array[Byte]](producerConfiguration) + while (iterator.hasNext && failedWrite == null) { + val currentRow = iterator.next() + val projectedRow = projection(currentRow) + val topic = projectedRow.getUTF8String(0) + val key = projectedRow.getBinary(1) + val value = projectedRow.getBinary(2) + if (topic == null) { + throw new NullPointerException(s"null topic present in the data. Use the " + + s"${KafkaSourceProvider.TOPIC_OPTION_KEY} option for setting a default topic.") + } + val record = new ProducerRecord[Array[Byte], Array[Byte]](topic.toString, key, value) + val callback = new Callback() { + override def onCompletion(recordMetadata: RecordMetadata, e: Exception): Unit = { + if (failedWrite == null && e != null) { + failedWrite = e + } + } + } + producer.send(record, callback) + } + } + + def close(): Unit = { + if (producer != null) { + checkForErrors + producer.close() + checkForErrors + producer = null + } + } + + private def createProjection: UnsafeProjection = { + val topicExpression = topic.map(Literal(_)).orElse { + inputSchema.find(_.name == KafkaWriter.TOPIC_ATTRIBUTE_NAME) + }.getOrElse { + throw new IllegalStateException(s"topic option required when no " + + s"'${KafkaWriter.TOPIC_ATTRIBUTE_NAME}' attribute is present") + } + topicExpression.dataType match { + case StringType => // good + case t => + throw new IllegalStateException(s"${KafkaWriter.TOPIC_ATTRIBUTE_NAME} " + + s"attribute unsupported type $t. ${KafkaWriter.TOPIC_ATTRIBUTE_NAME} " + + s"must be a ${StringType}") + } + val keyExpression = inputSchema.find(_.name == KafkaWriter.KEY_ATTRIBUTE_NAME) + .getOrElse(Literal(null, BinaryType)) + keyExpression.dataType match { + case StringType | BinaryType => // good + case t => + throw new IllegalStateException(s"${KafkaWriter.KEY_ATTRIBUTE_NAME} " + + s"attribute unsupported type $t") + } + val valueExpression = inputSchema + .find(_.name == KafkaWriter.VALUE_ATTRIBUTE_NAME).getOrElse( + throw new IllegalStateException(s"Required attribute " + + s"'${KafkaWriter.VALUE_ATTRIBUTE_NAME}' not found") + ) + valueExpression.dataType match { + case StringType | BinaryType => // good + case t => + throw new IllegalStateException(s"${KafkaWriter.VALUE_ATTRIBUTE_NAME} " + + s"attribute unsupported type $t") + } + UnsafeProjection.create( + Seq(topicExpression, Cast(keyExpression, BinaryType), + Cast(valueExpression, BinaryType)), inputSchema) + } + + private def checkForErrors: Unit = { + if (failedWrite != null) { + throw failedWrite + } + } +} + diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala new file mode 100644 index 0000000000000..a637d52c933a3 --- /dev/null +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.kafka010 + +import java.{util => ju} + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.{AnalysisException, SparkSession} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.execution.{QueryExecution, SQLExecution} +import org.apache.spark.sql.types.{BinaryType, StringType} +import org.apache.spark.util.Utils + +/** + * The [[KafkaWriter]] class is used to write data from a batch query + * or structured streaming query, given by a [[QueryExecution]], to Kafka. + * The data is assumed to have a value column, and an optional topic and key + * columns. If the topic column is missing, then the topic must come from + * the 'topic' configuration option. If the key column is missing, then a + * null valued key field will be added to the + * [[org.apache.kafka.clients.producer.ProducerRecord]]. + */ +private[kafka010] object KafkaWriter extends Logging { + val TOPIC_ATTRIBUTE_NAME: String = "topic" + val KEY_ATTRIBUTE_NAME: String = "key" + val VALUE_ATTRIBUTE_NAME: String = "value" + + override def toString: String = "KafkaWriter" + + def validateQuery( + queryExecution: QueryExecution, + kafkaParameters: ju.Map[String, Object], + topic: Option[String] = None): Unit = { + val schema = queryExecution.logical.output + schema.find(_.name == TOPIC_ATTRIBUTE_NAME).getOrElse( + if (topic == None) { + throw new AnalysisException(s"topic option required when no " + + s"'$TOPIC_ATTRIBUTE_NAME' attribute is present. Use the " + + s"${KafkaSourceProvider.TOPIC_OPTION_KEY} option for setting a topic.") + } else { + Literal(topic.get, StringType) + } + ).dataType match { + case StringType => // good + case _ => + throw new AnalysisException(s"Topic type must be a String") + } + schema.find(_.name == KEY_ATTRIBUTE_NAME).getOrElse( + Literal(null, StringType) + ).dataType match { + case StringType | BinaryType => // good + case _ => + throw new AnalysisException(s"$KEY_ATTRIBUTE_NAME attribute type " + + s"must be a String or BinaryType") + } + schema.find(_.name == VALUE_ATTRIBUTE_NAME).getOrElse( + throw new AnalysisException(s"Required attribute '$VALUE_ATTRIBUTE_NAME' not found") + ).dataType match { + case StringType | BinaryType => // good + case _ => + throw new AnalysisException(s"$VALUE_ATTRIBUTE_NAME attribute type " + + s"must be a String or BinaryType") + } + } + + def write( + sparkSession: SparkSession, + queryExecution: QueryExecution, + kafkaParameters: ju.Map[String, Object], + topic: Option[String] = None): Unit = { + val schema = queryExecution.logical.output + validateQuery(queryExecution, kafkaParameters, topic) + SQLExecution.withNewExecutionId(sparkSession, queryExecution) { + queryExecution.toRdd.foreachPartition { iter => + val writeTask = new KafkaWriteTask(kafkaParameters, schema, topic) + Utils.tryWithSafeFinally(block = writeTask.execute(iter))( + finallyBlock = writeTask.close()) + } + } + } +} diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala new file mode 100644 index 0000000000000..490535623cb36 --- /dev/null +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala @@ -0,0 +1,412 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.kafka010 + +import java.util.concurrent.atomic.AtomicInteger + +import org.apache.kafka.clients.producer.ProducerConfig +import org.apache.kafka.common.serialization.ByteArraySerializer +import org.scalatest.time.SpanSugar._ + +import org.apache.spark.SparkException +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, SpecificInternalRow, UnsafeProjection} +import org.apache.spark.sql.execution.streaming.MemoryStream +import org.apache.spark.sql.streaming._ +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.{BinaryType, DataType} + +class KafkaSinkSuite extends StreamTest with SharedSQLContext { + import testImplicits._ + + protected var testUtils: KafkaTestUtils = _ + + override val streamingTimeout = 30.seconds + + override def beforeAll(): Unit = { + super.beforeAll() + testUtils = new KafkaTestUtils( + withBrokerProps = Map("auto.create.topics.enable" -> "false")) + testUtils.setup() + } + + override def afterAll(): Unit = { + if (testUtils != null) { + testUtils.teardown() + testUtils = null + super.afterAll() + } + } + + test("batch - write to kafka") { + val topic = newTopic() + testUtils.createTopic(topic) + val df = Seq("1", "2", "3", "4", "5").map(v => (topic, v)).toDF("topic", "value") + df.write + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("topic", topic) + .save() + checkAnswer( + createKafkaReader(topic).selectExpr("CAST(value as STRING) value"), + Row("1") :: Row("2") :: Row("3") :: Row("4") :: Row("5") :: Nil) + } + + test("batch - null topic field value, and no topic option") { + val df = Seq[(String, String)](null.asInstanceOf[String] -> "1").toDF("topic", "value") + val ex = intercept[SparkException] { + df.write + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .save() + } + assert(ex.getMessage.toLowerCase.contains( + "null topic present in the data")) + } + + test("batch - unsupported save modes") { + val topic = newTopic() + testUtils.createTopic(topic) + val df = Seq[(String, String)](null.asInstanceOf[String] -> "1").toDF("topic", "value") + + // Test bad save mode Ignore + var ex = intercept[AnalysisException] { + df.write + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .mode(SaveMode.Ignore) + .save() + } + assert(ex.getMessage.toLowerCase.contains( + s"save mode ignore not allowed for kafka")) + + // Test bad save mode Overwrite + ex = intercept[AnalysisException] { + df.write + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .mode(SaveMode.Overwrite) + .save() + } + assert(ex.getMessage.toLowerCase.contains( + s"save mode overwrite not allowed for kafka")) + } + + test("streaming - write to kafka with topic field") { + val input = MemoryStream[String] + val topic = newTopic() + testUtils.createTopic(topic) + + val writer = createKafkaWriter( + input.toDF(), + withTopic = None, + withOutputMode = Some(OutputMode.Append))( + withSelectExpr = s"'$topic' as topic", "value") + + val reader = createKafkaReader(topic) + .selectExpr("CAST(key as STRING) key", "CAST(value as STRING) value") + .selectExpr("CAST(key as INT) key", "CAST(value as INT) value") + .as[(Int, Int)] + .map(_._2) + + try { + input.addData("1", "2", "3", "4", "5") + failAfter(streamingTimeout) { + writer.processAllAvailable() + } + checkDatasetUnorderly(reader, 1, 2, 3, 4, 5) + input.addData("6", "7", "8", "9", "10") + failAfter(streamingTimeout) { + writer.processAllAvailable() + } + checkDatasetUnorderly(reader, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10) + } finally { + writer.stop() + } + } + + test("streaming - write aggregation w/o topic field, with topic option") { + val input = MemoryStream[String] + val topic = newTopic() + testUtils.createTopic(topic) + + val writer = createKafkaWriter( + input.toDF().groupBy("value").count(), + withTopic = Some(topic), + withOutputMode = Some(OutputMode.Update()))( + withSelectExpr = "CAST(value as STRING) key", "CAST(count as STRING) value") + + val reader = createKafkaReader(topic) + .selectExpr("CAST(key as STRING) key", "CAST(value as STRING) value") + .selectExpr("CAST(key as INT) key", "CAST(value as INT) value") + .as[(Int, Int)] + + try { + input.addData("1", "2", "2", "3", "3", "3") + failAfter(streamingTimeout) { + writer.processAllAvailable() + } + checkDatasetUnorderly(reader, (1, 1), (2, 2), (3, 3)) + input.addData("1", "2", "3") + failAfter(streamingTimeout) { + writer.processAllAvailable() + } + checkDatasetUnorderly(reader, (1, 1), (2, 2), (3, 3), (1, 2), (2, 3), (3, 4)) + } finally { + writer.stop() + } + } + + test("streaming - aggregation with topic field and topic option") { + /* The purpose of this test is to ensure that the topic option + * overrides the topic field. We begin by writing some data that + * includes a topic field and value (e.g., 'foo') along with a topic + * option. Then when we read from the topic specified in the option + * we should see the data i.e., the data was written to the topic + * option, and not to the topic in the data e.g., foo + */ + val input = MemoryStream[String] + val topic = newTopic() + testUtils.createTopic(topic) + + val writer = createKafkaWriter( + input.toDF().groupBy("value").count(), + withTopic = Some(topic), + withOutputMode = Some(OutputMode.Update()))( + withSelectExpr = "'foo' as topic", + "CAST(value as STRING) key", "CAST(count as STRING) value") + + val reader = createKafkaReader(topic) + .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .selectExpr("CAST(key AS INT)", "CAST(value AS INT)") + .as[(Int, Int)] + + try { + input.addData("1", "2", "2", "3", "3", "3") + failAfter(streamingTimeout) { + writer.processAllAvailable() + } + checkDatasetUnorderly(reader, (1, 1), (2, 2), (3, 3)) + input.addData("1", "2", "3") + failAfter(streamingTimeout) { + writer.processAllAvailable() + } + checkDatasetUnorderly(reader, (1, 1), (2, 2), (3, 3), (1, 2), (2, 3), (3, 4)) + } finally { + writer.stop() + } + } + + + test("streaming - write data with bad schema") { + val input = MemoryStream[String] + val topic = newTopic() + testUtils.createTopic(topic) + + /* No topic field or topic option */ + var writer: StreamingQuery = null + var ex: Exception = null + try { + ex = intercept[StreamingQueryException] { + writer = createKafkaWriter(input.toDF())( + withSelectExpr = "value as key", "value" + ) + input.addData("1", "2", "3", "4", "5") + writer.processAllAvailable() + } + } finally { + writer.stop() + } + assert(ex.getMessage + .toLowerCase + .contains("topic option required when no 'topic' attribute is present")) + + try { + /* No value field */ + ex = intercept[StreamingQueryException] { + writer = createKafkaWriter(input.toDF())( + withSelectExpr = s"'$topic' as topic", "value as key" + ) + input.addData("1", "2", "3", "4", "5") + writer.processAllAvailable() + } + } finally { + writer.stop() + } + assert(ex.getMessage.toLowerCase.contains("required attribute 'value' not found")) + } + + test("streaming - write data with valid schema but wrong types") { + val input = MemoryStream[String] + val topic = newTopic() + testUtils.createTopic(topic) + + var writer: StreamingQuery = null + var ex: Exception = null + try { + /* topic field wrong type */ + ex = intercept[StreamingQueryException] { + writer = createKafkaWriter(input.toDF())( + withSelectExpr = s"CAST('1' as INT) as topic", "value" + ) + input.addData("1", "2", "3", "4", "5") + writer.processAllAvailable() + } + } finally { + writer.stop() + } + assert(ex.getMessage.toLowerCase.contains("topic type must be a string")) + + try { + /* value field wrong type */ + ex = intercept[StreamingQueryException] { + writer = createKafkaWriter(input.toDF())( + withSelectExpr = s"'$topic' as topic", "CAST(value as INT) as value" + ) + input.addData("1", "2", "3", "4", "5") + writer.processAllAvailable() + } + } finally { + writer.stop() + } + assert(ex.getMessage.toLowerCase.contains( + "value attribute type must be a string or binarytype")) + + try { + ex = intercept[StreamingQueryException] { + /* key field wrong type */ + writer = createKafkaWriter(input.toDF())( + withSelectExpr = s"'$topic' as topic", "CAST(value as INT) as key", "value" + ) + input.addData("1", "2", "3", "4", "5") + writer.processAllAvailable() + } + } finally { + writer.stop() + } + assert(ex.getMessage.toLowerCase.contains( + "key attribute type must be a string or binarytype")) + } + + test("streaming - write to non-existing topic") { + val input = MemoryStream[String] + val topic = newTopic() + + var writer: StreamingQuery = null + var ex: Exception = null + try { + ex = intercept[StreamingQueryException] { + writer = createKafkaWriter(input.toDF(), withTopic = Some(topic))() + input.addData("1", "2", "3", "4", "5") + writer.processAllAvailable() + } + } finally { + writer.stop() + } + assert(ex.getMessage.toLowerCase.contains("job aborted")) + } + + test("streaming - exception on config serializer") { + val input = MemoryStream[String] + var writer: StreamingQuery = null + var ex: Exception = null + ex = intercept[IllegalArgumentException] { + writer = createKafkaWriter( + input.toDF(), + withOptions = Map("kafka.key.serializer" -> "foo"))() + } + assert(ex.getMessage.toLowerCase.contains( + "kafka option 'key.serializer' is not supported")) + + ex = intercept[IllegalArgumentException] { + writer = createKafkaWriter( + input.toDF(), + withOptions = Map("kafka.value.serializer" -> "foo"))() + } + assert(ex.getMessage.toLowerCase.contains( + "kafka option 'value.serializer' is not supported")) + } + + test("generic - write big data with small producer buffer") { + /* This test ensures that we understand the semantics of Kafka when + * is comes to blocking on a call to send when the send buffer is full. + * This test will configure the smallest possible producer buffer and + * indicate that we should block when it is full. Thus, no exception should + * be thrown in the case of a full buffer. + */ + val topic = newTopic() + testUtils.createTopic(topic, 1) + val options = new java.util.HashMap[String, Object] + options.put("bootstrap.servers", testUtils.brokerAddress) + options.put("buffer.memory", "16384") // min buffer size + options.put("block.on.buffer.full", "true") + options.put(ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG, classOf[ByteArraySerializer].getName) + options.put(ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG, classOf[ByteArraySerializer].getName) + val inputSchema = Seq(AttributeReference("value", BinaryType)()) + val data = new Array[Byte](15000) // large value + val writeTask = new KafkaWriteTask(options, inputSchema, Some(topic)) + try { + val fieldTypes: Array[DataType] = Array(BinaryType) + val converter = UnsafeProjection.create(fieldTypes) + val row = new SpecificInternalRow(fieldTypes) + row.update(0, data) + val iter = Seq.fill(1000)(converter.apply(row)).iterator + writeTask.execute(iter) + } finally { + writeTask.close() + } + } + + private val topicId = new AtomicInteger(0) + + private def newTopic(): String = s"topic-${topicId.getAndIncrement()}" + + private def createKafkaReader(topic: String): DataFrame = { + spark.read + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("startingOffsets", "earliest") + .option("endingOffsets", "latest") + .option("subscribe", topic) + .load() + } + + private def createKafkaWriter( + input: DataFrame, + withTopic: Option[String] = None, + withOutputMode: Option[OutputMode] = None, + withOptions: Map[String, String] = Map[String, String]()) + (withSelectExpr: String*): StreamingQuery = { + var stream: DataStreamWriter[Row] = null + withTempDir { checkpointDir => + var df = input.toDF() + if (withSelectExpr.length > 0) { + df = df.selectExpr(withSelectExpr: _*) + } + stream = df.writeStream + .format("kafka") + .option("checkpointLocation", checkpointDir.getCanonicalPath) + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .queryName("kafkaStream") + withTopic.foreach(stream.option("topic", _)) + withOutputMode.foreach(stream.outputMode(_)) + withOptions.foreach(opt => stream.option(opt._1, opt._2)) + } + stream.start() + } +} From 9909f6d361fdf2b7ef30fa7fbbc91e00f2999794 Mon Sep 17 00:00:00 2001 From: wangzhenhua Date: Mon, 6 Mar 2017 21:45:36 -0800 Subject: [PATCH 21/78] [SPARK-19350][SQL] Cardinality estimation of Limit and Sample ## What changes were proposed in this pull request? Before this pr, LocalLimit/GlobalLimit/Sample propagates the same row count and column stats from its child, which is incorrect. We can get the correct rowCount in Statistics for GlobalLimit/Sample whether cbo is enabled or not. We don't know the rowCount for LocalLimit because we don't know the partition number at that time. Column stats should not be propagated because we don't know the distribution of columns after Limit or Sample. ## How was this patch tested? Added test cases. Author: wangzhenhua Closes #16696 from wzhfy/limitEstimation. --- .../plans/logical/basicLogicalOperators.scala | 38 +++--- .../BasicStatsEstimationSuite.scala | 122 ++++++++++++++++++ .../statsEstimation/StatsConfSuite.scala | 64 --------- .../spark/sql/StatisticsCollectionSuite.scala | 24 ---- 4 files changed, 145 insertions(+), 103 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala delete mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsConfSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index ccebae3cc2701..4d27ff2acdbad 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -752,14 +752,13 @@ case class GlobalLimit(limitExpr: Expression, child: LogicalPlan) extends UnaryN } override def computeStats(conf: CatalystConf): Statistics = { val limit = limitExpr.eval().asInstanceOf[Int] - val sizeInBytes = if (limit == 0) { - // sizeInBytes can't be zero, or sizeInBytes of BinaryNode will also be zero - // (product of children). - 1 - } else { - (limit: Long) * output.map(a => a.dataType.defaultSize).sum - } - child.stats(conf).copy(sizeInBytes = sizeInBytes) + val childStats = child.stats(conf) + val rowCount: BigInt = childStats.rowCount.map(_.min(limit)).getOrElse(limit) + // Don't propagate column stats, because we don't know the distribution after a limit operation + Statistics( + sizeInBytes = EstimationUtils.getOutputSize(output, rowCount, childStats.attributeStats), + rowCount = Some(rowCount), + isBroadcastable = childStats.isBroadcastable) } } @@ -773,14 +772,21 @@ case class LocalLimit(limitExpr: Expression, child: LogicalPlan) extends UnaryNo } override def computeStats(conf: CatalystConf): Statistics = { val limit = limitExpr.eval().asInstanceOf[Int] - val sizeInBytes = if (limit == 0) { + val childStats = child.stats(conf) + if (limit == 0) { // sizeInBytes can't be zero, or sizeInBytes of BinaryNode will also be zero // (product of children). - 1 + Statistics( + sizeInBytes = 1, + rowCount = Some(0), + isBroadcastable = childStats.isBroadcastable) } else { - (limit: Long) * output.map(a => a.dataType.defaultSize).sum + // The output row count of LocalLimit should be the sum of row counts from each partition. + // However, since the number of partitions is not available here, we just use statistics of + // the child. Because the distribution after a limit operation is unknown, we do not propagate + // the column stats. + childStats.copy(attributeStats = AttributeMap(Nil)) } - child.stats(conf).copy(sizeInBytes = sizeInBytes) } } @@ -816,12 +822,14 @@ case class Sample( override def computeStats(conf: CatalystConf): Statistics = { val ratio = upperBound - lowerBound - // BigInt can't multiply with Double - var sizeInBytes = child.stats(conf).sizeInBytes * (ratio * 100).toInt / 100 + val childStats = child.stats(conf) + var sizeInBytes = EstimationUtils.ceil(BigDecimal(childStats.sizeInBytes) * ratio) if (sizeInBytes == 0) { sizeInBytes = 1 } - child.stats(conf).copy(sizeInBytes = sizeInBytes) + val sampledRowCount = childStats.rowCount.map(c => EstimationUtils.ceil(BigDecimal(c) * ratio)) + // Don't propagate column stats, because we don't know the distribution after a sample operation + Statistics(sizeInBytes, sampledRowCount, isBroadcastable = childStats.isBroadcastable) } override protected def otherCopyArgs: Seq[AnyRef] = isTableSample :: Nil diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala new file mode 100644 index 0000000000000..e5dc811c8b7db --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.statsEstimation + +import org.apache.spark.sql.catalyst.CatalystConf +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference, Literal} +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.types.IntegerType + + +class BasicStatsEstimationSuite extends StatsEstimationTestBase { + val attribute = attr("key") + val colStat = ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), + nullCount = 0, avgLen = 4, maxLen = 4) + + val plan = StatsTestPlan( + outputList = Seq(attribute), + attributeStats = AttributeMap(Seq(attribute -> colStat)), + rowCount = 10, + // row count * (overhead + column size) + size = Some(10 * (8 + 4))) + + test("limit estimation: limit < child's rowCount") { + val localLimit = LocalLimit(Literal(2), plan) + val globalLimit = GlobalLimit(Literal(2), plan) + // LocalLimit's stats is just its child's stats except column stats + checkStats(localLimit, plan.stats(conf).copy(attributeStats = AttributeMap(Nil))) + checkStats(globalLimit, Statistics(sizeInBytes = 24, rowCount = Some(2))) + } + + test("limit estimation: limit > child's rowCount") { + val localLimit = LocalLimit(Literal(20), plan) + val globalLimit = GlobalLimit(Literal(20), plan) + checkStats(localLimit, plan.stats(conf).copy(attributeStats = AttributeMap(Nil))) + // Limit is larger than child's rowCount, so GlobalLimit's stats is equal to its child's stats. + checkStats(globalLimit, plan.stats(conf).copy(attributeStats = AttributeMap(Nil))) + } + + test("limit estimation: limit = 0") { + val localLimit = LocalLimit(Literal(0), plan) + val globalLimit = GlobalLimit(Literal(0), plan) + val stats = Statistics(sizeInBytes = 1, rowCount = Some(0)) + checkStats(localLimit, stats) + checkStats(globalLimit, stats) + } + + test("sample estimation") { + val sample = Sample(0.0, 0.5, withReplacement = false, (math.random * 1000).toLong, plan)() + checkStats(sample, Statistics(sizeInBytes = 60, rowCount = Some(5))) + + // Child doesn't have rowCount in stats + val childStats = Statistics(sizeInBytes = 120) + val childPlan = DummyLogicalPlan(childStats, childStats) + val sample2 = + Sample(0.0, 0.11, withReplacement = false, (math.random * 1000).toLong, childPlan)() + checkStats(sample2, Statistics(sizeInBytes = 14)) + } + + test("estimate statistics when the conf changes") { + val expectedDefaultStats = + Statistics( + sizeInBytes = 40, + rowCount = Some(10), + attributeStats = AttributeMap(Seq( + AttributeReference("c1", IntegerType)() -> ColumnStat(10, Some(1), Some(10), 0, 4, 4))), + isBroadcastable = false) + val expectedCboStats = + Statistics( + sizeInBytes = 4, + rowCount = Some(1), + attributeStats = AttributeMap(Seq( + AttributeReference("c1", IntegerType)() -> ColumnStat(1, Some(5), Some(5), 0, 4, 4))), + isBroadcastable = false) + + val plan = DummyLogicalPlan(defaultStats = expectedDefaultStats, cboStats = expectedCboStats) + checkStats( + plan, expectedStatsCboOn = expectedCboStats, expectedStatsCboOff = expectedDefaultStats) + } + + /** Check estimated stats when cbo is turned on/off. */ + private def checkStats( + plan: LogicalPlan, + expectedStatsCboOn: Statistics, + expectedStatsCboOff: Statistics): Unit = { + assert(plan.stats(conf.copy(cboEnabled = true)) == expectedStatsCboOn) + // Invalidate statistics + plan.invalidateStatsCache() + assert(plan.stats(conf.copy(cboEnabled = false)) == expectedStatsCboOff) + } + + /** Check estimated stats when it's the same whether cbo is turned on or off. */ + private def checkStats(plan: LogicalPlan, expectedStats: Statistics): Unit = + checkStats(plan, expectedStats, expectedStats) +} + +/** + * This class is used for unit-testing the cbo switch, it mimics a logical plan which computes + * a simple statistics or a cbo estimated statistics based on the conf. + */ +private case class DummyLogicalPlan( + defaultStats: Statistics, + cboStats: Statistics) extends LogicalPlan { + override def output: Seq[Attribute] = Nil + override def children: Seq[LogicalPlan] = Nil + override def computeStats(conf: CatalystConf): Statistics = + if (conf.cboEnabled) cboStats else defaultStats +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsConfSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsConfSuite.scala deleted file mode 100644 index 212d57a9bcf95..0000000000000 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsConfSuite.scala +++ /dev/null @@ -1,64 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.statsEstimation - -import org.apache.spark.sql.catalyst.CatalystConf -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference} -import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LogicalPlan, Statistics} -import org.apache.spark.sql.types.IntegerType - - -class StatsConfSuite extends StatsEstimationTestBase { - test("estimate statistics when the conf changes") { - val expectedDefaultStats = - Statistics( - sizeInBytes = 40, - rowCount = Some(10), - attributeStats = AttributeMap(Seq( - AttributeReference("c1", IntegerType)() -> ColumnStat(10, Some(1), Some(10), 0, 4, 4))), - isBroadcastable = false) - val expectedCboStats = - Statistics( - sizeInBytes = 4, - rowCount = Some(1), - attributeStats = AttributeMap(Seq( - AttributeReference("c1", IntegerType)() -> ColumnStat(1, Some(5), Some(5), 0, 4, 4))), - isBroadcastable = false) - - val plan = DummyLogicalPlan(defaultStats = expectedDefaultStats, cboStats = expectedCboStats) - // Return the statistics estimated by cbo - assert(plan.stats(conf.copy(cboEnabled = true)) == expectedCboStats) - // Invalidate statistics - plan.invalidateStatsCache() - // Return the simple statistics - assert(plan.stats(conf.copy(cboEnabled = false)) == expectedDefaultStats) - } -} - -/** - * This class is used for unit-testing the cbo switch, it mimics a logical plan which computes - * a simple statistics or a cbo estimated statistics based on the conf. - */ -private case class DummyLogicalPlan( - defaultStats: Statistics, - cboStats: Statistics) extends LogicalPlan { - override def output: Seq[Attribute] = Nil - override def children: Seq[LogicalPlan] = Nil - override def computeStats(conf: CatalystConf): Statistics = - if (conf.cboEnabled) cboStats else defaultStats -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala index bbb31dbc8f3de..1f547c5a2a8ff 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala @@ -112,30 +112,6 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared spark.sessionState.conf.autoBroadcastJoinThreshold) } - test("estimates the size of limit") { - withTempView("test") { - Seq(("one", 1), ("two", 2), ("three", 3), ("four", 4)).toDF("k", "v") - .createOrReplaceTempView("test") - Seq((0, 1), (1, 24), (2, 48)).foreach { case (limit, expected) => - val df = sql(s"""SELECT * FROM test limit $limit""") - - val sizesGlobalLimit = df.queryExecution.analyzed.collect { case g: GlobalLimit => - g.stats(conf).sizeInBytes - } - assert(sizesGlobalLimit.size === 1, s"Size wrong for:\n ${df.queryExecution}") - assert(sizesGlobalLimit.head === BigInt(expected), - s"expected exact size $expected for table 'test', got: ${sizesGlobalLimit.head}") - - val sizesLocalLimit = df.queryExecution.analyzed.collect { case l: LocalLimit => - l.stats(conf).sizeInBytes - } - assert(sizesLocalLimit.size === 1, s"Size wrong for:\n ${df.queryExecution}") - assert(sizesLocalLimit.head === BigInt(expected), - s"expected exact size $expected for table 'test', got: ${sizesLocalLimit.head}") - } - } - } - test("column stats round trip serialization") { // Make sure we serialize and then deserialize and we will get the result data val df = data.toDF(stats.keys.toSeq :+ "carray" : _*) From 1f6c090c15f355a0c2aad736f8291fcdee5c556d Mon Sep 17 00:00:00 2001 From: actuaryzhang Date: Mon, 6 Mar 2017 21:55:11 -0800 Subject: [PATCH 22/78] [SPARK-19818][SPARKR] rbind should check for name consistency of input data frames ## What changes were proposed in this pull request? Added checks for name consistency of input data frames in union. ## How was this patch tested? new test. Author: actuaryzhang Closes #17159 from actuaryzhang/sparkRUnion. --- R/pkg/R/DataFrame.R | 8 +++++++- R/pkg/inst/tests/testthat/test_sparkSQL.R | 7 +++++++ 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index e33d0d8e29d49..97e0c9edeab48 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -2642,6 +2642,7 @@ generateAliasesForIntersectedCols <- function (x, intersectedColNames, suffix) { #' #' Return a new SparkDataFrame containing the union of rows in this SparkDataFrame #' and another SparkDataFrame. This is equivalent to \code{UNION ALL} in SQL. +#' Input SparkDataFrames can have different schemas (names and data types). #' #' Note: This does not remove duplicate rows across the two SparkDataFrames. #' @@ -2685,7 +2686,8 @@ setMethod("unionAll", #' Union two or more SparkDataFrames #' -#' Union two or more SparkDataFrames. This is equivalent to \code{UNION ALL} in SQL. +#' Union two or more SparkDataFrames by row. As in R's \code{rbind}, this method +#' requires that the input SparkDataFrames have the same column names. #' #' Note: This does not remove duplicate rows across the two SparkDataFrames. #' @@ -2709,6 +2711,10 @@ setMethod("unionAll", setMethod("rbind", signature(... = "SparkDataFrame"), function(x, ..., deparse.level = 1) { + nm <- lapply(list(x, ...), names) + if (length(unique(nm)) != 1) { + stop("Names of input data frames are different.") + } if (nargs() == 3) { union(x, ...) } else { diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index 7c096597fea66..620b633637138 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -1850,6 +1850,13 @@ test_that("union(), rbind(), except(), and intersect() on a DataFrame", { expect_equal(count(unioned2), 12) expect_equal(first(unioned2)$name, "Michael") + df3 <- df2 + names(df3)[1] <- "newName" + expect_error(rbind(df, df3), + "Names of input data frames are different.") + expect_error(rbind(df, df2, df3), + "Names of input data frames are different.") + excepted <- arrange(except(df, df2), desc(df$age)) expect_is(unioned, "SparkDataFrame") expect_equal(count(excepted), 2) From e52499ea9c32326b399b50bf0e3f26278da3feb2 Mon Sep 17 00:00:00 2001 From: windpiger Date: Mon, 6 Mar 2017 22:36:43 -0800 Subject: [PATCH 23/78] [SPARK-19832][SQL] DynamicPartitionWriteTask get partitionPath should escape the partition name ## What changes were proposed in this pull request? Currently in DynamicPartitionWriteTask, when we get the paritionPath of a parition, we just escape the partition value, not escape the partition name. this will cause some problems for some special partition name situation, for example : 1) if the partition name contains '%' etc, there will be two partition path created in the filesytem, one is for escaped path like '/path/a%25b=1', another is for unescaped path like '/path/a%b=1'. and the data inserted stored in unescaped path, while the show partitions table will return 'a%25b=1' which the partition name is escaped. So here it is not consist. And I think the data should be stored in the escaped path in filesystem, which Hive2.0.0 also have the same action. 2) if the partition name contains ':', there will throw exception that new Path("/path","a:b"), this is illegal which has a colon in the relative path. ``` java.lang.IllegalArgumentException: java.net.URISyntaxException: Relative path in absolute URI: a:b at org.apache.hadoop.fs.Path.initialize(Path.java:205) at org.apache.hadoop.fs.Path.(Path.java:171) at org.apache.hadoop.fs.Path.(Path.java:88) ... 48 elided Caused by: java.net.URISyntaxException: Relative path in absolute URI: a:b at java.net.URI.checkPath(URI.java:1823) at java.net.URI.(URI.java:745) at org.apache.hadoop.fs.Path.initialize(Path.java:202) ... 50 more ``` ## How was this patch tested? unit test added Author: windpiger Closes #17173 from windpiger/fixDatasourceSpecialCharPartitionName. --- .../datasources/FileFormatWriter.scala | 2 +- .../sql/execution/command/DDLSuite.scala | 23 ++++++++++++ .../sql/hive/execution/HiveDDLSuite.scala | 35 ++++++++++++++++++- 3 files changed, 58 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala index 950e5ca0d6210..30a09a9ad3370 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala @@ -341,7 +341,7 @@ object FileFormatWriter extends Logging { Seq(Cast(c, StringType, Option(desc.timeZoneId))), Seq(StringType)) val str = If(IsNull(c), Literal(ExternalCatalogUtils.DEFAULT_PARTITION_NAME), escaped) - val partitionName = Literal(c.name + "=") :: str :: Nil + val partitionName = Literal(ExternalCatalogUtils.escapePathName(c.name) + "=") :: str :: Nil if (i == 0) partitionName else Literal(Path.SEPARATOR) :: partitionName } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index 6ffa58bcd9af1..b2199fdf90e5c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -1995,6 +1995,29 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { } } + Seq("a b", "a:b", "a%b", "a,b").foreach { specialChars => + test(s"data source table:partition column name containing $specialChars") { + withTable("t") { + withTempDir { dir => + spark.sql( + s""" + |CREATE TABLE t(a string, `$specialChars` string) + |USING parquet + |PARTITIONED BY(`$specialChars`) + |LOCATION '$dir' + """.stripMargin) + + assert(dir.listFiles().isEmpty) + spark.sql(s"INSERT INTO TABLE t PARTITION(`$specialChars`=2) SELECT 1") + val partEscaped = s"${ExternalCatalogUtils.escapePathName(specialChars)}=2" + val partFile = new File(dir, partEscaped) + assert(partFile.listFiles().length >= 1) + checkAnswer(spark.table("t"), Row("1", "2") :: Nil) + } + } + } + } + Seq("a b", "a:b", "a%b").foreach { specialChars => test(s"location uri contains $specialChars for datasource table") { withTable("t", "t1") { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala index e956c9abae514..df2c1cee942b0 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala @@ -27,7 +27,7 @@ import org.scalatest.BeforeAndAfterEach import org.apache.spark.SparkException import org.apache.spark.sql.{AnalysisException, QueryTest, Row, SaveMode} import org.apache.spark.sql.catalyst.analysis.{NoSuchPartitionException, TableAlreadyExistsException} -import org.apache.spark.sql.catalyst.catalog.{CatalogDatabase, CatalogTable, CatalogTableType, CatalogUtils} +import org.apache.spark.sql.catalyst.catalog.{CatalogDatabase, CatalogTable, CatalogTableType, CatalogUtils, ExternalCatalogUtils} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.hive.HiveExternalCatalog @@ -1690,6 +1690,39 @@ class HiveDDLSuite } } + Seq("parquet", "hive").foreach { datasource => + Seq("a b", "a:b", "a%b", "a,b").foreach { specialChars => + test(s"partition column name of $datasource table containing $specialChars") { + withTable("t") { + withTempDir { dir => + spark.sql( + s""" + |CREATE TABLE t(a string, `$specialChars` string) + |USING $datasource + |PARTITIONED BY(`$specialChars`) + |LOCATION '$dir' + """.stripMargin) + + assert(dir.listFiles().isEmpty) + spark.sql(s"INSERT INTO TABLE t PARTITION(`$specialChars`=2) SELECT 1") + val partEscaped = s"${ExternalCatalogUtils.escapePathName(specialChars)}=2" + val partFile = new File(dir, partEscaped) + assert(partFile.listFiles().length >= 1) + checkAnswer(spark.table("t"), Row("1", "2") :: Nil) + + withSQLConf("hive.exec.dynamic.partition.mode" -> "nonstrict") { + spark.sql(s"INSERT INTO TABLE t PARTITION(`$specialChars`) SELECT 3, 4") + val partEscaped1 = s"${ExternalCatalogUtils.escapePathName(specialChars)}=4" + val partFile1 = new File(dir, partEscaped1) + assert(partFile1.listFiles().length >= 1) + checkAnswer(spark.table("t"), Row("1", "2") :: Row("3", "4") :: Nil) + } + } + } + } + } + } + Seq("a b", "a:b", "a%b").foreach { specialChars => test(s"datasource table: location uri contains $specialChars") { withTable("t", "t1") { From 932196d9e30453e0827ee3cd8a81cb306b7a24d9 Mon Sep 17 00:00:00 2001 From: wangzhenhua Date: Mon, 6 Mar 2017 23:53:53 -0800 Subject: [PATCH 24/78] [SPARK-17075][SQL][FOLLOWUP] fix filter estimation issues ## What changes were proposed in this pull request? 1. support boolean type in binary expression estimation. 2. deal with compound Not conditions. 3. avoid convert BigInt/BigDecimal directly to double unless it's within range (0, 1). 4. reorganize test code. ## How was this patch tested? modify related test cases. Author: wangzhenhua Author: Zhenhua Wang Closes #17148 from wzhfy/fixFilter. --- .../statsEstimation/FilterEstimation.scala | 174 ++++---- .../FilterEstimationSuite.scala | 397 +++++++++--------- 2 files changed, 297 insertions(+), 274 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala index 0c928832d7d22..b10785b05d6c7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala @@ -19,11 +19,12 @@ package org.apache.spark.sql.catalyst.plans.logical.statsEstimation import scala.collection.immutable.HashSet import scala.collection.mutable +import scala.math.BigDecimal.RoundingMode import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.CatalystConf import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Filter, Statistics} import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ @@ -52,17 +53,19 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo def estimate: Option[Statistics] = { if (childStats.rowCount.isEmpty) return None - // save a mutable copy of colStats so that we can later change it recursively + // Save a mutable copy of colStats so that we can later change it recursively. colStatsMap.setInitValues(childStats.attributeStats) - // estimate selectivity of this filter predicate - val filterSelectivity: Double = calculateFilterSelectivity(plan.condition) match { - case Some(percent) => percent - // for not-supported condition, set filter selectivity to a conservative estimate 100% - case None => 1.0 - } + // Estimate selectivity of this filter predicate, and update column stats if needed. + // For not-supported condition, set filter selectivity to a conservative estimate 100% + val filterSelectivity: Double = calculateFilterSelectivity(plan.condition).getOrElse(1.0) - val newColStats = colStatsMap.toColumnStats + val newColStats = if (filterSelectivity == 0) { + // The output is empty, we don't need to keep column stats. + AttributeMap[ColumnStat](Nil) + } else { + colStatsMap.toColumnStats + } val filteredRowCount: BigInt = EstimationUtils.ceil(BigDecimal(childStats.rowCount.get) * filterSelectivity) @@ -74,12 +77,14 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo } /** - * Returns a percentage of rows meeting a compound condition in Filter node. - * A compound condition is decomposed into multiple single conditions linked with AND, OR, NOT. + * Returns a percentage of rows meeting a condition in Filter node. + * If it's a single condition, we calculate the percentage directly. + * If it's a compound condition, it is decomposed into multiple single conditions linked with + * AND, OR, NOT. * For logical AND conditions, we need to update stats after a condition estimation * so that the stats will be more accurate for subsequent estimation. This is needed for * range condition such as (c > 40 AND c <= 50) - * For logical OR conditions, we do not update stats after a condition estimation. + * For logical OR and NOT conditions, we do not update stats after a condition estimation. * * @param condition the compound logical expression * @param update a boolean flag to specify if we need to update ColumnStat of a column @@ -90,34 +95,29 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo def calculateFilterSelectivity(condition: Expression, update: Boolean = true): Option[Double] = { condition match { case And(cond1, cond2) => - // For ease of debugging, we compute percent1 and percent2 in 2 statements. - val percent1 = calculateFilterSelectivity(cond1, update) - val percent2 = calculateFilterSelectivity(cond2, update) - (percent1, percent2) match { - case (Some(p1), Some(p2)) => Some(p1 * p2) - case (Some(p1), None) => Some(p1) - case (None, Some(p2)) => Some(p2) - case (None, None) => None - } + val percent1 = calculateFilterSelectivity(cond1, update).getOrElse(1.0) + val percent2 = calculateFilterSelectivity(cond2, update).getOrElse(1.0) + Some(percent1 * percent2) case Or(cond1, cond2) => - // For ease of debugging, we compute percent1 and percent2 in 2 statements. - val percent1 = calculateFilterSelectivity(cond1, update = false) - val percent2 = calculateFilterSelectivity(cond2, update = false) - (percent1, percent2) match { - case (Some(p1), Some(p2)) => Some(math.min(1.0, p1 + p2 - (p1 * p2))) - case (Some(p1), None) => Some(1.0) - case (None, Some(p2)) => Some(1.0) - case (None, None) => None - } + val percent1 = calculateFilterSelectivity(cond1, update = false).getOrElse(1.0) + val percent2 = calculateFilterSelectivity(cond2, update = false).getOrElse(1.0) + Some(percent1 + percent2 - (percent1 * percent2)) - case Not(cond) => calculateFilterSelectivity(cond, update = false) match { - case Some(percent) => Some(1.0 - percent) - // for not-supported condition, set filter selectivity to a conservative estimate 100% - case None => None - } + case Not(And(cond1, cond2)) => + calculateFilterSelectivity(Or(Not(cond1), Not(cond2)), update = false) + + case Not(Or(cond1, cond2)) => + calculateFilterSelectivity(And(Not(cond1), Not(cond2)), update = false) - case _ => calculateSingleCondition(condition, update) + case Not(cond) => + calculateFilterSelectivity(cond, update = false) match { + case Some(percent) => Some(1.0 - percent) + case None => None + } + + case _ => + calculateSingleCondition(condition, update) } } @@ -225,12 +225,12 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo } val percent = if (isNull) { - nullPercent.toDouble + nullPercent } else { - 1.0 - nullPercent.toDouble + 1.0 - nullPercent } - Some(percent) + Some(percent.toDouble) } /** @@ -249,17 +249,19 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo attr: Attribute, literal: Literal, update: Boolean): Option[Double] = { + if (!colStatsMap.contains(attr)) { + logDebug("[CBO] No statistics for " + attr) + return None + } + attr.dataType match { - case _: NumericType | DateType | TimestampType => + case _: NumericType | DateType | TimestampType | BooleanType => evaluateBinaryForNumeric(op, attr, literal, update) case StringType | BinaryType => // TODO: It is difficult to support other binary comparisons for String/Binary // type without min/max and advanced statistics like histogram. logDebug("[CBO] No range comparison statistics for String/Binary type " + attr) None - case _ => - // TODO: support boolean type. - None } } @@ -291,6 +293,10 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo * Returns a percentage of rows meeting an equality (=) expression. * This method evaluates the equality predicate for all data types. * + * For EqualNullSafe (<=>), if the literal is not null, result will be the same as EqualTo; + * if the literal is null, the condition will be changed to IsNull after optimization. + * So we don't need specific logic for EqualNullSafe here. + * * @param attr an Attribute (or a column) * @param literal a literal value (or constant) * @param update a boolean flag to specify if we need to update ColumnStat of a given column @@ -323,7 +329,7 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo colStatsMap(attr) = newStats } - Some(1.0 / ndv.toDouble) + Some((1.0 / BigDecimal(ndv)).toDouble) } else { Some(0.0) } @@ -394,12 +400,12 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo // return the filter selectivity. Without advanced statistics such as histograms, // we have to assume uniform distribution. - Some(math.min(1.0, newNdv.toDouble / ndv.toDouble)) + Some(math.min(1.0, (BigDecimal(newNdv) / BigDecimal(ndv)).toDouble)) } /** * Returns a percentage of rows meeting a binary comparison expression. - * This method evaluate expression for Numeric columns only. + * This method evaluate expression for Numeric/Date/Timestamp/Boolean columns. * * @param op a binary comparison operator uch as =, <, <=, >, >= * @param attr an Attribute (or a column) @@ -414,53 +420,66 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo literal: Literal, update: Boolean): Option[Double] = { - var percent = 1.0 val colStat = colStatsMap(attr) - val statsRange = - Range(colStat.min, colStat.max, attr.dataType).asInstanceOf[NumericRange] + val statsRange = Range(colStat.min, colStat.max, attr.dataType).asInstanceOf[NumericRange] + val max = BigDecimal(statsRange.max) + val min = BigDecimal(statsRange.min) + val ndv = BigDecimal(colStat.distinctCount) // determine the overlapping degree between predicate range and column's range - val literalValueBD = BigDecimal(literal.value.toString) + val numericLiteral = if (literal.dataType == BooleanType) { + if (literal.value.asInstanceOf[Boolean]) BigDecimal(1) else BigDecimal(0) + } else { + BigDecimal(literal.value.toString) + } val (noOverlap: Boolean, completeOverlap: Boolean) = op match { case _: LessThan => - (literalValueBD <= statsRange.min, literalValueBD > statsRange.max) + (numericLiteral <= min, numericLiteral > max) case _: LessThanOrEqual => - (literalValueBD < statsRange.min, literalValueBD >= statsRange.max) + (numericLiteral < min, numericLiteral >= max) case _: GreaterThan => - (literalValueBD >= statsRange.max, literalValueBD < statsRange.min) + (numericLiteral >= max, numericLiteral < min) case _: GreaterThanOrEqual => - (literalValueBD > statsRange.max, literalValueBD <= statsRange.min) + (numericLiteral > max, numericLiteral <= min) } + var percent = BigDecimal(1.0) if (noOverlap) { percent = 0.0 } else if (completeOverlap) { percent = 1.0 } else { - // this is partial overlap case - val literalDouble = literalValueBD.toDouble - val maxDouble = BigDecimal(statsRange.max).toDouble - val minDouble = BigDecimal(statsRange.min).toDouble - + // This is the partial overlap case: // Without advanced statistics like histogram, we assume uniform data distribution. // We just prorate the adjusted range over the initial range to compute filter selectivity. - // For ease of computation, we convert all relevant numeric values to Double. + assert(max > min) percent = op match { case _: LessThan => - (literalDouble - minDouble) / (maxDouble - minDouble) + if (numericLiteral == max) { + // If the literal value is right on the boundary, we can minus the part of the + // boundary value (1/ndv). + 1.0 - 1.0 / ndv + } else { + (numericLiteral - min) / (max - min) + } case _: LessThanOrEqual => - if (literalValueBD == BigDecimal(statsRange.min)) { - 1.0 / colStat.distinctCount.toDouble + if (numericLiteral == min) { + // The boundary value is the only satisfying value. + 1.0 / ndv } else { - (literalDouble - minDouble) / (maxDouble - minDouble) + (numericLiteral - min) / (max - min) } case _: GreaterThan => - (maxDouble - literalDouble) / (maxDouble - minDouble) + if (numericLiteral == min) { + 1.0 - 1.0 / ndv + } else { + (max - numericLiteral) / (max - min) + } case _: GreaterThanOrEqual => - if (literalValueBD == BigDecimal(statsRange.max)) { - 1.0 / colStat.distinctCount.toDouble + if (numericLiteral == max) { + 1.0 / ndv } else { - (maxDouble - literalDouble) / (maxDouble - minDouble) + (max - numericLiteral) / (max - min) } } @@ -469,22 +488,25 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo val newValue = convertBoundValue(attr.dataType, literal.value) var newMax = colStat.max var newMin = colStat.min + var newNdv = (ndv * percent).setScale(0, RoundingMode.HALF_UP).toBigInt() + if (newNdv < 1) newNdv = 1 + op match { - case _: GreaterThan => newMin = newValue - case _: GreaterThanOrEqual => newMin = newValue - case _: LessThan => newMax = newValue - case _: LessThanOrEqual => newMax = newValue + case _: GreaterThan | _: GreaterThanOrEqual => + // If new ndv is 1, then new max must be equal to new min. + newMin = if (newNdv == 1) newMax else newValue + case _: LessThan | _: LessThanOrEqual => + newMax = if (newNdv == 1) newMin else newValue } - val newNdv = math.max(math.round(colStat.distinctCount.toDouble * percent), 1) - val newStats = colStat.copy(distinctCount = newNdv, min = newMin, - max = newMax, nullCount = 0) + val newStats = + colStat.copy(distinctCount = newNdv, min = newMin, max = newMax, nullCount = 0) colStatsMap(attr) = newStats } } - Some(percent) + Some(percent.toDouble) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala index 8be74ced7bb71..4691913c8c986 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.statsEstimation import java.sql.Date import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Filter, Statistics} import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils._ import org.apache.spark.sql.types._ @@ -33,219 +33,235 @@ class FilterEstimationSuite extends StatsEstimationTestBase { // Suppose our test table has 10 rows and 6 columns. // First column cint has values: 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 // Hence, distinctCount:10, min:1, max:10, nullCount:0, avgLen:4, maxLen:4 - val arInt = AttributeReference("cint", IntegerType)() - val childColStatInt = ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), + val attrInt = AttributeReference("cint", IntegerType)() + val colStatInt = ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), nullCount = 0, avgLen = 4, maxLen = 4) // only 2 values - val arBool = AttributeReference("cbool", BooleanType)() - val childColStatBool = ColumnStat(distinctCount = 2, min = Some(false), max = Some(true), + val attrBool = AttributeReference("cbool", BooleanType)() + val colStatBool = ColumnStat(distinctCount = 2, min = Some(false), max = Some(true), nullCount = 0, avgLen = 1, maxLen = 1) // Second column cdate has 10 values from 2017-01-01 through 2017-01-10. val dMin = Date.valueOf("2017-01-01") val dMax = Date.valueOf("2017-01-10") - val arDate = AttributeReference("cdate", DateType)() - val childColStatDate = ColumnStat(distinctCount = 10, min = Some(dMin), max = Some(dMax), + val attrDate = AttributeReference("cdate", DateType)() + val colStatDate = ColumnStat(distinctCount = 10, min = Some(dMin), max = Some(dMax), nullCount = 0, avgLen = 4, maxLen = 4) // Fourth column cdecimal has 4 values from 0.20 through 0.80 at increment of 0.20. val decMin = new java.math.BigDecimal("0.200000000000000000") val decMax = new java.math.BigDecimal("0.800000000000000000") - val arDecimal = AttributeReference("cdecimal", DecimalType(18, 18))() - val childColStatDecimal = ColumnStat(distinctCount = 4, min = Some(decMin), max = Some(decMax), + val attrDecimal = AttributeReference("cdecimal", DecimalType(18, 18))() + val colStatDecimal = ColumnStat(distinctCount = 4, min = Some(decMin), max = Some(decMax), nullCount = 0, avgLen = 8, maxLen = 8) // Fifth column cdouble has 10 double values: 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0 - val arDouble = AttributeReference("cdouble", DoubleType)() - val childColStatDouble = ColumnStat(distinctCount = 10, min = Some(1.0), max = Some(10.0), + val attrDouble = AttributeReference("cdouble", DoubleType)() + val colStatDouble = ColumnStat(distinctCount = 10, min = Some(1.0), max = Some(10.0), nullCount = 0, avgLen = 8, maxLen = 8) // Sixth column cstring has 10 String values: // "A0", "A1", "A2", "A3", "A4", "A5", "A6", "A7", "A8", "A9" - val arString = AttributeReference("cstring", StringType)() - val childColStatString = ColumnStat(distinctCount = 10, min = None, max = None, + val attrString = AttributeReference("cstring", StringType)() + val colStatString = ColumnStat(distinctCount = 10, min = None, max = None, nullCount = 0, avgLen = 2, maxLen = 2) + val attributeMap = AttributeMap(Seq( + attrInt -> colStatInt, + attrBool -> colStatBool, + attrDate -> colStatDate, + attrDecimal -> colStatDecimal, + attrDouble -> colStatDouble, + attrString -> colStatString)) + test("cint = 2") { validateEstimatedStats( - arInt, - Filter(EqualTo(arInt, Literal(2)), childStatsTestPlan(Seq(arInt), 10L)), - ColumnStat(distinctCount = 1, min = Some(2), max = Some(2), - nullCount = 0, avgLen = 4, maxLen = 4), - 1) + Filter(EqualTo(attrInt, Literal(2)), childStatsTestPlan(Seq(attrInt), 10L)), + Seq(attrInt -> ColumnStat(distinctCount = 1, min = Some(2), max = Some(2), + nullCount = 0, avgLen = 4, maxLen = 4)), + expectedRowCount = 1) } test("cint <=> 2") { validateEstimatedStats( - arInt, - Filter(EqualNullSafe(arInt, Literal(2)), childStatsTestPlan(Seq(arInt), 10L)), - ColumnStat(distinctCount = 1, min = Some(2), max = Some(2), - nullCount = 0, avgLen = 4, maxLen = 4), - 1) + Filter(EqualNullSafe(attrInt, Literal(2)), childStatsTestPlan(Seq(attrInt), 10L)), + Seq(attrInt -> ColumnStat(distinctCount = 1, min = Some(2), max = Some(2), + nullCount = 0, avgLen = 4, maxLen = 4)), + expectedRowCount = 1) } test("cint = 0") { // This is an out-of-range case since 0 is outside the range [min, max] validateEstimatedStats( - arInt, - Filter(EqualTo(arInt, Literal(0)), childStatsTestPlan(Seq(arInt), 10L)), - ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4), - 0) + Filter(EqualTo(attrInt, Literal(0)), childStatsTestPlan(Seq(attrInt), 10L)), + Nil, + expectedRowCount = 0) } test("cint < 3") { validateEstimatedStats( - arInt, - Filter(LessThan(arInt, Literal(3)), childStatsTestPlan(Seq(arInt), 10L)), - ColumnStat(distinctCount = 2, min = Some(1), max = Some(3), - nullCount = 0, avgLen = 4, maxLen = 4), - 3) + Filter(LessThan(attrInt, Literal(3)), childStatsTestPlan(Seq(attrInt), 10L)), + Seq(attrInt -> ColumnStat(distinctCount = 2, min = Some(1), max = Some(3), + nullCount = 0, avgLen = 4, maxLen = 4)), + expectedRowCount = 3) } test("cint < 0") { // This is a corner case since literal 0 is smaller than min. validateEstimatedStats( - arInt, - Filter(LessThan(arInt, Literal(0)), childStatsTestPlan(Seq(arInt), 10L)), - ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4), - 0) + Filter(LessThan(attrInt, Literal(0)), childStatsTestPlan(Seq(attrInt), 10L)), + Nil, + expectedRowCount = 0) } test("cint <= 3") { validateEstimatedStats( - arInt, - Filter(LessThanOrEqual(arInt, Literal(3)), childStatsTestPlan(Seq(arInt), 10L)), - ColumnStat(distinctCount = 2, min = Some(1), max = Some(3), - nullCount = 0, avgLen = 4, maxLen = 4), - 3) + Filter(LessThanOrEqual(attrInt, Literal(3)), childStatsTestPlan(Seq(attrInt), 10L)), + Seq(attrInt -> ColumnStat(distinctCount = 2, min = Some(1), max = Some(3), + nullCount = 0, avgLen = 4, maxLen = 4)), + expectedRowCount = 3) } test("cint > 6") { validateEstimatedStats( - arInt, - Filter(GreaterThan(arInt, Literal(6)), childStatsTestPlan(Seq(arInt), 10L)), - ColumnStat(distinctCount = 4, min = Some(6), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4), - 5) + Filter(GreaterThan(attrInt, Literal(6)), childStatsTestPlan(Seq(attrInt), 10L)), + Seq(attrInt -> ColumnStat(distinctCount = 4, min = Some(6), max = Some(10), + nullCount = 0, avgLen = 4, maxLen = 4)), + expectedRowCount = 5) } test("cint > 10") { // This is a corner case since max value is 10. validateEstimatedStats( - arInt, - Filter(GreaterThan(arInt, Literal(10)), childStatsTestPlan(Seq(arInt), 10L)), - ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4), - 0) + Filter(GreaterThan(attrInt, Literal(10)), childStatsTestPlan(Seq(attrInt), 10L)), + Nil, + expectedRowCount = 0) } test("cint >= 6") { validateEstimatedStats( - arInt, - Filter(GreaterThanOrEqual(arInt, Literal(6)), childStatsTestPlan(Seq(arInt), 10L)), - ColumnStat(distinctCount = 4, min = Some(6), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4), - 5) + Filter(GreaterThanOrEqual(attrInt, Literal(6)), childStatsTestPlan(Seq(attrInt), 10L)), + Seq(attrInt -> ColumnStat(distinctCount = 4, min = Some(6), max = Some(10), + nullCount = 0, avgLen = 4, maxLen = 4)), + expectedRowCount = 5) } test("cint IS NULL") { validateEstimatedStats( - arInt, - Filter(IsNull(arInt), childStatsTestPlan(Seq(arInt), 10L)), - ColumnStat(distinctCount = 0, min = None, max = None, - nullCount = 0, avgLen = 4, maxLen = 4), - 0) + Filter(IsNull(attrInt), childStatsTestPlan(Seq(attrInt), 10L)), + Nil, + expectedRowCount = 0) } test("cint IS NOT NULL") { validateEstimatedStats( - arInt, - Filter(IsNotNull(arInt), childStatsTestPlan(Seq(arInt), 10L)), - ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4), - 10) + Filter(IsNotNull(attrInt), childStatsTestPlan(Seq(attrInt), 10L)), + Seq(attrInt -> ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), + nullCount = 0, avgLen = 4, maxLen = 4)), + expectedRowCount = 10) } test("cint > 3 AND cint <= 6") { - val condition = And(GreaterThan(arInt, Literal(3)), LessThanOrEqual(arInt, Literal(6))) + val condition = And(GreaterThan(attrInt, Literal(3)), LessThanOrEqual(attrInt, Literal(6))) validateEstimatedStats( - arInt, - Filter(condition, childStatsTestPlan(Seq(arInt), 10L)), - ColumnStat(distinctCount = 3, min = Some(3), max = Some(6), - nullCount = 0, avgLen = 4, maxLen = 4), - 4) + Filter(condition, childStatsTestPlan(Seq(attrInt), 10L)), + Seq(attrInt -> ColumnStat(distinctCount = 3, min = Some(3), max = Some(6), + nullCount = 0, avgLen = 4, maxLen = 4)), + expectedRowCount = 4) } test("cint = 3 OR cint = 6") { - val condition = Or(EqualTo(arInt, Literal(3)), EqualTo(arInt, Literal(6))) + val condition = Or(EqualTo(attrInt, Literal(3)), EqualTo(attrInt, Literal(6))) + validateEstimatedStats( + Filter(condition, childStatsTestPlan(Seq(attrInt), 10L)), + Seq(attrInt -> ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), + nullCount = 0, avgLen = 4, maxLen = 4)), + expectedRowCount = 2) + } + + test("Not(cint > 3 AND cint <= 6)") { + val condition = Not(And(GreaterThan(attrInt, Literal(3)), LessThanOrEqual(attrInt, Literal(6)))) + validateEstimatedStats( + Filter(condition, childStatsTestPlan(Seq(attrInt), 10L)), + Seq(attrInt -> colStatInt), + expectedRowCount = 6) + } + + test("Not(cint <= 3 OR cint > 6)") { + val condition = Not(Or(LessThanOrEqual(attrInt, Literal(3)), GreaterThan(attrInt, Literal(6)))) + validateEstimatedStats( + Filter(condition, childStatsTestPlan(Seq(attrInt), 10L)), + Seq(attrInt -> colStatInt), + expectedRowCount = 5) + } + + test("Not(cint = 3 AND cstring < 'A8')") { + val condition = Not(And(EqualTo(attrInt, Literal(3)), LessThan(attrString, Literal("A8")))) + validateEstimatedStats( + Filter(condition, childStatsTestPlan(Seq(attrInt, attrString), 10L)), + Seq(attrInt -> colStatInt, attrString -> colStatString), + expectedRowCount = 10) + } + + test("Not(cint = 3 OR cstring < 'A8')") { + val condition = Not(Or(EqualTo(attrInt, Literal(3)), LessThan(attrString, Literal("A8")))) validateEstimatedStats( - arInt, - Filter(condition, childStatsTestPlan(Seq(arInt), 10L)), - ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4), - 2) + Filter(condition, childStatsTestPlan(Seq(attrInt, attrString), 10L)), + Seq(attrInt -> colStatInt, attrString -> colStatString), + expectedRowCount = 9) } test("cint IN (3, 4, 5)") { validateEstimatedStats( - arInt, - Filter(InSet(arInt, Set(3, 4, 5)), childStatsTestPlan(Seq(arInt), 10L)), - ColumnStat(distinctCount = 3, min = Some(3), max = Some(5), - nullCount = 0, avgLen = 4, maxLen = 4), - 3) + Filter(InSet(attrInt, Set(3, 4, 5)), childStatsTestPlan(Seq(attrInt), 10L)), + Seq(attrInt -> ColumnStat(distinctCount = 3, min = Some(3), max = Some(5), + nullCount = 0, avgLen = 4, maxLen = 4)), + expectedRowCount = 3) } test("cint NOT IN (3, 4, 5)") { validateEstimatedStats( - arInt, - Filter(Not(InSet(arInt, Set(3, 4, 5))), childStatsTestPlan(Seq(arInt), 10L)), - ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4), - 7) + Filter(Not(InSet(attrInt, Set(3, 4, 5))), childStatsTestPlan(Seq(attrInt), 10L)), + Seq(attrInt -> ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), + nullCount = 0, avgLen = 4, maxLen = 4)), + expectedRowCount = 7) } test("cbool = true") { validateEstimatedStats( - arBool, - Filter(EqualTo(arBool, Literal(true)), childStatsTestPlan(Seq(arBool), 10L)), - ColumnStat(distinctCount = 1, min = Some(true), max = Some(true), - nullCount = 0, avgLen = 1, maxLen = 1), - 5) + Filter(EqualTo(attrBool, Literal(true)), childStatsTestPlan(Seq(attrBool), 10L)), + Seq(attrBool -> ColumnStat(distinctCount = 1, min = Some(true), max = Some(true), + nullCount = 0, avgLen = 1, maxLen = 1)), + expectedRowCount = 5) } test("cbool > false") { - // bool comparison is not supported yet, so stats remain same. validateEstimatedStats( - arBool, - Filter(GreaterThan(arBool, Literal(false)), childStatsTestPlan(Seq(arBool), 10L)), - ColumnStat(distinctCount = 2, min = Some(false), max = Some(true), - nullCount = 0, avgLen = 1, maxLen = 1), - 10) + Filter(GreaterThan(attrBool, Literal(false)), childStatsTestPlan(Seq(attrBool), 10L)), + Seq(attrBool -> ColumnStat(distinctCount = 1, min = Some(true), max = Some(true), + nullCount = 0, avgLen = 1, maxLen = 1)), + expectedRowCount = 5) } test("cdate = cast('2017-01-02' AS DATE)") { val d20170102 = Date.valueOf("2017-01-02") validateEstimatedStats( - arDate, - Filter(EqualTo(arDate, Literal(d20170102)), - childStatsTestPlan(Seq(arDate), 10L)), - ColumnStat(distinctCount = 1, min = Some(d20170102), max = Some(d20170102), - nullCount = 0, avgLen = 4, maxLen = 4), - 1) + Filter(EqualTo(attrDate, Literal(d20170102)), + childStatsTestPlan(Seq(attrDate), 10L)), + Seq(attrDate -> ColumnStat(distinctCount = 1, min = Some(d20170102), max = Some(d20170102), + nullCount = 0, avgLen = 4, maxLen = 4)), + expectedRowCount = 1) } test("cdate < cast('2017-01-03' AS DATE)") { val d20170103 = Date.valueOf("2017-01-03") validateEstimatedStats( - arDate, - Filter(LessThan(arDate, Literal(d20170103)), - childStatsTestPlan(Seq(arDate), 10L)), - ColumnStat(distinctCount = 2, min = Some(dMin), max = Some(d20170103), - nullCount = 0, avgLen = 4, maxLen = 4), - 3) + Filter(LessThan(attrDate, Literal(d20170103)), + childStatsTestPlan(Seq(attrDate), 10L)), + Seq(attrDate -> ColumnStat(distinctCount = 2, min = Some(dMin), max = Some(d20170103), + nullCount = 0, avgLen = 4, maxLen = 4)), + expectedRowCount = 3) } test("""cdate IN ( cast('2017-01-03' AS DATE), @@ -254,133 +270,118 @@ class FilterEstimationSuite extends StatsEstimationTestBase { val d20170104 = Date.valueOf("2017-01-04") val d20170105 = Date.valueOf("2017-01-05") validateEstimatedStats( - arDate, - Filter(In(arDate, Seq(Literal(d20170103), Literal(d20170104), Literal(d20170105))), - childStatsTestPlan(Seq(arDate), 10L)), - ColumnStat(distinctCount = 3, min = Some(d20170103), max = Some(d20170105), - nullCount = 0, avgLen = 4, maxLen = 4), - 3) + Filter(In(attrDate, Seq(Literal(d20170103), Literal(d20170104), Literal(d20170105))), + childStatsTestPlan(Seq(attrDate), 10L)), + Seq(attrDate -> ColumnStat(distinctCount = 3, min = Some(d20170103), max = Some(d20170105), + nullCount = 0, avgLen = 4, maxLen = 4)), + expectedRowCount = 3) } test("cdecimal = 0.400000000000000000") { val dec_0_40 = new java.math.BigDecimal("0.400000000000000000") validateEstimatedStats( - arDecimal, - Filter(EqualTo(arDecimal, Literal(dec_0_40)), - childStatsTestPlan(Seq(arDecimal), 4L)), - ColumnStat(distinctCount = 1, min = Some(dec_0_40), max = Some(dec_0_40), - nullCount = 0, avgLen = 8, maxLen = 8), - 1) + Filter(EqualTo(attrDecimal, Literal(dec_0_40)), + childStatsTestPlan(Seq(attrDecimal), 4L)), + Seq(attrDecimal -> ColumnStat(distinctCount = 1, min = Some(dec_0_40), max = Some(dec_0_40), + nullCount = 0, avgLen = 8, maxLen = 8)), + expectedRowCount = 1) } test("cdecimal < 0.60 ") { val dec_0_60 = new java.math.BigDecimal("0.600000000000000000") validateEstimatedStats( - arDecimal, - Filter(LessThan(arDecimal, Literal(dec_0_60)), - childStatsTestPlan(Seq(arDecimal), 4L)), - ColumnStat(distinctCount = 3, min = Some(decMin), max = Some(dec_0_60), - nullCount = 0, avgLen = 8, maxLen = 8), - 3) + Filter(LessThan(attrDecimal, Literal(dec_0_60)), + childStatsTestPlan(Seq(attrDecimal), 4L)), + Seq(attrDecimal -> ColumnStat(distinctCount = 3, min = Some(decMin), max = Some(dec_0_60), + nullCount = 0, avgLen = 8, maxLen = 8)), + expectedRowCount = 3) } test("cdouble < 3.0") { validateEstimatedStats( - arDouble, - Filter(LessThan(arDouble, Literal(3.0)), childStatsTestPlan(Seq(arDouble), 10L)), - ColumnStat(distinctCount = 2, min = Some(1.0), max = Some(3.0), - nullCount = 0, avgLen = 8, maxLen = 8), - 3) + Filter(LessThan(attrDouble, Literal(3.0)), childStatsTestPlan(Seq(attrDouble), 10L)), + Seq(attrDouble -> ColumnStat(distinctCount = 2, min = Some(1.0), max = Some(3.0), + nullCount = 0, avgLen = 8, maxLen = 8)), + expectedRowCount = 3) } test("cstring = 'A2'") { validateEstimatedStats( - arString, - Filter(EqualTo(arString, Literal("A2")), childStatsTestPlan(Seq(arString), 10L)), - ColumnStat(distinctCount = 1, min = None, max = None, - nullCount = 0, avgLen = 2, maxLen = 2), - 1) + Filter(EqualTo(attrString, Literal("A2")), childStatsTestPlan(Seq(attrString), 10L)), + Seq(attrString -> ColumnStat(distinctCount = 1, min = None, max = None, + nullCount = 0, avgLen = 2, maxLen = 2)), + expectedRowCount = 1) } - // There is no min/max statistics for String type. We estimate 10 rows returned. - test("cstring < 'A2'") { + test("cstring < 'A2' - unsupported condition") { validateEstimatedStats( - arString, - Filter(LessThan(arString, Literal("A2")), childStatsTestPlan(Seq(arString), 10L)), - ColumnStat(distinctCount = 10, min = None, max = None, - nullCount = 0, avgLen = 2, maxLen = 2), - 10) + Filter(LessThan(attrString, Literal("A2")), childStatsTestPlan(Seq(attrString), 10L)), + Seq(attrString -> ColumnStat(distinctCount = 10, min = None, max = None, + nullCount = 0, avgLen = 2, maxLen = 2)), + expectedRowCount = 10) } - // This is a corner test case. We want to test if we can handle the case when the number of - // valid values in IN clause is greater than the number of distinct values for a given column. - // For example, column has only 2 distinct values 1 and 6. - // The predicate is: column IN (1, 2, 3, 4, 5). test("cint IN (1, 2, 3, 4, 5)") { + // This is a corner test case. We want to test if we can handle the case when the number of + // valid values in IN clause is greater than the number of distinct values for a given column. + // For example, column has only 2 distinct values 1 and 6. + // The predicate is: column IN (1, 2, 3, 4, 5). val cornerChildColStatInt = ColumnStat(distinctCount = 2, min = Some(1), max = Some(6), nullCount = 0, avgLen = 4, maxLen = 4) val cornerChildStatsTestplan = StatsTestPlan( - outputList = Seq(arInt), + outputList = Seq(attrInt), rowCount = 2L, - attributeStats = AttributeMap(Seq(arInt -> cornerChildColStatInt)) + attributeStats = AttributeMap(Seq(attrInt -> cornerChildColStatInt)) ) validateEstimatedStats( - arInt, - Filter(InSet(arInt, Set(1, 2, 3, 4, 5)), cornerChildStatsTestplan), - ColumnStat(distinctCount = 2, min = Some(1), max = Some(5), - nullCount = 0, avgLen = 4, maxLen = 4), - 2) + Filter(InSet(attrInt, Set(1, 2, 3, 4, 5)), cornerChildStatsTestplan), + Seq(attrInt -> ColumnStat(distinctCount = 2, min = Some(1), max = Some(5), + nullCount = 0, avgLen = 4, maxLen = 4)), + expectedRowCount = 2) } private def childStatsTestPlan(outList: Seq[Attribute], tableRowCount: BigInt): StatsTestPlan = { StatsTestPlan( outputList = outList, rowCount = tableRowCount, - attributeStats = AttributeMap(Seq( - arInt -> childColStatInt, - arBool -> childColStatBool, - arDate -> childColStatDate, - arDecimal -> childColStatDecimal, - arDouble -> childColStatDouble, - arString -> childColStatString - )) - ) + attributeStats = AttributeMap(outList.map(a => a -> attributeMap(a)))) } private def validateEstimatedStats( - ar: AttributeReference, filterNode: Filter, - expectedColStats: ColumnStat, - rowCount: Int): Unit = { - - val expectedAttrStats = toAttributeMap(Seq(ar.name -> expectedColStats), filterNode) - val expectedSizeInBytes = getOutputSize(filterNode.output, rowCount, expectedAttrStats) - - val filteredStats = filterNode.stats(conf) - assert(filteredStats.sizeInBytes == expectedSizeInBytes) - assert(filteredStats.rowCount.get == rowCount) - assert(filteredStats.attributeStats(ar) == expectedColStats) - - // If the filter has a binary operator (including those nested inside - // AND/OR/NOT), swap the sides of the attribte and the literal, reverse the - // operator, and then check again. - val rewrittenFilter = filterNode transformExpressionsDown { - case EqualTo(ar: AttributeReference, l: Literal) => - EqualTo(l, ar) - - case LessThan(ar: AttributeReference, l: Literal) => - GreaterThan(l, ar) - case LessThanOrEqual(ar: AttributeReference, l: Literal) => - GreaterThanOrEqual(l, ar) - - case GreaterThan(ar: AttributeReference, l: Literal) => - LessThan(l, ar) - case GreaterThanOrEqual(ar: AttributeReference, l: Literal) => - LessThanOrEqual(l, ar) + expectedColStats: Seq[(Attribute, ColumnStat)], + expectedRowCount: Int): Unit = { + + // If the filter has a binary operator (including those nested inside AND/OR/NOT), swap the + // sides of the attribute and the literal, reverse the operator, and then check again. + val swappedFilter = filterNode transformExpressionsDown { + case EqualTo(attr: Attribute, l: Literal) => + EqualTo(l, attr) + + case LessThan(attr: Attribute, l: Literal) => + GreaterThan(l, attr) + case LessThanOrEqual(attr: Attribute, l: Literal) => + GreaterThanOrEqual(l, attr) + + case GreaterThan(attr: Attribute, l: Literal) => + LessThan(l, attr) + case GreaterThanOrEqual(attr: Attribute, l: Literal) => + LessThanOrEqual(l, attr) + } + + val testFilters = if (swappedFilter != filterNode) { + Seq(swappedFilter, filterNode) + } else { + Seq(filterNode) } - if (rewrittenFilter != filterNode) { - validateEstimatedStats(ar, rewrittenFilter, expectedColStats, rowCount) + testFilters.foreach { filter => + val expectedAttributeMap = AttributeMap(expectedColStats) + val expectedStats = Statistics( + sizeInBytes = getOutputSize(filter.output, expectedRowCount, expectedAttributeMap), + rowCount = Some(expectedRowCount), + attributeStats = expectedAttributeMap) + assert(filter.stats(conf) == expectedStats) } } } From 030acdd1f06f49383079c306b63e874ad738851f Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Tue, 7 Mar 2017 09:00:14 -0800 Subject: [PATCH 25/78] [SPARK-19637][SQL] Add to_json in FunctionRegistry ## What changes were proposed in this pull request? This pr added entries in `FunctionRegistry` and supported `to_json` in SQL. ## How was this patch tested? Added tests in `JsonFunctionsSuite`. Author: Takeshi Yamamuro Closes #16981 from maropu/SPARK-19637. --- .../catalyst/analysis/FunctionRegistry.scala | 3 + .../expressions/jsonExpressions.scala | 41 +++++++++++- .../sql-tests/inputs/json-functions.sql | 8 +++ .../sql-tests/results/json-functions.sql.out | 63 +++++++++++++++++++ .../apache/spark/sql/JsonFunctionsSuite.scala | 23 +++++++ 5 files changed, 136 insertions(+), 2 deletions(-) create mode 100644 sql/core/src/test/resources/sql-tests/inputs/json-functions.sql create mode 100644 sql/core/src/test/resources/sql-tests/results/json-functions.sql.out diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 9c9465f6b8def..556fa9901701b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -421,6 +421,9 @@ object FunctionRegistry { expression[BitwiseOr]("|"), expression[BitwiseXor]("^"), + // json + expression[StructToJson]("to_json"), + // Cast aliases (SPARK-16730) castAlias("boolean", BooleanType), castAlias("tinyint", ByteType), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala index dbff62efdddb6..18b5f2f7ed2e8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala @@ -23,11 +23,12 @@ import scala.util.parsing.combinator.RegexParsers import com.fasterxml.jackson.core._ +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.json._ -import org.apache.spark.sql.catalyst.util.{GenericArrayData, ParseModes} +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData, ParseModes} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils @@ -330,7 +331,7 @@ case class GetJsonObject(json: Expression, path: Expression) // scalastyle:off line.size.limit @ExpressionDescription( - usage = "_FUNC_(jsonStr, p1, p2, ..., pn) - Return a tuple like the function get_json_object, but it takes multiple names. All the input parameters and output column types are string.", + usage = "_FUNC_(jsonStr, p1, p2, ..., pn) - Returns a tuple like the function get_json_object, but it takes multiple names. All the input parameters and output column types are string.", extended = """ Examples: > SELECT _FUNC_('{"a":1, "b":2}', 'a', 'b'); @@ -564,6 +565,17 @@ case class JsonToStruct( /** * Converts a [[StructType]] to a json output string. */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(expr[, options]) - Returns a json string with a given struct value", + extended = """ + Examples: + > SELECT _FUNC_(named_struct('a', 1, 'b', 2)); + {"a":1,"b":2} + > SELECT _FUNC_(named_struct('time', to_timestamp('2015-08-26', 'yyyy-MM-dd')), map('timestampFormat', 'dd/MM/yyyy')); + {"time":"26/08/2015"} + """) +// scalastyle:on line.size.limit case class StructToJson( options: Map[String, String], child: Expression, @@ -573,6 +585,14 @@ case class StructToJson( def this(options: Map[String, String], child: Expression) = this(options, child, None) + // Used in `FunctionRegistry` + def this(child: Expression) = this(Map.empty, child, None) + def this(child: Expression, options: Expression) = + this( + options = StructToJson.convertToMapData(options), + child = child, + timeZoneId = None) + @transient lazy val writer = new CharArrayWriter() @@ -613,3 +633,20 @@ case class StructToJson( override def inputTypes: Seq[AbstractDataType] = StructType :: Nil } + +object StructToJson { + + def convertToMapData(exp: Expression): Map[String, String] = exp match { + case m: CreateMap + if m.dataType.acceptsType(MapType(StringType, StringType, valueContainsNull = false)) => + val arrayMap = m.eval().asInstanceOf[ArrayBasedMapData] + ArrayBasedMapData.toScalaMap(arrayMap).map { case (key, value) => + key.toString -> value.toString + } + case m: CreateMap => + throw new AnalysisException( + s"A type of keys and values in map() must be string, but got ${m.dataType}") + case _ => + throw new AnalysisException("Must use a map() function for options") + } +} diff --git a/sql/core/src/test/resources/sql-tests/inputs/json-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/json-functions.sql new file mode 100644 index 0000000000000..9308560451bf5 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/json-functions.sql @@ -0,0 +1,8 @@ +-- to_json +describe function to_json; +describe function extended to_json; +select to_json(named_struct('a', 1, 'b', 2)); +select to_json(named_struct('time', to_timestamp('2015-08-26', 'yyyy-MM-dd')), map('timestampFormat', 'dd/MM/yyyy')); +-- Check if errors handled +select to_json(named_struct('a', 1, 'b', 2), named_struct('mode', 'PERMISSIVE')); +select to_json(); diff --git a/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out new file mode 100644 index 0000000000000..d8aa4fb9fa788 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out @@ -0,0 +1,63 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 6 + + +-- !query 0 +describe function to_json +-- !query 0 schema +struct +-- !query 0 output +Class: org.apache.spark.sql.catalyst.expressions.StructToJson +Function: to_json +Usage: to_json(expr[, options]) - Returns a json string with a given struct value + + +-- !query 1 +describe function extended to_json +-- !query 1 schema +struct +-- !query 1 output +Class: org.apache.spark.sql.catalyst.expressions.StructToJson +Extended Usage: + Examples: + > SELECT to_json(named_struct('a', 1, 'b', 2)); + {"a":1,"b":2} + > SELECT to_json(named_struct('time', to_timestamp('2015-08-26', 'yyyy-MM-dd')), map('timestampFormat', 'dd/MM/yyyy')); + {"time":"26/08/2015"} + +Function: to_json +Usage: to_json(expr[, options]) - Returns a json string with a given struct value + + +-- !query 2 +select to_json(named_struct('a', 1, 'b', 2)) +-- !query 2 schema +struct +-- !query 2 output +{"a":1,"b":2} + + +-- !query 3 +select to_json(named_struct('time', to_timestamp('2015-08-26', 'yyyy-MM-dd')), map('timestampFormat', 'dd/MM/yyyy')) +-- !query 3 schema +struct +-- !query 3 output +{"time":"26/08/2015"} + + +-- !query 4 +select to_json(named_struct('a', 1, 'b', 2), named_struct('mode', 'PERMISSIVE')) +-- !query 4 schema +struct<> +-- !query 4 output +org.apache.spark.sql.AnalysisException +Must use a map() function for options;; line 1 pos 7 + + +-- !query 5 +select to_json() +-- !query 5 schema +struct<> +-- !query 5 output +org.apache.spark.sql.AnalysisException +Invalid number of arguments for function to_json; line 1 pos 7 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala index 953d161ec2a1d..cdea3b9a0f79f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala @@ -197,4 +197,27 @@ class JsonFunctionsSuite extends QueryTest with SharedSQLContext { .select(to_json($"struct").as("json")) checkAnswer(dfTwo, readBackTwo) } + + test("SPARK-19637 Support to_json in SQL") { + val df1 = Seq(Tuple1(Tuple1(1))).toDF("a") + checkAnswer( + df1.selectExpr("to_json(a)"), + Row("""{"_1":1}""") :: Nil) + + val df2 = Seq(Tuple1(Tuple1(java.sql.Timestamp.valueOf("2015-08-26 18:00:00.0")))).toDF("a") + checkAnswer( + df2.selectExpr("to_json(a, map('timestampFormat', 'dd/MM/yyyy HH:mm'))"), + Row("""{"_1":"26/08/2015 18:00"}""") :: Nil) + + val errMsg1 = intercept[AnalysisException] { + df2.selectExpr("to_json(a, named_struct('a', 1))") + } + assert(errMsg1.getMessage.startsWith("Must use a map() function for options")) + + val errMsg2 = intercept[AnalysisException] { + df2.selectExpr("to_json(a, map('a', 1))") + } + assert(errMsg2.getMessage.startsWith( + "A type of keys and values in map() must be string, but got")) + } } From c05baabf10dd4c808929b4ae7a6d118aba6dd665 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 7 Mar 2017 09:21:58 -0800 Subject: [PATCH 26/78] [SPARK-19765][SPARK-18549][SQL] UNCACHE TABLE should un-cache all cached plans that refer to this table ## What changes were proposed in this pull request? When un-cache a table, we should not only remove the cache entry for this table, but also un-cache any other cached plans that refer to this table. This PR also includes some refactors: 1. use `java.util.LinkedList` to store the cache entries, so that it's safer to remove elements while iterating 2. rename `invalidateCache` to `recacheByPlan`, which is more obvious about what it does. ## How was this patch tested? new regression test Author: Wenchen Fan Closes #17097 from cloud-fan/cache. --- .../spark/sql/execution/CacheManager.scala | 118 ++++++++++-------- .../execution/columnar/InMemoryRelation.scala | 6 - .../spark/sql/execution/command/ddl.scala | 3 +- .../InsertIntoDataSourceCommand.scala | 5 +- .../spark/sql/internal/CatalogImpl.scala | 23 ++-- .../apache/spark/sql/CachedTableSuite.scala | 50 +++++--- .../hive/execution/InsertIntoHiveTable.scala | 4 +- .../spark/sql/hive/CachedTableSuite.scala | 6 +- .../apache/spark/sql/hive/parquetSuites.scala | 2 +- 9 files changed, 119 insertions(+), 98 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala index 80138510dc9ee..0ea806d6cb50b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.execution import java.util.concurrent.locks.ReentrantReadWriteLock +import scala.collection.JavaConverters._ + import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.spark.internal.Logging @@ -45,7 +47,7 @@ case class CachedData(plan: LogicalPlan, cachedRepresentation: InMemoryRelation) class CacheManager extends Logging { @transient - private val cachedData = new scala.collection.mutable.ArrayBuffer[CachedData] + private val cachedData = new java.util.LinkedList[CachedData] @transient private val cacheLock = new ReentrantReadWriteLock @@ -70,7 +72,7 @@ class CacheManager extends Logging { /** Clears all cached tables. */ def clearCache(): Unit = writeLock { - cachedData.foreach(_.cachedRepresentation.cachedColumnBuffers.unpersist()) + cachedData.asScala.foreach(_.cachedRepresentation.cachedColumnBuffers.unpersist()) cachedData.clear() } @@ -88,46 +90,81 @@ class CacheManager extends Logging { query: Dataset[_], tableName: Option[String] = None, storageLevel: StorageLevel = MEMORY_AND_DISK): Unit = writeLock { - val planToCache = query.queryExecution.analyzed + val planToCache = query.logicalPlan if (lookupCachedData(planToCache).nonEmpty) { logWarning("Asked to cache already cached data.") } else { val sparkSession = query.sparkSession - cachedData += - CachedData( - planToCache, - InMemoryRelation( - sparkSession.sessionState.conf.useCompression, - sparkSession.sessionState.conf.columnBatchSize, - storageLevel, - sparkSession.sessionState.executePlan(planToCache).executedPlan, - tableName)) + cachedData.add(CachedData( + planToCache, + InMemoryRelation( + sparkSession.sessionState.conf.useCompression, + sparkSession.sessionState.conf.columnBatchSize, + storageLevel, + sparkSession.sessionState.executePlan(planToCache).executedPlan, + tableName))) } } /** - * Tries to remove the data for the given [[Dataset]] from the cache. - * No operation, if it's already uncached. + * Un-cache all the cache entries that refer to the given plan. + */ + def uncacheQuery(query: Dataset[_], blocking: Boolean = true): Unit = writeLock { + uncacheQuery(query.sparkSession, query.logicalPlan, blocking) + } + + /** + * Un-cache all the cache entries that refer to the given plan. */ - def uncacheQuery(query: Dataset[_], blocking: Boolean = true): Boolean = writeLock { - val planToCache = query.queryExecution.analyzed - val dataIndex = cachedData.indexWhere(cd => planToCache.sameResult(cd.plan)) - val found = dataIndex >= 0 - if (found) { - cachedData(dataIndex).cachedRepresentation.cachedColumnBuffers.unpersist(blocking) - cachedData.remove(dataIndex) + def uncacheQuery(spark: SparkSession, plan: LogicalPlan, blocking: Boolean): Unit = writeLock { + val it = cachedData.iterator() + while (it.hasNext) { + val cd = it.next() + if (cd.plan.find(_.sameResult(plan)).isDefined) { + cd.cachedRepresentation.cachedColumnBuffers.unpersist(blocking) + it.remove() + } } - found + } + + /** + * Tries to re-cache all the cache entries that refer to the given plan. + */ + def recacheByPlan(spark: SparkSession, plan: LogicalPlan): Unit = writeLock { + recacheByCondition(spark, _.find(_.sameResult(plan)).isDefined) + } + + private def recacheByCondition(spark: SparkSession, condition: LogicalPlan => Boolean): Unit = { + val it = cachedData.iterator() + val needToRecache = scala.collection.mutable.ArrayBuffer.empty[CachedData] + while (it.hasNext) { + val cd = it.next() + if (condition(cd.plan)) { + cd.cachedRepresentation.cachedColumnBuffers.unpersist() + // Remove the cache entry before we create a new one, so that we can have a different + // physical plan. + it.remove() + val newCache = InMemoryRelation( + useCompression = cd.cachedRepresentation.useCompression, + batchSize = cd.cachedRepresentation.batchSize, + storageLevel = cd.cachedRepresentation.storageLevel, + child = spark.sessionState.executePlan(cd.plan).executedPlan, + tableName = cd.cachedRepresentation.tableName) + needToRecache += cd.copy(cachedRepresentation = newCache) + } + } + + needToRecache.foreach(cachedData.add) } /** Optionally returns cached data for the given [[Dataset]] */ def lookupCachedData(query: Dataset[_]): Option[CachedData] = readLock { - lookupCachedData(query.queryExecution.analyzed) + lookupCachedData(query.logicalPlan) } /** Optionally returns cached data for the given [[LogicalPlan]]. */ def lookupCachedData(plan: LogicalPlan): Option[CachedData] = readLock { - cachedData.find(cd => plan.sameResult(cd.plan)) + cachedData.asScala.find(cd => plan.sameResult(cd.plan)) } /** Replaces segments of the given logical plan with cached versions where possible. */ @@ -145,40 +182,17 @@ class CacheManager extends Logging { } /** - * Invalidates the cache of any data that contains `plan`. Note that it is possible that this - * function will over invalidate. - */ - def invalidateCache(plan: LogicalPlan): Unit = writeLock { - cachedData.foreach { - case data if data.plan.collect { case p if p.sameResult(plan) => p }.nonEmpty => - data.cachedRepresentation.recache() - case _ => - } - } - - /** - * Invalidates the cache of any data that contains `resourcePath` in one or more + * Tries to re-cache all the cache entries that contain `resourcePath` in one or more * `HadoopFsRelation` node(s) as part of its logical plan. */ - def invalidateCachedPath( - sparkSession: SparkSession, resourcePath: String): Unit = writeLock { + def recacheByPath(spark: SparkSession, resourcePath: String): Unit = writeLock { val (fs, qualifiedPath) = { val path = new Path(resourcePath) - val fs = path.getFileSystem(sparkSession.sessionState.newHadoopConf()) - (fs, path.makeQualified(fs.getUri, fs.getWorkingDirectory)) + val fs = path.getFileSystem(spark.sessionState.newHadoopConf()) + (fs, fs.makeQualified(path)) } - cachedData.filter { - case data if data.plan.find(lookupAndRefresh(_, fs, qualifiedPath)).isDefined => true - case _ => false - }.foreach { data => - val dataIndex = cachedData.indexWhere(cd => data.plan.sameResult(cd.plan)) - if (dataIndex >= 0) { - data.cachedRepresentation.cachedColumnBuffers.unpersist(blocking = true) - cachedData.remove(dataIndex) - } - sparkSession.sharedState.cacheManager.cacheQuery(Dataset.ofRows(sparkSession, data.plan)) - } + recacheByCondition(spark, _.find(lookupAndRefresh(_, fs, qualifiedPath)).isDefined) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala index 37bd95e737786..36037ac003728 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala @@ -85,12 +85,6 @@ case class InMemoryRelation( buildBuffers() } - def recache(): Unit = { - _cachedColumnBuffers.unpersist() - _cachedColumnBuffers = null - buildBuffers() - } - private def buildBuffers(): Unit = { val output = child.output val cached = child.execute().mapPartitionsInternal { rowIterator => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala index b5c60423514cb..9d3c55060dfb6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala @@ -199,8 +199,7 @@ case class DropTableCommand( } } try { - sparkSession.sharedState.cacheManager.uncacheQuery( - sparkSession.table(tableName.quotedString)) + sparkSession.sharedState.cacheManager.uncacheQuery(sparkSession.table(tableName)) } catch { case _: NoSuchTableException if ifExists => case NonFatal(e) => log.warn(e.toString, e) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSourceCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSourceCommand.scala index b2ff68a833fea..a813829d50cb1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSourceCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSourceCommand.scala @@ -42,8 +42,9 @@ case class InsertIntoDataSourceCommand( val df = sparkSession.internalCreateDataFrame(data.queryExecution.toRdd, logicalRelation.schema) relation.insert(df, overwrite) - // Invalidate the cache. - sparkSession.sharedState.cacheManager.invalidateCache(logicalRelation) + // Re-cache all cached plans(including this relation itself, if it's cached) that refer to this + // data source relation. + sparkSession.sharedState.cacheManager.recacheByPlan(sparkSession, logicalRelation) Seq.empty[Row] } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala index ed07ff3ff0599..53374859f13f4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala @@ -343,8 +343,8 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { * @since 2.0.0 */ override def dropTempView(viewName: String): Boolean = { - sparkSession.sessionState.catalog.getTempView(viewName).exists { tempView => - sparkSession.sharedState.cacheManager.uncacheQuery(Dataset.ofRows(sparkSession, tempView)) + sparkSession.sessionState.catalog.getTempView(viewName).exists { viewDef => + sparkSession.sharedState.cacheManager.uncacheQuery(sparkSession, viewDef, blocking = true) sessionCatalog.dropTempView(viewName) } } @@ -359,7 +359,7 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { */ override def dropGlobalTempView(viewName: String): Boolean = { sparkSession.sessionState.catalog.getGlobalTempView(viewName).exists { viewDef => - sparkSession.sharedState.cacheManager.uncacheQuery(Dataset.ofRows(sparkSession, viewDef)) + sparkSession.sharedState.cacheManager.uncacheQuery(sparkSession, viewDef, blocking = true) sessionCatalog.dropGlobalTempView(viewName) } } @@ -404,7 +404,7 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { * @since 2.0.0 */ override def uncacheTable(tableName: String): Unit = { - sparkSession.sharedState.cacheManager.uncacheQuery(query = sparkSession.table(tableName)) + sparkSession.sharedState.cacheManager.uncacheQuery(sparkSession.table(tableName)) } /** @@ -442,17 +442,12 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { // If this table is cached as an InMemoryRelation, drop the original // cached version and make the new version cached lazily. - val logicalPlan = sparkSession.table(tableIdent).queryExecution.analyzed - // Use lookupCachedData directly since RefreshTable also takes databaseName. - val isCached = sparkSession.sharedState.cacheManager.lookupCachedData(logicalPlan).nonEmpty - if (isCached) { - // Create a data frame to represent the table. - // TODO: Use uncacheTable once it supports database name. - val df = Dataset.ofRows(sparkSession, logicalPlan) + val table = sparkSession.table(tableIdent) + if (isCached(table)) { // Uncache the logicalPlan. - sparkSession.sharedState.cacheManager.uncacheQuery(df, blocking = true) + sparkSession.sharedState.cacheManager.uncacheQuery(table, blocking = true) // Cache it again. - sparkSession.sharedState.cacheManager.cacheQuery(df, Some(tableIdent.table)) + sparkSession.sharedState.cacheManager.cacheQuery(table, Some(tableIdent.table)) } } @@ -464,7 +459,7 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { * @since 2.0.0 */ override def refreshByPath(resourcePath: String): Unit = { - sparkSession.sharedState.cacheManager.invalidateCachedPath(sparkSession, resourcePath) + sparkSession.sharedState.cacheManager.recacheByPath(sparkSession, resourcePath) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index 2a0e088437fda..7a7d52b21427a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -24,15 +24,15 @@ import scala.language.postfixOps import org.scalatest.concurrent.Eventually._ import org.apache.spark.CleanerListener -import org.apache.spark.sql.catalyst.expressions.{Expression, SubqueryExpression} -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.expressions.SubqueryExpression import org.apache.spark.sql.execution.RDDScanExec import org.apache.spark.sql.execution.columnar._ import org.apache.spark.sql.execution.exchange.ShuffleExchange import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils} import org.apache.spark.storage.{RDDBlockId, StorageLevel} -import org.apache.spark.util.AccumulatorContext +import org.apache.spark.util.{AccumulatorContext, Utils} private case class BigData(s: String) @@ -65,7 +65,8 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext maybeBlock.nonEmpty } - private def getNumInMemoryRelations(plan: LogicalPlan): Int = { + private def getNumInMemoryRelations(ds: Dataset[_]): Int = { + val plan = ds.queryExecution.withCachedData var sum = plan.collect { case _: InMemoryRelation => 1 }.sum plan.transformAllExpressions { case e: SubqueryExpression => @@ -187,7 +188,7 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext assertCached(spark.table("testData")) assertResult(1, "InMemoryRelation not found, testData should have been cached") { - getNumInMemoryRelations(spark.table("testData").queryExecution.withCachedData) + getNumInMemoryRelations(spark.table("testData")) } spark.catalog.cacheTable("testData") @@ -580,21 +581,21 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext localRelation.createOrReplaceTempView("localRelation") spark.catalog.cacheTable("localRelation") - assert(getNumInMemoryRelations(localRelation.queryExecution.withCachedData) == 1) + assert(getNumInMemoryRelations(localRelation) == 1) } test("SPARK-19093 Caching in side subquery") { withTempView("t1") { Seq(1).toDF("c1").createOrReplaceTempView("t1") spark.catalog.cacheTable("t1") - val cachedPlan = + val ds = sql( """ |SELECT * FROM t1 |WHERE |NOT EXISTS (SELECT * FROM t1) - """.stripMargin).queryExecution.optimizedPlan - assert(getNumInMemoryRelations(cachedPlan) == 2) + """.stripMargin) + assert(getNumInMemoryRelations(ds) == 2) } } @@ -610,17 +611,17 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext spark.catalog.cacheTable("t4") // Nested predicate subquery - val cachedPlan = + val ds = sql( """ |SELECT * FROM t1 |WHERE |c1 IN (SELECT c1 FROM t2 WHERE c1 IN (SELECT c1 FROM t3 WHERE c1 = 1)) - """.stripMargin).queryExecution.optimizedPlan - assert(getNumInMemoryRelations(cachedPlan) == 3) + """.stripMargin) + assert(getNumInMemoryRelations(ds) == 3) // Scalar subquery and predicate subquery - val cachedPlan2 = + val ds2 = sql( """ |SELECT * FROM (SELECT max(c1) FROM t1 GROUP BY c1) @@ -630,8 +631,27 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext |EXISTS (SELECT c1 FROM t3) |OR |c1 IN (SELECT c1 FROM t4) - """.stripMargin).queryExecution.optimizedPlan - assert(getNumInMemoryRelations(cachedPlan2) == 4) + """.stripMargin) + assert(getNumInMemoryRelations(ds2) == 4) + } + } + + test("SPARK-19765: UNCACHE TABLE should un-cache all cached plans that refer to this table") { + withTable("t") { + withTempPath { path => + Seq(1 -> "a").toDF("i", "j").write.parquet(path.getCanonicalPath) + sql(s"CREATE TABLE t USING parquet LOCATION '$path'") + spark.catalog.cacheTable("t") + spark.table("t").select($"i").cache() + checkAnswer(spark.table("t").select($"i"), Row(1)) + assertCached(spark.table("t").select($"i")) + + Utils.deleteRecursively(path) + spark.sessionState.catalog.refreshTable(TableIdentifier("t")) + spark.catalog.uncacheTable("t") + assert(spark.table("t").select($"i").count() == 0) + assert(getNumInMemoryRelations(spark.table("t").select($"i")) == 0) + } } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala index 3c57ee4c8b8f6..b8536d0c1bd58 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala @@ -393,8 +393,8 @@ case class InsertIntoHiveTable( logWarning(s"Unable to delete staging directory: $stagingDir.\n" + e) } - // Invalidate the cache. - sparkSession.catalog.uncacheTable(table.qualifiedName) + // un-cache this table. + sparkSession.catalog.uncacheTable(table.identifier.quotedString) sparkSession.sessionState.catalog.refreshTable(table.identifier) // It would be nice to just return the childRdd unchanged so insert operations could be chained, diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala index 8ccc2b7527f24..2b3f36064c1f8 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala @@ -195,10 +195,8 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with TestHiveSingleto tempPath.delete() table("src").write.mode(SaveMode.Overwrite).parquet(tempPath.toString) sql("DROP TABLE IF EXISTS refreshTable") - sparkSession.catalog.createExternalTable("refreshTable", tempPath.toString, "parquet") - checkAnswer( - table("refreshTable"), - table("src").collect()) + sparkSession.catalog.createTable("refreshTable", tempPath.toString, "parquet") + checkAnswer(table("refreshTable"), table("src")) // Cache the table. sql("CACHE TABLE refreshTable") assertCached(table("refreshTable")) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala index 3512c4a890313..81af24979d822 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala @@ -453,7 +453,7 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { // Converted test_parquet should be cached. sessionState.catalog.getCachedDataSourceTable(tableIdentifier) match { case null => fail("Converted test_parquet should be cached in the cache.") - case logical @ LogicalRelation(parquetRelation: HadoopFsRelation, _, _) => // OK + case LogicalRelation(_: HadoopFsRelation, _, _) => // OK case other => fail( "The cached test_parquet should be a Parquet Relation. " + From 4a9034b17374cf19c77cb74e36c86cd085d59602 Mon Sep 17 00:00:00 2001 From: VinceShieh Date: Tue, 7 Mar 2017 11:24:20 -0800 Subject: [PATCH 27/78] [SPARK-17498][ML] StringIndexer enhancement for handling unseen labels ## What changes were proposed in this pull request? This PR is an enhancement to ML StringIndexer. Before this PR, String Indexer only supports "skip"/"error" options to deal with unseen records. But those unseen records might still be useful and user would like to keep the unseen labels in certain use cases, This PR enables StringIndexer to support keeping unseen labels as indices [numLabels]. '''Before StringIndexer().setHandleInvalid("skip") StringIndexer().setHandleInvalid("error") '''After support the third option "keep" StringIndexer().setHandleInvalid("keep") ## How was this patch tested? Test added in StringIndexerSuite Signed-off-by: VinceShieh (Please fill in changes proposed in this fix) Author: VinceShieh Closes #16883 from VinceShieh/spark-17498. --- docs/ml-features.md | 22 ++++++- .../spark/ml/feature/StringIndexer.scala | 65 ++++++++++++++----- .../spark/ml/feature/StringIndexerSuite.scala | 34 ++++++---- project/MimaExcludes.scala | 4 ++ 4 files changed, 95 insertions(+), 30 deletions(-) diff --git a/docs/ml-features.md b/docs/ml-features.md index 57605bafbf4c3..dad1c6db18f8b 100644 --- a/docs/ml-features.md +++ b/docs/ml-features.md @@ -503,6 +503,7 @@ for more details on the API. `StringIndexer` encodes a string column of labels to a column of label indices. The indices are in `[0, numLabels)`, ordered by label frequencies, so the most frequent label gets index `0`. +The unseen labels will be put at index numLabels if user chooses to keep them. If the input column is numeric, we cast it to string and index the string values. When downstream pipeline components such as `Estimator` or `Transformer` make use of this string-indexed label, you must set the input @@ -542,12 +543,13 @@ column, we should get the following: "a" gets index `0` because it is the most frequent, followed by "c" with index `1` and "b" with index `2`. -Additionally, there are two strategies regarding how `StringIndexer` will handle +Additionally, there are three strategies regarding how `StringIndexer` will handle unseen labels when you have fit a `StringIndexer` on one dataset and then use it to transform another: - throw an exception (which is the default) - skip the row containing the unseen label entirely +- put unseen labels in a special additional bucket, at index numLabels **Examples** @@ -561,6 +563,7 @@ Let's go back to our previous example but this time reuse our previously defined 1 | b 2 | c 3 | d + 4 | e ~~~~ If you've not set how `StringIndexer` handles unseen labels or set it to @@ -576,7 +579,22 @@ will be generated: 2 | c | 1.0 ~~~~ -Notice that the row containing "d" does not appear. +Notice that the rows containing "d" or "e" do not appear. + +If you call `setHandleInvalid("keep")`, the following dataset +will be generated: + +~~~~ + id | category | categoryIndex +----|----------|--------------- + 0 | a | 0.0 + 1 | b | 2.0 + 2 | c | 1.0 + 3 | d | 3.0 + 4 | e | 3.0 +~~~~ + +Notice that the rows containing "d" or "e" are mapped to index "3.0"
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index a503411b63612..810b02febbe77 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -17,6 +17,8 @@ package org.apache.spark.ml.feature +import scala.language.existentials + import org.apache.hadoop.fs.Path import org.apache.spark.SparkException @@ -24,7 +26,7 @@ import org.apache.spark.annotation.Since import org.apache.spark.ml.{Estimator, Model, Transformer} import org.apache.spark.ml.attribute.{Attribute, NominalAttribute} import org.apache.spark.ml.param._ -import org.apache.spark.ml.param.shared._ +import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} import org.apache.spark.ml.util._ import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.functions._ @@ -34,8 +36,27 @@ import org.apache.spark.util.collection.OpenHashMap /** * Base trait for [[StringIndexer]] and [[StringIndexerModel]]. */ -private[feature] trait StringIndexerBase extends Params with HasInputCol with HasOutputCol - with HasHandleInvalid { +private[feature] trait StringIndexerBase extends Params with HasInputCol with HasOutputCol { + + /** + * Param for how to handle unseen labels. Options are 'skip' (filter out rows with + * unseen labels), 'error' (throw an error), or 'keep' (put unseen labels in a special additional + * bucket, at index numLabels. + * Default: "error" + * @group param + */ + @Since("1.6.0") + val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", "how to handle " + + "unseen labels. Options are 'skip' (filter out rows with unseen labels), " + + "error (throw an error), or 'keep' (put unseen labels in a special additional bucket, " + + "at index numLabels).", + ParamValidators.inArray(StringIndexer.supportedHandleInvalids)) + + setDefault(handleInvalid, StringIndexer.ERROR_UNSEEN_LABEL) + + /** @group getParam */ + @Since("1.6.0") + def getHandleInvalid: String = $(handleInvalid) /** Validates and transforms the input schema. */ protected def validateAndTransformSchema(schema: StructType): StructType = { @@ -73,7 +94,6 @@ class StringIndexer @Since("1.4.0") ( /** @group setParam */ @Since("1.6.0") def setHandleInvalid(value: String): this.type = set(handleInvalid, value) - setDefault(handleInvalid, "error") /** @group setParam */ @Since("1.4.0") @@ -105,6 +125,11 @@ class StringIndexer @Since("1.4.0") ( @Since("1.6.0") object StringIndexer extends DefaultParamsReadable[StringIndexer] { + private[feature] val SKIP_UNSEEN_LABEL: String = "skip" + private[feature] val ERROR_UNSEEN_LABEL: String = "error" + private[feature] val KEEP_UNSEEN_LABEL: String = "keep" + private[feature] val supportedHandleInvalids: Array[String] = + Array(SKIP_UNSEEN_LABEL, ERROR_UNSEEN_LABEL, KEEP_UNSEEN_LABEL) @Since("1.6.0") override def load(path: String): StringIndexer = super.load(path) @@ -144,7 +169,6 @@ class StringIndexerModel ( /** @group setParam */ @Since("1.6.0") def setHandleInvalid(value: String): this.type = set(handleInvalid, value) - setDefault(handleInvalid, "error") /** @group setParam */ @Since("1.4.0") @@ -163,25 +187,34 @@ class StringIndexerModel ( } transformSchema(dataset.schema, logging = true) - val indexer = udf { label: String => - if (labelToIndex.contains(label)) { - labelToIndex(label) - } else { - throw new SparkException(s"Unseen label: $label.") - } + val filteredLabels = getHandleInvalid match { + case StringIndexer.KEEP_UNSEEN_LABEL => labels :+ "__unknown" + case _ => labels } val metadata = NominalAttribute.defaultAttr - .withName($(outputCol)).withValues(labels).toMetadata() + .withName($(outputCol)).withValues(filteredLabels).toMetadata() // If we are skipping invalid records, filter them out. - val filteredDataset = getHandleInvalid match { - case "skip" => + val (filteredDataset, keepInvalid) = getHandleInvalid match { + case StringIndexer.SKIP_UNSEEN_LABEL => val filterer = udf { label: String => labelToIndex.contains(label) } - dataset.where(filterer(dataset($(inputCol)))) - case _ => dataset + (dataset.where(filterer(dataset($(inputCol)))), false) + case _ => (dataset, getHandleInvalid == StringIndexer.KEEP_UNSEEN_LABEL) } + + val indexer = udf { label: String => + if (labelToIndex.contains(label)) { + labelToIndex(label) + } else if (keepInvalid) { + labels.length + } else { + throw new SparkException(s"Unseen label: $label. To handle unseen labels, " + + s"set Param handleInvalid to ${StringIndexer.KEEP_UNSEEN_LABEL}.") + } + } + filteredDataset.select(col("*"), indexer(dataset($(inputCol)).cast(StringType)).as($(outputCol), metadata)) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala index 2d0e63c9d669c..188dffb3dd55f 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala @@ -64,7 +64,7 @@ class StringIndexerSuite test("StringIndexerUnseen") { val data = Seq((0, "a"), (1, "b"), (4, "b")) - val data2 = Seq((0, "a"), (1, "b"), (2, "c")) + val data2 = Seq((0, "a"), (1, "b"), (2, "c"), (3, "d")) val df = data.toDF("id", "label") val df2 = data2.toDF("id", "label") val indexer = new StringIndexer() @@ -75,22 +75,32 @@ class StringIndexerSuite intercept[SparkException] { indexer.transform(df2).collect() } - val indexerSkipInvalid = new StringIndexer() - .setInputCol("label") - .setOutputCol("labelIndex") - .setHandleInvalid("skip") - .fit(df) + + indexer.setHandleInvalid("skip") // Verify that we skip the c record - val transformed = indexerSkipInvalid.transform(df2) - val attr = Attribute.fromStructField(transformed.schema("labelIndex")) + val transformedSkip = indexer.transform(df2) + val attrSkip = Attribute.fromStructField(transformedSkip.schema("labelIndex")) .asInstanceOf[NominalAttribute] - assert(attr.values.get === Array("b", "a")) - val output = transformed.select("id", "labelIndex").rdd.map { r => + assert(attrSkip.values.get === Array("b", "a")) + val outputSkip = transformedSkip.select("id", "labelIndex").rdd.map { r => (r.getInt(0), r.getDouble(1)) }.collect().toSet // a -> 1, b -> 0 - val expected = Set((0, 1.0), (1, 0.0)) - assert(output === expected) + val expectedSkip = Set((0, 1.0), (1, 0.0)) + assert(outputSkip === expectedSkip) + + indexer.setHandleInvalid("keep") + // Verify that we keep the unseen records + val transformedKeep = indexer.transform(df2) + val attrKeep = Attribute.fromStructField(transformedKeep.schema("labelIndex")) + .asInstanceOf[NominalAttribute] + assert(attrKeep.values.get === Array("b", "a", "__unknown")) + val outputKeep = transformedKeep.select("id", "labelIndex").rdd.map { r => + (r.getInt(0), r.getDouble(1)) + }.collect().toSet + // a -> 1, b -> 0, c -> 2, d -> 3 + val expectedKeep = Set((0, 1.0), (1, 0.0), (2, 2.0), (3, 2.0)) + assert(outputKeep === expectedKeep) } test("StringIndexer with a numeric input column") { diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 56b8c0b95e8a4..bd4528bd21264 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -914,6 +914,10 @@ object MimaExcludes { ) ++ Seq( // [SPARK-17163] Unify logistic regression interface. Private constructor has new signature. ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionModel.this") + ) ++ Seq( + // [SPARK-17498] StringIndexer enhancement for handling unseen labels + ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ml.feature.StringIndexer"), + ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ml.feature.StringIndexerModel") ) ++ Seq( // [SPARK-17365][Core] Remove/Kill multiple executors together to reduce RPC call time ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.SparkContext") From d69aeeaff4f90ce92ee9e84f24905ea9efa7ece2 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 7 Mar 2017 11:32:36 -0800 Subject: [PATCH 28/78] [SPARK-19516][DOC] update public doc to use SparkSession instead of SparkContext ## What changes were proposed in this pull request? After Spark 2.0, `SparkSession` becomes the new entry point for Spark applications. We should update the public documents to reflect this. ## How was this patch tested? N/A Author: Wenchen Fan Closes #16856 from cloud-fan/doc. --- docs/quick-start.md | 153 ++++++++---------- ...ming-guide.md => rdd-programming-guide.md} | 26 +-- 2 files changed, 76 insertions(+), 103 deletions(-) rename docs/{programming-guide.md => rdd-programming-guide.md} (99%) diff --git a/docs/quick-start.md b/docs/quick-start.md index aa4319a23325c..b88ae5f6bb313 100644 --- a/docs/quick-start.md +++ b/docs/quick-start.md @@ -10,12 +10,13 @@ description: Quick start tutorial for Spark SPARK_VERSION_SHORT This tutorial provides a quick introduction to using Spark. We will first introduce the API through Spark's interactive shell (in Python or Scala), then show how to write applications in Java, Scala, and Python. -See the [programming guide](programming-guide.html) for a more complete reference. To follow along with this guide, first download a packaged release of Spark from the [Spark website](http://spark.apache.org/downloads.html). Since we won't be using HDFS, you can download a package for any version of Hadoop. +Note that, before Spark 2.0, the main programming interface of Spark was the Resilient Distributed Dataset (RDD). After Spark 2.0, RDDs are replaced by Dataset, which is strongly-typed like an RDD, but with richer optimizations under the hood. The RDD interface is still supported, and you can get a more complete reference at the [RDD programming guide](rdd-programming-guide.html). However, we highly recommend you to switch to use Dataset, which has better performance than RDD. See the [SQL programming guide](sql-programming-guide.html) to get more information about Dataset. + # Interactive Analysis with the Spark Shell ## Basics @@ -29,28 +30,28 @@ or Python. Start it by running the following in the Spark directory: ./bin/spark-shell -Spark's primary abstraction is a distributed collection of items called a Resilient Distributed Dataset (RDD). RDDs can be created from Hadoop InputFormats (such as HDFS files) or by transforming other RDDs. Let's make a new RDD from the text of the README file in the Spark source directory: +Spark's primary abstraction is a distributed collection of items called a Dataset. Datasets can be created from Hadoop InputFormats (such as HDFS files) or by transforming other Datasets. Let's make a new Dataset from the text of the README file in the Spark source directory: {% highlight scala %} -scala> val textFile = sc.textFile("README.md") -textFile: org.apache.spark.rdd.RDD[String] = README.md MapPartitionsRDD[1] at textFile at :25 +scala> val textFile = spark.read.textFile("README.md") +textFile: org.apache.spark.sql.Dataset[String] = [value: string] {% endhighlight %} -RDDs have _[actions](programming-guide.html#actions)_, which return values, and _[transformations](programming-guide.html#transformations)_, which return pointers to new RDDs. Let's start with a few actions: +You can get values from Dataset directly, by calling some actions, or transform the Dataset to get a new one. For more details, please read the _[API doc](api/scala/index.html#org.apache.spark.sql.Dataset)_. {% highlight scala %} -scala> textFile.count() // Number of items in this RDD +scala> textFile.count() // Number of items in this Dataset res0: Long = 126 // May be different from yours as README.md will change over time, similar to other outputs -scala> textFile.first() // First item in this RDD +scala> textFile.first() // First item in this Dataset res1: String = # Apache Spark {% endhighlight %} -Now let's use a transformation. We will use the [`filter`](programming-guide.html#transformations) transformation to return a new RDD with a subset of the items in the file. +Now let's transform this Dataset to a new one. We call `filter` to return a new Dataset with a subset of the items in the file. {% highlight scala %} scala> val linesWithSpark = textFile.filter(line => line.contains("Spark")) -linesWithSpark: org.apache.spark.rdd.RDD[String] = MapPartitionsRDD[2] at filter at :27 +linesWithSpark: org.apache.spark.sql.Dataset[String] = [value: string] {% endhighlight %} We can chain together transformations and actions: @@ -65,32 +66,32 @@ res3: Long = 15 ./bin/pyspark -Spark's primary abstraction is a distributed collection of items called a Resilient Distributed Dataset (RDD). RDDs can be created from Hadoop InputFormats (such as HDFS files) or by transforming other RDDs. Let's make a new RDD from the text of the README file in the Spark source directory: +Spark's primary abstraction is a distributed collection of items called a Dataset. Datasets can be created from Hadoop InputFormats (such as HDFS files) or by transforming other Datasets. Due to Python's dynamic nature, we don't need the Dataset to be strongly-typed in Python. As a result, all Datasets in Python are Dataset[Row], and we call it `DataFrame` to be consistent with the data frame concept in Pandas and R. Let's make a new DataFrame from the text of the README file in the Spark source directory: {% highlight python %} ->>> textFile = sc.textFile("README.md") +>>> textFile = spark.read.text("README.md") {% endhighlight %} -RDDs have _[actions](programming-guide.html#actions)_, which return values, and _[transformations](programming-guide.html#transformations)_, which return pointers to new RDDs. Let's start with a few actions: +You can get values from DataFrame directly, by calling some actions, or transform the DataFrame to get a new one. For more details, please read the _[API doc](api/python/index.html#pyspark.sql.DataFrame)_. {% highlight python %} ->>> textFile.count() # Number of items in this RDD +>>> textFile.count() # Number of rows in this DataFrame 126 ->>> textFile.first() # First item in this RDD -u'# Apache Spark' +>>> textFile.first() # First row in this DataFrame +Row(value=u'# Apache Spark') {% endhighlight %} -Now let's use a transformation. We will use the [`filter`](programming-guide.html#transformations) transformation to return a new RDD with a subset of the items in the file. +Now let's transform this DataFrame to a new one. We call `filter` to return a new DataFrame with a subset of the lines in the file. {% highlight python %} ->>> linesWithSpark = textFile.filter(lambda line: "Spark" in line) +>>> linesWithSpark = textFile.filter(textFile.value.contains("Spark")) {% endhighlight %} We can chain together transformations and actions: {% highlight python %} ->>> textFile.filter(lambda line: "Spark" in line).count() # How many lines contain "Spark"? +>>> textFile.filter(textFile.value.contains("Spark")).count() # How many lines contain "Spark"? 15 {% endhighlight %} @@ -98,8 +99,8 @@ We can chain together transformations and actions:
-## More on RDD Operations -RDD actions and transformations can be used for more complex computations. Let's say we want to find the line with the most words: +## More on Dataset Operations +Dataset actions and transformations can be used for more complex computations. Let's say we want to find the line with the most words:
@@ -109,7 +110,7 @@ scala> textFile.map(line => line.split(" ").size).reduce((a, b) => if (a > b) a res4: Long = 15 {% endhighlight %} -This first maps a line to an integer value, creating a new RDD. `reduce` is called on that RDD to find the largest line count. The arguments to `map` and `reduce` are Scala function literals (closures), and can use any language feature or Scala/Java library. For example, we can easily call functions declared elsewhere. We'll use `Math.max()` function to make this code easier to understand: +This first maps a line to an integer value, creating a new Dataset. `reduce` is called on that Dataset to find the largest word count. The arguments to `map` and `reduce` are Scala function literals (closures), and can use any language feature or Scala/Java library. For example, we can easily call functions declared elsewhere. We'll use `Math.max()` function to make this code easier to understand: {% highlight scala %} scala> import java.lang.Math @@ -122,11 +123,11 @@ res5: Int = 15 One common data flow pattern is MapReduce, as popularized by Hadoop. Spark can implement MapReduce flows easily: {% highlight scala %} -scala> val wordCounts = textFile.flatMap(line => line.split(" ")).map(word => (word, 1)).reduceByKey((a, b) => a + b) -wordCounts: org.apache.spark.rdd.RDD[(String, Int)] = ShuffledRDD[8] at reduceByKey at :28 +scala> val wordCounts = textFile.flatMap(line => line.split(" ")).groupByKey(identity).count() +wordCounts: org.apache.spark.sql.Dataset[(String, Long)] = [value: string, count(1): bigint] {% endhighlight %} -Here, we combined the [`flatMap`](programming-guide.html#transformations), [`map`](programming-guide.html#transformations), and [`reduceByKey`](programming-guide.html#transformations) transformations to compute the per-word counts in the file as an RDD of (String, Int) pairs. To collect the word counts in our shell, we can use the [`collect`](programming-guide.html#actions) action: +Here, we call `flatMap` to transform a Dataset of lines to a Dataset of words, and then combine `groupByKey` and `count` to compute the per-word counts in the file as a Dataset of (String, Long) pairs. To collect the word counts in our shell, we can call `collect`: {% highlight scala %} scala> wordCounts.collect() @@ -137,37 +138,24 @@ res6: Array[(String, Int)] = Array((means,1), (under,2), (this,3), (Because,1),
{% highlight python %} ->>> textFile.map(lambda line: len(line.split())).reduce(lambda a, b: a if (a > b) else b) -15 +>>> from pyspark.sql.functions import * +>>> textFile.select(size(split(textFile.value, "\s+")).name("numWords")).agg(max(col("numWords"))).collect() +[Row(max(numWords)=15)] {% endhighlight %} -This first maps a line to an integer value, creating a new RDD. `reduce` is called on that RDD to find the largest line count. The arguments to `map` and `reduce` are Python [anonymous functions (lambdas)](https://docs.python.org/2/reference/expressions.html#lambda), -but we can also pass any top-level Python function we want. -For example, we'll define a `max` function to make this code easier to understand: - -{% highlight python %} ->>> def max(a, b): -... if a > b: -... return a -... else: -... return b -... - ->>> textFile.map(lambda line: len(line.split())).reduce(max) -15 -{% endhighlight %} +This first maps a line to an integer value and aliases it as "numWords", creating a new DataFrame. `agg` is called on that DataFrame to find the largest word count. The arguments to `select` and `agg` are both _[Column](api/python/index.html#pyspark.sql.Column)_, we can use `df.colName` to get a column from a DataFrame. We can also import pyspark.sql.functions, which provides a lot of convenient functions to build a new Column from an old one. One common data flow pattern is MapReduce, as popularized by Hadoop. Spark can implement MapReduce flows easily: {% highlight python %} ->>> wordCounts = textFile.flatMap(lambda line: line.split()).map(lambda word: (word, 1)).reduceByKey(lambda a, b: a+b) +>>> wordCounts = textFile.select(explode(split(textFile.value, "\s+")).as("word")).groupBy("word").count() {% endhighlight %} -Here, we combined the [`flatMap`](programming-guide.html#transformations), [`map`](programming-guide.html#transformations), and [`reduceByKey`](programming-guide.html#transformations) transformations to compute the per-word counts in the file as an RDD of (string, int) pairs. To collect the word counts in our shell, we can use the [`collect`](programming-guide.html#actions) action: +Here, we use the `explode` function in `select`, to transfrom a Dataset of lines to a Dataset of words, and then combine `groupBy` and `count` to compute the per-word counts in the file as a DataFrame of 2 columns: "word" and "count". To collect the word counts in our shell, we can call `collect`: {% highlight python %} >>> wordCounts.collect() -[(u'and', 9), (u'A', 1), (u'webpage', 1), (u'README', 1), (u'Note', 1), (u'"local"', 1), (u'variable', 1), ...] +[Row(word=u'online', count=1), Row(word=u'graphs', count=1), ...] {% endhighlight %}
@@ -181,7 +169,7 @@ Spark also supports pulling data sets into a cluster-wide in-memory cache. This {% highlight scala %} scala> linesWithSpark.cache() -res7: linesWithSpark.type = MapPartitionsRDD[2] at filter at :27 +res7: linesWithSpark.type = [value: string] scala> linesWithSpark.count() res8: Long = 15 @@ -193,7 +181,7 @@ res9: Long = 15 It may seem silly to use Spark to explore and cache a 100-line text file. The interesting part is that these same functions can be used on very large data sets, even when they are striped across tens or hundreds of nodes. You can also do this interactively by connecting `bin/spark-shell` to -a cluster, as described in the [programming guide](programming-guide.html#initializing-spark). +a cluster, as described in the [RDD programming guide](rdd-programming-guide.html#using-the-shell).
@@ -211,7 +199,7 @@ a cluster, as described in the [programming guide](programming-guide.html#initia It may seem silly to use Spark to explore and cache a 100-line text file. The interesting part is that these same functions can be used on very large data sets, even when they are striped across tens or hundreds of nodes. You can also do this interactively by connecting `bin/pyspark` to -a cluster, as described in the [programming guide](programming-guide.html#initializing-spark). +a cluster, as described in the [RDD programming guide](rdd-programming-guide.html#using-the-shell).
@@ -228,20 +216,17 @@ named `SimpleApp.scala`: {% highlight scala %} /* SimpleApp.scala */ -import org.apache.spark.SparkContext -import org.apache.spark.SparkContext._ -import org.apache.spark.SparkConf +import org.apache.spark.sql.SparkSession object SimpleApp { def main(args: Array[String]) { val logFile = "YOUR_SPARK_HOME/README.md" // Should be some file on your system - val conf = new SparkConf().setAppName("Simple Application") - val sc = new SparkContext(conf) - val logData = sc.textFile(logFile, 2).cache() + val spark = SparkSession.builder.appName("Simple Application").getOrCreate() + val logData = spark.read.textFile(logFile).cache() val numAs = logData.filter(line => line.contains("a")).count() val numBs = logData.filter(line => line.contains("b")).count() println(s"Lines with a: $numAs, Lines with b: $numBs") - sc.stop() + spark.stop() } } {% endhighlight %} @@ -251,16 +236,13 @@ Subclasses of `scala.App` may not work correctly. This program just counts the number of lines containing 'a' and the number containing 'b' in the Spark README. Note that you'll need to replace YOUR_SPARK_HOME with the location where Spark is -installed. Unlike the earlier examples with the Spark shell, which initializes its own SparkContext, -we initialize a SparkContext as part of the program. +installed. Unlike the earlier examples with the Spark shell, which initializes its own SparkSession, +we initialize a SparkSession as part of the program. -We pass the SparkContext constructor a -[SparkConf](api/scala/index.html#org.apache.spark.SparkConf) -object which contains information about our -application. +We call `SparkSession.builder` to construct a [[SparkSession]], then set the application name, and finally call `getOrCreate` to get the [[SparkSession]] instance. -Our application depends on the Spark API, so we'll also include an sbt configuration file, -`build.sbt`, which explains that Spark is a dependency. This file also adds a repository that +Our application depends on the Spark API, so we'll also include an sbt configuration file, +`build.sbt`, which explains that Spark is a dependency. This file also adds a repository that Spark depends on: {% highlight scala %} @@ -270,7 +252,7 @@ version := "1.0" scalaVersion := "{{site.SCALA_VERSION}}" -libraryDependencies += "org.apache.spark" %% "spark-core" % "{{site.SPARK_VERSION}}" +libraryDependencies += "org.apache.spark" %% "spark-sql" % "{{site.SPARK_VERSION}}" {% endhighlight %} For sbt to work correctly, we'll need to layout `SimpleApp.scala` and `build.sbt` @@ -309,34 +291,28 @@ We'll create a very simple Spark application, `SimpleApp.java`: {% highlight java %} /* SimpleApp.java */ -import org.apache.spark.api.java.*; -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.function.Function; +import org.apache.spark.sql.SparkSession; public class SimpleApp { public static void main(String[] args) { String logFile = "YOUR_SPARK_HOME/README.md"; // Should be some file on your system - SparkConf conf = new SparkConf().setAppName("Simple Application"); - JavaSparkContext sc = new JavaSparkContext(conf); - JavaRDD logData = sc.textFile(logFile).cache(); + SparkSession spark = SparkSession.builder().appName("Simple Application").getOrCreate(); + Dataset logData = spark.read.textFile(logFile).cache(); long numAs = logData.filter(s -> s.contains("a")).count(); long numBs = logData.filter(s -> s.contains("b")).count(); System.out.println("Lines with a: " + numAs + ", lines with b: " + numBs); - - sc.stop(); + + spark.stop(); } } {% endhighlight %} -This program just counts the number of lines containing 'a' and the number containing 'b' in a text -file. Note that you'll need to replace YOUR_SPARK_HOME with the location where Spark is installed. -As with the Scala example, we initialize a SparkContext, though we use the special -`JavaSparkContext` class to get a Java-friendly one. We also create RDDs (represented by -`JavaRDD`) and run transformations on them. Finally, we pass functions to Spark by creating classes -that extend `spark.api.java.function.Function`. The -[Spark programming guide](programming-guide.html) describes these differences in more detail. +This program just counts the number of lines containing 'a' and the number containing 'b' in the +Spark README. Note that you'll need to replace YOUR_SPARK_HOME with the location where Spark is +installed. Unlike the earlier examples with the Spark shell, which initializes its own SparkSession, +we initialize a SparkSession as part of the program. To build the program, we also write a Maven `pom.xml` file that lists Spark as a dependency. Note that Spark artifacts are tagged with a Scala version. @@ -352,7 +328,7 @@ Note that Spark artifacts are tagged with a Scala version. org.apache.spark - spark-core_{{site.SCALA_BINARY_VERSION}} + spark-sql_{{site.SCALA_BINARY_VERSION}} {{site.SPARK_VERSION}} @@ -395,27 +371,25 @@ As an example, we'll create a simple Spark application, `SimpleApp.py`: {% highlight python %} """SimpleApp.py""" -from pyspark import SparkContext +from pyspark.sql import SparkSession logFile = "YOUR_SPARK_HOME/README.md" # Should be some file on your system -sc = SparkContext("local", "Simple App") -logData = sc.textFile(logFile).cache() +spark = SparkSession.builder().appName(appName).master(master).getOrCreate() +logData = spark.read.text(logFile).cache() -numAs = logData.filter(lambda s: 'a' in s).count() -numBs = logData.filter(lambda s: 'b' in s).count() +numAs = logData.filter(logData.value.contains('a')).count() +numBs = logData.filter(logData.value.contains('b')).count() print("Lines with a: %i, lines with b: %i" % (numAs, numBs)) -sc.stop() +spark.stop() {% endhighlight %} This program just counts the number of lines containing 'a' and the number containing 'b' in a text file. Note that you'll need to replace YOUR_SPARK_HOME with the location where Spark is installed. -As with the Scala and Java examples, we use a SparkContext to create RDDs. -We can pass Python functions to Spark, which are automatically serialized along with any variables -that they reference. +As with the Scala and Java examples, we use a SparkSession to create Datasets. For applications that use custom classes or third-party libraries, we can also add code dependencies to `spark-submit` through its `--py-files` argument by packaging them into a .zip file (see `spark-submit --help` for details). @@ -438,8 +412,7 @@ Lines with a: 46, Lines with b: 23 # Where to Go from Here Congratulations on running your first Spark application! -* For an in-depth overview of the API, start with the [Spark programming guide](programming-guide.html), - or see "Programming Guides" menu for other components. +* For an in-depth overview of the API, start with the [RDD programming guide](rdd-programming-guide.html) and the [SQL programming guide](sql-programming-guide.html), or see "Programming Guides" menu for other components. * For running applications on a cluster, head to the [deployment overview](cluster-overview.html). * Finally, Spark includes several samples in the `examples` directory ([Scala]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/scala/org/apache/spark/examples), diff --git a/docs/programming-guide.md b/docs/rdd-programming-guide.md similarity index 99% rename from docs/programming-guide.md rename to docs/rdd-programming-guide.md index 6740dbe0014b4..cad9ff4e646e5 100644 --- a/docs/programming-guide.md +++ b/docs/rdd-programming-guide.md @@ -24,7 +24,7 @@ along with if you launch Spark's interactive shell -- either `bin/spark-shell` f
-Spark {{site.SPARK_VERSION}} is built and distributed to work with Scala {{site.SCALA_BINARY_VERSION}} +Spark {{site.SPARK_VERSION}} is built and distributed to work with Scala {{site.SCALA_BINARY_VERSION}} by default. (Spark can be built to work with other versions of Scala, too.) To write applications in Scala, you will need to use a compatible Scala version (e.g. {{site.SCALA_BINARY_VERSION}}.X). @@ -76,10 +76,10 @@ In addition, if you wish to access an HDFS cluster, you need to add a dependency Finally, you need to import some Spark classes into your program. Add the following lines: -{% highlight scala %} -import org.apache.spark.api.java.JavaSparkContext -import org.apache.spark.api.java.JavaRDD -import org.apache.spark.SparkConf +{% highlight java %} +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.SparkConf; {% endhighlight %}
@@ -244,13 +244,13 @@ use IPython, set the `PYSPARK_DRIVER_PYTHON` variable to `ipython` when running $ PYSPARK_DRIVER_PYTHON=ipython ./bin/pyspark {% endhighlight %} -To use the Jupyter notebook (previously known as the IPython notebook), +To use the Jupyter notebook (previously known as the IPython notebook), {% highlight bash %} $ PYSPARK_DRIVER_PYTHON=jupyter ./bin/pyspark {% endhighlight %} -You can customize the `ipython` or `jupyter` commands by setting `PYSPARK_DRIVER_PYTHON_OPTS`. +You can customize the `ipython` or `jupyter` commands by setting `PYSPARK_DRIVER_PYTHON_OPTS`. After the Jupyter Notebook server is launched, you can create a new "Python 2" notebook from the "Files" tab. Inside the notebook, you can input the command `%pylab inline` as part of @@ -811,7 +811,7 @@ The variables within the closure sent to each executor are now copies and thus, In local mode, in some circumstances the `foreach` function will actually execute within the same JVM as the driver and will reference the same original **counter**, and may actually update it. -To ensure well-defined behavior in these sorts of scenarios one should use an [`Accumulator`](#accumulators). Accumulators in Spark are used specifically to provide a mechanism for safely updating a variable when execution is split up across worker nodes in a cluster. The Accumulators section of this guide discusses these in more detail. +To ensure well-defined behavior in these sorts of scenarios one should use an [`Accumulator`](#accumulators). Accumulators in Spark are used specifically to provide a mechanism for safely updating a variable when execution is split up across worker nodes in a cluster. The Accumulators section of this guide discusses these in more detail. In general, closures - constructs like loops or locally defined methods, should not be used to mutate some global state. Spark does not define or guarantee the behavior of mutations to objects referenced from outside of closures. Some code that does this may work in local mode, but that's just by accident and such code will not behave as expected in distributed mode. Use an Accumulator instead if some global aggregation is needed. @@ -1230,8 +1230,8 @@ storage levels is: -**Note:** *In Python, stored objects will always be serialized with the [Pickle](https://docs.python.org/2/library/pickle.html) library, -so it does not matter whether you choose a serialized level. The available storage levels in Python include `MEMORY_ONLY`, `MEMORY_ONLY_2`, +**Note:** *In Python, stored objects will always be serialized with the [Pickle](https://docs.python.org/2/library/pickle.html) library, +so it does not matter whether you choose a serialized level. The available storage levels in Python include `MEMORY_ONLY`, `MEMORY_ONLY_2`, `MEMORY_AND_DISK`, `MEMORY_AND_DISK_2`, `DISK_ONLY`, and `DISK_ONLY_2`.* Spark also automatically persists some intermediate data in shuffle operations (e.g. `reduceByKey`), even without users calling `persist`. This is done to avoid recomputing the entire input if a node fails during the shuffle. We still recommend users call `persist` on the resulting RDD if they plan to reuse it. @@ -1346,7 +1346,7 @@ As a user, you can create named or unnamed accumulators. As seen in the image be Accumulators in the Spark UI

-Tracking accumulators in the UI can be useful for understanding the progress of +Tracking accumulators in the UI can be useful for understanding the progress of running stages (NOTE: this is not yet supported in Python).
@@ -1355,7 +1355,7 @@ running stages (NOTE: this is not yet supported in Python). A numeric accumulator can be created by calling `SparkContext.longAccumulator()` or `SparkContext.doubleAccumulator()` to accumulate values of type Long or Double, respectively. Tasks running on a cluster can then add to it using -the `add` method. However, they cannot read its value. Only the driver program can read the accumulator's value, +the `add` method. However, they cannot read its value. Only the driver program can read the accumulator's value, using its `value` method. The code below shows an accumulator being used to add up the elements of an array: @@ -1409,7 +1409,7 @@ Note that, when programmers define their own type of AccumulatorV2, the resultin A numeric accumulator can be created by calling `SparkContext.longAccumulator()` or `SparkContext.doubleAccumulator()` to accumulate values of type Long or Double, respectively. Tasks running on a cluster can then add to it using -the `add` method. However, they cannot read its value. Only the driver program can read the accumulator's value, +the `add` method. However, they cannot read its value. Only the driver program can read the accumulator's value, using its `value` method. The code below shows an accumulator being used to add up the elements of an array: From 49570ed05d44f96549c49929f35c1c202556731a Mon Sep 17 00:00:00 2001 From: uncleGen Date: Tue, 7 Mar 2017 12:24:53 -0800 Subject: [PATCH 29/78] [SPARK-19803][TEST] flaky BlockManagerReplicationSuite test failure ## What changes were proposed in this pull request? 200ms may be too short. Give more time for replication to happen and new block be reported to master ## How was this patch tested? test manully Author: uncleGen Author: dylon Closes #17144 from uncleGen/SPARK-19803. --- .../spark/storage/BlockManagerReplicationSuite.scala | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala index ccede34b8cb4d..75dc04038debc 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala @@ -489,12 +489,12 @@ class BlockManagerProactiveReplicationSuite extends BlockManagerReplicationBehav Thread.sleep(200) } - // giving enough time for replication complete and locks released - Thread.sleep(500) - - val newLocations = master.getLocations(blockId).toSet + val newLocations = eventually(timeout(5 seconds), interval(10 millis)) { + val _newLocations = master.getLocations(blockId).toSet + assert(_newLocations.size === replicationFactor) + _newLocations + } logInfo(s"New locations : $newLocations") - assert(newLocations.size === replicationFactor) // there should only be one common block manager between initial and new locations assert(newLocations.intersect(blockLocations.toSet).size === 1) From 6f4684622a951806bebe7652a14f7d1ce03e24c7 Mon Sep 17 00:00:00 2001 From: Jason White Date: Tue, 7 Mar 2017 13:14:37 -0800 Subject: [PATCH 30/78] [SPARK-19561] [PYTHON] cast TimestampType.toInternal output to long MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? Cast the output of `TimestampType.toInternal` to long to allow for proper Timestamp creation in DataFrames near the epoch. ## How was this patch tested? Added a new test that fails without the change. dongjoon-hyun davies Mind taking a look? The contribution is my original work and I license the work to the project under the project’s open source license. Author: Jason White Closes #16896 from JasonMWhite/SPARK-19561. --- python/pyspark/sql/tests.py | 6 ++++++ python/pyspark/sql/types.py | 2 +- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 81f3d1d36a342..4d48ef694d68f 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -1555,6 +1555,12 @@ def test_time_with_timezone(self): self.assertEqual(now, now1) self.assertEqual(now, utcnow1) + # regression test for SPARK-19561 + def test_datetime_at_epoch(self): + epoch = datetime.datetime.fromtimestamp(0) + df = self.spark.createDataFrame([Row(date=epoch)]) + self.assertEqual(df.first()['date'], epoch) + def test_decimal(self): from decimal import Decimal schema = StructType([StructField("decimal", DecimalType(10, 5))]) diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 26b54a7fb3709..1d31f25efad52 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -189,7 +189,7 @@ def toInternal(self, dt): if dt is not None: seconds = (calendar.timegm(dt.utctimetuple()) if dt.tzinfo else time.mktime(dt.timetuple())) - return int(seconds) * 1000000 + dt.microsecond + return long(seconds) * 1000000 + dt.microsecond def fromInternal(self, ts): if ts is not None: From 2e30c0b9bcaa6f7757bd85d1f1ec392d5f916f83 Mon Sep 17 00:00:00 2001 From: Michael Gummelt Date: Tue, 7 Mar 2017 21:29:08 +0000 Subject: [PATCH 31/78] [SPARK-19702][MESOS] Increase default refuse_seconds timeout in the Mesos Spark Dispatcher ## What changes were proposed in this pull request? Increase default refuse_seconds timeout, and make it configurable. See JIRA for details on how this reduces the risk of starvation. ## How was this patch tested? Unit tests, Manual testing, and Mesos/Spark integration test suite cc susanxhuynh skonto jmlvanre Author: Michael Gummelt Closes #17031 from mgummelt/SPARK-19702-suppress-revive. --- .../cluster/mesos/MesosClusterScheduler.scala | 75 +++++++++++++------ .../MesosCoarseGrainedSchedulerBackend.scala | 69 +++++++---------- .../MesosFineGrainedSchedulerBackend.scala | 19 +++-- .../cluster/mesos/MesosSchedulerUtils.scala | 60 +++++++++++---- .../mesos/MesosClusterSchedulerSuite.scala | 51 ++++++++----- ...osCoarseGrainedSchedulerBackendSuite.scala | 7 +- .../spark/scheduler/cluster/mesos/Utils.scala | 11 +++ 7 files changed, 187 insertions(+), 105 deletions(-) diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala index 2760f31b12fa7..1bc6f71860c3f 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala @@ -152,6 +152,7 @@ private[spark] class MesosClusterScheduler( // is registered with Mesos master. @volatile protected var ready = false private var masterInfo: Option[MasterInfo] = None + private var schedulerDriver: SchedulerDriver = _ def submitDriver(desc: MesosDriverDescription): CreateSubmissionResponse = { val c = new CreateSubmissionResponse @@ -168,9 +169,8 @@ private[spark] class MesosClusterScheduler( return c } c.submissionId = desc.submissionId - queuedDriversState.persist(desc.submissionId, desc) - queuedDrivers += desc c.success = true + addDriverToQueue(desc) } c } @@ -191,7 +191,7 @@ private[spark] class MesosClusterScheduler( // 4. Check if it has already completed. if (launchedDrivers.contains(submissionId)) { val task = launchedDrivers(submissionId) - mesosDriver.killTask(task.taskId) + schedulerDriver.killTask(task.taskId) k.success = true k.message = "Killing running driver" } else if (removeFromQueuedDrivers(submissionId)) { @@ -324,7 +324,7 @@ private[spark] class MesosClusterScheduler( ready = false metricsSystem.report() metricsSystem.stop() - mesosDriver.stop(true) + schedulerDriver.stop(true) } override def registered( @@ -340,6 +340,8 @@ private[spark] class MesosClusterScheduler( stateLock.synchronized { this.masterInfo = Some(masterInfo) + this.schedulerDriver = driver + if (!pendingRecover.isEmpty) { // Start task reconciliation if we need to recover. val statuses = pendingRecover.collect { @@ -506,11 +508,10 @@ private[spark] class MesosClusterScheduler( } private class ResourceOffer( - val offerId: OfferID, - val slaveId: SlaveID, - var resources: JList[Resource]) { + val offer: Offer, + var remainingResources: JList[Resource]) { override def toString(): String = { - s"Offer id: ${offerId}, resources: ${resources}" + s"Offer id: ${offer.getId}, resources: ${remainingResources}" } } @@ -518,16 +519,16 @@ private[spark] class MesosClusterScheduler( val taskId = TaskID.newBuilder().setValue(desc.submissionId).build() val (remainingResources, cpuResourcesToUse) = - partitionResources(offer.resources, "cpus", desc.cores) + partitionResources(offer.remainingResources, "cpus", desc.cores) val (finalResources, memResourcesToUse) = partitionResources(remainingResources.asJava, "mem", desc.mem) - offer.resources = finalResources.asJava + offer.remainingResources = finalResources.asJava val appName = desc.conf.get("spark.app.name") val taskInfo = TaskInfo.newBuilder() .setTaskId(taskId) .setName(s"Driver for ${appName}") - .setSlaveId(offer.slaveId) + .setSlaveId(offer.offer.getSlaveId) .setCommand(buildDriverCommand(desc)) .addAllResources(cpuResourcesToUse.asJava) .addAllResources(memResourcesToUse.asJava) @@ -549,23 +550,29 @@ private[spark] class MesosClusterScheduler( val driverCpu = submission.cores val driverMem = submission.mem logTrace(s"Finding offer to launch driver with cpu: $driverCpu, mem: $driverMem") - val offerOption = currentOffers.find { o => - getResource(o.resources, "cpus") >= driverCpu && - getResource(o.resources, "mem") >= driverMem + val offerOption = currentOffers.find { offer => + getResource(offer.remainingResources, "cpus") >= driverCpu && + getResource(offer.remainingResources, "mem") >= driverMem } if (offerOption.isEmpty) { logDebug(s"Unable to find offer to launch driver id: ${submission.submissionId}, " + s"cpu: $driverCpu, mem: $driverMem") } else { val offer = offerOption.get - val queuedTasks = tasks.getOrElseUpdate(offer.offerId, new ArrayBuffer[TaskInfo]) + val queuedTasks = tasks.getOrElseUpdate(offer.offer.getId, new ArrayBuffer[TaskInfo]) try { val task = createTaskInfo(submission, offer) queuedTasks += task - logTrace(s"Using offer ${offer.offerId.getValue} to launch driver " + + logTrace(s"Using offer ${offer.offer.getId.getValue} to launch driver " + submission.submissionId) - val newState = new MesosClusterSubmissionState(submission, task.getTaskId, offer.slaveId, - None, new Date(), None, getDriverFrameworkID(submission)) + val newState = new MesosClusterSubmissionState( + submission, + task.getTaskId, + offer.offer.getSlaveId, + None, + new Date(), + None, + getDriverFrameworkID(submission)) launchedDrivers(submission.submissionId) = newState launchedDriversState.persist(submission.submissionId, newState) afterLaunchCallback(submission.submissionId) @@ -588,7 +595,7 @@ private[spark] class MesosClusterScheduler( val currentTime = new Date() val currentOffers = offers.asScala.map { - o => new ResourceOffer(o.getId, o.getSlaveId, o.getResourcesList) + offer => new ResourceOffer(offer, offer.getResourcesList) }.toList stateLock.synchronized { @@ -615,8 +622,8 @@ private[spark] class MesosClusterScheduler( driver.launchTasks(Collections.singleton(offerId), taskInfos.asJava) } - for (o <- currentOffers if !tasks.contains(o.offerId)) { - driver.declineOffer(o.offerId) + for (offer <- currentOffers if !tasks.contains(offer.offer.getId)) { + declineOffer(driver, offer.offer, None, Some(getRejectOfferDuration(conf))) } } @@ -662,6 +669,12 @@ private[spark] class MesosClusterScheduler( override def statusUpdate(driver: SchedulerDriver, status: TaskStatus): Unit = { val taskId = status.getTaskId.getValue + + logInfo(s"Received status update: taskId=${taskId}" + + s" state=${status.getState}" + + s" message=${status.getMessage}" + + s" reason=${status.getReason}"); + stateLock.synchronized { if (launchedDrivers.contains(taskId)) { if (status.getReason == Reason.REASON_RECONCILIATION && @@ -682,8 +695,7 @@ private[spark] class MesosClusterScheduler( val newDriverDescription = state.driverDescription.copy( retryState = Some(new MesosClusterRetryState(status, retries, nextRetry, waitTimeSec))) - pendingRetryDrivers += newDriverDescription - pendingRetryDriversState.persist(taskId, newDriverDescription) + addDriverToPending(newDriverDescription, taskId); } else if (TaskState.isFinished(mesosToTaskState(status.getState))) { removeFromLaunchedDrivers(taskId) state.finishDate = Some(new Date()) @@ -746,4 +758,21 @@ private[spark] class MesosClusterScheduler( def getQueuedDriversSize: Int = queuedDrivers.size def getLaunchedDriversSize: Int = launchedDrivers.size def getPendingRetryDriversSize: Int = pendingRetryDrivers.size + + private def addDriverToQueue(desc: MesosDriverDescription): Unit = { + queuedDriversState.persist(desc.submissionId, desc) + queuedDrivers += desc + revive() + } + + private def addDriverToPending(desc: MesosDriverDescription, taskId: String) = { + pendingRetryDriversState.persist(taskId, desc) + pendingRetryDrivers += desc + revive() + } + + private def revive(): Unit = { + logInfo("Reviving Offers.") + schedulerDriver.reviveOffers() + } } diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala index f69c223ab9b6d..85c2e9c76f4b0 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala @@ -26,6 +26,7 @@ import scala.collection.mutable import scala.concurrent.Future import org.apache.mesos.Protos.{TaskInfo => MesosTaskInfo, _} +import org.apache.mesos.SchedulerDriver import org.apache.spark.{SecurityManager, SparkContext, SparkException, TaskState} import org.apache.spark.network.netty.SparkTransportConf @@ -119,11 +120,11 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( // Reject offers with mismatched constraints in seconds private val rejectOfferDurationForUnmetConstraints = - getRejectOfferDurationForUnmetConstraints(sc) + getRejectOfferDurationForUnmetConstraints(sc.conf) // Reject offers when we reached the maximum number of cores for this framework private val rejectOfferDurationForReachedMaxCores = - getRejectOfferDurationForReachedMaxCores(sc) + getRejectOfferDurationForReachedMaxCores(sc.conf) // A client for talking to the external shuffle service private val mesosExternalShuffleClient: Option[MesosExternalShuffleClient] = { @@ -146,6 +147,8 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( @volatile var appId: String = _ + private var schedulerDriver: SchedulerDriver = _ + def newMesosTaskId(): String = { val id = nextMesosTaskId nextMesosTaskId += 1 @@ -252,9 +255,12 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( override def offerRescinded(d: org.apache.mesos.SchedulerDriver, o: OfferID) {} override def registered( - d: org.apache.mesos.SchedulerDriver, frameworkId: FrameworkID, masterInfo: MasterInfo) { - appId = frameworkId.getValue - mesosExternalShuffleClient.foreach(_.init(appId)) + driver: org.apache.mesos.SchedulerDriver, + frameworkId: FrameworkID, + masterInfo: MasterInfo) { + this.appId = frameworkId.getValue + this.mesosExternalShuffleClient.foreach(_.init(appId)) + this.schedulerDriver = driver markRegistered() } @@ -293,46 +299,25 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( } private def declineUnmatchedOffers( - d: org.apache.mesos.SchedulerDriver, offers: mutable.Buffer[Offer]): Unit = { + driver: org.apache.mesos.SchedulerDriver, offers: mutable.Buffer[Offer]): Unit = { offers.foreach { offer => - declineOffer(d, offer, Some("unmet constraints"), + declineOffer( + driver, + offer, + Some("unmet constraints"), Some(rejectOfferDurationForUnmetConstraints)) } } - private def declineOffer( - d: org.apache.mesos.SchedulerDriver, - offer: Offer, - reason: Option[String] = None, - refuseSeconds: Option[Long] = None): Unit = { - - val id = offer.getId.getValue - val offerAttributes = toAttributeMap(offer.getAttributesList) - val mem = getResource(offer.getResourcesList, "mem") - val cpus = getResource(offer.getResourcesList, "cpus") - val ports = getRangeResource(offer.getResourcesList, "ports") - - logDebug(s"Declining offer: $id with attributes: $offerAttributes mem: $mem" + - s" cpu: $cpus port: $ports for $refuseSeconds seconds" + - reason.map(r => s" (reason: $r)").getOrElse("")) - - refuseSeconds match { - case Some(seconds) => - val filters = Filters.newBuilder().setRefuseSeconds(seconds).build() - d.declineOffer(offer.getId, filters) - case _ => d.declineOffer(offer.getId) - } - } - /** * Launches executors on accepted offers, and declines unused offers. Executors are launched * round-robin on offers. * - * @param d SchedulerDriver + * @param driver SchedulerDriver * @param offers Mesos offers that match attribute constraints */ private def handleMatchedOffers( - d: org.apache.mesos.SchedulerDriver, offers: mutable.Buffer[Offer]): Unit = { + driver: org.apache.mesos.SchedulerDriver, offers: mutable.Buffer[Offer]): Unit = { val tasks = buildMesosTasks(offers) for (offer <- offers) { val offerAttributes = toAttributeMap(offer.getAttributesList) @@ -358,15 +343,19 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( s" ports: $ports") } - d.launchTasks( + driver.launchTasks( Collections.singleton(offer.getId), offerTasks.asJava) } else if (totalCoresAcquired >= maxCores) { // Reject an offer for a configurable amount of time to avoid starving other frameworks - declineOffer(d, offer, Some("reached spark.cores.max"), + declineOffer(driver, + offer, + Some("reached spark.cores.max"), Some(rejectOfferDurationForReachedMaxCores)) } else { - declineOffer(d, offer) + declineOffer( + driver, + offer) } } } @@ -582,8 +571,8 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( // Close the mesos external shuffle client if used mesosExternalShuffleClient.foreach(_.close()) - if (mesosDriver != null) { - mesosDriver.stop() + if (schedulerDriver != null) { + schedulerDriver.stop() } } @@ -634,13 +623,13 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( } override def doKillExecutors(executorIds: Seq[String]): Future[Boolean] = Future.successful { - if (mesosDriver == null) { + if (schedulerDriver == null) { logWarning("Asked to kill executors before the Mesos driver was started.") false } else { for (executorId <- executorIds) { val taskId = TaskID.newBuilder().setValue(executorId).build() - mesosDriver.killTask(taskId) + schedulerDriver.killTask(taskId) } // no need to adjust `executorLimitOption` since the AllocationManager already communicated // the desired limit through a call to `doRequestTotalExecutors`. diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackend.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackend.scala index 7e561916a71e2..215271302ec51 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackend.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackend.scala @@ -24,6 +24,7 @@ import scala.collection.JavaConverters._ import scala.collection.mutable.{HashMap, HashSet} import org.apache.mesos.Protos.{ExecutorInfo => MesosExecutorInfo, TaskInfo => MesosTaskInfo, _} +import org.apache.mesos.SchedulerDriver import org.apache.mesos.protobuf.ByteString import org.apache.spark.{SparkContext, SparkException, TaskState} @@ -65,7 +66,9 @@ private[spark] class MesosFineGrainedSchedulerBackend( // reject offers with mismatched constraints in seconds private val rejectOfferDurationForUnmetConstraints = - getRejectOfferDurationForUnmetConstraints(sc) + getRejectOfferDurationForUnmetConstraints(sc.conf) + + private var schedulerDriver: SchedulerDriver = _ @volatile var appId: String = _ @@ -89,6 +92,7 @@ private[spark] class MesosFineGrainedSchedulerBackend( /** * Creates a MesosExecutorInfo that is used to launch a Mesos executor. + * * @param availableResources Available resources that is offered by Mesos * @param execId The executor id to assign to this new executor. * @return A tuple of the new mesos executor info and the remaining available resources. @@ -178,10 +182,13 @@ private[spark] class MesosFineGrainedSchedulerBackend( override def offerRescinded(d: org.apache.mesos.SchedulerDriver, o: OfferID) {} override def registered( - d: org.apache.mesos.SchedulerDriver, frameworkId: FrameworkID, masterInfo: MasterInfo) { + driver: org.apache.mesos.SchedulerDriver, + frameworkId: FrameworkID, + masterInfo: MasterInfo) { inClassLoader() { appId = frameworkId.getValue logInfo("Registered as framework ID " + appId) + this.schedulerDriver = driver markRegistered() } } @@ -383,13 +390,13 @@ private[spark] class MesosFineGrainedSchedulerBackend( } override def stop() { - if (mesosDriver != null) { - mesosDriver.stop() + if (schedulerDriver != null) { + schedulerDriver.stop() } } override def reviveOffers() { - mesosDriver.reviveOffers() + schedulerDriver.reviveOffers() } override def frameworkMessage( @@ -426,7 +433,7 @@ private[spark] class MesosFineGrainedSchedulerBackend( } override def killTask(taskId: Long, executorId: String, interruptThread: Boolean): Unit = { - mesosDriver.killTask( + schedulerDriver.killTask( TaskID.newBuilder() .setValue(taskId.toString).build() ) diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala index 1d742fefbbacf..3f25535cb5ec2 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala @@ -46,9 +46,6 @@ trait MesosSchedulerUtils extends Logging { // Lock used to wait for scheduler to be registered private final val registerLatch = new CountDownLatch(1) - // Driver for talking to Mesos - protected var mesosDriver: SchedulerDriver = null - /** * Creates a new MesosSchedulerDriver that communicates to the Mesos master. * @@ -115,10 +112,6 @@ trait MesosSchedulerUtils extends Logging { */ def startScheduler(newDriver: SchedulerDriver): Unit = { synchronized { - if (mesosDriver != null) { - registerLatch.await() - return - } @volatile var error: Option[Exception] = None @@ -128,8 +121,7 @@ trait MesosSchedulerUtils extends Logging { setDaemon(true) override def run() { try { - mesosDriver = newDriver - val ret = mesosDriver.run() + val ret = newDriver.run() logInfo("driver.run() returned with code " + ret) if (ret != null && ret.equals(Status.DRIVER_ABORTED)) { error = Some(new SparkException("Error starting driver, DRIVER_ABORTED")) @@ -379,12 +371,24 @@ trait MesosSchedulerUtils extends Logging { } } - protected def getRejectOfferDurationForUnmetConstraints(sc: SparkContext): Long = { - sc.conf.getTimeAsSeconds("spark.mesos.rejectOfferDurationForUnmetConstraints", "120s") + private def getRejectOfferDurationStr(conf: SparkConf): String = { + conf.get("spark.mesos.rejectOfferDuration", "120s") + } + + protected def getRejectOfferDuration(conf: SparkConf): Long = { + Utils.timeStringAsSeconds(getRejectOfferDurationStr(conf)) + } + + protected def getRejectOfferDurationForUnmetConstraints(conf: SparkConf): Long = { + conf.getTimeAsSeconds( + "spark.mesos.rejectOfferDurationForUnmetConstraints", + getRejectOfferDurationStr(conf)) } - protected def getRejectOfferDurationForReachedMaxCores(sc: SparkContext): Long = { - sc.conf.getTimeAsSeconds("spark.mesos.rejectOfferDurationForReachedMaxCores", "120s") + protected def getRejectOfferDurationForReachedMaxCores(conf: SparkConf): Long = { + conf.getTimeAsSeconds( + "spark.mesos.rejectOfferDurationForReachedMaxCores", + getRejectOfferDurationStr(conf)) } /** @@ -438,6 +442,7 @@ trait MesosSchedulerUtils extends Logging { /** * The values of the non-zero ports to be used by the executor process. + * * @param conf the spark config to use * @return the ono-zero values of the ports */ @@ -521,4 +526,33 @@ trait MesosSchedulerUtils extends Logging { case TaskState.KILLED => MesosTaskState.TASK_KILLED case TaskState.LOST => MesosTaskState.TASK_LOST } + + protected def declineOffer( + driver: org.apache.mesos.SchedulerDriver, + offer: Offer, + reason: Option[String] = None, + refuseSeconds: Option[Long] = None): Unit = { + + val id = offer.getId.getValue + val offerAttributes = toAttributeMap(offer.getAttributesList) + val mem = getResource(offer.getResourcesList, "mem") + val cpus = getResource(offer.getResourcesList, "cpus") + val ports = getRangeResource(offer.getResourcesList, "ports") + + logDebug(s"Declining offer: $id with " + + s"attributes: $offerAttributes " + + s"mem: $mem " + + s"cpu: $cpus " + + s"port: $ports " + + refuseSeconds.map(s => s"for ${s} seconds ").getOrElse("") + + reason.map(r => s" (reason: $r)").getOrElse("")) + + refuseSeconds match { + case Some(seconds) => + val filters = Filters.newBuilder().setRefuseSeconds(seconds).build() + driver.declineOffer(offer.getId, filters) + case _ => + driver.declineOffer(offer.getId) + } + } } diff --git a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterSchedulerSuite.scala b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterSchedulerSuite.scala index b9d098486b675..32967b04cd346 100644 --- a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterSchedulerSuite.scala +++ b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterSchedulerSuite.scala @@ -53,19 +53,32 @@ class MesosClusterSchedulerSuite extends SparkFunSuite with LocalSparkContext wi override def start(): Unit = { ready = true } } scheduler.start() + scheduler.registered(driver, Utils.TEST_FRAMEWORK_ID, Utils.TEST_MASTER_INFO) + } + + private def testDriverDescription(submissionId: String): MesosDriverDescription = { + new MesosDriverDescription( + "d1", + "jar", + 1000, + 1, + true, + command, + Map[String, String](), + submissionId, + new Date()) } test("can queue drivers") { setScheduler() - val response = scheduler.submitDriver( - new MesosDriverDescription("d1", "jar", 1000, 1, true, - command, Map[String, String](), "s1", new Date())) + val response = scheduler.submitDriver(testDriverDescription("s1")) assert(response.success) - val response2 = - scheduler.submitDriver(new MesosDriverDescription( - "d1", "jar", 1000, 1, true, command, Map[String, String](), "s2", new Date())) + verify(driver, times(1)).reviveOffers() + + val response2 = scheduler.submitDriver(testDriverDescription("s2")) assert(response2.success) + val state = scheduler.getSchedulerState() val queuedDrivers = state.queuedDrivers.toList assert(queuedDrivers(0).submissionId == response.submissionId) @@ -75,9 +88,7 @@ class MesosClusterSchedulerSuite extends SparkFunSuite with LocalSparkContext wi test("can kill queued drivers") { setScheduler() - val response = scheduler.submitDriver( - new MesosDriverDescription("d1", "jar", 1000, 1, true, - command, Map[String, String](), "s1", new Date())) + val response = scheduler.submitDriver(testDriverDescription("s1")) assert(response.success) val killResponse = scheduler.killDriver(response.submissionId) assert(killResponse.success) @@ -238,18 +249,10 @@ class MesosClusterSchedulerSuite extends SparkFunSuite with LocalSparkContext wi } test("can kill supervised drivers") { - val driver = mock[SchedulerDriver] val conf = new SparkConf() conf.setMaster("mesos://localhost:5050") conf.setAppName("spark mesos") - scheduler = new MesosClusterScheduler( - new BlackHoleMesosClusterPersistenceEngineFactory, conf) { - override def start(): Unit = { - ready = true - mesosDriver = driver - } - } - scheduler.start() + setScheduler(conf.getAll.toMap) val response = scheduler.submitDriver( new MesosDriverDescription("d1", "jar", 100, 1, true, command, @@ -291,4 +294,16 @@ class MesosClusterSchedulerSuite extends SparkFunSuite with LocalSparkContext wi assert(state.launchedDrivers.isEmpty) assert(state.finishedDrivers.size == 1) } + + test("Declines offer with refuse seconds = 120.") { + setScheduler() + + val filter = Filters.newBuilder().setRefuseSeconds(120).build() + val offerId = OfferID.newBuilder().setValue("o1").build() + val offer = Utils.createOffer(offerId.getValue, "s1", 1000, 1) + + scheduler.resourceOffers(driver, Collections.singletonList(offer)) + + verify(driver, times(1)).declineOffer(offerId, filter) + } } diff --git a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala index 78346e9744957..98033bec6dd68 100644 --- a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala +++ b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala @@ -552,17 +552,14 @@ class MesosCoarseGrainedSchedulerBackendSuite extends SparkFunSuite override protected def getShuffleClient(): MesosExternalShuffleClient = shuffleClient // override to avoid race condition with the driver thread on `mesosDriver` - override def startScheduler(newDriver: SchedulerDriver): Unit = { - mesosDriver = newDriver - } + override def startScheduler(newDriver: SchedulerDriver): Unit = {} override def stopExecutors(): Unit = { stopCalled = true } - - markRegistered() } backend.start() + backend.registered(driver, Utils.TEST_FRAMEWORK_ID, Utils.TEST_MASTER_INFO) backend } diff --git a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/Utils.scala b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/Utils.scala index 7ebb294aa9080..2a67cbc913ffe 100644 --- a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/Utils.scala +++ b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/Utils.scala @@ -28,6 +28,17 @@ import org.mockito.{ArgumentCaptor, Matchers} import org.mockito.Mockito._ object Utils { + + val TEST_FRAMEWORK_ID = FrameworkID.newBuilder() + .setValue("test-framework-id") + .build() + + val TEST_MASTER_INFO = MasterInfo.newBuilder() + .setId("test-master") + .setIp(0) + .setPort(0) + .build() + def createOffer( offerId: String, slaveId: String, From 8e41c2eed873e215b13215844ba5ba73a8906c5b Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Tue, 7 Mar 2017 16:21:18 -0800 Subject: [PATCH 32/78] [SPARK-19857][YARN] Correctly calculate next credential update time. Add parentheses so that both lines form a single statement; also add a log message so that the issue becomes more explicit if it shows up again. Tested manually with integration test that exercises the feature. Author: Marcelo Vanzin Closes #17198 from vanzin/SPARK-19857. --- .../spark/deploy/yarn/security/CredentialUpdater.scala | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/CredentialUpdater.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/CredentialUpdater.scala index 2fdb70a73c754..41b7b5d60b038 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/CredentialUpdater.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/CredentialUpdater.scala @@ -60,7 +60,7 @@ private[spark] class CredentialUpdater( if (remainingTime <= 0) { credentialUpdater.schedule(credentialUpdaterRunnable, 1, TimeUnit.MINUTES) } else { - logInfo(s"Scheduling credentials refresh from HDFS in $remainingTime millis.") + logInfo(s"Scheduling credentials refresh from HDFS in $remainingTime ms.") credentialUpdater.schedule(credentialUpdaterRunnable, remainingTime, TimeUnit.MILLISECONDS) } } @@ -81,8 +81,8 @@ private[spark] class CredentialUpdater( UserGroupInformation.getCurrentUser.addCredentials(newCredentials) logInfo("Credentials updated from credentials file.") - val remainingTime = getTimeOfNextUpdateFromFileName(credentialsStatus.getPath) - - System.currentTimeMillis() + val remainingTime = (getTimeOfNextUpdateFromFileName(credentialsStatus.getPath) + - System.currentTimeMillis()) if (remainingTime <= 0) TimeUnit.MINUTES.toMillis(1) else remainingTime } else { // If current credential file is older than expected, sleep 1 hour and check again. @@ -100,6 +100,7 @@ private[spark] class CredentialUpdater( TimeUnit.HOURS.toMillis(1) } + logInfo(s"Scheduling credentials refresh from HDFS in $timeToNextUpdate ms.") credentialUpdater.schedule( credentialUpdaterRunnable, timeToNextUpdate, TimeUnit.MILLISECONDS) } From 47b2f68a885b7a2fc593ac7a55cd19742016364d Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 7 Mar 2017 17:14:26 -0800 Subject: [PATCH 33/78] Revert "[SPARK-19561] [PYTHON] cast TimestampType.toInternal output to long" This reverts commit 711addd46e98e42deca97c5b9c0e55fddebaa458. --- python/pyspark/sql/tests.py | 6 ------ python/pyspark/sql/types.py | 2 +- 2 files changed, 1 insertion(+), 7 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 4d48ef694d68f..81f3d1d36a342 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -1555,12 +1555,6 @@ def test_time_with_timezone(self): self.assertEqual(now, now1) self.assertEqual(now, utcnow1) - # regression test for SPARK-19561 - def test_datetime_at_epoch(self): - epoch = datetime.datetime.fromtimestamp(0) - df = self.spark.createDataFrame([Row(date=epoch)]) - self.assertEqual(df.first()['date'], epoch) - def test_decimal(self): from decimal import Decimal schema = StructType([StructField("decimal", DecimalType(10, 5))]) diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 1d31f25efad52..26b54a7fb3709 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -189,7 +189,7 @@ def toInternal(self, dt): if dt is not None: seconds = (calendar.timegm(dt.utctimetuple()) if dt.tzinfo else time.mktime(dt.timetuple())) - return long(seconds) * 1000000 + dt.microsecond + return int(seconds) * 1000000 + dt.microsecond def fromInternal(self, ts): if ts is not None: From c96d14abae5962a7b15239319c2a151b95f7db94 Mon Sep 17 00:00:00 2001 From: Tejas Patil Date: Tue, 7 Mar 2017 20:19:30 -0800 Subject: [PATCH 34/78] [SPARK-19843][SQL] UTF8String => (int / long) conversion expensive for invalid inputs ## What changes were proposed in this pull request? Jira : https://issues.apache.org/jira/browse/SPARK-19843 Created wrapper classes (`IntWrapper`, `LongWrapper`) to wrap the result of parsing (which are primitive types). In case of problem in parsing, the method would return a boolean. ## How was this patch tested? - Added new unit tests - Ran a prod job which had conversion from string -> int and verified the outputs ## Performance Tiny regression when all strings are valid integers ``` conversion to int: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative -------------------------------------------------------------------------------- trunk 502 / 522 33.4 29.9 1.0X SPARK-19843 493 / 503 34.0 29.4 1.0X ``` Huge gain when all strings are invalid integers ``` conversion to int: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------- trunk 33913 / 34219 0.5 2021.4 1.0X SPARK-19843 154 / 162 108.8 9.2 220.0X ``` Author: Tejas Patil Closes #17184 from tejasapatil/SPARK-19843_is_numeric_maybe. --- .../apache/spark/unsafe/types/UTF8String.java | 120 +++++++++------- .../spark/unsafe/types/UTF8StringSuite.java | 128 +++++++++++++++++- .../spark/sql/catalyst/expressions/Cast.scala | 81 ++++++----- 3 files changed, 247 insertions(+), 82 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index 10a7cb1d06659..7abe0fa80ad7c 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -850,11 +850,8 @@ public UTF8String translate(Map dict) { return fromString(sb.toString()); } - private int getDigit(byte b) { - if (b >= '0' && b <= '9') { - return b - '0'; - } - throw new NumberFormatException(toString()); + public static class LongWrapper { + public long value = 0; } /** @@ -862,14 +859,18 @@ private int getDigit(byte b) { * * Note that, in this method we accumulate the result in negative format, and convert it to * positive format at the end, if this string is not started with '-'. This is because min value - * is bigger than max value in digits, e.g. Integer.MAX_VALUE is '2147483647' and - * Integer.MIN_VALUE is '-2147483648'. + * is bigger than max value in digits, e.g. Long.MAX_VALUE is '9223372036854775807' and + * Long.MIN_VALUE is '-9223372036854775808'. * * This code is mostly copied from LazyLong.parseLong in Hive. + * + * @param toLongResult If a valid `long` was parsed from this UTF8String, then its value would + * be set in `toLongResult` + * @return true if the parsing was successful else false */ - public long toLong() { + public boolean toLong(LongWrapper toLongResult) { if (numBytes == 0) { - throw new NumberFormatException("Empty string"); + return false; } byte b = getByte(0); @@ -878,7 +879,7 @@ public long toLong() { if (negative || b == '+') { offset++; if (numBytes == 1) { - throw new NumberFormatException(toString()); + return false; } } @@ -897,20 +898,25 @@ public long toLong() { break; } - int digit = getDigit(b); + int digit; + if (b >= '0' && b <= '9') { + digit = b - '0'; + } else { + return false; + } + // We are going to process the new digit and accumulate the result. However, before doing // this, if the result is already smaller than the stopValue(Long.MIN_VALUE / radix), then - // result * 10 will definitely be smaller than minValue, and we can stop and throw exception. + // result * 10 will definitely be smaller than minValue, and we can stop. if (result < stopValue) { - throw new NumberFormatException(toString()); + return false; } result = result * radix - digit; // Since the previous result is less than or equal to stopValue(Long.MIN_VALUE / radix), we - // can just use `result > 0` to check overflow. If result overflows, we should stop and throw - // exception. + // can just use `result > 0` to check overflow. If result overflows, we should stop. if (result > 0) { - throw new NumberFormatException(toString()); + return false; } } @@ -918,8 +924,9 @@ public long toLong() { // part will not change the number, but we will verify that the fractional part // is well formed. while (offset < numBytes) { - if (getDigit(getByte(offset)) == -1) { - throw new NumberFormatException(toString()); + byte currentByte = getByte(offset); + if (currentByte < '0' || currentByte > '9') { + return false; } offset++; } @@ -927,11 +934,16 @@ public long toLong() { if (!negative) { result = -result; if (result < 0) { - throw new NumberFormatException(toString()); + return false; } } - return result; + toLongResult.value = result; + return true; + } + + public static class IntWrapper { + public int value = 0; } /** @@ -946,10 +958,14 @@ public long toLong() { * * Note that, this method is almost same as `toLong`, but we leave it duplicated for performance * reasons, like Hive does. + * + * @param intWrapper If a valid `int` was parsed from this UTF8String, then its value would + * be set in `intWrapper` + * @return true if the parsing was successful else false */ - public int toInt() { + public boolean toInt(IntWrapper intWrapper) { if (numBytes == 0) { - throw new NumberFormatException("Empty string"); + return false; } byte b = getByte(0); @@ -958,7 +974,7 @@ public int toInt() { if (negative || b == '+') { offset++; if (numBytes == 1) { - throw new NumberFormatException(toString()); + return false; } } @@ -977,20 +993,25 @@ public int toInt() { break; } - int digit = getDigit(b); + int digit; + if (b >= '0' && b <= '9') { + digit = b - '0'; + } else { + return false; + } + // We are going to process the new digit and accumulate the result. However, before doing // this, if the result is already smaller than the stopValue(Integer.MIN_VALUE / radix), then - // result * 10 will definitely be smaller than minValue, and we can stop and throw exception. + // result * 10 will definitely be smaller than minValue, and we can stop if (result < stopValue) { - throw new NumberFormatException(toString()); + return false; } result = result * radix - digit; // Since the previous result is less than or equal to stopValue(Integer.MIN_VALUE / radix), - // we can just use `result > 0` to check overflow. If result overflows, we should stop and - // throw exception. + // we can just use `result > 0` to check overflow. If result overflows, we should stop if (result > 0) { - throw new NumberFormatException(toString()); + return false; } } @@ -998,8 +1019,9 @@ public int toInt() { // part will not change the number, but we will verify that the fractional part // is well formed. while (offset < numBytes) { - if (getDigit(getByte(offset)) == -1) { - throw new NumberFormatException(toString()); + byte currentByte = getByte(offset); + if (currentByte < '0' || currentByte > '9') { + return false; } offset++; } @@ -1007,31 +1029,33 @@ public int toInt() { if (!negative) { result = -result; if (result < 0) { - throw new NumberFormatException(toString()); + return false; } } - - return result; + intWrapper.value = result; + return true; } - public short toShort() { - int intValue = toInt(); - short result = (short) intValue; - if (result != intValue) { - throw new NumberFormatException(toString()); + public boolean toShort(IntWrapper intWrapper) { + if (toInt(intWrapper)) { + int intValue = intWrapper.value; + short result = (short) intValue; + if (result == intValue) { + return true; + } } - - return result; + return false; } - public byte toByte() { - int intValue = toInt(); - byte result = (byte) intValue; - if (result != intValue) { - throw new NumberFormatException(toString()); + public boolean toByte(IntWrapper intWrapper) { + if (toInt(intWrapper)) { + int intValue = intWrapper.value; + byte result = (byte) intValue; + if (result == intValue) { + return true; + } } - - return result; + return false; } @Override diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java index 6f6e0ef0e4855..c376371abdf90 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java @@ -22,9 +22,7 @@ import java.nio.ByteBuffer; import java.nio.ByteOrder; import java.nio.charset.StandardCharsets; -import java.util.Arrays; -import java.util.HashMap; -import java.util.HashSet; +import java.util.*; import com.google.common.collect.ImmutableMap; import org.apache.spark.unsafe.Platform; @@ -608,4 +606,128 @@ public void writeToOutputStreamIntArray() throws IOException { .writeTo(outputStream); assertEquals("大千世界", outputStream.toString("UTF-8")); } + + @Test + public void testToShort() throws IOException { + Map inputToExpectedOutput = new HashMap<>(); + inputToExpectedOutput.put("1", (short) 1); + inputToExpectedOutput.put("+1", (short) 1); + inputToExpectedOutput.put("-1", (short) -1); + inputToExpectedOutput.put("0", (short) 0); + inputToExpectedOutput.put("1111.12345678901234567890", (short) 1111); + inputToExpectedOutput.put(String.valueOf(Short.MAX_VALUE), Short.MAX_VALUE); + inputToExpectedOutput.put(String.valueOf(Short.MIN_VALUE), Short.MIN_VALUE); + + Random rand = new Random(); + for (int i = 0; i < 10; i++) { + short value = (short) rand.nextInt(); + inputToExpectedOutput.put(String.valueOf(value), value); + } + + IntWrapper wrapper = new IntWrapper(); + for (Map.Entry entry : inputToExpectedOutput.entrySet()) { + assertTrue(entry.getKey(), UTF8String.fromString(entry.getKey()).toShort(wrapper)); + assertEquals((short) entry.getValue(), wrapper.value); + } + + List negativeInputs = + Arrays.asList("", " ", "null", "NULL", "\n", "~1212121", "3276700"); + + for (String negativeInput : negativeInputs) { + assertFalse(negativeInput, UTF8String.fromString(negativeInput).toShort(wrapper)); + } + } + + @Test + public void testToByte() throws IOException { + Map inputToExpectedOutput = new HashMap<>(); + inputToExpectedOutput.put("1", (byte) 1); + inputToExpectedOutput.put("+1",(byte) 1); + inputToExpectedOutput.put("-1", (byte) -1); + inputToExpectedOutput.put("0", (byte) 0); + inputToExpectedOutput.put("111.12345678901234567890", (byte) 111); + inputToExpectedOutput.put(String.valueOf(Byte.MAX_VALUE), Byte.MAX_VALUE); + inputToExpectedOutput.put(String.valueOf(Byte.MIN_VALUE), Byte.MIN_VALUE); + + Random rand = new Random(); + for (int i = 0; i < 10; i++) { + byte value = (byte) rand.nextInt(); + inputToExpectedOutput.put(String.valueOf(value), value); + } + + IntWrapper intWrapper = new IntWrapper(); + for (Map.Entry entry : inputToExpectedOutput.entrySet()) { + assertTrue(entry.getKey(), UTF8String.fromString(entry.getKey()).toByte(intWrapper)); + assertEquals((byte) entry.getValue(), intWrapper.value); + } + + List negativeInputs = + Arrays.asList("", " ", "null", "NULL", "\n", "~1212121", "12345678901234567890"); + + for (String negativeInput : negativeInputs) { + assertFalse(negativeInput, UTF8String.fromString(negativeInput).toByte(intWrapper)); + } + } + + @Test + public void testToInt() throws IOException { + Map inputToExpectedOutput = new HashMap<>(); + inputToExpectedOutput.put("1", 1); + inputToExpectedOutput.put("+1", 1); + inputToExpectedOutput.put("-1", -1); + inputToExpectedOutput.put("0", 0); + inputToExpectedOutput.put("11111.1234567", 11111); + inputToExpectedOutput.put(String.valueOf(Integer.MAX_VALUE), Integer.MAX_VALUE); + inputToExpectedOutput.put(String.valueOf(Integer.MIN_VALUE), Integer.MIN_VALUE); + + Random rand = new Random(); + for (int i = 0; i < 10; i++) { + int value = rand.nextInt(); + inputToExpectedOutput.put(String.valueOf(value), value); + } + + IntWrapper intWrapper = new IntWrapper(); + for (Map.Entry entry : inputToExpectedOutput.entrySet()) { + assertTrue(entry.getKey(), UTF8String.fromString(entry.getKey()).toInt(intWrapper)); + assertEquals((int) entry.getValue(), intWrapper.value); + } + + List negativeInputs = + Arrays.asList("", " ", "null", "NULL", "\n", "~1212121", "12345678901234567890"); + + for (String negativeInput : negativeInputs) { + assertFalse(negativeInput, UTF8String.fromString(negativeInput).toInt(intWrapper)); + } + } + + @Test + public void testToLong() throws IOException { + Map inputToExpectedOutput = new HashMap<>(); + inputToExpectedOutput.put("1", 1L); + inputToExpectedOutput.put("+1", 1L); + inputToExpectedOutput.put("-1", -1L); + inputToExpectedOutput.put("0", 0L); + inputToExpectedOutput.put("1076753423.12345678901234567890", 1076753423L); + inputToExpectedOutput.put(String.valueOf(Long.MAX_VALUE), Long.MAX_VALUE); + inputToExpectedOutput.put(String.valueOf(Long.MIN_VALUE), Long.MIN_VALUE); + + Random rand = new Random(); + for (int i = 0; i < 10; i++) { + long value = rand.nextLong(); + inputToExpectedOutput.put(String.valueOf(value), value); + } + + LongWrapper wrapper = new LongWrapper(); + for (Map.Entry entry : inputToExpectedOutput.entrySet()) { + assertTrue(entry.getKey(), UTF8String.fromString(entry.getKey()).toLong(wrapper)); + assertEquals((long) entry.getValue(), wrapper.value); + } + + List negativeInputs = Arrays.asList("", " ", "null", "NULL", "\n", "~1212121", + "1234567890123456789012345678901234"); + + for (String negativeInput : negativeInputs) { + assertFalse(negativeInput, UTF8String.fromString(negativeInput).toLong(wrapper)); + } + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index a36d3507d92ec..7c60f7d57a99e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} - +import org.apache.spark.unsafe.types.UTF8String.{IntWrapper, LongWrapper} object Cast { @@ -277,9 +277,8 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String // LongConverter private[this] def castToLong(from: DataType): Any => Any = from match { case StringType => - buildCast[UTF8String](_, s => try s.toLong catch { - case _: NumberFormatException => null - }) + val result = new LongWrapper() + buildCast[UTF8String](_, s => if (s.toLong(result)) result.value else null) case BooleanType => buildCast[Boolean](_, b => if (b) 1L else 0L) case DateType => @@ -293,9 +292,8 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String // IntConverter private[this] def castToInt(from: DataType): Any => Any = from match { case StringType => - buildCast[UTF8String](_, s => try s.toInt catch { - case _: NumberFormatException => null - }) + val result = new IntWrapper() + buildCast[UTF8String](_, s => if (s.toInt(result)) result.value else null) case BooleanType => buildCast[Boolean](_, b => if (b) 1 else 0) case DateType => @@ -309,8 +307,11 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String // ShortConverter private[this] def castToShort(from: DataType): Any => Any = from match { case StringType => - buildCast[UTF8String](_, s => try s.toShort catch { - case _: NumberFormatException => null + val result = new IntWrapper() + buildCast[UTF8String](_, s => if (s.toShort(result)) { + result.value.toShort + } else { + null }) case BooleanType => buildCast[Boolean](_, b => if (b) 1.toShort else 0.toShort) @@ -325,8 +326,11 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String // ByteConverter private[this] def castToByte(from: DataType): Any => Any = from match { case StringType => - buildCast[UTF8String](_, s => try s.toByte catch { - case _: NumberFormatException => null + val result = new IntWrapper() + buildCast[UTF8String](_, s => if (s.toByte(result)) { + result.value.toByte + } else { + null }) case BooleanType => buildCast[Boolean](_, b => if (b) 1.toByte else 0.toByte) @@ -503,11 +507,11 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String case TimestampType => castToTimestampCode(from, ctx) case CalendarIntervalType => castToIntervalCode(from) case BooleanType => castToBooleanCode(from) - case ByteType => castToByteCode(from) - case ShortType => castToShortCode(from) - case IntegerType => castToIntCode(from) + case ByteType => castToByteCode(from, ctx) + case ShortType => castToShortCode(from, ctx) + case IntegerType => castToIntCode(from, ctx) case FloatType => castToFloatCode(from) - case LongType => castToLongCode(from) + case LongType => castToLongCode(from, ctx) case DoubleType => castToDoubleCode(from) case array: ArrayType => @@ -734,13 +738,16 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String (c, evPrim, evNull) => s"$evPrim = $c != 0;" } - private[this] def castToByteCode(from: DataType): CastFunction = from match { + private[this] def castToByteCode(from: DataType, ctx: CodegenContext): CastFunction = from match { case StringType => + val wrapper = ctx.freshName("wrapper") + ctx.addMutableState("UTF8String.IntWrapper", wrapper, + s"$wrapper = new UTF8String.IntWrapper();") (c, evPrim, evNull) => s""" - try { - $evPrim = $c.toByte(); - } catch (java.lang.NumberFormatException e) { + if ($c.toByte($wrapper)) { + $evPrim = (byte) $wrapper.value; + } else { $evNull = true; } """ @@ -756,13 +763,18 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String (c, evPrim, evNull) => s"$evPrim = (byte) $c;" } - private[this] def castToShortCode(from: DataType): CastFunction = from match { + private[this] def castToShortCode( + from: DataType, + ctx: CodegenContext): CastFunction = from match { case StringType => + val wrapper = ctx.freshName("wrapper") + ctx.addMutableState("UTF8String.IntWrapper", wrapper, + s"$wrapper = new UTF8String.IntWrapper();") (c, evPrim, evNull) => s""" - try { - $evPrim = $c.toShort(); - } catch (java.lang.NumberFormatException e) { + if ($c.toShort($wrapper)) { + $evPrim = (short) $wrapper.value; + } else { $evNull = true; } """ @@ -778,13 +790,16 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String (c, evPrim, evNull) => s"$evPrim = (short) $c;" } - private[this] def castToIntCode(from: DataType): CastFunction = from match { + private[this] def castToIntCode(from: DataType, ctx: CodegenContext): CastFunction = from match { case StringType => + val wrapper = ctx.freshName("wrapper") + ctx.addMutableState("UTF8String.IntWrapper", wrapper, + s"$wrapper = new UTF8String.IntWrapper();") (c, evPrim, evNull) => s""" - try { - $evPrim = $c.toInt(); - } catch (java.lang.NumberFormatException e) { + if ($c.toInt($wrapper)) { + $evPrim = $wrapper.value; + } else { $evNull = true; } """ @@ -800,13 +815,17 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String (c, evPrim, evNull) => s"$evPrim = (int) $c;" } - private[this] def castToLongCode(from: DataType): CastFunction = from match { + private[this] def castToLongCode(from: DataType, ctx: CodegenContext): CastFunction = from match { case StringType => + val wrapper = ctx.freshName("wrapper") + ctx.addMutableState("UTF8String.LongWrapper", wrapper, + s"$wrapper = new UTF8String.LongWrapper();") + (c, evPrim, evNull) => s""" - try { - $evPrim = $c.toLong(); - } catch (java.lang.NumberFormatException e) { + if ($c.toLong($wrapper)) { + $evPrim = $wrapper.value; + } else { $evNull = true; } """ From b9783a92f7ba0c3b22d7dceae7a3185de17dedcc Mon Sep 17 00:00:00 2001 From: jiangxingbo Date: Tue, 7 Mar 2017 20:25:38 -0800 Subject: [PATCH 35/78] [SPARK-18389][SQL] Disallow cyclic view reference ## What changes were proposed in this pull request? Disallow cyclic view references, a cyclic view reference may be created by the following queries: ``` CREATE VIEW testView AS SELECT id FROM tbl CREATE VIEW testView2 AS SELECT id FROM testView ALTER VIEW testView AS SELECT * FROM testView2 ``` In the above example, a reference cycle (testView -> testView2 -> testView) exsits. We disallow cyclic view references by checking that in ALTER VIEW command, when the `analyzedPlan` contains the same `View` node with the altered view, we should prevent the behavior and throw an AnalysisException. ## How was this patch tested? Test by `SQLViewSuite.test("correctly handle a cyclic view reference")`. Author: jiangxingbo Closes #17152 from jiangxb1987/cyclic-view. --- .../spark/sql/execution/command/views.scala | 61 ++++++++++++++++++- .../spark/sql/execution/SQLViewSuite.scala | 35 +++++++++-- 2 files changed, 90 insertions(+), 6 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala index 921c84895598c..00f0acab21aa2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala @@ -23,9 +23,9 @@ import org.apache.spark.sql.{AnalysisException, Row, SparkSession} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.{UnresolvedFunction, UnresolvedRelation} import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable, CatalogTableType} -import org.apache.spark.sql.catalyst.expressions.Alias +import org.apache.spark.sql.catalyst.expressions.{Alias, SubqueryExpression} import org.apache.spark.sql.catalyst.plans.QueryPlan -import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project} +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, View} import org.apache.spark.sql.types.MetadataBuilder @@ -154,6 +154,10 @@ case class CreateViewCommand( } else if (tableMetadata.tableType != CatalogTableType.VIEW) { throw new AnalysisException(s"$name is not a view") } else if (replace) { + // Detect cyclic view reference on CREATE OR REPLACE VIEW. + val viewIdent = tableMetadata.identifier + checkCyclicViewReference(analyzedPlan, Seq(viewIdent), viewIdent) + // Handles `CREATE OR REPLACE VIEW v0 AS SELECT ...` catalog.alterTable(prepareTable(sparkSession, analyzedPlan)) } else { @@ -283,6 +287,10 @@ case class AlterViewAsCommand( throw new AnalysisException(s"${viewMeta.identifier} is not a view.") } + // Detect cyclic view reference on ALTER VIEW. + val viewIdent = viewMeta.identifier + checkCyclicViewReference(analyzedPlan, Seq(viewIdent), viewIdent) + val newProperties = generateViewProperties(viewMeta.properties, session, analyzedPlan) val updatedViewMeta = viewMeta.copy( @@ -358,4 +366,53 @@ object ViewHelper { generateViewDefaultDatabase(viewDefaultDatabase) ++ generateQueryColumnNames(queryOutput) } + + /** + * Recursively search the logical plan to detect cyclic view references, throw an + * AnalysisException if cycle detected. + * + * A cyclic view reference is a cycle of reference dependencies, for example, if the following + * statements are executed: + * CREATE VIEW testView AS SELECT id FROM tbl + * CREATE VIEW testView2 AS SELECT id FROM testView + * ALTER VIEW testView AS SELECT * FROM testView2 + * The view `testView` references `testView2`, and `testView2` also references `testView`, + * therefore a reference cycle (testView -> testView2 -> testView) exists. + * + * @param plan the logical plan we detect cyclic view references from. + * @param path the path between the altered view and current node. + * @param viewIdent the table identifier of the altered view, we compare two views by the + * `desc.identifier`. + */ + def checkCyclicViewReference( + plan: LogicalPlan, + path: Seq[TableIdentifier], + viewIdent: TableIdentifier): Unit = { + plan match { + case v: View => + val ident = v.desc.identifier + val newPath = path :+ ident + // If the table identifier equals to the `viewIdent`, current view node is the same with + // the altered view. We detect a view reference cycle, should throw an AnalysisException. + if (ident == viewIdent) { + throw new AnalysisException(s"Recursive view $viewIdent detected " + + s"(cycle: ${newPath.mkString(" -> ")})") + } else { + v.children.foreach { child => + checkCyclicViewReference(child, newPath, viewIdent) + } + } + case _ => + plan.children.foreach(child => checkCyclicViewReference(child, path, viewIdent)) + } + + // Detect cyclic references from subqueries. + plan.expressions.foreach { expr => + expr match { + case s: SubqueryExpression => + checkCyclicViewReference(s.plan, path, viewIdent) + case _ => // Do nothing. + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala index 0e5a1dc6ab629..2ca2206bb9d44 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala @@ -609,12 +609,39 @@ abstract class SQLViewSuite extends QueryTest with SQLTestUtils { } } - // TODO: Check for cyclic view references on ALTER VIEW. - ignore("correctly handle a cyclic view reference") { - withView("view1", "view2") { + test("correctly handle a cyclic view reference") { + withView("view1", "view2", "view3") { sql("CREATE VIEW view1 AS SELECT * FROM jt") sql("CREATE VIEW view2 AS SELECT * FROM view1") - intercept[AnalysisException](sql("ALTER VIEW view1 AS SELECT * FROM view2")) + sql("CREATE VIEW view3 AS SELECT * FROM view2") + + // Detect cyclic view reference on ALTER VIEW. + val e1 = intercept[AnalysisException] { + sql("ALTER VIEW view1 AS SELECT * FROM view2") + }.getMessage + assert(e1.contains("Recursive view `default`.`view1` detected (cycle: `default`.`view1` " + + "-> `default`.`view2` -> `default`.`view1`)")) + + // Detect the most left cycle when there exists multiple cyclic view references. + val e2 = intercept[AnalysisException] { + sql("ALTER VIEW view1 AS SELECT * FROM view3 JOIN view2") + }.getMessage + assert(e2.contains("Recursive view `default`.`view1` detected (cycle: `default`.`view1` " + + "-> `default`.`view3` -> `default`.`view2` -> `default`.`view1`)")) + + // Detect cyclic view reference on CREATE OR REPLACE VIEW. + val e3 = intercept[AnalysisException] { + sql("CREATE OR REPLACE VIEW view1 AS SELECT * FROM view2") + }.getMessage + assert(e3.contains("Recursive view `default`.`view1` detected (cycle: `default`.`view1` " + + "-> `default`.`view2` -> `default`.`view1`)")) + + // Detect cyclic view reference from subqueries. + val e4 = intercept[AnalysisException] { + sql("ALTER VIEW view1 AS SELECT * FROM jt WHERE EXISTS (SELECT 1 FROM view2)") + }.getMessage + assert(e4.contains("Recursive view `default`.`view1` detected (cycle: `default`.`view1` " + + "-> `default`.`view2` -> `default`.`view1`)")) } } } From ca849ac4e8fc520a4a12949b62b9730c5dfa097d Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Tue, 7 Mar 2017 20:32:51 -0800 Subject: [PATCH 36/78] [SPARK-19841][SS] watermarkPredicate should filter based on keys ## What changes were proposed in this pull request? `StreamingDeduplicateExec.watermarkPredicate` should filter based on keys. Otherwise, it may generate a wrong answer if the watermark column in `keyExpression` has a different position in the row. `StateStoreSaveExec` has the same codes but its parent can makes sure the watermark column positions in `keyExpression` and `row` are the same. ## How was this patch tested? The added test. Author: Shixiong Zhu Closes #17183 from zsxwing/SPARK-19841. --- .../streaming/statefulOperators.scala | 28 +++++++++++++------ .../sql/streaming/DeduplicateSuite.scala | 19 +++++++++++++ 2 files changed, 39 insertions(+), 8 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala index d92529748b6ac..cbf656a2044dc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala @@ -68,7 +68,7 @@ trait StateStoreWriter extends StatefulOperator { } /** An operator that supports watermark. */ -trait WatermarkSupport extends SparkPlan { +trait WatermarkSupport extends UnaryExecNode { /** The keys that may have a watermark attribute. */ def keyExpressions: Seq[Attribute] @@ -76,8 +76,8 @@ trait WatermarkSupport extends SparkPlan { /** The watermark value. */ def eventTimeWatermark: Option[Long] - /** Generate a predicate that matches data older than the watermark */ - lazy val watermarkPredicate: Option[Predicate] = { + /** Generate an expression that matches data older than the watermark */ + lazy val watermarkExpression: Option[Expression] = { val optionalWatermarkAttribute = keyExpressions.find(_.metadata.contains(EventTimeWatermark.delayKey)) @@ -96,9 +96,19 @@ trait WatermarkSupport extends SparkPlan { } logInfo(s"Filtering state store on: $evictionExpression") - newPredicate(evictionExpression, keyExpressions) + evictionExpression } } + + /** Generate a predicate based on keys that matches data older than the watermark */ + lazy val watermarkPredicateForKeys: Option[Predicate] = + watermarkExpression.map(newPredicate(_, keyExpressions)) + + /** + * Generate a predicate based on the child output that matches data older than the watermark. + */ + lazy val watermarkPredicate: Option[Predicate] = + watermarkExpression.map(newPredicate(_, child.output)) } /** @@ -192,7 +202,7 @@ case class StateStoreSaveExec( } // Assumption: Append mode can be done only when watermark has been specified - store.remove(watermarkPredicate.get.eval _) + store.remove(watermarkPredicateForKeys.get.eval _) store.commit() numTotalStateRows += store.numKeys() @@ -215,7 +225,9 @@ case class StateStoreSaveExec( override def hasNext: Boolean = { if (!baseIterator.hasNext) { // Remove old aggregates if watermark specified - if (watermarkPredicate.nonEmpty) store.remove(watermarkPredicate.get.eval _) + if (watermarkPredicateForKeys.nonEmpty) { + store.remove(watermarkPredicateForKeys.get.eval _) + } store.commit() numTotalStateRows += store.numKeys() false @@ -361,7 +373,7 @@ case class StreamingDeduplicateExec( val numUpdatedStateRows = longMetric("numUpdatedStateRows") val baseIterator = watermarkPredicate match { - case Some(predicate) => iter.filter((row: InternalRow) => !predicate.eval(row)) + case Some(predicate) => iter.filter(row => !predicate.eval(row)) case None => iter } @@ -381,7 +393,7 @@ case class StreamingDeduplicateExec( } CompletionIterator[InternalRow, Iterator[InternalRow]](result, { - watermarkPredicate.foreach(f => store.remove(f.eval _)) + watermarkPredicateForKeys.foreach(f => store.remove(f.eval _)) store.commit() numTotalStateRows += store.numKeys() }) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/DeduplicateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/DeduplicateSuite.scala index 7ea716231e5dc..a15c2cff930fc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/DeduplicateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/DeduplicateSuite.scala @@ -249,4 +249,23 @@ class DeduplicateSuite extends StateStoreMetricsTest with BeforeAndAfterAll { } } } + + test("SPARK-19841: watermarkPredicate should filter based on keys") { + val input = MemoryStream[(Int, Int)] + val df = input.toDS.toDF("time", "id") + .withColumn("time", $"time".cast("timestamp")) + .withWatermark("time", "1 second") + .dropDuplicates("id", "time") // Change the column positions + .select($"id") + testStream(df)( + AddData(input, 1 -> 1, 1 -> 1, 1 -> 2), + CheckLastBatch(1, 2), + AddData(input, 1 -> 1, 2 -> 3, 2 -> 4), + CheckLastBatch(3, 4), + AddData(input, 1 -> 0, 1 -> 1, 3 -> 5, 3 -> 6), // Drop (1 -> 0, 1 -> 1) due to watermark + CheckLastBatch(5, 6), + AddData(input, 1 -> 0, 4 -> 7), // Drop (1 -> 0) due to watermark + CheckLastBatch(7) + ) + } } From d8830c5039d9c7c5ef03631904c32873ab558e22 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Tue, 7 Mar 2017 20:34:55 -0800 Subject: [PATCH 37/78] [SPARK-19859][SS] The new watermark should override the old one ## What changes were proposed in this pull request? The new watermark should override the old one. Otherwise, we just pick up the first column which has a watermark, it may be unexpected. ## How was this patch tested? The new test. Author: Shixiong Zhu Closes #17199 from zsxwing/SPARK-19859. --- .../plans/logical/EventTimeWatermark.scala | 7 +++++++ .../sql/streaming/EventTimeWatermarkSuite.scala | 14 ++++++++++++++ 2 files changed, 21 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/EventTimeWatermark.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/EventTimeWatermark.scala index 77309ce391a1a..62f68a6d7b528 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/EventTimeWatermark.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/EventTimeWatermark.scala @@ -42,6 +42,13 @@ case class EventTimeWatermark( .putLong(EventTimeWatermark.delayKey, delay.milliseconds) .build() a.withMetadata(updatedMetadata) + } else if (a.metadata.contains(EventTimeWatermark.delayKey)) { + // Remove existing watermark + val updatedMetadata = new MetadataBuilder() + .withMetadata(a.metadata) + .remove(EventTimeWatermark.delayKey) + .build() + a.withMetadata(updatedMetadata) } else { a } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala index c34d119734cc0..c768525bc6855 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala @@ -25,6 +25,7 @@ import org.scalatest.BeforeAndAfter import org.apache.spark.internal.Logging import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.plans.logical.EventTimeWatermark import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.functions.{count, window} import org.apache.spark.sql.streaming.OutputMode._ @@ -305,6 +306,19 @@ class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Loggin ) } + test("the new watermark should override the old one") { + val df = MemoryStream[(Long, Long)].toDF() + .withColumn("first", $"_1".cast("timestamp")) + .withColumn("second", $"_2".cast("timestamp")) + .withWatermark("first", "1 minute") + .withWatermark("second", "2 minutes") + + val eventTimeColumns = df.logicalPlan.output + .filter(_.metadata.contains(EventTimeWatermark.delayKey)) + assert(eventTimeColumns.size === 1) + assert(eventTimeColumns(0).name === "second") + } + private def assertNumStateRows(numTotalRows: Long): AssertOnQuery = AssertOnQuery { q => val progressWithData = q.recentProgress.filter(_.numInputRows > 0).lastOption.get assert(progressWithData.stateOperators(0).numRowsTotal === numTotalRows) From 56e1bd337ccb03cb01702e4260e4be59d2aa0ead Mon Sep 17 00:00:00 2001 From: Asher Krim Date: Tue, 7 Mar 2017 20:36:46 -0800 Subject: [PATCH 38/78] [SPARK-17629][ML] methods to return synonyms directly ## What changes were proposed in this pull request? provide methods to return synonyms directly, without wrapping them in a dataframe In performance sensitive applications (such as user facing apis) the roundtrip to and from dataframes is costly and unnecessary The methods are named ``findSynonymsArray`` to make the return type clear, which also implies a local datastructure ## How was this patch tested? updated word2vec tests Author: Asher Krim Closes #16811 from Krimit/w2vFindSynonymsLocal. --- .../apache/spark/ml/feature/Word2Vec.scala | 37 ++++++++++++++++--- .../spark/ml/feature/Word2VecSuite.scala | 20 +++++++--- 2 files changed, 45 insertions(+), 12 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala index 42e8a66a62b61..4ca062c0b5adf 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala @@ -227,25 +227,50 @@ class Word2VecModel private[ml] ( /** * Find "num" number of words closest in similarity to the given word, not - * including the word itself. Returns a dataframe with the words and the - * cosine similarities between the synonyms and the given word. + * including the word itself. + * @return a dataframe with columns "word" and "similarity" of the word and the cosine + * similarities between the synonyms and the given word vector. */ @Since("1.5.0") def findSynonyms(word: String, num: Int): DataFrame = { val spark = SparkSession.builder().getOrCreate() - spark.createDataFrame(wordVectors.findSynonyms(word, num)).toDF("word", "similarity") + spark.createDataFrame(findSynonymsArray(word, num)).toDF("word", "similarity") } /** - * Find "num" number of words whose vector representation most similar to the supplied vector. + * Find "num" number of words whose vector representation is most similar to the supplied vector. * If the supplied vector is the vector representation of a word in the model's vocabulary, - * that word will be in the results. Returns a dataframe with the words and the cosine + * that word will be in the results. + * @return a dataframe with columns "word" and "similarity" of the word and the cosine * similarities between the synonyms and the given word vector. */ @Since("2.0.0") def findSynonyms(vec: Vector, num: Int): DataFrame = { val spark = SparkSession.builder().getOrCreate() - spark.createDataFrame(wordVectors.findSynonyms(vec, num)).toDF("word", "similarity") + spark.createDataFrame(findSynonymsArray(vec, num)).toDF("word", "similarity") + } + + /** + * Find "num" number of words whose vector representation is most similar to the supplied vector. + * If the supplied vector is the vector representation of a word in the model's vocabulary, + * that word will be in the results. + * @return an array of the words and the cosine similarities between the synonyms given + * word vector. + */ + @Since("2.2.0") + def findSynonymsArray(vec: Vector, num: Int): Array[(String, Double)] = { + wordVectors.findSynonyms(vec, num) + } + + /** + * Find "num" number of words closest in similarity to the given word, not + * including the word itself. + * @return an array of the words and the cosine similarities between the synonyms given + * word vector. + */ + @Since("2.2.0") + def findSynonymsArray(word: String, num: Int): Array[(String, Double)] = { + wordVectors.findSynonyms(word, num) } /** @group setParam */ diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala index 613cc3d60b227..2043a16c15f1a 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala @@ -133,14 +133,22 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul .setSeed(42L) .fit(docDF) - val expectedSimilarity = Array(0.2608488929093532, -0.8271274846926078) - val (synonyms, similarity) = model.findSynonyms("a", 2).rdd.map { + val expected = Map(("b", 0.2608488929093532), ("c", -0.8271274846926078)) + val findSynonymsResult = model.findSynonyms("a", 2).rdd.map { case Row(w: String, sim: Double) => (w, sim) - }.collect().unzip + }.collectAsMap() + + expected.foreach { + case (expectedSynonym, expectedSimilarity) => + assert(findSynonymsResult.contains(expectedSynonym)) + assert(expectedSimilarity ~== findSynonymsResult.get(expectedSynonym).get absTol 1E-5) + } - assert(synonyms === Array("b", "c")) - expectedSimilarity.zip(similarity).foreach { - case (expected, actual) => assert(math.abs((expected - actual) / expected) < 1E-5) + val findSynonymsArrayResult = model.findSynonymsArray("a", 2).toMap + findSynonymsResult.foreach { + case (expectedSynonym, expectedSimilarity) => + assert(findSynonymsArrayResult.contains(expectedSynonym)) + assert(expectedSimilarity ~== findSynonymsArrayResult.get(expectedSynonym).get absTol 1E-5) } } From 314e48a3584bad4b486b046bbf0159d64ba857bc Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Wed, 8 Mar 2017 01:32:42 -0800 Subject: [PATCH 39/78] [SPARK-18055][SQL] Use correct mirror in ExpresionEncoder Previously, we were using the mirror of passed in `TypeTag` when reflecting to build an encoder. This fails when the outer class is built in (i.e. `Seq`'s default mirror is based on root classloader) but inner classes (i.e. `A` in `Seq[A]`) are defined in the REPL or a library. This patch changes us to always reflect based on a mirror created using the context classloader. Author: Michael Armbrust Closes #17201 from marmbrus/replSeqEncoder. --- .../test/scala/org/apache/spark/repl/ReplSuite.scala | 11 +++++++++++ .../sql/catalyst/encoders/ExpressionEncoder.scala | 4 ++-- 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala b/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala index 55c91675ed3ba..121a02a9be0a1 100644 --- a/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala +++ b/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala @@ -473,4 +473,15 @@ class ReplSuite extends SparkFunSuite { assertDoesNotContain("AssertionError", output) assertDoesNotContain("Exception", output) } + + test("newProductSeqEncoder with REPL defined class") { + val output = runInterpreterInPasteMode("local-cluster[1,4,4096]", + """ + |case class Click(id: Int) + |spark.implicits.newProductSeqEncoder[Click] + """.stripMargin) + + assertDoesNotContain("error:", output) + assertDoesNotContain("Exception", output) + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index 0782143d465b3..93fc565a53419 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -45,8 +45,8 @@ import org.apache.spark.util.Utils object ExpressionEncoder { def apply[T : TypeTag](): ExpressionEncoder[T] = { // We convert the not-serializable TypeTag into StructType and ClassTag. - val mirror = typeTag[T].mirror - val tpe = typeTag[T].tpe + val mirror = ScalaReflection.mirror + val tpe = typeTag[T].in(mirror).tpe if (ScalaReflection.optionOfProductType(tpe)) { throw new UnsupportedOperationException( From 1fa58868bc6635ff2119264665bd3d00b4b1253a Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Wed, 8 Mar 2017 02:05:01 -0800 Subject: [PATCH 40/78] [ML][MINOR] Separate estimator and model params for read/write test. ## What changes were proposed in this pull request? Since we allow ```Estimator``` and ```Model``` not always share same params (see ```ALSParams``` and ```ALSModelParams```), we should pass in test params for estimator and model separately in function ```testEstimatorAndModelReadWrite```. ## How was this patch tested? Existing tests. Author: Yanbo Liang Closes #17151 from yanboliang/test-rw. --- .../DecisionTreeClassifierSuite.scala | 8 +++-- .../classification/GBTClassifierSuite.scala | 3 +- .../ml/classification/LinearSVCSuite.scala | 2 +- .../LogisticRegressionSuite.scala | 2 +- .../ml/classification/NaiveBayesSuite.scala | 3 +- .../RandomForestClassifierSuite.scala | 3 +- .../ml/clustering/BisectingKMeansSuite.scala | 4 +-- .../ml/clustering/GaussianMixtureSuite.scala | 2 +- .../spark/ml/clustering/KMeansSuite.scala | 3 +- .../apache/spark/ml/clustering/LDASuite.scala | 4 ++- .../BucketedRandomProjectionLSHSuite.scala | 2 +- .../spark/ml/feature/ChiSqSelectorSuite.scala | 3 +- .../spark/ml/feature/MinHashLSHSuite.scala | 2 +- .../apache/spark/ml/fpm/FPGrowthSuite.scala | 4 +-- .../spark/ml/recommendation/ALSSuite.scala | 35 +++++++------------ .../AFTSurvivalRegressionSuite.scala | 3 +- .../DecisionTreeRegressorSuite.scala | 5 +-- .../ml/regression/GBTRegressorSuite.scala | 3 +- .../GeneralizedLinearRegressionSuite.scala | 1 + .../regression/IsotonicRegressionSuite.scala | 2 +- .../ml/regression/LinearRegressionSuite.scala | 2 +- .../RandomForestRegressorSuite.scala | 3 +- .../spark/ml/util/DefaultReadWriteTest.scala | 14 ++++---- 23 files changed, 59 insertions(+), 54 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala index c711e7fa9dc67..10de50306a5ce 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala @@ -372,16 +372,18 @@ class DecisionTreeClassifierSuite // Categorical splits with tree depth 2 val categoricalData: DataFrame = TreeTests.setMetadata(rdd, Map(0 -> 2, 1 -> 3), numClasses = 2) - testEstimatorAndModelReadWrite(dt, categoricalData, allParamSettings, checkModelData) + testEstimatorAndModelReadWrite(dt, categoricalData, allParamSettings, + allParamSettings, checkModelData) // Continuous splits with tree depth 2 val continuousData: DataFrame = TreeTests.setMetadata(rdd, Map.empty[Int, Int], numClasses = 2) - testEstimatorAndModelReadWrite(dt, continuousData, allParamSettings, checkModelData) + testEstimatorAndModelReadWrite(dt, continuousData, allParamSettings, + allParamSettings, checkModelData) // Continuous splits with tree depth 0 testEstimatorAndModelReadWrite(dt, continuousData, allParamSettings ++ Map("maxDepth" -> 0), - checkModelData) + allParamSettings ++ Map("maxDepth" -> 0), checkModelData) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala index 0598943c3d4be..0cddb37281b39 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala @@ -374,7 +374,8 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext val continuousData: DataFrame = TreeTests.setMetadata(rdd, Map.empty[Int, Int], numClasses = 2) - testEstimatorAndModelReadWrite(gbt, continuousData, allParamSettings, checkModelData) + testEstimatorAndModelReadWrite(gbt, continuousData, allParamSettings, + allParamSettings, checkModelData) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala index fe47176a4aaa6..4c63a2a88c6c6 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala @@ -232,7 +232,7 @@ class LinearSVCSuite extends SparkFunSuite with MLlibTestSparkContext with Defau } val svm = new LinearSVC() testEstimatorAndModelReadWrite(svm, smallBinaryDataset, LinearSVCSuite.allParamSettings, - checkModelData) + LinearSVCSuite.allParamSettings, checkModelData) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index d89a958eed45a..affaa573749e8 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -2089,7 +2089,7 @@ class LogisticRegressionSuite } val lr = new LogisticRegression() testEstimatorAndModelReadWrite(lr, smallBinaryDataset, LogisticRegressionSuite.allParamSettings, - checkModelData) + LogisticRegressionSuite.allParamSettings, checkModelData) } test("should support all NumericType labels and weights, and not support other types") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala index 37d7991fe8dd8..4d5d299d1408f 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala @@ -280,7 +280,8 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with Defa assert(model.theta === model2.theta) } val nb = new NaiveBayes() - testEstimatorAndModelReadWrite(nb, dataset, NaiveBayesSuite.allParamSettings, checkModelData) + testEstimatorAndModelReadWrite(nb, dataset, NaiveBayesSuite.allParamSettings, + NaiveBayesSuite.allParamSettings, checkModelData) } test("should support all NumericType labels and weights, and not support other types") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala index 44e1585ee514b..c3003cec73b41 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala @@ -218,7 +218,8 @@ class RandomForestClassifierSuite val continuousData: DataFrame = TreeTests.setMetadata(rdd, Map.empty[Int, Int], numClasses = 2) - testEstimatorAndModelReadWrite(rf, continuousData, allParamSettings, checkModelData) + testEstimatorAndModelReadWrite(rf, continuousData, allParamSettings, + allParamSettings, checkModelData) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala index 30513c1e276ae..200a892f6c694 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala @@ -138,8 +138,8 @@ class BisectingKMeansSuite assert(model.clusterCenters === model2.clusterCenters) } val bisectingKMeans = new BisectingKMeans() - testEstimatorAndModelReadWrite( - bisectingKMeans, dataset, BisectingKMeansSuite.allParamSettings, checkModelData) + testEstimatorAndModelReadWrite(bisectingKMeans, dataset, BisectingKMeansSuite.allParamSettings, + BisectingKMeansSuite.allParamSettings, checkModelData) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala index c500c5b3e365a..61da897b666f4 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala @@ -163,7 +163,7 @@ class GaussianMixtureSuite extends SparkFunSuite with MLlibTestSparkContext assert(model.gaussians.map(_.cov) === model2.gaussians.map(_.cov)) } val gm = new GaussianMixture() - testEstimatorAndModelReadWrite(gm, dataset, + testEstimatorAndModelReadWrite(gm, dataset, GaussianMixtureSuite.allParamSettings, GaussianMixtureSuite.allParamSettings, checkModelData) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala index e10127f7d108f..ca05b9c389f65 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala @@ -150,7 +150,8 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR assert(model.clusterCenters === model2.clusterCenters) } val kmeans = new KMeans() - testEstimatorAndModelReadWrite(kmeans, dataset, KMeansSuite.allParamSettings, checkModelData) + testEstimatorAndModelReadWrite(kmeans, dataset, KMeansSuite.allParamSettings, + KMeansSuite.allParamSettings, checkModelData) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala index 9aa11fbdbe868..75aa0be61a3ed 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala @@ -250,7 +250,8 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead Vectors.dense(model2.getDocConcentration) absTol 1e-6) } val lda = new LDA() - testEstimatorAndModelReadWrite(lda, dataset, LDASuite.allParamSettings, checkModelData) + testEstimatorAndModelReadWrite(lda, dataset, LDASuite.allParamSettings, + LDASuite.allParamSettings, checkModelData) } test("read/write DistributedLDAModel") { @@ -271,6 +272,7 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead } val lda = new LDA() testEstimatorAndModelReadWrite(lda, dataset, + LDASuite.allParamSettings ++ Map("optimizer" -> "em"), LDASuite.allParamSettings ++ Map("optimizer" -> "em"), checkModelData) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSHSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSHSuite.scala index ab937685a555c..91eac9e733312 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSHSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSHSuite.scala @@ -63,7 +63,7 @@ class BucketedRandomProjectionLSHSuite } val mh = new BucketedRandomProjectionLSH() val settings = Map("inputCol" -> "keys", "outputCol" -> "values", "bucketLength" -> 1.0) - testEstimatorAndModelReadWrite(mh, dataset, settings, checkModelData) + testEstimatorAndModelReadWrite(mh, dataset, settings, settings, checkModelData) } test("hashFunction") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala index 482e5d54260d4..d6925da97d57e 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala @@ -151,7 +151,8 @@ class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext assert(model.selectedFeatures === model2.selectedFeatures) } val nb = new ChiSqSelector - testEstimatorAndModelReadWrite(nb, dataset, ChiSqSelectorSuite.allParamSettings, checkModelData) + testEstimatorAndModelReadWrite(nb, dataset, ChiSqSelectorSuite.allParamSettings, + ChiSqSelectorSuite.allParamSettings, checkModelData) } test("should support all NumericType labels and not support other types") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/MinHashLSHSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/MinHashLSHSuite.scala index 3461cdf82460f..a2f009310fd7a 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/MinHashLSHSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/MinHashLSHSuite.scala @@ -54,7 +54,7 @@ class MinHashLSHSuite extends SparkFunSuite with MLlibTestSparkContext with Defa } val mh = new MinHashLSH() val settings = Map("inputCol" -> "keys", "outputCol" -> "values") - testEstimatorAndModelReadWrite(mh, dataset, settings, checkModelData) + testEstimatorAndModelReadWrite(mh, dataset, settings, settings, checkModelData) } test("hashFunction") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala index 74c7461401905..076d55c180548 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala @@ -99,8 +99,8 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul model2.freqItemsets.sort("items").collect()) } val fPGrowth = new FPGrowth() - testEstimatorAndModelReadWrite( - fPGrowth, dataset, FPGrowthSuite.allParamSettings, checkModelData) + testEstimatorAndModelReadWrite(fPGrowth, dataset, FPGrowthSuite.allParamSettings, + FPGrowthSuite.allParamSettings, checkModelData) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala index e494ea89e63bd..a177ed13bf8ef 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala @@ -518,37 +518,26 @@ class ALSSuite } test("read/write") { - import ALSSuite._ - val (ratings, _) = genExplicitTestData(numUsers = 4, numItems = 4, rank = 1) - val als = new ALS() - allEstimatorParamSettings.foreach { case (p, v) => - als.set(als.getParam(p), v) - } val spark = this.spark import spark.implicits._ - val model = als.fit(ratings.toDF()) - - // Test Estimator save/load - val als2 = testDefaultReadWrite(als) - allEstimatorParamSettings.foreach { case (p, v) => - val param = als.getParam(p) - assert(als.get(param).get === als2.get(param).get) - } + import ALSSuite._ + val (ratings, _) = genExplicitTestData(numUsers = 4, numItems = 4, rank = 1) - // Test Model save/load - val model2 = testDefaultReadWrite(model) - allModelParamSettings.foreach { case (p, v) => - val param = model.getParam(p) - assert(model.get(param).get === model2.get(param).get) - } - assert(model.rank === model2.rank) def getFactors(df: DataFrame): Set[(Int, Array[Float])] = { df.select("id", "features").collect().map { case r => (r.getInt(0), r.getAs[Array[Float]](1)) }.toSet } - assert(getFactors(model.userFactors) === getFactors(model2.userFactors)) - assert(getFactors(model.itemFactors) === getFactors(model2.itemFactors)) + + def checkModelData(model: ALSModel, model2: ALSModel): Unit = { + assert(model.rank === model2.rank) + assert(getFactors(model.userFactors) === getFactors(model2.userFactors)) + assert(getFactors(model.itemFactors) === getFactors(model2.itemFactors)) + } + + val als = new ALS() + testEstimatorAndModelReadWrite(als, ratings.toDF(), allEstimatorParamSettings, + allModelParamSettings, checkModelData) } test("input type validation") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala index 3cd4b0ac308ef..708185a0943df 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala @@ -419,7 +419,8 @@ class AFTSurvivalRegressionSuite } val aft = new AFTSurvivalRegression() testEstimatorAndModelReadWrite(aft, datasetMultivariate, - AFTSurvivalRegressionSuite.allParamSettings, checkModelData) + AFTSurvivalRegressionSuite.allParamSettings, AFTSurvivalRegressionSuite.allParamSettings, + checkModelData) } test("SPARK-15892: Incorrectly merged AFTAggregator with zero total count") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala index 15fa26e8b5272..0e91284d03d98 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala @@ -165,16 +165,17 @@ class DecisionTreeRegressorSuite val categoricalData: DataFrame = TreeTests.setMetadata(rdd, Map(0 -> 2, 1 -> 3), numClasses = 0) testEstimatorAndModelReadWrite(dt, categoricalData, - TreeTests.allParamSettings, checkModelData) + TreeTests.allParamSettings, TreeTests.allParamSettings, checkModelData) // Continuous splits with tree depth 2 val continuousData: DataFrame = TreeTests.setMetadata(rdd, Map.empty[Int, Int], numClasses = 0) testEstimatorAndModelReadWrite(dt, continuousData, - TreeTests.allParamSettings, checkModelData) + TreeTests.allParamSettings, TreeTests.allParamSettings, checkModelData) // Continuous splits with tree depth 0 testEstimatorAndModelReadWrite(dt, continuousData, + TreeTests.allParamSettings ++ Map("maxDepth" -> 0), TreeTests.allParamSettings ++ Map("maxDepth" -> 0), checkModelData) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala index dcf3f9a1ea9b2..03c2f97797bce 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala @@ -184,7 +184,8 @@ class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext val allParamSettings = TreeTests.allParamSettings ++ Map("lossType" -> "squared") val continuousData: DataFrame = TreeTests.setMetadata(rdd, Map.empty[Int, Int], numClasses = 0) - testEstimatorAndModelReadWrite(gbt, continuousData, allParamSettings, checkModelData) + testEstimatorAndModelReadWrite(gbt, continuousData, allParamSettings, + allParamSettings, checkModelData) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala index add28a72b6808..401911763fa3b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala @@ -1418,6 +1418,7 @@ class GeneralizedLinearRegressionSuite val glr = new GeneralizedLinearRegression() testEstimatorAndModelReadWrite(glr, datasetPoissonLog, + GeneralizedLinearRegressionSuite.allParamSettings, GeneralizedLinearRegressionSuite.allParamSettings, checkModelData) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala index 8cbb2acad243e..f41a3601b1fa8 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala @@ -178,7 +178,7 @@ class IsotonicRegressionSuite val ir = new IsotonicRegression() testEstimatorAndModelReadWrite(ir, dataset, IsotonicRegressionSuite.allParamSettings, - checkModelData) + IsotonicRegressionSuite.allParamSettings, checkModelData) } test("should support all NumericType labels and weights, and not support other types") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala index 584a1b272f6c8..6a51e75e12a36 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala @@ -985,7 +985,7 @@ class LinearRegressionSuite } val lr = new LinearRegression() testEstimatorAndModelReadWrite(lr, datasetWithWeight, LinearRegressionSuite.allParamSettings, - checkModelData) + LinearRegressionSuite.allParamSettings, checkModelData) } test("should support all NumericType labels and weights, and not support other types") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala index c08335f9f84af..3bf0445ebd3dd 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala @@ -124,7 +124,8 @@ class RandomForestRegressorSuite extends SparkFunSuite with MLlibTestSparkContex val continuousData: DataFrame = TreeTests.setMetadata(rdd, Map.empty[Int, Int], numClasses = 0) - testEstimatorAndModelReadWrite(rf, continuousData, allParamSettings, checkModelData) + testEstimatorAndModelReadWrite(rf, continuousData, allParamSettings, + allParamSettings, checkModelData) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala index 553b8725b30a3..bfe8f12258bb8 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala @@ -85,11 +85,12 @@ trait DefaultReadWriteTest extends TempDirectory { self: Suite => * - Check Params on Estimator and Model * - Compare model data * - * This requires that the [[Estimator]] and [[Model]] share the same set of [[Param]]s. + * This requires that [[Model]]'s [[Param]]s should be a subset of [[Estimator]]'s [[Param]]s. * * @param estimator Estimator to test * @param dataset Dataset to pass to [[Estimator.fit()]] - * @param testParams Set of [[Param]] values to set in estimator + * @param testEstimatorParams Set of [[Param]] values to set in estimator + * @param testModelParams Set of [[Param]] values to set in model * @param checkModelData Method which takes the original and loaded [[Model]] and compares their * data. This method does not need to check [[Param]] values. * @tparam E Type of [[Estimator]] @@ -99,24 +100,25 @@ trait DefaultReadWriteTest extends TempDirectory { self: Suite => E <: Estimator[M] with MLWritable, M <: Model[M] with MLWritable]( estimator: E, dataset: Dataset[_], - testParams: Map[String, Any], + testEstimatorParams: Map[String, Any], + testModelParams: Map[String, Any], checkModelData: (M, M) => Unit): Unit = { // Set some Params to make sure set Params are serialized. - testParams.foreach { case (p, v) => + testEstimatorParams.foreach { case (p, v) => estimator.set(estimator.getParam(p), v) } val model = estimator.fit(dataset) // Test Estimator save/load val estimator2 = testDefaultReadWrite(estimator) - testParams.foreach { case (p, v) => + testEstimatorParams.foreach { case (p, v) => val param = estimator.getParam(p) assert(estimator.get(param).get === estimator2.get(param).get) } // Test Model save/load val model2 = testDefaultReadWrite(model) - testParams.foreach { case (p, v) => + testModelParams.foreach { case (p, v) => val param = model.getParam(p) assert(model.get(param).get === model2.get(param).get) } From 81303f7ca7808d51229411dce8feeed8c23dbe15 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Wed, 8 Mar 2017 02:09:36 -0800 Subject: [PATCH 41/78] [SPARK-19806][ML][PYSPARK] PySpark GeneralizedLinearRegression supports tweedie distribution. ## What changes were proposed in this pull request? PySpark ```GeneralizedLinearRegression``` supports tweedie distribution. ## How was this patch tested? Add unit tests. Author: Yanbo Liang Closes #17146 from yanboliang/spark-19806. --- .../GeneralizedLinearRegression.scala | 8 +-- python/pyspark/ml/regression.py | 61 ++++++++++++++++--- python/pyspark/ml/tests.py | 20 ++++++ 3 files changed, 77 insertions(+), 12 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala index 110764dc074f7..3be8b533ee3f3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala @@ -66,7 +66,7 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam /** * Param for the power in the variance function of the Tweedie distribution which provides * the relationship between the variance and mean of the distribution. - * Only applicable for the Tweedie family. + * Only applicable to the Tweedie family. * (see * Tweedie Distribution (Wikipedia)) * Supported values: 0 and [1, Inf). @@ -79,7 +79,7 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam final val variancePower: DoubleParam = new DoubleParam(this, "variancePower", "The power in the variance function of the Tweedie distribution which characterizes " + "the relationship between the variance and mean of the distribution. " + - "Only applicable for the Tweedie family. Supported values: 0 and [1, Inf).", + "Only applicable to the Tweedie family. Supported values: 0 and [1, Inf).", (x: Double) => x >= 1.0 || x == 0.0) /** @group getParam */ @@ -106,7 +106,7 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam def getLink: String = $(link) /** - * Param for the index in the power link function. Only applicable for the Tweedie family. + * Param for the index in the power link function. Only applicable to the Tweedie family. * Note that link power 0, 1, -1 or 0.5 corresponds to the Log, Identity, Inverse or Sqrt * link, respectively. * When not set, this value defaults to 1 - [[variancePower]], which matches the R "statmod" @@ -116,7 +116,7 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam */ @Since("2.2.0") final val linkPower: DoubleParam = new DoubleParam(this, "linkPower", - "The index in the power link function. Only applicable for the Tweedie family.") + "The index in the power link function. Only applicable to the Tweedie family.") /** @group getParam */ @Since("2.2.0") diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index b199bf282e4f2..3c3fcc8d9b8d8 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -1294,8 +1294,8 @@ class GeneralizedLinearRegression(JavaEstimator, HasLabelCol, HasFeaturesCol, Ha Fit a Generalized Linear Model specified by giving a symbolic description of the linear predictor (link function) and a description of the error distribution (family). It supports - "gaussian", "binomial", "poisson" and "gamma" as family. Valid link functions for each family - is listed below. The first link function of each family is the default one. + "gaussian", "binomial", "poisson", "gamma" and "tweedie" as family. Valid link functions for + each family is listed below. The first link function of each family is the default one. * "gaussian" -> "identity", "log", "inverse" @@ -1305,6 +1305,9 @@ class GeneralizedLinearRegression(JavaEstimator, HasLabelCol, HasFeaturesCol, Ha * "gamma" -> "inverse", "identity", "log" + * "tweedie" -> power link function specified through "linkPower". \ + The default link power in the tweedie family is 1 - variancePower. + .. seealso:: `GLM `_ >>> from pyspark.ml.linalg import Vectors @@ -1344,7 +1347,7 @@ class GeneralizedLinearRegression(JavaEstimator, HasLabelCol, HasFeaturesCol, Ha family = Param(Params._dummy(), "family", "The name of family which is a description of " + "the error distribution to be used in the model. Supported options: " + - "gaussian (default), binomial, poisson and gamma.", + "gaussian (default), binomial, poisson, gamma and tweedie.", typeConverter=TypeConverters.toString) link = Param(Params._dummy(), "link", "The name of link function which provides the " + "relationship between the linear predictor and the mean of the distribution " + @@ -1352,32 +1355,46 @@ class GeneralizedLinearRegression(JavaEstimator, HasLabelCol, HasFeaturesCol, Ha "and sqrt.", typeConverter=TypeConverters.toString) linkPredictionCol = Param(Params._dummy(), "linkPredictionCol", "link prediction (linear " + "predictor) column name", typeConverter=TypeConverters.toString) + variancePower = Param(Params._dummy(), "variancePower", "The power in the variance function " + + "of the Tweedie distribution which characterizes the relationship " + + "between the variance and mean of the distribution. Only applicable " + + "for the Tweedie family. Supported values: 0 and [1, Inf).", + typeConverter=TypeConverters.toFloat) + linkPower = Param(Params._dummy(), "linkPower", "The index in the power link function. " + + "Only applicable to the Tweedie family.", + typeConverter=TypeConverters.toFloat) @keyword_only def __init__(self, labelCol="label", featuresCol="features", predictionCol="prediction", family="gaussian", link=None, fitIntercept=True, maxIter=25, tol=1e-6, - regParam=0.0, weightCol=None, solver="irls", linkPredictionCol=None): + regParam=0.0, weightCol=None, solver="irls", linkPredictionCol=None, + variancePower=0.0, linkPower=None): """ __init__(self, labelCol="label", featuresCol="features", predictionCol="prediction", \ family="gaussian", link=None, fitIntercept=True, maxIter=25, tol=1e-6, \ - regParam=0.0, weightCol=None, solver="irls", linkPredictionCol=None) + regParam=0.0, weightCol=None, solver="irls", linkPredictionCol=None, \ + variancePower=0.0, linkPower=None) """ super(GeneralizedLinearRegression, self).__init__() self._java_obj = self._new_java_obj( "org.apache.spark.ml.regression.GeneralizedLinearRegression", self.uid) - self._setDefault(family="gaussian", maxIter=25, tol=1e-6, regParam=0.0, solver="irls") + self._setDefault(family="gaussian", maxIter=25, tol=1e-6, regParam=0.0, solver="irls", + variancePower=0.0) kwargs = self._input_kwargs + self.setParams(**kwargs) @keyword_only @since("2.0.0") def setParams(self, labelCol="label", featuresCol="features", predictionCol="prediction", family="gaussian", link=None, fitIntercept=True, maxIter=25, tol=1e-6, - regParam=0.0, weightCol=None, solver="irls", linkPredictionCol=None): + regParam=0.0, weightCol=None, solver="irls", linkPredictionCol=None, + variancePower=0.0, linkPower=None): """ setParams(self, labelCol="label", featuresCol="features", predictionCol="prediction", \ family="gaussian", link=None, fitIntercept=True, maxIter=25, tol=1e-6, \ - regParam=0.0, weightCol=None, solver="irls", linkPredictionCol=None) + regParam=0.0, weightCol=None, solver="irls", linkPredictionCol=None, \ + variancePower=0.0, linkPower=None) Sets params for generalized linear regression. """ kwargs = self._input_kwargs @@ -1428,6 +1445,34 @@ def getLink(self): """ return self.getOrDefault(self.link) + @since("2.2.0") + def setVariancePower(self, value): + """ + Sets the value of :py:attr:`variancePower`. + """ + return self._set(variancePower=value) + + @since("2.2.0") + def getVariancePower(self): + """ + Gets the value of variancePower or its default value. + """ + return self.getOrDefault(self.variancePower) + + @since("2.2.0") + def setLinkPower(self, value): + """ + Sets the value of :py:attr:`linkPower`. + """ + return self._set(linkPower=value) + + @since("2.2.0") + def getLinkPower(self): + """ + Gets the value of linkPower or its default value. + """ + return self.getOrDefault(self.linkPower) + class GeneralizedLinearRegressionModel(JavaModel, JavaPredictionModel, JavaMLWritable, JavaMLReadable): diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 352416055791e..f052f5bb770c6 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -1223,6 +1223,26 @@ def test_apply_binary_term_freqs(self): ": expected " + str(expected[i]) + ", got " + str(features[i])) +class GeneralizedLinearRegressionTest(SparkSessionTestCase): + + def test_tweedie_distribution(self): + + df = self.spark.createDataFrame( + [(1.0, Vectors.dense(0.0, 0.0)), + (1.0, Vectors.dense(1.0, 2.0)), + (2.0, Vectors.dense(0.0, 0.0)), + (2.0, Vectors.dense(1.0, 1.0)), ], ["label", "features"]) + + glr = GeneralizedLinearRegression(family="tweedie", variancePower=1.6) + model = glr.fit(df) + self.assertTrue(np.allclose(model.coefficients.toArray(), [-0.4645, 0.3402], atol=1E-4)) + self.assertTrue(np.isclose(model.intercept, 0.7841, atol=1E-4)) + + model2 = glr.setLinkPower(-1.0).fit(df) + self.assertTrue(np.allclose(model2.coefficients.toArray(), [-0.6667, 0.5], atol=1E-4)) + self.assertTrue(np.isclose(model2.intercept, 0.6667, atol=1E-4)) + + class ALSTest(SparkSessionTestCase): def test_storage_levels(self): From 3f9f9180c2e695ad468eb813df5feec41e169531 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Wed, 8 Mar 2017 11:31:01 +0000 Subject: [PATCH 42/78] [SPARK-19693][SQL] Make the SET mapreduce.job.reduces automatically converted to spark.sql.shuffle.partitions ## What changes were proposed in this pull request? Make the `SET mapreduce.job.reduces` automatically converted to `spark.sql.shuffle.partitions`, it's similar to `SET mapred.reduce.tasks`. ## How was this patch tested? unit tests Author: Yuming Wang Closes #17020 from wangyum/SPARK-19693. --- .../sql/execution/command/SetCommand.scala | 17 +++++++++++++++++ .../org/apache/spark/sql/internal/SQLConf.scala | 4 ++++ .../org/apache/spark/sql/SQLQuerySuite.scala | 12 ++++++++++++ 3 files changed, 33 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/SetCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/SetCommand.scala index 7afa4e78a3786..5f12830ee621f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/SetCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/SetCommand.scala @@ -60,6 +60,23 @@ case class SetCommand(kv: Option[(String, Option[String])]) extends RunnableComm } (keyValueOutput, runFunc) + case Some((SQLConf.Replaced.MAPREDUCE_JOB_REDUCES, Some(value))) => + val runFunc = (sparkSession: SparkSession) => { + logWarning( + s"Property ${SQLConf.Replaced.MAPREDUCE_JOB_REDUCES} is Hadoop's property, " + + s"automatically converted to ${SQLConf.SHUFFLE_PARTITIONS.key} instead.") + if (value.toInt < 1) { + val msg = + s"Setting negative ${SQLConf.Replaced.MAPREDUCE_JOB_REDUCES} for automatically " + + "determining the number of reducers is not supported." + throw new IllegalArgumentException(msg) + } else { + sparkSession.conf.set(SQLConf.SHUFFLE_PARTITIONS.key, value) + Seq(Row(SQLConf.SHUFFLE_PARTITIONS.key, value)) + } + } + (keyValueOutput, runFunc) + case Some((key @ SetCommand.VariableName(name), Some(value))) => val runFunc = (sparkSession: SparkSession) => { sparkSession.conf.set(name, value) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 461dfe3a66e1b..fd3acd42e8315 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -677,6 +677,10 @@ object SQLConf { object Deprecated { val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks" } + + object Replaced { + val MAPREDUCE_JOB_REDUCES = "mapreduce.job.reduces" + } } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 468ea0551298e..d9e0196c57957 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -1019,6 +1019,18 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { spark.sessionState.conf.clear() } + test("SET mapreduce.job.reduces automatically converted to spark.sql.shuffle.partitions") { + spark.sessionState.conf.clear() + val before = spark.conf.get(SQLConf.SHUFFLE_PARTITIONS.key).toInt + val newConf = before + 1 + sql(s"SET mapreduce.job.reduces=${newConf.toString}") + val after = spark.conf.get(SQLConf.SHUFFLE_PARTITIONS.key).toInt + assert(before != after) + assert(newConf === after) + intercept[IllegalArgumentException](sql(s"SET mapreduce.job.reduces=-1")) + spark.sessionState.conf.clear() + } + test("apply schema") { val schema1 = StructType( StructField("f1", IntegerType, false) :: From 9ea201cf6482c9c62c9428759d238063db62d66e Mon Sep 17 00:00:00 2001 From: Anthony Truchet Date: Wed, 8 Mar 2017 11:44:25 +0000 Subject: [PATCH 43/78] [SPARK-16440][MLLIB] Ensure broadcasted variables are destroyed even in case of exception ## What changes were proposed in this pull request? Ensure broadcasted variable are destroyed even in case of exception ## How was this patch tested? Word2VecSuite was run locally Author: Anthony Truchet Closes #14299 from AnthonyTruchet/SPARK-16440. --- .../apache/spark/mllib/feature/Word2Vec.scala | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index 2364d43aaa0e2..531c8b07910fc 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -30,6 +30,7 @@ import org.json4s.jackson.JsonMethods._ import org.apache.spark.SparkContext import org.apache.spark.annotation.Since import org.apache.spark.api.java.JavaRDD +import org.apache.spark.broadcast.Broadcast import org.apache.spark.internal.Logging import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.{Loader, Saveable} @@ -314,6 +315,20 @@ class Word2Vec extends Serializable with Logging { val expTable = sc.broadcast(createExpTable()) val bcVocab = sc.broadcast(vocab) val bcVocabHash = sc.broadcast(vocabHash) + try { + doFit(dataset, sc, expTable, bcVocab, bcVocabHash) + } finally { + expTable.destroy(blocking = false) + bcVocab.destroy(blocking = false) + bcVocabHash.destroy(blocking = false) + } + } + + private def doFit[S <: Iterable[String]]( + dataset: RDD[S], sc: SparkContext, + expTable: Broadcast[Array[Float]], + bcVocab: Broadcast[Array[VocabWord]], + bcVocabHash: Broadcast[mutable.HashMap[String, Int]]) = { // each partition is a collection of sentences, // will be translated into arrays of Index integer val sentences: RDD[Array[Int]] = dataset.mapPartitions { sentenceIter => @@ -435,9 +450,6 @@ class Word2Vec extends Serializable with Logging { bcSyn1Global.destroy(false) } newSentences.unpersist() - expTable.destroy(false) - bcVocab.destroy(false) - bcVocabHash.destroy(false) val wordArray = vocab.map(_.word) new Word2VecModel(wordArray.zipWithIndex.toMap, syn0Global) From e44274870dee308f4e3e8ce79457d8d19693b6e5 Mon Sep 17 00:00:00 2001 From: wangzhenhua Date: Wed, 8 Mar 2017 16:01:28 +0100 Subject: [PATCH 44/78] [SPARK-17080][SQL] join reorder ## What changes were proposed in this pull request? Reorder the joins using a dynamic programming algorithm (Selinger paper): First we put all items (basic joined nodes) into level 1, then we build all two-way joins at level 2 from plans at level 1 (single items), then build all 3-way joins from plans at previous levels (two-way joins and single items), then 4-way joins ... etc, until we build all n-way joins and pick the best plan among them. When building m-way joins, we only keep the best plan (with the lowest cost) for the same set of m items. E.g., for 3-way joins, we keep only the best plan for items {A, B, C} among plans (A J B) J C, (A J C) J B and (B J C) J A. Thus, the plans maintained for each level when reordering four items A, B, C, D are as follows: ``` level 1: p({A}), p({B}), p({C}), p({D}) level 2: p({A, B}), p({A, C}), p({A, D}), p({B, C}), p({B, D}), p({C, D}) level 3: p({A, B, C}), p({A, B, D}), p({A, C, D}), p({B, C, D}) level 4: p({A, B, C, D}) ``` where p({A, B, C, D}) is the final output plan. For cost evaluation, since physical costs for operators are not available currently, we use cardinalities and sizes to compute costs. ## How was this patch tested? add test cases Author: wangzhenhua Author: Zhenhua Wang Closes #17138 from wzhfy/joinReorder. --- .../spark/sql/catalyst/CatalystConf.scala | 8 + .../optimizer/CostBasedJoinReorder.scala | 297 ++++++++++++++++++ .../sql/catalyst/optimizer/Optimizer.scala | 2 + .../catalyst/optimizer/JoinReorderSuite.scala | 194 ++++++++++++ .../spark/sql/catalyst/plans/PlanTest.scala | 2 +- .../StatsEstimationTestBase.scala | 4 +- .../apache/spark/sql/internal/SQLConf.scala | 16 + .../sql/execution/SparkSqlParserSuite.scala | 2 +- 8 files changed, 521 insertions(+), 4 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala index 5f50ce1ba68ff..fb99cb27b847b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala @@ -60,6 +60,12 @@ trait CatalystConf { * Enables CBO for estimation of plan statistics when set true. */ def cboEnabled: Boolean + + /** Enables join reorder in CBO. */ + def joinReorderEnabled: Boolean + + /** The maximum number of joined nodes allowed in the dynamic programming algorithm. */ + def joinReorderDPThreshold: Int } @@ -75,6 +81,8 @@ case class SimpleCatalystConf( runSQLonFile: Boolean = true, crossJoinEnabled: Boolean = false, cboEnabled: Boolean = false, + joinReorderEnabled: Boolean = false, + joinReorderDPThreshold: Int = 12, warehousePath: String = "/user/hive/warehouse", sessionLocalTimeZone: String = TimeZone.getDefault().getID) extends CatalystConf diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala new file mode 100644 index 0000000000000..b694561e5372d --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala @@ -0,0 +1,297 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import scala.collection.mutable + +import org.apache.spark.sql.catalyst.CatalystConf +import org.apache.spark.sql.catalyst.expressions.{And, Attribute, AttributeSet, Expression, PredicateHelper} +import org.apache.spark.sql.catalyst.plans.{Inner, InnerLike} +import org.apache.spark.sql.catalyst.plans.logical.{BinaryNode, Join, LogicalPlan, Project} +import org.apache.spark.sql.catalyst.rules.Rule + + +/** + * Cost-based join reorder. + * We may have several join reorder algorithms in the future. This class is the entry of these + * algorithms, and chooses which one to use. + */ +case class CostBasedJoinReorder(conf: CatalystConf) extends Rule[LogicalPlan] with PredicateHelper { + def apply(plan: LogicalPlan): LogicalPlan = { + if (!conf.cboEnabled || !conf.joinReorderEnabled) { + plan + } else { + val result = plan transform { + case p @ Project(projectList, j @ Join(_, _, _: InnerLike, _)) => + reorder(p, p.outputSet) + case j @ Join(_, _, _: InnerLike, _) => + reorder(j, j.outputSet) + } + // After reordering is finished, convert OrderedJoin back to Join + result transform { + case oj: OrderedJoin => oj.join + } + } + } + + def reorder(plan: LogicalPlan, output: AttributeSet): LogicalPlan = { + val (items, conditions) = extractInnerJoins(plan) + val result = + // Do reordering if the number of items is appropriate and join conditions exist. + // We also need to check if costs of all items can be evaluated. + if (items.size > 2 && items.size <= conf.joinReorderDPThreshold && conditions.nonEmpty && + items.forall(_.stats(conf).rowCount.isDefined)) { + JoinReorderDP.search(conf, items, conditions, output).getOrElse(plan) + } else { + plan + } + // Set consecutive join nodes ordered. + replaceWithOrderedJoin(result) + } + + /** + * Extract consecutive inner joinable items and join conditions. + * This method works for bushy trees and left/right deep trees. + */ + private def extractInnerJoins(plan: LogicalPlan): (Seq[LogicalPlan], Set[Expression]) = { + plan match { + case Join(left, right, _: InnerLike, cond) => + val (leftPlans, leftConditions) = extractInnerJoins(left) + val (rightPlans, rightConditions) = extractInnerJoins(right) + (leftPlans ++ rightPlans, cond.toSet.flatMap(splitConjunctivePredicates) ++ + leftConditions ++ rightConditions) + case Project(projectList, join) if projectList.forall(_.isInstanceOf[Attribute]) => + extractInnerJoins(join) + case _ => + (Seq(plan), Set()) + } + } + + private def replaceWithOrderedJoin(plan: LogicalPlan): LogicalPlan = plan match { + case j @ Join(left, right, _: InnerLike, cond) => + val replacedLeft = replaceWithOrderedJoin(left) + val replacedRight = replaceWithOrderedJoin(right) + OrderedJoin(j.copy(left = replacedLeft, right = replacedRight)) + case p @ Project(_, join) => + p.copy(child = replaceWithOrderedJoin(join)) + case _ => + plan + } + + /** This is a wrapper class for a join node that has been ordered. */ + private case class OrderedJoin(join: Join) extends BinaryNode { + override def left: LogicalPlan = join.left + override def right: LogicalPlan = join.right + override def output: Seq[Attribute] = join.output + } +} + +/** + * Reorder the joins using a dynamic programming algorithm. This implementation is based on the + * paper: Access Path Selection in a Relational Database Management System. + * http://www.inf.ed.ac.uk/teaching/courses/adbs/AccessPath.pdf + * + * First we put all items (basic joined nodes) into level 0, then we build all two-way joins + * at level 1 from plans at level 0 (single items), then build all 3-way joins from plans + * at previous levels (two-way joins and single items), then 4-way joins ... etc, until we + * build all n-way joins and pick the best plan among them. + * + * When building m-way joins, we only keep the best plan (with the lowest cost) for the same set + * of m items. E.g., for 3-way joins, we keep only the best plan for items {A, B, C} among + * plans (A J B) J C, (A J C) J B and (B J C) J A. + * + * Thus the plans maintained for each level when reordering four items A, B, C, D are as follows: + * level 0: p({A}), p({B}), p({C}), p({D}) + * level 1: p({A, B}), p({A, C}), p({A, D}), p({B, C}), p({B, D}), p({C, D}) + * level 2: p({A, B, C}), p({A, B, D}), p({A, C, D}), p({B, C, D}) + * level 3: p({A, B, C, D}) + * where p({A, B, C, D}) is the final output plan. + * + * For cost evaluation, since physical costs for operators are not available currently, we use + * cardinalities and sizes to compute costs. + */ +object JoinReorderDP extends PredicateHelper { + + def search( + conf: CatalystConf, + items: Seq[LogicalPlan], + conditions: Set[Expression], + topOutput: AttributeSet): Option[LogicalPlan] = { + + // Level i maintains all found plans for i + 1 items. + // Create the initial plans: each plan is a single item with zero cost. + val itemIndex = items.zipWithIndex + val foundPlans = mutable.Buffer[JoinPlanMap](itemIndex.map { + case (item, id) => Set(id) -> JoinPlan(Set(id), item, Set(), Cost(0, 0)) + }.toMap) + + for (lev <- 1 until items.length) { + // Build plans for the next level. + foundPlans += searchLevel(foundPlans, conf, conditions, topOutput) + } + + val plansLastLevel = foundPlans(items.length - 1) + if (plansLastLevel.isEmpty) { + // Failed to find a plan, fall back to the original plan + None + } else { + // There must be only one plan at the last level, which contains all items. + assert(plansLastLevel.size == 1 && plansLastLevel.head._1.size == items.length) + Some(plansLastLevel.head._2.plan) + } + } + + /** Find all possible plans at the next level, based on existing levels. */ + private def searchLevel( + existingLevels: Seq[JoinPlanMap], + conf: CatalystConf, + conditions: Set[Expression], + topOutput: AttributeSet): JoinPlanMap = { + + val nextLevel = mutable.Map.empty[Set[Int], JoinPlan] + var k = 0 + val lev = existingLevels.length - 1 + // Build plans for the next level from plans at level k (one side of the join) and level + // lev - k (the other side of the join). + // For the lower level k, we only need to search from 0 to lev - k, because when building + // a join from A and B, both A J B and B J A are handled. + while (k <= lev - k) { + val oneSideCandidates = existingLevels(k).values.toSeq + for (i <- oneSideCandidates.indices) { + val oneSidePlan = oneSideCandidates(i) + val otherSideCandidates = if (k == lev - k) { + // Both sides of a join are at the same level, no need to repeat for previous ones. + oneSideCandidates.drop(i) + } else { + existingLevels(lev - k).values.toSeq + } + + otherSideCandidates.foreach { otherSidePlan => + // Should not join two overlapping item sets. + if (oneSidePlan.itemIds.intersect(otherSidePlan.itemIds).isEmpty) { + val joinPlan = buildJoin(oneSidePlan, otherSidePlan, conf, conditions, topOutput) + // Check if it's the first plan for the item set, or it's a better plan than + // the existing one due to lower cost. + val existingPlan = nextLevel.get(joinPlan.itemIds) + if (existingPlan.isEmpty || joinPlan.cost.lessThan(existingPlan.get.cost)) { + nextLevel.update(joinPlan.itemIds, joinPlan) + } + } + } + } + k += 1 + } + nextLevel.toMap + } + + /** Build a new join node. */ + private def buildJoin( + oneJoinPlan: JoinPlan, + otherJoinPlan: JoinPlan, + conf: CatalystConf, + conditions: Set[Expression], + topOutput: AttributeSet): JoinPlan = { + + val onePlan = oneJoinPlan.plan + val otherPlan = otherJoinPlan.plan + // Now both onePlan and otherPlan become intermediate joins, so the cost of the + // new join should also include their own cardinalities and sizes. + val newCost = if (isCartesianProduct(onePlan) || isCartesianProduct(otherPlan)) { + // We consider cartesian product very expensive, thus set a very large cost for it. + // This enables to plan all the cartesian products at the end, because having a cartesian + // product as an intermediate join will significantly increase a plan's cost, making it + // impossible to be selected as the best plan for the items, unless there's no other choice. + Cost( + rows = BigInt(Long.MaxValue) * BigInt(Long.MaxValue), + size = BigInt(Long.MaxValue) * BigInt(Long.MaxValue)) + } else { + val onePlanStats = onePlan.stats(conf) + val otherPlanStats = otherPlan.stats(conf) + Cost( + rows = oneJoinPlan.cost.rows + onePlanStats.rowCount.get + + otherJoinPlan.cost.rows + otherPlanStats.rowCount.get, + size = oneJoinPlan.cost.size + onePlanStats.sizeInBytes + + otherJoinPlan.cost.size + otherPlanStats.sizeInBytes) + } + + // Put the deeper side on the left, tend to build a left-deep tree. + val (left, right) = if (oneJoinPlan.itemIds.size >= otherJoinPlan.itemIds.size) { + (onePlan, otherPlan) + } else { + (otherPlan, onePlan) + } + val joinConds = conditions + .filterNot(l => canEvaluate(l, onePlan)) + .filterNot(r => canEvaluate(r, otherPlan)) + .filter(e => e.references.subsetOf(onePlan.outputSet ++ otherPlan.outputSet)) + // We use inner join whether join condition is empty or not. Since cross join is + // equivalent to inner join without condition. + val newJoin = Join(left, right, Inner, joinConds.reduceOption(And)) + val collectedJoinConds = joinConds ++ oneJoinPlan.joinConds ++ otherJoinPlan.joinConds + val remainingConds = conditions -- collectedJoinConds + val neededAttr = AttributeSet(remainingConds.flatMap(_.references)) ++ topOutput + val neededFromNewJoin = newJoin.outputSet.filter(neededAttr.contains) + val newPlan = + if ((newJoin.outputSet -- neededFromNewJoin).nonEmpty) { + Project(neededFromNewJoin.toSeq, newJoin) + } else { + newJoin + } + + val itemIds = oneJoinPlan.itemIds.union(otherJoinPlan.itemIds) + JoinPlan(itemIds, newPlan, collectedJoinConds, newCost) + } + + private def isCartesianProduct(plan: LogicalPlan): Boolean = plan match { + case Join(_, _, _, None) => true + case Project(_, Join(_, _, _, None)) => true + case _ => false + } + + /** Map[set of item ids, join plan for these items] */ + type JoinPlanMap = Map[Set[Int], JoinPlan] + + /** + * Partial join order in a specific level. + * + * @param itemIds Set of item ids participating in this partial plan. + * @param plan The plan tree with the lowest cost for these items found so far. + * @param joinConds Join conditions included in the plan. + * @param cost The cost of this plan is the sum of costs of all intermediate joins. + */ + case class JoinPlan(itemIds: Set[Int], plan: LogicalPlan, joinConds: Set[Expression], cost: Cost) +} + +/** This class defines the cost model. */ +case class Cost(rows: BigInt, size: BigInt) { + /** + * An empirical value for the weights of cardinality (number of rows) in the cost formula: + * cost = rows * weight + size * (1 - weight), usually cardinality is more important than size. + */ + val weight = 0.7 + + def lessThan(other: Cost): Boolean = { + if (other.rows == 0 || other.size == 0) { + false + } else { + val relativeRows = BigDecimal(rows) / BigDecimal(other.rows) + val relativeSize = BigDecimal(size) / BigDecimal(other.size) + relativeRows * weight + relativeSize * (1 - weight) < 1 + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 036da3ad2062f..d5bbc6e8acc94 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -118,6 +118,8 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: CatalystConf) SimplifyCreateMapOps) :: Batch("Check Cartesian Products", Once, CheckCartesianProducts(conf)) :: + Batch("Join Reorder", Once, + CostBasedJoinReorder(conf)) :: Batch("Decimal Optimizations", fixedPoint, DecimalAggregates(conf)) :: Batch("Typed Filter Optimization", fixedPoint, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala new file mode 100644 index 0000000000000..1b2f7a66b6a0b --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala @@ -0,0 +1,194 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.SimpleCatalystConf +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap} +import org.apache.spark.sql.catalyst.plans.{Inner, PlanTest} +import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Join, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.RuleExecutor +import org.apache.spark.sql.catalyst.statsEstimation.{StatsEstimationTestBase, StatsTestPlan} +import org.apache.spark.sql.catalyst.util._ + + +class JoinReorderSuite extends PlanTest with StatsEstimationTestBase { + + override val conf = SimpleCatalystConf( + caseSensitiveAnalysis = true, cboEnabled = true, joinReorderEnabled = true) + + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = + Batch("Operator Optimizations", FixedPoint(100), + CombineFilters, + PushDownPredicate, + PushPredicateThroughJoin, + ColumnPruning, + CollapseProject) :: + Batch("Join Reorder", Once, + CostBasedJoinReorder(conf)) :: Nil + } + + /** Set up tables and columns for testing */ + private val columnInfo: AttributeMap[ColumnStat] = AttributeMap(Seq( + attr("t1.k-1-2") -> ColumnStat(distinctCount = 2, min = Some(1), max = Some(2), + nullCount = 0, avgLen = 4, maxLen = 4), + attr("t1.v-1-10") -> ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), + nullCount = 0, avgLen = 4, maxLen = 4), + attr("t2.k-1-5") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), + nullCount = 0, avgLen = 4, maxLen = 4), + attr("t3.v-1-100") -> ColumnStat(distinctCount = 100, min = Some(1), max = Some(100), + nullCount = 0, avgLen = 4, maxLen = 4), + attr("t4.k-1-2") -> ColumnStat(distinctCount = 2, min = Some(1), max = Some(2), + nullCount = 0, avgLen = 4, maxLen = 4), + attr("t4.v-1-10") -> ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), + nullCount = 0, avgLen = 4, maxLen = 4) + )) + + private val nameToAttr: Map[String, Attribute] = columnInfo.map(kv => kv._1.name -> kv._1) + private val nameToColInfo: Map[String, (Attribute, ColumnStat)] = + columnInfo.map(kv => kv._1.name -> kv) + + // Table t1/t4: big table with two columns + private val t1 = StatsTestPlan( + outputList = Seq("t1.k-1-2", "t1.v-1-10").map(nameToAttr), + rowCount = 1000, + // size = rows * (overhead + column length) + size = Some(1000 * (8 + 4 + 4)), + attributeStats = AttributeMap(Seq("t1.k-1-2", "t1.v-1-10").map(nameToColInfo))) + + private val t4 = StatsTestPlan( + outputList = Seq("t4.k-1-2", "t4.v-1-10").map(nameToAttr), + rowCount = 2000, + size = Some(2000 * (8 + 4 + 4)), + attributeStats = AttributeMap(Seq("t4.k-1-2", "t4.v-1-10").map(nameToColInfo))) + + // Table t2/t3: small table with only one column + private val t2 = StatsTestPlan( + outputList = Seq("t2.k-1-5").map(nameToAttr), + rowCount = 20, + size = Some(20 * (8 + 4)), + attributeStats = AttributeMap(Seq("t2.k-1-5").map(nameToColInfo))) + + private val t3 = StatsTestPlan( + outputList = Seq("t3.v-1-100").map(nameToAttr), + rowCount = 100, + size = Some(100 * (8 + 4)), + attributeStats = AttributeMap(Seq("t3.v-1-100").map(nameToColInfo))) + + test("reorder 3 tables") { + val originalPlan = + t1.join(t2).join(t3).where((nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5")) && + (nameToAttr("t1.v-1-10") === nameToAttr("t3.v-1-100"))) + + // The cost of original plan (use only cardinality to simplify explanation): + // cost = cost(t1 J t2) = 1000 * 20 / 5 = 4000 + // In contrast, the cost of the best plan: + // cost = cost(t1 J t3) = 1000 * 100 / 100 = 1000 < 4000 + // so (t1 J t3) J t2 is better (has lower cost, i.e. intermediate result size) than + // the original order (t1 J t2) J t3. + val bestPlan = + t1.join(t3, Inner, Some(nameToAttr("t1.v-1-10") === nameToAttr("t3.v-1-100"))) + .join(t2, Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5"))) + + assertEqualPlans(originalPlan, bestPlan) + } + + test("reorder 3 tables - put cross join at the end") { + val originalPlan = + t1.join(t2).join(t3).where(nameToAttr("t1.v-1-10") === nameToAttr("t3.v-1-100")) + + val bestPlan = + t1.join(t3, Inner, Some(nameToAttr("t1.v-1-10") === nameToAttr("t3.v-1-100"))) + .join(t2, Inner, None) + + assertEqualPlans(originalPlan, bestPlan) + } + + test("reorder 3 tables with pure-attribute project") { + val originalPlan = + t1.join(t2).join(t3).where((nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5")) && + (nameToAttr("t1.v-1-10") === nameToAttr("t3.v-1-100"))) + .select(nameToAttr("t1.v-1-10")) + + val bestPlan = + t1.join(t3, Inner, Some(nameToAttr("t1.v-1-10") === nameToAttr("t3.v-1-100"))) + .select(nameToAttr("t1.k-1-2"), nameToAttr("t1.v-1-10")) + .join(t2, Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5"))) + .select(nameToAttr("t1.v-1-10")) + + assertEqualPlans(originalPlan, bestPlan) + } + + test("don't reorder if project contains non-attribute") { + val originalPlan = + t1.join(t2, Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5"))) + .select((nameToAttr("t1.k-1-2") + nameToAttr("t2.k-1-5")) as "key", nameToAttr("t1.v-1-10")) + .join(t3, Inner, Some(nameToAttr("t1.v-1-10") === nameToAttr("t3.v-1-100"))) + .select("key".attr) + + assertEqualPlans(originalPlan, originalPlan) + } + + test("reorder 4 tables (bushy tree)") { + val originalPlan = + t1.join(t4).join(t2).join(t3).where((nameToAttr("t1.k-1-2") === nameToAttr("t4.k-1-2")) && + (nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5")) && + (nameToAttr("t4.v-1-10") === nameToAttr("t3.v-1-100"))) + + // The cost of original plan (use only cardinality to simplify explanation): + // cost(t1 J t4) = 1000 * 2000 / 2 = 1000000, cost(t1t4 J t2) = 1000000 * 20 / 5 = 4000000, + // cost = cost(t1 J t4) + cost(t1t4 J t2) = 5000000 + // In contrast, the cost of the best plan (a bushy tree): + // cost(t1 J t2) = 1000 * 20 / 5 = 4000, cost(t4 J t3) = 2000 * 100 / 100 = 2000, + // cost = cost(t1 J t2) + cost(t4 J t3) = 6000 << 5000000. + val bestPlan = + t1.join(t2, Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5"))) + .join(t4.join(t3, Inner, Some(nameToAttr("t4.v-1-10") === nameToAttr("t3.v-1-100"))), + Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t4.k-1-2"))) + + assertEqualPlans(originalPlan, bestPlan) + } + + private def assertEqualPlans( + originalPlan: LogicalPlan, + groundTruthBestPlan: LogicalPlan): Unit = { + val optimized = Optimize.execute(originalPlan.analyze) + val normalized1 = normalizePlan(normalizeExprIds(optimized)) + val normalized2 = normalizePlan(normalizeExprIds(groundTruthBestPlan.analyze)) + if (!sameJoinPlan(normalized1, normalized2)) { + fail( + s""" + |== FAIL: Plans do not match === + |${sideBySide(normalized1.treeString, normalized2.treeString).mkString("\n")} + """.stripMargin) + } + } + + /** Consider symmetry for joins when comparing plans. */ + private def sameJoinPlan(plan1: LogicalPlan, plan2: LogicalPlan): Boolean = { + (plan1, plan2) match { + case (j1: Join, j2: Join) => + (sameJoinPlan(j1.left, j2.left) && sameJoinPlan(j1.right, j2.right)) || + (sameJoinPlan(j1.left, j2.right) && sameJoinPlan(j1.right, j2.left)) + case _ => + plan1 == plan2 + } + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala index 3b7e5e938a8e4..e9b7a0c6ad671 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala @@ -62,7 +62,7 @@ abstract class PlanTest extends SparkFunSuite with PredicateHelper { * - Sample the seed will replaced by 0L. * - Join conditions will be resorted by hashCode. */ - private def normalizePlan(plan: LogicalPlan): LogicalPlan = { + protected def normalizePlan(plan: LogicalPlan): LogicalPlan = { plan transform { case filter @ Filter(condition: Expression, child: LogicalPlan) => Filter(splitConjunctivePredicates(condition).map(rewriteEqual(_)).sortBy(_.hashCode()) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsEstimationTestBase.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsEstimationTestBase.scala index c56b41ce37636..9b2b8dbe1bf4a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsEstimationTestBase.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsEstimationTestBase.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LeafNode, Logica import org.apache.spark.sql.types.{IntegerType, StringType} -class StatsEstimationTestBase extends SparkFunSuite { +trait StatsEstimationTestBase extends SparkFunSuite { /** Enable stats estimation based on CBO. */ protected val conf = SimpleCatalystConf(caseSensitiveAnalysis = true, cboEnabled = true) @@ -48,7 +48,7 @@ class StatsEstimationTestBase extends SparkFunSuite { /** * This class is used for unit-testing. It's a logical plan whose output and stats are passed in. */ -protected case class StatsTestPlan( +case class StatsTestPlan( outputList: Seq[Attribute], rowCount: BigInt, attributeStats: AttributeMap[ColumnStat], diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index fd3acd42e8315..94e3fa7dd13f7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -668,6 +668,18 @@ object SQLConf { .booleanConf .createWithDefault(false) + val JOIN_REORDER_ENABLED = + buildConf("spark.sql.cbo.joinReorder.enabled") + .doc("Enables join reorder in CBO.") + .booleanConf + .createWithDefault(false) + + val JOIN_REORDER_DP_THRESHOLD = + buildConf("spark.sql.cbo.joinReorder.dp.threshold") + .doc("The maximum number of joined nodes allowed in the dynamic programming algorithm.") + .intConf + .createWithDefault(12) + val SESSION_LOCAL_TIMEZONE = buildConf("spark.sql.session.timeZone") .doc("""The ID of session local timezone, e.g. "GMT", "America/Los_Angeles", etc.""") @@ -885,6 +897,10 @@ private[sql] class SQLConf extends Serializable with CatalystConf with Logging { override def cboEnabled: Boolean = getConf(SQLConf.CBO_ENABLED) + override def joinReorderEnabled: Boolean = getConf(SQLConf.JOIN_REORDER_ENABLED) + + override def joinReorderDPThreshold: Int = getConf(SQLConf.JOIN_REORDER_DP_THRESHOLD) + /** ********************** SQLConf functionality methods ************ */ /** Set Spark SQL configuration properties. */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala index d44a6e41cb347..a4d012cd76115 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala @@ -45,7 +45,7 @@ class SparkSqlParserSuite extends PlanTest { * Normalizes plans: * - CreateTable the createTime in tableDesc will replaced by -1L. */ - private def normalizePlan(plan: LogicalPlan): LogicalPlan = { + override def normalizePlan(plan: LogicalPlan): LogicalPlan = { plan match { case CreateTable(tableDesc, mode, query) => val newTableDesc = tableDesc.copy(createTime = -1L) From 5f7d835d380c1a558a4a6d8366140cd96ee202eb Mon Sep 17 00:00:00 2001 From: jiangxingbo Date: Wed, 8 Mar 2017 16:18:17 +0100 Subject: [PATCH 45/78] [SPARK-19865][SQL] remove the view identifier in SubqueryAlias ## What changes were proposed in this pull request? Since we have a `View` node now, we can remove the view identifier in `SubqueryAlias`, which was used to indicate a view node before. ## How was this patch tested? Update the related test cases. Author: jiangxingbo Closes #17210 from jiangxb1987/SubqueryAlias. --- .../spark/sql/catalyst/analysis/Analyzer.scala | 4 ++-- .../sql/catalyst/catalog/SessionCatalog.scala | 8 ++++---- .../apache/spark/sql/catalyst/dsl/package.scala | 4 ++-- .../spark/sql/catalyst/optimizer/subquery.scala | 8 ++++---- .../spark/sql/catalyst/parser/AstBuilder.scala | 6 +++--- .../plans/logical/basicLogicalOperators.scala | 3 +-- .../sql/catalyst/analysis/AnalysisSuite.scala | 16 ++++++++-------- .../catalyst/catalog/SessionCatalogSuite.scala | 6 +++--- .../catalyst/optimizer/ColumnPruningSuite.scala | 8 ++++---- .../EliminateSubqueryAliasesSuite.scala | 6 +++--- .../optimizer/JoinOptimizationSuite.scala | 8 ++++---- .../sql/catalyst/parser/PlanParserSuite.scala | 2 +- .../scala/org/apache/spark/sql/Dataset.scala | 2 +- .../sql/execution/joins/BroadcastJoinSuite.scala | 3 --- .../sql/hive/HiveMetastoreCatalogSuite.scala | 2 +- .../spark/sql/hive/execution/SQLQuerySuite.scala | 2 +- 16 files changed, 42 insertions(+), 46 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index ffa5aed30e19f..93666f14958e9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -598,7 +598,7 @@ class Analyzer( execute(child) } view.copy(child = newChild) - case p @ SubqueryAlias(_, view: View, _) => + case p @ SubqueryAlias(_, view: View) => val newChild = resolveRelation(view) p.copy(child = newChild) case _ => plan @@ -2363,7 +2363,7 @@ class Analyzer( */ object EliminateSubqueryAliases extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { - case SubqueryAlias(_, child, _) => child + case SubqueryAlias(_, child) => child } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index 498bfbde9d7a1..831e37aac1246 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -578,7 +578,7 @@ class SessionCatalog( val table = formatTableName(name.table) if (db == globalTempViewManager.database) { globalTempViewManager.get(table).map { viewDef => - SubqueryAlias(table, viewDef, None) + SubqueryAlias(table, viewDef) }.getOrElse(throw new NoSuchTableException(db, table)) } else if (name.database.isDefined || !tempTables.contains(table)) { val metadata = externalCatalog.getTable(db, table) @@ -591,17 +591,17 @@ class SessionCatalog( desc = metadata, output = metadata.schema.toAttributes, child = parser.parsePlan(viewText)) - SubqueryAlias(table, child, Some(name.copy(table = table, database = Some(db)))) + SubqueryAlias(table, child) } else { val tableRelation = CatalogRelation( metadata, // we assume all the columns are nullable. metadata.dataSchema.asNullable.toAttributes, metadata.partitionSchema.asNullable.toAttributes) - SubqueryAlias(table, tableRelation, None) + SubqueryAlias(table, tableRelation) } } else { - SubqueryAlias(table, tempTables(table), None) + SubqueryAlias(table, tempTables(table)) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index c062e4e84bcdd..0f0d90494f98c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -346,7 +346,7 @@ package object dsl { orderSpec: Seq[SortOrder]): LogicalPlan = Window(windowExpressions, partitionSpec, orderSpec, logicalPlan) - def subquery(alias: Symbol): LogicalPlan = SubqueryAlias(alias.name, logicalPlan, None) + def subquery(alias: Symbol): LogicalPlan = SubqueryAlias(alias.name, logicalPlan) def except(otherPlan: LogicalPlan): LogicalPlan = Except(logicalPlan, otherPlan) @@ -368,7 +368,7 @@ package object dsl { analysis.UnresolvedRelation(TableIdentifier(tableName)), Map.empty, logicalPlan, overwrite, false) - def as(alias: String): LogicalPlan = SubqueryAlias(alias, logicalPlan, None) + def as(alias: String): LogicalPlan = SubqueryAlias(alias, logicalPlan) def repartition(num: Integer): LogicalPlan = Repartition(num, shuffle = true, logicalPlan) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala index 4d62cce9da0ac..fb7ce6aecea53 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala @@ -169,7 +169,7 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] { // and Project operators, followed by an optional Filter, followed by an // Aggregate. Traverse the operators recursively. def evalPlan(lp : LogicalPlan) : Map[ExprId, Option[Any]] = lp match { - case SubqueryAlias(_, child, _) => evalPlan(child) + case SubqueryAlias(_, child) => evalPlan(child) case Filter(condition, child) => val bindings = evalPlan(child) if (bindings.isEmpty) bindings @@ -227,7 +227,7 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] { topPart += p bottomPart = child - case s @ SubqueryAlias(_, child, _) => + case s @ SubqueryAlias(_, child) => topPart += s bottomPart = child @@ -298,8 +298,8 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] { topPart.reverse.foreach { case Project(projList, _) => subqueryRoot = Project(projList ++ havingInputs, subqueryRoot) - case s @ SubqueryAlias(alias, _, None) => - subqueryRoot = SubqueryAlias(alias, subqueryRoot, None) + case s @ SubqueryAlias(alias, _) => + subqueryRoot = SubqueryAlias(alias, subqueryRoot) case op => sys.error(s"Unexpected operator $op in corelated subquery") } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index d2e091f4dda69..3cf11adc1953b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -108,7 +108,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { * This is only used for Common Table Expressions. */ override def visitNamedQuery(ctx: NamedQueryContext): SubqueryAlias = withOrigin(ctx) { - SubqueryAlias(ctx.name.getText, plan(ctx.query), None) + SubqueryAlias(ctx.name.getText, plan(ctx.query)) } /** @@ -666,7 +666,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { val tableWithAlias = Option(ctx.strictIdentifier).map(_.getText) match { case Some(strictIdentifier) => - SubqueryAlias(strictIdentifier, table, None) + SubqueryAlias(strictIdentifier, table) case _ => table } tableWithAlias.optionalMap(ctx.sample)(withSample) @@ -731,7 +731,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { * Create an alias (SubqueryAlias) for a LogicalPlan. */ private def aliasPlan(alias: ParserRuleContext, plan: LogicalPlan): LogicalPlan = { - SubqueryAlias(alias.getText, plan, None) + SubqueryAlias(alias.getText, plan) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 4d27ff2acdbad..70c5ed4b07c9a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -792,8 +792,7 @@ case class LocalLimit(limitExpr: Expression, child: LogicalPlan) extends UnaryNo case class SubqueryAlias( alias: String, - child: LogicalPlan, - view: Option[TableIdentifier]) + child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output.map(_.withQualifier(Some(alias))) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 01737e0a17341..893bb1b74cea7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -62,23 +62,23 @@ class AnalysisSuite extends AnalysisTest with ShouldMatchers { checkAnalysis( Project(Seq(UnresolvedAttribute("TbL.a")), - SubqueryAlias("TbL", UnresolvedRelation(TableIdentifier("TaBlE")), None)), + SubqueryAlias("TbL", UnresolvedRelation(TableIdentifier("TaBlE")))), Project(testRelation.output, testRelation)) assertAnalysisError( Project(Seq(UnresolvedAttribute("tBl.a")), - SubqueryAlias("TbL", UnresolvedRelation(TableIdentifier("TaBlE")), None)), + SubqueryAlias("TbL", UnresolvedRelation(TableIdentifier("TaBlE")))), Seq("cannot resolve")) checkAnalysis( Project(Seq(UnresolvedAttribute("TbL.a")), - SubqueryAlias("TbL", UnresolvedRelation(TableIdentifier("TaBlE")), None)), + SubqueryAlias("TbL", UnresolvedRelation(TableIdentifier("TaBlE")))), Project(testRelation.output, testRelation), caseSensitive = false) checkAnalysis( Project(Seq(UnresolvedAttribute("tBl.a")), - SubqueryAlias("TbL", UnresolvedRelation(TableIdentifier("TaBlE")), None)), + SubqueryAlias("TbL", UnresolvedRelation(TableIdentifier("TaBlE")))), Project(testRelation.output, testRelation), caseSensitive = false) } @@ -374,8 +374,8 @@ class AnalysisSuite extends AnalysisTest with ShouldMatchers { val query = Project(Seq($"x.key", $"y.key"), Join( - Project(Seq($"x.key"), SubqueryAlias("x", input, None)), - Project(Seq($"y.key"), SubqueryAlias("y", input, None)), + Project(Seq($"x.key"), SubqueryAlias("x", input)), + Project(Seq($"y.key"), SubqueryAlias("y", input)), Cross, None)) assertAnalysisSuccess(query) @@ -435,10 +435,10 @@ class AnalysisSuite extends AnalysisTest with ShouldMatchers { test("resolve as with an already existed alias") { checkAnalysis( Project(Seq(UnresolvedAttribute("tbl2.a")), - SubqueryAlias("tbl", testRelation, None).as("tbl2")), + SubqueryAlias("tbl", testRelation).as("tbl2")), Project(testRelation.output, testRelation), caseSensitive = false) - checkAnalysis(SubqueryAlias("tbl", testRelation, None).as("tbl2"), testRelation) + checkAnalysis(SubqueryAlias("tbl", testRelation).as("tbl2"), testRelation) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala index ffc272c6c0c39..328a16c4bf024 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala @@ -437,7 +437,7 @@ class SessionCatalogSuite extends PlanTest { .asInstanceOf[CatalogRelation].tableMeta == metastoreTable1) // Otherwise, we'll first look up a temporary table with the same name assert(sessionCatalog.lookupRelation(TableIdentifier("tbl1")) - == SubqueryAlias("tbl1", tempTable1, None)) + == SubqueryAlias("tbl1", tempTable1)) // Then, if that does not exist, look up the relation in the current database sessionCatalog.dropTable(TableIdentifier("tbl1"), ignoreIfNotExists = false, purge = false) assert(sessionCatalog.lookupRelation(TableIdentifier("tbl1")).children.head @@ -454,11 +454,11 @@ class SessionCatalogSuite extends PlanTest { val view = View(desc = metadata, output = metadata.schema.toAttributes, child = CatalystSqlParser.parsePlan(metadata.viewText.get)) comparePlans(sessionCatalog.lookupRelation(TableIdentifier("view1", Some("db3"))), - SubqueryAlias("view1", view, Some(TableIdentifier("view1", Some("db3"))))) + SubqueryAlias("view1", view)) // Look up a view using current database of the session catalog. sessionCatalog.setCurrentDatabase("db3") comparePlans(sessionCatalog.lookupRelation(TableIdentifier("view1")), - SubqueryAlias("view1", view, Some(TableIdentifier("view1", Some("db3"))))) + SubqueryAlias("view1", view)) } test("table exists") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala index 5bd1bc80c3b8a..589607e3ad5cb 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala @@ -320,16 +320,16 @@ class ColumnPruningSuite extends PlanTest { val query = Project(Seq($"x.key", $"y.key"), Join( - SubqueryAlias("x", input, None), - BroadcastHint(SubqueryAlias("y", input, None)), Inner, None)).analyze + SubqueryAlias("x", input), + BroadcastHint(SubqueryAlias("y", input)), Inner, None)).analyze val optimized = Optimize.execute(query) val expected = Join( - Project(Seq($"x.key"), SubqueryAlias("x", input, None)), + Project(Seq($"x.key"), SubqueryAlias("x", input)), BroadcastHint( - Project(Seq($"y.key"), SubqueryAlias("y", input, None))), + Project(Seq($"y.key"), SubqueryAlias("y", input))), Inner, None).analyze comparePlans(optimized, expected) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSubqueryAliasesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSubqueryAliasesSuite.scala index a8aeedbd62759..9b6d68aee803a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSubqueryAliasesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSubqueryAliasesSuite.scala @@ -46,13 +46,13 @@ class EliminateSubqueryAliasesSuite extends PlanTest with PredicateHelper { test("eliminate top level subquery") { val input = LocalRelation('a.int, 'b.int) - val query = SubqueryAlias("a", input, None) + val query = SubqueryAlias("a", input) comparePlans(afterOptimization(query), input) } test("eliminate mid-tree subquery") { val input = LocalRelation('a.int, 'b.int) - val query = Filter(TrueLiteral, SubqueryAlias("a", input, None)) + val query = Filter(TrueLiteral, SubqueryAlias("a", input)) comparePlans( afterOptimization(query), Filter(TrueLiteral, LocalRelation('a.int, 'b.int))) @@ -61,7 +61,7 @@ class EliminateSubqueryAliasesSuite extends PlanTest with PredicateHelper { test("eliminate multiple subqueries") { val input = LocalRelation('a.int, 'b.int) val query = Filter(TrueLiteral, - SubqueryAlias("c", SubqueryAlias("b", SubqueryAlias("a", input, None), None), None)) + SubqueryAlias("c", SubqueryAlias("b", SubqueryAlias("a", input)))) comparePlans( afterOptimization(query), Filter(TrueLiteral, LocalRelation('a.int, 'b.int))) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala index 65dd6225cea07..985e49069da90 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala @@ -129,15 +129,15 @@ class JoinOptimizationSuite extends PlanTest { val query = Project(Seq($"x.key", $"y.key"), Join( - SubqueryAlias("x", input, None), - BroadcastHint(SubqueryAlias("y", input, None)), Cross, None)).analyze + SubqueryAlias("x", input), + BroadcastHint(SubqueryAlias("y", input)), Cross, None)).analyze val optimized = Optimize.execute(query) val expected = Join( - Project(Seq($"x.key"), SubqueryAlias("x", input, None)), - BroadcastHint(Project(Seq($"y.key"), SubqueryAlias("y", input, None))), + Project(Seq($"x.key"), SubqueryAlias("x", input)), + BroadcastHint(Project(Seq($"y.key"), SubqueryAlias("y", input))), Cross, None).analyze comparePlans(optimized, expected) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala index 67d5d2202b680..411777d6e85a2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala @@ -79,7 +79,7 @@ class PlanParserSuite extends PlanTest { def cte(plan: LogicalPlan, namedPlans: (String, LogicalPlan)*): With = { val ctes = namedPlans.map { case (name, cte) => - name -> SubqueryAlias(name, cte, None) + name -> SubqueryAlias(name, cte) } With(plan, ctes) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 1b04623596073..f00311fc322d8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -1093,7 +1093,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ def as(alias: String): Dataset[T] = withTypedPlan { - SubqueryAlias(alias, logicalPlan, None) + SubqueryAlias(alias, logicalPlan) } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala index 9c55357ab9bc1..26c45e092dc65 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala @@ -22,15 +22,12 @@ import scala.reflect.ClassTag import org.apache.spark.AccumulatorSuite import org.apache.spark.sql.{Dataset, QueryTest, Row, SparkSession} import org.apache.spark.sql.catalyst.expressions.{BitwiseAnd, BitwiseOr, Cast, Literal, ShiftLeft} -import org.apache.spark.sql.catalyst.plans.logical.SubqueryAlias -import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.execution.exchange.EnsureRequirements import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types.{LongType, ShortType} -import org.apache.spark.util.Utils /** * Test various broadcast join operators. diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala index 892a22ddfafc8..cf552b4a88b2c 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala @@ -64,7 +64,7 @@ class HiveMetastoreCatalogSuite extends TestHiveSingleton with SQLTestUtils { spark.sql("create view vw1 as select 1 as id") val plan = spark.sql("select id from vw1").queryExecution.analyzed val aliases = plan.collect { - case x @ SubqueryAlias("vw1", _, Some(TableIdentifier("vw1", Some("default")))) => x + case x @ SubqueryAlias("vw1", _) => x } assert(aliases.size == 1) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index be9a5fd71bd25..236135dcff523 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -1030,7 +1030,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { withSQLConf(SQLConf.CONVERT_CTAS.key -> "false") { sql("CREATE TABLE explodeTest (key bigInt)") table("explodeTest").queryExecution.analyzed match { - case SubqueryAlias(_, r: CatalogRelation, _) => // OK + case SubqueryAlias(_, r: CatalogRelation) => // OK case _ => fail("To correctly test the fix of SPARK-5875, explodeTest should be a MetastoreRelation") } From 9a6ac7226fd09d570cae08d0daea82d9bca189a0 Mon Sep 17 00:00:00 2001 From: Xiao Li Date: Wed, 8 Mar 2017 09:36:01 -0800 Subject: [PATCH 46/78] [SPARK-19601][SQL] Fix CollapseRepartition rule to preserve shuffle-enabled Repartition ### What changes were proposed in this pull request? Observed by felixcheung in https://github.com/apache/spark/pull/16739, when users use the shuffle-enabled `repartition` API, they expect the partition they got should be the exact number they provided, even if they call shuffle-disabled `coalesce` later. Currently, `CollapseRepartition` rule does not consider whether shuffle is enabled or not. Thus, we got the following unexpected result. ```Scala val df = spark.range(0, 10000, 1, 5) val df2 = df.repartition(10) assert(df2.coalesce(13).rdd.getNumPartitions == 5) assert(df2.coalesce(7).rdd.getNumPartitions == 5) assert(df2.coalesce(3).rdd.getNumPartitions == 3) ``` This PR is to fix the issue. We preserve shuffle-enabled Repartition. ### How was this patch tested? Added a test case Author: Xiao Li Closes #16933 from gatorsmile/CollapseRepartition. --- R/pkg/inst/tests/testthat/test_sparkSQL.R | 4 +- .../spark/sql/catalyst/dsl/package.scala | 3 + .../sql/catalyst/optimizer/Optimizer.scala | 32 ++-- .../plans/logical/basicLogicalOperators.scala | 16 +- .../optimizer/CollapseRepartitionSuite.scala | 153 ++++++++++++++++-- .../scala/org/apache/spark/sql/Dataset.scala | 10 +- .../spark/sql/execution/PlannerSuite.scala | 9 +- 7 files changed, 178 insertions(+), 49 deletions(-) diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index 620b633637138..9735fe3201553 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -2592,8 +2592,8 @@ test_that("coalesce, repartition, numPartitions", { df2 <- repartition(df1, 10) expect_equal(getNumPartitions(df2), 10) - expect_equal(getNumPartitions(coalesce(df2, 13)), 5) - expect_equal(getNumPartitions(coalesce(df2, 7)), 5) + expect_equal(getNumPartitions(coalesce(df2, 13)), 10) + expect_equal(getNumPartitions(coalesce(df2, 7)), 7) expect_equal(getNumPartitions(coalesce(df2, 3)), 3) }) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 0f0d90494f98c..35ca2a0aa53a2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -370,6 +370,9 @@ package object dsl { def as(alias: String): LogicalPlan = SubqueryAlias(alias, logicalPlan) + def coalesce(num: Integer): LogicalPlan = + Repartition(num, shuffle = false, logicalPlan) + def repartition(num: Integer): LogicalPlan = Repartition(num, shuffle = true, logicalPlan) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index d5bbc6e8acc94..caafa1c134cd4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -564,27 +564,23 @@ object CollapseProject extends Rule[LogicalPlan] { } /** - * Combines adjacent [[Repartition]] and [[RepartitionByExpression]] operator combinations - * by keeping only the one. - * 1. For adjacent [[Repartition]]s, collapse into the last [[Repartition]]. - * 2. For adjacent [[RepartitionByExpression]]s, collapse into the last [[RepartitionByExpression]]. - * 3. For a combination of [[Repartition]] and [[RepartitionByExpression]], collapse as a single - * [[RepartitionByExpression]] with the expression and last number of partition. + * Combines adjacent [[RepartitionOperation]] operators */ object CollapseRepartition extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { - // Case 1 - case Repartition(numPartitions, shuffle, Repartition(_, _, child)) => - Repartition(numPartitions, shuffle, child) - // Case 2 - case RepartitionByExpression(exprs, RepartitionByExpression(_, child, _), numPartitions) => - RepartitionByExpression(exprs, child, numPartitions) - // Case 3 - case Repartition(numPartitions, _, r: RepartitionByExpression) => - r.copy(numPartitions = numPartitions) - // Case 3 - case RepartitionByExpression(exprs, Repartition(_, _, child), numPartitions) => - RepartitionByExpression(exprs, child, numPartitions) + // Case 1: When a Repartition has a child of Repartition or RepartitionByExpression, + // 1) When the top node does not enable the shuffle (i.e., coalesce API), but the child + // enables the shuffle. Returns the child node if the last numPartitions is bigger; + // otherwise, keep unchanged. + // 2) In the other cases, returns the top node with the child's child + case r @ Repartition(_, _, child: RepartitionOperation) => (r.shuffle, child.shuffle) match { + case (false, true) => if (r.numPartitions >= child.numPartitions) child else r + case _ => r.copy(child = child.child) + } + // Case 2: When a RepartitionByExpression has a child of Repartition or RepartitionByExpression + // we can remove the child. + case r @ RepartitionByExpression(_, child: RepartitionOperation, _) => + r.copy(child = child.child) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 70c5ed4b07c9a..31b6ed48a2230 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -842,6 +842,15 @@ case class Distinct(child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output } +/** + * A base interface for [[RepartitionByExpression]] and [[Repartition]] + */ +abstract class RepartitionOperation extends UnaryNode { + def shuffle: Boolean + def numPartitions: Int + override def output: Seq[Attribute] = child.output +} + /** * Returns a new RDD that has exactly `numPartitions` partitions. Differs from * [[RepartitionByExpression]] as this method is called directly by DataFrame's, because the user @@ -849,9 +858,8 @@ case class Distinct(child: LogicalPlan) extends UnaryNode { * of the output requires some specific ordering or distribution of the data. */ case class Repartition(numPartitions: Int, shuffle: Boolean, child: LogicalPlan) - extends UnaryNode { + extends RepartitionOperation { require(numPartitions > 0, s"Number of partitions ($numPartitions) must be positive.") - override def output: Seq[Attribute] = child.output } /** @@ -863,12 +871,12 @@ case class Repartition(numPartitions: Int, shuffle: Boolean, child: LogicalPlan) case class RepartitionByExpression( partitionExpressions: Seq[Expression], child: LogicalPlan, - numPartitions: Int) extends UnaryNode { + numPartitions: Int) extends RepartitionOperation { require(numPartitions > 0, s"Number of partitions ($numPartitions) must be positive.") override def maxRows: Option[Long] = child.maxRows - override def output: Seq[Attribute] = child.output + override def shuffle: Boolean = true } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseRepartitionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseRepartitionSuite.scala index 8952c72fe42fe..59d2dc46f00ce 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseRepartitionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseRepartitionSuite.scala @@ -32,47 +32,168 @@ class CollapseRepartitionSuite extends PlanTest { val testRelation = LocalRelation('a.int, 'b.int) + + test("collapse two adjacent coalesces into one") { + // Always respects the top coalesces amd removes useless coalesce below coalesce + val query1 = testRelation + .coalesce(10) + .coalesce(20) + val query2 = testRelation + .coalesce(30) + .coalesce(20) + + val optimized1 = Optimize.execute(query1.analyze) + val optimized2 = Optimize.execute(query2.analyze) + val correctAnswer = testRelation.coalesce(20).analyze + + comparePlans(optimized1, correctAnswer) + comparePlans(optimized2, correctAnswer) + } + test("collapse two adjacent repartitions into one") { - val query = testRelation + // Always respects the top repartition amd removes useless repartition below repartition + val query1 = testRelation + .repartition(10) + .repartition(20) + val query2 = testRelation + .repartition(30) + .repartition(20) + + val optimized1 = Optimize.execute(query1.analyze) + val optimized2 = Optimize.execute(query2.analyze) + val correctAnswer = testRelation.repartition(20).analyze + + comparePlans(optimized1, correctAnswer) + comparePlans(optimized2, correctAnswer) + } + + test("coalesce above repartition") { + // Remove useless coalesce above repartition + val query1 = testRelation .repartition(10) + .coalesce(20) + + val optimized1 = Optimize.execute(query1.analyze) + val correctAnswer1 = testRelation.repartition(10).analyze + + comparePlans(optimized1, correctAnswer1) + + // No change in this case + val query2 = testRelation + .repartition(30) + .coalesce(20) + + val optimized2 = Optimize.execute(query2.analyze) + val correctAnswer2 = query2.analyze + + comparePlans(optimized2, correctAnswer2) + } + + test("repartition above coalesce") { + // Always respects the top repartition amd removes useless coalesce below repartition + val query1 = testRelation + .coalesce(10) + .repartition(20) + val query2 = testRelation + .coalesce(30) .repartition(20) - val optimized = Optimize.execute(query.analyze) + val optimized1 = Optimize.execute(query1.analyze) + val optimized2 = Optimize.execute(query2.analyze) val correctAnswer = testRelation.repartition(20).analyze - comparePlans(optimized, correctAnswer) + comparePlans(optimized1, correctAnswer) + comparePlans(optimized2, correctAnswer) } - test("collapse repartition and repartitionBy into one") { - val query = testRelation + test("repartitionBy above repartition") { + // Always respects the top repartitionBy amd removes useless repartition + val query1 = testRelation .repartition(10) .distribute('a)(20) + val query2 = testRelation + .repartition(30) + .distribute('a)(20) - val optimized = Optimize.execute(query.analyze) + val optimized1 = Optimize.execute(query1.analyze) + val optimized2 = Optimize.execute(query2.analyze) val correctAnswer = testRelation.distribute('a)(20).analyze - comparePlans(optimized, correctAnswer) + comparePlans(optimized1, correctAnswer) + comparePlans(optimized2, correctAnswer) } - test("collapse repartitionBy and repartition into one") { - val query = testRelation + test("repartitionBy above coalesce") { + // Always respects the top repartitionBy amd removes useless coalesce below repartition + val query1 = testRelation + .coalesce(10) + .distribute('a)(20) + val query2 = testRelation + .coalesce(30) .distribute('a)(20) - .repartition(10) - val optimized = Optimize.execute(query.analyze) - val correctAnswer = testRelation.distribute('a)(10).analyze + val optimized1 = Optimize.execute(query1.analyze) + val optimized2 = Optimize.execute(query2.analyze) + val correctAnswer = testRelation.distribute('a)(20).analyze - comparePlans(optimized, correctAnswer) + comparePlans(optimized1, correctAnswer) + comparePlans(optimized2, correctAnswer) + } + + test("repartition above repartitionBy") { + // Always respects the top repartition amd removes useless distribute below repartition + val query1 = testRelation + .distribute('a)(10) + .repartition(20) + val query2 = testRelation + .distribute('a)(30) + .repartition(20) + + val optimized1 = Optimize.execute(query1.analyze) + val optimized2 = Optimize.execute(query2.analyze) + val correctAnswer = testRelation.repartition(20).analyze + + comparePlans(optimized1, correctAnswer) + comparePlans(optimized2, correctAnswer) + + } + + test("coalesce above repartitionBy") { + // Remove useless coalesce above repartition + val query1 = testRelation + .distribute('a)(10) + .coalesce(20) + + val optimized1 = Optimize.execute(query1.analyze) + val correctAnswer1 = testRelation.distribute('a)(10).analyze + + comparePlans(optimized1, correctAnswer1) + + // No change in this case + val query2 = testRelation + .distribute('a)(30) + .coalesce(20) + + val optimized2 = Optimize.execute(query2.analyze) + val correctAnswer2 = query2.analyze + + comparePlans(optimized2, correctAnswer2) } test("collapse two adjacent repartitionBys into one") { - val query = testRelation + // Always respects the top repartitionBy + val query1 = testRelation .distribute('b)(10) .distribute('a)(20) + val query2 = testRelation + .distribute('b)(30) + .distribute('a)(20) - val optimized = Optimize.execute(query.analyze) + val optimized1 = Optimize.execute(query1.analyze) + val optimized2 = Optimize.execute(query2.analyze) val correctAnswer = testRelation.distribute('a)(20).analyze - comparePlans(optimized, correctAnswer) + comparePlans(optimized1, correctAnswer) + comparePlans(optimized2, correctAnswer) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index f00311fc322d8..16edb35b1d43f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -2441,11 +2441,11 @@ class Dataset[T] private[sql]( } /** - * Returns a new Dataset that has exactly `numPartitions` partitions. - * Similar to coalesce defined on an `RDD`, this operation results in a narrow dependency, e.g. - * if you go from 1000 partitions to 100 partitions, there will not be a shuffle, instead each of - * the 100 new partitions will claim 10 of the current partitions. If a larger number of - * partitions is requested, it will stay at the current number of partitions. + * Returns a new Dataset that has exactly `numPartitions` partitions, when the fewer partitions + * are requested. If a larger number of partitions is requested, it will stay at the current + * number of partitions. Similar to coalesce defined on an `RDD`, this operation results in + * a narrow dependency, e.g. if you go from 1000 partitions to 100 partitions, there will not + * be a shuffle, instead each of the 100 new partitions will claim 10 of the current partitions. * * However, if you're doing a drastic coalesce, e.g. to numPartitions = 1, * this may result in your computation taking place on fewer nodes than diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 0bfc92fdb6218..02ccebd22bdf9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -242,11 +242,12 @@ class PlannerSuite extends SharedSQLContext { val doubleRepartitioned = testData.repartition(10).repartition(20).coalesce(5) def countRepartitions(plan: LogicalPlan): Int = plan.collect { case r: Repartition => r }.length assert(countRepartitions(doubleRepartitioned.queryExecution.logical) === 3) - assert(countRepartitions(doubleRepartitioned.queryExecution.optimizedPlan) === 1) + assert(countRepartitions(doubleRepartitioned.queryExecution.optimizedPlan) === 2) doubleRepartitioned.queryExecution.optimizedPlan match { - case r: Repartition => - assert(r.numPartitions === 5) - assert(r.shuffle === false) + case Repartition (numPartitions, shuffle, Repartition(_, shuffleChild, _)) => + assert(numPartitions === 5) + assert(shuffle === false) + assert(shuffleChild === true) } } From e420fd4592615d91cdcbca674ac58bcca6ab2ff3 Mon Sep 17 00:00:00 2001 From: Tejas Patil Date: Wed, 8 Mar 2017 09:38:05 -0800 Subject: [PATCH 47/78] [SPARK-19843][SQL][FOLLOWUP] Classdoc for `IntWrapper` and `LongWrapper` ## What changes were proposed in this pull request? This is as per suggestion by rxin at : https://github.com/apache/spark/pull/17184#discussion_r104841735 ## How was this patch tested? NA as this is a documentation change Author: Tejas Patil Closes #17205 from tejasapatil/SPARK-19843_followup. --- .../apache/spark/unsafe/types/UTF8String.java | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index 7abe0fa80ad7c..4c28075bd9386 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -850,10 +850,25 @@ public UTF8String translate(Map dict) { return fromString(sb.toString()); } + /** + * Wrapper over `long` to allow result of parsing long from string to be accessed via reference. + * This is done solely for better performance and is not expected to be used by end users. + */ public static class LongWrapper { public long value = 0; } + /** + * Wrapper over `int` to allow result of parsing integer from string to be accessed via reference. + * This is done solely for better performance and is not expected to be used by end users. + * + * {@link LongWrapper} could have been used here but using `int` directly save the extra cost of + * conversion from `long` -> `int` + */ + public static class IntWrapper { + public int value = 0; + } + /** * Parses this UTF8String to long. * @@ -942,10 +957,6 @@ public boolean toLong(LongWrapper toLongResult) { return true; } - public static class IntWrapper { - public int value = 0; - } - /** * Parses this UTF8String to int. * From f3387d97487cbef894b6963bc008f6a5c4294a85 Mon Sep 17 00:00:00 2001 From: windpiger Date: Wed, 8 Mar 2017 10:48:53 -0800 Subject: [PATCH 48/78] [SPARK-19864][SQL][TEST] provide a makeQualifiedPath functions to optimize some code ## What changes were proposed in this pull request? Currently there are lots of places to make the path qualified, it is better to provide a function to do this, then the code will be more simple. ## How was this patch tested? N/A Author: windpiger Closes #17204 from windpiger/addQualifiledPathUtil. --- .../spark/sql/execution/command/DDLSuite.scala | 11 +---------- .../apache/spark/sql/test/SQLTestUtils.scala | 14 +++++++++++++- .../sql/hive/execution/HiveDDLSuite.scala | 18 +++++------------- .../hive/orc/OrcHadoopFsRelationSuite.scala | 11 ++++------- .../sources/JsonHadoopFsRelationSuite.scala | 9 +++------ .../sources/ParquetHadoopFsRelationSuite.scala | 9 +++------ .../SimpleTextHadoopFsRelationSuite.scala | 9 +++------ 7 files changed, 32 insertions(+), 49 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index b2199fdf90e5c..c1f8b2b3d9605 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -132,13 +132,6 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { } } - private def makeQualifiedPath(path: String): URI = { - // copy-paste from SessionCatalog - val hadoopPath = new Path(path) - val fs = hadoopPath.getFileSystem(sparkContext.hadoopConfiguration) - fs.makeQualified(hadoopPath).toUri - } - test("Create Database using Default Warehouse Path") { val catalog = spark.sessionState.catalog val dbName = "db1" @@ -2086,9 +2079,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { Seq(1).toDF("a").write.saveAsTable("t") val tblloc = new File(loc, "t") val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t")) - val tblPath = new Path(tblloc.getAbsolutePath) - val fs = tblPath.getFileSystem(spark.sessionState.newHadoopConf()) - assert(table.location == fs.makeQualified(tblPath).toUri) + assert(table.location == makeQualifiedPath(tblloc.getAbsolutePath)) assert(tblloc.listFiles().nonEmpty) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala index d4afb9d8af6f8..9201954b66d10 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala @@ -18,13 +18,14 @@ package org.apache.spark.sql.test import java.io.File +import java.net.URI import java.util.UUID import scala.language.implicitConversions import scala.util.Try import scala.util.control.NonFatal -import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path import org.scalatest.BeforeAndAfterAll import org.apache.spark.SparkFunSuite @@ -294,6 +295,17 @@ private[sql] trait SQLTestUtils test(name) { runOnThread() } } } + + /** + * This method is used to make the given path qualified, when a path + * does not contain a scheme, this path will not be changed after the default + * FileSystem is changed. + */ + def makeQualifiedPath(path: String): URI = { + val hadoopPath = new Path(path) + val fs = hadoopPath.getFileSystem(spark.sessionState.newHadoopConf()) + fs.makeQualified(hadoopPath).toUri + } } private[sql] object SQLTestUtils { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala index df2c1cee942b0..10d929a4a0ef8 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala @@ -1654,10 +1654,8 @@ class HiveDDLSuite |LOCATION '$dir' |AS SELECT 3 as a, 4 as b, 1 as c, 2 as d """.stripMargin) - val dirPath = new Path(dir.getAbsolutePath) - val fs = dirPath.getFileSystem(spark.sessionState.newHadoopConf()) val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t")) - assert(new Path(table.location) == fs.makeQualified(dirPath)) + assert(table.location == makeQualifiedPath(dir.getAbsolutePath)) checkAnswer(spark.table("t"), Row(3, 4, 1, 2)) } @@ -1675,10 +1673,8 @@ class HiveDDLSuite |LOCATION '$dir' |AS SELECT 3 as a, 4 as b, 1 as c, 2 as d """.stripMargin) - val dirPath = new Path(dir.getAbsolutePath) - val fs = dirPath.getFileSystem(spark.sessionState.newHadoopConf()) val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t1")) - assert(new Path(table.location) == fs.makeQualified(dirPath)) + assert(table.location == makeQualifiedPath(dir.getAbsolutePath)) val partDir = new File(dir, "a=3") assert(partDir.exists()) @@ -1792,9 +1788,7 @@ class HiveDDLSuite """.stripMargin) val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t")) - val path = new Path(loc.getAbsolutePath) - val fs = path.getFileSystem(spark.sessionState.newHadoopConf()) - assert(table.location == fs.makeQualified(path).toUri) + assert(table.location == makeQualifiedPath(loc.getAbsolutePath)) assert(new Path(table.location).toString.contains(specialChars)) assert(loc.listFiles().isEmpty) @@ -1822,9 +1816,7 @@ class HiveDDLSuite """.stripMargin) val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t1")) - val path = new Path(loc.getAbsolutePath) - val fs = path.getFileSystem(spark.sessionState.newHadoopConf()) - assert(table.location == fs.makeQualified(path).toUri) + assert(table.location == makeQualifiedPath(loc.getAbsolutePath)) assert(new Path(table.location).toString.contains(specialChars)) assert(loc.listFiles().isEmpty) @@ -1871,7 +1863,7 @@ class HiveDDLSuite val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t")) val tblPath = new Path(tblloc.getAbsolutePath) val fs = tblPath.getFileSystem(spark.sessionState.newHadoopConf()) - assert(table.location == fs.makeQualified(tblPath).toUri) + assert(table.location == makeQualifiedPath(tblloc.getAbsolutePath)) assert(tblloc.listFiles().nonEmpty) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala index 4f771caa1db27..ba0a7605da71c 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala @@ -19,10 +19,10 @@ package org.apache.spark.sql.hive.orc import java.io.File -import org.apache.hadoop.fs.{FileSystem, Path} +import org.apache.hadoop.fs.Path -import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.catalog.CatalogUtils import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.HadoopFsRelationTest import org.apache.spark.sql.types._ @@ -42,12 +42,9 @@ class OrcHadoopFsRelationSuite extends HadoopFsRelationTest { test("save()/load() - partitioned table - simple queries - partition columns in data") { withTempDir { file => - val basePath = new Path(file.getCanonicalPath) - val fs = basePath.getFileSystem(SparkHadoopUtil.get.conf) - val qualifiedBasePath = fs.makeQualified(basePath) - for (p1 <- 1 to 2; p2 <- Seq("foo", "bar")) { - val partitionDir = new Path(qualifiedBasePath, s"p1=$p1/p2=$p2") + val partitionDir = new Path( + CatalogUtils.URIToString(makeQualifiedPath(file.getCanonicalPath)), s"p1=$p1/p2=$p2") sparkContext .parallelize(for (i <- 1 to 3) yield (i, s"val_$i", p1)) .toDF("a", "b", "p1") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/JsonHadoopFsRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/JsonHadoopFsRelationSuite.scala index d79edee5b1a4c..49be30435ad2f 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/JsonHadoopFsRelationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/JsonHadoopFsRelationSuite.scala @@ -21,8 +21,8 @@ import java.math.BigDecimal import org.apache.hadoop.fs.Path -import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.catalog.CatalogUtils import org.apache.spark.sql.types._ class JsonHadoopFsRelationSuite extends HadoopFsRelationTest { @@ -38,12 +38,9 @@ class JsonHadoopFsRelationSuite extends HadoopFsRelationTest { test("save()/load() - partitioned table - simple queries - partition columns in data") { withTempDir { file => - val basePath = new Path(file.getCanonicalPath) - val fs = basePath.getFileSystem(SparkHadoopUtil.get.conf) - val qualifiedBasePath = fs.makeQualified(basePath) - for (p1 <- 1 to 2; p2 <- Seq("foo", "bar")) { - val partitionDir = new Path(qualifiedBasePath, s"p1=$p1/p2=$p2") + val partitionDir = new Path( + CatalogUtils.URIToString(makeQualifiedPath(file.getCanonicalPath)), s"p1=$p1/p2=$p2") sparkContext .parallelize(for (i <- 1 to 3) yield s"""{"a":$i,"b":"val_$i"}""") .saveAsTextFile(partitionDir.toString) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala index 03207ab869d12..dce5bb7ddba66 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala @@ -23,8 +23,8 @@ import com.google.common.io.Files import org.apache.hadoop.fs.Path import org.apache.parquet.hadoop.ParquetOutputFormat -import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.catalog.CatalogUtils import org.apache.spark.sql.execution.datasources.SQLHadoopMapReduceCommitProtocol import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -44,12 +44,9 @@ class ParquetHadoopFsRelationSuite extends HadoopFsRelationTest { test("save()/load() - partitioned table - simple queries - partition columns in data") { withTempDir { file => - val basePath = new Path(file.getCanonicalPath) - val fs = basePath.getFileSystem(SparkHadoopUtil.get.conf) - val qualifiedBasePath = fs.makeQualified(basePath) - for (p1 <- 1 to 2; p2 <- Seq("foo", "bar")) { - val partitionDir = new Path(qualifiedBasePath, s"p1=$p1/p2=$p2") + val partitionDir = new Path( + CatalogUtils.URIToString(makeQualifiedPath(file.getCanonicalPath)), s"p1=$p1/p2=$p2") sparkContext .parallelize(for (i <- 1 to 3) yield (i, s"val_$i", p1)) .toDF("a", "b", "p1") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala index a47a2246ddc3c..2ec593b95c9b6 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.sources import org.apache.hadoop.fs.Path -import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.sql.catalyst.catalog.CatalogUtils import org.apache.spark.sql.catalyst.expressions.PredicateHelper import org.apache.spark.sql.types._ @@ -45,12 +45,9 @@ class SimpleTextHadoopFsRelationSuite extends HadoopFsRelationTest with Predicat test("save()/load() - partitioned table - simple queries - partition columns in data") { withTempDir { file => - val basePath = new Path(file.getCanonicalPath) - val fs = basePath.getFileSystem(SparkHadoopUtil.get.conf) - val qualifiedBasePath = fs.makeQualified(basePath) - for (p1 <- 1 to 2; p2 <- Seq("foo", "bar")) { - val partitionDir = new Path(qualifiedBasePath, s"p1=$p1/p2=$p2") + val partitionDir = new Path( + CatalogUtils.URIToString(makeQualifiedPath(file.getCanonicalPath)), s"p1=$p1/p2=$p2") sparkContext .parallelize(for (i <- 1 to 3) yield s"$i,val_$i,$p1") .saveAsTextFile(partitionDir.toString) From e9e2c612d58a19ddcb4b6abfb7389a4b0f7ef6f8 Mon Sep 17 00:00:00 2001 From: Wojtek Szymanski Date: Wed, 8 Mar 2017 12:36:16 -0800 Subject: [PATCH 49/78] [SPARK-19727][SQL] Fix for round function that modifies original column ## What changes were proposed in this pull request? Fix for SQL round function that modifies original column when underlying data frame is created from a local product. import org.apache.spark.sql.functions._ case class NumericRow(value: BigDecimal) val df = spark.createDataFrame(Seq(NumericRow(BigDecimal("1.23456789")))) df.show() +--------------------+ | value| +--------------------+ |1.234567890000000000| +--------------------+ df.withColumn("value_rounded", round('value)).show() // before +--------------------+-------------+ | value|value_rounded| +--------------------+-------------+ |1.000000000000000000| 1| +--------------------+-------------+ // after +--------------------+-------------+ | value|value_rounded| +--------------------+-------------+ |1.234567890000000000| 1| +--------------------+-------------+ ## How was this patch tested? New unit test added to existing suite `org.apache.spark.sql.MathFunctionsSuite` Author: Wojtek Szymanski Closes #17075 from wojtek-szymanski/SPARK-19727. --- .../sql/catalyst/CatalystTypeConverters.scala | 6 +--- .../spark/sql/catalyst/expressions/Cast.scala | 13 +++++++-- .../expressions/decimalExpressions.scala | 10 ++----- .../expressions/mathExpressions.scala | 2 +- .../org/apache/spark/sql/types/Decimal.scala | 28 +++++++++++++------ .../apache/spark/sql/types/DecimalSuite.scala | 8 +++++- .../apache/spark/sql/MathFunctionsSuite.scala | 12 ++++++++ 7 files changed, 54 insertions(+), 25 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala index 5b9161551a7af..d4ebdb139fe0f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala @@ -310,11 +310,7 @@ object CatalystTypeConverters { case d: JavaBigInteger => Decimal(d) case d: Decimal => d } - if (decimal.changePrecision(dataType.precision, dataType.scale)) { - decimal - } else { - null - } + decimal.toPrecision(dataType.precision, dataType.scale).orNull } override def toScala(catalystValue: Decimal): JavaBigDecimal = { if (catalystValue == null) null diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 7c60f7d57a99e..1049915986d9b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -352,6 +352,15 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String if (value.changePrecision(decimalType.precision, decimalType.scale)) value else null } + /** + * Create new `Decimal` with precision and scale given in `decimalType` (if any), + * returning null if it overflows or creating a new `value` and returning it if successful. + * + */ + private[this] def toPrecision(value: Decimal, decimalType: DecimalType): Decimal = + value.toPrecision(decimalType.precision, decimalType.scale).orNull + + private[this] def castToDecimal(from: DataType, target: DecimalType): Any => Any = from match { case StringType => buildCast[UTF8String](_, s => try { @@ -360,14 +369,14 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String case _: NumberFormatException => null }) case BooleanType => - buildCast[Boolean](_, b => changePrecision(if (b) Decimal.ONE else Decimal.ZERO, target)) + buildCast[Boolean](_, b => toPrecision(if (b) Decimal.ONE else Decimal.ZERO, target)) case DateType => buildCast[Int](_, d => null) // date can't cast to decimal in Hive case TimestampType => // Note that we lose precision here. buildCast[Long](_, t => changePrecision(Decimal(timestampToDouble(t)), target)) case dt: DecimalType => - b => changePrecision(b.asInstanceOf[Decimal].clone(), target) + b => toPrecision(b.asInstanceOf[Decimal], target) case t: IntegralType => b => changePrecision(Decimal(t.integral.asInstanceOf[Integral[Any]].toLong(b)), target) case x: FractionalType => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala index fa5dea6841149..c2211ae5d594b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala @@ -84,14 +84,8 @@ case class CheckOverflow(child: Expression, dataType: DecimalType) extends Unary override def nullable: Boolean = true - override def nullSafeEval(input: Any): Any = { - val d = input.asInstanceOf[Decimal].clone() - if (d.changePrecision(dataType.precision, dataType.scale)) { - d - } else { - null - } - } + override def nullSafeEval(input: Any): Any = + input.asInstanceOf[Decimal].toPrecision(dataType.precision, dataType.scale).orNull override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, eval => { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala index 65273a77b1054..dea5f85cb08cc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala @@ -1024,7 +1024,7 @@ abstract class RoundBase(child: Expression, scale: Expression, child.dataType match { case _: DecimalType => val decimal = input1.asInstanceOf[Decimal] - if (decimal.changePrecision(decimal.precision, _scale, mode)) decimal else null + decimal.toPrecision(decimal.precision, _scale, mode).orNull case ByteType => BigDecimal(input1.asInstanceOf[Byte]).setScale(_scale, mode).toByte case ShortType => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala index 089c84d5f7736..e8f6884c025c2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala @@ -21,6 +21,7 @@ import java.lang.{Long => JLong} import java.math.{BigInteger, MathContext, RoundingMode} import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.sql.AnalysisException /** * A mutable implementation of BigDecimal that can hold a Long if values are small enough. @@ -222,6 +223,19 @@ final class Decimal extends Ordered[Decimal] with Serializable { case java.math.BigDecimal.ROUND_HALF_EVEN => changePrecision(precision, scale, ROUND_HALF_EVEN) } + /** + * Create new `Decimal` with given precision and scale. + * + * @return `Some(decimal)` if successful or `None` if overflow would occur + */ + private[sql] def toPrecision( + precision: Int, + scale: Int, + roundMode: BigDecimal.RoundingMode.Value = ROUND_HALF_UP): Option[Decimal] = { + val copy = clone() + if (copy.changePrecision(precision, scale, roundMode)) Some(copy) else None + } + /** * Update precision and scale while keeping our value the same, and return true if successful. * @@ -362,17 +376,15 @@ final class Decimal extends Ordered[Decimal] with Serializable { def abs: Decimal = if (this.compare(Decimal.ZERO) < 0) this.unary_- else this def floor: Decimal = if (scale == 0) this else { - val value = this.clone() - value.changePrecision( - DecimalType.bounded(precision - scale + 1, 0).precision, 0, ROUND_FLOOR) - value + val newPrecision = DecimalType.bounded(precision - scale + 1, 0).precision + toPrecision(newPrecision, 0, ROUND_FLOOR).getOrElse( + throw new AnalysisException(s"Overflow when setting precision to $newPrecision")) } def ceil: Decimal = if (scale == 0) this else { - val value = this.clone() - value.changePrecision( - DecimalType.bounded(precision - scale + 1, 0).precision, 0, ROUND_CEILING) - value + val newPrecision = DecimalType.bounded(precision - scale + 1, 0).precision + toPrecision(newPrecision, 0, ROUND_CEILING).getOrElse( + throw new AnalysisException(s"Overflow when setting precision to $newPrecision")) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala index 52d0692524d0f..714883a4099cf 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala @@ -193,7 +193,7 @@ class DecimalSuite extends SparkFunSuite with PrivateMethodTester { assert(Decimal(Long.MaxValue, 100, 0).toUnscaledLong === Long.MaxValue) } - test("changePrecision() on compact decimal should respect rounding mode") { + test("changePrecision/toPrecision on compact decimal should respect rounding mode") { Seq(ROUND_FLOOR, ROUND_CEILING, ROUND_HALF_UP, ROUND_HALF_EVEN).foreach { mode => Seq("0.4", "0.5", "0.6", "1.0", "1.1", "1.6", "2.5", "5.5").foreach { n => Seq("", "-").foreach { sign => @@ -202,6 +202,12 @@ class DecimalSuite extends SparkFunSuite with PrivateMethodTester { val d = Decimal(unscaled, 8, 1) assert(d.changePrecision(10, 0, mode)) assert(d.toString === bd.setScale(0, mode).toString(), s"num: $sign$n, mode: $mode") + + val copy = d.toPrecision(10, 0, mode).orNull + assert(copy !== null) + assert(d.ne(copy)) + assert(d === copy) + assert(copy.toString === bd.setScale(0, mode).toString(), s"num: $sign$n, mode: $mode") } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MathFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MathFunctionsSuite.scala index 37443d0342980..328c5395ec91e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/MathFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/MathFunctionsSuite.scala @@ -233,6 +233,18 @@ class MathFunctionsSuite extends QueryTest with SharedSQLContext { ) } + test("round/bround with data frame from a local Seq of Product") { + val df = spark.createDataFrame(Seq(Tuple1(BigDecimal("5.9")))).toDF("value") + checkAnswer( + df.withColumn("value_rounded", round('value)), + Seq(Row(BigDecimal("5.9"), BigDecimal("6"))) + ) + checkAnswer( + df.withColumn("value_brounded", bround('value)), + Seq(Row(BigDecimal("5.9"), BigDecimal("6"))) + ) + } + test("exp") { testOneToOneMathFunction(exp, math.exp) } From 1bf9012380de2aa7bdf39220b55748defde8b700 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Wed, 8 Mar 2017 13:18:07 -0800 Subject: [PATCH 50/78] [SPARK-19858][SS] Add output mode to flatMapGroupsWithState and disallow invalid cases ## What changes were proposed in this pull request? Add a output mode parameter to `flatMapGroupsWithState` and just define `mapGroupsWithState` as `flatMapGroupsWithState(Update)`. `UnsupportedOperationChecker` is modified to disallow unsupported cases. - Batch mapGroupsWithState or flatMapGroupsWithState is always allowed. - For streaming (map/flatMap)GroupsWithState, see the following table: | Operators | Supported Query Output Mode | | ------------- | ------------- | | flatMapGroupsWithState(Update) without aggregation | Update | | flatMapGroupsWithState(Update) with aggregation | None | | flatMapGroupsWithState(Append) without aggregation | Append | | flatMapGroupsWithState(Append) before aggregation | Append, Update, Complete | | flatMapGroupsWithState(Append) after aggregation | None | | Multiple flatMapGroupsWithState(Append)s | Append | | Multiple mapGroupsWithStates | None | | Mxing mapGroupsWithStates and flatMapGroupsWithStates | None | | Other cases of multiple flatMapGroupsWithState | None | ## How was this patch tested? The added unit tests. Here are the tests related to (map/flatMap)GroupsWithState: ``` [info] - batch plan - flatMapGroupsWithState - flatMapGroupsWithState(Append) on batch relation: supported (1 millisecond) [info] - batch plan - flatMapGroupsWithState - multiple flatMapGroupsWithState(Append)s on batch relation: supported (0 milliseconds) [info] - batch plan - flatMapGroupsWithState - flatMapGroupsWithState(Update) on batch relation: supported (0 milliseconds) [info] - batch plan - flatMapGroupsWithState - multiple flatMapGroupsWithState(Update)s on batch relation: supported (0 milliseconds) [info] - streaming plan - flatMapGroupsWithState - flatMapGroupsWithState(Update) on streaming relation without aggregation in update mode: supported (2 milliseconds) [info] - streaming plan - flatMapGroupsWithState - flatMapGroupsWithState(Update) on streaming relation without aggregation in append mode: not supported (7 milliseconds) [info] - streaming plan - flatMapGroupsWithState - flatMapGroupsWithState(Update) on streaming relation without aggregation in complete mode: not supported (5 milliseconds) [info] - streaming plan - flatMapGroupsWithState - flatMapGroupsWithState(Update) on streaming relation with aggregation in Append mode: not supported (11 milliseconds) [info] - streaming plan - flatMapGroupsWithState - flatMapGroupsWithState(Update) on streaming relation with aggregation in Update mode: not supported (5 milliseconds) [info] - streaming plan - flatMapGroupsWithState - flatMapGroupsWithState(Update) on streaming relation with aggregation in Complete mode: not supported (5 milliseconds) [info] - streaming plan - flatMapGroupsWithState - flatMapGroupsWithState(Append) on streaming relation without aggregation in append mode: supported (1 millisecond) [info] - streaming plan - flatMapGroupsWithState - flatMapGroupsWithState(Append) on streaming relation without aggregation in update mode: not supported (6 milliseconds) [info] - streaming plan - flatMapGroupsWithState - flatMapGroupsWithState(Append) on streaming relation before aggregation in Append mode: supported (1 millisecond) [info] - streaming plan - flatMapGroupsWithState - flatMapGroupsWithState(Append) on streaming relation before aggregation in Update mode: supported (0 milliseconds) [info] - streaming plan - flatMapGroupsWithState - flatMapGroupsWithState(Append) on streaming relation before aggregation in Complete mode: supported (1 millisecond) [info] - streaming plan - flatMapGroupsWithState - flatMapGroupsWithState(Append) on streaming relation after aggregation in Append mode: not supported (6 milliseconds) [info] - streaming plan - flatMapGroupsWithState - flatMapGroupsWithState(Append) on streaming relation after aggregation in Update mode: not supported (4 milliseconds) [info] - streaming plan - flatMapGroupsWithState - flatMapGroupsWithState(Update) on streaming relation in complete mode: not supported (2 milliseconds) [info] - streaming plan - flatMapGroupsWithState - flatMapGroupsWithState(Append) on batch relation inside streaming relation in Append output mode: supported (1 millisecond) [info] - streaming plan - flatMapGroupsWithState - flatMapGroupsWithState(Append) on batch relation inside streaming relation in Update output mode: supported (1 millisecond) [info] - streaming plan - flatMapGroupsWithState - flatMapGroupsWithState(Update) on batch relation inside streaming relation in Append output mode: supported (0 milliseconds) [info] - streaming plan - flatMapGroupsWithState - flatMapGroupsWithState(Update) on batch relation inside streaming relation in Update output mode: supported (0 milliseconds) [info] - streaming plan - flatMapGroupsWithState - multiple flatMapGroupsWithStates on streaming relation and all are in append mode: supported (2 milliseconds) [info] - streaming plan - flatMapGroupsWithState - multiple flatMapGroupsWithStates on s streaming relation but some are not in append mode: not supported (7 milliseconds) [info] - streaming plan - mapGroupsWithState - mapGroupsWithState on streaming relation without aggregation in append mode: not supported (3 milliseconds) [info] - streaming plan - mapGroupsWithState - mapGroupsWithState on streaming relation without aggregation in complete mode: not supported (3 milliseconds) [info] - streaming plan - mapGroupsWithState - mapGroupsWithState on streaming relation with aggregation in Append mode: not supported (6 milliseconds) [info] - streaming plan - mapGroupsWithState - mapGroupsWithState on streaming relation with aggregation in Update mode: not supported (3 milliseconds) [info] - streaming plan - mapGroupsWithState - mapGroupsWithState on streaming relation with aggregation in Complete mode: not supported (4 milliseconds) [info] - streaming plan - mapGroupsWithState - multiple mapGroupsWithStates on streaming relation and all are in append mode: not supported (4 milliseconds) [info] - streaming plan - mapGroupsWithState - mixing mapGroupsWithStates and flatMapGroupsWithStates on streaming relation: not supported (4 milliseconds) ``` Author: Shixiong Zhu Closes #17197 from zsxwing/mapgroups-check. --- .../UnsupportedOperationChecker.scala | 77 ++++++- .../sql/catalyst/plans/logical/object.scala | 24 ++- .../streaming/InternalOutputModes.scala | 15 ++ .../analysis/UnsupportedOperationsSuite.scala | 203 ++++++++++++++++-- .../streaming/InternalOutputModesSuite.scala | 48 +++++ .../spark/sql/KeyValueGroupedDataset.scala | 91 +++++++- .../spark/sql/execution/SparkStrategies.scala | 21 +- .../streaming/IncrementalExecution.scala | 4 +- .../streaming/statefulOperators.scala | 4 +- .../sql/streaming/DataStreamWriter.scala | 16 +- .../apache/spark/sql/JavaDatasetSuite.java | 2 + ...cala => FlatMapGroupsWithStateSuite.scala} | 46 ++-- .../test/DataStreamReaderWriterSuite.scala | 41 +--- 13 files changed, 485 insertions(+), 107 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/streaming/InternalOutputModesSuite.scala rename sql/core/src/test/scala/org/apache/spark/sql/streaming/{MapGroupsWithStateSuite.scala => FlatMapGroupsWithStateSuite.scala} (88%) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala index 397f5cfe2a540..a9ff61e0e8802 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala @@ -51,6 +51,37 @@ object UnsupportedOperationChecker { subplan.collect { case a: Aggregate if a.isStreaming => a } } + val mapGroupsWithStates = plan.collect { + case f: FlatMapGroupsWithState if f.isStreaming && f.isMapGroupsWithState => f + } + + // Disallow multiple `mapGroupsWithState`s. + if (mapGroupsWithStates.size >= 2) { + throwError( + "Multiple mapGroupsWithStates are not supported on a streaming DataFrames/Datasets")(plan) + } + + val flatMapGroupsWithStates = plan.collect { + case f: FlatMapGroupsWithState if f.isStreaming && !f.isMapGroupsWithState => f + } + + // Disallow mixing `mapGroupsWithState`s and `flatMapGroupsWithState`s + if (mapGroupsWithStates.nonEmpty && flatMapGroupsWithStates.nonEmpty) { + throwError( + "Mixing mapGroupsWithStates and flatMapGroupsWithStates are not supported on a " + + "streaming DataFrames/Datasets")(plan) + } + + // Only allow multiple `FlatMapGroupsWithState(Append)`s in append mode. + if (flatMapGroupsWithStates.size >= 2 && ( + outputMode != InternalOutputModes.Append || + flatMapGroupsWithStates.exists(_.outputMode != InternalOutputModes.Append) + )) { + throwError( + "Multiple flatMapGroupsWithStates are not supported when they are not all in append mode" + + " or the output mode is not append on a streaming DataFrames/Datasets")(plan) + } + // Disallow multiple streaming aggregations val aggregates = collectStreamingAggregates(plan) @@ -116,9 +147,49 @@ object UnsupportedOperationChecker { throwError("Commands like CreateTable*, AlterTable*, Show* are not supported with " + "streaming DataFrames/Datasets") - case m: MapGroupsWithState if collectStreamingAggregates(m).nonEmpty => - throwError("(map/flatMap)GroupsWithState is not supported after aggregation on a " + - "streaming DataFrame/Dataset") + // mapGroupsWithState: Allowed only when no aggregation + Update output mode + case m: FlatMapGroupsWithState if m.isStreaming && m.isMapGroupsWithState => + if (collectStreamingAggregates(plan).isEmpty) { + if (outputMode != InternalOutputModes.Update) { + throwError("mapGroupsWithState is not supported with " + + s"$outputMode output mode on a streaming DataFrame/Dataset") + } else { + // Allowed when no aggregation + Update output mode + } + } else { + throwError("mapGroupsWithState is not supported with aggregation " + + "on a streaming DataFrame/Dataset") + } + + // flatMapGroupsWithState without aggregation + case m: FlatMapGroupsWithState + if m.isStreaming && collectStreamingAggregates(plan).isEmpty => + m.outputMode match { + case InternalOutputModes.Update => + if (outputMode != InternalOutputModes.Update) { + throwError("flatMapGroupsWithState in update mode is not supported with " + + s"$outputMode output mode on a streaming DataFrame/Dataset") + } + case InternalOutputModes.Append => + if (outputMode != InternalOutputModes.Append) { + throwError("flatMapGroupsWithState in append mode is not supported with " + + s"$outputMode output mode on a streaming DataFrame/Dataset") + } + } + + // flatMapGroupsWithState(Update) with aggregation + case m: FlatMapGroupsWithState + if m.isStreaming && m.outputMode == InternalOutputModes.Update + && collectStreamingAggregates(plan).nonEmpty => + throwError("flatMapGroupsWithState in update mode is not supported with " + + "aggregation on a streaming DataFrame/Dataset") + + // flatMapGroupsWithState(Append) with aggregation + case m: FlatMapGroupsWithState + if m.isStreaming && m.outputMode == InternalOutputModes.Append + && collectStreamingAggregates(m).nonEmpty => + throwError("flatMapGroupsWithState in append mode is not supported after " + + s"aggregation on a streaming DataFrame/Dataset") case d: Deduplicate if collectStreamingAggregates(d).nonEmpty => throwError("dropDuplicates is not supported after aggregation on a " + diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala index 0be4823bbc895..617239f56cdd3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.analysis.UnresolvedDeserializer import org.apache.spark.sql.catalyst.encoders._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.objects.Invoke +import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types._ object CatalystSerde { @@ -317,13 +318,15 @@ case class MapGroups( trait LogicalKeyedState[S] /** Factory for constructing new `MapGroupsWithState` nodes. */ -object MapGroupsWithState { +object FlatMapGroupsWithState { def apply[K: Encoder, V: Encoder, S: Encoder, U: Encoder]( func: (Any, Iterator[Any], LogicalKeyedState[Any]) => Iterator[Any], groupingAttributes: Seq[Attribute], dataAttributes: Seq[Attribute], + outputMode: OutputMode, + isMapGroupsWithState: Boolean, child: LogicalPlan): LogicalPlan = { - val mapped = new MapGroupsWithState( + val mapped = new FlatMapGroupsWithState( func, UnresolvedDeserializer(encoderFor[K].deserializer, groupingAttributes), UnresolvedDeserializer(encoderFor[V].deserializer, dataAttributes), @@ -332,7 +335,9 @@ object MapGroupsWithState { CatalystSerde.generateObjAttr[U], encoderFor[S].resolveAndBind().deserializer, encoderFor[S].namedExpressions, - child) + outputMode, + child, + isMapGroupsWithState) CatalystSerde.serialize[U](mapped) } } @@ -350,8 +355,10 @@ object MapGroupsWithState { * @param outputObjAttr used to define the output object * @param stateDeserializer used to deserialize state before calling `func` * @param stateSerializer used to serialize updated state after calling `func` + * @param outputMode the output mode of `func` + * @param isMapGroupsWithState whether it is created by the `mapGroupsWithState` method */ -case class MapGroupsWithState( +case class FlatMapGroupsWithState( func: (Any, Iterator[Any], LogicalKeyedState[Any]) => Iterator[Any], keyDeserializer: Expression, valueDeserializer: Expression, @@ -360,7 +367,14 @@ case class MapGroupsWithState( outputObjAttr: Attribute, stateDeserializer: Expression, stateSerializer: Seq[NamedExpression], - child: LogicalPlan) extends UnaryNode with ObjectProducer + outputMode: OutputMode, + child: LogicalPlan, + isMapGroupsWithState: Boolean = false) extends UnaryNode with ObjectProducer { + + if (isMapGroupsWithState) { + assert(outputMode == OutputMode.Update) + } +} /** Factory for constructing new `FlatMapGroupsInR` nodes. */ object FlatMapGroupsInR { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/streaming/InternalOutputModes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/streaming/InternalOutputModes.scala index 351bd6fff4adf..bdf2baf7361d3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/streaming/InternalOutputModes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/streaming/InternalOutputModes.scala @@ -44,4 +44,19 @@ private[sql] object InternalOutputModes { * aggregations, it will be equivalent to `Append` mode. */ case object Update extends OutputMode + + + def apply(outputMode: String): OutputMode = { + outputMode.toLowerCase match { + case "append" => + OutputMode.Append + case "complete" => + OutputMode.Complete + case "update" => + OutputMode.Update + case _ => + throw new IllegalArgumentException(s"Unknown output mode $outputMode. " + + "Accepted output modes are 'append', 'complete', 'update'") + } + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala index 82be69a0f7d7b..200c39f43a6b4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Literal, NamedExpression} import org.apache.spark.sql.catalyst.expressions.aggregate.Count import org.apache.spark.sql.catalyst.plans._ -import org.apache.spark.sql.catalyst.plans.logical.{MapGroupsWithState, _} +import org.apache.spark.sql.catalyst.plans.logical.{FlatMapGroupsWithState, _} import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.{IntegerType, LongType, MetadataBuilder} @@ -138,29 +138,202 @@ class UnsupportedOperationsSuite extends SparkFunSuite { outputMode = Complete, expectedMsgs = Seq("distinct aggregation")) - // MapGroupsWithState: Not supported after a streaming aggregation val att = new AttributeReference(name = "a", dataType = LongType)() - assertSupportedInBatchPlan( - "mapGroupsWithState - mapGroupsWithState on batch relation", - MapGroupsWithState(null, att, att, Seq(att), Seq(att), att, att, Seq(att), batchRelation)) + // FlatMapGroupsWithState: Both function modes equivalent and supported in batch. + for (funcMode <- Seq(Append, Update)) { + assertSupportedInBatchPlan( + s"flatMapGroupsWithState - flatMapGroupsWithState($funcMode) on batch relation", + FlatMapGroupsWithState( + null, att, att, Seq(att), Seq(att), att, att, Seq(att), funcMode, batchRelation)) + + assertSupportedInBatchPlan( + s"flatMapGroupsWithState - multiple flatMapGroupsWithState($funcMode)s on batch relation", + FlatMapGroupsWithState( + null, att, att, Seq(att), Seq(att), att, att, Seq(att), funcMode, + FlatMapGroupsWithState( + null, att, att, Seq(att), Seq(att), att, att, Seq(att), funcMode, batchRelation))) + } + + // FlatMapGroupsWithState(Update) in streaming without aggregation + assertSupportedInStreamingPlan( + "flatMapGroupsWithState - flatMapGroupsWithState(Update) " + + "on streaming relation without aggregation in update mode", + FlatMapGroupsWithState( + null, att, att, Seq(att), Seq(att), att, att, Seq(att), Update, streamRelation), + outputMode = Update) + + assertNotSupportedInStreamingPlan( + "flatMapGroupsWithState - flatMapGroupsWithState(Update) " + + "on streaming relation without aggregation in append mode", + FlatMapGroupsWithState( + null, att, att, Seq(att), Seq(att), att, att, Seq(att), Update, streamRelation), + outputMode = Append, + expectedMsgs = Seq("flatMapGroupsWithState in update mode", "Append")) + + assertNotSupportedInStreamingPlan( + "flatMapGroupsWithState - flatMapGroupsWithState(Update) " + + "on streaming relation without aggregation in complete mode", + FlatMapGroupsWithState( + null, att, att, Seq(att), Seq(att), att, att, Seq(att), Update, streamRelation), + outputMode = Complete, + // Disallowed by the aggregation check but let's still keep this test in case it's broken in + // future. + expectedMsgs = Seq("Complete")) + + // FlatMapGroupsWithState(Update) in streaming with aggregation + for (outputMode <- Seq(Append, Update, Complete)) { + assertNotSupportedInStreamingPlan( + "flatMapGroupsWithState - flatMapGroupsWithState(Update) on streaming relation " + + s"with aggregation in $outputMode mode", + FlatMapGroupsWithState( + null, att, att, Seq(att), Seq(att), att, att, Seq(att), Update, + Aggregate(Seq(attributeWithWatermark), aggExprs("c"), streamRelation)), + outputMode = outputMode, + expectedMsgs = Seq("flatMapGroupsWithState in update mode", "with aggregation")) + } + // FlatMapGroupsWithState(Append) in streaming without aggregation assertSupportedInStreamingPlan( - "mapGroupsWithState - mapGroupsWithState on streaming relation before aggregation", - MapGroupsWithState(null, att, att, Seq(att), Seq(att), att, att, Seq(att), streamRelation), + "flatMapGroupsWithState - flatMapGroupsWithState(Append) " + + "on streaming relation without aggregation in append mode", + FlatMapGroupsWithState( + null, att, att, Seq(att), Seq(att), att, att, Seq(att), Append, streamRelation), outputMode = Append) assertNotSupportedInStreamingPlan( - "mapGroupsWithState - mapGroupsWithState on streaming relation after aggregation", - MapGroupsWithState(null, att, att, Seq(att), Seq(att), att, att, Seq(att), - Aggregate(Nil, aggExprs("c"), streamRelation)), + "flatMapGroupsWithState - flatMapGroupsWithState(Append) " + + "on streaming relation without aggregation in update mode", + FlatMapGroupsWithState( + null, att, att, Seq(att), Seq(att), att, att, Seq(att), Append, streamRelation), + outputMode = Update, + expectedMsgs = Seq("flatMapGroupsWithState in append mode", "update")) + + // FlatMapGroupsWithState(Append) in streaming with aggregation + for (outputMode <- Seq(Append, Update, Complete)) { + assertSupportedInStreamingPlan( + "flatMapGroupsWithState - flatMapGroupsWithState(Append) " + + s"on streaming relation before aggregation in $outputMode mode", + Aggregate( + Seq(attributeWithWatermark), + aggExprs("c"), + FlatMapGroupsWithState( + null, att, att, Seq(att), Seq(att), att, att, Seq(att), Append, streamRelation)), + outputMode = outputMode) + } + + for (outputMode <- Seq(Append, Update)) { + assertNotSupportedInStreamingPlan( + "flatMapGroupsWithState - flatMapGroupsWithState(Append) " + + s"on streaming relation after aggregation in $outputMode mode", + FlatMapGroupsWithState(null, att, att, Seq(att), Seq(att), att, att, Seq(att), Append, + Aggregate(Seq(attributeWithWatermark), aggExprs("c"), streamRelation)), + outputMode = outputMode, + expectedMsgs = Seq("flatMapGroupsWithState", "after aggregation")) + } + + assertNotSupportedInStreamingPlan( + "flatMapGroupsWithState - " + + "flatMapGroupsWithState(Update) on streaming relation in complete mode", + FlatMapGroupsWithState( + null, att, att, Seq(att), Seq(att), att, att, Seq(att), Append, streamRelation), outputMode = Complete, - expectedMsgs = Seq("(map/flatMap)GroupsWithState")) + // Disallowed by the aggregation check but let's still keep this test in case it's broken in + // future. + expectedMsgs = Seq("Complete")) + // FlatMapGroupsWithState inside batch relation should always be allowed + for (funcMode <- Seq(Append, Update)) { + for (outputMode <- Seq(Append, Update)) { // Complete is not supported without aggregation + assertSupportedInStreamingPlan( + s"flatMapGroupsWithState - flatMapGroupsWithState($funcMode) on batch relation inside " + + s"streaming relation in $outputMode output mode", + FlatMapGroupsWithState( + null, att, att, Seq(att), Seq(att), att, att, Seq(att), funcMode, batchRelation), + outputMode = outputMode + ) + } + } + + // multiple FlatMapGroupsWithStates assertSupportedInStreamingPlan( - "mapGroupsWithState - mapGroupsWithState on batch relation inside streaming relation", - MapGroupsWithState(null, att, att, Seq(att), Seq(att), att, att, Seq(att), batchRelation), - outputMode = Append - ) + "flatMapGroupsWithState - multiple flatMapGroupsWithStates on streaming relation and all are " + + "in append mode", + FlatMapGroupsWithState( + null, att, att, Seq(att), Seq(att), att, att, Seq(att), Append, + FlatMapGroupsWithState( + null, att, att, Seq(att), Seq(att), att, att, Seq(att), Append, streamRelation)), + outputMode = Append) + + assertNotSupportedInStreamingPlan( + "flatMapGroupsWithState - multiple flatMapGroupsWithStates on s streaming relation but some" + + " are not in append mode", + FlatMapGroupsWithState( + null, att, att, Seq(att), Seq(att), att, att, Seq(att), Update, + FlatMapGroupsWithState( + null, att, att, Seq(att), Seq(att), att, att, Seq(att), Append, streamRelation)), + outputMode = Append, + expectedMsgs = Seq("multiple flatMapGroupsWithState", "append")) + + // mapGroupsWithState + assertNotSupportedInStreamingPlan( + "mapGroupsWithState - mapGroupsWithState " + + "on streaming relation without aggregation in append mode", + FlatMapGroupsWithState( + null, att, att, Seq(att), Seq(att), att, att, Seq(att), Update, streamRelation, + isMapGroupsWithState = true), + outputMode = Append, + // Disallowed by the aggregation check but let's still keep this test in case it's broken in + // future. + expectedMsgs = Seq("mapGroupsWithState", "append")) + + assertNotSupportedInStreamingPlan( + "mapGroupsWithState - mapGroupsWithState " + + "on streaming relation without aggregation in complete mode", + FlatMapGroupsWithState( + null, att, att, Seq(att), Seq(att), att, att, Seq(att), Update, streamRelation, + isMapGroupsWithState = true), + outputMode = Complete, + // Disallowed by the aggregation check but let's still keep this test in case it's broken in + // future. + expectedMsgs = Seq("Complete")) + + for (outputMode <- Seq(Append, Update, Complete)) { + assertNotSupportedInStreamingPlan( + "mapGroupsWithState - mapGroupsWithState on streaming relation " + + s"with aggregation in $outputMode mode", + FlatMapGroupsWithState( + null, att, att, Seq(att), Seq(att), att, att, Seq(att), Update, + Aggregate(Seq(attributeWithWatermark), aggExprs("c"), streamRelation), + isMapGroupsWithState = true), + outputMode = outputMode, + expectedMsgs = Seq("mapGroupsWithState", "with aggregation")) + } + + // multiple mapGroupsWithStates + assertNotSupportedInStreamingPlan( + "mapGroupsWithState - multiple mapGroupsWithStates on streaming relation and all are " + + "in append mode", + FlatMapGroupsWithState( + null, att, att, Seq(att), Seq(att), att, att, Seq(att), Update, + FlatMapGroupsWithState( + null, att, att, Seq(att), Seq(att), att, att, Seq(att), Update, streamRelation, + isMapGroupsWithState = true), + isMapGroupsWithState = true), + outputMode = Append, + expectedMsgs = Seq("multiple mapGroupsWithStates")) + + // mixing mapGroupsWithStates and flatMapGroupsWithStates + assertNotSupportedInStreamingPlan( + "mapGroupsWithState - " + + "mixing mapGroupsWithStates and flatMapGroupsWithStates on streaming relation", + FlatMapGroupsWithState( + null, att, att, Seq(att), Seq(att), att, att, Seq(att), Update, + FlatMapGroupsWithState( + null, att, att, Seq(att), Seq(att), att, att, Seq(att), Update, streamRelation, + isMapGroupsWithState = false), + isMapGroupsWithState = true), + outputMode = Append, + expectedMsgs = Seq("Mixing mapGroupsWithStates and flatMapGroupsWithStates")) // Deduplicate assertSupportedInStreamingPlan( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/streaming/InternalOutputModesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/streaming/InternalOutputModesSuite.scala new file mode 100644 index 0000000000000..201dac35ed2d8 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/streaming/InternalOutputModesSuite.scala @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.streaming + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.streaming.OutputMode + +class InternalOutputModesSuite extends SparkFunSuite { + + test("supported strings") { + def testMode(outputMode: String, expected: OutputMode): Unit = { + assert(InternalOutputModes(outputMode) === expected) + } + + testMode("append", OutputMode.Append) + testMode("Append", OutputMode.Append) + testMode("complete", OutputMode.Complete) + testMode("Complete", OutputMode.Complete) + testMode("update", OutputMode.Update) + testMode("Update", OutputMode.Update) + } + + test("unsupported strings") { + def testMode(outputMode: String): Unit = { + val acceptedModes = Seq("append", "update", "complete") + val e = intercept[IllegalArgumentException](InternalOutputModes(outputMode)) + (Seq("output mode", "unknown", outputMode) ++ acceptedModes).foreach { s => + assert(e.getMessage.toLowerCase.contains(s.toLowerCase)) + } + } + testMode("Xyz") + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala index 3a548c251f5b1..ab956ffd642e7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala @@ -24,8 +24,10 @@ import org.apache.spark.api.java.function._ import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder} import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, CreateStruct} import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.streaming.InternalOutputModes import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.expressions.ReduceAggregator +import org.apache.spark.sql.streaming.OutputMode /** * :: Experimental :: @@ -238,8 +240,16 @@ class KeyValueGroupedDataset[K, V] private[sql]( @InterfaceStability.Evolving def mapGroupsWithState[S: Encoder, U: Encoder]( func: (K, Iterator[V], KeyedState[S]) => U): Dataset[U] = { - flatMapGroupsWithState[S, U]( - (key: K, it: Iterator[V], s: KeyedState[S]) => Iterator(func(key, it, s))) + val flatMapFunc = (key: K, it: Iterator[V], s: KeyedState[S]) => Iterator(func(key, it, s)) + Dataset[U]( + sparkSession, + FlatMapGroupsWithState[K, V, S, U]( + flatMapFunc.asInstanceOf[(Any, Iterator[Any], LogicalKeyedState[Any]) => Iterator[Any]], + groupingAttributes, + dataAttributes, + OutputMode.Update, + isMapGroupsWithState = true, + child = logicalPlan)) } /** @@ -267,8 +277,8 @@ class KeyValueGroupedDataset[K, V] private[sql]( func: MapGroupsWithStateFunction[K, V, S, U], stateEncoder: Encoder[S], outputEncoder: Encoder[U]): Dataset[U] = { - flatMapGroupsWithState[S, U]( - (key: K, it: Iterator[V], s: KeyedState[S]) => Iterator(func.call(key, it.asJava, s)) + mapGroupsWithState[S, U]( + (key: K, it: Iterator[V], s: KeyedState[S]) => func.call(key, it.asJava, s) )(stateEncoder, outputEncoder) } @@ -284,6 +294,8 @@ class KeyValueGroupedDataset[K, V] private[sql]( * * @tparam S The type of the user-defined state. Must be encodable to Spark SQL types. * @tparam U The type of the output objects. Must be encodable to Spark SQL types. + * @param func Function to be called on every group. + * @param outputMode The output mode of the function. * * See [[Encoder]] for more details on what types are encodable to Spark SQL. * @since 2.1.1 @@ -291,14 +303,44 @@ class KeyValueGroupedDataset[K, V] private[sql]( @Experimental @InterfaceStability.Evolving def flatMapGroupsWithState[S: Encoder, U: Encoder]( - func: (K, Iterator[V], KeyedState[S]) => Iterator[U]): Dataset[U] = { + func: (K, Iterator[V], KeyedState[S]) => Iterator[U], outputMode: OutputMode): Dataset[U] = { + if (outputMode != OutputMode.Append && outputMode != OutputMode.Update) { + throw new IllegalArgumentException("The output mode of function should be append or update") + } Dataset[U]( sparkSession, - MapGroupsWithState[K, V, S, U]( + FlatMapGroupsWithState[K, V, S, U]( func.asInstanceOf[(Any, Iterator[Any], LogicalKeyedState[Any]) => Iterator[Any]], groupingAttributes, dataAttributes, - logicalPlan)) + outputMode, + isMapGroupsWithState = false, + child = logicalPlan)) + } + + /** + * ::Experimental:: + * (Scala-specific) + * Applies the given function to each group of data, while maintaining a user-defined per-group + * state. The result Dataset will represent the objects returned by the function. + * For a static batch Dataset, the function will be invoked once per group. For a streaming + * Dataset, the function will be invoked for each group repeatedly in every trigger, and + * updates to each group's state will be saved across invocations. + * See [[KeyedState]] for more details. + * + * @tparam S The type of the user-defined state. Must be encodable to Spark SQL types. + * @tparam U The type of the output objects. Must be encodable to Spark SQL types. + * @param func Function to be called on every group. + * @param outputMode The output mode of the function. + * + * See [[Encoder]] for more details on what types are encodable to Spark SQL. + * @since 2.1.1 + */ + @Experimental + @InterfaceStability.Evolving + def flatMapGroupsWithState[S: Encoder, U: Encoder]( + func: (K, Iterator[V], KeyedState[S]) => Iterator[U], outputMode: String): Dataset[U] = { + flatMapGroupsWithState(func, InternalOutputModes(outputMode)) } /** @@ -314,6 +356,7 @@ class KeyValueGroupedDataset[K, V] private[sql]( * @tparam S The type of the user-defined state. Must be encodable to Spark SQL types. * @tparam U The type of the output objects. Must be encodable to Spark SQL types. * @param func Function to be called on every group. + * @param outputMode The output mode of the function. * @param stateEncoder Encoder for the state type. * @param outputEncoder Encoder for the output type. * @@ -324,13 +367,45 @@ class KeyValueGroupedDataset[K, V] private[sql]( @InterfaceStability.Evolving def flatMapGroupsWithState[S, U]( func: FlatMapGroupsWithStateFunction[K, V, S, U], + outputMode: OutputMode, stateEncoder: Encoder[S], outputEncoder: Encoder[U]): Dataset[U] = { flatMapGroupsWithState[S, U]( - (key: K, it: Iterator[V], s: KeyedState[S]) => func.call(key, it.asJava, s).asScala + (key: K, it: Iterator[V], s: KeyedState[S]) => func.call(key, it.asJava, s).asScala, + outputMode )(stateEncoder, outputEncoder) } + /** + * ::Experimental:: + * (Java-specific) + * Applies the given function to each group of data, while maintaining a user-defined per-group + * state. The result Dataset will represent the objects returned by the function. + * For a static batch Dataset, the function will be invoked once per group. For a streaming + * Dataset, the function will be invoked for each group repeatedly in every trigger, and + * updates to each group's state will be saved across invocations. + * See [[KeyedState]] for more details. + * + * @tparam S The type of the user-defined state. Must be encodable to Spark SQL types. + * @tparam U The type of the output objects. Must be encodable to Spark SQL types. + * @param func Function to be called on every group. + * @param outputMode The output mode of the function. + * @param stateEncoder Encoder for the state type. + * @param outputEncoder Encoder for the output type. + * + * See [[Encoder]] for more details on what types are encodable to Spark SQL. + * @since 2.1.1 + */ + @Experimental + @InterfaceStability.Evolving + def flatMapGroupsWithState[S, U]( + func: FlatMapGroupsWithStateFunction[K, V, S, U], + outputMode: String, + stateEncoder: Encoder[S], + outputEncoder: Encoder[U]): Dataset[U] = { + flatMapGroupsWithState(func, InternalOutputModes(outputMode), stateEncoder, outputEncoder) + } + /** * (Scala-specific) * Reduces the elements of each group of data using the specified binary function. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 20bf4925dbec5..0f7aa3709c1cf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -326,14 +326,24 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } /** - * Strategy to convert MapGroupsWithState logical operator to physical operator + * Strategy to convert [[FlatMapGroupsWithState]] logical operator to physical operator * in streaming plans. Conversion for batch plans is handled by [[BasicOperators]]. */ object MapGroupsWithStateStrategy extends Strategy { override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case MapGroupsWithState( - f, keyDeser, valueDeser, groupAttr, dataAttr, outputAttr, stateDeser, stateSer, child) => - val execPlan = MapGroupsWithStateExec( + case FlatMapGroupsWithState( + f, + keyDeser, + valueDeser, + groupAttr, + dataAttr, + outputAttr, + stateDeser, + stateSer, + outputMode, + child, + _) => + val execPlan = FlatMapGroupsWithStateExec( f, keyDeser, valueDeser, groupAttr, dataAttr, outputAttr, None, stateDeser, stateSer, planLater(child)) execPlan :: Nil @@ -381,7 +391,8 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { execution.AppendColumnsWithObjectExec(f, childSer, newSer, planLater(child)) :: Nil case logical.MapGroups(f, key, value, grouping, data, objAttr, child) => execution.MapGroupsExec(f, key, value, grouping, data, objAttr, planLater(child)) :: Nil - case logical.MapGroupsWithState(f, key, value, grouping, data, output, _, _, child) => + case logical.FlatMapGroupsWithState( + f, key, value, grouping, data, output, _, _, _, child, _) => execution.MapGroupsExec(f, key, value, grouping, data, output, planLater(child)) :: Nil case logical.CoGroup(f, key, lObj, rObj, lGroup, rGroup, lAttr, rAttr, oAttr, left, right) => execution.CoGroupExec( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala index ffdcd9b19d058..610ce5e1ebf5d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala @@ -103,11 +103,11 @@ class IncrementalExecution( child, Some(stateId), Some(currentEventTimeWatermark)) - case MapGroupsWithStateExec( + case FlatMapGroupsWithStateExec( f, kDeser, vDeser, group, data, output, None, stateDeser, stateSer, child) => val stateId = OperatorStateId(checkpointLocation, operatorId.getAndIncrement(), currentBatchId) - MapGroupsWithStateExec( + FlatMapGroupsWithStateExec( f, kDeser, vDeser, group, data, output, Some(stateId), stateDeser, stateSer, child) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala index cbf656a2044dc..c3075a3eacaac 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala @@ -257,8 +257,8 @@ case class StateStoreSaveExec( } -/** Physical operator for executing streaming mapGroupsWithState. */ -case class MapGroupsWithStateExec( +/** Physical operator for executing streaming flatMapGroupsWithState. */ +case class FlatMapGroupsWithStateExec( func: (Any, Iterator[Any], LogicalKeyedState[Any]) => Iterator[Any], keyDeserializer: Expression, valueDeserializer: Expression, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala index 0f7a33723cccc..c8fda8cd83598 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala @@ -20,8 +20,8 @@ package org.apache.spark.sql.streaming import scala.collection.JavaConverters._ import org.apache.spark.annotation.{Experimental, InterfaceStability} -import org.apache.spark.sql.{AnalysisException, DataFrame, Dataset, ForeachWriter} -import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ +import org.apache.spark.sql.{AnalysisException, Dataset, ForeachWriter} +import org.apache.spark.sql.catalyst.streaming.InternalOutputModes import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.streaming.{ForeachSink, MemoryPlan, MemorySink} @@ -69,17 +69,7 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { * @since 2.0.0 */ def outputMode(outputMode: String): DataStreamWriter[T] = { - this.outputMode = outputMode.toLowerCase match { - case "append" => - OutputMode.Append - case "complete" => - OutputMode.Complete - case "update" => - OutputMode.Update - case _ => - throw new IllegalArgumentException(s"Unknown output mode $outputMode. " + - "Accepted output modes are 'append', 'complete', 'update'") - } + this.outputMode = InternalOutputModes(outputMode) this } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index e3b0e37ccab05..d06e35bb44d08 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -23,6 +23,7 @@ import java.sql.Timestamp; import java.util.*; +import org.apache.spark.sql.streaming.OutputMode; import scala.Tuple2; import scala.Tuple3; import scala.Tuple4; @@ -205,6 +206,7 @@ public void testGroupBy() { } return Collections.singletonList(sb.toString()).iterator(); }, + OutputMode.Append(), Encoders.LONG(), Encoders.STRING()); diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/MapGroupsWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala similarity index 88% rename from sql/core/src/test/scala/org/apache/spark/sql/streaming/MapGroupsWithStateSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala index 6cf4d51f99333..902b842e97aa9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/MapGroupsWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.execution.streaming.state.StateStore /** Class to check custom state types */ case class RunningCount(count: Long) -class MapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAfterAll { +class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAfterAll { import testImplicits._ @@ -119,9 +119,9 @@ class MapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAfterA val result = inputData.toDS() .groupByKey(x => x) - .flatMapGroupsWithState(stateFunc) // State: Int, Out: (Str, Str) + .flatMapGroupsWithState(stateFunc, Update) // State: Int, Out: (Str, Str) - testStream(result, Append)( + testStream(result, Update)( AddData(inputData, "a"), CheckLastBatch(("a", "1")), assertNumStateRows(total = 1, updated = 1), @@ -162,9 +162,9 @@ class MapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAfterA val result = inputData.toDS() .groupByKey(x => x) - .flatMapGroupsWithState(stateFunc) // State: Int, Out: (Str, Str) + .flatMapGroupsWithState(stateFunc, Update) // State: Int, Out: (Str, Str) - testStream(result, Append)( + testStream(result, Update)( AddData(inputData, "a", "a", "b"), CheckLastBatch(("a", "1"), ("a", "2"), ("b", "1")), StopStream, @@ -185,7 +185,7 @@ class MapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAfterA Iterator((key, values.size)) } checkAnswer( - Seq("a", "a", "b").toDS.groupByKey(x => x).flatMapGroupsWithState(stateFunc).toDF, + Seq("a", "a", "b").toDS.groupByKey(x => x).flatMapGroupsWithState(stateFunc, Update).toDF, Seq(("a", 2), ("b", 1)).toDF) } @@ -210,7 +210,7 @@ class MapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAfterA .groupByKey(x => x) .mapGroupsWithState(stateFunc) // Types = State: MyState, Out: (Str, Str) - testStream(result, Append)( + testStream(result, Update)( AddData(inputData, "a"), CheckLastBatch(("a", "1")), assertNumStateRows(total = 1, updated = 1), @@ -230,7 +230,7 @@ class MapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAfterA ) } - test("mapGroupsWithState - streaming + aggregation") { + test("flatMapGroupsWithState - streaming + aggregation") { // Function to maintain running count up to 2, and then remove the count // Returns the data and the count (-1 if count reached beyond 2 and state was just removed) val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => { @@ -238,10 +238,10 @@ class MapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAfterA val count = state.getOption.map(_.count).getOrElse(0L) + values.size if (count == 3) { state.remove() - (key, "-1") + Iterator(key -> "-1") } else { state.update(RunningCount(count)) - (key, count.toString) + Iterator(key -> count.toString) } } @@ -249,7 +249,7 @@ class MapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAfterA val result = inputData.toDS() .groupByKey(x => x) - .mapGroupsWithState(stateFunc) // Types = State: MyState, Out: (Str, Str) + .flatMapGroupsWithState(stateFunc, Append) // Types = State: MyState, Out: (Str, Str) .groupByKey(_._1) .count() @@ -290,7 +290,7 @@ class MapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAfterA testQuietly("StateStore.abort on task failure handling") { val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => { - if (MapGroupsWithStateSuite.failInTask) throw new Exception("expected failure") + if (FlatMapGroupsWithStateSuite.failInTask) throw new Exception("expected failure") val count = state.getOption.map(_.count).getOrElse(0L) + values.size state.update(RunningCount(count)) (key, count) @@ -303,11 +303,11 @@ class MapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAfterA .mapGroupsWithState(stateFunc) // Types = State: MyState, Out: (Str, Str) def setFailInTask(value: Boolean): AssertOnQuery = AssertOnQuery { q => - MapGroupsWithStateSuite.failInTask = value + FlatMapGroupsWithStateSuite.failInTask = value true } - testStream(result, Append)( + testStream(result, Update)( setFailInTask(false), AddData(inputData, "a"), CheckLastBatch(("a", 1L)), @@ -321,8 +321,24 @@ class MapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAfterA CheckLastBatch(("a", 3L)) // task should not fail, and should show correct count ) } + + test("disallow complete mode") { + val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => { + Iterator[String]() + } + + var e = intercept[IllegalArgumentException] { + MemoryStream[String].toDS().groupByKey(x => x).flatMapGroupsWithState(stateFunc, Complete) + } + assert(e.getMessage === "The output mode of function should be append or update") + + e = intercept[IllegalArgumentException] { + MemoryStream[String].toDS().groupByKey(x => x).flatMapGroupsWithState(stateFunc, "complete") + } + assert(e.getMessage === "The output mode of function should be append or update") + } } -object MapGroupsWithStateSuite { +object FlatMapGroupsWithStateSuite { var failInTask = true } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala index 0470411a0f108..f61dcdcbcf718 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala @@ -24,8 +24,7 @@ import scala.concurrent.duration._ import org.apache.hadoop.fs.Path import org.mockito.Mockito._ -import org.scalatest.{BeforeAndAfter, PrivateMethodTester} -import org.scalatest.PrivateMethodTester.PrivateMethod +import org.scalatest.BeforeAndAfter import org.apache.spark.sql._ import org.apache.spark.sql.execution.streaming._ @@ -107,7 +106,7 @@ class DefaultSource extends StreamSourceProvider with StreamSinkProvider { } } -class DataStreamReaderWriterSuite extends StreamTest with BeforeAndAfter with PrivateMethodTester { +class DataStreamReaderWriterSuite extends StreamTest with BeforeAndAfter { private def newMetadataDir = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath @@ -390,42 +389,6 @@ class DataStreamReaderWriterSuite extends StreamTest with BeforeAndAfter with Pr private def newTextInput = Utils.createTempDir(namePrefix = "text").getCanonicalPath - test("supported strings in outputMode(string)") { - val outputModeMethod = PrivateMethod[OutputMode]('outputMode) - - def testMode(outputMode: String, expected: OutputMode): Unit = { - val df = spark.readStream - .format("org.apache.spark.sql.streaming.test") - .load() - val w = df.writeStream - w.outputMode(outputMode) - val setOutputMode = w invokePrivate outputModeMethod() - assert(setOutputMode === expected) - } - - testMode("append", OutputMode.Append) - testMode("Append", OutputMode.Append) - testMode("complete", OutputMode.Complete) - testMode("Complete", OutputMode.Complete) - testMode("update", OutputMode.Update) - testMode("Update", OutputMode.Update) - } - - test("unsupported strings in outputMode(string)") { - def testMode(outputMode: String): Unit = { - val acceptedModes = Seq("append", "update", "complete") - val df = spark.readStream - .format("org.apache.spark.sql.streaming.test") - .load() - val w = df.writeStream - val e = intercept[IllegalArgumentException](w.outputMode(outputMode)) - (Seq("output mode", "unknown", outputMode) ++ acceptedModes).foreach { s => - assert(e.getMessage.toLowerCase.contains(s.toLowerCase)) - } - } - testMode("Xyz") - } - test("check foreach() catches null writers") { val df = spark.readStream .format("org.apache.spark.sql.streaming.test") From 6570cfd7abe349dc6d2151f2ac9dc662e7465a79 Mon Sep 17 00:00:00 2001 From: Kunal Khamar Date: Wed, 8 Mar 2017 13:06:22 -0800 Subject: [PATCH 51/78] [SPARK-19540][SQL] Add ability to clone SparkSession wherein cloned session has an identical copy of the SessionState Forking a newSession() from SparkSession currently makes a new SparkSession that does not retain SessionState (i.e. temporary tables, SQL config, registered functions etc.) This change adds a method cloneSession() which creates a new SparkSession with a copy of the parent's SessionState. Subsequent changes to base session are not propagated to cloned session, clone is independent after creation. If the base is changed after clone has been created, say user registers new UDF, then the new UDF will not be available inside the clone. Same goes for configs and temp tables. Unit tests Author: Kunal Khamar Author: Shixiong Zhu Closes #16826 from kunalkhamar/fork-sparksession. --- .../spark/sql/catalyst/CatalystConf.scala | 7 +- .../catalyst/analysis/FunctionRegistry.scala | 5 +- .../sql/catalyst/catalog/SessionCatalog.scala | 38 ++- .../catalog/SessionCatalogSuite.scala | 55 ++++ .../spark/sql/ExperimentalMethods.scala | 6 + .../org/apache/spark/sql/SparkSession.scala | 59 +++- .../sql/execution/datasources/rules.scala | 3 +- .../apache/spark/sql/internal/SQLConf.scala | 8 + .../spark/sql/internal/SessionState.scala | 235 ++++++++++------ .../apache/spark/sql/SessionStateSuite.scala | 162 +++++++++++ .../spark/sql/internal/CatalogSuite.scala | 21 +- .../sql/internal/SQLConfEntrySuite.scala | 18 ++ .../spark/sql/test/TestSQLContext.scala | 20 +- .../spark/sql/hive/HiveMetastoreCatalog.scala | 5 +- .../spark/sql/hive/HiveSessionCatalog.scala | 92 ++++-- .../spark/sql/hive/HiveSessionState.scala | 261 +++++++++++++----- .../sql/hive/client/HiveClientImpl.scala | 2 + .../apache/spark/sql/hive/test/TestHive.scala | 67 ++--- .../sql/hive/HiveSessionCatalogSuite.scala | 112 ++++++++ .../sql/hive/HiveSessionStateSuite.scala | 41 +++ 20 files changed, 981 insertions(+), 236 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/SessionStateSuite.scala create mode 100644 sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSessionCatalogSuite.scala create mode 100644 sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSessionStateSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala index fb99cb27b847b..cff0efa979932 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala @@ -66,6 +66,8 @@ trait CatalystConf { /** The maximum number of joined nodes allowed in the dynamic programming algorithm. */ def joinReorderDPThreshold: Int + + override def clone(): CatalystConf = throw new CloneNotSupportedException() } @@ -85,4 +87,7 @@ case class SimpleCatalystConf( joinReorderDPThreshold: Int = 12, warehousePath: String = "/user/hive/warehouse", sessionLocalTimeZone: String = TimeZone.getDefault().getID) - extends CatalystConf + extends CatalystConf { + + override def clone(): SimpleCatalystConf = this.copy() +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 556fa9901701b..0dcb44081f608 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -64,6 +64,8 @@ trait FunctionRegistry { /** Clear all registered functions. */ def clear(): Unit + /** Create a copy of this registry with identical functions as this registry. */ + override def clone(): FunctionRegistry = throw new CloneNotSupportedException() } class SimpleFunctionRegistry extends FunctionRegistry { @@ -107,7 +109,7 @@ class SimpleFunctionRegistry extends FunctionRegistry { functionBuilders.clear() } - def copy(): SimpleFunctionRegistry = synchronized { + override def clone(): SimpleFunctionRegistry = synchronized { val registry = new SimpleFunctionRegistry functionBuilders.iterator.foreach { case (name, (info, builder)) => registry.registerFunction(name, info, builder) @@ -150,6 +152,7 @@ object EmptyFunctionRegistry extends FunctionRegistry { throw new UnsupportedOperationException } + override def clone(): FunctionRegistry = this } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index 831e37aac1246..6cfc4a4321316 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -50,7 +50,6 @@ object SessionCatalog { class SessionCatalog( externalCatalog: ExternalCatalog, globalTempViewManager: GlobalTempViewManager, - functionResourceLoader: FunctionResourceLoader, functionRegistry: FunctionRegistry, conf: CatalystConf, hadoopConf: Configuration, @@ -66,16 +65,19 @@ class SessionCatalog( this( externalCatalog, new GlobalTempViewManager("global_temp"), - DummyFunctionResourceLoader, functionRegistry, conf, new Configuration(), CatalystSqlParser) + functionResourceLoader = DummyFunctionResourceLoader } // For testing only. def this(externalCatalog: ExternalCatalog) { - this(externalCatalog, new SimpleFunctionRegistry, new SimpleCatalystConf(true)) + this( + externalCatalog, + new SimpleFunctionRegistry, + SimpleCatalystConf(caseSensitiveAnalysis = true)) } /** List of temporary tables, mapping from table name to their logical plan. */ @@ -89,6 +91,8 @@ class SessionCatalog( @GuardedBy("this") protected var currentDb = formatDatabaseName(DEFAULT_DATABASE) + @volatile var functionResourceLoader: FunctionResourceLoader = _ + /** * Checks if the given name conforms the Hive standard ("[a-zA-z_0-9]+"), * i.e. if this name only contains characters, numbers, and _. @@ -987,6 +991,9 @@ class SessionCatalog( * by a tuple (resource type, resource uri). */ def loadFunctionResources(resources: Seq[FunctionResource]): Unit = { + if (functionResourceLoader == null) { + throw new IllegalStateException("functionResourceLoader has not yet been initialized") + } resources.foreach(functionResourceLoader.loadResource) } @@ -1182,4 +1189,29 @@ class SessionCatalog( } } + /** + * Create a new [[SessionCatalog]] with the provided parameters. `externalCatalog` and + * `globalTempViewManager` are `inherited`, while `currentDb` and `tempTables` are copied. + */ + def newSessionCatalogWith( + conf: CatalystConf, + hadoopConf: Configuration, + functionRegistry: FunctionRegistry, + parser: ParserInterface): SessionCatalog = { + val catalog = new SessionCatalog( + externalCatalog, + globalTempViewManager, + functionRegistry, + conf, + hadoopConf, + parser) + + synchronized { + catalog.currentDb = currentDb + // copy over temporary tables + tempTables.foreach(kv => catalog.tempTables.put(kv._1, kv._2)) + } + + catalog + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala index 328a16c4bf024..7e74dcdef0e27 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.catalog +import org.apache.hadoop.conf.Configuration + import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.{FunctionIdentifier, SimpleCatalystConf, TableIdentifier} import org.apache.spark.sql.catalyst.analysis._ @@ -1197,6 +1199,59 @@ class SessionCatalogSuite extends PlanTest { } } + test("clone SessionCatalog - temp views") { + val externalCatalog = newEmptyCatalog() + val original = new SessionCatalog(externalCatalog) + val tempTable1 = Range(1, 10, 1, 10) + original.createTempView("copytest1", tempTable1, overrideIfExists = false) + + // check if tables copied over + val clone = original.newSessionCatalogWith( + SimpleCatalystConf(caseSensitiveAnalysis = true), + new Configuration(), + new SimpleFunctionRegistry, + CatalystSqlParser) + assert(original ne clone) + assert(clone.getTempView("copytest1") == Some(tempTable1)) + + // check if clone and original independent + clone.dropTable(TableIdentifier("copytest1"), ignoreIfNotExists = false, purge = false) + assert(original.getTempView("copytest1") == Some(tempTable1)) + + val tempTable2 = Range(1, 20, 2, 10) + original.createTempView("copytest2", tempTable2, overrideIfExists = false) + assert(clone.getTempView("copytest2").isEmpty) + } + + test("clone SessionCatalog - current db") { + val externalCatalog = newEmptyCatalog() + val db1 = "db1" + val db2 = "db2" + val db3 = "db3" + + externalCatalog.createDatabase(newDb(db1), ignoreIfExists = true) + externalCatalog.createDatabase(newDb(db2), ignoreIfExists = true) + externalCatalog.createDatabase(newDb(db3), ignoreIfExists = true) + + val original = new SessionCatalog(externalCatalog) + original.setCurrentDatabase(db1) + + // check if current db copied over + val clone = original.newSessionCatalogWith( + SimpleCatalystConf(caseSensitiveAnalysis = true), + new Configuration(), + new SimpleFunctionRegistry, + CatalystSqlParser) + assert(original ne clone) + assert(clone.getCurrentDatabase == db1) + + // check if clone and original independent + clone.setCurrentDatabase(db2) + assert(original.getCurrentDatabase == db1) + original.setCurrentDatabase(db3) + assert(clone.getCurrentDatabase == db2) + } + test("SPARK-19737: detect undefined functions without triggering relation resolution") { import org.apache.spark.sql.catalyst.dsl.plans._ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/ExperimentalMethods.scala b/sql/core/src/main/scala/org/apache/spark/sql/ExperimentalMethods.scala index 1e8ba51e59e33..bd8dd6ea3fe0f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/ExperimentalMethods.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/ExperimentalMethods.scala @@ -46,4 +46,10 @@ class ExperimentalMethods private[sql]() { @volatile var extraOptimizations: Seq[Rule[LogicalPlan]] = Nil + override def clone(): ExperimentalMethods = { + val result = new ExperimentalMethods + result.extraStrategies = extraStrategies + result.extraOptimizations = extraOptimizations + result + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index afc1827e7eece..49562578b23cd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -21,7 +21,6 @@ import java.io.Closeable import java.util.concurrent.atomic.AtomicReference import scala.collection.JavaConverters._ -import scala.reflect.ClassTag import scala.reflect.runtime.universe.TypeTag import scala.util.control.NonFatal @@ -43,7 +42,7 @@ import org.apache.spark.sql.internal.{CatalogImpl, SessionState, SharedState} import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION import org.apache.spark.sql.sources.BaseRelation import org.apache.spark.sql.streaming._ -import org.apache.spark.sql.types.{DataType, LongType, StructType} +import org.apache.spark.sql.types.{DataType, StructType} import org.apache.spark.sql.util.ExecutionListenerManager import org.apache.spark.util.Utils @@ -67,15 +66,22 @@ import org.apache.spark.util.Utils * .config("spark.some.config.option", "some-value") * .getOrCreate() * }}} + * + * @param sparkContext The Spark context associated with this Spark session. + * @param existingSharedState If supplied, use the existing shared state + * instead of creating a new one. + * @param parentSessionState If supplied, inherit all session state (i.e. temporary + * views, SQL config, UDFs etc) from parent. */ @InterfaceStability.Stable class SparkSession private( @transient val sparkContext: SparkContext, - @transient private val existingSharedState: Option[SharedState]) + @transient private val existingSharedState: Option[SharedState], + @transient private val parentSessionState: Option[SessionState]) extends Serializable with Closeable with Logging { self => private[sql] def this(sc: SparkContext) { - this(sc, None) + this(sc, None, None) } sparkContext.assertNotStopped() @@ -108,6 +114,7 @@ class SparkSession private( /** * State isolated across sessions, including SQL configurations, temporary tables, registered * functions, and everything else that accepts a [[org.apache.spark.sql.internal.SQLConf]]. + * If `parentSessionState` is not null, the `SessionState` will be a copy of the parent. * * This is internal to Spark and there is no guarantee on interface stability. * @@ -116,9 +123,13 @@ class SparkSession private( @InterfaceStability.Unstable @transient lazy val sessionState: SessionState = { - SparkSession.reflect[SessionState, SparkSession]( - SparkSession.sessionStateClassName(sparkContext.conf), - self) + parentSessionState + .map(_.clone(this)) + .getOrElse { + SparkSession.instantiateSessionState( + SparkSession.sessionStateClassName(sparkContext.conf), + self) + } } /** @@ -208,7 +219,25 @@ class SparkSession private( * @since 2.0.0 */ def newSession(): SparkSession = { - new SparkSession(sparkContext, Some(sharedState)) + new SparkSession(sparkContext, Some(sharedState), parentSessionState = None) + } + + /** + * Create an identical copy of this `SparkSession`, sharing the underlying `SparkContext` + * and shared state. All the state of this session (i.e. SQL configurations, temporary tables, + * registered functions) is copied over, and the cloned session is set up with the same shared + * state as this session. The cloned session is independent of this session, that is, any + * non-global change in either session is not reflected in the other. + * + * @note Other than the `SparkContext`, all shared state is initialized lazily. + * This method will force the initialization of the shared state to ensure that parent + * and child sessions are set up with the same shared state. If the underlying catalog + * implementation is Hive, this will initialize the metastore, which may take some time. + */ + private[sql] def cloneSession(): SparkSession = { + val result = new SparkSession(sparkContext, Some(sharedState), Some(sessionState)) + result.sessionState // force copy of SessionState + result } @@ -971,16 +1000,18 @@ object SparkSession { } /** - * Helper method to create an instance of [[T]] using a single-arg constructor that - * accepts an [[Arg]]. + * Helper method to create an instance of `SessionState` based on `className` from conf. + * The result is either `SessionState` or `HiveSessionState`. */ - private def reflect[T, Arg <: AnyRef]( + private def instantiateSessionState( className: String, - ctorArg: Arg)(implicit ctorArgTag: ClassTag[Arg]): T = { + sparkSession: SparkSession): SessionState = { + try { + // get `SessionState.apply(SparkSession)` val clazz = Utils.classForName(className) - val ctor = clazz.getDeclaredConstructor(ctorArgTag.runtimeClass) - ctor.newInstance(ctorArg).asInstanceOf[T] + val method = clazz.getMethod("apply", sparkSession.getClass) + method.invoke(null, sparkSession).asInstanceOf[SessionState] } catch { case NonFatal(e) => throw new IllegalArgumentException(s"Error while instantiating '$className':", e) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala index 4d781b96abace..8b598cc60e778 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala @@ -66,7 +66,8 @@ class ResolveSQLOnFile(sparkSession: SparkSession) extends Rule[LogicalPlan] { * Preprocess [[CreateTable]], to do some normalization and checking. */ case class PreprocessTableCreation(sparkSession: SparkSession) extends Rule[LogicalPlan] { - private val catalog = sparkSession.sessionState.catalog + // catalog is a def and not a val/lazy val as the latter would introduce a circular reference + private def catalog = sparkSession.sessionState.catalog def apply(plan: LogicalPlan): LogicalPlan = plan transform { // When we CREATE TABLE without specifying the table schema, we should fail the query if diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 94e3fa7dd13f7..1244f690fd829 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1019,6 +1019,14 @@ private[sql] class SQLConf extends Serializable with CatalystConf with Logging { def clear(): Unit = { settings.clear() } + + override def clone(): SQLConf = { + val result = new SQLConf + getAllConfs.foreach { + case(k, v) => if (v ne null) result.setConfString(k, v) + } + result + } } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala index 69085605113ea..ce80604bd3657 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala @@ -22,38 +22,49 @@ import java.io.File import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path +import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.{Analyzer, FunctionRegistry} import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.optimizer.Optimizer import org.apache.spark.sql.catalyst.parser.ParserInterface import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution._ -import org.apache.spark.sql.execution.command.AnalyzeTableCommand import org.apache.spark.sql.execution.datasources._ -import org.apache.spark.sql.streaming.{StreamingQuery, StreamingQueryManager} +import org.apache.spark.sql.streaming.StreamingQueryManager import org.apache.spark.sql.util.ExecutionListenerManager /** * A class that holds all session-specific state in a given [[SparkSession]]. + * @param sparkContext The [[SparkContext]]. + * @param sharedState The shared state. + * @param conf SQL-specific key-value configurations. + * @param experimentalMethods The experimental methods. + * @param functionRegistry Internal catalog for managing functions registered by the user. + * @param catalog Internal catalog for managing table and database states. + * @param sqlParser Parser that extracts expressions, plans, table identifiers etc. from SQL texts. + * @param analyzer Logical query plan analyzer for resolving unresolved attributes and relations. + * @param streamingQueryManager Interface to start and stop + * [[org.apache.spark.sql.streaming.StreamingQuery]]s. + * @param queryExecutionCreator Lambda to create a [[QueryExecution]] from a [[LogicalPlan]] */ -private[sql] class SessionState(sparkSession: SparkSession) { +private[sql] class SessionState( + sparkContext: SparkContext, + sharedState: SharedState, + val conf: SQLConf, + val experimentalMethods: ExperimentalMethods, + val functionRegistry: FunctionRegistry, + val catalog: SessionCatalog, + val sqlParser: ParserInterface, + val analyzer: Analyzer, + val streamingQueryManager: StreamingQueryManager, + val queryExecutionCreator: LogicalPlan => QueryExecution) { - // Note: These are all lazy vals because they depend on each other (e.g. conf) and we - // want subclasses to override some of the fields. Otherwise, we would get a lot of NPEs. - - /** - * SQL-specific key-value configurations. - */ - lazy val conf: SQLConf = new SQLConf - - def newHadoopConf(): Configuration = { - val hadoopConf = new Configuration(sparkSession.sparkContext.hadoopConfiguration) - conf.getAllConfs.foreach { case (k, v) => if (v ne null) hadoopConf.set(k, v) } - hadoopConf - } + def newHadoopConf(): Configuration = SessionState.newHadoopConf( + sparkContext.hadoopConfiguration, + conf) def newHadoopConfWithOptions(options: Map[String, String]): Configuration = { val hadoopConf = newHadoopConf() @@ -65,22 +76,15 @@ private[sql] class SessionState(sparkSession: SparkSession) { hadoopConf } - lazy val experimentalMethods = new ExperimentalMethods - - /** - * Internal catalog for managing functions registered by the user. - */ - lazy val functionRegistry: FunctionRegistry = FunctionRegistry.builtin.copy() - /** * A class for loading resources specified by a function. */ - lazy val functionResourceLoader: FunctionResourceLoader = { + val functionResourceLoader: FunctionResourceLoader = { new FunctionResourceLoader { override def loadResource(resource: FunctionResource): Unit = { resource.resourceType match { case JarResource => addJar(resource.uri) - case FileResource => sparkSession.sparkContext.addFile(resource.uri) + case FileResource => sparkContext.addFile(resource.uri) case ArchiveResource => throw new AnalysisException( "Archive is not allowed to be loaded. If YARN mode is used, " + @@ -90,93 +94,78 @@ private[sql] class SessionState(sparkSession: SparkSession) { } } - /** - * Internal catalog for managing table and database states. - */ - lazy val catalog = new SessionCatalog( - sparkSession.sharedState.externalCatalog, - sparkSession.sharedState.globalTempViewManager, - functionResourceLoader, - functionRegistry, - conf, - newHadoopConf(), - sqlParser) - /** * Interface exposed to the user for registering user-defined functions. * Note that the user-defined functions must be deterministic. */ - lazy val udf: UDFRegistration = new UDFRegistration(functionRegistry) - - /** - * Logical query plan analyzer for resolving unresolved attributes and relations. - */ - lazy val analyzer: Analyzer = { - new Analyzer(catalog, conf) { - override val extendedResolutionRules = - new FindDataSourceTable(sparkSession) :: - new ResolveSQLOnFile(sparkSession) :: Nil - - override val postHocResolutionRules = - PreprocessTableCreation(sparkSession) :: - PreprocessTableInsertion(conf) :: - DataSourceAnalysis(conf) :: Nil - - override val extendedCheckRules = Seq(PreWriteCheck, HiveOnlyCheck) - } - } + val udf: UDFRegistration = new UDFRegistration(functionRegistry) /** * Logical query plan optimizer. */ - lazy val optimizer: Optimizer = new SparkOptimizer(catalog, conf, experimentalMethods) - - /** - * Parser that extracts expressions, plans, table identifiers etc. from SQL texts. - */ - lazy val sqlParser: ParserInterface = new SparkSqlParser(conf) + val optimizer: Optimizer = new SparkOptimizer(catalog, conf, experimentalMethods) /** * Planner that converts optimized logical plans to physical plans. */ def planner: SparkPlanner = - new SparkPlanner(sparkSession.sparkContext, conf, experimentalMethods.extraStrategies) + new SparkPlanner(sparkContext, conf, experimentalMethods.extraStrategies) /** * An interface to register custom [[org.apache.spark.sql.util.QueryExecutionListener]]s * that listen for execution metrics. */ - lazy val listenerManager: ExecutionListenerManager = new ExecutionListenerManager + val listenerManager: ExecutionListenerManager = new ExecutionListenerManager /** - * Interface to start and stop [[StreamingQuery]]s. + * Get an identical copy of the `SessionState` and associate it with the given `SparkSession` */ - lazy val streamingQueryManager: StreamingQueryManager = { - new StreamingQueryManager(sparkSession) - } + def clone(newSparkSession: SparkSession): SessionState = { + val sparkContext = newSparkSession.sparkContext + val confCopy = conf.clone() + val functionRegistryCopy = functionRegistry.clone() + val sqlParser: ParserInterface = new SparkSqlParser(confCopy) + val catalogCopy = catalog.newSessionCatalogWith( + confCopy, + SessionState.newHadoopConf(sparkContext.hadoopConfiguration, confCopy), + functionRegistryCopy, + sqlParser) + val queryExecutionCreator = (plan: LogicalPlan) => new QueryExecution(newSparkSession, plan) - private val jarClassLoader: NonClosableMutableURLClassLoader = - sparkSession.sharedState.jarClassLoader + SessionState.mergeSparkConf(confCopy, sparkContext.getConf) - // Automatically extract all entries and put it in our SQLConf - // We need to call it after all of vals have been initialized. - sparkSession.sparkContext.getConf.getAll.foreach { case (k, v) => - conf.setConfString(k, v) + new SessionState( + sparkContext, + newSparkSession.sharedState, + confCopy, + experimentalMethods.clone(), + functionRegistryCopy, + catalogCopy, + sqlParser, + SessionState.createAnalyzer(newSparkSession, catalogCopy, confCopy), + new StreamingQueryManager(newSparkSession), + queryExecutionCreator) } // ------------------------------------------------------ // Helper methods, partially leftover from pre-2.0 days // ------------------------------------------------------ - def executePlan(plan: LogicalPlan): QueryExecution = new QueryExecution(sparkSession, plan) + def executePlan(plan: LogicalPlan): QueryExecution = queryExecutionCreator(plan) def refreshTable(tableName: String): Unit = { catalog.refreshTable(sqlParser.parseTableIdentifier(tableName)) } + /** + * Add a jar path to [[SparkContext]] and the classloader. + * + * Note: this method seems not access any session state, but the subclass `HiveSessionState` needs + * to add the jar to its hive client for the current session. Hence, it still needs to be in + * [[SessionState]]. + */ def addJar(path: String): Unit = { - sparkSession.sparkContext.addJar(path) - + sparkContext.addJar(path) val uri = new Path(path).toUri val jarURL = if (uri.getScheme == null) { // `path` is a local file path without a URL scheme @@ -185,15 +174,93 @@ private[sql] class SessionState(sparkSession: SparkSession) { // `path` is a URL with a scheme uri.toURL } - jarClassLoader.addURL(jarURL) - Thread.currentThread().setContextClassLoader(jarClassLoader) + sharedState.jarClassLoader.addURL(jarURL) + Thread.currentThread().setContextClassLoader(sharedState.jarClassLoader) + } +} + + +private[sql] object SessionState { + + def apply(sparkSession: SparkSession): SessionState = { + apply(sparkSession, new SQLConf) + } + + def apply(sparkSession: SparkSession, sqlConf: SQLConf): SessionState = { + val sparkContext = sparkSession.sparkContext + + // Automatically extract all entries and put them in our SQLConf + mergeSparkConf(sqlConf, sparkContext.getConf) + + val functionRegistry = FunctionRegistry.builtin.clone() + + val sqlParser: ParserInterface = new SparkSqlParser(sqlConf) + + val catalog = new SessionCatalog( + sparkSession.sharedState.externalCatalog, + sparkSession.sharedState.globalTempViewManager, + functionRegistry, + sqlConf, + newHadoopConf(sparkContext.hadoopConfiguration, sqlConf), + sqlParser) + + val analyzer: Analyzer = createAnalyzer(sparkSession, catalog, sqlConf) + + val streamingQueryManager: StreamingQueryManager = new StreamingQueryManager(sparkSession) + + val queryExecutionCreator = (plan: LogicalPlan) => new QueryExecution(sparkSession, plan) + + val sessionState = new SessionState( + sparkContext, + sparkSession.sharedState, + sqlConf, + new ExperimentalMethods, + functionRegistry, + catalog, + sqlParser, + analyzer, + streamingQueryManager, + queryExecutionCreator) + // functionResourceLoader needs to access SessionState.addJar, so it cannot be created before + // creating SessionState. Setting `catalog.functionResourceLoader` here is safe since the caller + // cannot use SessionCatalog before we return SessionState. + catalog.functionResourceLoader = sessionState.functionResourceLoader + sessionState + } + + def newHadoopConf(hadoopConf: Configuration, sqlConf: SQLConf): Configuration = { + val newHadoopConf = new Configuration(hadoopConf) + sqlConf.getAllConfs.foreach { case (k, v) => if (v ne null) newHadoopConf.set(k, v) } + newHadoopConf + } + + /** + * Create an logical query plan `Analyzer` with rules specific to a non-Hive `SessionState`. + */ + private def createAnalyzer( + sparkSession: SparkSession, + catalog: SessionCatalog, + sqlConf: SQLConf): Analyzer = { + new Analyzer(catalog, sqlConf) { + override val extendedResolutionRules: Seq[Rule[LogicalPlan]] = + new FindDataSourceTable(sparkSession) :: + new ResolveSQLOnFile(sparkSession) :: Nil + + override val postHocResolutionRules: Seq[Rule[LogicalPlan]] = + PreprocessTableCreation(sparkSession) :: + PreprocessTableInsertion(sqlConf) :: + DataSourceAnalysis(sqlConf) :: Nil + + override val extendedCheckRules = Seq(PreWriteCheck, HiveOnlyCheck) + } } /** - * Analyzes the given table in the current database to generate statistics, which will be - * used in query optimizations. + * Extract entries from `SparkConf` and put them in the `SQLConf` */ - def analyze(tableIdent: TableIdentifier, noscan: Boolean = true): Unit = { - AnalyzeTableCommand(tableIdent, noscan).run(sparkSession) + def mergeSparkConf(sqlConf: SQLConf, sparkConf: SparkConf): Unit = { + sparkConf.getAll.foreach { case (k, v) => + sqlConf.setConfString(k, v) + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SessionStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SessionStateSuite.scala new file mode 100644 index 0000000000000..2d5e37242a58b --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/SessionStateSuite.scala @@ -0,0 +1,162 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.scalatest.BeforeAndAfterAll +import org.scalatest.BeforeAndAfterEach + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.rules.Rule + +class SessionStateSuite extends SparkFunSuite + with BeforeAndAfterEach with BeforeAndAfterAll { + + /** + * A shared SparkSession for all tests in this suite. Make sure you reset any changes to this + * session as this is a singleton HiveSparkSession in HiveSessionStateSuite and it's shared + * with all Hive test suites. + */ + protected var activeSession: SparkSession = _ + + override def beforeAll(): Unit = { + activeSession = SparkSession.builder().master("local").getOrCreate() + } + + override def afterAll(): Unit = { + if (activeSession != null) { + activeSession.stop() + activeSession = null + } + super.afterAll() + } + + test("fork new session and inherit RuntimeConfig options") { + val key = "spark-config-clone" + try { + activeSession.conf.set(key, "active") + + // inheritance + val forkedSession = activeSession.cloneSession() + assert(forkedSession ne activeSession) + assert(forkedSession.conf ne activeSession.conf) + assert(forkedSession.conf.get(key) == "active") + + // independence + forkedSession.conf.set(key, "forked") + assert(activeSession.conf.get(key) == "active") + activeSession.conf.set(key, "dontcopyme") + assert(forkedSession.conf.get(key) == "forked") + } finally { + activeSession.conf.unset(key) + } + } + + test("fork new session and inherit function registry and udf") { + val testFuncName1 = "strlenScala" + val testFuncName2 = "addone" + try { + activeSession.udf.register(testFuncName1, (_: String).length + (_: Int)) + val forkedSession = activeSession.cloneSession() + + // inheritance + assert(forkedSession ne activeSession) + assert(forkedSession.sessionState.functionRegistry ne + activeSession.sessionState.functionRegistry) + assert(forkedSession.sessionState.functionRegistry.lookupFunction(testFuncName1).nonEmpty) + + // independence + forkedSession.sessionState.functionRegistry.dropFunction(testFuncName1) + assert(activeSession.sessionState.functionRegistry.lookupFunction(testFuncName1).nonEmpty) + activeSession.udf.register(testFuncName2, (_: Int) + 1) + assert(forkedSession.sessionState.functionRegistry.lookupFunction(testFuncName2).isEmpty) + } finally { + activeSession.sessionState.functionRegistry.dropFunction(testFuncName1) + activeSession.sessionState.functionRegistry.dropFunction(testFuncName2) + } + } + + test("fork new session and inherit experimental methods") { + val originalExtraOptimizations = activeSession.experimental.extraOptimizations + val originalExtraStrategies = activeSession.experimental.extraStrategies + try { + object DummyRule1 extends Rule[LogicalPlan] { + def apply(p: LogicalPlan): LogicalPlan = p + } + object DummyRule2 extends Rule[LogicalPlan] { + def apply(p: LogicalPlan): LogicalPlan = p + } + val optimizations = List(DummyRule1, DummyRule2) + activeSession.experimental.extraOptimizations = optimizations + val forkedSession = activeSession.cloneSession() + + // inheritance + assert(forkedSession ne activeSession) + assert(forkedSession.experimental ne activeSession.experimental) + assert(forkedSession.experimental.extraOptimizations.toSet == + activeSession.experimental.extraOptimizations.toSet) + + // independence + forkedSession.experimental.extraOptimizations = List(DummyRule2) + assert(activeSession.experimental.extraOptimizations == optimizations) + activeSession.experimental.extraOptimizations = List(DummyRule1) + assert(forkedSession.experimental.extraOptimizations == List(DummyRule2)) + } finally { + activeSession.experimental.extraOptimizations = originalExtraOptimizations + activeSession.experimental.extraStrategies = originalExtraStrategies + } + } + + test("fork new sessions and run query on inherited table") { + def checkTableExists(sparkSession: SparkSession): Unit = { + QueryTest.checkAnswer(sparkSession.sql( + """ + |SELECT x.str, COUNT(*) + |FROM df x JOIN df y ON x.str = y.str + |GROUP BY x.str + """.stripMargin), + Row("1", 1) :: Row("2", 1) :: Row("3", 1) :: Nil) + } + + val spark = activeSession + // Cannot use `import activeSession.implicits._` due to the compiler limitation. + import spark.implicits._ + + try { + activeSession + .createDataset[(Int, String)](Seq(1, 2, 3).map(i => (i, i.toString))) + .toDF("int", "str") + .createOrReplaceTempView("df") + checkTableExists(activeSession) + + val forkedSession = activeSession.cloneSession() + assert(forkedSession ne activeSession) + assert(forkedSession.sessionState ne activeSession.sessionState) + checkTableExists(forkedSession) + checkTableExists(activeSession.cloneSession()) // ability to clone multiple times + checkTableExists(forkedSession.cloneSession()) // clone of clone + } finally { + activeSession.sql("drop table df") + } + } + + test("fork new session and inherit reference to SharedState") { + val forkedSession = activeSession.cloneSession() + assert(activeSession.sharedState eq forkedSession.sharedState) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala index 989a7f2698171..fcb8ffbc6edd0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala @@ -493,6 +493,25 @@ class CatalogSuite } } - // TODO: add tests for the rest of them + test("clone Catalog") { + // need to test tempTables are cloned + assert(spark.catalog.listTables().collect().isEmpty) + createTempTable("my_temp_table") + assert(spark.catalog.listTables().collect().map(_.name).toSet == Set("my_temp_table")) + + // inheritance + val forkedSession = spark.cloneSession() + assert(spark ne forkedSession) + assert(spark.catalog ne forkedSession.catalog) + assert(forkedSession.catalog.listTables().collect().map(_.name).toSet == Set("my_temp_table")) + + // independence + dropTable("my_temp_table") // drop table in original session + assert(spark.catalog.listTables().collect().map(_.name).toSet == Set()) + assert(forkedSession.catalog.listTables().collect().map(_.name).toSet == Set("my_temp_table")) + forkedSession.sessionState.catalog + .createTempView("fork_table", Range(1, 2, 3, 4), overrideIfExists = true) + assert(spark.catalog.listTables().collect().map(_.name).toSet == Set()) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfEntrySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfEntrySuite.scala index 0e3a5ca9d71dd..f2456c7704064 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfEntrySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfEntrySuite.scala @@ -187,4 +187,22 @@ class SQLConfEntrySuite extends SparkFunSuite { } assert(e2.getMessage === "The maximum size of the cache must not be negative") } + + test("clone SQLConf") { + val original = new SQLConf + val key = "spark.sql.SQLConfEntrySuite.clone" + assert(original.getConfString(key, "noentry") === "noentry") + + // inheritance + original.setConfString(key, "orig") + val clone = original.clone() + assert(original ne clone) + assert(clone.getConfString(key, "noentry") === "orig") + + // independence + clone.setConfString(key, "clone") + assert(original.getConfString(key, "noentry") === "orig") + original.setConfString(key, "dontcopyme") + assert(clone.getConfString(key, "noentry") === "clone") + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala index 8ab6db175da5d..898a2fb4f329b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala @@ -35,18 +35,16 @@ private[sql] class TestSparkSession(sc: SparkContext) extends SparkSession(sc) { } @transient - override lazy val sessionState: SessionState = new SessionState(self) { - override lazy val conf: SQLConf = { - new SQLConf { - clear() - override def clear(): Unit = { - super.clear() - // Make sure we start with the default test configs even after clear - TestSQLContext.overrideConfs.foreach { case (key, value) => setConfString(key, value) } - } + override lazy val sessionState: SessionState = SessionState( + this, + new SQLConf { + clear() + override def clear(): Unit = { + super.clear() + // Make sure we start with the default test configs even after clear + TestSQLContext.overrideConfs.foreach { case (key, value) => setConfString(key, value) } } - } - } + }) // Needed for Java tests def loadTestData(): Unit = { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 4d3b6c3cec1c6..d135dfa9f4157 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -41,8 +41,9 @@ import org.apache.spark.sql.types._ * cleaned up to integrate more nicely with [[HiveExternalCatalog]]. */ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Logging { - private val sessionState = sparkSession.sessionState.asInstanceOf[HiveSessionState] - private lazy val tableRelationCache = sparkSession.sessionState.catalog.tableRelationCache + // these are def_s and not val/lazy val since the latter would introduce circular references + private def sessionState = sparkSession.sessionState.asInstanceOf[HiveSessionState] + private def tableRelationCache = sparkSession.sessionState.catalog.tableRelationCache private def getCurrentDatabase: String = sessionState.catalog.getCurrentDatabase diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala index f1ea86890c210..6b7599e3d3401 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala @@ -26,7 +26,7 @@ import org.apache.hadoop.hive.ql.exec.{FunctionRegistry => HiveFunctionRegistry} import org.apache.hadoop.hive.ql.udf.generic.{AbstractGenericUDAFResolver, GenericUDF, GenericUDTF} import org.apache.spark.sql.{AnalysisException, SparkSession} -import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} +import org.apache.spark.sql.catalyst.{CatalystConf, FunctionIdentifier, TableIdentifier} import org.apache.spark.sql.catalyst.analysis.FunctionRegistry import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder import org.apache.spark.sql.catalyst.catalog.{FunctionResourceLoader, GlobalTempViewManager, SessionCatalog} @@ -43,31 +43,23 @@ import org.apache.spark.util.Utils private[sql] class HiveSessionCatalog( externalCatalog: HiveExternalCatalog, globalTempViewManager: GlobalTempViewManager, - sparkSession: SparkSession, - functionResourceLoader: FunctionResourceLoader, + private val metastoreCatalog: HiveMetastoreCatalog, functionRegistry: FunctionRegistry, conf: SQLConf, hadoopConf: Configuration, parser: ParserInterface) extends SessionCatalog( - externalCatalog, - globalTempViewManager, - functionResourceLoader, - functionRegistry, - conf, - hadoopConf, - parser) { + externalCatalog, + globalTempViewManager, + functionRegistry, + conf, + hadoopConf, + parser) { // ---------------------------------------------------------------- // | Methods and fields for interacting with HiveMetastoreCatalog | // ---------------------------------------------------------------- - // Catalog for handling data source tables. TODO: This really doesn't belong here since it is - // essentially a cache for metastore tables. However, it relies on a lot of session-specific - // things so it would be a lot of work to split its functionality between HiveSessionCatalog - // and HiveCatalog. We should still do it at some point... - private val metastoreCatalog = new HiveMetastoreCatalog(sparkSession) - // These 2 rules must be run before all other DDL post-hoc resolution rules, i.e. // `PreprocessTableCreation`, `PreprocessTableInsertion`, `DataSourceAnalysis` and `HiveAnalysis`. val ParquetConversions: Rule[LogicalPlan] = metastoreCatalog.ParquetConversions @@ -77,10 +69,51 @@ private[sql] class HiveSessionCatalog( metastoreCatalog.hiveDefaultTableFilePath(name) } + /** + * Create a new [[HiveSessionCatalog]] with the provided parameters. `externalCatalog` and + * `globalTempViewManager` are `inherited`, while `currentDb` and `tempTables` are copied. + */ + def newSessionCatalogWith( + newSparkSession: SparkSession, + conf: SQLConf, + hadoopConf: Configuration, + functionRegistry: FunctionRegistry, + parser: ParserInterface): HiveSessionCatalog = { + val catalog = HiveSessionCatalog( + newSparkSession, + functionRegistry, + conf, + hadoopConf, + parser) + + synchronized { + catalog.currentDb = currentDb + // copy over temporary tables + tempTables.foreach(kv => catalog.tempTables.put(kv._1, kv._2)) + } + + catalog + } + + /** + * The parent class [[SessionCatalog]] cannot access the [[SparkSession]] class, so we cannot add + * a [[SparkSession]] parameter to [[SessionCatalog.newSessionCatalogWith]]. However, + * [[HiveSessionCatalog]] requires a [[SparkSession]] parameter, so we can a new version of + * `newSessionCatalogWith` and disable this one. + * + * TODO Refactor HiveSessionCatalog to not use [[SparkSession]] directly. + */ + override def newSessionCatalogWith( + conf: CatalystConf, + hadoopConf: Configuration, + functionRegistry: FunctionRegistry, + parser: ParserInterface): HiveSessionCatalog = throw new UnsupportedOperationException( + "to clone HiveSessionCatalog, use the other clone method that also accepts a SparkSession") + // For testing only private[hive] def getCachedDataSourceTable(table: TableIdentifier): LogicalPlan = { val key = metastoreCatalog.getQualifiedTableName(table) - sparkSession.sessionState.catalog.tableRelationCache.getIfPresent(key) + tableRelationCache.getIfPresent(key) } override def makeFunctionBuilder(funcName: String, className: String): FunctionBuilder = { @@ -217,3 +250,28 @@ private[sql] class HiveSessionCatalog( "histogram_numeric" ) } + +private[sql] object HiveSessionCatalog { + + def apply( + sparkSession: SparkSession, + functionRegistry: FunctionRegistry, + conf: SQLConf, + hadoopConf: Configuration, + parser: ParserInterface): HiveSessionCatalog = { + // Catalog for handling data source tables. TODO: This really doesn't belong here since it is + // essentially a cache for metastore tables. However, it relies on a lot of session-specific + // things so it would be a lot of work to split its functionality between HiveSessionCatalog + // and HiveCatalog. We should still do it at some point... + val metastoreCatalog = new HiveMetastoreCatalog(sparkSession) + + new HiveSessionCatalog( + sparkSession.sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog], + sparkSession.sharedState.globalTempViewManager, + metastoreCatalog, + functionRegistry, + conf, + hadoopConf, + parser) + } +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala index 5a08a6bc66f6b..cb8bcb8591bd6 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala @@ -17,89 +17,65 @@ package org.apache.spark.sql.hive +import org.apache.spark.SparkContext import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.analysis.Analyzer -import org.apache.spark.sql.execution.SparkPlanner +import org.apache.spark.sql.catalyst.analysis.{Analyzer, FunctionRegistry} +import org.apache.spark.sql.catalyst.parser.ParserInterface +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution.{QueryExecution, SparkPlanner, SparkSqlParser} import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.hive.client.HiveClient -import org.apache.spark.sql.internal.SessionState +import org.apache.spark.sql.internal.{SessionState, SharedState, SQLConf} +import org.apache.spark.sql.streaming.StreamingQueryManager /** * A class that holds all session-specific state in a given [[SparkSession]] backed by Hive. + * @param sparkContext The [[SparkContext]]. + * @param sharedState The shared state. + * @param conf SQL-specific key-value configurations. + * @param experimentalMethods The experimental methods. + * @param functionRegistry Internal catalog for managing functions registered by the user. + * @param catalog Internal catalog for managing table and database states that uses Hive client for + * interacting with the metastore. + * @param sqlParser Parser that extracts expressions, plans, table identifiers etc. from SQL texts. + * @param metadataHive The Hive metadata client. + * @param analyzer Logical query plan analyzer for resolving unresolved attributes and relations. + * @param streamingQueryManager Interface to start and stop + * [[org.apache.spark.sql.streaming.StreamingQuery]]s. + * @param queryExecutionCreator Lambda to create a [[QueryExecution]] from a [[LogicalPlan]] + * @param plannerCreator Lambda to create a planner that takes into account Hive-specific strategies */ -private[hive] class HiveSessionState(sparkSession: SparkSession) - extends SessionState(sparkSession) { - - self => - - /** - * A Hive client used for interacting with the metastore. - */ - lazy val metadataHive: HiveClient = - sparkSession.sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog].client.newSession() - - /** - * Internal catalog for managing table and database states. - */ - override lazy val catalog = { - new HiveSessionCatalog( - sparkSession.sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog], - sparkSession.sharedState.globalTempViewManager, - sparkSession, - functionResourceLoader, - functionRegistry, +private[hive] class HiveSessionState( + sparkContext: SparkContext, + sharedState: SharedState, + conf: SQLConf, + experimentalMethods: ExperimentalMethods, + functionRegistry: FunctionRegistry, + override val catalog: HiveSessionCatalog, + sqlParser: ParserInterface, + val metadataHive: HiveClient, + analyzer: Analyzer, + streamingQueryManager: StreamingQueryManager, + queryExecutionCreator: LogicalPlan => QueryExecution, + val plannerCreator: () => SparkPlanner) + extends SessionState( + sparkContext, + sharedState, conf, - newHadoopConf(), - sqlParser) - } - - /** - * An analyzer that uses the Hive metastore. - */ - override lazy val analyzer: Analyzer = { - new Analyzer(catalog, conf) { - override val extendedResolutionRules = - new ResolveHiveSerdeTable(sparkSession) :: - new FindDataSourceTable(sparkSession) :: - new ResolveSQLOnFile(sparkSession) :: Nil - - override val postHocResolutionRules = - new DetermineTableStats(sparkSession) :: - catalog.ParquetConversions :: - catalog.OrcConversions :: - PreprocessTableCreation(sparkSession) :: - PreprocessTableInsertion(conf) :: - DataSourceAnalysis(conf) :: - HiveAnalysis :: Nil - - override val extendedCheckRules = Seq(PreWriteCheck) - } - } + experimentalMethods, + functionRegistry, + catalog, + sqlParser, + analyzer, + streamingQueryManager, + queryExecutionCreator) { self => /** * Planner that takes into account Hive-specific strategies. */ - override def planner: SparkPlanner = { - new SparkPlanner(sparkSession.sparkContext, conf, experimentalMethods.extraStrategies) - with HiveStrategies { - override val sparkSession: SparkSession = self.sparkSession - - override def strategies: Seq[Strategy] = { - experimentalMethods.extraStrategies ++ Seq( - FileSourceStrategy, - DataSourceStrategy, - SpecialLimits, - InMemoryScans, - HiveTableScans, - Scripts, - Aggregation, - JoinSelection, - BasicOperators - ) - } - } - } + override def planner: SparkPlanner = plannerCreator() // ------------------------------------------------------ @@ -146,4 +122,149 @@ private[hive] class HiveSessionState(sparkSession: SparkSession) conf.getConf(HiveUtils.HIVE_THRIFT_SERVER_ASYNC) } + /** + * Get an identical copy of the `HiveSessionState`. + * This should ideally reuse the `SessionState.clone` but cannot do so. + * Doing that will throw an exception when trying to clone the catalog. + */ + override def clone(newSparkSession: SparkSession): HiveSessionState = { + val sparkContext = newSparkSession.sparkContext + val confCopy = conf.clone() + val functionRegistryCopy = functionRegistry.clone() + val experimentalMethodsCopy = experimentalMethods.clone() + val sqlParser: ParserInterface = new SparkSqlParser(confCopy) + val catalogCopy = catalog.newSessionCatalogWith( + newSparkSession, + confCopy, + SessionState.newHadoopConf(sparkContext.hadoopConfiguration, confCopy), + functionRegistryCopy, + sqlParser) + val queryExecutionCreator = (plan: LogicalPlan) => new QueryExecution(newSparkSession, plan) + + val hiveClient = + newSparkSession.sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog].client + .newSession() + + SessionState.mergeSparkConf(confCopy, sparkContext.getConf) + + new HiveSessionState( + sparkContext, + newSparkSession.sharedState, + confCopy, + experimentalMethodsCopy, + functionRegistryCopy, + catalogCopy, + sqlParser, + hiveClient, + HiveSessionState.createAnalyzer(newSparkSession, catalogCopy, confCopy), + new StreamingQueryManager(newSparkSession), + queryExecutionCreator, + HiveSessionState.createPlannerCreator( + newSparkSession, + confCopy, + experimentalMethodsCopy)) + } + +} + +private[hive] object HiveSessionState { + + def apply(sparkSession: SparkSession): HiveSessionState = { + apply(sparkSession, new SQLConf) + } + + def apply(sparkSession: SparkSession, conf: SQLConf): HiveSessionState = { + val initHelper = SessionState(sparkSession, conf) + + val sparkContext = sparkSession.sparkContext + + val catalog = HiveSessionCatalog( + sparkSession, + initHelper.functionRegistry, + initHelper.conf, + SessionState.newHadoopConf(sparkContext.hadoopConfiguration, initHelper.conf), + initHelper.sqlParser) + + val metadataHive: HiveClient = + sparkSession.sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog].client + .newSession() + + val analyzer: Analyzer = createAnalyzer(sparkSession, catalog, initHelper.conf) + + val plannerCreator = createPlannerCreator( + sparkSession, + initHelper.conf, + initHelper.experimentalMethods) + + val hiveSessionState = new HiveSessionState( + sparkContext, + sparkSession.sharedState, + initHelper.conf, + initHelper.experimentalMethods, + initHelper.functionRegistry, + catalog, + initHelper.sqlParser, + metadataHive, + analyzer, + initHelper.streamingQueryManager, + initHelper.queryExecutionCreator, + plannerCreator) + catalog.functionResourceLoader = hiveSessionState.functionResourceLoader + hiveSessionState + } + + /** + * Create an logical query plan `Analyzer` with rules specific to a `HiveSessionState`. + */ + private def createAnalyzer( + sparkSession: SparkSession, + catalog: HiveSessionCatalog, + sqlConf: SQLConf): Analyzer = { + new Analyzer(catalog, sqlConf) { + override val extendedResolutionRules: Seq[Rule[LogicalPlan]] = + new ResolveHiveSerdeTable(sparkSession) :: + new FindDataSourceTable(sparkSession) :: + new ResolveSQLOnFile(sparkSession) :: Nil + + override val postHocResolutionRules: Seq[Rule[LogicalPlan]] = + new DetermineTableStats(sparkSession) :: + catalog.ParquetConversions :: + catalog.OrcConversions :: + PreprocessTableCreation(sparkSession) :: + PreprocessTableInsertion(sqlConf) :: + DataSourceAnalysis(sqlConf) :: + HiveAnalysis :: Nil + + override val extendedCheckRules = Seq(PreWriteCheck) + } + } + + private def createPlannerCreator( + associatedSparkSession: SparkSession, + sqlConf: SQLConf, + experimentalMethods: ExperimentalMethods): () => SparkPlanner = { + () => + new SparkPlanner( + associatedSparkSession.sparkContext, + sqlConf, + experimentalMethods.extraStrategies) + with HiveStrategies { + + override val sparkSession: SparkSession = associatedSparkSession + + override def strategies: Seq[Strategy] = { + experimentalMethods.extraStrategies ++ Seq( + FileSourceStrategy, + DataSourceStrategy, + SpecialLimits, + InMemoryScans, + HiveTableScans, + Scripts, + Aggregation, + JoinSelection, + BasicOperators + ) + } + } + } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala index 469c9d84de054..6e1f429286cfa 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala @@ -278,6 +278,8 @@ private[hive] class HiveClientImpl( state.getConf.setClassLoader(clientLoader.classLoader) // Set the thread local metastore client to the client associated with this HiveClientImpl. Hive.set(client) + // Replace conf in the thread local Hive with current conf + Hive.get(conf) // setCurrentSessionState will use the classLoader associated // with the HiveConf in `state` to override the context class loader of the current // thread. diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala index efc2f0098454b..076c40d45932b 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -30,16 +30,17 @@ import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.internal.Logging -import org.apache.spark.sql.{SparkSession, SQLContext} -import org.apache.spark.sql.catalyst.analysis._ -import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder -import org.apache.spark.sql.catalyst.expressions.ExpressionInfo +import org.apache.spark.sql.{ExperimentalMethods, SparkSession, SQLContext} +import org.apache.spark.sql.catalyst.analysis.{Analyzer, UnresolvedRelation} +import org.apache.spark.sql.catalyst.parser.ParserInterface import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.execution.QueryExecution +import org.apache.spark.sql.execution.{QueryExecution, SparkPlanner} import org.apache.spark.sql.execution.command.CacheTableCommand import org.apache.spark.sql.hive._ -import org.apache.spark.sql.internal.{SharedState, SQLConf} +import org.apache.spark.sql.hive.client.HiveClient +import org.apache.spark.sql.internal.{SessionState, SharedState, SQLConf} import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION +import org.apache.spark.sql.streaming.StreamingQueryManager import org.apache.spark.util.{ShutdownHookManager, Utils} // SPARK-3729: Test key required to check for initialization errors with config. @@ -84,7 +85,7 @@ class TestHiveContext( new TestHiveContext(sparkSession.newSession()) } - override def sessionState: TestHiveSessionState = sparkSession.sessionState + override def sessionState: HiveSessionState = sparkSession.sessionState def setCacheTables(c: Boolean): Unit = { sparkSession.setCacheTables(c) @@ -144,11 +145,35 @@ private[hive] class TestHiveSparkSession( existingSharedState.getOrElse(new SharedState(sc)) } - // TODO: Let's remove TestHiveSessionState. Otherwise, we are not really testing the reflection - // logic based on the setting of CATALOG_IMPLEMENTATION. @transient - override lazy val sessionState: TestHiveSessionState = - new TestHiveSessionState(self) + override lazy val sessionState: HiveSessionState = { + val testConf = + new SQLConf { + clear() + override def caseSensitiveAnalysis: Boolean = getConf(SQLConf.CASE_SENSITIVE, false) + override def clear(): Unit = { + super.clear() + TestHiveContext.overrideConfs.foreach { case (k, v) => setConfString(k, v) } + } + } + val queryExecutionCreator = (plan: LogicalPlan) => new TestHiveQueryExecution(this, plan) + val initHelper = HiveSessionState(this, testConf) + SessionState.mergeSparkConf(testConf, sparkContext.getConf) + + new HiveSessionState( + sparkContext, + sharedState, + testConf, + initHelper.experimentalMethods, + initHelper.functionRegistry, + initHelper.catalog, + initHelper.sqlParser, + initHelper.metadataHive, + initHelper.analyzer, + initHelper.streamingQueryManager, + queryExecutionCreator, + initHelper.plannerCreator) + } override def newSession(): TestHiveSparkSession = { new TestHiveSparkSession(sc, Some(sharedState), loadTestTables) @@ -492,26 +517,6 @@ private[hive] class TestHiveQueryExecution( } } -private[hive] class TestHiveSessionState( - sparkSession: TestHiveSparkSession) - extends HiveSessionState(sparkSession) { self => - - override lazy val conf: SQLConf = { - new SQLConf { - clear() - override def caseSensitiveAnalysis: Boolean = getConf(SQLConf.CASE_SENSITIVE, false) - override def clear(): Unit = { - super.clear() - TestHiveContext.overrideConfs.foreach { case (k, v) => setConfString(k, v) } - } - } - } - - override def executePlan(plan: LogicalPlan): TestHiveQueryExecution = { - new TestHiveQueryExecution(sparkSession, plan) - } -} - private[hive] object TestHiveContext { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSessionCatalogSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSessionCatalogSuite.scala new file mode 100644 index 0000000000000..3b0f59b15916c --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSessionCatalogSuite.scala @@ -0,0 +1,112 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import java.net.URI + +import org.apache.hadoop.conf.Configuration + +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.analysis.SimpleFunctionRegistry +import org.apache.spark.sql.catalyst.catalog.CatalogDatabase +import org.apache.spark.sql.catalyst.parser.CatalystSqlParser +import org.apache.spark.sql.catalyst.plans.logical.Range +import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.util.Utils + +class HiveSessionCatalogSuite extends TestHiveSingleton { + + test("clone HiveSessionCatalog") { + val original = spark.sessionState.catalog.asInstanceOf[HiveSessionCatalog] + + val tempTableName1 = "copytest1" + val tempTableName2 = "copytest2" + try { + val tempTable1 = Range(1, 10, 1, 10) + original.createTempView(tempTableName1, tempTable1, overrideIfExists = false) + + // check if tables copied over + val clone = original.newSessionCatalogWith( + spark, + new SQLConf, + new Configuration(), + new SimpleFunctionRegistry, + CatalystSqlParser) + assert(original ne clone) + assert(clone.getTempView(tempTableName1) == Some(tempTable1)) + + // check if clone and original independent + clone.dropTable(TableIdentifier(tempTableName1), ignoreIfNotExists = false, purge = false) + assert(original.getTempView(tempTableName1) == Some(tempTable1)) + + val tempTable2 = Range(1, 20, 2, 10) + original.createTempView(tempTableName2, tempTable2, overrideIfExists = false) + assert(clone.getTempView(tempTableName2).isEmpty) + } finally { + // Drop the created temp views from the global singleton HiveSession. + original.dropTable(TableIdentifier(tempTableName1), ignoreIfNotExists = true, purge = true) + original.dropTable(TableIdentifier(tempTableName2), ignoreIfNotExists = true, purge = true) + } + } + + test("clone SessionCatalog - current db") { + val original = spark.sessionState.catalog.asInstanceOf[HiveSessionCatalog] + val originalCurrentDatabase = original.getCurrentDatabase + val db1 = "db1" + val db2 = "db2" + val db3 = "db3" + try { + original.createDatabase(newDb(db1), ignoreIfExists = true) + original.createDatabase(newDb(db2), ignoreIfExists = true) + original.createDatabase(newDb(db3), ignoreIfExists = true) + + original.setCurrentDatabase(db1) + + // check if tables copied over + val clone = original.newSessionCatalogWith( + spark, + new SQLConf, + new Configuration(), + new SimpleFunctionRegistry, + CatalystSqlParser) + + // check if current db copied over + assert(original ne clone) + assert(clone.getCurrentDatabase == db1) + + // check if clone and original independent + clone.setCurrentDatabase(db2) + assert(original.getCurrentDatabase == db1) + original.setCurrentDatabase(db3) + assert(clone.getCurrentDatabase == db2) + } finally { + // Drop the created databases from the global singleton HiveSession. + original.dropDatabase(db1, ignoreIfNotExists = true, cascade = true) + original.dropDatabase(db2, ignoreIfNotExists = true, cascade = true) + original.dropDatabase(db3, ignoreIfNotExists = true, cascade = true) + original.setCurrentDatabase(originalCurrentDatabase) + } + } + + def newUriForDatabase(): URI = new URI(Utils.createTempDir().toURI.toString.stripSuffix("/")) + + def newDb(name: String): CatalogDatabase = { + CatalogDatabase(name, name + " description", newUriForDatabase(), Map.empty) + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSessionStateSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSessionStateSuite.scala new file mode 100644 index 0000000000000..67c77fb62f4e1 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSessionStateSuite.scala @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import org.scalatest.BeforeAndAfterEach + +import org.apache.spark.sql._ +import org.apache.spark.sql.hive.test.TestHiveSingleton + +/** + * Run all tests from `SessionStateSuite` with a `HiveSessionState`. + */ +class HiveSessionStateSuite extends SessionStateSuite + with TestHiveSingleton with BeforeAndAfterEach { + + override def beforeAll(): Unit = { + // Reuse the singleton session + activeSession = spark + } + + override def afterAll(): Unit = { + // Set activeSession to null to avoid stopping the singleton session + activeSession = null + super.afterAll() + } +} From 455129020ca7f6a162f6f2486a87cc43512cfd2c Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Wed, 8 Mar 2017 13:43:09 -0800 Subject: [PATCH 52/78] [SPARK-15463][SQL] Add an API to load DataFrame from Dataset[String] storing CSV ## What changes were proposed in this pull request? This PR proposes to add an API that loads `DataFrame` from `Dataset[String]` storing csv. It allows pre-processing before loading into CSV, which means allowing a lot of workarounds for many narrow cases, for example, as below: - Case 1 - pre-processing ```scala val df = spark.read.text("...") // Pre-processing with this. spark.read.csv(df.as[String]) ``` - Case 2 - use other input formats ```scala val rdd = spark.sparkContext.newAPIHadoopFile("/file.csv.lzo", classOf[com.hadoop.mapreduce.LzoTextInputFormat], classOf[org.apache.hadoop.io.LongWritable], classOf[org.apache.hadoop.io.Text]) val stringRdd = rdd.map(pair => new String(pair._2.getBytes, 0, pair._2.getLength)) spark.read.csv(stringRdd.toDS) ``` ## How was this patch tested? Added tests in `CSVSuite` and build with Scala 2.10. ``` ./dev/change-scala-version.sh 2.10 ./build/mvn -Pyarn -Phadoop-2.4 -Dscala-2.10 -DskipTests clean package ``` Author: hyukjinkwon Closes #16854 from HyukjinKwon/SPARK-15463. --- .../apache/spark/sql/DataFrameReader.scala | 71 ++++++++++++++++--- .../datasources/csv/CSVDataSource.scala | 49 +++++++------ .../datasources/csv/CSVOptions.scala | 2 +- .../datasources/csv/UnivocityParser.scala | 2 +- .../execution/datasources/csv/CSVSuite.scala | 27 +++++++ 5 files changed, 121 insertions(+), 30 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 41470ae6aae19..a5e38e25b1ec5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -29,6 +29,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonParser, JSONOptions} import org.apache.spark.sql.execution.LogicalRDD import org.apache.spark.sql.execution.command.DDLUtils +import org.apache.spark.sql.execution.datasources.csv._ import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.datasources.jdbc._ import org.apache.spark.sql.execution.datasources.json.JsonInferSchema @@ -368,14 +369,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { createParser) } - // Check a field requirement for corrupt records here to throw an exception in a driver side - schema.getFieldIndex(parsedOptions.columnNameOfCorruptRecord).foreach { corruptFieldIndex => - val f = schema(corruptFieldIndex) - if (f.dataType != StringType || !f.nullable) { - throw new AnalysisException( - "The field for corrupt records must be string type and nullable") - } - } + verifyColumnNameOfCorruptRecord(schema, parsedOptions.columnNameOfCorruptRecord) val parsed = jsonDataset.rdd.mapPartitions { iter => val parser = new JacksonParser(schema, parsedOptions) @@ -398,6 +392,51 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { csv(Seq(path): _*) } + /** + * Loads an `Dataset[String]` storing CSV rows and returns the result as a `DataFrame`. + * + * If the schema is not specified using `schema` function and `inferSchema` option is enabled, + * this function goes through the input once to determine the input schema. + * + * If the schema is not specified using `schema` function and `inferSchema` option is disabled, + * it determines the columns as string types and it reads only the first line to determine the + * names and the number of fields. + * + * @param csvDataset input Dataset with one CSV row per record + * @since 2.2.0 + */ + def csv(csvDataset: Dataset[String]): DataFrame = { + val parsedOptions: CSVOptions = new CSVOptions( + extraOptions.toMap, + sparkSession.sessionState.conf.sessionLocalTimeZone) + val filteredLines: Dataset[String] = + CSVUtils.filterCommentAndEmpty(csvDataset, parsedOptions) + val maybeFirstLine: Option[String] = filteredLines.take(1).headOption + + val schema = userSpecifiedSchema.getOrElse { + TextInputCSVDataSource.inferFromDataset( + sparkSession, + csvDataset, + maybeFirstLine, + parsedOptions) + } + + verifyColumnNameOfCorruptRecord(schema, parsedOptions.columnNameOfCorruptRecord) + + val linesWithoutHeader: RDD[String] = maybeFirstLine.map { firstLine => + filteredLines.rdd.mapPartitions(CSVUtils.filterHeaderLine(_, firstLine, parsedOptions)) + }.getOrElse(filteredLines.rdd) + + val parsed = linesWithoutHeader.mapPartitions { iter => + val parser = new UnivocityParser(schema, parsedOptions) + iter.flatMap(line => parser.parse(line)) + } + + Dataset.ofRows( + sparkSession, + LogicalRDD(schema.toAttributes, parsed)(sparkSession)) + } + /** * Loads a CSV file and returns the result as a `DataFrame`. * @@ -604,6 +643,22 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { } } + /** + * A convenient function for schema validation in datasources supporting + * `columnNameOfCorruptRecord` as an option. + */ + private def verifyColumnNameOfCorruptRecord( + schema: StructType, + columnNameOfCorruptRecord: String): Unit = { + schema.getFieldIndex(columnNameOfCorruptRecord).foreach { corruptFieldIndex => + val f = schema(corruptFieldIndex) + if (f.dataType != StringType || !f.nullable) { + throw new AnalysisException( + "The field for corrupt records must be string type and nullable") + } + } + } + /////////////////////////////////////////////////////////////////////////////////////// // Builder pattern config options /////////////////////////////////////////////////////////////////////////////////////// diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala index 47567032b0195..35ff924f27ce5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala @@ -17,12 +17,11 @@ package org.apache.spark.sql.execution.datasources.csv -import java.io.InputStream import java.nio.charset.{Charset, StandardCharsets} -import com.univocity.parsers.csv.{CsvParser, CsvParserSettings} +import com.univocity.parsers.csv.CsvParser import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.{FileStatus, Path} +import org.apache.hadoop.fs.FileStatus import org.apache.hadoop.io.{LongWritable, Text} import org.apache.hadoop.mapred.TextInputFormat import org.apache.hadoop.mapreduce.Job @@ -134,23 +133,33 @@ object TextInputCSVDataSource extends CSVDataSource { inputPaths: Seq[FileStatus], parsedOptions: CSVOptions): Option[StructType] = { val csv = createBaseDataset(sparkSession, inputPaths, parsedOptions) - CSVUtils.filterCommentAndEmpty(csv, parsedOptions).take(1).headOption match { - case Some(firstLine) => - val firstRow = new CsvParser(parsedOptions.asParserSettings).parseLine(firstLine) - val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis - val header = makeSafeHeader(firstRow, caseSensitive, parsedOptions) - val tokenRDD = csv.rdd.mapPartitions { iter => - val filteredLines = CSVUtils.filterCommentAndEmpty(iter, parsedOptions) - val linesWithoutHeader = - CSVUtils.filterHeaderLine(filteredLines, firstLine, parsedOptions) - val parser = new CsvParser(parsedOptions.asParserSettings) - linesWithoutHeader.map(parser.parseLine) - } - Some(CSVInferSchema.infer(tokenRDD, header, parsedOptions)) - case None => - // If the first line could not be read, just return the empty schema. - Some(StructType(Nil)) - } + val maybeFirstLine = CSVUtils.filterCommentAndEmpty(csv, parsedOptions).take(1).headOption + Some(inferFromDataset(sparkSession, csv, maybeFirstLine, parsedOptions)) + } + + /** + * Infers the schema from `Dataset` that stores CSV string records. + */ + def inferFromDataset( + sparkSession: SparkSession, + csv: Dataset[String], + maybeFirstLine: Option[String], + parsedOptions: CSVOptions): StructType = maybeFirstLine match { + case Some(firstLine) => + val firstRow = new CsvParser(parsedOptions.asParserSettings).parseLine(firstLine) + val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis + val header = makeSafeHeader(firstRow, caseSensitive, parsedOptions) + val tokenRDD = csv.rdd.mapPartitions { iter => + val filteredLines = CSVUtils.filterCommentAndEmpty(iter, parsedOptions) + val linesWithoutHeader = + CSVUtils.filterHeaderLine(filteredLines, firstLine, parsedOptions) + val parser = new CsvParser(parsedOptions.asParserSettings) + linesWithoutHeader.map(parser.parseLine) + } + CSVInferSchema.infer(tokenRDD, header, parsedOptions) + case None => + // If the first line could not be read, just return the empty schema. + StructType(Nil) } private def createBaseDataset( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala index 50503385ad6d1..0b1e5dac2da66 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala @@ -26,7 +26,7 @@ import org.apache.commons.lang3.time.FastDateFormat import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, CompressionCodecs, ParseModes} -private[csv] class CSVOptions( +class CSVOptions( @transient private val parameters: CaseInsensitiveMap[String], defaultTimeZoneId: String, defaultColumnNameOfCorruptRecord: String) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala index 3b3b87e4354d6..e42ea3fa391f5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala @@ -34,7 +34,7 @@ import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String -private[csv] class UnivocityParser( +class UnivocityParser( schema: StructType, requiredSchema: StructType, private val options: CSVOptions) extends Logging { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index eaedede349134..4435e4df38ef6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -129,6 +129,22 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { verifyCars(cars, withHeader = true, checkTypes = true) } + test("simple csv test with string dataset") { + val csvDataset = spark.read.text(testFile(carsFile)).as[String] + val cars = spark.read + .option("header", "true") + .option("inferSchema", "true") + .csv(csvDataset) + + verifyCars(cars, withHeader = true, checkTypes = true) + + val carsWithoutHeader = spark.read + .option("header", "false") + .csv(csvDataset) + + verifyCars(carsWithoutHeader, withHeader = false, checkTypes = false) + } + test("test inferring booleans") { val result = spark.read .format("csv") @@ -1088,4 +1104,15 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { checkAnswer(df, spark.emptyDataFrame) } } + + test("Empty string dataset produces empty dataframe and keep user-defined schema") { + val df1 = spark.read.csv(spark.emptyDataset[String]) + assert(df1.schema === spark.emptyDataFrame.schema) + checkAnswer(df1, spark.emptyDataFrame) + + val schema = StructType(StructField("a", StringType) :: Nil) + val df2 = spark.read.schema(schema).csv(spark.emptyDataset[String]) + assert(df2.schema === schema) + } + } From a3648b5d4f99ff9461d02f53e9ec71787a3abf51 Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Wed, 8 Mar 2017 14:35:07 -0800 Subject: [PATCH 53/78] [SPARK-19813] maxFilesPerTrigger combo latestFirst may miss old files in combination with maxFileAge in FileStreamSource ## What changes were proposed in this pull request? **The Problem** There is a file stream source option called maxFileAge which limits how old the files can be, relative the latest file that has been seen. This is used to limit the files that need to be remembered as "processed". Files older than the latest processed files are ignored. This values is by default 7 days. This causes a problem when both latestFirst = true maxFilesPerTrigger > total files to be processed. Here is what happens in all combinations 1) latestFirst = false - Since files are processed in order, there wont be any unprocessed file older than the latest processed file. All files will be processed. 2) latestFirst = true AND maxFilesPerTrigger is not set - The maxFileAge thresholding mechanism takes one batch initialize. If maxFilesPerTrigger is not, then all old files get processed in the first batch, and so no file is left behind. 3) latestFirst = true AND maxFilesPerTrigger is set to X - The first batch process the latest X files. That sets the threshold latest file - maxFileAge, so files older than this threshold will never be considered for processing. The bug is with case 3. **The Solution** Ignore `maxFileAge` when both `maxFilesPerTrigger` and `latestFirst` are set. ## How was this patch tested? Regression test in `FileStreamSourceSuite` Author: Burak Yavuz Closes #17153 from brkyvz/maxFileAge. --- .../streaming/FileStreamOptions.scala | 5 +- .../streaming/FileStreamSource.scala | 14 +++- .../sql/streaming/FileStreamSourceSuite.scala | 82 +++++++++++-------- 3 files changed, 63 insertions(+), 38 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamOptions.scala index 2f802d782f5ad..e7ba901945490 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamOptions.scala @@ -38,7 +38,10 @@ class FileStreamOptions(parameters: CaseInsensitiveMap[String]) extends Logging } /** - * Maximum age of a file that can be found in this directory, before it is deleted. + * Maximum age of a file that can be found in this directory, before it is ignored. For the + * first batch all files will be considered valid. If `latestFirst` is set to `true` and + * `maxFilesPerTrigger` is set, then this parameter will be ignored, because old files that are + * valid, and should be processed, may be ignored. Please refer to SPARK-19813 for details. * * The max age is specified with respect to the timestamp of the latest file, and not the * timestamp of the current system. That this means if the last file has timestamp 1000, and the diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala index 6a7263ca45d85..0f09b0a0c8f25 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala @@ -66,23 +66,29 @@ class FileStreamSource( private val fileSortOrder = if (sourceOptions.latestFirst) { logWarning( - """'latestFirst' is true. New files will be processed first. - |It may affect the watermark value""".stripMargin) + """'latestFirst' is true. New files will be processed first, which may affect the watermark + |value. In addition, 'maxFileAge' will be ignored.""".stripMargin) implicitly[Ordering[Long]].reverse } else { implicitly[Ordering[Long]] } + private val maxFileAgeMs: Long = if (sourceOptions.latestFirst && maxFilesPerBatch.isDefined) { + Long.MaxValue + } else { + sourceOptions.maxFileAgeMs + } + /** A mapping from a file that we have processed to some timestamp it was last modified. */ // Visible for testing and debugging in production. - val seenFiles = new SeenFilesMap(sourceOptions.maxFileAgeMs) + val seenFiles = new SeenFilesMap(maxFileAgeMs) metadataLog.allFiles().foreach { entry => seenFiles.add(entry.path, entry.timestamp) } seenFiles.purge() - logInfo(s"maxFilesPerBatch = $maxFilesPerBatch, maxFileAge = ${sourceOptions.maxFileAgeMs}") + logInfo(s"maxFilesPerBatch = $maxFilesPerBatch, maxFileAge = $maxFileAgeMs") /** * Returns the maximum offset that can be retrieved from the source. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala index 1586850c77fca..0517b0a800e53 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala @@ -1173,6 +1173,41 @@ class FileStreamSourceSuite extends FileStreamSourceTest { SerializedOffset(str.trim) } + private def runTwoBatchesAndVerifyResults( + src: File, + latestFirst: Boolean, + firstBatch: String, + secondBatch: String, + maxFileAge: Option[String] = None): Unit = { + val srcOptions = Map("latestFirst" -> latestFirst.toString, "maxFilesPerTrigger" -> "1") ++ + maxFileAge.map("maxFileAge" -> _) + val fileStream = createFileStream( + "text", + src.getCanonicalPath, + options = srcOptions) + val clock = new StreamManualClock() + testStream(fileStream)( + StartStream(trigger = ProcessingTime(10), triggerClock = clock), + AssertOnQuery { _ => + // Block until the first batch finishes. + eventually(timeout(streamingTimeout)) { + assert(clock.isStreamWaitingAt(0)) + } + true + }, + CheckLastBatch(firstBatch), + AdvanceManualClock(10), + AssertOnQuery { _ => + // Block until the second batch finishes. + eventually(timeout(streamingTimeout)) { + assert(clock.isStreamWaitingAt(10)) + } + true + }, + CheckLastBatch(secondBatch) + ) + } + test("FileStreamSource - latestFirst") { withTempDir { src => // Prepare two files: 1.txt, 2.txt, and make sure they have different modified time. @@ -1180,42 +1215,23 @@ class FileStreamSourceSuite extends FileStreamSourceTest { val f2 = stringToFile(new File(src, "2.txt"), "2") f2.setLastModified(f1.lastModified + 1000) - def runTwoBatchesAndVerifyResults( - latestFirst: Boolean, - firstBatch: String, - secondBatch: String): Unit = { - val fileStream = createFileStream( - "text", - src.getCanonicalPath, - options = Map("latestFirst" -> latestFirst.toString, "maxFilesPerTrigger" -> "1")) - val clock = new StreamManualClock() - testStream(fileStream)( - StartStream(trigger = ProcessingTime(10), triggerClock = clock), - AssertOnQuery { _ => - // Block until the first batch finishes. - eventually(timeout(streamingTimeout)) { - assert(clock.isStreamWaitingAt(0)) - } - true - }, - CheckLastBatch(firstBatch), - AdvanceManualClock(10), - AssertOnQuery { _ => - // Block until the second batch finishes. - eventually(timeout(streamingTimeout)) { - assert(clock.isStreamWaitingAt(10)) - } - true - }, - CheckLastBatch(secondBatch) - ) - } - // Read oldest files first, so the first batch is "1", and the second batch is "2". - runTwoBatchesAndVerifyResults(latestFirst = false, firstBatch = "1", secondBatch = "2") + runTwoBatchesAndVerifyResults(src, latestFirst = false, firstBatch = "1", secondBatch = "2") // Read latest files first, so the first batch is "2", and the second batch is "1". - runTwoBatchesAndVerifyResults(latestFirst = true, firstBatch = "2", secondBatch = "1") + runTwoBatchesAndVerifyResults(src, latestFirst = true, firstBatch = "2", secondBatch = "1") + } + } + + test("SPARK-19813: Ignore maxFileAge when maxFilesPerTrigger and latestFirst is used") { + withTempDir { src => + // Prepare two files: 1.txt, 2.txt, and make sure they have different modified time. + val f1 = stringToFile(new File(src, "1.txt"), "1") + val f2 = stringToFile(new File(src, "2.txt"), "2") + f2.setLastModified(f1.lastModified + 3600 * 1000 /* 1 hour later */) + + runTwoBatchesAndVerifyResults(src, latestFirst = true, firstBatch = "2", secondBatch = "1", + maxFileAge = Some("1m") /* 1 minute */) } } From d809ceed9762d5bbb04170e45f38751713112dd8 Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Wed, 8 Mar 2017 17:33:49 -0800 Subject: [PATCH 54/78] [MINOR][SQL] The analyzer rules are fired twice for cases when AnalysisException is raised from analyzer. ## What changes were proposed in this pull request? In general we have a checkAnalysis phase which validates the logical plan and throws AnalysisException on semantic errors. However we also can throw AnalysisException from a few analyzer rules like ResolveSubquery. I found that we fire up the analyzer rules twice for the queries that throw AnalysisException from one of the analyzer rules. This is a very minor fix. We don't have to strictly fix it. I just got confused seeing the rule getting fired two times when i was not expecting it. ## How was this patch tested? Tested manually. Author: Dilip Biswal Closes #17214 from dilipbiswal/analyis_twice. --- .../org/apache/spark/sql/execution/QueryExecution.scala | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index 6ec2f4d840862..9a3656ddc79f4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -46,9 +46,14 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) { protected def planner = sparkSession.sessionState.planner def assertAnalyzed(): Unit = { - try sparkSession.sessionState.analyzer.checkAnalysis(analyzed) catch { + // Analyzer is invoked outside the try block to avoid calling it again from within the + // catch block below. + analyzed + try { + sparkSession.sessionState.analyzer.checkAnalysis(analyzed) + } catch { case e: AnalysisException => - val ae = new AnalysisException(e.message, e.line, e.startPosition, Some(analyzed)) + val ae = new AnalysisException(e.message, e.line, e.startPosition, Option(analyzed)) ae.setStackTrace(e.getStackTrace) throw ae } From 09829be621f0f9bb5076abb3d832925624699fa9 Mon Sep 17 00:00:00 2001 From: Xiao Li Date: Wed, 8 Mar 2017 23:12:10 -0800 Subject: [PATCH 55/78] [SPARK-19235][SQL][TESTS] Enable Test Cases in DDLSuite with Hive Metastore ### What changes were proposed in this pull request? So far, the test cases in DDLSuites only verify the behaviors of InMemoryCatalog. That means, they do not cover the scenarios using HiveExternalCatalog. Thus, we need to improve the existing test suite to run these cases using Hive metastore. When porting these test cases, a bug of `SET LOCATION` is found. `path` is not set when the location is changed. After this PR, a few changes are made, as summarized below, - `DDLSuite` becomes an abstract class. Both `InMemoryCatalogedDDLSuite` and `HiveCatalogedDDLSuite` extend it. `InMemoryCatalogedDDLSuite` is using `InMemoryCatalog`. `HiveCatalogedDDLSuite` is using `HiveExternalCatalog`. - `InMemoryCatalogedDDLSuite` contains all the existing test cases in `DDLSuite`. - `HiveCatalogedDDLSuite` contains a subset of `DDLSuite`. The following test cases are excluded: 1. The following test cases only make sense for `InMemoryCatalog`: ``` test("desc table for parquet data source table using in-memory catalog") test("create a managed Hive source table") { test("create an external Hive source table") test("Create Hive Table As Select") ``` 2. The following test cases are unable to be ported because we are unable to alter table provider when using Hive metastore. In the future PRs we need to improve the test cases so that altering table provider is not needed: ``` test("alter table: set location (datasource table)") test("alter table: set properties (datasource table)") test("alter table: unset properties (datasource table)") test("alter table: set serde (datasource table)") test("alter table: set serde partition (datasource table)") test("alter table: change column (datasource table)") test("alter table: add partition (datasource table)") test("alter table: drop partition (datasource table)") test("alter table: rename partition (datasource table)") test("drop table - data source table") ``` **TODO** : in the future PRs, we need to remove `HiveDDLSuite` and move the test cases to either `DDLSuite`, `InMemoryCatalogedDDLSuite` or `HiveCatalogedDDLSuite`. ### How was this patch tested? N/A Author: Xiao Li Author: gatorsmile Closes #16592 from gatorsmile/refactorDDLSuite. --- .../sql/execution/command/DDLSuite.scala | 456 ++++++++++-------- .../apache/spark/sql/test/SQLTestUtils.scala | 5 + .../sql/hive/execution/HiveDDLSuite.scala | 157 +++--- 3 files changed, 345 insertions(+), 273 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index c1f8b2b3d9605..aa335c4453dd1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -30,23 +30,164 @@ import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION -import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils} import org.apache.spark.sql.types._ import org.apache.spark.util.Utils -class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { - private val escapedIdentifier = "`(.+)`".r +class InMemoryCatalogedDDLSuite extends DDLSuite with SharedSQLContext with BeforeAndAfterEach { override def afterEach(): Unit = { try { // drop all databases, tables and functions after each test spark.sessionState.catalog.reset() } finally { - Utils.deleteRecursively(new File("spark-warehouse")) + Utils.deleteRecursively(new File(spark.sessionState.conf.warehousePath)) super.afterEach() } } + protected override def generateTable( + catalog: SessionCatalog, + name: TableIdentifier): CatalogTable = { + val storage = + CatalogStorageFormat.empty.copy(locationUri = Some(catalog.defaultTablePath(name))) + val metadata = new MetadataBuilder() + .putString("key", "value") + .build() + CatalogTable( + identifier = name, + tableType = CatalogTableType.EXTERNAL, + storage = storage, + schema = new StructType() + .add("col1", "int", nullable = true, metadata = metadata) + .add("col2", "string") + .add("a", "int") + .add("b", "int"), + provider = Some("parquet"), + partitionColumnNames = Seq("a", "b"), + createTime = 0L, + tracksPartitionsInCatalog = true) + } + + test("desc table for parquet data source table using in-memory catalog") { + val tabName = "tab1" + withTable(tabName) { + sql(s"CREATE TABLE $tabName(a int comment 'test') USING parquet ") + + checkAnswer( + sql(s"DESC $tabName").select("col_name", "data_type", "comment"), + Row("a", "int", "test") + ) + } + } + + test("alter table: set location (datasource table)") { + testSetLocation(isDatasourceTable = true) + } + + test("alter table: set properties (datasource table)") { + testSetProperties(isDatasourceTable = true) + } + + test("alter table: unset properties (datasource table)") { + testUnsetProperties(isDatasourceTable = true) + } + + test("alter table: set serde (datasource table)") { + testSetSerde(isDatasourceTable = true) + } + + test("alter table: set serde partition (datasource table)") { + testSetSerdePartition(isDatasourceTable = true) + } + + test("alter table: change column (datasource table)") { + testChangeColumn(isDatasourceTable = true) + } + + test("alter table: add partition (datasource table)") { + testAddPartitions(isDatasourceTable = true) + } + + test("alter table: drop partition (datasource table)") { + testDropPartitions(isDatasourceTable = true) + } + + test("alter table: rename partition (datasource table)") { + testRenamePartitions(isDatasourceTable = true) + } + + test("drop table - data source table") { + testDropTable(isDatasourceTable = true) + } + + test("create a managed Hive source table") { + assume(spark.sparkContext.conf.get(CATALOG_IMPLEMENTATION) == "in-memory") + val tabName = "tbl" + withTable(tabName) { + val e = intercept[AnalysisException] { + sql(s"CREATE TABLE $tabName (i INT, j STRING)") + }.getMessage + assert(e.contains("Hive support is required to CREATE Hive TABLE")) + } + } + + test("create an external Hive source table") { + assume(spark.sparkContext.conf.get(CATALOG_IMPLEMENTATION) == "in-memory") + withTempDir { tempDir => + val tabName = "tbl" + withTable(tabName) { + val e = intercept[AnalysisException] { + sql( + s""" + |CREATE EXTERNAL TABLE $tabName (i INT, j STRING) + |ROW FORMAT DELIMITED FIELDS TERMINATED BY ',' + |LOCATION '${tempDir.toURI}' + """.stripMargin) + }.getMessage + assert(e.contains("Hive support is required to CREATE Hive TABLE")) + } + } + } + + test("Create Hive Table As Select") { + import testImplicits._ + withTable("t", "t1") { + var e = intercept[AnalysisException] { + sql("CREATE TABLE t SELECT 1 as a, 1 as b") + }.getMessage + assert(e.contains("Hive support is required to CREATE Hive TABLE (AS SELECT)")) + + spark.range(1).select('id as 'a, 'id as 'b).write.saveAsTable("t1") + e = intercept[AnalysisException] { + sql("CREATE TABLE t SELECT a, b from t1") + }.getMessage + assert(e.contains("Hive support is required to CREATE Hive TABLE (AS SELECT)")) + } + } + +} + +abstract class DDLSuite extends QueryTest with SQLTestUtils { + + protected def isUsingHiveMetastore: Boolean = { + spark.sparkContext.conf.get(CATALOG_IMPLEMENTATION) == "hive" + } + + protected def generateTable(catalog: SessionCatalog, name: TableIdentifier): CatalogTable + + private val escapedIdentifier = "`(.+)`".r + + protected def normalizeCatalogTable(table: CatalogTable): CatalogTable = table + + private def normalizeSerdeProp(props: Map[String, String]): Map[String, String] = { + props.filterNot(p => Seq("serialization.format", "path").contains(p._1)) + } + + private def checkCatalogTables(expected: CatalogTable, actual: CatalogTable): Unit = { + assert(normalizeCatalogTable(actual) == normalizeCatalogTable(expected)) + } + /** * Strip backticks, if any, from the string. */ @@ -75,33 +216,6 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { ignoreIfExists = false) } - private def generateTable(catalog: SessionCatalog, name: TableIdentifier): CatalogTable = { - val storage = - CatalogStorageFormat( - locationUri = Some(catalog.defaultTablePath(name)), - inputFormat = None, - outputFormat = None, - serde = None, - compressed = false, - properties = Map()) - val metadata = new MetadataBuilder() - .putString("key", "value") - .build() - CatalogTable( - identifier = name, - tableType = CatalogTableType.EXTERNAL, - storage = storage, - schema = new StructType() - .add("col1", "int", nullable = true, metadata = metadata) - .add("col2", "string") - .add("a", "int") - .add("b", "int"), - provider = Some("parquet"), - partitionColumnNames = Seq("a", "b"), - createTime = 0L, - tracksPartitionsInCatalog = true) - } - private def createTable(catalog: SessionCatalog, name: TableIdentifier): Unit = { catalog.createTable(generateTable(catalog, name), ignoreIfExists = false) } @@ -115,6 +229,11 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { catalog.createPartitions(tableName, Seq(part), ignoreIfExists = false) } + private def getDBPath(dbName: String): URI = { + val warehousePath = s"file:${spark.sessionState.conf.warehousePath.stripPrefix("file:")}" + new Path(warehousePath, s"$dbName.db").toUri + } + test("the qualified path of a database is stored in the catalog") { val catalog = spark.sessionState.catalog @@ -138,11 +257,10 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { try { sql(s"CREATE DATABASE $dbName") val db1 = catalog.getDatabaseMetadata(dbName) - val expectedLocation = makeQualifiedPath(s"spark-warehouse/$dbName.db") assert(db1 == CatalogDatabase( dbName, "", - expectedLocation, + getDBPath(dbName), Map.empty)) sql(s"DROP DATABASE $dbName CASCADE") assert(!catalog.databaseExists(dbName)) @@ -185,16 +303,17 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { val dbNameWithoutBackTicks = cleanIdentifier(dbName) sql(s"CREATE DATABASE $dbName") val db1 = catalog.getDatabaseMetadata(dbNameWithoutBackTicks) - val expectedLocation = makeQualifiedPath(s"spark-warehouse/$dbNameWithoutBackTicks.db") assert(db1 == CatalogDatabase( dbNameWithoutBackTicks, "", - expectedLocation, + getDBPath(dbNameWithoutBackTicks), Map.empty)) - intercept[DatabaseAlreadyExistsException] { + // TODO: HiveExternalCatalog should throw DatabaseAlreadyExistsException + val e = intercept[AnalysisException] { sql(s"CREATE DATABASE $dbName") - } + }.getMessage + assert(e.contains(s"already exists")) } finally { catalog.reset() } @@ -413,19 +532,6 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { } } - test("desc table for parquet data source table using in-memory catalog") { - assume(spark.sparkContext.conf.get(CATALOG_IMPLEMENTATION) == "in-memory") - val tabName = "tab1" - withTable(tabName) { - sql(s"CREATE TABLE $tabName(a int comment 'test') USING parquet ") - - checkAnswer( - sql(s"DESC $tabName").select("col_name", "data_type", "comment"), - Row("a", "int", "test") - ) - } - } - test("Alter/Describe Database") { val catalog = spark.sessionState.catalog val databaseNames = Seq("db1", "`database`") @@ -433,7 +539,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { databaseNames.foreach { dbName => try { val dbNameWithoutBackTicks = cleanIdentifier(dbName) - val location = makeQualifiedPath(s"spark-warehouse/$dbNameWithoutBackTicks.db") + val location = getDBPath(dbNameWithoutBackTicks) sql(s"CREATE DATABASE $dbName") @@ -477,7 +583,12 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { var message = intercept[AnalysisException] { sql(s"DROP DATABASE $dbName") }.getMessage - assert(message.contains(s"Database '$dbNameWithoutBackTicks' not found")) + // TODO: Unify the exception. + if (isUsingHiveMetastore) { + assert(message.contains(s"NoSuchObjectException: $dbNameWithoutBackTicks")) + } else { + assert(message.contains(s"Database '$dbNameWithoutBackTicks' not found")) + } message = intercept[AnalysisException] { sql(s"ALTER DATABASE $dbName SET DBPROPERTIES ('d'='d')") @@ -506,7 +617,12 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { val message = intercept[AnalysisException] { sql(s"DROP DATABASE $dbName RESTRICT") }.getMessage - assert(message.contains(s"Database '$dbName' is not empty. One or more tables exist")) + // TODO: Unify the exception. + if (isUsingHiveMetastore) { + assert(message.contains(s"Database $dbName is not empty. One or more tables exist")) + } else { + assert(message.contains(s"Database '$dbName' is not empty. One or more tables exist")) + } catalog.dropTable(tableIdent1, ignoreIfNotExists = false, purge = false) @@ -537,7 +653,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { createTable(catalog, tableIdent1) val expectedTableIdent = tableIdent1.copy(database = Some("default")) val expectedTable = generateTable(catalog, expectedTableIdent) - assert(catalog.getTableMetadata(tableIdent1) === expectedTable) + checkCatalogTables(expectedTable, catalog.getTableMetadata(tableIdent1)) } test("create table in a specific db") { @@ -546,7 +662,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { val tableIdent1 = TableIdentifier("tab1", Some("dbx")) createTable(catalog, tableIdent1) val expectedTable = generateTable(catalog, tableIdent1) - assert(catalog.getTableMetadata(tableIdent1) === expectedTable) + checkCatalogTables(expectedTable, catalog.getTableMetadata(tableIdent1)) } test("create table using") { @@ -731,52 +847,28 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { testSetLocation(isDatasourceTable = false) } - test("alter table: set location (datasource table)") { - testSetLocation(isDatasourceTable = true) - } - test("alter table: set properties") { testSetProperties(isDatasourceTable = false) } - test("alter table: set properties (datasource table)") { - testSetProperties(isDatasourceTable = true) - } - test("alter table: unset properties") { testUnsetProperties(isDatasourceTable = false) } - test("alter table: unset properties (datasource table)") { - testUnsetProperties(isDatasourceTable = true) - } - // TODO: move this test to HiveDDLSuite.scala ignore("alter table: set serde") { testSetSerde(isDatasourceTable = false) } - test("alter table: set serde (datasource table)") { - testSetSerde(isDatasourceTable = true) - } - // TODO: move this test to HiveDDLSuite.scala ignore("alter table: set serde partition") { testSetSerdePartition(isDatasourceTable = false) } - test("alter table: set serde partition (datasource table)") { - testSetSerdePartition(isDatasourceTable = true) - } - test("alter table: change column") { testChangeColumn(isDatasourceTable = false) } - test("alter table: change column (datasource table)") { - testChangeColumn(isDatasourceTable = true) - } - test("alter table: bucketing is not supported") { val catalog = spark.sessionState.catalog val tableIdent = TableIdentifier("tab1", Some("dbx")) @@ -805,10 +897,6 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { testAddPartitions(isDatasourceTable = false) } - test("alter table: add partition (datasource table)") { - testAddPartitions(isDatasourceTable = true) - } - test("alter table: recover partitions (sequential)") { withSQLConf("spark.rdd.parallelListingThreshold" -> "10") { testRecoverPartitions() @@ -821,7 +909,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { } } - private def testRecoverPartitions() { + protected def testRecoverPartitions() { val catalog = spark.sessionState.catalog // table to alter does not exist intercept[AnalysisException] { @@ -860,8 +948,14 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { sql("ALTER TABLE tab1 RECOVER PARTITIONS") assert(catalog.listPartitions(tableIdent).map(_.spec).toSet == Set(part1, part2)) - assert(catalog.getPartition(tableIdent, part1).parameters("numFiles") == "1") - assert(catalog.getPartition(tableIdent, part2).parameters("numFiles") == "2") + if (!isUsingHiveMetastore) { + assert(catalog.getPartition(tableIdent, part1).parameters("numFiles") == "1") + assert(catalog.getPartition(tableIdent, part2).parameters("numFiles") == "2") + } else { + // After ALTER TABLE, the statistics of the first partition is removed by Hive megastore + assert(catalog.getPartition(tableIdent, part1).parameters.get("numFiles").isEmpty) + assert(catalog.getPartition(tableIdent, part2).parameters("numFiles") == "2") + } } finally { fs.delete(root, true) } @@ -875,10 +969,6 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { testDropPartitions(isDatasourceTable = false) } - test("alter table: drop partition (datasource table)") { - testDropPartitions(isDatasourceTable = true) - } - test("alter table: drop partition is not supported for views") { assertUnsupported("ALTER VIEW dbx.tab1 DROP IF EXISTS PARTITION (b='2')") } @@ -887,10 +977,6 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { testRenamePartitions(isDatasourceTable = false) } - test("alter table: rename partition (datasource table)") { - testRenamePartitions(isDatasourceTable = true) - } - test("show table extended") { withTempView("show1a", "show2b") { sql( @@ -971,11 +1057,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { testDropTable(isDatasourceTable = false) } - test("drop table - data source table") { - testDropTable(isDatasourceTable = true) - } - - private def testDropTable(isDatasourceTable: Boolean): Unit = { + protected def testDropTable(isDatasourceTable: Boolean): Unit = { val catalog = spark.sessionState.catalog val tableIdent = TableIdentifier("tab1", Some("dbx")) createDatabase(catalog, "dbx") @@ -1011,9 +1093,10 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { tableIdent: TableIdentifier): Unit = { catalog.alterTable(catalog.getTableMetadata(tableIdent).copy( provider = Some("csv"))) + assert(catalog.getTableMetadata(tableIdent).provider == Some("csv")) } - private def testSetProperties(isDatasourceTable: Boolean): Unit = { + protected def testSetProperties(isDatasourceTable: Boolean): Unit = { val catalog = spark.sessionState.catalog val tableIdent = TableIdentifier("tab1", Some("dbx")) createDatabase(catalog, "dbx") @@ -1022,7 +1105,11 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { convertToDatasourceTable(catalog, tableIdent) } def getProps: Map[String, String] = { - catalog.getTableMetadata(tableIdent).properties + if (isUsingHiveMetastore) { + normalizeCatalogTable(catalog.getTableMetadata(tableIdent)).properties + } else { + catalog.getTableMetadata(tableIdent).properties + } } assert(getProps.isEmpty) // set table properties @@ -1038,7 +1125,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { } } - private def testUnsetProperties(isDatasourceTable: Boolean): Unit = { + protected def testUnsetProperties(isDatasourceTable: Boolean): Unit = { val catalog = spark.sessionState.catalog val tableIdent = TableIdentifier("tab1", Some("dbx")) createDatabase(catalog, "dbx") @@ -1047,7 +1134,11 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { convertToDatasourceTable(catalog, tableIdent) } def getProps: Map[String, String] = { - catalog.getTableMetadata(tableIdent).properties + if (isUsingHiveMetastore) { + normalizeCatalogTable(catalog.getTableMetadata(tableIdent)).properties + } else { + catalog.getTableMetadata(tableIdent).properties + } } // unset table properties sql("ALTER TABLE dbx.tab1 SET TBLPROPERTIES ('j' = 'am', 'p' = 'an', 'c' = 'lan', 'x' = 'y')") @@ -1071,7 +1162,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { assert(getProps == Map("x" -> "y")) } - private def testSetLocation(isDatasourceTable: Boolean): Unit = { + protected def testSetLocation(isDatasourceTable: Boolean): Unit = { val catalog = spark.sessionState.catalog val tableIdent = TableIdentifier("tab1", Some("dbx")) val partSpec = Map("a" -> "1", "b" -> "2") @@ -1082,24 +1173,21 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { convertToDatasourceTable(catalog, tableIdent) } assert(catalog.getTableMetadata(tableIdent).storage.locationUri.isDefined) - assert(catalog.getTableMetadata(tableIdent).storage.properties.isEmpty) + assert(normalizeSerdeProp(catalog.getTableMetadata(tableIdent).storage.properties).isEmpty) assert(catalog.getPartition(tableIdent, partSpec).storage.locationUri.isDefined) - assert(catalog.getPartition(tableIdent, partSpec).storage.properties.isEmpty) + assert( + normalizeSerdeProp(catalog.getPartition(tableIdent, partSpec).storage.properties).isEmpty) + // Verify that the location is set to the expected string def verifyLocation(expected: URI, spec: Option[TablePartitionSpec] = None): Unit = { val storageFormat = spec .map { s => catalog.getPartition(tableIdent, s).storage } .getOrElse { catalog.getTableMetadata(tableIdent).storage } - if (isDatasourceTable) { - if (spec.isDefined) { - assert(storageFormat.properties.isEmpty) - assert(storageFormat.locationUri === Some(expected)) - } else { - assert(storageFormat.locationUri === Some(expected)) - } - } else { - assert(storageFormat.locationUri === Some(expected)) - } + // TODO(gatorsmile): fix the bug in alter table set location. + // if (isUsingHiveMetastore) { + // assert(storageFormat.properties.get("path") === expected) + // } + assert(storageFormat.locationUri === Some(expected)) } // set table location sql("ALTER TABLE dbx.tab1 SET LOCATION '/path/to/your/lovely/heart'") @@ -1124,7 +1212,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { } } - private def testSetSerde(isDatasourceTable: Boolean): Unit = { + protected def testSetSerde(isDatasourceTable: Boolean): Unit = { val catalog = spark.sessionState.catalog val tableIdent = TableIdentifier("tab1", Some("dbx")) createDatabase(catalog, "dbx") @@ -1132,8 +1220,21 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { if (isDatasourceTable) { convertToDatasourceTable(catalog, tableIdent) } - assert(catalog.getTableMetadata(tableIdent).storage.serde.isEmpty) - assert(catalog.getTableMetadata(tableIdent).storage.properties.isEmpty) + def checkSerdeProps(expectedSerdeProps: Map[String, String]): Unit = { + val serdeProp = catalog.getTableMetadata(tableIdent).storage.properties + if (isUsingHiveMetastore) { + assert(normalizeSerdeProp(serdeProp) == expectedSerdeProps) + } else { + assert(serdeProp == expectedSerdeProps) + } + } + if (isUsingHiveMetastore) { + assert(catalog.getTableMetadata(tableIdent).storage.serde == + Some("org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe")) + } else { + assert(catalog.getTableMetadata(tableIdent).storage.serde.isEmpty) + } + checkSerdeProps(Map.empty[String, String]) // set table serde and/or properties (should fail on datasource tables) if (isDatasourceTable) { val e1 = intercept[AnalysisException] { @@ -1146,31 +1247,30 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { assert(e1.getMessage.contains("datasource")) assert(e2.getMessage.contains("datasource")) } else { - sql("ALTER TABLE dbx.tab1 SET SERDE 'org.apache.jadoop'") - assert(catalog.getTableMetadata(tableIdent).storage.serde == Some("org.apache.jadoop")) - assert(catalog.getTableMetadata(tableIdent).storage.properties.isEmpty) - sql("ALTER TABLE dbx.tab1 SET SERDE 'org.apache.madoop' " + + val newSerde = "org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe" + sql(s"ALTER TABLE dbx.tab1 SET SERDE '$newSerde'") + assert(catalog.getTableMetadata(tableIdent).storage.serde == Some(newSerde)) + checkSerdeProps(Map.empty[String, String]) + val serde2 = "org.apache.hadoop.hive.serde2.columnar.LazyBinaryColumnarSerDe" + sql(s"ALTER TABLE dbx.tab1 SET SERDE '$serde2' " + "WITH SERDEPROPERTIES ('k' = 'v', 'kay' = 'vee')") - assert(catalog.getTableMetadata(tableIdent).storage.serde == Some("org.apache.madoop")) - assert(catalog.getTableMetadata(tableIdent).storage.properties == - Map("k" -> "v", "kay" -> "vee")) + assert(catalog.getTableMetadata(tableIdent).storage.serde == Some(serde2)) + checkSerdeProps(Map("k" -> "v", "kay" -> "vee")) } // set serde properties only sql("ALTER TABLE dbx.tab1 SET SERDEPROPERTIES ('k' = 'vvv', 'kay' = 'vee')") - assert(catalog.getTableMetadata(tableIdent).storage.properties == - Map("k" -> "vvv", "kay" -> "vee")) + checkSerdeProps(Map("k" -> "vvv", "kay" -> "vee")) // set things without explicitly specifying database catalog.setCurrentDatabase("dbx") sql("ALTER TABLE tab1 SET SERDEPROPERTIES ('kay' = 'veee')") - assert(catalog.getTableMetadata(tableIdent).storage.properties == - Map("k" -> "vvv", "kay" -> "veee")) + checkSerdeProps(Map("k" -> "vvv", "kay" -> "veee")) // table to alter does not exist intercept[AnalysisException] { sql("ALTER TABLE does_not_exist SET SERDEPROPERTIES ('x' = 'y')") } } - private def testSetSerdePartition(isDatasourceTable: Boolean): Unit = { + protected def testSetSerdePartition(isDatasourceTable: Boolean): Unit = { val catalog = spark.sessionState.catalog val tableIdent = TableIdentifier("tab1", Some("dbx")) val spec = Map("a" -> "1", "b" -> "2") @@ -1183,8 +1283,21 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { if (isDatasourceTable) { convertToDatasourceTable(catalog, tableIdent) } - assert(catalog.getPartition(tableIdent, spec).storage.serde.isEmpty) - assert(catalog.getPartition(tableIdent, spec).storage.properties.isEmpty) + def checkPartitionSerdeProps(expectedSerdeProps: Map[String, String]): Unit = { + val serdeProp = catalog.getPartition(tableIdent, spec).storage.properties + if (isUsingHiveMetastore) { + assert(normalizeSerdeProp(serdeProp) == expectedSerdeProps) + } else { + assert(serdeProp == expectedSerdeProps) + } + } + if (isUsingHiveMetastore) { + assert(catalog.getPartition(tableIdent, spec).storage.serde == + Some("org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe")) + } else { + assert(catalog.getPartition(tableIdent, spec).storage.serde.isEmpty) + } + checkPartitionSerdeProps(Map.empty[String, String]) // set table serde and/or properties (should fail on datasource tables) if (isDatasourceTable) { val e1 = intercept[AnalysisException] { @@ -1199,26 +1312,23 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { } else { sql("ALTER TABLE dbx.tab1 PARTITION (a=1, b=2) SET SERDE 'org.apache.jadoop'") assert(catalog.getPartition(tableIdent, spec).storage.serde == Some("org.apache.jadoop")) - assert(catalog.getPartition(tableIdent, spec).storage.properties.isEmpty) + checkPartitionSerdeProps(Map.empty[String, String]) sql("ALTER TABLE dbx.tab1 PARTITION (a=1, b=2) SET SERDE 'org.apache.madoop' " + "WITH SERDEPROPERTIES ('k' = 'v', 'kay' = 'vee')") assert(catalog.getPartition(tableIdent, spec).storage.serde == Some("org.apache.madoop")) - assert(catalog.getPartition(tableIdent, spec).storage.properties == - Map("k" -> "v", "kay" -> "vee")) + checkPartitionSerdeProps(Map("k" -> "v", "kay" -> "vee")) } // set serde properties only maybeWrapException(isDatasourceTable) { sql("ALTER TABLE dbx.tab1 PARTITION (a=1, b=2) " + "SET SERDEPROPERTIES ('k' = 'vvv', 'kay' = 'vee')") - assert(catalog.getPartition(tableIdent, spec).storage.properties == - Map("k" -> "vvv", "kay" -> "vee")) + checkPartitionSerdeProps(Map("k" -> "vvv", "kay" -> "vee")) } // set things without explicitly specifying database catalog.setCurrentDatabase("dbx") maybeWrapException(isDatasourceTable) { sql("ALTER TABLE tab1 PARTITION (a=1, b=2) SET SERDEPROPERTIES ('kay' = 'veee')") - assert(catalog.getPartition(tableIdent, spec).storage.properties == - Map("k" -> "vvv", "kay" -> "veee")) + checkPartitionSerdeProps(Map("k" -> "vvv", "kay" -> "veee")) } // table to alter does not exist intercept[AnalysisException] { @@ -1226,7 +1336,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { } } - private def testAddPartitions(isDatasourceTable: Boolean): Unit = { + protected def testAddPartitions(isDatasourceTable: Boolean): Unit = { val catalog = spark.sessionState.catalog val tableIdent = TableIdentifier("tab1", Some("dbx")) val part1 = Map("a" -> "1", "b" -> "5") @@ -1247,7 +1357,15 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { "PARTITION (a='2', b='6') LOCATION 'paris' PARTITION (a='3', b='7')") assert(catalog.listPartitions(tableIdent).map(_.spec).toSet == Set(part1, part2, part3)) assert(catalog.getPartition(tableIdent, part1).storage.locationUri.isDefined) - assert(catalog.getPartition(tableIdent, part2).storage.locationUri == Option(new URI("paris"))) + val partitionLocation = if (isUsingHiveMetastore) { + val tableLocation = catalog.getTableMetadata(tableIdent).storage.locationUri + assert(tableLocation.isDefined) + makeQualifiedPath(new Path(tableLocation.get.toString, "paris")) + } else { + new URI("paris") + } + + assert(catalog.getPartition(tableIdent, part2).storage.locationUri == Option(partitionLocation)) assert(catalog.getPartition(tableIdent, part3).storage.locationUri.isDefined) // add partitions without explicitly specifying database @@ -1277,7 +1395,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { Set(part1, part2, part3, part4, part5)) } - private def testDropPartitions(isDatasourceTable: Boolean): Unit = { + protected def testDropPartitions(isDatasourceTable: Boolean): Unit = { val catalog = spark.sessionState.catalog val tableIdent = TableIdentifier("tab1", Some("dbx")) val part1 = Map("a" -> "1", "b" -> "5") @@ -1330,7 +1448,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { assert(catalog.listPartitions(tableIdent).isEmpty) } - private def testRenamePartitions(isDatasourceTable: Boolean): Unit = { + protected def testRenamePartitions(isDatasourceTable: Boolean): Unit = { val catalog = spark.sessionState.catalog val tableIdent = TableIdentifier("tab1", Some("dbx")) val part1 = Map("a" -> "1", "b" -> "q") @@ -1374,7 +1492,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { Set(Map("a" -> "1", "b" -> "p"), Map("a" -> "20", "b" -> "c"), Map("a" -> "3", "b" -> "p"))) } - private def testChangeColumn(isDatasourceTable: Boolean): Unit = { + protected def testChangeColumn(isDatasourceTable: Boolean): Unit = { val catalog = spark.sessionState.catalog val resolver = spark.sessionState.conf.resolver val tableIdent = TableIdentifier("tab1", Some("dbx")) @@ -1474,35 +1592,6 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { ) } - test("create a managed Hive source table") { - assume(spark.sparkContext.conf.get(CATALOG_IMPLEMENTATION) == "in-memory") - val tabName = "tbl" - withTable(tabName) { - val e = intercept[AnalysisException] { - sql(s"CREATE TABLE $tabName (i INT, j STRING)") - }.getMessage - assert(e.contains("Hive support is required to CREATE Hive TABLE")) - } - } - - test("create an external Hive source table") { - assume(spark.sparkContext.conf.get(CATALOG_IMPLEMENTATION) == "in-memory") - withTempDir { tempDir => - val tabName = "tbl" - withTable(tabName) { - val e = intercept[AnalysisException] { - sql( - s""" - |CREATE EXTERNAL TABLE $tabName (i INT, j STRING) - |ROW FORMAT DELIMITED FIELDS TERMINATED BY ',' - |LOCATION '${tempDir.toURI}' - """.stripMargin) - }.getMessage - assert(e.contains("Hive support is required to CREATE Hive TABLE")) - } - } - } - test("create a data source table without schema") { import testImplicits._ withTempPath { tempDir => @@ -1541,22 +1630,6 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { } } - test("Create Hive Table As Select") { - import testImplicits._ - withTable("t", "t1") { - var e = intercept[AnalysisException] { - sql("CREATE TABLE t SELECT 1 as a, 1 as b") - }.getMessage - assert(e.contains("Hive support is required to CREATE Hive TABLE (AS SELECT)")) - - spark.range(1).select('id as 'a, 'id as 'b).write.saveAsTable("t1") - e = intercept[AnalysisException] { - sql("CREATE TABLE t SELECT a, b from t1") - }.getMessage - assert(e.contains("Hive support is required to CREATE Hive TABLE (AS SELECT)")) - } - } - test("Create Data Source Table As Select") { import testImplicits._ withTable("t", "t1", "t2") { @@ -1580,7 +1653,8 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { } test("drop default database") { - Seq("true", "false").foreach { caseSensitive => + val caseSensitiveOptions = if (isUsingHiveMetastore) Seq("false") else Seq("true", "false") + caseSensitiveOptions.foreach { caseSensitive => withSQLConf(SQLConf.CASE_SENSITIVE.key -> caseSensitive) { var message = intercept[AnalysisException] { sql("DROP DATABASE default") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala index 9201954b66d10..12fc8993d7396 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala @@ -306,6 +306,11 @@ private[sql] trait SQLTestUtils val fs = hadoopPath.getFileSystem(spark.sessionState.newHadoopConf()) fs.makeQualified(hadoopPath).toUri } + + def makeQualifiedPath(path: Path): URI = { + val fs = path.getFileSystem(spark.sessionState.newHadoopConf()) + fs.makeQualified(path).toUri + } } private[sql] object SQLTestUtils { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala index 10d929a4a0ef8..fce055048d72f 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala @@ -27,16 +27,88 @@ import org.scalatest.BeforeAndAfterEach import org.apache.spark.SparkException import org.apache.spark.sql.{AnalysisException, QueryTest, Row, SaveMode} import org.apache.spark.sql.catalyst.analysis.{NoSuchPartitionException, TableAlreadyExistsException} -import org.apache.spark.sql.catalyst.catalog.{CatalogDatabase, CatalogTable, CatalogTableType, CatalogUtils, ExternalCatalogUtils} +import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.execution.command.DDLUtils +import org.apache.spark.sql.execution.command.{DDLSuite, DDLUtils} import org.apache.spark.sql.hive.HiveExternalCatalog import org.apache.spark.sql.hive.orc.OrcFileOperator import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION import org.apache.spark.sql.test.SQLTestUtils -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{MetadataBuilder, StructType} + +// TODO(gatorsmile): combine HiveCatalogedDDLSuite and HiveDDLSuite +class HiveCatalogedDDLSuite extends DDLSuite with TestHiveSingleton with BeforeAndAfterEach { + override def afterEach(): Unit = { + try { + // drop all databases, tables and functions after each test + spark.sessionState.catalog.reset() + } finally { + super.afterEach() + } + } + + protected override def generateTable( + catalog: SessionCatalog, + name: TableIdentifier): CatalogTable = { + val storage = + CatalogStorageFormat( + locationUri = Some(catalog.defaultTablePath(name)), + inputFormat = Some("org.apache.hadoop.mapred.SequenceFileInputFormat"), + outputFormat = Some("org.apache.hadoop.hive.ql.io.HiveSequenceFileOutputFormat"), + serde = Some("org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe"), + compressed = false, + properties = Map("serialization.format" -> "1")) + val metadata = new MetadataBuilder() + .putString("key", "value") + .build() + CatalogTable( + identifier = name, + tableType = CatalogTableType.EXTERNAL, + storage = storage, + schema = new StructType() + .add("col1", "int", nullable = true, metadata = metadata) + .add("col2", "string") + .add("a", "int") + .add("b", "int"), + provider = Some("hive"), + partitionColumnNames = Seq("a", "b"), + createTime = 0L, + tracksPartitionsInCatalog = true) + } + + protected override def normalizeCatalogTable(table: CatalogTable): CatalogTable = { + val nondeterministicProps = Set( + "CreateTime", + "transient_lastDdlTime", + "grantTime", + "lastUpdateTime", + "last_modified_by", + "last_modified_time", + "Owner:", + "COLUMN_STATS_ACCURATE", + // The following are hive specific schema parameters which we do not need to match exactly. + "numFiles", + "numRows", + "rawDataSize", + "totalSize", + "totalNumberFiles", + "maxFileSize", + "minFileSize" + ) + + table.copy( + createTime = 0L, + lastAccessTime = 0L, + owner = "", + properties = table.properties.filterKeys(!nondeterministicProps.contains(_)), + // View texts are checked separately + viewText = None + ) + } + +} class HiveDDLSuite extends QueryTest with SQLTestUtils with TestHiveSingleton with BeforeAndAfterEach { @@ -1719,61 +1791,6 @@ class HiveDDLSuite } } - Seq("a b", "a:b", "a%b").foreach { specialChars => - test(s"datasource table: location uri contains $specialChars") { - withTable("t", "t1") { - withTempDir { dir => - val loc = new File(dir, specialChars) - loc.mkdir() - spark.sql( - s""" - |CREATE TABLE t(a string) - |USING parquet - |LOCATION '$loc' - """.stripMargin) - - val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t")) - assert(table.location == new Path(loc.getAbsolutePath).toUri) - assert(new Path(table.location).toString.contains(specialChars)) - - assert(loc.listFiles().isEmpty) - spark.sql("INSERT INTO TABLE t SELECT 1") - assert(loc.listFiles().length >= 1) - checkAnswer(spark.table("t"), Row("1") :: Nil) - } - - withTempDir { dir => - val loc = new File(dir, specialChars) - loc.mkdir() - spark.sql( - s""" - |CREATE TABLE t1(a string, b string) - |USING parquet - |PARTITIONED BY(b) - |LOCATION '$loc' - """.stripMargin) - - val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t1")) - assert(table.location == new Path(loc.getAbsolutePath).toUri) - assert(new Path(table.location).toString.contains(specialChars)) - - assert(loc.listFiles().isEmpty) - spark.sql("INSERT INTO TABLE t1 PARTITION(b=2) SELECT 1") - val partFile = new File(loc, "b=2") - assert(partFile.listFiles().length >= 1) - checkAnswer(spark.table("t1"), Row("1", "2") :: Nil) - - spark.sql("INSERT INTO TABLE t1 PARTITION(b='2017-03-03 12:13%3A14') SELECT 1") - val partFile1 = new File(loc, "b=2017-03-03 12:13%3A14") - assert(!partFile1.exists()) - val partFile2 = new File(loc, "b=2017-03-03 12%3A13%253A14") - assert(partFile2.listFiles().length >= 1) - checkAnswer(spark.table("t1"), Row("1", "2") :: Row("1", "2017-03-03 12:13%3A14") :: Nil) - } - } - } - } - Seq("a b", "a:b", "a%b").foreach { specialChars => test(s"hive table: location uri contains $specialChars") { withTable("t") { @@ -1848,28 +1865,4 @@ class HiveDDLSuite } } } - - Seq("a b", "a:b", "a%b").foreach { specialChars => - test(s"location uri contains $specialChars for database") { - try { - withTable("t") { - withTempDir { dir => - val loc = new File(dir, specialChars) - spark.sql(s"CREATE DATABASE tmpdb LOCATION '$loc'") - spark.sql("USE tmpdb") - - Seq(1).toDF("a").write.saveAsTable("t") - val tblloc = new File(loc, "t") - val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t")) - val tblPath = new Path(tblloc.getAbsolutePath) - val fs = tblPath.getFileSystem(spark.sessionState.newHadoopConf()) - assert(table.location == makeQualifiedPath(tblloc.getAbsolutePath)) - assert(tblloc.listFiles().nonEmpty) - } - } - } finally { - spark.sql("DROP DATABASE IF EXISTS tmpdb") - } - } - } } From 029e40b412e332c9f0fff283d604e203066c78c0 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Wed, 8 Mar 2017 23:15:52 -0800 Subject: [PATCH 56/78] [SPARK-19874][BUILD] Hide API docs for org.apache.spark.sql.internal ## What changes were proposed in this pull request? The API docs should not include the "org.apache.spark.sql.internal" package because they are internal private APIs. ## How was this patch tested? Jenkins Author: Shixiong Zhu Closes #17217 from zsxwing/SPARK-19874. --- project/SparkBuild.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 93a31897c9fc1..e52baf51aed1a 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -655,6 +655,7 @@ object Unidoc { .map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/util/collection"))) .map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/sql/catalyst"))) .map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/sql/execution"))) + .map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/sql/internal"))) .map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/sql/hive/test"))) } From eeb1d6db878641d9eac62d0869a90fe80c1f4461 Mon Sep 17 00:00:00 2001 From: uncleGen Date: Wed, 8 Mar 2017 23:23:10 -0800 Subject: [PATCH 57/78] [SPARK-19859][SS][FOLLOW-UP] The new watermark should override the old one. ## What changes were proposed in this pull request? A follow up to SPARK-19859: - extract the calculation of `delayMs` and reuse it. - update EventTimeWatermarkExec - use the correct `delayMs` in EventTimeWatermark ## How was this patch tested? Jenkins. Author: uncleGen Closes #17221 from uncleGen/SPARK-19859. --- .../plans/logical/EventTimeWatermark.scala | 9 ++++++++- .../streaming/EventTimeWatermarkExec.scala | 19 +++++++++++-------- 2 files changed, 19 insertions(+), 9 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/EventTimeWatermark.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/EventTimeWatermark.scala index 62f68a6d7b528..06196b5afb031 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/EventTimeWatermark.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/EventTimeWatermark.scala @@ -24,6 +24,12 @@ import org.apache.spark.unsafe.types.CalendarInterval object EventTimeWatermark { /** The [[org.apache.spark.sql.types.Metadata]] key used to hold the eventTime watermark delay. */ val delayKey = "spark.watermarkDelayMs" + + def getDelayMs(delay: CalendarInterval): Long = { + // We define month as `31 days` to simplify calculation. + val millisPerMonth = CalendarInterval.MICROS_PER_DAY / 1000 * 31 + delay.milliseconds + delay.months * millisPerMonth + } } /** @@ -37,9 +43,10 @@ case class EventTimeWatermark( // Update the metadata on the eventTime column to include the desired delay. override val output: Seq[Attribute] = child.output.map { a => if (a semanticEquals eventTime) { + val delayMs = EventTimeWatermark.getDelayMs(delay) val updatedMetadata = new MetadataBuilder() .withMetadata(a.metadata) - .putLong(EventTimeWatermark.delayKey, delay.milliseconds) + .putLong(EventTimeWatermark.delayKey, delayMs) .build() a.withMetadata(updatedMetadata) } else if (a.metadata.contains(EventTimeWatermark.delayKey)) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/EventTimeWatermarkExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/EventTimeWatermarkExec.scala index 5a9a99e11188e..25cf609fc336e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/EventTimeWatermarkExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/EventTimeWatermarkExec.scala @@ -84,10 +84,7 @@ case class EventTimeWatermarkExec( child: SparkPlan) extends SparkPlan { val eventTimeStats = new EventTimeStatsAccum() - val delayMs = { - val millisPerMonth = CalendarInterval.MICROS_PER_DAY / 1000 * 31 - delay.milliseconds + delay.months * millisPerMonth - } + val delayMs = EventTimeWatermark.getDelayMs(delay) sparkContext.register(eventTimeStats) @@ -105,10 +102,16 @@ case class EventTimeWatermarkExec( override val output: Seq[Attribute] = child.output.map { a => if (a semanticEquals eventTime) { val updatedMetadata = new MetadataBuilder() - .withMetadata(a.metadata) - .putLong(EventTimeWatermark.delayKey, delayMs) - .build() - + .withMetadata(a.metadata) + .putLong(EventTimeWatermark.delayKey, delayMs) + .build() + a.withMetadata(updatedMetadata) + } else if (a.metadata.contains(EventTimeWatermark.delayKey)) { + // Remove existing watermark + val updatedMetadata = new MetadataBuilder() + .withMetadata(a.metadata) + .remove(EventTimeWatermark.delayKey) + .build() a.withMetadata(updatedMetadata) } else { a From 274973d2a32ff4eb28545b50a3135e4784eb3047 Mon Sep 17 00:00:00 2001 From: windpiger Date: Thu, 9 Mar 2017 01:18:17 -0800 Subject: [PATCH 58/78] [SPARK-19763][SQL] qualified external datasource table location stored in catalog ## What changes were proposed in this pull request? If we create a external datasource table with a non-qualified location , we should qualified it to store in catalog. ``` CREATE TABLE t(a string) USING parquet LOCATION '/path/xx' CREATE TABLE t1(a string, b string) USING parquet PARTITIONED BY(b) LOCATION '/path/xx' ``` when we get the table from catalog, the location should be qualified, e.g.'file:/path/xxx' ## How was this patch tested? unit test added Author: windpiger Closes #17095 from windpiger/tablepathQualified. --- .../sql/catalyst/catalog/SessionCatalog.scala | 14 +++++- .../sql/execution/command/DDLSuite.scala | 50 +++++++++++++++---- .../spark/sql/internal/CatalogSuite.scala | 3 +- .../spark/sql/sources/PathOptionSuite.scala | 9 ++-- .../apache/spark/sql/test/SQLTestUtils.scala | 5 -- .../sql/hive/HiveMetastoreCatalogSuite.scala | 4 +- .../spark/sql/hive/client/VersionsSuite.scala | 6 +-- .../sql/hive/execution/HiveDDLSuite.scala | 4 +- 8 files changed, 64 insertions(+), 31 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index 6cfc4a4321316..bfcdb70fe47c1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -259,7 +259,19 @@ class SessionCatalog( val db = formatDatabaseName(tableDefinition.identifier.database.getOrElse(getCurrentDatabase)) val table = formatTableName(tableDefinition.identifier.table) validateName(table) - val newTableDefinition = tableDefinition.copy(identifier = TableIdentifier(table, Some(db))) + + val newTableDefinition = if (tableDefinition.storage.locationUri.isDefined + && !tableDefinition.storage.locationUri.get.isAbsolute) { + // make the location of the table qualified. + val qualifiedTableLocation = + makeQualifiedPath(tableDefinition.storage.locationUri.get) + tableDefinition.copy( + storage = tableDefinition.storage.copy(locationUri = Some(qualifiedTableLocation)), + identifier = TableIdentifier(table, Some(db))) + } else { + tableDefinition.copy(identifier = TableIdentifier(table, Some(db))) + } + requireDbExists(db) externalCatalog.createTable(newTableDefinition, ignoreIfExists) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index aa335c4453dd1..5f70a8ce8918b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -230,8 +230,8 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { } private def getDBPath(dbName: String): URI = { - val warehousePath = s"file:${spark.sessionState.conf.warehousePath.stripPrefix("file:")}" - new Path(warehousePath, s"$dbName.db").toUri + val warehousePath = makeQualifiedPath(s"${spark.sessionState.conf.warehousePath}") + new Path(CatalogUtils.URIToString(warehousePath), s"$dbName.db").toUri } test("the qualified path of a database is stored in the catalog") { @@ -1360,7 +1360,7 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { val partitionLocation = if (isUsingHiveMetastore) { val tableLocation = catalog.getTableMetadata(tableIdent).storage.locationUri assert(tableLocation.isDefined) - makeQualifiedPath(new Path(tableLocation.get.toString, "paris")) + makeQualifiedPath(new Path(tableLocation.get.toString, "paris").toString) } else { new URI("paris") } @@ -1909,7 +1909,7 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { |OPTIONS(path "$dir") """.stripMargin) val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t")) - assert(table.location == new URI(dir.getAbsolutePath)) + assert(table.location == makeQualifiedPath(dir.getAbsolutePath)) dir.delete assert(!dir.exists) @@ -1950,7 +1950,7 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { |LOCATION "$dir" """.stripMargin) val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t")) - assert(table.location == new URI(dir.getAbsolutePath)) + assert(table.location == makeQualifiedPath(dir.getAbsolutePath)) spark.sql("INSERT INTO TABLE t PARTITION(a=1, b=2) SELECT 3, 4") checkAnswer(spark.table("t"), Row(3, 4, 1, 2) :: Nil) @@ -1976,7 +1976,8 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { |OPTIONS(path "$dir") """.stripMargin) val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t")) - assert(table.location == new URI(dir.getAbsolutePath)) + + assert(table.location == makeQualifiedPath(dir.getAbsolutePath)) dir.delete() checkAnswer(spark.table("t"), Nil) @@ -2032,7 +2033,7 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { |AS SELECT 3 as a, 4 as b, 1 as c, 2 as d """.stripMargin) val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t")) - assert(table.location == new URI(dir.getAbsolutePath)) + assert(table.location == makeQualifiedPath(dir.getAbsolutePath)) checkAnswer(spark.table("t"), Row(3, 4, 1, 2)) } @@ -2051,7 +2052,7 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { |AS SELECT 3 as a, 4 as b, 1 as c, 2 as d """.stripMargin) val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t1")) - assert(table.location == new URI(dir.getAbsolutePath)) + assert(table.location == makeQualifiedPath(dir.getAbsolutePath)) val partDir = new File(dir, "a=3") assert(partDir.exists()) @@ -2099,7 +2100,7 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { """.stripMargin) val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t")) - assert(table.location == new Path(loc.getAbsolutePath).toUri) + assert(table.location == makeQualifiedPath(loc.getAbsolutePath)) assert(new Path(table.location).toString.contains(specialChars)) assert(loc.listFiles().isEmpty) @@ -2120,7 +2121,7 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { """.stripMargin) val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t1")) - assert(table.location == new Path(loc.getAbsolutePath).toUri) + assert(table.location == makeQualifiedPath(loc.getAbsolutePath)) assert(new Path(table.location).toString.contains(specialChars)) assert(loc.listFiles().isEmpty) @@ -2162,4 +2163,33 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { } } } + + test("the qualified path of a datasource table is stored in the catalog") { + withTable("t", "t1") { + withTempDir { dir => + assert(!dir.getAbsolutePath.startsWith("file:/")) + spark.sql( + s""" + |CREATE TABLE t(a string) + |USING parquet + |LOCATION '$dir' + """.stripMargin) + val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t")) + assert(table.location.toString.startsWith("file:/")) + } + + withTempDir { dir => + assert(!dir.getAbsolutePath.startsWith("file:/")) + spark.sql( + s""" + |CREATE TABLE t1(a string, b string) + |USING parquet + |PARTITIONED BY(b) + |LOCATION '$dir' + """.stripMargin) + val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t1")) + assert(table.location.toString.startsWith("file:/")) + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala index fcb8ffbc6edd0..9742b3b2d5c29 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.internal import java.io.File -import java.net.URI import org.scalatest.BeforeAndAfterEach @@ -459,7 +458,7 @@ class CatalogSuite options = Map("path" -> dir.getAbsolutePath)) val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t")) assert(table.tableType == CatalogTableType.EXTERNAL) - assert(table.storage.locationUri.get == new URI(dir.getAbsolutePath)) + assert(table.storage.locationUri.get == makeQualifiedPath(dir.getAbsolutePath)) Seq((1)).toDF("i").write.insertInto("t") assert(dir.exists() && dir.listFiles().nonEmpty) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/PathOptionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/PathOptionSuite.scala index 7ab339e005295..60adee4599b0b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/PathOptionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/PathOptionSuite.scala @@ -75,7 +75,7 @@ class PathOptionSuite extends DataSourceTest with SharedSQLContext { |USING ${classOf[TestOptionsSource].getCanonicalName} |OPTIONS (PATH '/tmp/path') """.stripMargin) - assert(getPathOption("src") == Some("/tmp/path")) + assert(getPathOption("src") == Some("file:/tmp/path")) } // should exist even path option is not specified when creating table @@ -88,15 +88,16 @@ class PathOptionSuite extends DataSourceTest with SharedSQLContext { test("path option also exist for write path") { withTable("src") { withTempPath { p => - val path = new Path(p.getAbsolutePath).toString sql( s""" |CREATE TABLE src |USING ${classOf[TestOptionsSource].getCanonicalName} - |OPTIONS (PATH '$path') + |OPTIONS (PATH '$p') |AS SELECT 1 """.stripMargin) - assert(spark.table("src").schema.head.metadata.getString("path") == path) + assert(CatalogUtils.stringToURI( + spark.table("src").schema.head.metadata.getString("path")) == + makeQualifiedPath(p.getAbsolutePath)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala index 12fc8993d7396..9201954b66d10 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala @@ -306,11 +306,6 @@ private[sql] trait SQLTestUtils val fs = hadoopPath.getFileSystem(spark.sessionState.newHadoopConf()) fs.makeQualified(hadoopPath).toUri } - - def makeQualifiedPath(path: Path): URI = { - val fs = path.getFileSystem(spark.sessionState.newHadoopConf()) - fs.makeQualified(path).toUri - } } private[sql] object SQLTestUtils { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala index cf552b4a88b2c..079358b29a191 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.hive -import java.net.URI - import org.apache.spark.sql.{QueryTest, Row, SaveMode} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.catalog.CatalogTableType @@ -142,7 +140,7 @@ class DataSourceWithHiveMetastoreCatalogSuite assert(hiveTable.storage.serde === Some(serde)) assert(hiveTable.tableType === CatalogTableType.EXTERNAL) - assert(hiveTable.storage.locationUri === Some(new URI(path.getAbsolutePath))) + assert(hiveTable.storage.locationUri === Some(makeQualifiedPath(dir.getAbsolutePath))) val columns = hiveTable.schema assert(columns.map(_.name) === Seq("d1", "d2")) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala index dd624eca6b7b0..6025f8adbce28 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala @@ -658,19 +658,17 @@ class VersionsSuite extends QueryTest with SQLTestUtils with TestHiveSingleton w val tPath = new Path(spark.sessionState.conf.warehousePath, "t") Seq("1").toDF("a").write.saveAsTable("t") - val expectedPath = s"file:${tPath.toUri.getPath.stripSuffix("/")}" val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t")) - assert(table.location == CatalogUtils.stringToURI(expectedPath)) + assert(table.location == makeQualifiedPath(tPath.toString)) assert(tPath.getFileSystem(spark.sessionState.newHadoopConf()).exists(tPath)) checkAnswer(spark.table("t"), Row("1") :: Nil) val t1Path = new Path(spark.sessionState.conf.warehousePath, "t1") spark.sql("create table t1 using parquet as select 2 as a") val table1 = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t1")) - val expectedPath1 = s"file:${t1Path.toUri.getPath.stripSuffix("/")}" - assert(table1.location == CatalogUtils.stringToURI(expectedPath1)) + assert(table1.location == makeQualifiedPath(t1Path.toString)) assert(t1Path.getFileSystem(spark.sessionState.newHadoopConf()).exists(t1Path)) checkAnswer(spark.table("t1"), Row(2) :: Nil) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala index fce055048d72f..23aea24697785 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala @@ -1681,7 +1681,7 @@ class HiveDDLSuite """.stripMargin) val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t")) - assert(table.location == new URI(dir.getAbsolutePath)) + assert(table.location == makeQualifiedPath(dir.getAbsolutePath)) checkAnswer(spark.table("t"), Row(3, 4, 1, 2)) } @@ -1701,7 +1701,7 @@ class HiveDDLSuite """.stripMargin) val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t1")) - assert(table.location == new URI(dir.getAbsolutePath)) + assert(table.location == makeQualifiedPath(dir.getAbsolutePath)) val partDir = new File(dir, "a=3") assert(partDir.exists()) From 206030bd12405623c00c1ff334663984b9250adb Mon Sep 17 00:00:00 2001 From: Jason White Date: Thu, 9 Mar 2017 10:34:54 -0800 Subject: [PATCH 59/78] [SPARK-19561][SQL] add int case handling for TimestampType MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? Add handling of input of type `Int` for dataType `TimestampType` to `EvaluatePython.scala`. Py4J serializes ints smaller than MIN_INT or larger than MAX_INT to Long, which are handled correctly already, but values between MIN_INT and MAX_INT are serialized to Int. These range limits correspond to roughly half an hour on either side of the epoch. As a result, PySpark doesn't allow TimestampType values to be created in this range. Alternatives attempted: patching the `TimestampType.toInternal` function to cast return values to `long`, so Py4J would always serialize them to Scala Long. Python3 does not have a `long` type, so this approach failed on Python3. ## How was this patch tested? Added a new PySpark-side test that fails without the change. The contribution is my original work and I license the work to the project under the project’s open source license. Resubmission of https://github.com/apache/spark/pull/16896. The original PR didn't go through Jenkins and broke the build. davies dongjoon-hyun cloud-fan Could you kick off a Jenkins run for me? It passed everything for me locally, but it's possible something has changed in the last few weeks. Author: Jason White Closes #17200 from JasonMWhite/SPARK-19561. --- python/pyspark/sql/tests.py | 8 ++++++++ .../spark/sql/execution/python/EvaluatePython.scala | 2 ++ 2 files changed, 10 insertions(+) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 81f3d1d36a342..1b873e957888c 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -1555,6 +1555,14 @@ def test_time_with_timezone(self): self.assertEqual(now, now1) self.assertEqual(now, utcnow1) + # regression test for SPARK-19561 + def test_datetime_at_epoch(self): + epoch = datetime.datetime.fromtimestamp(0) + df = self.spark.createDataFrame([Row(date=epoch)]) + first = df.select('date', lit(epoch).alias('lit_date')).first() + self.assertEqual(first['date'], epoch) + self.assertEqual(first['lit_date'], epoch) + def test_decimal(self): from decimal import Decimal schema = StructType([StructField("decimal", DecimalType(10, 5))]) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala index 46fd54e5c7420..fcd84705f7e8b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala @@ -112,6 +112,8 @@ object EvaluatePython { case (c: Int, DateType) => c case (c: Long, TimestampType) => c + // Py4J serializes values between MIN_INT and MAX_INT as Ints, not Longs + case (c: Int, TimestampType) => c.toLong case (c, StringType) => UTF8String.fromString(c.toString) From b60b9fc10a1ee52c3c021a4a5faf10f92f83e3c9 Mon Sep 17 00:00:00 2001 From: Jimmy Xiang Date: Thu, 9 Mar 2017 10:52:18 -0800 Subject: [PATCH 60/78] [SPARK-19757][CORE] DriverEndpoint#makeOffers race against CoarseGrainedSchedulerBackend#killExecutors ## What changes were proposed in this pull request? While some executors are being killed due to idleness, if some new tasks come in, driver could assign them to some executors are being killed. These tasks will fail later when the executors are lost. This patch is to make sure CoarseGrainedSchedulerBackend#killExecutors and DriverEndpoint#makeOffers are properly synchronized. ## How was this patch tested? manual tests Author: Jimmy Xiang Closes #17091 from jxiang/spark-19757. --- .../CoarseGrainedSchedulerBackend.scala | 38 +++++++++++++------ 1 file changed, 26 insertions(+), 12 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index 94abe30bb12f2..7e2cfaccfc7ba 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -222,12 +222,18 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp // Make fake resource offers on all executors private def makeOffers() { - // Filter out executors under killing - val activeExecutors = executorDataMap.filterKeys(executorIsAlive) - val workOffers = activeExecutors.map { case (id, executorData) => - new WorkerOffer(id, executorData.executorHost, executorData.freeCores) - }.toIndexedSeq - launchTasks(scheduler.resourceOffers(workOffers)) + // Make sure no executor is killed while some task is launching on it + val taskDescs = CoarseGrainedSchedulerBackend.this.synchronized { + // Filter out executors under killing + val activeExecutors = executorDataMap.filterKeys(executorIsAlive) + val workOffers = activeExecutors.map { case (id, executorData) => + new WorkerOffer(id, executorData.executorHost, executorData.freeCores) + }.toIndexedSeq + scheduler.resourceOffers(workOffers) + } + if (!taskDescs.isEmpty) { + launchTasks(taskDescs) + } } override def onDisconnected(remoteAddress: RpcAddress): Unit = { @@ -240,12 +246,20 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp // Make fake resource offers on just one executor private def makeOffers(executorId: String) { - // Filter out executors under killing - if (executorIsAlive(executorId)) { - val executorData = executorDataMap(executorId) - val workOffers = IndexedSeq( - new WorkerOffer(executorId, executorData.executorHost, executorData.freeCores)) - launchTasks(scheduler.resourceOffers(workOffers)) + // Make sure no executor is killed while some task is launching on it + val taskDescs = CoarseGrainedSchedulerBackend.this.synchronized { + // Filter out executors under killing + if (executorIsAlive(executorId)) { + val executorData = executorDataMap(executorId) + val workOffers = IndexedSeq( + new WorkerOffer(executorId, executorData.executorHost, executorData.freeCores)) + scheduler.resourceOffers(workOffers) + } else { + Seq.empty + } + } + if (!taskDescs.isEmpty) { + launchTasks(taskDescs) } } From 3232e54f2fcb8d2072cba4bc763ef29d5d8d325f Mon Sep 17 00:00:00 2001 From: jinxing Date: Thu, 9 Mar 2017 10:56:19 -0800 Subject: [PATCH 61/78] [SPARK-19793] Use clock.getTimeMillis when mark task as finished in TaskSetManager. ## What changes were proposed in this pull request? TaskSetManager is now using `System.getCurrentTimeMillis` when mark task as finished in `handleSuccessfulTask` and `handleFailedTask`. Thus developer cannot set the tasks finishing time in unit test. When `handleSuccessfulTask`, task's duration = `System.getCurrentTimeMillis` - launchTime(which can be set by `clock`), the result is not correct. ## How was this patch tested? Existing tests. Author: jinxing Closes #17133 from jinxing64/SPARK-19793. --- .../scala/org/apache/spark/scheduler/TaskInfo.scala | 6 ++++-- .../org/apache/spark/scheduler/TaskSetManager.scala | 6 +++--- .../apache/spark/scheduler/TaskSetManagerSuite.scala | 10 +++++++++- .../scala/org/apache/spark/ui/StagePageSuite.scala | 2 +- 4 files changed, 17 insertions(+), 7 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala index 59680139e7af3..9843eab4f1346 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala @@ -70,11 +70,13 @@ class TaskInfo( var killed = false - private[spark] def markGettingResult(time: Long = System.currentTimeMillis) { + private[spark] def markGettingResult(time: Long) { gettingResultTime = time } - private[spark] def markFinished(state: TaskState, time: Long = System.currentTimeMillis) { + private[spark] def markFinished(state: TaskState, time: Long) { + // finishTime should be set larger than 0, otherwise "finished" below will return false. + assert(time > 0) finishTime = time if (state == TaskState.FAILED) { failed = true diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index 19ebaf817e24e..11633bef3cfc7 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -667,7 +667,7 @@ private[spark] class TaskSetManager( */ def handleTaskGettingResult(tid: Long): Unit = { val info = taskInfos(tid) - info.markGettingResult() + info.markGettingResult(clock.getTimeMillis()) sched.dagScheduler.taskGettingResult(info) } @@ -695,7 +695,7 @@ private[spark] class TaskSetManager( def handleSuccessfulTask(tid: Long, result: DirectTaskResult[_]): Unit = { val info = taskInfos(tid) val index = info.index - info.markFinished(TaskState.FINISHED) + info.markFinished(TaskState.FINISHED, clock.getTimeMillis()) removeRunningTask(tid) // This method is called by "TaskSchedulerImpl.handleSuccessfulTask" which holds the // "TaskSchedulerImpl" lock until exiting. To avoid the SPARK-7655 issue, we should not @@ -739,7 +739,7 @@ private[spark] class TaskSetManager( return } removeRunningTask(tid) - info.markFinished(state) + info.markFinished(state, clock.getTimeMillis()) val index = info.index copiesRunning(index) -= 1 var accumUpdates: Seq[AccumulatorV2[_, _]] = Seq.empty diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index 2c2cda9f318eb..f36bcd8504b05 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -192,6 +192,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg val taskOption = manager.resourceOffer("exec1", "host1", NO_PREF) assert(taskOption.isDefined) + clock.advance(1) // Tell it the task has finished manager.handleSuccessfulTask(0, createTaskResult(0, accumUpdates)) assert(sched.endedTasks(0) === Success) @@ -377,6 +378,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg sched = new FakeTaskScheduler(sc, ("exec1", "host1")) val taskSet = FakeTask.createTaskSet(1) val clock = new ManualClock + clock.advance(1) val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock = clock) assert(manager.resourceOffer("exec1", "host1", ANY).get.index === 0) @@ -394,6 +396,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg sched = new FakeTaskScheduler(sc, ("exec1", "host1")) val taskSet = FakeTask.createTaskSet(1) val clock = new ManualClock + clock.advance(1) val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock = clock) // Fail the task MAX_TASK_FAILURES times, and check that the task set is aborted @@ -427,6 +430,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg // affinity to exec1 on host1 - which we will fail. val taskSet = FakeTask.createTaskSet(1, Seq(TaskLocation("host1", "exec1"))) val clock = new ManualClock + clock.advance(1) // We don't directly use the application blacklist, but its presence triggers blacklisting // within the taskset. val mockListenerBus = mock(classOf[LiveListenerBus]) @@ -551,7 +555,9 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg Seq(TaskLocation("host1", "execB")), Seq(TaskLocation("host2", "execC")), Seq()) - val manager = new TaskSetManager(sched, taskSet, 1, clock = new ManualClock) + val clock = new ManualClock() + clock.advance(1) + val manager = new TaskSetManager(sched, taskSet, 1, clock = clock) sched.addExecutor("execA", "host1") manager.executorAdded() sched.addExecutor("execC", "host2") @@ -904,6 +910,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg assert(task.executorId === k) } assert(sched.startedTasks.toSet === Set(0, 1, 2, 3)) + clock.advance(1) // Complete the 3 tasks and leave 1 task in running for (id <- Set(0, 1, 2)) { manager.handleSuccessfulTask(id, createTaskResult(id, accumUpdatesByTask(id))) @@ -961,6 +968,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg tasks += task } assert(sched.startedTasks.toSet === (0 until 5).toSet) + clock.advance(1) // Complete 3 tasks and leave 2 tasks in running for (id <- Set(0, 1, 2)) { manager.handleSuccessfulTask(id, createTaskResult(id, accumUpdatesByTask(id))) diff --git a/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala b/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala index 11482d187aeca..38030e066080f 100644 --- a/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala @@ -77,7 +77,7 @@ class StagePageSuite extends SparkFunSuite with LocalSparkContext { val taskInfo = new TaskInfo(taskId, taskId, 0, 0, "0", "localhost", TaskLocality.ANY, false) jobListener.onStageSubmitted(SparkListenerStageSubmitted(stageInfo)) jobListener.onTaskStart(SparkListenerTaskStart(0, 0, taskInfo)) - taskInfo.markFinished(TaskState.FINISHED) + taskInfo.markFinished(TaskState.FINISHED, System.currentTimeMillis()) val taskMetrics = TaskMetrics.empty taskMetrics.incPeakExecutionMemory(peakExecutionMemory) jobListener.onTaskEnd( From 40da4d181d648308de85fdcabc5c098ee861949a Mon Sep 17 00:00:00 2001 From: Liwei Lin Date: Thu, 9 Mar 2017 11:02:44 -0800 Subject: [PATCH 62/78] [SPARK-19715][STRUCTURED STREAMING] Option to Strip Paths in FileSource ## What changes were proposed in this pull request? Today, we compare the whole path when deciding if a file is new in the FileSource for structured streaming. However, this would cause false negatives in the case where the path has changed in a cosmetic way (i.e. changing `s3n` to `s3a`). This patch adds an option `fileNameOnly` that causes the new file check to be based only on the filename (but still store the whole path in the log). ## Usage ```scala spark .readStream .option("fileNameOnly", true) .text("s3n://bucket/dir1/dir2") .writeStream ... ``` ## How was this patch tested? Added a test case Author: Liwei Lin Closes #17120 from lw-lin/filename-only. --- .../structured-streaming-programming-guide.md | 12 +++++-- .../streaming/FileStreamOptions.scala | 34 ++++++++++++++----- .../streaming/FileStreamSource.scala | 25 +++++++++----- .../sql/streaming/FileStreamSourceSuite.scala | 22 ++++++++++-- 4 files changed, 72 insertions(+), 21 deletions(-) diff --git a/docs/structured-streaming-programming-guide.md b/docs/structured-streaming-programming-guide.md index 6af47b6efba2c..995ac77a4fb3b 100644 --- a/docs/structured-streaming-programming-guide.md +++ b/docs/structured-streaming-programming-guide.md @@ -1052,10 +1052,18 @@ Here are the details of all the sinks in Spark. Append path: path to the output directory, must be specified. +
maxFilesPerTrigger: maximum number of new files to be considered in every trigger (default: no max)
- latestFirst: whether to processs the latest new files first, useful when there is a large backlog of files(default: false) -

+ latestFirst: whether to processs the latest new files first, useful when there is a large backlog of files (default: false) +
+ fileNameOnly: whether to check new files based on only the filename instead of on the full path (default: false). With this set to `true`, the following files would be considered as the same file, because their filenames, "dataset.txt", are the same: +
+ · "file:///dataset.txt"
+ · "s3://a/dataset.txt"
+ · "s3n://a/b/dataset.txt"
+ · "s3a://a/b/c/dataset.txt"
+
For file-format-specific options, see the related methods in DataFrameWriter (Scala/Java/Python). E.g. for "parquet" format options see DataFrameWriter.parquet() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamOptions.scala index e7ba901945490..d54ed44b43bf1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamOptions.scala @@ -61,13 +61,29 @@ class FileStreamOptions(parameters: CaseInsensitiveMap[String]) extends Logging * Whether to scan latest files first. If it's true, when the source finds unprocessed files in a * trigger, it will first process the latest files. */ - val latestFirst: Boolean = parameters.get("latestFirst").map { str => - try { - str.toBoolean - } catch { - case _: IllegalArgumentException => - throw new IllegalArgumentException( - s"Invalid value '$str' for option 'latestFirst', must be 'true' or 'false'") - } - }.getOrElse(false) + val latestFirst: Boolean = withBooleanParameter("latestFirst", false) + + /** + * Whether to check new files based on only the filename instead of on the full path. + * + * With this set to `true`, the following files would be considered as the same file, because + * their filenames, "dataset.txt", are the same: + * - "file:///dataset.txt" + * - "s3://a/dataset.txt" + * - "s3n://a/b/dataset.txt" + * - "s3a://a/b/c/dataset.txt" + */ + val fileNameOnly: Boolean = withBooleanParameter("fileNameOnly", false) + + private def withBooleanParameter(name: String, default: Boolean) = { + parameters.get(name).map { str => + try { + str.toBoolean + } catch { + case _: IllegalArgumentException => + throw new IllegalArgumentException( + s"Invalid value '$str' for option '$name', must be 'true' or 'false'") + } + }.getOrElse(default) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala index 0f09b0a0c8f25..411a15ffceb6a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.streaming +import java.net.URI + import scala.collection.JavaConverters._ import org.apache.hadoop.fs.{FileStatus, Path} @@ -79,9 +81,16 @@ class FileStreamSource( sourceOptions.maxFileAgeMs } + private val fileNameOnly = sourceOptions.fileNameOnly + if (fileNameOnly) { + logWarning("'fileNameOnly' is enabled. Make sure your file names are unique (e.g. using " + + "UUID), otherwise, files with the same name but under different paths will be considered " + + "the same and causes data lost.") + } + /** A mapping from a file that we have processed to some timestamp it was last modified. */ // Visible for testing and debugging in production. - val seenFiles = new SeenFilesMap(maxFileAgeMs) + val seenFiles = new SeenFilesMap(maxFileAgeMs, fileNameOnly) metadataLog.allFiles().foreach { entry => seenFiles.add(entry.path, entry.timestamp) @@ -268,7 +277,7 @@ object FileStreamSource { * To prevent the hash map from growing indefinitely, a purge function is available to * remove files "maxAgeMs" older than the latest file. */ - class SeenFilesMap(maxAgeMs: Long) { + class SeenFilesMap(maxAgeMs: Long, fileNameOnly: Boolean) { require(maxAgeMs >= 0) /** Mapping from file to its timestamp. */ @@ -280,9 +289,13 @@ object FileStreamSource { /** Timestamp for the last purge operation. */ private var lastPurgeTimestamp: Timestamp = 0L + @inline private def stripPathIfNecessary(path: String) = { + if (fileNameOnly) new Path(new URI(path)).getName else path + } + /** Add a new file to the map. */ def add(path: String, timestamp: Timestamp): Unit = { - map.put(path, timestamp) + map.put(stripPathIfNecessary(path), timestamp) if (timestamp > latestTimestamp) { latestTimestamp = timestamp } @@ -295,7 +308,7 @@ object FileStreamSource { def isNewFile(path: String, timestamp: Timestamp): Boolean = { // Note that we are testing against lastPurgeTimestamp here so we'd never miss a file that // is older than (latestTimestamp - maxAgeMs) but has not been purged yet. - timestamp >= lastPurgeTimestamp && !map.containsKey(path) + timestamp >= lastPurgeTimestamp && !map.containsKey(stripPathIfNecessary(path)) } /** Removes aged entries and returns the number of files removed. */ @@ -314,9 +327,5 @@ object FileStreamSource { } def size: Int = map.size() - - def allEntries: Seq[(String, Timestamp)] = { - map.asScala.toSeq - } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala index 0517b0a800e53..f705da3d6a709 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala @@ -1236,7 +1236,7 @@ class FileStreamSourceSuite extends FileStreamSourceTest { } test("SeenFilesMap") { - val map = new SeenFilesMap(maxAgeMs = 10) + val map = new SeenFilesMap(maxAgeMs = 10, fileNameOnly = false) map.add("a", 5) assert(map.size == 1) @@ -1269,8 +1269,26 @@ class FileStreamSourceSuite extends FileStreamSourceTest { assert(map.isNewFile("e", 20)) } + test("SeenFilesMap with fileNameOnly = true") { + val map = new SeenFilesMap(maxAgeMs = 10, fileNameOnly = true) + + map.add("file:///a/b/c/d", 5) + map.add("file:///a/b/c/e", 5) + assert(map.size === 2) + + assert(!map.isNewFile("d", 5)) + assert(!map.isNewFile("file:///d", 5)) + assert(!map.isNewFile("file:///x/d", 5)) + assert(!map.isNewFile("file:///x/y/d", 5)) + + map.add("s3:///bucket/d", 5) + map.add("s3n:///bucket/d", 5) + map.add("s3a:///bucket/d", 5) + assert(map.size === 2) + } + test("SeenFilesMap should only consider a file old if it is earlier than last purge time") { - val map = new SeenFilesMap(maxAgeMs = 10) + val map = new SeenFilesMap(maxAgeMs = 10, fileNameOnly = false) map.add("a", 20) assert(map.size == 1) From 30b18e69361746b4d656474374d8b486bb48a19e Mon Sep 17 00:00:00 2001 From: uncleGen Date: Thu, 9 Mar 2017 11:07:31 -0800 Subject: [PATCH 63/78] [SPARK-19861][SS] watermark should not be a negative time. ## What changes were proposed in this pull request? `watermark` should not be negative. This behavior is invalid, check it before real run. ## How was this patch tested? add new unit test. Author: uncleGen Author: dylon Closes #17202 from uncleGen/SPARK-19861. --- .../scala/org/apache/spark/sql/Dataset.scala | 4 +++- .../streaming/EventTimeWatermarkSuite.scala | 23 +++++++++++++++++++ 2 files changed, 26 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 16edb35b1d43f..0a4d3a93a07e8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -563,7 +563,7 @@ class Dataset[T] private[sql]( * @param eventTime the name of the column that contains the event time of the row. * @param delayThreshold the minimum delay to wait to data to arrive late, relative to the latest * record that has been processed in the form of an interval - * (e.g. "1 minute" or "5 hours"). + * (e.g. "1 minute" or "5 hours"). NOTE: This should not be negative. * * @group streaming * @since 2.1.0 @@ -576,6 +576,8 @@ class Dataset[T] private[sql]( val parsedDelay = Option(CalendarInterval.fromString("interval " + delayThreshold)) .getOrElse(throw new AnalysisException(s"Unable to parse time delay '$delayThreshold'")) + require(parsedDelay.milliseconds >= 0 && parsedDelay.months >= 0, + s"delay threshold ($delayThreshold) should not be negative.") EventTimeWatermark(UnresolvedAttribute(eventTime), parsedDelay, logicalPlan) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala index c768525bc6855..7614ea5eb3c01 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala @@ -306,6 +306,29 @@ class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Loggin ) } + test("delay threshold should not be negative.") { + val inputData = MemoryStream[Int].toDF() + var e = intercept[IllegalArgumentException] { + inputData.withWatermark("value", "-1 year") + } + assert(e.getMessage contains "should not be negative.") + + e = intercept[IllegalArgumentException] { + inputData.withWatermark("value", "1 year -13 months") + } + assert(e.getMessage contains "should not be negative.") + + e = intercept[IllegalArgumentException] { + inputData.withWatermark("value", "1 month -40 days") + } + assert(e.getMessage contains "should not be negative.") + + e = intercept[IllegalArgumentException] { + inputData.withWatermark("value", "-10 seconds") + } + assert(e.getMessage contains "should not be negative.") + } + test("the new watermark should override the old one") { val df = MemoryStream[(Long, Long)].toDF() .withColumn("first", $"_1".cast("timestamp")) From cabe1df8606e7e5b9e6efb106045deb3f39f5f13 Mon Sep 17 00:00:00 2001 From: Jeff Zhang Date: Thu, 9 Mar 2017 11:44:34 -0800 Subject: [PATCH 64/78] [SPARK-12334][SQL][PYSPARK] Support read from multiple input paths for orc file in DataFrameReader.orc Beside the issue in spark api, also fix 2 minor issues in pyspark - support read from multiple input paths for orc - support read from multiple input paths for text Author: Jeff Zhang Closes #10307 from zjffdu/SPARK-12334. --- python/pyspark/sql/readwriter.py | 14 ++++++++------ python/pyspark/sql/tests.py | 5 +++++ .../org/apache/spark/sql/DataFrameReader.scala | 6 +++--- .../apache/spark/sql/hive/orc/OrcQuerySuite.scala | 9 +++++++++ 4 files changed, 25 insertions(+), 9 deletions(-) diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index 45fb9b7591529..4354345ebc550 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -161,7 +161,7 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, mode=None, columnNameOfCorruptRecord=None, dateFormat=None, timestampFormat=None, timeZone=None, wholeFile=None): """ - Loads a JSON file and returns the results as a :class:`DataFrame`. + Loads JSON files and returns the results as a :class:`DataFrame`. `JSON Lines `_(newline-delimited JSON) is supported by default. For JSON (one record per file), set the `wholeFile` parameter to ``true``. @@ -169,7 +169,7 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, If the ``schema`` parameter is not specified, this function goes through the input once to determine the input schema. - :param path: string represents path to the JSON dataset, + :param path: string represents path to the JSON dataset, or a list of paths, or RDD of Strings storing JSON objects. :param schema: an optional :class:`pyspark.sql.types.StructType` for the input schema. :param primitivesAsString: infers all primitive values as a string type. If None is set, @@ -252,7 +252,7 @@ def func(iterator): jrdd = keyed._jrdd.map(self._spark._jvm.BytesToString()) return self._df(self._jreader.json(jrdd)) else: - raise TypeError("path can be only string or RDD") + raise TypeError("path can be only string, list or RDD") @since(1.4) def table(self, tableName): @@ -269,7 +269,7 @@ def table(self, tableName): @since(1.4) def parquet(self, *paths): - """Loads a Parquet file, returning the result as a :class:`DataFrame`. + """Loads Parquet files, returning the result as a :class:`DataFrame`. You can set the following Parquet-specific option(s) for reading Parquet files: * ``mergeSchema``: sets whether we should merge schemas collected from all \ @@ -407,7 +407,7 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non @since(1.5) def orc(self, path): - """Loads an ORC file, returning the result as a :class:`DataFrame`. + """Loads ORC files, returning the result as a :class:`DataFrame`. .. note:: Currently ORC support is only available together with Hive support. @@ -415,7 +415,9 @@ def orc(self, path): >>> df.dtypes [('a', 'bigint'), ('b', 'int'), ('c', 'int')] """ - return self._df(self._jreader.orc(path)) + if isinstance(path, basestring): + path = [path] + return self._df(self._jreader.orc(_to_seq(self._spark._sc, path))) @since(1.4) def jdbc(self, url, table, column=None, lowerBound=None, upperBound=None, numPartitions=None, diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 1b873e957888c..f0a9a0400e392 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -450,6 +450,11 @@ def test_wholefile_csv(self): Row(_c0=u'Hyukjin', _c1=u'25', _c2=u'I am Hyukjin\n\nI love Spark!')] self.assertEqual(ages_newlines.collect(), expected) + def test_read_multiple_orc_file(self): + df = self.spark.read.orc(["python/test_support/sql/orc_partitioned/b=0/c=0", + "python/test_support/sql/orc_partitioned/b=1/c=1"]) + self.assertEqual(2, df.count()) + def test_udf_with_input_file_name(self): from pyspark.sql.functions import udf, input_file_name from pyspark.sql.types import StringType diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index a5e38e25b1ec5..4f4cc93117494 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -262,7 +262,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { } /** - * Loads a JSON file and returns the results as a `DataFrame`. + * Loads JSON files and returns the results as a `DataFrame`. * * JSON Lines (newline-delimited JSON) is supported by * default. For JSON (one record per file), set the `wholeFile` option to true. @@ -438,7 +438,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { } /** - * Loads a CSV file and returns the result as a `DataFrame`. + * Loads CSV files and returns the result as a `DataFrame`. * * This function will go through the input once to determine the input schema if `inferSchema` * is enabled. To avoid going through the entire data once, disable `inferSchema` option or @@ -549,7 +549,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { } /** - * Loads an ORC file and returns the result as a `DataFrame`. + * Loads ORC files and returns the result as a `DataFrame`. * * @param paths input paths * @since 2.0.0 diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala index 38a5477796a4a..5d8ba9d7c85d1 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala @@ -33,6 +33,7 @@ import org.apache.spark.sql.hive.test.TestHive._ import org.apache.spark.sql.hive.test.TestHive.implicits._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{IntegerType, StructType} +import org.apache.spark.util.Utils case class AllDataTypesWithNonPrimitiveType( stringField: String, @@ -611,4 +612,12 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { } } } + + test("read from multiple orc input paths") { + val path1 = Utils.createTempDir() + val path2 = Utils.createTempDir() + makeOrcFile((1 to 10).map(Tuple1.apply), path1) + makeOrcFile((1 to 10).map(Tuple1.apply), path2) + assertResult(20)(read.orc(path1.getCanonicalPath, path2.getCanonicalPath).count()) + } } From f79371ad86d94da14bd1ddb53e99a388017b6892 Mon Sep 17 00:00:00 2001 From: Budde Date: Thu, 9 Mar 2017 12:55:33 -0800 Subject: [PATCH 65/78] [SPARK-19611][SQL] Introduce configurable table schema inference ## Summary of changes Add a new configuration option that allows Spark SQL to infer a case-sensitive schema from a Hive Metastore table's data files when a case-sensitive schema can't be read from the table properties. - Add spark.sql.hive.caseSensitiveInferenceMode param to SQLConf - Add schemaPreservesCase field to CatalogTable (set to false when schema can't successfully be read from Hive table props) - Perform schema inference in HiveMetastoreCatalog if schemaPreservesCase is false, depending on spark.sql.hive.caseSensitiveInferenceMode - Add alterTableSchema() method to the ExternalCatalog interface - Add HiveSchemaInferenceSuite tests - Refactor and move ParquetFileForamt.meregeMetastoreParquetSchema() as HiveMetastoreCatalog.mergeWithMetastoreSchema - Move schema merging tests from ParquetSchemaSuite to HiveSchemaInferenceSuite [JIRA for this change](https://issues.apache.org/jira/browse/SPARK-19611) ## How was this patch tested? The tests in ```HiveSchemaInferenceSuite``` should verify that schema inference is working as expected. ```ExternalCatalogSuite``` has also been extended to cover the new ```alterTableSchema()``` API. Author: Budde Closes #16944 from budde/SPARK-19611. --- .../catalyst/catalog/ExternalCatalog.scala | 15 +- .../catalyst/catalog/InMemoryCatalog.scala | 10 + .../sql/catalyst/catalog/interface.scala | 8 +- .../catalog/ExternalCatalogSuite.scala | 15 +- .../sql/catalyst/trees/TreeNodeSuite.scala | 3 +- .../parquet/ParquetFileFormat.scala | 65 ---- .../apache/spark/sql/internal/SQLConf.scala | 22 ++ .../parquet/ParquetSchemaSuite.scala | 82 ----- .../spark/sql/hive/HiveExternalCatalog.scala | 23 +- .../spark/sql/hive/HiveMetastoreCatalog.scala | 100 +++++- .../sql/hive/HiveSchemaInferenceSuite.scala | 305 ++++++++++++++++++ 11 files changed, 489 insertions(+), 159 deletions(-) create mode 100644 sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSchemaInferenceSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalog.scala index 31eded4deba7d..08a01e8601897 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalog.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.catalog import org.apache.spark.sql.catalyst.analysis.{FunctionAlreadyExistsException, NoSuchDatabaseException, NoSuchFunctionException, NoSuchTableException} import org.apache.spark.sql.catalyst.expressions.Expression - +import org.apache.spark.sql.types.StructType /** * Interface for the system catalog (of functions, partitions, tables, and databases). @@ -104,6 +104,19 @@ abstract class ExternalCatalog { */ def alterTable(tableDefinition: CatalogTable): Unit + /** + * Alter the schema of a table identified by the provided database and table name. The new schema + * should still contain the existing bucket columns and partition columns used by the table. This + * method will also update any Spark SQL-related parameters stored as Hive table properties (such + * as the schema itself). + * + * @param db Database that table to alter schema for exists in + * @param table Name of table to alter schema for + * @param schema Updated schema to be used for the table (must contain existing partition and + * bucket columns) + */ + def alterTableSchema(db: String, table: String, schema: StructType): Unit + def getTable(db: String, table: String): CatalogTable def getTableOption(db: String, table: String): Option[CatalogTable] diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala index 80aba4af9436c..5cc6b0abc6fde 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala @@ -31,6 +31,7 @@ import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils.escapePathName import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.util.StringUtils +import org.apache.spark.sql.types.StructType /** * An in-memory (ephemeral) implementation of the system catalog. @@ -297,6 +298,15 @@ class InMemoryCatalog( catalog(db).tables(tableDefinition.identifier.table).table = tableDefinition } + override def alterTableSchema( + db: String, + table: String, + schema: StructType): Unit = synchronized { + requireTableExists(db, table) + val origTable = catalog(db).tables(table).table + catalog(db).tables(table).table = origTable.copy(schema = schema) + } + override def getTable(db: String, table: String): CatalogTable = synchronized { requireTableExists(db, table) catalog(db).tables(table).table diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala index 4452c479875fa..e3631b0c07737 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala @@ -163,6 +163,11 @@ case class BucketSpec( * @param tracksPartitionsInCatalog whether this table's partition metadata is stored in the * catalog. If false, it is inferred automatically based on file * structure. + * @param schemaPresevesCase Whether or not the schema resolved for this table is case-sensitive. + * When using a Hive Metastore, this flag is set to false if a case- + * sensitive schema was unable to be read from the table properties. + * Used to trigger case-sensitive schema inference at query time, when + * configured. */ case class CatalogTable( identifier: TableIdentifier, @@ -180,7 +185,8 @@ case class CatalogTable( viewText: Option[String] = None, comment: Option[String] = None, unsupportedFeatures: Seq[String] = Seq.empty, - tracksPartitionsInCatalog: Boolean = false) { + tracksPartitionsInCatalog: Boolean = false, + schemaPreservesCase: Boolean = true) { import CatalogTable._ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogSuite.scala index 07ccd68698e94..7820f39d96426 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogSuite.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} import org.apache.spark.sql.catalyst.analysis.{FunctionAlreadyExistsException, NoSuchDatabaseException, NoSuchFunctionException} import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -240,6 +240,19 @@ abstract class ExternalCatalogSuite extends SparkFunSuite with BeforeAndAfterEac } } + test("alter table schema") { + val catalog = newBasicCatalog() + val tbl1 = catalog.getTable("db2", "tbl1") + val newSchema = StructType(Seq( + StructField("new_field_1", IntegerType), + StructField("new_field_2", StringType), + StructField("a", IntegerType), + StructField("b", StringType))) + catalog.alterTableSchema("db2", "tbl1", newSchema) + val newTbl1 = catalog.getTable("db2", "tbl1") + assert(newTbl1.schema == newSchema) + } + test("get table") { assert(newBasicCatalog().getTable("db2", "tbl1").identifier.table == "tbl1") } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala index af1eaa1f23746..37e3dfabd0b21 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala @@ -491,7 +491,8 @@ class TreeNodeSuite extends SparkFunSuite { "lastAccessTime" -> -1, "tracksPartitionsInCatalog" -> false, "properties" -> JNull, - "unsupportedFeatures" -> List.empty[String])) + "unsupportedFeatures" -> List.empty[String], + "schemaPreservesCase" -> JBool(true))) // For unknown case class, returns JNull. val bigValue = new Array[Int](10000) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala index 828949eddc8ec..5313c2f3746a5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala @@ -475,71 +475,6 @@ object ParquetFileFormat extends Logging { } } - /** - * Reconciles Hive Metastore case insensitivity issue and data type conflicts between Metastore - * schema and Parquet schema. - * - * Hive doesn't retain case information, while Parquet is case sensitive. On the other hand, the - * schema read from Parquet files may be incomplete (e.g. older versions of Parquet doesn't - * distinguish binary and string). This method generates a correct schema by merging Metastore - * schema data types and Parquet schema field names. - */ - def mergeMetastoreParquetSchema( - metastoreSchema: StructType, - parquetSchema: StructType): StructType = { - def schemaConflictMessage: String = - s"""Converting Hive Metastore Parquet, but detected conflicting schemas. Metastore schema: - |${metastoreSchema.prettyJson} - | - |Parquet schema: - |${parquetSchema.prettyJson} - """.stripMargin - - val mergedParquetSchema = mergeMissingNullableFields(metastoreSchema, parquetSchema) - - assert(metastoreSchema.size <= mergedParquetSchema.size, schemaConflictMessage) - - val ordinalMap = metastoreSchema.zipWithIndex.map { - case (field, index) => field.name.toLowerCase -> index - }.toMap - - val reorderedParquetSchema = mergedParquetSchema.sortBy(f => - ordinalMap.getOrElse(f.name.toLowerCase, metastoreSchema.size + 1)) - - StructType(metastoreSchema.zip(reorderedParquetSchema).map { - // Uses Parquet field names but retains Metastore data types. - case (mSchema, pSchema) if mSchema.name.toLowerCase == pSchema.name.toLowerCase => - mSchema.copy(name = pSchema.name) - case _ => - throw new SparkException(schemaConflictMessage) - }) - } - - /** - * Returns the original schema from the Parquet file with any missing nullable fields from the - * Hive Metastore schema merged in. - * - * When constructing a DataFrame from a collection of structured data, the resulting object has - * a schema corresponding to the union of the fields present in each element of the collection. - * Spark SQL simply assigns a null value to any field that isn't present for a particular row. - * In some cases, it is possible that a given table partition stored as a Parquet file doesn't - * contain a particular nullable field in its schema despite that field being present in the - * table schema obtained from the Hive Metastore. This method returns a schema representing the - * Parquet file schema along with any additional nullable fields from the Metastore schema - * merged in. - */ - private[parquet] def mergeMissingNullableFields( - metastoreSchema: StructType, - parquetSchema: StructType): StructType = { - val fieldMap = metastoreSchema.map(f => f.name.toLowerCase -> f).toMap - val missingFields = metastoreSchema - .map(_.name.toLowerCase) - .diff(parquetSchema.map(_.name.toLowerCase)) - .map(fieldMap(_)) - .filter(_.nullable) - StructType(parquetSchema ++ missingFields) - } - /** * Reads Parquet footers in multi-threaded manner. * If the config "spark.sql.files.ignoreCorruptFiles" is set to true, we will ignore the corrupted diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 1244f690fd829..8e3f567b7dd90 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -296,6 +296,25 @@ object SQLConf { .longConf .createWithDefault(250 * 1024 * 1024) + object HiveCaseSensitiveInferenceMode extends Enumeration { + val INFER_AND_SAVE, INFER_ONLY, NEVER_INFER = Value + } + + val HIVE_CASE_SENSITIVE_INFERENCE = buildConf("spark.sql.hive.caseSensitiveInferenceMode") + .doc("Sets the action to take when a case-sensitive schema cannot be read from a Hive " + + "table's properties. Although Spark SQL itself is not case-sensitive, Hive compatible file " + + "formats such as Parquet are. Spark SQL must use a case-preserving schema when querying " + + "any table backed by files containing case-sensitive field names or queries may not return " + + "accurate results. Valid options include INFER_AND_SAVE (the default mode-- infer the " + + "case-sensitive schema from the underlying data files and write it back to the table " + + "properties), INFER_ONLY (infer the schema but don't attempt to write it to the table " + + "properties) and NEVER_INFER (fallback to using the case-insensitive metastore schema " + + "instead of inferring).") + .stringConf + .transform(_.toUpperCase()) + .checkValues(HiveCaseSensitiveInferenceMode.values.map(_.toString)) + .createWithDefault(HiveCaseSensitiveInferenceMode.INFER_AND_SAVE.toString) + val OPTIMIZER_METADATA_ONLY = buildConf("spark.sql.optimizer.metadataOnly") .doc("When true, enable the metadata-only query optimization that use the table's metadata " + "to produce the partition columns instead of table scans. It applies when all the columns " + @@ -792,6 +811,9 @@ private[sql] class SQLConf extends Serializable with CatalystConf with Logging { def filesourcePartitionFileCacheSize: Long = getConf(HIVE_FILESOURCE_PARTITION_FILE_CACHE_SIZE) + def caseSensitiveInferenceMode: HiveCaseSensitiveInferenceMode.Value = + HiveCaseSensitiveInferenceMode.withName(getConf(HIVE_CASE_SENSITIVE_INFERENCE)) + def gatherFastStats: Boolean = getConf(GATHER_FASTSTAT) def optimizerMetadataOnly: Boolean = getConf(OPTIMIZER_METADATA_ONLY) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala index 8a980a7eb538f..6aa940afbb2c4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala @@ -368,88 +368,6 @@ class ParquetSchemaSuite extends ParquetSchemaTest { } } - test("merge with metastore schema") { - // Field type conflict resolution - assertResult( - StructType(Seq( - StructField("lowerCase", StringType), - StructField("UPPERCase", DoubleType, nullable = false)))) { - - ParquetFileFormat.mergeMetastoreParquetSchema( - StructType(Seq( - StructField("lowercase", StringType), - StructField("uppercase", DoubleType, nullable = false))), - - StructType(Seq( - StructField("lowerCase", BinaryType), - StructField("UPPERCase", IntegerType, nullable = true)))) - } - - // MetaStore schema is subset of parquet schema - assertResult( - StructType(Seq( - StructField("UPPERCase", DoubleType, nullable = false)))) { - - ParquetFileFormat.mergeMetastoreParquetSchema( - StructType(Seq( - StructField("uppercase", DoubleType, nullable = false))), - - StructType(Seq( - StructField("lowerCase", BinaryType), - StructField("UPPERCase", IntegerType, nullable = true)))) - } - - // Metastore schema contains additional non-nullable fields. - assert(intercept[Throwable] { - ParquetFileFormat.mergeMetastoreParquetSchema( - StructType(Seq( - StructField("uppercase", DoubleType, nullable = false), - StructField("lowerCase", BinaryType, nullable = false))), - - StructType(Seq( - StructField("UPPERCase", IntegerType, nullable = true)))) - }.getMessage.contains("detected conflicting schemas")) - - // Conflicting non-nullable field names - intercept[Throwable] { - ParquetFileFormat.mergeMetastoreParquetSchema( - StructType(Seq(StructField("lower", StringType, nullable = false))), - StructType(Seq(StructField("lowerCase", BinaryType)))) - } - } - - test("merge missing nullable fields from Metastore schema") { - // Standard case: Metastore schema contains additional nullable fields not present - // in the Parquet file schema. - assertResult( - StructType(Seq( - StructField("firstField", StringType, nullable = true), - StructField("secondField", StringType, nullable = true), - StructField("thirdfield", StringType, nullable = true)))) { - ParquetFileFormat.mergeMetastoreParquetSchema( - StructType(Seq( - StructField("firstfield", StringType, nullable = true), - StructField("secondfield", StringType, nullable = true), - StructField("thirdfield", StringType, nullable = true))), - StructType(Seq( - StructField("firstField", StringType, nullable = true), - StructField("secondField", StringType, nullable = true)))) - } - - // Merge should fail if the Metastore contains any additional fields that are not - // nullable. - assert(intercept[Throwable] { - ParquetFileFormat.mergeMetastoreParquetSchema( - StructType(Seq( - StructField("firstfield", StringType, nullable = true), - StructField("secondfield", StringType, nullable = true), - StructField("thirdfield", StringType, nullable = false))), - StructType(Seq( - StructField("firstField", StringType, nullable = true), - StructField("secondField", StringType, nullable = true)))) - }.getMessage.contains("detected conflicting schemas")) - } - test("schema merging failure error message") { import testImplicits._ diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala index 9ab4624594924..78aa2bd2494f3 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala @@ -597,6 +597,25 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat } } + override def alterTableSchema(db: String, table: String, schema: StructType): Unit = withClient { + requireTableExists(db, table) + val rawTable = getRawTable(db, table) + val withNewSchema = rawTable.copy(schema = schema) + // Add table metadata such as table schema, partition columns, etc. to table properties. + val updatedTable = withNewSchema.copy( + properties = withNewSchema.properties ++ tableMetaToTableProps(withNewSchema)) + try { + client.alterTable(updatedTable) + } catch { + case NonFatal(e) => + val warningMessage = + s"Could not alter schema of table ${rawTable.identifier.quotedString} in a Hive " + + "compatible way. Updating Hive metastore in Spark SQL specific format." + logWarning(warningMessage, e) + client.alterTable(updatedTable.copy(schema = updatedTable.partitionSchema)) + } + } + override def getTable(db: String, table: String): CatalogTable = withClient { restoreTableMetadata(getRawTable(db, table)) } @@ -690,10 +709,10 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat "different from the schema when this table was created by Spark SQL" + s"(${schemaFromTableProps.simpleString}). We have to fall back to the table schema " + "from Hive metastore which is not case preserving.") - hiveTable + hiveTable.copy(schemaPreservesCase = false) } } else { - hiveTable + hiveTable.copy(schemaPreservesCase = false) } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index d135dfa9f4157..056af495590f7 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -19,9 +19,12 @@ package org.apache.spark.sql.hive import java.net.URI +import scala.util.control.NonFatal + import com.google.common.util.concurrent.Striped import org.apache.hadoop.fs.Path +import org.apache.spark.SparkException import org.apache.spark.internal.Logging import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.{QualifiedTableName, TableIdentifier} @@ -32,6 +35,7 @@ import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.parquet.{ParquetFileFormat, ParquetOptions} import org.apache.spark.sql.hive.orc.OrcFileFormat +import org.apache.spark.sql.internal.SQLConf.HiveCaseSensitiveInferenceMode._ import org.apache.spark.sql.types._ /** @@ -44,6 +48,7 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log // these are def_s and not val/lazy val since the latter would introduce circular references private def sessionState = sparkSession.sessionState.asInstanceOf[HiveSessionState] private def tableRelationCache = sparkSession.sessionState.catalog.tableRelationCache + import HiveMetastoreCatalog._ private def getCurrentDatabase: String = sessionState.catalog.getCurrentDatabase @@ -130,6 +135,8 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log val lazyPruningEnabled = sparkSession.sqlContext.conf.manageFilesourcePartitions val tablePath = new Path(relation.tableMeta.location) + val fileFormat = fileFormatClass.newInstance() + val result = if (relation.isPartitioned) { val partitionSchema = relation.tableMeta.partitionSchema val rootPaths: Seq[Path] = if (lazyPruningEnabled) { @@ -170,16 +177,18 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log } } + val (dataSchema, updatedTable) = + inferIfNeeded(relation, options, fileFormat, Option(fileIndex)) + val fsRelation = HadoopFsRelation( location = fileIndex, partitionSchema = partitionSchema, - dataSchema = relation.tableMeta.dataSchema, + dataSchema = dataSchema, // We don't support hive bucketed tables, only ones we write out. bucketSpec = None, - fileFormat = fileFormatClass.newInstance(), + fileFormat = fileFormat, options = options)(sparkSession = sparkSession) - - val created = LogicalRelation(fsRelation, catalogTable = Some(relation.tableMeta)) + val created = LogicalRelation(fsRelation, catalogTable = Some(updatedTable)) tableRelationCache.put(tableIdentifier, created) created } @@ -196,17 +205,18 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log fileFormatClass, None) val logicalRelation = cached.getOrElse { + val (dataSchema, updatedTable) = inferIfNeeded(relation, options, fileFormat) val created = LogicalRelation( DataSource( sparkSession = sparkSession, paths = rootPath.toString :: Nil, - userSpecifiedSchema = Some(metastoreSchema), + userSpecifiedSchema = Option(dataSchema), // We don't support hive bucketed tables, only ones we write out. bucketSpec = None, options = options, className = fileType).resolveRelation(), - catalogTable = Some(relation.tableMeta)) + catalogTable = Some(updatedTable)) tableRelationCache.put(tableIdentifier, created) created @@ -218,6 +228,54 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log result.copy(expectedOutputAttributes = Some(relation.output)) } + private def inferIfNeeded( + relation: CatalogRelation, + options: Map[String, String], + fileFormat: FileFormat, + fileIndexOpt: Option[FileIndex] = None): (StructType, CatalogTable) = { + val inferenceMode = sparkSession.sessionState.conf.caseSensitiveInferenceMode + val shouldInfer = (inferenceMode != NEVER_INFER) && !relation.tableMeta.schemaPreservesCase + val tableName = relation.tableMeta.identifier.unquotedString + if (shouldInfer) { + logInfo(s"Inferring case-sensitive schema for table $tableName (inference mode: " + + s"$inferenceMode)") + val fileIndex = fileIndexOpt.getOrElse { + val rootPath = new Path(relation.tableMeta.location) + new InMemoryFileIndex(sparkSession, Seq(rootPath), options, None) + } + + val inferredSchema = fileFormat + .inferSchema( + sparkSession, + options, + fileIndex.listFiles(Nil).flatMap(_.files)) + .map(mergeWithMetastoreSchema(relation.tableMeta.schema, _)) + + inferredSchema match { + case Some(schema) => + if (inferenceMode == INFER_AND_SAVE) { + updateCatalogSchema(relation.tableMeta.identifier, schema) + } + (schema, relation.tableMeta.copy(schema = schema)) + case None => + logWarning(s"Unable to infer schema for table $tableName from file format " + + s"$fileFormat (inference mode: $inferenceMode). Using metastore schema.") + (relation.tableMeta.schema, relation.tableMeta) + } + } else { + (relation.tableMeta.schema, relation.tableMeta) + } + } + + private def updateCatalogSchema(identifier: TableIdentifier, schema: StructType): Unit = try { + val db = identifier.database.get + logInfo(s"Saving case-sensitive schema for table ${identifier.unquotedString}") + sparkSession.sharedState.externalCatalog.alterTableSchema(db, identifier.table, schema) + } catch { + case NonFatal(ex) => + logWarning(s"Unable to save case-sensitive schema for table ${identifier.unquotedString}", ex) + } + /** * When scanning or writing to non-partitioned Metastore Parquet tables, convert them to Parquet * data source relations for better performance. @@ -287,3 +345,33 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log } } } + +private[hive] object HiveMetastoreCatalog { + def mergeWithMetastoreSchema( + metastoreSchema: StructType, + inferredSchema: StructType): StructType = try { + // Find any nullable fields in mestastore schema that are missing from the inferred schema. + val metastoreFields = metastoreSchema.map(f => f.name.toLowerCase -> f).toMap + val missingNullables = metastoreFields + .filterKeys(!inferredSchema.map(_.name.toLowerCase).contains(_)) + .values + .filter(_.nullable) + + // Merge missing nullable fields to inferred schema and build a case-insensitive field map. + val inferredFields = StructType(inferredSchema ++ missingNullables) + .map(f => f.name.toLowerCase -> f).toMap + StructType(metastoreFields.map { case(name, field) => + field.copy(name = inferredFields(name).name) + }.toSeq) + } catch { + case NonFatal(_) => + val msg = s"""Detected conflicting schemas when merging the schema obtained from the Hive + | Metastore with the one inferred from the file format. Metastore schema: + |${metastoreSchema.prettyJson} + | + |Inferred schema: + |${inferredSchema.prettyJson} + """.stripMargin + throw new SparkException(msg) + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSchemaInferenceSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSchemaInferenceSuite.scala new file mode 100644 index 0000000000000..78955803819cf --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSchemaInferenceSuite.scala @@ -0,0 +1,305 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import java.io.File +import java.util.concurrent.{Executors, TimeUnit} + +import scala.util.Random + +import org.scalatest.BeforeAndAfterEach + +import org.apache.spark.metrics.source.HiveCatalogMetrics +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.catalog._ +import org.apache.spark.sql.execution.datasources.FileStatusCache +import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.hive.client.HiveClient +import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.internal.{HiveSerDe, SQLConf} +import org.apache.spark.sql.internal.SQLConf.HiveCaseSensitiveInferenceMode.{Value => InferenceMode, _} +import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.types._ + +class HiveSchemaInferenceSuite + extends QueryTest with TestHiveSingleton with SQLTestUtils with BeforeAndAfterEach { + + import HiveSchemaInferenceSuite._ + import HiveExternalCatalog.DATASOURCE_SCHEMA_PREFIX + + override def beforeEach(): Unit = { + super.beforeEach() + FileStatusCache.resetForTesting() + } + + override def afterEach(): Unit = { + super.afterEach() + spark.sessionState.catalog.tableRelationCache.invalidateAll() + FileStatusCache.resetForTesting() + } + + private val externalCatalog = spark.sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog] + private val client = externalCatalog.client + + // Return a copy of the given schema with all field names converted to lower case. + private def lowerCaseSchema(schema: StructType): StructType = { + StructType(schema.map(f => f.copy(name = f.name.toLowerCase))) + } + + // Create a Hive external test table containing the given field and partition column names. + // Returns a case-sensitive schema for the table. + private def setupExternalTable( + fileType: String, + fields: Seq[String], + partitionCols: Seq[String], + dir: File): StructType = { + // Treat all table fields as bigints... + val structFields = fields.map { field => + StructField( + name = field, + dataType = LongType, + nullable = true, + metadata = new MetadataBuilder().putString(HIVE_TYPE_STRING, "bigint").build()) + } + // and all partition columns as ints + val partitionStructFields = partitionCols.map { field => + StructField( + // Partition column case isn't preserved + name = field.toLowerCase, + dataType = IntegerType, + nullable = true, + metadata = new MetadataBuilder().putString(HIVE_TYPE_STRING, "int").build()) + } + val schema = StructType(structFields ++ partitionStructFields) + + // Write some test data (partitioned if specified) + val writer = spark.range(NUM_RECORDS) + .selectExpr((fields ++ partitionCols).map("id as " + _): _*) + .write + .partitionBy(partitionCols: _*) + .mode("overwrite") + fileType match { + case ORC_FILE_TYPE => + writer.orc(dir.getAbsolutePath) + case PARQUET_FILE_TYPE => + writer.parquet(dir.getAbsolutePath) + } + + // Create Hive external table with lowercased schema + val serde = HiveSerDe.serdeMap(fileType) + client.createTable( + CatalogTable( + identifier = TableIdentifier(table = TEST_TABLE_NAME, database = Option(DATABASE)), + tableType = CatalogTableType.EXTERNAL, + storage = CatalogStorageFormat( + locationUri = Option(new java.net.URI(dir.getAbsolutePath)), + inputFormat = serde.inputFormat, + outputFormat = serde.outputFormat, + serde = serde.serde, + compressed = false, + properties = Map("serialization.format" -> "1")), + schema = schema, + provider = Option("hive"), + partitionColumnNames = partitionCols.map(_.toLowerCase), + properties = Map.empty), + true) + + // Add partition records (if specified) + if (!partitionCols.isEmpty) { + spark.catalog.recoverPartitions(TEST_TABLE_NAME) + } + + // Check that the table returned by HiveExternalCatalog has schemaPreservesCase set to false + // and that the raw table returned by the Hive client doesn't have any Spark SQL properties + // set (table needs to be obtained from client since HiveExternalCatalog filters these + // properties out). + assert(!externalCatalog.getTable(DATABASE, TEST_TABLE_NAME).schemaPreservesCase) + val rawTable = client.getTable(DATABASE, TEST_TABLE_NAME) + assert(rawTable.properties.filterKeys(_.startsWith(DATASOURCE_SCHEMA_PREFIX)) == Map.empty) + schema + } + + private def withTestTables( + fileType: String)(f: (Seq[String], Seq[String], StructType) => Unit): Unit = { + // Test both a partitioned and unpartitioned Hive table + val tableFields = Seq( + (Seq("fieldOne"), Seq("partCol1", "partCol2")), + (Seq("fieldOne", "fieldTwo"), Seq.empty[String])) + + tableFields.foreach { case (fields, partCols) => + withTempDir { dir => + val schema = setupExternalTable(fileType, fields, partCols, dir) + withTable(TEST_TABLE_NAME) { f(fields, partCols, schema) } + } + } + } + + private def withFileTypes(f: (String) => Unit): Unit + = Seq(ORC_FILE_TYPE, PARQUET_FILE_TYPE).foreach(f) + + private def withInferenceMode(mode: InferenceMode)(f: => Unit): Unit = { + withSQLConf( + HiveUtils.CONVERT_METASTORE_ORC.key -> "true", + SQLConf.HIVE_CASE_SENSITIVE_INFERENCE.key -> mode.toString)(f) + } + + private val inferenceKey = SQLConf.HIVE_CASE_SENSITIVE_INFERENCE.key + + private def testFieldQuery(fields: Seq[String]): Unit = { + if (!fields.isEmpty) { + val query = s"SELECT * FROM ${TEST_TABLE_NAME} WHERE ${Random.shuffle(fields).head} >= 0" + assert(spark.sql(query).count == NUM_RECORDS) + } + } + + private def testTableSchema(expectedSchema: StructType): Unit + = assert(spark.table(TEST_TABLE_NAME).schema == expectedSchema) + + withFileTypes { fileType => + test(s"$fileType: schema should be inferred and saved when INFER_AND_SAVE is specified") { + withInferenceMode(INFER_AND_SAVE) { + withTestTables(fileType) { (fields, partCols, schema) => + testFieldQuery(fields) + testFieldQuery(partCols) + testTableSchema(schema) + + // Verify the catalog table now contains the updated schema and properties + val catalogTable = externalCatalog.getTable(DATABASE, TEST_TABLE_NAME) + assert(catalogTable.schemaPreservesCase) + assert(catalogTable.schema == schema) + assert(catalogTable.partitionColumnNames == partCols.map(_.toLowerCase)) + } + } + } + } + + withFileTypes { fileType => + test(s"$fileType: schema should be inferred but not stored when INFER_ONLY is specified") { + withInferenceMode(INFER_ONLY) { + withTestTables(fileType) { (fields, partCols, schema) => + val originalTable = externalCatalog.getTable(DATABASE, TEST_TABLE_NAME) + testFieldQuery(fields) + testFieldQuery(partCols) + testTableSchema(schema) + // Catalog table shouldn't be altered + assert(externalCatalog.getTable(DATABASE, TEST_TABLE_NAME) == originalTable) + } + } + } + } + + withFileTypes { fileType => + test(s"$fileType: schema should not be inferred when NEVER_INFER is specified") { + withInferenceMode(NEVER_INFER) { + withTestTables(fileType) { (fields, partCols, schema) => + val originalTable = externalCatalog.getTable(DATABASE, TEST_TABLE_NAME) + // Only check the table schema as the test queries will break + testTableSchema(lowerCaseSchema(schema)) + assert(externalCatalog.getTable(DATABASE, TEST_TABLE_NAME) == originalTable) + } + } + } + } + + test("mergeWithMetastoreSchema() should return expected results") { + // Field type conflict resolution + assertResult( + StructType(Seq( + StructField("lowerCase", StringType), + StructField("UPPERCase", DoubleType, nullable = false)))) { + + HiveMetastoreCatalog.mergeWithMetastoreSchema( + StructType(Seq( + StructField("lowercase", StringType), + StructField("uppercase", DoubleType, nullable = false))), + + StructType(Seq( + StructField("lowerCase", BinaryType), + StructField("UPPERCase", IntegerType, nullable = true)))) + } + + // MetaStore schema is subset of parquet schema + assertResult( + StructType(Seq( + StructField("UPPERCase", DoubleType, nullable = false)))) { + + HiveMetastoreCatalog.mergeWithMetastoreSchema( + StructType(Seq( + StructField("uppercase", DoubleType, nullable = false))), + + StructType(Seq( + StructField("lowerCase", BinaryType), + StructField("UPPERCase", IntegerType, nullable = true)))) + } + + // Metastore schema contains additional non-nullable fields. + assert(intercept[Throwable] { + HiveMetastoreCatalog.mergeWithMetastoreSchema( + StructType(Seq( + StructField("uppercase", DoubleType, nullable = false), + StructField("lowerCase", BinaryType, nullable = false))), + + StructType(Seq( + StructField("UPPERCase", IntegerType, nullable = true)))) + }.getMessage.contains("Detected conflicting schemas")) + + // Conflicting non-nullable field names + intercept[Throwable] { + HiveMetastoreCatalog.mergeWithMetastoreSchema( + StructType(Seq(StructField("lower", StringType, nullable = false))), + StructType(Seq(StructField("lowerCase", BinaryType)))) + } + + // Check that merging missing nullable fields works as expected. + assertResult( + StructType(Seq( + StructField("firstField", StringType, nullable = true), + StructField("secondField", StringType, nullable = true), + StructField("thirdfield", StringType, nullable = true)))) { + HiveMetastoreCatalog.mergeWithMetastoreSchema( + StructType(Seq( + StructField("firstfield", StringType, nullable = true), + StructField("secondfield", StringType, nullable = true), + StructField("thirdfield", StringType, nullable = true))), + StructType(Seq( + StructField("firstField", StringType, nullable = true), + StructField("secondField", StringType, nullable = true)))) + } + + // Merge should fail if the Metastore contains any additional fields that are not + // nullable. + assert(intercept[Throwable] { + HiveMetastoreCatalog.mergeWithMetastoreSchema( + StructType(Seq( + StructField("firstfield", StringType, nullable = true), + StructField("secondfield", StringType, nullable = true), + StructField("thirdfield", StringType, nullable = false))), + StructType(Seq( + StructField("firstField", StringType, nullable = true), + StructField("secondField", StringType, nullable = true)))) + }.getMessage.contains("Detected conflicting schemas")) + } +} + +object HiveSchemaInferenceSuite { + private val NUM_RECORDS = 10 + private val DATABASE = "default" + private val TEST_TABLE_NAME = "test_table" + private val ORC_FILE_TYPE = "orc" + private val PARQUET_FILE_TYPE = "parquet" +} From 82138e09b9ad8d9609d5c64d6c11244b8f230be7 Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Thu, 9 Mar 2017 17:42:10 -0800 Subject: [PATCH 66/78] [SPARK-19886] Fix reportDataLoss if statement in SS KafkaSource ## What changes were proposed in this pull request? Fix the `throw new IllegalStateException` if statement part. ## How is this patch tested Regression test Author: Burak Yavuz Closes #17228 from brkyvz/kafka-cause-fix. --- .../sql/kafka010/CachedKafkaConsumer.scala | 33 +++++++++++------- .../kafka010/CachedKafkaConsumerSuite.scala | 34 +++++++++++++++++++ 2 files changed, 54 insertions(+), 13 deletions(-) create mode 100644 external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/CachedKafkaConsumerSuite.scala diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/CachedKafkaConsumer.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/CachedKafkaConsumer.scala index 15b28256e825e..6d76904fb0e59 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/CachedKafkaConsumer.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/CachedKafkaConsumer.scala @@ -273,19 +273,7 @@ private[kafka010] case class CachedKafkaConsumer private( message: String, cause: Throwable = null): Unit = { val finalMessage = s"$message ${additionalMessage(failOnDataLoss)}" - if (failOnDataLoss) { - if (cause != null) { - throw new IllegalStateException(finalMessage) - } else { - throw new IllegalStateException(finalMessage, cause) - } - } else { - if (cause != null) { - logWarning(finalMessage) - } else { - logWarning(finalMessage, cause) - } - } + reportDataLoss0(failOnDataLoss, finalMessage, cause) } private def close(): Unit = consumer.close() @@ -398,4 +386,23 @@ private[kafka010] object CachedKafkaConsumer extends Logging { consumer } } + + private def reportDataLoss0( + failOnDataLoss: Boolean, + finalMessage: String, + cause: Throwable = null): Unit = { + if (failOnDataLoss) { + if (cause != null) { + throw new IllegalStateException(finalMessage, cause) + } else { + throw new IllegalStateException(finalMessage) + } + } else { + if (cause != null) { + logWarning(finalMessage, cause) + } else { + logWarning(finalMessage) + } + } + } } diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/CachedKafkaConsumerSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/CachedKafkaConsumerSuite.scala new file mode 100644 index 0000000000000..7aa7dd096c07b --- /dev/null +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/CachedKafkaConsumerSuite.scala @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.kafka010 + +import org.scalatest.PrivateMethodTester + +import org.apache.spark.sql.test.SharedSQLContext + +class CachedKafkaConsumerSuite extends SharedSQLContext with PrivateMethodTester { + + test("SPARK-19886: Report error cause correctly in reportDataLoss") { + val cause = new Exception("D'oh!") + val reportDataLoss = PrivateMethod[Unit]('reportDataLoss0) + val e = intercept[IllegalStateException] { + CachedKafkaConsumer.invokePrivate(reportDataLoss(true, "message", cause)) + } + assert(e.getCause === cause) + } +} From 5949e6c4477fd3cb07a6962dbee48b4416ea65dd Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Thu, 9 Mar 2017 22:58:52 -0800 Subject: [PATCH 67/78] [SPARK-19008][SQL] Improve performance of Dataset.map by eliminating boxing/unboxing ## What changes were proposed in this pull request? This PR improve performance of Dataset.map() for primitive types by removing boxing/unbox operations. This is based on [the discussion](https://github.com/apache/spark/pull/16391#discussion_r93788919) with cloud-fan. Current Catalyst generates a method call to a `apply()` method of an anonymous function written in Scala. The types of an argument and return value are `java.lang.Object`. As a result, each method call for a primitive value involves a pair of unboxing and boxing for calling this `apply()` method and a pair of boxing and unboxing for returning from this `apply()` method. This PR directly calls a specialized version of a `apply()` method without boxing and unboxing. For example, if types of an arguments ant return value is `int`, this PR generates a method call to `apply$mcII$sp`. This PR supports any combination of `Int`, `Long`, `Float`, and `Double`. The following is a benchmark result using [this program](https://github.com/apache/spark/pull/16391/files) with 4.7x. Here is a Dataset part of this program. Without this PR ``` OpenJDK 64-Bit Server VM 1.8.0_111-8u111-b14-2ubuntu0.16.04.2-b14 on Linux 4.4.0-47-generic Intel(R) Xeon(R) CPU E5-2667 v3 3.20GHz back-to-back map: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ RDD 1923 / 1952 52.0 19.2 1.0X DataFrame 526 / 548 190.2 5.3 3.7X Dataset 3094 / 3154 32.3 30.9 0.6X ``` With this PR ``` OpenJDK 64-Bit Server VM 1.8.0_111-8u111-b14-2ubuntu0.16.04.2-b14 on Linux 4.4.0-47-generic Intel(R) Xeon(R) CPU E5-2667 v3 3.20GHz back-to-back map: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ RDD 1883 / 1892 53.1 18.8 1.0X DataFrame 502 / 642 199.1 5.0 3.7X Dataset 657 / 784 152.2 6.6 2.9X ``` ```java def backToBackMap(spark: SparkSession, numRows: Long, numChains: Int): Benchmark = { import spark.implicits._ val rdd = spark.sparkContext.range(0, numRows) val ds = spark.range(0, numRows) val func = (l: Long) => l + 1 val benchmark = new Benchmark("back-to-back map", numRows) ... benchmark.addCase("Dataset") { iter => var res = ds.as[Long] var i = 0 while (i < numChains) { res = res.map(func) i += 1 } res.queryExecution.toRdd.foreach(_ => Unit) } benchmark } ``` A motivating example ```java Seq(1, 2, 3).toDS.map(i => i * 7).show ``` Generated code without this PR ```java /* 005 */ final class GeneratedIterator extends org.apache.spark.sql.execution.BufferedRowIterator { /* 006 */ private Object[] references; /* 007 */ private scala.collection.Iterator[] inputs; /* 008 */ private scala.collection.Iterator inputadapter_input; /* 009 */ private UnsafeRow deserializetoobject_result; /* 010 */ private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder deserializetoobject_holder; /* 011 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter deserializetoobject_rowWriter; /* 012 */ private int mapelements_argValue; /* 013 */ private UnsafeRow mapelements_result; /* 014 */ private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder mapelements_holder; /* 015 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter mapelements_rowWriter; /* 016 */ private UnsafeRow serializefromobject_result; /* 017 */ private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder serializefromobject_holder; /* 018 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter serializefromobject_rowWriter; /* 019 */ /* 020 */ public GeneratedIterator(Object[] references) { /* 021 */ this.references = references; /* 022 */ } /* 023 */ /* 024 */ public void init(int index, scala.collection.Iterator[] inputs) { /* 025 */ partitionIndex = index; /* 026 */ this.inputs = inputs; /* 027 */ inputadapter_input = inputs[0]; /* 028 */ deserializetoobject_result = new UnsafeRow(1); /* 029 */ this.deserializetoobject_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(deserializetoobject_result, 0); /* 030 */ this.deserializetoobject_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(deserializetoobject_holder, 1); /* 031 */ /* 032 */ mapelements_result = new UnsafeRow(1); /* 033 */ this.mapelements_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(mapelements_result, 0); /* 034 */ this.mapelements_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(mapelements_holder, 1); /* 035 */ serializefromobject_result = new UnsafeRow(1); /* 036 */ this.serializefromobject_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(serializefromobject_result, 0); /* 037 */ this.serializefromobject_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(serializefromobject_holder, 1); /* 038 */ /* 039 */ } /* 040 */ /* 041 */ protected void processNext() throws java.io.IOException { /* 042 */ while (inputadapter_input.hasNext() && !stopEarly()) { /* 043 */ InternalRow inputadapter_row = (InternalRow) inputadapter_input.next(); /* 044 */ int inputadapter_value = inputadapter_row.getInt(0); /* 045 */ /* 046 */ boolean mapelements_isNull = true; /* 047 */ int mapelements_value = -1; /* 048 */ if (!false) { /* 049 */ mapelements_argValue = inputadapter_value; /* 050 */ /* 051 */ mapelements_isNull = false; /* 052 */ if (!mapelements_isNull) { /* 053 */ Object mapelements_funcResult = null; /* 054 */ mapelements_funcResult = ((scala.Function1) references[0]).apply(mapelements_argValue); /* 055 */ if (mapelements_funcResult == null) { /* 056 */ mapelements_isNull = true; /* 057 */ } else { /* 058 */ mapelements_value = (Integer) mapelements_funcResult; /* 059 */ } /* 060 */ /* 061 */ } /* 062 */ /* 063 */ } /* 064 */ /* 065 */ serializefromobject_rowWriter.zeroOutNullBytes(); /* 066 */ /* 067 */ if (mapelements_isNull) { /* 068 */ serializefromobject_rowWriter.setNullAt(0); /* 069 */ } else { /* 070 */ serializefromobject_rowWriter.write(0, mapelements_value); /* 071 */ } /* 072 */ append(serializefromobject_result); /* 073 */ if (shouldStop()) return; /* 074 */ } /* 075 */ } /* 076 */ } ``` Generated code with this PR (lines 48-56 are changed) ```java /* 005 */ final class GeneratedIterator extends org.apache.spark.sql.execution.BufferedRowIterator { /* 006 */ private Object[] references; /* 007 */ private scala.collection.Iterator[] inputs; /* 008 */ private scala.collection.Iterator inputadapter_input; /* 009 */ private UnsafeRow deserializetoobject_result; /* 010 */ private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder deserializetoobject_holder; /* 011 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter deserializetoobject_rowWriter; /* 012 */ private int mapelements_argValue; /* 013 */ private UnsafeRow mapelements_result; /* 014 */ private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder mapelements_holder; /* 015 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter mapelements_rowWriter; /* 016 */ private UnsafeRow serializefromobject_result; /* 017 */ private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder serializefromobject_holder; /* 018 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter serializefromobject_rowWriter; /* 019 */ /* 020 */ public GeneratedIterator(Object[] references) { /* 021 */ this.references = references; /* 022 */ } /* 023 */ /* 024 */ public void init(int index, scala.collection.Iterator[] inputs) { /* 025 */ partitionIndex = index; /* 026 */ this.inputs = inputs; /* 027 */ inputadapter_input = inputs[0]; /* 028 */ deserializetoobject_result = new UnsafeRow(1); /* 029 */ this.deserializetoobject_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(deserializetoobject_result, 0); /* 030 */ this.deserializetoobject_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(deserializetoobject_holder, 1); /* 031 */ /* 032 */ mapelements_result = new UnsafeRow(1); /* 033 */ this.mapelements_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(mapelements_result, 0); /* 034 */ this.mapelements_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(mapelements_holder, 1); /* 035 */ serializefromobject_result = new UnsafeRow(1); /* 036 */ this.serializefromobject_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(serializefromobject_result, 0); /* 037 */ this.serializefromobject_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(serializefromobject_holder, 1); /* 038 */ /* 039 */ } /* 040 */ /* 041 */ protected void processNext() throws java.io.IOException { /* 042 */ while (inputadapter_input.hasNext() && !stopEarly()) { /* 043 */ InternalRow inputadapter_row = (InternalRow) inputadapter_input.next(); /* 044 */ int inputadapter_value = inputadapter_row.getInt(0); /* 045 */ /* 046 */ boolean mapelements_isNull = true; /* 047 */ int mapelements_value = -1; /* 048 */ if (!false) { /* 049 */ mapelements_argValue = inputadapter_value; /* 050 */ /* 051 */ mapelements_isNull = false; /* 052 */ if (!mapelements_isNull) { /* 053 */ mapelements_value = ((scala.Function1) references[0]).apply$mcII$sp(mapelements_argValue); /* 054 */ } /* 055 */ /* 056 */ } /* 057 */ /* 058 */ serializefromobject_rowWriter.zeroOutNullBytes(); /* 059 */ /* 060 */ if (mapelements_isNull) { /* 061 */ serializefromobject_rowWriter.setNullAt(0); /* 062 */ } else { /* 063 */ serializefromobject_rowWriter.write(0, mapelements_value); /* 064 */ } /* 065 */ append(serializefromobject_result); /* 066 */ if (shouldStop()) return; /* 067 */ } /* 068 */ } /* 069 */ } ``` Java bytecode for methods for `i => i * 7` ```java $ javap -c Test\$\$anonfun\$5\$\$anonfun\$apply\$mcV\$sp\$1.class Compiled from "Test.scala" public final class org.apache.spark.sql.Test$$anonfun$5$$anonfun$apply$mcV$sp$1 extends scala.runtime.AbstractFunction1$mcII$sp implements scala.Serializable { public static final long serialVersionUID; public final int apply(int); Code: 0: aload_0 1: iload_1 2: invokevirtual #18 // Method apply$mcII$sp:(I)I 5: ireturn public int apply$mcII$sp(int); Code: 0: iload_1 1: bipush 7 3: imul 4: ireturn public final java.lang.Object apply(java.lang.Object); Code: 0: aload_0 1: aload_1 2: invokestatic #29 // Method scala/runtime/BoxesRunTime.unboxToInt:(Ljava/lang/Object;)I 5: invokevirtual #31 // Method apply:(I)I 8: invokestatic #35 // Method scala/runtime/BoxesRunTime.boxToInteger:(I)Ljava/lang/Integer; 11: areturn public org.apache.spark.sql.Test$$anonfun$5$$anonfun$apply$mcV$sp$1(org.apache.spark.sql.Test$$anonfun$5); Code: 0: aload_0 1: invokespecial #42 // Method scala/runtime/AbstractFunction1$mcII$sp."":()V 4: return } ``` ## How was this patch tested? Added new test suites to `DatasetPrimitiveSuite`. Author: Kazuaki Ishizaki Closes #17172 from kiszk/SPARK-19008. --- .../sql/catalyst/plans/logical/object.scala | 38 +++++- .../apache/spark/sql/execution/objects.scala | 6 +- .../apache/spark/sql/DatasetBenchmark.scala | 122 +++++++++++++++++- .../spark/sql/DatasetPrimitiveSuite.scala | 51 ++++++++ 4 files changed, 208 insertions(+), 9 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala index 617239f56cdd3..7f4462e583607 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.objects.Invoke import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types._ +import org.apache.spark.util.Utils object CatalystSerde { def deserialize[T : Encoder](child: LogicalPlan): DeserializeToObject = { @@ -211,13 +212,48 @@ case class TypedFilter( def typedCondition(input: Expression): Expression = { val (funcClass, methodName) = func match { case m: FilterFunction[_] => classOf[FilterFunction[_]] -> "call" - case _ => classOf[Any => Boolean] -> "apply" + case _ => FunctionUtils.getFunctionOneName(BooleanType, input.dataType) } val funcObj = Literal.create(func, ObjectType(funcClass)) Invoke(funcObj, methodName, BooleanType, input :: Nil) } } +object FunctionUtils { + private def getMethodType(dt: DataType, isOutput: Boolean): Option[String] = { + dt match { + case BooleanType if isOutput => Some("Z") + case IntegerType => Some("I") + case LongType => Some("J") + case FloatType => Some("F") + case DoubleType => Some("D") + case _ => None + } + } + + def getFunctionOneName(outputDT: DataType, inputDT: DataType): (Class[_], String) = { + // load "scala.Function1" using Java API to avoid requirements of type parameters + Utils.classForName("scala.Function1") -> { + // if a pair of an argument and return types is one of specific types + // whose specialized method (apply$mc..$sp) is generated by scalac, + // Catalyst generated a direct method call to the specialized method. + // The followings are references for this specialization: + // http://www.scala-lang.org/api/2.12.0/scala/Function1.html + // https://github.com/scala/scala/blob/2.11.x/src/compiler/scala/tools/nsc/transform/ + // SpecializeTypes.scala + // http://www.cakesolutions.net/teamblogs/scala-dissection-functions + // http://axel22.github.io/2013/11/03/specialization-quirks.html + val inputType = getMethodType(inputDT, false) + val outputType = getMethodType(outputDT, true) + if (inputType.isDefined && outputType.isDefined) { + s"apply$$mc${outputType.get}${inputType.get}$$sp" + } else { + "apply" + } + } + } +} + /** Factory for constructing new `AppendColumn` nodes. */ object AppendColumns { def apply[T : Encoder, U : Encoder]( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala index 199ba5ce6969b..fdd1bcc94be25 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala @@ -28,11 +28,13 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.objects.Invoke +import org.apache.spark.sql.catalyst.plans.logical.FunctionUtils import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.plans.logical.LogicalKeyedState import org.apache.spark.sql.execution.streaming.KeyedStateImpl -import org.apache.spark.sql.types.{DataType, ObjectType, StructType} +import org.apache.spark.sql.types._ +import org.apache.spark.util.Utils /** @@ -219,7 +221,7 @@ case class MapElementsExec( override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = { val (funcClass, methodName) = func match { case m: MapFunction[_, _] => classOf[MapFunction[_, _]] -> "call" - case _ => classOf[Any => Any] -> "apply" + case _ => FunctionUtils.getFunctionOneName(outputObjAttr.dataType, child.output(0).dataType) } val funcObj = Literal.create(func, ObjectType(funcClass)) val callFunc = Invoke(funcObj, methodName, outputObjAttr.dataType, child.output) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala index 66d94d6016050..1a0672b8876da 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala @@ -31,6 +31,49 @@ object DatasetBenchmark { case class Data(l: Long, s: String) + def backToBackMapLong(spark: SparkSession, numRows: Long, numChains: Int): Benchmark = { + import spark.implicits._ + + val rdd = spark.sparkContext.range(0, numRows) + val ds = spark.range(0, numRows) + val df = ds.toDF("l") + val func = (l: Long) => l + 1 + + val benchmark = new Benchmark("back-to-back map long", numRows) + + benchmark.addCase("RDD") { iter => + var res = rdd + var i = 0 + while (i < numChains) { + res = res.map(func) + i += 1 + } + res.foreach(_ => Unit) + } + + benchmark.addCase("DataFrame") { iter => + var res = df + var i = 0 + while (i < numChains) { + res = res.select($"l" + 1 as "l") + i += 1 + } + res.queryExecution.toRdd.foreach(_ => Unit) + } + + benchmark.addCase("Dataset") { iter => + var res = ds.as[Long] + var i = 0 + while (i < numChains) { + res = res.map(func) + i += 1 + } + res.queryExecution.toRdd.foreach(_ => Unit) + } + + benchmark + } + def backToBackMap(spark: SparkSession, numRows: Long, numChains: Int): Benchmark = { import spark.implicits._ @@ -72,6 +115,49 @@ object DatasetBenchmark { benchmark } + def backToBackFilterLong(spark: SparkSession, numRows: Long, numChains: Int): Benchmark = { + import spark.implicits._ + + val rdd = spark.sparkContext.range(1, numRows) + val ds = spark.range(1, numRows) + val df = ds.toDF("l") + val func = (l: Long) => l % 2L == 0L + + val benchmark = new Benchmark("back-to-back filter Long", numRows) + + benchmark.addCase("RDD") { iter => + var res = rdd + var i = 0 + while (i < numChains) { + res = res.filter(func) + i += 1 + } + res.foreach(_ => Unit) + } + + benchmark.addCase("DataFrame") { iter => + var res = df + var i = 0 + while (i < numChains) { + res = res.filter($"l" % 2L === 0L) + i += 1 + } + res.queryExecution.toRdd.foreach(_ => Unit) + } + + benchmark.addCase("Dataset") { iter => + var res = ds.as[Long] + var i = 0 + while (i < numChains) { + res = res.filter(func) + i += 1 + } + res.queryExecution.toRdd.foreach(_ => Unit) + } + + benchmark + } + def backToBackFilter(spark: SparkSession, numRows: Long, numChains: Int): Benchmark = { import spark.implicits._ @@ -165,9 +251,22 @@ object DatasetBenchmark { val numRows = 100000000 val numChains = 10 - val benchmark = backToBackMap(spark, numRows, numChains) - val benchmark2 = backToBackFilter(spark, numRows, numChains) - val benchmark3 = aggregate(spark, numRows) + val benchmark0 = backToBackMapLong(spark, numRows, numChains) + val benchmark1 = backToBackMap(spark, numRows, numChains) + val benchmark2 = backToBackFilterLong(spark, numRows, numChains) + val benchmark3 = backToBackFilter(spark, numRows, numChains) + val benchmark4 = aggregate(spark, numRows) + + /* + OpenJDK 64-Bit Server VM 1.8.0_111-8u111-b14-2ubuntu0.16.04.2-b14 on Linux 4.4.0-47-generic + Intel(R) Xeon(R) CPU E5-2667 v3 @ 3.20GHz + back-to-back map long: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + RDD 1883 / 1892 53.1 18.8 1.0X + DataFrame 502 / 642 199.1 5.0 3.7X + Dataset 657 / 784 152.2 6.6 2.9X + */ + benchmark0.run() /* OpenJDK 64-Bit Server VM 1.8.0_91-b14 on Linux 3.10.0-327.18.2.el7.x86_64 @@ -178,7 +277,18 @@ object DatasetBenchmark { DataFrame 2647 / 3116 37.8 26.5 1.3X Dataset 4781 / 5155 20.9 47.8 0.7X */ - benchmark.run() + benchmark1.run() + + /* + OpenJDK 64-Bit Server VM 1.8.0_121-8u121-b13-0ubuntu1.16.04.2-b13 on Linux 4.4.0-47-generic + Intel(R) Xeon(R) CPU E5-2667 v3 @ 3.20GHz + back-to-back filter Long: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + RDD 846 / 1120 118.1 8.5 1.0X + DataFrame 270 / 329 370.9 2.7 3.1X + Dataset 545 / 789 183.5 5.4 1.6X + */ + benchmark2.run() /* OpenJDK 64-Bit Server VM 1.8.0_91-b14 on Linux 3.10.0-327.18.2.el7.x86_64 @@ -189,7 +299,7 @@ object DatasetBenchmark { DataFrame 59 / 72 1695.4 0.6 22.8X Dataset 2777 / 2805 36.0 27.8 0.5X */ - benchmark2.run() + benchmark3.run() /* Java HotSpot(TM) 64-Bit Server VM 1.8.0_60-b27 on Mac OS X 10.12.1 @@ -201,6 +311,6 @@ object DatasetBenchmark { Dataset sum using Aggregator 4656 / 4758 21.5 46.6 0.4X Dataset complex Aggregator 6636 / 7039 15.1 66.4 0.3X */ - benchmark3.run() + benchmark4.run() } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala index 6b50cb3e48c76..82b707537e45f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala @@ -62,6 +62,40 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext { 2, 3, 4) } + test("mapPrimitive") { + val dsInt = Seq(1, 2, 3).toDS() + checkDataset(dsInt.map(_ > 1), false, true, true) + checkDataset(dsInt.map(_ + 1), 2, 3, 4) + checkDataset(dsInt.map(_ + 8589934592L), 8589934593L, 8589934594L, 8589934595L) + checkDataset(dsInt.map(_ + 1.1F), 2.1F, 3.1F, 4.1F) + checkDataset(dsInt.map(_ + 1.23D), 2.23D, 3.23D, 4.23D) + + val dsLong = Seq(1L, 2L, 3L).toDS() + checkDataset(dsLong.map(_ > 1), false, true, true) + checkDataset(dsLong.map(e => (e + 1).toInt), 2, 3, 4) + checkDataset(dsLong.map(_ + 8589934592L), 8589934593L, 8589934594L, 8589934595L) + checkDataset(dsLong.map(_ + 1.1F), 2.1F, 3.1F, 4.1F) + checkDataset(dsLong.map(_ + 1.23D), 2.23D, 3.23D, 4.23D) + + val dsFloat = Seq(1F, 2F, 3F).toDS() + checkDataset(dsFloat.map(_ > 1), false, true, true) + checkDataset(dsFloat.map(e => (e + 1).toInt), 2, 3, 4) + checkDataset(dsFloat.map(e => (e + 123456L).toLong), 123457L, 123458L, 123459L) + checkDataset(dsFloat.map(_ + 1.1F), 2.1F, 3.1F, 4.1F) + checkDataset(dsFloat.map(_ + 1.23D), 2.23D, 3.23D, 4.23D) + + val dsDouble = Seq(1D, 2D, 3D).toDS() + checkDataset(dsDouble.map(_ > 1), false, true, true) + checkDataset(dsDouble.map(e => (e + 1).toInt), 2, 3, 4) + checkDataset(dsDouble.map(e => (e + 8589934592L).toLong), + 8589934593L, 8589934594L, 8589934595L) + checkDataset(dsDouble.map(e => (e + 1.1F).toFloat), 2.1F, 3.1F, 4.1F) + checkDataset(dsDouble.map(_ + 1.23D), 2.23D, 3.23D, 4.23D) + + val dsBoolean = Seq(true, false).toDS() + checkDataset(dsBoolean.map(e => !e), false, true) + } + test("filter") { val ds = Seq(1, 2, 3, 4).toDS() checkDataset( @@ -69,6 +103,23 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext { 2, 4) } + test("filterPrimitive") { + val dsInt = Seq(1, 2, 3).toDS() + checkDataset(dsInt.filter(_ > 1), 2, 3) + + val dsLong = Seq(1L, 2L, 3L).toDS() + checkDataset(dsLong.filter(_ > 1), 2L, 3L) + + val dsFloat = Seq(1F, 2F, 3F).toDS() + checkDataset(dsFloat.filter(_ > 1), 2F, 3F) + + val dsDouble = Seq(1D, 2D, 3D).toDS() + checkDataset(dsDouble.filter(_ > 1), 2D, 3D) + + val dsBoolean = Seq(true, false).toDS() + checkDataset(dsBoolean.filter(e => !e), false) + } + test("foreach") { val ds = Seq(1, 2, 3).toDS() val acc = sparkContext.longAccumulator From 501b7111997bc74754663348967104181b43319b Mon Sep 17 00:00:00 2001 From: Tyson Condie Date: Thu, 9 Mar 2017 23:02:13 -0800 Subject: [PATCH 68/78] [SPARK-19891][SS] Await Batch Lock notified on stream execution exit ## What changes were proposed in this pull request? We need to notify the await batch lock when the stream exits early e.g., when an exception has been thrown. ## How was this patch tested? Current tests that throw exceptions at runtime will finish faster as a result of this update. zsxwing Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Tyson Condie Closes #17231 from tcondie/kafka-writer. --- .../spark/sql/execution/streaming/StreamExecution.scala | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index 70912d13ae458..529263805c0aa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -361,6 +361,13 @@ class StreamExecution( } } } finally { + awaitBatchLock.lock() + try { + // Wake up any threads that are waiting for the stream to progress. + awaitBatchLockCondition.signalAll() + } finally { + awaitBatchLock.unlock() + } terminationLatch.countDown() } } From fcb68e0f5d49234ac4527109887ff08cd4e1c29f Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Fri, 10 Mar 2017 18:04:37 +0100 Subject: [PATCH 69/78] [SPARK-19786][SQL] Facilitate loop optimizations in a JIT compiler regarding range() ## What changes were proposed in this pull request? This PR improves performance of operations with `range()` by changing Java code generated by Catalyst. This PR is inspired by the [blog article](https://databricks.com/blog/2017/02/16/processing-trillion-rows-per-second-single-machine-can-nested-loop-joins-fast.html). This PR changes generated code in the following two points. 1. Replace a while-loop with long instance variables a for-loop with int local varibles 2. Suppress generation of `shouldStop()` method if this method is unnecessary (e.g. `append()` is not generated). These points facilitates compiler optimizations in a JIT compiler by feeding the simplified Java code into the JIT compiler. The performance is improved by 7.6x. Benchmark program: ```java val N = 1 << 29 val iters = 2 val benchmark = new Benchmark("range.count", N * iters) benchmark.addCase(s"with this PR") { i => var n = 0 var len = 0 while (n < iters) { len += sparkSession.range(N).selectExpr("count(id)").collect.length n += 1 } } benchmark.run ``` Performance result without this PR ``` OpenJDK 64-Bit Server VM 1.8.0_111-8u111-b14-2ubuntu0.16.04.2-b14 on Linux 4.4.0-47-generic Intel(R) Xeon(R) CPU E5-2667 v3 3.20GHz range.count: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ w/o this PR 1349 / 1356 796.2 1.3 1.0X ``` Performance result with this PR ``` OpenJDK 64-Bit Server VM 1.8.0_111-8u111-b14-2ubuntu0.16.04.2-b14 on Linux 4.4.0-47-generic Intel(R) Xeon(R) CPU E5-2667 v3 3.20GHz range.count: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ with this PR 177 / 271 6065.3 0.2 1.0X ``` Here is a comparison between generated code w/o and with this PR. Only the method ```agg_doAggregateWithoutKey``` is changed. Generated code without this PR ```java /* 005 */ final class GeneratedIterator extends org.apache.spark.sql.execution.BufferedRowIterator { /* 006 */ private Object[] references; /* 007 */ private scala.collection.Iterator[] inputs; /* 008 */ private boolean agg_initAgg; /* 009 */ private boolean agg_bufIsNull; /* 010 */ private long agg_bufValue; /* 011 */ private org.apache.spark.sql.execution.metric.SQLMetric range_numOutputRows; /* 012 */ private org.apache.spark.sql.execution.metric.SQLMetric range_numGeneratedRows; /* 013 */ private boolean range_initRange; /* 014 */ private long range_number; /* 015 */ private TaskContext range_taskContext; /* 016 */ private InputMetrics range_inputMetrics; /* 017 */ private long range_batchEnd; /* 018 */ private long range_numElementsTodo; /* 019 */ private scala.collection.Iterator range_input; /* 020 */ private UnsafeRow range_result; /* 021 */ private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder range_holder; /* 022 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter range_rowWriter; /* 023 */ private org.apache.spark.sql.execution.metric.SQLMetric agg_numOutputRows; /* 024 */ private org.apache.spark.sql.execution.metric.SQLMetric agg_aggTime; /* 025 */ private UnsafeRow agg_result; /* 026 */ private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder agg_holder; /* 027 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter agg_rowWriter; /* 028 */ /* 029 */ public GeneratedIterator(Object[] references) { /* 030 */ this.references = references; /* 031 */ } /* 032 */ /* 033 */ public void init(int index, scala.collection.Iterator[] inputs) { /* 034 */ partitionIndex = index; /* 035 */ this.inputs = inputs; /* 036 */ agg_initAgg = false; /* 037 */ /* 038 */ this.range_numOutputRows = (org.apache.spark.sql.execution.metric.SQLMetric) references[0]; /* 039 */ this.range_numGeneratedRows = (org.apache.spark.sql.execution.metric.SQLMetric) references[1]; /* 040 */ range_initRange = false; /* 041 */ range_number = 0L; /* 042 */ range_taskContext = TaskContext.get(); /* 043 */ range_inputMetrics = range_taskContext.taskMetrics().inputMetrics(); /* 044 */ range_batchEnd = 0; /* 045 */ range_numElementsTodo = 0L; /* 046 */ range_input = inputs[0]; /* 047 */ range_result = new UnsafeRow(1); /* 048 */ this.range_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(range_result, 0); /* 049 */ this.range_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(range_holder, 1); /* 050 */ this.agg_numOutputRows = (org.apache.spark.sql.execution.metric.SQLMetric) references[2]; /* 051 */ this.agg_aggTime = (org.apache.spark.sql.execution.metric.SQLMetric) references[3]; /* 052 */ agg_result = new UnsafeRow(1); /* 053 */ this.agg_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(agg_result, 0); /* 054 */ this.agg_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(agg_holder, 1); /* 055 */ /* 056 */ } /* 057 */ /* 058 */ private void agg_doAggregateWithoutKey() throws java.io.IOException { /* 059 */ // initialize aggregation buffer /* 060 */ agg_bufIsNull = false; /* 061 */ agg_bufValue = 0L; /* 062 */ /* 063 */ // initialize Range /* 064 */ if (!range_initRange) { /* 065 */ range_initRange = true; /* 066 */ initRange(partitionIndex); /* 067 */ } /* 068 */ /* 069 */ while (true) { /* 070 */ while (range_number != range_batchEnd) { /* 071 */ long range_value = range_number; /* 072 */ range_number += 1L; /* 073 */ /* 074 */ // do aggregate /* 075 */ // common sub-expressions /* 076 */ /* 077 */ // evaluate aggregate function /* 078 */ boolean agg_isNull1 = false; /* 079 */ /* 080 */ long agg_value1 = -1L; /* 081 */ agg_value1 = agg_bufValue + 1L; /* 082 */ // update aggregation buffer /* 083 */ agg_bufIsNull = false; /* 084 */ agg_bufValue = agg_value1; /* 085 */ /* 086 */ if (shouldStop()) return; /* 087 */ } /* 088 */ /* 089 */ if (range_taskContext.isInterrupted()) { /* 090 */ throw new TaskKilledException(); /* 091 */ } /* 092 */ /* 093 */ long range_nextBatchTodo; /* 094 */ if (range_numElementsTodo > 1000L) { /* 095 */ range_nextBatchTodo = 1000L; /* 096 */ range_numElementsTodo -= 1000L; /* 097 */ } else { /* 098 */ range_nextBatchTodo = range_numElementsTodo; /* 099 */ range_numElementsTodo = 0; /* 100 */ if (range_nextBatchTodo == 0) break; /* 101 */ } /* 102 */ range_numOutputRows.add(range_nextBatchTodo); /* 103 */ range_inputMetrics.incRecordsRead(range_nextBatchTodo); /* 104 */ /* 105 */ range_batchEnd += range_nextBatchTodo * 1L; /* 106 */ } /* 107 */ /* 108 */ } /* 109 */ /* 110 */ private void initRange(int idx) { /* 111 */ java.math.BigInteger index = java.math.BigInteger.valueOf(idx); /* 112 */ java.math.BigInteger numSlice = java.math.BigInteger.valueOf(2L); /* 113 */ java.math.BigInteger numElement = java.math.BigInteger.valueOf(10000L); /* 114 */ java.math.BigInteger step = java.math.BigInteger.valueOf(1L); /* 115 */ java.math.BigInteger start = java.math.BigInteger.valueOf(0L); /* 117 */ /* 118 */ java.math.BigInteger st = index.multiply(numElement).divide(numSlice).multiply(step).add(start); /* 119 */ if (st.compareTo(java.math.BigInteger.valueOf(Long.MAX_VALUE)) > 0) { /* 120 */ range_number = Long.MAX_VALUE; /* 121 */ } else if (st.compareTo(java.math.BigInteger.valueOf(Long.MIN_VALUE)) < 0) { /* 122 */ range_number = Long.MIN_VALUE; /* 123 */ } else { /* 124 */ range_number = st.longValue(); /* 125 */ } /* 126 */ range_batchEnd = range_number; /* 127 */ /* 128 */ java.math.BigInteger end = index.add(java.math.BigInteger.ONE).multiply(numElement).divide(numSlice) /* 129 */ .multiply(step).add(start); /* 130 */ if (end.compareTo(java.math.BigInteger.valueOf(Long.MAX_VALUE)) > 0) { /* 131 */ partitionEnd = Long.MAX_VALUE; /* 132 */ } else if (end.compareTo(java.math.BigInteger.valueOf(Long.MIN_VALUE)) < 0) { /* 133 */ partitionEnd = Long.MIN_VALUE; /* 134 */ } else { /* 135 */ partitionEnd = end.longValue(); /* 136 */ } /* 137 */ /* 138 */ java.math.BigInteger startToEnd = java.math.BigInteger.valueOf(partitionEnd).subtract( /* 139 */ java.math.BigInteger.valueOf(range_number)); /* 140 */ range_numElementsTodo = startToEnd.divide(step).longValue(); /* 141 */ if (range_numElementsTodo < 0) { /* 142 */ range_numElementsTodo = 0; /* 143 */ } else if (startToEnd.remainder(step).compareTo(java.math.BigInteger.valueOf(0L)) != 0) { /* 144 */ range_numElementsTodo++; /* 145 */ } /* 146 */ } /* 147 */ /* 148 */ protected void processNext() throws java.io.IOException { /* 149 */ while (!agg_initAgg) { /* 150 */ agg_initAgg = true; /* 151 */ long agg_beforeAgg = System.nanoTime(); /* 152 */ agg_doAggregateWithoutKey(); /* 153 */ agg_aggTime.add((System.nanoTime() - agg_beforeAgg) / 1000000); /* 154 */ /* 155 */ // output the result /* 156 */ /* 157 */ agg_numOutputRows.add(1); /* 158 */ agg_rowWriter.zeroOutNullBytes(); /* 159 */ /* 160 */ if (agg_bufIsNull) { /* 161 */ agg_rowWriter.setNullAt(0); /* 162 */ } else { /* 163 */ agg_rowWriter.write(0, agg_bufValue); /* 164 */ } /* 165 */ append(agg_result); /* 166 */ } /* 167 */ } /* 168 */ } ``` Generated code with this PR ```java /* 005 */ final class GeneratedIterator extends org.apache.spark.sql.execution.BufferedRowIterator { /* 006 */ private Object[] references; /* 007 */ private scala.collection.Iterator[] inputs; /* 008 */ private boolean agg_initAgg; /* 009 */ private boolean agg_bufIsNull; /* 010 */ private long agg_bufValue; /* 011 */ private org.apache.spark.sql.execution.metric.SQLMetric range_numOutputRows; /* 012 */ private org.apache.spark.sql.execution.metric.SQLMetric range_numGeneratedRows; /* 013 */ private boolean range_initRange; /* 014 */ private long range_number; /* 015 */ private TaskContext range_taskContext; /* 016 */ private InputMetrics range_inputMetrics; /* 017 */ private long range_batchEnd; /* 018 */ private long range_numElementsTodo; /* 019 */ private scala.collection.Iterator range_input; /* 020 */ private UnsafeRow range_result; /* 021 */ private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder range_holder; /* 022 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter range_rowWriter; /* 023 */ private org.apache.spark.sql.execution.metric.SQLMetric agg_numOutputRows; /* 024 */ private org.apache.spark.sql.execution.metric.SQLMetric agg_aggTime; /* 025 */ private UnsafeRow agg_result; /* 026 */ private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder agg_holder; /* 027 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter agg_rowWriter; /* 028 */ /* 029 */ public GeneratedIterator(Object[] references) { /* 030 */ this.references = references; /* 031 */ } /* 032 */ /* 033 */ public void init(int index, scala.collection.Iterator[] inputs) { /* 034 */ partitionIndex = index; /* 035 */ this.inputs = inputs; /* 036 */ agg_initAgg = false; /* 037 */ /* 038 */ this.range_numOutputRows = (org.apache.spark.sql.execution.metric.SQLMetric) references[0]; /* 039 */ this.range_numGeneratedRows = (org.apache.spark.sql.execution.metric.SQLMetric) references[1]; /* 040 */ range_initRange = false; /* 041 */ range_number = 0L; /* 042 */ range_taskContext = TaskContext.get(); /* 043 */ range_inputMetrics = range_taskContext.taskMetrics().inputMetrics(); /* 044 */ range_batchEnd = 0; /* 045 */ range_numElementsTodo = 0L; /* 046 */ range_input = inputs[0]; /* 047 */ range_result = new UnsafeRow(1); /* 048 */ this.range_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(range_result, 0); /* 049 */ this.range_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(range_holder, 1); /* 050 */ this.agg_numOutputRows = (org.apache.spark.sql.execution.metric.SQLMetric) references[2]; /* 051 */ this.agg_aggTime = (org.apache.spark.sql.execution.metric.SQLMetric) references[3]; /* 052 */ agg_result = new UnsafeRow(1); /* 053 */ this.agg_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(agg_result, 0); /* 054 */ this.agg_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(agg_holder, 1); /* 055 */ /* 056 */ } /* 057 */ /* 058 */ private void agg_doAggregateWithoutKey() throws java.io.IOException { /* 059 */ // initialize aggregation buffer /* 060 */ agg_bufIsNull = false; /* 061 */ agg_bufValue = 0L; /* 062 */ /* 063 */ // initialize Range /* 064 */ if (!range_initRange) { /* 065 */ range_initRange = true; /* 066 */ initRange(partitionIndex); /* 067 */ } /* 068 */ /* 069 */ while (true) { /* 070 */ long range_range = range_batchEnd - range_number; /* 071 */ if (range_range != 0L) { /* 072 */ int range_localEnd = (int)(range_range / 1L); /* 073 */ for (int range_localIdx = 0; range_localIdx < range_localEnd; range_localIdx++) { /* 074 */ long range_value = ((long)range_localIdx * 1L) + range_number; /* 075 */ /* 076 */ // do aggregate /* 077 */ // common sub-expressions /* 078 */ /* 079 */ // evaluate aggregate function /* 080 */ boolean agg_isNull1 = false; /* 081 */ /* 082 */ long agg_value1 = -1L; /* 083 */ agg_value1 = agg_bufValue + 1L; /* 084 */ // update aggregation buffer /* 085 */ agg_bufIsNull = false; /* 086 */ agg_bufValue = agg_value1; /* 087 */ /* 088 */ // shouldStop check is eliminated /* 089 */ } /* 090 */ range_number = range_batchEnd; /* 091 */ } /* 092 */ /* 093 */ if (range_taskContext.isInterrupted()) { /* 094 */ throw new TaskKilledException(); /* 095 */ } /* 096 */ /* 097 */ long range_nextBatchTodo; /* 098 */ if (range_numElementsTodo > 1000L) { /* 099 */ range_nextBatchTodo = 1000L; /* 100 */ range_numElementsTodo -= 1000L; /* 101 */ } else { /* 102 */ range_nextBatchTodo = range_numElementsTodo; /* 103 */ range_numElementsTodo = 0; /* 104 */ if (range_nextBatchTodo == 0) break; /* 105 */ } /* 106 */ range_numOutputRows.add(range_nextBatchTodo); /* 107 */ range_inputMetrics.incRecordsRead(range_nextBatchTodo); /* 108 */ /* 109 */ range_batchEnd += range_nextBatchTodo * 1L; /* 110 */ } /* 111 */ /* 112 */ } /* 113 */ /* 114 */ private void initRange(int idx) { /* 115 */ java.math.BigInteger index = java.math.BigInteger.valueOf(idx); /* 116 */ java.math.BigInteger numSlice = java.math.BigInteger.valueOf(2L); /* 117 */ java.math.BigInteger numElement = java.math.BigInteger.valueOf(10000L); /* 118 */ java.math.BigInteger step = java.math.BigInteger.valueOf(1L); /* 119 */ java.math.BigInteger start = java.math.BigInteger.valueOf(0L); /* 120 */ long partitionEnd; /* 121 */ /* 122 */ java.math.BigInteger st = index.multiply(numElement).divide(numSlice).multiply(step).add(start); /* 123 */ if (st.compareTo(java.math.BigInteger.valueOf(Long.MAX_VALUE)) > 0) { /* 124 */ range_number = Long.MAX_VALUE; /* 125 */ } else if (st.compareTo(java.math.BigInteger.valueOf(Long.MIN_VALUE)) < 0) { /* 126 */ range_number = Long.MIN_VALUE; /* 127 */ } else { /* 128 */ range_number = st.longValue(); /* 129 */ } /* 130 */ range_batchEnd = range_number; /* 131 */ /* 132 */ java.math.BigInteger end = index.add(java.math.BigInteger.ONE).multiply(numElement).divide(numSlice) /* 133 */ .multiply(step).add(start); /* 134 */ if (end.compareTo(java.math.BigInteger.valueOf(Long.MAX_VALUE)) > 0) { /* 135 */ partitionEnd = Long.MAX_VALUE; /* 136 */ } else if (end.compareTo(java.math.BigInteger.valueOf(Long.MIN_VALUE)) < 0) { /* 137 */ partitionEnd = Long.MIN_VALUE; /* 138 */ } else { /* 139 */ partitionEnd = end.longValue(); /* 140 */ } /* 141 */ /* 142 */ java.math.BigInteger startToEnd = java.math.BigInteger.valueOf(partitionEnd).subtract( /* 143 */ java.math.BigInteger.valueOf(range_number)); /* 144 */ range_numElementsTodo = startToEnd.divide(step).longValue(); /* 145 */ if (range_numElementsTodo < 0) { /* 146 */ range_numElementsTodo = 0; /* 147 */ } else if (startToEnd.remainder(step).compareTo(java.math.BigInteger.valueOf(0L)) != 0) { /* 148 */ range_numElementsTodo++; /* 149 */ } /* 150 */ } /* 151 */ /* 152 */ protected void processNext() throws java.io.IOException { /* 153 */ while (!agg_initAgg) { /* 154 */ agg_initAgg = true; /* 155 */ long agg_beforeAgg = System.nanoTime(); /* 156 */ agg_doAggregateWithoutKey(); /* 157 */ agg_aggTime.add((System.nanoTime() - agg_beforeAgg) / 1000000); /* 158 */ /* 159 */ // output the result /* 160 */ /* 161 */ agg_numOutputRows.add(1); /* 162 */ agg_rowWriter.zeroOutNullBytes(); /* 163 */ /* 164 */ if (agg_bufIsNull) { /* 165 */ agg_rowWriter.setNullAt(0); /* 166 */ } else { /* 167 */ agg_rowWriter.write(0, agg_bufValue); /* 168 */ } /* 169 */ append(agg_result); /* 170 */ } /* 171 */ } /* 172 */ } ``` A part of suppressing `shouldStop()` was originally developed by inouehrs ## How was this patch tested? Add new tests into `DataFrameRangeSuite` Author: Kazuaki Ishizaki Closes #17122 from kiszk/SPARK-19786. --- .../apache/spark/sql/execution/SortExec.scala | 2 ++ .../sql/execution/WholeStageCodegenExec.scala | 15 +++++++++++ .../aggregate/HashAggregateExec.scala | 2 ++ .../execution/basicPhysicalOperators.scala | 27 ++++++++++++++----- .../spark/sql/DataFrameRangeSuite.scala | 16 +++++++++++ 5 files changed, 55 insertions(+), 7 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala index cc576bbc4c802..f98ae82574d20 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala @@ -177,6 +177,8 @@ case class SortExec( """.stripMargin.trim } + protected override val shouldStopRequired = false + override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = { s""" |${row.code} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index c58474eba05d4..c31fd92447c0d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -206,6 +206,21 @@ trait CodegenSupport extends SparkPlan { def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = { throw new UnsupportedOperationException } + + /** + * For optimization to suppress shouldStop() in a loop of WholeStageCodegen. + * Returning true means we need to insert shouldStop() into the loop producing rows, if any. + */ + def isShouldStopRequired: Boolean = { + return shouldStopRequired && (this.parent == null || this.parent.isShouldStopRequired) + } + + /** + * Set to false if this plan consumes all rows produced by children but doesn't output row + * to buffer by calling append(), so the children don't require shouldStop() + * in the loop of producing rows. + */ + protected def shouldStopRequired: Boolean = true } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 4529ed067e565..68c8e6ce62cbb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -238,6 +238,8 @@ case class HashAggregateExec( """.stripMargin } + protected override val shouldStopRequired = false + private def doConsumeWithoutKeys(ctx: CodegenContext, input: Seq[ExprCode]): String = { // only have DeclarativeAggregate val functions = aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate]) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index 87e90ed685cca..d876688a8aabd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -387,8 +387,8 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range) // How many values should be generated in the next batch. val nextBatchTodo = ctx.freshName("nextBatchTodo") - // The default size of a batch. - val batchSize = 1000L + // The default size of a batch, which must be positive integer + val batchSize = 1000 ctx.addNewFunction("initRange", s""" @@ -434,6 +434,15 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range) val input = ctx.freshName("input") // Right now, Range is only used when there is one upstream. ctx.addMutableState("scala.collection.Iterator", input, s"$input = inputs[0];") + + val localIdx = ctx.freshName("localIdx") + val localEnd = ctx.freshName("localEnd") + val range = ctx.freshName("range") + val shouldStop = if (isShouldStopRequired) { + s"if (shouldStop()) { $number = $value + ${step}L; return; }" + } else { + "// shouldStop check is eliminated" + } s""" | // initialize Range | if (!$initTerm) { @@ -442,11 +451,15 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range) | } | | while (true) { - | while ($number != $batchEnd) { - | long $value = $number; - | $number += ${step}L; - | ${consume(ctx, Seq(ev))} - | if (shouldStop()) return; + | long $range = $batchEnd - $number; + | if ($range != 0L) { + | int $localEnd = (int)($range / ${step}L); + | for (int $localIdx = 0; $localIdx < $localEnd; $localIdx++) { + | long $value = ((long)$localIdx * ${step}L) + $number; + | ${consume(ctx, Seq(ev))} + | $shouldStop + | } + | $number = $batchEnd; | } | | if ($taskContext.isInterrupted()) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala index acf393a9b0faf..5e323c02b253d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala @@ -89,6 +89,22 @@ class DataFrameRangeSuite extends QueryTest with SharedSQLContext with Eventuall val n = 9L * 1000 * 1000 * 1000 * 1000 * 1000 * 1000 val res13 = spark.range(-n, n, n / 9).select("id") assert(res13.count == 18) + + // range with non aggregation operation + val res14 = spark.range(0, 100, 2).toDF.filter("50 <= id") + val len14 = res14.collect.length + assert(len14 == 25) + + val res15 = spark.range(100, -100, -2).toDF.filter("id <= 0") + val len15 = res15.collect.length + assert(len15 == 50) + + val res16 = spark.range(-1500, 1500, 3).toDF.filter("0 <= id") + val len16 = res16.collect.length + assert(len16 == 500) + + val res17 = spark.range(10, 0, -1, 1).toDF.sortWithinPartitions("id") + assert(res17.collect === (1 to 10).map(i => Row(i)).toArray) } test("Range with randomized parameters") { From dd9049e0492cc70b629518fee9b3d1632374c612 Mon Sep 17 00:00:00 2001 From: Carson Wang Date: Fri, 10 Mar 2017 11:13:26 -0800 Subject: [PATCH 70/78] [SPARK-19620][SQL] Fix incorrect exchange coordinator id in the physical plan ## What changes were proposed in this pull request? When adaptive execution is enabled, an exchange coordinator is used in the Exchange operators. For Join, the same exchange coordinator is used for its two Exchanges. But the physical plan shows two different coordinator Ids which is confusing. This PR is to fix the incorrect exchange coordinator id in the physical plan. The coordinator object instead of the `Option[ExchangeCoordinator]` should be used to generate the identity hash code of the same coordinator. ## How was this patch tested? Before the patch, the physical plan shows two different exchange coordinator id for Join. ``` == Physical Plan == *Project [key1#3L, value2#12L] +- *SortMergeJoin [key1#3L], [key2#11L], Inner :- *Sort [key1#3L ASC NULLS FIRST], false, 0 : +- Exchange(coordinator id: 1804587700) hashpartitioning(key1#3L, 10), coordinator[target post-shuffle partition size: 67108864] : +- *Project [(id#0L % 500) AS key1#3L] : +- *Filter isnotnull((id#0L % 500)) : +- *Range (0, 1000, step=1, splits=Some(10)) +- *Sort [key2#11L ASC NULLS FIRST], false, 0 +- Exchange(coordinator id: 793927319) hashpartitioning(key2#11L, 10), coordinator[target post-shuffle partition size: 67108864] +- *Project [(id#8L % 500) AS key2#11L, id#8L AS value2#12L] +- *Filter isnotnull((id#8L % 500)) +- *Range (0, 1000, step=1, splits=Some(10)) ``` After the patch, two exchange coordinator id are the same. Author: Carson Wang Closes #16952 from carsonwang/FixCoordinatorId. --- .../apache/spark/sql/execution/exchange/ShuffleExchange.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchange.scala index 125a4930c6528..f06544ea8ed04 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchange.scala @@ -46,7 +46,7 @@ case class ShuffleExchange( override def nodeName: String = { val extraInfo = coordinator match { case Some(exchangeCoordinator) => - s"(coordinator id: ${System.identityHashCode(coordinator)})" + s"(coordinator id: ${System.identityHashCode(exchangeCoordinator)})" case None => "" } From 8f0490e22b4c7f1fdf381c70c5894d46b7f7e6fb Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Fri, 10 Mar 2017 13:33:58 -0800 Subject: [PATCH 71/78] [SPARK-17979][SPARK-14453] Remove deprecated SPARK_YARN_USER_ENV and SPARK_JAVA_OPTS This fix removes deprecated support for config `SPARK_YARN_USER_ENV`, as is mentioned in SPARK-17979. This fix also removes deprecated support for the following: ``` SPARK_YARN_USER_ENV SPARK_JAVA_OPTS SPARK_CLASSPATH SPARK_WORKER_INSTANCES ``` Related JIRA: [SPARK-14453]: https://issues.apache.org/jira/browse/SPARK-14453 [SPARK-12344]: https://issues.apache.org/jira/browse/SPARK-12344 [SPARK-15781]: https://issues.apache.org/jira/browse/SPARK-15781 Existing tests should pass. Author: Yong Tang Closes #17212 from yongtang/SPARK-17979. --- conf/spark-env.sh.template | 3 - .../scala/org/apache/spark/SparkConf.scala | 65 ------------------- .../spark/deploy/FaultToleranceTest.scala | 3 +- .../spark/launcher/WorkerCommandBuilder.scala | 1 - docs/rdd-programming-guide.md | 2 +- .../launcher/AbstractCommandBuilder.java | 1 - .../launcher/SparkClassCommandBuilder.java | 2 - .../launcher/SparkSubmitCommandBuilder.java | 1 - .../MesosCoarseGrainedSchedulerBackend.scala | 5 -- .../MesosFineGrainedSchedulerBackend.scala | 4 -- .../org/apache/spark/deploy/yarn/Client.scala | 39 +---------- .../spark/deploy/yarn/ExecutorRunnable.scala | 8 --- 12 files changed, 3 insertions(+), 131 deletions(-) diff --git a/conf/spark-env.sh.template b/conf/spark-env.sh.template index 5c1e876ef9afc..94bd2c477a35b 100755 --- a/conf/spark-env.sh.template +++ b/conf/spark-env.sh.template @@ -25,12 +25,10 @@ # - HADOOP_CONF_DIR, to point Spark towards Hadoop configuration files # - SPARK_LOCAL_IP, to set the IP address Spark binds to on this node # - SPARK_PUBLIC_DNS, to set the public dns name of the driver program -# - SPARK_CLASSPATH, default classpath entries to append # Options read by executors and drivers running inside the cluster # - SPARK_LOCAL_IP, to set the IP address Spark binds to on this node # - SPARK_PUBLIC_DNS, to set the public DNS name of the driver program -# - SPARK_CLASSPATH, default classpath entries to append # - SPARK_LOCAL_DIRS, storage directories to use on this node for shuffle and RDD data # - MESOS_NATIVE_JAVA_LIBRARY, to point to your libmesos.so if you use Mesos @@ -48,7 +46,6 @@ # - SPARK_WORKER_CORES, to set the number of cores to use on this machine # - SPARK_WORKER_MEMORY, to set how much total memory workers have to give executors (e.g. 1000m, 2g) # - SPARK_WORKER_PORT / SPARK_WORKER_WEBUI_PORT, to use non-default ports for the worker -# - SPARK_WORKER_INSTANCES, to set the number of worker processes per node # - SPARK_WORKER_DIR, to set the working directory of worker processes # - SPARK_WORKER_OPTS, to set config properties only for the worker (e.g. "-Dx=y") # - SPARK_DAEMON_MEMORY, to allocate to the master, worker and history server themselves (default: 1g). diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index fe912e639bcbc..2a2ce0504dbbf 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -518,71 +518,6 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging with Seria } } - // Check for legacy configs - sys.env.get("SPARK_JAVA_OPTS").foreach { value => - val warning = - s""" - |SPARK_JAVA_OPTS was detected (set to '$value'). - |This is deprecated in Spark 1.0+. - | - |Please instead use: - | - ./spark-submit with conf/spark-defaults.conf to set defaults for an application - | - ./spark-submit with --driver-java-options to set -X options for a driver - | - spark.executor.extraJavaOptions to set -X options for executors - | - SPARK_DAEMON_JAVA_OPTS to set java options for standalone daemons (master or worker) - """.stripMargin - logWarning(warning) - - for (key <- Seq(executorOptsKey, driverOptsKey)) { - if (getOption(key).isDefined) { - throw new SparkException(s"Found both $key and SPARK_JAVA_OPTS. Use only the former.") - } else { - logWarning(s"Setting '$key' to '$value' as a work-around.") - set(key, value) - } - } - } - - sys.env.get("SPARK_CLASSPATH").foreach { value => - val warning = - s""" - |SPARK_CLASSPATH was detected (set to '$value'). - |This is deprecated in Spark 1.0+. - | - |Please instead use: - | - ./spark-submit with --driver-class-path to augment the driver classpath - | - spark.executor.extraClassPath to augment the executor classpath - """.stripMargin - logWarning(warning) - - for (key <- Seq(executorClasspathKey, driverClassPathKey)) { - if (getOption(key).isDefined) { - throw new SparkException(s"Found both $key and SPARK_CLASSPATH. Use only the former.") - } else { - logWarning(s"Setting '$key' to '$value' as a work-around.") - set(key, value) - } - } - } - - if (!contains(sparkExecutorInstances)) { - sys.env.get("SPARK_WORKER_INSTANCES").foreach { value => - val warning = - s""" - |SPARK_WORKER_INSTANCES was detected (set to '$value'). - |This is deprecated in Spark 1.0+. - | - |Please instead use: - | - ./spark-submit with --num-executors to specify the number of executors - | - Or set SPARK_EXECUTOR_INSTANCES - | - spark.executor.instances to configure the number of instances in the spark config. - """.stripMargin - logWarning(warning) - - set("spark.executor.instances", value) - } - } - if (contains("spark.master") && get("spark.master").startsWith("yarn-")) { val warning = s"spark.master ${get("spark.master")} is deprecated in Spark 2.0+, please " + "instead use \"yarn\" with specified deploy mode." diff --git a/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala b/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala index 320af5cf97550..c6307da61c7eb 100644 --- a/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala +++ b/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala @@ -43,8 +43,7 @@ import org.apache.spark.util.{ThreadUtils, Utils} * Execute using * ./bin/spark-class org.apache.spark.deploy.FaultToleranceTest * - * Make sure that the environment includes the following properties in SPARK_DAEMON_JAVA_OPTS - * *and* SPARK_JAVA_OPTS: + * Make sure that the environment includes the following properties in SPARK_DAEMON_JAVA_OPTS: * - spark.deploy.recoveryMode=ZOOKEEPER * - spark.deploy.zookeeper.url=172.17.42.1:2181 * Note that 172.17.42.1 is the default docker ip for the host and 2181 is the default ZK port. diff --git a/core/src/main/scala/org/apache/spark/launcher/WorkerCommandBuilder.scala b/core/src/main/scala/org/apache/spark/launcher/WorkerCommandBuilder.scala index 3fd812e9fcfe8..4216b2627309e 100644 --- a/core/src/main/scala/org/apache/spark/launcher/WorkerCommandBuilder.scala +++ b/core/src/main/scala/org/apache/spark/launcher/WorkerCommandBuilder.scala @@ -39,7 +39,6 @@ private[spark] class WorkerCommandBuilder(sparkHome: String, memoryMb: Int, comm val cmd = buildJavaCommand(command.classPathEntries.mkString(File.pathSeparator)) cmd.add(s"-Xmx${memoryMb}M") command.javaOpts.foreach(cmd.add) - addOptionString(cmd, getenv("SPARK_JAVA_OPTS")) cmd } diff --git a/docs/rdd-programming-guide.md b/docs/rdd-programming-guide.md index cad9ff4e646e5..e2bf2d7ca77ca 100644 --- a/docs/rdd-programming-guide.md +++ b/docs/rdd-programming-guide.md @@ -457,7 +457,7 @@ If required, a Hadoop configuration can be passed in as a Python dict. Here is a Elasticsearch ESInputFormat: {% highlight python %} -$ SPARK_CLASSPATH=/path/to/elasticsearch-hadoop.jar ./bin/pyspark +$ ./bin/pyspark --jars /path/to/elasticsearch-hadoop.jar >>> conf = {"es.resource" : "index/type"} # assume Elasticsearch is running on localhost defaults >>> rdd = sc.newAPIHadoopRDD("org.elasticsearch.hadoop.mr.EsInputFormat", "org.apache.hadoop.io.NullWritable", diff --git a/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java b/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java index bc8d6037a367b..6c0c3ebcaebf4 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java +++ b/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java @@ -135,7 +135,6 @@ List buildClassPath(String appClassPath) throws IOException { String sparkHome = getSparkHome(); Set cp = new LinkedHashSet<>(); - addToClassPath(cp, getenv("SPARK_CLASSPATH")); addToClassPath(cp, appClassPath); addToClassPath(cp, getConfDir()); diff --git a/launcher/src/main/java/org/apache/spark/launcher/SparkClassCommandBuilder.java b/launcher/src/main/java/org/apache/spark/launcher/SparkClassCommandBuilder.java index 81786841de224..7cf5b7379503f 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/SparkClassCommandBuilder.java +++ b/launcher/src/main/java/org/apache/spark/launcher/SparkClassCommandBuilder.java @@ -66,7 +66,6 @@ public List buildCommand(Map env) memKey = "SPARK_DAEMON_MEMORY"; break; case "org.apache.spark.executor.CoarseGrainedExecutorBackend": - javaOptsKeys.add("SPARK_JAVA_OPTS"); javaOptsKeys.add("SPARK_EXECUTOR_OPTS"); memKey = "SPARK_EXECUTOR_MEMORY"; break; @@ -84,7 +83,6 @@ public List buildCommand(Map env) memKey = "SPARK_DAEMON_MEMORY"; break; default: - javaOptsKeys.add("SPARK_JAVA_OPTS"); memKey = "SPARK_DRIVER_MEMORY"; break; } diff --git a/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java b/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java index 5e64fa7ed152c..5f2da036ff9f7 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java +++ b/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java @@ -240,7 +240,6 @@ private List buildSparkSubmitCommand(Map env) addOptionString(cmd, System.getenv("SPARK_DAEMON_JAVA_OPTS")); } addOptionString(cmd, System.getenv("SPARK_SUBMIT_OPTS")); - addOptionString(cmd, System.getenv("SPARK_JAVA_OPTS")); // We don't want the client to specify Xmx. These have to be set by their corresponding // memory flag --driver-memory or configuration entry spark.driver.memory diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala index 85c2e9c76f4b0..c049a32eabf90 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala @@ -175,11 +175,6 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( def createCommand(offer: Offer, numCores: Int, taskId: String): CommandInfo = { val environment = Environment.newBuilder() - val extraClassPath = conf.getOption("spark.executor.extraClassPath") - extraClassPath.foreach { cp => - environment.addVariables( - Environment.Variable.newBuilder().setName("SPARK_CLASSPATH").setValue(cp).build()) - } val extraJavaOpts = conf.get("spark.executor.extraJavaOptions", "") // Set the environment variable through a command prefix diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackend.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackend.scala index 215271302ec51..f198f8893b3db 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackend.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackend.scala @@ -106,10 +106,6 @@ private[spark] class MesosFineGrainedSchedulerBackend( throw new SparkException("Executor Spark home `spark.mesos.executor.home` is not set!") } val environment = Environment.newBuilder() - sc.conf.getOption("spark.executor.extraClassPath").foreach { cp => - environment.addVariables( - Environment.Variable.newBuilder().setName("SPARK_CLASSPATH").setValue(cp).build()) - } val extraJavaOpts = sc.conf.getOption("spark.executor.extraJavaOptions").getOrElse("") val prefixEnv = sc.conf.getOption("spark.executor.extraLibraryPath").map { p => diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index e86bd5459311d..ccb0f8fdbbc21 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -748,14 +748,6 @@ private[spark] class Client( .map { case (k, v) => (k.substring(amEnvPrefix.length), v) } .foreach { case (k, v) => YarnSparkHadoopUtil.addPathToEnvironment(env, k, v) } - // Keep this for backwards compatibility but users should move to the config - sys.env.get("SPARK_YARN_USER_ENV").foreach { userEnvs => - // Allow users to specify some environment variables. - YarnSparkHadoopUtil.setEnvFromInputString(env, userEnvs) - // Pass SPARK_YARN_USER_ENV itself to the AM so it can use it to set up executor environments. - env("SPARK_YARN_USER_ENV") = userEnvs - } - // If pyFiles contains any .py files, we need to add LOCALIZED_PYTHON_DIR to the PYTHONPATH // of the container processes too. Add all non-.py files directly to PYTHONPATH. // @@ -782,35 +774,7 @@ private[spark] class Client( sparkConf.setExecutorEnv("PYTHONPATH", pythonPathStr) } - // In cluster mode, if the deprecated SPARK_JAVA_OPTS is set, we need to propagate it to - // executors. But we can't just set spark.executor.extraJavaOptions, because the driver's - // SparkContext will not let that set spark* system properties, which is expected behavior for - // Yarn clients. So propagate it through the environment. - // - // Note that to warn the user about the deprecation in cluster mode, some code from - // SparkConf#validateSettings() is duplicated here (to avoid triggering the condition - // described above). if (isClusterMode) { - sys.env.get("SPARK_JAVA_OPTS").foreach { value => - val warning = - s""" - |SPARK_JAVA_OPTS was detected (set to '$value'). - |This is deprecated in Spark 1.0+. - | - |Please instead use: - | - ./spark-submit with conf/spark-defaults.conf to set defaults for an application - | - ./spark-submit with --driver-java-options to set -X options for a driver - | - spark.executor.extraJavaOptions to set -X options for executors - """.stripMargin - logWarning(warning) - for (proc <- Seq("driver", "executor")) { - val key = s"spark.$proc.extraJavaOptions" - if (sparkConf.contains(key)) { - throw new SparkException(s"Found both $key and SPARK_JAVA_OPTS. Use only the former.") - } - } - env("SPARK_JAVA_OPTS") = value - } // propagate PYSPARK_DRIVER_PYTHON and PYSPARK_PYTHON to driver in cluster mode Seq("PYSPARK_DRIVER_PYTHON", "PYSPARK_PYTHON").foreach { envname => if (!env.contains(envname)) { @@ -883,8 +847,7 @@ private[spark] class Client( // Include driver-specific java options if we are launching a driver if (isClusterMode) { - val driverOpts = sparkConf.get(DRIVER_JAVA_OPTIONS).orElse(sys.env.get("SPARK_JAVA_OPTS")) - driverOpts.foreach { opts => + sparkConf.get(DRIVER_JAVA_OPTIONS).foreach { opts => javaOpts ++= Utils.splitCommandString(opts).map(YarnSparkHadoopUtil.escapeForShell) } val libraryPaths = Seq(sparkConf.get(DRIVER_LIBRARY_PATH), diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala index ee85c043b8bc0..3f4d236571ffd 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala @@ -143,9 +143,6 @@ private[yarn] class ExecutorRunnable( sparkConf.get(EXECUTOR_JAVA_OPTIONS).foreach { opts => javaOpts ++= Utils.splitCommandString(opts).map(YarnSparkHadoopUtil.escapeForShell) } - sys.env.get("SPARK_JAVA_OPTS").foreach { opts => - javaOpts ++= Utils.splitCommandString(opts).map(YarnSparkHadoopUtil.escapeForShell) - } sparkConf.get(EXECUTOR_LIBRARY_PATH).foreach { p => prefixEnv = Some(Client.getClusterPath(sparkConf, Utils.libraryPathEnvPrefix(Seq(p)))) } @@ -229,11 +226,6 @@ private[yarn] class ExecutorRunnable( YarnSparkHadoopUtil.addPathToEnvironment(env, key, value) } - // Keep this for backwards compatibility but users should move to the config - sys.env.get("SPARK_YARN_USER_ENV").foreach { userEnvs => - YarnSparkHadoopUtil.setEnvFromInputString(env, userEnvs) - } - // lookup appropriate http scheme for container log urls val yarnHttpPolicy = conf.get( YarnConfiguration.YARN_HTTP_POLICY_KEY, From bc30351404d8bc610cbae65fdc12ca613e7735c6 Mon Sep 17 00:00:00 2001 From: Budde Date: Fri, 10 Mar 2017 15:18:37 -0800 Subject: [PATCH 72/78] [SPARK-19611][SQL] Preserve metastore field order when merging inferred schema ## What changes were proposed in this pull request? The ```HiveMetastoreCatalog.mergeWithMetastoreSchema()``` method added in #16944 may not preserve the same field order as the metastore schema in some cases, which can cause queries to fail. This change ensures that the metastore field order is preserved. ## How was this patch tested? A test for ensuring that metastore order is preserved was added to ```HiveSchemaInferenceSuite.``` The particular failure usecase from #16944 was tested manually as well. Author: Budde Closes #17249 from budde/PreserveMetastoreFieldOrder. --- .../spark/sql/hive/HiveMetastoreCatalog.scala | 5 +---- .../sql/hive/HiveSchemaInferenceSuite.scala | 21 +++++++++++++++++++ 2 files changed, 22 insertions(+), 4 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 056af495590f7..9f0d1ceb28fca 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -356,13 +356,10 @@ private[hive] object HiveMetastoreCatalog { .filterKeys(!inferredSchema.map(_.name.toLowerCase).contains(_)) .values .filter(_.nullable) - // Merge missing nullable fields to inferred schema and build a case-insensitive field map. val inferredFields = StructType(inferredSchema ++ missingNullables) .map(f => f.name.toLowerCase -> f).toMap - StructType(metastoreFields.map { case(name, field) => - field.copy(name = inferredFields(name).name) - }.toSeq) + StructType(metastoreSchema.map(f => f.copy(name = inferredFields(f.name).name))) } catch { case NonFatal(_) => val msg = s"""Detected conflicting schemas when merging the schema obtained from the Hive diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSchemaInferenceSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSchemaInferenceSuite.scala index 78955803819cf..e48ce2304d086 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSchemaInferenceSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSchemaInferenceSuite.scala @@ -293,6 +293,27 @@ class HiveSchemaInferenceSuite StructField("firstField", StringType, nullable = true), StructField("secondField", StringType, nullable = true)))) }.getMessage.contains("Detected conflicting schemas")) + + // Schema merge should maintain metastore order. + assertResult( + StructType(Seq( + StructField("first_field", StringType, nullable = true), + StructField("second_field", StringType, nullable = true), + StructField("third_field", StringType, nullable = true), + StructField("fourth_field", StringType, nullable = true), + StructField("fifth_field", StringType, nullable = true)))) { + HiveMetastoreCatalog.mergeWithMetastoreSchema( + StructType(Seq( + StructField("first_field", StringType, nullable = true), + StructField("second_field", StringType, nullable = true), + StructField("third_field", StringType, nullable = true), + StructField("fourth_field", StringType, nullable = true), + StructField("fifth_field", StringType, nullable = true))), + StructType(Seq( + StructField("fifth_field", StringType, nullable = true), + StructField("third_field", StringType, nullable = true), + StructField("second_field", StringType, nullable = true)))) + } } } From ffee4f1cefb0dfd8d9145ee3be82c6f7b799870b Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Fri, 10 Mar 2017 15:19:32 -0800 Subject: [PATCH 73/78] [SPARK-19905][SQL] Bring back Dataset.inputFiles for Hive SerDe tables ## What changes were proposed in this pull request? `Dataset.inputFiles` works by matching `FileRelation`s in the query plan. In Spark 2.1, Hive SerDe tables are represented by `MetastoreRelation`, which inherits from `FileRelation`. However, in Spark 2.2, Hive SerDe tables are now represented by `CatalogRelation`, which doesn't inherit from `FileRelation` anymore, due to the unification of Hive SerDe tables and data source tables. This change breaks `Dataset.inputFiles` for Hive SerDe tables. This PR tries to fix this issue by explicitly matching `CatalogRelation`s that are Hive SerDe tables in `Dataset.inputFiles`. Note that we can't make `CatalogRelation` inherit from `FileRelation` since not all `CatalogRelation`s are file based (e.g., JDBC data source tables). ## How was this patch tested? New test case added in `HiveDDLSuite`. Author: Cheng Lian Closes #17247 from liancheng/spark-19905-hive-table-input-files. --- .../src/main/scala/org/apache/spark/sql/Dataset.scala | 3 +++ .../spark/sql/hive/execution/HiveDDLSuite.scala | 11 +++++++++++ 2 files changed, 14 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 0a4d3a93a07e8..520663f624408 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -36,6 +36,7 @@ import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst._ import org.apache.spark.sql.catalyst.analysis._ +import org.apache.spark.sql.catalyst.catalog.CatalogRelation import org.apache.spark.sql.catalyst.encoders._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ @@ -2734,6 +2735,8 @@ class Dataset[T] private[sql]( fsBasedRelation.inputFiles case fr: FileRelation => fr.inputFiles + case r: CatalogRelation if DDLUtils.isHiveTable(r.tableMeta) => + r.tableMeta.storage.locationUri.map(_.toString).toArray }.flatten files.toSet.toArray } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala index 23aea24697785..79ad156c55611 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala @@ -1865,4 +1865,15 @@ class HiveDDLSuite } } } + + test("SPARK-19905: Hive SerDe table input paths") { + withTable("spark_19905") { + withTempView("spark_19905_view") { + spark.range(10).createOrReplaceTempView("spark_19905_view") + sql("CREATE TABLE spark_19905 STORED AS RCFILE AS SELECT * FROM spark_19905_view") + assert(spark.table("spark_19905").inputFiles.nonEmpty) + assert(sql("SELECT input_file_name() FROM spark_19905").count() > 0) + } + } + } } From fb9beda54622e0c3190c6504fc468fa4e50eeb45 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 10 Mar 2017 16:14:22 -0800 Subject: [PATCH 74/78] [SPARK-19893][SQL] should not run DataFrame set oprations with map type ## What changes were proposed in this pull request? In spark SQL, map type can't be used in equality test/comparison, and `Intersect`/`Except`/`Distinct` do need equality test for all columns, we should not allow map type in `Intersect`/`Except`/`Distinct`. ## How was this patch tested? new regression test Author: Wenchen Fan Closes #17236 from cloud-fan/map. --- .../sql/catalyst/analysis/CheckAnalysis.scala | 25 ++++++++++++++++--- .../org/apache/spark/sql/DataFrameSuite.scala | 19 ++++++++++++++ .../columnar/InMemoryColumnarQuerySuite.scala | 14 +++++------ 3 files changed, 47 insertions(+), 11 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 7529f9028498c..d32fbeb4e91ef 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -44,6 +44,18 @@ trait CheckAnalysis extends PredicateHelper { }).length > 1 } + protected def hasMapType(dt: DataType): Boolean = { + dt.existsRecursively(_.isInstanceOf[MapType]) + } + + protected def mapColumnInSetOperation(plan: LogicalPlan): Option[Attribute] = plan match { + case _: Intersect | _: Except | _: Distinct => + plan.output.find(a => hasMapType(a.dataType)) + case d: Deduplicate => + d.keys.find(a => hasMapType(a.dataType)) + case _ => None + } + private def checkLimitClause(limitExpr: Expression): Unit = { limitExpr match { case e if !e.foldable => failAnalysis( @@ -121,8 +133,7 @@ trait CheckAnalysis extends PredicateHelper { if (conditions.isEmpty && query.output.size != 1) { failAnalysis( s"Scalar subquery must return only one column, but got ${query.output.size}") - } - else if (conditions.nonEmpty) { + } else if (conditions.nonEmpty) { // Collect the columns from the subquery for further checking. var subqueryColumns = conditions.flatMap(_.references).filter(query.output.contains) @@ -200,7 +211,7 @@ trait CheckAnalysis extends PredicateHelper { s"filter expression '${f.condition.sql}' " + s"of type ${f.condition.dataType.simpleString} is not a boolean.") - case f @ Filter(condition, child) => + case Filter(condition, _) => splitConjunctivePredicates(condition).foreach { case _: PredicateSubquery | Not(_: PredicateSubquery) => case e if PredicateSubquery.hasNullAwarePredicateWithinNot(e) => @@ -374,6 +385,14 @@ trait CheckAnalysis extends PredicateHelper { |Conflicting attributes: ${conflictingAttributes.mkString(",")} """.stripMargin) + // TODO: although map type is not orderable, technically map type should be able to be + // used in equality comparison, remove this type check once we support it. + case o if mapColumnInSetOperation(o).isDefined => + val mapCol = mapColumnInSetOperation(o).get + failAnalysis("Cannot have map type columns in DataFrame which calls " + + s"set operations(intersect, except, etc.), but the type of column ${mapCol.name} " + + "is " + mapCol.dataType.simpleString) + case o if o.expressions.exists(!_.deterministic) && !o.isInstanceOf[Project] && !o.isInstanceOf[Filter] && !o.isInstanceOf[Aggregate] && !o.isInstanceOf[Window] => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 19c2d5532d088..52bd4e19f8952 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -1703,4 +1703,23 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { val df = spark.range(1).selectExpr("CAST(id as DECIMAL) as x").selectExpr("percentile(x, 0.5)") checkAnswer(df, Row(BigDecimal(0.0)) :: Nil) } + + test("SPARK-19893: cannot run set operations with map type") { + val df = spark.range(1).select(map(lit("key"), $"id").as("m")) + val e = intercept[AnalysisException](df.intersect(df)) + assert(e.message.contains( + "Cannot have map type columns in DataFrame which calls set operations")) + val e2 = intercept[AnalysisException](df.except(df)) + assert(e2.message.contains( + "Cannot have map type columns in DataFrame which calls set operations")) + val e3 = intercept[AnalysisException](df.distinct()) + assert(e3.message.contains( + "Cannot have map type columns in DataFrame which calls set operations")) + withTempView("v") { + df.createOrReplaceTempView("v") + val e4 = intercept[AnalysisException](sql("SELECT DISTINCT m FROM v")) + assert(e4.message.contains( + "Cannot have map type columns in DataFrame which calls set operations")) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala index f355a5200ce2f..0250a53fe2324 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala @@ -234,8 +234,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { Seq(StringType, BinaryType, NullType, BooleanType, ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType, DecimalType(25, 5), DecimalType(6, 5), - DateType, TimestampType, - ArrayType(IntegerType), MapType(StringType, LongType), struct) + DateType, TimestampType, ArrayType(IntegerType), struct) val fields = dataTypes.zipWithIndex.map { case (dataType, index) => StructField(s"col$index", dataType, true) } @@ -244,10 +243,10 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { // Create an RDD for the schema val rdd = - sparkContext.parallelize((1 to 10000), 10).map { i => + sparkContext.parallelize(1 to 10000, 10).map { i => Row( - s"str${i}: test cache.", - s"binary${i}: test cache.".getBytes(StandardCharsets.UTF_8), + s"str$i: test cache.", + s"binary$i: test cache.".getBytes(StandardCharsets.UTF_8), null, i % 2 == 0, i.toByte, @@ -255,13 +254,12 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { i, Long.MaxValue - i.toLong, (i + 0.25).toFloat, - (i + 0.75), + i + 0.75, BigDecimal(Long.MaxValue.toString + ".12345"), new java.math.BigDecimal(s"${i % 9 + 1}" + ".23456"), new Date(i), new Timestamp(i * 1000000L), - (i to i + 10).toSeq, - (i to i + 10).map(j => s"map_key_$j" -> (Long.MaxValue - j)).toMap, + i to i + 10, Row((i - 0.25).toFloat, Seq(true, false, null))) } spark.createDataFrame(rdd, schema).createOrReplaceTempView("InMemoryCache_different_data_types") From f6fdf92d0dce2cb3340f3e2ff768e09ef69176cd Mon Sep 17 00:00:00 2001 From: windpiger Date: Fri, 10 Mar 2017 20:59:32 -0800 Subject: [PATCH 75/78] [SPARK-19723][SQL] create datasource table with an non-existent location should work ## What changes were proposed in this pull request? This JIRA is a follow up work after [SPARK-19583](https://issues.apache.org/jira/browse/SPARK-19583) As we discussed in that [PR](https://github.com/apache/spark/pull/16938) The following DDL for datasource table with an non-existent location should work: ``` CREATE TABLE ... (PARTITIONED BY ...) LOCATION path ``` Currently it will throw exception that path not exists for datasource table for datasource table ## How was this patch tested? unit test added Author: windpiger Closes #17055 from windpiger/CTDataSourcePathNotExists. --- .../command/createDataSourceTables.scala | 3 +- .../sql/execution/command/DDLSuite.scala | 106 ++++++++++------- .../sql/hive/execution/HiveDDLSuite.scala | 111 ++++++++---------- 3 files changed, 115 insertions(+), 105 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala index 3da66afceda9c..2d890118ae0a5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala @@ -73,7 +73,8 @@ case class CreateDataSourceTableCommand(table: CatalogTable, ignoreIfExists: Boo className = table.provider.get, bucketSpec = table.bucketSpec, options = table.storage.properties ++ pathOption, - catalogTable = Some(tableWithDefaultOptions)).resolveRelation() + // As discussed in SPARK-19583, we don't check if the location is existed + catalogTable = Some(tableWithDefaultOptions)).resolveRelation(checkFilesExist = false) val partitionColumnNames = if (table.schema.nonEmpty) { table.partitionColumnNames diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index 5f70a8ce8918b..0666f446f3b52 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -230,7 +230,7 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { } private def getDBPath(dbName: String): URI = { - val warehousePath = makeQualifiedPath(s"${spark.sessionState.conf.warehousePath}") + val warehousePath = makeQualifiedPath(spark.sessionState.conf.warehousePath) new Path(CatalogUtils.URIToString(warehousePath), s"$dbName.db").toUri } @@ -1899,7 +1899,7 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { } } - test("insert data to a data source table which has a not existed location should succeed") { + test("insert data to a data source table which has a non-existing location should succeed") { withTable("t") { withTempDir { dir => spark.sql( @@ -1939,7 +1939,7 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { } } - test("insert into a data source table with no existed partition location should succeed") { + test("insert into a data source table with a non-existing partition location should succeed") { withTable("t") { withTempDir { dir => spark.sql( @@ -1966,7 +1966,7 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { } } - test("read data from a data source table which has a not existed location should succeed") { + test("read data from a data source table which has a non-existing location should succeed") { withTable("t") { withTempDir { dir => spark.sql( @@ -1994,7 +1994,7 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { } } - test("read data from a data source table with no existed partition location should succeed") { + test("read data from a data source table with non-existing partition location should succeed") { withTable("t") { withTempDir { dir => spark.sql( @@ -2016,48 +2016,72 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { } } + test("create datasource table with a non-existing location") { + withTable("t", "t1") { + withTempPath { dir => + spark.sql(s"CREATE TABLE t(a int, b int) USING parquet LOCATION '$dir'") + + val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t")) + assert(table.location == makeQualifiedPath(dir.getAbsolutePath)) + + spark.sql("INSERT INTO TABLE t SELECT 1, 2") + assert(dir.exists()) + + checkAnswer(spark.table("t"), Row(1, 2)) + } + // partition table + withTempPath { dir => + spark.sql(s"CREATE TABLE t1(a int, b int) USING parquet PARTITIONED BY(a) LOCATION '$dir'") + + val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t1")) + assert(table.location == makeQualifiedPath(dir.getAbsolutePath)) + + spark.sql("INSERT INTO TABLE t1 PARTITION(a=1) SELECT 2") + + val partDir = new File(dir, "a=1") + assert(partDir.exists()) + + checkAnswer(spark.table("t1"), Row(2, 1)) + } + } + } + Seq(true, false).foreach { shouldDelete => - val tcName = if (shouldDelete) "non-existent" else "existed" + val tcName = if (shouldDelete) "non-existing" else "existed" test(s"CTAS for external data source table with a $tcName location") { withTable("t", "t1") { - withTempDir { - dir => - if (shouldDelete) { - dir.delete() - } - spark.sql( - s""" - |CREATE TABLE t - |USING parquet - |LOCATION '$dir' - |AS SELECT 3 as a, 4 as b, 1 as c, 2 as d - """.stripMargin) - val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t")) - assert(table.location == makeQualifiedPath(dir.getAbsolutePath)) + withTempDir { dir => + if (shouldDelete) dir.delete() + spark.sql( + s""" + |CREATE TABLE t + |USING parquet + |LOCATION '$dir' + |AS SELECT 3 as a, 4 as b, 1 as c, 2 as d + """.stripMargin) + val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t")) + assert(table.location == makeQualifiedPath(dir.getAbsolutePath)) - checkAnswer(spark.table("t"), Row(3, 4, 1, 2)) + checkAnswer(spark.table("t"), Row(3, 4, 1, 2)) } // partition table - withTempDir { - dir => - if (shouldDelete) { - dir.delete() - } - spark.sql( - s""" - |CREATE TABLE t1 - |USING parquet - |PARTITIONED BY(a, b) - |LOCATION '$dir' - |AS SELECT 3 as a, 4 as b, 1 as c, 2 as d - """.stripMargin) - val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t1")) - assert(table.location == makeQualifiedPath(dir.getAbsolutePath)) - - val partDir = new File(dir, "a=3") - assert(partDir.exists()) - - checkAnswer(spark.table("t1"), Row(1, 2, 3, 4)) + withTempDir { dir => + if (shouldDelete) dir.delete() + spark.sql( + s""" + |CREATE TABLE t1 + |USING parquet + |PARTITIONED BY(a, b) + |LOCATION '$dir' + |AS SELECT 3 as a, 4 as b, 1 as c, 2 as d + """.stripMargin) + val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t1")) + assert(table.location == makeQualifiedPath(dir.getAbsolutePath)) + + val partDir = new File(dir, "a=3") + assert(partDir.exists()) + + checkAnswer(spark.table("t1"), Row(1, 2, 3, 4)) } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala index 79ad156c55611..d29242bb47e36 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala @@ -1663,43 +1663,73 @@ class HiveDDLSuite } } + test("create hive table with a non-existing location") { + withTable("t", "t1") { + withTempPath { dir => + spark.sql(s"CREATE TABLE t(a int, b int) USING hive LOCATION '$dir'") + + val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t")) + assert(table.location == makeQualifiedPath(dir.getAbsolutePath)) + + spark.sql("INSERT INTO TABLE t SELECT 1, 2") + assert(dir.exists()) + + checkAnswer(spark.table("t"), Row(1, 2)) + } + // partition table + withTempPath { dir => + spark.sql( + s""" + |CREATE TABLE t1(a int, b int) + |USING hive + |PARTITIONED BY(a) + |LOCATION '$dir' + """.stripMargin) + + val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t1")) + assert(table.location == makeQualifiedPath(dir.getAbsolutePath)) + + spark.sql("INSERT INTO TABLE t1 PARTITION(a=1) SELECT 2") + + val partDir = new File(dir, "a=1") + assert(partDir.exists()) + + checkAnswer(spark.table("t1"), Row(2, 1)) + } + } + } + Seq(true, false).foreach { shouldDelete => - val tcName = if (shouldDelete) "non-existent" else "existed" - test(s"CTAS for external data source table with a $tcName location") { + val tcName = if (shouldDelete) "non-existing" else "existed" + + test(s"CTAS for external hive table with a $tcName location") { withTable("t", "t1") { - withTempDir { - dir => - if (shouldDelete) { - dir.delete() - } + withSQLConf("hive.exec.dynamic.partition.mode" -> "nonstrict") { + withTempDir { dir => + if (shouldDelete) dir.delete() spark.sql( s""" |CREATE TABLE t - |USING parquet + |USING hive |LOCATION '$dir' |AS SELECT 3 as a, 4 as b, 1 as c, 2 as d """.stripMargin) - val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t")) assert(table.location == makeQualifiedPath(dir.getAbsolutePath)) checkAnswer(spark.table("t"), Row(3, 4, 1, 2)) - } - // partition table - withTempDir { - dir => - if (shouldDelete) { - dir.delete() - } + } + // partition table + withTempDir { dir => + if (shouldDelete) dir.delete() spark.sql( s""" |CREATE TABLE t1 - |USING parquet + |USING hive |PARTITIONED BY(a, b) |LOCATION '$dir' |AS SELECT 3 as a, 4 as b, 1 as c, 2 as d """.stripMargin) - val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t1")) assert(table.location == makeQualifiedPath(dir.getAbsolutePath)) @@ -1707,51 +1737,6 @@ class HiveDDLSuite assert(partDir.exists()) checkAnswer(spark.table("t1"), Row(1, 2, 3, 4)) - } - } - } - - test(s"CTAS for external hive table with a $tcName location") { - withTable("t", "t1") { - withSQLConf("hive.exec.dynamic.partition.mode" -> "nonstrict") { - withTempDir { - dir => - if (shouldDelete) { - dir.delete() - } - spark.sql( - s""" - |CREATE TABLE t - |USING hive - |LOCATION '$dir' - |AS SELECT 3 as a, 4 as b, 1 as c, 2 as d - """.stripMargin) - val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t")) - assert(table.location == makeQualifiedPath(dir.getAbsolutePath)) - - checkAnswer(spark.table("t"), Row(3, 4, 1, 2)) - } - // partition table - withTempDir { - dir => - if (shouldDelete) { - dir.delete() - } - spark.sql( - s""" - |CREATE TABLE t1 - |USING hive - |PARTITIONED BY(a, b) - |LOCATION '$dir' - |AS SELECT 3 as a, 4 as b, 1 as c, 2 as d - """.stripMargin) - val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t1")) - assert(table.location == makeQualifiedPath(dir.getAbsolutePath)) - - val partDir = new File(dir, "a=3") - assert(partDir.exists()) - - checkAnswer(spark.table("t1"), Row(1, 2, 3, 4)) } } } From e29a74d5b1fa3f9356b7af5dd7e3fce49bc8eb7d Mon Sep 17 00:00:00 2001 From: uncleGen Date: Sun, 12 Mar 2017 08:29:37 +0000 Subject: [PATCH 76/78] [DOCS][SS] fix structured streaming python example ## What changes were proposed in this pull request? - SS python example: `TypeError: 'xxx' object is not callable` - some other doc issue. ## How was this patch tested? Jenkins. Author: uncleGen Closes #17257 from uncleGen/docs-ss-python. --- docs/structured-streaming-programming-guide.md | 18 +++++++++--------- .../execution/streaming/FileStreamSource.scala | 2 +- .../streaming/dstream/FileInputDStream.scala | 2 +- 3 files changed, 11 insertions(+), 11 deletions(-) diff --git a/docs/structured-streaming-programming-guide.md b/docs/structured-streaming-programming-guide.md index 995ac77a4fb3b..798847237866b 100644 --- a/docs/structured-streaming-programming-guide.md +++ b/docs/structured-streaming-programming-guide.md @@ -539,7 +539,7 @@ spark = SparkSession. ... # Read text from socket socketDF = spark \ - .readStream() \ + .readStream \ .format("socket") \ .option("host", "localhost") \ .option("port", 9999) \ @@ -552,7 +552,7 @@ socketDF.printSchema() # Read all the csv files written atomically in a directory userSchema = StructType().add("name", "string").add("age", "integer") csvDF = spark \ - .readStream() \ + .readStream \ .option("sep", ";") \ .schema(userSchema) \ .csv("/path/to/directory") # Equivalent to format("csv").load("/path/to/directory") @@ -971,7 +971,7 @@ Here is the compatibility matrix.

Update mode uses watermark to drop old aggregation state.

- Complete mode does drop not old aggregation state since by definition this mode + Complete mode does not drop old aggregation state since by definition this mode preserves all data in the Result Table. @@ -1201,13 +1201,13 @@ noAggDF = deviceDataDf.select("device").where("signal > 10") # Print new data to console noAggDF \ - .writeStream() \ + .writeStream \ .format("console") \ .start() # Write new data to Parquet files noAggDF \ - .writeStream() \ + .writeStream \ .format("parquet") \ .option("checkpointLocation", "path/to/checkpoint/dir") \ .option("path", "path/to/destination/dir") \ @@ -1218,14 +1218,14 @@ aggDF = df.groupBy("device").count() # Print updated aggregations to console aggDF \ - .writeStream() \ + .writeStream \ .outputMode("complete") \ .format("console") \ .start() # Have all the aggregates in an in memory table. The query name will be the table name aggDF \ - .writeStream() \ + .writeStream \ .queryName("aggregates") \ .outputMode("complete") \ .format("memory") \ @@ -1313,7 +1313,7 @@ query.lastProgress(); // the most recent progress update of this streaming qu
{% highlight python %} -query = df.writeStream().format("console").start() # get the query object +query = df.writeStream.format("console").start() # get the query object query.id() # get the unique identifier of the running query that persists across restarts from checkpoint data @@ -1658,7 +1658,7 @@ aggDF {% highlight python %} aggDF \ - .writeStream() \ + .writeStream \ .outputMode("complete") \ .option("checkpointLocation", "path/to/HDFS/dir") \ .format("memory") \ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala index 411a15ffceb6a..a9e64c640042a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala @@ -97,7 +97,7 @@ class FileStreamSource( } seenFiles.purge() - logInfo(s"maxFilesPerBatch = $maxFilesPerBatch, maxFileAge = $maxFileAgeMs") + logInfo(s"maxFilesPerBatch = $maxFilesPerBatch, maxFileAgeMs = $maxFileAgeMs") /** * Returns the maximum offset that can be retrieved from the source. diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala index ed9305875cb77..905b1c52afa69 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala @@ -230,7 +230,7 @@ class FileInputDStream[K, V, F <: NewInputFormat[K, V]]( * - It must pass the user-provided file filter. * - It must be newer than the ignore threshold. It is assumed that files older than the ignore * threshold have already been considered or are existing files before start - * (when newFileOnly = true). + * (when newFilesOnly = true). * - It must not be present in the recently selected files that this class remembers. * - It must not be newer than the time of the batch (i.e. `currentTime` for which this * file is being tested. This can occur if the driver was recovered, and the missing batches From 2f5187bde1544c452fe5116a2bd243653332a079 Mon Sep 17 00:00:00 2001 From: "xiaojian.fxj" Date: Sun, 12 Mar 2017 10:29:00 -0700 Subject: [PATCH 77/78] [SPARK-19831][CORE] Reuse the existing cleanupThreadExecutor to clean up the directories of finished applications to avoid the block Cleaning the application may cost much time at worker, then it will block that the worker send heartbeats master because the worker is extend ThreadSafeRpcEndpoint. If the heartbeat from a worker is blocked by the message ApplicationFinished, master will think the worker is dead. If the worker has a driver, the driver will be scheduled by master again. It had better reuse the existing cleanupThreadExecutor to clean up the directories of finished applications to avoid the block. Author: xiaojian.fxj Closes #17189 from hustfxj/worker-hearbeat. --- .../org/apache/spark/deploy/worker/Worker.scala | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index e48817ebbafdd..00b9d1af373db 100755 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -62,8 +62,8 @@ private[deploy] class Worker( private val forwordMessageScheduler = ThreadUtils.newDaemonSingleThreadScheduledExecutor("worker-forward-message-scheduler") - // A separated thread to clean up the workDir. Used to provide the implicit parameter of `Future` - // methods. + // A separated thread to clean up the workDir and the directories of finished applications. + // Used to provide the implicit parameter of `Future` methods. private val cleanupThreadExecutor = ExecutionContext.fromExecutorService( ThreadUtils.newDaemonSingleThreadExecutor("worker-cleanup-thread")) @@ -578,10 +578,15 @@ private[deploy] class Worker( if (shouldCleanup) { finishedApps -= id appDirectories.remove(id).foreach { dirList => - logInfo(s"Cleaning up local directories for application $id") - dirList.foreach { dir => - Utils.deleteRecursively(new File(dir)) - } + concurrent.Future { + logInfo(s"Cleaning up local directories for application $id") + dirList.foreach { dir => + Utils.deleteRecursively(new File(dir)) + } + }(cleanupThreadExecutor).onFailure { + case e: Throwable => + logError(s"Clean up app dir $dirList failed: ${e.getMessage}", e) + }(cleanupThreadExecutor) } shuffleService.applicationRemoved(id) } From 9f8ce4825e378b6a856ce65cb9986a5a0f0b624e Mon Sep 17 00:00:00 2001 From: Xin Ren Date: Sun, 12 Mar 2017 12:15:19 -0700 Subject: [PATCH 78/78] [SPARK-19282][ML][SPARKR] RandomForest Wrapper and GBT Wrapper return param "maxDepth" to R models ## What changes were proposed in this pull request? RandomForest R Wrapper and GBT R Wrapper return param `maxDepth` to R models. Below 4 R wrappers are changed: * `RandomForestClassificationWrapper` * `RandomForestRegressionWrapper` * `GBTClassificationWrapper` * `GBTRegressionWrapper` ## How was this patch tested? Test manually on my local machine. Author: Xin Ren Closes #17207 from keypointt/SPARK-19282. --- R/pkg/R/mllib_tree.R | 11 +++++++---- R/pkg/inst/tests/testthat/test_mllib_tree.R | 10 ++++++++++ .../apache/spark/ml/r/GBTClassificationWrapper.scala | 1 + .../org/apache/spark/ml/r/GBTRegressionWrapper.scala | 1 + .../ml/r/RandomForestClassificationWrapper.scala | 1 + .../spark/ml/r/RandomForestRegressionWrapper.scala | 1 + 6 files changed, 21 insertions(+), 4 deletions(-) diff --git a/R/pkg/R/mllib_tree.R b/R/pkg/R/mllib_tree.R index 40a806c41bad0..82279be6fbe77 100644 --- a/R/pkg/R/mllib_tree.R +++ b/R/pkg/R/mllib_tree.R @@ -52,12 +52,14 @@ summary.treeEnsemble <- function(model) { numFeatures <- callJMethod(jobj, "numFeatures") features <- callJMethod(jobj, "features") featureImportances <- callJMethod(callJMethod(jobj, "featureImportances"), "toString") + maxDepth <- callJMethod(jobj, "maxDepth") numTrees <- callJMethod(jobj, "numTrees") treeWeights <- callJMethod(jobj, "treeWeights") list(formula = formula, numFeatures = numFeatures, features = features, featureImportances = featureImportances, + maxDepth = maxDepth, numTrees = numTrees, treeWeights = treeWeights, jobj = jobj) @@ -70,6 +72,7 @@ print.summary.treeEnsemble <- function(x) { cat("\nNumber of features: ", x$numFeatures) cat("\nFeatures: ", unlist(x$features)) cat("\nFeature importances: ", x$featureImportances) + cat("\nMax Depth: ", x$maxDepth) cat("\nNumber of trees: ", x$numTrees) cat("\nTree weights: ", unlist(x$treeWeights)) @@ -197,8 +200,8 @@ setMethod("spark.gbt", signature(data = "SparkDataFrame", formula = "formula"), #' @return \code{summary} returns summary information of the fitted model, which is a list. #' The list of components includes \code{formula} (formula), #' \code{numFeatures} (number of features), \code{features} (list of features), -#' \code{featureImportances} (feature importances), \code{numTrees} (number of trees), -#' and \code{treeWeights} (tree weights). +#' \code{featureImportances} (feature importances), \code{maxDepth} (max depth of trees), +#' \code{numTrees} (number of trees), and \code{treeWeights} (tree weights). #' @rdname spark.gbt #' @aliases summary,GBTRegressionModel-method #' @export @@ -403,8 +406,8 @@ setMethod("spark.randomForest", signature(data = "SparkDataFrame", formula = "fo #' @return \code{summary} returns summary information of the fitted model, which is a list. #' The list of components includes \code{formula} (formula), #' \code{numFeatures} (number of features), \code{features} (list of features), -#' \code{featureImportances} (feature importances), \code{numTrees} (number of trees), -#' and \code{treeWeights} (tree weights). +#' \code{featureImportances} (feature importances), \code{maxDepth} (max depth of trees), +#' \code{numTrees} (number of trees), and \code{treeWeights} (tree weights). #' @rdname spark.randomForest #' @aliases summary,RandomForestRegressionModel-method #' @export diff --git a/R/pkg/inst/tests/testthat/test_mllib_tree.R b/R/pkg/inst/tests/testthat/test_mllib_tree.R index e6fda251ebea2..e0802a9b02d13 100644 --- a/R/pkg/inst/tests/testthat/test_mllib_tree.R +++ b/R/pkg/inst/tests/testthat/test_mllib_tree.R @@ -39,6 +39,7 @@ test_that("spark.gbt", { tolerance = 1e-4) stats <- summary(model) expect_equal(stats$numTrees, 20) + expect_equal(stats$maxDepth, 5) expect_equal(stats$formula, "Employed ~ .") expect_equal(stats$numFeatures, 6) expect_equal(length(stats$treeWeights), 20) @@ -53,6 +54,7 @@ test_that("spark.gbt", { expect_equal(stats$numFeatures, stats2$numFeatures) expect_equal(stats$features, stats2$features) expect_equal(stats$featureImportances, stats2$featureImportances) + expect_equal(stats$maxDepth, stats2$maxDepth) expect_equal(stats$numTrees, stats2$numTrees) expect_equal(stats$treeWeights, stats2$treeWeights) @@ -66,6 +68,7 @@ test_that("spark.gbt", { stats <- summary(model) expect_equal(stats$numFeatures, 2) expect_equal(stats$numTrees, 20) + expect_equal(stats$maxDepth, 5) expect_error(capture.output(stats), NA) expect_true(length(capture.output(stats)) > 6) predictions <- collect(predict(model, data))$prediction @@ -93,6 +96,7 @@ test_that("spark.gbt", { expect_equal(iris2$NumericSpecies, as.double(collect(predict(m, df))$prediction)) expect_equal(s$numFeatures, 5) expect_equal(s$numTrees, 20) + expect_equal(stats$maxDepth, 5) # spark.gbt classification can work on libsvm data data <- read.df(absoluteSparkPath("data/mllib/sample_binary_classification_data.txt"), @@ -116,6 +120,7 @@ test_that("spark.randomForest", { stats <- summary(model) expect_equal(stats$numTrees, 1) + expect_equal(stats$maxDepth, 5) expect_error(capture.output(stats), NA) expect_true(length(capture.output(stats)) > 6) @@ -129,6 +134,7 @@ test_that("spark.randomForest", { tolerance = 1e-4) stats <- summary(model) expect_equal(stats$numTrees, 20) + expect_equal(stats$maxDepth, 5) modelPath <- tempfile(pattern = "spark-randomForestRegression", fileext = ".tmp") write.ml(model, modelPath) @@ -141,6 +147,7 @@ test_that("spark.randomForest", { expect_equal(stats$features, stats2$features) expect_equal(stats$featureImportances, stats2$featureImportances) expect_equal(stats$numTrees, stats2$numTrees) + expect_equal(stats$maxDepth, stats2$maxDepth) expect_equal(stats$treeWeights, stats2$treeWeights) unlink(modelPath) @@ -153,6 +160,7 @@ test_that("spark.randomForest", { stats <- summary(model) expect_equal(stats$numFeatures, 2) expect_equal(stats$numTrees, 20) + expect_equal(stats$maxDepth, 5) expect_error(capture.output(stats), NA) expect_true(length(capture.output(stats)) > 6) # Test string prediction values @@ -187,6 +195,8 @@ test_that("spark.randomForest", { stats <- summary(model) expect_equal(stats$numFeatures, 2) expect_equal(stats$numTrees, 20) + expect_equal(stats$maxDepth, 5) + # Test numeric prediction values predictions <- collect(predict(model, data))$prediction expect_equal(length(grep("1.0", predictions)), 50) diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/GBTClassificationWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/GBTClassificationWrapper.scala index aacb41ee2659b..c07eadb30a4d2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/GBTClassificationWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/GBTClassificationWrapper.scala @@ -44,6 +44,7 @@ private[r] class GBTClassifierWrapper private ( lazy val featureImportances: Vector = gbtcModel.featureImportances lazy val numTrees: Int = gbtcModel.getNumTrees lazy val treeWeights: Array[Double] = gbtcModel.treeWeights + lazy val maxDepth: Int = gbtcModel.getMaxDepth def summary: String = gbtcModel.toDebugString diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/GBTRegressionWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/GBTRegressionWrapper.scala index 585077588eb9b..b568d7859221f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/GBTRegressionWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/GBTRegressionWrapper.scala @@ -42,6 +42,7 @@ private[r] class GBTRegressorWrapper private ( lazy val featureImportances: Vector = gbtrModel.featureImportances lazy val numTrees: Int = gbtrModel.getNumTrees lazy val treeWeights: Array[Double] = gbtrModel.treeWeights + lazy val maxDepth: Int = gbtrModel.getMaxDepth def summary: String = gbtrModel.toDebugString diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestClassificationWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestClassificationWrapper.scala index 366f375b58582..8a83d4e980f7b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestClassificationWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestClassificationWrapper.scala @@ -44,6 +44,7 @@ private[r] class RandomForestClassifierWrapper private ( lazy val featureImportances: Vector = rfcModel.featureImportances lazy val numTrees: Int = rfcModel.getNumTrees lazy val treeWeights: Array[Double] = rfcModel.treeWeights + lazy val maxDepth: Int = rfcModel.getMaxDepth def summary: String = rfcModel.toDebugString diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestRegressionWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestRegressionWrapper.scala index 4b9a3a731da9b..038bd79c7022b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestRegressionWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestRegressionWrapper.scala @@ -42,6 +42,7 @@ private[r] class RandomForestRegressorWrapper private ( lazy val featureImportances: Vector = rfrModel.featureImportances lazy val numTrees: Int = rfrModel.getNumTrees lazy val treeWeights: Array[Double] = rfrModel.treeWeights + lazy val maxDepth: Int = rfrModel.getMaxDepth def summary: String = rfrModel.toDebugString