diff --git a/.github/workflows/build_and_test.yml b/.github/workflows/build_and_test.yml index 0521270cee7c..646a096b72de 100644 --- a/.github/workflows/build_and_test.yml +++ b/.github/workflows/build_and_test.yml @@ -243,7 +243,7 @@ jobs: - name: Install Python packages (Python 3.8) if: (contains(matrix.modules, 'sql') && !contains(matrix.modules, 'sql-')) run: | - python3.8 -m pip install 'numpy>=1.20.0' pyarrow pandas scipy unittest-xml-reporting 'grpcio==1.48.1' 'protobuf==4.21.6' + python3.8 -m pip install 'numpy>=1.20.0' pyarrow pandas scipy unittest-xml-reporting 'grpcio==1.48.1' 'protobuf==3.19.4' python3.8 -m pip list # Run the tests. - name: Run tests @@ -589,7 +589,7 @@ jobs: # See also https://issues.apache.org/jira/browse/SPARK-38279. python3.9 -m pip install 'sphinx<3.1.0' mkdocs pydata_sphinx_theme ipython nbsphinx numpydoc 'jinja2<3.0.0' 'markupsafe==2.0.1' 'pyzmq<24.0.0' python3.9 -m pip install ipython_genutils # See SPARK-38517 - python3.9 -m pip install sphinx_plotly_directive 'numpy>=1.20.0' pyarrow pandas 'plotly>=4.8' 'grpcio==1.48.1' 'protobuf==4.21.6' 'mypy-protobuf==3.3.0' + python3.9 -m pip install sphinx_plotly_directive 'numpy>=1.20.0' pyarrow pandas 'plotly>=4.8' 'grpcio==1.48.1' 'protobuf==3.19.4' 'mypy-protobuf==3.3.0' python3.9 -m pip install 'docutils<0.18.0' # See SPARK-39421 apt-get update -y apt-get install -y ruby ruby-dev diff --git a/connector/connect/src/main/protobuf/spark/connect/base.proto b/connector/connect/src/main/protobuf/spark/connect/base.proto index 66e27187153b..277da6b2431d 100644 --- a/connector/connect/src/main/protobuf/spark/connect/base.proto +++ b/connector/connect/src/main/protobuf/spark/connect/base.proto @@ -139,11 +139,7 @@ message ExecutePlanRequest { message ExecutePlanResponse { string client_id = 1; - // Result type - oneof result_type { - ArrowBatch arrow_batch = 2; - JSONBatch json_batch = 3; - } + ArrowBatch arrow_batch = 2; // Metrics for the query execution. Typically, this field is only present in the last // batch of results and then represent the overall state of the query execution. @@ -155,14 +151,6 @@ message ExecutePlanResponse { bytes data = 2; } - // Message type when the result is returned as JSON. This is essentially a bulk wrapper - // for the JSON result of a Spark DataFrame. All rows are returned in the JSON record format - // of `{col -> row}`. - message JSONBatch { - int64 row_count = 1; - bytes data = 2; - } - message Metrics { repeated MetricObject metrics = 1; diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala index 50ff08f997cb..092bdd00dc1c 100644 --- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala +++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala @@ -18,12 +18,10 @@ package org.apache.spark.sql.connect.service import scala.collection.JavaConverters._ -import scala.util.control.NonFatal import com.google.protobuf.ByteString import io.grpc.stub.StreamObserver -import org.apache.spark.SparkException import org.apache.spark.connect.proto import org.apache.spark.connect.proto.{ExecutePlanRequest, ExecutePlanResponse} import org.apache.spark.internal.Logging @@ -34,7 +32,6 @@ import org.apache.spark.sql.execution.{SparkPlan, SQLExecution} import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, AdaptiveSparkPlanHelper, QueryStageExec} import org.apache.spark.sql.execution.arrow.ArrowConverters import org.apache.spark.sql.types.StructType -import org.apache.spark.util.ThreadUtils class SparkConnectStreamHandler(responseObserver: StreamObserver[ExecutePlanResponse]) extends Logging { @@ -57,75 +54,7 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[ExecutePlanResp // Extract the plan from the request and convert it to a logical plan val planner = new SparkConnectPlanner(session) val dataframe = Dataset.ofRows(session, planner.transformRelation(request.getPlan.getRoot)) - try { - processAsArrowBatches(request.getClientId, dataframe) - } catch { - case e: Exception => - logWarning(e.getMessage) - processAsJsonBatches(request.getClientId, dataframe) - } - } - - def processAsJsonBatches(clientId: String, dataframe: DataFrame): Unit = { - // Only process up to 10MB of data. - val sb = new StringBuilder - var rowCount = 0 - dataframe.toJSON - .collect() - .foreach(row => { - - // There are a few cases to cover here. - // 1. The aggregated buffer size is larger than the MAX_BATCH_SIZE - // -> send the current batch and reset. - // 2. The aggregated buffer size is smaller than the MAX_BATCH_SIZE - // -> append the row to the buffer. - // 3. The row in question is larger than the MAX_BATCH_SIZE - // -> fail the query. - - // Case 3. - Fail - if (row.size > MAX_BATCH_SIZE) { - throw SparkException.internalError( - s"Serialized row is larger than MAX_BATCH_SIZE: ${row.size} > ${MAX_BATCH_SIZE}") - } - - // Case 1 - FLush and send. - if (sb.size + row.size > MAX_BATCH_SIZE) { - val response = proto.ExecutePlanResponse.newBuilder().setClientId(clientId) - val batch = proto.ExecutePlanResponse.JSONBatch - .newBuilder() - .setData(ByteString.copyFromUtf8(sb.toString())) - .setRowCount(rowCount) - .build() - response.setJsonBatch(batch) - responseObserver.onNext(response.build()) - sb.clear() - sb.append(row) - rowCount = 1 - } else { - // Case 2 - Append. - // Make sure to put the newline delimiters only between items and not at the end. - if (rowCount > 0) { - sb.append("\n") - } - sb.append(row) - rowCount += 1 - } - }) - - // If the last batch is not empty, send out the data to the client. - if (sb.size > 0) { - val response = proto.ExecutePlanResponse.newBuilder().setClientId(clientId) - val batch = proto.ExecutePlanResponse.JSONBatch - .newBuilder() - .setData(ByteString.copyFromUtf8(sb.toString())) - .setRowCount(rowCount) - .build() - response.setJsonBatch(batch) - responseObserver.onNext(response.build()) - } - - responseObserver.onNext(sendMetricsToResponse(clientId, dataframe)) - responseObserver.onCompleted() + processAsArrowBatches(request.getClientId, dataframe) } def processAsArrowBatches(clientId: String, dataframe: DataFrame): Unit = { @@ -142,83 +71,20 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[ExecutePlanResp var numSent = 0 if (numPartitions > 0) { - type Batch = (Array[Byte], Long) - val batches = rows.mapPartitionsInternal( SparkConnectStreamHandler .rowToArrowConverter(schema, maxRecordsPerBatch, maxBatchSize, timeZoneId)) - val signal = new Object - val partitions = collection.mutable.Map.empty[Int, Array[Batch]] - var error: Throwable = null - - val processPartition = (iter: Iterator[Batch]) => iter.toArray - - // This callback is executed by the DAGScheduler thread. - // After fetching a partition, it inserts the partition into the Map, and then - // wakes up the main thread. - val resultHandler = (partitionId: Int, partition: Array[Batch]) => { - signal.synchronized { - partitions(partitionId) = partition - signal.notify() - } - () - } - - val future = spark.sparkContext.submitJob( - rdd = batches, - processPartition = processPartition, - partitions = Seq.range(0, numPartitions), - resultHandler = resultHandler, - resultFunc = () => ()) - - // Collect errors and propagate them to the main thread. - future.onComplete { result => - result.failed.foreach { throwable => - signal.synchronized { - error = throwable - signal.notify() - } - } - }(ThreadUtils.sameThread) - - // The main thread will wait until 0-th partition is available, - // then send it to client and wait for the next partition. - // Different from the implementation of [[Dataset#collectAsArrowToPython]], it sends - // the arrow batches in main thread to avoid DAGScheduler thread been blocked for - // tasks not related to scheduling. This is particularly important if there are - // multiple users or clients running code at the same time. - var currentPartitionId = 0 - while (currentPartitionId < numPartitions) { - val partition = signal.synchronized { - var result = partitions.remove(currentPartitionId) - while (result.isEmpty && error == null) { - signal.wait() - result = partitions.remove(currentPartitionId) - } - error match { - case NonFatal(e) => - responseObserver.onError(error) - logError("Error while processing query.", e) - return - case fatal: Throwable => throw fatal - case null => result.get - } - } - - partition.foreach { case (bytes, count) => - val response = proto.ExecutePlanResponse.newBuilder().setClientId(clientId) - val batch = proto.ExecutePlanResponse.ArrowBatch - .newBuilder() - .setRowCount(count) - .setData(ByteString.copyFrom(bytes)) - .build() - response.setArrowBatch(batch) - responseObserver.onNext(response.build()) - numSent += 1 - } - - currentPartitionId += 1 + batches.collect().foreach { case (bytes, count) => + val response = proto.ExecutePlanResponse.newBuilder().setClientId(clientId) + val batch = proto.ExecutePlanResponse.ArrowBatch + .newBuilder() + .setRowCount(count) + .setData(ByteString.copyFrom(bytes)) + .build() + response.setArrowBatch(batch) + responseObserver.onNext(response.build()) + numSent += 1 } } diff --git a/core/src/main/scala/org/apache/spark/storage/PushBasedFetchHelper.scala b/core/src/main/scala/org/apache/spark/storage/PushBasedFetchHelper.scala index dd81c860ba33..8cc1b865207d 100644 --- a/core/src/main/scala/org/apache/spark/storage/PushBasedFetchHelper.scala +++ b/core/src/main/scala/org/apache/spark/storage/PushBasedFetchHelper.scala @@ -285,6 +285,8 @@ private class PushBasedFetchHelper( * 2. There is a failure when fetching remote shuffle chunks. * 3. There is a failure when processing SuccessFetchResult which is for a shuffle chunk * (local or remote). + * 4. There is a zero-size buffer when processing SuccessFetchResult for a shuffle chunk + * (local or remote). */ def initiateFallbackFetchForPushMergedBlock( blockId: BlockId, diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index b5f20522e91f..e35144756b59 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -782,7 +782,7 @@ final class ShuffleBlockFetcherIterator( logDebug("Number of requests in flight " + reqsInFlight) } - if (buf.size == 0) { + val in = if (buf.size == 0) { // We will never legitimately receive a zero-size block. All blocks with zero records // have zero size and all zero-size blocks have no records (and hence should never // have been requested in the first place). This statement relies on behaviors of the @@ -798,38 +798,52 @@ final class ShuffleBlockFetcherIterator( // since the last call. val msg = s"Received a zero-size buffer for block $blockId from $address " + s"(expectedApproxSize = $size, isNetworkReqDone=$isNetworkReqDone)" - throwFetchFailedException(blockId, mapIndex, address, new IOException(msg)) - } - - val in = try { - val bufIn = buf.createInputStream() - if (checksumEnabled) { - val checksum = ShuffleChecksumHelper.getChecksumByAlgorithm(checksumAlgorithm) - checkedIn = new CheckedInputStream(bufIn, checksum) - checkedIn + if (blockId.isShuffleChunk) { + // Zero-size block may come from nodes with hardware failures, For shuffle chunks, + // the original shuffle blocks that belong to that zero-size shuffle chunk is + // available and we can opt to fallback immediately. + logWarning(msg) + pushBasedFetchHelper.initiateFallbackFetchForPushMergedBlock(blockId, address) + // Set result to null to trigger another iteration of the while loop to get either. + result = null + null } else { - bufIn + throwFetchFailedException(blockId, mapIndex, address, new IOException(msg)) } - } catch { - // The exception could only be throwed by local shuffle block - case e: IOException => - assert(buf.isInstanceOf[FileSegmentManagedBuffer]) - e match { - case ce: ClosedByInterruptException => - logError("Failed to create input stream from local block, " + - ce.getMessage) - case e: IOException => logError("Failed to create input stream from local block", e) - } - buf.release() - if (blockId.isShuffleChunk) { - pushBasedFetchHelper.initiateFallbackFetchForPushMergedBlock(blockId, address) - // Set result to null to trigger another iteration of the while loop to get either. - result = null - null + } else { + try { + val bufIn = buf.createInputStream() + if (checksumEnabled) { + val checksum = ShuffleChecksumHelper.getChecksumByAlgorithm(checksumAlgorithm) + checkedIn = new CheckedInputStream(bufIn, checksum) + checkedIn } else { - throwFetchFailedException(blockId, mapIndex, address, e) + bufIn } + } catch { + // The exception could only be throwed by local shuffle block + case e: IOException => + assert(buf.isInstanceOf[FileSegmentManagedBuffer]) + e match { + case ce: ClosedByInterruptException => + logError("Failed to create input stream from local block, " + + ce.getMessage) + case e: IOException => + logError("Failed to create input stream from local block", e) + } + buf.release() + if (blockId.isShuffleChunk) { + pushBasedFetchHelper.initiateFallbackFetchForPushMergedBlock(blockId, address) + // Set result to null to trigger another iteration of the while loop to get + // either. + result = null + null + } else { + throwFetchFailedException(blockId, mapIndex, address, e) + } + } } + if (in != null) { try { input = streamWrapper(blockId, in) diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index f8fe28c0512b..64b6c93bf52c 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -1814,4 +1814,17 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT intercept[FetchFailedException] { iterator.next() } } + test("SPARK-40872: fallback to original shuffle block when a push-merged shuffle chunk " + + "is zero-size") { + val blockManager = mock(classOf[BlockManager]) + val localDirs = Array("local-dir") + val blocksByAddress = prepareForFallbackToLocalBlocks( + blockManager, Map(SHUFFLE_MERGER_IDENTIFIER -> localDirs)) + val zeroSizeBuffer = createMockManagedBuffer(0) + doReturn(Seq({zeroSizeBuffer})).when(blockManager) + .getLocalMergedBlockData(ShuffleMergedBlockId(0, 0, 2), localDirs) + val iterator = createShuffleBlockIteratorWithDefaults(blocksByAddress, + blockManager = Some(blockManager), streamWrapperLimitSize = Some(100)) + verifyLocalBlocksFromFallback(iterator) + } } diff --git a/dev/deps/spark-deps-hadoop-2-hive-2.3 b/dev/deps/spark-deps-hadoop-2-hive-2.3 index 41e8ff96acc0..7b7c3ac7fb36 100644 --- a/dev/deps/spark-deps-hadoop-2-hive-2.3 +++ b/dev/deps/spark-deps-hadoop-2-hive-2.3 @@ -268,6 +268,6 @@ xml-apis/1.4.01//xml-apis-1.4.01.jar xmlenc/0.52//xmlenc-0.52.jar xz/1.9//xz-1.9.jar zjsonpatch/0.3.0//zjsonpatch-0.3.0.jar -zookeeper-jute/3.6.2//zookeeper-jute-3.6.2.jar -zookeeper/3.6.2//zookeeper-3.6.2.jar +zookeeper-jute/3.6.3//zookeeper-jute-3.6.3.jar +zookeeper/3.6.3//zookeeper-3.6.3.jar zstd-jni/1.5.2-5//zstd-jni-1.5.2-5.jar diff --git a/dev/deps/spark-deps-hadoop-3-hive-2.3 b/dev/deps/spark-deps-hadoop-3-hive-2.3 index 876147c8bca6..c648f8896c38 100644 --- a/dev/deps/spark-deps-hadoop-3-hive-2.3 +++ b/dev/deps/spark-deps-hadoop-3-hive-2.3 @@ -253,6 +253,6 @@ wildfly-openssl/1.0.7.Final//wildfly-openssl-1.0.7.Final.jar xbean-asm9-shaded/4.22//xbean-asm9-shaded-4.22.jar xz/1.9//xz-1.9.jar zjsonpatch/0.3.0//zjsonpatch-0.3.0.jar -zookeeper-jute/3.6.2//zookeeper-jute-3.6.2.jar -zookeeper/3.6.2//zookeeper-3.6.2.jar +zookeeper-jute/3.6.3//zookeeper-jute-3.6.3.jar +zookeeper/3.6.3//zookeeper-3.6.3.jar zstd-jni/1.5.2-5//zstd-jni-1.5.2-5.jar diff --git a/docs/sql-ref-syntax-qry-select.md b/docs/sql-ref-syntax-qry-select.md index ea5c4a69d9ab..22c4d78605b4 100644 --- a/docs/sql-ref-syntax-qry-select.md +++ b/docs/sql-ref-syntax-qry-select.md @@ -83,6 +83,8 @@ SELECT [ hints , ... ] [ ALL | DISTINCT ] { [ [ named_expression | regex_column_ Specifies a source of input for the query. It can be one of the following: * Table relation * [Join relation](sql-ref-syntax-qry-select-join.html) + * [Pivot relation](sql-ref-syntax-qry-select-pivot.md) + * [Unpivot relation](sql-ref-syntax-qry-select-unpivot.md) * [Table-value function](sql-ref-syntax-qry-select-tvf.html) * [Inline table](sql-ref-syntax-qry-select-inline-table.html) * [ [LATERAL](sql-ref-syntax-qry-select-lateral-subquery.html) ] ( Subquery ) diff --git a/pom.xml b/pom.xml index 2dd898b8787e..e07b01ea9557 100644 --- a/pom.xml +++ b/pom.xml @@ -120,7 +120,7 @@ 3.3.4 2.5.0 ${hadoop.version} - 3.6.2 + 3.6.3 2.13.0 org.apache.hive core diff --git a/python/pyspark/sql/connect/client.py b/python/pyspark/sql/connect/client.py index 5bdf01afc99c..fdcf34b7a47e 100644 --- a/python/pyspark/sql/connect/client.py +++ b/python/pyspark/sql/connect/client.py @@ -16,7 +16,6 @@ # -import io import logging import os import typing @@ -446,13 +445,9 @@ def _analyze(self, plan: pb2.Plan, explain_mode: str = "extended") -> AnalyzeRes return AnalyzeResult.fromProto(resp) def _process_batch(self, b: pb2.ExecutePlanResponse) -> Optional[pandas.DataFrame]: - import pandas as pd - if b.arrow_batch is not None and len(b.arrow_batch.data) > 0: with pa.ipc.open_stream(b.arrow_batch.data) as rd: return rd.read_pandas() - elif b.json_batch is not None and len(b.json_batch.data) > 0: - return pd.read_json(io.BytesIO(b.json_batch.data), lines=True) return None def _execute_and_fetch(self, req: pb2.ExecutePlanRequest) -> typing.Optional[pandas.DataFrame]: diff --git a/python/pyspark/sql/connect/proto/base_pb2.py b/python/pyspark/sql/connect/proto/base_pb2.py index 7d9f98b243e1..daa1c25cc8fe 100644 --- a/python/pyspark/sql/connect/proto/base_pb2.py +++ b/python/pyspark/sql/connect/proto/base_pb2.py @@ -36,7 +36,7 @@ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x18spark/connect/base.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1cspark/connect/commands.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"t\n\x04Plan\x12-\n\x04root\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationH\x00R\x04root\x12\x32\n\x07\x63ommand\x18\x02 \x01(\x0b\x32\x16.spark.connect.CommandH\x00R\x07\x63ommandB\t\n\x07op_type"\xb5\x01\n\x07\x45xplain\x12\x45\n\x0c\x65xplain_mode\x18\x01 \x01(\x0e\x32".spark.connect.Explain.ExplainModeR\x0b\x65xplainMode"c\n\x0b\x45xplainMode\x12\x14\n\x10MODE_UNSPECIFIED\x10\x00\x12\n\n\x06SIMPLE\x10\x01\x12\x0c\n\x08\x45XTENDED\x10\x02\x12\x0b\n\x07\x43ODEGEN\x10\x03\x12\x08\n\x04\x43OST\x10\x04\x12\r\n\tFORMATTED\x10\x05"z\n\x0bUserContext\x12\x17\n\x07user_id\x18\x01 \x01(\tR\x06userId\x12\x1b\n\tuser_name\x18\x02 \x01(\tR\x08userName\x12\x35\n\nextensions\x18\xe7\x07 \x03(\x0b\x32\x14.google.protobuf.AnyR\nextensions"\x81\x02\n\x12\x41nalyzePlanRequest\x12\x1b\n\tclient_id\x18\x01 \x01(\tR\x08\x63lientId\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12\'\n\x04plan\x18\x03 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x12$\n\x0b\x63lient_type\x18\x04 \x01(\tH\x00R\nclientType\x88\x01\x01\x12\x30\n\x07\x65xplain\x18\x05 \x01(\x0b\x32\x16.spark.connect.ExplainR\x07\x65xplainB\x0e\n\x0c_client_type"\x8a\x01\n\x13\x41nalyzePlanResponse\x12\x1b\n\tclient_id\x18\x01 \x01(\tR\x08\x63lientId\x12/\n\x06schema\x18\x02 \x01(\x0b\x32\x17.spark.connect.DataTypeR\x06schema\x12%\n\x0e\x65xplain_string\x18\x03 \x01(\tR\rexplainString"\xcf\x01\n\x12\x45xecutePlanRequest\x12\x1b\n\tclient_id\x18\x01 \x01(\tR\x08\x63lientId\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12\'\n\x04plan\x18\x03 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x12$\n\x0b\x63lient_type\x18\x04 \x01(\tH\x00R\nclientType\x88\x01\x01\x42\x0e\n\x0c_client_type"\xad\x07\n\x13\x45xecutePlanResponse\x12\x1b\n\tclient_id\x18\x01 \x01(\tR\x08\x63lientId\x12P\n\x0b\x61rrow_batch\x18\x02 \x01(\x0b\x32-.spark.connect.ExecutePlanResponse.ArrowBatchH\x00R\narrowBatch\x12M\n\njson_batch\x18\x03 \x01(\x0b\x32,.spark.connect.ExecutePlanResponse.JSONBatchH\x00R\tjsonBatch\x12\x44\n\x07metrics\x18\x04 \x01(\x0b\x32*.spark.connect.ExecutePlanResponse.MetricsR\x07metrics\x1a=\n\nArrowBatch\x12\x1b\n\trow_count\x18\x01 \x01(\x03R\x08rowCount\x12\x12\n\x04\x64\x61ta\x18\x02 \x01(\x0cR\x04\x64\x61ta\x1a<\n\tJSONBatch\x12\x1b\n\trow_count\x18\x01 \x01(\x03R\x08rowCount\x12\x12\n\x04\x64\x61ta\x18\x02 \x01(\x0cR\x04\x64\x61ta\x1a\x85\x04\n\x07Metrics\x12Q\n\x07metrics\x18\x01 \x03(\x0b\x32\x37.spark.connect.ExecutePlanResponse.Metrics.MetricObjectR\x07metrics\x1a\xcc\x02\n\x0cMetricObject\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x17\n\x07plan_id\x18\x02 \x01(\x03R\x06planId\x12\x16\n\x06parent\x18\x03 \x01(\x03R\x06parent\x12z\n\x11\x65xecution_metrics\x18\x04 \x03(\x0b\x32M.spark.connect.ExecutePlanResponse.Metrics.MetricObject.ExecutionMetricsEntryR\x10\x65xecutionMetrics\x1a{\n\x15\x45xecutionMetricsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12L\n\x05value\x18\x02 \x01(\x0b\x32\x36.spark.connect.ExecutePlanResponse.Metrics.MetricValueR\x05value:\x02\x38\x01\x1aX\n\x0bMetricValue\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x14\n\x05value\x18\x02 \x01(\x03R\x05value\x12\x1f\n\x0bmetric_type\x18\x03 \x01(\tR\nmetricTypeB\r\n\x0bresult_type2\xc7\x01\n\x13SparkConnectService\x12X\n\x0b\x45xecutePlan\x12!.spark.connect.ExecutePlanRequest\x1a".spark.connect.ExecutePlanResponse"\x00\x30\x01\x12V\n\x0b\x41nalyzePlan\x12!.spark.connect.AnalyzePlanRequest\x1a".spark.connect.AnalyzePlanResponse"\x00\x42"\n\x1eorg.apache.spark.connect.protoP\x01\x62\x06proto3' + b'\n\x18spark/connect/base.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1cspark/connect/commands.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"t\n\x04Plan\x12-\n\x04root\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationH\x00R\x04root\x12\x32\n\x07\x63ommand\x18\x02 \x01(\x0b\x32\x16.spark.connect.CommandH\x00R\x07\x63ommandB\t\n\x07op_type"\xb5\x01\n\x07\x45xplain\x12\x45\n\x0c\x65xplain_mode\x18\x01 \x01(\x0e\x32".spark.connect.Explain.ExplainModeR\x0b\x65xplainMode"c\n\x0b\x45xplainMode\x12\x14\n\x10MODE_UNSPECIFIED\x10\x00\x12\n\n\x06SIMPLE\x10\x01\x12\x0c\n\x08\x45XTENDED\x10\x02\x12\x0b\n\x07\x43ODEGEN\x10\x03\x12\x08\n\x04\x43OST\x10\x04\x12\r\n\tFORMATTED\x10\x05"z\n\x0bUserContext\x12\x17\n\x07user_id\x18\x01 \x01(\tR\x06userId\x12\x1b\n\tuser_name\x18\x02 \x01(\tR\x08userName\x12\x35\n\nextensions\x18\xe7\x07 \x03(\x0b\x32\x14.google.protobuf.AnyR\nextensions"\x81\x02\n\x12\x41nalyzePlanRequest\x12\x1b\n\tclient_id\x18\x01 \x01(\tR\x08\x63lientId\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12\'\n\x04plan\x18\x03 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x12$\n\x0b\x63lient_type\x18\x04 \x01(\tH\x00R\nclientType\x88\x01\x01\x12\x30\n\x07\x65xplain\x18\x05 \x01(\x0b\x32\x16.spark.connect.ExplainR\x07\x65xplainB\x0e\n\x0c_client_type"\x8a\x01\n\x13\x41nalyzePlanResponse\x12\x1b\n\tclient_id\x18\x01 \x01(\tR\x08\x63lientId\x12/\n\x06schema\x18\x02 \x01(\x0b\x32\x17.spark.connect.DataTypeR\x06schema\x12%\n\x0e\x65xplain_string\x18\x03 \x01(\tR\rexplainString"\xcf\x01\n\x12\x45xecutePlanRequest\x12\x1b\n\tclient_id\x18\x01 \x01(\tR\x08\x63lientId\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12\'\n\x04plan\x18\x03 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x12$\n\x0b\x63lient_type\x18\x04 \x01(\tH\x00R\nclientType\x88\x01\x01\x42\x0e\n\x0c_client_type"\x8f\x06\n\x13\x45xecutePlanResponse\x12\x1b\n\tclient_id\x18\x01 \x01(\tR\x08\x63lientId\x12N\n\x0b\x61rrow_batch\x18\x02 \x01(\x0b\x32-.spark.connect.ExecutePlanResponse.ArrowBatchR\narrowBatch\x12\x44\n\x07metrics\x18\x04 \x01(\x0b\x32*.spark.connect.ExecutePlanResponse.MetricsR\x07metrics\x1a=\n\nArrowBatch\x12\x1b\n\trow_count\x18\x01 \x01(\x03R\x08rowCount\x12\x12\n\x04\x64\x61ta\x18\x02 \x01(\x0cR\x04\x64\x61ta\x1a\x85\x04\n\x07Metrics\x12Q\n\x07metrics\x18\x01 \x03(\x0b\x32\x37.spark.connect.ExecutePlanResponse.Metrics.MetricObjectR\x07metrics\x1a\xcc\x02\n\x0cMetricObject\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x17\n\x07plan_id\x18\x02 \x01(\x03R\x06planId\x12\x16\n\x06parent\x18\x03 \x01(\x03R\x06parent\x12z\n\x11\x65xecution_metrics\x18\x04 \x03(\x0b\x32M.spark.connect.ExecutePlanResponse.Metrics.MetricObject.ExecutionMetricsEntryR\x10\x65xecutionMetrics\x1a{\n\x15\x45xecutionMetricsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12L\n\x05value\x18\x02 \x01(\x0b\x32\x36.spark.connect.ExecutePlanResponse.Metrics.MetricValueR\x05value:\x02\x38\x01\x1aX\n\x0bMetricValue\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x14\n\x05value\x18\x02 \x01(\x03R\x05value\x12\x1f\n\x0bmetric_type\x18\x03 \x01(\tR\nmetricType2\xc7\x01\n\x13SparkConnectService\x12X\n\x0b\x45xecutePlan\x12!.spark.connect.ExecutePlanRequest\x1a".spark.connect.ExecutePlanResponse"\x00\x30\x01\x12V\n\x0b\x41nalyzePlan\x12!.spark.connect.AnalyzePlanRequest\x1a".spark.connect.AnalyzePlanResponse"\x00\x42"\n\x1eorg.apache.spark.connect.protoP\x01\x62\x06proto3' ) @@ -48,7 +48,6 @@ _EXECUTEPLANREQUEST = DESCRIPTOR.message_types_by_name["ExecutePlanRequest"] _EXECUTEPLANRESPONSE = DESCRIPTOR.message_types_by_name["ExecutePlanResponse"] _EXECUTEPLANRESPONSE_ARROWBATCH = _EXECUTEPLANRESPONSE.nested_types_by_name["ArrowBatch"] -_EXECUTEPLANRESPONSE_JSONBATCH = _EXECUTEPLANRESPONSE.nested_types_by_name["JSONBatch"] _EXECUTEPLANRESPONSE_METRICS = _EXECUTEPLANRESPONSE.nested_types_by_name["Metrics"] _EXECUTEPLANRESPONSE_METRICS_METRICOBJECT = _EXECUTEPLANRESPONSE_METRICS.nested_types_by_name[ "MetricObject" @@ -139,15 +138,6 @@ # @@protoc_insertion_point(class_scope:spark.connect.ExecutePlanResponse.ArrowBatch) }, ), - "JSONBatch": _reflection.GeneratedProtocolMessageType( - "JSONBatch", - (_message.Message,), - { - "DESCRIPTOR": _EXECUTEPLANRESPONSE_JSONBATCH, - "__module__": "spark.connect.base_pb2" - # @@protoc_insertion_point(class_scope:spark.connect.ExecutePlanResponse.JSONBatch) - }, - ), "Metrics": _reflection.GeneratedProtocolMessageType( "Metrics", (_message.Message,), @@ -191,7 +181,6 @@ ) _sym_db.RegisterMessage(ExecutePlanResponse) _sym_db.RegisterMessage(ExecutePlanResponse.ArrowBatch) -_sym_db.RegisterMessage(ExecutePlanResponse.JSONBatch) _sym_db.RegisterMessage(ExecutePlanResponse.Metrics) _sym_db.RegisterMessage(ExecutePlanResponse.Metrics.MetricObject) _sym_db.RegisterMessage(ExecutePlanResponse.Metrics.MetricObject.ExecutionMetricsEntry) @@ -219,19 +208,17 @@ _EXECUTEPLANREQUEST._serialized_start = 986 _EXECUTEPLANREQUEST._serialized_end = 1193 _EXECUTEPLANRESPONSE._serialized_start = 1196 - _EXECUTEPLANRESPONSE._serialized_end = 2137 - _EXECUTEPLANRESPONSE_ARROWBATCH._serialized_start = 1479 - _EXECUTEPLANRESPONSE_ARROWBATCH._serialized_end = 1540 - _EXECUTEPLANRESPONSE_JSONBATCH._serialized_start = 1542 - _EXECUTEPLANRESPONSE_JSONBATCH._serialized_end = 1602 - _EXECUTEPLANRESPONSE_METRICS._serialized_start = 1605 - _EXECUTEPLANRESPONSE_METRICS._serialized_end = 2122 - _EXECUTEPLANRESPONSE_METRICS_METRICOBJECT._serialized_start = 1700 - _EXECUTEPLANRESPONSE_METRICS_METRICOBJECT._serialized_end = 2032 - _EXECUTEPLANRESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_start = 1909 - _EXECUTEPLANRESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_end = 2032 - _EXECUTEPLANRESPONSE_METRICS_METRICVALUE._serialized_start = 2034 - _EXECUTEPLANRESPONSE_METRICS_METRICVALUE._serialized_end = 2122 - _SPARKCONNECTSERVICE._serialized_start = 2140 - _SPARKCONNECTSERVICE._serialized_end = 2339 + _EXECUTEPLANRESPONSE._serialized_end = 1979 + _EXECUTEPLANRESPONSE_ARROWBATCH._serialized_start = 1398 + _EXECUTEPLANRESPONSE_ARROWBATCH._serialized_end = 1459 + _EXECUTEPLANRESPONSE_METRICS._serialized_start = 1462 + _EXECUTEPLANRESPONSE_METRICS._serialized_end = 1979 + _EXECUTEPLANRESPONSE_METRICS_METRICOBJECT._serialized_start = 1557 + _EXECUTEPLANRESPONSE_METRICS_METRICOBJECT._serialized_end = 1889 + _EXECUTEPLANRESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_start = 1766 + _EXECUTEPLANRESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_end = 1889 + _EXECUTEPLANRESPONSE_METRICS_METRICVALUE._serialized_start = 1891 + _EXECUTEPLANRESPONSE_METRICS_METRICVALUE._serialized_end = 1979 + _SPARKCONNECTSERVICE._serialized_start = 1982 + _SPARKCONNECTSERVICE._serialized_end = 2181 # @@protoc_insertion_point(module_scope) diff --git a/python/pyspark/sql/connect/proto/base_pb2.pyi b/python/pyspark/sql/connect/proto/base_pb2.pyi index 18b70de57a3c..64bb51d4c0b9 100644 --- a/python/pyspark/sql/connect/proto/base_pb2.pyi +++ b/python/pyspark/sql/connect/proto/base_pb2.pyi @@ -401,28 +401,6 @@ class ExecutePlanResponse(google.protobuf.message.Message): self, field_name: typing_extensions.Literal["data", b"data", "row_count", b"row_count"] ) -> None: ... - class JSONBatch(google.protobuf.message.Message): - """Message type when the result is returned as JSON. This is essentially a bulk wrapper - for the JSON result of a Spark DataFrame. All rows are returned in the JSON record format - of `{col -> row}`. - """ - - DESCRIPTOR: google.protobuf.descriptor.Descriptor - - ROW_COUNT_FIELD_NUMBER: builtins.int - DATA_FIELD_NUMBER: builtins.int - row_count: builtins.int - data: builtins.bytes - def __init__( - self, - *, - row_count: builtins.int = ..., - data: builtins.bytes = ..., - ) -> None: ... - def ClearField( - self, field_name: typing_extensions.Literal["data", b"data", "row_count", b"row_count"] - ) -> None: ... - class Metrics(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor @@ -530,14 +508,11 @@ class ExecutePlanResponse(google.protobuf.message.Message): CLIENT_ID_FIELD_NUMBER: builtins.int ARROW_BATCH_FIELD_NUMBER: builtins.int - JSON_BATCH_FIELD_NUMBER: builtins.int METRICS_FIELD_NUMBER: builtins.int client_id: builtins.str @property def arrow_batch(self) -> global___ExecutePlanResponse.ArrowBatch: ... @property - def json_batch(self) -> global___ExecutePlanResponse.JSONBatch: ... - @property def metrics(self) -> global___ExecutePlanResponse.Metrics: """Metrics for the query execution. Typically, this field is only present in the last batch of results and then represent the overall state of the query execution. @@ -547,39 +522,17 @@ class ExecutePlanResponse(google.protobuf.message.Message): *, client_id: builtins.str = ..., arrow_batch: global___ExecutePlanResponse.ArrowBatch | None = ..., - json_batch: global___ExecutePlanResponse.JSONBatch | None = ..., metrics: global___ExecutePlanResponse.Metrics | None = ..., ) -> None: ... def HasField( self, - field_name: typing_extensions.Literal[ - "arrow_batch", - b"arrow_batch", - "json_batch", - b"json_batch", - "metrics", - b"metrics", - "result_type", - b"result_type", - ], + field_name: typing_extensions.Literal["arrow_batch", b"arrow_batch", "metrics", b"metrics"], ) -> builtins.bool: ... def ClearField( self, field_name: typing_extensions.Literal[ - "arrow_batch", - b"arrow_batch", - "client_id", - b"client_id", - "json_batch", - b"json_batch", - "metrics", - b"metrics", - "result_type", - b"result_type", + "arrow_batch", b"arrow_batch", "client_id", b"client_id", "metrics", b"metrics" ], ) -> None: ... - def WhichOneof( - self, oneof_group: typing_extensions.Literal["result_type", b"result_type"] - ) -> typing_extensions.Literal["arrow_batch", "json_batch"] | None: ... global___ExecutePlanResponse = ExecutePlanResponse diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py b/python/pyspark/sql/tests/connect/test_connect_basic.py index d3de94a379f8..9e7a5f2f4a54 100644 --- a/python/pyspark/sql/tests/connect/test_connect_basic.py +++ b/python/pyspark/sql/tests/connect/test_connect_basic.py @@ -221,7 +221,60 @@ def test_create_global_temp_view(self): with self.assertRaises(_MultiThreadedRendezvous): self.connect.sql("SELECT 1 AS X LIMIT 0").createGlobalTempView("view_1") - @unittest.skip("test_fill_na is flaky") + def test_to_pandas(self): + # SPARK-41005: Test to pandas + query = """ + SELECT * FROM VALUES + (false, 1, NULL), + (false, NULL, float(2.0)), + (NULL, 3, float(3.0)) + AS tab(a, b, c) + """ + + self.assert_eq( + self.connect.sql(query).toPandas(), + self.spark.sql(query).toPandas(), + ) + + query = """ + SELECT * FROM VALUES + (1, 1, NULL), + (2, NULL, float(2.0)), + (3, 3, float(3.0)) + AS tab(a, b, c) + """ + + self.assert_eq( + self.connect.sql(query).toPandas(), + self.spark.sql(query).toPandas(), + ) + + query = """ + SELECT * FROM VALUES + (double(1.0), 1, "1"), + (NULL, NULL, NULL), + (double(2.0), 3, "3") + AS tab(a, b, c) + """ + + self.assert_eq( + self.connect.sql(query).toPandas(), + self.spark.sql(query).toPandas(), + ) + + query = """ + SELECT * FROM VALUES + (float(1.0), double(1.0), 1, "1"), + (float(2.0), double(2.0), 2, "2"), + (float(3.0), double(3.0), 3, "3") + AS tab(a, b, c, d) + """ + + self.assert_eq( + self.connect.sql(query).toPandas(), + self.spark.sql(query).toPandas(), + ) + def test_fill_na(self): # SPARK-41128: Test fill na query = """ diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 index a3c5f4a7b070..21747a0a021f 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 @@ -697,7 +697,13 @@ setQuantifier ; relation - : LATERAL? relationPrimary joinRelation* + : LATERAL? relationPrimary relationExtension* + ; + +relationExtension + : joinRelation + | pivotClause + | unpivotClause ; joinRelation diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 104c5c1e0805..8cdf83f4be75 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -113,9 +113,9 @@ object FakeV2SessionCatalog extends TableCatalog with FunctionCatalog { * @param nestedViewDepth The nested depth in the view resolution, this enables us to limit the * depth of nested views. * @param maxNestedViewDepth The maximum allowed depth of nested view resolution. - * @param relationCache A mapping from qualified table names to resolved relations. This can ensure - * that the table is resolved only once if a table is used multiple times - * in a query. + * @param relationCache A mapping from qualified table names and time travel spec to resolved + * relations. This can ensure that the table is resolved only once if a table + * is used multiple times in a query. * @param referredTempViewNames All the temp view names referred by the current view we are * resolving. It's used to make sure the relation resolution is * consistent between view creation and view resolution. For example, @@ -129,7 +129,8 @@ case class AnalysisContext( catalogAndNamespace: Seq[String] = Nil, nestedViewDepth: Int = 0, maxNestedViewDepth: Int = -1, - relationCache: mutable.Map[Seq[String], LogicalPlan] = mutable.Map.empty, + relationCache: mutable.Map[(Seq[String], Option[TimeTravelSpec]), LogicalPlan] = + mutable.Map.empty, referredTempViewNames: Seq[Seq[String]] = Seq.empty, // 1. If we are resolving a view, this field will be restored from the view metadata, // by calling `AnalysisContext.withAnalysisContext(viewDesc)`. @@ -1239,7 +1240,7 @@ class Analyzer(override val catalogManager: CatalogManager) resolveTempView(u.multipartIdentifier, u.isStreaming, timeTravelSpec.isDefined).orElse { expandIdentifier(u.multipartIdentifier) match { case CatalogAndIdentifier(catalog, ident) => - val key = catalog.name +: ident.namespace :+ ident.name + val key = ((catalog.name +: ident.namespace :+ ident.name).toSeq, timeTravelSpec) AnalysisContext.get.relationCache.get(key).map(_.transform { case multi: MultiInstanceRelation => val newRelation = multi.newInstance() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index 99e5f411bdb6..8dd28e9aaae3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -464,8 +464,9 @@ object FileSourceMetadataAttribute { val FILE_SOURCE_METADATA_COL_ATTR_KEY = "__file_source_metadata_col" - def apply(name: String, dataType: DataType, nullable: Boolean = true): AttributeReference = - AttributeReference(name, dataType, nullable, + def apply(name: String, dataType: DataType): AttributeReference = + // Metadata column for file sources is always not nullable. + AttributeReference(name, dataType, nullable = false, new MetadataBuilder() .putBoolean(METADATA_COL_ATTR_KEY, value = true) .putBoolean(FILE_SOURCE_METADATA_COL_ATTR_KEY, value = true).build())() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 4adb70bc3909..d56ef28bcc32 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -929,7 +929,7 @@ class AstBuilder extends SqlBaseParserBaseVisitor[AnyRef] with SQLConfHelper wit Join(left, right, Inner, None, JoinHint.NONE) } } - if (conf.ansiRelationPrecedence) join else withJoinRelations(join, relation) + if (conf.ansiRelationPrecedence) join else withRelationExtensions(relation, join) } if (ctx.pivotClause() != null) { if (ctx.unpivotClause() != null) { @@ -1263,60 +1263,71 @@ class AstBuilder extends SqlBaseParserBaseVisitor[AnyRef] with SQLConfHelper wit * }}} */ override def visitRelation(ctx: RelationContext): LogicalPlan = withOrigin(ctx) { - withJoinRelations(plan(ctx.relationPrimary), ctx) + withRelationExtensions(ctx, plan(ctx.relationPrimary)) + } + + private def withRelationExtensions(ctx: RelationContext, query: LogicalPlan): LogicalPlan = { + ctx.relationExtension().asScala.foldLeft(query) { (left, extension) => + if (extension.joinRelation() != null) { + withJoinRelation(extension.joinRelation(), left) + } else if (extension.pivotClause() != null) { + withPivot(extension.pivotClause(), left) + } else { + assert(extension.unpivotClause() != null) + withUnpivot(extension.unpivotClause(), left) + } + } } /** - * Join one more [[LogicalPlan]]s to the current logical plan. + * Join one more [[LogicalPlan]] to the current logical plan. */ - private def withJoinRelations(base: LogicalPlan, ctx: RelationContext): LogicalPlan = { - ctx.joinRelation.asScala.foldLeft(base) { (left, join) => - withOrigin(join) { - val baseJoinType = join.joinType match { - case null => Inner - case jt if jt.CROSS != null => Cross - case jt if jt.FULL != null => FullOuter - case jt if jt.SEMI != null => LeftSemi - case jt if jt.ANTI != null => LeftAnti - case jt if jt.LEFT != null => LeftOuter - case jt if jt.RIGHT != null => RightOuter - case _ => Inner - } + private def withJoinRelation(ctx: JoinRelationContext, base: LogicalPlan): LogicalPlan = { + withOrigin(ctx) { + val baseJoinType = ctx.joinType match { + case null => Inner + case jt if jt.CROSS != null => Cross + case jt if jt.FULL != null => FullOuter + case jt if jt.SEMI != null => LeftSemi + case jt if jt.ANTI != null => LeftAnti + case jt if jt.LEFT != null => LeftOuter + case jt if jt.RIGHT != null => RightOuter + case _ => Inner + } - if (join.LATERAL != null && !join.right.isInstanceOf[AliasedQueryContext]) { - throw QueryParsingErrors.invalidLateralJoinRelationError(join.right) - } + if (ctx.LATERAL != null && !ctx.right.isInstanceOf[AliasedQueryContext]) { + throw QueryParsingErrors.invalidLateralJoinRelationError(ctx.right) + } - // Resolve the join type and join condition - val (joinType, condition) = Option(join.joinCriteria) match { - case Some(c) if c.USING != null => - if (join.LATERAL != null) { - throw QueryParsingErrors.lateralJoinWithUsingJoinUnsupportedError(ctx) - } - (UsingJoin(baseJoinType, visitIdentifierList(c.identifierList)), None) - case Some(c) if c.booleanExpression != null => - (baseJoinType, Option(expression(c.booleanExpression))) - case Some(c) => - throw new IllegalStateException(s"Unimplemented joinCriteria: $c") - case None if join.NATURAL != null => - if (join.LATERAL != null) { - throw QueryParsingErrors.lateralJoinWithNaturalJoinUnsupportedError(ctx) - } - if (baseJoinType == Cross) { - throw QueryParsingErrors.naturalCrossJoinUnsupportedError(ctx) - } - (NaturalJoin(baseJoinType), None) - case None => - (baseJoinType, None) - } - if (join.LATERAL != null) { - if (!Seq(Inner, Cross, LeftOuter).contains(joinType)) { - throw QueryParsingErrors.unsupportedLateralJoinTypeError(ctx, joinType.sql) + // Resolve the join type and join condition + val (joinType, condition) = Option(ctx.joinCriteria) match { + case Some(c) if c.USING != null => + if (ctx.LATERAL != null) { + throw QueryParsingErrors.lateralJoinWithUsingJoinUnsupportedError(ctx) } - LateralJoin(left, LateralSubquery(plan(join.right)), joinType, condition) - } else { - Join(left, plan(join.right), joinType, condition, JoinHint.NONE) + (UsingJoin(baseJoinType, visitIdentifierList(c.identifierList)), None) + case Some(c) if c.booleanExpression != null => + (baseJoinType, Option(expression(c.booleanExpression))) + case Some(c) => + throw new IllegalStateException(s"Unimplemented joinCriteria: $c") + case None if ctx.NATURAL != null => + if (ctx.LATERAL != null) { + throw QueryParsingErrors.lateralJoinWithNaturalJoinUnsupportedError(ctx) + } + if (baseJoinType == Cross) { + throw QueryParsingErrors.naturalCrossJoinUnsupportedError(ctx) + } + (NaturalJoin(baseJoinType), None) + case None => + (baseJoinType, None) + } + if (ctx.LATERAL != null) { + if (!Seq(Inner, Cross, LeftOuter).contains(joinType)) { + throw QueryParsingErrors.unsupportedLateralJoinTypeError(ctx, joinType.sql) } + LateralJoin(base, LateralSubquery(plan(ctx.right)), joinType, condition) + } else { + Join(base, plan(ctx.right), joinType, condition, JoinHint.NONE) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index 3c35ba9b6004..bbda9eb76b10 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -137,11 +137,22 @@ object ScanOperation extends OperationHelper { val alwaysInline = SQLConf.get.getConf(SQLConf.COLLAPSE_PROJECT_ALWAYS_INLINE) val (fields, filters, child, _) = collectProjectsAndFilters(plan, alwaysInline) // `collectProjectsAndFilters` transforms the plan bottom-up, so the bottom-most filter are - // placed at the beginning of `filters` list. According to the SQL semantic, we can only - // push down the bottom deterministic filters. - val filtersCanPushDown = filters.takeWhile(_.deterministic).flatMap(splitConjunctivePredicates) - val filtersStayUp = filters.dropWhile(_.deterministic) - Some((fields.getOrElse(child.output), filtersStayUp, filtersCanPushDown, child)) + // placed at the beginning of `filters` list. According to the SQL semantic, we cannot merge + // Filters if one or more of them are nondeterministic. This means we can only push down the + // bottom-most Filter, or more following deterministic Filters if the bottom-most Filter is + // also deterministic. + if (filters.isEmpty) { + Some((fields.getOrElse(child.output), Nil, Nil, child)) + } else if (filters.head.deterministic) { + val filtersCanPushDown = filters.takeWhile(_.deterministic) + .flatMap(splitConjunctivePredicates) + val filtersStayUp = filters.dropWhile(_.deterministic) + Some((fields.getOrElse(child.output), filtersStayUp, filtersCanPushDown, child)) + } else { + val filtersCanPushDown = splitConjunctivePredicates(filters.head) + val filtersStayUp = filters.drop(1) + Some((fields.getOrElse(child.output), filtersStayUp, filtersCanPushDown, child)) + } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryParsingErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryParsingErrors.scala index 8608d4ff306e..9624a06d80a6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryParsingErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryParsingErrors.scala @@ -99,15 +99,15 @@ private[sql] object QueryParsingErrors extends QueryErrorsBase { ctx) } - def unpivotWithPivotInFromClauseNotAllowedError(ctx: FromClauseContext): Throwable = { + def unpivotWithPivotInFromClauseNotAllowedError(ctx: ParserRuleContext): Throwable = { new ParseException("UNPIVOT cannot be used together with PIVOT in FROM clause", ctx) } - def lateralWithPivotInFromClauseNotAllowedError(ctx: FromClauseContext): Throwable = { + def lateralWithPivotInFromClauseNotAllowedError(ctx: ParserRuleContext): Throwable = { new ParseException(errorClass = "_LEGACY_ERROR_TEMP_0013", ctx) } - def lateralWithUnpivotInFromClauseNotAllowedError(ctx: FromClauseContext): Throwable = { + def lateralWithUnpivotInFromClauseNotAllowedError(ctx: ParserRuleContext): Throwable = { new ParseException("LATERAL cannot be used together with UNPIVOT in FROM clause", ctx) } @@ -164,7 +164,7 @@ private[sql] object QueryParsingErrors extends QueryErrorsBase { ctx) } - def naturalCrossJoinUnsupportedError(ctx: RelationContext): Throwable = { + def naturalCrossJoinUnsupportedError(ctx: ParserRuleContext): Throwable = { new ParseException( errorClass = "UNSUPPORTED_FEATURE.NATURAL_CROSS_JOIN", messageParameters = Map.empty, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/UnpivotParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/UnpivotParserSuite.scala index dd7e4ec4916f..c680e08c1c83 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/UnpivotParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/UnpivotParserSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.parser import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Unpivot} +import org.apache.spark.sql.internal.SQLConf class UnpivotParserSuite extends AnalysisTest { @@ -192,4 +193,151 @@ class UnpivotParserSuite extends AnalysisTest { ) } + test("unpivot - with joins") { + // unpivot the left table + assertEqual( + "SELECT * FROM t1 UNPIVOT (val FOR col in (a, b)) JOIN t2", + Unpivot( + None, + Some(Seq(Seq(UnresolvedAlias($"a")), Seq(UnresolvedAlias($"b")))), + None, + "col", + Seq("val"), + table("t1") + ).where(coalesce($"val").isNotNull).join(table("t2")).select(star())) + + // unpivot the join result + assertEqual( + "SELECT * FROM t1 JOIN t2 UNPIVOT (val FOR col in (a, b))", + Unpivot( + None, + Some(Seq(Seq(UnresolvedAlias($"a")), Seq(UnresolvedAlias($"b")))), + None, + "col", + Seq("val"), + table("t1").join(table("t2")) + ).where(coalesce($"val").isNotNull).select(star())) + + // unpivot the right table + assertEqual( + "SELECT * FROM t1 JOIN (t2 UNPIVOT (val FOR col in (a, b)))", + table("t1").join( + Unpivot( + None, + Some(Seq(Seq(UnresolvedAlias($"a")), Seq(UnresolvedAlias($"b")))), + None, + "col", + Seq("val"), + table("t2") + ).where(coalesce($"val").isNotNull) + ).select(star())) + } + + test("unpivot - with implicit joins") { + // unpivot the left table + assertEqual( + "SELECT * FROM t1 UNPIVOT (val FOR col in (a, b)), t2", + Unpivot( + None, + Some(Seq(Seq(UnresolvedAlias($"a")), Seq(UnresolvedAlias($"b")))), + None, + "col", + Seq("val"), + table("t1") + ).where(coalesce($"val").isNotNull).join(table("t2")).select(star())) + + // unpivot the join result + assertEqual( + "SELECT * FROM t1, t2 UNPIVOT (val FOR col in (a, b))", + Unpivot( + None, + Some(Seq(Seq(UnresolvedAlias($"a")), Seq(UnresolvedAlias($"b")))), + None, + "col", + Seq("val"), + table("t1").join(table("t2")) + ).where(coalesce($"val").isNotNull).select(star())) + + // unpivot the right table - same SQL as above but with ANSI mode + withSQLConf( + SQLConf.ANSI_ENABLED.key -> "true", + SQLConf.ANSI_RELATION_PRECEDENCE.key -> "true") { + assertEqual( + "SELECT * FROM t1, t2 UNPIVOT (val FOR col in (a, b))", + table("t1").join( + Unpivot( + None, + Some(Seq(Seq(UnresolvedAlias($"a")), Seq(UnresolvedAlias($"b")))), + None, + "col", + Seq("val"), + table("t2") + ).where(coalesce($"val").isNotNull) + ).select(star())) + } + + // unpivot the right table + assertEqual( + "SELECT * FROM t1, (t2 UNPIVOT (val FOR col in (a, b)))", + table("t1").join( + Unpivot( + None, + Some(Seq(Seq(UnresolvedAlias($"a")), Seq(UnresolvedAlias($"b")))), + None, + "col", + Seq("val"), + table("t2") + ).where(coalesce($"val").isNotNull) + ).select(star())) + + // mixed with explicit joins + assertEqual( + // unpivot the join result of t1, t2 and t3 + "SELECT * FROM t1, t2 JOIN t3 UNPIVOT (val FOR col in (a, b))", + Unpivot( + None, + Some(Seq(Seq(UnresolvedAlias($"a")), Seq(UnresolvedAlias($"b")))), + None, + "col", + Seq("val"), + table("t1").join(table("t2")).join(table("t3")) + ).where(coalesce($"val").isNotNull).select(star())) + withSQLConf( + SQLConf.ANSI_ENABLED.key -> "true", + SQLConf.ANSI_RELATION_PRECEDENCE.key -> "true") { + assertEqual( + // unpivot the join result of t2 and t3 + "SELECT * FROM t1, t2 JOIN t3 UNPIVOT (val FOR col in (a, b))", + table("t1").join( + Unpivot( + None, + Some(Seq(Seq(UnresolvedAlias($"a")), Seq(UnresolvedAlias($"b")))), + None, + "col", + Seq("val"), + table("t2").join(table("t3")) + ).where(coalesce($"val").isNotNull) + ).select(star())) + } + } + + test("unpivot - nested unpivot") { + assertEqual( + "SELECT * FROM t1 UNPIVOT (val FOR col in (a, b)) UNPIVOT (val FOR col in (a, b))", + Unpivot( + None, + Some(Seq(Seq(UnresolvedAlias($"a")), Seq(UnresolvedAlias($"b")))), + None, + "col", + Seq("val"), + Unpivot( + None, + Some(Seq(Seq(UnresolvedAlias($"a")), Seq(UnresolvedAlias($"b")))), + None, + "col", + Seq("val"), + table("t1") + ).where(coalesce($"val").isNotNull) + ).where(coalesce($"val").isNotNull).select(star())) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala index a8ccc39ac478..6b3744fe02d2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala @@ -40,6 +40,7 @@ import org.apache.spark.sql.execution.{ColumnarRule, SparkPlan} *