diff --git a/.github/workflows/build_and_test.yml b/.github/workflows/build_and_test.yml index a4aca07d69f7..56b34d6c5e7b 100644 --- a/.github/workflows/build_and_test.yml +++ b/.github/workflows/build_and_test.yml @@ -253,7 +253,7 @@ jobs: - name: Install Python packages (Python 3.9) if: (contains(matrix.modules, 'sql') && !contains(matrix.modules, 'sql-')) || contains(matrix.modules, 'connect') run: | - python3.9 -m pip install 'numpy>=1.20.0' pyarrow pandas scipy unittest-xml-reporting 'grpcio==1.59.3' 'grpcio-status==1.59.3' 'protobuf==4.25.1' + python3.9 -m pip install 'numpy>=1.20.0' pyarrow pandas scipy unittest-xml-reporting 'lxml==4.9.4' 'grpcio==1.59.3' 'grpcio-status==1.59.3' 'protobuf==4.25.1' python3.9 -m pip list # Run the tests. - name: Run tests diff --git a/common/utils/src/main/resources/error/error-classes.json b/common/utils/src/main/resources/error/error-classes.json index 700b1ed07513..9f68d4c5a53e 100644 --- a/common/utils/src/main/resources/error/error-classes.json +++ b/common/utils/src/main/resources/error/error-classes.json @@ -875,12 +875,6 @@ ], "sqlState" : "42K01" }, - "DATA_SOURCE_ALREADY_EXISTS" : { - "message" : [ - "Data source '' already exists in the registry. Please use a different name for the new data source." - ], - "sqlState" : "42710" - }, "DATA_SOURCE_NOT_EXIST" : { "message" : [ "Data source '' not found. Please make sure the data source is registered." @@ -1480,12 +1474,6 @@ }, "sqlState" : "42K0B" }, - "INCORRECT_END_OFFSET" : { - "message" : [ - "Max offset with rowsPerSecond is , but it's now." - ], - "sqlState" : "22003" - }, "INCORRECT_RAMP_UP_RATE" : { "message" : [ "Max offset with rowsPerSecond is , but 'rampUpTimeSeconds' is ." @@ -1906,11 +1894,6 @@ "Operation not found." ] }, - "SESSION_ALREADY_EXISTS" : { - "message" : [ - "Session already exists." - ] - }, "SESSION_CLOSED" : { "message" : [ "Session was closed." @@ -6065,11 +6048,6 @@ "." ] }, - "_LEGACY_ERROR_TEMP_2142" : { - "message" : [ - "Attributes for type is not supported." - ] - }, "_LEGACY_ERROR_TEMP_2144" : { "message" : [ "Unable to find constructor for . This could happen if is an interface, or a trait without companion object constructor." @@ -6920,11 +6898,6 @@ ": " ] }, - "_LEGACY_ERROR_TEMP_3066" : { - "message" : [ - "" - ] - }, "_LEGACY_ERROR_TEMP_3067" : { "message" : [ "Streaming aggregation doesn't support group aggregate pandas UDF" @@ -6980,11 +6953,6 @@ "More than one event time columns are available. Please ensure there is at most one event time column per stream. event time columns: " ] }, - "_LEGACY_ERROR_TEMP_3078" : { - "message" : [ - "Can not match ParquetTable in the query." - ] - }, "_LEGACY_ERROR_TEMP_3079" : { "message" : [ "Dynamic partition cannot be the parent of a static partition." diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala index c947d948b4cf..0740334724e8 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala @@ -85,7 +85,13 @@ class ClientE2ETestSuite extends RemoteSparkSession with SQLHelper with PrivateM |""".stripMargin) .collect() } - assert(ex.getErrorClass != null) + assert( + ex.getErrorClass === + "INCONSISTENT_BEHAVIOR_CROSS_VERSION.PARSE_DATETIME_BY_NEW_PARSER") + assert( + ex.getMessageParameters.asScala == Map( + "datetime" -> "'02-29'", + "config" -> "\"spark.sql.legacy.timeParserPolicy\"")) if (enrichErrorEnabled) { assert(ex.getCause.isInstanceOf[DateTimeException]) } else { diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala index c5c917ebfa95..0a7768aa488b 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala @@ -516,6 +516,10 @@ class PlanGenerationTestSuite simple.where("a + id < 1000") } + test("between expr") { + simple.selectExpr("rand(123) BETWEEN 0.1 AND 0.2") + } + test("unpivot values") { simple.unpivot( ids = Array(fn.col("id"), fn.col("a")), diff --git a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/GrpcExceptionConverter.scala b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/GrpcExceptionConverter.scala index 075526e7521d..cc47924de3b0 100644 --- a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/GrpcExceptionConverter.scala +++ b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/GrpcExceptionConverter.scala @@ -372,10 +372,14 @@ private[client] object GrpcExceptionConverter { .addAllErrorTypeHierarchy(classes.toImmutableArraySeq.asJava) if (errorClass != null) { + val messageParameters = JsonMethods + .parse(info.getMetadataOrDefault("messageParameters", "{}")) + .extract[Map[String, String]] builder.setSparkThrowable( FetchErrorDetailsResponse.SparkThrowable .newBuilder() .setErrorClass(errorClass) + .putAllMessageParameters(messageParameters.asJava) .build()) } diff --git a/connector/connect/common/src/test/resources/query-tests/explain-results/between_expr.explain b/connector/connect/common/src/test/resources/query-tests/explain-results/between_expr.explain new file mode 100644 index 000000000000..e21ebd5c5da1 --- /dev/null +++ b/connector/connect/common/src/test/resources/query-tests/explain-results/between_expr.explain @@ -0,0 +1,3 @@ +Project [((_common_expr_0#0 >= cast(0.1 as double)) AND (_common_expr_0#0 <= cast(0.2 as double))) AS between(rand(123), 0.1, 0.2)#0] ++- Project [id#0L, a#0, b#0, rand(123) AS _common_expr_0#0] + +- LocalRelation , [id#0L, a#0, b#0] diff --git a/connector/connect/common/src/test/resources/query-tests/queries/between_expr.json b/connector/connect/common/src/test/resources/query-tests/queries/between_expr.json new file mode 100644 index 000000000000..bf24a26601dc --- /dev/null +++ b/connector/connect/common/src/test/resources/query-tests/queries/between_expr.json @@ -0,0 +1,20 @@ +{ + "common": { + "planId": "1" + }, + "project": { + "input": { + "common": { + "planId": "0" + }, + "localRelation": { + "schema": "struct\u003cid:bigint,a:int,b:double\u003e" + } + }, + "expressions": [{ + "expressionString": { + "expression": "rand(123) BETWEEN 0.1 AND 0.2" + } + }] + } +} \ No newline at end of file diff --git a/connector/connect/common/src/test/resources/query-tests/queries/between_expr.proto.bin b/connector/connect/common/src/test/resources/query-tests/queries/between_expr.proto.bin new file mode 100644 index 000000000000..03a8aba2719d Binary files /dev/null and b/connector/connect/common/src/test/resources/query-tests/queries/between_expr.proto.bin differ diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala index ab4f06d508a0..39bf1a630af6 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala @@ -256,4 +256,13 @@ object Connect { .version("4.0.0") .booleanConf .createWithDefault(true) + + val CONNECT_GRPC_MAX_METADATA_SIZE = + buildStaticConf("spark.connect.grpc.maxMetadataSize") + .doc( + "Sets the maximum size of metadata fields. For instance, it restricts metadata fields " + + "in `ErrorInfo`.") + .version("4.0.0") + .bytesConf(ByteUnit.BYTE) + .createWithDefault(1024) } diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/ErrorUtils.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/ErrorUtils.scala index 703b11c0c736..f489551a1dba 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/ErrorUtils.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/ErrorUtils.scala @@ -172,6 +172,7 @@ private[connect] object ErrorUtils extends Logging { "classes", JsonMethods.compact(JsonMethods.render(allClasses(st.getClass).map(_.getName)))) + val maxMetadataSize = SparkEnv.get.conf.get(Connect.CONNECT_GRPC_MAX_METADATA_SIZE) // Add the SQL State and Error Class to the response metadata of the ErrorInfoObject. st match { case e: SparkThrowable => @@ -181,7 +182,12 @@ private[connect] object ErrorUtils extends Logging { } val errorClass = e.getErrorClass if (errorClass != null && errorClass.nonEmpty) { - errorInfo.putMetadata("errorClass", errorClass) + val messageParameters = JsonMethods.compact( + JsonMethods.render(map2jvalue(e.getMessageParameters.asScala.toMap))) + if (messageParameters.length <= maxMetadataSize) { + errorInfo.putMetadata("errorClass", errorClass) + errorInfo.putMetadata("messageParameters", messageParameters) + } } case _ => } @@ -200,8 +206,10 @@ private[connect] object ErrorUtils extends Logging { val withStackTrace = if (sessionHolderOpt.exists( _.session.conf.get(SQLConf.PYSPARK_JVM_STACKTRACE_ENABLED) && stackTrace.nonEmpty)) { - val maxSize = SparkEnv.get.conf.get(Connect.CONNECT_JVM_STACK_TRACE_MAX_SIZE) - errorInfo.putMetadata("stackTrace", StringUtils.abbreviate(stackTrace.get, maxSize)) + val maxSize = Math.min( + SparkEnv.get.conf.get(Connect.CONNECT_JVM_STACK_TRACE_MAX_SIZE), + maxMetadataSize) + errorInfo.putMetadata("stackTrace", StringUtils.abbreviate(stackTrace.get, maxSize.toInt)) } else { errorInfo } diff --git a/connector/docker-integration-tests/pom.xml b/connector/docker-integration-tests/pom.xml index ac8d2990c0e6..3f73177d7dd4 100644 --- a/connector/docker-integration-tests/pom.xml +++ b/connector/docker-integration-tests/pom.xml @@ -55,7 +55,7 @@ com.google.guava guava - 18.0 + 19.0 org.apache.spark diff --git a/connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala b/connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala index 5b4567aa2881..cee0d9a3dd72 100644 --- a/connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala +++ b/connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala @@ -1789,7 +1789,7 @@ class KafkaMicroBatchV2SourceSuite extends KafkaMicroBatchSourceSuiteBase { CheckAnswer(data: _*), Execute { query => // The rate limit is 1, so there must be some delay in offsets per partition. - val progressWithDelay = query.recentProgress.map(_.sources.head).reverse.find { progress => + val progressWithDelay = query.recentProgress.map(_.sources.head).findLast { progress => // find the metrics that has non-zero average offsetsBehindLatest greater than 0. !progress.metrics.isEmpty && progress.metrics.get("avgOffsetsBehindLatest").toDouble > 0 } diff --git a/core/benchmarks/ZStandardBenchmark-jdk21-results.txt b/core/benchmarks/ZStandardBenchmark-jdk21-results.txt index f5383af55c3a..e0928df7f3c6 100644 --- a/core/benchmarks/ZStandardBenchmark-jdk21-results.txt +++ b/core/benchmarks/ZStandardBenchmark-jdk21-results.txt @@ -2,48 +2,48 @@ Benchmark ZStandardCompressionCodec ================================================================================================ -OpenJDK 64-Bit Server VM 21.0.1+12-LTS on Linux 5.15.0-1051-azure +OpenJDK 64-Bit Server VM 21.0.1+12-LTS on Linux 5.15.0-1053-azure AMD EPYC 7763 64-Core Processor Benchmark ZStandardCompressionCodec: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative -------------------------------------------------------------------------------------------------------------------------------------- -Compression 10000 times at level 1 without buffer pool 672 681 10 0.0 67171.0 1.0X -Compression 10000 times at level 2 without buffer pool 715 718 4 0.0 71458.8 0.9X -Compression 10000 times at level 3 without buffer pool 831 835 4 0.0 83139.1 0.8X -Compression 10000 times at level 1 with buffer pool 609 611 2 0.0 60881.5 1.1X -Compression 10000 times at level 2 with buffer pool 648 649 1 0.0 64791.0 1.0X -Compression 10000 times at level 3 with buffer pool 744 751 6 0.0 74392.4 0.9X +Compression 10000 times at level 1 without buffer pool 674 920 293 0.0 67406.4 1.0X +Compression 10000 times at level 2 without buffer pool 882 884 3 0.0 88195.1 0.8X +Compression 10000 times at level 3 without buffer pool 973 978 4 0.0 97301.3 0.7X +Compression 10000 times at level 1 with buffer pool 955 955 1 0.0 95452.0 0.7X +Compression 10000 times at level 2 with buffer pool 994 996 2 0.0 99432.1 0.7X +Compression 10000 times at level 3 with buffer pool 1093 1101 12 0.0 109300.9 0.6X -OpenJDK 64-Bit Server VM 21.0.1+12-LTS on Linux 5.15.0-1051-azure +OpenJDK 64-Bit Server VM 21.0.1+12-LTS on Linux 5.15.0-1053-azure AMD EPYC 7763 64-Core Processor Benchmark ZStandardCompressionCodec: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------------------------ -Decompression 10000 times from level 1 without buffer pool 842 849 12 0.0 84240.0 1.0X -Decompression 10000 times from level 2 without buffer pool 842 846 6 0.0 84185.2 1.0X -Decompression 10000 times from level 3 without buffer pool 843 844 1 0.0 84285.4 1.0X -Decompression 10000 times from level 1 with buffer pool 770 771 1 0.0 77024.9 1.1X -Decompression 10000 times from level 2 with buffer pool 771 771 0 0.0 77120.4 1.1X -Decompression 10000 times from level 3 with buffer pool 770 771 0 0.0 77031.9 1.1X +Decompression 10000 times from level 1 without buffer pool 826 829 3 0.0 82591.4 1.0X +Decompression 10000 times from level 2 without buffer pool 825 826 1 0.0 82533.4 1.0X +Decompression 10000 times from level 3 without buffer pool 827 830 5 0.0 82715.3 1.0X +Decompression 10000 times from level 1 with buffer pool 763 764 1 0.0 76271.6 1.1X +Decompression 10000 times from level 2 with buffer pool 763 777 23 0.0 76321.2 1.1X +Decompression 10000 times from level 3 with buffer pool 763 765 2 0.0 76286.1 1.1X -OpenJDK 64-Bit Server VM 21.0.1+12-LTS on Linux 5.15.0-1051-azure +OpenJDK 64-Bit Server VM 21.0.1+12-LTS on Linux 5.15.0-1053-azure AMD EPYC 7763 64-Core Processor Parallel Compression at level 3: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Parallel Compression with 0 workers 48 50 3 0.0 376597.0 1.0X -Parallel Compression with 1 workers 41 42 3 0.0 318927.3 1.2X -Parallel Compression with 2 workers 38 40 2 0.0 297410.2 1.3X -Parallel Compression with 4 workers 37 39 1 0.0 287605.8 1.3X -Parallel Compression with 8 workers 39 40 1 0.0 301948.1 1.2X -Parallel Compression with 16 workers 41 43 1 0.0 317095.6 1.2X +Parallel Compression with 0 workers 49 50 1 0.0 384188.1 1.0X +Parallel Compression with 1 workers 42 44 4 0.0 328139.4 1.2X +Parallel Compression with 2 workers 40 42 1 0.0 309013.2 1.2X +Parallel Compression with 4 workers 40 41 1 0.0 309732.2 1.2X +Parallel Compression with 8 workers 41 43 2 0.0 319730.2 1.2X +Parallel Compression with 16 workers 43 45 1 0.0 337944.2 1.1X -OpenJDK 64-Bit Server VM 21.0.1+12-LTS on Linux 5.15.0-1051-azure +OpenJDK 64-Bit Server VM 21.0.1+12-LTS on Linux 5.15.0-1053-azure AMD EPYC 7763 64-Core Processor Parallel Compression at level 9: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Parallel Compression with 0 workers 174 175 1 0.0 1360596.3 1.0X -Parallel Compression with 1 workers 189 228 24 0.0 1477060.7 0.9X -Parallel Compression with 2 workers 109 118 15 0.0 851455.9 1.6X -Parallel Compression with 4 workers 114 118 3 0.0 891964.9 1.5X -Parallel Compression with 8 workers 115 122 4 0.0 899748.7 1.5X -Parallel Compression with 16 workers 119 123 2 0.0 931210.7 1.5X +Parallel Compression with 0 workers 160 161 1 0.0 1250203.7 1.0X +Parallel Compression with 1 workers 196 197 2 0.0 1529028.2 0.8X +Parallel Compression with 2 workers 114 121 10 0.0 892592.4 1.4X +Parallel Compression with 4 workers 111 113 1 0.0 865617.7 1.4X +Parallel Compression with 8 workers 112 117 2 0.0 878723.8 1.4X +Parallel Compression with 16 workers 114 117 2 0.0 889199.7 1.4X diff --git a/core/benchmarks/ZStandardBenchmark-results.txt b/core/benchmarks/ZStandardBenchmark-results.txt index 64375b7a379a..bbedcda0f160 100644 --- a/core/benchmarks/ZStandardBenchmark-results.txt +++ b/core/benchmarks/ZStandardBenchmark-results.txt @@ -2,48 +2,48 @@ Benchmark ZStandardCompressionCodec ================================================================================================ -OpenJDK 64-Bit Server VM 17.0.9+9-LTS on Linux 5.15.0-1051-azure +OpenJDK 64-Bit Server VM 17.0.9+9-LTS on Linux 5.15.0-1053-azure AMD EPYC 7763 64-Core Processor Benchmark ZStandardCompressionCodec: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative -------------------------------------------------------------------------------------------------------------------------------------- -Compression 10000 times at level 1 without buffer pool 666 669 3 0.0 66598.6 1.0X -Compression 10000 times at level 2 without buffer pool 711 711 1 0.0 71077.5 0.9X -Compression 10000 times at level 3 without buffer pool 816 816 0 0.0 81575.8 0.8X -Compression 10000 times at level 1 with buffer pool 591 592 1 0.0 59095.6 1.1X -Compression 10000 times at level 2 with buffer pool 630 632 1 0.0 62995.1 1.1X -Compression 10000 times at level 3 with buffer pool 742 742 0 0.0 74180.7 0.9X +Compression 10000 times at level 1 without buffer pool 666 669 3 0.0 66613.9 1.0X +Compression 10000 times at level 2 without buffer pool 708 709 1 0.0 70817.8 0.9X +Compression 10000 times at level 3 without buffer pool 818 819 1 0.0 81828.4 0.8X +Compression 10000 times at level 1 with buffer pool 591 596 9 0.0 59119.9 1.1X +Compression 10000 times at level 2 with buffer pool 629 630 1 0.0 62943.8 1.1X +Compression 10000 times at level 3 with buffer pool 746 747 2 0.0 74593.2 0.9X -OpenJDK 64-Bit Server VM 17.0.9+9-LTS on Linux 5.15.0-1051-azure +OpenJDK 64-Bit Server VM 17.0.9+9-LTS on Linux 5.15.0-1053-azure AMD EPYC 7763 64-Core Processor Benchmark ZStandardCompressionCodec: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------------------------ -Decompression 10000 times from level 1 without buffer pool 600 602 1 0.0 60024.1 1.0X -Decompression 10000 times from level 2 without buffer pool 600 603 2 0.0 59973.0 1.0X -Decompression 10000 times from level 3 without buffer pool 601 602 1 0.0 60075.9 1.0X -Decompression 10000 times from level 1 with buffer pool 553 553 0 0.0 55316.4 1.1X -Decompression 10000 times from level 2 with buffer pool 553 554 1 0.0 55271.5 1.1X -Decompression 10000 times from level 3 with buffer pool 553 553 0 0.0 55261.4 1.1X +Decompression 10000 times from level 1 without buffer pool 600 603 2 0.0 60027.9 1.0X +Decompression 10000 times from level 2 without buffer pool 603 604 1 0.0 60270.8 1.0X +Decompression 10000 times from level 3 without buffer pool 602 604 2 0.0 60224.9 1.0X +Decompression 10000 times from level 1 with buffer pool 548 548 0 0.0 54774.5 1.1X +Decompression 10000 times from level 2 with buffer pool 548 548 1 0.0 54763.3 1.1X +Decompression 10000 times from level 3 with buffer pool 548 548 0 0.0 54751.8 1.1X -OpenJDK 64-Bit Server VM 17.0.9+9-LTS on Linux 5.15.0-1051-azure +OpenJDK 64-Bit Server VM 17.0.9+9-LTS on Linux 5.15.0-1053-azure AMD EPYC 7763 64-Core Processor Parallel Compression at level 3: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Parallel Compression with 0 workers 49 50 1 0.0 380070.1 1.0X -Parallel Compression with 1 workers 41 42 4 0.0 319807.1 1.2X -Parallel Compression with 2 workers 38 41 2 0.0 297706.4 1.3X -Parallel Compression with 4 workers 38 40 2 0.0 296505.8 1.3X -Parallel Compression with 8 workers 39 41 1 0.0 305793.6 1.2X -Parallel Compression with 16 workers 43 44 1 0.0 332233.1 1.1X +Parallel Compression with 0 workers 49 50 1 0.0 382234.6 1.0X +Parallel Compression with 1 workers 42 43 1 0.0 327419.1 1.2X +Parallel Compression with 2 workers 38 41 2 0.0 299638.6 1.3X +Parallel Compression with 4 workers 38 40 2 0.0 294671.3 1.3X +Parallel Compression with 8 workers 40 42 1 0.0 308641.0 1.2X +Parallel Compression with 16 workers 43 45 2 0.0 335800.5 1.1X -OpenJDK 64-Bit Server VM 17.0.9+9-LTS on Linux 5.15.0-1051-azure +OpenJDK 64-Bit Server VM 17.0.9+9-LTS on Linux 5.15.0-1053-azure AMD EPYC 7763 64-Core Processor Parallel Compression at level 9: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Parallel Compression with 0 workers 175 175 0 0.0 1363800.8 1.0X -Parallel Compression with 1 workers 186 187 3 0.0 1455096.4 0.9X -Parallel Compression with 2 workers 110 115 6 0.0 863272.6 1.6X -Parallel Compression with 4 workers 104 108 2 0.0 810721.1 1.7X -Parallel Compression with 8 workers 110 112 2 0.0 859303.5 1.6X -Parallel Compression with 16 workers 109 112 2 0.0 847863.6 1.6X +Parallel Compression with 0 workers 159 161 1 0.0 1242039.9 1.0X +Parallel Compression with 1 workers 201 203 4 0.0 1568116.4 0.8X +Parallel Compression with 2 workers 115 122 12 0.0 900801.9 1.4X +Parallel Compression with 4 workers 111 114 3 0.0 868535.8 1.4X +Parallel Compression with 8 workers 113 117 2 0.0 886200.6 1.4X +Parallel Compression with 16 workers 113 117 2 0.0 880790.0 1.4X diff --git a/core/src/main/scala/org/apache/spark/TestUtils.scala b/core/src/main/scala/org/apache/spark/TestUtils.scala index fafde3cf12c6..e85f98ff55c5 100644 --- a/core/src/main/scala/org/apache/spark/TestUtils.scala +++ b/core/src/main/scala/org/apache/spark/TestUtils.scala @@ -34,7 +34,7 @@ import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import scala.io.Source import scala.reflect.{classTag, ClassTag} -import scala.sys.process.{Process, ProcessLogger} +import scala.sys.process.Process import scala.util.Try import com.google.common.io.{ByteStreams, Files} @@ -204,16 +204,7 @@ private[spark] object TestUtils extends SparkTestUtils { /** * Test if a command is available. */ - def testCommandAvailable(command: String): Boolean = { - val attempt = if (Utils.isWindows) { - Try(Process(Seq( - "cmd.exe", "/C", s"where $command")).run(ProcessLogger(_ => ())).exitValue()) - } else { - Try(Process(Seq( - "sh", "-c", s"command -v $command")).run(ProcessLogger(_ => ())).exitValue()) - } - attempt.isSuccess && attempt.get == 0 - } + def testCommandAvailable(command: String): Boolean = Utils.checkCommandAvailable(command) // SPARK-40053: This string needs to be updated when the // minimum python supported version changes. diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala b/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala index 0e61e38ff2b0..26c790a12447 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala @@ -18,14 +18,19 @@ package org.apache.spark.api.python import java.io.File +import java.nio.file.Paths import java.util.{List => JList} +import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import scala.jdk.CollectionConverters._ +import scala.sys.process.Process import org.apache.spark.{SparkContext, SparkEnv} import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} import org.apache.spark.internal.Logging +import org.apache.spark.util.ArrayImplicits.SparkArrayOps +import org.apache.spark.util.Utils private[spark] object PythonUtils extends Logging { val PY4J_ZIP_NAME = "py4j-0.10.9.7-src.zip" @@ -113,11 +118,10 @@ private[spark] object PythonUtils extends Logging { } val pythonVersionCMD = Seq(pythonExec, "-VV") - val PYTHONPATH = "PYTHONPATH" val pythonPath = PythonUtils.mergePythonPaths( PythonUtils.sparkPythonPath, - sys.env.getOrElse(PYTHONPATH, "")) - val environment = Map(PYTHONPATH -> pythonPath) + sys.env.getOrElse("PYTHONPATH", "")) + val environment = Map("PYTHONPATH" -> pythonPath) logInfo(s"Python path $pythonPath") val processPythonVer = Process(pythonVersionCMD, None, environment.toSeq: _*) @@ -145,4 +149,48 @@ private[spark] object PythonUtils extends Logging { listOfPackages.foreach(x => logInfo(s"List of Python packages :- ${formatOutput(x)}")) } } + + // Only for testing. + private[spark] var additionalTestingPath: Option[String] = None + + private[spark] val defaultPythonExec: String = sys.env.getOrElse( + "PYSPARK_DRIVER_PYTHON", sys.env.getOrElse("PYSPARK_PYTHON", "python3")) + + private[spark] def createPythonFunction(command: Array[Byte]): SimplePythonFunction = { + val sourcePython = if (Utils.isTesting) { + // Put PySpark source code instead of the build zip archive so we don't need + // to build PySpark every time during development. + val sparkHome: String = { + require( + sys.props.contains("spark.test.home") || sys.env.contains("SPARK_HOME"), + "spark.test.home or SPARK_HOME is not set.") + sys.props.getOrElse("spark.test.home", sys.env("SPARK_HOME")) + } + val sourcePath = Paths.get(sparkHome, "python").toAbsolutePath + val py4jPath = Paths.get( + sparkHome, "python", "lib", PythonUtils.PY4J_ZIP_NAME).toAbsolutePath + val merged = mergePythonPaths(sourcePath.toString, py4jPath.toString) + // Adds a additional path to search Python packages for testing. + additionalTestingPath.map(mergePythonPaths(_, merged)).getOrElse(merged) + } else { + PythonUtils.sparkPythonPath + } + val pythonPath = PythonUtils.mergePythonPaths( + sourcePython, sys.env.getOrElse("PYTHONPATH", "")) + + val pythonVer: String = + Process( + Seq(defaultPythonExec, "-c", "import sys; print('%d.%d' % sys.version_info[:2])"), + None, + "PYTHONPATH" -> pythonPath).!!.trim() + + SimplePythonFunction( + command = command.toImmutableArraySeq, + envVars = mutable.Map("PYTHONPATH" -> pythonPath).asJava, + pythonIncludes = List.empty.asJava, + pythonExec = defaultPythonExec, + pythonVer = pythonVer, + broadcastVars = List.empty.asJava, + accumulator = null) + } } diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index aa5bd73e8535..a55539c0a235 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -2851,7 +2851,7 @@ private[spark] object Utils else { // The last char is a dollar sign // Find last non-dollar char - val lastNonDollarChar = s.reverse.find(_ != '$') + val lastNonDollarChar = s.findLast(_ != '$') lastNonDollarChar match { case None => s case Some(c) => @@ -3027,6 +3027,23 @@ private[spark] object Utils } } + /** + * Check if a command is available. + */ + def checkCommandAvailable(command: String): Boolean = { + // To avoid conflicts with java.lang.Process + import scala.sys.process.{Process, ProcessLogger} + + val attempt = if (Utils.isWindows) { + Try(Process(Seq( + "cmd.exe", "/C", s"where $command")).run(ProcessLogger(_ => ())).exitValue()) + } else { + Try(Process(Seq( + "sh", "-c", s"command -v $command")).run(ProcessLogger(_ => ())).exitValue()) + } + attempt.isSuccess && attempt.get == 0 + } + /** * Return whether we are using G1GC or not */ diff --git a/dev/deps/spark-deps-hadoop-3-hive-2.3 b/dev/deps/spark-deps-hadoop-3-hive-2.3 index 1ed23f257d73..bebc63bae72b 100644 --- a/dev/deps/spark-deps-hadoop-3-hive-2.3 +++ b/dev/deps/spark-deps-hadoop-3-hive-2.3 @@ -267,4 +267,4 @@ xz/1.9//xz-1.9.jar zjsonpatch/0.3.0//zjsonpatch-0.3.0.jar zookeeper-jute/3.9.1//zookeeper-jute-3.9.1.jar zookeeper/3.9.1//zookeeper-3.9.1.jar -zstd-jni/1.5.5-10//zstd-jni-1.5.5-10.jar +zstd-jni/1.5.5-11//zstd-jni-1.5.5-11.jar diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index 4e24eb7b71f1..8595e7ec0e67 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -627,6 +627,7 @@ def __hash__(self): "pyspark.ml.tuning", # unittests "pyspark.ml.tests.test_algorithms", + "pyspark.ml.tests.test_als", "pyspark.ml.tests.test_base", "pyspark.ml.tests.test_evaluation", "pyspark.ml.tests.test_feature", @@ -816,9 +817,11 @@ def __hash__(self): "pyspark.pandas.tests.indexes.test_datetime_round", "pyspark.pandas.tests.indexes.test_align", "pyspark.pandas.tests.indexes.test_indexing", + "pyspark.pandas.tests.indexes.test_indexing_adv", "pyspark.pandas.tests.indexes.test_indexing_basic", "pyspark.pandas.tests.indexes.test_indexing_iloc", "pyspark.pandas.tests.indexes.test_indexing_loc", + "pyspark.pandas.tests.indexes.test_indexing_loc_2d", "pyspark.pandas.tests.indexes.test_indexing_loc_multi_idx", "pyspark.pandas.tests.indexes.test_reindex", "pyspark.pandas.tests.indexes.test_rename", @@ -878,7 +881,6 @@ def __hash__(self): "pyspark.pandas.tests.groupby.test_stat_func", "pyspark.pandas.tests.groupby.test_stat_prod", "pyspark.pandas.tests.groupby.test_value_counts", - "pyspark.pandas.tests.test_indexing", "pyspark.pandas.tests.diff_frames_ops.test_align", "pyspark.pandas.tests.diff_frames_ops.test_arithmetic", "pyspark.pandas.tests.diff_frames_ops.test_arithmetic_ext", @@ -1092,9 +1094,11 @@ def __hash__(self): "pyspark.pandas.tests.connect.indexes.test_parity_map", "pyspark.pandas.tests.connect.indexes.test_parity_align", "pyspark.pandas.tests.connect.indexes.test_parity_indexing", + "pyspark.pandas.tests.connect.indexes.test_parity_indexing_adv", "pyspark.pandas.tests.connect.indexes.test_parity_indexing_basic", "pyspark.pandas.tests.connect.indexes.test_parity_indexing_iloc", "pyspark.pandas.tests.connect.indexes.test_parity_indexing_loc", + "pyspark.pandas.tests.connect.indexes.test_parity_indexing_loc_2d", "pyspark.pandas.tests.connect.indexes.test_parity_indexing_loc_multi_idx", "pyspark.pandas.tests.connect.indexes.test_parity_reindex", "pyspark.pandas.tests.connect.indexes.test_parity_rename", diff --git a/docs/sql-error-conditions-invalid-handle-error-class.md b/docs/sql-error-conditions-invalid-handle-error-class.md index 14526cd53724..8df8e54a8d9d 100644 --- a/docs/sql-error-conditions-invalid-handle-error-class.md +++ b/docs/sql-error-conditions-invalid-handle-error-class.md @@ -41,10 +41,6 @@ Operation already exists. Operation not found. -## SESSION_ALREADY_EXISTS - -Session already exists. - ## SESSION_CLOSED Session was closed. diff --git a/docs/sql-error-conditions-sqlstates.md b/docs/sql-error-conditions-sqlstates.md index 49cfb56b3662..85f1c5c69c33 100644 --- a/docs/sql-error-conditions-sqlstates.md +++ b/docs/sql-error-conditions-sqlstates.md @@ -71,7 +71,7 @@ Spark SQL uses the following `SQLSTATE` classes: - ARITHMETIC_OVERFLOW, CAST_OVERFLOW, CAST_OVERFLOW_IN_TABLE_INSERT, DECIMAL_PRECISION_EXCEEDS_MAX_PRECISION, INVALID_INDEX_OF_ZERO, INCORRECT_END_OFFSET, INCORRECT_RAMP_UP_RATE, INVALID_ARRAY_INDEX, INVALID_ARRAY_INDEX_IN_ELEMENT_AT, NUMERIC_OUT_OF_SUPPORTED_RANGE, NUMERIC_VALUE_OUT_OF_RANGE + ARITHMETIC_OVERFLOW, CAST_OVERFLOW, CAST_OVERFLOW_IN_TABLE_INSERT, DECIMAL_PRECISION_EXCEEDS_MAX_PRECISION, INVALID_INDEX_OF_ZERO, INCORRECT_RAMP_UP_RATE, INVALID_ARRAY_INDEX, INVALID_ARRAY_INDEX_IN_ELEMENT_AT, NUMERIC_OUT_OF_SUPPORTED_RANGE, NUMERIC_VALUE_OUT_OF_RANGE diff --git a/docs/sql-error-conditions.md b/docs/sql-error-conditions.md index a8d2b6c894bc..248839666ef2 100644 --- a/docs/sql-error-conditions.md +++ b/docs/sql-error-conditions.md @@ -474,12 +474,6 @@ For more details see [DATATYPE_MISMATCH](sql-error-conditions-datatype-mismatch- DataType `` requires a length parameter, for example ``(10). Please specify the length. -### DATA_SOURCE_ALREADY_EXISTS - -[SQLSTATE: 42710](sql-error-conditions-sqlstates.html#class-42-syntax-error-or-access-rule-violation) - -Data source '``' already exists in the registry. Please use a different name for the new data source. - ### DATA_SOURCE_NOT_EXIST [SQLSTATE: 42704](sql-error-conditions-sqlstates.html#class-42-syntax-error-or-access-rule-violation) @@ -886,12 +880,6 @@ You may get a different result due to the upgrading to For more details see [INCONSISTENT_BEHAVIOR_CROSS_VERSION](sql-error-conditions-inconsistent-behavior-cross-version-error-class.html) -### INCORRECT_END_OFFSET - -[SQLSTATE: 22003](sql-error-conditions-sqlstates.html#class-22-data-exception) - -Max offset with `` rowsPerSecond is ``, but it's `` now. - ### INCORRECT_RAMP_UP_RATE [SQLSTATE: 22003](sql-error-conditions-sqlstates.html#class-22-data-exception) diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index 65c7d399a88b..1e6be16ef62b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -324,13 +324,22 @@ class ALSModel private[ml] ( // create a new column named map(predictionCol) by running the predict UDF. val validatedUsers = checkIntegers(dataset, $(userCol)) val validatedItems = checkIntegers(dataset, $(itemCol)) + + val validatedInputAlias = Identifiable.randomUID("__als_validated_input") + val itemFactorsAlias = Identifiable.randomUID("__als_item_factors") + val userFactorsAlias = Identifiable.randomUID("__als_user_factors") + val predictions = dataset - .join(userFactors, - validatedUsers === userFactors("id"), "left") - .join(itemFactors, - validatedItems === itemFactors("id"), "left") - .select(dataset("*"), - predict(userFactors("features"), itemFactors("features")).as($(predictionCol))) + .withColumns(Seq($(userCol), $(itemCol)), Seq(validatedUsers, validatedItems)) + .alias(validatedInputAlias) + .join(userFactors.alias(userFactorsAlias), + col(s"${validatedInputAlias}.${$(userCol)}") === col(s"${userFactorsAlias}.id"), "left") + .join(itemFactors.alias(itemFactorsAlias), + col(s"${validatedInputAlias}.${$(itemCol)}") === col(s"${itemFactorsAlias}.id"), "left") + .select(col(s"${validatedInputAlias}.*"), + predict(col(s"${userFactorsAlias}.features"), col(s"${itemFactorsAlias}.features")) + .alias($(predictionCol))) + getColdStartStrategy match { case ALSModel.Drop => predictions.na.drop("all", Seq($(predictionCol))) diff --git a/pom.xml b/pom.xml index 09a1d3d49998..1938708fee95 100644 --- a/pom.xml +++ b/pom.xml @@ -223,6 +223,7 @@ 1.6.0 1.77 1.9.0 + 3.3.0 4.1.100.Final 2.0.61.Final + + org.apache.datasketches + datasketches-java + ${datasketches.version} + diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index db546dcdd5bb..ac86aeee3d28 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -946,7 +946,7 @@ object Unsafe { object DockerIntegrationTests { // This serves to override the override specified in DependencyOverrides: lazy val settings = Seq( - dependencyOverrides += "com.google.guava" % "guava" % "18.0" + dependencyOverrides += "com.google.guava" % "guava" % "19.0" ) } diff --git a/python/pyspark/errors/exceptions/connect.py b/python/pyspark/errors/exceptions/connect.py index ba172135cb64..0cffe7268753 100644 --- a/python/pyspark/errors/exceptions/connect.py +++ b/python/pyspark/errors/exceptions/connect.py @@ -69,6 +69,9 @@ def convert_exception( if "errorClass" in info.metadata: error_class = info.metadata["errorClass"] + if "messageParameters" in info.metadata: + message_parameters = json.loads(info.metadata["messageParameters"]) + stacktrace: Optional[str] = None if resp is not None and resp.HasField("root_error_idx"): message = resp.errors[resp.root_error_idx].message diff --git a/python/pyspark/ml/tests/test_als.py b/python/pyspark/ml/tests/test_als.py new file mode 100644 index 000000000000..8eec0d937768 --- /dev/null +++ b/python/pyspark/ml/tests/test_als.py @@ -0,0 +1,68 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import tempfile +import unittest + +import pyspark.sql.functions as sf +from pyspark.ml.recommendation import ALS, ALSModel +from pyspark.testing.sqlutils import ReusedSQLTestCase + + +class ALSTest(ReusedSQLTestCase): + def test_ambiguous_column(self): + data = self.spark.createDataFrame( + [[1, 15, 1], [1, 2, 2], [2, 3, 4], [2, 2, 5]], + ["user", "item", "rating"], + ) + model = ALS( + userCol="user", + itemCol="item", + ratingCol="rating", + numUserBlocks=10, + numItemBlocks=10, + maxIter=1, + seed=42, + ).fit(data) + + with tempfile.TemporaryDirectory() as d: + model.write().overwrite().save(d) + loaded_model = ALSModel().load(d) + + with self.sql_conf({"spark.sql.analyzer.failAmbiguousSelfJoin": False}): + users = loaded_model.userFactors.select(sf.col("id").alias("user")) + items = loaded_model.itemFactors.select(sf.col("id").alias("item")) + predictions = loaded_model.transform(users.crossJoin(items)) + self.assertTrue(predictions.count() > 0) + + with self.sql_conf({"spark.sql.analyzer.failAmbiguousSelfJoin": True}): + users = loaded_model.userFactors.select(sf.col("id").alias("user")) + items = loaded_model.itemFactors.select(sf.col("id").alias("item")) + predictions = loaded_model.transform(users.crossJoin(items)) + self.assertTrue(predictions.count() > 0) + + +if __name__ == "__main__": + from pyspark.ml.tests.test_als import * # noqa: F401 + + try: + import xmlrunner + + testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/pandas/tests/connect/indexes/test_parity_indexing_adv.py b/python/pyspark/pandas/tests/connect/indexes/test_parity_indexing_adv.py new file mode 100644 index 000000000000..fe1e6b6c745a --- /dev/null +++ b/python/pyspark/pandas/tests/connect/indexes/test_parity_indexing_adv.py @@ -0,0 +1,41 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import unittest + +from pyspark.pandas.tests.indexes.test_indexing_adv import IndexingAdvMixin +from pyspark.testing.connectutils import ReusedConnectTestCase +from pyspark.testing.pandasutils import PandasOnSparkTestUtils + + +class IndexingAdvParityTests( + IndexingAdvMixin, + PandasOnSparkTestUtils, + ReusedConnectTestCase, +): + pass + + +if __name__ == "__main__": + from pyspark.pandas.tests.connect.indexes.test_parity_indexing import * # noqa: F401 + + try: + import xmlrunner # type: ignore[import] + + testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/pandas/tests/connect/indexes/test_parity_indexing_loc_2d.py b/python/pyspark/pandas/tests/connect/indexes/test_parity_indexing_loc_2d.py new file mode 100644 index 000000000000..18e0f9088223 --- /dev/null +++ b/python/pyspark/pandas/tests/connect/indexes/test_parity_indexing_loc_2d.py @@ -0,0 +1,41 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import unittest + +from pyspark.pandas.tests.indexes.test_indexing_loc_2d import IndexingLoc2DMixin +from pyspark.testing.connectutils import ReusedConnectTestCase +from pyspark.testing.pandasutils import PandasOnSparkTestUtils + + +class IndexingLoc2DParityTests( + IndexingLoc2DMixin, + PandasOnSparkTestUtils, + ReusedConnectTestCase, +): + pass + + +if __name__ == "__main__": + from pyspark.pandas.tests.connect.indexes.test_parity_indexing_loc_2d import * # noqa + + try: + import xmlrunner # type: ignore[import] + + testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/pandas/tests/test_indexing.py b/python/pyspark/pandas/tests/indexes/test_indexing_adv.py similarity index 62% rename from python/pyspark/pandas/tests/test_indexing.py rename to python/pyspark/pandas/tests/indexes/test_indexing_adv.py index ef496c3b5565..2ad6523f5926 100644 --- a/python/pyspark/pandas/tests/test_indexing.py +++ b/python/pyspark/pandas/tests/indexes/test_indexing_adv.py @@ -22,11 +22,12 @@ import pandas as pd from pyspark import pandas as ps -from pyspark.pandas.exceptions import SparkPandasIndexingError, SparkPandasNotImplementedError -from pyspark.testing.pandasutils import ComparisonTestBase, compare_both +from pyspark.pandas.exceptions import SparkPandasNotImplementedError +from pyspark.testing.pandasutils import PandasOnSparkTestCase, compare_both +from pyspark.testing.sqlutils import SQLTestUtils -class IndexingTest(ComparisonTestBase): +class IndexingAdvMixin: @property def pdf(self): return pd.DataFrame( @@ -41,6 +42,10 @@ def pdf2(self): index=[0, 1, 3, 5, 6, 8, 9, 9, 9], ) + @property + def psdf(self): + return ps.from_pandas(self.pdf) + @property def psdf2(self): return ps.from_pandas(self.pdf2) @@ -196,184 +201,6 @@ def test_iat_multiindex_columns(self): with self.assertRaises(KeyError): psdf.iat[99, 0] - def test_loc2d_multiindex(self): - psdf = self.psdf - psdf = psdf.set_index("b", append=True) - pdf = self.pdf - pdf = pdf.set_index("b", append=True) - - self.assert_eq(psdf.loc[:, :], pdf.loc[:, :]) - self.assert_eq(psdf.loc[:, "a"], pdf.loc[:, "a"]) - self.assert_eq(psdf.loc[5:5, "a"], pdf.loc[5:5, "a"]) - - self.assert_eq(psdf.loc[:, "a":"a"], pdf.loc[:, "a":"a"]) - self.assert_eq(psdf.loc[:, "a":"c"], pdf.loc[:, "a":"c"]) - self.assert_eq(psdf.loc[:, "b":"c"], pdf.loc[:, "b":"c"]) - - def test_loc2d(self): - psdf = self.psdf - pdf = self.pdf - - # index indexer is always regarded as slice for duplicated values - self.assert_eq(psdf.loc[5:5, "a"], pdf.loc[5:5, "a"]) - self.assert_eq(psdf.loc[[5], "a"], pdf.loc[[5], "a"]) - self.assert_eq(psdf.loc[5:5, ["a"]], pdf.loc[5:5, ["a"]]) - self.assert_eq(psdf.loc[[5], ["a"]], pdf.loc[[5], ["a"]]) - self.assert_eq(psdf.loc[:, :], pdf.loc[:, :]) - - self.assert_eq(psdf.loc[3:8, "a"], pdf.loc[3:8, "a"]) - self.assert_eq(psdf.loc[:8, "a"], pdf.loc[:8, "a"]) - self.assert_eq(psdf.loc[3:, "a"], pdf.loc[3:, "a"]) - self.assert_eq(psdf.loc[[8], "a"], pdf.loc[[8], "a"]) - - self.assert_eq(psdf.loc[3:8, ["a"]], pdf.loc[3:8, ["a"]]) - self.assert_eq(psdf.loc[:8, ["a"]], pdf.loc[:8, ["a"]]) - self.assert_eq(psdf.loc[3:, ["a"]], pdf.loc[3:, ["a"]]) - # TODO?: self.assert_eq(psdf.loc[[3, 4, 3], ['a']], pdf.loc[[3, 4, 3], ['a']]) - - self.assertRaises(SparkPandasIndexingError, lambda: psdf.loc[3, 3, 3]) - self.assertRaises(SparkPandasIndexingError, lambda: psdf.a.loc[3, 3]) - self.assertRaises(SparkPandasIndexingError, lambda: psdf.a.loc[3:, 3]) - self.assertRaises(SparkPandasIndexingError, lambda: psdf.a.loc[psdf.a % 2 == 0, 3]) - - self.assert_eq(psdf.loc[5, "a"], pdf.loc[5, "a"]) - self.assert_eq(psdf.loc[9, "a"], pdf.loc[9, "a"]) - self.assert_eq(psdf.loc[5, ["a"]], pdf.loc[5, ["a"]]) - self.assert_eq(psdf.loc[9, ["a"]], pdf.loc[9, ["a"]]) - - self.assert_eq(psdf.loc[:, "a":"a"], pdf.loc[:, "a":"a"]) - self.assert_eq(psdf.loc[:, "a":"d"], pdf.loc[:, "a":"d"]) - self.assert_eq(psdf.loc[:, "c":"d"], pdf.loc[:, "c":"d"]) - - # bool list-like column select - bool_list = [True, False] - self.assert_eq(psdf.loc[:, bool_list], pdf.loc[:, bool_list]) - self.assert_eq(psdf.loc[:, np.array(bool_list)], pdf.loc[:, np.array(bool_list)]) - - pser = pd.Series(bool_list, index=pdf.columns) - self.assert_eq(psdf.loc[:, pser], pdf.loc[:, pser]) - pser = pd.Series(list(reversed(bool_list)), index=list(reversed(pdf.columns))) - self.assert_eq(psdf.loc[:, pser], pdf.loc[:, pser]) - - self.assertRaises(IndexError, lambda: psdf.loc[:, bool_list[:-1]]) - self.assertRaises(IndexError, lambda: psdf.loc[:, np.array(bool_list + [True])]) - self.assertRaises(SparkPandasIndexingError, lambda: psdf.loc[:, pd.Series(bool_list)]) - - # non-string column names - psdf = self.psdf2 - pdf = self.pdf2 - - self.assert_eq(psdf.loc[5:5, 0], pdf.loc[5:5, 0]) - self.assert_eq(psdf.loc[5:5, [0]], pdf.loc[5:5, [0]]) - self.assert_eq(psdf.loc[3:8, 0], pdf.loc[3:8, 0]) - self.assert_eq(psdf.loc[3:8, [0]], pdf.loc[3:8, [0]]) - - self.assert_eq(psdf.loc[:, 0:0], pdf.loc[:, 0:0]) - self.assert_eq(psdf.loc[:, 0:3], pdf.loc[:, 0:3]) - self.assert_eq(psdf.loc[:, 2:3], pdf.loc[:, 2:3]) - - def test_loc2d_multiindex_columns(self): - arrays = [np.array(["bar", "bar", "baz", "baz"]), np.array(["one", "two", "one", "two"])] - - pdf = pd.DataFrame(np.random.randn(3, 4), index=["A", "B", "C"], columns=arrays) - psdf = ps.from_pandas(pdf) - - self.assert_eq(psdf.loc["B":"B", "bar"], pdf.loc["B":"B", "bar"]) - self.assert_eq(psdf.loc["B":"B", ["bar"]], pdf.loc["B":"B", ["bar"]]) - - self.assert_eq(psdf.loc[:, "bar":"bar"], pdf.loc[:, "bar":"bar"]) - self.assert_eq(psdf.loc[:, "bar":("baz", "one")], pdf.loc[:, "bar":("baz", "one")]) - self.assert_eq( - psdf.loc[:, ("bar", "two"):("baz", "one")], pdf.loc[:, ("bar", "two"):("baz", "one")] - ) - self.assert_eq(psdf.loc[:, ("bar", "two"):"bar"], pdf.loc[:, ("bar", "two"):"bar"]) - self.assert_eq(psdf.loc[:, "a":"bax"], pdf.loc[:, "a":"bax"]) - self.assert_eq( - psdf.loc[:, ("bar", "x"):("baz", "a")], - pdf.loc[:, ("bar", "x"):("baz", "a")], - almost=True, - ) - - pdf = pd.DataFrame( - np.random.randn(3, 4), - index=["A", "B", "C"], - columns=pd.MultiIndex.from_tuples( - [("bar", "two"), ("bar", "one"), ("baz", "one"), ("baz", "two")] - ), - ) - psdf = ps.from_pandas(pdf) - - self.assert_eq(psdf.loc[:, "bar":"baz"], pdf.loc[:, "bar":"baz"]) - - self.assertRaises(KeyError, lambda: psdf.loc[:, "bar":("baz", "one")]) - self.assertRaises(KeyError, lambda: psdf.loc[:, ("bar", "two"):"bar"]) - - # bool list-like column select - bool_list = [True, False, True, False] - self.assert_eq(psdf.loc[:, bool_list], pdf.loc[:, bool_list]) - self.assert_eq(psdf.loc[:, np.array(bool_list)], pdf.loc[:, np.array(bool_list)]) - - pser = pd.Series(bool_list, index=pdf.columns) - self.assert_eq(psdf.loc[:, pser], pdf.loc[:, pser]) - - pser = pd.Series(list(reversed(bool_list)), index=list(reversed(pdf.columns))) - self.assert_eq(psdf.loc[:, pser], pdf.loc[:, pser]) - - # non-string column names - arrays = [np.array([0, 0, 1, 1]), np.array([1, 2, 1, 2])] - - pdf = pd.DataFrame(np.random.randn(3, 4), index=["A", "B", "C"], columns=arrays) - psdf = ps.from_pandas(pdf) - - self.assert_eq(psdf.loc["B":"B", 0], pdf.loc["B":"B", 0]) - self.assert_eq(psdf.loc["B":"B", [0]], pdf.loc["B":"B", [0]]) - self.assert_eq(psdf.loc[:, 0:0], pdf.loc[:, 0:0]) - self.assert_eq(psdf.loc[:, 0:(1, 1)], pdf.loc[:, 0:(1, 1)]) - self.assert_eq(psdf.loc[:, (0, 2):(1, 1)], pdf.loc[:, (0, 2):(1, 1)]) - self.assert_eq(psdf.loc[:, (0, 2):0], pdf.loc[:, (0, 2):0]) - self.assert_eq(psdf.loc[:, -1:2], pdf.loc[:, -1:2]) - - def test_loc2d_with_known_divisions(self): - pdf = pd.DataFrame( - np.random.randn(20, 5), index=list("abcdefghijklmnopqrst"), columns=list("ABCDE") - ) - psdf = ps.from_pandas(pdf) - - self.assert_eq(psdf.loc[["a"], "A"], pdf.loc[["a"], "A"]) - self.assert_eq(psdf.loc[["a"], ["A"]], pdf.loc[["a"], ["A"]]) - self.assert_eq(psdf.loc["a":"o", "A"], pdf.loc["a":"o", "A"]) - self.assert_eq(psdf.loc["a":"o", ["A"]], pdf.loc["a":"o", ["A"]]) - self.assert_eq(psdf.loc[["n"], ["A"]], pdf.loc[["n"], ["A"]]) - self.assert_eq(psdf.loc[["a", "c", "n"], ["A"]], pdf.loc[["a", "c", "n"], ["A"]]) - # TODO?: self.assert_eq(psdf.loc[['t', 'b'], ['A']], pdf.loc[['t', 'b'], ['A']]) - # TODO?: self.assert_eq(psdf.loc[['r', 'r', 'c', 'g', 'h'], ['A']], - # TODO?: pdf.loc[['r', 'r', 'c', 'g', 'h'], ['A']]) - - @unittest.skip("TODO: should handle duplicated columns properly") - def test_loc2d_duplicated_columns(self): - pdf = pd.DataFrame( - np.random.randn(20, 5), index=list("abcdefghijklmnopqrst"), columns=list("AABCD") - ) - psdf = ps.from_pandas(pdf) - - # TODO?: self.assert_eq(psdf.loc[['a'], 'A'], pdf.loc[['a'], 'A']) - # TODO?: self.assert_eq(psdf.loc[['a'], ['A']], pdf.loc[['a'], ['A']]) - self.assert_eq(psdf.loc[["j"], "B"], pdf.loc[["j"], "B"]) - self.assert_eq(psdf.loc[["j"], ["B"]], pdf.loc[["j"], ["B"]]) - - # TODO?: self.assert_eq(psdf.loc['a':'o', 'A'], pdf.loc['a':'o', 'A']) - # TODO?: self.assert_eq(psdf.loc['a':'o', ['A']], pdf.loc['a':'o', ['A']]) - self.assert_eq(psdf.loc["j":"q", "B"], pdf.loc["j":"q", "B"]) - self.assert_eq(psdf.loc["j":"q", ["B"]], pdf.loc["j":"q", ["B"]]) - - # TODO?: self.assert_eq(psdf.loc['a':'o', 'B':'D'], pdf.loc['a':'o', 'B':'D']) - # TODO?: self.assert_eq(psdf.loc['a':'o', 'B':'D'], pdf.loc['a':'o', 'B':'D']) - # TODO?: self.assert_eq(psdf.loc['j':'q', 'B':'A'], pdf.loc['j':'q', 'B':'A']) - # TODO?: self.assert_eq(psdf.loc['j':'q', 'B':'A'], pdf.loc['j':'q', 'B':'A']) - - self.assert_eq(psdf.loc[psdf.B > 0, "B"], pdf.loc[pdf.B > 0, "B"]) - # TODO?: self.assert_eq(psdf.loc[psdf.B > 0, ['A', 'C']], pdf.loc[pdf.B > 0, ['A', 'C']]) - def test_getitem(self): pdf = pd.DataFrame( { @@ -558,8 +385,16 @@ def test_index_operator_int(self): psdf.iloc[[1, 1]] +class IndexingAdvTests( + IndexingAdvMixin, + PandasOnSparkTestCase, + SQLTestUtils, +): + pass + + if __name__ == "__main__": - from pyspark.pandas.tests.test_indexing import * # noqa: F401 + from pyspark.pandas.tests.indexes.test_indexing_adv import * # noqa: F401 try: import xmlrunner diff --git a/python/pyspark/pandas/tests/indexes/test_indexing_loc_2d.py b/python/pyspark/pandas/tests/indexes/test_indexing_loc_2d.py new file mode 100644 index 000000000000..88f41d1aade3 --- /dev/null +++ b/python/pyspark/pandas/tests/indexes/test_indexing_loc_2d.py @@ -0,0 +1,247 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import unittest + +import numpy as np +import pandas as pd + +from pyspark import pandas as ps +from pyspark.pandas.exceptions import SparkPandasIndexingError, SparkPandasNotImplementedError +from pyspark.testing.pandasutils import PandasOnSparkTestCase +from pyspark.testing.sqlutils import SQLTestUtils + + +class IndexingLoc2DMixin: + @property + def pdf(self): + return pd.DataFrame( + {"a": [1, 2, 3, 4, 5, 6, 7, 8, 9], "b": [4, 5, 6, 3, 2, 1, 0, 0, 0]}, + index=[0, 1, 3, 5, 6, 8, 9, 9, 9], + ) + + @property + def pdf2(self): + return pd.DataFrame( + {0: [1, 2, 3, 4, 5, 6, 7, 8, 9], 1: [4, 5, 6, 3, 2, 1, 0, 0, 0]}, + index=[0, 1, 3, 5, 6, 8, 9, 9, 9], + ) + + @property + def psdf(self): + return ps.from_pandas(self.pdf) + + @property + def psdf2(self): + return ps.from_pandas(self.pdf2) + + def test_loc2d_multiindex(self): + psdf = self.psdf + psdf = psdf.set_index("b", append=True) + pdf = self.pdf + pdf = pdf.set_index("b", append=True) + + self.assert_eq(psdf.loc[:, :], pdf.loc[:, :]) + self.assert_eq(psdf.loc[:, "a"], pdf.loc[:, "a"]) + self.assert_eq(psdf.loc[5:5, "a"], pdf.loc[5:5, "a"]) + + self.assert_eq(psdf.loc[:, "a":"a"], pdf.loc[:, "a":"a"]) + self.assert_eq(psdf.loc[:, "a":"c"], pdf.loc[:, "a":"c"]) + self.assert_eq(psdf.loc[:, "b":"c"], pdf.loc[:, "b":"c"]) + + def test_loc2d(self): + psdf = self.psdf + pdf = self.pdf + + # index indexer is always regarded as slice for duplicated values + self.assert_eq(psdf.loc[5:5, "a"], pdf.loc[5:5, "a"]) + self.assert_eq(psdf.loc[[5], "a"], pdf.loc[[5], "a"]) + self.assert_eq(psdf.loc[5:5, ["a"]], pdf.loc[5:5, ["a"]]) + self.assert_eq(psdf.loc[[5], ["a"]], pdf.loc[[5], ["a"]]) + self.assert_eq(psdf.loc[:, :], pdf.loc[:, :]) + + self.assert_eq(psdf.loc[3:8, "a"], pdf.loc[3:8, "a"]) + self.assert_eq(psdf.loc[:8, "a"], pdf.loc[:8, "a"]) + self.assert_eq(psdf.loc[3:, "a"], pdf.loc[3:, "a"]) + self.assert_eq(psdf.loc[[8], "a"], pdf.loc[[8], "a"]) + + self.assert_eq(psdf.loc[3:8, ["a"]], pdf.loc[3:8, ["a"]]) + self.assert_eq(psdf.loc[:8, ["a"]], pdf.loc[:8, ["a"]]) + self.assert_eq(psdf.loc[3:, ["a"]], pdf.loc[3:, ["a"]]) + # TODO?: self.assert_eq(psdf.loc[[3, 4, 3], ['a']], pdf.loc[[3, 4, 3], ['a']]) + + self.assertRaises(SparkPandasIndexingError, lambda: psdf.loc[3, 3, 3]) + self.assertRaises(SparkPandasIndexingError, lambda: psdf.a.loc[3, 3]) + self.assertRaises(SparkPandasIndexingError, lambda: psdf.a.loc[3:, 3]) + self.assertRaises(SparkPandasIndexingError, lambda: psdf.a.loc[psdf.a % 2 == 0, 3]) + + self.assert_eq(psdf.loc[5, "a"], pdf.loc[5, "a"]) + self.assert_eq(psdf.loc[9, "a"], pdf.loc[9, "a"]) + self.assert_eq(psdf.loc[5, ["a"]], pdf.loc[5, ["a"]]) + self.assert_eq(psdf.loc[9, ["a"]], pdf.loc[9, ["a"]]) + + self.assert_eq(psdf.loc[:, "a":"a"], pdf.loc[:, "a":"a"]) + self.assert_eq(psdf.loc[:, "a":"d"], pdf.loc[:, "a":"d"]) + self.assert_eq(psdf.loc[:, "c":"d"], pdf.loc[:, "c":"d"]) + + # bool list-like column select + bool_list = [True, False] + self.assert_eq(psdf.loc[:, bool_list], pdf.loc[:, bool_list]) + self.assert_eq(psdf.loc[:, np.array(bool_list)], pdf.loc[:, np.array(bool_list)]) + + pser = pd.Series(bool_list, index=pdf.columns) + self.assert_eq(psdf.loc[:, pser], pdf.loc[:, pser]) + pser = pd.Series(list(reversed(bool_list)), index=list(reversed(pdf.columns))) + self.assert_eq(psdf.loc[:, pser], pdf.loc[:, pser]) + + self.assertRaises(IndexError, lambda: psdf.loc[:, bool_list[:-1]]) + self.assertRaises(IndexError, lambda: psdf.loc[:, np.array(bool_list + [True])]) + self.assertRaises(SparkPandasIndexingError, lambda: psdf.loc[:, pd.Series(bool_list)]) + + # non-string column names + psdf = self.psdf2 + pdf = self.pdf2 + + self.assert_eq(psdf.loc[5:5, 0], pdf.loc[5:5, 0]) + self.assert_eq(psdf.loc[5:5, [0]], pdf.loc[5:5, [0]]) + self.assert_eq(psdf.loc[3:8, 0], pdf.loc[3:8, 0]) + self.assert_eq(psdf.loc[3:8, [0]], pdf.loc[3:8, [0]]) + + self.assert_eq(psdf.loc[:, 0:0], pdf.loc[:, 0:0]) + self.assert_eq(psdf.loc[:, 0:3], pdf.loc[:, 0:3]) + self.assert_eq(psdf.loc[:, 2:3], pdf.loc[:, 2:3]) + + def test_loc2d_multiindex_columns(self): + arrays = [np.array(["bar", "bar", "baz", "baz"]), np.array(["one", "two", "one", "two"])] + + pdf = pd.DataFrame(np.random.randn(3, 4), index=["A", "B", "C"], columns=arrays) + psdf = ps.from_pandas(pdf) + + self.assert_eq(psdf.loc["B":"B", "bar"], pdf.loc["B":"B", "bar"]) + self.assert_eq(psdf.loc["B":"B", ["bar"]], pdf.loc["B":"B", ["bar"]]) + + self.assert_eq(psdf.loc[:, "bar":"bar"], pdf.loc[:, "bar":"bar"]) + self.assert_eq(psdf.loc[:, "bar":("baz", "one")], pdf.loc[:, "bar":("baz", "one")]) + self.assert_eq( + psdf.loc[:, ("bar", "two"):("baz", "one")], pdf.loc[:, ("bar", "two"):("baz", "one")] + ) + self.assert_eq(psdf.loc[:, ("bar", "two"):"bar"], pdf.loc[:, ("bar", "two"):"bar"]) + self.assert_eq(psdf.loc[:, "a":"bax"], pdf.loc[:, "a":"bax"]) + self.assert_eq( + psdf.loc[:, ("bar", "x"):("baz", "a")], + pdf.loc[:, ("bar", "x"):("baz", "a")], + almost=True, + ) + + pdf = pd.DataFrame( + np.random.randn(3, 4), + index=["A", "B", "C"], + columns=pd.MultiIndex.from_tuples( + [("bar", "two"), ("bar", "one"), ("baz", "one"), ("baz", "two")] + ), + ) + psdf = ps.from_pandas(pdf) + + self.assert_eq(psdf.loc[:, "bar":"baz"], pdf.loc[:, "bar":"baz"]) + + self.assertRaises(KeyError, lambda: psdf.loc[:, "bar":("baz", "one")]) + self.assertRaises(KeyError, lambda: psdf.loc[:, ("bar", "two"):"bar"]) + + # bool list-like column select + bool_list = [True, False, True, False] + self.assert_eq(psdf.loc[:, bool_list], pdf.loc[:, bool_list]) + self.assert_eq(psdf.loc[:, np.array(bool_list)], pdf.loc[:, np.array(bool_list)]) + + pser = pd.Series(bool_list, index=pdf.columns) + self.assert_eq(psdf.loc[:, pser], pdf.loc[:, pser]) + + pser = pd.Series(list(reversed(bool_list)), index=list(reversed(pdf.columns))) + self.assert_eq(psdf.loc[:, pser], pdf.loc[:, pser]) + + # non-string column names + arrays = [np.array([0, 0, 1, 1]), np.array([1, 2, 1, 2])] + + pdf = pd.DataFrame(np.random.randn(3, 4), index=["A", "B", "C"], columns=arrays) + psdf = ps.from_pandas(pdf) + + self.assert_eq(psdf.loc["B":"B", 0], pdf.loc["B":"B", 0]) + self.assert_eq(psdf.loc["B":"B", [0]], pdf.loc["B":"B", [0]]) + self.assert_eq(psdf.loc[:, 0:0], pdf.loc[:, 0:0]) + self.assert_eq(psdf.loc[:, 0:(1, 1)], pdf.loc[:, 0:(1, 1)]) + self.assert_eq(psdf.loc[:, (0, 2):(1, 1)], pdf.loc[:, (0, 2):(1, 1)]) + self.assert_eq(psdf.loc[:, (0, 2):0], pdf.loc[:, (0, 2):0]) + self.assert_eq(psdf.loc[:, -1:2], pdf.loc[:, -1:2]) + + def test_loc2d_with_known_divisions(self): + pdf = pd.DataFrame( + np.random.randn(20, 5), index=list("abcdefghijklmnopqrst"), columns=list("ABCDE") + ) + psdf = ps.from_pandas(pdf) + + self.assert_eq(psdf.loc[["a"], "A"], pdf.loc[["a"], "A"]) + self.assert_eq(psdf.loc[["a"], ["A"]], pdf.loc[["a"], ["A"]]) + self.assert_eq(psdf.loc["a":"o", "A"], pdf.loc["a":"o", "A"]) + self.assert_eq(psdf.loc["a":"o", ["A"]], pdf.loc["a":"o", ["A"]]) + self.assert_eq(psdf.loc[["n"], ["A"]], pdf.loc[["n"], ["A"]]) + self.assert_eq(psdf.loc[["a", "c", "n"], ["A"]], pdf.loc[["a", "c", "n"], ["A"]]) + # TODO?: self.assert_eq(psdf.loc[['t', 'b'], ['A']], pdf.loc[['t', 'b'], ['A']]) + # TODO?: self.assert_eq(psdf.loc[['r', 'r', 'c', 'g', 'h'], ['A']], + # TODO?: pdf.loc[['r', 'r', 'c', 'g', 'h'], ['A']]) + + @unittest.skip("TODO: should handle duplicated columns properly") + def test_loc2d_duplicated_columns(self): + pdf = pd.DataFrame( + np.random.randn(20, 5), index=list("abcdefghijklmnopqrst"), columns=list("AABCD") + ) + psdf = ps.from_pandas(pdf) + + # TODO?: self.assert_eq(psdf.loc[['a'], 'A'], pdf.loc[['a'], 'A']) + # TODO?: self.assert_eq(psdf.loc[['a'], ['A']], pdf.loc[['a'], ['A']]) + self.assert_eq(psdf.loc[["j"], "B"], pdf.loc[["j"], "B"]) + self.assert_eq(psdf.loc[["j"], ["B"]], pdf.loc[["j"], ["B"]]) + + # TODO?: self.assert_eq(psdf.loc['a':'o', 'A'], pdf.loc['a':'o', 'A']) + # TODO?: self.assert_eq(psdf.loc['a':'o', ['A']], pdf.loc['a':'o', ['A']]) + self.assert_eq(psdf.loc["j":"q", "B"], pdf.loc["j":"q", "B"]) + self.assert_eq(psdf.loc["j":"q", ["B"]], pdf.loc["j":"q", ["B"]]) + + # TODO?: self.assert_eq(psdf.loc['a':'o', 'B':'D'], pdf.loc['a':'o', 'B':'D']) + # TODO?: self.assert_eq(psdf.loc['a':'o', 'B':'D'], pdf.loc['a':'o', 'B':'D']) + # TODO?: self.assert_eq(psdf.loc['j':'q', 'B':'A'], pdf.loc['j':'q', 'B':'A']) + # TODO?: self.assert_eq(psdf.loc['j':'q', 'B':'A'], pdf.loc['j':'q', 'B':'A']) + + self.assert_eq(psdf.loc[psdf.B > 0, "B"], pdf.loc[pdf.B > 0, "B"]) + # TODO?: self.assert_eq(psdf.loc[psdf.B > 0, ['A', 'C']], pdf.loc[pdf.B > 0, ['A', 'C']]) + + +class IndexingLoc2DTests( + IndexingLoc2DMixin, + PandasOnSparkTestCase, + SQLTestUtils, +): + pass + + +if __name__ == "__main__": + from pyspark.pandas.tests.indexes.test_indexing_loc_2d import * # noqa: F401 + + try: + import xmlrunner + + testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/sql/connect/functions/builtin.py b/python/pyspark/sql/connect/functions/builtin.py index cf2e2e0c7344..baf8dc82fd84 100644 --- a/python/pyspark/sql/connect/functions/builtin.py +++ b/python/pyspark/sql/connect/functions/builtin.py @@ -3684,6 +3684,14 @@ def sha1(col: "ColumnOrName") -> Column: def sha2(col: "ColumnOrName", numBits: int) -> Column: + if numBits not in [0, 224, 256, 384, 512]: + raise PySparkValueError( + error_class="VALUE_NOT_ALLOWED", + message_parameters={ + "arg_name": "numBits", + "allowed_values": "[0, 224, 256, 384, 512]", + }, + ) return _invoke_function("sha2", _to_col(col), lit(numBits)) diff --git a/python/pyspark/sql/functions/builtin.py b/python/pyspark/sql/functions/builtin.py index fae4de36638c..0ff1ee2a7394 100644 --- a/python/pyspark/sql/functions/builtin.py +++ b/python/pyspark/sql/functions/builtin.py @@ -9112,6 +9112,14 @@ def sha2(col: "ColumnOrName", numBits: int) -> Column: |Bob |cd9fb1e148ccd8442e5aa74904cc73bf6fb54d1d54d333bd596aa9bb4bb4e961| +-----+----------------------------------------------------------------+ """ + if numBits not in [0, 224, 256, 384, 512]: + raise PySparkValueError( + error_class="VALUE_NOT_ALLOWED", + message_parameters={ + "arg_name": "numBits", + "allowed_values": "[0, 224, 256, 384, 512]", + }, + ) return _invoke_function("sha2", _to_java_column(col), numBits) @@ -12963,7 +12971,7 @@ def array_prepend(col: "ColumnOrName", value: Any) -> Column: @_try_remote_functions def array_remove(col: "ColumnOrName", element: Any) -> Column: """ - Collection function: Remove all elements that equal to element from the given array. + Array function: Remove all elements that equal to element from the given array. .. versionadded:: 2.4.0 @@ -12980,13 +12988,69 @@ def array_remove(col: "ColumnOrName", element: Any) -> Column: Returns ------- :class:`~pyspark.sql.Column` - an array excluding given value. + A new column that is an array excluding the given value from the input column. Examples -------- - >>> df = spark.createDataFrame([([1, 2, 3, 1, 1],), ([],)], ['data']) - >>> df.select(array_remove(df.data, 1)).collect() - [Row(array_remove(data, 1)=[2, 3]), Row(array_remove(data, 1)=[])] + Example 1: Removing a specific value from a simple array + + >>> from pyspark.sql import functions as sf + >>> df = spark.createDataFrame([([1, 2, 3, 1, 1],)], ['data']) + >>> df.select(sf.array_remove(df.data, 1)).show() + +---------------------+ + |array_remove(data, 1)| + +---------------------+ + | [2, 3]| + +---------------------+ + + Example 2: Removing a specific value from multiple arrays + + >>> from pyspark.sql import functions as sf + >>> df = spark.createDataFrame([([1, 2, 3, 1, 1],), ([4, 5, 5, 4],)], ['data']) + >>> df.select(sf.array_remove(df.data, 5)).show() + +---------------------+ + |array_remove(data, 5)| + +---------------------+ + | [1, 2, 3, 1, 1]| + | [4, 4]| + +---------------------+ + + Example 3: Removing a value that does not exist in the array + + >>> from pyspark.sql import functions as sf + >>> df = spark.createDataFrame([([1, 2, 3],)], ['data']) + >>> df.select(sf.array_remove(df.data, 4)).show() + +---------------------+ + |array_remove(data, 4)| + +---------------------+ + | [1, 2, 3]| + +---------------------+ + + Example 4: Removing a value from an array with all identical values + + >>> from pyspark.sql import functions as sf + >>> df = spark.createDataFrame([([1, 1, 1],)], ['data']) + >>> df.select(sf.array_remove(df.data, 1)).show() + +---------------------+ + |array_remove(data, 1)| + +---------------------+ + | []| + +---------------------+ + + Example 5: Removing a value from an empty array + + >>> from pyspark.sql import functions as sf + >>> from pyspark.sql.types import ArrayType, IntegerType, StructType, StructField + >>> schema = StructType([ + ... StructField("data", ArrayType(IntegerType()), True) + ... ]) + >>> df = spark.createDataFrame([([],)], schema) + >>> df.select(sf.array_remove(df.data, 1)).show() + +---------------------+ + |array_remove(data, 1)| + +---------------------+ + | []| + +---------------------+ """ return _invoke_function("array_remove", _to_java_column(col), element) @@ -12994,7 +13058,7 @@ def array_remove(col: "ColumnOrName", element: Any) -> Column: @_try_remote_functions def array_distinct(col: "ColumnOrName") -> Column: """ - Collection function: removes duplicate values from the array. + Array function: removes duplicate values from the array. .. versionadded:: 2.4.0 @@ -13009,13 +13073,69 @@ def array_distinct(col: "ColumnOrName") -> Column: Returns ------- :class:`~pyspark.sql.Column` - an array of unique values. + A new column that is an array of unique values from the input column. Examples -------- + Example 1: Removing duplicate values from a simple array + + >>> from pyspark.sql import functions as sf + >>> df = spark.createDataFrame([([1, 2, 3, 2],)], ['data']) + >>> df.select(sf.array_distinct(df.data)).show() + +--------------------+ + |array_distinct(data)| + +--------------------+ + | [1, 2, 3]| + +--------------------+ + + Example 2: Removing duplicate values from multiple arrays + + >>> from pyspark.sql import functions as sf >>> df = spark.createDataFrame([([1, 2, 3, 2],), ([4, 5, 5, 4],)], ['data']) - >>> df.select(array_distinct(df.data)).collect() - [Row(array_distinct(data)=[1, 2, 3]), Row(array_distinct(data)=[4, 5])] + >>> df.select(sf.array_distinct(df.data)).show() + +--------------------+ + |array_distinct(data)| + +--------------------+ + | [1, 2, 3]| + | [4, 5]| + +--------------------+ + + Example 3: Removing duplicate values from an array with all identical values + + >>> from pyspark.sql import functions as sf + >>> df = spark.createDataFrame([([1, 1, 1],)], ['data']) + >>> df.select(sf.array_distinct(df.data)).show() + +--------------------+ + |array_distinct(data)| + +--------------------+ + | [1]| + +--------------------+ + + Example 4: Removing duplicate values from an array with no duplicate values + + >>> from pyspark.sql import functions as sf + >>> df = spark.createDataFrame([([1, 2, 3],)], ['data']) + >>> df.select(sf.array_distinct(df.data)).show() + +--------------------+ + |array_distinct(data)| + +--------------------+ + | [1, 2, 3]| + +--------------------+ + + Example 5: Removing duplicate values from an empty array + + >>> from pyspark.sql import functions as sf + >>> from pyspark.sql.types import ArrayType, IntegerType, StructType, StructField + >>> schema = StructType([ + ... StructField("data", ArrayType(IntegerType()), True) + ... ]) + >>> df = spark.createDataFrame([([],)], schema) + >>> df.select(sf.array_distinct(df.data)).show() + +--------------------+ + |array_distinct(data)| + +--------------------+ + | []| + +--------------------+ """ return _invoke_function_over_columns("array_distinct", col) @@ -13399,7 +13519,7 @@ def array_except(col1: "ColumnOrName", col2: "ColumnOrName") -> Column: @_try_remote_functions def array_compact(col: "ColumnOrName") -> Column: """ - Collection function: removes null values from the array. + Array function: removes null values from the array. .. versionadded:: 3.4.0 @@ -13411,7 +13531,7 @@ def array_compact(col: "ColumnOrName") -> Column: Returns ------- :class:`~pyspark.sql.Column` - an array by excluding the null values. + A new column that is an array excluding the null values from the input column. Notes ----- @@ -13419,9 +13539,69 @@ def array_compact(col: "ColumnOrName") -> Column: Examples -------- + Example 1: Removing null values from a simple array + + >>> from pyspark.sql import functions as sf + >>> df = spark.createDataFrame([([1, None, 2, 3],)], ['data']) + >>> df.select(sf.array_compact(df.data)).show() + +-------------------+ + |array_compact(data)| + +-------------------+ + | [1, 2, 3]| + +-------------------+ + + Example 2: Removing null values from multiple arrays + + >>> from pyspark.sql import functions as sf >>> df = spark.createDataFrame([([1, None, 2, 3],), ([4, 5, None, 4],)], ['data']) - >>> df.select(array_compact(df.data)).collect() - [Row(array_compact(data)=[1, 2, 3]), Row(array_compact(data)=[4, 5, 4])] + >>> df.select(sf.array_compact(df.data)).show() + +-------------------+ + |array_compact(data)| + +-------------------+ + | [1, 2, 3]| + | [4, 5, 4]| + +-------------------+ + + Example 3: Removing null values from an array with all null values + + >>> from pyspark.sql import functions as sf + >>> from pyspark.sql.types import ArrayType, StringType, StructField, StructType + >>> schema = StructType([ + ... StructField("data", ArrayType(StringType()), True) + ... ]) + >>> df = spark.createDataFrame([([None, None, None],)], schema) + >>> df.select(sf.array_compact(df.data)).show() + +-------------------+ + |array_compact(data)| + +-------------------+ + | []| + +-------------------+ + + Example 4: Removing null values from an array with no null values + + >>> from pyspark.sql import functions as sf + >>> df = spark.createDataFrame([([1, 2, 3],)], ['data']) + >>> df.select(sf.array_compact(df.data)).show() + +-------------------+ + |array_compact(data)| + +-------------------+ + | [1, 2, 3]| + +-------------------+ + + Example 5: Removing null values from an empty array + + >>> from pyspark.sql import functions as sf + >>> from pyspark.sql.types import ArrayType, StringType, StructField, StructType + >>> schema = StructType([ + ... StructField("data", ArrayType(StringType()), True) + ... ]) + >>> df = spark.createDataFrame([([],)], schema) + >>> df.select(sf.array_compact(df.data)).show() + +-------------------+ + |array_compact(data)| + +-------------------+ + | []| + +-------------------+ """ return _invoke_function_over_columns("array_compact", col) @@ -14597,7 +14777,7 @@ def size(col: "ColumnOrName") -> Column: @_try_remote_functions def array_min(col: "ColumnOrName") -> Column: """ - Collection function: returns the minimum value of the array. + Array function: returns the minimum value of the array. .. versionadded:: 2.4.0 @@ -14607,18 +14787,74 @@ def array_min(col: "ColumnOrName") -> Column: Parameters ---------- col : :class:`~pyspark.sql.Column` or str - name of column or expression + The name of the column or an expression that represents the array. Returns ------- :class:`~pyspark.sql.Column` - minimum value of array. + A new column that contains the minimum value of each array. Examples -------- + Example 1: Basic usage with integer array + + >>> from pyspark.sql import functions as sf >>> df = spark.createDataFrame([([2, 1, 3],), ([None, 10, -1],)], ['data']) - >>> df.select(array_min(df.data).alias('min')).collect() - [Row(min=1), Row(min=-1)] + >>> df.select(sf.array_min(df.data)).show() + +---------------+ + |array_min(data)| + +---------------+ + | 1| + | -1| + +---------------+ + + Example 2: Usage with string array + + >>> from pyspark.sql import functions as sf + >>> df = spark.createDataFrame([(['apple', 'banana', 'cherry'],)], ['data']) + >>> df.select(sf.array_min(df.data)).show() + +---------------+ + |array_min(data)| + +---------------+ + | apple| + +---------------+ + + Example 3: Usage with mixed type array + + >>> from pyspark.sql import functions as sf + >>> df = spark.createDataFrame([(['apple', 1, 'cherry'],)], ['data']) + >>> df.select(sf.array_min(df.data)).show() + +---------------+ + |array_min(data)| + +---------------+ + | 1| + +---------------+ + + Example 4: Usage with array of arrays + + >>> from pyspark.sql import functions as sf + >>> df = spark.createDataFrame([([[2, 1], [3, 4]],)], ['data']) + >>> df.select(sf.array_min(df.data)).show() + +---------------+ + |array_min(data)| + +---------------+ + | [2, 1]| + +---------------+ + + Example 5: Usage with empty array + + >>> from pyspark.sql import functions as sf + >>> from pyspark.sql.types import ArrayType, IntegerType, StructType, StructField + >>> schema = StructType([ + ... StructField("data", ArrayType(IntegerType()), True) + ... ]) + >>> df = spark.createDataFrame([([],)], schema=schema) + >>> df.select(sf.array_min(df.data)).show() + +---------------+ + |array_min(data)| + +---------------+ + | NULL| + +---------------+ """ return _invoke_function_over_columns("array_min", col) @@ -14626,7 +14862,7 @@ def array_min(col: "ColumnOrName") -> Column: @_try_remote_functions def array_max(col: "ColumnOrName") -> Column: """ - Collection function: returns the maximum value of the array. + Array function: returns the maximum value of the array. .. versionadded:: 2.4.0 @@ -14636,18 +14872,74 @@ def array_max(col: "ColumnOrName") -> Column: Parameters ---------- col : :class:`~pyspark.sql.Column` or str - name of column or expression + The name of the column or an expression that represents the array. Returns ------- :class:`~pyspark.sql.Column` - maximum value of an array. + A new column that contains the maximum value of each array. Examples -------- + Example 1: Basic usage with integer array + + >>> from pyspark.sql import functions as sf >>> df = spark.createDataFrame([([2, 1, 3],), ([None, 10, -1],)], ['data']) - >>> df.select(array_max(df.data).alias('max')).collect() - [Row(max=3), Row(max=10)] + >>> df.select(sf.array_max(df.data)).show() + +---------------+ + |array_max(data)| + +---------------+ + | 3| + | 10| + +---------------+ + + Example 2: Usage with string array + + >>> from pyspark.sql import functions as sf + >>> df = spark.createDataFrame([(['apple', 'banana', 'cherry'],)], ['data']) + >>> df.select(sf.array_max(df.data)).show() + +---------------+ + |array_max(data)| + +---------------+ + | cherry| + +---------------+ + + Example 3: Usage with mixed type array + + >>> from pyspark.sql import functions as sf + >>> df = spark.createDataFrame([(['apple', 1, 'cherry'],)], ['data']) + >>> df.select(sf.array_max(df.data)).show() + +---------------+ + |array_max(data)| + +---------------+ + | cherry| + +---------------+ + + Example 4: Usage with array of arrays + + >>> from pyspark.sql import functions as sf + >>> df = spark.createDataFrame([([[2, 1], [3, 4]],)], ['data']) + >>> df.select(sf.array_max(df.data)).show() + +---------------+ + |array_max(data)| + +---------------+ + | [3, 4]| + +---------------+ + + Example 5: Usage with empty array + + >>> from pyspark.sql import functions as sf + >>> from pyspark.sql.types import ArrayType, IntegerType, StructType, StructField + >>> schema = StructType([ + ... StructField("data", ArrayType(IntegerType()), True) + ... ]) + >>> df = spark.createDataFrame([([],)], schema=schema) + >>> df.select(sf.array_max(df.data)).show() + +---------------+ + |array_max(data)| + +---------------+ + | NULL| + +---------------+ """ return _invoke_function_over_columns("array_max", col) @@ -14655,25 +14947,82 @@ def array_max(col: "ColumnOrName") -> Column: @_try_remote_functions def array_size(col: "ColumnOrName") -> Column: """ - Returns the total number of elements in the array. The function returns null for null input. + Array function: returns the total number of elements in the array. + The function returns null for null input. .. versionadded:: 3.5.0 Parameters ---------- col : :class:`~pyspark.sql.Column` or str - target column to compute on. + The name of the column or an expression that represents the array. Returns ------- :class:`~pyspark.sql.Column` - total number of elements in the array. + A new column that contains the size of each array. Examples -------- + Example 1: Basic usage with integer array + + >>> from pyspark.sql import functions as sf >>> df = spark.createDataFrame([([2, 1, 3],), (None,)], ['data']) - >>> df.select(array_size(df.data).alias('r')).collect() - [Row(r=3), Row(r=None)] + >>> df.select(sf.array_size(df.data)).show() + +----------------+ + |array_size(data)| + +----------------+ + | 3| + | NULL| + +----------------+ + + Example 2: Usage with string array + + >>> from pyspark.sql import functions as sf + >>> df = spark.createDataFrame([(['apple', 'banana', 'cherry'],)], ['data']) + >>> df.select(sf.array_size(df.data)).show() + +----------------+ + |array_size(data)| + +----------------+ + | 3| + +----------------+ + + Example 3: Usage with mixed type array + + >>> from pyspark.sql import functions as sf + >>> df = spark.createDataFrame([(['apple', 1, 'cherry'],)], ['data']) + >>> df.select(sf.array_size(df.data)).show() + +----------------+ + |array_size(data)| + +----------------+ + | 3| + +----------------+ + + Example 4: Usage with array of arrays + + >>> from pyspark.sql import functions as sf + >>> df = spark.createDataFrame([([[2, 1], [3, 4]],)], ['data']) + >>> df.select(sf.array_size(df.data)).show() + +----------------+ + |array_size(data)| + +----------------+ + | 2| + +----------------+ + + Example 5: Usage with empty array + + >>> from pyspark.sql import functions as sf + >>> from pyspark.sql.types import ArrayType, IntegerType, StructType, StructField + >>> schema = StructType([ + ... StructField("data", ArrayType(IntegerType()), True) + ... ]) + >>> df = spark.createDataFrame([([],)], schema=schema) + >>> df.select(sf.array_size(df.data)).show() + +----------------+ + |array_size(data)| + +----------------+ + | 0| + +----------------+ """ return _invoke_function_over_columns("array_size", col) @@ -15096,7 +15445,7 @@ def map_from_entries(col: "ColumnOrName") -> Column: @_try_remote_functions def array_repeat(col: "ColumnOrName", count: Union["ColumnOrName", int]) -> Column: """ - Collection function: creates an array containing a column repeated count times. + Array function: creates an array containing a column repeated count times. .. versionadded:: 2.4.0 @@ -15106,20 +15455,65 @@ def array_repeat(col: "ColumnOrName", count: Union["ColumnOrName", int]) -> Colu Parameters ---------- col : :class:`~pyspark.sql.Column` or str - column name or column that contains the element to be repeated + The name of the column or an expression that represents the element to be repeated. count : :class:`~pyspark.sql.Column` or str or int - column name, column, or int containing the number of times to repeat the first argument + The name of the column, an expression, + or an integer that represents the number of times to repeat the element. Returns ------- :class:`~pyspark.sql.Column` - an array of repeated elements. + A new column that contains an array of repeated elements. Examples -------- + Example 1: Usage with string + + >>> from pyspark.sql import functions as sf >>> df = spark.createDataFrame([('ab',)], ['data']) - >>> df.select(array_repeat(df.data, 3).alias('r')).collect() - [Row(r=['ab', 'ab', 'ab'])] + >>> df.select(sf.array_repeat(df.data, 3)).show() + +---------------------+ + |array_repeat(data, 3)| + +---------------------+ + | [ab, ab, ab]| + +---------------------+ + + Example 2: Usage with integer + + >>> from pyspark.sql import functions as sf + >>> df = spark.createDataFrame([(3,)], ['data']) + >>> df.select(sf.array_repeat(df.data, 2)).show() + +---------------------+ + |array_repeat(data, 2)| + +---------------------+ + | [3, 3]| + +---------------------+ + + Example 3: Usage with array + + >>> from pyspark.sql import functions as sf + >>> df = spark.createDataFrame([(['apple', 'banana'],)], ['data']) + >>> df.select(sf.array_repeat(df.data, 2)).show(truncate=False) + +----------------------------------+ + |array_repeat(data, 2) | + +----------------------------------+ + |[[apple, banana], [apple, banana]]| + +----------------------------------+ + + Example 4: Usage with null + + >>> from pyspark.sql import functions as sf + >>> from pyspark.sql.types import IntegerType, StructType, StructField + >>> schema = StructType([ + ... StructField("data", IntegerType(), True) + ... ]) + >>> df = spark.createDataFrame([(None, )], schema=schema) + >>> df.select(sf.array_repeat(df.data, 3)).show() + +---------------------+ + |array_repeat(data, 3)| + +---------------------+ + | [NULL, NULL, NULL]| + +---------------------+ """ count = lit(count) if isinstance(count, int) else count diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py b/python/pyspark/sql/tests/connect/test_connect_basic.py index 32cd4ed62495..045ba8f0060d 100755 --- a/python/pyspark/sql/tests/connect/test_connect_basic.py +++ b/python/pyspark/sql/tests/connect/test_connect_basic.py @@ -3452,7 +3452,7 @@ def test_error_stack_trace(self): self.spark.stop() spark = ( PySparkSession.builder.config(conf=self.conf()) - .config("spark.connect.jvmStacktrace.maxSize", 128) + .config("spark.connect.grpc.maxMetadataSize", 128) .remote("local[4]") .getOrCreate() ) @@ -3486,6 +3486,17 @@ def test_can_create_multiple_sessions_to_different_remotes(self): PySparkSession.builder.create() self.assertIn("Create a new SparkSession is only supported with SparkConnect.", str(e)) + def test_get_message_parameters_without_enriched_error(self): + with self.sql_conf({"spark.sql.connect.enrichError.enabled": False}): + exception = None + try: + self.spark.sql("""SELECT a""") + except AnalysisException as e: + exception = e + + self.assertIsNotNone(exception) + self.assertEqual(exception.getMessageParameters(), {"objectName": "`a`"}) + class SparkConnectSessionWithOptionsTest(unittest.TestCase): def setUp(self) -> None: diff --git a/python/pyspark/sql/tests/connect/test_utils.py b/python/pyspark/sql/tests/connect/test_utils.py index 5f5f401cc626..917cb58057f7 100644 --- a/python/pyspark/sql/tests/connect/test_utils.py +++ b/python/pyspark/sql/tests/connect/test_utils.py @@ -14,16 +14,13 @@ # See the License for the specific language governing permissions and # limitations under the License. # -import unittest from pyspark.testing.connectutils import ReusedConnectTestCase from pyspark.sql.tests.test_utils import UtilsTestsMixin class ConnectUtilsTests(ReusedConnectTestCase, UtilsTestsMixin): - @unittest.skip("SPARK-46397: Different exception thrown") - def test_capture_illegalargument_exception(self): - super().test_capture_illegalargument_exception() + pass if __name__ == "__main__": diff --git a/python/pyspark/sql/tests/test_dataframe.py b/python/pyspark/sql/tests/test_dataframe.py index 692cf77d9afb..f1d690751ead 100644 --- a/python/pyspark/sql/tests/test_dataframe.py +++ b/python/pyspark/sql/tests/test_dataframe.py @@ -963,6 +963,24 @@ def test_unpivot_negative(self): ): df.unpivot("id", ["int", "str"], "var", "val").collect() + def test_melt_groupby(self): + df = self.spark.createDataFrame( + [(1, 2, 3, 4, 5, 6)], + ["f1", "f2", "label", "pred", "model_version", "ts"], + ) + self.assertEqual( + df.melt( + "model_version", + ["label", "f2"], + "f1", + "f2", + ) + .groupby("f1") + .count() + .count(), + 2, + ) + def test_observe(self): # SPARK-36263: tests the DataFrame.observe(Observation, *Column) method from pyspark.sql import Observation diff --git a/python/pyspark/sql/tests/test_python_datasource.py b/python/pyspark/sql/tests/test_python_datasource.py index b1bba584d85f..32333a8ccee9 100644 --- a/python/pyspark/sql/tests/test_python_datasource.py +++ b/python/pyspark/sql/tests/test_python_datasource.py @@ -225,10 +225,12 @@ def reader(self, schema) -> "DataSourceReader": assertDataFrameEqual(df, [Row(x=0, y="0"), Row(x=1, y="1")]) self.assertEqual(df.rdd.getNumPartitions(), 2) - def test_custom_json_data_source(self): + def _get_test_json_data_source(self): import json + import os + from dataclasses import dataclass - class JsonDataSourceReader(DataSourceReader): + class TestJsonReader(DataSourceReader): def __init__(self, options): self.options = options @@ -242,18 +244,39 @@ def read(self, partition): data = json.loads(line) yield data.get("name"), data.get("age") - class JsonDataSourceWriter(DataSourceWriter): + @dataclass + class TestCommitMessage(WriterCommitMessage): + count: int + + class TestJsonWriter(DataSourceWriter): def __init__(self, options): self.options = options + self.path = self.options.get("path") def write(self, iterator): - path = self.options.get("path") - with open(path, "w") as file: + from pyspark import TaskContext + + context = TaskContext.get() + output_path = os.path.join(self.path, f"{context.partitionId}.json") + count = 0 + with open(output_path, "w") as file: for row in iterator: + count += 1 + if "id" in row and row.id > 5: + raise Exception("id > 5") file.write(json.dumps(row.asDict()) + "\n") - return WriterCommitMessage() + return TestCommitMessage(count=count) - class JsonDataSource(DataSource): + def commit(self, messages): + total_count = sum(message.count for message in messages) + with open(os.path.join(self.path, "_success.txt"), "a") as file: + file.write(f"count: {total_count}\n") + + def abort(self, messages): + with open(os.path.join(self.path, "_failed.txt"), "a") as file: + file.write("failed") + + class TestJsonDataSource(DataSource): @classmethod def name(cls): return "my-json" @@ -262,13 +285,16 @@ def schema(self): return "name STRING, age INT" def reader(self, schema) -> "DataSourceReader": - return JsonDataSourceReader(self.options) + return TestJsonReader(self.options) def writer(self, schema, overwrite): - return JsonDataSourceWriter(self.options) + return TestJsonWriter(self.options) + + return TestJsonDataSource - self.spark.dataSource.register(JsonDataSource) - # Test data source read. + def test_custom_json_data_source_read(self): + data_source = self._get_test_json_data_source() + self.spark.dataSource.register(data_source) path1 = os.path.join(SPARK_HOME, "python/test_support/sql/people.json") path2 = os.path.join(SPARK_HOME, "python/test_support/sql/people1.json") assertDataFrameEqual( @@ -279,18 +305,34 @@ def writer(self, schema, overwrite): self.spark.read.format("my-json").load(path2), [Row(name="Jonathan", age=None)], ) - # Test data source write. - df = self.spark.read.json(path1) + + def test_custom_json_data_source_write(self): + data_source = self._get_test_json_data_source() + self.spark.dataSource.register(data_source) + input_path = os.path.join(SPARK_HOME, "python/test_support/sql/people.json") + df = self.spark.read.json(input_path) + with tempfile.TemporaryDirectory() as d: + df.write.format("my-json").mode("append").save(d) + assertDataFrameEqual(self.spark.read.json(d), self.spark.read.json(input_path)) + + def test_custom_json_data_source_commit(self): + data_source = self._get_test_json_data_source() + self.spark.dataSource.register(data_source) + with tempfile.TemporaryDirectory() as d: + self.spark.range(0, 5, 1, 3).write.format("my-json").mode("append").save(d) + with open(os.path.join(d, "_success.txt"), "r") as file: + text = file.read() + assert text == "count: 5\n" + + def test_custom_json_data_source_abort(self): + data_source = self._get_test_json_data_source() + self.spark.dataSource.register(data_source) with tempfile.TemporaryDirectory() as d: - path = os.path.join(d, "res.json") - df.write.format("my-json").mode("append").save(path) - with open(path, "r") as file: + with self.assertRaises(PythonException): + self.spark.range(0, 8, 1, 3).write.format("my-json").mode("append").save(d) + with open(os.path.join(d, "_failed.txt"), "r") as file: text = file.read() - assert text == ( - '{"age": null, "name": "Michael"}\n' - '{"age": 30, "name": "Andy"}\n' - '{"age": 19, "name": "Justin"}\n' - ) + assert text == "failed" class PythonDataSourceTests(BasePythonDataSourceTestsMixin, ReusedSQLTestCase): diff --git a/python/pyspark/sql/tests/test_utils.py b/python/pyspark/sql/tests/test_utils.py index d54db78d4b65..66b5c19fc975 100644 --- a/python/pyspark/sql/tests/test_utils.py +++ b/python/pyspark/sql/tests/test_utils.py @@ -20,11 +20,11 @@ from itertools import zip_longest from pyspark.errors import QueryContextType -from pyspark.sql.functions import sha2, to_timestamp from pyspark.errors import ( AnalysisException, ParseException, PySparkAssertionError, + PySparkValueError, IllegalArgumentException, SparkUpgradeException, ) @@ -590,8 +590,8 @@ def test_assert_equal_timestamp(self): data=[("1", "2023-01-01 12:01:01.000")], schema=["id", "timestamp"] ) - df1 = df1.withColumn("timestamp", to_timestamp("timestamp")) - df2 = df2.withColumn("timestamp", to_timestamp("timestamp")) + df1 = df1.withColumn("timestamp", F.to_timestamp("timestamp")) + df2 = df2.withColumn("timestamp", F.to_timestamp("timestamp")) assertDataFrameEqual(df1, df2, checkRowOrder=False) assertDataFrameEqual(df1, df2, checkRowOrder=True) @@ -1729,17 +1729,14 @@ def test_capture_illegalargument_exception(self): "Setting negative mapred.reduce.tasks", lambda: self.spark.sql("SET mapred.reduce.tasks=-1"), ) + + def test_capture_pyspark_value_exception(self): df = self.spark.createDataFrame([(1, 2)], ["a", "b"]) self.assertRaisesRegex( - IllegalArgumentException, - "1024 is not in the permitted values", - lambda: df.select(sha2(df.a, 1024)).collect(), + PySparkValueError, + "Value for `numBits` has to be amongst the following values", + lambda: df.select(F.sha2(df.a, 1024)).collect(), ) - try: - df.select(sha2(df.a, 1024)).collect() - except IllegalArgumentException as e: - self.assertRegex(e._desc, "1024 is not in the permitted values") - self.assertRegex(e._stackTrace, "org.apache.spark.sql.functions") def test_get_error_class_state(self): # SPARK-36953: test CapturedException.getErrorClass and getSqlState (from SparkThrowable) diff --git a/python/pyspark/sql/worker/commit_data_source_write.py b/python/pyspark/sql/worker/commit_data_source_write.py new file mode 100644 index 000000000000..afba7d467854 --- /dev/null +++ b/python/pyspark/sql/worker/commit_data_source_write.py @@ -0,0 +1,121 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import os +import sys +from typing import IO + +from pyspark.accumulators import _accumulatorRegistry +from pyspark.errors import PySparkAssertionError +from pyspark.java_gateway import local_connect_and_auth +from pyspark.serializers import ( + read_bool, + read_int, + write_int, + SpecialLengths, +) +from pyspark.sql.datasource import DataSourceWriter, WriterCommitMessage +from pyspark.util import handle_worker_exception +from pyspark.worker_util import ( + check_python_version, + pickleSer, + send_accumulator_updates, + setup_broadcasts, + setup_memory_limits, + setup_spark_files, +) + + +def main(infile: IO, outfile: IO) -> None: + """ + Main method for committing or aborting a data source write operation. + + This process is invoked from the `UserDefinedPythonDataSourceCommitRunner.runInPython` + method in the BatchWrite implementation of the PythonTableProvider. It is + responsible for invoking either the `commit` or the `abort` method on a data source + writer instance, given a list of commit messages. + """ + try: + check_python_version(infile) + + memory_limit_mb = int(os.environ.get("PYSPARK_PLANNER_MEMORY_MB", "-1")) + setup_memory_limits(memory_limit_mb) + + setup_spark_files(infile) + setup_broadcasts(infile) + + _accumulatorRegistry.clear() + + # Receive the data source writer instance. + writer = pickleSer._read_with_length(infile) + if not isinstance(writer, DataSourceWriter): + raise PySparkAssertionError( + error_class="PYTHON_DATA_SOURCE_TYPE_MISMATCH", + message_parameters={ + "expected": "an instance of DataSourceWriter", + "actual": f"'{type(writer).__name__}'", + }, + ) + + # Receive the commit messages. + num_messages = read_int(infile) + commit_messages = [] + for _ in range(num_messages): + message = pickleSer._read_with_length(infile) + if message is not None and not isinstance(message, WriterCommitMessage): + raise PySparkAssertionError( + error_class="PYTHON_DATA_SOURCE_TYPE_MISMATCH", + message_parameters={ + "expected": "an instance of WriterCommitMessage", + "actual": f"'{type(message).__name__}'", + }, + ) + commit_messages.append(message) + + # Receive a boolean to indicate whether to invoke `abort`. + abort = read_bool(infile) + + # Commit or abort the Python data source write. + # Note the commit messages can be None if there are failed tasks. + if abort: + writer.abort(commit_messages) # type: ignore[arg-type] + else: + writer.commit(commit_messages) # type: ignore[arg-type] + + # Send a status code back to JVM. + write_int(0, outfile) + + except BaseException as e: + handle_worker_exception(e, outfile) + sys.exit(-1) + + send_accumulator_updates(outfile) + + # check end of stream + if read_int(infile) == SpecialLengths.END_OF_STREAM: + write_int(SpecialLengths.END_OF_STREAM, outfile) + else: + # write a different value to tell JVM to not reuse this worker + write_int(SpecialLengths.END_OF_DATA_SECTION, outfile) + sys.exit(-1) + + +if __name__ == "__main__": + # Read information about how to connect back to the JVM from the environment. + java_port = int(os.environ["PYTHON_WORKER_FACTORY_PORT"]) + auth_secret = os.environ["PYTHON_WORKER_FACTORY_SECRET"] + (sock_file, _) = local_connect_and_auth(java_port, auth_secret) + main(sock_file, sock_file) diff --git a/python/pyspark/sql/worker/lookup_data_sources.py b/python/pyspark/sql/worker/lookup_data_sources.py new file mode 100644 index 000000000000..91963658ee61 --- /dev/null +++ b/python/pyspark/sql/worker/lookup_data_sources.py @@ -0,0 +1,99 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from importlib import import_module +from pkgutil import iter_modules +import os +import sys +from typing import IO + +from pyspark.accumulators import _accumulatorRegistry +from pyspark.java_gateway import local_connect_and_auth +from pyspark.serializers import ( + read_int, + write_int, + write_with_length, + SpecialLengths, +) +from pyspark.sql.datasource import DataSource +from pyspark.util import handle_worker_exception +from pyspark.worker_util import ( + check_python_version, + pickleSer, + send_accumulator_updates, + setup_broadcasts, + setup_memory_limits, + setup_spark_files, +) + + +def main(infile: IO, outfile: IO) -> None: + """ + Main method for looking up the available Python Data Sources in Python path. + + This process is invoked from the `UserDefinedPythonDataSourceLookupRunner.runInPython` + method in `UserDefinedPythonDataSource.lookupAllDataSourcesInPython` when the first + call related to Python Data Source happens via `DataSourceManager`. + + This is responsible for searching the available Python Data Sources so they can be + statically registered automatically. + """ + try: + check_python_version(infile) + + memory_limit_mb = int(os.environ.get("PYSPARK_PLANNER_MEMORY_MB", "-1")) + setup_memory_limits(memory_limit_mb) + + setup_spark_files(infile) + setup_broadcasts(infile) + + _accumulatorRegistry.clear() + + infos = {} + for info in iter_modules(): + if info.name.startswith("pyspark_"): + mod = import_module(info.name) + if hasattr(mod, "DefaultSource") and issubclass(mod.DefaultSource, DataSource): + infos[mod.DefaultSource.name()] = mod.DefaultSource + + # Writes name -> pickled data source to JVM side to be registered + # as a Data Source. + write_int(len(infos), outfile) + for name, dataSource in infos.items(): + write_with_length(name.encode("utf-8"), outfile) + pickleSer._write_with_length(dataSource, outfile) + + except BaseException as e: + handle_worker_exception(e, outfile) + sys.exit(-1) + + send_accumulator_updates(outfile) + + # check end of stream + if read_int(infile) == SpecialLengths.END_OF_STREAM: + write_int(SpecialLengths.END_OF_STREAM, outfile) + else: + # write a different value to tell JVM to not reuse this worker + write_int(SpecialLengths.END_OF_DATA_SECTION, outfile) + sys.exit(-1) + + +if __name__ == "__main__": + # Read information about how to connect back to the JVM from the environment. + java_port = int(os.environ["PYTHON_WORKER_FACTORY_PORT"]) + auth_secret = os.environ["PYTHON_WORKER_FACTORY_SECRET"] + (sock_file, _) = local_connect_and_auth(java_port, auth_secret) + main(sock_file, sock_file) diff --git a/python/pyspark/sql/worker/write_into_data_source.py b/python/pyspark/sql/worker/write_into_data_source.py index eea4a75e3be4..7db2744e16f8 100644 --- a/python/pyspark/sql/worker/write_into_data_source.py +++ b/python/pyspark/sql/worker/write_into_data_source.py @@ -211,6 +211,9 @@ def batch_to_rows() -> Iterator[Row]: command = (data_source_write_func, return_type) pickleSer._write_with_length(command, outfile) + # Return the picked writer. + pickleSer._write_with_length(writer, outfile) + except BaseException as e: handle_worker_exception(e, outfile) sys.exit(-1) diff --git a/sql/api/src/main/scala/org/apache/spark/sql/errors/DataTypeErrors.scala b/sql/api/src/main/scala/org/apache/spark/sql/errors/DataTypeErrors.scala index b30f7b7a00e9..456a311efda2 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/errors/DataTypeErrors.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/errors/DataTypeErrors.scala @@ -27,7 +27,7 @@ import org.apache.spark.unsafe.types.UTF8String /** * Object for grouping error messages from (most) exceptions thrown during query execution. * This does not include exceptions thrown during the eager execution of commands, which are - * grouped into [[QueryCompilationErrors]]. + * grouped into [[CompilationErrors]]. */ private[sql] object DataTypeErrors extends DataTypeErrorsBase { def unsupportedOperationExceptionError(): SparkUnsupportedOperationException = { diff --git a/sql/catalyst/pom.xml b/sql/catalyst/pom.xml index e7f8cbe0fe68..a84f8e54ec52 100644 --- a/sql/catalyst/pom.xml +++ b/sql/catalyst/pom.xml @@ -120,7 +120,6 @@ org.apache.datasketches datasketches-java - 3.3.0 diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 94f6d3346265..a57fd7a31d30 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -765,7 +765,7 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor case p: Pivot if !p.childrenResolved || !p.aggregates.forall(_.resolved) || (p.groupByExprsOpt.isDefined && !p.groupByExprsOpt.get.forall(_.resolved)) || !p.pivotColumn.resolved || !p.pivotValues.forall(_.resolved) => p - case p @ Pivot(groupByExprsOpt, pivotColumn, pivotValues, aggregates, child) => + case Pivot(groupByExprsOpt, pivotColumn, pivotValues, aggregates, child) => if (!RowOrdering.isOrderable(pivotColumn.dataType)) { throw QueryCompilationErrors.unorderablePivotColError(pivotColumn) } @@ -829,9 +829,7 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor Alias(ExtractValue(pivotAtt, Literal(i), resolver), outputName(value, aggregate))() } } - val newProject = Project(groupByExprsAttr ++ pivotOutputs, secondAgg) - newProject.copyTagsFrom(p) - newProject + Project(groupByExprsAttr ++ pivotOutputs, secondAgg) } else { val pivotAggregates: Seq[NamedExpression] = pivotValues.flatMap { value => def ifExpr(e: Expression) = { @@ -865,9 +863,7 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor Alias(filteredAggregate, outputName(value, aggregate))() } } - val newAggregate = Aggregate(groupByExprs, groupByExprs ++ pivotAggregates, child) - newAggregate.copyTagsFrom(p) - newAggregate + Aggregate(groupByExprs, groupByExprs ++ pivotAggregates, child) } } @@ -3264,9 +3260,7 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor // Finally, generate output columns according to the original projectList. val finalProjectList = aggregateExprs.map(_.toAttribute) - val newProject = Project(finalProjectList, withWindow) - newProject.copyTagsFrom(f) - newProject + Project(finalProjectList, withWindow) case p: LogicalPlan if !p.childrenResolved => p @@ -3284,9 +3278,7 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor // Finally, generate output columns according to the original projectList. val finalProjectList = aggregateExprs.map(_.toAttribute) - val newProject = Project(finalProjectList, withWindow) - newProject.copyTagsFrom(a) - newProject + Project(finalProjectList, withWindow) // We only extract Window Expressions after all expressions of the Project // have been resolved, and lateral column aliases are properly handled first. @@ -3303,9 +3295,7 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor // Finally, generate output columns according to the original projectList. val finalProjectList = projectList.map(_.toAttribute) - val newProject = Project(finalProjectList, withWindow) - newProject.copyTagsFrom(p) - newProject + Project(finalProjectList, withWindow) } } @@ -3461,14 +3451,20 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor _.containsPattern(NATURAL_LIKE_JOIN), ruleId) { case j @ Join(left, right, UsingJoin(joinType, usingCols), _, hint) if left.resolved && right.resolved && j.duplicateResolved => - commonNaturalJoinProcessing(left, right, joinType, usingCols, None, hint, - j.getTagValue(LogicalPlan.PLAN_ID_TAG)) + val project = commonNaturalJoinProcessing( + left, right, joinType, usingCols, None, hint) + j.getTagValue(LogicalPlan.PLAN_ID_TAG) + .foreach(project.setTagValue(LogicalPlan.PLAN_ID_TAG, _)) + project case j @ Join(left, right, NaturalJoin(joinType), condition, hint) if j.resolvedExceptNatural => // find common column names from both sides val joinNames = left.output.map(_.name).intersect(right.output.map(_.name)) - commonNaturalJoinProcessing(left, right, joinType, joinNames, condition, hint, - j.getTagValue(LogicalPlan.PLAN_ID_TAG)) + val project = commonNaturalJoinProcessing( + left, right, joinType, joinNames, condition, hint) + j.getTagValue(LogicalPlan.PLAN_ID_TAG) + .foreach(project.setTagValue(LogicalPlan.PLAN_ID_TAG, _)) + project } } @@ -3516,8 +3512,7 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor joinType: JoinType, joinNames: Seq[String], condition: Option[Expression], - hint: JoinHint, - planId: Option[Long] = None): LogicalPlan = { + hint: JoinHint): LogicalPlan = { import org.apache.spark.sql.catalyst.util._ val leftKeys = joinNames.map { keyName => @@ -3570,13 +3565,9 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor throw QueryExecutionErrors.unsupportedNaturalJoinTypeError(joinType) } - val newJoin = Join(left, right, joinType, newCondition, hint) - // retain the plan id used in Spark Connect - planId.foreach(newJoin.setTagValue(LogicalPlan.PLAN_ID_TAG, _)) - // use Project to hide duplicated common keys // propagate hidden columns from nested USING/NATURAL JOINs - val project = Project(projectList, newJoin) + val project = Project(projectList, Join(left, right, joinType, newCondition, hint)) project.setTagValue( Project.hiddenOutputTag, hiddenList.map(_.markAsQualifiedAccessOnly()) ++ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CTESubstitution.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CTESubstitution.scala index 173c9d44a2b3..2982d8477fcc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CTESubstitution.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CTESubstitution.scala @@ -149,12 +149,10 @@ object CTESubstitution extends Rule[LogicalPlan] { plan: LogicalPlan, cteDefs: ArrayBuffer[CTERelationDef]): LogicalPlan = { plan.resolveOperatorsUp { - case cte @ UnresolvedWith(child, relations) => + case UnresolvedWith(child, relations) => val resolvedCTERelations = resolveCTERelations(relations, isLegacy = true, forceInline = false, Seq.empty, cteDefs) - val substituted = substituteCTE(child, alwaysInline = true, resolvedCTERelations) - substituted.copyTagsFrom(cte) - substituted + substituteCTE(child, alwaysInline = true, resolvedCTERelations) } } @@ -204,7 +202,7 @@ object CTESubstitution extends Rule[LogicalPlan] { var firstSubstituted: Option[LogicalPlan] = None val newPlan = plan.resolveOperatorsDownWithPruning( _.containsAnyPattern(UNRESOLVED_WITH, PLAN_EXPRESSION)) { - case cte @ UnresolvedWith(child: LogicalPlan, relations) => + case UnresolvedWith(child: LogicalPlan, relations) => val resolvedCTERelations = resolveCTERelations(relations, isLegacy = false, forceInline, outerCTEDefs, cteDefs) ++ outerCTEDefs @@ -215,7 +213,6 @@ object CTESubstitution extends Rule[LogicalPlan] { if (firstSubstituted.isEmpty) { firstSubstituted = Some(substituted) } - substituted.copyTagsFrom(cte) substituted case other => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 24e2bb767ab4..d45f0fa4bf78 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -771,6 +771,7 @@ object FunctionRegistry { expression[PercentRank]("percent_rank"), // predicates + expression[Between]("between"), expression[And]("and"), expression[In]("in"), expression[Not]("not"), @@ -876,9 +877,6 @@ object FunctionRegistry { "expr1 <> expr2 - Returns true if `expr1` is not equal to `expr2`."), "!=" -> makeExprInfoForVirtualOperator("!=", "expr1 != expr2 - Returns true if `expr1` is not equal to `expr2`."), - "between" -> makeExprInfoForVirtualOperator("between", - "expr1 [NOT] BETWEEN expr2 AND expr3 - " + - "evaluate if `expr1` is [not] in between `expr2` and `expr3`."), "case" -> makeExprInfoForVirtualOperator("case", "CASE expr1 WHEN expr2 THEN expr3 [WHEN expr4 THEN expr5]* [ELSE expr6] END " + "- When `expr1` = `expr2`, returns `expr3`; when `expr1` = `expr4`, return `expr5`; " + diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/StreamingJoinHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/StreamingJoinHelper.scala index 380172c1a131..aad0c012a0d5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/StreamingJoinHelper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/StreamingJoinHelper.scala @@ -101,6 +101,9 @@ object StreamingJoinHelper extends PredicateHelper with Logging { case LessThanOrEqual(l, r) => getStateWatermarkSafely(l, r).map(_ - 1) case GreaterThan(l, r) => getStateWatermarkSafely(r, l) case GreaterThanOrEqual(l, r) => getStateWatermarkSafely(r, l).map(_ - 1) + case Between(input, lower, upper, _) => + getStateWatermarkSafely(lower, input).map(_ - 1) + .orElse(getStateWatermarkSafely(input, upper).map(_ - 1)) case _ => None } if (stateWatermark.nonEmpty) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index c5e98683c749..56e8843fda53 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -266,10 +266,8 @@ abstract class TypeCoercionBase { s -> Nil } else { assert(newChildren.length == 2) - val newExcept = Except(newChildren.head, newChildren.last, isAll) - newExcept.copyTagsFrom(s) val attrMapping = left.output.zip(newChildren.head.output) - newExcept -> attrMapping + Except(newChildren.head, newChildren.last, isAll) -> attrMapping } case s @ Intersect(left, right, isAll) if s.childrenResolved && @@ -279,10 +277,8 @@ abstract class TypeCoercionBase { s -> Nil } else { assert(newChildren.length == 2) - val newIntersect = Intersect(newChildren.head, newChildren.last, isAll) - newIntersect.copyTagsFrom(s) val attrMapping = left.output.zip(newChildren.head.output) - newIntersect -> attrMapping + Intersect(newChildren.head, newChildren.last, isAll) -> attrMapping } case s: Union if s.childrenResolved && !s.byName && @@ -292,9 +288,7 @@ abstract class TypeCoercionBase { s -> Nil } else { val attrMapping = s.children.head.output.zip(newChildren.head.output) - val newUnion = s.copy(children = newChildren) - newUnion.copyTagsFrom(s) - newUnion -> attrMapping + s.copy(children = newChildren) -> attrMapping } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogUtils.scala index b20c62438986..eb649c4d4796 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogUtils.scala @@ -64,7 +64,7 @@ object ExternalCatalogUtils { } def needsEscaping(c: Char): Boolean = { - c >= 0 && c < charToEscape.size() && charToEscape.get(c) + c < charToEscape.size() && charToEscape.get(c) } def escapePathName(path: String): String = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Between.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Between.scala new file mode 100644 index 000000000000..e5bb31bc34f1 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Between.scala @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "Usage: input [NOT] BETWEEN lower AND upper - evaluate if `input` is [not] in between `lower` and `upper`", + examples = """ + Examples: + > SELECT 0.5 _FUNC_ 0.1 AND 1.0; + true + """, + arguments = """ + Arguments: + * input - An expression that is being compared with lower and upper bound. + * lower - Lower bound of the between check. + * upper - Upper bound of the between check. + """, + since = "4.0.0", + group = "conditional_funcs") +case class Between private(input: Expression, lower: Expression, upper: Expression, replacement: Expression) + extends RuntimeReplaceable with InheritAnalysisRules { + def this(input: Expression, lower: Expression, upper: Expression) = { + this(input, lower, upper, { + val commonExpr = CommonExpressionDef(input) + val ref = new CommonExpressionRef(commonExpr) + val replacement = And(GreaterThanOrEqual(ref, lower), LessThanOrEqual(ref, upper)) + With(replacement, Seq(commonExpr)) + }) + }; + + override def parameters: Seq[Expression] = Seq(input, lower, upper) + + override protected def withNewChildInternal(newChild: Expression): Between = { + copy(replacement = newChild) + } +} + +object Between { + def apply(input: Expression, lower: Expression, upper: Expression): Between = { + new Between(input, lower, upper) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index bb63d874aa1a..7609e96eba6f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -2034,10 +2034,8 @@ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper with Logging { // Create the predicate. ctx.kind.getType match { case SqlBaseParser.BETWEEN => - // BETWEEN is translated to lower <= e && e <= upper - invertIfNotDefined(And( - GreaterThanOrEqual(e, expression(ctx.lower)), - LessThanOrEqual(e, expression(ctx.upper)))) + invertIfNotDefined(UnresolvedFunction( + "between", Seq(e, expression(ctx.lower), expression(ctx.upper)), isDistinct = false)) case SqlBaseParser.IN if ctx.query != null => invertIfNotDefined(InSubquery(getValueExpressions(e), ListQuery(plan(ctx.query)))) case SqlBaseParser.IN => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index ef7cd7401f25..2a62ea1feb03 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -361,6 +361,9 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] } else { transferAttrMapping ++ newOtherAttrMapping } + if (!(plan eq planAfterRule)) { + planAfterRule.copyTagsFrom(plan) + } planAfterRule -> resultAttrMapping.toSeq } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/AnalysisHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/AnalysisHelper.scala index 56f6b116759a..45a20cbe3aaf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/AnalysisHelper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/AnalysisHelper.scala @@ -177,10 +177,14 @@ trait AnalysisHelper extends QueryPlan[LogicalPlan] { self: LogicalPlan => self.markRuleAsIneffective(ruleId) self } else { + rewritten_plan.copyTagsFrom(self) rewritten_plan } } else { - afterRule.mapChildren(_.resolveOperatorsDownWithPruning(cond, ruleId)(rule)) + val newPlan = afterRule + .mapChildren(_.resolveOperatorsDownWithPruning(cond, ruleId)(rule)) + newPlan.copyTagsFrom(self) + newPlan } } } else { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/StaxXmlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/StaxXmlParser.scala index 32fdb3e5faf2..11edce8140f0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/StaxXmlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/StaxXmlParser.scala @@ -25,6 +25,7 @@ import javax.xml.stream.events._ import javax.xml.transform.stream.StreamSource import javax.xml.validation.Schema +import scala.annotation.tailrec import scala.collection.mutable.ArrayBuffer import scala.jdk.CollectionConverters._ import scala.util.Try @@ -35,7 +36,21 @@ import org.apache.spark.SparkUpgradeException import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.ExprUtils -import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, BadRecordException, CaseInsensitiveMap, DateFormatter, DropMalformedMode, FailureSafeParser, GenericArrayData, MapData, ParseMode, PartialResultArrayException, PartialResultException, PermissiveMode, TimestampFormatter} +import org.apache.spark.sql.catalyst.util.{ + ArrayBasedMapData, + BadRecordException, + CaseInsensitiveMap, + DateFormatter, + DropMalformedMode, + FailureSafeParser, + GenericArrayData, + MapData, + ParseMode, + PartialResultArrayException, + PartialResultException, + PermissiveMode, + TimestampFormatter +} import org.apache.spark.sql.catalyst.util.LegacyDateFormats.FAST_DATE_FORMAT import org.apache.spark.sql.catalyst.xml.StaxXmlParser.convertStream import org.apache.spark.sql.errors.QueryExecutionErrors @@ -69,6 +84,7 @@ class StaxXmlParser( private val decimalParser = ExprUtils.getDecimalParser(options.locale) + private val caseSensitive = SQLConf.get.caseSensitiveAnalysis /** * Parses a single XML string and turns it into either one resulting row or no row (if the @@ -85,7 +101,7 @@ class StaxXmlParser( } private def getFieldNameToIndex(schema: StructType): Map[String, Int] = { - if (SQLConf.get.caseSensitiveAnalysis) { + if (caseSensitive) { schema.map(_.name).zipWithIndex.toMap } else { CaseInsensitiveMap(schema.map(_.name).zipWithIndex.toMap) @@ -201,27 +217,30 @@ class StaxXmlParser( case (_: EndElement, _: DataType) => null case (c: Characters, ArrayType(st, _)) => // For `ArrayType`, it needs to return the type of element. The values are merged later. + parser.next convertTo(c.getData, st) case (c: Characters, st: StructType) => - // If a value tag is present, this can be an attribute-only element whose values is in that - // value tag field. Or, it can be a mixed-type element with both some character elements - // and other complex structure. Character elements are ignored. - val attributesOnly = st.fields.forall { f => - f.name == options.valueTag || f.name.startsWith(options.attributePrefix) - } - if (attributesOnly) { - // If everything else is an attribute column, there's no complex structure. - // Just return the value of the character element, or null if we don't have a value tag - st.find(_.name == options.valueTag).map( - valueTag => convertTo(c.getData, valueTag.dataType)).orNull - } else { - // Otherwise, ignore this character element, and continue parsing the following complex - // structure - parser.next - parser.peek match { - case _: EndElement => null // no struct here at all; done - case _ => convertObject(parser, st) - } + parser.next + parser.peek match { + case _: EndElement => + // It couldn't be an array of value tags + // as the opening tag is immediately followed by a closing tag. + if (c.isWhiteSpace) { + return null + } + val indexOpt = getFieldNameToIndex(st).get(options.valueTag) + indexOpt match { + case Some(index) => + convertTo(c.getData, st.fields(index).dataType) + case None => null + } + case _ => + val row = convertObject(parser, st) + if (!c.isWhiteSpace) { + addOrUpdate(row.toSeq(st).toArray, st, options.valueTag, c.getData, addToTail = false) + } else { + row + } } case (_: Characters, _: StringType) => convertTo(StaxXmlParserUtils.currentStructureAsString(parser), StringType) @@ -237,6 +256,7 @@ class StaxXmlParser( case _ => convertField(parser, dataType, attributes) } case (c: Characters, dt: DataType) => + parser.next convertTo(c.getData, dt) case (e: XMLEvent, dt: DataType) => throw new IllegalArgumentException( @@ -262,7 +282,12 @@ class StaxXmlParser( case e: StartElement => kvPairs += (UTF8String.fromString(StaxXmlParserUtils.getName(e.asStartElement.getName, options)) -> - convertField(parser, valueType)) + convertField(parser, valueType)) + case c: Characters if !c.isWhiteSpace => + // Create a value tag field for it + kvPairs += + // TODO: We don't support an array value tags in map yet. + (UTF8String.fromString(options.valueTag) -> convertTo(c.getData, valueType)) case _: EndElement => shouldStop = StaxXmlParserUtils.checkEndElement(parser) case _ => // do nothing @@ -343,8 +368,9 @@ class StaxXmlParser( val row = new Array[Any](schema.length) val nameToIndex = getFieldNameToIndex(schema) // If there are attributes, then we process them first. - convertAttributes(rootAttributes, schema).toSeq.foreach { case (f, v) => - nameToIndex.get(f).foreach { row(_) = v } + convertAttributes(rootAttributes, schema).toSeq.foreach { + case (f, v) => + nameToIndex.get(f).foreach { row(_) = v } } val wildcardColName = options.wildcardColName @@ -405,15 +431,11 @@ class StaxXmlParser( badRecordException = badRecordException.orElse(Some(e)) } - case c: Characters if !c.isWhiteSpace && isRootAttributesOnly => - nameToIndex.get(options.valueTag) match { - case Some(index) => - row(index) = convertTo(c.getData, schema(index).dataType) - case None => // do nothing - } + case c: Characters if !c.isWhiteSpace => + addOrUpdate(row, schema, options.valueTag, c.getData) case _: EndElement => - shouldStop = StaxXmlParserUtils.checkEndElement(parser) + shouldStop = parseAndCheckEndElement(row, schema, parser) case _ => // do nothing } @@ -576,6 +598,54 @@ class StaxXmlParser( castTo(data, FloatType).asInstanceOf[Float] } } + + @tailrec + private def parseAndCheckEndElement( + row: Array[Any], + schema: StructType, + parser: XMLEventReader): Boolean = { + parser.peek match { + case _: EndElement | _: EndDocument => true + case _: StartElement => false + case c: Characters if !c.isWhiteSpace => + parser.nextEvent() + addOrUpdate(row, schema, options.valueTag, c.getData) + parseAndCheckEndElement(row, schema, parser) + case _ => + parser.nextEvent() + parseAndCheckEndElement(row, schema, parser) + } + } + + private def addOrUpdate( + row: Array[Any], + schema: StructType, + name: String, + data: String, + addToTail: Boolean = true): InternalRow = { + schema.getFieldIndex(name) match { + case Some(index) => + schema(index).dataType match { + case ArrayType(elementType, _) => + val value = convertTo(data, elementType) + val result = if (row(index) == null) { + ArrayBuffer(value) + } else { + val genericArrayData = row(index).asInstanceOf[GenericArrayData] + if (addToTail) { + genericArrayData.toArray(elementType) :+ value + } else { + value +: genericArrayData.toArray(elementType) + } + } + row(index) = new GenericArrayData(result) + case dataType => + row(index) = convertTo(data, dataType) + } + case None => // do nothing + } + InternalRow.fromSeq(row.toIndexedSeq) + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/XmlInferSchema.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/XmlInferSchema.scala index de8ec33de0ce..9d0c16d95e46 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/XmlInferSchema.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/XmlInferSchema.scala @@ -164,7 +164,6 @@ class XmlInferSchema(options: XmlOptions, caseSensitive: Boolean) } } - @tailrec private def inferField(parser: XMLEventReader): DataType = { parser.peek match { case _: EndElement => NullType @@ -182,18 +181,25 @@ class XmlInferSchema(options: XmlOptions, caseSensitive: Boolean) case _ => inferField(parser) } case c: Characters if !c.isWhiteSpace => - // This could be the characters of a character-only element, or could have mixed - // characters and other complex structure val characterType = inferFrom(c.getData) parser.nextEvent() parser.peek match { case _: StartElement => - // Some more elements follow; so ignore the characters. - // Use the schema of the rest - inferObject(parser).asInstanceOf[StructType] + // Some more elements follow; + // This is a mix of values and other elements + val innerType = inferObject(parser).asInstanceOf[StructType] + addOrUpdateValueTagType(innerType, characterType) case _ => - // That's all, just the character-only body; use that as the type - characterType + val fieldType = inferField(parser) + fieldType match { + case st: StructType => addOrUpdateValueTagType(st, characterType) + case _: NullType => characterType + case _: DataType => + // The field type couldn't be an array type + new StructType() + .add(options.valueTag, addOrUpdateType(Some(characterType), fieldType)) + + } } case e: XMLEvent => throw new IllegalArgumentException(s"Failed to parse data with unexpected event $e") @@ -229,17 +235,19 @@ class XmlInferSchema(options: XmlOptions, caseSensitive: Boolean) val nameToDataType = collection.mutable.TreeMap.empty[String, DataType](caseSensitivityOrdering) - def addOrUpdateType(fieldName: String, newType: DataType): Unit = { - val oldTypeOpt = nameToDataType.get(fieldName) - oldTypeOpt match { - // If the field name exists in the map, - // merge the type and infer the combined field as an array type if necessary - case Some(oldType) if !oldType.isInstanceOf[ArrayType] => - nameToDataType.update(fieldName, ArrayType(compatibleType(oldType, newType))) - case Some(oldType) => - nameToDataType.update(fieldName, compatibleType(oldType, newType)) - case None => - nameToDataType.put(fieldName, newType) + @tailrec + def inferAndCheckEndElement(parser: XMLEventReader): Boolean = { + parser.peek match { + case _: EndElement | _: EndDocument => true + case _: StartElement => false + case c: Characters if !c.isWhiteSpace => + val characterType = inferFrom(c.getData) + parser.nextEvent() + addOrUpdateType(nameToDataType, options.valueTag, characterType) + inferAndCheckEndElement(parser) + case _ => + parser.nextEvent() + inferAndCheckEndElement(parser) } } @@ -248,7 +256,7 @@ class XmlInferSchema(options: XmlOptions, caseSensitive: Boolean) StaxXmlParserUtils.convertAttributesToValuesMap(rootAttributes, options) rootValuesMap.foreach { case (f, v) => - addOrUpdateType(f, inferFrom(v)) + addOrUpdateType(nameToDataType, f, inferFrom(v)) } var shouldStop = false while (!shouldStop) { @@ -281,29 +289,19 @@ class XmlInferSchema(options: XmlOptions, caseSensitive: Boolean) } // Add the field and datatypes so that we can check if this is ArrayType. val field = StaxXmlParserUtils.getName(e.asStartElement.getName, options) - addOrUpdateType(field, inferredType) + addOrUpdateType(nameToDataType, field, inferredType) case c: Characters if !c.isWhiteSpace => // This can be an attribute-only object val valueTagType = inferFrom(c.getData) - addOrUpdateType(options.valueTag, valueTagType) + addOrUpdateType(nameToDataType, options.valueTag, valueTagType) case _: EndElement => - shouldStop = StaxXmlParserUtils.checkEndElement(parser) + shouldStop = inferAndCheckEndElement(parser) case _ => // do nothing } } - // A structure object is an attribute-only element - // if it only consists of attributes and valueTags. - // If not, we will remove the valueTag field from the schema - val attributesOnly = nameToDataType.forall { - case (fieldName, _) => - fieldName == options.valueTag || fieldName.startsWith(options.attributePrefix) - } - if (!attributesOnly) { - nameToDataType -= options.valueTag - } // Note: other code relies on this sorting for correctness, so don't remove it! StructType(nameToDataType.map{ @@ -534,4 +532,75 @@ class XmlInferSchema(options: XmlOptions, caseSensitive: Boolean) } } } + + /** + * This helper function merges the data type of value tags and inner elements. + * It could only be structure data. Consider the following case, + * + * value1 + * 1 + * value2 + * + * Input: ''a struct'' and ''_VALUE string'' + * Return: ''a struct>'' + * @param objectType inner elements' type + * @param valueTagType value tag's type + */ + private[xml] def addOrUpdateValueTagType( + objectType: DataType, + valueTagType: DataType): DataType = { + (objectType, valueTagType) match { + case (st: StructType, _) => + val valueTagIndexOpt = st.getFieldIndex(options.valueTag) + + valueTagIndexOpt match { + // If the field name exists in the inner elements, + // merge the type and infer the combined field as an array type if necessary + case Some(index) if !st(index).dataType.isInstanceOf[ArrayType] => + updateStructField( + st, + index, + ArrayType(compatibleType(st(index).dataType, valueTagType))) + case Some(index) => + updateStructField(st, index, compatibleType(st(index).dataType, valueTagType)) + case None => + st.add(options.valueTag, valueTagType) + } + case _ => + throw new IllegalStateException( + "illegal state when merging value tags types in schema inference" + ) + } + } + + private def updateStructField( + structType: StructType, + index: Int, + newType: DataType): StructType = { + val newFields: Array[StructField] = + structType.fields.updated(index, structType.fields(index).copy(dataType = newType)) + StructType(newFields) + } + + private def addOrUpdateType( + nameToDataType: collection.mutable.TreeMap[String, DataType], + fieldName: String, + newType: DataType): Unit = { + val oldTypeOpt = nameToDataType.get(fieldName) + val mergedType = addOrUpdateType(oldTypeOpt, newType) + nameToDataType.put(fieldName, mergedType) + } + + private def addOrUpdateType(oldTypeOpt: Option[DataType], newType: DataType): DataType = { + oldTypeOpt match { + // If the field name already exists, + // merge the type and infer the combined field as an array type if necessary + case Some(oldType) if !oldType.isInstanceOf[ArrayType] && !newType.isInstanceOf[NullType] => + ArrayType(compatibleType(oldType, newType)) + case Some(oldType) => + compatibleType(oldType, newType) + case None => + newType + } + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala index ee41cbe2f50e..e8235fd10466 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala @@ -3856,12 +3856,6 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase with Compilat "reason" -> reason)) } - def dataSourceAlreadyExists(name: String): Throwable = { - new AnalysisException( - errorClass = "DATA_SOURCE_ALREADY_EXISTS", - messageParameters = Map("provider" -> name)) - } - def dataSourceDoesNotExist(name: String): Throwable = { new AnalysisException( errorClass = "DATA_SOURCE_NOT_EXIST", diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogSuite.scala index 1ee20a98cfd1..ba01f9559161 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogSuite.scala @@ -31,6 +31,7 @@ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns import org.apache.spark.sql.connector.catalog.SupportsNamespaces.PROP_OWNER +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -569,8 +570,10 @@ abstract class ExternalCatalogSuite extends SparkFunSuite { // then be caught and converted to a RuntimeException with a descriptive message. case ex: RuntimeException if ex.getMessage.contains("MetaException") => throw new AnalysisException( - errorClass = "_LEGACY_ERROR_TEMP_3066", - messageParameters = Map("msg" -> ex.getMessage)) + errorClass = "_LEGACY_ERROR_TEMP_2193", + messageParameters = Map( + "hiveMetastorePartitionPruningFallbackOnException" -> + SQLConf.HIVE_METASTORE_PARTITION_PRUNING_FALLBACK_ON_EXCEPTION.key)) } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala index fe5d024a6b37..50510861c639 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala @@ -157,8 +157,10 @@ class ExpressionParserSuite extends AnalysisTest { } test("between expressions") { - assertEqual("a between b and c", $"a" >= $"b" && $"a" <= $"c") - assertEqual("a not between b and c", !($"a" >= $"b" && $"a" <= $"c")) + assertEqual("a between b and c", + UnresolvedFunction("between", Seq($"a", $"b", $"c"), isDistinct = false)) + assertEqual("a not between b and c", + !UnresolvedFunction("between", Seq($"a", $"b", $"c"), isDistinct = false)) } test("in expressions") { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index eb5b38d42881..a2cfad800e00 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -114,8 +114,11 @@ class QueryExecution( // for eagerly executed commands we mark this place as beginning of execution. tracker.setReadyForExecution() val qe = sparkSession.sessionState.executePlan(c, CommandExecutionMode.NON_ROOT) - val result = SQLExecution.withNewExecutionId(qe, Some(commandExecutionName(c))) { - qe.executedPlan.executeCollect() + val name = commandExecutionName(c) + val result = QueryExecution.withInternalError(s"Eagerly executed $name failed.") { + SQLExecution.withNewExecutionId(qe, Some(name)) { + qe.executedPlan.executeCollect() + } } CommandResult( qe.analyzed.output, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceManager.scala index e6c4749df60a..4fc636a59e5a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceManager.scala @@ -17,12 +17,18 @@ package org.apache.spark.sql.execution.datasources +import java.io.File import java.util.Locale import java.util.concurrent.ConcurrentHashMap +import java.util.regex.Pattern +import scala.jdk.CollectionConverters._ + +import org.apache.spark.api.python.PythonUtils import org.apache.spark.internal.Logging import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.python.UserDefinedPythonDataSource +import org.apache.spark.util.Utils /** @@ -30,9 +36,13 @@ import org.apache.spark.sql.execution.python.UserDefinedPythonDataSource * their short names or fully qualified names. */ class DataSourceManager extends Logging { - // TODO(SPARK-45917): Statically load Python Data Source so idempotently Python - // Data Sources can be loaded even when the Driver is restarted. - private val dataSourceBuilders = new ConcurrentHashMap[String, UserDefinedPythonDataSource]() + // Lazy to avoid being invoked during Session initialization. + // Otherwise, it goes infinite loop, session -> Python runner -> SQLConf -> session. + private lazy val dataSourceBuilders = { + val builders = new ConcurrentHashMap[String, UserDefinedPythonDataSource]() + builders.putAll(DataSourceManager.initialDataSourceBuilders.asJava) + builders + } private def normalize(name: String): String = name.toLowerCase(Locale.ROOT) @@ -73,3 +83,42 @@ class DataSourceManager extends Logging { manager } } + + +object DataSourceManager extends Logging { + // Visible for testing + private[spark] var dataSourceBuilders: Option[Map[String, UserDefinedPythonDataSource]] = None + private lazy val shouldLoadPythonDataSources: Boolean = { + Utils.checkCommandAvailable(PythonUtils.defaultPythonExec) && + // Make sure PySpark zipped files also exist. + PythonUtils.sparkPythonPath + .split(Pattern.quote(File.separator)).forall(new File(_).exists()) + } + + private def initialDataSourceBuilders: Map[String, UserDefinedPythonDataSource] = { + if (Utils.isTesting || shouldLoadPythonDataSources) this.synchronized { + if (dataSourceBuilders.isEmpty) { + val maybeResult = try { + Some(UserDefinedPythonDataSource.lookupAllDataSourcesInPython()) + } catch { + case e: Throwable => + // Even if it fails for whatever reason, we shouldn't make the whole + // application fail. + logWarning( + s"Skipping the lookup of Python Data Sources due to the failure: $e") + None + } + + dataSourceBuilders = maybeResult.map { result => + result.names.zip(result.dataSources).map { case (name, dataSource) => + name -> + UserDefinedPythonDataSource(PythonUtils.createPythonFunction(dataSource)) + }.toMap + } + } + dataSourceBuilders.getOrElse(Map.empty) + } else { + Map.empty + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DescribeColumnExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DescribeColumnExec.scala index 61ccda3fc954..2683d8d547f0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DescribeColumnExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DescribeColumnExec.scala @@ -53,7 +53,7 @@ case class DescribeColumnExec( read.newScanBuilder(CaseInsensitiveStringMap.empty()).build() match { case s: SupportsReportStatistics => val stats = s.estimateStatistics() - Some(stats.columnStats().get(FieldReference.column(column.name))) + Option(stats.columnStats().get(FieldReference.column(column.name))) case _ => None } case _ => None diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonDataSource.scala index 778f55595aee..2812e31e7a8c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonDataSource.scala @@ -25,7 +25,7 @@ import scala.jdk.CollectionConverters._ import net.razorvine.pickle.Pickler import org.apache.spark.{JobArtifactSet, SparkException} -import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType, PythonFunction, PythonWorkerUtils, SimplePythonFunction, SpecialLengths} +import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType, PythonFunction, PythonUtils, PythonWorkerUtils, SimplePythonFunction, SpecialLengths} import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.PythonUDF @@ -128,6 +128,9 @@ class PythonTableProvider extends TableProvider { override def toBatch: BatchWrite = new BatchWrite { + // Store the pickled data source writer instance. + private var pythonDataSourceWriter: Array[Byte] = _ + override def createBatchWriterFactory( physicalInfo: PhysicalWriteInfo): DataWriterFactory = { @@ -136,14 +139,21 @@ class PythonTableProvider extends TableProvider { info.schema(), info.options(), isTruncate) + + pythonDataSourceWriter = writeInfo.writer + PythonBatchWriterFactory(source, writeInfo.func, info.schema(), jobArtifactUUID) } - // TODO(SPARK-45914): Support commit protocol - override def commit(messages: Array[WriterCommitMessage]): Unit = {} + override def commit(messages: Array[WriterCommitMessage]): Unit = { + source.commitWriteInPython(pythonDataSourceWriter, messages) + } - // TODO(SPARK-45914): Support commit protocol - override def abort(messages: Array[WriterCommitMessage]): Unit = {} + override def abort(messages: Array[WriterCommitMessage]): Unit = { + source.commitWriteInPython(pythonDataSourceWriter, messages, abort = true) + } + + override def toString: String = shortName } override def description: String = "(Python)" @@ -333,6 +343,17 @@ case class UserDefinedPythonDataSource(dataSourceCls: PythonFunction) { overwrite).runInPython() } + /** + * (Driver-side) Run Python process to either commit or abort a write operation. + */ + def commitWriteInPython( + writer: Array[Byte], + messages: Array[WriterCommitMessage], + abort: Boolean = false): Unit = { + new UserDefinedPythonDataSourceCommitRunner( + dataSourceCls, writer, messages, abort).runInPython() + } + /** * (Executor-side) Create an iterator that execute the Python function. */ @@ -404,6 +425,59 @@ object UserDefinedPythonDataSource { * The schema of the output to the Python data source write function. */ val writeOutputSchema: StructType = new StructType().add("message", BinaryType) + + /** + * (Driver-side) Look up all available Python Data Sources. + */ + def lookupAllDataSourcesInPython(): PythonLookupAllDataSourcesResult = { + new UserDefinedPythonDataSourceLookupRunner( + PythonUtils.createPythonFunction(Array.empty[Byte])).runInPython() + } +} + +/** + * All Data Sources in Python + */ +case class PythonLookupAllDataSourcesResult( + names: Array[String], dataSources: Array[Array[Byte]]) + +/** + * A runner used to look up Python Data Sources available in Python path. + */ +class UserDefinedPythonDataSourceLookupRunner(lookupSources: PythonFunction) + extends PythonPlannerRunner[PythonLookupAllDataSourcesResult](lookupSources) { + + override val workerModule = "pyspark.sql.worker.lookup_data_sources" + + override protected def writeToPython(dataOut: DataOutputStream, pickler: Pickler): Unit = { + // No input needed. + } + + override protected def receiveFromPython( + dataIn: DataInputStream): PythonLookupAllDataSourcesResult = { + // Receive the pickled data source or an exception raised in Python worker. + val length = dataIn.readInt() + if (length == SpecialLengths.PYTHON_EXCEPTION_THROWN) { + val msg = PythonWorkerUtils.readUTF(dataIn) + throw QueryCompilationErrors.failToPlanDataSourceError( + action = "lookup", tpe = "instance", msg = msg) + } + + val shortNames = ArrayBuffer.empty[String] + val pickledDataSources = ArrayBuffer.empty[Array[Byte]] + val numDataSources = length + + for (_ <- 0 until numDataSources) { + val shortName = PythonWorkerUtils.readUTF(dataIn) + val pickledDataSource: Array[Byte] = PythonWorkerUtils.readBytes(dataIn) + shortNames.append(shortName) + pickledDataSources.append(pickledDataSource) + } + + PythonLookupAllDataSourcesResult( + names = shortNames.toArray, + dataSources = pickledDataSources.toArray) + } } /** @@ -537,7 +611,7 @@ class UserDefinedPythonDataSourceReadRunner( /** * Hold the results of running [[UserDefinedPythonDataSourceWriteRunner]]. */ -case class PythonDataSourceWriteInfo(func: Array[Byte]) +case class PythonDataSourceWriteInfo(func: Array[Byte], writer: Array[Byte]) /** * A runner that creates a Python data source writer instance and returns a Python function @@ -587,9 +661,55 @@ class UserDefinedPythonDataSourceWriteRunner( action = "plan", tpe = "write", msg = msg) } - // Receive the pickled data source. + // Receive the pickled data source write function. val writeUdf: Array[Byte] = PythonWorkerUtils.readBytes(length, dataIn) - PythonDataSourceWriteInfo(func = writeUdf) + // Receive the pickled instance of the data source writer. + val writer: Array[Byte] = PythonWorkerUtils.readBytes(dataIn) + + PythonDataSourceWriteInfo(func = writeUdf, writer = writer) + } +} + +/** + * A runner that takes a Python data source writer and a list of commit messages, + * and invokes the `commit` or `abort` method of the writer in Python. + */ +class UserDefinedPythonDataSourceCommitRunner( + dataSourceCls: PythonFunction, + writer: Array[Byte], + messages: Array[WriterCommitMessage], + abort: Boolean) extends PythonPlannerRunner[Unit](dataSourceCls) { + override val workerModule: String = "pyspark.sql.worker.commit_data_source_write" + + override protected def writeToPython(dataOut: DataOutputStream, pickler: Pickler): Unit = { + // Send the Python data source writer. + PythonWorkerUtils.writeBytes(writer, dataOut) + + // Send the commit messages. + dataOut.writeInt(messages.length) + messages.foreach { message => + // Commit messages can be null if there are task failures. + if (message == null) { + dataOut.writeInt(SpecialLengths.NULL) + } else { + PythonWorkerUtils.writeBytes( + message.asInstanceOf[PythonWriterCommitMessage].pickledMessage, dataOut) + } + } + + // Send whether to invoke `abort` instead of `commit`. + dataOut.writeBoolean(abort) + } + + override protected def receiveFromPython(dataIn: DataInputStream): Unit = { + // Receive any exceptions thrown in the Python worker. + val code = dataIn.readInt() + if (code == SpecialLengths.PYTHON_EXCEPTION_THROWN) { + val msg = PythonWorkerUtils.readUTF(dataIn) + throw QueryCompilationErrors.failToPlanDataSourceError( + action = "commit or abort", tpe = "write", msg = msg) + } + assert(code == 0, s"Python commit job should run successfully, but got exit code: $code") } } diff --git a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md index 6cc42ba9c902..1a04fd57090d 100644 --- a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md +++ b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md @@ -44,6 +44,7 @@ | org.apache.spark.sql.catalyst.expressions.Atanh | atanh | SELECT atanh(0) | struct | | org.apache.spark.sql.catalyst.expressions.BRound | bround | SELECT bround(2.5, 0) | struct | | org.apache.spark.sql.catalyst.expressions.Base64 | base64 | SELECT base64('Spark SQL') | struct | +| org.apache.spark.sql.catalyst.expressions.Between | between | SELECT 0.5 between 0.1 AND 1.0 | struct | | org.apache.spark.sql.catalyst.expressions.Bin | bin | SELECT bin(13) | struct | | org.apache.spark.sql.catalyst.expressions.BitLength | bit_length | SELECT bit_length('Spark SQL') | struct | | org.apache.spark.sql.catalyst.expressions.BitmapBitPosition | bitmap_bit_position | SELECT bitmap_bit_position(1) | struct | diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/postgreSQL/create_view.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/postgreSQL/create_view.sql.out index 0a74ec87eb83..bbf0127f0ef6 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/postgreSQL/create_view.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/postgreSQL/create_view.sql.out @@ -846,7 +846,7 @@ CreateViewCommand `spark_catalog`.`testviewschm2`.`pubview`, SELECT * FROM tbl1 BETWEEN (SELECT d FROM tbl2 WHERE c = 1) AND (SELECT e FROM tbl3 WHERE f = 2) AND EXISTS (SELECT g FROM tbl4 LEFT JOIN tbl3 ON tbl4.h = tbl3.f), false, false, PersistedView, true +- Project [a#x, b#x] - +- Filter (((a#x >= scalar-subquery#x []) AND (a#x <= scalar-subquery#x [])) AND exists#x []) + +- Filter (between(a#x, scalar-subquery#x [], scalar-subquery#x []) AND exists#x []) : :- Project [d#x] : : +- Filter (c#x = 1) : : +- SubqueryAlias spark_catalog.testviewschm2.tbl2 @@ -882,7 +882,7 @@ BETWEEN (SELECT d FROM tbl2 WHERE c = 1) AND (SELECT e FROM tbl3 WHERE f = 2) AND EXISTS (SELECT g FROM tbl4 LEFT JOIN tbl3 ON tbl4.h = tbl3.f) AND NOT EXISTS (SELECT g FROM tbl4 LEFT JOIN tmptbl ON tbl4.h = tmptbl.j), false, false, PersistedView, true +- Project [a#x, b#x] - +- Filter ((((a#x >= scalar-subquery#x []) AND (a#x <= scalar-subquery#x [])) AND exists#x []) AND NOT exists#x []) + +- Filter ((between(a#x, scalar-subquery#x [], scalar-subquery#x []) AND exists#x []) AND NOT exists#x []) : :- Project [d#x] : : +- Filter (c#x = 1) : : +- SubqueryAlias spark_catalog.testviewschm2.tbl2 diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/postgreSQL/date.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/postgreSQL/date.sql.out index 00813e42d7a4..d3ac6a3eb2b5 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/postgreSQL/date.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/postgreSQL/date.sql.out @@ -147,7 +147,7 @@ SELECT f1 AS `Three` FROM DATE_TBL WHERE f1 BETWEEN '2000-01-01' AND '2001-01-01' -- !query analysis Project [f1#x AS Three#x] -+- Filter ((f1#x >= cast(2000-01-01 as date)) AND (f1#x <= cast(2001-01-01 as date))) ++- Filter between(f1#x, 2000-01-01, 2001-01-01) +- SubqueryAlias spark_catalog.default.date_tbl +- Relation spark_catalog.default.date_tbl[f1#x] parquet diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/postgreSQL/timestamp.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/postgreSQL/timestamp.sql.out index cc30ebddcbb9..a6c3c2782969 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/postgreSQL/timestamp.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/postgreSQL/timestamp.sql.out @@ -205,7 +205,7 @@ SELECT '' AS `54`, d1 as `timestamp`, FROM TIMESTAMP_TBL WHERE d1 BETWEEN '1902-01-01' AND '2038-01-01' -- !query analysis Project [ AS 54#x, d1#x AS timestamp#x, date_part(year, d1#x) AS year#x, date_part(month, d1#x) AS month#x, date_part(day, d1#x) AS day#x, date_part(hour, d1#x) AS hour#x, date_part(minute, d1#x) AS minute#x, date_part(second, d1#x) AS second#x] -+- Filter ((d1#x >= cast(1902-01-01 as timestamp)) AND (d1#x <= cast(2038-01-01 as timestamp))) ++- Filter between(d1#x, 1902-01-01, 2038-01-01) +- SubqueryAlias spark_catalog.default.timestamp_tbl +- Relation spark_catalog.default.timestamp_tbl[d1#x] parquet diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/postgreSQL/union.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/postgreSQL/union.sql.out index 343767d49b0d..c9f8513adac5 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/postgreSQL/union.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/postgreSQL/union.sql.out @@ -409,7 +409,7 @@ Sort [five#x ASC NULLS FIRST], true +- Distinct +- Union false, false :- Project [f1#x AS five#x] - : +- Filter ((f1#x >= -1000000.0) AND (f1#x <= 1000000.0)) + : +- Filter between(f1#x, -1000000.0, 1000000.0) : +- SubqueryAlias float8_tbl : +- View (`FLOAT8_TBL`, [f1#x]) : +- Project [cast(f1#x as double) AS f1#x] @@ -419,7 +419,7 @@ Sort [five#x ASC NULLS FIRST], true : +- LocalRelation [col1#x] +- Project [cast(f1#x as double) AS f1#x] +- Project [f1#x] - +- Filter ((f1#x >= 0) AND (f1#x <= 1000000)) + +- Filter between(f1#x, 0, 1000000) +- SubqueryAlias int4_tbl +- View (`INT4_TBL`, [f1#x]) +- Project [cast(f1#x as int) AS f1#x] diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/predicate-functions.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/predicate-functions.sql.out index a04ee20fa799..772e643027b1 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/predicate-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/predicate-functions.sql.out @@ -368,3 +368,85 @@ select null not in (1, 2, null) -- !query analysis Project [NOT cast(null as int) IN (cast(1 as int),cast(2 as int),cast(null as int)) AS (NOT (NULL IN (1, 2, NULL)))#x] +- OneRowRelation + + +-- !query +select 1 between 0 and 2 +-- !query analysis +Project [between(1, 0, 2) AS between(1, 0, 2)#x] ++- OneRowRelation + + +-- !query +select 0.5 between 0 and 1 +-- !query analysis +Project [between(0.5, 0, 1) AS between(0.5, 0, 1)#x] ++- OneRowRelation + + +-- !query +select 2.0 between '1.0' and '3.0' +-- !query analysis +Project [between(2.0, 1.0, 3.0) AS between(2.0, 1.0, 3.0)#x] ++- OneRowRelation + + +-- !query +select 'b' between 'a' and 'c' +-- !query analysis +Project [between(b, a, c) AS between(b, a, c)#x] ++- OneRowRelation + + +-- !query +select to_timestamp('2022-12-26 00:00:01') between to_date('2022-03-01') and to_date('2022-12-31') +-- !query analysis +Project [between(to_timestamp(2022-12-26 00:00:01, None, TimestampType, Some(America/Los_Angeles), false), to_date(2022-03-01, None, Some(America/Los_Angeles), false), to_date(2022-12-31, None, Some(America/Los_Angeles), false)) AS between(to_timestamp(2022-12-26 00:00:01), to_date(2022-03-01), to_date(2022-12-31))#x] ++- OneRowRelation + + +-- !query +select rand(123) between 0.1 AND 0.2 +-- !query analysis +[Analyzer test output redacted due to nondeterminism] + + +-- !query +select 1 not between 0 and 2 +-- !query analysis +Project [NOT between(1, 0, 2) AS (NOT between(1, 0, 2))#x] ++- OneRowRelation + + +-- !query +select 0.5 not between 0 and 1 +-- !query analysis +Project [NOT between(0.5, 0, 1) AS (NOT between(0.5, 0, 1))#x] ++- OneRowRelation + + +-- !query +select 2.0 not between '1.0' and '3.0' +-- !query analysis +Project [NOT between(2.0, 1.0, 3.0) AS (NOT between(2.0, 1.0, 3.0))#x] ++- OneRowRelation + + +-- !query +select 'b' not between 'a' and 'c' +-- !query analysis +Project [NOT between(b, a, c) AS (NOT between(b, a, c))#x] ++- OneRowRelation + + +-- !query +select to_timestamp('2022-12-26 00:00:01') not between to_date('2022-03-01') and to_date('2022-12-31') +-- !query analysis +Project [NOT between(to_timestamp(2022-12-26 00:00:01, None, TimestampType, Some(America/Los_Angeles), false), to_date(2022-03-01, None, Some(America/Los_Angeles), false), to_date(2022-12-31, None, Some(America/Los_Angeles), false)) AS (NOT between(to_timestamp(2022-12-26 00:00:01), to_date(2022-03-01), to_date(2022-12-31)))#x] ++- OneRowRelation + + +-- !query +select rand(123) not between 0.1 AND 0.2 +-- !query analysis +[Analyzer test output redacted due to nondeterminism] diff --git a/sql/core/src/test/resources/sql-tests/inputs/predicate-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/predicate-functions.sql index d19120cfbdc5..6f64b0da6502 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/predicate-functions.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/predicate-functions.sql @@ -66,3 +66,19 @@ select 1 not in ('2', '3', '4'); select 1 not in ('2', '3', '4', null); select null not in (1, 2, 3); select null not in (1, 2, null); + +-- Between +select 1 between 0 and 2; +select 0.5 between 0 and 1; +select 2.0 between '1.0' and '3.0'; +select 'b' between 'a' and 'c'; +select to_timestamp('2022-12-26 00:00:01') between to_date('2022-03-01') and to_date('2022-12-31'); +select rand(123) between 0.1 AND 0.2; + +-- Not(Between) +select 1 not between 0 and 2; +select 0.5 not between 0 and 1; +select 2.0 not between '1.0' and '3.0'; +select 'b' not between 'a' and 'c'; +select to_timestamp('2022-12-26 00:00:01') not between to_date('2022-03-01') and to_date('2022-12-31'); +select rand(123) not between 0.1 AND 0.2; diff --git a/sql/core/src/test/resources/sql-tests/results/predicate-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/predicate-functions.sql.out index 1cdf26d6eac9..71c342054ae4 100644 --- a/sql/core/src/test/resources/sql-tests/results/predicate-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/predicate-functions.sql.out @@ -421,3 +421,99 @@ select null not in (1, 2, null) struct<(NOT (NULL IN (1, 2, NULL))):boolean> -- !query output NULL + + +-- !query +select 1 between 0 and 2 +-- !query schema +struct +-- !query output +true + + +-- !query +select 0.5 between 0 and 1 +-- !query schema +struct +-- !query output +true + + +-- !query +select 2.0 between '1.0' and '3.0' +-- !query schema +struct +-- !query output +true + + +-- !query +select 'b' between 'a' and 'c' +-- !query schema +struct +-- !query output +true + + +-- !query +select to_timestamp('2022-12-26 00:00:01') between to_date('2022-03-01') and to_date('2022-12-31') +-- !query schema +struct +-- !query output +true + + +-- !query +select rand(123) between 0.1 AND 0.2 +-- !query schema +struct +-- !query output +true + + +-- !query +select 1 not between 0 and 2 +-- !query schema +struct<(NOT between(1, 0, 2)):boolean> +-- !query output +false + + +-- !query +select 0.5 not between 0 and 1 +-- !query schema +struct<(NOT between(0.5, 0, 1)):boolean> +-- !query output +false + + +-- !query +select 2.0 not between '1.0' and '3.0' +-- !query schema +struct<(NOT between(2.0, 1.0, 3.0)):boolean> +-- !query output +false + + +-- !query +select 'b' not between 'a' and 'c' +-- !query schema +struct<(NOT between(b, a, c)):boolean> +-- !query output +false + + +-- !query +select to_timestamp('2022-12-26 00:00:01') not between to_date('2022-03-01') and to_date('2022-12-31') +-- !query schema +struct<(NOT between(to_timestamp(2022-12-26 00:00:01), to_date(2022-03-01), to_date(2022-12-31))):boolean> +-- !query output +false + + +-- !query +select rand(123) not between 0.1 AND 0.2 +-- !query schema +struct<(NOT between(rand(123), 0.1, 0.2)):boolean> +-- !query output +false diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryExecutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryExecutionSuite.scala index 583d7fd7ee3b..7fa34cfddbf0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryExecutionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryExecutionSuite.scala @@ -160,7 +160,9 @@ class QueryExecutionSuite extends SharedSparkSession { // Throw an AnalysisException - this should be captured. spark.experimental.extraStrategies = Seq[SparkStrategy]( - (_: LogicalPlan) => throw new AnalysisException("_LEGACY_ERROR_TEMP_3078", Map.empty)) + (_: LogicalPlan) => throw new AnalysisException( + "UNSUPPORTED_DATASOURCE_FOR_DIRECT_QUERY", + messageParameters = Map("dataSourceType" -> "XXX"))) assert(qe.toString.contains("org.apache.spark.sql.AnalysisException")) // Throw an Error - this should not be captured. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/DescribeTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/DescribeTableSuite.scala index e2f2aee56115..a21baebe24d8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/DescribeTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/DescribeTableSuite.scala @@ -175,4 +175,25 @@ class DescribeTableSuite extends command.DescribeTableSuiteBase Row("max_col_len", "NULL"))) } } + + test("SPARK-46535: describe extended (formatted) a column without col stats") { + withNamespaceAndTable("ns", "tbl") { tbl => + sql( + s""" + |CREATE TABLE $tbl + |(key INT COMMENT 'column_comment', col STRING) + |$defaultUsing""".stripMargin) + + val descriptionDf = sql(s"DESCRIBE TABLE EXTENDED $tbl key") + assert(descriptionDf.schema.map(field => (field.name, field.dataType)) === Seq( + ("info_name", StringType), + ("info_value", StringType))) + QueryTest.checkAnswer( + descriptionDf, + Seq( + Row("col_name", "key"), + Row("data_type", "int"), + Row("comment", "column_comment"))) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala index da2705f7c72b..c27b71ac8278 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala @@ -2344,9 +2344,7 @@ class ParquetV2FilterSuite extends ParquetFilterSuite { checker(stripSparkFilter(query), expected) - case _ => - throw new AnalysisException( - errorClass = "_LEGACY_ERROR_TEMP_3078", messageParameters = Map.empty) + case _ => assert(false, "Can not match ParquetTable in the query.") } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala index b3e8e3c79384..4b9a95856afb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala @@ -22,6 +22,7 @@ import java.sql.{Date, Timestamp} import java.time.{Instant, LocalDateTime} import java.util.TimeZone +import scala.collection.immutable.ArraySeq import scala.collection.mutable import scala.io.Source import scala.jdk.CollectionConverters._ @@ -1145,7 +1146,7 @@ class XmlSuite extends QueryTest with SharedSparkSession { .option("inferSchema", true) .xml(getTestResourcePath(resDir + "mixed_children.xml")) val mixedRow = mixedDF.head() - assert(mixedRow.getAs[Row](0).toSeq === Seq(" lorem ")) + assert(mixedRow.getAs[Row](0) === Row(List(" issue ", " text ignored "), " lorem ")) assert(mixedRow.getString(1) === " ipsum ") } @@ -1729,9 +1730,15 @@ class XmlSuite extends QueryTest with SharedSparkSession { val TAG_NAME = "tag" val VALUETAG_NAME = "_VALUE" val schema = buildSchema( + field(VALUETAG_NAME), field(ATTRIBUTE_NAME), - field(TAG_NAME, LongType), - field(VALUETAG_NAME)) + field(TAG_NAME, LongType)) + val expectedAns = Seq( + Row("value1", null, null), + Row("value2", "attr1", null), + Row("4", null, 5L), + Row("7", null, 6L), + Row(null, "8", null)) val dfs = Seq( // user specified schema spark.read @@ -1744,25 +1751,7 @@ class XmlSuite extends QueryTest with SharedSparkSession { .xml(getTestResourcePath(resDir + "root-level-value-none.xml")) ) dfs.foreach { df => - val result = df.collect() - assert(result.length === 5) - assert(result(0).get(0) == null && result(0).get(1) == null) - assert( - result(1).getAs[String](ATTRIBUTE_NAME) == "attr1" - && result(1).getAs[Any](TAG_NAME) == null - ) - assert( - result(2).getAs[Long](TAG_NAME) == 5L - && result(2).getAs[Any](ATTRIBUTE_NAME) == null - ) - assert( - result(3).getAs[Long](TAG_NAME) == 6L - && result(3).getAs[Any](ATTRIBUTE_NAME) == null - ) - assert( - result(4).getAs[String](ATTRIBUTE_NAME) == "8" - && result(4).getAs[Any](TAG_NAME) == null - ) + checkAnswer(df, expectedAns) } } @@ -2371,4 +2360,248 @@ class XmlSuite extends QueryTest with SharedSparkSession { } } } + + test("capture values interspersed between elements - simple") { + val xmlString = + s""" + | + | value1 + | + | value2 + | 1 + | value3 + | + | value4 + | + |""".stripMargin + val input = spark.createDataset(Seq(xmlString)) + val df = spark.read + .option("rowTag", "ROW") + .option("ignoreSurroundingSpaces", true) + .option("multiLine", "true") + .xml(input) + + checkAnswer(df, Seq(Row(Array("value1", "value4"), Row(Array("value2", "value3"), 1)))) + } + + test("capture values interspersed between elements - array") { + val xmlString = + s""" + | + | value1 + | + | value2 + | 1 + | value3 + | + | + | value4 + | 2 + | value5 + | 3 + | value6 + | + | + |""".stripMargin + val input = spark.createDataset(Seq(xmlString)) + val expectedAns = Seq( + Row( + "value1", + Array( + Row(List("value2", "value3"), 1, null), + Row(List("value4", "value5", "value6"), 2, 3)))) + val df = spark.read + .option("rowTag", "ROW") + .option("ignoreSurroundingSpaces", true) + .option("multiLine", "true") + .xml(input) + + checkAnswer(df, expectedAns) + + } + + test("capture values interspersed between elements - long and double") { + val xmlString = + s""" + | + | + | 1 + | 2 + | 3 + | 4 + | 5.0 + | + | + |""".stripMargin + val input = spark.createDataset(Seq(xmlString)) + val df = spark.read + .option("rowTag", "ROW") + .option("ignoreSurroundingSpaces", true) + .option("multiLine", "true") + .xml(input) + + checkAnswer(df, Seq(Row(Row(Array(1.0, 3.0, 5.0), Array(2, 4))))) + } + + test("capture values interspersed between elements - comments") { + val xmlString = + s""" + | + | 1 2 + | + |""".stripMargin + val input = spark.createDataset(Seq(xmlString)) + val df = spark.read + .option("rowTag", "ROW") + .option("ignoreSurroundingSpaces", true) + .option("multiLine", "true") + .xml(input) + + checkAnswer(df, Seq(Row(Row(Array(1, 2))))) + } + + test("capture values interspersed between elements - whitespaces with quotes") { + val xmlString = + s""" + | + | " " + | " "1 + | + | + |""".stripMargin + val input = spark.createDataset(Seq(xmlString)) + val df = spark.read + .option("rowTag", "ROW") + .option("ignoreSurroundingSpaces", false) + .option("multiLine", "true") + .xml(input) + + checkAnswer(df, Seq( + Row("\" \"", Row(1, "\" \""), Row(Row(null, " "))))) + } + + test("capture values interspersed between elements - nested comments") { + val xmlString = + s""" + | + | 1 + | 2 + | 1 + | 3 + | 2 + | + | + |""".stripMargin + val input = spark.createDataset(Seq(xmlString)) + val df = spark.read + .option("rowTag", "ROW") + .option("ignoreSurroundingSpaces", true) + .option("multiLine", "true") + .xml(input) + + checkAnswer(df, Seq(Row(Row(Array(1, 2, 3), Array(1, 2))))) + } + + test("capture values interspersed between elements - nested struct") { + val xmlString = + s""" + | + | + | + | 1 + | value1 + | 2 + | value2 + | 3 + | + | value4 + | + | + |""".stripMargin + val input = spark.createDataset(Seq(xmlString)) + val df = spark.read + .option("rowTag", "ROW") + .option("ignoreSurroundingSpaces", true) + .option("multiLine", "true") + .xml(input) + + checkAnswer( + df, + Seq( + Row( + Row( + "value4", + Row( + Array("value1", "value2"), + Array(1, 2), + 3))))) + } + + test("capture values interspersed between elements - deeply nested") { + val xmlString = + s""" + | + | value1 + | + | value2 + | + | value3 + | + | value4 + | + | value5 + | 1 + | value6 + | 2 + | value7 + | + | value8 + | string + | value9 + | + | value10 + | + | + | 3 + | value11 + | 4 + | + | string + | value12 + | + | value13 + | 3 + | value14 + | + | value15 + | + | value16 + | + |""".stripMargin + val input = spark.createDataset(Seq(xmlString)) + val df = spark.read + .option("ignoreSurroundingSpaces", true) + .option("rowTag", "ROW") + .option("multiLine", "true") + .xml(input) + + val expectedAns = Seq(Row( + ArraySeq("value1", "value16"), + Row( + ArraySeq("value2", "value15"), + Row( + ArraySeq("value3", "value10", "value13", "value14"), + Array( + Row( + ArraySeq("value4", "value8", "value9"), + "string", + Row(ArraySeq("value5", "value6", "value7"), ArraySeq(1, 2))), + Row( + ArraySeq("value12"), + "string", + Row(ArraySeq("value11"), ArraySeq(3, 4)))), + 3)))) + + checkAnswer(df, expectedAns) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala index c261f1d529fd..080f57aa08a0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala @@ -17,11 +17,16 @@ package org.apache.spark.sql.execution.python +import java.io.{File, FileWriter} + import org.apache.spark.SparkException +import org.apache.spark.api.python.PythonUtils import org.apache.spark.sql.{AnalysisException, IntegratedUDFTestUtils, QueryTest, Row} +import org.apache.spark.sql.execution.datasources.DataSourceManager import org.apache.spark.sql.execution.datasources.v2.{BatchScanExec, DataSourceV2ScanRelation} import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types.StructType +import org.apache.spark.util.Utils class PythonDataSourceSuite extends QueryTest with SharedSparkSession { import IntegratedUDFTestUtils._ @@ -29,7 +34,7 @@ class PythonDataSourceSuite extends QueryTest with SharedSparkSession { setupTestData() private def dataSourceName = "SimpleDataSource" - private def simpleDataSourceReaderScript: String = + private val simpleDataSourceReaderScript: String = """ |from pyspark.sql.datasource import DataSourceReader, InputPartition |class SimpleDataSourceReader(DataSourceReader): @@ -40,6 +45,56 @@ class PythonDataSourceSuite extends QueryTest with SharedSparkSession { | yield (1, partition.value) | yield (2, partition.value) |""".stripMargin + private val staticSourceName = "custom_source" + private var tempDir: File = _ + + override def beforeAll(): Unit = { + // Create a Python Data Source package before starting up the Spark Session + // that triggers automatic registration of the Python Data Source. + val dataSourceScript = + s""" + |from pyspark.sql.datasource import DataSource, DataSourceReader + |$simpleDataSourceReaderScript + | + |class DefaultSource(DataSource): + | def schema(self) -> str: + | return "id INT, partition INT" + | + | def reader(self, schema): + | return SimpleDataSourceReader() + | + | @classmethod + | def name(cls): + | return "$staticSourceName" + |""".stripMargin + tempDir = Utils.createTempDir() + // Write a temporary package to test. + // tmp/my_source + // tmp/my_source/__init__.py + val packageDir = new File(tempDir, "pyspark_mysource") + assert(packageDir.mkdir()) + Utils.tryWithResource( + new FileWriter(new File(packageDir, "__init__.py")))(_.write(dataSourceScript)) + // So Spark Session initialization can lookup this temporary directory. + DataSourceManager.dataSourceBuilders = None + PythonUtils.additionalTestingPath = Some(tempDir.toString) + super.beforeAll() + } + + override def afterAll(): Unit = { + try { + Utils.deleteRecursively(tempDir) + PythonUtils.additionalTestingPath = None + } finally { + super.afterAll() + } + } + + test("SPARK-45917: automatic registration of Python Data Source") { + assume(shouldTestPandasUDFs) + val df = spark.read.format(staticSourceName).load() + checkAnswer(df, Seq(Row(0, 0), Row(0, 1), Row(1, 0), Row(1, 1), Row(2, 0), Row(2, 1))) + } test("simple data source") { assume(shouldTestPandasUDFs) @@ -600,4 +655,95 @@ class PythonDataSourceSuite extends QueryTest with SharedSparkSession { Seq(Row(2))) } } + + test("data source write commit and abort") { + assume(shouldTestPandasUDFs) + val dataSourceScript = + s""" + |import json + |import os + |from dataclasses import dataclass + |from pyspark import TaskContext + |from pyspark.sql.datasource import DataSource, DataSourceWriter, WriterCommitMessage + | + |@dataclass + |class SimpleCommitMessage(WriterCommitMessage): + | partition_id: int + | count: int + | + |class SimpleDataSourceWriter(DataSourceWriter): + | def __init__(self, options): + | self.options = options + | self.path = self.options.get("path") + | assert self.path is not None + | + | def write(self, iterator): + | context = TaskContext.get() + | partition_id = context.partitionId() + | output_path = os.path.join(self.path, f"{partition_id}.json") + | cnt = 0 + | with open(output_path, "w") as file: + | for row in iterator: + | if row.id >= 10: + | raise Exception("invalid value") + | file.write(json.dumps(row.asDict()) + "\\n") + | cnt += 1 + | return SimpleCommitMessage(partition_id=partition_id, count=cnt) + | + | def commit(self, messages) -> None: + | status = dict(num_files=len(messages), count=sum(m.count for m in messages)) + | + | with open(os.path.join(self.path, "success.json"), "a") as file: + | file.write(json.dumps(status) + "\\n") + | + | def abort(self, messages) -> None: + | with open(os.path.join(self.path, "failed.txt"), "a") as file: + | file.write("failed") + | + |class SimpleDataSource(DataSource): + | def writer(self, schema, saveMode): + | return SimpleDataSourceWriter(self.options) + |""".stripMargin + val dataSource = createUserDefinedPythonDataSource(dataSourceName, dataSourceScript) + spark.dataSource.registerPython(dataSourceName, dataSource) + withTempDir { dir => + val path = dir.getAbsolutePath + + withClue("commit") { + sql("SELECT * FROM range(0, 5, 1, 3)") + .write.format(dataSourceName) + .mode("append") + .save(path) + checkAnswer( + spark.read.format("json") + .schema("num_files bigint, count bigint") + .load(path + "/success.json"), + Seq(Row(3, 5))) + } + + withClue("commit again") { + sql("SELECT * FROM range(5, 7, 1, 1)") + .write.format(dataSourceName) + .mode("append") + .save(path) + checkAnswer( + spark.read.format("json") + .schema("num_files bigint, count bigint") + .load(path + "/success.json"), + Seq(Row(3, 5), Row(1, 2))) + } + + withClue("abort") { + intercept[SparkException] { + sql("SELECT * FROM range(8, 12, 1, 4)") + .write.format(dataSourceName) + .mode("append") + .save(path) + } + checkAnswer( + spark.read.text(path + "/failed.txt"), + Seq(Row("failed"))) + } + } + } } diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreLazyInitializationSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreLazyInitializationSuite.scala index af11b817d65b..b8739ce56e41 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreLazyInitializationSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreLazyInitializationSuite.scala @@ -61,11 +61,10 @@ class HiveMetastoreLazyInitializationSuite extends SparkFunSuite { spark.sql("show tables") }) for (msg <- Seq( - "show tables", "Could not connect to meta store", "org.apache.thrift.transport.TTransportException", "Connection refused")) { - exceptionString.contains(msg) + assert(exceptionString.contains(msg)) } } finally { Thread.currentThread().setContextClassLoader(originalClassLoader) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index c1d4fb364d3e..1a0c2f6165a4 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -246,8 +246,8 @@ abstract class SQLQuerySuiteBase extends QueryTest with SQLTestUtils with TestHi checkKeywordsExist(sql("describe function `between`"), "Function: between", - "Usage: expr1 [NOT] BETWEEN expr2 AND expr3 - " + - "evaluate if `expr1` is [not] in between `expr2` and `expr3`") + "Usage: input [NOT] BETWEEN lower AND upper - " + + "evaluate if `input` is [not] in between `lower` and `upper`") checkKeywordsExist(sql("describe function `case`"), "Function: case",