diff --git a/assembly/pom.xml b/assembly/pom.xml index 41e6648aa3f3c..dd3a386bb514e 100644 --- a/assembly/pom.xml +++ b/assembly/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.4.1-kylin-r70 + 2.4.1-kylin-r71 ../pom.xml diff --git a/common/kvstore/pom.xml b/common/kvstore/pom.xml index 53bd7f261eff3..18029c86442d4 100644 --- a/common/kvstore/pom.xml +++ b/common/kvstore/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.4.1-kylin-r70 + 2.4.1-kylin-r71 ../../pom.xml diff --git a/common/network-common/pom.xml b/common/network-common/pom.xml index 1a7fa64021197..b5a4551a54112 100644 --- a/common/network-common/pom.xml +++ b/common/network-common/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.4.1-kylin-r70 + 2.4.1-kylin-r71 ../../pom.xml diff --git a/common/network-shuffle/pom.xml b/common/network-shuffle/pom.xml index 93a2fc1fbb92d..dadf340e84f01 100644 --- a/common/network-shuffle/pom.xml +++ b/common/network-shuffle/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.4.1-kylin-r70 + 2.4.1-kylin-r71 ../../pom.xml diff --git a/common/network-yarn/pom.xml b/common/network-yarn/pom.xml index ba009e0e7896a..905fb1ebf540d 100644 --- a/common/network-yarn/pom.xml +++ b/common/network-yarn/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.4.1-kylin-r70 + 2.4.1-kylin-r71 ../../pom.xml diff --git a/common/sketch/pom.xml b/common/sketch/pom.xml index 5f2bb64898ab8..723ed3b2f3c61 100644 --- a/common/sketch/pom.xml +++ b/common/sketch/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.4.1-kylin-r70 + 2.4.1-kylin-r71 ../../pom.xml diff --git a/common/tags/pom.xml b/common/tags/pom.xml index 8160ff20c0da1..24cd8283640ac 100644 --- a/common/tags/pom.xml +++ b/common/tags/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.4.1-kylin-r70 + 2.4.1-kylin-r71 ../../pom.xml diff --git a/common/unsafe/pom.xml b/common/unsafe/pom.xml index f43e375d3cedf..40583281e13d6 100644 --- a/common/unsafe/pom.xml +++ b/common/unsafe/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.4.1-kylin-r70 + 2.4.1-kylin-r71 ../../pom.xml diff --git a/core/pom.xml b/core/pom.xml index 1bd6465593594..6496cead95797 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.4.1-kylin-r70 + 2.4.1-kylin-r71 ../pom.xml diff --git a/examples/pom.xml b/examples/pom.xml index faf57b95bca0e..6393c8a8c5052 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.4.1-kylin-r70 + 2.4.1-kylin-r71 ../pom.xml diff --git a/external/avro/pom.xml b/external/avro/pom.xml index 9cbeb976df591..20ee419f68840 100644 --- a/external/avro/pom.xml +++ b/external/avro/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.4.1-kylin-r70 + 2.4.1-kylin-r71 ../../pom.xml diff --git a/external/docker-integration-tests/pom.xml b/external/docker-integration-tests/pom.xml index 7a4a1536a735c..d2f7030ec20fe 100644 --- a/external/docker-integration-tests/pom.xml +++ b/external/docker-integration-tests/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.4.1-kylin-r70 + 2.4.1-kylin-r71 ../../pom.xml diff --git a/external/flume-assembly/pom.xml b/external/flume-assembly/pom.xml index c7fdd7f77e1a7..5b56d4be606c4 100644 --- a/external/flume-assembly/pom.xml +++ b/external/flume-assembly/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.4.1-kylin-r70 + 2.4.1-kylin-r71 ../../pom.xml diff --git a/external/flume-sink/pom.xml b/external/flume-sink/pom.xml index 4c56b41b7f92b..cb05074f4292f 100644 --- a/external/flume-sink/pom.xml +++ b/external/flume-sink/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.4.1-kylin-r70 + 2.4.1-kylin-r71 ../../pom.xml diff --git a/external/flume/pom.xml b/external/flume/pom.xml index 0a5bfdf6178fc..53b334cd95ab3 100644 --- a/external/flume/pom.xml +++ b/external/flume/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.4.1-kylin-r70 + 2.4.1-kylin-r71 ../../pom.xml diff --git a/external/kafka-0-10-assembly/pom.xml b/external/kafka-0-10-assembly/pom.xml index 86238385a9096..b9d2c0da42830 100644 --- a/external/kafka-0-10-assembly/pom.xml +++ b/external/kafka-0-10-assembly/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.4.1-kylin-r70 + 2.4.1-kylin-r71 ../../pom.xml diff --git a/external/kafka-0-10-sql/pom.xml b/external/kafka-0-10-sql/pom.xml index c2605595e2966..50321770056e1 100644 --- a/external/kafka-0-10-sql/pom.xml +++ b/external/kafka-0-10-sql/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.4.1-kylin-r70 + 2.4.1-kylin-r71 ../../pom.xml diff --git a/external/kafka-0-10/pom.xml b/external/kafka-0-10/pom.xml index 1968e805154b1..26952ef2d5790 100644 --- a/external/kafka-0-10/pom.xml +++ b/external/kafka-0-10/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.4.1-kylin-r70 + 2.4.1-kylin-r71 ../../pom.xml diff --git a/external/kafka-0-8-assembly/pom.xml b/external/kafka-0-8-assembly/pom.xml index d3db0a3991826..ae82444e6e1cc 100644 --- a/external/kafka-0-8-assembly/pom.xml +++ b/external/kafka-0-8-assembly/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.4.1-kylin-r70 + 2.4.1-kylin-r71 ../../pom.xml diff --git a/external/kafka-0-8/pom.xml b/external/kafka-0-8/pom.xml index 99ac54a738743..182b2096f8d0e 100644 --- a/external/kafka-0-8/pom.xml +++ b/external/kafka-0-8/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.4.1-kylin-r70 + 2.4.1-kylin-r71 ../../pom.xml diff --git a/external/kinesis-asl-assembly/pom.xml b/external/kinesis-asl-assembly/pom.xml index 56ba91c4fce2f..0f89c300301fb 100644 --- a/external/kinesis-asl-assembly/pom.xml +++ b/external/kinesis-asl-assembly/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.4.1-kylin-r70 + 2.4.1-kylin-r71 ../../pom.xml diff --git a/external/kinesis-asl/pom.xml b/external/kinesis-asl/pom.xml index 34bd5f3a98bad..05e439562ddeb 100644 --- a/external/kinesis-asl/pom.xml +++ b/external/kinesis-asl/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.11 - 2.4.1-kylin-r70 + 2.4.1-kylin-r71 ../../pom.xml diff --git a/external/spark-ganglia-lgpl/pom.xml b/external/spark-ganglia-lgpl/pom.xml index acf097f74b9e5..8c0acb7353896 100644 --- a/external/spark-ganglia-lgpl/pom.xml +++ b/external/spark-ganglia-lgpl/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.11 - 2.4.1-kylin-r70 + 2.4.1-kylin-r71 ../../pom.xml diff --git a/graphx/pom.xml b/graphx/pom.xml index 5a57c65d70524..023a481d8e1df 100644 --- a/graphx/pom.xml +++ b/graphx/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.4.1-kylin-r70 + 2.4.1-kylin-r71 ../pom.xml diff --git a/hadoop-cloud/pom.xml b/hadoop-cloud/pom.xml index a0b435c80046d..dc0ae525a27ad 100644 --- a/hadoop-cloud/pom.xml +++ b/hadoop-cloud/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.4.1-kylin-r70 + 2.4.1-kylin-r71 ../pom.xml diff --git a/launcher/pom.xml b/launcher/pom.xml index 2ce7e88594b34..d78328e9cbc82 100644 --- a/launcher/pom.xml +++ b/launcher/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.4.1-kylin-r70 + 2.4.1-kylin-r71 ../pom.xml diff --git a/mllib-local/pom.xml b/mllib-local/pom.xml index 754855096d594..94d4ce3857478 100644 --- a/mllib-local/pom.xml +++ b/mllib-local/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.4.1-kylin-r70 + 2.4.1-kylin-r71 ../pom.xml diff --git a/mllib/pom.xml b/mllib/pom.xml index cb387b78f1149..42bec8fb67e55 100644 --- a/mllib/pom.xml +++ b/mllib/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.4.1-kylin-r70 + 2.4.1-kylin-r71 ../pom.xml diff --git a/pom.xml b/pom.xml index 23224a5069fcd..b97b9351c4b79 100644 --- a/pom.xml +++ b/pom.xml @@ -26,7 +26,7 @@ org.apache.spark spark-parent_2.11 - 2.4.1-kylin-r70 + 2.4.1-kylin-r71 pom Spark Project Parent POM http://spark.apache.org/ diff --git a/repl/pom.xml b/repl/pom.xml index b6d49935024bd..0e09e0dbf354c 100644 --- a/repl/pom.xml +++ b/repl/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.4.1-kylin-r70 + 2.4.1-kylin-r71 ../pom.xml diff --git a/resource-managers/kubernetes/core/pom.xml b/resource-managers/kubernetes/core/pom.xml index 866c0cf552446..25f5c66e5d716 100644 --- a/resource-managers/kubernetes/core/pom.xml +++ b/resource-managers/kubernetes/core/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.11 - 2.4.1-kylin-r70 + 2.4.1-kylin-r71 ../../../pom.xml diff --git a/resource-managers/kubernetes/integration-tests/pom.xml b/resource-managers/kubernetes/integration-tests/pom.xml index ec8334f1fccd2..adde969b1c5cc 100644 --- a/resource-managers/kubernetes/integration-tests/pom.xml +++ b/resource-managers/kubernetes/integration-tests/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.11 - 2.4.1-kylin-r70 + 2.4.1-kylin-r71 ../../../pom.xml diff --git a/resource-managers/mesos/pom.xml b/resource-managers/mesos/pom.xml index ebd97a9d2d6cd..0fd3a1bce2522 100644 --- a/resource-managers/mesos/pom.xml +++ b/resource-managers/mesos/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.11 - 2.4.1-kylin-r70 + 2.4.1-kylin-r71 ../../pom.xml diff --git a/resource-managers/yarn/pom.xml b/resource-managers/yarn/pom.xml index 841791a6e8a15..b3809904332a6 100644 --- a/resource-managers/yarn/pom.xml +++ b/resource-managers/yarn/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.11 - 2.4.1-kylin-r70 + 2.4.1-kylin-r71 ../../pom.xml diff --git a/sql/catalyst/pom.xml b/sql/catalyst/pom.xml index 5736432b2ec19..d04cf02f4173e 100644 --- a/sql/catalyst/pom.xml +++ b/sql/catalyst/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.4.1-kylin-r70 + 2.4.1-kylin-r71 ../../pom.xml diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index 9bf9452855f5f..faf7c2ecf4c21 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -279,7 +279,7 @@ public void setDecimal(int ordinal, Decimal value, int precision) { Platform.putLong(baseObject, baseOffset + cursor, 0L); Platform.putLong(baseObject, baseOffset + cursor + 8, 0L); - if (value == null) { + if (value == null || !value.changePrecision(precision, value.scale())) { setNullAt(ordinal); // keep the offset for future update Platform.putLong(baseObject, getFieldOffset(ordinal), cursor << 32); diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala index 82692334544e2..b0c4cc05d0ce8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala @@ -87,49 +87,56 @@ object DecimalPrecision extends TypeCoercionRule { case q => q.transformExpressionsUp( decimalAndDecimal.orElse(integralAndDecimalLiteral).orElse(nondecimalAndDecimal)) } + private[catalyst] def decimalAndDecimal(): PartialFunction[Expression, Expression] = { + decimalAndDecimal(SQLConf.get.decimalOperationsAllowPrecisionLoss, !SQLConf.get.ansiEnabled) + } /** Decimal precision promotion for +, -, *, /, %, pmod, and binary comparison. */ - private[catalyst] val decimalAndDecimal: PartialFunction[Expression, Expression] = { + private[catalyst] def decimalAndDecimal(allowPrecisionLoss: Boolean, nullOnOverflow: Boolean) + : PartialFunction[Expression, Expression] = { // Skip nodes whose children have not been resolved yet case e if !e.childrenResolved => e // Skip nodes who is already promoted case e: BinaryArithmetic if e.left.isInstanceOf[PromotePrecision] => e - case Add(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => + case a @ Add(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => val resultScale = max(s1, s2) - val resultType = if (SQLConf.get.decimalOperationsAllowPrecisionLoss) { + val resultType = if (allowPrecisionLoss) { DecimalType.adjustPrecisionScale(max(p1 - s1, p2 - s2) + resultScale + 1, resultScale) } else { DecimalType.bounded(max(p1 - s1, p2 - s2) + resultScale + 1, resultScale) } - CheckOverflow(Add(promotePrecision(e1, resultType), promotePrecision(e2, resultType)), - resultType) + CheckOverflow( + a.copy(left = promotePrecision(e1, resultType), right = promotePrecision(e2, resultType)), + resultType, nullOnOverflow) - case Subtract(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => + case s @ Subtract(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => val resultScale = max(s1, s2) - val resultType = if (SQLConf.get.decimalOperationsAllowPrecisionLoss) { + val resultType = if (allowPrecisionLoss) { DecimalType.adjustPrecisionScale(max(p1 - s1, p2 - s2) + resultScale + 1, resultScale) } else { DecimalType.bounded(max(p1 - s1, p2 - s2) + resultScale + 1, resultScale) } - CheckOverflow(Subtract(promotePrecision(e1, resultType), promotePrecision(e2, resultType)), - resultType) + CheckOverflow( + s.copy(left = promotePrecision(e1, resultType), right = promotePrecision(e2, resultType)), + resultType, nullOnOverflow) - case Multiply(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => - val resultType = if (SQLConf.get.decimalOperationsAllowPrecisionLoss) { + case m @ Multiply(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => + val resultType = if (allowPrecisionLoss) { DecimalType.adjustPrecisionScale(p1 + p2 + 1, s1 + s2) } else { DecimalType.bounded(p1 + p2 + 1, s1 + s2) } val widerType = widerDecimalType(p1, s1, p2, s2) - CheckOverflow(Multiply(promotePrecision(e1, widerType), promotePrecision(e2, widerType)), - resultType) + CheckOverflow( + m.copy(left = promotePrecision(e1, widerType), right = promotePrecision(e2, widerType)), + resultType, nullOnOverflow) - case Divide(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => - val resultType = if (SQLConf.get.decimalOperationsAllowPrecisionLoss) { + case d @ Divide(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => + val resultType = if (allowPrecisionLoss) { // Precision: p1 - s1 + s2 + max(6, s1 + p2 + 1) // Scale: max(6, s1 + p2 + 1) val intDig = p1 - s1 + s2 @@ -147,30 +154,33 @@ object DecimalPrecision extends TypeCoercionRule { DecimalType.bounded(intDig + decDig, decDig) } val widerType = widerDecimalType(p1, s1, p2, s2) - CheckOverflow(Divide(promotePrecision(e1, widerType), promotePrecision(e2, widerType)), - resultType) + CheckOverflow( + d.copy(left = promotePrecision(e1, widerType), right = promotePrecision(e2, widerType)), + resultType, nullOnOverflow) - case Remainder(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => - val resultType = if (SQLConf.get.decimalOperationsAllowPrecisionLoss) { + case r @ Remainder(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => + val resultType = if (allowPrecisionLoss) { DecimalType.adjustPrecisionScale(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2)) } else { DecimalType.bounded(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2)) } // resultType may have lower precision, so we cast them into wider type first. val widerType = widerDecimalType(p1, s1, p2, s2) - CheckOverflow(Remainder(promotePrecision(e1, widerType), promotePrecision(e2, widerType)), - resultType) + CheckOverflow( + r.copy(left = promotePrecision(e1, widerType), right = promotePrecision(e2, widerType)), + resultType, nullOnOverflow) - case Pmod(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => - val resultType = if (SQLConf.get.decimalOperationsAllowPrecisionLoss) { + case p @ Pmod(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => + val resultType = if (allowPrecisionLoss) { DecimalType.adjustPrecisionScale(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2)) } else { DecimalType.bounded(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2)) } // resultType may have lower precision, so we cast them into wider type first. val widerType = widerDecimalType(p1, s1, p2, s2) - CheckOverflow(Pmod(promotePrecision(e1, widerType), promotePrecision(e2, widerType)), - resultType) + CheckOverflow( + p.copy(left = promotePrecision(e1, widerType), right = promotePrecision(e2, widerType)), + resultType, nullOnOverflow) case b @ BinaryComparison(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/StreamingJoinHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/StreamingJoinHelper.scala index 7a0aa08289efa..c7b79df4035cd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/StreamingJoinHelper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/StreamingJoinHelper.scala @@ -236,7 +236,7 @@ object StreamingJoinHelper extends PredicateHelper with Logging { collect(left, negate) ++ collect(right, !negate) case UnaryMinus(child) => collect(child, !negate) - case CheckOverflow(child, _) => + case CheckOverflow(child, _, _) => collect(child, negate) case PromotePrecision(child) => collect(child, negate) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/errors/QueryExecutionErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/errors/QueryExecutionErrors.scala new file mode 100644 index 0000000000000..b84643d3293dd --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/errors/QueryExecutionErrors.scala @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.errors + +import org.apache.spark.sql.types._ + +/** + * Object for grouping error messages from (most) exceptions thrown during query execution. + * This does not include exceptions thrown during the eager execution of commands, which are + * grouped into [[QueryCompilationErrors]]. + */ +object QueryExecutionErrors { + + def cannotChangeDecimalPrecisionError( + value: Decimal, decimalPrecision: Int, decimalScale: Int): ArithmeticException = { + new ArithmeticException(s"${value.toDebugString} cannot be represented as " + + s"Decimal($decimalPrecision, $decimalScale).") + } + +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala index 5ecb77be5965e..8356d6b9d1b1a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala @@ -21,6 +21,7 @@ import org.apache.spark.sql.catalyst.analysis.{DecimalPrecision, TypeCheckResult import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util.TypeUtils +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ abstract class AverageLike(child: Expression) extends DeclarativeAggregate { @@ -57,8 +58,11 @@ abstract class AverageLike(child: Expression) extends DeclarativeAggregate { // If all input are nulls, count will be 0 and we will get null after the division. override lazy val evaluateExpression = child.dataType match { - case _: DecimalType => - DecimalPrecision.decimalAndDecimal(sum / count.cast(DecimalType.LongDecimal)).cast(resultType) + case d: DecimalType => + DecimalPrecision.decimalAndDecimal()( + Divide( + CheckOverflowInSum(sum, d, !SQLConf.get.ansiEnabled), + count.cast(DecimalType.LongDecimal))).cast(resultType) case _ => sum.cast(resultType) / count.cast(resultType) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala index 761dba111c074..aca8c8e79cec2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala @@ -21,10 +21,22 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util.TypeUtils +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @ExpressionDescription( - usage = "_FUNC_(expr) - Returns the sum calculated from values of a group.") + usage = "_FUNC_(expr) - Returns the sum calculated from values of a group.", + examples = """ + Examples: + > SELECT _FUNC_(col) FROM VALUES (5), (10), (15) AS tab(col); + 30 + > SELECT _FUNC_(col) FROM VALUES (NULL), (10), (15) AS tab(col); + 25 + > SELECT _FUNC_(col) FROM VALUES (NULL), (NULL) AS tab(col); + NULL + """, + extended = "agg_funcs", + since = "1.0.0") case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCastInputTypes { override def children: Seq[Expression] = child :: Nil @@ -46,38 +58,94 @@ case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCast case _ => DoubleType } - private lazy val sumDataType = resultType + private lazy val sum = AttributeReference("sum", resultType)() - private lazy val sum = AttributeReference("sum", sumDataType)() + private lazy val isEmpty = AttributeReference("isEmpty", BooleanType, nullable = false)() - private lazy val zero = Cast(Literal(0), sumDataType) + private lazy val zero = Literal.default(resultType) - override lazy val aggBufferAttributes = sum :: Nil + override lazy val aggBufferAttributes = resultType match { + case _: DecimalType => sum :: isEmpty :: Nil + case _ => sum :: Nil + } - override lazy val initialValues: Seq[Expression] = Seq( - /* sum = */ Literal.create(null, sumDataType) - ) + override lazy val initialValues: Seq[Expression] = resultType match { + case _: DecimalType => Seq(zero, Literal(true, BooleanType)) + case _ => Seq(Literal(null, resultType)) + } override lazy val updateExpressions: Seq[Expression] = { - if (child.nullable) { - Seq( - /* sum = */ - coalesce(coalesce(sum, zero) + child.cast(sumDataType), sum) - ) - } else { - Seq( - /* sum = */ - coalesce(sum, zero) + child.cast(sumDataType) - ) + resultType match { + case _: DecimalType => + // For decimal type, the initial value of `sum` is 0. We need to keep `sum` unchanged if + // the input is null, as SUM function ignores null input. The `sum` can only be null if + // overflow happens under non-ansi mode. + val sumExpr = if (child.nullable) { + If(child.isNull, sum, sum + KnownNotNull(child).cast(resultType)) + } else { + sum + child.cast(resultType) + } + // The buffer becomes non-empty after seeing the first not-null input. + val isEmptyExpr = if (child.nullable) { + isEmpty && child.isNull + } else { + Literal(false, BooleanType) + } + Seq(sumExpr, isEmptyExpr) + case _ => + // For non-decimal type, the initial value of `sum` is null, which indicates no value. + // We need `coalesce(sum, zero)` to start summing values. And we need an outer `coalesce` + // in case the input is nullable. The `sum` can only be null if there is no value, as + // non-decimal type can produce overflowed value under non-ansi mode. + if (child.nullable) { + Seq(coalesce(coalesce(sum, zero) + child.cast(resultType), sum)) + } else { + Seq(coalesce(sum, zero) + child.cast(resultType)) + } } } + /** + * For decimal type: + * If isEmpty is false and if sum is null, then it means we have had an overflow. + * + * update of the sum is as follows: + * Check if either portion of the left.sum or right.sum has overflowed + * If it has, then the sum value will remain null. + * If it did not have overflow, then add the sum.left and sum.right + * + * isEmpty: Set to false if either one of the left or right is set to false. This + * means we have seen atleast a value that was not null. + */ override lazy val mergeExpressions: Seq[Expression] = { - Seq( - /* sum = */ - coalesce(coalesce(sum.left, zero) + sum.right, sum.left) - ) + resultType match { + case _: DecimalType => + val bufferOverflow = !isEmpty.left && sum.left.isNull + val inputOverflow = !isEmpty.right && sum.right.isNull + Seq( + If( + bufferOverflow || inputOverflow, + Literal.create(null, resultType), + // If both the buffer and the input do not overflow, just add them, as they can't be + // null. See the comments inside `updateExpressions`: `sum` can only be null if + // overflow happens. + KnownNotNull(sum.left) + KnownNotNull(sum.right)), + isEmpty.left && isEmpty.right) + case _ => Seq(coalesce(coalesce(sum.left, zero) + sum.right, sum.left)) + } } - override lazy val evaluateExpression: Expression = sum + /** + * If the isEmpty is true, then it means there were no values to begin with or all the values + * were null, so the result will be null. + * If the isEmpty is false, then if sum is null that means an overflow has happened. + * So now, if ansi is enabled, then throw exception, if not then return null. + * If sum is not null, then return the sum. + */ + override lazy val evaluateExpression: Expression = resultType match { + case d: DecimalType => + If(isEmpty, Literal.create(null, resultType), + CheckOverflowInSum(sum, d, !SQLConf.get.ansiEnabled)) + case _ => sum + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala index 17d4a0dc4e884..7bfaf0fd0c767 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala @@ -224,7 +224,7 @@ object Block { } else { args.foreach { case _: ExprValue | _: Inline | _: Block => - case _: Int | _: Long | _: Float | _: Double | _: String => + case _: Boolean | _: Byte | _: Int | _: Long | _: Float | _: Double | _: String => case other => throw new IllegalArgumentException( s"Can not interpolate ${other.getClass.getName} into code block.") } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala index 04de83343be71..7e4560ab8161b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, EmptyBlock, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ /** @@ -26,7 +28,7 @@ import org.apache.spark.sql.types._ * Note: this expression is internal and created only by the optimizer, * we don't need to do type check for it. */ -case class UnscaledValue(child: Expression) extends UnaryExpression { +case class UnscaledValue(child: Expression) extends UnaryExpression with NullIntolerant { override def dataType: DataType = LongType override def toString: String = s"UnscaledValue($child)" @@ -44,25 +46,56 @@ case class UnscaledValue(child: Expression) extends UnaryExpression { * Note: this expression is internal and created only by the optimizer, * we don't need to do type check for it. */ -case class MakeDecimal(child: Expression, precision: Int, scale: Int) extends UnaryExpression { +case class MakeDecimal( + child: Expression, + precision: Int, + scale: Int, + nullOnOverflow: Boolean) extends UnaryExpression with NullIntolerant { + + def this(child: Expression, precision: Int, scale: Int) = { + this(child, precision, scale, !SQLConf.get.ansiEnabled) + } override def dataType: DataType = DecimalType(precision, scale) - override def nullable: Boolean = true + override def nullable: Boolean = child.nullable || nullOnOverflow override def toString: String = s"MakeDecimal($child,$precision,$scale)" - protected override def nullSafeEval(input: Any): Any = - Decimal(input.asInstanceOf[Long], precision, scale) + protected override def nullSafeEval(input: Any): Any = { + val longInput = input.asInstanceOf[Long] + val result = new Decimal() + if (nullOnOverflow) { + result.setOrNull(longInput, precision, scale) + } else { + result.set(longInput, precision, scale) + } + } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, eval => { + val setMethod = if (nullOnOverflow) { + "setOrNull" + } else { + "set" + } + val setNull = if (nullable) { + s"${ev.isNull} = ${ev.value} == null;" + } else { + "" + } s""" - ${ev.value} = (new Decimal()).setOrNull($eval, $precision, $scale); - ${ev.isNull} = ${ev.value} == null; - """ + |${ev.value} = (new Decimal()).$setMethod($eval, $precision, $scale); + |$setNull + |""".stripMargin }) } } +object MakeDecimal { + def apply(child: Expression, precision: Int, scale: Int): MakeDecimal = { + new MakeDecimal(child, precision, scale) + } +} + /** * An expression used to wrap the children when promote the precision of DecimalType to avoid * promote multiple times. @@ -81,30 +114,85 @@ case class PromotePrecision(child: Expression) extends UnaryExpression { /** * Rounds the decimal to given scale and check whether the decimal can fit in provided precision - * or not, returns null if not. + * or not. If not, if `nullOnOverflow` is `true`, it returns `null`; otherwise an + * `ArithmeticException` is thrown. */ -case class CheckOverflow(child: Expression, dataType: DecimalType) extends UnaryExpression { +case class CheckOverflow( + child: Expression, + dataType: DecimalType, + nullOnOverflow: Boolean) extends UnaryExpression { override def nullable: Boolean = true override def nullSafeEval(input: Any): Any = - input.asInstanceOf[Decimal].toPrecision(dataType.precision, dataType.scale) + input.asInstanceOf[Decimal].toPrecision( + dataType.precision, + dataType.scale, + Decimal.ROUND_HALF_UP, + nullOnOverflow) override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, eval => { - val tmp = ctx.freshName("tmp") s""" - | Decimal $tmp = $eval.clone(); - | if ($tmp.changePrecision(${dataType.precision}, ${dataType.scale})) { - | ${ev.value} = $tmp; - | } else { - | ${ev.isNull} = true; - | } + |${ev.value} = $eval.toPrecision( + | ${dataType.precision}, ${dataType.scale}, Decimal.ROUND_HALF_UP(), $nullOnOverflow); + |${ev.isNull} = ${ev.value} == null; """.stripMargin }) } - override def toString: String = s"CheckOverflow($child, $dataType)" + override def toString: String = s"CheckOverflow($child, $dataType, $nullOnOverflow)" + + override def sql: String = child.sql +} + +// A variant `CheckOverflow`, which treats null as overflow. This is necessary in `Sum`. +case class CheckOverflowInSum( + child: Expression, + dataType: DecimalType, + nullOnOverflow: Boolean) extends UnaryExpression { + + override def nullable: Boolean = true + + override def eval(input: InternalRow): Any = { + val value = child.eval(input) + if (value == null) { + if (nullOnOverflow) null else throw new ArithmeticException("Overflow in sum of decimals.") + } else { + value.asInstanceOf[Decimal].toPrecision( + dataType.precision, + dataType.scale, + Decimal.ROUND_HALF_UP, + nullOnOverflow) + } + } + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val childGen = child.genCode(ctx) + val nullHandling = if (nullOnOverflow) { + "" + } else { + s""" + |throw new ArithmeticException("Overflow in sum of decimals."); + |""".stripMargin + } + val code = code""" + |${childGen.code} + |boolean ${ev.isNull} = ${childGen.isNull}; + |Decimal ${ev.value} = null; + |if (${childGen.isNull}) { + | $nullHandling + |} else { + | ${ev.value} = ${childGen.value}.toPrecision( + | ${dataType.precision}, ${dataType.scale}, Decimal.ROUND_HALF_UP(), $nullOnOverflow); + | ${ev.isNull} = ${ev.value} == null; + |} + |""".stripMargin + + ev.copy(code = code) + } + + override def toString: String = s"CheckOverflowInSum($child, $dataType, $nullOnOverflow)" override def sql: String = child.sql } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index d7409c5efa372..c9c898a8de344 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -150,6 +150,16 @@ object SQLConf { } } + val ANSI_ENABLED = buildConf("spark.sql.ansi.enabled") + .doc("When true, Spark SQL uses an ANSI compliant dialect instead of being Hive compliant. " + + "For example, Spark will throw an exception at runtime instead of returning null results " + + "when the inputs to a SQL operator/function are invalid." + + "For full details of this dialect, you can find them in the section \"ANSI Compliance\" of " + + "Spark's documentation. Some ANSI dialect features may be not from the ANSI SQL " + + "standard directly, but their behaviors align with ANSI SQL's style") + .booleanConf + .createWithDefault(false) + val OPTIMIZER_EXCLUDED_RULES = buildConf("spark.sql.optimizer.excludedRules") .doc("Configures a list of rules to be disabled in the optimizer, in which the rules are " + "specified by their rule names and separated by comma. It is not guaranteed that all the " + @@ -1617,6 +1627,8 @@ class SQLConf extends Serializable with Logging { /** ************************ Spark SQL Params/Hints ******************* */ + def ansiEnabled: Boolean = getConf(ANSI_ENABLED) + def optimizerExcludedRules: Option[String] = getConf(OPTIMIZER_EXCLUDED_RULES) def optimizerMaxIterations: Int = getConf(OPTIMIZER_MAX_ITERATIONS) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala index 9eed2eb202045..33c0cd07c8d07 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala @@ -22,6 +22,7 @@ import java.math.{BigInteger, MathContext, RoundingMode} import org.apache.spark.annotation.InterfaceStability import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.errors.QueryExecutionErrors /** * A mutable implementation of BigDecimal that can hold a Long if values are small enough. @@ -242,9 +243,25 @@ final class Decimal extends Ordered[Decimal] with Serializable { private[sql] def toPrecision( precision: Int, scale: Int, - roundMode: BigDecimal.RoundingMode.Value = ROUND_HALF_UP): Decimal = { + roundMode: BigDecimal.RoundingMode.Value, + nullOnOverflow: Boolean): Decimal = { val copy = clone() - if (copy.changePrecision(precision, scale, roundMode)) copy else null + if (copy.changePrecision(precision, scale, roundMode)) { + copy + } else { + if (nullOnOverflow) { + null + } else { + throw QueryExecutionErrors.cannotChangeDecimalPrecisionError(this, precision, scale) + } + } + } + + private[sql] def toPrecision( + precision: Int, + scale: Int, + roundMode: BigDecimal.RoundingMode.Value = ROUND_HALF_UP): Decimal = { + toPrecision(precision, scale, roundMode, true) } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DecimalExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DecimalExpressionSuite.scala index a8f758d625a02..941bab2ea01e9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DecimalExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DecimalExpressionSuite.scala @@ -45,18 +45,19 @@ class DecimalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { test("CheckOverflow") { val d1 = Decimal("10.1") - checkEvaluation(CheckOverflow(Literal(d1), DecimalType(4, 0)), Decimal("10")) - checkEvaluation(CheckOverflow(Literal(d1), DecimalType(4, 1)), d1) - checkEvaluation(CheckOverflow(Literal(d1), DecimalType(4, 2)), d1) - checkEvaluation(CheckOverflow(Literal(d1), DecimalType(4, 3)), null) + checkEvaluation(CheckOverflow(Literal(d1), DecimalType(4, 0), true), Decimal("10")) + checkEvaluation(CheckOverflow(Literal(d1), DecimalType(4, 1), true), d1) + checkEvaluation(CheckOverflow(Literal(d1), DecimalType(4, 2), true), d1) + checkEvaluation(CheckOverflow(Literal(d1), DecimalType(4, 3), true), null) val d2 = Decimal(101, 3, 1) - checkEvaluation(CheckOverflow(Literal(d2), DecimalType(4, 0)), Decimal("10")) - checkEvaluation(CheckOverflow(Literal(d2), DecimalType(4, 1)), d2) - checkEvaluation(CheckOverflow(Literal(d2), DecimalType(4, 2)), d2) - checkEvaluation(CheckOverflow(Literal(d2), DecimalType(4, 3)), null) + checkEvaluation(CheckOverflow(Literal(d2), DecimalType(4, 0), true), Decimal("10")) + checkEvaluation(CheckOverflow(Literal(d2), DecimalType(4, 1), true), d2) + checkEvaluation(CheckOverflow(Literal(d2), DecimalType(4, 2), true), d2) + checkEvaluation(CheckOverflow(Literal(d2), DecimalType(4, 3), true), null) - checkEvaluation(CheckOverflow(Literal.create(null, DecimalType(2, 1)), DecimalType(3, 2)), null) + checkEvaluation( + CheckOverflow(Literal.create(null, DecimalType(2, 1)), DecimalType(3, 2), true), null) } } diff --git a/sql/core/pom.xml b/sql/core/pom.xml index 0db539b4e8762..3f102462b6fd5 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.4.1-kylin-r70 + 2.4.1-kylin-r71 ../../pom.xml diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 5075209d7454f..c19250b9325b7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -17,39 +17,41 @@ package org.apache.spark.sql -import java.io.File +import java.io.{ByteArrayOutputStream, File} import java.nio.charset.StandardCharsets import java.sql.{Date, Timestamp} import java.util.UUID - +import java.util.concurrent.atomic.AtomicLong +import scala.reflect.runtime.universe.TypeTag import scala.util.Random - -import org.scalatest.Matchers._ - +import org.scalatest.Matchers.{assert, intercept, _} import org.apache.spark.SparkException import org.apache.spark.scheduler.{SparkListener, SparkListenerJobEnd} import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder} import org.apache.spark.sql.catalyst.expressions.Uuid import org.apache.spark.sql.catalyst.optimizer.ConvertToLocalRelation -import org.apache.spark.sql.catalyst.plans.logical.{Filter, OneRowRelation, Union} +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, OneRowRelation} +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.{FilterExec, QueryExecution, WholeStageCodegenExec} import org.apache.spark.sql.execution.aggregate.HashAggregateExec import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.test.{ExamplePoint, ExamplePointUDT, SharedSQLContext} -import org.apache.spark.sql.test.SQLTestData.{NullInts, NullStrings, TestData2} +import org.apache.spark.sql.test.{ExamplePoint, ExamplePointUDT, SharedSparkSession} +import org.apache.spark.sql.test.SQLTestData.{DecimalData, NullStrings, TestData2} import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.CalendarInterval import org.apache.spark.util.Utils import org.apache.spark.util.random.XORShiftRandom -class DataFrameSuite extends QueryTest with SharedSQLContext { +class DataFrameSuite extends QueryTest with SharedSparkSession { import testImplicits._ test("analysis error should be eagerly reported") { - intercept[Exception] { testData.select('nonExistentName) } + intercept[Exception] { testData.select("nonExistentName") } intercept[Exception] { - testData.groupBy('key).agg(Map("nonExistentName" -> "sum")) + testData.groupBy("key").agg(Map("nonExistentName" -> "sum")) } intercept[Exception] { testData.groupBy("nonExistentName").agg(Map("key" -> "sum")) @@ -85,129 +87,12 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { testData.collect().toSeq) } - test("union all") { - val unionDF = testData.union(testData).union(testData) - .union(testData).union(testData) - - // Before optimizer, Union should be combined. - assert(unionDF.queryExecution.analyzed.collect { - case j: Union if j.children.size == 5 => j }.size === 1) - - checkAnswer( - unionDF.agg(avg('key), max('key), min('key), sum('key)), - Row(50.5, 100, 1, 25250) :: Nil - ) - } - - test("union should union DataFrames with UDTs (SPARK-13410)") { - val rowRDD1 = sparkContext.parallelize(Seq(Row(1, new ExamplePoint(1.0, 2.0)))) - val schema1 = StructType(Array(StructField("label", IntegerType, false), - StructField("point", new ExamplePointUDT(), false))) - val rowRDD2 = sparkContext.parallelize(Seq(Row(2, new ExamplePoint(3.0, 4.0)))) - val schema2 = StructType(Array(StructField("label", IntegerType, false), - StructField("point", new ExamplePointUDT(), false))) - val df1 = spark.createDataFrame(rowRDD1, schema1) - val df2 = spark.createDataFrame(rowRDD2, schema2) - - checkAnswer( - df1.union(df2).orderBy("label"), - Seq(Row(1, new ExamplePoint(1.0, 2.0)), Row(2, new ExamplePoint(3.0, 4.0))) - ) - } - - test("union by name") { - var df1 = Seq((1, 2, 3)).toDF("a", "b", "c") - var df2 = Seq((3, 1, 2)).toDF("c", "a", "b") - val df3 = Seq((2, 3, 1)).toDF("b", "c", "a") - val unionDf = df1.unionByName(df2.unionByName(df3)) - checkAnswer(unionDf, - Row(1, 2, 3) :: Row(1, 2, 3) :: Row(1, 2, 3) :: Nil - ) - - // Check if adjacent unions are combined into a single one - assert(unionDf.queryExecution.optimizedPlan.collect { case u: Union => true }.size == 1) - - // Check failure cases - df1 = Seq((1, 2)).toDF("a", "c") - df2 = Seq((3, 4, 5)).toDF("a", "b", "c") - var errMsg = intercept[AnalysisException] { - df1.unionByName(df2) - }.getMessage - assert(errMsg.contains( - "Union can only be performed on tables with the same number of columns, " + - "but the first table has 2 columns and the second table has 3 columns")) - - df1 = Seq((1, 2, 3)).toDF("a", "b", "c") - df2 = Seq((4, 5, 6)).toDF("a", "c", "d") - errMsg = intercept[AnalysisException] { - df1.unionByName(df2) - }.getMessage - assert(errMsg.contains("""Cannot resolve column name "b" among (a, c, d)""")) - } - - test("union by name - type coercion") { - var df1 = Seq((1, "a")).toDF("c0", "c1") - var df2 = Seq((3, 1L)).toDF("c1", "c0") - checkAnswer(df1.unionByName(df2), Row(1L, "a") :: Row(1L, "3") :: Nil) - - df1 = Seq((1, 1.0)).toDF("c0", "c1") - df2 = Seq((8L, 3.0)).toDF("c1", "c0") - checkAnswer(df1.unionByName(df2), Row(1.0, 1.0) :: Row(3.0, 8.0) :: Nil) - - df1 = Seq((2.0f, 7.4)).toDF("c0", "c1") - df2 = Seq(("a", 4.0)).toDF("c1", "c0") - checkAnswer(df1.unionByName(df2), Row(2.0, "7.4") :: Row(4.0, "a") :: Nil) - - df1 = Seq((1, "a", 3.0)).toDF("c0", "c1", "c2") - df2 = Seq((1.2, 2, "bc")).toDF("c2", "c0", "c1") - val df3 = Seq(("def", 1.2, 3)).toDF("c1", "c2", "c0") - checkAnswer(df1.unionByName(df2.unionByName(df3)), - Row(1, "a", 3.0) :: Row(2, "bc", 1.2) :: Row(3, "def", 1.2) :: Nil - ) - } - - test("union by name - check case sensitivity") { - def checkCaseSensitiveTest(): Unit = { - val df1 = Seq((1, 2, 3)).toDF("ab", "cd", "ef") - val df2 = Seq((4, 5, 6)).toDF("cd", "ef", "AB") - checkAnswer(df1.unionByName(df2), Row(1, 2, 3) :: Row(6, 4, 5) :: Nil) - } - withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { - val errMsg2 = intercept[AnalysisException] { - checkCaseSensitiveTest() - }.getMessage - assert(errMsg2.contains("""Cannot resolve column name "ab" among (cd, ef, AB)""")) - } - withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { - checkCaseSensitiveTest() - } - } - - test("union by name - check name duplication") { - Seq((true, ("a", "a")), (false, ("aA", "Aa"))).foreach { case (caseSensitive, (c0, c1)) => - withSQLConf(SQLConf.CASE_SENSITIVE.key -> caseSensitive.toString) { - var df1 = Seq((1, 1)).toDF(c0, c1) - var df2 = Seq((1, 1)).toDF("c0", "c1") - var errMsg = intercept[AnalysisException] { - df1.unionByName(df2) - }.getMessage - assert(errMsg.contains("Found duplicate column(s) in the left attributes:")) - df1 = Seq((1, 1)).toDF("c0", "c1") - df2 = Seq((1, 1)).toDF(c0, c1) - errMsg = intercept[AnalysisException] { - df1.unionByName(df2) - }.getMessage - assert(errMsg.contains("Found duplicate column(s) in the right attributes:")) - } - } - } - test("empty data frame") { assert(spark.emptyDataFrame.columns.toSeq === Seq.empty[String]) assert(spark.emptyDataFrame.count() === 0) } - test("head and take") { + test("head, take") { assert(testData.take(2) === testData.collect().take(2)) assert(testData.head(2) === testData.collect().take(2)) assert(testData.head(2).head.schema === testData.schema) @@ -248,8 +133,10 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { test("Star Expansion - CreateStruct and CreateArray") { val structDf = testData2.select("a", "b").as("record") // CreateStruct and CreateArray in aggregateExpressions - assert(structDf.groupBy($"a").agg(min(struct($"record.*"))).first() == Row(3, Row(3, 1))) - assert(structDf.groupBy($"a").agg(min(array($"record.*"))).first() == Row(3, Seq(3, 1))) + assert(structDf.groupBy($"a").agg(min(struct($"record.*"))). + sort("a").first() == Row(1, Row(1, 1))) + assert(structDf.groupBy($"a").agg(min(array($"record.*"))). + sort("a").first() == Row(1, Seq(1, 1))) // CreateStruct and CreateArray in project list (unresolved alias) assert(structDf.select(struct($"record.*")).first() == Row(Row(1, 1))) @@ -279,7 +166,118 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { structDf.select(hash($"a", $"record.*"))) } - test("Star Expansion - explode should fail with a meaningful message if it takes a star") { + private def assertDecimalSumOverflow( + df: DataFrame, ansiEnabled: Boolean, expectedAnswer: Row): Unit = { + if (!ansiEnabled) { + checkAnswer(df, expectedAnswer) + } else { + val e = intercept[SparkException] { + df.collect() + } + assert(e.getCause.isInstanceOf[ArithmeticException]) + assert(e.getCause.getMessage.contains("cannot be represented as Decimal") || + e.getCause.getMessage.contains("Overflow in sum of decimals")) + } + } + + test("SPARK-28224: Aggregate sum big decimal overflow") { + val largeDecimals = spark.sparkContext.parallelize( + DecimalData(BigDecimal("1"* 20 + ".123"), BigDecimal("1"* 20 + ".123")) :: + DecimalData(BigDecimal("9"* 20 + ".123"), BigDecimal("9"* 20 + ".123")) :: Nil).toDF() + + Seq(true, false).foreach { ansiEnabled => + withSQLConf((SQLConf.ANSI_ENABLED.key, ansiEnabled.toString)) { + val structDf = largeDecimals.select("a").agg(sum("a")) + assertDecimalSumOverflow(structDf, ansiEnabled, Row(null)) + } + } + } + + test("SPARK-28067: sum of null decimal values") { + Seq("true", "false").foreach { wholeStageEnabled => + withSQLConf((SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, wholeStageEnabled)) { + Seq("true", "false").foreach { ansiEnabled => + withSQLConf((SQLConf.ANSI_ENABLED.key, ansiEnabled)) { + val df = spark.range(1, 4, 1).select(expr(s"cast(null as decimal(38,18)) as d")) + checkAnswer(df.agg(sum($"d")), Row(null)) + } + } + } + } + } + + test("SPARK-28067: Aggregate sum should not return wrong results for decimal overflow") { + Seq("true", "false").foreach { wholeStageEnabled => + withSQLConf((SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, wholeStageEnabled)) { + Seq(true, false).foreach { ansiEnabled => + withSQLConf((SQLConf.ANSI_ENABLED.key, ansiEnabled.toString)) { + val df0 = Seq( + (BigDecimal("10000000000000000000"), 1), + (BigDecimal("10000000000000000000"), 1), + (BigDecimal("10000000000000000000"), 2)).toDF("decNum", "intNum") + val df1 = Seq( + (BigDecimal("10000000000000000000"), 2), + (BigDecimal("10000000000000000000"), 2), + (BigDecimal("10000000000000000000"), 2), + (BigDecimal("10000000000000000000"), 2), + (BigDecimal("10000000000000000000"), 2), + (BigDecimal("10000000000000000000"), 2), + (BigDecimal("10000000000000000000"), 2), + (BigDecimal("10000000000000000000"), 2), + (BigDecimal("10000000000000000000"), 2)).toDF("decNum", "intNum") + val df = df0.union(df1) + val df2 = df.withColumnRenamed("decNum", "decNum2"). + join(df, "intNum").agg(sum("decNum")) + + val expectedAnswer = Row(null) + assertDecimalSumOverflow(df2, ansiEnabled, expectedAnswer) + + val decStr = "1" + "0" * 19 + val d1 = spark.range(0, 12, 1, 1) + val d2 = d1.select(expr(s"cast('$decStr' as decimal (38, 18)) as d")).agg(sum($"d")) + assertDecimalSumOverflow(d2, ansiEnabled, expectedAnswer) + + val d3 = spark.range(0, 1, 1, 1).union(spark.range(0, 11, 1, 1)) + val d4 = d3.select(expr(s"cast('$decStr' as decimal (38, 18)) as d")).agg(sum($"d")) + assertDecimalSumOverflow(d4, ansiEnabled, expectedAnswer) + + val d5 = d3.select(expr(s"cast('$decStr' as decimal (38, 18)) as d"), + lit(1).as("key")).groupBy("key").agg(sum($"d").alias("sumd")).select($"sumd") + assertDecimalSumOverflow(d5, ansiEnabled, expectedAnswer) + + val nullsDf = spark.range(1, 4, 1).select(expr(s"cast(null as decimal(38,18)) as d")) + + val largeDecimals = Seq(BigDecimal("1"* 20 + ".123"), BigDecimal("9"* 20 + ".123")). + toDF("d") + assertDecimalSumOverflow( + nullsDf.union(largeDecimals).agg(sum($"d")), ansiEnabled, expectedAnswer) + + val df3 = Seq( + (BigDecimal("10000000000000000000"), 1), + (BigDecimal("50000000000000000000"), 1), + (BigDecimal("10000000000000000000"), 2)).toDF("decNum", "intNum") + + val df4 = Seq( + (BigDecimal("10000000000000000000"), 1), + (BigDecimal("10000000000000000000"), 1), + (BigDecimal("10000000000000000000"), 2)).toDF("decNum", "intNum") + + val df5 = Seq( + (BigDecimal("10000000000000000000"), 1), + (BigDecimal("10000000000000000000"), 1), + (BigDecimal("20000000000000000000"), 2)).toDF("decNum", "intNum") + + val df6 = df3.union(df4).union(df5) + val df7 = df6.groupBy("intNum").agg(sum("decNum"), countDistinct("decNum")). + filter("intNum == 1") + assertDecimalSumOverflow(df7, ansiEnabled, Row(1, null, 2)) + } + } + } + } + } + + test("Star Expansion - ds.explode should fail with a meaningful message if it takes a star") { val df = Seq(("1", "1,2"), ("2", "4"), ("3", "7,8,9")).toDF("prefix", "csv") val e = intercept[AnalysisException] { df.explode($"*") { case Row(prefix: String, csv: String) => @@ -300,6 +298,21 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { Row("3", "7,8,9", "3:9") :: Nil) } + test("Star Expansion - explode should fail with a meaningful message if it takes a star") { + val df = Seq(("1,2"), ("4"), ("7,8,9")).toDF("csv") + val e = intercept[AnalysisException] { + df.select(explode($"*")) + } + assert(e.getMessage.contains("Invalid usage of '*' in expression 'explode'")) + } + + test("explode on output of array-valued function") { + val df = Seq(("1,2"), ("4"), ("7,8,9")).toDF("csv") + checkAnswer( + df.select(explode(split($"csv", pattern = ","))), + Row("1") :: Row("2") :: Row("4") :: Row("7") :: Row("8") :: Row("9") :: Nil) + } + test("Star Expansion - explode alias and star") { val df = Seq((Array("a"), 1)).toDF("a", "b") @@ -354,12 +367,12 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { test("repartition") { intercept[IllegalArgumentException] { - testData.select('key).repartition(0) + testData.select("key").repartition(0) } checkAnswer( - testData.select('key).repartition(10).select('key), - testData.select('key).collect().toSeq) + testData.select("key").repartition(10).select("key"), + testData.select("key").collect().toSeq) } test("repartition with SortOrder") { @@ -421,14 +434,14 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { test("coalesce") { intercept[IllegalArgumentException] { - testData.select('key).coalesce(0) + testData.select("key").coalesce(0) } - assert(testData.select('key).coalesce(1).rdd.partitions.size === 1) + assert(testData.select("key").coalesce(1).rdd.partitions.size === 1) checkAnswer( - testData.select('key).coalesce(1).select('key), - testData.select('key).collect().toSeq) + testData.select("key").coalesce(1).select("key"), + testData.select("key").collect().toSeq) assert(spark.emptyDataFrame.coalesce(1).rdd.partitions.size === 1) } @@ -441,7 +454,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { test("convert Scala Symbol 'attrname into unresolved attribute") { checkAnswer( - testData.where('key === lit(1)).select('value), + testData.where($"key" === lit(1)).select("value"), Row("1")) } @@ -453,17 +466,17 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { test("simple select") { checkAnswer( - testData.where('key === lit(1)).select('value), + testData.where($"key" === lit(1)).select("value"), Row("1")) } test("select with functions") { checkAnswer( - testData.select(sum('value), avg('value), count(lit(1))), + testData.select(sum("value"), avg("value"), count(lit(1))), Row(5050.0, 50.5, 100)) checkAnswer( - testData2.select('a + 'b, 'a < 'b), + testData2.select($"a" + $"b", $"a" < $"b"), Seq( Row(2, false), Row(3, true), @@ -473,31 +486,31 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { Row(5, false))) checkAnswer( - testData2.select(sumDistinct('a)), + testData2.select(sumDistinct($"a")), Row(6)) } test("sorting with null ordering") { val data = Seq[java.lang.Integer](2, 1, null).toDF("key") - checkAnswer(data.orderBy('key.asc), Row(null) :: Row(1) :: Row(2) :: Nil) + checkAnswer(data.orderBy($"key".asc), Row(null) :: Row(1) :: Row(2) :: Nil) checkAnswer(data.orderBy(asc("key")), Row(null) :: Row(1) :: Row(2) :: Nil) - checkAnswer(data.orderBy('key.asc_nulls_first), Row(null) :: Row(1) :: Row(2) :: Nil) + checkAnswer(data.orderBy($"key".asc_nulls_first), Row(null) :: Row(1) :: Row(2) :: Nil) checkAnswer(data.orderBy(asc_nulls_first("key")), Row(null) :: Row(1) :: Row(2) :: Nil) - checkAnswer(data.orderBy('key.asc_nulls_last), Row(1) :: Row(2) :: Row(null) :: Nil) + checkAnswer(data.orderBy($"key".asc_nulls_last), Row(1) :: Row(2) :: Row(null) :: Nil) checkAnswer(data.orderBy(asc_nulls_last("key")), Row(1) :: Row(2) :: Row(null) :: Nil) - checkAnswer(data.orderBy('key.desc), Row(2) :: Row(1) :: Row(null) :: Nil) + checkAnswer(data.orderBy($"key".desc), Row(2) :: Row(1) :: Row(null) :: Nil) checkAnswer(data.orderBy(desc("key")), Row(2) :: Row(1) :: Row(null) :: Nil) - checkAnswer(data.orderBy('key.desc_nulls_first), Row(null) :: Row(2) :: Row(1) :: Nil) + checkAnswer(data.orderBy($"key".desc_nulls_first), Row(null) :: Row(2) :: Row(1) :: Nil) checkAnswer(data.orderBy(desc_nulls_first("key")), Row(null) :: Row(2) :: Row(1) :: Nil) - checkAnswer(data.orderBy('key.desc_nulls_last), Row(2) :: Row(1) :: Row(null) :: Nil) + checkAnswer(data.orderBy($"key".desc_nulls_last), Row(2) :: Row(1) :: Row(null) :: Nil) checkAnswer(data.orderBy(desc_nulls_last("key")), Row(2) :: Row(1) :: Row(null) :: Nil) } test("global sorting") { checkAnswer( - testData2.orderBy('a.asc, 'b.asc), + testData2.orderBy($"a".asc, $"b".asc), Seq(Row(1, 1), Row(1, 2), Row(2, 1), Row(2, 2), Row(3, 1), Row(3, 2))) checkAnswer( @@ -505,31 +518,31 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { Seq(Row(1, 2), Row(1, 1), Row(2, 2), Row(2, 1), Row(3, 2), Row(3, 1))) checkAnswer( - testData2.orderBy('a.asc, 'b.desc), + testData2.orderBy($"a".asc, $"b".desc), Seq(Row(1, 2), Row(1, 1), Row(2, 2), Row(2, 1), Row(3, 2), Row(3, 1))) checkAnswer( - testData2.orderBy('a.desc, 'b.desc), + testData2.orderBy($"a".desc, $"b".desc), Seq(Row(3, 2), Row(3, 1), Row(2, 2), Row(2, 1), Row(1, 2), Row(1, 1))) checkAnswer( - testData2.orderBy('a.desc, 'b.asc), + testData2.orderBy($"a".desc, $"b".asc), Seq(Row(3, 1), Row(3, 2), Row(2, 1), Row(2, 2), Row(1, 1), Row(1, 2))) checkAnswer( - arrayData.toDF().orderBy('data.getItem(0).asc), + arrayData.toDF().orderBy($"data".getItem(0).asc), arrayData.toDF().collect().sortBy(_.getAs[Seq[Int]](0)(0)).toSeq) checkAnswer( - arrayData.toDF().orderBy('data.getItem(0).desc), + arrayData.toDF().orderBy($"data".getItem(0).desc), arrayData.toDF().collect().sortBy(_.getAs[Seq[Int]](0)(0)).reverse.toSeq) checkAnswer( - arrayData.toDF().orderBy('data.getItem(1).asc), + arrayData.toDF().orderBy($"data".getItem(1).asc), arrayData.toDF().collect().sortBy(_.getAs[Seq[Int]](0)(1)).toSeq) checkAnswer( - arrayData.toDF().orderBy('data.getItem(1).desc), + arrayData.toDF().orderBy($"data".getItem(1).desc), arrayData.toDF().collect().sortBy(_.getAs[Seq[Int]](0)(1)).reverse.toSeq) } @@ -552,14 +565,13 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { Row(0) :: Row(1) :: Nil ) } - test("except") { checkAnswer( lowerCaseData.except(upperCaseData), Row(1, "a") :: - Row(2, "b") :: - Row(3, "c") :: - Row(4, "d") :: Nil) + Row(2, "b") :: + Row(3, "c") :: + Row(4, "d") :: Nil) checkAnswer(lowerCaseData.except(lowerCaseData), Nil) checkAnswer(upperCaseData.except(upperCaseData), Nil) @@ -584,8 +596,8 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { checkAnswer( df.except(df.filter("0 = 1")), Row("id1", 1) :: - Row("id", 1) :: - Row("id1", 2) :: Nil) + Row("id", 1) :: + Row("id1", 2) :: Nil) // check if the empty set on the left side works checkAnswer( @@ -636,9 +648,9 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { checkAnswer( lowerCaseData.exceptAll(upperCaseData), Row(1, "a") :: - Row(2, "b") :: - Row(3, "c") :: - Row(4, "d") :: Nil) + Row(2, "b") :: + Row(3, "c") :: + Row(4, "d") :: Nil) checkAnswer(lowerCaseData.exceptAll(lowerCaseData), Nil) checkAnswer(upperCaseData.exceptAll(upperCaseData), Nil) @@ -661,16 +673,16 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { // check that duplicates are retained. val df = spark.sparkContext.parallelize( NullStrings(1, "id1") :: - NullStrings(1, "id1") :: - NullStrings(2, "id1") :: - NullStrings(3, null) :: Nil).toDF("id", "value") + NullStrings(1, "id1") :: + NullStrings(2, "id1") :: + NullStrings(3, null) :: Nil).toDF("id", "value") checkAnswer( df.exceptAll(df.filter("0 = 1")), Row(1, "id1") :: - Row(1, "id1") :: - Row(2, "id1") :: - Row(3, null) :: Nil) + Row(1, "id1") :: + Row(2, "id1") :: + Row(3, null) :: Nil) // check if the empty set on the left side works checkAnswer( @@ -704,18 +716,18 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { checkAnswer( lowerCaseData.intersect(lowerCaseData), Row(1, "a") :: - Row(2, "b") :: - Row(3, "c") :: - Row(4, "d") :: Nil) + Row(2, "b") :: + Row(3, "c") :: + Row(4, "d") :: Nil) checkAnswer(lowerCaseData.intersect(upperCaseData), Nil) // check null equality checkAnswer( nullInts.intersect(nullInts), Row(1) :: - Row(2) :: - Row(3) :: - Row(null) :: Nil) + Row(2) :: + Row(3) :: + Row(null) :: Nil) // check if values are de-duplicated checkAnswer( @@ -727,8 +739,8 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { checkAnswer( df.intersect(df), Row("id1", 1) :: - Row("id", 1) :: - Row("id1", 2) :: Nil) + Row("id", 1) :: + Row("id1", 2) :: Nil) } test("intersect - nullability") { @@ -756,21 +768,21 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { checkAnswer( lowerCaseDataWithDuplicates.intersectAll(lowerCaseDataWithDuplicates), Row(1, "a") :: - Row(2, "b") :: - Row(2, "b") :: - Row(3, "c") :: - Row(3, "c") :: - Row(3, "c") :: - Row(4, "d") :: Nil) + Row(2, "b") :: + Row(2, "b") :: + Row(3, "c") :: + Row(3, "c") :: + Row(3, "c") :: + Row(4, "d") :: Nil) checkAnswer(lowerCaseData.intersectAll(upperCaseData), Nil) // check null equality checkAnswer( nullInts.intersectAll(nullInts), Row(1) :: - Row(2) :: - Row(3) :: - Row(null) :: Nil) + Row(2) :: + Row(3) :: + Row(null) :: Nil) // Duplicate nulls are preserved. checkAnswer( @@ -811,7 +823,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { checkAnswer( // SELECT *, foo(key, value) FROM testData - testData.select($"*", foo('key, 'value)).limit(3), + testData.select($"*", foo($"key", $"value")).limit(3), Row(1, "1", "11") :: Row(2, "2", "22") :: Row(3, "3", "33") :: Nil ) } @@ -914,7 +926,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { } test("replace column using withColumns") { - val df2 = sparkContext.parallelize(Array((1, 2), (2, 3), (3, 4))).toDF("x", "y") + val df2 = sparkContext.parallelize(Seq((1, 2), (2, 3), (3, 4))).toDF("x", "y") val df3 = df2.withColumns(Seq("x", "newCol1", "newCol2"), Seq(df2("x") + 1, df2("y"), df2("y") + 1)) checkAnswer( @@ -1140,7 +1152,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { test("apply on query results (SPARK-5462)") { val df = testData.sparkSession.sql("select key from testData") - checkAnswer(df.select(df("key")), testData.select('key).collect().toSeq) + checkAnswer(df.select(df("key")), testData.select("key").collect().toSeq) } test("inputFiles") { @@ -1253,14 +1265,14 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { val longString = Array.fill(21)("1").mkString val df = sparkContext.parallelize(Seq("1", longString)).toDF() val expectedAnswerForFalse = "-RECORD 0----------------------\n" + - " value | 1 \n" + - "-RECORD 1----------------------\n" + - " value | 111111111111111111111 \n" + " value | 1 \n" + + "-RECORD 1----------------------\n" + + " value | 111111111111111111111 \n" assert(df.showString(10, truncate = 0, vertical = true) === expectedAnswerForFalse) val expectedAnswerForTrue = "-RECORD 0---------------------\n" + - " value | 1 \n" + - "-RECORD 1---------------------\n" + - " value | 11111111111111111... \n" + " value | 1 \n" + + "-RECORD 1---------------------\n" + + " value | 11111111111111111... \n" assert(df.showString(10, truncate = 20, vertical = true) === expectedAnswerForTrue) } @@ -1289,14 +1301,14 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { val longString = Array.fill(21)("1").mkString val df = sparkContext.parallelize(Seq("1", longString)).toDF() val expectedAnswerForFalse = "-RECORD 0----\n" + - " value | 1 \n" + - "-RECORD 1----\n" + - " value | 111 \n" + " value | 1 \n" + + "-RECORD 1----\n" + + " value | 111 \n" assert(df.showString(10, truncate = 3, vertical = true) === expectedAnswerForFalse) val expectedAnswerForTrue = "-RECORD 0------------------\n" + - " value | 1 \n" + - "-RECORD 1------------------\n" + - " value | 11111111111111... \n" + " value | 1 \n" + + "-RECORD 1------------------\n" + + " value | 11111111111111... \n" assert(df.showString(10, truncate = 17, vertical = true) === expectedAnswerForTrue) } @@ -1363,11 +1375,11 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { (Array(2, 3, 4), Array(2, 3, 4)) ).toDF() val expectedAnswer = "-RECORD 0--------\n" + - " _1 | [1, 2, 3] \n" + - " _2 | [1, 2, 3] \n" + - "-RECORD 1--------\n" + - " _1 | [2, 3, 4] \n" + - " _2 | [2, 3, 4] \n" + " _1 | [1, 2, 3] \n" + + " _2 | [1, 2, 3] \n" + + "-RECORD 1--------\n" + + " _1 | [2, 3, 4] \n" + + " _2 | [2, 3, 4] \n" assert(df.showString(10, vertical = true) === expectedAnswer) } @@ -1392,11 +1404,11 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { ("34".getBytes(StandardCharsets.UTF_8), "12346".getBytes(StandardCharsets.UTF_8)) ).toDF() val expectedAnswer = "-RECORD 0---------------\n" + - " _1 | [31 32] \n" + - " _2 | [41 42 43 2E] \n" + - "-RECORD 1---------------\n" + - " _1 | [33 34] \n" + - " _2 | [31 32 33 34 36] \n" + " _1 | [31 32] \n" + + " _2 | [41 42 43 2E] \n" + + "-RECORD 1---------------\n" + + " _1 | [33 34] \n" + + " _2 | [31 32 33 34 36] \n" assert(df.showString(10, vertical = true) === expectedAnswer) } @@ -1421,11 +1433,11 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { (2, 2) ).toDF() val expectedAnswer = "-RECORD 0--\n" + - " _1 | 1 \n" + - " _2 | 1 \n" + - "-RECORD 1--\n" + - " _1 | 2 \n" + - " _2 | 2 \n" + " _1 | 1 \n" + + " _2 | 1 \n" + + "-RECORD 1--\n" + + " _1 | 2 \n" + + " _2 | 2 \n" assert(df.showString(10, vertical = true) === expectedAnswer) } @@ -1442,9 +1454,9 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { test("SPARK-7319 showString, vertical = true") { val expectedAnswer = "-RECORD 0----\n" + - " key | 1 \n" + - " value | 1 \n" + - "only showing top 1 row\n" + " key | 1 \n" + + " value | 1 \n" + + "only showing top 1 row\n" assert(testData.select($"*").showString(1, vertical = true) === expectedAnswer) } @@ -1501,7 +1513,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { |""".stripMargin assert(df.showString(1, truncate = 0) === expectedAnswer) - withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> "GMT") { + withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> "UTC") { val expectedAnswer = """+----------+-------------------+ ||d |ts | @@ -1518,15 +1530,15 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { val ts = Timestamp.valueOf("2016-12-01 00:00:00") val df = Seq((d, ts)).toDF("d", "ts") val expectedAnswer = "-RECORD 0------------------\n" + - " d | 2016-12-01 \n" + - " ts | 2016-12-01 00:00:00 \n" + " d | 2016-12-01 \n" + + " ts | 2016-12-01 00:00:00 \n" assert(df.showString(1, truncate = 0, vertical = true) === expectedAnswer) - withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> "GMT") { + withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> "UTC") { val expectedAnswer = "-RECORD 0------------------\n" + - " d | 2016-12-01 \n" + - " ts | 2016-12-01 08:00:00 \n" + " d | 2016-12-01 \n" + + " ts | 2016-12-01 08:00:00 \n" assert(df.showString(1, truncate = 0, vertical = true) === expectedAnswer) } } @@ -1539,7 +1551,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { } test("SPARK-6899: type should match when using codegen") { - checkAnswer(decimalData.agg(avg('a)), Row(new java.math.BigDecimal(2))) + checkAnswer(decimalData.agg(avg("a")), Row(new java.math.BigDecimal(2))) } test("SPARK-7133: Implement struct, array, and map field accessor") { @@ -1581,10 +1593,10 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { test("SPARK-7324 dropDuplicates") { val testData = sparkContext.parallelize( (2, 1, 2) :: (1, 1, 1) :: - (1, 2, 1) :: (2, 1, 2) :: - (2, 2, 2) :: (2, 2, 1) :: - (2, 1, 1) :: (1, 1, 2) :: - (1, 2, 2) :: (1, 2, 1) :: Nil).toDF("key", "value1", "value2") + (1, 2, 1) :: (2, 1, 2) :: + (2, 2, 2) :: (2, 2, 1) :: + (2, 1, 1) :: (1, 1, 2) :: + (1, 2, 2) :: (1, 2, 1) :: Nil).toDF("key", "value1", "value2") checkAnswer( testData.dropDuplicates(), @@ -1669,47 +1681,48 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { test("SPARK-6941: Better error message for inserting into RDD-based Table") { withTempDir { dir => + withTempView("parquet_base", "json_base", "rdd_base", "indirect_ds", "one_row") { + val tempParquetFile = new File(dir, "tmp_parquet") + val tempJsonFile = new File(dir, "tmp_json") + + val df = Seq(Tuple1(1)).toDF() + val insertion = Seq(Tuple1(2)).toDF("col") + + // pass case: parquet table (HadoopFsRelation) + df.write.mode(SaveMode.Overwrite).parquet(tempParquetFile.getCanonicalPath) + val pdf = spark.read.parquet(tempParquetFile.getCanonicalPath) + pdf.createOrReplaceTempView("parquet_base") + + insertion.write.insertInto("parquet_base") + + // pass case: json table (InsertableRelation) + df.write.mode(SaveMode.Overwrite).json(tempJsonFile.getCanonicalPath) + val jdf = spark.read.json(tempJsonFile.getCanonicalPath) + jdf.createOrReplaceTempView("json_base") + insertion.write.mode(SaveMode.Overwrite).insertInto("json_base") + + // error cases: insert into an RDD + df.createOrReplaceTempView("rdd_base") + val e1 = intercept[AnalysisException] { + insertion.write.insertInto("rdd_base") + } + assert(e1.getMessage.contains("Inserting into an RDD-based table is not allowed.")) - val tempParquetFile = new File(dir, "tmp_parquet") - val tempJsonFile = new File(dir, "tmp_json") - - val df = Seq(Tuple1(1)).toDF() - val insertion = Seq(Tuple1(2)).toDF("col") - - // pass case: parquet table (HadoopFsRelation) - df.write.mode(SaveMode.Overwrite).parquet(tempParquetFile.getCanonicalPath) - val pdf = spark.read.parquet(tempParquetFile.getCanonicalPath) - pdf.createOrReplaceTempView("parquet_base") - - insertion.write.insertInto("parquet_base") - - // pass case: json table (InsertableRelation) - df.write.mode(SaveMode.Overwrite).json(tempJsonFile.getCanonicalPath) - val jdf = spark.read.json(tempJsonFile.getCanonicalPath) - jdf.createOrReplaceTempView("json_base") - insertion.write.mode(SaveMode.Overwrite).insertInto("json_base") - - // error cases: insert into an RDD - df.createOrReplaceTempView("rdd_base") - val e1 = intercept[AnalysisException] { - insertion.write.insertInto("rdd_base") - } - assert(e1.getMessage.contains("Inserting into an RDD-based table is not allowed.")) - - // error case: insert into a logical plan that is not a LeafNode - val indirectDS = pdf.select("_1").filter($"_1" > 5) - indirectDS.createOrReplaceTempView("indirect_ds") - val e2 = intercept[AnalysisException] { - insertion.write.insertInto("indirect_ds") - } - assert(e2.getMessage.contains("Inserting into an RDD-based table is not allowed.")) + // error case: insert into a logical plan that is not a LeafNode + val indirectDS = pdf.select("_1").filter($"_1" > 5) + indirectDS.createOrReplaceTempView("indirect_ds") + val e2 = intercept[AnalysisException] { + insertion.write.insertInto("indirect_ds") + } + assert(e2.getMessage.contains("Inserting into an RDD-based table is not allowed.")) - // error case: insert into an OneRowRelation - Dataset.ofRows(spark, OneRowRelation()).createOrReplaceTempView("one_row") - val e3 = intercept[AnalysisException] { - insertion.write.insertInto("one_row") + // error case: insert into an OneRowRelation + Dataset.ofRows(spark, OneRowRelation()).createOrReplaceTempView("one_row") + val e3 = intercept[AnalysisException] { + insertion.write.insertInto("one_row") + } + assert(e3.getMessage.contains("Inserting into an RDD-based table is not allowed.")) } - assert(e3.getMessage.contains("Inserting into an RDD-based table is not allowed.")) } } @@ -1741,7 +1754,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { test("Sorting columns are not in Filter and Project") { checkAnswer( - upperCaseData.filter('N > 1).select('N).filter('N < 6).orderBy('L.asc), + upperCaseData.filter($"N" > 1).select("N").filter($"N" < 6).orderBy($"L".asc), Row(2) :: Row(3) :: Row(4) :: Row(5) :: Nil) } @@ -1784,29 +1797,31 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { test("Alias uses internally generated names 'aggOrder' and 'havingCondition'") { val df = Seq(1 -> 2).toDF("i", "j") - val query1 = df.groupBy('i) - .agg(max('j).as("aggOrder")) - .orderBy(sum('j)) + val query1 = df.groupBy("i") + .agg(max("j").as("aggOrder")) + .orderBy(sum("j")) checkAnswer(query1, Row(1, 2)) // In the plan, there are two attributes having the same name 'havingCondition' // One is a user-provided alias name; another is an internally generated one. - val query2 = df.groupBy('i) - .agg(max('j).as("havingCondition")) - .where(sum('j) > 0) - .orderBy('havingCondition.asc) + val query2 = df.groupBy("i") + .agg(max("j").as("havingCondition")) + .where(sum("j") > 0) + .orderBy($"havingCondition".asc) checkAnswer(query2, Row(1, 2)) } test("SPARK-10316: respect non-deterministic expressions in PhysicalOperation") { - val input = spark.read.json((1 to 10).map(i => s"""{"id": $i}""").toDS()) + withTempDir { dir => + (1 to 10).toDF("id").write.mode(SaveMode.Overwrite).json(dir.getCanonicalPath) + val input = spark.read.json(dir.getCanonicalPath) - val df = input.select($"id", rand(0).as('r)) - df.as("a").join(df.filter($"r" < 0.5).as("b"), $"a.id" === $"b.id").collect().foreach { row => - assert(row.getDouble(1) - row.getDouble(3) === 0.0 +- 0.001) + val df = input.select($"id", rand(0).as("r")) + df.as("a").join(df.filter($"r" < 0.5).as("b"), $"a.id" === $"b.id").collect().foreach { row => + assert(row.getDouble(1) - row.getDouble(3) === 0.0 +- 0.001) + } } } - test("SPARK-10539: Project should not be pushed down through Intersect or Except") { val df1 = (1 to 100).map(Tuple1.apply).toDF("i") val df2 = (1 to 30).map(Tuple1.apply).toDF("i") @@ -1856,7 +1871,6 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { expected(except) ) } - test("SPARK-10743: keep the name of expression if possible when do cast") { val df = (1 to 10).map(Tuple1.apply).toDF("i").as("src") assert(df.select($"src.i".cast(StringType)).columns.head === "i") @@ -2001,8 +2015,8 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { val rdd = sparkContext.makeRDD(Seq(Row(1, 3), Row(2, 1))) val df = spark.createDataFrame( rdd, - new StructType().add("f1", IntegerType).add("f2", IntegerType), - needsConversion = false).select($"F1", $"f2".as("f2")) + new StructType().add("f1", IntegerType).add("f2", IntegerType)) + .select($"F1", $"f2".as("f2")) val df1 = df.as("a") val df2 = df.as("b") checkAnswer(df1.join(df2, $"a.f2" === $"b.f2"), Row(1, 3, 1, 3) :: Row(2, 1, 2, 1) :: Nil) @@ -2017,7 +2031,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { test("SPARK-11725: correctly handle null inputs for ScalaUDF") { val df = sparkContext.parallelize(Seq( - new java.lang.Integer(22) -> "John", + java.lang.Integer.valueOf(22) -> "John", null.asInstanceOf[java.lang.Integer] -> "Lucy")).toDF("age", "name") // passing null into the UDF that could handle it @@ -2126,19 +2140,23 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { val df = Seq("foo", "bar").map(Tuple1.apply).toDF("col") // invalid table names Seq("11111", "t~", "#$@sum", "table!#").foreach { name => - val m = intercept[AnalysisException](df.createOrReplaceTempView(name)).getMessage - assert(m.contains(s"Invalid view name: $name")) + withTempView(name) { + val m = intercept[AnalysisException](df.createOrReplaceTempView(name)).getMessage + assert(m.contains(s"Invalid view name: $name")) + } } // valid table names Seq("table1", "`11111`", "`t~`", "`#$@sum`", "`table!#`").foreach { name => - df.createOrReplaceTempView(name) + withTempView(name) { + df.createOrReplaceTempView(name) + } } } test("assertAnalyzed shouldn't replace original stack trace") { val e = intercept[AnalysisException] { - spark.range(1).select('id as 'a, 'id as 'b).groupBy('a).agg('b) + spark.range(1).select($"id" as "a", $"id" as "b").groupBy("a").agg($"b") } assert(e.getStackTrace.head.getClassName != classOf[QueryExecution].getName) @@ -2160,7 +2178,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { } - } + } test("SPARK-13774: Check error message for not existent globbed paths") { // Non-existent initial path component: @@ -2203,7 +2221,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { val size = 201L val rdd = sparkContext.makeRDD(Seq(Row.fromSeq(Seq.range(0, size)))) val schemas = List.range(0, size).map(a => StructField("name" + a, LongType, true)) - val df = spark.createDataFrame(rdd, StructType(schemas), false) + val df = spark.createDataFrame(rdd, StructType(schemas)) assert(df.persist.take(1).apply(0).toSeq(100).asInstanceOf[Long] == 100) } @@ -2233,9 +2251,9 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { } private def verifyNullabilityInFilterExec( - df: DataFrame, - expr: String, - expectedNonNullableColumns: Seq[String]): Unit = { + df: DataFrame, + expr: String, + expectedNonNullableColumns: Seq[String]): Unit = { val dfWithFilter = df.where(s"isnotnull($expr)").selectExpr(expr) dfWithFilter.queryExecution.executedPlan.collect { // When the child expression in isnotnull is null-intolerant (i.e. any null input will @@ -2250,9 +2268,9 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { test("SPARK-17957: no change on nullability in FilterExec output") { val df = sparkContext.parallelize(Seq( - null.asInstanceOf[java.lang.Integer] -> new java.lang.Integer(3), - new java.lang.Integer(1) -> null.asInstanceOf[java.lang.Integer], - new java.lang.Integer(2) -> new java.lang.Integer(4))).toDF() + null.asInstanceOf[java.lang.Integer] -> java.lang.Integer.valueOf(3), + java.lang.Integer.valueOf(1) -> null.asInstanceOf[java.lang.Integer], + java.lang.Integer.valueOf(2) -> java.lang.Integer.valueOf(4))).toDF() verifyNullabilityInFilterExec(df, expr = "Rand()", expectedNonNullableColumns = Seq.empty[String]) @@ -2267,9 +2285,9 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { test("SPARK-17957: set nullability to false in FilterExec output") { val df = sparkContext.parallelize(Seq( - null.asInstanceOf[java.lang.Integer] -> new java.lang.Integer(3), - new java.lang.Integer(1) -> null.asInstanceOf[java.lang.Integer], - new java.lang.Integer(2) -> new java.lang.Integer(4))).toDF() + null.asInstanceOf[java.lang.Integer] -> java.lang.Integer.valueOf(3), + java.lang.Integer.valueOf(1) -> null.asInstanceOf[java.lang.Integer], + java.lang.Integer.valueOf(2) -> java.lang.Integer.valueOf(4))).toDF() verifyNullabilityInFilterExec(df, expr = "_1 + _2 * 3", expectedNonNullableColumns = Seq("_1", "_2")) @@ -2338,7 +2356,6 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { val df = spark.range(1).selectExpr("CAST(id as DECIMAL) as x").selectExpr("percentile(x, 0.5)") checkAnswer(df, Row(BigDecimal(0)) :: Nil) } - test("SPARK-19893: cannot run set operations with map type") { val df = spark.range(1).select(map(lit("key"), $"id").as("m")) val e = intercept[AnalysisException](df.intersect(df)) @@ -2389,7 +2406,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { val e = intercept[SparkException] { df.filter(filter).count() }.getMessage - assert(e.contains("grows beyond 64 KB")) + assert(e.contains("grows beyond 64 KiB")) } } @@ -2405,7 +2422,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { test("order-by ordinal.") { checkAnswer( - testData2.select(lit(7), 'a, 'b).orderBy(lit(1), lit(2), lit(3)), + testData2.select(lit(7), $"a", $"b").orderBy(lit(1), lit(2), lit(3)), Seq(Row(7, 1, 1), Row(7, 1, 2), Row(7, 2, 1), Row(7, 2, 2), Row(7, 3, 1), Row(7, 3, 2))) } @@ -2424,7 +2441,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { test("SPARK-22271: mean overflows and returns null for some decimal variables") { val d = 0.034567890 val df = Seq(d, d, d, d, d, d, d, d, d, d).toDF("DecimalCol") - val result = df.select('DecimalCol cast DecimalType(38, 33)) + val result = df.select($"DecimalCol" cast DecimalType(38, 33)) .select(col("DecimalCol")).describe() val mean = result.select("DecimalCol").where($"summary" === "mean") assert(mean.collect().toSet === Set(Row("0.0345678900000000000000000000000000000"))) @@ -2460,24 +2477,25 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { val sourceDF = spark.createDataFrame(rows, schema) def structWhenDF: DataFrame = sourceDF - .select(when('cond, struct(lit("a").as("val1"), lit(10).as("val2"))).otherwise('s) as "res") - .select('res.getField("val1")) + .select(when($"cond", + struct(lit("a").as("val1"), lit(10).as("val2"))).otherwise($"s") as "res") + .select($"res".getField("val1")) def arrayWhenDF: DataFrame = sourceDF - .select(when('cond, array(lit("a"), lit("b"))).otherwise('a) as "res") - .select('res.getItem(0)) + .select(when($"cond", array(lit("a"), lit("b"))).otherwise($"a") as "res") + .select($"res".getItem(0)) def mapWhenDF: DataFrame = sourceDF - .select(when('cond, map(lit(0), lit("a"))).otherwise('m) as "res") - .select('res.getItem(0)) + .select(when($"cond", map(lit(0), lit("a"))).otherwise($"m") as "res") + .select($"res".getItem(0)) def structIfDF: DataFrame = sourceDF .select(expr("if(cond, struct('a' as val1, 10 as val2), s)") as "res") - .select('res.getField("val1")) + .select($"res".getField("val1")) def arrayIfDF: DataFrame = sourceDF .select(expr("if(cond, array('a', 'b'), a)") as "res") - .select('res.getItem(0)) + .select($"res".getItem(0)) def mapIfDF: DataFrame = sourceDF .select(expr("if(cond, map(0, 'a'), m)") as "res") - .select('res.getItem(0)) + .select($"res".getItem(0)) def checkResult(): Unit = { checkAnswer(structWhenDF, Seq(Row("a"), Row(null))) @@ -2517,7 +2535,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { } test("SPARK-24781: Using a reference not in aggregation in Filter/Sort") { - withSQLConf(SQLConf.DATAFRAME_RETAIN_GROUP_COLUMNS.key -> "false") { + withSQLConf(SQLConf.DATAFRAME_RETAIN_GROUP_COLUMNS.key -> "false") { val df = Seq(("test1", 0), ("test2", 1)).toDF("name", "id") val aggPlusSort1 = df.groupBy(df("name")).agg(count(df("name"))).orderBy(df("name")) @@ -2540,17 +2558,17 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { // partitions. .write.partitionBy("p").option("compression", "gzip").json(path.getCanonicalPath) - var numJobs = 0 + val numJobs = new AtomicLong(0) sparkContext.addSparkListener(new SparkListener { override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = { - numJobs += 1 + numJobs.incrementAndGet() } }) val df = spark.read.json(path.getCanonicalPath) assert(df.columns === Array("i", "p")) spark.sparkContext.listenerBus.waitUntilEmpty(10000) - assert(numJobs == 1) + assert(numJobs.get() == 1L) } } @@ -2622,4 +2640,12 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { checkAnswer(res, Row("1-1", 6, 6)) } } + + test("SPARK-29442 Set `default` mode should override the existing mode") { + val df = Seq(Tuple1(1)).toDF() + val writer = df.write.mode("overwrite").mode("default") + val modeField = classOf[DataFrameWriter[Tuple1[Int]]].getDeclaredField("mode") + modeField.setAccessible(true) + assert(SaveMode.ErrorIfExists === modeField.get(writer).asInstanceOf[SaveMode]) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala index a5f904c621e6e..9daa69ce9f155 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala @@ -178,4 +178,14 @@ class UnsafeRowSuite extends SparkFunSuite { // Makes sure hashCode on unsafe array won't crash unsafeRow.getArray(0).hashCode() } + + test("SPARK-32018: setDecimal with overflowed value") { + val d1 = new Decimal().set(BigDecimal("10000000000000000000")).toPrecision(38, 18) + val row = InternalRow.apply(d1) + val unsafeRow = UnsafeProjection.create(Array[DataType](DecimalType(38, 18))).apply(row) + assert(unsafeRow.getDecimal(0, 38, 18) === d1) + val d2 = (d1 * Decimal(10)).toPrecision(39, 18) + unsafeRow.setDecimal(0, d2, 38) + assert(unsafeRow.getDecimal(0, 38, 18) === null) + } } diff --git a/sql/hive-thriftserver/pom.xml b/sql/hive-thriftserver/pom.xml index d84421acf8390..fe102a1691305 100644 --- a/sql/hive-thriftserver/pom.xml +++ b/sql/hive-thriftserver/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.4.1-kylin-r70 + 2.4.1-kylin-r71 ../../pom.xml diff --git a/sql/hive/pom.xml b/sql/hive/pom.xml index 9f76d192524ee..8efb3799d2fb8 100644 --- a/sql/hive/pom.xml +++ b/sql/hive/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.4.1-kylin-r70 + 2.4.1-kylin-r71 ../../pom.xml diff --git a/streaming/pom.xml b/streaming/pom.xml index e84cd5514784e..b47ae88c16916 100644 --- a/streaming/pom.xml +++ b/streaming/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.4.1-kylin-r70 + 2.4.1-kylin-r71 ../pom.xml diff --git a/tools/pom.xml b/tools/pom.xml index e009120bb52bc..25d9508c256c4 100644 --- a/tools/pom.xml +++ b/tools/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.11 - 2.4.1-kylin-r70 + 2.4.1-kylin-r71 ../pom.xml