diff --git a/assembly/pom.xml b/assembly/pom.xml
index 41e6648aa3f3c..dd3a386bb514e 100644
--- a/assembly/pom.xml
+++ b/assembly/pom.xml
@@ -21,7 +21,7 @@
org.apache.spark
spark-parent_2.11
- 2.4.1-kylin-r70
+ 2.4.1-kylin-r71
../pom.xml
diff --git a/common/kvstore/pom.xml b/common/kvstore/pom.xml
index 53bd7f261eff3..18029c86442d4 100644
--- a/common/kvstore/pom.xml
+++ b/common/kvstore/pom.xml
@@ -22,7 +22,7 @@
org.apache.spark
spark-parent_2.11
- 2.4.1-kylin-r70
+ 2.4.1-kylin-r71
../../pom.xml
diff --git a/common/network-common/pom.xml b/common/network-common/pom.xml
index 1a7fa64021197..b5a4551a54112 100644
--- a/common/network-common/pom.xml
+++ b/common/network-common/pom.xml
@@ -22,7 +22,7 @@
org.apache.spark
spark-parent_2.11
- 2.4.1-kylin-r70
+ 2.4.1-kylin-r71
../../pom.xml
diff --git a/common/network-shuffle/pom.xml b/common/network-shuffle/pom.xml
index 93a2fc1fbb92d..dadf340e84f01 100644
--- a/common/network-shuffle/pom.xml
+++ b/common/network-shuffle/pom.xml
@@ -22,7 +22,7 @@
org.apache.spark
spark-parent_2.11
- 2.4.1-kylin-r70
+ 2.4.1-kylin-r71
../../pom.xml
diff --git a/common/network-yarn/pom.xml b/common/network-yarn/pom.xml
index ba009e0e7896a..905fb1ebf540d 100644
--- a/common/network-yarn/pom.xml
+++ b/common/network-yarn/pom.xml
@@ -22,7 +22,7 @@
org.apache.spark
spark-parent_2.11
- 2.4.1-kylin-r70
+ 2.4.1-kylin-r71
../../pom.xml
diff --git a/common/sketch/pom.xml b/common/sketch/pom.xml
index 5f2bb64898ab8..723ed3b2f3c61 100644
--- a/common/sketch/pom.xml
+++ b/common/sketch/pom.xml
@@ -22,7 +22,7 @@
org.apache.spark
spark-parent_2.11
- 2.4.1-kylin-r70
+ 2.4.1-kylin-r71
../../pom.xml
diff --git a/common/tags/pom.xml b/common/tags/pom.xml
index 8160ff20c0da1..24cd8283640ac 100644
--- a/common/tags/pom.xml
+++ b/common/tags/pom.xml
@@ -22,7 +22,7 @@
org.apache.spark
spark-parent_2.11
- 2.4.1-kylin-r70
+ 2.4.1-kylin-r71
../../pom.xml
diff --git a/common/unsafe/pom.xml b/common/unsafe/pom.xml
index f43e375d3cedf..40583281e13d6 100644
--- a/common/unsafe/pom.xml
+++ b/common/unsafe/pom.xml
@@ -22,7 +22,7 @@
org.apache.spark
spark-parent_2.11
- 2.4.1-kylin-r70
+ 2.4.1-kylin-r71
../../pom.xml
diff --git a/core/pom.xml b/core/pom.xml
index 1bd6465593594..6496cead95797 100644
--- a/core/pom.xml
+++ b/core/pom.xml
@@ -21,7 +21,7 @@
org.apache.spark
spark-parent_2.11
- 2.4.1-kylin-r70
+ 2.4.1-kylin-r71
../pom.xml
diff --git a/examples/pom.xml b/examples/pom.xml
index faf57b95bca0e..6393c8a8c5052 100644
--- a/examples/pom.xml
+++ b/examples/pom.xml
@@ -21,7 +21,7 @@
org.apache.spark
spark-parent_2.11
- 2.4.1-kylin-r70
+ 2.4.1-kylin-r71
../pom.xml
diff --git a/external/avro/pom.xml b/external/avro/pom.xml
index 9cbeb976df591..20ee419f68840 100644
--- a/external/avro/pom.xml
+++ b/external/avro/pom.xml
@@ -21,7 +21,7 @@
org.apache.spark
spark-parent_2.11
- 2.4.1-kylin-r70
+ 2.4.1-kylin-r71
../../pom.xml
diff --git a/external/docker-integration-tests/pom.xml b/external/docker-integration-tests/pom.xml
index 7a4a1536a735c..d2f7030ec20fe 100644
--- a/external/docker-integration-tests/pom.xml
+++ b/external/docker-integration-tests/pom.xml
@@ -22,7 +22,7 @@
org.apache.spark
spark-parent_2.11
- 2.4.1-kylin-r70
+ 2.4.1-kylin-r71
../../pom.xml
diff --git a/external/flume-assembly/pom.xml b/external/flume-assembly/pom.xml
index c7fdd7f77e1a7..5b56d4be606c4 100644
--- a/external/flume-assembly/pom.xml
+++ b/external/flume-assembly/pom.xml
@@ -21,7 +21,7 @@
org.apache.spark
spark-parent_2.11
- 2.4.1-kylin-r70
+ 2.4.1-kylin-r71
../../pom.xml
diff --git a/external/flume-sink/pom.xml b/external/flume-sink/pom.xml
index 4c56b41b7f92b..cb05074f4292f 100644
--- a/external/flume-sink/pom.xml
+++ b/external/flume-sink/pom.xml
@@ -21,7 +21,7 @@
org.apache.spark
spark-parent_2.11
- 2.4.1-kylin-r70
+ 2.4.1-kylin-r71
../../pom.xml
diff --git a/external/flume/pom.xml b/external/flume/pom.xml
index 0a5bfdf6178fc..53b334cd95ab3 100644
--- a/external/flume/pom.xml
+++ b/external/flume/pom.xml
@@ -21,7 +21,7 @@
org.apache.spark
spark-parent_2.11
- 2.4.1-kylin-r70
+ 2.4.1-kylin-r71
../../pom.xml
diff --git a/external/kafka-0-10-assembly/pom.xml b/external/kafka-0-10-assembly/pom.xml
index 86238385a9096..b9d2c0da42830 100644
--- a/external/kafka-0-10-assembly/pom.xml
+++ b/external/kafka-0-10-assembly/pom.xml
@@ -21,7 +21,7 @@
org.apache.spark
spark-parent_2.11
- 2.4.1-kylin-r70
+ 2.4.1-kylin-r71
../../pom.xml
diff --git a/external/kafka-0-10-sql/pom.xml b/external/kafka-0-10-sql/pom.xml
index c2605595e2966..50321770056e1 100644
--- a/external/kafka-0-10-sql/pom.xml
+++ b/external/kafka-0-10-sql/pom.xml
@@ -21,7 +21,7 @@
org.apache.spark
spark-parent_2.11
- 2.4.1-kylin-r70
+ 2.4.1-kylin-r71
../../pom.xml
diff --git a/external/kafka-0-10/pom.xml b/external/kafka-0-10/pom.xml
index 1968e805154b1..26952ef2d5790 100644
--- a/external/kafka-0-10/pom.xml
+++ b/external/kafka-0-10/pom.xml
@@ -21,7 +21,7 @@
org.apache.spark
spark-parent_2.11
- 2.4.1-kylin-r70
+ 2.4.1-kylin-r71
../../pom.xml
diff --git a/external/kafka-0-8-assembly/pom.xml b/external/kafka-0-8-assembly/pom.xml
index d3db0a3991826..ae82444e6e1cc 100644
--- a/external/kafka-0-8-assembly/pom.xml
+++ b/external/kafka-0-8-assembly/pom.xml
@@ -21,7 +21,7 @@
org.apache.spark
spark-parent_2.11
- 2.4.1-kylin-r70
+ 2.4.1-kylin-r71
../../pom.xml
diff --git a/external/kafka-0-8/pom.xml b/external/kafka-0-8/pom.xml
index 99ac54a738743..182b2096f8d0e 100644
--- a/external/kafka-0-8/pom.xml
+++ b/external/kafka-0-8/pom.xml
@@ -21,7 +21,7 @@
org.apache.spark
spark-parent_2.11
- 2.4.1-kylin-r70
+ 2.4.1-kylin-r71
../../pom.xml
diff --git a/external/kinesis-asl-assembly/pom.xml b/external/kinesis-asl-assembly/pom.xml
index 56ba91c4fce2f..0f89c300301fb 100644
--- a/external/kinesis-asl-assembly/pom.xml
+++ b/external/kinesis-asl-assembly/pom.xml
@@ -21,7 +21,7 @@
org.apache.spark
spark-parent_2.11
- 2.4.1-kylin-r70
+ 2.4.1-kylin-r71
../../pom.xml
diff --git a/external/kinesis-asl/pom.xml b/external/kinesis-asl/pom.xml
index 34bd5f3a98bad..05e439562ddeb 100644
--- a/external/kinesis-asl/pom.xml
+++ b/external/kinesis-asl/pom.xml
@@ -20,7 +20,7 @@
org.apache.spark
spark-parent_2.11
- 2.4.1-kylin-r70
+ 2.4.1-kylin-r71
../../pom.xml
diff --git a/external/spark-ganglia-lgpl/pom.xml b/external/spark-ganglia-lgpl/pom.xml
index acf097f74b9e5..8c0acb7353896 100644
--- a/external/spark-ganglia-lgpl/pom.xml
+++ b/external/spark-ganglia-lgpl/pom.xml
@@ -20,7 +20,7 @@
org.apache.spark
spark-parent_2.11
- 2.4.1-kylin-r70
+ 2.4.1-kylin-r71
../../pom.xml
diff --git a/graphx/pom.xml b/graphx/pom.xml
index 5a57c65d70524..023a481d8e1df 100644
--- a/graphx/pom.xml
+++ b/graphx/pom.xml
@@ -21,7 +21,7 @@
org.apache.spark
spark-parent_2.11
- 2.4.1-kylin-r70
+ 2.4.1-kylin-r71
../pom.xml
diff --git a/hadoop-cloud/pom.xml b/hadoop-cloud/pom.xml
index a0b435c80046d..dc0ae525a27ad 100644
--- a/hadoop-cloud/pom.xml
+++ b/hadoop-cloud/pom.xml
@@ -22,7 +22,7 @@
org.apache.spark
spark-parent_2.11
- 2.4.1-kylin-r70
+ 2.4.1-kylin-r71
../pom.xml
diff --git a/launcher/pom.xml b/launcher/pom.xml
index 2ce7e88594b34..d78328e9cbc82 100644
--- a/launcher/pom.xml
+++ b/launcher/pom.xml
@@ -22,7 +22,7 @@
org.apache.spark
spark-parent_2.11
- 2.4.1-kylin-r70
+ 2.4.1-kylin-r71
../pom.xml
diff --git a/mllib-local/pom.xml b/mllib-local/pom.xml
index 754855096d594..94d4ce3857478 100644
--- a/mllib-local/pom.xml
+++ b/mllib-local/pom.xml
@@ -21,7 +21,7 @@
org.apache.spark
spark-parent_2.11
- 2.4.1-kylin-r70
+ 2.4.1-kylin-r71
../pom.xml
diff --git a/mllib/pom.xml b/mllib/pom.xml
index cb387b78f1149..42bec8fb67e55 100644
--- a/mllib/pom.xml
+++ b/mllib/pom.xml
@@ -21,7 +21,7 @@
org.apache.spark
spark-parent_2.11
- 2.4.1-kylin-r70
+ 2.4.1-kylin-r71
../pom.xml
diff --git a/pom.xml b/pom.xml
index 23224a5069fcd..b97b9351c4b79 100644
--- a/pom.xml
+++ b/pom.xml
@@ -26,7 +26,7 @@
org.apache.spark
spark-parent_2.11
- 2.4.1-kylin-r70
+ 2.4.1-kylin-r71
pom
Spark Project Parent POM
http://spark.apache.org/
diff --git a/repl/pom.xml b/repl/pom.xml
index b6d49935024bd..0e09e0dbf354c 100644
--- a/repl/pom.xml
+++ b/repl/pom.xml
@@ -21,7 +21,7 @@
org.apache.spark
spark-parent_2.11
- 2.4.1-kylin-r70
+ 2.4.1-kylin-r71
../pom.xml
diff --git a/resource-managers/kubernetes/core/pom.xml b/resource-managers/kubernetes/core/pom.xml
index 866c0cf552446..25f5c66e5d716 100644
--- a/resource-managers/kubernetes/core/pom.xml
+++ b/resource-managers/kubernetes/core/pom.xml
@@ -20,7 +20,7 @@
org.apache.spark
spark-parent_2.11
- 2.4.1-kylin-r70
+ 2.4.1-kylin-r71
../../../pom.xml
diff --git a/resource-managers/kubernetes/integration-tests/pom.xml b/resource-managers/kubernetes/integration-tests/pom.xml
index ec8334f1fccd2..adde969b1c5cc 100644
--- a/resource-managers/kubernetes/integration-tests/pom.xml
+++ b/resource-managers/kubernetes/integration-tests/pom.xml
@@ -20,7 +20,7 @@
org.apache.spark
spark-parent_2.11
- 2.4.1-kylin-r70
+ 2.4.1-kylin-r71
../../../pom.xml
diff --git a/resource-managers/mesos/pom.xml b/resource-managers/mesos/pom.xml
index ebd97a9d2d6cd..0fd3a1bce2522 100644
--- a/resource-managers/mesos/pom.xml
+++ b/resource-managers/mesos/pom.xml
@@ -20,7 +20,7 @@
org.apache.spark
spark-parent_2.11
- 2.4.1-kylin-r70
+ 2.4.1-kylin-r71
../../pom.xml
diff --git a/resource-managers/yarn/pom.xml b/resource-managers/yarn/pom.xml
index 841791a6e8a15..b3809904332a6 100644
--- a/resource-managers/yarn/pom.xml
+++ b/resource-managers/yarn/pom.xml
@@ -20,7 +20,7 @@
org.apache.spark
spark-parent_2.11
- 2.4.1-kylin-r70
+ 2.4.1-kylin-r71
../../pom.xml
diff --git a/sql/catalyst/pom.xml b/sql/catalyst/pom.xml
index 5736432b2ec19..d04cf02f4173e 100644
--- a/sql/catalyst/pom.xml
+++ b/sql/catalyst/pom.xml
@@ -22,7 +22,7 @@
org.apache.spark
spark-parent_2.11
- 2.4.1-kylin-r70
+ 2.4.1-kylin-r71
../../pom.xml
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
index 9bf9452855f5f..faf7c2ecf4c21 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
@@ -279,7 +279,7 @@ public void setDecimal(int ordinal, Decimal value, int precision) {
Platform.putLong(baseObject, baseOffset + cursor, 0L);
Platform.putLong(baseObject, baseOffset + cursor + 8, 0L);
- if (value == null) {
+ if (value == null || !value.changePrecision(precision, value.scale())) {
setNullAt(ordinal);
// keep the offset for future update
Platform.putLong(baseObject, getFieldOffset(ordinal), cursor << 32);
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala
index 82692334544e2..b0c4cc05d0ce8 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala
@@ -87,49 +87,56 @@ object DecimalPrecision extends TypeCoercionRule {
case q => q.transformExpressionsUp(
decimalAndDecimal.orElse(integralAndDecimalLiteral).orElse(nondecimalAndDecimal))
}
+ private[catalyst] def decimalAndDecimal(): PartialFunction[Expression, Expression] = {
+ decimalAndDecimal(SQLConf.get.decimalOperationsAllowPrecisionLoss, !SQLConf.get.ansiEnabled)
+ }
/** Decimal precision promotion for +, -, *, /, %, pmod, and binary comparison. */
- private[catalyst] val decimalAndDecimal: PartialFunction[Expression, Expression] = {
+ private[catalyst] def decimalAndDecimal(allowPrecisionLoss: Boolean, nullOnOverflow: Boolean)
+ : PartialFunction[Expression, Expression] = {
// Skip nodes whose children have not been resolved yet
case e if !e.childrenResolved => e
// Skip nodes who is already promoted
case e: BinaryArithmetic if e.left.isInstanceOf[PromotePrecision] => e
- case Add(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) =>
+ case a @ Add(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) =>
val resultScale = max(s1, s2)
- val resultType = if (SQLConf.get.decimalOperationsAllowPrecisionLoss) {
+ val resultType = if (allowPrecisionLoss) {
DecimalType.adjustPrecisionScale(max(p1 - s1, p2 - s2) + resultScale + 1,
resultScale)
} else {
DecimalType.bounded(max(p1 - s1, p2 - s2) + resultScale + 1, resultScale)
}
- CheckOverflow(Add(promotePrecision(e1, resultType), promotePrecision(e2, resultType)),
- resultType)
+ CheckOverflow(
+ a.copy(left = promotePrecision(e1, resultType), right = promotePrecision(e2, resultType)),
+ resultType, nullOnOverflow)
- case Subtract(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) =>
+ case s @ Subtract(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) =>
val resultScale = max(s1, s2)
- val resultType = if (SQLConf.get.decimalOperationsAllowPrecisionLoss) {
+ val resultType = if (allowPrecisionLoss) {
DecimalType.adjustPrecisionScale(max(p1 - s1, p2 - s2) + resultScale + 1,
resultScale)
} else {
DecimalType.bounded(max(p1 - s1, p2 - s2) + resultScale + 1, resultScale)
}
- CheckOverflow(Subtract(promotePrecision(e1, resultType), promotePrecision(e2, resultType)),
- resultType)
+ CheckOverflow(
+ s.copy(left = promotePrecision(e1, resultType), right = promotePrecision(e2, resultType)),
+ resultType, nullOnOverflow)
- case Multiply(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) =>
- val resultType = if (SQLConf.get.decimalOperationsAllowPrecisionLoss) {
+ case m @ Multiply(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) =>
+ val resultType = if (allowPrecisionLoss) {
DecimalType.adjustPrecisionScale(p1 + p2 + 1, s1 + s2)
} else {
DecimalType.bounded(p1 + p2 + 1, s1 + s2)
}
val widerType = widerDecimalType(p1, s1, p2, s2)
- CheckOverflow(Multiply(promotePrecision(e1, widerType), promotePrecision(e2, widerType)),
- resultType)
+ CheckOverflow(
+ m.copy(left = promotePrecision(e1, widerType), right = promotePrecision(e2, widerType)),
+ resultType, nullOnOverflow)
- case Divide(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) =>
- val resultType = if (SQLConf.get.decimalOperationsAllowPrecisionLoss) {
+ case d @ Divide(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) =>
+ val resultType = if (allowPrecisionLoss) {
// Precision: p1 - s1 + s2 + max(6, s1 + p2 + 1)
// Scale: max(6, s1 + p2 + 1)
val intDig = p1 - s1 + s2
@@ -147,30 +154,33 @@ object DecimalPrecision extends TypeCoercionRule {
DecimalType.bounded(intDig + decDig, decDig)
}
val widerType = widerDecimalType(p1, s1, p2, s2)
- CheckOverflow(Divide(promotePrecision(e1, widerType), promotePrecision(e2, widerType)),
- resultType)
+ CheckOverflow(
+ d.copy(left = promotePrecision(e1, widerType), right = promotePrecision(e2, widerType)),
+ resultType, nullOnOverflow)
- case Remainder(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) =>
- val resultType = if (SQLConf.get.decimalOperationsAllowPrecisionLoss) {
+ case r @ Remainder(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) =>
+ val resultType = if (allowPrecisionLoss) {
DecimalType.adjustPrecisionScale(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2))
} else {
DecimalType.bounded(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2))
}
// resultType may have lower precision, so we cast them into wider type first.
val widerType = widerDecimalType(p1, s1, p2, s2)
- CheckOverflow(Remainder(promotePrecision(e1, widerType), promotePrecision(e2, widerType)),
- resultType)
+ CheckOverflow(
+ r.copy(left = promotePrecision(e1, widerType), right = promotePrecision(e2, widerType)),
+ resultType, nullOnOverflow)
- case Pmod(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) =>
- val resultType = if (SQLConf.get.decimalOperationsAllowPrecisionLoss) {
+ case p @ Pmod(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) =>
+ val resultType = if (allowPrecisionLoss) {
DecimalType.adjustPrecisionScale(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2))
} else {
DecimalType.bounded(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2))
}
// resultType may have lower precision, so we cast them into wider type first.
val widerType = widerDecimalType(p1, s1, p2, s2)
- CheckOverflow(Pmod(promotePrecision(e1, widerType), promotePrecision(e2, widerType)),
- resultType)
+ CheckOverflow(
+ p.copy(left = promotePrecision(e1, widerType), right = promotePrecision(e2, widerType)),
+ resultType, nullOnOverflow)
case b @ BinaryComparison(e1 @ DecimalType.Expression(p1, s1),
e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 =>
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/StreamingJoinHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/StreamingJoinHelper.scala
index 7a0aa08289efa..c7b79df4035cd 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/StreamingJoinHelper.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/StreamingJoinHelper.scala
@@ -236,7 +236,7 @@ object StreamingJoinHelper extends PredicateHelper with Logging {
collect(left, negate) ++ collect(right, !negate)
case UnaryMinus(child) =>
collect(child, !negate)
- case CheckOverflow(child, _) =>
+ case CheckOverflow(child, _, _) =>
collect(child, negate)
case PromotePrecision(child) =>
collect(child, negate)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/errors/QueryExecutionErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/errors/QueryExecutionErrors.scala
new file mode 100644
index 0000000000000..b84643d3293dd
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/errors/QueryExecutionErrors.scala
@@ -0,0 +1,35 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.errors
+
+import org.apache.spark.sql.types._
+
+/**
+ * Object for grouping error messages from (most) exceptions thrown during query execution.
+ * This does not include exceptions thrown during the eager execution of commands, which are
+ * grouped into [[QueryCompilationErrors]].
+ */
+object QueryExecutionErrors {
+
+ def cannotChangeDecimalPrecisionError(
+ value: Decimal, decimalPrecision: Int, decimalScale: Int): ArithmeticException = {
+ new ArithmeticException(s"${value.toDebugString} cannot be represented as " +
+ s"Decimal($decimalPrecision, $decimalScale).")
+ }
+
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala
index 5ecb77be5965e..8356d6b9d1b1a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala
@@ -21,6 +21,7 @@ import org.apache.spark.sql.catalyst.analysis.{DecimalPrecision, TypeCheckResult
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.util.TypeUtils
+import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
abstract class AverageLike(child: Expression) extends DeclarativeAggregate {
@@ -57,8 +58,11 @@ abstract class AverageLike(child: Expression) extends DeclarativeAggregate {
// If all input are nulls, count will be 0 and we will get null after the division.
override lazy val evaluateExpression = child.dataType match {
- case _: DecimalType =>
- DecimalPrecision.decimalAndDecimal(sum / count.cast(DecimalType.LongDecimal)).cast(resultType)
+ case d: DecimalType =>
+ DecimalPrecision.decimalAndDecimal()(
+ Divide(
+ CheckOverflowInSum(sum, d, !SQLConf.get.ansiEnabled),
+ count.cast(DecimalType.LongDecimal))).cast(resultType)
case _ =>
sum.cast(resultType) / count.cast(resultType)
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala
index 761dba111c074..aca8c8e79cec2 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala
@@ -21,10 +21,22 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.util.TypeUtils
+import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
@ExpressionDescription(
- usage = "_FUNC_(expr) - Returns the sum calculated from values of a group.")
+ usage = "_FUNC_(expr) - Returns the sum calculated from values of a group.",
+ examples = """
+ Examples:
+ > SELECT _FUNC_(col) FROM VALUES (5), (10), (15) AS tab(col);
+ 30
+ > SELECT _FUNC_(col) FROM VALUES (NULL), (10), (15) AS tab(col);
+ 25
+ > SELECT _FUNC_(col) FROM VALUES (NULL), (NULL) AS tab(col);
+ NULL
+ """,
+ extended = "agg_funcs",
+ since = "1.0.0")
case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCastInputTypes {
override def children: Seq[Expression] = child :: Nil
@@ -46,38 +58,94 @@ case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCast
case _ => DoubleType
}
- private lazy val sumDataType = resultType
+ private lazy val sum = AttributeReference("sum", resultType)()
- private lazy val sum = AttributeReference("sum", sumDataType)()
+ private lazy val isEmpty = AttributeReference("isEmpty", BooleanType, nullable = false)()
- private lazy val zero = Cast(Literal(0), sumDataType)
+ private lazy val zero = Literal.default(resultType)
- override lazy val aggBufferAttributes = sum :: Nil
+ override lazy val aggBufferAttributes = resultType match {
+ case _: DecimalType => sum :: isEmpty :: Nil
+ case _ => sum :: Nil
+ }
- override lazy val initialValues: Seq[Expression] = Seq(
- /* sum = */ Literal.create(null, sumDataType)
- )
+ override lazy val initialValues: Seq[Expression] = resultType match {
+ case _: DecimalType => Seq(zero, Literal(true, BooleanType))
+ case _ => Seq(Literal(null, resultType))
+ }
override lazy val updateExpressions: Seq[Expression] = {
- if (child.nullable) {
- Seq(
- /* sum = */
- coalesce(coalesce(sum, zero) + child.cast(sumDataType), sum)
- )
- } else {
- Seq(
- /* sum = */
- coalesce(sum, zero) + child.cast(sumDataType)
- )
+ resultType match {
+ case _: DecimalType =>
+ // For decimal type, the initial value of `sum` is 0. We need to keep `sum` unchanged if
+ // the input is null, as SUM function ignores null input. The `sum` can only be null if
+ // overflow happens under non-ansi mode.
+ val sumExpr = if (child.nullable) {
+ If(child.isNull, sum, sum + KnownNotNull(child).cast(resultType))
+ } else {
+ sum + child.cast(resultType)
+ }
+ // The buffer becomes non-empty after seeing the first not-null input.
+ val isEmptyExpr = if (child.nullable) {
+ isEmpty && child.isNull
+ } else {
+ Literal(false, BooleanType)
+ }
+ Seq(sumExpr, isEmptyExpr)
+ case _ =>
+ // For non-decimal type, the initial value of `sum` is null, which indicates no value.
+ // We need `coalesce(sum, zero)` to start summing values. And we need an outer `coalesce`
+ // in case the input is nullable. The `sum` can only be null if there is no value, as
+ // non-decimal type can produce overflowed value under non-ansi mode.
+ if (child.nullable) {
+ Seq(coalesce(coalesce(sum, zero) + child.cast(resultType), sum))
+ } else {
+ Seq(coalesce(sum, zero) + child.cast(resultType))
+ }
}
}
+ /**
+ * For decimal type:
+ * If isEmpty is false and if sum is null, then it means we have had an overflow.
+ *
+ * update of the sum is as follows:
+ * Check if either portion of the left.sum or right.sum has overflowed
+ * If it has, then the sum value will remain null.
+ * If it did not have overflow, then add the sum.left and sum.right
+ *
+ * isEmpty: Set to false if either one of the left or right is set to false. This
+ * means we have seen atleast a value that was not null.
+ */
override lazy val mergeExpressions: Seq[Expression] = {
- Seq(
- /* sum = */
- coalesce(coalesce(sum.left, zero) + sum.right, sum.left)
- )
+ resultType match {
+ case _: DecimalType =>
+ val bufferOverflow = !isEmpty.left && sum.left.isNull
+ val inputOverflow = !isEmpty.right && sum.right.isNull
+ Seq(
+ If(
+ bufferOverflow || inputOverflow,
+ Literal.create(null, resultType),
+ // If both the buffer and the input do not overflow, just add them, as they can't be
+ // null. See the comments inside `updateExpressions`: `sum` can only be null if
+ // overflow happens.
+ KnownNotNull(sum.left) + KnownNotNull(sum.right)),
+ isEmpty.left && isEmpty.right)
+ case _ => Seq(coalesce(coalesce(sum.left, zero) + sum.right, sum.left))
+ }
}
- override lazy val evaluateExpression: Expression = sum
+ /**
+ * If the isEmpty is true, then it means there were no values to begin with or all the values
+ * were null, so the result will be null.
+ * If the isEmpty is false, then if sum is null that means an overflow has happened.
+ * So now, if ansi is enabled, then throw exception, if not then return null.
+ * If sum is not null, then return the sum.
+ */
+ override lazy val evaluateExpression: Expression = resultType match {
+ case d: DecimalType =>
+ If(isEmpty, Literal.create(null, resultType),
+ CheckOverflowInSum(sum, d, !SQLConf.get.ansiEnabled))
+ case _ => sum
+ }
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala
index 17d4a0dc4e884..7bfaf0fd0c767 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala
@@ -224,7 +224,7 @@ object Block {
} else {
args.foreach {
case _: ExprValue | _: Inline | _: Block =>
- case _: Int | _: Long | _: Float | _: Double | _: String =>
+ case _: Boolean | _: Byte | _: Int | _: Long | _: Float | _: Double | _: String =>
case other => throw new IllegalArgumentException(
s"Can not interpolate ${other.getClass.getName} into code block.")
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala
index 04de83343be71..7e4560ab8161b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala
@@ -19,6 +19,8 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, EmptyBlock, ExprCode}
+import org.apache.spark.sql.catalyst.expressions.codegen.Block._
+import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
/**
@@ -26,7 +28,7 @@ import org.apache.spark.sql.types._
* Note: this expression is internal and created only by the optimizer,
* we don't need to do type check for it.
*/
-case class UnscaledValue(child: Expression) extends UnaryExpression {
+case class UnscaledValue(child: Expression) extends UnaryExpression with NullIntolerant {
override def dataType: DataType = LongType
override def toString: String = s"UnscaledValue($child)"
@@ -44,25 +46,56 @@ case class UnscaledValue(child: Expression) extends UnaryExpression {
* Note: this expression is internal and created only by the optimizer,
* we don't need to do type check for it.
*/
-case class MakeDecimal(child: Expression, precision: Int, scale: Int) extends UnaryExpression {
+case class MakeDecimal(
+ child: Expression,
+ precision: Int,
+ scale: Int,
+ nullOnOverflow: Boolean) extends UnaryExpression with NullIntolerant {
+
+ def this(child: Expression, precision: Int, scale: Int) = {
+ this(child, precision, scale, !SQLConf.get.ansiEnabled)
+ }
override def dataType: DataType = DecimalType(precision, scale)
- override def nullable: Boolean = true
+ override def nullable: Boolean = child.nullable || nullOnOverflow
override def toString: String = s"MakeDecimal($child,$precision,$scale)"
- protected override def nullSafeEval(input: Any): Any =
- Decimal(input.asInstanceOf[Long], precision, scale)
+ protected override def nullSafeEval(input: Any): Any = {
+ val longInput = input.asInstanceOf[Long]
+ val result = new Decimal()
+ if (nullOnOverflow) {
+ result.setOrNull(longInput, precision, scale)
+ } else {
+ result.set(longInput, precision, scale)
+ }
+ }
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
nullSafeCodeGen(ctx, ev, eval => {
+ val setMethod = if (nullOnOverflow) {
+ "setOrNull"
+ } else {
+ "set"
+ }
+ val setNull = if (nullable) {
+ s"${ev.isNull} = ${ev.value} == null;"
+ } else {
+ ""
+ }
s"""
- ${ev.value} = (new Decimal()).setOrNull($eval, $precision, $scale);
- ${ev.isNull} = ${ev.value} == null;
- """
+ |${ev.value} = (new Decimal()).$setMethod($eval, $precision, $scale);
+ |$setNull
+ |""".stripMargin
})
}
}
+object MakeDecimal {
+ def apply(child: Expression, precision: Int, scale: Int): MakeDecimal = {
+ new MakeDecimal(child, precision, scale)
+ }
+}
+
/**
* An expression used to wrap the children when promote the precision of DecimalType to avoid
* promote multiple times.
@@ -81,30 +114,85 @@ case class PromotePrecision(child: Expression) extends UnaryExpression {
/**
* Rounds the decimal to given scale and check whether the decimal can fit in provided precision
- * or not, returns null if not.
+ * or not. If not, if `nullOnOverflow` is `true`, it returns `null`; otherwise an
+ * `ArithmeticException` is thrown.
*/
-case class CheckOverflow(child: Expression, dataType: DecimalType) extends UnaryExpression {
+case class CheckOverflow(
+ child: Expression,
+ dataType: DecimalType,
+ nullOnOverflow: Boolean) extends UnaryExpression {
override def nullable: Boolean = true
override def nullSafeEval(input: Any): Any =
- input.asInstanceOf[Decimal].toPrecision(dataType.precision, dataType.scale)
+ input.asInstanceOf[Decimal].toPrecision(
+ dataType.precision,
+ dataType.scale,
+ Decimal.ROUND_HALF_UP,
+ nullOnOverflow)
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
nullSafeCodeGen(ctx, ev, eval => {
- val tmp = ctx.freshName("tmp")
s"""
- | Decimal $tmp = $eval.clone();
- | if ($tmp.changePrecision(${dataType.precision}, ${dataType.scale})) {
- | ${ev.value} = $tmp;
- | } else {
- | ${ev.isNull} = true;
- | }
+ |${ev.value} = $eval.toPrecision(
+ | ${dataType.precision}, ${dataType.scale}, Decimal.ROUND_HALF_UP(), $nullOnOverflow);
+ |${ev.isNull} = ${ev.value} == null;
""".stripMargin
})
}
- override def toString: String = s"CheckOverflow($child, $dataType)"
+ override def toString: String = s"CheckOverflow($child, $dataType, $nullOnOverflow)"
+
+ override def sql: String = child.sql
+}
+
+// A variant `CheckOverflow`, which treats null as overflow. This is necessary in `Sum`.
+case class CheckOverflowInSum(
+ child: Expression,
+ dataType: DecimalType,
+ nullOnOverflow: Boolean) extends UnaryExpression {
+
+ override def nullable: Boolean = true
+
+ override def eval(input: InternalRow): Any = {
+ val value = child.eval(input)
+ if (value == null) {
+ if (nullOnOverflow) null else throw new ArithmeticException("Overflow in sum of decimals.")
+ } else {
+ value.asInstanceOf[Decimal].toPrecision(
+ dataType.precision,
+ dataType.scale,
+ Decimal.ROUND_HALF_UP,
+ nullOnOverflow)
+ }
+ }
+
+ override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+ val childGen = child.genCode(ctx)
+ val nullHandling = if (nullOnOverflow) {
+ ""
+ } else {
+ s"""
+ |throw new ArithmeticException("Overflow in sum of decimals.");
+ |""".stripMargin
+ }
+ val code = code"""
+ |${childGen.code}
+ |boolean ${ev.isNull} = ${childGen.isNull};
+ |Decimal ${ev.value} = null;
+ |if (${childGen.isNull}) {
+ | $nullHandling
+ |} else {
+ | ${ev.value} = ${childGen.value}.toPrecision(
+ | ${dataType.precision}, ${dataType.scale}, Decimal.ROUND_HALF_UP(), $nullOnOverflow);
+ | ${ev.isNull} = ${ev.value} == null;
+ |}
+ |""".stripMargin
+
+ ev.copy(code = code)
+ }
+
+ override def toString: String = s"CheckOverflowInSum($child, $dataType, $nullOnOverflow)"
override def sql: String = child.sql
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index d7409c5efa372..c9c898a8de344 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -150,6 +150,16 @@ object SQLConf {
}
}
+ val ANSI_ENABLED = buildConf("spark.sql.ansi.enabled")
+ .doc("When true, Spark SQL uses an ANSI compliant dialect instead of being Hive compliant. " +
+ "For example, Spark will throw an exception at runtime instead of returning null results " +
+ "when the inputs to a SQL operator/function are invalid." +
+ "For full details of this dialect, you can find them in the section \"ANSI Compliance\" of " +
+ "Spark's documentation. Some ANSI dialect features may be not from the ANSI SQL " +
+ "standard directly, but their behaviors align with ANSI SQL's style")
+ .booleanConf
+ .createWithDefault(false)
+
val OPTIMIZER_EXCLUDED_RULES = buildConf("spark.sql.optimizer.excludedRules")
.doc("Configures a list of rules to be disabled in the optimizer, in which the rules are " +
"specified by their rule names and separated by comma. It is not guaranteed that all the " +
@@ -1617,6 +1627,8 @@ class SQLConf extends Serializable with Logging {
/** ************************ Spark SQL Params/Hints ******************* */
+ def ansiEnabled: Boolean = getConf(ANSI_ENABLED)
+
def optimizerExcludedRules: Option[String] = getConf(OPTIMIZER_EXCLUDED_RULES)
def optimizerMaxIterations: Int = getConf(OPTIMIZER_MAX_ITERATIONS)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala
index 9eed2eb202045..33c0cd07c8d07 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala
@@ -22,6 +22,7 @@ import java.math.{BigInteger, MathContext, RoundingMode}
import org.apache.spark.annotation.InterfaceStability
import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.catalyst.errors.QueryExecutionErrors
/**
* A mutable implementation of BigDecimal that can hold a Long if values are small enough.
@@ -242,9 +243,25 @@ final class Decimal extends Ordered[Decimal] with Serializable {
private[sql] def toPrecision(
precision: Int,
scale: Int,
- roundMode: BigDecimal.RoundingMode.Value = ROUND_HALF_UP): Decimal = {
+ roundMode: BigDecimal.RoundingMode.Value,
+ nullOnOverflow: Boolean): Decimal = {
val copy = clone()
- if (copy.changePrecision(precision, scale, roundMode)) copy else null
+ if (copy.changePrecision(precision, scale, roundMode)) {
+ copy
+ } else {
+ if (nullOnOverflow) {
+ null
+ } else {
+ throw QueryExecutionErrors.cannotChangeDecimalPrecisionError(this, precision, scale)
+ }
+ }
+ }
+
+ private[sql] def toPrecision(
+ precision: Int,
+ scale: Int,
+ roundMode: BigDecimal.RoundingMode.Value = ROUND_HALF_UP): Decimal = {
+ toPrecision(precision, scale, roundMode, true)
}
/**
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DecimalExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DecimalExpressionSuite.scala
index a8f758d625a02..941bab2ea01e9 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DecimalExpressionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DecimalExpressionSuite.scala
@@ -45,18 +45,19 @@ class DecimalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper {
test("CheckOverflow") {
val d1 = Decimal("10.1")
- checkEvaluation(CheckOverflow(Literal(d1), DecimalType(4, 0)), Decimal("10"))
- checkEvaluation(CheckOverflow(Literal(d1), DecimalType(4, 1)), d1)
- checkEvaluation(CheckOverflow(Literal(d1), DecimalType(4, 2)), d1)
- checkEvaluation(CheckOverflow(Literal(d1), DecimalType(4, 3)), null)
+ checkEvaluation(CheckOverflow(Literal(d1), DecimalType(4, 0), true), Decimal("10"))
+ checkEvaluation(CheckOverflow(Literal(d1), DecimalType(4, 1), true), d1)
+ checkEvaluation(CheckOverflow(Literal(d1), DecimalType(4, 2), true), d1)
+ checkEvaluation(CheckOverflow(Literal(d1), DecimalType(4, 3), true), null)
val d2 = Decimal(101, 3, 1)
- checkEvaluation(CheckOverflow(Literal(d2), DecimalType(4, 0)), Decimal("10"))
- checkEvaluation(CheckOverflow(Literal(d2), DecimalType(4, 1)), d2)
- checkEvaluation(CheckOverflow(Literal(d2), DecimalType(4, 2)), d2)
- checkEvaluation(CheckOverflow(Literal(d2), DecimalType(4, 3)), null)
+ checkEvaluation(CheckOverflow(Literal(d2), DecimalType(4, 0), true), Decimal("10"))
+ checkEvaluation(CheckOverflow(Literal(d2), DecimalType(4, 1), true), d2)
+ checkEvaluation(CheckOverflow(Literal(d2), DecimalType(4, 2), true), d2)
+ checkEvaluation(CheckOverflow(Literal(d2), DecimalType(4, 3), true), null)
- checkEvaluation(CheckOverflow(Literal.create(null, DecimalType(2, 1)), DecimalType(3, 2)), null)
+ checkEvaluation(
+ CheckOverflow(Literal.create(null, DecimalType(2, 1)), DecimalType(3, 2), true), null)
}
}
diff --git a/sql/core/pom.xml b/sql/core/pom.xml
index 0db539b4e8762..3f102462b6fd5 100644
--- a/sql/core/pom.xml
+++ b/sql/core/pom.xml
@@ -22,7 +22,7 @@
org.apache.spark
spark-parent_2.11
- 2.4.1-kylin-r70
+ 2.4.1-kylin-r71
../../pom.xml
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index 5075209d7454f..c19250b9325b7 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -17,39 +17,41 @@
package org.apache.spark.sql
-import java.io.File
+import java.io.{ByteArrayOutputStream, File}
import java.nio.charset.StandardCharsets
import java.sql.{Date, Timestamp}
import java.util.UUID
-
+import java.util.concurrent.atomic.AtomicLong
+import scala.reflect.runtime.universe.TypeTag
import scala.util.Random
-
-import org.scalatest.Matchers._
-
+import org.scalatest.Matchers.{assert, intercept, _}
import org.apache.spark.SparkException
import org.apache.spark.scheduler.{SparkListener, SparkListenerJobEnd}
import org.apache.spark.sql.catalyst.TableIdentifier
+import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder}
import org.apache.spark.sql.catalyst.expressions.Uuid
import org.apache.spark.sql.catalyst.optimizer.ConvertToLocalRelation
-import org.apache.spark.sql.catalyst.plans.logical.{Filter, OneRowRelation, Union}
+import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, OneRowRelation}
+import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.execution.{FilterExec, QueryExecution, WholeStageCodegenExec}
import org.apache.spark.sql.execution.aggregate.HashAggregateExec
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
-import org.apache.spark.sql.test.{ExamplePoint, ExamplePointUDT, SharedSQLContext}
-import org.apache.spark.sql.test.SQLTestData.{NullInts, NullStrings, TestData2}
+import org.apache.spark.sql.test.{ExamplePoint, ExamplePointUDT, SharedSparkSession}
+import org.apache.spark.sql.test.SQLTestData.{DecimalData, NullStrings, TestData2}
import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.CalendarInterval
import org.apache.spark.util.Utils
import org.apache.spark.util.random.XORShiftRandom
-class DataFrameSuite extends QueryTest with SharedSQLContext {
+class DataFrameSuite extends QueryTest with SharedSparkSession {
import testImplicits._
test("analysis error should be eagerly reported") {
- intercept[Exception] { testData.select('nonExistentName) }
+ intercept[Exception] { testData.select("nonExistentName") }
intercept[Exception] {
- testData.groupBy('key).agg(Map("nonExistentName" -> "sum"))
+ testData.groupBy("key").agg(Map("nonExistentName" -> "sum"))
}
intercept[Exception] {
testData.groupBy("nonExistentName").agg(Map("key" -> "sum"))
@@ -85,129 +87,12 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
testData.collect().toSeq)
}
- test("union all") {
- val unionDF = testData.union(testData).union(testData)
- .union(testData).union(testData)
-
- // Before optimizer, Union should be combined.
- assert(unionDF.queryExecution.analyzed.collect {
- case j: Union if j.children.size == 5 => j }.size === 1)
-
- checkAnswer(
- unionDF.agg(avg('key), max('key), min('key), sum('key)),
- Row(50.5, 100, 1, 25250) :: Nil
- )
- }
-
- test("union should union DataFrames with UDTs (SPARK-13410)") {
- val rowRDD1 = sparkContext.parallelize(Seq(Row(1, new ExamplePoint(1.0, 2.0))))
- val schema1 = StructType(Array(StructField("label", IntegerType, false),
- StructField("point", new ExamplePointUDT(), false)))
- val rowRDD2 = sparkContext.parallelize(Seq(Row(2, new ExamplePoint(3.0, 4.0))))
- val schema2 = StructType(Array(StructField("label", IntegerType, false),
- StructField("point", new ExamplePointUDT(), false)))
- val df1 = spark.createDataFrame(rowRDD1, schema1)
- val df2 = spark.createDataFrame(rowRDD2, schema2)
-
- checkAnswer(
- df1.union(df2).orderBy("label"),
- Seq(Row(1, new ExamplePoint(1.0, 2.0)), Row(2, new ExamplePoint(3.0, 4.0)))
- )
- }
-
- test("union by name") {
- var df1 = Seq((1, 2, 3)).toDF("a", "b", "c")
- var df2 = Seq((3, 1, 2)).toDF("c", "a", "b")
- val df3 = Seq((2, 3, 1)).toDF("b", "c", "a")
- val unionDf = df1.unionByName(df2.unionByName(df3))
- checkAnswer(unionDf,
- Row(1, 2, 3) :: Row(1, 2, 3) :: Row(1, 2, 3) :: Nil
- )
-
- // Check if adjacent unions are combined into a single one
- assert(unionDf.queryExecution.optimizedPlan.collect { case u: Union => true }.size == 1)
-
- // Check failure cases
- df1 = Seq((1, 2)).toDF("a", "c")
- df2 = Seq((3, 4, 5)).toDF("a", "b", "c")
- var errMsg = intercept[AnalysisException] {
- df1.unionByName(df2)
- }.getMessage
- assert(errMsg.contains(
- "Union can only be performed on tables with the same number of columns, " +
- "but the first table has 2 columns and the second table has 3 columns"))
-
- df1 = Seq((1, 2, 3)).toDF("a", "b", "c")
- df2 = Seq((4, 5, 6)).toDF("a", "c", "d")
- errMsg = intercept[AnalysisException] {
- df1.unionByName(df2)
- }.getMessage
- assert(errMsg.contains("""Cannot resolve column name "b" among (a, c, d)"""))
- }
-
- test("union by name - type coercion") {
- var df1 = Seq((1, "a")).toDF("c0", "c1")
- var df2 = Seq((3, 1L)).toDF("c1", "c0")
- checkAnswer(df1.unionByName(df2), Row(1L, "a") :: Row(1L, "3") :: Nil)
-
- df1 = Seq((1, 1.0)).toDF("c0", "c1")
- df2 = Seq((8L, 3.0)).toDF("c1", "c0")
- checkAnswer(df1.unionByName(df2), Row(1.0, 1.0) :: Row(3.0, 8.0) :: Nil)
-
- df1 = Seq((2.0f, 7.4)).toDF("c0", "c1")
- df2 = Seq(("a", 4.0)).toDF("c1", "c0")
- checkAnswer(df1.unionByName(df2), Row(2.0, "7.4") :: Row(4.0, "a") :: Nil)
-
- df1 = Seq((1, "a", 3.0)).toDF("c0", "c1", "c2")
- df2 = Seq((1.2, 2, "bc")).toDF("c2", "c0", "c1")
- val df3 = Seq(("def", 1.2, 3)).toDF("c1", "c2", "c0")
- checkAnswer(df1.unionByName(df2.unionByName(df3)),
- Row(1, "a", 3.0) :: Row(2, "bc", 1.2) :: Row(3, "def", 1.2) :: Nil
- )
- }
-
- test("union by name - check case sensitivity") {
- def checkCaseSensitiveTest(): Unit = {
- val df1 = Seq((1, 2, 3)).toDF("ab", "cd", "ef")
- val df2 = Seq((4, 5, 6)).toDF("cd", "ef", "AB")
- checkAnswer(df1.unionByName(df2), Row(1, 2, 3) :: Row(6, 4, 5) :: Nil)
- }
- withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") {
- val errMsg2 = intercept[AnalysisException] {
- checkCaseSensitiveTest()
- }.getMessage
- assert(errMsg2.contains("""Cannot resolve column name "ab" among (cd, ef, AB)"""))
- }
- withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") {
- checkCaseSensitiveTest()
- }
- }
-
- test("union by name - check name duplication") {
- Seq((true, ("a", "a")), (false, ("aA", "Aa"))).foreach { case (caseSensitive, (c0, c1)) =>
- withSQLConf(SQLConf.CASE_SENSITIVE.key -> caseSensitive.toString) {
- var df1 = Seq((1, 1)).toDF(c0, c1)
- var df2 = Seq((1, 1)).toDF("c0", "c1")
- var errMsg = intercept[AnalysisException] {
- df1.unionByName(df2)
- }.getMessage
- assert(errMsg.contains("Found duplicate column(s) in the left attributes:"))
- df1 = Seq((1, 1)).toDF("c0", "c1")
- df2 = Seq((1, 1)).toDF(c0, c1)
- errMsg = intercept[AnalysisException] {
- df1.unionByName(df2)
- }.getMessage
- assert(errMsg.contains("Found duplicate column(s) in the right attributes:"))
- }
- }
- }
-
test("empty data frame") {
assert(spark.emptyDataFrame.columns.toSeq === Seq.empty[String])
assert(spark.emptyDataFrame.count() === 0)
}
- test("head and take") {
+ test("head, take") {
assert(testData.take(2) === testData.collect().take(2))
assert(testData.head(2) === testData.collect().take(2))
assert(testData.head(2).head.schema === testData.schema)
@@ -248,8 +133,10 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
test("Star Expansion - CreateStruct and CreateArray") {
val structDf = testData2.select("a", "b").as("record")
// CreateStruct and CreateArray in aggregateExpressions
- assert(structDf.groupBy($"a").agg(min(struct($"record.*"))).first() == Row(3, Row(3, 1)))
- assert(structDf.groupBy($"a").agg(min(array($"record.*"))).first() == Row(3, Seq(3, 1)))
+ assert(structDf.groupBy($"a").agg(min(struct($"record.*"))).
+ sort("a").first() == Row(1, Row(1, 1)))
+ assert(structDf.groupBy($"a").agg(min(array($"record.*"))).
+ sort("a").first() == Row(1, Seq(1, 1)))
// CreateStruct and CreateArray in project list (unresolved alias)
assert(structDf.select(struct($"record.*")).first() == Row(Row(1, 1)))
@@ -279,7 +166,118 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
structDf.select(hash($"a", $"record.*")))
}
- test("Star Expansion - explode should fail with a meaningful message if it takes a star") {
+ private def assertDecimalSumOverflow(
+ df: DataFrame, ansiEnabled: Boolean, expectedAnswer: Row): Unit = {
+ if (!ansiEnabled) {
+ checkAnswer(df, expectedAnswer)
+ } else {
+ val e = intercept[SparkException] {
+ df.collect()
+ }
+ assert(e.getCause.isInstanceOf[ArithmeticException])
+ assert(e.getCause.getMessage.contains("cannot be represented as Decimal") ||
+ e.getCause.getMessage.contains("Overflow in sum of decimals"))
+ }
+ }
+
+ test("SPARK-28224: Aggregate sum big decimal overflow") {
+ val largeDecimals = spark.sparkContext.parallelize(
+ DecimalData(BigDecimal("1"* 20 + ".123"), BigDecimal("1"* 20 + ".123")) ::
+ DecimalData(BigDecimal("9"* 20 + ".123"), BigDecimal("9"* 20 + ".123")) :: Nil).toDF()
+
+ Seq(true, false).foreach { ansiEnabled =>
+ withSQLConf((SQLConf.ANSI_ENABLED.key, ansiEnabled.toString)) {
+ val structDf = largeDecimals.select("a").agg(sum("a"))
+ assertDecimalSumOverflow(structDf, ansiEnabled, Row(null))
+ }
+ }
+ }
+
+ test("SPARK-28067: sum of null decimal values") {
+ Seq("true", "false").foreach { wholeStageEnabled =>
+ withSQLConf((SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, wholeStageEnabled)) {
+ Seq("true", "false").foreach { ansiEnabled =>
+ withSQLConf((SQLConf.ANSI_ENABLED.key, ansiEnabled)) {
+ val df = spark.range(1, 4, 1).select(expr(s"cast(null as decimal(38,18)) as d"))
+ checkAnswer(df.agg(sum($"d")), Row(null))
+ }
+ }
+ }
+ }
+ }
+
+ test("SPARK-28067: Aggregate sum should not return wrong results for decimal overflow") {
+ Seq("true", "false").foreach { wholeStageEnabled =>
+ withSQLConf((SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, wholeStageEnabled)) {
+ Seq(true, false).foreach { ansiEnabled =>
+ withSQLConf((SQLConf.ANSI_ENABLED.key, ansiEnabled.toString)) {
+ val df0 = Seq(
+ (BigDecimal("10000000000000000000"), 1),
+ (BigDecimal("10000000000000000000"), 1),
+ (BigDecimal("10000000000000000000"), 2)).toDF("decNum", "intNum")
+ val df1 = Seq(
+ (BigDecimal("10000000000000000000"), 2),
+ (BigDecimal("10000000000000000000"), 2),
+ (BigDecimal("10000000000000000000"), 2),
+ (BigDecimal("10000000000000000000"), 2),
+ (BigDecimal("10000000000000000000"), 2),
+ (BigDecimal("10000000000000000000"), 2),
+ (BigDecimal("10000000000000000000"), 2),
+ (BigDecimal("10000000000000000000"), 2),
+ (BigDecimal("10000000000000000000"), 2)).toDF("decNum", "intNum")
+ val df = df0.union(df1)
+ val df2 = df.withColumnRenamed("decNum", "decNum2").
+ join(df, "intNum").agg(sum("decNum"))
+
+ val expectedAnswer = Row(null)
+ assertDecimalSumOverflow(df2, ansiEnabled, expectedAnswer)
+
+ val decStr = "1" + "0" * 19
+ val d1 = spark.range(0, 12, 1, 1)
+ val d2 = d1.select(expr(s"cast('$decStr' as decimal (38, 18)) as d")).agg(sum($"d"))
+ assertDecimalSumOverflow(d2, ansiEnabled, expectedAnswer)
+
+ val d3 = spark.range(0, 1, 1, 1).union(spark.range(0, 11, 1, 1))
+ val d4 = d3.select(expr(s"cast('$decStr' as decimal (38, 18)) as d")).agg(sum($"d"))
+ assertDecimalSumOverflow(d4, ansiEnabled, expectedAnswer)
+
+ val d5 = d3.select(expr(s"cast('$decStr' as decimal (38, 18)) as d"),
+ lit(1).as("key")).groupBy("key").agg(sum($"d").alias("sumd")).select($"sumd")
+ assertDecimalSumOverflow(d5, ansiEnabled, expectedAnswer)
+
+ val nullsDf = spark.range(1, 4, 1).select(expr(s"cast(null as decimal(38,18)) as d"))
+
+ val largeDecimals = Seq(BigDecimal("1"* 20 + ".123"), BigDecimal("9"* 20 + ".123")).
+ toDF("d")
+ assertDecimalSumOverflow(
+ nullsDf.union(largeDecimals).agg(sum($"d")), ansiEnabled, expectedAnswer)
+
+ val df3 = Seq(
+ (BigDecimal("10000000000000000000"), 1),
+ (BigDecimal("50000000000000000000"), 1),
+ (BigDecimal("10000000000000000000"), 2)).toDF("decNum", "intNum")
+
+ val df4 = Seq(
+ (BigDecimal("10000000000000000000"), 1),
+ (BigDecimal("10000000000000000000"), 1),
+ (BigDecimal("10000000000000000000"), 2)).toDF("decNum", "intNum")
+
+ val df5 = Seq(
+ (BigDecimal("10000000000000000000"), 1),
+ (BigDecimal("10000000000000000000"), 1),
+ (BigDecimal("20000000000000000000"), 2)).toDF("decNum", "intNum")
+
+ val df6 = df3.union(df4).union(df5)
+ val df7 = df6.groupBy("intNum").agg(sum("decNum"), countDistinct("decNum")).
+ filter("intNum == 1")
+ assertDecimalSumOverflow(df7, ansiEnabled, Row(1, null, 2))
+ }
+ }
+ }
+ }
+ }
+
+ test("Star Expansion - ds.explode should fail with a meaningful message if it takes a star") {
val df = Seq(("1", "1,2"), ("2", "4"), ("3", "7,8,9")).toDF("prefix", "csv")
val e = intercept[AnalysisException] {
df.explode($"*") { case Row(prefix: String, csv: String) =>
@@ -300,6 +298,21 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
Row("3", "7,8,9", "3:9") :: Nil)
}
+ test("Star Expansion - explode should fail with a meaningful message if it takes a star") {
+ val df = Seq(("1,2"), ("4"), ("7,8,9")).toDF("csv")
+ val e = intercept[AnalysisException] {
+ df.select(explode($"*"))
+ }
+ assert(e.getMessage.contains("Invalid usage of '*' in expression 'explode'"))
+ }
+
+ test("explode on output of array-valued function") {
+ val df = Seq(("1,2"), ("4"), ("7,8,9")).toDF("csv")
+ checkAnswer(
+ df.select(explode(split($"csv", pattern = ","))),
+ Row("1") :: Row("2") :: Row("4") :: Row("7") :: Row("8") :: Row("9") :: Nil)
+ }
+
test("Star Expansion - explode alias and star") {
val df = Seq((Array("a"), 1)).toDF("a", "b")
@@ -354,12 +367,12 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
test("repartition") {
intercept[IllegalArgumentException] {
- testData.select('key).repartition(0)
+ testData.select("key").repartition(0)
}
checkAnswer(
- testData.select('key).repartition(10).select('key),
- testData.select('key).collect().toSeq)
+ testData.select("key").repartition(10).select("key"),
+ testData.select("key").collect().toSeq)
}
test("repartition with SortOrder") {
@@ -421,14 +434,14 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
test("coalesce") {
intercept[IllegalArgumentException] {
- testData.select('key).coalesce(0)
+ testData.select("key").coalesce(0)
}
- assert(testData.select('key).coalesce(1).rdd.partitions.size === 1)
+ assert(testData.select("key").coalesce(1).rdd.partitions.size === 1)
checkAnswer(
- testData.select('key).coalesce(1).select('key),
- testData.select('key).collect().toSeq)
+ testData.select("key").coalesce(1).select("key"),
+ testData.select("key").collect().toSeq)
assert(spark.emptyDataFrame.coalesce(1).rdd.partitions.size === 1)
}
@@ -441,7 +454,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
test("convert Scala Symbol 'attrname into unresolved attribute") {
checkAnswer(
- testData.where('key === lit(1)).select('value),
+ testData.where($"key" === lit(1)).select("value"),
Row("1"))
}
@@ -453,17 +466,17 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
test("simple select") {
checkAnswer(
- testData.where('key === lit(1)).select('value),
+ testData.where($"key" === lit(1)).select("value"),
Row("1"))
}
test("select with functions") {
checkAnswer(
- testData.select(sum('value), avg('value), count(lit(1))),
+ testData.select(sum("value"), avg("value"), count(lit(1))),
Row(5050.0, 50.5, 100))
checkAnswer(
- testData2.select('a + 'b, 'a < 'b),
+ testData2.select($"a" + $"b", $"a" < $"b"),
Seq(
Row(2, false),
Row(3, true),
@@ -473,31 +486,31 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
Row(5, false)))
checkAnswer(
- testData2.select(sumDistinct('a)),
+ testData2.select(sumDistinct($"a")),
Row(6))
}
test("sorting with null ordering") {
val data = Seq[java.lang.Integer](2, 1, null).toDF("key")
- checkAnswer(data.orderBy('key.asc), Row(null) :: Row(1) :: Row(2) :: Nil)
+ checkAnswer(data.orderBy($"key".asc), Row(null) :: Row(1) :: Row(2) :: Nil)
checkAnswer(data.orderBy(asc("key")), Row(null) :: Row(1) :: Row(2) :: Nil)
- checkAnswer(data.orderBy('key.asc_nulls_first), Row(null) :: Row(1) :: Row(2) :: Nil)
+ checkAnswer(data.orderBy($"key".asc_nulls_first), Row(null) :: Row(1) :: Row(2) :: Nil)
checkAnswer(data.orderBy(asc_nulls_first("key")), Row(null) :: Row(1) :: Row(2) :: Nil)
- checkAnswer(data.orderBy('key.asc_nulls_last), Row(1) :: Row(2) :: Row(null) :: Nil)
+ checkAnswer(data.orderBy($"key".asc_nulls_last), Row(1) :: Row(2) :: Row(null) :: Nil)
checkAnswer(data.orderBy(asc_nulls_last("key")), Row(1) :: Row(2) :: Row(null) :: Nil)
- checkAnswer(data.orderBy('key.desc), Row(2) :: Row(1) :: Row(null) :: Nil)
+ checkAnswer(data.orderBy($"key".desc), Row(2) :: Row(1) :: Row(null) :: Nil)
checkAnswer(data.orderBy(desc("key")), Row(2) :: Row(1) :: Row(null) :: Nil)
- checkAnswer(data.orderBy('key.desc_nulls_first), Row(null) :: Row(2) :: Row(1) :: Nil)
+ checkAnswer(data.orderBy($"key".desc_nulls_first), Row(null) :: Row(2) :: Row(1) :: Nil)
checkAnswer(data.orderBy(desc_nulls_first("key")), Row(null) :: Row(2) :: Row(1) :: Nil)
- checkAnswer(data.orderBy('key.desc_nulls_last), Row(2) :: Row(1) :: Row(null) :: Nil)
+ checkAnswer(data.orderBy($"key".desc_nulls_last), Row(2) :: Row(1) :: Row(null) :: Nil)
checkAnswer(data.orderBy(desc_nulls_last("key")), Row(2) :: Row(1) :: Row(null) :: Nil)
}
test("global sorting") {
checkAnswer(
- testData2.orderBy('a.asc, 'b.asc),
+ testData2.orderBy($"a".asc, $"b".asc),
Seq(Row(1, 1), Row(1, 2), Row(2, 1), Row(2, 2), Row(3, 1), Row(3, 2)))
checkAnswer(
@@ -505,31 +518,31 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
Seq(Row(1, 2), Row(1, 1), Row(2, 2), Row(2, 1), Row(3, 2), Row(3, 1)))
checkAnswer(
- testData2.orderBy('a.asc, 'b.desc),
+ testData2.orderBy($"a".asc, $"b".desc),
Seq(Row(1, 2), Row(1, 1), Row(2, 2), Row(2, 1), Row(3, 2), Row(3, 1)))
checkAnswer(
- testData2.orderBy('a.desc, 'b.desc),
+ testData2.orderBy($"a".desc, $"b".desc),
Seq(Row(3, 2), Row(3, 1), Row(2, 2), Row(2, 1), Row(1, 2), Row(1, 1)))
checkAnswer(
- testData2.orderBy('a.desc, 'b.asc),
+ testData2.orderBy($"a".desc, $"b".asc),
Seq(Row(3, 1), Row(3, 2), Row(2, 1), Row(2, 2), Row(1, 1), Row(1, 2)))
checkAnswer(
- arrayData.toDF().orderBy('data.getItem(0).asc),
+ arrayData.toDF().orderBy($"data".getItem(0).asc),
arrayData.toDF().collect().sortBy(_.getAs[Seq[Int]](0)(0)).toSeq)
checkAnswer(
- arrayData.toDF().orderBy('data.getItem(0).desc),
+ arrayData.toDF().orderBy($"data".getItem(0).desc),
arrayData.toDF().collect().sortBy(_.getAs[Seq[Int]](0)(0)).reverse.toSeq)
checkAnswer(
- arrayData.toDF().orderBy('data.getItem(1).asc),
+ arrayData.toDF().orderBy($"data".getItem(1).asc),
arrayData.toDF().collect().sortBy(_.getAs[Seq[Int]](0)(1)).toSeq)
checkAnswer(
- arrayData.toDF().orderBy('data.getItem(1).desc),
+ arrayData.toDF().orderBy($"data".getItem(1).desc),
arrayData.toDF().collect().sortBy(_.getAs[Seq[Int]](0)(1)).reverse.toSeq)
}
@@ -552,14 +565,13 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
Row(0) :: Row(1) :: Nil
)
}
-
test("except") {
checkAnswer(
lowerCaseData.except(upperCaseData),
Row(1, "a") ::
- Row(2, "b") ::
- Row(3, "c") ::
- Row(4, "d") :: Nil)
+ Row(2, "b") ::
+ Row(3, "c") ::
+ Row(4, "d") :: Nil)
checkAnswer(lowerCaseData.except(lowerCaseData), Nil)
checkAnswer(upperCaseData.except(upperCaseData), Nil)
@@ -584,8 +596,8 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
checkAnswer(
df.except(df.filter("0 = 1")),
Row("id1", 1) ::
- Row("id", 1) ::
- Row("id1", 2) :: Nil)
+ Row("id", 1) ::
+ Row("id1", 2) :: Nil)
// check if the empty set on the left side works
checkAnswer(
@@ -636,9 +648,9 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
checkAnswer(
lowerCaseData.exceptAll(upperCaseData),
Row(1, "a") ::
- Row(2, "b") ::
- Row(3, "c") ::
- Row(4, "d") :: Nil)
+ Row(2, "b") ::
+ Row(3, "c") ::
+ Row(4, "d") :: Nil)
checkAnswer(lowerCaseData.exceptAll(lowerCaseData), Nil)
checkAnswer(upperCaseData.exceptAll(upperCaseData), Nil)
@@ -661,16 +673,16 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
// check that duplicates are retained.
val df = spark.sparkContext.parallelize(
NullStrings(1, "id1") ::
- NullStrings(1, "id1") ::
- NullStrings(2, "id1") ::
- NullStrings(3, null) :: Nil).toDF("id", "value")
+ NullStrings(1, "id1") ::
+ NullStrings(2, "id1") ::
+ NullStrings(3, null) :: Nil).toDF("id", "value")
checkAnswer(
df.exceptAll(df.filter("0 = 1")),
Row(1, "id1") ::
- Row(1, "id1") ::
- Row(2, "id1") ::
- Row(3, null) :: Nil)
+ Row(1, "id1") ::
+ Row(2, "id1") ::
+ Row(3, null) :: Nil)
// check if the empty set on the left side works
checkAnswer(
@@ -704,18 +716,18 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
checkAnswer(
lowerCaseData.intersect(lowerCaseData),
Row(1, "a") ::
- Row(2, "b") ::
- Row(3, "c") ::
- Row(4, "d") :: Nil)
+ Row(2, "b") ::
+ Row(3, "c") ::
+ Row(4, "d") :: Nil)
checkAnswer(lowerCaseData.intersect(upperCaseData), Nil)
// check null equality
checkAnswer(
nullInts.intersect(nullInts),
Row(1) ::
- Row(2) ::
- Row(3) ::
- Row(null) :: Nil)
+ Row(2) ::
+ Row(3) ::
+ Row(null) :: Nil)
// check if values are de-duplicated
checkAnswer(
@@ -727,8 +739,8 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
checkAnswer(
df.intersect(df),
Row("id1", 1) ::
- Row("id", 1) ::
- Row("id1", 2) :: Nil)
+ Row("id", 1) ::
+ Row("id1", 2) :: Nil)
}
test("intersect - nullability") {
@@ -756,21 +768,21 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
checkAnswer(
lowerCaseDataWithDuplicates.intersectAll(lowerCaseDataWithDuplicates),
Row(1, "a") ::
- Row(2, "b") ::
- Row(2, "b") ::
- Row(3, "c") ::
- Row(3, "c") ::
- Row(3, "c") ::
- Row(4, "d") :: Nil)
+ Row(2, "b") ::
+ Row(2, "b") ::
+ Row(3, "c") ::
+ Row(3, "c") ::
+ Row(3, "c") ::
+ Row(4, "d") :: Nil)
checkAnswer(lowerCaseData.intersectAll(upperCaseData), Nil)
// check null equality
checkAnswer(
nullInts.intersectAll(nullInts),
Row(1) ::
- Row(2) ::
- Row(3) ::
- Row(null) :: Nil)
+ Row(2) ::
+ Row(3) ::
+ Row(null) :: Nil)
// Duplicate nulls are preserved.
checkAnswer(
@@ -811,7 +823,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
checkAnswer(
// SELECT *, foo(key, value) FROM testData
- testData.select($"*", foo('key, 'value)).limit(3),
+ testData.select($"*", foo($"key", $"value")).limit(3),
Row(1, "1", "11") :: Row(2, "2", "22") :: Row(3, "3", "33") :: Nil
)
}
@@ -914,7 +926,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
}
test("replace column using withColumns") {
- val df2 = sparkContext.parallelize(Array((1, 2), (2, 3), (3, 4))).toDF("x", "y")
+ val df2 = sparkContext.parallelize(Seq((1, 2), (2, 3), (3, 4))).toDF("x", "y")
val df3 = df2.withColumns(Seq("x", "newCol1", "newCol2"),
Seq(df2("x") + 1, df2("y"), df2("y") + 1))
checkAnswer(
@@ -1140,7 +1152,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
test("apply on query results (SPARK-5462)") {
val df = testData.sparkSession.sql("select key from testData")
- checkAnswer(df.select(df("key")), testData.select('key).collect().toSeq)
+ checkAnswer(df.select(df("key")), testData.select("key").collect().toSeq)
}
test("inputFiles") {
@@ -1253,14 +1265,14 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
val longString = Array.fill(21)("1").mkString
val df = sparkContext.parallelize(Seq("1", longString)).toDF()
val expectedAnswerForFalse = "-RECORD 0----------------------\n" +
- " value | 1 \n" +
- "-RECORD 1----------------------\n" +
- " value | 111111111111111111111 \n"
+ " value | 1 \n" +
+ "-RECORD 1----------------------\n" +
+ " value | 111111111111111111111 \n"
assert(df.showString(10, truncate = 0, vertical = true) === expectedAnswerForFalse)
val expectedAnswerForTrue = "-RECORD 0---------------------\n" +
- " value | 1 \n" +
- "-RECORD 1---------------------\n" +
- " value | 11111111111111111... \n"
+ " value | 1 \n" +
+ "-RECORD 1---------------------\n" +
+ " value | 11111111111111111... \n"
assert(df.showString(10, truncate = 20, vertical = true) === expectedAnswerForTrue)
}
@@ -1289,14 +1301,14 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
val longString = Array.fill(21)("1").mkString
val df = sparkContext.parallelize(Seq("1", longString)).toDF()
val expectedAnswerForFalse = "-RECORD 0----\n" +
- " value | 1 \n" +
- "-RECORD 1----\n" +
- " value | 111 \n"
+ " value | 1 \n" +
+ "-RECORD 1----\n" +
+ " value | 111 \n"
assert(df.showString(10, truncate = 3, vertical = true) === expectedAnswerForFalse)
val expectedAnswerForTrue = "-RECORD 0------------------\n" +
- " value | 1 \n" +
- "-RECORD 1------------------\n" +
- " value | 11111111111111... \n"
+ " value | 1 \n" +
+ "-RECORD 1------------------\n" +
+ " value | 11111111111111... \n"
assert(df.showString(10, truncate = 17, vertical = true) === expectedAnswerForTrue)
}
@@ -1363,11 +1375,11 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
(Array(2, 3, 4), Array(2, 3, 4))
).toDF()
val expectedAnswer = "-RECORD 0--------\n" +
- " _1 | [1, 2, 3] \n" +
- " _2 | [1, 2, 3] \n" +
- "-RECORD 1--------\n" +
- " _1 | [2, 3, 4] \n" +
- " _2 | [2, 3, 4] \n"
+ " _1 | [1, 2, 3] \n" +
+ " _2 | [1, 2, 3] \n" +
+ "-RECORD 1--------\n" +
+ " _1 | [2, 3, 4] \n" +
+ " _2 | [2, 3, 4] \n"
assert(df.showString(10, vertical = true) === expectedAnswer)
}
@@ -1392,11 +1404,11 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
("34".getBytes(StandardCharsets.UTF_8), "12346".getBytes(StandardCharsets.UTF_8))
).toDF()
val expectedAnswer = "-RECORD 0---------------\n" +
- " _1 | [31 32] \n" +
- " _2 | [41 42 43 2E] \n" +
- "-RECORD 1---------------\n" +
- " _1 | [33 34] \n" +
- " _2 | [31 32 33 34 36] \n"
+ " _1 | [31 32] \n" +
+ " _2 | [41 42 43 2E] \n" +
+ "-RECORD 1---------------\n" +
+ " _1 | [33 34] \n" +
+ " _2 | [31 32 33 34 36] \n"
assert(df.showString(10, vertical = true) === expectedAnswer)
}
@@ -1421,11 +1433,11 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
(2, 2)
).toDF()
val expectedAnswer = "-RECORD 0--\n" +
- " _1 | 1 \n" +
- " _2 | 1 \n" +
- "-RECORD 1--\n" +
- " _1 | 2 \n" +
- " _2 | 2 \n"
+ " _1 | 1 \n" +
+ " _2 | 1 \n" +
+ "-RECORD 1--\n" +
+ " _1 | 2 \n" +
+ " _2 | 2 \n"
assert(df.showString(10, vertical = true) === expectedAnswer)
}
@@ -1442,9 +1454,9 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
test("SPARK-7319 showString, vertical = true") {
val expectedAnswer = "-RECORD 0----\n" +
- " key | 1 \n" +
- " value | 1 \n" +
- "only showing top 1 row\n"
+ " key | 1 \n" +
+ " value | 1 \n" +
+ "only showing top 1 row\n"
assert(testData.select($"*").showString(1, vertical = true) === expectedAnswer)
}
@@ -1501,7 +1513,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
|""".stripMargin
assert(df.showString(1, truncate = 0) === expectedAnswer)
- withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> "GMT") {
+ withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> "UTC") {
val expectedAnswer = """+----------+-------------------+
||d |ts |
@@ -1518,15 +1530,15 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
val ts = Timestamp.valueOf("2016-12-01 00:00:00")
val df = Seq((d, ts)).toDF("d", "ts")
val expectedAnswer = "-RECORD 0------------------\n" +
- " d | 2016-12-01 \n" +
- " ts | 2016-12-01 00:00:00 \n"
+ " d | 2016-12-01 \n" +
+ " ts | 2016-12-01 00:00:00 \n"
assert(df.showString(1, truncate = 0, vertical = true) === expectedAnswer)
- withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> "GMT") {
+ withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> "UTC") {
val expectedAnswer = "-RECORD 0------------------\n" +
- " d | 2016-12-01 \n" +
- " ts | 2016-12-01 08:00:00 \n"
+ " d | 2016-12-01 \n" +
+ " ts | 2016-12-01 08:00:00 \n"
assert(df.showString(1, truncate = 0, vertical = true) === expectedAnswer)
}
}
@@ -1539,7 +1551,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
}
test("SPARK-6899: type should match when using codegen") {
- checkAnswer(decimalData.agg(avg('a)), Row(new java.math.BigDecimal(2)))
+ checkAnswer(decimalData.agg(avg("a")), Row(new java.math.BigDecimal(2)))
}
test("SPARK-7133: Implement struct, array, and map field accessor") {
@@ -1581,10 +1593,10 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
test("SPARK-7324 dropDuplicates") {
val testData = sparkContext.parallelize(
(2, 1, 2) :: (1, 1, 1) ::
- (1, 2, 1) :: (2, 1, 2) ::
- (2, 2, 2) :: (2, 2, 1) ::
- (2, 1, 1) :: (1, 1, 2) ::
- (1, 2, 2) :: (1, 2, 1) :: Nil).toDF("key", "value1", "value2")
+ (1, 2, 1) :: (2, 1, 2) ::
+ (2, 2, 2) :: (2, 2, 1) ::
+ (2, 1, 1) :: (1, 1, 2) ::
+ (1, 2, 2) :: (1, 2, 1) :: Nil).toDF("key", "value1", "value2")
checkAnswer(
testData.dropDuplicates(),
@@ -1669,47 +1681,48 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
test("SPARK-6941: Better error message for inserting into RDD-based Table") {
withTempDir { dir =>
+ withTempView("parquet_base", "json_base", "rdd_base", "indirect_ds", "one_row") {
+ val tempParquetFile = new File(dir, "tmp_parquet")
+ val tempJsonFile = new File(dir, "tmp_json")
+
+ val df = Seq(Tuple1(1)).toDF()
+ val insertion = Seq(Tuple1(2)).toDF("col")
+
+ // pass case: parquet table (HadoopFsRelation)
+ df.write.mode(SaveMode.Overwrite).parquet(tempParquetFile.getCanonicalPath)
+ val pdf = spark.read.parquet(tempParquetFile.getCanonicalPath)
+ pdf.createOrReplaceTempView("parquet_base")
+
+ insertion.write.insertInto("parquet_base")
+
+ // pass case: json table (InsertableRelation)
+ df.write.mode(SaveMode.Overwrite).json(tempJsonFile.getCanonicalPath)
+ val jdf = spark.read.json(tempJsonFile.getCanonicalPath)
+ jdf.createOrReplaceTempView("json_base")
+ insertion.write.mode(SaveMode.Overwrite).insertInto("json_base")
+
+ // error cases: insert into an RDD
+ df.createOrReplaceTempView("rdd_base")
+ val e1 = intercept[AnalysisException] {
+ insertion.write.insertInto("rdd_base")
+ }
+ assert(e1.getMessage.contains("Inserting into an RDD-based table is not allowed."))
- val tempParquetFile = new File(dir, "tmp_parquet")
- val tempJsonFile = new File(dir, "tmp_json")
-
- val df = Seq(Tuple1(1)).toDF()
- val insertion = Seq(Tuple1(2)).toDF("col")
-
- // pass case: parquet table (HadoopFsRelation)
- df.write.mode(SaveMode.Overwrite).parquet(tempParquetFile.getCanonicalPath)
- val pdf = spark.read.parquet(tempParquetFile.getCanonicalPath)
- pdf.createOrReplaceTempView("parquet_base")
-
- insertion.write.insertInto("parquet_base")
-
- // pass case: json table (InsertableRelation)
- df.write.mode(SaveMode.Overwrite).json(tempJsonFile.getCanonicalPath)
- val jdf = spark.read.json(tempJsonFile.getCanonicalPath)
- jdf.createOrReplaceTempView("json_base")
- insertion.write.mode(SaveMode.Overwrite).insertInto("json_base")
-
- // error cases: insert into an RDD
- df.createOrReplaceTempView("rdd_base")
- val e1 = intercept[AnalysisException] {
- insertion.write.insertInto("rdd_base")
- }
- assert(e1.getMessage.contains("Inserting into an RDD-based table is not allowed."))
-
- // error case: insert into a logical plan that is not a LeafNode
- val indirectDS = pdf.select("_1").filter($"_1" > 5)
- indirectDS.createOrReplaceTempView("indirect_ds")
- val e2 = intercept[AnalysisException] {
- insertion.write.insertInto("indirect_ds")
- }
- assert(e2.getMessage.contains("Inserting into an RDD-based table is not allowed."))
+ // error case: insert into a logical plan that is not a LeafNode
+ val indirectDS = pdf.select("_1").filter($"_1" > 5)
+ indirectDS.createOrReplaceTempView("indirect_ds")
+ val e2 = intercept[AnalysisException] {
+ insertion.write.insertInto("indirect_ds")
+ }
+ assert(e2.getMessage.contains("Inserting into an RDD-based table is not allowed."))
- // error case: insert into an OneRowRelation
- Dataset.ofRows(spark, OneRowRelation()).createOrReplaceTempView("one_row")
- val e3 = intercept[AnalysisException] {
- insertion.write.insertInto("one_row")
+ // error case: insert into an OneRowRelation
+ Dataset.ofRows(spark, OneRowRelation()).createOrReplaceTempView("one_row")
+ val e3 = intercept[AnalysisException] {
+ insertion.write.insertInto("one_row")
+ }
+ assert(e3.getMessage.contains("Inserting into an RDD-based table is not allowed."))
}
- assert(e3.getMessage.contains("Inserting into an RDD-based table is not allowed."))
}
}
@@ -1741,7 +1754,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
test("Sorting columns are not in Filter and Project") {
checkAnswer(
- upperCaseData.filter('N > 1).select('N).filter('N < 6).orderBy('L.asc),
+ upperCaseData.filter($"N" > 1).select("N").filter($"N" < 6).orderBy($"L".asc),
Row(2) :: Row(3) :: Row(4) :: Row(5) :: Nil)
}
@@ -1784,29 +1797,31 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
test("Alias uses internally generated names 'aggOrder' and 'havingCondition'") {
val df = Seq(1 -> 2).toDF("i", "j")
- val query1 = df.groupBy('i)
- .agg(max('j).as("aggOrder"))
- .orderBy(sum('j))
+ val query1 = df.groupBy("i")
+ .agg(max("j").as("aggOrder"))
+ .orderBy(sum("j"))
checkAnswer(query1, Row(1, 2))
// In the plan, there are two attributes having the same name 'havingCondition'
// One is a user-provided alias name; another is an internally generated one.
- val query2 = df.groupBy('i)
- .agg(max('j).as("havingCondition"))
- .where(sum('j) > 0)
- .orderBy('havingCondition.asc)
+ val query2 = df.groupBy("i")
+ .agg(max("j").as("havingCondition"))
+ .where(sum("j") > 0)
+ .orderBy($"havingCondition".asc)
checkAnswer(query2, Row(1, 2))
}
test("SPARK-10316: respect non-deterministic expressions in PhysicalOperation") {
- val input = spark.read.json((1 to 10).map(i => s"""{"id": $i}""").toDS())
+ withTempDir { dir =>
+ (1 to 10).toDF("id").write.mode(SaveMode.Overwrite).json(dir.getCanonicalPath)
+ val input = spark.read.json(dir.getCanonicalPath)
- val df = input.select($"id", rand(0).as('r))
- df.as("a").join(df.filter($"r" < 0.5).as("b"), $"a.id" === $"b.id").collect().foreach { row =>
- assert(row.getDouble(1) - row.getDouble(3) === 0.0 +- 0.001)
+ val df = input.select($"id", rand(0).as("r"))
+ df.as("a").join(df.filter($"r" < 0.5).as("b"), $"a.id" === $"b.id").collect().foreach { row =>
+ assert(row.getDouble(1) - row.getDouble(3) === 0.0 +- 0.001)
+ }
}
}
-
test("SPARK-10539: Project should not be pushed down through Intersect or Except") {
val df1 = (1 to 100).map(Tuple1.apply).toDF("i")
val df2 = (1 to 30).map(Tuple1.apply).toDF("i")
@@ -1856,7 +1871,6 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
expected(except)
)
}
-
test("SPARK-10743: keep the name of expression if possible when do cast") {
val df = (1 to 10).map(Tuple1.apply).toDF("i").as("src")
assert(df.select($"src.i".cast(StringType)).columns.head === "i")
@@ -2001,8 +2015,8 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
val rdd = sparkContext.makeRDD(Seq(Row(1, 3), Row(2, 1)))
val df = spark.createDataFrame(
rdd,
- new StructType().add("f1", IntegerType).add("f2", IntegerType),
- needsConversion = false).select($"F1", $"f2".as("f2"))
+ new StructType().add("f1", IntegerType).add("f2", IntegerType))
+ .select($"F1", $"f2".as("f2"))
val df1 = df.as("a")
val df2 = df.as("b")
checkAnswer(df1.join(df2, $"a.f2" === $"b.f2"), Row(1, 3, 1, 3) :: Row(2, 1, 2, 1) :: Nil)
@@ -2017,7 +2031,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
test("SPARK-11725: correctly handle null inputs for ScalaUDF") {
val df = sparkContext.parallelize(Seq(
- new java.lang.Integer(22) -> "John",
+ java.lang.Integer.valueOf(22) -> "John",
null.asInstanceOf[java.lang.Integer] -> "Lucy")).toDF("age", "name")
// passing null into the UDF that could handle it
@@ -2126,19 +2140,23 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
val df = Seq("foo", "bar").map(Tuple1.apply).toDF("col")
// invalid table names
Seq("11111", "t~", "#$@sum", "table!#").foreach { name =>
- val m = intercept[AnalysisException](df.createOrReplaceTempView(name)).getMessage
- assert(m.contains(s"Invalid view name: $name"))
+ withTempView(name) {
+ val m = intercept[AnalysisException](df.createOrReplaceTempView(name)).getMessage
+ assert(m.contains(s"Invalid view name: $name"))
+ }
}
// valid table names
Seq("table1", "`11111`", "`t~`", "`#$@sum`", "`table!#`").foreach { name =>
- df.createOrReplaceTempView(name)
+ withTempView(name) {
+ df.createOrReplaceTempView(name)
+ }
}
}
test("assertAnalyzed shouldn't replace original stack trace") {
val e = intercept[AnalysisException] {
- spark.range(1).select('id as 'a, 'id as 'b).groupBy('a).agg('b)
+ spark.range(1).select($"id" as "a", $"id" as "b").groupBy("a").agg($"b")
}
assert(e.getStackTrace.head.getClassName != classOf[QueryExecution].getName)
@@ -2160,7 +2178,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
}
- }
+ }
test("SPARK-13774: Check error message for not existent globbed paths") {
// Non-existent initial path component:
@@ -2203,7 +2221,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
val size = 201L
val rdd = sparkContext.makeRDD(Seq(Row.fromSeq(Seq.range(0, size))))
val schemas = List.range(0, size).map(a => StructField("name" + a, LongType, true))
- val df = spark.createDataFrame(rdd, StructType(schemas), false)
+ val df = spark.createDataFrame(rdd, StructType(schemas))
assert(df.persist.take(1).apply(0).toSeq(100).asInstanceOf[Long] == 100)
}
@@ -2233,9 +2251,9 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
}
private def verifyNullabilityInFilterExec(
- df: DataFrame,
- expr: String,
- expectedNonNullableColumns: Seq[String]): Unit = {
+ df: DataFrame,
+ expr: String,
+ expectedNonNullableColumns: Seq[String]): Unit = {
val dfWithFilter = df.where(s"isnotnull($expr)").selectExpr(expr)
dfWithFilter.queryExecution.executedPlan.collect {
// When the child expression in isnotnull is null-intolerant (i.e. any null input will
@@ -2250,9 +2268,9 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
test("SPARK-17957: no change on nullability in FilterExec output") {
val df = sparkContext.parallelize(Seq(
- null.asInstanceOf[java.lang.Integer] -> new java.lang.Integer(3),
- new java.lang.Integer(1) -> null.asInstanceOf[java.lang.Integer],
- new java.lang.Integer(2) -> new java.lang.Integer(4))).toDF()
+ null.asInstanceOf[java.lang.Integer] -> java.lang.Integer.valueOf(3),
+ java.lang.Integer.valueOf(1) -> null.asInstanceOf[java.lang.Integer],
+ java.lang.Integer.valueOf(2) -> java.lang.Integer.valueOf(4))).toDF()
verifyNullabilityInFilterExec(df,
expr = "Rand()", expectedNonNullableColumns = Seq.empty[String])
@@ -2267,9 +2285,9 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
test("SPARK-17957: set nullability to false in FilterExec output") {
val df = sparkContext.parallelize(Seq(
- null.asInstanceOf[java.lang.Integer] -> new java.lang.Integer(3),
- new java.lang.Integer(1) -> null.asInstanceOf[java.lang.Integer],
- new java.lang.Integer(2) -> new java.lang.Integer(4))).toDF()
+ null.asInstanceOf[java.lang.Integer] -> java.lang.Integer.valueOf(3),
+ java.lang.Integer.valueOf(1) -> null.asInstanceOf[java.lang.Integer],
+ java.lang.Integer.valueOf(2) -> java.lang.Integer.valueOf(4))).toDF()
verifyNullabilityInFilterExec(df,
expr = "_1 + _2 * 3", expectedNonNullableColumns = Seq("_1", "_2"))
@@ -2338,7 +2356,6 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
val df = spark.range(1).selectExpr("CAST(id as DECIMAL) as x").selectExpr("percentile(x, 0.5)")
checkAnswer(df, Row(BigDecimal(0)) :: Nil)
}
-
test("SPARK-19893: cannot run set operations with map type") {
val df = spark.range(1).select(map(lit("key"), $"id").as("m"))
val e = intercept[AnalysisException](df.intersect(df))
@@ -2389,7 +2406,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
val e = intercept[SparkException] {
df.filter(filter).count()
}.getMessage
- assert(e.contains("grows beyond 64 KB"))
+ assert(e.contains("grows beyond 64 KiB"))
}
}
@@ -2405,7 +2422,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
test("order-by ordinal.") {
checkAnswer(
- testData2.select(lit(7), 'a, 'b).orderBy(lit(1), lit(2), lit(3)),
+ testData2.select(lit(7), $"a", $"b").orderBy(lit(1), lit(2), lit(3)),
Seq(Row(7, 1, 1), Row(7, 1, 2), Row(7, 2, 1), Row(7, 2, 2), Row(7, 3, 1), Row(7, 3, 2)))
}
@@ -2424,7 +2441,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
test("SPARK-22271: mean overflows and returns null for some decimal variables") {
val d = 0.034567890
val df = Seq(d, d, d, d, d, d, d, d, d, d).toDF("DecimalCol")
- val result = df.select('DecimalCol cast DecimalType(38, 33))
+ val result = df.select($"DecimalCol" cast DecimalType(38, 33))
.select(col("DecimalCol")).describe()
val mean = result.select("DecimalCol").where($"summary" === "mean")
assert(mean.collect().toSet === Set(Row("0.0345678900000000000000000000000000000")))
@@ -2460,24 +2477,25 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
val sourceDF = spark.createDataFrame(rows, schema)
def structWhenDF: DataFrame = sourceDF
- .select(when('cond, struct(lit("a").as("val1"), lit(10).as("val2"))).otherwise('s) as "res")
- .select('res.getField("val1"))
+ .select(when($"cond",
+ struct(lit("a").as("val1"), lit(10).as("val2"))).otherwise($"s") as "res")
+ .select($"res".getField("val1"))
def arrayWhenDF: DataFrame = sourceDF
- .select(when('cond, array(lit("a"), lit("b"))).otherwise('a) as "res")
- .select('res.getItem(0))
+ .select(when($"cond", array(lit("a"), lit("b"))).otherwise($"a") as "res")
+ .select($"res".getItem(0))
def mapWhenDF: DataFrame = sourceDF
- .select(when('cond, map(lit(0), lit("a"))).otherwise('m) as "res")
- .select('res.getItem(0))
+ .select(when($"cond", map(lit(0), lit("a"))).otherwise($"m") as "res")
+ .select($"res".getItem(0))
def structIfDF: DataFrame = sourceDF
.select(expr("if(cond, struct('a' as val1, 10 as val2), s)") as "res")
- .select('res.getField("val1"))
+ .select($"res".getField("val1"))
def arrayIfDF: DataFrame = sourceDF
.select(expr("if(cond, array('a', 'b'), a)") as "res")
- .select('res.getItem(0))
+ .select($"res".getItem(0))
def mapIfDF: DataFrame = sourceDF
.select(expr("if(cond, map(0, 'a'), m)") as "res")
- .select('res.getItem(0))
+ .select($"res".getItem(0))
def checkResult(): Unit = {
checkAnswer(structWhenDF, Seq(Row("a"), Row(null)))
@@ -2517,7 +2535,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
}
test("SPARK-24781: Using a reference not in aggregation in Filter/Sort") {
- withSQLConf(SQLConf.DATAFRAME_RETAIN_GROUP_COLUMNS.key -> "false") {
+ withSQLConf(SQLConf.DATAFRAME_RETAIN_GROUP_COLUMNS.key -> "false") {
val df = Seq(("test1", 0), ("test2", 1)).toDF("name", "id")
val aggPlusSort1 = df.groupBy(df("name")).agg(count(df("name"))).orderBy(df("name"))
@@ -2540,17 +2558,17 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
// partitions.
.write.partitionBy("p").option("compression", "gzip").json(path.getCanonicalPath)
- var numJobs = 0
+ val numJobs = new AtomicLong(0)
sparkContext.addSparkListener(new SparkListener {
override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = {
- numJobs += 1
+ numJobs.incrementAndGet()
}
})
val df = spark.read.json(path.getCanonicalPath)
assert(df.columns === Array("i", "p"))
spark.sparkContext.listenerBus.waitUntilEmpty(10000)
- assert(numJobs == 1)
+ assert(numJobs.get() == 1L)
}
}
@@ -2622,4 +2640,12 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
checkAnswer(res, Row("1-1", 6, 6))
}
}
+
+ test("SPARK-29442 Set `default` mode should override the existing mode") {
+ val df = Seq(Tuple1(1)).toDF()
+ val writer = df.write.mode("overwrite").mode("default")
+ val modeField = classOf[DataFrameWriter[Tuple1[Int]]].getDeclaredField("mode")
+ modeField.setAccessible(true)
+ assert(SaveMode.ErrorIfExists === modeField.get(writer).asInstanceOf[SaveMode])
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala
index a5f904c621e6e..9daa69ce9f155 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala
@@ -178,4 +178,14 @@ class UnsafeRowSuite extends SparkFunSuite {
// Makes sure hashCode on unsafe array won't crash
unsafeRow.getArray(0).hashCode()
}
+
+ test("SPARK-32018: setDecimal with overflowed value") {
+ val d1 = new Decimal().set(BigDecimal("10000000000000000000")).toPrecision(38, 18)
+ val row = InternalRow.apply(d1)
+ val unsafeRow = UnsafeProjection.create(Array[DataType](DecimalType(38, 18))).apply(row)
+ assert(unsafeRow.getDecimal(0, 38, 18) === d1)
+ val d2 = (d1 * Decimal(10)).toPrecision(39, 18)
+ unsafeRow.setDecimal(0, d2, 38)
+ assert(unsafeRow.getDecimal(0, 38, 18) === null)
+ }
}
diff --git a/sql/hive-thriftserver/pom.xml b/sql/hive-thriftserver/pom.xml
index d84421acf8390..fe102a1691305 100644
--- a/sql/hive-thriftserver/pom.xml
+++ b/sql/hive-thriftserver/pom.xml
@@ -22,7 +22,7 @@
org.apache.spark
spark-parent_2.11
- 2.4.1-kylin-r70
+ 2.4.1-kylin-r71
../../pom.xml
diff --git a/sql/hive/pom.xml b/sql/hive/pom.xml
index 9f76d192524ee..8efb3799d2fb8 100644
--- a/sql/hive/pom.xml
+++ b/sql/hive/pom.xml
@@ -22,7 +22,7 @@
org.apache.spark
spark-parent_2.11
- 2.4.1-kylin-r70
+ 2.4.1-kylin-r71
../../pom.xml
diff --git a/streaming/pom.xml b/streaming/pom.xml
index e84cd5514784e..b47ae88c16916 100644
--- a/streaming/pom.xml
+++ b/streaming/pom.xml
@@ -21,7 +21,7 @@
org.apache.spark
spark-parent_2.11
- 2.4.1-kylin-r70
+ 2.4.1-kylin-r71
../pom.xml
diff --git a/tools/pom.xml b/tools/pom.xml
index e009120bb52bc..25d9508c256c4 100644
--- a/tools/pom.xml
+++ b/tools/pom.xml
@@ -20,7 +20,7 @@
org.apache.spark
spark-parent_2.11
- 2.4.1-kylin-r70
+ 2.4.1-kylin-r71
../pom.xml