diff --git a/.github/workflows/util/install-spark-resources.sh b/.github/workflows/util/install-spark-resources.sh index 4d1dd27a9c45..d245dbb4ac31 100755 --- a/.github/workflows/util/install-spark-resources.sh +++ b/.github/workflows/util/install-spark-resources.sh @@ -119,6 +119,11 @@ case "$1" in cd ${INSTALL_DIR} && \ install_spark "4.0.1" "3" "2.12" ;; +4.1) + # Spark-4.x, scala 2.12 // using 2.12 as a hack as 4.0 does not have 2.13 suffix + cd ${INSTALL_DIR} && \ + install_spark "4.1.0" "3" "2.12" + ;; *) echo "Spark version is expected to be specified." exit 1 diff --git a/.github/workflows/velox_backend_x86.yml b/.github/workflows/velox_backend_x86.yml index 9adde6c5ce66..5ef70a79bc0f 100644 --- a/.github/workflows/velox_backend_x86.yml +++ b/.github/workflows/velox_backend_x86.yml @@ -1481,3 +1481,109 @@ jobs: **/target/*.log **/gluten-ut/**/hs_err_*.log **/gluten-ut/**/core.* + + spark-test-spark41: + needs: build-native-lib-centos-7 + runs-on: ubuntu-22.04 + env: + SPARK_TESTING: true + container: apache/gluten:centos-8-jdk17 + steps: + - uses: actions/checkout@v2 + - name: Download All Artifacts + uses: actions/download-artifact@v4 + with: + name: velox-native-lib-centos-7-${{github.sha}} + path: ./cpp/build/releases + - name: Download Arrow Jars + uses: actions/download-artifact@v4 + with: + name: arrow-jars-centos-7-${{github.sha}} + path: /root/.m2/repository/org/apache/arrow/ + - name: Prepare + run: | + dnf module -y install python39 && \ + alternatives --set python3 /usr/bin/python3.9 && \ + pip3 install setuptools==77.0.3 && \ + pip3 install pyspark==3.5.5 cython && \ + pip3 install pandas==2.2.3 pyarrow==20.0.0 + - name: Prepare Spark Resources for Spark 4.1.0 #TODO remove after image update + run: | + rm -rf /opt/shims/spark41 + bash .github/workflows/util/install-spark-resources.sh 4.1 + mv /opt/shims/spark41/spark_home/assembly/target/scala-2.12 /opt/shims/spark41/spark_home/assembly/target/scala-2.13 + - name: Build and Run unit test for Spark 4.1.0 with scala-2.13 (other tests) + run: | + cd $GITHUB_WORKSPACE/ + export SPARK_SCALA_VERSION=2.13 + yum install -y java-17-openjdk-devel + export JAVA_HOME=/usr/lib/jvm/java-17-openjdk + export PATH=$JAVA_HOME/bin:$PATH + java -version + $MVN_CMD clean test -Pspark-4.1 -Pscala-2.13 -Pjava-17 -Pbackends-velox \ + -Pspark-ut -DargLine="-Dspark.test.home=/opt/shims/spark41/spark_home/" \ + -DtagsToExclude=org.apache.spark.tags.ExtendedSQLTest,org.apache.gluten.tags.UDFTest,org.apache.gluten.tags.EnhancedFeaturesTest,org.apache.gluten.tags.SkipTest + - name: Upload test report + if: always() + uses: actions/upload-artifact@v4 + with: + name: ${{ github.job }}-report + path: '**/surefire-reports/TEST-*.xml' + - name: Upload unit tests log files + if: ${{ !success() }} + uses: actions/upload-artifact@v4 + with: + name: ${{ github.job }}-test-log + path: | + **/target/*.log + **/gluten-ut/**/hs_err_*.log + **/gluten-ut/**/core.* + + spark-test-spark41-slow: + needs: build-native-lib-centos-7 + runs-on: ubuntu-22.04 + env: + SPARK_TESTING: true + container: apache/gluten:centos-8-jdk17 + steps: + - uses: actions/checkout@v2 + - name: Download All Artifacts + uses: actions/download-artifact@v4 + with: + name: velox-native-lib-centos-7-${{github.sha}} + path: ./cpp/build/releases + - name: Download Arrow Jars + uses: actions/download-artifact@v4 + with: + name: arrow-jars-centos-7-${{github.sha}} + path: /root/.m2/repository/org/apache/arrow/ + - name: Prepare Spark Resources for Spark 4.1.0 #TODO remove after image update + run: | + rm -rf /opt/shims/spark41 + bash .github/workflows/util/install-spark-resources.sh 4.1 + mv /opt/shims/spark41/spark_home/assembly/target/scala-2.12 /opt/shims/spark41/spark_home/assembly/target/scala-2.13 + - name: Build and Run unit test for Spark 4.0 (slow tests) + run: | + cd $GITHUB_WORKSPACE/ + yum install -y java-17-openjdk-devel + export JAVA_HOME=/usr/lib/jvm/java-17-openjdk + export PATH=$JAVA_HOME/bin:$PATH + java -version + $MVN_CMD clean test -Pspark-4.1 -Pscala-2.13 -Pjava-17 -Pbackends-velox -Pspark-ut \ + -DargLine="-Dspark.test.home=/opt/shims/spark41/spark_home/" \ + -DtagsToInclude=org.apache.spark.tags.ExtendedSQLTest + - name: Upload test report + if: always() + uses: actions/upload-artifact@v4 + with: + name: ${{ github.job }}-report + path: '**/surefire-reports/TEST-*.xml' + - name: Upload unit tests log files + if: ${{ !success() }} + uses: actions/upload-artifact@v4 + with: + name: ${{ github.job }}-test-log + path: | + **/target/*.log + **/gluten-ut/**/hs_err_*.log + **/gluten-ut/**/core.* diff --git a/backends-velox/src/main/scala/org/apache/gluten/extension/ArrowConvertorRule.scala b/backends-velox/src/main/scala/org/apache/gluten/extension/ArrowConvertorRule.scala index 25371be8d1fa..925f2a6be94f 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/extension/ArrowConvertorRule.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/extension/ArrowConvertorRule.scala @@ -38,6 +38,24 @@ import java.nio.charset.StandardCharsets import scala.collection.convert.ImplicitConversions.`map AsScala` +/** + * Extracts a CSVTable from a DataSourceV2Relation. + * + * Only the table variable of DataSourceV2Relation is accessed to improve compatibility across + * different Spark versions. + * @since Spark + * 4.1 + */ +private object CSVTableExtractor { + def unapply(relation: DataSourceV2Relation): Option[(DataSourceV2Relation, CSVTable)] = { + relation.table match { + case t: CSVTable => + Some((relation, t)) + case _ => None + } + } +} + @Experimental case class ArrowConvertorRule(session: SparkSession) extends Rule[LogicalPlan] { override def apply(plan: LogicalPlan): LogicalPlan = { @@ -56,25 +74,15 @@ case class ArrowConvertorRule(session: SparkSession) extends Rule[LogicalPlan] { l.copy(relation = r.copy(fileFormat = new ArrowCSVFileFormat(csvOptions))(session)) case _ => l } - case d @ DataSourceV2Relation( - t @ CSVTable( - name, - sparkSession, - options, - paths, - userSpecifiedSchema, - fallbackFileFormat), - _, - _, - _, - _) if validate(session, t.dataSchema, options.asCaseSensitiveMap().toMap) => + case CSVTableExtractor(d, t) + if validate(session, t.dataSchema, t.options.asCaseSensitiveMap().toMap) => d.copy(table = ArrowCSVTable( - "arrow" + name, - sparkSession, - options, - paths, - userSpecifiedSchema, - fallbackFileFormat)) + "arrow" + t.name, + t.sparkSession, + t.options, + t.paths, + t.userSpecifiedSchema, + t.fallbackFileFormat)) case r => r } diff --git a/backends-velox/src/main/scala/org/apache/gluten/utils/ParquetMetadataUtils.scala b/backends-velox/src/main/scala/org/apache/gluten/utils/ParquetMetadataUtils.scala index 6239ab5ad749..ab76cba4aa5d 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/utils/ParquetMetadataUtils.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/utils/ParquetMetadataUtils.scala @@ -21,7 +21,7 @@ import org.apache.gluten.sql.shims.SparkShimLoader import org.apache.spark.internal.Logging import org.apache.spark.sql.execution.datasources.DataSourceUtils -import org.apache.spark.sql.execution.datasources.parquet.{ParquetFooterReader, ParquetOptions} +import org.apache.spark.sql.execution.datasources.parquet.{ParquetFooterReaderShim, ParquetOptions} import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileSystem, LocatedFileStatus, Path} @@ -135,7 +135,7 @@ object ParquetMetadataUtils extends Logging { parquetOptions: ParquetOptions): Option[String] = { val footer = try { - ParquetFooterReader.readFooter(conf, fileStatus, ParquetMetadataConverter.NO_FILTER) + ParquetFooterReaderShim.readFooter(conf, fileStatus, ParquetMetadataConverter.NO_FILTER) } catch { case e: Exception if ExceptionUtils.hasCause(e, classOf[ParquetCryptoRuntimeException]) => return Some("Encrypted Parquet footer detected.") diff --git a/backends-velox/src/test/scala/org/apache/gluten/execution/MiscOperatorSuite.scala b/backends-velox/src/test/scala/org/apache/gluten/execution/MiscOperatorSuite.scala index c3132420916a..cfddfb8e21e3 100644 --- a/backends-velox/src/test/scala/org/apache/gluten/execution/MiscOperatorSuite.scala +++ b/backends-velox/src/test/scala/org/apache/gluten/execution/MiscOperatorSuite.scala @@ -753,7 +753,11 @@ class MiscOperatorSuite extends VeloxWholeStageTransformerSuite with AdaptiveSpa val df = sql("SELECT 1") checkAnswer(df, Row(1)) val plan = df.queryExecution.executedPlan - assert(plan.find(_.isInstanceOf[RDDScanExec]).isDefined) + if (isSparkVersionGE("4.1")) { + assert(plan.find(_.getClass.getSimpleName == "OneRowRelationExec").isDefined) + } else { + assert(plan.find(_.isInstanceOf[RDDScanExec]).isDefined) + } assert(plan.find(_.isInstanceOf[ProjectExecTransformer]).isDefined) assert(plan.find(_.isInstanceOf[RowToVeloxColumnarExec]).isDefined) } diff --git a/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxHashJoinSuite.scala b/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxHashJoinSuite.scala index 53f44a2ccc81..5958baa3771f 100644 --- a/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxHashJoinSuite.scala +++ b/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxHashJoinSuite.scala @@ -92,12 +92,9 @@ class VeloxHashJoinSuite extends VeloxWholeStageTransformerSuite { // The computing is combined into one single whole stage transformer. val wholeStages = plan.collect { case wst: WholeStageTransformer => wst } - if (SparkShimLoader.getSparkVersion.startsWith("3.2.")) { + if (isSparkVersionLE("3.2")) { assert(wholeStages.length == 1) - } else if ( - SparkShimLoader.getSparkVersion.startsWith("3.5.") || - SparkShimLoader.getSparkVersion.startsWith("4.0.") - ) { + } else if (isSparkVersionGE("3.5")) { assert(wholeStages.length == 5) } else { assert(wholeStages.length == 3) diff --git a/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxStringFunctionsSuite.scala b/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxStringFunctionsSuite.scala index 06f0acb784b8..37f13bcea8c3 100644 --- a/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxStringFunctionsSuite.scala +++ b/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxStringFunctionsSuite.scala @@ -544,7 +544,8 @@ class VeloxStringFunctionsSuite extends VeloxWholeStageTransformerSuite { s"from $LINEITEM_TABLE limit 5") { _ => } } - testWithMinSparkVersion("split", "3.4") { + // TODO: fix on spark-4.1 + testWithSpecifiedSparkVersion("split", "3.4", "3.5") { runQueryAndCompare( s"select l_orderkey, l_comment, split(l_comment, '') " + s"from $LINEITEM_TABLE limit 5") { diff --git a/backends-velox/src/test/scala/org/apache/gluten/execution/python/ArrowEvalPythonExecSuite.scala b/backends-velox/src/test/scala/org/apache/gluten/execution/python/ArrowEvalPythonExecSuite.scala index 52a17995f386..f8e2554da7c7 100644 --- a/backends-velox/src/test/scala/org/apache/gluten/execution/python/ArrowEvalPythonExecSuite.scala +++ b/backends-velox/src/test/scala/org/apache/gluten/execution/python/ArrowEvalPythonExecSuite.scala @@ -39,7 +39,8 @@ class ArrowEvalPythonExecSuite extends WholeStageTransformerSuite { .set("spark.executor.cores", "1") } - test("arrow_udf test: without projection") { + // TODO: fix on spark-4.1 + testWithMaxSparkVersion("arrow_udf test: without projection", "4.0") { lazy val base = Seq(("1", 1), ("1", 2), ("2", 1), ("2", 2), ("3", 1), ("3", 2), ("0", 1), ("3", 0)) .toDF("a", "b") @@ -59,7 +60,8 @@ class ArrowEvalPythonExecSuite extends WholeStageTransformerSuite { checkAnswer(df2, expected) } - test("arrow_udf test: with unrelated projection") { + // TODO: fix on spark-4.1 + testWithMaxSparkVersion("arrow_udf test: with unrelated projection", "4.0") { lazy val base = Seq(("1", 1), ("1", 2), ("2", 1), ("2", 2), ("3", 1), ("3", 2), ("0", 1), ("3", 0)) .toDF("a", "b") diff --git a/dev/format-scala-code.sh b/dev/format-scala-code.sh index 96a782405bbf..4d2fba5ca166 100755 --- a/dev/format-scala-code.sh +++ b/dev/format-scala-code.sh @@ -22,7 +22,7 @@ MVN_CMD="${BASEDIR}/../build/mvn" # If a new profile is introduced for new modules, please add it here to ensure # the new modules are covered. PROFILES="-Pbackends-velox -Pceleborn,uniffle -Piceberg,delta,hudi,paimon \ - -Pspark-3.2,spark-3.3,spark-3.4,spark-3.5,spark-4.0 -Pspark-ut" + -Pspark-3.2,spark-3.3,spark-3.4,spark-3.5,spark-4.0,spark-4.1 -Pspark-ut" COMMAND=$1 diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/caller/CallerInfo.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/caller/CallerInfo.scala index 732c898285dd..7dfaaaa774c5 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/extension/caller/CallerInfo.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/extension/caller/CallerInfo.scala @@ -18,7 +18,7 @@ package org.apache.gluten.extension.caller import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec import org.apache.spark.sql.execution.columnar.InMemoryRelation -import org.apache.spark.sql.execution.streaming.StreamExecution +import org.apache.spark.util.SparkVersionUtil /** * Helper API that stores information about the call site of the columnar rule. Specific columnar @@ -70,7 +70,12 @@ object CallerInfo { } private def inStreamingCall(stack: Seq[StackTraceElement]): Boolean = { - stack.exists(_.getClassName.equals(StreamExecution.getClass.getName.split('$').head)) + val streamName = if (SparkVersionUtil.gteSpark41) { + "org.apache.spark.sql.execution.streaming.runtime.StreamExecution" + } else { + "org.apache.spark.sql.execution.streaming.StreamExecution" + } + stack.exists(_.getClassName.equals(streamName)) } private def inBloomFilterStatFunctionCall(stack: Seq[StackTraceElement]): Boolean = { diff --git a/gluten-core/src/main/scala/org/apache/spark/util/SparkVersionUtil.scala b/gluten-core/src/main/scala/org/apache/spark/util/SparkVersionUtil.scala index efa0c63dca52..50114ab7023e 100644 --- a/gluten-core/src/main/scala/org/apache/spark/util/SparkVersionUtil.scala +++ b/gluten-core/src/main/scala/org/apache/spark/util/SparkVersionUtil.scala @@ -25,6 +25,7 @@ object SparkVersionUtil { val gteSpark33: Boolean = comparedWithSpark33 >= 0 val gteSpark35: Boolean = comparedWithSpark35 >= 0 val gteSpark40: Boolean = compareMajorMinorVersion((4, 0)) >= 0 + val gteSpark41: Boolean = compareMajorMinorVersion((4, 1)) >= 0 // Returns X. X < 0 if one < other, x == 0 if one == other, x > 0 if one > other. def compareMajorMinorVersion(other: (Int, Int)): Int = { diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala b/gluten-substrait/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala index 52f6d31d1d1b..418de8578f5c 100644 --- a/gluten-substrait/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala +++ b/gluten-substrait/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala @@ -856,13 +856,6 @@ object ExpressionConverter extends SQLConfHelper with Logging { dateAdd.children, dateAdd ) - case timeAdd: TimeAdd => - BackendsApiManager.getSparkPlanExecApiInstance.genDateAddTransformer( - attributeSeq, - substraitExprName, - timeAdd.children, - timeAdd - ) case ss: StringSplit => BackendsApiManager.getSparkPlanExecApiInstance.genStringSplitTransformer( substraitExprName, diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/expression/ExpressionMappings.scala b/gluten-substrait/src/main/scala/org/apache/gluten/expression/ExpressionMappings.scala index b0b7c8079315..b13aced2a62c 100644 --- a/gluten-substrait/src/main/scala/org/apache/gluten/expression/ExpressionMappings.scala +++ b/gluten-substrait/src/main/scala/org/apache/gluten/expression/ExpressionMappings.scala @@ -179,7 +179,6 @@ object ExpressionMappings { Sig[Second](EXTRACT), Sig[FromUnixTime](FROM_UNIXTIME), Sig[DateAdd](DATE_ADD), - Sig[TimeAdd](TIMESTAMP_ADD), Sig[DateSub](DATE_SUB), Sig[DateDiff](DATE_DIFF), Sig[ToUnixTimestamp](TO_UNIX_TIMESTAMP), diff --git a/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/GlutenImplicits.scala b/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/GlutenImplicits.scala index 474a27176906..7267ce56ba1c 100644 --- a/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/GlutenImplicits.scala +++ b/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/GlutenImplicits.scala @@ -18,8 +18,8 @@ package org.apache.spark.sql.execution import org.apache.gluten.exception.GlutenException import org.apache.gluten.execution.{GlutenPlan, WholeStageTransformer} +import org.apache.gluten.sql.shims.SparkShimLoader import org.apache.gluten.utils.PlanUtil - import org.apache.spark.sql.{Column, Dataset, SparkSession} import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.logical.{CommandResult, LogicalPlan} @@ -130,10 +130,8 @@ object GlutenImplicits { val (innerNumGlutenNodes, innerFallbackNodeToReason) = withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false") { // re-plan manually to skip cached data - val newSparkPlan = QueryExecution.createSparkPlan( - spark, - spark.sessionState.planner, - p.inputPlan.logicalLink.get) + val newSparkPlan = SparkShimLoader.getSparkShims.createSparkPlan( + spark, spark.sessionState.planner, p.inputPlan.logicalLink.get) val newExecutedPlan = QueryExecution.prepareExecutedPlan( spark, newSparkPlan diff --git a/gluten-ut/spark40/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/GlutenParquetRowIndexSuite.scala b/gluten-ut/spark40/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/GlutenParquetRowIndexSuite.scala index 5cf41b7a9ed5..570b6d5e0c1a 100644 --- a/gluten-ut/spark40/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/GlutenParquetRowIndexSuite.scala +++ b/gluten-ut/spark40/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/GlutenParquetRowIndexSuite.scala @@ -42,7 +42,7 @@ class GlutenParquetRowIndexSuite extends ParquetRowIndexSuite with GlutenSQLTest import testImplicits._ private def readRowGroupRowCounts(path: String): Seq[Long] = { - ParquetFooterReader + ParquetFooterReaderShim .readFooter( spark.sessionState.newHadoopConf(), new Path(path), diff --git a/pom.xml b/pom.xml index c87d46ceb5c1..b88fa8dfee2b 100644 --- a/pom.xml +++ b/pom.xml @@ -938,7 +938,7 @@ scala-2.13 - 2.13.16 + 2.13.17 2.13 3.8.3 @@ -1270,6 +1270,86 @@ + + spark-4.1 + + 4.1 + spark-sql-columnar-shims-spark41 + 4.1.0 + 1.10.0 + delta-spark + 4.0.0 + 40 + 1.1.0 + 1.3.0 + 2.18.2 + 2.18.2 + 2.18.2 + 3.4.1 + 4.13.1 + 33.4.0-jre + 2.0.16 + 2.24.3 + 3.17.0 + 18.1.0 + 18.1.0 + + + + org.slf4j + slf4j-api + ${slf4j.version} + provided + + + org.apache.logging.log4j + log4j-slf4j2-impl + ${log4j.version} + provided + + + + + + org.apache.maven.plugins + maven-enforcer-plugin + + + enforce-java-17+ + + enforce + + + + + + java-17,java-21 + false + "-P spark-4.1" requires JDK 17+ + + + + + + enforce-scala-213 + + enforce + + + + + + scala-2.13 + "-P spark-4.1" requires Scala 2.13 + + + + + + + + + hadoop-2.7.4 diff --git a/shims/common/src/main/scala/org/apache/gluten/sql/shims/SparkShims.scala b/shims/common/src/main/scala/org/apache/gluten/sql/shims/SparkShims.scala index de220cab82e4..9164f4b7c43a 100644 --- a/shims/common/src/main/scala/org/apache/gluten/sql/shims/SparkShims.scala +++ b/shims/common/src/main/scala/org/apache/gluten/sql/shims/SparkShims.scala @@ -355,4 +355,14 @@ trait SparkShims { def unsupportedCodec: Seq[CompressionCodecName] = { Seq(CompressionCodecName.LZO, CompressionCodecName.BROTLI) } + + /** + * Shim layer for QueryExecution to maintain compatibility across different Spark versions. + * @since Spark + * 4.1 + */ + def createSparkPlan( + sparkSession: SparkSession, + planner: SparkPlanner, + plan: LogicalPlan): SparkPlan } diff --git a/shims/pom.xml b/shims/pom.xml index 9a8e639e04ae..8c10ce640fee 100644 --- a/shims/pom.xml +++ b/shims/pom.xml @@ -98,6 +98,12 @@ spark40 + + spark-4.1 + + spark41 + + default diff --git a/shims/spark32/src/main/scala/org/apache/gluten/sql/shims/spark32/Spark32Shims.scala b/shims/spark32/src/main/scala/org/apache/gluten/sql/shims/spark32/Spark32Shims.scala index 367b66f9c424..d1bf07e64b12 100644 --- a/shims/spark32/src/main/scala/org/apache/gluten/sql/shims/spark32/Spark32Shims.scala +++ b/shims/spark32/src/main/scala/org/apache/gluten/sql/shims/spark32/Spark32Shims.scala @@ -39,7 +39,7 @@ import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap import org.apache.spark.sql.catalyst.util.RebaseDateTime.RebaseSpec import org.apache.spark.sql.connector.catalog.Table import org.apache.spark.sql.connector.expressions.Transform -import org.apache.spark.sql.execution.{FileSourceScanExec, PartitionedFileUtil, SparkPlan} +import org.apache.spark.sql.execution.{FileSourceScanExec, PartitionedFileUtil, QueryExecution, SparkPlan, SparkPlanner} import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.FileFormatWriter.Empty2Null import org.apache.spark.sql.execution.datasources.parquet.ParquetFilters @@ -303,4 +303,16 @@ class Spark32Shims extends SparkShims { override def getErrorMessage(raiseError: RaiseError): Option[Expression] = { Some(raiseError.child) } + + /** + * Shim layer for QueryExecution to maintain compatibility across different Spark versions. + * + * @since Spark + * 4.1 + */ + override def createSparkPlan( + sparkSession: SparkSession, + planner: SparkPlanner, + plan: LogicalPlan): SparkPlan = + QueryExecution.createSparkPlan(sparkSession, planner, plan) } diff --git a/shims/spark32/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFooterReaderShim.scala b/shims/spark32/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFooterReaderShim.scala new file mode 100644 index 000000000000..b1419e5e6233 --- /dev/null +++ b/shims/spark32/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFooterReaderShim.scala @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.datasources.parquet + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{FileStatus, Path} +import org.apache.parquet.format.converter.ParquetMetadataConverter +import org.apache.parquet.hadoop.metadata.ParquetMetadata + +/** Shim layer for ParquetFooterReader to maintain compatibility across different Spark versions. */ +object ParquetFooterReaderShim { + + /** @since Spark 4.1 */ + def readFooter( + configuration: Configuration, + fileStatus: FileStatus, + filter: ParquetMetadataConverter.MetadataFilter): ParquetMetadata = { + ParquetFooterReader.readFooter(configuration, fileStatus, filter) + } + + /** @since Spark 4.1 */ + def readFooter( + configuration: Configuration, + file: Path, + filter: ParquetMetadataConverter.MetadataFilter): ParquetMetadata = { + ParquetFooterReader.readFooter(configuration, file, filter) + } +} diff --git a/shims/spark33/src/main/scala/org/apache/gluten/sql/shims/spark33/Spark33Shims.scala b/shims/spark33/src/main/scala/org/apache/gluten/sql/shims/spark33/Spark33Shims.scala index 2b6affcded6e..a18fb3171991 100644 --- a/shims/spark33/src/main/scala/org/apache/gluten/sql/shims/spark33/Spark33Shims.scala +++ b/shims/spark33/src/main/scala/org/apache/gluten/sql/shims/spark33/Spark33Shims.scala @@ -40,7 +40,7 @@ import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, TimestampFormatte import org.apache.spark.sql.catalyst.util.RebaseDateTime.RebaseSpec import org.apache.spark.sql.connector.catalog.Table import org.apache.spark.sql.connector.expressions.Transform -import org.apache.spark.sql.execution.{FileSourceScanExec, PartitionedFileUtil, SparkPlan} +import org.apache.spark.sql.execution.{FileSourceScanExec, PartitionedFileUtil, QueryExecution, SparkPlan, SparkPlanner} import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.FileFormatWriter.Empty2Null import org.apache.spark.sql.execution.datasources.parquet.ParquetFilters @@ -408,4 +408,16 @@ class Spark33Shims extends SparkShims { override def getErrorMessage(raiseError: RaiseError): Option[Expression] = { Some(raiseError.child) } + + /** + * Shim layer for QueryExecution to maintain compatibility across different Spark versions. + * + * @since Spark + * 4.1 + */ + override def createSparkPlan( + sparkSession: SparkSession, + planner: SparkPlanner, + plan: LogicalPlan): SparkPlan = + QueryExecution.createSparkPlan(sparkSession, planner, plan) } diff --git a/shims/spark33/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFooterReaderShim.scala b/shims/spark33/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFooterReaderShim.scala new file mode 100644 index 000000000000..b1419e5e6233 --- /dev/null +++ b/shims/spark33/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFooterReaderShim.scala @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.datasources.parquet + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{FileStatus, Path} +import org.apache.parquet.format.converter.ParquetMetadataConverter +import org.apache.parquet.hadoop.metadata.ParquetMetadata + +/** Shim layer for ParquetFooterReader to maintain compatibility across different Spark versions. */ +object ParquetFooterReaderShim { + + /** @since Spark 4.1 */ + def readFooter( + configuration: Configuration, + fileStatus: FileStatus, + filter: ParquetMetadataConverter.MetadataFilter): ParquetMetadata = { + ParquetFooterReader.readFooter(configuration, fileStatus, filter) + } + + /** @since Spark 4.1 */ + def readFooter( + configuration: Configuration, + file: Path, + filter: ParquetMetadataConverter.MetadataFilter): ParquetMetadata = { + ParquetFooterReader.readFooter(configuration, file, filter) + } +} diff --git a/shims/spark34/src/main/scala/org/apache/gluten/sql/shims/spark34/Spark34Shims.scala b/shims/spark34/src/main/scala/org/apache/gluten/sql/shims/spark34/Spark34Shims.scala index 39dd71a6bc3c..cdbeaa47838b 100644 --- a/shims/spark34/src/main/scala/org/apache/gluten/sql/shims/spark34/Spark34Shims.scala +++ b/shims/spark34/src/main/scala/org/apache/gluten/sql/shims/spark34/Spark34Shims.scala @@ -651,4 +651,16 @@ class Spark34Shims extends SparkShims { override def getErrorMessage(raiseError: RaiseError): Option[Expression] = { Some(raiseError.child) } + + /** + * Shim layer for QueryExecution to maintain compatibility across different Spark versions. + * + * @since Spark + * 4.1 + */ + override def createSparkPlan( + sparkSession: SparkSession, + planner: SparkPlanner, + plan: LogicalPlan): SparkPlan = + QueryExecution.createSparkPlan(sparkSession, planner, plan) } diff --git a/shims/spark34/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFooterReaderShim.scala b/shims/spark34/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFooterReaderShim.scala new file mode 100644 index 000000000000..b1419e5e6233 --- /dev/null +++ b/shims/spark34/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFooterReaderShim.scala @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.datasources.parquet + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{FileStatus, Path} +import org.apache.parquet.format.converter.ParquetMetadataConverter +import org.apache.parquet.hadoop.metadata.ParquetMetadata + +/** Shim layer for ParquetFooterReader to maintain compatibility across different Spark versions. */ +object ParquetFooterReaderShim { + + /** @since Spark 4.1 */ + def readFooter( + configuration: Configuration, + fileStatus: FileStatus, + filter: ParquetMetadataConverter.MetadataFilter): ParquetMetadata = { + ParquetFooterReader.readFooter(configuration, fileStatus, filter) + } + + /** @since Spark 4.1 */ + def readFooter( + configuration: Configuration, + file: Path, + filter: ParquetMetadataConverter.MetadataFilter): ParquetMetadata = { + ParquetFooterReader.readFooter(configuration, file, filter) + } +} diff --git a/shims/spark35/src/main/scala/org/apache/gluten/sql/shims/spark35/Spark35Shims.scala b/shims/spark35/src/main/scala/org/apache/gluten/sql/shims/spark35/Spark35Shims.scala index 78da08190f17..d993cc0bfd20 100644 --- a/shims/spark35/src/main/scala/org/apache/gluten/sql/shims/spark35/Spark35Shims.scala +++ b/shims/spark35/src/main/scala/org/apache/gluten/sql/shims/spark35/Spark35Shims.scala @@ -702,4 +702,16 @@ class Spark35Shims extends SparkShims { override def getErrorMessage(raiseError: RaiseError): Option[Expression] = { Some(raiseError.child) } + + /** + * Shim layer for QueryExecution to maintain compatibility across different Spark versions. + * + * @since Spark + * 4.1 + */ + override def createSparkPlan( + sparkSession: SparkSession, + planner: SparkPlanner, + plan: LogicalPlan): SparkPlan = + QueryExecution.createSparkPlan(sparkSession, planner, plan) } diff --git a/shims/spark35/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFooterReaderShim.scala b/shims/spark35/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFooterReaderShim.scala new file mode 100644 index 000000000000..b1419e5e6233 --- /dev/null +++ b/shims/spark35/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFooterReaderShim.scala @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.datasources.parquet + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{FileStatus, Path} +import org.apache.parquet.format.converter.ParquetMetadataConverter +import org.apache.parquet.hadoop.metadata.ParquetMetadata + +/** Shim layer for ParquetFooterReader to maintain compatibility across different Spark versions. */ +object ParquetFooterReaderShim { + + /** @since Spark 4.1 */ + def readFooter( + configuration: Configuration, + fileStatus: FileStatus, + filter: ParquetMetadataConverter.MetadataFilter): ParquetMetadata = { + ParquetFooterReader.readFooter(configuration, fileStatus, filter) + } + + /** @since Spark 4.1 */ + def readFooter( + configuration: Configuration, + file: Path, + filter: ParquetMetadataConverter.MetadataFilter): ParquetMetadata = { + ParquetFooterReader.readFooter(configuration, file, filter) + } +} diff --git a/shims/spark40/src/main/scala/org/apache/gluten/sql/shims/spark40/Spark40Shims.scala b/shims/spark40/src/main/scala/org/apache/gluten/sql/shims/spark40/Spark40Shims.scala index 80c22f2fad25..e5258eafa46c 100644 --- a/shims/spark40/src/main/scala/org/apache/gluten/sql/shims/spark40/Spark40Shims.scala +++ b/shims/spark40/src/main/scala/org/apache/gluten/sql/shims/spark40/Spark40Shims.scala @@ -40,6 +40,7 @@ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, InternalRowComparableWrapper, TimestampFormatter} import org.apache.spark.sql.catalyst.util.RebaseDateTime.RebaseSpec +import org.apache.spark.sql.classic.ClassicConversions._ import org.apache.spark.sql.connector.catalog.Table import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.connector.read.{HasPartitionKey, InputPartition, Scan} @@ -752,4 +753,16 @@ class Spark40Shims extends SparkShims { override def unsupportedCodec: Seq[CompressionCodecName] = { Seq(CompressionCodecName.LZO, CompressionCodecName.BROTLI, CompressionCodecName.LZ4_RAW) } + + /** + * Shim layer for QueryExecution to maintain compatibility across different Spark versions. + * + * @since Spark + * 4.1 + */ + override def createSparkPlan( + sparkSession: SparkSession, + planner: SparkPlanner, + plan: LogicalPlan): SparkPlan = + QueryExecution.createSparkPlan(sparkSession, planner, plan) } diff --git a/shims/spark40/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFooterReaderShim.scala b/shims/spark40/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFooterReaderShim.scala new file mode 100644 index 000000000000..b1419e5e6233 --- /dev/null +++ b/shims/spark40/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFooterReaderShim.scala @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.datasources.parquet + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{FileStatus, Path} +import org.apache.parquet.format.converter.ParquetMetadataConverter +import org.apache.parquet.hadoop.metadata.ParquetMetadata + +/** Shim layer for ParquetFooterReader to maintain compatibility across different Spark versions. */ +object ParquetFooterReaderShim { + + /** @since Spark 4.1 */ + def readFooter( + configuration: Configuration, + fileStatus: FileStatus, + filter: ParquetMetadataConverter.MetadataFilter): ParquetMetadata = { + ParquetFooterReader.readFooter(configuration, fileStatus, filter) + } + + /** @since Spark 4.1 */ + def readFooter( + configuration: Configuration, + file: Path, + filter: ParquetMetadataConverter.MetadataFilter): ParquetMetadata = { + ParquetFooterReader.readFooter(configuration, file, filter) + } +} diff --git a/shims/spark40/src/main/scala/org/apache/spark/sql/execution/datasources/v2/Spark35Scan.scala b/shims/spark40/src/main/scala/org/apache/spark/sql/execution/datasources/v2/Spark35Scan.scala deleted file mode 100644 index 98fcfa548384..000000000000 --- a/shims/spark40/src/main/scala/org/apache/spark/sql/execution/datasources/v2/Spark35Scan.scala +++ /dev/null @@ -1,50 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.sql.execution.datasources.v2 - -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, SortOrder} -import org.apache.spark.sql.connector.read.{InputPartition, PartitionReaderFactory, Scan} - -class Spark35Scan extends DataSourceV2ScanExecBase { - - override def scan: Scan = throw new UnsupportedOperationException("Spark35Scan") - - override def ordering: Option[Seq[SortOrder]] = throw new UnsupportedOperationException( - "Spark35Scan") - - override def readerFactory: PartitionReaderFactory = - throw new UnsupportedOperationException("Spark35Scan") - - override def keyGroupedPartitioning: Option[Seq[Expression]] = - throw new UnsupportedOperationException("Spark35Scan") - - override protected def inputPartitions: Seq[InputPartition] = - throw new UnsupportedOperationException("Spark35Scan") - - override def inputRDD: RDD[InternalRow] = throw new UnsupportedOperationException("Spark35Scan") - - override def output: Seq[Attribute] = throw new UnsupportedOperationException("Spark35Scan") - - override def productElement(n: Int): Any = throw new UnsupportedOperationException("Spark35Scan") - - override def productArity: Int = throw new UnsupportedOperationException("Spark35Scan") - - override def canEqual(that: Any): Boolean = throw new UnsupportedOperationException("Spark35Scan") - -} diff --git a/shims/spark41/pom.xml b/shims/spark41/pom.xml new file mode 100644 index 000000000000..3d39f1a9322d --- /dev/null +++ b/shims/spark41/pom.xml @@ -0,0 +1,118 @@ + + + + 4.0.0 + + + org.apache.gluten + spark-sql-columnar-shims + 1.6.0-SNAPSHOT + ../pom.xml + + + spark-sql-columnar-shims-spark41 + jar + Gluten Shims for Spark 4.1 + + + + org.apache.gluten + ${project.prefix}-shims-common + ${project.version} + compile + + + org.apache.spark + spark-sql_${scala.binary.version} + provided + true + + + org.apache.spark + spark-catalyst_${scala.binary.version} + provided + true + + + org.apache.spark + spark-core_${scala.binary.version} + provided + true + + + org.apache.hadoop + hadoop-common + ${hadoop.version} + provided + + + + + org.scalatest + scalatest_${scala.binary.version} + test + + + org.apache.spark + spark-core_${scala.binary.version} + test-jar + + + org.apache.spark + spark-sql_${scala.binary.version} + test-jar + + + org.apache.spark + spark-catalyst_${scala.binary.version} + test-jar + + + org.apache.spark + spark-hive_${scala.binary.version} + provided + + + + + + + + org.apache.maven.plugins + maven-checkstyle-plugin + + + org.scalastyle + scalastyle-maven-plugin + + + com.diffplug.spotless + spotless-maven-plugin + + + net.alchim31.maven + scala-maven-plugin + + + org.scalatest + scalatest-maven-plugin + + + org.apache.maven.plugins + maven-compiler-plugin + + + + diff --git a/shims/spark41/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarArrayShim.java b/shims/spark41/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarArrayShim.java new file mode 100644 index 000000000000..7d1347345a95 --- /dev/null +++ b/shims/spark41/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarArrayShim.java @@ -0,0 +1,253 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.vectorized; + +import org.apache.spark.SparkUnsupportedOperationException; +import org.apache.spark.sql.catalyst.expressions.SpecializedGettersReader; +import org.apache.spark.sql.catalyst.expressions.UnsafeArrayData; +import org.apache.spark.sql.catalyst.util.ArrayData; +import org.apache.spark.sql.catalyst.util.GenericArrayData; +import org.apache.spark.sql.types.*; +import org.apache.spark.sql.vectorized.ColumnVector; +import org.apache.spark.sql.vectorized.ColumnarArray; +import org.apache.spark.sql.vectorized.ColumnarMap; +import org.apache.spark.sql.vectorized.ColumnarRow; +import org.apache.spark.unsafe.types.CalendarInterval; +import org.apache.spark.unsafe.types.GeographyVal; +import org.apache.spark.unsafe.types.GeometryVal; +import org.apache.spark.unsafe.types.UTF8String; +import org.apache.spark.unsafe.types.VariantVal; + +public class ColumnarArrayShim extends ArrayData { + // The data for this array. This array contains elements from + // data[offset] to data[offset + length). + private final ColumnVector data; + private final int offset; + private final int length; + + public ColumnarArrayShim(ColumnVector data, int offset, int length) { + this.data = data; + this.offset = offset; + this.length = length; + } + + @Override + public int numElements() { + return length; + } + + /** + * Sets all the appropriate null bits in the input UnsafeArrayData. + * + * @param arrayData The UnsafeArrayData to set the null bits for + * @return The UnsafeArrayData with the null bits set + */ + private UnsafeArrayData setNullBits(UnsafeArrayData arrayData) { + if (data.hasNull()) { + for (int i = 0; i < length; i++) { + if (data.isNullAt(offset + i)) { + arrayData.setNullAt(i); + } + } + } + return arrayData; + } + + @Override + public ArrayData copy() { + DataType dt = data.dataType(); + + if (dt instanceof BooleanType) { + return setNullBits(UnsafeArrayData.fromPrimitiveArray(toBooleanArray())); + } else if (dt instanceof ByteType) { + return setNullBits(UnsafeArrayData.fromPrimitiveArray(toByteArray())); + } else if (dt instanceof ShortType) { + return setNullBits(UnsafeArrayData.fromPrimitiveArray(toShortArray())); + } else if (dt instanceof IntegerType + || dt instanceof DateType + || dt instanceof YearMonthIntervalType) { + return setNullBits(UnsafeArrayData.fromPrimitiveArray(toIntArray())); + } else if (dt instanceof LongType + || dt instanceof TimestampType + || dt instanceof DayTimeIntervalType) { + return setNullBits(UnsafeArrayData.fromPrimitiveArray(toLongArray())); + } else if (dt instanceof FloatType) { + return setNullBits(UnsafeArrayData.fromPrimitiveArray(toFloatArray())); + } else if (dt instanceof DoubleType) { + return setNullBits(UnsafeArrayData.fromPrimitiveArray(toDoubleArray())); + } else { + return new GenericArrayData(toObjectArray(dt)).copy(); // ensure the elements are copied. + } + } + + @Override + public boolean[] toBooleanArray() { + return data.getBooleans(offset, length); + } + + @Override + public byte[] toByteArray() { + return data.getBytes(offset, length); + } + + @Override + public short[] toShortArray() { + return data.getShorts(offset, length); + } + + @Override + public int[] toIntArray() { + return data.getInts(offset, length); + } + + @Override + public long[] toLongArray() { + return data.getLongs(offset, length); + } + + @Override + public float[] toFloatArray() { + return data.getFloats(offset, length); + } + + @Override + public double[] toDoubleArray() { + return data.getDoubles(offset, length); + } + + // TODO: this is extremely expensive. + @Override + public Object[] array() { + DataType dt = data.dataType(); + Object[] list = new Object[length]; + try { + for (int i = 0; i < length; i++) { + if (!data.isNullAt(offset + i)) { + list[i] = get(i, dt); + } + } + return list; + } catch (Exception e) { + throw new RuntimeException("Could not get the array", e); + } + } + + @Override + public boolean isNullAt(int ordinal) { + return data.isNullAt(offset + ordinal); + } + + @Override + public boolean getBoolean(int ordinal) { + return data.getBoolean(offset + ordinal); + } + + @Override + public byte getByte(int ordinal) { + return data.getByte(offset + ordinal); + } + + @Override + public short getShort(int ordinal) { + return data.getShort(offset + ordinal); + } + + @Override + public int getInt(int ordinal) { + return data.getInt(offset + ordinal); + } + + @Override + public long getLong(int ordinal) { + return data.getLong(offset + ordinal); + } + + @Override + public float getFloat(int ordinal) { + return data.getFloat(offset + ordinal); + } + + @Override + public double getDouble(int ordinal) { + return data.getDouble(offset + ordinal); + } + + @Override + public Decimal getDecimal(int ordinal, int precision, int scale) { + return data.getDecimal(offset + ordinal, precision, scale); + } + + @Override + public UTF8String getUTF8String(int ordinal) { + return data.getUTF8String(offset + ordinal); + } + + @Override + public byte[] getBinary(int ordinal) { + return data.getBinary(offset + ordinal); + } + + @Override + public GeographyVal getGeography(int ordinal) { + return data.getGeography(offset + ordinal); + } + + @Override + public GeometryVal getGeometry(int ordinal) { + return data.getGeometry(offset + ordinal); + } + + @Override + public CalendarInterval getInterval(int ordinal) { + return data.getInterval(offset + ordinal); + } + + @Override + public VariantVal getVariant(int ordinal) { + return data.getVariant(offset + ordinal); + } + + @Override + public ColumnarRow getStruct(int ordinal, int numFields) { + return data.getStruct(offset + ordinal); + } + + @Override + public ColumnarArray getArray(int ordinal) { + return data.getArray(offset + ordinal); + } + + @Override + public ColumnarMap getMap(int ordinal) { + return data.getMap(offset + ordinal); + } + + @Override + public Object get(int ordinal, DataType dataType) { + return SpecializedGettersReader.read(this, ordinal, dataType, true, false); + } + + @Override + public void update(int ordinal, Object value) { + throw SparkUnsupportedOperationException.apply(); + } + + @Override + public void setNullAt(int ordinal) { + throw SparkUnsupportedOperationException.apply(); + } +} diff --git a/shims/spark41/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVectorShim.java b/shims/spark41/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVectorShim.java new file mode 100644 index 000000000000..513f3d2d92ca --- /dev/null +++ b/shims/spark41/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVectorShim.java @@ -0,0 +1,225 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.vectorized; + +import org.apache.spark.sql.types.DataType; +import org.apache.spark.unsafe.types.UTF8String; + +import java.nio.ByteBuffer; + +/** + * because spark33 add new function abstract method 'putBooleans(int, byte)' in + * 'WritableColumnVector' And function getByteBuffer() + */ +public class WritableColumnVectorShim extends WritableColumnVector { + /** + * Sets up the common state and also handles creating the child columns if this is a nested type. + * + * @param capacity + * @param type + */ + protected WritableColumnVectorShim(int capacity, DataType type) { + super(capacity, type); + } + + protected void releaseMemory() {} + + @Override + public int getDictId(int rowId) { + return 0; + } + + @Override + protected void reserveInternal(int capacity) {} + + @Override + public void putNotNull(int rowId) {} + + @Override + public void putNull(int rowId) {} + + @Override + public void putNulls(int rowId, int count) {} + + @Override + public void putNotNulls(int rowId, int count) {} + + @Override + public void putBoolean(int rowId, boolean value) {} + + @Override + public void putBooleans(int rowId, int count, boolean value) {} + + @Override + public void putBooleans(int rowId, byte src) { + throw new UnsupportedOperationException("Unsupported function"); + } + + @Override + public void putByte(int rowId, byte value) {} + + @Override + public void putBytes(int rowId, int count, byte value) {} + + @Override + public void putBytes(int rowId, int count, byte[] src, int srcIndex) {} + + @Override + public void putShort(int rowId, short value) {} + + @Override + public void putShorts(int rowId, int count, short value) {} + + @Override + public void putShorts(int rowId, int count, short[] src, int srcIndex) {} + + @Override + public void putShorts(int rowId, int count, byte[] src, int srcIndex) {} + + @Override + public void putInt(int rowId, int value) {} + + @Override + public void putInts(int rowId, int count, int value) {} + + @Override + public void putInts(int rowId, int count, int[] src, int srcIndex) {} + + @Override + public void putInts(int rowId, int count, byte[] src, int srcIndex) {} + + @Override + public void putIntsLittleEndian(int rowId, int count, byte[] src, int srcIndex) {} + + @Override + public void putLong(int rowId, long value) {} + + @Override + public void putLongs(int rowId, int count, long value) {} + + @Override + public void putLongs(int rowId, int count, long[] src, int srcIndex) {} + + @Override + public void putLongs(int rowId, int count, byte[] src, int srcIndex) {} + + @Override + public void putLongsLittleEndian(int rowId, int count, byte[] src, int srcIndex) {} + + @Override + public void putFloat(int rowId, float value) {} + + @Override + public void putFloats(int rowId, int count, float value) {} + + @Override + public void putFloats(int rowId, int count, float[] src, int srcIndex) {} + + @Override + public void putFloats(int rowId, int count, byte[] src, int srcIndex) {} + + @Override + public void putFloatsLittleEndian(int rowId, int count, byte[] src, int srcIndex) {} + + @Override + public void putDouble(int rowId, double value) {} + + @Override + public void putDoubles(int rowId, int count, double value) {} + + @Override + public void putDoubles(int rowId, int count, double[] src, int srcIndex) {} + + @Override + public void putDoubles(int rowId, int count, byte[] src, int srcIndex) {} + + @Override + public void putDoublesLittleEndian(int rowId, int count, byte[] src, int srcIndex) {} + + @Override + public void putArray(int rowId, int offset, int length) {} + + @Override + public int putByteArray(int rowId, byte[] value, int offset, int count) { + return 0; + } + + @Override + protected UTF8String getBytesAsUTF8String(int rowId, int count) { + return null; + } + + @Override + public ByteBuffer getByteBuffer(int rowId, int count) { + throw new UnsupportedOperationException("Unsupported this function"); + } + + @Override + public int getArrayLength(int rowId) { + return 0; + } + + @Override + public int getArrayOffset(int rowId) { + return 0; + } + + @Override + public WritableColumnVector reserveNewColumn(int capacity, DataType type) { + return null; + } + + @Override + public boolean isNullAt(int rowId) { + return false; + } + + @Override + public boolean getBoolean(int rowId) { + return false; + } + + @Override + public byte getByte(int rowId) { + return 0; + } + + @Override + public short getShort(int rowId) { + return 0; + } + + @Override + public int getInt(int rowId) { + return 0; + } + + @Override + public long getLong(int rowId) { + return 0; + } + + @Override + public float getFloat(int rowId) { + return 0; + } + + @Override + public double getDouble(int rowId) { + return 0; + } +} diff --git a/shims/spark41/src/main/resources/META-INF/services/org.apache.gluten.sql.shims.SparkShimProvider b/shims/spark41/src/main/resources/META-INF/services/org.apache.gluten.sql.shims.SparkShimProvider new file mode 100644 index 000000000000..6c2e7e518171 --- /dev/null +++ b/shims/spark41/src/main/resources/META-INF/services/org.apache.gluten.sql.shims.SparkShimProvider @@ -0,0 +1 @@ +org.apache.gluten.sql.shims.spark41.SparkShimProvider \ No newline at end of file diff --git a/shims/spark41/src/main/scala/org/apache/gluten/execution/GenerateTreeStringShim.scala b/shims/spark41/src/main/scala/org/apache/gluten/execution/GenerateTreeStringShim.scala new file mode 100644 index 000000000000..3d329a8f31ce --- /dev/null +++ b/shims/spark41/src/main/scala/org/apache/gluten/execution/GenerateTreeStringShim.scala @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.gluten.execution + +import org.apache.spark.sql.execution.UnaryExecNode + +/** + * Spark 3.5 has changed the parameter type of the generateTreeString API in TreeNode. In order to + * support multiple versions of Spark, we cannot directly override the generateTreeString method in + * WhostageTransformer. Therefore, we have defined the GenerateTreeStringShim trait in the shim to + * allow different Spark versions to override their own generateTreeString. + */ + +trait WholeStageTransformerGenerateTreeStringShim extends UnaryExecNode { + + def stageId: Int + + def substraitPlanJson: String + + def wholeStageTransformerContextDefined: Boolean + + override def generateTreeString( + depth: Int, + lastChildren: java.util.ArrayList[Boolean], + append: String => Unit, + verbose: Boolean, + prefix: String = "", + addSuffix: Boolean = false, + maxFields: Int, + printNodeId: Boolean, + printOutputColumns: Boolean, + indent: Int = 0): Unit = { + val prefix = if (printNodeId) "^ " else s"^($stageId) " + child.generateTreeString( + depth, + lastChildren, + append, + verbose, + prefix, + addSuffix = false, + maxFields, + printNodeId = printNodeId, + printOutputColumns = printOutputColumns, + indent = indent) + + if (verbose && wholeStageTransformerContextDefined) { + append(prefix + "Substrait plan:\n") + append(substraitPlanJson) + append("\n") + } + } +} + +trait InputAdapterGenerateTreeStringShim extends UnaryExecNode { + + override def generateTreeString( + depth: Int, + lastChildren: java.util.ArrayList[Boolean], + append: String => Unit, + verbose: Boolean, + prefix: String = "", + addSuffix: Boolean = false, + maxFields: Int, + printNodeId: Boolean, + printOutputColumns: Boolean, + indent: Int = 0): Unit = { + child.generateTreeString( + depth, + lastChildren, + append, + verbose, + prefix = "", + addSuffix = false, + maxFields, + printNodeId, + printOutputColumns, + indent) + } +} diff --git a/shims/spark41/src/main/scala/org/apache/gluten/execution/PartitionedFileUtilShim.scala b/shims/spark41/src/main/scala/org/apache/gluten/execution/PartitionedFileUtilShim.scala new file mode 100644 index 000000000000..2d3287afe4d5 --- /dev/null +++ b/shims/spark41/src/main/scala/org/apache/gluten/execution/PartitionedFileUtilShim.scala @@ -0,0 +1,174 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.gluten.execution + +import org.apache.spark.paths.SparkPath +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.execution.PartitionedFileUtil +import org.apache.spark.sql.execution.datasources.{FileStatusWithMetadata, PartitionedFile} + +import org.apache.hadoop.fs.Path + +import java.lang.reflect.Method + +object PartitionedFileUtilShim { + + private val clz: Class[_] = PartitionedFileUtil.getClass + private val module = clz.getField("MODULE$").get(null) + + private lazy val getPartitionedFileMethod: Method = { + try { + val m = clz.getDeclaredMethod( + "getPartitionedFile", + classOf[FileStatusWithMetadata], + classOf[InternalRow]) + m.setAccessible(true) + m + } catch { + case _: NoSuchMethodException => null + } + } + + private lazy val getPartitionedFileByPathMethod: Method = { + try { + val m = clz.getDeclaredMethod( + "getPartitionedFile", + classOf[FileStatusWithMetadata], + classOf[Path], + classOf[InternalRow]) + m.setAccessible(true) + m + } catch { + case _: NoSuchMethodException => null + } + } + + private lazy val getPartitionedFileByPathSizeMethod: Method = { + try { + val m = clz.getDeclaredMethod( + "getPartitionedFile", + classOf[FileStatusWithMetadata], + classOf[Path], + classOf[InternalRow], + classOf[Long], + classOf[Long]) + m.setAccessible(true) + m + } catch { + case _: NoSuchMethodException => null + } + } + + def getPartitionedFile( + file: FileStatusWithMetadata, + partitionValues: InternalRow): PartitionedFile = { + if (getPartitionedFileMethod != null) { + getPartitionedFileMethod + .invoke(module, file, partitionValues) + .asInstanceOf[PartitionedFile] + } else if (getPartitionedFileByPathMethod != null) { + getPartitionedFileByPathMethod + .invoke(module, file, file.getPath, partitionValues) + .asInstanceOf[PartitionedFile] + } else if (getPartitionedFileByPathSizeMethod != null) { + getPartitionedFileByPathSizeMethod + .invoke(module, file, file.getPath, partitionValues, 0, file.getLen) + .asInstanceOf[PartitionedFile] + } else { + val params = clz.getDeclaredMethods + .find(_.getName == "getPartitionedFile") + .map(_.getGenericParameterTypes.mkString(", ")) + throw new RuntimeException( + s"getPartitionedFile with $params is not correctly shimmed " + + "in PartitionedFileUtilShim") + } + } + + private lazy val splitFilesMethod: Method = { + try { + val m = clz.getDeclaredMethod( + "splitFiles", + classOf[SparkSession], + classOf[FileStatusWithMetadata], + classOf[Boolean], + classOf[Long], + classOf[InternalRow]) + m.setAccessible(true) + m + } catch { + case _: NoSuchMethodException => null + } + } + + private lazy val splitFilesByPathMethod: Method = { + try { + val m = clz.getDeclaredMethod( + "splitFiles", + classOf[FileStatusWithMetadata], + classOf[Path], + classOf[Boolean], + classOf[Long], + classOf[InternalRow]) + m.setAccessible(true) + m + } catch { + case _: NoSuchMethodException => null + } + } + + def splitFiles( + sparkSession: SparkSession, + file: FileStatusWithMetadata, + isSplitable: Boolean, + maxSplitBytes: Long, + partitionValues: InternalRow): Seq[PartitionedFile] = { + if (splitFilesMethod != null) { + splitFilesMethod + .invoke( + module, + sparkSession, + file, + java.lang.Boolean.valueOf(isSplitable), + java.lang.Long.valueOf(maxSplitBytes), + partitionValues) + .asInstanceOf[Seq[PartitionedFile]] + } else if (splitFilesByPathMethod != null) { + splitFilesByPathMethod + .invoke( + module, + file, + file.getPath, + java.lang.Boolean.valueOf(isSplitable), + java.lang.Long.valueOf(maxSplitBytes), + partitionValues) + .asInstanceOf[Seq[PartitionedFile]] + } else { + val params = clz.getDeclaredMethods + .find(_.getName == "splitFiles") + .map(_.getGenericParameterTypes.mkString(", ")) + throw new RuntimeException( + s"splitFiles with $params is not correctly shimmed " + + "in PartitionedFileUtilShim") + } + } + + // Helper method to create PartitionedFile from path and length. + def makePartitionedFileFromPath(path: String, length: Long): PartitionedFile = { + PartitionedFile(null, SparkPath.fromPathString(path), 0, length, Array.empty) + } +} diff --git a/shims/spark41/src/main/scala/org/apache/gluten/sql/shims/spark41/Spark41Shims.scala b/shims/spark41/src/main/scala/org/apache/gluten/sql/shims/spark41/Spark41Shims.scala new file mode 100644 index 000000000000..ea5f733614cb --- /dev/null +++ b/shims/spark41/src/main/scala/org/apache/gluten/sql/shims/spark41/Spark41Shims.scala @@ -0,0 +1,767 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.gluten.sql.shims.spark41 + +import org.apache.gluten.execution.PartitionedFileUtilShim +import org.apache.gluten.expression.{ExpressionNames, Sig} +import org.apache.gluten.sql.shims.SparkShims + +import org.apache.spark._ +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.internal.io.FileCommitProtocol +import org.apache.spark.paths.SparkPath +import org.apache.spark.scheduler.TaskInfo +import org.apache.spark.shuffle.ShuffleHandle +import org.apache.spark.sql.{AnalysisException, SparkSession} +import org.apache.spark.sql.catalyst.{ExtendedAnalysisException, InternalRow} +import org.apache.spark.sql.catalyst.analysis.DecimalPrecisionTypeCoercion +import org.apache.spark.sql.catalyst.catalog.BucketSpec +import org.apache.spark.sql.catalyst.csv.CSVOptions +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.sql.catalyst.plans.QueryPlan +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Distribution, KeyGroupedPartitioning, KeyGroupedShuffleSpec, Partitioning} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.types.DataTypeUtils +import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, InternalRowComparableWrapper, TimestampFormatter} +import org.apache.spark.sql.catalyst.util.RebaseDateTime.RebaseSpec +import org.apache.spark.sql.connector.catalog.Table +import org.apache.spark.sql.connector.expressions.Transform +import org.apache.spark.sql.connector.read.{HasPartitionKey, InputPartition, Scan} +import org.apache.spark.sql.connector.read.streaming.SparkDataStream +import org.apache.spark.sql.execution._ +import org.apache.spark.sql.execution.datasources._ +import org.apache.spark.sql.execution.datasources.parquet.{ParquetFileFormat, ParquetFilters} +import org.apache.spark.sql.execution.datasources.v2.{BatchScanExec, BatchScanExecShim, DataSourceV2ScanExecBase} +import org.apache.spark.sql.execution.datasources.v2.text.TextScan +import org.apache.spark.sql.execution.datasources.v2.utils.CatalogUtil +import org.apache.spark.sql.execution.exchange.{BroadcastExchangeLike, ShuffleExchangeLike} +import org.apache.spark.sql.execution.window.{Final, Partial, _} +import org.apache.spark.sql.internal.{LegacyBehaviorPolicy, SQLConf} +import org.apache.spark.sql.types._ +import org.apache.spark.sql.util.CaseInsensitiveStringMap +import org.apache.spark.storage.{BlockId, BlockManagerId} + +import org.apache.hadoop.fs.{FileStatus, Path} +import org.apache.parquet.hadoop.metadata.{CompressionCodecName, ParquetMetadata} +import org.apache.parquet.hadoop.metadata.FileMetaData.EncryptionType +import org.apache.parquet.schema.MessageType + +import java.time.ZoneOffset +import java.util.{Map => JMap} + +import scala.collection.mutable +import scala.jdk.CollectionConverters._ +import scala.reflect.ClassTag + +class Spark41Shims extends SparkShims { + + override def getDistribution( + leftKeys: Seq[Expression], + rightKeys: Seq[Expression]): Seq[Distribution] = { + ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil + } + + override def scalarExpressionMappings: Seq[Sig] = { + Seq( + Sig[SplitPart](ExpressionNames.SPLIT_PART), + Sig[Sec](ExpressionNames.SEC), + Sig[Csc](ExpressionNames.CSC), + Sig[KnownNullable](ExpressionNames.KNOWN_NULLABLE), + Sig[Empty2Null](ExpressionNames.EMPTY2NULL), + Sig[Mask](ExpressionNames.MASK), + Sig[TimestampAdd](ExpressionNames.TIMESTAMP_ADD), + Sig[TimestampDiff](ExpressionNames.TIMESTAMP_DIFF), + Sig[RoundFloor](ExpressionNames.FLOOR), + Sig[RoundCeil](ExpressionNames.CEIL), + Sig[ArrayInsert](ExpressionNames.ARRAY_INSERT), + Sig[CheckOverflowInTableInsert](ExpressionNames.CHECK_OVERFLOW_IN_TABLE_INSERT), + Sig[ArrayAppend](ExpressionNames.ARRAY_APPEND), + Sig[UrlEncode](ExpressionNames.URL_ENCODE), + Sig[KnownNotContainsNull](ExpressionNames.KNOWN_NOT_CONTAINS_NULL), + Sig[UrlDecode](ExpressionNames.URL_DECODE) + ) + } + + override def aggregateExpressionMappings: Seq[Sig] = { + Seq( + Sig[RegrR2](ExpressionNames.REGR_R2), + Sig[RegrSlope](ExpressionNames.REGR_SLOPE), + Sig[RegrIntercept](ExpressionNames.REGR_INTERCEPT), + Sig[RegrSXY](ExpressionNames.REGR_SXY), + Sig[RegrReplacement](ExpressionNames.REGR_REPLACEMENT) + ) + } + + override def runtimeReplaceableExpressionMappings: Seq[Sig] = { + Seq( + Sig[ArrayCompact](ExpressionNames.ARRAY_COMPACT), + Sig[ArrayPrepend](ExpressionNames.ARRAY_PREPEND), + Sig[ArraySize](ExpressionNames.ARRAY_SIZE), + Sig[EqualNull](ExpressionNames.EQUAL_NULL), + Sig[ILike](ExpressionNames.ILIKE), + Sig[MapContainsKey](ExpressionNames.MAP_CONTAINS_KEY), + Sig[Get](ExpressionNames.GET), + Sig[Luhncheck](ExpressionNames.LUHN_CHECK) + ) + } + + override def convertPartitionTransforms( + partitions: Seq[Transform]): (Seq[String], Option[BucketSpec]) = { + CatalogUtil.convertPartitionTransforms(partitions) + } + + override def generateFileScanRDD( + sparkSession: SparkSession, + readFunction: PartitionedFile => Iterator[InternalRow], + filePartitions: Seq[FilePartition], + fileSourceScanExec: FileSourceScanExec): FileScanRDD = { + new FileScanRDD( + sparkSession, + readFunction, + filePartitions, + new StructType( + fileSourceScanExec.requiredSchema.fields ++ + fileSourceScanExec.relation.partitionSchema.fields), + fileSourceScanExec.fileConstantMetadataColumns + ) + } + + override def getTextScan( + sparkSession: SparkSession, + fileIndex: PartitioningAwareFileIndex, + dataSchema: StructType, + readDataSchema: StructType, + readPartitionSchema: StructType, + options: CaseInsensitiveStringMap, + partitionFilters: Seq[Expression], + dataFilters: Seq[Expression]): TextScan = { + TextScan( + sparkSession, + fileIndex, + dataSchema, + readDataSchema, + readPartitionSchema, + options, + partitionFilters, + dataFilters) + } + + override def filesGroupedToBuckets( + selectedPartitions: Array[PartitionDirectory]): Map[Int, Array[PartitionedFile]] = { + selectedPartitions + .flatMap(p => p.files.map(f => PartitionedFileUtilShim.getPartitionedFile(f, p.values))) + .groupBy { + f => + BucketingUtils + .getBucketId(f.toPath.getName) + .getOrElse(throw invalidBucketFile(f.urlEncodedPath)) + } + } + + override def getBatchScanExecTable(batchScan: BatchScanExec): Table = batchScan.table + + override def generatePartitionedFile( + partitionValues: InternalRow, + filePath: String, + start: Long, + length: Long, + @transient locations: Array[String] = Array.empty): PartitionedFile = + PartitionedFile(partitionValues, SparkPath.fromPathString(filePath), start, length, locations) + + override def bloomFilterExpressionMappings(): Seq[Sig] = Seq( + Sig[BloomFilterMightContain](ExpressionNames.MIGHT_CONTAIN), + Sig[BloomFilterAggregate](ExpressionNames.BLOOM_FILTER_AGG) + ) + + override def newBloomFilterAggregate[T]( + child: Expression, + estimatedNumItemsExpression: Expression, + numBitsExpression: Expression, + mutableAggBufferOffset: Int, + inputAggBufferOffset: Int): TypedImperativeAggregate[T] = { + BloomFilterAggregate( + child, + estimatedNumItemsExpression, + numBitsExpression, + mutableAggBufferOffset, + inputAggBufferOffset).asInstanceOf[TypedImperativeAggregate[T]] + } + + override def newMightContain( + bloomFilterExpression: Expression, + valueExpression: Expression): BinaryExpression = { + BloomFilterMightContain(bloomFilterExpression, valueExpression) + } + + override def replaceBloomFilterAggregate[T]( + expr: Expression, + bloomFilterAggReplacer: ( + Expression, + Expression, + Expression, + Int, + Int) => TypedImperativeAggregate[T]): Expression = expr match { + case BloomFilterAggregate( + child, + estimatedNumItemsExpression, + numBitsExpression, + mutableAggBufferOffset, + inputAggBufferOffset) => + bloomFilterAggReplacer( + child, + estimatedNumItemsExpression, + numBitsExpression, + mutableAggBufferOffset, + inputAggBufferOffset) + case other => other + } + + override def replaceMightContain[T]( + expr: Expression, + mightContainReplacer: (Expression, Expression) => BinaryExpression): Expression = expr match { + case BloomFilterMightContain(bloomFilterExpression, valueExpression) => + mightContainReplacer(bloomFilterExpression, valueExpression) + case other => other + } + + override def getFileSizeAndModificationTime( + file: PartitionedFile): (Option[Long], Option[Long]) = { + (Some(file.fileSize), Some(file.modificationTime)) + } + + override def generateMetadataColumns( + file: PartitionedFile, + metadataColumnNames: Seq[String]): Map[String, String] = { + val originMetadataColumn = super.generateMetadataColumns(file, metadataColumnNames) + val metadataColumn: mutable.Map[String, String] = mutable.Map(originMetadataColumn.toSeq: _*) + val path = new Path(file.filePath.toString) + for (columnName <- metadataColumnNames) { + columnName match { + case FileFormat.FILE_PATH => metadataColumn += (FileFormat.FILE_PATH -> path.toString) + case FileFormat.FILE_NAME => metadataColumn += (FileFormat.FILE_NAME -> path.getName) + case FileFormat.FILE_SIZE => + metadataColumn += (FileFormat.FILE_SIZE -> file.fileSize.toString) + case FileFormat.FILE_MODIFICATION_TIME => + val fileModifyTime = TimestampFormatter + .getFractionFormatter(ZoneOffset.UTC) + .format(file.modificationTime * 1000L) + metadataColumn += (FileFormat.FILE_MODIFICATION_TIME -> fileModifyTime) + case FileFormat.FILE_BLOCK_START => + metadataColumn += (FileFormat.FILE_BLOCK_START -> file.start.toString) + case FileFormat.FILE_BLOCK_LENGTH => + metadataColumn += (FileFormat.FILE_BLOCK_LENGTH -> file.length.toString) + case _ => + } + } + metadataColumn.toMap + } + + // https://issues.apache.org/jira/browse/SPARK-40400 + private def invalidBucketFile(path: String): Throwable = { + new SparkException( + errorClass = "INVALID_BUCKET_FILE", + messageParameters = Map("path" -> path), + cause = null) + } + + private def getLimit(limit: Int, offset: Int): Int = { + if (limit == -1) { + // Only offset specified, so fetch the maximum number rows + Int.MaxValue + } else { + assert(limit > offset) + limit - offset + } + } + + override def getLimitAndOffsetFromGlobalLimit(plan: GlobalLimitExec): (Int, Int) = { + (getLimit(plan.limit, plan.offset), plan.offset) + } + + override def isWindowGroupLimitExec(plan: SparkPlan): Boolean = plan match { + case _: WindowGroupLimitExec => true + case _ => false + } + + override def getWindowGroupLimitExecShim(plan: SparkPlan): WindowGroupLimitExecShim = { + val windowGroupLimitPlan = plan.asInstanceOf[WindowGroupLimitExec] + val mode = windowGroupLimitPlan.mode match { + case Partial => GlutenPartial + case Final => GlutenFinal + } + WindowGroupLimitExecShim( + windowGroupLimitPlan.partitionSpec, + windowGroupLimitPlan.orderSpec, + windowGroupLimitPlan.rankLikeFunction, + windowGroupLimitPlan.limit, + mode, + windowGroupLimitPlan.child + ) + } + + override def getWindowGroupLimitExec( + windowGroupLimitExecShim: WindowGroupLimitExecShim): SparkPlan = { + val mode = windowGroupLimitExecShim.mode match { + case GlutenPartial => Partial + case GlutenFinal => Final + } + WindowGroupLimitExec( + windowGroupLimitExecShim.partitionSpec, + windowGroupLimitExecShim.orderSpec, + windowGroupLimitExecShim.rankLikeFunction, + windowGroupLimitExecShim.limit, + mode, + windowGroupLimitExecShim.child + ) + } + + override def getLimitAndOffsetFromTopK(plan: TakeOrderedAndProjectExec): (Int, Int) = { + (getLimit(plan.limit, plan.offset), plan.offset) + } + + override def getExtendedColumnarPostRules(): List[SparkSession => Rule[SparkPlan]] = List() + + override def writeFilesExecuteTask( + description: WriteJobDescription, + jobTrackerID: String, + sparkStageId: Int, + sparkPartitionId: Int, + sparkAttemptNumber: Int, + committer: FileCommitProtocol, + iterator: Iterator[InternalRow]): WriteTaskResult = { + GlutenFileFormatWriter.writeFilesExecuteTask( + description, + jobTrackerID, + sparkStageId, + sparkPartitionId, + sparkAttemptNumber, + committer, + iterator + ) + } + + override def enableNativeWriteFilesByDefault(): Boolean = true + + override def broadcastInternal[T: ClassTag](sc: SparkContext, value: T): Broadcast[T] = { + SparkContextUtils.broadcastInternal(sc, value) + } + + override def setJobDescriptionOrTagForBroadcastExchange( + sc: SparkContext, + broadcastExchange: BroadcastExchangeLike): Unit = { + // Setup a job tag here so later it may get cancelled by tag if necessary. + sc.addJobTag(broadcastExchange.jobTag) + sc.setInterruptOnCancel(true) + } + + override def cancelJobGroupForBroadcastExchange( + sc: SparkContext, + broadcastExchange: BroadcastExchangeLike): Unit = { + sc.cancelJobsWithTag(broadcastExchange.jobTag) + } + + override def getShuffleReaderParam[K, C]( + handle: ShuffleHandle, + startMapIndex: Int, + endMapIndex: Int, + startPartition: Int, + endPartition: Int): Tuple2[Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])], Boolean] = { + ShuffleUtils.getReaderParam(handle, startMapIndex, endMapIndex, startPartition, endPartition) + } + + override def getShuffleAdvisoryPartitionSize(shuffle: ShuffleExchangeLike): Option[Long] = + shuffle.advisoryPartitionSize + + override def getPartitionId(taskInfo: TaskInfo): Int = { + taskInfo.partitionId + } + + override def supportDuplicateReadingTracking: Boolean = true + + def getFileStatus(partition: PartitionDirectory): Seq[(FileStatus, Map[String, Any])] = + partition.files.map(f => (f.fileStatus, f.metadata)) + + def isFileSplittable( + relation: HadoopFsRelation, + filePath: Path, + sparkSchema: StructType): Boolean = { + relation.fileFormat + .isSplitable(relation.sparkSession, relation.options, filePath) + } + + def isRowIndexMetadataColumn(name: String): Boolean = + name == ParquetFileFormat.ROW_INDEX_TEMPORARY_COLUMN_NAME || + name.equalsIgnoreCase("__delta_internal_is_row_deleted") + + def findRowIndexColumnIndexInSchema(sparkSchema: StructType): Int = { + sparkSchema.fields.zipWithIndex.find { + case (field: StructField, _: Int) => + field.name == ParquetFileFormat.ROW_INDEX_TEMPORARY_COLUMN_NAME + } match { + case Some((field: StructField, idx: Int)) => + if (field.dataType != LongType && field.dataType != IntegerType) { + throw new RuntimeException( + s"${ParquetFileFormat.ROW_INDEX_TEMPORARY_COLUMN_NAME} " + + "must be of LongType or IntegerType") + } + idx + case _ => -1 + } + } + + def splitFiles( + sparkSession: SparkSession, + file: FileStatus, + filePath: Path, + isSplitable: Boolean, + maxSplitBytes: Long, + partitionValues: InternalRow, + metadata: Map[String, Any] = Map.empty): Seq[PartitionedFile] = { + PartitionedFileUtilShim.splitFiles( + sparkSession, + FileStatusWithMetadata(file, metadata), + isSplitable, + maxSplitBytes, + partitionValues) + } + + def structFromAttributes(attrs: Seq[Attribute]): StructType = { + DataTypeUtils.fromAttributes(attrs) + } + + def attributesFromStruct(structType: StructType): Seq[Attribute] = { + DataTypeUtils.toAttributes(structType) + } + + def getAnalysisExceptionPlan(ae: AnalysisException): Option[LogicalPlan] = { + ae match { + case eae: ExtendedAnalysisException => + eae.plan + case _ => + None + } + } + override def getKeyGroupedPartitioning(batchScan: BatchScanExec): Option[Seq[Expression]] = { + batchScan.keyGroupedPartitioning + } + + override def getCommonPartitionValues( + batchScan: BatchScanExec): Option[Seq[(InternalRow, Int)]] = { + batchScan.spjParams.commonPartitionValues + } + + // please ref BatchScanExec::inputRDD + override def orderPartitions( + batchScan: DataSourceV2ScanExecBase, + scan: Scan, + keyGroupedPartitioning: Option[Seq[Expression]], + filteredPartitions: Seq[Seq[InputPartition]], + outputPartitioning: Partitioning, + commonPartitionValues: Option[Seq[(InternalRow, Int)]], + applyPartialClustering: Boolean, + replicatePartitions: Boolean, + joinKeyPositions: Option[Seq[Int]] = None): Seq[Seq[InputPartition]] = { + val original = batchScan.asInstanceOf[BatchScanExecShim] + scan match { + case _ if keyGroupedPartitioning.isDefined => + outputPartitioning match { + case p: KeyGroupedPartitioning => + assert(keyGroupedPartitioning.isDefined) + val expressions = keyGroupedPartitioning.get + + // Re-group the input partitions if we are projecting on a subset of join keys + val (groupedPartitions, partExpressions) = joinKeyPositions match { + case Some(projectPositions) => + val projectedExpressions = projectPositions.map(i => expressions(i)) + val parts = filteredPartitions.flatten + .groupBy( + part => { + val row = part.asInstanceOf[HasPartitionKey].partitionKey() + val projectedRow = + KeyGroupedPartitioning.project(expressions, projectPositions, row) + InternalRowComparableWrapper(projectedRow, projectedExpressions) + }) + .map { case (wrapper, splits) => (wrapper.row, splits) } + .toSeq + (parts, projectedExpressions) + case _ => + val groupedParts = filteredPartitions.map( + splits => { + assert(splits.nonEmpty && splits.head.isInstanceOf[HasPartitionKey]) + (splits.head.asInstanceOf[HasPartitionKey].partitionKey(), splits) + }) + (groupedParts, expressions) + } + + // Also re-group the partitions if we are reducing compatible partition expressions + val finalGroupedPartitions = original.reducers match { + case Some(reducers) => + val result = groupedPartitions + .groupBy { + case (row, _) => + KeyGroupedShuffleSpec.reducePartitionValue(row, partExpressions, reducers) + } + .map { case (wrapper, splits) => (wrapper.row, splits.flatMap(_._2)) } + .toSeq + val rowOrdering = + RowOrdering.createNaturalAscendingOrdering(partExpressions.map(_.dataType)) + result.sorted(rowOrdering.on((t: (InternalRow, _)) => t._1)) + case _ => groupedPartitions + } + + // When partially clustered, the input partitions are not grouped by partition + // values. Here we'll need to check `commonPartitionValues` and decide how to group + // and replicate splits within a partition. + if (commonPartitionValues.isDefined && applyPartialClustering) { + // A mapping from the common partition values to how many splits the partition + // should contain. + val commonPartValuesMap = commonPartitionValues.get + .map(t => (InternalRowComparableWrapper(t._1, partExpressions), t._2)) + .toMap + val filteredGroupedPartitions = finalGroupedPartitions.filter { + case (partValues, _) => + commonPartValuesMap.keySet.contains( + InternalRowComparableWrapper(partValues, partExpressions)) + } + val nestGroupedPartitions = filteredGroupedPartitions.map { + case (partValue, splits) => + // `commonPartValuesMap` should contain the part value since it's the super set. + val numSplits = commonPartValuesMap + .get(InternalRowComparableWrapper(partValue, partExpressions)) + assert( + numSplits.isDefined, + s"Partition value $partValue does not exist in " + + "common partition values from Spark plan") + + val newSplits = if (replicatePartitions) { + // We need to also replicate partitions according to the other side of join + Seq.fill(numSplits.get)(splits) + } else { + // Not grouping by partition values: this could be the side with partially + // clustered distribution. Because of dynamic filtering, we'll need to check if + // the final number of splits of a partition is smaller than the original + // number, and fill with empty splits if so. This is necessary so that both + // sides of a join will have the same number of partitions & splits. + splits.map(Seq(_)).padTo(numSplits.get, Seq.empty) + } + (InternalRowComparableWrapper(partValue, partExpressions), newSplits) + } + + // Now fill missing partition keys with empty partitions + val partitionMapping = nestGroupedPartitions.toMap + commonPartitionValues.get.flatMap { + case (partValue, numSplits) => + // Use empty partition for those partition values that are not present. + partitionMapping.getOrElse( + InternalRowComparableWrapper(partValue, partExpressions), + Seq.fill(numSplits)(Seq.empty)) + } + } else { + // either `commonPartitionValues` is not defined, or it is defined but + // `applyPartialClustering` is false. + val partitionMapping = finalGroupedPartitions.map { + case (partValue, splits) => + InternalRowComparableWrapper(partValue, partExpressions) -> splits + }.toMap + + // In case `commonPartitionValues` is not defined (e.g., SPJ is not used), there + // could exist duplicated partition values, as partition grouping is not done + // at the beginning and postponed to this method. It is important to use unique + // partition values here so that grouped partitions won't get duplicated. + p.uniquePartitionValues.map { + partValue => + // Use empty partition for those partition values that are not present + partitionMapping.getOrElse( + InternalRowComparableWrapper(partValue, partExpressions), + Seq.empty) + } + } + + case _ => filteredPartitions + } + case _ => + filteredPartitions + } + } + + override def withTryEvalMode(expr: Expression): Boolean = { + expr match { + case a: Add => a.evalMode == EvalMode.TRY + case s: Subtract => s.evalMode == EvalMode.TRY + case d: Divide => d.evalMode == EvalMode.TRY + case m: Multiply => m.evalMode == EvalMode.TRY + case c: Cast => c.evalMode == EvalMode.TRY + case _ => false + } + } + + override def withAnsiEvalMode(expr: Expression): Boolean = { + expr match { + case a: Add => a.evalMode == EvalMode.ANSI + case s: Subtract => s.evalMode == EvalMode.ANSI + case d: Divide => d.evalMode == EvalMode.ANSI + case m: Multiply => m.evalMode == EvalMode.ANSI + case c: Cast => c.evalMode == EvalMode.ANSI + case i: IntegralDivide => i.evalMode == EvalMode.ANSI + case _ => false + } + } + + override def dateTimestampFormatInReadIsDefaultValue( + csvOptions: CSVOptions, + timeZone: String): Boolean = { + val default = new CSVOptions(CaseInsensitiveMap(Map()), csvOptions.columnPruning, timeZone) + csvOptions.dateFormatInRead == default.dateFormatInRead && + csvOptions.timestampFormatInRead == default.timestampFormatInRead && + csvOptions.timestampNTZFormatInRead == default.timestampNTZFormatInRead + } + + override def createParquetFilters( + conf: SQLConf, + schema: MessageType, + caseSensitive: Option[Boolean] = None): ParquetFilters = { + new ParquetFilters( + schema, + conf.parquetFilterPushDownDate, + conf.parquetFilterPushDownTimestamp, + conf.parquetFilterPushDownDecimal, + conf.parquetFilterPushDownStringPredicate, + conf.parquetFilterPushDownInFilterThreshold, + caseSensitive.getOrElse(conf.caseSensitiveAnalysis), + RebaseSpec(LegacyBehaviorPolicy.CORRECTED) + ) + } + + override def extractExpressionArrayInsert(arrayInsert: Expression): Seq[Expression] = { + val expr = arrayInsert.asInstanceOf[ArrayInsert] + Seq(expr.srcArrayExpr, expr.posExpr, expr.itemExpr, Literal(expr.legacyNegativeIndex)) + } + + override def withOperatorIdMap[T](idMap: java.util.Map[QueryPlan[_], Int])(body: => T): T = { + val prevIdMap = QueryPlan.localIdMap.get() + try { + QueryPlan.localIdMap.set(idMap) + body + } finally { + QueryPlan.localIdMap.set(prevIdMap) + } + } + + override def getOperatorId(plan: QueryPlan[_]): Option[Int] = { + Option(QueryPlan.localIdMap.get().get(plan)) + } + + override def setOperatorId(plan: QueryPlan[_], opId: Int): Unit = { + val map = QueryPlan.localIdMap.get() + assert(!map.containsKey(plan)) + map.put(plan, opId) + } + + override def unsetOperatorId(plan: QueryPlan[_]): Unit = { + QueryPlan.localIdMap.get().remove(plan) + } + + override def isParquetFileEncrypted(footer: ParquetMetadata): Boolean = { + footer.getFileMetaData.getEncryptionType match { + // UNENCRYPTED file has a plaintext footer and no file encryption, + // We can leverage file metadata for this check and return unencrypted. + case EncryptionType.UNENCRYPTED => + false + // PLAINTEXT_FOOTER has a plaintext footer however the file is encrypted. + // In such cases, read the footer and use the metadata for encryption check. + case EncryptionType.PLAINTEXT_FOOTER => + true + case _ => + false + } + } + + override def getOtherConstantMetadataColumnValues(file: PartitionedFile): JMap[String, Object] = + file.otherConstantMetadataColumnValues.asJava.asInstanceOf[JMap[String, Object]] + + override def getCollectLimitOffset(plan: CollectLimitExec): Int = { + plan.offset + } + + override def unBase64FunctionFailsOnError(unBase64: UnBase64): Boolean = unBase64.failOnError + + override def extractExpressionTimestampAddUnit(exp: Expression): Option[Seq[String]] = { + exp match { + // Velox does not support quantity larger than Int.MaxValue. + case TimestampAdd(_, LongLiteral(quantity), _, _) if quantity > Integer.MAX_VALUE => + Option.empty + case timestampAdd: TimestampAdd => + Option.apply(Seq(timestampAdd.unit, timestampAdd.timeZoneId.getOrElse(""))) + case _ => Option.empty + } + } + + override def extractExpressionTimestampDiffUnit(exp: Expression): Option[String] = { + exp match { + case timestampDiff: TimestampDiff => + Some(timestampDiff.unit) + case _ => Option.empty + } + } + + override def widerDecimalType(d1: DecimalType, d2: DecimalType): DecimalType = { + DecimalPrecisionTypeCoercion.widerDecimalType(d1, d2) + } + + override def getErrorMessage(raiseError: RaiseError): Option[Expression] = { + raiseError.errorParms match { + case CreateMap(children, _) + if children.size == 2 && children.head.isInstanceOf[Literal] + && children.head.asInstanceOf[Literal].value.toString == "errorMessage" => + Some(children(1)) + case _ => + None + } + } + + override def throwExceptionInWrite( + t: Throwable, + writePath: String, + descriptionPath: String): Unit = { + throw t + } + + override def enrichWriteException(cause: Throwable, path: String): Nothing = { + GlutenFileFormatWriter.wrapWriteError(cause, path) + } + override def getFileSourceScanStream(scan: FileSourceScanExec): Option[SparkDataStream] = { + scan.stream + } + + override def unsupportedCodec: Seq[CompressionCodecName] = { + Seq(CompressionCodecName.LZO, CompressionCodecName.BROTLI, CompressionCodecName.LZ4_RAW) + } + + /** + * Shim layer for QueryExecution to maintain compatibility across different Spark versions. + * + * @since Spark + * 4.1 + */ + override def createSparkPlan( + sparkSession: SparkSession, + planner: SparkPlanner, + plan: LogicalPlan): SparkPlan = + QueryExecution.createSparkPlan(planner, plan) +} diff --git a/shims/spark41/src/main/scala/org/apache/gluten/sql/shims/spark41/SparkShimProvider.scala b/shims/spark41/src/main/scala/org/apache/gluten/sql/shims/spark41/SparkShimProvider.scala new file mode 100644 index 000000000000..9a5b6a63196d --- /dev/null +++ b/shims/spark41/src/main/scala/org/apache/gluten/sql/shims/spark41/SparkShimProvider.scala @@ -0,0 +1,25 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.gluten.sql.shims.spark41 + +import org.apache.gluten.sql.shims.SparkShims + +class SparkShimProvider extends org.apache.gluten.sql.shims.SparkShimProvider { + def createShim: SparkShims = { + new Spark41Shims() + } +} diff --git a/shims/spark41/src/main/scala/org/apache/gluten/utils/InternalRowUtl.scala b/shims/spark41/src/main/scala/org/apache/gluten/utils/InternalRowUtl.scala new file mode 100644 index 000000000000..654e43cbd03f --- /dev/null +++ b/shims/spark41/src/main/scala/org/apache/gluten/utils/InternalRowUtl.scala @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.gluten.utils + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.types.StructType + +object InternalRowUtl { + def toString(struct: StructType, rows: Iterator[InternalRow]): String = { + val encoder = ExpressionEncoder(struct).resolveAndBind() + val deserializer = encoder.createDeserializer() + rows.map(deserializer).mkString(System.lineSeparator()) + } + + def toString(struct: StructType, rows: Iterator[InternalRow], start: Int, length: Int): String = { + toString(struct, rows.slice(start, start + length)) + } +} diff --git a/shims/spark41/src/main/scala/org/apache/spark/ShuffleUtils.scala b/shims/spark41/src/main/scala/org/apache/spark/ShuffleUtils.scala new file mode 100644 index 000000000000..c2a6cd5cffc1 --- /dev/null +++ b/shims/spark41/src/main/scala/org/apache/spark/ShuffleUtils.scala @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark + +import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleHandle} +import org.apache.spark.storage.{BlockId, BlockManagerId} + +object ShuffleUtils { + def getReaderParam[K, C]( + handle: ShuffleHandle, + startMapIndex: Int, + endMapIndex: Int, + startPartition: Int, + endPartition: Int): Tuple2[Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])], Boolean] = { + val baseShuffleHandle = handle.asInstanceOf[BaseShuffleHandle[K, _, C]] + if (baseShuffleHandle.dependency.isShuffleMergeFinalizedMarked) { + val res = SparkEnv.get.mapOutputTracker.getPushBasedShuffleMapSizesByExecutorId( + handle.shuffleId, + startMapIndex, + endMapIndex, + startPartition, + endPartition) + (res.iter.map(b => (b._1, b._2.toSeq)), res.enableBatchFetch) + } else { + val address = SparkEnv.get.mapOutputTracker.getMapSizesByExecutorId( + handle.shuffleId, + startMapIndex, + endMapIndex, + startPartition, + endPartition) + (address.map(b => (b._1, b._2.toSeq)), true) + } + } +} diff --git a/shims/spark41/src/main/scala/org/apache/spark/SparkContextUtils.scala b/shims/spark41/src/main/scala/org/apache/spark/SparkContextUtils.scala new file mode 100644 index 000000000000..3cbf2b602d47 --- /dev/null +++ b/shims/spark41/src/main/scala/org/apache/spark/SparkContextUtils.scala @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark + +import org.apache.spark.broadcast.Broadcast + +import scala.reflect.ClassTag + +object SparkContextUtils { + def broadcastInternal[T: ClassTag](sc: SparkContext, value: T): Broadcast[T] = { + sc.broadcastInternal(value, serializedOnly = true) + } +} diff --git a/shims/spark41/src/main/scala/org/apache/spark/shuffle/SparkSortShuffleWriterUtil.scala b/shims/spark41/src/main/scala/org/apache/spark/shuffle/SparkSortShuffleWriterUtil.scala new file mode 100644 index 000000000000..95b15f04e7cb --- /dev/null +++ b/shims/spark41/src/main/scala/org/apache/spark/shuffle/SparkSortShuffleWriterUtil.scala @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.shuffle + +import org.apache.spark.TaskContext +import org.apache.spark.shuffle.api.ShuffleExecutorComponents +import org.apache.spark.shuffle.sort.SortShuffleWriter + +object SparkSortShuffleWriterUtil { + def create[K, V, C]( + handle: BaseShuffleHandle[K, V, C], + mapId: Long, + context: TaskContext, + writeMetrics: ShuffleWriteMetricsReporter, + shuffleExecutorComponents: ShuffleExecutorComponents): ShuffleWriter[K, V] = { + new SortShuffleWriter(handle, mapId, context, writeMetrics, shuffleExecutorComponents) + } +} diff --git a/shims/spark41/src/main/scala/org/apache/spark/sql/catalyst/expressions/PromotePrecision.scala b/shims/spark41/src/main/scala/org/apache/spark/sql/catalyst/expressions/PromotePrecision.scala new file mode 100644 index 000000000000..b18a79b864e2 --- /dev/null +++ b/shims/spark41/src/main/scala/org/apache/spark/sql/catalyst/expressions/PromotePrecision.scala @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.types._ + +case class PromotePrecision(child: Expression) extends UnaryExpression { + override def dataType: DataType = child.dataType + override def eval(input: InternalRow): Any = child.eval(input) + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = + child.genCode(ctx) + override def prettyName: String = "promote_precision" + override def sql: String = child.sql + override lazy val canonicalized: Expression = child.canonicalized + + override protected def withNewChildInternal(newChild: Expression): Expression = + copy(child = newChild) +} diff --git a/shims/spark41/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/InvokeExtractors.scala b/shims/spark41/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/InvokeExtractors.scala new file mode 100644 index 000000000000..8abebd0d2a8b --- /dev/null +++ b/shims/spark41/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/InvokeExtractors.scala @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.expressions.objects + +import org.apache.spark.sql.catalyst.expressions.{Expression, Literal} +import org.apache.spark.sql.catalyst.expressions.json.StructsToJsonEvaluator + +/** + * Extractors for Invoke expressions to ensure compatibility across different Spark versions. + * + * Since Spark 4.0, StructsToJson has been replaced with Invoke expressions using + * StructsToJsonEvaluator. This extractor provides a unified interface to extract evaluator options, + * child expression, and timeZoneId from the Invoke pattern. + */ +object StructsToJsonInvoke { + def unapply(expr: Expression): Option[(Map[String, String], Expression, Option[String])] = { + expr match { + case Invoke( + Literal(evaluator: StructsToJsonEvaluator, _), + "evaluate", + _, + Seq(child), + _, + _, + _, + _) => + Some((evaluator.options, child, evaluator.timeZoneId)) + case _ => + None + } + } +} diff --git a/shims/spark41/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CollapseProjectShim.scala b/shims/spark41/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CollapseProjectShim.scala new file mode 100644 index 000000000000..1df1456f4011 --- /dev/null +++ b/shims/spark41/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CollapseProjectShim.scala @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.expressions.{Expression, NamedExpression} + +object CollapseProjectShim { + def canCollapseExpressions( + consumers: Seq[Expression], + producers: Seq[NamedExpression], + alwaysInline: Boolean): Boolean = { + CollapseProject.canCollapseExpressions(consumers, producers, alwaysInline) + } + + def buildCleanedProjectList( + upper: Seq[NamedExpression], + lower: Seq[NamedExpression]): Seq[NamedExpression] = { + CollapseProject.buildCleanedProjectList(upper, lower) + } +} diff --git a/shims/spark41/src/main/scala/org/apache/spark/sql/classic/ClassicColumn.scala b/shims/spark41/src/main/scala/org/apache/spark/sql/classic/ClassicColumn.scala new file mode 100644 index 000000000000..7fe63d706e5d --- /dev/null +++ b/shims/spark41/src/main/scala/org/apache/spark/sql/classic/ClassicColumn.scala @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.classic + +import org.apache.spark.sql.Column +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.classic.ClassicConversions._ + +/** + * Ensures compatibility with Spark 4.0. The implicit class ColumnConstructorExt from + * ClassicConversions is used to construct a Column from an Expression. Since Spark 4.0, the Column + * class is private to the package org.apache.spark. This class provides a way to construct a Column + * in code that is outside the org.apache.spark package. Developers can directly call Column(e) if + * ColumnConstructorExt is imported and the caller code is within this package. + */ +object ClassicColumn { + def apply(e: Expression): Column = { + Column(e) + } +} diff --git a/shims/spark41/src/main/scala/org/apache/spark/sql/classic/ClassicDataset.scala b/shims/spark41/src/main/scala/org/apache/spark/sql/classic/ClassicDataset.scala new file mode 100644 index 000000000000..1a48a8e110a7 --- /dev/null +++ b/shims/spark41/src/main/scala/org/apache/spark/sql/classic/ClassicDataset.scala @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.classic + +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.classic + +/** Since Spark 4.0, the method ofRows cannot be invoked directly from sql.Dataset. */ +object ClassicDataset { + def ofRows(sparkSession: classic.SparkSession, logicalPlan: LogicalPlan): DataFrame = { + // Redirect to the classic.Dataset companion method. + classic.Dataset.ofRows(sparkSession, logicalPlan) + } +} diff --git a/shims/spark41/src/main/scala/org/apache/spark/sql/classic/ClassicTypes.scala b/shims/spark41/src/main/scala/org/apache/spark/sql/classic/ClassicTypes.scala new file mode 100644 index 000000000000..ea3fb8893f9b --- /dev/null +++ b/shims/spark41/src/main/scala/org/apache/spark/sql/classic/ClassicTypes.scala @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.classic + +import org.apache.spark.sql + +/** + * Prior to Spark 4.0, `ClassicSparkSession` refers to `sql.SparkSession`. Since Spark 4.0, it + * refers to `sql.classic.SparkSession`. + */ +object ClassicTypes { + + type ClassicSparkSession = sql.classic.SparkSession +} diff --git a/shims/spark41/src/main/scala/org/apache/spark/sql/classic/conversions.scala b/shims/spark41/src/main/scala/org/apache/spark/sql/classic/conversions.scala new file mode 100644 index 000000000000..b168838f058c --- /dev/null +++ b/shims/spark41/src/main/scala/org/apache/spark/sql/classic/conversions.scala @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.classic + +import org.apache.spark.sql + +/** + * Enables access to the methods in the companion object classic.SparkSession via sql.SparkSession. + * Since Spark 4.0, these methods have been moved from sql.SparkSession to sql.classic.SparkSession. + */ +object ExtendedClassicConversions { + + implicit class RichSqlSparkSession(sqlSparkSession: sql.SparkSession.type) { + def cleanupAnyExistingSession(): Unit = { + // Redirect to the classic.SparkSession companion method. + sql.classic.SparkSession.cleanupAnyExistingSession() + } + } +} diff --git a/shims/spark41/src/main/scala/org/apache/spark/sql/execution/AbstractFileSourceScanExec.scala b/shims/spark41/src/main/scala/org/apache/spark/sql/execution/AbstractFileSourceScanExec.scala new file mode 100644 index 000000000000..ac17c259b87c --- /dev/null +++ b/shims/spark41/src/main/scala/org/apache/spark/sql/execution/AbstractFileSourceScanExec.scala @@ -0,0 +1,307 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution + +import org.apache.gluten.execution.PartitionedFileUtilShim + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.{FileSourceOptions, InternalRow, TableIdentifier} +import org.apache.spark.sql.catalyst.catalog.BucketSpec +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap +import org.apache.spark.sql.errors.QueryExecutionErrors +import org.apache.spark.sql.execution.datasources._ +import org.apache.spark.sql.execution.datasources.parquet.{ParquetFileFormat => ParquetSource} +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.vectorized.ColumnarBatch +import org.apache.spark.util.collection.BitSet + +import org.apache.hadoop.fs.Path + +import java.util.concurrent.TimeUnit._ + +/** + * Physical plan node for scanning data from HadoopFsRelations. + * + * @param relation + * The file-based relation to scan. + * @param output + * Output attributes of the scan, including data attributes and partition attributes. + * @param requiredSchema + * Required schema of the underlying relation, excluding partition columns. + * @param partitionFilters + * Predicates to use for partition pruning. + * @param optionalBucketSet + * Bucket ids for bucket pruning. + * @param optionalNumCoalescedBuckets + * Number of coalesced buckets. + * @param dataFilters + * Filters on non-partition columns. + * @param tableIdentifier + * Identifier for the table in the metastore. + * @param disableBucketedScan + * Disable bucketed scan based on physical query plan, see rule [[DisableUnnecessaryBucketedScan]] + * for details. + */ +abstract class AbstractFileSourceScanExec( + @transient override val relation: HadoopFsRelation, + override val output: Seq[Attribute], + override val requiredSchema: StructType, + override val partitionFilters: Seq[Expression], + override val optionalBucketSet: Option[BitSet], + override val optionalNumCoalescedBuckets: Option[Int], + override val dataFilters: Seq[Expression], + override val tableIdentifier: Option[TableIdentifier], + override val disableBucketedScan: Boolean = false) + extends FileSourceScanLike { + + override def supportsColumnar: Boolean = { + // The value should be defined in GlutenPlan. + throw new UnsupportedOperationException( + "Unreachable code from org.apache.spark.sql.execution.AbstractFileSourceScanExec" + + ".supportsColumnar") + } + + private lazy val needsUnsafeRowConversion: Boolean = { + if (relation.fileFormat.isInstanceOf[ParquetSource]) { + conf.parquetVectorizedReaderEnabled + } else { + false + } + } + + lazy val inputRDD: RDD[InternalRow] = { + val options = relation.options + + (FileFormat.OPTION_RETURNING_BATCH -> supportsColumnar.toString) + val readFile: (PartitionedFile) => Iterator[InternalRow] = + relation.fileFormat.buildReaderWithPartitionValues( + sparkSession = relation.sparkSession, + dataSchema = relation.dataSchema, + partitionSchema = relation.partitionSchema, + requiredSchema = requiredSchema, + filters = pushedDownFilters, + options = options, + hadoopConf = relation.sparkSession.sessionState.newHadoopConfWithOptions(relation.options) + ) + + val readRDD = if (bucketedScan) { + createBucketedReadRDD(relation.bucketSpec.get, readFile, dynamicallySelectedPartitions) + } else { + createReadRDD(readFile, dynamicallySelectedPartitions) + } + sendDriverMetrics() + readRDD + } + + override def inputRDDs(): Seq[RDD[InternalRow]] = { + inputRDD :: Nil + } + + override protected def doExecute(): RDD[InternalRow] = { + val numOutputRows = longMetric("numOutputRows") + if (needsUnsafeRowConversion) { + inputRDD.mapPartitionsWithIndexInternal { + (index, iter) => + val toUnsafe = UnsafeProjection.create(schema) + toUnsafe.initialize(index) + iter.map { + row => + numOutputRows += 1 + toUnsafe(row) + } + } + } else { + inputRDD.mapPartitionsInternal { + iter => + iter.map { + row => + numOutputRows += 1 + row + } + } + } + } + + override protected def doExecuteColumnar(): RDD[ColumnarBatch] = { + val numOutputRows = longMetric("numOutputRows") + val scanTime = longMetric("scanTime") + inputRDD.asInstanceOf[RDD[ColumnarBatch]].mapPartitionsInternal { + batches => + new Iterator[ColumnarBatch] { + + override def hasNext: Boolean = { + // The `FileScanRDD` returns an iterator which scans the file during the `hasNext` call. + val startNs = System.nanoTime() + val res = batches.hasNext + scanTime += NANOSECONDS.toMillis(System.nanoTime() - startNs) + res + } + + override def next(): ColumnarBatch = { + val batch = batches.next() + numOutputRows += batch.numRows() + batch + } + } + } + } + + override val nodeNamePrefix: String = "File" + + /** + * Create an RDD for bucketed reads. The non-bucketed variant of this function is + * [[createReadRDD]]. + * + * The algorithm is pretty simple: each RDD partition being returned should include all the files + * with the same bucket id from all the given Hive partitions. + * + * @param bucketSpec + * the bucketing spec. + * @param readFile + * a function to read each (part of a) file. + * @param selectedPartitions + * Hive-style partition that are part of the read. + */ + private def createBucketedReadRDD( + bucketSpec: BucketSpec, + readFile: (PartitionedFile) => Iterator[InternalRow], + selectedPartitions: ScanFileListing): RDD[InternalRow] = { + logInfo(s"Planning with ${bucketSpec.numBuckets} buckets") + val partitionArray = selectedPartitions.toPartitionArray + val filesGroupedToBuckets = partitionArray.groupBy { + f => + BucketingUtils + .getBucketId(f.toPath.getName) + .getOrElse(throw QueryExecutionErrors.invalidBucketFile(f.urlEncodedPath)) + } + + val prunedFilesGroupedToBuckets = if (optionalBucketSet.isDefined) { + val bucketSet = optionalBucketSet.get + filesGroupedToBuckets.filter(f => bucketSet.get(f._1)) + } else { + filesGroupedToBuckets + } + + val filePartitions = optionalNumCoalescedBuckets + .map { + numCoalescedBuckets => + logInfo(s"Coalescing to $numCoalescedBuckets buckets") + val coalescedBuckets = prunedFilesGroupedToBuckets.groupBy(_._1 % numCoalescedBuckets) + Seq.tabulate(numCoalescedBuckets) { + bucketId => + val partitionedFiles = coalescedBuckets + .get(bucketId) + .map { + _.values.flatten.toArray + } + .getOrElse(Array.empty) + FilePartition(bucketId, partitionedFiles) + } + } + .getOrElse { + Seq.tabulate(bucketSpec.numBuckets) { + bucketId => + FilePartition(bucketId, prunedFilesGroupedToBuckets.getOrElse(bucketId, Array.empty)) + } + } + + new FileScanRDD( + relation.sparkSession, + readFile, + filePartitions, + new StructType(requiredSchema.fields ++ relation.partitionSchema.fields), + fileConstantMetadataColumns, + relation.fileFormat.fileConstantMetadataExtractors, + new FileSourceOptions(CaseInsensitiveMap(relation.options)) + ) + } + + /** + * Create an RDD for non-bucketed reads. The bucketed variant of this function is + * [[createBucketedReadRDD]]. + * + * @param readFile + * a function to read each (part of a) file. + * @param selectedPartitions + * Hive-style partition that are part of the read. + */ + private def createReadRDD( + readFile: (PartitionedFile) => Iterator[InternalRow], + selectedPartitions: ScanFileListing): RDD[InternalRow] = { + val openCostInBytes = relation.sparkSession.sessionState.conf.filesOpenCostInBytes + val maxSplitBytes = + FilePartition.maxSplitBytes(relation.sparkSession, selectedPartitions) + logInfo( + s"Planning scan with bin packing, max size: $maxSplitBytes bytes, " + + s"open cost is considered as scanning $openCostInBytes bytes.") + + // Filter files with bucket pruning if possible + val bucketingEnabled = relation.sparkSession.sessionState.conf.bucketingEnabled + val shouldProcess: Path => Boolean = optionalBucketSet match { + case Some(bucketSet) if bucketingEnabled => + // Do not prune the file if bucket file name is invalid + filePath => BucketingUtils.getBucketId(filePath.getName).forall(bucketSet.get) + case _ => + _ => true + } + + val splitFiles = selectedPartitions.filePartitionIterator + .flatMap { + partition => + partition.files.flatMap { + file => + if (shouldProcess(file.getPath)) { + val isSplitable = relation.fileFormat.isSplitable( + relation.sparkSession, + relation.options, + file.getPath) + PartitionedFileUtilShim.splitFiles( + sparkSession = relation.sparkSession, + file = file, + isSplitable = isSplitable, + maxSplitBytes = maxSplitBytes, + partitionValues = partition.values + ) + } else { + Seq.empty + } + } + } + .toArray + .sortBy(_.length)(implicitly[Ordering[Long]].reverse) + + val partitions = + FilePartition.getFilePartitions(relation.sparkSession, splitFiles, maxSplitBytes) + + new FileScanRDD( + relation.sparkSession, + readFile, + partitions, + new StructType(requiredSchema.fields ++ relation.partitionSchema.fields), + fileConstantMetadataColumns, + relation.fileFormat.fileConstantMetadataExtractors, + new FileSourceOptions(CaseInsensitiveMap(relation.options)) + ) + } + + // Filters unused DynamicPruningExpression expressions - one which has been replaced + // with DynamicPruningExpression(Literal.TrueLiteral) during Physical Planning + protected def filterUnusedDynamicPruningExpressions( + predicates: Seq[Expression]): Seq[Expression] = { + predicates.filterNot(_ == DynamicPruningExpression(Literal.TrueLiteral)) + } +} diff --git a/shims/spark41/src/main/scala/org/apache/spark/sql/execution/ExpandOutputPartitioningShim.scala b/shims/spark41/src/main/scala/org/apache/spark/sql/execution/ExpandOutputPartitioningShim.scala new file mode 100644 index 000000000000..791490064a96 --- /dev/null +++ b/shims/spark41/src/main/scala/org/apache/spark/sql/execution/ExpandOutputPartitioningShim.scala @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution + +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioningLike, Partitioning, PartitioningCollection} + +import scala.collection.mutable + +// https://issues.apache.org/jira/browse/SPARK-31869 +class ExpandOutputPartitioningShim( + streamedKeyExprs: Seq[Expression], + buildKeyExprs: Seq[Expression], + expandLimit: Int) { + // An one-to-many mapping from a streamed key to build keys. + private lazy val streamedKeyToBuildKeyMapping = { + val mapping = mutable.Map.empty[Expression, Seq[Expression]] + streamedKeyExprs.zip(buildKeyExprs).foreach { + case (streamedKey, buildKey) => + val key = streamedKey.canonicalized + mapping.get(key) match { + case Some(v) => mapping.put(key, v :+ buildKey) + case None => mapping.put(key, Seq(buildKey)) + } + } + mapping.toMap + } + + def expandPartitioning(partitioning: Partitioning): Partitioning = { + partitioning match { + case h: HashPartitioningLike => expandOutputPartitioning(h) + case c: PartitioningCollection => expandOutputPartitioning(c) + case _ => partitioning + } + } + + // Expands the given partitioning collection recursively. + private def expandOutputPartitioning( + partitioning: PartitioningCollection): PartitioningCollection = { + PartitioningCollection(partitioning.partitionings.flatMap { + case h: HashPartitioningLike => expandOutputPartitioning(h).partitionings + case c: PartitioningCollection => Seq(expandOutputPartitioning(c)) + case other => Seq(other) + }) + } + + // Expands the given hash partitioning by substituting streamed keys with build keys. + // For example, if the expressions for the given partitioning are Seq("a", "b", "c") + // where the streamed keys are Seq("b", "c") and the build keys are Seq("x", "y"), + // the expanded partitioning will have the following expressions: + // Seq("a", "b", "c"), Seq("a", "b", "y"), Seq("a", "x", "c"), Seq("a", "x", "y"). + // The expanded expressions are returned as PartitioningCollection. + private def expandOutputPartitioning( + partitioning: HashPartitioningLike): PartitioningCollection = { + val maxNumCombinations = expandLimit + var currentNumCombinations = 0 + + def generateExprCombinations( + current: Seq[Expression], + accumulated: Seq[Expression]): Seq[Seq[Expression]] = { + if (currentNumCombinations >= maxNumCombinations) { + Nil + } else if (current.isEmpty) { + currentNumCombinations += 1 + Seq(accumulated) + } else { + val buildKeysOpt = streamedKeyToBuildKeyMapping.get(current.head.canonicalized) + generateExprCombinations(current.tail, accumulated :+ current.head) ++ + buildKeysOpt + .map(_.flatMap(b => generateExprCombinations(current.tail, accumulated :+ b))) + .getOrElse(Nil) + } + } + + PartitioningCollection( + generateExprCombinations(partitioning.expressions, Nil) + .map(exprs => partitioning.withNewChildren(exprs).asInstanceOf[HashPartitioningLike])) + } +} diff --git a/shims/spark41/src/main/scala/org/apache/spark/sql/execution/FileSourceScanExecShim.scala b/shims/spark41/src/main/scala/org/apache/spark/sql/execution/FileSourceScanExecShim.scala new file mode 100644 index 000000000000..d9a50f65c711 --- /dev/null +++ b/shims/spark41/src/main/scala/org/apache/spark/sql/execution/FileSourceScanExecShim.scala @@ -0,0 +1,284 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution + +import org.apache.gluten.metrics.GlutenTimeMetric +import org.apache.gluten.sql.shims.SparkShimLoader + +import org.apache.spark.Partition +import org.apache.spark.internal.LogKeys.{COUNT, MAX_SPLIT_BYTES, OPEN_COST_IN_BYTES} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier} +import org.apache.spark.sql.catalyst.catalog.BucketSpec +import org.apache.spark.sql.catalyst.expressions.{And, Attribute, AttributeReference, BoundReference, Expression, FileSourceConstantMetadataAttribute, FileSourceGeneratedMetadataAttribute, PlanExpression, Predicate} +import org.apache.spark.sql.connector.read.streaming.SparkDataStream +import org.apache.spark.sql.errors.QueryExecutionErrors +import org.apache.spark.sql.execution.datasources.{BucketingUtils, FilePartition, HadoopFsRelation, PartitionDirectory} +import org.apache.spark.sql.execution.datasources.parquet.ParquetUtils +import org.apache.spark.sql.execution.metric.SQLMetric +import org.apache.spark.sql.types.StructType +import org.apache.spark.util.ArrayImplicits._ +import org.apache.spark.util.collection.BitSet + +import org.apache.hadoop.fs.Path + +abstract class FileSourceScanExecShim( + @transient relation: HadoopFsRelation, + output: Seq[Attribute], + requiredSchema: StructType, + partitionFilters: Seq[Expression], + optionalBucketSet: Option[BitSet], + optionalNumCoalescedBuckets: Option[Int], + dataFilters: Seq[Expression], + tableIdentifier: Option[TableIdentifier], + disableBucketedScan: Boolean = false) + extends AbstractFileSourceScanExec( + relation, + output, + requiredSchema, + partitionFilters, + optionalBucketSet, + optionalNumCoalescedBuckets, + dataFilters, + tableIdentifier, + disableBucketedScan) { + + // Note: "metrics" is made transient to avoid sending driver-side metrics to tasks. + @transient override lazy val metrics: Map[String, SQLMetric] = Map() + + lazy val metadataColumns: Seq[AttributeReference] = output.collect { + case FileSourceConstantMetadataAttribute(attr) => attr + case FileSourceGeneratedMetadataAttribute(attr, _) => attr + } + + protected lazy val driverMetricsAlias = driverMetrics + + def dataFiltersInScan: Seq[Expression] = dataFilters.filterNot(_.references.exists { + attr => SparkShimLoader.getSparkShims.isRowIndexMetadataColumn(attr.name) + }) + + def hasUnsupportedColumns: Boolean = { + // TODO, fallback if user define same name column due to we can't right now + // detect which column is metadata column which is user defined column. + val metadataColumnsNames = metadataColumns.map(_.name) + output + .filterNot(metadataColumns.toSet) + .exists(v => metadataColumnsNames.contains(v.name)) + } + + def isMetadataColumn(attr: Attribute): Boolean = metadataColumns.contains(attr) + + def hasFieldIds: Boolean = ParquetUtils.hasFieldIds(requiredSchema) + + protected def isDynamicPruningFilter(e: Expression): Boolean = + e.find(_.isInstanceOf[PlanExpression[_]]).isDefined + + protected def setFilesNumAndSizeMetric(partitions: ScanFileListing, static: Boolean): Unit = { + val filesNum = partitions.totalNumberOfFiles + val filesSize = partitions.totalFileSize + if (!static || !partitionFilters.exists(isDynamicPruningFilter)) { + driverMetrics("numFiles").set(filesNum) + driverMetrics("filesSize").set(filesSize) + } else { + driverMetrics("staticFilesNum").set(filesNum) + driverMetrics("staticFilesSize").set(filesSize) + } + if (relation.partitionSchema.nonEmpty) { + driverMetrics("numPartitions").set(partitions.partitionCount) + } + } + + @transient override lazy val dynamicallySelectedPartitions: ScanFileListing = { + val dynamicDataFilters = dataFilters.filter(isDynamicPruningFilter) + val dynamicPartitionFilters = + partitionFilters.filter(isDynamicPruningFilter) + if (dynamicPartitionFilters.nonEmpty) { + GlutenTimeMetric.withMillisTime { + // call the file index for the files matching all filters except dynamic partition filters + val boundedFilters = dynamicPartitionFilters.map { + dynamicPartitionFilter => + dynamicPartitionFilter.transform { + case a: AttributeReference => + val index = relation.partitionSchema.indexWhere(a.name == _.name) + BoundReference(index, relation.partitionSchema(index).dataType, nullable = true) + } + } + val boundPredicate = Predicate.create(boundedFilters.reduce(And), Nil) + val returnedFiles = + selectedPartitions.filterAndPruneFiles(boundPredicate, dynamicDataFilters) + setFilesNumAndSizeMetric(returnedFiles, false) + returnedFiles + }(t => driverMetrics("pruningTime").set(t)) + } else { + selectedPartitions + } + } + + def getPartitionArray: Array[PartitionDirectory] = { + // TODO: fix the value of partiton directories in dynamic pruning + val staticDataFilters = dataFilters.filterNot(isDynamicPruningFilter) + val staticPartitionFilters = partitionFilters.filterNot(isDynamicPruningFilter) + val partitionDirectories = + relation.location.listFiles(staticPartitionFilters, staticDataFilters) + partitionDirectories.toArray + } + + /** + * Create an RDD for bucketed reads. The non-bucketed variant of this function is + * [[createReadRDD]]. + * + * The algorithm is pretty simple: each RDD partition being returned should include all the files + * with the same bucket id from all the given Hive partitions. + * + * @param bucketSpec + * the bucketing spec. + * @param selectedPartitions + * Hive-style partition that are part of the read. + */ + private def createBucketedReadPartition( + bucketSpec: BucketSpec, + selectedPartitions: ScanFileListing): Seq[FilePartition] = { + logInfo(log"Planning with ${MDC(COUNT, bucketSpec.numBuckets)} buckets") + val partitionArray = selectedPartitions.toPartitionArray + val filesGroupedToBuckets = partitionArray.groupBy { + f => + BucketingUtils + .getBucketId(f.toPath.getName) + .getOrElse(throw QueryExecutionErrors.invalidBucketFile(f.urlEncodedPath)) + } + + val prunedFilesGroupedToBuckets = if (optionalBucketSet.isDefined) { + val bucketSet = optionalBucketSet.get + filesGroupedToBuckets.filter(f => bucketSet.get(f._1)) + } else { + filesGroupedToBuckets + } + + val filePartitions = optionalNumCoalescedBuckets + .map { + numCoalescedBuckets => + logInfo(log"Coalescing to ${MDC(COUNT, numCoalescedBuckets)} buckets") + val coalescedBuckets = prunedFilesGroupedToBuckets.groupBy(_._1 % numCoalescedBuckets) + Seq.tabulate(numCoalescedBuckets) { + bucketId => + val partitionedFiles = coalescedBuckets + .get(bucketId) + .map { + _.values.flatten.toArray + } + .getOrElse(Array.empty) + FilePartition(bucketId, partitionedFiles) + } + } + .getOrElse { + Seq.tabulate(bucketSpec.numBuckets) { + bucketId => + FilePartition(bucketId, prunedFilesGroupedToBuckets.getOrElse(bucketId, Array.empty)) + } + } + filePartitions + } + + /** + * Create an RDD for non-bucketed reads. The bucketed variant of this function is + * [[createBucketedReadRDD]]. + * + * @param selectedPartitions + * Hive-style partition that are part of the read. + */ + private def createReadPartitions(selectedPartitions: ScanFileListing): Seq[FilePartition] = { + val openCostInBytes = relation.sparkSession.sessionState.conf.filesOpenCostInBytes + val maxSplitBytes = + FilePartition.maxSplitBytes(relation.sparkSession, selectedPartitions) + logInfo(log"Planning scan with bin packing, max size: ${MDC(MAX_SPLIT_BYTES, maxSplitBytes)} " + + log"bytes, open cost is considered as scanning ${MDC(OPEN_COST_IN_BYTES, openCostInBytes)} " + + log"bytes.") + + // Filter files with bucket pruning if possible + val bucketingEnabled = relation.sparkSession.sessionState.conf.bucketingEnabled + val shouldProcess: Path => Boolean = optionalBucketSet match { + case Some(bucketSet) if bucketingEnabled => + // Do not prune the file if bucket file name is invalid + filePath => BucketingUtils.getBucketId(filePath.getName).forall(bucketSet.get) + case _ => + _ => true + } + + val splitFiles = selectedPartitions.filePartitionIterator + .flatMap { + partition => + val ListingPartition(partitionVals, _, fileStatusIterator) = partition + fileStatusIterator.flatMap { + file => + // getPath() is very expensive so we only want to call it once in this block: + val filePath = file.getPath + if (shouldProcess(filePath)) { + val isSplitable = + relation.fileFormat.isSplitable(relation.sparkSession, relation.options, filePath) + PartitionedFileUtil.splitFiles( + file = file, + filePath = filePath, + isSplitable = isSplitable, + maxSplitBytes = maxSplitBytes, + partitionValues = partitionVals + ) + } else { + Seq.empty + } + } + } + .toArray + .sortBy(_.length)(implicitly[Ordering[Long]].reverse) + + val partitions = FilePartition + .getFilePartitions(relation.sparkSession, splitFiles.toImmutableArraySeq, maxSplitBytes) + partitions + } + + def getPartitionsSeq(): Seq[Partition] = { + if (bucketedScan) { + createBucketedReadPartition(relation.bucketSpec.get, dynamicallySelectedPartitions) + } else { + createReadPartitions(dynamicallySelectedPartitions) + } + } +} + +abstract class ArrowFileSourceScanLikeShim(original: FileSourceScanExec) + extends FileSourceScanLike { + override val nodeNamePrefix: String = "ArrowFile" + + override def tableIdentifier: Option[TableIdentifier] = original.tableIdentifier + + override def inputRDDs(): Seq[RDD[InternalRow]] = original.inputRDDs() + + override def dataFilters: Seq[Expression] = original.dataFilters + + override def disableBucketedScan: Boolean = original.disableBucketedScan + + override def optionalBucketSet: Option[BitSet] = original.optionalBucketSet + + override def optionalNumCoalescedBuckets: Option[Int] = original.optionalNumCoalescedBuckets + + override def partitionFilters: Seq[Expression] = original.partitionFilters + + override def relation: HadoopFsRelation = original.relation + + override def requiredSchema: StructType = original.requiredSchema + + override def getStream: Option[SparkDataStream] = original.stream +} diff --git a/shims/spark41/src/main/scala/org/apache/spark/sql/execution/GlutenFileFormatWriter.scala b/shims/spark41/src/main/scala/org/apache/spark/sql/execution/GlutenFileFormatWriter.scala new file mode 100644 index 000000000000..8dc5fbc96434 --- /dev/null +++ b/shims/spark41/src/main/scala/org/apache/spark/sql/execution/GlutenFileFormatWriter.scala @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution + +import org.apache.spark.internal.io.FileCommitProtocol +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.errors.QueryExecutionErrors +import org.apache.spark.sql.execution.datasources.{FileFormatWriter, WriteJobDescription, WriteTaskResult} + +object GlutenFileFormatWriter { + def writeFilesExecuteTask( + description: WriteJobDescription, + jobTrackerID: String, + sparkStageId: Int, + sparkPartitionId: Int, + sparkAttemptNumber: Int, + committer: FileCommitProtocol, + iterator: Iterator[InternalRow]): WriteTaskResult = { + FileFormatWriter.executeTask( + description, + jobTrackerID, + sparkStageId, + sparkPartitionId, + sparkAttemptNumber, + committer, + iterator, + None + ) + } + + // Wrapper for throwing standardized write error using QueryExecutionErrors + def wrapWriteError(cause: Throwable, writePath: String): Nothing = { + throw QueryExecutionErrors.taskFailedWhileWritingRowsError(writePath, cause) + } +} diff --git a/shims/spark41/src/main/scala/org/apache/spark/sql/execution/PartitioningAndOrderingPreservingNodeShim.scala b/shims/spark41/src/main/scala/org/apache/spark/sql/execution/PartitioningAndOrderingPreservingNodeShim.scala new file mode 100644 index 000000000000..f0ffcc8b690a --- /dev/null +++ b/shims/spark41/src/main/scala/org/apache/spark/sql/execution/PartitioningAndOrderingPreservingNodeShim.scala @@ -0,0 +1,20 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution + +trait OrderPreservingNodeShim extends OrderPreservingUnaryExecNode +trait PartitioningPreservingNodeShim extends PartitioningPreservingUnaryExecNode diff --git a/shims/spark41/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFooterReaderShim.scala b/shims/spark41/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFooterReaderShim.scala new file mode 100644 index 000000000000..f4cc013ad4da --- /dev/null +++ b/shims/spark41/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFooterReaderShim.scala @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.datasources.parquet + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{FileStatus, Path} +import org.apache.parquet.format.converter.ParquetMetadataConverter +import org.apache.parquet.hadoop.metadata.ParquetMetadata +import org.apache.parquet.hadoop.util.HadoopInputFile + +/** Shim layer for ParquetFooterReader to maintain compatibility across different Spark versions. */ +object ParquetFooterReaderShim { + + /** @since Spark 4.1 */ + def readFooter( + configuration: Configuration, + fileStatus: FileStatus, + filter: ParquetMetadataConverter.MetadataFilter): ParquetMetadata = { + ParquetFooterReader.readFooter(HadoopInputFile.fromStatus(fileStatus, configuration), filter) + } + + /** @since Spark 4.1 */ + def readFooter( + configuration: Configuration, + file: Path, + filter: ParquetMetadataConverter.MetadataFilter): ParquetMetadata = { + ParquetFooterReader.readFooter(HadoopInputFile.fromPath(file, configuration), filter) + } +} diff --git a/shims/spark41/src/main/scala/org/apache/spark/sql/execution/datasources/v2/AbstractBatchScanExec.scala b/shims/spark41/src/main/scala/org/apache/spark/sql/execution/datasources/v2/AbstractBatchScanExec.scala new file mode 100644 index 000000000000..a6a46b10f49d --- /dev/null +++ b/shims/spark41/src/main/scala/org/apache/spark/sql/execution/datasources/v2/AbstractBatchScanExec.scala @@ -0,0 +1,284 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.datasources.v2 + +import org.apache.spark.SparkException +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.physical.{KeyGroupedPartitioning, KeyGroupedShuffleSpec, Partitioning, SinglePartition} +import org.apache.spark.sql.catalyst.util.{truncatedString, InternalRowComparableWrapper} +import org.apache.spark.sql.connector.catalog.Table +import org.apache.spark.sql.connector.read._ +import org.apache.spark.sql.execution.joins.StoragePartitionJoinParams +import org.apache.spark.util.ArrayImplicits._ + +import com.google.common.base.Objects + +/** + * Physical plan node for scanning a batch of data from a data source v2. Please ref BatchScanExec + * in Spark + */ +abstract class AbstractBatchScanExec( + output: Seq[AttributeReference], + @transient scan: Scan, + val runtimeFilters: Seq[Expression], + ordering: Option[Seq[SortOrder]] = None, + @transient table: Table, + val spjParams: StoragePartitionJoinParams = StoragePartitionJoinParams() +) extends DataSourceV2ScanExecBase { + + @transient lazy val batch: Batch = if (scan == null) null else scan.toBatch + + // TODO: unify the equal/hashCode implementation for all data source v2 query plans. + override def equals(other: Any): Boolean = other match { + case other: AbstractBatchScanExec => + this.batch != null && this.batch == other.batch && + this.runtimeFilters == other.runtimeFilters && + this.spjParams == other.spjParams + case _ => + false + } + + override def hashCode(): Int = Objects.hashCode(batch, runtimeFilters) + + @transient override lazy val inputPartitions: Seq[InputPartition] = inputPartitionsShim + + @transient protected lazy val inputPartitionsShim: Seq[InputPartition] = + batch.planInputPartitions() + + @transient private lazy val filteredPartitions: Seq[Seq[InputPartition]] = { + val dataSourceFilters = runtimeFilters.flatMap { + case DynamicPruningExpression(e) => DataSourceV2Strategy.translateRuntimeFilterV2(e) + case _ => None + } + + if (dataSourceFilters.nonEmpty) { + val originalPartitioning = outputPartitioning + + // the cast is safe as runtime filters are only assigned if the scan can be filtered + val filterableScan = scan.asInstanceOf[SupportsRuntimeV2Filtering] + filterableScan.filter(dataSourceFilters.toArray) + + // call toBatch again to get filtered partitions + val newPartitions = scan.toBatch.planInputPartitions() + + originalPartitioning match { + case p: KeyGroupedPartitioning => + if (newPartitions.exists(!_.isInstanceOf[HasPartitionKey])) { + throw new SparkException( + "Data source must have preserved the original partitioning " + + "during runtime filtering: not all partitions implement HasPartitionKey after " + + "filtering") + } + val newPartitionValues = newPartitions + .map( + partition => + InternalRowComparableWrapper( + partition.asInstanceOf[HasPartitionKey], + p.expressions)) + .toSet + val oldPartitionValues = p.partitionValues + .map(partition => InternalRowComparableWrapper(partition, p.expressions)) + .toSet + // We require the new number of partition values to be equal or less than the old number + // of partition values here. In the case of less than, empty partitions will be added for + // those missing values that are not present in the new input partitions. + if (oldPartitionValues.size < newPartitionValues.size) { + throw new SparkException( + "During runtime filtering, data source must either report " + + "the same number of partition values, or a subset of partition values from the " + + s"original. Before: ${oldPartitionValues.size} partition values. " + + s"After: ${newPartitionValues.size} partition values") + } + + if (!newPartitionValues.forall(oldPartitionValues.contains)) { + throw new SparkException( + "During runtime filtering, data source must not report new " + + "partition values that are not present in the original partitioning.") + } + + groupPartitions(newPartitions.toImmutableArraySeq) + .map(_.groupedParts.map(_.parts)) + .getOrElse(Seq.empty) + case _ => + // no validation is needed as the data source did not report any specific partitioning + newPartitions.map(Seq(_)) + } + + } else { + partitions + } + } + + override def outputPartitioning: Partitioning = { + super.outputPartitioning match { + case k: KeyGroupedPartitioning if spjParams.commonPartitionValues.isDefined => + // We allow duplicated partition values if + // `spark.sql.sources.v2.bucketing.partiallyClusteredDistribution.enabled` is true + val newPartValues = spjParams.commonPartitionValues.get.flatMap { + case (partValue, numSplits) => Seq.fill(numSplits)(partValue) + } + k.copy(numPartitions = newPartValues.length, partitionValues = newPartValues) + case p => p + } + } + + override lazy val readerFactory: PartitionReaderFactory = batch.createReaderFactory() + + override lazy val inputRDD: RDD[InternalRow] = { + val rdd = if (filteredPartitions.isEmpty && outputPartitioning == SinglePartition) { + // return an empty RDD with 1 partition if dynamic filtering removed the only split + sparkContext.parallelize(Array.empty[InternalRow], 1) + } else { + val finalPartitions = outputPartitioning match { + case p: KeyGroupedPartitioning => + assert(spjParams.keyGroupedPartitioning.isDefined) + val expressions = spjParams.keyGroupedPartitioning.get + + // Re-group the input partitions if we are projecting on a subset of join keys + val (groupedPartitions, partExpressions) = spjParams.joinKeyPositions match { + case Some(projectPositions) => + val projectedExpressions = projectPositions.map(i => expressions(i)) + val parts = filteredPartitions.flatten + .groupBy( + part => { + val row = part.asInstanceOf[HasPartitionKey].partitionKey() + val projectedRow = + KeyGroupedPartitioning.project(expressions, projectPositions, row) + InternalRowComparableWrapper(projectedRow, projectedExpressions) + }) + .map { case (wrapper, splits) => (wrapper.row, splits) } + .toSeq + (parts, projectedExpressions) + case _ => + val groupedParts = filteredPartitions.map( + splits => { + assert(splits.nonEmpty && splits.head.isInstanceOf[HasPartitionKey]) + (splits.head.asInstanceOf[HasPartitionKey].partitionKey(), splits) + }) + (groupedParts, expressions) + } + + // Also re-group the partitions if we are reducing compatible partition expressions + val finalGroupedPartitions = spjParams.reducers match { + case Some(reducers) => + val result = groupedPartitions + .groupBy { + case (row, _) => + KeyGroupedShuffleSpec.reducePartitionValue(row, partExpressions, reducers) + } + .map { case (wrapper, splits) => (wrapper.row, splits.flatMap(_._2)) } + .toSeq + val rowOrdering = + RowOrdering.createNaturalAscendingOrdering(partExpressions.map(_.dataType)) + result.sorted(rowOrdering.on((t: (InternalRow, _)) => t._1)) + case _ => groupedPartitions + } + + // When partially clustered, the input partitions are not grouped by partition + // values. Here we'll need to check `commonPartitionValues` and decide how to group + // and replicate splits within a partition. + if (spjParams.commonPartitionValues.isDefined && spjParams.applyPartialClustering) { + // A mapping from the common partition values to how many splits the partition + // should contain. + val commonPartValuesMap = spjParams.commonPartitionValues.get + .map(t => (InternalRowComparableWrapper(t._1, partExpressions), t._2)) + .toMap + val filteredGroupedPartitions = finalGroupedPartitions.filter { + case (partValues, _) => + commonPartValuesMap.keySet.contains( + InternalRowComparableWrapper(partValues, partExpressions)) + } + val nestGroupedPartitions = filteredGroupedPartitions.map { + case (partValue, splits) => + // `commonPartValuesMap` should contain the part value since it's the super set. + val numSplits = commonPartValuesMap + .get(InternalRowComparableWrapper(partValue, partExpressions)) + assert( + numSplits.isDefined, + s"Partition value $partValue does not exist in " + + "common partition values from Spark plan") + + val newSplits = if (spjParams.replicatePartitions) { + // We need to also replicate partitions according to the other side of join + Seq.fill(numSplits.get)(splits) + } else { + // Not grouping by partition values: this could be the side with partially + // clustered distribution. Because of dynamic filtering, we'll need to check if + // the final number of splits of a partition is smaller than the original + // number, and fill with empty splits if so. This is necessary so that both + // sides of a join will have the same number of partitions & splits. + splits.map(Seq(_)).padTo(numSplits.get, Seq.empty) + } + (InternalRowComparableWrapper(partValue, partExpressions), newSplits) + } + + // Now fill missing partition keys with empty partitions + val partitionMapping = nestGroupedPartitions.toMap + spjParams.commonPartitionValues.get.flatMap { + case (partValue, numSplits) => + // Use empty partition for those partition values that are not present. + partitionMapping.getOrElse( + InternalRowComparableWrapper(partValue, partExpressions), + Seq.fill(numSplits)(Seq.empty)) + } + } else { + // either `commonPartitionValues` is not defined, or it is defined but + // `applyPartialClustering` is false. + val partitionMapping = finalGroupedPartitions.map { + case (partValue, splits) => + InternalRowComparableWrapper(partValue, partExpressions) -> splits + }.toMap + + // In case `commonPartitionValues` is not defined (e.g., SPJ is not used), there + // could exist duplicated partition values, as partition grouping is not done + // at the beginning and postponed to this method. It is important to use unique + // partition values here so that grouped partitions won't get duplicated. + p.uniquePartitionValues.map { + partValue => + // Use empty partition for those partition values that are not present + partitionMapping.getOrElse( + InternalRowComparableWrapper(partValue, partExpressions), + Seq.empty) + } + } + + case _ => filteredPartitions + } + + new DataSourceRDD( + sparkContext, + finalPartitions, + readerFactory, + supportsColumnar, + customMetrics) + } + postDriverMetrics() + rdd + } + + override def keyGroupedPartitioning: Option[Seq[Expression]] = + spjParams.keyGroupedPartitioning + + override def simpleString(maxFields: Int): String = { + val truncatedOutputString = truncatedString(output, "[", ", ", "]", maxFields) + val runtimeFiltersString = s"RuntimeFilters: ${runtimeFilters.mkString("[", ",", "]")}" + val result = s"$nodeName$truncatedOutputString ${scan.description()} $runtimeFiltersString" + redact(result) + } +} diff --git a/shims/spark41/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala.deprecated b/shims/spark41/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala.deprecated new file mode 100644 index 000000000000..d43331d57c47 --- /dev/null +++ b/shims/spark41/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala.deprecated @@ -0,0 +1,243 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2 + +import com.google.common.base.Objects + +import org.apache.spark.SparkException +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.QueryPlan +import org.apache.spark.sql.catalyst.plans.physical.{KeyGroupedPartitioning, Partitioning, SinglePartition} +import org.apache.spark.sql.catalyst.util.{truncatedString, InternalRowComparableWrapper} +import org.apache.spark.sql.connector.catalog.Table +import org.apache.spark.sql.connector.read._ +import org.apache.spark.sql.internal.SQLConf + +/** + * Physical plan node for scanning a batch of data from a data source v2. + */ +case class BatchScanExec( + output: Seq[AttributeReference], + @transient scan: Scan, + runtimeFilters: Seq[Expression], + keyGroupedPartitioning: Option[Seq[Expression]] = None, + ordering: Option[Seq[SortOrder]] = None, + @transient table: Table, + commonPartitionValues: Option[Seq[(InternalRow, Int)]] = None, + applyPartialClustering: Boolean = false, + replicatePartitions: Boolean = false) extends DataSourceV2ScanExecBase { + + @transient lazy val batch = if (scan == null) null else scan.toBatch + + // TODO: unify the equal/hashCode implementation for all data source v2 query plans. + override def equals(other: Any): Boolean = other match { + case other: BatchScanExec => + this.batch != null && this.batch == other.batch && + this.runtimeFilters == other.runtimeFilters && + this.commonPartitionValues == other.commonPartitionValues && + this.replicatePartitions == other.replicatePartitions && + this.applyPartialClustering == other.applyPartialClustering + case _ => + false + } + + override def hashCode(): Int = Objects.hashCode(batch, runtimeFilters) + + @transient override lazy val inputPartitions: Seq[InputPartition] = batch.planInputPartitions() + + @transient private lazy val filteredPartitions: Seq[Seq[InputPartition]] = { + val dataSourceFilters = runtimeFilters.flatMap { + case DynamicPruningExpression(e) => DataSourceV2Strategy.translateRuntimeFilterV2(e) + case _ => None + } + + if (dataSourceFilters.nonEmpty) { + val originalPartitioning = outputPartitioning + + // the cast is safe as runtime filters are only assigned if the scan can be filtered + val filterableScan = scan.asInstanceOf[SupportsRuntimeV2Filtering] + filterableScan.filter(dataSourceFilters.toArray) + + // call toBatch again to get filtered partitions + val newPartitions = scan.toBatch.planInputPartitions() + + originalPartitioning match { + case p: KeyGroupedPartitioning => + if (newPartitions.exists(!_.isInstanceOf[HasPartitionKey])) { + throw new SparkException("Data source must have preserved the original partitioning " + + "during runtime filtering: not all partitions implement HasPartitionKey after " + + "filtering") + } + val newPartitionValues = newPartitions.map(partition => + InternalRowComparableWrapper(partition.asInstanceOf[HasPartitionKey], p.expressions)) + .toSet + val oldPartitionValues = p.partitionValues + .map(partition => InternalRowComparableWrapper(partition, p.expressions)).toSet + // We require the new number of partition values to be equal or less than the old number + // of partition values here. In the case of less than, empty partitions will be added for + // those missing values that are not present in the new input partitions. + if (oldPartitionValues.size < newPartitionValues.size) { + throw new SparkException("During runtime filtering, data source must either report " + + "the same number of partition values, or a subset of partition values from the " + + s"original. Before: ${oldPartitionValues.size} partition values. " + + s"After: ${newPartitionValues.size} partition values") + } + + if (!newPartitionValues.forall(oldPartitionValues.contains)) { + throw new SparkException("During runtime filtering, data source must not report new " + + "partition values that are not present in the original partitioning.") + } + + groupPartitions(newPartitions).get.map(_._2) + + case _ => + // no validation is needed as the data source did not report any specific partitioning + newPartitions.map(Seq(_)) + } + + } else { + partitions + } + } + + override def outputPartitioning: Partitioning = { + super.outputPartitioning match { + case k: KeyGroupedPartitioning if commonPartitionValues.isDefined => + // We allow duplicated partition values if + // `spark.sql.sources.v2.bucketing.partiallyClusteredDistribution.enabled` is true + val newPartValues = commonPartitionValues.get.flatMap { case (partValue, numSplits) => + Seq.fill(numSplits)(partValue) + } + k.copy(numPartitions = newPartValues.length, partitionValues = newPartValues) + case p => p + } + } + + override lazy val readerFactory: PartitionReaderFactory = batch.createReaderFactory() + + override lazy val inputRDD: RDD[InternalRow] = { + val rdd = if (filteredPartitions.isEmpty && outputPartitioning == SinglePartition) { + // return an empty RDD with 1 partition if dynamic filtering removed the only split + sparkContext.parallelize(Array.empty[InternalRow], 1) + } else { + var finalPartitions = filteredPartitions + + outputPartitioning match { + case p: KeyGroupedPartitioning => + if (conf.v2BucketingPushPartValuesEnabled && + conf.v2BucketingPartiallyClusteredDistributionEnabled) { + assert(filteredPartitions.forall(_.size == 1), + "Expect partitions to be not grouped when " + + s"${SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key} " + + "is enabled") + + val groupedPartitions = groupPartitions(finalPartitions.map(_.head), true).get + + // This means the input partitions are not grouped by partition values. We'll need to + // check `groupByPartitionValues` and decide whether to group and replicate splits + // within a partition. + if (commonPartitionValues.isDefined && applyPartialClustering) { + // A mapping from the common partition values to how many splits the partition + // should contain. Note this no longer maintain the partition key ordering. + val commonPartValuesMap = commonPartitionValues + .get + .map(t => (InternalRowComparableWrapper(t._1, p.expressions), t._2)) + .toMap + val nestGroupedPartitions = groupedPartitions.map { + case (partValue, splits) => + // `commonPartValuesMap` should contain the part value since it's the super set. + val numSplits = commonPartValuesMap + .get(InternalRowComparableWrapper(partValue, p.expressions)) + assert(numSplits.isDefined, s"Partition value $partValue does not exist in " + + "common partition values from Spark plan") + + val newSplits = if (replicatePartitions) { + // We need to also replicate partitions according to the other side of join + Seq.fill(numSplits.get)(splits) + } else { + // Not grouping by partition values: this could be the side with partially + // clustered distribution. Because of dynamic filtering, we'll need to check if + // the final number of splits of a partition is smaller than the original + // number, and fill with empty splits if so. This is necessary so that both + // sides of a join will have the same number of partitions & splits. + splits.map(Seq(_)).padTo(numSplits.get, Seq.empty) + } + (InternalRowComparableWrapper(partValue, p.expressions), newSplits) + } + + // Now fill missing partition keys with empty partitions + val partitionMapping = nestGroupedPartitions.toMap + finalPartitions = commonPartitionValues.get.flatMap { case (partValue, numSplits) => + // Use empty partition for those partition values that are not present. + partitionMapping.getOrElse( + InternalRowComparableWrapper(partValue, p.expressions), + Seq.fill(numSplits)(Seq.empty)) + } + } else { + val partitionMapping = groupedPartitions.map { case (row, parts) => + InternalRowComparableWrapper(row, p.expressions) -> parts + }.toMap + finalPartitions = p.partitionValues.map { partValue => + // Use empty partition for those partition values that are not present + partitionMapping.getOrElse( + InternalRowComparableWrapper(partValue, p.expressions), Seq.empty) + } + } + } else { + val partitionMapping = finalPartitions.map { parts => + val row = parts.head.asInstanceOf[HasPartitionKey].partitionKey() + InternalRowComparableWrapper(row, p.expressions) -> parts + }.toMap + finalPartitions = p.partitionValues.map { partValue => + // Use empty partition for those partition values that are not present + partitionMapping.getOrElse( + InternalRowComparableWrapper(partValue, p.expressions), Seq.empty) + } + } + + case _ => + } + + new DataSourceRDD( + sparkContext, finalPartitions, readerFactory, supportsColumnar, customMetrics) + } + postDriverMetrics() + rdd + } + + override def doCanonicalize(): BatchScanExec = { + this.copy( + output = output.map(QueryPlan.normalizeExpressions(_, output)), + runtimeFilters = QueryPlan.normalizePredicates( + runtimeFilters.filterNot(_ == DynamicPruningExpression(Literal.TrueLiteral)), + output)) + } + + override def simpleString(maxFields: Int): String = { + val truncatedOutputString = truncatedString(output, "[", ", ", "]", maxFields) + val runtimeFiltersString = s"RuntimeFilters: ${runtimeFilters.mkString("[", ",", "]")}" + val result = s"$nodeName$truncatedOutputString ${scan.description()} $runtimeFiltersString" + redact(result) + } + + override def nodeName: String = { + s"BatchScan ${table.name()}".trim + } +} diff --git a/shims/spark41/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExecShim.scala b/shims/spark41/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExecShim.scala new file mode 100644 index 000000000000..046c42ea7b6e --- /dev/null +++ b/shims/spark41/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExecShim.scala @@ -0,0 +1,177 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.datasources.v2 + +import org.apache.spark.SparkException +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.physical.KeyGroupedPartitioning +import org.apache.spark.sql.catalyst.util.InternalRowComparableWrapper +import org.apache.spark.sql.connector.catalog.Table +import org.apache.spark.sql.connector.catalog.functions.Reducer +import org.apache.spark.sql.connector.expressions.aggregate.Aggregation +import org.apache.spark.sql.connector.read.{HasPartitionKey, InputPartition, Scan, SupportsRuntimeV2Filtering} +import org.apache.spark.sql.execution.datasources.v2.orc.OrcScan +import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetScan +import org.apache.spark.sql.execution.joins.StoragePartitionJoinParams +import org.apache.spark.sql.execution.metric.SQLMetric +import org.apache.spark.sql.vectorized.ColumnarBatch +import org.apache.spark.util.ArrayImplicits._ + +abstract class BatchScanExecShim( + output: Seq[AttributeReference], + @transient scan: Scan, + runtimeFilters: Seq[Expression], + keyGroupedPartitioning: Option[Seq[Expression]] = None, + ordering: Option[Seq[SortOrder]] = None, + @transient val table: Table, + val joinKeyPositions: Option[Seq[Int]] = None, + val commonPartitionValues: Option[Seq[(InternalRow, Int)]] = None, + val reducers: Option[Seq[Option[Reducer[_, _]]]] = None, + val applyPartialClustering: Boolean = false, + val replicatePartitions: Boolean = false) + extends AbstractBatchScanExec( + output, + scan, + runtimeFilters, + ordering, + table, + StoragePartitionJoinParams( + keyGroupedPartitioning, + joinKeyPositions, + commonPartitionValues, + reducers, + applyPartialClustering, + replicatePartitions) + ) { + + // Note: "metrics" is made transient to avoid sending driver-side metrics to tasks. + @transient override lazy val metrics: Map[String, SQLMetric] = Map() + + lazy val metadataColumns: Seq[AttributeReference] = output.collect { + case FileSourceConstantMetadataAttribute(attr) => attr + case FileSourceGeneratedMetadataAttribute(attr, _) => attr + } + + def hasUnsupportedColumns: Boolean = { + // TODO, fallback if user define same name column due to we can't right now + // detect which column is metadata column which is user defined column. + val metadataColumnsNames = metadataColumns.map(_.name) + output + .filterNot(metadataColumns.toSet) + .exists(v => metadataColumnsNames.contains(v.name)) + } + + override def doExecuteColumnar(): RDD[ColumnarBatch] = { + throw new UnsupportedOperationException("Need to implement this method") + } + + @transient protected lazy val filteredPartitions: Seq[Seq[InputPartition]] = { + val dataSourceFilters = runtimeFilters.flatMap { + case DynamicPruningExpression(e) => DataSourceV2Strategy.translateRuntimeFilterV2(e) + case _ => None + } + + if (dataSourceFilters.nonEmpty) { + val originalPartitioning = outputPartitioning + + // the cast is safe as runtime filters are only assigned if the scan can be filtered + val filterableScan = scan.asInstanceOf[SupportsRuntimeV2Filtering] + filterableScan.filter(dataSourceFilters.toArray) + + // call toBatch again to get filtered partitions + val newPartitions = scan.toBatch.planInputPartitions() + + originalPartitioning match { + case p: KeyGroupedPartitioning => + if (newPartitions.exists(!_.isInstanceOf[HasPartitionKey])) { + throw new SparkException( + "Data source must have preserved the original partitioning " + + "during runtime filtering: not all partitions implement HasPartitionKey after " + + "filtering") + } + val newPartitionValues = newPartitions + .map( + partition => + InternalRowComparableWrapper( + partition.asInstanceOf[HasPartitionKey], + p.expressions)) + .toSet + val oldPartitionValues = p.partitionValues + .map(partition => InternalRowComparableWrapper(partition, p.expressions)) + .toSet + // We require the new number of partition values to be equal or less than the old number + // of partition values here. In the case of less than, empty partitions will be added for + // those missing values that are not present in the new input partitions. + if (oldPartitionValues.size < newPartitionValues.size) { + throw new SparkException( + "During runtime filtering, data source must either report " + + "the same number of partition values, or a subset of partition values from the " + + s"original. Before: ${oldPartitionValues.size} partition values. " + + s"After: ${newPartitionValues.size} partition values") + } + + if (!newPartitionValues.forall(oldPartitionValues.contains)) { + throw new SparkException( + "During runtime filtering, data source must not report new " + + "partition values that are not present in the original partitioning.") + } + + groupPartitions(newPartitions.toImmutableArraySeq) + .map(_.groupedParts.map(_.parts)) + .getOrElse(Seq.empty) + + case _ => + // no validation is needed as the data source did not report any specific partitioning + newPartitions.map(Seq(_)) + } + + } else { + partitions + } + } + + @transient lazy val pushedAggregate: Option[Aggregation] = { + scan match { + case s: ParquetScan => s.pushedAggregate + case o: OrcScan => o.pushedAggregate + case _ => None + } + } +} + +abstract class ArrowBatchScanExecShim(original: BatchScanExec) + extends BatchScanExecShim( + original.output, + original.scan, + original.runtimeFilters, + original.spjParams.keyGroupedPartitioning, + original.ordering, + original.table, + original.spjParams.joinKeyPositions, + original.spjParams.commonPartitionValues, + original.spjParams.reducers, + original.spjParams.applyPartialClustering, + original.spjParams.replicatePartitions + ) { + override def scan: Scan = original.scan + + override def ordering: Option[Seq[SortOrder]] = original.ordering + + override def output: Seq[Attribute] = original.output +} diff --git a/shims/spark41/src/main/scala/org/apache/spark/sql/execution/datasources/v2/utils/CatalogUtil.scala b/shims/spark41/src/main/scala/org/apache/spark/sql/execution/datasources/v2/utils/CatalogUtil.scala new file mode 100644 index 000000000000..c6686ba55299 --- /dev/null +++ b/shims/spark41/src/main/scala/org/apache/spark/sql/execution/datasources/v2/utils/CatalogUtil.scala @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.datasources.v2.utils + +import org.apache.spark.sql.catalyst.catalog.BucketSpec +import org.apache.spark.sql.connector.expressions.Transform + +object CatalogUtil { + + def convertPartitionTransforms(partitions: Seq[Transform]): (Seq[String], Option[BucketSpec]) = { + import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.TransformHelper + val (identityCols, bucketSpec, _) = partitions.convertTransforms + (identityCols, bucketSpec) + } +} diff --git a/shims/spark41/src/main/scala/org/apache/spark/sql/execution/python/BasePythonRunnerShim.scala b/shims/spark41/src/main/scala/org/apache/spark/sql/execution/python/BasePythonRunnerShim.scala new file mode 100644 index 000000000000..127a0fc3cfc9 --- /dev/null +++ b/shims/spark41/src/main/scala/org/apache/spark/sql/execution/python/BasePythonRunnerShim.scala @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.python + +import org.apache.spark.SparkEnv +import org.apache.spark.TaskContext +import org.apache.spark.api.python.{BasePythonRunner, ChainedPythonFunctions, PythonWorker} +import org.apache.spark.sql.execution.metric.SQLMetric +import org.apache.spark.sql.execution.python.EvalPythonExec.ArgumentMetadata +import org.apache.spark.sql.vectorized.ColumnarBatch + +import java.io.DataOutputStream + +abstract class BasePythonRunnerShim( + funcs: Seq[(ChainedPythonFunctions, Long)], + evalType: Int, + argMetas: Array[Array[(Int, Option[String])]], + pythonMetrics: Map[String, SQLMetric]) + extends BasePythonRunner[ColumnarBatch, ColumnarBatch]( + funcs.map(_._1), + evalType, + argMetas.map(_.map(_._1)), + None, + pythonMetrics) { + + protected def createNewWriter( + env: SparkEnv, + worker: PythonWorker, + inputIterator: Iterator[ColumnarBatch], + partitionIndex: Int, + context: TaskContext): Writer + + protected def writeUdf( + dataOut: DataOutputStream, + argOffsets: Array[Array[(Int, Option[String])]]): Unit = { + PythonUDFRunner.writeUDFs( + dataOut, + funcs, + argOffsets.map(_.map(pair => ArgumentMetadata(pair._1, pair._2))), + None) + } + + override protected def newWriter( + env: SparkEnv, + worker: PythonWorker, + inputIterator: Iterator[ColumnarBatch], + partitionIndex: Int, + context: TaskContext): Writer = { + createNewWriter(env, worker, inputIterator, partitionIndex, context) + } + +} diff --git a/shims/spark41/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExecBase.scala b/shims/spark41/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExecBase.scala new file mode 100644 index 000000000000..7ad7ca6b09ee --- /dev/null +++ b/shims/spark41/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExecBase.scala @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.python + +import org.apache.spark.sql.catalyst.expressions.{Expression, NamedArgumentExpression} + +abstract class EvalPythonExecBase extends EvalPythonExec { + + override protected def evaluatorFactory: EvalPythonEvaluatorFactory = { + throw new IllegalStateException("EvalPythonExecTransformer doesn't support evaluate") + } +} + +object EvalPythonExecBase { + object NamedArgumentExpressionShim { + def unapply(expr: Expression): Option[(String, Expression)] = expr match { + case NamedArgumentExpression(key, value) => Some((key, value)) + case _ => None + } + } +} diff --git a/shims/spark41/src/main/scala/org/apache/spark/sql/execution/ui/TypeAlias.scala b/shims/spark41/src/main/scala/org/apache/spark/sql/execution/ui/TypeAlias.scala new file mode 100644 index 000000000000..40c66590e022 --- /dev/null +++ b/shims/spark41/src/main/scala/org/apache/spark/sql/execution/ui/TypeAlias.scala @@ -0,0 +1,25 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.ui + +/** + * Ensures compatibility for the type HttpServletRequest across Spark 4.0 and earlier versions. + * Starting from Spark 4.0, `jakarta.servlet.http.HttpServletRequest` is used. + */ +object TypeAlias { + type HttpServletRequest = jakarta.servlet.http.HttpServletRequest +} diff --git a/shims/spark41/src/main/scala/org/apache/spark/sql/hive/execution/AbstractHiveTableScanExec.scala b/shims/spark41/src/main/scala/org/apache/spark/sql/hive/execution/AbstractHiveTableScanExec.scala new file mode 100644 index 000000000000..8422d33e521d --- /dev/null +++ b/shims/spark41/src/main/scala/org/apache/spark/sql/hive/execution/AbstractHiveTableScanExec.scala @@ -0,0 +1,273 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.hive.execution + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.CastSupport +import org.apache.spark.sql.catalyst.catalog.HiveTableRelation +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.SchemaPruning.RootField +import org.apache.spark.sql.execution._ +import org.apache.spark.sql.execution.metric.SQLMetrics +import org.apache.spark.sql.hive._ +import org.apache.spark.sql.hive.client.HiveClientImpl +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.{BooleanType, DataType, StructType} +import org.apache.spark.util.Utils + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.hive.ql.io.{DelegateSymlinkTextInputFormat, SymlinkTextInputFormat} +import org.apache.hadoop.hive.ql.metadata.{Partition => HivePartition} +import org.apache.hadoop.hive.ql.plan.TableDesc +import org.apache.hadoop.hive.serde.serdeConstants +import org.apache.hadoop.hive.serde2.objectinspector._ +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils.ObjectInspectorCopyOption +import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoUtils +import org.apache.hadoop.mapred.InputFormat + +import scala.collection.JavaConverters._ + +/** + * The Hive table scan operator. Column and partition pruning are both handled. + * + * @param requestedAttributes + * Attributes to be fetched from the Hive table. + * @param relation + * The Hive table be scanned. + * @param partitionPruningPred + * An optional partition pruning predicate for partitioned table. + * @param prunedOutput + * The pruned output. + */ +abstract private[hive] class AbstractHiveTableScanExec( + requestedAttributes: Seq[Attribute], + relation: HiveTableRelation, + partitionPruningPred: Seq[Expression], + prunedOutput: Seq[Attribute] = Seq.empty[Attribute])( + @transient protected val sparkSession: SparkSession) + extends LeafExecNode + with CastSupport { + + require( + partitionPruningPred.isEmpty || relation.isPartitioned, + "Partition pruning predicates only supported for partitioned tables.") + + override def conf: SQLConf = sparkSession.sessionState.conf + + override def nodeName: String = s"Scan hive ${relation.tableMeta.qualifiedName}" + + override lazy val metrics = Map( + "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) + + override def producedAttributes: AttributeSet = outputSet ++ + AttributeSet(partitionPruningPred.flatMap(_.references)) + + private val originalAttributes = AttributeMap(relation.output.map(a => a -> a)) + + override def output: Seq[Attribute] = { + if (prunedOutput.nonEmpty) { + prunedOutput + } else { + // Retrieve the original attributes based on expression ID so that capitalization matches. + requestedAttributes.map(attr => originalAttributes.getOrElse(attr, attr)).distinct + } + } + + // Bind all partition key attribute references in the partition pruning predicate for later + // evaluation. + private lazy val boundPruningPred = partitionPruningPred.reduceLeftOption(And).map { + pred => + require( + pred.dataType == BooleanType, + s"Data type of predicate $pred must be ${BooleanType.catalogString} rather than " + + s"${pred.dataType.catalogString}.") + + BindReferences.bindReference(pred, relation.partitionCols) + } + + @transient private lazy val hiveQlTable = HiveClientImpl.toHiveTable(relation.tableMeta) + @transient private lazy val tableDesc = new TableDesc( + getInputFormat(hiveQlTable.getInputFormatClass, conf), + hiveQlTable.getOutputFormatClass, + hiveQlTable.getMetadata) + + // Create a local copy of hadoopConf,so that scan specific modifications should not impact + // other queries + @transient private lazy val hadoopConf = { + val c = sparkSession.sessionState.newHadoopConf() + // append columns ids and names before broadcast + addColumnMetadataToConf(c) + c + } + + @transient private lazy val hadoopReader = + new HadoopTableReader(output, relation.partitionCols, tableDesc, sparkSession, hadoopConf) + + private def castFromString(value: String, dataType: DataType) = { + cast(Literal(value), dataType).eval(null) + } + + private def addColumnMetadataToConf(hiveConf: Configuration): Unit = { + // Specifies needed column IDs for those non-partitioning columns. + val columnOrdinals = AttributeMap(relation.dataCols.zipWithIndex) + val neededColumnIDs = output.flatMap(columnOrdinals.get).map(o => o: Integer) + val neededColumnNames = output.filter(columnOrdinals.contains).map(_.name) + + HiveShim.appendReadColumns(hiveConf, neededColumnIDs, neededColumnNames) + + val deserializer = tableDesc.getDeserializerClass.getConstructor().newInstance() + deserializer.initialize(hiveConf, tableDesc.getProperties) + + // Specifies types and object inspectors of columns to be scanned. + val structOI = ObjectInspectorUtils + .getStandardObjectInspector(deserializer.getObjectInspector, ObjectInspectorCopyOption.JAVA) + .asInstanceOf[StructObjectInspector] + + val columnTypeNames = structOI.getAllStructFieldRefs.asScala + .map(_.getFieldObjectInspector) + .map(TypeInfoUtils.getTypeInfoFromObjectInspector(_).getTypeName) + .mkString(",") + + hiveConf.set(serdeConstants.LIST_COLUMN_TYPES, columnTypeNames) + hiveConf.set(serdeConstants.LIST_COLUMNS, relation.dataCols.map(_.name).mkString(",")) + } + + /** + * Prunes partitions not involve the query plan. + * + * @param partitions + * All partitions of the relation. + * @return + * Partitions that are involved in the query plan. + */ + private[hive] def prunePartitions(partitions: Seq[HivePartition]): Seq[HivePartition] = { + boundPruningPred match { + case None => partitions + case Some(shouldKeep) => + partitions.filter { + part => + val dataTypes = relation.partitionCols.map(_.dataType) + val castedValues = part.getValues.asScala + .zip(dataTypes) + .map { case (value, dataType) => castFromString(value, dataType) } + + // Only partitioned values are needed here, since the predicate has + // already been bound to partition key attribute references. + val row = InternalRow.fromSeq(castedValues.toSeq) + shouldKeep.eval(row).asInstanceOf[Boolean] + } + } + } + + @transient lazy val prunedPartitions: Seq[HivePartition] = { + if (relation.prunedPartitions.nonEmpty) { + val hivePartitions = + relation.prunedPartitions.get.map(HiveClientImpl.toHivePartition(_, hiveQlTable)) + if (partitionPruningPred.forall(!ExecSubqueryExpression.hasSubquery(_))) { + hivePartitions + } else { + prunePartitions(hivePartitions) + } + } else { + if ( + sparkSession.sessionState.conf.metastorePartitionPruning && + partitionPruningPred.nonEmpty + ) { + rawPartitions + } else { + prunePartitions(rawPartitions) + } + } + } + + // exposed for tests + @transient lazy val rawPartitions: Seq[HivePartition] = { + val prunedPartitions = + if ( + sparkSession.sessionState.conf.metastorePartitionPruning && + partitionPruningPred.nonEmpty + ) { + // Retrieve the original attributes based on expression ID so that capitalization matches. + val normalizedFilters = partitionPruningPred.map(_.transform { + case a: AttributeReference => originalAttributes(a) + }) + sparkSession.sessionState.catalog + .listPartitionsByFilter(relation.tableMeta.identifier, normalizedFilters) + } else { + sparkSession.sessionState.catalog.listPartitions(relation.tableMeta.identifier) + } + prunedPartitions.map(HiveClientImpl.toHivePartition(_, hiveQlTable)) + } + + override protected def doExecute(): RDD[InternalRow] = { + // Using dummyCallSite, as getCallSite can turn out to be expensive with + // multiple partitions. + val rdd = if (!relation.isPartitioned) { + Utils.withDummyCallSite(sparkContext) { + hadoopReader.makeRDDForTable(hiveQlTable) + } + } else { + Utils.withDummyCallSite(sparkContext) { + hadoopReader.makeRDDForPartitionedTable(prunedPartitions) + } + } + val numOutputRows = longMetric("numOutputRows") + // Avoid to serialize MetastoreRelation because schema is lazy. (see SPARK-15649) + val outputSchema = schema + rdd.mapPartitionsWithIndexInternal { + (index, iter) => + val proj = UnsafeProjection.create(outputSchema) + proj.initialize(index) + iter.map { + r => + numOutputRows += 1 + proj(r) + } + } + } + + // Filters unused DynamicPruningExpression expressions - one which has been replaced + // with DynamicPruningExpression(Literal.TrueLiteral) during Physical Planning + private def filterUnusedDynamicPruningExpressions( + predicates: Seq[Expression]): Seq[Expression] = { + predicates.filterNot(_ == DynamicPruningExpression(Literal.TrueLiteral)) + } + + // Optionally returns a delegate input format based on the provided input format class. + // This is currently used to replace SymlinkTextInputFormat with DelegateSymlinkTextInputFormat + // in order to fix SPARK-40815. + private def getInputFormat( + inputFormatClass: Class[_ <: InputFormat[_, _]], + conf: SQLConf): Class[_ <: InputFormat[_, _]] = { + if ( + inputFormatClass == classOf[SymlinkTextInputFormat] && + conf != null && conf.getConf(HiveUtils.USE_DELEGATE_FOR_SYMLINK_TEXT_INPUT_FORMAT) + ) { + classOf[DelegateSymlinkTextInputFormat] + } else { + inputFormatClass + } + } + + override def otherCopyArgs: Seq[AnyRef] = Seq(sparkSession) + + def pruneSchema(schema: StructType, requestedFields: Seq[RootField]): StructType = { + SchemaPruning.pruneSchema(schema, requestedFields) + } +}