diff --git a/.github/workflows/util/install-spark-resources.sh b/.github/workflows/util/install-spark-resources.sh
index 4d1dd27a9c45..d245dbb4ac31 100755
--- a/.github/workflows/util/install-spark-resources.sh
+++ b/.github/workflows/util/install-spark-resources.sh
@@ -119,6 +119,11 @@ case "$1" in
cd ${INSTALL_DIR} && \
install_spark "4.0.1" "3" "2.12"
;;
+4.1)
+ # Spark-4.x, scala 2.12 // using 2.12 as a hack as 4.0 does not have 2.13 suffix
+ cd ${INSTALL_DIR} && \
+ install_spark "4.1.0" "3" "2.12"
+ ;;
*)
echo "Spark version is expected to be specified."
exit 1
diff --git a/.github/workflows/velox_backend_x86.yml b/.github/workflows/velox_backend_x86.yml
index 9adde6c5ce66..5ef70a79bc0f 100644
--- a/.github/workflows/velox_backend_x86.yml
+++ b/.github/workflows/velox_backend_x86.yml
@@ -1481,3 +1481,109 @@ jobs:
**/target/*.log
**/gluten-ut/**/hs_err_*.log
**/gluten-ut/**/core.*
+
+ spark-test-spark41:
+ needs: build-native-lib-centos-7
+ runs-on: ubuntu-22.04
+ env:
+ SPARK_TESTING: true
+ container: apache/gluten:centos-8-jdk17
+ steps:
+ - uses: actions/checkout@v2
+ - name: Download All Artifacts
+ uses: actions/download-artifact@v4
+ with:
+ name: velox-native-lib-centos-7-${{github.sha}}
+ path: ./cpp/build/releases
+ - name: Download Arrow Jars
+ uses: actions/download-artifact@v4
+ with:
+ name: arrow-jars-centos-7-${{github.sha}}
+ path: /root/.m2/repository/org/apache/arrow/
+ - name: Prepare
+ run: |
+ dnf module -y install python39 && \
+ alternatives --set python3 /usr/bin/python3.9 && \
+ pip3 install setuptools==77.0.3 && \
+ pip3 install pyspark==3.5.5 cython && \
+ pip3 install pandas==2.2.3 pyarrow==20.0.0
+ - name: Prepare Spark Resources for Spark 4.1.0 #TODO remove after image update
+ run: |
+ rm -rf /opt/shims/spark41
+ bash .github/workflows/util/install-spark-resources.sh 4.1
+ mv /opt/shims/spark41/spark_home/assembly/target/scala-2.12 /opt/shims/spark41/spark_home/assembly/target/scala-2.13
+ - name: Build and Run unit test for Spark 4.1.0 with scala-2.13 (other tests)
+ run: |
+ cd $GITHUB_WORKSPACE/
+ export SPARK_SCALA_VERSION=2.13
+ yum install -y java-17-openjdk-devel
+ export JAVA_HOME=/usr/lib/jvm/java-17-openjdk
+ export PATH=$JAVA_HOME/bin:$PATH
+ java -version
+ $MVN_CMD clean test -Pspark-4.1 -Pscala-2.13 -Pjava-17 -Pbackends-velox \
+ -Pspark-ut -DargLine="-Dspark.test.home=/opt/shims/spark41/spark_home/" \
+ -DtagsToExclude=org.apache.spark.tags.ExtendedSQLTest,org.apache.gluten.tags.UDFTest,org.apache.gluten.tags.EnhancedFeaturesTest,org.apache.gluten.tags.SkipTest
+ - name: Upload test report
+ if: always()
+ uses: actions/upload-artifact@v4
+ with:
+ name: ${{ github.job }}-report
+ path: '**/surefire-reports/TEST-*.xml'
+ - name: Upload unit tests log files
+ if: ${{ !success() }}
+ uses: actions/upload-artifact@v4
+ with:
+ name: ${{ github.job }}-test-log
+ path: |
+ **/target/*.log
+ **/gluten-ut/**/hs_err_*.log
+ **/gluten-ut/**/core.*
+
+ spark-test-spark41-slow:
+ needs: build-native-lib-centos-7
+ runs-on: ubuntu-22.04
+ env:
+ SPARK_TESTING: true
+ container: apache/gluten:centos-8-jdk17
+ steps:
+ - uses: actions/checkout@v2
+ - name: Download All Artifacts
+ uses: actions/download-artifact@v4
+ with:
+ name: velox-native-lib-centos-7-${{github.sha}}
+ path: ./cpp/build/releases
+ - name: Download Arrow Jars
+ uses: actions/download-artifact@v4
+ with:
+ name: arrow-jars-centos-7-${{github.sha}}
+ path: /root/.m2/repository/org/apache/arrow/
+ - name: Prepare Spark Resources for Spark 4.1.0 #TODO remove after image update
+ run: |
+ rm -rf /opt/shims/spark41
+ bash .github/workflows/util/install-spark-resources.sh 4.1
+ mv /opt/shims/spark41/spark_home/assembly/target/scala-2.12 /opt/shims/spark41/spark_home/assembly/target/scala-2.13
+ - name: Build and Run unit test for Spark 4.0 (slow tests)
+ run: |
+ cd $GITHUB_WORKSPACE/
+ yum install -y java-17-openjdk-devel
+ export JAVA_HOME=/usr/lib/jvm/java-17-openjdk
+ export PATH=$JAVA_HOME/bin:$PATH
+ java -version
+ $MVN_CMD clean test -Pspark-4.1 -Pscala-2.13 -Pjava-17 -Pbackends-velox -Pspark-ut \
+ -DargLine="-Dspark.test.home=/opt/shims/spark41/spark_home/" \
+ -DtagsToInclude=org.apache.spark.tags.ExtendedSQLTest
+ - name: Upload test report
+ if: always()
+ uses: actions/upload-artifact@v4
+ with:
+ name: ${{ github.job }}-report
+ path: '**/surefire-reports/TEST-*.xml'
+ - name: Upload unit tests log files
+ if: ${{ !success() }}
+ uses: actions/upload-artifact@v4
+ with:
+ name: ${{ github.job }}-test-log
+ path: |
+ **/target/*.log
+ **/gluten-ut/**/hs_err_*.log
+ **/gluten-ut/**/core.*
diff --git a/backends-velox/src/main/scala/org/apache/gluten/extension/ArrowConvertorRule.scala b/backends-velox/src/main/scala/org/apache/gluten/extension/ArrowConvertorRule.scala
index 25371be8d1fa..925f2a6be94f 100644
--- a/backends-velox/src/main/scala/org/apache/gluten/extension/ArrowConvertorRule.scala
+++ b/backends-velox/src/main/scala/org/apache/gluten/extension/ArrowConvertorRule.scala
@@ -38,6 +38,24 @@ import java.nio.charset.StandardCharsets
import scala.collection.convert.ImplicitConversions.`map AsScala`
+/**
+ * Extracts a CSVTable from a DataSourceV2Relation.
+ *
+ * Only the table variable of DataSourceV2Relation is accessed to improve compatibility across
+ * different Spark versions.
+ * @since Spark
+ * 4.1
+ */
+private object CSVTableExtractor {
+ def unapply(relation: DataSourceV2Relation): Option[(DataSourceV2Relation, CSVTable)] = {
+ relation.table match {
+ case t: CSVTable =>
+ Some((relation, t))
+ case _ => None
+ }
+ }
+}
+
@Experimental
case class ArrowConvertorRule(session: SparkSession) extends Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan = {
@@ -56,25 +74,15 @@ case class ArrowConvertorRule(session: SparkSession) extends Rule[LogicalPlan] {
l.copy(relation = r.copy(fileFormat = new ArrowCSVFileFormat(csvOptions))(session))
case _ => l
}
- case d @ DataSourceV2Relation(
- t @ CSVTable(
- name,
- sparkSession,
- options,
- paths,
- userSpecifiedSchema,
- fallbackFileFormat),
- _,
- _,
- _,
- _) if validate(session, t.dataSchema, options.asCaseSensitiveMap().toMap) =>
+ case CSVTableExtractor(d, t)
+ if validate(session, t.dataSchema, t.options.asCaseSensitiveMap().toMap) =>
d.copy(table = ArrowCSVTable(
- "arrow" + name,
- sparkSession,
- options,
- paths,
- userSpecifiedSchema,
- fallbackFileFormat))
+ "arrow" + t.name,
+ t.sparkSession,
+ t.options,
+ t.paths,
+ t.userSpecifiedSchema,
+ t.fallbackFileFormat))
case r =>
r
}
diff --git a/backends-velox/src/main/scala/org/apache/gluten/utils/ParquetMetadataUtils.scala b/backends-velox/src/main/scala/org/apache/gluten/utils/ParquetMetadataUtils.scala
index 6239ab5ad749..ab76cba4aa5d 100644
--- a/backends-velox/src/main/scala/org/apache/gluten/utils/ParquetMetadataUtils.scala
+++ b/backends-velox/src/main/scala/org/apache/gluten/utils/ParquetMetadataUtils.scala
@@ -21,7 +21,7 @@ import org.apache.gluten.sql.shims.SparkShimLoader
import org.apache.spark.internal.Logging
import org.apache.spark.sql.execution.datasources.DataSourceUtils
-import org.apache.spark.sql.execution.datasources.parquet.{ParquetFooterReader, ParquetOptions}
+import org.apache.spark.sql.execution.datasources.parquet.{ParquetFooterReaderShim, ParquetOptions}
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileSystem, LocatedFileStatus, Path}
@@ -135,7 +135,7 @@ object ParquetMetadataUtils extends Logging {
parquetOptions: ParquetOptions): Option[String] = {
val footer =
try {
- ParquetFooterReader.readFooter(conf, fileStatus, ParquetMetadataConverter.NO_FILTER)
+ ParquetFooterReaderShim.readFooter(conf, fileStatus, ParquetMetadataConverter.NO_FILTER)
} catch {
case e: Exception if ExceptionUtils.hasCause(e, classOf[ParquetCryptoRuntimeException]) =>
return Some("Encrypted Parquet footer detected.")
diff --git a/backends-velox/src/test/scala/org/apache/gluten/execution/MiscOperatorSuite.scala b/backends-velox/src/test/scala/org/apache/gluten/execution/MiscOperatorSuite.scala
index c3132420916a..cfddfb8e21e3 100644
--- a/backends-velox/src/test/scala/org/apache/gluten/execution/MiscOperatorSuite.scala
+++ b/backends-velox/src/test/scala/org/apache/gluten/execution/MiscOperatorSuite.scala
@@ -753,7 +753,11 @@ class MiscOperatorSuite extends VeloxWholeStageTransformerSuite with AdaptiveSpa
val df = sql("SELECT 1")
checkAnswer(df, Row(1))
val plan = df.queryExecution.executedPlan
- assert(plan.find(_.isInstanceOf[RDDScanExec]).isDefined)
+ if (isSparkVersionGE("4.1")) {
+ assert(plan.find(_.getClass.getSimpleName == "OneRowRelationExec").isDefined)
+ } else {
+ assert(plan.find(_.isInstanceOf[RDDScanExec]).isDefined)
+ }
assert(plan.find(_.isInstanceOf[ProjectExecTransformer]).isDefined)
assert(plan.find(_.isInstanceOf[RowToVeloxColumnarExec]).isDefined)
}
diff --git a/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxHashJoinSuite.scala b/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxHashJoinSuite.scala
index 53f44a2ccc81..5958baa3771f 100644
--- a/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxHashJoinSuite.scala
+++ b/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxHashJoinSuite.scala
@@ -92,12 +92,9 @@ class VeloxHashJoinSuite extends VeloxWholeStageTransformerSuite {
// The computing is combined into one single whole stage transformer.
val wholeStages = plan.collect { case wst: WholeStageTransformer => wst }
- if (SparkShimLoader.getSparkVersion.startsWith("3.2.")) {
+ if (isSparkVersionLE("3.2")) {
assert(wholeStages.length == 1)
- } else if (
- SparkShimLoader.getSparkVersion.startsWith("3.5.") ||
- SparkShimLoader.getSparkVersion.startsWith("4.0.")
- ) {
+ } else if (isSparkVersionGE("3.5")) {
assert(wholeStages.length == 5)
} else {
assert(wholeStages.length == 3)
diff --git a/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxStringFunctionsSuite.scala b/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxStringFunctionsSuite.scala
index 06f0acb784b8..37f13bcea8c3 100644
--- a/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxStringFunctionsSuite.scala
+++ b/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxStringFunctionsSuite.scala
@@ -544,7 +544,8 @@ class VeloxStringFunctionsSuite extends VeloxWholeStageTransformerSuite {
s"from $LINEITEM_TABLE limit 5") { _ => }
}
- testWithMinSparkVersion("split", "3.4") {
+ // TODO: fix on spark-4.1
+ testWithSpecifiedSparkVersion("split", "3.4", "3.5") {
runQueryAndCompare(
s"select l_orderkey, l_comment, split(l_comment, '') " +
s"from $LINEITEM_TABLE limit 5") {
diff --git a/backends-velox/src/test/scala/org/apache/gluten/execution/python/ArrowEvalPythonExecSuite.scala b/backends-velox/src/test/scala/org/apache/gluten/execution/python/ArrowEvalPythonExecSuite.scala
index 52a17995f386..f8e2554da7c7 100644
--- a/backends-velox/src/test/scala/org/apache/gluten/execution/python/ArrowEvalPythonExecSuite.scala
+++ b/backends-velox/src/test/scala/org/apache/gluten/execution/python/ArrowEvalPythonExecSuite.scala
@@ -39,7 +39,8 @@ class ArrowEvalPythonExecSuite extends WholeStageTransformerSuite {
.set("spark.executor.cores", "1")
}
- test("arrow_udf test: without projection") {
+ // TODO: fix on spark-4.1
+ testWithMaxSparkVersion("arrow_udf test: without projection", "4.0") {
lazy val base =
Seq(("1", 1), ("1", 2), ("2", 1), ("2", 2), ("3", 1), ("3", 2), ("0", 1), ("3", 0))
.toDF("a", "b")
@@ -59,7 +60,8 @@ class ArrowEvalPythonExecSuite extends WholeStageTransformerSuite {
checkAnswer(df2, expected)
}
- test("arrow_udf test: with unrelated projection") {
+ // TODO: fix on spark-4.1
+ testWithMaxSparkVersion("arrow_udf test: with unrelated projection", "4.0") {
lazy val base =
Seq(("1", 1), ("1", 2), ("2", 1), ("2", 2), ("3", 1), ("3", 2), ("0", 1), ("3", 0))
.toDF("a", "b")
diff --git a/dev/format-scala-code.sh b/dev/format-scala-code.sh
index 96a782405bbf..4d2fba5ca166 100755
--- a/dev/format-scala-code.sh
+++ b/dev/format-scala-code.sh
@@ -22,7 +22,7 @@ MVN_CMD="${BASEDIR}/../build/mvn"
# If a new profile is introduced for new modules, please add it here to ensure
# the new modules are covered.
PROFILES="-Pbackends-velox -Pceleborn,uniffle -Piceberg,delta,hudi,paimon \
- -Pspark-3.2,spark-3.3,spark-3.4,spark-3.5,spark-4.0 -Pspark-ut"
+ -Pspark-3.2,spark-3.3,spark-3.4,spark-3.5,spark-4.0,spark-4.1 -Pspark-ut"
COMMAND=$1
diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/caller/CallerInfo.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/caller/CallerInfo.scala
index 732c898285dd..7dfaaaa774c5 100644
--- a/gluten-core/src/main/scala/org/apache/gluten/extension/caller/CallerInfo.scala
+++ b/gluten-core/src/main/scala/org/apache/gluten/extension/caller/CallerInfo.scala
@@ -18,7 +18,7 @@ package org.apache.gluten.extension.caller
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec
import org.apache.spark.sql.execution.columnar.InMemoryRelation
-import org.apache.spark.sql.execution.streaming.StreamExecution
+import org.apache.spark.util.SparkVersionUtil
/**
* Helper API that stores information about the call site of the columnar rule. Specific columnar
@@ -70,7 +70,12 @@ object CallerInfo {
}
private def inStreamingCall(stack: Seq[StackTraceElement]): Boolean = {
- stack.exists(_.getClassName.equals(StreamExecution.getClass.getName.split('$').head))
+ val streamName = if (SparkVersionUtil.gteSpark41) {
+ "org.apache.spark.sql.execution.streaming.runtime.StreamExecution"
+ } else {
+ "org.apache.spark.sql.execution.streaming.StreamExecution"
+ }
+ stack.exists(_.getClassName.equals(streamName))
}
private def inBloomFilterStatFunctionCall(stack: Seq[StackTraceElement]): Boolean = {
diff --git a/gluten-core/src/main/scala/org/apache/spark/util/SparkVersionUtil.scala b/gluten-core/src/main/scala/org/apache/spark/util/SparkVersionUtil.scala
index efa0c63dca52..50114ab7023e 100644
--- a/gluten-core/src/main/scala/org/apache/spark/util/SparkVersionUtil.scala
+++ b/gluten-core/src/main/scala/org/apache/spark/util/SparkVersionUtil.scala
@@ -25,6 +25,7 @@ object SparkVersionUtil {
val gteSpark33: Boolean = comparedWithSpark33 >= 0
val gteSpark35: Boolean = comparedWithSpark35 >= 0
val gteSpark40: Boolean = compareMajorMinorVersion((4, 0)) >= 0
+ val gteSpark41: Boolean = compareMajorMinorVersion((4, 1)) >= 0
// Returns X. X < 0 if one < other, x == 0 if one == other, x > 0 if one > other.
def compareMajorMinorVersion(other: (Int, Int)): Int = {
diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala b/gluten-substrait/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala
index 52f6d31d1d1b..418de8578f5c 100644
--- a/gluten-substrait/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala
+++ b/gluten-substrait/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala
@@ -856,13 +856,6 @@ object ExpressionConverter extends SQLConfHelper with Logging {
dateAdd.children,
dateAdd
)
- case timeAdd: TimeAdd =>
- BackendsApiManager.getSparkPlanExecApiInstance.genDateAddTransformer(
- attributeSeq,
- substraitExprName,
- timeAdd.children,
- timeAdd
- )
case ss: StringSplit =>
BackendsApiManager.getSparkPlanExecApiInstance.genStringSplitTransformer(
substraitExprName,
diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/expression/ExpressionMappings.scala b/gluten-substrait/src/main/scala/org/apache/gluten/expression/ExpressionMappings.scala
index b0b7c8079315..b13aced2a62c 100644
--- a/gluten-substrait/src/main/scala/org/apache/gluten/expression/ExpressionMappings.scala
+++ b/gluten-substrait/src/main/scala/org/apache/gluten/expression/ExpressionMappings.scala
@@ -179,7 +179,6 @@ object ExpressionMappings {
Sig[Second](EXTRACT),
Sig[FromUnixTime](FROM_UNIXTIME),
Sig[DateAdd](DATE_ADD),
- Sig[TimeAdd](TIMESTAMP_ADD),
Sig[DateSub](DATE_SUB),
Sig[DateDiff](DATE_DIFF),
Sig[ToUnixTimestamp](TO_UNIX_TIMESTAMP),
diff --git a/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/GlutenImplicits.scala b/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/GlutenImplicits.scala
index 474a27176906..7267ce56ba1c 100644
--- a/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/GlutenImplicits.scala
+++ b/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/GlutenImplicits.scala
@@ -18,8 +18,8 @@ package org.apache.spark.sql.execution
import org.apache.gluten.exception.GlutenException
import org.apache.gluten.execution.{GlutenPlan, WholeStageTransformer}
+import org.apache.gluten.sql.shims.SparkShimLoader
import org.apache.gluten.utils.PlanUtil
-
import org.apache.spark.sql.{Column, Dataset, SparkSession}
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.plans.logical.{CommandResult, LogicalPlan}
@@ -130,10 +130,8 @@ object GlutenImplicits {
val (innerNumGlutenNodes, innerFallbackNodeToReason) =
withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false") {
// re-plan manually to skip cached data
- val newSparkPlan = QueryExecution.createSparkPlan(
- spark,
- spark.sessionState.planner,
- p.inputPlan.logicalLink.get)
+ val newSparkPlan = SparkShimLoader.getSparkShims.createSparkPlan(
+ spark, spark.sessionState.planner, p.inputPlan.logicalLink.get)
val newExecutedPlan = QueryExecution.prepareExecutedPlan(
spark,
newSparkPlan
diff --git a/gluten-ut/spark40/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/GlutenParquetRowIndexSuite.scala b/gluten-ut/spark40/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/GlutenParquetRowIndexSuite.scala
index 5cf41b7a9ed5..570b6d5e0c1a 100644
--- a/gluten-ut/spark40/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/GlutenParquetRowIndexSuite.scala
+++ b/gluten-ut/spark40/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/GlutenParquetRowIndexSuite.scala
@@ -42,7 +42,7 @@ class GlutenParquetRowIndexSuite extends ParquetRowIndexSuite with GlutenSQLTest
import testImplicits._
private def readRowGroupRowCounts(path: String): Seq[Long] = {
- ParquetFooterReader
+ ParquetFooterReaderShim
.readFooter(
spark.sessionState.newHadoopConf(),
new Path(path),
diff --git a/pom.xml b/pom.xml
index c87d46ceb5c1..b88fa8dfee2b 100644
--- a/pom.xml
+++ b/pom.xml
@@ -938,7 +938,7 @@
scala-2.13
- 2.13.16
+ 2.13.17
2.13
3.8.3
@@ -1270,6 +1270,86 @@
+
+ spark-4.1
+
+ 4.1
+ spark-sql-columnar-shims-spark41
+ 4.1.0
+ 1.10.0
+ delta-spark
+ 4.0.0
+ 40
+ 1.1.0
+ 1.3.0
+ 2.18.2
+ 2.18.2
+ 2.18.2
+ 3.4.1
+ 4.13.1
+ 33.4.0-jre
+ 2.0.16
+ 2.24.3
+ 3.17.0
+ 18.1.0
+ 18.1.0
+
+
+
+ org.slf4j
+ slf4j-api
+ ${slf4j.version}
+ provided
+
+
+ org.apache.logging.log4j
+ log4j-slf4j2-impl
+ ${log4j.version}
+ provided
+
+
+
+
+
+ org.apache.maven.plugins
+ maven-enforcer-plugin
+
+
+ enforce-java-17+
+
+ enforce
+
+
+
+
+
+ java-17,java-21
+ false
+ "-P spark-4.1" requires JDK 17+
+
+
+
+
+
+ enforce-scala-213
+
+ enforce
+
+
+
+
+
+ scala-2.13
+ "-P spark-4.1" requires Scala 2.13
+
+
+
+
+
+
+
+
+
hadoop-2.7.4
diff --git a/shims/common/src/main/scala/org/apache/gluten/sql/shims/SparkShims.scala b/shims/common/src/main/scala/org/apache/gluten/sql/shims/SparkShims.scala
index de220cab82e4..9164f4b7c43a 100644
--- a/shims/common/src/main/scala/org/apache/gluten/sql/shims/SparkShims.scala
+++ b/shims/common/src/main/scala/org/apache/gluten/sql/shims/SparkShims.scala
@@ -355,4 +355,14 @@ trait SparkShims {
def unsupportedCodec: Seq[CompressionCodecName] = {
Seq(CompressionCodecName.LZO, CompressionCodecName.BROTLI)
}
+
+ /**
+ * Shim layer for QueryExecution to maintain compatibility across different Spark versions.
+ * @since Spark
+ * 4.1
+ */
+ def createSparkPlan(
+ sparkSession: SparkSession,
+ planner: SparkPlanner,
+ plan: LogicalPlan): SparkPlan
}
diff --git a/shims/pom.xml b/shims/pom.xml
index 9a8e639e04ae..8c10ce640fee 100644
--- a/shims/pom.xml
+++ b/shims/pom.xml
@@ -98,6 +98,12 @@
spark40
+
+ spark-4.1
+
+ spark41
+
+
default
diff --git a/shims/spark32/src/main/scala/org/apache/gluten/sql/shims/spark32/Spark32Shims.scala b/shims/spark32/src/main/scala/org/apache/gluten/sql/shims/spark32/Spark32Shims.scala
index 367b66f9c424..d1bf07e64b12 100644
--- a/shims/spark32/src/main/scala/org/apache/gluten/sql/shims/spark32/Spark32Shims.scala
+++ b/shims/spark32/src/main/scala/org/apache/gluten/sql/shims/spark32/Spark32Shims.scala
@@ -39,7 +39,7 @@ import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
import org.apache.spark.sql.catalyst.util.RebaseDateTime.RebaseSpec
import org.apache.spark.sql.connector.catalog.Table
import org.apache.spark.sql.connector.expressions.Transform
-import org.apache.spark.sql.execution.{FileSourceScanExec, PartitionedFileUtil, SparkPlan}
+import org.apache.spark.sql.execution.{FileSourceScanExec, PartitionedFileUtil, QueryExecution, SparkPlan, SparkPlanner}
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.execution.datasources.FileFormatWriter.Empty2Null
import org.apache.spark.sql.execution.datasources.parquet.ParquetFilters
@@ -303,4 +303,16 @@ class Spark32Shims extends SparkShims {
override def getErrorMessage(raiseError: RaiseError): Option[Expression] = {
Some(raiseError.child)
}
+
+ /**
+ * Shim layer for QueryExecution to maintain compatibility across different Spark versions.
+ *
+ * @since Spark
+ * 4.1
+ */
+ override def createSparkPlan(
+ sparkSession: SparkSession,
+ planner: SparkPlanner,
+ plan: LogicalPlan): SparkPlan =
+ QueryExecution.createSparkPlan(sparkSession, planner, plan)
}
diff --git a/shims/spark32/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFooterReaderShim.scala b/shims/spark32/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFooterReaderShim.scala
new file mode 100644
index 000000000000..b1419e5e6233
--- /dev/null
+++ b/shims/spark32/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFooterReaderShim.scala
@@ -0,0 +1,42 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.datasources.parquet
+
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.fs.{FileStatus, Path}
+import org.apache.parquet.format.converter.ParquetMetadataConverter
+import org.apache.parquet.hadoop.metadata.ParquetMetadata
+
+/** Shim layer for ParquetFooterReader to maintain compatibility across different Spark versions. */
+object ParquetFooterReaderShim {
+
+ /** @since Spark 4.1 */
+ def readFooter(
+ configuration: Configuration,
+ fileStatus: FileStatus,
+ filter: ParquetMetadataConverter.MetadataFilter): ParquetMetadata = {
+ ParquetFooterReader.readFooter(configuration, fileStatus, filter)
+ }
+
+ /** @since Spark 4.1 */
+ def readFooter(
+ configuration: Configuration,
+ file: Path,
+ filter: ParquetMetadataConverter.MetadataFilter): ParquetMetadata = {
+ ParquetFooterReader.readFooter(configuration, file, filter)
+ }
+}
diff --git a/shims/spark33/src/main/scala/org/apache/gluten/sql/shims/spark33/Spark33Shims.scala b/shims/spark33/src/main/scala/org/apache/gluten/sql/shims/spark33/Spark33Shims.scala
index 2b6affcded6e..a18fb3171991 100644
--- a/shims/spark33/src/main/scala/org/apache/gluten/sql/shims/spark33/Spark33Shims.scala
+++ b/shims/spark33/src/main/scala/org/apache/gluten/sql/shims/spark33/Spark33Shims.scala
@@ -40,7 +40,7 @@ import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, TimestampFormatte
import org.apache.spark.sql.catalyst.util.RebaseDateTime.RebaseSpec
import org.apache.spark.sql.connector.catalog.Table
import org.apache.spark.sql.connector.expressions.Transform
-import org.apache.spark.sql.execution.{FileSourceScanExec, PartitionedFileUtil, SparkPlan}
+import org.apache.spark.sql.execution.{FileSourceScanExec, PartitionedFileUtil, QueryExecution, SparkPlan, SparkPlanner}
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.execution.datasources.FileFormatWriter.Empty2Null
import org.apache.spark.sql.execution.datasources.parquet.ParquetFilters
@@ -408,4 +408,16 @@ class Spark33Shims extends SparkShims {
override def getErrorMessage(raiseError: RaiseError): Option[Expression] = {
Some(raiseError.child)
}
+
+ /**
+ * Shim layer for QueryExecution to maintain compatibility across different Spark versions.
+ *
+ * @since Spark
+ * 4.1
+ */
+ override def createSparkPlan(
+ sparkSession: SparkSession,
+ planner: SparkPlanner,
+ plan: LogicalPlan): SparkPlan =
+ QueryExecution.createSparkPlan(sparkSession, planner, plan)
}
diff --git a/shims/spark33/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFooterReaderShim.scala b/shims/spark33/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFooterReaderShim.scala
new file mode 100644
index 000000000000..b1419e5e6233
--- /dev/null
+++ b/shims/spark33/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFooterReaderShim.scala
@@ -0,0 +1,42 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.datasources.parquet
+
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.fs.{FileStatus, Path}
+import org.apache.parquet.format.converter.ParquetMetadataConverter
+import org.apache.parquet.hadoop.metadata.ParquetMetadata
+
+/** Shim layer for ParquetFooterReader to maintain compatibility across different Spark versions. */
+object ParquetFooterReaderShim {
+
+ /** @since Spark 4.1 */
+ def readFooter(
+ configuration: Configuration,
+ fileStatus: FileStatus,
+ filter: ParquetMetadataConverter.MetadataFilter): ParquetMetadata = {
+ ParquetFooterReader.readFooter(configuration, fileStatus, filter)
+ }
+
+ /** @since Spark 4.1 */
+ def readFooter(
+ configuration: Configuration,
+ file: Path,
+ filter: ParquetMetadataConverter.MetadataFilter): ParquetMetadata = {
+ ParquetFooterReader.readFooter(configuration, file, filter)
+ }
+}
diff --git a/shims/spark34/src/main/scala/org/apache/gluten/sql/shims/spark34/Spark34Shims.scala b/shims/spark34/src/main/scala/org/apache/gluten/sql/shims/spark34/Spark34Shims.scala
index 39dd71a6bc3c..cdbeaa47838b 100644
--- a/shims/spark34/src/main/scala/org/apache/gluten/sql/shims/spark34/Spark34Shims.scala
+++ b/shims/spark34/src/main/scala/org/apache/gluten/sql/shims/spark34/Spark34Shims.scala
@@ -651,4 +651,16 @@ class Spark34Shims extends SparkShims {
override def getErrorMessage(raiseError: RaiseError): Option[Expression] = {
Some(raiseError.child)
}
+
+ /**
+ * Shim layer for QueryExecution to maintain compatibility across different Spark versions.
+ *
+ * @since Spark
+ * 4.1
+ */
+ override def createSparkPlan(
+ sparkSession: SparkSession,
+ planner: SparkPlanner,
+ plan: LogicalPlan): SparkPlan =
+ QueryExecution.createSparkPlan(sparkSession, planner, plan)
}
diff --git a/shims/spark34/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFooterReaderShim.scala b/shims/spark34/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFooterReaderShim.scala
new file mode 100644
index 000000000000..b1419e5e6233
--- /dev/null
+++ b/shims/spark34/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFooterReaderShim.scala
@@ -0,0 +1,42 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.datasources.parquet
+
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.fs.{FileStatus, Path}
+import org.apache.parquet.format.converter.ParquetMetadataConverter
+import org.apache.parquet.hadoop.metadata.ParquetMetadata
+
+/** Shim layer for ParquetFooterReader to maintain compatibility across different Spark versions. */
+object ParquetFooterReaderShim {
+
+ /** @since Spark 4.1 */
+ def readFooter(
+ configuration: Configuration,
+ fileStatus: FileStatus,
+ filter: ParquetMetadataConverter.MetadataFilter): ParquetMetadata = {
+ ParquetFooterReader.readFooter(configuration, fileStatus, filter)
+ }
+
+ /** @since Spark 4.1 */
+ def readFooter(
+ configuration: Configuration,
+ file: Path,
+ filter: ParquetMetadataConverter.MetadataFilter): ParquetMetadata = {
+ ParquetFooterReader.readFooter(configuration, file, filter)
+ }
+}
diff --git a/shims/spark35/src/main/scala/org/apache/gluten/sql/shims/spark35/Spark35Shims.scala b/shims/spark35/src/main/scala/org/apache/gluten/sql/shims/spark35/Spark35Shims.scala
index 78da08190f17..d993cc0bfd20 100644
--- a/shims/spark35/src/main/scala/org/apache/gluten/sql/shims/spark35/Spark35Shims.scala
+++ b/shims/spark35/src/main/scala/org/apache/gluten/sql/shims/spark35/Spark35Shims.scala
@@ -702,4 +702,16 @@ class Spark35Shims extends SparkShims {
override def getErrorMessage(raiseError: RaiseError): Option[Expression] = {
Some(raiseError.child)
}
+
+ /**
+ * Shim layer for QueryExecution to maintain compatibility across different Spark versions.
+ *
+ * @since Spark
+ * 4.1
+ */
+ override def createSparkPlan(
+ sparkSession: SparkSession,
+ planner: SparkPlanner,
+ plan: LogicalPlan): SparkPlan =
+ QueryExecution.createSparkPlan(sparkSession, planner, plan)
}
diff --git a/shims/spark35/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFooterReaderShim.scala b/shims/spark35/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFooterReaderShim.scala
new file mode 100644
index 000000000000..b1419e5e6233
--- /dev/null
+++ b/shims/spark35/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFooterReaderShim.scala
@@ -0,0 +1,42 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.datasources.parquet
+
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.fs.{FileStatus, Path}
+import org.apache.parquet.format.converter.ParquetMetadataConverter
+import org.apache.parquet.hadoop.metadata.ParquetMetadata
+
+/** Shim layer for ParquetFooterReader to maintain compatibility across different Spark versions. */
+object ParquetFooterReaderShim {
+
+ /** @since Spark 4.1 */
+ def readFooter(
+ configuration: Configuration,
+ fileStatus: FileStatus,
+ filter: ParquetMetadataConverter.MetadataFilter): ParquetMetadata = {
+ ParquetFooterReader.readFooter(configuration, fileStatus, filter)
+ }
+
+ /** @since Spark 4.1 */
+ def readFooter(
+ configuration: Configuration,
+ file: Path,
+ filter: ParquetMetadataConverter.MetadataFilter): ParquetMetadata = {
+ ParquetFooterReader.readFooter(configuration, file, filter)
+ }
+}
diff --git a/shims/spark40/src/main/scala/org/apache/gluten/sql/shims/spark40/Spark40Shims.scala b/shims/spark40/src/main/scala/org/apache/gluten/sql/shims/spark40/Spark40Shims.scala
index 80c22f2fad25..e5258eafa46c 100644
--- a/shims/spark40/src/main/scala/org/apache/gluten/sql/shims/spark40/Spark40Shims.scala
+++ b/shims/spark40/src/main/scala/org/apache/gluten/sql/shims/spark40/Spark40Shims.scala
@@ -40,6 +40,7 @@ import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.types.DataTypeUtils
import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, InternalRowComparableWrapper, TimestampFormatter}
import org.apache.spark.sql.catalyst.util.RebaseDateTime.RebaseSpec
+import org.apache.spark.sql.classic.ClassicConversions._
import org.apache.spark.sql.connector.catalog.Table
import org.apache.spark.sql.connector.expressions.Transform
import org.apache.spark.sql.connector.read.{HasPartitionKey, InputPartition, Scan}
@@ -752,4 +753,16 @@ class Spark40Shims extends SparkShims {
override def unsupportedCodec: Seq[CompressionCodecName] = {
Seq(CompressionCodecName.LZO, CompressionCodecName.BROTLI, CompressionCodecName.LZ4_RAW)
}
+
+ /**
+ * Shim layer for QueryExecution to maintain compatibility across different Spark versions.
+ *
+ * @since Spark
+ * 4.1
+ */
+ override def createSparkPlan(
+ sparkSession: SparkSession,
+ planner: SparkPlanner,
+ plan: LogicalPlan): SparkPlan =
+ QueryExecution.createSparkPlan(sparkSession, planner, plan)
}
diff --git a/shims/spark40/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFooterReaderShim.scala b/shims/spark40/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFooterReaderShim.scala
new file mode 100644
index 000000000000..b1419e5e6233
--- /dev/null
+++ b/shims/spark40/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFooterReaderShim.scala
@@ -0,0 +1,42 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.datasources.parquet
+
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.fs.{FileStatus, Path}
+import org.apache.parquet.format.converter.ParquetMetadataConverter
+import org.apache.parquet.hadoop.metadata.ParquetMetadata
+
+/** Shim layer for ParquetFooterReader to maintain compatibility across different Spark versions. */
+object ParquetFooterReaderShim {
+
+ /** @since Spark 4.1 */
+ def readFooter(
+ configuration: Configuration,
+ fileStatus: FileStatus,
+ filter: ParquetMetadataConverter.MetadataFilter): ParquetMetadata = {
+ ParquetFooterReader.readFooter(configuration, fileStatus, filter)
+ }
+
+ /** @since Spark 4.1 */
+ def readFooter(
+ configuration: Configuration,
+ file: Path,
+ filter: ParquetMetadataConverter.MetadataFilter): ParquetMetadata = {
+ ParquetFooterReader.readFooter(configuration, file, filter)
+ }
+}
diff --git a/shims/spark40/src/main/scala/org/apache/spark/sql/execution/datasources/v2/Spark35Scan.scala b/shims/spark40/src/main/scala/org/apache/spark/sql/execution/datasources/v2/Spark35Scan.scala
deleted file mode 100644
index 98fcfa548384..000000000000
--- a/shims/spark40/src/main/scala/org/apache/spark/sql/execution/datasources/v2/Spark35Scan.scala
+++ /dev/null
@@ -1,50 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.spark.sql.execution.datasources.v2
-
-import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, SortOrder}
-import org.apache.spark.sql.connector.read.{InputPartition, PartitionReaderFactory, Scan}
-
-class Spark35Scan extends DataSourceV2ScanExecBase {
-
- override def scan: Scan = throw new UnsupportedOperationException("Spark35Scan")
-
- override def ordering: Option[Seq[SortOrder]] = throw new UnsupportedOperationException(
- "Spark35Scan")
-
- override def readerFactory: PartitionReaderFactory =
- throw new UnsupportedOperationException("Spark35Scan")
-
- override def keyGroupedPartitioning: Option[Seq[Expression]] =
- throw new UnsupportedOperationException("Spark35Scan")
-
- override protected def inputPartitions: Seq[InputPartition] =
- throw new UnsupportedOperationException("Spark35Scan")
-
- override def inputRDD: RDD[InternalRow] = throw new UnsupportedOperationException("Spark35Scan")
-
- override def output: Seq[Attribute] = throw new UnsupportedOperationException("Spark35Scan")
-
- override def productElement(n: Int): Any = throw new UnsupportedOperationException("Spark35Scan")
-
- override def productArity: Int = throw new UnsupportedOperationException("Spark35Scan")
-
- override def canEqual(that: Any): Boolean = throw new UnsupportedOperationException("Spark35Scan")
-
-}
diff --git a/shims/spark41/pom.xml b/shims/spark41/pom.xml
new file mode 100644
index 000000000000..3d39f1a9322d
--- /dev/null
+++ b/shims/spark41/pom.xml
@@ -0,0 +1,118 @@
+
+
+
+ 4.0.0
+
+
+ org.apache.gluten
+ spark-sql-columnar-shims
+ 1.6.0-SNAPSHOT
+ ../pom.xml
+
+
+ spark-sql-columnar-shims-spark41
+ jar
+ Gluten Shims for Spark 4.1
+
+
+
+ org.apache.gluten
+ ${project.prefix}-shims-common
+ ${project.version}
+ compile
+
+
+ org.apache.spark
+ spark-sql_${scala.binary.version}
+ provided
+ true
+
+
+ org.apache.spark
+ spark-catalyst_${scala.binary.version}
+ provided
+ true
+
+
+ org.apache.spark
+ spark-core_${scala.binary.version}
+ provided
+ true
+
+
+ org.apache.hadoop
+ hadoop-common
+ ${hadoop.version}
+ provided
+
+
+
+
+ org.scalatest
+ scalatest_${scala.binary.version}
+ test
+
+
+ org.apache.spark
+ spark-core_${scala.binary.version}
+ test-jar
+
+
+ org.apache.spark
+ spark-sql_${scala.binary.version}
+ test-jar
+
+
+ org.apache.spark
+ spark-catalyst_${scala.binary.version}
+ test-jar
+
+
+ org.apache.spark
+ spark-hive_${scala.binary.version}
+ provided
+
+
+
+
+
+
+
+ org.apache.maven.plugins
+ maven-checkstyle-plugin
+
+
+ org.scalastyle
+ scalastyle-maven-plugin
+
+
+ com.diffplug.spotless
+ spotless-maven-plugin
+
+
+ net.alchim31.maven
+ scala-maven-plugin
+
+
+ org.scalatest
+ scalatest-maven-plugin
+
+
+ org.apache.maven.plugins
+ maven-compiler-plugin
+
+
+
+
diff --git a/shims/spark41/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarArrayShim.java b/shims/spark41/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarArrayShim.java
new file mode 100644
index 000000000000..7d1347345a95
--- /dev/null
+++ b/shims/spark41/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarArrayShim.java
@@ -0,0 +1,253 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.vectorized;
+
+import org.apache.spark.SparkUnsupportedOperationException;
+import org.apache.spark.sql.catalyst.expressions.SpecializedGettersReader;
+import org.apache.spark.sql.catalyst.expressions.UnsafeArrayData;
+import org.apache.spark.sql.catalyst.util.ArrayData;
+import org.apache.spark.sql.catalyst.util.GenericArrayData;
+import org.apache.spark.sql.types.*;
+import org.apache.spark.sql.vectorized.ColumnVector;
+import org.apache.spark.sql.vectorized.ColumnarArray;
+import org.apache.spark.sql.vectorized.ColumnarMap;
+import org.apache.spark.sql.vectorized.ColumnarRow;
+import org.apache.spark.unsafe.types.CalendarInterval;
+import org.apache.spark.unsafe.types.GeographyVal;
+import org.apache.spark.unsafe.types.GeometryVal;
+import org.apache.spark.unsafe.types.UTF8String;
+import org.apache.spark.unsafe.types.VariantVal;
+
+public class ColumnarArrayShim extends ArrayData {
+ // The data for this array. This array contains elements from
+ // data[offset] to data[offset + length).
+ private final ColumnVector data;
+ private final int offset;
+ private final int length;
+
+ public ColumnarArrayShim(ColumnVector data, int offset, int length) {
+ this.data = data;
+ this.offset = offset;
+ this.length = length;
+ }
+
+ @Override
+ public int numElements() {
+ return length;
+ }
+
+ /**
+ * Sets all the appropriate null bits in the input UnsafeArrayData.
+ *
+ * @param arrayData The UnsafeArrayData to set the null bits for
+ * @return The UnsafeArrayData with the null bits set
+ */
+ private UnsafeArrayData setNullBits(UnsafeArrayData arrayData) {
+ if (data.hasNull()) {
+ for (int i = 0; i < length; i++) {
+ if (data.isNullAt(offset + i)) {
+ arrayData.setNullAt(i);
+ }
+ }
+ }
+ return arrayData;
+ }
+
+ @Override
+ public ArrayData copy() {
+ DataType dt = data.dataType();
+
+ if (dt instanceof BooleanType) {
+ return setNullBits(UnsafeArrayData.fromPrimitiveArray(toBooleanArray()));
+ } else if (dt instanceof ByteType) {
+ return setNullBits(UnsafeArrayData.fromPrimitiveArray(toByteArray()));
+ } else if (dt instanceof ShortType) {
+ return setNullBits(UnsafeArrayData.fromPrimitiveArray(toShortArray()));
+ } else if (dt instanceof IntegerType
+ || dt instanceof DateType
+ || dt instanceof YearMonthIntervalType) {
+ return setNullBits(UnsafeArrayData.fromPrimitiveArray(toIntArray()));
+ } else if (dt instanceof LongType
+ || dt instanceof TimestampType
+ || dt instanceof DayTimeIntervalType) {
+ return setNullBits(UnsafeArrayData.fromPrimitiveArray(toLongArray()));
+ } else if (dt instanceof FloatType) {
+ return setNullBits(UnsafeArrayData.fromPrimitiveArray(toFloatArray()));
+ } else if (dt instanceof DoubleType) {
+ return setNullBits(UnsafeArrayData.fromPrimitiveArray(toDoubleArray()));
+ } else {
+ return new GenericArrayData(toObjectArray(dt)).copy(); // ensure the elements are copied.
+ }
+ }
+
+ @Override
+ public boolean[] toBooleanArray() {
+ return data.getBooleans(offset, length);
+ }
+
+ @Override
+ public byte[] toByteArray() {
+ return data.getBytes(offset, length);
+ }
+
+ @Override
+ public short[] toShortArray() {
+ return data.getShorts(offset, length);
+ }
+
+ @Override
+ public int[] toIntArray() {
+ return data.getInts(offset, length);
+ }
+
+ @Override
+ public long[] toLongArray() {
+ return data.getLongs(offset, length);
+ }
+
+ @Override
+ public float[] toFloatArray() {
+ return data.getFloats(offset, length);
+ }
+
+ @Override
+ public double[] toDoubleArray() {
+ return data.getDoubles(offset, length);
+ }
+
+ // TODO: this is extremely expensive.
+ @Override
+ public Object[] array() {
+ DataType dt = data.dataType();
+ Object[] list = new Object[length];
+ try {
+ for (int i = 0; i < length; i++) {
+ if (!data.isNullAt(offset + i)) {
+ list[i] = get(i, dt);
+ }
+ }
+ return list;
+ } catch (Exception e) {
+ throw new RuntimeException("Could not get the array", e);
+ }
+ }
+
+ @Override
+ public boolean isNullAt(int ordinal) {
+ return data.isNullAt(offset + ordinal);
+ }
+
+ @Override
+ public boolean getBoolean(int ordinal) {
+ return data.getBoolean(offset + ordinal);
+ }
+
+ @Override
+ public byte getByte(int ordinal) {
+ return data.getByte(offset + ordinal);
+ }
+
+ @Override
+ public short getShort(int ordinal) {
+ return data.getShort(offset + ordinal);
+ }
+
+ @Override
+ public int getInt(int ordinal) {
+ return data.getInt(offset + ordinal);
+ }
+
+ @Override
+ public long getLong(int ordinal) {
+ return data.getLong(offset + ordinal);
+ }
+
+ @Override
+ public float getFloat(int ordinal) {
+ return data.getFloat(offset + ordinal);
+ }
+
+ @Override
+ public double getDouble(int ordinal) {
+ return data.getDouble(offset + ordinal);
+ }
+
+ @Override
+ public Decimal getDecimal(int ordinal, int precision, int scale) {
+ return data.getDecimal(offset + ordinal, precision, scale);
+ }
+
+ @Override
+ public UTF8String getUTF8String(int ordinal) {
+ return data.getUTF8String(offset + ordinal);
+ }
+
+ @Override
+ public byte[] getBinary(int ordinal) {
+ return data.getBinary(offset + ordinal);
+ }
+
+ @Override
+ public GeographyVal getGeography(int ordinal) {
+ return data.getGeography(offset + ordinal);
+ }
+
+ @Override
+ public GeometryVal getGeometry(int ordinal) {
+ return data.getGeometry(offset + ordinal);
+ }
+
+ @Override
+ public CalendarInterval getInterval(int ordinal) {
+ return data.getInterval(offset + ordinal);
+ }
+
+ @Override
+ public VariantVal getVariant(int ordinal) {
+ return data.getVariant(offset + ordinal);
+ }
+
+ @Override
+ public ColumnarRow getStruct(int ordinal, int numFields) {
+ return data.getStruct(offset + ordinal);
+ }
+
+ @Override
+ public ColumnarArray getArray(int ordinal) {
+ return data.getArray(offset + ordinal);
+ }
+
+ @Override
+ public ColumnarMap getMap(int ordinal) {
+ return data.getMap(offset + ordinal);
+ }
+
+ @Override
+ public Object get(int ordinal, DataType dataType) {
+ return SpecializedGettersReader.read(this, ordinal, dataType, true, false);
+ }
+
+ @Override
+ public void update(int ordinal, Object value) {
+ throw SparkUnsupportedOperationException.apply();
+ }
+
+ @Override
+ public void setNullAt(int ordinal) {
+ throw SparkUnsupportedOperationException.apply();
+ }
+}
diff --git a/shims/spark41/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVectorShim.java b/shims/spark41/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVectorShim.java
new file mode 100644
index 000000000000..513f3d2d92ca
--- /dev/null
+++ b/shims/spark41/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVectorShim.java
@@ -0,0 +1,225 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.vectorized;
+
+import org.apache.spark.sql.types.DataType;
+import org.apache.spark.unsafe.types.UTF8String;
+
+import java.nio.ByteBuffer;
+
+/**
+ * because spark33 add new function abstract method 'putBooleans(int, byte)' in
+ * 'WritableColumnVector' And function getByteBuffer()
+ */
+public class WritableColumnVectorShim extends WritableColumnVector {
+ /**
+ * Sets up the common state and also handles creating the child columns if this is a nested type.
+ *
+ * @param capacity
+ * @param type
+ */
+ protected WritableColumnVectorShim(int capacity, DataType type) {
+ super(capacity, type);
+ }
+
+ protected void releaseMemory() {}
+
+ @Override
+ public int getDictId(int rowId) {
+ return 0;
+ }
+
+ @Override
+ protected void reserveInternal(int capacity) {}
+
+ @Override
+ public void putNotNull(int rowId) {}
+
+ @Override
+ public void putNull(int rowId) {}
+
+ @Override
+ public void putNulls(int rowId, int count) {}
+
+ @Override
+ public void putNotNulls(int rowId, int count) {}
+
+ @Override
+ public void putBoolean(int rowId, boolean value) {}
+
+ @Override
+ public void putBooleans(int rowId, int count, boolean value) {}
+
+ @Override
+ public void putBooleans(int rowId, byte src) {
+ throw new UnsupportedOperationException("Unsupported function");
+ }
+
+ @Override
+ public void putByte(int rowId, byte value) {}
+
+ @Override
+ public void putBytes(int rowId, int count, byte value) {}
+
+ @Override
+ public void putBytes(int rowId, int count, byte[] src, int srcIndex) {}
+
+ @Override
+ public void putShort(int rowId, short value) {}
+
+ @Override
+ public void putShorts(int rowId, int count, short value) {}
+
+ @Override
+ public void putShorts(int rowId, int count, short[] src, int srcIndex) {}
+
+ @Override
+ public void putShorts(int rowId, int count, byte[] src, int srcIndex) {}
+
+ @Override
+ public void putInt(int rowId, int value) {}
+
+ @Override
+ public void putInts(int rowId, int count, int value) {}
+
+ @Override
+ public void putInts(int rowId, int count, int[] src, int srcIndex) {}
+
+ @Override
+ public void putInts(int rowId, int count, byte[] src, int srcIndex) {}
+
+ @Override
+ public void putIntsLittleEndian(int rowId, int count, byte[] src, int srcIndex) {}
+
+ @Override
+ public void putLong(int rowId, long value) {}
+
+ @Override
+ public void putLongs(int rowId, int count, long value) {}
+
+ @Override
+ public void putLongs(int rowId, int count, long[] src, int srcIndex) {}
+
+ @Override
+ public void putLongs(int rowId, int count, byte[] src, int srcIndex) {}
+
+ @Override
+ public void putLongsLittleEndian(int rowId, int count, byte[] src, int srcIndex) {}
+
+ @Override
+ public void putFloat(int rowId, float value) {}
+
+ @Override
+ public void putFloats(int rowId, int count, float value) {}
+
+ @Override
+ public void putFloats(int rowId, int count, float[] src, int srcIndex) {}
+
+ @Override
+ public void putFloats(int rowId, int count, byte[] src, int srcIndex) {}
+
+ @Override
+ public void putFloatsLittleEndian(int rowId, int count, byte[] src, int srcIndex) {}
+
+ @Override
+ public void putDouble(int rowId, double value) {}
+
+ @Override
+ public void putDoubles(int rowId, int count, double value) {}
+
+ @Override
+ public void putDoubles(int rowId, int count, double[] src, int srcIndex) {}
+
+ @Override
+ public void putDoubles(int rowId, int count, byte[] src, int srcIndex) {}
+
+ @Override
+ public void putDoublesLittleEndian(int rowId, int count, byte[] src, int srcIndex) {}
+
+ @Override
+ public void putArray(int rowId, int offset, int length) {}
+
+ @Override
+ public int putByteArray(int rowId, byte[] value, int offset, int count) {
+ return 0;
+ }
+
+ @Override
+ protected UTF8String getBytesAsUTF8String(int rowId, int count) {
+ return null;
+ }
+
+ @Override
+ public ByteBuffer getByteBuffer(int rowId, int count) {
+ throw new UnsupportedOperationException("Unsupported this function");
+ }
+
+ @Override
+ public int getArrayLength(int rowId) {
+ return 0;
+ }
+
+ @Override
+ public int getArrayOffset(int rowId) {
+ return 0;
+ }
+
+ @Override
+ public WritableColumnVector reserveNewColumn(int capacity, DataType type) {
+ return null;
+ }
+
+ @Override
+ public boolean isNullAt(int rowId) {
+ return false;
+ }
+
+ @Override
+ public boolean getBoolean(int rowId) {
+ return false;
+ }
+
+ @Override
+ public byte getByte(int rowId) {
+ return 0;
+ }
+
+ @Override
+ public short getShort(int rowId) {
+ return 0;
+ }
+
+ @Override
+ public int getInt(int rowId) {
+ return 0;
+ }
+
+ @Override
+ public long getLong(int rowId) {
+ return 0;
+ }
+
+ @Override
+ public float getFloat(int rowId) {
+ return 0;
+ }
+
+ @Override
+ public double getDouble(int rowId) {
+ return 0;
+ }
+}
diff --git a/shims/spark41/src/main/resources/META-INF/services/org.apache.gluten.sql.shims.SparkShimProvider b/shims/spark41/src/main/resources/META-INF/services/org.apache.gluten.sql.shims.SparkShimProvider
new file mode 100644
index 000000000000..6c2e7e518171
--- /dev/null
+++ b/shims/spark41/src/main/resources/META-INF/services/org.apache.gluten.sql.shims.SparkShimProvider
@@ -0,0 +1 @@
+org.apache.gluten.sql.shims.spark41.SparkShimProvider
\ No newline at end of file
diff --git a/shims/spark41/src/main/scala/org/apache/gluten/execution/GenerateTreeStringShim.scala b/shims/spark41/src/main/scala/org/apache/gluten/execution/GenerateTreeStringShim.scala
new file mode 100644
index 000000000000..3d329a8f31ce
--- /dev/null
+++ b/shims/spark41/src/main/scala/org/apache/gluten/execution/GenerateTreeStringShim.scala
@@ -0,0 +1,93 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.gluten.execution
+
+import org.apache.spark.sql.execution.UnaryExecNode
+
+/**
+ * Spark 3.5 has changed the parameter type of the generateTreeString API in TreeNode. In order to
+ * support multiple versions of Spark, we cannot directly override the generateTreeString method in
+ * WhostageTransformer. Therefore, we have defined the GenerateTreeStringShim trait in the shim to
+ * allow different Spark versions to override their own generateTreeString.
+ */
+
+trait WholeStageTransformerGenerateTreeStringShim extends UnaryExecNode {
+
+ def stageId: Int
+
+ def substraitPlanJson: String
+
+ def wholeStageTransformerContextDefined: Boolean
+
+ override def generateTreeString(
+ depth: Int,
+ lastChildren: java.util.ArrayList[Boolean],
+ append: String => Unit,
+ verbose: Boolean,
+ prefix: String = "",
+ addSuffix: Boolean = false,
+ maxFields: Int,
+ printNodeId: Boolean,
+ printOutputColumns: Boolean,
+ indent: Int = 0): Unit = {
+ val prefix = if (printNodeId) "^ " else s"^($stageId) "
+ child.generateTreeString(
+ depth,
+ lastChildren,
+ append,
+ verbose,
+ prefix,
+ addSuffix = false,
+ maxFields,
+ printNodeId = printNodeId,
+ printOutputColumns = printOutputColumns,
+ indent = indent)
+
+ if (verbose && wholeStageTransformerContextDefined) {
+ append(prefix + "Substrait plan:\n")
+ append(substraitPlanJson)
+ append("\n")
+ }
+ }
+}
+
+trait InputAdapterGenerateTreeStringShim extends UnaryExecNode {
+
+ override def generateTreeString(
+ depth: Int,
+ lastChildren: java.util.ArrayList[Boolean],
+ append: String => Unit,
+ verbose: Boolean,
+ prefix: String = "",
+ addSuffix: Boolean = false,
+ maxFields: Int,
+ printNodeId: Boolean,
+ printOutputColumns: Boolean,
+ indent: Int = 0): Unit = {
+ child.generateTreeString(
+ depth,
+ lastChildren,
+ append,
+ verbose,
+ prefix = "",
+ addSuffix = false,
+ maxFields,
+ printNodeId,
+ printOutputColumns,
+ indent)
+ }
+}
diff --git a/shims/spark41/src/main/scala/org/apache/gluten/execution/PartitionedFileUtilShim.scala b/shims/spark41/src/main/scala/org/apache/gluten/execution/PartitionedFileUtilShim.scala
new file mode 100644
index 000000000000..2d3287afe4d5
--- /dev/null
+++ b/shims/spark41/src/main/scala/org/apache/gluten/execution/PartitionedFileUtilShim.scala
@@ -0,0 +1,174 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.gluten.execution
+
+import org.apache.spark.paths.SparkPath
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.execution.PartitionedFileUtil
+import org.apache.spark.sql.execution.datasources.{FileStatusWithMetadata, PartitionedFile}
+
+import org.apache.hadoop.fs.Path
+
+import java.lang.reflect.Method
+
+object PartitionedFileUtilShim {
+
+ private val clz: Class[_] = PartitionedFileUtil.getClass
+ private val module = clz.getField("MODULE$").get(null)
+
+ private lazy val getPartitionedFileMethod: Method = {
+ try {
+ val m = clz.getDeclaredMethod(
+ "getPartitionedFile",
+ classOf[FileStatusWithMetadata],
+ classOf[InternalRow])
+ m.setAccessible(true)
+ m
+ } catch {
+ case _: NoSuchMethodException => null
+ }
+ }
+
+ private lazy val getPartitionedFileByPathMethod: Method = {
+ try {
+ val m = clz.getDeclaredMethod(
+ "getPartitionedFile",
+ classOf[FileStatusWithMetadata],
+ classOf[Path],
+ classOf[InternalRow])
+ m.setAccessible(true)
+ m
+ } catch {
+ case _: NoSuchMethodException => null
+ }
+ }
+
+ private lazy val getPartitionedFileByPathSizeMethod: Method = {
+ try {
+ val m = clz.getDeclaredMethod(
+ "getPartitionedFile",
+ classOf[FileStatusWithMetadata],
+ classOf[Path],
+ classOf[InternalRow],
+ classOf[Long],
+ classOf[Long])
+ m.setAccessible(true)
+ m
+ } catch {
+ case _: NoSuchMethodException => null
+ }
+ }
+
+ def getPartitionedFile(
+ file: FileStatusWithMetadata,
+ partitionValues: InternalRow): PartitionedFile = {
+ if (getPartitionedFileMethod != null) {
+ getPartitionedFileMethod
+ .invoke(module, file, partitionValues)
+ .asInstanceOf[PartitionedFile]
+ } else if (getPartitionedFileByPathMethod != null) {
+ getPartitionedFileByPathMethod
+ .invoke(module, file, file.getPath, partitionValues)
+ .asInstanceOf[PartitionedFile]
+ } else if (getPartitionedFileByPathSizeMethod != null) {
+ getPartitionedFileByPathSizeMethod
+ .invoke(module, file, file.getPath, partitionValues, 0, file.getLen)
+ .asInstanceOf[PartitionedFile]
+ } else {
+ val params = clz.getDeclaredMethods
+ .find(_.getName == "getPartitionedFile")
+ .map(_.getGenericParameterTypes.mkString(", "))
+ throw new RuntimeException(
+ s"getPartitionedFile with $params is not correctly shimmed " +
+ "in PartitionedFileUtilShim")
+ }
+ }
+
+ private lazy val splitFilesMethod: Method = {
+ try {
+ val m = clz.getDeclaredMethod(
+ "splitFiles",
+ classOf[SparkSession],
+ classOf[FileStatusWithMetadata],
+ classOf[Boolean],
+ classOf[Long],
+ classOf[InternalRow])
+ m.setAccessible(true)
+ m
+ } catch {
+ case _: NoSuchMethodException => null
+ }
+ }
+
+ private lazy val splitFilesByPathMethod: Method = {
+ try {
+ val m = clz.getDeclaredMethod(
+ "splitFiles",
+ classOf[FileStatusWithMetadata],
+ classOf[Path],
+ classOf[Boolean],
+ classOf[Long],
+ classOf[InternalRow])
+ m.setAccessible(true)
+ m
+ } catch {
+ case _: NoSuchMethodException => null
+ }
+ }
+
+ def splitFiles(
+ sparkSession: SparkSession,
+ file: FileStatusWithMetadata,
+ isSplitable: Boolean,
+ maxSplitBytes: Long,
+ partitionValues: InternalRow): Seq[PartitionedFile] = {
+ if (splitFilesMethod != null) {
+ splitFilesMethod
+ .invoke(
+ module,
+ sparkSession,
+ file,
+ java.lang.Boolean.valueOf(isSplitable),
+ java.lang.Long.valueOf(maxSplitBytes),
+ partitionValues)
+ .asInstanceOf[Seq[PartitionedFile]]
+ } else if (splitFilesByPathMethod != null) {
+ splitFilesByPathMethod
+ .invoke(
+ module,
+ file,
+ file.getPath,
+ java.lang.Boolean.valueOf(isSplitable),
+ java.lang.Long.valueOf(maxSplitBytes),
+ partitionValues)
+ .asInstanceOf[Seq[PartitionedFile]]
+ } else {
+ val params = clz.getDeclaredMethods
+ .find(_.getName == "splitFiles")
+ .map(_.getGenericParameterTypes.mkString(", "))
+ throw new RuntimeException(
+ s"splitFiles with $params is not correctly shimmed " +
+ "in PartitionedFileUtilShim")
+ }
+ }
+
+ // Helper method to create PartitionedFile from path and length.
+ def makePartitionedFileFromPath(path: String, length: Long): PartitionedFile = {
+ PartitionedFile(null, SparkPath.fromPathString(path), 0, length, Array.empty)
+ }
+}
diff --git a/shims/spark41/src/main/scala/org/apache/gluten/sql/shims/spark41/Spark41Shims.scala b/shims/spark41/src/main/scala/org/apache/gluten/sql/shims/spark41/Spark41Shims.scala
new file mode 100644
index 000000000000..ea5f733614cb
--- /dev/null
+++ b/shims/spark41/src/main/scala/org/apache/gluten/sql/shims/spark41/Spark41Shims.scala
@@ -0,0 +1,767 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.gluten.sql.shims.spark41
+
+import org.apache.gluten.execution.PartitionedFileUtilShim
+import org.apache.gluten.expression.{ExpressionNames, Sig}
+import org.apache.gluten.sql.shims.SparkShims
+
+import org.apache.spark._
+import org.apache.spark.broadcast.Broadcast
+import org.apache.spark.internal.io.FileCommitProtocol
+import org.apache.spark.paths.SparkPath
+import org.apache.spark.scheduler.TaskInfo
+import org.apache.spark.shuffle.ShuffleHandle
+import org.apache.spark.sql.{AnalysisException, SparkSession}
+import org.apache.spark.sql.catalyst.{ExtendedAnalysisException, InternalRow}
+import org.apache.spark.sql.catalyst.analysis.DecimalPrecisionTypeCoercion
+import org.apache.spark.sql.catalyst.catalog.BucketSpec
+import org.apache.spark.sql.catalyst.csv.CSVOptions
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate._
+import org.apache.spark.sql.catalyst.plans.QueryPlan
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Distribution, KeyGroupedPartitioning, KeyGroupedShuffleSpec, Partitioning}
+import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.catalyst.types.DataTypeUtils
+import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, InternalRowComparableWrapper, TimestampFormatter}
+import org.apache.spark.sql.catalyst.util.RebaseDateTime.RebaseSpec
+import org.apache.spark.sql.connector.catalog.Table
+import org.apache.spark.sql.connector.expressions.Transform
+import org.apache.spark.sql.connector.read.{HasPartitionKey, InputPartition, Scan}
+import org.apache.spark.sql.connector.read.streaming.SparkDataStream
+import org.apache.spark.sql.execution._
+import org.apache.spark.sql.execution.datasources._
+import org.apache.spark.sql.execution.datasources.parquet.{ParquetFileFormat, ParquetFilters}
+import org.apache.spark.sql.execution.datasources.v2.{BatchScanExec, BatchScanExecShim, DataSourceV2ScanExecBase}
+import org.apache.spark.sql.execution.datasources.v2.text.TextScan
+import org.apache.spark.sql.execution.datasources.v2.utils.CatalogUtil
+import org.apache.spark.sql.execution.exchange.{BroadcastExchangeLike, ShuffleExchangeLike}
+import org.apache.spark.sql.execution.window.{Final, Partial, _}
+import org.apache.spark.sql.internal.{LegacyBehaviorPolicy, SQLConf}
+import org.apache.spark.sql.types._
+import org.apache.spark.sql.util.CaseInsensitiveStringMap
+import org.apache.spark.storage.{BlockId, BlockManagerId}
+
+import org.apache.hadoop.fs.{FileStatus, Path}
+import org.apache.parquet.hadoop.metadata.{CompressionCodecName, ParquetMetadata}
+import org.apache.parquet.hadoop.metadata.FileMetaData.EncryptionType
+import org.apache.parquet.schema.MessageType
+
+import java.time.ZoneOffset
+import java.util.{Map => JMap}
+
+import scala.collection.mutable
+import scala.jdk.CollectionConverters._
+import scala.reflect.ClassTag
+
+class Spark41Shims extends SparkShims {
+
+ override def getDistribution(
+ leftKeys: Seq[Expression],
+ rightKeys: Seq[Expression]): Seq[Distribution] = {
+ ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil
+ }
+
+ override def scalarExpressionMappings: Seq[Sig] = {
+ Seq(
+ Sig[SplitPart](ExpressionNames.SPLIT_PART),
+ Sig[Sec](ExpressionNames.SEC),
+ Sig[Csc](ExpressionNames.CSC),
+ Sig[KnownNullable](ExpressionNames.KNOWN_NULLABLE),
+ Sig[Empty2Null](ExpressionNames.EMPTY2NULL),
+ Sig[Mask](ExpressionNames.MASK),
+ Sig[TimestampAdd](ExpressionNames.TIMESTAMP_ADD),
+ Sig[TimestampDiff](ExpressionNames.TIMESTAMP_DIFF),
+ Sig[RoundFloor](ExpressionNames.FLOOR),
+ Sig[RoundCeil](ExpressionNames.CEIL),
+ Sig[ArrayInsert](ExpressionNames.ARRAY_INSERT),
+ Sig[CheckOverflowInTableInsert](ExpressionNames.CHECK_OVERFLOW_IN_TABLE_INSERT),
+ Sig[ArrayAppend](ExpressionNames.ARRAY_APPEND),
+ Sig[UrlEncode](ExpressionNames.URL_ENCODE),
+ Sig[KnownNotContainsNull](ExpressionNames.KNOWN_NOT_CONTAINS_NULL),
+ Sig[UrlDecode](ExpressionNames.URL_DECODE)
+ )
+ }
+
+ override def aggregateExpressionMappings: Seq[Sig] = {
+ Seq(
+ Sig[RegrR2](ExpressionNames.REGR_R2),
+ Sig[RegrSlope](ExpressionNames.REGR_SLOPE),
+ Sig[RegrIntercept](ExpressionNames.REGR_INTERCEPT),
+ Sig[RegrSXY](ExpressionNames.REGR_SXY),
+ Sig[RegrReplacement](ExpressionNames.REGR_REPLACEMENT)
+ )
+ }
+
+ override def runtimeReplaceableExpressionMappings: Seq[Sig] = {
+ Seq(
+ Sig[ArrayCompact](ExpressionNames.ARRAY_COMPACT),
+ Sig[ArrayPrepend](ExpressionNames.ARRAY_PREPEND),
+ Sig[ArraySize](ExpressionNames.ARRAY_SIZE),
+ Sig[EqualNull](ExpressionNames.EQUAL_NULL),
+ Sig[ILike](ExpressionNames.ILIKE),
+ Sig[MapContainsKey](ExpressionNames.MAP_CONTAINS_KEY),
+ Sig[Get](ExpressionNames.GET),
+ Sig[Luhncheck](ExpressionNames.LUHN_CHECK)
+ )
+ }
+
+ override def convertPartitionTransforms(
+ partitions: Seq[Transform]): (Seq[String], Option[BucketSpec]) = {
+ CatalogUtil.convertPartitionTransforms(partitions)
+ }
+
+ override def generateFileScanRDD(
+ sparkSession: SparkSession,
+ readFunction: PartitionedFile => Iterator[InternalRow],
+ filePartitions: Seq[FilePartition],
+ fileSourceScanExec: FileSourceScanExec): FileScanRDD = {
+ new FileScanRDD(
+ sparkSession,
+ readFunction,
+ filePartitions,
+ new StructType(
+ fileSourceScanExec.requiredSchema.fields ++
+ fileSourceScanExec.relation.partitionSchema.fields),
+ fileSourceScanExec.fileConstantMetadataColumns
+ )
+ }
+
+ override def getTextScan(
+ sparkSession: SparkSession,
+ fileIndex: PartitioningAwareFileIndex,
+ dataSchema: StructType,
+ readDataSchema: StructType,
+ readPartitionSchema: StructType,
+ options: CaseInsensitiveStringMap,
+ partitionFilters: Seq[Expression],
+ dataFilters: Seq[Expression]): TextScan = {
+ TextScan(
+ sparkSession,
+ fileIndex,
+ dataSchema,
+ readDataSchema,
+ readPartitionSchema,
+ options,
+ partitionFilters,
+ dataFilters)
+ }
+
+ override def filesGroupedToBuckets(
+ selectedPartitions: Array[PartitionDirectory]): Map[Int, Array[PartitionedFile]] = {
+ selectedPartitions
+ .flatMap(p => p.files.map(f => PartitionedFileUtilShim.getPartitionedFile(f, p.values)))
+ .groupBy {
+ f =>
+ BucketingUtils
+ .getBucketId(f.toPath.getName)
+ .getOrElse(throw invalidBucketFile(f.urlEncodedPath))
+ }
+ }
+
+ override def getBatchScanExecTable(batchScan: BatchScanExec): Table = batchScan.table
+
+ override def generatePartitionedFile(
+ partitionValues: InternalRow,
+ filePath: String,
+ start: Long,
+ length: Long,
+ @transient locations: Array[String] = Array.empty): PartitionedFile =
+ PartitionedFile(partitionValues, SparkPath.fromPathString(filePath), start, length, locations)
+
+ override def bloomFilterExpressionMappings(): Seq[Sig] = Seq(
+ Sig[BloomFilterMightContain](ExpressionNames.MIGHT_CONTAIN),
+ Sig[BloomFilterAggregate](ExpressionNames.BLOOM_FILTER_AGG)
+ )
+
+ override def newBloomFilterAggregate[T](
+ child: Expression,
+ estimatedNumItemsExpression: Expression,
+ numBitsExpression: Expression,
+ mutableAggBufferOffset: Int,
+ inputAggBufferOffset: Int): TypedImperativeAggregate[T] = {
+ BloomFilterAggregate(
+ child,
+ estimatedNumItemsExpression,
+ numBitsExpression,
+ mutableAggBufferOffset,
+ inputAggBufferOffset).asInstanceOf[TypedImperativeAggregate[T]]
+ }
+
+ override def newMightContain(
+ bloomFilterExpression: Expression,
+ valueExpression: Expression): BinaryExpression = {
+ BloomFilterMightContain(bloomFilterExpression, valueExpression)
+ }
+
+ override def replaceBloomFilterAggregate[T](
+ expr: Expression,
+ bloomFilterAggReplacer: (
+ Expression,
+ Expression,
+ Expression,
+ Int,
+ Int) => TypedImperativeAggregate[T]): Expression = expr match {
+ case BloomFilterAggregate(
+ child,
+ estimatedNumItemsExpression,
+ numBitsExpression,
+ mutableAggBufferOffset,
+ inputAggBufferOffset) =>
+ bloomFilterAggReplacer(
+ child,
+ estimatedNumItemsExpression,
+ numBitsExpression,
+ mutableAggBufferOffset,
+ inputAggBufferOffset)
+ case other => other
+ }
+
+ override def replaceMightContain[T](
+ expr: Expression,
+ mightContainReplacer: (Expression, Expression) => BinaryExpression): Expression = expr match {
+ case BloomFilterMightContain(bloomFilterExpression, valueExpression) =>
+ mightContainReplacer(bloomFilterExpression, valueExpression)
+ case other => other
+ }
+
+ override def getFileSizeAndModificationTime(
+ file: PartitionedFile): (Option[Long], Option[Long]) = {
+ (Some(file.fileSize), Some(file.modificationTime))
+ }
+
+ override def generateMetadataColumns(
+ file: PartitionedFile,
+ metadataColumnNames: Seq[String]): Map[String, String] = {
+ val originMetadataColumn = super.generateMetadataColumns(file, metadataColumnNames)
+ val metadataColumn: mutable.Map[String, String] = mutable.Map(originMetadataColumn.toSeq: _*)
+ val path = new Path(file.filePath.toString)
+ for (columnName <- metadataColumnNames) {
+ columnName match {
+ case FileFormat.FILE_PATH => metadataColumn += (FileFormat.FILE_PATH -> path.toString)
+ case FileFormat.FILE_NAME => metadataColumn += (FileFormat.FILE_NAME -> path.getName)
+ case FileFormat.FILE_SIZE =>
+ metadataColumn += (FileFormat.FILE_SIZE -> file.fileSize.toString)
+ case FileFormat.FILE_MODIFICATION_TIME =>
+ val fileModifyTime = TimestampFormatter
+ .getFractionFormatter(ZoneOffset.UTC)
+ .format(file.modificationTime * 1000L)
+ metadataColumn += (FileFormat.FILE_MODIFICATION_TIME -> fileModifyTime)
+ case FileFormat.FILE_BLOCK_START =>
+ metadataColumn += (FileFormat.FILE_BLOCK_START -> file.start.toString)
+ case FileFormat.FILE_BLOCK_LENGTH =>
+ metadataColumn += (FileFormat.FILE_BLOCK_LENGTH -> file.length.toString)
+ case _ =>
+ }
+ }
+ metadataColumn.toMap
+ }
+
+ // https://issues.apache.org/jira/browse/SPARK-40400
+ private def invalidBucketFile(path: String): Throwable = {
+ new SparkException(
+ errorClass = "INVALID_BUCKET_FILE",
+ messageParameters = Map("path" -> path),
+ cause = null)
+ }
+
+ private def getLimit(limit: Int, offset: Int): Int = {
+ if (limit == -1) {
+ // Only offset specified, so fetch the maximum number rows
+ Int.MaxValue
+ } else {
+ assert(limit > offset)
+ limit - offset
+ }
+ }
+
+ override def getLimitAndOffsetFromGlobalLimit(plan: GlobalLimitExec): (Int, Int) = {
+ (getLimit(plan.limit, plan.offset), plan.offset)
+ }
+
+ override def isWindowGroupLimitExec(plan: SparkPlan): Boolean = plan match {
+ case _: WindowGroupLimitExec => true
+ case _ => false
+ }
+
+ override def getWindowGroupLimitExecShim(plan: SparkPlan): WindowGroupLimitExecShim = {
+ val windowGroupLimitPlan = plan.asInstanceOf[WindowGroupLimitExec]
+ val mode = windowGroupLimitPlan.mode match {
+ case Partial => GlutenPartial
+ case Final => GlutenFinal
+ }
+ WindowGroupLimitExecShim(
+ windowGroupLimitPlan.partitionSpec,
+ windowGroupLimitPlan.orderSpec,
+ windowGroupLimitPlan.rankLikeFunction,
+ windowGroupLimitPlan.limit,
+ mode,
+ windowGroupLimitPlan.child
+ )
+ }
+
+ override def getWindowGroupLimitExec(
+ windowGroupLimitExecShim: WindowGroupLimitExecShim): SparkPlan = {
+ val mode = windowGroupLimitExecShim.mode match {
+ case GlutenPartial => Partial
+ case GlutenFinal => Final
+ }
+ WindowGroupLimitExec(
+ windowGroupLimitExecShim.partitionSpec,
+ windowGroupLimitExecShim.orderSpec,
+ windowGroupLimitExecShim.rankLikeFunction,
+ windowGroupLimitExecShim.limit,
+ mode,
+ windowGroupLimitExecShim.child
+ )
+ }
+
+ override def getLimitAndOffsetFromTopK(plan: TakeOrderedAndProjectExec): (Int, Int) = {
+ (getLimit(plan.limit, plan.offset), plan.offset)
+ }
+
+ override def getExtendedColumnarPostRules(): List[SparkSession => Rule[SparkPlan]] = List()
+
+ override def writeFilesExecuteTask(
+ description: WriteJobDescription,
+ jobTrackerID: String,
+ sparkStageId: Int,
+ sparkPartitionId: Int,
+ sparkAttemptNumber: Int,
+ committer: FileCommitProtocol,
+ iterator: Iterator[InternalRow]): WriteTaskResult = {
+ GlutenFileFormatWriter.writeFilesExecuteTask(
+ description,
+ jobTrackerID,
+ sparkStageId,
+ sparkPartitionId,
+ sparkAttemptNumber,
+ committer,
+ iterator
+ )
+ }
+
+ override def enableNativeWriteFilesByDefault(): Boolean = true
+
+ override def broadcastInternal[T: ClassTag](sc: SparkContext, value: T): Broadcast[T] = {
+ SparkContextUtils.broadcastInternal(sc, value)
+ }
+
+ override def setJobDescriptionOrTagForBroadcastExchange(
+ sc: SparkContext,
+ broadcastExchange: BroadcastExchangeLike): Unit = {
+ // Setup a job tag here so later it may get cancelled by tag if necessary.
+ sc.addJobTag(broadcastExchange.jobTag)
+ sc.setInterruptOnCancel(true)
+ }
+
+ override def cancelJobGroupForBroadcastExchange(
+ sc: SparkContext,
+ broadcastExchange: BroadcastExchangeLike): Unit = {
+ sc.cancelJobsWithTag(broadcastExchange.jobTag)
+ }
+
+ override def getShuffleReaderParam[K, C](
+ handle: ShuffleHandle,
+ startMapIndex: Int,
+ endMapIndex: Int,
+ startPartition: Int,
+ endPartition: Int): Tuple2[Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])], Boolean] = {
+ ShuffleUtils.getReaderParam(handle, startMapIndex, endMapIndex, startPartition, endPartition)
+ }
+
+ override def getShuffleAdvisoryPartitionSize(shuffle: ShuffleExchangeLike): Option[Long] =
+ shuffle.advisoryPartitionSize
+
+ override def getPartitionId(taskInfo: TaskInfo): Int = {
+ taskInfo.partitionId
+ }
+
+ override def supportDuplicateReadingTracking: Boolean = true
+
+ def getFileStatus(partition: PartitionDirectory): Seq[(FileStatus, Map[String, Any])] =
+ partition.files.map(f => (f.fileStatus, f.metadata))
+
+ def isFileSplittable(
+ relation: HadoopFsRelation,
+ filePath: Path,
+ sparkSchema: StructType): Boolean = {
+ relation.fileFormat
+ .isSplitable(relation.sparkSession, relation.options, filePath)
+ }
+
+ def isRowIndexMetadataColumn(name: String): Boolean =
+ name == ParquetFileFormat.ROW_INDEX_TEMPORARY_COLUMN_NAME ||
+ name.equalsIgnoreCase("__delta_internal_is_row_deleted")
+
+ def findRowIndexColumnIndexInSchema(sparkSchema: StructType): Int = {
+ sparkSchema.fields.zipWithIndex.find {
+ case (field: StructField, _: Int) =>
+ field.name == ParquetFileFormat.ROW_INDEX_TEMPORARY_COLUMN_NAME
+ } match {
+ case Some((field: StructField, idx: Int)) =>
+ if (field.dataType != LongType && field.dataType != IntegerType) {
+ throw new RuntimeException(
+ s"${ParquetFileFormat.ROW_INDEX_TEMPORARY_COLUMN_NAME} " +
+ "must be of LongType or IntegerType")
+ }
+ idx
+ case _ => -1
+ }
+ }
+
+ def splitFiles(
+ sparkSession: SparkSession,
+ file: FileStatus,
+ filePath: Path,
+ isSplitable: Boolean,
+ maxSplitBytes: Long,
+ partitionValues: InternalRow,
+ metadata: Map[String, Any] = Map.empty): Seq[PartitionedFile] = {
+ PartitionedFileUtilShim.splitFiles(
+ sparkSession,
+ FileStatusWithMetadata(file, metadata),
+ isSplitable,
+ maxSplitBytes,
+ partitionValues)
+ }
+
+ def structFromAttributes(attrs: Seq[Attribute]): StructType = {
+ DataTypeUtils.fromAttributes(attrs)
+ }
+
+ def attributesFromStruct(structType: StructType): Seq[Attribute] = {
+ DataTypeUtils.toAttributes(structType)
+ }
+
+ def getAnalysisExceptionPlan(ae: AnalysisException): Option[LogicalPlan] = {
+ ae match {
+ case eae: ExtendedAnalysisException =>
+ eae.plan
+ case _ =>
+ None
+ }
+ }
+ override def getKeyGroupedPartitioning(batchScan: BatchScanExec): Option[Seq[Expression]] = {
+ batchScan.keyGroupedPartitioning
+ }
+
+ override def getCommonPartitionValues(
+ batchScan: BatchScanExec): Option[Seq[(InternalRow, Int)]] = {
+ batchScan.spjParams.commonPartitionValues
+ }
+
+ // please ref BatchScanExec::inputRDD
+ override def orderPartitions(
+ batchScan: DataSourceV2ScanExecBase,
+ scan: Scan,
+ keyGroupedPartitioning: Option[Seq[Expression]],
+ filteredPartitions: Seq[Seq[InputPartition]],
+ outputPartitioning: Partitioning,
+ commonPartitionValues: Option[Seq[(InternalRow, Int)]],
+ applyPartialClustering: Boolean,
+ replicatePartitions: Boolean,
+ joinKeyPositions: Option[Seq[Int]] = None): Seq[Seq[InputPartition]] = {
+ val original = batchScan.asInstanceOf[BatchScanExecShim]
+ scan match {
+ case _ if keyGroupedPartitioning.isDefined =>
+ outputPartitioning match {
+ case p: KeyGroupedPartitioning =>
+ assert(keyGroupedPartitioning.isDefined)
+ val expressions = keyGroupedPartitioning.get
+
+ // Re-group the input partitions if we are projecting on a subset of join keys
+ val (groupedPartitions, partExpressions) = joinKeyPositions match {
+ case Some(projectPositions) =>
+ val projectedExpressions = projectPositions.map(i => expressions(i))
+ val parts = filteredPartitions.flatten
+ .groupBy(
+ part => {
+ val row = part.asInstanceOf[HasPartitionKey].partitionKey()
+ val projectedRow =
+ KeyGroupedPartitioning.project(expressions, projectPositions, row)
+ InternalRowComparableWrapper(projectedRow, projectedExpressions)
+ })
+ .map { case (wrapper, splits) => (wrapper.row, splits) }
+ .toSeq
+ (parts, projectedExpressions)
+ case _ =>
+ val groupedParts = filteredPartitions.map(
+ splits => {
+ assert(splits.nonEmpty && splits.head.isInstanceOf[HasPartitionKey])
+ (splits.head.asInstanceOf[HasPartitionKey].partitionKey(), splits)
+ })
+ (groupedParts, expressions)
+ }
+
+ // Also re-group the partitions if we are reducing compatible partition expressions
+ val finalGroupedPartitions = original.reducers match {
+ case Some(reducers) =>
+ val result = groupedPartitions
+ .groupBy {
+ case (row, _) =>
+ KeyGroupedShuffleSpec.reducePartitionValue(row, partExpressions, reducers)
+ }
+ .map { case (wrapper, splits) => (wrapper.row, splits.flatMap(_._2)) }
+ .toSeq
+ val rowOrdering =
+ RowOrdering.createNaturalAscendingOrdering(partExpressions.map(_.dataType))
+ result.sorted(rowOrdering.on((t: (InternalRow, _)) => t._1))
+ case _ => groupedPartitions
+ }
+
+ // When partially clustered, the input partitions are not grouped by partition
+ // values. Here we'll need to check `commonPartitionValues` and decide how to group
+ // and replicate splits within a partition.
+ if (commonPartitionValues.isDefined && applyPartialClustering) {
+ // A mapping from the common partition values to how many splits the partition
+ // should contain.
+ val commonPartValuesMap = commonPartitionValues.get
+ .map(t => (InternalRowComparableWrapper(t._1, partExpressions), t._2))
+ .toMap
+ val filteredGroupedPartitions = finalGroupedPartitions.filter {
+ case (partValues, _) =>
+ commonPartValuesMap.keySet.contains(
+ InternalRowComparableWrapper(partValues, partExpressions))
+ }
+ val nestGroupedPartitions = filteredGroupedPartitions.map {
+ case (partValue, splits) =>
+ // `commonPartValuesMap` should contain the part value since it's the super set.
+ val numSplits = commonPartValuesMap
+ .get(InternalRowComparableWrapper(partValue, partExpressions))
+ assert(
+ numSplits.isDefined,
+ s"Partition value $partValue does not exist in " +
+ "common partition values from Spark plan")
+
+ val newSplits = if (replicatePartitions) {
+ // We need to also replicate partitions according to the other side of join
+ Seq.fill(numSplits.get)(splits)
+ } else {
+ // Not grouping by partition values: this could be the side with partially
+ // clustered distribution. Because of dynamic filtering, we'll need to check if
+ // the final number of splits of a partition is smaller than the original
+ // number, and fill with empty splits if so. This is necessary so that both
+ // sides of a join will have the same number of partitions & splits.
+ splits.map(Seq(_)).padTo(numSplits.get, Seq.empty)
+ }
+ (InternalRowComparableWrapper(partValue, partExpressions), newSplits)
+ }
+
+ // Now fill missing partition keys with empty partitions
+ val partitionMapping = nestGroupedPartitions.toMap
+ commonPartitionValues.get.flatMap {
+ case (partValue, numSplits) =>
+ // Use empty partition for those partition values that are not present.
+ partitionMapping.getOrElse(
+ InternalRowComparableWrapper(partValue, partExpressions),
+ Seq.fill(numSplits)(Seq.empty))
+ }
+ } else {
+ // either `commonPartitionValues` is not defined, or it is defined but
+ // `applyPartialClustering` is false.
+ val partitionMapping = finalGroupedPartitions.map {
+ case (partValue, splits) =>
+ InternalRowComparableWrapper(partValue, partExpressions) -> splits
+ }.toMap
+
+ // In case `commonPartitionValues` is not defined (e.g., SPJ is not used), there
+ // could exist duplicated partition values, as partition grouping is not done
+ // at the beginning and postponed to this method. It is important to use unique
+ // partition values here so that grouped partitions won't get duplicated.
+ p.uniquePartitionValues.map {
+ partValue =>
+ // Use empty partition for those partition values that are not present
+ partitionMapping.getOrElse(
+ InternalRowComparableWrapper(partValue, partExpressions),
+ Seq.empty)
+ }
+ }
+
+ case _ => filteredPartitions
+ }
+ case _ =>
+ filteredPartitions
+ }
+ }
+
+ override def withTryEvalMode(expr: Expression): Boolean = {
+ expr match {
+ case a: Add => a.evalMode == EvalMode.TRY
+ case s: Subtract => s.evalMode == EvalMode.TRY
+ case d: Divide => d.evalMode == EvalMode.TRY
+ case m: Multiply => m.evalMode == EvalMode.TRY
+ case c: Cast => c.evalMode == EvalMode.TRY
+ case _ => false
+ }
+ }
+
+ override def withAnsiEvalMode(expr: Expression): Boolean = {
+ expr match {
+ case a: Add => a.evalMode == EvalMode.ANSI
+ case s: Subtract => s.evalMode == EvalMode.ANSI
+ case d: Divide => d.evalMode == EvalMode.ANSI
+ case m: Multiply => m.evalMode == EvalMode.ANSI
+ case c: Cast => c.evalMode == EvalMode.ANSI
+ case i: IntegralDivide => i.evalMode == EvalMode.ANSI
+ case _ => false
+ }
+ }
+
+ override def dateTimestampFormatInReadIsDefaultValue(
+ csvOptions: CSVOptions,
+ timeZone: String): Boolean = {
+ val default = new CSVOptions(CaseInsensitiveMap(Map()), csvOptions.columnPruning, timeZone)
+ csvOptions.dateFormatInRead == default.dateFormatInRead &&
+ csvOptions.timestampFormatInRead == default.timestampFormatInRead &&
+ csvOptions.timestampNTZFormatInRead == default.timestampNTZFormatInRead
+ }
+
+ override def createParquetFilters(
+ conf: SQLConf,
+ schema: MessageType,
+ caseSensitive: Option[Boolean] = None): ParquetFilters = {
+ new ParquetFilters(
+ schema,
+ conf.parquetFilterPushDownDate,
+ conf.parquetFilterPushDownTimestamp,
+ conf.parquetFilterPushDownDecimal,
+ conf.parquetFilterPushDownStringPredicate,
+ conf.parquetFilterPushDownInFilterThreshold,
+ caseSensitive.getOrElse(conf.caseSensitiveAnalysis),
+ RebaseSpec(LegacyBehaviorPolicy.CORRECTED)
+ )
+ }
+
+ override def extractExpressionArrayInsert(arrayInsert: Expression): Seq[Expression] = {
+ val expr = arrayInsert.asInstanceOf[ArrayInsert]
+ Seq(expr.srcArrayExpr, expr.posExpr, expr.itemExpr, Literal(expr.legacyNegativeIndex))
+ }
+
+ override def withOperatorIdMap[T](idMap: java.util.Map[QueryPlan[_], Int])(body: => T): T = {
+ val prevIdMap = QueryPlan.localIdMap.get()
+ try {
+ QueryPlan.localIdMap.set(idMap)
+ body
+ } finally {
+ QueryPlan.localIdMap.set(prevIdMap)
+ }
+ }
+
+ override def getOperatorId(plan: QueryPlan[_]): Option[Int] = {
+ Option(QueryPlan.localIdMap.get().get(plan))
+ }
+
+ override def setOperatorId(plan: QueryPlan[_], opId: Int): Unit = {
+ val map = QueryPlan.localIdMap.get()
+ assert(!map.containsKey(plan))
+ map.put(plan, opId)
+ }
+
+ override def unsetOperatorId(plan: QueryPlan[_]): Unit = {
+ QueryPlan.localIdMap.get().remove(plan)
+ }
+
+ override def isParquetFileEncrypted(footer: ParquetMetadata): Boolean = {
+ footer.getFileMetaData.getEncryptionType match {
+ // UNENCRYPTED file has a plaintext footer and no file encryption,
+ // We can leverage file metadata for this check and return unencrypted.
+ case EncryptionType.UNENCRYPTED =>
+ false
+ // PLAINTEXT_FOOTER has a plaintext footer however the file is encrypted.
+ // In such cases, read the footer and use the metadata for encryption check.
+ case EncryptionType.PLAINTEXT_FOOTER =>
+ true
+ case _ =>
+ false
+ }
+ }
+
+ override def getOtherConstantMetadataColumnValues(file: PartitionedFile): JMap[String, Object] =
+ file.otherConstantMetadataColumnValues.asJava.asInstanceOf[JMap[String, Object]]
+
+ override def getCollectLimitOffset(plan: CollectLimitExec): Int = {
+ plan.offset
+ }
+
+ override def unBase64FunctionFailsOnError(unBase64: UnBase64): Boolean = unBase64.failOnError
+
+ override def extractExpressionTimestampAddUnit(exp: Expression): Option[Seq[String]] = {
+ exp match {
+ // Velox does not support quantity larger than Int.MaxValue.
+ case TimestampAdd(_, LongLiteral(quantity), _, _) if quantity > Integer.MAX_VALUE =>
+ Option.empty
+ case timestampAdd: TimestampAdd =>
+ Option.apply(Seq(timestampAdd.unit, timestampAdd.timeZoneId.getOrElse("")))
+ case _ => Option.empty
+ }
+ }
+
+ override def extractExpressionTimestampDiffUnit(exp: Expression): Option[String] = {
+ exp match {
+ case timestampDiff: TimestampDiff =>
+ Some(timestampDiff.unit)
+ case _ => Option.empty
+ }
+ }
+
+ override def widerDecimalType(d1: DecimalType, d2: DecimalType): DecimalType = {
+ DecimalPrecisionTypeCoercion.widerDecimalType(d1, d2)
+ }
+
+ override def getErrorMessage(raiseError: RaiseError): Option[Expression] = {
+ raiseError.errorParms match {
+ case CreateMap(children, _)
+ if children.size == 2 && children.head.isInstanceOf[Literal]
+ && children.head.asInstanceOf[Literal].value.toString == "errorMessage" =>
+ Some(children(1))
+ case _ =>
+ None
+ }
+ }
+
+ override def throwExceptionInWrite(
+ t: Throwable,
+ writePath: String,
+ descriptionPath: String): Unit = {
+ throw t
+ }
+
+ override def enrichWriteException(cause: Throwable, path: String): Nothing = {
+ GlutenFileFormatWriter.wrapWriteError(cause, path)
+ }
+ override def getFileSourceScanStream(scan: FileSourceScanExec): Option[SparkDataStream] = {
+ scan.stream
+ }
+
+ override def unsupportedCodec: Seq[CompressionCodecName] = {
+ Seq(CompressionCodecName.LZO, CompressionCodecName.BROTLI, CompressionCodecName.LZ4_RAW)
+ }
+
+ /**
+ * Shim layer for QueryExecution to maintain compatibility across different Spark versions.
+ *
+ * @since Spark
+ * 4.1
+ */
+ override def createSparkPlan(
+ sparkSession: SparkSession,
+ planner: SparkPlanner,
+ plan: LogicalPlan): SparkPlan =
+ QueryExecution.createSparkPlan(planner, plan)
+}
diff --git a/shims/spark41/src/main/scala/org/apache/gluten/sql/shims/spark41/SparkShimProvider.scala b/shims/spark41/src/main/scala/org/apache/gluten/sql/shims/spark41/SparkShimProvider.scala
new file mode 100644
index 000000000000..9a5b6a63196d
--- /dev/null
+++ b/shims/spark41/src/main/scala/org/apache/gluten/sql/shims/spark41/SparkShimProvider.scala
@@ -0,0 +1,25 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.gluten.sql.shims.spark41
+
+import org.apache.gluten.sql.shims.SparkShims
+
+class SparkShimProvider extends org.apache.gluten.sql.shims.SparkShimProvider {
+ def createShim: SparkShims = {
+ new Spark41Shims()
+ }
+}
diff --git a/shims/spark41/src/main/scala/org/apache/gluten/utils/InternalRowUtl.scala b/shims/spark41/src/main/scala/org/apache/gluten/utils/InternalRowUtl.scala
new file mode 100644
index 000000000000..654e43cbd03f
--- /dev/null
+++ b/shims/spark41/src/main/scala/org/apache/gluten/utils/InternalRowUtl.scala
@@ -0,0 +1,33 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.gluten.utils
+
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
+import org.apache.spark.sql.types.StructType
+
+object InternalRowUtl {
+ def toString(struct: StructType, rows: Iterator[InternalRow]): String = {
+ val encoder = ExpressionEncoder(struct).resolveAndBind()
+ val deserializer = encoder.createDeserializer()
+ rows.map(deserializer).mkString(System.lineSeparator())
+ }
+
+ def toString(struct: StructType, rows: Iterator[InternalRow], start: Int, length: Int): String = {
+ toString(struct, rows.slice(start, start + length))
+ }
+}
diff --git a/shims/spark41/src/main/scala/org/apache/spark/ShuffleUtils.scala b/shims/spark41/src/main/scala/org/apache/spark/ShuffleUtils.scala
new file mode 100644
index 000000000000..c2a6cd5cffc1
--- /dev/null
+++ b/shims/spark41/src/main/scala/org/apache/spark/ShuffleUtils.scala
@@ -0,0 +1,48 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark
+
+import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleHandle}
+import org.apache.spark.storage.{BlockId, BlockManagerId}
+
+object ShuffleUtils {
+ def getReaderParam[K, C](
+ handle: ShuffleHandle,
+ startMapIndex: Int,
+ endMapIndex: Int,
+ startPartition: Int,
+ endPartition: Int): Tuple2[Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])], Boolean] = {
+ val baseShuffleHandle = handle.asInstanceOf[BaseShuffleHandle[K, _, C]]
+ if (baseShuffleHandle.dependency.isShuffleMergeFinalizedMarked) {
+ val res = SparkEnv.get.mapOutputTracker.getPushBasedShuffleMapSizesByExecutorId(
+ handle.shuffleId,
+ startMapIndex,
+ endMapIndex,
+ startPartition,
+ endPartition)
+ (res.iter.map(b => (b._1, b._2.toSeq)), res.enableBatchFetch)
+ } else {
+ val address = SparkEnv.get.mapOutputTracker.getMapSizesByExecutorId(
+ handle.shuffleId,
+ startMapIndex,
+ endMapIndex,
+ startPartition,
+ endPartition)
+ (address.map(b => (b._1, b._2.toSeq)), true)
+ }
+ }
+}
diff --git a/shims/spark41/src/main/scala/org/apache/spark/SparkContextUtils.scala b/shims/spark41/src/main/scala/org/apache/spark/SparkContextUtils.scala
new file mode 100644
index 000000000000..3cbf2b602d47
--- /dev/null
+++ b/shims/spark41/src/main/scala/org/apache/spark/SparkContextUtils.scala
@@ -0,0 +1,27 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark
+
+import org.apache.spark.broadcast.Broadcast
+
+import scala.reflect.ClassTag
+
+object SparkContextUtils {
+ def broadcastInternal[T: ClassTag](sc: SparkContext, value: T): Broadcast[T] = {
+ sc.broadcastInternal(value, serializedOnly = true)
+ }
+}
diff --git a/shims/spark41/src/main/scala/org/apache/spark/shuffle/SparkSortShuffleWriterUtil.scala b/shims/spark41/src/main/scala/org/apache/spark/shuffle/SparkSortShuffleWriterUtil.scala
new file mode 100644
index 000000000000..95b15f04e7cb
--- /dev/null
+++ b/shims/spark41/src/main/scala/org/apache/spark/shuffle/SparkSortShuffleWriterUtil.scala
@@ -0,0 +1,32 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.shuffle
+
+import org.apache.spark.TaskContext
+import org.apache.spark.shuffle.api.ShuffleExecutorComponents
+import org.apache.spark.shuffle.sort.SortShuffleWriter
+
+object SparkSortShuffleWriterUtil {
+ def create[K, V, C](
+ handle: BaseShuffleHandle[K, V, C],
+ mapId: Long,
+ context: TaskContext,
+ writeMetrics: ShuffleWriteMetricsReporter,
+ shuffleExecutorComponents: ShuffleExecutorComponents): ShuffleWriter[K, V] = {
+ new SortShuffleWriter(handle, mapId, context, writeMetrics, shuffleExecutorComponents)
+ }
+}
diff --git a/shims/spark41/src/main/scala/org/apache/spark/sql/catalyst/expressions/PromotePrecision.scala b/shims/spark41/src/main/scala/org/apache/spark/sql/catalyst/expressions/PromotePrecision.scala
new file mode 100644
index 000000000000..b18a79b864e2
--- /dev/null
+++ b/shims/spark41/src/main/scala/org/apache/spark/sql/catalyst/expressions/PromotePrecision.scala
@@ -0,0 +1,34 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.catalyst.expressions
+
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
+import org.apache.spark.sql.types._
+
+case class PromotePrecision(child: Expression) extends UnaryExpression {
+ override def dataType: DataType = child.dataType
+ override def eval(input: InternalRow): Any = child.eval(input)
+ override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode =
+ child.genCode(ctx)
+ override def prettyName: String = "promote_precision"
+ override def sql: String = child.sql
+ override lazy val canonicalized: Expression = child.canonicalized
+
+ override protected def withNewChildInternal(newChild: Expression): Expression =
+ copy(child = newChild)
+}
diff --git a/shims/spark41/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/InvokeExtractors.scala b/shims/spark41/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/InvokeExtractors.scala
new file mode 100644
index 000000000000..8abebd0d2a8b
--- /dev/null
+++ b/shims/spark41/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/InvokeExtractors.scala
@@ -0,0 +1,46 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.catalyst.expressions.objects
+
+import org.apache.spark.sql.catalyst.expressions.{Expression, Literal}
+import org.apache.spark.sql.catalyst.expressions.json.StructsToJsonEvaluator
+
+/**
+ * Extractors for Invoke expressions to ensure compatibility across different Spark versions.
+ *
+ * Since Spark 4.0, StructsToJson has been replaced with Invoke expressions using
+ * StructsToJsonEvaluator. This extractor provides a unified interface to extract evaluator options,
+ * child expression, and timeZoneId from the Invoke pattern.
+ */
+object StructsToJsonInvoke {
+ def unapply(expr: Expression): Option[(Map[String, String], Expression, Option[String])] = {
+ expr match {
+ case Invoke(
+ Literal(evaluator: StructsToJsonEvaluator, _),
+ "evaluate",
+ _,
+ Seq(child),
+ _,
+ _,
+ _,
+ _) =>
+ Some((evaluator.options, child, evaluator.timeZoneId))
+ case _ =>
+ None
+ }
+ }
+}
diff --git a/shims/spark41/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CollapseProjectShim.scala b/shims/spark41/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CollapseProjectShim.scala
new file mode 100644
index 000000000000..1df1456f4011
--- /dev/null
+++ b/shims/spark41/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CollapseProjectShim.scala
@@ -0,0 +1,34 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.catalyst.optimizer
+
+import org.apache.spark.sql.catalyst.expressions.{Expression, NamedExpression}
+
+object CollapseProjectShim {
+ def canCollapseExpressions(
+ consumers: Seq[Expression],
+ producers: Seq[NamedExpression],
+ alwaysInline: Boolean): Boolean = {
+ CollapseProject.canCollapseExpressions(consumers, producers, alwaysInline)
+ }
+
+ def buildCleanedProjectList(
+ upper: Seq[NamedExpression],
+ lower: Seq[NamedExpression]): Seq[NamedExpression] = {
+ CollapseProject.buildCleanedProjectList(upper, lower)
+ }
+}
diff --git a/shims/spark41/src/main/scala/org/apache/spark/sql/classic/ClassicColumn.scala b/shims/spark41/src/main/scala/org/apache/spark/sql/classic/ClassicColumn.scala
new file mode 100644
index 000000000000..7fe63d706e5d
--- /dev/null
+++ b/shims/spark41/src/main/scala/org/apache/spark/sql/classic/ClassicColumn.scala
@@ -0,0 +1,34 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.classic
+
+import org.apache.spark.sql.Column
+import org.apache.spark.sql.catalyst.expressions.Expression
+import org.apache.spark.sql.classic.ClassicConversions._
+
+/**
+ * Ensures compatibility with Spark 4.0. The implicit class ColumnConstructorExt from
+ * ClassicConversions is used to construct a Column from an Expression. Since Spark 4.0, the Column
+ * class is private to the package org.apache.spark. This class provides a way to construct a Column
+ * in code that is outside the org.apache.spark package. Developers can directly call Column(e) if
+ * ColumnConstructorExt is imported and the caller code is within this package.
+ */
+object ClassicColumn {
+ def apply(e: Expression): Column = {
+ Column(e)
+ }
+}
diff --git a/shims/spark41/src/main/scala/org/apache/spark/sql/classic/ClassicDataset.scala b/shims/spark41/src/main/scala/org/apache/spark/sql/classic/ClassicDataset.scala
new file mode 100644
index 000000000000..1a48a8e110a7
--- /dev/null
+++ b/shims/spark41/src/main/scala/org/apache/spark/sql/classic/ClassicDataset.scala
@@ -0,0 +1,28 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.classic
+
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.classic
+
+/** Since Spark 4.0, the method ofRows cannot be invoked directly from sql.Dataset. */
+object ClassicDataset {
+ def ofRows(sparkSession: classic.SparkSession, logicalPlan: LogicalPlan): DataFrame = {
+ // Redirect to the classic.Dataset companion method.
+ classic.Dataset.ofRows(sparkSession, logicalPlan)
+ }
+}
diff --git a/shims/spark41/src/main/scala/org/apache/spark/sql/classic/ClassicTypes.scala b/shims/spark41/src/main/scala/org/apache/spark/sql/classic/ClassicTypes.scala
new file mode 100644
index 000000000000..ea3fb8893f9b
--- /dev/null
+++ b/shims/spark41/src/main/scala/org/apache/spark/sql/classic/ClassicTypes.scala
@@ -0,0 +1,28 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.classic
+
+import org.apache.spark.sql
+
+/**
+ * Prior to Spark 4.0, `ClassicSparkSession` refers to `sql.SparkSession`. Since Spark 4.0, it
+ * refers to `sql.classic.SparkSession`.
+ */
+object ClassicTypes {
+
+ type ClassicSparkSession = sql.classic.SparkSession
+}
diff --git a/shims/spark41/src/main/scala/org/apache/spark/sql/classic/conversions.scala b/shims/spark41/src/main/scala/org/apache/spark/sql/classic/conversions.scala
new file mode 100644
index 000000000000..b168838f058c
--- /dev/null
+++ b/shims/spark41/src/main/scala/org/apache/spark/sql/classic/conversions.scala
@@ -0,0 +1,33 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.classic
+
+import org.apache.spark.sql
+
+/**
+ * Enables access to the methods in the companion object classic.SparkSession via sql.SparkSession.
+ * Since Spark 4.0, these methods have been moved from sql.SparkSession to sql.classic.SparkSession.
+ */
+object ExtendedClassicConversions {
+
+ implicit class RichSqlSparkSession(sqlSparkSession: sql.SparkSession.type) {
+ def cleanupAnyExistingSession(): Unit = {
+ // Redirect to the classic.SparkSession companion method.
+ sql.classic.SparkSession.cleanupAnyExistingSession()
+ }
+ }
+}
diff --git a/shims/spark41/src/main/scala/org/apache/spark/sql/execution/AbstractFileSourceScanExec.scala b/shims/spark41/src/main/scala/org/apache/spark/sql/execution/AbstractFileSourceScanExec.scala
new file mode 100644
index 000000000000..ac17c259b87c
--- /dev/null
+++ b/shims/spark41/src/main/scala/org/apache/spark/sql/execution/AbstractFileSourceScanExec.scala
@@ -0,0 +1,307 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution
+
+import org.apache.gluten.execution.PartitionedFileUtilShim
+
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.{FileSourceOptions, InternalRow, TableIdentifier}
+import org.apache.spark.sql.catalyst.catalog.BucketSpec
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
+import org.apache.spark.sql.errors.QueryExecutionErrors
+import org.apache.spark.sql.execution.datasources._
+import org.apache.spark.sql.execution.datasources.parquet.{ParquetFileFormat => ParquetSource}
+import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.vectorized.ColumnarBatch
+import org.apache.spark.util.collection.BitSet
+
+import org.apache.hadoop.fs.Path
+
+import java.util.concurrent.TimeUnit._
+
+/**
+ * Physical plan node for scanning data from HadoopFsRelations.
+ *
+ * @param relation
+ * The file-based relation to scan.
+ * @param output
+ * Output attributes of the scan, including data attributes and partition attributes.
+ * @param requiredSchema
+ * Required schema of the underlying relation, excluding partition columns.
+ * @param partitionFilters
+ * Predicates to use for partition pruning.
+ * @param optionalBucketSet
+ * Bucket ids for bucket pruning.
+ * @param optionalNumCoalescedBuckets
+ * Number of coalesced buckets.
+ * @param dataFilters
+ * Filters on non-partition columns.
+ * @param tableIdentifier
+ * Identifier for the table in the metastore.
+ * @param disableBucketedScan
+ * Disable bucketed scan based on physical query plan, see rule [[DisableUnnecessaryBucketedScan]]
+ * for details.
+ */
+abstract class AbstractFileSourceScanExec(
+ @transient override val relation: HadoopFsRelation,
+ override val output: Seq[Attribute],
+ override val requiredSchema: StructType,
+ override val partitionFilters: Seq[Expression],
+ override val optionalBucketSet: Option[BitSet],
+ override val optionalNumCoalescedBuckets: Option[Int],
+ override val dataFilters: Seq[Expression],
+ override val tableIdentifier: Option[TableIdentifier],
+ override val disableBucketedScan: Boolean = false)
+ extends FileSourceScanLike {
+
+ override def supportsColumnar: Boolean = {
+ // The value should be defined in GlutenPlan.
+ throw new UnsupportedOperationException(
+ "Unreachable code from org.apache.spark.sql.execution.AbstractFileSourceScanExec" +
+ ".supportsColumnar")
+ }
+
+ private lazy val needsUnsafeRowConversion: Boolean = {
+ if (relation.fileFormat.isInstanceOf[ParquetSource]) {
+ conf.parquetVectorizedReaderEnabled
+ } else {
+ false
+ }
+ }
+
+ lazy val inputRDD: RDD[InternalRow] = {
+ val options = relation.options +
+ (FileFormat.OPTION_RETURNING_BATCH -> supportsColumnar.toString)
+ val readFile: (PartitionedFile) => Iterator[InternalRow] =
+ relation.fileFormat.buildReaderWithPartitionValues(
+ sparkSession = relation.sparkSession,
+ dataSchema = relation.dataSchema,
+ partitionSchema = relation.partitionSchema,
+ requiredSchema = requiredSchema,
+ filters = pushedDownFilters,
+ options = options,
+ hadoopConf = relation.sparkSession.sessionState.newHadoopConfWithOptions(relation.options)
+ )
+
+ val readRDD = if (bucketedScan) {
+ createBucketedReadRDD(relation.bucketSpec.get, readFile, dynamicallySelectedPartitions)
+ } else {
+ createReadRDD(readFile, dynamicallySelectedPartitions)
+ }
+ sendDriverMetrics()
+ readRDD
+ }
+
+ override def inputRDDs(): Seq[RDD[InternalRow]] = {
+ inputRDD :: Nil
+ }
+
+ override protected def doExecute(): RDD[InternalRow] = {
+ val numOutputRows = longMetric("numOutputRows")
+ if (needsUnsafeRowConversion) {
+ inputRDD.mapPartitionsWithIndexInternal {
+ (index, iter) =>
+ val toUnsafe = UnsafeProjection.create(schema)
+ toUnsafe.initialize(index)
+ iter.map {
+ row =>
+ numOutputRows += 1
+ toUnsafe(row)
+ }
+ }
+ } else {
+ inputRDD.mapPartitionsInternal {
+ iter =>
+ iter.map {
+ row =>
+ numOutputRows += 1
+ row
+ }
+ }
+ }
+ }
+
+ override protected def doExecuteColumnar(): RDD[ColumnarBatch] = {
+ val numOutputRows = longMetric("numOutputRows")
+ val scanTime = longMetric("scanTime")
+ inputRDD.asInstanceOf[RDD[ColumnarBatch]].mapPartitionsInternal {
+ batches =>
+ new Iterator[ColumnarBatch] {
+
+ override def hasNext: Boolean = {
+ // The `FileScanRDD` returns an iterator which scans the file during the `hasNext` call.
+ val startNs = System.nanoTime()
+ val res = batches.hasNext
+ scanTime += NANOSECONDS.toMillis(System.nanoTime() - startNs)
+ res
+ }
+
+ override def next(): ColumnarBatch = {
+ val batch = batches.next()
+ numOutputRows += batch.numRows()
+ batch
+ }
+ }
+ }
+ }
+
+ override val nodeNamePrefix: String = "File"
+
+ /**
+ * Create an RDD for bucketed reads. The non-bucketed variant of this function is
+ * [[createReadRDD]].
+ *
+ * The algorithm is pretty simple: each RDD partition being returned should include all the files
+ * with the same bucket id from all the given Hive partitions.
+ *
+ * @param bucketSpec
+ * the bucketing spec.
+ * @param readFile
+ * a function to read each (part of a) file.
+ * @param selectedPartitions
+ * Hive-style partition that are part of the read.
+ */
+ private def createBucketedReadRDD(
+ bucketSpec: BucketSpec,
+ readFile: (PartitionedFile) => Iterator[InternalRow],
+ selectedPartitions: ScanFileListing): RDD[InternalRow] = {
+ logInfo(s"Planning with ${bucketSpec.numBuckets} buckets")
+ val partitionArray = selectedPartitions.toPartitionArray
+ val filesGroupedToBuckets = partitionArray.groupBy {
+ f =>
+ BucketingUtils
+ .getBucketId(f.toPath.getName)
+ .getOrElse(throw QueryExecutionErrors.invalidBucketFile(f.urlEncodedPath))
+ }
+
+ val prunedFilesGroupedToBuckets = if (optionalBucketSet.isDefined) {
+ val bucketSet = optionalBucketSet.get
+ filesGroupedToBuckets.filter(f => bucketSet.get(f._1))
+ } else {
+ filesGroupedToBuckets
+ }
+
+ val filePartitions = optionalNumCoalescedBuckets
+ .map {
+ numCoalescedBuckets =>
+ logInfo(s"Coalescing to $numCoalescedBuckets buckets")
+ val coalescedBuckets = prunedFilesGroupedToBuckets.groupBy(_._1 % numCoalescedBuckets)
+ Seq.tabulate(numCoalescedBuckets) {
+ bucketId =>
+ val partitionedFiles = coalescedBuckets
+ .get(bucketId)
+ .map {
+ _.values.flatten.toArray
+ }
+ .getOrElse(Array.empty)
+ FilePartition(bucketId, partitionedFiles)
+ }
+ }
+ .getOrElse {
+ Seq.tabulate(bucketSpec.numBuckets) {
+ bucketId =>
+ FilePartition(bucketId, prunedFilesGroupedToBuckets.getOrElse(bucketId, Array.empty))
+ }
+ }
+
+ new FileScanRDD(
+ relation.sparkSession,
+ readFile,
+ filePartitions,
+ new StructType(requiredSchema.fields ++ relation.partitionSchema.fields),
+ fileConstantMetadataColumns,
+ relation.fileFormat.fileConstantMetadataExtractors,
+ new FileSourceOptions(CaseInsensitiveMap(relation.options))
+ )
+ }
+
+ /**
+ * Create an RDD for non-bucketed reads. The bucketed variant of this function is
+ * [[createBucketedReadRDD]].
+ *
+ * @param readFile
+ * a function to read each (part of a) file.
+ * @param selectedPartitions
+ * Hive-style partition that are part of the read.
+ */
+ private def createReadRDD(
+ readFile: (PartitionedFile) => Iterator[InternalRow],
+ selectedPartitions: ScanFileListing): RDD[InternalRow] = {
+ val openCostInBytes = relation.sparkSession.sessionState.conf.filesOpenCostInBytes
+ val maxSplitBytes =
+ FilePartition.maxSplitBytes(relation.sparkSession, selectedPartitions)
+ logInfo(
+ s"Planning scan with bin packing, max size: $maxSplitBytes bytes, " +
+ s"open cost is considered as scanning $openCostInBytes bytes.")
+
+ // Filter files with bucket pruning if possible
+ val bucketingEnabled = relation.sparkSession.sessionState.conf.bucketingEnabled
+ val shouldProcess: Path => Boolean = optionalBucketSet match {
+ case Some(bucketSet) if bucketingEnabled =>
+ // Do not prune the file if bucket file name is invalid
+ filePath => BucketingUtils.getBucketId(filePath.getName).forall(bucketSet.get)
+ case _ =>
+ _ => true
+ }
+
+ val splitFiles = selectedPartitions.filePartitionIterator
+ .flatMap {
+ partition =>
+ partition.files.flatMap {
+ file =>
+ if (shouldProcess(file.getPath)) {
+ val isSplitable = relation.fileFormat.isSplitable(
+ relation.sparkSession,
+ relation.options,
+ file.getPath)
+ PartitionedFileUtilShim.splitFiles(
+ sparkSession = relation.sparkSession,
+ file = file,
+ isSplitable = isSplitable,
+ maxSplitBytes = maxSplitBytes,
+ partitionValues = partition.values
+ )
+ } else {
+ Seq.empty
+ }
+ }
+ }
+ .toArray
+ .sortBy(_.length)(implicitly[Ordering[Long]].reverse)
+
+ val partitions =
+ FilePartition.getFilePartitions(relation.sparkSession, splitFiles, maxSplitBytes)
+
+ new FileScanRDD(
+ relation.sparkSession,
+ readFile,
+ partitions,
+ new StructType(requiredSchema.fields ++ relation.partitionSchema.fields),
+ fileConstantMetadataColumns,
+ relation.fileFormat.fileConstantMetadataExtractors,
+ new FileSourceOptions(CaseInsensitiveMap(relation.options))
+ )
+ }
+
+ // Filters unused DynamicPruningExpression expressions - one which has been replaced
+ // with DynamicPruningExpression(Literal.TrueLiteral) during Physical Planning
+ protected def filterUnusedDynamicPruningExpressions(
+ predicates: Seq[Expression]): Seq[Expression] = {
+ predicates.filterNot(_ == DynamicPruningExpression(Literal.TrueLiteral))
+ }
+}
diff --git a/shims/spark41/src/main/scala/org/apache/spark/sql/execution/ExpandOutputPartitioningShim.scala b/shims/spark41/src/main/scala/org/apache/spark/sql/execution/ExpandOutputPartitioningShim.scala
new file mode 100644
index 000000000000..791490064a96
--- /dev/null
+++ b/shims/spark41/src/main/scala/org/apache/spark/sql/execution/ExpandOutputPartitioningShim.scala
@@ -0,0 +1,93 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution
+
+import org.apache.spark.sql.catalyst.expressions.Expression
+import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioningLike, Partitioning, PartitioningCollection}
+
+import scala.collection.mutable
+
+// https://issues.apache.org/jira/browse/SPARK-31869
+class ExpandOutputPartitioningShim(
+ streamedKeyExprs: Seq[Expression],
+ buildKeyExprs: Seq[Expression],
+ expandLimit: Int) {
+ // An one-to-many mapping from a streamed key to build keys.
+ private lazy val streamedKeyToBuildKeyMapping = {
+ val mapping = mutable.Map.empty[Expression, Seq[Expression]]
+ streamedKeyExprs.zip(buildKeyExprs).foreach {
+ case (streamedKey, buildKey) =>
+ val key = streamedKey.canonicalized
+ mapping.get(key) match {
+ case Some(v) => mapping.put(key, v :+ buildKey)
+ case None => mapping.put(key, Seq(buildKey))
+ }
+ }
+ mapping.toMap
+ }
+
+ def expandPartitioning(partitioning: Partitioning): Partitioning = {
+ partitioning match {
+ case h: HashPartitioningLike => expandOutputPartitioning(h)
+ case c: PartitioningCollection => expandOutputPartitioning(c)
+ case _ => partitioning
+ }
+ }
+
+ // Expands the given partitioning collection recursively.
+ private def expandOutputPartitioning(
+ partitioning: PartitioningCollection): PartitioningCollection = {
+ PartitioningCollection(partitioning.partitionings.flatMap {
+ case h: HashPartitioningLike => expandOutputPartitioning(h).partitionings
+ case c: PartitioningCollection => Seq(expandOutputPartitioning(c))
+ case other => Seq(other)
+ })
+ }
+
+ // Expands the given hash partitioning by substituting streamed keys with build keys.
+ // For example, if the expressions for the given partitioning are Seq("a", "b", "c")
+ // where the streamed keys are Seq("b", "c") and the build keys are Seq("x", "y"),
+ // the expanded partitioning will have the following expressions:
+ // Seq("a", "b", "c"), Seq("a", "b", "y"), Seq("a", "x", "c"), Seq("a", "x", "y").
+ // The expanded expressions are returned as PartitioningCollection.
+ private def expandOutputPartitioning(
+ partitioning: HashPartitioningLike): PartitioningCollection = {
+ val maxNumCombinations = expandLimit
+ var currentNumCombinations = 0
+
+ def generateExprCombinations(
+ current: Seq[Expression],
+ accumulated: Seq[Expression]): Seq[Seq[Expression]] = {
+ if (currentNumCombinations >= maxNumCombinations) {
+ Nil
+ } else if (current.isEmpty) {
+ currentNumCombinations += 1
+ Seq(accumulated)
+ } else {
+ val buildKeysOpt = streamedKeyToBuildKeyMapping.get(current.head.canonicalized)
+ generateExprCombinations(current.tail, accumulated :+ current.head) ++
+ buildKeysOpt
+ .map(_.flatMap(b => generateExprCombinations(current.tail, accumulated :+ b)))
+ .getOrElse(Nil)
+ }
+ }
+
+ PartitioningCollection(
+ generateExprCombinations(partitioning.expressions, Nil)
+ .map(exprs => partitioning.withNewChildren(exprs).asInstanceOf[HashPartitioningLike]))
+ }
+}
diff --git a/shims/spark41/src/main/scala/org/apache/spark/sql/execution/FileSourceScanExecShim.scala b/shims/spark41/src/main/scala/org/apache/spark/sql/execution/FileSourceScanExecShim.scala
new file mode 100644
index 000000000000..d9a50f65c711
--- /dev/null
+++ b/shims/spark41/src/main/scala/org/apache/spark/sql/execution/FileSourceScanExecShim.scala
@@ -0,0 +1,284 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution
+
+import org.apache.gluten.metrics.GlutenTimeMetric
+import org.apache.gluten.sql.shims.SparkShimLoader
+
+import org.apache.spark.Partition
+import org.apache.spark.internal.LogKeys.{COUNT, MAX_SPLIT_BYTES, OPEN_COST_IN_BYTES}
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier}
+import org.apache.spark.sql.catalyst.catalog.BucketSpec
+import org.apache.spark.sql.catalyst.expressions.{And, Attribute, AttributeReference, BoundReference, Expression, FileSourceConstantMetadataAttribute, FileSourceGeneratedMetadataAttribute, PlanExpression, Predicate}
+import org.apache.spark.sql.connector.read.streaming.SparkDataStream
+import org.apache.spark.sql.errors.QueryExecutionErrors
+import org.apache.spark.sql.execution.datasources.{BucketingUtils, FilePartition, HadoopFsRelation, PartitionDirectory}
+import org.apache.spark.sql.execution.datasources.parquet.ParquetUtils
+import org.apache.spark.sql.execution.metric.SQLMetric
+import org.apache.spark.sql.types.StructType
+import org.apache.spark.util.ArrayImplicits._
+import org.apache.spark.util.collection.BitSet
+
+import org.apache.hadoop.fs.Path
+
+abstract class FileSourceScanExecShim(
+ @transient relation: HadoopFsRelation,
+ output: Seq[Attribute],
+ requiredSchema: StructType,
+ partitionFilters: Seq[Expression],
+ optionalBucketSet: Option[BitSet],
+ optionalNumCoalescedBuckets: Option[Int],
+ dataFilters: Seq[Expression],
+ tableIdentifier: Option[TableIdentifier],
+ disableBucketedScan: Boolean = false)
+ extends AbstractFileSourceScanExec(
+ relation,
+ output,
+ requiredSchema,
+ partitionFilters,
+ optionalBucketSet,
+ optionalNumCoalescedBuckets,
+ dataFilters,
+ tableIdentifier,
+ disableBucketedScan) {
+
+ // Note: "metrics" is made transient to avoid sending driver-side metrics to tasks.
+ @transient override lazy val metrics: Map[String, SQLMetric] = Map()
+
+ lazy val metadataColumns: Seq[AttributeReference] = output.collect {
+ case FileSourceConstantMetadataAttribute(attr) => attr
+ case FileSourceGeneratedMetadataAttribute(attr, _) => attr
+ }
+
+ protected lazy val driverMetricsAlias = driverMetrics
+
+ def dataFiltersInScan: Seq[Expression] = dataFilters.filterNot(_.references.exists {
+ attr => SparkShimLoader.getSparkShims.isRowIndexMetadataColumn(attr.name)
+ })
+
+ def hasUnsupportedColumns: Boolean = {
+ // TODO, fallback if user define same name column due to we can't right now
+ // detect which column is metadata column which is user defined column.
+ val metadataColumnsNames = metadataColumns.map(_.name)
+ output
+ .filterNot(metadataColumns.toSet)
+ .exists(v => metadataColumnsNames.contains(v.name))
+ }
+
+ def isMetadataColumn(attr: Attribute): Boolean = metadataColumns.contains(attr)
+
+ def hasFieldIds: Boolean = ParquetUtils.hasFieldIds(requiredSchema)
+
+ protected def isDynamicPruningFilter(e: Expression): Boolean =
+ e.find(_.isInstanceOf[PlanExpression[_]]).isDefined
+
+ protected def setFilesNumAndSizeMetric(partitions: ScanFileListing, static: Boolean): Unit = {
+ val filesNum = partitions.totalNumberOfFiles
+ val filesSize = partitions.totalFileSize
+ if (!static || !partitionFilters.exists(isDynamicPruningFilter)) {
+ driverMetrics("numFiles").set(filesNum)
+ driverMetrics("filesSize").set(filesSize)
+ } else {
+ driverMetrics("staticFilesNum").set(filesNum)
+ driverMetrics("staticFilesSize").set(filesSize)
+ }
+ if (relation.partitionSchema.nonEmpty) {
+ driverMetrics("numPartitions").set(partitions.partitionCount)
+ }
+ }
+
+ @transient override lazy val dynamicallySelectedPartitions: ScanFileListing = {
+ val dynamicDataFilters = dataFilters.filter(isDynamicPruningFilter)
+ val dynamicPartitionFilters =
+ partitionFilters.filter(isDynamicPruningFilter)
+ if (dynamicPartitionFilters.nonEmpty) {
+ GlutenTimeMetric.withMillisTime {
+ // call the file index for the files matching all filters except dynamic partition filters
+ val boundedFilters = dynamicPartitionFilters.map {
+ dynamicPartitionFilter =>
+ dynamicPartitionFilter.transform {
+ case a: AttributeReference =>
+ val index = relation.partitionSchema.indexWhere(a.name == _.name)
+ BoundReference(index, relation.partitionSchema(index).dataType, nullable = true)
+ }
+ }
+ val boundPredicate = Predicate.create(boundedFilters.reduce(And), Nil)
+ val returnedFiles =
+ selectedPartitions.filterAndPruneFiles(boundPredicate, dynamicDataFilters)
+ setFilesNumAndSizeMetric(returnedFiles, false)
+ returnedFiles
+ }(t => driverMetrics("pruningTime").set(t))
+ } else {
+ selectedPartitions
+ }
+ }
+
+ def getPartitionArray: Array[PartitionDirectory] = {
+ // TODO: fix the value of partiton directories in dynamic pruning
+ val staticDataFilters = dataFilters.filterNot(isDynamicPruningFilter)
+ val staticPartitionFilters = partitionFilters.filterNot(isDynamicPruningFilter)
+ val partitionDirectories =
+ relation.location.listFiles(staticPartitionFilters, staticDataFilters)
+ partitionDirectories.toArray
+ }
+
+ /**
+ * Create an RDD for bucketed reads. The non-bucketed variant of this function is
+ * [[createReadRDD]].
+ *
+ * The algorithm is pretty simple: each RDD partition being returned should include all the files
+ * with the same bucket id from all the given Hive partitions.
+ *
+ * @param bucketSpec
+ * the bucketing spec.
+ * @param selectedPartitions
+ * Hive-style partition that are part of the read.
+ */
+ private def createBucketedReadPartition(
+ bucketSpec: BucketSpec,
+ selectedPartitions: ScanFileListing): Seq[FilePartition] = {
+ logInfo(log"Planning with ${MDC(COUNT, bucketSpec.numBuckets)} buckets")
+ val partitionArray = selectedPartitions.toPartitionArray
+ val filesGroupedToBuckets = partitionArray.groupBy {
+ f =>
+ BucketingUtils
+ .getBucketId(f.toPath.getName)
+ .getOrElse(throw QueryExecutionErrors.invalidBucketFile(f.urlEncodedPath))
+ }
+
+ val prunedFilesGroupedToBuckets = if (optionalBucketSet.isDefined) {
+ val bucketSet = optionalBucketSet.get
+ filesGroupedToBuckets.filter(f => bucketSet.get(f._1))
+ } else {
+ filesGroupedToBuckets
+ }
+
+ val filePartitions = optionalNumCoalescedBuckets
+ .map {
+ numCoalescedBuckets =>
+ logInfo(log"Coalescing to ${MDC(COUNT, numCoalescedBuckets)} buckets")
+ val coalescedBuckets = prunedFilesGroupedToBuckets.groupBy(_._1 % numCoalescedBuckets)
+ Seq.tabulate(numCoalescedBuckets) {
+ bucketId =>
+ val partitionedFiles = coalescedBuckets
+ .get(bucketId)
+ .map {
+ _.values.flatten.toArray
+ }
+ .getOrElse(Array.empty)
+ FilePartition(bucketId, partitionedFiles)
+ }
+ }
+ .getOrElse {
+ Seq.tabulate(bucketSpec.numBuckets) {
+ bucketId =>
+ FilePartition(bucketId, prunedFilesGroupedToBuckets.getOrElse(bucketId, Array.empty))
+ }
+ }
+ filePartitions
+ }
+
+ /**
+ * Create an RDD for non-bucketed reads. The bucketed variant of this function is
+ * [[createBucketedReadRDD]].
+ *
+ * @param selectedPartitions
+ * Hive-style partition that are part of the read.
+ */
+ private def createReadPartitions(selectedPartitions: ScanFileListing): Seq[FilePartition] = {
+ val openCostInBytes = relation.sparkSession.sessionState.conf.filesOpenCostInBytes
+ val maxSplitBytes =
+ FilePartition.maxSplitBytes(relation.sparkSession, selectedPartitions)
+ logInfo(log"Planning scan with bin packing, max size: ${MDC(MAX_SPLIT_BYTES, maxSplitBytes)} " +
+ log"bytes, open cost is considered as scanning ${MDC(OPEN_COST_IN_BYTES, openCostInBytes)} " +
+ log"bytes.")
+
+ // Filter files with bucket pruning if possible
+ val bucketingEnabled = relation.sparkSession.sessionState.conf.bucketingEnabled
+ val shouldProcess: Path => Boolean = optionalBucketSet match {
+ case Some(bucketSet) if bucketingEnabled =>
+ // Do not prune the file if bucket file name is invalid
+ filePath => BucketingUtils.getBucketId(filePath.getName).forall(bucketSet.get)
+ case _ =>
+ _ => true
+ }
+
+ val splitFiles = selectedPartitions.filePartitionIterator
+ .flatMap {
+ partition =>
+ val ListingPartition(partitionVals, _, fileStatusIterator) = partition
+ fileStatusIterator.flatMap {
+ file =>
+ // getPath() is very expensive so we only want to call it once in this block:
+ val filePath = file.getPath
+ if (shouldProcess(filePath)) {
+ val isSplitable =
+ relation.fileFormat.isSplitable(relation.sparkSession, relation.options, filePath)
+ PartitionedFileUtil.splitFiles(
+ file = file,
+ filePath = filePath,
+ isSplitable = isSplitable,
+ maxSplitBytes = maxSplitBytes,
+ partitionValues = partitionVals
+ )
+ } else {
+ Seq.empty
+ }
+ }
+ }
+ .toArray
+ .sortBy(_.length)(implicitly[Ordering[Long]].reverse)
+
+ val partitions = FilePartition
+ .getFilePartitions(relation.sparkSession, splitFiles.toImmutableArraySeq, maxSplitBytes)
+ partitions
+ }
+
+ def getPartitionsSeq(): Seq[Partition] = {
+ if (bucketedScan) {
+ createBucketedReadPartition(relation.bucketSpec.get, dynamicallySelectedPartitions)
+ } else {
+ createReadPartitions(dynamicallySelectedPartitions)
+ }
+ }
+}
+
+abstract class ArrowFileSourceScanLikeShim(original: FileSourceScanExec)
+ extends FileSourceScanLike {
+ override val nodeNamePrefix: String = "ArrowFile"
+
+ override def tableIdentifier: Option[TableIdentifier] = original.tableIdentifier
+
+ override def inputRDDs(): Seq[RDD[InternalRow]] = original.inputRDDs()
+
+ override def dataFilters: Seq[Expression] = original.dataFilters
+
+ override def disableBucketedScan: Boolean = original.disableBucketedScan
+
+ override def optionalBucketSet: Option[BitSet] = original.optionalBucketSet
+
+ override def optionalNumCoalescedBuckets: Option[Int] = original.optionalNumCoalescedBuckets
+
+ override def partitionFilters: Seq[Expression] = original.partitionFilters
+
+ override def relation: HadoopFsRelation = original.relation
+
+ override def requiredSchema: StructType = original.requiredSchema
+
+ override def getStream: Option[SparkDataStream] = original.stream
+}
diff --git a/shims/spark41/src/main/scala/org/apache/spark/sql/execution/GlutenFileFormatWriter.scala b/shims/spark41/src/main/scala/org/apache/spark/sql/execution/GlutenFileFormatWriter.scala
new file mode 100644
index 000000000000..8dc5fbc96434
--- /dev/null
+++ b/shims/spark41/src/main/scala/org/apache/spark/sql/execution/GlutenFileFormatWriter.scala
@@ -0,0 +1,49 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution
+
+import org.apache.spark.internal.io.FileCommitProtocol
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.errors.QueryExecutionErrors
+import org.apache.spark.sql.execution.datasources.{FileFormatWriter, WriteJobDescription, WriteTaskResult}
+
+object GlutenFileFormatWriter {
+ def writeFilesExecuteTask(
+ description: WriteJobDescription,
+ jobTrackerID: String,
+ sparkStageId: Int,
+ sparkPartitionId: Int,
+ sparkAttemptNumber: Int,
+ committer: FileCommitProtocol,
+ iterator: Iterator[InternalRow]): WriteTaskResult = {
+ FileFormatWriter.executeTask(
+ description,
+ jobTrackerID,
+ sparkStageId,
+ sparkPartitionId,
+ sparkAttemptNumber,
+ committer,
+ iterator,
+ None
+ )
+ }
+
+ // Wrapper for throwing standardized write error using QueryExecutionErrors
+ def wrapWriteError(cause: Throwable, writePath: String): Nothing = {
+ throw QueryExecutionErrors.taskFailedWhileWritingRowsError(writePath, cause)
+ }
+}
diff --git a/shims/spark41/src/main/scala/org/apache/spark/sql/execution/PartitioningAndOrderingPreservingNodeShim.scala b/shims/spark41/src/main/scala/org/apache/spark/sql/execution/PartitioningAndOrderingPreservingNodeShim.scala
new file mode 100644
index 000000000000..f0ffcc8b690a
--- /dev/null
+++ b/shims/spark41/src/main/scala/org/apache/spark/sql/execution/PartitioningAndOrderingPreservingNodeShim.scala
@@ -0,0 +1,20 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution
+
+trait OrderPreservingNodeShim extends OrderPreservingUnaryExecNode
+trait PartitioningPreservingNodeShim extends PartitioningPreservingUnaryExecNode
diff --git a/shims/spark41/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFooterReaderShim.scala b/shims/spark41/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFooterReaderShim.scala
new file mode 100644
index 000000000000..f4cc013ad4da
--- /dev/null
+++ b/shims/spark41/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFooterReaderShim.scala
@@ -0,0 +1,43 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.datasources.parquet
+
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.fs.{FileStatus, Path}
+import org.apache.parquet.format.converter.ParquetMetadataConverter
+import org.apache.parquet.hadoop.metadata.ParquetMetadata
+import org.apache.parquet.hadoop.util.HadoopInputFile
+
+/** Shim layer for ParquetFooterReader to maintain compatibility across different Spark versions. */
+object ParquetFooterReaderShim {
+
+ /** @since Spark 4.1 */
+ def readFooter(
+ configuration: Configuration,
+ fileStatus: FileStatus,
+ filter: ParquetMetadataConverter.MetadataFilter): ParquetMetadata = {
+ ParquetFooterReader.readFooter(HadoopInputFile.fromStatus(fileStatus, configuration), filter)
+ }
+
+ /** @since Spark 4.1 */
+ def readFooter(
+ configuration: Configuration,
+ file: Path,
+ filter: ParquetMetadataConverter.MetadataFilter): ParquetMetadata = {
+ ParquetFooterReader.readFooter(HadoopInputFile.fromPath(file, configuration), filter)
+ }
+}
diff --git a/shims/spark41/src/main/scala/org/apache/spark/sql/execution/datasources/v2/AbstractBatchScanExec.scala b/shims/spark41/src/main/scala/org/apache/spark/sql/execution/datasources/v2/AbstractBatchScanExec.scala
new file mode 100644
index 000000000000..a6a46b10f49d
--- /dev/null
+++ b/shims/spark41/src/main/scala/org/apache/spark/sql/execution/datasources/v2/AbstractBatchScanExec.scala
@@ -0,0 +1,284 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.datasources.v2
+
+import org.apache.spark.SparkException
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.physical.{KeyGroupedPartitioning, KeyGroupedShuffleSpec, Partitioning, SinglePartition}
+import org.apache.spark.sql.catalyst.util.{truncatedString, InternalRowComparableWrapper}
+import org.apache.spark.sql.connector.catalog.Table
+import org.apache.spark.sql.connector.read._
+import org.apache.spark.sql.execution.joins.StoragePartitionJoinParams
+import org.apache.spark.util.ArrayImplicits._
+
+import com.google.common.base.Objects
+
+/**
+ * Physical plan node for scanning a batch of data from a data source v2. Please ref BatchScanExec
+ * in Spark
+ */
+abstract class AbstractBatchScanExec(
+ output: Seq[AttributeReference],
+ @transient scan: Scan,
+ val runtimeFilters: Seq[Expression],
+ ordering: Option[Seq[SortOrder]] = None,
+ @transient table: Table,
+ val spjParams: StoragePartitionJoinParams = StoragePartitionJoinParams()
+) extends DataSourceV2ScanExecBase {
+
+ @transient lazy val batch: Batch = if (scan == null) null else scan.toBatch
+
+ // TODO: unify the equal/hashCode implementation for all data source v2 query plans.
+ override def equals(other: Any): Boolean = other match {
+ case other: AbstractBatchScanExec =>
+ this.batch != null && this.batch == other.batch &&
+ this.runtimeFilters == other.runtimeFilters &&
+ this.spjParams == other.spjParams
+ case _ =>
+ false
+ }
+
+ override def hashCode(): Int = Objects.hashCode(batch, runtimeFilters)
+
+ @transient override lazy val inputPartitions: Seq[InputPartition] = inputPartitionsShim
+
+ @transient protected lazy val inputPartitionsShim: Seq[InputPartition] =
+ batch.planInputPartitions()
+
+ @transient private lazy val filteredPartitions: Seq[Seq[InputPartition]] = {
+ val dataSourceFilters = runtimeFilters.flatMap {
+ case DynamicPruningExpression(e) => DataSourceV2Strategy.translateRuntimeFilterV2(e)
+ case _ => None
+ }
+
+ if (dataSourceFilters.nonEmpty) {
+ val originalPartitioning = outputPartitioning
+
+ // the cast is safe as runtime filters are only assigned if the scan can be filtered
+ val filterableScan = scan.asInstanceOf[SupportsRuntimeV2Filtering]
+ filterableScan.filter(dataSourceFilters.toArray)
+
+ // call toBatch again to get filtered partitions
+ val newPartitions = scan.toBatch.planInputPartitions()
+
+ originalPartitioning match {
+ case p: KeyGroupedPartitioning =>
+ if (newPartitions.exists(!_.isInstanceOf[HasPartitionKey])) {
+ throw new SparkException(
+ "Data source must have preserved the original partitioning " +
+ "during runtime filtering: not all partitions implement HasPartitionKey after " +
+ "filtering")
+ }
+ val newPartitionValues = newPartitions
+ .map(
+ partition =>
+ InternalRowComparableWrapper(
+ partition.asInstanceOf[HasPartitionKey],
+ p.expressions))
+ .toSet
+ val oldPartitionValues = p.partitionValues
+ .map(partition => InternalRowComparableWrapper(partition, p.expressions))
+ .toSet
+ // We require the new number of partition values to be equal or less than the old number
+ // of partition values here. In the case of less than, empty partitions will be added for
+ // those missing values that are not present in the new input partitions.
+ if (oldPartitionValues.size < newPartitionValues.size) {
+ throw new SparkException(
+ "During runtime filtering, data source must either report " +
+ "the same number of partition values, or a subset of partition values from the " +
+ s"original. Before: ${oldPartitionValues.size} partition values. " +
+ s"After: ${newPartitionValues.size} partition values")
+ }
+
+ if (!newPartitionValues.forall(oldPartitionValues.contains)) {
+ throw new SparkException(
+ "During runtime filtering, data source must not report new " +
+ "partition values that are not present in the original partitioning.")
+ }
+
+ groupPartitions(newPartitions.toImmutableArraySeq)
+ .map(_.groupedParts.map(_.parts))
+ .getOrElse(Seq.empty)
+ case _ =>
+ // no validation is needed as the data source did not report any specific partitioning
+ newPartitions.map(Seq(_))
+ }
+
+ } else {
+ partitions
+ }
+ }
+
+ override def outputPartitioning: Partitioning = {
+ super.outputPartitioning match {
+ case k: KeyGroupedPartitioning if spjParams.commonPartitionValues.isDefined =>
+ // We allow duplicated partition values if
+ // `spark.sql.sources.v2.bucketing.partiallyClusteredDistribution.enabled` is true
+ val newPartValues = spjParams.commonPartitionValues.get.flatMap {
+ case (partValue, numSplits) => Seq.fill(numSplits)(partValue)
+ }
+ k.copy(numPartitions = newPartValues.length, partitionValues = newPartValues)
+ case p => p
+ }
+ }
+
+ override lazy val readerFactory: PartitionReaderFactory = batch.createReaderFactory()
+
+ override lazy val inputRDD: RDD[InternalRow] = {
+ val rdd = if (filteredPartitions.isEmpty && outputPartitioning == SinglePartition) {
+ // return an empty RDD with 1 partition if dynamic filtering removed the only split
+ sparkContext.parallelize(Array.empty[InternalRow], 1)
+ } else {
+ val finalPartitions = outputPartitioning match {
+ case p: KeyGroupedPartitioning =>
+ assert(spjParams.keyGroupedPartitioning.isDefined)
+ val expressions = spjParams.keyGroupedPartitioning.get
+
+ // Re-group the input partitions if we are projecting on a subset of join keys
+ val (groupedPartitions, partExpressions) = spjParams.joinKeyPositions match {
+ case Some(projectPositions) =>
+ val projectedExpressions = projectPositions.map(i => expressions(i))
+ val parts = filteredPartitions.flatten
+ .groupBy(
+ part => {
+ val row = part.asInstanceOf[HasPartitionKey].partitionKey()
+ val projectedRow =
+ KeyGroupedPartitioning.project(expressions, projectPositions, row)
+ InternalRowComparableWrapper(projectedRow, projectedExpressions)
+ })
+ .map { case (wrapper, splits) => (wrapper.row, splits) }
+ .toSeq
+ (parts, projectedExpressions)
+ case _ =>
+ val groupedParts = filteredPartitions.map(
+ splits => {
+ assert(splits.nonEmpty && splits.head.isInstanceOf[HasPartitionKey])
+ (splits.head.asInstanceOf[HasPartitionKey].partitionKey(), splits)
+ })
+ (groupedParts, expressions)
+ }
+
+ // Also re-group the partitions if we are reducing compatible partition expressions
+ val finalGroupedPartitions = spjParams.reducers match {
+ case Some(reducers) =>
+ val result = groupedPartitions
+ .groupBy {
+ case (row, _) =>
+ KeyGroupedShuffleSpec.reducePartitionValue(row, partExpressions, reducers)
+ }
+ .map { case (wrapper, splits) => (wrapper.row, splits.flatMap(_._2)) }
+ .toSeq
+ val rowOrdering =
+ RowOrdering.createNaturalAscendingOrdering(partExpressions.map(_.dataType))
+ result.sorted(rowOrdering.on((t: (InternalRow, _)) => t._1))
+ case _ => groupedPartitions
+ }
+
+ // When partially clustered, the input partitions are not grouped by partition
+ // values. Here we'll need to check `commonPartitionValues` and decide how to group
+ // and replicate splits within a partition.
+ if (spjParams.commonPartitionValues.isDefined && spjParams.applyPartialClustering) {
+ // A mapping from the common partition values to how many splits the partition
+ // should contain.
+ val commonPartValuesMap = spjParams.commonPartitionValues.get
+ .map(t => (InternalRowComparableWrapper(t._1, partExpressions), t._2))
+ .toMap
+ val filteredGroupedPartitions = finalGroupedPartitions.filter {
+ case (partValues, _) =>
+ commonPartValuesMap.keySet.contains(
+ InternalRowComparableWrapper(partValues, partExpressions))
+ }
+ val nestGroupedPartitions = filteredGroupedPartitions.map {
+ case (partValue, splits) =>
+ // `commonPartValuesMap` should contain the part value since it's the super set.
+ val numSplits = commonPartValuesMap
+ .get(InternalRowComparableWrapper(partValue, partExpressions))
+ assert(
+ numSplits.isDefined,
+ s"Partition value $partValue does not exist in " +
+ "common partition values from Spark plan")
+
+ val newSplits = if (spjParams.replicatePartitions) {
+ // We need to also replicate partitions according to the other side of join
+ Seq.fill(numSplits.get)(splits)
+ } else {
+ // Not grouping by partition values: this could be the side with partially
+ // clustered distribution. Because of dynamic filtering, we'll need to check if
+ // the final number of splits of a partition is smaller than the original
+ // number, and fill with empty splits if so. This is necessary so that both
+ // sides of a join will have the same number of partitions & splits.
+ splits.map(Seq(_)).padTo(numSplits.get, Seq.empty)
+ }
+ (InternalRowComparableWrapper(partValue, partExpressions), newSplits)
+ }
+
+ // Now fill missing partition keys with empty partitions
+ val partitionMapping = nestGroupedPartitions.toMap
+ spjParams.commonPartitionValues.get.flatMap {
+ case (partValue, numSplits) =>
+ // Use empty partition for those partition values that are not present.
+ partitionMapping.getOrElse(
+ InternalRowComparableWrapper(partValue, partExpressions),
+ Seq.fill(numSplits)(Seq.empty))
+ }
+ } else {
+ // either `commonPartitionValues` is not defined, or it is defined but
+ // `applyPartialClustering` is false.
+ val partitionMapping = finalGroupedPartitions.map {
+ case (partValue, splits) =>
+ InternalRowComparableWrapper(partValue, partExpressions) -> splits
+ }.toMap
+
+ // In case `commonPartitionValues` is not defined (e.g., SPJ is not used), there
+ // could exist duplicated partition values, as partition grouping is not done
+ // at the beginning and postponed to this method. It is important to use unique
+ // partition values here so that grouped partitions won't get duplicated.
+ p.uniquePartitionValues.map {
+ partValue =>
+ // Use empty partition for those partition values that are not present
+ partitionMapping.getOrElse(
+ InternalRowComparableWrapper(partValue, partExpressions),
+ Seq.empty)
+ }
+ }
+
+ case _ => filteredPartitions
+ }
+
+ new DataSourceRDD(
+ sparkContext,
+ finalPartitions,
+ readerFactory,
+ supportsColumnar,
+ customMetrics)
+ }
+ postDriverMetrics()
+ rdd
+ }
+
+ override def keyGroupedPartitioning: Option[Seq[Expression]] =
+ spjParams.keyGroupedPartitioning
+
+ override def simpleString(maxFields: Int): String = {
+ val truncatedOutputString = truncatedString(output, "[", ", ", "]", maxFields)
+ val runtimeFiltersString = s"RuntimeFilters: ${runtimeFilters.mkString("[", ",", "]")}"
+ val result = s"$nodeName$truncatedOutputString ${scan.description()} $runtimeFiltersString"
+ redact(result)
+ }
+}
diff --git a/shims/spark41/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala.deprecated b/shims/spark41/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala.deprecated
new file mode 100644
index 000000000000..d43331d57c47
--- /dev/null
+++ b/shims/spark41/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala.deprecated
@@ -0,0 +1,243 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources.v2
+
+import com.google.common.base.Objects
+
+import org.apache.spark.SparkException
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.QueryPlan
+import org.apache.spark.sql.catalyst.plans.physical.{KeyGroupedPartitioning, Partitioning, SinglePartition}
+import org.apache.spark.sql.catalyst.util.{truncatedString, InternalRowComparableWrapper}
+import org.apache.spark.sql.connector.catalog.Table
+import org.apache.spark.sql.connector.read._
+import org.apache.spark.sql.internal.SQLConf
+
+/**
+ * Physical plan node for scanning a batch of data from a data source v2.
+ */
+case class BatchScanExec(
+ output: Seq[AttributeReference],
+ @transient scan: Scan,
+ runtimeFilters: Seq[Expression],
+ keyGroupedPartitioning: Option[Seq[Expression]] = None,
+ ordering: Option[Seq[SortOrder]] = None,
+ @transient table: Table,
+ commonPartitionValues: Option[Seq[(InternalRow, Int)]] = None,
+ applyPartialClustering: Boolean = false,
+ replicatePartitions: Boolean = false) extends DataSourceV2ScanExecBase {
+
+ @transient lazy val batch = if (scan == null) null else scan.toBatch
+
+ // TODO: unify the equal/hashCode implementation for all data source v2 query plans.
+ override def equals(other: Any): Boolean = other match {
+ case other: BatchScanExec =>
+ this.batch != null && this.batch == other.batch &&
+ this.runtimeFilters == other.runtimeFilters &&
+ this.commonPartitionValues == other.commonPartitionValues &&
+ this.replicatePartitions == other.replicatePartitions &&
+ this.applyPartialClustering == other.applyPartialClustering
+ case _ =>
+ false
+ }
+
+ override def hashCode(): Int = Objects.hashCode(batch, runtimeFilters)
+
+ @transient override lazy val inputPartitions: Seq[InputPartition] = batch.planInputPartitions()
+
+ @transient private lazy val filteredPartitions: Seq[Seq[InputPartition]] = {
+ val dataSourceFilters = runtimeFilters.flatMap {
+ case DynamicPruningExpression(e) => DataSourceV2Strategy.translateRuntimeFilterV2(e)
+ case _ => None
+ }
+
+ if (dataSourceFilters.nonEmpty) {
+ val originalPartitioning = outputPartitioning
+
+ // the cast is safe as runtime filters are only assigned if the scan can be filtered
+ val filterableScan = scan.asInstanceOf[SupportsRuntimeV2Filtering]
+ filterableScan.filter(dataSourceFilters.toArray)
+
+ // call toBatch again to get filtered partitions
+ val newPartitions = scan.toBatch.planInputPartitions()
+
+ originalPartitioning match {
+ case p: KeyGroupedPartitioning =>
+ if (newPartitions.exists(!_.isInstanceOf[HasPartitionKey])) {
+ throw new SparkException("Data source must have preserved the original partitioning " +
+ "during runtime filtering: not all partitions implement HasPartitionKey after " +
+ "filtering")
+ }
+ val newPartitionValues = newPartitions.map(partition =>
+ InternalRowComparableWrapper(partition.asInstanceOf[HasPartitionKey], p.expressions))
+ .toSet
+ val oldPartitionValues = p.partitionValues
+ .map(partition => InternalRowComparableWrapper(partition, p.expressions)).toSet
+ // We require the new number of partition values to be equal or less than the old number
+ // of partition values here. In the case of less than, empty partitions will be added for
+ // those missing values that are not present in the new input partitions.
+ if (oldPartitionValues.size < newPartitionValues.size) {
+ throw new SparkException("During runtime filtering, data source must either report " +
+ "the same number of partition values, or a subset of partition values from the " +
+ s"original. Before: ${oldPartitionValues.size} partition values. " +
+ s"After: ${newPartitionValues.size} partition values")
+ }
+
+ if (!newPartitionValues.forall(oldPartitionValues.contains)) {
+ throw new SparkException("During runtime filtering, data source must not report new " +
+ "partition values that are not present in the original partitioning.")
+ }
+
+ groupPartitions(newPartitions).get.map(_._2)
+
+ case _ =>
+ // no validation is needed as the data source did not report any specific partitioning
+ newPartitions.map(Seq(_))
+ }
+
+ } else {
+ partitions
+ }
+ }
+
+ override def outputPartitioning: Partitioning = {
+ super.outputPartitioning match {
+ case k: KeyGroupedPartitioning if commonPartitionValues.isDefined =>
+ // We allow duplicated partition values if
+ // `spark.sql.sources.v2.bucketing.partiallyClusteredDistribution.enabled` is true
+ val newPartValues = commonPartitionValues.get.flatMap { case (partValue, numSplits) =>
+ Seq.fill(numSplits)(partValue)
+ }
+ k.copy(numPartitions = newPartValues.length, partitionValues = newPartValues)
+ case p => p
+ }
+ }
+
+ override lazy val readerFactory: PartitionReaderFactory = batch.createReaderFactory()
+
+ override lazy val inputRDD: RDD[InternalRow] = {
+ val rdd = if (filteredPartitions.isEmpty && outputPartitioning == SinglePartition) {
+ // return an empty RDD with 1 partition if dynamic filtering removed the only split
+ sparkContext.parallelize(Array.empty[InternalRow], 1)
+ } else {
+ var finalPartitions = filteredPartitions
+
+ outputPartitioning match {
+ case p: KeyGroupedPartitioning =>
+ if (conf.v2BucketingPushPartValuesEnabled &&
+ conf.v2BucketingPartiallyClusteredDistributionEnabled) {
+ assert(filteredPartitions.forall(_.size == 1),
+ "Expect partitions to be not grouped when " +
+ s"${SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key} " +
+ "is enabled")
+
+ val groupedPartitions = groupPartitions(finalPartitions.map(_.head), true).get
+
+ // This means the input partitions are not grouped by partition values. We'll need to
+ // check `groupByPartitionValues` and decide whether to group and replicate splits
+ // within a partition.
+ if (commonPartitionValues.isDefined && applyPartialClustering) {
+ // A mapping from the common partition values to how many splits the partition
+ // should contain. Note this no longer maintain the partition key ordering.
+ val commonPartValuesMap = commonPartitionValues
+ .get
+ .map(t => (InternalRowComparableWrapper(t._1, p.expressions), t._2))
+ .toMap
+ val nestGroupedPartitions = groupedPartitions.map {
+ case (partValue, splits) =>
+ // `commonPartValuesMap` should contain the part value since it's the super set.
+ val numSplits = commonPartValuesMap
+ .get(InternalRowComparableWrapper(partValue, p.expressions))
+ assert(numSplits.isDefined, s"Partition value $partValue does not exist in " +
+ "common partition values from Spark plan")
+
+ val newSplits = if (replicatePartitions) {
+ // We need to also replicate partitions according to the other side of join
+ Seq.fill(numSplits.get)(splits)
+ } else {
+ // Not grouping by partition values: this could be the side with partially
+ // clustered distribution. Because of dynamic filtering, we'll need to check if
+ // the final number of splits of a partition is smaller than the original
+ // number, and fill with empty splits if so. This is necessary so that both
+ // sides of a join will have the same number of partitions & splits.
+ splits.map(Seq(_)).padTo(numSplits.get, Seq.empty)
+ }
+ (InternalRowComparableWrapper(partValue, p.expressions), newSplits)
+ }
+
+ // Now fill missing partition keys with empty partitions
+ val partitionMapping = nestGroupedPartitions.toMap
+ finalPartitions = commonPartitionValues.get.flatMap { case (partValue, numSplits) =>
+ // Use empty partition for those partition values that are not present.
+ partitionMapping.getOrElse(
+ InternalRowComparableWrapper(partValue, p.expressions),
+ Seq.fill(numSplits)(Seq.empty))
+ }
+ } else {
+ val partitionMapping = groupedPartitions.map { case (row, parts) =>
+ InternalRowComparableWrapper(row, p.expressions) -> parts
+ }.toMap
+ finalPartitions = p.partitionValues.map { partValue =>
+ // Use empty partition for those partition values that are not present
+ partitionMapping.getOrElse(
+ InternalRowComparableWrapper(partValue, p.expressions), Seq.empty)
+ }
+ }
+ } else {
+ val partitionMapping = finalPartitions.map { parts =>
+ val row = parts.head.asInstanceOf[HasPartitionKey].partitionKey()
+ InternalRowComparableWrapper(row, p.expressions) -> parts
+ }.toMap
+ finalPartitions = p.partitionValues.map { partValue =>
+ // Use empty partition for those partition values that are not present
+ partitionMapping.getOrElse(
+ InternalRowComparableWrapper(partValue, p.expressions), Seq.empty)
+ }
+ }
+
+ case _ =>
+ }
+
+ new DataSourceRDD(
+ sparkContext, finalPartitions, readerFactory, supportsColumnar, customMetrics)
+ }
+ postDriverMetrics()
+ rdd
+ }
+
+ override def doCanonicalize(): BatchScanExec = {
+ this.copy(
+ output = output.map(QueryPlan.normalizeExpressions(_, output)),
+ runtimeFilters = QueryPlan.normalizePredicates(
+ runtimeFilters.filterNot(_ == DynamicPruningExpression(Literal.TrueLiteral)),
+ output))
+ }
+
+ override def simpleString(maxFields: Int): String = {
+ val truncatedOutputString = truncatedString(output, "[", ", ", "]", maxFields)
+ val runtimeFiltersString = s"RuntimeFilters: ${runtimeFilters.mkString("[", ",", "]")}"
+ val result = s"$nodeName$truncatedOutputString ${scan.description()} $runtimeFiltersString"
+ redact(result)
+ }
+
+ override def nodeName: String = {
+ s"BatchScan ${table.name()}".trim
+ }
+}
diff --git a/shims/spark41/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExecShim.scala b/shims/spark41/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExecShim.scala
new file mode 100644
index 000000000000..046c42ea7b6e
--- /dev/null
+++ b/shims/spark41/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExecShim.scala
@@ -0,0 +1,177 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.datasources.v2
+
+import org.apache.spark.SparkException
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.physical.KeyGroupedPartitioning
+import org.apache.spark.sql.catalyst.util.InternalRowComparableWrapper
+import org.apache.spark.sql.connector.catalog.Table
+import org.apache.spark.sql.connector.catalog.functions.Reducer
+import org.apache.spark.sql.connector.expressions.aggregate.Aggregation
+import org.apache.spark.sql.connector.read.{HasPartitionKey, InputPartition, Scan, SupportsRuntimeV2Filtering}
+import org.apache.spark.sql.execution.datasources.v2.orc.OrcScan
+import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetScan
+import org.apache.spark.sql.execution.joins.StoragePartitionJoinParams
+import org.apache.spark.sql.execution.metric.SQLMetric
+import org.apache.spark.sql.vectorized.ColumnarBatch
+import org.apache.spark.util.ArrayImplicits._
+
+abstract class BatchScanExecShim(
+ output: Seq[AttributeReference],
+ @transient scan: Scan,
+ runtimeFilters: Seq[Expression],
+ keyGroupedPartitioning: Option[Seq[Expression]] = None,
+ ordering: Option[Seq[SortOrder]] = None,
+ @transient val table: Table,
+ val joinKeyPositions: Option[Seq[Int]] = None,
+ val commonPartitionValues: Option[Seq[(InternalRow, Int)]] = None,
+ val reducers: Option[Seq[Option[Reducer[_, _]]]] = None,
+ val applyPartialClustering: Boolean = false,
+ val replicatePartitions: Boolean = false)
+ extends AbstractBatchScanExec(
+ output,
+ scan,
+ runtimeFilters,
+ ordering,
+ table,
+ StoragePartitionJoinParams(
+ keyGroupedPartitioning,
+ joinKeyPositions,
+ commonPartitionValues,
+ reducers,
+ applyPartialClustering,
+ replicatePartitions)
+ ) {
+
+ // Note: "metrics" is made transient to avoid sending driver-side metrics to tasks.
+ @transient override lazy val metrics: Map[String, SQLMetric] = Map()
+
+ lazy val metadataColumns: Seq[AttributeReference] = output.collect {
+ case FileSourceConstantMetadataAttribute(attr) => attr
+ case FileSourceGeneratedMetadataAttribute(attr, _) => attr
+ }
+
+ def hasUnsupportedColumns: Boolean = {
+ // TODO, fallback if user define same name column due to we can't right now
+ // detect which column is metadata column which is user defined column.
+ val metadataColumnsNames = metadataColumns.map(_.name)
+ output
+ .filterNot(metadataColumns.toSet)
+ .exists(v => metadataColumnsNames.contains(v.name))
+ }
+
+ override def doExecuteColumnar(): RDD[ColumnarBatch] = {
+ throw new UnsupportedOperationException("Need to implement this method")
+ }
+
+ @transient protected lazy val filteredPartitions: Seq[Seq[InputPartition]] = {
+ val dataSourceFilters = runtimeFilters.flatMap {
+ case DynamicPruningExpression(e) => DataSourceV2Strategy.translateRuntimeFilterV2(e)
+ case _ => None
+ }
+
+ if (dataSourceFilters.nonEmpty) {
+ val originalPartitioning = outputPartitioning
+
+ // the cast is safe as runtime filters are only assigned if the scan can be filtered
+ val filterableScan = scan.asInstanceOf[SupportsRuntimeV2Filtering]
+ filterableScan.filter(dataSourceFilters.toArray)
+
+ // call toBatch again to get filtered partitions
+ val newPartitions = scan.toBatch.planInputPartitions()
+
+ originalPartitioning match {
+ case p: KeyGroupedPartitioning =>
+ if (newPartitions.exists(!_.isInstanceOf[HasPartitionKey])) {
+ throw new SparkException(
+ "Data source must have preserved the original partitioning " +
+ "during runtime filtering: not all partitions implement HasPartitionKey after " +
+ "filtering")
+ }
+ val newPartitionValues = newPartitions
+ .map(
+ partition =>
+ InternalRowComparableWrapper(
+ partition.asInstanceOf[HasPartitionKey],
+ p.expressions))
+ .toSet
+ val oldPartitionValues = p.partitionValues
+ .map(partition => InternalRowComparableWrapper(partition, p.expressions))
+ .toSet
+ // We require the new number of partition values to be equal or less than the old number
+ // of partition values here. In the case of less than, empty partitions will be added for
+ // those missing values that are not present in the new input partitions.
+ if (oldPartitionValues.size < newPartitionValues.size) {
+ throw new SparkException(
+ "During runtime filtering, data source must either report " +
+ "the same number of partition values, or a subset of partition values from the " +
+ s"original. Before: ${oldPartitionValues.size} partition values. " +
+ s"After: ${newPartitionValues.size} partition values")
+ }
+
+ if (!newPartitionValues.forall(oldPartitionValues.contains)) {
+ throw new SparkException(
+ "During runtime filtering, data source must not report new " +
+ "partition values that are not present in the original partitioning.")
+ }
+
+ groupPartitions(newPartitions.toImmutableArraySeq)
+ .map(_.groupedParts.map(_.parts))
+ .getOrElse(Seq.empty)
+
+ case _ =>
+ // no validation is needed as the data source did not report any specific partitioning
+ newPartitions.map(Seq(_))
+ }
+
+ } else {
+ partitions
+ }
+ }
+
+ @transient lazy val pushedAggregate: Option[Aggregation] = {
+ scan match {
+ case s: ParquetScan => s.pushedAggregate
+ case o: OrcScan => o.pushedAggregate
+ case _ => None
+ }
+ }
+}
+
+abstract class ArrowBatchScanExecShim(original: BatchScanExec)
+ extends BatchScanExecShim(
+ original.output,
+ original.scan,
+ original.runtimeFilters,
+ original.spjParams.keyGroupedPartitioning,
+ original.ordering,
+ original.table,
+ original.spjParams.joinKeyPositions,
+ original.spjParams.commonPartitionValues,
+ original.spjParams.reducers,
+ original.spjParams.applyPartialClustering,
+ original.spjParams.replicatePartitions
+ ) {
+ override def scan: Scan = original.scan
+
+ override def ordering: Option[Seq[SortOrder]] = original.ordering
+
+ override def output: Seq[Attribute] = original.output
+}
diff --git a/shims/spark41/src/main/scala/org/apache/spark/sql/execution/datasources/v2/utils/CatalogUtil.scala b/shims/spark41/src/main/scala/org/apache/spark/sql/execution/datasources/v2/utils/CatalogUtil.scala
new file mode 100644
index 000000000000..c6686ba55299
--- /dev/null
+++ b/shims/spark41/src/main/scala/org/apache/spark/sql/execution/datasources/v2/utils/CatalogUtil.scala
@@ -0,0 +1,29 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.datasources.v2.utils
+
+import org.apache.spark.sql.catalyst.catalog.BucketSpec
+import org.apache.spark.sql.connector.expressions.Transform
+
+object CatalogUtil {
+
+ def convertPartitionTransforms(partitions: Seq[Transform]): (Seq[String], Option[BucketSpec]) = {
+ import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.TransformHelper
+ val (identityCols, bucketSpec, _) = partitions.convertTransforms
+ (identityCols, bucketSpec)
+ }
+}
diff --git a/shims/spark41/src/main/scala/org/apache/spark/sql/execution/python/BasePythonRunnerShim.scala b/shims/spark41/src/main/scala/org/apache/spark/sql/execution/python/BasePythonRunnerShim.scala
new file mode 100644
index 000000000000..127a0fc3cfc9
--- /dev/null
+++ b/shims/spark41/src/main/scala/org/apache/spark/sql/execution/python/BasePythonRunnerShim.scala
@@ -0,0 +1,66 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.python
+
+import org.apache.spark.SparkEnv
+import org.apache.spark.TaskContext
+import org.apache.spark.api.python.{BasePythonRunner, ChainedPythonFunctions, PythonWorker}
+import org.apache.spark.sql.execution.metric.SQLMetric
+import org.apache.spark.sql.execution.python.EvalPythonExec.ArgumentMetadata
+import org.apache.spark.sql.vectorized.ColumnarBatch
+
+import java.io.DataOutputStream
+
+abstract class BasePythonRunnerShim(
+ funcs: Seq[(ChainedPythonFunctions, Long)],
+ evalType: Int,
+ argMetas: Array[Array[(Int, Option[String])]],
+ pythonMetrics: Map[String, SQLMetric])
+ extends BasePythonRunner[ColumnarBatch, ColumnarBatch](
+ funcs.map(_._1),
+ evalType,
+ argMetas.map(_.map(_._1)),
+ None,
+ pythonMetrics) {
+
+ protected def createNewWriter(
+ env: SparkEnv,
+ worker: PythonWorker,
+ inputIterator: Iterator[ColumnarBatch],
+ partitionIndex: Int,
+ context: TaskContext): Writer
+
+ protected def writeUdf(
+ dataOut: DataOutputStream,
+ argOffsets: Array[Array[(Int, Option[String])]]): Unit = {
+ PythonUDFRunner.writeUDFs(
+ dataOut,
+ funcs,
+ argOffsets.map(_.map(pair => ArgumentMetadata(pair._1, pair._2))),
+ None)
+ }
+
+ override protected def newWriter(
+ env: SparkEnv,
+ worker: PythonWorker,
+ inputIterator: Iterator[ColumnarBatch],
+ partitionIndex: Int,
+ context: TaskContext): Writer = {
+ createNewWriter(env, worker, inputIterator, partitionIndex, context)
+ }
+
+}
diff --git a/shims/spark41/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExecBase.scala b/shims/spark41/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExecBase.scala
new file mode 100644
index 000000000000..7ad7ca6b09ee
--- /dev/null
+++ b/shims/spark41/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExecBase.scala
@@ -0,0 +1,35 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.python
+
+import org.apache.spark.sql.catalyst.expressions.{Expression, NamedArgumentExpression}
+
+abstract class EvalPythonExecBase extends EvalPythonExec {
+
+ override protected def evaluatorFactory: EvalPythonEvaluatorFactory = {
+ throw new IllegalStateException("EvalPythonExecTransformer doesn't support evaluate")
+ }
+}
+
+object EvalPythonExecBase {
+ object NamedArgumentExpressionShim {
+ def unapply(expr: Expression): Option[(String, Expression)] = expr match {
+ case NamedArgumentExpression(key, value) => Some((key, value))
+ case _ => None
+ }
+ }
+}
diff --git a/shims/spark41/src/main/scala/org/apache/spark/sql/execution/ui/TypeAlias.scala b/shims/spark41/src/main/scala/org/apache/spark/sql/execution/ui/TypeAlias.scala
new file mode 100644
index 000000000000..40c66590e022
--- /dev/null
+++ b/shims/spark41/src/main/scala/org/apache/spark/sql/execution/ui/TypeAlias.scala
@@ -0,0 +1,25 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.ui
+
+/**
+ * Ensures compatibility for the type HttpServletRequest across Spark 4.0 and earlier versions.
+ * Starting from Spark 4.0, `jakarta.servlet.http.HttpServletRequest` is used.
+ */
+object TypeAlias {
+ type HttpServletRequest = jakarta.servlet.http.HttpServletRequest
+}
diff --git a/shims/spark41/src/main/scala/org/apache/spark/sql/hive/execution/AbstractHiveTableScanExec.scala b/shims/spark41/src/main/scala/org/apache/spark/sql/hive/execution/AbstractHiveTableScanExec.scala
new file mode 100644
index 000000000000..8422d33e521d
--- /dev/null
+++ b/shims/spark41/src/main/scala/org/apache/spark/sql/hive/execution/AbstractHiveTableScanExec.scala
@@ -0,0 +1,273 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.hive.execution
+
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.analysis.CastSupport
+import org.apache.spark.sql.catalyst.catalog.HiveTableRelation
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.SchemaPruning.RootField
+import org.apache.spark.sql.execution._
+import org.apache.spark.sql.execution.metric.SQLMetrics
+import org.apache.spark.sql.hive._
+import org.apache.spark.sql.hive.client.HiveClientImpl
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types.{BooleanType, DataType, StructType}
+import org.apache.spark.util.Utils
+
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.hive.ql.io.{DelegateSymlinkTextInputFormat, SymlinkTextInputFormat}
+import org.apache.hadoop.hive.ql.metadata.{Partition => HivePartition}
+import org.apache.hadoop.hive.ql.plan.TableDesc
+import org.apache.hadoop.hive.serde.serdeConstants
+import org.apache.hadoop.hive.serde2.objectinspector._
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils.ObjectInspectorCopyOption
+import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoUtils
+import org.apache.hadoop.mapred.InputFormat
+
+import scala.collection.JavaConverters._
+
+/**
+ * The Hive table scan operator. Column and partition pruning are both handled.
+ *
+ * @param requestedAttributes
+ * Attributes to be fetched from the Hive table.
+ * @param relation
+ * The Hive table be scanned.
+ * @param partitionPruningPred
+ * An optional partition pruning predicate for partitioned table.
+ * @param prunedOutput
+ * The pruned output.
+ */
+abstract private[hive] class AbstractHiveTableScanExec(
+ requestedAttributes: Seq[Attribute],
+ relation: HiveTableRelation,
+ partitionPruningPred: Seq[Expression],
+ prunedOutput: Seq[Attribute] = Seq.empty[Attribute])(
+ @transient protected val sparkSession: SparkSession)
+ extends LeafExecNode
+ with CastSupport {
+
+ require(
+ partitionPruningPred.isEmpty || relation.isPartitioned,
+ "Partition pruning predicates only supported for partitioned tables.")
+
+ override def conf: SQLConf = sparkSession.sessionState.conf
+
+ override def nodeName: String = s"Scan hive ${relation.tableMeta.qualifiedName}"
+
+ override lazy val metrics = Map(
+ "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
+
+ override def producedAttributes: AttributeSet = outputSet ++
+ AttributeSet(partitionPruningPred.flatMap(_.references))
+
+ private val originalAttributes = AttributeMap(relation.output.map(a => a -> a))
+
+ override def output: Seq[Attribute] = {
+ if (prunedOutput.nonEmpty) {
+ prunedOutput
+ } else {
+ // Retrieve the original attributes based on expression ID so that capitalization matches.
+ requestedAttributes.map(attr => originalAttributes.getOrElse(attr, attr)).distinct
+ }
+ }
+
+ // Bind all partition key attribute references in the partition pruning predicate for later
+ // evaluation.
+ private lazy val boundPruningPred = partitionPruningPred.reduceLeftOption(And).map {
+ pred =>
+ require(
+ pred.dataType == BooleanType,
+ s"Data type of predicate $pred must be ${BooleanType.catalogString} rather than " +
+ s"${pred.dataType.catalogString}.")
+
+ BindReferences.bindReference(pred, relation.partitionCols)
+ }
+
+ @transient private lazy val hiveQlTable = HiveClientImpl.toHiveTable(relation.tableMeta)
+ @transient private lazy val tableDesc = new TableDesc(
+ getInputFormat(hiveQlTable.getInputFormatClass, conf),
+ hiveQlTable.getOutputFormatClass,
+ hiveQlTable.getMetadata)
+
+ // Create a local copy of hadoopConf,so that scan specific modifications should not impact
+ // other queries
+ @transient private lazy val hadoopConf = {
+ val c = sparkSession.sessionState.newHadoopConf()
+ // append columns ids and names before broadcast
+ addColumnMetadataToConf(c)
+ c
+ }
+
+ @transient private lazy val hadoopReader =
+ new HadoopTableReader(output, relation.partitionCols, tableDesc, sparkSession, hadoopConf)
+
+ private def castFromString(value: String, dataType: DataType) = {
+ cast(Literal(value), dataType).eval(null)
+ }
+
+ private def addColumnMetadataToConf(hiveConf: Configuration): Unit = {
+ // Specifies needed column IDs for those non-partitioning columns.
+ val columnOrdinals = AttributeMap(relation.dataCols.zipWithIndex)
+ val neededColumnIDs = output.flatMap(columnOrdinals.get).map(o => o: Integer)
+ val neededColumnNames = output.filter(columnOrdinals.contains).map(_.name)
+
+ HiveShim.appendReadColumns(hiveConf, neededColumnIDs, neededColumnNames)
+
+ val deserializer = tableDesc.getDeserializerClass.getConstructor().newInstance()
+ deserializer.initialize(hiveConf, tableDesc.getProperties)
+
+ // Specifies types and object inspectors of columns to be scanned.
+ val structOI = ObjectInspectorUtils
+ .getStandardObjectInspector(deserializer.getObjectInspector, ObjectInspectorCopyOption.JAVA)
+ .asInstanceOf[StructObjectInspector]
+
+ val columnTypeNames = structOI.getAllStructFieldRefs.asScala
+ .map(_.getFieldObjectInspector)
+ .map(TypeInfoUtils.getTypeInfoFromObjectInspector(_).getTypeName)
+ .mkString(",")
+
+ hiveConf.set(serdeConstants.LIST_COLUMN_TYPES, columnTypeNames)
+ hiveConf.set(serdeConstants.LIST_COLUMNS, relation.dataCols.map(_.name).mkString(","))
+ }
+
+ /**
+ * Prunes partitions not involve the query plan.
+ *
+ * @param partitions
+ * All partitions of the relation.
+ * @return
+ * Partitions that are involved in the query plan.
+ */
+ private[hive] def prunePartitions(partitions: Seq[HivePartition]): Seq[HivePartition] = {
+ boundPruningPred match {
+ case None => partitions
+ case Some(shouldKeep) =>
+ partitions.filter {
+ part =>
+ val dataTypes = relation.partitionCols.map(_.dataType)
+ val castedValues = part.getValues.asScala
+ .zip(dataTypes)
+ .map { case (value, dataType) => castFromString(value, dataType) }
+
+ // Only partitioned values are needed here, since the predicate has
+ // already been bound to partition key attribute references.
+ val row = InternalRow.fromSeq(castedValues.toSeq)
+ shouldKeep.eval(row).asInstanceOf[Boolean]
+ }
+ }
+ }
+
+ @transient lazy val prunedPartitions: Seq[HivePartition] = {
+ if (relation.prunedPartitions.nonEmpty) {
+ val hivePartitions =
+ relation.prunedPartitions.get.map(HiveClientImpl.toHivePartition(_, hiveQlTable))
+ if (partitionPruningPred.forall(!ExecSubqueryExpression.hasSubquery(_))) {
+ hivePartitions
+ } else {
+ prunePartitions(hivePartitions)
+ }
+ } else {
+ if (
+ sparkSession.sessionState.conf.metastorePartitionPruning &&
+ partitionPruningPred.nonEmpty
+ ) {
+ rawPartitions
+ } else {
+ prunePartitions(rawPartitions)
+ }
+ }
+ }
+
+ // exposed for tests
+ @transient lazy val rawPartitions: Seq[HivePartition] = {
+ val prunedPartitions =
+ if (
+ sparkSession.sessionState.conf.metastorePartitionPruning &&
+ partitionPruningPred.nonEmpty
+ ) {
+ // Retrieve the original attributes based on expression ID so that capitalization matches.
+ val normalizedFilters = partitionPruningPred.map(_.transform {
+ case a: AttributeReference => originalAttributes(a)
+ })
+ sparkSession.sessionState.catalog
+ .listPartitionsByFilter(relation.tableMeta.identifier, normalizedFilters)
+ } else {
+ sparkSession.sessionState.catalog.listPartitions(relation.tableMeta.identifier)
+ }
+ prunedPartitions.map(HiveClientImpl.toHivePartition(_, hiveQlTable))
+ }
+
+ override protected def doExecute(): RDD[InternalRow] = {
+ // Using dummyCallSite, as getCallSite can turn out to be expensive with
+ // multiple partitions.
+ val rdd = if (!relation.isPartitioned) {
+ Utils.withDummyCallSite(sparkContext) {
+ hadoopReader.makeRDDForTable(hiveQlTable)
+ }
+ } else {
+ Utils.withDummyCallSite(sparkContext) {
+ hadoopReader.makeRDDForPartitionedTable(prunedPartitions)
+ }
+ }
+ val numOutputRows = longMetric("numOutputRows")
+ // Avoid to serialize MetastoreRelation because schema is lazy. (see SPARK-15649)
+ val outputSchema = schema
+ rdd.mapPartitionsWithIndexInternal {
+ (index, iter) =>
+ val proj = UnsafeProjection.create(outputSchema)
+ proj.initialize(index)
+ iter.map {
+ r =>
+ numOutputRows += 1
+ proj(r)
+ }
+ }
+ }
+
+ // Filters unused DynamicPruningExpression expressions - one which has been replaced
+ // with DynamicPruningExpression(Literal.TrueLiteral) during Physical Planning
+ private def filterUnusedDynamicPruningExpressions(
+ predicates: Seq[Expression]): Seq[Expression] = {
+ predicates.filterNot(_ == DynamicPruningExpression(Literal.TrueLiteral))
+ }
+
+ // Optionally returns a delegate input format based on the provided input format class.
+ // This is currently used to replace SymlinkTextInputFormat with DelegateSymlinkTextInputFormat
+ // in order to fix SPARK-40815.
+ private def getInputFormat(
+ inputFormatClass: Class[_ <: InputFormat[_, _]],
+ conf: SQLConf): Class[_ <: InputFormat[_, _]] = {
+ if (
+ inputFormatClass == classOf[SymlinkTextInputFormat] &&
+ conf != null && conf.getConf(HiveUtils.USE_DELEGATE_FOR_SYMLINK_TEXT_INPUT_FORMAT)
+ ) {
+ classOf[DelegateSymlinkTextInputFormat]
+ } else {
+ inputFormatClass
+ }
+ }
+
+ override def otherCopyArgs: Seq[AnyRef] = Seq(sparkSession)
+
+ def pruneSchema(schema: StructType, requestedFields: Seq[RootField]): StructType = {
+ SchemaPruning.pruneSchema(schema, requestedFields)
+ }
+}