From 36b6c83f5f46ed985699aab07931c50218821d5c Mon Sep 17 00:00:00 2001 From: Sunitha Kambhampati Date: Tue, 2 Jun 2020 11:30:30 +0000 Subject: [PATCH 1/5] KE-24858 [SPARK-28067][SQL] Fix incorrect results for decimal aggregate sum by returning null on decimal overflow --- .../analysis/StreamingJoinHelper.scala | 2 +- .../errors/QueryExecutionErrors.scala | 35 + .../catalyst/expressions/aggregate/Sum.scala | 90 +- .../expressions/decimalExpressions.scala | 126 +- .../apache/spark/sql/internal/SQLConf.scala | 12 + .../org/apache/spark/sql/types/Decimal.scala | 14 +- .../org/apache/spark/sql/DataFrameSuite.scala | 1134 ++++++++--------- 7 files changed, 754 insertions(+), 659 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/errors/QueryExecutionErrors.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/StreamingJoinHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/StreamingJoinHelper.scala index 7a0aa08289efa..c7b79df4035cd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/StreamingJoinHelper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/StreamingJoinHelper.scala @@ -236,7 +236,7 @@ object StreamingJoinHelper extends PredicateHelper with Logging { collect(left, negate) ++ collect(right, !negate) case UnaryMinus(child) => collect(child, !negate) - case CheckOverflow(child, _) => + case CheckOverflow(child, _, _) => collect(child, negate) case PromotePrecision(child) => collect(child, negate) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/errors/QueryExecutionErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/errors/QueryExecutionErrors.scala new file mode 100644 index 0000000000000..b84643d3293dd --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/errors/QueryExecutionErrors.scala @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.errors + +import org.apache.spark.sql.types._ + +/** + * Object for grouping error messages from (most) exceptions thrown during query execution. + * This does not include exceptions thrown during the eager execution of commands, which are + * grouped into [[QueryCompilationErrors]]. + */ +object QueryExecutionErrors { + + def cannotChangeDecimalPrecisionError( + value: Decimal, decimalPrecision: Int, decimalScale: Int): ArithmeticException = { + new ArithmeticException(s"${value.toDebugString} cannot be represented as " + + s"Decimal($decimalPrecision, $decimalScale).") + } + +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala index 761dba111c074..13facbdee64eb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala @@ -21,10 +21,22 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util.TypeUtils +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @ExpressionDescription( - usage = "_FUNC_(expr) - Returns the sum calculated from values of a group.") + usage = "_FUNC_(expr) - Returns the sum calculated from values of a group.", + examples = """ + Examples: + > SELECT _FUNC_(col) FROM VALUES (5), (10), (15) AS tab(col); + 30 + > SELECT _FUNC_(col) FROM VALUES (NULL), (10), (15) AS tab(col); + 25 + > SELECT _FUNC_(col) FROM VALUES (NULL), (NULL) AS tab(col); + NULL + """, + extended = "agg_funcs", + since = "1.0.0") case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCastInputTypes { override def children: Seq[Expression] = child :: Nil @@ -50,34 +62,74 @@ case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCast private lazy val sum = AttributeReference("sum", sumDataType)() - private lazy val zero = Cast(Literal(0), sumDataType) + private lazy val isEmpty = AttributeReference("isEmpty", BooleanType, nullable = false)() - override lazy val aggBufferAttributes = sum :: Nil + private lazy val zero = Literal.default(sumDataType) - override lazy val initialValues: Seq[Expression] = Seq( - /* sum = */ Literal.create(null, sumDataType) - ) + override lazy val aggBufferAttributes = resultType match { + case _: DecimalType => sum :: isEmpty :: Nil + case _ => sum :: Nil + } + + override lazy val initialValues: Seq[Expression] = resultType match { + case _: DecimalType => Seq(Literal(null, resultType), Literal(true, BooleanType)) + case _ => Seq(Literal(null, resultType)) + } override lazy val updateExpressions: Seq[Expression] = { if (child.nullable) { - Seq( - /* sum = */ - coalesce(coalesce(sum, zero) + child.cast(sumDataType), sum) - ) + val updateSumExpr = coalesce(coalesce(sum, zero) + child.cast(sumDataType), sum) + resultType match { + case _: DecimalType => + Seq(updateSumExpr, isEmpty && child.isNull) + case _ => Seq(updateSumExpr) + } } else { - Seq( - /* sum = */ - coalesce(sum, zero) + child.cast(sumDataType) - ) + val updateSumExpr = coalesce(sum, zero) + child.cast(sumDataType) + resultType match { + case _: DecimalType => + Seq(updateSumExpr, Literal(false, BooleanType)) + case _ => Seq(updateSumExpr) + } } } + /** + * For decimal type: + * If isEmpty is false and if sum is null, then it means we have had an overflow. + * + * update of the sum is as follows: + * Check if either portion of the left.sum or right.sum has overflowed + * If it has, then the sum value will remain null. + * If it did not have overflow, then add the sum.left and sum.right + * + * isEmpty: Set to false if either one of the left or right is set to false. This + * means we have seen atleast a value that was not null. + */ override lazy val mergeExpressions: Seq[Expression] = { - Seq( - /* sum = */ - coalesce(coalesce(sum.left, zero) + sum.right, sum.left) - ) + val mergeSumExpr = coalesce(coalesce(sum.left, zero) + sum.right, sum.left) + resultType match { + case _: DecimalType => + val inputOverflow = !isEmpty.right && sum.right.isNull + val bufferOverflow = !isEmpty.left && sum.left.isNull + Seq( + If(inputOverflow || bufferOverflow, Literal.create(null, sumDataType), mergeSumExpr), + isEmpty.left && isEmpty.right) + case _ => Seq(mergeSumExpr) + } } - override lazy val evaluateExpression: Expression = sum + /** + * If the isEmpty is true, then it means there were no values to begin with or all the values + * were null, so the result will be null. + * If the isEmpty is false, then if sum is null that means an overflow has happened. + * So now, if ansi is enabled, then throw exception, if not then return null. + * If sum is not null, then return the sum. + */ + override lazy val evaluateExpression: Expression = resultType match { + case d: DecimalType => + If(isEmpty, Literal.create(null, sumDataType), + CheckOverflowInSum(sum, d, !SQLConf.get.ansiEnabled)) + case _ => sum + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala index 04de83343be71..7e4560ab8161b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, EmptyBlock, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ /** @@ -26,7 +28,7 @@ import org.apache.spark.sql.types._ * Note: this expression is internal and created only by the optimizer, * we don't need to do type check for it. */ -case class UnscaledValue(child: Expression) extends UnaryExpression { +case class UnscaledValue(child: Expression) extends UnaryExpression with NullIntolerant { override def dataType: DataType = LongType override def toString: String = s"UnscaledValue($child)" @@ -44,25 +46,56 @@ case class UnscaledValue(child: Expression) extends UnaryExpression { * Note: this expression is internal and created only by the optimizer, * we don't need to do type check for it. */ -case class MakeDecimal(child: Expression, precision: Int, scale: Int) extends UnaryExpression { +case class MakeDecimal( + child: Expression, + precision: Int, + scale: Int, + nullOnOverflow: Boolean) extends UnaryExpression with NullIntolerant { + + def this(child: Expression, precision: Int, scale: Int) = { + this(child, precision, scale, !SQLConf.get.ansiEnabled) + } override def dataType: DataType = DecimalType(precision, scale) - override def nullable: Boolean = true + override def nullable: Boolean = child.nullable || nullOnOverflow override def toString: String = s"MakeDecimal($child,$precision,$scale)" - protected override def nullSafeEval(input: Any): Any = - Decimal(input.asInstanceOf[Long], precision, scale) + protected override def nullSafeEval(input: Any): Any = { + val longInput = input.asInstanceOf[Long] + val result = new Decimal() + if (nullOnOverflow) { + result.setOrNull(longInput, precision, scale) + } else { + result.set(longInput, precision, scale) + } + } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, eval => { + val setMethod = if (nullOnOverflow) { + "setOrNull" + } else { + "set" + } + val setNull = if (nullable) { + s"${ev.isNull} = ${ev.value} == null;" + } else { + "" + } s""" - ${ev.value} = (new Decimal()).setOrNull($eval, $precision, $scale); - ${ev.isNull} = ${ev.value} == null; - """ + |${ev.value} = (new Decimal()).$setMethod($eval, $precision, $scale); + |$setNull + |""".stripMargin }) } } +object MakeDecimal { + def apply(child: Expression, precision: Int, scale: Int): MakeDecimal = { + new MakeDecimal(child, precision, scale) + } +} + /** * An expression used to wrap the children when promote the precision of DecimalType to avoid * promote multiple times. @@ -81,30 +114,85 @@ case class PromotePrecision(child: Expression) extends UnaryExpression { /** * Rounds the decimal to given scale and check whether the decimal can fit in provided precision - * or not, returns null if not. + * or not. If not, if `nullOnOverflow` is `true`, it returns `null`; otherwise an + * `ArithmeticException` is thrown. */ -case class CheckOverflow(child: Expression, dataType: DecimalType) extends UnaryExpression { +case class CheckOverflow( + child: Expression, + dataType: DecimalType, + nullOnOverflow: Boolean) extends UnaryExpression { override def nullable: Boolean = true override def nullSafeEval(input: Any): Any = - input.asInstanceOf[Decimal].toPrecision(dataType.precision, dataType.scale) + input.asInstanceOf[Decimal].toPrecision( + dataType.precision, + dataType.scale, + Decimal.ROUND_HALF_UP, + nullOnOverflow) override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, eval => { - val tmp = ctx.freshName("tmp") s""" - | Decimal $tmp = $eval.clone(); - | if ($tmp.changePrecision(${dataType.precision}, ${dataType.scale})) { - | ${ev.value} = $tmp; - | } else { - | ${ev.isNull} = true; - | } + |${ev.value} = $eval.toPrecision( + | ${dataType.precision}, ${dataType.scale}, Decimal.ROUND_HALF_UP(), $nullOnOverflow); + |${ev.isNull} = ${ev.value} == null; """.stripMargin }) } - override def toString: String = s"CheckOverflow($child, $dataType)" + override def toString: String = s"CheckOverflow($child, $dataType, $nullOnOverflow)" + + override def sql: String = child.sql +} + +// A variant `CheckOverflow`, which treats null as overflow. This is necessary in `Sum`. +case class CheckOverflowInSum( + child: Expression, + dataType: DecimalType, + nullOnOverflow: Boolean) extends UnaryExpression { + + override def nullable: Boolean = true + + override def eval(input: InternalRow): Any = { + val value = child.eval(input) + if (value == null) { + if (nullOnOverflow) null else throw new ArithmeticException("Overflow in sum of decimals.") + } else { + value.asInstanceOf[Decimal].toPrecision( + dataType.precision, + dataType.scale, + Decimal.ROUND_HALF_UP, + nullOnOverflow) + } + } + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val childGen = child.genCode(ctx) + val nullHandling = if (nullOnOverflow) { + "" + } else { + s""" + |throw new ArithmeticException("Overflow in sum of decimals."); + |""".stripMargin + } + val code = code""" + |${childGen.code} + |boolean ${ev.isNull} = ${childGen.isNull}; + |Decimal ${ev.value} = null; + |if (${childGen.isNull}) { + | $nullHandling + |} else { + | ${ev.value} = ${childGen.value}.toPrecision( + | ${dataType.precision}, ${dataType.scale}, Decimal.ROUND_HALF_UP(), $nullOnOverflow); + | ${ev.isNull} = ${ev.value} == null; + |} + |""".stripMargin + + ev.copy(code = code) + } + + override def toString: String = s"CheckOverflowInSum($child, $dataType, $nullOnOverflow)" override def sql: String = child.sql } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index d7409c5efa372..c9c898a8de344 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -150,6 +150,16 @@ object SQLConf { } } + val ANSI_ENABLED = buildConf("spark.sql.ansi.enabled") + .doc("When true, Spark SQL uses an ANSI compliant dialect instead of being Hive compliant. " + + "For example, Spark will throw an exception at runtime instead of returning null results " + + "when the inputs to a SQL operator/function are invalid." + + "For full details of this dialect, you can find them in the section \"ANSI Compliance\" of " + + "Spark's documentation. Some ANSI dialect features may be not from the ANSI SQL " + + "standard directly, but their behaviors align with ANSI SQL's style") + .booleanConf + .createWithDefault(false) + val OPTIMIZER_EXCLUDED_RULES = buildConf("spark.sql.optimizer.excludedRules") .doc("Configures a list of rules to be disabled in the optimizer, in which the rules are " + "specified by their rule names and separated by comma. It is not guaranteed that all the " + @@ -1617,6 +1627,8 @@ class SQLConf extends Serializable with Logging { /** ************************ Spark SQL Params/Hints ******************* */ + def ansiEnabled: Boolean = getConf(ANSI_ENABLED) + def optimizerExcludedRules: Option[String] = getConf(OPTIMIZER_EXCLUDED_RULES) def optimizerMaxIterations: Int = getConf(OPTIMIZER_MAX_ITERATIONS) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala index 9eed2eb202045..63696658fd3ac 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala @@ -22,6 +22,7 @@ import java.math.{BigInteger, MathContext, RoundingMode} import org.apache.spark.annotation.InterfaceStability import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.errors.QueryExecutionErrors /** * A mutable implementation of BigDecimal that can hold a Long if values are small enough. @@ -242,9 +243,18 @@ final class Decimal extends Ordered[Decimal] with Serializable { private[sql] def toPrecision( precision: Int, scale: Int, - roundMode: BigDecimal.RoundingMode.Value = ROUND_HALF_UP): Decimal = { + roundMode: BigDecimal.RoundingMode.Value = ROUND_HALF_UP, + nullOnOverflow: Boolean = true): Decimal = { val copy = clone() - if (copy.changePrecision(precision, scale, roundMode)) copy else null + if (copy.changePrecision(precision, scale, roundMode)) { + copy + } else { + if (nullOnOverflow) { + null + } else { + throw QueryExecutionErrors.cannotChangeDecimalPrecisionError(this, precision, scale) + } + } } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 5075209d7454f..ac97288ddb6b9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -17,39 +17,41 @@ package org.apache.spark.sql -import java.io.File +import java.io.{ByteArrayOutputStream, File} import java.nio.charset.StandardCharsets import java.sql.{Date, Timestamp} import java.util.UUID - +import java.util.concurrent.atomic.AtomicLong +import scala.reflect.runtime.universe.TypeTag import scala.util.Random - -import org.scalatest.Matchers._ - +import org.scalatest.Matchers.{assert, _} import org.apache.spark.SparkException import org.apache.spark.scheduler.{SparkListener, SparkListenerJobEnd} import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder} import org.apache.spark.sql.catalyst.expressions.Uuid import org.apache.spark.sql.catalyst.optimizer.ConvertToLocalRelation -import org.apache.spark.sql.catalyst.plans.logical.{Filter, OneRowRelation, Union} +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, OneRowRelation} +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.{FilterExec, QueryExecution, WholeStageCodegenExec} import org.apache.spark.sql.execution.aggregate.HashAggregateExec import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.test.{ExamplePoint, ExamplePointUDT, SharedSQLContext} -import org.apache.spark.sql.test.SQLTestData.{NullInts, NullStrings, TestData2} +import org.apache.spark.sql.test.{ExamplePoint, ExamplePointUDT, SharedSparkSession} +import org.apache.spark.sql.test.SQLTestData.{DecimalData, TestData2} import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.CalendarInterval import org.apache.spark.util.Utils import org.apache.spark.util.random.XORShiftRandom -class DataFrameSuite extends QueryTest with SharedSQLContext { +class DataFrameSuite extends QueryTest with SharedSparkSession { import testImplicits._ test("analysis error should be eagerly reported") { - intercept[Exception] { testData.select('nonExistentName) } + intercept[Exception] { testData.select("nonExistentName") } intercept[Exception] { - testData.groupBy('key).agg(Map("nonExistentName" -> "sum")) + testData.groupBy("key").agg(Map("nonExistentName" -> "sum")) } intercept[Exception] { testData.groupBy("nonExistentName").agg(Map("key" -> "sum")) @@ -85,129 +87,12 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { testData.collect().toSeq) } - test("union all") { - val unionDF = testData.union(testData).union(testData) - .union(testData).union(testData) - - // Before optimizer, Union should be combined. - assert(unionDF.queryExecution.analyzed.collect { - case j: Union if j.children.size == 5 => j }.size === 1) - - checkAnswer( - unionDF.agg(avg('key), max('key), min('key), sum('key)), - Row(50.5, 100, 1, 25250) :: Nil - ) - } - - test("union should union DataFrames with UDTs (SPARK-13410)") { - val rowRDD1 = sparkContext.parallelize(Seq(Row(1, new ExamplePoint(1.0, 2.0)))) - val schema1 = StructType(Array(StructField("label", IntegerType, false), - StructField("point", new ExamplePointUDT(), false))) - val rowRDD2 = sparkContext.parallelize(Seq(Row(2, new ExamplePoint(3.0, 4.0)))) - val schema2 = StructType(Array(StructField("label", IntegerType, false), - StructField("point", new ExamplePointUDT(), false))) - val df1 = spark.createDataFrame(rowRDD1, schema1) - val df2 = spark.createDataFrame(rowRDD2, schema2) - - checkAnswer( - df1.union(df2).orderBy("label"), - Seq(Row(1, new ExamplePoint(1.0, 2.0)), Row(2, new ExamplePoint(3.0, 4.0))) - ) - } - - test("union by name") { - var df1 = Seq((1, 2, 3)).toDF("a", "b", "c") - var df2 = Seq((3, 1, 2)).toDF("c", "a", "b") - val df3 = Seq((2, 3, 1)).toDF("b", "c", "a") - val unionDf = df1.unionByName(df2.unionByName(df3)) - checkAnswer(unionDf, - Row(1, 2, 3) :: Row(1, 2, 3) :: Row(1, 2, 3) :: Nil - ) - - // Check if adjacent unions are combined into a single one - assert(unionDf.queryExecution.optimizedPlan.collect { case u: Union => true }.size == 1) - - // Check failure cases - df1 = Seq((1, 2)).toDF("a", "c") - df2 = Seq((3, 4, 5)).toDF("a", "b", "c") - var errMsg = intercept[AnalysisException] { - df1.unionByName(df2) - }.getMessage - assert(errMsg.contains( - "Union can only be performed on tables with the same number of columns, " + - "but the first table has 2 columns and the second table has 3 columns")) - - df1 = Seq((1, 2, 3)).toDF("a", "b", "c") - df2 = Seq((4, 5, 6)).toDF("a", "c", "d") - errMsg = intercept[AnalysisException] { - df1.unionByName(df2) - }.getMessage - assert(errMsg.contains("""Cannot resolve column name "b" among (a, c, d)""")) - } - - test("union by name - type coercion") { - var df1 = Seq((1, "a")).toDF("c0", "c1") - var df2 = Seq((3, 1L)).toDF("c1", "c0") - checkAnswer(df1.unionByName(df2), Row(1L, "a") :: Row(1L, "3") :: Nil) - - df1 = Seq((1, 1.0)).toDF("c0", "c1") - df2 = Seq((8L, 3.0)).toDF("c1", "c0") - checkAnswer(df1.unionByName(df2), Row(1.0, 1.0) :: Row(3.0, 8.0) :: Nil) - - df1 = Seq((2.0f, 7.4)).toDF("c0", "c1") - df2 = Seq(("a", 4.0)).toDF("c1", "c0") - checkAnswer(df1.unionByName(df2), Row(2.0, "7.4") :: Row(4.0, "a") :: Nil) - - df1 = Seq((1, "a", 3.0)).toDF("c0", "c1", "c2") - df2 = Seq((1.2, 2, "bc")).toDF("c2", "c0", "c1") - val df3 = Seq(("def", 1.2, 3)).toDF("c1", "c2", "c0") - checkAnswer(df1.unionByName(df2.unionByName(df3)), - Row(1, "a", 3.0) :: Row(2, "bc", 1.2) :: Row(3, "def", 1.2) :: Nil - ) - } - - test("union by name - check case sensitivity") { - def checkCaseSensitiveTest(): Unit = { - val df1 = Seq((1, 2, 3)).toDF("ab", "cd", "ef") - val df2 = Seq((4, 5, 6)).toDF("cd", "ef", "AB") - checkAnswer(df1.unionByName(df2), Row(1, 2, 3) :: Row(6, 4, 5) :: Nil) - } - withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { - val errMsg2 = intercept[AnalysisException] { - checkCaseSensitiveTest() - }.getMessage - assert(errMsg2.contains("""Cannot resolve column name "ab" among (cd, ef, AB)""")) - } - withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { - checkCaseSensitiveTest() - } - } - - test("union by name - check name duplication") { - Seq((true, ("a", "a")), (false, ("aA", "Aa"))).foreach { case (caseSensitive, (c0, c1)) => - withSQLConf(SQLConf.CASE_SENSITIVE.key -> caseSensitive.toString) { - var df1 = Seq((1, 1)).toDF(c0, c1) - var df2 = Seq((1, 1)).toDF("c0", "c1") - var errMsg = intercept[AnalysisException] { - df1.unionByName(df2) - }.getMessage - assert(errMsg.contains("Found duplicate column(s) in the left attributes:")) - df1 = Seq((1, 1)).toDF("c0", "c1") - df2 = Seq((1, 1)).toDF(c0, c1) - errMsg = intercept[AnalysisException] { - df1.unionByName(df2) - }.getMessage - assert(errMsg.contains("Found duplicate column(s) in the right attributes:")) - } - } - } - test("empty data frame") { assert(spark.emptyDataFrame.columns.toSeq === Seq.empty[String]) assert(spark.emptyDataFrame.count() === 0) } - test("head and take") { + test("head, take") { assert(testData.take(2) === testData.collect().take(2)) assert(testData.head(2) === testData.collect().take(2)) assert(testData.head(2).head.schema === testData.schema) @@ -248,8 +133,10 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { test("Star Expansion - CreateStruct and CreateArray") { val structDf = testData2.select("a", "b").as("record") // CreateStruct and CreateArray in aggregateExpressions - assert(structDf.groupBy($"a").agg(min(struct($"record.*"))).first() == Row(3, Row(3, 1))) - assert(structDf.groupBy($"a").agg(min(array($"record.*"))).first() == Row(3, Seq(3, 1))) + assert(structDf.groupBy($"a").agg(min(struct($"record.*"))). + sort("a").first() == Row(1, Row(1, 1))) + assert(structDf.groupBy($"a").agg(min(array($"record.*"))). + sort("a").first() == Row(1, Seq(1, 1))) // CreateStruct and CreateArray in project list (unresolved alias) assert(structDf.select(struct($"record.*")).first() == Row(Row(1, 1))) @@ -279,7 +166,126 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { structDf.select(hash($"a", $"record.*"))) } - test("Star Expansion - explode should fail with a meaningful message if it takes a star") { + private def assertDecimalSumOverflow( + df: DataFrame, ansiEnabled: Boolean, expectedAnswer: Row): Unit = { + if (!ansiEnabled) { + try { + checkAnswer(df, expectedAnswer) + } catch { + case e: SparkException if e.getCause.isInstanceOf[ArithmeticException] => + // This is an existing bug that we can write overflowed decimal to UnsafeRow but fail + // to read it. + assert(e.getCause.getMessage.contains("Decimal precision 39 exceeds max precision 38")) + } + } else { + val e = intercept[SparkException] { + df.collect + } + assert(e.getCause.isInstanceOf[ArithmeticException]) + assert(e.getCause.getMessage.contains("cannot be represented as Decimal") || + e.getCause.getMessage.contains("Overflow in sum of decimals") || + e.getCause.getMessage.contains("Decimal precision 39 exceeds max precision 38")) + } + } + + test("SPARK-28224: Aggregate sum big decimal overflow") { + val largeDecimals = spark.sparkContext.parallelize( + DecimalData(BigDecimal("1"* 20 + ".123"), BigDecimal("1"* 20 + ".123")) :: + DecimalData(BigDecimal("9"* 20 + ".123"), BigDecimal("9"* 20 + ".123")) :: Nil).toDF() + + Seq(true, false).foreach { ansiEnabled => + withSQLConf((SQLConf.ANSI_ENABLED.key, ansiEnabled.toString)) { + val structDf = largeDecimals.select("a").agg(sum("a")) + assertDecimalSumOverflow(structDf, ansiEnabled, Row(null)) + } + } + } + + test("SPARK-28067: sum of null decimal values") { + Seq("true", "false").foreach { wholeStageEnabled => + withSQLConf((SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, wholeStageEnabled)) { + Seq("true", "false").foreach { ansiEnabled => + withSQLConf((SQLConf.ANSI_ENABLED.key, ansiEnabled)) { + val df = spark.range(1, 4, 1).select(expr(s"cast(null as decimal(38,18)) as d")) + checkAnswer(df.agg(sum($"d")), Row(null)) + } + } + } + } + } + + test("SPARK-28067: Aggregate sum should not return wrong results for decimal overflow") { + Seq("true", "false").foreach { wholeStageEnabled => + withSQLConf((SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, wholeStageEnabled)) { + Seq(true, false).foreach { ansiEnabled => + withSQLConf((SQLConf.ANSI_ENABLED.key, ansiEnabled.toString)) { + val df0 = Seq( + (BigDecimal("10000000000000000000"), 1), + (BigDecimal("10000000000000000000"), 1), + (BigDecimal("10000000000000000000"), 2)).toDF("decNum", "intNum") + val df1 = Seq( + (BigDecimal("10000000000000000000"), 2), + (BigDecimal("10000000000000000000"), 2), + (BigDecimal("10000000000000000000"), 2), + (BigDecimal("10000000000000000000"), 2), + (BigDecimal("10000000000000000000"), 2), + (BigDecimal("10000000000000000000"), 2), + (BigDecimal("10000000000000000000"), 2), + (BigDecimal("10000000000000000000"), 2), + (BigDecimal("10000000000000000000"), 2)).toDF("decNum", "intNum") + val df = df0.union(df1) + val df2 = df.withColumnRenamed("decNum", "decNum2"). + join(df, "intNum").agg(sum("decNum")) + + val expectedAnswer = Row(null) + assertDecimalSumOverflow(df2, ansiEnabled, expectedAnswer) + + val decStr = "1" + "0" * 19 + val d1 = spark.range(0, 12, 1, 1) + val d2 = d1.select(expr(s"cast('$decStr' as decimal (38, 18)) as d")).agg(sum($"d")) + assertDecimalSumOverflow(d2, ansiEnabled, expectedAnswer) + + val d3 = spark.range(0, 1, 1, 1).union(spark.range(0, 11, 1, 1)) + val d4 = d3.select(expr(s"cast('$decStr' as decimal (38, 18)) as d")).agg(sum($"d")) + assertDecimalSumOverflow(d4, ansiEnabled, expectedAnswer) + + val d5 = d3.select(expr(s"cast('$decStr' as decimal (38, 18)) as d"), + lit(1).as("key")).groupBy("key").agg(sum($"d").alias("sumd")).select($"sumd") + assertDecimalSumOverflow(d5, ansiEnabled, expectedAnswer) + + val nullsDf = spark.range(1, 4, 1).select(expr(s"cast(null as decimal(38,18)) as d")) + + val largeDecimals = Seq(BigDecimal("1"* 20 + ".123"), BigDecimal("9"* 20 + ".123")). + toDF("d") + assertDecimalSumOverflow( + nullsDf.union(largeDecimals).agg(sum($"d")), ansiEnabled, expectedAnswer) + + val df3 = Seq( + (BigDecimal("10000000000000000000"), 1), + (BigDecimal("50000000000000000000"), 1), + (BigDecimal("10000000000000000000"), 2)).toDF("decNum", "intNum") + + val df4 = Seq( + (BigDecimal("10000000000000000000"), 1), + (BigDecimal("10000000000000000000"), 1), + (BigDecimal("10000000000000000000"), 2)).toDF("decNum", "intNum") + + val df5 = Seq( + (BigDecimal("10000000000000000000"), 1), + (BigDecimal("10000000000000000000"), 1), + (BigDecimal("20000000000000000000"), 2)).toDF("decNum", "intNum") + + val df6 = df3.union(df4).union(df5) + val df7 = df6.groupBy("intNum").agg(sum("decNum"), countDistinct("decNum")). + filter("intNum == 1") + assertDecimalSumOverflow(df7, ansiEnabled, Row(1, null, 2)) + } + } + } + } + } + + test("Star Expansion - ds.explode should fail with a meaningful message if it takes a star") { val df = Seq(("1", "1,2"), ("2", "4"), ("3", "7,8,9")).toDF("prefix", "csv") val e = intercept[AnalysisException] { df.explode($"*") { case Row(prefix: String, csv: String) => @@ -300,6 +306,21 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { Row("3", "7,8,9", "3:9") :: Nil) } + test("Star Expansion - explode should fail with a meaningful message if it takes a star") { + val df = Seq(("1,2"), ("4"), ("7,8,9")).toDF("csv") + val e = intercept[AnalysisException] { + df.select(explode($"*")) + } + assert(e.getMessage.contains("Invalid usage of '*' in expression 'explode'")) + } + + test("explode on output of array-valued function") { + val df = Seq(("1,2"), ("4"), ("7,8,9")).toDF("csv") + checkAnswer( + df.select(explode(split($"csv", pattern = ","))), + Row("1") :: Row("2") :: Row("4") :: Row("7") :: Row("8") :: Row("9") :: Nil) + } + test("Star Expansion - explode alias and star") { val df = Seq((Array("a"), 1)).toDF("a", "b") @@ -354,12 +375,12 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { test("repartition") { intercept[IllegalArgumentException] { - testData.select('key).repartition(0) + testData.select("key").repartition(0) } checkAnswer( - testData.select('key).repartition(10).select('key), - testData.select('key).collect().toSeq) + testData.select("key").repartition(10).select("key"), + testData.select("key").collect().toSeq) } test("repartition with SortOrder") { @@ -421,16 +442,16 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { test("coalesce") { intercept[IllegalArgumentException] { - testData.select('key).coalesce(0) + testData.select("key").coalesce(0) } - assert(testData.select('key).coalesce(1).rdd.partitions.size === 1) + assert(testData.select("key").coalesce(1).rdd.partitions.size === 1) checkAnswer( - testData.select('key).coalesce(1).select('key), - testData.select('key).collect().toSeq) + testData.select("key").coalesce(1).select("key"), + testData.select("key").collect().toSeq) - assert(spark.emptyDataFrame.coalesce(1).rdd.partitions.size === 1) + assert(spark.emptyDataFrame.coalesce(1).rdd.partitions.size === 0) } test("convert $\"attribute name\" into unresolved attribute") { @@ -441,7 +462,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { test("convert Scala Symbol 'attrname into unresolved attribute") { checkAnswer( - testData.where('key === lit(1)).select('value), + testData.where($"key" === lit(1)).select("value"), Row("1")) } @@ -453,17 +474,17 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { test("simple select") { checkAnswer( - testData.where('key === lit(1)).select('value), + testData.where($"key" === lit(1)).select("value"), Row("1")) } test("select with functions") { checkAnswer( - testData.select(sum('value), avg('value), count(lit(1))), + testData.select(sum("value"), avg("value"), count(lit(1))), Row(5050.0, 50.5, 100)) checkAnswer( - testData2.select('a + 'b, 'a < 'b), + testData2.select($"a" + $"b", $"a" < $"b"), Seq( Row(2, false), Row(3, true), @@ -473,31 +494,31 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { Row(5, false))) checkAnswer( - testData2.select(sumDistinct('a)), + testData2.select(sumDistinct($"a")), Row(6)) } test("sorting with null ordering") { val data = Seq[java.lang.Integer](2, 1, null).toDF("key") - checkAnswer(data.orderBy('key.asc), Row(null) :: Row(1) :: Row(2) :: Nil) + checkAnswer(data.orderBy($"key".asc), Row(null) :: Row(1) :: Row(2) :: Nil) checkAnswer(data.orderBy(asc("key")), Row(null) :: Row(1) :: Row(2) :: Nil) - checkAnswer(data.orderBy('key.asc_nulls_first), Row(null) :: Row(1) :: Row(2) :: Nil) + checkAnswer(data.orderBy($"key".asc_nulls_first), Row(null) :: Row(1) :: Row(2) :: Nil) checkAnswer(data.orderBy(asc_nulls_first("key")), Row(null) :: Row(1) :: Row(2) :: Nil) - checkAnswer(data.orderBy('key.asc_nulls_last), Row(1) :: Row(2) :: Row(null) :: Nil) + checkAnswer(data.orderBy($"key".asc_nulls_last), Row(1) :: Row(2) :: Row(null) :: Nil) checkAnswer(data.orderBy(asc_nulls_last("key")), Row(1) :: Row(2) :: Row(null) :: Nil) - checkAnswer(data.orderBy('key.desc), Row(2) :: Row(1) :: Row(null) :: Nil) + checkAnswer(data.orderBy($"key".desc), Row(2) :: Row(1) :: Row(null) :: Nil) checkAnswer(data.orderBy(desc("key")), Row(2) :: Row(1) :: Row(null) :: Nil) - checkAnswer(data.orderBy('key.desc_nulls_first), Row(null) :: Row(2) :: Row(1) :: Nil) + checkAnswer(data.orderBy($"key".desc_nulls_first), Row(null) :: Row(2) :: Row(1) :: Nil) checkAnswer(data.orderBy(desc_nulls_first("key")), Row(null) :: Row(2) :: Row(1) :: Nil) - checkAnswer(data.orderBy('key.desc_nulls_last), Row(2) :: Row(1) :: Row(null) :: Nil) + checkAnswer(data.orderBy($"key".desc_nulls_last), Row(2) :: Row(1) :: Row(null) :: Nil) checkAnswer(data.orderBy(desc_nulls_last("key")), Row(2) :: Row(1) :: Row(null) :: Nil) } test("global sorting") { checkAnswer( - testData2.orderBy('a.asc, 'b.asc), + testData2.orderBy($"a".asc, $"b".asc), Seq(Row(1, 1), Row(1, 2), Row(2, 1), Row(2, 2), Row(3, 1), Row(3, 2))) checkAnswer( @@ -505,31 +526,31 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { Seq(Row(1, 2), Row(1, 1), Row(2, 2), Row(2, 1), Row(3, 2), Row(3, 1))) checkAnswer( - testData2.orderBy('a.asc, 'b.desc), + testData2.orderBy($"a".asc, $"b".desc), Seq(Row(1, 2), Row(1, 1), Row(2, 2), Row(2, 1), Row(3, 2), Row(3, 1))) checkAnswer( - testData2.orderBy('a.desc, 'b.desc), + testData2.orderBy($"a".desc, $"b".desc), Seq(Row(3, 2), Row(3, 1), Row(2, 2), Row(2, 1), Row(1, 2), Row(1, 1))) checkAnswer( - testData2.orderBy('a.desc, 'b.asc), + testData2.orderBy($"a".desc, $"b".asc), Seq(Row(3, 1), Row(3, 2), Row(2, 1), Row(2, 2), Row(1, 1), Row(1, 2))) checkAnswer( - arrayData.toDF().orderBy('data.getItem(0).asc), + arrayData.toDF().orderBy($"data".getItem(0).asc), arrayData.toDF().collect().sortBy(_.getAs[Seq[Int]](0)(0)).toSeq) checkAnswer( - arrayData.toDF().orderBy('data.getItem(0).desc), + arrayData.toDF().orderBy($"data".getItem(0).desc), arrayData.toDF().collect().sortBy(_.getAs[Seq[Int]](0)(0)).reverse.toSeq) checkAnswer( - arrayData.toDF().orderBy('data.getItem(1).asc), + arrayData.toDF().orderBy($"data".getItem(1).asc), arrayData.toDF().collect().sortBy(_.getAs[Seq[Int]](0)(1)).toSeq) checkAnswer( - arrayData.toDF().orderBy('data.getItem(1).desc), + arrayData.toDF().orderBy($"data".getItem(1).desc), arrayData.toDF().collect().sortBy(_.getAs[Seq[Int]](0)(1)).reverse.toSeq) } @@ -553,265 +574,12 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { ) } - test("except") { - checkAnswer( - lowerCaseData.except(upperCaseData), - Row(1, "a") :: - Row(2, "b") :: - Row(3, "c") :: - Row(4, "d") :: Nil) - checkAnswer(lowerCaseData.except(lowerCaseData), Nil) - checkAnswer(upperCaseData.except(upperCaseData), Nil) - - // check null equality - checkAnswer( - nullInts.except(nullInts.filter("0 = 1")), - nullInts) - checkAnswer( - nullInts.except(nullInts), - Nil) - - // check if values are de-duplicated - checkAnswer( - allNulls.except(allNulls.filter("0 = 1")), - Row(null) :: Nil) - checkAnswer( - allNulls.except(allNulls), - Nil) - - // check if values are de-duplicated - val df = Seq(("id1", 1), ("id1", 1), ("id", 1), ("id1", 2)).toDF("id", "value") - checkAnswer( - df.except(df.filter("0 = 1")), - Row("id1", 1) :: - Row("id", 1) :: - Row("id1", 2) :: Nil) - - // check if the empty set on the left side works - checkAnswer( - allNulls.filter("0 = 1").except(allNulls), - Nil) - } - - test("SPARK-23274: except between two projects without references used in filter") { - val df = Seq((1, 2, 4), (1, 3, 5), (2, 2, 3), (2, 4, 5)).toDF("a", "b", "c") - val df1 = df.filter($"a" === 1) - val df2 = df.filter($"a" === 2) - checkAnswer(df1.select("b").except(df2.select("b")), Row(3) :: Nil) - checkAnswer(df1.select("b").except(df2.select("c")), Row(2) :: Nil) - } - - test("except distinct - SQL compliance") { - val df_left = Seq(1, 2, 2, 3, 3, 4).toDF("id") - val df_right = Seq(1, 3).toDF("id") - - checkAnswer( - df_left.except(df_right), - Row(2) :: Row(4) :: Nil - ) - } - - test("except - nullability") { - val nonNullableInts = Seq(Tuple1(11), Tuple1(3)).toDF() - assert(nonNullableInts.schema.forall(!_.nullable)) - - val df1 = nonNullableInts.except(nullInts) - checkAnswer(df1, Row(11) :: Nil) - assert(df1.schema.forall(!_.nullable)) - - val df2 = nullInts.except(nonNullableInts) - checkAnswer(df2, Row(1) :: Row(2) :: Row(null) :: Nil) - assert(df2.schema.forall(_.nullable)) - - val df3 = nullInts.except(nullInts) - checkAnswer(df3, Nil) - assert(df3.schema.forall(_.nullable)) - - val df4 = nonNullableInts.except(nonNullableInts) - checkAnswer(df4, Nil) - assert(df4.schema.forall(!_.nullable)) - } - - test("except all") { - checkAnswer( - lowerCaseData.exceptAll(upperCaseData), - Row(1, "a") :: - Row(2, "b") :: - Row(3, "c") :: - Row(4, "d") :: Nil) - checkAnswer(lowerCaseData.exceptAll(lowerCaseData), Nil) - checkAnswer(upperCaseData.exceptAll(upperCaseData), Nil) - - // check null equality - checkAnswer( - nullInts.exceptAll(nullInts.filter("0 = 1")), - nullInts) - checkAnswer( - nullInts.exceptAll(nullInts), - Nil) - - // check that duplicate values are preserved - checkAnswer( - allNulls.exceptAll(allNulls.filter("0 = 1")), - Row(null) :: Row(null) :: Row(null) :: Row(null) :: Nil) - checkAnswer( - allNulls.exceptAll(allNulls.limit(2)), - Row(null) :: Row(null) :: Nil) - - // check that duplicates are retained. - val df = spark.sparkContext.parallelize( - NullStrings(1, "id1") :: - NullStrings(1, "id1") :: - NullStrings(2, "id1") :: - NullStrings(3, null) :: Nil).toDF("id", "value") - - checkAnswer( - df.exceptAll(df.filter("0 = 1")), - Row(1, "id1") :: - Row(1, "id1") :: - Row(2, "id1") :: - Row(3, null) :: Nil) - - // check if the empty set on the left side works - checkAnswer( - allNulls.filter("0 = 1").exceptAll(allNulls), - Nil) - - } - - test("exceptAll - nullability") { - val nonNullableInts = Seq(Tuple1(11), Tuple1(3)).toDF() - assert(nonNullableInts.schema.forall(!_.nullable)) - - val df1 = nonNullableInts.exceptAll(nullInts) - checkAnswer(df1, Row(11) :: Nil) - assert(df1.schema.forall(!_.nullable)) - - val df2 = nullInts.exceptAll(nonNullableInts) - checkAnswer(df2, Row(1) :: Row(2) :: Row(null) :: Nil) - assert(df2.schema.forall(_.nullable)) - - val df3 = nullInts.exceptAll(nullInts) - checkAnswer(df3, Nil) - assert(df3.schema.forall(_.nullable)) - - val df4 = nonNullableInts.exceptAll(nonNullableInts) - checkAnswer(df4, Nil) - assert(df4.schema.forall(!_.nullable)) - } - - test("intersect") { - checkAnswer( - lowerCaseData.intersect(lowerCaseData), - Row(1, "a") :: - Row(2, "b") :: - Row(3, "c") :: - Row(4, "d") :: Nil) - checkAnswer(lowerCaseData.intersect(upperCaseData), Nil) - - // check null equality - checkAnswer( - nullInts.intersect(nullInts), - Row(1) :: - Row(2) :: - Row(3) :: - Row(null) :: Nil) - - // check if values are de-duplicated - checkAnswer( - allNulls.intersect(allNulls), - Row(null) :: Nil) - - // check if values are de-duplicated - val df = Seq(("id1", 1), ("id1", 1), ("id", 1), ("id1", 2)).toDF("id", "value") - checkAnswer( - df.intersect(df), - Row("id1", 1) :: - Row("id", 1) :: - Row("id1", 2) :: Nil) - } - - test("intersect - nullability") { - val nonNullableInts = Seq(Tuple1(1), Tuple1(3)).toDF() - assert(nonNullableInts.schema.forall(!_.nullable)) - - val df1 = nonNullableInts.intersect(nullInts) - checkAnswer(df1, Row(1) :: Row(3) :: Nil) - assert(df1.schema.forall(!_.nullable)) - - val df2 = nullInts.intersect(nonNullableInts) - checkAnswer(df2, Row(1) :: Row(3) :: Nil) - assert(df2.schema.forall(!_.nullable)) - - val df3 = nullInts.intersect(nullInts) - checkAnswer(df3, Row(1) :: Row(2) :: Row(3) :: Row(null) :: Nil) - assert(df3.schema.forall(_.nullable)) - - val df4 = nonNullableInts.intersect(nonNullableInts) - checkAnswer(df4, Row(1) :: Row(3) :: Nil) - assert(df4.schema.forall(!_.nullable)) - } - - test("intersectAll") { - checkAnswer( - lowerCaseDataWithDuplicates.intersectAll(lowerCaseDataWithDuplicates), - Row(1, "a") :: - Row(2, "b") :: - Row(2, "b") :: - Row(3, "c") :: - Row(3, "c") :: - Row(3, "c") :: - Row(4, "d") :: Nil) - checkAnswer(lowerCaseData.intersectAll(upperCaseData), Nil) - - // check null equality - checkAnswer( - nullInts.intersectAll(nullInts), - Row(1) :: - Row(2) :: - Row(3) :: - Row(null) :: Nil) - - // Duplicate nulls are preserved. - checkAnswer( - allNulls.intersectAll(allNulls), - Row(null) :: Row(null) :: Row(null) :: Row(null) :: Nil) - - val df_left = Seq(1, 2, 2, 3, 3, 4).toDF("id") - val df_right = Seq(1, 2, 2, 3).toDF("id") - - checkAnswer( - df_left.intersectAll(df_right), - Row(1) :: Row(2) :: Row(2) :: Row(3) :: Nil) - } - - test("intersectAll - nullability") { - val nonNullableInts = Seq(Tuple1(1), Tuple1(3)).toDF() - assert(nonNullableInts.schema.forall(!_.nullable)) - - val df1 = nonNullableInts.intersectAll(nullInts) - checkAnswer(df1, Row(1) :: Row(3) :: Nil) - assert(df1.schema.forall(!_.nullable)) - - val df2 = nullInts.intersectAll(nonNullableInts) - checkAnswer(df2, Row(1) :: Row(3) :: Nil) - assert(df2.schema.forall(!_.nullable)) - - val df3 = nullInts.intersectAll(nullInts) - checkAnswer(df3, Row(1) :: Row(2) :: Row(3) :: Row(null) :: Nil) - assert(df3.schema.forall(_.nullable)) - - val df4 = nonNullableInts.intersectAll(nonNullableInts) - checkAnswer(df4, Row(1) :: Row(3) :: Nil) - assert(df4.schema.forall(!_.nullable)) - } - test("udf") { val foo = udf((a: Int, b: String) => a.toString + b) checkAnswer( // SELECT *, foo(key, value) FROM testData - testData.select($"*", foo('key, 'value)).limit(3), + testData.select($"*", foo($"key", $"value")).limit(3), Row(1, "1", "11") :: Row(2, "2", "22") :: Row(3, "3", "33") :: Nil ) } @@ -914,7 +682,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { } test("replace column using withColumns") { - val df2 = sparkContext.parallelize(Array((1, 2), (2, 3), (3, 4))).toDF("x", "y") + val df2 = sparkContext.parallelize(Seq((1, 2), (2, 3), (3, 4))).toDF("x", "y") val df3 = df2.withColumns(Seq("x", "newCol1", "newCol2"), Seq(df2("x") + 1, df2("y"), df2("y") + 1)) checkAnswer( @@ -954,6 +722,21 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { assert(df.schema.map(_.name) === Seq("value")) } + test("SPARK-28189 drop column using drop with column reference with case-insensitive names") { + // With SQL config caseSensitive OFF, case insensitive column name should work + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { + val col1 = testData("KEY") + val df1 = testData.drop(col1) + checkAnswer(df1, testData.selectExpr("value")) + assert(df1.schema.map(_.name) === Seq("value")) + + val col2 = testData("Key") + val df2 = testData.drop(col2) + checkAnswer(df2, testData.selectExpr("value")) + assert(df2.schema.map(_.name) === Seq("value")) + } + } + test("drop unknown column (no-op) with column reference") { val col = Column("random") val df = testData.drop(col) @@ -1140,7 +923,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { test("apply on query results (SPARK-5462)") { val df = testData.sparkSession.sql("select key from testData") - checkAnswer(df.select(df("key")), testData.select('key).collect().toSeq) + checkAnswer(df.select(df("key")), testData.select("key").collect().toSeq) } test("inputFiles") { @@ -1501,7 +1284,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { |""".stripMargin assert(df.showString(1, truncate = 0) === expectedAnswer) - withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> "GMT") { + withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> "UTC") { val expectedAnswer = """+----------+-------------------+ ||d |ts | @@ -1522,7 +1305,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { " ts | 2016-12-01 00:00:00 \n" assert(df.showString(1, truncate = 0, vertical = true) === expectedAnswer) - withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> "GMT") { + withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> "UTC") { val expectedAnswer = "-RECORD 0------------------\n" + " d | 2016-12-01 \n" + @@ -1539,7 +1322,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { } test("SPARK-6899: type should match when using codegen") { - checkAnswer(decimalData.agg(avg('a)), Row(new java.math.BigDecimal(2))) + checkAnswer(decimalData.agg(avg("a")), Row(new java.math.BigDecimal(2))) } test("SPARK-7133: Implement struct, array, and map field accessor") { @@ -1669,47 +1452,48 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { test("SPARK-6941: Better error message for inserting into RDD-based Table") { withTempDir { dir => + withTempView("parquet_base", "json_base", "rdd_base", "indirect_ds", "one_row") { + val tempParquetFile = new File(dir, "tmp_parquet") + val tempJsonFile = new File(dir, "tmp_json") + + val df = Seq(Tuple1(1)).toDF() + val insertion = Seq(Tuple1(2)).toDF("col") + + // pass case: parquet table (HadoopFsRelation) + df.write.mode(SaveMode.Overwrite).parquet(tempParquetFile.getCanonicalPath) + val pdf = spark.read.parquet(tempParquetFile.getCanonicalPath) + pdf.createOrReplaceTempView("parquet_base") + + insertion.write.insertInto("parquet_base") + + // pass case: json table (InsertableRelation) + df.write.mode(SaveMode.Overwrite).json(tempJsonFile.getCanonicalPath) + val jdf = spark.read.json(tempJsonFile.getCanonicalPath) + jdf.createOrReplaceTempView("json_base") + insertion.write.mode(SaveMode.Overwrite).insertInto("json_base") + + // error cases: insert into an RDD + df.createOrReplaceTempView("rdd_base") + val e1 = intercept[AnalysisException] { + insertion.write.insertInto("rdd_base") + } + assert(e1.getMessage.contains("Inserting into an RDD-based table is not allowed.")) - val tempParquetFile = new File(dir, "tmp_parquet") - val tempJsonFile = new File(dir, "tmp_json") - - val df = Seq(Tuple1(1)).toDF() - val insertion = Seq(Tuple1(2)).toDF("col") - - // pass case: parquet table (HadoopFsRelation) - df.write.mode(SaveMode.Overwrite).parquet(tempParquetFile.getCanonicalPath) - val pdf = spark.read.parquet(tempParquetFile.getCanonicalPath) - pdf.createOrReplaceTempView("parquet_base") - - insertion.write.insertInto("parquet_base") - - // pass case: json table (InsertableRelation) - df.write.mode(SaveMode.Overwrite).json(tempJsonFile.getCanonicalPath) - val jdf = spark.read.json(tempJsonFile.getCanonicalPath) - jdf.createOrReplaceTempView("json_base") - insertion.write.mode(SaveMode.Overwrite).insertInto("json_base") - - // error cases: insert into an RDD - df.createOrReplaceTempView("rdd_base") - val e1 = intercept[AnalysisException] { - insertion.write.insertInto("rdd_base") - } - assert(e1.getMessage.contains("Inserting into an RDD-based table is not allowed.")) - - // error case: insert into a logical plan that is not a LeafNode - val indirectDS = pdf.select("_1").filter($"_1" > 5) - indirectDS.createOrReplaceTempView("indirect_ds") - val e2 = intercept[AnalysisException] { - insertion.write.insertInto("indirect_ds") - } - assert(e2.getMessage.contains("Inserting into an RDD-based table is not allowed.")) + // error case: insert into a logical plan that is not a LeafNode + val indirectDS = pdf.select("_1").filter($"_1" > 5) + indirectDS.createOrReplaceTempView("indirect_ds") + val e2 = intercept[AnalysisException] { + insertion.write.insertInto("indirect_ds") + } + assert(e2.getMessage.contains("Inserting into an RDD-based table is not allowed.")) - // error case: insert into an OneRowRelation - Dataset.ofRows(spark, OneRowRelation()).createOrReplaceTempView("one_row") - val e3 = intercept[AnalysisException] { - insertion.write.insertInto("one_row") + // error case: insert into an OneRowRelation + Dataset.ofRows(spark, OneRowRelation()).createOrReplaceTempView("one_row") + val e3 = intercept[AnalysisException] { + insertion.write.insertInto("one_row") + } + assert(e3.getMessage.contains("Inserting into an RDD-based table is not allowed.")) } - assert(e3.getMessage.contains("Inserting into an RDD-based table is not allowed.")) } } @@ -1741,7 +1525,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { test("Sorting columns are not in Filter and Project") { checkAnswer( - upperCaseData.filter('N > 1).select('N).filter('N < 6).orderBy('L.asc), + upperCaseData.filter($"N" > 1).select("N").filter($"N" < 6).orderBy($"L".asc), Row(2) :: Row(3) :: Row(4) :: Row(5) :: Nil) } @@ -1784,77 +1568,30 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { test("Alias uses internally generated names 'aggOrder' and 'havingCondition'") { val df = Seq(1 -> 2).toDF("i", "j") - val query1 = df.groupBy('i) - .agg(max('j).as("aggOrder")) - .orderBy(sum('j)) + val query1 = df.groupBy("i") + .agg(max("j").as("aggOrder")) + .orderBy(sum("j")) checkAnswer(query1, Row(1, 2)) // In the plan, there are two attributes having the same name 'havingCondition' // One is a user-provided alias name; another is an internally generated one. - val query2 = df.groupBy('i) - .agg(max('j).as("havingCondition")) - .where(sum('j) > 0) - .orderBy('havingCondition.asc) + val query2 = df.groupBy("i") + .agg(max("j").as("havingCondition")) + .where(sum("j") > 0) + .orderBy($"havingCondition".asc) checkAnswer(query2, Row(1, 2)) } test("SPARK-10316: respect non-deterministic expressions in PhysicalOperation") { - val input = spark.read.json((1 to 10).map(i => s"""{"id": $i}""").toDS()) - - val df = input.select($"id", rand(0).as('r)) - df.as("a").join(df.filter($"r" < 0.5).as("b"), $"a.id" === $"b.id").collect().foreach { row => - assert(row.getDouble(1) - row.getDouble(3) === 0.0 +- 0.001) - } - } - - test("SPARK-10539: Project should not be pushed down through Intersect or Except") { - val df1 = (1 to 100).map(Tuple1.apply).toDF("i") - val df2 = (1 to 30).map(Tuple1.apply).toDF("i") - val intersect = df1.intersect(df2) - val except = df1.except(df2) - assert(intersect.count() === 30) - assert(except.count() === 70) - } - - test("SPARK-10740: handle nondeterministic expressions correctly for set operations") { - val df1 = (1 to 20).map(Tuple1.apply).toDF("i") - val df2 = (1 to 10).map(Tuple1.apply).toDF("i") + withTempDir { dir => + (1 to 10).toDF("id").write.mode(SaveMode.Overwrite).json(dir.getCanonicalPath) + val input = spark.read.json(dir.getCanonicalPath) - // When generating expected results at here, we need to follow the implementation of - // Rand expression. - def expected(df: DataFrame): Seq[Row] = { - df.rdd.collectPartitions().zipWithIndex.flatMap { - case (data, index) => - val rng = new org.apache.spark.util.random.XORShiftRandom(7 + index) - data.filter(_.getInt(0) < rng.nextDouble() * 10) + val df = input.select($"id", rand(0).as("r")) + df.as("a").join(df.filter($"r" < 0.5).as("b"), $"a.id" === $"b.id").collect().foreach { row => + assert(row.getDouble(1) - row.getDouble(3) === 0.0 +- 0.001) } } - - val union = df1.union(df2) - checkAnswer( - union.filter('i < rand(7) * 10), - expected(union) - ) - checkAnswer( - union.select(rand(7)), - union.rdd.collectPartitions().zipWithIndex.flatMap { - case (data, index) => - val rng = new org.apache.spark.util.random.XORShiftRandom(7 + index) - data.map(_ => rng.nextDouble()).map(i => Row(i)) - } - ) - - val intersect = df1.intersect(df2) - checkAnswer( - intersect.filter('i < rand(7) * 10), - expected(intersect) - ) - - val except = df1.except(df2) - checkAnswer( - except.filter('i < rand(7) * 10), - expected(except) - ) } test("SPARK-10743: keep the name of expression if possible when do cast") { @@ -2001,8 +1738,8 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { val rdd = sparkContext.makeRDD(Seq(Row(1, 3), Row(2, 1))) val df = spark.createDataFrame( rdd, - new StructType().add("f1", IntegerType).add("f2", IntegerType), - needsConversion = false).select($"F1", $"f2".as("f2")) + new StructType().add("f1", IntegerType).add("f2", IntegerType)) + .select($"F1", $"f2".as("f2")) val df1 = df.as("a") val df2 = df.as("b") checkAnswer(df1.join(df2, $"a.f2" === $"b.f2"), Row(1, 3, 1, 3) :: Row(2, 1, 2, 1) :: Nil) @@ -2017,7 +1754,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { test("SPARK-11725: correctly handle null inputs for ScalaUDF") { val df = sparkContext.parallelize(Seq( - new java.lang.Integer(22) -> "John", + java.lang.Integer.valueOf(22) -> "John", null.asInstanceOf[java.lang.Integer] -> "Lucy")).toDF("age", "name") // passing null into the UDF that could handle it @@ -2074,25 +1811,27 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { } test("reuse exchange") { - withSQLConf("spark.sql.autoBroadcastJoinThreshold" -> "2") { + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "2") { val df = spark.range(100).toDF() val join = df.join(df, "id") val plan = join.queryExecution.executedPlan checkAnswer(join, df) assert( - join.queryExecution.executedPlan.collect { case e: ShuffleExchangeExec => true }.size === 1) + collect(join.queryExecution.executedPlan) { + case e: ShuffleExchangeExec => true }.size === 1) assert( - join.queryExecution.executedPlan.collect { case e: ReusedExchangeExec => true }.size === 1) + collect(join.queryExecution.executedPlan) { case e: ReusedExchangeExec => true }.size === 1) val broadcasted = broadcast(join) val join2 = join.join(broadcasted, "id").join(broadcasted, "id") checkAnswer(join2, df) assert( - join2.queryExecution.executedPlan.collect { case e: ShuffleExchangeExec => true }.size == 1) + collect(join2.queryExecution.executedPlan) { + case e: ShuffleExchangeExec => true }.size == 1) assert( - join2.queryExecution.executedPlan - .collect { case e: BroadcastExchangeExec => true }.size === 1) + collect(join2.queryExecution.executedPlan) { + case e: BroadcastExchangeExec => true }.size === 1) assert( - join2.queryExecution.executedPlan.collect { case e: ReusedExchangeExec => true }.size == 4) + collect(join2.queryExecution.executedPlan) { case e: ReusedExchangeExec => true }.size == 4) } } @@ -2126,19 +1865,23 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { val df = Seq("foo", "bar").map(Tuple1.apply).toDF("col") // invalid table names Seq("11111", "t~", "#$@sum", "table!#").foreach { name => - val m = intercept[AnalysisException](df.createOrReplaceTempView(name)).getMessage - assert(m.contains(s"Invalid view name: $name")) + withTempView(name) { + val m = intercept[AnalysisException](df.createOrReplaceTempView(name)).getMessage + assert(m.contains(s"Invalid view name: $name")) + } } // valid table names Seq("table1", "`11111`", "`t~`", "`#$@sum`", "`table!#`").foreach { name => - df.createOrReplaceTempView(name) + withTempView(name) { + df.createOrReplaceTempView(name) + } } } test("assertAnalyzed shouldn't replace original stack trace") { val e = intercept[AnalysisException] { - spark.range(1).select('id as 'a, 'id as 'b).groupBy('a).agg('b) + spark.range(1).select($"id" as "a", $"id" as "b").groupBy("a").agg($"b") } assert(e.getStackTrace.head.getClassName != classOf[QueryExecution].getName) @@ -2203,7 +1946,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { val size = 201L val rdd = sparkContext.makeRDD(Seq(Row.fromSeq(Seq.range(0, size)))) val schemas = List.range(0, size).map(a => StructField("name" + a, LongType, true)) - val df = spark.createDataFrame(rdd, StructType(schemas), false) + val df = spark.createDataFrame(rdd, StructType(schemas)) assert(df.persist.take(1).apply(0).toSeq(100).asInstanceOf[Long] == 100) } @@ -2250,9 +1993,9 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { test("SPARK-17957: no change on nullability in FilterExec output") { val df = sparkContext.parallelize(Seq( - null.asInstanceOf[java.lang.Integer] -> new java.lang.Integer(3), - new java.lang.Integer(1) -> null.asInstanceOf[java.lang.Integer], - new java.lang.Integer(2) -> new java.lang.Integer(4))).toDF() + null.asInstanceOf[java.lang.Integer] -> java.lang.Integer.valueOf(3), + java.lang.Integer.valueOf(1) -> null.asInstanceOf[java.lang.Integer], + java.lang.Integer.valueOf(2) -> java.lang.Integer.valueOf(4))).toDF() verifyNullabilityInFilterExec(df, expr = "Rand()", expectedNonNullableColumns = Seq.empty[String]) @@ -2267,9 +2010,9 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { test("SPARK-17957: set nullability to false in FilterExec output") { val df = sparkContext.parallelize(Seq( - null.asInstanceOf[java.lang.Integer] -> new java.lang.Integer(3), - new java.lang.Integer(1) -> null.asInstanceOf[java.lang.Integer], - new java.lang.Integer(2) -> new java.lang.Integer(4))).toDF() + null.asInstanceOf[java.lang.Integer] -> java.lang.Integer.valueOf(3), + java.lang.Integer.valueOf(1) -> null.asInstanceOf[java.lang.Integer], + java.lang.Integer.valueOf(2) -> java.lang.Integer.valueOf(4))).toDF() verifyNullabilityInFilterExec(df, expr = "_1 + _2 * 3", expectedNonNullableColumns = Seq("_1", "_2")) @@ -2305,21 +2048,6 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { } } - test("SPARK-17123: Performing set operations that combine non-scala native types") { - val dates = Seq( - (new Date(0), BigDecimal.valueOf(1), new Timestamp(2)), - (new Date(3), BigDecimal.valueOf(4), new Timestamp(5)) - ).toDF("date", "timestamp", "decimal") - - val widenTypedRows = Seq( - (new Timestamp(2), 10.5D, "string") - ).toDF("date", "timestamp", "decimal") - - dates.union(widenTypedRows).collect() - dates.except(widenTypedRows).collect() - dates.intersect(widenTypedRows).collect() - } - test("SPARK-18070 binary operator should not consider nullability when comparing input types") { val rows = Seq(Row(Seq(1), Seq(1))) val schema = new StructType() @@ -2339,25 +2067,6 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { checkAnswer(df, Row(BigDecimal(0)) :: Nil) } - test("SPARK-19893: cannot run set operations with map type") { - val df = spark.range(1).select(map(lit("key"), $"id").as("m")) - val e = intercept[AnalysisException](df.intersect(df)) - assert(e.message.contains( - "Cannot have map type columns in DataFrame which calls set operations")) - val e2 = intercept[AnalysisException](df.except(df)) - assert(e2.message.contains( - "Cannot have map type columns in DataFrame which calls set operations")) - val e3 = intercept[AnalysisException](df.distinct()) - assert(e3.message.contains( - "Cannot have map type columns in DataFrame which calls set operations")) - withTempView("v") { - df.createOrReplaceTempView("v") - val e4 = intercept[AnalysisException](sql("SELECT DISTINCT m FROM v")) - assert(e4.message.contains( - "Cannot have map type columns in DataFrame which calls set operations")) - } - } - test("SPARK-20359: catalyst outer join optimization should not throw npe") { val df1 = Seq("a", "b", "c").toDF("x") .withColumn("y", udf{ (x: String) => x.substring(0, 1) + "!" }.apply($"x")) @@ -2389,7 +2098,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { val e = intercept[SparkException] { df.filter(filter).count() }.getMessage - assert(e.contains("grows beyond 64 KB")) + assert(e.contains("grows beyond 64 KiB")) } } @@ -2405,26 +2114,14 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { test("order-by ordinal.") { checkAnswer( - testData2.select(lit(7), 'a, 'b).orderBy(lit(1), lit(2), lit(3)), + testData2.select(lit(7), $"a", $"b").orderBy(lit(1), lit(2), lit(3)), Seq(Row(7, 1, 1), Row(7, 1, 2), Row(7, 2, 1), Row(7, 2, 2), Row(7, 3, 1), Row(7, 3, 2))) } - test("SPARK-22226: splitExpressions should not generate codes beyond 64KB") { - val colNumber = 10000 - val input = spark.range(2).rdd.map(_ => Row(1 to colNumber: _*)) - val df = sqlContext.createDataFrame(input, StructType( - (1 to colNumber).map(colIndex => StructField(s"_$colIndex", IntegerType, false)))) - val newCols = (1 to colNumber).flatMap { colIndex => - Seq(expr(s"if(1000 < _$colIndex, 1000, _$colIndex)"), - expr(s"sqrt(_$colIndex)")) - } - df.select(newCols: _*).collect() - } - test("SPARK-22271: mean overflows and returns null for some decimal variables") { val d = 0.034567890 val df = Seq(d, d, d, d, d, d, d, d, d, d).toDF("DecimalCol") - val result = df.select('DecimalCol cast DecimalType(38, 33)) + val result = df.select($"DecimalCol" cast DecimalType(38, 33)) .select(col("DecimalCol")).describe() val mean = result.select("DecimalCol").where($"summary" === "mean") assert(mean.collect().toSet === Set(Row("0.0345678900000000000000000000000000000"))) @@ -2460,24 +2157,25 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { val sourceDF = spark.createDataFrame(rows, schema) def structWhenDF: DataFrame = sourceDF - .select(when('cond, struct(lit("a").as("val1"), lit(10).as("val2"))).otherwise('s) as "res") - .select('res.getField("val1")) + .select(when($"cond", + struct(lit("a").as("val1"), lit(10).as("val2"))).otherwise($"s") as "res") + .select($"res".getField("val1")) def arrayWhenDF: DataFrame = sourceDF - .select(when('cond, array(lit("a"), lit("b"))).otherwise('a) as "res") - .select('res.getItem(0)) + .select(when($"cond", array(lit("a"), lit("b"))).otherwise($"a") as "res") + .select($"res".getItem(0)) def mapWhenDF: DataFrame = sourceDF - .select(when('cond, map(lit(0), lit("a"))).otherwise('m) as "res") - .select('res.getItem(0)) + .select(when($"cond", map(lit(0), lit("a"))).otherwise($"m") as "res") + .select($"res".getItem(0)) def structIfDF: DataFrame = sourceDF .select(expr("if(cond, struct('a' as val1, 10 as val2), s)") as "res") - .select('res.getField("val1")) + .select($"res".getField("val1")) def arrayIfDF: DataFrame = sourceDF .select(expr("if(cond, array('a', 'b'), a)") as "res") - .select('res.getItem(0)) + .select($"res".getItem(0)) def mapIfDF: DataFrame = sourceDF .select(expr("if(cond, map(0, 'a'), m)") as "res") - .select('res.getItem(0)) + .select($"res".getItem(0)) def checkResult(): Unit = { checkAnswer(structWhenDF, Seq(Row("a"), Row(null))) @@ -2540,36 +2238,18 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { // partitions. .write.partitionBy("p").option("compression", "gzip").json(path.getCanonicalPath) - var numJobs = 0 + val numJobs = new AtomicLong(0) sparkContext.addSparkListener(new SparkListener { override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = { - numJobs += 1 + numJobs.incrementAndGet() } }) val df = spark.read.json(path.getCanonicalPath) assert(df.columns === Array("i", "p")) - spark.sparkContext.listenerBus.waitUntilEmpty(10000) - assert(numJobs == 1) - } - } - - test("SPARK-25368 Incorrect predicate pushdown returns wrong result") { - def check(newCol: Column, filter: Column, result: Seq[Row]): Unit = { - val df1 = spark.createDataFrame(Seq( - (1, 1) - )).toDF("a", "b").withColumn("c", newCol) - - val df2 = df1.union(df1).withColumn("d", spark_partition_id).filter(filter) - checkAnswer(df2, result) + spark.sparkContext.listenerBus.waitUntilEmpty() + assert(numJobs.get() == 1L) } - - check(lit(null).cast("int"), $"c".isNull, Seq(Row(1, 1, null, 0), Row(1, 1, null, 1))) - check(lit(null).cast("int"), $"c".isNotNull, Seq()) - check(lit(2).cast("int"), $"c".isNull, Seq()) - check(lit(2).cast("int"), $"c".isNotNull, Seq(Row(1, 1, 2, 0), Row(1, 1, 2, 1))) - check(lit(2).cast("int"), $"c" === 2, Seq(Row(1, 1, 2, 0), Row(1, 1, 2, 1))) - check(lit(2).cast("int"), $"c" =!= 2, Seq()) } test("SPARK-25402 Null handling in BooleanSimplification") { @@ -2622,4 +2302,222 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { checkAnswer(res, Row("1-1", 6, 6)) } } + + test("SPARK-27671: Fix analysis exception when casting null in nested field in struct") { + val df = sql("SELECT * FROM VALUES (('a', (10, null))), (('b', (10, 50))), " + + "(('c', null)) AS tab(x, y)") + checkAnswer(df, Row("a", Row(10, null)) :: Row("b", Row(10, 50)) :: Row("c", null) :: Nil) + + val cast = sql("SELECT cast(struct(1, null) AS struct)") + checkAnswer(cast, Row(Row(1, null)) :: Nil) + } + + test("SPARK-27439: Explain result should match collected result after view change") { + withTempView("test", "test2", "tmp") { + spark.range(10).createOrReplaceTempView("test") + spark.range(5).createOrReplaceTempView("test2") + spark.sql("select * from test").createOrReplaceTempView("tmp") + val df = spark.sql("select * from tmp") + spark.sql("select * from test2").createOrReplaceTempView("tmp") + + val captured = new ByteArrayOutputStream() + Console.withOut(captured) { + df.explain(extended = true) + } + checkAnswer(df, spark.range(10).toDF) + val output = captured.toString + assert(output.contains( + """== Parsed Logical Plan == + |'Project [*] + |+- 'UnresolvedRelation [tmp]""".stripMargin)) + assert(output.contains( + """== Physical Plan == + |*(1) Range (0, 10, step=1, splits=2)""".stripMargin)) + } + } + + test("SPARK-29442 Set `default` mode should override the existing mode") { + val df = Seq(Tuple1(1)).toDF() + val writer = df.write.mode("overwrite").mode("default") + val modeField = classOf[DataFrameWriter[Tuple1[Int]]].getDeclaredField("mode") + modeField.setAccessible(true) + assert(SaveMode.ErrorIfExists === modeField.get(writer).asInstanceOf[SaveMode]) + } + + test("sample should not duplicated the input data") { + val df1 = spark.range(10).select($"id" as "id1", $"id" % 5 as "key1") + val df2 = spark.range(10).select($"id" as "id2", $"id" % 5 as "key2") + val sampled = df1.join(df2, $"key1" === $"key2") + .sample(0.5, 42) + .select("id1", "id2") + val idTuples = sampled.collect().map(row => row.getLong(0) -> row.getLong(1)) + assert(idTuples.length == idTuples.toSet.size) + } + + test("groupBy.as") { + val df1 = Seq((1, 2, 3), (2, 3, 4)).toDF("a", "b", "c") + .repartition($"a", $"b").sortWithinPartitions("a", "b") + val df2 = Seq((1, 2, 4), (2, 3, 5)).toDF("a", "b", "c") + .repartition($"a", $"b").sortWithinPartitions("a", "b") + + implicit val valueEncoder = RowEncoder(df1.schema) + + val df3 = df1.groupBy("a", "b").as[GroupByKey, Row] + .cogroup(df2.groupBy("a", "b").as[GroupByKey, Row]) { case (_, data1, data2) => + data1.zip(data2).map { p => + p._1.getInt(2) + p._2.getInt(2) + } + }.toDF + + checkAnswer(df3.sort("value"), Row(7) :: Row(9) :: Nil) + + // Assert that no extra shuffle introduced by cogroup. + val exchanges = collect(df3.queryExecution.executedPlan) { + case h: ShuffleExchangeExec => h + } + assert(exchanges.size == 2) + } + + test("groupBy.as: custom grouping expressions") { + val df1 = Seq((1, 2, 3), (2, 3, 4)).toDF("a1", "b", "c") + .repartition($"a1", $"b").sortWithinPartitions("a1", "b") + val df2 = Seq((1, 2, 4), (2, 3, 5)).toDF("a1", "b", "c") + .repartition($"a1", $"b").sortWithinPartitions("a1", "b") + + implicit val valueEncoder = RowEncoder(df1.schema) + + val groupedDataset1 = df1.groupBy(($"a1" + 1).as("a"), $"b").as[GroupByKey, Row] + val groupedDataset2 = df2.groupBy(($"a1" + 1).as("a"), $"b").as[GroupByKey, Row] + + val df3 = groupedDataset1 + .cogroup(groupedDataset2) { case (_, data1, data2) => + data1.zip(data2).map { p => + p._1.getInt(2) + p._2.getInt(2) + } + }.toDF + + checkAnswer(df3.sort("value"), Row(7) :: Row(9) :: Nil) + } + + test("groupBy.as: throw AnalysisException for unresolved grouping expr") { + val df = Seq((1, 2, 3), (2, 3, 4)).toDF("a", "b", "c") + + implicit val valueEncoder = RowEncoder(df.schema) + + val err = intercept[AnalysisException] { + df.groupBy($"d", $"b").as[GroupByKey, Row] + } + assert(err.getMessage.contains("cannot resolve '`d`'")) + } + + test("emptyDataFrame should be foldable") { + val emptyDf = spark.emptyDataFrame.withColumn("id", lit(1L)) + val joined = spark.range(10).join(emptyDf, "id") + joined.queryExecution.optimizedPlan match { + case LocalRelation(Seq(id), Nil, _) => + assert(id.name == "id") + case _ => + fail("emptyDataFrame should be foldable") + } + } + + test("SPARK-30811: CTE should not cause stack overflow when " + + "it refers to non-existent table with same name") { + val e = intercept[AnalysisException] { + sql("WITH t AS (SELECT 1 FROM nonexist.t) SELECT * FROM t") + } + assert(e.getMessage.contains("Table or view not found:")) + } + + test("CalendarInterval reflection support") { + val df = Seq((1, new CalendarInterval(1, 2, 3))).toDF("a", "b") + checkAnswer(df.selectExpr("b"), Row(new CalendarInterval(1, 2, 3))) + } + + test("SPARK-31552: array encoder with different types") { + // primitives + val booleans = Array(true, false) + checkAnswer(Seq(booleans).toDF(), Row(booleans)) + + val bytes = Array(1.toByte, 2.toByte) + checkAnswer(Seq(bytes).toDF(), Row(bytes)) + val shorts = Array(1.toShort, 2.toShort) + checkAnswer(Seq(shorts).toDF(), Row(shorts)) + val ints = Array(1, 2) + checkAnswer(Seq(ints).toDF(), Row(ints)) + val longs = Array(1L, 2L) + checkAnswer(Seq(longs).toDF(), Row(longs)) + + val floats = Array(1.0F, 2.0F) + checkAnswer(Seq(floats).toDF(), Row(floats)) + val doubles = Array(1.0D, 2.0D) + checkAnswer(Seq(doubles).toDF(), Row(doubles)) + + val strings = Array("2020-04-24", "2020-04-25") + checkAnswer(Seq(strings).toDF(), Row(strings)) + + // tuples + val decOne = Decimal(1, 38, 18) + val decTwo = Decimal(2, 38, 18) + val tuple1 = (1, 2.2, "3.33", decOne, Date.valueOf("2012-11-22")) + val tuple2 = (2, 3.3, "4.44", decTwo, Date.valueOf("2022-11-22")) + checkAnswer(Seq(Array(tuple1, tuple2)).toDF(), Seq(Seq(tuple1, tuple2)).toDF()) + + // case classes + val gbks = Array(GroupByKey(1, 2), GroupByKey(4, 5)) + checkAnswer(Seq(gbks).toDF(), Row(Array(Row(1, 2), Row(4, 5)))) + + // We can move this implicit def to [[SQLImplicits]] when we eventually make fully + // support for array encoder like Seq and Set + // For now cases below, decimal/datetime/interval/binary/nested types, etc, + // are not supported by array + implicit def newArrayEncoder[T <: Array[_] : TypeTag]: Encoder[T] = ExpressionEncoder() + + // decimals + val decSpark = Array(decOne, decTwo) + val decScala = decSpark.map(_.toBigDecimal) + val decJava = decSpark.map(_.toJavaBigDecimal) + checkAnswer(Seq(decSpark).toDF(), Row(decJava)) + checkAnswer(Seq(decScala).toDF(), Row(decJava)) + checkAnswer(Seq(decJava).toDF(), Row(decJava)) + + // datetimes and intervals + val dates = strings.map(Date.valueOf) + checkAnswer(Seq(dates).toDF(), Row(dates)) + val localDates = dates.map(d => DateTimeUtils.daysToLocalDate(DateTimeUtils.fromJavaDate(d))) + checkAnswer(Seq(localDates).toDF(), Row(dates)) + + val timestamps = + Array(Timestamp.valueOf("2020-04-24 12:34:56"), Timestamp.valueOf("2020-04-24 11:22:33")) + checkAnswer(Seq(timestamps).toDF(), Row(timestamps)) + val instants = + timestamps.map(t => DateTimeUtils.microsToInstant(DateTimeUtils.fromJavaTimestamp(t))) + checkAnswer(Seq(instants).toDF(), Row(timestamps)) + + val intervals = Array(new CalendarInterval(1, 2, 3), new CalendarInterval(4, 5, 6)) + checkAnswer(Seq(intervals).toDF(), Row(intervals)) + + // binary + val bins = Array(Array(1.toByte), Array(2.toByte), Array(3.toByte), Array(4.toByte)) + checkAnswer(Seq(bins).toDF(), Row(bins)) + + // nested + val nestedIntArray = Array(Array(1), Array(2)) + checkAnswer(Seq(nestedIntArray).toDF(), Row(nestedIntArray.map(wrapIntArray))) + val nestedDecArray = Array(decSpark) + checkAnswer(Seq(nestedDecArray).toDF(), Row(Array(wrapRefArray(decJava)))) + } + + test("SPARK-31750: eliminate UpCast if child's dataType is DecimalType") { + withTempPath { f => + sql("select cast(1 as decimal(38, 0)) as d") + .write.mode("overwrite") + .parquet(f.getAbsolutePath) + + val df = spark.read.parquet(f.getAbsolutePath).as[BigDecimal] + assert(df.schema === new StructType().add(StructField("d", DecimalType(38, 0)))) + } + } } + +case class GroupByKey(a: Int, b: Int) From 4d9f176c602d2337e4d69d0f41a6331df7398ace Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 9 Jul 2020 15:56:40 +0900 Subject: [PATCH 2/5] [SPARK-28067][SPARK-32018] Fix decimal overflow issues ### What changes were proposed in this pull request? This is a followup of https://github.com/apache/spark/pull/27627 to fix the remaining issues. There are 2 issues fixed in this PR: 1. `UnsafeRow.setDecimal` can set an overflowed decimal and causes an error when reading it. The expected behavior is to return null. 2. The update/merge expression for decimal type in `Sum` is wrong. We shouldn't turn the `sum` value back to 0 after it becomes null due to overflow. This issue was hidden because: 2.1 for hash aggregate, the buffer is unsafe row. Due to the first bug, we fail when overflow happens, so there is no chance to mistakenly turn null back to 0. 2.2 for sort-based aggregate, the buffer is generic row. The decimal can overflow (the Decimal class has unlimited precision) and we don't have the null problem. If we only fix the first bug, then the second bug is exposed and test fails. If we only fix the second bug, there is no way to test it. This PR fixes these 2 bugs together. ### Why are the changes needed? Fix issues during decimal sum when overflow happens ### Does this PR introduce _any_ user-facing change? Yes. Now decimal sum can return null correctly for overflow under non-ansi mode. ### How was this patch tested? new test and updated test Closes #29026 from cloud-fan/decimal. Authored-by: Wenchen Fan Signed-off-by: HyukjinKwon --- .../sql/catalyst/expressions/UnsafeRow.java | 2 +- .../catalyst/expressions/aggregate/Sum.scala | 64 ++++++++++++------- .../org/apache/spark/sql/DataFrameSuite.scala | 14 +--- .../org/apache/spark/sql/UnsafeRowSuite.scala | 10 +++ 4 files changed, 54 insertions(+), 36 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index 9bf9452855f5f..faf7c2ecf4c21 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -279,7 +279,7 @@ public void setDecimal(int ordinal, Decimal value, int precision) { Platform.putLong(baseObject, baseOffset + cursor, 0L); Platform.putLong(baseObject, baseOffset + cursor + 8, 0L); - if (value == null) { + if (value == null || !value.changePrecision(precision, value.scale())) { setNullAt(ordinal); // keep the offset for future update Platform.putLong(baseObject, getFieldOffset(ordinal), cursor << 32); diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala index 13facbdee64eb..aca8c8e79cec2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala @@ -58,13 +58,11 @@ case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCast case _ => DoubleType } - private lazy val sumDataType = resultType - - private lazy val sum = AttributeReference("sum", sumDataType)() + private lazy val sum = AttributeReference("sum", resultType)() private lazy val isEmpty = AttributeReference("isEmpty", BooleanType, nullable = false)() - private lazy val zero = Literal.default(sumDataType) + private lazy val zero = Literal.default(resultType) override lazy val aggBufferAttributes = resultType match { case _: DecimalType => sum :: isEmpty :: Nil @@ -72,25 +70,38 @@ case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCast } override lazy val initialValues: Seq[Expression] = resultType match { - case _: DecimalType => Seq(Literal(null, resultType), Literal(true, BooleanType)) + case _: DecimalType => Seq(zero, Literal(true, BooleanType)) case _ => Seq(Literal(null, resultType)) } override lazy val updateExpressions: Seq[Expression] = { - if (child.nullable) { - val updateSumExpr = coalesce(coalesce(sum, zero) + child.cast(sumDataType), sum) - resultType match { - case _: DecimalType => - Seq(updateSumExpr, isEmpty && child.isNull) - case _ => Seq(updateSumExpr) - } - } else { - val updateSumExpr = coalesce(sum, zero) + child.cast(sumDataType) - resultType match { - case _: DecimalType => - Seq(updateSumExpr, Literal(false, BooleanType)) - case _ => Seq(updateSumExpr) - } + resultType match { + case _: DecimalType => + // For decimal type, the initial value of `sum` is 0. We need to keep `sum` unchanged if + // the input is null, as SUM function ignores null input. The `sum` can only be null if + // overflow happens under non-ansi mode. + val sumExpr = if (child.nullable) { + If(child.isNull, sum, sum + KnownNotNull(child).cast(resultType)) + } else { + sum + child.cast(resultType) + } + // The buffer becomes non-empty after seeing the first not-null input. + val isEmptyExpr = if (child.nullable) { + isEmpty && child.isNull + } else { + Literal(false, BooleanType) + } + Seq(sumExpr, isEmptyExpr) + case _ => + // For non-decimal type, the initial value of `sum` is null, which indicates no value. + // We need `coalesce(sum, zero)` to start summing values. And we need an outer `coalesce` + // in case the input is nullable. The `sum` can only be null if there is no value, as + // non-decimal type can produce overflowed value under non-ansi mode. + if (child.nullable) { + Seq(coalesce(coalesce(sum, zero) + child.cast(resultType), sum)) + } else { + Seq(coalesce(sum, zero) + child.cast(resultType)) + } } } @@ -107,15 +118,20 @@ case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCast * means we have seen atleast a value that was not null. */ override lazy val mergeExpressions: Seq[Expression] = { - val mergeSumExpr = coalesce(coalesce(sum.left, zero) + sum.right, sum.left) resultType match { case _: DecimalType => - val inputOverflow = !isEmpty.right && sum.right.isNull val bufferOverflow = !isEmpty.left && sum.left.isNull + val inputOverflow = !isEmpty.right && sum.right.isNull Seq( - If(inputOverflow || bufferOverflow, Literal.create(null, sumDataType), mergeSumExpr), + If( + bufferOverflow || inputOverflow, + Literal.create(null, resultType), + // If both the buffer and the input do not overflow, just add them, as they can't be + // null. See the comments inside `updateExpressions`: `sum` can only be null if + // overflow happens. + KnownNotNull(sum.left) + KnownNotNull(sum.right)), isEmpty.left && isEmpty.right) - case _ => Seq(mergeSumExpr) + case _ => Seq(coalesce(coalesce(sum.left, zero) + sum.right, sum.left)) } } @@ -128,7 +144,7 @@ case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCast */ override lazy val evaluateExpression: Expression = resultType match { case d: DecimalType => - If(isEmpty, Literal.create(null, sumDataType), + If(isEmpty, Literal.create(null, resultType), CheckOverflowInSum(sum, d, !SQLConf.get.ansiEnabled)) case _ => sum } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index ac97288ddb6b9..4bdc5194f4ceb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -169,22 +169,14 @@ class DataFrameSuite extends QueryTest with SharedSparkSession { private def assertDecimalSumOverflow( df: DataFrame, ansiEnabled: Boolean, expectedAnswer: Row): Unit = { if (!ansiEnabled) { - try { - checkAnswer(df, expectedAnswer) - } catch { - case e: SparkException if e.getCause.isInstanceOf[ArithmeticException] => - // This is an existing bug that we can write overflowed decimal to UnsafeRow but fail - // to read it. - assert(e.getCause.getMessage.contains("Decimal precision 39 exceeds max precision 38")) - } + checkAnswer(df, expectedAnswer) } else { val e = intercept[SparkException] { - df.collect + df.collect() } assert(e.getCause.isInstanceOf[ArithmeticException]) assert(e.getCause.getMessage.contains("cannot be represented as Decimal") || - e.getCause.getMessage.contains("Overflow in sum of decimals") || - e.getCause.getMessage.contains("Decimal precision 39 exceeds max precision 38")) + e.getCause.getMessage.contains("Overflow in sum of decimals")) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala index a5f904c621e6e..9daa69ce9f155 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala @@ -178,4 +178,14 @@ class UnsafeRowSuite extends SparkFunSuite { // Makes sure hashCode on unsafe array won't crash unsafeRow.getArray(0).hashCode() } + + test("SPARK-32018: setDecimal with overflowed value") { + val d1 = new Decimal().set(BigDecimal("10000000000000000000")).toPrecision(38, 18) + val row = InternalRow.apply(d1) + val unsafeRow = UnsafeProjection.create(Array[DataType](DecimalType(38, 18))).apply(row) + assert(unsafeRow.getDecimal(0, 38, 18) === d1) + val d2 = (d1 * Decimal(10)).toPrecision(39, 18) + unsafeRow.setDecimal(0, d2, 38) + assert(unsafeRow.getDecimal(0, 38, 18) === null) + } } From 3126a285d685442684e8cf0817d86f6b9ef6e25b Mon Sep 17 00:00:00 2001 From: "longfei.jiang" Date: Fri, 16 Jul 2021 17:31:36 +0800 Subject: [PATCH 3/5] KE-24858 fix error: java.lang.IllegalArgumentException: Can not interpolate java.lang.Boolean into code block. --- .../catalyst/analysis/DecimalPrecision.scala | 60 +++-- .../expressions/aggregate/Average.scala | 8 +- .../expressions/codegen/javaCode.scala | 2 +- .../expressions/DecimalExpressionSuite.scala | 19 +- .../org/apache/spark/sql/DataFrameSuite.scala | 241 ++++-------------- 5 files changed, 104 insertions(+), 226 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala index 82692334544e2..b0c4cc05d0ce8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala @@ -87,49 +87,56 @@ object DecimalPrecision extends TypeCoercionRule { case q => q.transformExpressionsUp( decimalAndDecimal.orElse(integralAndDecimalLiteral).orElse(nondecimalAndDecimal)) } + private[catalyst] def decimalAndDecimal(): PartialFunction[Expression, Expression] = { + decimalAndDecimal(SQLConf.get.decimalOperationsAllowPrecisionLoss, !SQLConf.get.ansiEnabled) + } /** Decimal precision promotion for +, -, *, /, %, pmod, and binary comparison. */ - private[catalyst] val decimalAndDecimal: PartialFunction[Expression, Expression] = { + private[catalyst] def decimalAndDecimal(allowPrecisionLoss: Boolean, nullOnOverflow: Boolean) + : PartialFunction[Expression, Expression] = { // Skip nodes whose children have not been resolved yet case e if !e.childrenResolved => e // Skip nodes who is already promoted case e: BinaryArithmetic if e.left.isInstanceOf[PromotePrecision] => e - case Add(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => + case a @ Add(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => val resultScale = max(s1, s2) - val resultType = if (SQLConf.get.decimalOperationsAllowPrecisionLoss) { + val resultType = if (allowPrecisionLoss) { DecimalType.adjustPrecisionScale(max(p1 - s1, p2 - s2) + resultScale + 1, resultScale) } else { DecimalType.bounded(max(p1 - s1, p2 - s2) + resultScale + 1, resultScale) } - CheckOverflow(Add(promotePrecision(e1, resultType), promotePrecision(e2, resultType)), - resultType) + CheckOverflow( + a.copy(left = promotePrecision(e1, resultType), right = promotePrecision(e2, resultType)), + resultType, nullOnOverflow) - case Subtract(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => + case s @ Subtract(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => val resultScale = max(s1, s2) - val resultType = if (SQLConf.get.decimalOperationsAllowPrecisionLoss) { + val resultType = if (allowPrecisionLoss) { DecimalType.adjustPrecisionScale(max(p1 - s1, p2 - s2) + resultScale + 1, resultScale) } else { DecimalType.bounded(max(p1 - s1, p2 - s2) + resultScale + 1, resultScale) } - CheckOverflow(Subtract(promotePrecision(e1, resultType), promotePrecision(e2, resultType)), - resultType) + CheckOverflow( + s.copy(left = promotePrecision(e1, resultType), right = promotePrecision(e2, resultType)), + resultType, nullOnOverflow) - case Multiply(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => - val resultType = if (SQLConf.get.decimalOperationsAllowPrecisionLoss) { + case m @ Multiply(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => + val resultType = if (allowPrecisionLoss) { DecimalType.adjustPrecisionScale(p1 + p2 + 1, s1 + s2) } else { DecimalType.bounded(p1 + p2 + 1, s1 + s2) } val widerType = widerDecimalType(p1, s1, p2, s2) - CheckOverflow(Multiply(promotePrecision(e1, widerType), promotePrecision(e2, widerType)), - resultType) + CheckOverflow( + m.copy(left = promotePrecision(e1, widerType), right = promotePrecision(e2, widerType)), + resultType, nullOnOverflow) - case Divide(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => - val resultType = if (SQLConf.get.decimalOperationsAllowPrecisionLoss) { + case d @ Divide(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => + val resultType = if (allowPrecisionLoss) { // Precision: p1 - s1 + s2 + max(6, s1 + p2 + 1) // Scale: max(6, s1 + p2 + 1) val intDig = p1 - s1 + s2 @@ -147,30 +154,33 @@ object DecimalPrecision extends TypeCoercionRule { DecimalType.bounded(intDig + decDig, decDig) } val widerType = widerDecimalType(p1, s1, p2, s2) - CheckOverflow(Divide(promotePrecision(e1, widerType), promotePrecision(e2, widerType)), - resultType) + CheckOverflow( + d.copy(left = promotePrecision(e1, widerType), right = promotePrecision(e2, widerType)), + resultType, nullOnOverflow) - case Remainder(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => - val resultType = if (SQLConf.get.decimalOperationsAllowPrecisionLoss) { + case r @ Remainder(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => + val resultType = if (allowPrecisionLoss) { DecimalType.adjustPrecisionScale(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2)) } else { DecimalType.bounded(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2)) } // resultType may have lower precision, so we cast them into wider type first. val widerType = widerDecimalType(p1, s1, p2, s2) - CheckOverflow(Remainder(promotePrecision(e1, widerType), promotePrecision(e2, widerType)), - resultType) + CheckOverflow( + r.copy(left = promotePrecision(e1, widerType), right = promotePrecision(e2, widerType)), + resultType, nullOnOverflow) - case Pmod(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => - val resultType = if (SQLConf.get.decimalOperationsAllowPrecisionLoss) { + case p @ Pmod(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => + val resultType = if (allowPrecisionLoss) { DecimalType.adjustPrecisionScale(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2)) } else { DecimalType.bounded(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2)) } // resultType may have lower precision, so we cast them into wider type first. val widerType = widerDecimalType(p1, s1, p2, s2) - CheckOverflow(Pmod(promotePrecision(e1, widerType), promotePrecision(e2, widerType)), - resultType) + CheckOverflow( + p.copy(left = promotePrecision(e1, widerType), right = promotePrecision(e2, widerType)), + resultType, nullOnOverflow) case b @ BinaryComparison(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala index 5ecb77be5965e..8356d6b9d1b1a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala @@ -21,6 +21,7 @@ import org.apache.spark.sql.catalyst.analysis.{DecimalPrecision, TypeCheckResult import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util.TypeUtils +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ abstract class AverageLike(child: Expression) extends DeclarativeAggregate { @@ -57,8 +58,11 @@ abstract class AverageLike(child: Expression) extends DeclarativeAggregate { // If all input are nulls, count will be 0 and we will get null after the division. override lazy val evaluateExpression = child.dataType match { - case _: DecimalType => - DecimalPrecision.decimalAndDecimal(sum / count.cast(DecimalType.LongDecimal)).cast(resultType) + case d: DecimalType => + DecimalPrecision.decimalAndDecimal()( + Divide( + CheckOverflowInSum(sum, d, !SQLConf.get.ansiEnabled), + count.cast(DecimalType.LongDecimal))).cast(resultType) case _ => sum.cast(resultType) / count.cast(resultType) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala index 17d4a0dc4e884..7bfaf0fd0c767 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala @@ -224,7 +224,7 @@ object Block { } else { args.foreach { case _: ExprValue | _: Inline | _: Block => - case _: Int | _: Long | _: Float | _: Double | _: String => + case _: Boolean | _: Byte | _: Int | _: Long | _: Float | _: Double | _: String => case other => throw new IllegalArgumentException( s"Can not interpolate ${other.getClass.getName} into code block.") } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DecimalExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DecimalExpressionSuite.scala index a8f758d625a02..941bab2ea01e9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DecimalExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DecimalExpressionSuite.scala @@ -45,18 +45,19 @@ class DecimalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { test("CheckOverflow") { val d1 = Decimal("10.1") - checkEvaluation(CheckOverflow(Literal(d1), DecimalType(4, 0)), Decimal("10")) - checkEvaluation(CheckOverflow(Literal(d1), DecimalType(4, 1)), d1) - checkEvaluation(CheckOverflow(Literal(d1), DecimalType(4, 2)), d1) - checkEvaluation(CheckOverflow(Literal(d1), DecimalType(4, 3)), null) + checkEvaluation(CheckOverflow(Literal(d1), DecimalType(4, 0), true), Decimal("10")) + checkEvaluation(CheckOverflow(Literal(d1), DecimalType(4, 1), true), d1) + checkEvaluation(CheckOverflow(Literal(d1), DecimalType(4, 2), true), d1) + checkEvaluation(CheckOverflow(Literal(d1), DecimalType(4, 3), true), null) val d2 = Decimal(101, 3, 1) - checkEvaluation(CheckOverflow(Literal(d2), DecimalType(4, 0)), Decimal("10")) - checkEvaluation(CheckOverflow(Literal(d2), DecimalType(4, 1)), d2) - checkEvaluation(CheckOverflow(Literal(d2), DecimalType(4, 2)), d2) - checkEvaluation(CheckOverflow(Literal(d2), DecimalType(4, 3)), null) + checkEvaluation(CheckOverflow(Literal(d2), DecimalType(4, 0), true), Decimal("10")) + checkEvaluation(CheckOverflow(Literal(d2), DecimalType(4, 1), true), d2) + checkEvaluation(CheckOverflow(Literal(d2), DecimalType(4, 2), true), d2) + checkEvaluation(CheckOverflow(Literal(d2), DecimalType(4, 3), true), null) - checkEvaluation(CheckOverflow(Literal.create(null, DecimalType(2, 1)), DecimalType(3, 2)), null) + checkEvaluation( + CheckOverflow(Literal.create(null, DecimalType(2, 1)), DecimalType(3, 2), true), null) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 4bdc5194f4ceb..8a4a3af0f281c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -167,7 +167,7 @@ class DataFrameSuite extends QueryTest with SharedSparkSession { } private def assertDecimalSumOverflow( - df: DataFrame, ansiEnabled: Boolean, expectedAnswer: Row): Unit = { + df: DataFrame, ansiEnabled: Boolean, expectedAnswer: Row): Unit = { if (!ansiEnabled) { checkAnswer(df, expectedAnswer) } else { @@ -1028,14 +1028,14 @@ class DataFrameSuite extends QueryTest with SharedSparkSession { val longString = Array.fill(21)("1").mkString val df = sparkContext.parallelize(Seq("1", longString)).toDF() val expectedAnswerForFalse = "-RECORD 0----------------------\n" + - " value | 1 \n" + - "-RECORD 1----------------------\n" + - " value | 111111111111111111111 \n" + " value | 1 \n" + + "-RECORD 1----------------------\n" + + " value | 111111111111111111111 \n" assert(df.showString(10, truncate = 0, vertical = true) === expectedAnswerForFalse) val expectedAnswerForTrue = "-RECORD 0---------------------\n" + - " value | 1 \n" + - "-RECORD 1---------------------\n" + - " value | 11111111111111111... \n" + " value | 1 \n" + + "-RECORD 1---------------------\n" + + " value | 11111111111111111... \n" assert(df.showString(10, truncate = 20, vertical = true) === expectedAnswerForTrue) } @@ -1064,14 +1064,14 @@ class DataFrameSuite extends QueryTest with SharedSparkSession { val longString = Array.fill(21)("1").mkString val df = sparkContext.parallelize(Seq("1", longString)).toDF() val expectedAnswerForFalse = "-RECORD 0----\n" + - " value | 1 \n" + - "-RECORD 1----\n" + - " value | 111 \n" + " value | 1 \n" + + "-RECORD 1----\n" + + " value | 111 \n" assert(df.showString(10, truncate = 3, vertical = true) === expectedAnswerForFalse) val expectedAnswerForTrue = "-RECORD 0------------------\n" + - " value | 1 \n" + - "-RECORD 1------------------\n" + - " value | 11111111111111... \n" + " value | 1 \n" + + "-RECORD 1------------------\n" + + " value | 11111111111111... \n" assert(df.showString(10, truncate = 17, vertical = true) === expectedAnswerForTrue) } @@ -1138,11 +1138,11 @@ class DataFrameSuite extends QueryTest with SharedSparkSession { (Array(2, 3, 4), Array(2, 3, 4)) ).toDF() val expectedAnswer = "-RECORD 0--------\n" + - " _1 | [1, 2, 3] \n" + - " _2 | [1, 2, 3] \n" + - "-RECORD 1--------\n" + - " _1 | [2, 3, 4] \n" + - " _2 | [2, 3, 4] \n" + " _1 | [1, 2, 3] \n" + + " _2 | [1, 2, 3] \n" + + "-RECORD 1--------\n" + + " _1 | [2, 3, 4] \n" + + " _2 | [2, 3, 4] \n" assert(df.showString(10, vertical = true) === expectedAnswer) } @@ -1167,11 +1167,11 @@ class DataFrameSuite extends QueryTest with SharedSparkSession { ("34".getBytes(StandardCharsets.UTF_8), "12346".getBytes(StandardCharsets.UTF_8)) ).toDF() val expectedAnswer = "-RECORD 0---------------\n" + - " _1 | [31 32] \n" + - " _2 | [41 42 43 2E] \n" + - "-RECORD 1---------------\n" + - " _1 | [33 34] \n" + - " _2 | [31 32 33 34 36] \n" + " _1 | [31 32] \n" + + " _2 | [41 42 43 2E] \n" + + "-RECORD 1---------------\n" + + " _1 | [33 34] \n" + + " _2 | [31 32 33 34 36] \n" assert(df.showString(10, vertical = true) === expectedAnswer) } @@ -1196,11 +1196,11 @@ class DataFrameSuite extends QueryTest with SharedSparkSession { (2, 2) ).toDF() val expectedAnswer = "-RECORD 0--\n" + - " _1 | 1 \n" + - " _2 | 1 \n" + - "-RECORD 1--\n" + - " _1 | 2 \n" + - " _2 | 2 \n" + " _1 | 1 \n" + + " _2 | 1 \n" + + "-RECORD 1--\n" + + " _1 | 2 \n" + + " _2 | 2 \n" assert(df.showString(10, vertical = true) === expectedAnswer) } @@ -1217,9 +1217,9 @@ class DataFrameSuite extends QueryTest with SharedSparkSession { test("SPARK-7319 showString, vertical = true") { val expectedAnswer = "-RECORD 0----\n" + - " key | 1 \n" + - " value | 1 \n" + - "only showing top 1 row\n" + " key | 1 \n" + + " value | 1 \n" + + "only showing top 1 row\n" assert(testData.select($"*").showString(1, vertical = true) === expectedAnswer) } @@ -1293,15 +1293,15 @@ class DataFrameSuite extends QueryTest with SharedSparkSession { val ts = Timestamp.valueOf("2016-12-01 00:00:00") val df = Seq((d, ts)).toDF("d", "ts") val expectedAnswer = "-RECORD 0------------------\n" + - " d | 2016-12-01 \n" + - " ts | 2016-12-01 00:00:00 \n" + " d | 2016-12-01 \n" + + " ts | 2016-12-01 00:00:00 \n" assert(df.showString(1, truncate = 0, vertical = true) === expectedAnswer) withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> "UTC") { val expectedAnswer = "-RECORD 0------------------\n" + - " d | 2016-12-01 \n" + - " ts | 2016-12-01 08:00:00 \n" + " d | 2016-12-01 \n" + + " ts | 2016-12-01 08:00:00 \n" assert(df.showString(1, truncate = 0, vertical = true) === expectedAnswer) } } @@ -1356,10 +1356,10 @@ class DataFrameSuite extends QueryTest with SharedSparkSession { test("SPARK-7324 dropDuplicates") { val testData = sparkContext.parallelize( (2, 1, 2) :: (1, 1, 1) :: - (1, 2, 1) :: (2, 1, 2) :: - (2, 2, 2) :: (2, 2, 1) :: - (2, 1, 1) :: (1, 1, 2) :: - (1, 2, 2) :: (1, 2, 1) :: Nil).toDF("key", "value1", "value2") + (1, 2, 1) :: (2, 1, 2) :: + (2, 2, 2) :: (2, 2, 1) :: + (2, 1, 1) :: (1, 1, 2) :: + (1, 2, 2) :: (1, 2, 1) :: Nil).toDF("key", "value1", "value2") checkAnswer( testData.dropDuplicates(), @@ -1803,27 +1803,25 @@ class DataFrameSuite extends QueryTest with SharedSparkSession { } test("reuse exchange") { - withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "2") { + withSQLConf("spark.sql.autoBroadcastJoinThreshold" -> "2") { val df = spark.range(100).toDF() val join = df.join(df, "id") val plan = join.queryExecution.executedPlan checkAnswer(join, df) assert( - collect(join.queryExecution.executedPlan) { - case e: ShuffleExchangeExec => true }.size === 1) + join.queryExecution.executedPlan.collect { case e: ShuffleExchangeExec => true }.size === 1) assert( - collect(join.queryExecution.executedPlan) { case e: ReusedExchangeExec => true }.size === 1) + join.queryExecution.executedPlan.collect { case e: ReusedExchangeExec => true }.size === 1) val broadcasted = broadcast(join) val join2 = join.join(broadcasted, "id").join(broadcasted, "id") checkAnswer(join2, df) assert( - collect(join2.queryExecution.executedPlan) { - case e: ShuffleExchangeExec => true }.size == 1) + join2.queryExecution.executedPlan.collect { case e: ShuffleExchangeExec => true }.size == 1) assert( - collect(join2.queryExecution.executedPlan) { - case e: BroadcastExchangeExec => true }.size === 1) + join2.queryExecution.executedPlan + .collect { case e: BroadcastExchangeExec => true }.size === 1) assert( - collect(join2.queryExecution.executedPlan) { case e: ReusedExchangeExec => true }.size == 4) + join2.queryExecution.executedPlan.collect { case e: ReusedExchangeExec => true }.size == 4) } } @@ -1895,7 +1893,7 @@ class DataFrameSuite extends QueryTest with SharedSparkSession { } - } + } test("SPARK-13774: Check error message for not existent globbed paths") { // Non-existent initial path component: @@ -1968,9 +1966,9 @@ class DataFrameSuite extends QueryTest with SharedSparkSession { } private def verifyNullabilityInFilterExec( - df: DataFrame, - expr: String, - expectedNonNullableColumns: Seq[String]): Unit = { + df: DataFrame, + expr: String, + expectedNonNullableColumns: Seq[String]): Unit = { val dfWithFilter = df.where(s"isnotnull($expr)").selectExpr(expr) dfWithFilter.queryExecution.executedPlan.collect { // When the child expression in isnotnull is null-intolerant (i.e. any null input will @@ -2207,7 +2205,7 @@ class DataFrameSuite extends QueryTest with SharedSparkSession { } test("SPARK-24781: Using a reference not in aggregation in Filter/Sort") { - withSQLConf(SQLConf.DATAFRAME_RETAIN_GROUP_COLUMNS.key -> "false") { + withSQLConf(SQLConf.DATAFRAME_RETAIN_GROUP_COLUMNS.key -> "false") { val df = Seq(("test1", 0), ("test2", 1)).toDF("name", "id") val aggPlusSort1 = df.groupBy(df("name")).agg(count(df("name"))).orderBy(df("name")) @@ -2239,7 +2237,7 @@ class DataFrameSuite extends QueryTest with SharedSparkSession { val df = spark.read.json(path.getCanonicalPath) assert(df.columns === Array("i", "p")) - spark.sparkContext.listenerBus.waitUntilEmpty() + spark.sparkContext.listenerBus.waitUntilEmpty(10000) assert(numJobs.get() == 1L) } } @@ -2346,62 +2344,6 @@ class DataFrameSuite extends QueryTest with SharedSparkSession { assert(idTuples.length == idTuples.toSet.size) } - test("groupBy.as") { - val df1 = Seq((1, 2, 3), (2, 3, 4)).toDF("a", "b", "c") - .repartition($"a", $"b").sortWithinPartitions("a", "b") - val df2 = Seq((1, 2, 4), (2, 3, 5)).toDF("a", "b", "c") - .repartition($"a", $"b").sortWithinPartitions("a", "b") - - implicit val valueEncoder = RowEncoder(df1.schema) - - val df3 = df1.groupBy("a", "b").as[GroupByKey, Row] - .cogroup(df2.groupBy("a", "b").as[GroupByKey, Row]) { case (_, data1, data2) => - data1.zip(data2).map { p => - p._1.getInt(2) + p._2.getInt(2) - } - }.toDF - - checkAnswer(df3.sort("value"), Row(7) :: Row(9) :: Nil) - - // Assert that no extra shuffle introduced by cogroup. - val exchanges = collect(df3.queryExecution.executedPlan) { - case h: ShuffleExchangeExec => h - } - assert(exchanges.size == 2) - } - - test("groupBy.as: custom grouping expressions") { - val df1 = Seq((1, 2, 3), (2, 3, 4)).toDF("a1", "b", "c") - .repartition($"a1", $"b").sortWithinPartitions("a1", "b") - val df2 = Seq((1, 2, 4), (2, 3, 5)).toDF("a1", "b", "c") - .repartition($"a1", $"b").sortWithinPartitions("a1", "b") - - implicit val valueEncoder = RowEncoder(df1.schema) - - val groupedDataset1 = df1.groupBy(($"a1" + 1).as("a"), $"b").as[GroupByKey, Row] - val groupedDataset2 = df2.groupBy(($"a1" + 1).as("a"), $"b").as[GroupByKey, Row] - - val df3 = groupedDataset1 - .cogroup(groupedDataset2) { case (_, data1, data2) => - data1.zip(data2).map { p => - p._1.getInt(2) + p._2.getInt(2) - } - }.toDF - - checkAnswer(df3.sort("value"), Row(7) :: Row(9) :: Nil) - } - - test("groupBy.as: throw AnalysisException for unresolved grouping expr") { - val df = Seq((1, 2, 3), (2, 3, 4)).toDF("a", "b", "c") - - implicit val valueEncoder = RowEncoder(df.schema) - - val err = intercept[AnalysisException] { - df.groupBy($"d", $"b").as[GroupByKey, Row] - } - assert(err.getMessage.contains("cannot resolve '`d`'")) - } - test("emptyDataFrame should be foldable") { val emptyDf = spark.emptyDataFrame.withColumn("id", lit(1L)) val joined = spark.range(10).join(emptyDf, "id") @@ -2421,85 +2363,6 @@ class DataFrameSuite extends QueryTest with SharedSparkSession { assert(e.getMessage.contains("Table or view not found:")) } - test("CalendarInterval reflection support") { - val df = Seq((1, new CalendarInterval(1, 2, 3))).toDF("a", "b") - checkAnswer(df.selectExpr("b"), Row(new CalendarInterval(1, 2, 3))) - } - - test("SPARK-31552: array encoder with different types") { - // primitives - val booleans = Array(true, false) - checkAnswer(Seq(booleans).toDF(), Row(booleans)) - - val bytes = Array(1.toByte, 2.toByte) - checkAnswer(Seq(bytes).toDF(), Row(bytes)) - val shorts = Array(1.toShort, 2.toShort) - checkAnswer(Seq(shorts).toDF(), Row(shorts)) - val ints = Array(1, 2) - checkAnswer(Seq(ints).toDF(), Row(ints)) - val longs = Array(1L, 2L) - checkAnswer(Seq(longs).toDF(), Row(longs)) - - val floats = Array(1.0F, 2.0F) - checkAnswer(Seq(floats).toDF(), Row(floats)) - val doubles = Array(1.0D, 2.0D) - checkAnswer(Seq(doubles).toDF(), Row(doubles)) - - val strings = Array("2020-04-24", "2020-04-25") - checkAnswer(Seq(strings).toDF(), Row(strings)) - - // tuples - val decOne = Decimal(1, 38, 18) - val decTwo = Decimal(2, 38, 18) - val tuple1 = (1, 2.2, "3.33", decOne, Date.valueOf("2012-11-22")) - val tuple2 = (2, 3.3, "4.44", decTwo, Date.valueOf("2022-11-22")) - checkAnswer(Seq(Array(tuple1, tuple2)).toDF(), Seq(Seq(tuple1, tuple2)).toDF()) - - // case classes - val gbks = Array(GroupByKey(1, 2), GroupByKey(4, 5)) - checkAnswer(Seq(gbks).toDF(), Row(Array(Row(1, 2), Row(4, 5)))) - - // We can move this implicit def to [[SQLImplicits]] when we eventually make fully - // support for array encoder like Seq and Set - // For now cases below, decimal/datetime/interval/binary/nested types, etc, - // are not supported by array - implicit def newArrayEncoder[T <: Array[_] : TypeTag]: Encoder[T] = ExpressionEncoder() - - // decimals - val decSpark = Array(decOne, decTwo) - val decScala = decSpark.map(_.toBigDecimal) - val decJava = decSpark.map(_.toJavaBigDecimal) - checkAnswer(Seq(decSpark).toDF(), Row(decJava)) - checkAnswer(Seq(decScala).toDF(), Row(decJava)) - checkAnswer(Seq(decJava).toDF(), Row(decJava)) - - // datetimes and intervals - val dates = strings.map(Date.valueOf) - checkAnswer(Seq(dates).toDF(), Row(dates)) - val localDates = dates.map(d => DateTimeUtils.daysToLocalDate(DateTimeUtils.fromJavaDate(d))) - checkAnswer(Seq(localDates).toDF(), Row(dates)) - - val timestamps = - Array(Timestamp.valueOf("2020-04-24 12:34:56"), Timestamp.valueOf("2020-04-24 11:22:33")) - checkAnswer(Seq(timestamps).toDF(), Row(timestamps)) - val instants = - timestamps.map(t => DateTimeUtils.microsToInstant(DateTimeUtils.fromJavaTimestamp(t))) - checkAnswer(Seq(instants).toDF(), Row(timestamps)) - - val intervals = Array(new CalendarInterval(1, 2, 3), new CalendarInterval(4, 5, 6)) - checkAnswer(Seq(intervals).toDF(), Row(intervals)) - - // binary - val bins = Array(Array(1.toByte), Array(2.toByte), Array(3.toByte), Array(4.toByte)) - checkAnswer(Seq(bins).toDF(), Row(bins)) - - // nested - val nestedIntArray = Array(Array(1), Array(2)) - checkAnswer(Seq(nestedIntArray).toDF(), Row(nestedIntArray.map(wrapIntArray))) - val nestedDecArray = Array(decSpark) - checkAnswer(Seq(nestedDecArray).toDF(), Row(Array(wrapRefArray(decJava)))) - } - test("SPARK-31750: eliminate UpCast if child's dataType is DecimalType") { withTempPath { f => sql("select cast(1 as decimal(38, 0)) as d") From 57dd3f5148ddf6fe601fb6b419d277b6258210ac Mon Sep 17 00:00:00 2001 From: "longfei.jiang" Date: Wed, 21 Jul 2021 18:09:06 +0800 Subject: [PATCH 4/5] KE-24858 fix ci error --- .../org/apache/spark/sql/types/Decimal.scala | 11 +- .../org/apache/spark/sql/DataFrameSuite.scala | 459 ++++++++++++++---- 2 files changed, 375 insertions(+), 95 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala index 63696658fd3ac..33c0cd07c8d07 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala @@ -243,8 +243,8 @@ final class Decimal extends Ordered[Decimal] with Serializable { private[sql] def toPrecision( precision: Int, scale: Int, - roundMode: BigDecimal.RoundingMode.Value = ROUND_HALF_UP, - nullOnOverflow: Boolean = true): Decimal = { + roundMode: BigDecimal.RoundingMode.Value, + nullOnOverflow: Boolean): Decimal = { val copy = clone() if (copy.changePrecision(precision, scale, roundMode)) { copy @@ -257,6 +257,13 @@ final class Decimal extends Ordered[Decimal] with Serializable { } } + private[sql] def toPrecision( + precision: Int, + scale: Int, + roundMode: BigDecimal.RoundingMode.Value = ROUND_HALF_UP): Decimal = { + toPrecision(precision, scale, roundMode, true) + } + /** * Update precision and scale while keeping our value the same, and return true if successful. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 8a4a3af0f281c..c19250b9325b7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -24,7 +24,7 @@ import java.util.UUID import java.util.concurrent.atomic.AtomicLong import scala.reflect.runtime.universe.TypeTag import scala.util.Random -import org.scalatest.Matchers.{assert, _} +import org.scalatest.Matchers.{assert, intercept, _} import org.apache.spark.SparkException import org.apache.spark.scheduler.{SparkListener, SparkListenerJobEnd} import org.apache.spark.sql.catalyst.TableIdentifier @@ -39,7 +39,7 @@ import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExc import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.{ExamplePoint, ExamplePointUDT, SharedSparkSession} -import org.apache.spark.sql.test.SQLTestData.{DecimalData, TestData2} +import org.apache.spark.sql.test.SQLTestData.{DecimalData, NullStrings, TestData2} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.CalendarInterval import org.apache.spark.util.Utils @@ -443,7 +443,7 @@ class DataFrameSuite extends QueryTest with SharedSparkSession { testData.select("key").coalesce(1).select("key"), testData.select("key").collect().toSeq) - assert(spark.emptyDataFrame.coalesce(1).rdd.partitions.size === 0) + assert(spark.emptyDataFrame.coalesce(1).rdd.partitions.size === 1) } test("convert $\"attribute name\" into unresolved attribute") { @@ -565,6 +565,258 @@ class DataFrameSuite extends QueryTest with SharedSparkSession { Row(0) :: Row(1) :: Nil ) } + test("except") { + checkAnswer( + lowerCaseData.except(upperCaseData), + Row(1, "a") :: + Row(2, "b") :: + Row(3, "c") :: + Row(4, "d") :: Nil) + checkAnswer(lowerCaseData.except(lowerCaseData), Nil) + checkAnswer(upperCaseData.except(upperCaseData), Nil) + + // check null equality + checkAnswer( + nullInts.except(nullInts.filter("0 = 1")), + nullInts) + checkAnswer( + nullInts.except(nullInts), + Nil) + + // check if values are de-duplicated + checkAnswer( + allNulls.except(allNulls.filter("0 = 1")), + Row(null) :: Nil) + checkAnswer( + allNulls.except(allNulls), + Nil) + + // check if values are de-duplicated + val df = Seq(("id1", 1), ("id1", 1), ("id", 1), ("id1", 2)).toDF("id", "value") + checkAnswer( + df.except(df.filter("0 = 1")), + Row("id1", 1) :: + Row("id", 1) :: + Row("id1", 2) :: Nil) + + // check if the empty set on the left side works + checkAnswer( + allNulls.filter("0 = 1").except(allNulls), + Nil) + } + + test("SPARK-23274: except between two projects without references used in filter") { + val df = Seq((1, 2, 4), (1, 3, 5), (2, 2, 3), (2, 4, 5)).toDF("a", "b", "c") + val df1 = df.filter($"a" === 1) + val df2 = df.filter($"a" === 2) + checkAnswer(df1.select("b").except(df2.select("b")), Row(3) :: Nil) + checkAnswer(df1.select("b").except(df2.select("c")), Row(2) :: Nil) + } + + test("except distinct - SQL compliance") { + val df_left = Seq(1, 2, 2, 3, 3, 4).toDF("id") + val df_right = Seq(1, 3).toDF("id") + + checkAnswer( + df_left.except(df_right), + Row(2) :: Row(4) :: Nil + ) + } + + test("except - nullability") { + val nonNullableInts = Seq(Tuple1(11), Tuple1(3)).toDF() + assert(nonNullableInts.schema.forall(!_.nullable)) + + val df1 = nonNullableInts.except(nullInts) + checkAnswer(df1, Row(11) :: Nil) + assert(df1.schema.forall(!_.nullable)) + + val df2 = nullInts.except(nonNullableInts) + checkAnswer(df2, Row(1) :: Row(2) :: Row(null) :: Nil) + assert(df2.schema.forall(_.nullable)) + + val df3 = nullInts.except(nullInts) + checkAnswer(df3, Nil) + assert(df3.schema.forall(_.nullable)) + + val df4 = nonNullableInts.except(nonNullableInts) + checkAnswer(df4, Nil) + assert(df4.schema.forall(!_.nullable)) + } + + test("except all") { + checkAnswer( + lowerCaseData.exceptAll(upperCaseData), + Row(1, "a") :: + Row(2, "b") :: + Row(3, "c") :: + Row(4, "d") :: Nil) + checkAnswer(lowerCaseData.exceptAll(lowerCaseData), Nil) + checkAnswer(upperCaseData.exceptAll(upperCaseData), Nil) + + // check null equality + checkAnswer( + nullInts.exceptAll(nullInts.filter("0 = 1")), + nullInts) + checkAnswer( + nullInts.exceptAll(nullInts), + Nil) + + // check that duplicate values are preserved + checkAnswer( + allNulls.exceptAll(allNulls.filter("0 = 1")), + Row(null) :: Row(null) :: Row(null) :: Row(null) :: Nil) + checkAnswer( + allNulls.exceptAll(allNulls.limit(2)), + Row(null) :: Row(null) :: Nil) + + // check that duplicates are retained. + val df = spark.sparkContext.parallelize( + NullStrings(1, "id1") :: + NullStrings(1, "id1") :: + NullStrings(2, "id1") :: + NullStrings(3, null) :: Nil).toDF("id", "value") + + checkAnswer( + df.exceptAll(df.filter("0 = 1")), + Row(1, "id1") :: + Row(1, "id1") :: + Row(2, "id1") :: + Row(3, null) :: Nil) + + // check if the empty set on the left side works + checkAnswer( + allNulls.filter("0 = 1").exceptAll(allNulls), + Nil) + + } + + test("exceptAll - nullability") { + val nonNullableInts = Seq(Tuple1(11), Tuple1(3)).toDF() + assert(nonNullableInts.schema.forall(!_.nullable)) + + val df1 = nonNullableInts.exceptAll(nullInts) + checkAnswer(df1, Row(11) :: Nil) + assert(df1.schema.forall(!_.nullable)) + + val df2 = nullInts.exceptAll(nonNullableInts) + checkAnswer(df2, Row(1) :: Row(2) :: Row(null) :: Nil) + assert(df2.schema.forall(_.nullable)) + + val df3 = nullInts.exceptAll(nullInts) + checkAnswer(df3, Nil) + assert(df3.schema.forall(_.nullable)) + + val df4 = nonNullableInts.exceptAll(nonNullableInts) + checkAnswer(df4, Nil) + assert(df4.schema.forall(!_.nullable)) + } + + test("intersect") { + checkAnswer( + lowerCaseData.intersect(lowerCaseData), + Row(1, "a") :: + Row(2, "b") :: + Row(3, "c") :: + Row(4, "d") :: Nil) + checkAnswer(lowerCaseData.intersect(upperCaseData), Nil) + + // check null equality + checkAnswer( + nullInts.intersect(nullInts), + Row(1) :: + Row(2) :: + Row(3) :: + Row(null) :: Nil) + + // check if values are de-duplicated + checkAnswer( + allNulls.intersect(allNulls), + Row(null) :: Nil) + + // check if values are de-duplicated + val df = Seq(("id1", 1), ("id1", 1), ("id", 1), ("id1", 2)).toDF("id", "value") + checkAnswer( + df.intersect(df), + Row("id1", 1) :: + Row("id", 1) :: + Row("id1", 2) :: Nil) + } + + test("intersect - nullability") { + val nonNullableInts = Seq(Tuple1(1), Tuple1(3)).toDF() + assert(nonNullableInts.schema.forall(!_.nullable)) + + val df1 = nonNullableInts.intersect(nullInts) + checkAnswer(df1, Row(1) :: Row(3) :: Nil) + assert(df1.schema.forall(!_.nullable)) + + val df2 = nullInts.intersect(nonNullableInts) + checkAnswer(df2, Row(1) :: Row(3) :: Nil) + assert(df2.schema.forall(!_.nullable)) + + val df3 = nullInts.intersect(nullInts) + checkAnswer(df3, Row(1) :: Row(2) :: Row(3) :: Row(null) :: Nil) + assert(df3.schema.forall(_.nullable)) + + val df4 = nonNullableInts.intersect(nonNullableInts) + checkAnswer(df4, Row(1) :: Row(3) :: Nil) + assert(df4.schema.forall(!_.nullable)) + } + + test("intersectAll") { + checkAnswer( + lowerCaseDataWithDuplicates.intersectAll(lowerCaseDataWithDuplicates), + Row(1, "a") :: + Row(2, "b") :: + Row(2, "b") :: + Row(3, "c") :: + Row(3, "c") :: + Row(3, "c") :: + Row(4, "d") :: Nil) + checkAnswer(lowerCaseData.intersectAll(upperCaseData), Nil) + + // check null equality + checkAnswer( + nullInts.intersectAll(nullInts), + Row(1) :: + Row(2) :: + Row(3) :: + Row(null) :: Nil) + + // Duplicate nulls are preserved. + checkAnswer( + allNulls.intersectAll(allNulls), + Row(null) :: Row(null) :: Row(null) :: Row(null) :: Nil) + + val df_left = Seq(1, 2, 2, 3, 3, 4).toDF("id") + val df_right = Seq(1, 2, 2, 3).toDF("id") + + checkAnswer( + df_left.intersectAll(df_right), + Row(1) :: Row(2) :: Row(2) :: Row(3) :: Nil) + } + + test("intersectAll - nullability") { + val nonNullableInts = Seq(Tuple1(1), Tuple1(3)).toDF() + assert(nonNullableInts.schema.forall(!_.nullable)) + + val df1 = nonNullableInts.intersectAll(nullInts) + checkAnswer(df1, Row(1) :: Row(3) :: Nil) + assert(df1.schema.forall(!_.nullable)) + + val df2 = nullInts.intersectAll(nonNullableInts) + checkAnswer(df2, Row(1) :: Row(3) :: Nil) + assert(df2.schema.forall(!_.nullable)) + + val df3 = nullInts.intersectAll(nullInts) + checkAnswer(df3, Row(1) :: Row(2) :: Row(3) :: Row(null) :: Nil) + assert(df3.schema.forall(_.nullable)) + + val df4 = nonNullableInts.intersectAll(nonNullableInts) + checkAnswer(df4, Row(1) :: Row(3) :: Nil) + assert(df4.schema.forall(!_.nullable)) + } test("udf") { val foo = udf((a: Int, b: String) => a.toString + b) @@ -714,21 +966,6 @@ class DataFrameSuite extends QueryTest with SharedSparkSession { assert(df.schema.map(_.name) === Seq("value")) } - test("SPARK-28189 drop column using drop with column reference with case-insensitive names") { - // With SQL config caseSensitive OFF, case insensitive column name should work - withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { - val col1 = testData("KEY") - val df1 = testData.drop(col1) - checkAnswer(df1, testData.selectExpr("value")) - assert(df1.schema.map(_.name) === Seq("value")) - - val col2 = testData("Key") - val df2 = testData.drop(col2) - checkAnswer(df2, testData.selectExpr("value")) - assert(df2.schema.map(_.name) === Seq("value")) - } - } - test("drop unknown column (no-op) with column reference") { val col = Column("random") val df = testData.drop(col) @@ -1585,7 +1822,55 @@ class DataFrameSuite extends QueryTest with SharedSparkSession { } } } + test("SPARK-10539: Project should not be pushed down through Intersect or Except") { + val df1 = (1 to 100).map(Tuple1.apply).toDF("i") + val df2 = (1 to 30).map(Tuple1.apply).toDF("i") + val intersect = df1.intersect(df2) + val except = df1.except(df2) + assert(intersect.count() === 30) + assert(except.count() === 70) + } + + test("SPARK-10740: handle nondeterministic expressions correctly for set operations") { + val df1 = (1 to 20).map(Tuple1.apply).toDF("i") + val df2 = (1 to 10).map(Tuple1.apply).toDF("i") + + // When generating expected results at here, we need to follow the implementation of + // Rand expression. + def expected(df: DataFrame): Seq[Row] = { + df.rdd.collectPartitions().zipWithIndex.flatMap { + case (data, index) => + val rng = new org.apache.spark.util.random.XORShiftRandom(7 + index) + data.filter(_.getInt(0) < rng.nextDouble() * 10) + } + } + val union = df1.union(df2) + checkAnswer( + union.filter('i < rand(7) * 10), + expected(union) + ) + checkAnswer( + union.select(rand(7)), + union.rdd.collectPartitions().zipWithIndex.flatMap { + case (data, index) => + val rng = new org.apache.spark.util.random.XORShiftRandom(7 + index) + data.map(_ => rng.nextDouble()).map(i => Row(i)) + } + ) + + val intersect = df1.intersect(df2) + checkAnswer( + intersect.filter('i < rand(7) * 10), + expected(intersect) + ) + + val except = df1.except(df2) + checkAnswer( + except.filter('i < rand(7) * 10), + expected(except) + ) + } test("SPARK-10743: keep the name of expression if possible when do cast") { val df = (1 to 10).map(Tuple1.apply).toDF("i").as("src") assert(df.select($"src.i".cast(StringType)).columns.head === "i") @@ -2038,6 +2323,21 @@ class DataFrameSuite extends QueryTest with SharedSparkSession { } } + test("SPARK-17123: Performing set operations that combine non-scala native types") { + val dates = Seq( + (new Date(0), BigDecimal.valueOf(1), new Timestamp(2)), + (new Date(3), BigDecimal.valueOf(4), new Timestamp(5)) + ).toDF("date", "timestamp", "decimal") + + val widenTypedRows = Seq( + (new Timestamp(2), 10.5D, "string") + ).toDF("date", "timestamp", "decimal") + + dates.union(widenTypedRows).collect() + dates.except(widenTypedRows).collect() + dates.intersect(widenTypedRows).collect() + } + test("SPARK-18070 binary operator should not consider nullability when comparing input types") { val rows = Seq(Row(Seq(1), Seq(1))) val schema = new StructType() @@ -2056,6 +2356,24 @@ class DataFrameSuite extends QueryTest with SharedSparkSession { val df = spark.range(1).selectExpr("CAST(id as DECIMAL) as x").selectExpr("percentile(x, 0.5)") checkAnswer(df, Row(BigDecimal(0)) :: Nil) } + test("SPARK-19893: cannot run set operations with map type") { + val df = spark.range(1).select(map(lit("key"), $"id").as("m")) + val e = intercept[AnalysisException](df.intersect(df)) + assert(e.message.contains( + "Cannot have map type columns in DataFrame which calls set operations")) + val e2 = intercept[AnalysisException](df.except(df)) + assert(e2.message.contains( + "Cannot have map type columns in DataFrame which calls set operations")) + val e3 = intercept[AnalysisException](df.distinct()) + assert(e3.message.contains( + "Cannot have map type columns in DataFrame which calls set operations")) + withTempView("v") { + df.createOrReplaceTempView("v") + val e4 = intercept[AnalysisException](sql("SELECT DISTINCT m FROM v")) + assert(e4.message.contains( + "Cannot have map type columns in DataFrame which calls set operations")) + } + } test("SPARK-20359: catalyst outer join optimization should not throw npe") { val df1 = Seq("a", "b", "c").toDF("x") @@ -2108,6 +2426,18 @@ class DataFrameSuite extends QueryTest with SharedSparkSession { Seq(Row(7, 1, 1), Row(7, 1, 2), Row(7, 2, 1), Row(7, 2, 2), Row(7, 3, 1), Row(7, 3, 2))) } + test("SPARK-22226: splitExpressions should not generate codes beyond 64KB") { + val colNumber = 10000 + val input = spark.range(2).rdd.map(_ => Row(1 to colNumber: _*)) + val df = sqlContext.createDataFrame(input, StructType( + (1 to colNumber).map(colIndex => StructField(s"_$colIndex", IntegerType, false)))) + val newCols = (1 to colNumber).flatMap { colIndex => + Seq(expr(s"if(1000 < _$colIndex, 1000, _$colIndex)"), + expr(s"sqrt(_$colIndex)")) + } + df.select(newCols: _*).collect() + } + test("SPARK-22271: mean overflows and returns null for some decimal variables") { val d = 0.034567890 val df = Seq(d, d, d, d, d, d, d, d, d, d).toDF("DecimalCol") @@ -2242,6 +2572,24 @@ class DataFrameSuite extends QueryTest with SharedSparkSession { } } + test("SPARK-25368 Incorrect predicate pushdown returns wrong result") { + def check(newCol: Column, filter: Column, result: Seq[Row]): Unit = { + val df1 = spark.createDataFrame(Seq( + (1, 1) + )).toDF("a", "b").withColumn("c", newCol) + + val df2 = df1.union(df1).withColumn("d", spark_partition_id).filter(filter) + checkAnswer(df2, result) + } + + check(lit(null).cast("int"), $"c".isNull, Seq(Row(1, 1, null, 0), Row(1, 1, null, 1))) + check(lit(null).cast("int"), $"c".isNotNull, Seq()) + check(lit(2).cast("int"), $"c".isNull, Seq()) + check(lit(2).cast("int"), $"c".isNotNull, Seq(Row(1, 1, 2, 0), Row(1, 1, 2, 1))) + check(lit(2).cast("int"), $"c" === 2, Seq(Row(1, 1, 2, 0), Row(1, 1, 2, 1))) + check(lit(2).cast("int"), $"c" =!= 2, Seq()) + } + test("SPARK-25402 Null handling in BooleanSimplification") { val schema = StructType.fromDDL("a boolean, b int") val rows = Seq(Row(null, 1)) @@ -2293,39 +2641,6 @@ class DataFrameSuite extends QueryTest with SharedSparkSession { } } - test("SPARK-27671: Fix analysis exception when casting null in nested field in struct") { - val df = sql("SELECT * FROM VALUES (('a', (10, null))), (('b', (10, 50))), " + - "(('c', null)) AS tab(x, y)") - checkAnswer(df, Row("a", Row(10, null)) :: Row("b", Row(10, 50)) :: Row("c", null) :: Nil) - - val cast = sql("SELECT cast(struct(1, null) AS struct)") - checkAnswer(cast, Row(Row(1, null)) :: Nil) - } - - test("SPARK-27439: Explain result should match collected result after view change") { - withTempView("test", "test2", "tmp") { - spark.range(10).createOrReplaceTempView("test") - spark.range(5).createOrReplaceTempView("test2") - spark.sql("select * from test").createOrReplaceTempView("tmp") - val df = spark.sql("select * from tmp") - spark.sql("select * from test2").createOrReplaceTempView("tmp") - - val captured = new ByteArrayOutputStream() - Console.withOut(captured) { - df.explain(extended = true) - } - checkAnswer(df, spark.range(10).toDF) - val output = captured.toString - assert(output.contains( - """== Parsed Logical Plan == - |'Project [*] - |+- 'UnresolvedRelation [tmp]""".stripMargin)) - assert(output.contains( - """== Physical Plan == - |*(1) Range (0, 10, step=1, splits=2)""".stripMargin)) - } - } - test("SPARK-29442 Set `default` mode should override the existing mode") { val df = Seq(Tuple1(1)).toDF() val writer = df.write.mode("overwrite").mode("default") @@ -2333,46 +2648,4 @@ class DataFrameSuite extends QueryTest with SharedSparkSession { modeField.setAccessible(true) assert(SaveMode.ErrorIfExists === modeField.get(writer).asInstanceOf[SaveMode]) } - - test("sample should not duplicated the input data") { - val df1 = spark.range(10).select($"id" as "id1", $"id" % 5 as "key1") - val df2 = spark.range(10).select($"id" as "id2", $"id" % 5 as "key2") - val sampled = df1.join(df2, $"key1" === $"key2") - .sample(0.5, 42) - .select("id1", "id2") - val idTuples = sampled.collect().map(row => row.getLong(0) -> row.getLong(1)) - assert(idTuples.length == idTuples.toSet.size) - } - - test("emptyDataFrame should be foldable") { - val emptyDf = spark.emptyDataFrame.withColumn("id", lit(1L)) - val joined = spark.range(10).join(emptyDf, "id") - joined.queryExecution.optimizedPlan match { - case LocalRelation(Seq(id), Nil, _) => - assert(id.name == "id") - case _ => - fail("emptyDataFrame should be foldable") - } - } - - test("SPARK-30811: CTE should not cause stack overflow when " + - "it refers to non-existent table with same name") { - val e = intercept[AnalysisException] { - sql("WITH t AS (SELECT 1 FROM nonexist.t) SELECT * FROM t") - } - assert(e.getMessage.contains("Table or view not found:")) - } - - test("SPARK-31750: eliminate UpCast if child's dataType is DecimalType") { - withTempPath { f => - sql("select cast(1 as decimal(38, 0)) as d") - .write.mode("overwrite") - .parquet(f.getAbsolutePath) - - val df = spark.read.parquet(f.getAbsolutePath).as[BigDecimal] - assert(df.schema === new StructType().add(StructField("d", DecimalType(38, 0)))) - } - } } - -case class GroupByKey(a: Int, b: Int) From 9c651ac871dbfbb427732a63fa26d617cc1a2bac Mon Sep 17 00:00:00 2001 From: "longfei.jiang" Date: Thu, 22 Jul 2021 15:51:17 +0800 Subject: [PATCH 5/5] KE-24858 update pom version --- assembly/pom.xml | 2 +- common/kvstore/pom.xml | 2 +- common/network-common/pom.xml | 2 +- common/network-shuffle/pom.xml | 2 +- common/network-yarn/pom.xml | 2 +- common/sketch/pom.xml | 2 +- common/tags/pom.xml | 2 +- common/unsafe/pom.xml | 2 +- core/pom.xml | 2 +- examples/pom.xml | 2 +- external/avro/pom.xml | 2 +- external/docker-integration-tests/pom.xml | 2 +- external/flume-assembly/pom.xml | 2 +- external/flume-sink/pom.xml | 2 +- external/flume/pom.xml | 2 +- external/kafka-0-10-assembly/pom.xml | 2 +- external/kafka-0-10-sql/pom.xml | 2 +- external/kafka-0-10/pom.xml | 2 +- external/kafka-0-8-assembly/pom.xml | 2 +- external/kafka-0-8/pom.xml | 2 +- external/kinesis-asl-assembly/pom.xml | 2 +- external/kinesis-asl/pom.xml | 2 +- external/spark-ganglia-lgpl/pom.xml | 2 +- graphx/pom.xml | 2 +- hadoop-cloud/pom.xml | 2 +- launcher/pom.xml | 2 +- mllib-local/pom.xml | 2 +- mllib/pom.xml | 2 +- pom.xml | 2 +- repl/pom.xml | 2 +- resource-managers/kubernetes/core/pom.xml | 2 +- resource-managers/kubernetes/integration-tests/pom.xml | 2 +- resource-managers/mesos/pom.xml | 2 +- resource-managers/yarn/pom.xml | 2 +- sql/catalyst/pom.xml | 2 +- sql/core/pom.xml | 2 +- sql/hive-thriftserver/pom.xml | 2 +- sql/hive/pom.xml | 2 +- streaming/pom.xml | 2 +- tools/pom.xml | 2 +- 40 files changed, 40 insertions(+), 40 deletions(-) diff --git a/assembly/pom.xml b/assembly/pom.xml index 41e6648aa3f3c..dd3a386bb514e 100644 --- a/assembly/pom.xml +++ b/assembly/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.4.1-kylin-r70 + 2.4.1-kylin-r71 ../pom.xml diff --git a/common/kvstore/pom.xml b/common/kvstore/pom.xml index 53bd7f261eff3..18029c86442d4 100644 --- a/common/kvstore/pom.xml +++ b/common/kvstore/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.4.1-kylin-r70 + 2.4.1-kylin-r71 ../../pom.xml diff --git a/common/network-common/pom.xml b/common/network-common/pom.xml index 1a7fa64021197..b5a4551a54112 100644 --- a/common/network-common/pom.xml +++ b/common/network-common/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.4.1-kylin-r70 + 2.4.1-kylin-r71 ../../pom.xml diff --git a/common/network-shuffle/pom.xml b/common/network-shuffle/pom.xml index 93a2fc1fbb92d..dadf340e84f01 100644 --- a/common/network-shuffle/pom.xml +++ b/common/network-shuffle/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.4.1-kylin-r70 + 2.4.1-kylin-r71 ../../pom.xml diff --git a/common/network-yarn/pom.xml b/common/network-yarn/pom.xml index ba009e0e7896a..905fb1ebf540d 100644 --- a/common/network-yarn/pom.xml +++ b/common/network-yarn/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.4.1-kylin-r70 + 2.4.1-kylin-r71 ../../pom.xml diff --git a/common/sketch/pom.xml b/common/sketch/pom.xml index 5f2bb64898ab8..723ed3b2f3c61 100644 --- a/common/sketch/pom.xml +++ b/common/sketch/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.4.1-kylin-r70 + 2.4.1-kylin-r71 ../../pom.xml diff --git a/common/tags/pom.xml b/common/tags/pom.xml index 8160ff20c0da1..24cd8283640ac 100644 --- a/common/tags/pom.xml +++ b/common/tags/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.4.1-kylin-r70 + 2.4.1-kylin-r71 ../../pom.xml diff --git a/common/unsafe/pom.xml b/common/unsafe/pom.xml index f43e375d3cedf..40583281e13d6 100644 --- a/common/unsafe/pom.xml +++ b/common/unsafe/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.4.1-kylin-r70 + 2.4.1-kylin-r71 ../../pom.xml diff --git a/core/pom.xml b/core/pom.xml index 1bd6465593594..6496cead95797 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.4.1-kylin-r70 + 2.4.1-kylin-r71 ../pom.xml diff --git a/examples/pom.xml b/examples/pom.xml index faf57b95bca0e..6393c8a8c5052 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.4.1-kylin-r70 + 2.4.1-kylin-r71 ../pom.xml diff --git a/external/avro/pom.xml b/external/avro/pom.xml index 9cbeb976df591..20ee419f68840 100644 --- a/external/avro/pom.xml +++ b/external/avro/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.4.1-kylin-r70 + 2.4.1-kylin-r71 ../../pom.xml diff --git a/external/docker-integration-tests/pom.xml b/external/docker-integration-tests/pom.xml index 7a4a1536a735c..d2f7030ec20fe 100644 --- a/external/docker-integration-tests/pom.xml +++ b/external/docker-integration-tests/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.4.1-kylin-r70 + 2.4.1-kylin-r71 ../../pom.xml diff --git a/external/flume-assembly/pom.xml b/external/flume-assembly/pom.xml index c7fdd7f77e1a7..5b56d4be606c4 100644 --- a/external/flume-assembly/pom.xml +++ b/external/flume-assembly/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.4.1-kylin-r70 + 2.4.1-kylin-r71 ../../pom.xml diff --git a/external/flume-sink/pom.xml b/external/flume-sink/pom.xml index 4c56b41b7f92b..cb05074f4292f 100644 --- a/external/flume-sink/pom.xml +++ b/external/flume-sink/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.4.1-kylin-r70 + 2.4.1-kylin-r71 ../../pom.xml diff --git a/external/flume/pom.xml b/external/flume/pom.xml index 0a5bfdf6178fc..53b334cd95ab3 100644 --- a/external/flume/pom.xml +++ b/external/flume/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.4.1-kylin-r70 + 2.4.1-kylin-r71 ../../pom.xml diff --git a/external/kafka-0-10-assembly/pom.xml b/external/kafka-0-10-assembly/pom.xml index 86238385a9096..b9d2c0da42830 100644 --- a/external/kafka-0-10-assembly/pom.xml +++ b/external/kafka-0-10-assembly/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.4.1-kylin-r70 + 2.4.1-kylin-r71 ../../pom.xml diff --git a/external/kafka-0-10-sql/pom.xml b/external/kafka-0-10-sql/pom.xml index c2605595e2966..50321770056e1 100644 --- a/external/kafka-0-10-sql/pom.xml +++ b/external/kafka-0-10-sql/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.4.1-kylin-r70 + 2.4.1-kylin-r71 ../../pom.xml diff --git a/external/kafka-0-10/pom.xml b/external/kafka-0-10/pom.xml index 1968e805154b1..26952ef2d5790 100644 --- a/external/kafka-0-10/pom.xml +++ b/external/kafka-0-10/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.4.1-kylin-r70 + 2.4.1-kylin-r71 ../../pom.xml diff --git a/external/kafka-0-8-assembly/pom.xml b/external/kafka-0-8-assembly/pom.xml index d3db0a3991826..ae82444e6e1cc 100644 --- a/external/kafka-0-8-assembly/pom.xml +++ b/external/kafka-0-8-assembly/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.4.1-kylin-r70 + 2.4.1-kylin-r71 ../../pom.xml diff --git a/external/kafka-0-8/pom.xml b/external/kafka-0-8/pom.xml index 99ac54a738743..182b2096f8d0e 100644 --- a/external/kafka-0-8/pom.xml +++ b/external/kafka-0-8/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.4.1-kylin-r70 + 2.4.1-kylin-r71 ../../pom.xml diff --git a/external/kinesis-asl-assembly/pom.xml b/external/kinesis-asl-assembly/pom.xml index 56ba91c4fce2f..0f89c300301fb 100644 --- a/external/kinesis-asl-assembly/pom.xml +++ b/external/kinesis-asl-assembly/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.4.1-kylin-r70 + 2.4.1-kylin-r71 ../../pom.xml diff --git a/external/kinesis-asl/pom.xml b/external/kinesis-asl/pom.xml index 34bd5f3a98bad..05e439562ddeb 100644 --- a/external/kinesis-asl/pom.xml +++ b/external/kinesis-asl/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.11 - 2.4.1-kylin-r70 + 2.4.1-kylin-r71 ../../pom.xml diff --git a/external/spark-ganglia-lgpl/pom.xml b/external/spark-ganglia-lgpl/pom.xml index acf097f74b9e5..8c0acb7353896 100644 --- a/external/spark-ganglia-lgpl/pom.xml +++ b/external/spark-ganglia-lgpl/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.11 - 2.4.1-kylin-r70 + 2.4.1-kylin-r71 ../../pom.xml diff --git a/graphx/pom.xml b/graphx/pom.xml index 5a57c65d70524..023a481d8e1df 100644 --- a/graphx/pom.xml +++ b/graphx/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.4.1-kylin-r70 + 2.4.1-kylin-r71 ../pom.xml diff --git a/hadoop-cloud/pom.xml b/hadoop-cloud/pom.xml index a0b435c80046d..dc0ae525a27ad 100644 --- a/hadoop-cloud/pom.xml +++ b/hadoop-cloud/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.4.1-kylin-r70 + 2.4.1-kylin-r71 ../pom.xml diff --git a/launcher/pom.xml b/launcher/pom.xml index 2ce7e88594b34..d78328e9cbc82 100644 --- a/launcher/pom.xml +++ b/launcher/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.4.1-kylin-r70 + 2.4.1-kylin-r71 ../pom.xml diff --git a/mllib-local/pom.xml b/mllib-local/pom.xml index 754855096d594..94d4ce3857478 100644 --- a/mllib-local/pom.xml +++ b/mllib-local/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.4.1-kylin-r70 + 2.4.1-kylin-r71 ../pom.xml diff --git a/mllib/pom.xml b/mllib/pom.xml index cb387b78f1149..42bec8fb67e55 100644 --- a/mllib/pom.xml +++ b/mllib/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.4.1-kylin-r70 + 2.4.1-kylin-r71 ../pom.xml diff --git a/pom.xml b/pom.xml index 23224a5069fcd..b97b9351c4b79 100644 --- a/pom.xml +++ b/pom.xml @@ -26,7 +26,7 @@ org.apache.spark spark-parent_2.11 - 2.4.1-kylin-r70 + 2.4.1-kylin-r71 pom Spark Project Parent POM http://spark.apache.org/ diff --git a/repl/pom.xml b/repl/pom.xml index b6d49935024bd..0e09e0dbf354c 100644 --- a/repl/pom.xml +++ b/repl/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.4.1-kylin-r70 + 2.4.1-kylin-r71 ../pom.xml diff --git a/resource-managers/kubernetes/core/pom.xml b/resource-managers/kubernetes/core/pom.xml index 866c0cf552446..25f5c66e5d716 100644 --- a/resource-managers/kubernetes/core/pom.xml +++ b/resource-managers/kubernetes/core/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.11 - 2.4.1-kylin-r70 + 2.4.1-kylin-r71 ../../../pom.xml diff --git a/resource-managers/kubernetes/integration-tests/pom.xml b/resource-managers/kubernetes/integration-tests/pom.xml index ec8334f1fccd2..adde969b1c5cc 100644 --- a/resource-managers/kubernetes/integration-tests/pom.xml +++ b/resource-managers/kubernetes/integration-tests/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.11 - 2.4.1-kylin-r70 + 2.4.1-kylin-r71 ../../../pom.xml diff --git a/resource-managers/mesos/pom.xml b/resource-managers/mesos/pom.xml index ebd97a9d2d6cd..0fd3a1bce2522 100644 --- a/resource-managers/mesos/pom.xml +++ b/resource-managers/mesos/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.11 - 2.4.1-kylin-r70 + 2.4.1-kylin-r71 ../../pom.xml diff --git a/resource-managers/yarn/pom.xml b/resource-managers/yarn/pom.xml index 841791a6e8a15..b3809904332a6 100644 --- a/resource-managers/yarn/pom.xml +++ b/resource-managers/yarn/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.11 - 2.4.1-kylin-r70 + 2.4.1-kylin-r71 ../../pom.xml diff --git a/sql/catalyst/pom.xml b/sql/catalyst/pom.xml index 5736432b2ec19..d04cf02f4173e 100644 --- a/sql/catalyst/pom.xml +++ b/sql/catalyst/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.4.1-kylin-r70 + 2.4.1-kylin-r71 ../../pom.xml diff --git a/sql/core/pom.xml b/sql/core/pom.xml index 0db539b4e8762..3f102462b6fd5 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.4.1-kylin-r70 + 2.4.1-kylin-r71 ../../pom.xml diff --git a/sql/hive-thriftserver/pom.xml b/sql/hive-thriftserver/pom.xml index d84421acf8390..fe102a1691305 100644 --- a/sql/hive-thriftserver/pom.xml +++ b/sql/hive-thriftserver/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.4.1-kylin-r70 + 2.4.1-kylin-r71 ../../pom.xml diff --git a/sql/hive/pom.xml b/sql/hive/pom.xml index 9f76d192524ee..8efb3799d2fb8 100644 --- a/sql/hive/pom.xml +++ b/sql/hive/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.4.1-kylin-r70 + 2.4.1-kylin-r71 ../../pom.xml diff --git a/streaming/pom.xml b/streaming/pom.xml index e84cd5514784e..b47ae88c16916 100644 --- a/streaming/pom.xml +++ b/streaming/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.4.1-kylin-r70 + 2.4.1-kylin-r71 ../pom.xml diff --git a/tools/pom.xml b/tools/pom.xml index e009120bb52bc..25d9508c256c4 100644 --- a/tools/pom.xml +++ b/tools/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.11 - 2.4.1-kylin-r70 + 2.4.1-kylin-r71 ../pom.xml