diff --git a/.github/workflows/build_and_test.yml b/.github/workflows/build_and_test.yml
index 0521270cee7c..646a096b72de 100644
--- a/.github/workflows/build_and_test.yml
+++ b/.github/workflows/build_and_test.yml
@@ -243,7 +243,7 @@ jobs:
- name: Install Python packages (Python 3.8)
if: (contains(matrix.modules, 'sql') && !contains(matrix.modules, 'sql-'))
run: |
- python3.8 -m pip install 'numpy>=1.20.0' pyarrow pandas scipy unittest-xml-reporting 'grpcio==1.48.1' 'protobuf==4.21.6'
+ python3.8 -m pip install 'numpy>=1.20.0' pyarrow pandas scipy unittest-xml-reporting 'grpcio==1.48.1' 'protobuf==3.19.4'
python3.8 -m pip list
# Run the tests.
- name: Run tests
@@ -589,7 +589,7 @@ jobs:
# See also https://issues.apache.org/jira/browse/SPARK-38279.
python3.9 -m pip install 'sphinx<3.1.0' mkdocs pydata_sphinx_theme ipython nbsphinx numpydoc 'jinja2<3.0.0' 'markupsafe==2.0.1' 'pyzmq<24.0.0'
python3.9 -m pip install ipython_genutils # See SPARK-38517
- python3.9 -m pip install sphinx_plotly_directive 'numpy>=1.20.0' pyarrow pandas 'plotly>=4.8' 'grpcio==1.48.1' 'protobuf==4.21.6' 'mypy-protobuf==3.3.0'
+ python3.9 -m pip install sphinx_plotly_directive 'numpy>=1.20.0' pyarrow pandas 'plotly>=4.8' 'grpcio==1.48.1' 'protobuf==3.19.4' 'mypy-protobuf==3.3.0'
python3.9 -m pip install 'docutils<0.18.0' # See SPARK-39421
apt-get update -y
apt-get install -y ruby ruby-dev
diff --git a/connector/connect/src/main/protobuf/spark/connect/base.proto b/connector/connect/src/main/protobuf/spark/connect/base.proto
index 66e27187153b..277da6b2431d 100644
--- a/connector/connect/src/main/protobuf/spark/connect/base.proto
+++ b/connector/connect/src/main/protobuf/spark/connect/base.proto
@@ -139,11 +139,7 @@ message ExecutePlanRequest {
message ExecutePlanResponse {
string client_id = 1;
- // Result type
- oneof result_type {
- ArrowBatch arrow_batch = 2;
- JSONBatch json_batch = 3;
- }
+ ArrowBatch arrow_batch = 2;
// Metrics for the query execution. Typically, this field is only present in the last
// batch of results and then represent the overall state of the query execution.
@@ -155,14 +151,6 @@ message ExecutePlanResponse {
bytes data = 2;
}
- // Message type when the result is returned as JSON. This is essentially a bulk wrapper
- // for the JSON result of a Spark DataFrame. All rows are returned in the JSON record format
- // of `{col -> row}`.
- message JSONBatch {
- int64 row_count = 1;
- bytes data = 2;
- }
-
message Metrics {
repeated MetricObject metrics = 1;
diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala
index 50ff08f997cb..092bdd00dc1c 100644
--- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala
+++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala
@@ -18,12 +18,10 @@
package org.apache.spark.sql.connect.service
import scala.collection.JavaConverters._
-import scala.util.control.NonFatal
import com.google.protobuf.ByteString
import io.grpc.stub.StreamObserver
-import org.apache.spark.SparkException
import org.apache.spark.connect.proto
import org.apache.spark.connect.proto.{ExecutePlanRequest, ExecutePlanResponse}
import org.apache.spark.internal.Logging
@@ -34,7 +32,6 @@ import org.apache.spark.sql.execution.{SparkPlan, SQLExecution}
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, AdaptiveSparkPlanHelper, QueryStageExec}
import org.apache.spark.sql.execution.arrow.ArrowConverters
import org.apache.spark.sql.types.StructType
-import org.apache.spark.util.ThreadUtils
class SparkConnectStreamHandler(responseObserver: StreamObserver[ExecutePlanResponse])
extends Logging {
@@ -57,75 +54,7 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[ExecutePlanResp
// Extract the plan from the request and convert it to a logical plan
val planner = new SparkConnectPlanner(session)
val dataframe = Dataset.ofRows(session, planner.transformRelation(request.getPlan.getRoot))
- try {
- processAsArrowBatches(request.getClientId, dataframe)
- } catch {
- case e: Exception =>
- logWarning(e.getMessage)
- processAsJsonBatches(request.getClientId, dataframe)
- }
- }
-
- def processAsJsonBatches(clientId: String, dataframe: DataFrame): Unit = {
- // Only process up to 10MB of data.
- val sb = new StringBuilder
- var rowCount = 0
- dataframe.toJSON
- .collect()
- .foreach(row => {
-
- // There are a few cases to cover here.
- // 1. The aggregated buffer size is larger than the MAX_BATCH_SIZE
- // -> send the current batch and reset.
- // 2. The aggregated buffer size is smaller than the MAX_BATCH_SIZE
- // -> append the row to the buffer.
- // 3. The row in question is larger than the MAX_BATCH_SIZE
- // -> fail the query.
-
- // Case 3. - Fail
- if (row.size > MAX_BATCH_SIZE) {
- throw SparkException.internalError(
- s"Serialized row is larger than MAX_BATCH_SIZE: ${row.size} > ${MAX_BATCH_SIZE}")
- }
-
- // Case 1 - FLush and send.
- if (sb.size + row.size > MAX_BATCH_SIZE) {
- val response = proto.ExecutePlanResponse.newBuilder().setClientId(clientId)
- val batch = proto.ExecutePlanResponse.JSONBatch
- .newBuilder()
- .setData(ByteString.copyFromUtf8(sb.toString()))
- .setRowCount(rowCount)
- .build()
- response.setJsonBatch(batch)
- responseObserver.onNext(response.build())
- sb.clear()
- sb.append(row)
- rowCount = 1
- } else {
- // Case 2 - Append.
- // Make sure to put the newline delimiters only between items and not at the end.
- if (rowCount > 0) {
- sb.append("\n")
- }
- sb.append(row)
- rowCount += 1
- }
- })
-
- // If the last batch is not empty, send out the data to the client.
- if (sb.size > 0) {
- val response = proto.ExecutePlanResponse.newBuilder().setClientId(clientId)
- val batch = proto.ExecutePlanResponse.JSONBatch
- .newBuilder()
- .setData(ByteString.copyFromUtf8(sb.toString()))
- .setRowCount(rowCount)
- .build()
- response.setJsonBatch(batch)
- responseObserver.onNext(response.build())
- }
-
- responseObserver.onNext(sendMetricsToResponse(clientId, dataframe))
- responseObserver.onCompleted()
+ processAsArrowBatches(request.getClientId, dataframe)
}
def processAsArrowBatches(clientId: String, dataframe: DataFrame): Unit = {
@@ -142,83 +71,20 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[ExecutePlanResp
var numSent = 0
if (numPartitions > 0) {
- type Batch = (Array[Byte], Long)
-
val batches = rows.mapPartitionsInternal(
SparkConnectStreamHandler
.rowToArrowConverter(schema, maxRecordsPerBatch, maxBatchSize, timeZoneId))
- val signal = new Object
- val partitions = collection.mutable.Map.empty[Int, Array[Batch]]
- var error: Throwable = null
-
- val processPartition = (iter: Iterator[Batch]) => iter.toArray
-
- // This callback is executed by the DAGScheduler thread.
- // After fetching a partition, it inserts the partition into the Map, and then
- // wakes up the main thread.
- val resultHandler = (partitionId: Int, partition: Array[Batch]) => {
- signal.synchronized {
- partitions(partitionId) = partition
- signal.notify()
- }
- ()
- }
-
- val future = spark.sparkContext.submitJob(
- rdd = batches,
- processPartition = processPartition,
- partitions = Seq.range(0, numPartitions),
- resultHandler = resultHandler,
- resultFunc = () => ())
-
- // Collect errors and propagate them to the main thread.
- future.onComplete { result =>
- result.failed.foreach { throwable =>
- signal.synchronized {
- error = throwable
- signal.notify()
- }
- }
- }(ThreadUtils.sameThread)
-
- // The main thread will wait until 0-th partition is available,
- // then send it to client and wait for the next partition.
- // Different from the implementation of [[Dataset#collectAsArrowToPython]], it sends
- // the arrow batches in main thread to avoid DAGScheduler thread been blocked for
- // tasks not related to scheduling. This is particularly important if there are
- // multiple users or clients running code at the same time.
- var currentPartitionId = 0
- while (currentPartitionId < numPartitions) {
- val partition = signal.synchronized {
- var result = partitions.remove(currentPartitionId)
- while (result.isEmpty && error == null) {
- signal.wait()
- result = partitions.remove(currentPartitionId)
- }
- error match {
- case NonFatal(e) =>
- responseObserver.onError(error)
- logError("Error while processing query.", e)
- return
- case fatal: Throwable => throw fatal
- case null => result.get
- }
- }
-
- partition.foreach { case (bytes, count) =>
- val response = proto.ExecutePlanResponse.newBuilder().setClientId(clientId)
- val batch = proto.ExecutePlanResponse.ArrowBatch
- .newBuilder()
- .setRowCount(count)
- .setData(ByteString.copyFrom(bytes))
- .build()
- response.setArrowBatch(batch)
- responseObserver.onNext(response.build())
- numSent += 1
- }
-
- currentPartitionId += 1
+ batches.collect().foreach { case (bytes, count) =>
+ val response = proto.ExecutePlanResponse.newBuilder().setClientId(clientId)
+ val batch = proto.ExecutePlanResponse.ArrowBatch
+ .newBuilder()
+ .setRowCount(count)
+ .setData(ByteString.copyFrom(bytes))
+ .build()
+ response.setArrowBatch(batch)
+ responseObserver.onNext(response.build())
+ numSent += 1
}
}
diff --git a/core/src/main/scala/org/apache/spark/storage/PushBasedFetchHelper.scala b/core/src/main/scala/org/apache/spark/storage/PushBasedFetchHelper.scala
index dd81c860ba33..8cc1b865207d 100644
--- a/core/src/main/scala/org/apache/spark/storage/PushBasedFetchHelper.scala
+++ b/core/src/main/scala/org/apache/spark/storage/PushBasedFetchHelper.scala
@@ -285,6 +285,8 @@ private class PushBasedFetchHelper(
* 2. There is a failure when fetching remote shuffle chunks.
* 3. There is a failure when processing SuccessFetchResult which is for a shuffle chunk
* (local or remote).
+ * 4. There is a zero-size buffer when processing SuccessFetchResult for a shuffle chunk
+ * (local or remote).
*/
def initiateFallbackFetchForPushMergedBlock(
blockId: BlockId,
diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
index b5f20522e91f..e35144756b59 100644
--- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
+++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
@@ -782,7 +782,7 @@ final class ShuffleBlockFetcherIterator(
logDebug("Number of requests in flight " + reqsInFlight)
}
- if (buf.size == 0) {
+ val in = if (buf.size == 0) {
// We will never legitimately receive a zero-size block. All blocks with zero records
// have zero size and all zero-size blocks have no records (and hence should never
// have been requested in the first place). This statement relies on behaviors of the
@@ -798,38 +798,52 @@ final class ShuffleBlockFetcherIterator(
// since the last call.
val msg = s"Received a zero-size buffer for block $blockId from $address " +
s"(expectedApproxSize = $size, isNetworkReqDone=$isNetworkReqDone)"
- throwFetchFailedException(blockId, mapIndex, address, new IOException(msg))
- }
-
- val in = try {
- val bufIn = buf.createInputStream()
- if (checksumEnabled) {
- val checksum = ShuffleChecksumHelper.getChecksumByAlgorithm(checksumAlgorithm)
- checkedIn = new CheckedInputStream(bufIn, checksum)
- checkedIn
+ if (blockId.isShuffleChunk) {
+ // Zero-size block may come from nodes with hardware failures, For shuffle chunks,
+ // the original shuffle blocks that belong to that zero-size shuffle chunk is
+ // available and we can opt to fallback immediately.
+ logWarning(msg)
+ pushBasedFetchHelper.initiateFallbackFetchForPushMergedBlock(blockId, address)
+ // Set result to null to trigger another iteration of the while loop to get either.
+ result = null
+ null
} else {
- bufIn
+ throwFetchFailedException(blockId, mapIndex, address, new IOException(msg))
}
- } catch {
- // The exception could only be throwed by local shuffle block
- case e: IOException =>
- assert(buf.isInstanceOf[FileSegmentManagedBuffer])
- e match {
- case ce: ClosedByInterruptException =>
- logError("Failed to create input stream from local block, " +
- ce.getMessage)
- case e: IOException => logError("Failed to create input stream from local block", e)
- }
- buf.release()
- if (blockId.isShuffleChunk) {
- pushBasedFetchHelper.initiateFallbackFetchForPushMergedBlock(blockId, address)
- // Set result to null to trigger another iteration of the while loop to get either.
- result = null
- null
+ } else {
+ try {
+ val bufIn = buf.createInputStream()
+ if (checksumEnabled) {
+ val checksum = ShuffleChecksumHelper.getChecksumByAlgorithm(checksumAlgorithm)
+ checkedIn = new CheckedInputStream(bufIn, checksum)
+ checkedIn
} else {
- throwFetchFailedException(blockId, mapIndex, address, e)
+ bufIn
}
+ } catch {
+ // The exception could only be throwed by local shuffle block
+ case e: IOException =>
+ assert(buf.isInstanceOf[FileSegmentManagedBuffer])
+ e match {
+ case ce: ClosedByInterruptException =>
+ logError("Failed to create input stream from local block, " +
+ ce.getMessage)
+ case e: IOException =>
+ logError("Failed to create input stream from local block", e)
+ }
+ buf.release()
+ if (blockId.isShuffleChunk) {
+ pushBasedFetchHelper.initiateFallbackFetchForPushMergedBlock(blockId, address)
+ // Set result to null to trigger another iteration of the while loop to get
+ // either.
+ result = null
+ null
+ } else {
+ throwFetchFailedException(blockId, mapIndex, address, e)
+ }
+ }
}
+
if (in != null) {
try {
input = streamWrapper(blockId, in)
diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
index f8fe28c0512b..64b6c93bf52c 100644
--- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
@@ -1814,4 +1814,17 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
intercept[FetchFailedException] { iterator.next() }
}
+ test("SPARK-40872: fallback to original shuffle block when a push-merged shuffle chunk " +
+ "is zero-size") {
+ val blockManager = mock(classOf[BlockManager])
+ val localDirs = Array("local-dir")
+ val blocksByAddress = prepareForFallbackToLocalBlocks(
+ blockManager, Map(SHUFFLE_MERGER_IDENTIFIER -> localDirs))
+ val zeroSizeBuffer = createMockManagedBuffer(0)
+ doReturn(Seq({zeroSizeBuffer})).when(blockManager)
+ .getLocalMergedBlockData(ShuffleMergedBlockId(0, 0, 2), localDirs)
+ val iterator = createShuffleBlockIteratorWithDefaults(blocksByAddress,
+ blockManager = Some(blockManager), streamWrapperLimitSize = Some(100))
+ verifyLocalBlocksFromFallback(iterator)
+ }
}
diff --git a/dev/deps/spark-deps-hadoop-2-hive-2.3 b/dev/deps/spark-deps-hadoop-2-hive-2.3
index 41e8ff96acc0..7b7c3ac7fb36 100644
--- a/dev/deps/spark-deps-hadoop-2-hive-2.3
+++ b/dev/deps/spark-deps-hadoop-2-hive-2.3
@@ -268,6 +268,6 @@ xml-apis/1.4.01//xml-apis-1.4.01.jar
xmlenc/0.52//xmlenc-0.52.jar
xz/1.9//xz-1.9.jar
zjsonpatch/0.3.0//zjsonpatch-0.3.0.jar
-zookeeper-jute/3.6.2//zookeeper-jute-3.6.2.jar
-zookeeper/3.6.2//zookeeper-3.6.2.jar
+zookeeper-jute/3.6.3//zookeeper-jute-3.6.3.jar
+zookeeper/3.6.3//zookeeper-3.6.3.jar
zstd-jni/1.5.2-5//zstd-jni-1.5.2-5.jar
diff --git a/dev/deps/spark-deps-hadoop-3-hive-2.3 b/dev/deps/spark-deps-hadoop-3-hive-2.3
index 876147c8bca6..c648f8896c38 100644
--- a/dev/deps/spark-deps-hadoop-3-hive-2.3
+++ b/dev/deps/spark-deps-hadoop-3-hive-2.3
@@ -253,6 +253,6 @@ wildfly-openssl/1.0.7.Final//wildfly-openssl-1.0.7.Final.jar
xbean-asm9-shaded/4.22//xbean-asm9-shaded-4.22.jar
xz/1.9//xz-1.9.jar
zjsonpatch/0.3.0//zjsonpatch-0.3.0.jar
-zookeeper-jute/3.6.2//zookeeper-jute-3.6.2.jar
-zookeeper/3.6.2//zookeeper-3.6.2.jar
+zookeeper-jute/3.6.3//zookeeper-jute-3.6.3.jar
+zookeeper/3.6.3//zookeeper-3.6.3.jar
zstd-jni/1.5.2-5//zstd-jni-1.5.2-5.jar
diff --git a/docs/sql-ref-syntax-qry-select.md b/docs/sql-ref-syntax-qry-select.md
index ea5c4a69d9ab..22c4d78605b4 100644
--- a/docs/sql-ref-syntax-qry-select.md
+++ b/docs/sql-ref-syntax-qry-select.md
@@ -83,6 +83,8 @@ SELECT [ hints , ... ] [ ALL | DISTINCT ] { [ [ named_expression | regex_column_
Specifies a source of input for the query. It can be one of the following:
* Table relation
* [Join relation](sql-ref-syntax-qry-select-join.html)
+ * [Pivot relation](sql-ref-syntax-qry-select-pivot.md)
+ * [Unpivot relation](sql-ref-syntax-qry-select-unpivot.md)
* [Table-value function](sql-ref-syntax-qry-select-tvf.html)
* [Inline table](sql-ref-syntax-qry-select-inline-table.html)
* [ [LATERAL](sql-ref-syntax-qry-select-lateral-subquery.html) ] ( Subquery )
diff --git a/pom.xml b/pom.xml
index 2dd898b8787e..e07b01ea9557 100644
--- a/pom.xml
+++ b/pom.xml
@@ -120,7 +120,7 @@
3.3.4
2.5.0
${hadoop.version}
- 3.6.2
+ 3.6.3
2.13.0
org.apache.hive
core
diff --git a/python/pyspark/sql/connect/client.py b/python/pyspark/sql/connect/client.py
index 5bdf01afc99c..fdcf34b7a47e 100644
--- a/python/pyspark/sql/connect/client.py
+++ b/python/pyspark/sql/connect/client.py
@@ -16,7 +16,6 @@
#
-import io
import logging
import os
import typing
@@ -446,13 +445,9 @@ def _analyze(self, plan: pb2.Plan, explain_mode: str = "extended") -> AnalyzeRes
return AnalyzeResult.fromProto(resp)
def _process_batch(self, b: pb2.ExecutePlanResponse) -> Optional[pandas.DataFrame]:
- import pandas as pd
-
if b.arrow_batch is not None and len(b.arrow_batch.data) > 0:
with pa.ipc.open_stream(b.arrow_batch.data) as rd:
return rd.read_pandas()
- elif b.json_batch is not None and len(b.json_batch.data) > 0:
- return pd.read_json(io.BytesIO(b.json_batch.data), lines=True)
return None
def _execute_and_fetch(self, req: pb2.ExecutePlanRequest) -> typing.Optional[pandas.DataFrame]:
diff --git a/python/pyspark/sql/connect/proto/base_pb2.py b/python/pyspark/sql/connect/proto/base_pb2.py
index 7d9f98b243e1..daa1c25cc8fe 100644
--- a/python/pyspark/sql/connect/proto/base_pb2.py
+++ b/python/pyspark/sql/connect/proto/base_pb2.py
@@ -36,7 +36,7 @@
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
- b'\n\x18spark/connect/base.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1cspark/connect/commands.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"t\n\x04Plan\x12-\n\x04root\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationH\x00R\x04root\x12\x32\n\x07\x63ommand\x18\x02 \x01(\x0b\x32\x16.spark.connect.CommandH\x00R\x07\x63ommandB\t\n\x07op_type"\xb5\x01\n\x07\x45xplain\x12\x45\n\x0c\x65xplain_mode\x18\x01 \x01(\x0e\x32".spark.connect.Explain.ExplainModeR\x0b\x65xplainMode"c\n\x0b\x45xplainMode\x12\x14\n\x10MODE_UNSPECIFIED\x10\x00\x12\n\n\x06SIMPLE\x10\x01\x12\x0c\n\x08\x45XTENDED\x10\x02\x12\x0b\n\x07\x43ODEGEN\x10\x03\x12\x08\n\x04\x43OST\x10\x04\x12\r\n\tFORMATTED\x10\x05"z\n\x0bUserContext\x12\x17\n\x07user_id\x18\x01 \x01(\tR\x06userId\x12\x1b\n\tuser_name\x18\x02 \x01(\tR\x08userName\x12\x35\n\nextensions\x18\xe7\x07 \x03(\x0b\x32\x14.google.protobuf.AnyR\nextensions"\x81\x02\n\x12\x41nalyzePlanRequest\x12\x1b\n\tclient_id\x18\x01 \x01(\tR\x08\x63lientId\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12\'\n\x04plan\x18\x03 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x12$\n\x0b\x63lient_type\x18\x04 \x01(\tH\x00R\nclientType\x88\x01\x01\x12\x30\n\x07\x65xplain\x18\x05 \x01(\x0b\x32\x16.spark.connect.ExplainR\x07\x65xplainB\x0e\n\x0c_client_type"\x8a\x01\n\x13\x41nalyzePlanResponse\x12\x1b\n\tclient_id\x18\x01 \x01(\tR\x08\x63lientId\x12/\n\x06schema\x18\x02 \x01(\x0b\x32\x17.spark.connect.DataTypeR\x06schema\x12%\n\x0e\x65xplain_string\x18\x03 \x01(\tR\rexplainString"\xcf\x01\n\x12\x45xecutePlanRequest\x12\x1b\n\tclient_id\x18\x01 \x01(\tR\x08\x63lientId\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12\'\n\x04plan\x18\x03 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x12$\n\x0b\x63lient_type\x18\x04 \x01(\tH\x00R\nclientType\x88\x01\x01\x42\x0e\n\x0c_client_type"\xad\x07\n\x13\x45xecutePlanResponse\x12\x1b\n\tclient_id\x18\x01 \x01(\tR\x08\x63lientId\x12P\n\x0b\x61rrow_batch\x18\x02 \x01(\x0b\x32-.spark.connect.ExecutePlanResponse.ArrowBatchH\x00R\narrowBatch\x12M\n\njson_batch\x18\x03 \x01(\x0b\x32,.spark.connect.ExecutePlanResponse.JSONBatchH\x00R\tjsonBatch\x12\x44\n\x07metrics\x18\x04 \x01(\x0b\x32*.spark.connect.ExecutePlanResponse.MetricsR\x07metrics\x1a=\n\nArrowBatch\x12\x1b\n\trow_count\x18\x01 \x01(\x03R\x08rowCount\x12\x12\n\x04\x64\x61ta\x18\x02 \x01(\x0cR\x04\x64\x61ta\x1a<\n\tJSONBatch\x12\x1b\n\trow_count\x18\x01 \x01(\x03R\x08rowCount\x12\x12\n\x04\x64\x61ta\x18\x02 \x01(\x0cR\x04\x64\x61ta\x1a\x85\x04\n\x07Metrics\x12Q\n\x07metrics\x18\x01 \x03(\x0b\x32\x37.spark.connect.ExecutePlanResponse.Metrics.MetricObjectR\x07metrics\x1a\xcc\x02\n\x0cMetricObject\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x17\n\x07plan_id\x18\x02 \x01(\x03R\x06planId\x12\x16\n\x06parent\x18\x03 \x01(\x03R\x06parent\x12z\n\x11\x65xecution_metrics\x18\x04 \x03(\x0b\x32M.spark.connect.ExecutePlanResponse.Metrics.MetricObject.ExecutionMetricsEntryR\x10\x65xecutionMetrics\x1a{\n\x15\x45xecutionMetricsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12L\n\x05value\x18\x02 \x01(\x0b\x32\x36.spark.connect.ExecutePlanResponse.Metrics.MetricValueR\x05value:\x02\x38\x01\x1aX\n\x0bMetricValue\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x14\n\x05value\x18\x02 \x01(\x03R\x05value\x12\x1f\n\x0bmetric_type\x18\x03 \x01(\tR\nmetricTypeB\r\n\x0bresult_type2\xc7\x01\n\x13SparkConnectService\x12X\n\x0b\x45xecutePlan\x12!.spark.connect.ExecutePlanRequest\x1a".spark.connect.ExecutePlanResponse"\x00\x30\x01\x12V\n\x0b\x41nalyzePlan\x12!.spark.connect.AnalyzePlanRequest\x1a".spark.connect.AnalyzePlanResponse"\x00\x42"\n\x1eorg.apache.spark.connect.protoP\x01\x62\x06proto3'
+ b'\n\x18spark/connect/base.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1cspark/connect/commands.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"t\n\x04Plan\x12-\n\x04root\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationH\x00R\x04root\x12\x32\n\x07\x63ommand\x18\x02 \x01(\x0b\x32\x16.spark.connect.CommandH\x00R\x07\x63ommandB\t\n\x07op_type"\xb5\x01\n\x07\x45xplain\x12\x45\n\x0c\x65xplain_mode\x18\x01 \x01(\x0e\x32".spark.connect.Explain.ExplainModeR\x0b\x65xplainMode"c\n\x0b\x45xplainMode\x12\x14\n\x10MODE_UNSPECIFIED\x10\x00\x12\n\n\x06SIMPLE\x10\x01\x12\x0c\n\x08\x45XTENDED\x10\x02\x12\x0b\n\x07\x43ODEGEN\x10\x03\x12\x08\n\x04\x43OST\x10\x04\x12\r\n\tFORMATTED\x10\x05"z\n\x0bUserContext\x12\x17\n\x07user_id\x18\x01 \x01(\tR\x06userId\x12\x1b\n\tuser_name\x18\x02 \x01(\tR\x08userName\x12\x35\n\nextensions\x18\xe7\x07 \x03(\x0b\x32\x14.google.protobuf.AnyR\nextensions"\x81\x02\n\x12\x41nalyzePlanRequest\x12\x1b\n\tclient_id\x18\x01 \x01(\tR\x08\x63lientId\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12\'\n\x04plan\x18\x03 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x12$\n\x0b\x63lient_type\x18\x04 \x01(\tH\x00R\nclientType\x88\x01\x01\x12\x30\n\x07\x65xplain\x18\x05 \x01(\x0b\x32\x16.spark.connect.ExplainR\x07\x65xplainB\x0e\n\x0c_client_type"\x8a\x01\n\x13\x41nalyzePlanResponse\x12\x1b\n\tclient_id\x18\x01 \x01(\tR\x08\x63lientId\x12/\n\x06schema\x18\x02 \x01(\x0b\x32\x17.spark.connect.DataTypeR\x06schema\x12%\n\x0e\x65xplain_string\x18\x03 \x01(\tR\rexplainString"\xcf\x01\n\x12\x45xecutePlanRequest\x12\x1b\n\tclient_id\x18\x01 \x01(\tR\x08\x63lientId\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12\'\n\x04plan\x18\x03 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x12$\n\x0b\x63lient_type\x18\x04 \x01(\tH\x00R\nclientType\x88\x01\x01\x42\x0e\n\x0c_client_type"\x8f\x06\n\x13\x45xecutePlanResponse\x12\x1b\n\tclient_id\x18\x01 \x01(\tR\x08\x63lientId\x12N\n\x0b\x61rrow_batch\x18\x02 \x01(\x0b\x32-.spark.connect.ExecutePlanResponse.ArrowBatchR\narrowBatch\x12\x44\n\x07metrics\x18\x04 \x01(\x0b\x32*.spark.connect.ExecutePlanResponse.MetricsR\x07metrics\x1a=\n\nArrowBatch\x12\x1b\n\trow_count\x18\x01 \x01(\x03R\x08rowCount\x12\x12\n\x04\x64\x61ta\x18\x02 \x01(\x0cR\x04\x64\x61ta\x1a\x85\x04\n\x07Metrics\x12Q\n\x07metrics\x18\x01 \x03(\x0b\x32\x37.spark.connect.ExecutePlanResponse.Metrics.MetricObjectR\x07metrics\x1a\xcc\x02\n\x0cMetricObject\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x17\n\x07plan_id\x18\x02 \x01(\x03R\x06planId\x12\x16\n\x06parent\x18\x03 \x01(\x03R\x06parent\x12z\n\x11\x65xecution_metrics\x18\x04 \x03(\x0b\x32M.spark.connect.ExecutePlanResponse.Metrics.MetricObject.ExecutionMetricsEntryR\x10\x65xecutionMetrics\x1a{\n\x15\x45xecutionMetricsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12L\n\x05value\x18\x02 \x01(\x0b\x32\x36.spark.connect.ExecutePlanResponse.Metrics.MetricValueR\x05value:\x02\x38\x01\x1aX\n\x0bMetricValue\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x14\n\x05value\x18\x02 \x01(\x03R\x05value\x12\x1f\n\x0bmetric_type\x18\x03 \x01(\tR\nmetricType2\xc7\x01\n\x13SparkConnectService\x12X\n\x0b\x45xecutePlan\x12!.spark.connect.ExecutePlanRequest\x1a".spark.connect.ExecutePlanResponse"\x00\x30\x01\x12V\n\x0b\x41nalyzePlan\x12!.spark.connect.AnalyzePlanRequest\x1a".spark.connect.AnalyzePlanResponse"\x00\x42"\n\x1eorg.apache.spark.connect.protoP\x01\x62\x06proto3'
)
@@ -48,7 +48,6 @@
_EXECUTEPLANREQUEST = DESCRIPTOR.message_types_by_name["ExecutePlanRequest"]
_EXECUTEPLANRESPONSE = DESCRIPTOR.message_types_by_name["ExecutePlanResponse"]
_EXECUTEPLANRESPONSE_ARROWBATCH = _EXECUTEPLANRESPONSE.nested_types_by_name["ArrowBatch"]
-_EXECUTEPLANRESPONSE_JSONBATCH = _EXECUTEPLANRESPONSE.nested_types_by_name["JSONBatch"]
_EXECUTEPLANRESPONSE_METRICS = _EXECUTEPLANRESPONSE.nested_types_by_name["Metrics"]
_EXECUTEPLANRESPONSE_METRICS_METRICOBJECT = _EXECUTEPLANRESPONSE_METRICS.nested_types_by_name[
"MetricObject"
@@ -139,15 +138,6 @@
# @@protoc_insertion_point(class_scope:spark.connect.ExecutePlanResponse.ArrowBatch)
},
),
- "JSONBatch": _reflection.GeneratedProtocolMessageType(
- "JSONBatch",
- (_message.Message,),
- {
- "DESCRIPTOR": _EXECUTEPLANRESPONSE_JSONBATCH,
- "__module__": "spark.connect.base_pb2"
- # @@protoc_insertion_point(class_scope:spark.connect.ExecutePlanResponse.JSONBatch)
- },
- ),
"Metrics": _reflection.GeneratedProtocolMessageType(
"Metrics",
(_message.Message,),
@@ -191,7 +181,6 @@
)
_sym_db.RegisterMessage(ExecutePlanResponse)
_sym_db.RegisterMessage(ExecutePlanResponse.ArrowBatch)
-_sym_db.RegisterMessage(ExecutePlanResponse.JSONBatch)
_sym_db.RegisterMessage(ExecutePlanResponse.Metrics)
_sym_db.RegisterMessage(ExecutePlanResponse.Metrics.MetricObject)
_sym_db.RegisterMessage(ExecutePlanResponse.Metrics.MetricObject.ExecutionMetricsEntry)
@@ -219,19 +208,17 @@
_EXECUTEPLANREQUEST._serialized_start = 986
_EXECUTEPLANREQUEST._serialized_end = 1193
_EXECUTEPLANRESPONSE._serialized_start = 1196
- _EXECUTEPLANRESPONSE._serialized_end = 2137
- _EXECUTEPLANRESPONSE_ARROWBATCH._serialized_start = 1479
- _EXECUTEPLANRESPONSE_ARROWBATCH._serialized_end = 1540
- _EXECUTEPLANRESPONSE_JSONBATCH._serialized_start = 1542
- _EXECUTEPLANRESPONSE_JSONBATCH._serialized_end = 1602
- _EXECUTEPLANRESPONSE_METRICS._serialized_start = 1605
- _EXECUTEPLANRESPONSE_METRICS._serialized_end = 2122
- _EXECUTEPLANRESPONSE_METRICS_METRICOBJECT._serialized_start = 1700
- _EXECUTEPLANRESPONSE_METRICS_METRICOBJECT._serialized_end = 2032
- _EXECUTEPLANRESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_start = 1909
- _EXECUTEPLANRESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_end = 2032
- _EXECUTEPLANRESPONSE_METRICS_METRICVALUE._serialized_start = 2034
- _EXECUTEPLANRESPONSE_METRICS_METRICVALUE._serialized_end = 2122
- _SPARKCONNECTSERVICE._serialized_start = 2140
- _SPARKCONNECTSERVICE._serialized_end = 2339
+ _EXECUTEPLANRESPONSE._serialized_end = 1979
+ _EXECUTEPLANRESPONSE_ARROWBATCH._serialized_start = 1398
+ _EXECUTEPLANRESPONSE_ARROWBATCH._serialized_end = 1459
+ _EXECUTEPLANRESPONSE_METRICS._serialized_start = 1462
+ _EXECUTEPLANRESPONSE_METRICS._serialized_end = 1979
+ _EXECUTEPLANRESPONSE_METRICS_METRICOBJECT._serialized_start = 1557
+ _EXECUTEPLANRESPONSE_METRICS_METRICOBJECT._serialized_end = 1889
+ _EXECUTEPLANRESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_start = 1766
+ _EXECUTEPLANRESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_end = 1889
+ _EXECUTEPLANRESPONSE_METRICS_METRICVALUE._serialized_start = 1891
+ _EXECUTEPLANRESPONSE_METRICS_METRICVALUE._serialized_end = 1979
+ _SPARKCONNECTSERVICE._serialized_start = 1982
+ _SPARKCONNECTSERVICE._serialized_end = 2181
# @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/base_pb2.pyi b/python/pyspark/sql/connect/proto/base_pb2.pyi
index 18b70de57a3c..64bb51d4c0b9 100644
--- a/python/pyspark/sql/connect/proto/base_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/base_pb2.pyi
@@ -401,28 +401,6 @@ class ExecutePlanResponse(google.protobuf.message.Message):
self, field_name: typing_extensions.Literal["data", b"data", "row_count", b"row_count"]
) -> None: ...
- class JSONBatch(google.protobuf.message.Message):
- """Message type when the result is returned as JSON. This is essentially a bulk wrapper
- for the JSON result of a Spark DataFrame. All rows are returned in the JSON record format
- of `{col -> row}`.
- """
-
- DESCRIPTOR: google.protobuf.descriptor.Descriptor
-
- ROW_COUNT_FIELD_NUMBER: builtins.int
- DATA_FIELD_NUMBER: builtins.int
- row_count: builtins.int
- data: builtins.bytes
- def __init__(
- self,
- *,
- row_count: builtins.int = ...,
- data: builtins.bytes = ...,
- ) -> None: ...
- def ClearField(
- self, field_name: typing_extensions.Literal["data", b"data", "row_count", b"row_count"]
- ) -> None: ...
-
class Metrics(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
@@ -530,14 +508,11 @@ class ExecutePlanResponse(google.protobuf.message.Message):
CLIENT_ID_FIELD_NUMBER: builtins.int
ARROW_BATCH_FIELD_NUMBER: builtins.int
- JSON_BATCH_FIELD_NUMBER: builtins.int
METRICS_FIELD_NUMBER: builtins.int
client_id: builtins.str
@property
def arrow_batch(self) -> global___ExecutePlanResponse.ArrowBatch: ...
@property
- def json_batch(self) -> global___ExecutePlanResponse.JSONBatch: ...
- @property
def metrics(self) -> global___ExecutePlanResponse.Metrics:
"""Metrics for the query execution. Typically, this field is only present in the last
batch of results and then represent the overall state of the query execution.
@@ -547,39 +522,17 @@ class ExecutePlanResponse(google.protobuf.message.Message):
*,
client_id: builtins.str = ...,
arrow_batch: global___ExecutePlanResponse.ArrowBatch | None = ...,
- json_batch: global___ExecutePlanResponse.JSONBatch | None = ...,
metrics: global___ExecutePlanResponse.Metrics | None = ...,
) -> None: ...
def HasField(
self,
- field_name: typing_extensions.Literal[
- "arrow_batch",
- b"arrow_batch",
- "json_batch",
- b"json_batch",
- "metrics",
- b"metrics",
- "result_type",
- b"result_type",
- ],
+ field_name: typing_extensions.Literal["arrow_batch", b"arrow_batch", "metrics", b"metrics"],
) -> builtins.bool: ...
def ClearField(
self,
field_name: typing_extensions.Literal[
- "arrow_batch",
- b"arrow_batch",
- "client_id",
- b"client_id",
- "json_batch",
- b"json_batch",
- "metrics",
- b"metrics",
- "result_type",
- b"result_type",
+ "arrow_batch", b"arrow_batch", "client_id", b"client_id", "metrics", b"metrics"
],
) -> None: ...
- def WhichOneof(
- self, oneof_group: typing_extensions.Literal["result_type", b"result_type"]
- ) -> typing_extensions.Literal["arrow_batch", "json_batch"] | None: ...
global___ExecutePlanResponse = ExecutePlanResponse
diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py b/python/pyspark/sql/tests/connect/test_connect_basic.py
index d3de94a379f8..9e7a5f2f4a54 100644
--- a/python/pyspark/sql/tests/connect/test_connect_basic.py
+++ b/python/pyspark/sql/tests/connect/test_connect_basic.py
@@ -221,7 +221,60 @@ def test_create_global_temp_view(self):
with self.assertRaises(_MultiThreadedRendezvous):
self.connect.sql("SELECT 1 AS X LIMIT 0").createGlobalTempView("view_1")
- @unittest.skip("test_fill_na is flaky")
+ def test_to_pandas(self):
+ # SPARK-41005: Test to pandas
+ query = """
+ SELECT * FROM VALUES
+ (false, 1, NULL),
+ (false, NULL, float(2.0)),
+ (NULL, 3, float(3.0))
+ AS tab(a, b, c)
+ """
+
+ self.assert_eq(
+ self.connect.sql(query).toPandas(),
+ self.spark.sql(query).toPandas(),
+ )
+
+ query = """
+ SELECT * FROM VALUES
+ (1, 1, NULL),
+ (2, NULL, float(2.0)),
+ (3, 3, float(3.0))
+ AS tab(a, b, c)
+ """
+
+ self.assert_eq(
+ self.connect.sql(query).toPandas(),
+ self.spark.sql(query).toPandas(),
+ )
+
+ query = """
+ SELECT * FROM VALUES
+ (double(1.0), 1, "1"),
+ (NULL, NULL, NULL),
+ (double(2.0), 3, "3")
+ AS tab(a, b, c)
+ """
+
+ self.assert_eq(
+ self.connect.sql(query).toPandas(),
+ self.spark.sql(query).toPandas(),
+ )
+
+ query = """
+ SELECT * FROM VALUES
+ (float(1.0), double(1.0), 1, "1"),
+ (float(2.0), double(2.0), 2, "2"),
+ (float(3.0), double(3.0), 3, "3")
+ AS tab(a, b, c, d)
+ """
+
+ self.assert_eq(
+ self.connect.sql(query).toPandas(),
+ self.spark.sql(query).toPandas(),
+ )
+
def test_fill_na(self):
# SPARK-41128: Test fill na
query = """
diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4
index a3c5f4a7b070..21747a0a021f 100644
--- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4
+++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4
@@ -697,7 +697,13 @@ setQuantifier
;
relation
- : LATERAL? relationPrimary joinRelation*
+ : LATERAL? relationPrimary relationExtension*
+ ;
+
+relationExtension
+ : joinRelation
+ | pivotClause
+ | unpivotClause
;
joinRelation
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 104c5c1e0805..8cdf83f4be75 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -113,9 +113,9 @@ object FakeV2SessionCatalog extends TableCatalog with FunctionCatalog {
* @param nestedViewDepth The nested depth in the view resolution, this enables us to limit the
* depth of nested views.
* @param maxNestedViewDepth The maximum allowed depth of nested view resolution.
- * @param relationCache A mapping from qualified table names to resolved relations. This can ensure
- * that the table is resolved only once if a table is used multiple times
- * in a query.
+ * @param relationCache A mapping from qualified table names and time travel spec to resolved
+ * relations. This can ensure that the table is resolved only once if a table
+ * is used multiple times in a query.
* @param referredTempViewNames All the temp view names referred by the current view we are
* resolving. It's used to make sure the relation resolution is
* consistent between view creation and view resolution. For example,
@@ -129,7 +129,8 @@ case class AnalysisContext(
catalogAndNamespace: Seq[String] = Nil,
nestedViewDepth: Int = 0,
maxNestedViewDepth: Int = -1,
- relationCache: mutable.Map[Seq[String], LogicalPlan] = mutable.Map.empty,
+ relationCache: mutable.Map[(Seq[String], Option[TimeTravelSpec]), LogicalPlan] =
+ mutable.Map.empty,
referredTempViewNames: Seq[Seq[String]] = Seq.empty,
// 1. If we are resolving a view, this field will be restored from the view metadata,
// by calling `AnalysisContext.withAnalysisContext(viewDesc)`.
@@ -1239,7 +1240,7 @@ class Analyzer(override val catalogManager: CatalogManager)
resolveTempView(u.multipartIdentifier, u.isStreaming, timeTravelSpec.isDefined).orElse {
expandIdentifier(u.multipartIdentifier) match {
case CatalogAndIdentifier(catalog, ident) =>
- val key = catalog.name +: ident.namespace :+ ident.name
+ val key = ((catalog.name +: ident.namespace :+ ident.name).toSeq, timeTravelSpec)
AnalysisContext.get.relationCache.get(key).map(_.transform {
case multi: MultiInstanceRelation =>
val newRelation = multi.newInstance()
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
index 99e5f411bdb6..8dd28e9aaae3 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
@@ -464,8 +464,9 @@ object FileSourceMetadataAttribute {
val FILE_SOURCE_METADATA_COL_ATTR_KEY = "__file_source_metadata_col"
- def apply(name: String, dataType: DataType, nullable: Boolean = true): AttributeReference =
- AttributeReference(name, dataType, nullable,
+ def apply(name: String, dataType: DataType): AttributeReference =
+ // Metadata column for file sources is always not nullable.
+ AttributeReference(name, dataType, nullable = false,
new MetadataBuilder()
.putBoolean(METADATA_COL_ATTR_KEY, value = true)
.putBoolean(FILE_SOURCE_METADATA_COL_ATTR_KEY, value = true).build())()
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
index 4adb70bc3909..d56ef28bcc32 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
@@ -929,7 +929,7 @@ class AstBuilder extends SqlBaseParserBaseVisitor[AnyRef] with SQLConfHelper wit
Join(left, right, Inner, None, JoinHint.NONE)
}
}
- if (conf.ansiRelationPrecedence) join else withJoinRelations(join, relation)
+ if (conf.ansiRelationPrecedence) join else withRelationExtensions(relation, join)
}
if (ctx.pivotClause() != null) {
if (ctx.unpivotClause() != null) {
@@ -1263,60 +1263,71 @@ class AstBuilder extends SqlBaseParserBaseVisitor[AnyRef] with SQLConfHelper wit
* }}}
*/
override def visitRelation(ctx: RelationContext): LogicalPlan = withOrigin(ctx) {
- withJoinRelations(plan(ctx.relationPrimary), ctx)
+ withRelationExtensions(ctx, plan(ctx.relationPrimary))
+ }
+
+ private def withRelationExtensions(ctx: RelationContext, query: LogicalPlan): LogicalPlan = {
+ ctx.relationExtension().asScala.foldLeft(query) { (left, extension) =>
+ if (extension.joinRelation() != null) {
+ withJoinRelation(extension.joinRelation(), left)
+ } else if (extension.pivotClause() != null) {
+ withPivot(extension.pivotClause(), left)
+ } else {
+ assert(extension.unpivotClause() != null)
+ withUnpivot(extension.unpivotClause(), left)
+ }
+ }
}
/**
- * Join one more [[LogicalPlan]]s to the current logical plan.
+ * Join one more [[LogicalPlan]] to the current logical plan.
*/
- private def withJoinRelations(base: LogicalPlan, ctx: RelationContext): LogicalPlan = {
- ctx.joinRelation.asScala.foldLeft(base) { (left, join) =>
- withOrigin(join) {
- val baseJoinType = join.joinType match {
- case null => Inner
- case jt if jt.CROSS != null => Cross
- case jt if jt.FULL != null => FullOuter
- case jt if jt.SEMI != null => LeftSemi
- case jt if jt.ANTI != null => LeftAnti
- case jt if jt.LEFT != null => LeftOuter
- case jt if jt.RIGHT != null => RightOuter
- case _ => Inner
- }
+ private def withJoinRelation(ctx: JoinRelationContext, base: LogicalPlan): LogicalPlan = {
+ withOrigin(ctx) {
+ val baseJoinType = ctx.joinType match {
+ case null => Inner
+ case jt if jt.CROSS != null => Cross
+ case jt if jt.FULL != null => FullOuter
+ case jt if jt.SEMI != null => LeftSemi
+ case jt if jt.ANTI != null => LeftAnti
+ case jt if jt.LEFT != null => LeftOuter
+ case jt if jt.RIGHT != null => RightOuter
+ case _ => Inner
+ }
- if (join.LATERAL != null && !join.right.isInstanceOf[AliasedQueryContext]) {
- throw QueryParsingErrors.invalidLateralJoinRelationError(join.right)
- }
+ if (ctx.LATERAL != null && !ctx.right.isInstanceOf[AliasedQueryContext]) {
+ throw QueryParsingErrors.invalidLateralJoinRelationError(ctx.right)
+ }
- // Resolve the join type and join condition
- val (joinType, condition) = Option(join.joinCriteria) match {
- case Some(c) if c.USING != null =>
- if (join.LATERAL != null) {
- throw QueryParsingErrors.lateralJoinWithUsingJoinUnsupportedError(ctx)
- }
- (UsingJoin(baseJoinType, visitIdentifierList(c.identifierList)), None)
- case Some(c) if c.booleanExpression != null =>
- (baseJoinType, Option(expression(c.booleanExpression)))
- case Some(c) =>
- throw new IllegalStateException(s"Unimplemented joinCriteria: $c")
- case None if join.NATURAL != null =>
- if (join.LATERAL != null) {
- throw QueryParsingErrors.lateralJoinWithNaturalJoinUnsupportedError(ctx)
- }
- if (baseJoinType == Cross) {
- throw QueryParsingErrors.naturalCrossJoinUnsupportedError(ctx)
- }
- (NaturalJoin(baseJoinType), None)
- case None =>
- (baseJoinType, None)
- }
- if (join.LATERAL != null) {
- if (!Seq(Inner, Cross, LeftOuter).contains(joinType)) {
- throw QueryParsingErrors.unsupportedLateralJoinTypeError(ctx, joinType.sql)
+ // Resolve the join type and join condition
+ val (joinType, condition) = Option(ctx.joinCriteria) match {
+ case Some(c) if c.USING != null =>
+ if (ctx.LATERAL != null) {
+ throw QueryParsingErrors.lateralJoinWithUsingJoinUnsupportedError(ctx)
}
- LateralJoin(left, LateralSubquery(plan(join.right)), joinType, condition)
- } else {
- Join(left, plan(join.right), joinType, condition, JoinHint.NONE)
+ (UsingJoin(baseJoinType, visitIdentifierList(c.identifierList)), None)
+ case Some(c) if c.booleanExpression != null =>
+ (baseJoinType, Option(expression(c.booleanExpression)))
+ case Some(c) =>
+ throw new IllegalStateException(s"Unimplemented joinCriteria: $c")
+ case None if ctx.NATURAL != null =>
+ if (ctx.LATERAL != null) {
+ throw QueryParsingErrors.lateralJoinWithNaturalJoinUnsupportedError(ctx)
+ }
+ if (baseJoinType == Cross) {
+ throw QueryParsingErrors.naturalCrossJoinUnsupportedError(ctx)
+ }
+ (NaturalJoin(baseJoinType), None)
+ case None =>
+ (baseJoinType, None)
+ }
+ if (ctx.LATERAL != null) {
+ if (!Seq(Inner, Cross, LeftOuter).contains(joinType)) {
+ throw QueryParsingErrors.unsupportedLateralJoinTypeError(ctx, joinType.sql)
}
+ LateralJoin(base, LateralSubquery(plan(ctx.right)), joinType, condition)
+ } else {
+ Join(base, plan(ctx.right), joinType, condition, JoinHint.NONE)
}
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
index 3c35ba9b6004..bbda9eb76b10 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
@@ -137,11 +137,22 @@ object ScanOperation extends OperationHelper {
val alwaysInline = SQLConf.get.getConf(SQLConf.COLLAPSE_PROJECT_ALWAYS_INLINE)
val (fields, filters, child, _) = collectProjectsAndFilters(plan, alwaysInline)
// `collectProjectsAndFilters` transforms the plan bottom-up, so the bottom-most filter are
- // placed at the beginning of `filters` list. According to the SQL semantic, we can only
- // push down the bottom deterministic filters.
- val filtersCanPushDown = filters.takeWhile(_.deterministic).flatMap(splitConjunctivePredicates)
- val filtersStayUp = filters.dropWhile(_.deterministic)
- Some((fields.getOrElse(child.output), filtersStayUp, filtersCanPushDown, child))
+ // placed at the beginning of `filters` list. According to the SQL semantic, we cannot merge
+ // Filters if one or more of them are nondeterministic. This means we can only push down the
+ // bottom-most Filter, or more following deterministic Filters if the bottom-most Filter is
+ // also deterministic.
+ if (filters.isEmpty) {
+ Some((fields.getOrElse(child.output), Nil, Nil, child))
+ } else if (filters.head.deterministic) {
+ val filtersCanPushDown = filters.takeWhile(_.deterministic)
+ .flatMap(splitConjunctivePredicates)
+ val filtersStayUp = filters.dropWhile(_.deterministic)
+ Some((fields.getOrElse(child.output), filtersStayUp, filtersCanPushDown, child))
+ } else {
+ val filtersCanPushDown = splitConjunctivePredicates(filters.head)
+ val filtersStayUp = filters.drop(1)
+ Some((fields.getOrElse(child.output), filtersStayUp, filtersCanPushDown, child))
+ }
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryParsingErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryParsingErrors.scala
index 8608d4ff306e..9624a06d80a6 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryParsingErrors.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryParsingErrors.scala
@@ -99,15 +99,15 @@ private[sql] object QueryParsingErrors extends QueryErrorsBase {
ctx)
}
- def unpivotWithPivotInFromClauseNotAllowedError(ctx: FromClauseContext): Throwable = {
+ def unpivotWithPivotInFromClauseNotAllowedError(ctx: ParserRuleContext): Throwable = {
new ParseException("UNPIVOT cannot be used together with PIVOT in FROM clause", ctx)
}
- def lateralWithPivotInFromClauseNotAllowedError(ctx: FromClauseContext): Throwable = {
+ def lateralWithPivotInFromClauseNotAllowedError(ctx: ParserRuleContext): Throwable = {
new ParseException(errorClass = "_LEGACY_ERROR_TEMP_0013", ctx)
}
- def lateralWithUnpivotInFromClauseNotAllowedError(ctx: FromClauseContext): Throwable = {
+ def lateralWithUnpivotInFromClauseNotAllowedError(ctx: ParserRuleContext): Throwable = {
new ParseException("LATERAL cannot be used together with UNPIVOT in FROM clause", ctx)
}
@@ -164,7 +164,7 @@ private[sql] object QueryParsingErrors extends QueryErrorsBase {
ctx)
}
- def naturalCrossJoinUnsupportedError(ctx: RelationContext): Throwable = {
+ def naturalCrossJoinUnsupportedError(ctx: ParserRuleContext): Throwable = {
new ParseException(
errorClass = "UNSUPPORTED_FEATURE.NATURAL_CROSS_JOIN",
messageParameters = Map.empty,
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/UnpivotParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/UnpivotParserSuite.scala
index dd7e4ec4916f..c680e08c1c83 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/UnpivotParserSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/UnpivotParserSuite.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.parser
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Unpivot}
+import org.apache.spark.sql.internal.SQLConf
class UnpivotParserSuite extends AnalysisTest {
@@ -192,4 +193,151 @@ class UnpivotParserSuite extends AnalysisTest {
)
}
+ test("unpivot - with joins") {
+ // unpivot the left table
+ assertEqual(
+ "SELECT * FROM t1 UNPIVOT (val FOR col in (a, b)) JOIN t2",
+ Unpivot(
+ None,
+ Some(Seq(Seq(UnresolvedAlias($"a")), Seq(UnresolvedAlias($"b")))),
+ None,
+ "col",
+ Seq("val"),
+ table("t1")
+ ).where(coalesce($"val").isNotNull).join(table("t2")).select(star()))
+
+ // unpivot the join result
+ assertEqual(
+ "SELECT * FROM t1 JOIN t2 UNPIVOT (val FOR col in (a, b))",
+ Unpivot(
+ None,
+ Some(Seq(Seq(UnresolvedAlias($"a")), Seq(UnresolvedAlias($"b")))),
+ None,
+ "col",
+ Seq("val"),
+ table("t1").join(table("t2"))
+ ).where(coalesce($"val").isNotNull).select(star()))
+
+ // unpivot the right table
+ assertEqual(
+ "SELECT * FROM t1 JOIN (t2 UNPIVOT (val FOR col in (a, b)))",
+ table("t1").join(
+ Unpivot(
+ None,
+ Some(Seq(Seq(UnresolvedAlias($"a")), Seq(UnresolvedAlias($"b")))),
+ None,
+ "col",
+ Seq("val"),
+ table("t2")
+ ).where(coalesce($"val").isNotNull)
+ ).select(star()))
+ }
+
+ test("unpivot - with implicit joins") {
+ // unpivot the left table
+ assertEqual(
+ "SELECT * FROM t1 UNPIVOT (val FOR col in (a, b)), t2",
+ Unpivot(
+ None,
+ Some(Seq(Seq(UnresolvedAlias($"a")), Seq(UnresolvedAlias($"b")))),
+ None,
+ "col",
+ Seq("val"),
+ table("t1")
+ ).where(coalesce($"val").isNotNull).join(table("t2")).select(star()))
+
+ // unpivot the join result
+ assertEqual(
+ "SELECT * FROM t1, t2 UNPIVOT (val FOR col in (a, b))",
+ Unpivot(
+ None,
+ Some(Seq(Seq(UnresolvedAlias($"a")), Seq(UnresolvedAlias($"b")))),
+ None,
+ "col",
+ Seq("val"),
+ table("t1").join(table("t2"))
+ ).where(coalesce($"val").isNotNull).select(star()))
+
+ // unpivot the right table - same SQL as above but with ANSI mode
+ withSQLConf(
+ SQLConf.ANSI_ENABLED.key -> "true",
+ SQLConf.ANSI_RELATION_PRECEDENCE.key -> "true") {
+ assertEqual(
+ "SELECT * FROM t1, t2 UNPIVOT (val FOR col in (a, b))",
+ table("t1").join(
+ Unpivot(
+ None,
+ Some(Seq(Seq(UnresolvedAlias($"a")), Seq(UnresolvedAlias($"b")))),
+ None,
+ "col",
+ Seq("val"),
+ table("t2")
+ ).where(coalesce($"val").isNotNull)
+ ).select(star()))
+ }
+
+ // unpivot the right table
+ assertEqual(
+ "SELECT * FROM t1, (t2 UNPIVOT (val FOR col in (a, b)))",
+ table("t1").join(
+ Unpivot(
+ None,
+ Some(Seq(Seq(UnresolvedAlias($"a")), Seq(UnresolvedAlias($"b")))),
+ None,
+ "col",
+ Seq("val"),
+ table("t2")
+ ).where(coalesce($"val").isNotNull)
+ ).select(star()))
+
+ // mixed with explicit joins
+ assertEqual(
+ // unpivot the join result of t1, t2 and t3
+ "SELECT * FROM t1, t2 JOIN t3 UNPIVOT (val FOR col in (a, b))",
+ Unpivot(
+ None,
+ Some(Seq(Seq(UnresolvedAlias($"a")), Seq(UnresolvedAlias($"b")))),
+ None,
+ "col",
+ Seq("val"),
+ table("t1").join(table("t2")).join(table("t3"))
+ ).where(coalesce($"val").isNotNull).select(star()))
+ withSQLConf(
+ SQLConf.ANSI_ENABLED.key -> "true",
+ SQLConf.ANSI_RELATION_PRECEDENCE.key -> "true") {
+ assertEqual(
+ // unpivot the join result of t2 and t3
+ "SELECT * FROM t1, t2 JOIN t3 UNPIVOT (val FOR col in (a, b))",
+ table("t1").join(
+ Unpivot(
+ None,
+ Some(Seq(Seq(UnresolvedAlias($"a")), Seq(UnresolvedAlias($"b")))),
+ None,
+ "col",
+ Seq("val"),
+ table("t2").join(table("t3"))
+ ).where(coalesce($"val").isNotNull)
+ ).select(star()))
+ }
+ }
+
+ test("unpivot - nested unpivot") {
+ assertEqual(
+ "SELECT * FROM t1 UNPIVOT (val FOR col in (a, b)) UNPIVOT (val FOR col in (a, b))",
+ Unpivot(
+ None,
+ Some(Seq(Seq(UnresolvedAlias($"a")), Seq(UnresolvedAlias($"b")))),
+ None,
+ "col",
+ Seq("val"),
+ Unpivot(
+ None,
+ Some(Seq(Seq(UnresolvedAlias($"a")), Seq(UnresolvedAlias($"b")))),
+ None,
+ "col",
+ Seq("val"),
+ table("t1")
+ ).where(coalesce($"val").isNotNull)
+ ).where(coalesce($"val").isNotNull).select(star()))
+ }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala
index a8ccc39ac478..6b3744fe02d2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala
@@ -40,6 +40,7 @@ import org.apache.spark.sql.execution.{ColumnarRule, SparkPlan}
*
* - Analyzer Rules.
* - Check Analysis Rules.
+ * - Cache Plan Normalization Rules.
* - Optimizer Rules.
* - Pre CBO Rules.
* - Planning Strategies.
@@ -217,6 +218,22 @@ class SparkSessionExtensions {
checkRuleBuilders += builder
}
+ private[this] val planNormalizationRules = mutable.Buffer.empty[RuleBuilder]
+
+ def buildPlanNormalizationRules(session: SparkSession): Seq[Rule[LogicalPlan]] = {
+ planNormalizationRules.map(_.apply(session)).toSeq
+ }
+
+ /**
+ * Inject a plan normalization `Rule` builder into the [[SparkSession]]. The injected rules will
+ * be executed just before query caching decisions are made. Such rules can be used to improve the
+ * cache hit rate by normalizing different plans to the same form. These rules should never modify
+ * the result of the LogicalPlan.
+ */
+ def injectPlanNormalizationRules(builder: RuleBuilder): Unit = {
+ planNormalizationRules += builder
+ }
+
private[this] val optimizerRules = mutable.Buffer.empty[RuleBuilder]
private[sql] def buildOptimizerRules(session: SparkSession): Seq[Rule[LogicalPlan]] = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala
index e9bbbc717d1e..d41611439f0b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala
@@ -89,7 +89,7 @@ class CacheManager extends Logging with AdaptiveSparkPlanHelper {
query: Dataset[_],
tableName: Option[String] = None,
storageLevel: StorageLevel = MEMORY_AND_DISK): Unit = {
- cacheQuery(query.sparkSession, query.logicalPlan, tableName, storageLevel)
+ cacheQuery(query.sparkSession, query.queryExecution.normalized, tableName, storageLevel)
}
/**
@@ -143,7 +143,7 @@ class CacheManager extends Logging with AdaptiveSparkPlanHelper {
def uncacheQuery(
query: Dataset[_],
cascade: Boolean): Unit = {
- uncacheQuery(query.sparkSession, query.logicalPlan, cascade)
+ uncacheQuery(query.sparkSession, query.queryExecution.normalized, cascade)
}
/**
@@ -281,7 +281,7 @@ class CacheManager extends Logging with AdaptiveSparkPlanHelper {
/** Optionally returns cached data for the given [[Dataset]] */
def lookupCachedData(query: Dataset[_]): Option[CachedData] = {
- lookupCachedData(query.logicalPlan)
+ lookupCachedData(query.queryExecution.normalized)
}
/** Optionally returns cached data for the given [[LogicalPlan]]. */
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
index 3706d5a1e3d4..796ec41ab51c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
@@ -105,12 +105,29 @@ class QueryExecution(
case other => other
}
+ // The plan that has been normalized by custom rules, so that it's more likely to hit cache.
+ lazy val normalized: LogicalPlan = {
+ val normalizationRules = sparkSession.sessionState.planNormalizationRules
+ if (normalizationRules.isEmpty) {
+ commandExecuted
+ } else {
+ val planChangeLogger = new PlanChangeLogger[LogicalPlan]()
+ val normalized = normalizationRules.foldLeft(commandExecuted) { (p, rule) =>
+ val result = rule.apply(p)
+ planChangeLogger.logRule(rule.ruleName, p, result)
+ result
+ }
+ planChangeLogger.logBatch("Plan Normalization", commandExecuted, normalized)
+ normalized
+ }
+ }
+
lazy val withCachedData: LogicalPlan = sparkSession.withActive {
assertAnalyzed()
assertSupported()
// clone the plan to avoid sharing the plan instance between different stages like analyzing,
// optimizing and planning.
- sparkSession.sharedState.cacheManager.useCachedData(commandExecuted.clone())
+ sparkSession.sharedState.cacheManager.useCachedData(normalized.clone())
}
def assertCommandExecuted(): Unit = commandExecuted
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala
index 576801d3dd59..476d6579b383 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala
@@ -275,8 +275,12 @@ object FileSourceStrategy extends Strategy with PredicateHelper with Logging {
.get.withName(FileFormat.ROW_INDEX)
}
}
+ // SPARK-41151: metadata column is not nullable for file sources.
+ // Here, we *explicitly* enforce the not null to `CreateStruct(structColumns)`
+ // to avoid any risk of inconsistent schema nullability
val metadataAlias =
- Alias(CreateStruct(structColumns), METADATA_NAME)(exprId = metadataStruct.exprId)
+ Alias(KnownNotNull(CreateStruct(structColumns)),
+ METADATA_NAME)(exprId = metadataStruct.exprId)
execution.ProjectExec(
readDataColumns ++ partitionColumns :+ metadataAlias, scan)
}.getOrElse(scan)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala
index b12a86c08d18..f81b12796ce9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala
@@ -317,6 +317,10 @@ abstract class BaseSessionStateBuilder(
extensions.buildRuntimeOptimizerRules(session))
}
+ protected def planNormalizationRules: Seq[Rule[LogicalPlan]] = {
+ extensions.buildPlanNormalizationRules(session)
+ }
+
/**
* Create a query execution object.
*/
@@ -371,7 +375,8 @@ abstract class BaseSessionStateBuilder(
createQueryExecution,
createClone,
columnarRules,
- adaptiveRulesHolder)
+ adaptiveRulesHolder,
+ planNormalizationRules)
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala
index 1d5e61aab269..eb0b71d155ba 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala
@@ -30,6 +30,7 @@ import org.apache.spark.sql.catalyst.catalog._
import org.apache.spark.sql.catalyst.optimizer.Optimizer
import org.apache.spark.sql.catalyst.parser.ParserInterface
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.connector.catalog.CatalogManager
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.adaptive.AdaptiveRulesHolder
@@ -79,7 +80,8 @@ private[sql] class SessionState(
createQueryExecution: (LogicalPlan, CommandExecutionMode.Value) => QueryExecution,
createClone: (SparkSession, SessionState) => SessionState,
val columnarRules: Seq[ColumnarRule],
- val adaptiveRulesHolder: AdaptiveRulesHolder) {
+ val adaptiveRulesHolder: AdaptiveRulesHolder,
+ val planNormalizationRules: Seq[Rule[LogicalPlan]]) {
// The following fields are lazy to avoid creating the Hive client when creating SessionState.
lazy val catalog: SessionCatalog = catalogBuilder()
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala
index 98cb54ccbbc3..cf5f8d990f79 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala
@@ -30,10 +30,10 @@ import org.apache.hadoop.fs.{LocalFileSystem, Path}
import org.apache.spark.SparkException
import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd}
import org.apache.spark.sql.TestingUDT.{IntervalUDT, NullData, NullUDT}
-import org.apache.spark.sql.catalyst.expressions.AttributeReference
+import org.apache.spark.sql.catalyst.expressions.{AttributeReference, GreaterThan, Literal}
import org.apache.spark.sql.catalyst.expressions.IntegralLiteralTestUtils.{negativeInt, positiveInt}
import org.apache.spark.sql.catalyst.plans.logical.Filter
-import org.apache.spark.sql.execution.SimpleMode
+import org.apache.spark.sql.execution.{FileSourceScanLike, SimpleMode}
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.execution.datasources.FilePartition
import org.apache.spark.sql.execution.datasources.v2.{BatchScanExec, FileScan}
@@ -1074,6 +1074,24 @@ class FileBasedDataSourceSuite extends QueryTest
checkAnswer(df, Row("v1", "v2"))
}
}
+
+ test("SPARK-41017: filter pushdown with nondeterministic predicates") {
+ withTempPath { path =>
+ val pathStr = path.getCanonicalPath
+ spark.range(10).write.parquet(pathStr)
+ Seq("parquet", "").foreach { useV1SourceList =>
+ withSQLConf(SQLConf.USE_V1_SOURCE_LIST.key -> useV1SourceList) {
+ val scan = spark.read.parquet(pathStr)
+ val df = scan.where(rand() > 0.5 && $"id" > 5)
+ val filters = df.queryExecution.executedPlan.collect {
+ case f: FileSourceScanLike => f.dataFilters
+ case b: BatchScanExec => b.scan.asInstanceOf[FileScan].dataFilters
+ }.flatten
+ assert(filters.contains(GreaterThan(scan.logicalPlan.output.head, Literal(5L))))
+ }
+ }
+ }
+ }
}
object TestingUDT {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala
index 10d2227324f1..f5f04eabec03 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala
@@ -192,6 +192,23 @@ class SparkSessionExtensionSuite extends SparkFunSuite with SQLHelper {
testInjectColumnar(false)
}
+ test("inject plan normalization rules") {
+ val extensions = create { extensions =>
+ extensions.injectPlanNormalizationRules { session =>
+ org.apache.spark.sql.catalyst.optimizer.PushDownPredicates
+ }
+ }
+ withSession(extensions) { session =>
+ import session.implicits._
+ val df = Seq((1, "a"), (2, "b")).toDF("i", "s")
+ df.select("i").filter($"i" > 1).cache()
+ assert(df.filter($"i" > 1).select("i").queryExecution.executedPlan.find {
+ case _: org.apache.spark.sql.execution.columnar.InMemoryTableScanExec => true
+ case _ => false
+ }.isDefined)
+ }
+ }
+
test("SPARK-39991: AQE should retain column statistics from completed query stages") {
val extensions = create { extensions =>
extensions.injectColumnar(_ =>
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala
index de8612c3348e..ea93367b3dbb 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala
@@ -2775,6 +2775,23 @@ class DataSourceV2SQLSuiteV1Filter extends DataSourceV2SQLSuite with AlterTableT
}
}
+ test("SPARK-41154: Incorrect relation caching for queries with time travel spec") {
+ sql("use testcat")
+ val t1 = "testcat.t1"
+ val t2 = "testcat.t2"
+ withTable(t1, t2) {
+ sql(s"CREATE TABLE $t1 USING foo AS SELECT 1 as c")
+ sql(s"CREATE TABLE $t2 USING foo AS SELECT 2 as c")
+ assert(
+ sql("""
+ |SELECT * FROM t VERSION AS OF '1'
+ |UNION ALL
+ |SELECT * FROM t VERSION AS OF '2'
+ |""".stripMargin
+ ).collect() === Array(Row(1), Row(2)))
+ }
+ }
+
private def testNotSupportedV2Command(sqlCommand: String, sqlParams: String): Unit = {
checkError(
exception = intercept[AnalysisException] {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileMetadataStructSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileMetadataStructSuite.scala
index e0e208b62f1c..a39a36a4f83b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileMetadataStructSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileMetadataStructSuite.scala
@@ -600,7 +600,7 @@ class FileMetadataStructSuite extends QueryTest with SharedSparkSession {
val df2 = spark.read.format("json")
.load(dir.getCanonicalPath + "/target/new-streaming-data-join")
// Verify self-join results
- assert(streamQuery2.lastProgress.numInputRows == 4L)
+ assert(streamQuery2.lastProgress.numInputRows == 2L)
assert(df2.count() == 2L)
assert(df2.select("*").columns.toSet == Set("name", "age", "info", "_metadata"))
}
@@ -654,4 +654,19 @@ class FileMetadataStructSuite extends QueryTest with SharedSparkSession {
}
}
}
+
+ metadataColumnsTest("SPARK-41151: consistent _metadata nullability " +
+ "between analyzed and executed", schema) { (df, _, _) =>
+ val queryExecution = df.select("_metadata").queryExecution
+ val analyzedSchema = queryExecution.analyzed.schema
+ val executedSchema = queryExecution.executedPlan.schema
+ assert(analyzedSchema.fields.head.name == "_metadata")
+ assert(executedSchema.fields.head.name == "_metadata")
+ // For stateful streaming, we store the schema in the state store
+ // and check consistency across batches.
+ // To avoid state schema compatibility mismatched,
+ // we should keep nullability consistent for _metadata struct
+ assert(!analyzedSchema.fields.head.nullable)
+ assert(analyzedSchema.fields.head.nullable == executedSchema.fields.head.nullable)
+ }
}